591 lines
24 KiB
Python
591 lines
24 KiB
Python
# Copyright 2025 Nur Muhammad Mahi Shafiullah,
|
|
# and The HuggingFace Inc. team. All rights reserved.
|
|
# Heavy inspiration taken from
|
|
# * DETR by Meta AI (Carion et. al.): https://github.com/facebookresearch/detr
|
|
# * DiT by Meta AI (Peebles and Xie): https://github.com/facebookresearch/DiT
|
|
# * DiT Policy by Dasari et. al. : https://dit-policy.github.io/
|
|
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import copy
|
|
from collections import deque
|
|
|
|
import einops
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F # noqa: N812
|
|
|
|
from lerobot.common.constants import OBS_ENV, OBS_ROBOT
|
|
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionRgbEncoder
|
|
from lerobot.common.policies.dit_flow.configuration_dit_flow import DiTFlowConfig
|
|
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
|
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
|
from lerobot.common.policies.utils import (
|
|
get_device_from_parameters,
|
|
get_dtype_from_parameters,
|
|
populate_queues,
|
|
)
|
|
|
|
|
|
def _get_activation_fn(activation: str):
|
|
"""Return an activation function given a string"""
|
|
if activation == "relu":
|
|
return F.relu
|
|
if activation == "gelu":
|
|
return nn.GELU(approximate="tanh")
|
|
if activation == "glu":
|
|
return F.glu
|
|
raise RuntimeError(f"activation should be relu/gelu/glu, not {activation}.")
|
|
|
|
|
|
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
|
return x * (1 + scale.unsqueeze(0)) + shift.unsqueeze(0)
|
|
|
|
|
|
class _TimeNetwork(nn.Module):
|
|
def __init__(self, frequency_embedding_dim, hidden_dim, learnable_w=False, max_period=1000):
|
|
assert frequency_embedding_dim % 2 == 0, "time_dim must be even!"
|
|
half_dim = int(frequency_embedding_dim // 2)
|
|
super().__init__()
|
|
|
|
w = np.log(max_period) / (half_dim - 1)
|
|
w = torch.exp(torch.arange(half_dim) * -w).float()
|
|
self.register_parameter("w", nn.Parameter(w, requires_grad=learnable_w))
|
|
|
|
self.out_net = nn.Sequential(
|
|
nn.Linear(frequency_embedding_dim, hidden_dim), nn.SiLU(), nn.Linear(hidden_dim, hidden_dim)
|
|
)
|
|
|
|
def forward(self, t):
|
|
assert len(t.shape) == 1, "assumes 1d input timestep array"
|
|
t = t[:, None] * self.w[None]
|
|
t = torch.cat((torch.cos(t), torch.sin(t)), dim=1)
|
|
return self.out_net(t)
|
|
|
|
|
|
class _ShiftScaleMod(nn.Module):
|
|
def __init__(self, dim):
|
|
super().__init__()
|
|
self.act = nn.SiLU()
|
|
self.scale = nn.Linear(dim, dim)
|
|
self.shift = nn.Linear(dim, dim)
|
|
|
|
def forward(self, x, c):
|
|
c = self.act(c)
|
|
return x * (1 + self.scale(c)[None]) + self.shift(c)[None]
|
|
|
|
def reset_parameters(self):
|
|
nn.init.zeros_(self.scale.weight)
|
|
nn.init.zeros_(self.shift.weight)
|
|
nn.init.zeros_(self.scale.bias)
|
|
nn.init.zeros_(self.shift.bias)
|
|
|
|
|
|
class _ZeroScaleMod(nn.Module):
|
|
def __init__(self, dim):
|
|
super().__init__()
|
|
self.act = nn.SiLU()
|
|
self.scale = nn.Linear(dim, dim)
|
|
|
|
def forward(self, x, c):
|
|
c = self.act(c)
|
|
return x * self.scale(c)[None]
|
|
|
|
def reset_parameters(self):
|
|
nn.init.zeros_(self.scale.weight)
|
|
nn.init.zeros_(self.scale.bias)
|
|
|
|
|
|
class _DiTDecoder(nn.Module):
|
|
def __init__(self, d_model=256, nhead=6, dim_feedforward=2048, dropout=0.0, activation="gelu"):
|
|
super().__init__()
|
|
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
|
# Implementation of Feedforward model
|
|
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
|
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
|
|
|
self.norm1 = nn.LayerNorm(d_model, eps=1e-6)
|
|
self.norm2 = nn.LayerNorm(d_model, eps=1e-6)
|
|
|
|
self.dropout1 = nn.Dropout(dropout)
|
|
self.dropout2 = nn.Dropout(dropout)
|
|
self.dropout3 = nn.Dropout(dropout)
|
|
|
|
self.activation = _get_activation_fn(activation)
|
|
|
|
# create mlp
|
|
self.mlp = nn.Sequential(
|
|
self.linear1,
|
|
self.activation,
|
|
self.dropout2,
|
|
self.linear2,
|
|
self.dropout3,
|
|
)
|
|
|
|
# create modulation layers
|
|
self.attn_modulate = _ShiftScaleMod(d_model)
|
|
self.attn_gate = _ZeroScaleMod(d_model)
|
|
self.mlp_modulate = _ShiftScaleMod(d_model)
|
|
self.mlp_gate = _ZeroScaleMod(d_model)
|
|
|
|
def forward(self, x, t, cond):
|
|
# process the conditioning vector first
|
|
cond = cond + t
|
|
|
|
x2 = self.attn_modulate(self.norm1(x), cond)
|
|
x2, _ = self.self_attn(x2, x2, x2, need_weights=False)
|
|
x = x + self.attn_gate(self.dropout1(x2), cond)
|
|
|
|
x3 = self.mlp_modulate(self.norm2(x), cond)
|
|
# TODO: verify and then remove
|
|
# x3 = self.linear2(self.dropout2(self.activation(self.linear1(x3))))
|
|
# x3 = self.mlp_gate(self.dropout3(x3), cond)
|
|
x3 = self.mlp(x3)
|
|
x3 = self.mlp_gate(x3, cond)
|
|
return x + x3
|
|
|
|
def reset_parameters(self):
|
|
for p in self.parameters():
|
|
if p.dim() > 1:
|
|
nn.init.xavier_uniform_(p)
|
|
|
|
for s in (self.attn_modulate, self.attn_gate, self.mlp_modulate, self.mlp_gate):
|
|
s.reset_parameters()
|
|
|
|
|
|
class _FinalLayer(nn.Module):
|
|
def __init__(self, hidden_size, out_size):
|
|
super().__init__()
|
|
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
|
self.linear = nn.Linear(hidden_size, out_size, bias=True)
|
|
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
|
|
|
|
def forward(self, x, t, cond):
|
|
# process the conditioning vector first
|
|
cond = cond + t
|
|
|
|
shift, scale = self.adaLN_modulation(cond).chunk(2, dim=1)
|
|
x = modulate(x, shift, scale)
|
|
x = self.linear(x)
|
|
return x
|
|
|
|
def reset_parameters(self):
|
|
for p in self.parameters():
|
|
nn.init.zeros_(p)
|
|
|
|
|
|
class _TransformerDecoder(nn.Module):
|
|
def __init__(self, base_module, num_layers):
|
|
super().__init__()
|
|
self.layers = nn.ModuleList([copy.deepcopy(base_module) for _ in range(num_layers)])
|
|
|
|
for layer in self.layers:
|
|
layer.reset_parameters()
|
|
|
|
def forward(self, src, t, cond):
|
|
x = src
|
|
for layer in self.layers:
|
|
x = layer(x, t, cond)
|
|
return x
|
|
|
|
|
|
class _DiTNoiseNet(nn.Module):
|
|
def __init__(
|
|
self,
|
|
ac_dim,
|
|
ac_chunk,
|
|
cond_dim,
|
|
time_dim=256,
|
|
hidden_dim=256,
|
|
num_blocks=6,
|
|
dropout=0.1,
|
|
dim_feedforward=2048,
|
|
nhead=6,
|
|
activation="gelu",
|
|
clip_sample=False,
|
|
clip_sample_range=1.0,
|
|
):
|
|
super().__init__()
|
|
self.ac_dim, self.ac_chunk = ac_dim, ac_chunk
|
|
|
|
# positional encoding blocks
|
|
self.register_parameter(
|
|
"dec_pos",
|
|
nn.Parameter(torch.empty(ac_chunk, 1, hidden_dim), requires_grad=True),
|
|
)
|
|
nn.init.xavier_uniform_(self.dec_pos.data)
|
|
|
|
# input encoder mlps
|
|
self.time_net = _TimeNetwork(time_dim, hidden_dim)
|
|
self.ac_proj = nn.Sequential(
|
|
nn.Linear(ac_dim, ac_dim),
|
|
nn.GELU(approximate="tanh"),
|
|
nn.Linear(ac_dim, hidden_dim),
|
|
)
|
|
self.cond_proj = nn.Linear(cond_dim, hidden_dim)
|
|
|
|
# decoder blocks
|
|
decoder_module = _DiTDecoder(
|
|
hidden_dim,
|
|
nhead=nhead,
|
|
dim_feedforward=dim_feedforward,
|
|
dropout=dropout,
|
|
activation=activation,
|
|
)
|
|
self.decoder = _TransformerDecoder(decoder_module, num_blocks)
|
|
|
|
# turns predicted tokens into epsilons
|
|
self.eps_out = _FinalLayer(hidden_dim, ac_dim)
|
|
|
|
# clip the output samples
|
|
self.clip_sample = clip_sample
|
|
self.clip_sample_range = clip_sample_range
|
|
|
|
print("Number of flow params: {:.2f}M".format(sum(p.numel() for p in self.parameters()) / 1e6))
|
|
|
|
def forward(self, noisy_actions, time, global_cond):
|
|
c = self.cond_proj(global_cond)
|
|
time_enc = self.time_net(time)
|
|
|
|
ac_tokens = self.ac_proj(noisy_actions) # [B, T, adim] -> [B, T, hidden_dim]
|
|
ac_tokens = ac_tokens.transpose(0, 1) # [B, T, hidden_dim] -> [T, B, hidden_dim]
|
|
|
|
# Allow variable length action chunks
|
|
dec_in = ac_tokens + self.dec_pos[: ac_tokens.size(0)] # [T, B, hidden_dim]
|
|
|
|
# apply decoder
|
|
dec_out = self.decoder(dec_in, time_enc, c)
|
|
|
|
# apply final epsilon prediction layer
|
|
eps_out = self.eps_out(dec_out, time_enc, c) # [T, B, hidden_dim] -> [T, B, adim]
|
|
return eps_out.transpose(0, 1) # [T, B, adim] -> [B, T, adim]
|
|
|
|
@torch.no_grad()
|
|
def sample(
|
|
self, condition: torch.Tensor, timesteps: int = 100, generator: torch.Generator | None = None
|
|
) -> torch.Tensor:
|
|
# Use Euler integration to solve the ODE.
|
|
batch_size, device = condition.shape[0], condition.device
|
|
x_0 = self.sample_noise(batch_size, device, generator)
|
|
dt = 1.0 / timesteps
|
|
t_all = (
|
|
torch.arange(timesteps, device=device).float().unsqueeze(0).expand(batch_size, timesteps)
|
|
/ timesteps
|
|
)
|
|
|
|
for k in range(timesteps):
|
|
t = t_all[:, k]
|
|
x_0 = x_0 + dt * self.forward(x_0, t, condition)
|
|
if self.clip_sample:
|
|
x_0 = torch.clamp(x_0, -self.clip_sample_range, self.clip_sample_range)
|
|
return x_0
|
|
|
|
def sample_noise(self, batch_size: int, device, generator: torch.Generator | None = None) -> torch.Tensor:
|
|
return torch.randn(batch_size, self.ac_chunk, self.ac_dim, device=device, generator=generator)
|
|
|
|
|
|
class DiTFlowPolicy(PreTrainedPolicy):
|
|
"""
|
|
Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
|
|
(paper: https://arxiv.org/abs/2303.04137, code: https://github.com/real-stanford/diffusion_policy).
|
|
"""
|
|
|
|
config_class = DiTFlowConfig
|
|
name = "DiTFlow"
|
|
|
|
def __init__(
|
|
self,
|
|
config: DiTFlowConfig,
|
|
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
|
):
|
|
"""
|
|
Args:
|
|
config: Policy configuration class instance or None, in which case the default instantiation of
|
|
the configuration class is used.
|
|
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
|
that they will be passed with a call to `load_state_dict` before the policy is used.
|
|
"""
|
|
super().__init__(config)
|
|
config.validate_features()
|
|
self.config = config
|
|
|
|
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
|
self.normalize_targets = Normalize(
|
|
config.output_features, config.normalization_mapping, dataset_stats
|
|
)
|
|
self.unnormalize_outputs = Unnormalize(
|
|
config.output_features, config.normalization_mapping, dataset_stats
|
|
)
|
|
|
|
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
|
self._queues = None
|
|
|
|
self.dit_flow = DiTFlowModel(config)
|
|
|
|
self.reset()
|
|
|
|
def get_optim_params(self) -> dict:
|
|
return self.dit_flow.parameters()
|
|
|
|
def reset(self):
|
|
"""Clear observation and action queues. Should be called on `env.reset()`"""
|
|
self._queues = {
|
|
"observation.state": deque(maxlen=self.config.n_obs_steps),
|
|
"action": deque(maxlen=self.config.n_action_steps),
|
|
}
|
|
if self.config.image_features:
|
|
self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps)
|
|
if self.config.env_state_feature:
|
|
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
|
|
|
|
@torch.no_grad
|
|
def select_action(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
"""Select a single action given environment observations.
|
|
|
|
This method handles caching a history of observations and an action trajectory generated by the
|
|
underlying flow model. Here's how it works:
|
|
- `n_obs_steps` steps worth of observations are cached (for the first steps, the observation is
|
|
copied `n_obs_steps` times to fill the cache).
|
|
- The flow model generates `horizon` steps worth of actions.
|
|
- `n_action_steps` worth of actions are actually kept for execution, starting from the current step.
|
|
Schematically this looks like:
|
|
----------------------------------------------------------------------------------------------
|
|
(legend: o = n_obs_steps, h = horizon, a = n_action_steps)
|
|
|timestep | n-o+1 | n-o+2 | ..... | n | ..... | n+a-1 | n+a | ..... | n-o+h |
|
|
|observation is used | YES | YES | YES | YES | NO | NO | NO | NO | NO |
|
|
|action is generated | YES | YES | YES | YES | YES | YES | YES | YES | YES |
|
|
|action is used | NO | NO | NO | YES | YES | YES | NO | NO | NO |
|
|
----------------------------------------------------------------------------------------------
|
|
Note that this means we require: `n_action_steps <= horizon - n_obs_steps + 1`. Also, note that
|
|
"horizon" may not the best name to describe what the variable actually means, because this period is
|
|
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
|
|
"""
|
|
batch = self.normalize_inputs(batch)
|
|
if self.config.image_features:
|
|
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
|
batch["observation.images"] = torch.stack(
|
|
[batch[key] for key in self.config.image_features], dim=-4
|
|
)
|
|
# Note: It's important that this happens after stacking the images into a single key.
|
|
self._queues = populate_queues(self._queues, batch)
|
|
|
|
if len(self._queues["action"]) == 0:
|
|
# stack n latest observations from the queue
|
|
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
|
actions = self.dit_flow.generate_actions(batch)
|
|
|
|
# TODO(rcadene): make above methods return output dictionary?
|
|
actions = self.unnormalize_outputs({"action": actions})["action"]
|
|
|
|
self._queues["action"].extend(actions.transpose(0, 1))
|
|
|
|
action = self._queues["action"].popleft()
|
|
return action
|
|
|
|
def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
|
"""Run the batch through the model and compute the loss for training or validation."""
|
|
batch = self.normalize_inputs(batch)
|
|
if self.config.image_features:
|
|
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
|
batch["observation.images"] = torch.stack(
|
|
[batch[key] for key in self.config.image_features], dim=-4
|
|
)
|
|
batch = self.normalize_targets(batch)
|
|
loss = self.dit_flow.compute_loss(batch)
|
|
return {"loss": loss}
|
|
|
|
|
|
class DiTFlowModel(nn.Module):
|
|
def __init__(self, config: DiTFlowConfig):
|
|
super().__init__()
|
|
self.config = config
|
|
|
|
# Build observation encoders (depending on which observations are provided).
|
|
global_cond_dim = self.config.robot_state_feature.shape[0]
|
|
if self.config.image_features:
|
|
num_images = len(self.config.image_features)
|
|
if self.config.use_separate_rgb_encoder_per_camera:
|
|
encoders = [DiffusionRgbEncoder(config) for _ in range(num_images)]
|
|
self.rgb_encoder = nn.ModuleList(encoders)
|
|
global_cond_dim += encoders[0].feature_dim * num_images
|
|
else:
|
|
self.rgb_encoder = DiffusionRgbEncoder(config)
|
|
global_cond_dim += self.rgb_encoder.feature_dim * num_images
|
|
if self.config.env_state_feature:
|
|
global_cond_dim += self.config.env_state_feature.shape[0]
|
|
|
|
self.velocity_net = _DiTNoiseNet(
|
|
ac_dim=config.action_feature.shape[0],
|
|
ac_chunk=config.horizon,
|
|
cond_dim=global_cond_dim * config.n_obs_steps,
|
|
time_dim=config.frequency_embedding_dim,
|
|
hidden_dim=config.hidden_dim,
|
|
num_blocks=config.num_blocks,
|
|
dropout=config.dropout,
|
|
dim_feedforward=config.dim_feedforward,
|
|
nhead=config.num_heads,
|
|
activation=config.activation,
|
|
clip_sample=config.clip_sample,
|
|
clip_sample_range=config.clip_sample_range,
|
|
)
|
|
|
|
self.num_inference_steps = config.num_inference_steps or 100
|
|
self.training_noise_sampling = config.training_noise_sampling
|
|
if config.training_noise_sampling == "uniform":
|
|
self.noise_distribution = torch.distributions.Uniform(
|
|
low=0,
|
|
high=1,
|
|
)
|
|
elif config.training_noise_sampling == "beta":
|
|
# From the Pi0 paper, https://www.physicalintelligence.company/download/pi0.pdf Appendix B.
|
|
# There, they say the PDF for the distribution they use is the following:
|
|
# $p(t) = Beta(s-t / s; 1.5, 1)$
|
|
# So, we first figure out the distribution over $t' = s - s * t$ and then transform it to $t$.
|
|
s = 0.999 # constant from the paper
|
|
beta_dist = torch.distributions.Beta(
|
|
concentration1=1.5, # alpha
|
|
concentration0=1.0, # beta
|
|
)
|
|
affine_transform = torch.distributions.transforms.AffineTransform(loc=s, scale=-s)
|
|
self.noise_distribution = torch.distributions.TransformedDistribution(
|
|
beta_dist, [affine_transform]
|
|
)
|
|
|
|
# ========= inference ============
|
|
def conditional_sample(
|
|
self,
|
|
batch_size: int,
|
|
global_cond: torch.Tensor | None = None,
|
|
generator: torch.Generator | None = None,
|
|
) -> torch.Tensor:
|
|
device = get_device_from_parameters(self)
|
|
dtype = get_dtype_from_parameters(self)
|
|
|
|
# Expand global conditioning to the batch size.
|
|
if global_cond is not None:
|
|
global_cond = global_cond.expand(batch_size, -1).to(device=device, dtype=dtype)
|
|
|
|
# Sample prior.
|
|
sample = self.velocity_net.sample(
|
|
global_cond, timesteps=self.num_inference_steps, generator=generator
|
|
)
|
|
return sample
|
|
|
|
def _prepare_global_conditioning(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
"""Encode image features and concatenate them all together along with the state vector."""
|
|
batch_size, n_obs_steps = batch[OBS_ROBOT].shape[:2]
|
|
global_cond_feats = [batch[OBS_ROBOT]]
|
|
# Extract image features.
|
|
if self.config.image_features:
|
|
if self.config.use_separate_rgb_encoder_per_camera:
|
|
# Combine batch and sequence dims while rearranging to make the camera index dimension first.
|
|
images_per_camera = einops.rearrange(batch["observation.images"], "b s n ... -> n (b s) ...")
|
|
img_features_list = torch.cat(
|
|
[
|
|
encoder(images)
|
|
for encoder, images in zip(self.rgb_encoder, images_per_camera, strict=True)
|
|
]
|
|
)
|
|
# Separate batch and sequence dims back out. The camera index dim gets absorbed into the
|
|
# feature dim (effectively concatenating the camera features).
|
|
img_features = einops.rearrange(
|
|
img_features_list, "(n b s) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
|
|
)
|
|
else:
|
|
# Combine batch, sequence, and "which camera" dims before passing to shared encoder.
|
|
img_features = self.rgb_encoder(
|
|
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
|
|
)
|
|
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
|
|
# feature dim (effectively concatenating the camera features).
|
|
img_features = einops.rearrange(
|
|
img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
|
|
)
|
|
global_cond_feats.append(img_features)
|
|
|
|
if self.config.env_state_feature:
|
|
global_cond_feats.append(batch[OBS_ENV])
|
|
|
|
# Concatenate features then flatten to (B, global_cond_dim).
|
|
return torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1)
|
|
|
|
def generate_actions(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
"""
|
|
This function expects `batch` to have:
|
|
{
|
|
"observation.state": (B, n_obs_steps, state_dim)
|
|
|
|
"observation.images": (B, n_obs_steps, num_cameras, C, H, W)
|
|
AND/OR
|
|
"observation.environment_state": (B, environment_dim)
|
|
}
|
|
"""
|
|
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
|
assert n_obs_steps == self.config.n_obs_steps
|
|
|
|
# Encode image features and concatenate them all together along with the state vector.
|
|
global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim)
|
|
|
|
# run sampling
|
|
actions = self.conditional_sample(batch_size, global_cond=global_cond)
|
|
|
|
# Extract `n_action_steps` steps worth of actions (from the current observation).
|
|
start = n_obs_steps - 1
|
|
end = start + self.config.n_action_steps
|
|
actions = actions[:, start:end]
|
|
|
|
return actions
|
|
|
|
def compute_loss(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
"""
|
|
This function expects `batch` to have (at least):
|
|
{
|
|
"observation.state": (B, n_obs_steps, state_dim)
|
|
|
|
"observation.images": (B, n_obs_steps, num_cameras, C, H, W)
|
|
AND/OR
|
|
"observation.environment_state": (B, environment_dim)
|
|
|
|
"action": (B, horizon, action_dim)
|
|
"action_is_pad": (B, horizon)
|
|
}
|
|
"""
|
|
# Input validation.
|
|
assert set(batch).issuperset({"observation.state", "action", "action_is_pad"})
|
|
assert "observation.images" in batch or "observation.environment_state" in batch
|
|
n_obs_steps = batch["observation.state"].shape[1]
|
|
horizon = batch["action"].shape[1]
|
|
assert horizon == self.config.horizon
|
|
assert n_obs_steps == self.config.n_obs_steps
|
|
|
|
# Encode image features and concatenate them all together along with the state vector.
|
|
global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim)
|
|
|
|
# Forward diffusion.
|
|
trajectory = batch["action"]
|
|
# Sample noise to add to the trajectory.
|
|
noise = self.velocity_net.sample_noise(trajectory.shape[0], trajectory.device)
|
|
# Sample a random noising timestep for each item in the batch.
|
|
timesteps = self.noise_distribution.sample((trajectory.shape[0],)).to(trajectory.device)
|
|
# Add noise to the clean trajectories according to the noise magnitude at each timestep.
|
|
noisy_trajectory = (1 - timesteps[:, None, None]) * noise + timesteps[:, None, None] * trajectory
|
|
|
|
# Run the denoising network (that might denoise the trajectory, or attempt to predict the noise).
|
|
pred = self.velocity_net(noisy_actions=noisy_trajectory, time=timesteps, global_cond=global_cond)
|
|
target = trajectory - noise
|
|
loss = F.mse_loss(pred, target, reduction="none")
|
|
|
|
# Mask loss wherever the action is padded with copies (edges of the dataset trajectory).
|
|
if self.config.do_mask_loss_for_padding:
|
|
if "action_is_pad" not in batch:
|
|
raise ValueError(
|
|
"You need to provide 'action_is_pad' in the batch when "
|
|
f"{self.config.do_mask_loss_for_padding=}."
|
|
)
|
|
in_episode_bound = ~batch["action_is_pad"]
|
|
loss = loss * in_episode_bound.unsqueeze(-1)
|
|
|
|
return loss.mean()
|