Use dataclass config for ACT
This commit is contained in:
parent
34f00753eb
commit
ef4bd9e25c
|
@ -1,60 +1,104 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ActConfig:
|
class ActConfig:
|
||||||
"""
|
"""Configuration class for the Action Chunking Transformers policy.
|
||||||
TODO(now): Document all variables
|
|
||||||
TODO(now): Pick sensible defaults for a use case?
|
Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".
|
||||||
|
|
||||||
|
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||||
|
Those are: `state_dim`, `action_dim` and `camera_names`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_dim: Dimensionality of the observation state space (excluding images).
|
||||||
|
action_dim: Dimensionality of the action space.
|
||||||
|
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
||||||
|
current step and additional steps going back).
|
||||||
|
camera_names: The (unique) set of names for the cameras.
|
||||||
|
chunk_size: The size of the action prediction "chunks" in units of environment steps.
|
||||||
|
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
|
||||||
|
This should be no greater than the chunk size. For example, if the chunk size size 100, you may
|
||||||
|
set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
|
||||||
|
environment, and throws the other 50 out.
|
||||||
|
image_normalization_mean: Value to subtract from the input image pixels (inputs are assumed to be in
|
||||||
|
[0, 1]) for normalization.
|
||||||
|
image_normalization_std: Value by which to divide the input image pixels (after the mean has been
|
||||||
|
subtracted).
|
||||||
|
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||||
|
use_pretrained_backbone: Whether the backbone should be initialized with ImageNet, pretrained weights
|
||||||
|
from torchvision.
|
||||||
|
replace_final_stride_with_dilation: Whether to replace the ResNet's final 2x2 stride with a dilated
|
||||||
|
convolution.
|
||||||
|
pre_norm: Whether to use "pre-norm" in the transformer blocks.
|
||||||
|
d_model: The transformer blocks' main hidden dimension.
|
||||||
|
n_heads: The number of heads to use in the transformer blocks' multi-head attention.
|
||||||
|
dim_feedforward: The dimension to expand the transformer's hidden dimension to in the feed-forward
|
||||||
|
layers.
|
||||||
|
feedforward_activation: The activation to use in the transformer block's feed-forward layers.
|
||||||
|
n_encoder_layers: The number of transformer layers to use for the transformer encoder.
|
||||||
|
n_decoder_layers: The number of transformer layers to use for the transformer decoder.
|
||||||
|
use_vae: Whether to use a variational objective during training. This introduces another transformer
|
||||||
|
which is used as the VAE's encoder (not to be confused with the transformer encoder - see
|
||||||
|
documentation in the policy class).
|
||||||
|
latent_dim: The VAE's latent dimension.
|
||||||
|
n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder.
|
||||||
|
use_temporal_aggregation: Whether to blend the actions of multiple policy invocations for any given
|
||||||
|
environment step.
|
||||||
|
dropout: Dropout to use in the transformer layers (see code for details).
|
||||||
|
kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective
|
||||||
|
is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Environment.
|
# Environment.
|
||||||
state_dim: int
|
state_dim: int = 14
|
||||||
action_dim: int
|
action_dim: int = 14
|
||||||
|
|
||||||
# Inputs / output structure.
|
# Inputs / output structure.
|
||||||
n_obs_steps: int
|
n_obs_steps: int = 1
|
||||||
camera_names: list[str]
|
camera_names: list[str] = field(default_factory=lambda: ["top"])
|
||||||
chunk_size: int
|
chunk_size: int = 100
|
||||||
n_action_steps: int
|
n_action_steps: int = 100
|
||||||
|
|
||||||
# Vision preprocessing.
|
# Vision preprocessing.
|
||||||
image_normalization_mean: tuple[float, float, float]
|
image_normalization_mean: tuple[float, float, float] = field(
|
||||||
image_normalization_std: tuple[float, float, float]
|
default_factory=lambda: [0.485, 0.456, 0.406]
|
||||||
|
)
|
||||||
|
image_normalization_std: tuple[float, float, float] = field(default_factory=lambda: [0.229, 0.224, 0.225])
|
||||||
|
|
||||||
# Architecture.
|
# Architecture.
|
||||||
# Vision backbone.
|
# Vision backbone.
|
||||||
vision_backbone: str
|
vision_backbone: str = "resnet18"
|
||||||
use_pretrained_backbone: bool
|
use_pretrained_backbone: bool = True
|
||||||
replace_final_stride_with_dilation: int
|
replace_final_stride_with_dilation: int = False
|
||||||
# Transformer layers.
|
# Transformer layers.
|
||||||
pre_norm: bool
|
pre_norm: bool = False
|
||||||
d_model: int
|
d_model: int = 512
|
||||||
n_heads: int
|
n_heads: int = 8
|
||||||
dim_feedforward: int
|
dim_feedforward: int = 3200
|
||||||
feedforward_activation: str
|
feedforward_activation: str = "relu"
|
||||||
n_encoder_layers: int
|
n_encoder_layers: int = 4
|
||||||
n_decoder_layers: int
|
n_decoder_layers: int = 1
|
||||||
# VAE.
|
# VAE.
|
||||||
use_vae: bool
|
use_vae: bool = True
|
||||||
latent_dim: int
|
latent_dim: int = 32
|
||||||
n_vae_encoder_layers: int
|
n_vae_encoder_layers: int = 4
|
||||||
|
|
||||||
# Inference.
|
# Inference.
|
||||||
use_temporal_aggregation: bool
|
use_temporal_aggregation: bool = False
|
||||||
|
|
||||||
# Training and loss computation.
|
# Training and loss computation.
|
||||||
dropout: float
|
dropout: float = 0.1
|
||||||
kl_weight: float
|
kl_weight: float = 10.0
|
||||||
|
|
||||||
# ---
|
# ---
|
||||||
# TODO(alexander-soare): Remove these from the policy config.
|
# TODO(alexander-soare): Remove these from the policy config.
|
||||||
batch_size: int
|
batch_size: int = 8
|
||||||
lr: float
|
lr: float = 1e-5
|
||||||
lr_backbone: float
|
lr_backbone: float = 1e-5
|
||||||
weight_decay: float
|
weight_decay: float = 1e-4
|
||||||
grad_clip_norm: float
|
grad_clip_norm: float = 10
|
||||||
utd: int
|
utd: int = 1
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""Input validation."""
|
"""Input validation."""
|
||||||
|
@ -66,3 +110,5 @@ class ActConfig:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The chunk size is the upper bound for the number of action steps per model invocation."
|
"The chunk size is the upper bound for the number of action steps per model invocation."
|
||||||
)
|
)
|
||||||
|
if self.camera_names != ["top"]:
|
||||||
|
raise ValueError("For now, `camera_names` can only be ['top']")
|
||||||
|
|
|
@ -54,7 +54,7 @@ policy:
|
||||||
|
|
||||||
# Training and loss computation.
|
# Training and loss computation.
|
||||||
dropout: 0.1
|
dropout: 0.1
|
||||||
kl_weight: 10
|
kl_weight: 10.0
|
||||||
|
|
||||||
# ---
|
# ---
|
||||||
# TODO(alexander-soare): Remove these from the policy config.
|
# TODO(alexander-soare): Remove these from the policy config.
|
||||||
|
|
|
@ -18,14 +18,14 @@ from lerobot.common.datasets.xarm import XarmDataset
|
||||||
from lerobot.common.datasets.aloha import AlohaDataset
|
from lerobot.common.datasets.aloha import AlohaDataset
|
||||||
from lerobot.common.datasets.pusht import PushtDataset
|
from lerobot.common.datasets.pusht import PushtDataset
|
||||||
|
|
||||||
from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
|
from lerobot.common.policies.act.modeling_act import ActPolicy
|
||||||
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
|
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
|
||||||
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
||||||
|
|
||||||
|
|
||||||
def test_available():
|
def test_available():
|
||||||
policy_classes = [
|
policy_classes = [
|
||||||
ActionChunkingTransformerPolicy,
|
ActPolicy,
|
||||||
DiffusionPolicy,
|
DiffusionPolicy,
|
||||||
TDMPCPolicy,
|
TDMPCPolicy,
|
||||||
]
|
]
|
||||||
|
|
Loading…
Reference in New Issue