Compare commits
29 Commits
166b2463b0
...
d5b8b26c65
Author | SHA1 | Date |
---|---|---|
|
d5b8b26c65 | |
|
b43ece8934 | |
|
c10c5a0e64 | |
|
a8db91c40e | |
|
0f5f7ac780 | |
|
768e36660d | |
|
790d6740ba | |
|
3d8a29fe6e | |
|
78c05cf0be | |
|
035e95a41b | |
|
48d7213a8a | |
|
ea1c582239 | |
|
d1bec3e8ae | |
|
1329954dba | |
|
dc4e94fc65 | |
|
17572b3211 | |
|
aa70e14033 | |
|
cf6e677485 | |
|
6ca03b0dac | |
|
cee77f3d4e | |
|
f585cec385 | |
|
cae49528ee | |
|
c6bcfb3539 | |
|
2642d58b7a | |
|
3b8a85a3f2 | |
|
42cca28332 | |
|
b6face0179 | |
|
489cdc2ace | |
|
78d3ba8db2 |
|
@ -48,7 +48,7 @@ repos:
|
|||
- id: pyupgrade
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.11.4
|
||||
rev: v0.11.5
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
|
@ -57,7 +57,7 @@ repos:
|
|||
|
||||
##### Security #####
|
||||
- repo: https://github.com/gitleaks/gitleaks
|
||||
rev: v8.24.2
|
||||
rev: v8.24.3
|
||||
hooks:
|
||||
- id: gitleaks
|
||||
|
||||
|
|
|
@ -103,13 +103,20 @@ When using `miniconda`, install `ffmpeg` in your environment:
|
|||
conda install ffmpeg -c conda-forge
|
||||
```
|
||||
|
||||
> **NOTE:** This usually installs `ffmpeg 7.X` for your platform compiled with the `libsvtav1` encoder. If `libsvtav1` is not supported (check supported encoders with `ffmpeg -encoders`), you can:
|
||||
> - _[On any platform]_ Explicitly install `ffmpeg 7.X` using:
|
||||
> ```bash
|
||||
> conda install ffmpeg=7.1.1 -c conda-forge
|
||||
> ```
|
||||
> - _[On Linux only]_ Install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), and make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`.
|
||||
|
||||
Install 🤗 LeRobot:
|
||||
```bash
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
> **NOTE:** If you encounter build errors, you may need to install additional dependencies (`cmake`, `build-essential`, and `ffmpeg libs`). On Linux, run:
|
||||
`sudo apt-get install cmake build-essential python-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev pkg-config`. For other systems, see: [Compiling PyAV](https://pyav.org/docs/develop/overview/installation.html#bring-your-own-ffmpeg)
|
||||
`sudo apt-get install cmake build-essential python3-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev pkg-config`. For other systems, see: [Compiling PyAV](https://pyav.org/docs/develop/overview/installation.html#bring-your-own-ffmpeg)
|
||||
|
||||
For simulations, 🤗 LeRobot comes with gymnasium environments that can be installed as extras:
|
||||
- [aloha](https://github.com/huggingface/gym-aloha)
|
||||
|
|
|
@ -14,7 +14,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
|||
tcpdump sysstat screen tmux \
|
||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa \
|
||||
speech-dispatcher portaudio19-dev libgeos-dev \
|
||||
python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \
|
||||
python${PYTHON_VERSION} python${PYTHON_VERSION}-venv python${PYTHON_VERSION}-dev \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install ffmpeg build dependencies. See:
|
||||
|
|
|
@ -4,7 +4,7 @@ This tutorial will explain the training script, how to use it, and particularly
|
|||
|
||||
## The training script
|
||||
|
||||
LeRobot offers a training script at [`lerobot/scripts/train.py`](../../lerobot/scripts/train.py). At a high level it does the following:
|
||||
LeRobot offers a training script at [`lerobot/scripts/train.py`](../lerobot/scripts/train.py). At a high level it does the following:
|
||||
|
||||
- Initialize/load a configuration for the following steps using.
|
||||
- Instantiates a dataset.
|
||||
|
@ -21,7 +21,7 @@ In the training script, the main function `train` expects a `TrainPipelineConfig
|
|||
def train(cfg: TrainPipelineConfig):
|
||||
```
|
||||
|
||||
You can inspect the `TrainPipelineConfig` defined in [`lerobot/configs/train.py`](../../lerobot/configs/train.py) (which is heavily commented and meant to be a reference to understand any option)
|
||||
You can inspect the `TrainPipelineConfig` defined in [`lerobot/configs/train.py`](../lerobot/configs/train.py) (which is heavily commented and meant to be a reference to understand any option)
|
||||
|
||||
When running the script, inputs for the command line are parsed thanks to the `@parser.wrap()` decorator and an instance of this class is automatically generated. Under the hood, this is done with [Draccus](https://github.com/dlwh/draccus) which is a tool dedicated for this purpose. If you're familiar with Hydra, Draccus can similarly load configurations from config files (.json, .yaml) and also override their values through command line inputs. Unlike Hydra, these configurations are pre-defined in the code through dataclasses rather than being defined entirely in config files. This allows for more rigorous serialization/deserialization, typing, and to manipulate configuration as objects directly in the code and not as dictionaries or namespaces (which enables nice features in an IDE such as autocomplete, jump-to-def, etc.)
|
||||
|
||||
|
@ -50,7 +50,7 @@ By default, every field takes its default value specified in the dataclass. If a
|
|||
|
||||
## Specifying values from the CLI
|
||||
|
||||
Let's say that we want to train [Diffusion Policy](../../lerobot/common/policies/diffusion) on the [pusht](https://huggingface.co/datasets/lerobot/pusht) dataset, using the [gym_pusht](https://github.com/huggingface/gym-pusht) environment for evaluation. The command to do so would look like this:
|
||||
Let's say that we want to train [Diffusion Policy](../lerobot/common/policies/diffusion) on the [pusht](https://huggingface.co/datasets/lerobot/pusht) dataset, using the [gym_pusht](https://github.com/huggingface/gym-pusht) environment for evaluation. The command to do so would look like this:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--dataset.repo_id=lerobot/pusht \
|
||||
|
@ -60,10 +60,10 @@ python lerobot/scripts/train.py \
|
|||
|
||||
Let's break this down:
|
||||
- To specify the dataset, we just need to specify its `repo_id` on the hub which is the only required argument in the `DatasetConfig`. The rest of the fields have default values and in this case we are fine with those so we can just add the option `--dataset.repo_id=lerobot/pusht`.
|
||||
- To specify the policy, we can just select diffusion policy using `--policy` appended with `.type`. Here, `.type` is a special argument which allows us to select config classes inheriting from `draccus.ChoiceRegistry` and that have been decorated with the `register_subclass()` method. To have a better explanation of this feature, have a look at this [Draccus demo](https://github.com/dlwh/draccus?tab=readme-ov-file#more-flexible-configuration-with-choice-types). In our code, we use this mechanism mainly to select policies, environments, robots, and some other components like optimizers. The policies available to select are located in [lerobot/common/policies](../../lerobot/common/policies)
|
||||
- Similarly, we select the environment with `--env.type=pusht`. The different environment configs are available in [`lerobot/common/envs/configs.py`](../../lerobot/common/envs/configs.py)
|
||||
- To specify the policy, we can just select diffusion policy using `--policy` appended with `.type`. Here, `.type` is a special argument which allows us to select config classes inheriting from `draccus.ChoiceRegistry` and that have been decorated with the `register_subclass()` method. To have a better explanation of this feature, have a look at this [Draccus demo](https://github.com/dlwh/draccus?tab=readme-ov-file#more-flexible-configuration-with-choice-types). In our code, we use this mechanism mainly to select policies, environments, robots, and some other components like optimizers. The policies available to select are located in [lerobot/common/policies](../lerobot/common/policies)
|
||||
- Similarly, we select the environment with `--env.type=pusht`. The different environment configs are available in [`lerobot/common/envs/configs.py`](../lerobot/common/envs/configs.py)
|
||||
|
||||
Let's see another example. Let's say you've been training [ACT](../../lerobot/common/policies/act) on [lerobot/aloha_sim_insertion_human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human) using the [gym-aloha](https://github.com/huggingface/gym-aloha) environment for evaluation with:
|
||||
Let's see another example. Let's say you've been training [ACT](../lerobot/common/policies/act) on [lerobot/aloha_sim_insertion_human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human) using the [gym-aloha](https://github.com/huggingface/gym-aloha) environment for evaluation with:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.type=act \
|
||||
|
@ -74,7 +74,7 @@ python lerobot/scripts/train.py \
|
|||
> Notice we added `--output_dir` to explicitly tell where to write outputs from this run (checkpoints, training state, configs etc.). This is not mandatory and if you don't specify it, a default directory will be created from the current date and time, env.type and policy.type. This will typically look like `outputs/train/2025-01-24/16-10-05_aloha_act`.
|
||||
|
||||
We now want to train a different policy for aloha on another task. We'll change the dataset and use [lerobot/aloha_sim_transfer_cube_human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human) instead. Of course, we also need to change the task of the environment as well to match this other task.
|
||||
Looking at the [`AlohaEnv`](../../lerobot/common/envs/configs.py) config, the task is `"AlohaInsertion-v0"` by default, which corresponds to the task we trained on in the command above. The [gym-aloha](https://github.com/huggingface/gym-aloha?tab=readme-ov-file#description) environment also has the `AlohaTransferCube-v0` task which corresponds to this other task we want to train on. Putting this together, we can train this new policy on this different task using:
|
||||
Looking at the [`AlohaEnv`](../lerobot/common/envs/configs.py) config, the task is `"AlohaInsertion-v0"` by default, which corresponds to the task we trained on in the command above. The [gym-aloha](https://github.com/huggingface/gym-aloha?tab=readme-ov-file#description) environment also has the `AlohaTransferCube-v0` task which corresponds to this other task we want to train on. Putting this together, we can train this new policy on this different task using:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.type=act \
|
||||
|
|
|
@ -830,11 +830,6 @@ It contains:
|
|||
- `dtRphone:33.84 (29.5hz)` which is the delta time of capturing an image from the phone camera in the thread running asynchronously.
|
||||
|
||||
Troubleshooting:
|
||||
- On Linux, if you encounter any issue during video encoding with `ffmpeg: unknown encoder libsvtav1`, you can:
|
||||
- install with conda-forge by running `conda install -c conda-forge ffmpeg` (it should be compiled with `libsvtav1`),
|
||||
> **NOTE:** This usually installs `ffmpeg 7.X` for your platform (check the version installed with `ffmpeg -encoders | grep libsvtav1`). If it isn't `ffmpeg 7.X` or lacks `libsvtav1` support, you can explicitly install `ffmpeg 7.X` using: `conda install ffmpeg=7.1.1 -c conda-forge`
|
||||
- or, install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1),
|
||||
- and, make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`.
|
||||
- On Linux, if the left and right arrow keys and escape key don't have any effect during data recording, make sure you've set the `$DISPLAY` environment variable. See [pynput limitations](https://pynput.readthedocs.io/en/latest/limitations.html#linux).
|
||||
|
||||
At the end of data recording, your dataset will be uploaded on your Hugging Face page (e.g. https://huggingface.co/datasets/cadene/koch_test) that you can obtain by running:
|
||||
|
|
|
@ -20,7 +20,7 @@ from pathlib import Path
|
|||
|
||||
import draccus
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR, LRScheduler
|
||||
|
||||
from lerobot.common.constants import SCHEDULER_STATE
|
||||
from lerobot.common.datasets.utils import write_json
|
||||
|
@ -120,3 +120,16 @@ def load_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> LRScheduler:
|
|||
state_dict = deserialize_json_into_object(save_dir / SCHEDULER_STATE, scheduler.state_dict())
|
||||
scheduler.load_state_dict(state_dict)
|
||||
return scheduler
|
||||
|
||||
|
||||
@LRSchedulerConfig.register_subclass("cosine_annealing")
|
||||
@dataclass
|
||||
class CosineAnnealingSchedulerConfig(LRSchedulerConfig):
|
||||
"""Implements Cosine Annealing learning rate scheduler"""
|
||||
|
||||
min_lr: float = 0 # Minimum learning rate
|
||||
T_max: int = 100000 # Number of iterations for a full decay (half-cycle)
|
||||
num_warmup_steps: int = 0 # Not used but somehow required by the parent class
|
||||
|
||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LRScheduler:
|
||||
return CosineAnnealingLR(optimizer, T_max=self.T_max, eta_min=self.min_lr)
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
|
||||
from .act.configuration_act import ACTConfig as ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||
from .dot.configuration_dot import DOTConfig as DOTConfig
|
||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
|
||||
|
|
|
@ -0,0 +1,212 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Ilia Larchenko and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common.optim.optimizers import AdamWConfig
|
||||
from lerobot.common.optim.schedulers import CosineAnnealingSchedulerConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("dot")
|
||||
@dataclass
|
||||
class DOTConfig(PreTrainedConfig):
|
||||
"""Configuration class for the Decision Transformer (DOT) policy.
|
||||
|
||||
DOT is a transformer-based policy for sequential decision making that predicts future actions based on
|
||||
a history of past observations and actions. This configuration enables fine-grained
|
||||
control over the model’s temporal horizon, input normalization, architectural parameters, and
|
||||
augmentation strategies.
|
||||
|
||||
Defaults are configured for general robot manipulation tasks like Push-T and ALOHA insert/transfer.
|
||||
|
||||
The parameters you will most likely need to modify are those related to temporal structure and
|
||||
normalization:
|
||||
- `train_horizon` and `inference_horizon`
|
||||
- `lookback_obs_steps` and `lookback_aug`
|
||||
- `alpha` and `train_alpha`
|
||||
- `normalization_mapping`
|
||||
|
||||
Notes on the temporal design:
|
||||
- `train_horizon`: Length of action sequence the model is trained on. Must be ≥ `inference_horizon`.
|
||||
- `inference_horizon`: How far into the future the model predicts during inference (in environment steps).
|
||||
A good rule of thumb is 2×FPS (e.g., 30–50 for 15–25 FPS environments).
|
||||
- `alpha` / `train_alpha`: Control exponential decay of loss weights for inference and training.
|
||||
These should be tuned such that all predicted steps contribute meaningful signal.
|
||||
|
||||
Notes on the inputs:
|
||||
- Observations can come from:
|
||||
- Images (e.g., keys starting with `"observation.images"`)
|
||||
- Proprioceptive state (`"observation.state"`)
|
||||
- Environment state (`"observation.environment_state"`)
|
||||
- At least one of image or environment state inputs must be provided.
|
||||
- The "action" key is required as an output.
|
||||
|
||||
Args:
|
||||
n_obs_steps: Number of past steps passed to the model, including the current step.
|
||||
train_horizon: Number of future steps the model is trained to predict.
|
||||
inference_horizon: Number of future steps predicted during inference.
|
||||
lookback_obs_steps: Number of past steps to include for temporal context.
|
||||
lookback_aug: Number of steps into the far past from which to randomly sample for augmentation.
|
||||
normalization_mapping: Dictionary specifying normalization mode for each input/output group.
|
||||
override_dataset_stats: If True, replaces the dataset's stats with manually defined `new_dataset_stats`.
|
||||
new_dataset_stats: Optional manual min/max overrides used if `override_dataset_stats=True`.
|
||||
vision_backbone: Name of the ResNet variant used for image encoding (e.g., "resnet18").
|
||||
pretrained_backbone_weights: Optional pretrained weights (e.g., "ResNet18_Weights.IMAGENET1K_V1").
|
||||
pre_norm: Whether to apply pre-norm in transformer layers.
|
||||
lora_rank: If > 0, applies LoRA adapters of the given rank to transformer layers.
|
||||
merge_lora: Whether to merge LoRA weights at inference time.
|
||||
dim_model: Dimension of the transformer hidden state.
|
||||
n_heads: Number of attention heads.
|
||||
dim_feedforward: Dimension of the feedforward MLP inside the transformer.
|
||||
n_decoder_layers: Number of transformer decoder layers.
|
||||
rescale_shape: Resize shape for input images (e.g., (96, 96)).
|
||||
crop_scale: Image crop scale for augmentation.
|
||||
state_noise: Magnitude of additive uniform noise for state inputs.
|
||||
noise_decay: Decay factor applied to `crop_scale` and `state_noise` during training.
|
||||
dropout: Dropout rate used in transformer layers.
|
||||
alpha: Decay factor for inference loss weighting.
|
||||
train_alpha: Decay factor for training loss weighting.
|
||||
predict_every_n: Predict actions every `n` frames instead of every frame.
|
||||
return_every_n: Return every `n`-th predicted action during inference.
|
||||
optimizer_lr: Initial learning rate.
|
||||
optimizer_min_lr: Minimum learning rate for cosine scheduler.
|
||||
optimizer_lr_cycle_steps: Total steps in one learning rate cycle.
|
||||
optimizer_weight_decay: L2 weight decay for optimizer.
|
||||
|
||||
Raises:
|
||||
ValueError: If the temporal settings are inconsistent (e.g., `train_horizon < inference_horizon`,
|
||||
or `predict_every_n` > allowed bounds).
|
||||
"""
|
||||
|
||||
# Input / output structure.
|
||||
n_obs_steps: int = 3
|
||||
train_horizon: int = 20
|
||||
inference_horizon: int = 20
|
||||
lookback_obs_steps: int = 10
|
||||
lookback_aug: int = 5
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
"STATE": NormalizationMode.MIN_MAX,
|
||||
"ENV": NormalizationMode.MIN_MAX,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
)
|
||||
|
||||
# Align with the new config system
|
||||
override_dataset_stats: bool = False
|
||||
new_dataset_stats: dict[str, dict[str, list[float]]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": {"max": [512.0] * 2, "min": [0.0] * 2},
|
||||
"observation.environment_state": {"max": [512.0] * 16, "min": [0.0] * 16},
|
||||
"observation.state": {"max": [512.0] * 2, "min": [0.0] * 2},
|
||||
}
|
||||
)
|
||||
|
||||
# Architecture.
|
||||
vision_backbone: str = "resnet18"
|
||||
pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1"
|
||||
pre_norm: bool = True
|
||||
lora_rank: int = 20
|
||||
merge_lora: bool = False
|
||||
|
||||
dim_model: int = 128
|
||||
n_heads: int = 8
|
||||
dim_feedforward: int = 512
|
||||
n_decoder_layers: int = 8
|
||||
rescale_shape: tuple[int, int] = (96, 96)
|
||||
|
||||
# Augmentation.
|
||||
crop_scale: float = 0.8
|
||||
state_noise: float = 0.01
|
||||
noise_decay: float = 0.999995
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: float = 0.1
|
||||
|
||||
# Weighting and inference.
|
||||
alpha: float = 0.75
|
||||
train_alpha: float = 0.9
|
||||
predict_every_n: int = 1
|
||||
return_every_n: int = 1
|
||||
|
||||
# Training preset
|
||||
optimizer_lr: float = 1.0e-4
|
||||
optimizer_min_lr: float = 1.0e-4
|
||||
optimizer_lr_cycle_steps: int = 300000
|
||||
optimizer_weight_decay: float = 1e-5
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self.predict_every_n > self.inference_horizon:
|
||||
raise ValueError(
|
||||
f"predict_every_n ({self.predict_every_n}) must be less than or equal to horizon ({self.inference_horizon})."
|
||||
)
|
||||
if self.return_every_n > self.inference_horizon:
|
||||
raise ValueError(
|
||||
f"return_every_n ({self.return_every_n}) must be less than or equal to horizon ({self.inference_horizon})."
|
||||
)
|
||||
if self.predict_every_n > self.inference_horizon // self.return_every_n:
|
||||
raise ValueError(
|
||||
f"predict_every_n ({self.predict_every_n}) must be less than or equal to horizon // return_every_n({self.inference_horizon // self.return_every_n})."
|
||||
)
|
||||
if self.train_horizon < self.inference_horizon:
|
||||
raise ValueError(
|
||||
f"train_horizon ({self.train_horizon}) must be greater than or equal to horizon ({self.inference_horizon})."
|
||||
)
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> None:
|
||||
return CosineAnnealingSchedulerConfig(
|
||||
min_lr=self.optimizer_min_lr, T_max=self.optimizer_lr_cycle_steps
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if not self.image_features and not self.env_state_feature:
|
||||
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> None:
|
||||
far_past_obs = list(
|
||||
range(
|
||||
-self.lookback_aug - self.lookback_obs_steps, self.lookback_aug + 1 - self.lookback_obs_steps
|
||||
)
|
||||
)
|
||||
recent_obs = list(range(2 - self.n_obs_steps, 1))
|
||||
|
||||
return far_past_obs + recent_obs
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
far_past_actions = list(
|
||||
range(
|
||||
-self.lookback_aug - self.lookback_obs_steps, self.lookback_aug + 1 - self.lookback_obs_steps
|
||||
)
|
||||
)
|
||||
recent_actions = list(range(2 - self.n_obs_steps, self.train_horizon))
|
||||
|
||||
return far_past_actions + recent_actions
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
|
@ -0,0 +1,558 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Ilia Larchenko and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""The implementation of the Decoder-Only Transformer (DOT) policy.
|
||||
|
||||
More details here: https://github.com/IliaLarchenko/dot_policy
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
from torch import Tensor, nn
|
||||
from torchvision import transforms
|
||||
from torchvision.ops.misc import FrozenBatchNorm2d
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
|
||||
from lerobot.common.policies.dot.configuration_dot import DOTConfig
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
|
||||
class DOT(nn.Module):
|
||||
"""The underlying neural network for DOT
|
||||
Note: Unlike ACT, DOT has no encoder, no VAE, and no cross-attention. All inputs are directly projected
|
||||
to the model dimension and passed as memory to a Transformer decoder.
|
||||
|
||||
- Inputs (images, state, env_state) are linearly projected and concatenated.
|
||||
- A trainable prefix token and positional embeddings are added.
|
||||
- The Transformer decoder predicts a sequence of future actions autoregressively.
|
||||
|
||||
DOT Transformer
|
||||
Used for autoregressive action prediction
|
||||
(no encoder, no VAE)
|
||||
|
||||
┌──────────────────────────────────────────────────────┐
|
||||
│ image emb. state emb. env_state emb. │
|
||||
│ │ │ │ │
|
||||
│ ┌───────┘ │ │ │
|
||||
│ │ ┌────────┘ │ │
|
||||
│ ▼ ▼ ▼ │
|
||||
│ ┌──────────────────────────────────────────┐ │
|
||||
│ │ Concatenate + Add Positional Emb. │ │
|
||||
│ └──────────────────────────────────────────┘ │
|
||||
│ │ │
|
||||
│ ▼ │
|
||||
│ ┌───────────────────────────────────┐ │
|
||||
│ │ Transformer Decoder (L layers)│ │
|
||||
│ └───────────────────────────────────┘ │
|
||||
│ │ │
|
||||
│ ▼ │
|
||||
│ Linear projection to action space │
|
||||
│ │ │
|
||||
│ ▼ │
|
||||
│ Outputs │
|
||||
└──────────────────────────────────────────────────────┘
|
||||
"""
|
||||
|
||||
def __init__(self, config: DOTConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.projections = nn.ModuleDict()
|
||||
self.n_features = 0
|
||||
|
||||
self.image_names = sorted(config.image_features.keys())
|
||||
|
||||
# Set up a shared visual backbone (e.g., ResNet18) for all cameras.
|
||||
# The final layer is replaced with a linear projection to match model_dim.
|
||||
if len(self.image_names) > 0:
|
||||
backbone = getattr(torchvision.models, self.config.vision_backbone)(
|
||||
weights=self.config.pretrained_backbone_weights,
|
||||
norm_layer=FrozenBatchNorm2d,
|
||||
)
|
||||
backbone.fc = nn.Linear(backbone.fc.in_features, self.config.dim_model)
|
||||
|
||||
self.projections["images"] = add_lora_to_backbone(backbone, rank=config.lora_rank)
|
||||
self.n_features += len(self.image_names) * self.config.n_obs_steps
|
||||
|
||||
if self.config.robot_state_feature:
|
||||
self.projections["state"] = nn.Linear(
|
||||
self.config.robot_state_feature.shape[0], self.config.dim_model
|
||||
)
|
||||
self.n_features += self.config.n_obs_steps
|
||||
|
||||
if self.config.env_state_feature:
|
||||
self.projections["environment_state"] = nn.Linear(
|
||||
self.config.env_state_feature.shape[0], self.config.dim_model
|
||||
)
|
||||
self.n_features += self.config.n_obs_steps
|
||||
|
||||
self.projections_names = sorted(self.projections.keys())
|
||||
obs_mapping = {
|
||||
"images": "observation.images",
|
||||
"state": "observation.state",
|
||||
"environment_state": "observation.environment_state",
|
||||
}
|
||||
self.obs_mapping = {k: v for k, v in obs_mapping.items() if k in self.projections_names}
|
||||
|
||||
# Optional trainable prefix token added to the input sequence (can be used for task conditioning or extra context)
|
||||
self.prefix_input = nn.Parameter(torch.randn(1, 1, config.dim_model))
|
||||
|
||||
# Setup transformer decoder
|
||||
dec_layer = nn.TransformerDecoderLayer(
|
||||
d_model=self.config.dim_model,
|
||||
nhead=self.config.n_heads,
|
||||
dim_feedforward=self.config.dim_feedforward,
|
||||
dropout=self.config.dropout,
|
||||
batch_first=True,
|
||||
norm_first=self.config.pre_norm,
|
||||
)
|
||||
|
||||
decoder_norm = nn.LayerNorm(self.config.dim_model)
|
||||
self.decoder = nn.TransformerDecoder(
|
||||
dec_layer, num_layers=self.config.n_decoder_layers, norm=decoder_norm
|
||||
)
|
||||
|
||||
# Sinusoidal positional encodings for the decoder input tokens (fixed, not trainable)
|
||||
decoder_pos = create_sinusoidal_pos_embedding(
|
||||
config.train_horizon + config.lookback_obs_steps, config.dim_model
|
||||
)
|
||||
decoder_pos = torch.cat(
|
||||
[
|
||||
decoder_pos[:1],
|
||||
decoder_pos[-config.train_horizon - config.n_obs_steps + 2 :],
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
self.register_buffer("decoder_pos", decoder_pos)
|
||||
|
||||
# Extend positional encodings for inference (when inference_horizon > train_horizon)
|
||||
decoder_pos_inf = self.decoder_pos[
|
||||
: self.decoder_pos.shape[0] + self.config.inference_horizon - self.config.train_horizon
|
||||
]
|
||||
self.register_buffer("decoder_pos_inf", decoder_pos_inf)
|
||||
# Causal mask for decoder: prevent attending to future positions
|
||||
mask = torch.zeros(len(decoder_pos), len(decoder_pos), dtype=torch.bool)
|
||||
mask[
|
||||
: len(decoder_pos) + config.inference_horizon - config.train_horizon,
|
||||
len(decoder_pos) + config.inference_horizon - config.train_horizon :,
|
||||
] = True
|
||||
self.register_buffer("mask", mask)
|
||||
|
||||
# Learnable positional embeddings for input tokens (state/image/env projections)
|
||||
self.inputs_pos_emb = nn.Parameter(torch.empty(1, self.n_features, self.config.dim_model))
|
||||
nn.init.uniform_(
|
||||
self.inputs_pos_emb,
|
||||
-((1 / self.config.dim_model) ** 0.5),
|
||||
(1 / self.config.dim_model) ** 0.5,
|
||||
)
|
||||
|
||||
# The output actions are generated by a linear layer
|
||||
self.action_head = nn.Linear(self.config.dim_model, self.config.action_feature.shape[0])
|
||||
|
||||
def _process_inputs(self, batch):
|
||||
# Project all inputs to the model dimension and concatenate them
|
||||
inputs_projections_list = []
|
||||
|
||||
for state in self.projections_names:
|
||||
batch_state = self.obs_mapping[state]
|
||||
if batch_state in batch:
|
||||
batch_size, n_obs, *obs_shape = batch[batch_state].shape
|
||||
enc = self.projections[state](batch[batch_state].view(batch_size * n_obs, *obs_shape)).view(
|
||||
batch_size, n_obs, -1
|
||||
)
|
||||
inputs_projections_list.append(enc)
|
||||
|
||||
return torch.cat(inputs_projections_list, dim=1)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
|
||||
"""
|
||||
A forward pass through the Decision Transformer (DOT).
|
||||
|
||||
The model uses a transformer decoder to predict a sequence of future actions from projected
|
||||
and positionally-embedded image, state, and environment features.
|
||||
|
||||
Args:
|
||||
batch (dict): A dictionary containing the following keys (if available):
|
||||
- "observation.images": (B, T, C, H, W) tensor of camera frames.
|
||||
- "observation.state": (B, T, D) tensor of proprioceptive robot states.
|
||||
- "observation.environment_state": (B, T, D) tensor of environment states.
|
||||
|
||||
Returns:
|
||||
Tensor: A tensor of shape (B, horizon, action_dim) containing predicted future actions.
|
||||
"""
|
||||
# Project image/state/env_state inputs to the model dimension and concatenate along the time axis.
|
||||
inputs_projections = self._process_inputs(batch) # (B, T, D)
|
||||
batch_size = inputs_projections.shape[0]
|
||||
|
||||
# Add learnable positional embeddings to each projected input token.
|
||||
inputs_projections += self.inputs_pos_emb.expand(batch_size, -1, -1)
|
||||
|
||||
# Prepend a trainable prefix token to the input sequence
|
||||
inputs_projections = torch.cat(
|
||||
[self.prefix_input.expand(batch_size, -1, -1), inputs_projections], dim=1
|
||||
) # (B, T+1, D)
|
||||
|
||||
# Use different positional encodings and masks for training vs. inference.
|
||||
if self.training:
|
||||
decoder_out = self.decoder(
|
||||
self.decoder_pos.expand(batch_size, -1, -1), inputs_projections, self.mask
|
||||
)
|
||||
else:
|
||||
decoder_out = self.decoder(self.decoder_pos_inf.expand(batch_size, -1, -1), inputs_projections)
|
||||
return self.action_head(decoder_out)
|
||||
|
||||
|
||||
class DOTPolicy(PreTrainedPolicy):
|
||||
"""
|
||||
Decision Transformer (DOT) Policy. (github: https://github.com/IliaLarchenko/dot_policy)
|
||||
|
||||
A minimal transformer decoder-based policy for autoregressive action prediction in robot control.
|
||||
This is a simplified alternative to ACT: no encoder, no VAE, and no cross-attention, making it efficient
|
||||
for deployment in low-dimensional environments with visual and proprioceptive inputs.
|
||||
"""
|
||||
|
||||
name = "dot"
|
||||
config_class = DOTConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: DOTConfig,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
config (DOTConfig): Configuration for the DOT model and policy behavior.
|
||||
dataset_stats (optional): Dataset statistics used for normalizing inputs/outputs.
|
||||
If not provided, stats should be set later via `load_state_dict()` before inference.
|
||||
"""
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.image_names = sorted(config.image_features.keys())
|
||||
|
||||
if config.override_dataset_stats:
|
||||
if dataset_stats is None:
|
||||
dataset_stats = {}
|
||||
for k, v in config.new_dataset_stats.items():
|
||||
if k not in dataset_stats:
|
||||
dataset_stats[k] = {}
|
||||
for k1, v1 in v.items():
|
||||
dataset_stats[k][k1] = torch.tensor(v1)
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.model = DOT(self.config)
|
||||
|
||||
self.state_noise = self.config.state_noise
|
||||
self.crop_scale = self.config.crop_scale
|
||||
self.alpha = self.config.alpha
|
||||
self.inference_horizon = self.config.inference_horizon
|
||||
self.return_every_n = self.config.return_every_n
|
||||
self.predict_every_n = self.config.predict_every_n
|
||||
|
||||
# Inference action chunking and observation queues
|
||||
self._old_predictions = None
|
||||
self._input_buffers = {}
|
||||
|
||||
# Weights used for chunking
|
||||
action_weights = self.alpha ** torch.arange(self.inference_horizon).float()
|
||||
action_weights /= action_weights.sum()
|
||||
action_weights = action_weights.view(1, -1, 1)
|
||||
self.register_buffer("action_weights", action_weights)
|
||||
|
||||
# Weights for the loss computations
|
||||
# Actions that are further in the future are weighted less
|
||||
loss_weights = torch.ones(self.config.train_horizon + self.config.n_obs_steps - 1)
|
||||
loss_weights[-self.config.train_horizon :] = (
|
||||
self.config.train_alpha ** torch.arange(self.config.train_horizon).float()
|
||||
)
|
||||
loss_weights /= loss_weights.mean()
|
||||
loss_weights = loss_weights.view(1, -1, 1)
|
||||
self.register_buffer("loss_weights", loss_weights)
|
||||
|
||||
# TODO(jadechoghari): Move augmentations to dataloader (__getitem__) for CPU-side processing.
|
||||
# Nearest interpolation is required for PushT but may be not the best in general
|
||||
self.resize_transform = transforms.Resize(
|
||||
config.rescale_shape, interpolation=InterpolationMode.NEAREST
|
||||
)
|
||||
|
||||
self.step = 0
|
||||
self.last_action = None
|
||||
|
||||
def reset(self):
|
||||
self._old_predictions = None
|
||||
self._input_buffers = {}
|
||||
self.last_action = None
|
||||
self.step = 0
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.model.parameters()
|
||||
|
||||
def _update_observation_buffers(self, buffer_name: str, observation: Tensor) -> Tensor:
|
||||
# Maintain a rolling buffer of lookback_obs_steps + 1;
|
||||
# shift left and append new observation each step
|
||||
if buffer_name not in self._input_buffers:
|
||||
self._input_buffers[buffer_name] = observation.unsqueeze(1).repeat(
|
||||
1,
|
||||
self.config.lookback_obs_steps + 1,
|
||||
*torch.ones(len(observation.shape[1:])).int(),
|
||||
)
|
||||
else:
|
||||
self._input_buffers[buffer_name] = self._input_buffers[buffer_name].roll(shifts=-1, dims=1)
|
||||
self._input_buffers[buffer_name][:, -1] = observation
|
||||
|
||||
return torch.cat(
|
||||
[
|
||||
self._input_buffers[buffer_name][:, :1],
|
||||
self._input_buffers[buffer_name][:, -(self.config.n_obs_steps - 1) :],
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
def _prepare_batch_for_inference(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
# Resize and stack all images
|
||||
if len(self.image_names) > 0:
|
||||
batch["observation.images"] = torch.stack(
|
||||
[self.resize_transform(batch[k]) for k in self.image_names],
|
||||
dim=1,
|
||||
) # batch_size, n_cam, c, h, w
|
||||
|
||||
# Update observation queues for all inputs and stack the last n_obs_steps
|
||||
for name, batch_name in self.model.obs_mapping.items():
|
||||
batch[batch_name] = self._update_observation_buffers(name, batch[batch_name])
|
||||
|
||||
# Reshape images tensor to keep the same order as during training
|
||||
if "observation.images" in batch:
|
||||
batch["observation.images"] = batch["observation.images"].flatten(1, 2)
|
||||
# batch_size, n_obs * n_cam, c, h, w
|
||||
|
||||
return batch
|
||||
|
||||
def _chunk_actions(self, actions: Tensor) -> Tensor:
|
||||
# Store the previous action predictions in a buffer
|
||||
# Compute the weighted average of the inference horizon action predictions
|
||||
if self._old_predictions is not None:
|
||||
self._old_predictions[:, 0] = actions
|
||||
else:
|
||||
self._old_predictions = actions.unsqueeze(1).repeat(1, self.config.inference_horizon, 1, 1)
|
||||
|
||||
action = (self._old_predictions[:, :, 0] * self.action_weights).sum(dim=1)
|
||||
self._old_predictions = self._old_predictions.roll(shifts=(1, -1), dims=(1, 2))
|
||||
|
||||
return action
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""
|
||||
Select an action given current environment observations.
|
||||
|
||||
This function handles autoregressive rollout during inference using a fixed prediction horizon.
|
||||
The model predicts every `predict_every_n` steps, and returns actions every `return_every_n` steps.
|
||||
Between predictions, previously predicted actions are reused by shifting and repeating the last step.
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
batch = self._prepare_batch_for_inference(batch)
|
||||
|
||||
# Only run model prediction every predict_every_n steps
|
||||
if self.step % self.predict_every_n == 0:
|
||||
actions_pred = self.model(batch)[:, -self.config.inference_horizon :]
|
||||
self.last_action = self.unnormalize_outputs({"action": actions_pred})["action"]
|
||||
else:
|
||||
# Otherwise shift previous predictions and repeat last action
|
||||
self.last_action = self.last_action.roll(-1, dims=1)
|
||||
self.last_action[:, -1] = self.last_action[:, -2]
|
||||
|
||||
self.step += 1
|
||||
|
||||
# Return chunked actions for return_every_n steps
|
||||
action = self._chunk_actions(self.last_action)
|
||||
for _ in range(self.return_every_n - 1):
|
||||
self.last_action = self.last_action.roll(-1, dims=1)
|
||||
self.last_action[:, -1] = self.last_action[:, -2]
|
||||
action = self._chunk_actions(self.last_action)
|
||||
|
||||
return action
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
lookback_ind = torch.randint(0, 2 * self.config.lookback_aug + 1, (1,)).item()
|
||||
for k in list(self.model.obs_mapping.values()) + list(self.image_names) + ["action", "action_is_pad"]:
|
||||
if k != "observation.images":
|
||||
batch[k] = torch.cat(
|
||||
[
|
||||
batch[k][:, lookback_ind : lookback_ind + 1],
|
||||
batch[k][:, 2 * self.config.lookback_aug + 1 :],
|
||||
],
|
||||
1,
|
||||
)
|
||||
batch = self.normalize_targets(self.normalize_inputs(batch))
|
||||
|
||||
if len(self.config.image_features) > 0:
|
||||
scale = 1 - torch.rand(1) * (1 - self.crop_scale)
|
||||
new_shape = (
|
||||
int(self.config.rescale_shape[0] * scale),
|
||||
int(self.config.rescale_shape[1] * scale),
|
||||
)
|
||||
crop_transform = transforms.RandomCrop(new_shape)
|
||||
|
||||
for k in self.image_names:
|
||||
batch_size, n_obs, c, h, w = batch[k].shape
|
||||
batch[k] = batch[k].view(batch_size * n_obs, c, h, w)
|
||||
batch[k] = crop_transform(self.resize_transform(batch[k]))
|
||||
batch[k] = batch[k].view(batch_size, n_obs, c, *batch[k].shape[-2:])
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.image_names], dim=2).flatten(
|
||||
1, 2
|
||||
) # batch_size, n_obs * n_cam, c, h, w
|
||||
|
||||
# Add random noise to states during training
|
||||
# TODO(jadechoghari): better to move this to the dataloader
|
||||
if self.state_noise is not None:
|
||||
for k in self.model.obs_mapping.values():
|
||||
if k != "observation.images":
|
||||
batch[k] += (torch.rand_like(batch[k]) * 2 - 1) * self.state_noise
|
||||
|
||||
actions_hat = self.model(batch)
|
||||
|
||||
l1_loss = nn.functional.l1_loss(batch["action"], actions_hat, reduction="none")
|
||||
rev_padding = (~batch["action_is_pad"]).unsqueeze(-1)
|
||||
|
||||
# Apply padding, weights and decay to the loss
|
||||
l1_loss = (l1_loss * rev_padding * self.loss_weights).mean()
|
||||
|
||||
loss_dict = {"l1_loss": l1_loss.item()}
|
||||
loss = l1_loss
|
||||
|
||||
# Reduce the aggressiveness of augmentations
|
||||
self.state_noise *= self.config.noise_decay
|
||||
self.crop_scale = 1 - (1 - self.crop_scale) * self.config.noise_decay
|
||||
|
||||
return loss, loss_dict
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_name_or_path, *args, **kwargs):
|
||||
"""Load model from pretrained checkpoint and merge LoRA after loading"""
|
||||
policy = super().from_pretrained(pretrained_name_or_path, *args, **kwargs)
|
||||
|
||||
if getattr(policy.config, "merge_lora", False):
|
||||
print("Merging LoRA after loading pretrained model...")
|
||||
policy.model = merge_lora_weights(policy.model)
|
||||
|
||||
return policy
|
||||
|
||||
|
||||
class LoRAConv2d(nn.Module):
|
||||
"""
|
||||
Applies Low-Rank Adaptation (LoRA) to a Conv2D layer.
|
||||
|
||||
LoRA adds trainable low-rank matrices (A and B) to adapt pretrained weights without full fine-tuning.
|
||||
The adaptation is merged into the base conv weights via `merge_lora()` after training.
|
||||
|
||||
Args:
|
||||
base_conv (nn.Conv2d): The original convolutional layer to be adapted.
|
||||
rank (int): The rank of the low-rank approximation (default: 4).
|
||||
"""
|
||||
|
||||
def __init__(self, base_conv: nn.Conv2d, rank: int = 4):
|
||||
super().__init__()
|
||||
self.base_conv = base_conv
|
||||
|
||||
# Flatten the original conv weight
|
||||
out_channels, in_channels, kh, kw = base_conv.weight.shape
|
||||
self.weight_shape = (out_channels, in_channels, kh, kw)
|
||||
fan_in = in_channels * kh * kw
|
||||
|
||||
# Low-rank trainable matrices A and B
|
||||
self.lora_A = nn.Parameter(torch.normal(0, 0.02, (out_channels, rank)))
|
||||
self.lora_B = nn.Parameter(torch.normal(0, 0.02, (rank, fan_in)))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
lora_update = torch.matmul(self.lora_A, self.lora_B).view(self.weight_shape)
|
||||
|
||||
return nn.functional.conv2d(
|
||||
x,
|
||||
self.base_conv.weight + lora_update,
|
||||
self.base_conv.bias,
|
||||
stride=self.base_conv.stride,
|
||||
padding=self.base_conv.padding,
|
||||
dilation=self.base_conv.dilation,
|
||||
groups=self.base_conv.groups,
|
||||
)
|
||||
|
||||
def merge_lora(self) -> nn.Conv2d:
|
||||
"""Merge LoRA weights into the base convolution and return a standard Conv2d layer"""
|
||||
lora_update = torch.matmul(self.lora_A, self.lora_B).view(self.weight_shape)
|
||||
self.base_conv.weight.copy_(self.base_conv.weight + lora_update)
|
||||
|
||||
return self.base_conv
|
||||
|
||||
|
||||
def replace_conv2d_with_lora(module: nn.Module, rank: int = 4) -> nn.Module:
|
||||
"""Recursively replace Conv2d layers with LoRAConv2d in the module"""
|
||||
for name, child in list(module.named_children()):
|
||||
if isinstance(child, nn.Conv2d):
|
||||
setattr(module, name, LoRAConv2d(child, rank))
|
||||
else:
|
||||
replace_conv2d_with_lora(child, rank)
|
||||
return module
|
||||
|
||||
|
||||
def merge_lora_weights(module: nn.Module) -> nn.Module:
|
||||
"""Recursively merge LoRA weights in the module"""
|
||||
for name, child in list(module.named_children()):
|
||||
if isinstance(child, LoRAConv2d):
|
||||
setattr(module, name, child.merge_lora())
|
||||
else:
|
||||
merge_lora_weights(child)
|
||||
return module
|
||||
|
||||
|
||||
def add_lora_to_backbone(backbone: nn.Module, rank: int = 4) -> nn.Module:
|
||||
"""
|
||||
Adds LoRA to a convolutional backbone by replacing Conv2d layers
|
||||
and freezing all other weights except LoRA layers and the final classifier.
|
||||
"""
|
||||
replace_conv2d_with_lora(backbone, rank)
|
||||
|
||||
for name, param in backbone.named_parameters():
|
||||
if "lora_" in name or name.startswith("fc"):
|
||||
param.requires_grad = True
|
||||
else:
|
||||
param.requires_grad = False
|
||||
|
||||
return backbone
|
||||
|
||||
|
||||
def create_sinusoidal_pos_embedding(num_positions: int, dimension: int) -> Tensor:
|
||||
"""Generates sinusoidal positional embeddings like in the original Transformer paper."""
|
||||
position = torch.arange(num_positions, dtype=torch.float).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, dimension, 2, dtype=torch.float) * (-math.log(10000.0) / dimension))
|
||||
pe = torch.zeros(num_positions, dimension)
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
return pe
|
|
@ -24,6 +24,7 @@ from lerobot.common.envs.configs import EnvConfig
|
|||
from lerobot.common.envs.utils import env_to_policy_features
|
||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.dot.configuration_dot import DOTConfig
|
||||
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
|
@ -59,6 +60,10 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
|
|||
from lerobot.common.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy
|
||||
|
||||
return PI0FASTPolicy
|
||||
elif name == "dot":
|
||||
from lerobot.common.policies.dot.modeling_dot import DOTPolicy
|
||||
|
||||
return DOTPolicy
|
||||
else:
|
||||
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
||||
|
||||
|
@ -76,6 +81,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
|||
return PI0Config(**kwargs)
|
||||
elif policy_type == "pi0fast":
|
||||
return PI0FASTConfig(**kwargs)
|
||||
elif policy_type == "dot":
|
||||
return DOTConfig(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
||||
|
||||
|
|
|
@ -512,13 +512,13 @@ if __name__ == "__main__":
|
|||
)
|
||||
parser.add_argument(
|
||||
"--width",
|
||||
type=str,
|
||||
type=int,
|
||||
default=640,
|
||||
help="Set the width for all cameras. If not provided, use the default width of each camera.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--height",
|
||||
type=str,
|
||||
type=int,
|
||||
default=480,
|
||||
help="Set the height for all cameras. If not provided, use the default height of each camera.",
|
||||
)
|
||||
|
|
|
@ -492,13 +492,13 @@ if __name__ == "__main__":
|
|||
)
|
||||
parser.add_argument(
|
||||
"--width",
|
||||
type=str,
|
||||
type=int,
|
||||
default=None,
|
||||
help="Set the width for all cameras. If not provided, use the default width of each camera.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--height",
|
||||
type=str,
|
||||
type=int,
|
||||
default=None,
|
||||
help="Set the height for all cameras. If not provided, use the default height of each camera.",
|
||||
)
|
||||
|
|
|
@ -174,7 +174,10 @@ def run_server(
|
|||
dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys
|
||||
]
|
||||
videos_info = [
|
||||
{"url": url_for("static", filename=video_path), "filename": video_path.parent.name}
|
||||
{
|
||||
"url": url_for("static", filename=str(video_path).replace("\\", "/")),
|
||||
"filename": video_path.parent.name,
|
||||
}
|
||||
for video_path in video_paths
|
||||
]
|
||||
tasks = dataset.meta.episodes[episode_id]["tasks"]
|
||||
|
@ -381,7 +384,7 @@ def visualize_dataset_html(
|
|||
if isinstance(dataset, LeRobotDataset):
|
||||
ln_videos_dir = static_dir / "videos"
|
||||
if not ln_videos_dir.exists():
|
||||
ln_videos_dir.symlink_to((dataset.root / "videos").resolve())
|
||||
ln_videos_dir.symlink_to((dataset.root / "videos").resolve().as_posix())
|
||||
|
||||
if serve:
|
||||
run_server(dataset, episodes, host, port, static_dir, template_dir)
|
||||
|
|
Loading…
Reference in New Issue