lerobot/lerobot/common/policies/dexvla/policy_heads/modeling_unet_diffusion.py

388 lines
14 KiB
Python

#!/usr/bin/env python
# Copyright 2025 DexVLA Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import math
from typing import Union
import torch
import torch.nn as nn
# requires diffusers==0.11.1
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from transformers.modeling_utils import PreTrainedModel
from .configuration_unet_diffusion import UnetDiffusionPolicyConfig
# =================== UNet for Diffusion ==============
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim, dtype):
super().__init__()
self.dim = dim
self.dtype = dtype
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device, dtype=self.dtype) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class Downsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
def forward(self, x):
return self.conv(x)
class Upsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
def forward(self, x):
return self.conv(x)
class Conv1dBlock(nn.Module):
"""
Conv1d --> GroupNorm --> Mish
"""
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
super().__init__()
self.block = nn.Sequential(
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
nn.GroupNorm(n_groups, out_channels),
nn.Mish(),
)
def forward(self, x):
return self.block(x)
class ConditionalResidualBlock1D(nn.Module):
def __init__(self, in_channels, out_channels, cond_dim, kernel_size=3, n_groups=8):
super().__init__()
self.blocks = nn.ModuleList(
[
Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups),
Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups),
]
)
# FiLM modulation https://arxiv.org/abs/1709.07871
# predicts per-channel scale and bias
cond_channels = out_channels * 2
self.out_channels = out_channels
self.cond_encoder = nn.Sequential(
nn.Mish(), nn.Linear(cond_dim, cond_channels), nn.Unflatten(-1, (-1, 1))
)
# make sure dimensions compatible
self.residual_conv = (
nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
)
def forward(self, x, cond):
"""
x : [ batch_size x in_channels x horizon ]
cond : [ batch_size x cond_dim]
returns:
out : [ batch_size x out_channels x horizon ]
"""
out = self.blocks[0](x)
embed = self.cond_encoder(cond)
embed = embed.reshape(embed.shape[0], 2, self.out_channels, 1)
scale = embed[:, 0, ...]
bias = embed[:, 1, ...]
out = scale * out + bias
out = self.blocks[1](out)
out = out + self.residual_conv(x)
return out
class ConditionalUnet1D(PreTrainedModel):
_no_split_modules = ["mid_modules", "down_modules", "up_modules"]
config_class = UnetDiffusionPolicyConfig
def __init__(self, config: UnetDiffusionPolicyConfig):
"""
input_dim: Dim of actions.
global_cond_dim: Dim of global conditioning applied with FiLM
in addition to diffusion step embedding. This is usually obs_horizon * obs_dim
diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k
down_dims: Channel size for each UNet level.
The length of this array determines number of levels.
kernel_size: Conv kernel size
n_groups: Number of groups for GroupNorm
"""
super().__init__(config)
all_dims = [config.input_dim] + list(config.down_dims)
start_dim = config.down_dims[0]
self.num_queries = config.prediction_horizon
self.noise_samples = config.noise_samples
# self.global_1d_pool = nn.AdaptiveAvgPool1d(1)
# self.proj2action = nn.Linear(config.hidden_dim, config.global_cond_dim)
self.norm_after_pool = nn.LayerNorm(config.global_cond_dim)
self.combine = nn.Linear(config.global_cond_dim + config.state_dim, config.global_cond_dim)
dsed = config.diffusion_step_embed_dim
diffusion_step_encoder = nn.Sequential(
SinusoidalPosEmb(dsed, torch.bfloat16),
nn.Linear(dsed, dsed * 4),
nn.Mish(),
nn.Linear(dsed * 4, dsed),
)
cond_dim = dsed + config.global_cond_dim
in_out = list(zip(all_dims[:-1], all_dims[1:], strict=False))
mid_dim = all_dims[-1]
self.mid_modules = nn.ModuleList(
[
ConditionalResidualBlock1D(
mid_dim,
mid_dim,
cond_dim=cond_dim,
kernel_size=config.kernel_size,
n_groups=config.n_groups,
),
ConditionalResidualBlock1D(
mid_dim,
mid_dim,
cond_dim=cond_dim,
kernel_size=config.kernel_size,
n_groups=config.n_groups,
),
]
)
down_modules = nn.ModuleList([])
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (len(in_out) - 1)
down_modules.append(
nn.ModuleList(
[
ConditionalResidualBlock1D(
dim_in,
dim_out,
cond_dim=cond_dim,
kernel_size=config.kernel_size,
n_groups=config.n_groups,
),
ConditionalResidualBlock1D(
dim_out,
dim_out,
cond_dim=cond_dim,
kernel_size=config.kernel_size,
n_groups=config.n_groups,
),
Downsample1d(dim_out) if not is_last else nn.Identity(),
]
)
)
up_modules = nn.ModuleList([])
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (len(in_out) - 1)
up_modules.append(
nn.ModuleList(
[
ConditionalResidualBlock1D(
dim_out * 2,
dim_in,
cond_dim=cond_dim,
kernel_size=config.kernel_size,
n_groups=config.n_groups,
),
ConditionalResidualBlock1D(
dim_in,
dim_in,
cond_dim=cond_dim,
kernel_size=config.kernel_size,
n_groups=config.n_groups,
),
Upsample1d(dim_in) if not is_last else nn.Identity(),
]
)
)
final_conv = nn.Sequential(
Conv1dBlock(start_dim, start_dim, kernel_size=config.kernel_size),
nn.Conv1d(start_dim, config.input_dim, 1),
)
self.diffusion_step_encoder = diffusion_step_encoder
self.up_modules = up_modules
self.down_modules = down_modules
self.final_conv = final_conv
print("number of parameters: {:e}".format(sum(p.numel() for p in self.parameters())))
self.num_inference_timesteps = config.num_inference_timesteps
# self.proj_to_action = nn.Identity()
self.noise_scheduler = DDIMScheduler(
num_train_timesteps=config.num_train_timesteps, # 100
beta_schedule="squaredcos_cap_v2",
clip_sample=True,
set_alpha_to_one=True,
steps_offset=0,
prediction_type="epsilon",
)
# self.num_inference_timesteps = config.num_inference_timesteps # 100
def forward(self, actions, hidden_states, states, is_pad):
"""
Forward pass for the diffusion head.
:param actions: target actions, shape [b, Ta, D] D:10 = 3+6+1
:param hidden_states: hidden states from the llava_pythia, as the condition for the diffusion, shape [b,Tokens, D] 8 1200 1024
:param states: robot states, shape [b, D]
:return: loss
"""
if actions is not None: # training time
b = actions.size(0)
actions = copy.deepcopy(actions[:, : self.num_queries])
is_pad = copy.deepcopy(is_pad[:, : self.num_queries])
num_noise_samples = self.noise_samples
# sample noise to add to actions
noise = torch.randn(
[num_noise_samples] + list(actions.shape), device=actions.device, dtype=actions.dtype
) # num_noise, b, Ta, D
# sample a diffusion iteration for each data point
timesteps = torch.randint(
0, self.noise_scheduler.config.num_train_timesteps, (b,), device=actions.device
).long()
timesteps, noise = timesteps.to(actions.device), noise.to(actions.device)
# add noise to the clean actions according to the noise magnitude at each diffusion iteration
# (this is the forward diffusion process)
noisy_actions = torch.cat(
[self.noise_scheduler.add_noise(actions, noise[i], timesteps) for i in range(len(noise))],
dim=0,
) # [num_noise_samples * b, Ta, action_dim]
noisy_actions = noisy_actions.to(dtype=actions.dtype)
assert hidden_states.ndim == 3
hidden_states = hidden_states.repeat(num_noise_samples, 1, 1)
timesteps = timesteps.repeat(num_noise_samples)
is_pad = is_pad.repeat(num_noise_samples, 1)
states = states.repeat(num_noise_samples, 1)
noise_pred = self.model_forward(
noisy_actions, timesteps, global_cond=hidden_states, states=states
)
noise = noise.view(noise.size(0) * noise.size(1), *noise.size()[2:])
loss = torch.nn.functional.mse_loss(noise_pred, noise, reduction="none")
loss = (loss * ~is_pad.unsqueeze(-1)).mean()
# loss_dict['loss'] = loss
return {"loss": loss}
# return loss
else: # inference time
b = 1
tp = self.num_queries
action_dim = 14
# initialize action from Gaussian noise
noisy_action = torch.randn((b, tp, action_dim)).cuda()
naction = noisy_action.to(dtype=hidden_states.dtype)
# init scheduler
self.noise_scheduler.set_timesteps(self.num_inference_timesteps)
for k in self.noise_scheduler.timesteps:
# predict noise
noise_pred = self.model_forward(naction, k, global_cond=hidden_states, states=states)
# inverse diffusion step (remove noise)
naction = self.noise_scheduler.step(
model_output=noise_pred, timestep=k, sample=naction
).prev_sample
return naction
def model_forward(
self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int], global_cond=None, states=None
):
"""
x: (b,T,input_dim)
timestep: (b,) or int, diffusion step
global_cond: (b,global_cond_dim)
output: (b,T,input_dim)
"""
# (b,t,c)
sample = sample.moveaxis(-1, -2)
# (b,c,t)
# global_cond = self.global_1d_pool(global_cond.permute(0, 2, 1)).squeeze(-1)
global_cond = global_cond.squeeze(1)
global_cond = self.norm_after_pool(global_cond)
global_cond = torch.cat([global_cond, states], dim=-1) if states is not None else global_cond
global_cond = self.combine(global_cond)
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
global_feature = self.diffusion_step_encoder(timesteps)
if global_cond is not None:
global_feature = torch.cat([global_feature, global_cond], axis=-1)
x = sample
h = []
for _idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
x = resnet(x, global_feature)
x = resnet2(x, global_feature)
h.append(x)
x = downsample(x)
for mid_module in self.mid_modules:
x = mid_module(x, global_feature)
for _idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
x = torch.cat((x, h.pop()), dim=1)
x = resnet(x, global_feature)
x = resnet2(x, global_feature)
x = upsample(x)
x = self.final_conv(x)
# (b,c,t)
x = x.moveaxis(-1, -2)
# (b,t,c)
return x