Refactor eval.py (#127)
This commit is contained in:
parent
b7b69fcc3d
commit
bccee745c3
7
Makefile
7
Makefile
|
@ -35,6 +35,7 @@ test-act-ete-train:
|
||||||
training.offline_steps=2 \
|
training.offline_steps=2 \
|
||||||
training.online_steps=0 \
|
training.online_steps=0 \
|
||||||
eval.n_episodes=1 \
|
eval.n_episodes=1 \
|
||||||
|
eval.batch_size=1 \
|
||||||
device=cpu \
|
device=cpu \
|
||||||
training.save_model=true \
|
training.save_model=true \
|
||||||
training.save_freq=2 \
|
training.save_freq=2 \
|
||||||
|
@ -47,6 +48,7 @@ test-act-ete-eval:
|
||||||
python lerobot/scripts/eval.py \
|
python lerobot/scripts/eval.py \
|
||||||
-p tests/outputs/act/checkpoints/000002 \
|
-p tests/outputs/act/checkpoints/000002 \
|
||||||
eval.n_episodes=1 \
|
eval.n_episodes=1 \
|
||||||
|
eval.batch_size=1 \
|
||||||
env.episode_length=8 \
|
env.episode_length=8 \
|
||||||
device=cpu \
|
device=cpu \
|
||||||
|
|
||||||
|
@ -58,6 +60,7 @@ test-diffusion-ete-train:
|
||||||
training.offline_steps=2 \
|
training.offline_steps=2 \
|
||||||
training.online_steps=0 \
|
training.online_steps=0 \
|
||||||
eval.n_episodes=1 \
|
eval.n_episodes=1 \
|
||||||
|
eval.batch_size=1 \
|
||||||
device=cpu \
|
device=cpu \
|
||||||
training.save_model=true \
|
training.save_model=true \
|
||||||
training.save_freq=2 \
|
training.save_freq=2 \
|
||||||
|
@ -68,6 +71,7 @@ test-diffusion-ete-eval:
|
||||||
python lerobot/scripts/eval.py \
|
python lerobot/scripts/eval.py \
|
||||||
-p tests/outputs/diffusion/checkpoints/000002 \
|
-p tests/outputs/diffusion/checkpoints/000002 \
|
||||||
eval.n_episodes=1 \
|
eval.n_episodes=1 \
|
||||||
|
eval.batch_size=1 \
|
||||||
env.episode_length=8 \
|
env.episode_length=8 \
|
||||||
device=cpu \
|
device=cpu \
|
||||||
|
|
||||||
|
@ -81,6 +85,7 @@ test-tdmpc-ete-train:
|
||||||
training.offline_steps=2 \
|
training.offline_steps=2 \
|
||||||
training.online_steps=2 \
|
training.online_steps=2 \
|
||||||
eval.n_episodes=1 \
|
eval.n_episodes=1 \
|
||||||
|
eval.batch_size=1 \
|
||||||
env.episode_length=2 \
|
env.episode_length=2 \
|
||||||
device=cpu \
|
device=cpu \
|
||||||
training.save_model=true \
|
training.save_model=true \
|
||||||
|
@ -92,6 +97,7 @@ test-tdmpc-ete-eval:
|
||||||
python lerobot/scripts/eval.py \
|
python lerobot/scripts/eval.py \
|
||||||
-p tests/outputs/tdmpc/checkpoints/000002 \
|
-p tests/outputs/tdmpc/checkpoints/000002 \
|
||||||
eval.n_episodes=1 \
|
eval.n_episodes=1 \
|
||||||
|
eval.batch_size=1 \
|
||||||
env.episode_length=8 \
|
env.episode_length=8 \
|
||||||
device=cpu \
|
device=cpu \
|
||||||
|
|
||||||
|
@ -100,5 +106,6 @@ test-default-ete-eval:
|
||||||
python lerobot/scripts/eval.py \
|
python lerobot/scripts/eval.py \
|
||||||
--config lerobot/configs/default.yaml \
|
--config lerobot/configs/default.yaml \
|
||||||
eval.n_episodes=1 \
|
eval.n_episodes=1 \
|
||||||
|
eval.batch_size=1 \
|
||||||
env.episode_length=8 \
|
env.episode_length=8 \
|
||||||
device=cpu \
|
device=cpu \
|
||||||
|
|
|
@ -1,13 +1,17 @@
|
||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
from omegaconf import DictConfig
|
||||||
|
|
||||||
|
|
||||||
def make_env(cfg, num_parallel_envs=0) -> gym.Env | gym.vector.SyncVectorEnv:
|
def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv:
|
||||||
"""
|
"""Makes a gym vector environment according to the evaluation config.
|
||||||
Note: When `num_parallel_envs > 0`, this function returns a `SyncVectorEnv` which takes batched action as input and
|
|
||||||
returns batched observation, reward, terminated, truncated of `num_parallel_envs` items.
|
n_envs can be used to override eval.batch_size in the configuration. Must be at least 1.
|
||||||
"""
|
"""
|
||||||
|
if n_envs is not None and n_envs < 1:
|
||||||
|
raise ValueError("`n_envs must be at least 1")
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"obs_type": "pixels_agent_pos",
|
"obs_type": "pixels_agent_pos",
|
||||||
"render_mode": "rgb_array",
|
"render_mode": "rgb_array",
|
||||||
|
@ -28,16 +32,13 @@ def make_env(cfg, num_parallel_envs=0) -> gym.Env | gym.vector.SyncVectorEnv:
|
||||||
|
|
||||||
gym_handle = f"{package_name}/{cfg.env.task}"
|
gym_handle = f"{package_name}/{cfg.env.task}"
|
||||||
|
|
||||||
if num_parallel_envs == 0:
|
# batched version of the env that returns an observation of shape (b, c)
|
||||||
# non-batched version of the env that returns an observation of shape (c)
|
env_cls = gym.vector.AsyncVectorEnv if cfg.eval.use_async_envs else gym.vector.SyncVectorEnv
|
||||||
env = gym.make(gym_handle, disable_env_checker=True, **kwargs)
|
env = env_cls(
|
||||||
else:
|
[
|
||||||
# batched version of the env that returns an observation of shape (b, c)
|
lambda: gym.make(gym_handle, disable_env_checker=True, **kwargs)
|
||||||
env = gym.vector.SyncVectorEnv(
|
for _ in range(n_envs if n_envs is not None else cfg.eval.batch_size)
|
||||||
[
|
]
|
||||||
lambda: gym.make(gym_handle, disable_env_checker=True, **kwargs)
|
)
|
||||||
for _ in range(num_parallel_envs)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
return env
|
return env
|
||||||
|
|
|
@ -1,15 +1,23 @@
|
||||||
import einops
|
import einops
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
def preprocess_observation(observation):
|
def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
|
||||||
|
"""Convert environment observation to LeRobot format observation.
|
||||||
|
Args:
|
||||||
|
observation: Dictionary of observation batches from a Gym vector environment.
|
||||||
|
Returns:
|
||||||
|
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
|
||||||
|
"""
|
||||||
# map to expected inputs for the policy
|
# map to expected inputs for the policy
|
||||||
obs = {}
|
return_observations = {}
|
||||||
|
|
||||||
if isinstance(observation["pixels"], dict):
|
if isinstance(observations["pixels"], dict):
|
||||||
imgs = {f"observation.images.{key}": img for key, img in observation["pixels"].items()}
|
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
|
||||||
else:
|
else:
|
||||||
imgs = {"observation.image": observation["pixels"]}
|
imgs = {"observation.image": observations["pixels"]}
|
||||||
|
|
||||||
for imgkey, img in imgs.items():
|
for imgkey, img in imgs.items():
|
||||||
img = torch.from_numpy(img)
|
img = torch.from_numpy(img)
|
||||||
|
@ -26,17 +34,10 @@ def preprocess_observation(observation):
|
||||||
img = img.type(torch.float32)
|
img = img.type(torch.float32)
|
||||||
img /= 255
|
img /= 255
|
||||||
|
|
||||||
obs[imgkey] = img
|
return_observations[imgkey] = img
|
||||||
|
|
||||||
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing requirement for "agent_pos"
|
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
|
||||||
obs["observation.state"] = torch.from_numpy(observation["agent_pos"]).float()
|
# requirement for "agent_pos"
|
||||||
|
return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
|
||||||
|
|
||||||
return obs
|
return return_observations
|
||||||
|
|
||||||
|
|
||||||
def postprocess_action(action):
|
|
||||||
action = action.to("cpu").numpy()
|
|
||||||
assert (
|
|
||||||
action.ndim == 2
|
|
||||||
), "we assume dimensions are respectively the number of parallel envs, action dimensions"
|
|
||||||
return action
|
|
||||||
|
|
|
@ -115,7 +115,7 @@ class Logger:
|
||||||
for k, v in d.items():
|
for k, v in d.items():
|
||||||
self._wandb.log({f"{mode}/{k}": v}, step=step)
|
self._wandb.log({f"{mode}/{k}": v}, step=step)
|
||||||
|
|
||||||
def log_video(self, video, step, mode="train"):
|
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
||||||
assert mode in {"train", "eval"}
|
assert mode in {"train", "eval"}
|
||||||
wandb_video = self._wandb.Video(video, fps=self._cfg.fps, format="mp4")
|
wandb_video = self._wandb.Video(video_path, fps=self._cfg.fps, format="mp4")
|
||||||
self._wandb.log({f"{mode}/video": wandb_video}, step=step)
|
self._wandb.log({f"{mode}/video": wandb_video}, step=step)
|
||||||
|
|
|
@ -10,6 +10,8 @@ hydra:
|
||||||
name: default
|
name: default
|
||||||
|
|
||||||
device: cuda # cpu
|
device: cuda # cpu
|
||||||
|
# `seed` is used for training (eg: model initialization, dataset shuffling)
|
||||||
|
# AND for the evaluation environments.
|
||||||
seed: ???
|
seed: ???
|
||||||
dataset_repo_id: lerobot/pusht
|
dataset_repo_id: lerobot/pusht
|
||||||
|
|
||||||
|
@ -18,6 +20,8 @@ training:
|
||||||
online_steps: ???
|
online_steps: ???
|
||||||
online_steps_between_rollouts: ???
|
online_steps_between_rollouts: ???
|
||||||
online_sampling_ratio: 0.5
|
online_sampling_ratio: 0.5
|
||||||
|
# `online_env_seed` is used for environments for online training data rollouts.
|
||||||
|
online_env_seed: ???
|
||||||
eval_freq: ???
|
eval_freq: ???
|
||||||
save_freq: ???
|
save_freq: ???
|
||||||
log_freq: 250
|
log_freq: 250
|
||||||
|
@ -25,8 +29,10 @@ training:
|
||||||
|
|
||||||
eval:
|
eval:
|
||||||
n_episodes: 1
|
n_episodes: 1
|
||||||
# TODO(alexander-soare): Right now this does not work. Reinstate this.
|
# `batch_size` specifies the number of environments to use in a gym.vector.VectorEnv.
|
||||||
batch_size: 1
|
batch_size: 1
|
||||||
|
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
|
||||||
|
use_async_envs: false
|
||||||
|
|
||||||
wandb:
|
wandb:
|
||||||
enable: true
|
enable: true
|
||||||
|
|
|
@ -28,7 +28,8 @@ training:
|
||||||
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||||
|
|
||||||
eval:
|
eval:
|
||||||
n_episodes:: 50
|
n_episodes: 50
|
||||||
|
batch_size: 50
|
||||||
|
|
||||||
# See `configuration_act.py` for more details.
|
# See `configuration_act.py` for more details.
|
||||||
policy:
|
policy:
|
||||||
|
|
|
@ -28,6 +28,7 @@ training:
|
||||||
|
|
||||||
eval:
|
eval:
|
||||||
n_episodes: 50
|
n_episodes: 50
|
||||||
|
batch_size: 50
|
||||||
|
|
||||||
override_dataset_stats:
|
override_dataset_stats:
|
||||||
# TODO(rcadene, alexander-soare): should we remove image stats as well? do we use a pretrained vision model?
|
# TODO(rcadene, alexander-soare): should we remove image stats as well? do we use a pretrained vision model?
|
||||||
|
|
|
@ -8,6 +8,7 @@ training:
|
||||||
eval_freq: 5000
|
eval_freq: 5000
|
||||||
online_steps_between_rollouts: 1
|
online_steps_between_rollouts: 1
|
||||||
online_sampling_ratio: 0.5
|
online_sampling_ratio: 0.5
|
||||||
|
online_env_seed: 10000
|
||||||
|
|
||||||
batch_size: 256
|
batch_size: 256
|
||||||
grad_clip_norm: 10.0
|
grad_clip_norm: 10.0
|
||||||
|
|
|
@ -34,286 +34,328 @@ import time
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from datetime import datetime as dt
|
from datetime import datetime as dt
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset, Features, Image, Sequence, Value
|
from datasets import Dataset, Features, Image, Sequence, Value, concatenate_datasets
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
||||||
from huggingface_hub.utils._validators import HFValidationError
|
from huggingface_hub.utils._validators import HFValidationError
|
||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
|
from torch import Tensor
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
from lerobot.common.datasets.utils import hf_transform_to_torch
|
from lerobot.common.datasets.utils import hf_transform_to_torch
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
|
from lerobot.common.envs.utils import preprocess_observation
|
||||||
from lerobot.common.logger import log_output_dir
|
from lerobot.common.logger import log_output_dir
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
|
from lerobot.common.policies.policy_protocol import Policy
|
||||||
|
from lerobot.common.policies.utils import get_device_from_parameters
|
||||||
from lerobot.common.utils.io_utils import write_video
|
from lerobot.common.utils.io_utils import write_video
|
||||||
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
|
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
|
||||||
|
|
||||||
|
|
||||||
|
def rollout(
|
||||||
|
env: gym.vector.VectorEnv,
|
||||||
|
policy: Policy,
|
||||||
|
seeds: list[int] | None = None,
|
||||||
|
return_observations: bool = False,
|
||||||
|
render_callback: Callable[[gym.vector.VectorEnv], None] | None = None,
|
||||||
|
enable_progbar: bool = False,
|
||||||
|
) -> dict:
|
||||||
|
"""Run a batched policy rollout once through a batch of environments.
|
||||||
|
|
||||||
|
Note that all environments in the batch are run until the last environment is done. This means some
|
||||||
|
data will probably need to be discarded (for environments that aren't the first one to be done).
|
||||||
|
|
||||||
|
The return dictionary contains:
|
||||||
|
(optional) "observation": A a dictionary of (batch, sequence + 1, *) tensors mapped to observation
|
||||||
|
keys. NOTE the that this has an extra sequence element relative to the other keys in the
|
||||||
|
dictionary. This is because an extra observation is included for after the environment is
|
||||||
|
terminated or truncated.
|
||||||
|
"action": A (batch, sequence, action_dim) tensor of actions applied based on the observations (not
|
||||||
|
including the last observations).
|
||||||
|
"reward": A (batch, sequence) tensor of rewards received for applying the actions.
|
||||||
|
"success": A (batch, sequence) tensor of success conditions (the only time this can be True is upon
|
||||||
|
environment termination/truncation).
|
||||||
|
"don": A (batch, sequence) tensor of **cumulative** done conditions. For any given batch element,
|
||||||
|
the first True is followed by True's all the way till the end. This can be used for masking
|
||||||
|
extraneous elements from the sequences above.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: The batch of environments.
|
||||||
|
policy: The policy.
|
||||||
|
seeds: The environments are seeded once at the start of the rollout. If provided, this argument
|
||||||
|
specifies the seeds for each of the environments.
|
||||||
|
return_observations: Whether to include all observations in the returned rollout data. Observations
|
||||||
|
are returned optionally because they typically take more memory to cache. Defaults to False.
|
||||||
|
render_callback: Optional rendering callback to be used after the environments are reset, and after
|
||||||
|
every step.
|
||||||
|
enable_progbar: Enable a progress bar over rollout steps.
|
||||||
|
Returns:
|
||||||
|
The dictionary described above.
|
||||||
|
"""
|
||||||
|
device = get_device_from_parameters(policy)
|
||||||
|
|
||||||
|
# Reset the policy and environments.
|
||||||
|
policy.reset()
|
||||||
|
|
||||||
|
observation, info = env.reset(seed=seeds)
|
||||||
|
if render_callback is not None:
|
||||||
|
render_callback(env)
|
||||||
|
|
||||||
|
all_observations = []
|
||||||
|
all_actions = []
|
||||||
|
all_rewards = []
|
||||||
|
all_successes = []
|
||||||
|
all_dones = []
|
||||||
|
|
||||||
|
step = 0
|
||||||
|
# Keep track of which environments are done.
|
||||||
|
done = np.array([False] * env.num_envs)
|
||||||
|
max_steps = env.call("_max_episode_steps")[0]
|
||||||
|
progbar = trange(
|
||||||
|
max_steps,
|
||||||
|
desc=f"Running rollout with {max_steps} steps (maximum) per rollout",
|
||||||
|
disable=not enable_progbar,
|
||||||
|
leave=False,
|
||||||
|
)
|
||||||
|
while not np.all(done):
|
||||||
|
# Numpy array to tensor and changing dictionary keys to LeRobot policy format.
|
||||||
|
observation = preprocess_observation(observation)
|
||||||
|
if return_observations:
|
||||||
|
all_observations.append(deepcopy(observation))
|
||||||
|
|
||||||
|
observation = {key: observation[key].to(device, non_blocking=True) for key in observation}
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
action = policy.select_action(observation)
|
||||||
|
|
||||||
|
# Convert to CPU / numpy.
|
||||||
|
action = action.to("cpu").numpy()
|
||||||
|
assert action.ndim == 2, "Action dimensions should be (batch, action_dim)"
|
||||||
|
|
||||||
|
# Apply the next action.
|
||||||
|
observation, reward, terminated, truncated, info = env.step(action)
|
||||||
|
if render_callback is not None:
|
||||||
|
render_callback(env)
|
||||||
|
|
||||||
|
# VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't
|
||||||
|
# available of none of the envs finished.
|
||||||
|
if "final_info" in info:
|
||||||
|
successes = [info["is_success"] if info is not None else False for info in info["final_info"]]
|
||||||
|
else:
|
||||||
|
successes = [False] * env.num_envs
|
||||||
|
|
||||||
|
# Keep track of which environments are done so far.
|
||||||
|
done = terminated | truncated | done
|
||||||
|
|
||||||
|
all_actions.append(torch.from_numpy(action))
|
||||||
|
all_rewards.append(torch.from_numpy(reward))
|
||||||
|
all_dones.append(torch.from_numpy(done))
|
||||||
|
all_successes.append(torch.tensor(successes))
|
||||||
|
|
||||||
|
step += 1
|
||||||
|
running_success_rate = (
|
||||||
|
einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any").numpy().mean()
|
||||||
|
)
|
||||||
|
progbar.set_postfix({"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"})
|
||||||
|
progbar.update()
|
||||||
|
|
||||||
|
# Track the final observation.
|
||||||
|
if return_observations:
|
||||||
|
observation = preprocess_observation(observation)
|
||||||
|
all_observations.append(deepcopy(observation))
|
||||||
|
|
||||||
|
# Stack the sequence along the first dimension so that we have (batch, sequence, *) tensors.
|
||||||
|
ret = {
|
||||||
|
"action": torch.stack(all_actions, dim=1),
|
||||||
|
"reward": torch.stack(all_rewards, dim=1),
|
||||||
|
"success": torch.stack(all_successes, dim=1),
|
||||||
|
"done": torch.stack(all_dones, dim=1),
|
||||||
|
}
|
||||||
|
if return_observations:
|
||||||
|
stacked_observations = {}
|
||||||
|
for key in all_observations[0]:
|
||||||
|
stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1)
|
||||||
|
ret["observation"] = stacked_observations
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def eval_policy(
|
def eval_policy(
|
||||||
env: gym.vector.VectorEnv,
|
env: gym.vector.VectorEnv,
|
||||||
policy: torch.nn.Module,
|
policy: torch.nn.Module,
|
||||||
|
n_episodes: int,
|
||||||
max_episodes_rendered: int = 0,
|
max_episodes_rendered: int = 0,
|
||||||
video_dir: Path = None,
|
video_dir: Path | None = None,
|
||||||
return_episode_data: bool = False,
|
return_episode_data: bool = False,
|
||||||
seed=None,
|
start_seed: int | None = None,
|
||||||
):
|
enable_progbar: bool = False,
|
||||||
|
enable_inner_progbar: bool = False,
|
||||||
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
set `return_episode_data` to return a Hugging Face dataset object in an "episodes" key of the return dict.
|
Args:
|
||||||
|
env: The batch of environments.
|
||||||
|
policy: The policy.
|
||||||
|
n_episodes: The number of episodes to evaluate.
|
||||||
|
max_episodes_rendered: Maximum number of episodes to render into videos.
|
||||||
|
video_dir: Where to save rendered videos.
|
||||||
|
return_episode_data: Whether to return episode data for online training. Incorporates the data into
|
||||||
|
the "episodes" key of the returned dictionary.
|
||||||
|
start_seed: The first seed to use for the first individual rollout. For all subsequent rollouts the
|
||||||
|
seed is incremented by 1. If not provided, the environments are not manually seeded.
|
||||||
|
enable_progbar: Enable progress bar over batches.
|
||||||
|
enable_inner_progbar: Enable progress bar over steps in each batch.
|
||||||
|
Returns:
|
||||||
|
Dictionary with metrics and data regarding the rollouts.
|
||||||
"""
|
"""
|
||||||
|
start = time.time()
|
||||||
policy.eval()
|
policy.eval()
|
||||||
|
|
||||||
fps = env.unwrapped.metadata["render_fps"]
|
# Determine how many batched rollouts we need to get n_episodes. Note that if n_episodes is not evenly
|
||||||
|
# divisible by env.num_envs we end up discarding some data in the last batch.
|
||||||
|
n_batches = n_episodes // env.num_envs + int((n_episodes % env.num_envs) != 0)
|
||||||
|
|
||||||
device = "cpu" if policy is None else next(policy.parameters()).device
|
# Keep track of some metrics.
|
||||||
|
|
||||||
start = time.time()
|
|
||||||
sum_rewards = []
|
sum_rewards = []
|
||||||
max_rewards = []
|
max_rewards = []
|
||||||
all_successes = []
|
all_successes = []
|
||||||
seeds = []
|
all_seeds = []
|
||||||
threads = [] # for video saving threads
|
threads = [] # for video saving threads
|
||||||
episode_counter = 0 # for saving the correct number of videos
|
n_episodes_rendered = 0 # for saving the correct number of videos
|
||||||
|
|
||||||
num_episodes = len(env.envs)
|
# Callback for visualization.
|
||||||
|
def render_frame(env: gym.vector.VectorEnv):
|
||||||
# TODO(alexander-soare): if num_episodes is not evenly divisible by the batch size, this will do more work than
|
|
||||||
# needed as I'm currently taking a ceil.
|
|
||||||
ep_frames = []
|
|
||||||
|
|
||||||
def render_frame(env):
|
|
||||||
# noqa: B023
|
# noqa: B023
|
||||||
eps_rendered = min(max_episodes_rendered, len(env.envs))
|
if n_episodes_rendered >= max_episodes_rendered:
|
||||||
visu = np.stack([env.envs[i].render() for i in range(eps_rendered)])
|
return
|
||||||
ep_frames.append(visu) # noqa: B023
|
n_to_render_now = min(max_episodes_rendered - n_episodes_rendered, env.num_envs)
|
||||||
|
if isinstance(env, gym.vector.SyncVectorEnv):
|
||||||
|
ep_frames.append(np.stack([env.envs[i].render() for i in range(n_to_render_now)])) # noqa: B023
|
||||||
|
elif isinstance(env, gym.vector.AsyncVectorEnv):
|
||||||
|
# Here we must render all frames and discard any we don't need.
|
||||||
|
ep_frames.append(np.stack(env.call("render")[:n_to_render_now]))
|
||||||
|
|
||||||
for _ in range(num_episodes):
|
if max_episodes_rendered > 0:
|
||||||
seeds.append("TODO")
|
video_paths: list[str] = []
|
||||||
|
|
||||||
if hasattr(policy, "reset"):
|
if return_episode_data:
|
||||||
policy.reset()
|
episode_data: dict | None = None
|
||||||
else:
|
|
||||||
logging.warning(
|
progbar = trange(n_batches, desc="Stepping through eval batches", disable=not enable_progbar)
|
||||||
f"Policy {policy} doesnt have a `reset` method. It is required if the policy relies on an internal state during rollout."
|
for batch_ix in progbar:
|
||||||
|
# Cache frames for rendering videos. Each item will be (b, h, w, c), and the list indexes the rollout
|
||||||
|
# step.
|
||||||
|
if max_episodes_rendered > 0:
|
||||||
|
ep_frames: list[np.ndarray] = []
|
||||||
|
|
||||||
|
seeds = range(start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs))
|
||||||
|
rollout_data = rollout(
|
||||||
|
env,
|
||||||
|
policy,
|
||||||
|
seeds=seeds,
|
||||||
|
return_observations=return_episode_data,
|
||||||
|
render_callback=render_frame if max_episodes_rendered > 0 else None,
|
||||||
|
enable_progbar=enable_inner_progbar,
|
||||||
)
|
)
|
||||||
|
|
||||||
# reset the environment
|
# Figure out where in each rollout sequence the first done condition was encountered (results after
|
||||||
observation, info = env.reset(seed=seed)
|
# this won't be included).
|
||||||
if max_episodes_rendered > 0:
|
n_steps = rollout_data["done"].shape[1]
|
||||||
render_frame(env)
|
# Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker.
|
||||||
|
done_indices = torch.argmax(rollout_data["done"].to(int), axis=1) # (batch_size, rollout_steps)
|
||||||
|
# Make a mask with shape (batch, n_steps) to mask out rollout data after the first done
|
||||||
|
# (batch-element-wise). Note the `done_indices + 1` to make sure to keep the data from the done step.
|
||||||
|
mask = (torch.arange(n_steps) <= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)).int()
|
||||||
|
# Extend metrics.
|
||||||
|
batch_sum_rewards = einops.reduce((rollout_data["reward"] * mask), "b n -> b", "sum")
|
||||||
|
sum_rewards.extend(batch_sum_rewards.tolist())
|
||||||
|
batch_max_rewards = einops.reduce((rollout_data["reward"] * mask), "b n -> b", "max")
|
||||||
|
max_rewards.extend(batch_max_rewards.tolist())
|
||||||
|
batch_successes = einops.reduce((rollout_data["success"] * mask), "b n -> b", "any")
|
||||||
|
all_successes.extend(batch_successes.tolist())
|
||||||
|
all_seeds.extend(seeds)
|
||||||
|
|
||||||
observations = []
|
|
||||||
actions = []
|
|
||||||
# episode
|
|
||||||
# frame_id
|
|
||||||
# timestamp
|
|
||||||
rewards = []
|
|
||||||
successes = []
|
|
||||||
dones = []
|
|
||||||
|
|
||||||
done = torch.tensor([False for _ in env.envs])
|
|
||||||
step = 0
|
|
||||||
max_steps = env.envs[0]._max_episode_steps
|
|
||||||
progbar = trange(max_steps, desc=f"Running eval with {max_steps} steps (maximum) per rollout.")
|
|
||||||
while not done.all():
|
|
||||||
# format from env keys to lerobot keys
|
|
||||||
observation = preprocess_observation(observation)
|
|
||||||
if return_episode_data:
|
if return_episode_data:
|
||||||
observations.append(deepcopy(observation))
|
this_episode_data = _compile_episode_data(
|
||||||
|
rollout_data,
|
||||||
|
done_indices,
|
||||||
|
start_episode_index=batch_ix * env.num_envs,
|
||||||
|
start_data_index=(
|
||||||
|
0 if episode_data is None else (episode_data["episode_data_index"]["to"][-1].item())
|
||||||
|
),
|
||||||
|
fps=env.unwrapped.metadata["render_fps"],
|
||||||
|
)
|
||||||
|
if episode_data is None:
|
||||||
|
episode_data = this_episode_data
|
||||||
|
else:
|
||||||
|
# Some sanity checks to make sure we are not correctly compiling the data.
|
||||||
|
assert (
|
||||||
|
episode_data["hf_dataset"]["episode_index"][-1] + 1
|
||||||
|
== this_episode_data["hf_dataset"]["episode_index"][0]
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
episode_data["hf_dataset"]["index"][-1] + 1 == this_episode_data["hf_dataset"]["index"][0]
|
||||||
|
)
|
||||||
|
assert torch.equal(
|
||||||
|
episode_data["episode_data_index"]["to"][-1],
|
||||||
|
this_episode_data["episode_data_index"]["from"][0],
|
||||||
|
)
|
||||||
|
# Concatenate the episode data.
|
||||||
|
episode_data = {
|
||||||
|
"hf_dataset": concatenate_datasets(
|
||||||
|
[episode_data["hf_dataset"], this_episode_data["hf_dataset"]]
|
||||||
|
),
|
||||||
|
"episode_data_index": {
|
||||||
|
k: torch.cat(
|
||||||
|
[
|
||||||
|
episode_data["episode_data_index"][k],
|
||||||
|
this_episode_data["episode_data_index"][k],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
for k in ["from", "to"]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
# send observation to device/gpu
|
# Maybe render video for visualization.
|
||||||
observation = {key: observation[key].to(device, non_blocking=True) for key in observation}
|
|
||||||
|
|
||||||
# get the next action for the environment
|
|
||||||
with torch.inference_mode():
|
|
||||||
action = policy.select_action(observation)
|
|
||||||
|
|
||||||
# convert to cpu numpy
|
|
||||||
action = postprocess_action(action)
|
|
||||||
|
|
||||||
# apply the next action
|
|
||||||
observation, reward, terminated, truncated, info = env.step(action)
|
|
||||||
if max_episodes_rendered > 0:
|
if max_episodes_rendered > 0:
|
||||||
render_frame(env)
|
batch_stacked_frames = np.stack(ep_frames, axis=1) # (b, t, *)
|
||||||
|
for stacked_frames, done_index in zip(
|
||||||
|
batch_stacked_frames, done_indices.flatten().tolist(), strict=False
|
||||||
|
):
|
||||||
|
if n_episodes_rendered >= max_episodes_rendered:
|
||||||
|
break
|
||||||
|
video_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
video_path = video_dir / f"eval_episode_{n_episodes_rendered}.mp4"
|
||||||
|
video_paths.append(str(video_path))
|
||||||
|
thread = threading.Thread(
|
||||||
|
target=write_video,
|
||||||
|
args=(
|
||||||
|
str(video_path),
|
||||||
|
stacked_frames[: done_index + 2], # + 2 to capture the observation frame after done
|
||||||
|
env.unwrapped.metadata["render_fps"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
thread.start()
|
||||||
|
threads.append(thread)
|
||||||
|
n_episodes_rendered += 1
|
||||||
|
|
||||||
# TODO(rcadene): implement a wrapper over env to return torch tensors in float32 (and cuda?)
|
progbar.set_postfix(
|
||||||
action = torch.from_numpy(action)
|
{"running_success_rate": f"{np.mean(all_successes[:n_episodes]).item() * 100:.1f}%"}
|
||||||
reward = torch.from_numpy(reward)
|
)
|
||||||
terminated = torch.from_numpy(terminated)
|
|
||||||
truncated = torch.from_numpy(truncated)
|
|
||||||
# environment is considered done (no more steps), when success state is reached (terminated is True),
|
|
||||||
# or time limit is reached (truncated is True), or it was previsouly done.
|
|
||||||
done = terminated | truncated | done
|
|
||||||
|
|
||||||
if "final_info" in info:
|
|
||||||
# VectorEnv stores is_success into `info["final_info"][env_id]["is_success"]` instead of `info["is_success"]`
|
|
||||||
success = [
|
|
||||||
env_info["is_success"] if env_info is not None else False for env_info in info["final_info"]
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
success = [False for _ in env.envs]
|
|
||||||
success = torch.tensor(success)
|
|
||||||
|
|
||||||
actions.append(action)
|
|
||||||
rewards.append(reward)
|
|
||||||
dones.append(done)
|
|
||||||
successes.append(success)
|
|
||||||
|
|
||||||
step += 1
|
|
||||||
progbar.update()
|
progbar.update()
|
||||||
|
|
||||||
env.close()
|
# Wait till all video rendering threads are done.
|
||||||
|
|
||||||
# add the last observation when the env is done
|
|
||||||
if return_episode_data:
|
|
||||||
observation = preprocess_observation(observation)
|
|
||||||
observations.append(deepcopy(observation))
|
|
||||||
|
|
||||||
if return_episode_data:
|
|
||||||
new_obses = {}
|
|
||||||
for key in observations[0].keys(): # noqa: SIM118
|
|
||||||
new_obses[key] = torch.stack([obs[key] for obs in observations], dim=1)
|
|
||||||
observations = new_obses
|
|
||||||
actions = torch.stack(actions, dim=1)
|
|
||||||
rewards = torch.stack(rewards, dim=1)
|
|
||||||
successes = torch.stack(successes, dim=1)
|
|
||||||
dones = torch.stack(dones, dim=1)
|
|
||||||
|
|
||||||
# Figure out where in each rollout sequence the first done condition was encountered (results after
|
|
||||||
# this won't be included).
|
|
||||||
# Note: this assumes that the shape of the done key is (batch_size, max_steps).
|
|
||||||
# Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker.
|
|
||||||
done_indices = torch.argmax(dones.to(int), axis=1) # (batch_size, rollout_steps)
|
|
||||||
expand_done_indices = done_indices[:, None].expand(-1, step)
|
|
||||||
expand_step_indices = torch.arange(step)[None, :].expand(num_episodes, -1)
|
|
||||||
mask = (expand_step_indices <= expand_done_indices).int() # (batch_size, rollout_steps)
|
|
||||||
batch_sum_reward = einops.reduce((rewards * mask), "b n -> b", "sum")
|
|
||||||
batch_max_reward = einops.reduce((rewards * mask), "b n -> b", "max")
|
|
||||||
batch_success = einops.reduce((successes * mask), "b n -> b", "any")
|
|
||||||
sum_rewards.extend(batch_sum_reward.tolist())
|
|
||||||
max_rewards.extend(batch_max_reward.tolist())
|
|
||||||
all_successes.extend(batch_success.tolist())
|
|
||||||
|
|
||||||
# similar logic is implemented when datasets are pushed to hub (see: `push_to_hub`)
|
|
||||||
ep_dicts = []
|
|
||||||
episode_data_index = {"from": [], "to": []}
|
|
||||||
num_episodes = dones.shape[0]
|
|
||||||
total_frames = 0
|
|
||||||
id_from = 0
|
|
||||||
for ep_id in range(num_episodes):
|
|
||||||
num_frames = done_indices[ep_id].item() + 1
|
|
||||||
total_frames += num_frames
|
|
||||||
|
|
||||||
# TODO(rcadene): We need to add a missing last frame which is the observation
|
|
||||||
# of a done state. it is critical to have this frame for tdmpc to predict a "done observation/state"
|
|
||||||
if return_episode_data:
|
|
||||||
ep_dict = {
|
|
||||||
"action": actions[ep_id, :num_frames],
|
|
||||||
"episode_index": torch.tensor([ep_id] * num_frames),
|
|
||||||
"frame_index": torch.arange(0, num_frames, 1),
|
|
||||||
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
|
||||||
"next.done": dones[ep_id, :num_frames],
|
|
||||||
"next.reward": rewards[ep_id, :num_frames].type(torch.float32),
|
|
||||||
}
|
|
||||||
for key in observations:
|
|
||||||
ep_dict[key] = observations[key][ep_id][:num_frames]
|
|
||||||
ep_dicts.append(ep_dict)
|
|
||||||
|
|
||||||
episode_data_index["from"].append(id_from)
|
|
||||||
episode_data_index["to"].append(id_from + num_frames)
|
|
||||||
|
|
||||||
id_from += num_frames
|
|
||||||
|
|
||||||
# similar logic is implemented in dataset preprocessing
|
|
||||||
if return_episode_data:
|
|
||||||
data_dict = {}
|
|
||||||
keys = ep_dicts[0].keys()
|
|
||||||
for key in keys:
|
|
||||||
if "image" not in key:
|
|
||||||
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
|
||||||
else:
|
|
||||||
if key not in data_dict:
|
|
||||||
data_dict[key] = []
|
|
||||||
for ep_dict in ep_dicts:
|
|
||||||
for img in ep_dict[key]:
|
|
||||||
# sanity check that images are channel first
|
|
||||||
c, h, w = img.shape
|
|
||||||
assert c < h and c < w, f"expect channel first images, but instead {img.shape}"
|
|
||||||
|
|
||||||
# sanity check that images are float32 in range [0,1]
|
|
||||||
assert img.dtype == torch.float32, f"expect torch.float32, but instead {img.dtype=}"
|
|
||||||
assert img.max() <= 1, f"expect pixels lower than 1, but instead {img.max()=}"
|
|
||||||
assert img.min() >= 0, f"expect pixels greater than 1, but instead {img.min()=}"
|
|
||||||
|
|
||||||
# from float32 in range [0,1] to uint8 in range [0,255]
|
|
||||||
img *= 255
|
|
||||||
img = img.type(torch.uint8)
|
|
||||||
|
|
||||||
# convert to channel last and numpy as expected by PIL
|
|
||||||
img = PILImage.fromarray(img.permute(1, 2, 0).numpy())
|
|
||||||
|
|
||||||
data_dict[key].append(img)
|
|
||||||
|
|
||||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
|
||||||
episode_data_index["from"] = torch.tensor(episode_data_index["from"])
|
|
||||||
episode_data_index["to"] = torch.tensor(episode_data_index["to"])
|
|
||||||
|
|
||||||
# TODO(rcadene): clean this
|
|
||||||
features = {}
|
|
||||||
for key in observations:
|
|
||||||
if "image" in key:
|
|
||||||
features[key] = Image()
|
|
||||||
else:
|
|
||||||
features[key] = Sequence(
|
|
||||||
length=data_dict[key].shape[1], feature=Value(dtype="float32", id=None)
|
|
||||||
)
|
|
||||||
features.update(
|
|
||||||
{
|
|
||||||
"action": Sequence(
|
|
||||||
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
|
||||||
),
|
|
||||||
"episode_index": Value(dtype="int64", id=None),
|
|
||||||
"frame_index": Value(dtype="int64", id=None),
|
|
||||||
"timestamp": Value(dtype="float32", id=None),
|
|
||||||
"next.reward": Value(dtype="float32", id=None),
|
|
||||||
"next.done": Value(dtype="bool", id=None),
|
|
||||||
#'next.success': Value(dtype='bool', id=None),
|
|
||||||
"index": Value(dtype="int64", id=None),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
features = Features(features)
|
|
||||||
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
|
||||||
hf_dataset.set_transform(hf_transform_to_torch)
|
|
||||||
|
|
||||||
if max_episodes_rendered > 0:
|
|
||||||
batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *)
|
|
||||||
|
|
||||||
for stacked_frames, done_index in zip(
|
|
||||||
batch_stacked_frames, done_indices.flatten().tolist(), strict=False
|
|
||||||
):
|
|
||||||
if episode_counter >= max_episodes_rendered:
|
|
||||||
continue
|
|
||||||
video_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
video_path = video_dir / f"eval_episode_{episode_counter}.mp4"
|
|
||||||
thread = threading.Thread(
|
|
||||||
target=write_video,
|
|
||||||
args=(str(video_path), stacked_frames[:done_index], fps),
|
|
||||||
)
|
|
||||||
thread.start()
|
|
||||||
threads.append(thread)
|
|
||||||
episode_counter += 1
|
|
||||||
|
|
||||||
videos = einops.rearrange(batch_stacked_frames, "b t h w c -> b t c h w")
|
|
||||||
|
|
||||||
for thread in threads:
|
for thread in threads:
|
||||||
thread.join()
|
thread.join()
|
||||||
|
|
||||||
|
# Compile eval info.
|
||||||
info = {
|
info = {
|
||||||
"per_episode": [
|
"per_episode": [
|
||||||
{
|
{
|
||||||
|
@ -325,32 +367,127 @@ def eval_policy(
|
||||||
}
|
}
|
||||||
for i, (sum_reward, max_reward, success, seed) in enumerate(
|
for i, (sum_reward, max_reward, success, seed) in enumerate(
|
||||||
zip(
|
zip(
|
||||||
sum_rewards[:num_episodes],
|
sum_rewards[:n_episodes],
|
||||||
max_rewards[:num_episodes],
|
max_rewards[:n_episodes],
|
||||||
all_successes[:num_episodes],
|
all_successes[:n_episodes],
|
||||||
seeds[:num_episodes],
|
all_seeds[:n_episodes],
|
||||||
strict=True,
|
strict=True,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
"aggregated": {
|
"aggregated": {
|
||||||
"avg_sum_reward": float(np.nanmean(sum_rewards[:num_episodes])),
|
"avg_sum_reward": float(np.nanmean(sum_rewards[:n_episodes])),
|
||||||
"avg_max_reward": float(np.nanmean(max_rewards[:num_episodes])),
|
"avg_max_reward": float(np.nanmean(max_rewards[:n_episodes])),
|
||||||
"pc_success": float(np.nanmean(all_successes[:num_episodes]) * 100),
|
"pc_success": float(np.nanmean(all_successes[:n_episodes]) * 100),
|
||||||
"eval_s": time.time() - start,
|
"eval_s": time.time() - start,
|
||||||
"eval_ep_s": (time.time() - start) / num_episodes,
|
"eval_ep_s": (time.time() - start) / n_episodes,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if return_episode_data:
|
if return_episode_data:
|
||||||
info["episodes"] = {
|
info["episodes"] = episode_data
|
||||||
"hf_dataset": hf_dataset,
|
|
||||||
"episode_data_index": episode_data_index,
|
|
||||||
}
|
|
||||||
if max_episodes_rendered > 0:
|
if max_episodes_rendered > 0:
|
||||||
info["videos"] = videos
|
info["video_paths"] = video_paths
|
||||||
|
|
||||||
return info
|
return info
|
||||||
|
|
||||||
|
|
||||||
|
def _compile_episode_data(
|
||||||
|
rollout_data: dict, done_indices: Tensor, start_episode_index: int, start_data_index: int, fps: float
|
||||||
|
) -> dict:
|
||||||
|
"""Convenience function for `eval_policy(return_episode_data=True)`
|
||||||
|
|
||||||
|
Compiles all the rollout data into a Hugging Face dataset.
|
||||||
|
|
||||||
|
Similar logic is implemented when datasets are pushed to hub (see: `push_to_hub`).
|
||||||
|
"""
|
||||||
|
ep_dicts = []
|
||||||
|
episode_data_index = {"from": [], "to": []}
|
||||||
|
total_frames = 0
|
||||||
|
data_index_from = start_data_index
|
||||||
|
for ep_ix in range(rollout_data["action"].shape[0]):
|
||||||
|
num_frames = done_indices[ep_ix].item() + 1 # + 1 to include the first done frame
|
||||||
|
total_frames += num_frames
|
||||||
|
|
||||||
|
# TODO(rcadene): We need to add a missing last frame which is the observation
|
||||||
|
# of a done state. it is critical to have this frame for tdmpc to predict a "done observation/state"
|
||||||
|
ep_dict = {
|
||||||
|
"action": rollout_data["action"][ep_ix, :num_frames],
|
||||||
|
"episode_index": torch.tensor([start_episode_index + ep_ix] * num_frames),
|
||||||
|
"frame_index": torch.arange(0, num_frames, 1),
|
||||||
|
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
||||||
|
"next.done": rollout_data["done"][ep_ix, :num_frames],
|
||||||
|
"next.reward": rollout_data["reward"][ep_ix, :num_frames].type(torch.float32),
|
||||||
|
}
|
||||||
|
for key in rollout_data["observation"]:
|
||||||
|
ep_dict[key] = rollout_data["observation"][key][ep_ix][:num_frames]
|
||||||
|
ep_dicts.append(ep_dict)
|
||||||
|
|
||||||
|
episode_data_index["from"].append(data_index_from)
|
||||||
|
episode_data_index["to"].append(data_index_from + num_frames)
|
||||||
|
|
||||||
|
data_index_from += num_frames
|
||||||
|
|
||||||
|
data_dict = {}
|
||||||
|
for key in ep_dicts[0]:
|
||||||
|
if "image" not in key:
|
||||||
|
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||||
|
else:
|
||||||
|
if key not in data_dict:
|
||||||
|
data_dict[key] = []
|
||||||
|
for ep_dict in ep_dicts:
|
||||||
|
for img in ep_dict[key]:
|
||||||
|
# sanity check that images are channel first
|
||||||
|
c, h, w = img.shape
|
||||||
|
assert c < h and c < w, f"expect channel first images, but instead {img.shape}"
|
||||||
|
|
||||||
|
# sanity check that images are float32 in range [0,1]
|
||||||
|
assert img.dtype == torch.float32, f"expect torch.float32, but instead {img.dtype=}"
|
||||||
|
assert img.max() <= 1, f"expect pixels lower than 1, but instead {img.max()=}"
|
||||||
|
assert img.min() >= 0, f"expect pixels greater than 1, but instead {img.min()=}"
|
||||||
|
|
||||||
|
# from float32 in range [0,1] to uint8 in range [0,255]
|
||||||
|
img *= 255
|
||||||
|
img = img.type(torch.uint8)
|
||||||
|
|
||||||
|
# convert to channel last and numpy as expected by PIL
|
||||||
|
img = PILImage.fromarray(img.permute(1, 2, 0).numpy())
|
||||||
|
|
||||||
|
data_dict[key].append(img)
|
||||||
|
|
||||||
|
data_dict["index"] = torch.arange(start_data_index, start_data_index + total_frames, 1)
|
||||||
|
episode_data_index["from"] = torch.tensor(episode_data_index["from"])
|
||||||
|
episode_data_index["to"] = torch.tensor(episode_data_index["to"])
|
||||||
|
|
||||||
|
# TODO(rcadene): clean this
|
||||||
|
features = {}
|
||||||
|
for key in rollout_data["observation"]:
|
||||||
|
if "image" in key:
|
||||||
|
features[key] = Image()
|
||||||
|
else:
|
||||||
|
features[key] = Sequence(length=data_dict[key].shape[1], feature=Value(dtype="float32", id=None))
|
||||||
|
features.update(
|
||||||
|
{
|
||||||
|
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
|
||||||
|
"episode_index": Value(dtype="int64", id=None),
|
||||||
|
"frame_index": Value(dtype="int64", id=None),
|
||||||
|
"timestamp": Value(dtype="float32", id=None),
|
||||||
|
"next.reward": Value(dtype="float32", id=None),
|
||||||
|
"next.done": Value(dtype="bool", id=None),
|
||||||
|
#'next.success': Value(dtype='bool', id=None),
|
||||||
|
"index": Value(dtype="int64", id=None),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
features = Features(features)
|
||||||
|
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
||||||
|
hf_dataset.set_transform(hf_transform_to_torch)
|
||||||
|
return {
|
||||||
|
"hf_dataset": hf_dataset,
|
||||||
|
"episode_data_index": episode_data_index,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def eval(
|
def eval(
|
||||||
pretrained_policy_path: str | None = None,
|
pretrained_policy_path: str | None = None,
|
||||||
hydra_cfg_path: str | None = None,
|
hydra_cfg_path: str | None = None,
|
||||||
|
@ -378,7 +515,7 @@ def eval(
|
||||||
log_output_dir(out_dir)
|
log_output_dir(out_dir)
|
||||||
|
|
||||||
logging.info("Making environment.")
|
logging.info("Making environment.")
|
||||||
env = make_env(hydra_cfg, num_parallel_envs=hydra_cfg.eval.n_episodes)
|
env = make_env(hydra_cfg)
|
||||||
|
|
||||||
logging.info("Making policy.")
|
logging.info("Making policy.")
|
||||||
if hydra_cfg_path is None:
|
if hydra_cfg_path is None:
|
||||||
|
@ -391,19 +528,21 @@ def eval(
|
||||||
info = eval_policy(
|
info = eval_policy(
|
||||||
env,
|
env,
|
||||||
policy,
|
policy,
|
||||||
|
hydra_cfg.eval.n_episodes,
|
||||||
max_episodes_rendered=10,
|
max_episodes_rendered=10,
|
||||||
video_dir=Path(out_dir) / "eval",
|
video_dir=Path(out_dir) / "eval",
|
||||||
return_episode_data=False,
|
start_seed=hydra_cfg.seed,
|
||||||
seed=hydra_cfg.seed,
|
enable_progbar=True,
|
||||||
|
enable_inner_progbar=True,
|
||||||
)
|
)
|
||||||
print(info["aggregated"])
|
print(info["aggregated"])
|
||||||
|
|
||||||
# Save info
|
# Save info
|
||||||
with open(Path(out_dir) / "eval_info.json", "w") as f:
|
with open(Path(out_dir) / "eval_info.json", "w") as f:
|
||||||
# remove pytorch tensors which are not serializable to save the evaluation results only
|
|
||||||
del info["videos"]
|
|
||||||
json.dump(info, f, indent=2)
|
json.dump(info, f, indent=2)
|
||||||
|
|
||||||
|
env.close()
|
||||||
|
|
||||||
logging.info("End of eval")
|
logging.info("End of eval")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -269,7 +269,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
offline_dataset = make_dataset(cfg)
|
offline_dataset = make_dataset(cfg)
|
||||||
|
|
||||||
logging.info("make_env")
|
logging.info("make_env")
|
||||||
env = make_env(cfg, num_parallel_envs=cfg.eval.n_episodes)
|
eval_env = make_env(cfg)
|
||||||
|
|
||||||
logging.info("make_policy")
|
logging.info("make_policy")
|
||||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=offline_dataset.stats)
|
policy = make_policy(hydra_cfg=cfg, dataset_stats=offline_dataset.stats)
|
||||||
|
@ -337,15 +337,16 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
if step % cfg.training.eval_freq == 0:
|
if step % cfg.training.eval_freq == 0:
|
||||||
logging.info(f"Eval policy at step {step}")
|
logging.info(f"Eval policy at step {step}")
|
||||||
eval_info = eval_policy(
|
eval_info = eval_policy(
|
||||||
env,
|
eval_env,
|
||||||
policy,
|
policy,
|
||||||
|
cfg.eval.n_episodes,
|
||||||
video_dir=Path(out_dir) / "eval",
|
video_dir=Path(out_dir) / "eval",
|
||||||
max_episodes_rendered=4,
|
max_episodes_rendered=4,
|
||||||
seed=cfg.seed,
|
start_seed=cfg.seed,
|
||||||
)
|
)
|
||||||
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline)
|
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline)
|
||||||
if cfg.wandb.enable:
|
if cfg.wandb.enable:
|
||||||
logger.log_video(eval_info["videos"][0], step, mode="eval")
|
logger.log_video(eval_info["video_paths"][0], step, mode="eval")
|
||||||
logging.info("Resume training")
|
logging.info("Resume training")
|
||||||
|
|
||||||
if cfg.training.save_model and step % cfg.training.save_freq == 0:
|
if cfg.training.save_model and step % cfg.training.save_freq == 0:
|
||||||
|
@ -395,7 +396,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
step += 1
|
step += 1
|
||||||
|
|
||||||
# create an env dedicated to online episodes collection from policy rollout
|
# create an env dedicated to online episodes collection from policy rollout
|
||||||
rollout_env = make_env(cfg, num_parallel_envs=1)
|
online_training_env = make_env(cfg, n_envs=1)
|
||||||
|
|
||||||
# create an empty online dataset similar to offline dataset
|
# create an empty online dataset similar to offline dataset
|
||||||
online_dataset = deepcopy(offline_dataset)
|
online_dataset = deepcopy(offline_dataset)
|
||||||
|
@ -427,10 +428,12 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
policy.eval()
|
policy.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
eval_info = eval_policy(
|
eval_info = eval_policy(
|
||||||
rollout_env,
|
online_training_env,
|
||||||
policy,
|
policy,
|
||||||
|
n_episodes=1,
|
||||||
return_episode_data=True,
|
return_episode_data=True,
|
||||||
seed=cfg.seed,
|
start_seed=cfg.training.online_env_seed,
|
||||||
|
enable_progbar=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
add_episodes_inplace(
|
add_episodes_inplace(
|
||||||
|
@ -461,6 +464,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
step += 1
|
step += 1
|
||||||
online_step += 1
|
online_step += 1
|
||||||
|
|
||||||
|
eval_env.close()
|
||||||
|
online_training_env.close()
|
||||||
logging.info("End of training")
|
logging.info("End of training")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -37,7 +37,7 @@ def test_factory(env_name):
|
||||||
overrides=[f"env={env_name}", f"device={DEVICE}"],
|
overrides=[f"env={env_name}", f"device={DEVICE}"],
|
||||||
)
|
)
|
||||||
|
|
||||||
env = make_env(cfg, num_parallel_envs=1)
|
env = make_env(cfg, n_envs=1)
|
||||||
obs, _ = env.reset()
|
obs, _ = env.reset()
|
||||||
obs = preprocess_observation(obs)
|
obs = preprocess_observation(obs)
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@ from lerobot import available_policies
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
from lerobot.common.datasets.utils import cycle
|
from lerobot.common.datasets.utils import cycle
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
|
from lerobot.common.envs.utils import preprocess_observation
|
||||||
from lerobot.common.policies.factory import get_policy_and_config_classes, make_policy
|
from lerobot.common.policies.factory import get_policy_and_config_classes, make_policy
|
||||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||||
from lerobot.common.policies.policy_protocol import Policy
|
from lerobot.common.policies.policy_protocol import Policy
|
||||||
|
@ -80,7 +80,7 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||||
assert isinstance(policy, PyTorchModelHubMixin)
|
assert isinstance(policy, PyTorchModelHubMixin)
|
||||||
|
|
||||||
# Check that we run select_actions and get the appropriate output.
|
# Check that we run select_actions and get the appropriate output.
|
||||||
env = make_env(cfg, num_parallel_envs=2)
|
env = make_env(cfg, n_envs=2)
|
||||||
|
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
|
@ -112,10 +112,7 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||||
|
|
||||||
# get the next action for the environment
|
# get the next action for the environment
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
action = policy.select_action(observation)
|
action = policy.select_action(observation).cpu().numpy()
|
||||||
|
|
||||||
# convert action to cpu numpy array
|
|
||||||
action = postprocess_action(action)
|
|
||||||
|
|
||||||
# Test step through policy
|
# Test step through policy
|
||||||
env.step(action)
|
env.step(action)
|
||||||
|
|
Loading…
Reference in New Issue