pretrained config for act

This commit is contained in:
Quentin Gallouédec 2024-04-25 16:06:57 +02:00
parent 659c69a1c0
commit 783a40c9d4
1 changed files with 29 additions and 26 deletions

View File

@ -1,8 +1,7 @@
from dataclasses import dataclass, field from transformers.configuration_utils import PretrainedConfig
@dataclass class ActionChunkingTransformerConfig(PretrainedConfig):
class ActionChunkingTransformerConfig:
"""Configuration class for the Action Chunking Transformers policy. """Configuration class for the Action Chunking Transformers policy.
Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer". Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".
@ -55,37 +54,41 @@ class ActionChunkingTransformerConfig:
dropout: Dropout to use in the transformer layers (see code for details). dropout: Dropout to use in the transformer layers (see code for details).
kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective
is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`. is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`.
"""
Example:
```python
>>> from lerobot import ActionChunkingTransformerConfig
>>> # Initializing an ACT style configuration
>>> configuration = ActionChunkingTransformerConfig()
>>> # Initializing a model (with random weights) from the ACT style configuration
>>> model = ActionChunkingTransformerPolicy(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
# Input / output structure. # Input / output structure.
n_obs_steps: int = 1 n_obs_steps: int = 1
chunk_size: int = 100 chunk_size: int = 100
n_action_steps: int = 100 n_action_steps: int = 100
input_shapes: dict[str, list[str]] = field( input_shapes: dict[str, list[str]] = {
default_factory=lambda: { "observation.images.top": [3, 480, 640],
"observation.images.top": [3, 480, 640], "observation.state": [14],
"observation.state": [14], }
}
) output_shapes: dict[str, list[str]] = {"action": [14]}
output_shapes: dict[str, list[str]] = field(
default_factory=lambda: {
"action": [14],
}
)
# Normalization / Unnormalization # Normalization / Unnormalization
normalize_input_modes: dict[str, str] = field( normalize_input_modes: dict[str, str] = {
default_factory=lambda: { "observation.image": "mean_std",
"observation.image": "mean_std", "observation.state": "mean_std",
"observation.state": "mean_std", }
}
) unnormalize_output_modes: dict[str, str] = {"action": "mean_std"}
unnormalize_output_modes: dict[str, str] = field(
default_factory=lambda: {
"action": "mean_std",
}
)
# Architecture. # Architecture.
# Vision backbone. # Vision backbone.