pretrained config for act
This commit is contained in:
parent
659c69a1c0
commit
783a40c9d4
|
@ -1,8 +1,7 @@
|
||||||
from dataclasses import dataclass, field
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
class ActionChunkingTransformerConfig(PretrainedConfig):
|
||||||
class ActionChunkingTransformerConfig:
|
|
||||||
"""Configuration class for the Action Chunking Transformers policy.
|
"""Configuration class for the Action Chunking Transformers policy.
|
||||||
|
|
||||||
Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".
|
Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".
|
||||||
|
@ -55,37 +54,41 @@ class ActionChunkingTransformerConfig:
|
||||||
dropout: Dropout to use in the transformer layers (see code for details).
|
dropout: Dropout to use in the transformer layers (see code for details).
|
||||||
kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective
|
kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective
|
||||||
is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`.
|
is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`.
|
||||||
"""
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from lerobot import ActionChunkingTransformerConfig
|
||||||
|
|
||||||
|
>>> # Initializing an ACT style configuration
|
||||||
|
>>> configuration = ActionChunkingTransformerConfig()
|
||||||
|
|
||||||
|
>>> # Initializing a model (with random weights) from the ACT style configuration
|
||||||
|
>>> model = ActionChunkingTransformerPolicy(configuration)
|
||||||
|
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```"""
|
||||||
|
|
||||||
# Input / output structure.
|
# Input / output structure.
|
||||||
n_obs_steps: int = 1
|
n_obs_steps: int = 1
|
||||||
chunk_size: int = 100
|
chunk_size: int = 100
|
||||||
n_action_steps: int = 100
|
n_action_steps: int = 100
|
||||||
|
|
||||||
input_shapes: dict[str, list[str]] = field(
|
input_shapes: dict[str, list[str]] = {
|
||||||
default_factory=lambda: {
|
"observation.images.top": [3, 480, 640],
|
||||||
"observation.images.top": [3, 480, 640],
|
"observation.state": [14],
|
||||||
"observation.state": [14],
|
}
|
||||||
}
|
|
||||||
)
|
output_shapes: dict[str, list[str]] = {"action": [14]}
|
||||||
output_shapes: dict[str, list[str]] = field(
|
|
||||||
default_factory=lambda: {
|
|
||||||
"action": [14],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Normalization / Unnormalization
|
# Normalization / Unnormalization
|
||||||
normalize_input_modes: dict[str, str] = field(
|
normalize_input_modes: dict[str, str] = {
|
||||||
default_factory=lambda: {
|
"observation.image": "mean_std",
|
||||||
"observation.image": "mean_std",
|
"observation.state": "mean_std",
|
||||||
"observation.state": "mean_std",
|
}
|
||||||
}
|
|
||||||
)
|
unnormalize_output_modes: dict[str, str] = {"action": "mean_std"}
|
||||||
unnormalize_output_modes: dict[str, str] = field(
|
|
||||||
default_factory=lambda: {
|
|
||||||
"action": "mean_std",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Architecture.
|
# Architecture.
|
||||||
# Vision backbone.
|
# Vision backbone.
|
||||||
|
|
Loading…
Reference in New Issue