backup wip

This commit is contained in:
Alexander Soare 2024-04-15 19:06:44 +01:00
parent 14f3ffb412
commit 5608e659e6
14 changed files with 1059 additions and 977 deletions

View File

@ -11,7 +11,7 @@ import torch
from omegaconf import OmegaConf
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.common.utils import init_hydra_config
output_directory = Path("outputs/train/example_pusht_diffusion")

View File

@ -56,7 +56,7 @@ class ActionChunkingTransformerConfig:
# Inputs / output structure.
n_obs_steps: int = 1
camera_names: list[str] = field(default_factory=lambda: ["top"])
camera_names: tuple[str] = ("top",)
chunk_size: int = 100
n_action_steps: int = 100
@ -101,7 +101,7 @@ class ActionChunkingTransformerConfig:
utd: int = 1
def __post_init__(self):
"""Input validation."""
"""Input validation (not exhaustive)."""
if not self.vision_backbone.startswith("resnet"):
raise ValueError("`vision_backbone` must be one of the ResNet variants.")
if self.use_temporal_aggregation:

View File

@ -163,7 +163,8 @@ class ActionChunkingTransformerPolicy(nn.Module):
@torch.no_grad
def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
"""
"""Select a single action given environment observations.
This method wraps `select_actions` in order to return one action at a time for execution in the
environment. It works by managing the actions in a queue and only calling `select_actions` when the
queue is empty.

View File

@ -0,0 +1,83 @@
from dataclasses import dataclass
@dataclass
class DiffusionConfig:
"""Configuration class for Diffusion Policy.
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `state_dim`, `action_dim` and `image_size`.
Args:
state_dim: Dimensionality of the observation state space (excluding images).
action_dim: Dimensionality of the action space.
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
current step and additional steps going back).
horizon: Diffusion model action prediction horizon as detailed in the main policy documentation.
"""
# Environment.
# Inherit these from the environment config.
state_dim: int = 2
action_dim: int = 2
image_size: tuple[int, int] = (96, 96)
# Inputs / output structure.
n_obs_steps: int = 2
horizon: int = 16
n_action_steps: int = 8
# Vision preprocessing.
image_normalization_mean: tuple[float, float, float] = (0.5, 0.5, 0.5)
image_normalization_std: tuple[float, float, float] = (0.5, 0.5, 0.5)
# Architecture / modeling.
# Vision backbone.
vision_backbone: str = "resnet18"
crop_shape: tuple[int, int] = (84, 84)
crop_is_random: bool = True
use_pretrained_backbone: bool = False
use_group_norm: bool = True
spatial_softmax_num_keypoints: int = 32
# Unet.
down_dims: tuple[int, ...] = (512, 1024, 2048)
kernel_size: int = 5
n_groups: int = 8
diffusion_step_embed_dim: int = 128
film_scale_modulation: bool = True
# Noise scheduler.
num_train_timesteps: int = 100
beta_schedule: str = "squaredcos_cap_v2"
beta_start: float = 0.0001
beta_end: float = 0.02
variance_type: str = "fixed_small"
prediction_type: str = "epsilon"
clip_sample: True
# Inference
num_inference_steps: int = 100
# ---
# TODO(alexander-soare): Remove these from the policy config.
batch_size: int = 64
grad_clip_norm: int = 10
lr: float = 1.0e-4
lr_scheduler: str = "cosine"
lr_warmup_steps: int = 500
adam_betas: tuple[float, float] = (0.95, 0.999)
adam_eps: float = 1.0e-8
adam_weight_decay: float = 1.0e-6
utd: int = 1
use_ema: bool = True
ema_update_after_step: int = 0
ema_min_rate: float = 0.0
ema_max_rate: float = 0.9999
ema_inv_gamma: float = 1.0
ema_power: float = 0.75
def __post_init__(self):
"""Input validation (not exhaustive)."""
if not self.vision_backbone.startswith("resnet"):
raise ValueError("`vision_backbone` must be one of the ResNet variants.")

View File

