pretrained config for act

This commit is contained in:
Quentin Gallouédec 2024-04-25 16:06:57 +02:00
parent 659c69a1c0
commit 783a40c9d4
1 changed files with 29 additions and 26 deletions

View File

@ -1,8 +1,7 @@
from dataclasses import dataclass, field
from transformers.configuration_utils import PretrainedConfig
@dataclass
class ActionChunkingTransformerConfig:
class ActionChunkingTransformerConfig(PretrainedConfig):
"""Configuration class for the Action Chunking Transformers policy.
Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".
@ -55,37 +54,41 @@ class ActionChunkingTransformerConfig:
dropout: Dropout to use in the transformer layers (see code for details).
kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective
is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`.
"""
Example:
```python
>>> from lerobot import ActionChunkingTransformerConfig
>>> # Initializing an ACT style configuration
>>> configuration = ActionChunkingTransformerConfig()
>>> # Initializing a model (with random weights) from the ACT style configuration
>>> model = ActionChunkingTransformerPolicy(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
# Input / output structure.
n_obs_steps: int = 1
chunk_size: int = 100
n_action_steps: int = 100
input_shapes: dict[str, list[str]] = field(
default_factory=lambda: {
"observation.images.top": [3, 480, 640],
"observation.state": [14],
}
)
output_shapes: dict[str, list[str]] = field(
default_factory=lambda: {
"action": [14],
}
)
input_shapes: dict[str, list[str]] = {
"observation.images.top": [3, 480, 640],
"observation.state": [14],
}
output_shapes: dict[str, list[str]] = {"action": [14]}
# Normalization / Unnormalization
normalize_input_modes: dict[str, str] = field(
default_factory=lambda: {
"observation.image": "mean_std",
"observation.state": "mean_std",
}
)
unnormalize_output_modes: dict[str, str] = field(
default_factory=lambda: {
"action": "mean_std",
}
)
normalize_input_modes: dict[str, str] = {
"observation.image": "mean_std",
"observation.state": "mean_std",
}
unnormalize_output_modes: dict[str, str] = {"action": "mean_std"}
# Architecture.
# Vision backbone.