Training can runs (TODO: eval)
This commit is contained in:
parent
5395829596
commit
f1230cdac0
|
@ -97,13 +97,13 @@ class DETRVAE(nn.Module):
|
|||
) # (bs, seq+1, hidden_dim)
|
||||
encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim)
|
||||
# do not mask cls token
|
||||
cls_joint_is_pad = torch.full((bs, 2), False).to(qpos.device) # False: not a padding
|
||||
is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1)
|
||||
# cls_joint_is_pad = torch.full((bs, 2), False).to(qpos.device) # False: not a padding
|
||||
# is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1)
|
||||
# obtain position embedding
|
||||
pos_embed = self.pos_table.clone().detach()
|
||||
pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim)
|
||||
# query model
|
||||
encoder_output = self.encoder(encoder_input, pos=pos_embed, src_key_padding_mask=is_pad)
|
||||
encoder_output = self.encoder(encoder_input, pos=pos_embed) # , src_key_padding_mask=is_pad)
|
||||
encoder_output = encoder_output[0] # take cls output only
|
||||
latent_info = self.latent_proj(encoder_output)
|
||||
mu = latent_info[:, : self.latent_dim]
|
||||
|
@ -149,63 +149,6 @@ class DETRVAE(nn.Module):
|
|||
return a_hat, is_pad_hat, [mu, logvar]
|
||||
|
||||
|
||||
class CNNMLP(nn.Module):
|
||||
def __init__(self, backbones, state_dim, camera_names):
|
||||
"""Initializes the model.
|
||||
Parameters:
|
||||
backbones: torch module of the backbone to be used. See backbone.py
|
||||
transformer: torch module of the transformer architecture. See transformer.py
|
||||
state_dim: robot state dimension of the environment
|
||||
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
|
||||
DETR can detect in a single image. For COCO, we recommend 100 queries.
|
||||
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
|
||||
"""
|
||||
super().__init__()
|
||||
self.camera_names = camera_names
|
||||
self.action_head = nn.Linear(1000, state_dim) # TODO add more
|
||||
if backbones is not None:
|
||||
self.backbones = nn.ModuleList(backbones)
|
||||
backbone_down_projs = []
|
||||
for backbone in backbones:
|
||||
down_proj = nn.Sequential(
|
||||
nn.Conv2d(backbone.num_channels, 128, kernel_size=5),
|
||||
nn.Conv2d(128, 64, kernel_size=5),
|
||||
nn.Conv2d(64, 32, kernel_size=5),
|
||||
)
|
||||
backbone_down_projs.append(down_proj)
|
||||
self.backbone_down_projs = nn.ModuleList(backbone_down_projs)
|
||||
|
||||
mlp_in_dim = 768 * len(backbones) + 14
|
||||
self.mlp = mlp(input_dim=mlp_in_dim, hidden_dim=1024, output_dim=14, hidden_depth=2)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, qpos, image, env_state, actions=None):
|
||||
"""
|
||||
qpos: batch, qpos_dim
|
||||
image: batch, num_cam, channel, height, width
|
||||
env_state: None
|
||||
actions: batch, seq, action_dim
|
||||
"""
|
||||
del env_state, actions
|
||||
bs, _ = qpos.shape
|
||||
# Image observation features and position embeddings
|
||||
all_cam_features = []
|
||||
for cam_id, _ in enumerate(self.camera_names):
|
||||
features, pos = self.backbones[cam_id](image[:, cam_id])
|
||||
features = features[0] # take the last layer feature
|
||||
pos = pos[0] # not used
|
||||
all_cam_features.append(self.backbone_down_projs[cam_id](features))
|
||||
# flatten everything
|
||||
flattened_features = []
|
||||
for cam_feature in all_cam_features:
|
||||
flattened_features.append(cam_feature.reshape([bs, -1]))
|
||||
flattened_features = torch.cat(flattened_features, axis=1) # 768 each
|
||||
features = torch.cat([flattened_features, qpos], axis=1) # qpos: 14
|
||||
a_hat = self.mlp(features)
|
||||
return a_hat
|
||||
|
||||
|
||||
def mlp(input_dim, hidden_dim, output_dim, hidden_depth):
|
||||
if hidden_depth == 0:
|
||||
mods = [nn.Linear(input_dim, output_dim)]
|
||||
|
@ -263,26 +206,3 @@ def build(args):
|
|||
print("number of parameters: {:.2f}M".format(n_parameters / 1e6))
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def build_cnnmlp(args):
|
||||
state_dim = 14 # TODO hardcode
|
||||
|
||||
# From state
|
||||
# backbone = None # from state for now, no need for conv nets
|
||||
# From image
|
||||
backbones = []
|
||||
for _ in args.camera_names:
|
||||
backbone = build_backbone(args)
|
||||
backbones.append(backbone)
|
||||
|
||||
model = CNNMLP(
|
||||
backbones,
|
||||
state_dim=state_dim,
|
||||
camera_names=args.camera_names,
|
||||
)
|
||||
|
||||
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
print("number of parameters: {:.2f}M".format(n_parameters / 1e6))
|
||||
|
||||
return model
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import logging
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -64,13 +65,128 @@ def kl_divergence(mu, logvar):
|
|||
|
||||
|
||||
class ActionChunkingTransformerPolicy(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
def __init__(self, cfg, device):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
self.device = device
|
||||
self.model, self.optimizer = build_act_model_and_optimizer(cfg)
|
||||
self.kl_weight = cfg.kl_weight
|
||||
self.kl_weight = self.cfg.kl_weight
|
||||
logging.info(f"KL Weight {self.kl_weight}")
|
||||
|
||||
def forward(self, qpos, image, actions=None, is_pad=None):
|
||||
def update(self, replay_buffer, step):
|
||||
del step
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
self.train()
|
||||
|
||||
num_slices = self.cfg.batch_size
|
||||
batch_size = self.cfg.horizon * num_slices
|
||||
|
||||
assert batch_size % self.cfg.horizon == 0
|
||||
assert batch_size % num_slices == 0
|
||||
|
||||
def process_batch(batch, horizon, num_slices):
|
||||
# trajectory t = 64, horizon h = 16
|
||||
# (t h) ... -> t h ...
|
||||
batch = batch.reshape(num_slices, horizon)
|
||||
|
||||
image = batch["observation", "image", "top"]
|
||||
image = image[:, 0] # first observation t=0
|
||||
# batch, num_cam, channel, height, width
|
||||
image = image.unsqueeze(1)
|
||||
assert image.ndim == 5
|
||||
image = image.float()
|
||||
|
||||
state = batch["observation", "state"]
|
||||
state = state[:, 0] # first observation t=0
|
||||
# batch, qpos_dim
|
||||
assert state.ndim == 2
|
||||
|
||||
action = batch["action"]
|
||||
# batch, seq, action_dim
|
||||
assert action.ndim == 3
|
||||
assert action.shape[1] == horizon
|
||||
|
||||
if self.cfg.n_obs_steps > 1:
|
||||
raise NotImplementedError()
|
||||
# # keep first n observations of the slice corresponding to t=[-1,0]
|
||||
# image = image[:, : self.cfg.n_obs_steps]
|
||||
# state = state[:, : self.cfg.n_obs_steps]
|
||||
|
||||
out = {
|
||||
"obs": {
|
||||
"image": image.to(self.device, non_blocking=True),
|
||||
"agent_pos": state.to(self.device, non_blocking=True),
|
||||
},
|
||||
"action": action.to(self.device, non_blocking=True),
|
||||
}
|
||||
return out
|
||||
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
batch = process_batch(batch, self.cfg.horizon, num_slices)
|
||||
|
||||
data_s = time.time() - start_time
|
||||
|
||||
loss = self.compute_loss(batch)
|
||||
loss.backward()
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(),
|
||||
self.cfg.grad_clip_norm,
|
||||
error_if_nonfinite=False,
|
||||
)
|
||||
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
# self.lr_scheduler.step()
|
||||
|
||||
info = {
|
||||
"loss": loss.item(),
|
||||
"grad_norm": float(grad_norm),
|
||||
# "lr": self.lr_scheduler.get_last_lr()[0],
|
||||
"lr": self.cfg.lr,
|
||||
"data_s": data_s,
|
||||
"update_s": time.time() - start_time,
|
||||
}
|
||||
|
||||
return info
|
||||
|
||||
def save(self, fp):
|
||||
torch.save(self.state_dict(), fp)
|
||||
|
||||
def load(self, fp):
|
||||
d = torch.load(fp)
|
||||
self.load_state_dict(d)
|
||||
|
||||
def compute_loss(self, batch):
|
||||
loss_dict = self._forward(
|
||||
qpos=batch["obs"]["agent_pos"],
|
||||
image=batch["obs"]["image"],
|
||||
actions=batch["action"],
|
||||
)
|
||||
loss = loss_dict["loss"]
|
||||
return loss
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, observation, step_count):
|
||||
# TODO(rcadene): remove unused step_count
|
||||
del step_count
|
||||
|
||||
self.eval()
|
||||
|
||||
# TODO(rcadene): remove unsqueeze hack to add bsize=1
|
||||
observation["image"] = observation["image"].unsqueeze(0)
|
||||
observation["state"] = observation["state"].unsqueeze(0)
|
||||
|
||||
obs_dict = {
|
||||
"image": observation["image"],
|
||||
"agent_pos": observation["state"],
|
||||
}
|
||||
action = self._forward(qpos=obs_dict["agent_pos"], image=obs_dict["image"])
|
||||
return action
|
||||
|
||||
def _forward(self, qpos, image, actions=None, is_pad=None):
|
||||
env_state = None
|
||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
image = normalize(image)
|
||||
|
@ -78,23 +194,21 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
is_train_mode = actions is not None
|
||||
if is_train_mode: # training time
|
||||
actions = actions[:, : self.model.num_queries]
|
||||
if is_pad is not None:
|
||||
is_pad = is_pad[:, : self.model.num_queries]
|
||||
|
||||
a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad)
|
||||
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
|
||||
loss_dict = {}
|
||||
all_l1 = F.l1_loss(actions, a_hat, reduction="none")
|
||||
l1 = (all_l1 * ~is_pad.unsqueeze(-1)).mean()
|
||||
l1 = all_l1.mean() if is_pad is None else (all_l1 * ~is_pad.unsqueeze(-1)).mean()
|
||||
loss_dict["l1"] = l1
|
||||
loss_dict["kl"] = total_kld[0]
|
||||
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.kl_weight
|
||||
return loss_dict
|
||||
else:
|
||||
a_hat, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior
|
||||
return a_hat
|
||||
|
||||
def configure_optimizers(self):
|
||||
return self.optimizer
|
||||
action, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior
|
||||
return action
|
||||
|
||||
|
||||
# class CNNMLPPolicy(nn.Module):
|
||||
|
|
|
@ -3,14 +3,11 @@ Various positional encodings for the transformer.
|
|||
"""
|
||||
import math
|
||||
|
||||
import IPython
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .utils import NestedTensor
|
||||
|
||||
e = IPython.embed
|
||||
|
||||
|
||||
class PositionEmbeddingSine(nn.Module):
|
||||
"""
|
||||
|
|
|
@ -18,9 +18,9 @@ def make_policy(cfg):
|
|||
**cfg.policy,
|
||||
)
|
||||
elif cfg.policy.name == "act":
|
||||
from lerobot.common.policies.diffusion.policy import ActionChunkingTransformerPolicy
|
||||
from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
|
||||
|
||||
policy = ActionChunkingTransformerPolicy(cfg.policy)
|
||||
policy = ActionChunkingTransformerPolicy(cfg.policy, cfg.device)
|
||||
else:
|
||||
raise ValueError(cfg.policy.name)
|
||||
|
||||
|
|
|
@ -1,22 +1,50 @@
|
|||
# @package _global_
|
||||
|
||||
n_action_steps: 1
|
||||
|
||||
state_dim: 14
|
||||
|
||||
offline_steps: 1344000
|
||||
online_steps: 0
|
||||
|
||||
eval_episodes: 1
|
||||
eval_freq: 10000
|
||||
save_freq: 100000
|
||||
log_freq: 250
|
||||
|
||||
horizon: 100
|
||||
n_obs_steps: 1
|
||||
n_action_steps: 1
|
||||
|
||||
policy:
|
||||
name: act
|
||||
|
||||
pretrained_model_path:
|
||||
|
||||
lr: 1e-5
|
||||
lr_backbone: 1e-5
|
||||
weight_decay: 1e-4
|
||||
grad_clip_norm: 10
|
||||
backbone: resnet18
|
||||
num_queries: 100 # chunk_size
|
||||
num_queries: ${horizon} # chunk_size
|
||||
horizon: ${horizon} # chunk_size
|
||||
kl_weight: 10
|
||||
hidden_dim: 512
|
||||
dim_feedforward: 3200
|
||||
enc_layers: 7
|
||||
dec_layers: 8
|
||||
nheads: 8
|
||||
camera_names: top
|
||||
camera_names: [top]
|
||||
position_embedding: sine
|
||||
masks: false
|
||||
dilation: false
|
||||
dropout: 0.1
|
||||
pre_norm: false
|
||||
|
||||
batch_size: 8
|
||||
|
||||
per_alpha: 0.6
|
||||
per_beta: 0.4
|
||||
|
||||
balanced_sampling: false
|
||||
utd: 1
|
||||
|
||||
n_obs_steps: ${n_obs_steps}
|
||||
|
|
Loading…
Reference in New Issue