Refactor eval.py (#127)
This commit is contained in:
parent
b7b69fcc3d
commit
bccee745c3
7
Makefile
7
Makefile
|
@ -35,6 +35,7 @@ test-act-ete-train:
|
|||
training.offline_steps=2 \
|
||||
training.online_steps=0 \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
device=cpu \
|
||||
training.save_model=true \
|
||||
training.save_freq=2 \
|
||||
|
@ -47,6 +48,7 @@ test-act-ete-eval:
|
|||
python lerobot/scripts/eval.py \
|
||||
-p tests/outputs/act/checkpoints/000002 \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=8 \
|
||||
device=cpu \
|
||||
|
||||
|
@ -58,6 +60,7 @@ test-diffusion-ete-train:
|
|||
training.offline_steps=2 \
|
||||
training.online_steps=0 \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
device=cpu \
|
||||
training.save_model=true \
|
||||
training.save_freq=2 \
|
||||
|
@ -68,6 +71,7 @@ test-diffusion-ete-eval:
|
|||
python lerobot/scripts/eval.py \
|
||||
-p tests/outputs/diffusion/checkpoints/000002 \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=8 \
|
||||
device=cpu \
|
||||
|
||||
|
@ -81,6 +85,7 @@ test-tdmpc-ete-train:
|
|||
training.offline_steps=2 \
|
||||
training.online_steps=2 \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=2 \
|
||||
device=cpu \
|
||||
training.save_model=true \
|
||||
|
@ -92,6 +97,7 @@ test-tdmpc-ete-eval:
|
|||
python lerobot/scripts/eval.py \
|
||||
-p tests/outputs/tdmpc/checkpoints/000002 \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=8 \
|
||||
device=cpu \
|
||||
|
||||
|
@ -100,5 +106,6 @@ test-default-ete-eval:
|
|||
python lerobot/scripts/eval.py \
|
||||
--config lerobot/configs/default.yaml \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=8 \
|
||||
device=cpu \
|
||||
|
|
|
@ -1,13 +1,17 @@
|
|||
import importlib
|
||||
|
||||
import gymnasium as gym
|
||||
from omegaconf import DictConfig
|
||||
|
||||
|
||||
def make_env(cfg, num_parallel_envs=0) -> gym.Env | gym.vector.SyncVectorEnv:
|
||||
"""
|
||||
Note: When `num_parallel_envs > 0`, this function returns a `SyncVectorEnv` which takes batched action as input and
|
||||
returns batched observation, reward, terminated, truncated of `num_parallel_envs` items.
|
||||
def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv:
|
||||
"""Makes a gym vector environment according to the evaluation config.
|
||||
|
||||
n_envs can be used to override eval.batch_size in the configuration. Must be at least 1.
|
||||
"""
|
||||
if n_envs is not None and n_envs < 1:
|
||||
raise ValueError("`n_envs must be at least 1")
|
||||
|
||||
kwargs = {
|
||||
"obs_type": "pixels_agent_pos",
|
||||
"render_mode": "rgb_array",
|
||||
|
@ -28,16 +32,13 @@ def make_env(cfg, num_parallel_envs=0) -> gym.Env | gym.vector.SyncVectorEnv:
|
|||
|
||||
gym_handle = f"{package_name}/{cfg.env.task}"
|
||||
|
||||
if num_parallel_envs == 0:
|
||||
# non-batched version of the env that returns an observation of shape (c)
|
||||
env = gym.make(gym_handle, disable_env_checker=True, **kwargs)
|
||||
else:
|
||||
# batched version of the env that returns an observation of shape (b, c)
|
||||
env = gym.vector.SyncVectorEnv(
|
||||
[
|
||||
lambda: gym.make(gym_handle, disable_env_checker=True, **kwargs)
|
||||
for _ in range(num_parallel_envs)
|
||||
]
|
||||
)
|
||||
# batched version of the env that returns an observation of shape (b, c)
|
||||
env_cls = gym.vector.AsyncVectorEnv if cfg.eval.use_async_envs else gym.vector.SyncVectorEnv
|
||||
env = env_cls(
|
||||
[
|
||||
lambda: gym.make(gym_handle, disable_env_checker=True, **kwargs)
|
||||
for _ in range(n_envs if n_envs is not None else cfg.eval.batch_size)
|
||||
]
|
||||
)
|
||||
|
||||
return env
|
||||
|
|
|
@ -1,15 +1,23 @@
|
|||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def preprocess_observation(observation):
|
||||
def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
|
||||
"""Convert environment observation to LeRobot format observation.
|
||||
Args:
|
||||
observation: Dictionary of observation batches from a Gym vector environment.
|
||||
Returns:
|
||||
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
|
||||
"""
|
||||
# map to expected inputs for the policy
|
||||
obs = {}
|
||||
return_observations = {}
|
||||
|
||||
if isinstance(observation["pixels"], dict):
|
||||
imgs = {f"observation.images.{key}": img for key, img in observation["pixels"].items()}
|
||||
if isinstance(observations["pixels"], dict):
|
||||
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
|
||||
else:
|
||||
imgs = {"observation.image": observation["pixels"]}
|
||||
imgs = {"observation.image": observations["pixels"]}
|
||||
|
||||
for imgkey, img in imgs.items():
|
||||
img = torch.from_numpy(img)
|
||||
|
@ -26,17 +34,10 @@ def preprocess_observation(observation):
|
|||
img = img.type(torch.float32)
|
||||
img /= 255
|
||||
|
||||
obs[imgkey] = img
|
||||
return_observations[imgkey] = img
|
||||
|
||||
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing requirement for "agent_pos"
|
||||
obs["observation.state"] = torch.from_numpy(observation["agent_pos"]).float()
|
||||
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
|
||||
# requirement for "agent_pos"
|
||||
return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
|
||||
|
||||
return obs
|
||||
|
||||
|
||||
def postprocess_action(action):
|
||||
action = action.to("cpu").numpy()
|
||||
assert (
|
||||
action.ndim == 2
|
||||
), "we assume dimensions are respectively the number of parallel envs, action dimensions"
|
||||
return action
|
||||
return return_observations
|
||||
|
|
|
@ -115,7 +115,7 @@ class Logger:
|
|||
for k, v in d.items():
|
||||
self._wandb.log({f"{mode}/{k}": v}, step=step)
|
||||
|
||||
def log_video(self, video, step, mode="train"):
|
||||
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
||||
assert mode in {"train", "eval"}
|
||||
wandb_video = self._wandb.Video(video, fps=self._cfg.fps, format="mp4")
|
||||
wandb_video = self._wandb.Video(video_path, fps=self._cfg.fps, format="mp4")
|
||||
self._wandb.log({f"{mode}/video": wandb_video}, step=step)
|
||||
|
|
|
@ -10,6 +10,8 @@ hydra:
|
|||
name: default
|
||||
|
||||
device: cuda # cpu
|
||||
# `seed` is used for training (eg: model initialization, dataset shuffling)
|
||||
# AND for the evaluation environments.
|
||||
seed: ???
|
||||
dataset_repo_id: lerobot/pusht
|
||||
|
||||
|
@ -18,6 +20,8 @@ training:
|
|||
online_steps: ???
|
||||
online_steps_between_rollouts: ???
|
||||
online_sampling_ratio: 0.5
|
||||
# `online_env_seed` is used for environments for online training data rollouts.
|
||||
online_env_seed: ???
|
||||
eval_freq: ???
|
||||
save_freq: ???
|
||||
log_freq: 250
|
||||
|
@ -25,8 +29,10 @@ training:
|
|||
|
||||
eval:
|
||||
n_episodes: 1
|
||||
# TODO(alexander-soare): Right now this does not work. Reinstate this.
|
||||
# `batch_size` specifies the number of environments to use in a gym.vector.VectorEnv.
|
||||
batch_size: 1
|
||||
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
|
||||
use_async_envs: false
|
||||
|
||||
wandb:
|
||||
enable: true
|
||||
|
|
|
@ -28,7 +28,8 @@ training:
|
|||
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||
|
||||
eval:
|
||||
n_episodes:: 50
|
||||
n_episodes: 50
|
||||
batch_size: 50
|
||||
|
||||
# See `configuration_act.py` for more details.
|
||||
policy:
|
||||
|
|
|
@ -28,6 +28,7 @@ training:
|
|||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
batch_size: 50
|
||||
|
||||
override_dataset_stats:
|
||||
# TODO(rcadene, alexander-soare): should we remove image stats as well? do we use a pretrained vision model?
|
||||
|
|
|
@ -8,6 +8,7 @@ training:
|
|||
eval_freq: 5000
|
||||
online_steps_between_rollouts: 1
|
||||
online_sampling_ratio: 0.5
|
||||
online_env_seed: 10000
|
||||
|
||||
batch_size: 256
|
||||
grad_clip_norm: 10.0
|
||||
|
|
|
@ -34,286 +34,328 @@ import time
|
|||
from copy import deepcopy
|
||||
from datetime import datetime as dt
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import einops
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import Dataset, Features, Image, Sequence, Value
|
||||
from datasets import Dataset, Features, Image, Sequence, Value, concatenate_datasets
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
||||
from huggingface_hub.utils._validators import HFValidationError
|
||||
from PIL import Image as PILImage
|
||||
from torch import Tensor
|
||||
from tqdm import trange
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.utils import hf_transform_to_torch
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
|
||||
from lerobot.common.envs.utils import preprocess_observation
|
||||
from lerobot.common.logger import log_output_dir
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
from lerobot.common.utils.io_utils import write_video
|
||||
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
|
||||
|
||||
|
||||
def rollout(
|
||||
env: gym.vector.VectorEnv,
|
||||
policy: Policy,
|
||||
seeds: list[int] | None = None,
|
||||
return_observations: bool = False,
|
||||
render_callback: Callable[[gym.vector.VectorEnv], None] | None = None,
|
||||
enable_progbar: bool = False,
|
||||
) -> dict:
|
||||
"""Run a batched policy rollout once through a batch of environments.
|
||||
|
||||
Note that all environments in the batch are run until the last environment is done. This means some
|
||||
data will probably need to be discarded (for environments that aren't the first one to be done).
|
||||
|
||||
The return dictionary contains:
|
||||
(optional) "observation": A a dictionary of (batch, sequence + 1, *) tensors mapped to observation
|
||||
keys. NOTE the that this has an extra sequence element relative to the other keys in the
|
||||
dictionary. This is because an extra observation is included for after the environment is
|
||||
terminated or truncated.
|
||||
"action": A (batch, sequence, action_dim) tensor of actions applied based on the observations (not
|
||||
including the last observations).
|
||||
"reward": A (batch, sequence) tensor of rewards received for applying the actions.
|
||||
"success": A (batch, sequence) tensor of success conditions (the only time this can be True is upon
|
||||
environment termination/truncation).
|
||||
"don": A (batch, sequence) tensor of **cumulative** done conditions. For any given batch element,
|
||||
the first True is followed by True's all the way till the end. This can be used for masking
|
||||
extraneous elements from the sequences above.
|
||||
|
||||
Args:
|
||||
env: The batch of environments.
|
||||
policy: The policy.
|
||||
seeds: The environments are seeded once at the start of the rollout. If provided, this argument
|
||||
specifies the seeds for each of the environments.
|
||||
return_observations: Whether to include all observations in the returned rollout data. Observations
|
||||
are returned optionally because they typically take more memory to cache. Defaults to False.
|
||||
render_callback: Optional rendering callback to be used after the environments are reset, and after
|
||||
every step.
|
||||
enable_progbar: Enable a progress bar over rollout steps.
|
||||
Returns:
|
||||
The dictionary described above.
|
||||
"""
|
||||
device = get_device_from_parameters(policy)
|
||||
|
||||
# Reset the policy and environments.
|
||||
policy.reset()
|
||||
|
||||
observation, info = env.reset(seed=seeds)
|
||||
if render_callback is not None:
|
||||
render_callback(env)
|
||||
|
||||
all_observations = []
|
||||
all_actions = []
|
||||
all_rewards = []
|
||||
all_successes = []
|
||||
all_dones = []
|
||||
|
||||
step = 0
|
||||
# Keep track of which environments are done.
|
||||
done = np.array([False] * env.num_envs)
|
||||
max_steps = env.call("_max_episode_steps")[0]
|
||||
progbar = trange(
|
||||
max_steps,
|
||||
desc=f"Running rollout with {max_steps} steps (maximum) per rollout",
|
||||
disable=not enable_progbar,
|
||||
leave=False,
|
||||
)
|
||||
while not np.all(done):
|
||||
# Numpy array to tensor and changing dictionary keys to LeRobot policy format.
|
||||
observation = preprocess_observation(observation)
|
||||
if return_observations:
|
||||
all_observations.append(deepcopy(observation))
|
||||
|
||||
observation = {key: observation[key].to(device, non_blocking=True) for key in observation}
|
||||
|
||||
with torch.inference_mode():
|
||||
action = policy.select_action(observation)
|
||||
|
||||
# Convert to CPU / numpy.
|
||||
action = action.to("cpu").numpy()
|
||||
assert action.ndim == 2, "Action dimensions should be (batch, action_dim)"
|
||||
|
||||
# Apply the next action.
|
||||
observation, reward, terminated, truncated, info = env.step(action)
|
||||
if render_callback is not None:
|
||||
render_callback(env)
|
||||
|
||||
# VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't
|
||||
# available of none of the envs finished.
|
||||
if "final_info" in info:
|
||||
successes = [info["is_success"] if info is not None else False for info in info["final_info"]]
|
||||
else:
|
||||
successes = [False] * env.num_envs
|
||||
|
||||
# Keep track of which environments are done so far.
|
||||
done = terminated | truncated | done
|
||||
|
||||
all_actions.append(torch.from_numpy(action))
|
||||
all_rewards.append(torch.from_numpy(reward))
|
||||
all_dones.append(torch.from_numpy(done))
|
||||
all_successes.append(torch.tensor(successes))
|
||||
|
||||
step += 1
|
||||
running_success_rate = (
|
||||
einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any").numpy().mean()
|
||||
)
|
||||
progbar.set_postfix({"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"})
|
||||
progbar.update()
|
||||
|
||||
# Track the final observation.
|
||||
if return_observations:
|
||||
observation = preprocess_observation(observation)
|
||||
all_observations.append(deepcopy(observation))
|
||||
|
||||
# Stack the sequence along the first dimension so that we have (batch, sequence, *) tensors.
|
||||
ret = {
|
||||
"action": torch.stack(all_actions, dim=1),
|
||||
"reward": torch.stack(all_rewards, dim=1),
|
||||
"success": torch.stack(all_successes, dim=1),
|
||||
"done": torch.stack(all_dones, dim=1),
|
||||
}
|
||||
if return_observations:
|
||||
stacked_observations = {}
|
||||
for key in all_observations[0]:
|
||||
stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1)
|
||||
ret["observation"] = stacked_observations
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def eval_policy(
|
||||
env: gym.vector.VectorEnv,
|
||||
policy: torch.nn.Module,
|
||||
n_episodes: int,
|
||||
max_episodes_rendered: int = 0,
|
||||
video_dir: Path = None,
|
||||
video_dir: Path | None = None,
|
||||
return_episode_data: bool = False,
|
||||
seed=None,
|
||||
):
|
||||
start_seed: int | None = None,
|
||||
enable_progbar: bool = False,
|
||||
enable_inner_progbar: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
set `return_episode_data` to return a Hugging Face dataset object in an "episodes" key of the return dict.
|
||||
Args:
|
||||
env: The batch of environments.
|
||||
policy: The policy.
|
||||
n_episodes: The number of episodes to evaluate.
|
||||
max_episodes_rendered: Maximum number of episodes to render into videos.
|
||||
video_dir: Where to save rendered videos.
|
||||
return_episode_data: Whether to return episode data for online training. Incorporates the data into
|
||||
the "episodes" key of the returned dictionary.
|
||||
start_seed: The first seed to use for the first individual rollout. For all subsequent rollouts the
|
||||
seed is incremented by 1. If not provided, the environments are not manually seeded.
|
||||
enable_progbar: Enable progress bar over batches.
|
||||
enable_inner_progbar: Enable progress bar over steps in each batch.
|
||||
Returns:
|
||||
Dictionary with metrics and data regarding the rollouts.
|
||||
"""
|
||||
start = time.time()
|
||||
policy.eval()
|
||||
|
||||
fps = env.unwrapped.metadata["render_fps"]
|
||||
# Determine how many batched rollouts we need to get n_episodes. Note that if n_episodes is not evenly
|
||||
# divisible by env.num_envs we end up discarding some data in the last batch.
|
||||
n_batches = n_episodes // env.num_envs + int((n_episodes % env.num_envs) != 0)
|
||||
|
||||
device = "cpu" if policy is None else next(policy.parameters()).device
|
||||
|
||||
start = time.time()
|
||||
# Keep track of some metrics.
|
||||
sum_rewards = []
|
||||
max_rewards = []
|
||||
all_successes = []
|
||||
seeds = []
|
||||
all_seeds = []
|
||||
threads = [] # for video saving threads
|
||||
episode_counter = 0 # for saving the correct number of videos
|
||||
n_episodes_rendered = 0 # for saving the correct number of videos
|
||||
|
||||
num_episodes = len(env.envs)
|
||||
|
||||
# TODO(alexander-soare): if num_episodes is not evenly divisible by the batch size, this will do more work than
|
||||
# needed as I'm currently taking a ceil.
|
||||
ep_frames = []
|
||||
|
||||
def render_frame(env):
|
||||
# Callback for visualization.
|
||||
def render_frame(env: gym.vector.VectorEnv):
|
||||
# noqa: B023
|
||||
eps_rendered = min(max_episodes_rendered, len(env.envs))
|
||||
visu = np.stack([env.envs[i].render() for i in range(eps_rendered)])
|
||||
ep_frames.append(visu) # noqa: B023
|
||||
if n_episodes_rendered >= max_episodes_rendered:
|
||||
return
|
||||
n_to_render_now = min(max_episodes_rendered - n_episodes_rendered, env.num_envs)
|
||||
if isinstance(env, gym.vector.SyncVectorEnv):
|
||||
ep_frames.append(np.stack([env.envs[i].render() for i in range(n_to_render_now)])) # noqa: B023
|
||||
elif isinstance(env, gym.vector.AsyncVectorEnv):
|
||||
# Here we must render all frames and discard any we don't need.
|
||||
ep_frames.append(np.stack(env.call("render")[:n_to_render_now]))
|
||||
|
||||
for _ in range(num_episodes):
|
||||
seeds.append("TODO")
|
||||
if max_episodes_rendered > 0:
|
||||
video_paths: list[str] = []
|
||||
|
||||
if hasattr(policy, "reset"):
|
||||
policy.reset()
|
||||
else:
|
||||
logging.warning(
|
||||
f"Policy {policy} doesnt have a `reset` method. It is required if the policy relies on an internal state during rollout."
|
||||
if return_episode_data:
|
||||
episode_data: dict | None = None
|
||||
|
||||
progbar = trange(n_batches, desc="Stepping through eval batches", disable=not enable_progbar)
|
||||
for batch_ix in progbar:
|
||||
# Cache frames for rendering videos. Each item will be (b, h, w, c), and the list indexes the rollout
|
||||
# step.
|
||||
if max_episodes_rendered > 0:
|
||||
ep_frames: list[np.ndarray] = []
|
||||
|
||||
seeds = range(start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs))
|
||||
rollout_data = rollout(
|
||||
env,
|
||||
policy,
|
||||
seeds=seeds,
|
||||
return_observations=return_episode_data,
|
||||
render_callback=render_frame if max_episodes_rendered > 0 else None,
|
||||
enable_progbar=enable_inner_progbar,
|
||||
)
|
||||
|
||||
# reset the environment
|
||||
observation, info = env.reset(seed=seed)
|
||||
if max_episodes_rendered > 0:
|
||||
render_frame(env)
|
||||
# Figure out where in each rollout sequence the first done condition was encountered (results after
|
||||
# this won't be included).
|
||||
n_steps = rollout_data["done"].shape[1]
|
||||
# Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker.
|
||||
done_indices = torch.argmax(rollout_data["done"].to(int), axis=1) # (batch_size, rollout_steps)
|
||||
# Make a mask with shape (batch, n_steps) to mask out rollout data after the first done
|
||||
# (batch-element-wise). Note the `done_indices + 1` to make sure to keep the data from the done step.
|
||||
mask = (torch.arange(n_steps) <= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)).int()
|
||||
# Extend metrics.
|
||||
batch_sum_rewards = einops.reduce((rollout_data["reward"] * mask), "b n -> b", "sum")
|
||||
sum_rewards.extend(batch_sum_rewards.tolist())
|
||||
batch_max_rewards = einops.reduce((rollout_data["reward"] * mask), "b n -> b", "max")
|
||||
max_rewards.extend(batch_max_rewards.tolist())
|
||||
batch_successes = einops.reduce((rollout_data["success"] * mask), "b n -> b", "any")
|
||||
all_successes.extend(batch_successes.tolist())
|
||||
all_seeds.extend(seeds)
|
||||
|
||||
observations = []
|
||||
actions = []
|
||||
# episode
|
||||
# frame_id
|
||||
# timestamp
|
||||
rewards = []
|
||||
successes = []
|
||||
dones = []
|
||||
|
||||
done = torch.tensor([False for _ in env.envs])
|
||||
step = 0
|
||||
max_steps = env.envs[0]._max_episode_steps
|
||||
progbar = trange(max_steps, desc=f"Running eval with {max_steps} steps (maximum) per rollout.")
|
||||
while not done.all():
|
||||
# format from env keys to lerobot keys
|
||||
observation = preprocess_observation(observation)
|
||||
if return_episode_data:
|
||||
observations.append(deepcopy(observation))
|
||||
this_episode_data = _compile_episode_data(
|
||||
rollout_data,
|
||||
done_indices,
|
||||
start_episode_index=batch_ix * env.num_envs,
|
||||
start_data_index=(
|
||||
0 if episode_data is None else (episode_data["episode_data_index"]["to"][-1].item())
|
||||
),
|
||||
fps=env.unwrapped.metadata["render_fps"],
|
||||
)
|
||||
if episode_data is None:
|
||||
episode_data = this_episode_data
|
||||
else:
|
||||
# Some sanity checks to make sure we are not correctly compiling the data.
|
||||
assert (
|
||||
episode_data["hf_dataset"]["episode_index"][-1] + 1
|
||||
== this_episode_data["hf_dataset"]["episode_index"][0]
|
||||
)
|
||||
assert (
|
||||
episode_data["hf_dataset"]["index"][-1] + 1 == this_episode_data["hf_dataset"]["index"][0]
|
||||
)
|
||||
assert torch.equal(
|
||||
episode_data["episode_data_index"]["to"][-1],
|
||||
this_episode_data["episode_data_index"]["from"][0],
|
||||
)
|
||||
# Concatenate the episode data.
|
||||
episode_data = {
|
||||
"hf_dataset": concatenate_datasets(
|
||||
[episode_data["hf_dataset"], this_episode_data["hf_dataset"]]
|
||||
),
|
||||
"episode_data_index": {
|
||||
k: torch.cat(
|
||||
[
|
||||
episode_data["episode_data_index"][k],
|
||||
this_episode_data["episode_data_index"][k],
|
||||
]
|
||||
)
|
||||
for k in ["from", "to"]
|
||||
},
|
||||
}
|
||||
|
||||
# send observation to device/gpu
|
||||
observation = {key: observation[key].to(device, non_blocking=True) for key in observation}
|
||||
|
||||
# get the next action for the environment
|
||||
with torch.inference_mode():
|
||||
action = policy.select_action(observation)
|
||||
|
||||
# convert to cpu numpy
|
||||
action = postprocess_action(action)
|
||||
|
||||
# apply the next action
|
||||
observation, reward, terminated, truncated, info = env.step(action)
|
||||
# Maybe render video for visualization.
|
||||
if max_episodes_rendered > 0:
|
||||
render_frame(env)
|
||||
batch_stacked_frames = np.stack(ep_frames, axis=1) # (b, t, *)
|
||||
for stacked_frames, done_index in zip(
|
||||
batch_stacked_frames, done_indices.flatten().tolist(), strict=False
|
||||
):
|
||||
if n_episodes_rendered >= max_episodes_rendered:
|
||||
break
|
||||
video_dir.mkdir(parents=True, exist_ok=True)
|
||||
video_path = video_dir / f"eval_episode_{n_episodes_rendered}.mp4"
|
||||
video_paths.append(str(video_path))
|
||||
thread = threading.Thread(
|
||||
target=write_video,
|
||||
args=(
|
||||
str(video_path),
|
||||
stacked_frames[: done_index + 2], # + 2 to capture the observation frame after done
|
||||
env.unwrapped.metadata["render_fps"],
|
||||
),
|
||||
)
|
||||
thread.start()
|
||||
threads.append(thread)
|
||||
n_episodes_rendered += 1
|
||||
|
||||
# TODO(rcadene): implement a wrapper over env to return torch tensors in float32 (and cuda?)
|
||||
action = torch.from_numpy(action)
|
||||
reward = torch.from_numpy(reward)
|
||||
terminated = torch.from_numpy(terminated)
|
||||
truncated = torch.from_numpy(truncated)
|
||||
# environment is considered done (no more steps), when success state is reached (terminated is True),
|
||||
# or time limit is reached (truncated is True), or it was previsouly done.
|
||||
done = terminated | truncated | done
|
||||
|
||||
if "final_info" in info:
|
||||
# VectorEnv stores is_success into `info["final_info"][env_id]["is_success"]` instead of `info["is_success"]`
|
||||
success = [
|
||||
env_info["is_success"] if env_info is not None else False for env_info in info["final_info"]
|
||||
]
|
||||
else:
|
||||
success = [False for _ in env.envs]
|
||||
success = torch.tensor(success)
|
||||
|
||||
actions.append(action)
|
||||
rewards.append(reward)
|
||||
dones.append(done)
|
||||
successes.append(success)
|
||||
|
||||
step += 1
|
||||
progbar.set_postfix(
|
||||
{"running_success_rate": f"{np.mean(all_successes[:n_episodes]).item() * 100:.1f}%"}
|
||||
)
|
||||
progbar.update()
|
||||
|
||||
env.close()
|
||||
|
||||
# add the last observation when the env is done
|
||||
if return_episode_data:
|
||||
observation = preprocess_observation(observation)
|
||||
observations.append(deepcopy(observation))
|
||||
|
||||
if return_episode_data:
|
||||
new_obses = {}
|
||||
for key in observations[0].keys(): # noqa: SIM118
|
||||
new_obses[key] = torch.stack([obs[key] for obs in observations], dim=1)
|
||||
observations = new_obses
|
||||
actions = torch.stack(actions, dim=1)
|
||||
rewards = torch.stack(rewards, dim=1)
|
||||
successes = torch.stack(successes, dim=1)
|
||||
dones = torch.stack(dones, dim=1)
|
||||
|
||||
# Figure out where in each rollout sequence the first done condition was encountered (results after
|
||||
# this won't be included).
|
||||
# Note: this assumes that the shape of the done key is (batch_size, max_steps).
|
||||
# Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker.
|
||||
done_indices = torch.argmax(dones.to(int), axis=1) # (batch_size, rollout_steps)
|
||||
expand_done_indices = done_indices[:, None].expand(-1, step)
|
||||
expand_step_indices = torch.arange(step)[None, :].expand(num_episodes, -1)
|
||||
mask = (expand_step_indices <= expand_done_indices).int() # (batch_size, rollout_steps)
|
||||
batch_sum_reward = einops.reduce((rewards * mask), "b n -> b", "sum")
|
||||
batch_max_reward = einops.reduce((rewards * mask), "b n -> b", "max")
|
||||
batch_success = einops.reduce((successes * mask), "b n -> b", "any")
|
||||
sum_rewards.extend(batch_sum_reward.tolist())
|
||||
max_rewards.extend(batch_max_reward.tolist())
|
||||
all_successes.extend(batch_success.tolist())
|
||||
|
||||
# similar logic is implemented when datasets are pushed to hub (see: `push_to_hub`)
|
||||
ep_dicts = []
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
num_episodes = dones.shape[0]
|
||||
total_frames = 0
|
||||
id_from = 0
|
||||
for ep_id in range(num_episodes):
|
||||
num_frames = done_indices[ep_id].item() + 1
|
||||
total_frames += num_frames
|
||||
|
||||
# TODO(rcadene): We need to add a missing last frame which is the observation
|
||||
# of a done state. it is critical to have this frame for tdmpc to predict a "done observation/state"
|
||||
if return_episode_data:
|
||||
ep_dict = {
|
||||
"action": actions[ep_id, :num_frames],
|
||||
"episode_index": torch.tensor([ep_id] * num_frames),
|
||||
"frame_index": torch.arange(0, num_frames, 1),
|
||||
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
||||
"next.done": dones[ep_id, :num_frames],
|
||||
"next.reward": rewards[ep_id, :num_frames].type(torch.float32),
|
||||
}
|
||||
for key in observations:
|
||||
ep_dict[key] = observations[key][ep_id][:num_frames]
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
episode_data_index["from"].append(id_from)
|
||||
episode_data_index["to"].append(id_from + num_frames)
|
||||
|
||||
id_from += num_frames
|
||||
|
||||
# similar logic is implemented in dataset preprocessing
|
||||
if return_episode_data:
|
||||
data_dict = {}
|
||||
keys = ep_dicts[0].keys()
|
||||
for key in keys:
|
||||
if "image" not in key:
|
||||
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||
else:
|
||||
if key not in data_dict:
|
||||
data_dict[key] = []
|
||||
for ep_dict in ep_dicts:
|
||||
for img in ep_dict[key]:
|
||||
# sanity check that images are channel first
|
||||
c, h, w = img.shape
|
||||
assert c < h and c < w, f"expect channel first images, but instead {img.shape}"
|
||||
|
||||
# sanity check that images are float32 in range [0,1]
|
||||
assert img.dtype == torch.float32, f"expect torch.float32, but instead {img.dtype=}"
|
||||
assert img.max() <= 1, f"expect pixels lower than 1, but instead {img.max()=}"
|
||||
assert img.min() >= 0, f"expect pixels greater than 1, but instead {img.min()=}"
|
||||
|
||||
# from float32 in range [0,1] to uint8 in range [0,255]
|
||||
img *= 255
|
||||
img = img.type(torch.uint8)
|
||||
|
||||
# convert to channel last and numpy as expected by PIL
|
||||
img = PILImage.fromarray(img.permute(1, 2, 0).numpy())
|
||||
|
||||
data_dict[key].append(img)
|
||||
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
episode_data_index["from"] = torch.tensor(episode_data_index["from"])
|
||||
episode_data_index["to"] = torch.tensor(episode_data_index["to"])
|
||||
|
||||
# TODO(rcadene): clean this
|
||||
features = {}
|
||||
for key in observations:
|
||||
if "image" in key:
|
||||
features[key] = Image()
|
||||
else:
|
||||
features[key] = Sequence(
|
||||
length=data_dict[key].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features.update(
|
||||
{
|
||||
"action": Sequence(
|
||||
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
),
|
||||
"episode_index": Value(dtype="int64", id=None),
|
||||
"frame_index": Value(dtype="int64", id=None),
|
||||
"timestamp": Value(dtype="float32", id=None),
|
||||
"next.reward": Value(dtype="float32", id=None),
|
||||
"next.done": Value(dtype="bool", id=None),
|
||||
#'next.success': Value(dtype='bool', id=None),
|
||||
"index": Value(dtype="int64", id=None),
|
||||
}
|
||||
)
|
||||
features = Features(features)
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
|
||||
if max_episodes_rendered > 0:
|
||||
batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *)
|
||||
|
||||
for stacked_frames, done_index in zip(
|
||||
batch_stacked_frames, done_indices.flatten().tolist(), strict=False
|
||||
):
|
||||
if episode_counter >= max_episodes_rendered:
|
||||
continue
|
||||
video_dir.mkdir(parents=True, exist_ok=True)
|
||||
video_path = video_dir / f"eval_episode_{episode_counter}.mp4"
|
||||
thread = threading.Thread(
|
||||
target=write_video,
|
||||
args=(str(video_path), stacked_frames[:done_index], fps),
|
||||
)
|
||||
thread.start()
|
||||
threads.append(thread)
|
||||
episode_counter += 1
|
||||
|
||||
videos = einops.rearrange(batch_stacked_frames, "b t h w c -> b t c h w")
|
||||
|
||||
# Wait till all video rendering threads are done.
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# Compile eval info.
|
||||
info = {
|
||||
"per_episode": [
|
||||
{
|
||||
|
@ -325,32 +367,127 @@ def eval_policy(
|
|||
}
|
||||
for i, (sum_reward, max_reward, success, seed) in enumerate(
|
||||
zip(
|
||||
sum_rewards[:num_episodes],
|
||||
max_rewards[:num_episodes],
|
||||
all_successes[:num_episodes],
|
||||
seeds[:num_episodes],
|
||||
sum_rewards[:n_episodes],
|
||||
max_rewards[:n_episodes],
|
||||
all_successes[:n_episodes],
|
||||
all_seeds[:n_episodes],
|
||||
strict=True,
|
||||
)
|
||||
)
|
||||
],
|
||||
"aggregated": {
|
||||
"avg_sum_reward": float(np.nanmean(sum_rewards[:num_episodes])),
|
||||
"avg_max_reward": float(np.nanmean(max_rewards[:num_episodes])),
|
||||
"pc_success": float(np.nanmean(all_successes[:num_episodes]) * 100),
|
||||
"avg_sum_reward": float(np.nanmean(sum_rewards[:n_episodes])),
|
||||
"avg_max_reward": float(np.nanmean(max_rewards[:n_episodes])),
|
||||
"pc_success": float(np.nanmean(all_successes[:n_episodes]) * 100),
|
||||
"eval_s": time.time() - start,
|
||||
"eval_ep_s": (time.time() - start) / num_episodes,
|
||||
"eval_ep_s": (time.time() - start) / n_episodes,
|
||||
},
|
||||
}
|
||||
|
||||
if return_episode_data:
|
||||
info["episodes"] = {
|
||||
"hf_dataset": hf_dataset,
|
||||
"episode_data_index": episode_data_index,
|
||||
}
|
||||
info["episodes"] = episode_data
|
||||
|
||||
if max_episodes_rendered > 0:
|
||||
info["videos"] = videos
|
||||
info["video_paths"] = video_paths
|
||||
|
||||
return info
|
||||
|
||||
|
||||
def _compile_episode_data(
|
||||
rollout_data: dict, done_indices: Tensor, start_episode_index: int, start_data_index: int, fps: float
|
||||
) -> dict:
|
||||
"""Convenience function for `eval_policy(return_episode_data=True)`
|
||||
|
||||
Compiles all the rollout data into a Hugging Face dataset.
|
||||
|
||||
Similar logic is implemented when datasets are pushed to hub (see: `push_to_hub`).
|
||||
"""
|
||||
ep_dicts = []
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
total_frames = 0
|
||||
data_index_from = start_data_index
|
||||
for ep_ix in range(rollout_data["action"].shape[0]):
|
||||
num_frames = done_indices[ep_ix].item() + 1 # + 1 to include the first done frame
|
||||
total_frames += num_frames
|
||||
|
||||
# TODO(rcadene): We need to add a missing last frame which is the observation
|
||||
# of a done state. it is critical to have this frame for tdmpc to predict a "done observation/state"
|
||||
ep_dict = {
|
||||
"action": rollout_data["action"][ep_ix, :num_frames],
|
||||
"episode_index": torch.tensor([start_episode_index + ep_ix] * num_frames),
|
||||
"frame_index": torch.arange(0, num_frames, 1),
|
||||
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
||||
"next.done": rollout_data["done"][ep_ix, :num_frames],
|
||||
"next.reward": rollout_data["reward"][ep_ix, :num_frames].type(torch.float32),
|
||||
}
|
||||
for key in rollout_data["observation"]:
|
||||
ep_dict[key] = rollout_data["observation"][key][ep_ix][:num_frames]
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
episode_data_index["from"].append(data_index_from)
|
||||
episode_data_index["to"].append(data_index_from + num_frames)
|
||||
|
||||
data_index_from += num_frames
|
||||
|
||||
data_dict = {}
|
||||
for key in ep_dicts[0]:
|
||||
if "image" not in key:
|
||||
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||
else:
|
||||
if key not in data_dict:
|
||||
data_dict[key] = []
|
||||
for ep_dict in ep_dicts:
|
||||
for img in ep_dict[key]:
|
||||
# sanity check that images are channel first
|
||||
c, h, w = img.shape
|
||||
assert c < h and c < w, f"expect channel first images, but instead {img.shape}"
|
||||
|
||||
# sanity check that images are float32 in range [0,1]
|
||||
assert img.dtype == torch.float32, f"expect torch.float32, but instead {img.dtype=}"
|
||||
assert img.max() <= 1, f"expect pixels lower than 1, but instead {img.max()=}"
|
||||
assert img.min() >= 0, f"expect pixels greater than 1, but instead {img.min()=}"
|
||||
|
||||
# from float32 in range [0,1] to uint8 in range [0,255]
|
||||
img *= 255
|
||||
img = img.type(torch.uint8)
|
||||
|
||||
# convert to channel last and numpy as expected by PIL
|
||||
img = PILImage.fromarray(img.permute(1, 2, 0).numpy())
|
||||
|
||||
data_dict[key].append(img)
|
||||
|
||||
data_dict["index"] = torch.arange(start_data_index, start_data_index + total_frames, 1)
|
||||
episode_data_index["from"] = torch.tensor(episode_data_index["from"])
|
||||
episode_data_index["to"] = torch.tensor(episode_data_index["to"])
|
||||
|
||||
# TODO(rcadene): clean this
|
||||
features = {}
|
||||
for key in rollout_data["observation"]:
|
||||
if "image" in key:
|
||||
features[key] = Image()
|
||||
else:
|
||||
features[key] = Sequence(length=data_dict[key].shape[1], feature=Value(dtype="float32", id=None))
|
||||
features.update(
|
||||
{
|
||||
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
|
||||
"episode_index": Value(dtype="int64", id=None),
|
||||
"frame_index": Value(dtype="int64", id=None),
|
||||
"timestamp": Value(dtype="float32", id=None),
|
||||
"next.reward": Value(dtype="float32", id=None),
|
||||
"next.done": Value(dtype="bool", id=None),
|
||||
#'next.success': Value(dtype='bool', id=None),
|
||||
"index": Value(dtype="int64", id=None),
|
||||
}
|
||||
)
|
||||
features = Features(features)
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return {
|
||||
"hf_dataset": hf_dataset,
|
||||
"episode_data_index": episode_data_index,
|
||||
}
|
||||
|
||||
|
||||
def eval(
|
||||
pretrained_policy_path: str | None = None,
|
||||
hydra_cfg_path: str | None = None,
|
||||
|
@ -378,7 +515,7 @@ def eval(
|
|||
log_output_dir(out_dir)
|
||||
|
||||
logging.info("Making environment.")
|
||||
env = make_env(hydra_cfg, num_parallel_envs=hydra_cfg.eval.n_episodes)
|
||||
env = make_env(hydra_cfg)
|
||||
|
||||
logging.info("Making policy.")
|
||||
if hydra_cfg_path is None:
|
||||
|
@ -391,19 +528,21 @@ def eval(
|
|||
info = eval_policy(
|
||||
env,
|
||||
policy,
|
||||
hydra_cfg.eval.n_episodes,
|
||||
max_episodes_rendered=10,
|
||||
video_dir=Path(out_dir) / "eval",
|
||||
return_episode_data=False,
|
||||
seed=hydra_cfg.seed,
|
||||
start_seed=hydra_cfg.seed,
|
||||
enable_progbar=True,
|
||||
enable_inner_progbar=True,
|
||||
)
|
||||
print(info["aggregated"])
|
||||
|
||||
# Save info
|
||||
with open(Path(out_dir) / "eval_info.json", "w") as f:
|
||||
# remove pytorch tensors which are not serializable to save the evaluation results only
|
||||
del info["videos"]
|
||||
json.dump(info, f, indent=2)
|
||||
|
||||
env.close()
|
||||
|
||||
logging.info("End of eval")
|
||||
|
||||
|
||||
|
|
|
@ -269,7 +269,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
offline_dataset = make_dataset(cfg)
|
||||
|
||||
logging.info("make_env")
|
||||
env = make_env(cfg, num_parallel_envs=cfg.eval.n_episodes)
|
||||
eval_env = make_env(cfg)
|
||||
|
||||
logging.info("make_policy")
|
||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=offline_dataset.stats)
|
||||
|
@ -337,15 +337,16 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
if step % cfg.training.eval_freq == 0:
|
||||
logging.info(f"Eval policy at step {step}")
|
||||
eval_info = eval_policy(
|
||||
env,
|
||||
eval_env,
|
||||
policy,
|
||||
cfg.eval.n_episodes,
|
||||
video_dir=Path(out_dir) / "eval",
|
||||
max_episodes_rendered=4,
|
||||
seed=cfg.seed,
|
||||
start_seed=cfg.seed,
|
||||
)
|
||||
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline)
|
||||
if cfg.wandb.enable:
|
||||
logger.log_video(eval_info["videos"][0], step, mode="eval")
|
||||
logger.log_video(eval_info["video_paths"][0], step, mode="eval")
|
||||
logging.info("Resume training")
|
||||
|
||||
if cfg.training.save_model and step % cfg.training.save_freq == 0:
|
||||
|
@ -395,7 +396,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
step += 1
|
||||
|
||||
# create an env dedicated to online episodes collection from policy rollout
|
||||
rollout_env = make_env(cfg, num_parallel_envs=1)
|
||||
online_training_env = make_env(cfg, n_envs=1)
|
||||
|
||||
# create an empty online dataset similar to offline dataset
|
||||
online_dataset = deepcopy(offline_dataset)
|
||||
|
@ -427,10 +428,12 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
policy.eval()
|
||||
with torch.no_grad():
|
||||
eval_info = eval_policy(
|
||||
rollout_env,
|
||||
online_training_env,
|
||||
policy,
|
||||
n_episodes=1,
|
||||
return_episode_data=True,
|
||||
seed=cfg.seed,
|
||||
start_seed=cfg.training.online_env_seed,
|
||||
enable_progbar=True,
|
||||
)
|
||||
|
||||
add_episodes_inplace(
|
||||
|
@ -461,6 +464,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
step += 1
|
||||
online_step += 1
|
||||
|
||||
eval_env.close()
|
||||
online_training_env.close()
|
||||
logging.info("End of training")
|
||||
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ def test_factory(env_name):
|
|||
overrides=[f"env={env_name}", f"device={DEVICE}"],
|
||||
)
|
||||
|
||||
env = make_env(cfg, num_parallel_envs=1)
|
||||
env = make_env(cfg, n_envs=1)
|
||||
obs, _ = env.reset()
|
||||
obs = preprocess_observation(obs)
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ from lerobot import available_policies
|
|||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.utils import cycle
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
|
||||
from lerobot.common.envs.utils import preprocess_observation
|
||||
from lerobot.common.policies.factory import get_policy_and_config_classes, make_policy
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.policy_protocol import Policy
|
||||
|
@ -80,7 +80,7 @@ def test_policy(env_name, policy_name, extra_overrides):
|
|||
assert isinstance(policy, PyTorchModelHubMixin)
|
||||
|
||||
# Check that we run select_actions and get the appropriate output.
|
||||
env = make_env(cfg, num_parallel_envs=2)
|
||||
env = make_env(cfg, n_envs=2)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
|
@ -112,10 +112,7 @@ def test_policy(env_name, policy_name, extra_overrides):
|
|||
|
||||
# get the next action for the environment
|
||||
with torch.inference_mode():
|
||||
action = policy.select_action(observation)
|
||||
|
||||
# convert action to cpu numpy array
|
||||
action = postprocess_action(action)
|
||||
action = policy.select_action(observation).cpu().numpy()
|
||||
|
||||
# Test step through policy
|
||||
env.step(action)
|
||||
|
|
Loading…
Reference in New Issue