revision 1

This commit is contained in:
Alexander Soare 2024-04-15 10:56:43 +01:00
parent 40d417ef60
commit 30023535f9
6 changed files with 24 additions and 20 deletions

View File

@ -86,10 +86,11 @@ def make_dataset(
] ]
) )
delta_timestamps = cfg.policy.delta_timestamps delta_timestamps = cfg.policy.get("delta_timestamps")
for key in delta_timestamps: if delta_timestamps is not None:
if isinstance(delta_timestamps[key], str): for key in delta_timestamps:
delta_timestamps[key] = eval(delta_timestamps[key]) if isinstance(delta_timestamps[key], str):
delta_timestamps[key] = eval(delta_timestamps[key])
dataset = clsfunc( dataset = clsfunc(
dataset_id=cfg.dataset_id, dataset_id=cfg.dataset_id,

View File

@ -2,7 +2,7 @@ from dataclasses import dataclass, field
@dataclass @dataclass
class ActConfig: class ActionChunkingTransformerConfig:
"""Configuration class for the Action Chunking Transformers policy. """Configuration class for the Action Chunking Transformers policy.
Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer". Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".

View File

@ -20,10 +20,10 @@ from torch import Tensor, nn
from torchvision.models._utils import IntermediateLayerGetter from torchvision.models._utils import IntermediateLayerGetter
from torchvision.ops.misc import FrozenBatchNorm2d from torchvision.ops.misc import FrozenBatchNorm2d
from lerobot.common.policies.act.configuration_act import ActConfig from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
class ActPolicy(nn.Module): class ActionChunkingTransformerPolicy(nn.Module):
""" """
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act) Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
@ -61,9 +61,11 @@ class ActPolicy(nn.Module):
""" """
name = "act" name = "act"
_multiple_obs_steps_not_handled_msg = "ActPolicy does not handle multiple observation steps." _multiple_obs_steps_not_handled_msg = (
"ActionChunkingTransformerPolicy does not handle multiple observation steps."
)
def __init__(self, cfg: ActConfig): def __init__(self, cfg: ActionChunkingTransformerConfig):
""" """
TODO(alexander-soare): Add documentation for all parameters once we have model configs established. TODO(alexander-soare): Add documentation for all parameters once we have model configs established.
""" """
@ -398,7 +400,7 @@ class ActPolicy(nn.Module):
class _TransformerEncoder(nn.Module): class _TransformerEncoder(nn.Module):
"""Convenience module for running multiple encoder layers, maybe followed by normalization.""" """Convenience module for running multiple encoder layers, maybe followed by normalization."""
def __init__(self, cfg: ActConfig): def __init__(self, cfg: ActionChunkingTransformerConfig):
super().__init__() super().__init__()
self.layers = nn.ModuleList([_TransformerEncoderLayer(cfg) for _ in range(cfg.n_encoder_layers)]) self.layers = nn.ModuleList([_TransformerEncoderLayer(cfg) for _ in range(cfg.n_encoder_layers)])
self.norm = nn.LayerNorm(cfg.d_model) if cfg.pre_norm else nn.Identity() self.norm = nn.LayerNorm(cfg.d_model) if cfg.pre_norm else nn.Identity()
@ -411,7 +413,7 @@ class _TransformerEncoder(nn.Module):
class _TransformerEncoderLayer(nn.Module): class _TransformerEncoderLayer(nn.Module):
def __init__(self, cfg: ActConfig): def __init__(self, cfg: ActionChunkingTransformerConfig):
super().__init__() super().__init__()
self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout) self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout)
@ -449,7 +451,7 @@ class _TransformerEncoderLayer(nn.Module):
class _TransformerDecoder(nn.Module): class _TransformerDecoder(nn.Module):
def __init__(self, cfg: ActConfig): def __init__(self, cfg: ActionChunkingTransformerConfig):
"""Convenience module for running multiple decoder layers followed by normalization.""" """Convenience module for running multiple decoder layers followed by normalization."""
super().__init__() super().__init__()
self.layers = nn.ModuleList([_TransformerDecoderLayer(cfg) for _ in range(cfg.n_decoder_layers)]) self.layers = nn.ModuleList([_TransformerDecoderLayer(cfg) for _ in range(cfg.n_decoder_layers)])
@ -472,7 +474,7 @@ class _TransformerDecoder(nn.Module):
class _TransformerDecoderLayer(nn.Module): class _TransformerDecoderLayer(nn.Module):
def __init__(self, cfg: ActConfig): def __init__(self, cfg: ActionChunkingTransformerConfig):
super().__init__() super().__init__()
self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout) self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout)
self.multihead_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout) self.multihead_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout)

View File

@ -28,21 +28,21 @@ def make_policy(cfg):
**cfg.policy, **cfg.policy,
) )
elif cfg.policy.name == "act": elif cfg.policy.name == "act":
from lerobot.common.policies.act.configuration_act import ActConfig from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
from lerobot.common.policies.act.modeling_act import ActPolicy from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
expected_kwargs = set(inspect.signature(ActConfig).parameters) expected_kwargs = set(inspect.signature(ActionChunkingTransformerConfig).parameters)
assert set(cfg.policy).issuperset( assert set(cfg.policy).issuperset(
expected_kwargs expected_kwargs
), f"Hydra config is missing arguments: {set(cfg.policy).difference(expected_kwargs)}" ), f"Hydra config is missing arguments: {set(cfg.policy).difference(expected_kwargs)}"
policy_cfg = ActConfig( policy_cfg = ActionChunkingTransformerConfig(
**{ **{
k: v k: v
for k, v in OmegaConf.to_container(cfg.policy, resolve=True).items() for k, v in OmegaConf.to_container(cfg.policy, resolve=True).items()
if k in expected_kwargs if k in expected_kwargs
} }
) )
policy = ActPolicy(policy_cfg) policy = ActionChunkingTransformerPolicy(policy_cfg)
policy.to(get_safe_torch_device(cfg.device)) policy.to(get_safe_torch_device(cfg.device))
else: else:
raise ValueError(cfg.policy.name) raise ValueError(cfg.policy.name)

View File

@ -11,6 +11,7 @@ log_freq: 250
n_obs_steps: 1 n_obs_steps: 1
# when temporal_agg=False, n_action_steps=horizon # when temporal_agg=False, n_action_steps=horizon
# See `configuration_act.py` for more details.
policy: policy:
name: act name: act

View File

@ -18,14 +18,14 @@ from lerobot.common.datasets.xarm import XarmDataset
from lerobot.common.datasets.aloha import AlohaDataset from lerobot.common.datasets.aloha import AlohaDataset
from lerobot.common.datasets.pusht import PushtDataset from lerobot.common.datasets.pusht import PushtDataset
from lerobot.common.policies.act.modeling_act import ActPolicy from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
from lerobot.common.policies.diffusion.policy import DiffusionPolicy from lerobot.common.policies.diffusion.policy import DiffusionPolicy
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
def test_available(): def test_available():
policy_classes = [ policy_classes = [
ActPolicy, ActionChunkingTransformerPolicy,
DiffusionPolicy, DiffusionPolicy,
TDMPCPolicy, TDMPCPolicy,
] ]