From 55e484124ac3cf5bdfece0411a9b8ec8744f1ba4 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Fri, 12 Apr 2024 16:55:32 +0100 Subject: [PATCH 1/7] draft pr --- lerobot/common/datasets/factory.py | 9 +- .../common/policies/act/configuration_act.py | 68 ++ lerobot/common/policies/act/modeling_act.py | 626 ++++++++++++++++++ lerobot/common/policies/factory.py | 17 +- lerobot/configs/policy/act.yaml | 89 +-- 5 files changed, 758 insertions(+), 51 deletions(-) create mode 100644 lerobot/common/policies/act/configuration_act.py create mode 100644 lerobot/common/policies/act/modeling_act.py diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 4ae161f6..10106fe9 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -86,11 +86,10 @@ def make_dataset( ] ) - delta_timestamps = cfg.policy.get("delta_timestamps") - if delta_timestamps is not None: - for key in delta_timestamps: - if isinstance(delta_timestamps[key], str): - delta_timestamps[key] = eval(delta_timestamps[key]) + delta_timestamps = cfg.policy.delta_timestamps + for key in delta_timestamps: + if isinstance(delta_timestamps[key], str): + delta_timestamps[key] = eval(delta_timestamps[key]) dataset = clsfunc( dataset_id=cfg.dataset_id, diff --git a/lerobot/common/policies/act/configuration_act.py b/lerobot/common/policies/act/configuration_act.py new file mode 100644 index 00000000..a3dc1590 --- /dev/null +++ b/lerobot/common/policies/act/configuration_act.py @@ -0,0 +1,68 @@ +from dataclasses import dataclass + + +@dataclass +class ActConfig: + """ + TODO(now): Document all variables + TODO(now): Pick sensible defaults for a use case? + """ + + # Environment. + state_dim: int + action_dim: int + + # Inputs / output structure. + n_obs_steps: int + camera_names: list[str] + chunk_size: int + n_action_steps: int + + # Vision preprocessing. + image_normalization_mean: tuple[float, float, float] + image_normalization_std: tuple[float, float, float] + + # Architecture. + # Vision backbone. + vision_backbone: str + use_pretrained_backbone: bool + replace_final_stride_with_dilation: int + # Transformer layers. + pre_norm: bool + d_model: int + n_heads: int + dim_feedforward: int + feedforward_activation: str + n_encoder_layers: int + n_decoder_layers: int + # VAE. + use_vae: bool + latent_dim: int + n_vae_encoder_layers: int + + # Inference. + use_temporal_aggregation: bool + + # Training and loss computation. + dropout: float + kl_weight: float + + # --- + # TODO(alexander-soare): Remove these from the policy config. + batch_size: int + lr: float + lr_backbone: float + weight_decay: float + grad_clip_norm: float + utd: int + + def __post_init__(self): + """Input validation.""" + if not self.vision_backbone.startswith("resnet"): + raise ValueError("`vision_backbone` must be one of the ResNet variants.") + if self.use_temporal_aggregation: + raise NotImplementedError("Temporal aggregation is not yet implemented.") + if self.n_action_steps > self.chunk_size: + raise ValueError( + "The chunk size is the upper bound for the number of action steps per model invocation." + ) diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py new file mode 100644 index 00000000..769c9470 --- /dev/null +++ b/lerobot/common/policies/act/modeling_act.py @@ -0,0 +1,626 @@ +"""Action Chunking Transformer Policy + +As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://arxiv.org/abs/2304.13705). +The majority of changes here involve removing unused code, unifying naming, and adding helpful comments. +""" + +import math +import time +from collections import deque +from itertools import chain +from typing import Callable + +import einops +import numpy as np +import torch +import torch.nn.functional as F # noqa: N812 +import torchvision +import torchvision.transforms as transforms +from torch import Tensor, nn +from torchvision.models._utils import IntermediateLayerGetter +from torchvision.ops.misc import FrozenBatchNorm2d + +from lerobot.common.policies.act.configuration_act import ActConfig + + +class ActPolicy(nn.Module): + """ + Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost + Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act) + + Note: In this code we use the terms `vae_encoder`, 'encoder', `decoder`. The meanings are as follows. + - The `vae_encoder` is, as per the literature around variational auto-encoders (VAE), the part of the + model that encodes the target data (a sequence of actions), and the condition (the robot + joint-space). + - A transformer with an `encoder` (not the VAE encoder) and `decoder` (not the VAE decoder) with + cross-attention is used as the VAE decoder. For these terms, we drop the `vae_` prefix because we + have an option to train this model without the variational objective (in which case we drop the + `vae_encoder` altogether, and nothing about this model has anything to do with a VAE). + + Transformer + Used alone for inference + (acts as VAE decoder + during training) + ┌───────────────────────┐ + │ Outputs │ + │ ▲ │ + │ ┌─────►┌───────┐ │ + ┌──────┐ │ │ │Transf.│ │ + │ │ │ ├─────►│decoder│ │ + ┌────┴────┐ │ │ │ │ │ │ + │ │ │ │ ┌───┴───┬─►│ │ │ + │ VAE │ │ │ │ │ └───────┘ │ + │ encoder │ │ │ │Transf.│ │ + │ │ │ │ │encoder│ │ + └───▲─────┘ │ │ │ │ │ + │ │ │ └───▲───┘ │ + │ │ │ │ │ + inputs └─────┼─────┘ │ + │ │ + └───────────────────────┘ + """ + + name = "act" + _multiple_obs_steps_not_handled_msg = "ActPolicy does not handle multiple observation steps." + + def __init__(self, cfg: ActConfig): + """ + TODO(alexander-soare): Add documentation for all parameters once we have model configs established. + """ + super().__init__() + if getattr(cfg, "n_obs_steps", 1) != 1: + raise ValueError(self._multiple_obs_steps_not_handled_msg) + self.cfg = cfg + + # BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence]. + # The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]). + if self.cfg.use_vae: + self.vae_encoder = _TransformerEncoder(cfg) + self.vae_encoder_cls_embed = nn.Embedding(1, cfg.d_model) + # Projection layer for joint-space configuration to hidden dimension. + self.vae_encoder_robot_state_input_proj = nn.Linear(cfg.state_dim, cfg.d_model) + # Projection layer for action (joint-space target) to hidden dimension. + self.vae_encoder_action_input_proj = nn.Linear(cfg.state_dim, cfg.d_model) + self.latent_dim = cfg.latent_dim + # Projection layer from the VAE encoder's output to the latent distribution's parameter space. + self.vae_encoder_latent_output_proj = nn.Linear(cfg.d_model, self.latent_dim * 2) + # Fixed sinusoidal positional embedding the whole input to the VAE encoder. Unsqueeze for batch + # dimension. + self.register_buffer( + "vae_encoder_pos_enc", + _create_sinusoidal_position_embedding(1 + 1 + cfg.chunk_size, cfg.d_model).unsqueeze(0), + ) + + # Backbone for image feature extraction. + self.image_normalizer = transforms.Normalize( + mean=cfg.image_normalization_mean, std=cfg.image_normalization_std + ) + backbone_model = getattr(torchvision.models, cfg.vision_backbone)( + replace_stride_with_dilation=[False, False, cfg.replace_final_stride_with_dilation], + pretrained=cfg.use_pretrained_backbone, + norm_layer=FrozenBatchNorm2d, + ) + # Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final feature + # map). + # Note: The forward method of this returns a dict: {"feature_map": output}. + self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"}) + + # Transformer (acts as VAE decoder when training with the variational objective). + self.encoder = _TransformerEncoder(cfg) + self.decoder = _TransformerDecoder(cfg) + + # Transformer encoder input projections. The tokens will be structured like + # [latent, robot_state, image_feature_map_pixels]. + self.encoder_robot_state_input_proj = nn.Linear(cfg.state_dim, cfg.d_model) + self.encoder_latent_input_proj = nn.Linear(self.latent_dim, cfg.d_model) + self.encoder_img_feat_input_proj = nn.Conv2d( + backbone_model.fc.in_features, cfg.d_model, kernel_size=1 + ) + # Transformer encoder positional embeddings. + self.encoder_robot_and_latent_pos_embed = nn.Embedding(2, cfg.d_model) + self.encoder_cam_feat_pos_embed = _SinusoidalPositionEmbedding2D(cfg.d_model // 2) + + # Transformer decoder. + # Learnable positional embedding for the transformer's decoder (in the style of DETR object queries). + self.decoder_pos_embed = nn.Embedding(cfg.chunk_size, cfg.d_model) + + # Final action regression head on the output of the transformer's decoder. + self.action_head = nn.Linear(cfg.d_model, cfg.action_dim) + + self._reset_parameters() + self._create_optimizer() + + def _create_optimizer(self): + optimizer_params_dicts = [ + { + "params": [ + p for n, p in self.named_parameters() if not n.startswith("backbone") and p.requires_grad + ] + }, + { + "params": [ + p for n, p in self.named_parameters() if n.startswith("backbone") and p.requires_grad + ], + "lr": self.cfg.lr_backbone, + }, + ] + self.optimizer = torch.optim.AdamW( + optimizer_params_dicts, lr=self.cfg.lr, weight_decay=self.cfg.weight_decay + ) + + def _reset_parameters(self): + """Xavier-uniform initialization of the transformer parameters as in the original code.""" + for p in chain(self.encoder.parameters(), self.decoder.parameters()): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def reset(self): + """This should be called whenever the environment is reset.""" + if self.cfg.n_action_steps is not None: + self._action_queue = deque([], maxlen=self.cfg.n_action_steps) + + @torch.no_grad + def select_action(self, batch: dict[str, Tensor], **_) -> Tensor: + """ + This method wraps `select_actions` in order to return one action at a time for execution in the + environment. It works by managing the actions in a queue and only calling `select_actions` when the + queue is empty. + """ + if len(self._action_queue) == 0: + # `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has + # shape (n_action_steps, batch_size, *), hence the transpose. + self._action_queue.extend(self.select_actions(batch).transpose(0, 1)) + return self._action_queue.popleft() + + @torch.no_grad + def select_actions(self, batch: dict[str, Tensor]) -> Tensor: + """Use the action chunking transformer to generate a sequence of actions.""" + self.eval() + self._preprocess_batch(batch, add_obs_steps_dim=True) + + action = self.forward(batch, return_loss=False) + + return action[: self.cfg.n_action_steps] + + def __call__(self, *args, **kwargs) -> dict: + # TODO(alexander-soare): Temporary bridge until we know what to do about the `update` method. + return self.update(*args, **kwargs) + + def _preprocess_batch( + self, batch: dict[str, Tensor], add_obs_steps_dim: bool = False + ) -> dict[str, Tensor]: + """ + This function expects `batch` to have (at least): + { + "observation.state": (B, 1, J) OR (B, J) tensor of robot states (joint configuration). + "observation.images.top": (B, 1, C, H, W) OR (B, C, H, W) tensor of images. + } + """ + if add_obs_steps_dim: + # Add a dimension for the observations steps. Since n_obs_steps > 1 is not supported right now, + # this just amounts to an unsqueeze. + for k in batch: + if k.startswith("observation."): + batch[k] = batch[k].unsqueeze(1) + + if batch["observation.state"].shape[1] != 1: + raise ValueError(self._multiple_obs_steps_not_handled_msg) + batch["observation.state"] = batch["observation.state"].squeeze(1) + # TODO(alexander-soare): generalize this to multiple images. + assert ( + sum(k.startswith("observation.images.") and not k.endswith("is_pad") for k in batch) == 1 + ), "ACT only handles one image for now." + # Note: no squeeze is required for "observation.images.top" because then we'd have to unsqueeze to get + # the image index dimension. + + def update(self, batch, **_) -> dict: + """Run the model in train mode, compute the loss, and do an optimization step.""" + start_time = time.time() + self._preprocess_batch(batch) + + self.train() + + num_slices = self.cfg.batch_size + batch_size = self.cfg.chunk_size * num_slices + + assert batch_size % self.cfg.chunk_size == 0 + assert batch_size % num_slices == 0 + + loss = self.forward(batch, return_loss=True)["loss"] + loss.backward() + + grad_norm = torch.nn.utils.clip_grad_norm_( + self.parameters(), + self.cfg.grad_clip_norm, + error_if_nonfinite=False, + ) + + self.optimizer.step() + self.optimizer.zero_grad() + + info = { + "loss": loss.item(), + "grad_norm": float(grad_norm), + "lr": self.cfg.lr, + "update_s": time.time() - start_time, + } + + return info + + def forward(self, batch: dict[str, Tensor], return_loss: bool = False) -> dict | Tensor: + """A forward pass through the DNN part of this policy with optional loss computation.""" + images = self.image_normalizer(batch["observation.images.top"]) + + if return_loss: # training time + actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward( + batch["observation.state"], images, batch["action"] + ) + + l1_loss = ( + F.l1_loss(batch["action"], actions_hat, reduction="none") + * ~batch["action_is_pad"].unsqueeze(-1) + ).mean() + + loss_dict = {} + loss_dict["l1"] = l1_loss + if self.cfg.use_vae: + # Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for + # each dimension independently, we sum over the latent dimension to get the total + # KL-divergence per batch element, then take the mean over the batch. + # (See App. B of https://arxiv.org/abs/1312.6114 for more details). + mean_kld = ( + (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean() + ) + loss_dict["kl"] = mean_kld + loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.cfg.kl_weight + else: + loss_dict["loss"] = loss_dict["l1"] + return loss_dict + else: + action, _ = self._forward(batch["observation.state"], images) + return action + + def _forward( + self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None + ) -> tuple[Tensor, tuple[Tensor | None, Tensor | None]]: + """ + Args: + robot_state: (B, J) batch of robot joint configurations. + image: (B, N, C, H, W) batch of N camera frames. + actions: (B, S, A) batch of actions from the target dataset which must be provided if the + VAE is enabled and the model is in training mode. + Returns: + (B, S, A) batch of action sequences + Tuple containing the latent PDF's parameters (mean, log(σ²)) both as (B, L) tensors where L is the + latent dimension. + """ + if self.cfg.use_vae and self.training: + assert ( + actions is not None + ), "actions must be provided when using the variational objective in training mode." + + batch_size = robot_state.shape[0] + + # Prepare the latent for input to the transformer encoder. + if self.cfg.use_vae and actions is not None: + # Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence]. + cls_embed = einops.repeat( + self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size + ) # (B, 1, D) + robot_state_embed = self.vae_encoder_robot_state_input_proj(robot_state).unsqueeze(1) # (B, 1, D) + action_embed = self.vae_encoder_action_input_proj(actions) # (B, S, D) + vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D) + + # Prepare fixed positional embedding. + # Note: detach() shouldn't be necessary but leaving it the same as the original code just in case. + pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D) + + # Forward pass through VAE encoder to get the latent PDF parameters. + cls_token_out = self.vae_encoder( + vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2) + )[0] # select the class token, with shape (B, D) + latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out) + mu = latent_pdf_params[:, : self.latent_dim] + # This is 2log(sigma). Done this way to match the original implementation. + log_sigma_x2 = latent_pdf_params[:, self.latent_dim :] + + # Sample the latent with the reparameterization trick. + latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu) + else: + # When not using the VAE encoder, we set the latent to be all zeros. + mu = log_sigma_x2 = None + latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=torch.float32).to( + robot_state.device + ) + + # Prepare all other transformer encoder inputs. + # Camera observation features and positional embeddings. + all_cam_features = [] + all_cam_pos_embeds = [] + for cam_id, _ in enumerate(self.cfg.camera_names): + cam_features = self.backbone(image[:, cam_id])["feature_map"] + cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype) + cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w) + all_cam_features.append(cam_features) + all_cam_pos_embeds.append(cam_pos_embed) + # Concatenate camera observation feature maps and positional embeddings along the width dimension. + encoder_in = torch.cat(all_cam_features, axis=3) + cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=3) + + # Get positional embeddings for robot state and latent. + robot_state_embed = self.encoder_robot_state_input_proj(robot_state) + latent_embed = self.encoder_latent_input_proj(latent_sample) + + # Stack encoder input and positional embeddings moving to (S, B, C). + encoder_in = torch.cat( + [ + torch.stack([latent_embed, robot_state_embed], axis=0), + encoder_in.flatten(2).permute(2, 0, 1), + ] + ) + pos_embed = torch.cat( + [ + self.encoder_robot_and_latent_pos_embed.weight.unsqueeze(1), + cam_pos_embed.flatten(2).permute(2, 0, 1), + ], + axis=0, + ) + + # Forward pass through the transformer modules. + encoder_out = self.encoder(encoder_in, pos_embed=pos_embed) + decoder_in = torch.zeros( + (self.cfg.chunk_size, batch_size, self.cfg.d_model), + dtype=pos_embed.dtype, + device=pos_embed.device, + ) + decoder_out = self.decoder( + decoder_in, + encoder_out, + encoder_pos_embed=pos_embed, + decoder_pos_embed=self.decoder_pos_embed.weight.unsqueeze(1), + ) + + # Move back to (B, S, C). + decoder_out = decoder_out.transpose(0, 1) + + actions = self.action_head(decoder_out) + + return actions, (mu, log_sigma_x2) + + def save(self, fp): + torch.save(self.state_dict(), fp) + + def load(self, fp): + d = torch.load(fp) + self.load_state_dict(d) + + +class _TransformerEncoder(nn.Module): + """Convenience module for running multiple encoder layers, maybe followed by normalization.""" + + def __init__(self, cfg: ActConfig): + super().__init__() + self.layers = nn.ModuleList([_TransformerEncoderLayer(cfg) for _ in range(cfg.n_encoder_layers)]) + self.norm = nn.LayerNorm(cfg.d_model) if cfg.pre_norm else nn.Identity() + + def forward(self, x: Tensor, pos_embed: Tensor | None = None) -> Tensor: + for layer in self.layers: + x = layer(x, pos_embed=pos_embed) + x = self.norm(x) + return x + + +class _TransformerEncoderLayer(nn.Module): + def __init__(self, cfg: ActConfig): + super().__init__() + self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout) + + # Feed forward layers. + self.linear1 = nn.Linear(cfg.d_model, cfg.dim_feedforward) + self.dropout = nn.Dropout(cfg.dropout) + self.linear2 = nn.Linear(cfg.dim_feedforward, cfg.d_model) + + self.norm1 = nn.LayerNorm(cfg.d_model) + self.norm2 = nn.LayerNorm(cfg.d_model) + self.dropout1 = nn.Dropout(cfg.dropout) + self.dropout2 = nn.Dropout(cfg.dropout) + + self.activation = _get_activation_fn(cfg.feedforward_activation) + self.pre_norm = cfg.pre_norm + + def forward(self, x, pos_embed: Tensor | None = None) -> Tensor: + skip = x + if self.pre_norm: + x = self.norm1(x) + q = k = x if pos_embed is None else x + pos_embed + x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights + x = skip + self.dropout1(x) + if self.pre_norm: + skip = x + x = self.norm2(x) + else: + x = self.norm1(x) + skip = x + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + x = skip + self.dropout2(x) + if not self.pre_norm: + x = self.norm2(x) + return x + + +class _TransformerDecoder(nn.Module): + def __init__(self, cfg: ActConfig): + """Convenience module for running multiple decoder layers followed by normalization.""" + super().__init__() + self.layers = nn.ModuleList([_TransformerDecoderLayer(cfg) for _ in range(cfg.n_decoder_layers)]) + self.norm = nn.LayerNorm(cfg.d_model) + + def forward( + self, + x: Tensor, + encoder_out: Tensor, + decoder_pos_embed: Tensor | None = None, + encoder_pos_embed: Tensor | None = None, + ) -> Tensor: + for layer in self.layers: + x = layer( + x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed + ) + if self.norm is not None: + x = self.norm(x) + return x + + +class _TransformerDecoderLayer(nn.Module): + def __init__(self, cfg: ActConfig): + super().__init__() + self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout) + self.multihead_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout) + + # Feed forward layers. + self.linear1 = nn.Linear(cfg.d_model, cfg.dim_feedforward) + self.dropout = nn.Dropout(cfg.dropout) + self.linear2 = nn.Linear(cfg.dim_feedforward, cfg.d_model) + + self.norm1 = nn.LayerNorm(cfg.d_model) + self.norm2 = nn.LayerNorm(cfg.d_model) + self.norm3 = nn.LayerNorm(cfg.d_model) + self.dropout1 = nn.Dropout(cfg.dropout) + self.dropout2 = nn.Dropout(cfg.dropout) + self.dropout3 = nn.Dropout(cfg.dropout) + + self.activation = _get_activation_fn(cfg.feedforward_activation) + self.pre_norm = cfg.pre_norm + + def maybe_add_pos_embed(self, tensor: Tensor, pos_embed: Tensor | None) -> Tensor: + return tensor if pos_embed is None else tensor + pos_embed + + def forward( + self, + x: Tensor, + encoder_out: Tensor, + decoder_pos_embed: Tensor | None = None, + encoder_pos_embed: Tensor | None = None, + ) -> Tensor: + """ + Args: + x: (Decoder Sequence, Batch, Channel) tensor of input tokens. + encoder_out: (Encoder Sequence, B, C) output features from the last layer of the encoder we are + cross-attending with. + decoder_pos_embed: (ES, 1, C) positional embedding for keys (from the encoder). + encoder_pos_embed: (DS, 1, C) Positional_embedding for the queries (from the decoder). + Returns: + (DS, B, C) tensor of decoder output features. + """ + skip = x + if self.pre_norm: + x = self.norm1(x) + q = k = self.maybe_add_pos_embed(x, decoder_pos_embed) + x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights + x = skip + self.dropout1(x) + if self.pre_norm: + skip = x + x = self.norm2(x) + else: + x = self.norm1(x) + skip = x + x = self.multihead_attn( + query=self.maybe_add_pos_embed(x, decoder_pos_embed), + key=self.maybe_add_pos_embed(encoder_out, encoder_pos_embed), + value=encoder_out, + )[0] # select just the output, not the attention weights + x = skip + self.dropout2(x) + if self.pre_norm: + skip = x + x = self.norm3(x) + else: + x = self.norm2(x) + skip = x + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + x = skip + self.dropout3(x) + if not self.pre_norm: + x = self.norm3(x) + return x + + +def _create_sinusoidal_position_embedding(num_positions: int, dimension: int) -> Tensor: + """1D sinusoidal positional embeddings as in Attention is All You Need. + + Args: + num_positions: Number of token positions required. + Returns: (num_positions, dimension) position embeddings (the first dimension is the batch dimension). + + """ + + def get_position_angle_vec(position): + return [position / np.power(10000, 2 * (hid_j // 2) / dimension) for hid_j in range(dimension)] + + sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(num_positions)]) + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + return torch.from_numpy(sinusoid_table).float() + + +class _SinusoidalPositionEmbedding2D(nn.Module): + """2D sinusoidal positional embeddings similar to what's presented in Attention Is All You Need. + + The variation is that the position indices are normalized in [0, 2π] (not quite: the lower bound is 1/H + for the vertical direction, and 1/W for the horizontal direction. + """ + + def __init__(self, dimension: int): + """ + Args: + dimension: The desired dimension of the embeddings. + """ + super().__init__() + self.dimension = dimension + self._two_pi = 2 * math.pi + self._eps = 1e-6 + # Inverse "common ratio" for the geometric progression in sinusoid frequencies. + self._temperature = 10000 + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x: A (B, C, H, W) batch of 2D feature map to generate the embeddings for. + Returns: + A (1, C, H, W) batch of corresponding sinusoidal positional embeddings. + """ + not_mask = torch.ones_like(x[0, :1]) # (1, H, W) + # Note: These are like range(1, H+1) and range(1, W+1) respectively, but in most implementations + # they would be range(0, H) and range(0, W). Keeping it at as is to match the original code. + y_range = not_mask.cumsum(1, dtype=torch.float32) + x_range = not_mask.cumsum(2, dtype=torch.float32) + + # "Normalize" the position index such that it ranges in [0, 2π]. + # Note: Adding epsilon on the denominator should not be needed as all values of y_embed and x_range + # are non-zero by construction. This is an artifact of the original code. + y_range = y_range / (y_range[:, -1:, :] + self._eps) * self._two_pi + x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi + + inverse_frequency = self._temperature ** ( + 2 * (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2) / self.dimension + ) + + x_range = x_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1) + y_range = y_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1) + + # Note: this stack then flatten operation results in interleaved sine and cosine terms. + # pos_embed_x and pos_embed_y are (1, H, W, C // 2). + pos_embed_x = torch.stack((x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1).flatten(3) + pos_embed_y = torch.stack((y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1).flatten(3) + pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(0, 3, 1, 2) # (1, C, H, W) + + return pos_embed + + +def _get_activation_fn(activation: str) -> Callable: + """Return an activation function given a string.""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(f"activation should be relu/gelu/glu, not {activation}.") diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 9077d4d0..d9ba3f07 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -1,3 +1,8 @@ +import inspect + +from lerobot.common.utils import get_safe_torch_device + + def make_policy(cfg): if cfg.policy.name == "tdmpc": from lerobot.common.policies.tdmpc.policy import TDMPCPolicy @@ -21,10 +26,16 @@ def make_policy(cfg): **cfg.policy, ) elif cfg.policy.name == "act": - from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy + from lerobot.common.policies.act.configuration_act import ActConfig + from lerobot.common.policies.act.modeling_act import ActPolicy - policy = ActionChunkingTransformerPolicy(cfg.policy, cfg.device) - policy.to(cfg.device) + expected_kwargs = set(inspect.signature(ActConfig).parameters) + assert set(cfg.policy).issuperset( + expected_kwargs + ), f"Hydra config is missing arguments: {set(cfg.policy).difference(expected_kwargs)}" + policy_cfg = ActConfig(**{k: v for k, v in cfg.policy.items() if k in expected_kwargs}) + policy = ActPolicy(policy_cfg) + policy.to(get_safe_torch_device(cfg.device)) else: raise ValueError(cfg.policy.name) diff --git a/lerobot/configs/policy/act.yaml b/lerobot/configs/policy/act.yaml index e2074b46..8ae3087c 100644 --- a/lerobot/configs/policy/act.yaml +++ b/lerobot/configs/policy/act.yaml @@ -8,61 +8,64 @@ eval_freq: 10000 save_freq: 100000 log_freq: 250 -horizon: 100 n_obs_steps: 1 # when temporal_agg=False, n_action_steps=horizon -n_action_steps: ${horizon} policy: name: act pretrained_model_path: + # Environment. + # Inherit these from the environment. + state_dim: ??? + action_dim: ??? + + # Inputs / output structure. + n_obs_steps: ${n_obs_steps} + camera_names: [top] # [top, front_close, left_pillar, right_pillar] + chunk_size: 100 # chunk_size + n_action_steps: 100 + + # Vision preprocessing. + image_normalization_mean: [0.485, 0.456, 0.406] + image_normalization_std: [0.229, 0.224, 0.225] + + # Architecture. + # Vision backbone. + vision_backbone: resnet18 + use_pretrained_backbone: true + replace_final_stride_with_dilation: false + # Transformer layers. + pre_norm: false + d_model: 512 + n_heads: 8 + dim_feedforward: 3200 + feedforward_activation: relu + n_encoder_layers: 4 + n_decoder_layers: 1 + # VAE. + use_vae: true + latent_dim: 32 + n_vae_encoder_layers: 4 + + # Inference. + use_temporal_aggregation: false + + # Training and loss computation. + dropout: 0.1 + kl_weight: 10 + + # --- + # TODO(alexander-soare): Remove these from the policy config. + batch_size: 8 lr: 1e-5 lr_backbone: 1e-5 - pretrained_backbone: true weight_decay: 1e-4 grad_clip_norm: 10 - backbone: resnet18 - horizon: ${horizon} # chunk_size - kl_weight: 10 - d_model: 512 - dim_feedforward: 3200 - vae_enc_layers: 4 - enc_layers: 4 - dec_layers: 1 - num_heads: 8 - #camera_names: [top, front_close, left_pillar, right_pillar] - camera_names: [top] - dilation: false - dropout: 0.1 - pre_norm: false - activation: relu - latent_dim: 32 - - use_vae: true - - batch_size: 8 - - per_alpha: 0.6 - per_beta: 0.4 - - balanced_sampling: false utd: 1 - n_obs_steps: ${n_obs_steps} - n_action_steps: ${n_action_steps} - - temporal_agg: false - - state_dim: 14 - action_dim: 14 - - image_normalization: - mean: [0.485, 0.456, 0.406] - std: [0.229, 0.224, 0.225] - delta_timestamps: - observation.images.top: [0.0] - observation.state: [0.0] - action: "[i / ${fps} for i in range(${horizon})]" + observation.images.top: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]" + observation.state: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]" + action: "[i / ${fps} for i in range(${policy.chunk_size})]" From 34f00753eb384575dcfdac7a95505c38fc58aab5 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Fri, 12 Apr 2024 17:13:25 +0100 Subject: [PATCH 2/7] remove policy.py --- lerobot/common/policies/act/policy.py | 678 -------------------------- 1 file changed, 678 deletions(-) delete mode 100644 lerobot/common/policies/act/policy.py diff --git a/lerobot/common/policies/act/policy.py b/lerobot/common/policies/act/policy.py deleted file mode 100644 index 25b814ed..00000000 --- a/lerobot/common/policies/act/policy.py +++ /dev/null @@ -1,678 +0,0 @@ -"""Action Chunking Transformer Policy - -As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://arxiv.org/abs/2304.13705). -The majority of changes here involve removing unused code, unifying naming, and adding helpful comments. -""" - -import math -import time -from collections import deque -from itertools import chain -from typing import Callable - -import einops -import numpy as np -import torch -import torch.nn.functional as F # noqa: N812 -import torchvision -import torchvision.transforms as transforms -from torch import Tensor, nn -from torchvision.models._utils import IntermediateLayerGetter -from torchvision.ops.misc import FrozenBatchNorm2d - -from lerobot.common.utils import get_safe_torch_device - - -class ActionChunkingTransformerPolicy(nn.Module): - """ - Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost - Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act) - - Note: In this code we use the terms `vae_encoder`, 'encoder', `decoder`. The meanings are as follows. - - The `vae_encoder` is, as per the literature around variational auto-encoders (VAE), the part of the - model that encodes the target data (a sequence of actions), and the condition (the robot - joint-space). - - A transformer with an `encoder` (not the VAE encoder) and `decoder` (not the VAE decoder) with - cross-attention is used as the VAE decoder. For these terms, we drop the `vae_` prefix because we - have an option to train this model without the variational objective (in which case we drop the - `vae_encoder` altogether, and nothing about this model has anything to do with a VAE). - - Transformer - Used alone for inference - (acts as VAE decoder - during training) - ┌───────────────────────┐ - │ Outputs │ - │ ▲ │ - │ ┌─────►┌───────┐ │ - ┌──────┐ │ │ │Transf.│ │ - │ │ │ ├─────►│decoder│ │ - ┌────┴────┐ │ │ │ │ │ │ - │ │ │ │ ┌───┴───┬─►│ │ │ - │ VAE │ │ │ │ │ └───────┘ │ - │ encoder │ │ │ │Transf.│ │ - │ │ │ │ │encoder│ │ - └───▲─────┘ │ │ │ │ │ - │ │ │ └───▲───┘ │ - │ │ │ │ │ - inputs └─────┼─────┘ │ - │ │ - └───────────────────────┘ - """ - - name = "act" - _multiple_obs_steps_not_handled_msg = ( - "ActionChunkingTransformerPolicy does not handle multiple observation steps." - ) - - def __init__(self, cfg, device): - """ - TODO(alexander-soare): Add documentation for all parameters once we have model configs established. - """ - super().__init__() - if getattr(cfg, "n_obs_steps", 1) != 1: - raise ValueError(self._multiple_obs_steps_not_handled_msg) - self.cfg = cfg - self.n_action_steps = cfg.n_action_steps - self.device = get_safe_torch_device(device) - self.camera_names = cfg.camera_names - self.use_vae = cfg.use_vae - self.horizon = cfg.horizon - self.d_model = cfg.d_model - - transformer_common_kwargs = dict( # noqa: C408 - d_model=self.d_model, - num_heads=cfg.num_heads, - dim_feedforward=cfg.dim_feedforward, - dropout=cfg.dropout, - activation=cfg.activation, - normalize_before=cfg.pre_norm, - ) - - # BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence]. - # The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]). - if self.use_vae: - self.vae_encoder = _TransformerEncoder(num_layers=cfg.vae_enc_layers, **transformer_common_kwargs) - self.vae_encoder_cls_embed = nn.Embedding(1, self.d_model) - # Projection layer for joint-space configuration to hidden dimension. - self.vae_encoder_robot_state_input_proj = nn.Linear(cfg.state_dim, self.d_model) - # Projection layer for action (joint-space target) to hidden dimension. - self.vae_encoder_action_input_proj = nn.Linear(cfg.state_dim, self.d_model) - self.latent_dim = cfg.latent_dim - # Projection layer from the VAE encoder's output to the latent distribution's parameter space. - self.vae_encoder_latent_output_proj = nn.Linear(self.d_model, self.latent_dim * 2) - # Fixed sinusoidal positional embedding the whole input to the VAE encoder. Unsqueeze for batch - # dimension. - self.register_buffer( - "vae_encoder_pos_enc", - _create_sinusoidal_position_embedding(1 + 1 + self.horizon, self.d_model).unsqueeze(0), - ) - - # Backbone for image feature extraction. - self.image_normalizer = transforms.Normalize( - mean=cfg.image_normalization.mean, std=cfg.image_normalization.std - ) - backbone_model = getattr(torchvision.models, cfg.backbone)( - replace_stride_with_dilation=[False, False, cfg.dilation], - pretrained=cfg.pretrained_backbone, - norm_layer=FrozenBatchNorm2d, - ) - # Note: The forward method of this returns a dict: {"feature_map": output}. - self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"}) - - # Transformer (acts as VAE decoder when training with the variational objective). - self.encoder = _TransformerEncoder(num_layers=cfg.enc_layers, **transformer_common_kwargs) - self.decoder = _TransformerDecoder(num_layers=cfg.dec_layers, **transformer_common_kwargs) - - # Transformer encoder input projections. The tokens will be structured like - # [latent, robot_state, image_feature_map_pixels]. - self.encoder_robot_state_input_proj = nn.Linear(cfg.state_dim, self.d_model) - self.encoder_latent_input_proj = nn.Linear(self.latent_dim, self.d_model) - self.encoder_img_feat_input_proj = nn.Conv2d( - backbone_model.fc.in_features, self.d_model, kernel_size=1 - ) - # Transformer encoder positional embeddings. - self.encoder_robot_and_latent_pos_embed = nn.Embedding(2, self.d_model) - self.encoder_cam_feat_pos_embed = _SinusoidalPositionEmbedding2D(self.d_model // 2) - - # Transformer decoder. - # Learnable positional embedding for the transformer's decoder (in the style of DETR object queries). - self.decoder_pos_embed = nn.Embedding(self.horizon, self.d_model) - - # Final action regression head on the output of the transformer's decoder. - self.action_head = nn.Linear(self.d_model, cfg.action_dim) - - self._reset_parameters() - - self._create_optimizer() - self.to(self.device) - - def _create_optimizer(self): - optimizer_params_dicts = [ - { - "params": [ - p for n, p in self.named_parameters() if not n.startswith("backbone") and p.requires_grad - ] - }, - { - "params": [ - p for n, p in self.named_parameters() if n.startswith("backbone") and p.requires_grad - ], - "lr": self.cfg.lr_backbone, - }, - ] - self.optimizer = torch.optim.AdamW( - optimizer_params_dicts, lr=self.cfg.lr, weight_decay=self.cfg.weight_decay - ) - - def _reset_parameters(self): - """Xavier-uniform initialization of the transformer parameters as in the original code.""" - for p in chain(self.encoder.parameters(), self.decoder.parameters()): - if p.dim() > 1: - nn.init.xavier_uniform_(p) - - def reset(self): - """This should be called whenever the environment is reset.""" - if self.n_action_steps is not None: - self._action_queue = deque([], maxlen=self.n_action_steps) - - def select_action(self, batch: dict[str, Tensor], *_, **__) -> Tensor: - """ - This method wraps `select_actions` in order to return one action at a time for execution in the - environment. It works by managing the actions in a queue and only calling `select_actions` when the - queue is empty. - """ - if len(self._action_queue) == 0: - # `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has shape - # (n_action_steps, batch_size, *), hence the transpose. - self._action_queue.extend(self.select_actions(batch).transpose(0, 1)) - return self._action_queue.popleft() - - @torch.no_grad() - def select_actions(self, batch: dict[str, Tensor]) -> Tensor: - """Use the action chunking transformer to generate a sequence of actions.""" - self.eval() - self._preprocess_batch(batch, add_obs_steps_dim=True) - - action = self.forward(batch, return_loss=False) - - if self.cfg.temporal_agg: - # TODO(rcadene): implement temporal aggregation - raise NotImplementedError() - # all_time_actions[[t], t:t+num_queries] = action - # actions_for_curr_step = all_time_actions[:, t] - # actions_populated = torch.all(actions_for_curr_step != 0, axis=1) - # actions_for_curr_step = actions_for_curr_step[actions_populated] - # k = 0.01 - # exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step))) - # exp_weights = exp_weights / exp_weights.sum() - # exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1) - # raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True) - - return action[: self.n_action_steps] - - def __call__(self, *args, **kwargs) -> dict: - # TODO(now): Temporary bridge until we know what to do about the `update` method. - return self.update(*args, **kwargs) - - def _preprocess_batch( - self, batch: dict[str, Tensor], add_obs_steps_dim: bool = False - ) -> dict[str, Tensor]: - """ - This function expects `batch` to have (at least): - { - "observation.state": (B, 1, J) OR (B, J) tensor of robot states (joint configuration). - "observation.images.top": (B, 1, C, H, W) OR (B, C, H, W) tensor of images. - "action": (B, H, J) tensor of actions (positional target for robot joint configuration) - "action_is_pad": (B, H) mask for whether the actions are padding outside of the episode bounds. - } - """ - if add_obs_steps_dim: - # Add a dimension for the observations steps. Since n_obs_steps > 1 is not supported right now, - # this just amounts to an unsqueeze. - for k in batch: - if k.startswith("observation."): - batch[k] = batch[k].unsqueeze(1) - - if batch["observation.state"].shape[1] != 1: - raise ValueError(self._multiple_obs_steps_not_handled_msg) - batch["observation.state"] = batch["observation.state"].squeeze(1) - # TODO(alexander-soare): generalize this to multiple images. - assert ( - sum(k.startswith("observation.images.") and not k.endswith("is_pad") for k in batch) == 1 - ), "ACT only handles one image for now." - # Note: no squeeze is required for "observation.images.top" because then we'd have to unsqueeze to get - # the image index dimension. - - def update(self, batch, *_, **__) -> dict: - start_time = time.time() - self._preprocess_batch(batch) - - self.train() - - num_slices = self.cfg.batch_size - batch_size = self.cfg.horizon * num_slices - - assert batch_size % self.cfg.horizon == 0 - assert batch_size % num_slices == 0 - - loss = self.forward(batch, return_loss=True)["loss"] - loss.backward() - - grad_norm = torch.nn.utils.clip_grad_norm_( - self.parameters(), - self.cfg.grad_clip_norm, - error_if_nonfinite=False, - ) - - self.optimizer.step() - self.optimizer.zero_grad() - - info = { - "loss": loss.item(), - "grad_norm": float(grad_norm), - "lr": self.cfg.lr, - "update_s": time.time() - start_time, - } - - return info - - def forward(self, batch: dict[str, Tensor], return_loss: bool = False) -> dict | Tensor: - images = self.image_normalizer(batch["observation.images.top"]) - - if return_loss: # training time - actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward( - batch["observation.state"], images, batch["action"] - ) - - l1_loss = ( - F.l1_loss(batch["action"], actions_hat, reduction="none") - * ~batch["action_is_pad"].unsqueeze(-1) - ).mean() - - loss_dict = {} - loss_dict["l1"] = l1_loss - if self.cfg.use_vae: - # Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for - # each dimension independently, we sum over the latent dimension to get the total - # KL-divergence per batch element, then take the mean over the batch. - # (See App. B of https://arxiv.org/abs/1312.6114 for more details). - mean_kld = ( - (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean() - ) - loss_dict["kl"] = mean_kld - loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.cfg.kl_weight - else: - loss_dict["loss"] = loss_dict["l1"] - return loss_dict - else: - action, _ = self._forward(batch["observation.state"], images) - return action - - def _forward( - self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None - ) -> tuple[Tensor, tuple[Tensor | None, Tensor | None]]: - """ - Args: - robot_state: (B, J) batch of robot joint configurations. - image: (B, N, C, H, W) batch of N camera frames. - actions: (B, S, A) batch of actions from the target dataset which must be provided if the - VAE is enabled and the model is in training mode. - Returns: - (B, S, A) batch of action sequences - Tuple containing the latent PDF's parameters (mean, log(σ²)) both as (B, L) tensors where L is the - latent dimension. - """ - if self.use_vae and self.training: - assert ( - actions is not None - ), "actions must be provided when using the variational objective in training mode." - - batch_size = robot_state.shape[0] - - # Prepare the latent for input to the transformer encoder. - if self.use_vae and actions is not None: - # Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence]. - cls_embed = einops.repeat( - self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size - ) # (B, 1, D) - robot_state_embed = self.vae_encoder_robot_state_input_proj(robot_state).unsqueeze(1) # (B, 1, D) - action_embed = self.vae_encoder_action_input_proj(actions) # (B, S, D) - vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D) - - # Prepare fixed positional embedding. - # Note: detach() shouldn't be necessary but leaving it the same as the original code just in case. - pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D) - - # Forward pass through VAE encoder to get the latent PDF parameters. - cls_token_out = self.vae_encoder( - vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2) - )[0] # select the class token, with shape (B, D) - latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out) - mu = latent_pdf_params[:, : self.latent_dim] - # This is 2log(sigma). Done this way to match the original implementation. - log_sigma_x2 = latent_pdf_params[:, self.latent_dim :] - - # Sample the latent with the reparameterization trick. - latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu) - else: - # When not using the VAE encoder, we set the latent to be all zeros. - mu = log_sigma_x2 = None - latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=torch.float32).to( - robot_state.device - ) - - # Prepare all other transformer encoder inputs. - # Camera observation features and positional embeddings. - all_cam_features = [] - all_cam_pos_embeds = [] - for cam_id, _ in enumerate(self.camera_names): - cam_features = self.backbone(image[:, cam_id])["feature_map"] - cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype) - cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w) - all_cam_features.append(cam_features) - all_cam_pos_embeds.append(cam_pos_embed) - # Concatenate camera observation feature maps and positional embeddings along the width dimension. - encoder_in = torch.cat(all_cam_features, axis=3) - cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=3) - - # Get positional embeddings for robot state and latent. - robot_state_embed = self.encoder_robot_state_input_proj(robot_state) - latent_embed = self.encoder_latent_input_proj(latent_sample) - - # Stack encoder input and positional embeddings moving to (S, B, C). - encoder_in = torch.cat( - [ - torch.stack([latent_embed, robot_state_embed], axis=0), - encoder_in.flatten(2).permute(2, 0, 1), - ] - ) - pos_embed = torch.cat( - [ - self.encoder_robot_and_latent_pos_embed.weight.unsqueeze(1), - cam_pos_embed.flatten(2).permute(2, 0, 1), - ], - axis=0, - ) - - # Forward pass through the transformer modules. - encoder_out = self.encoder(encoder_in, pos_embed=pos_embed) - decoder_in = torch.zeros( - (self.horizon, batch_size, self.d_model), dtype=pos_embed.dtype, device=pos_embed.device - ) - decoder_out = self.decoder( - decoder_in, - encoder_out, - encoder_pos_embed=pos_embed, - decoder_pos_embed=self.decoder_pos_embed.weight.unsqueeze(1), - ) - - # Move back to (B, S, C). - decoder_out = decoder_out.transpose(0, 1) - - actions = self.action_head(decoder_out) - - return actions, (mu, log_sigma_x2) - - def save(self, fp): - torch.save(self.state_dict(), fp) - - def load(self, fp): - d = torch.load(fp) - self.load_state_dict(d) - - -class _TransformerEncoder(nn.Module): - """Convenience module for running multiple encoder layers, maybe followed by normalization.""" - - def __init__(self, num_layers: int, **encoder_layer_kwargs: dict): - super().__init__() - self.layers = nn.ModuleList( - [_TransformerEncoderLayer(**encoder_layer_kwargs) for _ in range(num_layers)] - ) - self.norm = ( - nn.LayerNorm(encoder_layer_kwargs["d_model"]) - if encoder_layer_kwargs["normalize_before"] - else nn.Identity() - ) - - def forward(self, x: Tensor, pos_embed: Tensor | None = None) -> Tensor: - for layer in self.layers: - x = layer(x, pos_embed=pos_embed) - x = self.norm(x) - return x - - -class _TransformerEncoderLayer(nn.Module): - def __init__( - self, - d_model: int, - num_heads: int, - dim_feedforward: int, - dropout: float, - activation: str, - normalize_before: bool, - ): - super().__init__() - self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout) - - # Feed forward layers. - self.linear1 = nn.Linear(d_model, dim_feedforward) - self.dropout = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim_feedforward, d_model) - - self.norm1 = nn.LayerNorm(d_model) - self.norm2 = nn.LayerNorm(d_model) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) - - self.activation = _get_activation_fn(activation) - self.normalize_before = normalize_before - - def forward(self, x, pos_embed: Tensor | None = None) -> Tensor: - skip = x - if self.normalize_before: - x = self.norm1(x) - q = k = x if pos_embed is None else x + pos_embed - x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights - x = skip + self.dropout1(x) - if self.normalize_before: - skip = x - x = self.norm2(x) - else: - x = self.norm1(x) - skip = x - x = self.linear2(self.dropout(self.activation(self.linear1(x)))) - x = skip + self.dropout2(x) - if not self.normalize_before: - x = self.norm2(x) - return x - - -class _TransformerDecoder(nn.Module): - def __init__(self, num_layers: int, **decoder_layer_kwargs): - """Convenience module for running multiple decoder layers followed by normalization.""" - super().__init__() - self.layers = nn.ModuleList( - [_TransformerDecoderLayer(**decoder_layer_kwargs) for _ in range(num_layers)] - ) - self.num_layers = num_layers - self.norm = nn.LayerNorm(decoder_layer_kwargs["d_model"]) - - def forward( - self, - x: Tensor, - encoder_out: Tensor, - decoder_pos_embed: Tensor | None = None, - encoder_pos_embed: Tensor | None = None, - ) -> Tensor: - for layer in self.layers: - x = layer( - x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed - ) - if self.norm is not None: - x = self.norm(x) - return x - - -class _TransformerDecoderLayer(nn.Module): - def __init__( - self, - d_model: int, - num_heads: int, - dim_feedforward: int, - dropout: float, - activation: str, - normalize_before: bool, - ): - super().__init__() - self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout) - self.multihead_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout) - - # Feed forward layers. - self.linear1 = nn.Linear(d_model, dim_feedforward) - self.dropout = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim_feedforward, d_model) - - self.norm1 = nn.LayerNorm(d_model) - self.norm2 = nn.LayerNorm(d_model) - self.norm3 = nn.LayerNorm(d_model) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) - self.dropout3 = nn.Dropout(dropout) - - self.activation = _get_activation_fn(activation) - self.normalize_before = normalize_before - - def maybe_add_pos_embed(self, tensor: Tensor, pos_embed: Tensor | None) -> Tensor: - return tensor if pos_embed is None else tensor + pos_embed - - def forward( - self, - x: Tensor, - encoder_out: Tensor, - decoder_pos_embed: Tensor | None = None, - encoder_pos_embed: Tensor | None = None, - ) -> Tensor: - """ - Args: - x: (Decoder Sequence, Batch, Channel) tensor of input tokens. - encoder_out: (Encoder Sequence, B, C) output features from the last layer of the encoder we are - cross-attending with. - decoder_pos_embed: (ES, 1, C) positional embedding for keys (from the encoder). - encoder_pos_embed: (DS, 1, C) Positional_embedding for the queries (from the decoder). - Returns: - (DS, B, C) tensor of decoder output features. - """ - skip = x - if self.normalize_before: - x = self.norm1(x) - q = k = self.maybe_add_pos_embed(x, decoder_pos_embed) - x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights - x = skip + self.dropout1(x) - if self.normalize_before: - skip = x - x = self.norm2(x) - else: - x = self.norm1(x) - skip = x - x = self.multihead_attn( - query=self.maybe_add_pos_embed(x, decoder_pos_embed), - key=self.maybe_add_pos_embed(encoder_out, encoder_pos_embed), - value=encoder_out, - )[0] # select just the output, not the attention weights - x = skip + self.dropout2(x) - if self.normalize_before: - skip = x - x = self.norm3(x) - else: - x = self.norm2(x) - skip = x - x = self.linear2(self.dropout(self.activation(self.linear1(x)))) - x = skip + self.dropout3(x) - if not self.normalize_before: - x = self.norm3(x) - return x - - -def _create_sinusoidal_position_embedding(num_positions: int, dimension: int) -> Tensor: - """1D sinusoidal positional embeddings as in Attention is All You Need. - - Args: - num_positions: Number of token positions required. - Returns: (num_positions, dimension) position embeddings (the first dimension is the batch dimension). - - """ - - def get_position_angle_vec(position): - return [position / np.power(10000, 2 * (hid_j // 2) / dimension) for hid_j in range(dimension)] - - sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(num_positions)]) - sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i - sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 - return torch.from_numpy(sinusoid_table).float() - - -class _SinusoidalPositionEmbedding2D(nn.Module): - """2D sinusoidal positional embeddings similar to what's presented in Attention Is All You Need. - - The variation is that the position indices are normalized in [0, 2π] (not quite: the lower bound is 1/H - for the vertical direction, and 1/W for the horizontal direction. - """ - - def __init__(self, dimension: int): - """ - Args: - dimension: The desired dimension of the embeddings. - """ - super().__init__() - self.dimension = dimension - self._two_pi = 2 * math.pi - self._eps = 1e-6 - # Inverse "common ratio" for the geometric progression in sinusoid frequencies. - self._temperature = 10000 - - def forward(self, x: Tensor) -> Tensor: - """ - Args: - x: A (B, C, H, W) batch of 2D feature map to generate the embeddings for. - Returns: - A (1, C, H, W) batch of corresponding sinusoidal positional embeddings. - """ - not_mask = torch.ones_like(x[0, :1]) # (1, H, W) - # Note: These are like range(1, H+1) and range(1, W+1) respectively, but in most implementations - # they would be range(0, H) and range(0, W). Keeping it at as is to match the original code. - y_range = not_mask.cumsum(1, dtype=torch.float32) - x_range = not_mask.cumsum(2, dtype=torch.float32) - - # "Normalize" the position index such that it ranges in [0, 2π]. - # Note: Adding epsilon on the denominator should not be needed as all values of y_embed and x_range - # are non-zero by construction. This is an artifact of the original code. - y_range = y_range / (y_range[:, -1:, :] + self._eps) * self._two_pi - x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi - - inverse_frequency = self._temperature ** ( - 2 * (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2) / self.dimension - ) - - x_range = x_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1) - y_range = y_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1) - - # Note: this stack then flatten operation results in interleaved sine and cosine terms. - # pos_embed_x and pos_embed_y are (1, H, W, C // 2). - pos_embed_x = torch.stack((x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1).flatten(3) - pos_embed_y = torch.stack((y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1).flatten(3) - pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(0, 3, 1, 2) # (1, C, H, W) - - return pos_embed - - -def _get_activation_fn(activation: str) -> Callable: - """Return an activation function given a string.""" - if activation == "relu": - return F.relu - if activation == "gelu": - return F.gelu - if activation == "glu": - return F.glu - raise RuntimeError(f"activation should be relu/gelu/glu, not {activation}.") From ef4bd9e25c31c14ec8ac83f8cb92274b64548137 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Mon, 15 Apr 2024 09:39:23 +0100 Subject: [PATCH 3/7] Use dataclass config for ACT --- .../common/policies/act/configuration_act.py | 114 ++++++++++++------ lerobot/configs/policy/act.yaml | 2 +- tests/test_available.py | 4 +- 3 files changed, 83 insertions(+), 37 deletions(-) diff --git a/lerobot/common/policies/act/configuration_act.py b/lerobot/common/policies/act/configuration_act.py index a3dc1590..84d960db 100644 --- a/lerobot/common/policies/act/configuration_act.py +++ b/lerobot/common/policies/act/configuration_act.py @@ -1,60 +1,104 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field @dataclass class ActConfig: - """ - TODO(now): Document all variables - TODO(now): Pick sensible defaults for a use case? + """Configuration class for the Action Chunking Transformers policy. + + Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer". + + The parameters you will most likely need to change are the ones which depend on the environment / sensors. + Those are: `state_dim`, `action_dim` and `camera_names`. + + Args: + state_dim: Dimensionality of the observation state space (excluding images). + action_dim: Dimensionality of the action space. + n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the + current step and additional steps going back). + camera_names: The (unique) set of names for the cameras. + chunk_size: The size of the action prediction "chunks" in units of environment steps. + n_action_steps: The number of action steps to run in the environment for one invocation of the policy. + This should be no greater than the chunk size. For example, if the chunk size size 100, you may + set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the + environment, and throws the other 50 out. + image_normalization_mean: Value to subtract from the input image pixels (inputs are assumed to be in + [0, 1]) for normalization. + image_normalization_std: Value by which to divide the input image pixels (after the mean has been + subtracted). + vision_backbone: Name of the torchvision resnet backbone to use for encoding images. + use_pretrained_backbone: Whether the backbone should be initialized with ImageNet, pretrained weights + from torchvision. + replace_final_stride_with_dilation: Whether to replace the ResNet's final 2x2 stride with a dilated + convolution. + pre_norm: Whether to use "pre-norm" in the transformer blocks. + d_model: The transformer blocks' main hidden dimension. + n_heads: The number of heads to use in the transformer blocks' multi-head attention. + dim_feedforward: The dimension to expand the transformer's hidden dimension to in the feed-forward + layers. + feedforward_activation: The activation to use in the transformer block's feed-forward layers. + n_encoder_layers: The number of transformer layers to use for the transformer encoder. + n_decoder_layers: The number of transformer layers to use for the transformer decoder. + use_vae: Whether to use a variational objective during training. This introduces another transformer + which is used as the VAE's encoder (not to be confused with the transformer encoder - see + documentation in the policy class). + latent_dim: The VAE's latent dimension. + n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder. + use_temporal_aggregation: Whether to blend the actions of multiple policy invocations for any given + environment step. + dropout: Dropout to use in the transformer layers (see code for details). + kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective + is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`. """ # Environment. - state_dim: int - action_dim: int + state_dim: int = 14 + action_dim: int = 14 # Inputs / output structure. - n_obs_steps: int - camera_names: list[str] - chunk_size: int - n_action_steps: int + n_obs_steps: int = 1 + camera_names: list[str] = field(default_factory=lambda: ["top"]) + chunk_size: int = 100 + n_action_steps: int = 100 # Vision preprocessing. - image_normalization_mean: tuple[float, float, float] - image_normalization_std: tuple[float, float, float] + image_normalization_mean: tuple[float, float, float] = field( + default_factory=lambda: [0.485, 0.456, 0.406] + ) + image_normalization_std: tuple[float, float, float] = field(default_factory=lambda: [0.229, 0.224, 0.225]) # Architecture. # Vision backbone. - vision_backbone: str - use_pretrained_backbone: bool - replace_final_stride_with_dilation: int + vision_backbone: str = "resnet18" + use_pretrained_backbone: bool = True + replace_final_stride_with_dilation: int = False # Transformer layers. - pre_norm: bool - d_model: int - n_heads: int - dim_feedforward: int - feedforward_activation: str - n_encoder_layers: int - n_decoder_layers: int + pre_norm: bool = False + d_model: int = 512 + n_heads: int = 8 + dim_feedforward: int = 3200 + feedforward_activation: str = "relu" + n_encoder_layers: int = 4 + n_decoder_layers: int = 1 # VAE. - use_vae: bool - latent_dim: int - n_vae_encoder_layers: int + use_vae: bool = True + latent_dim: int = 32 + n_vae_encoder_layers: int = 4 # Inference. - use_temporal_aggregation: bool + use_temporal_aggregation: bool = False # Training and loss computation. - dropout: float - kl_weight: float + dropout: float = 0.1 + kl_weight: float = 10.0 # --- # TODO(alexander-soare): Remove these from the policy config. - batch_size: int - lr: float - lr_backbone: float - weight_decay: float - grad_clip_norm: float - utd: int + batch_size: int = 8 + lr: float = 1e-5 + lr_backbone: float = 1e-5 + weight_decay: float = 1e-4 + grad_clip_norm: float = 10 + utd: int = 1 def __post_init__(self): """Input validation.""" @@ -66,3 +110,5 @@ class ActConfig: raise ValueError( "The chunk size is the upper bound for the number of action steps per model invocation." ) + if self.camera_names != ["top"]: + raise ValueError("For now, `camera_names` can only be ['top']") diff --git a/lerobot/configs/policy/act.yaml b/lerobot/configs/policy/act.yaml index 8ae3087c..22f2d53a 100644 --- a/lerobot/configs/policy/act.yaml +++ b/lerobot/configs/policy/act.yaml @@ -54,7 +54,7 @@ policy: # Training and loss computation. dropout: 0.1 - kl_weight: 10 + kl_weight: 10.0 # --- # TODO(alexander-soare): Remove these from the policy config. diff --git a/tests/test_available.py b/tests/test_available.py index be74a42a..36791a3e 100644 --- a/tests/test_available.py +++ b/tests/test_available.py @@ -18,14 +18,14 @@ from lerobot.common.datasets.xarm import XarmDataset from lerobot.common.datasets.aloha import AlohaDataset from lerobot.common.datasets.pusht import PushtDataset -from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy +from lerobot.common.policies.act.modeling_act import ActPolicy from lerobot.common.policies.diffusion.policy import DiffusionPolicy from lerobot.common.policies.tdmpc.policy import TDMPCPolicy def test_available(): policy_classes = [ - ActionChunkingTransformerPolicy, + ActPolicy, DiffusionPolicy, TDMPCPolicy, ] From 2ccf89d78c32a9beaed716be6310a8239d78e6de Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Mon, 15 Apr 2024 09:47:25 +0100 Subject: [PATCH 4/7] try fix tests --- .github/workflows/test.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b3411e11..a86193b8 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -146,7 +146,8 @@ jobs: device=cpu \ save_model=true \ save_freq=2 \ - horizon=20 \ + policy.n_action_steps=20 \ + policy.chunk_size=20 \ policy.batch_size=2 \ hydra.run.dir=tests/outputs/act/ From 9241b5e8302bc2c9fe415b1c1c2f988ead6de746 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Mon, 15 Apr 2024 09:52:54 +0100 Subject: [PATCH 5/7] pass step as kwarg --- lerobot/scripts/eval.py | 2 +- lerobot/scripts/train.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index d676623e..2b8906d7 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -130,7 +130,7 @@ def eval_policy( # get the next action for the environment with torch.inference_mode(): - action = policy.select_action(observation, step) + action = policy.select_action(observation, step=step) # apply inverse transform to unnormalize the action action = postprocess_action(action, transform) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 03506f2a..5ff6538d 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -251,7 +251,7 @@ def train(cfg: dict, out_dir=None, job_name=None): for key in batch: batch[key] = batch[key].to(cfg.device, non_blocking=True) - train_info = policy(batch, step) + train_info = policy(batch, step=step) # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done? if step % cfg.log_freq == 0: From 40d417ef608ffb77301b00ff565f8ad68163cf40 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Mon, 15 Apr 2024 09:59:18 +0100 Subject: [PATCH 6/7] Make sure to make remove all traces of omegaconf from policy config --- lerobot/common/policies/factory.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index d9ba3f07..ed8ba7cf 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -1,5 +1,7 @@ import inspect +from omegaconf import OmegaConf + from lerobot.common.utils import get_safe_torch_device @@ -33,7 +35,13 @@ def make_policy(cfg): assert set(cfg.policy).issuperset( expected_kwargs ), f"Hydra config is missing arguments: {set(cfg.policy).difference(expected_kwargs)}" - policy_cfg = ActConfig(**{k: v for k, v in cfg.policy.items() if k in expected_kwargs}) + policy_cfg = ActConfig( + **{ + k: v + for k, v in OmegaConf.to_container(cfg.policy, resolve=True).items() + if k in expected_kwargs + } + ) policy = ActPolicy(policy_cfg) policy.to(get_safe_torch_device(cfg.device)) else: From 30023535f977212f12026b848e831b8367005328 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Mon, 15 Apr 2024 10:56:43 +0100 Subject: [PATCH 7/7] revision 1 --- lerobot/common/datasets/factory.py | 9 +++++---- .../common/policies/act/configuration_act.py | 2 +- lerobot/common/policies/act/modeling_act.py | 18 ++++++++++-------- lerobot/common/policies/factory.py | 10 +++++----- lerobot/configs/policy/act.yaml | 1 + tests/test_available.py | 4 ++-- 6 files changed, 24 insertions(+), 20 deletions(-) diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 10106fe9..4ae161f6 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -86,10 +86,11 @@ def make_dataset( ] ) - delta_timestamps = cfg.policy.delta_timestamps - for key in delta_timestamps: - if isinstance(delta_timestamps[key], str): - delta_timestamps[key] = eval(delta_timestamps[key]) + delta_timestamps = cfg.policy.get("delta_timestamps") + if delta_timestamps is not None: + for key in delta_timestamps: + if isinstance(delta_timestamps[key], str): + delta_timestamps[key] = eval(delta_timestamps[key]) dataset = clsfunc( dataset_id=cfg.dataset_id, diff --git a/lerobot/common/policies/act/configuration_act.py b/lerobot/common/policies/act/configuration_act.py index 84d960db..74ed270e 100644 --- a/lerobot/common/policies/act/configuration_act.py +++ b/lerobot/common/policies/act/configuration_act.py @@ -2,7 +2,7 @@ from dataclasses import dataclass, field @dataclass -class ActConfig: +class ActionChunkingTransformerConfig: """Configuration class for the Action Chunking Transformers policy. Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer". diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 769c9470..1361e071 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -20,10 +20,10 @@ from torch import Tensor, nn from torchvision.models._utils import IntermediateLayerGetter from torchvision.ops.misc import FrozenBatchNorm2d -from lerobot.common.policies.act.configuration_act import ActConfig +from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig -class ActPolicy(nn.Module): +class ActionChunkingTransformerPolicy(nn.Module): """ Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act) @@ -61,9 +61,11 @@ class ActPolicy(nn.Module): """ name = "act" - _multiple_obs_steps_not_handled_msg = "ActPolicy does not handle multiple observation steps." + _multiple_obs_steps_not_handled_msg = ( + "ActionChunkingTransformerPolicy does not handle multiple observation steps." + ) - def __init__(self, cfg: ActConfig): + def __init__(self, cfg: ActionChunkingTransformerConfig): """ TODO(alexander-soare): Add documentation for all parameters once we have model configs established. """ @@ -398,7 +400,7 @@ class ActPolicy(nn.Module): class _TransformerEncoder(nn.Module): """Convenience module for running multiple encoder layers, maybe followed by normalization.""" - def __init__(self, cfg: ActConfig): + def __init__(self, cfg: ActionChunkingTransformerConfig): super().__init__() self.layers = nn.ModuleList([_TransformerEncoderLayer(cfg) for _ in range(cfg.n_encoder_layers)]) self.norm = nn.LayerNorm(cfg.d_model) if cfg.pre_norm else nn.Identity() @@ -411,7 +413,7 @@ class _TransformerEncoder(nn.Module): class _TransformerEncoderLayer(nn.Module): - def __init__(self, cfg: ActConfig): + def __init__(self, cfg: ActionChunkingTransformerConfig): super().__init__() self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout) @@ -449,7 +451,7 @@ class _TransformerEncoderLayer(nn.Module): class _TransformerDecoder(nn.Module): - def __init__(self, cfg: ActConfig): + def __init__(self, cfg: ActionChunkingTransformerConfig): """Convenience module for running multiple decoder layers followed by normalization.""" super().__init__() self.layers = nn.ModuleList([_TransformerDecoderLayer(cfg) for _ in range(cfg.n_decoder_layers)]) @@ -472,7 +474,7 @@ class _TransformerDecoder(nn.Module): class _TransformerDecoderLayer(nn.Module): - def __init__(self, cfg: ActConfig): + def __init__(self, cfg: ActionChunkingTransformerConfig): super().__init__() self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout) self.multihead_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout) diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index ed8ba7cf..80ae27da 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -28,21 +28,21 @@ def make_policy(cfg): **cfg.policy, ) elif cfg.policy.name == "act": - from lerobot.common.policies.act.configuration_act import ActConfig - from lerobot.common.policies.act.modeling_act import ActPolicy + from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig + from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy - expected_kwargs = set(inspect.signature(ActConfig).parameters) + expected_kwargs = set(inspect.signature(ActionChunkingTransformerConfig).parameters) assert set(cfg.policy).issuperset( expected_kwargs ), f"Hydra config is missing arguments: {set(cfg.policy).difference(expected_kwargs)}" - policy_cfg = ActConfig( + policy_cfg = ActionChunkingTransformerConfig( **{ k: v for k, v in OmegaConf.to_container(cfg.policy, resolve=True).items() if k in expected_kwargs } ) - policy = ActPolicy(policy_cfg) + policy = ActionChunkingTransformerPolicy(policy_cfg) policy.to(get_safe_torch_device(cfg.device)) else: raise ValueError(cfg.policy.name) diff --git a/lerobot/configs/policy/act.yaml b/lerobot/configs/policy/act.yaml index 22f2d53a..bd883613 100644 --- a/lerobot/configs/policy/act.yaml +++ b/lerobot/configs/policy/act.yaml @@ -11,6 +11,7 @@ log_freq: 250 n_obs_steps: 1 # when temporal_agg=False, n_action_steps=horizon +# See `configuration_act.py` for more details. policy: name: act diff --git a/tests/test_available.py b/tests/test_available.py index 36791a3e..b25a921f 100644 --- a/tests/test_available.py +++ b/tests/test_available.py @@ -18,14 +18,14 @@ from lerobot.common.datasets.xarm import XarmDataset from lerobot.common.datasets.aloha import AlohaDataset from lerobot.common.datasets.pusht import PushtDataset -from lerobot.common.policies.act.modeling_act import ActPolicy +from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy from lerobot.common.policies.diffusion.policy import DiffusionPolicy from lerobot.common.policies.tdmpc.policy import TDMPCPolicy def test_available(): policy_classes = [ - ActPolicy, + ActionChunkingTransformerPolicy, DiffusionPolicy, TDMPCPolicy, ]