pretrained config for act
This commit is contained in:
parent
659c69a1c0
commit
783a40c9d4
|
@ -1,8 +1,7 @@
|
|||
from dataclasses import dataclass, field
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActionChunkingTransformerConfig:
|
||||
class ActionChunkingTransformerConfig(PretrainedConfig):
|
||||
"""Configuration class for the Action Chunking Transformers policy.
|
||||
|
||||
Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".
|
||||
|
@ -55,37 +54,41 @@ class ActionChunkingTransformerConfig:
|
|||
dropout: Dropout to use in the transformer layers (see code for details).
|
||||
kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective
|
||||
is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`.
|
||||
"""
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from lerobot import ActionChunkingTransformerConfig
|
||||
|
||||
>>> # Initializing an ACT style configuration
|
||||
>>> configuration = ActionChunkingTransformerConfig()
|
||||
|
||||
>>> # Initializing a model (with random weights) from the ACT style configuration
|
||||
>>> model = ActionChunkingTransformerPolicy(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
# Input / output structure.
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 100
|
||||
n_action_steps: int = 100
|
||||
|
||||
input_shapes: dict[str, list[str]] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.images.top": [3, 480, 640],
|
||||
"observation.state": [14],
|
||||
}
|
||||
)
|
||||
output_shapes: dict[str, list[str]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": [14],
|
||||
}
|
||||
)
|
||||
input_shapes: dict[str, list[str]] = {
|
||||
"observation.images.top": [3, 480, 640],
|
||||
"observation.state": [14],
|
||||
}
|
||||
|
||||
output_shapes: dict[str, list[str]] = {"action": [14]}
|
||||
|
||||
# Normalization / Unnormalization
|
||||
normalize_input_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": "mean_std",
|
||||
"observation.state": "mean_std",
|
||||
}
|
||||
)
|
||||
unnormalize_output_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": "mean_std",
|
||||
}
|
||||
)
|
||||
normalize_input_modes: dict[str, str] = {
|
||||
"observation.image": "mean_std",
|
||||
"observation.state": "mean_std",
|
||||
}
|
||||
|
||||
unnormalize_output_modes: dict[str, str] = {"action": "mean_std"}
|
||||
|
||||
# Architecture.
|
||||
# Vision backbone.
|
||||
|
|
Loading…
Reference in New Issue