diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index 971f4b63..749bb533 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -35,7 +35,7 @@ def make_env(cfg, num_parallel_envs=0) -> gym.Env | gym.vector.SyncVectorEnv: kwargs["task"] = cfg.env.task env_fn = lambda: gym.make( # noqa: E731 - "gym_aloha/AlohaInsertion-v0", + "gym_aloha/AlohaTransferCube-v0", **kwargs, ) else: diff --git a/lerobot/common/policies/act/policy.py b/lerobot/common/policies/act/policy.py index a9a5ac06..75d5ca0e 100644 --- a/lerobot/common/policies/act/policy.py +++ b/lerobot/common/policies/act/policy.py @@ -3,9 +3,10 @@ As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://arxiv.org/abs/2304.13705). The majority of changes here involve removing unused code, unifying naming, and adding helpful comments. """ -from collections import deque + import math import time +from collections import deque from itertools import chain from typing import Callable @@ -22,67 +23,6 @@ from torchvision.ops.misc import FrozenBatchNorm2d from lerobot.common.utils import get_safe_torch_device -# class AbstractPolicy(nn.Module): -# """Base policy which all policies should be derived from. - -# The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its -# documentation for more information. - -# Note: -# When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to: -# 1. set the required class attributes: -# - for classes inheriting from `AbstractDataset`: `available_datasets` -# - for classes inheriting from `AbstractEnv`: `name`, `available_tasks` -# - for classes inheriting from `AbstractPolicy`: `name` -# 2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`) -# 3. update variables in `tests/test_available.py` by importing your new class -# """ - -# name: str | None = None # same name should be used to instantiate the policy in factory.py - -# def __init__(self, n_action_steps: int | None): -# """ -# n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single -# action is returned by `select_actions` and that doesn't have a horizon dimension. The `forward` method then -# adds that dimension. -# """ -# super().__init__() -# assert self.name is not None, "Subclasses of `AbstractPolicy` should set the `name` class attribute." -# self.n_action_steps = n_action_steps -# self.clear_action_queue() - -# def clear_action_queue(self): -# """This should be called whenever the environment is reset.""" -# if self.n_action_steps is not None: -# self._action_queue = deque([], maxlen=self.n_action_steps) - -# def forward(self, fn) -> Tensor: -# """Inference step that makes multi-step policies compatible with their single-step environments. - -# WARNING: In general, this should not be overriden. - -# Consider a "policy" that observes the environment then charts a course of N actions to take. To make this fit -# into the formalism of a TorchRL environment, we view it as being effectively a policy that (1) makes an -# observation and prepares a queue of actions, (2) consumes that queue when queried, regardless of the environment -# observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that -# the subclass doesn't have to. - -# This method effectively wraps the `select_actions` method of the subclass. The following assumptions are made: -# 1. The `select_actions` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is -# the action trajectory horizon and * is the action dimensions. -# 2. Prior to the `select_actions` method being called, theres is an `n_action_steps` instance attribute defined. -# """ -# if self.n_action_steps is None: -# return self.select_actions(*args, **kwargs) -# if len(self._action_queue) == 0: -# # `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has shape -# # (n_action_steps, batch_size, *), hence the transpose. -# self._action_queue.extend(self.select_actions(*args, **kwargs).transpose(0, 1)) -# return self._action_queue.popleft() - - - - class ActionChunkingTransformerPolicy(nn.Module): """ Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost @@ -228,18 +168,30 @@ class ActionChunkingTransformerPolicy(nn.Module): if p.dim() > 1: nn.init.xavier_uniform_(p) - @torch.no_grad() - def select_action(self, batch, *_): - # TODO(now): Implement queueing mechanism. - self.eval() - self._preprocess_batch(batch) + def reset(self): + """This should be called whenever the environment is reset.""" + if self.n_action_steps is not None: + self._action_queue = deque([], maxlen=self.n_action_steps) - # TODO(now): What's up with this 0.182? - action = self.forward( - robot_state=batch["observation.state"] * 0.182, - image=batch["observation.images.top"], - return_loss=False, - ) + def select_action(self, batch: dict[str, Tensor], *_): + """ + This method wraps `select_actions` in order to return one action at a time for execution in the + environment. It works by managing the actions in a queue and only calling `select_actions` when the + queue is empty. + """ + if len(self._action_queue) == 0: + # `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has shape + # (n_action_steps, batch_size, *), hence the transpose. + self._action_queue.extend(self.select_actions(batch).transpose(0, 1)) + return self._action_queue.popleft() + + @torch.no_grad() + def select_actions(self, batch: dict[str, Tensor]): + """Use the action chunking transformer to generate a sequence of actions.""" + self.eval() + self._preprocess_batch(batch, add_obs_steps_dim=True) + + action = self.forward(batch, return_loss=False) if self.cfg.temporal_agg: # TODO(rcadene): implement temporal aggregation @@ -257,25 +209,37 @@ class ActionChunkingTransformerPolicy(nn.Module): return action[: self.n_action_steps] def __call__(self, *args, **kwargs): - # TODO(now): Temporary bridge. + # TODO(now): Temporary bridge until we know what to do about the `update` method. return self.update(*args, **kwargs) - def _preprocess_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + def _preprocess_batch( + self, batch: dict[str, Tensor], add_obs_steps_dim: bool = False + ) -> dict[str, Tensor]: """ - Expects batch to have (at least): + This function expects `batch` to have (at least): { - "observation.state": (B, 1, J) tensor of robot states (joint configuration) - - "observation.images.top": (B, 1, C, H, W) tensor of images. + "observation.state": (B, 1, J) OR (B, J) tensor of robot states (joint configuration). + "observation.images.top": (B, 1, C, H, W) OR (B, C, H, W) tensor of images. "action": (B, H, J) tensor of actions (positional target for robot joint configuration) "action_is_pad": (B, H) mask for whether the actions are padding outside of the episode bounds. } """ + if add_obs_steps_dim: + # Add a dimension for the observations steps. Since n_obs_steps > 1 is not supported right now, + # this just amounts to an unsqueeze. + for k in batch: + if k.startswith("observation."): + batch[k] = batch[k].unsqueeze(1) + if batch["observation.state"].shape[1] != 1: raise ValueError(self._multiple_obs_steps_not_handled_msg) batch["observation.state"] = batch["observation.state"].squeeze(1) - # TODO(alexander-soare): generalize this to multiple images. Note: no squeeze is required for - # "observation.images.top" because then we'd have to unsqueeze to get get the image index dimension. + # TODO(alexander-soare): generalize this to multiple images. + assert ( + sum(k.startswith("observation.images.") and not k.endswith("is_pad") for k in batch) == 1 + ), "ACT only handles one image for now." + # Note: no squeeze is required for "observation.images.top" because then we'd have to unsqueeze to get + # the image index dimension. def update(self, batch, *_): start_time = time.time() @@ -378,9 +342,7 @@ class ActionChunkingTransformerPolicy(nn.Module): # Forward pass through VAE encoder and sample the latent with the reparameterization trick. cls_token_out = self.vae_encoder( vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2) - )[ - 0 - ] # (B, D) + )[0] # (B, D) latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out) mu = latent_pdf_params[:, : self.latent_dim] # This is 2log(sigma). Done this way to match the original implementation. diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 90e7ecc1..cc956014 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -26,7 +26,6 @@ def make_policy(cfg): policy = ActionChunkingTransformerPolicy( cfg.policy, cfg.device, - n_obs_steps=cfg.n_obs_steps, n_action_steps=cfg.n_action_steps, ) else: diff --git a/lerobot/configs/policy/act.yaml b/lerobot/configs/policy/act.yaml index c1d1801f..80f50003 100644 --- a/lerobot/configs/policy/act.yaml +++ b/lerobot/configs/policy/act.yaml @@ -58,6 +58,6 @@ policy: action_dim: ??? delta_timestamps: - observation.image: [0.0] + observation.images.top: [0.0] observation.state: [0.0] action: [0.0, 0.02, 0.04, 0.06, 0.08, 0.1, 0.12, 0.14, 0.16, 0.18, 0.2, 0.22, 0.24, 0.26, 0.28, 0.3, 0.32, 0.34, 0.36, 0.38, 0.4, 0.42, 0.44, 0.46, 0.48, 0.5, 0.52, 0.54, 0.56, 0.58, 0.6, 0.62, 0.64, 0.66, 0.68, 0.70, 0.72, 0.74, 0.76, 0.78, 0.8, 0.82, 0.84, 0.86, 0.88, 0.9, 0.92, 0.94, 0.96, 0.98, 1.0, 1.02, 1.04, 1.06, 1.08, 1.1, 1.12, 1.14, 1.16, 1.18, 1.2, 1.22, 1.24, 1.26, 1.28, 1.3, 1.32, 1.34, 1.36, 1.38, 1.40, 1.42, 1.44, 1.46, 1.48, 1.5, 1.52, 1.54, 1.56, 1.58, 1.6, 1.62, 1.64, 1.66, 1.68, 1.7, 1.72, 1.74, 1.76, 1.78, 1.8, 1.82, 1.84, 1.86, 1.88, 1.90, 1.92, 1.94, 1.96, 1.98] diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index b05f9704..b43f4ed1 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -89,7 +89,9 @@ def eval_policy( visu = env.envs[0].render(mode="visualization") visu = visu[None, ...] # add batch dim else: - visu = np.stack([env.render(mode="visualization") for env in env.envs]) + # TODO(now): Put mode back in. + visu = np.stack([env.render() for env in env.envs]) + # visu = np.stack([env.render(mode="visualization") for env in env.envs]) ep_frames.append(visu) # noqa: B023 for _ in range(num_episodes): @@ -248,7 +250,7 @@ def eval(cfg: dict, out_dir=None, stats_path=None): logging.info("Making transforms.") # TODO(alexander-soare): Completely decouple datasets from evaluation. - dataset = make_dataset(cfg, stats_path=stats_path) + transform = make_dataset(cfg, stats_path=stats_path).transform logging.info("Making environment.") env = make_env(cfg, num_parallel_envs=cfg.rollout_batch_size) @@ -263,7 +265,7 @@ def eval(cfg: dict, out_dir=None, stats_path=None): video_dir=Path(out_dir) / "eval", fps=cfg.env.fps, # TODO(rcadene): what should we do with the transform? - transform=dataset.transform, + transform=transform, seed=cfg.seed, ) print(info["aggregated"]) diff --git a/poetry.lock b/poetry.lock index f96f66bc..60354b8a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -941,7 +941,7 @@ mujoco = "^2.3.7" type = "git" url = "git@github.com:huggingface/gym-xarm.git" reference = "HEAD" -resolved_reference = "2eb83fc4fc871b9d271c946d169e42f226ac3a7c" +resolved_reference = "08ddd5a9400783a6898bbf3c3014fc5da3961b9d" [[package]] name = "gymnasium" @@ -1709,20 +1709,20 @@ pyopengl = "*" [[package]] name = "networkx" -version = "3.2.1" +version = "3.3" description = "Python package for creating and manipulating graphs and networks" optional = false -python-versions = ">=3.9" +python-versions = ">=3.10" files = [ - {file = "networkx-3.2.1-py3-none-any.whl", hash = "sha256:f18c69adc97877c42332c170849c96cefa91881c99a7cb3e95b7c659ebdc1ec2"}, - {file = "networkx-3.2.1.tar.gz", hash = "sha256:9f1bb5cf3409bf324e0a722c20bdb4c20ee39bf1c30ce8ae499c8502b0b5e0c6"}, + {file = "networkx-3.3-py3-none-any.whl", hash = "sha256:28575580c6ebdaf4505b22c6256a2b9de86b316dc63ba9e93abde3d78dfdbcf2"}, + {file = "networkx-3.3.tar.gz", hash = "sha256:0c127d8b2f4865f59ae9cb8aafcd60b5c70f3241ebd66f7defad7c4ab90126c9"}, ] [package.extras] -default = ["matplotlib (>=3.5)", "numpy (>=1.22)", "pandas (>=1.4)", "scipy (>=1.9,!=1.11.0,!=1.11.1)"] -developer = ["changelist (==0.4)", "mypy (>=1.1)", "pre-commit (>=3.2)", "rtoml"] -doc = ["nb2plots (>=0.7)", "nbconvert (<7.9)", "numpydoc (>=1.6)", "pillow (>=9.4)", "pydata-sphinx-theme (>=0.14)", "sphinx (>=7)", "sphinx-gallery (>=0.14)", "texext (>=0.6.7)"] -extra = ["lxml (>=4.6)", "pydot (>=1.4.2)", "pygraphviz (>=1.11)", "sympy (>=1.10)"] +default = ["matplotlib (>=3.6)", "numpy (>=1.23)", "pandas (>=1.4)", "scipy (>=1.9,!=1.11.0,!=1.11.1)"] +developer = ["changelist (==0.5)", "mypy (>=1.1)", "pre-commit (>=3.2)", "rtoml"] +doc = ["myst-nb (>=1.0)", "numpydoc (>=1.7)", "pillow (>=9.4)", "pydata-sphinx-theme (>=0.14)", "sphinx (>=7)", "sphinx-gallery (>=0.14)", "texext (>=0.6.7)"] +extra = ["lxml (>=4.6)", "pydot (>=2.0)", "pygraphviz (>=1.12)", "sympy (>=1.10)"] test = ["pytest (>=7.2)", "pytest-cov (>=4.0)"] [[package]] @@ -3699,20 +3699,20 @@ watchdog = ["watchdog (>=2.3)"] [[package]] name = "zarr" -version = "2.17.1" +version = "2.17.2" description = "An implementation of chunked, compressed, N-dimensional arrays for Python" optional = false python-versions = ">=3.9" files = [ - {file = "zarr-2.17.1-py3-none-any.whl", hash = "sha256:e25df2741a6e92645f3890f30f3136d5b57a0f8f831094b024bbcab5f2797bc7"}, - {file = "zarr-2.17.1.tar.gz", hash = "sha256:564b3aa072122546fe69a0fa21736f466b20fad41754334b62619f088ce46261"}, + {file = "zarr-2.17.2-py3-none-any.whl", hash = "sha256:70d7cc07c24280c380ef80644151d136b7503b0d83c9f214e8000ddc0f57f69b"}, + {file = "zarr-2.17.2.tar.gz", hash = "sha256:2cbaa6cb4e342d45152d4a7a4b2013c337fcd3a8e7bc98253560180de60552ce"}, ] [package.dependencies] asciitree = "*" fasteners = {version = "*", markers = "sys_platform != \"emscripten\""} numcodecs = ">=0.10.0" -numpy = ">=1.21.1" +numpy = ">=1.23" [package.extras] docs = ["numcodecs[msgpack]", "numpydoc", "pydata-sphinx-theme", "sphinx", "sphinx-automodapi", "sphinx-copybutton", "sphinx-design", "sphinx-issues"]