diff --git a/lerobot/common/policies/act/detr_vae.py b/lerobot/common/policies/act/detr_vae.py index ff137a34..aaf4d098 100644 --- a/lerobot/common/policies/act/detr_vae.py +++ b/lerobot/common/policies/act/detr_vae.py @@ -4,7 +4,7 @@ import torch from torch import nn from .backbone import build_backbone -from .transformer import TransformerEncoder, TransformerEncoderLayer, build_transformer +from .transformer import Transformer, TransformerEncoder def get_sinusoid_encoding_table(n_position, d_hid): @@ -124,16 +124,14 @@ class ActionChunkingTransformer(nn.Module): robot_state_embed = self.vae_encoder_robot_state_input_proj(robot_state).unsqueeze(1) # (B, 1, D) action_embed = self.vae_encoder_action_input_proj(actions) # (B, S, D) vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D) - vae_encoder_input = vae_encoder_input.permute(1, 0, 2) # (S+2, B, D) # Note: detach() shouldn't be necessary but leaving it the same as the original code just in case. # Prepare fixed positional embedding. - pos_embed = self.vae_encoder_pos_enc.clone().detach().permute(1, 0, 2) # (S+2, 1, D) + pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D) # Forward pass through VAE encoder and sample the latent with the reparameterization trick. - vae_encoder_output = self.vae_encoder( - vae_encoder_input, pos=pos_embed - ) # , src_key_padding_mask=is_pad) # TODO(now) - vae_encoder_output = vae_encoder_output[0] # take cls output only - latent_pdf_params = self.vae_encoder_latent_output_proj(vae_encoder_output) + cls_token_out = self.vae_encoder( + vae_encoder_input.permute(1, 0, 2), pos=pos_embed.permute(1, 0, 2) + )[0] # (B, D) + latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out) mu = latent_pdf_params[:, : self.latent_dim] logvar = latent_pdf_params[:, self.latent_dim :] # Use reparameterization trick to sample from the latent's PDF. @@ -151,10 +149,11 @@ class ActionChunkingTransformer(nn.Module): all_cam_pos = [] for cam_id, _ in enumerate(self.camera_names): # TODO(now): remove the positional embedding from the backbones. - features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED - features = features[0] # take the last layer feature + cam_features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED + cam_features = cam_features[0] # take the last layer feature pos = pos[0] - all_cam_features.append(self.encoder_img_feat_input_proj(features)) + cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w) + all_cam_features.append(cam_features) all_cam_pos.append(pos) # Concatenate image observation feature maps along the width dimension. transformer_input = torch.cat(all_cam_features, axis=3) @@ -163,36 +162,25 @@ class ActionChunkingTransformer(nn.Module): robot_state_embed = self.encoder_robot_state_input_proj(robot_state) latent_embed = self.encoder_latent_input_proj(latent_sample) + # TODO(now): Explain all of this madness. + transformer_input = torch.cat( + [ + torch.stack([latent_embed, robot_state_embed], axis=0), + transformer_input.flatten(2).permute(2, 0, 1), + ] + ) + pos_embed = torch.cat( + [self.additional_pos_embed.weight.unsqueeze(1), pos.flatten(2).permute(2, 0, 1)], axis=0 + ) + # Run the transformer and project the outputs to the action space. transformer_output = self.transformer( transformer_input, - query_embed=self.decoder_pos_embed.weight, - pos_embed=pos, - latent_input=latent_embed, - proprio_input=robot_state_embed, - additional_pos_embed=self.additional_pos_embed.weight, - ) - a_hat = self.action_head(transformer_output) - - return a_hat, [mu, logvar] - - -def build_vae_encoder(args): - d_model = args.hidden_dim # 256 - dropout = args.dropout # 0.1 - nhead = args.nheads # 8 - dim_feedforward = args.dim_feedforward # 2048 - num_encoder_layers = args.enc_layers # 4 # TODO shared with VAE decoder - normalize_before = args.pre_norm # False - activation = "relu" - - encoder_layer = TransformerEncoderLayer( - d_model, nhead, dim_feedforward, dropout, activation, normalize_before - ) - encoder_norm = nn.LayerNorm(d_model) if normalize_before else None - encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) - - return encoder + encoder_pos=pos_embed, + decoder_pos=self.decoder_pos_embed.weight.unsqueeze(1), + ).transpose(0, 1) # back to (B, S, C) + actions = self.action_head(transformer_output) + return actions, [mu, logvar] def build(args): @@ -203,9 +191,26 @@ def build(args): backbone = build_backbone(args) backbones.append(backbone) - transformer = build_transformer(args) + transformer = Transformer( + d_model=args.hidden_dim, + dropout=args.dropout, + nhead=args.nheads, + dim_feedforward=args.dim_feedforward, + num_encoder_layers=args.enc_layers, + num_decoder_layers=args.dec_layers, + normalize_before=args.pre_norm, + ) - vae_encoder = build_vae_encoder(args) + # TODO(now): args.enc_layers shouldn't be shared with the transformer decoder + vae_encoder = TransformerEncoder( + num_layers=args.enc_layers, + d_model=args.hidden_dim, + nhead=args.nheads, + dim_feedforward=args.dim_feedforward, + dropout=args.dropout, + activation="relu", + normalize_before=args.pre_norm, + ) model = ActionChunkingTransformer( backbones, diff --git a/lerobot/common/policies/act/transformer.py b/lerobot/common/policies/act/transformer.py index 11d5a013..7e71f3ea 100644 --- a/lerobot/common/policies/act/transformer.py +++ b/lerobot/common/policies/act/transformer.py @@ -1,13 +1,7 @@ """ -DETR Transformer class. - -Copy-paste from torch.nn.Transformer with modifications: - * positional encodings are passed in MHattention - * extra LN at the end of encoder is removed - * decoder returns a stack of activations from all decoding layers +TODO(now) """ -import copy from typing import Optional import torch @@ -28,117 +22,68 @@ class Transformer(nn.Module): normalize_before=False, ): super().__init__() - encoder_layer = TransformerEncoderLayer( - d_model, nhead, dim_feedforward, dropout, activation, normalize_before + self.encoder = TransformerEncoder( + num_encoder_layers, d_model, nhead, dim_feedforward, dropout, activation, normalize_before ) - encoder_norm = nn.LayerNorm(d_model) if normalize_before else None - self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) - - decoder_layer = TransformerDecoderLayer( - d_model, nhead, dim_feedforward, dropout, activation, normalize_before + self.decoder = TransformerDecoder( + num_decoder_layers, d_model, nhead, dim_feedforward, dropout, activation, normalize_before ) - decoder_norm = nn.LayerNorm(d_model) - self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm) - - self._reset_parameters() - self.d_model = d_model self.nhead = nhead + self._init_params() # TODO(now): move to somewhere common - def _reset_parameters(self): + def _init_params(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) - def forward( - self, - src, - query_embed, - pos_embed, - latent_input=None, - proprio_input=None, - additional_pos_embed=None, - ): + def forward(self, x, encoder_pos, decoder_pos): + """ + Args: + x: ((E)ncoder (S)equence, (B)atch, (C)hannels) + decoder_pos: (Decoder Sequence, C) tensor for the decoder's positional embedding. + encoder_pos: (ES, C) tenso + """ # TODO flatten only when input has H and W - if len(src.shape) == 4: # has H and W - # flatten NxCxHxW to HWxNxC - bs, c, h, w = src.shape - # Each "pixel" on the feature maps will form a token. - src = src.flatten(2).permute(2, 0, 1) - pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1) - query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) + bs = x.shape[1] - additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim - pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0) - - addition_input = torch.stack([latent_input, proprio_input], axis=0) - src = torch.cat([addition_input, src], axis=0) - else: - assert len(src.shape) == 3 - # flatten NxHWxC to HWxNxC - bs, hw, c = src.shape - src = src.permute(1, 0, 2) - pos_embed = pos_embed.unsqueeze(1).repeat(1, bs, 1) - query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) - - tgt = torch.zeros_like(query_embed) - memory = self.encoder(src, pos=pos_embed) - hs = self.decoder(tgt, memory, pos=pos_embed, query_pos=query_embed) - hs = hs.transpose(0, 1) - return hs + encoder_out = self.encoder(x, pos=encoder_pos) + decoder_in = torch.zeros( + (decoder_pos.shape[0], bs, decoder_pos.shape[2]), + dtype=decoder_pos.dtype, + device=decoder_pos.device, + ) + decoder_out = self.decoder(decoder_in, encoder_out, encoder_pos=encoder_pos, decoder_pos=decoder_pos) + return decoder_out class TransformerEncoder(nn.Module): - def __init__(self, encoder_layer, num_layers, norm=None): - super().__init__() - self.layers = _get_clones(encoder_layer, num_layers) - self.num_layers = num_layers - self.norm = norm - - def forward( + def __init__( self, - src, - pos: Optional[Tensor] = None, + num_layers, + d_model, + nhead, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, ): - output = src - - for layer in self.layers: - output = layer(output, pos=pos) - - if self.norm is not None: - output = self.norm(output) - - return output - - -class TransformerDecoder(nn.Module): - def __init__(self, decoder_layer, num_layers, norm=None): super().__init__() - self.layers = _get_clones(decoder_layer, num_layers) - self.num_layers = num_layers - self.norm = norm - - def forward( - self, - tgt, - memory, - pos: Optional[Tensor] = None, - query_pos: Optional[Tensor] = None, - ): - output = tgt + self.layers = nn.ModuleList( + [ + TransformerEncoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before + ) + for _ in range(num_layers) + ] + ) + self.norm = nn.LayerNorm(d_model) if normalize_before else nn.Identity() + def forward(self, x, pos: Optional[Tensor] = None): for layer in self.layers: - output = layer( - output, - memory, - pos=pos, - query_pos=query_pos, - ) - - if self.norm is not None: - output = self.norm(output) - - return output + x = layer(x, pos=pos) + x = self.norm(x) + return x class TransformerEncoderLayer(nn.Module): @@ -160,45 +105,55 @@ class TransformerEncoderLayer(nn.Module): self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before - def with_pos_embed(self, tensor, pos: Optional[Tensor]): - return tensor if pos is None else tensor + pos - - def forward_post( - self, - src, - pos: Optional[Tensor] = None, - ): - q = k = self.with_pos_embed(src, pos) - src2 = self.self_attn(q, k, value=src)[0] - src = src + self.dropout1(src2) - src = self.norm1(src) - src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) - src = src + self.dropout2(src2) - src = self.norm2(src) - return src - - def forward_pre( - self, - src, - pos: Optional[Tensor] = None, - ): - src2 = self.norm1(src) - q = k = self.with_pos_embed(src2, pos) - src2 = self.self_attn(q, k, value=src2)[0] - src = src + self.dropout1(src2) - src2 = self.norm2(src) - src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) - src = src + self.dropout2(src2) - return src - - def forward( - self, - src, - pos: Optional[Tensor] = None, - ): + def forward(self, x, pos: Optional[Tensor] = None): + skip = x if self.normalize_before: - return self.forward_pre(src, pos) - return self.forward_post(src, pos) + x = self.norm1(x) + q = k = x if pos is None else x + pos + x = self.self_attn(q, k, value=x)[0] + x = skip + self.dropout1(x) + if self.normalize_before: + skip = x + x = self.norm2(x) + else: + x = self.norm1(x) + skip = x + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + x = skip + self.dropout2(x) + if not self.normalize_before: + x = self.norm2(x) + return x + + +class TransformerDecoder(nn.Module): + def __init__( + self, + num_layers, + d_model, + nhead, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + ): + super().__init__() + self.layers = nn.ModuleList( + [ + TransformerDecoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before + ) + for _ in range(num_layers) + ] + ) + self.num_layers = num_layers + self.norm = nn.LayerNorm(d_model) + + def forward(self, x, encoder_out, decoder_pos: Tensor | None = None, encoder_pos: Tensor | None = None): + for layer in self.layers: + x = layer(x, encoder_out, decoder_pos=decoder_pos, encoder_pos=encoder_pos) + if self.norm is not None: + x = self.norm(x) + return x class TransformerDecoderLayer(nn.Module): @@ -223,86 +178,55 @@ class TransformerDecoderLayer(nn.Module): self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before - def with_pos_embed(self, tensor, pos: Optional[Tensor]): + def maybe_add_pos_embed(self, tensor: Tensor, pos: Tensor | None) -> Tensor: return tensor if pos is None else tensor + pos - def forward_post( - self, - tgt, - memory, - pos: Optional[Tensor] = None, - query_pos: Optional[Tensor] = None, - ): - q = k = self.with_pos_embed(tgt, query_pos) - tgt2 = self.self_attn(q, k, value=tgt)[0] - tgt = tgt + self.dropout1(tgt2) - tgt = self.norm1(tgt) - tgt2 = self.multihead_attn( - query=self.with_pos_embed(tgt, query_pos), - key=self.with_pos_embed(memory, pos), - value=memory, - )[0] - tgt = tgt + self.dropout2(tgt2) - tgt = self.norm2(tgt) - tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) - tgt = tgt + self.dropout3(tgt2) - tgt = self.norm3(tgt) - return tgt - - def forward_pre( - self, - tgt, - memory, - pos: Optional[Tensor] = None, - query_pos: Optional[Tensor] = None, - ): - tgt2 = self.norm1(tgt) - q = k = self.with_pos_embed(tgt2, query_pos) - tgt2 = self.self_attn(q, k, value=tgt2)[0] - tgt = tgt + self.dropout1(tgt2) - tgt2 = self.norm2(tgt) - tgt2 = self.multihead_attn( - query=self.with_pos_embed(tgt2, query_pos), - key=self.with_pos_embed(memory, pos), - value=memory, - )[0] - tgt = tgt + self.dropout2(tgt2) - tgt2 = self.norm3(tgt) - tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) - tgt = tgt + self.dropout3(tgt2) - return tgt - def forward( self, - tgt, - memory, - pos: Optional[Tensor] = None, - query_pos: Optional[Tensor] = None, - ): + x: Tensor, + encoder_out: Tensor, + decoder_pos: Tensor | None = None, + encoder_pos: Tensor | None = None, + ) -> Tensor: + """ + Args: + x: (Decoder Sequence, Batch, Channel) tensor of input tokens. + encoder_out: (Encoder Sequence, B, C) output features from the last layer of the encoder we are + cross-attending with. + decoder_pos: (ES, 1, C) positional embedding for keys (from the encoder). + encoder_pos: (DS, 1, C) Positional_embedding for the queries (from the decoder). + Returns: + (DS, B, C) tensor of decoder output features. + """ + skip = x if self.normalize_before: - return self.forward_pre( - tgt, - memory, - pos, - query_pos, - ) - return self.forward_post(tgt, memory, pos, query_pos) - - -def _get_clones(module, n): - return nn.ModuleList([copy.deepcopy(module) for _ in range(n)]) - - -def build_transformer(args): - return Transformer( - d_model=args.hidden_dim, - dropout=args.dropout, - nhead=args.nheads, - dim_feedforward=args.dim_feedforward, - num_encoder_layers=args.enc_layers, - num_decoder_layers=args.dec_layers, - normalize_before=args.pre_norm, - ) + x = self.norm1(x) + q = k = self.maybe_add_pos_embed(x, decoder_pos) + x = self.self_attn(q, k, value=x)[0] + x = skip + self.dropout1(x) + if self.normalize_before: + skip = x + x = self.norm2(x) + else: + x = self.norm1(x) + skip = x + x = self.multihead_attn( + query=self.maybe_add_pos_embed(x, decoder_pos), + key=self.maybe_add_pos_embed(encoder_out, encoder_pos), + value=encoder_out, + )[0] + x = skip + self.dropout2(x) + if self.normalize_before: + skip = x + x = self.norm3(x) + else: + x = self.norm2(x) + skip = x + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + x = skip + self.dropout3(x) + if not self.normalize_before: + x = self.norm3(x) + return x def _get_activation_fn(activation): @@ -313,4 +237,4 @@ def _get_activation_fn(activation): return F.gelu if activation == "glu": return F.glu - raise RuntimeError(f"activation should be relu/gelu, not {activation}.") + raise RuntimeError(f"activation should be relu/gelu/glu, not {activation}.") diff --git a/pyproject.toml b/pyproject.toml index b2526e5c..6d76cffc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -101,3 +101,6 @@ enable = true [build-system] requires = ["poetry-core>=1.0.0", "poetry-dynamic-versioning>=1.0.0,<2.0.0"] build-backend = "poetry_dynamic_versioning.backend" + +[tool.black] +line-length = 110