merge branch + cleanup

This commit is contained in:
root 2025-02-16 19:28:17 +00:00
commit 3b8a85a3f2
5 changed files with 610 additions and 1 deletions

View File

@ -20,7 +20,7 @@ from pathlib import Path
import draccus
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR, LRScheduler
from lerobot.common.constants import SCHEDULER_STATE
from lerobot.common.datasets.utils import write_json
@ -120,3 +120,16 @@ def load_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> LRScheduler:
state_dict = deserialize_json_into_object(save_dir / SCHEDULER_STATE, scheduler.state_dict())
scheduler.load_state_dict(state_dict)
return scheduler
@LRSchedulerConfig.register_subclass("cosine_annealing")
@dataclass
class CosineAnnealingSchedulerConfig(LRSchedulerConfig):
"""Implements Cosine Annealing learning rate scheduler"""
min_lr: float = 0 # Minimum learning rate
T_max: int = 100000 # Number of iterations for a full decay (half-cycle)
num_warmup_steps: int = 0 # Not used but somehow required by the parent class
def build(self, optimizer: Optimizer, num_training_steps: int) -> LRScheduler:
return CosineAnnealingLR(optimizer, T_max=self.T_max, eta_min=self.min_lr)

View File

@ -1,5 +1,6 @@
from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .dot.configuration_dot import DOTConfig as DOTConfig
from .pi0.configuration_pi0 import PI0Config as PI0Config
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig

View File

@ -0,0 +1,144 @@
from dataclasses import dataclass, field
from lerobot.common.optim.optimizers import AdamWConfig
from lerobot.common.optim.schedulers import CosineAnnealingSchedulerConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
@PreTrainedConfig.register_subclass("dot")
@dataclass
class DOTConfig(PreTrainedConfig):
"""Configuration for DOT (Decision Transformer) policy.
You need to change some parameters in this configuration to make it work for your problem:
FPS/prediction horizon related features - may need to adjust:
- train_horizon: the number of steps to predict during training
- inference_horizon: the number of steps to predict during validation
- alpha: exponential factor for weighting of each next action
- train_alpha: exponential factor for action weighting during training
For inference speed optimization:
- predict_every_n: number of frames to predict in the future
- return_every_n: instead of returning next predicted actions, returns nth future action
"""
# Input / output structure.
n_obs_steps: int = 3
train_horizon: int = 20
inference_horizon: int = 20
lookback_obs_steps: int = 10
lookback_aug: int = 5
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.MEAN_STD,
"STATE": NormalizationMode.MIN_MAX,
"ENV": NormalizationMode.MIN_MAX,
"ACTION": NormalizationMode.MIN_MAX,
}
)
# Not sure if there is a better way to do this with new config system.
override_dataset_stats: bool = False
new_dataset_stats: dict[str, dict[str, list[float]]] = field(
default_factory=lambda: {
"action": {"max": [512.0] * 2, "min": [0.0] * 2},
"observation.environment_state": {"max": [512.0] * 16, "min": [0.0] * 16},
"observation.state": {"max": [512.0] * 2, "min": [0.0] * 2},
}
)
# Architecture.
vision_backbone: str = "resnet18"
pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1"
pre_norm: bool = True
lora_rank: int = 20
merge_lora: bool = False
dim_model: int = 128
n_heads: int = 8
dim_feedforward: int = 512
n_decoder_layers: int = 8
rescale_shape: tuple[int, int] = (96, 96)
# Augmentation.
crop_scale: float = 0.8
state_noise: float = 0.01
noise_decay: float = 0.999995
# Training and loss computation.
dropout: float = 0.1
# Weighting and inference.
alpha: float = 0.75
train_alpha: float = 0.9
predict_every_n: int = 1
return_every_n: int = 1
# Training preset
optimizer_lr: float = 1.0e-4
optimizer_min_lr: float = 1.0e-4
optimizer_lr_cycle_steps: int = 300000
optimizer_weight_decay: float = 1e-5
def __post_init__(self):
super().__post_init__()
if self.predict_every_n > self.inference_horizon:
raise ValueError(
f"predict_every_n ({self.predict_every_n}) must be less than or equal to horizon ({self.inference_horizon})."
)
if self.return_every_n > self.inference_horizon:
raise ValueError(
f"return_every_n ({self.return_every_n}) must be less than or equal to horizon ({self.inference_horizon})."
)
if self.predict_every_n > self.inference_horizon // self.return_every_n:
raise ValueError(
f"predict_every_n ({self.predict_every_n}) must be less than or equal to horizon // return_every_n({self.inference_horizon // self.return_every_n})."
)
if self.train_horizon < self.inference_horizon:
raise ValueError(
f"train_horizon ({self.train_horizon}) must be greater than or equal to horizon ({self.inference_horizon})."
)
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
weight_decay=self.optimizer_weight_decay,
)
def get_scheduler_preset(self) -> None:
return CosineAnnealingSchedulerConfig(
min_lr=self.optimizer_min_lr, T_max=self.optimizer_lr_cycle_steps
)
def validate_features(self) -> None:
if not self.image_features and not self.env_state_feature:
raise ValueError("You must provide at least one image or the environment state among the inputs.")
@property
def observation_delta_indices(self) -> None:
far_past_obs = list(
range(
-self.lookback_aug - self.lookback_obs_steps, self.lookback_aug + 1 - self.lookback_obs_steps
)
)
recent_obs = list(range(2 - self.n_obs_steps, 1))
return far_past_obs + recent_obs
@property
def action_delta_indices(self) -> list:
far_past_actions = list(
range(
-self.lookback_aug - self.lookback_obs_steps, self.lookback_aug + 1 - self.lookback_obs_steps
)
)
recent_actions = list(range(2 - self.n_obs_steps, self.train_horizon))
return far_past_actions + recent_actions
@property
def reward_delta_indices(self) -> None:
return None

