Add act yaml (TODO: try train.py)

This commit is contained in:
Cadene 2024-03-08 18:08:28 +00:00
parent a45802c281
commit 5395829596
3 changed files with 36 additions and 8 deletions

View File

@ -1,3 +1,5 @@
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
@ -61,20 +63,20 @@ def kl_divergence(mu, logvar):
return total_kld, dimension_wise_kld, mean_kld
class ACTPolicy(nn.Module):
class ActionChunkingTransformerPolicy(nn.Module):
def __init__(self, cfg):
super().__init__()
model, optimizer = build_act_model_and_optimizer(cfg)
self.model = model # CVAE decoder
self.optimizer = optimizer
self.model, self.optimizer = build_act_model_and_optimizer(cfg)
self.kl_weight = cfg.kl_weight
print(f"KL Weight {self.kl_weight}")
logging.info(f"KL Weight {self.kl_weight}")
def __call__(self, qpos, image, actions=None, is_pad=None):
def forward(self, qpos, image, actions=None, is_pad=None):
env_state = None
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
image = normalize(image)
if actions is not None: # training time
is_train_mode = actions is not None
if is_train_mode: # training time
actions = actions[:, : self.model.num_queries]
is_pad = is_pad[:, : self.model.num_queries]
@ -87,7 +89,7 @@ class ACTPolicy(nn.Module):
loss_dict["kl"] = total_kld[0]
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.kl_weight
return loss_dict
else: # inference time
else:
a_hat, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior
return a_hat

View File

@ -17,6 +17,10 @@ def make_policy(cfg):
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
**cfg.policy,
)
elif cfg.policy.name == "act":
from lerobot.common.policies.diffusion.policy import ActionChunkingTransformerPolicy
policy = ActionChunkingTransformerPolicy(cfg.policy)
else:
raise ValueError(cfg.policy.name)

View File

@ -0,0 +1,22 @@
# @package _global_
n_action_steps: 1
state_dim: 14
policy:
name: act
lr: 1e-5
lr_backbone: 1e-5
backbone: resnet18
num_queries: 100 # chunk_size
kl_weight: 10
hidden_dim: 512
dim_feedforward: 3200
enc_layers: 7
dec_layers: 8
nheads: 8
camera_names: top
batch_size: 8