Merge remote-tracking branch 'upstream/user/rcadene/2024_03_31_remove_torchrl' into refactor_act

This commit is contained in:
Alexander Soare 2024-04-09 08:36:28 +01:00
commit e6c6c2367f
13 changed files with 109 additions and 247 deletions

126
.github/poetry/cpu/poetry.lock generated vendored
View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand.
[[package]]
name = "absl-py"
@ -889,6 +889,69 @@ files = [
[package.extras]
protobuf = ["grpcio-tools (>=1.62.1)"]
[[package]]
name = "gym-aloha"
version = "0.1.0"
description = "A gym environment for ALOHA"
optional = true
python-versions = "^3.10"
files = []
develop = false
[package.dependencies]
dm-control = "1.0.14"
gymnasium = "^0.29.1"
mujoco = "^2.3.7"
[package.source]
type = "git"
url = "git@github.com:huggingface/gym-aloha.git"
reference = "HEAD"
resolved_reference = "c636f05ba0d1760df94537da84c860be1487e17f"
[[package]]
name = "gym-pusht"
version = "0.1.0"
description = "A gymnasium environment for PushT."
optional = true
python-versions = "^3.10"
files = []
develop = false
[package.dependencies]
gymnasium = "^0.29.1"
opencv-python = "^4.9.0.80"
pygame = "^2.5.2"
pymunk = "^6.6.0"
scikit-image = "^0.22.0"
shapely = "^2.0.3"
[package.source]
type = "git"
url = "git@github.com:huggingface/gym-pusht.git"
reference = "HEAD"
resolved_reference = "6c9893504f670ff069d0f759a733e971ea1efdbf"
[[package]]
name = "gym-xarm"
version = "0.1.0"
description = "A gym environment for xArm"
optional = true
python-versions = "^3.10"
files = []
develop = false
[package.dependencies]
gymnasium = "^0.29.1"
gymnasium-robotics = "^1.2.4"
mujoco = "^2.3.7"
[package.source]
type = "git"
url = "git@github.com:huggingface/gym-xarm.git"
reference = "HEAD"
resolved_reference = "08ddd5a9400783a6898bbf3c3014fc5da3961b9d"
[[package]]
name = "gymnasium"
version = "0.29.1"
@ -2988,31 +3051,6 @@ numpy = "*"
packaging = "*"
protobuf = ">=3.20"
[[package]]
name = "tensordict"
version = "0.4.0+b4c91e8"
description = ""
optional = false
python-versions = "*"
files = []
develop = false
[package.dependencies]
cloudpickle = "*"
numpy = "*"
torch = ">=2.1.0"
[package.extras]
checkpointing = ["torchsnapshot-nightly"]
h5 = ["h5py (>=3.8)"]
tests = ["pytest", "pytest-benchmark", "pytest-instafail", "pytest-rerunfailures", "pyyaml"]
[package.source]
type = "git"
url = "https://github.com/pytorch/tensordict"
reference = "HEAD"
resolved_reference = "b4c91e8828c538ca0a50d8383fd99311a9afb078"
[[package]]
name = "termcolor"
version = "2.4.0"
@ -3091,40 +3129,6 @@ type = "legacy"
url = "https://download.pytorch.org/whl/cpu"
reference = "torch-cpu"
[[package]]
name = "torchrl"
version = "0.4.0+13bef42"
description = ""
optional = false
python-versions = "*"
files = []
develop = false
[package.dependencies]
cloudpickle = "*"
numpy = "*"
packaging = "*"
tensordict = ">=0.4.0"
torch = ">=2.1.0"
[package.extras]
all = ["ale-py", "atari-py", "dm_control", "git", "gym", "gym[accept-rom-license]", "gymnasium", "h5py", "huggingface_hub", "hydra-core (>=1.1)", "hydra-submitit-launcher", "minari", "moviepy", "mujoco", "pandas", "pettingzoo (>=1.24.1)", "pillow", "pygame", "pytest", "pytest-instafail", "pyyaml", "requests", "scikit-learn", "scipy", "tensorboard", "torchsnapshot", "torchvision", "tqdm", "vmas (>=1.2.10)", "wandb"]
atari = ["ale-py", "atari-py", "gym", "gym[accept-rom-license]", "pygame"]
checkpointing = ["torchsnapshot"]
dm-control = ["dm_control"]
gym-continuous = ["gymnasium", "mujoco"]
marl = ["pettingzoo (>=1.24.1)", "vmas (>=1.2.10)"]
offline-data = ["h5py", "huggingface_hub", "minari", "pandas", "pillow", "requests", "scikit-learn", "torchvision", "tqdm"]
rendering = ["moviepy"]
tests = ["pytest", "pytest-instafail", "pyyaml", "scipy"]
utils = ["git", "hydra-core (>=1.1)", "hydra-submitit-launcher", "tensorboard", "tqdm", "wandb"]
[package.source]
type = "git"
url = "https://github.com/pytorch/rl"
reference = "13bef426dcfa5887c6e5034a6e9697993fa92c37"
resolved_reference = "13bef426dcfa5887c6e5034a6e9697993fa92c37"
[[package]]
name = "torchvision"
version = "0.17.1+cpu"
@ -3330,4 +3334,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "8800bb8b24312d17b765cd2ce2799f49436171dd5fbf1bec3b07f853cfa9befd"
content-hash = "32cd6caa01276a90b37cb177204e5b1511e92838f3f0268391034042d56f3bd6"

View File

@ -39,8 +39,6 @@ scikit-image = "^0.22.0"
numba = "^0.59.0"
mpmath = "^1.3.0"
torch = {version = "^2.2.1", source = "torch-cpu"}
tensordict = {git = "https://github.com/pytorch/tensordict"}
torchrl = {git = "https://github.com/pytorch/rl", rev = "13bef426dcfa5887c6e5034a6e9697993fa92c37"}
mujoco = "^2.3.7"
opencv-python = "^4.9.0.80"
diffusers = "^0.26.3"
@ -53,7 +51,12 @@ huggingface-hub = "^0.21.4"
gymnasium-robotics = "^1.2.4"
gymnasium = "^0.29.1"
cmake = "^3.29.0.1"
gym-pusht = { git = "git@github.com:huggingface/gym-pusht.git", optional = true}
gym-xarm = { git = "git@github.com:huggingface/gym-xarm.git", optional = true}
gym-aloha = { git = "git@github.com:huggingface/gym-aloha.git", optional = true}
# gym-pusht = { path = "../gym-pusht", develop = true, optional = true}
# gym-xarm = { path = "../gym-xarm", develop = true, optional = true}
# gym-aloha = { path = "../gym-aloha", develop = true, optional = true}
[tool.poetry.group.dev.dependencies]
pre-commit = "^3.6.2"

View File

@ -1,82 +0,0 @@
from collections import deque
import torch
from torch import Tensor, nn
class AbstractPolicy(nn.Module):
"""Base policy which all policies should be derived from.
The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its
documentation for more information.
Note:
When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
1. set the required class attributes:
- for classes inheriting from `AbstractDataset`: `available_datasets`
- for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
- for classes inheriting from `AbstractPolicy`: `name`
2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
3. update variables in `tests/test_available.py` by importing your new class
"""
name: str | None = None # same name should be used to instantiate the policy in factory.py
def __init__(self, n_action_steps: int | None):
"""
n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single
action is returned by `select_actions` and that doesn't have a horizon dimension. The `forward` method then
adds that dimension.
"""
super().__init__()
assert self.name is not None, "Subclasses of `AbstractPolicy` should set the `name` class attribute."
self.n_action_steps = n_action_steps
self.clear_action_queue()
def update(self, replay_buffer, step):
"""One step of the policy's learning algorithm."""
raise NotImplementedError("Abstract method")
def save(self, fp):
torch.save(self.state_dict(), fp)
def load(self, fp):
d = torch.load(fp)
self.load_state_dict(d)
def select_actions(self, observation) -> Tensor:
"""Select an action (or trajectory of actions) based on an observation during rollout.
If n_action_steps was provided at initialization, this should return a (batch_size, n_action_steps, *) tensor of
actions. Otherwise if n_actions_steps is None, this should return a (batch_size, *) tensor of actions.
"""
raise NotImplementedError("Abstract method")
def clear_action_queue(self):
"""This should be called whenever the environment is reset."""
if self.n_action_steps is not None:
self._action_queue = deque([], maxlen=self.n_action_steps)
def forward(self, *args, **kwargs) -> Tensor:
"""Inference step that makes multi-step policies compatible with their single-step environments.
WARNING: In general, this should not be overriden.
Consider a "policy" that observes the environment then charts a course of N actions to take. To make this fit
into the formalism of a TorchRL environment, we view it as being effectively a policy that (1) makes an
observation and prepares a queue of actions, (2) consumes that queue when queried, regardless of the environment
observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that
the subclass doesn't have to.
This method effectively wraps the `select_actions` method of the subclass. The following assumptions are made:
1. The `select_actions` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is
the action trajectory horizon and * is the action dimensions.
2. Prior to the `select_actions` method being called, theres is an `n_action_steps` instance attribute defined.
"""
if self.n_action_steps is None:
return self.select_actions(*args, **kwargs)
if len(self._action_queue) == 0:
# `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has shape
# (n_action_steps, batch_size, *), hence the transpose.
self._action_queue.extend(self.select_actions(*args, **kwargs).transpose(0, 1))
return self._action_queue.popleft()

View File

@ -65,7 +65,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
"ActionChunkingTransformerPolicy does not handle multiple observation steps."
)
def __init__(self, cfg, device, n_action_steps=1):
def __init__(self, cfg, device):
"""
TODO(alexander-soare): Add documentation for all parameters once we have model configs established.
"""
@ -73,7 +73,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
if getattr(cfg, "n_obs_steps", 1) != 1:
raise ValueError(self._multiple_obs_steps_not_handled_msg)
self.cfg = cfg
self.n_action_steps = n_action_steps
self.n_action_steps = cfg.n_action_steps
self.device = get_safe_torch_device(device)
self.camera_names = cfg.camera_names
self.use_vae = cfg.use_vae
@ -176,7 +176,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
if self.n_action_steps is not None:
self._action_queue = deque([], maxlen=self.n_action_steps)
def select_action(self, batch: dict[str, Tensor], *_) -> Tensor:
def select_action(self, batch: dict[str, Tensor], *_, **__) -> Tensor:
"""
This method wraps `select_actions` in order to return one action at a time for execution in the
environment. It works by managing the actions in a queue and only calling `select_actions` when the
@ -244,7 +244,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
# Note: no squeeze is required for "observation.images.top" because then we'd have to unsqueeze to get
# the image index dimension.
def update(self, batch, *_) -> dict:
def update(self, batch, *_, **__) -> dict:
start_time = time.time()
self._preprocess_batch(batch)
@ -311,7 +311,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
def _forward(
self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None
) -> tuple[Tensor, tuple[Tensor, Tensor]]:
) -> tuple[Tensor, tuple[Tensor | None, Tensor | None]]:
"""
Args:
robot_state: (B, J) batch of robot joint configurations.
@ -344,16 +344,16 @@ class ActionChunkingTransformerPolicy(nn.Module):
# Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D)
# Forward pass through VAE encoder.
# Forward pass through VAE encoder to get the latent PDF parameters.
cls_token_out = self.vae_encoder(
vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2)
)[0] # select the class token, with shape (B, D)
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
# Sample the latent with the reparameterization trick.
mu = latent_pdf_params[:, : self.latent_dim]
# This is 2log(sigma). Done this way to match the original implementation.
log_sigma_x2 = latent_pdf_params[:, self.latent_dim :]
# Sample the latent with the reparameterization trick.
latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu)
else:
# When not using the VAE encoder, we set the latent to be all zeros.

View File

@ -16,18 +16,15 @@ def make_policy(cfg):
cfg_obs_encoder=cfg.obs_encoder,
cfg_optimizer=cfg.optimizer,
cfg_ema=cfg.ema,
n_obs_steps=cfg.n_obs_steps,
n_action_steps=cfg.n_action_steps,
# n_obs_steps=cfg.n_obs_steps,
# n_action_steps=cfg.n_action_steps,
**cfg.policy,
)
elif cfg.policy.name == "act":
from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
policy = ActionChunkingTransformerPolicy(
cfg.policy,
cfg.device,
n_action_steps=cfg.n_action_steps,
)
policy = ActionChunkingTransformerPolicy(cfg.policy, cfg.device)
policy.to(cfg.device)
else:
raise ValueError(cfg.policy.name)

View File

@ -110,7 +110,6 @@ class TDMPCPolicy(nn.Module):
# self.bc_optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
self.model.eval()
self.model_target.eval()
self.batch_size = cfg.batch_size
self.register_buffer("step", torch.zeros(1))
@ -151,6 +150,8 @@ class TDMPCPolicy(nn.Module):
t0 = step == 0
self.eval()
if len(self._queues["action"]) == 0:
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
@ -172,7 +173,7 @@ class TDMPCPolicy(nn.Module):
actions.append(action)
action = torch.stack(actions)
# self.act returns an action for 1 timestep only, so we copy it over `n_action_steps` time
# tdmpc returns an action for 1 timestep only, so we copy it over `n_action_steps` time
if i in range(self.n_action_steps):
self._queues["action"].append(action)
@ -325,7 +326,7 @@ class TDMPCPolicy(nn.Module):
def _td_target(self, next_z, reward, mask):
"""Compute the TD-target from a reward and the observation at the following time step."""
next_v = self.model.V(next_z)
td_target = reward + self.cfg.discount * mask * next_v
td_target = reward + self.cfg.discount * mask * next_v.squeeze(2)
return td_target
def forward(self, batch, step):
@ -420,6 +421,8 @@ class TDMPCPolicy(nn.Module):
# idxs = torch.cat([idxs, demo_idxs])
# weights = torch.cat([weights, demo_weights])
batch_size = batch["index"].shape[0]
# TODO(rcadene): convert tdmpc with (batch size, time/horizon, channels)
# instead of currently (time/horizon, batch size, channels) which is not the pytorch convention
# batch size b = 256, time/horizon t = 5
@ -433,7 +436,7 @@ class TDMPCPolicy(nn.Module):
# idxs = batch["index"] # TODO(rcadene): use idxs to update sampling weights
done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
weights = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
weights = torch.ones(batch_size, dtype=torch.bool, device=reward.device)
obses = {
"rgb": batch["observation.image"],
@ -476,7 +479,7 @@ class TDMPCPolicy(nn.Module):
td_targets = self._td_target(next_z, reward, mask)
# Latent rollout
zs = torch.empty(horizon + 1, self.batch_size, self.cfg.latent_dim, device=self.device)
zs = torch.empty(horizon + 1, batch_size, self.cfg.latent_dim, device=self.device)
reward_preds = torch.empty_like(reward, device=self.device)
assert reward.shape[0] == horizon
z = self.model.encode(obs)
@ -485,22 +488,21 @@ class TDMPCPolicy(nn.Module):
for t in range(horizon):
z, reward_pred = self.model.next(z, action[t])
zs[t + 1] = z
reward_preds[t] = reward_pred
reward_preds[t] = reward_pred.squeeze(1)
with torch.no_grad():
v_target = self.model_target.Q(zs[:-1].detach(), action, return_type="min")
# Predictions
qs = self.model.Q(zs[:-1], action, return_type="all")
qs = qs.squeeze(3)
value_info["Q"] = qs.mean().item()
v = self.model.V(zs[:-1])
value_info["V"] = v.mean().item()
# Losses
rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view(-1, 1, 1)
consistency_loss = (rho * torch.mean(h.mse(zs[1:], z_targets), dim=2, keepdim=True) * loss_mask).sum(
dim=0
)
rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view(-1, 1)
consistency_loss = (rho * torch.mean(h.mse(zs[1:], z_targets), dim=2) * loss_mask).sum(dim=0)
reward_loss = (rho * h.mse(reward_preds, reward) * loss_mask).sum(dim=0)
q_value_loss, priority_loss = 0, 0
for q in range(self.cfg.num_q):
@ -508,7 +510,9 @@ class TDMPCPolicy(nn.Module):
priority_loss += (rho * h.l1(qs[q], td_targets) * loss_mask).sum(dim=0)
expectile = h.linear_schedule(self.cfg.expectile, step)
v_value_loss = (rho * h.l2_expectile(v_target - v, expectile=expectile) * loss_mask).sum(dim=0)
v_value_loss = (rho * h.l2_expectile(v_target - v, expectile=expectile).squeeze(2) * loss_mask).sum(
dim=0
)
total_loss = (
self.cfg.consistency_coef * consistency_loss
@ -517,7 +521,7 @@ class TDMPCPolicy(nn.Module):
+ self.cfg.value_coef * v_value_loss
)
weighted_loss = (total_loss.squeeze(1) * weights).mean()
weighted_loss = (total_loss * weights).mean()
weighted_loss.register_hook(lambda grad: grad * (1 / self.cfg.horizon))
has_nan = torch.isnan(weighted_loss).item()
if has_nan:

View File

@ -51,6 +51,7 @@ policy:
utd: 1
n_obs_steps: ${n_obs_steps}
n_action_steps: ${n_action_steps}
temporal_agg: false

View File

@ -38,6 +38,7 @@ policy:
horizon: ${horizon}
n_obs_steps: ${n_obs_steps}
n_action_steps: ${n_action_steps}
num_inference_steps: 100
obs_as_global_cond: ${obs_as_global_cond}
# crop_shape: null

View File

@ -36,7 +36,6 @@ policy:
log_std_max: 2
# learning
batch_size: 256
max_buffer_size: 10000
horizon: 5
reward_coef: 0.5

View File

@ -86,14 +86,10 @@ def eval_policy(
def maybe_render_frame(env):
if save_video: # noqa: B023
if return_first_video:
# TODO(now): Put mode back in.
visu = env.envs[0].render()
# visu = env.envs[0].render(mode="visualization")
visu = visu[None, ...] # add batch dim
else:
# TODO(now): Put mode back in.
visu = np.stack([env.render() for env in env.envs])
# visu = np.stack([env.render(mode="visualization") for env in env.envs])
ep_frames.append(visu) # noqa: B023
for _ in range(num_episodes):

65
poetry.lock generated
View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand.
[[package]]
name = "absl-py"
@ -898,7 +898,7 @@ mujoco = "^2.3.7"
type = "git"
url = "git@github.com:huggingface/gym-aloha.git"
reference = "HEAD"
resolved_reference = "ec7200831e36c14e343cf7d275c6b047f2fe9d11"
resolved_reference = "c636f05ba0d1760df94537da84c860be1487e17f"
[[package]]
name = "gym-pusht"
@ -3339,31 +3339,6 @@ numpy = "*"
packaging = "*"
protobuf = ">=3.20"
[[package]]
name = "tensordict"
version = "0.4.0+f622b2f"
description = ""
optional = false
python-versions = "*"
files = []
develop = false
[package.dependencies]
cloudpickle = "*"
numpy = "*"
torch = ">=2.1.0"
[package.extras]
checkpointing = ["torchsnapshot-nightly"]
h5 = ["h5py (>=3.8)"]
tests = ["pytest", "pytest-benchmark", "pytest-instafail", "pytest-rerunfailures", "pyyaml"]
[package.source]
type = "git"
url = "https://github.com/pytorch/tensordict"
reference = "HEAD"
resolved_reference = "f622b2f973320f769b6c09793ca827f27e47d603"
[[package]]
name = "termcolor"
version = "2.4.0"
@ -3464,40 +3439,6 @@ typing-extensions = ">=4.8.0"
opt-einsum = ["opt-einsum (>=3.3)"]
optree = ["optree (>=0.9.1)"]
[[package]]
name = "torchrl"
version = "0.4.0+13bef42"
description = ""
optional = false
python-versions = "*"
files = []
develop = false
[package.dependencies]
cloudpickle = "*"
numpy = "*"
packaging = "*"
tensordict = ">=0.4.0"
torch = ">=2.1.0"
[package.extras]
all = ["ale-py", "atari-py", "dm_control", "git", "gym", "gym[accept-rom-license]", "gymnasium", "h5py", "huggingface_hub", "hydra-core (>=1.1)", "hydra-submitit-launcher", "minari", "moviepy", "mujoco", "pandas", "pettingzoo (>=1.24.1)", "pillow", "pygame", "pytest", "pytest-instafail", "pyyaml", "requests", "scikit-learn", "scipy", "tensorboard", "torchsnapshot", "torchvision", "tqdm", "vmas (>=1.2.10)", "wandb"]
atari = ["ale-py", "atari-py", "gym", "gym[accept-rom-license]", "pygame"]
checkpointing = ["torchsnapshot"]
dm-control = ["dm_control"]
gym-continuous = ["gymnasium", "mujoco"]
marl = ["pettingzoo (>=1.24.1)", "vmas (>=1.2.10)"]
offline-data = ["h5py", "huggingface_hub", "minari", "pandas", "pillow", "requests", "scikit-learn", "torchvision", "tqdm"]
rendering = ["moviepy"]
tests = ["pytest", "pytest-instafail", "pyyaml", "scipy"]
utils = ["git", "hydra-core (>=1.1)", "hydra-submitit-launcher", "tensorboard", "tqdm", "wandb"]
[package.source]
type = "git"
url = "https://github.com/pytorch/rl"
reference = "13bef426dcfa5887c6e5034a6e9697993fa92c37"
resolved_reference = "13bef426dcfa5887c6e5034a6e9697993fa92c37"
[[package]]
name = "torchvision"
version = "0.17.2"
@ -3741,4 +3682,4 @@ xarm = ["gym-xarm"]
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "cb450ac7186e004536d75409edd42cd96062f7b1fd47822a5460d12eab8762f9"
content-hash = "bf4627c62a45764931729ce373f1038fe289b6caebb01e66d878f6f278c54518"

View File

@ -39,8 +39,6 @@ scikit-image = "^0.22.0"
numba = "^0.59.0"
mpmath = "^1.3.0"
torch = "^2.2.1"
tensordict = {git = "https://github.com/pytorch/tensordict"}
torchrl = {git = "https://github.com/pytorch/rl", rev = "13bef426dcfa5887c6e5034a6e9697993fa92c37"}
mujoco = "^2.3.7"
opencv-python = "^4.9.0.80"
diffusers = "^0.26.3"

View File

@ -15,10 +15,10 @@ from .utils import DEVICE, DEFAULT_CONFIG_PATH
("xarm", "tdmpc", ["policy.mpc=true"]),
("pusht", "tdmpc", ["policy.mpc=false"]),
("pusht", "diffusion", []),
# ("aloha", "act", ["env.task=sim_insertion", "dataset_id=aloha_sim_insertion_human"]),
#("aloha", "act", ["env.task=sim_insertion", "dataset_id=aloha_sim_insertion_scripted"]),
#("aloha", "act", ["env.task=sim_transfer_cube", "dataset_id=aloha_sim_transfer_cube_human"]),
#("aloha", "act", ["env.task=sim_transfer_cube", "dataset_id=aloha_sim_transfer_cube_scripted"]),
("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_id=aloha_sim_insertion_human"]),
("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_id=aloha_sim_insertion_scripted"]),
("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_human"]),
("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_scripted"]),
# TODO(aliberts): xarm not working with diffusion
# ("xarm", "diffusion", []),
],
@ -49,7 +49,7 @@ def test_policy(env_name, policy_name, extra_overrides):
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=4,
batch_size=cfg.policy.batch_size,
batch_size=2,
shuffle=True,
pin_memory=DEVICE != "cpu",
drop_last=True,