Quality of life patches for eval.py

This commit is contained in:
Alexander Soare 2024-04-19 11:06:29 +01:00
parent d5c4b0c344
commit 2fa693e93b
2 changed files with 53 additions and 38 deletions

View File

@ -44,6 +44,7 @@ import torch
from datasets import Dataset from datasets import Dataset
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from PIL import Image as PILImage from PIL import Image as PILImage
from tqdm import trange
from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.factory import make_dataset
from lerobot.common.envs.factory import make_env from lerobot.common.envs.factory import make_env
@ -64,8 +65,12 @@ def eval_policy(
video_dir: Path = None, video_dir: Path = None,
# TODO(rcadene): make it possible to overwrite fps? we should use env.fps # TODO(rcadene): make it possible to overwrite fps? we should use env.fps
transform: callable = None, transform: callable = None,
return_episode_data: bool = False,
seed=None, seed=None,
): ):
"""
set `return_episode_data` to return a Hugging Face dataset object in an "episodes" key of the return dict.
"""
fps = env.unwrapped.metadata["render_fps"] fps = env.unwrapped.metadata["render_fps"]
if policy is not None: if policy is not None:
@ -118,10 +123,13 @@ def eval_policy(
done = torch.tensor([False for _ in env.envs]) done = torch.tensor([False for _ in env.envs])
step = 0 step = 0
max_steps = env.envs[0]._max_episode_steps
progbar = trange(max_steps, desc=f"Running eval with {max_steps} steps (maximum) per rollout.")
while not done.all(): while not done.all():
# format from env keys to lerobot keys # format from env keys to lerobot keys
observation = preprocess_observation(observation) observation = preprocess_observation(observation)
observations.append(deepcopy(observation)) if return_episode_data:
observations.append(deepcopy(observation))
# apply transform to normalize the observations # apply transform to normalize the observations
for key in observation: for key in observation:
@ -166,17 +174,20 @@ def eval_policy(
successes.append(success) successes.append(success)
step += 1 step += 1
progbar.update()
env.close() env.close()
# add the last observation when the env is done # add the last observation when the env is done
observation = preprocess_observation(observation) if return_episode_data:
observations.append(deepcopy(observation)) observation = preprocess_observation(observation)
observations.append(deepcopy(observation))
new_obses = {} if return_episode_data:
for key in observations[0].keys(): # noqa: SIM118 new_obses = {}
new_obses[key] = torch.stack([obs[key] for obs in observations], dim=1) for key in observations[0].keys(): # noqa: SIM118
observations = new_obses new_obses[key] = torch.stack([obs[key] for obs in observations], dim=1)
observations = new_obses
actions = torch.stack(actions, dim=1) actions = torch.stack(actions, dim=1)
rewards = torch.stack(rewards, dim=1) rewards = torch.stack(rewards, dim=1)
successes = torch.stack(successes, dim=1) successes = torch.stack(successes, dim=1)
@ -208,40 +219,42 @@ def eval_policy(
# TODO(rcadene): We need to add a missing last frame which is the observation # TODO(rcadene): We need to add a missing last frame which is the observation
# of a done state. it is critical to have this frame for tdmpc to predict a "done observation/state" # of a done state. it is critical to have this frame for tdmpc to predict a "done observation/state"
ep_dict = { if return_episode_data:
"action": actions[ep_id, :num_frames], ep_dict = {
"episode_id": torch.tensor([ep_id] * num_frames), "action": actions[ep_id, :num_frames],
"frame_id": torch.arange(0, num_frames, 1), "episode_id": torch.tensor([ep_id] * num_frames),
"timestamp": torch.arange(0, num_frames, 1) / fps, "frame_id": torch.arange(0, num_frames, 1),
"next.done": dones[ep_id, :num_frames], "timestamp": torch.arange(0, num_frames, 1) / fps,
"next.reward": rewards[ep_id, :num_frames].type(torch.float32), "next.done": dones[ep_id, :num_frames],
"episode_data_index_from": torch.tensor([idx_from] * num_frames), "next.reward": rewards[ep_id, :num_frames].type(torch.float32),
"episode_data_index_to": torch.tensor([idx_from + num_frames] * num_frames), "episode_data_index_from": torch.tensor([idx_from] * num_frames),
} "episode_data_index_to": torch.tensor([idx_from + num_frames] * num_frames),
for key in observations: }
ep_dict[key] = observations[key][ep_id][:num_frames] for key in observations:
ep_dicts.append(ep_dict) ep_dict[key] = observations[key][ep_id][:num_frames]
ep_dicts.append(ep_dict)
idx_from += num_frames idx_from += num_frames
# similar logic is implemented in dataset preprocessing # similar logic is implemented in dataset preprocessing
data_dict = {} if return_episode_data:
keys = ep_dicts[0].keys() data_dict = {}
for key in keys: keys = ep_dicts[0].keys()
if "image" not in key: for key in keys:
data_dict[key] = torch.cat([x[key] for x in ep_dicts]) if "image" not in key:
else: data_dict[key] = torch.cat([x[key] for x in ep_dicts])
if key not in data_dict: else:
data_dict[key] = [] if key not in data_dict:
for ep_dict in ep_dicts: data_dict[key] = []
for x in ep_dict[key]: for ep_dict in ep_dicts:
# c h w -> h w c for x in ep_dict[key]:
img = PILImage.fromarray(x.permute(1, 2, 0).numpy()) # c h w -> h w c
data_dict[key].append(img) img = PILImage.fromarray(x.permute(1, 2, 0).numpy())
data_dict[key].append(img)
data_dict["index"] = torch.arange(0, total_frames, 1) data_dict["index"] = torch.arange(0, total_frames, 1)
data_dict = Dataset.from_dict(data_dict).with_format("torch") episodes_as_hf_dataset = Dataset.from_dict(data_dict).with_format("torch")
if max_episodes_rendered > 0: if max_episodes_rendered > 0:
batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *) batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *)
@ -249,7 +262,7 @@ def eval_policy(
for stacked_frames, done_index in zip( for stacked_frames, done_index in zip(
batch_stacked_frames, done_indices.flatten().tolist(), strict=False batch_stacked_frames, done_indices.flatten().tolist(), strict=False
): ):
if episode_counter >= num_episodes: if episode_counter >= max_episodes_rendered:
continue continue
video_dir.mkdir(parents=True, exist_ok=True) video_dir.mkdir(parents=True, exist_ok=True)
video_path = video_dir / f"eval_episode_{episode_counter}.mp4" video_path = video_dir / f"eval_episode_{episode_counter}.mp4"
@ -292,8 +305,9 @@ def eval_policy(
"eval_s": time.time() - start, "eval_s": time.time() - start,
"eval_ep_s": (time.time() - start) / num_episodes, "eval_ep_s": (time.time() - start) / num_episodes,
}, },
"episodes": data_dict,
} }
if return_episode_data:
info["episodes"] = episodes_as_hf_dataset
if max_episodes_rendered > 0: if max_episodes_rendered > 0:
info["videos"] = videos info["videos"] = videos
return info return info
@ -330,6 +344,7 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
max_episodes_rendered=10, max_episodes_rendered=10,
video_dir=Path(out_dir) / "eval", video_dir=Path(out_dir) / "eval",
transform=transform, transform=transform,
return_episode_data=False,
seed=cfg.seed, seed=cfg.seed,
) )
print(info["aggregated"]) print(info["aggregated"])
@ -337,7 +352,6 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
# Save info # Save info
with open(Path(out_dir) / "eval_info.json", "w") as f: with open(Path(out_dir) / "eval_info.json", "w") as f:
# remove pytorch tensors which are not serializable to save the evaluation results only # remove pytorch tensors which are not serializable to save the evaluation results only
del info["episodes"]
del info["videos"] del info["videos"]
json.dump(info, f, indent=2) json.dump(info, f, indent=2)

View File

@ -303,6 +303,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
rollout_env, rollout_env,
policy, policy,
transform=offline_dataset.transform, transform=offline_dataset.transform,
return_episode_data=True,
seed=cfg.seed, seed=cfg.seed,
) )