221 lines
8.5 KiB
Python
221 lines
8.5 KiB
Python
import einops
|
|
import torch
|
|
import torch.nn.functional as F # noqa: N812
|
|
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
|
from torch import Tensor, nn
|
|
|
|
from lerobot.common.policies.diffusion.model.conditional_unet1d import ConditionalUnet1D
|
|
from lerobot.common.policies.diffusion.model.rgb_encoder import RgbEncoder
|
|
from lerobot.common.policies.utils import get_device_from_parameters, get_dtype_from_parameters
|
|
|
|
|
|
class DiffusionUnetImagePolicy(nn.Module):
|
|
"""
|
|
TODO(now): Add DDIM scheduler.
|
|
|
|
Changes: TODO(now)
|
|
- Use single image encoder for now instead of generic obs_encoder. We may generalize again when/if
|
|
needed. Code for a general observation encoder can be found at:
|
|
https://github.com/huggingface/lerobot/blob/920e0d118b493e4cc3058a9b1b764f38ae145d8e/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py
|
|
- Uses the observation as global conditioning for the Unet by default.
|
|
- Does not do any inpainting (which would be applicable if the observation were not used to condition
|
|
the Unet).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
cfg,
|
|
shape_meta: dict,
|
|
noise_scheduler: DDPMScheduler,
|
|
horizon,
|
|
n_action_steps,
|
|
n_obs_steps,
|
|
num_inference_steps=None,
|
|
diffusion_step_embed_dim=256,
|
|
down_dims=(256, 512, 1024),
|
|
kernel_size=5,
|
|
n_groups=8,
|
|
film_scale_modulation=True,
|
|
):
|
|
super().__init__()
|
|
action_shape = shape_meta["action"]["shape"]
|
|
assert len(action_shape) == 1
|
|
action_dim = action_shape[0]
|
|
|
|
self.rgb_encoder = RgbEncoder(input_shape=shape_meta.obs.image.shape, **cfg.rgb_encoder)
|
|
|
|
self.unet = ConditionalUnet1D(
|
|
input_dim=action_dim,
|
|
global_cond_dim=(action_dim + self.rgb_encoder.feature_dim) * n_obs_steps,
|
|
diffusion_step_embed_dim=diffusion_step_embed_dim,
|
|
down_dims=down_dims,
|
|
kernel_size=kernel_size,
|
|
n_groups=n_groups,
|
|
film_scale_modulation=film_scale_modulation,
|
|
)
|
|
|
|
self.noise_scheduler = noise_scheduler
|
|
self.horizon = horizon
|
|
self.action_dim = action_dim
|
|
self.n_action_steps = n_action_steps
|
|
self.n_obs_steps = n_obs_steps
|
|
|
|
if num_inference_steps is None:
|
|
num_inference_steps = noise_scheduler.config.num_train_timesteps
|
|
|
|
self.num_inference_steps = num_inference_steps
|
|
|
|
# ========= inference ============
|
|
def conditional_sample(
|
|
self,
|
|
condition_data,
|
|
inpainting_mask,
|
|
local_cond=None,
|
|
global_cond=None,
|
|
generator=None,
|
|
):
|
|
model = self.unet
|
|
scheduler = self.noise_scheduler
|
|
|
|
trajectory = torch.randn(
|
|
size=condition_data.shape,
|
|
dtype=condition_data.dtype,
|
|
device=condition_data.device,
|
|
generator=generator,
|
|
)
|
|
|
|
# set step values
|
|
scheduler.set_timesteps(self.num_inference_steps)
|
|
|
|
for t in scheduler.timesteps:
|
|
# 1. apply conditioning
|
|
trajectory[inpainting_mask] = condition_data[inpainting_mask]
|
|
|
|
# 2. predict model output
|
|
model_output = model(
|
|
trajectory,
|
|
torch.full(trajectory.shape[:1], t, dtype=torch.long, device=trajectory.device),
|
|
local_cond=local_cond,
|
|
global_cond=global_cond,
|
|
)
|
|
|
|
# 3. compute previous image: x_t -> x_t-1
|
|
trajectory = scheduler.step(
|
|
model_output,
|
|
t,
|
|
trajectory,
|
|
generator=generator,
|
|
).prev_sample
|
|
|
|
# finally make sure conditioning is enforced
|
|
trajectory[inpainting_mask] = condition_data[inpainting_mask]
|
|
|
|
return trajectory
|
|
|
|
def predict_action(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
"""
|
|
This function expects `batch` to have (at least):
|
|
{
|
|
"observation.state": (B, n_obs_steps, state_dim)
|
|
"observation.image": (B, n_obs_steps, C, H, W)
|
|
}
|
|
"""
|
|
assert set(batch).issuperset({"observation.state", "observation.image"})
|
|
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
|
assert n_obs_steps == self.n_obs_steps
|
|
assert self.n_obs_steps == n_obs_steps
|
|
|
|
# build input
|
|
device = get_device_from_parameters(self)
|
|
dtype = get_dtype_from_parameters(self)
|
|
|
|
# Extract image feature (first combine batch and sequence dims).
|
|
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
|
|
# Separate batch and sequence dims.
|
|
img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
|
|
# Concatenate state and image features then flatten to (B, global_cond_dim).
|
|
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
|
|
# reshape back to B, Do
|
|
# empty data for action
|
|
cond_data = torch.zeros(size=(batch_size, self.horizon, self.action_dim), device=device, dtype=dtype)
|
|
cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
|
|
|
|
# run sampling
|
|
nsample = self.conditional_sample(cond_data, cond_mask, global_cond=global_cond)
|
|
|
|
# `horizon` steps worth of actions (from the first observation).
|
|
action = nsample[..., : self.action_dim]
|
|
# Extract `n_action_steps` steps worth of action (from the current observation).
|
|
start = n_obs_steps - 1
|
|
end = start + self.n_action_steps
|
|
action = action[:, start:end]
|
|
|
|
return action
|
|
|
|
def compute_loss(self, batch: dict[str, Tensor]) -> Tensor:
|
|
"""
|
|
This function expects `batch` to have (at least):
|
|
{
|
|
"observation.state": (B, n_obs_steps, state_dim)
|
|
"observation.image": (B, n_obs_steps, C, H, W)
|
|
"action": (B, horizon, action_dim)
|
|
"action_is_pad": (B, horizon) # TODO(now) maybe this is (B, horizon, 1)
|
|
}
|
|
"""
|
|
assert set(batch).issuperset({"observation.state", "observation.image", "action", "action_is_pad"})
|
|
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
|
horizon = batch["action"].shape[1]
|
|
assert horizon == self.horizon
|
|
assert n_obs_steps == self.n_obs_steps
|
|
assert self.n_obs_steps == n_obs_steps
|
|
|
|
# handle different ways of passing observation
|
|
local_cond = None
|
|
global_cond = None
|
|
trajectory = batch["action"]
|
|
cond_data = trajectory
|
|
|
|
# Extract image feature (first combine batch and sequence dims).
|
|
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
|
|
# Separate batch and sequence dims.
|
|
img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
|
|
# Concatenate state and image features then flatten to (B, global_cond_dim).
|
|
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
|
|
|
|
# Sample noise that we'll add to the images
|
|
noise = torch.randn(trajectory.shape, device=trajectory.device)
|
|
# Sample a random timestep for each image
|
|
timesteps = torch.randint(
|
|
0,
|
|
self.noise_scheduler.config.num_train_timesteps,
|
|
(trajectory.shape[0],),
|
|
device=trajectory.device,
|
|
).long()
|
|
# Add noise to the clean images according to the noise magnitude at each timestep
|
|
# (this is the forward diffusion process)
|
|
noisy_trajectory = self.noise_scheduler.add_noise(trajectory, noise, timesteps)
|
|
|
|
# Apply inpainting. TODO(now): implement?
|
|
inpainting_mask = torch.zeros_like(trajectory, dtype=bool)
|
|
noisy_trajectory[inpainting_mask] = cond_data[inpainting_mask]
|
|
|
|
# Predict the noise residual
|
|
pred = self.unet(noisy_trajectory, timesteps, local_cond=local_cond, global_cond=global_cond)
|
|
|
|
pred_type = self.noise_scheduler.config.prediction_type
|
|
if pred_type == "epsilon":
|
|
target = noise
|
|
elif pred_type == "sample":
|
|
target = trajectory
|
|
else:
|
|
raise ValueError(f"Unsupported prediction type {pred_type}")
|
|
|
|
loss = F.mse_loss(pred, target, reduction="none")
|
|
loss = loss * (~inpainting_mask)
|
|
|
|
if "action_is_pad" in batch:
|
|
in_episode_bound = ~batch["action_is_pad"]
|
|
loss = loss * in_episode_bound[:, :, None].type(loss.dtype)
|
|
|
|
return loss.mean()
|