diff --git a/lerobot/common/policies/act/backbone.py b/lerobot/common/policies/act/backbone.py deleted file mode 100644 index 6399d339..00000000 --- a/lerobot/common/policies/act/backbone.py +++ /dev/null @@ -1,115 +0,0 @@ -from typing import List - -import torch -import torchvision -from torch import nn -from torchvision.models._utils import IntermediateLayerGetter - -from .position_encoding import build_position_encoding -from .utils import NestedTensor, is_main_process - - -class FrozenBatchNorm2d(torch.nn.Module): - """ - BatchNorm2d where the batch statistics and the affine parameters are fixed. - - Copy-paste from torchvision.misc.ops with added eps before rqsrt, - without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101] - produce nans. - """ - - def __init__(self, n): - super().__init__() - self.register_buffer("weight", torch.ones(n)) - self.register_buffer("bias", torch.zeros(n)) - self.register_buffer("running_mean", torch.zeros(n)) - self.register_buffer("running_var", torch.ones(n)) - - def _load_from_state_dict( - self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs - ): - num_batches_tracked_key = prefix + "num_batches_tracked" - if num_batches_tracked_key in state_dict: - del state_dict[num_batches_tracked_key] - - super()._load_from_state_dict( - state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs - ) - - def forward(self, x): - # move reshapes to the beginning - # to make it fuser-friendly - w = self.weight.reshape(1, -1, 1, 1) - b = self.bias.reshape(1, -1, 1, 1) - rv = self.running_var.reshape(1, -1, 1, 1) - rm = self.running_mean.reshape(1, -1, 1, 1) - eps = 1e-5 - scale = w * (rv + eps).rsqrt() - bias = b - rm * scale - return x * scale + bias - - -class BackboneBase(nn.Module): - def __init__( - self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool - ): - super().__init__() - # for name, parameter in backbone.named_parameters(): # only train later layers # TODO do we want this? - # if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: - # parameter.requires_grad_(False) - if return_interm_layers: - return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} - else: - return_layers = {"layer4": "0"} - self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) - self.num_channels = num_channels - - def forward(self, tensor): - xs = self.body(tensor) - return xs - # out: Dict[str, NestedTensor] = {} - # for name, x in xs.items(): - # m = tensor_list.mask - # assert m is not None - # mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] - # out[name] = NestedTensor(x, mask) - # return out - - -class Backbone(BackboneBase): - """ResNet backbone with frozen BatchNorm.""" - - def __init__(self, name: str, train_backbone: bool, return_interm_layers: bool, dilation: bool): - backbone = getattr(torchvision.models, name)( - replace_stride_with_dilation=[False, False, dilation], - pretrained=is_main_process(), - norm_layer=FrozenBatchNorm2d, - ) # pretrained # TODO do we want frozen batch_norm?? - num_channels = 512 if name in ("resnet18", "resnet34") else 2048 - super().__init__(backbone, train_backbone, num_channels, return_interm_layers) - - -class Joiner(nn.Sequential): - def __init__(self, backbone, position_embedding): - super().__init__(backbone, position_embedding) - - def forward(self, tensor_list: NestedTensor): - xs = self[0](tensor_list) - out: List[NestedTensor] = [] - pos = [] - for _, x in xs.items(): - out.append(x) - # position encoding - pos.append(self[1](x).to(x.dtype)) - - return out, pos - - -def build_backbone(args): - position_embedding = build_position_encoding(args) - train_backbone = args.lr_backbone > 0 - return_interm_layers = args.masks - backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation) - model = Joiner(backbone, position_embedding) - model.num_channels = backbone.num_channels - return model diff --git a/lerobot/common/policies/act/detr_vae.py b/lerobot/common/policies/act/detr_vae.py deleted file mode 100644 index aaf4d098..00000000 --- a/lerobot/common/policies/act/detr_vae.py +++ /dev/null @@ -1,229 +0,0 @@ -import einops -import numpy as np -import torch -from torch import nn - -from .backbone import build_backbone -from .transformer import Transformer, TransformerEncoder - - -def get_sinusoid_encoding_table(n_position, d_hid): - def get_position_angle_vec(position): - return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] - - sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) - sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i - sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 - - return torch.FloatTensor(sinusoid_table).unsqueeze(0) - - -class ActionChunkingTransformer(nn.Module): - """ - Action Chunking Transformer as per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware - (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act) - - Note: In this code we use the symbols `vae_encoder`, 'encoder', `decoder`. The meanings are as follows. - - The `vae_encoder` is, as per the literature around conditional variational auto-encoders (cVAE), the - part of the model that encodes the target data (here, a sequence of actions), and the condition - (here, we include the robot joint-space state as an input to the encoder). - - The `transformer` is the cVAE's decoder. But since we have an option to train this model without the - variational objective (in which case we drop the `vae_encoder` altogether), we don't call it the - `vae_decoder`. - # TODO(now): remove the following - - The `encoder` is actually a component of the cVAE's "decoder". But we refer to it as an "encoder" - because, in terms of the transformer with cross-attention that forms the cVAE's decoder, it is the - "encoder" part. We drop the `vae_` prefix because we have an option to train this model without the - variational objective (in which case we drop the `vae_encoder` altogether), and nothing about this - model has anything to do with a VAE). - - The `decoder` is a building block of the VAE decoder, and is just the "decoder" part of a - transformer with cross-attention. For the same reasoning behind the naming of `encoder`, we make - this term agnostic to the option to use a variational objective for training. - - """ - - def __init__( - self, backbones, transformer, vae_encoder, state_dim, action_dim, horizon, camera_names, use_vae - ): - """Initializes the model. - Parameters: - backbones: torch module of the backbone to be used. See backbone.py - transformer: torch module of the transformer architecture. See transformer.py - state_dim: robot state dimension of the environment - horizon: number of object queries, ie detection slot. This is the maximal number of objects - DETR can detect in a single image. For COCO, we recommend 100 queries. - - Args: - state_dim: Robot positional state dimension. - action_dim: Action dimension. - horizon: The number of actions to generate in one forward pass. - use_vae: Whether to use the variational objective. TODO(now): Give more details. - """ - super().__init__() - - self.camera_names = camera_names - self.transformer = transformer - self.vae_encoder = vae_encoder - self.use_vae = use_vae - hidden_dim = transformer.d_model - - # BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence]. - # The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]). - if use_vae: - self.cls_embed = nn.Embedding(1, hidden_dim) - # Projection layer for joint-space configuration to hidden dimension. - self.vae_encoder_robot_state_input_proj = nn.Linear(state_dim, hidden_dim) - # Projection layer for action (joint-space target) to hidden dimension. - self.vae_encoder_action_input_proj = nn.Linear(state_dim, hidden_dim) - # Final size of latent z. TODO(now): Add to hyperparams. - self.latent_dim = 32 - # Projection layer from the VAE encoder's output to the latent distribution's parameter space. - self.vae_encoder_latent_output_proj = nn.Linear(hidden_dim, self.latent_dim * 2) - # Fixed sinusoidal positional embedding the whole input to the VAE encoder. - self.register_buffer( - "vae_encoder_pos_enc", get_sinusoid_encoding_table(1 + 1 + horizon, hidden_dim) - ) - - # Transformer encoder input projections. The tokens will be structured like - # [latent, robot_state, image_feature_map_pixels]. - self.backbones = nn.ModuleList(backbones) - self.encoder_img_feat_input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1) - self.encoder_robot_state_input_proj = nn.Linear(state_dim, hidden_dim) - self.encoder_latent_input_proj = nn.Linear(self.latent_dim, hidden_dim) - # TODO(now): Fix this nonsense. One positional embedding is needed. We should extract the image - # feature dimension with a dry run. - self.additional_pos_embed = nn.Embedding( - 2, hidden_dim - ) # learned position embedding for proprio and latent - - # Transformer decoder. - # Learnable positional embedding for the transformer's decoder (in the style of DETR object queries). - self.decoder_pos_embed = nn.Embedding(horizon, hidden_dim) - # Final action regression head on the output of the transformer's decoder. - self.action_head = nn.Linear(hidden_dim, action_dim) - - def forward(self, robot_state, image, actions=None): - """ - Args: - robot_state: (B, J) batch of robot joint configurations. - image: (B, N, C, H, W) batch of N camera frames. - actions: (B, S, A) batch of actions from the target dataset which must be provided if the - VAE is enabled and the model is in training mode. - """ - if self.use_vae and self.training: - assert ( - actions is not None - ), "actions must be provided when using the variational objective in training mode." - - batch_size, _ = robot_state.shape - - # Prepare the latent for input to the transformer. - if self.use_vae and actions is not None: - # Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence]. - cls_embed = einops.repeat(self.cls_embed.weight, "1 d -> b 1 d", b=batch_size) # (B, 1, D) - robot_state_embed = self.vae_encoder_robot_state_input_proj(robot_state).unsqueeze(1) # (B, 1, D) - action_embed = self.vae_encoder_action_input_proj(actions) # (B, S, D) - vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D) - # Note: detach() shouldn't be necessary but leaving it the same as the original code just in case. - # Prepare fixed positional embedding. - pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D) - # Forward pass through VAE encoder and sample the latent with the reparameterization trick. - cls_token_out = self.vae_encoder( - vae_encoder_input.permute(1, 0, 2), pos=pos_embed.permute(1, 0, 2) - )[0] # (B, D) - latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out) - mu = latent_pdf_params[:, : self.latent_dim] - logvar = latent_pdf_params[:, self.latent_dim :] - # Use reparameterization trick to sample from the latent's PDF. - latent_sample = mu + logvar.div(2).exp() * torch.randn_like(mu) - else: - # When not using the VAE encoder, we set the latent to be all zeros. - mu = logvar = None - latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=robot_state.dtype).to( - robot_state.device - ) - - # Prepare all other transformer inputs. - # Image observation features and position embeddings. - all_cam_features = [] - all_cam_pos = [] - for cam_id, _ in enumerate(self.camera_names): - # TODO(now): remove the positional embedding from the backbones. - cam_features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED - cam_features = cam_features[0] # take the last layer feature - pos = pos[0] - cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w) - all_cam_features.append(cam_features) - all_cam_pos.append(pos) - # Concatenate image observation feature maps along the width dimension. - transformer_input = torch.cat(all_cam_features, axis=3) - # TODO(now): remove the positional embedding from the backbones. - pos = torch.cat(all_cam_pos, axis=3) - robot_state_embed = self.encoder_robot_state_input_proj(robot_state) - latent_embed = self.encoder_latent_input_proj(latent_sample) - - # TODO(now): Explain all of this madness. - transformer_input = torch.cat( - [ - torch.stack([latent_embed, robot_state_embed], axis=0), - transformer_input.flatten(2).permute(2, 0, 1), - ] - ) - pos_embed = torch.cat( - [self.additional_pos_embed.weight.unsqueeze(1), pos.flatten(2).permute(2, 0, 1)], axis=0 - ) - - # Run the transformer and project the outputs to the action space. - transformer_output = self.transformer( - transformer_input, - encoder_pos=pos_embed, - decoder_pos=self.decoder_pos_embed.weight.unsqueeze(1), - ).transpose(0, 1) # back to (B, S, C) - actions = self.action_head(transformer_output) - return actions, [mu, logvar] - - -def build(args): - # From state - # backbone = None # from state for now, no need for conv nets - # From image - backbones = [] - backbone = build_backbone(args) - backbones.append(backbone) - - transformer = Transformer( - d_model=args.hidden_dim, - dropout=args.dropout, - nhead=args.nheads, - dim_feedforward=args.dim_feedforward, - num_encoder_layers=args.enc_layers, - num_decoder_layers=args.dec_layers, - normalize_before=args.pre_norm, - ) - - # TODO(now): args.enc_layers shouldn't be shared with the transformer decoder - vae_encoder = TransformerEncoder( - num_layers=args.enc_layers, - d_model=args.hidden_dim, - nhead=args.nheads, - dim_feedforward=args.dim_feedforward, - dropout=args.dropout, - activation="relu", - normalize_before=args.pre_norm, - ) - - model = ActionChunkingTransformer( - backbones, - transformer, - vae_encoder, - state_dim=args.state_dim, - action_dim=args.action_dim, - horizon=args.num_queries, - camera_names=args.camera_names, - use_vae=args.vae, - ) - - n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) - print("number of parameters: {:.2f}M".format(n_parameters / 1e6)) - - return model diff --git a/lerobot/common/policies/act/policy.py b/lerobot/common/policies/act/policy.py index 7d24620a..906ea0cd 100644 --- a/lerobot/common/policies/act/policy.py +++ b/lerobot/common/policies/act/policy.py @@ -1,50 +1,32 @@ -import logging -import time +"""Action Chunking Transformer Policy +As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://arxiv.org/abs/2304.13705). +""" + +import logging +import math +import time +from itertools import chain +from typing import Callable, Optional + +import einops +import numpy as np import torch import torch.nn.functional as F # noqa: N812 +import torchvision import torchvision.transforms as transforms +from torch import Tensor, nn +from torchvision.models._utils import IntermediateLayerGetter +from torchvision.ops.misc import FrozenBatchNorm2d from lerobot.common.policies.abstract import AbstractPolicy -from lerobot.common.policies.act.detr_vae import build from lerobot.common.utils import get_safe_torch_device -def build_act_model_and_optimizer(cfg): - model = build(cfg) - - param_dicts = [ - {"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]}, - { - "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad], - "lr": cfg.lr_backbone, - }, - ] - optimizer = torch.optim.AdamW(param_dicts, lr=cfg.lr, weight_decay=cfg.weight_decay) - - return model, optimizer - - -def kl_divergence(mu, logvar): - batch_size = mu.size(0) - assert batch_size != 0 - if mu.data.ndimension() == 4: - mu = mu.view(mu.size(0), mu.size(1)) - if logvar.data.ndimension() == 4: - logvar = logvar.view(logvar.size(0), logvar.size(1)) - - klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()) - total_kld = klds.sum(1).mean(0, True) - dimension_wise_kld = klds.mean(0) - mean_kld = klds.mean(1).mean(0, True) - - return total_kld, dimension_wise_kld, mean_kld - - class ActionChunkingTransformerPolicy(AbstractPolicy): """ - Action Chunking Transformer as per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware - (https://arxiv.org/abs/2304.13705). + Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost + Hardware (https://arxiv.org/abs/2304.13705). """ name = "act" @@ -68,7 +50,35 @@ class ActionChunkingTransformerPolicy(AbstractPolicy): self.cfg = cfg self.n_action_steps = n_action_steps self.device = get_safe_torch_device(device) - self.model, self.optimizer = build_act_model_and_optimizer(cfg) + + self.model = ActionChunkingTransformer( + cfg, + state_dim=cfg.state_dim, + action_dim=cfg.action_dim, + horizon=cfg.horizon, + camera_names=cfg.camera_names, + use_vae=cfg.vae, + ) + + optimizer_params_dicts = [ + { + "params": [ + p + for n, p in self.model.named_parameters() + if not n.startswith("backbone") and p.requires_grad + ] + }, + { + "params": [ + p + for n, p in self.model.named_parameters() + if n.startswith("backbone") and p.requires_grad + ], + "lr": cfg.lr_backbone, + }, + ] + self.optimizer = torch.optim.AdamW(optimizer_params_dicts, lr=cfg.lr, weight_decay=cfg.weight_decay) + self.kl_weight = self.cfg.kl_weight logging.info(f"KL Weight {self.kl_weight}") self.to(self.device) @@ -140,12 +150,10 @@ class ActionChunkingTransformerPolicy(AbstractPolicy): self.optimizer.step() self.optimizer.zero_grad() - # self.lr_scheduler.step() info = { "loss": loss.item(), "grad_norm": float(grad_norm), - # "lr": self.lr_scheduler.get_last_lr()[0], "lr": self.cfg.lr, "data_s": data_s, "update_s": time.time() - start_time, @@ -213,31 +221,495 @@ class ActionChunkingTransformerPolicy(AbstractPolicy): action = action[: self.n_action_steps] return action - def _forward(self, qpos, image, actions=None, is_pad=None): - env_state = None + def _forward(self, qpos, image, actions=None): normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) image = normalize(image) is_training = actions is not None if is_training: # training time - actions = actions[:, : self.model.num_queries] - if is_pad is not None: - is_pad = is_pad[:, : self.model.num_queries] + actions = actions[:, : self.model.horizon] - a_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad) + a_hat, (mu, log_sigma_x2) = self.model(qpos, image, actions) all_l1 = F.l1_loss(actions, a_hat, reduction="none") - l1 = all_l1.mean() if is_pad is None else (all_l1 * ~is_pad.unsqueeze(-1)).mean() + l1 = all_l1.mean() loss_dict = {} loss_dict["l1"] = l1 if self.cfg.vae: - total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar) - loss_dict["kl"] = total_kld[0] + # Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for + # each dimension independently, we sum over the latent dimension to get the total + # KL-divergence per batch element, then take the mean over the batch. + # (See App. B of https://arxiv.org/abs/1312.6114 for more details). + mean_kld = (-0.5 * (1 + log_sigma_x2 - mu.pow(2) - (log_sigma_x2).exp())).sum(-1).mean() + loss_dict["kl"] = mean_kld loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.kl_weight else: loss_dict["loss"] = loss_dict["l1"] return loss_dict else: - action, _ = self.model(qpos, image, env_state) # no action, sample from prior + action, _ = self.model(qpos, image) # no action, sample from prior return action + + +def create_sinusoidal_position_embedding(n_position, d_hid): + def get_position_angle_vec(position): + return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] + + sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + + return torch.FloatTensor(sinusoid_table).unsqueeze(0) + + +# TODO(alexander-soare) move all this code into the policy when we have the policy API established. +class ActionChunkingTransformer(nn.Module): + """ + Action Chunking Transformer as per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware + (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act) + + Note: In this code we use the terms `vae_encoder`, 'encoder', `decoder`. The meanings are as follows. + - The `vae_encoder` is, as per the literature around variational auto-encoders (VAE), the part of the + model that encodes the target data (a sequence of actions), and the condition (the robot + joint-space). + - A transformer with an `encoder` (not the VAE encoder) and `decoder` (not the VAE decoder) with + cross-attention is used as the VAE decoder. For these terms, we drop the `vae_` prefix because we + have an option to train this model without the variational objective (in which case we drop the + `vae_encoder` altogether, and nothing about this model has anything to do with a VAE). + + Transformer + Used alone for inference + (acts as VAE decoder + during training) + ┌───────────────────────┐ + │ Outputs │ + │ ▲ │ + │ ┌─────►┌───────┐ │ + ┌──────┐ │ │ │Transf.│ │ + │ │ │ ├─────►│decoder│ │ + ┌────┴────┐ │ │ │ │ │ │ + │ │ │ │ ┌───┴───┬─►│ │ │ + │ VAE │ │ │ │ │ └───────┘ │ + │ encoder │ │ │ │Transf.│ │ + │ │ │ │ │encoder│ │ + └───▲─────┘ │ │ │ │ │ + │ │ │ └───▲───┘ │ + │ │ │ │ │ + inputs └─────┼─────┘ │ + │ │ + └───────────────────────┘ + """ + + def __init__(self, args, state_dim, action_dim, horizon, camera_names, use_vae): + """Initializes the model. + Parameters: + state_dim: robot state dimension of the environment + horizon: number of object queries, ie detection slot. This is the maximal number of objects + DETR can detect in a single image. For COCO, we recommend 100 queries. + + Args: + state_dim: Robot positional state dimension. + action_dim: Action dimension. + horizon: The number of actions to generate in one forward pass. + use_vae: Whether to use the variational objective. TODO(now): Give more details. + """ + super().__init__() + + self.camera_names = camera_names + self.use_vae = use_vae + self.horizon = horizon + self.hidden_dim = args.hidden_dim + + transformer_common_kwargs = dict( # noqa: C408 + d_model=self.hidden_dim, + nhead=args.nheads, + dim_feedforward=args.dim_feedforward, + dropout=args.dropout, + activation=args.activation, + normalize_before=args.pre_norm, + ) + + # BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence]. + # The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]). + if use_vae: + # TODO(now): args.enc_layers shouldn't be shared with the transformer decoder + self.vae_encoder = TransformerEncoder(num_layers=args.enc_layers, **transformer_common_kwargs) + self.cls_embed = nn.Embedding(1, self.hidden_dim) + # Projection layer for joint-space configuration to hidden dimension. + self.vae_encoder_robot_state_input_proj = nn.Linear(state_dim, self.hidden_dim) + # Projection layer for action (joint-space target) to hidden dimension. + self.vae_encoder_action_input_proj = nn.Linear(state_dim, self.hidden_dim) + # Final size of latent z. TODO(now): Add to hyperparams. + self.latent_dim = 32 + # Projection layer from the VAE encoder's output to the latent distribution's parameter space. + self.vae_encoder_latent_output_proj = nn.Linear(self.hidden_dim, self.latent_dim * 2) + # Fixed sinusoidal positional embedding the whole input to the VAE encoder. + self.register_buffer( + "vae_encoder_pos_enc", create_sinusoidal_position_embedding(1 + 1 + horizon, self.hidden_dim) + ) + + # Backbone for image feature extraction. + self.backbone_position_embedding = SinusoidalPositionEmbedding2D(self.hidden_dim // 2) + backbone_model = getattr(torchvision.models, args.backbone)( + replace_stride_with_dilation=[False, False, args.dilation], + pretrained=True, # TODO(now): Add pretrained option + norm_layer=FrozenBatchNorm2d, + ) + # Note: The forward method of this returns a dict: {"feature_map": output}. + self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"}) + + # Transformer (acts as VAE decoder when training with the variational objective). + self.encoder = TransformerEncoder(num_layers=args.enc_layers, **transformer_common_kwargs) + self.decoder = TransformerDecoder(num_layers=args.dec_layers, **transformer_common_kwargs) + + # Transformer encoder input projections. The tokens will be structured like + # [latent, robot_state, image_feature_map_pixels]. + self.encoder_img_feat_input_proj = nn.Conv2d( + backbone_model.fc.in_features, self.hidden_dim, kernel_size=1 + ) + self.encoder_robot_state_input_proj = nn.Linear(state_dim, self.hidden_dim) + self.encoder_latent_input_proj = nn.Linear(self.latent_dim, self.hidden_dim) + # TODO(now): Fix this nonsense. One positional embedding is needed. We should extract the image + # feature dimension with a dry run. + self.additional_pos_embed = nn.Embedding( + 2, self.hidden_dim + ) # learned position embedding for proprio and latent + + # Transformer decoder. + # Learnable positional embedding for the transformer's decoder (in the style of DETR object queries). + self.decoder_pos_embed_embed = nn.Embedding(horizon, self.hidden_dim) + # Final action regression head on the output of the transformer's decoder. + self.action_head = nn.Linear(self.hidden_dim, action_dim) + + self._reset_parameters() + + def _reset_parameters(self): + """Xavier-uniform initialization of the transformer parameters as in the original code.""" + for p in chain(self.encoder.parameters(), self.decoder.parameters()): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, robot_state, image, actions=None): + """ + Args: + robot_state: (B, J) batch of robot joint configurations. + image: (B, N, C, H, W) batch of N camera frames. + actions: (B, S, A) batch of actions from the target dataset which must be provided if the + VAE is enabled and the model is in training mode. + """ + if self.use_vae and self.training: + assert ( + actions is not None + ), "actions must be provided when using the variational objective in training mode." + + batch_size, _ = robot_state.shape + + # Prepare the latent for input to the transformer. + if self.use_vae and actions is not None: + # Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence]. + cls_embed = einops.repeat(self.cls_embed.weight, "1 d -> b 1 d", b=batch_size) # (B, 1, D) + robot_state_embed = self.vae_encoder_robot_state_input_proj(robot_state).unsqueeze(1) # (B, 1, D) + action_embed = self.vae_encoder_action_input_proj(actions) # (B, S, D) + vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D) + # Note: detach() shouldn't be necessary but leaving it the same as the original code just in case. + # Prepare fixed positional embedding. + pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D) + # Forward pass through VAE encoder and sample the latent with the reparameterization trick. + cls_token_out = self.vae_encoder( + vae_encoder_input.permute(1, 0, 2), pos=pos_embed.permute(1, 0, 2) + )[0] # (B, D) + latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out) + mu = latent_pdf_params[:, : self.latent_dim] + # This is 2log(sigma). Done this way to match the original implementation. + log_sigma_x2 = latent_pdf_params[:, self.latent_dim :] + # Use reparameterization trick to sample from the latent's PDF. + latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu) + else: + # When not using the VAE encoder, we set the latent to be all zeros. + mu = log_sigma_x2 = None + latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=torch.float32).to( + robot_state.device + ) + + # Prepare all other transformer inputs. + # Image observation features and position embeddings. + all_cam_features = [] + all_cam_pos = [] + for cam_id, _ in enumerate(self.camera_names): + cam_features = self.backbone(image[:, cam_id])["feature_map"] + pos = self.backbone_position_embedding(cam_features).to(dtype=cam_features.dtype) + cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w) + all_cam_features.append(cam_features) + all_cam_pos.append(pos) + # Concatenate image observation feature maps along the width dimension. + encoder_in = torch.cat(all_cam_features, axis=3) + pos = torch.cat(all_cam_pos, axis=3) + robot_state_embed = self.encoder_robot_state_input_proj(robot_state) + latent_embed = self.encoder_latent_input_proj(latent_sample) + + # TODO(now): Explain all of this madness. + encoder_in = torch.cat( + [ + torch.stack([latent_embed, robot_state_embed], axis=0), + encoder_in.flatten(2).permute(2, 0, 1), + ] + ) + pos_embed = torch.cat( + [self.additional_pos_embed.weight.unsqueeze(1), pos.flatten(2).permute(2, 0, 1)], axis=0 + ) + + encoder_out = self.encoder(encoder_in, pos=pos_embed) + decoder_in = torch.zeros( + (self.horizon, batch_size, self.hidden_dim), dtype=pos_embed.dtype, device=pos_embed.device + ) + decoder_out = self.decoder( + decoder_in, + encoder_out, + encoder_pos_embed=pos_embed, + decoder_pos_embed=self.decoder_pos_embed_embed.weight.unsqueeze(1), + ).transpose(0, 1) # back to (B, S, C) + + actions = self.action_head(decoder_out) + return actions, [mu, log_sigma_x2] + + +class TransformerEncoder(nn.Module): + def __init__( + self, + num_layers, + d_model, + nhead, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + ): + super().__init__() + self.layers = nn.ModuleList( + [ + TransformerEncoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before + ) + for _ in range(num_layers) + ] + ) + self.norm = nn.LayerNorm(d_model) if normalize_before else nn.Identity() + + def forward(self, x, pos: Optional[Tensor] = None): + for layer in self.layers: + x = layer(x, pos=pos) + x = self.norm(x) + return x + + +class TransformerEncoderLayer(nn.Module): + def __init__( + self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False + ): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def forward(self, x, pos_embed: Optional[Tensor] = None): + skip = x + if self.normalize_before: + x = self.norm1(x) + q = k = x if pos_embed is None else x + pos_embed + x = self.self_attn(q, k, value=x)[0] + x = skip + self.dropout1(x) + if self.normalize_before: + skip = x + x = self.norm2(x) + else: + x = self.norm1(x) + skip = x + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + x = skip + self.dropout2(x) + if not self.normalize_before: + x = self.norm2(x) + return x + + +class TransformerDecoder(nn.Module): + def __init__( + self, + num_layers, + d_model, + nhead, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + ): + super().__init__() + self.layers = nn.ModuleList( + [ + TransformerDecoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before + ) + for _ in range(num_layers) + ] + ) + self.num_layers = num_layers + self.norm = nn.LayerNorm(d_model) + + def forward( + self, x, encoder_out, decoder_pos_embed: Tensor | None = None, encoder_pos_embed: Tensor | None = None + ): + for layer in self.layers: + x = layer( + x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed + ) + if self.norm is not None: + x = self.norm(x) + return x + + +class TransformerDecoderLayer(nn.Module): + def __init__( + self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False + ): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def maybe_add_pos_embed(self, tensor: Tensor, pos_embed: Tensor | None) -> Tensor: + return tensor if pos_embed is None else tensor + pos_embed + + def forward( + self, + x: Tensor, + encoder_out: Tensor, + decoder_pos_embed: Tensor | None = None, + encoder_pos_embed: Tensor | None = None, + ) -> Tensor: + """ + Args: + x: (Decoder Sequence, Batch, Channel) tensor of input tokens. + encoder_out: (Encoder Sequence, B, C) output features from the last layer of the encoder we are + cross-attending with. + decoder_pos_embed: (ES, 1, C) positional embedding for keys (from the encoder). + encoder_pos_embed: (DS, 1, C) Positional_embedding for the queries (from the decoder). + Returns: + (DS, B, C) tensor of decoder output features. + """ + skip = x + if self.normalize_before: + x = self.norm1(x) + q = k = self.maybe_add_pos_embed(x, decoder_pos_embed) + x = self.self_attn(q, k, value=x)[0] + x = skip + self.dropout1(x) + if self.normalize_before: + skip = x + x = self.norm2(x) + else: + x = self.norm1(x) + skip = x + x = self.multihead_attn( + query=self.maybe_add_pos_embed(x, decoder_pos_embed), + key=self.maybe_add_pos_embed(encoder_out, encoder_pos_embed), + value=encoder_out, + )[0] + x = skip + self.dropout2(x) + if self.normalize_before: + skip = x + x = self.norm3(x) + else: + x = self.norm2(x) + skip = x + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + x = skip + self.dropout3(x) + if not self.normalize_before: + x = self.norm3(x) + return x + + +class SinusoidalPositionEmbedding2D(nn.Module): + """Sinusoidal positional embeddings similar to what's presented in Attention Is All You Need. + + The variation is that the position indices are normalized in [0, 2π] (not quite: the lower bound is 1/H + for the vertical direction, and 1/W for the horizontal direction. + """ + + def __init__(self, dimension: int): + """ + Args: + dimension: The desired dimension of the embeddings. + """ + super().__init__() + self.dimension = dimension + self._two_pi = 2 * math.pi + self._eps = 1e-6 + # Inverse "common ratio" for the geometric progression in sinusoid frequencies. + self._temperature = 10000 + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x: A (B, C, H, W) batch of 2D feature map to generate the embeddings for. + Returns: + A (1, C, H, W) batch of corresponding sinusoidal positional embeddings. + """ + not_mask = torch.ones_like(x[0, [0]]) # (1, H, W) + # Note: These are like range(1, H+1) and range(1, W+1) respectively, but in most implementations + # they would be range(0, H) and range(0, W). Keeping it at as to match the original code. + y_range = not_mask.cumsum(1, dtype=torch.float32) + x_range = not_mask.cumsum(2, dtype=torch.float32) + + # "Normalize" the position index such that it ranges in [0, 2π]. + # Note: Adding epsilon on the denominator should not be needed as all values of y_embed and x_range + # are non-zero by construction. This is an artifact of the original code. + y_range = y_range / (y_range[:, -1:, :] + self._eps) * self._two_pi + x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi + + inverse_frequency = self._temperature ** ( + 2 * (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2) / self.dimension + ) + + x_range = x_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1) + y_range = y_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1) + + # Note: this stack then flatten operation results in interleaved sine and cosine terms. + # pos_embed_x and pos_embed are (1, H, W, C // 2). + pos_embed_x = torch.stack((x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1).flatten(3) + pos_embed_y = torch.stack((y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1).flatten(3) + pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(0, 3, 1, 2) # (1, C, H, W) + + return pos_embed + + +def _get_activation_fn(activation: str) -> Callable: + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(f"activation should be relu/gelu/glu, not {activation}.") diff --git a/lerobot/common/policies/act/position_encoding.py b/lerobot/common/policies/act/position_encoding.py deleted file mode 100644 index 63bb4840..00000000 --- a/lerobot/common/policies/act/position_encoding.py +++ /dev/null @@ -1,102 +0,0 @@ -""" -Various positional encodings for the transformer. -""" - -import math - -import torch -from torch import nn - -from .utils import NestedTensor - - -class PositionEmbeddingSine(nn.Module): - """ - This is a more standard version of the position embedding, very similar to the one - used by the Attention is all you need paper, generalized to work on images. - """ - - def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): - super().__init__() - self.num_pos_feats = num_pos_feats - self.temperature = temperature - self.normalize = normalize - if scale is not None and normalize is False: - raise ValueError("normalize should be True if scale is passed") - if scale is None: - scale = 2 * math.pi - self.scale = scale - - def forward(self, tensor): - x = tensor - # mask = tensor_list.mask - # assert mask is not None - # not_mask = ~mask - - not_mask = torch.ones_like(x[0, [0]]) - y_embed = not_mask.cumsum(1, dtype=torch.float32) - x_embed = not_mask.cumsum(2, dtype=torch.float32) - if self.normalize: - eps = 1e-6 - y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale - x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale - - dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) - dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) - - pos_x = x_embed[:, :, :, None] / dim_t - pos_y = y_embed[:, :, :, None] / dim_t - pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) - pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) - pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) - return pos - - -class PositionEmbeddingLearned(nn.Module): - """ - Absolute pos embedding, learned. - """ - - def __init__(self, num_pos_feats=256): - super().__init__() - self.row_embed = nn.Embedding(50, num_pos_feats) - self.col_embed = nn.Embedding(50, num_pos_feats) - self.reset_parameters() - - def reset_parameters(self): - nn.init.uniform_(self.row_embed.weight) - nn.init.uniform_(self.col_embed.weight) - - def forward(self, tensor_list: NestedTensor): - x = tensor_list.tensors - h, w = x.shape[-2:] - i = torch.arange(w, device=x.device) - j = torch.arange(h, device=x.device) - x_emb = self.col_embed(i) - y_emb = self.row_embed(j) - pos = ( - torch.cat( - [ - x_emb.unsqueeze(0).repeat(h, 1, 1), - y_emb.unsqueeze(1).repeat(1, w, 1), - ], - dim=-1, - ) - .permute(2, 0, 1) - .unsqueeze(0) - .repeat(x.shape[0], 1, 1, 1) - ) - return pos - - -def build_position_encoding(args): - n_steps = args.hidden_dim // 2 - if args.position_embedding in ("v2", "sine"): - # TODO find a better way of exposing other arguments - position_embedding = PositionEmbeddingSine(n_steps, normalize=True) - elif args.position_embedding in ("v3", "learned"): - position_embedding = PositionEmbeddingLearned(n_steps) - else: - raise ValueError(f"not supported {args.position_embedding}") - - return position_embedding diff --git a/lerobot/common/policies/act/transformer.py b/lerobot/common/policies/act/transformer.py deleted file mode 100644 index 7e71f3ea..00000000 --- a/lerobot/common/policies/act/transformer.py +++ /dev/null @@ -1,240 +0,0 @@ -""" -TODO(now) -""" - -from typing import Optional - -import torch -import torch.nn.functional as F # noqa: N812 -from torch import Tensor, nn - - -class Transformer(nn.Module): - def __init__( - self, - d_model=512, - nhead=8, - num_encoder_layers=6, - num_decoder_layers=6, - dim_feedforward=2048, - dropout=0.1, - activation="relu", - normalize_before=False, - ): - super().__init__() - self.encoder = TransformerEncoder( - num_encoder_layers, d_model, nhead, dim_feedforward, dropout, activation, normalize_before - ) - self.decoder = TransformerDecoder( - num_decoder_layers, d_model, nhead, dim_feedforward, dropout, activation, normalize_before - ) - self.d_model = d_model - self.nhead = nhead - self._init_params() # TODO(now): move to somewhere common - - def _init_params(self): - for p in self.parameters(): - if p.dim() > 1: - nn.init.xavier_uniform_(p) - - def forward(self, x, encoder_pos, decoder_pos): - """ - Args: - x: ((E)ncoder (S)equence, (B)atch, (C)hannels) - decoder_pos: (Decoder Sequence, C) tensor for the decoder's positional embedding. - encoder_pos: (ES, C) tenso - """ - # TODO flatten only when input has H and W - bs = x.shape[1] - - encoder_out = self.encoder(x, pos=encoder_pos) - decoder_in = torch.zeros( - (decoder_pos.shape[0], bs, decoder_pos.shape[2]), - dtype=decoder_pos.dtype, - device=decoder_pos.device, - ) - decoder_out = self.decoder(decoder_in, encoder_out, encoder_pos=encoder_pos, decoder_pos=decoder_pos) - return decoder_out - - -class TransformerEncoder(nn.Module): - def __init__( - self, - num_layers, - d_model, - nhead, - dim_feedforward=2048, - dropout=0.1, - activation="relu", - normalize_before=False, - ): - super().__init__() - self.layers = nn.ModuleList( - [ - TransformerEncoderLayer( - d_model, nhead, dim_feedforward, dropout, activation, normalize_before - ) - for _ in range(num_layers) - ] - ) - self.norm = nn.LayerNorm(d_model) if normalize_before else nn.Identity() - - def forward(self, x, pos: Optional[Tensor] = None): - for layer in self.layers: - x = layer(x, pos=pos) - x = self.norm(x) - return x - - -class TransformerEncoderLayer(nn.Module): - def __init__( - self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False - ): - super().__init__() - self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) - # Implementation of Feedforward model - self.linear1 = nn.Linear(d_model, dim_feedforward) - self.dropout = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim_feedforward, d_model) - - self.norm1 = nn.LayerNorm(d_model) - self.norm2 = nn.LayerNorm(d_model) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) - - self.activation = _get_activation_fn(activation) - self.normalize_before = normalize_before - - def forward(self, x, pos: Optional[Tensor] = None): - skip = x - if self.normalize_before: - x = self.norm1(x) - q = k = x if pos is None else x + pos - x = self.self_attn(q, k, value=x)[0] - x = skip + self.dropout1(x) - if self.normalize_before: - skip = x - x = self.norm2(x) - else: - x = self.norm1(x) - skip = x - x = self.linear2(self.dropout(self.activation(self.linear1(x)))) - x = skip + self.dropout2(x) - if not self.normalize_before: - x = self.norm2(x) - return x - - -class TransformerDecoder(nn.Module): - def __init__( - self, - num_layers, - d_model, - nhead, - dim_feedforward=2048, - dropout=0.1, - activation="relu", - normalize_before=False, - ): - super().__init__() - self.layers = nn.ModuleList( - [ - TransformerDecoderLayer( - d_model, nhead, dim_feedforward, dropout, activation, normalize_before - ) - for _ in range(num_layers) - ] - ) - self.num_layers = num_layers - self.norm = nn.LayerNorm(d_model) - - def forward(self, x, encoder_out, decoder_pos: Tensor | None = None, encoder_pos: Tensor | None = None): - for layer in self.layers: - x = layer(x, encoder_out, decoder_pos=decoder_pos, encoder_pos=encoder_pos) - if self.norm is not None: - x = self.norm(x) - return x - - -class TransformerDecoderLayer(nn.Module): - def __init__( - self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False - ): - super().__init__() - self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) - self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) - # Implementation of Feedforward model - self.linear1 = nn.Linear(d_model, dim_feedforward) - self.dropout = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim_feedforward, d_model) - - self.norm1 = nn.LayerNorm(d_model) - self.norm2 = nn.LayerNorm(d_model) - self.norm3 = nn.LayerNorm(d_model) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) - self.dropout3 = nn.Dropout(dropout) - - self.activation = _get_activation_fn(activation) - self.normalize_before = normalize_before - - def maybe_add_pos_embed(self, tensor: Tensor, pos: Tensor | None) -> Tensor: - return tensor if pos is None else tensor + pos - - def forward( - self, - x: Tensor, - encoder_out: Tensor, - decoder_pos: Tensor | None = None, - encoder_pos: Tensor | None = None, - ) -> Tensor: - """ - Args: - x: (Decoder Sequence, Batch, Channel) tensor of input tokens. - encoder_out: (Encoder Sequence, B, C) output features from the last layer of the encoder we are - cross-attending with. - decoder_pos: (ES, 1, C) positional embedding for keys (from the encoder). - encoder_pos: (DS, 1, C) Positional_embedding for the queries (from the decoder). - Returns: - (DS, B, C) tensor of decoder output features. - """ - skip = x - if self.normalize_before: - x = self.norm1(x) - q = k = self.maybe_add_pos_embed(x, decoder_pos) - x = self.self_attn(q, k, value=x)[0] - x = skip + self.dropout1(x) - if self.normalize_before: - skip = x - x = self.norm2(x) - else: - x = self.norm1(x) - skip = x - x = self.multihead_attn( - query=self.maybe_add_pos_embed(x, decoder_pos), - key=self.maybe_add_pos_embed(encoder_out, encoder_pos), - value=encoder_out, - )[0] - x = skip + self.dropout2(x) - if self.normalize_before: - skip = x - x = self.norm3(x) - else: - x = self.norm2(x) - skip = x - x = self.linear2(self.dropout(self.activation(self.linear1(x)))) - x = skip + self.dropout3(x) - if not self.normalize_before: - x = self.norm3(x) - return x - - -def _get_activation_fn(activation): - """Return an activation function given a string""" - if activation == "relu": - return F.relu - if activation == "gelu": - return F.gelu - if activation == "glu": - return F.glu - raise RuntimeError(f"activation should be relu/gelu/glu, not {activation}.") diff --git a/lerobot/common/policies/act/utils.py b/lerobot/common/policies/act/utils.py deleted file mode 100644 index 0d935839..00000000 --- a/lerobot/common/policies/act/utils.py +++ /dev/null @@ -1,478 +0,0 @@ -""" -Misc functions, including distributed helpers. - -Mostly copy-paste from torchvision references. -""" - -import datetime -import os -import pickle -import subprocess -import time -from collections import defaultdict, deque -from typing import List, Optional - -import torch -import torch.distributed as dist - -# needed due to empty tensor bug in pytorch and torchvision 0.5 -import torchvision -from packaging import version -from torch import Tensor - -if version.parse(torchvision.__version__) < version.parse("0.7"): - from torchvision.ops import _new_empty_tensor - from torchvision.ops.misc import _output_size - - -class SmoothedValue: - """Track a series of values and provide access to smoothed values over a - window or the global series average. - """ - - def __init__(self, window_size=20, fmt=None): - if fmt is None: - fmt = "{median:.4f} ({global_avg:.4f})" - self.deque = deque(maxlen=window_size) - self.total = 0.0 - self.count = 0 - self.fmt = fmt - - def update(self, value, n=1): - self.deque.append(value) - self.count += n - self.total += value * n - - def synchronize_between_processes(self): - """ - Warning: does not synchronize the deque! - """ - if not is_dist_avail_and_initialized(): - return - t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") - dist.barrier() - dist.all_reduce(t) - t = t.tolist() - self.count = int(t[0]) - self.total = t[1] - - @property - def median(self): - d = torch.tensor(list(self.deque)) - return d.median().item() - - @property - def avg(self): - d = torch.tensor(list(self.deque), dtype=torch.float32) - return d.mean().item() - - @property - def global_avg(self): - return self.total / self.count - - @property - def max(self): - return max(self.deque) - - @property - def value(self): - return self.deque[-1] - - def __str__(self): - return self.fmt.format( - median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value - ) - - -def all_gather(data): - """ - Run all_gather on arbitrary picklable data (not necessarily tensors) - Args: - data: any picklable object - Returns: - list[data]: list of data gathered from each rank - """ - world_size = get_world_size() - if world_size == 1: - return [data] - - # serialized to a Tensor - buffer = pickle.dumps(data) - storage = torch.ByteStorage.from_buffer(buffer) - tensor = torch.ByteTensor(storage).to("cuda") - - # obtain Tensor size of each rank - local_size = torch.tensor([tensor.numel()], device="cuda") - size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] - dist.all_gather(size_list, local_size) - size_list = [int(size.item()) for size in size_list] - max_size = max(size_list) - - # receiving Tensor from all ranks - # we pad the tensor because torch all_gather does not support - # gathering tensors of different shapes - tensor_list = [] - for _ in size_list: - tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) - if local_size != max_size: - padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") - tensor = torch.cat((tensor, padding), dim=0) - dist.all_gather(tensor_list, tensor) - - data_list = [] - for size, tensor in zip(size_list, tensor_list, strict=False): - buffer = tensor.cpu().numpy().tobytes()[:size] - data_list.append(pickle.loads(buffer)) - - return data_list - - -def reduce_dict(input_dict, average=True): - """ - Args: - input_dict (dict): all the values will be reduced - average (bool): whether to do average or sum - Reduce the values in the dictionary from all processes so that all processes - have the averaged results. Returns a dict with the same fields as - input_dict, after reduction. - """ - world_size = get_world_size() - if world_size < 2: - return input_dict - with torch.no_grad(): - names = [] - values = [] - # sort the keys so that they are consistent across processes - for k in sorted(input_dict.keys()): - names.append(k) - values.append(input_dict[k]) - values = torch.stack(values, dim=0) - dist.all_reduce(values) - if average: - values /= world_size - reduced_dict = {k: v for k, v in zip(names, values, strict=False)} # noqa: C416 - return reduced_dict - - -class MetricLogger: - def __init__(self, delimiter="\t"): - self.meters = defaultdict(SmoothedValue) - self.delimiter = delimiter - - def update(self, **kwargs): - for k, v in kwargs.items(): - if isinstance(v, torch.Tensor): - v = v.item() - assert isinstance(v, (float, int)) - self.meters[k].update(v) - - def __getattr__(self, attr): - if attr in self.meters: - return self.meters[attr] - if attr in self.__dict__: - return self.__dict__[attr] - raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr)) - - def __str__(self): - loss_str = [] - for name, meter in self.meters.items(): - loss_str.append("{}: {}".format(name, str(meter))) - return self.delimiter.join(loss_str) - - def synchronize_between_processes(self): - for meter in self.meters.values(): - meter.synchronize_between_processes() - - def add_meter(self, name, meter): - self.meters[name] = meter - - def log_every(self, iterable, print_freq, header=None): - if not header: - header = "" - start_time = time.time() - end = time.time() - iter_time = SmoothedValue(fmt="{avg:.4f}") - data_time = SmoothedValue(fmt="{avg:.4f}") - space_fmt = ":" + str(len(str(len(iterable)))) + "d" - if torch.cuda.is_available(): - log_msg = self.delimiter.join( - [ - header, - "[{0" + space_fmt + "}/{1}]", - "eta: {eta}", - "{meters}", - "time: {time}", - "data: {data}", - "max mem: {memory:.0f}", - ] - ) - else: - log_msg = self.delimiter.join( - [ - header, - "[{0" + space_fmt + "}/{1}]", - "eta: {eta}", - "{meters}", - "time: {time}", - "data: {data}", - ] - ) - mega_b = 1024.0 * 1024.0 - for i, obj in enumerate(iterable): - data_time.update(time.time() - end) - yield obj - iter_time.update(time.time() - end) - if i % print_freq == 0 or i == len(iterable) - 1: - eta_seconds = iter_time.global_avg * (len(iterable) - i) - eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) - if torch.cuda.is_available(): - print( - log_msg.format( - i, - len(iterable), - eta=eta_string, - meters=str(self), - time=str(iter_time), - data=str(data_time), - memory=torch.cuda.max_memory_allocated() / mega_b, - ) - ) - else: - print( - log_msg.format( - i, - len(iterable), - eta=eta_string, - meters=str(self), - time=str(iter_time), - data=str(data_time), - ) - ) - end = time.time() - total_time = time.time() - start_time - total_time_str = str(datetime.timedelta(seconds=int(total_time))) - print("{} Total time: {} ({:.4f} s / it)".format(header, total_time_str, total_time / len(iterable))) - - -def get_sha(): - cwd = os.path.dirname(os.path.abspath(__file__)) - - def _run(command): - return subprocess.check_output(command, cwd=cwd).decode("ascii").strip() - - sha = "N/A" - diff = "clean" - branch = "N/A" - try: - sha = _run(["git", "rev-parse", "HEAD"]) - subprocess.check_output(["git", "diff"], cwd=cwd) - diff = _run(["git", "diff-index", "HEAD"]) - diff = "has uncommited changes" if diff else "clean" - branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"]) - except Exception: - pass - message = f"sha: {sha}, status: {diff}, branch: {branch}" - return message - - -def collate_fn(batch): - batch = list(zip(*batch, strict=False)) - batch[0] = nested_tensor_from_tensor_list(batch[0]) - return tuple(batch) - - -def _max_by_axis(the_list): - # type: (List[List[int]]) -> List[int] - maxes = the_list[0] - for sublist in the_list[1:]: - for index, item in enumerate(sublist): - maxes[index] = max(maxes[index], item) - return maxes - - -class NestedTensor: - def __init__(self, tensors, mask: Optional[Tensor]): - self.tensors = tensors - self.mask = mask - - def to(self, device): - # type: (Device) -> NestedTensor # noqa - cast_tensor = self.tensors.to(device) - mask = self.mask - if mask is not None: - assert mask is not None - cast_mask = mask.to(device) - else: - cast_mask = None - return NestedTensor(cast_tensor, cast_mask) - - def decompose(self): - return self.tensors, self.mask - - def __repr__(self): - return str(self.tensors) - - -def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): - # TODO make this more general - if tensor_list[0].ndim == 3: - if torchvision._is_tracing(): - # nested_tensor_from_tensor_list() does not export well to ONNX - # call _onnx_nested_tensor_from_tensor_list() instead - return _onnx_nested_tensor_from_tensor_list(tensor_list) - - # TODO make it support different-sized images - max_size = _max_by_axis([list(img.shape) for img in tensor_list]) - # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) - batch_shape = [len(tensor_list)] + max_size - b, c, h, w = batch_shape - dtype = tensor_list[0].dtype - device = tensor_list[0].device - tensor = torch.zeros(batch_shape, dtype=dtype, device=device) - mask = torch.ones((b, h, w), dtype=torch.bool, device=device) - for img, pad_img, m in zip(tensor_list, tensor, mask, strict=False): - pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) - m[: img.shape[1], : img.shape[2]] = False - else: - raise ValueError("not supported") - return NestedTensor(tensor, mask) - - -# _onnx_nested_tensor_from_tensor_list() is an implementation of -# nested_tensor_from_tensor_list() that is supported by ONNX tracing. -@torch.jit.unused -def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: - max_size = [] - for i in range(tensor_list[0].dim()): - max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to( - torch.int64 - ) - max_size.append(max_size_i) - max_size = tuple(max_size) - - # work around for - # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) - # m[: img.shape[1], :img.shape[2]] = False - # which is not yet supported in onnx - padded_imgs = [] - padded_masks = [] - for img in tensor_list: - padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape), strict=False)] - padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) - padded_imgs.append(padded_img) - - m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) - padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) - padded_masks.append(padded_mask.to(torch.bool)) - - tensor = torch.stack(padded_imgs) - mask = torch.stack(padded_masks) - - return NestedTensor(tensor, mask=mask) - - -def setup_for_distributed(is_master): - """ - This function disables printing when not in master process - """ - import builtins as __builtin__ - - builtin_print = __builtin__.print - - def print(*args, **kwargs): - force = kwargs.pop("force", False) - if is_master or force: - builtin_print(*args, **kwargs) - - __builtin__.print = print - - -def is_dist_avail_and_initialized(): - if not dist.is_available(): - return False - if not dist.is_initialized(): - return False - return True - - -def get_world_size(): - if not is_dist_avail_and_initialized(): - return 1 - return dist.get_world_size() - - -def get_rank(): - if not is_dist_avail_and_initialized(): - return 0 - return dist.get_rank() - - -def is_main_process(): - return get_rank() == 0 - - -def save_on_master(*args, **kwargs): - if is_main_process(): - torch.save(*args, **kwargs) - - -def init_distributed_mode(args): - if "RANK" in os.environ and "WORLD_SIZE" in os.environ: - args.rank = int(os.environ["RANK"]) - args.world_size = int(os.environ["WORLD_SIZE"]) - args.gpu = int(os.environ["LOCAL_RANK"]) - elif "SLURM_PROCID" in os.environ: - args.rank = int(os.environ["SLURM_PROCID"]) - args.gpu = args.rank % torch.cuda.device_count() - else: - print("Not using distributed mode") - args.distributed = False - return - - args.distributed = True - - torch.cuda.set_device(args.gpu) - args.dist_backend = "nccl" - print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True) - torch.distributed.init_process_group( - backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank - ) - torch.distributed.barrier() - setup_for_distributed(args.rank == 0) - - -@torch.no_grad() -def accuracy(output, target, topk=(1,)): - """Computes the precision@k for the specified values of k""" - if target.numel() == 0: - return [torch.zeros([], device=output.device)] - maxk = max(topk) - batch_size = target.size(0) - - _, pred = output.topk(maxk, 1, True, True) - pred = pred.t() - correct = pred.eq(target.view(1, -1).expand_as(pred)) - - res = [] - for k in topk: - correct_k = correct[:k].view(-1).float().sum(0) - res.append(correct_k.mul_(100.0 / batch_size)) - return res - - -def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): - # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor - """ - Equivalent to nn.functional.interpolate, but with support for empty batch sizes. - This will eventually be supported natively by PyTorch, and this - class can go away. - """ - if version.parse(torchvision.__version__) < version.parse("0.7"): - if input.numel() > 0: - return torch.nn.functional.interpolate(input, size, scale_factor, mode, align_corners) - - output_shape = _output_size(2, input, size, scale_factor) - output_shape = list(input.shape[:-2]) + list(output_shape) - return _new_empty_tensor(input, output_shape) - else: - return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) diff --git a/lerobot/configs/policy/act.yaml b/lerobot/configs/policy/act.yaml index 1086b595..22b6cd6f 100644 --- a/lerobot/configs/policy/act.yaml +++ b/lerobot/configs/policy/act.yaml @@ -33,11 +33,10 @@ policy: nheads: 8 #camera_names: [top, front_close, left_pillar, right_pillar] camera_names: [top] - position_embedding: sine - masks: false dilation: false dropout: 0.1 pre_norm: false + activation: relu vae: true diff --git a/scripts/convert_act_weights.py b/scripts/convert_act_weights.py index d0c0c3e7..c8f83422 100644 --- a/scripts/convert_act_weights.py +++ b/scripts/convert_act_weights.py @@ -11,6 +11,19 @@ policy = make_policy(cfg) state_dict = torch.load("/home/alexander/Projects/act/outputs/sim_transfer_cube_human_vae/policy_last.ckpt") +# Remove keys based on what they start with. + +start_removals = [ + # There is a bug that means the pretrained model doesn't even use the final decoder layers. + *[f"model.transformer.decoder.layers.{i}" for i in range(1, 7)], + "model.is_pad_head.", +] + +for to_remove in start_removals: + for k in list(state_dict.keys()): + if k.startswith(to_remove): + del state_dict[k] + # Replace keys based on what they start with. @@ -26,6 +39,9 @@ start_replacements = [ ("model.input_proj.", "model.encoder_img_feat_input_proj."), ("model.input_proj_robot_state", "model.encoder_robot_state_input_proj"), ("model.latent_out_proj.", "model.encoder_latent_input_proj."), + ("model.transformer.encoder.", "model.encoder."), + ("model.transformer.decoder.", "model.decoder."), + ("model.backbones.0.0.body.", "model.backbone."), ] for to_replace, replace_with in start_replacements: @@ -35,18 +51,6 @@ for to_replace, replace_with in start_replacements: state_dict[k_] = state_dict[k] del state_dict[k] -# Remove keys based on what they start with. - -start_removals = [ - # There is a bug that means the pretrained model doesn't even use the final decoder layers. - *[f"model.transformer.decoder.layers.{i}" for i in range(1, 7)], - "model.is_pad_head.", -] - -for to_remove in start_removals: - for k in list(state_dict.keys()): - if k.startswith(to_remove): - del state_dict[k] missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)