This commit is contained in:
Nur Muhammad "Mahi" Shafiullah 2025-04-05 10:29:55 +08:00 committed by GitHub
commit 10a209e42d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 806 additions and 0 deletions

View File

@ -0,0 +1,212 @@
#!/usr/bin/env python
# Copyright 2025 Nur Muhammad Mahi Shafiullah,
# and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.common.optim.optimizers import AdamConfig
from lerobot.common.optim.schedulers import DiffuserSchedulerConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
@PreTrainedConfig.register_subclass("ditflow")
@dataclass
class DiTFlowConfig(PreTrainedConfig):
"""Configuration class for DiTFlowPolicy.
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `input_shapes` and `output_shapes`.
Notes on the inputs and outputs:
- "observation.state" is required as an input key.
- Either:
- At least one key starting with "observation.image is required as an input.
AND/OR
- The key "observation.environment_state" is required as input.
- If there are multiple keys beginning with "observation.image" they are treated as multiple camera
views. Right now we only support all images having the same shape.
- "action" is required as an output key.
Args:
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
current step and additional steps going back).
horizon: DiT-flow model action prediction size as detailed in `DiTFlowPolicy.select_action`.
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
See `DiTFlowPolicy.select_action` for more details.
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
the input data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
the output data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
[-1, 1] range.
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
original scale. Note that this is also used for normalizing the training targets.
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
within the image size. If None, no cropping is done.
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
mode).
pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone.
`None` means no pretrained weights.
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax.
use_separate_rgb_encoders_per_camera: Whether to use a separate RGB encoder for each camera view.
frequency_embedding_dim: The embedding dimension for the time value embedding in the flow model.
num_blocks: The number of transformer blocks in the DiT flow model.
hidden_dim: The hidden dimension for the transformer blocks in the DiT flow model.
num_heads: The number of attention heads in the transformer blocks.
dropout: The dropout rate used inside the transformer blocks.
dim_feedforward: The expanded feedforward dimension in the MLPs used in the transformer block.
activation: The activation function used in the transformer blocks.
clip_sample: Whether to clip the sample to [-`clip_sample_range`, +`clip_sample_range`] for each
denoising step at inference time. WARNING: you will need to make sure your action-space is
normalized to fit within this range.
clip_sample_range: The magnitude of the clipping range as described above.
num_inference_steps: Number of reverse diffusion steps to use at inference time (steps are evenly
spaced).
do_mask_loss_for_padding: Whether to mask the loss when there are copy-padded actions. See
`LeRobotDataset` and `load_previous_and_future_frames` for mor information. Note, this defaults
to False as the original Diffusion Policy implementation does the same.
"""
# Inputs / output structure.
n_obs_steps: int = 2
horizon: int = 16
n_action_steps: int = 8
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.MEAN_STD,
"STATE": NormalizationMode.MIN_MAX,
"ACTION": NormalizationMode.MIN_MAX,
}
)
# The original implementation doesn't sample frames for the last 7 steps,
# which avoids excessive padding and leads to improved training results.
drop_n_last_frames: int = 7 # horizon - n_action_steps - n_obs_steps + 1
# Architecture / modeling.
# Vision backbone.
vision_backbone: str = "resnet18"
crop_shape: tuple[int, int] | None = (84, 84)
crop_is_random: bool = True
pretrained_backbone_weights: str | None = None
use_group_norm: bool = True
spatial_softmax_num_keypoints: int = 32
use_separate_rgb_encoder_per_camera: bool = False
# Diffusion Transformer (DiT) parameters.
frequency_embedding_dim: int = 256
hidden_dim: int = 512
num_blocks: int = 6
num_heads: int = 16
dropout: float = 0.1
dim_feedforward: int = 4096
activation: str = "gelu"
# Noise scheduler.
training_noise_sampling: str = (
"uniform" # "uniform" or "beta", from pi0 https://www.physicalintelligence.company/download/pi0.pdf
)
clip_sample: bool = True
clip_sample_range: float = 1.0
# Inference
num_inference_steps: int | None = 100
# Loss computation
do_mask_loss_for_padding: bool = False
# Training presets
optimizer_lr: float = 1e-4
optimizer_betas: tuple = (0.95, 0.999)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-6
scheduler_name: str = "cosine"
scheduler_warmup_steps: int = 500
def __post_init__(self):
super().__post_init__()
"""Input validation (not exhaustive)."""
if not self.vision_backbone.startswith("resnet"):
raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
)
if self.training_noise_sampling not in ("uniform", "beta"):
raise ValueError(
f"`training_noise_sampling` must be either 'uniform' or 'beta'. Got {self.training_noise_sampling}."
)
def get_optimizer_preset(self) -> AdamConfig:
return AdamConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
)
def get_scheduler_preset(self) -> DiffuserSchedulerConfig:
return DiffuserSchedulerConfig(
name=self.scheduler_name,
num_warmup_steps=self.scheduler_warmup_steps,
)
def validate_features(self) -> None:
if len(self.image_features) == 0 and self.env_state_feature is None:
raise ValueError("You must provide at least one image or the environment state among the inputs.")
if self.crop_shape is not None:
for key, image_ft in self.image_features.items():
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
raise ValueError(
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
f"for `crop_shape` and {image_ft.shape} for "
f"`{key}`."
)
# Check that all input images have the same shape.
first_image_key, first_image_ft = next(iter(self.image_features.items()))
for key, image_ft in self.image_features.items():
if image_ft.shape != first_image_ft.shape:
raise ValueError(
f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match."
)
@property
def observation_delta_indices(self) -> list:
return list(range(1 - self.n_obs_steps, 1))
@property
def action_delta_indices(self) -> list:
return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon))
@property
def reward_delta_indices(self) -> None:
return None

