Add act yaml (TODO: try train.py)
This commit is contained in:
parent
a45802c281
commit
5395829596
|
@ -1,3 +1,5 @@
|
||||||
|
import logging
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
|
@ -61,20 +63,20 @@ def kl_divergence(mu, logvar):
|
||||||
return total_kld, dimension_wise_kld, mean_kld
|
return total_kld, dimension_wise_kld, mean_kld
|
||||||
|
|
||||||
|
|
||||||
class ACTPolicy(nn.Module):
|
class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
def __init__(self, cfg):
|
def __init__(self, cfg):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
model, optimizer = build_act_model_and_optimizer(cfg)
|
self.model, self.optimizer = build_act_model_and_optimizer(cfg)
|
||||||
self.model = model # CVAE decoder
|
|
||||||
self.optimizer = optimizer
|
|
||||||
self.kl_weight = cfg.kl_weight
|
self.kl_weight = cfg.kl_weight
|
||||||
print(f"KL Weight {self.kl_weight}")
|
logging.info(f"KL Weight {self.kl_weight}")
|
||||||
|
|
||||||
def __call__(self, qpos, image, actions=None, is_pad=None):
|
def forward(self, qpos, image, actions=None, is_pad=None):
|
||||||
env_state = None
|
env_state = None
|
||||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||||
image = normalize(image)
|
image = normalize(image)
|
||||||
if actions is not None: # training time
|
|
||||||
|
is_train_mode = actions is not None
|
||||||
|
if is_train_mode: # training time
|
||||||
actions = actions[:, : self.model.num_queries]
|
actions = actions[:, : self.model.num_queries]
|
||||||
is_pad = is_pad[:, : self.model.num_queries]
|
is_pad = is_pad[:, : self.model.num_queries]
|
||||||
|
|
||||||
|
@ -87,7 +89,7 @@ class ACTPolicy(nn.Module):
|
||||||
loss_dict["kl"] = total_kld[0]
|
loss_dict["kl"] = total_kld[0]
|
||||||
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.kl_weight
|
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.kl_weight
|
||||||
return loss_dict
|
return loss_dict
|
||||||
else: # inference time
|
else:
|
||||||
a_hat, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior
|
a_hat, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior
|
||||||
return a_hat
|
return a_hat
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,10 @@ def make_policy(cfg):
|
||||||
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
|
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
|
||||||
**cfg.policy,
|
**cfg.policy,
|
||||||
)
|
)
|
||||||
|
elif cfg.policy.name == "act":
|
||||||
|
from lerobot.common.policies.diffusion.policy import ActionChunkingTransformerPolicy
|
||||||
|
|
||||||
|
policy = ActionChunkingTransformerPolicy(cfg.policy)
|
||||||
else:
|
else:
|
||||||
raise ValueError(cfg.policy.name)
|
raise ValueError(cfg.policy.name)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,22 @@
|
||||||
|
# @package _global_
|
||||||
|
|
||||||
|
n_action_steps: 1
|
||||||
|
|
||||||
|
state_dim: 14
|
||||||
|
|
||||||
|
policy:
|
||||||
|
name: act
|
||||||
|
|
||||||
|
lr: 1e-5
|
||||||
|
lr_backbone: 1e-5
|
||||||
|
backbone: resnet18
|
||||||
|
num_queries: 100 # chunk_size
|
||||||
|
kl_weight: 10
|
||||||
|
hidden_dim: 512
|
||||||
|
dim_feedforward: 3200
|
||||||
|
enc_layers: 7
|
||||||
|
dec_layers: 8
|
||||||
|
nheads: 8
|
||||||
|
camera_names: top
|
||||||
|
|
||||||
|
batch_size: 8
|
Loading…
Reference in New Issue