Merge remote-tracking branch 'upstream/main' into refactor_dp
This commit is contained in:
commit
14f3ffb412
|
@ -146,7 +146,8 @@ jobs:
|
|||
device=cpu \
|
||||
save_model=true \
|
||||
save_freq=2 \
|
||||
horizon=20 \
|
||||
policy.n_action_steps=20 \
|
||||
policy.chunk_size=20 \
|
||||
policy.batch_size=2 \
|
||||
hydra.run.dir=tests/outputs/act/
|
||||
|
||||
|
|
|
@ -0,0 +1,114 @@
|
|||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActionChunkingTransformerConfig:
|
||||
"""Configuration class for the Action Chunking Transformers policy.
|
||||
|
||||
Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".
|
||||
|
||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||
Those are: `state_dim`, `action_dim` and `camera_names`.
|
||||
|
||||
Args:
|
||||
state_dim: Dimensionality of the observation state space (excluding images).
|
||||
action_dim: Dimensionality of the action space.
|
||||
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
||||
current step and additional steps going back).
|
||||
camera_names: The (unique) set of names for the cameras.
|
||||
chunk_size: The size of the action prediction "chunks" in units of environment steps.
|
||||
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
|
||||
This should be no greater than the chunk size. For example, if the chunk size size 100, you may
|
||||
set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
|
||||
environment, and throws the other 50 out.
|
||||
image_normalization_mean: Value to subtract from the input image pixels (inputs are assumed to be in
|
||||
[0, 1]) for normalization.
|
||||
image_normalization_std: Value by which to divide the input image pixels (after the mean has been
|
||||
subtracted).
|
||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||
use_pretrained_backbone: Whether the backbone should be initialized with ImageNet, pretrained weights
|
||||
from torchvision.
|
||||
replace_final_stride_with_dilation: Whether to replace the ResNet's final 2x2 stride with a dilated
|
||||
convolution.
|
||||
pre_norm: Whether to use "pre-norm" in the transformer blocks.
|
||||
d_model: The transformer blocks' main hidden dimension.
|
||||
n_heads: The number of heads to use in the transformer blocks' multi-head attention.
|
||||
dim_feedforward: The dimension to expand the transformer's hidden dimension to in the feed-forward
|
||||
layers.
|
||||
feedforward_activation: The activation to use in the transformer block's feed-forward layers.
|
||||
n_encoder_layers: The number of transformer layers to use for the transformer encoder.
|
||||
n_decoder_layers: The number of transformer layers to use for the transformer decoder.
|
||||
use_vae: Whether to use a variational objective during training. This introduces another transformer
|
||||
which is used as the VAE's encoder (not to be confused with the transformer encoder - see
|
||||
documentation in the policy class).
|
||||
latent_dim: The VAE's latent dimension.
|
||||
n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder.
|
||||
use_temporal_aggregation: Whether to blend the actions of multiple policy invocations for any given
|
||||
environment step.
|
||||
dropout: Dropout to use in the transformer layers (see code for details).
|
||||
kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective
|
||||
is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`.
|
||||
"""
|
||||
|
||||
# Environment.
|
||||
state_dim: int = 14
|
||||
action_dim: int = 14
|
||||
|
||||
# Inputs / output structure.
|
||||
n_obs_steps: int = 1
|
||||
camera_names: list[str] = field(default_factory=lambda: ["top"])
|
||||
chunk_size: int = 100
|
||||
n_action_steps: int = 100
|
||||
|
||||
# Vision preprocessing.
|
||||
image_normalization_mean: tuple[float, float, float] = field(
|
||||
default_factory=lambda: [0.485, 0.456, 0.406]
|
||||
)
|
||||
image_normalization_std: tuple[float, float, float] = field(default_factory=lambda: [0.229, 0.224, 0.225])
|
||||
|
||||
# Architecture.
|
||||
# Vision backbone.
|
||||
vision_backbone: str = "resnet18"
|
||||
use_pretrained_backbone: bool = True
|
||||
replace_final_stride_with_dilation: int = False
|
||||
# Transformer layers.
|
||||
pre_norm: bool = False
|
||||
d_model: int = 512
|
||||
n_heads: int = 8
|
||||
dim_feedforward: int = 3200
|
||||
feedforward_activation: str = "relu"
|
||||
n_encoder_layers: int = 4
|
||||
n_decoder_layers: int = 1
|
||||
# VAE.
|
||||
use_vae: bool = True
|
||||
latent_dim: int = 32
|
||||
n_vae_encoder_layers: int = 4
|
||||
|
||||
# Inference.
|
||||
use_temporal_aggregation: bool = False
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: float = 0.1
|
||||
kl_weight: float = 10.0
|
||||
|
||||
# ---
|
||||
# TODO(alexander-soare): Remove these from the policy config.
|
||||
batch_size: int = 8
|
||||
lr: float = 1e-5
|
||||
lr_backbone: float = 1e-5
|
||||
weight_decay: float = 1e-4
|
||||
grad_clip_norm: float = 10
|
||||
utd: int = 1
|
||||
|
||||
def __post_init__(self):
|
||||
"""Input validation."""
|
||||
if not self.vision_backbone.startswith("resnet"):
|
||||
raise ValueError("`vision_backbone` must be one of the ResNet variants.")
|
||||
if self.use_temporal_aggregation:
|
||||
raise NotImplementedError("Temporal aggregation is not yet implemented.")
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
"The chunk size is the upper bound for the number of action steps per model invocation."
|
||||
)
|
||||
if self.camera_names != ["top"]:
|
||||
raise ValueError("For now, `camera_names` can only be ['top']")
|
|
@ -20,7 +20,7 @@ from torch import Tensor, nn
|
|||
from torchvision.models._utils import IntermediateLayerGetter
|
||||
from torchvision.ops.misc import FrozenBatchNorm2d
|
||||
|
||||
from lerobot.common.utils import get_safe_torch_device
|
||||
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
|
||||
|
||||
|
||||
class ActionChunkingTransformerPolicy(nn.Module):
|
||||
|
@ -65,7 +65,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
"ActionChunkingTransformerPolicy does not handle multiple observation steps."
|
||||
)
|
||||
|
||||
def __init__(self, cfg, device):
|
||||
def __init__(self, cfg: ActionChunkingTransformerConfig):
|
||||
"""
|
||||
TODO(alexander-soare): Add documentation for all parameters once we have model configs established.
|
||||
"""
|
||||
|
@ -73,79 +73,64 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
if getattr(cfg, "n_obs_steps", 1) != 1:
|
||||
raise ValueError(self._multiple_obs_steps_not_handled_msg)
|
||||
self.cfg = cfg
|
||||
self.n_action_steps = cfg.n_action_steps
|
||||
self.device = get_safe_torch_device(device)
|
||||
self.camera_names = cfg.camera_names
|
||||
self.use_vae = cfg.use_vae
|
||||
self.horizon = cfg.horizon
|
||||
self.d_model = cfg.d_model
|
||||
|
||||
transformer_common_kwargs = dict( # noqa: C408
|
||||
d_model=self.d_model,
|
||||
num_heads=cfg.num_heads,
|
||||
dim_feedforward=cfg.dim_feedforward,
|
||||
dropout=cfg.dropout,
|
||||
activation=cfg.activation,
|
||||
normalize_before=cfg.pre_norm,
|
||||
)
|
||||
|
||||
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
|
||||
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
||||
if self.use_vae:
|
||||
self.vae_encoder = _TransformerEncoder(num_layers=cfg.vae_enc_layers, **transformer_common_kwargs)
|
||||
self.vae_encoder_cls_embed = nn.Embedding(1, self.d_model)
|
||||
if self.cfg.use_vae:
|
||||
self.vae_encoder = _TransformerEncoder(cfg)
|
||||
self.vae_encoder_cls_embed = nn.Embedding(1, cfg.d_model)
|
||||
# Projection layer for joint-space configuration to hidden dimension.
|
||||
self.vae_encoder_robot_state_input_proj = nn.Linear(cfg.state_dim, self.d_model)
|
||||
self.vae_encoder_robot_state_input_proj = nn.Linear(cfg.state_dim, cfg.d_model)
|
||||
# Projection layer for action (joint-space target) to hidden dimension.
|
||||
self.vae_encoder_action_input_proj = nn.Linear(cfg.state_dim, self.d_model)
|
||||
self.vae_encoder_action_input_proj = nn.Linear(cfg.state_dim, cfg.d_model)
|
||||
self.latent_dim = cfg.latent_dim
|
||||
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
|
||||
self.vae_encoder_latent_output_proj = nn.Linear(self.d_model, self.latent_dim * 2)
|
||||
self.vae_encoder_latent_output_proj = nn.Linear(cfg.d_model, self.latent_dim * 2)
|
||||
# Fixed sinusoidal positional embedding the whole input to the VAE encoder. Unsqueeze for batch
|
||||
# dimension.
|
||||
self.register_buffer(
|
||||
"vae_encoder_pos_enc",
|
||||
_create_sinusoidal_position_embedding(1 + 1 + self.horizon, self.d_model).unsqueeze(0),
|
||||
_create_sinusoidal_position_embedding(1 + 1 + cfg.chunk_size, cfg.d_model).unsqueeze(0),
|
||||
)
|
||||
|
||||
# Backbone for image feature extraction.
|
||||
self.image_normalizer = transforms.Normalize(
|
||||
mean=cfg.image_normalization.mean, std=cfg.image_normalization.std
|
||||
mean=cfg.image_normalization_mean, std=cfg.image_normalization_std
|
||||
)
|
||||
backbone_model = getattr(torchvision.models, cfg.backbone)(
|
||||
replace_stride_with_dilation=[False, False, cfg.dilation],
|
||||
pretrained=cfg.pretrained_backbone,
|
||||
backbone_model = getattr(torchvision.models, cfg.vision_backbone)(
|
||||
replace_stride_with_dilation=[False, False, cfg.replace_final_stride_with_dilation],
|
||||
pretrained=cfg.use_pretrained_backbone,
|
||||
norm_layer=FrozenBatchNorm2d,
|
||||
)
|
||||
# Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final feature
|
||||
# map).
|
||||
# Note: The forward method of this returns a dict: {"feature_map": output}.
|
||||
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
|
||||
|
||||
# Transformer (acts as VAE decoder when training with the variational objective).
|
||||
self.encoder = _TransformerEncoder(num_layers=cfg.enc_layers, **transformer_common_kwargs)
|
||||
self.decoder = _TransformerDecoder(num_layers=cfg.dec_layers, **transformer_common_kwargs)
|
||||
self.encoder = _TransformerEncoder(cfg)
|
||||
self.decoder = _TransformerDecoder(cfg)
|
||||
|
||||
# Transformer encoder input projections. The tokens will be structured like
|
||||
# [latent, robot_state, image_feature_map_pixels].
|
||||
self.encoder_robot_state_input_proj = nn.Linear(cfg.state_dim, self.d_model)
|
||||
self.encoder_latent_input_proj = nn.Linear(self.latent_dim, self.d_model)
|
||||
self.encoder_robot_state_input_proj = nn.Linear(cfg.state_dim, cfg.d_model)
|
||||
self.encoder_latent_input_proj = nn.Linear(self.latent_dim, cfg.d_model)
|
||||
self.encoder_img_feat_input_proj = nn.Conv2d(
|
||||
backbone_model.fc.in_features, self.d_model, kernel_size=1
|
||||
backbone_model.fc.in_features, cfg.d_model, kernel_size=1
|
||||
)
|
||||
# Transformer encoder positional embeddings.
|
||||
self.encoder_robot_and_latent_pos_embed = nn.Embedding(2, self.d_model)
|
||||
self.encoder_cam_feat_pos_embed = _SinusoidalPositionEmbedding2D(self.d_model // 2)
|
||||
self.encoder_robot_and_latent_pos_embed = nn.Embedding(2, cfg.d_model)
|
||||
self.encoder_cam_feat_pos_embed = _SinusoidalPositionEmbedding2D(cfg.d_model // 2)
|
||||
|
||||
# Transformer decoder.
|
||||
# Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
|
||||
self.decoder_pos_embed = nn.Embedding(self.horizon, self.d_model)
|
||||
self.decoder_pos_embed = nn.Embedding(cfg.chunk_size, cfg.d_model)
|
||||
|
||||
# Final action regression head on the output of the transformer's decoder.
|
||||
self.action_head = nn.Linear(self.d_model, cfg.action_dim)
|
||||
self.action_head = nn.Linear(cfg.d_model, cfg.action_dim)
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
self._create_optimizer()
|
||||
self.to(self.device)
|
||||
|
||||
def _create_optimizer(self):
|
||||
optimizer_params_dicts = [
|
||||
|
@ -173,8 +158,8 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
|
||||
def reset(self):
|
||||
"""This should be called whenever the environment is reset."""
|
||||
if self.n_action_steps is not None:
|
||||
self._action_queue = deque([], maxlen=self.n_action_steps)
|
||||
if self.cfg.n_action_steps is not None:
|
||||
self._action_queue = deque([], maxlen=self.cfg.n_action_steps)
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
|
||||
|
@ -184,8 +169,8 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
queue is empty.
|
||||
"""
|
||||
if len(self._action_queue) == 0:
|
||||
# `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has shape
|
||||
# (n_action_steps, batch_size, *), hence the transpose.
|
||||
# `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has
|
||||
# shape (n_action_steps, batch_size, *), hence the transpose.
|
||||
self._action_queue.extend(self.select_actions(batch).transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
|
@ -197,20 +182,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
|
||||
action = self.forward(batch, return_loss=False)
|
||||
|
||||
if self.cfg.temporal_agg:
|
||||
# TODO(rcadene): implement temporal aggregation
|
||||
raise NotImplementedError()
|
||||
# all_time_actions[[t], t:t+num_queries] = action
|
||||
# actions_for_curr_step = all_time_actions[:, t]
|
||||
# actions_populated = torch.all(actions_for_curr_step != 0, axis=1)
|
||||
# actions_for_curr_step = actions_for_curr_step[actions_populated]
|
||||
# k = 0.01
|
||||
# exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step)))
|
||||
# exp_weights = exp_weights / exp_weights.sum()
|
||||
# exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1)
|
||||
# raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True)
|
||||
|
||||
return action[: self.n_action_steps]
|
||||
return action[: self.cfg.n_action_steps]
|
||||
|
||||
def __call__(self, *args, **kwargs) -> dict:
|
||||
# TODO(alexander-soare): Temporary bridge until we know what to do about the `update` method.
|
||||
|
@ -251,9 +223,9 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
self.train()
|
||||
|
||||
num_slices = self.cfg.batch_size
|
||||
batch_size = self.cfg.horizon * num_slices
|
||||
batch_size = self.cfg.chunk_size * num_slices
|
||||
|
||||
assert batch_size % self.cfg.horizon == 0
|
||||
assert batch_size % self.cfg.chunk_size == 0
|
||||
assert batch_size % num_slices == 0
|
||||
|
||||
loss = self.forward(batch, return_loss=True)["loss"]
|
||||
|
@ -324,7 +296,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
Tuple containing the latent PDF's parameters (mean, log(σ²)) both as (B, L) tensors where L is the
|
||||
latent dimension.
|
||||
"""
|
||||
if self.use_vae and self.training:
|
||||
if self.cfg.use_vae and self.training:
|
||||
assert (
|
||||
actions is not None
|
||||
), "actions must be provided when using the variational objective in training mode."
|
||||
|
@ -332,7 +304,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
batch_size = robot_state.shape[0]
|
||||
|
||||
# Prepare the latent for input to the transformer encoder.
|
||||
if self.use_vae and actions is not None:
|
||||
if self.cfg.use_vae and actions is not None:
|
||||
# Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence].
|
||||
cls_embed = einops.repeat(
|
||||
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
|
||||
|
@ -367,7 +339,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
# Camera observation features and positional embeddings.
|
||||
all_cam_features = []
|
||||
all_cam_pos_embeds = []
|
||||
for cam_id, _ in enumerate(self.camera_names):
|
||||
for cam_id, _ in enumerate(self.cfg.camera_names):
|
||||
cam_features = self.backbone(image[:, cam_id])["feature_map"]
|
||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
||||
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
|
||||
|
@ -399,7 +371,9 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
# Forward pass through the transformer modules.
|
||||
encoder_out = self.encoder(encoder_in, pos_embed=pos_embed)
|
||||
decoder_in = torch.zeros(
|
||||
(self.horizon, batch_size, self.d_model), dtype=pos_embed.dtype, device=pos_embed.device
|
||||
(self.cfg.chunk_size, batch_size, self.cfg.d_model),
|
||||
dtype=pos_embed.dtype,
|
||||
device=pos_embed.device,
|
||||
)
|
||||
decoder_out = self.decoder(
|
||||
decoder_in,
|
||||
|
@ -426,16 +400,10 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
class _TransformerEncoder(nn.Module):
|
||||
"""Convenience module for running multiple encoder layers, maybe followed by normalization."""
|
||||
|
||||
def __init__(self, num_layers: int, **encoder_layer_kwargs: dict):
|
||||
def __init__(self, cfg: ActionChunkingTransformerConfig):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList(
|
||||
[_TransformerEncoderLayer(**encoder_layer_kwargs) for _ in range(num_layers)]
|
||||
)
|
||||
self.norm = (
|
||||
nn.LayerNorm(encoder_layer_kwargs["d_model"])
|
||||
if encoder_layer_kwargs["normalize_before"]
|
||||
else nn.Identity()
|
||||
)
|
||||
self.layers = nn.ModuleList([_TransformerEncoderLayer(cfg) for _ in range(cfg.n_encoder_layers)])
|
||||
self.norm = nn.LayerNorm(cfg.d_model) if cfg.pre_norm else nn.Identity()
|
||||
|
||||
def forward(self, x: Tensor, pos_embed: Tensor | None = None) -> Tensor:
|
||||
for layer in self.layers:
|
||||
|
@ -445,39 +413,31 @@ class _TransformerEncoder(nn.Module):
|
|||
|
||||
|
||||
class _TransformerEncoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
num_heads: int,
|
||||
dim_feedforward: int,
|
||||
dropout: float,
|
||||
activation: str,
|
||||
normalize_before: bool,
|
||||
):
|
||||
def __init__(self, cfg: ActionChunkingTransformerConfig):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
|
||||
self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout)
|
||||
|
||||
# Feed forward layers.
|
||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||
self.linear1 = nn.Linear(cfg.d_model, cfg.dim_feedforward)
|
||||
self.dropout = nn.Dropout(cfg.dropout)
|
||||
self.linear2 = nn.Linear(cfg.dim_feedforward, cfg.d_model)
|
||||
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
self.norm1 = nn.LayerNorm(cfg.d_model)
|
||||
self.norm2 = nn.LayerNorm(cfg.d_model)
|
||||
self.dropout1 = nn.Dropout(cfg.dropout)
|
||||
self.dropout2 = nn.Dropout(cfg.dropout)
|
||||
|
||||
self.activation = _get_activation_fn(activation)
|
||||
self.normalize_before = normalize_before
|
||||
self.activation = _get_activation_fn(cfg.feedforward_activation)
|
||||
self.pre_norm = cfg.pre_norm
|
||||
|
||||
def forward(self, x, pos_embed: Tensor | None = None) -> Tensor:
|
||||
skip = x
|
||||
if self.normalize_before:
|
||||
if self.pre_norm:
|
||||
x = self.norm1(x)
|
||||
q = k = x if pos_embed is None else x + pos_embed
|
||||
x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights
|
||||
x = skip + self.dropout1(x)
|
||||
if self.normalize_before:
|
||||
if self.pre_norm:
|
||||
skip = x
|
||||
x = self.norm2(x)
|
||||
else:
|
||||
|
@ -485,20 +445,17 @@ class _TransformerEncoderLayer(nn.Module):
|
|||
skip = x
|
||||
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
||||
x = skip + self.dropout2(x)
|
||||
if not self.normalize_before:
|
||||
if not self.pre_norm:
|
||||
x = self.norm2(x)
|
||||
return x
|
||||
|
||||
|
||||
class _TransformerDecoder(nn.Module):
|
||||
def __init__(self, num_layers: int, **decoder_layer_kwargs):
|
||||
def __init__(self, cfg: ActionChunkingTransformerConfig):
|
||||
"""Convenience module for running multiple decoder layers followed by normalization."""
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList(
|
||||
[_TransformerDecoderLayer(**decoder_layer_kwargs) for _ in range(num_layers)]
|
||||
)
|
||||
self.num_layers = num_layers
|
||||
self.norm = nn.LayerNorm(decoder_layer_kwargs["d_model"])
|
||||
self.layers = nn.ModuleList([_TransformerDecoderLayer(cfg) for _ in range(cfg.n_decoder_layers)])
|
||||
self.norm = nn.LayerNorm(cfg.d_model)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -517,33 +474,25 @@ class _TransformerDecoder(nn.Module):
|
|||
|
||||
|
||||
class _TransformerDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
num_heads: int,
|
||||
dim_feedforward: int,
|
||||
dropout: float,
|
||||
activation: str,
|
||||
normalize_before: bool,
|
||||
):
|
||||
def __init__(self, cfg: ActionChunkingTransformerConfig):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
|
||||
self.multihead_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
|
||||
self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout)
|
||||
self.multihead_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout)
|
||||
|
||||
# Feed forward layers.
|
||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||
self.linear1 = nn.Linear(cfg.d_model, cfg.dim_feedforward)
|
||||
self.dropout = nn.Dropout(cfg.dropout)
|
||||
self.linear2 = nn.Linear(cfg.dim_feedforward, cfg.d_model)
|
||||
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.norm3 = nn.LayerNorm(d_model)
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
self.dropout3 = nn.Dropout(dropout)
|
||||
self.norm1 = nn.LayerNorm(cfg.d_model)
|
||||
self.norm2 = nn.LayerNorm(cfg.d_model)
|
||||
self.norm3 = nn.LayerNorm(cfg.d_model)
|
||||
self.dropout1 = nn.Dropout(cfg.dropout)
|
||||
self.dropout2 = nn.Dropout(cfg.dropout)
|
||||
self.dropout3 = nn.Dropout(cfg.dropout)
|
||||
|
||||
self.activation = _get_activation_fn(activation)
|
||||
self.normalize_before = normalize_before
|
||||
self.activation = _get_activation_fn(cfg.feedforward_activation)
|
||||
self.pre_norm = cfg.pre_norm
|
||||
|
||||
def maybe_add_pos_embed(self, tensor: Tensor, pos_embed: Tensor | None) -> Tensor:
|
||||
return tensor if pos_embed is None else tensor + pos_embed
|
||||
|
@ -566,12 +515,12 @@ class _TransformerDecoderLayer(nn.Module):
|
|||
(DS, B, C) tensor of decoder output features.
|
||||
"""
|
||||
skip = x
|
||||
if self.normalize_before:
|
||||
if self.pre_norm:
|
||||
x = self.norm1(x)
|
||||
q = k = self.maybe_add_pos_embed(x, decoder_pos_embed)
|
||||
x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights
|
||||
x = skip + self.dropout1(x)
|
||||
if self.normalize_before:
|
||||
if self.pre_norm:
|
||||
skip = x
|
||||
x = self.norm2(x)
|
||||
else:
|
||||
|
@ -583,7 +532,7 @@ class _TransformerDecoderLayer(nn.Module):
|
|||
value=encoder_out,
|
||||
)[0] # select just the output, not the attention weights
|
||||
x = skip + self.dropout2(x)
|
||||
if self.normalize_before:
|
||||
if self.pre_norm:
|
||||
skip = x
|
||||
x = self.norm3(x)
|
||||
else:
|
||||
|
@ -591,7 +540,7 @@ class _TransformerDecoderLayer(nn.Module):
|
|||
skip = x
|
||||
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
||||
x = skip + self.dropout3(x)
|
||||
if not self.normalize_before:
|
||||
if not self.pre_norm:
|
||||
x = self.norm3(x)
|
||||
return x
|
||||
|
|
@ -1,3 +1,10 @@
|
|||
import inspect
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from lerobot.common.utils import get_safe_torch_device
|
||||
|
||||
|
||||
def make_policy(cfg):
|
||||
if cfg.policy.name == "tdmpc":
|
||||
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
||||
|
@ -19,10 +26,22 @@ def make_policy(cfg):
|
|||
**cfg.policy,
|
||||
)
|
||||
elif cfg.policy.name == "act":
|
||||
from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
|
||||
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
|
||||
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
|
||||
|
||||
policy = ActionChunkingTransformerPolicy(cfg.policy, cfg.device)
|
||||
policy.to(cfg.device)
|
||||
expected_kwargs = set(inspect.signature(ActionChunkingTransformerConfig).parameters)
|
||||
assert set(cfg.policy).issuperset(
|
||||
expected_kwargs
|
||||
), f"Hydra config is missing arguments: {set(cfg.policy).difference(expected_kwargs)}"
|
||||
policy_cfg = ActionChunkingTransformerConfig(
|
||||
**{
|
||||
k: v
|
||||
for k, v in OmegaConf.to_container(cfg.policy, resolve=True).items()
|
||||
if k in expected_kwargs
|
||||
}
|
||||
)
|
||||
policy = ActionChunkingTransformerPolicy(policy_cfg)
|
||||
policy.to(get_safe_torch_device(cfg.device))
|
||||
else:
|
||||
raise ValueError(cfg.policy.name)
|
||||
|
||||
|
|
|
@ -8,61 +8,65 @@ eval_freq: 10000
|
|||
save_freq: 100000
|
||||
log_freq: 250
|
||||
|
||||
horizon: 100
|
||||
n_obs_steps: 1
|
||||
# when temporal_agg=False, n_action_steps=horizon
|
||||
n_action_steps: ${horizon}
|
||||
|
||||
# See `configuration_act.py` for more details.
|
||||
policy:
|
||||
name: act
|
||||
|
||||
pretrained_model_path:
|
||||
|
||||
# Environment.
|
||||
# Inherit these from the environment.
|
||||
state_dim: ???
|
||||
action_dim: ???
|
||||
|
||||
# Inputs / output structure.
|
||||
n_obs_steps: ${n_obs_steps}
|
||||
camera_names: [top] # [top, front_close, left_pillar, right_pillar]
|
||||
chunk_size: 100 # chunk_size
|
||||
n_action_steps: 100
|
||||
|
||||
# Vision preprocessing.
|
||||
image_normalization_mean: [0.485, 0.456, 0.406]
|
||||
image_normalization_std: [0.229, 0.224, 0.225]
|
||||
|
||||
# Architecture.
|
||||
# Vision backbone.
|
||||
vision_backbone: resnet18
|
||||
use_pretrained_backbone: true
|
||||
replace_final_stride_with_dilation: false
|
||||
# Transformer layers.
|
||||
pre_norm: false
|
||||
d_model: 512
|
||||
n_heads: 8
|
||||
dim_feedforward: 3200
|
||||
feedforward_activation: relu
|
||||
n_encoder_layers: 4
|
||||
n_decoder_layers: 1
|
||||
# VAE.
|
||||
use_vae: true
|
||||
latent_dim: 32
|
||||
n_vae_encoder_layers: 4
|
||||
|
||||
# Inference.
|
||||
use_temporal_aggregation: false
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: 0.1
|
||||
kl_weight: 10.0
|
||||
|
||||
# ---
|
||||
# TODO(alexander-soare): Remove these from the policy config.
|
||||
batch_size: 8
|
||||
lr: 1e-5
|
||||
lr_backbone: 1e-5
|
||||
pretrained_backbone: true
|
||||
weight_decay: 1e-4
|
||||
grad_clip_norm: 10
|
||||
backbone: resnet18
|
||||
horizon: ${horizon} # chunk_size
|
||||
kl_weight: 10
|
||||
d_model: 512
|
||||
dim_feedforward: 3200
|
||||
vae_enc_layers: 4
|
||||
enc_layers: 4
|
||||
dec_layers: 1
|
||||
num_heads: 8
|
||||
#camera_names: [top, front_close, left_pillar, right_pillar]
|
||||
camera_names: [top]
|
||||
dilation: false
|
||||
dropout: 0.1
|
||||
pre_norm: false
|
||||
activation: relu
|
||||
latent_dim: 32
|
||||
|
||||
use_vae: true
|
||||
|
||||
batch_size: 8
|
||||
|
||||
per_alpha: 0.6
|
||||
per_beta: 0.4
|
||||
|
||||
balanced_sampling: false
|
||||
utd: 1
|
||||
|
||||
n_obs_steps: ${n_obs_steps}
|
||||
n_action_steps: ${n_action_steps}
|
||||
|
||||
temporal_agg: false
|
||||
|
||||
state_dim: 14
|
||||
action_dim: 14
|
||||
|
||||
image_normalization:
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
|
||||
delta_timestamps:
|
||||
observation.images.top: [0.0]
|
||||
observation.state: [0.0]
|
||||
action: "[i / ${fps} for i in range(${horizon})]"
|
||||
observation.images.top: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]"
|
||||
observation.state: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]"
|
||||
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||
|
|
|
@ -18,7 +18,7 @@ from lerobot.common.datasets.xarm import XarmDataset
|
|||
from lerobot.common.datasets.aloha import AlohaDataset
|
||||
from lerobot.common.datasets.pusht import PushtDataset
|
||||
|
||||
from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
|
||||
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
|
||||
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
|
||||
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
||||
|
||||
|
|
Loading…
Reference in New Issue