Copy past from act repo

This commit is contained in:
Cadene 2024-03-08 16:54:43 +00:00
parent 603455e313
commit fbc66a082b
6 changed files with 1492 additions and 0 deletions

View File

@ -0,0 +1,115 @@
from typing import List
import torch
import torchvision
from torch import nn
from torchvision.models._utils import IntermediateLayerGetter
from .position_encoding import build_position_encoding
from .utils import NestedTensor, is_main_process
class FrozenBatchNorm2d(torch.nn.Module):
"""
BatchNorm2d where the batch statistics and the affine parameters are fixed.
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101]
produce nans.
"""
def __init__(self, n):
super().__init__()
self.register_buffer("weight", torch.ones(n))
self.register_buffer("bias", torch.zeros(n))
self.register_buffer("running_mean", torch.zeros(n))
self.register_buffer("running_var", torch.ones(n))
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
num_batches_tracked_key = prefix + "num_batches_tracked"
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]
super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)
def forward(self, x):
# move reshapes to the beginning
# to make it fuser-friendly
w = self.weight.reshape(1, -1, 1, 1)
b = self.bias.reshape(1, -1, 1, 1)
rv = self.running_var.reshape(1, -1, 1, 1)
rm = self.running_mean.reshape(1, -1, 1, 1)
eps = 1e-5
scale = w * (rv + eps).rsqrt()
bias = b - rm * scale
return x * scale + bias
class BackboneBase(nn.Module):
def __init__(
self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool
):
super().__init__()
# for name, parameter in backbone.named_parameters(): # only train later layers # TODO do we want this?
# if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
# parameter.requires_grad_(False)
if return_interm_layers:
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
else:
return_layers = {"layer4": "0"}
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
self.num_channels = num_channels
def forward(self, tensor):
xs = self.body(tensor)
return xs
# out: Dict[str, NestedTensor] = {}
# for name, x in xs.items():
# m = tensor_list.mask
# assert m is not None
# mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
# out[name] = NestedTensor(x, mask)
# return out
class Backbone(BackboneBase):
"""ResNet backbone with frozen BatchNorm."""
def __init__(self, name: str, train_backbone: bool, return_interm_layers: bool, dilation: bool):
backbone = getattr(torchvision.models, name)(
replace_stride_with_dilation=[False, False, dilation],
pretrained=is_main_process(),
norm_layer=FrozenBatchNorm2d,
) # pretrained # TODO do we want frozen batch_norm??
num_channels = 512 if name in ("resnet18", "resnet34") else 2048
super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
class Joiner(nn.Sequential):
def __init__(self, backbone, position_embedding):
super().__init__(backbone, position_embedding)
def forward(self, tensor_list: NestedTensor):
xs = self[0](tensor_list)
out: List[NestedTensor] = []
pos = []
for _, x in xs.items():
out.append(x)
# position encoding
pos.append(self[1](x).to(x.dtype))
return out, pos
def build_backbone(args):
position_embedding = build_position_encoding(args)
train_backbone = args.lr_backbone > 0
return_interm_layers = args.masks
backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
model = Joiner(backbone, position_embedding)
model.num_channels = backbone.num_channels
return model

View File

@ -0,0 +1,288 @@
import numpy as np
import torch
from torch import nn
from torch.autograd import Variable
from .backbone import build_backbone
from .transformer import TransformerEncoder, TransformerEncoderLayer, build_transformer
def reparametrize(mu, logvar):
std = logvar.div(2).exp()
eps = Variable(std.data.new(std.size()).normal_())
return mu + std * eps
def get_sinusoid_encoding_table(n_position, d_hid):
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
class DETRVAE(nn.Module):
"""This is the DETR module that performs object detection"""
def __init__(self, backbones, transformer, encoder, state_dim, num_queries, camera_names):
"""Initializes the model.
Parameters:
backbones: torch module of the backbone to be used. See backbone.py
transformer: torch module of the transformer architecture. See transformer.py
state_dim: robot state dimension of the environment
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
DETR can detect in a single image. For COCO, we recommend 100 queries.
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
"""
super().__init__()
self.num_queries = num_queries
self.camera_names = camera_names
self.transformer = transformer
self.encoder = encoder
hidden_dim = transformer.d_model
self.action_head = nn.Linear(hidden_dim, state_dim)
self.is_pad_head = nn.Linear(hidden_dim, 1)
self.query_embed = nn.Embedding(num_queries, hidden_dim)
if backbones is not None:
self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1)
self.backbones = nn.ModuleList(backbones)
self.input_proj_robot_state = nn.Linear(14, hidden_dim)
else:
# input_dim = 14 + 7 # robot_state + env_state
self.input_proj_robot_state = nn.Linear(14, hidden_dim)
self.input_proj_env_state = nn.Linear(7, hidden_dim)
self.pos = torch.nn.Embedding(2, hidden_dim)
self.backbones = None
# encoder extra parameters
self.latent_dim = 32 # final size of latent z # TODO tune
self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
self.encoder_action_proj = nn.Linear(14, hidden_dim) # project action to embedding
self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding
self.latent_proj = nn.Linear(
hidden_dim, self.latent_dim * 2
) # project hidden state to latent std, var
self.register_buffer(
"pos_table", get_sinusoid_encoding_table(1 + 1 + num_queries, hidden_dim)
) # [CLS], qpos, a_seq
# decoder extra parameters
self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding
self.additional_pos_embed = nn.Embedding(
2, hidden_dim
) # learned position embedding for proprio and latent
def forward(self, qpos, image, env_state, actions=None, is_pad=None):
"""
qpos: batch, qpos_dim
image: batch, num_cam, channel, height, width
env_state: None
actions: batch, seq, action_dim
"""
is_training = actions is not None # train or val
bs, _ = qpos.shape
### Obtain latent z from action sequence
if is_training:
# project action sequence to embedding dim, and concat with a CLS token
action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim)
qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim)
qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim)
cls_embed = self.cls_embed.weight # (1, hidden_dim)
cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim)
encoder_input = torch.cat(
[cls_embed, qpos_embed, action_embed], axis=1
) # (bs, seq+1, hidden_dim)
encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim)
# do not mask cls token
cls_joint_is_pad = torch.full((bs, 2), False).to(qpos.device) # False: not a padding
is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1)
# obtain position embedding
pos_embed = self.pos_table.clone().detach()
pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim)
# query model
encoder_output = self.encoder(encoder_input, pos=pos_embed, src_key_padding_mask=is_pad)
encoder_output = encoder_output[0] # take cls output only
latent_info = self.latent_proj(encoder_output)
mu = latent_info[:, : self.latent_dim]
logvar = latent_info[:, self.latent_dim :]
latent_sample = reparametrize(mu, logvar)
latent_input = self.latent_out_proj(latent_sample)
else:
mu = logvar = None
latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(qpos.device)
latent_input = self.latent_out_proj(latent_sample)
if self.backbones is not None:
# Image observation features and position embeddings
all_cam_features = []
all_cam_pos = []
for cam_id, _ in enumerate(self.camera_names):
features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED
features = features[0] # take the last layer feature
pos = pos[0]
all_cam_features.append(self.input_proj(features))
all_cam_pos.append(pos)
# proprioception features
proprio_input = self.input_proj_robot_state(qpos)
# fold camera dimension into width dimension
src = torch.cat(all_cam_features, axis=3)
pos = torch.cat(all_cam_pos, axis=3)
hs = self.transformer(
src,
None,
self.query_embed.weight,
pos,
latent_input,
proprio_input,
self.additional_pos_embed.weight,
)[0]
else:
qpos = self.input_proj_robot_state(qpos)
env_state = self.input_proj_env_state(env_state)
transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2
hs = self.transformer(transformer_input, None, self.query_embed.weight, self.pos.weight)[0]
a_hat = self.action_head(hs)
is_pad_hat = self.is_pad_head(hs)
return a_hat, is_pad_hat, [mu, logvar]
class CNNMLP(nn.Module):
def __init__(self, backbones, state_dim, camera_names):
"""Initializes the model.
Parameters:
backbones: torch module of the backbone to be used. See backbone.py
transformer: torch module of the transformer architecture. See transformer.py
state_dim: robot state dimension of the environment
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
DETR can detect in a single image. For COCO, we recommend 100 queries.
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
"""
super().__init__()
self.camera_names = camera_names
self.action_head = nn.Linear(1000, state_dim) # TODO add more
if backbones is not None:
self.backbones = nn.ModuleList(backbones)
backbone_down_projs = []
for backbone in backbones:
down_proj = nn.Sequential(
nn.Conv2d(backbone.num_channels, 128, kernel_size=5),
nn.Conv2d(128, 64, kernel_size=5),
nn.Conv2d(64, 32, kernel_size=5),
)
backbone_down_projs.append(down_proj)
self.backbone_down_projs = nn.ModuleList(backbone_down_projs)
mlp_in_dim = 768 * len(backbones) + 14
self.mlp = mlp(input_dim=mlp_in_dim, hidden_dim=1024, output_dim=14, hidden_depth=2)
else:
raise NotImplementedError
def forward(self, qpos, image, env_state, actions=None):
"""
qpos: batch, qpos_dim
image: batch, num_cam, channel, height, width
env_state: None
actions: batch, seq, action_dim
"""
del env_state, actions
bs, _ = qpos.shape
# Image observation features and position embeddings
all_cam_features = []
for cam_id, _ in enumerate(self.camera_names):
features, pos = self.backbones[cam_id](image[:, cam_id])
features = features[0] # take the last layer feature
pos = pos[0] # not used
all_cam_features.append(self.backbone_down_projs[cam_id](features))
# flatten everything
flattened_features = []
for cam_feature in all_cam_features:
flattened_features.append(cam_feature.reshape([bs, -1]))
flattened_features = torch.cat(flattened_features, axis=1) # 768 each
features = torch.cat([flattened_features, qpos], axis=1) # qpos: 14
a_hat = self.mlp(features)
return a_hat
def mlp(input_dim, hidden_dim, output_dim, hidden_depth):
if hidden_depth == 0:
mods = [nn.Linear(input_dim, output_dim)]
else:
mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)]
for _ in range(hidden_depth - 1):
mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)]
mods.append(nn.Linear(hidden_dim, output_dim))
trunk = nn.Sequential(*mods)
return trunk
def build_encoder(args):
d_model = args.hidden_dim # 256
dropout = args.dropout # 0.1
nhead = args.nheads # 8
dim_feedforward = args.dim_feedforward # 2048
num_encoder_layers = args.enc_layers # 4 # TODO shared with VAE decoder
normalize_before = args.pre_norm # False
activation = "relu"
encoder_layer = TransformerEncoderLayer(
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
)
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
return encoder
def build(args):
state_dim = 14 # TODO hardcode
# From state
# backbone = None # from state for now, no need for conv nets
# From image
backbones = []
backbone = build_backbone(args)
backbones.append(backbone)
transformer = build_transformer(args)
encoder = build_encoder(args)
model = DETRVAE(
backbones,
transformer,
encoder,
state_dim=state_dim,
num_queries=args.num_queries,
camera_names=args.camera_names,
)
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("number of parameters: {:.2f}M".format(n_parameters / 1e6))
return model
def build_cnnmlp(args):
state_dim = 14 # TODO hardcode
# From state
# backbone = None # from state for now, no need for conv nets
# From image
backbones = []
for _ in args.camera_names:
backbone = build_backbone(args)
backbones.append(backbone)
model = CNNMLP(
backbones,
state_dim=state_dim,
camera_names=args.camera_names,
)
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("number of parameters: {:.2f}M".format(n_parameters / 1e6))
return model

