From 5395829596f5714894f29bfb7df8e58f0f3babdc Mon Sep 17 00:00:00 2001 From: Cadene Date: Fri, 8 Mar 2024 18:08:28 +0000 Subject: [PATCH] Add act yaml (TODO: try train.py) --- lerobot/common/policies/act/policy.py | 18 ++++++++++-------- lerobot/common/policies/factory.py | 4 ++++ lerobot/configs/policy/act.yaml | 22 ++++++++++++++++++++++ 3 files changed, 36 insertions(+), 8 deletions(-) create mode 100644 lerobot/configs/policy/act.yaml diff --git a/lerobot/common/policies/act/policy.py b/lerobot/common/policies/act/policy.py index 50aa3607..70221dbb 100644 --- a/lerobot/common/policies/act/policy.py +++ b/lerobot/common/policies/act/policy.py @@ -1,3 +1,5 @@ +import logging + import torch import torch.nn as nn import torch.nn.functional as F # noqa: N812 @@ -61,20 +63,20 @@ def kl_divergence(mu, logvar): return total_kld, dimension_wise_kld, mean_kld -class ACTPolicy(nn.Module): +class ActionChunkingTransformerPolicy(nn.Module): def __init__(self, cfg): super().__init__() - model, optimizer = build_act_model_and_optimizer(cfg) - self.model = model # CVAE decoder - self.optimizer = optimizer + self.model, self.optimizer = build_act_model_and_optimizer(cfg) self.kl_weight = cfg.kl_weight - print(f"KL Weight {self.kl_weight}") + logging.info(f"KL Weight {self.kl_weight}") - def __call__(self, qpos, image, actions=None, is_pad=None): + def forward(self, qpos, image, actions=None, is_pad=None): env_state = None normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) image = normalize(image) - if actions is not None: # training time + + is_train_mode = actions is not None + if is_train_mode: # training time actions = actions[:, : self.model.num_queries] is_pad = is_pad[:, : self.model.num_queries] @@ -87,7 +89,7 @@ class ACTPolicy(nn.Module): loss_dict["kl"] = total_kld[0] loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.kl_weight return loss_dict - else: # inference time + else: a_hat, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior return a_hat diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 9507586c..8acab33f 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -17,6 +17,10 @@ def make_policy(cfg): n_action_steps=cfg.n_action_steps + cfg.n_latency_steps, **cfg.policy, ) + elif cfg.policy.name == "act": + from lerobot.common.policies.diffusion.policy import ActionChunkingTransformerPolicy + + policy = ActionChunkingTransformerPolicy(cfg.policy) else: raise ValueError(cfg.policy.name) diff --git a/lerobot/configs/policy/act.yaml b/lerobot/configs/policy/act.yaml new file mode 100644 index 00000000..42341707 --- /dev/null +++ b/lerobot/configs/policy/act.yaml @@ -0,0 +1,22 @@ +# @package _global_ + +n_action_steps: 1 + +state_dim: 14 + +policy: + name: act + + lr: 1e-5 + lr_backbone: 1e-5 + backbone: resnet18 + num_queries: 100 # chunk_size + kl_weight: 10 + hidden_dim: 512 + dim_feedforward: 3200 + enc_layers: 7 + dec_layers: 8 + nheads: 8 + camera_names: top + + batch_size: 8