View File

@ -0,0 +1,444 @@
#!/usr/bin/env python
"""The implementation of the Decoder-Only Transformer (DOT) policy.
More details here: https://github.com/IliaLarchenko/dot_policy
"""
import math
import torch
import torchvision
from torch import Tensor, nn
from torchvision import transforms
from torchvision.ops.misc import FrozenBatchNorm2d
from torchvision.transforms.functional import InterpolationMode
from lerobot.common.policies.dot.configuration_dot import DOTConfig
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy
class DOT(nn.Module):
def __init__(self, config: DOTConfig):
super().__init__()
self.config = config
self.projections = nn.ModuleDict()
self.n_features = 0
self.image_names = sorted(config.image_features.keys())
# I use one backbone for all cameras and simply project the output to the model dimension
if len(self.image_names) > 0:
backbone = getattr(torchvision.models, self.config.vision_backbone)(
weights=self.config.pretrained_backbone_weights,
norm_layer=FrozenBatchNorm2d,
)
backbone.fc = nn.Linear(backbone.fc.in_features, self.config.dim_model)
self.projections["images"] = add_lora_to_backbone(backbone, rank=config.lora_rank)
self.n_features += len(self.image_names) * self.config.n_obs_steps
if self.config.robot_state_feature:
self.projections["state"] = nn.Linear(
self.config.robot_state_feature.shape[0], self.config.dim_model
)
self.n_features += self.config.n_obs_steps
if self.config.env_state_feature:
self.projections["environment_state"] = nn.Linear(
self.config.env_state_feature.shape[0], self.config.dim_model
)
self.n_features += self.config.n_obs_steps
self.projections_names = sorted(self.projections.keys())
obs_mapping = {
"images": "observation.images",
"state": "observation.state",
"environment_state": "observation.environment_state",
}
self.obs_mapping = {k: v for k, v in obs_mapping.items() if k in self.projections_names}
# Extra trainable vector that I add to the input features (not necessary)
self.prefix_input = nn.Parameter(torch.randn(1, 1, config.dim_model))
# Setup transformer decoder
dec_layer = nn.TransformerDecoderLayer(
d_model=self.config.dim_model,
nhead=self.config.n_heads,
dim_feedforward=self.config.dim_feedforward,
dropout=self.config.dropout,
batch_first=True,
norm_first=self.config.pre_norm,
)
decoder_norm = nn.LayerNorm(self.config.dim_model)
self.decoder = nn.TransformerDecoder(
dec_layer, num_layers=self.config.n_decoder_layers, norm=decoder_norm
)
# Decoder uses as input not-trainable positional encodings
decoder_pos = create_sinusoidal_pos_embedding(
config.train_horizon + config.lookback_obs_steps, config.dim_model
)
decoder_pos = torch.cat(
[
decoder_pos[:1],
decoder_pos[-config.train_horizon - config.n_obs_steps + 2 :],
],
dim=0,
)
self.register_buffer("decoder_pos", decoder_pos)
decoder_pos_inf = self.decoder_pos[
: self.decoder_pos.shape[0] + self.config.inference_horizon - self.config.train_horizon
]
self.register_buffer("decoder_pos_inf", decoder_pos_inf)
mask = torch.zeros(len(decoder_pos), len(decoder_pos), dtype=torch.bool)
mask[
: len(decoder_pos) + config.inference_horizon - config.train_horizon,
len(decoder_pos) + config.inference_horizon - config.train_horizon :,
] = True
self.register_buffer("mask", mask)
# Input features need a trainable positional embeddings
self.inputs_pos_emb = nn.Parameter(torch.empty(1, self.n_features, self.config.dim_model))
nn.init.uniform_(
self.inputs_pos_emb,
-((1 / self.config.dim_model) ** 0.5),
(1 / self.config.dim_model) ** 0.5,
)
# The output actions are generated by a linear layer
self.action_head = nn.Linear(self.config.dim_model, self.config.action_feature.shape[0])
def _process_inputs(self, batch):
# Project all inputs to the model dimension and concatenate them
inputs_projections_list = []
for state in self.projections_names:
batch_state = self.obs_mapping[state]
if batch_state in batch:
bs, n_obs, *obs_shape = batch[batch_state].shape
enc = self.projections[state](batch[batch_state].view(bs * n_obs, *obs_shape)).view(
bs, n_obs, -1
)
inputs_projections_list.append(enc)
return torch.cat(inputs_projections_list, dim=1)
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
inputs_projections = self._process_inputs(batch)
bs = inputs_projections.shape[0]
inputs_projections += self.inputs_pos_emb.expand(bs, -1, -1)
inputs_projections = torch.cat([self.prefix_input.expand(bs, -1, -1), inputs_projections], dim=1)
if self.training:
decoder_out = self.decoder(self.decoder_pos.expand(bs, -1, -1), inputs_projections, self.mask)
else:
decoder_out = self.decoder(self.decoder_pos_inf.expand(bs, -1, -1), inputs_projections)
return self.action_head(decoder_out)
class DOTPolicy(PreTrainedPolicy):
name = "dot"
config_class = DOTConfig
def __init__(
self,
config: DOTConfig,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
super().__init__(config)
config.validate_features()
self.config = config
self.image_names = sorted(config.image_features.keys())
if config.override_dataset_stats:
if dataset_stats is None:
dataset_stats = {}
for k, v in config.new_dataset_stats.items():
if k not in dataset_stats:
dataset_stats[k] = {}
for k1, v1 in v.items():
dataset_stats[k][k1] = torch.tensor(v1)
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.model = DOT(self.config)
self.state_noise = self.config.state_noise
self.crop_scale = self.config.crop_scale
self.alpha = self.config.alpha
self.inference_horizon = self.config.inference_horizon
self.return_every_n = self.config.return_every_n
self.predict_every_n = self.config.predict_every_n
# Inference action chunking and observation queues
self._old_predictions = None
self._input_buffers = {}
# Weights used for chunking
action_weights = self.alpha ** torch.arange(self.inference_horizon).float()
action_weights /= action_weights.sum()
action_weights = action_weights.view(1, -1, 1)
self.register_buffer("action_weights", action_weights)
# Weights for the loss computations
# Actions that are further in the future are weighted less
loss_weights = torch.ones(self.config.train_horizon + self.config.n_obs_steps - 1)
loss_weights[-self.config.train_horizon :] = (
self.config.train_alpha ** torch.arange(self.config.train_horizon).float()
)
loss_weights /= loss_weights.mean()
loss_weights = loss_weights.view(1, -1, 1)
self.register_buffer("loss_weights", loss_weights)
# TODO: properly move it to dataloader and process on CPU
# Nearest interpolation is required for PushT but may be not the best in general
self.resize_transform = transforms.Resize(
config.rescale_shape, interpolation=InterpolationMode.NEAREST
)
self.step = 0
self.last_action = None
def reset(self):
self._old_predictions = None
self._input_buffers = {}
self.last_action = None
self.step = 0
def get_optim_params(self) -> dict:
return self.model.parameters()
def _update_observation_buffers(self, buffer_name: str, observation: Tensor) -> Tensor:
# We keep the last lookback_obs_steps + 1 of each input in the queue
# Every step they are updated and the oldest one is removed
if buffer_name not in self._input_buffers:
self._input_buffers[buffer_name] = observation.unsqueeze(1).repeat(
1,
self.config.lookback_obs_steps + 1,
*torch.ones(len(observation.shape[1:])).int(),
)
else:
self._input_buffers[buffer_name] = self._input_buffers[buffer_name].roll(shifts=-1, dims=1)
self._input_buffers[buffer_name][:, -1] = observation
return torch.cat(
[
self._input_buffers[buffer_name][:, :1],
self._input_buffers[buffer_name][:, -(self.config.n_obs_steps - 1) :],
],
dim=1,
)
def _prepare_batch_for_inference(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = self.normalize_inputs(batch)
# Resize and stack all images
if len(self.image_names) > 0:
batch["observation.images"] = torch.stack(
[self.resize_transform(batch[k]) for k in self.image_names],
dim=1,
) # bs, n_cam, c, h, w
# Update observation queues for all inputs and stack the last n_obs_steps
for name, batch_name in self.model.obs_mapping.items():
batch[batch_name] = self._update_observation_buffers(name, batch[batch_name])
# Reshape images tensor to keep the same order as during training
if "observation.images" in batch:
batch["observation.images"] = batch["observation.images"].flatten(1, 2)
# bs, n_obs * n_cam, c, h, w
return batch
def _chunk_actions(self, actions: Tensor) -> Tensor:
# Store the previous action predictions in a buffer
# Compute the weighted average of the inference horizon action predictions
if self._old_predictions is not None:
self._old_predictions[:, 0] = actions
else:
self._old_predictions = actions.unsqueeze(1).repeat(1, self.config.inference_horizon, 1, 1)
action = (self._old_predictions[:, :, 0] * self.action_weights).sum(dim=1)
self._old_predictions = self._old_predictions.roll(shifts=(1, -1), dims=(1, 2))
return action
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
self.eval()
batch = self._prepare_batch_for_inference(batch)
# Only run model prediction every predict_every_n steps
if self.step % self.predict_every_n == 0:
actions_pred = self.model(batch)[:, -self.config.inference_horizon :]
self.last_action = self.unnormalize_outputs({"action": actions_pred})["action"]
else:
# Otherwise shift previous predictions and repeat last action
self.last_action = self.last_action.roll(-1, dims=1)
self.last_action[:, -1] = self.last_action[:, -2]
self.step += 1
# Return chunked actions for return_every_n steps
action = self._chunk_actions(self.last_action)
for _ in range(self.return_every_n - 1):
self.last_action = self.last_action.roll(-1, dims=1)
self.last_action[:, -1] = self.last_action[:, -2]
action = self._chunk_actions(self.last_action)
return action
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
lookback_ind = torch.randint(0, 2 * self.config.lookback_aug + 1, (1,)).item()
for k in list(self.model.obs_mapping.values()) + list(self.image_names) + ["action", "action_is_pad"]:
if k != "observation.images":
batch[k] = torch.cat(
[
batch[k][:, lookback_ind : lookback_ind + 1],
batch[k][:, 2 * self.config.lookback_aug + 1 :],
],
1,
)
batch = self.normalize_targets(self.normalize_inputs(batch))
if len(self.config.image_features) > 0:
# Maybe not the best way but it works well
scale = 1 - torch.rand(1) * (1 - self.crop_scale)
new_shape = (
int(self.config.rescale_shape[0] * scale),
int(self.config.rescale_shape[1] * scale),
)
crop_transform = transforms.RandomCrop(new_shape)
for k in self.image_names:
bs, n_obs, c, h, w = batch[k].shape
batch[k] = batch[k].view(bs * n_obs, c, h, w)
batch[k] = crop_transform(self.resize_transform(batch[k]))
batch[k] = batch[k].view(bs, n_obs, c, *batch[k].shape[-2:])
batch["observation.images"] = torch.stack([batch[k] for k in self.image_names], dim=2).flatten(
1, 2
) # bs, n_obs * n_cam, c, h, w
# Add random noise to states during training
# TODO: it should be done in the dataloader
if self.state_noise is not None:
for k in self.model.obs_mapping.values():
if k != "observation.images":
batch[k] += (torch.rand_like(batch[k]) * 2 - 1) * self.state_noise
actions_hat = self.model(batch)
loss = nn.functional.l1_loss(batch["action"], actions_hat, reduction="none")
rev_padding = (~batch["action_is_pad"]).unsqueeze(-1)
# Apply padding, weights and decay to the loss
loss = (loss * rev_padding * self.loss_weights).mean()
loss_dict = {"loss": loss}
# Reduce the aggressiveness of augmentations
self.state_noise *= self.config.noise_decay
self.crop_scale = 1 - (1 - self.crop_scale) * self.config.noise_decay
return loss_dict
@classmethod
def from_pretrained(cls, pretrained_name_or_path, *args, **kwargs):
"""Load model from pretrained checkpoint and merge LoRA after loading"""
policy = super().from_pretrained(pretrained_name_or_path, *args, **kwargs)
if getattr(policy.config, "merge_lora", False):
print("Merging LoRA after loading pretrained model...")
policy.model = merge_lora_weights(policy.model)
return policy
class LoRAConv2d(nn.Module):
def __init__(self, base_conv, rank=4):
super().__init__()
self.base_conv = base_conv
# Flatten the original conv weight
out_channels, in_channels, kh, kw = base_conv.weight.shape
self.weight_shape = (out_channels, in_channels, kh, kw)
fan_in = in_channels * kh * kw
self.lora_A = nn.Parameter(torch.normal(0, 0.02, (out_channels, rank)))
self.lora_B = nn.Parameter(torch.normal(0, 0.02, (rank, fan_in)))
def forward(self, x):
lora_update = torch.matmul(self.lora_A, self.lora_B).view(self.weight_shape)
return nn.functional.conv2d(
x,
self.base_conv.weight + lora_update,
self.base_conv.bias,
stride=self.base_conv.stride,
padding=self.base_conv.padding,
dilation=self.base_conv.dilation,
groups=self.base_conv.groups,
)
def merge_lora(self):
"""Merge LoRA weights into the base convolution and return a standard Conv2d layer"""
lora_update = torch.matmul(self.lora_A, self.lora_B).view(self.weight_shape)
self.base_conv.weight.copy_(self.base_conv.weight + lora_update)
return self.base_conv
def replace_conv2d_with_lora(module, rank=4):
"""Recursively replace Conv2d layers with LoRAConv2d in the module"""
for name, child in list(module.named_children()):
if isinstance(child, nn.Conv2d):
setattr(module, name, LoRAConv2d(child, rank))
else:
replace_conv2d_with_lora(child, rank)
return module
def merge_lora_weights(module):
"""Recursively merge LoRA weights in the module"""
for name, child in list(module.named_children()):
if isinstance(child, LoRAConv2d):
setattr(module, name, child.merge_lora())
else:
merge_lora_weights(child)
return module
def add_lora_to_backbone(backbone, rank=4, verbose=True):
replace_conv2d_with_lora(backbone, rank)
for name, param in backbone.named_parameters():
if "lora_" in name or name.startswith("fc"):
param.requires_grad = True
else:
param.requires_grad = False
return backbone
def create_sinusoidal_pos_embedding(num_positions: int, dimension: int) -> Tensor:
position = torch.arange(num_positions, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, dimension, 2, dtype=torch.float) * (-math.log(10000.0) / dimension))
pe = torch.zeros(num_positions, dimension)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
return pe

View File

@ -25,6 +25,7 @@ from lerobot.common.envs.configs import EnvConfig
from lerobot.common.envs.utils import env_to_policy_features
from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.dot.configuration_dot import DOTConfig
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
@ -55,6 +56,10 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
return PI0Policy
elif name == "dot":
from lerobot.common.policies.dot.modeling_dot import DOTPolicy
return DOTPolicy
else:
raise NotImplementedError(f"Policy with name {name} is not implemented.")
@ -70,6 +75,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return VQBeTConfig(**kwargs)
elif policy_type == "pi0":
return PI0Config(**kwargs)
elif policy_type == "dot":
return DOTConfig(**kwargs)
else:
raise ValueError(f"Policy type '{policy_type}' is not available.")