Quality of life patches for eval.py

This commit is contained in:
Alexander Soare 2024-04-19 11:06:29 +01:00
parent d5c4b0c344
commit 2fa693e93b
2 changed files with 53 additions and 38 deletions

View File

@ -44,6 +44,7 @@ import torch
from datasets import Dataset
from huggingface_hub import snapshot_download
from PIL import Image as PILImage
from tqdm import trange
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.envs.factory import make_env
@ -64,8 +65,12 @@ def eval_policy(
video_dir: Path = None,
# TODO(rcadene): make it possible to overwrite fps? we should use env.fps
transform: callable = None,
return_episode_data: bool = False,
seed=None,
):
"""
set `return_episode_data` to return a Hugging Face dataset object in an "episodes" key of the return dict.
"""
fps = env.unwrapped.metadata["render_fps"]
if policy is not None:
@ -118,10 +123,13 @@ def eval_policy(
done = torch.tensor([False for _ in env.envs])
step = 0
max_steps = env.envs[0]._max_episode_steps
progbar = trange(max_steps, desc=f"Running eval with {max_steps} steps (maximum) per rollout.")
while not done.all():
# format from env keys to lerobot keys
observation = preprocess_observation(observation)
observations.append(deepcopy(observation))
if return_episode_data:
observations.append(deepcopy(observation))
# apply transform to normalize the observations
for key in observation:
@ -166,17 +174,20 @@ def eval_policy(
successes.append(success)
step += 1
progbar.update()
env.close()
# add the last observation when the env is done
observation = preprocess_observation(observation)
observations.append(deepcopy(observation))
if return_episode_data:
observation = preprocess_observation(observation)
observations.append(deepcopy(observation))
new_obses = {}
for key in observations[0].keys(): # noqa: SIM118
new_obses[key] = torch.stack([obs[key] for obs in observations], dim=1)
observations = new_obses
if return_episode_data:
new_obses = {}
for key in observations[0].keys(): # noqa: SIM118
new_obses[key] = torch.stack([obs[key] for obs in observations], dim=1)
observations = new_obses
actions = torch.stack(actions, dim=1)
rewards = torch.stack(rewards, dim=1)
successes = torch.stack(successes, dim=1)
@ -208,40 +219,42 @@ def eval_policy(
# TODO(rcadene): We need to add a missing last frame which is the observation
# of a done state. it is critical to have this frame for tdmpc to predict a "done observation/state"
ep_dict = {
"action": actions[ep_id, :num_frames],
"episode_id": torch.tensor([ep_id] * num_frames),
"frame_id": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / fps,
"next.done": dones[ep_id, :num_frames],
"next.reward": rewards[ep_id, :num_frames].type(torch.float32),
"episode_data_index_from": torch.tensor([idx_from] * num_frames),
"episode_data_index_to": torch.tensor([idx_from + num_frames] * num_frames),
}
for key in observations:
ep_dict[key] = observations[key][ep_id][:num_frames]
ep_dicts.append(ep_dict)
if return_episode_data:
ep_dict = {
"action": actions[ep_id, :num_frames],
"episode_id": torch.tensor([ep_id] * num_frames),
"frame_id": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / fps,
"next.done": dones[ep_id, :num_frames],
"next.reward": rewards[ep_id, :num_frames].type(torch.float32),
"episode_data_index_from": torch.tensor([idx_from] * num_frames),
"episode_data_index_to": torch.tensor([idx_from + num_frames] * num_frames),
}
for key in observations:
ep_dict[key] = observations[key][ep_id][:num_frames]
ep_dicts.append(ep_dict)
idx_from += num_frames
# similar logic is implemented in dataset preprocessing
data_dict = {}
keys = ep_dicts[0].keys()
for key in keys:
if "image" not in key:
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
else:
if key not in data_dict:
data_dict[key] = []
for ep_dict in ep_dicts:
for x in ep_dict[key]:
# c h w -> h w c
img = PILImage.fromarray(x.permute(1, 2, 0).numpy())
data_dict[key].append(img)
if return_episode_data:
data_dict = {}
keys = ep_dicts[0].keys()
for key in keys:
if "image" not in key:
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
else:
if key not in data_dict:
data_dict[key] = []
for ep_dict in ep_dicts:
for x in ep_dict[key]:
# c h w -> h w c
img = PILImage.fromarray(x.permute(1, 2, 0).numpy())
data_dict[key].append(img)
data_dict["index"] = torch.arange(0, total_frames, 1)
data_dict["index"] = torch.arange(0, total_frames, 1)
data_dict = Dataset.from_dict(data_dict).with_format("torch")
episodes_as_hf_dataset = Dataset.from_dict(data_dict).with_format("torch")
if max_episodes_rendered > 0:
batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *)
@ -249,7 +262,7 @@ def eval_policy(
for stacked_frames, done_index in zip(
batch_stacked_frames, done_indices.flatten().tolist(), strict=False
):
if episode_counter >= num_episodes:
if episode_counter >= max_episodes_rendered:
continue
video_dir.mkdir(parents=True, exist_ok=True)
video_path = video_dir / f"eval_episode_{episode_counter}.mp4"
@ -292,8 +305,9 @@ def eval_policy(
"eval_s": time.time() - start,
"eval_ep_s": (time.time() - start) / num_episodes,
},
"episodes": data_dict,
}
if return_episode_data:
info["episodes"] = episodes_as_hf_dataset
if max_episodes_rendered > 0:
info["videos"] = videos
return info
@ -330,6 +344,7 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
max_episodes_rendered=10,
video_dir=Path(out_dir) / "eval",
transform=transform,
return_episode_data=False,
seed=cfg.seed,
)
print(info["aggregated"])
@ -337,7 +352,6 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
# Save info
with open(Path(out_dir) / "eval_info.json", "w") as f:
# remove pytorch tensors which are not serializable to save the evaluation results only
del info["episodes"]
del info["videos"]
json.dump(info, f, indent=2)

View File

@ -303,6 +303,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
rollout_env,
policy,
transform=offline_dataset.transform,
return_episode_data=True,
seed=cfg.seed,
)