@ -1,306 +0,0 @@
import logging
import math
import einops
import torch
import torch.nn as nn
from torch import Tensor
logger = logging.getLogger(__name__)
class _SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class _Conv1dBlock(nn.Module):
"""Conv1d --> GroupNorm --> Mish"""
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
super().__init__()
self.block = nn.Sequential(
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
nn.GroupNorm(n_groups, out_channels),
nn.Mish(),
)
def forward(self, x):
return self.block(x)
class _ConditionalResidualBlock1D(nn.Module):
"""ResNet style 1D convolutional block with FiLM modulation for conditioning."""
def __init__(
self,
in_channels: int,
out_channels: int,
cond_dim: int,
kernel_size: int = 3,
n_groups: int = 8,
# Set to True to do scale modulation with FiLM as well as bias modulation (defaults to False meaning
# FiLM just modulates bias).
film_scale_modulation: bool = False,
):
super().__init__()
self.film_scale_modulation = film_scale_modulation
self.out_channels = out_channels
self.conv1 = _Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups)
# FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale.
cond_channels = out_channels * 2 if film_scale_modulation else out_channels
self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels))
self.conv2 = _Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups)
# A final convolution for dimension matching the residual (if needed).
self.residual_conv = (
nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
)
def forward(self, x: Tensor, cond: Tensor) -> Tensor:
"""
Args:
x: (B, in_channels, T)
cond: (B, cond_dim)
Returns:
(B, out_channels, T)
"""
out = self.conv1(x)
# Get condition embedding. Unsqueeze for broadcasting to `out`, resulting in (B, out_channels, 1).
cond_embed = self.cond_encoder(cond).unsqueeze(-1)
if self.film_scale_modulation:
# Treat the embedding as a list of scales and biases.
scale = cond_embed[:, : self.out_channels]
bias = cond_embed[:, self.out_channels :]
out = scale * out + bias
else:
# Treat the embedding as biases.
out = out + cond_embed
out = self.conv2(out)
out = out + self.residual_conv(x)
return out
class ConditionalUnet1D(nn.Module):
"""A 1D convolutional UNet with FiLM modulation for conditioning.
Two types of conditioning can be applied:
- Global: Conditioning information that is aggregated over the whole observation window. This is
incorporated via the FiLM technique in the residual convolution blocks of the Unet's encoder/decoder.
- Local: Conditioning information for each timestep in the observation window. This is incorporated
by encoding the information via 1D convolutions and adding the resulting embeddings to the inputs and
outputs of the Unet's encoder/decoder.
"""
def __init__(
self,
input_dim: int,
local_cond_dim: int | None = None,
global_cond_dim: int | None = None,
diffusion_step_embed_dim: int = 256,
down_dims: int | None = None,
kernel_size: int = 3,
n_groups: int = 8,
film_scale_modulation: bool = False,
):
super().__init__()
if down_dims is None:
down_dims = [256, 512, 1024]
# Encoder for the diffusion timestep.
self.diffusion_step_encoder = nn.Sequential(
_SinusoidalPosEmb(diffusion_step_embed_dim),
nn.Linear(diffusion_step_embed_dim, diffusion_step_embed_dim * 4),
nn.Mish(),
nn.Linear(diffusion_step_embed_dim * 4, diffusion_step_embed_dim),
)
# The FiLM conditioning dimension.
cond_dim = diffusion_step_embed_dim
if global_cond_dim is not None:
cond_dim += global_cond_dim
self.local_cond_down_encoder = None
self.local_cond_up_encoder = None
if local_cond_dim is not None:
# Encoder for the local conditioning. The output gets added to the Unet encoder input.
self.local_cond_down_encoder = _ConditionalResidualBlock1D(
local_cond_dim,
down_dims[0],
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
film_scale_modulation=film_scale_modulation,
)
# Encoder for the local conditioning. The output gets added to the Unet encoder output.
self.local_cond_up_encoder = _ConditionalResidualBlock1D(
local_cond_dim,
down_dims[0],
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
film_scale_modulation=film_scale_modulation,
)
# In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we
# just reverse these.
in_out = [(input_dim, down_dims[0])] + list(zip(down_dims[:-1], down_dims[1:], strict=True))
# Unet encoder.
self.down_modules = nn.ModuleList([])
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (len(in_out) - 1)
self.down_modules.append(
nn.ModuleList(
[
_ConditionalResidualBlock1D(
dim_in,
dim_out,
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
film_scale_modulation=film_scale_modulation,
),
_ConditionalResidualBlock1D(
dim_out,
dim_out,
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
film_scale_modulation=film_scale_modulation,
),
# Downsample as long as it is not the last block.
nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(),
]
)
)
# Processing in the middle of the auto-encoder.
self.mid_modules = nn.ModuleList(
[
_ConditionalResidualBlock1D(
down_dims[-1],
down_dims[-1],
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
film_scale_modulation=film_scale_modulation,
),
_ConditionalResidualBlock1D(
down_dims[-1],
down_dims[-1],
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
film_scale_modulation=film_scale_modulation,
),
]
)
# Unet decoder.
self.up_modules = nn.ModuleList([])
for ind, (dim_out, dim_in) in enumerate(reversed(in_out[1:])):
is_last = ind >= (len(in_out) - 1)
self.up_modules.append(
nn.ModuleList(
[
_ConditionalResidualBlock1D(
dim_in * 2, # x2 as it takes the encoder's skip connection as well
dim_out,
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
film_scale_modulation=film_scale_modulation,
),
_ConditionalResidualBlock1D(
dim_out,
dim_out,
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
film_scale_modulation=film_scale_modulation,
),
# Upsample as long as it is not the last block.
nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(),
]
)
)
self.final_conv = nn.Sequential(
_Conv1dBlock(down_dims[0], down_dims[0], kernel_size=kernel_size),
nn.Conv1d(down_dims[0], input_dim, 1),
)
def forward(self, x: Tensor, timestep: Tensor | int, local_cond=None, global_cond=None) -> Tensor:
"""
Args:
x: (B, T, input_dim) tensor for input to the Unet.
timestep: (B,) tensor of (timestep_we_are_denoising_from - 1).
local_cond: (B, T, local_cond_dim)
global_cond: (B, global_cond_dim)
output: (B, T, input_dim)
Returns:
(B, T, input_dim)
"""
# For 1D convolutions we'll need feature dimension first.
x = einops.rearrange(x, "b t d -> b d t")
if local_cond is not None:
if self.local_cond_down_encoder is None or self.local_cond_up_encoder is None:
raise ValueError(
"`local_cond` was provided but the relevant encoders weren't built at initialization."
)
local_cond = einops.rearrange(local_cond, "b t d -> b d t")
timesteps_embed = self.diffusion_step_encoder(timestep)
# If there is a global conditioning feature, concatenate it to the timestep embedding.
if global_cond is not None:
global_feature = torch.cat([timesteps_embed, global_cond], axis=-1)
else:
global_feature = timesteps_embed
encoder_skip_features: list[Tensor] = []
for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
x = resnet(x, global_feature)
if idx == 0 and local_cond is not None:
x = x + self.local_cond_down_encoder(local_cond, global_feature)
x = resnet2(x, global_feature)
encoder_skip_features.append(x)
x = downsample(x)
for mid_module in self.mid_modules:
x = mid_module(x, global_feature)
for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
x = torch.cat((x, encoder_skip_features.pop()), dim=1)
x = resnet(x, global_feature)
# Note: The condition in the original implementation is:
# if idx == len(self.up_modules) and local_cond is not None:
# But as they mention in their comments, this is incorrect. We use the correct condition here.
if idx == (len(self.up_modules) - 1) and local_cond is not None:
x = x + self.local_cond_up_encoder(local_cond, global_feature)
x = resnet2(x, global_feature)
x = upsample(x)
x = self.final_conv(x)
x = einops.rearrange(x, "b d t -> b t d")
return x

View File

@ -1,175 +0,0 @@
import einops
import torch
import torch.nn.functional as F # noqa: N812
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from torch import Tensor, nn
from lerobot.common.policies.diffusion.model.conditional_unet1d import ConditionalUnet1D
from lerobot.common.policies.diffusion.model.rgb_encoder import RgbEncoder
from lerobot.common.policies.utils import get_device_from_parameters, get_dtype_from_parameters
class DiffusionUnetImagePolicy(nn.Module):
def __init__(
self,
cfg,
shape_meta: dict,
noise_scheduler: DDPMScheduler,
horizon,
n_action_steps,
n_obs_steps,
num_inference_steps=None,
diffusion_step_embed_dim=256,
down_dims=(256, 512, 1024),
kernel_size=5,
n_groups=8,
film_scale_modulation=True,
):
super().__init__()
action_shape = shape_meta["action"]["shape"]
assert len(action_shape) == 1
action_dim = action_shape[0]
self.rgb_encoder = RgbEncoder(input_shape=shape_meta.obs.image.shape, **cfg.rgb_encoder)
self.unet = ConditionalUnet1D(
input_dim=action_dim,
global_cond_dim=(action_dim + self.rgb_encoder.feature_dim) * n_obs_steps,
diffusion_step_embed_dim=diffusion_step_embed_dim,
down_dims=down_dims,
kernel_size=kernel_size,
n_groups=n_groups,
film_scale_modulation=film_scale_modulation,
)
self.noise_scheduler = noise_scheduler
self.horizon = horizon
self.action_dim = action_dim
self.n_action_steps = n_action_steps
self.n_obs_steps = n_obs_steps
if num_inference_steps is None:
num_inference_steps = noise_scheduler.config.num_train_timesteps
self.num_inference_steps = num_inference_steps
# ========= inference ============
def conditional_sample(self, batch_size, global_cond=None, generator=None):
device = get_device_from_parameters(self)
dtype = get_dtype_from_parameters(self)
# Sample prior.
sample = torch.randn(
size=(batch_size, self.horizon, self.action_dim),
dtype=dtype,
device=device,
generator=generator,
)
self.noise_scheduler.set_timesteps(self.num_inference_steps)
for t in self.noise_scheduler.timesteps:
# Predict model output.
model_output = self.unet(
sample,
torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device),
global_cond=global_cond,
)
# Compute previous image: x_t -> x_t-1
sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample
return sample
def generate_actions(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""
This function expects `batch` to have (at least):
{
"observation.state": (B, n_obs_steps, state_dim)
"observation.image": (B, n_obs_steps, C, H, W)
}
"""
assert set(batch).issuperset({"observation.state", "observation.image"})
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
assert n_obs_steps == self.n_obs_steps
assert self.n_obs_steps == n_obs_steps
# Extract image feature (first combine batch and sequence dims).
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
# Separate batch and sequence dims.
img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
# Concatenate state and image features then flatten to (B, global_cond_dim).
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
# run sampling
sample = self.conditional_sample(batch_size, global_cond=global_cond)
# `horizon` steps worth of actions (from the first observation).
action = sample[..., : self.action_dim]
# Extract `n_action_steps` steps worth of actions (from the current observation).
start = n_obs_steps - 1
end = start + self.n_action_steps
action = action[:, start:end]
return action
def compute_loss(self, batch: dict[str, Tensor]) -> Tensor:
"""
This function expects `batch` to have (at least):
{
"observation.state": (B, n_obs_steps, state_dim)
"observation.image": (B, n_obs_steps, C, H, W)
"action": (B, horizon, action_dim)
"action_is_pad": (B, horizon)
}
"""
# Input validation.
assert set(batch).issuperset({"observation.state", "observation.image", "action", "action_is_pad"})
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
horizon = batch["action"].shape[1]
assert horizon == self.horizon
assert n_obs_steps == self.n_obs_steps
assert self.n_obs_steps == n_obs_steps
# Extract image feature (first combine batch and sequence dims).
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
# Separate batch and sequence dims.
img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
# Concatenate state and image features then flatten to (B, global_cond_dim).
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
trajectory = batch["action"]
# Forward diffusion.
# Sample noise to add to the trajectory.
eps = torch.randn(trajectory.shape, device=trajectory.device)
# Sample a random noising timestep for each item in the batch.
timesteps = torch.randint(
low=0,
high=self.noise_scheduler.config.num_train_timesteps,
size=(trajectory.shape[0],),
device=trajectory.device,
).long()
# Add noise to the clean trajectories according to the noise magnitude at each timestep.
noisy_trajectory = self.noise_scheduler.add_noise(trajectory, eps, timesteps)
# Run the denoising network (that might denoise the trajectory, or attempt to predict the noise).
pred = self.unet(noisy_trajectory, timesteps, global_cond=global_cond)
# Compute the loss.
# The targe is either the original trajectory, or the noise.
pred_type = self.noise_scheduler.config.prediction_type
if pred_type == "epsilon":
target = eps
elif pred_type == "sample":
target = batch["action"]
else:
raise ValueError(f"Unsupported prediction type {pred_type}")
loss = F.mse_loss(pred, target, reduction="none")
# Mask loss wherever the action is padded with copies (edges of the dataset trajectory).
if "action_is_pad" in batch:
in_episode_bound = ~batch["action_is_pad"]
loss = loss * in_episode_bound.unsqueeze(-1)
return loss.mean()

View File

@ -1,68 +0,0 @@
import torch
from torch.nn.modules.batchnorm import _BatchNorm
class EMAModel:
"""
Exponential Moving Average of models weights
"""
def __init__(
self, model, update_after_step=0, inv_gamma=1.0, power=2 / 3, min_value=0.0, max_value=0.9999
):
"""
@crowsonkb's notes on EMA Warmup:
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
at 215.4k steps).
Args:
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
power (float): Exponential factor of EMA warmup. Default: 2/3.
min_value (float): The minimum EMA decay rate. Default: 0.
"""
self.averaged_model = model
self.averaged_model.eval()
self.averaged_model.requires_grad_(False)
self.update_after_step = update_after_step
self.inv_gamma = inv_gamma
self.power = power
self.min_value = min_value
self.max_value = max_value
self.alpha = 0.0
self.optimization_step = 0
def get_decay(self, optimization_step):
"""
Compute the decay factor for the exponential moving average.
"""
step = max(0, optimization_step - self.update_after_step - 1)
value = 1 - (1 + step / self.inv_gamma) ** -self.power
if step <= 0:
return 0.0
return max(self.min_value, min(value, self.max_value))
@torch.no_grad()
def step(self, new_model):
self.alpha = self.get_decay(self.optimization_step)
for module, ema_module in zip(new_model.modules(), self.averaged_model.modules(), strict=True):
# Iterate over immediate parameters only.
for param, ema_param in zip(
module.parameters(recurse=False), ema_module.parameters(recurse=False), strict=True
):
if isinstance(param, dict):
raise RuntimeError("Dict parameter not supported")
if isinstance(module, _BatchNorm) or not param.requires_grad:
# Copy BatchNorm parameters, and non-trainable parameters directly.
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
else:
ema_param.mul_(self.alpha)
ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.alpha)
self.optimization_step += 1

View File

@ -1,147 +0,0 @@
from typing import Callable
import torch
import torchvision
from robomimic.models.base_nets import SpatialSoftmax
from torch import Tensor, nn
from torchvision.transforms import CenterCrop, RandomCrop
class RgbEncoder(nn.Module):
"""Encoder an RGB image into a 1D feature vector.
Includes the ability to normalize and crop the image first.
"""
def __init__(
self,
input_shape: tuple[int, int, int],
norm_mean_std: tuple[float, float] = [1.0, 1.0],
crop_shape: tuple[int, int] | None = None,
random_crop: bool = False,
backbone_name: str = "resnet18",
pretrained_backbone: bool = False,
use_group_norm: bool = False,
relu: bool = True,
num_keypoints: int = 32,
):
"""
Args:
input_shape: channel-first input shape (C, H, W)
norm_mean_std: mean and standard deviation used for image normalization. Images are normalized as
(image - mean) / std.
crop_shape: (H, W) shape to crop to (must fit within the input shape). If not provided, no
cropping is done.
random_crop: Whether the crop should be random at training time (it's always a center crop in
eval mode).
backbone_name: The name of one of the available resnet models from torchvision (eg resnet18).
pretrained_backbone: whether to use timm pretrained weights.
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
relu: whether to use relu as a final step.
num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image).
"""
super().__init__()
if input_shape[0] != 3:
raise ValueError("Only RGB images are handled")
if not backbone_name.startswith("resnet"):
raise ValueError(
"Only resnet is supported for now (because of the assumption that 'layer4' is the output layer)"
)
# Set up optional preprocessing.
if norm_mean_std == [1.0, 1.0]:
self.normalizer = nn.Identity()
else:
self.normalizer = torchvision.transforms.Normalize(mean=norm_mean_std[0], std=norm_mean_std[1])
if crop_shape is not None:
self.do_crop = True
self.center_crop = CenterCrop(crop_shape) # always use center crop for eval
if random_crop:
self.maybe_random_crop = RandomCrop(crop_shape)
else:
self.maybe_random_crop = self.center_crop
else:
self.do_crop = False
# Set up backbone.
backbone_model = getattr(torchvision.models, backbone_name)(pretrained=pretrained_backbone)
# Note: This assumes that the layer4 feature map is children()[-3]
# TODO(alexander-soare): Use a safer alternative.
self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
if use_group_norm:
if pretrained_backbone:
raise ValueError(
"You can't replace BatchNorm in a pretrained model without ruining the weights!"
)
self.backbone = _replace_submodules(
root_module=self.backbone,
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
)
# Set up pooling and final layers.
# Use a dry run to get the feature map shape.
with torch.inference_mode():
feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:])
self.pool = SpatialSoftmax(feat_map_shape, num_kp=num_keypoints)
self.feature_dim = num_keypoints * 2
self.out = nn.Linear(num_keypoints * 2, self.feature_dim)
self.maybe_relu = nn.ReLU() if relu else nn.Identity()
def forward(self, x: Tensor) -> Tensor:
"""
Args:
x: (B, C, H, W) image tensor with pixel values in [0, 1].
Returns:
(B, D) image feature.
"""
# Preprocess: normalize and maybe crop (if it was set up in the __init__).
x = self.normalizer(x)
if self.do_crop:
if self.training: # noqa: SIM108
x = self.maybe_random_crop(x)
else:
# Always use center crop for eval.
x = self.center_crop(x)
# Extract backbone feature.
x = torch.flatten(self.pool(self.backbone(x)), start_dim=1)
# Final linear layer.
x = self.out(x)
# Maybe a final non-linearity.
x = self.maybe_relu(x)
return x
def _replace_submodules(
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
) -> nn.Module:
"""
Args:
root_module: The module for which the submodules need to be replaced
predicate: Takes a module as an argument and must return True if the that module is to be replaced.
func: Takes a module as an argument and returns a new module to replace it with.
Returns:
The root module with its submodules replaced.
"""
if predicate(root_module):
return func(root_module)
replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
for *parents, k in replace_list:
parent_module = root_module
if len(parents) > 0:
parent_module = root_module.get_submodule(".".join(parents))
if isinstance(parent_module, nn.Sequential):
src_module = parent_module[int(k)]
else:
src_module = getattr(parent_module, k)
tgt_module = func(src_module)
if isinstance(parent_module, nn.Sequential):
parent_module[int(k)] = tgt_module
else:
setattr(parent_module, k, tgt_module)
# verify that all BN are replaced
assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
return root_module

View File

@ -0,0 +1,878 @@
"""
TODO(alexander-soare):
- Remove reliance on Robomimic for SpatialSoftmax.
- Remove reliance on diffusers for DDPMScheduler.
- Move EMA out of policy.
"""
import copy
import logging
import math
import time
from collections import deque
from typing import Callable
import einops
import hydra
import torch
import torch.nn.functional as F # noqa: N812
import torchvision
from diffusers.optimization import get_scheduler
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from robomimic.models.base_nets import SpatialSoftmax
from torch import Tensor, nn
from torch.nn.modules.batchnorm import _BatchNorm
from lerobot.common.policies.utils import (
get_device_from_parameters,
get_dtype_from_parameters,
populate_queues,
)
from lerobot.common.utils import get_safe_torch_device
logger = logging.getLogger(__name__)
class DiffusionPolicy(nn.Module):
"""
Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
(paper: https://arxiv.org/abs/2303.04137, code: https://github.com/real-stanford/diffusion_policy).
"""
name = "diffusion"
def __init__(
self,
cfg,
cfg_device,
cfg_noise_scheduler,
cfg_optimizer,
cfg_ema,
shape_meta: dict,
horizon,
n_action_steps,
n_obs_steps,
num_inference_steps=None,
diffusion_step_embed_dim=256,
down_dims=(256, 512, 1024),
kernel_size=5,
n_groups=8,
film_scale_modulation=True,
**_,
):
super().__init__()
self.cfg = cfg
self.n_obs_steps = n_obs_steps
self.n_action_steps = n_action_steps
# queues are populated during rollout of the policy, they contain the n latest observations and actions
self._queues = None
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
self.diffusion = _DiffusionUnetImagePolicy(
cfg,
shape_meta=shape_meta,
noise_scheduler=noise_scheduler,
horizon=horizon,
n_action_steps=n_action_steps,
n_obs_steps=n_obs_steps,
num_inference_steps=num_inference_steps,
diffusion_step_embed_dim=diffusion_step_embed_dim,
down_dims=down_dims,
kernel_size=kernel_size,
n_groups=n_groups,
film_scale_modulation=film_scale_modulation,
)
self.device = get_safe_torch_device(cfg_device)
self.diffusion.to(self.device)
# TODO(alexander-soare): This should probably be managed outside of the policy class.
self.ema_diffusion = None
self.ema = None
if self.cfg.use_ema:
self.ema_diffusion = copy.deepcopy(self.diffusion)
self.ema = hydra.utils.instantiate(
cfg_ema,
model=self.ema_diffusion,
)
self.optimizer = hydra.utils.instantiate(
cfg_optimizer,
params=self.diffusion.parameters(),
)
# TODO(rcadene): modify lr scheduler so that it doesnt depend on epochs but steps
self.global_step = 0
# configure lr scheduler
self.lr_scheduler = get_scheduler(
cfg.lr_scheduler,
optimizer=self.optimizer,
num_warmup_steps=cfg.lr_warmup_steps,
num_training_steps=cfg.offline_steps,
# pytorch assumes stepping LRScheduler every epoch
# however huggingface diffusers steps it every batch
last_epoch=self.global_step - 1,
)
def reset(self):
"""
Clear observation and action queues. Should be called on `env.reset()`
"""
self._queues = {
"observation.image": deque(maxlen=self.n_obs_steps),
"observation.state": deque(maxlen=self.n_obs_steps),
"action": deque(maxlen=self.n_action_steps),
}
@torch.no_grad
def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
"""Select a single action given environment observations.
This method handles caching a history of observations and an action trajectory generated by the
underlying diffusion model. Here's how it works:
- `n_obs_steps` steps worth of observations are cached (for the first steps, the observation is
copied `n_obs_steps` times to fill the cache).
- The diffusion model generates `horizon` steps worth of actions.
- `n_action_steps` worth of actions are actually kept for execution, starting from the current step.
Schematically this looks like:
(legend: o = n_obs_steps, h = horizon, a = n_action_steps)
|timestep | n-o+1 | n-o+2 | ..... | n | ..... | n+a-1 | n+a | ..... |n-o+1+h|
|observation is used | YES | YES | ..... | NO | NO | NO | NO | NO | NO |
|action is generated | YES | YES | YES | YES | YES | YES | YES | YES | YES |
|action is used | NO | NO | NO | YES | YES | YES | NO | NO | NO |
Note that this means we require: `n_action_steps < horizon - n_obs_steps + 1`. Also, note that
"horizon" may not the best name to describe what the variable actually means, because this period is
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
Note: this method uses the ema model weights if self.training == False, otherwise the non-ema model
weights.
"""
assert "observation.image" in batch
assert "observation.state" in batch
assert len(batch) == 2
self._queues = populate_queues(self._queues, batch)
if len(self._queues["action"]) == 0:
# stack n latest observations from the queue
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
if not self.training and self.ema_diffusion is not None:
actions = self.ema_diffusion.generate_actions(batch)
else:
actions = self.diffusion.generate_actions(batch)
self._queues["action"].extend(actions.transpose(0, 1))
action = self._queues["action"].popleft()
return action
def forward(self, batch, **_):
start_time = time.time()
self.diffusion.train()
loss = self.diffusion.compute_loss(batch)
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
self.diffusion.parameters(),
self.cfg.grad_clip_norm,
error_if_nonfinite=False,
)
self.optimizer.step()
self.optimizer.zero_grad()
self.lr_scheduler.step()
if self.ema is not None:
self.ema.step(self.diffusion)
info = {
"loss": loss.item(),
"grad_norm": float(grad_norm),
"lr": self.lr_scheduler.get_last_lr()[0],
"update_s": time.time() - start_time,
}
return info
def save(self, fp):
torch.save(self.state_dict(), fp)
def load(self, fp):
d = torch.load(fp)
missing_keys, unexpected_keys = self.load_state_dict(d, strict=False)
if len(missing_keys) > 0:
assert all(k.startswith("ema_diffusion.") for k in missing_keys)
logging.warning(
"DiffusionPolicy.load expected ema parameters in loaded state dict but none were found."
)
assert len(unexpected_keys) == 0
class _DiffusionUnetImagePolicy(nn.Module):
def __init__(
self,
cfg,
shape_meta: dict,
noise_scheduler: DDPMScheduler,
horizon,
n_action_steps,
n_obs_steps,
num_inference_steps=None,
diffusion_step_embed_dim=256,
down_dims=(256, 512, 1024),
kernel_size=5,
n_groups=8,
film_scale_modulation=True,
):
super().__init__()
action_shape = shape_meta["action"]["shape"]
assert len(action_shape) == 1
action_dim = action_shape[0]
self.rgb_encoder = _RgbEncoder(input_shape=shape_meta.obs.image.shape, **cfg.rgb_encoder)
self.unet = _ConditionalUnet1D(
input_dim=action_dim,
global_cond_dim=(action_dim + self.rgb_encoder.feature_dim) * n_obs_steps,
diffusion_step_embed_dim=diffusion_step_embed_dim,
down_dims=down_dims,
kernel_size=kernel_size,
n_groups=n_groups,
film_scale_modulation=film_scale_modulation,
)
self.noise_scheduler = noise_scheduler
self.horizon = horizon
self.action_dim = action_dim
self.n_action_steps = n_action_steps
self.n_obs_steps = n_obs_steps
if num_inference_steps is None:
num_inference_steps = noise_scheduler.config.num_train_timesteps
self.num_inference_steps = num_inference_steps
# ========= inference ============
def conditional_sample(self, batch_size, global_cond=None, generator=None):
device = get_device_from_parameters(self)
dtype = get_dtype_from_parameters(self)
# Sample prior.
sample = torch.randn(
size=(batch_size, self.horizon, self.action_dim),
dtype=dtype,
device=device,
generator=generator,
)
self.noise_scheduler.set_timesteps(self.num_inference_steps)
for t in self.noise_scheduler.timesteps:
# Predict model output.
model_output = self.unet(
sample,
torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device),
global_cond=global_cond,
)
# Compute previous image: x_t -> x_t-1
sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample
return sample
def generate_actions(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""
This function expects `batch` to have (at least):
{
"observation.state": (B, n_obs_steps, state_dim)
"observation.image": (B, n_obs_steps, C, H, W)
}
"""
assert set(batch).issuperset({"observation.state", "observation.image"})
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
assert n_obs_steps == self.n_obs_steps
assert self.n_obs_steps == n_obs_steps
# Extract image feature (first combine batch and sequence dims).
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
# Separate batch and sequence dims.
img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
# Concatenate state and image features then flatten to (B, global_cond_dim).
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
# run sampling
sample = self.conditional_sample(batch_size, global_cond=global_cond)
# `horizon` steps worth of actions (from the first observation).
action = sample[..., : self.action_dim]
# Extract `n_action_steps` steps worth of actions (from the current observation).
start = n_obs_steps - 1
end = start + self.n_action_steps
action = action[:, start:end]
return action
def compute_loss(self, batch: dict[str, Tensor]) -> Tensor:
"""
This function expects `batch` to have (at least):
{
"observation.state": (B, n_obs_steps, state_dim)
"observation.image": (B, n_obs_steps, C, H, W)
"action": (B, horizon, action_dim)
"action_is_pad": (B, horizon)
}
"""
# Input validation.
assert set(batch).issuperset({"observation.state", "observation.image", "action", "action_is_pad"})
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
horizon = batch["action"].shape[1]
assert horizon == self.horizon
assert n_obs_steps == self.n_obs_steps
assert self.n_obs_steps == n_obs_steps
# Extract image feature (first combine batch and sequence dims).
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
# Separate batch and sequence dims.
img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
# Concatenate state and image features then flatten to (B, global_cond_dim).
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
trajectory = batch["action"]
# Forward diffusion.
# Sample noise to add to the trajectory.
eps = torch.randn(trajectory.shape, device=trajectory.device)
# Sample a random noising timestep for each item in the batch.
timesteps = torch.randint(
low=0,
high=self.noise_scheduler.config.num_train_timesteps,
size=(trajectory.shape[0],),
device=trajectory.device,
).long()
# Add noise to the clean trajectories according to the noise magnitude at each timestep.
noisy_trajectory = self.noise_scheduler.add_noise(trajectory, eps, timesteps)
# Run the denoising network (that might denoise the trajectory, or attempt to predict the noise).
pred = self.unet(noisy_trajectory, timesteps, global_cond=global_cond)
# Compute the loss.
# The targe is either the original trajectory, or the noise.
pred_type = self.noise_scheduler.config.prediction_type
if pred_type == "epsilon":
target = eps
elif pred_type == "sample":
target = batch["action"]
else:
raise ValueError(f"Unsupported prediction type {pred_type}")
loss = F.mse_loss(pred, target, reduction="none")
# Mask loss wherever the action is padded with copies (edges of the dataset trajectory).
if "action_is_pad" in batch:
in_episode_bound = ~batch["action_is_pad"]
loss = loss * in_episode_bound.unsqueeze(-1)
return loss.mean()
class _RgbEncoder(nn.Module):
"""Encoder an RGB image into a 1D feature vector.
Includes the ability to normalize and crop the image first.
"""
def __init__(
self,
input_shape: tuple[int, int, int],
norm_mean_std: tuple[float, float] = [1.0, 1.0],
crop_shape: tuple[int, int] | None = None,
random_crop: bool = False,
backbone_name: str = "resnet18",
pretrained_backbone: bool = False,
use_group_norm: bool = False,
num_keypoints: int = 32,
):
"""
Args:
input_shape: channel-first input shape (C, H, W)
norm_mean_std: mean and standard deviation used for image normalization. Images are normalized as
(image - mean) / std.
crop_shape: (H, W) shape to crop to (must fit within the input shape). If not provided, no
cropping is done.
random_crop: Whether the crop should be random at training time (it's always a center crop in
eval mode).
backbone_name: The name of one of the available resnet models from torchvision (eg resnet18).
pretrained_backbone: whether to use timm pretrained weights.
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image).
"""
super().__init__()
if input_shape[0] != 3:
raise ValueError("Only RGB images are handled")
if not backbone_name.startswith("resnet"):
raise ValueError(
"Only resnet is supported for now (because of the assumption that 'layer4' is the output layer)"
)
# Set up optional preprocessing.
if norm_mean_std == [1.0, 1.0]:
self.normalizer = nn.Identity()
else:
self.normalizer = torchvision.transforms.Normalize(mean=norm_mean_std[0], std=norm_mean_std[1])
if crop_shape is not None:
self.do_crop = True
# Always use center crop for eval
self.center_crop = torchvision.transforms.CenterCrop(crop_shape)
if random_crop:
self.maybe_random_crop = torchvision.transforms.RandomCrop(crop_shape)
else:
self.maybe_random_crop = self.center_crop
else:
self.do_crop = False
# Set up backbone.
backbone_model = getattr(torchvision.models, backbone_name)(pretrained=pretrained_backbone)
# Note: This assumes that the layer4 feature map is children()[-3]
# TODO(alexander-soare): Use a safer alternative.
self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
if use_group_norm:
if pretrained_backbone:
raise ValueError(
"You can't replace BatchNorm in a pretrained model without ruining the weights!"
)
self.backbone = _replace_submodules(
root_module=self.backbone,
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
)
# Set up pooling and final layers.
# Use a dry run to get the feature map shape.
with torch.inference_mode():
feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:])
self.pool = SpatialSoftmax(feat_map_shape, num_kp=num_keypoints)
self.feature_dim = num_keypoints * 2
self.out = nn.Linear(num_keypoints * 2, self.feature_dim)
self.relu = nn.ReLU()
def forward(self, x: Tensor) -> Tensor:
"""
Args:
x: (B, C, H, W) image tensor with pixel values in [0, 1].
Returns:
(B, D) image feature.
"""
# Preprocess: normalize and maybe crop (if it was set up in the __init__).
x = self.normalizer(x)
if self.do_crop:
if self.training: # noqa: SIM108
x = self.maybe_random_crop(x)
else:
# Always use center crop for eval.
x = self.center_crop(x)
# Extract backbone feature.
x = torch.flatten(self.pool(self.backbone(x)), start_dim=1)
# Final linear layer with non-linearity.
x = self.relu(self.out(x))
return x
def _replace_submodules(
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
) -> nn.Module:
"""
Args:
root_module: The module for which the submodules need to be replaced
predicate: Takes a module as an argument and must return True if the that module is to be replaced.
func: Takes a module as an argument and returns a new module to replace it with.
Returns:
The root module with its submodules replaced.
"""
if predicate(root_module):
return func(root_module)
replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
for *parents, k in replace_list:
parent_module = root_module
if len(parents) > 0:
parent_module = root_module.get_submodule(".".join(parents))
if isinstance(parent_module, nn.Sequential):
src_module = parent_module[int(k)]
else:
src_module = getattr(parent_module, k)
tgt_module = func(src_module)
if isinstance(parent_module, nn.Sequential):
parent_module[int(k)] = tgt_module
else:
setattr(parent_module, k, tgt_module)
# verify that all BN are replaced
assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
return root_module
class _SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class _Conv1dBlock(nn.Module):
"""Conv1d --> GroupNorm --> Mish"""
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
super().__init__()
self.block = nn.Sequential(
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
nn.GroupNorm(n_groups, out_channels),
nn.Mish(),
)
def forward(self, x):
return self.block(x)
class _ConditionalUnet1D(nn.Module):
"""A 1D convolutional UNet with FiLM modulation for conditioning.
Two types of conditioning can be applied:
- Global: Conditioning information that is aggregated over the whole observation window. This is
incorporated via the FiLM technique in the residual convolution blocks of the Unet's encoder/decoder.
- Local: Conditioning information for each timestep in the observation window. This is incorporated
by encoding the information via 1D convolutions and adding the resulting embeddings to the inputs and
outputs of the Unet's encoder/decoder.
"""
def __init__(
self,
input_dim: int,
local_cond_dim: int | None = None,
global_cond_dim: int | None = None,
diffusion_step_embed_dim: int = 256,
down_dims: int | None = None,
kernel_size: int = 3,
n_groups: int = 8,
film_scale_modulation: bool = False,
):
super().__init__()
if down_dims is None:
down_dims = [256, 512, 1024]
# Encoder for the diffusion timestep.
self.diffusion_step_encoder = nn.Sequential(
_SinusoidalPosEmb(diffusion_step_embed_dim),
nn.Linear(diffusion_step_embed_dim, diffusion_step_embed_dim * 4),
nn.Mish(),
nn.Linear(diffusion_step_embed_dim * 4, diffusion_step_embed_dim),
)
# The FiLM conditioning dimension.
cond_dim = diffusion_step_embed_dim
if global_cond_dim is not None:
cond_dim += global_cond_dim
self.local_cond_down_encoder = None
self.local_cond_up_encoder = None
if local_cond_dim is not None:
# Encoder for the local conditioning. The output gets added to the Unet encoder input.
self.local_cond_down_encoder = _ConditionalResidualBlock1D(
local_cond_dim,
down_dims[0],
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
film_scale_modulation=film_scale_modulation,
)
# Encoder for the local conditioning. The output gets added to the Unet encoder output.
self.local_cond_up_encoder = _ConditionalResidualBlock1D(
local_cond_dim,
down_dims[0],
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
film_scale_modulation=film_scale_modulation,
)
# In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we
# just reverse these.
in_out = [(input_dim, down_dims[0])] + list(zip(down_dims[:-1], down_dims[1:], strict=True))
# Unet encoder.
self.down_modules = nn.ModuleList([])
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (len(in_out) - 1)
self.down_modules.append(
nn.ModuleList(
[
_ConditionalResidualBlock1D(
dim_in,
dim_out,
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
film_scale_modulation=film_scale_modulation,
),
_ConditionalResidualBlock1D(
dim_out,
dim_out,
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
film_scale_modulation=film_scale_modulation,
),
# Downsample as long as it is not the last block.
nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(),
]
)
)
# Processing in the middle of the auto-encoder.
self.mid_modules = nn.ModuleList(
[
_ConditionalResidualBlock1D(
down_dims[-1],
down_dims[-1],
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
film_scale_modulation=film_scale_modulation,
),
_ConditionalResidualBlock1D(
down_dims[-1],
down_dims[-1],
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
film_scale_modulation=film_scale_modulation,
),
]
)
# Unet decoder.
self.up_modules = nn.ModuleList([])
for ind, (dim_out, dim_in) in enumerate(reversed(in_out[1:])):
is_last = ind >= (len(in_out) - 1)
self.up_modules.append(
nn.ModuleList(
[
_ConditionalResidualBlock1D(
dim_in * 2, # x2 as it takes the encoder's skip connection as well
dim_out,
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
film_scale_modulation=film_scale_modulation,
),
_ConditionalResidualBlock1D(
dim_out,
dim_out,
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
film_scale_modulation=film_scale_modulation,
),
# Upsample as long as it is not the last block.
nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(),
]
)
)
self.final_conv = nn.Sequential(
_Conv1dBlock(down_dims[0], down_dims[0], kernel_size=kernel_size),
nn.Conv1d(down_dims[0], input_dim, 1),
)
def forward(self, x: Tensor, timestep: Tensor | int, local_cond=None, global_cond=None) -> Tensor:
"""
Args:
x: (B, T, input_dim) tensor for input to the Unet.
timestep: (B,) tensor of (timestep_we_are_denoising_from - 1).
local_cond: (B, T, local_cond_dim)
global_cond: (B, global_cond_dim)
output: (B, T, input_dim)
Returns:
(B, T, input_dim)
"""
# For 1D convolutions we'll need feature dimension first.
x = einops.rearrange(x, "b t d -> b d t")
if local_cond is not None:
if self.local_cond_down_encoder is None or self.local_cond_up_encoder is None:
raise ValueError(
"`local_cond` was provided but the relevant encoders weren't built at initialization."
)
local_cond = einops.rearrange(local_cond, "b t d -> b d t")
timesteps_embed = self.diffusion_step_encoder(timestep)
# If there is a global conditioning feature, concatenate it to the timestep embedding.
if global_cond is not None:
global_feature = torch.cat([timesteps_embed, global_cond], axis=-1)
else:
global_feature = timesteps_embed
encoder_skip_features: list[Tensor] = []
for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
x = resnet(x, global_feature)
if idx == 0 and local_cond is not None:
x = x + self.local_cond_down_encoder(local_cond, global_feature)
x = resnet2(x, global_feature)
encoder_skip_features.append(x)
x = downsample(x)
for mid_module in self.mid_modules:
x = mid_module(x, global_feature)
for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
x = torch.cat((x, encoder_skip_features.pop()), dim=1)
x = resnet(x, global_feature)
# Note: The condition in the original implementation is:
# if idx == len(self.up_modules) and local_cond is not None:
# But as they mention in their comments, this is incorrect. We use the correct condition here.
if idx == (len(self.up_modules) - 1) and local_cond is not None:
x = x + self.local_cond_up_encoder(local_cond, global_feature)
x = resnet2(x, global_feature)
x = upsample(x)
x = self.final_conv(x)
x = einops.rearrange(x, "b d t -> b t d")
return x
class _ConditionalResidualBlock1D(nn.Module):
"""ResNet style 1D convolutional block with FiLM modulation for conditioning."""
def __init__(
self,
in_channels: int,
out_channels: int,
cond_dim: int,
kernel_size: int = 3,
n_groups: int = 8,
# Set to True to do scale modulation with FiLM as well as bias modulation (defaults to False meaning
# FiLM just modulates bias).
film_scale_modulation: bool = False,
):
super().__init__()
self.film_scale_modulation = film_scale_modulation
self.out_channels = out_channels
self.conv1 = _Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups)
# FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale.
cond_channels = out_channels * 2 if film_scale_modulation else out_channels
self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels))
self.conv2 = _Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups)
# A final convolution for dimension matching the residual (if needed).
self.residual_conv = (
nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
)
def forward(self, x: Tensor, cond: Tensor) -> Tensor:
"""
Args:
x: (B, in_channels, T)
cond: (B, cond_dim)
Returns:
(B, out_channels, T)
"""
out = self.conv1(x)
# Get condition embedding. Unsqueeze for broadcasting to `out`, resulting in (B, out_channels, 1).
cond_embed = self.cond_encoder(cond).unsqueeze(-1)
if self.film_scale_modulation:
# Treat the embedding as a list of scales and biases.
scale = cond_embed[:, : self.out_channels]
bias = cond_embed[:, self.out_channels :]
out = scale * out + bias
else:
# Treat the embedding as biases.
out = out + cond_embed
out = self.conv2(out)
out = out + self.residual_conv(x)
return out
class _EMA:
"""
Exponential Moving Average of models weights
"""
def __init__(
self, model, update_after_step=0, inv_gamma=1.0, power=2 / 3, min_value=0.0, max_value=0.9999
):
"""
@crowsonkb's notes on EMA Warmup:
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
at 215.4k steps).
Args:
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
power (float): Exponential factor of EMA warmup. Default: 2/3.
min_value (float): The minimum EMA decay rate. Default: 0.
"""
self.averaged_model = model
self.averaged_model.eval()
self.averaged_model.requires_grad_(False)
self.update_after_step = update_after_step
self.inv_gamma = inv_gamma
self.power = power
self.min_value = min_value
self.max_value = max_value
self.alpha = 0.0
self.optimization_step = 0
def get_decay(self, optimization_step):
"""
Compute the decay factor for the exponential moving average.
"""
step = max(0, optimization_step - self.update_after_step - 1)
value = 1 - (1 + step / self.inv_gamma) ** -self.power
if step <= 0:
return 0.0
return max(self.min_value, min(value, self.max_value))
@torch.no_grad()
def step(self, new_model):
self.alpha = self.get_decay(self.optimization_step)
for module, ema_module in zip(new_model.modules(), self.averaged_model.modules(), strict=True):
# Iterate over immediate parameters only.
for param, ema_param in zip(
module.parameters(recurse=False), ema_module.parameters(recurse=False), strict=True
):
if isinstance(param, dict):
raise RuntimeError("Dict parameter not supported")
if isinstance(module, _BatchNorm) or not param.requires_grad:
# Copy BatchNorm parameters, and non-trainable parameters directly.
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
else:
ema_param.mul_(self.alpha)
ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.alpha)
self.optimization_step += 1

