From 78d3ba8db26143a4ba9610686f0c438753533bbc Mon Sep 17 00:00:00 2001 From: Ilia Date: Sat, 8 Feb 2025 18:21:23 +0700 Subject: [PATCH 1/4] Added dot-polity with new config (to be debugged) --- lerobot/common/optim/schedulers.py | 15 +- lerobot/common/policies/__init__.py | 1 + .../common/policies/dot/configuration_dot.py | 133 ++++++ lerobot/common/policies/dot/modeling_dot.py | 433 ++++++++++++++++++ lerobot/common/policies/factory.py | 7 + 5 files changed, 588 insertions(+), 1 deletion(-) create mode 100644 lerobot/common/policies/dot/configuration_dot.py create mode 100644 lerobot/common/policies/dot/modeling_dot.py diff --git a/lerobot/common/optim/schedulers.py b/lerobot/common/optim/schedulers.py index 80d83bdf..da1e048e 100644 --- a/lerobot/common/optim/schedulers.py +++ b/lerobot/common/optim/schedulers.py @@ -4,7 +4,7 @@ from dataclasses import asdict, dataclass import draccus from torch.optim import Optimizer -from torch.optim.lr_scheduler import LambdaLR, LRScheduler +from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR, LRScheduler @dataclass @@ -89,3 +89,16 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig): return cosine_decay_schedule(current_step) return LambdaLR(optimizer, lr_lambda, -1) + + +@LRSchedulerConfig.register_subclass("cosine_annealing") +@dataclass +class CosineAnnealingSchedulerConfig(LRSchedulerConfig): + """Implements Cosine Annealing learning rate scheduler""" + + min_lr: float = 0 # Minimum learning rate + T_max: int = 100000 # Number of iterations for a full decay (half-cycle) + num_warmup_steps: int = 0 # Not used but somehow required by the parent class + + def build(self, optimizer: Optimizer, num_training_steps: int) -> LRScheduler: + return CosineAnnealingLR(optimizer, T_max=self.T_max, eta_min=self.min_lr) diff --git a/lerobot/common/policies/__init__.py b/lerobot/common/policies/__init__.py index 2e4486ef..8432b684 100644 --- a/lerobot/common/policies/__init__.py +++ b/lerobot/common/policies/__init__.py @@ -1,5 +1,6 @@ from .act.configuration_act import ACTConfig as ACTConfig from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig +from .dot.configuration_dot import DOTConfig as DOTConfig from .pi0.configuration_pi0 import PI0Config as PI0Config from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig diff --git a/lerobot/common/policies/dot/configuration_dot.py b/lerobot/common/policies/dot/configuration_dot.py new file mode 100644 index 00000000..5b5a7aba --- /dev/null +++ b/lerobot/common/policies/dot/configuration_dot.py @@ -0,0 +1,133 @@ +from dataclasses import dataclass, field + +from lerobot.common.optim.optimizers import AdamWConfig +from lerobot.common.optim.schedulers import CosineAnnealingSchedulerConfig +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import NormalizationMode + + +@PreTrainedConfig.register_subclass("dot") +@dataclass +class DOTConfig(PreTrainedConfig): + """Configuration for DOT (Decision Transformer) policy. + + You need to change some parameters in this configuration to make it work for your problem: + + FPS/prediction horizon related features - may need to adjust: + - train_horizon: the number of steps to predict during training + - inference_horizon: the number of steps to predict during validation + - alpha: exponential factor for weighting of each next action + - train_alpha: exponential factor for action weighting during training + + For inference speed optimization: + - predict_every_n: number of frames to predict in the future + - return_every_n: instead of returning next predicted actions, returns nth future action + """ + + # Input / output structure. + n_obs_steps: int = 3 + train_horizon: int = 20 + inference_horizon: int = 20 + lookback_obs_steps: int = 10 + lookback_aug: int = 5 + + normalization_mapping: dict[str, NormalizationMode] = field( + default_factory=lambda: { + "VISUAL": NormalizationMode.MEAN_STD, + "STATE": NormalizationMode.MIN_MAX, + "ACTION": NormalizationMode.MIN_MAX, + } + ) + + # Architecture. + vision_backbone: str = "resnet18" + pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1" + pre_norm: bool = True + lora_rank: int = 20 + merge_lora: bool = True + + dim_model: int = 128 + n_heads: int = 8 + dim_feedforward: int = 512 + n_decoder_layers: int = 8 + rescale_shape: tuple[int, int] = (96, 96) + + # Augmentation. + crop_scale: float = 1.0 + state_noise: float = 0.01 + noise_decay: float = 0.999995 + + # Training and loss computation. + dropout: float = 0.1 + + # Weighting and inference. + alpha: float = 0.75 + train_alpha: float = 0.9 + predict_every_n: int = 1 + return_every_n: int = 1 + + # Training preset + optimizer_lr: float = 1.0e-4 + optimizer_min_lr: float = 1.0e-4 + optimizer_lr_cycle_steps: int = 300000 + optimizer_weight_decay: float = 1e-5 + + def __post_init__(self): + super().__post_init__() + if self.predict_every_n > self.inference_horizon: + raise ValueError( + f"predict_every_n ({self.predict_every_n}) must be less than or equal to horizon ({self.inference_horizon})." + ) + if self.return_every_n > self.inference_horizon: + raise ValueError( + f"return_every_n ({self.return_every_n}) must be less than or equal to horizon ({self.inference_horizon})." + ) + if self.predict_every_n > self.inference_horizon // self.return_every_n: + raise ValueError( + f"predict_every_n ({self.predict_every_n}) must be less than or equal to horizon // return_every_n({self.inference_horizon // self.return_every_n})." + ) + if self.train_horizon < self.inference_horizon: + raise ValueError( + f"train_horizon ({self.train_horizon}) must be greater than or equal to horizon ({self.inference_horizon})." + ) + + def get_optimizer_preset(self) -> AdamWConfig: + return AdamWConfig( + lr=self.optimizer_lr, + weight_decay=self.optimizer_weight_decay, + ) + + def get_scheduler_preset(self) -> None: + return CosineAnnealingSchedulerConfig( + min_lr=self.optimizer_min_lr, T_max=self.optimizer_lr_cycle_steps + ) + + def validate_features(self) -> None: + if not self.image_features and not self.env_state_feature: + raise ValueError("You must provide at least one image or the environment state among the inputs.") + + @property + def observation_delta_indices(self) -> None: + far_past_obs = list( + range( + -self.lookback_aug - self.lookback_obs_steps, self.lookback_aug + 1 - self.lookback_obs_steps + ) + ) + recent_obs = list(range(2 - self.n_obs_steps, 1)) + + return far_past_obs + recent_obs + + @property + def action_delta_indices(self) -> list: + far_past_actions = list( + range( + -self.lookback_aug - self.lookback_obs_steps, self.lookback_aug + 1 - self.lookback_obs_steps + ) + ) + recent_actions = list(range(2 - self.n_obs_steps, self.train_horizon)) + + return far_past_actions + recent_actions + + @property + def reward_delta_indices(self) -> None: + return None diff --git a/lerobot/common/policies/dot/modeling_dot.py b/lerobot/common/policies/dot/modeling_dot.py new file mode 100644 index 00000000..477cfdf8 --- /dev/null +++ b/lerobot/common/policies/dot/modeling_dot.py @@ -0,0 +1,433 @@ +#!/usr/bin/env python + +"""The implementation of the Decoder-Only Transformer (DOT) policy. + +More details here: https://github.com/IliaLarchenko/dot_policy +""" + +import math + +import torch +import torchvision +from torch import Tensor, nn +from torchvision import transforms +from torchvision.ops.misc import FrozenBatchNorm2d +from torchvision.transforms.functional import InterpolationMode + +from lerobot.common.policies.dot.configuration_dot import DOTConfig +from lerobot.common.policies.normalize import Normalize, Unnormalize +from lerobot.common.policies.pretrained import PreTrainedPolicy + + +class DOT(nn.Module): + def __init__(self, config: DOTConfig): + super().__init__() + self.config = config + + self.projections = nn.ModuleDict() + self.n_features = 0 + + # I use one backbone for all cameras and simply project the output to the model dimension + if len(self.config.image_features) > 0: + backbone = getattr(torchvision.models, self.config.vision_backbone)( + weights=self.config.pretrained_backbone_weights, + norm_layer=FrozenBatchNorm2d, + ) + backbone.fc = nn.Linear(backbone.fc.in_features, self.config.dim_model) + + self.projections["images"] = add_lora_to_backbone(backbone, rank=config.lora_rank) + self.n_features += len(self.config.image_features) * self.config.n_obs_steps + + if self.config.robot_state_feature: + self.projections["state"] = nn.Linear( + self.config.robot_state_feature.shape[0], self.config.dim_model + ) + self.n_features += self.config.n_obs_steps + + if self.config.env_state_feature: + self.projections["env_state"] = nn.Linear( + self.config.env_state_feature.shape[0], self.config.dim_model + ) + self.n_features += self.config.n_obs_steps + + self.projections_names = sorted(self.projections.keys()) + obs_mapping = { + "images": "observation.images", + "state": "observation.state", + "env_state": "observation.environment_state", + } + self.obs_mapping = {k: v for k, v in obs_mapping.items() if k in self.projections_names} + + # Extra trainable vector that I add to the input features (not necessary) + self.prefix_input = nn.Parameter(torch.randn(1, 1, config.dim_model)) + + # Setup transformer decoder + dec_layer = nn.TransformerDecoderLayer( + d_model=self.config.dim_model, + nhead=self.config.n_heads, + dim_feedforward=self.config.dim_feedforward, + dropout=self.config.dropout, + batch_first=True, + norm_first=self.config.pre_norm, + ) + + decoder_norm = nn.LayerNorm(self.config.dim_model) + self.decoder = nn.TransformerDecoder( + dec_layer, num_layers=self.config.n_decoder_layers, norm=decoder_norm + ) + + # Decoder uses as input not-trainable positional encodings + decoder_pos = create_sinusoidal_pos_embedding( + config.train_horizon + config.lookback_obs_steps, config.dim_model + ) + decoder_pos = torch.cat( + [ + decoder_pos[:1], + decoder_pos[-config.train_horizon - config.n_obs_steps + 2 :], + ], + dim=0, + ) + self.register_buffer("decoder_pos", decoder_pos) + + decoder_pos_inf = self.decoder_pos[ + : self.decoder_pos.shape[0] + self.config.inference_horizon - self.config.train_horizon + ] + self.register_buffer("decoder_pos_inf", decoder_pos_inf) + + mask = torch.zeros(len(decoder_pos), len(decoder_pos), dtype=torch.bool) + mask[ + : len(decoder_pos) + config.inference_horizon - config.train_horizon, + len(decoder_pos) + config.inference_horizon - config.train_horizon :, + ] = True + self.register_buffer("mask", mask) + + # Input features need a trainable positional embeddings + self.inputs_pos_emb = nn.Parameter(torch.empty(1, self.n_features, self.config.dim_model)) + nn.init.uniform_( + self.inputs_pos_emb, + -((1 / self.config.dim_model) ** 0.5), + (1 / self.config.dim_model) ** 0.5, + ) + + # The output actions are generated by a linear layer + self.action_head = nn.Linear(self.config.dim_model, self.config.action_feature.shape[0]) + + def _process_inputs(self, batch): + # Project all inputs to the model dimension and concatenate them + inputs_projections_list = [] + + for state in self.projections_names: + batch_state = self.obs_mapping[state] + if batch_state in batch: + bs, n_obs, *obs_shape = batch[batch_state].shape + enc = self.projections[state](batch[batch_state].view(bs * n_obs, *obs_shape)).view( + bs, n_obs, -1 + ) + inputs_projections_list.append(enc) + + return torch.cat(inputs_projections_list, dim=1) + + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]: + inputs_projections = self._process_inputs(batch) + bs = inputs_projections.shape[0] + + inputs_projections += self.inputs_pos_emb.expand(bs, -1, -1) + inputs_projections = torch.cat([self.prefix_input.expand(bs, -1, -1), inputs_projections], dim=1) + + if self.training: + decoder_out = self.decoder(self.decoder_pos.expand(bs, -1, -1), inputs_projections, self.mask) + else: + decoder_out = self.decoder(self.decoder_pos_inf.expand(bs, -1, -1), inputs_projections) + return self.action_head(decoder_out) + + +class DOTPolicy(PreTrainedPolicy): + name = "dot" + config_class = DOTConfig + + def __init__( + self, + config: DOTConfig, + dataset_stats: dict[str, dict[str, Tensor]] | None = None, + ): + super().__init__(config) + config.validate_features() + self.config = config + + self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) + self.normalize_targets = Normalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + self.unnormalize_outputs = Unnormalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + + self.model = DOT(self.config) + + self.return_every_n = self.config.return_every_n + self.state_noise = self.config.state_noise + self.crop_scale = self.config.crop_scale + self.alpha = self.config.alpha + self.inference_horizon = self.config.inference_horizon + + # Inference action chunking and observation queues + self._old_predictions = None + self._input_buffers = {} + + # Weights used for chunking + action_weights = self.alpha ** torch.arange(self.inference_horizon).float() + action_weights /= action_weights.sum() + action_weights = action_weights.view(1, -1, 1) + self.register_buffer("action_weights", action_weights) + + # Weights for the loss computations + # Actions that are further in the future are weighted less + loss_weights = torch.ones(self.config.train_horizon + self.config.n_obs_steps - 1) + loss_weights[-self.config.train_horizon :] = ( + self.config.train_alpha ** torch.arange(self.config.train_horizon).float() + ) + loss_weights /= loss_weights.mean() + loss_weights = loss_weights.view(1, -1, 1) + self.register_buffer("loss_weights", loss_weights) + + # TODO: properly move it to dataloader and process on CPU + # Nearest interpolation is required for PushT but may be not the best in general + self.resize_transform = transforms.Resize( + config.rescale_shape, interpolation=InterpolationMode.NEAREST + ) + + self.predict_every_n = self.config.predict_every_n + self.step = 0 + + def reset(self): + self._old_predictions = None + self._input_buffers = {} + + def get_optim_params(self) -> dict: + return self.model.parameters() + + def _update_observation_buffers(self, buffer_name: str, observation: Tensor) -> Tensor: + # We keep the last lookback_obs_steps + 1 of each input in the queue + # Every step they are updated and the oldest one is removed + if buffer_name not in self._input_buffers: + self._input_buffers[buffer_name] = observation.unsqueeze(1).repeat( + 1, + self.config.lookback_obs_steps + 1, + *torch.ones(len(observation.shape[1:])).int(), + ) + else: + self._input_buffers[buffer_name] = self._input_buffers[buffer_name].roll(shifts=-1, dims=1) + self._input_buffers[buffer_name][:, -1] = observation + + return torch.cat( + [ + self._input_buffers[buffer_name][:, :1], + self._input_buffers[buffer_name][:, -(self.config.n_obs_steps - 1) :], + ], + dim=1, + ) + + def _prepare_batch_for_inference(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + batch = self.normalize_inputs(batch) + + # Resize and stack all images + if len(self.config.image_features) > 0: + batch["observation.images"] = torch.stack( + [self.resize_transform(batch[k]) for k in self.config.image_features], + dim=1, + ) # bs, n_cam, c, h, w + + # Update observation queues for all inputs and stack the last n_obs_steps + for name, batch_name in self.model.obs_mapping.items(): + batch[batch_name] = self._update_observation_buffers(name, batch[batch_name]) + + # Reshape images tensor to keep the same order as during training + if "observation.images" in batch: + batch["observation.images"] = batch["observation.images"].flatten(1, 2) + # bs, n_obs * n_cam, c, h, w + + return batch + + def _chunk_actions(self, actions: Tensor) -> Tensor: + # Store the previous action predictions in a buffer + # Compute the weighted average of the inference horizon action predictions + if self._old_predictions is not None: + self._old_predictions[:, 0] = actions + else: + self._old_predictions = actions.unsqueeze(1).repeat(1, self.config.inference_horizon, 1, 1) + + action = (self._old_predictions[:, :, 0] * self.action_weights).sum(dim=1) + self._old_predictions = self._old_predictions.roll(shifts=(1, -1), dims=(1, 2)) + + return action + + @torch.no_grad() + def select_action(self, batch: dict[str, Tensor]) -> Tensor: + self.eval() + + batch = self._prepare_batch_for_inference(batch) + + # Only run model prediction every predict_every_n steps + if self.step % self.predict_every_n == 0: + actions_pred = self.model(batch)[:, -self.config.inference_horizon :] + self.last_action = self.unnormalize_outputs({"action": actions_pred})["action"] + else: + # Otherwise shift previous predictions and repeat last action + self.last_action = self.last_action.roll(-1, dims=1) + self.last_action[:, -1] = self.last_action[:, -2] + + self.step += 1 + + # Return chunked actions for return_every_n steps + action = self._chunk_actions(self.last_action) + for _ in range(self.return_every_n - 1): + self.last_action = self.last_action.roll(-1, dims=1) + self.last_action[:, -1] = self.last_action[:, -2] + action = self._chunk_actions(self.last_action) + + return action + + def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + lookback_ind = torch.randint(0, 2 * self.config.lookback_aug + 1, (1,)).item() + for k in ( + list(self.model.obs_mapping.values()) + + list(self.config.image_features.keys()) + + ["action", "action_is_pad"] + ): + if k != "observation.images": + batch[k] = torch.cat( + [ + batch[k][:, lookback_ind : lookback_ind + 1], + batch[k][:, 2 * self.config.lookback_aug + 1 :], + ], + 1, + ) + batch = self.normalize_targets(self.normalize_inputs(batch)) + + if len(self.config.image_features) > 0: + # Maybe not the best way but it works well + scale = 1 - torch.rand(1) * (1 - self.crop_scale) + new_shape = ( + int(self.config.rescale_shape[0] * scale), + int(self.config.rescale_shape[1] * scale), + ) + crop_transform = transforms.RandomCrop(new_shape) + + for k in self.config.image_features: + bs, n_obs, c, h, w = batch[k].shape + batch[k] = batch[k].view(bs * n_obs, c, h, w) + batch[k] = crop_transform(self.resize_transform(batch[k])) + batch[k] = batch[k].view(bs, n_obs, c, *batch[k].shape[-2:]) + batch["observation.images"] = torch.stack( + [batch[k] for k in self.config.image_features], dim=2 + ).flatten(1, 2) # bs, n_obs * n_cam, c, h, w + + # Add random noise to states during training + # TODO: it should be done in the dataloader + if self.state_noise is not None: + for k in self.model.obs_mapping.values(): + if k != "observation.images": + batch[k] += (torch.rand_like(batch[k]) * 2 - 1) * self.state_noise + + actions_hat = self.model(batch) + + loss = nn.functional.l1_loss(batch["action"], actions_hat, reduction="none") + rev_padding = (~batch["action_is_pad"]).unsqueeze(-1) + + # Apply padding, weights and decay to the loss + loss = (loss * rev_padding * self.loss_weights).mean() + + loss_dict = {"loss": loss} + + # Reduce the aggressiveness of augmentations + self.state_noise *= self.config.noise_decay + self.crop_scale = 1 - (1 - self.crop_scale) * self.config.noise_decay + + return loss_dict + + @classmethod + def from_pretrained(cls, *args, **kwargs): + """Load model from pretrained checkpoint and merge LoRA after loading""" + policy = super().from_pretrained(*args, **kwargs) + + if getattr(policy.config, "merge_lora", False): + print("Merging LoRA after loading pretrained model...") + policy.model = merge_lora_weights(policy.model) + + return policy + + +class LoRAConv2d(nn.Module): + def __init__(self, base_conv, rank=4): + super().__init__() + self.base_conv = base_conv + + # Flatten the original conv weight + out_channels, in_channels, kh, kw = base_conv.weight.shape + self.weight_shape = (out_channels, in_channels, kh, kw) + fan_in = in_channels * kh * kw + + # LoRA parameters + self.lora_A = nn.Parameter(torch.normal(0, 0.02, (out_channels, rank))) + self.lora_B = nn.Parameter(torch.normal(0, 0.02, (rank, fan_in))) + + def forward(self, x): + lora_update = torch.matmul(self.lora_A, self.lora_B).view(self.weight_shape) + + return nn.functional.conv2d( + x, + self.base_conv.weight + lora_update, + self.base_conv.bias, + stride=self.base_conv.stride, + padding=self.base_conv.padding, + dilation=self.base_conv.dilation, + groups=self.base_conv.groups, + ) + + def merge_lora(self): + """Merge LoRA weights into the base convolution and return a standard Conv2d layer""" + lora_update = torch.matmul(self.lora_A, self.lora_B).view(self.weight_shape) + self.base_conv.weight.copy_(self.base_conv.weight + lora_update) + + return self.base_conv + + +def replace_conv2d_with_lora(module, rank=4): + """Recursively replace Conv2d layers with LoRAConv2d in the module""" + for name, child in list(module.named_children()): + if isinstance(child, nn.Conv2d): + setattr(module, name, LoRAConv2d(child, rank)) + else: + replace_conv2d_with_lora(child, rank) + return module + + +def merge_lora_weights(module): + """Recursively merge LoRA weights in the module""" + for name, child in list(module.named_children()): + if isinstance(child, LoRAConv2d): + setattr(module, name, child.merge_lora()) + else: + merge_lora_weights(child) + return module + + +def add_lora_to_backbone(backbone, rank=4, verbose=True): + replace_conv2d_with_lora(backbone, rank) + + for name, param in backbone.named_parameters(): + if "lora_" in name or name.startswith("fc"): + param.requires_grad = True + else: + param.requires_grad = False + + return backbone + + +def create_sinusoidal_pos_embedding(num_positions: int, dimension: int) -> Tensor: + position = torch.arange(num_positions, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, dimension, 2, dtype=torch.float) * (-math.log(10000.0) / dimension)) + pe = torch.zeros(num_positions, dimension) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + return pe diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index cd440f7a..a213d5e1 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -25,6 +25,7 @@ from lerobot.common.envs.configs import EnvConfig from lerobot.common.envs.utils import env_to_policy_features from lerobot.common.policies.act.configuration_act import ACTConfig from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig +from lerobot.common.policies.dot.configuration_dot import DOTConfig from lerobot.common.policies.pi0.configuration_pi0 import PI0Config from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig @@ -55,6 +56,10 @@ def get_policy_class(name: str) -> PreTrainedPolicy: from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy return PI0Policy + elif name == "dot": + from lerobot.common.policies.dot.modeling_dot import DOTPolicy + + return DOTPolicy else: raise NotImplementedError(f"Policy with name {name} is not implemented.") @@ -70,6 +75,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: return VQBeTConfig(**kwargs) elif policy_type == "pi0": return PI0Config(**kwargs) + elif policy_type == "dot": + return DOTConfig(**kwargs) else: raise ValueError(f"Policy type '{policy_type}' is not available.") From 489cdc2aceac21f6070a44ead618309f5ce7075f Mon Sep 17 00:00:00 2001 From: Ilia Date: Sun, 9 Feb 2025 19:09:41 +0700 Subject: [PATCH 2/4] Fixed small issues --- .../common/policies/dot/configuration_dot.py | 13 ++++++++- lerobot/common/policies/dot/modeling_dot.py | 29 ++++++++++++++----- 2 files changed, 33 insertions(+), 9 deletions(-) diff --git a/lerobot/common/policies/dot/configuration_dot.py b/lerobot/common/policies/dot/configuration_dot.py index 5b5a7aba..33d23f73 100644 --- a/lerobot/common/policies/dot/configuration_dot.py +++ b/lerobot/common/policies/dot/configuration_dot.py @@ -35,16 +35,27 @@ class DOTConfig(PreTrainedConfig): default_factory=lambda: { "VISUAL": NormalizationMode.MEAN_STD, "STATE": NormalizationMode.MIN_MAX, + "ENV": NormalizationMode.MIN_MAX, "ACTION": NormalizationMode.MIN_MAX, } ) + # Not sure if there is a better way to do this with new config system. + override_dataset_stats: bool = False + new_dataset_stats: dict[str, dict[str, list[float]]] = field( + default_factory=lambda: { + "action": {"max": [512.0] * 2, "min": [0.0] * 2}, + "observation.environment_state": {"max": [512.0] * 16, "min": [0.0] * 16}, + "observation.state": {"max": [512.0] * 2, "min": [0.0] * 2}, + } + ) + # Architecture. vision_backbone: str = "resnet18" pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1" pre_norm: bool = True lora_rank: int = 20 - merge_lora: bool = True + merge_lora: bool = False dim_model: int = 128 n_heads: int = 8 diff --git a/lerobot/common/policies/dot/modeling_dot.py b/lerobot/common/policies/dot/modeling_dot.py index 477cfdf8..5f392cbc 100644 --- a/lerobot/common/policies/dot/modeling_dot.py +++ b/lerobot/common/policies/dot/modeling_dot.py @@ -45,7 +45,7 @@ class DOT(nn.Module): self.n_features += self.config.n_obs_steps if self.config.env_state_feature: - self.projections["env_state"] = nn.Linear( + self.projections["environment_state"] = nn.Linear( self.config.env_state_feature.shape[0], self.config.dim_model ) self.n_features += self.config.n_obs_steps @@ -54,7 +54,7 @@ class DOT(nn.Module): obs_mapping = { "images": "observation.images", "state": "observation.state", - "env_state": "observation.environment_state", + "environment_state": "observation.environment_state", } self.obs_mapping = {k: v for k, v in obs_mapping.items() if k in self.projections_names} @@ -154,6 +154,15 @@ class DOTPolicy(PreTrainedPolicy): config.validate_features() self.config = config + if config.override_dataset_stats: + if dataset_stats is None: + dataset_stats = {} + for k, v in config.new_dataset_stats.items(): + if k not in dataset_stats: + dataset_stats[k] = {} + for k1, v1 in v.items(): + dataset_stats[k][k1] = torch.tensor(v1) + self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) self.normalize_targets = Normalize( config.output_features, config.normalization_mapping, dataset_stats @@ -164,11 +173,12 @@ class DOTPolicy(PreTrainedPolicy): self.model = DOT(self.config) - self.return_every_n = self.config.return_every_n self.state_noise = self.config.state_noise self.crop_scale = self.config.crop_scale self.alpha = self.config.alpha self.inference_horizon = self.config.inference_horizon + self.return_every_n = self.config.return_every_n + self.predict_every_n = self.config.predict_every_n # Inference action chunking and observation queues self._old_predictions = None @@ -196,12 +206,14 @@ class DOTPolicy(PreTrainedPolicy): config.rescale_shape, interpolation=InterpolationMode.NEAREST ) - self.predict_every_n = self.config.predict_every_n self.step = 0 + self.last_action = None def reset(self): self._old_predictions = None self._input_buffers = {} + self.last_action = None + self.step = 0 def get_optim_params(self) -> dict: return self.model.parameters() @@ -346,9 +358,9 @@ class DOTPolicy(PreTrainedPolicy): return loss_dict @classmethod - def from_pretrained(cls, *args, **kwargs): + def from_pretrained(cls, pretrained_name_or_path, *args, **kwargs): """Load model from pretrained checkpoint and merge LoRA after loading""" - policy = super().from_pretrained(*args, **kwargs) + policy = super().from_pretrained(pretrained_name_or_path, *args, **kwargs) if getattr(policy.config, "merge_lora", False): print("Merging LoRA after loading pretrained model...") @@ -368,8 +380,9 @@ class LoRAConv2d(nn.Module): fan_in = in_channels * kh * kw # LoRA parameters - self.lora_A = nn.Parameter(torch.normal(0, 0.02, (out_channels, rank))) - self.lora_B = nn.Parameter(torch.normal(0, 0.02, (rank, fan_in))) + std = 1 / math.sqrt(fan_in) + self.lora_A = nn.Parameter(torch.normal(0, std, (out_channels, rank))) + self.lora_B = nn.Parameter(torch.normal(0, std, (rank, fan_in))) def forward(self, x): lora_update = torch.matmul(self.lora_A, self.lora_B).view(self.weight_shape) From b6face0179610d707402da8c933bd49b05604bd2 Mon Sep 17 00:00:00 2001 From: Ilia Date: Mon, 10 Feb 2025 21:12:08 +0700 Subject: [PATCH 3/4] Fixed cameras order --- lerobot/common/policies/dot/modeling_dot.py | 32 ++++++++++----------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/lerobot/common/policies/dot/modeling_dot.py b/lerobot/common/policies/dot/modeling_dot.py index 5f392cbc..ac486f0e 100644 --- a/lerobot/common/policies/dot/modeling_dot.py +++ b/lerobot/common/policies/dot/modeling_dot.py @@ -27,8 +27,10 @@ class DOT(nn.Module): self.projections = nn.ModuleDict() self.n_features = 0 + self.image_names = sorted(config.image_features.keys()) + # I use one backbone for all cameras and simply project the output to the model dimension - if len(self.config.image_features) > 0: + if len(self.image_names) > 0: backbone = getattr(torchvision.models, self.config.vision_backbone)( weights=self.config.pretrained_backbone_weights, norm_layer=FrozenBatchNorm2d, @@ -36,7 +38,7 @@ class DOT(nn.Module): backbone.fc = nn.Linear(backbone.fc.in_features, self.config.dim_model) self.projections["images"] = add_lora_to_backbone(backbone, rank=config.lora_rank) - self.n_features += len(self.config.image_features) * self.config.n_obs_steps + self.n_features += len(self.image_names) * self.config.n_obs_steps if self.config.robot_state_feature: self.projections["state"] = nn.Linear( @@ -154,6 +156,8 @@ class DOTPolicy(PreTrainedPolicy): config.validate_features() self.config = config + self.image_names = sorted(config.image_features.keys()) + if config.override_dataset_stats: if dataset_stats is None: dataset_stats = {} @@ -243,9 +247,9 @@ class DOTPolicy(PreTrainedPolicy): batch = self.normalize_inputs(batch) # Resize and stack all images - if len(self.config.image_features) > 0: + if len(self.image_names) > 0: batch["observation.images"] = torch.stack( - [self.resize_transform(batch[k]) for k in self.config.image_features], + [self.resize_transform(batch[k]) for k in self.image_names], dim=1, ) # bs, n_cam, c, h, w @@ -301,11 +305,7 @@ class DOTPolicy(PreTrainedPolicy): def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: lookback_ind = torch.randint(0, 2 * self.config.lookback_aug + 1, (1,)).item() - for k in ( - list(self.model.obs_mapping.values()) - + list(self.config.image_features.keys()) - + ["action", "action_is_pad"] - ): + for k in list(self.model.obs_mapping.values()) + list(self.image_names) + ["action", "action_is_pad"]: if k != "observation.images": batch[k] = torch.cat( [ @@ -325,14 +325,14 @@ class DOTPolicy(PreTrainedPolicy): ) crop_transform = transforms.RandomCrop(new_shape) - for k in self.config.image_features: + for k in self.image_names: bs, n_obs, c, h, w = batch[k].shape batch[k] = batch[k].view(bs * n_obs, c, h, w) batch[k] = crop_transform(self.resize_transform(batch[k])) batch[k] = batch[k].view(bs, n_obs, c, *batch[k].shape[-2:]) - batch["observation.images"] = torch.stack( - [batch[k] for k in self.config.image_features], dim=2 - ).flatten(1, 2) # bs, n_obs * n_cam, c, h, w + batch["observation.images"] = torch.stack([batch[k] for k in self.image_names], dim=2).flatten( + 1, 2 + ) # bs, n_obs * n_cam, c, h, w # Add random noise to states during training # TODO: it should be done in the dataloader @@ -379,10 +379,8 @@ class LoRAConv2d(nn.Module): self.weight_shape = (out_channels, in_channels, kh, kw) fan_in = in_channels * kh * kw - # LoRA parameters - std = 1 / math.sqrt(fan_in) - self.lora_A = nn.Parameter(torch.normal(0, std, (out_channels, rank))) - self.lora_B = nn.Parameter(torch.normal(0, std, (rank, fan_in))) + self.lora_A = nn.Parameter(torch.normal(0, 0.02, (out_channels, rank))) + self.lora_B = nn.Parameter(torch.normal(0, 0.02, (rank, fan_in))) def forward(self, x): lora_update = torch.matmul(self.lora_A, self.lora_B).view(self.weight_shape) From 42cca283322bbb68b6d3f1b4a436ada5b5d3393a Mon Sep 17 00:00:00 2001 From: Ilia Date: Mon, 10 Feb 2025 21:12:24 +0700 Subject: [PATCH 4/4] Default crop --- lerobot/common/policies/dot/configuration_dot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lerobot/common/policies/dot/configuration_dot.py b/lerobot/common/policies/dot/configuration_dot.py index 33d23f73..9406b956 100644 --- a/lerobot/common/policies/dot/configuration_dot.py +++ b/lerobot/common/policies/dot/configuration_dot.py @@ -64,7 +64,7 @@ class DOTConfig(PreTrainedConfig): rescale_shape: tuple[int, int] = (96, 96) # Augmentation. - crop_scale: float = 1.0 + crop_scale: float = 0.8 state_noise: float = 0.01 noise_decay: float = 0.999995