Quality of life patches for eval.py

This commit is contained in:
Alexander Soare 2024-04-19 11:06:29 +01:00
parent d5c4b0c344
commit 2fa693e93b
2 changed files with 53 additions and 38 deletions

View File

@ -44,6 +44,7 @@ import torch
from datasets import Dataset from datasets import Dataset
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from PIL import Image as PILImage from PIL import Image as PILImage
from tqdm import trange
from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.factory import make_dataset
from lerobot.common.envs.factory import make_env from lerobot.common.envs.factory import make_env
@ -64,8 +65,12 @@ def eval_policy(
video_dir: Path = None, video_dir: Path = None,
# TODO(rcadene): make it possible to overwrite fps? we should use env.fps # TODO(rcadene): make it possible to overwrite fps? we should use env.fps
transform: callable = None, transform: callable = None,
return_episode_data: bool = False,
seed=None, seed=None,
): ):
"""
set `return_episode_data` to return a Hugging Face dataset object in an "episodes" key of the return dict.
"""
fps = env.unwrapped.metadata["render_fps"] fps = env.unwrapped.metadata["render_fps"]
if policy is not None: if policy is not None:
@ -118,9 +123,12 @@ def eval_policy(
done = torch.tensor([False for _ in env.envs]) done = torch.tensor([False for _ in env.envs])
step = 0 step = 0
max_steps = env.envs[0]._max_episode_steps
progbar = trange(max_steps, desc=f"Running eval with {max_steps} steps (maximum) per rollout.")
while not done.all(): while not done.all():
# format from env keys to lerobot keys # format from env keys to lerobot keys
observation = preprocess_observation(observation) observation = preprocess_observation(observation)
if return_episode_data:
observations.append(deepcopy(observation)) observations.append(deepcopy(observation))
# apply transform to normalize the observations # apply transform to normalize the observations
@ -166,13 +174,16 @@ def eval_policy(
successes.append(success) successes.append(success)
step += 1 step += 1
progbar.update()
env.close() env.close()
# add the last observation when the env is done # add the last observation when the env is done
if return_episode_data:
observation = preprocess_observation(observation) observation = preprocess_observation(observation)
observations.append(deepcopy(observation)) observations.append(deepcopy(observation))
if return_episode_data:
new_obses = {} new_obses = {}
for key in observations[0].keys(): # noqa: SIM118 for key in observations[0].keys(): # noqa: SIM118
new_obses[key] = torch.stack([obs[key] for obs in observations], dim=1) new_obses[key] = torch.stack([obs[key] for obs in observations], dim=1)
@ -208,6 +219,7 @@ def eval_policy(
# TODO(rcadene): We need to add a missing last frame which is the observation # TODO(rcadene): We need to add a missing last frame which is the observation
# of a done state. it is critical to have this frame for tdmpc to predict a "done observation/state" # of a done state. it is critical to have this frame for tdmpc to predict a "done observation/state"
if return_episode_data:
ep_dict = { ep_dict = {
"action": actions[ep_id, :num_frames], "action": actions[ep_id, :num_frames],
"episode_id": torch.tensor([ep_id] * num_frames), "episode_id": torch.tensor([ep_id] * num_frames),
@ -225,6 +237,7 @@ def eval_policy(
idx_from += num_frames idx_from += num_frames
# similar logic is implemented in dataset preprocessing # similar logic is implemented in dataset preprocessing
if return_episode_data:
data_dict = {} data_dict = {}
keys = ep_dicts[0].keys() keys = ep_dicts[0].keys()
for key in keys: for key in keys:
@ -241,7 +254,7 @@ def eval_policy(
data_dict["index"] = torch.arange(0, total_frames, 1) data_dict["index"] = torch.arange(0, total_frames, 1)
data_dict = Dataset.from_dict(data_dict).with_format("torch") episodes_as_hf_dataset = Dataset.from_dict(data_dict).with_format("torch")
if max_episodes_rendered > 0: if max_episodes_rendered > 0:
batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *) batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *)
@ -249,7 +262,7 @@ def eval_policy(
for stacked_frames, done_index in zip( for stacked_frames, done_index in zip(
batch_stacked_frames, done_indices.flatten().tolist(), strict=False batch_stacked_frames, done_indices.flatten().tolist(), strict=False
): ):
if episode_counter >= num_episodes: if episode_counter >= max_episodes_rendered:
continue continue
video_dir.mkdir(parents=True, exist_ok=True) video_dir.mkdir(parents=True, exist_ok=True)
video_path = video_dir / f"eval_episode_{episode_counter}.mp4" video_path = video_dir / f"eval_episode_{episode_counter}.mp4"
@ -292,8 +305,9 @@ def eval_policy(
"eval_s": time.time() - start, "eval_s": time.time() - start,
"eval_ep_s": (time.time() - start) / num_episodes, "eval_ep_s": (time.time() - start) / num_episodes,
}, },
"episodes": data_dict,
} }
if return_episode_data:
info["episodes"] = episodes_as_hf_dataset
if max_episodes_rendered > 0: if max_episodes_rendered > 0:
info["videos"] = videos info["videos"] = videos
return info return info
@ -330,6 +344,7 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
max_episodes_rendered=10, max_episodes_rendered=10,
video_dir=Path(out_dir) / "eval", video_dir=Path(out_dir) / "eval",
transform=transform, transform=transform,
return_episode_data=False,
seed=cfg.seed, seed=cfg.seed,
) )
print(info["aggregated"]) print(info["aggregated"])
@ -337,7 +352,6 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
# Save info # Save info
with open(Path(out_dir) / "eval_info.json", "w") as f: with open(Path(out_dir) / "eval_info.json", "w") as f:
# remove pytorch tensors which are not serializable to save the evaluation results only # remove pytorch tensors which are not serializable to save the evaluation results only
del info["episodes"]
del info["videos"] del info["videos"]
json.dump(info, f, indent=2) json.dump(info, f, indent=2)

View File

@ -303,6 +303,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
rollout_env, rollout_env,
policy, policy,
transform=offline_dataset.transform, transform=offline_dataset.transform,
return_episode_data=True,
seed=cfg.seed, seed=cfg.seed,
) )