Quality of life patches for eval.py
This commit is contained in:
parent
d5c4b0c344
commit
2fa693e93b
|
@ -44,6 +44,7 @@ import torch
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
|
from tqdm import trange
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
|
@ -64,8 +65,12 @@ def eval_policy(
|
||||||
video_dir: Path = None,
|
video_dir: Path = None,
|
||||||
# TODO(rcadene): make it possible to overwrite fps? we should use env.fps
|
# TODO(rcadene): make it possible to overwrite fps? we should use env.fps
|
||||||
transform: callable = None,
|
transform: callable = None,
|
||||||
|
return_episode_data: bool = False,
|
||||||
seed=None,
|
seed=None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
set `return_episode_data` to return a Hugging Face dataset object in an "episodes" key of the return dict.
|
||||||
|
"""
|
||||||
fps = env.unwrapped.metadata["render_fps"]
|
fps = env.unwrapped.metadata["render_fps"]
|
||||||
|
|
||||||
if policy is not None:
|
if policy is not None:
|
||||||
|
@ -118,10 +123,13 @@ def eval_policy(
|
||||||
|
|
||||||
done = torch.tensor([False for _ in env.envs])
|
done = torch.tensor([False for _ in env.envs])
|
||||||
step = 0
|
step = 0
|
||||||
|
max_steps = env.envs[0]._max_episode_steps
|
||||||
|
progbar = trange(max_steps, desc=f"Running eval with {max_steps} steps (maximum) per rollout.")
|
||||||
while not done.all():
|
while not done.all():
|
||||||
# format from env keys to lerobot keys
|
# format from env keys to lerobot keys
|
||||||
observation = preprocess_observation(observation)
|
observation = preprocess_observation(observation)
|
||||||
observations.append(deepcopy(observation))
|
if return_episode_data:
|
||||||
|
observations.append(deepcopy(observation))
|
||||||
|
|
||||||
# apply transform to normalize the observations
|
# apply transform to normalize the observations
|
||||||
for key in observation:
|
for key in observation:
|
||||||
|
@ -166,17 +174,20 @@ def eval_policy(
|
||||||
successes.append(success)
|
successes.append(success)
|
||||||
|
|
||||||
step += 1
|
step += 1
|
||||||
|
progbar.update()
|
||||||
|
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
# add the last observation when the env is done
|
# add the last observation when the env is done
|
||||||
observation = preprocess_observation(observation)
|
if return_episode_data:
|
||||||
observations.append(deepcopy(observation))
|
observation = preprocess_observation(observation)
|
||||||
|
observations.append(deepcopy(observation))
|
||||||
|
|
||||||
new_obses = {}
|
if return_episode_data:
|
||||||
for key in observations[0].keys(): # noqa: SIM118
|
new_obses = {}
|
||||||
new_obses[key] = torch.stack([obs[key] for obs in observations], dim=1)
|
for key in observations[0].keys(): # noqa: SIM118
|
||||||
observations = new_obses
|
new_obses[key] = torch.stack([obs[key] for obs in observations], dim=1)
|
||||||
|
observations = new_obses
|
||||||
actions = torch.stack(actions, dim=1)
|
actions = torch.stack(actions, dim=1)
|
||||||
rewards = torch.stack(rewards, dim=1)
|
rewards = torch.stack(rewards, dim=1)
|
||||||
successes = torch.stack(successes, dim=1)
|
successes = torch.stack(successes, dim=1)
|
||||||
|
@ -208,40 +219,42 @@ def eval_policy(
|
||||||
|
|
||||||
# TODO(rcadene): We need to add a missing last frame which is the observation
|
# TODO(rcadene): We need to add a missing last frame which is the observation
|
||||||
# of a done state. it is critical to have this frame for tdmpc to predict a "done observation/state"
|
# of a done state. it is critical to have this frame for tdmpc to predict a "done observation/state"
|
||||||
ep_dict = {
|
if return_episode_data:
|
||||||
"action": actions[ep_id, :num_frames],
|
ep_dict = {
|
||||||
"episode_id": torch.tensor([ep_id] * num_frames),
|
"action": actions[ep_id, :num_frames],
|
||||||
"frame_id": torch.arange(0, num_frames, 1),
|
"episode_id": torch.tensor([ep_id] * num_frames),
|
||||||
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
"frame_id": torch.arange(0, num_frames, 1),
|
||||||
"next.done": dones[ep_id, :num_frames],
|
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
||||||
"next.reward": rewards[ep_id, :num_frames].type(torch.float32),
|
"next.done": dones[ep_id, :num_frames],
|
||||||
"episode_data_index_from": torch.tensor([idx_from] * num_frames),
|
"next.reward": rewards[ep_id, :num_frames].type(torch.float32),
|
||||||
"episode_data_index_to": torch.tensor([idx_from + num_frames] * num_frames),
|
"episode_data_index_from": torch.tensor([idx_from] * num_frames),
|
||||||
}
|
"episode_data_index_to": torch.tensor([idx_from + num_frames] * num_frames),
|
||||||
for key in observations:
|
}
|
||||||
ep_dict[key] = observations[key][ep_id][:num_frames]
|
for key in observations:
|
||||||
ep_dicts.append(ep_dict)
|
ep_dict[key] = observations[key][ep_id][:num_frames]
|
||||||
|
ep_dicts.append(ep_dict)
|
||||||
|
|
||||||
idx_from += num_frames
|
idx_from += num_frames
|
||||||
|
|
||||||
# similar logic is implemented in dataset preprocessing
|
# similar logic is implemented in dataset preprocessing
|
||||||
data_dict = {}
|
if return_episode_data:
|
||||||
keys = ep_dicts[0].keys()
|
data_dict = {}
|
||||||
for key in keys:
|
keys = ep_dicts[0].keys()
|
||||||
if "image" not in key:
|
for key in keys:
|
||||||
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
if "image" not in key:
|
||||||
else:
|
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||||
if key not in data_dict:
|
else:
|
||||||
data_dict[key] = []
|
if key not in data_dict:
|
||||||
for ep_dict in ep_dicts:
|
data_dict[key] = []
|
||||||
for x in ep_dict[key]:
|
for ep_dict in ep_dicts:
|
||||||
# c h w -> h w c
|
for x in ep_dict[key]:
|
||||||
img = PILImage.fromarray(x.permute(1, 2, 0).numpy())
|
# c h w -> h w c
|
||||||
data_dict[key].append(img)
|
img = PILImage.fromarray(x.permute(1, 2, 0).numpy())
|
||||||
|
data_dict[key].append(img)
|
||||||
|
|
||||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||||
|
|
||||||
data_dict = Dataset.from_dict(data_dict).with_format("torch")
|
episodes_as_hf_dataset = Dataset.from_dict(data_dict).with_format("torch")
|
||||||
|
|
||||||
if max_episodes_rendered > 0:
|
if max_episodes_rendered > 0:
|
||||||
batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *)
|
batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *)
|
||||||
|
@ -249,7 +262,7 @@ def eval_policy(
|
||||||
for stacked_frames, done_index in zip(
|
for stacked_frames, done_index in zip(
|
||||||
batch_stacked_frames, done_indices.flatten().tolist(), strict=False
|
batch_stacked_frames, done_indices.flatten().tolist(), strict=False
|
||||||
):
|
):
|
||||||
if episode_counter >= num_episodes:
|
if episode_counter >= max_episodes_rendered:
|
||||||
continue
|
continue
|
||||||
video_dir.mkdir(parents=True, exist_ok=True)
|
video_dir.mkdir(parents=True, exist_ok=True)
|
||||||
video_path = video_dir / f"eval_episode_{episode_counter}.mp4"
|
video_path = video_dir / f"eval_episode_{episode_counter}.mp4"
|
||||||
|
@ -292,8 +305,9 @@ def eval_policy(
|
||||||
"eval_s": time.time() - start,
|
"eval_s": time.time() - start,
|
||||||
"eval_ep_s": (time.time() - start) / num_episodes,
|
"eval_ep_s": (time.time() - start) / num_episodes,
|
||||||
},
|
},
|
||||||
"episodes": data_dict,
|
|
||||||
}
|
}
|
||||||
|
if return_episode_data:
|
||||||
|
info["episodes"] = episodes_as_hf_dataset
|
||||||
if max_episodes_rendered > 0:
|
if max_episodes_rendered > 0:
|
||||||
info["videos"] = videos
|
info["videos"] = videos
|
||||||
return info
|
return info
|
||||||
|
@ -330,6 +344,7 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
|
||||||
max_episodes_rendered=10,
|
max_episodes_rendered=10,
|
||||||
video_dir=Path(out_dir) / "eval",
|
video_dir=Path(out_dir) / "eval",
|
||||||
transform=transform,
|
transform=transform,
|
||||||
|
return_episode_data=False,
|
||||||
seed=cfg.seed,
|
seed=cfg.seed,
|
||||||
)
|
)
|
||||||
print(info["aggregated"])
|
print(info["aggregated"])
|
||||||
|
@ -337,7 +352,6 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
|
||||||
# Save info
|
# Save info
|
||||||
with open(Path(out_dir) / "eval_info.json", "w") as f:
|
with open(Path(out_dir) / "eval_info.json", "w") as f:
|
||||||
# remove pytorch tensors which are not serializable to save the evaluation results only
|
# remove pytorch tensors which are not serializable to save the evaluation results only
|
||||||
del info["episodes"]
|
|
||||||
del info["videos"]
|
del info["videos"]
|
||||||
json.dump(info, f, indent=2)
|
json.dump(info, f, indent=2)
|
||||||
|
|
||||||
|
|
|
@ -303,6 +303,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
rollout_env,
|
rollout_env,
|
||||||
policy,
|
policy,
|
||||||
transform=offline_dataset.transform,
|
transform=offline_dataset.transform,
|
||||||
|
return_episode_data=True,
|
||||||
seed=cfg.seed,
|
seed=cfg.seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue