lerobot/lerobot/scripts/eval.py

96 lines
2.6 KiB
Python
Raw Normal View History

from pathlib import Path
import hydra
import imageio
import numpy as np
import torch
from tensordict.nn import TensorDictModule
from termcolor import colored
2024-01-31 21:54:32 +08:00
from torchrl.envs import EnvBase
from lerobot.common.envs.factory import make_env
from lerobot.common.tdmpc import TDMPC
from lerobot.common.utils import set_seed
def eval_policy(
2024-01-31 21:54:32 +08:00
env: EnvBase,
policy: TensorDictModule = None,
num_episodes: int = 10,
max_steps: int = 30,
save_video: bool = False,
video_dir: Path = None,
):
rewards = []
successes = []
for i in range(num_episodes):
ep_frames = []
def rendering_callback(env, td=None):
nonlocal ep_frames
frame = env.render()
ep_frames.append(frame)
tensordict = env.reset()
if save_video:
# render first frame before rollout
rendering_callback(env)
rollout = env.rollout(
2024-01-31 21:54:32 +08:00
max_steps=max_steps,
policy=policy,
callback=rendering_callback if save_video else None,
auto_reset=False,
tensordict=tensordict,
)
# print(", ".join([f"{x:.3f}" for x in rollout["next", "reward"][:,0].tolist()]))
ep_reward = rollout["next", "reward"].sum()
ep_success = rollout["next", "success"].any()
rewards.append(ep_reward.item())
successes.append(ep_success.item())
2024-01-31 07:30:14 +08:00
if save_video:
video_dir.mkdir(parents=True, exist_ok=True)
# TODO(rcadene): make fps configurable
video_path = video_dir / f"eval_episode_{i}.mp4"
imageio.mimsave(video_path, np.stack(ep_frames), fps=15)
metrics = {
"avg_reward": np.nanmean(rewards),
"pc_success": np.nanmean(successes) * 100,
}
return metrics
@hydra.main(version_base=None, config_name="default", config_path="../configs")
def eval(cfg: dict):
assert torch.cuda.is_available()
set_seed(cfg.seed)
print(colored("Log dir:", "yellow", attrs=["bold"]), cfg.log_dir)
env = make_env(cfg)
policy = TDMPC(cfg)
2024-01-31 07:30:14 +08:00
# ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt"
policy.load(ckpt_path)
2024-01-31 07:30:14 +08:00
policy = TensorDictModule(
policy,
in_keys=["observation", "step_count"],
out_keys=["action"],
)
2024-01-31 21:54:32 +08:00
# policy can be None to rollout a random policy
metrics = eval_policy(
2024-01-31 07:30:14 +08:00
env,
2024-01-31 21:54:32 +08:00
policy=policy,
num_episodes=20,
save_video=False,
video_dir=Path("tmp/2023_01_29_xarm_lift_final"),
2024-01-31 07:30:14 +08:00
)
print(metrics)
if __name__ == "__main__":
eval()