lerobot/lerobot/common/policies/act/policy.py

679 lines
29 KiB
Python
Raw Normal View History

2024-04-05 01:34:41 +08:00
"""Action Chunking Transformer Policy
As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://arxiv.org/abs/2304.13705).
2024-04-05 18:03:28 +08:00
The majority of changes here involve removing unused code, unifying naming, and adding helpful comments.
2024-04-05 01:34:41 +08:00
"""
2024-04-08 17:23:26 +08:00
2024-04-05 01:34:41 +08:00
import math
import time
2024-04-08 17:23:26 +08:00
from collections import deque
2024-04-05 01:34:41 +08:00
from itertools import chain
2024-04-05 18:03:28 +08:00
from typing import Callable
2024-04-05 01:34:41 +08:00
import einops
import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
2024-04-05 01:34:41 +08:00
import torchvision
import torchvision.transforms as transforms
2024-04-05 01:34:41 +08:00
from torch import Tensor, nn
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.ops.misc import FrozenBatchNorm2d
2024-03-21 01:38:55 +08:00
from lerobot.common.utils import get_safe_torch_device
2024-04-06 00:38:29 +08:00
class ActionChunkingTransformerPolicy(nn.Module):
2024-04-03 02:11:53 +08:00
"""
2024-04-05 01:34:41 +08:00
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
2024-04-05 18:03:28 +08:00
Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
Note: In this code we use the terms `vae_encoder`, 'encoder', `decoder`. The meanings are as follows.
- The `vae_encoder` is, as per the literature around variational auto-encoders (VAE), the part of the
model that encodes the target data (a sequence of actions), and the condition (the robot
joint-space).
- A transformer with an `encoder` (not the VAE encoder) and `decoder` (not the VAE decoder) with
cross-attention is used as the VAE decoder. For these terms, we drop the `vae_` prefix because we
have an option to train this model without the variational objective (in which case we drop the
`vae_encoder` altogether, and nothing about this model has anything to do with a VAE).
Transformer
Used alone for inference
(acts as VAE decoder
during training)
Outputs
Transf.
decoder
VAE
encoder Transf.
encoder
inputs
2024-04-03 02:11:53 +08:00
"""
name = "act"
2024-04-06 00:38:29 +08:00
_multiple_obs_steps_not_handled_msg = (
"ActionChunkingTransformerPolicy does not handle multiple observation steps."
)
def __init__(self, cfg, device):
2024-04-03 02:11:53 +08:00
"""
2024-04-08 20:10:19 +08:00
TODO(alexander-soare): Add documentation for all parameters once we have model configs established.
2024-04-03 02:11:53 +08:00
"""
2024-04-06 00:38:29 +08:00
super().__init__()
if getattr(cfg, "n_obs_steps", 1) != 1:
raise ValueError(self._multiple_obs_steps_not_handled_msg)
self.cfg = cfg
self.n_action_steps = cfg.n_action_steps
2024-03-21 01:38:55 +08:00
self.device = get_safe_torch_device(device)
2024-04-06 00:38:29 +08:00
self.camera_names = cfg.camera_names
self.use_vae = cfg.use_vae
self.horizon = cfg.horizon
self.d_model = cfg.d_model
transformer_common_kwargs = dict( # noqa: C408
d_model=self.d_model,
num_heads=cfg.num_heads,
dim_feedforward=cfg.dim_feedforward,
dropout=cfg.dropout,
activation=cfg.activation,
normalize_before=cfg.pre_norm,
)
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
if self.use_vae:
self.vae_encoder = _TransformerEncoder(num_layers=cfg.vae_enc_layers, **transformer_common_kwargs)
self.vae_encoder_cls_embed = nn.Embedding(1, self.d_model)
# Projection layer for joint-space configuration to hidden dimension.
self.vae_encoder_robot_state_input_proj = nn.Linear(cfg.state_dim, self.d_model)
# Projection layer for action (joint-space target) to hidden dimension.
self.vae_encoder_action_input_proj = nn.Linear(cfg.state_dim, self.d_model)
self.latent_dim = cfg.latent_dim
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
self.vae_encoder_latent_output_proj = nn.Linear(self.d_model, self.latent_dim * 2)
# Fixed sinusoidal positional embedding the whole input to the VAE encoder. Unsqueeze for batch
# dimension.
self.register_buffer(
"vae_encoder_pos_enc",
_create_sinusoidal_position_embedding(1 + 1 + self.horizon, self.d_model).unsqueeze(0),
)
# Backbone for image feature extraction.
2024-04-08 20:10:19 +08:00
self.image_normalizer = transforms.Normalize(
mean=cfg.image_normalization.mean, std=cfg.image_normalization.std
)
2024-04-06 00:38:29 +08:00
backbone_model = getattr(torchvision.models, cfg.backbone)(
replace_stride_with_dilation=[False, False, cfg.dilation],
pretrained=cfg.pretrained_backbone,
norm_layer=FrozenBatchNorm2d,
)
# Note: The forward method of this returns a dict: {"feature_map": output}.
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
# Transformer (acts as VAE decoder when training with the variational objective).
self.encoder = _TransformerEncoder(num_layers=cfg.enc_layers, **transformer_common_kwargs)
self.decoder = _TransformerDecoder(num_layers=cfg.dec_layers, **transformer_common_kwargs)
# Transformer encoder input projections. The tokens will be structured like
# [latent, robot_state, image_feature_map_pixels].
self.encoder_robot_state_input_proj = nn.Linear(cfg.state_dim, self.d_model)
self.encoder_latent_input_proj = nn.Linear(self.latent_dim, self.d_model)
self.encoder_img_feat_input_proj = nn.Conv2d(
backbone_model.fc.in_features, self.d_model, kernel_size=1
)
# Transformer encoder positional embeddings.
self.encoder_robot_and_latent_pos_embed = nn.Embedding(2, self.d_model)
self.encoder_cam_feat_pos_embed = _SinusoidalPositionEmbedding2D(self.d_model // 2)
# Transformer decoder.
# Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
self.decoder_pos_embed = nn.Embedding(self.horizon, self.d_model)
# Final action regression head on the output of the transformer's decoder.
self.action_head = nn.Linear(self.d_model, cfg.action_dim)
self._reset_parameters()
2024-04-05 01:34:41 +08:00
2024-04-05 18:03:28 +08:00
self._create_optimizer()
self.to(self.device)
2024-04-05 01:34:41 +08:00
2024-04-05 18:03:28 +08:00
def _create_optimizer(self):
2024-04-05 01:34:41 +08:00
optimizer_params_dicts = [
{
"params": [
2024-04-06 00:38:29 +08:00
p for n, p in self.named_parameters() if not n.startswith("backbone") and p.requires_grad
2024-04-05 01:34:41 +08:00
]
},
{
"params": [
2024-04-06 00:38:29 +08:00
p for n, p in self.named_parameters() if n.startswith("backbone") and p.requires_grad
2024-04-05 01:34:41 +08:00
],
2024-04-05 18:03:28 +08:00
"lr": self.cfg.lr_backbone,
2024-04-05 01:34:41 +08:00
},
]
2024-04-05 18:03:28 +08:00
self.optimizer = torch.optim.AdamW(
optimizer_params_dicts, lr=self.cfg.lr, weight_decay=self.cfg.weight_decay
)
2024-04-06 00:38:29 +08:00
def _reset_parameters(self):
"""Xavier-uniform initialization of the transformer parameters as in the original code."""
for p in chain(self.encoder.parameters(), self.decoder.parameters()):
if p.dim() > 1:
nn.init.xavier_uniform_(p)
2024-04-08 17:23:26 +08:00
def reset(self):
"""This should be called whenever the environment is reset."""
if self.n_action_steps is not None:
self._action_queue = deque([], maxlen=self.n_action_steps)
def select_action(self, batch: dict[str, Tensor], *_, **__) -> Tensor:
2024-04-08 17:23:26 +08:00
"""
This method wraps `select_actions` in order to return one action at a time for execution in the
environment. It works by managing the actions in a queue and only calling `select_actions` when the
queue is empty.
"""
if len(self._action_queue) == 0:
# `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has shape
# (n_action_steps, batch_size, *), hence the transpose.
self._action_queue.extend(self.select_actions(batch).transpose(0, 1))
return self._action_queue.popleft()
@torch.no_grad()
2024-04-08 21:51:45 +08:00
def select_actions(self, batch: dict[str, Tensor]) -> Tensor:
2024-04-08 17:23:26 +08:00
"""Use the action chunking transformer to generate a sequence of actions."""
self.eval()
2024-04-08 17:23:26 +08:00
self._preprocess_batch(batch, add_obs_steps_dim=True)
2024-04-08 17:23:26 +08:00
action = self.forward(batch, return_loss=False)
if self.cfg.temporal_agg:
# TODO(rcadene): implement temporal aggregation
raise NotImplementedError()
# all_time_actions[[t], t:t+num_queries] = action
# actions_for_curr_step = all_time_actions[:, t]
# actions_populated = torch.all(actions_for_curr_step != 0, axis=1)
# actions_for_curr_step = actions_for_curr_step[actions_populated]
# k = 0.01
# exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step)))
# exp_weights = exp_weights / exp_weights.sum()
# exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1)
# raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True)
return action[: self.n_action_steps]
2024-04-08 21:51:45 +08:00
def __call__(self, *args, **kwargs) -> dict:
2024-04-08 17:23:26 +08:00
# TODO(now): Temporary bridge until we know what to do about the `update` method.
2024-04-06 00:38:29 +08:00
return self.update(*args, **kwargs)
2024-04-08 17:23:26 +08:00
def _preprocess_batch(
self, batch: dict[str, Tensor], add_obs_steps_dim: bool = False
) -> dict[str, Tensor]:
2024-04-06 00:38:29 +08:00
"""
2024-04-08 17:23:26 +08:00
This function expects `batch` to have (at least):
2024-04-06 00:38:29 +08:00
{
2024-04-08 17:23:26 +08:00
"observation.state": (B, 1, J) OR (B, J) tensor of robot states (joint configuration).
"observation.images.top": (B, 1, C, H, W) OR (B, C, H, W) tensor of images.
2024-04-06 00:38:29 +08:00
"action": (B, H, J) tensor of actions (positional target for robot joint configuration)
"action_is_pad": (B, H) mask for whether the actions are padding outside of the episode bounds.
}
"""
2024-04-08 17:23:26 +08:00
if add_obs_steps_dim:
# Add a dimension for the observations steps. Since n_obs_steps > 1 is not supported right now,
# this just amounts to an unsqueeze.
for k in batch:
if k.startswith("observation."):
batch[k] = batch[k].unsqueeze(1)
2024-04-06 00:38:29 +08:00
if batch["observation.state"].shape[1] != 1:
raise ValueError(self._multiple_obs_steps_not_handled_msg)
batch["observation.state"] = batch["observation.state"].squeeze(1)
2024-04-08 17:23:26 +08:00
# TODO(alexander-soare): generalize this to multiple images.
assert (
sum(k.startswith("observation.images.") and not k.endswith("is_pad") for k in batch) == 1
), "ACT only handles one image for now."
# Note: no squeeze is required for "observation.images.top" because then we'd have to unsqueeze to get
# the image index dimension.
2024-04-06 00:38:29 +08:00
def update(self, batch, *_, **__) -> dict:
2024-04-06 00:38:29 +08:00
start_time = time.time()
self._preprocess_batch(batch)
self.train()
num_slices = self.cfg.batch_size
batch_size = self.cfg.horizon * num_slices
assert batch_size % self.cfg.horizon == 0
assert batch_size % num_slices == 0
2024-04-06 01:46:30 +08:00
loss = self.forward(batch, return_loss=True)["loss"]
2024-04-06 00:38:29 +08:00
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
self.parameters(),
self.cfg.grad_clip_norm,
error_if_nonfinite=False,
)
self.optimizer.step()
self.optimizer.zero_grad()
info = {
"loss": loss.item(),
"grad_norm": float(grad_norm),
"lr": self.cfg.lr,
"update_s": time.time() - start_time,
}
return info
2024-04-08 21:51:45 +08:00
def forward(self, batch: dict[str, Tensor], return_loss: bool = False) -> dict | Tensor:
2024-04-08 20:10:19 +08:00
images = self.image_normalizer(batch["observation.images.top"])
2024-04-06 01:46:30 +08:00
if return_loss: # training time
actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(
batch["observation.state"], images, batch["action"]
)
2024-04-06 01:46:30 +08:00
l1_loss = (
F.l1_loss(batch["action"], actions_hat, reduction="none")
* ~batch["action_is_pad"].unsqueeze(-1)
).mean()
loss_dict = {}
2024-04-06 01:46:30 +08:00
loss_dict["l1"] = l1_loss
2024-04-05 18:03:28 +08:00
if self.cfg.use_vae:
2024-04-05 01:34:41 +08:00
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
# each dimension independently, we sum over the latent dimension to get the total
# KL-divergence per batch element, then take the mean over the batch.
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
2024-04-06 01:46:30 +08:00
mean_kld = (
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
)
2024-04-05 01:34:41 +08:00
loss_dict["kl"] = mean_kld
2024-04-05 18:03:28 +08:00
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.cfg.kl_weight
else:
loss_dict["loss"] = loss_dict["l1"]
return loss_dict
else:
2024-04-06 01:46:30 +08:00
action, _ = self._forward(batch["observation.state"], images)
return action
2024-04-05 01:34:41 +08:00
2024-04-08 21:51:45 +08:00
def _forward(
self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None
) -> tuple[Tensor, tuple[Tensor | None, Tensor | None]]:
2024-04-05 01:34:41 +08:00
"""
Args:
robot_state: (B, J) batch of robot joint configurations.
image: (B, N, C, H, W) batch of N camera frames.
actions: (B, S, A) batch of actions from the target dataset which must be provided if the
VAE is enabled and the model is in training mode.
2024-04-06 00:38:29 +08:00
Returns:
(B, S, A) batch of action sequences
Tuple containing the latent PDF's parameters (mean, log(σ²)) both as (B, L) tensors where L is the
latent dimension.
2024-04-05 01:34:41 +08:00
"""
if self.use_vae and self.training:
assert (
actions is not None
), "actions must be provided when using the variational objective in training mode."
2024-04-06 00:38:29 +08:00
batch_size = robot_state.shape[0]
2024-04-05 01:34:41 +08:00
2024-04-05 18:03:28 +08:00
# Prepare the latent for input to the transformer encoder.
2024-04-05 01:34:41 +08:00
if self.use_vae and actions is not None:
# Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence].
2024-04-05 18:03:28 +08:00
cls_embed = einops.repeat(
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
) # (B, 1, D)
2024-04-05 01:34:41 +08:00
robot_state_embed = self.vae_encoder_robot_state_input_proj(robot_state).unsqueeze(1) # (B, 1, D)
action_embed = self.vae_encoder_action_input_proj(actions) # (B, S, D)
vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D)
2024-04-08 21:44:10 +08:00
2024-04-05 01:34:41 +08:00
# Prepare fixed positional embedding.
2024-04-08 21:44:10 +08:00
# Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
2024-04-05 01:34:41 +08:00
pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D)
2024-04-08 21:44:10 +08:00
# Forward pass through VAE encoder to get the latent PDF parameters.
2024-04-05 01:34:41 +08:00
cls_token_out = self.vae_encoder(
2024-04-05 18:03:28 +08:00
vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2)
2024-04-08 21:44:10 +08:00
)[0] # select the class token, with shape (B, D)
2024-04-05 01:34:41 +08:00
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
mu = latent_pdf_params[:, : self.latent_dim]
# This is 2log(sigma). Done this way to match the original implementation.
log_sigma_x2 = latent_pdf_params[:, self.latent_dim :]
# Sample the latent with the reparameterization trick.
2024-04-05 01:34:41 +08:00
latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu)
else:
# When not using the VAE encoder, we set the latent to be all zeros.
mu = log_sigma_x2 = None
latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=torch.float32).to(
robot_state.device
)
2024-04-05 18:03:28 +08:00
# Prepare all other transformer encoder inputs.
# Camera observation features and positional embeddings.
2024-04-05 01:34:41 +08:00
all_cam_features = []
2024-04-05 18:03:28 +08:00
all_cam_pos_embeds = []
2024-04-05 01:34:41 +08:00
for cam_id, _ in enumerate(self.camera_names):
cam_features = self.backbone(image[:, cam_id])["feature_map"]
2024-04-05 18:03:28 +08:00
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
2024-04-05 01:34:41 +08:00
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
all_cam_features.append(cam_features)
2024-04-05 18:03:28 +08:00
all_cam_pos_embeds.append(cam_pos_embed)
# Concatenate camera observation feature maps and positional embeddings along the width dimension.
2024-04-05 01:34:41 +08:00
encoder_in = torch.cat(all_cam_features, axis=3)
2024-04-05 18:03:28 +08:00
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=3)
# Get positional embeddings for robot state and latent.
2024-04-05 01:34:41 +08:00
robot_state_embed = self.encoder_robot_state_input_proj(robot_state)
latent_embed = self.encoder_latent_input_proj(latent_sample)
2024-04-05 18:03:28 +08:00
# Stack encoder input and positional embeddings moving to (S, B, C).
2024-04-05 01:34:41 +08:00
encoder_in = torch.cat(
[
torch.stack([latent_embed, robot_state_embed], axis=0),
encoder_in.flatten(2).permute(2, 0, 1),
]
)
pos_embed = torch.cat(
2024-04-05 18:03:28 +08:00
[
self.encoder_robot_and_latent_pos_embed.weight.unsqueeze(1),
cam_pos_embed.flatten(2).permute(2, 0, 1),
],
axis=0,
2024-04-05 01:34:41 +08:00
)
2024-04-05 18:03:28 +08:00
# Forward pass through the transformer modules.
encoder_out = self.encoder(encoder_in, pos_embed=pos_embed)
2024-04-05 01:34:41 +08:00
decoder_in = torch.zeros(
2024-04-05 18:03:28 +08:00
(self.horizon, batch_size, self.d_model), dtype=pos_embed.dtype, device=pos_embed.device
2024-04-05 01:34:41 +08:00
)
decoder_out = self.decoder(
decoder_in,
encoder_out,
encoder_pos_embed=pos_embed,
2024-04-05 18:03:28 +08:00
decoder_pos_embed=self.decoder_pos_embed.weight.unsqueeze(1),
)
# Move back to (B, S, C).
decoder_out = decoder_out.transpose(0, 1)
2024-04-05 01:34:41 +08:00
actions = self.action_head(decoder_out)
2024-04-05 18:03:28 +08:00
2024-04-08 21:51:45 +08:00
return actions, (mu, log_sigma_x2)
2024-04-05 01:34:41 +08:00
2024-04-06 00:38:29 +08:00
def save(self, fp):
torch.save(self.state_dict(), fp)
def load(self, fp):
d = torch.load(fp)
self.load_state_dict(d)
2024-04-05 01:34:41 +08:00
2024-04-05 18:03:28 +08:00
class _TransformerEncoder(nn.Module):
"""Convenience module for running multiple encoder layers, maybe followed by normalization."""
def __init__(self, num_layers: int, **encoder_layer_kwargs: dict):
2024-04-05 01:34:41 +08:00
super().__init__()
self.layers = nn.ModuleList(
2024-04-05 18:03:28 +08:00
[_TransformerEncoderLayer(**encoder_layer_kwargs) for _ in range(num_layers)]
)
self.norm = (
nn.LayerNorm(encoder_layer_kwargs["d_model"])
if encoder_layer_kwargs["normalize_before"]
else nn.Identity()
2024-04-05 01:34:41 +08:00
)
2024-04-05 18:03:28 +08:00
def forward(self, x: Tensor, pos_embed: Tensor | None = None) -> Tensor:
2024-04-05 01:34:41 +08:00
for layer in self.layers:
2024-04-05 18:03:28 +08:00
x = layer(x, pos_embed=pos_embed)
2024-04-05 01:34:41 +08:00
x = self.norm(x)
return x
2024-04-05 18:03:28 +08:00
class _TransformerEncoderLayer(nn.Module):
2024-04-05 01:34:41 +08:00
def __init__(
2024-04-05 18:03:28 +08:00
self,
d_model: int,
num_heads: int,
dim_feedforward: int,
dropout: float,
activation: str,
normalize_before: bool,
2024-04-05 01:34:41 +08:00
):
super().__init__()
2024-04-05 18:03:28 +08:00
self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
# Feed forward layers.
2024-04-05 01:34:41 +08:00
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
2024-04-05 18:03:28 +08:00
def forward(self, x, pos_embed: Tensor | None = None) -> Tensor:
2024-04-05 01:34:41 +08:00
skip = x
if self.normalize_before:
x = self.norm1(x)
q = k = x if pos_embed is None else x + pos_embed
2024-04-08 21:44:10 +08:00
x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights
2024-04-05 01:34:41 +08:00
x = skip + self.dropout1(x)
if self.normalize_before:
skip = x
x = self.norm2(x)
else:
x = self.norm1(x)
skip = x
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
x = skip + self.dropout2(x)
if not self.normalize_before:
x = self.norm2(x)
return x
2024-04-05 18:03:28 +08:00
class _TransformerDecoder(nn.Module):
def __init__(self, num_layers: int, **decoder_layer_kwargs):
"""Convenience module for running multiple decoder layers followed by normalization."""
2024-04-05 01:34:41 +08:00
super().__init__()
self.layers = nn.ModuleList(
2024-04-05 18:03:28 +08:00
[_TransformerDecoderLayer(**decoder_layer_kwargs) for _ in range(num_layers)]
2024-04-05 01:34:41 +08:00
)
self.num_layers = num_layers
2024-04-05 18:03:28 +08:00
self.norm = nn.LayerNorm(decoder_layer_kwargs["d_model"])
2024-04-05 01:34:41 +08:00
def forward(
2024-04-05 18:03:28 +08:00
self,
x: Tensor,
encoder_out: Tensor,
decoder_pos_embed: Tensor | None = None,
encoder_pos_embed: Tensor | None = None,
) -> Tensor:
2024-04-05 01:34:41 +08:00
for layer in self.layers:
x = layer(
x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed
)
if self.norm is not None:
x = self.norm(x)
return x
2024-04-05 18:03:28 +08:00
class _TransformerDecoderLayer(nn.Module):
2024-04-05 01:34:41 +08:00
def __init__(
2024-04-05 18:03:28 +08:00
self,
d_model: int,
num_heads: int,
dim_feedforward: int,
dropout: float,
activation: str,
normalize_before: bool,
2024-04-05 01:34:41 +08:00
):
super().__init__()
2024-04-05 18:03:28 +08:00
self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
self.multihead_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
# Feed forward layers.
2024-04-05 01:34:41 +08:00
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
def maybe_add_pos_embed(self, tensor: Tensor, pos_embed: Tensor | None) -> Tensor:
return tensor if pos_embed is None else tensor + pos_embed
def forward(
self,
x: Tensor,
encoder_out: Tensor,
decoder_pos_embed: Tensor | None = None,
encoder_pos_embed: Tensor | None = None,
) -> Tensor:
"""
Args:
x: (Decoder Sequence, Batch, Channel) tensor of input tokens.
encoder_out: (Encoder Sequence, B, C) output features from the last layer of the encoder we are
cross-attending with.
decoder_pos_embed: (ES, 1, C) positional embedding for keys (from the encoder).
encoder_pos_embed: (DS, 1, C) Positional_embedding for the queries (from the decoder).
Returns:
(DS, B, C) tensor of decoder output features.
"""
skip = x
if self.normalize_before:
x = self.norm1(x)
q = k = self.maybe_add_pos_embed(x, decoder_pos_embed)
2024-04-08 21:44:10 +08:00
x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights
2024-04-05 01:34:41 +08:00
x = skip + self.dropout1(x)
if self.normalize_before:
skip = x
x = self.norm2(x)
else:
x = self.norm1(x)
skip = x
x = self.multihead_attn(
query=self.maybe_add_pos_embed(x, decoder_pos_embed),
key=self.maybe_add_pos_embed(encoder_out, encoder_pos_embed),
value=encoder_out,
2024-04-08 21:44:10 +08:00
)[0] # select just the output, not the attention weights
2024-04-05 01:34:41 +08:00
x = skip + self.dropout2(x)
if self.normalize_before:
skip = x
x = self.norm3(x)
else:
x = self.norm2(x)
skip = x
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
x = skip + self.dropout3(x)
if not self.normalize_before:
x = self.norm3(x)
return x
2024-04-05 18:03:28 +08:00
def _create_sinusoidal_position_embedding(num_positions: int, dimension: int) -> Tensor:
"""1D sinusoidal positional embeddings as in Attention is All You Need.
Args:
num_positions: Number of token positions required.
Returns: (num_positions, dimension) position embeddings (the first dimension is the batch dimension).
"""
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / dimension) for hid_j in range(dimension)]
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(num_positions)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.from_numpy(sinusoid_table).float()
class _SinusoidalPositionEmbedding2D(nn.Module):
"""2D sinusoidal positional embeddings similar to what's presented in Attention Is All You Need.
2024-04-05 01:34:41 +08:00
The variation is that the position indices are normalized in [0, 2π] (not quite: the lower bound is 1/H
for the vertical direction, and 1/W for the horizontal direction.
"""
def __init__(self, dimension: int):
"""
Args:
dimension: The desired dimension of the embeddings.
"""
super().__init__()
self.dimension = dimension
self._two_pi = 2 * math.pi
self._eps = 1e-6
# Inverse "common ratio" for the geometric progression in sinusoid frequencies.
self._temperature = 10000
def forward(self, x: Tensor) -> Tensor:
"""
Args:
x: A (B, C, H, W) batch of 2D feature map to generate the embeddings for.
Returns:
A (1, C, H, W) batch of corresponding sinusoidal positional embeddings.
"""
2024-04-08 21:44:10 +08:00
not_mask = torch.ones_like(x[0, :1]) # (1, H, W)
2024-04-05 01:34:41 +08:00
# Note: These are like range(1, H+1) and range(1, W+1) respectively, but in most implementations
2024-04-08 21:59:37 +08:00
# they would be range(0, H) and range(0, W). Keeping it at as is to match the original code.
2024-04-05 01:34:41 +08:00
y_range = not_mask.cumsum(1, dtype=torch.float32)
x_range = not_mask.cumsum(2, dtype=torch.float32)
# "Normalize" the position index such that it ranges in [0, 2π].
# Note: Adding epsilon on the denominator should not be needed as all values of y_embed and x_range
# are non-zero by construction. This is an artifact of the original code.
y_range = y_range / (y_range[:, -1:, :] + self._eps) * self._two_pi
x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi
inverse_frequency = self._temperature ** (
2 * (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2) / self.dimension
)
x_range = x_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)
y_range = y_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)
# Note: this stack then flatten operation results in interleaved sine and cosine terms.
2024-04-08 21:59:37 +08:00
# pos_embed_x and pos_embed_y are (1, H, W, C // 2).
2024-04-05 01:34:41 +08:00
pos_embed_x = torch.stack((x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1).flatten(3)
pos_embed_y = torch.stack((y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1).flatten(3)
pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(0, 3, 1, 2) # (1, C, H, W)
return pos_embed
def _get_activation_fn(activation: str) -> Callable:
2024-04-05 18:03:28 +08:00
"""Return an activation function given a string."""
2024-04-05 01:34:41 +08:00
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(f"activation should be relu/gelu/glu, not {activation}.")