Merge remote-tracking branch 'upstream/user/rcadene/2024_03_31_remove_torchrl' into refactor_act
This commit is contained in:
commit
e6c6c2367f
|
@ -1,4 +1,4 @@
|
||||||
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand.
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "absl-py"
|
name = "absl-py"
|
||||||
|
@ -889,6 +889,69 @@ files = [
|
||||||
[package.extras]
|
[package.extras]
|
||||||
protobuf = ["grpcio-tools (>=1.62.1)"]
|
protobuf = ["grpcio-tools (>=1.62.1)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "gym-aloha"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "A gym environment for ALOHA"
|
||||||
|
optional = true
|
||||||
|
python-versions = "^3.10"
|
||||||
|
files = []
|
||||||
|
develop = false
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
dm-control = "1.0.14"
|
||||||
|
gymnasium = "^0.29.1"
|
||||||
|
mujoco = "^2.3.7"
|
||||||
|
|
||||||
|
[package.source]
|
||||||
|
type = "git"
|
||||||
|
url = "git@github.com:huggingface/gym-aloha.git"
|
||||||
|
reference = "HEAD"
|
||||||
|
resolved_reference = "c636f05ba0d1760df94537da84c860be1487e17f"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "gym-pusht"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "A gymnasium environment for PushT."
|
||||||
|
optional = true
|
||||||
|
python-versions = "^3.10"
|
||||||
|
files = []
|
||||||
|
develop = false
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
gymnasium = "^0.29.1"
|
||||||
|
opencv-python = "^4.9.0.80"
|
||||||
|
pygame = "^2.5.2"
|
||||||
|
pymunk = "^6.6.0"
|
||||||
|
scikit-image = "^0.22.0"
|
||||||
|
shapely = "^2.0.3"
|
||||||
|
|
||||||
|
[package.source]
|
||||||
|
type = "git"
|
||||||
|
url = "git@github.com:huggingface/gym-pusht.git"
|
||||||
|
reference = "HEAD"
|
||||||
|
resolved_reference = "6c9893504f670ff069d0f759a733e971ea1efdbf"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "gym-xarm"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "A gym environment for xArm"
|
||||||
|
optional = true
|
||||||
|
python-versions = "^3.10"
|
||||||
|
files = []
|
||||||
|
develop = false
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
gymnasium = "^0.29.1"
|
||||||
|
gymnasium-robotics = "^1.2.4"
|
||||||
|
mujoco = "^2.3.7"
|
||||||
|
|
||||||
|
[package.source]
|
||||||
|
type = "git"
|
||||||
|
url = "git@github.com:huggingface/gym-xarm.git"
|
||||||
|
reference = "HEAD"
|
||||||
|
resolved_reference = "08ddd5a9400783a6898bbf3c3014fc5da3961b9d"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "gymnasium"
|
name = "gymnasium"
|
||||||
version = "0.29.1"
|
version = "0.29.1"
|
||||||
|
@ -2988,31 +3051,6 @@ numpy = "*"
|
||||||
packaging = "*"
|
packaging = "*"
|
||||||
protobuf = ">=3.20"
|
protobuf = ">=3.20"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "tensordict"
|
|
||||||
version = "0.4.0+b4c91e8"
|
|
||||||
description = ""
|
|
||||||
optional = false
|
|
||||||
python-versions = "*"
|
|
||||||
files = []
|
|
||||||
develop = false
|
|
||||||
|
|
||||||
[package.dependencies]
|
|
||||||
cloudpickle = "*"
|
|
||||||
numpy = "*"
|
|
||||||
torch = ">=2.1.0"
|
|
||||||
|
|
||||||
[package.extras]
|
|
||||||
checkpointing = ["torchsnapshot-nightly"]
|
|
||||||
h5 = ["h5py (>=3.8)"]
|
|
||||||
tests = ["pytest", "pytest-benchmark", "pytest-instafail", "pytest-rerunfailures", "pyyaml"]
|
|
||||||
|
|
||||||
[package.source]
|
|
||||||
type = "git"
|
|
||||||
url = "https://github.com/pytorch/tensordict"
|
|
||||||
reference = "HEAD"
|
|
||||||
resolved_reference = "b4c91e8828c538ca0a50d8383fd99311a9afb078"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "termcolor"
|
name = "termcolor"
|
||||||
version = "2.4.0"
|
version = "2.4.0"
|
||||||
|
@ -3091,40 +3129,6 @@ type = "legacy"
|
||||||
url = "https://download.pytorch.org/whl/cpu"
|
url = "https://download.pytorch.org/whl/cpu"
|
||||||
reference = "torch-cpu"
|
reference = "torch-cpu"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "torchrl"
|
|
||||||
version = "0.4.0+13bef42"
|
|
||||||
description = ""
|
|
||||||
optional = false
|
|
||||||
python-versions = "*"
|
|
||||||
files = []
|
|
||||||
develop = false
|
|
||||||
|
|
||||||
[package.dependencies]
|
|
||||||
cloudpickle = "*"
|
|
||||||
numpy = "*"
|
|
||||||
packaging = "*"
|
|
||||||
tensordict = ">=0.4.0"
|
|
||||||
torch = ">=2.1.0"
|
|
||||||
|
|
||||||
[package.extras]
|
|
||||||
all = ["ale-py", "atari-py", "dm_control", "git", "gym", "gym[accept-rom-license]", "gymnasium", "h5py", "huggingface_hub", "hydra-core (>=1.1)", "hydra-submitit-launcher", "minari", "moviepy", "mujoco", "pandas", "pettingzoo (>=1.24.1)", "pillow", "pygame", "pytest", "pytest-instafail", "pyyaml", "requests", "scikit-learn", "scipy", "tensorboard", "torchsnapshot", "torchvision", "tqdm", "vmas (>=1.2.10)", "wandb"]
|
|
||||||
atari = ["ale-py", "atari-py", "gym", "gym[accept-rom-license]", "pygame"]
|
|
||||||
checkpointing = ["torchsnapshot"]
|
|
||||||
dm-control = ["dm_control"]
|
|
||||||
gym-continuous = ["gymnasium", "mujoco"]
|
|
||||||
marl = ["pettingzoo (>=1.24.1)", "vmas (>=1.2.10)"]
|
|
||||||
offline-data = ["h5py", "huggingface_hub", "minari", "pandas", "pillow", "requests", "scikit-learn", "torchvision", "tqdm"]
|
|
||||||
rendering = ["moviepy"]
|
|
||||||
tests = ["pytest", "pytest-instafail", "pyyaml", "scipy"]
|
|
||||||
utils = ["git", "hydra-core (>=1.1)", "hydra-submitit-launcher", "tensorboard", "tqdm", "wandb"]
|
|
||||||
|
|
||||||
[package.source]
|
|
||||||
type = "git"
|
|
||||||
url = "https://github.com/pytorch/rl"
|
|
||||||
reference = "13bef426dcfa5887c6e5034a6e9697993fa92c37"
|
|
||||||
resolved_reference = "13bef426dcfa5887c6e5034a6e9697993fa92c37"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "torchvision"
|
name = "torchvision"
|
||||||
version = "0.17.1+cpu"
|
version = "0.17.1+cpu"
|
||||||
|
@ -3330,4 +3334,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.10"
|
python-versions = "^3.10"
|
||||||
content-hash = "8800bb8b24312d17b765cd2ce2799f49436171dd5fbf1bec3b07f853cfa9befd"
|
content-hash = "32cd6caa01276a90b37cb177204e5b1511e92838f3f0268391034042d56f3bd6"
|
||||||
|
|
|
@ -39,8 +39,6 @@ scikit-image = "^0.22.0"
|
||||||
numba = "^0.59.0"
|
numba = "^0.59.0"
|
||||||
mpmath = "^1.3.0"
|
mpmath = "^1.3.0"
|
||||||
torch = {version = "^2.2.1", source = "torch-cpu"}
|
torch = {version = "^2.2.1", source = "torch-cpu"}
|
||||||
tensordict = {git = "https://github.com/pytorch/tensordict"}
|
|
||||||
torchrl = {git = "https://github.com/pytorch/rl", rev = "13bef426dcfa5887c6e5034a6e9697993fa92c37"}
|
|
||||||
mujoco = "^2.3.7"
|
mujoco = "^2.3.7"
|
||||||
opencv-python = "^4.9.0.80"
|
opencv-python = "^4.9.0.80"
|
||||||
diffusers = "^0.26.3"
|
diffusers = "^0.26.3"
|
||||||
|
@ -53,7 +51,12 @@ huggingface-hub = "^0.21.4"
|
||||||
gymnasium-robotics = "^1.2.4"
|
gymnasium-robotics = "^1.2.4"
|
||||||
gymnasium = "^0.29.1"
|
gymnasium = "^0.29.1"
|
||||||
cmake = "^3.29.0.1"
|
cmake = "^3.29.0.1"
|
||||||
|
gym-pusht = { git = "git@github.com:huggingface/gym-pusht.git", optional = true}
|
||||||
|
gym-xarm = { git = "git@github.com:huggingface/gym-xarm.git", optional = true}
|
||||||
|
gym-aloha = { git = "git@github.com:huggingface/gym-aloha.git", optional = true}
|
||||||
|
# gym-pusht = { path = "../gym-pusht", develop = true, optional = true}
|
||||||
|
# gym-xarm = { path = "../gym-xarm", develop = true, optional = true}
|
||||||
|
# gym-aloha = { path = "../gym-aloha", develop = true, optional = true}
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
pre-commit = "^3.6.2"
|
pre-commit = "^3.6.2"
|
||||||
|
|
|
@ -1,82 +0,0 @@
|
||||||
from collections import deque
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import Tensor, nn
|
|
||||||
|
|
||||||
|
|
||||||
class AbstractPolicy(nn.Module):
|
|
||||||
"""Base policy which all policies should be derived from.
|
|
||||||
|
|
||||||
The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its
|
|
||||||
documentation for more information.
|
|
||||||
|
|
||||||
Note:
|
|
||||||
When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
|
|
||||||
1. set the required class attributes:
|
|
||||||
- for classes inheriting from `AbstractDataset`: `available_datasets`
|
|
||||||
- for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
|
|
||||||
- for classes inheriting from `AbstractPolicy`: `name`
|
|
||||||
2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
|
|
||||||
3. update variables in `tests/test_available.py` by importing your new class
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: str | None = None # same name should be used to instantiate the policy in factory.py
|
|
||||||
|
|
||||||
def __init__(self, n_action_steps: int | None):
|
|
||||||
"""
|
|
||||||
n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single
|
|
||||||
action is returned by `select_actions` and that doesn't have a horizon dimension. The `forward` method then
|
|
||||||
adds that dimension.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
assert self.name is not None, "Subclasses of `AbstractPolicy` should set the `name` class attribute."
|
|
||||||
self.n_action_steps = n_action_steps
|
|
||||||
self.clear_action_queue()
|
|
||||||
|
|
||||||
def update(self, replay_buffer, step):
|
|
||||||
"""One step of the policy's learning algorithm."""
|
|
||||||
raise NotImplementedError("Abstract method")
|
|
||||||
|
|
||||||
def save(self, fp):
|
|
||||||
torch.save(self.state_dict(), fp)
|
|
||||||
|
|
||||||
def load(self, fp):
|
|
||||||
d = torch.load(fp)
|
|
||||||
self.load_state_dict(d)
|
|
||||||
|
|
||||||
def select_actions(self, observation) -> Tensor:
|
|
||||||
"""Select an action (or trajectory of actions) based on an observation during rollout.
|
|
||||||
|
|
||||||
If n_action_steps was provided at initialization, this should return a (batch_size, n_action_steps, *) tensor of
|
|
||||||
actions. Otherwise if n_actions_steps is None, this should return a (batch_size, *) tensor of actions.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError("Abstract method")
|
|
||||||
|
|
||||||
def clear_action_queue(self):
|
|
||||||
"""This should be called whenever the environment is reset."""
|
|
||||||
if self.n_action_steps is not None:
|
|
||||||
self._action_queue = deque([], maxlen=self.n_action_steps)
|
|
||||||
|
|
||||||
def forward(self, *args, **kwargs) -> Tensor:
|
|
||||||
"""Inference step that makes multi-step policies compatible with their single-step environments.
|
|
||||||
|
|
||||||
WARNING: In general, this should not be overriden.
|
|
||||||
|
|
||||||
Consider a "policy" that observes the environment then charts a course of N actions to take. To make this fit
|
|
||||||
into the formalism of a TorchRL environment, we view it as being effectively a policy that (1) makes an
|
|
||||||
observation and prepares a queue of actions, (2) consumes that queue when queried, regardless of the environment
|
|
||||||
observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that
|
|
||||||
the subclass doesn't have to.
|
|
||||||
|
|
||||||
This method effectively wraps the `select_actions` method of the subclass. The following assumptions are made:
|
|
||||||
1. The `select_actions` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is
|
|
||||||
the action trajectory horizon and * is the action dimensions.
|
|
||||||
2. Prior to the `select_actions` method being called, theres is an `n_action_steps` instance attribute defined.
|
|
||||||
"""
|
|
||||||
if self.n_action_steps is None:
|
|
||||||
return self.select_actions(*args, **kwargs)
|
|
||||||
if len(self._action_queue) == 0:
|
|
||||||
# `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has shape
|
|
||||||
# (n_action_steps, batch_size, *), hence the transpose.
|
|
||||||
self._action_queue.extend(self.select_actions(*args, **kwargs).transpose(0, 1))
|
|
||||||
return self._action_queue.popleft()
|
|
|
@ -65,7 +65,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
"ActionChunkingTransformerPolicy does not handle multiple observation steps."
|
"ActionChunkingTransformerPolicy does not handle multiple observation steps."
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, cfg, device, n_action_steps=1):
|
def __init__(self, cfg, device):
|
||||||
"""
|
"""
|
||||||
TODO(alexander-soare): Add documentation for all parameters once we have model configs established.
|
TODO(alexander-soare): Add documentation for all parameters once we have model configs established.
|
||||||
"""
|
"""
|
||||||
|
@ -73,7 +73,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
if getattr(cfg, "n_obs_steps", 1) != 1:
|
if getattr(cfg, "n_obs_steps", 1) != 1:
|
||||||
raise ValueError(self._multiple_obs_steps_not_handled_msg)
|
raise ValueError(self._multiple_obs_steps_not_handled_msg)
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.n_action_steps = n_action_steps
|
self.n_action_steps = cfg.n_action_steps
|
||||||
self.device = get_safe_torch_device(device)
|
self.device = get_safe_torch_device(device)
|
||||||
self.camera_names = cfg.camera_names
|
self.camera_names = cfg.camera_names
|
||||||
self.use_vae = cfg.use_vae
|
self.use_vae = cfg.use_vae
|
||||||
|
@ -176,7 +176,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
if self.n_action_steps is not None:
|
if self.n_action_steps is not None:
|
||||||
self._action_queue = deque([], maxlen=self.n_action_steps)
|
self._action_queue = deque([], maxlen=self.n_action_steps)
|
||||||
|
|
||||||
def select_action(self, batch: dict[str, Tensor], *_) -> Tensor:
|
def select_action(self, batch: dict[str, Tensor], *_, **__) -> Tensor:
|
||||||
"""
|
"""
|
||||||
This method wraps `select_actions` in order to return one action at a time for execution in the
|
This method wraps `select_actions` in order to return one action at a time for execution in the
|
||||||
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
||||||
|
@ -244,7 +244,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
# Note: no squeeze is required for "observation.images.top" because then we'd have to unsqueeze to get
|
# Note: no squeeze is required for "observation.images.top" because then we'd have to unsqueeze to get
|
||||||
# the image index dimension.
|
# the image index dimension.
|
||||||
|
|
||||||
def update(self, batch, *_) -> dict:
|
def update(self, batch, *_, **__) -> dict:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
self._preprocess_batch(batch)
|
self._preprocess_batch(batch)
|
||||||
|
|
||||||
|
@ -311,7 +311,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
|
|
||||||
def _forward(
|
def _forward(
|
||||||
self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None
|
self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None
|
||||||
) -> tuple[Tensor, tuple[Tensor, Tensor]]:
|
) -> tuple[Tensor, tuple[Tensor | None, Tensor | None]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
robot_state: (B, J) batch of robot joint configurations.
|
robot_state: (B, J) batch of robot joint configurations.
|
||||||
|
@ -344,16 +344,16 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
# Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
|
# Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
|
||||||
pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D)
|
pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D)
|
||||||
|
|
||||||
# Forward pass through VAE encoder.
|
# Forward pass through VAE encoder to get the latent PDF parameters.
|
||||||
cls_token_out = self.vae_encoder(
|
cls_token_out = self.vae_encoder(
|
||||||
vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2)
|
vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2)
|
||||||
)[0] # select the class token, with shape (B, D)
|
)[0] # select the class token, with shape (B, D)
|
||||||
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
|
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
|
||||||
|
|
||||||
# Sample the latent with the reparameterization trick.
|
|
||||||
mu = latent_pdf_params[:, : self.latent_dim]
|
mu = latent_pdf_params[:, : self.latent_dim]
|
||||||
# This is 2log(sigma). Done this way to match the original implementation.
|
# This is 2log(sigma). Done this way to match the original implementation.
|
||||||
log_sigma_x2 = latent_pdf_params[:, self.latent_dim :]
|
log_sigma_x2 = latent_pdf_params[:, self.latent_dim :]
|
||||||
|
|
||||||
|
# Sample the latent with the reparameterization trick.
|
||||||
latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu)
|
latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu)
|
||||||
else:
|
else:
|
||||||
# When not using the VAE encoder, we set the latent to be all zeros.
|
# When not using the VAE encoder, we set the latent to be all zeros.
|
||||||
|
|
|
@ -16,18 +16,15 @@ def make_policy(cfg):
|
||||||
cfg_obs_encoder=cfg.obs_encoder,
|
cfg_obs_encoder=cfg.obs_encoder,
|
||||||
cfg_optimizer=cfg.optimizer,
|
cfg_optimizer=cfg.optimizer,
|
||||||
cfg_ema=cfg.ema,
|
cfg_ema=cfg.ema,
|
||||||
n_obs_steps=cfg.n_obs_steps,
|
# n_obs_steps=cfg.n_obs_steps,
|
||||||
n_action_steps=cfg.n_action_steps,
|
# n_action_steps=cfg.n_action_steps,
|
||||||
**cfg.policy,
|
**cfg.policy,
|
||||||
)
|
)
|
||||||
elif cfg.policy.name == "act":
|
elif cfg.policy.name == "act":
|
||||||
from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
|
from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
|
||||||
|
|
||||||
policy = ActionChunkingTransformerPolicy(
|
policy = ActionChunkingTransformerPolicy(cfg.policy, cfg.device)
|
||||||
cfg.policy,
|
policy.to(cfg.device)
|
||||||
cfg.device,
|
|
||||||
n_action_steps=cfg.n_action_steps,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(cfg.policy.name)
|
raise ValueError(cfg.policy.name)
|
||||||
|
|
||||||
|
|
|
@ -110,7 +110,6 @@ class TDMPCPolicy(nn.Module):
|
||||||
# self.bc_optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
|
# self.bc_optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
self.model_target.eval()
|
self.model_target.eval()
|
||||||
self.batch_size = cfg.batch_size
|
|
||||||
|
|
||||||
self.register_buffer("step", torch.zeros(1))
|
self.register_buffer("step", torch.zeros(1))
|
||||||
|
|
||||||
|
@ -151,6 +150,8 @@ class TDMPCPolicy(nn.Module):
|
||||||
|
|
||||||
t0 = step == 0
|
t0 = step == 0
|
||||||
|
|
||||||
|
self.eval()
|
||||||
|
|
||||||
if len(self._queues["action"]) == 0:
|
if len(self._queues["action"]) == 0:
|
||||||
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
||||||
|
|
||||||
|
@ -172,7 +173,7 @@ class TDMPCPolicy(nn.Module):
|
||||||
actions.append(action)
|
actions.append(action)
|
||||||
action = torch.stack(actions)
|
action = torch.stack(actions)
|
||||||
|
|
||||||
# self.act returns an action for 1 timestep only, so we copy it over `n_action_steps` time
|
# tdmpc returns an action for 1 timestep only, so we copy it over `n_action_steps` time
|
||||||
if i in range(self.n_action_steps):
|
if i in range(self.n_action_steps):
|
||||||
self._queues["action"].append(action)
|
self._queues["action"].append(action)
|
||||||
|
|
||||||
|
@ -325,7 +326,7 @@ class TDMPCPolicy(nn.Module):
|
||||||
def _td_target(self, next_z, reward, mask):
|
def _td_target(self, next_z, reward, mask):
|
||||||
"""Compute the TD-target from a reward and the observation at the following time step."""
|
"""Compute the TD-target from a reward and the observation at the following time step."""
|
||||||
next_v = self.model.V(next_z)
|
next_v = self.model.V(next_z)
|
||||||
td_target = reward + self.cfg.discount * mask * next_v
|
td_target = reward + self.cfg.discount * mask * next_v.squeeze(2)
|
||||||
return td_target
|
return td_target
|
||||||
|
|
||||||
def forward(self, batch, step):
|
def forward(self, batch, step):
|
||||||
|
@ -420,6 +421,8 @@ class TDMPCPolicy(nn.Module):
|
||||||
# idxs = torch.cat([idxs, demo_idxs])
|
# idxs = torch.cat([idxs, demo_idxs])
|
||||||
# weights = torch.cat([weights, demo_weights])
|
# weights = torch.cat([weights, demo_weights])
|
||||||
|
|
||||||
|
batch_size = batch["index"].shape[0]
|
||||||
|
|
||||||
# TODO(rcadene): convert tdmpc with (batch size, time/horizon, channels)
|
# TODO(rcadene): convert tdmpc with (batch size, time/horizon, channels)
|
||||||
# instead of currently (time/horizon, batch size, channels) which is not the pytorch convention
|
# instead of currently (time/horizon, batch size, channels) which is not the pytorch convention
|
||||||
# batch size b = 256, time/horizon t = 5
|
# batch size b = 256, time/horizon t = 5
|
||||||
|
@ -433,7 +436,7 @@ class TDMPCPolicy(nn.Module):
|
||||||
# idxs = batch["index"] # TODO(rcadene): use idxs to update sampling weights
|
# idxs = batch["index"] # TODO(rcadene): use idxs to update sampling weights
|
||||||
done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
|
done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
|
||||||
mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
|
mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
|
||||||
weights = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
|
weights = torch.ones(batch_size, dtype=torch.bool, device=reward.device)
|
||||||
|
|
||||||
obses = {
|
obses = {
|
||||||
"rgb": batch["observation.image"],
|
"rgb": batch["observation.image"],
|
||||||
|
@ -476,7 +479,7 @@ class TDMPCPolicy(nn.Module):
|
||||||
td_targets = self._td_target(next_z, reward, mask)
|
td_targets = self._td_target(next_z, reward, mask)
|
||||||
|
|
||||||
# Latent rollout
|
# Latent rollout
|
||||||
zs = torch.empty(horizon + 1, self.batch_size, self.cfg.latent_dim, device=self.device)
|
zs = torch.empty(horizon + 1, batch_size, self.cfg.latent_dim, device=self.device)
|
||||||
reward_preds = torch.empty_like(reward, device=self.device)
|
reward_preds = torch.empty_like(reward, device=self.device)
|
||||||
assert reward.shape[0] == horizon
|
assert reward.shape[0] == horizon
|
||||||
z = self.model.encode(obs)
|
z = self.model.encode(obs)
|
||||||
|
@ -485,22 +488,21 @@ class TDMPCPolicy(nn.Module):
|
||||||
for t in range(horizon):
|
for t in range(horizon):
|
||||||
z, reward_pred = self.model.next(z, action[t])
|
z, reward_pred = self.model.next(z, action[t])
|
||||||
zs[t + 1] = z
|
zs[t + 1] = z
|
||||||
reward_preds[t] = reward_pred
|
reward_preds[t] = reward_pred.squeeze(1)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
v_target = self.model_target.Q(zs[:-1].detach(), action, return_type="min")
|
v_target = self.model_target.Q(zs[:-1].detach(), action, return_type="min")
|
||||||
|
|
||||||
# Predictions
|
# Predictions
|
||||||
qs = self.model.Q(zs[:-1], action, return_type="all")
|
qs = self.model.Q(zs[:-1], action, return_type="all")
|
||||||
|
qs = qs.squeeze(3)
|
||||||
value_info["Q"] = qs.mean().item()
|
value_info["Q"] = qs.mean().item()
|
||||||
v = self.model.V(zs[:-1])
|
v = self.model.V(zs[:-1])
|
||||||
value_info["V"] = v.mean().item()
|
value_info["V"] = v.mean().item()
|
||||||
|
|
||||||
# Losses
|
# Losses
|
||||||
rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view(-1, 1, 1)
|
rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view(-1, 1)
|
||||||
consistency_loss = (rho * torch.mean(h.mse(zs[1:], z_targets), dim=2, keepdim=True) * loss_mask).sum(
|
consistency_loss = (rho * torch.mean(h.mse(zs[1:], z_targets), dim=2) * loss_mask).sum(dim=0)
|
||||||
dim=0
|
|
||||||
)
|
|
||||||
reward_loss = (rho * h.mse(reward_preds, reward) * loss_mask).sum(dim=0)
|
reward_loss = (rho * h.mse(reward_preds, reward) * loss_mask).sum(dim=0)
|
||||||
q_value_loss, priority_loss = 0, 0
|
q_value_loss, priority_loss = 0, 0
|
||||||
for q in range(self.cfg.num_q):
|
for q in range(self.cfg.num_q):
|
||||||
|
@ -508,7 +510,9 @@ class TDMPCPolicy(nn.Module):
|
||||||
priority_loss += (rho * h.l1(qs[q], td_targets) * loss_mask).sum(dim=0)
|
priority_loss += (rho * h.l1(qs[q], td_targets) * loss_mask).sum(dim=0)
|
||||||
|
|
||||||
expectile = h.linear_schedule(self.cfg.expectile, step)
|
expectile = h.linear_schedule(self.cfg.expectile, step)
|
||||||
v_value_loss = (rho * h.l2_expectile(v_target - v, expectile=expectile) * loss_mask).sum(dim=0)
|
v_value_loss = (rho * h.l2_expectile(v_target - v, expectile=expectile).squeeze(2) * loss_mask).sum(
|
||||||
|
dim=0
|
||||||
|
)
|
||||||
|
|
||||||
total_loss = (
|
total_loss = (
|
||||||
self.cfg.consistency_coef * consistency_loss
|
self.cfg.consistency_coef * consistency_loss
|
||||||
|
@ -517,7 +521,7 @@ class TDMPCPolicy(nn.Module):
|
||||||
+ self.cfg.value_coef * v_value_loss
|
+ self.cfg.value_coef * v_value_loss
|
||||||
)
|
)
|
||||||
|
|
||||||
weighted_loss = (total_loss.squeeze(1) * weights).mean()
|
weighted_loss = (total_loss * weights).mean()
|
||||||
weighted_loss.register_hook(lambda grad: grad * (1 / self.cfg.horizon))
|
weighted_loss.register_hook(lambda grad: grad * (1 / self.cfg.horizon))
|
||||||
has_nan = torch.isnan(weighted_loss).item()
|
has_nan = torch.isnan(weighted_loss).item()
|
||||||
if has_nan:
|
if has_nan:
|
||||||
|
|
|
@ -51,6 +51,7 @@ policy:
|
||||||
utd: 1
|
utd: 1
|
||||||
|
|
||||||
n_obs_steps: ${n_obs_steps}
|
n_obs_steps: ${n_obs_steps}
|
||||||
|
n_action_steps: ${n_action_steps}
|
||||||
|
|
||||||
temporal_agg: false
|
temporal_agg: false
|
||||||
|
|
||||||
|
|
|
@ -38,6 +38,7 @@ policy:
|
||||||
|
|
||||||
horizon: ${horizon}
|
horizon: ${horizon}
|
||||||
n_obs_steps: ${n_obs_steps}
|
n_obs_steps: ${n_obs_steps}
|
||||||
|
n_action_steps: ${n_action_steps}
|
||||||
num_inference_steps: 100
|
num_inference_steps: 100
|
||||||
obs_as_global_cond: ${obs_as_global_cond}
|
obs_as_global_cond: ${obs_as_global_cond}
|
||||||
# crop_shape: null
|
# crop_shape: null
|
||||||
|
|
|
@ -36,7 +36,6 @@ policy:
|
||||||
log_std_max: 2
|
log_std_max: 2
|
||||||
|
|
||||||
# learning
|
# learning
|
||||||
batch_size: 256
|
|
||||||
max_buffer_size: 10000
|
max_buffer_size: 10000
|
||||||
horizon: 5
|
horizon: 5
|
||||||
reward_coef: 0.5
|
reward_coef: 0.5
|
||||||
|
|
|
@ -86,14 +86,10 @@ def eval_policy(
|
||||||
def maybe_render_frame(env):
|
def maybe_render_frame(env):
|
||||||
if save_video: # noqa: B023
|
if save_video: # noqa: B023
|
||||||
if return_first_video:
|
if return_first_video:
|
||||||
# TODO(now): Put mode back in.
|
|
||||||
visu = env.envs[0].render()
|
visu = env.envs[0].render()
|
||||||
# visu = env.envs[0].render(mode="visualization")
|
|
||||||
visu = visu[None, ...] # add batch dim
|
visu = visu[None, ...] # add batch dim
|
||||||
else:
|
else:
|
||||||
# TODO(now): Put mode back in.
|
|
||||||
visu = np.stack([env.render() for env in env.envs])
|
visu = np.stack([env.render() for env in env.envs])
|
||||||
# visu = np.stack([env.render(mode="visualization") for env in env.envs])
|
|
||||||
ep_frames.append(visu) # noqa: B023
|
ep_frames.append(visu) # noqa: B023
|
||||||
|
|
||||||
for _ in range(num_episodes):
|
for _ in range(num_episodes):
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand.
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "absl-py"
|
name = "absl-py"
|
||||||
|
@ -898,7 +898,7 @@ mujoco = "^2.3.7"
|
||||||
type = "git"
|
type = "git"
|
||||||
url = "git@github.com:huggingface/gym-aloha.git"
|
url = "git@github.com:huggingface/gym-aloha.git"
|
||||||
reference = "HEAD"
|
reference = "HEAD"
|
||||||
resolved_reference = "ec7200831e36c14e343cf7d275c6b047f2fe9d11"
|
resolved_reference = "c636f05ba0d1760df94537da84c860be1487e17f"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "gym-pusht"
|
name = "gym-pusht"
|
||||||
|
@ -3339,31 +3339,6 @@ numpy = "*"
|
||||||
packaging = "*"
|
packaging = "*"
|
||||||
protobuf = ">=3.20"
|
protobuf = ">=3.20"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "tensordict"
|
|
||||||
version = "0.4.0+f622b2f"
|
|
||||||
description = ""
|
|
||||||
optional = false
|
|
||||||
python-versions = "*"
|
|
||||||
files = []
|
|
||||||
develop = false
|
|
||||||
|
|
||||||
[package.dependencies]
|
|
||||||
cloudpickle = "*"
|
|
||||||
numpy = "*"
|
|
||||||
torch = ">=2.1.0"
|
|
||||||
|
|
||||||
[package.extras]
|
|
||||||
checkpointing = ["torchsnapshot-nightly"]
|
|
||||||
h5 = ["h5py (>=3.8)"]
|
|
||||||
tests = ["pytest", "pytest-benchmark", "pytest-instafail", "pytest-rerunfailures", "pyyaml"]
|
|
||||||
|
|
||||||
[package.source]
|
|
||||||
type = "git"
|
|
||||||
url = "https://github.com/pytorch/tensordict"
|
|
||||||
reference = "HEAD"
|
|
||||||
resolved_reference = "f622b2f973320f769b6c09793ca827f27e47d603"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "termcolor"
|
name = "termcolor"
|
||||||
version = "2.4.0"
|
version = "2.4.0"
|
||||||
|
@ -3464,40 +3439,6 @@ typing-extensions = ">=4.8.0"
|
||||||
opt-einsum = ["opt-einsum (>=3.3)"]
|
opt-einsum = ["opt-einsum (>=3.3)"]
|
||||||
optree = ["optree (>=0.9.1)"]
|
optree = ["optree (>=0.9.1)"]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "torchrl"
|
|
||||||
version = "0.4.0+13bef42"
|
|
||||||
description = ""
|
|
||||||
optional = false
|
|
||||||
python-versions = "*"
|
|
||||||
files = []
|
|
||||||
develop = false
|
|
||||||
|
|
||||||
[package.dependencies]
|
|
||||||
cloudpickle = "*"
|
|
||||||
numpy = "*"
|
|
||||||
packaging = "*"
|
|
||||||
tensordict = ">=0.4.0"
|
|
||||||
torch = ">=2.1.0"
|
|
||||||
|
|
||||||
[package.extras]
|
|
||||||
all = ["ale-py", "atari-py", "dm_control", "git", "gym", "gym[accept-rom-license]", "gymnasium", "h5py", "huggingface_hub", "hydra-core (>=1.1)", "hydra-submitit-launcher", "minari", "moviepy", "mujoco", "pandas", "pettingzoo (>=1.24.1)", "pillow", "pygame", "pytest", "pytest-instafail", "pyyaml", "requests", "scikit-learn", "scipy", "tensorboard", "torchsnapshot", "torchvision", "tqdm", "vmas (>=1.2.10)", "wandb"]
|
|
||||||
atari = ["ale-py", "atari-py", "gym", "gym[accept-rom-license]", "pygame"]
|
|
||||||
checkpointing = ["torchsnapshot"]
|
|
||||||
dm-control = ["dm_control"]
|
|
||||||
gym-continuous = ["gymnasium", "mujoco"]
|
|
||||||
marl = ["pettingzoo (>=1.24.1)", "vmas (>=1.2.10)"]
|
|
||||||
offline-data = ["h5py", "huggingface_hub", "minari", "pandas", "pillow", "requests", "scikit-learn", "torchvision", "tqdm"]
|
|
||||||
rendering = ["moviepy"]
|
|
||||||
tests = ["pytest", "pytest-instafail", "pyyaml", "scipy"]
|
|
||||||
utils = ["git", "hydra-core (>=1.1)", "hydra-submitit-launcher", "tensorboard", "tqdm", "wandb"]
|
|
||||||
|
|
||||||
[package.source]
|
|
||||||
type = "git"
|
|
||||||
url = "https://github.com/pytorch/rl"
|
|
||||||
reference = "13bef426dcfa5887c6e5034a6e9697993fa92c37"
|
|
||||||
resolved_reference = "13bef426dcfa5887c6e5034a6e9697993fa92c37"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "torchvision"
|
name = "torchvision"
|
||||||
version = "0.17.2"
|
version = "0.17.2"
|
||||||
|
@ -3741,4 +3682,4 @@ xarm = ["gym-xarm"]
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.10"
|
python-versions = "^3.10"
|
||||||
content-hash = "cb450ac7186e004536d75409edd42cd96062f7b1fd47822a5460d12eab8762f9"
|
content-hash = "bf4627c62a45764931729ce373f1038fe289b6caebb01e66d878f6f278c54518"
|
||||||
|
|
|
@ -39,8 +39,6 @@ scikit-image = "^0.22.0"
|
||||||
numba = "^0.59.0"
|
numba = "^0.59.0"
|
||||||
mpmath = "^1.3.0"
|
mpmath = "^1.3.0"
|
||||||
torch = "^2.2.1"
|
torch = "^2.2.1"
|
||||||
tensordict = {git = "https://github.com/pytorch/tensordict"}
|
|
||||||
torchrl = {git = "https://github.com/pytorch/rl", rev = "13bef426dcfa5887c6e5034a6e9697993fa92c37"}
|
|
||||||
mujoco = "^2.3.7"
|
mujoco = "^2.3.7"
|
||||||
opencv-python = "^4.9.0.80"
|
opencv-python = "^4.9.0.80"
|
||||||
diffusers = "^0.26.3"
|
diffusers = "^0.26.3"
|
||||||
|
|
|
@ -15,10 +15,10 @@ from .utils import DEVICE, DEFAULT_CONFIG_PATH
|
||||||
("xarm", "tdmpc", ["policy.mpc=true"]),
|
("xarm", "tdmpc", ["policy.mpc=true"]),
|
||||||
("pusht", "tdmpc", ["policy.mpc=false"]),
|
("pusht", "tdmpc", ["policy.mpc=false"]),
|
||||||
("pusht", "diffusion", []),
|
("pusht", "diffusion", []),
|
||||||
# ("aloha", "act", ["env.task=sim_insertion", "dataset_id=aloha_sim_insertion_human"]),
|
("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_id=aloha_sim_insertion_human"]),
|
||||||
#("aloha", "act", ["env.task=sim_insertion", "dataset_id=aloha_sim_insertion_scripted"]),
|
("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_id=aloha_sim_insertion_scripted"]),
|
||||||
#("aloha", "act", ["env.task=sim_transfer_cube", "dataset_id=aloha_sim_transfer_cube_human"]),
|
("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_human"]),
|
||||||
#("aloha", "act", ["env.task=sim_transfer_cube", "dataset_id=aloha_sim_transfer_cube_scripted"]),
|
("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_scripted"]),
|
||||||
# TODO(aliberts): xarm not working with diffusion
|
# TODO(aliberts): xarm not working with diffusion
|
||||||
# ("xarm", "diffusion", []),
|
# ("xarm", "diffusion", []),
|
||||||
],
|
],
|
||||||
|
@ -49,7 +49,7 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
batch_size=cfg.policy.batch_size,
|
batch_size=2,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
pin_memory=DEVICE != "cpu",
|
pin_memory=DEVICE != "cpu",
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
|
|
Loading…
Reference in New Issue