backup wip
This commit is contained in:
parent
278336a39a
commit
3a4dfa82fe
|
@ -1,115 +0,0 @@
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torchvision
|
|
||||||
from torch import nn
|
|
||||||
from torchvision.models._utils import IntermediateLayerGetter
|
|
||||||
|
|
||||||
from .position_encoding import build_position_encoding
|
|
||||||
from .utils import NestedTensor, is_main_process
|
|
||||||
|
|
||||||
|
|
||||||
class FrozenBatchNorm2d(torch.nn.Module):
|
|
||||||
"""
|
|
||||||
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
|
||||||
|
|
||||||
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
|
|
||||||
without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101]
|
|
||||||
produce nans.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, n):
|
|
||||||
super().__init__()
|
|
||||||
self.register_buffer("weight", torch.ones(n))
|
|
||||||
self.register_buffer("bias", torch.zeros(n))
|
|
||||||
self.register_buffer("running_mean", torch.zeros(n))
|
|
||||||
self.register_buffer("running_var", torch.ones(n))
|
|
||||||
|
|
||||||
def _load_from_state_dict(
|
|
||||||
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
|
||||||
):
|
|
||||||
num_batches_tracked_key = prefix + "num_batches_tracked"
|
|
||||||
if num_batches_tracked_key in state_dict:
|
|
||||||
del state_dict[num_batches_tracked_key]
|
|
||||||
|
|
||||||
super()._load_from_state_dict(
|
|
||||||
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
# move reshapes to the beginning
|
|
||||||
# to make it fuser-friendly
|
|
||||||
w = self.weight.reshape(1, -1, 1, 1)
|
|
||||||
b = self.bias.reshape(1, -1, 1, 1)
|
|
||||||
rv = self.running_var.reshape(1, -1, 1, 1)
|
|
||||||
rm = self.running_mean.reshape(1, -1, 1, 1)
|
|
||||||
eps = 1e-5
|
|
||||||
scale = w * (rv + eps).rsqrt()
|
|
||||||
bias = b - rm * scale
|
|
||||||
return x * scale + bias
|
|
||||||
|
|
||||||
|
|
||||||
class BackboneBase(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
# for name, parameter in backbone.named_parameters(): # only train later layers # TODO do we want this?
|
|
||||||
# if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
|
|
||||||
# parameter.requires_grad_(False)
|
|
||||||
if return_interm_layers:
|
|
||||||
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
|
|
||||||
else:
|
|
||||||
return_layers = {"layer4": "0"}
|
|
||||||
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
|
|
||||||
self.num_channels = num_channels
|
|
||||||
|
|
||||||
def forward(self, tensor):
|
|
||||||
xs = self.body(tensor)
|
|
||||||
return xs
|
|
||||||
# out: Dict[str, NestedTensor] = {}
|
|
||||||
# for name, x in xs.items():
|
|
||||||
# m = tensor_list.mask
|
|
||||||
# assert m is not None
|
|
||||||
# mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
|
|
||||||
# out[name] = NestedTensor(x, mask)
|
|
||||||
# return out
|
|
||||||
|
|
||||||
|
|
||||||
class Backbone(BackboneBase):
|
|
||||||
"""ResNet backbone with frozen BatchNorm."""
|
|
||||||
|
|
||||||
def __init__(self, name: str, train_backbone: bool, return_interm_layers: bool, dilation: bool):
|
|
||||||
backbone = getattr(torchvision.models, name)(
|
|
||||||
replace_stride_with_dilation=[False, False, dilation],
|
|
||||||
pretrained=is_main_process(),
|
|
||||||
norm_layer=FrozenBatchNorm2d,
|
|
||||||
) # pretrained # TODO do we want frozen batch_norm??
|
|
||||||
num_channels = 512 if name in ("resnet18", "resnet34") else 2048
|
|
||||||
super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
|
|
||||||
|
|
||||||
|
|
||||||
class Joiner(nn.Sequential):
|
|
||||||
def __init__(self, backbone, position_embedding):
|
|
||||||
super().__init__(backbone, position_embedding)
|
|
||||||
|
|
||||||
def forward(self, tensor_list: NestedTensor):
|
|
||||||
xs = self[0](tensor_list)
|
|
||||||
out: List[NestedTensor] = []
|
|
||||||
pos = []
|
|
||||||
for _, x in xs.items():
|
|
||||||
out.append(x)
|
|
||||||
# position encoding
|
|
||||||
pos.append(self[1](x).to(x.dtype))
|
|
||||||
|
|
||||||
return out, pos
|
|
||||||
|
|
||||||
|
|
||||||
def build_backbone(args):
|
|
||||||
position_embedding = build_position_encoding(args)
|
|
||||||
train_backbone = args.lr_backbone > 0
|
|
||||||
return_interm_layers = args.masks
|
|
||||||
backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
|
|
||||||
model = Joiner(backbone, position_embedding)
|
|
||||||
model.num_channels = backbone.num_channels
|
|
||||||
return model
|
|
|
@ -1,229 +0,0 @@
|
||||||
import einops
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
from .backbone import build_backbone
|
|
||||||
from .transformer import Transformer, TransformerEncoder
|
|
||||||
|
|
||||||
|
|
||||||
def get_sinusoid_encoding_table(n_position, d_hid):
|
|
||||||
def get_position_angle_vec(position):
|
|
||||||
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
|
|
||||||
|
|
||||||
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
|
|
||||||
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
|
||||||
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
|
||||||
|
|
||||||
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
|
|
||||||
|
|
||||||
|
|
||||||
class ActionChunkingTransformer(nn.Module):
|
|
||||||
"""
|
|
||||||
Action Chunking Transformer as per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware
|
|
||||||
(paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
|
|
||||||
|
|
||||||
Note: In this code we use the symbols `vae_encoder`, 'encoder', `decoder`. The meanings are as follows.
|
|
||||||
- The `vae_encoder` is, as per the literature around conditional variational auto-encoders (cVAE), the
|
|
||||||
part of the model that encodes the target data (here, a sequence of actions), and the condition
|
|
||||||
(here, we include the robot joint-space state as an input to the encoder).
|
|
||||||
- The `transformer` is the cVAE's decoder. But since we have an option to train this model without the
|
|
||||||
variational objective (in which case we drop the `vae_encoder` altogether), we don't call it the
|
|
||||||
`vae_decoder`.
|
|
||||||
# TODO(now): remove the following
|
|
||||||
- The `encoder` is actually a component of the cVAE's "decoder". But we refer to it as an "encoder"
|
|
||||||
because, in terms of the transformer with cross-attention that forms the cVAE's decoder, it is the
|
|
||||||
"encoder" part. We drop the `vae_` prefix because we have an option to train this model without the
|
|
||||||
variational objective (in which case we drop the `vae_encoder` altogether), and nothing about this
|
|
||||||
model has anything to do with a VAE).
|
|
||||||
- The `decoder` is a building block of the VAE decoder, and is just the "decoder" part of a
|
|
||||||
transformer with cross-attention. For the same reasoning behind the naming of `encoder`, we make
|
|
||||||
this term agnostic to the option to use a variational objective for training.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, backbones, transformer, vae_encoder, state_dim, action_dim, horizon, camera_names, use_vae
|
|
||||||
):
|
|
||||||
"""Initializes the model.
|
|
||||||
Parameters:
|
|
||||||
backbones: torch module of the backbone to be used. See backbone.py
|
|
||||||
transformer: torch module of the transformer architecture. See transformer.py
|
|
||||||
state_dim: robot state dimension of the environment
|
|
||||||
horizon: number of object queries, ie detection slot. This is the maximal number of objects
|
|
||||||
DETR can detect in a single image. For COCO, we recommend 100 queries.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state_dim: Robot positional state dimension.
|
|
||||||
action_dim: Action dimension.
|
|
||||||
horizon: The number of actions to generate in one forward pass.
|
|
||||||
use_vae: Whether to use the variational objective. TODO(now): Give more details.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.camera_names = camera_names
|
|
||||||
self.transformer = transformer
|
|
||||||
self.vae_encoder = vae_encoder
|
|
||||||
self.use_vae = use_vae
|
|
||||||
hidden_dim = transformer.d_model
|
|
||||||
|
|
||||||
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
|
|
||||||
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
|
||||||
if use_vae:
|
|
||||||
self.cls_embed = nn.Embedding(1, hidden_dim)
|
|
||||||
# Projection layer for joint-space configuration to hidden dimension.
|
|
||||||
self.vae_encoder_robot_state_input_proj = nn.Linear(state_dim, hidden_dim)
|
|
||||||
# Projection layer for action (joint-space target) to hidden dimension.
|
|
||||||
self.vae_encoder_action_input_proj = nn.Linear(state_dim, hidden_dim)
|
|
||||||
# Final size of latent z. TODO(now): Add to hyperparams.
|
|
||||||
self.latent_dim = 32
|
|
||||||
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
|
|
||||||
self.vae_encoder_latent_output_proj = nn.Linear(hidden_dim, self.latent_dim * 2)
|
|
||||||
# Fixed sinusoidal positional embedding the whole input to the VAE encoder.
|
|
||||||
self.register_buffer(
|
|
||||||
"vae_encoder_pos_enc", get_sinusoid_encoding_table(1 + 1 + horizon, hidden_dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Transformer encoder input projections. The tokens will be structured like
|
|
||||||
# [latent, robot_state, image_feature_map_pixels].
|
|
||||||
self.backbones = nn.ModuleList(backbones)
|
|
||||||
self.encoder_img_feat_input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1)
|
|
||||||
self.encoder_robot_state_input_proj = nn.Linear(state_dim, hidden_dim)
|
|
||||||
self.encoder_latent_input_proj = nn.Linear(self.latent_dim, hidden_dim)
|
|
||||||
# TODO(now): Fix this nonsense. One positional embedding is needed. We should extract the image
|
|
||||||
# feature dimension with a dry run.
|
|
||||||
self.additional_pos_embed = nn.Embedding(
|
|
||||||
2, hidden_dim
|
|
||||||
) # learned position embedding for proprio and latent
|
|
||||||
|
|
||||||
# Transformer decoder.
|
|
||||||
# Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
|
|
||||||
self.decoder_pos_embed = nn.Embedding(horizon, hidden_dim)
|
|
||||||
# Final action regression head on the output of the transformer's decoder.
|
|
||||||
self.action_head = nn.Linear(hidden_dim, action_dim)
|
|
||||||
|
|
||||||
def forward(self, robot_state, image, actions=None):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
robot_state: (B, J) batch of robot joint configurations.
|
|
||||||
image: (B, N, C, H, W) batch of N camera frames.
|
|
||||||
actions: (B, S, A) batch of actions from the target dataset which must be provided if the
|
|
||||||
VAE is enabled and the model is in training mode.
|
|
||||||
"""
|
|
||||||
if self.use_vae and self.training:
|
|
||||||
assert (
|
|
||||||
actions is not None
|
|
||||||
), "actions must be provided when using the variational objective in training mode."
|
|
||||||
|
|
||||||
batch_size, _ = robot_state.shape
|
|
||||||
|
|
||||||
# Prepare the latent for input to the transformer.
|
|
||||||
if self.use_vae and actions is not None:
|
|
||||||
# Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence].
|
|
||||||
cls_embed = einops.repeat(self.cls_embed.weight, "1 d -> b 1 d", b=batch_size) # (B, 1, D)
|
|
||||||
robot_state_embed = self.vae_encoder_robot_state_input_proj(robot_state).unsqueeze(1) # (B, 1, D)
|
|
||||||
action_embed = self.vae_encoder_action_input_proj(actions) # (B, S, D)
|
|
||||||
vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D)
|
|
||||||
# Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
|
|
||||||
# Prepare fixed positional embedding.
|
|
||||||
pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D)
|
|
||||||
# Forward pass through VAE encoder and sample the latent with the reparameterization trick.
|
|
||||||
cls_token_out = self.vae_encoder(
|
|
||||||
vae_encoder_input.permute(1, 0, 2), pos=pos_embed.permute(1, 0, 2)
|
|
||||||
)[0] # (B, D)
|
|
||||||
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
|
|
||||||
mu = latent_pdf_params[:, : self.latent_dim]
|
|
||||||
logvar = latent_pdf_params[:, self.latent_dim :]
|
|
||||||
# Use reparameterization trick to sample from the latent's PDF.
|
|
||||||
latent_sample = mu + logvar.div(2).exp() * torch.randn_like(mu)
|
|
||||||
else:
|
|
||||||
# When not using the VAE encoder, we set the latent to be all zeros.
|
|
||||||
mu = logvar = None
|
|
||||||
latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=robot_state.dtype).to(
|
|
||||||
robot_state.device
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prepare all other transformer inputs.
|
|
||||||
# Image observation features and position embeddings.
|
|
||||||
all_cam_features = []
|
|
||||||
all_cam_pos = []
|
|
||||||
for cam_id, _ in enumerate(self.camera_names):
|
|
||||||
# TODO(now): remove the positional embedding from the backbones.
|
|
||||||
cam_features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED
|
|
||||||
cam_features = cam_features[0] # take the last layer feature
|
|
||||||
pos = pos[0]
|
|
||||||
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
|
|
||||||
all_cam_features.append(cam_features)
|
|
||||||
all_cam_pos.append(pos)
|
|
||||||
# Concatenate image observation feature maps along the width dimension.
|
|
||||||
transformer_input = torch.cat(all_cam_features, axis=3)
|
|
||||||
# TODO(now): remove the positional embedding from the backbones.
|
|
||||||
pos = torch.cat(all_cam_pos, axis=3)
|
|
||||||
robot_state_embed = self.encoder_robot_state_input_proj(robot_state)
|
|
||||||
latent_embed = self.encoder_latent_input_proj(latent_sample)
|
|
||||||
|
|
||||||
# TODO(now): Explain all of this madness.
|
|
||||||
transformer_input = torch.cat(
|
|
||||||
[
|
|
||||||
torch.stack([latent_embed, robot_state_embed], axis=0),
|
|
||||||
transformer_input.flatten(2).permute(2, 0, 1),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
pos_embed = torch.cat(
|
|
||||||
[self.additional_pos_embed.weight.unsqueeze(1), pos.flatten(2).permute(2, 0, 1)], axis=0
|
|
||||||
)
|
|
||||||
|
|
||||||
# Run the transformer and project the outputs to the action space.
|
|
||||||
transformer_output = self.transformer(
|
|
||||||
transformer_input,
|
|
||||||
encoder_pos=pos_embed,
|
|
||||||
decoder_pos=self.decoder_pos_embed.weight.unsqueeze(1),
|
|
||||||
).transpose(0, 1) # back to (B, S, C)
|
|
||||||
actions = self.action_head(transformer_output)
|
|
||||||
return actions, [mu, logvar]
|
|
||||||
|
|
||||||
|
|
||||||
def build(args):
|
|
||||||
# From state
|
|
||||||
# backbone = None # from state for now, no need for conv nets
|
|
||||||
# From image
|
|
||||||
backbones = []
|
|
||||||
backbone = build_backbone(args)
|
|
||||||
backbones.append(backbone)
|
|
||||||
|
|
||||||
transformer = Transformer(
|
|
||||||
d_model=args.hidden_dim,
|
|
||||||
dropout=args.dropout,
|
|
||||||
nhead=args.nheads,
|
|
||||||
dim_feedforward=args.dim_feedforward,
|
|
||||||
num_encoder_layers=args.enc_layers,
|
|
||||||
num_decoder_layers=args.dec_layers,
|
|
||||||
normalize_before=args.pre_norm,
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO(now): args.enc_layers shouldn't be shared with the transformer decoder
|
|
||||||
vae_encoder = TransformerEncoder(
|
|
||||||
num_layers=args.enc_layers,
|
|
||||||
d_model=args.hidden_dim,
|
|
||||||
nhead=args.nheads,
|
|
||||||
dim_feedforward=args.dim_feedforward,
|
|
||||||
dropout=args.dropout,
|
|
||||||
activation="relu",
|
|
||||||
normalize_before=args.pre_norm,
|
|
||||||
)
|
|
||||||
|
|
||||||
model = ActionChunkingTransformer(
|
|
||||||
backbones,
|
|
||||||
transformer,
|
|
||||||
vae_encoder,
|
|
||||||
state_dim=args.state_dim,
|
|
||||||
action_dim=args.action_dim,
|
|
||||||
horizon=args.num_queries,
|
|
||||||
camera_names=args.camera_names,
|
|
||||||
use_vae=args.vae,
|
|
||||||
)
|
|
||||||
|
|
||||||
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
||||||
print("number of parameters: {:.2f}M".format(n_parameters / 1e6))
|
|
||||||
|
|
||||||
return model
|
|
|
@ -1,50 +1,32 @@
|
||||||
import logging
|
"""Action Chunking Transformer Policy
|
||||||
import time
|
|
||||||
|
|
||||||
|
As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://arxiv.org/abs/2304.13705).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
from itertools import chain
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
import einops
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
|
import torchvision
|
||||||
import torchvision.transforms as transforms
|
import torchvision.transforms as transforms
|
||||||
|
from torch import Tensor, nn
|
||||||
|
from torchvision.models._utils import IntermediateLayerGetter
|
||||||
|
from torchvision.ops.misc import FrozenBatchNorm2d
|
||||||
|
|
||||||
from lerobot.common.policies.abstract import AbstractPolicy
|
from lerobot.common.policies.abstract import AbstractPolicy
|
||||||
from lerobot.common.policies.act.detr_vae import build
|
|
||||||
from lerobot.common.utils import get_safe_torch_device
|
from lerobot.common.utils import get_safe_torch_device
|
||||||
|
|
||||||
|
|
||||||
def build_act_model_and_optimizer(cfg):
|
|
||||||
model = build(cfg)
|
|
||||||
|
|
||||||
param_dicts = [
|
|
||||||
{"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
|
|
||||||
{
|
|
||||||
"params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
|
|
||||||
"lr": cfg.lr_backbone,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
optimizer = torch.optim.AdamW(param_dicts, lr=cfg.lr, weight_decay=cfg.weight_decay)
|
|
||||||
|
|
||||||
return model, optimizer
|
|
||||||
|
|
||||||
|
|
||||||
def kl_divergence(mu, logvar):
|
|
||||||
batch_size = mu.size(0)
|
|
||||||
assert batch_size != 0
|
|
||||||
if mu.data.ndimension() == 4:
|
|
||||||
mu = mu.view(mu.size(0), mu.size(1))
|
|
||||||
if logvar.data.ndimension() == 4:
|
|
||||||
logvar = logvar.view(logvar.size(0), logvar.size(1))
|
|
||||||
|
|
||||||
klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
|
|
||||||
total_kld = klds.sum(1).mean(0, True)
|
|
||||||
dimension_wise_kld = klds.mean(0)
|
|
||||||
mean_kld = klds.mean(1).mean(0, True)
|
|
||||||
|
|
||||||
return total_kld, dimension_wise_kld, mean_kld
|
|
||||||
|
|
||||||
|
|
||||||
class ActionChunkingTransformerPolicy(AbstractPolicy):
|
class ActionChunkingTransformerPolicy(AbstractPolicy):
|
||||||
"""
|
"""
|
||||||
Action Chunking Transformer as per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware
|
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
|
||||||
(https://arxiv.org/abs/2304.13705).
|
Hardware (https://arxiv.org/abs/2304.13705).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name = "act"
|
name = "act"
|
||||||
|
@ -68,7 +50,35 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.n_action_steps = n_action_steps
|
self.n_action_steps = n_action_steps
|
||||||
self.device = get_safe_torch_device(device)
|
self.device = get_safe_torch_device(device)
|
||||||
self.model, self.optimizer = build_act_model_and_optimizer(cfg)
|
|
||||||
|
self.model = ActionChunkingTransformer(
|
||||||
|
cfg,
|
||||||
|
state_dim=cfg.state_dim,
|
||||||
|
action_dim=cfg.action_dim,
|
||||||
|
horizon=cfg.horizon,
|
||||||
|
camera_names=cfg.camera_names,
|
||||||
|
use_vae=cfg.vae,
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer_params_dicts = [
|
||||||
|
{
|
||||||
|
"params": [
|
||||||
|
p
|
||||||
|
for n, p in self.model.named_parameters()
|
||||||
|
if not n.startswith("backbone") and p.requires_grad
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": [
|
||||||
|
p
|
||||||
|
for n, p in self.model.named_parameters()
|
||||||
|
if n.startswith("backbone") and p.requires_grad
|
||||||
|
],
|
||||||
|
"lr": cfg.lr_backbone,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
self.optimizer = torch.optim.AdamW(optimizer_params_dicts, lr=cfg.lr, weight_decay=cfg.weight_decay)
|
||||||
|
|
||||||
self.kl_weight = self.cfg.kl_weight
|
self.kl_weight = self.cfg.kl_weight
|
||||||
logging.info(f"KL Weight {self.kl_weight}")
|
logging.info(f"KL Weight {self.kl_weight}")
|
||||||
self.to(self.device)
|
self.to(self.device)
|
||||||
|
@ -140,12 +150,10 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
|
||||||
|
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
# self.lr_scheduler.step()
|
|
||||||
|
|
||||||
info = {
|
info = {
|
||||||
"loss": loss.item(),
|
"loss": loss.item(),
|
||||||
"grad_norm": float(grad_norm),
|
"grad_norm": float(grad_norm),
|
||||||
# "lr": self.lr_scheduler.get_last_lr()[0],
|
|
||||||
"lr": self.cfg.lr,
|
"lr": self.cfg.lr,
|
||||||
"data_s": data_s,
|
"data_s": data_s,
|
||||||
"update_s": time.time() - start_time,
|
"update_s": time.time() - start_time,
|
||||||
|
@ -213,31 +221,495 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
|
||||||
action = action[: self.n_action_steps]
|
action = action[: self.n_action_steps]
|
||||||
return action
|
return action
|
||||||
|
|
||||||
def _forward(self, qpos, image, actions=None, is_pad=None):
|
def _forward(self, qpos, image, actions=None):
|
||||||
env_state = None
|
|
||||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||||
image = normalize(image)
|
image = normalize(image)
|
||||||
|
|
||||||
is_training = actions is not None
|
is_training = actions is not None
|
||||||
if is_training: # training time
|
if is_training: # training time
|
||||||
actions = actions[:, : self.model.num_queries]
|
actions = actions[:, : self.model.horizon]
|
||||||
if is_pad is not None:
|
|
||||||
is_pad = is_pad[:, : self.model.num_queries]
|
|
||||||
|
|
||||||
a_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad)
|
a_hat, (mu, log_sigma_x2) = self.model(qpos, image, actions)
|
||||||
|
|
||||||
all_l1 = F.l1_loss(actions, a_hat, reduction="none")
|
all_l1 = F.l1_loss(actions, a_hat, reduction="none")
|
||||||
l1 = all_l1.mean() if is_pad is None else (all_l1 * ~is_pad.unsqueeze(-1)).mean()
|
l1 = all_l1.mean()
|
||||||
|
|
||||||
loss_dict = {}
|
loss_dict = {}
|
||||||
loss_dict["l1"] = l1
|
loss_dict["l1"] = l1
|
||||||
if self.cfg.vae:
|
if self.cfg.vae:
|
||||||
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
|
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
||||||
loss_dict["kl"] = total_kld[0]
|
# each dimension independently, we sum over the latent dimension to get the total
|
||||||
|
# KL-divergence per batch element, then take the mean over the batch.
|
||||||
|
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
|
||||||
|
mean_kld = (-0.5 * (1 + log_sigma_x2 - mu.pow(2) - (log_sigma_x2).exp())).sum(-1).mean()
|
||||||
|
loss_dict["kl"] = mean_kld
|
||||||
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.kl_weight
|
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.kl_weight
|
||||||
else:
|
else:
|
||||||
loss_dict["loss"] = loss_dict["l1"]
|
loss_dict["loss"] = loss_dict["l1"]
|
||||||
return loss_dict
|
return loss_dict
|
||||||
else:
|
else:
|
||||||
action, _ = self.model(qpos, image, env_state) # no action, sample from prior
|
action, _ = self.model(qpos, image) # no action, sample from prior
|
||||||
return action
|
return action
|
||||||
|
|
||||||
|
|
||||||
|
def create_sinusoidal_position_embedding(n_position, d_hid):
|
||||||
|
def get_position_angle_vec(position):
|
||||||
|
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
|
||||||
|
|
||||||
|
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
|
||||||
|
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
||||||
|
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
||||||
|
|
||||||
|
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(alexander-soare) move all this code into the policy when we have the policy API established.
|
||||||
|
class ActionChunkingTransformer(nn.Module):
|
||||||
|
"""
|
||||||
|
Action Chunking Transformer as per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware
|
||||||
|
(paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
|
||||||
|
|
||||||
|
Note: In this code we use the terms `vae_encoder`, 'encoder', `decoder`. The meanings are as follows.
|
||||||
|
- The `vae_encoder` is, as per the literature around variational auto-encoders (VAE), the part of the
|
||||||
|
model that encodes the target data (a sequence of actions), and the condition (the robot
|
||||||
|
joint-space).
|
||||||
|
- A transformer with an `encoder` (not the VAE encoder) and `decoder` (not the VAE decoder) with
|
||||||
|
cross-attention is used as the VAE decoder. For these terms, we drop the `vae_` prefix because we
|
||||||
|
have an option to train this model without the variational objective (in which case we drop the
|
||||||
|
`vae_encoder` altogether, and nothing about this model has anything to do with a VAE).
|
||||||
|
|
||||||
|
Transformer
|
||||||
|
Used alone for inference
|
||||||
|
(acts as VAE decoder
|
||||||
|
during training)
|
||||||
|
┌───────────────────────┐
|
||||||
|
│ Outputs │
|
||||||
|
│ ▲ │
|
||||||
|
│ ┌─────►┌───────┐ │
|
||||||
|
┌──────┐ │ │ │Transf.│ │
|
||||||
|
│ │ │ ├─────►│decoder│ │
|
||||||
|
┌────┴────┐ │ │ │ │ │ │
|
||||||
|
│ │ │ │ ┌───┴───┬─►│ │ │
|
||||||
|
│ VAE │ │ │ │ │ └───────┘ │
|
||||||
|
│ encoder │ │ │ │Transf.│ │
|
||||||
|
│ │ │ │ │encoder│ │
|
||||||
|
└───▲─────┘ │ │ │ │ │
|
||||||
|
│ │ │ └───▲───┘ │
|
||||||
|
│ │ │ │ │
|
||||||
|
inputs └─────┼─────┘ │
|
||||||
|
│ │
|
||||||
|
└───────────────────────┘
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, args, state_dim, action_dim, horizon, camera_names, use_vae):
|
||||||
|
"""Initializes the model.
|
||||||
|
Parameters:
|
||||||
|
state_dim: robot state dimension of the environment
|
||||||
|
horizon: number of object queries, ie detection slot. This is the maximal number of objects
|
||||||
|
DETR can detect in a single image. For COCO, we recommend 100 queries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_dim: Robot positional state dimension.
|
||||||
|
action_dim: Action dimension.
|
||||||
|
horizon: The number of actions to generate in one forward pass.
|
||||||
|
use_vae: Whether to use the variational objective. TODO(now): Give more details.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.camera_names = camera_names
|
||||||
|
self.use_vae = use_vae
|
||||||
|
self.horizon = horizon
|
||||||
|
self.hidden_dim = args.hidden_dim
|
||||||
|
|
||||||
|
transformer_common_kwargs = dict( # noqa: C408
|
||||||
|
d_model=self.hidden_dim,
|
||||||
|
nhead=args.nheads,
|
||||||
|
dim_feedforward=args.dim_feedforward,
|
||||||
|
dropout=args.dropout,
|
||||||
|
activation=args.activation,
|
||||||
|
normalize_before=args.pre_norm,
|
||||||
|
)
|
||||||
|
|
||||||
|
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
|
||||||
|
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
||||||
|
if use_vae:
|
||||||
|
# TODO(now): args.enc_layers shouldn't be shared with the transformer decoder
|
||||||
|
self.vae_encoder = TransformerEncoder(num_layers=args.enc_layers, **transformer_common_kwargs)
|
||||||
|
self.cls_embed = nn.Embedding(1, self.hidden_dim)
|
||||||
|
# Projection layer for joint-space configuration to hidden dimension.
|
||||||
|
self.vae_encoder_robot_state_input_proj = nn.Linear(state_dim, self.hidden_dim)
|
||||||
|
# Projection layer for action (joint-space target) to hidden dimension.
|
||||||
|
self.vae_encoder_action_input_proj = nn.Linear(state_dim, self.hidden_dim)
|
||||||
|
# Final size of latent z. TODO(now): Add to hyperparams.
|
||||||
|
self.latent_dim = 32
|
||||||
|
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
|
||||||
|
self.vae_encoder_latent_output_proj = nn.Linear(self.hidden_dim, self.latent_dim * 2)
|
||||||
|
# Fixed sinusoidal positional embedding the whole input to the VAE encoder.
|
||||||
|
self.register_buffer(
|
||||||
|
"vae_encoder_pos_enc", create_sinusoidal_position_embedding(1 + 1 + horizon, self.hidden_dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Backbone for image feature extraction.
|
||||||
|
self.backbone_position_embedding = SinusoidalPositionEmbedding2D(self.hidden_dim // 2)
|
||||||
|
backbone_model = getattr(torchvision.models, args.backbone)(
|
||||||
|
replace_stride_with_dilation=[False, False, args.dilation],
|
||||||
|
pretrained=True, # TODO(now): Add pretrained option
|
||||||
|
norm_layer=FrozenBatchNorm2d,
|
||||||
|
)
|
||||||
|
# Note: The forward method of this returns a dict: {"feature_map": output}.
|
||||||
|
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
|
||||||
|
|
||||||
|
# Transformer (acts as VAE decoder when training with the variational objective).
|
||||||
|
self.encoder = TransformerEncoder(num_layers=args.enc_layers, **transformer_common_kwargs)
|
||||||
|
self.decoder = TransformerDecoder(num_layers=args.dec_layers, **transformer_common_kwargs)
|
||||||
|
|
||||||
|
# Transformer encoder input projections. The tokens will be structured like
|
||||||
|
# [latent, robot_state, image_feature_map_pixels].
|
||||||
|
self.encoder_img_feat_input_proj = nn.Conv2d(
|
||||||
|
backbone_model.fc.in_features, self.hidden_dim, kernel_size=1
|
||||||
|
)
|
||||||
|
self.encoder_robot_state_input_proj = nn.Linear(state_dim, self.hidden_dim)
|
||||||
|
self.encoder_latent_input_proj = nn.Linear(self.latent_dim, self.hidden_dim)
|
||||||
|
# TODO(now): Fix this nonsense. One positional embedding is needed. We should extract the image
|
||||||
|
# feature dimension with a dry run.
|
||||||
|
self.additional_pos_embed = nn.Embedding(
|
||||||
|
2, self.hidden_dim
|
||||||
|
) # learned position embedding for proprio and latent
|
||||||
|
|
||||||
|
# Transformer decoder.
|
||||||
|
# Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
|
||||||
|
self.decoder_pos_embed_embed = nn.Embedding(horizon, self.hidden_dim)
|
||||||
|
# Final action regression head on the output of the transformer's decoder.
|
||||||
|
self.action_head = nn.Linear(self.hidden_dim, action_dim)
|
||||||
|
|
||||||
|
self._reset_parameters()
|
||||||
|
|
||||||
|
def _reset_parameters(self):
|
||||||
|
"""Xavier-uniform initialization of the transformer parameters as in the original code."""
|
||||||
|
for p in chain(self.encoder.parameters(), self.decoder.parameters()):
|
||||||
|
if p.dim() > 1:
|
||||||
|
nn.init.xavier_uniform_(p)
|
||||||
|
|
||||||
|
def forward(self, robot_state, image, actions=None):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
robot_state: (B, J) batch of robot joint configurations.
|
||||||
|
image: (B, N, C, H, W) batch of N camera frames.
|
||||||
|
actions: (B, S, A) batch of actions from the target dataset which must be provided if the
|
||||||
|
VAE is enabled and the model is in training mode.
|
||||||
|
"""
|
||||||
|
if self.use_vae and self.training:
|
||||||
|
assert (
|
||||||
|
actions is not None
|
||||||
|
), "actions must be provided when using the variational objective in training mode."
|
||||||
|
|
||||||
|
batch_size, _ = robot_state.shape
|
||||||
|
|
||||||
|
# Prepare the latent for input to the transformer.
|
||||||
|
if self.use_vae and actions is not None:
|
||||||
|
# Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence].
|
||||||
|
cls_embed = einops.repeat(self.cls_embed.weight, "1 d -> b 1 d", b=batch_size) # (B, 1, D)
|
||||||
|
robot_state_embed = self.vae_encoder_robot_state_input_proj(robot_state).unsqueeze(1) # (B, 1, D)
|
||||||
|
action_embed = self.vae_encoder_action_input_proj(actions) # (B, S, D)
|
||||||
|
vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D)
|
||||||
|
# Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
|
||||||
|
# Prepare fixed positional embedding.
|
||||||
|
pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D)
|
||||||
|
# Forward pass through VAE encoder and sample the latent with the reparameterization trick.
|
||||||
|
cls_token_out = self.vae_encoder(
|
||||||
|
vae_encoder_input.permute(1, 0, 2), pos=pos_embed.permute(1, 0, 2)
|
||||||
|
)[0] # (B, D)
|
||||||
|
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
|
||||||
|
mu = latent_pdf_params[:, : self.latent_dim]
|
||||||
|
# This is 2log(sigma). Done this way to match the original implementation.
|
||||||
|
log_sigma_x2 = latent_pdf_params[:, self.latent_dim :]
|
||||||
|
# Use reparameterization trick to sample from the latent's PDF.
|
||||||
|
latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu)
|
||||||
|
else:
|
||||||
|
# When not using the VAE encoder, we set the latent to be all zeros.
|
||||||
|
mu = log_sigma_x2 = None
|
||||||
|
latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=torch.float32).to(
|
||||||
|
robot_state.device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare all other transformer inputs.
|
||||||
|
# Image observation features and position embeddings.
|
||||||
|
all_cam_features = []
|
||||||
|
all_cam_pos = []
|
||||||
|
for cam_id, _ in enumerate(self.camera_names):
|
||||||
|
cam_features = self.backbone(image[:, cam_id])["feature_map"]
|
||||||
|
pos = self.backbone_position_embedding(cam_features).to(dtype=cam_features.dtype)
|
||||||
|
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
|
||||||
|
all_cam_features.append(cam_features)
|
||||||
|
all_cam_pos.append(pos)
|
||||||
|
# Concatenate image observation feature maps along the width dimension.
|
||||||
|
encoder_in = torch.cat(all_cam_features, axis=3)
|
||||||
|
pos = torch.cat(all_cam_pos, axis=3)
|
||||||
|
robot_state_embed = self.encoder_robot_state_input_proj(robot_state)
|
||||||
|
latent_embed = self.encoder_latent_input_proj(latent_sample)
|
||||||
|
|
||||||
|
# TODO(now): Explain all of this madness.
|
||||||
|
encoder_in = torch.cat(
|
||||||
|
[
|
||||||
|
torch.stack([latent_embed, robot_state_embed], axis=0),
|
||||||
|
encoder_in.flatten(2).permute(2, 0, 1),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
pos_embed = torch.cat(
|
||||||
|
[self.additional_pos_embed.weight.unsqueeze(1), pos.flatten(2).permute(2, 0, 1)], axis=0
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_out = self.encoder(encoder_in, pos=pos_embed)
|
||||||
|
decoder_in = torch.zeros(
|
||||||
|
(self.horizon, batch_size, self.hidden_dim), dtype=pos_embed.dtype, device=pos_embed.device
|
||||||
|
)
|
||||||
|
decoder_out = self.decoder(
|
||||||
|
decoder_in,
|
||||||
|
encoder_out,
|
||||||
|
encoder_pos_embed=pos_embed,
|
||||||
|
decoder_pos_embed=self.decoder_pos_embed_embed.weight.unsqueeze(1),
|
||||||
|
).transpose(0, 1) # back to (B, S, C)
|
||||||
|
|
||||||
|
actions = self.action_head(decoder_out)
|
||||||
|
return actions, [mu, log_sigma_x2]
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_layers,
|
||||||
|
d_model,
|
||||||
|
nhead,
|
||||||
|
dim_feedforward=2048,
|
||||||
|
dropout=0.1,
|
||||||
|
activation="relu",
|
||||||
|
normalize_before=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
TransformerEncoderLayer(
|
||||||
|
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
||||||
|
)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.norm = nn.LayerNorm(d_model) if normalize_before else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x, pos: Optional[Tensor] = None):
|
||||||
|
for layer in self.layers:
|
||||||
|
x = layer(x, pos=pos)
|
||||||
|
x = self.norm(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoderLayer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
||||||
|
# Implementation of Feedforward model
|
||||||
|
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||||
|
|
||||||
|
self.norm1 = nn.LayerNorm(d_model)
|
||||||
|
self.norm2 = nn.LayerNorm(d_model)
|
||||||
|
self.dropout1 = nn.Dropout(dropout)
|
||||||
|
self.dropout2 = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
self.activation = _get_activation_fn(activation)
|
||||||
|
self.normalize_before = normalize_before
|
||||||
|
|
||||||
|
def forward(self, x, pos_embed: Optional[Tensor] = None):
|
||||||
|
skip = x
|
||||||
|
if self.normalize_before:
|
||||||
|
x = self.norm1(x)
|
||||||
|
q = k = x if pos_embed is None else x + pos_embed
|
||||||
|
x = self.self_attn(q, k, value=x)[0]
|
||||||
|
x = skip + self.dropout1(x)
|
||||||
|
if self.normalize_before:
|
||||||
|
skip = x
|
||||||
|
x = self.norm2(x)
|
||||||
|
else:
|
||||||
|
x = self.norm1(x)
|
||||||
|
skip = x
|
||||||
|
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
||||||
|
x = skip + self.dropout2(x)
|
||||||
|
if not self.normalize_before:
|
||||||
|
x = self.norm2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerDecoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_layers,
|
||||||
|
d_model,
|
||||||
|
nhead,
|
||||||
|
dim_feedforward=2048,
|
||||||
|
dropout=0.1,
|
||||||
|
activation="relu",
|
||||||
|
normalize_before=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
TransformerDecoderLayer(
|
||||||
|
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
||||||
|
)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.norm = nn.LayerNorm(d_model)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, x, encoder_out, decoder_pos_embed: Tensor | None = None, encoder_pos_embed: Tensor | None = None
|
||||||
|
):
|
||||||
|
for layer in self.layers:
|
||||||
|
x = layer(
|
||||||
|
x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed
|
||||||
|
)
|
||||||
|
if self.norm is not None:
|
||||||
|
x = self.norm(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerDecoderLayer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
||||||
|
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
||||||
|
# Implementation of Feedforward model
|
||||||
|
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||||
|
|
||||||
|
self.norm1 = nn.LayerNorm(d_model)
|
||||||
|
self.norm2 = nn.LayerNorm(d_model)
|
||||||
|
self.norm3 = nn.LayerNorm(d_model)
|
||||||
|
self.dropout1 = nn.Dropout(dropout)
|
||||||
|
self.dropout2 = nn.Dropout(dropout)
|
||||||
|
self.dropout3 = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
self.activation = _get_activation_fn(activation)
|
||||||
|
self.normalize_before = normalize_before
|
||||||
|
|
||||||
|
def maybe_add_pos_embed(self, tensor: Tensor, pos_embed: Tensor | None) -> Tensor:
|
||||||
|
return tensor if pos_embed is None else tensor + pos_embed
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
encoder_out: Tensor,
|
||||||
|
decoder_pos_embed: Tensor | None = None,
|
||||||
|
encoder_pos_embed: Tensor | None = None,
|
||||||
|
) -> Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: (Decoder Sequence, Batch, Channel) tensor of input tokens.
|
||||||
|
encoder_out: (Encoder Sequence, B, C) output features from the last layer of the encoder we are
|
||||||
|
cross-attending with.
|
||||||
|
decoder_pos_embed: (ES, 1, C) positional embedding for keys (from the encoder).
|
||||||
|
encoder_pos_embed: (DS, 1, C) Positional_embedding for the queries (from the decoder).
|
||||||
|
Returns:
|
||||||
|
(DS, B, C) tensor of decoder output features.
|
||||||
|
"""
|
||||||
|
skip = x
|
||||||
|
if self.normalize_before:
|
||||||
|
x = self.norm1(x)
|
||||||
|
q = k = self.maybe_add_pos_embed(x, decoder_pos_embed)
|
||||||
|
x = self.self_attn(q, k, value=x)[0]
|
||||||
|
x = skip + self.dropout1(x)
|
||||||
|
if self.normalize_before:
|
||||||
|
skip = x
|
||||||
|
x = self.norm2(x)
|
||||||
|
else:
|
||||||
|
x = self.norm1(x)
|
||||||
|
skip = x
|
||||||
|
x = self.multihead_attn(
|
||||||
|
query=self.maybe_add_pos_embed(x, decoder_pos_embed),
|
||||||
|
key=self.maybe_add_pos_embed(encoder_out, encoder_pos_embed),
|
||||||
|
value=encoder_out,
|
||||||
|
)[0]
|
||||||
|
x = skip + self.dropout2(x)
|
||||||
|
if self.normalize_before:
|
||||||
|
skip = x
|
||||||
|
x = self.norm3(x)
|
||||||
|
else:
|
||||||
|
x = self.norm2(x)
|
||||||
|
skip = x
|
||||||
|
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
||||||
|
x = skip + self.dropout3(x)
|
||||||
|
if not self.normalize_before:
|
||||||
|
x = self.norm3(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SinusoidalPositionEmbedding2D(nn.Module):
|
||||||
|
"""Sinusoidal positional embeddings similar to what's presented in Attention Is All You Need.
|
||||||
|
|
||||||
|
The variation is that the position indices are normalized in [0, 2π] (not quite: the lower bound is 1/H
|
||||||
|
for the vertical direction, and 1/W for the horizontal direction.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dimension: int):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
dimension: The desired dimension of the embeddings.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.dimension = dimension
|
||||||
|
self._two_pi = 2 * math.pi
|
||||||
|
self._eps = 1e-6
|
||||||
|
# Inverse "common ratio" for the geometric progression in sinusoid frequencies.
|
||||||
|
self._temperature = 10000
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: A (B, C, H, W) batch of 2D feature map to generate the embeddings for.
|
||||||
|
Returns:
|
||||||
|
A (1, C, H, W) batch of corresponding sinusoidal positional embeddings.
|
||||||
|
"""
|
||||||
|
not_mask = torch.ones_like(x[0, [0]]) # (1, H, W)
|
||||||
|
# Note: These are like range(1, H+1) and range(1, W+1) respectively, but in most implementations
|
||||||
|
# they would be range(0, H) and range(0, W). Keeping it at as to match the original code.
|
||||||
|
y_range = not_mask.cumsum(1, dtype=torch.float32)
|
||||||
|
x_range = not_mask.cumsum(2, dtype=torch.float32)
|
||||||
|
|
||||||
|
# "Normalize" the position index such that it ranges in [0, 2π].
|
||||||
|
# Note: Adding epsilon on the denominator should not be needed as all values of y_embed and x_range
|
||||||
|
# are non-zero by construction. This is an artifact of the original code.
|
||||||
|
y_range = y_range / (y_range[:, -1:, :] + self._eps) * self._two_pi
|
||||||
|
x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi
|
||||||
|
|
||||||
|
inverse_frequency = self._temperature ** (
|
||||||
|
2 * (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2) / self.dimension
|
||||||
|
)
|
||||||
|
|
||||||
|
x_range = x_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)
|
||||||
|
y_range = y_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)
|
||||||
|
|
||||||
|
# Note: this stack then flatten operation results in interleaved sine and cosine terms.
|
||||||
|
# pos_embed_x and pos_embed are (1, H, W, C // 2).
|
||||||
|
pos_embed_x = torch.stack((x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1).flatten(3)
|
||||||
|
pos_embed_y = torch.stack((y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1).flatten(3)
|
||||||
|
pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(0, 3, 1, 2) # (1, C, H, W)
|
||||||
|
|
||||||
|
return pos_embed
|
||||||
|
|
||||||
|
|
||||||
|
def _get_activation_fn(activation: str) -> Callable:
|
||||||
|
"""Return an activation function given a string"""
|
||||||
|
if activation == "relu":
|
||||||
|
return F.relu
|
||||||
|
if activation == "gelu":
|
||||||
|
return F.gelu
|
||||||
|
if activation == "glu":
|
||||||
|
return F.glu
|
||||||
|
raise RuntimeError(f"activation should be relu/gelu/glu, not {activation}.")
|
||||||
|
|
|
@ -1,102 +0,0 @@
|
||||||
"""
|
|
||||||
Various positional encodings for the transformer.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import math
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
from .utils import NestedTensor
|
|
||||||
|
|
||||||
|
|
||||||
class PositionEmbeddingSine(nn.Module):
|
|
||||||
"""
|
|
||||||
This is a more standard version of the position embedding, very similar to the one
|
|
||||||
used by the Attention is all you need paper, generalized to work on images.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
|
||||||
super().__init__()
|
|
||||||
self.num_pos_feats = num_pos_feats
|
|
||||||
self.temperature = temperature
|
|
||||||
self.normalize = normalize
|
|
||||||
if scale is not None and normalize is False:
|
|
||||||
raise ValueError("normalize should be True if scale is passed")
|
|
||||||
if scale is None:
|
|
||||||
scale = 2 * math.pi
|
|
||||||
self.scale = scale
|
|
||||||
|
|
||||||
def forward(self, tensor):
|
|
||||||
x = tensor
|
|
||||||
# mask = tensor_list.mask
|
|
||||||
# assert mask is not None
|
|
||||||
# not_mask = ~mask
|
|
||||||
|
|
||||||
not_mask = torch.ones_like(x[0, [0]])
|
|
||||||
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
|
||||||
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
|
||||||
if self.normalize:
|
|
||||||
eps = 1e-6
|
|
||||||
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
|
||||||
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
|
||||||
|
|
||||||
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
|
||||||
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
|
||||||
|
|
||||||
pos_x = x_embed[:, :, :, None] / dim_t
|
|
||||||
pos_y = y_embed[:, :, :, None] / dim_t
|
|
||||||
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
|
||||||
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
|
||||||
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
|
||||||
return pos
|
|
||||||
|
|
||||||
|
|
||||||
class PositionEmbeddingLearned(nn.Module):
|
|
||||||
"""
|
|
||||||
Absolute pos embedding, learned.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, num_pos_feats=256):
|
|
||||||
super().__init__()
|
|
||||||
self.row_embed = nn.Embedding(50, num_pos_feats)
|
|
||||||
self.col_embed = nn.Embedding(50, num_pos_feats)
|
|
||||||
self.reset_parameters()
|
|
||||||
|
|
||||||
def reset_parameters(self):
|
|
||||||
nn.init.uniform_(self.row_embed.weight)
|
|
||||||
nn.init.uniform_(self.col_embed.weight)
|
|
||||||
|
|
||||||
def forward(self, tensor_list: NestedTensor):
|
|
||||||
x = tensor_list.tensors
|
|
||||||
h, w = x.shape[-2:]
|
|
||||||
i = torch.arange(w, device=x.device)
|
|
||||||
j = torch.arange(h, device=x.device)
|
|
||||||
x_emb = self.col_embed(i)
|
|
||||||
y_emb = self.row_embed(j)
|
|
||||||
pos = (
|
|
||||||
torch.cat(
|
|
||||||
[
|
|
||||||
x_emb.unsqueeze(0).repeat(h, 1, 1),
|
|
||||||
y_emb.unsqueeze(1).repeat(1, w, 1),
|
|
||||||
],
|
|
||||||
dim=-1,
|
|
||||||
)
|
|
||||||
.permute(2, 0, 1)
|
|
||||||
.unsqueeze(0)
|
|
||||||
.repeat(x.shape[0], 1, 1, 1)
|
|
||||||
)
|
|
||||||
return pos
|
|
||||||
|
|
||||||
|
|
||||||
def build_position_encoding(args):
|
|
||||||
n_steps = args.hidden_dim // 2
|
|
||||||
if args.position_embedding in ("v2", "sine"):
|
|
||||||
# TODO find a better way of exposing other arguments
|
|
||||||
position_embedding = PositionEmbeddingSine(n_steps, normalize=True)
|
|
||||||
elif args.position_embedding in ("v3", "learned"):
|
|
||||||
position_embedding = PositionEmbeddingLearned(n_steps)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"not supported {args.position_embedding}")
|
|
||||||
|
|
||||||
return position_embedding
|
|
|
@ -1,240 +0,0 @@
|
||||||
"""
|
|
||||||
TODO(now)
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F # noqa: N812
|
|
||||||
from torch import Tensor, nn
|
|
||||||
|
|
||||||
|
|
||||||
class Transformer(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
d_model=512,
|
|
||||||
nhead=8,
|
|
||||||
num_encoder_layers=6,
|
|
||||||
num_decoder_layers=6,
|
|
||||||
dim_feedforward=2048,
|
|
||||||
dropout=0.1,
|
|
||||||
activation="relu",
|
|
||||||
normalize_before=False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.encoder = TransformerEncoder(
|
|
||||||
num_encoder_layers, d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
|
||||||
)
|
|
||||||
self.decoder = TransformerDecoder(
|
|
||||||
num_decoder_layers, d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
|
||||||
)
|
|
||||||
self.d_model = d_model
|
|
||||||
self.nhead = nhead
|
|
||||||
self._init_params() # TODO(now): move to somewhere common
|
|
||||||
|
|
||||||
def _init_params(self):
|
|
||||||
for p in self.parameters():
|
|
||||||
if p.dim() > 1:
|
|
||||||
nn.init.xavier_uniform_(p)
|
|
||||||
|
|
||||||
def forward(self, x, encoder_pos, decoder_pos):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
x: ((E)ncoder (S)equence, (B)atch, (C)hannels)
|
|
||||||
decoder_pos: (Decoder Sequence, C) tensor for the decoder's positional embedding.
|
|
||||||
encoder_pos: (ES, C) tenso
|
|
||||||
"""
|
|
||||||
# TODO flatten only when input has H and W
|
|
||||||
bs = x.shape[1]
|
|
||||||
|
|
||||||
encoder_out = self.encoder(x, pos=encoder_pos)
|
|
||||||
decoder_in = torch.zeros(
|
|
||||||
(decoder_pos.shape[0], bs, decoder_pos.shape[2]),
|
|
||||||
dtype=decoder_pos.dtype,
|
|
||||||
device=decoder_pos.device,
|
|
||||||
)
|
|
||||||
decoder_out = self.decoder(decoder_in, encoder_out, encoder_pos=encoder_pos, decoder_pos=decoder_pos)
|
|
||||||
return decoder_out
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerEncoder(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
num_layers,
|
|
||||||
d_model,
|
|
||||||
nhead,
|
|
||||||
dim_feedforward=2048,
|
|
||||||
dropout=0.1,
|
|
||||||
activation="relu",
|
|
||||||
normalize_before=False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.layers = nn.ModuleList(
|
|
||||||
[
|
|
||||||
TransformerEncoderLayer(
|
|
||||||
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
|
||||||
)
|
|
||||||
for _ in range(num_layers)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
self.norm = nn.LayerNorm(d_model) if normalize_before else nn.Identity()
|
|
||||||
|
|
||||||
def forward(self, x, pos: Optional[Tensor] = None):
|
|
||||||
for layer in self.layers:
|
|
||||||
x = layer(x, pos=pos)
|
|
||||||
x = self.norm(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerEncoderLayer(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
|
||||||
# Implementation of Feedforward model
|
|
||||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
|
||||||
self.dropout = nn.Dropout(dropout)
|
|
||||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
|
||||||
|
|
||||||
self.norm1 = nn.LayerNorm(d_model)
|
|
||||||
self.norm2 = nn.LayerNorm(d_model)
|
|
||||||
self.dropout1 = nn.Dropout(dropout)
|
|
||||||
self.dropout2 = nn.Dropout(dropout)
|
|
||||||
|
|
||||||
self.activation = _get_activation_fn(activation)
|
|
||||||
self.normalize_before = normalize_before
|
|
||||||
|
|
||||||
def forward(self, x, pos: Optional[Tensor] = None):
|
|
||||||
skip = x
|
|
||||||
if self.normalize_before:
|
|
||||||
x = self.norm1(x)
|
|
||||||
q = k = x if pos is None else x + pos
|
|
||||||
x = self.self_attn(q, k, value=x)[0]
|
|
||||||
x = skip + self.dropout1(x)
|
|
||||||
if self.normalize_before:
|
|
||||||
skip = x
|
|
||||||
x = self.norm2(x)
|
|
||||||
else:
|
|
||||||
x = self.norm1(x)
|
|
||||||
skip = x
|
|
||||||
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
|
||||||
x = skip + self.dropout2(x)
|
|
||||||
if not self.normalize_before:
|
|
||||||
x = self.norm2(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerDecoder(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
num_layers,
|
|
||||||
d_model,
|
|
||||||
nhead,
|
|
||||||
dim_feedforward=2048,
|
|
||||||
dropout=0.1,
|
|
||||||
activation="relu",
|
|
||||||
normalize_before=False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.layers = nn.ModuleList(
|
|
||||||
[
|
|
||||||
TransformerDecoderLayer(
|
|
||||||
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
|
||||||
)
|
|
||||||
for _ in range(num_layers)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
self.num_layers = num_layers
|
|
||||||
self.norm = nn.LayerNorm(d_model)
|
|
||||||
|
|
||||||
def forward(self, x, encoder_out, decoder_pos: Tensor | None = None, encoder_pos: Tensor | None = None):
|
|
||||||
for layer in self.layers:
|
|
||||||
x = layer(x, encoder_out, decoder_pos=decoder_pos, encoder_pos=encoder_pos)
|
|
||||||
if self.norm is not None:
|
|
||||||
x = self.norm(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerDecoderLayer(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
|
||||||
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
|
||||||
# Implementation of Feedforward model
|
|
||||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
|
||||||
self.dropout = nn.Dropout(dropout)
|
|
||||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
|
||||||
|
|
||||||
self.norm1 = nn.LayerNorm(d_model)
|
|
||||||
self.norm2 = nn.LayerNorm(d_model)
|
|
||||||
self.norm3 = nn.LayerNorm(d_model)
|
|
||||||
self.dropout1 = nn.Dropout(dropout)
|
|
||||||
self.dropout2 = nn.Dropout(dropout)
|
|
||||||
self.dropout3 = nn.Dropout(dropout)
|
|
||||||
|
|
||||||
self.activation = _get_activation_fn(activation)
|
|
||||||
self.normalize_before = normalize_before
|
|
||||||
|
|
||||||
def maybe_add_pos_embed(self, tensor: Tensor, pos: Tensor | None) -> Tensor:
|
|
||||||
return tensor if pos is None else tensor + pos
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x: Tensor,
|
|
||||||
encoder_out: Tensor,
|
|
||||||
decoder_pos: Tensor | None = None,
|
|
||||||
encoder_pos: Tensor | None = None,
|
|
||||||
) -> Tensor:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
x: (Decoder Sequence, Batch, Channel) tensor of input tokens.
|
|
||||||
encoder_out: (Encoder Sequence, B, C) output features from the last layer of the encoder we are
|
|
||||||
cross-attending with.
|
|
||||||
decoder_pos: (ES, 1, C) positional embedding for keys (from the encoder).
|
|
||||||
encoder_pos: (DS, 1, C) Positional_embedding for the queries (from the decoder).
|
|
||||||
Returns:
|
|
||||||
(DS, B, C) tensor of decoder output features.
|
|
||||||
"""
|
|
||||||
skip = x
|
|
||||||
if self.normalize_before:
|
|
||||||
x = self.norm1(x)
|
|
||||||
q = k = self.maybe_add_pos_embed(x, decoder_pos)
|
|
||||||
x = self.self_attn(q, k, value=x)[0]
|
|
||||||
x = skip + self.dropout1(x)
|
|
||||||
if self.normalize_before:
|
|
||||||
skip = x
|
|
||||||
x = self.norm2(x)
|
|
||||||
else:
|
|
||||||
x = self.norm1(x)
|
|
||||||
skip = x
|
|
||||||
x = self.multihead_attn(
|
|
||||||
query=self.maybe_add_pos_embed(x, decoder_pos),
|
|
||||||
key=self.maybe_add_pos_embed(encoder_out, encoder_pos),
|
|
||||||
value=encoder_out,
|
|
||||||
)[0]
|
|
||||||
x = skip + self.dropout2(x)
|
|
||||||
if self.normalize_before:
|
|
||||||
skip = x
|
|
||||||
x = self.norm3(x)
|
|
||||||
else:
|
|
||||||
x = self.norm2(x)
|
|
||||||
skip = x
|
|
||||||
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
|
||||||
x = skip + self.dropout3(x)
|
|
||||||
if not self.normalize_before:
|
|
||||||
x = self.norm3(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def _get_activation_fn(activation):
|
|
||||||
"""Return an activation function given a string"""
|
|
||||||
if activation == "relu":
|
|
||||||
return F.relu
|
|
||||||
if activation == "gelu":
|
|
||||||
return F.gelu
|
|
||||||
if activation == "glu":
|
|
||||||
return F.glu
|
|
||||||
raise RuntimeError(f"activation should be relu/gelu/glu, not {activation}.")
|
|
|
@ -1,478 +0,0 @@
|
||||||
"""
|
|
||||||
Misc functions, including distributed helpers.
|
|
||||||
|
|
||||||
Mostly copy-paste from torchvision references.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import datetime
|
|
||||||
import os
|
|
||||||
import pickle
|
|
||||||
import subprocess
|
|
||||||
import time
|
|
||||||
from collections import defaultdict, deque
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
|
|
||||||
# needed due to empty tensor bug in pytorch and torchvision 0.5
|
|
||||||
import torchvision
|
|
||||||
from packaging import version
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
if version.parse(torchvision.__version__) < version.parse("0.7"):
|
|
||||||
from torchvision.ops import _new_empty_tensor
|
|
||||||
from torchvision.ops.misc import _output_size
|
|
||||||
|
|
||||||
|
|
||||||
class SmoothedValue:
|
|
||||||
"""Track a series of values and provide access to smoothed values over a
|
|
||||||
window or the global series average.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, window_size=20, fmt=None):
|
|
||||||
if fmt is None:
|
|
||||||
fmt = "{median:.4f} ({global_avg:.4f})"
|
|
||||||
self.deque = deque(maxlen=window_size)
|
|
||||||
self.total = 0.0
|
|
||||||
self.count = 0
|
|
||||||
self.fmt = fmt
|
|
||||||
|
|
||||||
def update(self, value, n=1):
|
|
||||||
self.deque.append(value)
|
|
||||||
self.count += n
|
|
||||||
self.total += value * n
|
|
||||||
|
|
||||||
def synchronize_between_processes(self):
|
|
||||||
"""
|
|
||||||
Warning: does not synchronize the deque!
|
|
||||||
"""
|
|
||||||
if not is_dist_avail_and_initialized():
|
|
||||||
return
|
|
||||||
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
|
|
||||||
dist.barrier()
|
|
||||||
dist.all_reduce(t)
|
|
||||||
t = t.tolist()
|
|
||||||
self.count = int(t[0])
|
|
||||||
self.total = t[1]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def median(self):
|
|
||||||
d = torch.tensor(list(self.deque))
|
|
||||||
return d.median().item()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def avg(self):
|
|
||||||
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
|
||||||
return d.mean().item()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def global_avg(self):
|
|
||||||
return self.total / self.count
|
|
||||||
|
|
||||||
@property
|
|
||||||
def max(self):
|
|
||||||
return max(self.deque)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def value(self):
|
|
||||||
return self.deque[-1]
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return self.fmt.format(
|
|
||||||
median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def all_gather(data):
|
|
||||||
"""
|
|
||||||
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
|
||||||
Args:
|
|
||||||
data: any picklable object
|
|
||||||
Returns:
|
|
||||||
list[data]: list of data gathered from each rank
|
|
||||||
"""
|
|
||||||
world_size = get_world_size()
|
|
||||||
if world_size == 1:
|
|
||||||
return [data]
|
|
||||||
|
|
||||||
# serialized to a Tensor
|
|
||||||
buffer = pickle.dumps(data)
|
|
||||||
storage = torch.ByteStorage.from_buffer(buffer)
|
|
||||||
tensor = torch.ByteTensor(storage).to("cuda")
|
|
||||||
|
|
||||||
# obtain Tensor size of each rank
|
|
||||||
local_size = torch.tensor([tensor.numel()], device="cuda")
|
|
||||||
size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
|
|
||||||
dist.all_gather(size_list, local_size)
|
|
||||||
size_list = [int(size.item()) for size in size_list]
|
|
||||||
max_size = max(size_list)
|
|
||||||
|
|
||||||
# receiving Tensor from all ranks
|
|
||||||
# we pad the tensor because torch all_gather does not support
|
|
||||||
# gathering tensors of different shapes
|
|
||||||
tensor_list = []
|
|
||||||
for _ in size_list:
|
|
||||||
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
|
|
||||||
if local_size != max_size:
|
|
||||||
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
|
|
||||||
tensor = torch.cat((tensor, padding), dim=0)
|
|
||||||
dist.all_gather(tensor_list, tensor)
|
|
||||||
|
|
||||||
data_list = []
|
|
||||||
for size, tensor in zip(size_list, tensor_list, strict=False):
|
|
||||||
buffer = tensor.cpu().numpy().tobytes()[:size]
|
|
||||||
data_list.append(pickle.loads(buffer))
|
|
||||||
|
|
||||||
return data_list
|
|
||||||
|
|
||||||
|
|
||||||
def reduce_dict(input_dict, average=True):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
input_dict (dict): all the values will be reduced
|
|
||||||
average (bool): whether to do average or sum
|
|
||||||
Reduce the values in the dictionary from all processes so that all processes
|
|
||||||
have the averaged results. Returns a dict with the same fields as
|
|
||||||
input_dict, after reduction.
|
|
||||||
"""
|
|
||||||
world_size = get_world_size()
|
|
||||||
if world_size < 2:
|
|
||||||
return input_dict
|
|
||||||
with torch.no_grad():
|
|
||||||
names = []
|
|
||||||
values = []
|
|
||||||
# sort the keys so that they are consistent across processes
|
|
||||||
for k in sorted(input_dict.keys()):
|
|
||||||
names.append(k)
|
|
||||||
values.append(input_dict[k])
|
|
||||||
values = torch.stack(values, dim=0)
|
|
||||||
dist.all_reduce(values)
|
|
||||||
if average:
|
|
||||||
values /= world_size
|
|
||||||
reduced_dict = {k: v for k, v in zip(names, values, strict=False)} # noqa: C416
|
|
||||||
return reduced_dict
|
|
||||||
|
|
||||||
|
|
||||||
class MetricLogger:
|
|
||||||
def __init__(self, delimiter="\t"):
|
|
||||||
self.meters = defaultdict(SmoothedValue)
|
|
||||||
self.delimiter = delimiter
|
|
||||||
|
|
||||||
def update(self, **kwargs):
|
|
||||||
for k, v in kwargs.items():
|
|
||||||
if isinstance(v, torch.Tensor):
|
|
||||||
v = v.item()
|
|
||||||
assert isinstance(v, (float, int))
|
|
||||||
self.meters[k].update(v)
|
|
||||||
|
|
||||||
def __getattr__(self, attr):
|
|
||||||
if attr in self.meters:
|
|
||||||
return self.meters[attr]
|
|
||||||
if attr in self.__dict__:
|
|
||||||
return self.__dict__[attr]
|
|
||||||
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr))
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
loss_str = []
|
|
||||||
for name, meter in self.meters.items():
|
|
||||||
loss_str.append("{}: {}".format(name, str(meter)))
|
|
||||||
return self.delimiter.join(loss_str)
|
|
||||||
|
|
||||||
def synchronize_between_processes(self):
|
|
||||||
for meter in self.meters.values():
|
|
||||||
meter.synchronize_between_processes()
|
|
||||||
|
|
||||||
def add_meter(self, name, meter):
|
|
||||||
self.meters[name] = meter
|
|
||||||
|
|
||||||
def log_every(self, iterable, print_freq, header=None):
|
|
||||||
if not header:
|
|
||||||
header = ""
|
|
||||||
start_time = time.time()
|
|
||||||
end = time.time()
|
|
||||||
iter_time = SmoothedValue(fmt="{avg:.4f}")
|
|
||||||
data_time = SmoothedValue(fmt="{avg:.4f}")
|
|
||||||
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
log_msg = self.delimiter.join(
|
|
||||||
[
|
|
||||||
header,
|
|
||||||
"[{0" + space_fmt + "}/{1}]",
|
|
||||||
"eta: {eta}",
|
|
||||||
"{meters}",
|
|
||||||
"time: {time}",
|
|
||||||
"data: {data}",
|
|
||||||
"max mem: {memory:.0f}",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
log_msg = self.delimiter.join(
|
|
||||||
[
|
|
||||||
header,
|
|
||||||
"[{0" + space_fmt + "}/{1}]",
|
|
||||||
"eta: {eta}",
|
|
||||||
"{meters}",
|
|
||||||
"time: {time}",
|
|
||||||
"data: {data}",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
mega_b = 1024.0 * 1024.0
|
|
||||||
for i, obj in enumerate(iterable):
|
|
||||||
data_time.update(time.time() - end)
|
|
||||||
yield obj
|
|
||||||
iter_time.update(time.time() - end)
|
|
||||||
if i % print_freq == 0 or i == len(iterable) - 1:
|
|
||||||
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
|
||||||
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
print(
|
|
||||||
log_msg.format(
|
|
||||||
i,
|
|
||||||
len(iterable),
|
|
||||||
eta=eta_string,
|
|
||||||
meters=str(self),
|
|
||||||
time=str(iter_time),
|
|
||||||
data=str(data_time),
|
|
||||||
memory=torch.cuda.max_memory_allocated() / mega_b,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
print(
|
|
||||||
log_msg.format(
|
|
||||||
i,
|
|
||||||
len(iterable),
|
|
||||||
eta=eta_string,
|
|
||||||
meters=str(self),
|
|
||||||
time=str(iter_time),
|
|
||||||
data=str(data_time),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
end = time.time()
|
|
||||||
total_time = time.time() - start_time
|
|
||||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
|
||||||
print("{} Total time: {} ({:.4f} s / it)".format(header, total_time_str, total_time / len(iterable)))
|
|
||||||
|
|
||||||
|
|
||||||
def get_sha():
|
|
||||||
cwd = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
|
|
||||||
def _run(command):
|
|
||||||
return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
|
|
||||||
|
|
||||||
sha = "N/A"
|
|
||||||
diff = "clean"
|
|
||||||
branch = "N/A"
|
|
||||||
try:
|
|
||||||
sha = _run(["git", "rev-parse", "HEAD"])
|
|
||||||
subprocess.check_output(["git", "diff"], cwd=cwd)
|
|
||||||
diff = _run(["git", "diff-index", "HEAD"])
|
|
||||||
diff = "has uncommited changes" if diff else "clean"
|
|
||||||
branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
message = f"sha: {sha}, status: {diff}, branch: {branch}"
|
|
||||||
return message
|
|
||||||
|
|
||||||
|
|
||||||
def collate_fn(batch):
|
|
||||||
batch = list(zip(*batch, strict=False))
|
|
||||||
batch[0] = nested_tensor_from_tensor_list(batch[0])
|
|
||||||
return tuple(batch)
|
|
||||||
|
|
||||||
|
|
||||||
def _max_by_axis(the_list):
|
|
||||||
# type: (List[List[int]]) -> List[int]
|
|
||||||
maxes = the_list[0]
|
|
||||||
for sublist in the_list[1:]:
|
|
||||||
for index, item in enumerate(sublist):
|
|
||||||
maxes[index] = max(maxes[index], item)
|
|
||||||
return maxes
|
|
||||||
|
|
||||||
|
|
||||||
class NestedTensor:
|
|
||||||
def __init__(self, tensors, mask: Optional[Tensor]):
|
|
||||||
self.tensors = tensors
|
|
||||||
self.mask = mask
|
|
||||||
|
|
||||||
def to(self, device):
|
|
||||||
# type: (Device) -> NestedTensor # noqa
|
|
||||||
cast_tensor = self.tensors.to(device)
|
|
||||||
mask = self.mask
|
|
||||||
if mask is not None:
|
|
||||||
assert mask is not None
|
|
||||||
cast_mask = mask.to(device)
|
|
||||||
else:
|
|
||||||
cast_mask = None
|
|
||||||
return NestedTensor(cast_tensor, cast_mask)
|
|
||||||
|
|
||||||
def decompose(self):
|
|
||||||
return self.tensors, self.mask
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return str(self.tensors)
|
|
||||||
|
|
||||||
|
|
||||||
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
|
|
||||||
# TODO make this more general
|
|
||||||
if tensor_list[0].ndim == 3:
|
|
||||||
if torchvision._is_tracing():
|
|
||||||
# nested_tensor_from_tensor_list() does not export well to ONNX
|
|
||||||
# call _onnx_nested_tensor_from_tensor_list() instead
|
|
||||||
return _onnx_nested_tensor_from_tensor_list(tensor_list)
|
|
||||||
|
|
||||||
# TODO make it support different-sized images
|
|
||||||
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
|
|
||||||
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
|
|
||||||
batch_shape = [len(tensor_list)] + max_size
|
|
||||||
b, c, h, w = batch_shape
|
|
||||||
dtype = tensor_list[0].dtype
|
|
||||||
device = tensor_list[0].device
|
|
||||||
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
|
|
||||||
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
|
|
||||||
for img, pad_img, m in zip(tensor_list, tensor, mask, strict=False):
|
|
||||||
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
|
||||||
m[: img.shape[1], : img.shape[2]] = False
|
|
||||||
else:
|
|
||||||
raise ValueError("not supported")
|
|
||||||
return NestedTensor(tensor, mask)
|
|
||||||
|
|
||||||
|
|
||||||
# _onnx_nested_tensor_from_tensor_list() is an implementation of
|
|
||||||
# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
|
|
||||||
@torch.jit.unused
|
|
||||||
def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
|
|
||||||
max_size = []
|
|
||||||
for i in range(tensor_list[0].dim()):
|
|
||||||
max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(
|
|
||||||
torch.int64
|
|
||||||
)
|
|
||||||
max_size.append(max_size_i)
|
|
||||||
max_size = tuple(max_size)
|
|
||||||
|
|
||||||
# work around for
|
|
||||||
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
|
||||||
# m[: img.shape[1], :img.shape[2]] = False
|
|
||||||
# which is not yet supported in onnx
|
|
||||||
padded_imgs = []
|
|
||||||
padded_masks = []
|
|
||||||
for img in tensor_list:
|
|
||||||
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape), strict=False)]
|
|
||||||
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
|
|
||||||
padded_imgs.append(padded_img)
|
|
||||||
|
|
||||||
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
|
|
||||||
padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
|
|
||||||
padded_masks.append(padded_mask.to(torch.bool))
|
|
||||||
|
|
||||||
tensor = torch.stack(padded_imgs)
|
|
||||||
mask = torch.stack(padded_masks)
|
|
||||||
|
|
||||||
return NestedTensor(tensor, mask=mask)
|
|
||||||
|
|
||||||
|
|
||||||
def setup_for_distributed(is_master):
|
|
||||||
"""
|
|
||||||
This function disables printing when not in master process
|
|
||||||
"""
|
|
||||||
import builtins as __builtin__
|
|
||||||
|
|
||||||
builtin_print = __builtin__.print
|
|
||||||
|
|
||||||
def print(*args, **kwargs):
|
|
||||||
force = kwargs.pop("force", False)
|
|
||||||
if is_master or force:
|
|
||||||
builtin_print(*args, **kwargs)
|
|
||||||
|
|
||||||
__builtin__.print = print
|
|
||||||
|
|
||||||
|
|
||||||
def is_dist_avail_and_initialized():
|
|
||||||
if not dist.is_available():
|
|
||||||
return False
|
|
||||||
if not dist.is_initialized():
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def get_world_size():
|
|
||||||
if not is_dist_avail_and_initialized():
|
|
||||||
return 1
|
|
||||||
return dist.get_world_size()
|
|
||||||
|
|
||||||
|
|
||||||
def get_rank():
|
|
||||||
if not is_dist_avail_and_initialized():
|
|
||||||
return 0
|
|
||||||
return dist.get_rank()
|
|
||||||
|
|
||||||
|
|
||||||
def is_main_process():
|
|
||||||
return get_rank() == 0
|
|
||||||
|
|
||||||
|
|
||||||
def save_on_master(*args, **kwargs):
|
|
||||||
if is_main_process():
|
|
||||||
torch.save(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def init_distributed_mode(args):
|
|
||||||
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
|
||||||
args.rank = int(os.environ["RANK"])
|
|
||||||
args.world_size = int(os.environ["WORLD_SIZE"])
|
|
||||||
args.gpu = int(os.environ["LOCAL_RANK"])
|
|
||||||
elif "SLURM_PROCID" in os.environ:
|
|
||||||
args.rank = int(os.environ["SLURM_PROCID"])
|
|
||||||
args.gpu = args.rank % torch.cuda.device_count()
|
|
||||||
else:
|
|
||||||
print("Not using distributed mode")
|
|
||||||
args.distributed = False
|
|
||||||
return
|
|
||||||
|
|
||||||
args.distributed = True
|
|
||||||
|
|
||||||
torch.cuda.set_device(args.gpu)
|
|
||||||
args.dist_backend = "nccl"
|
|
||||||
print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True)
|
|
||||||
torch.distributed.init_process_group(
|
|
||||||
backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
|
|
||||||
)
|
|
||||||
torch.distributed.barrier()
|
|
||||||
setup_for_distributed(args.rank == 0)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def accuracy(output, target, topk=(1,)):
|
|
||||||
"""Computes the precision@k for the specified values of k"""
|
|
||||||
if target.numel() == 0:
|
|
||||||
return [torch.zeros([], device=output.device)]
|
|
||||||
maxk = max(topk)
|
|
||||||
batch_size = target.size(0)
|
|
||||||
|
|
||||||
_, pred = output.topk(maxk, 1, True, True)
|
|
||||||
pred = pred.t()
|
|
||||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
|
||||||
|
|
||||||
res = []
|
|
||||||
for k in topk:
|
|
||||||
correct_k = correct[:k].view(-1).float().sum(0)
|
|
||||||
res.append(correct_k.mul_(100.0 / batch_size))
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
|
|
||||||
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
|
|
||||||
"""
|
|
||||||
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
|
|
||||||
This will eventually be supported natively by PyTorch, and this
|
|
||||||
class can go away.
|
|
||||||
"""
|
|
||||||
if version.parse(torchvision.__version__) < version.parse("0.7"):
|
|
||||||
if input.numel() > 0:
|
|
||||||
return torch.nn.functional.interpolate(input, size, scale_factor, mode, align_corners)
|
|
||||||
|
|
||||||
output_shape = _output_size(2, input, size, scale_factor)
|
|
||||||
output_shape = list(input.shape[:-2]) + list(output_shape)
|
|
||||||
return _new_empty_tensor(input, output_shape)
|
|
||||||
else:
|
|
||||||
return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
|
|
|
@ -33,11 +33,10 @@ policy:
|
||||||
nheads: 8
|
nheads: 8
|
||||||
#camera_names: [top, front_close, left_pillar, right_pillar]
|
#camera_names: [top, front_close, left_pillar, right_pillar]
|
||||||
camera_names: [top]
|
camera_names: [top]
|
||||||
position_embedding: sine
|
|
||||||
masks: false
|
|
||||||
dilation: false
|
dilation: false
|
||||||
dropout: 0.1
|
dropout: 0.1
|
||||||
pre_norm: false
|
pre_norm: false
|
||||||
|
activation: relu
|
||||||
|
|
||||||
vae: true
|
vae: true
|
||||||
|
|
||||||
|
|
|
@ -11,6 +11,19 @@ policy = make_policy(cfg)
|
||||||
|
|
||||||
state_dict = torch.load("/home/alexander/Projects/act/outputs/sim_transfer_cube_human_vae/policy_last.ckpt")
|
state_dict = torch.load("/home/alexander/Projects/act/outputs/sim_transfer_cube_human_vae/policy_last.ckpt")
|
||||||
|
|
||||||
|
# Remove keys based on what they start with.
|
||||||
|
|
||||||
|
start_removals = [
|
||||||
|
# There is a bug that means the pretrained model doesn't even use the final decoder layers.
|
||||||
|
*[f"model.transformer.decoder.layers.{i}" for i in range(1, 7)],
|
||||||
|
"model.is_pad_head.",
|
||||||
|
]
|
||||||
|
|
||||||
|
for to_remove in start_removals:
|
||||||
|
for k in list(state_dict.keys()):
|
||||||
|
if k.startswith(to_remove):
|
||||||
|
del state_dict[k]
|
||||||
|
|
||||||
|
|
||||||
# Replace keys based on what they start with.
|
# Replace keys based on what they start with.
|
||||||
|
|
||||||
|
@ -26,6 +39,9 @@ start_replacements = [
|
||||||
("model.input_proj.", "model.encoder_img_feat_input_proj."),
|
("model.input_proj.", "model.encoder_img_feat_input_proj."),
|
||||||
("model.input_proj_robot_state", "model.encoder_robot_state_input_proj"),
|
("model.input_proj_robot_state", "model.encoder_robot_state_input_proj"),
|
||||||
("model.latent_out_proj.", "model.encoder_latent_input_proj."),
|
("model.latent_out_proj.", "model.encoder_latent_input_proj."),
|
||||||
|
("model.transformer.encoder.", "model.encoder."),
|
||||||
|
("model.transformer.decoder.", "model.decoder."),
|
||||||
|
("model.backbones.0.0.body.", "model.backbone."),
|
||||||
]
|
]
|
||||||
|
|
||||||
for to_replace, replace_with in start_replacements:
|
for to_replace, replace_with in start_replacements:
|
||||||
|
@ -35,18 +51,6 @@ for to_replace, replace_with in start_replacements:
|
||||||
state_dict[k_] = state_dict[k]
|
state_dict[k_] = state_dict[k]
|
||||||
del state_dict[k]
|
del state_dict[k]
|
||||||
|
|
||||||
# Remove keys based on what they start with.
|
|
||||||
|
|
||||||
start_removals = [
|
|
||||||
# There is a bug that means the pretrained model doesn't even use the final decoder layers.
|
|
||||||
*[f"model.transformer.decoder.layers.{i}" for i in range(1, 7)],
|
|
||||||
"model.is_pad_head.",
|
|
||||||
]
|
|
||||||
|
|
||||||
for to_remove in start_removals:
|
|
||||||
for k in list(state_dict.keys()):
|
|
||||||
if k.startswith(to_remove):
|
|
||||||
del state_dict[k]
|
|
||||||
|
|
||||||
missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)
|
missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue