From 30023535f977212f12026b848e831b8367005328 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Mon, 15 Apr 2024 10:56:43 +0100 Subject: [PATCH] revision 1 --- lerobot/common/datasets/factory.py | 9 +++++---- .../common/policies/act/configuration_act.py | 2 +- lerobot/common/policies/act/modeling_act.py | 18 ++++++++++-------- lerobot/common/policies/factory.py | 10 +++++----- lerobot/configs/policy/act.yaml | 1 + tests/test_available.py | 4 ++-- 6 files changed, 24 insertions(+), 20 deletions(-) diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 10106fe9..4ae161f6 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -86,10 +86,11 @@ def make_dataset( ] ) - delta_timestamps = cfg.policy.delta_timestamps - for key in delta_timestamps: - if isinstance(delta_timestamps[key], str): - delta_timestamps[key] = eval(delta_timestamps[key]) + delta_timestamps = cfg.policy.get("delta_timestamps") + if delta_timestamps is not None: + for key in delta_timestamps: + if isinstance(delta_timestamps[key], str): + delta_timestamps[key] = eval(delta_timestamps[key]) dataset = clsfunc( dataset_id=cfg.dataset_id, diff --git a/lerobot/common/policies/act/configuration_act.py b/lerobot/common/policies/act/configuration_act.py index 84d960db..74ed270e 100644 --- a/lerobot/common/policies/act/configuration_act.py +++ b/lerobot/common/policies/act/configuration_act.py @@ -2,7 +2,7 @@ from dataclasses import dataclass, field @dataclass -class ActConfig: +class ActionChunkingTransformerConfig: """Configuration class for the Action Chunking Transformers policy. Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer". diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 769c9470..1361e071 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -20,10 +20,10 @@ from torch import Tensor, nn from torchvision.models._utils import IntermediateLayerGetter from torchvision.ops.misc import FrozenBatchNorm2d -from lerobot.common.policies.act.configuration_act import ActConfig +from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig -class ActPolicy(nn.Module): +class ActionChunkingTransformerPolicy(nn.Module): """ Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act) @@ -61,9 +61,11 @@ class ActPolicy(nn.Module): """ name = "act" - _multiple_obs_steps_not_handled_msg = "ActPolicy does not handle multiple observation steps." + _multiple_obs_steps_not_handled_msg = ( + "ActionChunkingTransformerPolicy does not handle multiple observation steps." + ) - def __init__(self, cfg: ActConfig): + def __init__(self, cfg: ActionChunkingTransformerConfig): """ TODO(alexander-soare): Add documentation for all parameters once we have model configs established. """ @@ -398,7 +400,7 @@ class ActPolicy(nn.Module): class _TransformerEncoder(nn.Module): """Convenience module for running multiple encoder layers, maybe followed by normalization.""" - def __init__(self, cfg: ActConfig): + def __init__(self, cfg: ActionChunkingTransformerConfig): super().__init__() self.layers = nn.ModuleList([_TransformerEncoderLayer(cfg) for _ in range(cfg.n_encoder_layers)]) self.norm = nn.LayerNorm(cfg.d_model) if cfg.pre_norm else nn.Identity() @@ -411,7 +413,7 @@ class _TransformerEncoder(nn.Module): class _TransformerEncoderLayer(nn.Module): - def __init__(self, cfg: ActConfig): + def __init__(self, cfg: ActionChunkingTransformerConfig): super().__init__() self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout) @@ -449,7 +451,7 @@ class _TransformerEncoderLayer(nn.Module): class _TransformerDecoder(nn.Module): - def __init__(self, cfg: ActConfig): + def __init__(self, cfg: ActionChunkingTransformerConfig): """Convenience module for running multiple decoder layers followed by normalization.""" super().__init__() self.layers = nn.ModuleList([_TransformerDecoderLayer(cfg) for _ in range(cfg.n_decoder_layers)]) @@ -472,7 +474,7 @@ class _TransformerDecoder(nn.Module): class _TransformerDecoderLayer(nn.Module): - def __init__(self, cfg: ActConfig): + def __init__(self, cfg: ActionChunkingTransformerConfig): super().__init__() self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout) self.multihead_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout) diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index ed8ba7cf..80ae27da 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -28,21 +28,21 @@ def make_policy(cfg): **cfg.policy, ) elif cfg.policy.name == "act": - from lerobot.common.policies.act.configuration_act import ActConfig - from lerobot.common.policies.act.modeling_act import ActPolicy + from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig + from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy - expected_kwargs = set(inspect.signature(ActConfig).parameters) + expected_kwargs = set(inspect.signature(ActionChunkingTransformerConfig).parameters) assert set(cfg.policy).issuperset( expected_kwargs ), f"Hydra config is missing arguments: {set(cfg.policy).difference(expected_kwargs)}" - policy_cfg = ActConfig( + policy_cfg = ActionChunkingTransformerConfig( **{ k: v for k, v in OmegaConf.to_container(cfg.policy, resolve=True).items() if k in expected_kwargs } ) - policy = ActPolicy(policy_cfg) + policy = ActionChunkingTransformerPolicy(policy_cfg) policy.to(get_safe_torch_device(cfg.device)) else: raise ValueError(cfg.policy.name) diff --git a/lerobot/configs/policy/act.yaml b/lerobot/configs/policy/act.yaml index 22f2d53a..bd883613 100644 --- a/lerobot/configs/policy/act.yaml +++ b/lerobot/configs/policy/act.yaml @@ -11,6 +11,7 @@ log_freq: 250 n_obs_steps: 1 # when temporal_agg=False, n_action_steps=horizon +# See `configuration_act.py` for more details. policy: name: act diff --git a/tests/test_available.py b/tests/test_available.py index 36791a3e..b25a921f 100644 --- a/tests/test_available.py +++ b/tests/test_available.py @@ -18,14 +18,14 @@ from lerobot.common.datasets.xarm import XarmDataset from lerobot.common.datasets.aloha import AlohaDataset from lerobot.common.datasets.pusht import PushtDataset -from lerobot.common.policies.act.modeling_act import ActPolicy +from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy from lerobot.common.policies.diffusion.policy import DiffusionPolicy from lerobot.common.policies.tdmpc.policy import TDMPCPolicy def test_available(): policy_classes = [ - ActPolicy, + ActionChunkingTransformerPolicy, DiffusionPolicy, TDMPCPolicy, ]