From 2fa693e93b1c1ec0535f8a28e0723a1ceaef0edc Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Fri, 19 Apr 2024 11:06:29 +0100 Subject: [PATCH] Quality of life patches for eval.py --- lerobot/scripts/eval.py | 90 +++++++++++++++++++++++----------------- lerobot/scripts/train.py | 1 + 2 files changed, 53 insertions(+), 38 deletions(-) diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 6c5b757d..3c0204b7 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -44,6 +44,7 @@ import torch from datasets import Dataset from huggingface_hub import snapshot_download from PIL import Image as PILImage +from tqdm import trange from lerobot.common.datasets.factory import make_dataset from lerobot.common.envs.factory import make_env @@ -64,8 +65,12 @@ def eval_policy( video_dir: Path = None, # TODO(rcadene): make it possible to overwrite fps? we should use env.fps transform: callable = None, + return_episode_data: bool = False, seed=None, ): + """ + set `return_episode_data` to return a Hugging Face dataset object in an "episodes" key of the return dict. + """ fps = env.unwrapped.metadata["render_fps"] if policy is not None: @@ -118,10 +123,13 @@ def eval_policy( done = torch.tensor([False for _ in env.envs]) step = 0 + max_steps = env.envs[0]._max_episode_steps + progbar = trange(max_steps, desc=f"Running eval with {max_steps} steps (maximum) per rollout.") while not done.all(): # format from env keys to lerobot keys observation = preprocess_observation(observation) - observations.append(deepcopy(observation)) + if return_episode_data: + observations.append(deepcopy(observation)) # apply transform to normalize the observations for key in observation: @@ -166,17 +174,20 @@ def eval_policy( successes.append(success) step += 1 + progbar.update() env.close() # add the last observation when the env is done - observation = preprocess_observation(observation) - observations.append(deepcopy(observation)) + if return_episode_data: + observation = preprocess_observation(observation) + observations.append(deepcopy(observation)) - new_obses = {} - for key in observations[0].keys(): # noqa: SIM118 - new_obses[key] = torch.stack([obs[key] for obs in observations], dim=1) - observations = new_obses + if return_episode_data: + new_obses = {} + for key in observations[0].keys(): # noqa: SIM118 + new_obses[key] = torch.stack([obs[key] for obs in observations], dim=1) + observations = new_obses actions = torch.stack(actions, dim=1) rewards = torch.stack(rewards, dim=1) successes = torch.stack(successes, dim=1) @@ -208,40 +219,42 @@ def eval_policy( # TODO(rcadene): We need to add a missing last frame which is the observation # of a done state. it is critical to have this frame for tdmpc to predict a "done observation/state" - ep_dict = { - "action": actions[ep_id, :num_frames], - "episode_id": torch.tensor([ep_id] * num_frames), - "frame_id": torch.arange(0, num_frames, 1), - "timestamp": torch.arange(0, num_frames, 1) / fps, - "next.done": dones[ep_id, :num_frames], - "next.reward": rewards[ep_id, :num_frames].type(torch.float32), - "episode_data_index_from": torch.tensor([idx_from] * num_frames), - "episode_data_index_to": torch.tensor([idx_from + num_frames] * num_frames), - } - for key in observations: - ep_dict[key] = observations[key][ep_id][:num_frames] - ep_dicts.append(ep_dict) + if return_episode_data: + ep_dict = { + "action": actions[ep_id, :num_frames], + "episode_id": torch.tensor([ep_id] * num_frames), + "frame_id": torch.arange(0, num_frames, 1), + "timestamp": torch.arange(0, num_frames, 1) / fps, + "next.done": dones[ep_id, :num_frames], + "next.reward": rewards[ep_id, :num_frames].type(torch.float32), + "episode_data_index_from": torch.tensor([idx_from] * num_frames), + "episode_data_index_to": torch.tensor([idx_from + num_frames] * num_frames), + } + for key in observations: + ep_dict[key] = observations[key][ep_id][:num_frames] + ep_dicts.append(ep_dict) idx_from += num_frames # similar logic is implemented in dataset preprocessing - data_dict = {} - keys = ep_dicts[0].keys() - for key in keys: - if "image" not in key: - data_dict[key] = torch.cat([x[key] for x in ep_dicts]) - else: - if key not in data_dict: - data_dict[key] = [] - for ep_dict in ep_dicts: - for x in ep_dict[key]: - # c h w -> h w c - img = PILImage.fromarray(x.permute(1, 2, 0).numpy()) - data_dict[key].append(img) + if return_episode_data: + data_dict = {} + keys = ep_dicts[0].keys() + for key in keys: + if "image" not in key: + data_dict[key] = torch.cat([x[key] for x in ep_dicts]) + else: + if key not in data_dict: + data_dict[key] = [] + for ep_dict in ep_dicts: + for x in ep_dict[key]: + # c h w -> h w c + img = PILImage.fromarray(x.permute(1, 2, 0).numpy()) + data_dict[key].append(img) - data_dict["index"] = torch.arange(0, total_frames, 1) + data_dict["index"] = torch.arange(0, total_frames, 1) - data_dict = Dataset.from_dict(data_dict).with_format("torch") + episodes_as_hf_dataset = Dataset.from_dict(data_dict).with_format("torch") if max_episodes_rendered > 0: batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *) @@ -249,7 +262,7 @@ def eval_policy( for stacked_frames, done_index in zip( batch_stacked_frames, done_indices.flatten().tolist(), strict=False ): - if episode_counter >= num_episodes: + if episode_counter >= max_episodes_rendered: continue video_dir.mkdir(parents=True, exist_ok=True) video_path = video_dir / f"eval_episode_{episode_counter}.mp4" @@ -292,8 +305,9 @@ def eval_policy( "eval_s": time.time() - start, "eval_ep_s": (time.time() - start) / num_episodes, }, - "episodes": data_dict, } + if return_episode_data: + info["episodes"] = episodes_as_hf_dataset if max_episodes_rendered > 0: info["videos"] = videos return info @@ -330,6 +344,7 @@ def eval(cfg: dict, out_dir=None, stats_path=None): max_episodes_rendered=10, video_dir=Path(out_dir) / "eval", transform=transform, + return_episode_data=False, seed=cfg.seed, ) print(info["aggregated"]) @@ -337,7 +352,6 @@ def eval(cfg: dict, out_dir=None, stats_path=None): # Save info with open(Path(out_dir) / "eval_info.json", "w") as f: # remove pytorch tensors which are not serializable to save the evaluation results only - del info["episodes"] del info["videos"] json.dump(info, f, indent=2) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 146fcc21..8a4758b6 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -303,6 +303,7 @@ def train(cfg: dict, out_dir=None, job_name=None): rollout_env, policy, transform=offline_dataset.transform, + return_episode_data=True, seed=cfg.seed, )