This commit is contained in:
Jade Choghari 2025-04-04 10:09:08 +00:00 committed by GitHub
commit 4b17751e5a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 792 additions and 1 deletions

View File

@ -20,7 +20,7 @@ from pathlib import Path
import draccus
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR, LRScheduler
from lerobot.common.constants import SCHEDULER_STATE
from lerobot.common.datasets.utils import write_json
@ -120,3 +120,16 @@ def load_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> LRScheduler:
state_dict = deserialize_json_into_object(save_dir / SCHEDULER_STATE, scheduler.state_dict())
scheduler.load_state_dict(state_dict)
return scheduler
@LRSchedulerConfig.register_subclass("cosine_annealing")
@dataclass
class CosineAnnealingSchedulerConfig(LRSchedulerConfig):
"""Implements Cosine Annealing learning rate scheduler"""
min_lr: float = 0 # Minimum learning rate
T_max: int = 100000 # Number of iterations for a full decay (half-cycle)
num_warmup_steps: int = 0 # Not used but somehow required by the parent class
def build(self, optimizer: Optimizer, num_training_steps: int) -> LRScheduler:
return CosineAnnealingLR(optimizer, T_max=self.T_max, eta_min=self.min_lr)

View File

@ -14,6 +14,7 @@
from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .dot.configuration_dot import DOTConfig as DOTConfig
from .pi0.configuration_pi0 import PI0Config as PI0Config
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig

View File

@ -0,0 +1,212 @@
#!/usr/bin/env python
# Copyright 2025 Ilia Larchenko and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.common.optim.optimizers import AdamWConfig
from lerobot.common.optim.schedulers import CosineAnnealingSchedulerConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
@PreTrainedConfig.register_subclass("dot")
@dataclass
class DOTConfig(PreTrainedConfig):
"""Configuration class for the Decision Transformer (DOT) policy.
DOT is a transformer-based policy for sequential decision making that predicts future actions based on
a history of past observations and actions. This configuration enables fine-grained
control over the models temporal horizon, input normalization, architectural parameters, and
augmentation strategies.
Defaults are configured for general robot manipulation tasks like Push-T and ALOHA insert/transfer.
The parameters you will most likely need to modify are those related to temporal structure and
normalization:
- `train_horizon` and `inference_horizon`
- `lookback_obs_steps` and `lookback_aug`
- `alpha` and `train_alpha`
- `normalization_mapping`
Notes on the temporal design:
- `train_horizon`: Length of action sequence the model is trained on. Must be `inference_horizon`.
- `inference_horizon`: How far into the future the model predicts during inference (in environment steps).
A good rule of thumb is 2×FPS (e.g., 3050 for 1525 FPS environments).
- `alpha` / `train_alpha`: Control exponential decay of loss weights for inference and training.
These should be tuned such that all predicted steps contribute meaningful signal.
Notes on the inputs:
- Observations can come from:
- Images (e.g., keys starting with `"observation.images"`)
- Proprioceptive state (`"observation.state"`)
- Environment state (`"observation.environment_state"`)
- At least one of image or environment state inputs must be provided.
- The "action" key is required as an output.
Args:
n_obs_steps: Number of past steps passed to the model, including the current step.
train_horizon: Number of future steps the model is trained to predict.
inference_horizon: Number of future steps predicted during inference.
lookback_obs_steps: Number of past steps to include for temporal context.
lookback_aug: Number of steps into the far past from which to randomly sample for augmentation.
normalization_mapping: Dictionary specifying normalization mode for each input/output group.
override_dataset_stats: If True, replaces the dataset's stats with manually defined `new_dataset_stats`.
new_dataset_stats: Optional manual min/max overrides used if `override_dataset_stats=True`.
vision_backbone: Name of the ResNet variant used for image encoding (e.g., "resnet18").
pretrained_backbone_weights: Optional pretrained weights (e.g., "ResNet18_Weights.IMAGENET1K_V1").
pre_norm: Whether to apply pre-norm in transformer layers.
lora_rank: If > 0, applies LoRA adapters of the given rank to transformer layers.
merge_lora: Whether to merge LoRA weights at inference time.
dim_model: Dimension of the transformer hidden state.
n_heads: Number of attention heads.
dim_feedforward: Dimension of the feedforward MLP inside the transformer.
n_decoder_layers: Number of transformer decoder layers.
rescale_shape: Resize shape for input images (e.g., (96, 96)).
crop_scale: Image crop scale for augmentation.
state_noise: Magnitude of additive uniform noise for state inputs.
noise_decay: Decay factor applied to `crop_scale` and `state_noise` during training.
dropout: Dropout rate used in transformer layers.
alpha: Decay factor for inference loss weighting.
train_alpha: Decay factor for training loss weighting.
predict_every_n: Predict actions every `n` frames instead of every frame.
return_every_n: Return every `n`-th predicted action during inference.
optimizer_lr: Initial learning rate.
optimizer_min_lr: Minimum learning rate for cosine scheduler.
optimizer_lr_cycle_steps: Total steps in one learning rate cycle.
optimizer_weight_decay: L2 weight decay for optimizer.
Raises:
ValueError: If the temporal settings are inconsistent (e.g., `train_horizon < inference_horizon`,
or `predict_every_n` > allowed bounds).
"""
# Input / output structure.
n_obs_steps: int = 3
train_horizon: int = 20
inference_horizon: int = 20
lookback_obs_steps: int = 10
lookback_aug: int = 5
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.MEAN_STD,
"STATE": NormalizationMode.MIN_MAX,
"ENV": NormalizationMode.MIN_MAX,
"ACTION": NormalizationMode.MIN_MAX,
}
)
# Align with the new config system
override_dataset_stats: bool = False
new_dataset_stats: dict[str, dict[str, list[float]]] = field(
default_factory=lambda: {
"action": {"max": [512.0] * 2, "min": [0.0] * 2},
"observation.environment_state": {"max": [512.0] * 16, "min": [0.0] * 16},
"observation.state": {"max": [512.0] * 2, "min": [0.0] * 2},
}
)
# Architecture.
vision_backbone: str = "resnet18"
pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1"
pre_norm: bool = True
lora_rank: int = 20
merge_lora: bool = False
dim_model: int = 128
n_heads: int = 8
dim_feedforward: int = 512
n_decoder_layers: int = 8
rescale_shape: tuple[int, int] = (96, 96)
# Augmentation.
crop_scale: float = 0.8
state_noise: float = 0.01
noise_decay: float = 0.999995
# Training and loss computation.
dropout: float = 0.1
# Weighting and inference.
alpha: float = 0.75
train_alpha: float = 0.9
predict_every_n: int = 1
return_every_n: int = 1
# Training preset
optimizer_lr: float = 1.0e-4
optimizer_min_lr: float = 1.0e-4
optimizer_lr_cycle_steps: int = 300000
optimizer_weight_decay: float = 1e-5
def __post_init__(self):
super().__post_init__()
if self.predict_every_n > self.inference_horizon:
raise ValueError(
f"predict_every_n ({self.predict_every_n}) must be less than or equal to horizon ({self.inference_horizon})."
)
if self.return_every_n > self.inference_horizon:
raise ValueError(
f"return_every_n ({self.return_every_n}) must be less than or equal to horizon ({self.inference_horizon})."
)
if self.predict_every_n > self.inference_horizon // self.return_every_n:
raise ValueError(
f"predict_every_n ({self.predict_every_n}) must be less than or equal to horizon // return_every_n({self.inference_horizon // self.return_every_n})."
)
if self.train_horizon < self.inference_horizon:
raise ValueError(
f"train_horizon ({self.train_horizon}) must be greater than or equal to horizon ({self.inference_horizon})."
)
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
weight_decay=self.optimizer_weight_decay,
)
def get_scheduler_preset(self) -> None:
return CosineAnnealingSchedulerConfig(
min_lr=self.optimizer_min_lr, T_max=self.optimizer_lr_cycle_steps
)
def validate_features(self) -> None:
if not self.image_features and not self.env_state_feature:
raise ValueError("You must provide at least one image or the environment state among the inputs.")
@property
def observation_delta_indices(self) -> None:
far_past_obs = list(
range(
-self.lookback_aug - self.lookback_obs_steps, self.lookback_aug + 1 - self.lookback_obs_steps
)
)
recent_obs = list(range(2 - self.n_obs_steps, 1))
return far_past_obs + recent_obs
@property
def action_delta_indices(self) -> list:
far_past_actions = list(
range(
-self.lookback_aug - self.lookback_obs_steps, self.lookback_aug + 1 - self.lookback_obs_steps
)
)
recent_actions = list(range(2 - self.n_obs_steps, self.train_horizon))
return far_past_actions + recent_actions
@property
def reward_delta_indices(self) -> None:
return None

View File

@ -0,0 +1,558 @@
#!/usr/bin/env python
# Copyright 2025 Ilia Larchenko and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The implementation of the Decoder-Only Transformer (DOT) policy.
More details here: https://github.com/IliaLarchenko/dot_policy
"""
import math
import torch
import torchvision
from torch import Tensor, nn
from torchvision import transforms
from torchvision.ops.misc import FrozenBatchNorm2d
from torchvision.transforms.functional import InterpolationMode
from lerobot.common.policies.dot.configuration_dot import DOTConfig
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy
class DOT(nn.Module):
"""The underlying neural network for DOT
Note: Unlike ACT, DOT has no encoder, no VAE, and no cross-attention. All inputs are directly projected
to the model dimension and passed as memory to a Transformer decoder.
- Inputs (images, state, env_state) are linearly projected and concatenated.
- A trainable prefix token and positional embeddings are added.
- The Transformer decoder predicts a sequence of future actions autoregressively.
DOT Transformer
Used for autoregressive action prediction
(no encoder, no VAE)
image emb. state emb. env_state emb.
Concatenate + Add Positional Emb.
Transformer Decoder (L layers)
Linear projection to action space
Outputs
"""
def __init__(self, config: DOTConfig):
super().__init__()
self.config = config
self.projections = nn.ModuleDict()
self.n_features = 0
self.image_names = sorted(config.image_features.keys())
# Set up a shared visual backbone (e.g., ResNet18) for all cameras.
# The final layer is replaced with a linear projection to match model_dim.
if len(self.image_names) > 0:
backbone = getattr(torchvision.models, self.config.vision_backbone)(
weights=self.config.pretrained_backbone_weights,
norm_layer=FrozenBatchNorm2d,
)
backbone.fc = nn.Linear(backbone.fc.in_features, self.config.dim_model)
self.projections["images"] = add_lora_to_backbone(backbone, rank=config.lora_rank)
self.n_features += len(self.image_names) * self.config.n_obs_steps
if self.config.robot_state_feature:
self.projections["state"] = nn.Linear(
self.config.robot_state_feature.shape[0], self.config.dim_model
)
self.n_features += self.config.n_obs_steps
if self.config.env_state_feature:
self.projections["environment_state"] = nn.Linear(
self.config.env_state_feature.shape[0], self.config.dim_model
)
self.n_features += self.config.n_obs_steps
self.projections_names = sorted(self.projections.keys())
obs_mapping = {
"images": "observation.images",
"state": "observation.state",
"environment_state": "observation.environment_state",
}
self.obs_mapping = {k: v for k, v in obs_mapping.items() if k in self.projections_names}
# Optional trainable prefix token added to the input sequence (can be used for task conditioning or extra context)
self.prefix_input = nn.Parameter(torch.randn(1, 1, config.dim_model))
# Setup transformer decoder
dec_layer = nn.TransformerDecoderLayer(
d_model=self.config.dim_model,
nhead=self.config.n_heads,
dim_feedforward=self.config.dim_feedforward,
dropout=self.config.dropout,
batch_first=True,
norm_first=self.config.pre_norm,
)
decoder_norm = nn.LayerNorm(self.config.dim_model)
self.decoder = nn.TransformerDecoder(
dec_layer, num_layers=self.config.n_decoder_layers, norm=decoder_norm
)
# Sinusoidal positional encodings for the decoder input tokens (fixed, not trainable)
decoder_pos = create_sinusoidal_pos_embedding(
config.train_horizon + config.lookback_obs_steps, config.dim_model
)
decoder_pos = torch.cat(
[
decoder_pos[:1],
decoder_pos[-config.train_horizon - config.n_obs_steps + 2 :],
],
dim=0,
)
self.register_buffer("decoder_pos", decoder_pos)
# Extend positional encodings for inference (when inference_horizon > train_horizon)
decoder_pos_inf = self.decoder_pos[
: self.decoder_pos.shape[0] + self.config.inference_horizon - self.config.train_horizon
]
self.register_buffer("decoder_pos_inf", decoder_pos_inf)
# Causal mask for decoder: prevent attending to future positions
mask = torch.zeros(len(decoder_pos), len(decoder_pos), dtype=torch.bool)
mask[
: len(decoder_pos) + config.inference_horizon - config.train_horizon,
len(decoder_pos) + config.inference_horizon - config.train_horizon :,
] = True
self.register_buffer("mask", mask)
# Learnable positional embeddings for input tokens (state/image/env projections)
self.inputs_pos_emb = nn.Parameter(torch.empty(1, self.n_features, self.config.dim_model))
nn.init.uniform_(
self.inputs_pos_emb,
-((1 / self.config.dim_model) ** 0.5),
(1 / self.config.dim_model) ** 0.5,
)
# The output actions are generated by a linear layer
self.action_head = nn.Linear(self.config.dim_model, self.config.action_feature.shape[0])
def _process_inputs(self, batch):
# Project all inputs to the model dimension and concatenate them
inputs_projections_list = []
for state in self.projections_names:
batch_state = self.obs_mapping[state]
if batch_state in batch:
batch_size, n_obs, *obs_shape = batch[batch_state].shape
enc = self.projections[state](batch[batch_state].view(batch_size * n_obs, *obs_shape)).view(
batch_size, n_obs, -1
)
inputs_projections_list.append(enc)
return torch.cat(inputs_projections_list, dim=1)
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
"""
A forward pass through the Decision Transformer (DOT).
The model uses a transformer decoder to predict a sequence of future actions from projected
and positionally-embedded image, state, and environment features.
Args:
batch (dict): A dictionary containing the following keys (if available):
- "observation.images": (B, T, C, H, W) tensor of camera frames.
- "observation.state": (B, T, D) tensor of proprioceptive robot states.
- "observation.environment_state": (B, T, D) tensor of environment states.
Returns:
Tensor: A tensor of shape (B, horizon, action_dim) containing predicted future actions.
"""
# Project image/state/env_state inputs to the model dimension and concatenate along the time axis.
inputs_projections = self._process_inputs(batch) # (B, T, D)
batch_size = inputs_projections.shape[0]
# Add learnable positional embeddings to each projected input token.
inputs_projections += self.inputs_pos_emb.expand(batch_size, -1, -1)
# Prepend a trainable prefix token to the input sequence
inputs_projections = torch.cat(
[self.prefix_input.expand(batch_size, -1, -1), inputs_projections], dim=1
) # (B, T+1, D)
# Use different positional encodings and masks for training vs. inference.
if self.training:
decoder_out = self.decoder(
self.decoder_pos.expand(batch_size, -1, -1), inputs_projections, self.mask
)
else:
decoder_out = self.decoder(self.decoder_pos_inf.expand(batch_size, -1, -1), inputs_projections)
return self.action_head(decoder_out)
class DOTPolicy(PreTrainedPolicy):
"""
Decision Transformer (DOT) Policy. (github: https://github.com/IliaLarchenko/dot_policy)
A minimal transformer decoder-based policy for autoregressive action prediction in robot control.
This is a simplified alternative to ACT: no encoder, no VAE, and no cross-attention, making it efficient
for deployment in low-dimensional environments with visual and proprioceptive inputs.
"""
name = "dot"
config_class = DOTConfig
def __init__(
self,
config: DOTConfig,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
config (DOTConfig): Configuration for the DOT model and policy behavior.
dataset_stats (optional): Dataset statistics used for normalizing inputs/outputs.
If not provided, stats should be set later via `load_state_dict()` before inference.
"""
super().__init__(config)
config.validate_features()
self.config = config
self.image_names = sorted(config.image_features.keys())
if config.override_dataset_stats:
if dataset_stats is None:
dataset_stats = {}
for k, v in config.new_dataset_stats.items():
if k not in dataset_stats:
dataset_stats[k] = {}
for k1, v1 in v.items():
dataset_stats[k][k1] = torch.tensor(v1)
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.model = DOT(self.config)
self.state_noise = self.config.state_noise
self.crop_scale = self.config.crop_scale
self.alpha = self.config.alpha
self.inference_horizon = self.config.inference_horizon
self.return_every_n = self.config.return_every_n
self.predict_every_n = self.config.predict_every_n
# Inference action chunking and observation queues
self._old_predictions = None
self._input_buffers = {}
# Weights used for chunking
action_weights = self.alpha ** torch.arange(self.inference_horizon).float()
action_weights /= action_weights.sum()
action_weights = action_weights.view(1, -1, 1)
self.register_buffer("action_weights", action_weights)
# Weights for the loss computations
# Actions that are further in the future are weighted less
loss_weights = torch.ones(self.config.train_horizon + self.config.n_obs_steps - 1)
loss_weights[-self.config.train_horizon :] = (
self.config.train_alpha ** torch.arange(self.config.train_horizon).float()
)
loss_weights /= loss_weights.mean()
loss_weights = loss_weights.view(1, -1, 1)
self.register_buffer("loss_weights", loss_weights)
# TODO(jadechoghari): Move augmentations to dataloader (__getitem__) for CPU-side processing.
# Nearest interpolation is required for PushT but may be not the best in general
self.resize_transform = transforms.Resize(
config.rescale_shape, interpolation=InterpolationMode.NEAREST
)
self.step = 0
self.last_action = None
def reset(self):
self._old_predictions = None
self._input_buffers = {}
self.last_action = None
self.step = 0
def get_optim_params(self) -> dict:
return self.model.parameters()
def _update_observation_buffers(self, buffer_name: str, observation: Tensor) -> Tensor:
# Maintain a rolling buffer of lookback_obs_steps + 1;
# shift left and append new observation each step
if buffer_name not in self._input_buffers:
self._input_buffers[buffer_name] = observation.unsqueeze(1).repeat(
1,
self.config.lookback_obs_steps + 1,
*torch.ones(len(observation.shape[1:])).int(),
)
else:
self._input_buffers[buffer_name] = self._input_buffers[buffer_name].roll(shifts=-1, dims=1)
self._input_buffers[buffer_name][:, -1] = observation
return torch.cat(
[
self._input_buffers[buffer_name][:, :1],
self._input_buffers[buffer_name][:, -(self.config.n_obs_steps - 1) :],
],
dim=1,
)
def _prepare_batch_for_inference(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = self.normalize_inputs(batch)
# Resize and stack all images
if len(self.image_names) > 0:
batch["observation.images"] = torch.stack(
[self.resize_transform(batch[k]) for k in self.image_names],
dim=1,
) # batch_size, n_cam, c, h, w
# Update observation queues for all inputs and stack the last n_obs_steps
for name, batch_name in self.model.obs_mapping.items():
batch[batch_name] = self._update_observation_buffers(name, batch[batch_name])
# Reshape images tensor to keep the same order as during training
if "observation.images" in batch:
batch["observation.images"] = batch["observation.images"].flatten(1, 2)
# batch_size, n_obs * n_cam, c, h, w
return batch
def _chunk_actions(self, actions: Tensor) -> Tensor:
# Store the previous action predictions in a buffer
# Compute the weighted average of the inference horizon action predictions
if self._old_predictions is not None:
self._old_predictions[:, 0] = actions
else:
self._old_predictions = actions.unsqueeze(1).repeat(1, self.config.inference_horizon, 1, 1)
action = (self._old_predictions[:, :, 0] * self.action_weights).sum(dim=1)
self._old_predictions = self._old_predictions.roll(shifts=(1, -1), dims=(1, 2))
return action
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""
Select an action given current environment observations.
This function handles autoregressive rollout during inference using a fixed prediction horizon.
The model predicts every `predict_every_n` steps, and returns actions every `return_every_n` steps.
Between predictions, previously predicted actions are reused by shifting and repeating the last step.
"""
self.eval()
batch = self._prepare_batch_for_inference(batch)
# Only run model prediction every predict_every_n steps
if self.step % self.predict_every_n == 0:
actions_pred = self.model(batch)[:, -self.config.inference_horizon :]
self.last_action = self.unnormalize_outputs({"action": actions_pred})["action"]
else:
# Otherwise shift previous predictions and repeat last action
self.last_action = self.last_action.roll(-1, dims=1)
self.last_action[:, -1] = self.last_action[:, -2]
self.step += 1
# Return chunked actions for return_every_n steps
action = self._chunk_actions(self.last_action)
for _ in range(self.return_every_n - 1):
self.last_action = self.last_action.roll(-1, dims=1)
self.last_action[:, -1] = self.last_action[:, -2]
action = self._chunk_actions(self.last_action)
return action
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""
lookback_ind = torch.randint(0, 2 * self.config.lookback_aug + 1, (1,)).item()
for k in list(self.model.obs_mapping.values()) + list(self.image_names) + ["action", "action_is_pad"]:
if k != "observation.images":
batch[k] = torch.cat(
[
batch[k][:, lookback_ind : lookback_ind + 1],
batch[k][:, 2 * self.config.lookback_aug + 1 :],
],
1,
)
batch = self.normalize_targets(self.normalize_inputs(batch))
if len(self.config.image_features) > 0:
scale = 1 - torch.rand(1) * (1 - self.crop_scale)
new_shape = (
int(self.config.rescale_shape[0] * scale),
int(self.config.rescale_shape[1] * scale),
)
crop_transform = transforms.RandomCrop(new_shape)
for k in self.image_names:
batch_size, n_obs, c, h, w = batch[k].shape
batch[k] = batch[k].view(batch_size * n_obs, c, h, w)
batch[k] = crop_transform(self.resize_transform(batch[k]))
batch[k] = batch[k].view(batch_size, n_obs, c, *batch[k].shape[-2:])
batch["observation.images"] = torch.stack([batch[k] for k in self.image_names], dim=2).flatten(
1, 2
) # batch_size, n_obs * n_cam, c, h, w
# Add random noise to states during training
# TODO(jadechoghari): better to move this to the dataloader
if self.state_noise is not None:
for k in self.model.obs_mapping.values():
if k != "observation.images":
batch[k] += (torch.rand_like(batch[k]) * 2 - 1) * self.state_noise
actions_hat = self.model(batch)
l1_loss = nn.functional.l1_loss(batch["action"], actions_hat, reduction="none")
rev_padding = (~batch["action_is_pad"]).unsqueeze(-1)
# Apply padding, weights and decay to the loss
l1_loss = (l1_loss * rev_padding * self.loss_weights).mean()
loss_dict = {"l1_loss": l1_loss.item()}
loss = l1_loss
# Reduce the aggressiveness of augmentations
self.state_noise *= self.config.noise_decay
self.crop_scale = 1 - (1 - self.crop_scale) * self.config.noise_decay
return loss, loss_dict
@classmethod
def from_pretrained(cls, pretrained_name_or_path, *args, **kwargs):
"""Load model from pretrained checkpoint and merge LoRA after loading"""
policy = super().from_pretrained(pretrained_name_or_path, *args, **kwargs)
if getattr(policy.config, "merge_lora", False):
print("Merging LoRA after loading pretrained model...")
policy.model = merge_lora_weights(policy.model)
return policy
class LoRAConv2d(nn.Module):
"""
Applies Low-Rank Adaptation (LoRA) to a Conv2D layer.
LoRA adds trainable low-rank matrices (A and B) to adapt pretrained weights without full fine-tuning.
The adaptation is merged into the base conv weights via `merge_lora()` after training.
Args:
base_conv (nn.Conv2d): The original convolutional layer to be adapted.
rank (int): The rank of the low-rank approximation (default: 4).
"""
def __init__(self, base_conv: nn.Conv2d, rank: int = 4):
super().__init__()
self.base_conv = base_conv
# Flatten the original conv weight
out_channels, in_channels, kh, kw = base_conv.weight.shape
self.weight_shape = (out_channels, in_channels, kh, kw)
fan_in = in_channels * kh * kw
# Low-rank trainable matrices A and B
self.lora_A = nn.Parameter(torch.normal(0, 0.02, (out_channels, rank)))
self.lora_B = nn.Parameter(torch.normal(0, 0.02, (rank, fan_in)))
def forward(self, x: torch.Tensor) -> torch.Tensor:
lora_update = torch.matmul(self.lora_A, self.lora_B).view(self.weight_shape)
return nn.functional.conv2d(
x,
self.base_conv.weight + lora_update,
self.base_conv.bias,
stride=self.base_conv.stride,
padding=self.base_conv.padding,
dilation=self.base_conv.dilation,
groups=self.base_conv.groups,
)
def merge_lora(self) -> nn.Conv2d:
"""Merge LoRA weights into the base convolution and return a standard Conv2d layer"""
lora_update = torch.matmul(self.lora_A, self.lora_B).view(self.weight_shape)
self.base_conv.weight.copy_(self.base_conv.weight + lora_update)
return self.base_conv
def replace_conv2d_with_lora(module: nn.Module, rank: int = 4) -> nn.Module:
"""Recursively replace Conv2d layers with LoRAConv2d in the module"""
for name, child in list(module.named_children()):
if isinstance(child, nn.Conv2d):
setattr(module, name, LoRAConv2d(child, rank))
else:
replace_conv2d_with_lora(child, rank)
return module
def merge_lora_weights(module: nn.Module) -> nn.Module:
"""Recursively merge LoRA weights in the module"""
for name, child in list(module.named_children()):
if isinstance(child, LoRAConv2d):
setattr(module, name, child.merge_lora())
else:
merge_lora_weights(child)
return module
def add_lora_to_backbone(backbone: nn.Module, rank: int = 4) -> nn.Module:
"""
Adds LoRA to a convolutional backbone by replacing Conv2d layers
and freezing all other weights except LoRA layers and the final classifier.
"""
replace_conv2d_with_lora(backbone, rank)
for name, param in backbone.named_parameters():
if "lora_" in name or name.startswith("fc"):
param.requires_grad = True
else:
param.requires_grad = False
return backbone
def create_sinusoidal_pos_embedding(num_positions: int, dimension: int) -> Tensor:
"""Generates sinusoidal positional embeddings like in the original Transformer paper."""
position = torch.arange(num_positions, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, dimension, 2, dtype=torch.float) * (-math.log(10000.0) / dimension))
pe = torch.zeros(num_positions, dimension)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
return pe

View File

@ -24,6 +24,7 @@ from lerobot.common.envs.configs import EnvConfig
from lerobot.common.envs.utils import env_to_policy_features
from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.dot.configuration_dot import DOTConfig
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
from lerobot.common.policies.pretrained import PreTrainedPolicy
@ -59,6 +60,10 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
from lerobot.common.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy
return PI0FASTPolicy
elif name == "dot":
from lerobot.common.policies.dot.modeling_dot import DOTPolicy
return DOTPolicy
else:
raise NotImplementedError(f"Policy with name {name} is not implemented.")
@ -76,6 +81,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return PI0Config(**kwargs)
elif policy_type == "pi0fast":
return PI0FASTConfig(**kwargs)
elif policy_type == "dot":
return DOTConfig(**kwargs)
else:
raise ValueError(f"Policy type '{policy_type}' is not available.")