View File

@ -1,169 +0,0 @@
import copy
import logging
import time
from collections import deque
import hydra
import torch
from diffusers.optimization import get_scheduler
from torch import nn
from lerobot.common.policies.diffusion.model.diffusion_unet_image_policy import DiffusionUnetImagePolicy
from lerobot.common.policies.utils import populate_queues
from lerobot.common.utils import get_safe_torch_device
class DiffusionPolicy(nn.Module):
name = "diffusion"
def __init__(
self,
cfg,
cfg_device,
cfg_noise_scheduler,
cfg_optimizer,
cfg_ema,
shape_meta: dict,
horizon,
n_action_steps,
n_obs_steps,
num_inference_steps=None,
diffusion_step_embed_dim=256,
down_dims=(256, 512, 1024),
kernel_size=5,
n_groups=8,
film_scale_modulation=True,
**_,
):
super().__init__()
self.cfg = cfg
self.n_obs_steps = n_obs_steps
self.n_action_steps = n_action_steps
# queues are populated during rollout of the policy, they contain the n latest observations and actions
self._queues = None
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
self.diffusion = DiffusionUnetImagePolicy(
cfg,
shape_meta=shape_meta,
noise_scheduler=noise_scheduler,
horizon=horizon,
n_action_steps=n_action_steps,
n_obs_steps=n_obs_steps,
num_inference_steps=num_inference_steps,
diffusion_step_embed_dim=diffusion_step_embed_dim,
down_dims=down_dims,
kernel_size=kernel_size,
n_groups=n_groups,
film_scale_modulation=film_scale_modulation,
)
self.device = get_safe_torch_device(cfg_device)
self.diffusion.to(self.device)
# TODO(alexander-soare): This should probably be managed outside of the policy class.
self.ema_diffusion = None
self.ema = None
if self.cfg.use_ema:
self.ema_diffusion = copy.deepcopy(self.diffusion)
self.ema = hydra.utils.instantiate(
cfg_ema,
model=self.ema_diffusion,
)
self.optimizer = hydra.utils.instantiate(
cfg_optimizer,
params=self.diffusion.parameters(),
)
# TODO(rcadene): modify lr scheduler so that it doesnt depend on epochs but steps
self.global_step = 0
# configure lr scheduler
self.lr_scheduler = get_scheduler(
cfg.lr_scheduler,
optimizer=self.optimizer,
num_warmup_steps=cfg.lr_warmup_steps,
num_training_steps=cfg.offline_steps,
# pytorch assumes stepping LRScheduler every epoch
# however huggingface diffusers steps it every batch
last_epoch=self.global_step - 1,
)
def reset(self):
"""
Clear observation and action queues. Should be called on `env.reset()`
"""
self._queues = {
"observation.image": deque(maxlen=self.n_obs_steps),
"observation.state": deque(maxlen=self.n_obs_steps),
"action": deque(maxlen=self.n_action_steps),
}
@torch.no_grad
def select_action(self, batch, **_):
"""
Note: this uses the ema model weights if self.training == False, otherwise the non-ema model weights.
"""
assert "observation.image" in batch
assert "observation.state" in batch
assert len(batch) == 2
self._queues = populate_queues(self._queues, batch)
if len(self._queues["action"]) == 0:
# stack n latest observations from the queue
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
if not self.training and self.ema_diffusion is not None:
actions = self.ema_diffusion.generate_actions(batch)
else:
actions = self.diffusion.generate_actions(batch)
self._queues["action"].extend(actions.transpose(0, 1))
action = self._queues["action"].popleft()
return action
def forward(self, batch, **_):
start_time = time.time()
self.diffusion.train()
loss = self.diffusion.compute_loss(batch)
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
self.diffusion.parameters(),
self.cfg.grad_clip_norm,
error_if_nonfinite=False,
)
self.optimizer.step()
self.optimizer.zero_grad()
self.lr_scheduler.step()
if self.ema is not None:
self.ema.step(self.diffusion)
info = {
"loss": loss.item(),
"grad_norm": float(grad_norm),
"lr": self.lr_scheduler.get_last_lr()[0],
"update_s": time.time() - start_time,
}
return info
def save(self, fp):
torch.save(self.state_dict(), fp)
def load(self, fp):
d = torch.load(fp)
missing_keys, unexpected_keys = self.load_state_dict(d, strict=False)
if len(missing_keys) > 0:
assert all(k.startswith("ema_diffusion.") for k in missing_keys)
logging.warning(
"DiffusionPolicy.load expected ema parameters in loaded state dict but none were found."
)
assert len(unexpected_keys) == 0

