Merge remote-tracking branch 'upstream/user/rcadene/2024_03_31_remove_torchrl' into refactor_act

This commit is contained in:
Alexander Soare 2024-04-09 08:36:28 +01:00
commit e6c6c2367f
13 changed files with 109 additions and 247 deletions

126
.github/poetry/cpu/poetry.lock generated vendored
View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. # This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand.
[[package]] [[package]]
name = "absl-py" name = "absl-py"
@ -889,6 +889,69 @@ files = [
[package.extras] [package.extras]
protobuf = ["grpcio-tools (>=1.62.1)"] protobuf = ["grpcio-tools (>=1.62.1)"]
[[package]]
name = "gym-aloha"
version = "0.1.0"
description = "A gym environment for ALOHA"
optional = true
python-versions = "^3.10"
files = []
develop = false
[package.dependencies]
dm-control = "1.0.14"
gymnasium = "^0.29.1"
mujoco = "^2.3.7"
[package.source]
type = "git"
url = "git@github.com:huggingface/gym-aloha.git"
reference = "HEAD"
resolved_reference = "c636f05ba0d1760df94537da84c860be1487e17f"
[[package]]
name = "gym-pusht"
version = "0.1.0"
description = "A gymnasium environment for PushT."
optional = true
python-versions = "^3.10"
files = []
develop = false
[package.dependencies]
gymnasium = "^0.29.1"
opencv-python = "^4.9.0.80"
pygame = "^2.5.2"
pymunk = "^6.6.0"
scikit-image = "^0.22.0"
shapely = "^2.0.3"
[package.source]
type = "git"
url = "git@github.com:huggingface/gym-pusht.git"
reference = "HEAD"
resolved_reference = "6c9893504f670ff069d0f759a733e971ea1efdbf"
[[package]]
name = "gym-xarm"
version = "0.1.0"
description = "A gym environment for xArm"
optional = true
python-versions = "^3.10"
files = []
develop = false
[package.dependencies]
gymnasium = "^0.29.1"
gymnasium-robotics = "^1.2.4"
mujoco = "^2.3.7"
[package.source]
type = "git"
url = "git@github.com:huggingface/gym-xarm.git"
reference = "HEAD"
resolved_reference = "08ddd5a9400783a6898bbf3c3014fc5da3961b9d"
[[package]] [[package]]
name = "gymnasium" name = "gymnasium"
version = "0.29.1" version = "0.29.1"
@ -2988,31 +3051,6 @@ numpy = "*"
packaging = "*" packaging = "*"
protobuf = ">=3.20" protobuf = ">=3.20"
[[package]]
name = "tensordict"
version = "0.4.0+b4c91e8"
description = ""
optional = false
python-versions = "*"
files = []
develop = false
[package.dependencies]
cloudpickle = "*"
numpy = "*"
torch = ">=2.1.0"
[package.extras]
checkpointing = ["torchsnapshot-nightly"]
h5 = ["h5py (>=3.8)"]
tests = ["pytest", "pytest-benchmark", "pytest-instafail", "pytest-rerunfailures", "pyyaml"]
[package.source]
type = "git"
url = "https://github.com/pytorch/tensordict"
reference = "HEAD"
resolved_reference = "b4c91e8828c538ca0a50d8383fd99311a9afb078"
[[package]] [[package]]
name = "termcolor" name = "termcolor"
version = "2.4.0" version = "2.4.0"
@ -3091,40 +3129,6 @@ type = "legacy"
url = "https://download.pytorch.org/whl/cpu" url = "https://download.pytorch.org/whl/cpu"
reference = "torch-cpu" reference = "torch-cpu"
[[package]]
name = "torchrl"
version = "0.4.0+13bef42"
description = ""
optional = false
python-versions = "*"
files = []
develop = false
[package.dependencies]
cloudpickle = "*"
numpy = "*"
packaging = "*"
tensordict = ">=0.4.0"
torch = ">=2.1.0"
[package.extras]
all = ["ale-py", "atari-py", "dm_control", "git", "gym", "gym[accept-rom-license]", "gymnasium", "h5py", "huggingface_hub", "hydra-core (>=1.1)", "hydra-submitit-launcher", "minari", "moviepy", "mujoco", "pandas", "pettingzoo (>=1.24.1)", "pillow", "pygame", "pytest", "pytest-instafail", "pyyaml", "requests", "scikit-learn", "scipy", "tensorboard", "torchsnapshot", "torchvision", "tqdm", "vmas (>=1.2.10)", "wandb"]
atari = ["ale-py", "atari-py", "gym", "gym[accept-rom-license]", "pygame"]
checkpointing = ["torchsnapshot"]
dm-control = ["dm_control"]
gym-continuous = ["gymnasium", "mujoco"]
marl = ["pettingzoo (>=1.24.1)", "vmas (>=1.2.10)"]
offline-data = ["h5py", "huggingface_hub", "minari", "pandas", "pillow", "requests", "scikit-learn", "torchvision", "tqdm"]
rendering = ["moviepy"]
tests = ["pytest", "pytest-instafail", "pyyaml", "scipy"]
utils = ["git", "hydra-core (>=1.1)", "hydra-submitit-launcher", "tensorboard", "tqdm", "wandb"]
[package.source]
type = "git"
url = "https://github.com/pytorch/rl"
reference = "13bef426dcfa5887c6e5034a6e9697993fa92c37"
resolved_reference = "13bef426dcfa5887c6e5034a6e9697993fa92c37"
[[package]] [[package]]
name = "torchvision" name = "torchvision"
version = "0.17.1+cpu" version = "0.17.1+cpu"
@ -3330,4 +3334,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.10" python-versions = "^3.10"
content-hash = "8800bb8b24312d17b765cd2ce2799f49436171dd5fbf1bec3b07f853cfa9befd" content-hash = "32cd6caa01276a90b37cb177204e5b1511e92838f3f0268391034042d56f3bd6"

View File

@ -39,8 +39,6 @@ scikit-image = "^0.22.0"
numba = "^0.59.0" numba = "^0.59.0"
mpmath = "^1.3.0" mpmath = "^1.3.0"
torch = {version = "^2.2.1", source = "torch-cpu"} torch = {version = "^2.2.1", source = "torch-cpu"}
tensordict = {git = "https://github.com/pytorch/tensordict"}
torchrl = {git = "https://github.com/pytorch/rl", rev = "13bef426dcfa5887c6e5034a6e9697993fa92c37"}
mujoco = "^2.3.7" mujoco = "^2.3.7"
opencv-python = "^4.9.0.80" opencv-python = "^4.9.0.80"
diffusers = "^0.26.3" diffusers = "^0.26.3"
@ -53,7 +51,12 @@ huggingface-hub = "^0.21.4"
gymnasium-robotics = "^1.2.4" gymnasium-robotics = "^1.2.4"
gymnasium = "^0.29.1" gymnasium = "^0.29.1"
cmake = "^3.29.0.1" cmake = "^3.29.0.1"
gym-pusht = { git = "git@github.com:huggingface/gym-pusht.git", optional = true}
gym-xarm = { git = "git@github.com:huggingface/gym-xarm.git", optional = true}
gym-aloha = { git = "git@github.com:huggingface/gym-aloha.git", optional = true}
# gym-pusht = { path = "../gym-pusht", develop = true, optional = true}
# gym-xarm = { path = "../gym-xarm", develop = true, optional = true}
# gym-aloha = { path = "../gym-aloha", develop = true, optional = true}
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
pre-commit = "^3.6.2" pre-commit = "^3.6.2"

View File

@ -1,82 +0,0 @@
from collections import deque
import torch
from torch import Tensor, nn
class AbstractPolicy(nn.Module):
"""Base policy which all policies should be derived from.
The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its
documentation for more information.
Note:
When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
1. set the required class attributes:
- for classes inheriting from `AbstractDataset`: `available_datasets`
- for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
- for classes inheriting from `AbstractPolicy`: `name`
2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
3. update variables in `tests/test_available.py` by importing your new class
"""
name: str | None = None # same name should be used to instantiate the policy in factory.py
def __init__(self, n_action_steps: int | None):
"""
n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single
action is returned by `select_actions` and that doesn't have a horizon dimension. The `forward` method then
adds that dimension.
"""
super().__init__()
assert self.name is not None, "Subclasses of `AbstractPolicy` should set the `name` class attribute."
self.n_action_steps = n_action_steps
self.clear_action_queue()
def update(self, replay_buffer, step):
"""One step of the policy's learning algorithm."""
raise NotImplementedError("Abstract method")
def save(self, fp):
torch.save(self.state_dict(), fp)
def load(self, fp):
d = torch.load(fp)
self.load_state_dict(d)
def select_actions(self, observation) -> Tensor:
"""Select an action (or trajectory of actions) based on an observation during rollout.
If n_action_steps was provided at initialization, this should return a (batch_size, n_action_steps, *) tensor of
actions. Otherwise if n_actions_steps is None, this should return a (batch_size, *) tensor of actions.
"""
raise NotImplementedError("Abstract method")
def clear_action_queue(self):
"""This should be called whenever the environment is reset."""
if self.n_action_steps is not None:
self._action_queue = deque([], maxlen=self.n_action_steps)
def forward(self, *args, **kwargs) -> Tensor:
"""Inference step that makes multi-step policies compatible with their single-step environments.
WARNING: In general, this should not be overriden.
Consider a "policy" that observes the environment then charts a course of N actions to take. To make this fit
into the formalism of a TorchRL environment, we view it as being effectively a policy that (1) makes an
observation and prepares a queue of actions, (2) consumes that queue when queried, regardless of the environment
observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that
the subclass doesn't have to.
This method effectively wraps the `select_actions` method of the subclass. The following assumptions are made:
1. The `select_actions` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is
the action trajectory horizon and * is the action dimensions.
2. Prior to the `select_actions` method being called, theres is an `n_action_steps` instance attribute defined.
"""
if self.n_action_steps is None:
return self.select_actions(*args, **kwargs)
if len(self._action_queue) == 0:
# `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has shape
# (n_action_steps, batch_size, *), hence the transpose.
self._action_queue.extend(self.select_actions(*args, **kwargs).transpose(0, 1))
return self._action_queue.popleft()

View File

@ -65,7 +65,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
"ActionChunkingTransformerPolicy does not handle multiple observation steps." "ActionChunkingTransformerPolicy does not handle multiple observation steps."
) )
def __init__(self, cfg, device, n_action_steps=1): def __init__(self, cfg, device):
""" """
TODO(alexander-soare): Add documentation for all parameters once we have model configs established. TODO(alexander-soare): Add documentation for all parameters once we have model configs established.
""" """
@ -73,7 +73,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
if getattr(cfg, "n_obs_steps", 1) != 1: if getattr(cfg, "n_obs_steps", 1) != 1:
raise ValueError(self._multiple_obs_steps_not_handled_msg) raise ValueError(self._multiple_obs_steps_not_handled_msg)
self.cfg = cfg self.cfg = cfg
self.n_action_steps = n_action_steps self.n_action_steps = cfg.n_action_steps
self.device = get_safe_torch_device(device) self.device = get_safe_torch_device(device)
self.camera_names = cfg.camera_names self.camera_names = cfg.camera_names
self.use_vae = cfg.use_vae self.use_vae = cfg.use_vae
@ -176,7 +176,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
if self.n_action_steps is not None: if self.n_action_steps is not None:
self._action_queue = deque([], maxlen=self.n_action_steps) self._action_queue = deque([], maxlen=self.n_action_steps)
def select_action(self, batch: dict[str, Tensor], *_) -> Tensor: def select_action(self, batch: dict[str, Tensor], *_, **__) -> Tensor:
""" """
This method wraps `select_actions` in order to return one action at a time for execution in the This method wraps `select_actions` in order to return one action at a time for execution in the
environment. It works by managing the actions in a queue and only calling `select_actions` when the environment. It works by managing the actions in a queue and only calling `select_actions` when the
@ -244,7 +244,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
# Note: no squeeze is required for "observation.images.top" because then we'd have to unsqueeze to get # Note: no squeeze is required for "observation.images.top" because then we'd have to unsqueeze to get
# the image index dimension. # the image index dimension.
def update(self, batch, *_) -> dict: def update(self, batch, *_, **__) -> dict:
start_time = time.time() start_time = time.time()
self._preprocess_batch(batch) self._preprocess_batch(batch)
@ -311,7 +311,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
def _forward( def _forward(
self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None
) -> tuple[Tensor, tuple[Tensor, Tensor]]: ) -> tuple[Tensor, tuple[Tensor | None, Tensor | None]]:
""" """
Args: Args:
robot_state: (B, J) batch of robot joint configurations. robot_state: (B, J) batch of robot joint configurations.
@ -344,16 +344,16 @@ class ActionChunkingTransformerPolicy(nn.Module):
# Note: detach() shouldn't be necessary but leaving it the same as the original code just in case. # Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D) pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D)
# Forward pass through VAE encoder. # Forward pass through VAE encoder to get the latent PDF parameters.
cls_token_out = self.vae_encoder( cls_token_out = self.vae_encoder(
vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2) vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2)
)[0] # select the class token, with shape (B, D) )[0] # select the class token, with shape (B, D)
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out) latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
# Sample the latent with the reparameterization trick.
mu = latent_pdf_params[:, : self.latent_dim] mu = latent_pdf_params[:, : self.latent_dim]
# This is 2log(sigma). Done this way to match the original implementation. # This is 2log(sigma). Done this way to match the original implementation.
log_sigma_x2 = latent_pdf_params[:, self.latent_dim :] log_sigma_x2 = latent_pdf_params[:, self.latent_dim :]
# Sample the latent with the reparameterization trick.
latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu) latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu)
else: else:
# When not using the VAE encoder, we set the latent to be all zeros. # When not using the VAE encoder, we set the latent to be all zeros.

View File

@ -16,18 +16,15 @@ def make_policy(cfg):
cfg_obs_encoder=cfg.obs_encoder, cfg_obs_encoder=cfg.obs_encoder,
cfg_optimizer=cfg.optimizer, cfg_optimizer=cfg.optimizer,
cfg_ema=cfg.ema, cfg_ema=cfg.ema,
n_obs_steps=cfg.n_obs_steps, # n_obs_steps=cfg.n_obs_steps,
n_action_steps=cfg.n_action_steps, # n_action_steps=cfg.n_action_steps,
**cfg.policy, **cfg.policy,
) )
elif cfg.policy.name == "act": elif cfg.policy.name == "act":
from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
policy = ActionChunkingTransformerPolicy( policy = ActionChunkingTransformerPolicy(cfg.policy, cfg.device)
cfg.policy, policy.to(cfg.device)
cfg.device,
n_action_steps=cfg.n_action_steps,
)
else: else:
raise ValueError(cfg.policy.name) raise ValueError(cfg.policy.name)

View File

@ -110,7 +110,6 @@ class TDMPCPolicy(nn.Module):
# self.bc_optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr) # self.bc_optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
self.model.eval() self.model.eval()
self.model_target.eval() self.model_target.eval()
self.batch_size = cfg.batch_size
self.register_buffer("step", torch.zeros(1)) self.register_buffer("step", torch.zeros(1))
@ -151,6 +150,8 @@ class TDMPCPolicy(nn.Module):
t0 = step == 0 t0 = step == 0
self.eval()
if len(self._queues["action"]) == 0: if len(self._queues["action"]) == 0:
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch} batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
@ -172,7 +173,7 @@ class TDMPCPolicy(nn.Module):
actions.append(action) actions.append(action)
action = torch.stack(actions) action = torch.stack(actions)
# self.act returns an action for 1 timestep only, so we copy it over `n_action_steps` time # tdmpc returns an action for 1 timestep only, so we copy it over `n_action_steps` time
if i in range(self.n_action_steps): if i in range(self.n_action_steps):
self._queues["action"].append(action) self._queues["action"].append(action)
@ -325,7 +326,7 @@ class TDMPCPolicy(nn.Module):
def _td_target(self, next_z, reward, mask): def _td_target(self, next_z, reward, mask):
"""Compute the TD-target from a reward and the observation at the following time step.""" """Compute the TD-target from a reward and the observation at the following time step."""
next_v = self.model.V(next_z) next_v = self.model.V(next_z)
td_target = reward + self.cfg.discount * mask * next_v td_target = reward + self.cfg.discount * mask * next_v.squeeze(2)
return td_target return td_target
def forward(self, batch, step): def forward(self, batch, step):
@ -420,6 +421,8 @@ class TDMPCPolicy(nn.Module):
# idxs = torch.cat([idxs, demo_idxs]) # idxs = torch.cat([idxs, demo_idxs])
# weights = torch.cat([weights, demo_weights]) # weights = torch.cat([weights, demo_weights])
batch_size = batch["index"].shape[0]
# TODO(rcadene): convert tdmpc with (batch size, time/horizon, channels) # TODO(rcadene): convert tdmpc with (batch size, time/horizon, channels)
# instead of currently (time/horizon, batch size, channels) which is not the pytorch convention # instead of currently (time/horizon, batch size, channels) which is not the pytorch convention
# batch size b = 256, time/horizon t = 5 # batch size b = 256, time/horizon t = 5
@ -433,7 +436,7 @@ class TDMPCPolicy(nn.Module):
# idxs = batch["index"] # TODO(rcadene): use idxs to update sampling weights # idxs = batch["index"] # TODO(rcadene): use idxs to update sampling weights
done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device) done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device) mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
weights = torch.ones_like(reward, dtype=torch.bool, device=reward.device) weights = torch.ones(batch_size, dtype=torch.bool, device=reward.device)
obses = { obses = {
"rgb": batch["observation.image"], "rgb": batch["observation.image"],
@ -476,7 +479,7 @@ class TDMPCPolicy(nn.Module):
td_targets = self._td_target(next_z, reward, mask) td_targets = self._td_target(next_z, reward, mask)
# Latent rollout # Latent rollout
zs = torch.empty(horizon + 1, self.batch_size, self.cfg.latent_dim, device=self.device) zs = torch.empty(horizon + 1, batch_size, self.cfg.latent_dim, device=self.device)
reward_preds = torch.empty_like(reward, device=self.device) reward_preds = torch.empty_like(reward, device=self.device)
assert reward.shape[0] == horizon assert reward.shape[0] == horizon
z = self.model.encode(obs) z = self.model.encode(obs)
@ -485,22 +488,21 @@ class TDMPCPolicy(nn.Module):
for t in range(horizon): for t in range(horizon):
z, reward_pred = self.model.next(z, action[t]) z, reward_pred = self.model.next(z, action[t])
zs[t + 1] = z zs[t + 1] = z
reward_preds[t] = reward_pred reward_preds[t] = reward_pred.squeeze(1)
with torch.no_grad(): with torch.no_grad():
v_target = self.model_target.Q(zs[:-1].detach(), action, return_type="min") v_target = self.model_target.Q(zs[:-1].detach(), action, return_type="min")
# Predictions # Predictions
qs = self.model.Q(zs[:-1], action, return_type="all") qs = self.model.Q(zs[:-1], action, return_type="all")
qs = qs.squeeze(3)
value_info["Q"] = qs.mean().item() value_info["Q"] = qs.mean().item()
v = self.model.V(zs[:-1]) v = self.model.V(zs[:-1])
value_info["V"] = v.mean().item() value_info["V"] = v.mean().item()
# Losses # Losses
rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view(-1, 1, 1) rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view(-1, 1)
consistency_loss = (rho * torch.mean(h.mse(zs[1:], z_targets), dim=2, keepdim=True) * loss_mask).sum( consistency_loss = (rho * torch.mean(h.mse(zs[1:], z_targets), dim=2) * loss_mask).sum(dim=0)
dim=0
)
reward_loss = (rho * h.mse(reward_preds, reward) * loss_mask).sum(dim=0) reward_loss = (rho * h.mse(reward_preds, reward) * loss_mask).sum(dim=0)
q_value_loss, priority_loss = 0, 0 q_value_loss, priority_loss = 0, 0
for q in range(self.cfg.num_q): for q in range(self.cfg.num_q):
@ -508,7 +510,9 @@ class TDMPCPolicy(nn.Module):
priority_loss += (rho * h.l1(qs[q], td_targets) * loss_mask).sum(dim=0) priority_loss += (rho * h.l1(qs[q], td_targets) * loss_mask).sum(dim=0)
expectile = h.linear_schedule(self.cfg.expectile, step) expectile = h.linear_schedule(self.cfg.expectile, step)
v_value_loss = (rho * h.l2_expectile(v_target - v, expectile=expectile) * loss_mask).sum(dim=0) v_value_loss = (rho * h.l2_expectile(v_target - v, expectile=expectile).squeeze(2) * loss_mask).sum(
dim=0
)
total_loss = ( total_loss = (
self.cfg.consistency_coef * consistency_loss self.cfg.consistency_coef * consistency_loss
@ -517,7 +521,7 @@ class TDMPCPolicy(nn.Module):
+ self.cfg.value_coef * v_value_loss + self.cfg.value_coef * v_value_loss
) )
weighted_loss = (total_loss.squeeze(1) * weights).mean() weighted_loss = (total_loss * weights).mean()
weighted_loss.register_hook(lambda grad: grad * (1 / self.cfg.horizon)) weighted_loss.register_hook(lambda grad: grad * (1 / self.cfg.horizon))
has_nan = torch.isnan(weighted_loss).item() has_nan = torch.isnan(weighted_loss).item()
if has_nan: if has_nan:

View File

@ -51,6 +51,7 @@ policy:
utd: 1 utd: 1
n_obs_steps: ${n_obs_steps} n_obs_steps: ${n_obs_steps}
n_action_steps: ${n_action_steps}
temporal_agg: false temporal_agg: false

View File

@ -38,6 +38,7 @@ policy:
horizon: ${horizon} horizon: ${horizon}
n_obs_steps: ${n_obs_steps} n_obs_steps: ${n_obs_steps}
n_action_steps: ${n_action_steps}
num_inference_steps: 100 num_inference_steps: 100
obs_as_global_cond: ${obs_as_global_cond} obs_as_global_cond: ${obs_as_global_cond}
# crop_shape: null # crop_shape: null

View File

@ -36,7 +36,6 @@ policy:
log_std_max: 2 log_std_max: 2
# learning # learning
batch_size: 256
max_buffer_size: 10000 max_buffer_size: 10000
horizon: 5 horizon: 5
reward_coef: 0.5 reward_coef: 0.5

View File

@ -86,14 +86,10 @@ def eval_policy(
def maybe_render_frame(env): def maybe_render_frame(env):
if save_video: # noqa: B023 if save_video: # noqa: B023
if return_first_video: if return_first_video:
# TODO(now): Put mode back in.
visu = env.envs[0].render() visu = env.envs[0].render()
# visu = env.envs[0].render(mode="visualization")
visu = visu[None, ...] # add batch dim visu = visu[None, ...] # add batch dim
else: else:
# TODO(now): Put mode back in.
visu = np.stack([env.render() for env in env.envs]) visu = np.stack([env.render() for env in env.envs])
# visu = np.stack([env.render(mode="visualization") for env in env.envs])
ep_frames.append(visu) # noqa: B023 ep_frames.append(visu) # noqa: B023
for _ in range(num_episodes): for _ in range(num_episodes):

65
poetry.lock generated
View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. # This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand.
[[package]] [[package]]
name = "absl-py" name = "absl-py"
@ -898,7 +898,7 @@ mujoco = "^2.3.7"
type = "git" type = "git"
url = "git@github.com:huggingface/gym-aloha.git" url = "git@github.com:huggingface/gym-aloha.git"
reference = "HEAD" reference = "HEAD"
resolved_reference = "ec7200831e36c14e343cf7d275c6b047f2fe9d11" resolved_reference = "c636f05ba0d1760df94537da84c860be1487e17f"
[[package]] [[package]]
name = "gym-pusht" name = "gym-pusht"
@ -3339,31 +3339,6 @@ numpy = "*"
packaging = "*" packaging = "*"
protobuf = ">=3.20" protobuf = ">=3.20"
[[package]]
name = "tensordict"
version = "0.4.0+f622b2f"
description = ""
optional = false
python-versions = "*"
files = []
develop = false
[package.dependencies]
cloudpickle = "*"
numpy = "*"
torch = ">=2.1.0"
[package.extras]
checkpointing = ["torchsnapshot-nightly"]
h5 = ["h5py (>=3.8)"]
tests = ["pytest", "pytest-benchmark", "pytest-instafail", "pytest-rerunfailures", "pyyaml"]
[package.source]
type = "git"
url = "https://github.com/pytorch/tensordict"
reference = "HEAD"
resolved_reference = "f622b2f973320f769b6c09793ca827f27e47d603"
[[package]] [[package]]
name = "termcolor" name = "termcolor"
version = "2.4.0" version = "2.4.0"
@ -3464,40 +3439,6 @@ typing-extensions = ">=4.8.0"
opt-einsum = ["opt-einsum (>=3.3)"] opt-einsum = ["opt-einsum (>=3.3)"]
optree = ["optree (>=0.9.1)"] optree = ["optree (>=0.9.1)"]
[[package]]
name = "torchrl"
version = "0.4.0+13bef42"
description = ""
optional = false
python-versions = "*"
files = []
develop = false
[package.dependencies]
cloudpickle = "*"
numpy = "*"
packaging = "*"
tensordict = ">=0.4.0"
torch = ">=2.1.0"
[package.extras]
all = ["ale-py", "atari-py", "dm_control", "git", "gym", "gym[accept-rom-license]", "gymnasium", "h5py", "huggingface_hub", "hydra-core (>=1.1)", "hydra-submitit-launcher", "minari", "moviepy", "mujoco", "pandas", "pettingzoo (>=1.24.1)", "pillow", "pygame", "pytest", "pytest-instafail", "pyyaml", "requests", "scikit-learn", "scipy", "tensorboard", "torchsnapshot", "torchvision", "tqdm", "vmas (>=1.2.10)", "wandb"]
atari = ["ale-py", "atari-py", "gym", "gym[accept-rom-license]", "pygame"]
checkpointing = ["torchsnapshot"]
dm-control = ["dm_control"]
gym-continuous = ["gymnasium", "mujoco"]
marl = ["pettingzoo (>=1.24.1)", "vmas (>=1.2.10)"]
offline-data = ["h5py", "huggingface_hub", "minari", "pandas", "pillow", "requests", "scikit-learn", "torchvision", "tqdm"]
rendering = ["moviepy"]
tests = ["pytest", "pytest-instafail", "pyyaml", "scipy"]
utils = ["git", "hydra-core (>=1.1)", "hydra-submitit-launcher", "tensorboard", "tqdm", "wandb"]
[package.source]
type = "git"
url = "https://github.com/pytorch/rl"
reference = "13bef426dcfa5887c6e5034a6e9697993fa92c37"
resolved_reference = "13bef426dcfa5887c6e5034a6e9697993fa92c37"
[[package]] [[package]]
name = "torchvision" name = "torchvision"
version = "0.17.2" version = "0.17.2"
@ -3741,4 +3682,4 @@ xarm = ["gym-xarm"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.10" python-versions = "^3.10"
content-hash = "cb450ac7186e004536d75409edd42cd96062f7b1fd47822a5460d12eab8762f9" content-hash = "bf4627c62a45764931729ce373f1038fe289b6caebb01e66d878f6f278c54518"

View File

@ -39,8 +39,6 @@ scikit-image = "^0.22.0"
numba = "^0.59.0" numba = "^0.59.0"
mpmath = "^1.3.0" mpmath = "^1.3.0"
torch = "^2.2.1" torch = "^2.2.1"
tensordict = {git = "https://github.com/pytorch/tensordict"}
torchrl = {git = "https://github.com/pytorch/rl", rev = "13bef426dcfa5887c6e5034a6e9697993fa92c37"}
mujoco = "^2.3.7" mujoco = "^2.3.7"
opencv-python = "^4.9.0.80" opencv-python = "^4.9.0.80"
diffusers = "^0.26.3" diffusers = "^0.26.3"

View File

@ -15,10 +15,10 @@ from .utils import DEVICE, DEFAULT_CONFIG_PATH
("xarm", "tdmpc", ["policy.mpc=true"]), ("xarm", "tdmpc", ["policy.mpc=true"]),
("pusht", "tdmpc", ["policy.mpc=false"]), ("pusht", "tdmpc", ["policy.mpc=false"]),
("pusht", "diffusion", []), ("pusht", "diffusion", []),
# ("aloha", "act", ["env.task=sim_insertion", "dataset_id=aloha_sim_insertion_human"]), ("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_id=aloha_sim_insertion_human"]),
#("aloha", "act", ["env.task=sim_insertion", "dataset_id=aloha_sim_insertion_scripted"]), ("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_id=aloha_sim_insertion_scripted"]),
#("aloha", "act", ["env.task=sim_transfer_cube", "dataset_id=aloha_sim_transfer_cube_human"]), ("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_human"]),
#("aloha", "act", ["env.task=sim_transfer_cube", "dataset_id=aloha_sim_transfer_cube_scripted"]), ("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_scripted"]),
# TODO(aliberts): xarm not working with diffusion # TODO(aliberts): xarm not working with diffusion
# ("xarm", "diffusion", []), # ("xarm", "diffusion", []),
], ],
@ -49,7 +49,7 @@ def test_policy(env_name, policy_name, extra_overrides):
dataloader = torch.utils.data.DataLoader( dataloader = torch.utils.data.DataLoader(
dataset, dataset,
num_workers=4, num_workers=4,
batch_size=cfg.policy.batch_size, batch_size=2,
shuffle=True, shuffle=True,
pin_memory=DEVICE != "cpu", pin_memory=DEVICE != "cpu",
drop_last=True, drop_last=True,