Training can runs (TODO: eval)

This commit is contained in:
Cadene 2024-03-09 16:52:08 +00:00
parent 5395829596
commit f1230cdac0
5 changed files with 161 additions and 102 deletions

View File

@ -97,13 +97,13 @@ class DETRVAE(nn.Module):
) # (bs, seq+1, hidden_dim) ) # (bs, seq+1, hidden_dim)
encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim) encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim)
# do not mask cls token # do not mask cls token
cls_joint_is_pad = torch.full((bs, 2), False).to(qpos.device) # False: not a padding # cls_joint_is_pad = torch.full((bs, 2), False).to(qpos.device) # False: not a padding
is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1) # is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1)
# obtain position embedding # obtain position embedding
pos_embed = self.pos_table.clone().detach() pos_embed = self.pos_table.clone().detach()
pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim) pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim)
# query model # query model
encoder_output = self.encoder(encoder_input, pos=pos_embed, src_key_padding_mask=is_pad) encoder_output = self.encoder(encoder_input, pos=pos_embed) # , src_key_padding_mask=is_pad)
encoder_output = encoder_output[0] # take cls output only encoder_output = encoder_output[0] # take cls output only
latent_info = self.latent_proj(encoder_output) latent_info = self.latent_proj(encoder_output)
mu = latent_info[:, : self.latent_dim] mu = latent_info[:, : self.latent_dim]
@ -149,63 +149,6 @@ class DETRVAE(nn.Module):
return a_hat, is_pad_hat, [mu, logvar] return a_hat, is_pad_hat, [mu, logvar]
class CNNMLP(nn.Module):
def __init__(self, backbones, state_dim, camera_names):
"""Initializes the model.
Parameters:
backbones: torch module of the backbone to be used. See backbone.py
transformer: torch module of the transformer architecture. See transformer.py
state_dim: robot state dimension of the environment
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
DETR can detect in a single image. For COCO, we recommend 100 queries.
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
"""
super().__init__()
self.camera_names = camera_names
self.action_head = nn.Linear(1000, state_dim) # TODO add more
if backbones is not None:
self.backbones = nn.ModuleList(backbones)
backbone_down_projs = []
for backbone in backbones:
down_proj = nn.Sequential(
nn.Conv2d(backbone.num_channels, 128, kernel_size=5),
nn.Conv2d(128, 64, kernel_size=5),
nn.Conv2d(64, 32, kernel_size=5),
)
backbone_down_projs.append(down_proj)
self.backbone_down_projs = nn.ModuleList(backbone_down_projs)
mlp_in_dim = 768 * len(backbones) + 14
self.mlp = mlp(input_dim=mlp_in_dim, hidden_dim=1024, output_dim=14, hidden_depth=2)
else:
raise NotImplementedError
def forward(self, qpos, image, env_state, actions=None):
"""
qpos: batch, qpos_dim
image: batch, num_cam, channel, height, width
env_state: None
actions: batch, seq, action_dim
"""
del env_state, actions
bs, _ = qpos.shape
# Image observation features and position embeddings
all_cam_features = []
for cam_id, _ in enumerate(self.camera_names):
features, pos = self.backbones[cam_id](image[:, cam_id])
features = features[0] # take the last layer feature
pos = pos[0] # not used
all_cam_features.append(self.backbone_down_projs[cam_id](features))
# flatten everything
flattened_features = []
for cam_feature in all_cam_features:
flattened_features.append(cam_feature.reshape([bs, -1]))
flattened_features = torch.cat(flattened_features, axis=1) # 768 each
features = torch.cat([flattened_features, qpos], axis=1) # qpos: 14
a_hat = self.mlp(features)
return a_hat
def mlp(input_dim, hidden_dim, output_dim, hidden_depth): def mlp(input_dim, hidden_dim, output_dim, hidden_depth):
if hidden_depth == 0: if hidden_depth == 0:
mods = [nn.Linear(input_dim, output_dim)] mods = [nn.Linear(input_dim, output_dim)]
@ -263,26 +206,3 @@ def build(args):
print("number of parameters: {:.2f}M".format(n_parameters / 1e6)) print("number of parameters: {:.2f}M".format(n_parameters / 1e6))
return model return model
def build_cnnmlp(args):
state_dim = 14 # TODO hardcode
# From state
# backbone = None # from state for now, no need for conv nets
# From image
backbones = []
for _ in args.camera_names:
backbone = build_backbone(args)
backbones.append(backbone)
model = CNNMLP(
backbones,
state_dim=state_dim,
camera_names=args.camera_names,
)
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("number of parameters: {:.2f}M".format(n_parameters / 1e6))
return model

