diff --git a/examples/3_train_policy.py b/examples/3_train_policy.py index d2fff13b..64804d8f 100644 --- a/examples/3_train_policy.py +++ b/examples/3_train_policy.py @@ -11,7 +11,7 @@ import torch from omegaconf import OmegaConf from lerobot.common.datasets.factory import make_dataset -from lerobot.common.policies.diffusion.policy import DiffusionPolicy +from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.common.utils import init_hydra_config output_directory = Path("outputs/train/example_pusht_diffusion") diff --git a/lerobot/common/policies/act/configuration_act.py b/lerobot/common/policies/act/configuration_act.py index 74ed270e..72d35eb3 100644 --- a/lerobot/common/policies/act/configuration_act.py +++ b/lerobot/common/policies/act/configuration_act.py @@ -56,7 +56,7 @@ class ActionChunkingTransformerConfig: # Inputs / output structure. n_obs_steps: int = 1 - camera_names: list[str] = field(default_factory=lambda: ["top"]) + camera_names: tuple[str] = ("top",) chunk_size: int = 100 n_action_steps: int = 100 @@ -101,7 +101,7 @@ class ActionChunkingTransformerConfig: utd: int = 1 def __post_init__(self): - """Input validation.""" + """Input validation (not exhaustive).""" if not self.vision_backbone.startswith("resnet"): raise ValueError("`vision_backbone` must be one of the ResNet variants.") if self.use_temporal_aggregation: diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 1361e071..18ea3377 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -163,7 +163,8 @@ class ActionChunkingTransformerPolicy(nn.Module): @torch.no_grad def select_action(self, batch: dict[str, Tensor], **_) -> Tensor: - """ + """Select a single action given environment observations. + This method wraps `select_actions` in order to return one action at a time for execution in the environment. It works by managing the actions in a queue and only calling `select_actions` when the queue is empty. diff --git a/lerobot/common/policies/diffusion/configuration_diffusion.py b/lerobot/common/policies/diffusion/configuration_diffusion.py new file mode 100644 index 00000000..272c80ea --- /dev/null +++ b/lerobot/common/policies/diffusion/configuration_diffusion.py @@ -0,0 +1,83 @@ +from dataclasses import dataclass + + +@dataclass +class DiffusionConfig: + """Configuration class for Diffusion Policy. + + Defaults are configured for training with PushT providing proprioceptive and single camera observations. + + The parameters you will most likely need to change are the ones which depend on the environment / sensors. + Those are: `state_dim`, `action_dim` and `image_size`. + + Args: + state_dim: Dimensionality of the observation state space (excluding images). + action_dim: Dimensionality of the action space. + n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the + current step and additional steps going back). + horizon: Diffusion model action prediction horizon as detailed in the main policy documentation. + """ + + # Environment. + # Inherit these from the environment config. + state_dim: int = 2 + action_dim: int = 2 + image_size: tuple[int, int] = (96, 96) + + # Inputs / output structure. + n_obs_steps: int = 2 + horizon: int = 16 + n_action_steps: int = 8 + + # Vision preprocessing. + image_normalization_mean: tuple[float, float, float] = (0.5, 0.5, 0.5) + image_normalization_std: tuple[float, float, float] = (0.5, 0.5, 0.5) + + # Architecture / modeling. + # Vision backbone. + vision_backbone: str = "resnet18" + crop_shape: tuple[int, int] = (84, 84) + crop_is_random: bool = True + use_pretrained_backbone: bool = False + use_group_norm: bool = True + spatial_softmax_num_keypoints: int = 32 + # Unet. + down_dims: tuple[int, ...] = (512, 1024, 2048) + kernel_size: int = 5 + n_groups: int = 8 + diffusion_step_embed_dim: int = 128 + film_scale_modulation: bool = True + # Noise scheduler. + num_train_timesteps: int = 100 + beta_schedule: str = "squaredcos_cap_v2" + beta_start: float = 0.0001 + beta_end: float = 0.02 + variance_type: str = "fixed_small" + prediction_type: str = "epsilon" + clip_sample: True + + # Inference + num_inference_steps: int = 100 + + # --- + # TODO(alexander-soare): Remove these from the policy config. + batch_size: int = 64 + grad_clip_norm: int = 10 + lr: float = 1.0e-4 + lr_scheduler: str = "cosine" + lr_warmup_steps: int = 500 + adam_betas: tuple[float, float] = (0.95, 0.999) + adam_eps: float = 1.0e-8 + adam_weight_decay: float = 1.0e-6 + utd: int = 1 + use_ema: bool = True + ema_update_after_step: int = 0 + ema_min_rate: float = 0.0 + ema_max_rate: float = 0.9999 + ema_inv_gamma: float = 1.0 + ema_power: float = 0.75 + + def __post_init__(self): + """Input validation (not exhaustive).""" + if not self.vision_backbone.startswith("resnet"): + raise ValueError("`vision_backbone` must be one of the ResNet variants.") diff --git a/lerobot/common/policies/diffusion/model/conditional_unet1d.py b/lerobot/common/policies/diffusion/model/conditional_unet1d.py deleted file mode 100644 index c3dcc198..00000000 --- a/lerobot/common/policies/diffusion/model/conditional_unet1d.py +++ /dev/null @@ -1,306 +0,0 @@ -import logging -import math - -import einops -import torch -import torch.nn as nn -from torch import Tensor - -logger = logging.getLogger(__name__) - - -class _SinusoidalPosEmb(nn.Module): - def __init__(self, dim): - super().__init__() - self.dim = dim - - def forward(self, x): - device = x.device - half_dim = self.dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, device=device) * -emb) - emb = x[:, None] * emb[None, :] - emb = torch.cat((emb.sin(), emb.cos()), dim=-1) - return emb - - -class _Conv1dBlock(nn.Module): - """Conv1d --> GroupNorm --> Mish""" - - def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): - super().__init__() - - self.block = nn.Sequential( - nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), - nn.GroupNorm(n_groups, out_channels), - nn.Mish(), - ) - - def forward(self, x): - return self.block(x) - - -class _ConditionalResidualBlock1D(nn.Module): - """ResNet style 1D convolutional block with FiLM modulation for conditioning.""" - - def __init__( - self, - in_channels: int, - out_channels: int, - cond_dim: int, - kernel_size: int = 3, - n_groups: int = 8, - # Set to True to do scale modulation with FiLM as well as bias modulation (defaults to False meaning - # FiLM just modulates bias). - film_scale_modulation: bool = False, - ): - super().__init__() - - self.film_scale_modulation = film_scale_modulation - self.out_channels = out_channels - - self.conv1 = _Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups) - - # FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale. - cond_channels = out_channels * 2 if film_scale_modulation else out_channels - self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels)) - - self.conv2 = _Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups) - - # A final convolution for dimension matching the residual (if needed). - self.residual_conv = ( - nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity() - ) - - def forward(self, x: Tensor, cond: Tensor) -> Tensor: - """ - Args: - x: (B, in_channels, T) - cond: (B, cond_dim) - Returns: - (B, out_channels, T) - """ - out = self.conv1(x) - - # Get condition embedding. Unsqueeze for broadcasting to `out`, resulting in (B, out_channels, 1). - cond_embed = self.cond_encoder(cond).unsqueeze(-1) - if self.film_scale_modulation: - # Treat the embedding as a list of scales and biases. - scale = cond_embed[:, : self.out_channels] - bias = cond_embed[:, self.out_channels :] - out = scale * out + bias - else: - # Treat the embedding as biases. - out = out + cond_embed - - out = self.conv2(out) - out = out + self.residual_conv(x) - return out - - -class ConditionalUnet1D(nn.Module): - """A 1D convolutional UNet with FiLM modulation for conditioning. - - Two types of conditioning can be applied: - - Global: Conditioning information that is aggregated over the whole observation window. This is - incorporated via the FiLM technique in the residual convolution blocks of the Unet's encoder/decoder. - - Local: Conditioning information for each timestep in the observation window. This is incorporated - by encoding the information via 1D convolutions and adding the resulting embeddings to the inputs and - outputs of the Unet's encoder/decoder. - """ - - def __init__( - self, - input_dim: int, - local_cond_dim: int | None = None, - global_cond_dim: int | None = None, - diffusion_step_embed_dim: int = 256, - down_dims: int | None = None, - kernel_size: int = 3, - n_groups: int = 8, - film_scale_modulation: bool = False, - ): - super().__init__() - - if down_dims is None: - down_dims = [256, 512, 1024] - - # Encoder for the diffusion timestep. - self.diffusion_step_encoder = nn.Sequential( - _SinusoidalPosEmb(diffusion_step_embed_dim), - nn.Linear(diffusion_step_embed_dim, diffusion_step_embed_dim * 4), - nn.Mish(), - nn.Linear(diffusion_step_embed_dim * 4, diffusion_step_embed_dim), - ) - - # The FiLM conditioning dimension. - cond_dim = diffusion_step_embed_dim - if global_cond_dim is not None: - cond_dim += global_cond_dim - - self.local_cond_down_encoder = None - self.local_cond_up_encoder = None - if local_cond_dim is not None: - # Encoder for the local conditioning. The output gets added to the Unet encoder input. - self.local_cond_down_encoder = _ConditionalResidualBlock1D( - local_cond_dim, - down_dims[0], - cond_dim=cond_dim, - kernel_size=kernel_size, - n_groups=n_groups, - film_scale_modulation=film_scale_modulation, - ) - # Encoder for the local conditioning. The output gets added to the Unet encoder output. - self.local_cond_up_encoder = _ConditionalResidualBlock1D( - local_cond_dim, - down_dims[0], - cond_dim=cond_dim, - kernel_size=kernel_size, - n_groups=n_groups, - film_scale_modulation=film_scale_modulation, - ) - - # In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we - # just reverse these. - in_out = [(input_dim, down_dims[0])] + list(zip(down_dims[:-1], down_dims[1:], strict=True)) - - # Unet encoder. - self.down_modules = nn.ModuleList([]) - for ind, (dim_in, dim_out) in enumerate(in_out): - is_last = ind >= (len(in_out) - 1) - self.down_modules.append( - nn.ModuleList( - [ - _ConditionalResidualBlock1D( - dim_in, - dim_out, - cond_dim=cond_dim, - kernel_size=kernel_size, - n_groups=n_groups, - film_scale_modulation=film_scale_modulation, - ), - _ConditionalResidualBlock1D( - dim_out, - dim_out, - cond_dim=cond_dim, - kernel_size=kernel_size, - n_groups=n_groups, - film_scale_modulation=film_scale_modulation, - ), - # Downsample as long as it is not the last block. - nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(), - ] - ) - ) - - # Processing in the middle of the auto-encoder. - self.mid_modules = nn.ModuleList( - [ - _ConditionalResidualBlock1D( - down_dims[-1], - down_dims[-1], - cond_dim=cond_dim, - kernel_size=kernel_size, - n_groups=n_groups, - film_scale_modulation=film_scale_modulation, - ), - _ConditionalResidualBlock1D( - down_dims[-1], - down_dims[-1], - cond_dim=cond_dim, - kernel_size=kernel_size, - n_groups=n_groups, - film_scale_modulation=film_scale_modulation, - ), - ] - ) - - # Unet decoder. - self.up_modules = nn.ModuleList([]) - for ind, (dim_out, dim_in) in enumerate(reversed(in_out[1:])): - is_last = ind >= (len(in_out) - 1) - self.up_modules.append( - nn.ModuleList( - [ - _ConditionalResidualBlock1D( - dim_in * 2, # x2 as it takes the encoder's skip connection as well - dim_out, - cond_dim=cond_dim, - kernel_size=kernel_size, - n_groups=n_groups, - film_scale_modulation=film_scale_modulation, - ), - _ConditionalResidualBlock1D( - dim_out, - dim_out, - cond_dim=cond_dim, - kernel_size=kernel_size, - n_groups=n_groups, - film_scale_modulation=film_scale_modulation, - ), - # Upsample as long as it is not the last block. - nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(), - ] - ) - ) - - self.final_conv = nn.Sequential( - _Conv1dBlock(down_dims[0], down_dims[0], kernel_size=kernel_size), - nn.Conv1d(down_dims[0], input_dim, 1), - ) - - def forward(self, x: Tensor, timestep: Tensor | int, local_cond=None, global_cond=None) -> Tensor: - """ - Args: - x: (B, T, input_dim) tensor for input to the Unet. - timestep: (B,) tensor of (timestep_we_are_denoising_from - 1). - local_cond: (B, T, local_cond_dim) - global_cond: (B, global_cond_dim) - output: (B, T, input_dim) - Returns: - (B, T, input_dim) - """ - # For 1D convolutions we'll need feature dimension first. - x = einops.rearrange(x, "b t d -> b d t") - if local_cond is not None: - if self.local_cond_down_encoder is None or self.local_cond_up_encoder is None: - raise ValueError( - "`local_cond` was provided but the relevant encoders weren't built at initialization." - ) - local_cond = einops.rearrange(local_cond, "b t d -> b d t") - - timesteps_embed = self.diffusion_step_encoder(timestep) - - # If there is a global conditioning feature, concatenate it to the timestep embedding. - if global_cond is not None: - global_feature = torch.cat([timesteps_embed, global_cond], axis=-1) - else: - global_feature = timesteps_embed - - encoder_skip_features: list[Tensor] = [] - for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules): - x = resnet(x, global_feature) - if idx == 0 and local_cond is not None: - x = x + self.local_cond_down_encoder(local_cond, global_feature) - x = resnet2(x, global_feature) - encoder_skip_features.append(x) - x = downsample(x) - - for mid_module in self.mid_modules: - x = mid_module(x, global_feature) - - for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules): - x = torch.cat((x, encoder_skip_features.pop()), dim=1) - x = resnet(x, global_feature) - # Note: The condition in the original implementation is: - # if idx == len(self.up_modules) and local_cond is not None: - # But as they mention in their comments, this is incorrect. We use the correct condition here. - if idx == (len(self.up_modules) - 1) and local_cond is not None: - x = x + self.local_cond_up_encoder(local_cond, global_feature) - x = resnet2(x, global_feature) - x = upsample(x) - - x = self.final_conv(x) - - x = einops.rearrange(x, "b d t -> b t d") - return x diff --git a/lerobot/common/policies/diffusion/model/diffusion_unet_image_policy.py b/lerobot/common/policies/diffusion/model/diffusion_unet_image_policy.py deleted file mode 100644 index 3e7727f3..00000000 --- a/lerobot/common/policies/diffusion/model/diffusion_unet_image_policy.py +++ /dev/null @@ -1,175 +0,0 @@ -import einops -import torch -import torch.nn.functional as F # noqa: N812 -from diffusers.schedulers.scheduling_ddpm import DDPMScheduler -from torch import Tensor, nn - -from lerobot.common.policies.diffusion.model.conditional_unet1d import ConditionalUnet1D -from lerobot.common.policies.diffusion.model.rgb_encoder import RgbEncoder -from lerobot.common.policies.utils import get_device_from_parameters, get_dtype_from_parameters - - -class DiffusionUnetImagePolicy(nn.Module): - def __init__( - self, - cfg, - shape_meta: dict, - noise_scheduler: DDPMScheduler, - horizon, - n_action_steps, - n_obs_steps, - num_inference_steps=None, - diffusion_step_embed_dim=256, - down_dims=(256, 512, 1024), - kernel_size=5, - n_groups=8, - film_scale_modulation=True, - ): - super().__init__() - action_shape = shape_meta["action"]["shape"] - assert len(action_shape) == 1 - action_dim = action_shape[0] - - self.rgb_encoder = RgbEncoder(input_shape=shape_meta.obs.image.shape, **cfg.rgb_encoder) - - self.unet = ConditionalUnet1D( - input_dim=action_dim, - global_cond_dim=(action_dim + self.rgb_encoder.feature_dim) * n_obs_steps, - diffusion_step_embed_dim=diffusion_step_embed_dim, - down_dims=down_dims, - kernel_size=kernel_size, - n_groups=n_groups, - film_scale_modulation=film_scale_modulation, - ) - - self.noise_scheduler = noise_scheduler - self.horizon = horizon - self.action_dim = action_dim - self.n_action_steps = n_action_steps - self.n_obs_steps = n_obs_steps - - if num_inference_steps is None: - num_inference_steps = noise_scheduler.config.num_train_timesteps - - self.num_inference_steps = num_inference_steps - - # ========= inference ============ - def conditional_sample(self, batch_size, global_cond=None, generator=None): - device = get_device_from_parameters(self) - dtype = get_dtype_from_parameters(self) - - # Sample prior. - sample = torch.randn( - size=(batch_size, self.horizon, self.action_dim), - dtype=dtype, - device=device, - generator=generator, - ) - - self.noise_scheduler.set_timesteps(self.num_inference_steps) - - for t in self.noise_scheduler.timesteps: - # Predict model output. - model_output = self.unet( - sample, - torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device), - global_cond=global_cond, - ) - # Compute previous image: x_t -> x_t-1 - sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample - - return sample - - def generate_actions(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: - """ - This function expects `batch` to have (at least): - { - "observation.state": (B, n_obs_steps, state_dim) - "observation.image": (B, n_obs_steps, C, H, W) - } - """ - assert set(batch).issuperset({"observation.state", "observation.image"}) - batch_size, n_obs_steps = batch["observation.state"].shape[:2] - assert n_obs_steps == self.n_obs_steps - assert self.n_obs_steps == n_obs_steps - - # Extract image feature (first combine batch and sequence dims). - img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ...")) - # Separate batch and sequence dims. - img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size) - # Concatenate state and image features then flatten to (B, global_cond_dim). - global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1) - - # run sampling - sample = self.conditional_sample(batch_size, global_cond=global_cond) - - # `horizon` steps worth of actions (from the first observation). - action = sample[..., : self.action_dim] - # Extract `n_action_steps` steps worth of actions (from the current observation). - start = n_obs_steps - 1 - end = start + self.n_action_steps - action = action[:, start:end] - - return action - - def compute_loss(self, batch: dict[str, Tensor]) -> Tensor: - """ - This function expects `batch` to have (at least): - { - "observation.state": (B, n_obs_steps, state_dim) - "observation.image": (B, n_obs_steps, C, H, W) - "action": (B, horizon, action_dim) - "action_is_pad": (B, horizon) - } - """ - # Input validation. - assert set(batch).issuperset({"observation.state", "observation.image", "action", "action_is_pad"}) - batch_size, n_obs_steps = batch["observation.state"].shape[:2] - horizon = batch["action"].shape[1] - assert horizon == self.horizon - assert n_obs_steps == self.n_obs_steps - assert self.n_obs_steps == n_obs_steps - - # Extract image feature (first combine batch and sequence dims). - img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ...")) - # Separate batch and sequence dims. - img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size) - # Concatenate state and image features then flatten to (B, global_cond_dim). - global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1) - - trajectory = batch["action"] - - # Forward diffusion. - # Sample noise to add to the trajectory. - eps = torch.randn(trajectory.shape, device=trajectory.device) - # Sample a random noising timestep for each item in the batch. - timesteps = torch.randint( - low=0, - high=self.noise_scheduler.config.num_train_timesteps, - size=(trajectory.shape[0],), - device=trajectory.device, - ).long() - # Add noise to the clean trajectories according to the noise magnitude at each timestep. - noisy_trajectory = self.noise_scheduler.add_noise(trajectory, eps, timesteps) - - # Run the denoising network (that might denoise the trajectory, or attempt to predict the noise). - pred = self.unet(noisy_trajectory, timesteps, global_cond=global_cond) - - # Compute the loss. - # The targe is either the original trajectory, or the noise. - pred_type = self.noise_scheduler.config.prediction_type - if pred_type == "epsilon": - target = eps - elif pred_type == "sample": - target = batch["action"] - else: - raise ValueError(f"Unsupported prediction type {pred_type}") - - loss = F.mse_loss(pred, target, reduction="none") - - # Mask loss wherever the action is padded with copies (edges of the dataset trajectory). - if "action_is_pad" in batch: - in_episode_bound = ~batch["action_is_pad"] - loss = loss * in_episode_bound.unsqueeze(-1) - - return loss.mean() diff --git a/lerobot/common/policies/diffusion/model/ema_model.py b/lerobot/common/policies/diffusion/model/ema_model.py deleted file mode 100644 index 1e3447f3..00000000 --- a/lerobot/common/policies/diffusion/model/ema_model.py +++ /dev/null @@ -1,68 +0,0 @@ -import torch -from torch.nn.modules.batchnorm import _BatchNorm - - -class EMAModel: - """ - Exponential Moving Average of models weights - """ - - def __init__( - self, model, update_after_step=0, inv_gamma=1.0, power=2 / 3, min_value=0.0, max_value=0.9999 - ): - """ - @crowsonkb's notes on EMA Warmup: - If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan - to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), - gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 - at 215.4k steps). - Args: - inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. - power (float): Exponential factor of EMA warmup. Default: 2/3. - min_value (float): The minimum EMA decay rate. Default: 0. - """ - - self.averaged_model = model - self.averaged_model.eval() - self.averaged_model.requires_grad_(False) - - self.update_after_step = update_after_step - self.inv_gamma = inv_gamma - self.power = power - self.min_value = min_value - self.max_value = max_value - - self.alpha = 0.0 - self.optimization_step = 0 - - def get_decay(self, optimization_step): - """ - Compute the decay factor for the exponential moving average. - """ - step = max(0, optimization_step - self.update_after_step - 1) - value = 1 - (1 + step / self.inv_gamma) ** -self.power - - if step <= 0: - return 0.0 - - return max(self.min_value, min(value, self.max_value)) - - @torch.no_grad() - def step(self, new_model): - self.alpha = self.get_decay(self.optimization_step) - - for module, ema_module in zip(new_model.modules(), self.averaged_model.modules(), strict=True): - # Iterate over immediate parameters only. - for param, ema_param in zip( - module.parameters(recurse=False), ema_module.parameters(recurse=False), strict=True - ): - if isinstance(param, dict): - raise RuntimeError("Dict parameter not supported") - if isinstance(module, _BatchNorm) or not param.requires_grad: - # Copy BatchNorm parameters, and non-trainable parameters directly. - ema_param.copy_(param.to(dtype=ema_param.dtype).data) - else: - ema_param.mul_(self.alpha) - ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.alpha) - - self.optimization_step += 1 diff --git a/lerobot/common/policies/diffusion/model/rgb_encoder.py b/lerobot/common/policies/diffusion/model/rgb_encoder.py deleted file mode 100644 index 2a5edf45..00000000 --- a/lerobot/common/policies/diffusion/model/rgb_encoder.py +++ /dev/null @@ -1,147 +0,0 @@ -from typing import Callable - -import torch -import torchvision -from robomimic.models.base_nets import SpatialSoftmax -from torch import Tensor, nn -from torchvision.transforms import CenterCrop, RandomCrop - - -class RgbEncoder(nn.Module): - """Encoder an RGB image into a 1D feature vector. - - Includes the ability to normalize and crop the image first. - """ - - def __init__( - self, - input_shape: tuple[int, int, int], - norm_mean_std: tuple[float, float] = [1.0, 1.0], - crop_shape: tuple[int, int] | None = None, - random_crop: bool = False, - backbone_name: str = "resnet18", - pretrained_backbone: bool = False, - use_group_norm: bool = False, - relu: bool = True, - num_keypoints: int = 32, - ): - """ - Args: - input_shape: channel-first input shape (C, H, W) - norm_mean_std: mean and standard deviation used for image normalization. Images are normalized as - (image - mean) / std. - crop_shape: (H, W) shape to crop to (must fit within the input shape). If not provided, no - cropping is done. - random_crop: Whether the crop should be random at training time (it's always a center crop in - eval mode). - backbone_name: The name of one of the available resnet models from torchvision (eg resnet18). - pretrained_backbone: whether to use timm pretrained weights. - use_group_norm: Whether to replace batch normalization with group normalization in the backbone. - The group sizes are set to be about 16 (to be precise, feature_dim // 16). - relu: whether to use relu as a final step. - num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image). - """ - super().__init__() - if input_shape[0] != 3: - raise ValueError("Only RGB images are handled") - if not backbone_name.startswith("resnet"): - raise ValueError( - "Only resnet is supported for now (because of the assumption that 'layer4' is the output layer)" - ) - - # Set up optional preprocessing. - if norm_mean_std == [1.0, 1.0]: - self.normalizer = nn.Identity() - else: - self.normalizer = torchvision.transforms.Normalize(mean=norm_mean_std[0], std=norm_mean_std[1]) - - if crop_shape is not None: - self.do_crop = True - self.center_crop = CenterCrop(crop_shape) # always use center crop for eval - if random_crop: - self.maybe_random_crop = RandomCrop(crop_shape) - else: - self.maybe_random_crop = self.center_crop - else: - self.do_crop = False - - # Set up backbone. - backbone_model = getattr(torchvision.models, backbone_name)(pretrained=pretrained_backbone) - # Note: This assumes that the layer4 feature map is children()[-3] - # TODO(alexander-soare): Use a safer alternative. - self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2])) - if use_group_norm: - if pretrained_backbone: - raise ValueError( - "You can't replace BatchNorm in a pretrained model without ruining the weights!" - ) - self.backbone = _replace_submodules( - root_module=self.backbone, - predicate=lambda x: isinstance(x, nn.BatchNorm2d), - func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features), - ) - - # Set up pooling and final layers. - # Use a dry run to get the feature map shape. - with torch.inference_mode(): - feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:]) - self.pool = SpatialSoftmax(feat_map_shape, num_kp=num_keypoints) - self.feature_dim = num_keypoints * 2 - self.out = nn.Linear(num_keypoints * 2, self.feature_dim) - self.maybe_relu = nn.ReLU() if relu else nn.Identity() - - def forward(self, x: Tensor) -> Tensor: - """ - Args: - x: (B, C, H, W) image tensor with pixel values in [0, 1]. - Returns: - (B, D) image feature. - """ - # Preprocess: normalize and maybe crop (if it was set up in the __init__). - x = self.normalizer(x) - if self.do_crop: - if self.training: # noqa: SIM108 - x = self.maybe_random_crop(x) - else: - # Always use center crop for eval. - x = self.center_crop(x) - # Extract backbone feature. - x = torch.flatten(self.pool(self.backbone(x)), start_dim=1) - # Final linear layer. - x = self.out(x) - # Maybe a final non-linearity. - x = self.maybe_relu(x) - return x - - -def _replace_submodules( - root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module] -) -> nn.Module: - """ - Args: - root_module: The module for which the submodules need to be replaced - predicate: Takes a module as an argument and must return True if the that module is to be replaced. - func: Takes a module as an argument and returns a new module to replace it with. - Returns: - The root module with its submodules replaced. - """ - if predicate(root_module): - return func(root_module) - - replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)] - for *parents, k in replace_list: - parent_module = root_module - if len(parents) > 0: - parent_module = root_module.get_submodule(".".join(parents)) - if isinstance(parent_module, nn.Sequential): - src_module = parent_module[int(k)] - else: - src_module = getattr(parent_module, k) - tgt_module = func(src_module) - if isinstance(parent_module, nn.Sequential): - parent_module[int(k)] = tgt_module - else: - setattr(parent_module, k, tgt_module) - # verify that all BN are replaced - assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True)) - return root_module diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py new file mode 100644 index 00000000..4853dbcf --- /dev/null +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -0,0 +1,878 @@ +""" +TODO(alexander-soare): + - Remove reliance on Robomimic for SpatialSoftmax. + - Remove reliance on diffusers for DDPMScheduler. + - Move EMA out of policy. +""" + +import copy +import logging +import math +import time +from collections import deque +from typing import Callable + +import einops +import hydra +import torch +import torch.nn.functional as F # noqa: N812 +import torchvision +from diffusers.optimization import get_scheduler +from diffusers.schedulers.scheduling_ddpm import DDPMScheduler +from robomimic.models.base_nets import SpatialSoftmax +from torch import Tensor, nn +from torch.nn.modules.batchnorm import _BatchNorm + +from lerobot.common.policies.utils import ( + get_device_from_parameters, + get_dtype_from_parameters, + populate_queues, +) +from lerobot.common.utils import get_safe_torch_device + +logger = logging.getLogger(__name__) + + +class DiffusionPolicy(nn.Module): + """ + Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion" + (paper: https://arxiv.org/abs/2303.04137, code: https://github.com/real-stanford/diffusion_policy). + """ + + name = "diffusion" + + def __init__( + self, + cfg, + cfg_device, + cfg_noise_scheduler, + cfg_optimizer, + cfg_ema, + shape_meta: dict, + horizon, + n_action_steps, + n_obs_steps, + num_inference_steps=None, + diffusion_step_embed_dim=256, + down_dims=(256, 512, 1024), + kernel_size=5, + n_groups=8, + film_scale_modulation=True, + **_, + ): + super().__init__() + self.cfg = cfg + self.n_obs_steps = n_obs_steps + self.n_action_steps = n_action_steps + + # queues are populated during rollout of the policy, they contain the n latest observations and actions + self._queues = None + + noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler) + + self.diffusion = _DiffusionUnetImagePolicy( + cfg, + shape_meta=shape_meta, + noise_scheduler=noise_scheduler, + horizon=horizon, + n_action_steps=n_action_steps, + n_obs_steps=n_obs_steps, + num_inference_steps=num_inference_steps, + diffusion_step_embed_dim=diffusion_step_embed_dim, + down_dims=down_dims, + kernel_size=kernel_size, + n_groups=n_groups, + film_scale_modulation=film_scale_modulation, + ) + + self.device = get_safe_torch_device(cfg_device) + self.diffusion.to(self.device) + + # TODO(alexander-soare): This should probably be managed outside of the policy class. + self.ema_diffusion = None + self.ema = None + if self.cfg.use_ema: + self.ema_diffusion = copy.deepcopy(self.diffusion) + self.ema = hydra.utils.instantiate( + cfg_ema, + model=self.ema_diffusion, + ) + + self.optimizer = hydra.utils.instantiate( + cfg_optimizer, + params=self.diffusion.parameters(), + ) + + # TODO(rcadene): modify lr scheduler so that it doesnt depend on epochs but steps + self.global_step = 0 + + # configure lr scheduler + self.lr_scheduler = get_scheduler( + cfg.lr_scheduler, + optimizer=self.optimizer, + num_warmup_steps=cfg.lr_warmup_steps, + num_training_steps=cfg.offline_steps, + # pytorch assumes stepping LRScheduler every epoch + # however huggingface diffusers steps it every batch + last_epoch=self.global_step - 1, + ) + + def reset(self): + """ + Clear observation and action queues. Should be called on `env.reset()` + """ + self._queues = { + "observation.image": deque(maxlen=self.n_obs_steps), + "observation.state": deque(maxlen=self.n_obs_steps), + "action": deque(maxlen=self.n_action_steps), + } + + @torch.no_grad + def select_action(self, batch: dict[str, Tensor], **_) -> Tensor: + """Select a single action given environment observations. + + This method handles caching a history of observations and an action trajectory generated by the + underlying diffusion model. Here's how it works: + - `n_obs_steps` steps worth of observations are cached (for the first steps, the observation is + copied `n_obs_steps` times to fill the cache). + - The diffusion model generates `horizon` steps worth of actions. + - `n_action_steps` worth of actions are actually kept for execution, starting from the current step. + Schematically this looks like: + (legend: o = n_obs_steps, h = horizon, a = n_action_steps) + |timestep | n-o+1 | n-o+2 | ..... | n | ..... | n+a-1 | n+a | ..... |n-o+1+h| + |observation is used | YES | YES | ..... | NO | NO | NO | NO | NO | NO | + |action is generated | YES | YES | YES | YES | YES | YES | YES | YES | YES | + |action is used | NO | NO | NO | YES | YES | YES | NO | NO | NO | + Note that this means we require: `n_action_steps < horizon - n_obs_steps + 1`. Also, note that + "horizon" may not the best name to describe what the variable actually means, because this period is + actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past. + + Note: this method uses the ema model weights if self.training == False, otherwise the non-ema model + weights. + """ + assert "observation.image" in batch + assert "observation.state" in batch + assert len(batch) == 2 + + self._queues = populate_queues(self._queues, batch) + + if len(self._queues["action"]) == 0: + # stack n latest observations from the queue + batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch} + if not self.training and self.ema_diffusion is not None: + actions = self.ema_diffusion.generate_actions(batch) + else: + actions = self.diffusion.generate_actions(batch) + self._queues["action"].extend(actions.transpose(0, 1)) + + action = self._queues["action"].popleft() + return action + + def forward(self, batch, **_): + start_time = time.time() + + self.diffusion.train() + + loss = self.diffusion.compute_loss(batch) + loss.backward() + + grad_norm = torch.nn.utils.clip_grad_norm_( + self.diffusion.parameters(), + self.cfg.grad_clip_norm, + error_if_nonfinite=False, + ) + + self.optimizer.step() + self.optimizer.zero_grad() + self.lr_scheduler.step() + + if self.ema is not None: + self.ema.step(self.diffusion) + + info = { + "loss": loss.item(), + "grad_norm": float(grad_norm), + "lr": self.lr_scheduler.get_last_lr()[0], + "update_s": time.time() - start_time, + } + + return info + + def save(self, fp): + torch.save(self.state_dict(), fp) + + def load(self, fp): + d = torch.load(fp) + missing_keys, unexpected_keys = self.load_state_dict(d, strict=False) + if len(missing_keys) > 0: + assert all(k.startswith("ema_diffusion.") for k in missing_keys) + logging.warning( + "DiffusionPolicy.load expected ema parameters in loaded state dict but none were found." + ) + assert len(unexpected_keys) == 0 + + +class _DiffusionUnetImagePolicy(nn.Module): + def __init__( + self, + cfg, + shape_meta: dict, + noise_scheduler: DDPMScheduler, + horizon, + n_action_steps, + n_obs_steps, + num_inference_steps=None, + diffusion_step_embed_dim=256, + down_dims=(256, 512, 1024), + kernel_size=5, + n_groups=8, + film_scale_modulation=True, + ): + super().__init__() + action_shape = shape_meta["action"]["shape"] + assert len(action_shape) == 1 + action_dim = action_shape[0] + + self.rgb_encoder = _RgbEncoder(input_shape=shape_meta.obs.image.shape, **cfg.rgb_encoder) + + self.unet = _ConditionalUnet1D( + input_dim=action_dim, + global_cond_dim=(action_dim + self.rgb_encoder.feature_dim) * n_obs_steps, + diffusion_step_embed_dim=diffusion_step_embed_dim, + down_dims=down_dims, + kernel_size=kernel_size, + n_groups=n_groups, + film_scale_modulation=film_scale_modulation, + ) + + self.noise_scheduler = noise_scheduler + self.horizon = horizon + self.action_dim = action_dim + self.n_action_steps = n_action_steps + self.n_obs_steps = n_obs_steps + + if num_inference_steps is None: + num_inference_steps = noise_scheduler.config.num_train_timesteps + + self.num_inference_steps = num_inference_steps + + # ========= inference ============ + def conditional_sample(self, batch_size, global_cond=None, generator=None): + device = get_device_from_parameters(self) + dtype = get_dtype_from_parameters(self) + + # Sample prior. + sample = torch.randn( + size=(batch_size, self.horizon, self.action_dim), + dtype=dtype, + device=device, + generator=generator, + ) + + self.noise_scheduler.set_timesteps(self.num_inference_steps) + + for t in self.noise_scheduler.timesteps: + # Predict model output. + model_output = self.unet( + sample, + torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device), + global_cond=global_cond, + ) + # Compute previous image: x_t -> x_t-1 + sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample + + return sample + + def generate_actions(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + """ + This function expects `batch` to have (at least): + { + "observation.state": (B, n_obs_steps, state_dim) + "observation.image": (B, n_obs_steps, C, H, W) + } + """ + assert set(batch).issuperset({"observation.state", "observation.image"}) + batch_size, n_obs_steps = batch["observation.state"].shape[:2] + assert n_obs_steps == self.n_obs_steps + assert self.n_obs_steps == n_obs_steps + + # Extract image feature (first combine batch and sequence dims). + img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ...")) + # Separate batch and sequence dims. + img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size) + # Concatenate state and image features then flatten to (B, global_cond_dim). + global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1) + + # run sampling + sample = self.conditional_sample(batch_size, global_cond=global_cond) + + # `horizon` steps worth of actions (from the first observation). + action = sample[..., : self.action_dim] + # Extract `n_action_steps` steps worth of actions (from the current observation). + start = n_obs_steps - 1 + end = start + self.n_action_steps + action = action[:, start:end] + + return action + + def compute_loss(self, batch: dict[str, Tensor]) -> Tensor: + """ + This function expects `batch` to have (at least): + { + "observation.state": (B, n_obs_steps, state_dim) + "observation.image": (B, n_obs_steps, C, H, W) + "action": (B, horizon, action_dim) + "action_is_pad": (B, horizon) + } + """ + # Input validation. + assert set(batch).issuperset({"observation.state", "observation.image", "action", "action_is_pad"}) + batch_size, n_obs_steps = batch["observation.state"].shape[:2] + horizon = batch["action"].shape[1] + assert horizon == self.horizon + assert n_obs_steps == self.n_obs_steps + assert self.n_obs_steps == n_obs_steps + + # Extract image feature (first combine batch and sequence dims). + img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ...")) + # Separate batch and sequence dims. + img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size) + # Concatenate state and image features then flatten to (B, global_cond_dim). + global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1) + + trajectory = batch["action"] + + # Forward diffusion. + # Sample noise to add to the trajectory. + eps = torch.randn(trajectory.shape, device=trajectory.device) + # Sample a random noising timestep for each item in the batch. + timesteps = torch.randint( + low=0, + high=self.noise_scheduler.config.num_train_timesteps, + size=(trajectory.shape[0],), + device=trajectory.device, + ).long() + # Add noise to the clean trajectories according to the noise magnitude at each timestep. + noisy_trajectory = self.noise_scheduler.add_noise(trajectory, eps, timesteps) + + # Run the denoising network (that might denoise the trajectory, or attempt to predict the noise). + pred = self.unet(noisy_trajectory, timesteps, global_cond=global_cond) + + # Compute the loss. + # The targe is either the original trajectory, or the noise. + pred_type = self.noise_scheduler.config.prediction_type + if pred_type == "epsilon": + target = eps + elif pred_type == "sample": + target = batch["action"] + else: + raise ValueError(f"Unsupported prediction type {pred_type}") + + loss = F.mse_loss(pred, target, reduction="none") + + # Mask loss wherever the action is padded with copies (edges of the dataset trajectory). + if "action_is_pad" in batch: + in_episode_bound = ~batch["action_is_pad"] + loss = loss * in_episode_bound.unsqueeze(-1) + + return loss.mean() + + +class _RgbEncoder(nn.Module): + """Encoder an RGB image into a 1D feature vector. + + Includes the ability to normalize and crop the image first. + """ + + def __init__( + self, + input_shape: tuple[int, int, int], + norm_mean_std: tuple[float, float] = [1.0, 1.0], + crop_shape: tuple[int, int] | None = None, + random_crop: bool = False, + backbone_name: str = "resnet18", + pretrained_backbone: bool = False, + use_group_norm: bool = False, + num_keypoints: int = 32, + ): + """ + Args: + input_shape: channel-first input shape (C, H, W) + norm_mean_std: mean and standard deviation used for image normalization. Images are normalized as + (image - mean) / std. + crop_shape: (H, W) shape to crop to (must fit within the input shape). If not provided, no + cropping is done. + random_crop: Whether the crop should be random at training time (it's always a center crop in + eval mode). + backbone_name: The name of one of the available resnet models from torchvision (eg resnet18). + pretrained_backbone: whether to use timm pretrained weights. + use_group_norm: Whether to replace batch normalization with group normalization in the backbone. + The group sizes are set to be about 16 (to be precise, feature_dim // 16). + num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image). + """ + super().__init__() + if input_shape[0] != 3: + raise ValueError("Only RGB images are handled") + if not backbone_name.startswith("resnet"): + raise ValueError( + "Only resnet is supported for now (because of the assumption that 'layer4' is the output layer)" + ) + + # Set up optional preprocessing. + if norm_mean_std == [1.0, 1.0]: + self.normalizer = nn.Identity() + else: + self.normalizer = torchvision.transforms.Normalize(mean=norm_mean_std[0], std=norm_mean_std[1]) + + if crop_shape is not None: + self.do_crop = True + # Always use center crop for eval + self.center_crop = torchvision.transforms.CenterCrop(crop_shape) + if random_crop: + self.maybe_random_crop = torchvision.transforms.RandomCrop(crop_shape) + else: + self.maybe_random_crop = self.center_crop + else: + self.do_crop = False + + # Set up backbone. + backbone_model = getattr(torchvision.models, backbone_name)(pretrained=pretrained_backbone) + # Note: This assumes that the layer4 feature map is children()[-3] + # TODO(alexander-soare): Use a safer alternative. + self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2])) + if use_group_norm: + if pretrained_backbone: + raise ValueError( + "You can't replace BatchNorm in a pretrained model without ruining the weights!" + ) + self.backbone = _replace_submodules( + root_module=self.backbone, + predicate=lambda x: isinstance(x, nn.BatchNorm2d), + func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features), + ) + + # Set up pooling and final layers. + # Use a dry run to get the feature map shape. + with torch.inference_mode(): + feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:]) + self.pool = SpatialSoftmax(feat_map_shape, num_kp=num_keypoints) + self.feature_dim = num_keypoints * 2 + self.out = nn.Linear(num_keypoints * 2, self.feature_dim) + self.relu = nn.ReLU() + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x: (B, C, H, W) image tensor with pixel values in [0, 1]. + Returns: + (B, D) image feature. + """ + # Preprocess: normalize and maybe crop (if it was set up in the __init__). + x = self.normalizer(x) + if self.do_crop: + if self.training: # noqa: SIM108 + x = self.maybe_random_crop(x) + else: + # Always use center crop for eval. + x = self.center_crop(x) + # Extract backbone feature. + x = torch.flatten(self.pool(self.backbone(x)), start_dim=1) + # Final linear layer with non-linearity. + x = self.relu(self.out(x)) + return x + + +def _replace_submodules( + root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module] +) -> nn.Module: + """ + Args: + root_module: The module for which the submodules need to be replaced + predicate: Takes a module as an argument and must return True if the that module is to be replaced. + func: Takes a module as an argument and returns a new module to replace it with. + Returns: + The root module with its submodules replaced. + """ + if predicate(root_module): + return func(root_module) + + replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)] + for *parents, k in replace_list: + parent_module = root_module + if len(parents) > 0: + parent_module = root_module.get_submodule(".".join(parents)) + if isinstance(parent_module, nn.Sequential): + src_module = parent_module[int(k)] + else: + src_module = getattr(parent_module, k) + tgt_module = func(src_module) + if isinstance(parent_module, nn.Sequential): + parent_module[int(k)] = tgt_module + else: + setattr(parent_module, k, tgt_module) + # verify that all BN are replaced + assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True)) + return root_module + + +class _SinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class _Conv1dBlock(nn.Module): + """Conv1d --> GroupNorm --> Mish""" + + def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): + super().__init__() + + self.block = nn.Sequential( + nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), + nn.GroupNorm(n_groups, out_channels), + nn.Mish(), + ) + + def forward(self, x): + return self.block(x) + + +class _ConditionalUnet1D(nn.Module): + """A 1D convolutional UNet with FiLM modulation for conditioning. + + Two types of conditioning can be applied: + - Global: Conditioning information that is aggregated over the whole observation window. This is + incorporated via the FiLM technique in the residual convolution blocks of the Unet's encoder/decoder. + - Local: Conditioning information for each timestep in the observation window. This is incorporated + by encoding the information via 1D convolutions and adding the resulting embeddings to the inputs and + outputs of the Unet's encoder/decoder. + """ + + def __init__( + self, + input_dim: int, + local_cond_dim: int | None = None, + global_cond_dim: int | None = None, + diffusion_step_embed_dim: int = 256, + down_dims: int | None = None, + kernel_size: int = 3, + n_groups: int = 8, + film_scale_modulation: bool = False, + ): + super().__init__() + + if down_dims is None: + down_dims = [256, 512, 1024] + + # Encoder for the diffusion timestep. + self.diffusion_step_encoder = nn.Sequential( + _SinusoidalPosEmb(diffusion_step_embed_dim), + nn.Linear(diffusion_step_embed_dim, diffusion_step_embed_dim * 4), + nn.Mish(), + nn.Linear(diffusion_step_embed_dim * 4, diffusion_step_embed_dim), + ) + + # The FiLM conditioning dimension. + cond_dim = diffusion_step_embed_dim + if global_cond_dim is not None: + cond_dim += global_cond_dim + + self.local_cond_down_encoder = None + self.local_cond_up_encoder = None + if local_cond_dim is not None: + # Encoder for the local conditioning. The output gets added to the Unet encoder input. + self.local_cond_down_encoder = _ConditionalResidualBlock1D( + local_cond_dim, + down_dims[0], + cond_dim=cond_dim, + kernel_size=kernel_size, + n_groups=n_groups, + film_scale_modulation=film_scale_modulation, + ) + # Encoder for the local conditioning. The output gets added to the Unet encoder output. + self.local_cond_up_encoder = _ConditionalResidualBlock1D( + local_cond_dim, + down_dims[0], + cond_dim=cond_dim, + kernel_size=kernel_size, + n_groups=n_groups, + film_scale_modulation=film_scale_modulation, + ) + + # In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we + # just reverse these. + in_out = [(input_dim, down_dims[0])] + list(zip(down_dims[:-1], down_dims[1:], strict=True)) + + # Unet encoder. + self.down_modules = nn.ModuleList([]) + for ind, (dim_in, dim_out) in enumerate(in_out): + is_last = ind >= (len(in_out) - 1) + self.down_modules.append( + nn.ModuleList( + [ + _ConditionalResidualBlock1D( + dim_in, + dim_out, + cond_dim=cond_dim, + kernel_size=kernel_size, + n_groups=n_groups, + film_scale_modulation=film_scale_modulation, + ), + _ConditionalResidualBlock1D( + dim_out, + dim_out, + cond_dim=cond_dim, + kernel_size=kernel_size, + n_groups=n_groups, + film_scale_modulation=film_scale_modulation, + ), + # Downsample as long as it is not the last block. + nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(), + ] + ) + ) + + # Processing in the middle of the auto-encoder. + self.mid_modules = nn.ModuleList( + [ + _ConditionalResidualBlock1D( + down_dims[-1], + down_dims[-1], + cond_dim=cond_dim, + kernel_size=kernel_size, + n_groups=n_groups, + film_scale_modulation=film_scale_modulation, + ), + _ConditionalResidualBlock1D( + down_dims[-1], + down_dims[-1], + cond_dim=cond_dim, + kernel_size=kernel_size, + n_groups=n_groups, + film_scale_modulation=film_scale_modulation, + ), + ] + ) + + # Unet decoder. + self.up_modules = nn.ModuleList([]) + for ind, (dim_out, dim_in) in enumerate(reversed(in_out[1:])): + is_last = ind >= (len(in_out) - 1) + self.up_modules.append( + nn.ModuleList( + [ + _ConditionalResidualBlock1D( + dim_in * 2, # x2 as it takes the encoder's skip connection as well + dim_out, + cond_dim=cond_dim, + kernel_size=kernel_size, + n_groups=n_groups, + film_scale_modulation=film_scale_modulation, + ), + _ConditionalResidualBlock1D( + dim_out, + dim_out, + cond_dim=cond_dim, + kernel_size=kernel_size, + n_groups=n_groups, + film_scale_modulation=film_scale_modulation, + ), + # Upsample as long as it is not the last block. + nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(), + ] + ) + ) + + self.final_conv = nn.Sequential( + _Conv1dBlock(down_dims[0], down_dims[0], kernel_size=kernel_size), + nn.Conv1d(down_dims[0], input_dim, 1), + ) + + def forward(self, x: Tensor, timestep: Tensor | int, local_cond=None, global_cond=None) -> Tensor: + """ + Args: + x: (B, T, input_dim) tensor for input to the Unet. + timestep: (B,) tensor of (timestep_we_are_denoising_from - 1). + local_cond: (B, T, local_cond_dim) + global_cond: (B, global_cond_dim) + output: (B, T, input_dim) + Returns: + (B, T, input_dim) + """ + # For 1D convolutions we'll need feature dimension first. + x = einops.rearrange(x, "b t d -> b d t") + if local_cond is not None: + if self.local_cond_down_encoder is None or self.local_cond_up_encoder is None: + raise ValueError( + "`local_cond` was provided but the relevant encoders weren't built at initialization." + ) + local_cond = einops.rearrange(local_cond, "b t d -> b d t") + + timesteps_embed = self.diffusion_step_encoder(timestep) + + # If there is a global conditioning feature, concatenate it to the timestep embedding. + if global_cond is not None: + global_feature = torch.cat([timesteps_embed, global_cond], axis=-1) + else: + global_feature = timesteps_embed + + encoder_skip_features: list[Tensor] = [] + for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules): + x = resnet(x, global_feature) + if idx == 0 and local_cond is not None: + x = x + self.local_cond_down_encoder(local_cond, global_feature) + x = resnet2(x, global_feature) + encoder_skip_features.append(x) + x = downsample(x) + + for mid_module in self.mid_modules: + x = mid_module(x, global_feature) + + for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules): + x = torch.cat((x, encoder_skip_features.pop()), dim=1) + x = resnet(x, global_feature) + # Note: The condition in the original implementation is: + # if idx == len(self.up_modules) and local_cond is not None: + # But as they mention in their comments, this is incorrect. We use the correct condition here. + if idx == (len(self.up_modules) - 1) and local_cond is not None: + x = x + self.local_cond_up_encoder(local_cond, global_feature) + x = resnet2(x, global_feature) + x = upsample(x) + + x = self.final_conv(x) + + x = einops.rearrange(x, "b d t -> b t d") + return x + + +class _ConditionalResidualBlock1D(nn.Module): + """ResNet style 1D convolutional block with FiLM modulation for conditioning.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + cond_dim: int, + kernel_size: int = 3, + n_groups: int = 8, + # Set to True to do scale modulation with FiLM as well as bias modulation (defaults to False meaning + # FiLM just modulates bias). + film_scale_modulation: bool = False, + ): + super().__init__() + + self.film_scale_modulation = film_scale_modulation + self.out_channels = out_channels + + self.conv1 = _Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups) + + # FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale. + cond_channels = out_channels * 2 if film_scale_modulation else out_channels + self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels)) + + self.conv2 = _Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups) + + # A final convolution for dimension matching the residual (if needed). + self.residual_conv = ( + nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity() + ) + + def forward(self, x: Tensor, cond: Tensor) -> Tensor: + """ + Args: + x: (B, in_channels, T) + cond: (B, cond_dim) + Returns: + (B, out_channels, T) + """ + out = self.conv1(x) + + # Get condition embedding. Unsqueeze for broadcasting to `out`, resulting in (B, out_channels, 1). + cond_embed = self.cond_encoder(cond).unsqueeze(-1) + if self.film_scale_modulation: + # Treat the embedding as a list of scales and biases. + scale = cond_embed[:, : self.out_channels] + bias = cond_embed[:, self.out_channels :] + out = scale * out + bias + else: + # Treat the embedding as biases. + out = out + cond_embed + + out = self.conv2(out) + out = out + self.residual_conv(x) + return out + + +class _EMA: + """ + Exponential Moving Average of models weights + """ + + def __init__( + self, model, update_after_step=0, inv_gamma=1.0, power=2 / 3, min_value=0.0, max_value=0.9999 + ): + """ + @crowsonkb's notes on EMA Warmup: + If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan + to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), + gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 + at 215.4k steps). + Args: + inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. + power (float): Exponential factor of EMA warmup. Default: 2/3. + min_value (float): The minimum EMA decay rate. Default: 0. + """ + + self.averaged_model = model + self.averaged_model.eval() + self.averaged_model.requires_grad_(False) + + self.update_after_step = update_after_step + self.inv_gamma = inv_gamma + self.power = power + self.min_value = min_value + self.max_value = max_value + + self.alpha = 0.0 + self.optimization_step = 0 + + def get_decay(self, optimization_step): + """ + Compute the decay factor for the exponential moving average. + """ + step = max(0, optimization_step - self.update_after_step - 1) + value = 1 - (1 + step / self.inv_gamma) ** -self.power + + if step <= 0: + return 0.0 + + return max(self.min_value, min(value, self.max_value)) + + @torch.no_grad() + def step(self, new_model): + self.alpha = self.get_decay(self.optimization_step) + + for module, ema_module in zip(new_model.modules(), self.averaged_model.modules(), strict=True): + # Iterate over immediate parameters only. + for param, ema_param in zip( + module.parameters(recurse=False), ema_module.parameters(recurse=False), strict=True + ): + if isinstance(param, dict): + raise RuntimeError("Dict parameter not supported") + if isinstance(module, _BatchNorm) or not param.requires_grad: + # Copy BatchNorm parameters, and non-trainable parameters directly. + ema_param.copy_(param.to(dtype=ema_param.dtype).data) + else: + ema_param.mul_(self.alpha) + ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.alpha) + + self.optimization_step += 1 diff --git a/lerobot/common/policies/diffusion/policy.py b/lerobot/common/policies/diffusion/policy.py deleted file mode 100644 index f88e2f25..00000000 --- a/lerobot/common/policies/diffusion/policy.py +++ /dev/null @@ -1,169 +0,0 @@ -import copy -import logging -import time -from collections import deque - -import hydra -import torch -from diffusers.optimization import get_scheduler -from torch import nn - -from lerobot.common.policies.diffusion.model.diffusion_unet_image_policy import DiffusionUnetImagePolicy -from lerobot.common.policies.utils import populate_queues -from lerobot.common.utils import get_safe_torch_device - - -class DiffusionPolicy(nn.Module): - name = "diffusion" - - def __init__( - self, - cfg, - cfg_device, - cfg_noise_scheduler, - cfg_optimizer, - cfg_ema, - shape_meta: dict, - horizon, - n_action_steps, - n_obs_steps, - num_inference_steps=None, - diffusion_step_embed_dim=256, - down_dims=(256, 512, 1024), - kernel_size=5, - n_groups=8, - film_scale_modulation=True, - **_, - ): - super().__init__() - self.cfg = cfg - self.n_obs_steps = n_obs_steps - self.n_action_steps = n_action_steps - - # queues are populated during rollout of the policy, they contain the n latest observations and actions - self._queues = None - - noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler) - - self.diffusion = DiffusionUnetImagePolicy( - cfg, - shape_meta=shape_meta, - noise_scheduler=noise_scheduler, - horizon=horizon, - n_action_steps=n_action_steps, - n_obs_steps=n_obs_steps, - num_inference_steps=num_inference_steps, - diffusion_step_embed_dim=diffusion_step_embed_dim, - down_dims=down_dims, - kernel_size=kernel_size, - n_groups=n_groups, - film_scale_modulation=film_scale_modulation, - ) - - self.device = get_safe_torch_device(cfg_device) - self.diffusion.to(self.device) - - # TODO(alexander-soare): This should probably be managed outside of the policy class. - self.ema_diffusion = None - self.ema = None - if self.cfg.use_ema: - self.ema_diffusion = copy.deepcopy(self.diffusion) - self.ema = hydra.utils.instantiate( - cfg_ema, - model=self.ema_diffusion, - ) - - self.optimizer = hydra.utils.instantiate( - cfg_optimizer, - params=self.diffusion.parameters(), - ) - - # TODO(rcadene): modify lr scheduler so that it doesnt depend on epochs but steps - self.global_step = 0 - - # configure lr scheduler - self.lr_scheduler = get_scheduler( - cfg.lr_scheduler, - optimizer=self.optimizer, - num_warmup_steps=cfg.lr_warmup_steps, - num_training_steps=cfg.offline_steps, - # pytorch assumes stepping LRScheduler every epoch - # however huggingface diffusers steps it every batch - last_epoch=self.global_step - 1, - ) - - def reset(self): - """ - Clear observation and action queues. Should be called on `env.reset()` - """ - self._queues = { - "observation.image": deque(maxlen=self.n_obs_steps), - "observation.state": deque(maxlen=self.n_obs_steps), - "action": deque(maxlen=self.n_action_steps), - } - - @torch.no_grad - def select_action(self, batch, **_): - """ - Note: this uses the ema model weights if self.training == False, otherwise the non-ema model weights. - """ - assert "observation.image" in batch - assert "observation.state" in batch - assert len(batch) == 2 - - self._queues = populate_queues(self._queues, batch) - - if len(self._queues["action"]) == 0: - # stack n latest observations from the queue - batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch} - if not self.training and self.ema_diffusion is not None: - actions = self.ema_diffusion.generate_actions(batch) - else: - actions = self.diffusion.generate_actions(batch) - self._queues["action"].extend(actions.transpose(0, 1)) - - action = self._queues["action"].popleft() - return action - - def forward(self, batch, **_): - start_time = time.time() - - self.diffusion.train() - - loss = self.diffusion.compute_loss(batch) - loss.backward() - - grad_norm = torch.nn.utils.clip_grad_norm_( - self.diffusion.parameters(), - self.cfg.grad_clip_norm, - error_if_nonfinite=False, - ) - - self.optimizer.step() - self.optimizer.zero_grad() - self.lr_scheduler.step() - - if self.ema is not None: - self.ema.step(self.diffusion) - - info = { - "loss": loss.item(), - "grad_norm": float(grad_norm), - "lr": self.lr_scheduler.get_last_lr()[0], - "update_s": time.time() - start_time, - } - - return info - - def save(self, fp): - torch.save(self.state_dict(), fp) - - def load(self, fp): - d = torch.load(fp) - missing_keys, unexpected_keys = self.load_state_dict(d, strict=False) - if len(missing_keys) > 0: - assert all(k.startswith("ema_diffusion.") for k in missing_keys) - logging.warning( - "DiffusionPolicy.load expected ema parameters in loaded state dict but none were found." - ) - assert len(unexpected_keys) == 0 diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index f0454b8e..0b185e86 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -1,59 +1,61 @@ import inspect -from omegaconf import OmegaConf +from omegaconf import DictConfig, OmegaConf from lerobot.common.utils import get_safe_torch_device -def make_policy(cfg): - if cfg.policy.name == "tdmpc": +def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg): + expected_kwargs = set(inspect.signature(policy_cfg_class).parameters) + assert set(hydra_cfg.policy).issuperset( + expected_kwargs + ), f"Hydra config is missing arguments: {set(hydra_cfg.policy).difference(expected_kwargs)}" + policy_cfg = policy_cfg_class( + **{ + k: v + for k, v in OmegaConf.to_container(hydra_cfg.policy, resolve=True).items() + if k in expected_kwargs + } + ) + return policy_cfg + + +def make_policy(hydra_cfg: DictConfig): + if hydra_cfg.policy.name == "tdmpc": from lerobot.common.policies.tdmpc.policy import TDMPCPolicy policy = TDMPCPolicy( - cfg.policy, n_obs_steps=cfg.n_obs_steps, n_action_steps=cfg.n_action_steps, device=cfg.device + hydra_cfg.policy, + n_obs_steps=hydra_cfg.n_obs_steps, + n_action_steps=hydra_cfg.n_action_steps, + device=hydra_cfg.device, ) - elif cfg.policy.name == "diffusion": - from lerobot.common.policies.diffusion.policy import DiffusionPolicy + elif hydra_cfg.policy.name == "diffusion": + from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig + from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy - policy = DiffusionPolicy( - cfg=cfg.policy, - cfg_device=cfg.device, - cfg_noise_scheduler=cfg.noise_scheduler, - cfg_optimizer=cfg.optimizer, - cfg_ema=cfg.ema, - # n_obs_steps=cfg.n_obs_steps, - # n_action_steps=cfg.n_action_steps, - **cfg.policy, - ) - elif cfg.policy.name == "act": + policy_cfg = _policy_cfg_from_hydra_cfg(DiffusionConfig, hydra_cfg) + policy = DiffusionPolicy(policy_cfg) + policy.to(get_safe_torch_device(hydra_cfg.device)) + elif hydra_cfg.policy.name == "act": from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy - expected_kwargs = set(inspect.signature(ActionChunkingTransformerConfig).parameters) - assert set(cfg.policy).issuperset( - expected_kwargs - ), f"Hydra config is missing arguments: {set(cfg.policy).difference(expected_kwargs)}" - policy_cfg = ActionChunkingTransformerConfig( - **{ - k: v - for k, v in OmegaConf.to_container(cfg.policy, resolve=True).items() - if k in expected_kwargs - } - ) + policy_cfg = _policy_cfg_from_hydra_cfg(ActionChunkingTransformerConfig, hydra_cfg) policy = ActionChunkingTransformerPolicy(policy_cfg) - policy.to(get_safe_torch_device(cfg.device)) + policy.to(get_safe_torch_device(hydra_cfg.device)) else: - raise ValueError(cfg.policy.name) + raise ValueError(hydra_cfg.policy.name) - if cfg.policy.pretrained_model_path: + if hydra_cfg.policy.pretrained_model_path: # TODO(rcadene): hack for old pretrained models from fowm - if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path: - if "offline" in cfg.policy.pretrained_model_path: + if hydra_cfg.policy.name == "tdmpc" and "fowm" in hydra_cfg.policy.pretrained_model_path: + if "offline" in hydra_cfg.policy.pretrained_model_path: policy.step[0] = 25000 - elif "final" in cfg.policy.pretrained_model_path: + elif "final" in hydra_cfg.policy.pretrained_model_path: policy.step[0] = 100000 else: raise NotImplementedError() - policy.load(cfg.policy.pretrained_model_path) + policy.load(hydra_cfg.policy.pretrained_model_path) return policy diff --git a/lerobot/configs/policy/act.yaml b/lerobot/configs/policy/act.yaml index bd883613..5dd70d71 100644 --- a/lerobot/configs/policy/act.yaml +++ b/lerobot/configs/policy/act.yaml @@ -18,7 +18,7 @@ policy: pretrained_model_path: # Environment. - # Inherit these from the environment. + # Inherit these from the environment config. state_dim: ??? action_dim: ??? diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index 005b0517..c8bdb0c4 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -1,17 +1,5 @@ # @package _global_ -shape_meta: - # acceptable types: rgb, low_dim - obs: - image: - shape: [3, 96, 96] - type: rgb - agent_pos: - shape: [2] - type: low_dim - action: - shape: [2] - seed: 100000 horizon: 16 n_obs_steps: 2 @@ -33,75 +21,70 @@ offline_prioritized_sampler: true policy: name: diffusion - shape_meta: ${shape_meta} + pretrained_model_path: - horizon: ${horizon} + # Environment. + # Inherit these from the environment config. + state_dim: ??? + action_dim: ??? + image_size: + - ${env.image_size} # height + - ${env.image_size} # width + + # Inputs / output structure. n_obs_steps: ${n_obs_steps} + horizon: ${horizon} n_action_steps: ${n_action_steps} - num_inference_steps: 100 - # crop_shape: null - diffusion_step_embed_dim: 128 + + # Vision preprocessing. + image_normalization_mean: [0.5, 0.5, 0.5] + image_normalization_std: [0.5, 0.5, 0.5] + + # Architecture / modeling. + # Vision backbone. + vision_backbone: resnet18 + crop_shape: [84, 84] + random_crop: True + use_pretrained_backbone: false + use_group_norm: True + spatial_softmax_num_keypoints: 32 + # Unet. down_dims: [512, 1024, 2048] kernel_size: 5 n_groups: 8 + diffusion_step_embed_dim: 128 film_scale_modulation: True - - pretrained_model_path: - - batch_size: 64 - - per_alpha: 0.6 - per_beta: 0.4 - - balanced_sampling: false - utd: 1 - offline_steps: ${offline_steps} - use_ema: true - lr_scheduler: cosine - lr_warmup_steps: 500 - grad_clip_norm: 10 - - delta_timestamps: - observation.image: [-0.1, 0] - observation.state: [-0.1, 0] - action: [-0.1, 0, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1.0, 1.1, 1.2, 1.3, 1.4] - - rgb_encoder: - backbone_name: resnet18 - pretrained_backbone: false - use_group_norm: True - num_keypoints: 32 - relu: true - norm_mean_std: [0.5, 0.5] # for PushT the original impl normalizes to [-1, 1] (maybe not the case for robomimic envs) - crop_shape: [84, 84] - random_crop: True - -noise_scheduler: - _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler + # Noise scheduler. num_train_timesteps: 100 + beta_schedule: squaredcos_cap_v2 beta_start: 0.0001 beta_end: 0.02 - beta_schedule: squaredcos_cap_v2 - variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan - clip_sample: True # required when predict_epsilon=False - prediction_type: epsilon # or sample + variance_type: fixed_small + prediction_type: epsilon # epsilon / sample + clip_sample: True -rgb_model: - pretrained: false - num_keypoints: 32 - relu: true + # Inference + num_inference_steps: 100 -ema: - _target_: lerobot.common.policies.diffusion.model.ema_model.EMAModel - update_after_step: 0 - inv_gamma: 1.0 - power: 0.75 - min_value: 0.0 - max_value: 0.9999 - -optimizer: - _target_: torch.optim.AdamW + # --- + # TODO(alexander-soare): Remove these from the policy config. + batch_size: 64 + grad_clip_norm: 10 lr: 1.0e-4 - betas: [0.95, 0.999] - eps: 1.0e-8 - weight_decay: 1.0e-6 + lr_scheduler: cosine + lr_warmup_steps: 500 + adam_betas: [0.95, 0.999] + adam_eps: 1.0e-8 + adam_weight_decay: 1.0e-6 + utd: 1 + use_ema: true + ema_update_after_step: 0 + ema_min_rate: 0.0 + ema_max_rate: 0.9999 + ema_inv_gamma: 1.0 + ema_power: 0.75 + + delta_timestamps: + observation.images: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]" + observation.state: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]" + action: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1 - ${n_obs_steps} + ${policy.horizon})]" diff --git a/tests/test_available.py b/tests/test_available.py index b25a921f..373cc1a7 100644 --- a/tests/test_available.py +++ b/tests/test_available.py @@ -19,7 +19,7 @@ from lerobot.common.datasets.aloha import AlohaDataset from lerobot.common.datasets.pusht import PushtDataset from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy -from lerobot.common.policies.diffusion.policy import DiffusionPolicy +from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.common.policies.tdmpc.policy import TDMPCPolicy