388 lines
14 KiB
Python
388 lines
14 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2025 DexVLA Team and The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import copy
|
|
import math
|
|
from typing import Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
# requires diffusers==0.11.1
|
|
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
|
from transformers.modeling_utils import PreTrainedModel
|
|
|
|
from .configuration_unet_diffusion import UnetDiffusionPolicyConfig
|
|
|
|
# =================== UNet for Diffusion ==============
|
|
|
|
|
|
class SinusoidalPosEmb(nn.Module):
|
|
def __init__(self, dim, dtype):
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.dtype = dtype
|
|
|
|
def forward(self, x):
|
|
device = x.device
|
|
half_dim = self.dim // 2
|
|
emb = math.log(10000) / (half_dim - 1)
|
|
emb = torch.exp(torch.arange(half_dim, device=device, dtype=self.dtype) * -emb)
|
|
emb = x[:, None] * emb[None, :]
|
|
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
|
return emb
|
|
|
|
|
|
class Downsample1d(nn.Module):
|
|
def __init__(self, dim):
|
|
super().__init__()
|
|
self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
|
|
|
|
def forward(self, x):
|
|
return self.conv(x)
|
|
|
|
|
|
class Upsample1d(nn.Module):
|
|
def __init__(self, dim):
|
|
super().__init__()
|
|
self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
|
|
|
|
def forward(self, x):
|
|
return self.conv(x)
|
|
|
|
|
|
class Conv1dBlock(nn.Module):
|
|
"""
|
|
Conv1d --> GroupNorm --> Mish
|
|
"""
|
|
|
|
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
|
|
super().__init__()
|
|
|
|
self.block = nn.Sequential(
|
|
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
|
|
nn.GroupNorm(n_groups, out_channels),
|
|
nn.Mish(),
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.block(x)
|
|
|
|
|
|
class ConditionalResidualBlock1D(nn.Module):
|
|
def __init__(self, in_channels, out_channels, cond_dim, kernel_size=3, n_groups=8):
|
|
super().__init__()
|
|
|
|
self.blocks = nn.ModuleList(
|
|
[
|
|
Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups),
|
|
Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups),
|
|
]
|
|
)
|
|
|
|
# FiLM modulation https://arxiv.org/abs/1709.07871
|
|
# predicts per-channel scale and bias
|
|
cond_channels = out_channels * 2
|
|
self.out_channels = out_channels
|
|
self.cond_encoder = nn.Sequential(
|
|
nn.Mish(), nn.Linear(cond_dim, cond_channels), nn.Unflatten(-1, (-1, 1))
|
|
)
|
|
|
|
# make sure dimensions compatible
|
|
self.residual_conv = (
|
|
nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
|
|
)
|
|
|
|
def forward(self, x, cond):
|
|
"""
|
|
x : [ batch_size x in_channels x horizon ]
|
|
cond : [ batch_size x cond_dim]
|
|
|
|
returns:
|
|
out : [ batch_size x out_channels x horizon ]
|
|
"""
|
|
out = self.blocks[0](x)
|
|
embed = self.cond_encoder(cond)
|
|
|
|
embed = embed.reshape(embed.shape[0], 2, self.out_channels, 1)
|
|
scale = embed[:, 0, ...]
|
|
bias = embed[:, 1, ...]
|
|
out = scale * out + bias
|
|
|
|
out = self.blocks[1](out)
|
|
out = out + self.residual_conv(x)
|
|
return out
|
|
|
|
|
|
class ConditionalUnet1D(PreTrainedModel):
|
|
_no_split_modules = ["mid_modules", "down_modules", "up_modules"]
|
|
|
|
config_class = UnetDiffusionPolicyConfig
|
|
|
|
def __init__(self, config: UnetDiffusionPolicyConfig):
|
|
"""
|
|
input_dim: Dim of actions.
|
|
global_cond_dim: Dim of global conditioning applied with FiLM
|
|
in addition to diffusion step embedding. This is usually obs_horizon * obs_dim
|
|
diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k
|
|
down_dims: Channel size for each UNet level.
|
|
The length of this array determines number of levels.
|
|
kernel_size: Conv kernel size
|
|
n_groups: Number of groups for GroupNorm
|
|
"""
|
|
|
|
super().__init__(config)
|
|
all_dims = [config.input_dim] + list(config.down_dims)
|
|
start_dim = config.down_dims[0]
|
|
|
|
self.num_queries = config.prediction_horizon
|
|
self.noise_samples = config.noise_samples
|
|
# self.global_1d_pool = nn.AdaptiveAvgPool1d(1)
|
|
# self.proj2action = nn.Linear(config.hidden_dim, config.global_cond_dim)
|
|
self.norm_after_pool = nn.LayerNorm(config.global_cond_dim)
|
|
self.combine = nn.Linear(config.global_cond_dim + config.state_dim, config.global_cond_dim)
|
|
dsed = config.diffusion_step_embed_dim
|
|
diffusion_step_encoder = nn.Sequential(
|
|
SinusoidalPosEmb(dsed, torch.bfloat16),
|
|
nn.Linear(dsed, dsed * 4),
|
|
nn.Mish(),
|
|
nn.Linear(dsed * 4, dsed),
|
|
)
|
|
cond_dim = dsed + config.global_cond_dim
|
|
|
|
in_out = list(zip(all_dims[:-1], all_dims[1:], strict=False))
|
|
mid_dim = all_dims[-1]
|
|
self.mid_modules = nn.ModuleList(
|
|
[
|
|
ConditionalResidualBlock1D(
|
|
mid_dim,
|
|
mid_dim,
|
|
cond_dim=cond_dim,
|
|
kernel_size=config.kernel_size,
|
|
n_groups=config.n_groups,
|
|
),
|
|
ConditionalResidualBlock1D(
|
|
mid_dim,
|
|
mid_dim,
|
|
cond_dim=cond_dim,
|
|
kernel_size=config.kernel_size,
|
|
n_groups=config.n_groups,
|
|
),
|
|
]
|
|
)
|
|
|
|
down_modules = nn.ModuleList([])
|
|
for ind, (dim_in, dim_out) in enumerate(in_out):
|
|
is_last = ind >= (len(in_out) - 1)
|
|
down_modules.append(
|
|
nn.ModuleList(
|
|
[
|
|
ConditionalResidualBlock1D(
|
|
dim_in,
|
|
dim_out,
|
|
cond_dim=cond_dim,
|
|
kernel_size=config.kernel_size,
|
|
n_groups=config.n_groups,
|
|
),
|
|
ConditionalResidualBlock1D(
|
|
dim_out,
|
|
dim_out,
|
|
cond_dim=cond_dim,
|
|
kernel_size=config.kernel_size,
|
|
n_groups=config.n_groups,
|
|
),
|
|
Downsample1d(dim_out) if not is_last else nn.Identity(),
|
|
]
|
|
)
|
|
)
|
|
|
|
up_modules = nn.ModuleList([])
|
|
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
|
is_last = ind >= (len(in_out) - 1)
|
|
up_modules.append(
|
|
nn.ModuleList(
|
|
[
|
|
ConditionalResidualBlock1D(
|
|
dim_out * 2,
|
|
dim_in,
|
|
cond_dim=cond_dim,
|
|
kernel_size=config.kernel_size,
|
|
n_groups=config.n_groups,
|
|
),
|
|
ConditionalResidualBlock1D(
|
|
dim_in,
|
|
dim_in,
|
|
cond_dim=cond_dim,
|
|
kernel_size=config.kernel_size,
|
|
n_groups=config.n_groups,
|
|
),
|
|
Upsample1d(dim_in) if not is_last else nn.Identity(),
|
|
]
|
|
)
|
|
)
|
|
|
|
final_conv = nn.Sequential(
|
|
Conv1dBlock(start_dim, start_dim, kernel_size=config.kernel_size),
|
|
nn.Conv1d(start_dim, config.input_dim, 1),
|
|
)
|
|
|
|
self.diffusion_step_encoder = diffusion_step_encoder
|
|
self.up_modules = up_modules
|
|
self.down_modules = down_modules
|
|
self.final_conv = final_conv
|
|
|
|
print("number of parameters: {:e}".format(sum(p.numel() for p in self.parameters())))
|
|
|
|
self.num_inference_timesteps = config.num_inference_timesteps
|
|
# self.proj_to_action = nn.Identity()
|
|
self.noise_scheduler = DDIMScheduler(
|
|
num_train_timesteps=config.num_train_timesteps, # 100
|
|
beta_schedule="squaredcos_cap_v2",
|
|
clip_sample=True,
|
|
set_alpha_to_one=True,
|
|
steps_offset=0,
|
|
prediction_type="epsilon",
|
|
)
|
|
|
|
# self.num_inference_timesteps = config.num_inference_timesteps # 100
|
|
|
|
def forward(self, actions, hidden_states, states, is_pad):
|
|
"""
|
|
Forward pass for the diffusion head.
|
|
:param actions: target actions, shape [b, Ta, D] D:10 = 3+6+1
|
|
:param hidden_states: hidden states from the llava_pythia, as the condition for the diffusion, shape [b,Tokens, D] 8 1200 1024
|
|
:param states: robot states, shape [b, D]
|
|
:return: loss
|
|
"""
|
|
if actions is not None: # training time
|
|
b = actions.size(0)
|
|
actions = copy.deepcopy(actions[:, : self.num_queries])
|
|
is_pad = copy.deepcopy(is_pad[:, : self.num_queries])
|
|
num_noise_samples = self.noise_samples
|
|
# sample noise to add to actions
|
|
noise = torch.randn(
|
|
[num_noise_samples] + list(actions.shape), device=actions.device, dtype=actions.dtype
|
|
) # num_noise, b, Ta, D
|
|
# sample a diffusion iteration for each data point
|
|
timesteps = torch.randint(
|
|
0, self.noise_scheduler.config.num_train_timesteps, (b,), device=actions.device
|
|
).long()
|
|
|
|
timesteps, noise = timesteps.to(actions.device), noise.to(actions.device)
|
|
|
|
# add noise to the clean actions according to the noise magnitude at each diffusion iteration
|
|
# (this is the forward diffusion process)
|
|
noisy_actions = torch.cat(
|
|
[self.noise_scheduler.add_noise(actions, noise[i], timesteps) for i in range(len(noise))],
|
|
dim=0,
|
|
) # [num_noise_samples * b, Ta, action_dim]
|
|
|
|
noisy_actions = noisy_actions.to(dtype=actions.dtype)
|
|
assert hidden_states.ndim == 3
|
|
|
|
hidden_states = hidden_states.repeat(num_noise_samples, 1, 1)
|
|
timesteps = timesteps.repeat(num_noise_samples)
|
|
is_pad = is_pad.repeat(num_noise_samples, 1)
|
|
states = states.repeat(num_noise_samples, 1)
|
|
|
|
noise_pred = self.model_forward(
|
|
noisy_actions, timesteps, global_cond=hidden_states, states=states
|
|
)
|
|
noise = noise.view(noise.size(0) * noise.size(1), *noise.size()[2:])
|
|
loss = torch.nn.functional.mse_loss(noise_pred, noise, reduction="none")
|
|
loss = (loss * ~is_pad.unsqueeze(-1)).mean()
|
|
# loss_dict['loss'] = loss
|
|
return {"loss": loss}
|
|
# return loss
|
|
else: # inference time
|
|
b = 1
|
|
tp = self.num_queries
|
|
action_dim = 14
|
|
|
|
# initialize action from Gaussian noise
|
|
noisy_action = torch.randn((b, tp, action_dim)).cuda()
|
|
|
|
naction = noisy_action.to(dtype=hidden_states.dtype)
|
|
# init scheduler
|
|
self.noise_scheduler.set_timesteps(self.num_inference_timesteps)
|
|
|
|
for k in self.noise_scheduler.timesteps:
|
|
# predict noise
|
|
noise_pred = self.model_forward(naction, k, global_cond=hidden_states, states=states)
|
|
|
|
# inverse diffusion step (remove noise)
|
|
naction = self.noise_scheduler.step(
|
|
model_output=noise_pred, timestep=k, sample=naction
|
|
).prev_sample
|
|
|
|
return naction
|
|
|
|
def model_forward(
|
|
self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int], global_cond=None, states=None
|
|
):
|
|
"""
|
|
x: (b,T,input_dim)
|
|
timestep: (b,) or int, diffusion step
|
|
global_cond: (b,global_cond_dim)
|
|
output: (b,T,input_dim)
|
|
"""
|
|
# (b,t,c)
|
|
sample = sample.moveaxis(-1, -2)
|
|
# (b,c,t)
|
|
# global_cond = self.global_1d_pool(global_cond.permute(0, 2, 1)).squeeze(-1)
|
|
global_cond = global_cond.squeeze(1)
|
|
|
|
global_cond = self.norm_after_pool(global_cond)
|
|
global_cond = torch.cat([global_cond, states], dim=-1) if states is not None else global_cond
|
|
global_cond = self.combine(global_cond)
|
|
# 1. time
|
|
timesteps = timestep
|
|
if not torch.is_tensor(timesteps):
|
|
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
|
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
|
timesteps = timesteps[None].to(sample.device)
|
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
|
timesteps = timesteps.expand(sample.shape[0])
|
|
|
|
global_feature = self.diffusion_step_encoder(timesteps)
|
|
|
|
if global_cond is not None:
|
|
global_feature = torch.cat([global_feature, global_cond], axis=-1)
|
|
|
|
x = sample
|
|
h = []
|
|
for _idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
|
|
x = resnet(x, global_feature)
|
|
x = resnet2(x, global_feature)
|
|
h.append(x)
|
|
x = downsample(x)
|
|
|
|
for mid_module in self.mid_modules:
|
|
x = mid_module(x, global_feature)
|
|
|
|
for _idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
|
|
x = torch.cat((x, h.pop()), dim=1)
|
|
x = resnet(x, global_feature)
|
|
x = resnet2(x, global_feature)
|
|
x = upsample(x)
|
|
|
|
x = self.final_conv(x)
|
|
|
|
# (b,c,t)
|
|
x = x.moveaxis(-1, -2)
|
|
# (b,t,c)
|
|
return x
|