lerobot/lerobot/common/policies/diffusion/modeling_diffusion.py

597 lines
24 KiB
Python

"""Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
TODO(alexander-soare):
- Remove reliance on Robomimic for SpatialSoftmax.
- Remove reliance on diffusers for DDPMScheduler and LR scheduler.
"""
import math
from collections import deque
from typing import Callable
import einops
import torch
import torch.nn.functional as F # noqa: N812
import torchvision
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from huggingface_hub import PyTorchModelHubMixin
from robomimic.models.base_nets import SpatialSoftmax
from torch import Tensor, nn
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.utils import (
get_device_from_parameters,
get_dtype_from_parameters,
populate_queues,
)
class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
"""
Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
(paper: https://arxiv.org/abs/2303.04137, code: https://github.com/real-stanford/diffusion_policy).
"""
name = "diffusion"
def __init__(
self,
config: DiffusionConfig | None = None,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
config: Policy configuration class instance or None, in which case the default instantiation of
the configuration class is used.
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__()
if config is None:
config = DiffusionConfig()
self.config = config
self.normalize_inputs = Normalize(
config.input_shapes, config.input_normalization_modes, dataset_stats
)
self.normalize_targets = Normalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
)
# queues are populated during rollout of the policy, they contain the n latest observations and actions
self._queues = None
self.diffusion = DiffusionModel(config)
def reset(self):
"""
Clear observation and action queues. Should be called on `env.reset()`
"""
self._queues = {
"observation.image": deque(maxlen=self.config.n_obs_steps),
"observation.state": deque(maxlen=self.config.n_obs_steps),
"action": deque(maxlen=self.config.n_action_steps),
}
@torch.no_grad
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.
This method handles caching a history of observations and an action trajectory generated by the
underlying diffusion model. Here's how it works:
- `n_obs_steps` steps worth of observations are cached (for the first steps, the observation is
copied `n_obs_steps` times to fill the cache).
- The diffusion model generates `horizon` steps worth of actions.
- `n_action_steps` worth of actions are actually kept for execution, starting from the current step.
Schematically this looks like:
----------------------------------------------------------------------------------------------
(legend: o = n_obs_steps, h = horizon, a = n_action_steps)
|timestep | n-o+1 | n-o+2 | ..... | n | ..... | n+a-1 | n+a | ..... |n-o+1+h|
|observation is used | YES | YES | YES | NO | NO | NO | NO | NO | NO |
|action is generated | YES | YES | YES | YES | YES | YES | YES | YES | YES |
|action is used | NO | NO | NO | YES | YES | YES | NO | NO | NO |
----------------------------------------------------------------------------------------------
Note that this means we require: `n_action_steps < horizon - n_obs_steps + 1`. Also, note that
"horizon" may not the best name to describe what the variable actually means, because this period is
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
"""
assert "observation.image" in batch
assert "observation.state" in batch
batch = self.normalize_inputs(batch)
self._queues = populate_queues(self._queues, batch)
if len(self._queues["action"]) == 0:
# stack n latest observations from the queue
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
actions = self.diffusion.generate_actions(batch)
# TODO(rcadene): make above methods return output dictionary?
actions = self.unnormalize_outputs({"action": actions})["action"]
self._queues["action"].extend(actions.transpose(0, 1))
action = self._queues["action"].popleft()
return action
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
loss = self.diffusion.compute_loss(batch)
return {"loss": loss}
class DiffusionModel(nn.Module):
def __init__(self, config: DiffusionConfig):
super().__init__()
self.config = config
self.rgb_encoder = DiffusionRgbEncoder(config)
self.unet = DiffusionConditionalUnet1d(
config,
global_cond_dim=(config.output_shapes["action"][0] + self.rgb_encoder.feature_dim)
* config.n_obs_steps,
)
self.noise_scheduler = DDPMScheduler(
num_train_timesteps=config.num_train_timesteps,
beta_start=config.beta_start,
beta_end=config.beta_end,
beta_schedule=config.beta_schedule,
variance_type="fixed_small",
clip_sample=config.clip_sample,
clip_sample_range=config.clip_sample_range,
prediction_type=config.prediction_type,
)
if config.num_inference_steps is None:
self.num_inference_steps = self.noise_scheduler.config.num_train_timesteps
else:
self.num_inference_steps = config.num_inference_steps
# ========= inference ============
def conditional_sample(
self, batch_size: int, global_cond: Tensor | None = None, generator: torch.Generator | None = None
) -> Tensor:
device = get_device_from_parameters(self)
dtype = get_dtype_from_parameters(self)
# Sample prior.
sample = torch.randn(
size=(batch_size, self.config.horizon, self.config.output_shapes["action"][0]),
dtype=dtype,
device=device,
generator=generator,
)
self.noise_scheduler.set_timesteps(self.num_inference_steps)
for t in self.noise_scheduler.timesteps:
# Predict model output.
model_output = self.unet(
sample,
torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device),
global_cond=global_cond,
)
# Compute previous image: x_t -> x_t-1
sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample
return sample
def generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
"""
This function expects `batch` to have (at least):
{
"observation.state": (B, n_obs_steps, state_dim)
"observation.image": (B, n_obs_steps, C, H, W)
}
"""
assert set(batch).issuperset({"observation.state", "observation.image"})
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
assert n_obs_steps == self.config.n_obs_steps
# Extract image feature (first combine batch and sequence dims).
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
# Separate batch and sequence dims.
img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
# Concatenate state and image features then flatten to (B, global_cond_dim).
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
# run sampling
sample = self.conditional_sample(batch_size, global_cond=global_cond)
# `horizon` steps worth of actions (from the first observation).
actions = sample[..., : self.config.output_shapes["action"][0]]
# Extract `n_action_steps` steps worth of actions (from the current observation).
start = n_obs_steps - 1
end = start + self.config.n_action_steps
actions = actions[:, start:end]
return actions
def compute_loss(self, batch: dict[str, Tensor]) -> Tensor:
"""
This function expects `batch` to have (at least):
{
"observation.state": (B, n_obs_steps, state_dim)
"observation.image": (B, n_obs_steps, C, H, W)
"action": (B, horizon, action_dim)
"action_is_pad": (B, horizon)
}
"""
# Input validation.
assert set(batch).issuperset({"observation.state", "observation.image", "action", "action_is_pad"})
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
horizon = batch["action"].shape[1]
assert horizon == self.config.horizon
assert n_obs_steps == self.config.n_obs_steps
# Extract image feature (first combine batch and sequence dims).
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
# Separate batch and sequence dims.
img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
# Concatenate state and image features then flatten to (B, global_cond_dim).
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
trajectory = batch["action"]
# Forward diffusion.
# Sample noise to add to the trajectory.
eps = torch.randn(trajectory.shape, device=trajectory.device)
# Sample a random noising timestep for each item in the batch.
timesteps = torch.randint(
low=0,
high=self.noise_scheduler.config.num_train_timesteps,
size=(trajectory.shape[0],),
device=trajectory.device,
).long()
# Add noise to the clean trajectories according to the noise magnitude at each timestep.
noisy_trajectory = self.noise_scheduler.add_noise(trajectory, eps, timesteps)
# Run the denoising network (that might denoise the trajectory, or attempt to predict the noise).
pred = self.unet(noisy_trajectory, timesteps, global_cond=global_cond)
# Compute the loss.
# The target is either the original trajectory, or the noise.
if self.config.prediction_type == "epsilon":
target = eps
elif self.config.prediction_type == "sample":
target = batch["action"]
else:
raise ValueError(f"Unsupported prediction type {self.config.prediction_type}")
loss = F.mse_loss(pred, target, reduction="none")
# Mask loss wherever the action is padded with copies (edges of the dataset trajectory).
if "action_is_pad" in batch:
in_episode_bound = ~batch["action_is_pad"]
loss = loss * in_episode_bound.unsqueeze(-1)
return loss.mean()
class DiffusionRgbEncoder(nn.Module):
"""Encoder an RGB image into a 1D feature vector.
Includes the ability to normalize and crop the image first.
"""
def __init__(self, config: DiffusionConfig):
super().__init__()
# Set up optional preprocessing.
if config.crop_shape is not None:
self.do_crop = True
# Always use center crop for eval
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
if config.crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
else:
self.maybe_random_crop = self.center_crop
else:
self.do_crop = False
# Set up backbone.
backbone_model = getattr(torchvision.models, config.vision_backbone)(
weights=config.pretrained_backbone_weights
)
# Note: This assumes that the layer4 feature map is children()[-3]
# TODO(alexander-soare): Use a safer alternative.
self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
if config.use_group_norm:
if config.pretrained_backbone_weights:
raise ValueError(
"You can't replace BatchNorm in a pretrained model without ruining the weights!"
)
self.backbone = _replace_submodules(
root_module=self.backbone,
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
)
# Set up pooling and final layers.
# Use a dry run to get the feature map shape.
with torch.inference_mode():
feat_map_shape = tuple(
self.backbone(torch.zeros(size=(1, *config.input_shapes["observation.image"]))).shape[1:]
)
self.pool = SpatialSoftmax(feat_map_shape, num_kp=config.spatial_softmax_num_keypoints)
self.feature_dim = config.spatial_softmax_num_keypoints * 2
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
self.relu = nn.ReLU()
def forward(self, x: Tensor) -> Tensor:
"""
Args:
x: (B, C, H, W) image tensor with pixel values in [0, 1].
Returns:
(B, D) image feature.
"""
# Preprocess: maybe crop (if it was set up in the __init__).
if self.do_crop:
if self.training: # noqa: SIM108
x = self.maybe_random_crop(x)
else:
# Always use center crop for eval.
x = self.center_crop(x)
# Extract backbone feature.
x = torch.flatten(self.pool(self.backbone(x)), start_dim=1)
# Final linear layer with non-linearity.
x = self.relu(self.out(x))
return x
def _replace_submodules(
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
) -> nn.Module:
"""
Args:
root_module: The module for which the submodules need to be replaced
predicate: Takes a module as an argument and must return True if the that module is to be replaced.
func: Takes a module as an argument and returns a new module to replace it with.
Returns:
The root module with its submodules replaced.
"""
if predicate(root_module):
return func(root_module)
replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
for *parents, k in replace_list:
parent_module = root_module
if len(parents) > 0:
parent_module = root_module.get_submodule(".".join(parents))
if isinstance(parent_module, nn.Sequential):
src_module = parent_module[int(k)]
else:
src_module = getattr(parent_module, k)
tgt_module = func(src_module)
if isinstance(parent_module, nn.Sequential):
parent_module[int(k)] = tgt_module
else:
setattr(parent_module, k, tgt_module)
# verify that all BN are replaced
assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
return root_module
class DiffusionSinusoidalPosEmb(nn.Module):
"""1D sinusoidal positional embeddings as in Attention is All You Need."""
def __init__(self, dim: int):
super().__init__()
self.dim = dim
def forward(self, x: Tensor) -> Tensor:
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x.unsqueeze(-1) * emb.unsqueeze(0)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class DiffusionConv1dBlock(nn.Module):
"""Conv1d --> GroupNorm --> Mish"""
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
super().__init__()
self.block = nn.Sequential(
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
nn.GroupNorm(n_groups, out_channels),
nn.Mish(),
)
def forward(self, x):
return self.block(x)
class DiffusionConditionalUnet1d(nn.Module):
"""A 1D convolutional UNet with FiLM modulation for conditioning.
Note: this removes local conditioning as compared to the original diffusion policy code.
"""
def __init__(self, config: DiffusionConfig, global_cond_dim: int):
super().__init__()
self.config = config
# Encoder for the diffusion timestep.
self.diffusion_step_encoder = nn.Sequential(
DiffusionSinusoidalPosEmb(config.diffusion_step_embed_dim),
nn.Linear(config.diffusion_step_embed_dim, config.diffusion_step_embed_dim * 4),
nn.Mish(),
nn.Linear(config.diffusion_step_embed_dim * 4, config.diffusion_step_embed_dim),
)
# The FiLM conditioning dimension.
cond_dim = config.diffusion_step_embed_dim + global_cond_dim
# In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we
# just reverse these.
in_out = [(config.output_shapes["action"][0], config.down_dims[0])] + list(
zip(config.down_dims[:-1], config.down_dims[1:], strict=True)
)
# Unet encoder.
common_res_block_kwargs = {
"cond_dim": cond_dim,
"kernel_size": config.kernel_size,
"n_groups": config.n_groups,
"use_film_scale_modulation": config.use_film_scale_modulation,
}
self.down_modules = nn.ModuleList([])
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (len(in_out) - 1)
self.down_modules.append(
nn.ModuleList(
[
DiffusionConditionalResidualBlock1d(dim_in, dim_out, **common_res_block_kwargs),
DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
# Downsample as long as it is not the last block.
nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(),
]
)
)
# Processing in the middle of the auto-encoder.
self.mid_modules = nn.ModuleList(
[
DiffusionConditionalResidualBlock1d(
config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs
),
DiffusionConditionalResidualBlock1d(
config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs
),
]
)
# Unet decoder.
self.up_modules = nn.ModuleList([])
for ind, (dim_out, dim_in) in enumerate(reversed(in_out[1:])):
is_last = ind >= (len(in_out) - 1)
self.up_modules.append(
nn.ModuleList(
[
# dim_in * 2, because it takes the encoder's skip connection as well
DiffusionConditionalResidualBlock1d(dim_in * 2, dim_out, **common_res_block_kwargs),
DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
# Upsample as long as it is not the last block.
nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(),
]
)
)
self.final_conv = nn.Sequential(
DiffusionConv1dBlock(config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size),
nn.Conv1d(config.down_dims[0], config.output_shapes["action"][0], 1),
)
def forward(self, x: Tensor, timestep: Tensor | int, global_cond=None) -> Tensor:
"""
Args:
x: (B, T, input_dim) tensor for input to the Unet.
timestep: (B,) tensor of (timestep_we_are_denoising_from - 1).
global_cond: (B, global_cond_dim)
output: (B, T, input_dim)
Returns:
(B, T, input_dim) diffusion model prediction.
"""
# For 1D convolutions we'll need feature dimension first.
x = einops.rearrange(x, "b t d -> b d t")
timesteps_embed = self.diffusion_step_encoder(timestep)
# If there is a global conditioning feature, concatenate it to the timestep embedding.
if global_cond is not None:
global_feature = torch.cat([timesteps_embed, global_cond], axis=-1)
else:
global_feature = timesteps_embed
# Run encoder, keeping track of skip features to pass to the decoder.
encoder_skip_features: list[Tensor] = []
for resnet, resnet2, downsample in self.down_modules:
x = resnet(x, global_feature)
x = resnet2(x, global_feature)
encoder_skip_features.append(x)
x = downsample(x)
for mid_module in self.mid_modules:
x = mid_module(x, global_feature)
# Run decoder, using the skip features from the encoder.
for resnet, resnet2, upsample in self.up_modules:
x = torch.cat((x, encoder_skip_features.pop()), dim=1)
x = resnet(x, global_feature)
x = resnet2(x, global_feature)
x = upsample(x)
x = self.final_conv(x)
x = einops.rearrange(x, "b d t -> b t d")
return x
class DiffusionConditionalResidualBlock1d(nn.Module):
"""ResNet style 1D convolutional block with FiLM modulation for conditioning."""
def __init__(
self,
in_channels: int,
out_channels: int,
cond_dim: int,
kernel_size: int = 3,
n_groups: int = 8,
# Set to True to do scale modulation with FiLM as well as bias modulation (defaults to False meaning
# FiLM just modulates bias).
use_film_scale_modulation: bool = False,
):
super().__init__()
self.use_film_scale_modulation = use_film_scale_modulation
self.out_channels = out_channels
self.conv1 = DiffusionConv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups)
# FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale.
cond_channels = out_channels * 2 if use_film_scale_modulation else out_channels
self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels))
self.conv2 = DiffusionConv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups)
# A final convolution for dimension matching the residual (if needed).
self.residual_conv = (
nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
)
def forward(self, x: Tensor, cond: Tensor) -> Tensor:
"""
Args:
x: (B, in_channels, T)
cond: (B, cond_dim)
Returns:
(B, out_channels, T)
"""
out = self.conv1(x)
# Get condition embedding. Unsqueeze for broadcasting to `out`, resulting in (B, out_channels, 1).
cond_embed = self.cond_encoder(cond).unsqueeze(-1)
if self.use_film_scale_modulation:
# Treat the embedding as a list of scales and biases.
scale = cond_embed[:, : self.out_channels]
bias = cond_embed[:, self.out_channels :]
out = scale * out + bias
else:
# Treat the embedding as biases.
out = out + cond_embed
out = self.conv2(out)
out = out + self.residual_conv(x)
return out