backup wip
This commit is contained in:
parent
14f3ffb412
commit
5608e659e6
|
@ -11,7 +11,7 @@ import torch
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
|
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||||
from lerobot.common.utils import init_hydra_config
|
from lerobot.common.utils import init_hydra_config
|
||||||
|
|
||||||
output_directory = Path("outputs/train/example_pusht_diffusion")
|
output_directory = Path("outputs/train/example_pusht_diffusion")
|
||||||
|
|
|
@ -56,7 +56,7 @@ class ActionChunkingTransformerConfig:
|
||||||
|
|
||||||
# Inputs / output structure.
|
# Inputs / output structure.
|
||||||
n_obs_steps: int = 1
|
n_obs_steps: int = 1
|
||||||
camera_names: list[str] = field(default_factory=lambda: ["top"])
|
camera_names: tuple[str] = ("top",)
|
||||||
chunk_size: int = 100
|
chunk_size: int = 100
|
||||||
n_action_steps: int = 100
|
n_action_steps: int = 100
|
||||||
|
|
||||||
|
@ -101,7 +101,7 @@ class ActionChunkingTransformerConfig:
|
||||||
utd: int = 1
|
utd: int = 1
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""Input validation."""
|
"""Input validation (not exhaustive)."""
|
||||||
if not self.vision_backbone.startswith("resnet"):
|
if not self.vision_backbone.startswith("resnet"):
|
||||||
raise ValueError("`vision_backbone` must be one of the ResNet variants.")
|
raise ValueError("`vision_backbone` must be one of the ResNet variants.")
|
||||||
if self.use_temporal_aggregation:
|
if self.use_temporal_aggregation:
|
||||||
|
|
|
@ -163,7 +163,8 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
|
|
||||||
@torch.no_grad
|
@torch.no_grad
|
||||||
def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
|
def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
|
||||||
"""
|
"""Select a single action given environment observations.
|
||||||
|
|
||||||
This method wraps `select_actions` in order to return one action at a time for execution in the
|
This method wraps `select_actions` in order to return one action at a time for execution in the
|
||||||
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
||||||
queue is empty.
|
queue is empty.
|
||||||
|
|
|
@ -0,0 +1,83 @@
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DiffusionConfig:
|
||||||
|
"""Configuration class for Diffusion Policy.
|
||||||
|
|
||||||
|
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
|
||||||
|
|
||||||
|
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||||
|
Those are: `state_dim`, `action_dim` and `image_size`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_dim: Dimensionality of the observation state space (excluding images).
|
||||||
|
action_dim: Dimensionality of the action space.
|
||||||
|
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
||||||
|
current step and additional steps going back).
|
||||||
|
horizon: Diffusion model action prediction horizon as detailed in the main policy documentation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Environment.
|
||||||
|
# Inherit these from the environment config.
|
||||||
|
state_dim: int = 2
|
||||||
|
action_dim: int = 2
|
||||||
|
image_size: tuple[int, int] = (96, 96)
|
||||||
|
|
||||||
|
# Inputs / output structure.
|
||||||
|
n_obs_steps: int = 2
|
||||||
|
horizon: int = 16
|
||||||
|
n_action_steps: int = 8
|
||||||
|
|
||||||
|
# Vision preprocessing.
|
||||||
|
image_normalization_mean: tuple[float, float, float] = (0.5, 0.5, 0.5)
|
||||||
|
image_normalization_std: tuple[float, float, float] = (0.5, 0.5, 0.5)
|
||||||
|
|
||||||
|
# Architecture / modeling.
|
||||||
|
# Vision backbone.
|
||||||
|
vision_backbone: str = "resnet18"
|
||||||
|
crop_shape: tuple[int, int] = (84, 84)
|
||||||
|
crop_is_random: bool = True
|
||||||
|
use_pretrained_backbone: bool = False
|
||||||
|
use_group_norm: bool = True
|
||||||
|
spatial_softmax_num_keypoints: int = 32
|
||||||
|
# Unet.
|
||||||
|
down_dims: tuple[int, ...] = (512, 1024, 2048)
|
||||||
|
kernel_size: int = 5
|
||||||
|
n_groups: int = 8
|
||||||
|
diffusion_step_embed_dim: int = 128
|
||||||
|
film_scale_modulation: bool = True
|
||||||
|
# Noise scheduler.
|
||||||
|
num_train_timesteps: int = 100
|
||||||
|
beta_schedule: str = "squaredcos_cap_v2"
|
||||||
|
beta_start: float = 0.0001
|
||||||
|
beta_end: float = 0.02
|
||||||
|
variance_type: str = "fixed_small"
|
||||||
|
prediction_type: str = "epsilon"
|
||||||
|
clip_sample: True
|
||||||
|
|
||||||
|
# Inference
|
||||||
|
num_inference_steps: int = 100
|
||||||
|
|
||||||
|
# ---
|
||||||
|
# TODO(alexander-soare): Remove these from the policy config.
|
||||||
|
batch_size: int = 64
|
||||||
|
grad_clip_norm: int = 10
|
||||||
|
lr: float = 1.0e-4
|
||||||
|
lr_scheduler: str = "cosine"
|
||||||
|
lr_warmup_steps: int = 500
|
||||||
|
adam_betas: tuple[float, float] = (0.95, 0.999)
|
||||||
|
adam_eps: float = 1.0e-8
|
||||||
|
adam_weight_decay: float = 1.0e-6
|
||||||
|
utd: int = 1
|
||||||
|
use_ema: bool = True
|
||||||
|
ema_update_after_step: int = 0
|
||||||
|
ema_min_rate: float = 0.0
|
||||||
|
ema_max_rate: float = 0.9999
|
||||||
|
ema_inv_gamma: float = 1.0
|
||||||
|
ema_power: float = 0.75
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""Input validation (not exhaustive)."""
|
||||||
|
if not self.vision_backbone.startswith("resnet"):
|
||||||
|
raise ValueError("`vision_backbone` must be one of the ResNet variants.")
|
|
@ -1,306 +0,0 @@
|
||||||
import logging
|
|
||||||
import math
|
|
||||||
|
|
||||||
import einops
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class _SinusoidalPosEmb(nn.Module):
|
|
||||||
def __init__(self, dim):
|
|
||||||
super().__init__()
|
|
||||||
self.dim = dim
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
device = x.device
|
|
||||||
half_dim = self.dim // 2
|
|
||||||
emb = math.log(10000) / (half_dim - 1)
|
|
||||||
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
|
||||||
emb = x[:, None] * emb[None, :]
|
|
||||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
|
||||||
return emb
|
|
||||||
|
|
||||||
|
|
||||||
class _Conv1dBlock(nn.Module):
|
|
||||||
"""Conv1d --> GroupNorm --> Mish"""
|
|
||||||
|
|
||||||
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.block = nn.Sequential(
|
|
||||||
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
|
|
||||||
nn.GroupNorm(n_groups, out_channels),
|
|
||||||
nn.Mish(),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.block(x)
|
|
||||||
|
|
||||||
|
|
||||||
class _ConditionalResidualBlock1D(nn.Module):
|
|
||||||
"""ResNet style 1D convolutional block with FiLM modulation for conditioning."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels: int,
|
|
||||||
out_channels: int,
|
|
||||||
cond_dim: int,
|
|
||||||
kernel_size: int = 3,
|
|
||||||
n_groups: int = 8,
|
|
||||||
# Set to True to do scale modulation with FiLM as well as bias modulation (defaults to False meaning
|
|
||||||
# FiLM just modulates bias).
|
|
||||||
film_scale_modulation: bool = False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.film_scale_modulation = film_scale_modulation
|
|
||||||
self.out_channels = out_channels
|
|
||||||
|
|
||||||
self.conv1 = _Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups)
|
|
||||||
|
|
||||||
# FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale.
|
|
||||||
cond_channels = out_channels * 2 if film_scale_modulation else out_channels
|
|
||||||
self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels))
|
|
||||||
|
|
||||||
self.conv2 = _Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups)
|
|
||||||
|
|
||||||
# A final convolution for dimension matching the residual (if needed).
|
|
||||||
self.residual_conv = (
|
|
||||||
nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x: Tensor, cond: Tensor) -> Tensor:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
x: (B, in_channels, T)
|
|
||||||
cond: (B, cond_dim)
|
|
||||||
Returns:
|
|
||||||
(B, out_channels, T)
|
|
||||||
"""
|
|
||||||
out = self.conv1(x)
|
|
||||||
|
|
||||||
# Get condition embedding. Unsqueeze for broadcasting to `out`, resulting in (B, out_channels, 1).
|
|
||||||
cond_embed = self.cond_encoder(cond).unsqueeze(-1)
|
|
||||||
if self.film_scale_modulation:
|
|
||||||
# Treat the embedding as a list of scales and biases.
|
|
||||||
scale = cond_embed[:, : self.out_channels]
|
|
||||||
bias = cond_embed[:, self.out_channels :]
|
|
||||||
out = scale * out + bias
|
|
||||||
else:
|
|
||||||
# Treat the embedding as biases.
|
|
||||||
out = out + cond_embed
|
|
||||||
|
|
||||||
out = self.conv2(out)
|
|
||||||
out = out + self.residual_conv(x)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class ConditionalUnet1D(nn.Module):
|
|
||||||
"""A 1D convolutional UNet with FiLM modulation for conditioning.
|
|
||||||
|
|
||||||
Two types of conditioning can be applied:
|
|
||||||
- Global: Conditioning information that is aggregated over the whole observation window. This is
|
|
||||||
incorporated via the FiLM technique in the residual convolution blocks of the Unet's encoder/decoder.
|
|
||||||
- Local: Conditioning information for each timestep in the observation window. This is incorporated
|
|
||||||
by encoding the information via 1D convolutions and adding the resulting embeddings to the inputs and
|
|
||||||
outputs of the Unet's encoder/decoder.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
input_dim: int,
|
|
||||||
local_cond_dim: int | None = None,
|
|
||||||
global_cond_dim: int | None = None,
|
|
||||||
diffusion_step_embed_dim: int = 256,
|
|
||||||
down_dims: int | None = None,
|
|
||||||
kernel_size: int = 3,
|
|
||||||
n_groups: int = 8,
|
|
||||||
film_scale_modulation: bool = False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
if down_dims is None:
|
|
||||||
down_dims = [256, 512, 1024]
|
|
||||||
|
|
||||||
# Encoder for the diffusion timestep.
|
|
||||||
self.diffusion_step_encoder = nn.Sequential(
|
|
||||||
_SinusoidalPosEmb(diffusion_step_embed_dim),
|
|
||||||
nn.Linear(diffusion_step_embed_dim, diffusion_step_embed_dim * 4),
|
|
||||||
nn.Mish(),
|
|
||||||
nn.Linear(diffusion_step_embed_dim * 4, diffusion_step_embed_dim),
|
|
||||||
)
|
|
||||||
|
|
||||||
# The FiLM conditioning dimension.
|
|
||||||
cond_dim = diffusion_step_embed_dim
|
|
||||||
if global_cond_dim is not None:
|
|
||||||
cond_dim += global_cond_dim
|
|
||||||
|
|
||||||
self.local_cond_down_encoder = None
|
|
||||||
self.local_cond_up_encoder = None
|
|
||||||
if local_cond_dim is not None:
|
|
||||||
# Encoder for the local conditioning. The output gets added to the Unet encoder input.
|
|
||||||
self.local_cond_down_encoder = _ConditionalResidualBlock1D(
|
|
||||||
local_cond_dim,
|
|
||||||
down_dims[0],
|
|
||||||
cond_dim=cond_dim,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
n_groups=n_groups,
|
|
||||||
film_scale_modulation=film_scale_modulation,
|
|
||||||
)
|
|
||||||
# Encoder for the local conditioning. The output gets added to the Unet encoder output.
|
|
||||||
self.local_cond_up_encoder = _ConditionalResidualBlock1D(
|
|
||||||
local_cond_dim,
|
|
||||||
down_dims[0],
|
|
||||||
cond_dim=cond_dim,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
n_groups=n_groups,
|
|
||||||
film_scale_modulation=film_scale_modulation,
|
|
||||||
)
|
|
||||||
|
|
||||||
# In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we
|
|
||||||
# just reverse these.
|
|
||||||
in_out = [(input_dim, down_dims[0])] + list(zip(down_dims[:-1], down_dims[1:], strict=True))
|
|
||||||
|
|
||||||
# Unet encoder.
|
|
||||||
self.down_modules = nn.ModuleList([])
|
|
||||||
for ind, (dim_in, dim_out) in enumerate(in_out):
|
|
||||||
is_last = ind >= (len(in_out) - 1)
|
|
||||||
self.down_modules.append(
|
|
||||||
nn.ModuleList(
|
|
||||||
[
|
|
||||||
_ConditionalResidualBlock1D(
|
|
||||||
dim_in,
|
|
||||||
dim_out,
|
|
||||||
cond_dim=cond_dim,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
n_groups=n_groups,
|
|
||||||
film_scale_modulation=film_scale_modulation,
|
|
||||||
),
|
|
||||||
_ConditionalResidualBlock1D(
|
|
||||||
dim_out,
|
|
||||||
dim_out,
|
|
||||||
cond_dim=cond_dim,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
n_groups=n_groups,
|
|
||||||
film_scale_modulation=film_scale_modulation,
|
|
||||||
),
|
|
||||||
# Downsample as long as it is not the last block.
|
|
||||||
nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Processing in the middle of the auto-encoder.
|
|
||||||
self.mid_modules = nn.ModuleList(
|
|
||||||
[
|
|
||||||
_ConditionalResidualBlock1D(
|
|
||||||
down_dims[-1],
|
|
||||||
down_dims[-1],
|
|
||||||
cond_dim=cond_dim,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
n_groups=n_groups,
|
|
||||||
film_scale_modulation=film_scale_modulation,
|
|
||||||
),
|
|
||||||
_ConditionalResidualBlock1D(
|
|
||||||
down_dims[-1],
|
|
||||||
down_dims[-1],
|
|
||||||
cond_dim=cond_dim,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
n_groups=n_groups,
|
|
||||||
film_scale_modulation=film_scale_modulation,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Unet decoder.
|
|
||||||
self.up_modules = nn.ModuleList([])
|
|
||||||
for ind, (dim_out, dim_in) in enumerate(reversed(in_out[1:])):
|
|
||||||
is_last = ind >= (len(in_out) - 1)
|
|
||||||
self.up_modules.append(
|
|
||||||
nn.ModuleList(
|
|
||||||
[
|
|
||||||
_ConditionalResidualBlock1D(
|
|
||||||
dim_in * 2, # x2 as it takes the encoder's skip connection as well
|
|
||||||
dim_out,
|
|
||||||
cond_dim=cond_dim,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
n_groups=n_groups,
|
|
||||||
film_scale_modulation=film_scale_modulation,
|
|
||||||
),
|
|
||||||
_ConditionalResidualBlock1D(
|
|
||||||
dim_out,
|
|
||||||
dim_out,
|
|
||||||
cond_dim=cond_dim,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
n_groups=n_groups,
|
|
||||||
film_scale_modulation=film_scale_modulation,
|
|
||||||
),
|
|
||||||
# Upsample as long as it is not the last block.
|
|
||||||
nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.final_conv = nn.Sequential(
|
|
||||||
_Conv1dBlock(down_dims[0], down_dims[0], kernel_size=kernel_size),
|
|
||||||
nn.Conv1d(down_dims[0], input_dim, 1),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x: Tensor, timestep: Tensor | int, local_cond=None, global_cond=None) -> Tensor:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
x: (B, T, input_dim) tensor for input to the Unet.
|
|
||||||
timestep: (B,) tensor of (timestep_we_are_denoising_from - 1).
|
|
||||||
local_cond: (B, T, local_cond_dim)
|
|
||||||
global_cond: (B, global_cond_dim)
|
|
||||||
output: (B, T, input_dim)
|
|
||||||
Returns:
|
|
||||||
(B, T, input_dim)
|
|
||||||
"""
|
|
||||||
# For 1D convolutions we'll need feature dimension first.
|
|
||||||
x = einops.rearrange(x, "b t d -> b d t")
|
|
||||||
if local_cond is not None:
|
|
||||||
if self.local_cond_down_encoder is None or self.local_cond_up_encoder is None:
|
|
||||||
raise ValueError(
|
|
||||||
"`local_cond` was provided but the relevant encoders weren't built at initialization."
|
|
||||||
)
|
|
||||||
local_cond = einops.rearrange(local_cond, "b t d -> b d t")
|
|
||||||
|
|
||||||
timesteps_embed = self.diffusion_step_encoder(timestep)
|
|
||||||
|
|
||||||
# If there is a global conditioning feature, concatenate it to the timestep embedding.
|
|
||||||
if global_cond is not None:
|
|
||||||
global_feature = torch.cat([timesteps_embed, global_cond], axis=-1)
|
|
||||||
else:
|
|
||||||
global_feature = timesteps_embed
|
|
||||||
|
|
||||||
encoder_skip_features: list[Tensor] = []
|
|
||||||
for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
|
|
||||||
x = resnet(x, global_feature)
|
|
||||||
if idx == 0 and local_cond is not None:
|
|
||||||
x = x + self.local_cond_down_encoder(local_cond, global_feature)
|
|
||||||
x = resnet2(x, global_feature)
|
|
||||||
encoder_skip_features.append(x)
|
|
||||||
x = downsample(x)
|
|
||||||
|
|
||||||
for mid_module in self.mid_modules:
|
|
||||||
x = mid_module(x, global_feature)
|
|
||||||
|
|
||||||
for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
|
|
||||||
x = torch.cat((x, encoder_skip_features.pop()), dim=1)
|
|
||||||
x = resnet(x, global_feature)
|
|
||||||
# Note: The condition in the original implementation is:
|
|
||||||
# if idx == len(self.up_modules) and local_cond is not None:
|
|
||||||
# But as they mention in their comments, this is incorrect. We use the correct condition here.
|
|
||||||
if idx == (len(self.up_modules) - 1) and local_cond is not None:
|
|
||||||
x = x + self.local_cond_up_encoder(local_cond, global_feature)
|
|
||||||
x = resnet2(x, global_feature)
|
|
||||||
x = upsample(x)
|
|
||||||
|
|
||||||
x = self.final_conv(x)
|
|
||||||
|
|
||||||
x = einops.rearrange(x, "b d t -> b t d")
|
|
||||||
return x
|
|
|
@ -1,175 +0,0 @@
|
||||||
import einops
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F # noqa: N812
|
|
||||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
|
||||||
from torch import Tensor, nn
|
|
||||||
|
|
||||||
from lerobot.common.policies.diffusion.model.conditional_unet1d import ConditionalUnet1D
|
|
||||||
from lerobot.common.policies.diffusion.model.rgb_encoder import RgbEncoder
|
|
||||||
from lerobot.common.policies.utils import get_device_from_parameters, get_dtype_from_parameters
|
|
||||||
|
|
||||||
|
|
||||||
class DiffusionUnetImagePolicy(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
cfg,
|
|
||||||
shape_meta: dict,
|
|
||||||
noise_scheduler: DDPMScheduler,
|
|
||||||
horizon,
|
|
||||||
n_action_steps,
|
|
||||||
n_obs_steps,
|
|
||||||
num_inference_steps=None,
|
|
||||||
diffusion_step_embed_dim=256,
|
|
||||||
down_dims=(256, 512, 1024),
|
|
||||||
kernel_size=5,
|
|
||||||
n_groups=8,
|
|
||||||
film_scale_modulation=True,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
action_shape = shape_meta["action"]["shape"]
|
|
||||||
assert len(action_shape) == 1
|
|
||||||
action_dim = action_shape[0]
|
|
||||||
|
|
||||||
self.rgb_encoder = RgbEncoder(input_shape=shape_meta.obs.image.shape, **cfg.rgb_encoder)
|
|
||||||
|
|
||||||
self.unet = ConditionalUnet1D(
|
|
||||||
input_dim=action_dim,
|
|
||||||
global_cond_dim=(action_dim + self.rgb_encoder.feature_dim) * n_obs_steps,
|
|
||||||
diffusion_step_embed_dim=diffusion_step_embed_dim,
|
|
||||||
down_dims=down_dims,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
n_groups=n_groups,
|
|
||||||
film_scale_modulation=film_scale_modulation,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.noise_scheduler = noise_scheduler
|
|
||||||
self.horizon = horizon
|
|
||||||
self.action_dim = action_dim
|
|
||||||
self.n_action_steps = n_action_steps
|
|
||||||
self.n_obs_steps = n_obs_steps
|
|
||||||
|
|
||||||
if num_inference_steps is None:
|
|
||||||
num_inference_steps = noise_scheduler.config.num_train_timesteps
|
|
||||||
|
|
||||||
self.num_inference_steps = num_inference_steps
|
|
||||||
|
|
||||||
# ========= inference ============
|
|
||||||
def conditional_sample(self, batch_size, global_cond=None, generator=None):
|
|
||||||
device = get_device_from_parameters(self)
|
|
||||||
dtype = get_dtype_from_parameters(self)
|
|
||||||
|
|
||||||
# Sample prior.
|
|
||||||
sample = torch.randn(
|
|
||||||
size=(batch_size, self.horizon, self.action_dim),
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
generator=generator,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.noise_scheduler.set_timesteps(self.num_inference_steps)
|
|
||||||
|
|
||||||
for t in self.noise_scheduler.timesteps:
|
|
||||||
# Predict model output.
|
|
||||||
model_output = self.unet(
|
|
||||||
sample,
|
|
||||||
torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device),
|
|
||||||
global_cond=global_cond,
|
|
||||||
)
|
|
||||||
# Compute previous image: x_t -> x_t-1
|
|
||||||
sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample
|
|
||||||
|
|
||||||
return sample
|
|
||||||
|
|
||||||
def generate_actions(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
||||||
"""
|
|
||||||
This function expects `batch` to have (at least):
|
|
||||||
{
|
|
||||||
"observation.state": (B, n_obs_steps, state_dim)
|
|
||||||
"observation.image": (B, n_obs_steps, C, H, W)
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
assert set(batch).issuperset({"observation.state", "observation.image"})
|
|
||||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
|
||||||
assert n_obs_steps == self.n_obs_steps
|
|
||||||
assert self.n_obs_steps == n_obs_steps
|
|
||||||
|
|
||||||
# Extract image feature (first combine batch and sequence dims).
|
|
||||||
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
|
|
||||||
# Separate batch and sequence dims.
|
|
||||||
img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
|
|
||||||
# Concatenate state and image features then flatten to (B, global_cond_dim).
|
|
||||||
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
|
|
||||||
|
|
||||||
# run sampling
|
|
||||||
sample = self.conditional_sample(batch_size, global_cond=global_cond)
|
|
||||||
|
|
||||||
# `horizon` steps worth of actions (from the first observation).
|
|
||||||
action = sample[..., : self.action_dim]
|
|
||||||
# Extract `n_action_steps` steps worth of actions (from the current observation).
|
|
||||||
start = n_obs_steps - 1
|
|
||||||
end = start + self.n_action_steps
|
|
||||||
action = action[:, start:end]
|
|
||||||
|
|
||||||
return action
|
|
||||||
|
|
||||||
def compute_loss(self, batch: dict[str, Tensor]) -> Tensor:
|
|
||||||
"""
|
|
||||||
This function expects `batch` to have (at least):
|
|
||||||
{
|
|
||||||
"observation.state": (B, n_obs_steps, state_dim)
|
|
||||||
"observation.image": (B, n_obs_steps, C, H, W)
|
|
||||||
"action": (B, horizon, action_dim)
|
|
||||||
"action_is_pad": (B, horizon)
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
# Input validation.
|
|
||||||
assert set(batch).issuperset({"observation.state", "observation.image", "action", "action_is_pad"})
|
|
||||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
|
||||||
horizon = batch["action"].shape[1]
|
|
||||||
assert horizon == self.horizon
|
|
||||||
assert n_obs_steps == self.n_obs_steps
|
|
||||||
assert self.n_obs_steps == n_obs_steps
|
|
||||||
|
|
||||||
# Extract image feature (first combine batch and sequence dims).
|
|
||||||
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
|
|
||||||
# Separate batch and sequence dims.
|
|
||||||
img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
|
|
||||||
# Concatenate state and image features then flatten to (B, global_cond_dim).
|
|
||||||
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
|
|
||||||
|
|
||||||
trajectory = batch["action"]
|
|
||||||
|
|
||||||
# Forward diffusion.
|
|
||||||
# Sample noise to add to the trajectory.
|
|
||||||
eps = torch.randn(trajectory.shape, device=trajectory.device)
|
|
||||||
# Sample a random noising timestep for each item in the batch.
|
|
||||||
timesteps = torch.randint(
|
|
||||||
low=0,
|
|
||||||
high=self.noise_scheduler.config.num_train_timesteps,
|
|
||||||
size=(trajectory.shape[0],),
|
|
||||||
device=trajectory.device,
|
|
||||||
).long()
|
|
||||||
# Add noise to the clean trajectories according to the noise magnitude at each timestep.
|
|
||||||
noisy_trajectory = self.noise_scheduler.add_noise(trajectory, eps, timesteps)
|
|
||||||
|
|
||||||
# Run the denoising network (that might denoise the trajectory, or attempt to predict the noise).
|
|
||||||
pred = self.unet(noisy_trajectory, timesteps, global_cond=global_cond)
|
|
||||||
|
|
||||||
# Compute the loss.
|
|
||||||
# The targe is either the original trajectory, or the noise.
|
|
||||||
pred_type = self.noise_scheduler.config.prediction_type
|
|
||||||
if pred_type == "epsilon":
|
|
||||||
target = eps
|
|
||||||
elif pred_type == "sample":
|
|
||||||
target = batch["action"]
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported prediction type {pred_type}")
|
|
||||||
|
|
||||||
loss = F.mse_loss(pred, target, reduction="none")
|
|
||||||
|
|
||||||
# Mask loss wherever the action is padded with copies (edges of the dataset trajectory).
|
|
||||||
if "action_is_pad" in batch:
|
|
||||||
in_episode_bound = ~batch["action_is_pad"]
|
|
||||||
loss = loss * in_episode_bound.unsqueeze(-1)
|
|
||||||
|
|
||||||
return loss.mean()
|
|
|
@ -1,68 +0,0 @@
|
||||||
import torch
|
|
||||||
from torch.nn.modules.batchnorm import _BatchNorm
|
|
||||||
|
|
||||||
|
|
||||||
class EMAModel:
|
|
||||||
"""
|
|
||||||
Exponential Moving Average of models weights
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, model, update_after_step=0, inv_gamma=1.0, power=2 / 3, min_value=0.0, max_value=0.9999
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
@crowsonkb's notes on EMA Warmup:
|
|
||||||
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
|
|
||||||
to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
|
|
||||||
gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
|
|
||||||
at 215.4k steps).
|
|
||||||
Args:
|
|
||||||
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
|
|
||||||
power (float): Exponential factor of EMA warmup. Default: 2/3.
|
|
||||||
min_value (float): The minimum EMA decay rate. Default: 0.
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.averaged_model = model
|
|
||||||
self.averaged_model.eval()
|
|
||||||
self.averaged_model.requires_grad_(False)
|
|
||||||
|
|
||||||
self.update_after_step = update_after_step
|
|
||||||
self.inv_gamma = inv_gamma
|
|
||||||
self.power = power
|
|
||||||
self.min_value = min_value
|
|
||||||
self.max_value = max_value
|
|
||||||
|
|
||||||
self.alpha = 0.0
|
|
||||||
self.optimization_step = 0
|
|
||||||
|
|
||||||
def get_decay(self, optimization_step):
|
|
||||||
"""
|
|
||||||
Compute the decay factor for the exponential moving average.
|
|
||||||
"""
|
|
||||||
step = max(0, optimization_step - self.update_after_step - 1)
|
|
||||||
value = 1 - (1 + step / self.inv_gamma) ** -self.power
|
|
||||||
|
|
||||||
if step <= 0:
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
return max(self.min_value, min(value, self.max_value))
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def step(self, new_model):
|
|
||||||
self.alpha = self.get_decay(self.optimization_step)
|
|
||||||
|
|
||||||
for module, ema_module in zip(new_model.modules(), self.averaged_model.modules(), strict=True):
|
|
||||||
# Iterate over immediate parameters only.
|
|
||||||
for param, ema_param in zip(
|
|
||||||
module.parameters(recurse=False), ema_module.parameters(recurse=False), strict=True
|
|
||||||
):
|
|
||||||
if isinstance(param, dict):
|
|
||||||
raise RuntimeError("Dict parameter not supported")
|
|
||||||
if isinstance(module, _BatchNorm) or not param.requires_grad:
|
|
||||||
# Copy BatchNorm parameters, and non-trainable parameters directly.
|
|
||||||
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
|
|
||||||
else:
|
|
||||||
ema_param.mul_(self.alpha)
|
|
||||||
ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.alpha)
|
|
||||||
|
|
||||||
self.optimization_step += 1
|
|
|
@ -1,147 +0,0 @@
|
||||||
from typing import Callable
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torchvision
|
|
||||||
from robomimic.models.base_nets import SpatialSoftmax
|
|
||||||
from torch import Tensor, nn
|
|
||||||
from torchvision.transforms import CenterCrop, RandomCrop
|
|
||||||
|
|
||||||
|
|
||||||
class RgbEncoder(nn.Module):
|
|
||||||
"""Encoder an RGB image into a 1D feature vector.
|
|
||||||
|
|
||||||
Includes the ability to normalize and crop the image first.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
input_shape: tuple[int, int, int],
|
|
||||||
norm_mean_std: tuple[float, float] = [1.0, 1.0],
|
|
||||||
crop_shape: tuple[int, int] | None = None,
|
|
||||||
random_crop: bool = False,
|
|
||||||
backbone_name: str = "resnet18",
|
|
||||||
pretrained_backbone: bool = False,
|
|
||||||
use_group_norm: bool = False,
|
|
||||||
relu: bool = True,
|
|
||||||
num_keypoints: int = 32,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
input_shape: channel-first input shape (C, H, W)
|
|
||||||
norm_mean_std: mean and standard deviation used for image normalization. Images are normalized as
|
|
||||||
(image - mean) / std.
|
|
||||||
crop_shape: (H, W) shape to crop to (must fit within the input shape). If not provided, no
|
|
||||||
cropping is done.
|
|
||||||
random_crop: Whether the crop should be random at training time (it's always a center crop in
|
|
||||||
eval mode).
|
|
||||||
backbone_name: The name of one of the available resnet models from torchvision (eg resnet18).
|
|
||||||
pretrained_backbone: whether to use timm pretrained weights.
|
|
||||||
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
|
|
||||||
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
|
|
||||||
relu: whether to use relu as a final step.
|
|
||||||
num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image).
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
if input_shape[0] != 3:
|
|
||||||
raise ValueError("Only RGB images are handled")
|
|
||||||
if not backbone_name.startswith("resnet"):
|
|
||||||
raise ValueError(
|
|
||||||
"Only resnet is supported for now (because of the assumption that 'layer4' is the output layer)"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set up optional preprocessing.
|
|
||||||
if norm_mean_std == [1.0, 1.0]:
|
|
||||||
self.normalizer = nn.Identity()
|
|
||||||
else:
|
|
||||||
self.normalizer = torchvision.transforms.Normalize(mean=norm_mean_std[0], std=norm_mean_std[1])
|
|
||||||
|
|
||||||
if crop_shape is not None:
|
|
||||||
self.do_crop = True
|
|
||||||
self.center_crop = CenterCrop(crop_shape) # always use center crop for eval
|
|
||||||
if random_crop:
|
|
||||||
self.maybe_random_crop = RandomCrop(crop_shape)
|
|
||||||
else:
|
|
||||||
self.maybe_random_crop = self.center_crop
|
|
||||||
else:
|
|
||||||
self.do_crop = False
|
|
||||||
|
|
||||||
# Set up backbone.
|
|
||||||
backbone_model = getattr(torchvision.models, backbone_name)(pretrained=pretrained_backbone)
|
|
||||||
# Note: This assumes that the layer4 feature map is children()[-3]
|
|
||||||
# TODO(alexander-soare): Use a safer alternative.
|
|
||||||
self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
|
|
||||||
if use_group_norm:
|
|
||||||
if pretrained_backbone:
|
|
||||||
raise ValueError(
|
|
||||||
"You can't replace BatchNorm in a pretrained model without ruining the weights!"
|
|
||||||
)
|
|
||||||
self.backbone = _replace_submodules(
|
|
||||||
root_module=self.backbone,
|
|
||||||
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
|
|
||||||
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set up pooling and final layers.
|
|
||||||
# Use a dry run to get the feature map shape.
|
|
||||||
with torch.inference_mode():
|
|
||||||
feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:])
|
|
||||||
self.pool = SpatialSoftmax(feat_map_shape, num_kp=num_keypoints)
|
|
||||||
self.feature_dim = num_keypoints * 2
|
|
||||||
self.out = nn.Linear(num_keypoints * 2, self.feature_dim)
|
|
||||||
self.maybe_relu = nn.ReLU() if relu else nn.Identity()
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
x: (B, C, H, W) image tensor with pixel values in [0, 1].
|
|
||||||
Returns:
|
|
||||||
(B, D) image feature.
|
|
||||||
"""
|
|
||||||
# Preprocess: normalize and maybe crop (if it was set up in the __init__).
|
|
||||||
x = self.normalizer(x)
|
|
||||||
if self.do_crop:
|
|
||||||
if self.training: # noqa: SIM108
|
|
||||||
x = self.maybe_random_crop(x)
|
|
||||||
else:
|
|
||||||
# Always use center crop for eval.
|
|
||||||
x = self.center_crop(x)
|
|
||||||
# Extract backbone feature.
|
|
||||||
x = torch.flatten(self.pool(self.backbone(x)), start_dim=1)
|
|
||||||
# Final linear layer.
|
|
||||||
x = self.out(x)
|
|
||||||
# Maybe a final non-linearity.
|
|
||||||
x = self.maybe_relu(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def _replace_submodules(
|
|
||||||
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
|
|
||||||
) -> nn.Module:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
root_module: The module for which the submodules need to be replaced
|
|
||||||
predicate: Takes a module as an argument and must return True if the that module is to be replaced.
|
|
||||||
func: Takes a module as an argument and returns a new module to replace it with.
|
|
||||||
Returns:
|
|
||||||
The root module with its submodules replaced.
|
|
||||||
"""
|
|
||||||
if predicate(root_module):
|
|
||||||
return func(root_module)
|
|
||||||
|
|
||||||
replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
|
|
||||||
for *parents, k in replace_list:
|
|
||||||
parent_module = root_module
|
|
||||||
if len(parents) > 0:
|
|
||||||
parent_module = root_module.get_submodule(".".join(parents))
|
|
||||||
if isinstance(parent_module, nn.Sequential):
|
|
||||||
src_module = parent_module[int(k)]
|
|
||||||
else:
|
|
||||||
src_module = getattr(parent_module, k)
|
|
||||||
tgt_module = func(src_module)
|
|
||||||
if isinstance(parent_module, nn.Sequential):
|
|
||||||
parent_module[int(k)] = tgt_module
|
|
||||||
else:
|
|
||||||
setattr(parent_module, k, tgt_module)
|
|
||||||
# verify that all BN are replaced
|
|
||||||
assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
|
|
||||||
return root_module
|
|
|
@ -0,0 +1,878 @@
|
||||||
|
"""
|
||||||
|
TODO(alexander-soare):
|
||||||
|
- Remove reliance on Robomimic for SpatialSoftmax.
|
||||||
|
- Remove reliance on diffusers for DDPMScheduler.
|
||||||
|
- Move EMA out of policy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
from collections import deque
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import einops
|
||||||
|
import hydra
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F # noqa: N812
|
||||||
|
import torchvision
|
||||||
|
from diffusers.optimization import get_scheduler
|
||||||
|
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||||
|
from robomimic.models.base_nets import SpatialSoftmax
|
||||||
|
from torch import Tensor, nn
|
||||||
|
from torch.nn.modules.batchnorm import _BatchNorm
|
||||||
|
|
||||||
|
from lerobot.common.policies.utils import (
|
||||||
|
get_device_from_parameters,
|
||||||
|
get_dtype_from_parameters,
|
||||||
|
populate_queues,
|
||||||
|
)
|
||||||
|
from lerobot.common.utils import get_safe_torch_device
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DiffusionPolicy(nn.Module):
|
||||||
|
"""
|
||||||
|
Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
|
||||||
|
(paper: https://arxiv.org/abs/2303.04137, code: https://github.com/real-stanford/diffusion_policy).
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "diffusion"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cfg,
|
||||||
|
cfg_device,
|
||||||
|
cfg_noise_scheduler,
|
||||||
|
cfg_optimizer,
|
||||||
|
cfg_ema,
|
||||||
|
shape_meta: dict,
|
||||||
|
horizon,
|
||||||
|
n_action_steps,
|
||||||
|
n_obs_steps,
|
||||||
|
num_inference_steps=None,
|
||||||
|
diffusion_step_embed_dim=256,
|
||||||
|
down_dims=(256, 512, 1024),
|
||||||
|
kernel_size=5,
|
||||||
|
n_groups=8,
|
||||||
|
film_scale_modulation=True,
|
||||||
|
**_,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.cfg = cfg
|
||||||
|
self.n_obs_steps = n_obs_steps
|
||||||
|
self.n_action_steps = n_action_steps
|
||||||
|
|
||||||
|
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
||||||
|
self._queues = None
|
||||||
|
|
||||||
|
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
|
||||||
|
|
||||||
|
self.diffusion = _DiffusionUnetImagePolicy(
|
||||||
|
cfg,
|
||||||
|
shape_meta=shape_meta,
|
||||||
|
noise_scheduler=noise_scheduler,
|
||||||
|
horizon=horizon,
|
||||||
|
n_action_steps=n_action_steps,
|
||||||
|
n_obs_steps=n_obs_steps,
|
||||||
|
num_inference_steps=num_inference_steps,
|
||||||
|
diffusion_step_embed_dim=diffusion_step_embed_dim,
|
||||||
|
down_dims=down_dims,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
n_groups=n_groups,
|
||||||
|
film_scale_modulation=film_scale_modulation,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.device = get_safe_torch_device(cfg_device)
|
||||||
|
self.diffusion.to(self.device)
|
||||||
|
|
||||||
|
# TODO(alexander-soare): This should probably be managed outside of the policy class.
|
||||||
|
self.ema_diffusion = None
|
||||||
|
self.ema = None
|
||||||
|
if self.cfg.use_ema:
|
||||||
|
self.ema_diffusion = copy.deepcopy(self.diffusion)
|
||||||
|
self.ema = hydra.utils.instantiate(
|
||||||
|
cfg_ema,
|
||||||
|
model=self.ema_diffusion,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.optimizer = hydra.utils.instantiate(
|
||||||
|
cfg_optimizer,
|
||||||
|
params=self.diffusion.parameters(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO(rcadene): modify lr scheduler so that it doesnt depend on epochs but steps
|
||||||
|
self.global_step = 0
|
||||||
|
|
||||||
|
# configure lr scheduler
|
||||||
|
self.lr_scheduler = get_scheduler(
|
||||||
|
cfg.lr_scheduler,
|
||||||
|
optimizer=self.optimizer,
|
||||||
|
num_warmup_steps=cfg.lr_warmup_steps,
|
||||||
|
num_training_steps=cfg.offline_steps,
|
||||||
|
# pytorch assumes stepping LRScheduler every epoch
|
||||||
|
# however huggingface diffusers steps it every batch
|
||||||
|
last_epoch=self.global_step - 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""
|
||||||
|
Clear observation and action queues. Should be called on `env.reset()`
|
||||||
|
"""
|
||||||
|
self._queues = {
|
||||||
|
"observation.image": deque(maxlen=self.n_obs_steps),
|
||||||
|
"observation.state": deque(maxlen=self.n_obs_steps),
|
||||||
|
"action": deque(maxlen=self.n_action_steps),
|
||||||
|
}
|
||||||
|
|
||||||
|
@torch.no_grad
|
||||||
|
def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
|
||||||
|
"""Select a single action given environment observations.
|
||||||
|
|
||||||
|
This method handles caching a history of observations and an action trajectory generated by the
|
||||||
|
underlying diffusion model. Here's how it works:
|
||||||
|
- `n_obs_steps` steps worth of observations are cached (for the first steps, the observation is
|
||||||
|
copied `n_obs_steps` times to fill the cache).
|
||||||
|
- The diffusion model generates `horizon` steps worth of actions.
|
||||||
|
- `n_action_steps` worth of actions are actually kept for execution, starting from the current step.
|
||||||
|
Schematically this looks like:
|
||||||
|
(legend: o = n_obs_steps, h = horizon, a = n_action_steps)
|
||||||
|
|timestep | n-o+1 | n-o+2 | ..... | n | ..... | n+a-1 | n+a | ..... |n-o+1+h|
|
||||||
|
|observation is used | YES | YES | ..... | NO | NO | NO | NO | NO | NO |
|
||||||
|
|action is generated | YES | YES | YES | YES | YES | YES | YES | YES | YES |
|
||||||
|
|action is used | NO | NO | NO | YES | YES | YES | NO | NO | NO |
|
||||||
|
Note that this means we require: `n_action_steps < horizon - n_obs_steps + 1`. Also, note that
|
||||||
|
"horizon" may not the best name to describe what the variable actually means, because this period is
|
||||||
|
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
|
||||||
|
|
||||||
|
Note: this method uses the ema model weights if self.training == False, otherwise the non-ema model
|
||||||
|
weights.
|
||||||
|
"""
|
||||||
|
assert "observation.image" in batch
|
||||||
|
assert "observation.state" in batch
|
||||||
|
assert len(batch) == 2
|
||||||
|
|
||||||
|
self._queues = populate_queues(self._queues, batch)
|
||||||
|
|
||||||
|
if len(self._queues["action"]) == 0:
|
||||||
|
# stack n latest observations from the queue
|
||||||
|
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
||||||
|
if not self.training and self.ema_diffusion is not None:
|
||||||
|
actions = self.ema_diffusion.generate_actions(batch)
|
||||||
|
else:
|
||||||
|
actions = self.diffusion.generate_actions(batch)
|
||||||
|
self._queues["action"].extend(actions.transpose(0, 1))
|
||||||
|
|
||||||
|
action = self._queues["action"].popleft()
|
||||||
|
return action
|
||||||
|
|
||||||
|
def forward(self, batch, **_):
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
self.diffusion.train()
|
||||||
|
|
||||||
|
loss = self.diffusion.compute_loss(batch)
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
|
self.diffusion.parameters(),
|
||||||
|
self.cfg.grad_clip_norm,
|
||||||
|
error_if_nonfinite=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.optimizer.step()
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
self.lr_scheduler.step()
|
||||||
|
|
||||||
|
if self.ema is not None:
|
||||||
|
self.ema.step(self.diffusion)
|
||||||
|
|
||||||
|
info = {
|
||||||
|
"loss": loss.item(),
|
||||||
|
"grad_norm": float(grad_norm),
|
||||||
|
"lr": self.lr_scheduler.get_last_lr()[0],
|
||||||
|
"update_s": time.time() - start_time,
|
||||||
|
}
|
||||||
|
|
||||||
|
return info
|
||||||
|
|
||||||
|
def save(self, fp):
|
||||||
|
torch.save(self.state_dict(), fp)
|
||||||
|
|
||||||
|
def load(self, fp):
|
||||||
|
d = torch.load(fp)
|
||||||
|
missing_keys, unexpected_keys = self.load_state_dict(d, strict=False)
|
||||||
|
if len(missing_keys) > 0:
|
||||||
|
assert all(k.startswith("ema_diffusion.") for k in missing_keys)
|
||||||
|
logging.warning(
|
||||||
|
"DiffusionPolicy.load expected ema parameters in loaded state dict but none were found."
|
||||||
|
)
|
||||||
|
assert len(unexpected_keys) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class _DiffusionUnetImagePolicy(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cfg,
|
||||||
|
shape_meta: dict,
|
||||||
|
noise_scheduler: DDPMScheduler,
|
||||||
|
horizon,
|
||||||
|
n_action_steps,
|
||||||
|
n_obs_steps,
|
||||||
|
num_inference_steps=None,
|
||||||
|
diffusion_step_embed_dim=256,
|
||||||
|
down_dims=(256, 512, 1024),
|
||||||
|
kernel_size=5,
|
||||||
|
n_groups=8,
|
||||||
|
film_scale_modulation=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
action_shape = shape_meta["action"]["shape"]
|
||||||
|
assert len(action_shape) == 1
|
||||||
|
action_dim = action_shape[0]
|
||||||
|
|
||||||
|
self.rgb_encoder = _RgbEncoder(input_shape=shape_meta.obs.image.shape, **cfg.rgb_encoder)
|
||||||
|
|
||||||
|
self.unet = _ConditionalUnet1D(
|
||||||
|
input_dim=action_dim,
|
||||||
|
global_cond_dim=(action_dim + self.rgb_encoder.feature_dim) * n_obs_steps,
|
||||||
|
diffusion_step_embed_dim=diffusion_step_embed_dim,
|
||||||
|
down_dims=down_dims,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
n_groups=n_groups,
|
||||||
|
film_scale_modulation=film_scale_modulation,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.noise_scheduler = noise_scheduler
|
||||||
|
self.horizon = horizon
|
||||||
|
self.action_dim = action_dim
|
||||||
|
self.n_action_steps = n_action_steps
|
||||||
|
self.n_obs_steps = n_obs_steps
|
||||||
|
|
||||||
|
if num_inference_steps is None:
|
||||||
|
num_inference_steps = noise_scheduler.config.num_train_timesteps
|
||||||
|
|
||||||
|
self.num_inference_steps = num_inference_steps
|
||||||
|
|
||||||
|
# ========= inference ============
|
||||||
|
def conditional_sample(self, batch_size, global_cond=None, generator=None):
|
||||||
|
device = get_device_from_parameters(self)
|
||||||
|
dtype = get_dtype_from_parameters(self)
|
||||||
|
|
||||||
|
# Sample prior.
|
||||||
|
sample = torch.randn(
|
||||||
|
size=(batch_size, self.horizon, self.action_dim),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
generator=generator,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.noise_scheduler.set_timesteps(self.num_inference_steps)
|
||||||
|
|
||||||
|
for t in self.noise_scheduler.timesteps:
|
||||||
|
# Predict model output.
|
||||||
|
model_output = self.unet(
|
||||||
|
sample,
|
||||||
|
torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device),
|
||||||
|
global_cond=global_cond,
|
||||||
|
)
|
||||||
|
# Compute previous image: x_t -> x_t-1
|
||||||
|
sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample
|
||||||
|
|
||||||
|
return sample
|
||||||
|
|
||||||
|
def generate_actions(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
|
"""
|
||||||
|
This function expects `batch` to have (at least):
|
||||||
|
{
|
||||||
|
"observation.state": (B, n_obs_steps, state_dim)
|
||||||
|
"observation.image": (B, n_obs_steps, C, H, W)
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
assert set(batch).issuperset({"observation.state", "observation.image"})
|
||||||
|
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||||
|
assert n_obs_steps == self.n_obs_steps
|
||||||
|
assert self.n_obs_steps == n_obs_steps
|
||||||
|
|
||||||
|
# Extract image feature (first combine batch and sequence dims).
|
||||||
|
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
|
||||||
|
# Separate batch and sequence dims.
|
||||||
|
img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
|
||||||
|
# Concatenate state and image features then flatten to (B, global_cond_dim).
|
||||||
|
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
|
||||||
|
|
||||||
|
# run sampling
|
||||||
|
sample = self.conditional_sample(batch_size, global_cond=global_cond)
|
||||||
|
|
||||||
|
# `horizon` steps worth of actions (from the first observation).
|
||||||
|
action = sample[..., : self.action_dim]
|
||||||
|
# Extract `n_action_steps` steps worth of actions (from the current observation).
|
||||||
|
start = n_obs_steps - 1
|
||||||
|
end = start + self.n_action_steps
|
||||||
|
action = action[:, start:end]
|
||||||
|
|
||||||
|
return action
|
||||||
|
|
||||||
|
def compute_loss(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
|
"""
|
||||||
|
This function expects `batch` to have (at least):
|
||||||
|
{
|
||||||
|
"observation.state": (B, n_obs_steps, state_dim)
|
||||||
|
"observation.image": (B, n_obs_steps, C, H, W)
|
||||||
|
"action": (B, horizon, action_dim)
|
||||||
|
"action_is_pad": (B, horizon)
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
# Input validation.
|
||||||
|
assert set(batch).issuperset({"observation.state", "observation.image", "action", "action_is_pad"})
|
||||||
|
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||||
|
horizon = batch["action"].shape[1]
|
||||||
|
assert horizon == self.horizon
|
||||||
|
assert n_obs_steps == self.n_obs_steps
|
||||||
|
assert self.n_obs_steps == n_obs_steps
|
||||||
|
|
||||||
|
# Extract image feature (first combine batch and sequence dims).
|
||||||
|
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
|
||||||
|
# Separate batch and sequence dims.
|
||||||
|
img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
|
||||||
|
# Concatenate state and image features then flatten to (B, global_cond_dim).
|
||||||
|
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
|
||||||
|
|
||||||
|
trajectory = batch["action"]
|
||||||
|
|
||||||
|
# Forward diffusion.
|
||||||
|
# Sample noise to add to the trajectory.
|
||||||
|
eps = torch.randn(trajectory.shape, device=trajectory.device)
|
||||||
|
# Sample a random noising timestep for each item in the batch.
|
||||||
|
timesteps = torch.randint(
|
||||||
|
low=0,
|
||||||
|
high=self.noise_scheduler.config.num_train_timesteps,
|
||||||
|
size=(trajectory.shape[0],),
|
||||||
|
device=trajectory.device,
|
||||||
|
).long()
|
||||||
|
# Add noise to the clean trajectories according to the noise magnitude at each timestep.
|
||||||
|
noisy_trajectory = self.noise_scheduler.add_noise(trajectory, eps, timesteps)
|
||||||
|
|
||||||
|
# Run the denoising network (that might denoise the trajectory, or attempt to predict the noise).
|
||||||
|
pred = self.unet(noisy_trajectory, timesteps, global_cond=global_cond)
|
||||||
|
|
||||||
|
# Compute the loss.
|
||||||
|
# The targe is either the original trajectory, or the noise.
|
||||||
|
pred_type = self.noise_scheduler.config.prediction_type
|
||||||
|
if pred_type == "epsilon":
|
||||||
|
target = eps
|
||||||
|
elif pred_type == "sample":
|
||||||
|
target = batch["action"]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported prediction type {pred_type}")
|
||||||
|
|
||||||
|
loss = F.mse_loss(pred, target, reduction="none")
|
||||||
|
|
||||||
|
# Mask loss wherever the action is padded with copies (edges of the dataset trajectory).
|
||||||
|
if "action_is_pad" in batch:
|
||||||
|
in_episode_bound = ~batch["action_is_pad"]
|
||||||
|
loss = loss * in_episode_bound.unsqueeze(-1)
|
||||||
|
|
||||||
|
return loss.mean()
|
||||||
|
|
||||||
|
|
||||||
|
class _RgbEncoder(nn.Module):
|
||||||
|
"""Encoder an RGB image into a 1D feature vector.
|
||||||
|
|
||||||
|
Includes the ability to normalize and crop the image first.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_shape: tuple[int, int, int],
|
||||||
|
norm_mean_std: tuple[float, float] = [1.0, 1.0],
|
||||||
|
crop_shape: tuple[int, int] | None = None,
|
||||||
|
random_crop: bool = False,
|
||||||
|
backbone_name: str = "resnet18",
|
||||||
|
pretrained_backbone: bool = False,
|
||||||
|
use_group_norm: bool = False,
|
||||||
|
num_keypoints: int = 32,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
input_shape: channel-first input shape (C, H, W)
|
||||||
|
norm_mean_std: mean and standard deviation used for image normalization. Images are normalized as
|
||||||
|
(image - mean) / std.
|
||||||
|
crop_shape: (H, W) shape to crop to (must fit within the input shape). If not provided, no
|
||||||
|
cropping is done.
|
||||||
|
random_crop: Whether the crop should be random at training time (it's always a center crop in
|
||||||
|
eval mode).
|
||||||
|
backbone_name: The name of one of the available resnet models from torchvision (eg resnet18).
|
||||||
|
pretrained_backbone: whether to use timm pretrained weights.
|
||||||
|
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
|
||||||
|
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
|
||||||
|
num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image).
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
if input_shape[0] != 3:
|
||||||
|
raise ValueError("Only RGB images are handled")
|
||||||
|
if not backbone_name.startswith("resnet"):
|
||||||
|
raise ValueError(
|
||||||
|
"Only resnet is supported for now (because of the assumption that 'layer4' is the output layer)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set up optional preprocessing.
|
||||||
|
if norm_mean_std == [1.0, 1.0]:
|
||||||
|
self.normalizer = nn.Identity()
|
||||||
|
else:
|
||||||
|
self.normalizer = torchvision.transforms.Normalize(mean=norm_mean_std[0], std=norm_mean_std[1])
|
||||||
|
|
||||||
|
if crop_shape is not None:
|
||||||
|
self.do_crop = True
|
||||||
|
# Always use center crop for eval
|
||||||
|
self.center_crop = torchvision.transforms.CenterCrop(crop_shape)
|
||||||
|
if random_crop:
|
||||||
|
self.maybe_random_crop = torchvision.transforms.RandomCrop(crop_shape)
|
||||||
|
else:
|
||||||
|
self.maybe_random_crop = self.center_crop
|
||||||
|
else:
|
||||||
|
self.do_crop = False
|
||||||
|
|
||||||
|
# Set up backbone.
|
||||||
|
backbone_model = getattr(torchvision.models, backbone_name)(pretrained=pretrained_backbone)
|
||||||
|
# Note: This assumes that the layer4 feature map is children()[-3]
|
||||||
|
# TODO(alexander-soare): Use a safer alternative.
|
||||||
|
self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
|
||||||
|
if use_group_norm:
|
||||||
|
if pretrained_backbone:
|
||||||
|
raise ValueError(
|
||||||
|
"You can't replace BatchNorm in a pretrained model without ruining the weights!"
|
||||||
|
)
|
||||||
|
self.backbone = _replace_submodules(
|
||||||
|
root_module=self.backbone,
|
||||||
|
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
|
||||||
|
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set up pooling and final layers.
|
||||||
|
# Use a dry run to get the feature map shape.
|
||||||
|
with torch.inference_mode():
|
||||||
|
feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:])
|
||||||
|
self.pool = SpatialSoftmax(feat_map_shape, num_kp=num_keypoints)
|
||||||
|
self.feature_dim = num_keypoints * 2
|
||||||
|
self.out = nn.Linear(num_keypoints * 2, self.feature_dim)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: (B, C, H, W) image tensor with pixel values in [0, 1].
|
||||||
|
Returns:
|
||||||
|
(B, D) image feature.
|
||||||
|
"""
|
||||||
|
# Preprocess: normalize and maybe crop (if it was set up in the __init__).
|
||||||
|
x = self.normalizer(x)
|
||||||
|
if self.do_crop:
|
||||||
|
if self.training: # noqa: SIM108
|
||||||
|
x = self.maybe_random_crop(x)
|
||||||
|
else:
|
||||||
|
# Always use center crop for eval.
|
||||||
|
x = self.center_crop(x)
|
||||||
|
# Extract backbone feature.
|
||||||
|
x = torch.flatten(self.pool(self.backbone(x)), start_dim=1)
|
||||||
|
# Final linear layer with non-linearity.
|
||||||
|
x = self.relu(self.out(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def _replace_submodules(
|
||||||
|
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
|
||||||
|
) -> nn.Module:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
root_module: The module for which the submodules need to be replaced
|
||||||
|
predicate: Takes a module as an argument and must return True if the that module is to be replaced.
|
||||||
|
func: Takes a module as an argument and returns a new module to replace it with.
|
||||||
|
Returns:
|
||||||
|
The root module with its submodules replaced.
|
||||||
|
"""
|
||||||
|
if predicate(root_module):
|
||||||
|
return func(root_module)
|
||||||
|
|
||||||
|
replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
|
||||||
|
for *parents, k in replace_list:
|
||||||
|
parent_module = root_module
|
||||||
|
if len(parents) > 0:
|
||||||
|
parent_module = root_module.get_submodule(".".join(parents))
|
||||||
|
if isinstance(parent_module, nn.Sequential):
|
||||||
|
src_module = parent_module[int(k)]
|
||||||
|
else:
|
||||||
|
src_module = getattr(parent_module, k)
|
||||||
|
tgt_module = func(src_module)
|
||||||
|
if isinstance(parent_module, nn.Sequential):
|
||||||
|
parent_module[int(k)] = tgt_module
|
||||||
|
else:
|
||||||
|
setattr(parent_module, k, tgt_module)
|
||||||
|
# verify that all BN are replaced
|
||||||
|
assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
|
||||||
|
return root_module
|
||||||
|
|
||||||
|
|
||||||
|
class _SinusoidalPosEmb(nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
device = x.device
|
||||||
|
half_dim = self.dim // 2
|
||||||
|
emb = math.log(10000) / (half_dim - 1)
|
||||||
|
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
||||||
|
emb = x[:, None] * emb[None, :]
|
||||||
|
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
class _Conv1dBlock(nn.Module):
|
||||||
|
"""Conv1d --> GroupNorm --> Mish"""
|
||||||
|
|
||||||
|
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.block = nn.Sequential(
|
||||||
|
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
|
||||||
|
nn.GroupNorm(n_groups, out_channels),
|
||||||
|
nn.Mish(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.block(x)
|
||||||
|
|
||||||
|
|
||||||
|
class _ConditionalUnet1D(nn.Module):
|
||||||
|
"""A 1D convolutional UNet with FiLM modulation for conditioning.
|
||||||
|
|
||||||
|
Two types of conditioning can be applied:
|
||||||
|
- Global: Conditioning information that is aggregated over the whole observation window. This is
|
||||||
|
incorporated via the FiLM technique in the residual convolution blocks of the Unet's encoder/decoder.
|
||||||
|
- Local: Conditioning information for each timestep in the observation window. This is incorporated
|
||||||
|
by encoding the information via 1D convolutions and adding the resulting embeddings to the inputs and
|
||||||
|
outputs of the Unet's encoder/decoder.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dim: int,
|
||||||
|
local_cond_dim: int | None = None,
|
||||||
|
global_cond_dim: int | None = None,
|
||||||
|
diffusion_step_embed_dim: int = 256,
|
||||||
|
down_dims: int | None = None,
|
||||||
|
kernel_size: int = 3,
|
||||||
|
n_groups: int = 8,
|
||||||
|
film_scale_modulation: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if down_dims is None:
|
||||||
|
down_dims = [256, 512, 1024]
|
||||||
|
|
||||||
|
# Encoder for the diffusion timestep.
|
||||||
|
self.diffusion_step_encoder = nn.Sequential(
|
||||||
|
_SinusoidalPosEmb(diffusion_step_embed_dim),
|
||||||
|
nn.Linear(diffusion_step_embed_dim, diffusion_step_embed_dim * 4),
|
||||||
|
nn.Mish(),
|
||||||
|
nn.Linear(diffusion_step_embed_dim * 4, diffusion_step_embed_dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
# The FiLM conditioning dimension.
|
||||||
|
cond_dim = diffusion_step_embed_dim
|
||||||
|
if global_cond_dim is not None:
|
||||||
|
cond_dim += global_cond_dim
|
||||||
|
|
||||||
|
self.local_cond_down_encoder = None
|
||||||
|
self.local_cond_up_encoder = None
|
||||||
|
if local_cond_dim is not None:
|
||||||
|
# Encoder for the local conditioning. The output gets added to the Unet encoder input.
|
||||||
|
self.local_cond_down_encoder = _ConditionalResidualBlock1D(
|
||||||
|
local_cond_dim,
|
||||||
|
down_dims[0],
|
||||||
|
cond_dim=cond_dim,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
n_groups=n_groups,
|
||||||
|
film_scale_modulation=film_scale_modulation,
|
||||||
|
)
|
||||||
|
# Encoder for the local conditioning. The output gets added to the Unet encoder output.
|
||||||
|
self.local_cond_up_encoder = _ConditionalResidualBlock1D(
|
||||||
|
local_cond_dim,
|
||||||
|
down_dims[0],
|
||||||
|
cond_dim=cond_dim,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
n_groups=n_groups,
|
||||||
|
film_scale_modulation=film_scale_modulation,
|
||||||
|
)
|
||||||
|
|
||||||
|
# In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we
|
||||||
|
# just reverse these.
|
||||||
|
in_out = [(input_dim, down_dims[0])] + list(zip(down_dims[:-1], down_dims[1:], strict=True))
|
||||||
|
|
||||||
|
# Unet encoder.
|
||||||
|
self.down_modules = nn.ModuleList([])
|
||||||
|
for ind, (dim_in, dim_out) in enumerate(in_out):
|
||||||
|
is_last = ind >= (len(in_out) - 1)
|
||||||
|
self.down_modules.append(
|
||||||
|
nn.ModuleList(
|
||||||
|
[
|
||||||
|
_ConditionalResidualBlock1D(
|
||||||
|
dim_in,
|
||||||
|
dim_out,
|
||||||
|
cond_dim=cond_dim,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
n_groups=n_groups,
|
||||||
|
film_scale_modulation=film_scale_modulation,
|
||||||
|
),
|
||||||
|
_ConditionalResidualBlock1D(
|
||||||
|
dim_out,
|
||||||
|
dim_out,
|
||||||
|
cond_dim=cond_dim,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
n_groups=n_groups,
|
||||||
|
film_scale_modulation=film_scale_modulation,
|
||||||
|
),
|
||||||
|
# Downsample as long as it is not the last block.
|
||||||
|
nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Processing in the middle of the auto-encoder.
|
||||||
|
self.mid_modules = nn.ModuleList(
|
||||||
|
[
|
||||||
|
_ConditionalResidualBlock1D(
|
||||||
|
down_dims[-1],
|
||||||
|
down_dims[-1],
|
||||||
|
cond_dim=cond_dim,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
n_groups=n_groups,
|
||||||
|
film_scale_modulation=film_scale_modulation,
|
||||||
|
),
|
||||||
|
_ConditionalResidualBlock1D(
|
||||||
|
down_dims[-1],
|
||||||
|
down_dims[-1],
|
||||||
|
cond_dim=cond_dim,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
n_groups=n_groups,
|
||||||
|
film_scale_modulation=film_scale_modulation,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Unet decoder.
|
||||||
|
self.up_modules = nn.ModuleList([])
|
||||||
|
for ind, (dim_out, dim_in) in enumerate(reversed(in_out[1:])):
|
||||||
|
is_last = ind >= (len(in_out) - 1)
|
||||||
|
self.up_modules.append(
|
||||||
|
nn.ModuleList(
|
||||||
|
[
|
||||||
|
_ConditionalResidualBlock1D(
|
||||||
|
dim_in * 2, # x2 as it takes the encoder's skip connection as well
|
||||||
|
dim_out,
|
||||||
|
cond_dim=cond_dim,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
n_groups=n_groups,
|
||||||
|
film_scale_modulation=film_scale_modulation,
|
||||||
|
),
|
||||||
|
_ConditionalResidualBlock1D(
|
||||||
|
dim_out,
|
||||||
|
dim_out,
|
||||||
|
cond_dim=cond_dim,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
n_groups=n_groups,
|
||||||
|
film_scale_modulation=film_scale_modulation,
|
||||||
|
),
|
||||||
|
# Upsample as long as it is not the last block.
|
||||||
|
nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.final_conv = nn.Sequential(
|
||||||
|
_Conv1dBlock(down_dims[0], down_dims[0], kernel_size=kernel_size),
|
||||||
|
nn.Conv1d(down_dims[0], input_dim, 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, timestep: Tensor | int, local_cond=None, global_cond=None) -> Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: (B, T, input_dim) tensor for input to the Unet.
|
||||||
|
timestep: (B,) tensor of (timestep_we_are_denoising_from - 1).
|
||||||
|
local_cond: (B, T, local_cond_dim)
|
||||||
|
global_cond: (B, global_cond_dim)
|
||||||
|
output: (B, T, input_dim)
|
||||||
|
Returns:
|
||||||
|
(B, T, input_dim)
|
||||||
|
"""
|
||||||
|
# For 1D convolutions we'll need feature dimension first.
|
||||||
|
x = einops.rearrange(x, "b t d -> b d t")
|
||||||
|
if local_cond is not None:
|
||||||
|
if self.local_cond_down_encoder is None or self.local_cond_up_encoder is None:
|
||||||
|
raise ValueError(
|
||||||
|
"`local_cond` was provided but the relevant encoders weren't built at initialization."
|
||||||
|
)
|
||||||
|
local_cond = einops.rearrange(local_cond, "b t d -> b d t")
|
||||||
|
|
||||||
|
timesteps_embed = self.diffusion_step_encoder(timestep)
|
||||||
|
|
||||||
|
# If there is a global conditioning feature, concatenate it to the timestep embedding.
|
||||||
|
if global_cond is not None:
|
||||||
|
global_feature = torch.cat([timesteps_embed, global_cond], axis=-1)
|
||||||
|
else:
|
||||||
|
global_feature = timesteps_embed
|
||||||
|
|
||||||
|
encoder_skip_features: list[Tensor] = []
|
||||||
|
for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
|
||||||
|
x = resnet(x, global_feature)
|
||||||
|
if idx == 0 and local_cond is not None:
|
||||||
|
x = x + self.local_cond_down_encoder(local_cond, global_feature)
|
||||||
|
x = resnet2(x, global_feature)
|
||||||
|
encoder_skip_features.append(x)
|
||||||
|
x = downsample(x)
|
||||||
|
|
||||||
|
for mid_module in self.mid_modules:
|
||||||
|
x = mid_module(x, global_feature)
|
||||||
|
|
||||||
|
for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
|
||||||
|
x = torch.cat((x, encoder_skip_features.pop()), dim=1)
|
||||||
|
x = resnet(x, global_feature)
|
||||||
|
# Note: The condition in the original implementation is:
|
||||||
|
# if idx == len(self.up_modules) and local_cond is not None:
|
||||||
|
# But as they mention in their comments, this is incorrect. We use the correct condition here.
|
||||||
|
if idx == (len(self.up_modules) - 1) and local_cond is not None:
|
||||||
|
x = x + self.local_cond_up_encoder(local_cond, global_feature)
|
||||||
|
x = resnet2(x, global_feature)
|
||||||
|
x = upsample(x)
|
||||||
|
|
||||||
|
x = self.final_conv(x)
|
||||||
|
|
||||||
|
x = einops.rearrange(x, "b d t -> b t d")
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class _ConditionalResidualBlock1D(nn.Module):
|
||||||
|
"""ResNet style 1D convolutional block with FiLM modulation for conditioning."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
cond_dim: int,
|
||||||
|
kernel_size: int = 3,
|
||||||
|
n_groups: int = 8,
|
||||||
|
# Set to True to do scale modulation with FiLM as well as bias modulation (defaults to False meaning
|
||||||
|
# FiLM just modulates bias).
|
||||||
|
film_scale_modulation: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.film_scale_modulation = film_scale_modulation
|
||||||
|
self.out_channels = out_channels
|
||||||
|
|
||||||
|
self.conv1 = _Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups)
|
||||||
|
|
||||||
|
# FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale.
|
||||||
|
cond_channels = out_channels * 2 if film_scale_modulation else out_channels
|
||||||
|
self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels))
|
||||||
|
|
||||||
|
self.conv2 = _Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups)
|
||||||
|
|
||||||
|
# A final convolution for dimension matching the residual (if needed).
|
||||||
|
self.residual_conv = (
|
||||||
|
nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, cond: Tensor) -> Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: (B, in_channels, T)
|
||||||
|
cond: (B, cond_dim)
|
||||||
|
Returns:
|
||||||
|
(B, out_channels, T)
|
||||||
|
"""
|
||||||
|
out = self.conv1(x)
|
||||||
|
|
||||||
|
# Get condition embedding. Unsqueeze for broadcasting to `out`, resulting in (B, out_channels, 1).
|
||||||
|
cond_embed = self.cond_encoder(cond).unsqueeze(-1)
|
||||||
|
if self.film_scale_modulation:
|
||||||
|
# Treat the embedding as a list of scales and biases.
|
||||||
|
scale = cond_embed[:, : self.out_channels]
|
||||||
|
bias = cond_embed[:, self.out_channels :]
|
||||||
|
out = scale * out + bias
|
||||||
|
else:
|
||||||
|
# Treat the embedding as biases.
|
||||||
|
out = out + cond_embed
|
||||||
|
|
||||||
|
out = self.conv2(out)
|
||||||
|
out = out + self.residual_conv(x)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class _EMA:
|
||||||
|
"""
|
||||||
|
Exponential Moving Average of models weights
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, model, update_after_step=0, inv_gamma=1.0, power=2 / 3, min_value=0.0, max_value=0.9999
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
@crowsonkb's notes on EMA Warmup:
|
||||||
|
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
|
||||||
|
to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
|
||||||
|
gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
|
||||||
|
at 215.4k steps).
|
||||||
|
Args:
|
||||||
|
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
|
||||||
|
power (float): Exponential factor of EMA warmup. Default: 2/3.
|
||||||
|
min_value (float): The minimum EMA decay rate. Default: 0.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.averaged_model = model
|
||||||
|
self.averaged_model.eval()
|
||||||
|
self.averaged_model.requires_grad_(False)
|
||||||
|
|
||||||
|
self.update_after_step = update_after_step
|
||||||
|
self.inv_gamma = inv_gamma
|
||||||
|
self.power = power
|
||||||
|
self.min_value = min_value
|
||||||
|
self.max_value = max_value
|
||||||
|
|
||||||
|
self.alpha = 0.0
|
||||||
|
self.optimization_step = 0
|
||||||
|
|
||||||
|
def get_decay(self, optimization_step):
|
||||||
|
"""
|
||||||
|
Compute the decay factor for the exponential moving average.
|
||||||
|
"""
|
||||||
|
step = max(0, optimization_step - self.update_after_step - 1)
|
||||||
|
value = 1 - (1 + step / self.inv_gamma) ** -self.power
|
||||||
|
|
||||||
|
if step <= 0:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
return max(self.min_value, min(value, self.max_value))
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def step(self, new_model):
|
||||||
|
self.alpha = self.get_decay(self.optimization_step)
|
||||||
|
|
||||||
|
for module, ema_module in zip(new_model.modules(), self.averaged_model.modules(), strict=True):
|
||||||
|
# Iterate over immediate parameters only.
|
||||||
|
for param, ema_param in zip(
|
||||||
|
module.parameters(recurse=False), ema_module.parameters(recurse=False), strict=True
|
||||||
|
):
|
||||||
|
if isinstance(param, dict):
|
||||||
|
raise RuntimeError("Dict parameter not supported")
|
||||||
|
if isinstance(module, _BatchNorm) or not param.requires_grad:
|
||||||
|
# Copy BatchNorm parameters, and non-trainable parameters directly.
|
||||||
|
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
|
||||||
|
else:
|
||||||
|
ema_param.mul_(self.alpha)
|
||||||
|
ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.alpha)
|
||||||
|
|
||||||
|
self.optimization_step += 1
|
|
@ -1,169 +0,0 @@
|
||||||
import copy
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
from collections import deque
|
|
||||||
|
|
||||||
import hydra
|
|
||||||
import torch
|
|
||||||
from diffusers.optimization import get_scheduler
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
from lerobot.common.policies.diffusion.model.diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
|
||||||
from lerobot.common.policies.utils import populate_queues
|
|
||||||
from lerobot.common.utils import get_safe_torch_device
|
|
||||||
|
|
||||||
|
|
||||||
class DiffusionPolicy(nn.Module):
|
|
||||||
name = "diffusion"
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
cfg,
|
|
||||||
cfg_device,
|
|
||||||
cfg_noise_scheduler,
|
|
||||||
cfg_optimizer,
|
|
||||||
cfg_ema,
|
|
||||||
shape_meta: dict,
|
|
||||||
horizon,
|
|
||||||
n_action_steps,
|
|
||||||
n_obs_steps,
|
|
||||||
num_inference_steps=None,
|
|
||||||
diffusion_step_embed_dim=256,
|
|
||||||
down_dims=(256, 512, 1024),
|
|
||||||
kernel_size=5,
|
|
||||||
n_groups=8,
|
|
||||||
film_scale_modulation=True,
|
|
||||||
**_,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.cfg = cfg
|
|
||||||
self.n_obs_steps = n_obs_steps
|
|
||||||
self.n_action_steps = n_action_steps
|
|
||||||
|
|
||||||
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
|
||||||
self._queues = None
|
|
||||||
|
|
||||||
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
|
|
||||||
|
|
||||||
self.diffusion = DiffusionUnetImagePolicy(
|
|
||||||
cfg,
|
|
||||||
shape_meta=shape_meta,
|
|
||||||
noise_scheduler=noise_scheduler,
|
|
||||||
horizon=horizon,
|
|
||||||
n_action_steps=n_action_steps,
|
|
||||||
n_obs_steps=n_obs_steps,
|
|
||||||
num_inference_steps=num_inference_steps,
|
|
||||||
diffusion_step_embed_dim=diffusion_step_embed_dim,
|
|
||||||
down_dims=down_dims,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
n_groups=n_groups,
|
|
||||||
film_scale_modulation=film_scale_modulation,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.device = get_safe_torch_device(cfg_device)
|
|
||||||
self.diffusion.to(self.device)
|
|
||||||
|
|
||||||
# TODO(alexander-soare): This should probably be managed outside of the policy class.
|
|
||||||
self.ema_diffusion = None
|
|
||||||
self.ema = None
|
|
||||||
if self.cfg.use_ema:
|
|
||||||
self.ema_diffusion = copy.deepcopy(self.diffusion)
|
|
||||||
self.ema = hydra.utils.instantiate(
|
|
||||||
cfg_ema,
|
|
||||||
model=self.ema_diffusion,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.optimizer = hydra.utils.instantiate(
|
|
||||||
cfg_optimizer,
|
|
||||||
params=self.diffusion.parameters(),
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO(rcadene): modify lr scheduler so that it doesnt depend on epochs but steps
|
|
||||||
self.global_step = 0
|
|
||||||
|
|
||||||
# configure lr scheduler
|
|
||||||
self.lr_scheduler = get_scheduler(
|
|
||||||
cfg.lr_scheduler,
|
|
||||||
optimizer=self.optimizer,
|
|
||||||
num_warmup_steps=cfg.lr_warmup_steps,
|
|
||||||
num_training_steps=cfg.offline_steps,
|
|
||||||
# pytorch assumes stepping LRScheduler every epoch
|
|
||||||
# however huggingface diffusers steps it every batch
|
|
||||||
last_epoch=self.global_step - 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
"""
|
|
||||||
Clear observation and action queues. Should be called on `env.reset()`
|
|
||||||
"""
|
|
||||||
self._queues = {
|
|
||||||
"observation.image": deque(maxlen=self.n_obs_steps),
|
|
||||||
"observation.state": deque(maxlen=self.n_obs_steps),
|
|
||||||
"action": deque(maxlen=self.n_action_steps),
|
|
||||||
}
|
|
||||||
|
|
||||||
@torch.no_grad
|
|
||||||
def select_action(self, batch, **_):
|
|
||||||
"""
|
|
||||||
Note: this uses the ema model weights if self.training == False, otherwise the non-ema model weights.
|
|
||||||
"""
|
|
||||||
assert "observation.image" in batch
|
|
||||||
assert "observation.state" in batch
|
|
||||||
assert len(batch) == 2
|
|
||||||
|
|
||||||
self._queues = populate_queues(self._queues, batch)
|
|
||||||
|
|
||||||
if len(self._queues["action"]) == 0:
|
|
||||||
# stack n latest observations from the queue
|
|
||||||
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
|
||||||
if not self.training and self.ema_diffusion is not None:
|
|
||||||
actions = self.ema_diffusion.generate_actions(batch)
|
|
||||||
else:
|
|
||||||
actions = self.diffusion.generate_actions(batch)
|
|
||||||
self._queues["action"].extend(actions.transpose(0, 1))
|
|
||||||
|
|
||||||
action = self._queues["action"].popleft()
|
|
||||||
return action
|
|
||||||
|
|
||||||
def forward(self, batch, **_):
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
self.diffusion.train()
|
|
||||||
|
|
||||||
loss = self.diffusion.compute_loss(batch)
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
|
||||||
self.diffusion.parameters(),
|
|
||||||
self.cfg.grad_clip_norm,
|
|
||||||
error_if_nonfinite=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.optimizer.step()
|
|
||||||
self.optimizer.zero_grad()
|
|
||||||
self.lr_scheduler.step()
|
|
||||||
|
|
||||||
if self.ema is not None:
|
|
||||||
self.ema.step(self.diffusion)
|
|
||||||
|
|
||||||
info = {
|
|
||||||
"loss": loss.item(),
|
|
||||||
"grad_norm": float(grad_norm),
|
|
||||||
"lr": self.lr_scheduler.get_last_lr()[0],
|
|
||||||
"update_s": time.time() - start_time,
|
|
||||||
}
|
|
||||||
|
|
||||||
return info
|
|
||||||
|
|
||||||
def save(self, fp):
|
|
||||||
torch.save(self.state_dict(), fp)
|
|
||||||
|
|
||||||
def load(self, fp):
|
|
||||||
d = torch.load(fp)
|
|
||||||
missing_keys, unexpected_keys = self.load_state_dict(d, strict=False)
|
|
||||||
if len(missing_keys) > 0:
|
|
||||||
assert all(k.startswith("ema_diffusion.") for k in missing_keys)
|
|
||||||
logging.warning(
|
|
||||||
"DiffusionPolicy.load expected ema parameters in loaded state dict but none were found."
|
|
||||||
)
|
|
||||||
assert len(unexpected_keys) == 0
|
|
|
@ -1,59 +1,61 @@
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
|
|
||||||
from lerobot.common.utils import get_safe_torch_device
|
from lerobot.common.utils import get_safe_torch_device
|
||||||
|
|
||||||
|
|
||||||
def make_policy(cfg):
|
def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
|
||||||
if cfg.policy.name == "tdmpc":
|
expected_kwargs = set(inspect.signature(policy_cfg_class).parameters)
|
||||||
|
assert set(hydra_cfg.policy).issuperset(
|
||||||
|
expected_kwargs
|
||||||
|
), f"Hydra config is missing arguments: {set(hydra_cfg.policy).difference(expected_kwargs)}"
|
||||||
|
policy_cfg = policy_cfg_class(
|
||||||
|
**{
|
||||||
|
k: v
|
||||||
|
for k, v in OmegaConf.to_container(hydra_cfg.policy, resolve=True).items()
|
||||||
|
if k in expected_kwargs
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return policy_cfg
|
||||||
|
|
||||||
|
|
||||||
|
def make_policy(hydra_cfg: DictConfig):
|
||||||
|
if hydra_cfg.policy.name == "tdmpc":
|
||||||
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
||||||
|
|
||||||
policy = TDMPCPolicy(
|
policy = TDMPCPolicy(
|
||||||
cfg.policy, n_obs_steps=cfg.n_obs_steps, n_action_steps=cfg.n_action_steps, device=cfg.device
|
hydra_cfg.policy,
|
||||||
|
n_obs_steps=hydra_cfg.n_obs_steps,
|
||||||
|
n_action_steps=hydra_cfg.n_action_steps,
|
||||||
|
device=hydra_cfg.device,
|
||||||
)
|
)
|
||||||
elif cfg.policy.name == "diffusion":
|
elif hydra_cfg.policy.name == "diffusion":
|
||||||
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
|
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||||
|
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||||
|
|
||||||
policy = DiffusionPolicy(
|
policy_cfg = _policy_cfg_from_hydra_cfg(DiffusionConfig, hydra_cfg)
|
||||||
cfg=cfg.policy,
|
policy = DiffusionPolicy(policy_cfg)
|
||||||
cfg_device=cfg.device,
|
policy.to(get_safe_torch_device(hydra_cfg.device))
|
||||||
cfg_noise_scheduler=cfg.noise_scheduler,
|
elif hydra_cfg.policy.name == "act":
|
||||||
cfg_optimizer=cfg.optimizer,
|
|
||||||
cfg_ema=cfg.ema,
|
|
||||||
# n_obs_steps=cfg.n_obs_steps,
|
|
||||||
# n_action_steps=cfg.n_action_steps,
|
|
||||||
**cfg.policy,
|
|
||||||
)
|
|
||||||
elif cfg.policy.name == "act":
|
|
||||||
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
|
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
|
||||||
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
|
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
|
||||||
|
|
||||||
expected_kwargs = set(inspect.signature(ActionChunkingTransformerConfig).parameters)
|
policy_cfg = _policy_cfg_from_hydra_cfg(ActionChunkingTransformerConfig, hydra_cfg)
|
||||||
assert set(cfg.policy).issuperset(
|
|
||||||
expected_kwargs
|
|
||||||
), f"Hydra config is missing arguments: {set(cfg.policy).difference(expected_kwargs)}"
|
|
||||||
policy_cfg = ActionChunkingTransformerConfig(
|
|
||||||
**{
|
|
||||||
k: v
|
|
||||||
for k, v in OmegaConf.to_container(cfg.policy, resolve=True).items()
|
|
||||||
if k in expected_kwargs
|
|
||||||
}
|
|
||||||
)
|
|
||||||
policy = ActionChunkingTransformerPolicy(policy_cfg)
|
policy = ActionChunkingTransformerPolicy(policy_cfg)
|
||||||
policy.to(get_safe_torch_device(cfg.device))
|
policy.to(get_safe_torch_device(hydra_cfg.device))
|
||||||
else:
|
else:
|
||||||
raise ValueError(cfg.policy.name)
|
raise ValueError(hydra_cfg.policy.name)
|
||||||
|
|
||||||
if cfg.policy.pretrained_model_path:
|
if hydra_cfg.policy.pretrained_model_path:
|
||||||
# TODO(rcadene): hack for old pretrained models from fowm
|
# TODO(rcadene): hack for old pretrained models from fowm
|
||||||
if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path:
|
if hydra_cfg.policy.name == "tdmpc" and "fowm" in hydra_cfg.policy.pretrained_model_path:
|
||||||
if "offline" in cfg.policy.pretrained_model_path:
|
if "offline" in hydra_cfg.policy.pretrained_model_path:
|
||||||
policy.step[0] = 25000
|
policy.step[0] = 25000
|
||||||
elif "final" in cfg.policy.pretrained_model_path:
|
elif "final" in hydra_cfg.policy.pretrained_model_path:
|
||||||
policy.step[0] = 100000
|
policy.step[0] = 100000
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
policy.load(cfg.policy.pretrained_model_path)
|
policy.load(hydra_cfg.policy.pretrained_model_path)
|
||||||
|
|
||||||
return policy
|
return policy
|
||||||
|
|
|
@ -18,7 +18,7 @@ policy:
|
||||||
pretrained_model_path:
|
pretrained_model_path:
|
||||||
|
|
||||||
# Environment.
|
# Environment.
|
||||||
# Inherit these from the environment.
|
# Inherit these from the environment config.
|
||||||
state_dim: ???
|
state_dim: ???
|
||||||
action_dim: ???
|
action_dim: ???
|
||||||
|
|
||||||
|
|
|
@ -1,17 +1,5 @@
|
||||||
# @package _global_
|
# @package _global_
|
||||||
|
|
||||||
shape_meta:
|
|
||||||
# acceptable types: rgb, low_dim
|
|
||||||
obs:
|
|
||||||
image:
|
|
||||||
shape: [3, 96, 96]
|
|
||||||
type: rgb
|
|
||||||
agent_pos:
|
|
||||||
shape: [2]
|
|
||||||
type: low_dim
|
|
||||||
action:
|
|
||||||
shape: [2]
|
|
||||||
|
|
||||||
seed: 100000
|
seed: 100000
|
||||||
horizon: 16
|
horizon: 16
|
||||||
n_obs_steps: 2
|
n_obs_steps: 2
|
||||||
|
@ -33,75 +21,70 @@ offline_prioritized_sampler: true
|
||||||
policy:
|
policy:
|
||||||
name: diffusion
|
name: diffusion
|
||||||
|
|
||||||
shape_meta: ${shape_meta}
|
pretrained_model_path:
|
||||||
|
|
||||||
horizon: ${horizon}
|
# Environment.
|
||||||
|
# Inherit these from the environment config.
|
||||||
|
state_dim: ???
|
||||||
|
action_dim: ???
|
||||||
|
image_size:
|
||||||
|
- ${env.image_size} # height
|
||||||
|
- ${env.image_size} # width
|
||||||
|
|
||||||
|
# Inputs / output structure.
|
||||||
n_obs_steps: ${n_obs_steps}
|
n_obs_steps: ${n_obs_steps}
|
||||||
|
horizon: ${horizon}
|
||||||
n_action_steps: ${n_action_steps}
|
n_action_steps: ${n_action_steps}
|
||||||
num_inference_steps: 100
|
|
||||||
# crop_shape: null
|
# Vision preprocessing.
|
||||||
diffusion_step_embed_dim: 128
|
image_normalization_mean: [0.5, 0.5, 0.5]
|
||||||
|
image_normalization_std: [0.5, 0.5, 0.5]
|
||||||
|
|
||||||
|
# Architecture / modeling.
|
||||||
|
# Vision backbone.
|
||||||
|
vision_backbone: resnet18
|
||||||
|
crop_shape: [84, 84]
|
||||||
|
random_crop: True
|
||||||
|
use_pretrained_backbone: false
|
||||||
|
use_group_norm: True
|
||||||
|
spatial_softmax_num_keypoints: 32
|
||||||
|
# Unet.
|
||||||
down_dims: [512, 1024, 2048]
|
down_dims: [512, 1024, 2048]
|
||||||
kernel_size: 5
|
kernel_size: 5
|
||||||
n_groups: 8
|
n_groups: 8
|
||||||
|
diffusion_step_embed_dim: 128
|
||||||
film_scale_modulation: True
|
film_scale_modulation: True
|
||||||
|
# Noise scheduler.
|
||||||
pretrained_model_path:
|
|
||||||
|
|
||||||
batch_size: 64
|
|
||||||
|
|
||||||
per_alpha: 0.6
|
|
||||||
per_beta: 0.4
|
|
||||||
|
|
||||||
balanced_sampling: false
|
|
||||||
utd: 1
|
|
||||||
offline_steps: ${offline_steps}
|
|
||||||
use_ema: true
|
|
||||||
lr_scheduler: cosine
|
|
||||||
lr_warmup_steps: 500
|
|
||||||
grad_clip_norm: 10
|
|
||||||
|
|
||||||
delta_timestamps:
|
|
||||||
observation.image: [-0.1, 0]
|
|
||||||
observation.state: [-0.1, 0]
|
|
||||||
action: [-0.1, 0, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1.0, 1.1, 1.2, 1.3, 1.4]
|
|
||||||
|
|
||||||
rgb_encoder:
|
|
||||||
backbone_name: resnet18
|
|
||||||
pretrained_backbone: false
|
|
||||||
use_group_norm: True
|
|
||||||
num_keypoints: 32
|
|
||||||
relu: true
|
|
||||||
norm_mean_std: [0.5, 0.5] # for PushT the original impl normalizes to [-1, 1] (maybe not the case for robomimic envs)
|
|
||||||
crop_shape: [84, 84]
|
|
||||||
random_crop: True
|
|
||||||
|
|
||||||
noise_scheduler:
|
|
||||||
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
|
|
||||||
num_train_timesteps: 100
|
num_train_timesteps: 100
|
||||||
|
beta_schedule: squaredcos_cap_v2
|
||||||
beta_start: 0.0001
|
beta_start: 0.0001
|
||||||
beta_end: 0.02
|
beta_end: 0.02
|
||||||
beta_schedule: squaredcos_cap_v2
|
variance_type: fixed_small
|
||||||
variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan
|
prediction_type: epsilon # epsilon / sample
|
||||||
clip_sample: True # required when predict_epsilon=False
|
clip_sample: True
|
||||||
prediction_type: epsilon # or sample
|
|
||||||
|
|
||||||
rgb_model:
|
# Inference
|
||||||
pretrained: false
|
num_inference_steps: 100
|
||||||
num_keypoints: 32
|
|
||||||
relu: true
|
|
||||||
|
|
||||||
ema:
|
# ---
|
||||||
_target_: lerobot.common.policies.diffusion.model.ema_model.EMAModel
|
# TODO(alexander-soare): Remove these from the policy config.
|
||||||
update_after_step: 0
|
batch_size: 64
|
||||||
inv_gamma: 1.0
|
grad_clip_norm: 10
|
||||||
power: 0.75
|
|
||||||
min_value: 0.0
|
|
||||||
max_value: 0.9999
|
|
||||||
|
|
||||||
optimizer:
|
|
||||||
_target_: torch.optim.AdamW
|
|
||||||
lr: 1.0e-4
|
lr: 1.0e-4
|
||||||
betas: [0.95, 0.999]
|
lr_scheduler: cosine
|
||||||
eps: 1.0e-8
|
lr_warmup_steps: 500
|
||||||
weight_decay: 1.0e-6
|
adam_betas: [0.95, 0.999]
|
||||||
|
adam_eps: 1.0e-8
|
||||||
|
adam_weight_decay: 1.0e-6
|
||||||
|
utd: 1
|
||||||
|
use_ema: true
|
||||||
|
ema_update_after_step: 0
|
||||||
|
ema_min_rate: 0.0
|
||||||
|
ema_max_rate: 0.9999
|
||||||
|
ema_inv_gamma: 1.0
|
||||||
|
ema_power: 0.75
|
||||||
|
|
||||||
|
delta_timestamps:
|
||||||
|
observation.images: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]"
|
||||||
|
observation.state: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]"
|
||||||
|
action: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1 - ${n_obs_steps} + ${policy.horizon})]"
|
||||||
|
|
|
@ -19,7 +19,7 @@ from lerobot.common.datasets.aloha import AlohaDataset
|
||||||
from lerobot.common.datasets.pusht import PushtDataset
|
from lerobot.common.datasets.pusht import PushtDataset
|
||||||
|
|
||||||
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
|
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
|
||||||
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
|
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||||
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue