diff --git a/lerobot/common/policies/dit_flow/configuration_dit_flow.py b/lerobot/common/policies/dit_flow/configuration_dit_flow.py new file mode 100644 index 00000000..ce74c967 --- /dev/null +++ b/lerobot/common/policies/dit_flow/configuration_dit_flow.py @@ -0,0 +1,212 @@ +#!/usr/bin/env python + +# Copyright 2025 Nur Muhammad Mahi Shafiullah, +# and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass, field + +from lerobot.common.optim.optimizers import AdamConfig +from lerobot.common.optim.schedulers import DiffuserSchedulerConfig +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import NormalizationMode + + +@PreTrainedConfig.register_subclass("ditflow") +@dataclass +class DiTFlowConfig(PreTrainedConfig): + """Configuration class for DiTFlowPolicy. + + Defaults are configured for training with PushT providing proprioceptive and single camera observations. + + The parameters you will most likely need to change are the ones which depend on the environment / sensors. + Those are: `input_shapes` and `output_shapes`. + + Notes on the inputs and outputs: + - "observation.state" is required as an input key. + - Either: + - At least one key starting with "observation.image is required as an input. + AND/OR + - The key "observation.environment_state" is required as input. + - If there are multiple keys beginning with "observation.image" they are treated as multiple camera + views. Right now we only support all images having the same shape. + - "action" is required as an output key. + + Args: + n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the + current step and additional steps going back). + horizon: DiT-flow model action prediction size as detailed in `DiTFlowPolicy.select_action`. + n_action_steps: The number of action steps to run in the environment for one invocation of the policy. + See `DiTFlowPolicy.select_action` for more details. + input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents + the input data name, and the value is a list indicating the dimensions of the corresponding data. + For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96], + indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't + include batch dimension or temporal dimension. + output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents + the output data name, and the value is a list indicating the dimensions of the corresponding data. + For example, "action" refers to an output shape of [14], indicating 14-dimensional actions. + Importantly, `output_shapes` doesn't include batch dimension or temporal dimension. + input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"), + and the value specifies the normalization mode to apply. The two available modes are "mean_std" + which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a + [-1, 1] range. + output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the + original scale. Note that this is also used for normalizing the training targets. + vision_backbone: Name of the torchvision resnet backbone to use for encoding images. + crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit + within the image size. If None, no cropping is done. + crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval + mode). + pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone. + `None` means no pretrained weights. + use_group_norm: Whether to replace batch normalization with group normalization in the backbone. + The group sizes are set to be about 16 (to be precise, feature_dim // 16). + spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax. + use_separate_rgb_encoders_per_camera: Whether to use a separate RGB encoder for each camera view. + + frequency_embedding_dim: The embedding dimension for the time value embedding in the flow model. + num_blocks: The number of transformer blocks in the DiT flow model. + hidden_dim: The hidden dimension for the transformer blocks in the DiT flow model. + num_heads: The number of attention heads in the transformer blocks. + dropout: The dropout rate used inside the transformer blocks. + dim_feedforward: The expanded feedforward dimension in the MLPs used in the transformer block. + activation: The activation function used in the transformer blocks. + clip_sample: Whether to clip the sample to [-`clip_sample_range`, +`clip_sample_range`] for each + denoising step at inference time. WARNING: you will need to make sure your action-space is + normalized to fit within this range. + clip_sample_range: The magnitude of the clipping range as described above. + num_inference_steps: Number of reverse diffusion steps to use at inference time (steps are evenly + spaced). + do_mask_loss_for_padding: Whether to mask the loss when there are copy-padded actions. See + `LeRobotDataset` and `load_previous_and_future_frames` for mor information. Note, this defaults + to False as the original Diffusion Policy implementation does the same. + """ + + # Inputs / output structure. + n_obs_steps: int = 2 + horizon: int = 16 + n_action_steps: int = 8 + + normalization_mapping: dict[str, NormalizationMode] = field( + default_factory=lambda: { + "VISUAL": NormalizationMode.MEAN_STD, + "STATE": NormalizationMode.MIN_MAX, + "ACTION": NormalizationMode.MIN_MAX, + } + ) + + # The original implementation doesn't sample frames for the last 7 steps, + # which avoids excessive padding and leads to improved training results. + drop_n_last_frames: int = 7 # horizon - n_action_steps - n_obs_steps + 1 + + # Architecture / modeling. + # Vision backbone. + vision_backbone: str = "resnet18" + crop_shape: tuple[int, int] | None = (84, 84) + crop_is_random: bool = True + pretrained_backbone_weights: str | None = None + use_group_norm: bool = True + spatial_softmax_num_keypoints: int = 32 + use_separate_rgb_encoder_per_camera: bool = False + + # Diffusion Transformer (DiT) parameters. + frequency_embedding_dim: int = 256 + hidden_dim: int = 512 + num_blocks: int = 6 + num_heads: int = 16 + dropout: float = 0.1 + dim_feedforward: int = 4096 + activation: str = "gelu" + + # Noise scheduler. + training_noise_sampling: str = ( + "uniform" # "uniform" or "beta", from pi0 https://www.physicalintelligence.company/download/pi0.pdf + ) + clip_sample: bool = True + clip_sample_range: float = 1.0 + + # Inference + num_inference_steps: int | None = 100 + + # Loss computation + do_mask_loss_for_padding: bool = False + + # Training presets + optimizer_lr: float = 1e-4 + optimizer_betas: tuple = (0.95, 0.999) + optimizer_eps: float = 1e-8 + optimizer_weight_decay: float = 1e-6 + scheduler_name: str = "cosine" + scheduler_warmup_steps: int = 500 + + def __post_init__(self): + super().__post_init__() + + """Input validation (not exhaustive).""" + if not self.vision_backbone.startswith("resnet"): + raise ValueError( + f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}." + ) + + if self.training_noise_sampling not in ("uniform", "beta"): + raise ValueError( + f"`training_noise_sampling` must be either 'uniform' or 'beta'. Got {self.training_noise_sampling}." + ) + + def get_optimizer_preset(self) -> AdamConfig: + return AdamConfig( + lr=self.optimizer_lr, + betas=self.optimizer_betas, + eps=self.optimizer_eps, + weight_decay=self.optimizer_weight_decay, + ) + + def get_scheduler_preset(self) -> DiffuserSchedulerConfig: + return DiffuserSchedulerConfig( + name=self.scheduler_name, + num_warmup_steps=self.scheduler_warmup_steps, + ) + + def validate_features(self) -> None: + if len(self.image_features) == 0 and self.env_state_feature is None: + raise ValueError("You must provide at least one image or the environment state among the inputs.") + + if self.crop_shape is not None: + for key, image_ft in self.image_features.items(): + if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]: + raise ValueError( + f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} " + f"for `crop_shape` and {image_ft.shape} for " + f"`{key}`." + ) + + # Check that all input images have the same shape. + first_image_key, first_image_ft = next(iter(self.image_features.items())) + for key, image_ft in self.image_features.items(): + if image_ft.shape != first_image_ft.shape: + raise ValueError( + f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match." + ) + + @property + def observation_delta_indices(self) -> list: + return list(range(1 - self.n_obs_steps, 1)) + + @property + def action_delta_indices(self) -> list: + return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon)) + + @property + def reward_delta_indices(self) -> None: + return None diff --git a/lerobot/common/policies/dit_flow/modeling_dit_flow.py b/lerobot/common/policies/dit_flow/modeling_dit_flow.py new file mode 100644 index 00000000..dc015469 --- /dev/null +++ b/lerobot/common/policies/dit_flow/modeling_dit_flow.py @@ -0,0 +1,587 @@ +# Copyright 2025 Nur Muhammad Mahi Shafiullah, +# and The HuggingFace Inc. team. All rights reserved. +# Heavy inspiration taken from +# * DETR by Meta AI (Carion et. al.): https://github.com/facebookresearch/detr +# * DiT by Meta AI (Peebles and Xie): https://github.com/facebookresearch/DiT +# * DiT Policy by Dasari et. al. : https://github.com/sudeepdasari/dit-policy + +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import copy +from collections import deque + +import einops +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F # noqa: N812 + +from lerobot.common.constants import OBS_ENV, OBS_ROBOT +from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionRgbEncoder +from lerobot.common.policies.dit_flow.configuration_dit_flow import DiTFlowConfig +from lerobot.common.policies.normalize import Normalize, Unnormalize +from lerobot.common.policies.pretrained import PreTrainedPolicy +from lerobot.common.policies.utils import ( + get_device_from_parameters, + get_dtype_from_parameters, + populate_queues, +) + + +def _get_activation_fn(activation: str): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return nn.GELU(approximate="tanh") + if activation == "glu": + return F.glu + raise RuntimeError(f"activation should be relu/gelu/glu, not {activation}.") + + +def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + return x * (1 + scale.unsqueeze(0)) + shift.unsqueeze(0) + + +class _TimeNetwork(nn.Module): + def __init__(self, frequency_embedding_dim, hidden_dim, learnable_w=False, max_period=1000): + assert frequency_embedding_dim % 2 == 0, "time_dim must be even!" + half_dim = int(frequency_embedding_dim // 2) + super().__init__() + + w = np.log(max_period) / (half_dim - 1) + w = torch.exp(torch.arange(half_dim) * -w).float() + self.register_parameter("w", nn.Parameter(w, requires_grad=learnable_w)) + + self.out_net = nn.Sequential( + nn.Linear(frequency_embedding_dim, hidden_dim), nn.SiLU(), nn.Linear(hidden_dim, hidden_dim) + ) + + def forward(self, t): + assert len(t.shape) == 1, "assumes 1d input timestep array" + t = t[:, None] * self.w[None] + t = torch.cat((torch.cos(t), torch.sin(t)), dim=1) + return self.out_net(t) + + +class _ShiftScaleMod(nn.Module): + def __init__(self, dim): + super().__init__() + self.act = nn.SiLU() + self.scale = nn.Linear(dim, dim) + self.shift = nn.Linear(dim, dim) + + def forward(self, x, c): + c = self.act(c) + return x * (1 + self.scale(c)[None]) + self.shift(c)[None] + + def reset_parameters(self): + nn.init.zeros_(self.scale.weight) + nn.init.zeros_(self.shift.weight) + nn.init.zeros_(self.scale.bias) + nn.init.zeros_(self.shift.bias) + + +class _ZeroScaleMod(nn.Module): + def __init__(self, dim): + super().__init__() + self.act = nn.SiLU() + self.scale = nn.Linear(dim, dim) + + def forward(self, x, c): + c = self.act(c) + return x * self.scale(c)[None] + + def reset_parameters(self): + nn.init.zeros_(self.scale.weight) + nn.init.zeros_(self.scale.bias) + + +class _DiTDecoder(nn.Module): + def __init__(self, d_model=256, nhead=6, dim_feedforward=2048, dropout=0.0, activation="gelu"): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model, eps=1e-6) + self.norm2 = nn.LayerNorm(d_model, eps=1e-6) + + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + + # create mlp + self.mlp = nn.Sequential( + self.linear1, + self.activation, + self.dropout2, + self.linear2, + self.dropout3, + ) + + # create modulation layers + self.attn_modulate = _ShiftScaleMod(d_model) + self.attn_gate = _ZeroScaleMod(d_model) + self.mlp_modulate = _ShiftScaleMod(d_model) + self.mlp_gate = _ZeroScaleMod(d_model) + + def forward(self, x, t, cond): + # process the conditioning vector first + cond = cond + t + + x2 = self.attn_modulate(self.norm1(x), cond) + x2, _ = self.self_attn(x2, x2, x2, need_weights=False) + x = x + self.attn_gate(self.dropout1(x2), cond) + + x3 = self.mlp_modulate(self.norm2(x), cond) + x3 = self.mlp(x3) + x3 = self.mlp_gate(x3, cond) + return x + x3 + + def reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + for s in (self.attn_modulate, self.attn_gate, self.mlp_modulate, self.mlp_gate): + s.reset_parameters() + + +class _FinalLayer(nn.Module): + def __init__(self, hidden_size, out_size): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_size, bias=True) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) + + def forward(self, x, t, cond): + # process the conditioning vector first + cond = cond + t + + shift, scale = self.adaLN_modulation(cond).chunk(2, dim=1) + x = modulate(x, shift, scale) + x = self.linear(x) + return x + + def reset_parameters(self): + for p in self.parameters(): + nn.init.zeros_(p) + + +class _TransformerDecoder(nn.Module): + def __init__(self, base_module, num_layers): + super().__init__() + self.layers = nn.ModuleList([copy.deepcopy(base_module) for _ in range(num_layers)]) + + for layer in self.layers: + layer.reset_parameters() + + def forward(self, src, t, cond): + x = src + for layer in self.layers: + x = layer(x, t, cond) + return x + + +class _DiTNoiseNet(nn.Module): + def __init__( + self, + ac_dim, + ac_chunk, + cond_dim, + time_dim=256, + hidden_dim=256, + num_blocks=6, + dropout=0.1, + dim_feedforward=2048, + nhead=8, + activation="gelu", + clip_sample=False, + clip_sample_range=1.0, + ): + super().__init__() + self.ac_dim, self.ac_chunk = ac_dim, ac_chunk + + # positional encoding blocks + self.register_parameter( + "dec_pos", + nn.Parameter(torch.empty(ac_chunk, 1, hidden_dim), requires_grad=True), + ) + nn.init.xavier_uniform_(self.dec_pos.data) + + # input encoder mlps + self.time_net = _TimeNetwork(time_dim, hidden_dim) + self.ac_proj = nn.Sequential( + nn.Linear(ac_dim, ac_dim), + nn.GELU(approximate="tanh"), + nn.Linear(ac_dim, hidden_dim), + ) + self.cond_proj = nn.Linear(cond_dim, hidden_dim) + + # decoder blocks + decoder_module = _DiTDecoder( + hidden_dim, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + activation=activation, + ) + self.decoder = _TransformerDecoder(decoder_module, num_blocks) + + # turns predicted tokens into epsilons + self.eps_out = _FinalLayer(hidden_dim, ac_dim) + + # clip the output samples + self.clip_sample = clip_sample + self.clip_sample_range = clip_sample_range + + print("Number of flow params: {:.2f}M".format(sum(p.numel() for p in self.parameters()) / 1e6)) + + def forward(self, noisy_actions, time, global_cond): + c = self.cond_proj(global_cond) + time_enc = self.time_net(time) + + ac_tokens = self.ac_proj(noisy_actions) # [B, T, adim] -> [B, T, hidden_dim] + ac_tokens = ac_tokens.transpose(0, 1) # [B, T, hidden_dim] -> [T, B, hidden_dim] + + # Allow variable length action chunks + dec_in = ac_tokens + self.dec_pos[: ac_tokens.size(0)] # [T, B, hidden_dim] + + # apply decoder + dec_out = self.decoder(dec_in, time_enc, c) + + # apply final epsilon prediction layer + eps_out = self.eps_out(dec_out, time_enc, c) # [T, B, hidden_dim] -> [T, B, adim] + return eps_out.transpose(0, 1) # [T, B, adim] -> [B, T, adim] + + @torch.no_grad() + def sample( + self, condition: torch.Tensor, timesteps: int = 100, generator: torch.Generator | None = None + ) -> torch.Tensor: + # Use Euler integration to solve the ODE. + batch_size, device = condition.shape[0], condition.device + x_0 = self.sample_noise(batch_size, device, generator) + dt = 1.0 / timesteps + t_all = ( + torch.arange(timesteps, device=device).float().unsqueeze(0).expand(batch_size, timesteps) + / timesteps + ) + + for k in range(timesteps): + t = t_all[:, k] + x_0 = x_0 + dt * self.forward(x_0, t, condition) + if self.clip_sample: + x_0 = torch.clamp(x_0, -self.clip_sample_range, self.clip_sample_range) + return x_0 + + def sample_noise(self, batch_size: int, device, generator: torch.Generator | None = None) -> torch.Tensor: + return torch.randn(batch_size, self.ac_chunk, self.ac_dim, device=device, generator=generator) + + +class DiTFlowPolicy(PreTrainedPolicy): + """ + Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion" + (paper: https://arxiv.org/abs/2303.04137, code: https://github.com/real-stanford/diffusion_policy). + """ + + config_class = DiTFlowConfig + name = "DiTFlow" + + def __init__( + self, + config: DiTFlowConfig, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, + ): + """ + Args: + config: Policy configuration class instance or None, in which case the default instantiation of + the configuration class is used. + dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected + that they will be passed with a call to `load_state_dict` before the policy is used. + """ + super().__init__(config) + config.validate_features() + self.config = config + + self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) + self.normalize_targets = Normalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + self.unnormalize_outputs = Unnormalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + + # queues are populated during rollout of the policy, they contain the n latest observations and actions + self._queues = None + + self.dit_flow = DiTFlowModel(config) + + self.reset() + + def get_optim_params(self) -> dict: + return self.dit_flow.parameters() + + def reset(self): + """Clear observation and action queues. Should be called on `env.reset()`""" + self._queues = { + "observation.state": deque(maxlen=self.config.n_obs_steps), + "action": deque(maxlen=self.config.n_action_steps), + } + if self.config.image_features: + self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps) + if self.config.env_state_feature: + self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps) + + @torch.no_grad + def select_action(self, batch: dict[str, torch.Tensor]) -> torch.Tensor: + """Select a single action given environment observations. + + This method handles caching a history of observations and an action trajectory generated by the + underlying flow model. Here's how it works: + - `n_obs_steps` steps worth of observations are cached (for the first steps, the observation is + copied `n_obs_steps` times to fill the cache). + - The flow model generates `horizon` steps worth of actions. + - `n_action_steps` worth of actions are actually kept for execution, starting from the current step. + Schematically this looks like: + ---------------------------------------------------------------------------------------------- + (legend: o = n_obs_steps, h = horizon, a = n_action_steps) + |timestep | n-o+1 | n-o+2 | ..... | n | ..... | n+a-1 | n+a | ..... | n-o+h | + |observation is used | YES | YES | YES | YES | NO | NO | NO | NO | NO | + |action is generated | YES | YES | YES | YES | YES | YES | YES | YES | YES | + |action is used | NO | NO | NO | YES | YES | YES | NO | NO | NO | + ---------------------------------------------------------------------------------------------- + Note that this means we require: `n_action_steps <= horizon - n_obs_steps + 1`. Also, note that + "horizon" may not the best name to describe what the variable actually means, because this period is + actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past. + """ + batch = self.normalize_inputs(batch) + if self.config.image_features: + batch = dict(batch) # shallow copy so that adding a key doesn't modify the original + batch["observation.images"] = torch.stack( + [batch[key] for key in self.config.image_features], dim=-4 + ) + # Note: It's important that this happens after stacking the images into a single key. + self._queues = populate_queues(self._queues, batch) + + if len(self._queues["action"]) == 0: + # stack n latest observations from the queue + batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues} + actions = self.dit_flow.generate_actions(batch) + + # TODO(rcadene): make above methods return output dictionary? + actions = self.unnormalize_outputs({"action": actions})["action"] + + self._queues["action"].extend(actions.transpose(0, 1)) + + action = self._queues["action"].popleft() + return action + + def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """Run the batch through the model and compute the loss for training or validation.""" + batch = self.normalize_inputs(batch) + if self.config.image_features: + batch = dict(batch) # shallow copy so that adding a key doesn't modify the original + batch["observation.images"] = torch.stack( + [batch[key] for key in self.config.image_features], dim=-4 + ) + batch = self.normalize_targets(batch) + loss = self.dit_flow.compute_loss(batch) + return {"loss": loss} + + +class DiTFlowModel(nn.Module): + def __init__(self, config: DiTFlowConfig): + super().__init__() + self.config = config + + # Build observation encoders (depending on which observations are provided). + global_cond_dim = self.config.robot_state_feature.shape[0] + if self.config.image_features: + num_images = len(self.config.image_features) + if self.config.use_separate_rgb_encoder_per_camera: + encoders = [DiffusionRgbEncoder(config) for _ in range(num_images)] + self.rgb_encoder = nn.ModuleList(encoders) + global_cond_dim += encoders[0].feature_dim * num_images + else: + self.rgb_encoder = DiffusionRgbEncoder(config) + global_cond_dim += self.rgb_encoder.feature_dim * num_images + if self.config.env_state_feature: + global_cond_dim += self.config.env_state_feature.shape[0] + + self.velocity_net = _DiTNoiseNet( + ac_dim=config.action_feature.shape[0], + ac_chunk=config.horizon, + cond_dim=global_cond_dim * config.n_obs_steps, + time_dim=config.frequency_embedding_dim, + hidden_dim=config.hidden_dim, + num_blocks=config.num_blocks, + dropout=config.dropout, + dim_feedforward=config.dim_feedforward, + nhead=config.num_heads, + activation=config.activation, + clip_sample=config.clip_sample, + clip_sample_range=config.clip_sample_range, + ) + + self.num_inference_steps = config.num_inference_steps or 100 + self.training_noise_sampling = config.training_noise_sampling + if config.training_noise_sampling == "uniform": + self.noise_distribution = torch.distributions.Uniform( + low=0, + high=1, + ) + elif config.training_noise_sampling == "beta": + # From the Pi0 paper, https://www.physicalintelligence.company/download/pi0.pdf Appendix B. + # There, they say the PDF for the distribution they use is the following: + # $p(t) = Beta((s-t) / s; 1.5, 1)$ + # So, we first figure out the distribution over $t'$ and then transform it to $t = s - s * t'$. + s = 0.999 # constant from the paper + beta_dist = torch.distributions.Beta( + concentration1=1.5, # alpha + concentration0=1.0, # beta + ) + affine_transform = torch.distributions.transforms.AffineTransform(loc=s, scale=-s) + self.noise_distribution = torch.distributions.TransformedDistribution( + beta_dist, [affine_transform] + ) + + # ========= inference ============ + def conditional_sample( + self, + batch_size: int, + global_cond: torch.Tensor | None = None, + generator: torch.Generator | None = None, + ) -> torch.Tensor: + device = get_device_from_parameters(self) + dtype = get_dtype_from_parameters(self) + + # Expand global conditioning to the batch size. + if global_cond is not None: + global_cond = global_cond.expand(batch_size, -1).to(device=device, dtype=dtype) + + # Sample prior. + sample = self.velocity_net.sample( + global_cond, timesteps=self.num_inference_steps, generator=generator + ) + return sample + + def _prepare_global_conditioning(self, batch: dict[str, torch.Tensor]) -> torch.Tensor: + """Encode image features and concatenate them all together along with the state vector.""" + batch_size, n_obs_steps = batch[OBS_ROBOT].shape[:2] + global_cond_feats = [batch[OBS_ROBOT]] + # Extract image features. + if self.config.image_features: + if self.config.use_separate_rgb_encoder_per_camera: + # Combine batch and sequence dims while rearranging to make the camera index dimension first. + images_per_camera = einops.rearrange(batch["observation.images"], "b s n ... -> n (b s) ...") + img_features_list = torch.cat( + [ + encoder(images) + for encoder, images in zip(self.rgb_encoder, images_per_camera, strict=True) + ] + ) + # Separate batch and sequence dims back out. The camera index dim gets absorbed into the + # feature dim (effectively concatenating the camera features). + img_features = einops.rearrange( + img_features_list, "(n b s) ... -> b s (n ...)", b=batch_size, s=n_obs_steps + ) + else: + # Combine batch, sequence, and "which camera" dims before passing to shared encoder. + img_features = self.rgb_encoder( + einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...") + ) + # Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the + # feature dim (effectively concatenating the camera features). + img_features = einops.rearrange( + img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps + ) + global_cond_feats.append(img_features) + + if self.config.env_state_feature: + global_cond_feats.append(batch[OBS_ENV]) + + # Concatenate features then flatten to (B, global_cond_dim). + return torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1) + + def generate_actions(self, batch: dict[str, torch.Tensor]) -> torch.Tensor: + """ + This function expects `batch` to have: + { + "observation.state": (B, n_obs_steps, state_dim) + + "observation.images": (B, n_obs_steps, num_cameras, C, H, W) + AND/OR + "observation.environment_state": (B, environment_dim) + } + """ + batch_size, n_obs_steps = batch["observation.state"].shape[:2] + assert n_obs_steps == self.config.n_obs_steps + + # Encode image features and concatenate them all together along with the state vector. + global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim) + + # run sampling + actions = self.conditional_sample(batch_size, global_cond=global_cond) + + # Extract `n_action_steps` steps worth of actions (from the current observation). + start = n_obs_steps - 1 + end = start + self.config.n_action_steps + actions = actions[:, start:end] + + return actions + + def compute_loss(self, batch: dict[str, torch.Tensor]) -> torch.Tensor: + """ + This function expects `batch` to have (at least): + { + "observation.state": (B, n_obs_steps, state_dim) + + "observation.images": (B, n_obs_steps, num_cameras, C, H, W) + AND/OR + "observation.environment_state": (B, environment_dim) + + "action": (B, horizon, action_dim) + "action_is_pad": (B, horizon) + } + """ + # Input validation. + assert set(batch).issuperset({"observation.state", "action", "action_is_pad"}) + assert "observation.images" in batch or "observation.environment_state" in batch + n_obs_steps = batch["observation.state"].shape[1] + horizon = batch["action"].shape[1] + assert horizon == self.config.horizon + assert n_obs_steps == self.config.n_obs_steps + + # Encode image features and concatenate them all together along with the state vector. + global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim) + + # Forward diffusion. + trajectory = batch["action"] + # Sample noise to add to the trajectory. + noise = self.velocity_net.sample_noise(trajectory.shape[0], trajectory.device) + # Sample a random noising timestep for each item in the batch. + timesteps = self.noise_distribution.sample((trajectory.shape[0],)).to(trajectory.device) + # Add noise to the clean trajectories according to the noise magnitude at each timestep. + noisy_trajectory = (1 - timesteps[:, None, None]) * noise + timesteps[:, None, None] * trajectory + + # Run the denoising network (that might denoise the trajectory, or attempt to predict the noise). + pred = self.velocity_net(noisy_actions=noisy_trajectory, time=timesteps, global_cond=global_cond) + target = trajectory - noise + loss = F.mse_loss(pred, target, reduction="none") + + # Mask loss wherever the action is padded with copies (edges of the dataset trajectory). + if self.config.do_mask_loss_for_padding: + if "action_is_pad" not in batch: + raise ValueError( + "You need to provide 'action_is_pad' in the batch when " + f"{self.config.do_mask_loss_for_padding=}." + ) + in_episode_bound = ~batch["action_is_pad"] + loss = loss * in_episode_bound.unsqueeze(-1) + + return loss.mean() diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 8def95a3..67fb769b 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -24,6 +24,7 @@ from lerobot.common.envs.configs import EnvConfig from lerobot.common.envs.utils import env_to_policy_features from lerobot.common.policies.act.configuration_act import ACTConfig from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig +from lerobot.common.policies.dit_flow.configuration_dit_flow import DiTFlowConfig from lerobot.common.policies.pi0.configuration_pi0 import PI0Config from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig from lerobot.common.policies.pretrained import PreTrainedPolicy @@ -43,6 +44,10 @@ def get_policy_class(name: str) -> PreTrainedPolicy: from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy return DiffusionPolicy + elif name == "ditflow": + from lerobot.common.policies.dit_flow.modeling_dit_flow import DiTFlowPolicy + + return DiTFlowPolicy elif name == "act": from lerobot.common.policies.act.modeling_act import ACTPolicy @@ -68,6 +73,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: return TDMPCConfig(**kwargs) elif policy_type == "diffusion": return DiffusionConfig(**kwargs) + elif policy_type == "ditflow": + return DiTFlowConfig(**kwargs) elif policy_type == "act": return ACTConfig(**kwargs) elif policy_type == "vqbet":