Eval reproduction works with gym_aloha
This commit is contained in:
parent
e982c732f1
commit
1bab4a1dd5
|
@ -35,7 +35,7 @@ def make_env(cfg, num_parallel_envs=0) -> gym.Env | gym.vector.SyncVectorEnv:
|
|||
kwargs["task"] = cfg.env.task
|
||||
|
||||
env_fn = lambda: gym.make( # noqa: E731
|
||||
"gym_aloha/AlohaInsertion-v0",
|
||||
"gym_aloha/AlohaTransferCube-v0",
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
|
|
|
@ -3,9 +3,10 @@
|
|||
As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://arxiv.org/abs/2304.13705).
|
||||
The majority of changes here involve removing unused code, unifying naming, and adding helpful comments.
|
||||
"""
|
||||
from collections import deque
|
||||
|
||||
import math
|
||||
import time
|
||||
from collections import deque
|
||||
from itertools import chain
|
||||
from typing import Callable
|
||||
|
||||
|
@ -22,67 +23,6 @@ from torchvision.ops.misc import FrozenBatchNorm2d
|
|||
from lerobot.common.utils import get_safe_torch_device
|
||||
|
||||
|
||||
# class AbstractPolicy(nn.Module):
|
||||
# """Base policy which all policies should be derived from.
|
||||
|
||||
# The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its
|
||||
# documentation for more information.
|
||||
|
||||
# Note:
|
||||
# When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
|
||||
# 1. set the required class attributes:
|
||||
# - for classes inheriting from `AbstractDataset`: `available_datasets`
|
||||
# - for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
|
||||
# - for classes inheriting from `AbstractPolicy`: `name`
|
||||
# 2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
|
||||
# 3. update variables in `tests/test_available.py` by importing your new class
|
||||
# """
|
||||
|
||||
# name: str | None = None # same name should be used to instantiate the policy in factory.py
|
||||
|
||||
# def __init__(self, n_action_steps: int | None):
|
||||
# """
|
||||
# n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single
|
||||
# action is returned by `select_actions` and that doesn't have a horizon dimension. The `forward` method then
|
||||
# adds that dimension.
|
||||
# """
|
||||
# super().__init__()
|
||||
# assert self.name is not None, "Subclasses of `AbstractPolicy` should set the `name` class attribute."
|
||||
# self.n_action_steps = n_action_steps
|
||||
# self.clear_action_queue()
|
||||
|
||||
# def clear_action_queue(self):
|
||||
# """This should be called whenever the environment is reset."""
|
||||
# if self.n_action_steps is not None:
|
||||
# self._action_queue = deque([], maxlen=self.n_action_steps)
|
||||
|
||||
# def forward(self, fn) -> Tensor:
|
||||
# """Inference step that makes multi-step policies compatible with their single-step environments.
|
||||
|
||||
# WARNING: In general, this should not be overriden.
|
||||
|
||||
# Consider a "policy" that observes the environment then charts a course of N actions to take. To make this fit
|
||||
# into the formalism of a TorchRL environment, we view it as being effectively a policy that (1) makes an
|
||||
# observation and prepares a queue of actions, (2) consumes that queue when queried, regardless of the environment
|
||||
# observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that
|
||||
# the subclass doesn't have to.
|
||||
|
||||
# This method effectively wraps the `select_actions` method of the subclass. The following assumptions are made:
|
||||
# 1. The `select_actions` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is
|
||||
# the action trajectory horizon and * is the action dimensions.
|
||||
# 2. Prior to the `select_actions` method being called, theres is an `n_action_steps` instance attribute defined.
|
||||
# """
|
||||
# if self.n_action_steps is None:
|
||||
# return self.select_actions(*args, **kwargs)
|
||||
# if len(self._action_queue) == 0:
|
||||
# # `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has shape
|
||||
# # (n_action_steps, batch_size, *), hence the transpose.
|
||||
# self._action_queue.extend(self.select_actions(*args, **kwargs).transpose(0, 1))
|
||||
# return self._action_queue.popleft()
|
||||
|
||||
|
||||
|
||||
|
||||
class ActionChunkingTransformerPolicy(nn.Module):
|
||||
"""
|
||||
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
|
||||
|
@ -228,18 +168,30 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch, *_):
|
||||
# TODO(now): Implement queueing mechanism.
|
||||
self.eval()
|
||||
self._preprocess_batch(batch)
|
||||
def reset(self):
|
||||
"""This should be called whenever the environment is reset."""
|
||||
if self.n_action_steps is not None:
|
||||
self._action_queue = deque([], maxlen=self.n_action_steps)
|
||||
|
||||
# TODO(now): What's up with this 0.182?
|
||||
action = self.forward(
|
||||
robot_state=batch["observation.state"] * 0.182,
|
||||
image=batch["observation.images.top"],
|
||||
return_loss=False,
|
||||
)
|
||||
def select_action(self, batch: dict[str, Tensor], *_):
|
||||
"""
|
||||
This method wraps `select_actions` in order to return one action at a time for execution in the
|
||||
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
||||
queue is empty.
|
||||
"""
|
||||
if len(self._action_queue) == 0:
|
||||
# `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has shape
|
||||
# (n_action_steps, batch_size, *), hence the transpose.
|
||||
self._action_queue.extend(self.select_actions(batch).transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
@torch.no_grad()
|
||||
def select_actions(self, batch: dict[str, Tensor]):
|
||||
"""Use the action chunking transformer to generate a sequence of actions."""
|
||||
self.eval()
|
||||
self._preprocess_batch(batch, add_obs_steps_dim=True)
|
||||
|
||||
action = self.forward(batch, return_loss=False)
|
||||
|
||||
if self.cfg.temporal_agg:
|
||||
# TODO(rcadene): implement temporal aggregation
|
||||
|
@ -257,25 +209,37 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
return action[: self.n_action_steps]
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# TODO(now): Temporary bridge.
|
||||
# TODO(now): Temporary bridge until we know what to do about the `update` method.
|
||||
return self.update(*args, **kwargs)
|
||||
|
||||
def _preprocess_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
def _preprocess_batch(
|
||||
self, batch: dict[str, Tensor], add_obs_steps_dim: bool = False
|
||||
) -> dict[str, Tensor]:
|
||||
"""
|
||||
Expects batch to have (at least):
|
||||
This function expects `batch` to have (at least):
|
||||
{
|
||||
"observation.state": (B, 1, J) tensor of robot states (joint configuration)
|
||||
|
||||
"observation.images.top": (B, 1, C, H, W) tensor of images.
|
||||
"observation.state": (B, 1, J) OR (B, J) tensor of robot states (joint configuration).
|
||||
"observation.images.top": (B, 1, C, H, W) OR (B, C, H, W) tensor of images.
|
||||
"action": (B, H, J) tensor of actions (positional target for robot joint configuration)
|
||||
"action_is_pad": (B, H) mask for whether the actions are padding outside of the episode bounds.
|
||||
}
|
||||
"""
|
||||
if add_obs_steps_dim:
|
||||
# Add a dimension for the observations steps. Since n_obs_steps > 1 is not supported right now,
|
||||
# this just amounts to an unsqueeze.
|
||||
for k in batch:
|
||||
if k.startswith("observation."):
|
||||
batch[k] = batch[k].unsqueeze(1)
|
||||
|
||||
if batch["observation.state"].shape[1] != 1:
|
||||
raise ValueError(self._multiple_obs_steps_not_handled_msg)
|
||||
batch["observation.state"] = batch["observation.state"].squeeze(1)
|
||||
# TODO(alexander-soare): generalize this to multiple images. Note: no squeeze is required for
|
||||
# "observation.images.top" because then we'd have to unsqueeze to get get the image index dimension.
|
||||
# TODO(alexander-soare): generalize this to multiple images.
|
||||
assert (
|
||||
sum(k.startswith("observation.images.") and not k.endswith("is_pad") for k in batch) == 1
|
||||
), "ACT only handles one image for now."
|
||||
# Note: no squeeze is required for "observation.images.top" because then we'd have to unsqueeze to get
|
||||
# the image index dimension.
|
||||
|
||||
def update(self, batch, *_):
|
||||
start_time = time.time()
|
||||
|
@ -378,9 +342,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
# Forward pass through VAE encoder and sample the latent with the reparameterization trick.
|
||||
cls_token_out = self.vae_encoder(
|
||||
vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2)
|
||||
)[
|
||||
0
|
||||
] # (B, D)
|
||||
)[0] # (B, D)
|
||||
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
|
||||
mu = latent_pdf_params[:, : self.latent_dim]
|
||||
# This is 2log(sigma). Done this way to match the original implementation.
|
||||
|
|
|
@ -26,7 +26,6 @@ def make_policy(cfg):
|
|||
policy = ActionChunkingTransformerPolicy(
|
||||
cfg.policy,
|
||||
cfg.device,
|
||||
n_obs_steps=cfg.n_obs_steps,
|
||||
n_action_steps=cfg.n_action_steps,
|
||||
)
|
||||
else:
|
||||
|
|
|
@ -58,6 +58,6 @@ policy:
|
|||
action_dim: ???
|
||||
|
||||
delta_timestamps:
|
||||
observation.image: [0.0]
|
||||
observation.images.top: [0.0]
|
||||
observation.state: [0.0]
|
||||
action: [0.0, 0.02, 0.04, 0.06, 0.08, 0.1, 0.12, 0.14, 0.16, 0.18, 0.2, 0.22, 0.24, 0.26, 0.28, 0.3, 0.32, 0.34, 0.36, 0.38, 0.4, 0.42, 0.44, 0.46, 0.48, 0.5, 0.52, 0.54, 0.56, 0.58, 0.6, 0.62, 0.64, 0.66, 0.68, 0.70, 0.72, 0.74, 0.76, 0.78, 0.8, 0.82, 0.84, 0.86, 0.88, 0.9, 0.92, 0.94, 0.96, 0.98, 1.0, 1.02, 1.04, 1.06, 1.08, 1.1, 1.12, 1.14, 1.16, 1.18, 1.2, 1.22, 1.24, 1.26, 1.28, 1.3, 1.32, 1.34, 1.36, 1.38, 1.40, 1.42, 1.44, 1.46, 1.48, 1.5, 1.52, 1.54, 1.56, 1.58, 1.6, 1.62, 1.64, 1.66, 1.68, 1.7, 1.72, 1.74, 1.76, 1.78, 1.8, 1.82, 1.84, 1.86, 1.88, 1.90, 1.92, 1.94, 1.96, 1.98]
|
||||
|
|
|
@ -89,7 +89,9 @@ def eval_policy(
|
|||
visu = env.envs[0].render(mode="visualization")
|
||||
visu = visu[None, ...] # add batch dim
|
||||
else:
|
||||
visu = np.stack([env.render(mode="visualization") for env in env.envs])
|
||||
# TODO(now): Put mode back in.
|
||||
visu = np.stack([env.render() for env in env.envs])
|
||||
# visu = np.stack([env.render(mode="visualization") for env in env.envs])
|
||||
ep_frames.append(visu) # noqa: B023
|
||||
|
||||
for _ in range(num_episodes):
|
||||
|
@ -248,7 +250,7 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
|
|||
|
||||
logging.info("Making transforms.")
|
||||
# TODO(alexander-soare): Completely decouple datasets from evaluation.
|
||||
dataset = make_dataset(cfg, stats_path=stats_path)
|
||||
transform = make_dataset(cfg, stats_path=stats_path).transform
|
||||
|
||||
logging.info("Making environment.")
|
||||
env = make_env(cfg, num_parallel_envs=cfg.rollout_batch_size)
|
||||
|
@ -263,7 +265,7 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
|
|||
video_dir=Path(out_dir) / "eval",
|
||||
fps=cfg.env.fps,
|
||||
# TODO(rcadene): what should we do with the transform?
|
||||
transform=dataset.transform,
|
||||
transform=transform,
|
||||
seed=cfg.seed,
|
||||
)
|
||||
print(info["aggregated"])
|
||||
|
|
|
@ -941,7 +941,7 @@ mujoco = "^2.3.7"
|
|||
type = "git"
|
||||
url = "git@github.com:huggingface/gym-xarm.git"
|
||||
reference = "HEAD"
|
||||
resolved_reference = "2eb83fc4fc871b9d271c946d169e42f226ac3a7c"
|
||||
resolved_reference = "08ddd5a9400783a6898bbf3c3014fc5da3961b9d"
|
||||
|
||||
[[package]]
|
||||
name = "gymnasium"
|
||||
|
@ -1709,20 +1709,20 @@ pyopengl = "*"
|
|||
|
||||
[[package]]
|
||||
name = "networkx"
|
||||
version = "3.2.1"
|
||||
version = "3.3"
|
||||
description = "Python package for creating and manipulating graphs and networks"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
python-versions = ">=3.10"
|
||||
files = [
|
||||
{file = "networkx-3.2.1-py3-none-any.whl", hash = "sha256:f18c69adc97877c42332c170849c96cefa91881c99a7cb3e95b7c659ebdc1ec2"},
|
||||
{file = "networkx-3.2.1.tar.gz", hash = "sha256:9f1bb5cf3409bf324e0a722c20bdb4c20ee39bf1c30ce8ae499c8502b0b5e0c6"},
|
||||
{file = "networkx-3.3-py3-none-any.whl", hash = "sha256:28575580c6ebdaf4505b22c6256a2b9de86b316dc63ba9e93abde3d78dfdbcf2"},
|
||||
{file = "networkx-3.3.tar.gz", hash = "sha256:0c127d8b2f4865f59ae9cb8aafcd60b5c70f3241ebd66f7defad7c4ab90126c9"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
default = ["matplotlib (>=3.5)", "numpy (>=1.22)", "pandas (>=1.4)", "scipy (>=1.9,!=1.11.0,!=1.11.1)"]
|
||||
developer = ["changelist (==0.4)", "mypy (>=1.1)", "pre-commit (>=3.2)", "rtoml"]
|
||||
doc = ["nb2plots (>=0.7)", "nbconvert (<7.9)", "numpydoc (>=1.6)", "pillow (>=9.4)", "pydata-sphinx-theme (>=0.14)", "sphinx (>=7)", "sphinx-gallery (>=0.14)", "texext (>=0.6.7)"]
|
||||
extra = ["lxml (>=4.6)", "pydot (>=1.4.2)", "pygraphviz (>=1.11)", "sympy (>=1.10)"]
|
||||
default = ["matplotlib (>=3.6)", "numpy (>=1.23)", "pandas (>=1.4)", "scipy (>=1.9,!=1.11.0,!=1.11.1)"]
|
||||
developer = ["changelist (==0.5)", "mypy (>=1.1)", "pre-commit (>=3.2)", "rtoml"]
|
||||
doc = ["myst-nb (>=1.0)", "numpydoc (>=1.7)", "pillow (>=9.4)", "pydata-sphinx-theme (>=0.14)", "sphinx (>=7)", "sphinx-gallery (>=0.14)", "texext (>=0.6.7)"]
|
||||
extra = ["lxml (>=4.6)", "pydot (>=2.0)", "pygraphviz (>=1.12)", "sympy (>=1.10)"]
|
||||
test = ["pytest (>=7.2)", "pytest-cov (>=4.0)"]
|
||||
|
||||
[[package]]
|
||||
|
@ -3699,20 +3699,20 @@ watchdog = ["watchdog (>=2.3)"]
|
|||
|
||||
[[package]]
|
||||
name = "zarr"
|
||||
version = "2.17.1"
|
||||
version = "2.17.2"
|
||||
description = "An implementation of chunked, compressed, N-dimensional arrays for Python"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "zarr-2.17.1-py3-none-any.whl", hash = "sha256:e25df2741a6e92645f3890f30f3136d5b57a0f8f831094b024bbcab5f2797bc7"},
|
||||
{file = "zarr-2.17.1.tar.gz", hash = "sha256:564b3aa072122546fe69a0fa21736f466b20fad41754334b62619f088ce46261"},
|
||||
{file = "zarr-2.17.2-py3-none-any.whl", hash = "sha256:70d7cc07c24280c380ef80644151d136b7503b0d83c9f214e8000ddc0f57f69b"},
|
||||
{file = "zarr-2.17.2.tar.gz", hash = "sha256:2cbaa6cb4e342d45152d4a7a4b2013c337fcd3a8e7bc98253560180de60552ce"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
asciitree = "*"
|
||||
fasteners = {version = "*", markers = "sys_platform != \"emscripten\""}
|
||||
numcodecs = ">=0.10.0"
|
||||
numpy = ">=1.21.1"
|
||||
numpy = ">=1.23"
|
||||
|
||||
[package.extras]
|
||||
docs = ["numcodecs[msgpack]", "numpydoc", "pydata-sphinx-theme", "sphinx", "sphinx-automodapi", "sphinx-copybutton", "sphinx-design", "sphinx-issues"]
|
||||
|
|
Loading…
Reference in New Issue