Eval reproduction works with gym_aloha

This commit is contained in:
Alexander Soare 2024-04-08 10:23:26 +01:00
parent e982c732f1
commit 1bab4a1dd5
6 changed files with 66 additions and 103 deletions

View File

@ -35,7 +35,7 @@ def make_env(cfg, num_parallel_envs=0) -> gym.Env | gym.vector.SyncVectorEnv:
kwargs["task"] = cfg.env.task kwargs["task"] = cfg.env.task
env_fn = lambda: gym.make( # noqa: E731 env_fn = lambda: gym.make( # noqa: E731
"gym_aloha/AlohaInsertion-v0", "gym_aloha/AlohaTransferCube-v0",
**kwargs, **kwargs,
) )
else: else:

View File

@ -3,9 +3,10 @@
As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://arxiv.org/abs/2304.13705). As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://arxiv.org/abs/2304.13705).
The majority of changes here involve removing unused code, unifying naming, and adding helpful comments. The majority of changes here involve removing unused code, unifying naming, and adding helpful comments.
""" """
from collections import deque
import math import math
import time import time
from collections import deque
from itertools import chain from itertools import chain
from typing import Callable from typing import Callable
@ -22,67 +23,6 @@ from torchvision.ops.misc import FrozenBatchNorm2d
from lerobot.common.utils import get_safe_torch_device from lerobot.common.utils import get_safe_torch_device
# class AbstractPolicy(nn.Module):
# """Base policy which all policies should be derived from.
# The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its
# documentation for more information.
# Note:
# When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
# 1. set the required class attributes:
# - for classes inheriting from `AbstractDataset`: `available_datasets`
# - for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
# - for classes inheriting from `AbstractPolicy`: `name`
# 2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
# 3. update variables in `tests/test_available.py` by importing your new class
# """
# name: str | None = None # same name should be used to instantiate the policy in factory.py
# def __init__(self, n_action_steps: int | None):
# """
# n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single
# action is returned by `select_actions` and that doesn't have a horizon dimension. The `forward` method then
# adds that dimension.
# """
# super().__init__()
# assert self.name is not None, "Subclasses of `AbstractPolicy` should set the `name` class attribute."
# self.n_action_steps = n_action_steps
# self.clear_action_queue()
# def clear_action_queue(self):
# """This should be called whenever the environment is reset."""
# if self.n_action_steps is not None:
# self._action_queue = deque([], maxlen=self.n_action_steps)
# def forward(self, fn) -> Tensor:
# """Inference step that makes multi-step policies compatible with their single-step environments.
# WARNING: In general, this should not be overriden.
# Consider a "policy" that observes the environment then charts a course of N actions to take. To make this fit
# into the formalism of a TorchRL environment, we view it as being effectively a policy that (1) makes an
# observation and prepares a queue of actions, (2) consumes that queue when queried, regardless of the environment
# observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that
# the subclass doesn't have to.
# This method effectively wraps the `select_actions` method of the subclass. The following assumptions are made:
# 1. The `select_actions` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is
# the action trajectory horizon and * is the action dimensions.
# 2. Prior to the `select_actions` method being called, theres is an `n_action_steps` instance attribute defined.
# """
# if self.n_action_steps is None:
# return self.select_actions(*args, **kwargs)
# if len(self._action_queue) == 0:
# # `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has shape
# # (n_action_steps, batch_size, *), hence the transpose.
# self._action_queue.extend(self.select_actions(*args, **kwargs).transpose(0, 1))
# return self._action_queue.popleft()
class ActionChunkingTransformerPolicy(nn.Module): class ActionChunkingTransformerPolicy(nn.Module):
""" """
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
@ -228,18 +168,30 @@ class ActionChunkingTransformerPolicy(nn.Module):
if p.dim() > 1: if p.dim() > 1:
nn.init.xavier_uniform_(p) nn.init.xavier_uniform_(p)
@torch.no_grad() def reset(self):
def select_action(self, batch, *_): """This should be called whenever the environment is reset."""
# TODO(now): Implement queueing mechanism. if self.n_action_steps is not None:
self.eval() self._action_queue = deque([], maxlen=self.n_action_steps)
self._preprocess_batch(batch)
# TODO(now): What's up with this 0.182? def select_action(self, batch: dict[str, Tensor], *_):
action = self.forward( """
robot_state=batch["observation.state"] * 0.182, This method wraps `select_actions` in order to return one action at a time for execution in the
image=batch["observation.images.top"], environment. It works by managing the actions in a queue and only calling `select_actions` when the
return_loss=False, queue is empty.
) """
if len(self._action_queue) == 0:
# `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has shape
# (n_action_steps, batch_size, *), hence the transpose.
self._action_queue.extend(self.select_actions(batch).transpose(0, 1))
return self._action_queue.popleft()
@torch.no_grad()
def select_actions(self, batch: dict[str, Tensor]):
"""Use the action chunking transformer to generate a sequence of actions."""
self.eval()
self._preprocess_batch(batch, add_obs_steps_dim=True)
action = self.forward(batch, return_loss=False)
if self.cfg.temporal_agg: if self.cfg.temporal_agg:
# TODO(rcadene): implement temporal aggregation # TODO(rcadene): implement temporal aggregation
@ -257,25 +209,37 @@ class ActionChunkingTransformerPolicy(nn.Module):
return action[: self.n_action_steps] return action[: self.n_action_steps]
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
# TODO(now): Temporary bridge. # TODO(now): Temporary bridge until we know what to do about the `update` method.
return self.update(*args, **kwargs) return self.update(*args, **kwargs)
def _preprocess_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: def _preprocess_batch(
self, batch: dict[str, Tensor], add_obs_steps_dim: bool = False
) -> dict[str, Tensor]:
""" """
Expects batch to have (at least): This function expects `batch` to have (at least):
{ {
"observation.state": (B, 1, J) tensor of robot states (joint configuration) "observation.state": (B, 1, J) OR (B, J) tensor of robot states (joint configuration).
"observation.images.top": (B, 1, C, H, W) OR (B, C, H, W) tensor of images.
"observation.images.top": (B, 1, C, H, W) tensor of images.
"action": (B, H, J) tensor of actions (positional target for robot joint configuration) "action": (B, H, J) tensor of actions (positional target for robot joint configuration)
"action_is_pad": (B, H) mask for whether the actions are padding outside of the episode bounds. "action_is_pad": (B, H) mask for whether the actions are padding outside of the episode bounds.
} }
""" """
if add_obs_steps_dim:
# Add a dimension for the observations steps. Since n_obs_steps > 1 is not supported right now,
# this just amounts to an unsqueeze.
for k in batch:
if k.startswith("observation."):
batch[k] = batch[k].unsqueeze(1)
if batch["observation.state"].shape[1] != 1: if batch["observation.state"].shape[1] != 1:
raise ValueError(self._multiple_obs_steps_not_handled_msg) raise ValueError(self._multiple_obs_steps_not_handled_msg)
batch["observation.state"] = batch["observation.state"].squeeze(1) batch["observation.state"] = batch["observation.state"].squeeze(1)
# TODO(alexander-soare): generalize this to multiple images. Note: no squeeze is required for # TODO(alexander-soare): generalize this to multiple images.
# "observation.images.top" because then we'd have to unsqueeze to get get the image index dimension. assert (
sum(k.startswith("observation.images.") and not k.endswith("is_pad") for k in batch) == 1
), "ACT only handles one image for now."
# Note: no squeeze is required for "observation.images.top" because then we'd have to unsqueeze to get
# the image index dimension.
def update(self, batch, *_): def update(self, batch, *_):
start_time = time.time() start_time = time.time()
@ -378,9 +342,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
# Forward pass through VAE encoder and sample the latent with the reparameterization trick. # Forward pass through VAE encoder and sample the latent with the reparameterization trick.
cls_token_out = self.vae_encoder( cls_token_out = self.vae_encoder(
vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2) vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2)
)[ )[0] # (B, D)
0
] # (B, D)
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out) latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
mu = latent_pdf_params[:, : self.latent_dim] mu = latent_pdf_params[:, : self.latent_dim]
# This is 2log(sigma). Done this way to match the original implementation. # This is 2log(sigma). Done this way to match the original implementation.

View File

@ -26,7 +26,6 @@ def make_policy(cfg):
policy = ActionChunkingTransformerPolicy( policy = ActionChunkingTransformerPolicy(
cfg.policy, cfg.policy,
cfg.device, cfg.device,
n_obs_steps=cfg.n_obs_steps,
n_action_steps=cfg.n_action_steps, n_action_steps=cfg.n_action_steps,
) )
else: else:

View File

@ -58,6 +58,6 @@ policy:
action_dim: ??? action_dim: ???
delta_timestamps: delta_timestamps:
observation.image: [0.0] observation.images.top: [0.0]
observation.state: [0.0] observation.state: [0.0]
action: [0.0, 0.02, 0.04, 0.06, 0.08, 0.1, 0.12, 0.14, 0.16, 0.18, 0.2, 0.22, 0.24, 0.26, 0.28, 0.3, 0.32, 0.34, 0.36, 0.38, 0.4, 0.42, 0.44, 0.46, 0.48, 0.5, 0.52, 0.54, 0.56, 0.58, 0.6, 0.62, 0.64, 0.66, 0.68, 0.70, 0.72, 0.74, 0.76, 0.78, 0.8, 0.82, 0.84, 0.86, 0.88, 0.9, 0.92, 0.94, 0.96, 0.98, 1.0, 1.02, 1.04, 1.06, 1.08, 1.1, 1.12, 1.14, 1.16, 1.18, 1.2, 1.22, 1.24, 1.26, 1.28, 1.3, 1.32, 1.34, 1.36, 1.38, 1.40, 1.42, 1.44, 1.46, 1.48, 1.5, 1.52, 1.54, 1.56, 1.58, 1.6, 1.62, 1.64, 1.66, 1.68, 1.7, 1.72, 1.74, 1.76, 1.78, 1.8, 1.82, 1.84, 1.86, 1.88, 1.90, 1.92, 1.94, 1.96, 1.98] action: [0.0, 0.02, 0.04, 0.06, 0.08, 0.1, 0.12, 0.14, 0.16, 0.18, 0.2, 0.22, 0.24, 0.26, 0.28, 0.3, 0.32, 0.34, 0.36, 0.38, 0.4, 0.42, 0.44, 0.46, 0.48, 0.5, 0.52, 0.54, 0.56, 0.58, 0.6, 0.62, 0.64, 0.66, 0.68, 0.70, 0.72, 0.74, 0.76, 0.78, 0.8, 0.82, 0.84, 0.86, 0.88, 0.9, 0.92, 0.94, 0.96, 0.98, 1.0, 1.02, 1.04, 1.06, 1.08, 1.1, 1.12, 1.14, 1.16, 1.18, 1.2, 1.22, 1.24, 1.26, 1.28, 1.3, 1.32, 1.34, 1.36, 1.38, 1.40, 1.42, 1.44, 1.46, 1.48, 1.5, 1.52, 1.54, 1.56, 1.58, 1.6, 1.62, 1.64, 1.66, 1.68, 1.7, 1.72, 1.74, 1.76, 1.78, 1.8, 1.82, 1.84, 1.86, 1.88, 1.90, 1.92, 1.94, 1.96, 1.98]

