backup wip
This commit is contained in:
parent
110ac5ffa1
commit
278336a39a
|
@ -4,7 +4,7 @@ import torch
|
|||
from torch import nn
|
||||
|
||||
from .backbone import build_backbone
|
||||
from .transformer import TransformerEncoder, TransformerEncoderLayer, build_transformer
|
||||
from .transformer import Transformer, TransformerEncoder
|
||||
|
||||
|
||||
def get_sinusoid_encoding_table(n_position, d_hid):
|
||||
|
@ -124,16 +124,14 @@ class ActionChunkingTransformer(nn.Module):
|
|||
robot_state_embed = self.vae_encoder_robot_state_input_proj(robot_state).unsqueeze(1) # (B, 1, D)
|
||||
action_embed = self.vae_encoder_action_input_proj(actions) # (B, S, D)
|
||||
vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D)
|
||||
vae_encoder_input = vae_encoder_input.permute(1, 0, 2) # (S+2, B, D)
|
||||
# Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
|
||||
# Prepare fixed positional embedding.
|
||||
pos_embed = self.vae_encoder_pos_enc.clone().detach().permute(1, 0, 2) # (S+2, 1, D)
|
||||
pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D)
|
||||
# Forward pass through VAE encoder and sample the latent with the reparameterization trick.
|
||||
vae_encoder_output = self.vae_encoder(
|
||||
vae_encoder_input, pos=pos_embed
|
||||
) # , src_key_padding_mask=is_pad) # TODO(now)
|
||||
vae_encoder_output = vae_encoder_output[0] # take cls output only
|
||||
latent_pdf_params = self.vae_encoder_latent_output_proj(vae_encoder_output)
|
||||
cls_token_out = self.vae_encoder(
|
||||
vae_encoder_input.permute(1, 0, 2), pos=pos_embed.permute(1, 0, 2)
|
||||
)[0] # (B, D)
|
||||
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
|
||||
mu = latent_pdf_params[:, : self.latent_dim]
|
||||
logvar = latent_pdf_params[:, self.latent_dim :]
|
||||
# Use reparameterization trick to sample from the latent's PDF.
|
||||
|
@ -151,10 +149,11 @@ class ActionChunkingTransformer(nn.Module):
|
|||
all_cam_pos = []
|
||||
for cam_id, _ in enumerate(self.camera_names):
|
||||
# TODO(now): remove the positional embedding from the backbones.
|
||||
features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED
|
||||
features = features[0] # take the last layer feature
|
||||
cam_features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED
|
||||
cam_features = cam_features[0] # take the last layer feature
|
||||
pos = pos[0]
|
||||
all_cam_features.append(self.encoder_img_feat_input_proj(features))
|
||||
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
|
||||
all_cam_features.append(cam_features)
|
||||
all_cam_pos.append(pos)
|
||||
# Concatenate image observation feature maps along the width dimension.
|
||||
transformer_input = torch.cat(all_cam_features, axis=3)
|
||||
|
@ -163,36 +162,25 @@ class ActionChunkingTransformer(nn.Module):
|
|||
robot_state_embed = self.encoder_robot_state_input_proj(robot_state)
|
||||
latent_embed = self.encoder_latent_input_proj(latent_sample)
|
||||
|
||||
# TODO(now): Explain all of this madness.
|
||||
transformer_input = torch.cat(
|
||||
[
|
||||
torch.stack([latent_embed, robot_state_embed], axis=0),
|
||||
transformer_input.flatten(2).permute(2, 0, 1),
|
||||
]
|
||||
)
|
||||
pos_embed = torch.cat(
|
||||
[self.additional_pos_embed.weight.unsqueeze(1), pos.flatten(2).permute(2, 0, 1)], axis=0
|
||||
)
|
||||
|
||||
# Run the transformer and project the outputs to the action space.
|
||||
transformer_output = self.transformer(
|
||||
transformer_input,
|
||||
query_embed=self.decoder_pos_embed.weight,
|
||||
pos_embed=pos,
|
||||
latent_input=latent_embed,
|
||||
proprio_input=robot_state_embed,
|
||||
additional_pos_embed=self.additional_pos_embed.weight,
|
||||
)
|
||||
a_hat = self.action_head(transformer_output)
|
||||
|
||||
return a_hat, [mu, logvar]
|
||||
|
||||
|
||||
def build_vae_encoder(args):
|
||||
d_model = args.hidden_dim # 256
|
||||
dropout = args.dropout # 0.1
|
||||
nhead = args.nheads # 8
|
||||
dim_feedforward = args.dim_feedforward # 2048
|
||||
num_encoder_layers = args.enc_layers # 4 # TODO shared with VAE decoder
|
||||
normalize_before = args.pre_norm # False
|
||||
activation = "relu"
|
||||
|
||||
encoder_layer = TransformerEncoderLayer(
|
||||
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
||||
)
|
||||
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
||||
encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
||||
|
||||
return encoder
|
||||
encoder_pos=pos_embed,
|
||||
decoder_pos=self.decoder_pos_embed.weight.unsqueeze(1),
|
||||
).transpose(0, 1) # back to (B, S, C)
|
||||
actions = self.action_head(transformer_output)
|
||||
return actions, [mu, logvar]
|
||||
|
||||
|
||||
def build(args):
|
||||
|
@ -203,9 +191,26 @@ def build(args):
|
|||
backbone = build_backbone(args)
|
||||
backbones.append(backbone)
|
||||
|
||||
transformer = build_transformer(args)
|
||||
transformer = Transformer(
|
||||
d_model=args.hidden_dim,
|
||||
dropout=args.dropout,
|
||||
nhead=args.nheads,
|
||||
dim_feedforward=args.dim_feedforward,
|
||||
num_encoder_layers=args.enc_layers,
|
||||
num_decoder_layers=args.dec_layers,
|
||||
normalize_before=args.pre_norm,
|
||||
)
|
||||
|
||||
vae_encoder = build_vae_encoder(args)
|
||||
# TODO(now): args.enc_layers shouldn't be shared with the transformer decoder
|
||||
vae_encoder = TransformerEncoder(
|
||||
num_layers=args.enc_layers,
|
||||
d_model=args.hidden_dim,
|
||||
nhead=args.nheads,
|
||||
dim_feedforward=args.dim_feedforward,
|
||||
dropout=args.dropout,
|
||||
activation="relu",
|
||||
normalize_before=args.pre_norm,
|
||||
)
|
||||
|
||||
model = ActionChunkingTransformer(
|
||||
backbones,
|
||||
|
|
|
@ -1,13 +1,7 @@
|
|||
"""
|
||||
DETR Transformer class.
|
||||
|
||||
Copy-paste from torch.nn.Transformer with modifications:
|
||||
* positional encodings are passed in MHattention
|
||||
* extra LN at the end of encoder is removed
|
||||
* decoder returns a stack of activations from all decoding layers
|
||||
TODO(now)
|
||||
"""
|
||||
|
||||
import copy
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
@ -28,117 +22,68 @@ class Transformer(nn.Module):
|
|||
normalize_before=False,
|
||||
):
|
||||
super().__init__()
|
||||
encoder_layer = TransformerEncoderLayer(
|
||||
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
||||
self.encoder = TransformerEncoder(
|
||||
num_encoder_layers, d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
||||
)
|
||||
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
||||
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
||||
|
||||
decoder_layer = TransformerDecoderLayer(
|
||||
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
||||
self.decoder = TransformerDecoder(
|
||||
num_decoder_layers, d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
||||
)
|
||||
decoder_norm = nn.LayerNorm(d_model)
|
||||
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
self.d_model = d_model
|
||||
self.nhead = nhead
|
||||
self._init_params() # TODO(now): move to somewhere common
|
||||
|
||||
def _reset_parameters(self):
|
||||
def _init_params(self):
|
||||
for p in self.parameters():
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src,
|
||||
query_embed,
|
||||
pos_embed,
|
||||
latent_input=None,
|
||||
proprio_input=None,
|
||||
additional_pos_embed=None,
|
||||
):
|
||||
def forward(self, x, encoder_pos, decoder_pos):
|
||||
"""
|
||||
Args:
|
||||
x: ((E)ncoder (S)equence, (B)atch, (C)hannels)
|
||||
decoder_pos: (Decoder Sequence, C) tensor for the decoder's positional embedding.
|
||||
encoder_pos: (ES, C) tenso
|
||||
"""
|
||||
# TODO flatten only when input has H and W
|
||||
if len(src.shape) == 4: # has H and W
|
||||
# flatten NxCxHxW to HWxNxC
|
||||
bs, c, h, w = src.shape
|
||||
# Each "pixel" on the feature maps will form a token.
|
||||
src = src.flatten(2).permute(2, 0, 1)
|
||||
pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1)
|
||||
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||
bs = x.shape[1]
|
||||
|
||||
additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim
|
||||
pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
|
||||
|
||||
addition_input = torch.stack([latent_input, proprio_input], axis=0)
|
||||
src = torch.cat([addition_input, src], axis=0)
|
||||
else:
|
||||
assert len(src.shape) == 3
|
||||
# flatten NxHWxC to HWxNxC
|
||||
bs, hw, c = src.shape
|
||||
src = src.permute(1, 0, 2)
|
||||
pos_embed = pos_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||
|
||||
tgt = torch.zeros_like(query_embed)
|
||||
memory = self.encoder(src, pos=pos_embed)
|
||||
hs = self.decoder(tgt, memory, pos=pos_embed, query_pos=query_embed)
|
||||
hs = hs.transpose(0, 1)
|
||||
return hs
|
||||
encoder_out = self.encoder(x, pos=encoder_pos)
|
||||
decoder_in = torch.zeros(
|
||||
(decoder_pos.shape[0], bs, decoder_pos.shape[2]),
|
||||
dtype=decoder_pos.dtype,
|
||||
device=decoder_pos.device,
|
||||
)
|
||||
decoder_out = self.decoder(decoder_in, encoder_out, encoder_pos=encoder_pos, decoder_pos=decoder_pos)
|
||||
return decoder_out
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
def __init__(self, encoder_layer, num_layers, norm=None):
|
||||
super().__init__()
|
||||
self.layers = _get_clones(encoder_layer, num_layers)
|
||||
self.num_layers = num_layers
|
||||
self.norm = norm
|
||||
|
||||
def forward(
|
||||
def __init__(
|
||||
self,
|
||||
src,
|
||||
pos: Optional[Tensor] = None,
|
||||
num_layers,
|
||||
d_model,
|
||||
nhead,
|
||||
dim_feedforward=2048,
|
||||
dropout=0.1,
|
||||
activation="relu",
|
||||
normalize_before=False,
|
||||
):
|
||||
output = src
|
||||
|
||||
for layer in self.layers:
|
||||
output = layer(output, pos=pos)
|
||||
|
||||
if self.norm is not None:
|
||||
output = self.norm(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class TransformerDecoder(nn.Module):
|
||||
def __init__(self, decoder_layer, num_layers, norm=None):
|
||||
super().__init__()
|
||||
self.layers = _get_clones(decoder_layer, num_layers)
|
||||
self.num_layers = num_layers
|
||||
self.norm = norm
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None,
|
||||
):
|
||||
output = tgt
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
TransformerEncoderLayer(
|
||||
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.norm = nn.LayerNorm(d_model) if normalize_before else nn.Identity()
|
||||
|
||||
def forward(self, x, pos: Optional[Tensor] = None):
|
||||
for layer in self.layers:
|
||||
output = layer(
|
||||
output,
|
||||
memory,
|
||||
pos=pos,
|
||||
query_pos=query_pos,
|
||||
)
|
||||
|
||||
if self.norm is not None:
|
||||
output = self.norm(output)
|
||||
|
||||
return output
|
||||
x = layer(x, pos=pos)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
|
@ -160,45 +105,55 @@ class TransformerEncoderLayer(nn.Module):
|
|||
self.activation = _get_activation_fn(activation)
|
||||
self.normalize_before = normalize_before
|
||||
|
||||
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
||||
return tensor if pos is None else tensor + pos
|
||||
|
||||
def forward_post(
|
||||
self,
|
||||
src,
|
||||
pos: Optional[Tensor] = None,
|
||||
):
|
||||
q = k = self.with_pos_embed(src, pos)
|
||||
src2 = self.self_attn(q, k, value=src)[0]
|
||||
src = src + self.dropout1(src2)
|
||||
src = self.norm1(src)
|
||||
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
||||
src = src + self.dropout2(src2)
|
||||
src = self.norm2(src)
|
||||
return src
|
||||
|
||||
def forward_pre(
|
||||
self,
|
||||
src,
|
||||
pos: Optional[Tensor] = None,
|
||||
):
|
||||
src2 = self.norm1(src)
|
||||
q = k = self.with_pos_embed(src2, pos)
|
||||
src2 = self.self_attn(q, k, value=src2)[0]
|
||||
src = src + self.dropout1(src2)
|
||||
src2 = self.norm2(src)
|
||||
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
|
||||
src = src + self.dropout2(src2)
|
||||
return src
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src,
|
||||
pos: Optional[Tensor] = None,
|
||||
):
|
||||
def forward(self, x, pos: Optional[Tensor] = None):
|
||||
skip = x
|
||||
if self.normalize_before:
|
||||
return self.forward_pre(src, pos)
|
||||
return self.forward_post(src, pos)
|
||||
x = self.norm1(x)
|
||||
q = k = x if pos is None else x + pos
|
||||
x = self.self_attn(q, k, value=x)[0]
|
||||
x = skip + self.dropout1(x)
|
||||
if self.normalize_before:
|
||||
skip = x
|
||||
x = self.norm2(x)
|
||||
else:
|
||||
x = self.norm1(x)
|
||||
skip = x
|
||||
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
||||
x = skip + self.dropout2(x)
|
||||
if not self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
return x
|
||||
|
||||
|
||||
class TransformerDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_layers,
|
||||
d_model,
|
||||
nhead,
|
||||
dim_feedforward=2048,
|
||||
dropout=0.1,
|
||||
activation="relu",
|
||||
normalize_before=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
TransformerDecoderLayer(
|
||||
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.num_layers = num_layers
|
||||
self.norm = nn.LayerNorm(d_model)
|
||||
|
||||
def forward(self, x, encoder_out, decoder_pos: Tensor | None = None, encoder_pos: Tensor | None = None):
|
||||
for layer in self.layers:
|
||||
x = layer(x, encoder_out, decoder_pos=decoder_pos, encoder_pos=encoder_pos)
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
class TransformerDecoderLayer(nn.Module):
|
||||
|
@ -223,86 +178,55 @@ class TransformerDecoderLayer(nn.Module):
|
|||
self.activation = _get_activation_fn(activation)
|
||||
self.normalize_before = normalize_before
|
||||
|
||||
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
||||
def maybe_add_pos_embed(self, tensor: Tensor, pos: Tensor | None) -> Tensor:
|
||||
return tensor if pos is None else tensor + pos
|
||||
|
||||
def forward_post(
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None,
|
||||
):
|
||||
q = k = self.with_pos_embed(tgt, query_pos)
|
||||
tgt2 = self.self_attn(q, k, value=tgt)[0]
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
tgt = self.norm1(tgt)
|
||||
tgt2 = self.multihead_attn(
|
||||
query=self.with_pos_embed(tgt, query_pos),
|
||||
key=self.with_pos_embed(memory, pos),
|
||||
value=memory,
|
||||
)[0]
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
tgt = self.norm2(tgt)
|
||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
||||
tgt = tgt + self.dropout3(tgt2)
|
||||
tgt = self.norm3(tgt)
|
||||
return tgt
|
||||
|
||||
def forward_pre(
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None,
|
||||
):
|
||||
tgt2 = self.norm1(tgt)
|
||||
q = k = self.with_pos_embed(tgt2, query_pos)
|
||||
tgt2 = self.self_attn(q, k, value=tgt2)[0]
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
tgt2 = self.norm2(tgt)
|
||||
tgt2 = self.multihead_attn(
|
||||
query=self.with_pos_embed(tgt2, query_pos),
|
||||
key=self.with_pos_embed(memory, pos),
|
||||
value=memory,
|
||||
)[0]
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
tgt2 = self.norm3(tgt)
|
||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
||||
tgt = tgt + self.dropout3(tgt2)
|
||||
return tgt
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None,
|
||||
):
|
||||
x: Tensor,
|
||||
encoder_out: Tensor,
|
||||
decoder_pos: Tensor | None = None,
|
||||
encoder_pos: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: (Decoder Sequence, Batch, Channel) tensor of input tokens.
|
||||
encoder_out: (Encoder Sequence, B, C) output features from the last layer of the encoder we are
|
||||
cross-attending with.
|
||||
decoder_pos: (ES, 1, C) positional embedding for keys (from the encoder).
|
||||
encoder_pos: (DS, 1, C) Positional_embedding for the queries (from the decoder).
|
||||
Returns:
|
||||
(DS, B, C) tensor of decoder output features.
|
||||
"""
|
||||
skip = x
|
||||
if self.normalize_before:
|
||||
return self.forward_pre(
|
||||
tgt,
|
||||
memory,
|
||||
pos,
|
||||
query_pos,
|
||||
)
|
||||
return self.forward_post(tgt, memory, pos, query_pos)
|
||||
|
||||
|
||||
def _get_clones(module, n):
|
||||
return nn.ModuleList([copy.deepcopy(module) for _ in range(n)])
|
||||
|
||||
|
||||
def build_transformer(args):
|
||||
return Transformer(
|
||||
d_model=args.hidden_dim,
|
||||
dropout=args.dropout,
|
||||
nhead=args.nheads,
|
||||
dim_feedforward=args.dim_feedforward,
|
||||
num_encoder_layers=args.enc_layers,
|
||||
num_decoder_layers=args.dec_layers,
|
||||
normalize_before=args.pre_norm,
|
||||
)
|
||||
x = self.norm1(x)
|
||||
q = k = self.maybe_add_pos_embed(x, decoder_pos)
|
||||
x = self.self_attn(q, k, value=x)[0]
|
||||
x = skip + self.dropout1(x)
|
||||
if self.normalize_before:
|
||||
skip = x
|
||||
x = self.norm2(x)
|
||||
else:
|
||||
x = self.norm1(x)
|
||||
skip = x
|
||||
x = self.multihead_attn(
|
||||
query=self.maybe_add_pos_embed(x, decoder_pos),
|
||||
key=self.maybe_add_pos_embed(encoder_out, encoder_pos),
|
||||
value=encoder_out,
|
||||
)[0]
|
||||
x = skip + self.dropout2(x)
|
||||
if self.normalize_before:
|
||||
skip = x
|
||||
x = self.norm3(x)
|
||||
else:
|
||||
x = self.norm2(x)
|
||||
skip = x
|
||||
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
||||
x = skip + self.dropout3(x)
|
||||
if not self.normalize_before:
|
||||
x = self.norm3(x)
|
||||
return x
|
||||
|
||||
|
||||
def _get_activation_fn(activation):
|
||||
|
@ -313,4 +237,4 @@ def _get_activation_fn(activation):
|
|||
return F.gelu
|
||||
if activation == "glu":
|
||||
return F.glu
|
||||
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
|
||||
raise RuntimeError(f"activation should be relu/gelu/glu, not {activation}.")
|
||||
|
|
|
@ -101,3 +101,6 @@ enable = true
|
|||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0", "poetry-dynamic-versioning>=1.0.0,<2.0.0"]
|
||||
build-backend = "poetry_dynamic_versioning.backend"
|
||||
|
||||
[tool.black]
|
||||
line-length = 110
|
||||
|
|
Loading…
Reference in New Issue