revision 1

This commit is contained in:
Alexander Soare 2024-04-15 10:56:43 +01:00
parent 40d417ef60
commit 30023535f9
6 changed files with 24 additions and 20 deletions

View File

@ -86,10 +86,11 @@ def make_dataset(
]
)
delta_timestamps = cfg.policy.delta_timestamps
for key in delta_timestamps:
if isinstance(delta_timestamps[key], str):
delta_timestamps[key] = eval(delta_timestamps[key])
delta_timestamps = cfg.policy.get("delta_timestamps")
if delta_timestamps is not None:
for key in delta_timestamps:
if isinstance(delta_timestamps[key], str):
delta_timestamps[key] = eval(delta_timestamps[key])
dataset = clsfunc(
dataset_id=cfg.dataset_id,

View File

@ -2,7 +2,7 @@ from dataclasses import dataclass, field
@dataclass
class ActConfig:
class ActionChunkingTransformerConfig:
"""Configuration class for the Action Chunking Transformers policy.
Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".

View File

@ -20,10 +20,10 @@ from torch import Tensor, nn
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.ops.misc import FrozenBatchNorm2d
from lerobot.common.policies.act.configuration_act import ActConfig
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
class ActPolicy(nn.Module):
class ActionChunkingTransformerPolicy(nn.Module):
"""
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
@ -61,9 +61,11 @@ class ActPolicy(nn.Module):
"""
name = "act"
_multiple_obs_steps_not_handled_msg = "ActPolicy does not handle multiple observation steps."
_multiple_obs_steps_not_handled_msg = (
"ActionChunkingTransformerPolicy does not handle multiple observation steps."
)
def __init__(self, cfg: ActConfig):
def __init__(self, cfg: ActionChunkingTransformerConfig):
"""
TODO(alexander-soare): Add documentation for all parameters once we have model configs established.
"""
@ -398,7 +400,7 @@ class ActPolicy(nn.Module):
class _TransformerEncoder(nn.Module):
"""Convenience module for running multiple encoder layers, maybe followed by normalization."""
def __init__(self, cfg: ActConfig):
def __init__(self, cfg: ActionChunkingTransformerConfig):
super().__init__()
self.layers = nn.ModuleList([_TransformerEncoderLayer(cfg) for _ in range(cfg.n_encoder_layers)])
self.norm = nn.LayerNorm(cfg.d_model) if cfg.pre_norm else nn.Identity()
@ -411,7 +413,7 @@ class _TransformerEncoder(nn.Module):
class _TransformerEncoderLayer(nn.Module):
def __init__(self, cfg: ActConfig):
def __init__(self, cfg: ActionChunkingTransformerConfig):
super().__init__()
self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout)
@ -449,7 +451,7 @@ class _TransformerEncoderLayer(nn.Module):
class _TransformerDecoder(nn.Module):
def __init__(self, cfg: ActConfig):
def __init__(self, cfg: ActionChunkingTransformerConfig):
"""Convenience module for running multiple decoder layers followed by normalization."""
super().__init__()
self.layers = nn.ModuleList([_TransformerDecoderLayer(cfg) for _ in range(cfg.n_decoder_layers)])
@ -472,7 +474,7 @@ class _TransformerDecoder(nn.Module):
class _TransformerDecoderLayer(nn.Module):
def __init__(self, cfg: ActConfig):
def __init__(self, cfg: ActionChunkingTransformerConfig):
super().__init__()
self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout)
self.multihead_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout)

View File

@ -28,21 +28,21 @@ def make_policy(cfg):
**cfg.policy,
)
elif cfg.policy.name == "act":
from lerobot.common.policies.act.configuration_act import ActConfig
from lerobot.common.policies.act.modeling_act import ActPolicy
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
expected_kwargs = set(inspect.signature(ActConfig).parameters)
expected_kwargs = set(inspect.signature(ActionChunkingTransformerConfig).parameters)
assert set(cfg.policy).issuperset(
expected_kwargs
), f"Hydra config is missing arguments: {set(cfg.policy).difference(expected_kwargs)}"
policy_cfg = ActConfig(
policy_cfg = ActionChunkingTransformerConfig(
**{
k: v
for k, v in OmegaConf.to_container(cfg.policy, resolve=True).items()
if k in expected_kwargs
}
)
policy = ActPolicy(policy_cfg)
policy = ActionChunkingTransformerPolicy(policy_cfg)
policy.to(get_safe_torch_device(cfg.device))
else:
raise ValueError(cfg.policy.name)

View File

@ -11,6 +11,7 @@ log_freq: 250
n_obs_steps: 1
# when temporal_agg=False, n_action_steps=horizon
# See `configuration_act.py` for more details.
policy:
name: act

View File

@ -18,14 +18,14 @@ from lerobot.common.datasets.xarm import XarmDataset
from lerobot.common.datasets.aloha import AlohaDataset
from lerobot.common.datasets.pusht import PushtDataset
from lerobot.common.policies.act.modeling_act import ActPolicy
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
def test_available():
policy_classes = [
ActPolicy,
ActionChunkingTransformerPolicy,
DiffusionPolicy,
TDMPCPolicy,
]