View File

@ -0,0 +1,138 @@
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
import torchvision.transforms as transforms
from lerobot.common.policies.act.detr_vae import build
def build_act_model_and_optimizer(cfg):
model = build(cfg)
model.cuda()
param_dicts = [
{"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
{
"params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
"lr": cfg.lr_backbone,
},
]
optimizer = torch.optim.AdamW(param_dicts, lr=cfg.lr, weight_decay=cfg.weight_decay)
return model, optimizer
# def build_CNNMLP_model_and_optimizer(cfg):
# parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
# args = parser.parse_args()
# for k, v in cfg.items():
# setattr(args, k, v)
# model = build_CNNMLP_model(args)
# model.cuda()
# param_dicts = [
# {"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
# {
# "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
# "lr": args.lr_backbone,
# },
# ]
# optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
# weight_decay=args.weight_decay)
# return model, optimizer
def kl_divergence(mu, logvar):
batch_size = mu.size(0)
assert batch_size != 0
if mu.data.ndimension() == 4:
mu = mu.view(mu.size(0), mu.size(1))
if logvar.data.ndimension() == 4:
logvar = logvar.view(logvar.size(0), logvar.size(1))
klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
total_kld = klds.sum(1).mean(0, True)
dimension_wise_kld = klds.mean(0)
mean_kld = klds.mean(1).mean(0, True)
return total_kld, dimension_wise_kld, mean_kld
class ACTPolicy(nn.Module):
def __init__(self, cfg):
super().__init__()
model, optimizer = build_act_model_and_optimizer(cfg)
self.model = model # CVAE decoder
self.optimizer = optimizer
self.kl_weight = cfg.kl_weight
print(f"KL Weight {self.kl_weight}")
def __call__(self, qpos, image, actions=None, is_pad=None):
env_state = None
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
image = normalize(image)
if actions is not None: # training time
actions = actions[:, : self.model.num_queries]
is_pad = is_pad[:, : self.model.num_queries]
a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad)
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
loss_dict = {}
all_l1 = F.l1_loss(actions, a_hat, reduction="none")
l1 = (all_l1 * ~is_pad.unsqueeze(-1)).mean()
loss_dict["l1"] = l1
loss_dict["kl"] = total_kld[0]
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.kl_weight
return loss_dict
else: # inference time
a_hat, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior
return a_hat
def configure_optimizers(self):
return self.optimizer
# class CNNMLPPolicy(nn.Module):
# def __init__(self, cfg):
# super().__init__()
# model, optimizer = build_CNNMLP_model_and_optimizer(cfg)
# self.model = model # decoder
# self.optimizer = optimizer
# def __call__(self, qpos, image, actions=None, is_pad=None):
# env_state = None # TODO
# normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
# std=[0.229, 0.224, 0.225])
# image = normalize(image)
# if actions is not None: # training time
# actions = actions[:, 0]
# a_hat = self.model(qpos, image, env_state, actions)
# mse = F.mse_loss(actions, a_hat)
# loss_dict = dict()
# loss_dict['mse'] = mse
# loss_dict['loss'] = loss_dict['mse']
# return loss_dict
# else: # inference time
# a_hat = self.model(qpos, image, env_state) # no action, sample from prior
# return a_hat
# def configure_optimizers(self):
# return self.optimizer
# def kl_divergence(mu, logvar):
# batch_size = mu.size(0)
# assert batch_size != 0
# if mu.data.ndimension() == 4:
# mu = mu.view(mu.size(0), mu.size(1))
# if logvar.data.ndimension() == 4:
# logvar = logvar.view(logvar.size(0), logvar.size(1))
# klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
# total_kld = klds.sum(1).mean(0, True)
# dimension_wise_kld = klds.mean(0)
# mean_kld = klds.mean(1).mean(0, True)
# return total_kld, dimension_wise_kld, mean_kld

View File

@ -0,0 +1,104 @@
"""
Various positional encodings for the transformer.
"""
import math
import IPython
import torch
from torch import nn
from .utils import NestedTensor
e = IPython.embed
class PositionEmbeddingSine(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, tensor):
x = tensor
# mask = tensor_list.mask
# assert mask is not None
# not_mask = ~mask
not_mask = torch.ones_like(x[0, [0]])
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
class PositionEmbeddingLearned(nn.Module):
"""
Absolute pos embedding, learned.
"""
def __init__(self, num_pos_feats=256):
super().__init__()
self.row_embed = nn.Embedding(50, num_pos_feats)
self.col_embed = nn.Embedding(50, num_pos_feats)
self.reset_parameters()
def reset_parameters(self):
nn.init.uniform_(self.row_embed.weight)
nn.init.uniform_(self.col_embed.weight)
def forward(self, tensor_list: NestedTensor):
x = tensor_list.tensors
h, w = x.shape[-2:]
i = torch.arange(w, device=x.device)
j = torch.arange(h, device=x.device)
x_emb = self.col_embed(i)
y_emb = self.row_embed(j)
pos = (
torch.cat(
[
x_emb.unsqueeze(0).repeat(h, 1, 1),
y_emb.unsqueeze(1).repeat(1, w, 1),
],
dim=-1,
)
.permute(2, 0, 1)
.unsqueeze(0)
.repeat(x.shape[0], 1, 1, 1)
)
return pos
def build_position_encoding(args):
n_steps = args.hidden_dim // 2
if args.position_embedding in ("v2", "sine"):
# TODO find a better way of exposing other arguments
position_embedding = PositionEmbeddingSine(n_steps, normalize=True)
elif args.position_embedding in ("v3", "learned"):
position_embedding = PositionEmbeddingLearned(n_steps)
else:
raise ValueError(f"not supported {args.position_embedding}")
return position_embedding

View File

@ -0,0 +1,370 @@
"""
DETR Transformer class.
Copy-paste from torch.nn.Transformer with modifications:
* positional encodings are passed in MHattention
* extra LN at the end of encoder is removed
* decoder returns a stack of activations from all decoding layers
"""
import copy
from typing import Optional
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
class Transformer(nn.Module):
def __init__(
self,
d_model=512,
nhead=8,
num_encoder_layers=6,
num_decoder_layers=6,
dim_feedforward=2048,
dropout=0.1,
activation="relu",
normalize_before=False,
return_intermediate_dec=False,
):
super().__init__()
encoder_layer = TransformerEncoderLayer(
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
)
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
decoder_layer = TransformerDecoderLayer(
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
)
decoder_norm = nn.LayerNorm(d_model)
self.decoder = TransformerDecoder(
decoder_layer, num_decoder_layers, decoder_norm, return_intermediate=return_intermediate_dec
)
self._reset_parameters()
self.d_model = d_model
self.nhead = nhead
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(
self,
src,
mask,
query_embed,
pos_embed,
latent_input=None,
proprio_input=None,
additional_pos_embed=None,
):
# TODO flatten only when input has H and W
if len(src.shape) == 4: # has H and W
# flatten NxCxHxW to HWxNxC
bs, c, h, w = src.shape
src = src.flatten(2).permute(2, 0, 1)
pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1)
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
# mask = mask.flatten(1)
additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim
pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
addition_input = torch.stack([latent_input, proprio_input], axis=0)
src = torch.cat([addition_input, src], axis=0)
else:
assert len(src.shape) == 3
# flatten NxHWxC to HWxNxC
bs, hw, c = src.shape
src = src.permute(1, 0, 2)
pos_embed = pos_embed.unsqueeze(1).repeat(1, bs, 1)
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
tgt = torch.zeros_like(query_embed)
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed)
hs = hs.transpose(1, 2)
return hs
class TransformerEncoder(nn.Module):
def __init__(self, encoder_layer, num_layers, norm=None):
super().__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
def forward(
self,
src,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
):
output = src
for layer in self.layers:
output = layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos)
if self.norm is not None:
output = self.norm(output)
return output
class TransformerDecoder(nn.Module):
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
super().__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
self.return_intermediate = return_intermediate
def forward(
self,
tgt,
memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
output = tgt
intermediate = []
for layer in self.layers:
output = layer(
output,
memory,
tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
pos=pos,
query_pos=query_pos,
)
if self.return_intermediate:
intermediate.append(self.norm(output))
if self.norm is not None:
output = self.norm(output)
if self.return_intermediate:
intermediate.pop()
intermediate.append(output)
if self.return_intermediate:
return torch.stack(intermediate)
return output.unsqueeze(0)
class TransformerEncoderLayer(nn.Module):
def __init__(
self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False
):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(
self,
src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
):
q = k = self.with_pos_embed(src, pos)
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
src = src + self.dropout1(src2)
src = self.norm1(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = src + self.dropout2(src2)
src = self.norm2(src)
return src
def forward_pre(
self,
src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
):
src2 = self.norm1(src)
q = k = self.with_pos_embed(src2, pos)
src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
src = src + self.dropout1(src2)
src2 = self.norm2(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
src = src + self.dropout2(src2)
return src
def forward(
self,
src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
):
if self.normalize_before:
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
class TransformerDecoderLayer(nn.Module):
def __init__(
self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False
):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(
self,
tgt,
memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
q = k = self.with_pos_embed(tgt, query_pos)
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
tgt2 = self.multihead_attn(
query=self.with_pos_embed(tgt, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask,
)[0]
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt)
return tgt
def forward_pre(
self,
tgt,
memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, query_pos)
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout1(tgt2)
tgt2 = self.norm2(tgt)
tgt2 = self.multihead_attn(
query=self.with_pos_embed(tgt2, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask,
)[0]
tgt = tgt + self.dropout2(tgt2)
tgt2 = self.norm3(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout3(tgt2)
return tgt
def forward(
self,
tgt,
memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
if self.normalize_before:
return self.forward_pre(
tgt,
memory,
tgt_mask,
memory_mask,
tgt_key_padding_mask,
memory_key_padding_mask,
pos,
query_pos,
)
return self.forward_post(
tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos
)
def _get_clones(module, n):
return nn.ModuleList([copy.deepcopy(module) for _ in range(n)])
def build_transformer(args):
return Transformer(
d_model=args.hidden_dim,
dropout=args.dropout,
nhead=args.nheads,
dim_feedforward=args.dim_feedforward,
num_encoder_layers=args.enc_layers,
num_decoder_layers=args.dec_layers,
normalize_before=args.pre_norm,
return_intermediate_dec=True,
)
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")

View File

@ -0,0 +1,477 @@
"""
Misc functions, including distributed helpers.
Mostly copy-paste from torchvision references.
"""
import datetime
import os
import pickle
import subprocess
import time
from collections import defaultdict, deque
from typing import List, Optional
import torch
import torch.distributed as dist
# needed due to empty tensor bug in pytorch and torchvision 0.5
import torchvision
from packaging import version
from torch import Tensor
if version.parse(torchvision.__version__) < version.parse("0.7"):
from torchvision.ops import _new_empty_tensor
from torchvision.ops.misc import _output_size
class SmoothedValue:
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
if not is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
)
def all_gather(data):
"""
Run all_gather on arbitrary picklable data (not necessarily tensors)
Args:
data: any picklable object
Returns:
list[data]: list of data gathered from each rank
"""
world_size = get_world_size()
if world_size == 1:
return [data]
# serialized to a Tensor
buffer = pickle.dumps(data)
storage = torch.ByteStorage.from_buffer(buffer)
tensor = torch.ByteTensor(storage).to("cuda")
# obtain Tensor size of each rank
local_size = torch.tensor([tensor.numel()], device="cuda")
size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
dist.all_gather(size_list, local_size)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)
# receiving Tensor from all ranks
# we pad the tensor because torch all_gather does not support
# gathering tensors of different shapes
tensor_list = []
for _ in size_list:
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
if local_size != max_size:
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
tensor = torch.cat((tensor, padding), dim=0)
dist.all_gather(tensor_list, tensor)
data_list = []
for size, tensor in zip(size_list, tensor_list, strict=False):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))
return data_list
def reduce_dict(input_dict, average=True):
"""
Args:
input_dict (dict): all the values will be reduced
average (bool): whether to do average or sum
Reduce the values in the dictionary from all processes so that all processes
have the averaged results. Returns a dict with the same fields as
input_dict, after reduction.
"""
world_size = get_world_size()
if world_size < 2:
return input_dict
with torch.no_grad():
names = []
values = []
# sort the keys so that they are consistent across processes
for k in sorted(input_dict.keys()):
names.append(k)
values.append(input_dict[k])
values = torch.stack(values, dim=0)
dist.all_reduce(values)
if average:
values /= world_size
reduced_dict = {k: v for k, v in zip(names, values, strict=False)} # noqa: C416
return reduced_dict
class MetricLogger:
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append("{}: {}".format(name, str(meter)))
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, iterable, print_freq, header=None):
if not header:
header = ""
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt="{avg:.4f}")
data_time = SmoothedValue(fmt="{avg:.4f}")
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
if torch.cuda.is_available():
log_msg = self.delimiter.join(
[
header,
"[{0" + space_fmt + "}/{1}]",
"eta: {eta}",
"{meters}",
"time: {time}",
"data: {data}",
"max mem: {memory:.0f}",
]
)
else:
log_msg = self.delimiter.join(
[
header,
"[{0" + space_fmt + "}/{1}]",
"eta: {eta}",
"{meters}",
"time: {time}",
"data: {data}",
]
)
mega_b = 1024.0 * 1024.0
for i, obj in enumerate(iterable):
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if i % print_freq == 0 or i == len(iterable) - 1:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(
log_msg.format(
i,
len(iterable),
eta=eta_string,
meters=str(self),
time=str(iter_time),
data=str(data_time),
memory=torch.cuda.max_memory_allocated() / mega_b,
)
)
else:
print(
log_msg.format(
i,
len(iterable),
eta=eta_string,
meters=str(self),
time=str(iter_time),
data=str(data_time),
)
)
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print("{} Total time: {} ({:.4f} s / it)".format(header, total_time_str, total_time / len(iterable)))
def get_sha():
cwd = os.path.dirname(os.path.abspath(__file__))
def _run(command):
return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
sha = "N/A"
diff = "clean"
branch = "N/A"
try:
sha = _run(["git", "rev-parse", "HEAD"])
subprocess.check_output(["git", "diff"], cwd=cwd)
diff = _run(["git", "diff-index", "HEAD"])
diff = "has uncommited changes" if diff else "clean"
branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
except Exception:
pass
message = f"sha: {sha}, status: {diff}, branch: {branch}"
return message
def collate_fn(batch):
batch = list(zip(*batch, strict=False))
batch[0] = nested_tensor_from_tensor_list(batch[0])
return tuple(batch)
def _max_by_axis(the_list):
# type: (List[List[int]]) -> List[int]
maxes = the_list[0]
for sublist in the_list[1:]:
for index, item in enumerate(sublist):
maxes[index] = max(maxes[index], item)
return maxes
class NestedTensor:
def __init__(self, tensors, mask: Optional[Tensor]):
self.tensors = tensors
self.mask = mask
def to(self, device):
# type: (Device) -> NestedTensor # noqa
cast_tensor = self.tensors.to(device)
mask = self.mask
if mask is not None:
assert mask is not None
cast_mask = mask.to(device)
else:
cast_mask = None
return NestedTensor(cast_tensor, cast_mask)
def decompose(self):
return self.tensors, self.mask
def __repr__(self):
return str(self.tensors)
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
# TODO make this more general
if tensor_list[0].ndim == 3:
if torchvision._is_tracing():
# nested_tensor_from_tensor_list() does not export well to ONNX
# call _onnx_nested_tensor_from_tensor_list() instead
return _onnx_nested_tensor_from_tensor_list(tensor_list)
# TODO make it support different-sized images
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
batch_shape = [len(tensor_list)] + max_size
b, c, h, w = batch_shape
dtype = tensor_list[0].dtype
device = tensor_list[0].device
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
for img, pad_img, m in zip(tensor_list, tensor, mask, strict=False):
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
m[: img.shape[1], : img.shape[2]] = False
else:
raise ValueError("not supported")
return NestedTensor(tensor, mask)
# _onnx_nested_tensor_from_tensor_list() is an implementation of
# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
@torch.jit.unused
def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
max_size = []
for i in range(tensor_list[0].dim()):
max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(
torch.int64
)
max_size.append(max_size_i)
max_size = tuple(max_size)
# work around for
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
# m[: img.shape[1], :img.shape[2]] = False
# which is not yet supported in onnx
padded_imgs = []
padded_masks = []
for img in tensor_list:
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape), strict=False)]
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
padded_imgs.append(padded_img)
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
padded_masks.append(padded_mask.to(torch.bool))
tensor = torch.stack(padded_imgs)
mask = torch.stack(padded_masks)
return NestedTensor(tensor, mask=mask)
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop("force", False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def save_on_master(*args, **kwargs):
if is_main_process():
torch.save(*args, **kwargs)
def init_distributed_mode(args):
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ["WORLD_SIZE"])
args.gpu = int(os.environ["LOCAL_RANK"])
elif "SLURM_PROCID" in os.environ:
args.rank = int(os.environ["SLURM_PROCID"])
args.gpu = args.rank % torch.cuda.device_count()
else:
print("Not using distributed mode")
args.distributed = False
return
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = "nccl"
print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True)
torch.distributed.init_process_group(
backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
)
torch.distributed.barrier()
setup_for_distributed(args.rank == 0)
@torch.no_grad()
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
if target.numel() == 0:
return [torch.zeros([], device=output.device)]
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
"""
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
This will eventually be supported natively by PyTorch, and this
class can go away.
"""
if version.parse(torchvision.__version__) < version.parse("0.7"):
if input.numel() > 0:
return torch.nn.functional.interpolate(input, size, scale_factor, mode, align_corners)
output_shape = _output_size(2, input, size, scale_factor)
output_shape = list(input.shape[:-2]) + list(output_shape)
return _new_empty_tensor(input, output_shape)
else:
return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)