View File

@ -0,0 +1,587 @@
# Copyright 2025 Nur Muhammad Mahi Shafiullah,
# and The HuggingFace Inc. team. All rights reserved.
# Heavy inspiration taken from
# * DETR by Meta AI (Carion et. al.): https://github.com/facebookresearch/detr
# * DiT by Meta AI (Peebles and Xie): https://github.com/facebookresearch/DiT
# * DiT Policy by Dasari et. al. : https://github.com/sudeepdasari/dit-policy
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import copy
from collections import deque
import einops
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from lerobot.common.constants import OBS_ENV, OBS_ROBOT
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionRgbEncoder
from lerobot.common.policies.dit_flow.configuration_dit_flow import DiTFlowConfig
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.utils import (
get_device_from_parameters,
get_dtype_from_parameters,
populate_queues,
)
def _get_activation_fn(activation: str):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return nn.GELU(approximate="tanh")
if activation == "glu":
return F.glu
raise RuntimeError(f"activation should be relu/gelu/glu, not {activation}.")
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
return x * (1 + scale.unsqueeze(0)) + shift.unsqueeze(0)
class _TimeNetwork(nn.Module):
def __init__(self, frequency_embedding_dim, hidden_dim, learnable_w=False, max_period=1000):
assert frequency_embedding_dim % 2 == 0, "time_dim must be even!"
half_dim = int(frequency_embedding_dim // 2)
super().__init__()
w = np.log(max_period) / (half_dim - 1)
w = torch.exp(torch.arange(half_dim) * -w).float()
self.register_parameter("w", nn.Parameter(w, requires_grad=learnable_w))
self.out_net = nn.Sequential(
nn.Linear(frequency_embedding_dim, hidden_dim), nn.SiLU(), nn.Linear(hidden_dim, hidden_dim)
)
def forward(self, t):
assert len(t.shape) == 1, "assumes 1d input timestep array"
t = t[:, None] * self.w[None]
t = torch.cat((torch.cos(t), torch.sin(t)), dim=1)
return self.out_net(t)
class _ShiftScaleMod(nn.Module):
def __init__(self, dim):
super().__init__()
self.act = nn.SiLU()
self.scale = nn.Linear(dim, dim)
self.shift = nn.Linear(dim, dim)
def forward(self, x, c):
c = self.act(c)
return x * (1 + self.scale(c)[None]) + self.shift(c)[None]
def reset_parameters(self):
nn.init.zeros_(self.scale.weight)
nn.init.zeros_(self.shift.weight)
nn.init.zeros_(self.scale.bias)
nn.init.zeros_(self.shift.bias)
class _ZeroScaleMod(nn.Module):
def __init__(self, dim):
super().__init__()
self.act = nn.SiLU()
self.scale = nn.Linear(dim, dim)
def forward(self, x, c):
c = self.act(c)
return x * self.scale(c)[None]
def reset_parameters(self):
nn.init.zeros_(self.scale.weight)
nn.init.zeros_(self.scale.bias)
class _DiTDecoder(nn.Module):
def __init__(self, d_model=256, nhead=6, dim_feedforward=2048, dropout=0.0, activation="gelu"):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model, eps=1e-6)
self.norm2 = nn.LayerNorm(d_model, eps=1e-6)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
# create mlp
self.mlp = nn.Sequential(
self.linear1,
self.activation,
self.dropout2,
self.linear2,
self.dropout3,
)
# create modulation layers
self.attn_modulate = _ShiftScaleMod(d_model)
self.attn_gate = _ZeroScaleMod(d_model)
self.mlp_modulate = _ShiftScaleMod(d_model)
self.mlp_gate = _ZeroScaleMod(d_model)
def forward(self, x, t, cond):
# process the conditioning vector first
cond = cond + t
x2 = self.attn_modulate(self.norm1(x), cond)
x2, _ = self.self_attn(x2, x2, x2, need_weights=False)
x = x + self.attn_gate(self.dropout1(x2), cond)
x3 = self.mlp_modulate(self.norm2(x), cond)
x3 = self.mlp(x3)
x3 = self.mlp_gate(x3, cond)
return x + x3
def reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
for s in (self.attn_modulate, self.attn_gate, self.mlp_modulate, self.mlp_gate):
s.reset_parameters()
class _FinalLayer(nn.Module):
def __init__(self, hidden_size, out_size):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, out_size, bias=True)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
def forward(self, x, t, cond):
# process the conditioning vector first
cond = cond + t
shift, scale = self.adaLN_modulation(cond).chunk(2, dim=1)
x = modulate(x, shift, scale)
x = self.linear(x)
return x
def reset_parameters(self):
for p in self.parameters():
nn.init.zeros_(p)
class _TransformerDecoder(nn.Module):
def __init__(self, base_module, num_layers):
super().__init__()
self.layers = nn.ModuleList([copy.deepcopy(base_module) for _ in range(num_layers)])
for layer in self.layers:
layer.reset_parameters()
def forward(self, src, t, cond):
x = src
for layer in self.layers:
x = layer(x, t, cond)
return x
class _DiTNoiseNet(nn.Module):
def __init__(
self,
ac_dim,
ac_chunk,
cond_dim,
time_dim=256,
hidden_dim=256,
num_blocks=6,
dropout=0.1,
dim_feedforward=2048,
nhead=8,
activation="gelu",
clip_sample=False,
clip_sample_range=1.0,
):
super().__init__()
self.ac_dim, self.ac_chunk = ac_dim, ac_chunk
# positional encoding blocks
self.register_parameter(
"dec_pos",
nn.Parameter(torch.empty(ac_chunk, 1, hidden_dim), requires_grad=True),
)
nn.init.xavier_uniform_(self.dec_pos.data)
# input encoder mlps
self.time_net = _TimeNetwork(time_dim, hidden_dim)
self.ac_proj = nn.Sequential(
nn.Linear(ac_dim, ac_dim),
nn.GELU(approximate="tanh"),
nn.Linear(ac_dim, hidden_dim),
)
self.cond_proj = nn.Linear(cond_dim, hidden_dim)
# decoder blocks
decoder_module = _DiTDecoder(
hidden_dim,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
activation=activation,
)
self.decoder = _TransformerDecoder(decoder_module, num_blocks)
# turns predicted tokens into epsilons
self.eps_out = _FinalLayer(hidden_dim, ac_dim)
# clip the output samples
self.clip_sample = clip_sample
self.clip_sample_range = clip_sample_range
print("Number of flow params: {:.2f}M".format(sum(p.numel() for p in self.parameters()) / 1e6))
def forward(self, noisy_actions, time, global_cond):
c = self.cond_proj(global_cond)
time_enc = self.time_net(time)
ac_tokens = self.ac_proj(noisy_actions) # [B, T, adim] -> [B, T, hidden_dim]
ac_tokens = ac_tokens.transpose(0, 1) # [B, T, hidden_dim] -> [T, B, hidden_dim]
# Allow variable length action chunks
dec_in = ac_tokens + self.dec_pos[: ac_tokens.size(0)] # [T, B, hidden_dim]
# apply decoder
dec_out = self.decoder(dec_in, time_enc, c)
# apply final epsilon prediction layer
eps_out = self.eps_out(dec_out, time_enc, c) # [T, B, hidden_dim] -> [T, B, adim]
return eps_out.transpose(0, 1) # [T, B, adim] -> [B, T, adim]
@torch.no_grad()
def sample(
self, condition: torch.Tensor, timesteps: int = 100, generator: torch.Generator | None = None
) -> torch.Tensor:
# Use Euler integration to solve the ODE.
batch_size, device = condition.shape[0], condition.device
x_0 = self.sample_noise(batch_size, device, generator)
dt = 1.0 / timesteps
t_all = (
torch.arange(timesteps, device=device).float().unsqueeze(0).expand(batch_size, timesteps)
/ timesteps
)
for k in range(timesteps):
t = t_all[:, k]
x_0 = x_0 + dt * self.forward(x_0, t, condition)
if self.clip_sample:
x_0 = torch.clamp(x_0, -self.clip_sample_range, self.clip_sample_range)
return x_0
def sample_noise(self, batch_size: int, device, generator: torch.Generator | None = None) -> torch.Tensor:
return torch.randn(batch_size, self.ac_chunk, self.ac_dim, device=device, generator=generator)
class DiTFlowPolicy(PreTrainedPolicy):
"""
Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
(paper: https://arxiv.org/abs/2303.04137, code: https://github.com/real-stanford/diffusion_policy).
"""
config_class = DiTFlowConfig
name = "DiTFlow"
def __init__(
self,
config: DiTFlowConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
):
"""
Args:
config: Policy configuration class instance or None, in which case the default instantiation of
the configuration class is used.
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__(config)
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
# queues are populated during rollout of the policy, they contain the n latest observations and actions
self._queues = None
self.dit_flow = DiTFlowModel(config)
self.reset()
def get_optim_params(self) -> dict:
return self.dit_flow.parameters()
def reset(self):
"""Clear observation and action queues. Should be called on `env.reset()`"""
self._queues = {
"observation.state": deque(maxlen=self.config.n_obs_steps),
"action": deque(maxlen=self.config.n_action_steps),
}
if self.config.image_features:
self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps)
if self.config.env_state_feature:
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
@torch.no_grad
def select_action(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
"""Select a single action given environment observations.
This method handles caching a history of observations and an action trajectory generated by the
underlying flow model. Here's how it works:
- `n_obs_steps` steps worth of observations are cached (for the first steps, the observation is
copied `n_obs_steps` times to fill the cache).
- The flow model generates `horizon` steps worth of actions.
- `n_action_steps` worth of actions are actually kept for execution, starting from the current step.
Schematically this looks like:
----------------------------------------------------------------------------------------------
(legend: o = n_obs_steps, h = horizon, a = n_action_steps)
|timestep | n-o+1 | n-o+2 | ..... | n | ..... | n+a-1 | n+a | ..... | n-o+h |
|observation is used | YES | YES | YES | YES | NO | NO | NO | NO | NO |
|action is generated | YES | YES | YES | YES | YES | YES | YES | YES | YES |
|action is used | NO | NO | NO | YES | YES | YES | NO | NO | NO |
----------------------------------------------------------------------------------------------
Note that this means we require: `n_action_steps <= horizon - n_obs_steps + 1`. Also, note that
"horizon" may not the best name to describe what the variable actually means, because this period is
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
"""
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack(
[batch[key] for key in self.config.image_features], dim=-4
)
# Note: It's important that this happens after stacking the images into a single key.
self._queues = populate_queues(self._queues, batch)
if len(self._queues["action"]) == 0:
# stack n latest observations from the queue
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
actions = self.dit_flow.generate_actions(batch)
# TODO(rcadene): make above methods return output dictionary?
actions = self.unnormalize_outputs({"action": actions})["action"]
self._queues["action"].extend(actions.transpose(0, 1))
action = self._queues["action"].popleft()
return action
def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack(
[batch[key] for key in self.config.image_features], dim=-4
)
batch = self.normalize_targets(batch)
loss = self.dit_flow.compute_loss(batch)
return {"loss": loss}
class DiTFlowModel(nn.Module):
def __init__(self, config: DiTFlowConfig):
super().__init__()
self.config = config
# Build observation encoders (depending on which observations are provided).
global_cond_dim = self.config.robot_state_feature.shape[0]
if self.config.image_features:
num_images = len(self.config.image_features)
if self.config.use_separate_rgb_encoder_per_camera:
encoders = [DiffusionRgbEncoder(config) for _ in range(num_images)]
self.rgb_encoder = nn.ModuleList(encoders)
global_cond_dim += encoders[0].feature_dim * num_images
else:
self.rgb_encoder = DiffusionRgbEncoder(config)
global_cond_dim += self.rgb_encoder.feature_dim * num_images
if self.config.env_state_feature:
global_cond_dim += self.config.env_state_feature.shape[0]
self.velocity_net = _DiTNoiseNet(
ac_dim=config.action_feature.shape[0],
ac_chunk=config.horizon,
cond_dim=global_cond_dim * config.n_obs_steps,
time_dim=config.frequency_embedding_dim,
hidden_dim=config.hidden_dim,
num_blocks=config.num_blocks,
dropout=config.dropout,
dim_feedforward=config.dim_feedforward,
nhead=config.num_heads,
activation=config.activation,
clip_sample=config.clip_sample,
clip_sample_range=config.clip_sample_range,
)
self.num_inference_steps = config.num_inference_steps or 100
self.training_noise_sampling = config.training_noise_sampling
if config.training_noise_sampling == "uniform":
self.noise_distribution = torch.distributions.Uniform(
low=0,
high=1,
)
elif config.training_noise_sampling == "beta":
# From the Pi0 paper, https://www.physicalintelligence.company/download/pi0.pdf Appendix B.
# There, they say the PDF for the distribution they use is the following:
# $p(t) = Beta((s-t) / s; 1.5, 1)$
# So, we first figure out the distribution over $t'$ and then transform it to $t = s - s * t'$.
s = 0.999 # constant from the paper
beta_dist = torch.distributions.Beta(
concentration1=1.5, # alpha
concentration0=1.0, # beta
)
affine_transform = torch.distributions.transforms.AffineTransform(loc=s, scale=-s)
self.noise_distribution = torch.distributions.TransformedDistribution(
beta_dist, [affine_transform]
)
# ========= inference ============
def conditional_sample(
self,
batch_size: int,
global_cond: torch.Tensor | None = None,
generator: torch.Generator | None = None,
) -> torch.Tensor:
device = get_device_from_parameters(self)
dtype = get_dtype_from_parameters(self)
# Expand global conditioning to the batch size.
if global_cond is not None:
global_cond = global_cond.expand(batch_size, -1).to(device=device, dtype=dtype)
# Sample prior.
sample = self.velocity_net.sample(
global_cond, timesteps=self.num_inference_steps, generator=generator
)
return sample
def _prepare_global_conditioning(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
"""Encode image features and concatenate them all together along with the state vector."""
batch_size, n_obs_steps = batch[OBS_ROBOT].shape[:2]
global_cond_feats = [batch[OBS_ROBOT]]
# Extract image features.
if self.config.image_features:
if self.config.use_separate_rgb_encoder_per_camera:
# Combine batch and sequence dims while rearranging to make the camera index dimension first.
images_per_camera = einops.rearrange(batch["observation.images"], "b s n ... -> n (b s) ...")
img_features_list = torch.cat(
[
encoder(images)
for encoder, images in zip(self.rgb_encoder, images_per_camera, strict=True)
]
)
# Separate batch and sequence dims back out. The camera index dim gets absorbed into the
# feature dim (effectively concatenating the camera features).
img_features = einops.rearrange(
img_features_list, "(n b s) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
)
else:
# Combine batch, sequence, and "which camera" dims before passing to shared encoder.
img_features = self.rgb_encoder(
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
)
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
# feature dim (effectively concatenating the camera features).
img_features = einops.rearrange(
img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
)
global_cond_feats.append(img_features)
if self.config.env_state_feature:
global_cond_feats.append(batch[OBS_ENV])
# Concatenate features then flatten to (B, global_cond_dim).
return torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1)
def generate_actions(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
"""
This function expects `batch` to have:
{
"observation.state": (B, n_obs_steps, state_dim)
"observation.images": (B, n_obs_steps, num_cameras, C, H, W)
AND/OR
"observation.environment_state": (B, environment_dim)
}
"""
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
assert n_obs_steps == self.config.n_obs_steps
# Encode image features and concatenate them all together along with the state vector.
global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim)
# run sampling
actions = self.conditional_sample(batch_size, global_cond=global_cond)
# Extract `n_action_steps` steps worth of actions (from the current observation).
start = n_obs_steps - 1
end = start + self.config.n_action_steps
actions = actions[:, start:end]
return actions
def compute_loss(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
"""
This function expects `batch` to have (at least):
{
"observation.state": (B, n_obs_steps, state_dim)
"observation.images": (B, n_obs_steps, num_cameras, C, H, W)
AND/OR
"observation.environment_state": (B, environment_dim)
"action": (B, horizon, action_dim)
"action_is_pad": (B, horizon)
}
"""
# Input validation.
assert set(batch).issuperset({"observation.state", "action", "action_is_pad"})
assert "observation.images" in batch or "observation.environment_state" in batch
n_obs_steps = batch["observation.state"].shape[1]
horizon = batch["action"].shape[1]
assert horizon == self.config.horizon
assert n_obs_steps == self.config.n_obs_steps
# Encode image features and concatenate them all together along with the state vector.
global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim)
# Forward diffusion.
trajectory = batch["action"]
# Sample noise to add to the trajectory.
noise = self.velocity_net.sample_noise(trajectory.shape[0], trajectory.device)
# Sample a random noising timestep for each item in the batch.
timesteps = self.noise_distribution.sample((trajectory.shape[0],)).to(trajectory.device)
# Add noise to the clean trajectories according to the noise magnitude at each timestep.
noisy_trajectory = (1 - timesteps[:, None, None]) * noise + timesteps[:, None, None] * trajectory
# Run the denoising network (that might denoise the trajectory, or attempt to predict the noise).
pred = self.velocity_net(noisy_actions=noisy_trajectory, time=timesteps, global_cond=global_cond)
target = trajectory - noise
loss = F.mse_loss(pred, target, reduction="none")
# Mask loss wherever the action is padded with copies (edges of the dataset trajectory).
if self.config.do_mask_loss_for_padding:
if "action_is_pad" not in batch:
raise ValueError(
"You need to provide 'action_is_pad' in the batch when "
f"{self.config.do_mask_loss_for_padding=}."
)
in_episode_bound = ~batch["action_is_pad"]
loss = loss * in_episode_bound.unsqueeze(-1)
return loss.mean()

View File

@ -24,6 +24,7 @@ from lerobot.common.envs.configs import EnvConfig
from lerobot.common.envs.utils import env_to_policy_features from lerobot.common.envs.utils import env_to_policy_features
from lerobot.common.policies.act.configuration_act import ACTConfig from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.dit_flow.configuration_dit_flow import DiTFlowConfig
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.pretrained import PreTrainedPolicy
@ -43,6 +44,10 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
return DiffusionPolicy return DiffusionPolicy
elif name == "ditflow":
from lerobot.common.policies.dit_flow.modeling_dit_flow import DiTFlowPolicy
return DiTFlowPolicy
elif name == "act": elif name == "act":
from lerobot.common.policies.act.modeling_act import ACTPolicy from lerobot.common.policies.act.modeling_act import ACTPolicy
@ -68,6 +73,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return TDMPCConfig(**kwargs) return TDMPCConfig(**kwargs)
elif policy_type == "diffusion": elif policy_type == "diffusion":
return DiffusionConfig(**kwargs) return DiffusionConfig(**kwargs)
elif policy_type == "ditflow":
return DiTFlowConfig(**kwargs)
elif policy_type == "act": elif policy_type == "act":
return ACTConfig(**kwargs) return ACTConfig(**kwargs)
elif policy_type == "vqbet": elif policy_type == "vqbet":