Merge ef8579eacd
into 1c873df5c0
This commit is contained in:
commit
10a209e42d
|
@ -0,0 +1,212 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Nur Muhammad Mahi Shafiullah,
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common.optim.optimizers import AdamConfig
|
||||
from lerobot.common.optim.schedulers import DiffuserSchedulerConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("ditflow")
|
||||
@dataclass
|
||||
class DiTFlowConfig(PreTrainedConfig):
|
||||
"""Configuration class for DiTFlowPolicy.
|
||||
|
||||
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
|
||||
|
||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||
Those are: `input_shapes` and `output_shapes`.
|
||||
|
||||
Notes on the inputs and outputs:
|
||||
- "observation.state" is required as an input key.
|
||||
- Either:
|
||||
- At least one key starting with "observation.image is required as an input.
|
||||
AND/OR
|
||||
- The key "observation.environment_state" is required as input.
|
||||
- If there are multiple keys beginning with "observation.image" they are treated as multiple camera
|
||||
views. Right now we only support all images having the same shape.
|
||||
- "action" is required as an output key.
|
||||
|
||||
Args:
|
||||
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
||||
current step and additional steps going back).
|
||||
horizon: DiT-flow model action prediction size as detailed in `DiTFlowPolicy.select_action`.
|
||||
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
|
||||
See `DiTFlowPolicy.select_action` for more details.
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
||||
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
|
||||
include batch dimension or temporal dimension.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
|
||||
the output data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
|
||||
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
|
||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||
[-1, 1] range.
|
||||
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
|
||||
original scale. Note that this is also used for normalizing the training targets.
|
||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
|
||||
within the image size. If None, no cropping is done.
|
||||
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
|
||||
mode).
|
||||
pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone.
|
||||
`None` means no pretrained weights.
|
||||
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
|
||||
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
|
||||
spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax.
|
||||
use_separate_rgb_encoders_per_camera: Whether to use a separate RGB encoder for each camera view.
|
||||
|
||||
frequency_embedding_dim: The embedding dimension for the time value embedding in the flow model.
|
||||
num_blocks: The number of transformer blocks in the DiT flow model.
|
||||
hidden_dim: The hidden dimension for the transformer blocks in the DiT flow model.
|
||||
num_heads: The number of attention heads in the transformer blocks.
|
||||
dropout: The dropout rate used inside the transformer blocks.
|
||||
dim_feedforward: The expanded feedforward dimension in the MLPs used in the transformer block.
|
||||
activation: The activation function used in the transformer blocks.
|
||||
clip_sample: Whether to clip the sample to [-`clip_sample_range`, +`clip_sample_range`] for each
|
||||
denoising step at inference time. WARNING: you will need to make sure your action-space is
|
||||
normalized to fit within this range.
|
||||
clip_sample_range: The magnitude of the clipping range as described above.
|
||||
num_inference_steps: Number of reverse diffusion steps to use at inference time (steps are evenly
|
||||
spaced).
|
||||
do_mask_loss_for_padding: Whether to mask the loss when there are copy-padded actions. See
|
||||
`LeRobotDataset` and `load_previous_and_future_frames` for mor information. Note, this defaults
|
||||
to False as the original Diffusion Policy implementation does the same.
|
||||
"""
|
||||
|
||||
# Inputs / output structure.
|
||||
n_obs_steps: int = 2
|
||||
horizon: int = 16
|
||||
n_action_steps: int = 8
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
"STATE": NormalizationMode.MIN_MAX,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
)
|
||||
|
||||
# The original implementation doesn't sample frames for the last 7 steps,
|
||||
# which avoids excessive padding and leads to improved training results.
|
||||
drop_n_last_frames: int = 7 # horizon - n_action_steps - n_obs_steps + 1
|
||||
|
||||
# Architecture / modeling.
|
||||
# Vision backbone.
|
||||
vision_backbone: str = "resnet18"
|
||||
crop_shape: tuple[int, int] | None = (84, 84)
|
||||
crop_is_random: bool = True
|
||||
pretrained_backbone_weights: str | None = None
|
||||
use_group_norm: bool = True
|
||||
spatial_softmax_num_keypoints: int = 32
|
||||
use_separate_rgb_encoder_per_camera: bool = False
|
||||
|
||||
# Diffusion Transformer (DiT) parameters.
|
||||
frequency_embedding_dim: int = 256
|
||||
hidden_dim: int = 512
|
||||
num_blocks: int = 6
|
||||
num_heads: int = 16
|
||||
dropout: float = 0.1
|
||||
dim_feedforward: int = 4096
|
||||
activation: str = "gelu"
|
||||
|
||||
# Noise scheduler.
|
||||
training_noise_sampling: str = (
|
||||
"uniform" # "uniform" or "beta", from pi0 https://www.physicalintelligence.company/download/pi0.pdf
|
||||
)
|
||||
clip_sample: bool = True
|
||||
clip_sample_range: float = 1.0
|
||||
|
||||
# Inference
|
||||
num_inference_steps: int | None = 100
|
||||
|
||||
# Loss computation
|
||||
do_mask_loss_for_padding: bool = False
|
||||
|
||||
# Training presets
|
||||
optimizer_lr: float = 1e-4
|
||||
optimizer_betas: tuple = (0.95, 0.999)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-6
|
||||
scheduler_name: str = "cosine"
|
||||
scheduler_warmup_steps: int = 500
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
"""Input validation (not exhaustive)."""
|
||||
if not self.vision_backbone.startswith("resnet"):
|
||||
raise ValueError(
|
||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||
)
|
||||
|
||||
if self.training_noise_sampling not in ("uniform", "beta"):
|
||||
raise ValueError(
|
||||
f"`training_noise_sampling` must be either 'uniform' or 'beta'. Got {self.training_noise_sampling}."
|
||||
)
|
||||
|
||||
def get_optimizer_preset(self) -> AdamConfig:
|
||||
return AdamConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> DiffuserSchedulerConfig:
|
||||
return DiffuserSchedulerConfig(
|
||||
name=self.scheduler_name,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if len(self.image_features) == 0 and self.env_state_feature is None:
|
||||
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
|
||||
if self.crop_shape is not None:
|
||||
for key, image_ft in self.image_features.items():
|
||||
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
|
||||
raise ValueError(
|
||||
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
|
||||
f"for `crop_shape` and {image_ft.shape} for "
|
||||
f"`{key}`."
|
||||
)
|
||||
|
||||
# Check that all input images have the same shape.
|
||||
first_image_key, first_image_ft = next(iter(self.image_features.items()))
|
||||
for key, image_ft in self.image_features.items():
|
||||
if image_ft.shape != first_image_ft.shape:
|
||||
raise ValueError(
|
||||
f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match."
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list:
|
||||
return list(range(1 - self.n_obs_steps, 1))
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
|
@ -0,0 +1,587 @@
|
|||
# Copyright 2025 Nur Muhammad Mahi Shafiullah,
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
# Heavy inspiration taken from
|
||||
# * DETR by Meta AI (Carion et. al.): https://github.com/facebookresearch/detr
|
||||
# * DiT by Meta AI (Peebles and Xie): https://github.com/facebookresearch/DiT
|
||||
# * DiT Policy by Dasari et. al. : https://github.com/sudeepdasari/dit-policy
|
||||
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import copy
|
||||
from collections import deque
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
|
||||
from lerobot.common.constants import OBS_ENV, OBS_ROBOT
|
||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionRgbEncoder
|
||||
from lerobot.common.policies.dit_flow.configuration_dit_flow import DiTFlowConfig
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.utils import (
|
||||
get_device_from_parameters,
|
||||
get_dtype_from_parameters,
|
||||
populate_queues,
|
||||
)
|
||||
|
||||
|
||||
def _get_activation_fn(activation: str):
|
||||
"""Return an activation function given a string"""
|
||||
if activation == "relu":
|
||||
return F.relu
|
||||
if activation == "gelu":
|
||||
return nn.GELU(approximate="tanh")
|
||||
if activation == "glu":
|
||||
return F.glu
|
||||
raise RuntimeError(f"activation should be relu/gelu/glu, not {activation}.")
|
||||
|
||||
|
||||
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
||||
return x * (1 + scale.unsqueeze(0)) + shift.unsqueeze(0)
|
||||
|
||||
|
||||
class _TimeNetwork(nn.Module):
|
||||
def __init__(self, frequency_embedding_dim, hidden_dim, learnable_w=False, max_period=1000):
|
||||
assert frequency_embedding_dim % 2 == 0, "time_dim must be even!"
|
||||
half_dim = int(frequency_embedding_dim // 2)
|
||||
super().__init__()
|
||||
|
||||
w = np.log(max_period) / (half_dim - 1)
|
||||
w = torch.exp(torch.arange(half_dim) * -w).float()
|
||||
self.register_parameter("w", nn.Parameter(w, requires_grad=learnable_w))
|
||||
|
||||
self.out_net = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_dim, hidden_dim), nn.SiLU(), nn.Linear(hidden_dim, hidden_dim)
|
||||
)
|
||||
|
||||
def forward(self, t):
|
||||
assert len(t.shape) == 1, "assumes 1d input timestep array"
|
||||
t = t[:, None] * self.w[None]
|
||||
t = torch.cat((torch.cos(t), torch.sin(t)), dim=1)
|
||||
return self.out_net(t)
|
||||
|
||||
|
||||
class _ShiftScaleMod(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.act = nn.SiLU()
|
||||
self.scale = nn.Linear(dim, dim)
|
||||
self.shift = nn.Linear(dim, dim)
|
||||
|
||||
def forward(self, x, c):
|
||||
c = self.act(c)
|
||||
return x * (1 + self.scale(c)[None]) + self.shift(c)[None]
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.zeros_(self.scale.weight)
|
||||
nn.init.zeros_(self.shift.weight)
|
||||
nn.init.zeros_(self.scale.bias)
|
||||
nn.init.zeros_(self.shift.bias)
|
||||
|
||||
|
||||
class _ZeroScaleMod(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.act = nn.SiLU()
|
||||
self.scale = nn.Linear(dim, dim)
|
||||
|
||||
def forward(self, x, c):
|
||||
c = self.act(c)
|
||||
return x * self.scale(c)[None]
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.zeros_(self.scale.weight)
|
||||
nn.init.zeros_(self.scale.bias)
|
||||
|
||||
|
||||
class _DiTDecoder(nn.Module):
|
||||
def __init__(self, d_model=256, nhead=6, dim_feedforward=2048, dropout=0.0, activation="gelu"):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
||||
# Implementation of Feedforward model
|
||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||
|
||||
self.norm1 = nn.LayerNorm(d_model, eps=1e-6)
|
||||
self.norm2 = nn.LayerNorm(d_model, eps=1e-6)
|
||||
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
self.dropout3 = nn.Dropout(dropout)
|
||||
|
||||
self.activation = _get_activation_fn(activation)
|
||||
|
||||
# create mlp
|
||||
self.mlp = nn.Sequential(
|
||||
self.linear1,
|
||||
self.activation,
|
||||
self.dropout2,
|
||||
self.linear2,
|
||||
self.dropout3,
|
||||
)
|
||||
|
||||
# create modulation layers
|
||||
self.attn_modulate = _ShiftScaleMod(d_model)
|
||||
self.attn_gate = _ZeroScaleMod(d_model)
|
||||
self.mlp_modulate = _ShiftScaleMod(d_model)
|
||||
self.mlp_gate = _ZeroScaleMod(d_model)
|
||||
|
||||
def forward(self, x, t, cond):
|
||||
# process the conditioning vector first
|
||||
cond = cond + t
|
||||
|
||||
x2 = self.attn_modulate(self.norm1(x), cond)
|
||||
x2, _ = self.self_attn(x2, x2, x2, need_weights=False)
|
||||
x = x + self.attn_gate(self.dropout1(x2), cond)
|
||||
|
||||
x3 = self.mlp_modulate(self.norm2(x), cond)
|
||||
x3 = self.mlp(x3)
|
||||
x3 = self.mlp_gate(x3, cond)
|
||||
return x + x3
|
||||
|
||||
def reset_parameters(self):
|
||||
for p in self.parameters():
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
|
||||
for s in (self.attn_modulate, self.attn_gate, self.mlp_modulate, self.mlp_gate):
|
||||
s.reset_parameters()
|
||||
|
||||
|
||||
class _FinalLayer(nn.Module):
|
||||
def __init__(self, hidden_size, out_size):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, out_size, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
|
||||
|
||||
def forward(self, x, t, cond):
|
||||
# process the conditioning vector first
|
||||
cond = cond + t
|
||||
|
||||
shift, scale = self.adaLN_modulation(cond).chunk(2, dim=1)
|
||||
x = modulate(x, shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
def reset_parameters(self):
|
||||
for p in self.parameters():
|
||||
nn.init.zeros_(p)
|
||||
|
||||
|
||||
class _TransformerDecoder(nn.Module):
|
||||
def __init__(self, base_module, num_layers):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([copy.deepcopy(base_module) for _ in range(num_layers)])
|
||||
|
||||
for layer in self.layers:
|
||||
layer.reset_parameters()
|
||||
|
||||
def forward(self, src, t, cond):
|
||||
x = src
|
||||
for layer in self.layers:
|
||||
x = layer(x, t, cond)
|
||||
return x
|
||||
|
||||
|
||||
class _DiTNoiseNet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
ac_dim,
|
||||
ac_chunk,
|
||||
cond_dim,
|
||||
time_dim=256,
|
||||
hidden_dim=256,
|
||||
num_blocks=6,
|
||||
dropout=0.1,
|
||||
dim_feedforward=2048,
|
||||
nhead=8,
|
||||
activation="gelu",
|
||||
clip_sample=False,
|
||||
clip_sample_range=1.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.ac_dim, self.ac_chunk = ac_dim, ac_chunk
|
||||
|
||||
# positional encoding blocks
|
||||
self.register_parameter(
|
||||
"dec_pos",
|
||||
nn.Parameter(torch.empty(ac_chunk, 1, hidden_dim), requires_grad=True),
|
||||
)
|
||||
nn.init.xavier_uniform_(self.dec_pos.data)
|
||||
|
||||
# input encoder mlps
|
||||
self.time_net = _TimeNetwork(time_dim, hidden_dim)
|
||||
self.ac_proj = nn.Sequential(
|
||||
nn.Linear(ac_dim, ac_dim),
|
||||
nn.GELU(approximate="tanh"),
|
||||
nn.Linear(ac_dim, hidden_dim),
|
||||
)
|
||||
self.cond_proj = nn.Linear(cond_dim, hidden_dim)
|
||||
|
||||
# decoder blocks
|
||||
decoder_module = _DiTDecoder(
|
||||
hidden_dim,
|
||||
nhead=nhead,
|
||||
dim_feedforward=dim_feedforward,
|
||||
dropout=dropout,
|
||||
activation=activation,
|
||||
)
|
||||
self.decoder = _TransformerDecoder(decoder_module, num_blocks)
|
||||
|
||||
# turns predicted tokens into epsilons
|
||||
self.eps_out = _FinalLayer(hidden_dim, ac_dim)
|
||||
|
||||
# clip the output samples
|
||||
self.clip_sample = clip_sample
|
||||
self.clip_sample_range = clip_sample_range
|
||||
|
||||
print("Number of flow params: {:.2f}M".format(sum(p.numel() for p in self.parameters()) / 1e6))
|
||||
|
||||
def forward(self, noisy_actions, time, global_cond):
|
||||
c = self.cond_proj(global_cond)
|
||||
time_enc = self.time_net(time)
|
||||
|
||||
ac_tokens = self.ac_proj(noisy_actions) # [B, T, adim] -> [B, T, hidden_dim]
|
||||
ac_tokens = ac_tokens.transpose(0, 1) # [B, T, hidden_dim] -> [T, B, hidden_dim]
|
||||
|
||||
# Allow variable length action chunks
|
||||
dec_in = ac_tokens + self.dec_pos[: ac_tokens.size(0)] # [T, B, hidden_dim]
|
||||
|
||||
# apply decoder
|
||||
dec_out = self.decoder(dec_in, time_enc, c)
|
||||
|
||||
# apply final epsilon prediction layer
|
||||
eps_out = self.eps_out(dec_out, time_enc, c) # [T, B, hidden_dim] -> [T, B, adim]
|
||||
return eps_out.transpose(0, 1) # [T, B, adim] -> [B, T, adim]
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(
|
||||
self, condition: torch.Tensor, timesteps: int = 100, generator: torch.Generator | None = None
|
||||
) -> torch.Tensor:
|
||||
# Use Euler integration to solve the ODE.
|
||||
batch_size, device = condition.shape[0], condition.device
|
||||
x_0 = self.sample_noise(batch_size, device, generator)
|
||||
dt = 1.0 / timesteps
|
||||
t_all = (
|
||||
torch.arange(timesteps, device=device).float().unsqueeze(0).expand(batch_size, timesteps)
|
||||
/ timesteps
|
||||
)
|
||||
|
||||
for k in range(timesteps):
|
||||
t = t_all[:, k]
|
||||
x_0 = x_0 + dt * self.forward(x_0, t, condition)
|
||||
if self.clip_sample:
|
||||
x_0 = torch.clamp(x_0, -self.clip_sample_range, self.clip_sample_range)
|
||||
return x_0
|
||||
|
||||
def sample_noise(self, batch_size: int, device, generator: torch.Generator | None = None) -> torch.Tensor:
|
||||
return torch.randn(batch_size, self.ac_chunk, self.ac_dim, device=device, generator=generator)
|
||||
|
||||
|
||||
class DiTFlowPolicy(PreTrainedPolicy):
|
||||
"""
|
||||
Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
|
||||
(paper: https://arxiv.org/abs/2303.04137, code: https://github.com/real-stanford/diffusion_policy).
|
||||
"""
|
||||
|
||||
config_class = DiTFlowConfig
|
||||
name = "DiTFlow"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: DiTFlowConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
the configuration class is used.
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
||||
self._queues = None
|
||||
|
||||
self.dit_flow = DiTFlowModel(config)
|
||||
|
||||
self.reset()
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.dit_flow.parameters()
|
||||
|
||||
def reset(self):
|
||||
"""Clear observation and action queues. Should be called on `env.reset()`"""
|
||||
self._queues = {
|
||||
"observation.state": deque(maxlen=self.config.n_obs_steps),
|
||||
"action": deque(maxlen=self.config.n_action_steps),
|
||||
}
|
||||
if self.config.image_features:
|
||||
self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps)
|
||||
if self.config.env_state_feature:
|
||||
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
This method handles caching a history of observations and an action trajectory generated by the
|
||||
underlying flow model. Here's how it works:
|
||||
- `n_obs_steps` steps worth of observations are cached (for the first steps, the observation is
|
||||
copied `n_obs_steps` times to fill the cache).
|
||||
- The flow model generates `horizon` steps worth of actions.
|
||||
- `n_action_steps` worth of actions are actually kept for execution, starting from the current step.
|
||||
Schematically this looks like:
|
||||
----------------------------------------------------------------------------------------------
|
||||
(legend: o = n_obs_steps, h = horizon, a = n_action_steps)
|
||||
|timestep | n-o+1 | n-o+2 | ..... | n | ..... | n+a-1 | n+a | ..... | n-o+h |
|
||||
|observation is used | YES | YES | YES | YES | NO | NO | NO | NO | NO |
|
||||
|action is generated | YES | YES | YES | YES | YES | YES | YES | YES | YES |
|
||||
|action is used | NO | NO | NO | YES | YES | YES | NO | NO | NO |
|
||||
----------------------------------------------------------------------------------------------
|
||||
Note that this means we require: `n_action_steps <= horizon - n_obs_steps + 1`. Also, note that
|
||||
"horizon" may not the best name to describe what the variable actually means, because this period is
|
||||
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
|
||||
"""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[key] for key in self.config.image_features], dim=-4
|
||||
)
|
||||
# Note: It's important that this happens after stacking the images into a single key.
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
if len(self._queues["action"]) == 0:
|
||||
# stack n latest observations from the queue
|
||||
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
||||
actions = self.dit_flow.generate_actions(batch)
|
||||
|
||||
# TODO(rcadene): make above methods return output dictionary?
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
self._queues["action"].extend(actions.transpose(0, 1))
|
||||
|
||||
action = self._queues["action"].popleft()
|
||||
return action
|
||||
|
||||
def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[key] for key in self.config.image_features], dim=-4
|
||||
)
|
||||
batch = self.normalize_targets(batch)
|
||||
loss = self.dit_flow.compute_loss(batch)
|
||||
return {"loss": loss}
|
||||
|
||||
|
||||
class DiTFlowModel(nn.Module):
|
||||
def __init__(self, config: DiTFlowConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
# Build observation encoders (depending on which observations are provided).
|
||||
global_cond_dim = self.config.robot_state_feature.shape[0]
|
||||
if self.config.image_features:
|
||||
num_images = len(self.config.image_features)
|
||||
if self.config.use_separate_rgb_encoder_per_camera:
|
||||
encoders = [DiffusionRgbEncoder(config) for _ in range(num_images)]
|
||||
self.rgb_encoder = nn.ModuleList(encoders)
|
||||
global_cond_dim += encoders[0].feature_dim * num_images
|
||||
else:
|
||||
self.rgb_encoder = DiffusionRgbEncoder(config)
|
||||
global_cond_dim += self.rgb_encoder.feature_dim * num_images
|
||||
if self.config.env_state_feature:
|
||||
global_cond_dim += self.config.env_state_feature.shape[0]
|
||||
|
||||
self.velocity_net = _DiTNoiseNet(
|
||||
ac_dim=config.action_feature.shape[0],
|
||||
ac_chunk=config.horizon,
|
||||
cond_dim=global_cond_dim * config.n_obs_steps,
|
||||
time_dim=config.frequency_embedding_dim,
|
||||
hidden_dim=config.hidden_dim,
|
||||
num_blocks=config.num_blocks,
|
||||
dropout=config.dropout,
|
||||
dim_feedforward=config.dim_feedforward,
|
||||
nhead=config.num_heads,
|
||||
activation=config.activation,
|
||||
clip_sample=config.clip_sample,
|
||||
clip_sample_range=config.clip_sample_range,
|
||||
)
|
||||
|
||||
self.num_inference_steps = config.num_inference_steps or 100
|
||||
self.training_noise_sampling = config.training_noise_sampling
|
||||
if config.training_noise_sampling == "uniform":
|
||||
self.noise_distribution = torch.distributions.Uniform(
|
||||
low=0,
|
||||
high=1,
|
||||
)
|
||||
elif config.training_noise_sampling == "beta":
|
||||
# From the Pi0 paper, https://www.physicalintelligence.company/download/pi0.pdf Appendix B.
|
||||
# There, they say the PDF for the distribution they use is the following:
|
||||
# $p(t) = Beta((s-t) / s; 1.5, 1)$
|
||||
# So, we first figure out the distribution over $t'$ and then transform it to $t = s - s * t'$.
|
||||
s = 0.999 # constant from the paper
|
||||
beta_dist = torch.distributions.Beta(
|
||||
concentration1=1.5, # alpha
|
||||
concentration0=1.0, # beta
|
||||
)
|
||||
affine_transform = torch.distributions.transforms.AffineTransform(loc=s, scale=-s)
|
||||
self.noise_distribution = torch.distributions.TransformedDistribution(
|
||||
beta_dist, [affine_transform]
|
||||
)
|
||||
|
||||
# ========= inference ============
|
||||
def conditional_sample(
|
||||
self,
|
||||
batch_size: int,
|
||||
global_cond: torch.Tensor | None = None,
|
||||
generator: torch.Generator | None = None,
|
||||
) -> torch.Tensor:
|
||||
device = get_device_from_parameters(self)
|
||||
dtype = get_dtype_from_parameters(self)
|
||||
|
||||
# Expand global conditioning to the batch size.
|
||||
if global_cond is not None:
|
||||
global_cond = global_cond.expand(batch_size, -1).to(device=device, dtype=dtype)
|
||||
|
||||
# Sample prior.
|
||||
sample = self.velocity_net.sample(
|
||||
global_cond, timesteps=self.num_inference_steps, generator=generator
|
||||
)
|
||||
return sample
|
||||
|
||||
def _prepare_global_conditioning(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""Encode image features and concatenate them all together along with the state vector."""
|
||||
batch_size, n_obs_steps = batch[OBS_ROBOT].shape[:2]
|
||||
global_cond_feats = [batch[OBS_ROBOT]]
|
||||
# Extract image features.
|
||||
if self.config.image_features:
|
||||
if self.config.use_separate_rgb_encoder_per_camera:
|
||||
# Combine batch and sequence dims while rearranging to make the camera index dimension first.
|
||||
images_per_camera = einops.rearrange(batch["observation.images"], "b s n ... -> n (b s) ...")
|
||||
img_features_list = torch.cat(
|
||||
[
|
||||
encoder(images)
|
||||
for encoder, images in zip(self.rgb_encoder, images_per_camera, strict=True)
|
||||
]
|
||||
)
|
||||
# Separate batch and sequence dims back out. The camera index dim gets absorbed into the
|
||||
# feature dim (effectively concatenating the camera features).
|
||||
img_features = einops.rearrange(
|
||||
img_features_list, "(n b s) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
|
||||
)
|
||||
else:
|
||||
# Combine batch, sequence, and "which camera" dims before passing to shared encoder.
|
||||
img_features = self.rgb_encoder(
|
||||
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
|
||||
)
|
||||
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
|
||||
# feature dim (effectively concatenating the camera features).
|
||||
img_features = einops.rearrange(
|
||||
img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
|
||||
)
|
||||
global_cond_feats.append(img_features)
|
||||
|
||||
if self.config.env_state_feature:
|
||||
global_cond_feats.append(batch[OBS_ENV])
|
||||
|
||||
# Concatenate features then flatten to (B, global_cond_dim).
|
||||
return torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1)
|
||||
|
||||
def generate_actions(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
This function expects `batch` to have:
|
||||
{
|
||||
"observation.state": (B, n_obs_steps, state_dim)
|
||||
|
||||
"observation.images": (B, n_obs_steps, num_cameras, C, H, W)
|
||||
AND/OR
|
||||
"observation.environment_state": (B, environment_dim)
|
||||
}
|
||||
"""
|
||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||
assert n_obs_steps == self.config.n_obs_steps
|
||||
|
||||
# Encode image features and concatenate them all together along with the state vector.
|
||||
global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim)
|
||||
|
||||
# run sampling
|
||||
actions = self.conditional_sample(batch_size, global_cond=global_cond)
|
||||
|
||||
# Extract `n_action_steps` steps worth of actions (from the current observation).
|
||||
start = n_obs_steps - 1
|
||||
end = start + self.config.n_action_steps
|
||||
actions = actions[:, start:end]
|
||||
|
||||
return actions
|
||||
|
||||
def compute_loss(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
This function expects `batch` to have (at least):
|
||||
{
|
||||
"observation.state": (B, n_obs_steps, state_dim)
|
||||
|
||||
"observation.images": (B, n_obs_steps, num_cameras, C, H, W)
|
||||
AND/OR
|
||||
"observation.environment_state": (B, environment_dim)
|
||||
|
||||
"action": (B, horizon, action_dim)
|
||||
"action_is_pad": (B, horizon)
|
||||
}
|
||||
"""
|
||||
# Input validation.
|
||||
assert set(batch).issuperset({"observation.state", "action", "action_is_pad"})
|
||||
assert "observation.images" in batch or "observation.environment_state" in batch
|
||||
n_obs_steps = batch["observation.state"].shape[1]
|
||||
horizon = batch["action"].shape[1]
|
||||
assert horizon == self.config.horizon
|
||||
assert n_obs_steps == self.config.n_obs_steps
|
||||
|
||||
# Encode image features and concatenate them all together along with the state vector.
|
||||
global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim)
|
||||
|
||||
# Forward diffusion.
|
||||
trajectory = batch["action"]
|
||||
# Sample noise to add to the trajectory.
|
||||
noise = self.velocity_net.sample_noise(trajectory.shape[0], trajectory.device)
|
||||
# Sample a random noising timestep for each item in the batch.
|
||||
timesteps = self.noise_distribution.sample((trajectory.shape[0],)).to(trajectory.device)
|
||||
# Add noise to the clean trajectories according to the noise magnitude at each timestep.
|
||||
noisy_trajectory = (1 - timesteps[:, None, None]) * noise + timesteps[:, None, None] * trajectory
|
||||
|
||||
# Run the denoising network (that might denoise the trajectory, or attempt to predict the noise).
|
||||
pred = self.velocity_net(noisy_actions=noisy_trajectory, time=timesteps, global_cond=global_cond)
|
||||
target = trajectory - noise
|
||||
loss = F.mse_loss(pred, target, reduction="none")
|
||||
|
||||
# Mask loss wherever the action is padded with copies (edges of the dataset trajectory).
|
||||
if self.config.do_mask_loss_for_padding:
|
||||
if "action_is_pad" not in batch:
|
||||
raise ValueError(
|
||||
"You need to provide 'action_is_pad' in the batch when "
|
||||
f"{self.config.do_mask_loss_for_padding=}."
|
||||
)
|
||||
in_episode_bound = ~batch["action_is_pad"]
|
||||
loss = loss * in_episode_bound.unsqueeze(-1)
|
||||
|
||||
return loss.mean()
|
|
@ -24,6 +24,7 @@ from lerobot.common.envs.configs import EnvConfig
|
|||
from lerobot.common.envs.utils import env_to_policy_features
|
||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.dit_flow.configuration_dit_flow import DiTFlowConfig
|
||||
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
|
@ -43,6 +44,10 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
|
|||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
|
||||
return DiffusionPolicy
|
||||
elif name == "ditflow":
|
||||
from lerobot.common.policies.dit_flow.modeling_dit_flow import DiTFlowPolicy
|
||||
|
||||
return DiTFlowPolicy
|
||||
elif name == "act":
|
||||
from lerobot.common.policies.act.modeling_act import ACTPolicy
|
||||
|
||||
|
@ -68,6 +73,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
|||
return TDMPCConfig(**kwargs)
|
||||
elif policy_type == "diffusion":
|
||||
return DiffusionConfig(**kwargs)
|
||||
elif policy_type == "ditflow":
|
||||
return DiTFlowConfig(**kwargs)
|
||||
elif policy_type == "act":
|
||||
return ACTConfig(**kwargs)
|
||||
elif policy_type == "vqbet":
|
||||
|
|
Loading…
Reference in New Issue