View File

@ -1,4 +1,5 @@
import logging import logging
import time
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -64,13 +65,128 @@ def kl_divergence(mu, logvar):
class ActionChunkingTransformerPolicy(nn.Module): class ActionChunkingTransformerPolicy(nn.Module):
def __init__(self, cfg): def __init__(self, cfg, device):
super().__init__() super().__init__()
self.cfg = cfg
self.device = device
self.model, self.optimizer = build_act_model_and_optimizer(cfg) self.model, self.optimizer = build_act_model_and_optimizer(cfg)
self.kl_weight = cfg.kl_weight self.kl_weight = self.cfg.kl_weight
logging.info(f"KL Weight {self.kl_weight}") logging.info(f"KL Weight {self.kl_weight}")
def forward(self, qpos, image, actions=None, is_pad=None): def update(self, replay_buffer, step):
del step
start_time = time.time()
self.train()
num_slices = self.cfg.batch_size
batch_size = self.cfg.horizon * num_slices
assert batch_size % self.cfg.horizon == 0
assert batch_size % num_slices == 0
def process_batch(batch, horizon, num_slices):
# trajectory t = 64, horizon h = 16
# (t h) ... -> t h ...
batch = batch.reshape(num_slices, horizon)
image = batch["observation", "image", "top"]
image = image[:, 0] # first observation t=0
# batch, num_cam, channel, height, width
image = image.unsqueeze(1)
assert image.ndim == 5
image = image.float()
state = batch["observation", "state"]
state = state[:, 0] # first observation t=0
# batch, qpos_dim
assert state.ndim == 2
action = batch["action"]
# batch, seq, action_dim
assert action.ndim == 3
assert action.shape[1] == horizon
if self.cfg.n_obs_steps > 1:
raise NotImplementedError()
# # keep first n observations of the slice corresponding to t=[-1,0]
# image = image[:, : self.cfg.n_obs_steps]
# state = state[:, : self.cfg.n_obs_steps]
out = {
"obs": {
"image": image.to(self.device, non_blocking=True),
"agent_pos": state.to(self.device, non_blocking=True),
},
"action": action.to(self.device, non_blocking=True),
}
return out
batch = replay_buffer.sample(batch_size)
batch = process_batch(batch, self.cfg.horizon, num_slices)
data_s = time.time() - start_time
loss = self.compute_loss(batch)
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.cfg.grad_clip_norm,
error_if_nonfinite=False,
)
self.optimizer.step()
self.optimizer.zero_grad()
# self.lr_scheduler.step()
info = {
"loss": loss.item(),
"grad_norm": float(grad_norm),
# "lr": self.lr_scheduler.get_last_lr()[0],
"lr": self.cfg.lr,
"data_s": data_s,
"update_s": time.time() - start_time,
}
return info
def save(self, fp):
torch.save(self.state_dict(), fp)
def load(self, fp):
d = torch.load(fp)
self.load_state_dict(d)
def compute_loss(self, batch):
loss_dict = self._forward(
qpos=batch["obs"]["agent_pos"],
image=batch["obs"]["image"],
actions=batch["action"],
)
loss = loss_dict["loss"]
return loss
@torch.no_grad()
def forward(self, observation, step_count):
# TODO(rcadene): remove unused step_count
del step_count
self.eval()
# TODO(rcadene): remove unsqueeze hack to add bsize=1
observation["image"] = observation["image"].unsqueeze(0)
observation["state"] = observation["state"].unsqueeze(0)
obs_dict = {
"image": observation["image"],
"agent_pos": observation["state"],
}
action = self._forward(qpos=obs_dict["agent_pos"], image=obs_dict["image"])
return action
def _forward(self, qpos, image, actions=None, is_pad=None):
env_state = None env_state = None
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
image = normalize(image) image = normalize(image)
@ -78,23 +194,21 @@ class ActionChunkingTransformerPolicy(nn.Module):
is_train_mode = actions is not None is_train_mode = actions is not None
if is_train_mode: # training time if is_train_mode: # training time
actions = actions[:, : self.model.num_queries] actions = actions[:, : self.model.num_queries]
if is_pad is not None:
is_pad = is_pad[:, : self.model.num_queries] is_pad = is_pad[:, : self.model.num_queries]
a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad) a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad)
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar) total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
loss_dict = {} loss_dict = {}
all_l1 = F.l1_loss(actions, a_hat, reduction="none") all_l1 = F.l1_loss(actions, a_hat, reduction="none")
l1 = (all_l1 * ~is_pad.unsqueeze(-1)).mean() l1 = all_l1.mean() if is_pad is None else (all_l1 * ~is_pad.unsqueeze(-1)).mean()
loss_dict["l1"] = l1 loss_dict["l1"] = l1
loss_dict["kl"] = total_kld[0] loss_dict["kl"] = total_kld[0]
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.kl_weight loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.kl_weight
return loss_dict return loss_dict
else: else:
a_hat, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior action, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior
return a_hat return action
def configure_optimizers(self):
return self.optimizer
# class CNNMLPPolicy(nn.Module): # class CNNMLPPolicy(nn.Module):

