From 110ac5ffa123c64eb61a313eb08638ed6efe84ee Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Wed, 3 Apr 2024 14:21:07 +0100 Subject: [PATCH] backup wip --- lerobot/common/envs/aloha/env.py | 1 - lerobot/common/policies/act/detr_vae.py | 216 ++++++++++----------- lerobot/common/policies/act/policy.py | 5 +- lerobot/common/policies/act/transformer.py | 85 ++------ lerobot/configs/policy/act.yaml | 2 +- scripts/convert_act_weights.py | 64 ++++++ 6 files changed, 182 insertions(+), 191 deletions(-) create mode 100644 scripts/convert_act_weights.py diff --git a/lerobot/common/envs/aloha/env.py b/lerobot/common/envs/aloha/env.py index 8f907650..ad8087d0 100644 --- a/lerobot/common/envs/aloha/env.py +++ b/lerobot/common/envs/aloha/env.py @@ -191,7 +191,6 @@ class AlohaEnv(AbstractEnv): { "observation": TensorDict(obs, batch_size=[]), "reward": torch.tensor([reward], dtype=torch.float32), - # success and done are true when coverage > self.success_threshold in env "done": torch.tensor([done], dtype=torch.bool), "success": torch.tensor([success], dtype=torch.bool), }, diff --git a/lerobot/common/policies/act/detr_vae.py b/lerobot/common/policies/act/detr_vae.py index f21308ad..ff137a34 100644 --- a/lerobot/common/policies/act/detr_vae.py +++ b/lerobot/common/policies/act/detr_vae.py @@ -1,18 +1,12 @@ +import einops import numpy as np import torch from torch import nn -from torch.autograd import Variable from .backbone import build_backbone from .transformer import TransformerEncoder, TransformerEncoderLayer, build_transformer -def reparametrize(mu, logvar): - std = logvar.div(2).exp() - eps = Variable(std.data.new(std.size()).normal_()) - return mu + std * eps - - def get_sinusoid_encoding_table(n_position, d_hid): def get_position_angle_vec(position): return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] @@ -27,7 +21,7 @@ def get_sinusoid_encoding_table(n_position, d_hid): class ActionChunkingTransformer(nn.Module): """ Action Chunking Transformer as per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware - (https://arxiv.org/abs/2304.13705). + (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act) Note: In this code we use the symbols `vae_encoder`, 'encoder', `decoder`. The meanings are as follows. - The `vae_encoder` is, as per the literature around conditional variational auto-encoders (cVAE), the @@ -49,7 +43,7 @@ class ActionChunkingTransformer(nn.Module): """ def __init__( - self, backbones, transformer, vae_encoder, state_dim, action_dim, horizon, camera_names, vae + self, backbones, transformer, vae_encoder, state_dim, action_dim, horizon, camera_names, use_vae ): """Initializes the model. Parameters: @@ -63,134 +57,124 @@ class ActionChunkingTransformer(nn.Module): state_dim: Robot positional state dimension. action_dim: Action dimension. horizon: The number of actions to generate in one forward pass. - vae: Whether to use the variational objective. TODO(now): Give more details. + use_vae: Whether to use the variational objective. TODO(now): Give more details. """ super().__init__() + self.camera_names = camera_names self.transformer = transformer self.vae_encoder = vae_encoder - self.vae = vae + self.use_vae = use_vae hidden_dim = transformer.d_model - self.action_head = nn.Linear(hidden_dim, action_dim) - self.is_pad_head = nn.Linear(hidden_dim, 1) - # Positional embedding to be used as input to the latent vae_encoder (if applicable) and for the - self.pos_embed = nn.Embedding(horizon, hidden_dim) - if backbones is not None: - self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1) - self.backbones = nn.ModuleList(backbones) - self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim) - else: - # input_dim = 14 + 7 # robot_state + env_state - self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim) - # TODO(rcadene): understand what is env_state, and why it needs to be 7 - self.input_proj_env_state = nn.Linear(state_dim // 2, hidden_dim) - self.pos = torch.nn.Embedding(2, hidden_dim) - self.backbones = None - # vae_encoder extra parameters - self.latent_dim = 32 # final size of latent z # TODO tune - self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding - self.vae_encoder_action_proj = nn.Linear(14, hidden_dim) # project action to embedding - self.vae_encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding - self.latent_proj = nn.Linear( - hidden_dim, self.latent_dim * 2 - ) # project hidden state to latent std, var - self.register_buffer( - "pos_table", get_sinusoid_encoding_table(1 + 1 + horizon, hidden_dim) - ) # [CLS], qpos, a_seq + # BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence]. + # The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]). + if use_vae: + self.cls_embed = nn.Embedding(1, hidden_dim) + # Projection layer for joint-space configuration to hidden dimension. + self.vae_encoder_robot_state_input_proj = nn.Linear(state_dim, hidden_dim) + # Projection layer for action (joint-space target) to hidden dimension. + self.vae_encoder_action_input_proj = nn.Linear(state_dim, hidden_dim) + # Final size of latent z. TODO(now): Add to hyperparams. + self.latent_dim = 32 + # Projection layer from the VAE encoder's output to the latent distribution's parameter space. + self.vae_encoder_latent_output_proj = nn.Linear(hidden_dim, self.latent_dim * 2) + # Fixed sinusoidal positional embedding the whole input to the VAE encoder. + self.register_buffer( + "vae_encoder_pos_enc", get_sinusoid_encoding_table(1 + 1 + horizon, hidden_dim) + ) - # decoder extra parameters - self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding + # Transformer encoder input projections. The tokens will be structured like + # [latent, robot_state, image_feature_map_pixels]. + self.backbones = nn.ModuleList(backbones) + self.encoder_img_feat_input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1) + self.encoder_robot_state_input_proj = nn.Linear(state_dim, hidden_dim) + self.encoder_latent_input_proj = nn.Linear(self.latent_dim, hidden_dim) + # TODO(now): Fix this nonsense. One positional embedding is needed. We should extract the image + # feature dimension with a dry run. self.additional_pos_embed = nn.Embedding( 2, hidden_dim ) # learned position embedding for proprio and latent - def forward(self, qpos, image, env_state, actions=None, is_pad=None): + # Transformer decoder. + # Learnable positional embedding for the transformer's decoder (in the style of DETR object queries). + self.decoder_pos_embed = nn.Embedding(horizon, hidden_dim) + # Final action regression head on the output of the transformer's decoder. + self.action_head = nn.Linear(hidden_dim, action_dim) + + def forward(self, robot_state, image, actions=None): """ - qpos: batch, qpos_dim - image: batch, num_cam, channel, height, width - env_state: None - actions: batch, seq, action_dim + Args: + robot_state: (B, J) batch of robot joint configurations. + image: (B, N, C, H, W) batch of N camera frames. + actions: (B, S, A) batch of actions from the target dataset which must be provided if the + VAE is enabled and the model is in training mode. """ - is_training = actions is not None # train or val - bs, _ = qpos.shape - ### Obtain latent z from action sequence - if self.vae and is_training: - # project action sequence to embedding dim, and concat with a CLS token - action_embed = self.vae_encoder_action_proj(actions) # (bs, seq, hidden_dim) - qpos_embed = self.vae_encoder_joint_proj(qpos) # (bs, hidden_dim) - qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim) - cls_embed = self.cls_embed.weight # (1, hidden_dim) - cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim) - vae_encoder_input = torch.cat( - [cls_embed, qpos_embed, action_embed], axis=1 - ) # (bs, seq+1, hidden_dim) - vae_encoder_input = vae_encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim) - # do not mask cls token - # cls_joint_is_pad = torch.full((bs, 2), False).to(qpos.device) # False: not a padding - # is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1) - # obtain position embedding - pos_embed = self.pos_table.clone().detach() - pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim) - # query model + if self.use_vae and self.training: + assert ( + actions is not None + ), "actions must be provided when using the variational objective in training mode." + + batch_size, _ = robot_state.shape + + # Prepare the latent for input to the transformer. + if self.use_vae and actions is not None: + # Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence]. + cls_embed = einops.repeat(self.cls_embed.weight, "1 d -> b 1 d", b=batch_size) # (B, 1, D) + robot_state_embed = self.vae_encoder_robot_state_input_proj(robot_state).unsqueeze(1) # (B, 1, D) + action_embed = self.vae_encoder_action_input_proj(actions) # (B, S, D) + vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D) + vae_encoder_input = vae_encoder_input.permute(1, 0, 2) # (S+2, B, D) + # Note: detach() shouldn't be necessary but leaving it the same as the original code just in case. + # Prepare fixed positional embedding. + pos_embed = self.vae_encoder_pos_enc.clone().detach().permute(1, 0, 2) # (S+2, 1, D) + # Forward pass through VAE encoder and sample the latent with the reparameterization trick. vae_encoder_output = self.vae_encoder( vae_encoder_input, pos=pos_embed - ) # , src_key_padding_mask=is_pad) + ) # , src_key_padding_mask=is_pad) # TODO(now) vae_encoder_output = vae_encoder_output[0] # take cls output only - latent_info = self.latent_proj(vae_encoder_output) - mu = latent_info[:, : self.latent_dim] - logvar = latent_info[:, self.latent_dim :] - latent_sample = reparametrize(mu, logvar) - latent_input = self.latent_out_proj(latent_sample) + latent_pdf_params = self.vae_encoder_latent_output_proj(vae_encoder_output) + mu = latent_pdf_params[:, : self.latent_dim] + logvar = latent_pdf_params[:, self.latent_dim :] + # Use reparameterization trick to sample from the latent's PDF. + latent_sample = mu + logvar.div(2).exp() * torch.randn_like(mu) else: + # When not using the VAE encoder, we set the latent to be all zeros. mu = logvar = None - latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(qpos.device) - latent_input = self.latent_out_proj(latent_sample) + latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=robot_state.dtype).to( + robot_state.device + ) - if self.backbones is not None: - # Image observation features and position embeddings - all_cam_features = [] - all_cam_pos = [] - for cam_id, _ in enumerate(self.camera_names): - features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED - features = features[0] # take the last layer feature - pos = pos[0] - all_cam_features.append(self.input_proj(features)) - all_cam_pos.append(pos) - # proprioception features - proprio_input = self.input_proj_robot_state(qpos) - # fold camera dimension into width dimension - src = torch.cat(all_cam_features, axis=3) - pos = torch.cat(all_cam_pos, axis=3) - hs = self.transformer( - src, - None, - self.pos_embed.weight, - pos, - latent_input, - proprio_input, - self.additional_pos_embed.weight, - )[0] - else: - qpos = self.input_proj_robot_state(qpos) - env_state = self.input_proj_env_state(env_state) - transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2 - hs = self.transformer(transformer_input, None, self.pos_embed.weight, self.pos.weight)[0] - a_hat = self.action_head(hs) - is_pad_hat = self.is_pad_head(hs) - return a_hat, is_pad_hat, [mu, logvar] + # Prepare all other transformer inputs. + # Image observation features and position embeddings. + all_cam_features = [] + all_cam_pos = [] + for cam_id, _ in enumerate(self.camera_names): + # TODO(now): remove the positional embedding from the backbones. + features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED + features = features[0] # take the last layer feature + pos = pos[0] + all_cam_features.append(self.encoder_img_feat_input_proj(features)) + all_cam_pos.append(pos) + # Concatenate image observation feature maps along the width dimension. + transformer_input = torch.cat(all_cam_features, axis=3) + # TODO(now): remove the positional embedding from the backbones. + pos = torch.cat(all_cam_pos, axis=3) + robot_state_embed = self.encoder_robot_state_input_proj(robot_state) + latent_embed = self.encoder_latent_input_proj(latent_sample) + # Run the transformer and project the outputs to the action space. + transformer_output = self.transformer( + transformer_input, + query_embed=self.decoder_pos_embed.weight, + pos_embed=pos, + latent_input=latent_embed, + proprio_input=robot_state_embed, + additional_pos_embed=self.additional_pos_embed.weight, + ) + a_hat = self.action_head(transformer_output) -def mlp(input_dim, hidden_dim, output_dim, hidden_depth): - if hidden_depth == 0: - mods = [nn.Linear(input_dim, output_dim)] - else: - mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)] - for _ in range(hidden_depth - 1): - mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)] - mods.append(nn.Linear(hidden_dim, output_dim)) - trunk = nn.Sequential(*mods) - return trunk + return a_hat, [mu, logvar] def build_vae_encoder(args): @@ -231,7 +215,7 @@ def build(args): action_dim=args.action_dim, horizon=args.num_queries, camera_names=args.camera_names, - vae=args.vae, + use_vae=args.vae, ) n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) diff --git a/lerobot/common/policies/act/policy.py b/lerobot/common/policies/act/policy.py index 5cf74ae5..7d24620a 100644 --- a/lerobot/common/policies/act/policy.py +++ b/lerobot/common/policies/act/policy.py @@ -224,8 +224,7 @@ class ActionChunkingTransformerPolicy(AbstractPolicy): if is_pad is not None: is_pad = is_pad[:, : self.model.num_queries] - breakpoint() - a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad) + a_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad) all_l1 = F.l1_loss(actions, a_hat, reduction="none") l1 = all_l1.mean() if is_pad is None else (all_l1 * ~is_pad.unsqueeze(-1)).mean() @@ -240,5 +239,5 @@ class ActionChunkingTransformerPolicy(AbstractPolicy): loss_dict["loss"] = loss_dict["l1"] return loss_dict else: - action, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior + action, _ = self.model(qpos, image, env_state) # no action, sample from prior return action diff --git a/lerobot/common/policies/act/transformer.py b/lerobot/common/policies/act/transformer.py index 20cfc815..11d5a013 100644 --- a/lerobot/common/policies/act/transformer.py +++ b/lerobot/common/policies/act/transformer.py @@ -26,10 +26,8 @@ class Transformer(nn.Module): dropout=0.1, activation="relu", normalize_before=False, - return_intermediate_dec=False, ): super().__init__() - encoder_layer = TransformerEncoderLayer( d_model, nhead, dim_feedforward, dropout, activation, normalize_before ) @@ -40,9 +38,7 @@ class Transformer(nn.Module): d_model, nhead, dim_feedforward, dropout, activation, normalize_before ) decoder_norm = nn.LayerNorm(d_model) - self.decoder = TransformerDecoder( - decoder_layer, num_decoder_layers, decoder_norm, return_intermediate=return_intermediate_dec - ) + self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm) self._reset_parameters() @@ -57,7 +53,6 @@ class Transformer(nn.Module): def forward( self, src, - mask, query_embed, pos_embed, latent_input=None, @@ -68,10 +63,10 @@ class Transformer(nn.Module): if len(src.shape) == 4: # has H and W # flatten NxCxHxW to HWxNxC bs, c, h, w = src.shape + # Each "pixel" on the feature maps will form a token. src = src.flatten(2).permute(2, 0, 1) pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1) query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) - # mask = mask.flatten(1) additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0) @@ -87,9 +82,9 @@ class Transformer(nn.Module): query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) tgt = torch.zeros_like(query_embed) - memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) - hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed) - hs = hs.transpose(1, 2) + memory = self.encoder(src, pos=pos_embed) + hs = self.decoder(tgt, memory, pos=pos_embed, query_pos=query_embed) + hs = hs.transpose(0, 1) return hs @@ -103,14 +98,12 @@ class TransformerEncoder(nn.Module): def forward( self, src, - mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, ): output = src for layer in self.layers: - output = layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos) + output = layer(output, pos=pos) if self.norm is not None: output = self.norm(output) @@ -119,52 +112,33 @@ class TransformerEncoder(nn.Module): class TransformerDecoder(nn.Module): - def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): + def __init__(self, decoder_layer, num_layers, norm=None): super().__init__() self.layers = _get_clones(decoder_layer, num_layers) self.num_layers = num_layers self.norm = norm - self.return_intermediate = return_intermediate def forward( self, tgt, memory, - tgt_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, - tgt_key_padding_mask: Optional[Tensor] = None, - memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, ): output = tgt - intermediate = [] - for layer in self.layers: output = layer( output, memory, - tgt_mask=tgt_mask, - memory_mask=memory_mask, - tgt_key_padding_mask=tgt_key_padding_mask, - memory_key_padding_mask=memory_key_padding_mask, pos=pos, query_pos=query_pos, ) - if self.return_intermediate: - intermediate.append(self.norm(output)) if self.norm is not None: output = self.norm(output) - if self.return_intermediate: - intermediate.pop() - intermediate.append(output) - if self.return_intermediate: - return torch.stack(intermediate) - - return output.unsqueeze(0) + return output class TransformerEncoderLayer(nn.Module): @@ -192,12 +166,10 @@ class TransformerEncoderLayer(nn.Module): def forward_post( self, src, - src_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, ): q = k = self.with_pos_embed(src, pos) - src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] + src2 = self.self_attn(q, k, value=src)[0] src = src + self.dropout1(src2) src = self.norm1(src) src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) @@ -208,13 +180,11 @@ class TransformerEncoderLayer(nn.Module): def forward_pre( self, src, - src_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, ): src2 = self.norm1(src) q = k = self.with_pos_embed(src2, pos) - src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] + src2 = self.self_attn(q, k, value=src2)[0] src = src + self.dropout1(src2) src2 = self.norm2(src) src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) @@ -224,13 +194,11 @@ class TransformerEncoderLayer(nn.Module): def forward( self, src, - src_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, ): if self.normalize_before: - return self.forward_pre(src, src_mask, src_key_padding_mask, pos) - return self.forward_post(src, src_mask, src_key_padding_mask, pos) + return self.forward_pre(src, pos) + return self.forward_post(src, pos) class TransformerDecoderLayer(nn.Module): @@ -262,23 +230,17 @@ class TransformerDecoderLayer(nn.Module): self, tgt, memory, - tgt_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, - tgt_key_padding_mask: Optional[Tensor] = None, - memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, ): q = k = self.with_pos_embed(tgt, query_pos) - tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] + tgt2 = self.self_attn(q, k, value=tgt)[0] tgt = tgt + self.dropout1(tgt2) tgt = self.norm1(tgt) tgt2 = self.multihead_attn( query=self.with_pos_embed(tgt, query_pos), key=self.with_pos_embed(memory, pos), value=memory, - attn_mask=memory_mask, - key_padding_mask=memory_key_padding_mask, )[0] tgt = tgt + self.dropout2(tgt2) tgt = self.norm2(tgt) @@ -291,24 +253,18 @@ class TransformerDecoderLayer(nn.Module): self, tgt, memory, - tgt_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, - tgt_key_padding_mask: Optional[Tensor] = None, - memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, ): tgt2 = self.norm1(tgt) q = k = self.with_pos_embed(tgt2, query_pos) - tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] + tgt2 = self.self_attn(q, k, value=tgt2)[0] tgt = tgt + self.dropout1(tgt2) tgt2 = self.norm2(tgt) tgt2 = self.multihead_attn( query=self.with_pos_embed(tgt2, query_pos), key=self.with_pos_embed(memory, pos), value=memory, - attn_mask=memory_mask, - key_padding_mask=memory_key_padding_mask, )[0] tgt = tgt + self.dropout2(tgt2) tgt2 = self.norm3(tgt) @@ -320,10 +276,6 @@ class TransformerDecoderLayer(nn.Module): self, tgt, memory, - tgt_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, - tgt_key_padding_mask: Optional[Tensor] = None, - memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, ): @@ -331,16 +283,10 @@ class TransformerDecoderLayer(nn.Module): return self.forward_pre( tgt, memory, - tgt_mask, - memory_mask, - tgt_key_padding_mask, - memory_key_padding_mask, pos, query_pos, ) - return self.forward_post( - tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos - ) + return self.forward_post(tgt, memory, pos, query_pos) def _get_clones(module, n): @@ -356,7 +302,6 @@ def build_transformer(args): num_encoder_layers=args.enc_layers, num_decoder_layers=args.dec_layers, normalize_before=args.pre_norm, - return_intermediate_dec=True, ) diff --git a/lerobot/configs/policy/act.yaml b/lerobot/configs/policy/act.yaml index 0244944b..1086b595 100644 --- a/lerobot/configs/policy/act.yaml +++ b/lerobot/configs/policy/act.yaml @@ -29,7 +29,7 @@ policy: hidden_dim: 512 dim_feedforward: 3200 enc_layers: 4 - dec_layers: 7 + dec_layers: 1 nheads: 8 #camera_names: [top, front_close, left_pillar, right_pillar] camera_names: [top] diff --git a/scripts/convert_act_weights.py b/scripts/convert_act_weights.py new file mode 100644 index 00000000..d0c0c3e7 --- /dev/null +++ b/scripts/convert_act_weights.py @@ -0,0 +1,64 @@ +import torch + +from lerobot.common.policies.factory import make_policy +from lerobot.common.utils import init_hydra_config + +cfg = init_hydra_config( + "/home/alexander/Projects/lerobot/outputs/train/act_aloha_sim_transfer_cube_human/.hydra/config.yaml" +) + +policy = make_policy(cfg) + +state_dict = torch.load("/home/alexander/Projects/act/outputs/sim_transfer_cube_human_vae/policy_last.ckpt") + + +# Replace keys based on what they start with. + +start_replacements = [ + ("model.query_embed.weight", "model.pos_embed.weight"), + ("model.pos_table", "model.vae_encoder_pos_enc"), + ("model.pos_embed.weight", "model.decoder_pos_embed.weight"), + ("model.encoder.", "model.vae_encoder."), + ("model.encoder_action_proj.", "model.vae_encoder_action_input_proj."), + ("model.encoder_joint_proj.", "model.vae_encoder_robot_state_input_proj."), + ("model.latent_proj.", "model.vae_encoder_latent_output_proj."), + ("model.latent_proj.", "model.vae_encoder_latent_output_proj."), + ("model.input_proj.", "model.encoder_img_feat_input_proj."), + ("model.input_proj_robot_state", "model.encoder_robot_state_input_proj"), + ("model.latent_out_proj.", "model.encoder_latent_input_proj."), +] + +for to_replace, replace_with in start_replacements: + for k in list(state_dict.keys()): + if k.startswith(to_replace): + k_ = replace_with + k.removeprefix(to_replace) + state_dict[k_] = state_dict[k] + del state_dict[k] + +# Remove keys based on what they start with. + +start_removals = [ + # There is a bug that means the pretrained model doesn't even use the final decoder layers. + *[f"model.transformer.decoder.layers.{i}" for i in range(1, 7)], + "model.is_pad_head.", +] + +for to_remove in start_removals: + for k in list(state_dict.keys()): + if k.startswith(to_remove): + del state_dict[k] + +missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False) + +if len(missing_keys) != 0: + print("MISSING KEYS") + print(missing_keys) +if len(unexpected_keys) != 0: + print("UNEXPECTED KEYS") + print(unexpected_keys) + +# if len(missing_keys) != 0 or len(unexpected_keys) != 0: +# print("Failed due to mismatch in state dicts.") +# exit() + +policy.save("/tmp/weights.pth")