revision 1
This commit is contained in:
parent
40d417ef60
commit
30023535f9
|
@ -86,7 +86,8 @@ def make_dataset(
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
delta_timestamps = cfg.policy.delta_timestamps
|
delta_timestamps = cfg.policy.get("delta_timestamps")
|
||||||
|
if delta_timestamps is not None:
|
||||||
for key in delta_timestamps:
|
for key in delta_timestamps:
|
||||||
if isinstance(delta_timestamps[key], str):
|
if isinstance(delta_timestamps[key], str):
|
||||||
delta_timestamps[key] = eval(delta_timestamps[key])
|
delta_timestamps[key] = eval(delta_timestamps[key])
|
||||||
|
|
|
@ -2,7 +2,7 @@ from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ActConfig:
|
class ActionChunkingTransformerConfig:
|
||||||
"""Configuration class for the Action Chunking Transformers policy.
|
"""Configuration class for the Action Chunking Transformers policy.
|
||||||
|
|
||||||
Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".
|
Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".
|
||||||
|
|
|
@ -20,10 +20,10 @@ from torch import Tensor, nn
|
||||||
from torchvision.models._utils import IntermediateLayerGetter
|
from torchvision.models._utils import IntermediateLayerGetter
|
||||||
from torchvision.ops.misc import FrozenBatchNorm2d
|
from torchvision.ops.misc import FrozenBatchNorm2d
|
||||||
|
|
||||||
from lerobot.common.policies.act.configuration_act import ActConfig
|
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
|
||||||
|
|
||||||
|
|
||||||
class ActPolicy(nn.Module):
|
class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
"""
|
"""
|
||||||
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
|
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
|
||||||
Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
|
Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
|
||||||
|
@ -61,9 +61,11 @@ class ActPolicy(nn.Module):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name = "act"
|
name = "act"
|
||||||
_multiple_obs_steps_not_handled_msg = "ActPolicy does not handle multiple observation steps."
|
_multiple_obs_steps_not_handled_msg = (
|
||||||
|
"ActionChunkingTransformerPolicy does not handle multiple observation steps."
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(self, cfg: ActConfig):
|
def __init__(self, cfg: ActionChunkingTransformerConfig):
|
||||||
"""
|
"""
|
||||||
TODO(alexander-soare): Add documentation for all parameters once we have model configs established.
|
TODO(alexander-soare): Add documentation for all parameters once we have model configs established.
|
||||||
"""
|
"""
|
||||||
|
@ -398,7 +400,7 @@ class ActPolicy(nn.Module):
|
||||||
class _TransformerEncoder(nn.Module):
|
class _TransformerEncoder(nn.Module):
|
||||||
"""Convenience module for running multiple encoder layers, maybe followed by normalization."""
|
"""Convenience module for running multiple encoder layers, maybe followed by normalization."""
|
||||||
|
|
||||||
def __init__(self, cfg: ActConfig):
|
def __init__(self, cfg: ActionChunkingTransformerConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layers = nn.ModuleList([_TransformerEncoderLayer(cfg) for _ in range(cfg.n_encoder_layers)])
|
self.layers = nn.ModuleList([_TransformerEncoderLayer(cfg) for _ in range(cfg.n_encoder_layers)])
|
||||||
self.norm = nn.LayerNorm(cfg.d_model) if cfg.pre_norm else nn.Identity()
|
self.norm = nn.LayerNorm(cfg.d_model) if cfg.pre_norm else nn.Identity()
|
||||||
|
@ -411,7 +413,7 @@ class _TransformerEncoder(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class _TransformerEncoderLayer(nn.Module):
|
class _TransformerEncoderLayer(nn.Module):
|
||||||
def __init__(self, cfg: ActConfig):
|
def __init__(self, cfg: ActionChunkingTransformerConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout)
|
self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout)
|
||||||
|
|
||||||
|
@ -449,7 +451,7 @@ class _TransformerEncoderLayer(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class _TransformerDecoder(nn.Module):
|
class _TransformerDecoder(nn.Module):
|
||||||
def __init__(self, cfg: ActConfig):
|
def __init__(self, cfg: ActionChunkingTransformerConfig):
|
||||||
"""Convenience module for running multiple decoder layers followed by normalization."""
|
"""Convenience module for running multiple decoder layers followed by normalization."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layers = nn.ModuleList([_TransformerDecoderLayer(cfg) for _ in range(cfg.n_decoder_layers)])
|
self.layers = nn.ModuleList([_TransformerDecoderLayer(cfg) for _ in range(cfg.n_decoder_layers)])
|
||||||
|
@ -472,7 +474,7 @@ class _TransformerDecoder(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class _TransformerDecoderLayer(nn.Module):
|
class _TransformerDecoderLayer(nn.Module):
|
||||||
def __init__(self, cfg: ActConfig):
|
def __init__(self, cfg: ActionChunkingTransformerConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout)
|
self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout)
|
||||||
self.multihead_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout)
|
self.multihead_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout)
|
||||||
|
|
|
@ -28,21 +28,21 @@ def make_policy(cfg):
|
||||||
**cfg.policy,
|
**cfg.policy,
|
||||||
)
|
)
|
||||||
elif cfg.policy.name == "act":
|
elif cfg.policy.name == "act":
|
||||||
from lerobot.common.policies.act.configuration_act import ActConfig
|
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
|
||||||
from lerobot.common.policies.act.modeling_act import ActPolicy
|
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
|
||||||
|
|
||||||
expected_kwargs = set(inspect.signature(ActConfig).parameters)
|
expected_kwargs = set(inspect.signature(ActionChunkingTransformerConfig).parameters)
|
||||||
assert set(cfg.policy).issuperset(
|
assert set(cfg.policy).issuperset(
|
||||||
expected_kwargs
|
expected_kwargs
|
||||||
), f"Hydra config is missing arguments: {set(cfg.policy).difference(expected_kwargs)}"
|
), f"Hydra config is missing arguments: {set(cfg.policy).difference(expected_kwargs)}"
|
||||||
policy_cfg = ActConfig(
|
policy_cfg = ActionChunkingTransformerConfig(
|
||||||
**{
|
**{
|
||||||
k: v
|
k: v
|
||||||
for k, v in OmegaConf.to_container(cfg.policy, resolve=True).items()
|
for k, v in OmegaConf.to_container(cfg.policy, resolve=True).items()
|
||||||
if k in expected_kwargs
|
if k in expected_kwargs
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
policy = ActPolicy(policy_cfg)
|
policy = ActionChunkingTransformerPolicy(policy_cfg)
|
||||||
policy.to(get_safe_torch_device(cfg.device))
|
policy.to(get_safe_torch_device(cfg.device))
|
||||||
else:
|
else:
|
||||||
raise ValueError(cfg.policy.name)
|
raise ValueError(cfg.policy.name)
|
||||||
|
|
|
@ -11,6 +11,7 @@ log_freq: 250
|
||||||
n_obs_steps: 1
|
n_obs_steps: 1
|
||||||
# when temporal_agg=False, n_action_steps=horizon
|
# when temporal_agg=False, n_action_steps=horizon
|
||||||
|
|
||||||
|
# See `configuration_act.py` for more details.
|
||||||
policy:
|
policy:
|
||||||
name: act
|
name: act
|
||||||
|
|
||||||
|
|
|
@ -18,14 +18,14 @@ from lerobot.common.datasets.xarm import XarmDataset
|
||||||
from lerobot.common.datasets.aloha import AlohaDataset
|
from lerobot.common.datasets.aloha import AlohaDataset
|
||||||
from lerobot.common.datasets.pusht import PushtDataset
|
from lerobot.common.datasets.pusht import PushtDataset
|
||||||
|
|
||||||
from lerobot.common.policies.act.modeling_act import ActPolicy
|
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
|
||||||
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
|
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
|
||||||
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
||||||
|
|
||||||
|
|
||||||
def test_available():
|
def test_available():
|
||||||
policy_classes = [
|
policy_classes = [
|
||||||
ActPolicy,
|
ActionChunkingTransformerPolicy,
|
||||||
DiffusionPolicy,
|
DiffusionPolicy,
|
||||||
TDMPCPolicy,
|
TDMPCPolicy,
|
||||||
]
|
]
|
||||||
|
|
Loading…
Reference in New Issue