Merge pull request from Cadene/user/alexander-soare/multistep_policy_and_serial_env

Incorporate SerialEnv and introduct multistep policy logic
This commit is contained in:
Remi 2024-03-20 15:55:27 +01:00 committed by GitHub
commit ec536ef0fa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 460 additions and 283 deletions

View File

@ -57,9 +57,9 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
@property
def stats_patterns(self) -> dict:
return {
("observation", "state"): "b c -> 1 c",
("observation", "image"): "b c h w -> 1 c 1 1",
("action",): "b c -> 1 c",
("observation", "state"): "b c -> c",
("observation", "image"): "b c h w -> c 1 1",
("action",): "b c -> c",
}
@property

View File

@ -115,11 +115,11 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
@property
def stats_patterns(self) -> dict:
d = {
("observation", "state"): "b c -> 1 c",
("action",): "b c -> 1 c",
("observation", "state"): "b c -> c",
("action",): "b c -> c",
}
for cam in CAMERAS[self.dataset_id]:
d[("observation", "image", cam)] = "b c h w -> 1 c 1 1"
d[("observation", "image", cam)] = "b c h w -> c 1 1"
return d
@property

View File

@ -1,4 +1,3 @@
import abc
from collections import deque
from typing import Optional
@ -27,7 +26,6 @@ class AbstractEnv(EnvBase):
self.image_size = image_size
self.num_prev_obs = num_prev_obs
self.num_prev_action = num_prev_action
self._rendering_hooks = []
if pixels_only:
assert from_pixels
@ -45,36 +43,20 @@ class AbstractEnv(EnvBase):
raise NotImplementedError()
# self._prev_action_queue = deque(maxlen=self.num_prev_action)
def register_rendering_hook(self, func):
self._rendering_hooks.append(func)
def call_rendering_hooks(self):
for func in self._rendering_hooks:
func(self)
def reset_rendering_hooks(self):
self._rendering_hooks = []
@abc.abstractmethod
def render(self, mode="rgb_array", width=640, height=480):
raise NotImplementedError()
raise NotImplementedError("Abstract method")
@abc.abstractmethod
def _reset(self, tensordict: Optional[TensorDict] = None):
raise NotImplementedError()
raise NotImplementedError("Abstract method")
@abc.abstractmethod
def _step(self, tensordict: TensorDict):
raise NotImplementedError()
raise NotImplementedError("Abstract method")
@abc.abstractmethod
def _make_env(self):
raise NotImplementedError()
raise NotImplementedError("Abstract method")
@abc.abstractmethod
def _make_spec(self):
raise NotImplementedError()
raise NotImplementedError("Abstract method")
@abc.abstractmethod
def _set_seed(self, seed: Optional[int]):
raise NotImplementedError()
raise NotImplementedError("Abstract method")

View File

