From 73dfa3c8e331f9a4558ea51d1830360cfb2d4f1b Mon Sep 17 00:00:00 2001 From: Cadene Date: Tue, 9 Apr 2024 02:50:32 +0000 Subject: [PATCH 1/4] tests for tdmpc and diffusion policy are passing --- lerobot/common/policies/factory.py | 4 ++-- lerobot/common/policies/tdmpc/policy.py | 24 +++++++++++++----------- lerobot/configs/policy/diffusion.yaml | 1 + lerobot/configs/policy/tdmpc.yaml | 1 - poetry.lock | 4 ++-- tests/test_policies.py | 2 +- 6 files changed, 19 insertions(+), 17 deletions(-) diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 90e7ecc1..371ab221 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -16,8 +16,8 @@ def make_policy(cfg): cfg_obs_encoder=cfg.obs_encoder, cfg_optimizer=cfg.optimizer, cfg_ema=cfg.ema, - n_obs_steps=cfg.n_obs_steps, - n_action_steps=cfg.n_action_steps, + # n_obs_steps=cfg.n_obs_steps, + # n_action_steps=cfg.n_action_steps, **cfg.policy, ) elif cfg.policy.name == "act": diff --git a/lerobot/common/policies/tdmpc/policy.py b/lerobot/common/policies/tdmpc/policy.py index 04aa5b11..2d547f26 100644 --- a/lerobot/common/policies/tdmpc/policy.py +++ b/lerobot/common/policies/tdmpc/policy.py @@ -110,7 +110,6 @@ class TDMPCPolicy(nn.Module): # self.bc_optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr) self.model.eval() self.model_target.eval() - self.batch_size = cfg.batch_size self.register_buffer("step", torch.zeros(1)) @@ -325,7 +324,7 @@ class TDMPCPolicy(nn.Module): def _td_target(self, next_z, reward, mask): """Compute the TD-target from a reward and the observation at the following time step.""" next_v = self.model.V(next_z) - td_target = reward + self.cfg.discount * mask * next_v + td_target = reward + self.cfg.discount * mask * next_v.squeeze(2) return td_target def forward(self, batch, step): @@ -420,6 +419,8 @@ class TDMPCPolicy(nn.Module): # idxs = torch.cat([idxs, demo_idxs]) # weights = torch.cat([weights, demo_weights]) + batch_size = batch["index"].shape[0] + # TODO(rcadene): convert tdmpc with (batch size, time/horizon, channels) # instead of currently (time/horizon, batch size, channels) which is not the pytorch convention # batch size b = 256, time/horizon t = 5 @@ -433,7 +434,7 @@ class TDMPCPolicy(nn.Module): # idxs = batch["index"] # TODO(rcadene): use idxs to update sampling weights done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device) mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device) - weights = torch.ones_like(reward, dtype=torch.bool, device=reward.device) + weights = torch.ones(batch_size, dtype=torch.bool, device=reward.device) obses = { "rgb": batch["observation.image"], @@ -476,7 +477,7 @@ class TDMPCPolicy(nn.Module): td_targets = self._td_target(next_z, reward, mask) # Latent rollout - zs = torch.empty(horizon + 1, self.batch_size, self.cfg.latent_dim, device=self.device) + zs = torch.empty(horizon + 1, batch_size, self.cfg.latent_dim, device=self.device) reward_preds = torch.empty_like(reward, device=self.device) assert reward.shape[0] == horizon z = self.model.encode(obs) @@ -485,22 +486,21 @@ class TDMPCPolicy(nn.Module): for t in range(horizon): z, reward_pred = self.model.next(z, action[t]) zs[t + 1] = z - reward_preds[t] = reward_pred + reward_preds[t] = reward_pred.squeeze(1) with torch.no_grad(): v_target = self.model_target.Q(zs[:-1].detach(), action, return_type="min") # Predictions qs = self.model.Q(zs[:-1], action, return_type="all") + qs = qs.squeeze(3) value_info["Q"] = qs.mean().item() v = self.model.V(zs[:-1]) value_info["V"] = v.mean().item() # Losses - rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view(-1, 1, 1) - consistency_loss = (rho * torch.mean(h.mse(zs[1:], z_targets), dim=2, keepdim=True) * loss_mask).sum( - dim=0 - ) + rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view(-1, 1) + consistency_loss = (rho * torch.mean(h.mse(zs[1:], z_targets), dim=2) * loss_mask).sum(dim=0) reward_loss = (rho * h.mse(reward_preds, reward) * loss_mask).sum(dim=0) q_value_loss, priority_loss = 0, 0 for q in range(self.cfg.num_q): @@ -508,7 +508,9 @@ class TDMPCPolicy(nn.Module): priority_loss += (rho * h.l1(qs[q], td_targets) * loss_mask).sum(dim=0) expectile = h.linear_schedule(self.cfg.expectile, step) - v_value_loss = (rho * h.l2_expectile(v_target - v, expectile=expectile) * loss_mask).sum(dim=0) + v_value_loss = (rho * h.l2_expectile(v_target - v, expectile=expectile).squeeze(2) * loss_mask).sum( + dim=0 + ) total_loss = ( self.cfg.consistency_coef * consistency_loss @@ -517,7 +519,7 @@ class TDMPCPolicy(nn.Module): + self.cfg.value_coef * v_value_loss ) - weighted_loss = (total_loss.squeeze(1) * weights).mean() + weighted_loss = (total_loss * weights).mean() weighted_loss.register_hook(lambda grad: grad * (1 / self.cfg.horizon)) has_nan = torch.isnan(weighted_loss).item() if has_nan: diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index 6da62e10..811ee824 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -38,6 +38,7 @@ policy: horizon: ${horizon} n_obs_steps: ${n_obs_steps} + n_action_steps: ${n_action_steps} num_inference_steps: 100 obs_as_global_cond: ${obs_as_global_cond} # crop_shape: null diff --git a/lerobot/configs/policy/tdmpc.yaml b/lerobot/configs/policy/tdmpc.yaml index 4fd2b6bb..2ebaad9b 100644 --- a/lerobot/configs/policy/tdmpc.yaml +++ b/lerobot/configs/policy/tdmpc.yaml @@ -36,7 +36,6 @@ policy: log_std_max: 2 # learning - batch_size: 256 max_buffer_size: 10000 horizon: 5 reward_coef: 0.5 diff --git a/poetry.lock b/poetry.lock index 98449df4..e5105b75 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand. [[package]] name = "absl-py" @@ -897,7 +897,7 @@ mujoco = "^2.3.7" type = "git" url = "git@github.com:huggingface/gym-aloha.git" reference = "HEAD" -resolved_reference = "ec7200831e36c14e343cf7d275c6b047f2fe9d11" +resolved_reference = "c636f05ba0d1760df94537da84c860be1487e17f" [[package]] name = "gym-pusht" diff --git a/tests/test_policies.py b/tests/test_policies.py index 82033b78..5d0c0d89 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -49,7 +49,7 @@ def test_policy(env_name, policy_name, extra_overrides): dataloader = torch.utils.data.DataLoader( dataset, num_workers=4, - batch_size=cfg.policy.batch_size, + batch_size=2, shuffle=True, pin_memory=DEVICE != "cpu", drop_last=True, From 6902e01db07e2f27d862166d093c23e24654c900 Mon Sep 17 00:00:00 2001 From: Cadene Date: Tue, 9 Apr 2024 03:28:56 +0000 Subject: [PATCH 2/4] tests are passing for aloha/act policies, removes abstract policy --- lerobot/common/policies/abstract.py | 82 ------------- lerobot/common/policies/act/policy.py | 153 ++++++++++++------------ lerobot/common/policies/factory.py | 6 +- lerobot/common/policies/tdmpc/policy.py | 4 +- lerobot/configs/policy/act.yaml | 4 + tests/test_policies.py | 8 +- 6 files changed, 90 insertions(+), 167 deletions(-) delete mode 100644 lerobot/common/policies/abstract.py diff --git a/lerobot/common/policies/abstract.py b/lerobot/common/policies/abstract.py deleted file mode 100644 index 6dc72bef..00000000 --- a/lerobot/common/policies/abstract.py +++ /dev/null @@ -1,82 +0,0 @@ -from collections import deque - -import torch -from torch import Tensor, nn - - -class AbstractPolicy(nn.Module): - """Base policy which all policies should be derived from. - - The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its - documentation for more information. - - Note: - When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to: - 1. set the required class attributes: - - for classes inheriting from `AbstractDataset`: `available_datasets` - - for classes inheriting from `AbstractEnv`: `name`, `available_tasks` - - for classes inheriting from `AbstractPolicy`: `name` - 2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`) - 3. update variables in `tests/test_available.py` by importing your new class - """ - - name: str | None = None # same name should be used to instantiate the policy in factory.py - - def __init__(self, n_action_steps: int | None): - """ - n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single - action is returned by `select_actions` and that doesn't have a horizon dimension. The `forward` method then - adds that dimension. - """ - super().__init__() - assert self.name is not None, "Subclasses of `AbstractPolicy` should set the `name` class attribute." - self.n_action_steps = n_action_steps - self.clear_action_queue() - - def update(self, replay_buffer, step): - """One step of the policy's learning algorithm.""" - raise NotImplementedError("Abstract method") - - def save(self, fp): - torch.save(self.state_dict(), fp) - - def load(self, fp): - d = torch.load(fp) - self.load_state_dict(d) - - def select_actions(self, observation) -> Tensor: - """Select an action (or trajectory of actions) based on an observation during rollout. - - If n_action_steps was provided at initialization, this should return a (batch_size, n_action_steps, *) tensor of - actions. Otherwise if n_actions_steps is None, this should return a (batch_size, *) tensor of actions. - """ - raise NotImplementedError("Abstract method") - - def clear_action_queue(self): - """This should be called whenever the environment is reset.""" - if self.n_action_steps is not None: - self._action_queue = deque([], maxlen=self.n_action_steps) - - def forward(self, *args, **kwargs) -> Tensor: - """Inference step that makes multi-step policies compatible with their single-step environments. - - WARNING: In general, this should not be overriden. - - Consider a "policy" that observes the environment then charts a course of N actions to take. To make this fit - into the formalism of a TorchRL environment, we view it as being effectively a policy that (1) makes an - observation and prepares a queue of actions, (2) consumes that queue when queried, regardless of the environment - observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that - the subclass doesn't have to. - - This method effectively wraps the `select_actions` method of the subclass. The following assumptions are made: - 1. The `select_actions` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is - the action trajectory horizon and * is the action dimensions. - 2. Prior to the `select_actions` method being called, theres is an `n_action_steps` instance attribute defined. - """ - if self.n_action_steps is None: - return self.select_actions(*args, **kwargs) - if len(self._action_queue) == 0: - # `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has shape - # (n_action_steps, batch_size, *), hence the transpose. - self._action_queue.extend(self.select_actions(*args, **kwargs).transpose(0, 1)) - return self._action_queue.popleft() diff --git a/lerobot/common/policies/act/policy.py b/lerobot/common/policies/act/policy.py index ae4f7320..4138e910 100644 --- a/lerobot/common/policies/act/policy.py +++ b/lerobot/common/policies/act/policy.py @@ -1,13 +1,14 @@ import logging import time +from collections import deque import torch import torch.nn.functional as F # noqa: N812 import torchvision.transforms as transforms +from torch import nn -from lerobot.common.policies.abstract import AbstractPolicy from lerobot.common.policies.act.detr_vae import build -from lerobot.common.utils import get_safe_torch_device +from lerobot.common.policies.utils import populate_queues def build_act_model_and_optimizer(cfg): @@ -41,75 +42,61 @@ def kl_divergence(mu, logvar): return total_kld, dimension_wise_kld, mean_kld -class ActionChunkingTransformerPolicy(AbstractPolicy): +class ActionChunkingTransformerPolicy(nn.Module): name = "act" - def __init__(self, cfg, device, n_action_steps=1): - super().__init__(n_action_steps) + def __init__(self, cfg, n_obs_steps, n_action_steps): + super().__init__() self.cfg = cfg + self.n_obs_steps = n_obs_steps + if self.n_obs_steps > 1: + raise NotImplementedError() self.n_action_steps = n_action_steps - self.device = get_safe_torch_device(device) self.model, self.optimizer = build_act_model_and_optimizer(cfg) self.kl_weight = self.cfg.kl_weight logging.info(f"KL Weight {self.kl_weight}") - self.to(self.device) - def update(self, replay_buffer, step): + def reset(self): + """ + Clear observation and action queues. Should be called on `env.reset()` + """ + self._queues = { + "observation.images.top": deque(maxlen=self.n_obs_steps), + "observation.state": deque(maxlen=self.n_obs_steps), + "action": deque(maxlen=self.n_action_steps), + } + + def forward(self, batch, step): del step start_time = time.time() self.train() - num_slices = self.cfg.batch_size - batch_size = self.cfg.horizon * num_slices + image = batch["observation.images.top"] + # batch, num_cam, channel, height, width + image = image.unsqueeze(1) + assert image.ndim == 5 - assert batch_size % self.cfg.horizon == 0 - assert batch_size % num_slices == 0 + state = batch["observation.state"] + # batch, qpos_dim + assert state.ndim == 2 - def process_batch(batch, horizon, num_slices): - # trajectory t = 64, horizon h = 16 - # (t h) ... -> t h ... - batch = batch.reshape(num_slices, horizon) + action = batch["action"] + # batch, seq, action_dim + assert action.ndim == 3 - image = batch["observation", "image", "top"] - image = image[:, 0] # first observation t=0 - # batch, num_cam, channel, height, width - image = image.unsqueeze(1) - assert image.ndim == 5 - image = image.float() - - state = batch["observation", "state"] - state = state[:, 0] # first observation t=0 - # batch, qpos_dim - assert state.ndim == 2 - - action = batch["action"] - # batch, seq, action_dim - assert action.ndim == 3 - assert action.shape[1] == horizon - - if self.cfg.n_obs_steps > 1: - raise NotImplementedError() - # # keep first n observations of the slice corresponding to t=[-1,0] - # image = image[:, : self.cfg.n_obs_steps] - # state = state[:, : self.cfg.n_obs_steps] - - out = { - "obs": { - "image": image.to(self.device, non_blocking=True), - "agent_pos": state.to(self.device, non_blocking=True), - }, - "action": action.to(self.device, non_blocking=True), - } - return out - - batch = replay_buffer.sample(batch_size) - batch = process_batch(batch, self.cfg.horizon, num_slices) + preprocessed_batch = { + "obs": { + "image": image, + "agent_pos": state, + }, + "action": action, + } data_s = time.time() - start_time - loss = self.compute_loss(batch) + loss = self.compute_loss(preprocessed_batch) loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( @@ -150,40 +137,52 @@ class ActionChunkingTransformerPolicy(AbstractPolicy): return loss @torch.no_grad() - def select_actions(self, observation, step_count): - if observation["image"].shape[0] != 1: - raise NotImplementedError("Batch size > 1 not handled") + def select_action(self, batch, step): + assert "observation.images.top" in batch + assert "observation.state" in batch + assert len(batch) == 2 + + self._queues = populate_queues(self._queues, batch) # TODO(rcadene): remove unused step_count - del step_count + del step self.eval() - # TODO(rcadene): remove hack - # add 1 camera dimension - observation["image", "top"] = observation["image", "top"].unsqueeze(1) + if len(self._queues["action"]) == 0: + batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch} - obs_dict = { - "image": observation["image", "top"], - "agent_pos": observation["state"], - } - action = self._forward(qpos=obs_dict["agent_pos"], image=obs_dict["image"]) + if self.n_obs_steps == 1: + # hack to remove the time dimension + for key in batch: + assert batch[key].shape[1] == 1 + batch[key] = batch[key][:, 0] - if self.cfg.temporal_agg: - # TODO(rcadene): implement temporal aggregation - raise NotImplementedError() - # all_time_actions[[t], t:t+num_queries] = action - # actions_for_curr_step = all_time_actions[:, t] - # actions_populated = torch.all(actions_for_curr_step != 0, axis=1) - # actions_for_curr_step = actions_for_curr_step[actions_populated] - # k = 0.01 - # exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step))) - # exp_weights = exp_weights / exp_weights.sum() - # exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1) - # raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True) + actions = self._forward( + # TODO(rcadene): remove unsqueeze hack to add the "number of cameras" dimension + image=batch["observation.images.top"].unsqueeze(1), + qpos=batch["observation.state"], + ) - # take first predicted action or n first actions - action = action[: self.n_action_steps] + if self.cfg.temporal_agg: + # TODO(rcadene): implement temporal aggregation + raise NotImplementedError() + # all_time_actions[[t], t:t+num_queries] = action + # actions_for_curr_step = all_time_actions[:, t] + # actions_populated = torch.all(actions_for_curr_step != 0, axis=1) + # actions_for_curr_step = actions_for_curr_step[actions_populated] + # k = 0.01 + # exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step))) + # exp_weights = exp_weights / exp_weights.sum() + # exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1) + # raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True) + + # act returns a sequence of `n` actions, but we consider only + # the first `n_action_steps` actions subset + for i in range(self.n_action_steps): + self._queues["action"].append(actions[:, i]) + + action = self._queues["action"].popleft() return action def _forward(self, qpos, image, actions=None, is_pad=None): diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 371ab221..8636aa6e 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -25,10 +25,10 @@ def make_policy(cfg): policy = ActionChunkingTransformerPolicy( cfg.policy, - cfg.device, - n_obs_steps=cfg.n_obs_steps, - n_action_steps=cfg.n_action_steps, + n_obs_steps=cfg.policy.n_obs_steps, + n_action_steps=cfg.policy.n_action_steps, ) + policy.to(cfg.device) else: raise ValueError(cfg.policy.name) diff --git a/lerobot/common/policies/tdmpc/policy.py b/lerobot/common/policies/tdmpc/policy.py index 2d547f26..942ee9b1 100644 --- a/lerobot/common/policies/tdmpc/policy.py +++ b/lerobot/common/policies/tdmpc/policy.py @@ -150,6 +150,8 @@ class TDMPCPolicy(nn.Module): t0 = step == 0 + self.eval() + if len(self._queues["action"]) == 0: batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch} @@ -171,7 +173,7 @@ class TDMPCPolicy(nn.Module): actions.append(action) action = torch.stack(actions) - # self.act returns an action for 1 timestep only, so we copy it over `n_action_steps` time + # tdmpc returns an action for 1 timestep only, so we copy it over `n_action_steps` time if i in range(self.n_action_steps): self._queues["action"].append(action) diff --git a/lerobot/configs/policy/act.yaml b/lerobot/configs/policy/act.yaml index 9dca436f..cf5d7508 100644 --- a/lerobot/configs/policy/act.yaml +++ b/lerobot/configs/policy/act.yaml @@ -50,8 +50,12 @@ policy: utd: 1 n_obs_steps: ${n_obs_steps} + n_action_steps: ${n_action_steps} temporal_agg: false state_dim: ??? action_dim: ??? + + delta_timestamps: + action: "[i / ${fps} for i in range(${horizon})]" diff --git a/tests/test_policies.py b/tests/test_policies.py index 5d0c0d89..8ccc7c62 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -15,10 +15,10 @@ from .utils import DEVICE, DEFAULT_CONFIG_PATH ("xarm", "tdmpc", ["policy.mpc=true"]), ("pusht", "tdmpc", ["policy.mpc=false"]), ("pusht", "diffusion", []), - # ("aloha", "act", ["env.task=sim_insertion", "dataset_id=aloha_sim_insertion_human"]), - #("aloha", "act", ["env.task=sim_insertion", "dataset_id=aloha_sim_insertion_scripted"]), - #("aloha", "act", ["env.task=sim_transfer_cube", "dataset_id=aloha_sim_transfer_cube_human"]), - #("aloha", "act", ["env.task=sim_transfer_cube", "dataset_id=aloha_sim_transfer_cube_scripted"]), + ("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_id=aloha_sim_insertion_human"]), + ("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_id=aloha_sim_insertion_scripted"]), + ("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_human"]), + ("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_scripted"]), # TODO(aliberts): xarm not working with diffusion # ("xarm", "diffusion", []), ], From 253e495df237ccd7b6db1dfacac2cdeeba29bd82 Mon Sep 17 00:00:00 2001 From: Cadene Date: Tue, 9 Apr 2024 03:46:05 +0000 Subject: [PATCH 3/4] remove render(mode=visualization) --- lerobot/scripts/eval.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index e7ba53fc..512bb451 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -86,10 +86,10 @@ def eval_policy( def maybe_render_frame(env): if save_video: # noqa: B023 if return_first_video: - visu = env.envs[0].render(mode="visualization") + visu = env.envs[0].render() visu = visu[None, ...] # add batch dim else: - visu = np.stack([env.render(mode="visualization") for env in env.envs]) + visu = np.stack([env.render() for env in env.envs]) ep_frames.append(visu) # noqa: B023 for _ in range(num_episodes): From 19e7661b8da4f561b6557008ed969dda8a1f13f9 Mon Sep 17 00:00:00 2001 From: Cadene Date: Tue, 9 Apr 2024 03:50:49 +0000 Subject: [PATCH 4/4] Remove torchrl/tensordict from dependecies + update poetry cpu --- .github/poetry/cpu/poetry.lock | 126 +++++++++++++++--------------- .github/poetry/cpu/pyproject.toml | 9 ++- poetry.lock | 61 +-------------- pyproject.toml | 2 - 4 files changed, 72 insertions(+), 126 deletions(-) diff --git a/.github/poetry/cpu/poetry.lock b/.github/poetry/cpu/poetry.lock index ba820f34..15b27c76 100644 --- a/.github/poetry/cpu/poetry.lock +++ b/.github/poetry/cpu/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand. [[package]] name = "absl-py" @@ -889,6 +889,69 @@ files = [ [package.extras] protobuf = ["grpcio-tools (>=1.62.1)"] +[[package]] +name = "gym-aloha" +version = "0.1.0" +description = "A gym environment for ALOHA" +optional = true +python-versions = "^3.10" +files = [] +develop = false + +[package.dependencies] +dm-control = "1.0.14" +gymnasium = "^0.29.1" +mujoco = "^2.3.7" + +[package.source] +type = "git" +url = "git@github.com:huggingface/gym-aloha.git" +reference = "HEAD" +resolved_reference = "c636f05ba0d1760df94537da84c860be1487e17f" + +[[package]] +name = "gym-pusht" +version = "0.1.0" +description = "A gymnasium environment for PushT." +optional = true +python-versions = "^3.10" +files = [] +develop = false + +[package.dependencies] +gymnasium = "^0.29.1" +opencv-python = "^4.9.0.80" +pygame = "^2.5.2" +pymunk = "^6.6.0" +scikit-image = "^0.22.0" +shapely = "^2.0.3" + +[package.source] +type = "git" +url = "git@github.com:huggingface/gym-pusht.git" +reference = "HEAD" +resolved_reference = "6c9893504f670ff069d0f759a733e971ea1efdbf" + +[[package]] +name = "gym-xarm" +version = "0.1.0" +description = "A gym environment for xArm" +optional = true +python-versions = "^3.10" +files = [] +develop = false + +[package.dependencies] +gymnasium = "^0.29.1" +gymnasium-robotics = "^1.2.4" +mujoco = "^2.3.7" + +[package.source] +type = "git" +url = "git@github.com:huggingface/gym-xarm.git" +reference = "HEAD" +resolved_reference = "08ddd5a9400783a6898bbf3c3014fc5da3961b9d" + [[package]] name = "gymnasium" version = "0.29.1" @@ -2988,31 +3051,6 @@ numpy = "*" packaging = "*" protobuf = ">=3.20" -[[package]] -name = "tensordict" -version = "0.4.0+b4c91e8" -description = "" -optional = false -python-versions = "*" -files = [] -develop = false - -[package.dependencies] -cloudpickle = "*" -numpy = "*" -torch = ">=2.1.0" - -[package.extras] -checkpointing = ["torchsnapshot-nightly"] -h5 = ["h5py (>=3.8)"] -tests = ["pytest", "pytest-benchmark", "pytest-instafail", "pytest-rerunfailures", "pyyaml"] - -[package.source] -type = "git" -url = "https://github.com/pytorch/tensordict" -reference = "HEAD" -resolved_reference = "b4c91e8828c538ca0a50d8383fd99311a9afb078" - [[package]] name = "termcolor" version = "2.4.0" @@ -3091,40 +3129,6 @@ type = "legacy" url = "https://download.pytorch.org/whl/cpu" reference = "torch-cpu" -[[package]] -name = "torchrl" -version = "0.4.0+13bef42" -description = "" -optional = false -python-versions = "*" -files = [] -develop = false - -[package.dependencies] -cloudpickle = "*" -numpy = "*" -packaging = "*" -tensordict = ">=0.4.0" -torch = ">=2.1.0" - -[package.extras] -all = ["ale-py", "atari-py", "dm_control", "git", "gym", "gym[accept-rom-license]", "gymnasium", "h5py", "huggingface_hub", "hydra-core (>=1.1)", "hydra-submitit-launcher", "minari", "moviepy", "mujoco", "pandas", "pettingzoo (>=1.24.1)", "pillow", "pygame", "pytest", "pytest-instafail", "pyyaml", "requests", "scikit-learn", "scipy", "tensorboard", "torchsnapshot", "torchvision", "tqdm", "vmas (>=1.2.10)", "wandb"] -atari = ["ale-py", "atari-py", "gym", "gym[accept-rom-license]", "pygame"] -checkpointing = ["torchsnapshot"] -dm-control = ["dm_control"] -gym-continuous = ["gymnasium", "mujoco"] -marl = ["pettingzoo (>=1.24.1)", "vmas (>=1.2.10)"] -offline-data = ["h5py", "huggingface_hub", "minari", "pandas", "pillow", "requests", "scikit-learn", "torchvision", "tqdm"] -rendering = ["moviepy"] -tests = ["pytest", "pytest-instafail", "pyyaml", "scipy"] -utils = ["git", "hydra-core (>=1.1)", "hydra-submitit-launcher", "tensorboard", "tqdm", "wandb"] - -[package.source] -type = "git" -url = "https://github.com/pytorch/rl" -reference = "13bef426dcfa5887c6e5034a6e9697993fa92c37" -resolved_reference = "13bef426dcfa5887c6e5034a6e9697993fa92c37" - [[package]] name = "torchvision" version = "0.17.1+cpu" @@ -3330,4 +3334,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "8800bb8b24312d17b765cd2ce2799f49436171dd5fbf1bec3b07f853cfa9befd" +content-hash = "32cd6caa01276a90b37cb177204e5b1511e92838f3f0268391034042d56f3bd6" diff --git a/.github/poetry/cpu/pyproject.toml b/.github/poetry/cpu/pyproject.toml index e84b93c9..d310da47 100644 --- a/.github/poetry/cpu/pyproject.toml +++ b/.github/poetry/cpu/pyproject.toml @@ -39,8 +39,6 @@ scikit-image = "^0.22.0" numba = "^0.59.0" mpmath = "^1.3.0" torch = {version = "^2.2.1", source = "torch-cpu"} -tensordict = {git = "https://github.com/pytorch/tensordict"} -torchrl = {git = "https://github.com/pytorch/rl", rev = "13bef426dcfa5887c6e5034a6e9697993fa92c37"} mujoco = "^2.3.7" opencv-python = "^4.9.0.80" diffusers = "^0.26.3" @@ -53,7 +51,12 @@ huggingface-hub = "^0.21.4" gymnasium-robotics = "^1.2.4" gymnasium = "^0.29.1" cmake = "^3.29.0.1" - +gym-pusht = { git = "git@github.com:huggingface/gym-pusht.git", optional = true} +gym-xarm = { git = "git@github.com:huggingface/gym-xarm.git", optional = true} +gym-aloha = { git = "git@github.com:huggingface/gym-aloha.git", optional = true} +# gym-pusht = { path = "../gym-pusht", develop = true, optional = true} +# gym-xarm = { path = "../gym-xarm", develop = true, optional = true} +# gym-aloha = { path = "../gym-aloha", develop = true, optional = true} [tool.poetry.group.dev.dependencies] pre-commit = "^3.6.2" diff --git a/poetry.lock b/poetry.lock index e5105b75..95c9f31e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3255,31 +3255,6 @@ numpy = "*" packaging = "*" protobuf = ">=3.20" -[[package]] -name = "tensordict" -version = "0.4.0+b4c91e8" -description = "" -optional = false -python-versions = "*" -files = [] -develop = false - -[package.dependencies] -cloudpickle = "*" -numpy = "*" -torch = ">=2.1.0" - -[package.extras] -checkpointing = ["torchsnapshot-nightly"] -h5 = ["h5py (>=3.8)"] -tests = ["pytest", "pytest-benchmark", "pytest-instafail", "pytest-rerunfailures", "pyyaml"] - -[package.source] -type = "git" -url = "https://github.com/pytorch/tensordict" -reference = "HEAD" -resolved_reference = "b4c91e8828c538ca0a50d8383fd99311a9afb078" - [[package]] name = "termcolor" version = "2.4.0" @@ -3380,40 +3355,6 @@ typing-extensions = ">=4.8.0" opt-einsum = ["opt-einsum (>=3.3)"] optree = ["optree (>=0.9.1)"] -[[package]] -name = "torchrl" -version = "0.4.0+13bef42" -description = "" -optional = false -python-versions = "*" -files = [] -develop = false - -[package.dependencies] -cloudpickle = "*" -numpy = "*" -packaging = "*" -tensordict = ">=0.4.0" -torch = ">=2.1.0" - -[package.extras] -all = ["ale-py", "atari-py", "dm_control", "git", "gym", "gym[accept-rom-license]", "gymnasium", "h5py", "huggingface_hub", "hydra-core (>=1.1)", "hydra-submitit-launcher", "minari", "moviepy", "mujoco", "pandas", "pettingzoo (>=1.24.1)", "pillow", "pygame", "pytest", "pytest-instafail", "pyyaml", "requests", "scikit-learn", "scipy", "tensorboard", "torchsnapshot", "torchvision", "tqdm", "vmas (>=1.2.10)", "wandb"] -atari = ["ale-py", "atari-py", "gym", "gym[accept-rom-license]", "pygame"] -checkpointing = ["torchsnapshot"] -dm-control = ["dm_control"] -gym-continuous = ["gymnasium", "mujoco"] -marl = ["pettingzoo (>=1.24.1)", "vmas (>=1.2.10)"] -offline-data = ["h5py", "huggingface_hub", "minari", "pandas", "pillow", "requests", "scikit-learn", "torchvision", "tqdm"] -rendering = ["moviepy"] -tests = ["pytest", "pytest-instafail", "pyyaml", "scipy"] -utils = ["git", "hydra-core (>=1.1)", "hydra-submitit-launcher", "tensorboard", "tqdm", "wandb"] - -[package.source] -type = "git" -url = "https://github.com/pytorch/rl" -reference = "13bef426dcfa5887c6e5034a6e9697993fa92c37" -resolved_reference = "13bef426dcfa5887c6e5034a6e9697993fa92c37" - [[package]] name = "torchvision" version = "0.17.1" @@ -3657,4 +3598,4 @@ xarm = ["gym-xarm"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "cb450ac7186e004536d75409edd42cd96062f7b1fd47822a5460d12eab8762f9" +content-hash = "bf4627c62a45764931729ce373f1038fe289b6caebb01e66d878f6f278c54518" diff --git a/pyproject.toml b/pyproject.toml index e78a502d..a549e66f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,8 +39,6 @@ scikit-image = "^0.22.0" numba = "^0.59.0" mpmath = "^1.3.0" torch = "^2.2.1" -tensordict = {git = "https://github.com/pytorch/tensordict"} -torchrl = {git = "https://github.com/pytorch/rl", rev = "13bef426dcfa5887c6e5034a6e9697993fa92c37"} mujoco = "^2.3.7" opencv-python = "^4.9.0.80" diffusers = "^0.26.3"