backup wip

This commit is contained in:
Alexander Soare 2024-04-04 18:34:41 +01:00
parent 278336a39a
commit 3a4dfa82fe
8 changed files with 538 additions and 1227 deletions

View File

@ -1,115 +0,0 @@
from typing import List
import torch
import torchvision
from torch import nn
from torchvision.models._utils import IntermediateLayerGetter
from .position_encoding import build_position_encoding
from .utils import NestedTensor, is_main_process
class FrozenBatchNorm2d(torch.nn.Module):
"""
BatchNorm2d where the batch statistics and the affine parameters are fixed.
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101]
produce nans.
"""
def __init__(self, n):
super().__init__()
self.register_buffer("weight", torch.ones(n))
self.register_buffer("bias", torch.zeros(n))
self.register_buffer("running_mean", torch.zeros(n))
self.register_buffer("running_var", torch.ones(n))
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
num_batches_tracked_key = prefix + "num_batches_tracked"
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]
super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)
def forward(self, x):
# move reshapes to the beginning
# to make it fuser-friendly
w = self.weight.reshape(1, -1, 1, 1)
b = self.bias.reshape(1, -1, 1, 1)
rv = self.running_var.reshape(1, -1, 1, 1)
rm = self.running_mean.reshape(1, -1, 1, 1)
eps = 1e-5
scale = w * (rv + eps).rsqrt()
bias = b - rm * scale
return x * scale + bias
class BackboneBase(nn.Module):
def __init__(
self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool
):
super().__init__()
# for name, parameter in backbone.named_parameters(): # only train later layers # TODO do we want this?
# if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
# parameter.requires_grad_(False)
if return_interm_layers:
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
else:
return_layers = {"layer4": "0"}
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
self.num_channels = num_channels
def forward(self, tensor):
xs = self.body(tensor)
return xs
# out: Dict[str, NestedTensor] = {}
# for name, x in xs.items():
# m = tensor_list.mask
# assert m is not None
# mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
# out[name] = NestedTensor(x, mask)
# return out
class Backbone(BackboneBase):
"""ResNet backbone with frozen BatchNorm."""
def __init__(self, name: str, train_backbone: bool, return_interm_layers: bool, dilation: bool):
backbone = getattr(torchvision.models, name)(
replace_stride_with_dilation=[False, False, dilation],
pretrained=is_main_process(),
norm_layer=FrozenBatchNorm2d,
) # pretrained # TODO do we want frozen batch_norm??
num_channels = 512 if name in ("resnet18", "resnet34") else 2048
super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
class Joiner(nn.Sequential):
def __init__(self, backbone, position_embedding):
super().__init__(backbone, position_embedding)
def forward(self, tensor_list: NestedTensor):
xs = self[0](tensor_list)
out: List[NestedTensor] = []
pos = []
for _, x in xs.items():
out.append(x)
# position encoding
pos.append(self[1](x).to(x.dtype))
return out, pos
def build_backbone(args):
position_embedding = build_position_encoding(args)
train_backbone = args.lr_backbone > 0
return_interm_layers = args.masks
backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
model = Joiner(backbone, position_embedding)
model.num_channels = backbone.num_channels
return model

View File

@ -1,229 +0,0 @@
import einops
import numpy as np
import torch
from torch import nn
from .backbone import build_backbone
from .transformer import Transformer, TransformerEncoder
def get_sinusoid_encoding_table(n_position, d_hid):
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
class ActionChunkingTransformer(nn.Module):
"""
Action Chunking Transformer as per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware
(paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
Note: In this code we use the symbols `vae_encoder`, 'encoder', `decoder`. The meanings are as follows.
- The `vae_encoder` is, as per the literature around conditional variational auto-encoders (cVAE), the
part of the model that encodes the target data (here, a sequence of actions), and the condition
(here, we include the robot joint-space state as an input to the encoder).
- The `transformer` is the cVAE's decoder. But since we have an option to train this model without the
variational objective (in which case we drop the `vae_encoder` altogether), we don't call it the
`vae_decoder`.
# TODO(now): remove the following
- The `encoder` is actually a component of the cVAE's "decoder". But we refer to it as an "encoder"
because, in terms of the transformer with cross-attention that forms the cVAE's decoder, it is the
"encoder" part. We drop the `vae_` prefix because we have an option to train this model without the
variational objective (in which case we drop the `vae_encoder` altogether), and nothing about this
model has anything to do with a VAE).
- The `decoder` is a building block of the VAE decoder, and is just the "decoder" part of a
transformer with cross-attention. For the same reasoning behind the naming of `encoder`, we make
this term agnostic to the option to use a variational objective for training.
"""
def __init__(
self, backbones, transformer, vae_encoder, state_dim, action_dim, horizon, camera_names, use_vae
):
"""Initializes the model.
Parameters:
backbones: torch module of the backbone to be used. See backbone.py
transformer: torch module of the transformer architecture. See transformer.py
state_dim: robot state dimension of the environment
horizon: number of object queries, ie detection slot. This is the maximal number of objects
DETR can detect in a single image. For COCO, we recommend 100 queries.
Args:
state_dim: Robot positional state dimension.
action_dim: Action dimension.
horizon: The number of actions to generate in one forward pass.
use_vae: Whether to use the variational objective. TODO(now): Give more details.
"""
super().__init__()
self.camera_names = camera_names
self.transformer = transformer
self.vae_encoder = vae_encoder
self.use_vae = use_vae
hidden_dim = transformer.d_model
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
if use_vae:
self.cls_embed = nn.Embedding(1, hidden_dim)
# Projection layer for joint-space configuration to hidden dimension.
self.vae_encoder_robot_state_input_proj = nn.Linear(state_dim, hidden_dim)
# Projection layer for action (joint-space target) to hidden dimension.
self.vae_encoder_action_input_proj = nn.Linear(state_dim, hidden_dim)
# Final size of latent z. TODO(now): Add to hyperparams.
self.latent_dim = 32
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
self.vae_encoder_latent_output_proj = nn.Linear(hidden_dim, self.latent_dim * 2)
# Fixed sinusoidal positional embedding the whole input to the VAE encoder.
self.register_buffer(
"vae_encoder_pos_enc", get_sinusoid_encoding_table(1 + 1 + horizon, hidden_dim)
)
# Transformer encoder input projections. The tokens will be structured like
# [latent, robot_state, image_feature_map_pixels].
self.backbones = nn.ModuleList(backbones)
self.encoder_img_feat_input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1)
self.encoder_robot_state_input_proj = nn.Linear(state_dim, hidden_dim)
self.encoder_latent_input_proj = nn.Linear(self.latent_dim, hidden_dim)
# TODO(now): Fix this nonsense. One positional embedding is needed. We should extract the image
# feature dimension with a dry run.
self.additional_pos_embed = nn.Embedding(
2, hidden_dim
) # learned position embedding for proprio and latent
# Transformer decoder.
# Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
self.decoder_pos_embed = nn.Embedding(horizon, hidden_dim)
# Final action regression head on the output of the transformer's decoder.
self.action_head = nn.Linear(hidden_dim, action_dim)
def forward(self, robot_state, image, actions=None):
"""
Args:
robot_state: (B, J) batch of robot joint configurations.
image: (B, N, C, H, W) batch of N camera frames.
actions: (B, S, A) batch of actions from the target dataset which must be provided if the
VAE is enabled and the model is in training mode.
"""
if self.use_vae and self.training:
assert (
actions is not None
), "actions must be provided when using the variational objective in training mode."
batch_size, _ = robot_state.shape
# Prepare the latent for input to the transformer.
if self.use_vae and actions is not None:
# Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence].
cls_embed = einops.repeat(self.cls_embed.weight, "1 d -> b 1 d", b=batch_size) # (B, 1, D)
robot_state_embed = self.vae_encoder_robot_state_input_proj(robot_state).unsqueeze(1) # (B, 1, D)
action_embed = self.vae_encoder_action_input_proj(actions) # (B, S, D)
vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D)
# Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
# Prepare fixed positional embedding.
pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D)
# Forward pass through VAE encoder and sample the latent with the reparameterization trick.
cls_token_out = self.vae_encoder(
vae_encoder_input.permute(1, 0, 2), pos=pos_embed.permute(1, 0, 2)
)[0] # (B, D)
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
mu = latent_pdf_params[:, : self.latent_dim]
logvar = latent_pdf_params[:, self.latent_dim :]
# Use reparameterization trick to sample from the latent's PDF.
latent_sample = mu + logvar.div(2).exp() * torch.randn_like(mu)
else:
# When not using the VAE encoder, we set the latent to be all zeros.
mu = logvar = None
latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=robot_state.dtype).to(
robot_state.device
)
# Prepare all other transformer inputs.
# Image observation features and position embeddings.
all_cam_features = []
all_cam_pos = []
for cam_id, _ in enumerate(self.camera_names):
# TODO(now): remove the positional embedding from the backbones.
cam_features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED
cam_features = cam_features[0] # take the last layer feature
pos = pos[0]
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
all_cam_features.append(cam_features)
all_cam_pos.append(pos)
# Concatenate image observation feature maps along the width dimension.
transformer_input = torch.cat(all_cam_features, axis=3)
# TODO(now): remove the positional embedding from the backbones.
pos = torch.cat(all_cam_pos, axis=3)
robot_state_embed = self.encoder_robot_state_input_proj(robot_state)
latent_embed = self.encoder_latent_input_proj(latent_sample)
# TODO(now): Explain all of this madness.
transformer_input = torch.cat(
[
torch.stack([latent_embed, robot_state_embed], axis=0),
transformer_input.flatten(2).permute(2, 0, 1),
]
)
pos_embed = torch.cat(
[self.additional_pos_embed.weight.unsqueeze(1), pos.flatten(2).permute(2, 0, 1)], axis=0
)
# Run the transformer and project the outputs to the action space.
transformer_output = self.transformer(
transformer_input,
encoder_pos=pos_embed,
decoder_pos=self.decoder_pos_embed.weight.unsqueeze(1),
).transpose(0, 1) # back to (B, S, C)
actions = self.action_head(transformer_output)
return actions, [mu, logvar]
def build(args):
# From state
# backbone = None # from state for now, no need for conv nets
# From image
backbones = []
backbone = build_backbone(args)
backbones.append(backbone)
transformer = Transformer(
d_model=args.hidden_dim,
dropout=args.dropout,
nhead=args.nheads,
dim_feedforward=args.dim_feedforward,
num_encoder_layers=args.enc_layers,
num_decoder_layers=args.dec_layers,
normalize_before=args.pre_norm,
)
# TODO(now): args.enc_layers shouldn't be shared with the transformer decoder
vae_encoder = TransformerEncoder(
num_layers=args.enc_layers,
d_model=args.hidden_dim,
nhead=args.nheads,
dim_feedforward=args.dim_feedforward,
dropout=args.dropout,
activation="relu",
normalize_before=args.pre_norm,
)
model = ActionChunkingTransformer(
backbones,
transformer,
vae_encoder,
state_dim=args.state_dim,
action_dim=args.action_dim,
horizon=args.num_queries,
camera_names=args.camera_names,
use_vae=args.vae,
)
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("number of parameters: {:.2f}M".format(n_parameters / 1e6))
return model

View File

@ -1,50 +1,32 @@
import logging
import time
"""Action Chunking Transformer Policy
As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://arxiv.org/abs/2304.13705).
"""
import logging
import math
import time
from itertools import chain
from typing import Callable, Optional
import einops
import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
import torchvision
import torchvision.transforms as transforms
from torch import Tensor, nn
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.ops.misc import FrozenBatchNorm2d
from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.policies.act.detr_vae import build
from lerobot.common.utils import get_safe_torch_device
def build_act_model_and_optimizer(cfg):
model = build(cfg)
param_dicts = [
{"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
{
"params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
"lr": cfg.lr_backbone,
},
]
optimizer = torch.optim.AdamW(param_dicts, lr=cfg.lr, weight_decay=cfg.weight_decay)
return model, optimizer
def kl_divergence(mu, logvar):
batch_size = mu.size(0)
assert batch_size != 0
if mu.data.ndimension() == 4:
mu = mu.view(mu.size(0), mu.size(1))
if logvar.data.ndimension() == 4:
logvar = logvar.view(logvar.size(0), logvar.size(1))
klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
total_kld = klds.sum(1).mean(0, True)
dimension_wise_kld = klds.mean(0)
mean_kld = klds.mean(1).mean(0, True)
return total_kld, dimension_wise_kld, mean_kld
class ActionChunkingTransformerPolicy(AbstractPolicy):
"""
Action Chunking Transformer as per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware
(https://arxiv.org/abs/2304.13705).
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
Hardware (https://arxiv.org/abs/2304.13705).
"""
name = "act"
@ -68,7 +50,35 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
self.cfg = cfg
self.n_action_steps = n_action_steps
self.device = get_safe_torch_device(device)
self.model, self.optimizer = build_act_model_and_optimizer(cfg)
self.model = ActionChunkingTransformer(
cfg,
state_dim=cfg.state_dim,
action_dim=cfg.action_dim,
horizon=cfg.horizon,
camera_names=cfg.camera_names,
use_vae=cfg.vae,
)
optimizer_params_dicts = [
{
"params": [
p
for n, p in self.model.named_parameters()
if not n.startswith("backbone") and p.requires_grad
]
},
{
"params": [
p
for n, p in self.model.named_parameters()
if n.startswith("backbone") and p.requires_grad
],
"lr": cfg.lr_backbone,
},
]
self.optimizer = torch.optim.AdamW(optimizer_params_dicts, lr=cfg.lr, weight_decay=cfg.weight_decay)
self.kl_weight = self.cfg.kl_weight
logging.info(f"KL Weight {self.kl_weight}")
self.to(self.device)
@ -140,12 +150,10 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
self.optimizer.step()
self.optimizer.zero_grad()
# self.lr_scheduler.step()
info = {
"loss": loss.item(),
"grad_norm": float(grad_norm),
# "lr": self.lr_scheduler.get_last_lr()[0],
"lr": self.cfg.lr,
"data_s": data_s,
"update_s": time.time() - start_time,
@ -213,31 +221,495 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
action = action[: self.n_action_steps]
return action
def _forward(self, qpos, image, actions=None, is_pad=None):
env_state = None
def _forward(self, qpos, image, actions=None):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
image = normalize(image)
is_training = actions is not None
if is_training: # training time
actions = actions[:, : self.model.num_queries]
if is_pad is not None:
is_pad = is_pad[:, : self.model.num_queries]
actions = actions[:, : self.model.horizon]
a_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad)
a_hat, (mu, log_sigma_x2) = self.model(qpos, image, actions)
all_l1 = F.l1_loss(actions, a_hat, reduction="none")
l1 = all_l1.mean() if is_pad is None else (all_l1 * ~is_pad.unsqueeze(-1)).mean()
l1 = all_l1.mean()
loss_dict = {}
loss_dict["l1"] = l1
if self.cfg.vae:
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
loss_dict["kl"] = total_kld[0]
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
# each dimension independently, we sum over the latent dimension to get the total
# KL-divergence per batch element, then take the mean over the batch.
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
mean_kld = (-0.5 * (1 + log_sigma_x2 - mu.pow(2) - (log_sigma_x2).exp())).sum(-1).mean()
loss_dict["kl"] = mean_kld
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.kl_weight
else:
loss_dict["loss"] = loss_dict["l1"]
return loss_dict
else:
action, _ = self.model(qpos, image, env_state) # no action, sample from prior
action, _ = self.model(qpos, image) # no action, sample from prior
return action
def create_sinusoidal_position_embedding(n_position, d_hid):
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
# TODO(alexander-soare) move all this code into the policy when we have the policy API established.
class ActionChunkingTransformer(nn.Module):
"""
Action Chunking Transformer as per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware
(paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
Note: In this code we use the terms `vae_encoder`, 'encoder', `decoder`. The meanings are as follows.
- The `vae_encoder` is, as per the literature around variational auto-encoders (VAE), the part of the
model that encodes the target data (a sequence of actions), and the condition (the robot
joint-space).
- A transformer with an `encoder` (not the VAE encoder) and `decoder` (not the VAE decoder) with
cross-attention is used as the VAE decoder. For these terms, we drop the `vae_` prefix because we
have an option to train this model without the variational objective (in which case we drop the
`vae_encoder` altogether, and nothing about this model has anything to do with a VAE).
Transformer
Used alone for inference
(acts as VAE decoder
during training)
Outputs
Transf.
decoder
VAE
encoder Transf.
encoder
inputs
"""
def __init__(self, args, state_dim, action_dim, horizon, camera_names, use_vae):
"""Initializes the model.
Parameters:
state_dim: robot state dimension of the environment
horizon: number of object queries, ie detection slot. This is the maximal number of objects
DETR can detect in a single image. For COCO, we recommend 100 queries.
Args:
state_dim: Robot positional state dimension.
action_dim: Action dimension.
horizon: The number of actions to generate in one forward pass.
use_vae: Whether to use the variational objective. TODO(now): Give more details.
"""
super().__init__()
self.camera_names = camera_names
self.use_vae = use_vae
self.horizon = horizon
self.hidden_dim = args.hidden_dim
transformer_common_kwargs = dict( # noqa: C408
d_model=self.hidden_dim,
nhead=args.nheads,
dim_feedforward=args.dim_feedforward,
dropout=args.dropout,
activation=args.activation,
normalize_before=args.pre_norm,
)
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
if use_vae:
# TODO(now): args.enc_layers shouldn't be shared with the transformer decoder
self.vae_encoder = TransformerEncoder(num_layers=args.enc_layers, **transformer_common_kwargs)
self.cls_embed = nn.Embedding(1, self.hidden_dim)
# Projection layer for joint-space configuration to hidden dimension.
self.vae_encoder_robot_state_input_proj = nn.Linear(state_dim, self.hidden_dim)
# Projection layer for action (joint-space target) to hidden dimension.
self.vae_encoder_action_input_proj = nn.Linear(state_dim, self.hidden_dim)
# Final size of latent z. TODO(now): Add to hyperparams.
self.latent_dim = 32
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
self.vae_encoder_latent_output_proj = nn.Linear(self.hidden_dim, self.latent_dim * 2)
# Fixed sinusoidal positional embedding the whole input to the VAE encoder.
self.register_buffer(
"vae_encoder_pos_enc", create_sinusoidal_position_embedding(1 + 1 + horizon, self.hidden_dim)
)
# Backbone for image feature extraction.
self.backbone_position_embedding = SinusoidalPositionEmbedding2D(self.hidden_dim // 2)
backbone_model = getattr(torchvision.models, args.backbone)(
replace_stride_with_dilation=[False, False, args.dilation],
pretrained=True, # TODO(now): Add pretrained option
norm_layer=FrozenBatchNorm2d,
)
# Note: The forward method of this returns a dict: {"feature_map": output}.
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
# Transformer (acts as VAE decoder when training with the variational objective).
self.encoder = TransformerEncoder(num_layers=args.enc_layers, **transformer_common_kwargs)
self.decoder = TransformerDecoder(num_layers=args.dec_layers, **transformer_common_kwargs)
# Transformer encoder input projections. The tokens will be structured like
# [latent, robot_state, image_feature_map_pixels].
self.encoder_img_feat_input_proj = nn.Conv2d(
backbone_model.fc.in_features, self.hidden_dim, kernel_size=1
)
self.encoder_robot_state_input_proj = nn.Linear(state_dim, self.hidden_dim)
self.encoder_latent_input_proj = nn.Linear(self.latent_dim, self.hidden_dim)
# TODO(now): Fix this nonsense. One positional embedding is needed. We should extract the image
# feature dimension with a dry run.
self.additional_pos_embed = nn.Embedding(
2, self.hidden_dim
) # learned position embedding for proprio and latent
# Transformer decoder.
# Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
self.decoder_pos_embed_embed = nn.Embedding(horizon, self.hidden_dim)
# Final action regression head on the output of the transformer's decoder.
self.action_head = nn.Linear(self.hidden_dim, action_dim)
self._reset_parameters()
def _reset_parameters(self):
"""Xavier-uniform initialization of the transformer parameters as in the original code."""
for p in chain(self.encoder.parameters(), self.decoder.parameters()):
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, robot_state, image, actions=None):
"""
Args:
robot_state: (B, J) batch of robot joint configurations.
image: (B, N, C, H, W) batch of N camera frames.
actions: (B, S, A) batch of actions from the target dataset which must be provided if the
VAE is enabled and the model is in training mode.
"""
if self.use_vae and self.training:
assert (
actions is not None
), "actions must be provided when using the variational objective in training mode."
batch_size, _ = robot_state.shape
# Prepare the latent for input to the transformer.
if self.use_vae and actions is not None:
# Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence].
cls_embed = einops.repeat(self.cls_embed.weight, "1 d -> b 1 d", b=batch_size) # (B, 1, D)
robot_state_embed = self.vae_encoder_robot_state_input_proj(robot_state).unsqueeze(1) # (B, 1, D)
action_embed = self.vae_encoder_action_input_proj(actions) # (B, S, D)
vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D)
# Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
# Prepare fixed positional embedding.
pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D)
# Forward pass through VAE encoder and sample the latent with the reparameterization trick.
cls_token_out = self.vae_encoder(
vae_encoder_input.permute(1, 0, 2), pos=pos_embed.permute(1, 0, 2)
)[0] # (B, D)
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
mu = latent_pdf_params[:, : self.latent_dim]
# This is 2log(sigma). Done this way to match the original implementation.
log_sigma_x2 = latent_pdf_params[:, self.latent_dim :]
# Use reparameterization trick to sample from the latent's PDF.
latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu)
else:
# When not using the VAE encoder, we set the latent to be all zeros.
mu = log_sigma_x2 = None
latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=torch.float32).to(
robot_state.device
)
# Prepare all other transformer inputs.
# Image observation features and position embeddings.
all_cam_features = []
all_cam_pos = []
for cam_id, _ in enumerate(self.camera_names):
cam_features = self.backbone(image[:, cam_id])["feature_map"]
pos = self.backbone_position_embedding(cam_features).to(dtype=cam_features.dtype)
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
all_cam_features.append(cam_features)
all_cam_pos.append(pos)
# Concatenate image observation feature maps along the width dimension.
encoder_in = torch.cat(all_cam_features, axis=3)
pos = torch.cat(all_cam_pos, axis=3)
robot_state_embed = self.encoder_robot_state_input_proj(robot_state)
latent_embed = self.encoder_latent_input_proj(latent_sample)
# TODO(now): Explain all of this madness.
encoder_in = torch.cat(
[
torch.stack([latent_embed, robot_state_embed], axis=0),
encoder_in.flatten(2).permute(2, 0, 1),
]
)
pos_embed = torch.cat(
[self.additional_pos_embed.weight.unsqueeze(1), pos.flatten(2).permute(2, 0, 1)], axis=0
)
encoder_out = self.encoder(encoder_in, pos=pos_embed)
decoder_in = torch.zeros(
(self.horizon, batch_size, self.hidden_dim), dtype=pos_embed.dtype, device=pos_embed.device
)
decoder_out = self.decoder(
decoder_in,
encoder_out,
encoder_pos_embed=pos_embed,
decoder_pos_embed=self.decoder_pos_embed_embed.weight.unsqueeze(1),
).transpose(0, 1) # back to (B, S, C)
actions = self.action_head(decoder_out)
return actions, [mu, log_sigma_x2]
class TransformerEncoder(nn.Module):
def __init__(
self,
num_layers,
d_model,
nhead,
dim_feedforward=2048,
dropout=0.1,
activation="relu",
normalize_before=False,
):
super().__init__()
self.layers = nn.ModuleList(
[
TransformerEncoderLayer(
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
)
for _ in range(num_layers)
]
)
self.norm = nn.LayerNorm(d_model) if normalize_before else nn.Identity()
def forward(self, x, pos: Optional[Tensor] = None):
for layer in self.layers:
x = layer(x, pos=pos)
x = self.norm(x)
return x
class TransformerEncoderLayer(nn.Module):
def __init__(
self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False
):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
def forward(self, x, pos_embed: Optional[Tensor] = None):
skip = x
if self.normalize_before:
x = self.norm1(x)
q = k = x if pos_embed is None else x + pos_embed
x = self.self_attn(q, k, value=x)[0]
x = skip + self.dropout1(x)
if self.normalize_before:
skip = x
x = self.norm2(x)
else:
x = self.norm1(x)
skip = x
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
x = skip + self.dropout2(x)
if not self.normalize_before:
x = self.norm2(x)
return x
class TransformerDecoder(nn.Module):
def __init__(
self,
num_layers,
d_model,
nhead,
dim_feedforward=2048,
dropout=0.1,
activation="relu",
normalize_before=False,
):
super().__init__()
self.layers = nn.ModuleList(
[
TransformerDecoderLayer(
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
)
for _ in range(num_layers)
]
)
self.num_layers = num_layers
self.norm = nn.LayerNorm(d_model)
def forward(
self, x, encoder_out, decoder_pos_embed: Tensor | None = None, encoder_pos_embed: Tensor | None = None
):
for layer in self.layers:
x = layer(
x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed
)
if self.norm is not None:
x = self.norm(x)
return x
class TransformerDecoderLayer(nn.Module):
def __init__(
self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False
):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
def maybe_add_pos_embed(self, tensor: Tensor, pos_embed: Tensor | None) -> Tensor:
return tensor if pos_embed is None else tensor + pos_embed
def forward(
self,
x: Tensor,
encoder_out: Tensor,
decoder_pos_embed: Tensor | None = None,
encoder_pos_embed: Tensor | None = None,
) -> Tensor:
"""
Args:
x: (Decoder Sequence, Batch, Channel) tensor of input tokens.
encoder_out: (Encoder Sequence, B, C) output features from the last layer of the encoder we are
cross-attending with.
decoder_pos_embed: (ES, 1, C) positional embedding for keys (from the encoder).
encoder_pos_embed: (DS, 1, C) Positional_embedding for the queries (from the decoder).
Returns:
(DS, B, C) tensor of decoder output features.
"""
skip = x
if self.normalize_before:
x = self.norm1(x)
q = k = self.maybe_add_pos_embed(x, decoder_pos_embed)
x = self.self_attn(q, k, value=x)[0]
x = skip + self.dropout1(x)
if self.normalize_before:
skip = x
x = self.norm2(x)
else:
x = self.norm1(x)
skip = x
x = self.multihead_attn(
query=self.maybe_add_pos_embed(x, decoder_pos_embed),
key=self.maybe_add_pos_embed(encoder_out, encoder_pos_embed),
value=encoder_out,
)[0]
x = skip + self.dropout2(x)
if self.normalize_before:
skip = x
x = self.norm3(x)
else:
x = self.norm2(x)
skip = x
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
x = skip + self.dropout3(x)
if not self.normalize_before:
x = self.norm3(x)
return x
class SinusoidalPositionEmbedding2D(nn.Module):
"""Sinusoidal positional embeddings similar to what's presented in Attention Is All You Need.
The variation is that the position indices are normalized in [0, 2π] (not quite: the lower bound is 1/H
for the vertical direction, and 1/W for the horizontal direction.
"""
def __init__(self, dimension: int):
"""
Args:
dimension: The desired dimension of the embeddings.
"""
super().__init__()
self.dimension = dimension
self._two_pi = 2 * math.pi
self._eps = 1e-6
# Inverse "common ratio" for the geometric progression in sinusoid frequencies.
self._temperature = 10000
def forward(self, x: Tensor) -> Tensor:
"""
Args:
x: A (B, C, H, W) batch of 2D feature map to generate the embeddings for.
Returns:
A (1, C, H, W) batch of corresponding sinusoidal positional embeddings.
"""
not_mask = torch.ones_like(x[0, [0]]) # (1, H, W)
# Note: These are like range(1, H+1) and range(1, W+1) respectively, but in most implementations
# they would be range(0, H) and range(0, W). Keeping it at as to match the original code.
y_range = not_mask.cumsum(1, dtype=torch.float32)
x_range = not_mask.cumsum(2, dtype=torch.float32)
# "Normalize" the position index such that it ranges in [0, 2π].
# Note: Adding epsilon on the denominator should not be needed as all values of y_embed and x_range
# are non-zero by construction. This is an artifact of the original code.
y_range = y_range / (y_range[:, -1:, :] + self._eps) * self._two_pi
x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi
inverse_frequency = self._temperature ** (
2 * (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2) / self.dimension
)
x_range = x_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)
y_range = y_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)
# Note: this stack then flatten operation results in interleaved sine and cosine terms.
# pos_embed_x and pos_embed are (1, H, W, C // 2).
pos_embed_x = torch.stack((x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1).flatten(3)
pos_embed_y = torch.stack((y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1).flatten(3)
pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(0, 3, 1, 2) # (1, C, H, W)
return pos_embed
def _get_activation_fn(activation: str) -> Callable:
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(f"activation should be relu/gelu/glu, not {activation}.")

View File

@ -1,102 +0,0 @@
"""
Various positional encodings for the transformer.
"""
import math
import torch
from torch import nn
from .utils import NestedTensor
class PositionEmbeddingSine(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, tensor):
x = tensor
# mask = tensor_list.mask
# assert mask is not None
# not_mask = ~mask
not_mask = torch.ones_like(x[0, [0]])
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
class PositionEmbeddingLearned(nn.Module):
"""
Absolute pos embedding, learned.
"""
def __init__(self, num_pos_feats=256):
super().__init__()
self.row_embed = nn.Embedding(50, num_pos_feats)
self.col_embed = nn.Embedding(50, num_pos_feats)
self.reset_parameters()
def reset_parameters(self):
nn.init.uniform_(self.row_embed.weight)
nn.init.uniform_(self.col_embed.weight)
def forward(self, tensor_list: NestedTensor):
x = tensor_list.tensors
h, w = x.shape[-2:]
i = torch.arange(w, device=x.device)
j = torch.arange(h, device=x.device)
x_emb = self.col_embed(i)
y_emb = self.row_embed(j)
pos = (
torch.cat(
[
x_emb.unsqueeze(0).repeat(h, 1, 1),
y_emb.unsqueeze(1).repeat(1, w, 1),
],
dim=-1,
)
.permute(2, 0, 1)
.unsqueeze(0)
.repeat(x.shape[0], 1, 1, 1)
)
return pos
def build_position_encoding(args):
n_steps = args.hidden_dim // 2
if args.position_embedding in ("v2", "sine"):
# TODO find a better way of exposing other arguments
position_embedding = PositionEmbeddingSine(n_steps, normalize=True)
elif args.position_embedding in ("v3", "learned"):
position_embedding = PositionEmbeddingLearned(n_steps)
else:
raise ValueError(f"not supported {args.position_embedding}")
return position_embedding

View File

@ -1,240 +0,0 @@
"""
TODO(now)
"""
from typing import Optional
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
class Transformer(nn.Module):
def __init__(
self,
d_model=512,
nhead=8,
num_encoder_layers=6,
num_decoder_layers=6,
dim_feedforward=2048,
dropout=0.1,
activation="relu",
normalize_before=False,
):
super().__init__()
self.encoder = TransformerEncoder(
num_encoder_layers, d_model, nhead, dim_feedforward, dropout, activation, normalize_before
)
self.decoder = TransformerDecoder(
num_decoder_layers, d_model, nhead, dim_feedforward, dropout, activation, normalize_before
)
self.d_model = d_model
self.nhead = nhead
self._init_params() # TODO(now): move to somewhere common
def _init_params(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, x, encoder_pos, decoder_pos):
"""
Args:
x: ((E)ncoder (S)equence, (B)atch, (C)hannels)
decoder_pos: (Decoder Sequence, C) tensor for the decoder's positional embedding.
encoder_pos: (ES, C) tenso
"""
# TODO flatten only when input has H and W
bs = x.shape[1]
encoder_out = self.encoder(x, pos=encoder_pos)
decoder_in = torch.zeros(
(decoder_pos.shape[0], bs, decoder_pos.shape[2]),
dtype=decoder_pos.dtype,
device=decoder_pos.device,
)
decoder_out = self.decoder(decoder_in, encoder_out, encoder_pos=encoder_pos, decoder_pos=decoder_pos)
return decoder_out
class TransformerEncoder(nn.Module):
def __init__(
self,
num_layers,
d_model,
nhead,
dim_feedforward=2048,
dropout=0.1,
activation="relu",
normalize_before=False,
):
super().__init__()
self.layers = nn.ModuleList(
[
TransformerEncoderLayer(
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
)
for _ in range(num_layers)
]
)
self.norm = nn.LayerNorm(d_model) if normalize_before else nn.Identity()
def forward(self, x, pos: Optional[Tensor] = None):
for layer in self.layers:
x = layer(x, pos=pos)
x = self.norm(x)
return x
class TransformerEncoderLayer(nn.Module):
def __init__(
self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False
):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
def forward(self, x, pos: Optional[Tensor] = None):
skip = x
if self.normalize_before:
x = self.norm1(x)
q = k = x if pos is None else x + pos
x = self.self_attn(q, k, value=x)[0]
x = skip + self.dropout1(x)
if self.normalize_before:
skip = x
x = self.norm2(x)
else:
x = self.norm1(x)
skip = x
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
x = skip + self.dropout2(x)
if not self.normalize_before:
x = self.norm2(x)
return x
class TransformerDecoder(nn.Module):
def __init__(
self,
num_layers,
d_model,
nhead,
dim_feedforward=2048,
dropout=0.1,
activation="relu",
normalize_before=False,
):
super().__init__()
self.layers = nn.ModuleList(
[
TransformerDecoderLayer(
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
)
for _ in range(num_layers)
]
)
self.num_layers = num_layers
self.norm = nn.LayerNorm(d_model)
def forward(self, x, encoder_out, decoder_pos: Tensor | None = None, encoder_pos: Tensor | None = None):
for layer in self.layers:
x = layer(x, encoder_out, decoder_pos=decoder_pos, encoder_pos=encoder_pos)
if self.norm is not None:
x = self.norm(x)
return x
class TransformerDecoderLayer(nn.Module):
def __init__(
self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False
):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
def maybe_add_pos_embed(self, tensor: Tensor, pos: Tensor | None) -> Tensor:
return tensor if pos is None else tensor + pos
def forward(
self,
x: Tensor,
encoder_out: Tensor,
decoder_pos: Tensor | None = None,
encoder_pos: Tensor | None = None,
) -> Tensor:
"""
Args:
x: (Decoder Sequence, Batch, Channel) tensor of input tokens.
encoder_out: (Encoder Sequence, B, C) output features from the last layer of the encoder we are
cross-attending with.
decoder_pos: (ES, 1, C) positional embedding for keys (from the encoder).
encoder_pos: (DS, 1, C) Positional_embedding for the queries (from the decoder).
Returns:
(DS, B, C) tensor of decoder output features.
"""
skip = x
if self.normalize_before:
x = self.norm1(x)
q = k = self.maybe_add_pos_embed(x, decoder_pos)
x = self.self_attn(q, k, value=x)[0]
x = skip + self.dropout1(x)
if self.normalize_before:
skip = x
x = self.norm2(x)
else:
x = self.norm1(x)
skip = x
x = self.multihead_attn(
query=self.maybe_add_pos_embed(x, decoder_pos),
key=self.maybe_add_pos_embed(encoder_out, encoder_pos),
value=encoder_out,
)[0]
x = skip + self.dropout2(x)
if self.normalize_before:
skip = x
x = self.norm3(x)
else:
x = self.norm2(x)
skip = x
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
x = skip + self.dropout3(x)
if not self.normalize_before:
x = self.norm3(x)
return x
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(f"activation should be relu/gelu/glu, not {activation}.")

View File

@ -1,478 +0,0 @@
"""
Misc functions, including distributed helpers.
Mostly copy-paste from torchvision references.
"""
import datetime
import os
import pickle
import subprocess
import time
from collections import defaultdict, deque
from typing import List, Optional
import torch
import torch.distributed as dist
# needed due to empty tensor bug in pytorch and torchvision 0.5
import torchvision
from packaging import version
from torch import Tensor
if version.parse(torchvision.__version__) < version.parse("0.7"):
from torchvision.ops import _new_empty_tensor
from torchvision.ops.misc import _output_size
class SmoothedValue:
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
if not is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
)
def all_gather(data):
"""
Run all_gather on arbitrary picklable data (not necessarily tensors)
Args:
data: any picklable object
Returns:
list[data]: list of data gathered from each rank
"""
world_size = get_world_size()
if world_size == 1:
return [data]
# serialized to a Tensor
buffer = pickle.dumps(data)
storage = torch.ByteStorage.from_buffer(buffer)
tensor = torch.ByteTensor(storage).to("cuda")
# obtain Tensor size of each rank
local_size = torch.tensor([tensor.numel()], device="cuda")
size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
dist.all_gather(size_list, local_size)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)
# receiving Tensor from all ranks
# we pad the tensor because torch all_gather does not support
# gathering tensors of different shapes
tensor_list = []
for _ in size_list:
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
if local_size != max_size:
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
tensor = torch.cat((tensor, padding), dim=0)
dist.all_gather(tensor_list, tensor)
data_list = []
for size, tensor in zip(size_list, tensor_list, strict=False):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))
return data_list
def reduce_dict(input_dict, average=True):
"""
Args:
input_dict (dict): all the values will be reduced
average (bool): whether to do average or sum
Reduce the values in the dictionary from all processes so that all processes
have the averaged results. Returns a dict with the same fields as
input_dict, after reduction.
"""
world_size = get_world_size()
if world_size < 2:
return input_dict
with torch.no_grad():
names = []
values = []
# sort the keys so that they are consistent across processes
for k in sorted(input_dict.keys()):
names.append(k)
values.append(input_dict[k])
values = torch.stack(values, dim=0)
dist.all_reduce(values)
if average:
values /= world_size
reduced_dict = {k: v for k, v in zip(names, values, strict=False)} # noqa: C416
return reduced_dict
class MetricLogger:
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append("{}: {}".format(name, str(meter)))
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, iterable, print_freq, header=None):
if not header:
header = ""
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt="{avg:.4f}")
data_time = SmoothedValue(fmt="{avg:.4f}")
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
if torch.cuda.is_available():
log_msg = self.delimiter.join(
[
header,
"[{0" + space_fmt + "}/{1}]",
"eta: {eta}",
"{meters}",
"time: {time}",
"data: {data}",
"max mem: {memory:.0f}",
]
)
else:
log_msg = self.delimiter.join(
[
header,
"[{0" + space_fmt + "}/{1}]",
"eta: {eta}",
"{meters}",
"time: {time}",
"data: {data}",
]
)
mega_b = 1024.0 * 1024.0
for i, obj in enumerate(iterable):
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if i % print_freq == 0 or i == len(iterable) - 1:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(
log_msg.format(
i,
len(iterable),
eta=eta_string,
meters=str(self),
time=str(iter_time),
data=str(data_time),
memory=torch.cuda.max_memory_allocated() / mega_b,
)
)
else:
print(
log_msg.format(
i,
len(iterable),
eta=eta_string,
meters=str(self),
time=str(iter_time),
data=str(data_time),
)
)
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print("{} Total time: {} ({:.4f} s / it)".format(header, total_time_str, total_time / len(iterable)))
def get_sha():
cwd = os.path.dirname(os.path.abspath(__file__))
def _run(command):
return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
sha = "N/A"
diff = "clean"
branch = "N/A"
try:
sha = _run(["git", "rev-parse", "HEAD"])
subprocess.check_output(["git", "diff"], cwd=cwd)
diff = _run(["git", "diff-index", "HEAD"])
diff = "has uncommited changes" if diff else "clean"
branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
except Exception:
pass
message = f"sha: {sha}, status: {diff}, branch: {branch}"
return message
def collate_fn(batch):
batch = list(zip(*batch, strict=False))
batch[0] = nested_tensor_from_tensor_list(batch[0])
return tuple(batch)
def _max_by_axis(the_list):
# type: (List[List[int]]) -> List[int]
maxes = the_list[0]
for sublist in the_list[1:]:
for index, item in enumerate(sublist):
maxes[index] = max(maxes[index], item)
return maxes
class NestedTensor:
def __init__(self, tensors, mask: Optional[Tensor]):
self.tensors = tensors
self.mask = mask
def to(self, device):
# type: (Device) -> NestedTensor # noqa
cast_tensor = self.tensors.to(device)
mask = self.mask
if mask is not None:
assert mask is not None
cast_mask = mask.to(device)
else:
cast_mask = None
return NestedTensor(cast_tensor, cast_mask)
def decompose(self):
return self.tensors, self.mask
def __repr__(self):
return str(self.tensors)
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
# TODO make this more general
if tensor_list[0].ndim == 3:
if torchvision._is_tracing():
# nested_tensor_from_tensor_list() does not export well to ONNX
# call _onnx_nested_tensor_from_tensor_list() instead
return _onnx_nested_tensor_from_tensor_list(tensor_list)
# TODO make it support different-sized images
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
batch_shape = [len(tensor_list)] + max_size
b, c, h, w = batch_shape
dtype = tensor_list[0].dtype
device = tensor_list[0].device
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
for img, pad_img, m in zip(tensor_list, tensor, mask, strict=False):
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
m[: img.shape[1], : img.shape[2]] = False
else:
raise ValueError("not supported")
return NestedTensor(tensor, mask)
# _onnx_nested_tensor_from_tensor_list() is an implementation of
# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
@torch.jit.unused
def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
max_size = []
for i in range(tensor_list[0].dim()):
max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(
torch.int64
)
max_size.append(max_size_i)
max_size = tuple(max_size)
# work around for
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
# m[: img.shape[1], :img.shape[2]] = False
# which is not yet supported in onnx
padded_imgs = []
padded_masks = []
for img in tensor_list:
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape), strict=False)]
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
padded_imgs.append(padded_img)
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
padded_masks.append(padded_mask.to(torch.bool))
tensor = torch.stack(padded_imgs)
mask = torch.stack(padded_masks)
return NestedTensor(tensor, mask=mask)
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop("force", False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def save_on_master(*args, **kwargs):
if is_main_process():
torch.save(*args, **kwargs)
def init_distributed_mode(args):
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ["WORLD_SIZE"])
args.gpu = int(os.environ["LOCAL_RANK"])
elif "SLURM_PROCID" in os.environ:
args.rank = int(os.environ["SLURM_PROCID"])
args.gpu = args.rank % torch.cuda.device_count()
else:
print("Not using distributed mode")
args.distributed = False
return
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = "nccl"
print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True)
torch.distributed.init_process_group(
backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
)
torch.distributed.barrier()
setup_for_distributed(args.rank == 0)
@torch.no_grad()
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
if target.numel() == 0:
return [torch.zeros([], device=output.device)]
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
"""
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
This will eventually be supported natively by PyTorch, and this
class can go away.
"""
if version.parse(torchvision.__version__) < version.parse("0.7"):
if input.numel() > 0:
return torch.nn.functional.interpolate(input, size, scale_factor, mode, align_corners)
output_shape = _output_size(2, input, size, scale_factor)
output_shape = list(input.shape[:-2]) + list(output_shape)
return _new_empty_tensor(input, output_shape)
else:
return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)

View File

@ -33,11 +33,10 @@ policy:
nheads: 8
#camera_names: [top, front_close, left_pillar, right_pillar]
camera_names: [top]
position_embedding: sine
masks: false
dilation: false
dropout: 0.1
pre_norm: false
activation: relu
vae: true

View File

@ -11,6 +11,19 @@ policy = make_policy(cfg)
state_dict = torch.load("/home/alexander/Projects/act/outputs/sim_transfer_cube_human_vae/policy_last.ckpt")
# Remove keys based on what they start with.
start_removals = [
# There is a bug that means the pretrained model doesn't even use the final decoder layers.
*[f"model.transformer.decoder.layers.{i}" for i in range(1, 7)],
"model.is_pad_head.",
]
for to_remove in start_removals:
for k in list(state_dict.keys()):
if k.startswith(to_remove):
del state_dict[k]
# Replace keys based on what they start with.
@ -26,6 +39,9 @@ start_replacements = [
("model.input_proj.", "model.encoder_img_feat_input_proj."),
("model.input_proj_robot_state", "model.encoder_robot_state_input_proj"),
("model.latent_out_proj.", "model.encoder_latent_input_proj."),
("model.transformer.encoder.", "model.encoder."),
("model.transformer.decoder.", "model.decoder."),
("model.backbones.0.0.body.", "model.backbone."),
]
for to_replace, replace_with in start_replacements:
@ -35,18 +51,6 @@ for to_replace, replace_with in start_replacements:
state_dict[k_] = state_dict[k]
del state_dict[k]
# Remove keys based on what they start with.
start_removals = [
# There is a bug that means the pretrained model doesn't even use the final decoder layers.
*[f"model.transformer.decoder.layers.{i}" for i in range(1, 7)],
"model.is_pad_head.",
]
for to_remove in start_removals:
for k in list(state_dict.keys()):
if k.startswith(to_remove):
del state_dict[k]
missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)