View File

@ -89,7 +89,9 @@ def eval_policy(
visu = env.envs[0].render(mode="visualization") visu = env.envs[0].render(mode="visualization")
visu = visu[None, ...] # add batch dim visu = visu[None, ...] # add batch dim
else: else:
visu = np.stack([env.render(mode="visualization") for env in env.envs]) # TODO(now): Put mode back in.
visu = np.stack([env.render() for env in env.envs])
# visu = np.stack([env.render(mode="visualization") for env in env.envs])
ep_frames.append(visu) # noqa: B023 ep_frames.append(visu) # noqa: B023
for _ in range(num_episodes): for _ in range(num_episodes):
@ -248,7 +250,7 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
logging.info("Making transforms.") logging.info("Making transforms.")
# TODO(alexander-soare): Completely decouple datasets from evaluation. # TODO(alexander-soare): Completely decouple datasets from evaluation.
dataset = make_dataset(cfg, stats_path=stats_path) transform = make_dataset(cfg, stats_path=stats_path).transform
logging.info("Making environment.") logging.info("Making environment.")
env = make_env(cfg, num_parallel_envs=cfg.rollout_batch_size) env = make_env(cfg, num_parallel_envs=cfg.rollout_batch_size)
@ -263,7 +265,7 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
video_dir=Path(out_dir) / "eval", video_dir=Path(out_dir) / "eval",
fps=cfg.env.fps, fps=cfg.env.fps,
# TODO(rcadene): what should we do with the transform? # TODO(rcadene): what should we do with the transform?
transform=dataset.transform, transform=transform,
seed=cfg.seed, seed=cfg.seed,
) )
print(info["aggregated"]) print(info["aggregated"])

26
poetry.lock generated
View File