View File

@ -1,59 +1,61 @@
import inspect
from omegaconf import OmegaConf
from omegaconf import DictConfig, OmegaConf
from lerobot.common.utils import get_safe_torch_device
def make_policy(cfg):
if cfg.policy.name == "tdmpc":
def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
expected_kwargs = set(inspect.signature(policy_cfg_class).parameters)
assert set(hydra_cfg.policy).issuperset(
expected_kwargs
), f"Hydra config is missing arguments: {set(hydra_cfg.policy).difference(expected_kwargs)}"
policy_cfg = policy_cfg_class(
**{
k: v
for k, v in OmegaConf.to_container(hydra_cfg.policy, resolve=True).items()
if k in expected_kwargs
}
)
return policy_cfg
def make_policy(hydra_cfg: DictConfig):
if hydra_cfg.policy.name == "tdmpc":
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
policy = TDMPCPolicy(
cfg.policy, n_obs_steps=cfg.n_obs_steps, n_action_steps=cfg.n_action_steps, device=cfg.device
hydra_cfg.policy,
n_obs_steps=hydra_cfg.n_obs_steps,
n_action_steps=hydra_cfg.n_action_steps,
device=hydra_cfg.device,
)
elif cfg.policy.name == "diffusion":
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
elif hydra_cfg.policy.name == "diffusion":
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
policy = DiffusionPolicy(
cfg=cfg.policy,
cfg_device=cfg.device,
cfg_noise_scheduler=cfg.noise_scheduler,
cfg_optimizer=cfg.optimizer,
cfg_ema=cfg.ema,
# n_obs_steps=cfg.n_obs_steps,
# n_action_steps=cfg.n_action_steps,
**cfg.policy,
)
elif cfg.policy.name == "act":
policy_cfg = _policy_cfg_from_hydra_cfg(DiffusionConfig, hydra_cfg)
policy = DiffusionPolicy(policy_cfg)
policy.to(get_safe_torch_device(hydra_cfg.device))
elif hydra_cfg.policy.name == "act":
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
expected_kwargs = set(inspect.signature(ActionChunkingTransformerConfig).parameters)
assert set(cfg.policy).issuperset(
expected_kwargs
), f"Hydra config is missing arguments: {set(cfg.policy).difference(expected_kwargs)}"
policy_cfg = ActionChunkingTransformerConfig(
**{
k: v
for k, v in OmegaConf.to_container(cfg.policy, resolve=True).items()
if k in expected_kwargs
}
)
policy_cfg = _policy_cfg_from_hydra_cfg(ActionChunkingTransformerConfig, hydra_cfg)
policy = ActionChunkingTransformerPolicy(policy_cfg)
policy.to(get_safe_torch_device(cfg.device))
policy.to(get_safe_torch_device(hydra_cfg.device))
else:
raise ValueError(cfg.policy.name)
raise ValueError(hydra_cfg.policy.name)
if cfg.policy.pretrained_model_path:
if hydra_cfg.policy.pretrained_model_path:
# TODO(rcadene): hack for old pretrained models from fowm
if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path:
if "offline" in cfg.policy.pretrained_model_path:
if hydra_cfg.policy.name == "tdmpc" and "fowm" in hydra_cfg.policy.pretrained_model_path:
if "offline" in hydra_cfg.policy.pretrained_model_path:
policy.step[0] = 25000
elif "final" in cfg.policy.pretrained_model_path:
elif "final" in hydra_cfg.policy.pretrained_model_path:
policy.step[0] = 100000
else:
raise NotImplementedError()
policy.load(cfg.policy.pretrained_model_path)
policy.load(hydra_cfg.policy.pretrained_model_path)
return policy

View File

@ -18,7 +18,7 @@ policy:
pretrained_model_path:
# Environment.
# Inherit these from the environment.
# Inherit these from the environment config.
state_dim: ???
action_dim: ???

View File

@ -1,17 +1,5 @@
# @package _global_
shape_meta:
# acceptable types: rgb, low_dim
obs:
image:
shape: [3, 96, 96]
type: rgb
agent_pos:
shape: [2]
type: low_dim
action:
shape: [2]
seed: 100000
horizon: 16
n_obs_steps: 2
@ -33,75 +21,70 @@ offline_prioritized_sampler: true
policy:
name: diffusion
shape_meta: ${shape_meta}
pretrained_model_path:
horizon: ${horizon}
# Environment.
# Inherit these from the environment config.
state_dim: ???
action_dim: ???
image_size:
- ${env.image_size} # height
- ${env.image_size} # width
# Inputs / output structure.
n_obs_steps: ${n_obs_steps}
horizon: ${horizon}
n_action_steps: ${n_action_steps}
num_inference_steps: 100
# crop_shape: null
diffusion_step_embed_dim: 128
# Vision preprocessing.
image_normalization_mean: [0.5, 0.5, 0.5]
image_normalization_std: [0.5, 0.5, 0.5]
# Architecture / modeling.
# Vision backbone.
vision_backbone: resnet18
crop_shape: [84, 84]
random_crop: True
use_pretrained_backbone: false
use_group_norm: True
spatial_softmax_num_keypoints: 32
# Unet.
down_dims: [512, 1024, 2048]
kernel_size: 5
n_groups: 8
diffusion_step_embed_dim: 128
film_scale_modulation: True
pretrained_model_path:
batch_size: 64
per_alpha: 0.6
per_beta: 0.4
balanced_sampling: false
utd: 1
offline_steps: ${offline_steps}
use_ema: true
lr_scheduler: cosine
lr_warmup_steps: 500
grad_clip_norm: 10
delta_timestamps:
observation.image: [-0.1, 0]
observation.state: [-0.1, 0]
action: [-0.1, 0, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1.0, 1.1, 1.2, 1.3, 1.4]
rgb_encoder:
backbone_name: resnet18
pretrained_backbone: false
use_group_norm: True
num_keypoints: 32
relu: true
norm_mean_std: [0.5, 0.5] # for PushT the original impl normalizes to [-1, 1] (maybe not the case for robomimic envs)
crop_shape: [84, 84]
random_crop: True
noise_scheduler:
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
# Noise scheduler.
num_train_timesteps: 100
beta_schedule: squaredcos_cap_v2
beta_start: 0.0001
beta_end: 0.02
beta_schedule: squaredcos_cap_v2
variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan
clip_sample: True # required when predict_epsilon=False
prediction_type: epsilon # or sample
variance_type: fixed_small
prediction_type: epsilon # epsilon / sample
clip_sample: True
rgb_model:
pretrained: false
num_keypoints: 32
relu: true
# Inference
num_inference_steps: 100
ema:
_target_: lerobot.common.policies.diffusion.model.ema_model.EMAModel
update_after_step: 0
inv_gamma: 1.0
power: 0.75
min_value: 0.0
max_value: 0.9999
optimizer:
_target_: torch.optim.AdamW
# ---
# TODO(alexander-soare): Remove these from the policy config.
batch_size: 64
grad_clip_norm: 10
lr: 1.0e-4
betas: [0.95, 0.999]
eps: 1.0e-8
weight_decay: 1.0e-6
lr_scheduler: cosine
lr_warmup_steps: 500
adam_betas: [0.95, 0.999]
adam_eps: 1.0e-8
adam_weight_decay: 1.0e-6
utd: 1
use_ema: true
ema_update_after_step: 0
ema_min_rate: 0.0
ema_max_rate: 0.9999
ema_inv_gamma: 1.0
ema_power: 0.75
delta_timestamps:
observation.images: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]"
observation.state: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]"
action: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1 - ${n_obs_steps} + ${policy.horizon})]"

View File

@ -19,7 +19,7 @@ from lerobot.common.datasets.aloha import AlohaDataset
from lerobot.common.datasets.pusht import PushtDataset
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy