Add act yaml (TODO: try train.py)
This commit is contained in:
parent
a45802c281
commit
5395829596
|
@ -1,3 +1,5 @@
|
|||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
|
@ -61,20 +63,20 @@ def kl_divergence(mu, logvar):
|
|||
return total_kld, dimension_wise_kld, mean_kld
|
||||
|
||||
|
||||
class ACTPolicy(nn.Module):
|
||||
class ActionChunkingTransformerPolicy(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
model, optimizer = build_act_model_and_optimizer(cfg)
|
||||
self.model = model # CVAE decoder
|
||||
self.optimizer = optimizer
|
||||
self.model, self.optimizer = build_act_model_and_optimizer(cfg)
|
||||
self.kl_weight = cfg.kl_weight
|
||||
print(f"KL Weight {self.kl_weight}")
|
||||
logging.info(f"KL Weight {self.kl_weight}")
|
||||
|
||||
def __call__(self, qpos, image, actions=None, is_pad=None):
|
||||
def forward(self, qpos, image, actions=None, is_pad=None):
|
||||
env_state = None
|
||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
image = normalize(image)
|
||||
if actions is not None: # training time
|
||||
|
||||
is_train_mode = actions is not None
|
||||
if is_train_mode: # training time
|
||||
actions = actions[:, : self.model.num_queries]
|
||||
is_pad = is_pad[:, : self.model.num_queries]
|
||||
|
||||
|
@ -87,7 +89,7 @@ class ACTPolicy(nn.Module):
|
|||
loss_dict["kl"] = total_kld[0]
|
||||
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.kl_weight
|
||||
return loss_dict
|
||||
else: # inference time
|
||||
else:
|
||||
a_hat, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior
|
||||
return a_hat
|
||||
|
||||
|
|
|
@ -17,6 +17,10 @@ def make_policy(cfg):
|
|||
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
|
||||
**cfg.policy,
|
||||
)
|
||||
elif cfg.policy.name == "act":
|
||||
from lerobot.common.policies.diffusion.policy import ActionChunkingTransformerPolicy
|
||||
|
||||
policy = ActionChunkingTransformerPolicy(cfg.policy)
|
||||
else:
|
||||
raise ValueError(cfg.policy.name)
|
||||
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
# @package _global_
|
||||
|
||||
n_action_steps: 1
|
||||
|
||||
state_dim: 14
|
||||
|
||||
policy:
|
||||
name: act
|
||||
|
||||
lr: 1e-5
|
||||
lr_backbone: 1e-5
|
||||
backbone: resnet18
|
||||
num_queries: 100 # chunk_size
|
||||
kl_weight: 10
|
||||
hidden_dim: 512
|
||||
dim_feedforward: 3200
|
||||
enc_layers: 7
|
||||
dec_layers: 8
|
||||
nheads: 8
|
||||
camera_names: top
|
||||
|
||||
batch_size: 8
|
Loading…
Reference in New Issue