backup wip

This commit is contained in:
Alexander Soare 2024-04-05 11:03:28 +01:00
parent 3a4dfa82fe
commit edb125b351
3 changed files with 188 additions and 213 deletions

View File

@ -1,13 +1,13 @@
"""Action Chunking Transformer Policy """Action Chunking Transformer Policy
As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://arxiv.org/abs/2304.13705). As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://arxiv.org/abs/2304.13705).
The majority of changes here involve removing unused code, unifying naming, and adding helpful comments.
""" """
import logging
import math import math
import time import time
from itertools import chain from itertools import chain
from typing import Callable, Optional from typing import Callable
import einops import einops
import numpy as np import numpy as np
@ -26,40 +26,56 @@ from lerobot.common.utils import get_safe_torch_device
class ActionChunkingTransformerPolicy(AbstractPolicy): class ActionChunkingTransformerPolicy(AbstractPolicy):
""" """
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
Hardware (https://arxiv.org/abs/2304.13705). Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
Note: In this code we use the terms `vae_encoder`, 'encoder', `decoder`. The meanings are as follows.
- The `vae_encoder` is, as per the literature around variational auto-encoders (VAE), the part of the
model that encodes the target data (a sequence of actions), and the condition (the robot
joint-space).
- A transformer with an `encoder` (not the VAE encoder) and `decoder` (not the VAE decoder) with
cross-attention is used as the VAE decoder. For these terms, we drop the `vae_` prefix because we
have an option to train this model without the variational objective (in which case we drop the
`vae_encoder` altogether, and nothing about this model has anything to do with a VAE).
Transformer
Used alone for inference
(acts as VAE decoder
during training)
Outputs
Transf.
decoder
VAE
encoder Transf.
encoder
inputs
""" """
name = "act" name = "act"
def __init__(self, cfg, device, n_action_steps=1): def __init__(self, cfg, device, n_action_steps=1):
""" """
Args: TODO(alexander-soare): Add documentation for all parameters.
vae: Whether to use the variational objective. TODO(now): Give more details.
temporal_agg: Whether to do temporal aggregation. For each timestep during rollout, the action
returned as an exponential moving average of previously generated actions for that timestep.
n_obs_steps: Number of time steps worth of observation to use as input.
horizon: The number of actions to generate in one forward pass.
kl_weight: Weight for KL divergence. Defaults to None. Only applicable when using the variational
objective.
batch_size: Training batch size.
grad_clip_norm: Optionally clip the gradients to have this value as the norm at most. Defaults to
None meaning gradient clipping is not applied.
lr: Learning rate.
""" """
super().__init__(n_action_steps) super().__init__(n_action_steps)
self.cfg = cfg self.cfg = cfg
self.n_action_steps = n_action_steps self.n_action_steps = n_action_steps
self.device = get_safe_torch_device(device) self.device = get_safe_torch_device(device)
self.model = ActionChunkingTransformer( self.model = _ActionChunkingTransformer(cfg)
cfg, self._create_optimizer()
state_dim=cfg.state_dim, self.to(self.device)
action_dim=cfg.action_dim,
horizon=cfg.horizon,
camera_names=cfg.camera_names,
use_vae=cfg.vae,
)
def _create_optimizer(self):
optimizer_params_dicts = [ optimizer_params_dicts = [
{ {
"params": [ "params": [
@ -74,14 +90,12 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
for n, p in self.model.named_parameters() for n, p in self.model.named_parameters()
if n.startswith("backbone") and p.requires_grad if n.startswith("backbone") and p.requires_grad
], ],
"lr": cfg.lr_backbone, "lr": self.cfg.lr_backbone,
}, },
] ]
self.optimizer = torch.optim.AdamW(optimizer_params_dicts, lr=cfg.lr, weight_decay=cfg.weight_decay) self.optimizer = torch.optim.AdamW(
optimizer_params_dicts, lr=self.cfg.lr, weight_decay=self.cfg.weight_decay
self.kl_weight = self.cfg.kl_weight )
logging.info(f"KL Weight {self.kl_weight}")
self.to(self.device)
def update(self, replay_buffer, step): def update(self, replay_buffer, step):
del step del step
@ -137,7 +151,6 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
batch = process_batch(batch, self.cfg.horizon, num_slices) batch = process_batch(batch, self.cfg.horizon, num_slices)
data_s = time.time() - start_time data_s = time.time() - start_time
print(data_s)
loss = self.compute_loss(batch) loss = self.compute_loss(batch)
loss.backward() loss.backward()
@ -192,16 +205,6 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
"image": observation["image", "top"], "image": observation["image", "top"],
"agent_pos": observation["state"], "agent_pos": observation["state"],
} }
# qpos = obs_dict["agent_pos"]
# img = obs_dict["image"]
# qpos_ = torch.load('/tmp/qpos.pth')
# img_ = torch.load('/tmp/curr_image.pth')
# out_ = torch.load('/tmp/out.pth')
# import cv2, numpy as np
# cv2.imwrite("ours.png", (obs_dict["image"][0, 0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
# cv2.imwrite("theirs.png", (img_[0, 0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
# out = self._forward(qpos_, img_)
# breakpoint()
action = self._forward(qpos=obs_dict["agent_pos"] * 0.182, image=obs_dict["image"]) action = self._forward(qpos=obs_dict["agent_pos"] * 0.182, image=obs_dict["image"])
if self.cfg.temporal_agg: if self.cfg.temporal_agg:
@ -236,14 +239,14 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
loss_dict = {} loss_dict = {}
loss_dict["l1"] = l1 loss_dict["l1"] = l1
if self.cfg.vae: if self.cfg.use_vae:
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for # Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
# each dimension independently, we sum over the latent dimension to get the total # each dimension independently, we sum over the latent dimension to get the total
# KL-divergence per batch element, then take the mean over the batch. # KL-divergence per batch element, then take the mean over the batch.
# (See App. B of https://arxiv.org/abs/1312.6114 for more details). # (See App. B of https://arxiv.org/abs/1312.6114 for more details).
mean_kld = (-0.5 * (1 + log_sigma_x2 - mu.pow(2) - (log_sigma_x2).exp())).sum(-1).mean() mean_kld = (-0.5 * (1 + log_sigma_x2 - mu.pow(2) - (log_sigma_x2).exp())).sum(-1).mean()
loss_dict["kl"] = mean_kld loss_dict["kl"] = mean_kld
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.kl_weight loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.cfg.kl_weight
else: else:
loss_dict["loss"] = loss_dict["l1"] loss_dict["loss"] = loss_dict["l1"]
return loss_dict return loss_dict
@ -252,135 +255,74 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
return action return action
def create_sinusoidal_position_embedding(n_position, d_hid):
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
# TODO(alexander-soare) move all this code into the policy when we have the policy API established. # TODO(alexander-soare) move all this code into the policy when we have the policy API established.
class ActionChunkingTransformer(nn.Module): class _ActionChunkingTransformer(nn.Module):
""" def __init__(self, cfg):
Action Chunking Transformer as per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware
(paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
Note: In this code we use the terms `vae_encoder`, 'encoder', `decoder`. The meanings are as follows.
- The `vae_encoder` is, as per the literature around variational auto-encoders (VAE), the part of the
model that encodes the target data (a sequence of actions), and the condition (the robot
joint-space).
- A transformer with an `encoder` (not the VAE encoder) and `decoder` (not the VAE decoder) with
cross-attention is used as the VAE decoder. For these terms, we drop the `vae_` prefix because we
have an option to train this model without the variational objective (in which case we drop the
`vae_encoder` altogether, and nothing about this model has anything to do with a VAE).
Transformer
Used alone for inference
(acts as VAE decoder
during training)
Outputs
Transf.
decoder
VAE
encoder Transf.
encoder
inputs
"""
def __init__(self, args, state_dim, action_dim, horizon, camera_names, use_vae):
"""Initializes the model.
Parameters:
state_dim: robot state dimension of the environment
horizon: number of object queries, ie detection slot. This is the maximal number of objects
DETR can detect in a single image. For COCO, we recommend 100 queries.
Args:
state_dim: Robot positional state dimension.
action_dim: Action dimension.
horizon: The number of actions to generate in one forward pass.
use_vae: Whether to use the variational objective. TODO(now): Give more details.
"""
super().__init__() super().__init__()
self.camera_names = camera_names self.camera_names = cfg.camera_names
self.use_vae = use_vae self.use_vae = cfg.use_vae
self.horizon = horizon self.horizon = cfg.horizon
self.hidden_dim = args.hidden_dim self.d_model = cfg.d_model
transformer_common_kwargs = dict( # noqa: C408 transformer_common_kwargs = dict( # noqa: C408
d_model=self.hidden_dim, d_model=self.d_model,
nhead=args.nheads, num_heads=cfg.num_heads,
dim_feedforward=args.dim_feedforward, dim_feedforward=cfg.dim_feedforward,
dropout=args.dropout, dropout=cfg.dropout,
activation=args.activation, activation=cfg.activation,
normalize_before=args.pre_norm, normalize_before=cfg.pre_norm,
) )
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence]. # BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]). # The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
if use_vae: if self.use_vae:
# TODO(now): args.enc_layers shouldn't be shared with the transformer decoder self.vae_encoder = _TransformerEncoder(num_layers=cfg.vae_enc_layers, **transformer_common_kwargs)
self.vae_encoder = TransformerEncoder(num_layers=args.enc_layers, **transformer_common_kwargs) self.vae_encoder_cls_embed = nn.Embedding(1, self.d_model)
self.cls_embed = nn.Embedding(1, self.hidden_dim)
# Projection layer for joint-space configuration to hidden dimension. # Projection layer for joint-space configuration to hidden dimension.
self.vae_encoder_robot_state_input_proj = nn.Linear(state_dim, self.hidden_dim) self.vae_encoder_robot_state_input_proj = nn.Linear(cfg.state_dim, self.d_model)
# Projection layer for action (joint-space target) to hidden dimension. # Projection layer for action (joint-space target) to hidden dimension.
self.vae_encoder_action_input_proj = nn.Linear(state_dim, self.hidden_dim) self.vae_encoder_action_input_proj = nn.Linear(cfg.state_dim, self.d_model)
# Final size of latent z. TODO(now): Add to hyperparams. self.latent_dim = cfg.latent_dim
self.latent_dim = 32
# Projection layer from the VAE encoder's output to the latent distribution's parameter space. # Projection layer from the VAE encoder's output to the latent distribution's parameter space.
self.vae_encoder_latent_output_proj = nn.Linear(self.hidden_dim, self.latent_dim * 2) self.vae_encoder_latent_output_proj = nn.Linear(self.d_model, self.latent_dim * 2)
# Fixed sinusoidal positional embedding the whole input to the VAE encoder. # Fixed sinusoidal positional embedding the whole input to the VAE encoder. Unsqueeze for batch
# dimension.
self.register_buffer( self.register_buffer(
"vae_encoder_pos_enc", create_sinusoidal_position_embedding(1 + 1 + horizon, self.hidden_dim) "vae_encoder_pos_enc",
_create_sinusoidal_position_embedding(1 + 1 + self.horizon, self.d_model).unsqueeze(0),
) )
# Backbone for image feature extraction. # Backbone for image feature extraction.
self.backbone_position_embedding = SinusoidalPositionEmbedding2D(self.hidden_dim // 2) backbone_model = getattr(torchvision.models, cfg.backbone)(
backbone_model = getattr(torchvision.models, args.backbone)( replace_stride_with_dilation=[False, False, cfg.dilation],
replace_stride_with_dilation=[False, False, args.dilation], pretrained=cfg.pretrained_backbone,
pretrained=True, # TODO(now): Add pretrained option
norm_layer=FrozenBatchNorm2d, norm_layer=FrozenBatchNorm2d,
) )
# Note: The forward method of this returns a dict: {"feature_map": output}. # Note: The forward method of this returns a dict: {"feature_map": output}.
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"}) self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
# Transformer (acts as VAE decoder when training with the variational objective). # Transformer (acts as VAE decoder when training with the variational objective).
self.encoder = TransformerEncoder(num_layers=args.enc_layers, **transformer_common_kwargs) self.encoder = _TransformerEncoder(num_layers=cfg.enc_layers, **transformer_common_kwargs)
self.decoder = TransformerDecoder(num_layers=args.dec_layers, **transformer_common_kwargs) self.decoder = _TransformerDecoder(num_layers=cfg.dec_layers, **transformer_common_kwargs)
# Transformer encoder input projections. The tokens will be structured like # Transformer encoder input projections. The tokens will be structured like
# [latent, robot_state, image_feature_map_pixels]. # [latent, robot_state, image_feature_map_pixels].
self.encoder_robot_state_input_proj = nn.Linear(cfg.state_dim, self.d_model)
self.encoder_latent_input_proj = nn.Linear(self.latent_dim, self.d_model)
self.encoder_img_feat_input_proj = nn.Conv2d( self.encoder_img_feat_input_proj = nn.Conv2d(
backbone_model.fc.in_features, self.hidden_dim, kernel_size=1 backbone_model.fc.in_features, self.d_model, kernel_size=1
) )
self.encoder_robot_state_input_proj = nn.Linear(state_dim, self.hidden_dim) # Transformer encoder positional embeddings.
self.encoder_latent_input_proj = nn.Linear(self.latent_dim, self.hidden_dim) self.encoder_robot_and_latent_pos_embed = nn.Embedding(2, self.d_model)
# TODO(now): Fix this nonsense. One positional embedding is needed. We should extract the image self.encoder_cam_feat_pos_embed = _SinusoidalPositionEmbedding2D(self.d_model // 2)
# feature dimension with a dry run.
self.additional_pos_embed = nn.Embedding(
2, self.hidden_dim
) # learned position embedding for proprio and latent
# Transformer decoder. # Transformer decoder.
# Learnable positional embedding for the transformer's decoder (in the style of DETR object queries). # Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
self.decoder_pos_embed_embed = nn.Embedding(horizon, self.hidden_dim) self.decoder_pos_embed = nn.Embedding(self.horizon, self.d_model)
# Final action regression head on the output of the transformer's decoder. # Final action regression head on the output of the transformer's decoder.
self.action_head = nn.Linear(self.hidden_dim, action_dim) self.action_head = nn.Linear(self.d_model, cfg.action_dim)
self._reset_parameters() self._reset_parameters()
@ -390,7 +332,7 @@ class ActionChunkingTransformer(nn.Module):
if p.dim() > 1: if p.dim() > 1:
nn.init.xavier_uniform_(p) nn.init.xavier_uniform_(p)
def forward(self, robot_state, image, actions=None): def forward(self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None):
""" """
Args: Args:
robot_state: (B, J) batch of robot joint configurations. robot_state: (B, J) batch of robot joint configurations.
@ -405,10 +347,12 @@ class ActionChunkingTransformer(nn.Module):
batch_size, _ = robot_state.shape batch_size, _ = robot_state.shape
# Prepare the latent for input to the transformer. # Prepare the latent for input to the transformer encoder.
if self.use_vae and actions is not None: if self.use_vae and actions is not None:
# Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence]. # Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence].
cls_embed = einops.repeat(self.cls_embed.weight, "1 d -> b 1 d", b=batch_size) # (B, 1, D) cls_embed = einops.repeat(
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
) # (B, 1, D)
robot_state_embed = self.vae_encoder_robot_state_input_proj(robot_state).unsqueeze(1) # (B, 1, D) robot_state_embed = self.vae_encoder_robot_state_input_proj(robot_state).unsqueeze(1) # (B, 1, D)
action_embed = self.vae_encoder_action_input_proj(actions) # (B, S, D) action_embed = self.vae_encoder_action_input_proj(actions) # (B, S, D)
vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D) vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D)
@ -417,7 +361,7 @@ class ActionChunkingTransformer(nn.Module):
pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D) pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D)
# Forward pass through VAE encoder and sample the latent with the reparameterization trick. # Forward pass through VAE encoder and sample the latent with the reparameterization trick.
cls_token_out = self.vae_encoder( cls_token_out = self.vae_encoder(
vae_encoder_input.permute(1, 0, 2), pos=pos_embed.permute(1, 0, 2) vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2)
)[0] # (B, D) )[0] # (B, D)
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out) latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
mu = latent_pdf_params[:, : self.latent_dim] mu = latent_pdf_params[:, : self.latent_dim]
@ -432,23 +376,25 @@ class ActionChunkingTransformer(nn.Module):
robot_state.device robot_state.device
) )
# Prepare all other transformer inputs. # Prepare all other transformer encoder inputs.
# Image observation features and position embeddings. # Camera observation features and positional embeddings.
all_cam_features = [] all_cam_features = []
all_cam_pos = [] all_cam_pos_embeds = []
for cam_id, _ in enumerate(self.camera_names): for cam_id, _ in enumerate(self.camera_names):
cam_features = self.backbone(image[:, cam_id])["feature_map"] cam_features = self.backbone(image[:, cam_id])["feature_map"]
pos = self.backbone_position_embedding(cam_features).to(dtype=cam_features.dtype) cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w) cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
all_cam_features.append(cam_features) all_cam_features.append(cam_features)
all_cam_pos.append(pos) all_cam_pos_embeds.append(cam_pos_embed)
# Concatenate image observation feature maps along the width dimension. # Concatenate camera observation feature maps and positional embeddings along the width dimension.
encoder_in = torch.cat(all_cam_features, axis=3) encoder_in = torch.cat(all_cam_features, axis=3)
pos = torch.cat(all_cam_pos, axis=3) cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=3)
# Get positional embeddings for robot state and latent.
robot_state_embed = self.encoder_robot_state_input_proj(robot_state) robot_state_embed = self.encoder_robot_state_input_proj(robot_state)
latent_embed = self.encoder_latent_input_proj(latent_sample) latent_embed = self.encoder_latent_input_proj(latent_sample)
# TODO(now): Explain all of this madness. # Stack encoder input and positional embeddings moving to (S, B, C).
encoder_in = torch.cat( encoder_in = torch.cat(
[ [
torch.stack([latent_embed, robot_state_embed], axis=0), torch.stack([latent_embed, robot_state_embed], axis=0),
@ -456,60 +402,68 @@ class ActionChunkingTransformer(nn.Module):
] ]
) )
pos_embed = torch.cat( pos_embed = torch.cat(
[self.additional_pos_embed.weight.unsqueeze(1), pos.flatten(2).permute(2, 0, 1)], axis=0 [
self.encoder_robot_and_latent_pos_embed.weight.unsqueeze(1),
cam_pos_embed.flatten(2).permute(2, 0, 1),
],
axis=0,
) )
encoder_out = self.encoder(encoder_in, pos=pos_embed) # Forward pass through the transformer modules.
encoder_out = self.encoder(encoder_in, pos_embed=pos_embed)
decoder_in = torch.zeros( decoder_in = torch.zeros(
(self.horizon, batch_size, self.hidden_dim), dtype=pos_embed.dtype, device=pos_embed.device (self.horizon, batch_size, self.d_model), dtype=pos_embed.dtype, device=pos_embed.device
) )
decoder_out = self.decoder( decoder_out = self.decoder(
decoder_in, decoder_in,
encoder_out, encoder_out,
encoder_pos_embed=pos_embed, encoder_pos_embed=pos_embed,
decoder_pos_embed=self.decoder_pos_embed_embed.weight.unsqueeze(1), decoder_pos_embed=self.decoder_pos_embed.weight.unsqueeze(1),
).transpose(0, 1) # back to (B, S, C) )
# Move back to (B, S, C).
decoder_out = decoder_out.transpose(0, 1)
actions = self.action_head(decoder_out) actions = self.action_head(decoder_out)
return actions, [mu, log_sigma_x2] return actions, [mu, log_sigma_x2]
class TransformerEncoder(nn.Module): class _TransformerEncoder(nn.Module):
def __init__( """Convenience module for running multiple encoder layers, maybe followed by normalization."""
self,
num_layers, def __init__(self, num_layers: int, **encoder_layer_kwargs: dict):
d_model,
nhead,
dim_feedforward=2048,
dropout=0.1,
activation="relu",
normalize_before=False,
):
super().__init__() super().__init__()
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [_TransformerEncoderLayer(**encoder_layer_kwargs) for _ in range(num_layers)]
TransformerEncoderLayer( )
d_model, nhead, dim_feedforward, dropout, activation, normalize_before self.norm = (
) nn.LayerNorm(encoder_layer_kwargs["d_model"])
for _ in range(num_layers) if encoder_layer_kwargs["normalize_before"]
] else nn.Identity()
) )
self.norm = nn.LayerNorm(d_model) if normalize_before else nn.Identity()
def forward(self, x, pos: Optional[Tensor] = None): def forward(self, x: Tensor, pos_embed: Tensor | None = None) -> Tensor:
for layer in self.layers: for layer in self.layers:
x = layer(x, pos=pos) x = layer(x, pos_embed=pos_embed)
x = self.norm(x) x = self.norm(x)
return x return x
class TransformerEncoderLayer(nn.Module): class _TransformerEncoderLayer(nn.Module):
def __init__( def __init__(
self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False self,
d_model: int,
num_heads: int,
dim_feedforward: int,
dropout: float,
activation: str,
normalize_before: bool,
): ):
super().__init__() super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
# Implementation of Feedforward model
# Feed forward layers.
self.linear1 = nn.Linear(d_model, dim_feedforward) self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model) self.linear2 = nn.Linear(dim_feedforward, d_model)
@ -522,7 +476,7 @@ class TransformerEncoderLayer(nn.Module):
self.activation = _get_activation_fn(activation) self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before self.normalize_before = normalize_before
def forward(self, x, pos_embed: Optional[Tensor] = None): def forward(self, x, pos_embed: Tensor | None = None) -> Tensor:
skip = x skip = x
if self.normalize_before: if self.normalize_before:
x = self.norm1(x) x = self.norm1(x)
@ -542,32 +496,23 @@ class TransformerEncoderLayer(nn.Module):
return x return x
class TransformerDecoder(nn.Module): class _TransformerDecoder(nn.Module):
def __init__( def __init__(self, num_layers: int, **decoder_layer_kwargs):
self, """Convenience module for running multiple decoder layers followed by normalization."""
num_layers,
d_model,
nhead,
dim_feedforward=2048,
dropout=0.1,
activation="relu",
normalize_before=False,
):
super().__init__() super().__init__()
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [_TransformerDecoderLayer(**decoder_layer_kwargs) for _ in range(num_layers)]
TransformerDecoderLayer(
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
)
for _ in range(num_layers)
]
) )
self.num_layers = num_layers self.num_layers = num_layers
self.norm = nn.LayerNorm(d_model) self.norm = nn.LayerNorm(decoder_layer_kwargs["d_model"])
def forward( def forward(
self, x, encoder_out, decoder_pos_embed: Tensor | None = None, encoder_pos_embed: Tensor | None = None self,
): x: Tensor,
encoder_out: Tensor,
decoder_pos_embed: Tensor | None = None,
encoder_pos_embed: Tensor | None = None,
) -> Tensor:
for layer in self.layers: for layer in self.layers:
x = layer( x = layer(
x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed
@ -577,14 +522,21 @@ class TransformerDecoder(nn.Module):
return x return x
class TransformerDecoderLayer(nn.Module): class _TransformerDecoderLayer(nn.Module):
def __init__( def __init__(
self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False self,
d_model: int,
num_heads: int,
dim_feedforward: int,
dropout: float,
activation: str,
normalize_before: bool,
): ):
super().__init__() super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) self.multihead_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
# Implementation of Feedforward model
# Feed forward layers.
self.linear1 = nn.Linear(d_model, dim_feedforward) self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model) self.linear2 = nn.Linear(dim_feedforward, d_model)
@ -650,8 +602,26 @@ class TransformerDecoderLayer(nn.Module):
return x return x
class SinusoidalPositionEmbedding2D(nn.Module): def _create_sinusoidal_position_embedding(num_positions: int, dimension: int) -> Tensor:
"""Sinusoidal positional embeddings similar to what's presented in Attention Is All You Need. """1D sinusoidal positional embeddings as in Attention is All You Need.
Args:
num_positions: Number of token positions required.
Returns: (num_positions, dimension) position embeddings (the first dimension is the batch dimension).
"""
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / dimension) for hid_j in range(dimension)]
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(num_positions)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.from_numpy(sinusoid_table).float()
class _SinusoidalPositionEmbedding2D(nn.Module):
"""2D sinusoidal positional embeddings similar to what's presented in Attention Is All You Need.
The variation is that the position indices are normalized in [0, 2π] (not quite: the lower bound is 1/H The variation is that the position indices are normalized in [0, 2π] (not quite: the lower bound is 1/H
for the vertical direction, and 1/W for the horizontal direction. for the vertical direction, and 1/W for the horizontal direction.
@ -705,7 +675,7 @@ class SinusoidalPositionEmbedding2D(nn.Module):
def _get_activation_fn(activation: str) -> Callable: def _get_activation_fn(activation: str) -> Callable:
"""Return an activation function given a string""" """Return an activation function given a string."""
if activation == "relu": if activation == "relu":
return F.relu return F.relu
if activation == "gelu": if activation == "gelu":

View File

@ -21,24 +21,27 @@ policy:
lr: 1e-5 lr: 1e-5
lr_backbone: 1e-5 lr_backbone: 1e-5
pretrained_backbone: true
weight_decay: 1e-4 weight_decay: 1e-4
grad_clip_norm: 10 grad_clip_norm: 10
backbone: resnet18 backbone: resnet18
horizon: ${horizon} # chunk_size horizon: ${horizon} # chunk_size
kl_weight: 10 kl_weight: 10
hidden_dim: 512 d_model: 512
dim_feedforward: 3200 dim_feedforward: 3200
vae_enc_layers: 4
enc_layers: 4 enc_layers: 4
dec_layers: 1 dec_layers: 1
nheads: 8 num_heads: 8
#camera_names: [top, front_close, left_pillar, right_pillar] #camera_names: [top, front_close, left_pillar, right_pillar]
camera_names: [top] camera_names: [top]
dilation: false dilation: false
dropout: 0.1 dropout: 0.1
pre_norm: false pre_norm: false
activation: relu activation: relu
latent_dim: 32
vae: true use_vae: true
batch_size: 8 batch_size: 8

View File

@ -42,6 +42,8 @@ start_replacements = [
("model.transformer.encoder.", "model.encoder."), ("model.transformer.encoder.", "model.encoder."),
("model.transformer.decoder.", "model.decoder."), ("model.transformer.decoder.", "model.decoder."),
("model.backbones.0.0.body.", "model.backbone."), ("model.backbones.0.0.body.", "model.backbone."),
("model.additional_pos_embed.weight", "model.encoder_robot_and_latent_pos_embed.weight"),
("model.cls_embed.weight", "model.vae_encoder_cls_embed.weight"),
] ]
for to_replace, replace_with in start_replacements: for to_replace, replace_with in start_replacements: