backup wip

This commit is contained in:
Alexander Soare 2024-04-05 11:03:28 +01:00
parent 3a4dfa82fe
commit edb125b351
3 changed files with 188 additions and 213 deletions
lerobot
common/policies/act
configs/policy
scripts

View File

@ -1,13 +1,13 @@
"""Action Chunking Transformer Policy
As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://arxiv.org/abs/2304.13705).
The majority of changes here involve removing unused code, unifying naming, and adding helpful comments.
"""
import logging
import math
import time
from itertools import chain
from typing import Callable, Optional
from typing import Callable
import einops
import numpy as np
@ -26,40 +26,56 @@ from lerobot.common.utils import get_safe_torch_device
class ActionChunkingTransformerPolicy(AbstractPolicy):
"""
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
Hardware (https://arxiv.org/abs/2304.13705).
Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
Note: In this code we use the terms `vae_encoder`, 'encoder', `decoder`. The meanings are as follows.
- The `vae_encoder` is, as per the literature around variational auto-encoders (VAE), the part of the
model that encodes the target data (a sequence of actions), and the condition (the robot
joint-space).
- A transformer with an `encoder` (not the VAE encoder) and `decoder` (not the VAE decoder) with
cross-attention is used as the VAE decoder. For these terms, we drop the `vae_` prefix because we
have an option to train this model without the variational objective (in which case we drop the
`vae_encoder` altogether, and nothing about this model has anything to do with a VAE).
Transformer
Used alone for inference
(acts as VAE decoder
during training)
Outputs
Transf.
decoder
VAE
encoder Transf.
encoder
inputs
"""
name = "act"
def __init__(self, cfg, device, n_action_steps=1):
"""
Args:
vae: Whether to use the variational objective. TODO(now): Give more details.
temporal_agg: Whether to do temporal aggregation. For each timestep during rollout, the action
returned as an exponential moving average of previously generated actions for that timestep.
n_obs_steps: Number of time steps worth of observation to use as input.
horizon: The number of actions to generate in one forward pass.
kl_weight: Weight for KL divergence. Defaults to None. Only applicable when using the variational
objective.
batch_size: Training batch size.
grad_clip_norm: Optionally clip the gradients to have this value as the norm at most. Defaults to
None meaning gradient clipping is not applied.
lr: Learning rate.
TODO(alexander-soare): Add documentation for all parameters.
"""
super().__init__(n_action_steps)
self.cfg = cfg
self.n_action_steps = n_action_steps
self.device = get_safe_torch_device(device)
self.model = ActionChunkingTransformer(
cfg,
state_dim=cfg.state_dim,
action_dim=cfg.action_dim,
horizon=cfg.horizon,
camera_names=cfg.camera_names,
use_vae=cfg.vae,
)
self.model = _ActionChunkingTransformer(cfg)
self._create_optimizer()
self.to(self.device)
def _create_optimizer(self):
optimizer_params_dicts = [
{
"params": [
@ -74,14 +90,12 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
for n, p in self.model.named_parameters()
if n.startswith("backbone") and p.requires_grad
],
"lr": cfg.lr_backbone,
"lr": self.cfg.lr_backbone,
},
]
self.optimizer = torch.optim.AdamW(optimizer_params_dicts, lr=cfg.lr, weight_decay=cfg.weight_decay)
self.kl_weight = self.cfg.kl_weight
logging.info(f"KL Weight {self.kl_weight}")
self.to(self.device)
self.optimizer = torch.optim.AdamW(
optimizer_params_dicts, lr=self.cfg.lr, weight_decay=self.cfg.weight_decay
)
def update(self, replay_buffer, step):
del step
@ -137,7 +151,6 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
batch = process_batch(batch, self.cfg.horizon, num_slices)
data_s = time.time() - start_time
print(data_s)
loss = self.compute_loss(batch)
loss.backward()
@ -192,16 +205,6 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
"image": observation["image", "top"],
"agent_pos": observation["state"],
}
# qpos = obs_dict["agent_pos"]
# img = obs_dict["image"]
# qpos_ = torch.load('/tmp/qpos.pth')
# img_ = torch.load('/tmp/curr_image.pth')
# out_ = torch.load('/tmp/out.pth')
# import cv2, numpy as np
# cv2.imwrite("ours.png", (obs_dict["image"][0, 0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
# cv2.imwrite("theirs.png", (img_[0, 0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
# out = self._forward(qpos_, img_)
# breakpoint()
action = self._forward(qpos=obs_dict["agent_pos"] * 0.182, image=obs_dict["image"])
if self.cfg.temporal_agg:
@ -236,14 +239,14 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
loss_dict = {}
loss_dict["l1"] = l1
if self.cfg.vae:
if self.cfg.use_vae:
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
# each dimension independently, we sum over the latent dimension to get the total
# KL-divergence per batch element, then take the mean over the batch.
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
mean_kld = (-0.5 * (1 + log_sigma_x2 - mu.pow(2) - (log_sigma_x2).exp())).sum(-1).mean()
loss_dict["kl"] = mean_kld
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.kl_weight
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.cfg.kl_weight
else:
loss_dict["loss"] = loss_dict["l1"]
return loss_dict
@ -252,135 +255,74 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
return action
def create_sinusoidal_position_embedding(n_position, d_hid):
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
# TODO(alexander-soare) move all this code into the policy when we have the policy API established.
class ActionChunkingTransformer(nn.Module):
"""
Action Chunking Transformer as per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware
(paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
Note: In this code we use the terms `vae_encoder`, 'encoder', `decoder`. The meanings are as follows.
- The `vae_encoder` is, as per the literature around variational auto-encoders (VAE), the part of the
model that encodes the target data (a sequence of actions), and the condition (the robot
joint-space).
- A transformer with an `encoder` (not the VAE encoder) and `decoder` (not the VAE decoder) with
cross-attention is used as the VAE decoder. For these terms, we drop the `vae_` prefix because we
have an option to train this model without the variational objective (in which case we drop the
`vae_encoder` altogether, and nothing about this model has anything to do with a VAE).
Transformer
Used alone for inference
(acts as VAE decoder
during training)
Outputs
Transf.
decoder
VAE
encoder Transf.
encoder
inputs
"""
def __init__(self, args, state_dim, action_dim, horizon, camera_names, use_vae):
"""Initializes the model.
Parameters:
state_dim: robot state dimension of the environment
horizon: number of object queries, ie detection slot. This is the maximal number of objects
DETR can detect in a single image. For COCO, we recommend 100 queries.
Args:
state_dim: Robot positional state dimension.
action_dim: Action dimension.
horizon: The number of actions to generate in one forward pass.
use_vae: Whether to use the variational objective. TODO(now): Give more details.
"""
class _ActionChunkingTransformer(nn.Module):
def __init__(self, cfg):
super().__init__()
self.camera_names = camera_names
self.use_vae = use_vae
self.horizon = horizon
self.hidden_dim = args.hidden_dim
self.camera_names = cfg.camera_names
self.use_vae = cfg.use_vae
self.horizon = cfg.horizon
self.d_model = cfg.d_model
transformer_common_kwargs = dict( # noqa: C408
d_model=self.hidden_dim,
nhead=args.nheads,
dim_feedforward=args.dim_feedforward,
dropout=args.dropout,
activation=args.activation,
normalize_before=args.pre_norm,
d_model=self.d_model,
num_heads=cfg.num_heads,
dim_feedforward=cfg.dim_feedforward,
dropout=cfg.dropout,
activation=cfg.activation,
normalize_before=cfg.pre_norm,
)
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
if use_vae:
# TODO(now): args.enc_layers shouldn't be shared with the transformer decoder
self.vae_encoder = TransformerEncoder(num_layers=args.enc_layers, **transformer_common_kwargs)
self.cls_embed = nn.Embedding(1, self.hidden_dim)
if self.use_vae:
self.vae_encoder = _TransformerEncoder(num_layers=cfg.vae_enc_layers, **transformer_common_kwargs)
self.vae_encoder_cls_embed = nn.Embedding(1, self.d_model)
# Projection layer for joint-space configuration to hidden dimension.
self.vae_encoder_robot_state_input_proj = nn.Linear(state_dim, self.hidden_dim)
self.vae_encoder_robot_state_input_proj = nn.Linear(cfg.state_dim, self.d_model)
# Projection layer for action (joint-space target) to hidden dimension.
self.vae_encoder_action_input_proj = nn.Linear(state_dim, self.hidden_dim)
# Final size of latent z. TODO(now): Add to hyperparams.
self.latent_dim = 32
self.vae_encoder_action_input_proj = nn.Linear(cfg.state_dim, self.d_model)
self.latent_dim = cfg.latent_dim
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
self.vae_encoder_latent_output_proj = nn.Linear(self.hidden_dim, self.latent_dim * 2)
# Fixed sinusoidal positional embedding the whole input to the VAE encoder.
self.vae_encoder_latent_output_proj = nn.Linear(self.d_model, self.latent_dim * 2)
# Fixed sinusoidal positional embedding the whole input to the VAE encoder. Unsqueeze for batch
# dimension.
self.register_buffer(
"vae_encoder_pos_enc", create_sinusoidal_position_embedding(1 + 1 + horizon, self.hidden_dim)
"vae_encoder_pos_enc",
_create_sinusoidal_position_embedding(1 + 1 + self.horizon, self.d_model).unsqueeze(0),
)
# Backbone for image feature extraction.
self.backbone_position_embedding = SinusoidalPositionEmbedding2D(self.hidden_dim // 2)
backbone_model = getattr(torchvision.models, args.backbone)(
replace_stride_with_dilation=[False, False, args.dilation],
pretrained=True, # TODO(now): Add pretrained option
backbone_model = getattr(torchvision.models, cfg.backbone)(
replace_stride_with_dilation=[False, False, cfg.dilation],
pretrained=cfg.pretrained_backbone,
norm_layer=FrozenBatchNorm2d,
)
# Note: The forward method of this returns a dict: {"feature_map": output}.
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
# Transformer (acts as VAE decoder when training with the variational objective).
self.encoder = TransformerEncoder(num_layers=args.enc_layers, **transformer_common_kwargs)
self.decoder = TransformerDecoder(num_layers=args.dec_layers, **transformer_common_kwargs)
self.encoder = _TransformerEncoder(num_layers=cfg.enc_layers, **transformer_common_kwargs)
self.decoder = _TransformerDecoder(num_layers=cfg.dec_layers, **transformer_common_kwargs)
# Transformer encoder input projections. The tokens will be structured like
# [latent, robot_state, image_feature_map_pixels].
self.encoder_robot_state_input_proj = nn.Linear(cfg.state_dim, self.d_model)
self.encoder_latent_input_proj = nn.Linear(self.latent_dim, self.d_model)
self.encoder_img_feat_input_proj = nn.Conv2d(
backbone_model.fc.in_features, self.hidden_dim, kernel_size=1
backbone_model.fc.in_features, self.d_model, kernel_size=1
)
self.encoder_robot_state_input_proj = nn.Linear(state_dim, self.hidden_dim)
self.encoder_latent_input_proj = nn.Linear(self.latent_dim, self.hidden_dim)
# TODO(now): Fix this nonsense. One positional embedding is needed. We should extract the image
# feature dimension with a dry run.
self.additional_pos_embed = nn.Embedding(
2, self.hidden_dim
) # learned position embedding for proprio and latent
# Transformer encoder positional embeddings.
self.encoder_robot_and_latent_pos_embed = nn.Embedding(2, self.d_model)
self.encoder_cam_feat_pos_embed = _SinusoidalPositionEmbedding2D(self.d_model // 2)
# Transformer decoder.
# Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
self.decoder_pos_embed_embed = nn.Embedding(horizon, self.hidden_dim)
self.decoder_pos_embed = nn.Embedding(self.horizon, self.d_model)
# Final action regression head on the output of the transformer's decoder.
self.action_head = nn.Linear(self.hidden_dim, action_dim)
self.action_head = nn.Linear(self.d_model, cfg.action_dim)
self._reset_parameters()
@ -390,7 +332,7 @@ class ActionChunkingTransformer(nn.Module):
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, robot_state, image, actions=None):
def forward(self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None):
"""
Args:
robot_state: (B, J) batch of robot joint configurations.
@ -405,10 +347,12 @@ class ActionChunkingTransformer(nn.Module):
batch_size, _ = robot_state.shape
# Prepare the latent for input to the transformer.
# Prepare the latent for input to the transformer encoder.
if self.use_vae and actions is not None:
# Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence].
cls_embed = einops.repeat(self.cls_embed.weight, "1 d -> b 1 d", b=batch_size) # (B, 1, D)
cls_embed = einops.repeat(
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
) # (B, 1, D)
robot_state_embed = self.vae_encoder_robot_state_input_proj(robot_state).unsqueeze(1) # (B, 1, D)
action_embed = self.vae_encoder_action_input_proj(actions) # (B, S, D)
vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D)
@ -417,7 +361,7 @@ class ActionChunkingTransformer(nn.Module):
pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D)
# Forward pass through VAE encoder and sample the latent with the reparameterization trick.
cls_token_out = self.vae_encoder(
vae_encoder_input.permute(1, 0, 2), pos=pos_embed.permute(1, 0, 2)
vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2)
)[0] # (B, D)
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
mu = latent_pdf_params[:, : self.latent_dim]
@ -432,23 +376,25 @@ class ActionChunkingTransformer(nn.Module):
robot_state.device
)
# Prepare all other transformer inputs.
# Image observation features and position embeddings.
# Prepare all other transformer encoder inputs.
# Camera observation features and positional embeddings.
all_cam_features = []
all_cam_pos = []
all_cam_pos_embeds = []
for cam_id, _ in enumerate(self.camera_names):
cam_features = self.backbone(image[:, cam_id])["feature_map"]
pos = self.backbone_position_embedding(cam_features).to(dtype=cam_features.dtype)
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
all_cam_features.append(cam_features)
all_cam_pos.append(pos)
# Concatenate image observation feature maps along the width dimension.
all_cam_pos_embeds.append(cam_pos_embed)
# Concatenate camera observation feature maps and positional embeddings along the width dimension.
encoder_in = torch.cat(all_cam_features, axis=3)
pos = torch.cat(all_cam_pos, axis=3)
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=3)
# Get positional embeddings for robot state and latent.
robot_state_embed = self.encoder_robot_state_input_proj(robot_state)
latent_embed = self.encoder_latent_input_proj(latent_sample)
# TODO(now): Explain all of this madness.
# Stack encoder input and positional embeddings moving to (S, B, C).
encoder_in = torch.cat(
[
torch.stack([latent_embed, robot_state_embed], axis=0),
@ -456,60 +402,68 @@ class ActionChunkingTransformer(nn.Module):
]
)
pos_embed = torch.cat(
[self.additional_pos_embed.weight.unsqueeze(1), pos.flatten(2).permute(2, 0, 1)], axis=0
[
self.encoder_robot_and_latent_pos_embed.weight.unsqueeze(1),
cam_pos_embed.flatten(2).permute(2, 0, 1),
],
axis=0,
)
encoder_out = self.encoder(encoder_in, pos=pos_embed)
# Forward pass through the transformer modules.
encoder_out = self.encoder(encoder_in, pos_embed=pos_embed)
decoder_in = torch.zeros(
(self.horizon, batch_size, self.hidden_dim), dtype=pos_embed.dtype, device=pos_embed.device
(self.horizon, batch_size, self.d_model), dtype=pos_embed.dtype, device=pos_embed.device
)
decoder_out = self.decoder(
decoder_in,
encoder_out,
encoder_pos_embed=pos_embed,
decoder_pos_embed=self.decoder_pos_embed_embed.weight.unsqueeze(1),
).transpose(0, 1) # back to (B, S, C)
decoder_pos_embed=self.decoder_pos_embed.weight.unsqueeze(1),
)
# Move back to (B, S, C).
decoder_out = decoder_out.transpose(0, 1)
actions = self.action_head(decoder_out)
return actions, [mu, log_sigma_x2]
class TransformerEncoder(nn.Module):
def __init__(
self,
num_layers,
d_model,
nhead,
dim_feedforward=2048,
dropout=0.1,
activation="relu",
normalize_before=False,
):
class _TransformerEncoder(nn.Module):
"""Convenience module for running multiple encoder layers, maybe followed by normalization."""
def __init__(self, num_layers: int, **encoder_layer_kwargs: dict):
super().__init__()
self.layers = nn.ModuleList(
[
TransformerEncoderLayer(
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
)
for _ in range(num_layers)
]
[_TransformerEncoderLayer(**encoder_layer_kwargs) for _ in range(num_layers)]
)
self.norm = (
nn.LayerNorm(encoder_layer_kwargs["d_model"])
if encoder_layer_kwargs["normalize_before"]
else nn.Identity()
)
self.norm = nn.LayerNorm(d_model) if normalize_before else nn.Identity()
def forward(self, x, pos: Optional[Tensor] = None):
def forward(self, x: Tensor, pos_embed: Tensor | None = None) -> Tensor:
for layer in self.layers:
x = layer(x, pos=pos)
x = layer(x, pos_embed=pos_embed)
x = self.norm(x)
return x
class TransformerEncoderLayer(nn.Module):
class _TransformerEncoderLayer(nn.Module):
def __init__(
self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False
self,
d_model: int,
num_heads: int,
dim_feedforward: int,
dropout: float,
activation: str,
normalize_before: bool,
):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
# Feed forward layers.
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
@ -522,7 +476,7 @@ class TransformerEncoderLayer(nn.Module):
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
def forward(self, x, pos_embed: Optional[Tensor] = None):
def forward(self, x, pos_embed: Tensor | None = None) -> Tensor:
skip = x
if self.normalize_before:
x = self.norm1(x)
@ -542,32 +496,23 @@ class TransformerEncoderLayer(nn.Module):
return x
class TransformerDecoder(nn.Module):
def __init__(
self,
num_layers,
d_model,
nhead,
dim_feedforward=2048,
dropout=0.1,
activation="relu",
normalize_before=False,
):
class _TransformerDecoder(nn.Module):
def __init__(self, num_layers: int, **decoder_layer_kwargs):
"""Convenience module for running multiple decoder layers followed by normalization."""
super().__init__()
self.layers = nn.ModuleList(
[
TransformerDecoderLayer(
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
)
for _ in range(num_layers)
]
[_TransformerDecoderLayer(**decoder_layer_kwargs) for _ in range(num_layers)]
)
self.num_layers = num_layers
self.norm = nn.LayerNorm(d_model)
self.norm = nn.LayerNorm(decoder_layer_kwargs["d_model"])
def forward(
self, x, encoder_out, decoder_pos_embed: Tensor | None = None, encoder_pos_embed: Tensor | None = None
):
self,
x: Tensor,
encoder_out: Tensor,
decoder_pos_embed: Tensor | None = None,
encoder_pos_embed: Tensor | None = None,
) -> Tensor:
for layer in self.layers:
x = layer(
x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed
@ -577,14 +522,21 @@ class TransformerDecoder(nn.Module):
return x
class TransformerDecoderLayer(nn.Module):
class _TransformerDecoderLayer(nn.Module):
def __init__(
self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False
self,
d_model: int,
num_heads: int,
dim_feedforward: int,
dropout: float,
activation: str,
normalize_before: bool,
):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
self.multihead_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
# Feed forward layers.
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
@ -650,8 +602,26 @@ class TransformerDecoderLayer(nn.Module):
return x
class SinusoidalPositionEmbedding2D(nn.Module):
"""Sinusoidal positional embeddings similar to what's presented in Attention Is All You Need.
def _create_sinusoidal_position_embedding(num_positions: int, dimension: int) -> Tensor:
"""1D sinusoidal positional embeddings as in Attention is All You Need.
Args:
num_positions: Number of token positions required.
Returns: (num_positions, dimension) position embeddings (the first dimension is the batch dimension).
"""
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / dimension) for hid_j in range(dimension)]
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(num_positions)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.from_numpy(sinusoid_table).float()
class _SinusoidalPositionEmbedding2D(nn.Module):
"""2D sinusoidal positional embeddings similar to what's presented in Attention Is All You Need.
The variation is that the position indices are normalized in [0, 2π] (not quite: the lower bound is 1/H
for the vertical direction, and 1/W for the horizontal direction.
@ -705,7 +675,7 @@ class SinusoidalPositionEmbedding2D(nn.Module):
def _get_activation_fn(activation: str) -> Callable:
"""Return an activation function given a string"""
"""Return an activation function given a string."""
if activation == "relu":
return F.relu
if activation == "gelu":

View File

@ -21,24 +21,27 @@ policy:
lr: 1e-5
lr_backbone: 1e-5
pretrained_backbone: true
weight_decay: 1e-4
grad_clip_norm: 10
backbone: resnet18
horizon: ${horizon} # chunk_size
kl_weight: 10
hidden_dim: 512
d_model: 512
dim_feedforward: 3200
vae_enc_layers: 4
enc_layers: 4
dec_layers: 1
nheads: 8
num_heads: 8
#camera_names: [top, front_close, left_pillar, right_pillar]
camera_names: [top]
dilation: false
dropout: 0.1
pre_norm: false
activation: relu
latent_dim: 32
vae: true
use_vae: true
batch_size: 8

View File

@ -42,6 +42,8 @@ start_replacements = [
("model.transformer.encoder.", "model.encoder."),
("model.transformer.decoder.", "model.decoder."),
("model.backbones.0.0.body.", "model.backbone."),
("model.additional_pos_embed.weight", "model.encoder_robot_and_latent_pos_embed.weight"),
("model.cls_embed.weight", "model.vae_encoder_cls_embed.weight"),
]
for to_replace, replace_with in start_replacements: