Merge 3d8a29fe6e
into 1c873df5c0
This commit is contained in:
commit
4b17751e5a
|
@ -20,7 +20,7 @@ from pathlib import Path
|
|||
|
||||
import draccus
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR, LRScheduler
|
||||
|
||||
from lerobot.common.constants import SCHEDULER_STATE
|
||||
from lerobot.common.datasets.utils import write_json
|
||||
|
@ -120,3 +120,16 @@ def load_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> LRScheduler:
|
|||
state_dict = deserialize_json_into_object(save_dir / SCHEDULER_STATE, scheduler.state_dict())
|
||||
scheduler.load_state_dict(state_dict)
|
||||
return scheduler
|
||||
|
||||
|
||||
@LRSchedulerConfig.register_subclass("cosine_annealing")
|
||||
@dataclass
|
||||
class CosineAnnealingSchedulerConfig(LRSchedulerConfig):
|
||||
"""Implements Cosine Annealing learning rate scheduler"""
|
||||
|
||||
min_lr: float = 0 # Minimum learning rate
|
||||
T_max: int = 100000 # Number of iterations for a full decay (half-cycle)
|
||||
num_warmup_steps: int = 0 # Not used but somehow required by the parent class
|
||||
|
||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LRScheduler:
|
||||
return CosineAnnealingLR(optimizer, T_max=self.T_max, eta_min=self.min_lr)
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
|
||||
from .act.configuration_act import ACTConfig as ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||
from .dot.configuration_dot import DOTConfig as DOTConfig
|
||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
|
||||
|
|
|
@ -0,0 +1,212 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Ilia Larchenko and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common.optim.optimizers import AdamWConfig
|
||||
from lerobot.common.optim.schedulers import CosineAnnealingSchedulerConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("dot")
|
||||
@dataclass
|
||||
class DOTConfig(PreTrainedConfig):
|
||||
"""Configuration class for the Decision Transformer (DOT) policy.
|
||||
|
||||
DOT is a transformer-based policy for sequential decision making that predicts future actions based on
|
||||
a history of past observations and actions. This configuration enables fine-grained
|
||||
control over the model’s temporal horizon, input normalization, architectural parameters, and
|
||||
augmentation strategies.
|
||||
|
||||
Defaults are configured for general robot manipulation tasks like Push-T and ALOHA insert/transfer.
|
||||
|
||||
The parameters you will most likely need to modify are those related to temporal structure and
|
||||
normalization:
|
||||
- `train_horizon` and `inference_horizon`
|
||||
- `lookback_obs_steps` and `lookback_aug`
|
||||
- `alpha` and `train_alpha`
|
||||
- `normalization_mapping`
|
||||
|
||||
Notes on the temporal design:
|
||||
- `train_horizon`: Length of action sequence the model is trained on. Must be ≥ `inference_horizon`.
|
||||
- `inference_horizon`: How far into the future the model predicts during inference (in environment steps).
|
||||
A good rule of thumb is 2×FPS (e.g., 30–50 for 15–25 FPS environments).
|
||||
- `alpha` / `train_alpha`: Control exponential decay of loss weights for inference and training.
|
||||
These should be tuned such that all predicted steps contribute meaningful signal.
|
||||
|
||||
Notes on the inputs:
|
||||
- Observations can come from:
|
||||
- Images (e.g., keys starting with `"observation.images"`)
|
||||
- Proprioceptive state (`"observation.state"`)
|
||||
- Environment state (`"observation.environment_state"`)
|
||||
- At least one of image or environment state inputs must be provided.
|
||||
- The "action" key is required as an output.
|
||||
|
||||
Args:
|
||||
n_obs_steps: Number of past steps passed to the model, including the current step.
|
||||
train_horizon: Number of future steps the model is trained to predict.
|
||||
inference_horizon: Number of future steps predicted during inference.
|
||||
lookback_obs_steps: Number of past steps to include for temporal context.
|
||||
lookback_aug: Number of steps into the far past from which to randomly sample for augmentation.
|
||||
normalization_mapping: Dictionary specifying normalization mode for each input/output group.
|
||||
override_dataset_stats: If True, replaces the dataset's stats with manually defined `new_dataset_stats`.
|
||||
new_dataset_stats: Optional manual min/max overrides used if `override_dataset_stats=True`.
|
||||
vision_backbone: Name of the ResNet variant used for image encoding (e.g., "resnet18").
|
||||
pretrained_backbone_weights: Optional pretrained weights (e.g., "ResNet18_Weights.IMAGENET1K_V1").
|
||||
pre_norm: Whether to apply pre-norm in transformer layers.
|
||||
lora_rank: If > 0, applies LoRA adapters of the given rank to transformer layers.
|
||||
merge_lora: Whether to merge LoRA weights at inference time.
|
||||
dim_model: Dimension of the transformer hidden state.
|
||||
n_heads: Number of attention heads.
|
||||
dim_feedforward: Dimension of the feedforward MLP inside the transformer.
|
||||
n_decoder_layers: Number of transformer decoder layers.
|
||||
rescale_shape: Resize shape for input images (e.g., (96, 96)).
|
||||
crop_scale: Image crop scale for augmentation.
|
||||
state_noise: Magnitude of additive uniform noise for state inputs.
|
||||
noise_decay: Decay factor applied to `crop_scale` and `state_noise` during training.
|
||||
dropout: Dropout rate used in transformer layers.
|
||||
alpha: Decay factor for inference loss weighting.
|
||||
train_alpha: Decay factor for training loss weighting.
|
||||
predict_every_n: Predict actions every `n` frames instead of every frame.
|
||||
return_every_n: Return every `n`-th predicted action during inference.
|
||||
optimizer_lr: Initial learning rate.
|
||||
optimizer_min_lr: Minimum learning rate for cosine scheduler.
|
||||
optimizer_lr_cycle_steps: Total steps in one learning rate cycle.
|
||||
optimizer_weight_decay: L2 weight decay for optimizer.
|
||||
|
||||
Raises:
|
||||
ValueError: If the temporal settings are inconsistent (e.g., `train_horizon < inference_horizon`,
|
||||
or `predict_every_n` > allowed bounds).
|
||||
"""
|
||||
|
||||
# Input / output structure.
|
||||
n_obs_steps: int = 3
|
||||
train_horizon: int = 20
|
||||
inference_horizon: int = 20
|
||||
lookback_obs_steps: int = 10
|
||||
lookback_aug: int = 5
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
"STATE": NormalizationMode.MIN_MAX,
|
||||
"ENV": NormalizationMode.MIN_MAX,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
)
|
||||
|
||||
# Align with the new config system
|
||||
override_dataset_stats: bool = False
|
||||
new_dataset_stats: dict[str, dict[str, list[float]]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": {"max": [512.0] * 2, "min": [0.0] * 2},
|
||||
"observation.environment_state": {"max": [512.0] * 16, "min": [0.0] * 16},
|
||||
"observation.state": {"max": [512.0] * 2, "min": [0.0] * 2},
|
||||
}
|
||||
)
|
||||
|
||||
# Architecture.
|
||||
vision_backbone: str = "resnet18"
|
||||
pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1"
|
||||
pre_norm: bool = True
|
||||
lora_rank: int = 20
|
||||
merge_lora: bool = False
|
||||
|
||||
dim_model: int = 128
|
||||
n_heads: int = 8
|
||||
dim_feedforward: int = 512
|
||||
n_decoder_layers: int = 8
|
||||
rescale_shape: tuple[int, int] = (96, 96)
|
||||
|
||||
# Augmentation.
|
||||
crop_scale: float = 0.8
|
||||
state_noise: float = 0.01
|
||||
noise_decay: float = 0.999995
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: float = 0.1
|
||||
|
||||
# Weighting and inference.
|
||||
alpha: float = 0.75
|
||||
train_alpha: float = 0.9
|
||||
predict_every_n: int = 1
|
||||
return_every_n: int = 1
|
||||
|
||||
# Training preset
|
||||
optimizer_lr: float = 1.0e-4
|
||||
optimizer_min_lr: float = 1.0e-4
|
||||
optimizer_lr_cycle_steps: int = 300000
|
||||
optimizer_weight_decay: float = 1e-5
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self.predict_every_n > self.inference_horizon:
|
||||
raise ValueError(
|
||||
f"predict_every_n ({self.predict_every_n}) must be less than or equal to horizon ({self.inference_horizon})."
|
||||
)
|
||||
if self.return_every_n > self.inference_horizon:
|
||||
raise ValueError(
|
||||
f"return_every_n ({self.return_every_n}) must be less than or equal to horizon ({self.inference_horizon})."
|
||||
)
|
||||
if self.predict_every_n > self.inference_horizon // self.return_every_n:
|
||||
raise ValueError(
|
||||
f"predict_every_n ({self.predict_every_n}) must be less than or equal to horizon // return_every_n({self.inference_horizon // self.return_every_n})."
|
||||
)
|
||||
if self.train_horizon < self.inference_horizon:
|
||||
raise ValueError(
|
||||
f"train_horizon ({self.train_horizon}) must be greater than or equal to horizon ({self.inference_horizon})."
|
||||
)
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> None:
|
||||
return CosineAnnealingSchedulerConfig(
|
||||
min_lr=self.optimizer_min_lr, T_max=self.optimizer_lr_cycle_steps
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if not self.image_features and not self.env_state_feature:
|
||||
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> None:
|
||||
far_past_obs = list(
|
||||
range(
|
||||
-self.lookback_aug - self.lookback_obs_steps, self.lookback_aug + 1 - self.lookback_obs_steps
|
||||
)
|
||||
)
|
||||
recent_obs = list(range(2 - self.n_obs_steps, 1))
|
||||
|
||||
return far_past_obs + recent_obs
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
far_past_actions = list(
|
||||
range(
|
||||
-self.lookback_aug - self.lookback_obs_steps, self.lookback_aug + 1 - self.lookback_obs_steps
|
||||
)
|
||||
)
|
||||
recent_actions = list(range(2 - self.n_obs_steps, self.train_horizon))
|
||||
|
||||
return far_past_actions + recent_actions
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
|
@ -0,0 +1,558 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Ilia Larchenko and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""The implementation of the Decoder-Only Transformer (DOT) policy.
|
||||
|
||||
More details here: https://github.com/IliaLarchenko/dot_policy
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
from torch import Tensor, nn
|
||||
from torchvision import transforms
|
||||
from torchvision.ops.misc import FrozenBatchNorm2d
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
|
||||
from lerobot.common.policies.dot.configuration_dot import DOTConfig
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
|
||||
class DOT(nn.Module):
|
||||
"""The underlying neural network for DOT
|
||||
Note: Unlike ACT, DOT has no encoder, no VAE, and no cross-attention. All inputs are directly projected
|
||||
to the model dimension and passed as memory to a Transformer decoder.
|
||||
|
||||
- Inputs (images, state, env_state) are linearly projected and concatenated.
|
||||
- A trainable prefix token and positional embeddings are added.
|
||||
- The Transformer decoder predicts a sequence of future actions autoregressively.
|
||||
|
||||
DOT Transformer
|
||||
Used for autoregressive action prediction
|
||||
(no encoder, no VAE)
|
||||
|
||||
┌──────────────────────────────────────────────────────┐
|
||||
│ image emb. state emb. env_state emb. │
|
||||
│ │ │ │ │
|
||||
│ ┌───────┘ │ │ │
|
||||
│ │ ┌────────┘ │ │
|
||||
│ ▼ ▼ ▼ │
|
||||
│ ┌──────────────────────────────────────────┐ │
|
||||
│ │ Concatenate + Add Positional Emb. │ │
|
||||
│ └──────────────────────────────────────────┘ │
|
||||
│ │ │
|
||||
│ ▼ │
|
||||
│ ┌───────────────────────────────────┐ │
|
||||
│ │ Transformer Decoder (L layers)│ │
|
||||
│ └───────────────────────────────────┘ │
|
||||
│ │ │
|
||||
│ ▼ │
|
||||
│ Linear projection to action space │
|
||||
│ │ │
|
||||
│ ▼ │
|
||||
│ Outputs │
|
||||
└──────────────────────────────────────────────────────┘
|
||||
"""
|
||||
|
||||
def __init__(self, config: DOTConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.projections = nn.ModuleDict()
|
||||
self.n_features = 0
|
||||
|
||||
self.image_names = sorted(config.image_features.keys())
|
||||
|
||||
# Set up a shared visual backbone (e.g., ResNet18) for all cameras.
|
||||
# The final layer is replaced with a linear projection to match model_dim.
|
||||
if len(self.image_names) > 0:
|
||||
backbone = getattr(torchvision.models, self.config.vision_backbone)(
|
||||
weights=self.config.pretrained_backbone_weights,
|
||||
norm_layer=FrozenBatchNorm2d,
|
||||
)
|
||||
backbone.fc = nn.Linear(backbone.fc.in_features, self.config.dim_model)
|
||||
|
||||
self.projections["images"] = add_lora_to_backbone(backbone, rank=config.lora_rank)
|
||||
self.n_features += len(self.image_names) * self.config.n_obs_steps
|
||||
|
||||
if self.config.robot_state_feature:
|
||||
self.projections["state"] = nn.Linear(
|
||||
self.config.robot_state_feature.shape[0], self.config.dim_model
|
||||
)
|
||||
self.n_features += self.config.n_obs_steps
|
||||
|
||||
if self.config.env_state_feature:
|
||||
self.projections["environment_state"] = nn.Linear(
|
||||
self.config.env_state_feature.shape[0], self.config.dim_model
|
||||
)
|
||||
self.n_features += self.config.n_obs_steps
|
||||
|
||||
self.projections_names = sorted(self.projections.keys())
|
||||
obs_mapping = {
|
||||
"images": "observation.images",
|
||||
"state": "observation.state",
|
||||
"environment_state": "observation.environment_state",
|
||||
}
|
||||
self.obs_mapping = {k: v for k, v in obs_mapping.items() if k in self.projections_names}
|
||||
|
||||
# Optional trainable prefix token added to the input sequence (can be used for task conditioning or extra context)
|
||||
self.prefix_input = nn.Parameter(torch.randn(1, 1, config.dim_model))
|
||||
|
||||
# Setup transformer decoder
|
||||
dec_layer = nn.TransformerDecoderLayer(
|
||||
d_model=self.config.dim_model,
|
||||
nhead=self.config.n_heads,
|
||||
dim_feedforward=self.config.dim_feedforward,
|
||||
dropout=self.config.dropout,
|
||||
batch_first=True,
|
||||
norm_first=self.config.pre_norm,
|
||||
)
|
||||
|
||||
decoder_norm = nn.LayerNorm(self.config.dim_model)
|
||||
self.decoder = nn.TransformerDecoder(
|
||||
dec_layer, num_layers=self.config.n_decoder_layers, norm=decoder_norm
|
||||
)
|
||||
|
||||
# Sinusoidal positional encodings for the decoder input tokens (fixed, not trainable)
|
||||
decoder_pos = create_sinusoidal_pos_embedding(
|
||||
config.train_horizon + config.lookback_obs_steps, config.dim_model
|
||||
)
|
||||
decoder_pos = torch.cat(
|
||||
[
|
||||
decoder_pos[:1],
|
||||
decoder_pos[-config.train_horizon - config.n_obs_steps + 2 :],
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
self.register_buffer("decoder_pos", decoder_pos)
|
||||
|
||||
# Extend positional encodings for inference (when inference_horizon > train_horizon)
|
||||
decoder_pos_inf = self.decoder_pos[
|
||||
: self.decoder_pos.shape[0] + self.config.inference_horizon - self.config.train_horizon
|
||||
]
|
||||
self.register_buffer("decoder_pos_inf", decoder_pos_inf)
|
||||
# Causal mask for decoder: prevent attending to future positions
|
||||
mask = torch.zeros(len(decoder_pos), len(decoder_pos), dtype=torch.bool)
|
||||
mask[
|
||||
: len(decoder_pos) + config.inference_horizon - config.train_horizon,
|
||||
len(decoder_pos) + config.inference_horizon - config.train_horizon :,
|
||||
] = True
|
||||
self.register_buffer("mask", mask)
|
||||
|
||||
# Learnable positional embeddings for input tokens (state/image/env projections)
|
||||
self.inputs_pos_emb = nn.Parameter(torch.empty(1, self.n_features, self.config.dim_model))
|
||||
nn.init.uniform_(
|
||||
self.inputs_pos_emb,
|
||||
-((1 / self.config.dim_model) ** 0.5),
|
||||
(1 / self.config.dim_model) ** 0.5,
|
||||
)
|
||||
|
||||
# The output actions are generated by a linear layer
|
||||
self.action_head = nn.Linear(self.config.dim_model, self.config.action_feature.shape[0])
|
||||
|
||||
def _process_inputs(self, batch):
|
||||
# Project all inputs to the model dimension and concatenate them
|
||||
inputs_projections_list = []
|
||||
|
||||
for state in self.projections_names:
|
||||
batch_state = self.obs_mapping[state]
|
||||
if batch_state in batch:
|
||||
batch_size, n_obs, *obs_shape = batch[batch_state].shape
|
||||
enc = self.projections[state](batch[batch_state].view(batch_size * n_obs, *obs_shape)).view(
|
||||
batch_size, n_obs, -1
|
||||
)
|
||||
inputs_projections_list.append(enc)
|
||||
|
||||
return torch.cat(inputs_projections_list, dim=1)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
|
||||
"""
|
||||
A forward pass through the Decision Transformer (DOT).
|
||||
|
||||
The model uses a transformer decoder to predict a sequence of future actions from projected
|
||||
and positionally-embedded image, state, and environment features.
|
||||
|
||||
Args:
|
||||
batch (dict): A dictionary containing the following keys (if available):
|
||||
- "observation.images": (B, T, C, H, W) tensor of camera frames.
|
||||
- "observation.state": (B, T, D) tensor of proprioceptive robot states.
|
||||
- "observation.environment_state": (B, T, D) tensor of environment states.
|
||||
|
||||
Returns:
|
||||
Tensor: A tensor of shape (B, horizon, action_dim) containing predicted future actions.
|
||||
"""
|
||||
# Project image/state/env_state inputs to the model dimension and concatenate along the time axis.
|
||||
inputs_projections = self._process_inputs(batch) # (B, T, D)
|
||||
batch_size = inputs_projections.shape[0]
|
||||
|
||||
# Add learnable positional embeddings to each projected input token.
|
||||
inputs_projections += self.inputs_pos_emb.expand(batch_size, -1, -1)
|
||||
|
||||
# Prepend a trainable prefix token to the input sequence
|
||||
inputs_projections = torch.cat(
|
||||
[self.prefix_input.expand(batch_size, -1, -1), inputs_projections], dim=1
|
||||
) # (B, T+1, D)
|
||||
|
||||
# Use different positional encodings and masks for training vs. inference.
|
||||
if self.training:
|
||||
decoder_out = self.decoder(
|
||||
self.decoder_pos.expand(batch_size, -1, -1), inputs_projections, self.mask
|
||||
)
|
||||
else:
|
||||
decoder_out = self.decoder(self.decoder_pos_inf.expand(batch_size, -1, -1), inputs_projections)
|
||||
return self.action_head(decoder_out)
|
||||
|
||||
|
||||
class DOTPolicy(PreTrainedPolicy):
|
||||
"""
|
||||
Decision Transformer (DOT) Policy. (github: https://github.com/IliaLarchenko/dot_policy)
|
||||
|
||||
A minimal transformer decoder-based policy for autoregressive action prediction in robot control.
|
||||
This is a simplified alternative to ACT: no encoder, no VAE, and no cross-attention, making it efficient
|
||||
for deployment in low-dimensional environments with visual and proprioceptive inputs.
|
||||
"""
|
||||
|
||||
name = "dot"
|
||||
config_class = DOTConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: DOTConfig,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
config (DOTConfig): Configuration for the DOT model and policy behavior.
|
||||
dataset_stats (optional): Dataset statistics used for normalizing inputs/outputs.
|
||||
If not provided, stats should be set later via `load_state_dict()` before inference.
|
||||
"""
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.image_names = sorted(config.image_features.keys())
|
||||
|
||||
if config.override_dataset_stats:
|
||||
if dataset_stats is None:
|
||||
dataset_stats = {}
|
||||
for k, v in config.new_dataset_stats.items():
|
||||
if k not in dataset_stats:
|
||||
dataset_stats[k] = {}
|
||||
for k1, v1 in v.items():
|
||||
dataset_stats[k][k1] = torch.tensor(v1)
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.model = DOT(self.config)
|
||||
|
||||
self.state_noise = self.config.state_noise
|
||||
self.crop_scale = self.config.crop_scale
|
||||
self.alpha = self.config.alpha
|
||||
self.inference_horizon = self.config.inference_horizon
|
||||
self.return_every_n = self.config.return_every_n
|
||||
self.predict_every_n = self.config.predict_every_n
|
||||
|
||||
# Inference action chunking and observation queues
|
||||
self._old_predictions = None
|
||||
self._input_buffers = {}
|
||||
|
||||
# Weights used for chunking
|
||||
action_weights = self.alpha ** torch.arange(self.inference_horizon).float()
|
||||
action_weights /= action_weights.sum()
|
||||
action_weights = action_weights.view(1, -1, 1)
|
||||
self.register_buffer("action_weights", action_weights)
|
||||
|
||||
# Weights for the loss computations
|
||||
# Actions that are further in the future are weighted less
|
||||
loss_weights = torch.ones(self.config.train_horizon + self.config.n_obs_steps - 1)
|
||||
loss_weights[-self.config.train_horizon :] = (
|
||||
self.config.train_alpha ** torch.arange(self.config.train_horizon).float()
|
||||
)
|
||||
loss_weights /= loss_weights.mean()
|
||||
loss_weights = loss_weights.view(1, -1, 1)
|
||||
self.register_buffer("loss_weights", loss_weights)
|
||||
|
||||
# TODO(jadechoghari): Move augmentations to dataloader (__getitem__) for CPU-side processing.
|
||||
# Nearest interpolation is required for PushT but may be not the best in general
|
||||
self.resize_transform = transforms.Resize(
|
||||
config.rescale_shape, interpolation=InterpolationMode.NEAREST
|
||||
)
|
||||
|
||||
self.step = 0
|
||||
self.last_action = None
|
||||
|
||||
def reset(self):
|
||||
self._old_predictions = None
|
||||
self._input_buffers = {}
|
||||
self.last_action = None
|
||||
self.step = 0
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.model.parameters()
|
||||
|
||||
def _update_observation_buffers(self, buffer_name: str, observation: Tensor) -> Tensor:
|
||||
# Maintain a rolling buffer of lookback_obs_steps + 1;
|
||||
# shift left and append new observation each step
|
||||
if buffer_name not in self._input_buffers:
|
||||
self._input_buffers[buffer_name] = observation.unsqueeze(1).repeat(
|
||||
1,
|
||||
self.config.lookback_obs_steps + 1,
|
||||
*torch.ones(len(observation.shape[1:])).int(),
|
||||
)
|
||||
else:
|
||||
self._input_buffers[buffer_name] = self._input_buffers[buffer_name].roll(shifts=-1, dims=1)
|
||||
self._input_buffers[buffer_name][:, -1] = observation
|
||||
|
||||
return torch.cat(
|
||||
[
|
||||
self._input_buffers[buffer_name][:, :1],
|
||||
self._input_buffers[buffer_name][:, -(self.config.n_obs_steps - 1) :],
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
def _prepare_batch_for_inference(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
# Resize and stack all images
|
||||
if len(self.image_names) > 0:
|
||||
batch["observation.images"] = torch.stack(
|
||||
[self.resize_transform(batch[k]) for k in self.image_names],
|
||||
dim=1,
|
||||
) # batch_size, n_cam, c, h, w
|
||||
|
||||
# Update observation queues for all inputs and stack the last n_obs_steps
|
||||
for name, batch_name in self.model.obs_mapping.items():
|
||||
batch[batch_name] = self._update_observation_buffers(name, batch[batch_name])
|
||||
|
||||
# Reshape images tensor to keep the same order as during training
|
||||
if "observation.images" in batch:
|
||||
batch["observation.images"] = batch["observation.images"].flatten(1, 2)
|
||||
# batch_size, n_obs * n_cam, c, h, w
|
||||
|
||||
return batch
|
||||
|
||||
def _chunk_actions(self, actions: Tensor) -> Tensor:
|
||||
# Store the previous action predictions in a buffer
|
||||
# Compute the weighted average of the inference horizon action predictions
|
||||
if self._old_predictions is not None:
|
||||
self._old_predictions[:, 0] = actions
|
||||
else:
|
||||
self._old_predictions = actions.unsqueeze(1).repeat(1, self.config.inference_horizon, 1, 1)
|
||||
|
||||
action = (self._old_predictions[:, :, 0] * self.action_weights).sum(dim=1)
|
||||
self._old_predictions = self._old_predictions.roll(shifts=(1, -1), dims=(1, 2))
|
||||
|
||||
return action
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""
|
||||
Select an action given current environment observations.
|
||||
|
||||
This function handles autoregressive rollout during inference using a fixed prediction horizon.
|
||||
The model predicts every `predict_every_n` steps, and returns actions every `return_every_n` steps.
|
||||
Between predictions, previously predicted actions are reused by shifting and repeating the last step.
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
batch = self._prepare_batch_for_inference(batch)
|
||||
|
||||
# Only run model prediction every predict_every_n steps
|
||||
if self.step % self.predict_every_n == 0:
|
||||
actions_pred = self.model(batch)[:, -self.config.inference_horizon :]
|
||||
self.last_action = self.unnormalize_outputs({"action": actions_pred})["action"]
|
||||
else:
|
||||
# Otherwise shift previous predictions and repeat last action
|
||||
self.last_action = self.last_action.roll(-1, dims=1)
|
||||
self.last_action[:, -1] = self.last_action[:, -2]
|
||||
|
||||
self.step += 1
|
||||
|
||||
# Return chunked actions for return_every_n steps
|
||||
action = self._chunk_actions(self.last_action)
|
||||
for _ in range(self.return_every_n - 1):
|
||||
self.last_action = self.last_action.roll(-1, dims=1)
|
||||
self.last_action[:, -1] = self.last_action[:, -2]
|
||||
action = self._chunk_actions(self.last_action)
|
||||
|
||||
return action
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
lookback_ind = torch.randint(0, 2 * self.config.lookback_aug + 1, (1,)).item()
|
||||
for k in list(self.model.obs_mapping.values()) + list(self.image_names) + ["action", "action_is_pad"]:
|
||||
if k != "observation.images":
|
||||
batch[k] = torch.cat(
|
||||
[
|
||||
batch[k][:, lookback_ind : lookback_ind + 1],
|
||||
batch[k][:, 2 * self.config.lookback_aug + 1 :],
|
||||
],
|
||||
1,
|
||||
)
|
||||
batch = self.normalize_targets(self.normalize_inputs(batch))
|
||||
|
||||
if len(self.config.image_features) > 0:
|
||||
scale = 1 - torch.rand(1) * (1 - self.crop_scale)
|
||||
new_shape = (
|
||||
int(self.config.rescale_shape[0] * scale),
|
||||
int(self.config.rescale_shape[1] * scale),
|
||||
)
|
||||
crop_transform = transforms.RandomCrop(new_shape)
|
||||
|
||||
for k in self.image_names:
|
||||
batch_size, n_obs, c, h, w = batch[k].shape
|
||||
batch[k] = batch[k].view(batch_size * n_obs, c, h, w)
|
||||
batch[k] = crop_transform(self.resize_transform(batch[k]))
|
||||
batch[k] = batch[k].view(batch_size, n_obs, c, *batch[k].shape[-2:])
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.image_names], dim=2).flatten(
|
||||
1, 2
|
||||
) # batch_size, n_obs * n_cam, c, h, w
|
||||
|
||||
# Add random noise to states during training
|
||||
# TODO(jadechoghari): better to move this to the dataloader
|
||||
if self.state_noise is not None:
|
||||
for k in self.model.obs_mapping.values():
|
||||
if k != "observation.images":
|
||||
batch[k] += (torch.rand_like(batch[k]) * 2 - 1) * self.state_noise
|
||||
|
||||
actions_hat = self.model(batch)
|
||||
|
||||
l1_loss = nn.functional.l1_loss(batch["action"], actions_hat, reduction="none")
|
||||
rev_padding = (~batch["action_is_pad"]).unsqueeze(-1)
|
||||
|
||||
# Apply padding, weights and decay to the loss
|
||||
l1_loss = (l1_loss * rev_padding * self.loss_weights).mean()
|
||||
|
||||
loss_dict = {"l1_loss": l1_loss.item()}
|
||||
loss = l1_loss
|
||||
|
||||
# Reduce the aggressiveness of augmentations
|
||||
self.state_noise *= self.config.noise_decay
|
||||
self.crop_scale = 1 - (1 - self.crop_scale) * self.config.noise_decay
|
||||
|
||||
return loss, loss_dict
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_name_or_path, *args, **kwargs):
|
||||
"""Load model from pretrained checkpoint and merge LoRA after loading"""
|
||||
policy = super().from_pretrained(pretrained_name_or_path, *args, **kwargs)
|
||||
|
||||
if getattr(policy.config, "merge_lora", False):
|
||||
print("Merging LoRA after loading pretrained model...")
|
||||
policy.model = merge_lora_weights(policy.model)
|
||||
|
||||
return policy
|
||||
|
||||
|
||||
class LoRAConv2d(nn.Module):
|
||||
"""
|
||||
Applies Low-Rank Adaptation (LoRA) to a Conv2D layer.
|
||||
|
||||
LoRA adds trainable low-rank matrices (A and B) to adapt pretrained weights without full fine-tuning.
|
||||
The adaptation is merged into the base conv weights via `merge_lora()` after training.
|
||||
|
||||
Args:
|
||||
base_conv (nn.Conv2d): The original convolutional layer to be adapted.
|
||||
rank (int): The rank of the low-rank approximation (default: 4).
|
||||
"""
|
||||
|
||||
def __init__(self, base_conv: nn.Conv2d, rank: int = 4):
|
||||
super().__init__()
|
||||
self.base_conv = base_conv
|
||||
|
||||
# Flatten the original conv weight
|
||||
out_channels, in_channels, kh, kw = base_conv.weight.shape
|
||||
self.weight_shape = (out_channels, in_channels, kh, kw)
|
||||
fan_in = in_channels * kh * kw
|
||||
|
||||
# Low-rank trainable matrices A and B
|
||||
self.lora_A = nn.Parameter(torch.normal(0, 0.02, (out_channels, rank)))
|
||||
self.lora_B = nn.Parameter(torch.normal(0, 0.02, (rank, fan_in)))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
lora_update = torch.matmul(self.lora_A, self.lora_B).view(self.weight_shape)
|
||||
|
||||
return nn.functional.conv2d(
|
||||
x,
|
||||
self.base_conv.weight + lora_update,
|
||||
self.base_conv.bias,
|
||||
stride=self.base_conv.stride,
|
||||
padding=self.base_conv.padding,
|
||||
dilation=self.base_conv.dilation,
|
||||
groups=self.base_conv.groups,
|
||||
)
|
||||
|
||||
def merge_lora(self) -> nn.Conv2d:
|
||||
"""Merge LoRA weights into the base convolution and return a standard Conv2d layer"""
|
||||
lora_update = torch.matmul(self.lora_A, self.lora_B).view(self.weight_shape)
|
||||
self.base_conv.weight.copy_(self.base_conv.weight + lora_update)
|
||||
|
||||
return self.base_conv
|
||||
|
||||
|
||||
def replace_conv2d_with_lora(module: nn.Module, rank: int = 4) -> nn.Module:
|
||||
"""Recursively replace Conv2d layers with LoRAConv2d in the module"""
|
||||
for name, child in list(module.named_children()):
|
||||
if isinstance(child, nn.Conv2d):
|
||||
setattr(module, name, LoRAConv2d(child, rank))
|
||||
else:
|
||||
replace_conv2d_with_lora(child, rank)
|
||||
return module
|
||||
|
||||
|
||||
def merge_lora_weights(module: nn.Module) -> nn.Module:
|
||||
"""Recursively merge LoRA weights in the module"""
|
||||
for name, child in list(module.named_children()):
|
||||
if isinstance(child, LoRAConv2d):
|
||||
setattr(module, name, child.merge_lora())
|
||||
else:
|
||||
merge_lora_weights(child)
|
||||
return module
|
||||
|
||||
|
||||
def add_lora_to_backbone(backbone: nn.Module, rank: int = 4) -> nn.Module:
|
||||
"""
|
||||
Adds LoRA to a convolutional backbone by replacing Conv2d layers
|
||||
and freezing all other weights except LoRA layers and the final classifier.
|
||||
"""
|
||||
replace_conv2d_with_lora(backbone, rank)
|
||||
|
||||
for name, param in backbone.named_parameters():
|
||||
if "lora_" in name or name.startswith("fc"):
|
||||
param.requires_grad = True
|
||||
else:
|
||||
param.requires_grad = False
|
||||
|
||||
return backbone
|
||||
|
||||
|
||||
def create_sinusoidal_pos_embedding(num_positions: int, dimension: int) -> Tensor:
|
||||
"""Generates sinusoidal positional embeddings like in the original Transformer paper."""
|
||||
position = torch.arange(num_positions, dtype=torch.float).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, dimension, 2, dtype=torch.float) * (-math.log(10000.0) / dimension))
|
||||
pe = torch.zeros(num_positions, dimension)
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
return pe
|
|
@ -24,6 +24,7 @@ from lerobot.common.envs.configs import EnvConfig
|
|||
from lerobot.common.envs.utils import env_to_policy_features
|
||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.dot.configuration_dot import DOTConfig
|
||||
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
|
@ -59,6 +60,10 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
|
|||
from lerobot.common.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy
|
||||
|
||||
return PI0FASTPolicy
|
||||
elif name == "dot":
|
||||
from lerobot.common.policies.dot.modeling_dot import DOTPolicy
|
||||
|
||||
return DOTPolicy
|
||||
else:
|
||||
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
||||
|
||||
|
@ -76,6 +81,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
|||
return PI0Config(**kwargs)
|
||||
elif policy_type == "pi0fast":
|
||||
return PI0FASTConfig(**kwargs)
|
||||
elif policy_type == "dot":
|
||||
return DOTConfig(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
||||
|
||||
|
|
Loading…
Reference in New Issue