2024-03-22 18:26:55 +08:00
|
|
|
"""Evaluate a policy on an environment by running rollouts and computing metrics.
|
|
|
|
|
|
|
|
The script may be run in one of two ways:
|
|
|
|
|
|
|
|
1. By providing the path to a config file with the --config argument.
|
|
|
|
2. By providing a HuggingFace Hub ID with the --hub-id argument. You may also provide a revision number with the
|
|
|
|
--revision argument.
|
|
|
|
|
|
|
|
In either case, it is possible to override config arguments by adding a list of config.key=value arguments.
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
|
|
|
You have a specific config file to go with trained model weights, and want to run 10 episodes.
|
|
|
|
|
|
|
|
```
|
2024-03-22 21:25:23 +08:00
|
|
|
python lerobot/scripts/eval.py \
|
|
|
|
--config PATH/TO/FOLDER/config.yaml \
|
|
|
|
policy.pretrained_model_path=PATH/TO/FOLDER/weights.pth \
|
|
|
|
eval_episodes=10
|
2024-03-22 18:26:55 +08:00
|
|
|
```
|
|
|
|
|
|
|
|
You have a HuggingFace Hub ID, you know which revision you want, and want to run 10 episodes (note that in this case,
|
|
|
|
you don't need to specify which weights to use):
|
|
|
|
|
|
|
|
```
|
|
|
|
python lerobot/scripts/eval.py --hub-id HUB/ID --revision v1.0 eval_episodes=10
|
|
|
|
```
|
|
|
|
"""
|
|
|
|
|
|
|
|
import argparse
|
2024-03-22 23:06:57 +08:00
|
|
|
import json
|
2024-03-02 23:53:29 +08:00
|
|
|
import logging
|
2024-02-25 02:19:18 +08:00
|
|
|
import threading
|
2024-03-01 07:13:06 +08:00
|
|
|
import time
|
2024-04-10 19:34:01 +08:00
|
|
|
from copy import deepcopy
|
2024-03-22 18:26:55 +08:00
|
|
|
from datetime import datetime as dt
|
2024-01-29 20:49:30 +08:00
|
|
|
from pathlib import Path
|
|
|
|
|
2024-03-20 17:45:45 +08:00
|
|
|
import einops
|
2024-03-31 23:05:25 +08:00
|
|
|
import gymnasium as gym
|
2024-01-29 20:49:30 +08:00
|
|
|
import imageio
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
2024-04-16 05:26:33 +08:00
|
|
|
from datasets import Dataset
|
2024-03-22 18:26:55 +08:00
|
|
|
from huggingface_hub import snapshot_download
|
2024-04-17 00:07:39 +08:00
|
|
|
from PIL import Image as PILImage
|
2024-01-29 20:49:30 +08:00
|
|
|
|
2024-03-31 23:05:25 +08:00
|
|
|
from lerobot.common.datasets.factory import make_dataset
|
2024-01-31 21:48:12 +08:00
|
|
|
from lerobot.common.envs.factory import make_env
|
2024-04-06 00:21:07 +08:00
|
|
|
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
|
2024-03-06 18:14:03 +08:00
|
|
|
from lerobot.common.logger import log_output_dir
|
2024-02-25 18:50:23 +08:00
|
|
|
from lerobot.common.policies.factory import make_policy
|
2024-04-18 20:47:42 +08:00
|
|
|
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
|
2024-02-25 02:19:18 +08:00
|
|
|
|
2024-01-29 20:49:30 +08:00
|
|
|
|
2024-02-25 02:18:39 +08:00
|
|
|
def write_video(video_path, stacked_frames, fps):
|
|
|
|
imageio.mimsave(video_path, stacked_frames, fps=fps)
|
2024-01-29 20:49:30 +08:00
|
|
|
|
2024-02-25 02:19:18 +08:00
|
|
|
|
2024-01-31 21:48:12 +08:00
|
|
|
def eval_policy(
|
2024-03-31 23:05:25 +08:00
|
|
|
env: gym.vector.VectorEnv,
|
2024-04-10 23:09:04 +08:00
|
|
|
policy: torch.nn.Module,
|
2024-04-10 19:34:01 +08:00
|
|
|
max_episodes_rendered: int = 0,
|
2024-01-31 21:54:32 +08:00
|
|
|
video_dir: Path = None,
|
2024-03-31 23:05:25 +08:00
|
|
|
# TODO(rcadene): make it possible to overwrite fps? we should use env.fps
|
|
|
|
transform: callable = None,
|
2024-04-05 17:31:39 +08:00
|
|
|
seed=None,
|
2024-01-29 20:49:30 +08:00
|
|
|
):
|
2024-04-10 19:34:01 +08:00
|
|
|
fps = env.unwrapped.metadata["render_fps"]
|
|
|
|
|
2024-03-21 01:38:55 +08:00
|
|
|
if policy is not None:
|
|
|
|
policy.eval()
|
2024-04-05 17:31:39 +08:00
|
|
|
device = "cpu" if policy is None else next(policy.parameters()).device
|
|
|
|
|
2024-03-01 07:13:06 +08:00
|
|
|
start = time.time()
|
2024-02-22 20:14:12 +08:00
|
|
|
sum_rewards = []
|
|
|
|
max_rewards = []
|
2024-03-31 23:05:25 +08:00
|
|
|
all_successes = []
|
2024-03-22 23:06:57 +08:00
|
|
|
seeds = []
|
2024-03-14 23:22:55 +08:00
|
|
|
threads = [] # for video saving threads
|
|
|
|
episode_counter = 0 # for saving the correct number of videos
|
|
|
|
|
2024-03-31 23:05:25 +08:00
|
|
|
num_episodes = len(env.envs)
|
|
|
|
|
2024-03-14 23:22:55 +08:00
|
|
|
# TODO(alexander-soare): if num_episodes is not evenly divisible by the batch size, this will do more work than
|
|
|
|
# needed as I'm currently taking a ceil.
|
2024-03-31 23:05:25 +08:00
|
|
|
ep_frames = []
|
|
|
|
|
2024-04-10 19:34:01 +08:00
|
|
|
def render_frame(env):
|
|
|
|
# noqa: B023
|
|
|
|
eps_rendered = min(max_episodes_rendered, len(env.envs))
|
|
|
|
visu = np.stack([env.envs[i].render() for i in range(eps_rendered)])
|
|
|
|
ep_frames.append(visu) # noqa: B023
|
2024-03-31 23:05:25 +08:00
|
|
|
|
|
|
|
for _ in range(num_episodes):
|
|
|
|
seeds.append("TODO")
|
|
|
|
|
|
|
|
if hasattr(policy, "reset"):
|
|
|
|
policy.reset()
|
|
|
|
else:
|
|
|
|
logging.warning(
|
2024-04-05 17:31:39 +08:00
|
|
|
f"Policy {policy} doesnt have a `reset` method. It is required if the policy relies on an internal state during rollout."
|
2024-03-31 23:05:25 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
# reset the environment
|
2024-04-05 17:31:39 +08:00
|
|
|
observation, info = env.reset(seed=seed)
|
2024-04-10 19:34:01 +08:00
|
|
|
if max_episodes_rendered > 0:
|
|
|
|
render_frame(env)
|
|
|
|
|
|
|
|
observations = []
|
|
|
|
actions = []
|
|
|
|
# episode
|
|
|
|
# frame_id
|
|
|
|
# timestamp
|
2024-03-31 23:05:25 +08:00
|
|
|
rewards = []
|
|
|
|
successes = []
|
|
|
|
dones = []
|
2024-01-31 21:48:12 +08:00
|
|
|
|
2024-03-31 23:05:25 +08:00
|
|
|
done = torch.tensor([False for _ in env.envs])
|
|
|
|
step = 0
|
2024-04-05 17:31:39 +08:00
|
|
|
while not done.all():
|
2024-04-10 19:34:01 +08:00
|
|
|
# format from env keys to lerobot keys
|
|
|
|
observation = preprocess_observation(observation)
|
|
|
|
observations.append(deepcopy(observation))
|
|
|
|
|
2024-03-31 23:05:25 +08:00
|
|
|
# apply transform to normalize the observations
|
2024-04-10 19:34:01 +08:00
|
|
|
for key in observation:
|
|
|
|
observation[key] = torch.stack([transform({key: item})[key] for item in observation[key]])
|
2024-03-22 23:06:57 +08:00
|
|
|
|
2024-03-31 23:05:25 +08:00
|
|
|
# send observation to device/gpu
|
2024-04-05 17:31:39 +08:00
|
|
|
observation = {key: observation[key].to(device, non_blocking=True) for key in observation}
|
2024-03-31 23:05:25 +08:00
|
|
|
|
|
|
|
# get the next action for the environment
|
2024-02-18 09:23:44 +08:00
|
|
|
with torch.inference_mode():
|
2024-04-15 16:52:54 +08:00
|
|
|
action = policy.select_action(observation, step=step)
|
2024-03-31 23:05:25 +08:00
|
|
|
|
|
|
|
# apply inverse transform to unnormalize the action
|
|
|
|
action = postprocess_action(action, transform)
|
|
|
|
|
2024-04-11 00:07:27 +08:00
|
|
|
# apply the next action
|
2024-03-31 23:05:25 +08:00
|
|
|
observation, reward, terminated, truncated, info = env.step(action)
|
2024-04-10 19:34:01 +08:00
|
|
|
if max_episodes_rendered > 0:
|
|
|
|
render_frame(env)
|
2024-03-31 23:05:25 +08:00
|
|
|
|
|
|
|
# TODO(rcadene): implement a wrapper over env to return torch tensors in float32 (and cuda?)
|
2024-04-10 19:34:01 +08:00
|
|
|
action = torch.from_numpy(action)
|
2024-03-31 23:05:25 +08:00
|
|
|
reward = torch.from_numpy(reward)
|
|
|
|
terminated = torch.from_numpy(terminated)
|
|
|
|
truncated = torch.from_numpy(truncated)
|
|
|
|
# environment is considered done (no more steps), when success state is reached (terminated is True),
|
|
|
|
# or time limit is reached (truncated is True), or it was previsouly done.
|
|
|
|
done = terminated | truncated | done
|
|
|
|
|
|
|
|
if "final_info" in info:
|
|
|
|
# VectorEnv stores is_success into `info["final_info"][env_id]["is_success"]` instead of `info["is_success"]`
|
|
|
|
success = [
|
|
|
|
env_info["is_success"] if env_info is not None else False for env_info in info["final_info"]
|
|
|
|
]
|
|
|
|
else:
|
|
|
|
success = [False for _ in env.envs]
|
|
|
|
success = torch.tensor(success)
|
|
|
|
|
2024-04-10 19:34:01 +08:00
|
|
|
actions.append(action)
|
2024-03-31 23:05:25 +08:00
|
|
|
rewards.append(reward)
|
|
|
|
dones.append(done)
|
|
|
|
successes.append(success)
|
|
|
|
|
|
|
|
step += 1
|
|
|
|
|
2024-04-10 19:34:01 +08:00
|
|
|
env.close()
|
|
|
|
|
|
|
|
# add the last observation when the env is done
|
|
|
|
observation = preprocess_observation(observation)
|
|
|
|
observations.append(deepcopy(observation))
|
|
|
|
|
|
|
|
new_obses = {}
|
|
|
|
for key in observations[0].keys(): # noqa: SIM118
|
|
|
|
new_obses[key] = torch.stack([obs[key] for obs in observations], dim=1)
|
|
|
|
observations = new_obses
|
|
|
|
actions = torch.stack(actions, dim=1)
|
2024-03-31 23:05:25 +08:00
|
|
|
rewards = torch.stack(rewards, dim=1)
|
|
|
|
successes = torch.stack(successes, dim=1)
|
|
|
|
dones = torch.stack(dones, dim=1)
|
|
|
|
|
|
|
|
# Figure out where in each rollout sequence the first done condition was encountered (results after
|
|
|
|
# this won't be included).
|
|
|
|
# Note: this assumes that the shape of the done key is (batch_size, max_steps).
|
|
|
|
# Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker.
|
|
|
|
done_indices = torch.argmax(dones.to(int), axis=1) # (batch_size, rollout_steps)
|
|
|
|
expand_done_indices = done_indices[:, None].expand(-1, step)
|
|
|
|
expand_step_indices = torch.arange(step)[None, :].expand(num_episodes, -1)
|
|
|
|
mask = (expand_step_indices <= expand_done_indices).int() # (batch_size, rollout_steps)
|
|
|
|
batch_sum_reward = einops.reduce((rewards * mask), "b n -> b", "sum")
|
|
|
|
batch_max_reward = einops.reduce((rewards * mask), "b n -> b", "max")
|
|
|
|
batch_success = einops.reduce((successes * mask), "b n -> b", "any")
|
|
|
|
sum_rewards.extend(batch_sum_reward.tolist())
|
|
|
|
max_rewards.extend(batch_max_reward.tolist())
|
|
|
|
all_successes.extend(batch_success.tolist())
|
|
|
|
|
2024-04-10 19:34:01 +08:00
|
|
|
# similar logic is implemented in dataset preprocessing
|
|
|
|
ep_dicts = []
|
|
|
|
num_episodes = dones.shape[0]
|
|
|
|
total_frames = 0
|
2024-04-16 05:26:33 +08:00
|
|
|
idx_from = 0
|
2024-04-10 19:34:01 +08:00
|
|
|
for ep_id in range(num_episodes):
|
|
|
|
num_frames = done_indices[ep_id].item() + 1
|
2024-04-16 05:26:33 +08:00
|
|
|
total_frames += num_frames
|
|
|
|
|
2024-04-10 19:34:01 +08:00
|
|
|
# TODO(rcadene): We need to add a missing last frame which is the observation
|
|
|
|
# of a done state. it is critical to have this frame for tdmpc to predict a "done observation/state"
|
|
|
|
ep_dict = {
|
|
|
|
"action": actions[ep_id, :num_frames],
|
2024-04-16 05:26:33 +08:00
|
|
|
"episode_id": torch.tensor([ep_id] * num_frames),
|
2024-04-10 19:34:01 +08:00
|
|
|
"frame_id": torch.arange(0, num_frames, 1),
|
|
|
|
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
|
|
|
"next.done": dones[ep_id, :num_frames],
|
|
|
|
"next.reward": rewards[ep_id, :num_frames].type(torch.float32),
|
2024-04-17 01:14:40 +08:00
|
|
|
"episode_data_index_from": torch.tensor([idx_from] * num_frames),
|
|
|
|
"episode_data_index_to": torch.tensor([idx_from + num_frames] * num_frames),
|
2024-04-10 19:34:01 +08:00
|
|
|
}
|
|
|
|
for key in observations:
|
2024-04-17 00:07:39 +08:00
|
|
|
ep_dict[key] = observations[key][ep_id][:num_frames]
|
2024-04-10 19:34:01 +08:00
|
|
|
ep_dicts.append(ep_dict)
|
|
|
|
|
2024-04-16 05:26:33 +08:00
|
|
|
idx_from += num_frames
|
2024-04-10 19:34:01 +08:00
|
|
|
|
|
|
|
# similar logic is implemented in dataset preprocessing
|
|
|
|
data_dict = {}
|
|
|
|
keys = ep_dicts[0].keys()
|
|
|
|
for key in keys:
|
2024-04-17 00:07:39 +08:00
|
|
|
if "image" not in key:
|
|
|
|
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
|
|
|
else:
|
|
|
|
if key not in data_dict:
|
|
|
|
data_dict[key] = []
|
|
|
|
for ep_dict in ep_dicts:
|
|
|
|
for x in ep_dict[key]:
|
|
|
|
# c h w -> h w c
|
|
|
|
img = PILImage.fromarray(x.permute(1, 2, 0).numpy())
|
|
|
|
data_dict[key].append(img)
|
|
|
|
|
2024-04-10 19:34:01 +08:00
|
|
|
data_dict["index"] = torch.arange(0, total_frames, 1)
|
|
|
|
|
2024-04-18 17:43:16 +08:00
|
|
|
hf_dataset = Dataset.from_dict(data_dict).with_format("torch")
|
2024-04-16 05:26:33 +08:00
|
|
|
|
2024-04-10 19:34:01 +08:00
|
|
|
if max_episodes_rendered > 0:
|
2024-03-31 23:05:25 +08:00
|
|
|
batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *)
|
|
|
|
|
2024-04-10 19:34:01 +08:00
|
|
|
for stacked_frames, done_index in zip(
|
|
|
|
batch_stacked_frames, done_indices.flatten().tolist(), strict=False
|
|
|
|
):
|
|
|
|
if episode_counter >= num_episodes:
|
|
|
|
continue
|
|
|
|
video_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
video_path = video_dir / f"eval_episode_{episode_counter}.mp4"
|
|
|
|
thread = threading.Thread(
|
|
|
|
target=write_video,
|
|
|
|
args=(str(video_path), stacked_frames[:done_index], fps),
|
|
|
|
)
|
|
|
|
thread.start()
|
|
|
|
threads.append(thread)
|
|
|
|
episode_counter += 1
|
2024-03-31 23:05:25 +08:00
|
|
|
|
2024-04-10 22:26:30 +08:00
|
|
|
videos = einops.rearrange(batch_stacked_frames, "b t h w c -> b t c h w")
|
2024-01-31 21:48:12 +08:00
|
|
|
|
2024-02-25 02:18:39 +08:00
|
|
|
for thread in threads:
|
|
|
|
thread.join()
|
|
|
|
|
2024-03-01 07:13:06 +08:00
|
|
|
info = {
|
2024-03-22 23:43:45 +08:00
|
|
|
"per_episode": [
|
2024-03-22 23:06:57 +08:00
|
|
|
{
|
|
|
|
"episode_ix": i,
|
|
|
|
"sum_reward": sum_reward,
|
|
|
|
"max_reward": max_reward,
|
|
|
|
"success": success,
|
|
|
|
"seed": seed,
|
|
|
|
}
|
|
|
|
for i, (sum_reward, max_reward, success, seed) in enumerate(
|
|
|
|
zip(
|
|
|
|
sum_rewards[:num_episodes],
|
|
|
|
max_rewards[:num_episodes],
|
2024-03-31 23:05:25 +08:00
|
|
|
all_successes[:num_episodes],
|
2024-03-22 23:06:57 +08:00
|
|
|
seeds[:num_episodes],
|
|
|
|
strict=True,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
],
|
2024-03-22 23:43:45 +08:00
|
|
|
"aggregated": {
|
2024-03-31 23:05:25 +08:00
|
|
|
"avg_sum_reward": float(np.nanmean(sum_rewards[:num_episodes])),
|
|
|
|
"avg_max_reward": float(np.nanmean(max_rewards[:num_episodes])),
|
|
|
|
"pc_success": float(np.nanmean(all_successes[:num_episodes]) * 100),
|
2024-03-22 23:06:57 +08:00
|
|
|
"eval_s": time.time() - start,
|
|
|
|
"eval_ep_s": (time.time() - start) / num_episodes,
|
|
|
|
},
|
2024-04-18 17:43:16 +08:00
|
|
|
"episodes": hf_dataset,
|
2024-01-29 20:49:30 +08:00
|
|
|
}
|
2024-04-10 19:34:01 +08:00
|
|
|
if max_episodes_rendered > 0:
|
|
|
|
info["videos"] = videos
|
2024-03-01 07:13:06 +08:00
|
|
|
return info
|
2024-01-29 20:49:30 +08:00
|
|
|
|
|
|
|
|
2024-03-22 18:26:55 +08:00
|
|
|
def eval(cfg: dict, out_dir=None, stats_path=None):
|
2024-02-22 20:14:12 +08:00
|
|
|
if out_dir is None:
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
2024-03-06 18:14:03 +08:00
|
|
|
init_logging()
|
|
|
|
|
2024-03-21 01:38:55 +08:00
|
|
|
# Check device is available
|
|
|
|
get_safe_torch_device(cfg.device, log=True)
|
2024-03-06 18:14:03 +08:00
|
|
|
|
2024-02-25 02:18:39 +08:00
|
|
|
torch.backends.cudnn.benchmark = True
|
2024-03-06 18:14:03 +08:00
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
2024-03-26 00:19:28 +08:00
|
|
|
set_global_seed(cfg.seed)
|
2024-03-06 18:14:03 +08:00
|
|
|
|
|
|
|
log_output_dir(out_dir)
|
2024-01-29 20:49:30 +08:00
|
|
|
|
2024-03-22 18:26:55 +08:00
|
|
|
logging.info("Making transforms.")
|
2024-03-27 00:13:40 +08:00
|
|
|
# TODO(alexander-soare): Completely decouple datasets from evaluation.
|
2024-04-08 17:23:26 +08:00
|
|
|
transform = make_dataset(cfg, stats_path=stats_path).transform
|
2024-03-02 23:53:29 +08:00
|
|
|
|
2024-03-22 18:26:55 +08:00
|
|
|
logging.info("Making environment.")
|
2024-03-31 23:05:25 +08:00
|
|
|
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
|
|
|
|
|
2024-04-10 23:09:04 +08:00
|
|
|
logging.info("Making policy.")
|
|
|
|
policy = make_policy(cfg)
|
2024-02-20 20:26:57 +08:00
|
|
|
|
2024-03-22 23:06:57 +08:00
|
|
|
info = eval_policy(
|
2024-01-31 07:30:14 +08:00
|
|
|
env,
|
2024-04-10 23:09:04 +08:00
|
|
|
policy,
|
2024-04-10 19:34:01 +08:00
|
|
|
max_episodes_rendered=10,
|
2024-02-22 20:14:12 +08:00
|
|
|
video_dir=Path(out_dir) / "eval",
|
2024-04-08 17:23:26 +08:00
|
|
|
transform=transform,
|
2024-04-05 17:31:39 +08:00
|
|
|
seed=cfg.seed,
|
2024-01-31 07:30:14 +08:00
|
|
|
)
|
2024-03-22 23:43:45 +08:00
|
|
|
print(info["aggregated"])
|
2024-03-22 23:06:57 +08:00
|
|
|
|
|
|
|
# Save info
|
|
|
|
with open(Path(out_dir) / "eval_info.json", "w") as f:
|
2024-04-10 22:26:30 +08:00
|
|
|
# remove pytorch tensors which are not serializable to save the evaluation results only
|
|
|
|
del info["episodes"]
|
|
|
|
del info["videos"]
|
2024-03-22 23:06:57 +08:00
|
|
|
json.dump(info, f, indent=2)
|
2024-01-29 20:49:30 +08:00
|
|
|
|
2024-03-06 18:14:03 +08:00
|
|
|
logging.info("End of eval")
|
|
|
|
|
2024-01-29 20:49:30 +08:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2024-03-22 18:26:55 +08:00
|
|
|
parser = argparse.ArgumentParser(
|
|
|
|
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
|
|
|
)
|
|
|
|
group = parser.add_mutually_exclusive_group(required=True)
|
|
|
|
group.add_argument("--config", help="Path to a specific yaml config you want to use.")
|
|
|
|
group.add_argument("--hub-id", help="HuggingFace Hub ID for a pretrained model.")
|
|
|
|
parser.add_argument("--revision", help="Optionally provide the HuggingFace Hub revision ID.")
|
|
|
|
parser.add_argument(
|
|
|
|
"overrides",
|
|
|
|
nargs="*",
|
|
|
|
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
|
|
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
if args.config is not None:
|
|
|
|
# Note: For the config_path, Hydra wants a path relative to this script file.
|
2024-03-27 00:13:40 +08:00
|
|
|
cfg = init_hydra_config(args.config, args.overrides)
|
2024-03-22 18:26:55 +08:00
|
|
|
# TODO(alexander-soare): Save and load stats in trained model directory.
|
|
|
|
stats_path = None
|
|
|
|
elif args.hub_id is not None:
|
2024-03-27 21:22:14 +08:00
|
|
|
folder = Path(snapshot_download(args.hub_id, revision=args.revision))
|
2024-03-27 00:13:40 +08:00
|
|
|
cfg = init_hydra_config(
|
|
|
|
folder / "config.yaml", [f"policy.pretrained_model_path={folder / 'model.pt'}", *args.overrides]
|
|
|
|
)
|
2024-03-22 18:26:55 +08:00
|
|
|
stats_path = folder / "stats.pth"
|
|
|
|
|
|
|
|
eval(
|
|
|
|
cfg,
|
2024-03-22 20:33:25 +08:00
|
|
|
out_dir=f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{cfg.env.name}_{cfg.policy.name}",
|
2024-03-22 20:58:59 +08:00
|
|
|
stats_path=stats_path,
|
2024-03-22 18:26:55 +08:00
|
|
|
)
|