diff --git a/lerobot/common/policies/act/detr_vae.py b/lerobot/common/policies/act/detr_vae.py index 9be9eb40..272eb846 100644 --- a/lerobot/common/policies/act/detr_vae.py +++ b/lerobot/common/policies/act/detr_vae.py @@ -97,13 +97,13 @@ class DETRVAE(nn.Module): ) # (bs, seq+1, hidden_dim) encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim) # do not mask cls token - cls_joint_is_pad = torch.full((bs, 2), False).to(qpos.device) # False: not a padding - is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1) + # cls_joint_is_pad = torch.full((bs, 2), False).to(qpos.device) # False: not a padding + # is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1) # obtain position embedding pos_embed = self.pos_table.clone().detach() pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim) # query model - encoder_output = self.encoder(encoder_input, pos=pos_embed, src_key_padding_mask=is_pad) + encoder_output = self.encoder(encoder_input, pos=pos_embed) # , src_key_padding_mask=is_pad) encoder_output = encoder_output[0] # take cls output only latent_info = self.latent_proj(encoder_output) mu = latent_info[:, : self.latent_dim] @@ -149,63 +149,6 @@ class DETRVAE(nn.Module): return a_hat, is_pad_hat, [mu, logvar] -class CNNMLP(nn.Module): - def __init__(self, backbones, state_dim, camera_names): - """Initializes the model. - Parameters: - backbones: torch module of the backbone to be used. See backbone.py - transformer: torch module of the transformer architecture. See transformer.py - state_dim: robot state dimension of the environment - num_queries: number of object queries, ie detection slot. This is the maximal number of objects - DETR can detect in a single image. For COCO, we recommend 100 queries. - aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. - """ - super().__init__() - self.camera_names = camera_names - self.action_head = nn.Linear(1000, state_dim) # TODO add more - if backbones is not None: - self.backbones = nn.ModuleList(backbones) - backbone_down_projs = [] - for backbone in backbones: - down_proj = nn.Sequential( - nn.Conv2d(backbone.num_channels, 128, kernel_size=5), - nn.Conv2d(128, 64, kernel_size=5), - nn.Conv2d(64, 32, kernel_size=5), - ) - backbone_down_projs.append(down_proj) - self.backbone_down_projs = nn.ModuleList(backbone_down_projs) - - mlp_in_dim = 768 * len(backbones) + 14 - self.mlp = mlp(input_dim=mlp_in_dim, hidden_dim=1024, output_dim=14, hidden_depth=2) - else: - raise NotImplementedError - - def forward(self, qpos, image, env_state, actions=None): - """ - qpos: batch, qpos_dim - image: batch, num_cam, channel, height, width - env_state: None - actions: batch, seq, action_dim - """ - del env_state, actions - bs, _ = qpos.shape - # Image observation features and position embeddings - all_cam_features = [] - for cam_id, _ in enumerate(self.camera_names): - features, pos = self.backbones[cam_id](image[:, cam_id]) - features = features[0] # take the last layer feature - pos = pos[0] # not used - all_cam_features.append(self.backbone_down_projs[cam_id](features)) - # flatten everything - flattened_features = [] - for cam_feature in all_cam_features: - flattened_features.append(cam_feature.reshape([bs, -1])) - flattened_features = torch.cat(flattened_features, axis=1) # 768 each - features = torch.cat([flattened_features, qpos], axis=1) # qpos: 14 - a_hat = self.mlp(features) - return a_hat - - def mlp(input_dim, hidden_dim, output_dim, hidden_depth): if hidden_depth == 0: mods = [nn.Linear(input_dim, output_dim)] @@ -263,26 +206,3 @@ def build(args): print("number of parameters: {:.2f}M".format(n_parameters / 1e6)) return model - - -def build_cnnmlp(args): - state_dim = 14 # TODO hardcode - - # From state - # backbone = None # from state for now, no need for conv nets - # From image - backbones = [] - for _ in args.camera_names: - backbone = build_backbone(args) - backbones.append(backbone) - - model = CNNMLP( - backbones, - state_dim=state_dim, - camera_names=args.camera_names, - ) - - n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) - print("number of parameters: {:.2f}M".format(n_parameters / 1e6)) - - return model diff --git a/lerobot/common/policies/act/policy.py b/lerobot/common/policies/act/policy.py index 70221dbb..77d3d4a1 100644 --- a/lerobot/common/policies/act/policy.py +++ b/lerobot/common/policies/act/policy.py @@ -1,4 +1,5 @@ import logging +import time import torch import torch.nn as nn @@ -64,13 +65,128 @@ def kl_divergence(mu, logvar): class ActionChunkingTransformerPolicy(nn.Module): - def __init__(self, cfg): + def __init__(self, cfg, device): super().__init__() + self.cfg = cfg + self.device = device self.model, self.optimizer = build_act_model_and_optimizer(cfg) - self.kl_weight = cfg.kl_weight + self.kl_weight = self.cfg.kl_weight logging.info(f"KL Weight {self.kl_weight}") - def forward(self, qpos, image, actions=None, is_pad=None): + def update(self, replay_buffer, step): + del step + + start_time = time.time() + + self.train() + + num_slices = self.cfg.batch_size + batch_size = self.cfg.horizon * num_slices + + assert batch_size % self.cfg.horizon == 0 + assert batch_size % num_slices == 0 + + def process_batch(batch, horizon, num_slices): + # trajectory t = 64, horizon h = 16 + # (t h) ... -> t h ... + batch = batch.reshape(num_slices, horizon) + + image = batch["observation", "image", "top"] + image = image[:, 0] # first observation t=0 + # batch, num_cam, channel, height, width + image = image.unsqueeze(1) + assert image.ndim == 5 + image = image.float() + + state = batch["observation", "state"] + state = state[:, 0] # first observation t=0 + # batch, qpos_dim + assert state.ndim == 2 + + action = batch["action"] + # batch, seq, action_dim + assert action.ndim == 3 + assert action.shape[1] == horizon + + if self.cfg.n_obs_steps > 1: + raise NotImplementedError() + # # keep first n observations of the slice corresponding to t=[-1,0] + # image = image[:, : self.cfg.n_obs_steps] + # state = state[:, : self.cfg.n_obs_steps] + + out = { + "obs": { + "image": image.to(self.device, non_blocking=True), + "agent_pos": state.to(self.device, non_blocking=True), + }, + "action": action.to(self.device, non_blocking=True), + } + return out + + batch = replay_buffer.sample(batch_size) + batch = process_batch(batch, self.cfg.horizon, num_slices) + + data_s = time.time() - start_time + + loss = self.compute_loss(batch) + loss.backward() + + grad_norm = torch.nn.utils.clip_grad_norm_( + self.model.parameters(), + self.cfg.grad_clip_norm, + error_if_nonfinite=False, + ) + + self.optimizer.step() + self.optimizer.zero_grad() + # self.lr_scheduler.step() + + info = { + "loss": loss.item(), + "grad_norm": float(grad_norm), + # "lr": self.lr_scheduler.get_last_lr()[0], + "lr": self.cfg.lr, + "data_s": data_s, + "update_s": time.time() - start_time, + } + + return info + + def save(self, fp): + torch.save(self.state_dict(), fp) + + def load(self, fp): + d = torch.load(fp) + self.load_state_dict(d) + + def compute_loss(self, batch): + loss_dict = self._forward( + qpos=batch["obs"]["agent_pos"], + image=batch["obs"]["image"], + actions=batch["action"], + ) + loss = loss_dict["loss"] + return loss + + @torch.no_grad() + def forward(self, observation, step_count): + # TODO(rcadene): remove unused step_count + del step_count + + self.eval() + + # TODO(rcadene): remove unsqueeze hack to add bsize=1 + observation["image"] = observation["image"].unsqueeze(0) + observation["state"] = observation["state"].unsqueeze(0) + + obs_dict = { + "image": observation["image"], + "agent_pos": observation["state"], + } + action = self._forward(qpos=obs_dict["agent_pos"], image=obs_dict["image"]) + return action + + def _forward(self, qpos, image, actions=None, is_pad=None): env_state = None normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) image = normalize(image) @@ -78,23 +194,21 @@ class ActionChunkingTransformerPolicy(nn.Module): is_train_mode = actions is not None if is_train_mode: # training time actions = actions[:, : self.model.num_queries] - is_pad = is_pad[:, : self.model.num_queries] + if is_pad is not None: + is_pad = is_pad[:, : self.model.num_queries] a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad) total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar) loss_dict = {} all_l1 = F.l1_loss(actions, a_hat, reduction="none") - l1 = (all_l1 * ~is_pad.unsqueeze(-1)).mean() + l1 = all_l1.mean() if is_pad is None else (all_l1 * ~is_pad.unsqueeze(-1)).mean() loss_dict["l1"] = l1 loss_dict["kl"] = total_kld[0] loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.kl_weight return loss_dict else: - a_hat, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior - return a_hat - - def configure_optimizers(self): - return self.optimizer + action, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior + return action # class CNNMLPPolicy(nn.Module): diff --git a/lerobot/common/policies/act/position_encoding.py b/lerobot/common/policies/act/position_encoding.py index b8107079..94e862f6 100644 --- a/lerobot/common/policies/act/position_encoding.py +++ b/lerobot/common/policies/act/position_encoding.py @@ -3,14 +3,11 @@ Various positional encodings for the transformer. """ import math -import IPython import torch from torch import nn from .utils import NestedTensor -e = IPython.embed - class PositionEmbeddingSine(nn.Module): """ diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 8acab33f..5ccd1fc4 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -18,9 +18,9 @@ def make_policy(cfg): **cfg.policy, ) elif cfg.policy.name == "act": - from lerobot.common.policies.diffusion.policy import ActionChunkingTransformerPolicy + from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy - policy = ActionChunkingTransformerPolicy(cfg.policy) + policy = ActionChunkingTransformerPolicy(cfg.policy, cfg.device) else: raise ValueError(cfg.policy.name) diff --git a/lerobot/configs/policy/act.yaml b/lerobot/configs/policy/act.yaml index 42341707..98bd92c5 100644 --- a/lerobot/configs/policy/act.yaml +++ b/lerobot/configs/policy/act.yaml @@ -1,22 +1,50 @@ # @package _global_ -n_action_steps: 1 - state_dim: 14 +offline_steps: 1344000 +online_steps: 0 + +eval_episodes: 1 +eval_freq: 10000 +save_freq: 100000 +log_freq: 250 + +horizon: 100 +n_obs_steps: 1 +n_action_steps: 1 + policy: name: act + pretrained_model_path: + lr: 1e-5 lr_backbone: 1e-5 + weight_decay: 1e-4 + grad_clip_norm: 10 backbone: resnet18 - num_queries: 100 # chunk_size + num_queries: ${horizon} # chunk_size + horizon: ${horizon} # chunk_size kl_weight: 10 hidden_dim: 512 dim_feedforward: 3200 enc_layers: 7 dec_layers: 8 nheads: 8 - camera_names: top + camera_names: [top] + position_embedding: sine + masks: false + dilation: false + dropout: 0.1 + pre_norm: false batch_size: 8 + + per_alpha: 0.6 + per_beta: 0.4 + + balanced_sampling: false + utd: 1 + + n_obs_steps: ${n_obs_steps}