@ -941,7 +941,7 @@ mujoco = "^2.3.7"
type = "git" type = "git"
url = "git@github.com:huggingface/gym-xarm.git" url = "git@github.com:huggingface/gym-xarm.git"
reference = "HEAD" reference = "HEAD"
resolved_reference = "2eb83fc4fc871b9d271c946d169e42f226ac3a7c" resolved_reference = "08ddd5a9400783a6898bbf3c3014fc5da3961b9d"
[[package]] [[package]]
name = "gymnasium" name = "gymnasium"
@ -1709,20 +1709,20 @@ pyopengl = "*"
[[package]] [[package]]
name = "networkx" name = "networkx"
version = "3.2.1" version = "3.3"
description = "Python package for creating and manipulating graphs and networks" description = "Python package for creating and manipulating graphs and networks"
optional = false optional = false
python-versions = ">=3.9" python-versions = ">=3.10"
files = [ files = [
{file = "networkx-3.2.1-py3-none-any.whl", hash = "sha256:f18c69adc97877c42332c170849c96cefa91881c99a7cb3e95b7c659ebdc1ec2"}, {file = "networkx-3.3-py3-none-any.whl", hash = "sha256:28575580c6ebdaf4505b22c6256a2b9de86b316dc63ba9e93abde3d78dfdbcf2"},
{file = "networkx-3.2.1.tar.gz", hash = "sha256:9f1bb5cf3409bf324e0a722c20bdb4c20ee39bf1c30ce8ae499c8502b0b5e0c6"}, {file = "networkx-3.3.tar.gz", hash = "sha256:0c127d8b2f4865f59ae9cb8aafcd60b5c70f3241ebd66f7defad7c4ab90126c9"},
] ]
[package.extras] [package.extras]
default = ["matplotlib (>=3.5)", "numpy (>=1.22)", "pandas (>=1.4)", "scipy (>=1.9,!=1.11.0,!=1.11.1)"] default = ["matplotlib (>=3.6)", "numpy (>=1.23)", "pandas (>=1.4)", "scipy (>=1.9,!=1.11.0,!=1.11.1)"]
developer = ["changelist (==0.4)", "mypy (>=1.1)", "pre-commit (>=3.2)", "rtoml"] developer = ["changelist (==0.5)", "mypy (>=1.1)", "pre-commit (>=3.2)", "rtoml"]
doc = ["nb2plots (>=0.7)", "nbconvert (<7.9)", "numpydoc (>=1.6)", "pillow (>=9.4)", "pydata-sphinx-theme (>=0.14)", "sphinx (>=7)", "sphinx-gallery (>=0.14)", "texext (>=0.6.7)"] doc = ["myst-nb (>=1.0)", "numpydoc (>=1.7)", "pillow (>=9.4)", "pydata-sphinx-theme (>=0.14)", "sphinx (>=7)", "sphinx-gallery (>=0.14)", "texext (>=0.6.7)"]
extra = ["lxml (>=4.6)", "pydot (>=1.4.2)", "pygraphviz (>=1.11)", "sympy (>=1.10)"] extra = ["lxml (>=4.6)", "pydot (>=2.0)", "pygraphviz (>=1.12)", "sympy (>=1.10)"]
test = ["pytest (>=7.2)", "pytest-cov (>=4.0)"] test = ["pytest (>=7.2)", "pytest-cov (>=4.0)"]
[[package]] [[package]]
@ -3699,20 +3699,20 @@ watchdog = ["watchdog (>=2.3)"]
[[package]] [[package]]
name = "zarr" name = "zarr"
version = "2.17.1" version = "2.17.2"
description = "An implementation of chunked, compressed, N-dimensional arrays for Python" description = "An implementation of chunked, compressed, N-dimensional arrays for Python"
optional = false optional = false
python-versions = ">=3.9" python-versions = ">=3.9"
files = [ files = [
{file = "zarr-2.17.1-py3-none-any.whl", hash = "sha256:e25df2741a6e92645f3890f30f3136d5b57a0f8f831094b024bbcab5f2797bc7"}, {file = "zarr-2.17.2-py3-none-any.whl", hash = "sha256:70d7cc07c24280c380ef80644151d136b7503b0d83c9f214e8000ddc0f57f69b"},
{file = "zarr-2.17.1.tar.gz", hash = "sha256:564b3aa072122546fe69a0fa21736f466b20fad41754334b62619f088ce46261"}, {file = "zarr-2.17.2.tar.gz", hash = "sha256:2cbaa6cb4e342d45152d4a7a4b2013c337fcd3a8e7bc98253560180de60552ce"},
] ]
[package.dependencies] [package.dependencies]
asciitree = "*" asciitree = "*"
fasteners = {version = "*", markers = "sys_platform != \"emscripten\""} fasteners = {version = "*", markers = "sys_platform != \"emscripten\""}
numcodecs = ">=0.10.0" numcodecs = ">=0.10.0"
numpy = ">=1.21.1" numpy = ">=1.23"
[package.extras] [package.extras]
docs = ["numcodecs[msgpack]", "numpydoc", "pydata-sphinx-theme", "sphinx", "sphinx-automodapi", "sphinx-copybutton", "sphinx-design", "sphinx-issues"] docs = ["numcodecs[msgpack]", "numpydoc", "pydata-sphinx-theme", "sphinx", "sphinx-automodapi", "sphinx-copybutton", "sphinx-design", "sphinx-issues"]