View File

@ -3,14 +3,11 @@ Various positional encodings for the transformer.
""" """
import math import math
import IPython
import torch import torch
from torch import nn from torch import nn
from .utils import NestedTensor from .utils import NestedTensor
e = IPython.embed
class PositionEmbeddingSine(nn.Module): class PositionEmbeddingSine(nn.Module):
""" """

View File

@ -18,9 +18,9 @@ def make_policy(cfg):
**cfg.policy, **cfg.policy,
) )
elif cfg.policy.name == "act": elif cfg.policy.name == "act":
from lerobot.common.policies.diffusion.policy import ActionChunkingTransformerPolicy from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
policy = ActionChunkingTransformerPolicy(cfg.policy) policy = ActionChunkingTransformerPolicy(cfg.policy, cfg.device)
else: else:
raise ValueError(cfg.policy.name) raise ValueError(cfg.policy.name)

View File

@ -1,22 +1,50 @@
# @package _global_ # @package _global_
n_action_steps: 1
state_dim: 14 state_dim: 14
offline_steps: 1344000
online_steps: 0
eval_episodes: 1
eval_freq: 10000
save_freq: 100000
log_freq: 250
horizon: 100
n_obs_steps: 1
n_action_steps: 1
policy: policy:
name: act name: act
pretrained_model_path:
lr: 1e-5 lr: 1e-5
lr_backbone: 1e-5 lr_backbone: 1e-5
weight_decay: 1e-4
grad_clip_norm: 10
backbone: resnet18 backbone: resnet18
num_queries: 100 # chunk_size num_queries: ${horizon} # chunk_size
horizon: ${horizon} # chunk_size
kl_weight: 10 kl_weight: 10
hidden_dim: 512 hidden_dim: 512
dim_feedforward: 3200 dim_feedforward: 3200
enc_layers: 7 enc_layers: 7
dec_layers: 8 dec_layers: 8
nheads: 8 nheads: 8
camera_names: top camera_names: [top]
position_embedding: sine
masks: false
dilation: false
dropout: 0.1
pre_norm: false
batch_size: 8 batch_size: 8
per_alpha: 0.6
per_beta: 0.4
balanced_sampling: false
utd: 1
n_obs_steps: ${n_obs_steps}