Training can runs (TODO: eval)
This commit is contained in:
parent
5395829596
commit
f1230cdac0
|
@ -97,13 +97,13 @@ class DETRVAE(nn.Module):
|
||||||
) # (bs, seq+1, hidden_dim)
|
) # (bs, seq+1, hidden_dim)
|
||||||
encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim)
|
encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim)
|
||||||
# do not mask cls token
|
# do not mask cls token
|
||||||
cls_joint_is_pad = torch.full((bs, 2), False).to(qpos.device) # False: not a padding
|
# cls_joint_is_pad = torch.full((bs, 2), False).to(qpos.device) # False: not a padding
|
||||||
is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1)
|
# is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1)
|
||||||
# obtain position embedding
|
# obtain position embedding
|
||||||
pos_embed = self.pos_table.clone().detach()
|
pos_embed = self.pos_table.clone().detach()
|
||||||
pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim)
|
pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim)
|
||||||
# query model
|
# query model
|
||||||
encoder_output = self.encoder(encoder_input, pos=pos_embed, src_key_padding_mask=is_pad)
|
encoder_output = self.encoder(encoder_input, pos=pos_embed) # , src_key_padding_mask=is_pad)
|
||||||
encoder_output = encoder_output[0] # take cls output only
|
encoder_output = encoder_output[0] # take cls output only
|
||||||
latent_info = self.latent_proj(encoder_output)
|
latent_info = self.latent_proj(encoder_output)
|
||||||
mu = latent_info[:, : self.latent_dim]
|
mu = latent_info[:, : self.latent_dim]
|
||||||
|
@ -149,63 +149,6 @@ class DETRVAE(nn.Module):
|
||||||
return a_hat, is_pad_hat, [mu, logvar]
|
return a_hat, is_pad_hat, [mu, logvar]
|
||||||
|
|
||||||
|
|
||||||
class CNNMLP(nn.Module):
|
|
||||||
def __init__(self, backbones, state_dim, camera_names):
|
|
||||||
"""Initializes the model.
|
|
||||||
Parameters:
|
|
||||||
backbones: torch module of the backbone to be used. See backbone.py
|
|
||||||
transformer: torch module of the transformer architecture. See transformer.py
|
|
||||||
state_dim: robot state dimension of the environment
|
|
||||||
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
|
|
||||||
DETR can detect in a single image. For COCO, we recommend 100 queries.
|
|
||||||
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.camera_names = camera_names
|
|
||||||
self.action_head = nn.Linear(1000, state_dim) # TODO add more
|
|
||||||
if backbones is not None:
|
|
||||||
self.backbones = nn.ModuleList(backbones)
|
|
||||||
backbone_down_projs = []
|
|
||||||
for backbone in backbones:
|
|
||||||
down_proj = nn.Sequential(
|
|
||||||
nn.Conv2d(backbone.num_channels, 128, kernel_size=5),
|
|
||||||
nn.Conv2d(128, 64, kernel_size=5),
|
|
||||||
nn.Conv2d(64, 32, kernel_size=5),
|
|
||||||
)
|
|
||||||
backbone_down_projs.append(down_proj)
|
|
||||||
self.backbone_down_projs = nn.ModuleList(backbone_down_projs)
|
|
||||||
|
|
||||||
mlp_in_dim = 768 * len(backbones) + 14
|
|
||||||
self.mlp = mlp(input_dim=mlp_in_dim, hidden_dim=1024, output_dim=14, hidden_depth=2)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def forward(self, qpos, image, env_state, actions=None):
|
|
||||||
"""
|
|
||||||
qpos: batch, qpos_dim
|
|
||||||
image: batch, num_cam, channel, height, width
|
|
||||||
env_state: None
|
|
||||||
actions: batch, seq, action_dim
|
|
||||||
"""
|
|
||||||
del env_state, actions
|
|
||||||
bs, _ = qpos.shape
|
|
||||||
# Image observation features and position embeddings
|
|
||||||
all_cam_features = []
|
|
||||||
for cam_id, _ in enumerate(self.camera_names):
|
|
||||||
features, pos = self.backbones[cam_id](image[:, cam_id])
|
|
||||||
features = features[0] # take the last layer feature
|
|
||||||
pos = pos[0] # not used
|
|
||||||
all_cam_features.append(self.backbone_down_projs[cam_id](features))
|
|
||||||
# flatten everything
|
|
||||||
flattened_features = []
|
|
||||||
for cam_feature in all_cam_features:
|
|
||||||
flattened_features.append(cam_feature.reshape([bs, -1]))
|
|
||||||
flattened_features = torch.cat(flattened_features, axis=1) # 768 each
|
|
||||||
features = torch.cat([flattened_features, qpos], axis=1) # qpos: 14
|
|
||||||
a_hat = self.mlp(features)
|
|
||||||
return a_hat
|
|
||||||
|
|
||||||
|
|
||||||
def mlp(input_dim, hidden_dim, output_dim, hidden_depth):
|
def mlp(input_dim, hidden_dim, output_dim, hidden_depth):
|
||||||
if hidden_depth == 0:
|
if hidden_depth == 0:
|
||||||
mods = [nn.Linear(input_dim, output_dim)]
|
mods = [nn.Linear(input_dim, output_dim)]
|
||||||
|
@ -263,26 +206,3 @@ def build(args):
|
||||||
print("number of parameters: {:.2f}M".format(n_parameters / 1e6))
|
print("number of parameters: {:.2f}M".format(n_parameters / 1e6))
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def build_cnnmlp(args):
|
|
||||||
state_dim = 14 # TODO hardcode
|
|
||||||
|
|
||||||
# From state
|
|
||||||
# backbone = None # from state for now, no need for conv nets
|
|
||||||
# From image
|
|
||||||
backbones = []
|
|
||||||
for _ in args.camera_names:
|
|
||||||
backbone = build_backbone(args)
|
|
||||||
backbones.append(backbone)
|
|
||||||
|
|
||||||
model = CNNMLP(
|
|
||||||
backbones,
|
|
||||||
state_dim=state_dim,
|
|
||||||
camera_names=args.camera_names,
|
|
||||||
)
|
|
||||||
|
|
||||||
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
||||||
print("number of parameters: {:.2f}M".format(n_parameters / 1e6))
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
@ -64,13 +65,128 @@ def kl_divergence(mu, logvar):
|
||||||
|
|
||||||
|
|
||||||
class ActionChunkingTransformerPolicy(nn.Module):
|
class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
def __init__(self, cfg):
|
def __init__(self, cfg, device):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.cfg = cfg
|
||||||
|
self.device = device
|
||||||
self.model, self.optimizer = build_act_model_and_optimizer(cfg)
|
self.model, self.optimizer = build_act_model_and_optimizer(cfg)
|
||||||
self.kl_weight = cfg.kl_weight
|
self.kl_weight = self.cfg.kl_weight
|
||||||
logging.info(f"KL Weight {self.kl_weight}")
|
logging.info(f"KL Weight {self.kl_weight}")
|
||||||
|
|
||||||
def forward(self, qpos, image, actions=None, is_pad=None):
|
def update(self, replay_buffer, step):
|
||||||
|
del step
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
self.train()
|
||||||
|
|
||||||
|
num_slices = self.cfg.batch_size
|
||||||
|
batch_size = self.cfg.horizon * num_slices
|
||||||
|
|
||||||
|
assert batch_size % self.cfg.horizon == 0
|
||||||
|
assert batch_size % num_slices == 0
|
||||||
|
|
||||||
|
def process_batch(batch, horizon, num_slices):
|
||||||
|
# trajectory t = 64, horizon h = 16
|
||||||
|
# (t h) ... -> t h ...
|
||||||
|
batch = batch.reshape(num_slices, horizon)
|
||||||
|
|
||||||
|
image = batch["observation", "image", "top"]
|
||||||
|
image = image[:, 0] # first observation t=0
|
||||||
|
# batch, num_cam, channel, height, width
|
||||||
|
image = image.unsqueeze(1)
|
||||||
|
assert image.ndim == 5
|
||||||
|
image = image.float()
|
||||||
|
|
||||||
|
state = batch["observation", "state"]
|
||||||
|
state = state[:, 0] # first observation t=0
|
||||||
|
# batch, qpos_dim
|
||||||
|
assert state.ndim == 2
|
||||||
|
|
||||||
|
action = batch["action"]
|
||||||
|
# batch, seq, action_dim
|
||||||
|
assert action.ndim == 3
|
||||||
|
assert action.shape[1] == horizon
|
||||||
|
|
||||||
|
if self.cfg.n_obs_steps > 1:
|
||||||
|
raise NotImplementedError()
|
||||||
|
# # keep first n observations of the slice corresponding to t=[-1,0]
|
||||||
|
# image = image[:, : self.cfg.n_obs_steps]
|
||||||
|
# state = state[:, : self.cfg.n_obs_steps]
|
||||||
|
|
||||||
|
out = {
|
||||||
|
"obs": {
|
||||||
|
"image": image.to(self.device, non_blocking=True),
|
||||||
|
"agent_pos": state.to(self.device, non_blocking=True),
|
||||||
|
},
|
||||||
|
"action": action.to(self.device, non_blocking=True),
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
|
||||||
|
batch = replay_buffer.sample(batch_size)
|
||||||
|
batch = process_batch(batch, self.cfg.horizon, num_slices)
|
||||||
|
|
||||||
|
data_s = time.time() - start_time
|
||||||
|
|
||||||
|
loss = self.compute_loss(batch)
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
|
self.model.parameters(),
|
||||||
|
self.cfg.grad_clip_norm,
|
||||||
|
error_if_nonfinite=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.optimizer.step()
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
# self.lr_scheduler.step()
|
||||||
|
|
||||||
|
info = {
|
||||||
|
"loss": loss.item(),
|
||||||
|
"grad_norm": float(grad_norm),
|
||||||
|
# "lr": self.lr_scheduler.get_last_lr()[0],
|
||||||
|
"lr": self.cfg.lr,
|
||||||
|
"data_s": data_s,
|
||||||
|
"update_s": time.time() - start_time,
|
||||||
|
}
|
||||||
|
|
||||||
|
return info
|
||||||
|
|
||||||
|
def save(self, fp):
|
||||||
|
torch.save(self.state_dict(), fp)
|
||||||
|
|
||||||
|
def load(self, fp):
|
||||||
|
d = torch.load(fp)
|
||||||
|
self.load_state_dict(d)
|
||||||
|
|
||||||
|
def compute_loss(self, batch):
|
||||||
|
loss_dict = self._forward(
|
||||||
|
qpos=batch["obs"]["agent_pos"],
|
||||||
|
image=batch["obs"]["image"],
|
||||||
|
actions=batch["action"],
|
||||||
|
)
|
||||||
|
loss = loss_dict["loss"]
|
||||||
|
return loss
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(self, observation, step_count):
|
||||||
|
# TODO(rcadene): remove unused step_count
|
||||||
|
del step_count
|
||||||
|
|
||||||
|
self.eval()
|
||||||
|
|
||||||
|
# TODO(rcadene): remove unsqueeze hack to add bsize=1
|
||||||
|
observation["image"] = observation["image"].unsqueeze(0)
|
||||||
|
observation["state"] = observation["state"].unsqueeze(0)
|
||||||
|
|
||||||
|
obs_dict = {
|
||||||
|
"image": observation["image"],
|
||||||
|
"agent_pos": observation["state"],
|
||||||
|
}
|
||||||
|
action = self._forward(qpos=obs_dict["agent_pos"], image=obs_dict["image"])
|
||||||
|
return action
|
||||||
|
|
||||||
|
def _forward(self, qpos, image, actions=None, is_pad=None):
|
||||||
env_state = None
|
env_state = None
|
||||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||||
image = normalize(image)
|
image = normalize(image)
|
||||||
|
@ -78,23 +194,21 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
is_train_mode = actions is not None
|
is_train_mode = actions is not None
|
||||||
if is_train_mode: # training time
|
if is_train_mode: # training time
|
||||||
actions = actions[:, : self.model.num_queries]
|
actions = actions[:, : self.model.num_queries]
|
||||||
|
if is_pad is not None:
|
||||||
is_pad = is_pad[:, : self.model.num_queries]
|
is_pad = is_pad[:, : self.model.num_queries]
|
||||||
|
|
||||||
a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad)
|
a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad)
|
||||||
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
|
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
|
||||||
loss_dict = {}
|
loss_dict = {}
|
||||||
all_l1 = F.l1_loss(actions, a_hat, reduction="none")
|
all_l1 = F.l1_loss(actions, a_hat, reduction="none")
|
||||||
l1 = (all_l1 * ~is_pad.unsqueeze(-1)).mean()
|
l1 = all_l1.mean() if is_pad is None else (all_l1 * ~is_pad.unsqueeze(-1)).mean()
|
||||||
loss_dict["l1"] = l1
|
loss_dict["l1"] = l1
|
||||||
loss_dict["kl"] = total_kld[0]
|
loss_dict["kl"] = total_kld[0]
|
||||||
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.kl_weight
|
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.kl_weight
|
||||||
return loss_dict
|
return loss_dict
|
||||||
else:
|
else:
|
||||||
a_hat, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior
|
action, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior
|
||||||
return a_hat
|
return action
|
||||||
|
|
||||||
def configure_optimizers(self):
|
|
||||||
return self.optimizer
|
|
||||||
|
|
||||||
|
|
||||||
# class CNNMLPPolicy(nn.Module):
|
# class CNNMLPPolicy(nn.Module):
|
||||||
|
|
|
@ -3,14 +3,11 @@ Various positional encodings for the transformer.
|
||||||
"""
|
"""
|
||||||
import math
|
import math
|
||||||
|
|
||||||
import IPython
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from .utils import NestedTensor
|
from .utils import NestedTensor
|
||||||
|
|
||||||
e = IPython.embed
|
|
||||||
|
|
||||||
|
|
||||||
class PositionEmbeddingSine(nn.Module):
|
class PositionEmbeddingSine(nn.Module):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -18,9 +18,9 @@ def make_policy(cfg):
|
||||||
**cfg.policy,
|
**cfg.policy,
|
||||||
)
|
)
|
||||||
elif cfg.policy.name == "act":
|
elif cfg.policy.name == "act":
|
||||||
from lerobot.common.policies.diffusion.policy import ActionChunkingTransformerPolicy
|
from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
|
||||||
|
|
||||||
policy = ActionChunkingTransformerPolicy(cfg.policy)
|
policy = ActionChunkingTransformerPolicy(cfg.policy, cfg.device)
|
||||||
else:
|
else:
|
||||||
raise ValueError(cfg.policy.name)
|
raise ValueError(cfg.policy.name)
|
||||||
|
|
||||||
|
|
|
@ -1,22 +1,50 @@
|
||||||
# @package _global_
|
# @package _global_
|
||||||
|
|
||||||
n_action_steps: 1
|
|
||||||
|
|
||||||
state_dim: 14
|
state_dim: 14
|
||||||
|
|
||||||
|
offline_steps: 1344000
|
||||||
|
online_steps: 0
|
||||||
|
|
||||||
|
eval_episodes: 1
|
||||||
|
eval_freq: 10000
|
||||||
|
save_freq: 100000
|
||||||
|
log_freq: 250
|
||||||
|
|
||||||
|
horizon: 100
|
||||||
|
n_obs_steps: 1
|
||||||
|
n_action_steps: 1
|
||||||
|
|
||||||
policy:
|
policy:
|
||||||
name: act
|
name: act
|
||||||
|
|
||||||
|
pretrained_model_path:
|
||||||
|
|
||||||
lr: 1e-5
|
lr: 1e-5
|
||||||
lr_backbone: 1e-5
|
lr_backbone: 1e-5
|
||||||
|
weight_decay: 1e-4
|
||||||
|
grad_clip_norm: 10
|
||||||
backbone: resnet18
|
backbone: resnet18
|
||||||
num_queries: 100 # chunk_size
|
num_queries: ${horizon} # chunk_size
|
||||||
|
horizon: ${horizon} # chunk_size
|
||||||
kl_weight: 10
|
kl_weight: 10
|
||||||
hidden_dim: 512
|
hidden_dim: 512
|
||||||
dim_feedforward: 3200
|
dim_feedforward: 3200
|
||||||
enc_layers: 7
|
enc_layers: 7
|
||||||
dec_layers: 8
|
dec_layers: 8
|
||||||
nheads: 8
|
nheads: 8
|
||||||
camera_names: top
|
camera_names: [top]
|
||||||
|
position_embedding: sine
|
||||||
|
masks: false
|
||||||
|
dilation: false
|
||||||
|
dropout: 0.1
|
||||||
|
pre_norm: false
|
||||||
|
|
||||||
batch_size: 8
|
batch_size: 8
|
||||||
|
|
||||||
|
per_alpha: 0.6
|
||||||
|
per_beta: 0.4
|
||||||
|
|
||||||
|
balanced_sampling: false
|
||||||
|
utd: 1
|
||||||
|
|
||||||
|
n_obs_steps: ${n_obs_steps}
|
||||||
|
|
Loading…
Reference in New Issue