@ -35,6 +35,8 @@ _has_gym = importlib.util.find_spec("gym") is not None
class AlohaEnv(AbstractEnv):
_reset_warning_issued = False
def __init__(
self,
task,
@ -120,90 +122,76 @@ class AlohaEnv(AbstractEnv):
return obs
def _reset(self, tensordict: Optional[TensorDict] = None):
td = tensordict
if td is None or td.is_empty():
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
self._current_seed += 1
self.set_seed(self._current_seed)
if tensordict is not None and not AlohaEnv._reset_warning_issued:
logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.")
AlohaEnv._reset_warning_issued = True
# TODO(rcadene): do not use global variable for this
if "sim_transfer_cube" in self.task:
BOX_POSE[0] = sample_box_pose() # used in sim reset
elif "sim_insertion" in self.task:
BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
self._current_seed += 1
self.set_seed(self._current_seed)
raw_obs = self._env.reset()
# TODO(rcadene): add assert
# assert self._current_seed == self._env._seed
# TODO(rcadene): do not use global variable for this
if "sim_transfer_cube" in self.task:
BOX_POSE[0] = sample_box_pose() # used in sim reset
elif "sim_insertion" in self.task:
BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset
obs = self._format_raw_obs(raw_obs.observation)
raw_obs = self._env.reset()
# TODO(rcadene): add assert
# assert self._current_seed == self._env._seed
if self.num_prev_obs > 0:
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue = deque(
[obs["image"]["top"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
)
stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))}
if "state" in obs:
self._prev_obs_state_queue = deque(
[obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
)
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs
obs = self._format_raw_obs(raw_obs.observation)
td = TensorDict(
{
"observation": TensorDict(obs, batch_size=[]),
"done": torch.tensor([False], dtype=torch.bool),
},
batch_size=[],
)
else:
raise NotImplementedError()
if self.num_prev_obs > 0:
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue = deque(
[obs["image"]["top"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
)
stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))}
if "state" in obs:
self._prev_obs_state_queue = deque(
[obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
)
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs
td = TensorDict(
{
"observation": TensorDict(obs, batch_size=[]),
"done": torch.tensor([False], dtype=torch.bool),
},
batch_size=[],
)
self.call_rendering_hooks()
return td
def _step(self, tensordict: TensorDict):
td = tensordict
action = td["action"].numpy()
# step expects shape=(4,) so we pad if necessary
assert action.ndim == 1
# TODO(rcadene): add info["is_success"] and info["success"] ?
sum_reward = 0
if action.ndim == 1:
action = einops.repeat(action, "c -> t c", t=self.frame_skip)
else:
if self.frame_skip > 1:
raise NotImplementedError()
_, reward, _, raw_obs = self._env.step(action)
num_action_steps = action.shape[0]
for i in range(num_action_steps):
_, reward, discount, raw_obs = self._env.step(action[i])
del discount # not used
# TODO(rcadene): add an enum
success = done = reward == 4
obs = self._format_raw_obs(raw_obs)
# TOOD(rcadene): add an enum
success = done = reward == 4
sum_reward += reward
obs = self._format_raw_obs(raw_obs)
if self.num_prev_obs > 0:
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue.append(obs["image"]["top"])
stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))}
if "state" in obs:
self._prev_obs_state_queue.append(obs["state"])
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs
self.call_rendering_hooks()
if self.num_prev_obs > 0:
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue.append(obs["image"]["top"])
stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))}
if "state" in obs:
self._prev_obs_state_queue.append(obs["state"])
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs
td = TensorDict(
{
"observation": TensorDict(obs, batch_size=[]),
"reward": torch.tensor([sum_reward], dtype=torch.float32),
"reward": torch.tensor([reward], dtype=torch.float32),
# succes and done are true when coverage > self.success_threshold in env
"done": torch.tensor([done], dtype=torch.bool),
"success": torch.tensor([success], dtype=torch.bool),

View File

@ -1,14 +1,18 @@
from torchrl.envs import SerialEnv
from torchrl.envs.transforms import Compose, StepCounter, Transform, TransformedEnv
def make_env(cfg, transform=None):
"""
Note: The returned environment is wrapped in a torchrl.SerialEnv with cfg.rollout_batch_size underlying
environments. The env therefore returns batches.`
"""
kwargs = {
"frame_skip": cfg.env.action_repeat,
"from_pixels": cfg.env.from_pixels,
"pixels_only": cfg.env.pixels_only,
"image_size": cfg.env.image_size,
# TODO(rcadene): do we want a specific eval_env_seed?
"seed": cfg.seed,
"num_prev_obs": cfg.n_obs_steps - 1,
}
@ -31,22 +35,33 @@ def make_env(cfg, transform=None):
else:
raise ValueError(cfg.env.name)
env = clsfunc(**kwargs)
def _make_env(seed):
nonlocal kwargs
kwargs["seed"] = seed
env = clsfunc(**kwargs)
# limit rollout to max_steps
env = TransformedEnv(env, StepCounter(max_steps=cfg.env.episode_length))
# limit rollout to max_steps
env = TransformedEnv(env, StepCounter(max_steps=cfg.env.episode_length))
if transform is not None:
# useful to add normalization
if isinstance(transform, Compose):
for tf in transform:
env.append_transform(tf.clone())
elif isinstance(transform, Transform):
env.append_transform(transform.clone())
else:
raise NotImplementedError()
if transform is not None:
# useful to add normalization
if isinstance(transform, Compose):
for tf in transform:
env.append_transform(tf.clone())
elif isinstance(transform, Transform):
env.append_transform(transform.clone())
else:
raise NotImplementedError()
return env
return env
return SerialEnv(
cfg.rollout_batch_size,
create_env_fn=_make_env,
create_env_kwargs=[
{"seed": env_seed} for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size)
],
)
# def make_env(env_name, frame_skip, device, is_test=False):

View File

@ -1,8 +1,8 @@
import importlib
import logging
from collections import deque
from typing import Optional
import einops
import torch
from tensordict import TensorDict
from torchrl.data.tensor_specs import (
@ -20,6 +20,8 @@ _has_gym = importlib.util.find_spec("gym") is not None
class PushtEnv(AbstractEnv):
_reset_warning_issued = False
def __init__(
self,
task="pusht",
@ -80,80 +82,67 @@ class PushtEnv(AbstractEnv):
return obs
def _reset(self, tensordict: Optional[TensorDict] = None):
td = tensordict
if td is None or td.is_empty():
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
self._current_seed += 1
self.set_seed(self._current_seed)
raw_obs = self._env.reset()
assert self._current_seed == self._env._seed
if tensordict is not None and not PushtEnv._reset_warning_issued:
logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.")
PushtEnv._reset_warning_issued = True
obs = self._format_raw_obs(raw_obs)
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
self._current_seed += 1
self.set_seed(self._current_seed)
raw_obs = self._env.reset()
assert self._current_seed == self._env._seed
if self.num_prev_obs > 0:
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue = deque(
[obs["image"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
)
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
if "state" in obs:
self._prev_obs_state_queue = deque(
[obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
)
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs
obs = self._format_raw_obs(raw_obs)
td = TensorDict(
{
"observation": TensorDict(obs, batch_size=[]),
"done": torch.tensor([False], dtype=torch.bool),
},
batch_size=[],
)
else:
raise NotImplementedError()
if self.num_prev_obs > 0:
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue = deque(
[obs["image"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
)
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
if "state" in obs:
self._prev_obs_state_queue = deque(
[obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
)
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs
td = TensorDict(
{
"observation": TensorDict(obs, batch_size=[]),
"done": torch.tensor([False], dtype=torch.bool),
},
batch_size=[],
)
self.call_rendering_hooks()
return td
def _step(self, tensordict: TensorDict):
td = tensordict
action = td["action"].numpy()
# step expects shape=(4,) so we pad if necessary
assert action.ndim == 1
# TODO(rcadene): add info["is_success"] and info["success"] ?
sum_reward = 0
if action.ndim == 1:
action = einops.repeat(action, "c -> t c", t=self.frame_skip)
else:
if self.frame_skip > 1:
raise NotImplementedError()
raw_obs, reward, done, info = self._env.step(action)
num_action_steps = action.shape[0]
for i in range(num_action_steps):
raw_obs, reward, done, info = self._env.step(action[i])
sum_reward += reward
obs = self._format_raw_obs(raw_obs)
obs = self._format_raw_obs(raw_obs)
if self.num_prev_obs > 0:
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue.append(obs["image"])
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
if "state" in obs:
self._prev_obs_state_queue.append(obs["state"])
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs
self.call_rendering_hooks()
if self.num_prev_obs > 0:
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue.append(obs["image"])
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
if "state" in obs:
self._prev_obs_state_queue.append(obs["state"])
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs
td = TensorDict(
{
"observation": TensorDict(obs, batch_size=[]),
"reward": torch.tensor([sum_reward], dtype=torch.float32),
# succes and done are true when coverage > self.success_threshold in env
"reward": torch.tensor([reward], dtype=torch.float32),
# success and done are true when coverage > self.success_threshold in env
"done": torch.tensor([done], dtype=torch.bool),
"success": torch.tensor([done], dtype=torch.bool),
},

View File

@ -118,7 +118,6 @@ class SimxarmEnv(AbstractEnv):
else:
raise NotImplementedError()
self.call_rendering_hooks()
return td
def _step(self, tensordict: TensorDict):
@ -152,8 +151,6 @@ class SimxarmEnv(AbstractEnv):
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs
self.call_rendering_hooks()
td = TensorDict(
{
"observation": self._format_raw_obs(raw_obs),

View File

@ -0,0 +1,70 @@
from collections import deque
import torch
from torch import Tensor, nn
class AbstractPolicy(nn.Module):
"""Base policy which all policies should be derived from.
The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its
documentation for more information.
"""
def __init__(self, n_action_steps: int | None):
"""
n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single
action is returned by `select_actions` and that doesn't have a horizon dimension. The `forward` method then
adds that dimension.
"""
super().__init__()
self.n_action_steps = n_action_steps
self.clear_action_queue()
def update(self, replay_buffer, step):
"""One step of the policy's learning algorithm."""
raise NotImplementedError("Abstract method")
def save(self, fp):
torch.save(self.state_dict(), fp)
def load(self, fp):
d = torch.load(fp)
self.load_state_dict(d)
def select_actions(self, observation) -> Tensor:
"""Select an action (or trajectory of actions) based on an observation during rollout.
If n_action_steps was provided at initialization, this should return a (batch_size, n_action_steps, *) tensor of
actions. Otherwise if n_actions_steps is None, this should return a (batch_size, *) tensor of actions.
"""
raise NotImplementedError("Abstract method")
def clear_action_queue(self):
"""This should be called whenever the environment is reset."""
if self.n_action_steps is not None:
self._action_queue = deque([], maxlen=self.n_action_steps)
def forward(self, *args, **kwargs) -> Tensor:
"""Inference step that makes multi-step policies compatible with their single-step environments.
WARNING: In general, this should not be overriden.
Consider a "policy" that observes the environment then charts a course of N actions to take. To make this fit
into the formalism of a TorchRL environment, we view it as being effectively a policy that (1) makes an
observation and prepares a queue of actions, (2) consumes that queue when queried, regardless of the environment
observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that
the subclass doesn't have to.
This method effectively wraps the `select_actions` method of the subclass. The following assumptions are made:
1. The `select_actions` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is
the action trajectory horizon and * is the action dimensions.
2. Prior to the `select_actions` method being called, theres is an `n_action_steps` instance attribute defined.
"""
if self.n_action_steps is None:
return self.select_actions(*args, **kwargs)
if len(self._action_queue) == 0:
# `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has shape
# (n_action_steps, batch_size, *), hence the transpose.
self._action_queue.extend(self.select_actions(*args, **kwargs).transpose(0, 1))
return self._action_queue.popleft()

View File

@ -2,10 +2,10 @@ import logging
import time
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
import torchvision.transforms as transforms
from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.policies.act.detr_vae import build
@ -40,9 +40,9 @@ def kl_divergence(mu, logvar):
return total_kld, dimension_wise_kld, mean_kld
class ActionChunkingTransformerPolicy(nn.Module):
class ActionChunkingTransformerPolicy(AbstractPolicy):
def __init__(self, cfg, device, n_action_steps=1):
super().__init__()
super().__init__(n_action_steps)
self.cfg = cfg
self.n_action_steps = n_action_steps
self.device = device
@ -147,16 +147,15 @@ class ActionChunkingTransformerPolicy(nn.Module):
return loss
@torch.no_grad()
def forward(self, observation, step_count):
def select_actions(self, observation, step_count):
if observation["image"].shape[0] != 1:
raise NotImplementedError("Batch size > 1 not handled")
# TODO(rcadene): remove unused step_count
del step_count
self.eval()
# TODO(rcadene): remove unsqueeze hack to add bsize=1
observation["image", "top"] = observation["image", "top"].unsqueeze(0)
# observation["state"] = observation["state"].unsqueeze(0)
# TODO(rcadene): remove hack
# add 1 camera dimension
observation["image", "top"] = observation["image", "top"].unsqueeze(1)
@ -180,11 +179,8 @@ class ActionChunkingTransformerPolicy(nn.Module):
# exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1)
# raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True)
# remove bsize=1
action = action.squeeze(0)
# take first predicted action or n first actions
action = action[0] if self.n_action_steps == 1 else action[: self.n_action_steps]
action = action[: self.n_action_steps]
return action
def _forward(self, qpos, image, actions=None, is_pad=None):

View File

@ -3,14 +3,14 @@ import time
import hydra
import torch
import torch.nn as nn
from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder
class DiffusionPolicy(nn.Module):
class DiffusionPolicy(AbstractPolicy):
def __init__(
self,
cfg,
@ -34,7 +34,7 @@ class DiffusionPolicy(nn.Module):
# parameters passed to step
**kwargs,
):
super().__init__()
super().__init__(n_action_steps)
self.cfg = cfg
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
@ -93,21 +93,16 @@ class DiffusionPolicy(nn.Module):
)
@torch.no_grad()
def forward(self, observation, step_count):
def select_actions(self, observation, step_count):
# TODO(rcadene): remove unused step_count
del step_count
# TODO(rcadene): remove unsqueeze hack to add bsize=1
observation["image"] = observation["image"].unsqueeze(0)
observation["state"] = observation["state"].unsqueeze(0)
obs_dict = {
"image": observation["image"],
"agent_pos": observation["state"],
}
out = self.diffusion.predict_action(obs_dict)
action = out["action"].squeeze(0)
action = out["action"]
return action
def update(self, replay_buffer, step):

View File

@ -1,4 +1,7 @@
def make_policy(cfg):
if cfg.policy.name != "diffusion" and cfg.rollout_batch_size > 1:
raise NotImplementedError("Only diffusion policy supports rollout_batch_size > 1 for the time being.")
if cfg.policy.name == "tdmpc":
from lerobot.common.policies.tdmpc.policy import TDMPC

View File

@ -9,6 +9,7 @@ import torch
import torch.nn as nn
import lerobot.common.policies.tdmpc.helper as h
from lerobot.common.policies.abstract import AbstractPolicy
FIRST_FRAME = 0
@ -85,11 +86,11 @@ class TOLD(nn.Module):
return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2
class TDMPC(nn.Module):
class TDMPC(AbstractPolicy):
"""Implementation of TD-MPC learning + inference."""
def __init__(self, cfg, device):
super().__init__()
super().__init__(None)
self.action_dim = cfg.action_dim
self.cfg = cfg
@ -124,20 +125,19 @@ class TDMPC(nn.Module):
self.model_target.load_state_dict(d["model_target"])
@torch.no_grad()
def forward(self, observation, step_count):
t0 = step_count.item() == 0
def select_actions(self, observation, step_count):
if observation["image"].shape[0] != 1:
raise NotImplementedError("Batch size > 1 not handled")
# TODO(rcadene): remove unsqueeze hack...
if observation["image"].ndim == 3:
observation["image"] = observation["image"].unsqueeze(0)
observation["state"] = observation["state"].unsqueeze(0)
t0 = step_count.item() == 0
obs = {
# TODO(rcadene): remove contiguous hack...
"rgb": observation["image"].contiguous(),
"state": observation["state"].contiguous(),
}
action = self.act(obs, t0=t0, step=self.step.item())
# Note: unsqueeze needed because `act` still uses non-batch logic.
action = self.act(obs, t0=t0, step=self.step.item()).unsqueeze(0)
return action
@torch.no_grad()

View File

@ -10,6 +10,9 @@ hydra:
name: default
seed: 1337
# batch size for TorchRL SerialEnv. Each underlying env will get the seed = seed + env_index
# NOTE: only diffusion policy supports rollout_batch_size > 1
rollout_batch_size: 1
device: cuda # cpu
prefetch: 4
eval_freq: ???

View File

@ -3,6 +3,7 @@ import threading
import time
from pathlib import Path
import einops
import hydra
import imageio
import numpy as np
@ -10,10 +11,12 @@ import torch
import tqdm
from tensordict.nn import TensorDictModule
from torchrl.envs import EnvBase
from torchrl.envs.batched_envs import BatchedEnvBase
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.envs.factory import make_env
from lerobot.common.logger import log_output_dir
from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils import init_logging, set_seed
@ -23,8 +26,8 @@ def write_video(video_path, stacked_frames, fps):
def eval_policy(
env: EnvBase,
policy: TensorDictModule = None,
env: BatchedEnvBase,
policy: AbstractPolicy,
num_episodes: int = 10,
max_steps: int = 30,
save_video: bool = False,
@ -37,55 +40,75 @@ def eval_policy(
sum_rewards = []
max_rewards = []
successes = []
threads = []
for i in tqdm.tqdm(range(num_episodes)):
ep_frames = []
if save_video or (return_first_video and i == 0):
threads = [] # for video saving threads
episode_counter = 0 # for saving the correct number of videos
def render_frame(env):
# TODO(alexander-soare): if num_episodes is not evenly divisible by the batch size, this will do more work than
# needed as I'm currently taking a ceil.
for i in tqdm.tqdm(range(-(-num_episodes // env.batch_size[0]))):
ep_frames = []
def maybe_render_frame(env: EnvBase, _):
if save_video or (return_first_video and i == 0): # noqa: B023
ep_frames.append(env.render()) # noqa: B023
env.register_rendering_hook(render_frame)
with torch.inference_mode():
# TODO(alexander-soare): When `break_when_any_done == False` this rolls out for max_steps even when all
# envs are done the first time. But we only use the first rollout. This is a waste of compute.
policy.clear_action_queue()
rollout = env.rollout(
max_steps=max_steps,
policy=policy,
auto_cast_to_device=True,
callback=maybe_render_frame,
break_when_any_done=env.batch_size[0] == 1,
)
# print(", ".join([f"{x:.3f}" for x in rollout["next", "reward"][:,0].tolist()]))
ep_sum_reward = rollout["next", "reward"].sum()
ep_max_reward = rollout["next", "reward"].max()
ep_success = rollout["next", "success"].any()
sum_rewards.append(ep_sum_reward.item())
max_rewards.append(ep_max_reward.item())
successes.append(ep_success.item())
# Figure out where in each rollout sequence the first done condition was encountered (results after this won't
# be included).
# Note: this assumes that the shape of the done key is (batch_size, max_steps, 1).
# Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker.
rollout_steps = rollout["next", "done"].shape[1]
done_indices = torch.argmax(rollout["next", "done"].to(int), axis=1) # (batch_size, rollout_steps)
mask = (torch.arange(rollout_steps) <= done_indices).unsqueeze(-1) # (batch_size, rollout_steps, 1)
batch_sum_reward = einops.reduce((rollout["next", "reward"] * mask), "b n 1 -> b", "sum")
batch_max_reward = einops.reduce((rollout["next", "reward"] * mask), "b n 1 -> b", "max")
batch_success = einops.reduce((rollout["next", "success"] * mask), "b n 1 -> b", "any")
sum_rewards.extend(batch_sum_reward.tolist())
max_rewards.extend(batch_max_reward.tolist())
successes.extend(batch_success.tolist())
if save_video or (return_first_video and i == 0):
stacked_frames = np.stack(ep_frames)
batch_stacked_frames = np.stack(ep_frames) # (t, b, *)
batch_stacked_frames = batch_stacked_frames.transpose(
1, 0, *range(2, batch_stacked_frames.ndim)
) # (b, t, *)
if save_video:
video_dir.mkdir(parents=True, exist_ok=True)
video_path = video_dir / f"eval_episode_{i}.mp4"
thread = threading.Thread(
target=write_video,
args=(str(video_path), stacked_frames, fps),
)
thread.start()
threads.append(thread)
for stacked_frames, done_index in zip(
batch_stacked_frames, done_indices.flatten().tolist(), strict=False
):
if episode_counter >= num_episodes:
continue
video_dir.mkdir(parents=True, exist_ok=True)
video_path = video_dir / f"eval_episode_{episode_counter}.mp4"
thread = threading.Thread(
target=write_video,
args=(str(video_path), stacked_frames[:done_index], fps),
)
thread.start()
threads.append(thread)
episode_counter += 1
if return_first_video and i == 0:
first_video = stacked_frames.transpose(0, 3, 1, 2)
env.reset_rendering_hooks()
first_video = batch_stacked_frames[0].transpose(0, 3, 1, 2)
for thread in threads:
thread.join()
info = {
"avg_sum_reward": np.nanmean(sum_rewards),
"avg_max_reward": np.nanmean(max_rewards),
"pc_success": np.nanmean(successes) * 100,
"avg_sum_reward": np.nanmean(sum_rewards[:num_episodes]),
"avg_max_reward": np.nanmean(max_rewards[:num_episodes]),
"pc_success": np.nanmean(successes[:num_episodes]) * 100,
"eval_s": time.time() - start,
"eval_ep_s": (time.time() - start) / num_episodes,
}
@ -139,7 +162,7 @@ def eval(cfg: dict, out_dir=None):
save_video=True,
video_dir=Path(out_dir) / "eval",
fps=cfg.env.fps,
max_steps=cfg.env.episode_length // cfg.n_action_steps,
max_steps=cfg.env.episode_length,
num_episodes=cfg.eval_episodes,
)
print(metrics)

View File

@ -112,6 +112,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
raise NotImplementedError()
if job_name is None:
raise NotImplementedError()
if cfg.online_steps > 0:
assert cfg.rollout_batch_size == 1, "rollout_batch_size > 1 not supported for online training steps"
init_logging()
@ -192,7 +194,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
env,
td_policy,
num_episodes=cfg.eval_episodes,
max_steps=cfg.env.episode_length // cfg.n_action_steps,
max_steps=cfg.env.episode_length,
return_first_video=True,
video_dir=Path(out_dir) / "eval",
save_video=True,
@ -218,11 +220,11 @@ def train(cfg: dict, out_dir=None, job_name=None):
# TODO: add configurable number of rollout? (default=1)
with torch.no_grad():
rollout = env.rollout(
max_steps=cfg.env.episode_length // cfg.n_action_steps,
max_steps=cfg.env.episode_length,
policy=td_policy,
auto_cast_to_device=True,
)
assert len(rollout) <= cfg.env.episode_length // cfg.n_action_steps
assert len(rollout) <= cfg.env.episode_length
# set same episode index for all time steps contained in this rollout
rollout["episode"] = torch.tensor([env_step] * len(rollout), dtype=torch.int)
online_buffer.extend(rollout)

81
poetry.lock generated
View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
[[package]]
name = "absl-py"
@ -658,13 +658,13 @@ typing = ["typing-extensions (>=4.8)"]
[[package]]
name = "fsspec"
version = "2024.2.0"
version = "2024.3.1"
description = "File-system specification"
optional = false
python-versions = ">=3.8"
files = [
{file = "fsspec-2024.2.0-py3-none-any.whl", hash = "sha256:817f969556fa5916bc682e02ca2045f96ff7f586d45110fcb76022063ad2c7d8"},
{file = "fsspec-2024.2.0.tar.gz", hash = "sha256:b6ad1a679f760dda52b1168c859d01b7b80648ea6f7f7c7f5a8a91dc3f3ecb84"},
{file = "fsspec-2024.3.1-py3-none-any.whl", hash = "sha256:918d18d41bf73f0e2b261824baeb1b124bcf771767e3a26425cd7dec3332f512"},
{file = "fsspec-2024.3.1.tar.gz", hash = "sha256:f39780e282d7d117ffb42bb96992f8a90795e4d0fb0f661a70ca39fe9c43ded9"},
]
[package.extras]
@ -1468,32 +1468,32 @@ setuptools = "*"
[[package]]
name = "numba"
version = "0.59.0"
version = "0.59.1"
description = "compiling Python code using LLVM"
optional = false
python-versions = ">=3.9"
files = [
{file = "numba-0.59.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8d061d800473fb8fef76a455221f4ad649a53f5e0f96e3f6c8b8553ee6fa98fa"},
{file = "numba-0.59.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c086a434e7d3891ce5dfd3d1e7ee8102ac1e733962098578b507864120559ceb"},
{file = "numba-0.59.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9e20736bf62e61f8353fb71b0d3a1efba636c7a303d511600fc57648b55823ed"},
{file = "numba-0.59.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e86e6786aec31d2002122199486e10bbc0dc40f78d76364cded375912b13614c"},
{file = "numba-0.59.0-cp310-cp310-win_amd64.whl", hash = "sha256:0307ee91b24500bb7e64d8a109848baf3a3905df48ce142b8ac60aaa406a0400"},
{file = "numba-0.59.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d540f69a8245fb714419c2209e9af6104e568eb97623adc8943642e61f5d6d8e"},
{file = "numba-0.59.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1192d6b2906bf3ff72b1d97458724d98860ab86a91abdd4cfd9328432b661e31"},
{file = "numba-0.59.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:90efb436d3413809fcd15298c6d395cb7d98184350472588356ccf19db9e37c8"},
{file = "numba-0.59.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:cd3dac45e25d927dcb65d44fb3a973994f5add2b15add13337844afe669dd1ba"},
{file = "numba-0.59.0-cp311-cp311-win_amd64.whl", hash = "sha256:753dc601a159861808cc3207bad5c17724d3b69552fd22768fddbf302a817a4c"},
{file = "numba-0.59.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:ce62bc0e6dd5264e7ff7f34f41786889fa81a6b860662f824aa7532537a7bee0"},
{file = "numba-0.59.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8cbef55b73741b5eea2dbaf1b0590b14977ca95a13a07d200b794f8f6833a01c"},
{file = "numba-0.59.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:70d26ba589f764be45ea8c272caa467dbe882b9676f6749fe6f42678091f5f21"},
{file = "numba-0.59.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e125f7d69968118c28ec0eed9fbedd75440e64214b8d2eac033c22c04db48492"},
{file = "numba-0.59.0-cp312-cp312-win_amd64.whl", hash = "sha256:4981659220b61a03c1e557654027d271f56f3087448967a55c79a0e5f926de62"},
{file = "numba-0.59.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:fe4d7562d1eed754a7511ed7ba962067f198f86909741c5c6e18c4f1819b1f47"},
{file = "numba-0.59.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:6feb1504bb432280f900deaf4b1dadcee68812209500ed3f81c375cbceab24dc"},
{file = "numba-0.59.0-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:944faad25ee23ea9dda582bfb0189fb9f4fc232359a80ab2a028b94c14ce2b1d"},
{file = "numba-0.59.0-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5516a469514bfae52a9d7989db4940653a5cbfac106f44cb9c50133b7ad6224b"},
{file = "numba-0.59.0-cp39-cp39-win_amd64.whl", hash = "sha256:32bd0a41525ec0b1b853da244808f4e5333867df3c43c30c33f89cf20b9c2b63"},
{file = "numba-0.59.0.tar.gz", hash = "sha256:12b9b064a3e4ad00e2371fc5212ef0396c80f41caec9b5ec391c8b04b6eaf2a8"},
{file = "numba-0.59.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:97385a7f12212c4f4bc28f648720a92514bee79d7063e40ef66c2d30600fd18e"},
{file = "numba-0.59.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0b77aecf52040de2a1eb1d7e314497b9e56fba17466c80b457b971a25bb1576d"},
{file = "numba-0.59.1-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:3476a4f641bfd58f35ead42f4dcaf5f132569c4647c6f1360ccf18ee4cda3990"},
{file = "numba-0.59.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:525ef3f820931bdae95ee5379c670d5c97289c6520726bc6937a4a7d4230ba24"},
{file = "numba-0.59.1-cp310-cp310-win_amd64.whl", hash = "sha256:990e395e44d192a12105eca3083b61307db7da10e093972ca285c85bef0963d6"},
{file = "numba-0.59.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:43727e7ad20b3ec23ee4fc642f5b61845c71f75dd2825b3c234390c6d8d64051"},
{file = "numba-0.59.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:411df625372c77959570050e861981e9d196cc1da9aa62c3d6a836b5cc338966"},
{file = "numba-0.59.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:2801003caa263d1e8497fb84829a7ecfb61738a95f62bc05693fcf1733e978e4"},
{file = "numba-0.59.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:dd2842fac03be4e5324ebbbd4d2d0c8c0fc6e0df75c09477dd45b288a0777389"},
{file = "numba-0.59.1-cp311-cp311-win_amd64.whl", hash = "sha256:0594b3dfb369fada1f8bb2e3045cd6c61a564c62e50cf1f86b4666bc721b3450"},
{file = "numba-0.59.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:1cce206a3b92836cdf26ef39d3a3242fec25e07f020cc4feec4c4a865e340569"},
{file = "numba-0.59.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8c8b4477763cb1fbd86a3be7050500229417bf60867c93e131fd2626edb02238"},
{file = "numba-0.59.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:7d80bce4ef7e65bf895c29e3889ca75a29ee01da80266a01d34815918e365835"},
{file = "numba-0.59.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f7ad1d217773e89a9845886401eaaab0a156a90aa2f179fdc125261fd1105096"},
{file = "numba-0.59.1-cp312-cp312-win_amd64.whl", hash = "sha256:5bf68f4d69dd3a9f26a9b23548fa23e3bcb9042e2935257b471d2a8d3c424b7f"},
{file = "numba-0.59.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4e0318ae729de6e5dbe64c75ead1a95eb01fabfe0e2ebed81ebf0344d32db0ae"},
{file = "numba-0.59.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0f68589740a8c38bb7dc1b938b55d1145244c8353078eea23895d4f82c8b9ec1"},
{file = "numba-0.59.1-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:649913a3758891c77c32e2d2a3bcbedf4a69f5fea276d11f9119677c45a422e8"},
{file = "numba-0.59.1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9712808e4545270291d76b9a264839ac878c5eb7d8b6e02c970dc0ac29bc8187"},
{file = "numba-0.59.1-cp39-cp39-win_amd64.whl", hash = "sha256:8d51ccd7008a83105ad6a0082b6a2b70f1142dc7cfd76deb8c5a862367eb8c86"},
{file = "numba-0.59.1.tar.gz", hash = "sha256:76f69132b96028d2774ed20415e8c528a34e3299a40581bae178f0994a2f370b"},
]
[package.dependencies]
@ -2684,13 +2684,13 @@ test = ["asv", "gmpy2", "hypothesis", "mpmath", "pooch", "pytest", "pytest-cov",
[[package]]
name = "sentry-sdk"
version = "1.41.0"
version = "1.42.0"
description = "Python client for Sentry (https://sentry.io)"
optional = false
python-versions = "*"
files = [
{file = "sentry-sdk-1.41.0.tar.gz", hash = "sha256:4f2d6c43c07925d8cd10dfbd0970ea7cb784f70e79523cca9dbcd72df38e5a46"},
{file = "sentry_sdk-1.41.0-py2.py3-none-any.whl", hash = "sha256:be4f8f4b29a80b6a3b71f0f31487beb9e296391da20af8504498a328befed53f"},
{file = "sentry-sdk-1.42.0.tar.gz", hash = "sha256:4a8364b8f7edbf47f95f7163e48334c96100d9c098f0ae6606e2e18183c223e6"},
{file = "sentry_sdk-1.42.0-py2.py3-none-any.whl", hash = "sha256:a654ee7e497a3f5f6368b36d4f04baeab1fe92b3105f7f6965d6ef0de35a9ba4"},
]
[package.dependencies]
@ -2714,6 +2714,7 @@ grpcio = ["grpcio (>=1.21.1)"]
httpx = ["httpx (>=0.16.0)"]
huey = ["huey (>=2)"]
loguru = ["loguru (>=0.5)"]
openai = ["openai (>=1.0.0)", "tiktoken (>=0.3.0)"]
opentelemetry = ["opentelemetry-distro (>=0.35b0)"]
opentelemetry-experimental = ["opentelemetry-distro (>=0.40b0,<1.0)", "opentelemetry-instrumentation-aiohttp-client (>=0.40b0,<1.0)", "opentelemetry-instrumentation-django (>=0.40b0,<1.0)", "opentelemetry-instrumentation-fastapi (>=0.40b0,<1.0)", "opentelemetry-instrumentation-flask (>=0.40b0,<1.0)", "opentelemetry-instrumentation-requests (>=0.40b0,<1.0)", "opentelemetry-instrumentation-sqlite3 (>=0.40b0,<1.0)", "opentelemetry-instrumentation-urllib (>=0.40b0,<1.0)"]
pure-eval = ["asttokens", "executing", "pure-eval"]
@ -2829,18 +2830,18 @@ test = ["pytest"]
[[package]]
name = "setuptools"
version = "69.1.1"
version = "69.2.0"
description = "Easily download, build, install, upgrade, and uninstall Python packages"
optional = false
python-versions = ">=3.8"
files = [
{file = "setuptools-69.1.1-py3-none-any.whl", hash = "sha256:02fa291a0471b3a18b2b2481ed902af520c69e8ae0919c13da936542754b4c56"},
{file = "setuptools-69.1.1.tar.gz", hash = "sha256:5c0806c7d9af348e6dd3777b4f4dbb42c7ad85b190104837488eab9a7c945cf8"},
{file = "setuptools-69.2.0-py3-none-any.whl", hash = "sha256:c21c49fb1042386df081cb5d86759792ab89efca84cf114889191cd09aacc80c"},
{file = "setuptools-69.2.0.tar.gz", hash = "sha256:0ff4183f8f42cd8fa3acea16c45205521a4ef28f73c6391d8a25e92893134f2e"},
]
[package.extras]
docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"]
testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "packaging (>=23.2)", "pip (>=19.1)", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff (>=0.2.1)", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"]
testing = ["build[virtualenv]", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "mypy (==1.9)", "packaging (>=23.2)", "pip (>=19.1)", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff (>=0.2.1)", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"]
testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "packaging (>=23.2)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"]
[[package]]
@ -2949,7 +2950,7 @@ mpmath = ">=0.19"
[[package]]
name = "tensordict"
version = "0.4.0+551331d"
version = "0.4.0+ca4256e"
description = ""
optional = false
python-versions = "*"
@ -2970,7 +2971,7 @@ tests = ["pytest", "pytest-benchmark", "pytest-instafail", "pytest-rerunfailures
type = "git"
url = "https://github.com/pytorch/tensordict"
reference = "HEAD"
resolved_reference = "ed22554d6860731610df784b2f5d09f31d3dbc7a"
resolved_reference = "b4c91e8828c538ca0a50d8383fd99311a9afb078"
[[package]]
name = "termcolor"
@ -3311,18 +3312,18 @@ jupyter = ["ipytree (>=0.2.2)", "ipywidgets (>=8.0.0)", "notebook"]
[[package]]
name = "zipp"
version = "3.17.0"
version = "3.18.1"
description = "Backport of pathlib-compatible object wrapper for zip files"
optional = false
python-versions = ">=3.8"
files = [
{file = "zipp-3.17.0-py3-none-any.whl", hash = "sha256:0e923e726174922dce09c53c59ad483ff7bbb8e572e00c7f7c46b88556409f31"},
{file = "zipp-3.17.0.tar.gz", hash = "sha256:84e64a1c28cf7e91ed2078bb8cc8c259cb19b76942096c8d7b84947690cabaf0"},
{file = "zipp-3.18.1-py3-none-any.whl", hash = "sha256:206f5a15f2af3dbaee80769fb7dc6f249695e940acca08dfb2a4769fe61e538b"},
{file = "zipp-3.18.1.tar.gz", hash = "sha256:2884ed22e7d8961de1c9a05142eb69a247f120291bc0206a00a7642f09b5b715"},
]
[package.extras]
docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"]
testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy (>=0.9.1)", "pytest-ruff"]
docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"]
[metadata]
lock-version = "2.0"

Binary file not shown.

View File

@ -1,25 +1,138 @@
from omegaconf import open_dict
import pytest
from tensordict import TensorDict
from tensordict.nn import TensorDictModule
import torch
from torchrl.data import UnboundedContinuousTensorSpec
from torchrl.envs import EnvBase
from lerobot.common.policies.factory import make_policy
from lerobot.common.envs.factory import make_env
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.policies.abstract import AbstractPolicy
from .utils import DEVICE, init_config
@pytest.mark.parametrize(
"env_name,policy_name",
"env_name,policy_name,extra_overrides",
[
("simxarm", "tdmpc"),
("pusht", "tdmpc"),
("simxarm", "diffusion"),
("pusht", "diffusion"),
("simxarm", "tdmpc", ["policy.mpc=true"]),
("pusht", "tdmpc", ["policy.mpc=false"]),
("simxarm", "diffusion", []),
("pusht", "diffusion", []),
("aloha", "act", ["env.task=sim_insertion_scripted"]),
],
)
def test_factory(env_name, policy_name):
def test_concrete_policy(env_name, policy_name, extra_overrides):
"""
Tests:
- Making the policy object.
- Updating the policy.
- Using the policy to select actions at inference time.
"""
cfg = init_config(
overrides=[
f"env={env_name}",
f"policy={policy_name}",
f"device={DEVICE}",
]
+ extra_overrides
)
# Check that we can make the policy object.
policy = make_policy(cfg)
# Check that we run select_actions and get the appropriate output.
if env_name == "simxarm":
# TODO(rcadene): Not implemented
return
if policy_name == "tdmpc":
# TODO(alexander-soare): TDMPC does not use n_obs_steps but the environment requires this.
with open_dict(cfg):
cfg["n_obs_steps"] = 1
offline_buffer = make_offline_buffer(cfg)
env = make_env(cfg, transform=offline_buffer.transform)
if env_name != "aloha":
# TODO(alexander-soare): Fix this part of the test. PrioritizedSliceSampler raises NotImplementedError:
# seq_length as a list is not supported for now.
policy.update(offline_buffer, torch.tensor(0, device=DEVICE))
action = policy(
env.observation_spec.rand()["observation"].to(DEVICE),
torch.tensor(0, device=DEVICE),
)
assert action.shape == env.action_spec.shape
def test_abstract_policy_forward():
"""
Given an underlying policy that produces an action trajectory with n_action_steps actions, checks that:
- The policy is invoked the expected number of times during a rollout.
- The environment's termination condition is respected even when part way through an action trajectory.
- The observations are returned correctly.
"""
n_action_steps = 8 # our test policy will output 8 action step horizons
terminate_at = 10 # some number that is more than n_action_steps but not a multiple
rollout_max_steps = terminate_at + 1 # some number greater than terminate_at
# A minimal environment for testing.
class StubEnv(EnvBase):
def __init__(self):
super().__init__()
self.action_spec = UnboundedContinuousTensorSpec(shape=(1,))
self.reward_spec = UnboundedContinuousTensorSpec(shape=(1,))
def _step(self, tensordict: TensorDict) -> TensorDict:
self.invocation_count += 1
return TensorDict(
{
"observation": torch.tensor([self.invocation_count]),
"reward": torch.tensor([self.invocation_count]),
"terminated": torch.tensor(
tensordict["action"].item() == terminate_at
),
}
)
def _reset(self, tensordict: TensorDict) -> TensorDict:
self.invocation_count = 0
return TensorDict(
{
"observation": torch.tensor([self.invocation_count]),
"reward": torch.tensor([self.invocation_count]),
}
)
def _set_seed(self, seed: int | None):
return
class StubPolicy(AbstractPolicy):
def __init__(self):
super().__init__(n_action_steps)
self.n_policy_invocations = 0
def update(self):
pass
def select_actions(self):
self.n_policy_invocations += 1
return torch.stack(
[torch.tensor([i]) for i in range(self.n_action_steps)]
).unsqueeze(0)
env = StubEnv()
policy = StubPolicy()
policy = TensorDictModule(
policy,
in_keys=[],
out_keys=["action"],
)
# Keep track to make sure the policy is called the expected number of times
rollout = env.rollout(rollout_max_steps, policy)
assert len(rollout) == terminate_at + 1 # +1 for the reset observation
assert policy.n_policy_invocations == (terminate_at // n_action_steps) + 1
assert torch.equal(rollout["observation"].flatten(), torch.arange(terminate_at + 1))