Quality of life patches for eval.py
This commit is contained in:
parent
d5c4b0c344
commit
2fa693e93b
|
@ -44,6 +44,7 @@ import torch
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
|
from tqdm import trange
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
|
@ -64,8 +65,12 @@ def eval_policy(
|
||||||
video_dir: Path = None,
|
video_dir: Path = None,
|
||||||
# TODO(rcadene): make it possible to overwrite fps? we should use env.fps
|
# TODO(rcadene): make it possible to overwrite fps? we should use env.fps
|
||||||
transform: callable = None,
|
transform: callable = None,
|
||||||
|
return_episode_data: bool = False,
|
||||||
seed=None,
|
seed=None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
set `return_episode_data` to return a Hugging Face dataset object in an "episodes" key of the return dict.
|
||||||
|
"""
|
||||||
fps = env.unwrapped.metadata["render_fps"]
|
fps = env.unwrapped.metadata["render_fps"]
|
||||||
|
|
||||||
if policy is not None:
|
if policy is not None:
|
||||||
|
@ -118,9 +123,12 @@ def eval_policy(
|
||||||
|
|
||||||
done = torch.tensor([False for _ in env.envs])
|
done = torch.tensor([False for _ in env.envs])
|
||||||
step = 0
|
step = 0
|
||||||
|
max_steps = env.envs[0]._max_episode_steps
|
||||||
|
progbar = trange(max_steps, desc=f"Running eval with {max_steps} steps (maximum) per rollout.")
|
||||||
while not done.all():
|
while not done.all():
|
||||||
# format from env keys to lerobot keys
|
# format from env keys to lerobot keys
|
||||||
observation = preprocess_observation(observation)
|
observation = preprocess_observation(observation)
|
||||||
|
if return_episode_data:
|
||||||
observations.append(deepcopy(observation))
|
observations.append(deepcopy(observation))
|
||||||
|
|
||||||
# apply transform to normalize the observations
|
# apply transform to normalize the observations
|
||||||
|
@ -166,13 +174,16 @@ def eval_policy(
|
||||||
successes.append(success)
|
successes.append(success)
|
||||||
|
|
||||||
step += 1
|
step += 1
|
||||||
|
progbar.update()
|
||||||
|
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
# add the last observation when the env is done
|
# add the last observation when the env is done
|
||||||
|
if return_episode_data:
|
||||||
observation = preprocess_observation(observation)
|
observation = preprocess_observation(observation)
|
||||||
observations.append(deepcopy(observation))
|
observations.append(deepcopy(observation))
|
||||||
|
|
||||||
|
if return_episode_data:
|
||||||
new_obses = {}
|
new_obses = {}
|
||||||
for key in observations[0].keys(): # noqa: SIM118
|
for key in observations[0].keys(): # noqa: SIM118
|
||||||
new_obses[key] = torch.stack([obs[key] for obs in observations], dim=1)
|
new_obses[key] = torch.stack([obs[key] for obs in observations], dim=1)
|
||||||
|
@ -208,6 +219,7 @@ def eval_policy(
|
||||||
|
|
||||||
# TODO(rcadene): We need to add a missing last frame which is the observation
|
# TODO(rcadene): We need to add a missing last frame which is the observation
|
||||||
# of a done state. it is critical to have this frame for tdmpc to predict a "done observation/state"
|
# of a done state. it is critical to have this frame for tdmpc to predict a "done observation/state"
|
||||||
|
if return_episode_data:
|
||||||
ep_dict = {
|
ep_dict = {
|
||||||
"action": actions[ep_id, :num_frames],
|
"action": actions[ep_id, :num_frames],
|
||||||
"episode_id": torch.tensor([ep_id] * num_frames),
|
"episode_id": torch.tensor([ep_id] * num_frames),
|
||||||
|
@ -225,6 +237,7 @@ def eval_policy(
|
||||||
idx_from += num_frames
|
idx_from += num_frames
|
||||||
|
|
||||||
# similar logic is implemented in dataset preprocessing
|
# similar logic is implemented in dataset preprocessing
|
||||||
|
if return_episode_data:
|
||||||
data_dict = {}
|
data_dict = {}
|
||||||
keys = ep_dicts[0].keys()
|
keys = ep_dicts[0].keys()
|
||||||
for key in keys:
|
for key in keys:
|
||||||
|
@ -241,7 +254,7 @@ def eval_policy(
|
||||||
|
|
||||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||||
|
|
||||||
data_dict = Dataset.from_dict(data_dict).with_format("torch")
|
episodes_as_hf_dataset = Dataset.from_dict(data_dict).with_format("torch")
|
||||||
|
|
||||||
if max_episodes_rendered > 0:
|
if max_episodes_rendered > 0:
|
||||||
batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *)
|
batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *)
|
||||||
|
@ -249,7 +262,7 @@ def eval_policy(
|
||||||
for stacked_frames, done_index in zip(
|
for stacked_frames, done_index in zip(
|
||||||
batch_stacked_frames, done_indices.flatten().tolist(), strict=False
|
batch_stacked_frames, done_indices.flatten().tolist(), strict=False
|
||||||
):
|
):
|
||||||
if episode_counter >= num_episodes:
|
if episode_counter >= max_episodes_rendered:
|
||||||
continue
|
continue
|
||||||
video_dir.mkdir(parents=True, exist_ok=True)
|
video_dir.mkdir(parents=True, exist_ok=True)
|
||||||
video_path = video_dir / f"eval_episode_{episode_counter}.mp4"
|
video_path = video_dir / f"eval_episode_{episode_counter}.mp4"
|
||||||
|
@ -292,8 +305,9 @@ def eval_policy(
|
||||||
"eval_s": time.time() - start,
|
"eval_s": time.time() - start,
|
||||||
"eval_ep_s": (time.time() - start) / num_episodes,
|
"eval_ep_s": (time.time() - start) / num_episodes,
|
||||||
},
|
},
|
||||||
"episodes": data_dict,
|
|
||||||
}
|
}
|
||||||
|
if return_episode_data:
|
||||||
|
info["episodes"] = episodes_as_hf_dataset
|
||||||
if max_episodes_rendered > 0:
|
if max_episodes_rendered > 0:
|
||||||
info["videos"] = videos
|
info["videos"] = videos
|
||||||
return info
|
return info
|
||||||
|
@ -330,6 +344,7 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
|
||||||
max_episodes_rendered=10,
|
max_episodes_rendered=10,
|
||||||
video_dir=Path(out_dir) / "eval",
|
video_dir=Path(out_dir) / "eval",
|
||||||
transform=transform,
|
transform=transform,
|
||||||
|
return_episode_data=False,
|
||||||
seed=cfg.seed,
|
seed=cfg.seed,
|
||||||
)
|
)
|
||||||
print(info["aggregated"])
|
print(info["aggregated"])
|
||||||
|
@ -337,7 +352,6 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
|
||||||
# Save info
|
# Save info
|
||||||
with open(Path(out_dir) / "eval_info.json", "w") as f:
|
with open(Path(out_dir) / "eval_info.json", "w") as f:
|
||||||
# remove pytorch tensors which are not serializable to save the evaluation results only
|
# remove pytorch tensors which are not serializable to save the evaluation results only
|
||||||
del info["episodes"]
|
|
||||||
del info["videos"]
|
del info["videos"]
|
||||||
json.dump(info, f, indent=2)
|
json.dump(info, f, indent=2)
|
||||||
|
|
||||||
|
|
|
@ -303,6 +303,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
rollout_env,
|
rollout_env,
|
||||||
policy,
|
policy,
|
||||||
transform=offline_dataset.transform,
|
transform=offline_dataset.transform,
|
||||||
|
return_episode_data=True,
|
||||||
seed=cfg.seed,
|
seed=cfg.seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue