Merge remote-tracking branch 'upstream/user/rcadene/2024_03_31_remove_torchrl' into refactor_act
This commit is contained in:
commit
e6c6c2367f
|
@ -1,4 +1,4 @@
|
|||
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "absl-py"
|
||||
|
@ -889,6 +889,69 @@ files = [
|
|||
[package.extras]
|
||||
protobuf = ["grpcio-tools (>=1.62.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "gym-aloha"
|
||||
version = "0.1.0"
|
||||
description = "A gym environment for ALOHA"
|
||||
optional = true
|
||||
python-versions = "^3.10"
|
||||
files = []
|
||||
develop = false
|
||||
|
||||
[package.dependencies]
|
||||
dm-control = "1.0.14"
|
||||
gymnasium = "^0.29.1"
|
||||
mujoco = "^2.3.7"
|
||||
|
||||
[package.source]
|
||||
type = "git"
|
||||
url = "git@github.com:huggingface/gym-aloha.git"
|
||||
reference = "HEAD"
|
||||
resolved_reference = "c636f05ba0d1760df94537da84c860be1487e17f"
|
||||
|
||||
[[package]]
|
||||
name = "gym-pusht"
|
||||
version = "0.1.0"
|
||||
description = "A gymnasium environment for PushT."
|
||||
optional = true
|
||||
python-versions = "^3.10"
|
||||
files = []
|
||||
develop = false
|
||||
|
||||
[package.dependencies]
|
||||
gymnasium = "^0.29.1"
|
||||
opencv-python = "^4.9.0.80"
|
||||
pygame = "^2.5.2"
|
||||
pymunk = "^6.6.0"
|
||||
scikit-image = "^0.22.0"
|
||||
shapely = "^2.0.3"
|
||||
|
||||
[package.source]
|
||||
type = "git"
|
||||
url = "git@github.com:huggingface/gym-pusht.git"
|
||||
reference = "HEAD"
|
||||
resolved_reference = "6c9893504f670ff069d0f759a733e971ea1efdbf"
|
||||
|
||||
[[package]]
|
||||
name = "gym-xarm"
|
||||
version = "0.1.0"
|
||||
description = "A gym environment for xArm"
|
||||
optional = true
|
||||
python-versions = "^3.10"
|
||||
files = []
|
||||
develop = false
|
||||
|
||||
[package.dependencies]
|
||||
gymnasium = "^0.29.1"
|
||||
gymnasium-robotics = "^1.2.4"
|
||||
mujoco = "^2.3.7"
|
||||
|
||||
[package.source]
|
||||
type = "git"
|
||||
url = "git@github.com:huggingface/gym-xarm.git"
|
||||
reference = "HEAD"
|
||||
resolved_reference = "08ddd5a9400783a6898bbf3c3014fc5da3961b9d"
|
||||
|
||||
[[package]]
|
||||
name = "gymnasium"
|
||||
version = "0.29.1"
|
||||
|
@ -2988,31 +3051,6 @@ numpy = "*"
|
|||
packaging = "*"
|
||||
protobuf = ">=3.20"
|
||||
|
||||
[[package]]
|
||||
name = "tensordict"
|
||||
version = "0.4.0+b4c91e8"
|
||||
description = ""
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = []
|
||||
develop = false
|
||||
|
||||
[package.dependencies]
|
||||
cloudpickle = "*"
|
||||
numpy = "*"
|
||||
torch = ">=2.1.0"
|
||||
|
||||
[package.extras]
|
||||
checkpointing = ["torchsnapshot-nightly"]
|
||||
h5 = ["h5py (>=3.8)"]
|
||||
tests = ["pytest", "pytest-benchmark", "pytest-instafail", "pytest-rerunfailures", "pyyaml"]
|
||||
|
||||
[package.source]
|
||||
type = "git"
|
||||
url = "https://github.com/pytorch/tensordict"
|
||||
reference = "HEAD"
|
||||
resolved_reference = "b4c91e8828c538ca0a50d8383fd99311a9afb078"
|
||||
|
||||
[[package]]
|
||||
name = "termcolor"
|
||||
version = "2.4.0"
|
||||
|
@ -3091,40 +3129,6 @@ type = "legacy"
|
|||
url = "https://download.pytorch.org/whl/cpu"
|
||||
reference = "torch-cpu"
|
||||
|
||||
[[package]]
|
||||
name = "torchrl"
|
||||
version = "0.4.0+13bef42"
|
||||
description = ""
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = []
|
||||
develop = false
|
||||
|
||||
[package.dependencies]
|
||||
cloudpickle = "*"
|
||||
numpy = "*"
|
||||
packaging = "*"
|
||||
tensordict = ">=0.4.0"
|
||||
torch = ">=2.1.0"
|
||||
|
||||
[package.extras]
|
||||
all = ["ale-py", "atari-py", "dm_control", "git", "gym", "gym[accept-rom-license]", "gymnasium", "h5py", "huggingface_hub", "hydra-core (>=1.1)", "hydra-submitit-launcher", "minari", "moviepy", "mujoco", "pandas", "pettingzoo (>=1.24.1)", "pillow", "pygame", "pytest", "pytest-instafail", "pyyaml", "requests", "scikit-learn", "scipy", "tensorboard", "torchsnapshot", "torchvision", "tqdm", "vmas (>=1.2.10)", "wandb"]
|
||||
atari = ["ale-py", "atari-py", "gym", "gym[accept-rom-license]", "pygame"]
|
||||
checkpointing = ["torchsnapshot"]
|
||||
dm-control = ["dm_control"]
|
||||
gym-continuous = ["gymnasium", "mujoco"]
|
||||
marl = ["pettingzoo (>=1.24.1)", "vmas (>=1.2.10)"]
|
||||
offline-data = ["h5py", "huggingface_hub", "minari", "pandas", "pillow", "requests", "scikit-learn", "torchvision", "tqdm"]
|
||||
rendering = ["moviepy"]
|
||||
tests = ["pytest", "pytest-instafail", "pyyaml", "scipy"]
|
||||
utils = ["git", "hydra-core (>=1.1)", "hydra-submitit-launcher", "tensorboard", "tqdm", "wandb"]
|
||||
|
||||
[package.source]
|
||||
type = "git"
|
||||
url = "https://github.com/pytorch/rl"
|
||||
reference = "13bef426dcfa5887c6e5034a6e9697993fa92c37"
|
||||
resolved_reference = "13bef426dcfa5887c6e5034a6e9697993fa92c37"
|
||||
|
||||
[[package]]
|
||||
name = "torchvision"
|
||||
version = "0.17.1+cpu"
|
||||
|
@ -3330,4 +3334,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "8800bb8b24312d17b765cd2ce2799f49436171dd5fbf1bec3b07f853cfa9befd"
|
||||
content-hash = "32cd6caa01276a90b37cb177204e5b1511e92838f3f0268391034042d56f3bd6"
|
||||
|
|
|
@ -39,8 +39,6 @@ scikit-image = "^0.22.0"
|
|||
numba = "^0.59.0"
|
||||
mpmath = "^1.3.0"
|
||||
torch = {version = "^2.2.1", source = "torch-cpu"}
|
||||
tensordict = {git = "https://github.com/pytorch/tensordict"}
|
||||
torchrl = {git = "https://github.com/pytorch/rl", rev = "13bef426dcfa5887c6e5034a6e9697993fa92c37"}
|
||||
mujoco = "^2.3.7"
|
||||
opencv-python = "^4.9.0.80"
|
||||
diffusers = "^0.26.3"
|
||||
|
@ -53,7 +51,12 @@ huggingface-hub = "^0.21.4"
|
|||
gymnasium-robotics = "^1.2.4"
|
||||
gymnasium = "^0.29.1"
|
||||
cmake = "^3.29.0.1"
|
||||
|
||||
gym-pusht = { git = "git@github.com:huggingface/gym-pusht.git", optional = true}
|
||||
gym-xarm = { git = "git@github.com:huggingface/gym-xarm.git", optional = true}
|
||||
gym-aloha = { git = "git@github.com:huggingface/gym-aloha.git", optional = true}
|
||||
# gym-pusht = { path = "../gym-pusht", develop = true, optional = true}
|
||||
# gym-xarm = { path = "../gym-xarm", develop = true, optional = true}
|
||||
# gym-aloha = { path = "../gym-aloha", develop = true, optional = true}
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pre-commit = "^3.6.2"
|
||||
|
|
|
@ -1,82 +0,0 @@
|
|||
from collections import deque
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
class AbstractPolicy(nn.Module):
|
||||
"""Base policy which all policies should be derived from.
|
||||
|
||||
The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its
|
||||
documentation for more information.
|
||||
|
||||
Note:
|
||||
When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
|
||||
1. set the required class attributes:
|
||||
- for classes inheriting from `AbstractDataset`: `available_datasets`
|
||||
- for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
|
||||
- for classes inheriting from `AbstractPolicy`: `name`
|
||||
2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
|
||||
3. update variables in `tests/test_available.py` by importing your new class
|
||||
"""
|
||||
|
||||
name: str | None = None # same name should be used to instantiate the policy in factory.py
|
||||
|
||||
def __init__(self, n_action_steps: int | None):
|
||||
"""
|
||||
n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single
|
||||
action is returned by `select_actions` and that doesn't have a horizon dimension. The `forward` method then
|
||||
adds that dimension.
|
||||
"""
|
||||
super().__init__()
|
||||
assert self.name is not None, "Subclasses of `AbstractPolicy` should set the `name` class attribute."
|
||||
self.n_action_steps = n_action_steps
|
||||
self.clear_action_queue()
|
||||
|
||||
def update(self, replay_buffer, step):
|
||||
"""One step of the policy's learning algorithm."""
|
||||
raise NotImplementedError("Abstract method")
|
||||
|
||||
def save(self, fp):
|
||||
torch.save(self.state_dict(), fp)
|
||||
|
||||
def load(self, fp):
|
||||
d = torch.load(fp)
|
||||
self.load_state_dict(d)
|
||||
|
||||
def select_actions(self, observation) -> Tensor:
|
||||
"""Select an action (or trajectory of actions) based on an observation during rollout.
|
||||
|
||||
If n_action_steps was provided at initialization, this should return a (batch_size, n_action_steps, *) tensor of
|
||||
actions. Otherwise if n_actions_steps is None, this should return a (batch_size, *) tensor of actions.
|
||||
"""
|
||||
raise NotImplementedError("Abstract method")
|
||||
|
||||
def clear_action_queue(self):
|
||||
"""This should be called whenever the environment is reset."""
|
||||
if self.n_action_steps is not None:
|
||||
self._action_queue = deque([], maxlen=self.n_action_steps)
|
||||
|
||||
def forward(self, *args, **kwargs) -> Tensor:
|
||||
"""Inference step that makes multi-step policies compatible with their single-step environments.
|
||||
|
||||
WARNING: In general, this should not be overriden.
|
||||
|
||||
Consider a "policy" that observes the environment then charts a course of N actions to take. To make this fit
|
||||
into the formalism of a TorchRL environment, we view it as being effectively a policy that (1) makes an
|
||||
observation and prepares a queue of actions, (2) consumes that queue when queried, regardless of the environment
|
||||
observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that
|
||||
the subclass doesn't have to.
|
||||
|
||||
This method effectively wraps the `select_actions` method of the subclass. The following assumptions are made:
|
||||
1. The `select_actions` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is
|
||||
the action trajectory horizon and * is the action dimensions.
|
||||
2. Prior to the `select_actions` method being called, theres is an `n_action_steps` instance attribute defined.
|
||||
"""
|
||||
if self.n_action_steps is None:
|
||||
return self.select_actions(*args, **kwargs)
|
||||
if len(self._action_queue) == 0:
|
||||
# `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has shape
|
||||
# (n_action_steps, batch_size, *), hence the transpose.
|
||||
self._action_queue.extend(self.select_actions(*args, **kwargs).transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
|
@ -65,7 +65,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
"ActionChunkingTransformerPolicy does not handle multiple observation steps."
|
||||
)
|
||||
|
||||
def __init__(self, cfg, device, n_action_steps=1):
|
||||
def __init__(self, cfg, device):
|
||||
"""
|
||||
TODO(alexander-soare): Add documentation for all parameters once we have model configs established.
|
||||
"""
|
||||
|
@ -73,7 +73,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
if getattr(cfg, "n_obs_steps", 1) != 1:
|
||||
raise ValueError(self._multiple_obs_steps_not_handled_msg)
|
||||
self.cfg = cfg
|
||||
self.n_action_steps = n_action_steps
|
||||
self.n_action_steps = cfg.n_action_steps
|
||||
self.device = get_safe_torch_device(device)
|
||||
self.camera_names = cfg.camera_names
|
||||
self.use_vae = cfg.use_vae
|
||||
|
@ -176,7 +176,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
if self.n_action_steps is not None:
|
||||
self._action_queue = deque([], maxlen=self.n_action_steps)
|
||||
|
||||
def select_action(self, batch: dict[str, Tensor], *_) -> Tensor:
|
||||
def select_action(self, batch: dict[str, Tensor], *_, **__) -> Tensor:
|
||||
"""
|
||||
This method wraps `select_actions` in order to return one action at a time for execution in the
|
||||
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
||||
|
@ -244,7 +244,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
# Note: no squeeze is required for "observation.images.top" because then we'd have to unsqueeze to get
|
||||
# the image index dimension.
|
||||
|
||||
def update(self, batch, *_) -> dict:
|
||||
def update(self, batch, *_, **__) -> dict:
|
||||
start_time = time.time()
|
||||
self._preprocess_batch(batch)
|
||||
|
||||
|
@ -311,7 +311,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
|
||||
def _forward(
|
||||
self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None
|
||||
) -> tuple[Tensor, tuple[Tensor, Tensor]]:
|
||||
) -> tuple[Tensor, tuple[Tensor | None, Tensor | None]]:
|
||||
"""
|
||||
Args:
|
||||
robot_state: (B, J) batch of robot joint configurations.
|
||||
|
@ -344,16 +344,16 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
# Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
|
||||
pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D)
|
||||
|
||||
# Forward pass through VAE encoder.
|
||||
# Forward pass through VAE encoder to get the latent PDF parameters.
|
||||
cls_token_out = self.vae_encoder(
|
||||
vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2)
|
||||
)[0] # select the class token, with shape (B, D)
|
||||
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
|
||||
|
||||
# Sample the latent with the reparameterization trick.
|
||||
mu = latent_pdf_params[:, : self.latent_dim]
|
||||
# This is 2log(sigma). Done this way to match the original implementation.
|
||||
log_sigma_x2 = latent_pdf_params[:, self.latent_dim :]
|
||||
|
||||
# Sample the latent with the reparameterization trick.
|
||||
latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu)
|
||||
else:
|
||||
# When not using the VAE encoder, we set the latent to be all zeros.
|
||||
|
|
|
@ -16,18 +16,15 @@ def make_policy(cfg):
|
|||
cfg_obs_encoder=cfg.obs_encoder,
|
||||
cfg_optimizer=cfg.optimizer,
|
||||
cfg_ema=cfg.ema,
|
||||
n_obs_steps=cfg.n_obs_steps,
|
||||
n_action_steps=cfg.n_action_steps,
|
||||
# n_obs_steps=cfg.n_obs_steps,
|
||||
# n_action_steps=cfg.n_action_steps,
|
||||
**cfg.policy,
|
||||
)
|
||||
elif cfg.policy.name == "act":
|
||||
from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
|
||||
|
||||
policy = ActionChunkingTransformerPolicy(
|
||||
cfg.policy,
|
||||
cfg.device,
|
||||
n_action_steps=cfg.n_action_steps,
|
||||
)
|
||||
policy = ActionChunkingTransformerPolicy(cfg.policy, cfg.device)
|
||||
policy.to(cfg.device)
|
||||
else:
|
||||
raise ValueError(cfg.policy.name)
|
||||
|
||||
|
|
|
@ -110,7 +110,6 @@ class TDMPCPolicy(nn.Module):
|
|||
# self.bc_optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
|
||||
self.model.eval()
|
||||
self.model_target.eval()
|
||||
self.batch_size = cfg.batch_size
|
||||
|
||||
self.register_buffer("step", torch.zeros(1))
|
||||
|
||||
|
@ -151,6 +150,8 @@ class TDMPCPolicy(nn.Module):
|
|||
|
||||
t0 = step == 0
|
||||
|
||||
self.eval()
|
||||
|
||||
if len(self._queues["action"]) == 0:
|
||||
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
||||
|
||||
|
@ -172,7 +173,7 @@ class TDMPCPolicy(nn.Module):
|
|||
actions.append(action)
|
||||
action = torch.stack(actions)
|
||||
|
||||
# self.act returns an action for 1 timestep only, so we copy it over `n_action_steps` time
|
||||
# tdmpc returns an action for 1 timestep only, so we copy it over `n_action_steps` time
|
||||
if i in range(self.n_action_steps):
|
||||
self._queues["action"].append(action)
|
||||
|
||||
|
@ -325,7 +326,7 @@ class TDMPCPolicy(nn.Module):
|
|||
def _td_target(self, next_z, reward, mask):
|
||||
"""Compute the TD-target from a reward and the observation at the following time step."""
|
||||
next_v = self.model.V(next_z)
|
||||
td_target = reward + self.cfg.discount * mask * next_v
|
||||
td_target = reward + self.cfg.discount * mask * next_v.squeeze(2)
|
||||
return td_target
|
||||
|
||||
def forward(self, batch, step):
|
||||
|
@ -420,6 +421,8 @@ class TDMPCPolicy(nn.Module):
|
|||
# idxs = torch.cat([idxs, demo_idxs])
|
||||
# weights = torch.cat([weights, demo_weights])
|
||||
|
||||
batch_size = batch["index"].shape[0]
|
||||
|
||||
# TODO(rcadene): convert tdmpc with (batch size, time/horizon, channels)
|
||||
# instead of currently (time/horizon, batch size, channels) which is not the pytorch convention
|
||||
# batch size b = 256, time/horizon t = 5
|
||||
|
@ -433,7 +436,7 @@ class TDMPCPolicy(nn.Module):
|
|||
# idxs = batch["index"] # TODO(rcadene): use idxs to update sampling weights
|
||||
done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
|
||||
mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
|
||||
weights = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
|
||||
weights = torch.ones(batch_size, dtype=torch.bool, device=reward.device)
|
||||
|
||||
obses = {
|
||||
"rgb": batch["observation.image"],
|
||||
|
@ -476,7 +479,7 @@ class TDMPCPolicy(nn.Module):
|
|||
td_targets = self._td_target(next_z, reward, mask)
|
||||
|
||||
# Latent rollout
|
||||
zs = torch.empty(horizon + 1, self.batch_size, self.cfg.latent_dim, device=self.device)
|
||||
zs = torch.empty(horizon + 1, batch_size, self.cfg.latent_dim, device=self.device)
|
||||
reward_preds = torch.empty_like(reward, device=self.device)
|
||||
assert reward.shape[0] == horizon
|
||||
z = self.model.encode(obs)
|
||||
|
@ -485,22 +488,21 @@ class TDMPCPolicy(nn.Module):
|
|||
for t in range(horizon):
|
||||
z, reward_pred = self.model.next(z, action[t])
|
||||
zs[t + 1] = z
|
||||
reward_preds[t] = reward_pred
|
||||
reward_preds[t] = reward_pred.squeeze(1)
|
||||
|
||||
with torch.no_grad():
|
||||
v_target = self.model_target.Q(zs[:-1].detach(), action, return_type="min")
|
||||
|
||||
# Predictions
|
||||
qs = self.model.Q(zs[:-1], action, return_type="all")
|
||||
qs = qs.squeeze(3)
|
||||
value_info["Q"] = qs.mean().item()
|
||||
v = self.model.V(zs[:-1])
|
||||
value_info["V"] = v.mean().item()
|
||||
|
||||
# Losses
|
||||
rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view(-1, 1, 1)
|
||||
consistency_loss = (rho * torch.mean(h.mse(zs[1:], z_targets), dim=2, keepdim=True) * loss_mask).sum(
|
||||
dim=0
|
||||
)
|
||||
rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view(-1, 1)
|
||||
consistency_loss = (rho * torch.mean(h.mse(zs[1:], z_targets), dim=2) * loss_mask).sum(dim=0)
|
||||
reward_loss = (rho * h.mse(reward_preds, reward) * loss_mask).sum(dim=0)
|
||||
q_value_loss, priority_loss = 0, 0
|
||||
for q in range(self.cfg.num_q):
|
||||
|
@ -508,7 +510,9 @@ class TDMPCPolicy(nn.Module):
|
|||
priority_loss += (rho * h.l1(qs[q], td_targets) * loss_mask).sum(dim=0)
|
||||
|
||||
expectile = h.linear_schedule(self.cfg.expectile, step)
|
||||
v_value_loss = (rho * h.l2_expectile(v_target - v, expectile=expectile) * loss_mask).sum(dim=0)
|
||||
v_value_loss = (rho * h.l2_expectile(v_target - v, expectile=expectile).squeeze(2) * loss_mask).sum(
|
||||
dim=0
|
||||
)
|
||||
|
||||
total_loss = (
|
||||
self.cfg.consistency_coef * consistency_loss
|
||||
|
@ -517,7 +521,7 @@ class TDMPCPolicy(nn.Module):
|
|||
+ self.cfg.value_coef * v_value_loss
|
||||
)
|
||||
|
||||
weighted_loss = (total_loss.squeeze(1) * weights).mean()
|
||||
weighted_loss = (total_loss * weights).mean()
|
||||
weighted_loss.register_hook(lambda grad: grad * (1 / self.cfg.horizon))
|
||||
has_nan = torch.isnan(weighted_loss).item()
|
||||
if has_nan:
|
||||
|
|
|
@ -51,6 +51,7 @@ policy:
|
|||
utd: 1
|
||||
|
||||
n_obs_steps: ${n_obs_steps}
|
||||
n_action_steps: ${n_action_steps}
|
||||
|
||||
temporal_agg: false
|
||||
|
||||
|
|
|
@ -38,6 +38,7 @@ policy:
|
|||
|
||||
horizon: ${horizon}
|
||||
n_obs_steps: ${n_obs_steps}
|
||||
n_action_steps: ${n_action_steps}
|
||||
num_inference_steps: 100
|
||||
obs_as_global_cond: ${obs_as_global_cond}
|
||||
# crop_shape: null
|
||||
|
|
|
@ -36,7 +36,6 @@ policy:
|
|||
log_std_max: 2
|
||||
|
||||
# learning
|
||||
batch_size: 256
|
||||
max_buffer_size: 10000
|
||||
horizon: 5
|
||||
reward_coef: 0.5
|
||||
|
|
|
@ -86,14 +86,10 @@ def eval_policy(
|
|||
def maybe_render_frame(env):
|
||||
if save_video: # noqa: B023
|
||||
if return_first_video:
|
||||
# TODO(now): Put mode back in.
|
||||
visu = env.envs[0].render()
|
||||
# visu = env.envs[0].render(mode="visualization")
|
||||
visu = visu[None, ...] # add batch dim
|
||||
else:
|
||||
# TODO(now): Put mode back in.
|
||||
visu = np.stack([env.render() for env in env.envs])
|
||||
# visu = np.stack([env.render(mode="visualization") for env in env.envs])
|
||||
ep_frames.append(visu) # noqa: B023
|
||||
|
||||
for _ in range(num_episodes):
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "absl-py"
|
||||
|
@ -898,7 +898,7 @@ mujoco = "^2.3.7"
|
|||
type = "git"
|
||||
url = "git@github.com:huggingface/gym-aloha.git"
|
||||
reference = "HEAD"
|
||||
resolved_reference = "ec7200831e36c14e343cf7d275c6b047f2fe9d11"
|
||||
resolved_reference = "c636f05ba0d1760df94537da84c860be1487e17f"
|
||||
|
||||
[[package]]
|
||||
name = "gym-pusht"
|
||||
|
@ -3339,31 +3339,6 @@ numpy = "*"
|
|||
packaging = "*"
|
||||
protobuf = ">=3.20"
|
||||
|
||||
[[package]]
|
||||
name = "tensordict"
|
||||
version = "0.4.0+f622b2f"
|
||||
description = ""
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = []
|
||||
develop = false
|
||||
|
||||
[package.dependencies]
|
||||
cloudpickle = "*"
|
||||
numpy = "*"
|
||||
torch = ">=2.1.0"
|
||||
|
||||
[package.extras]
|
||||
checkpointing = ["torchsnapshot-nightly"]
|
||||
h5 = ["h5py (>=3.8)"]
|
||||
tests = ["pytest", "pytest-benchmark", "pytest-instafail", "pytest-rerunfailures", "pyyaml"]
|
||||
|
||||
[package.source]
|
||||
type = "git"
|
||||
url = "https://github.com/pytorch/tensordict"
|
||||
reference = "HEAD"
|
||||
resolved_reference = "f622b2f973320f769b6c09793ca827f27e47d603"
|
||||
|
||||
[[package]]
|
||||
name = "termcolor"
|
||||
version = "2.4.0"
|
||||
|
@ -3464,40 +3439,6 @@ typing-extensions = ">=4.8.0"
|
|||
opt-einsum = ["opt-einsum (>=3.3)"]
|
||||
optree = ["optree (>=0.9.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "torchrl"
|
||||
version = "0.4.0+13bef42"
|
||||
description = ""
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = []
|
||||
develop = false
|
||||
|
||||
[package.dependencies]
|
||||
cloudpickle = "*"
|
||||
numpy = "*"
|
||||
packaging = "*"
|
||||
tensordict = ">=0.4.0"
|
||||
torch = ">=2.1.0"
|
||||
|
||||
[package.extras]
|
||||
all = ["ale-py", "atari-py", "dm_control", "git", "gym", "gym[accept-rom-license]", "gymnasium", "h5py", "huggingface_hub", "hydra-core (>=1.1)", "hydra-submitit-launcher", "minari", "moviepy", "mujoco", "pandas", "pettingzoo (>=1.24.1)", "pillow", "pygame", "pytest", "pytest-instafail", "pyyaml", "requests", "scikit-learn", "scipy", "tensorboard", "torchsnapshot", "torchvision", "tqdm", "vmas (>=1.2.10)", "wandb"]
|
||||
atari = ["ale-py", "atari-py", "gym", "gym[accept-rom-license]", "pygame"]
|
||||
checkpointing = ["torchsnapshot"]
|
||||
dm-control = ["dm_control"]
|
||||
gym-continuous = ["gymnasium", "mujoco"]
|
||||
marl = ["pettingzoo (>=1.24.1)", "vmas (>=1.2.10)"]
|
||||
offline-data = ["h5py", "huggingface_hub", "minari", "pandas", "pillow", "requests", "scikit-learn", "torchvision", "tqdm"]
|
||||
rendering = ["moviepy"]
|
||||
tests = ["pytest", "pytest-instafail", "pyyaml", "scipy"]
|
||||
utils = ["git", "hydra-core (>=1.1)", "hydra-submitit-launcher", "tensorboard", "tqdm", "wandb"]
|
||||
|
||||
[package.source]
|
||||
type = "git"
|
||||
url = "https://github.com/pytorch/rl"
|
||||
reference = "13bef426dcfa5887c6e5034a6e9697993fa92c37"
|
||||
resolved_reference = "13bef426dcfa5887c6e5034a6e9697993fa92c37"
|
||||
|
||||
[[package]]
|
||||
name = "torchvision"
|
||||
version = "0.17.2"
|
||||
|
@ -3741,4 +3682,4 @@ xarm = ["gym-xarm"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "cb450ac7186e004536d75409edd42cd96062f7b1fd47822a5460d12eab8762f9"
|
||||
content-hash = "bf4627c62a45764931729ce373f1038fe289b6caebb01e66d878f6f278c54518"
|
||||
|
|
|
@ -39,8 +39,6 @@ scikit-image = "^0.22.0"
|
|||
numba = "^0.59.0"
|
||||
mpmath = "^1.3.0"
|
||||
torch = "^2.2.1"
|
||||
tensordict = {git = "https://github.com/pytorch/tensordict"}
|
||||
torchrl = {git = "https://github.com/pytorch/rl", rev = "13bef426dcfa5887c6e5034a6e9697993fa92c37"}
|
||||
mujoco = "^2.3.7"
|
||||
opencv-python = "^4.9.0.80"
|
||||
diffusers = "^0.26.3"
|
||||
|
|
|
@ -15,10 +15,10 @@ from .utils import DEVICE, DEFAULT_CONFIG_PATH
|
|||
("xarm", "tdmpc", ["policy.mpc=true"]),
|
||||
("pusht", "tdmpc", ["policy.mpc=false"]),
|
||||
("pusht", "diffusion", []),
|
||||
# ("aloha", "act", ["env.task=sim_insertion", "dataset_id=aloha_sim_insertion_human"]),
|
||||
#("aloha", "act", ["env.task=sim_insertion", "dataset_id=aloha_sim_insertion_scripted"]),
|
||||
#("aloha", "act", ["env.task=sim_transfer_cube", "dataset_id=aloha_sim_transfer_cube_human"]),
|
||||
#("aloha", "act", ["env.task=sim_transfer_cube", "dataset_id=aloha_sim_transfer_cube_scripted"]),
|
||||
("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_id=aloha_sim_insertion_human"]),
|
||||
("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_id=aloha_sim_insertion_scripted"]),
|
||||
("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_human"]),
|
||||
("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_scripted"]),
|
||||
# TODO(aliberts): xarm not working with diffusion
|
||||
# ("xarm", "diffusion", []),
|
||||
],
|
||||
|
@ -49,7 +49,7 @@ def test_policy(env_name, policy_name, extra_overrides):
|
|||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=4,
|
||||
batch_size=cfg.policy.batch_size,
|
||||
batch_size=2,
|
||||
shuffle=True,
|
||||
pin_memory=DEVICE != "cpu",
|
||||
drop_last=True,
|
||||
|
|
Loading…
Reference in New Issue