From 976a197f9851be45906bbbd158cb0ca058e223d2 Mon Sep 17 00:00:00 2001
From: Alexander Soare <alexander.soare159@gmail.com>
Date: Thu, 11 Apr 2024 17:51:35 +0100
Subject: [PATCH 1/8] backup wip

---
 .../_diffusion_policy_replay_buffer.py}       |   5 +
 lerobot/common/datasets/pusht.py              |   4 +-
 lerobot/common/policies/act/policy.py         |  11 +-
 .../diffusion/diffusion_unet_image_policy.py  | 315 ------
 .../diffusion/model/conditional_unet1d.py     | 401 ++++----
 .../diffusion/model/conv1d_components.py      |  47 -
 .../diffusion/model/crop_randomizer.py        | 294 ------
 .../diffusion/model/dict_of_tensor_mixin.py   |  41 -
 .../model/diffusion_unet_image_policy.py      | 220 ++++
 .../policies/diffusion/model/ema_model.py     |  13 -
 .../policies/diffusion/model/lr_scheduler.py  |  46 -
 .../diffusion/model/mask_generator.py         |  65 --
 .../diffusion/model/module_attr_mixin.py      |  15 -
 .../model/multi_image_obs_encoder.py          | 214 ----
 .../policies/diffusion/model/normalizer.py    | 358 -------
 .../diffusion/model/positional_embedding.py   |  19 -
 .../policies/diffusion/model/rgb_encoder.py   | 147 +++
 .../policies/diffusion/model/tensor_utils.py  | 972 ------------------
 lerobot/common/policies/diffusion/policy.py   |  78 +-
 .../policies/diffusion/pytorch_utils.py       |  76 --
 lerobot/common/policies/factory.py            |   2 -
 lerobot/common/policies/utils.py              |  22 +
 lerobot/common/utils.py                       |   1 +
 lerobot/configs/policy/diffusion.yaml         |  24 +-
 lerobot/scripts/eval.py                       |   2 +-
 lerobot/scripts/train.py                      |   2 +-
 26 files changed, 661 insertions(+), 2733 deletions(-)
 rename lerobot/common/{policies/diffusion/replay_buffer.py => datasets/_diffusion_policy_replay_buffer.py} (99%)
 delete mode 100644 lerobot/common/policies/diffusion/diffusion_unet_image_policy.py
 delete mode 100644 lerobot/common/policies/diffusion/model/conv1d_components.py
 delete mode 100644 lerobot/common/policies/diffusion/model/crop_randomizer.py
 delete mode 100644 lerobot/common/policies/diffusion/model/dict_of_tensor_mixin.py
 create mode 100644 lerobot/common/policies/diffusion/model/diffusion_unet_image_policy.py
 delete mode 100644 lerobot/common/policies/diffusion/model/lr_scheduler.py
 delete mode 100644 lerobot/common/policies/diffusion/model/mask_generator.py
 delete mode 100644 lerobot/common/policies/diffusion/model/module_attr_mixin.py
 delete mode 100644 lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py
 delete mode 100644 lerobot/common/policies/diffusion/model/normalizer.py
 delete mode 100644 lerobot/common/policies/diffusion/model/positional_embedding.py
 create mode 100644 lerobot/common/policies/diffusion/model/rgb_encoder.py
 delete mode 100644 lerobot/common/policies/diffusion/model/tensor_utils.py
 delete mode 100644 lerobot/common/policies/diffusion/pytorch_utils.py

diff --git a/lerobot/common/policies/diffusion/replay_buffer.py b/lerobot/common/datasets/_diffusion_policy_replay_buffer.py
similarity index 99%
rename from lerobot/common/policies/diffusion/replay_buffer.py
rename to lerobot/common/datasets/_diffusion_policy_replay_buffer.py
index 7fccf74d..1697f9fc 100644
--- a/lerobot/common/policies/diffusion/replay_buffer.py
+++ b/lerobot/common/datasets/_diffusion_policy_replay_buffer.py
@@ -1,3 +1,8 @@
+"""Helper code for loading PushT dataset from Diffusion Policy (https://diffusion-policy.cs.columbia.edu/)
+
+Copied from the original Diffusion Policy repository.
+"""
+
 from __future__ import annotations
 
 import math
diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py
index b468637e..9af6f3a1 100644
--- a/lerobot/common/datasets/pusht.py
+++ b/lerobot/common/datasets/pusht.py
@@ -8,8 +8,10 @@ import torch
 import tqdm
 from gym_pusht.envs.pusht import pymunk_to_shapely
 
+from lerobot.common.datasets._diffusion_policy_replay_buffer import (
+    ReplayBuffer as DiffusionPolicyReplayBuffer,
+)
 from lerobot.common.datasets.utils import download_and_extract_zip, load_data_with_delta_timestamps
-from lerobot.common.policies.diffusion.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
 
 # as define in env
 SUCCESS_THRESHOLD = 0.95  # 95% coverage,
diff --git a/lerobot/common/policies/act/policy.py b/lerobot/common/policies/act/policy.py
index 25b814ed..821b0196 100644
--- a/lerobot/common/policies/act/policy.py
+++ b/lerobot/common/policies/act/policy.py
@@ -176,7 +176,8 @@ class ActionChunkingTransformerPolicy(nn.Module):
         if self.n_action_steps is not None:
             self._action_queue = deque([], maxlen=self.n_action_steps)
 
-    def select_action(self, batch: dict[str, Tensor], *_, **__) -> Tensor:
+    @torch.no_grad
+    def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
         """
         This method wraps `select_actions` in order to return one action at a time for execution in the
         environment. It works by managing the actions in a queue and only calling `select_actions` when the
@@ -188,7 +189,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
             self._action_queue.extend(self.select_actions(batch).transpose(0, 1))
         return self._action_queue.popleft()
 
-    @torch.no_grad()
+    @torch.no_grad
     def select_actions(self, batch: dict[str, Tensor]) -> Tensor:
         """Use the action chunking transformer to generate a sequence of actions."""
         self.eval()
@@ -223,8 +224,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
         {
             "observation.state": (B, 1, J) OR (B, J) tensor of robot states (joint configuration).
             "observation.images.top": (B, 1, C, H, W) OR (B, C, H, W) tensor of images.
-            "action": (B, H, J) tensor of actions (positional target for robot joint configuration)
-            "action_is_pad": (B, H) mask for whether the actions are padding outside of the episode bounds.
         }
         """
         if add_obs_steps_dim:
@@ -244,7 +243,8 @@ class ActionChunkingTransformerPolicy(nn.Module):
         # Note: no squeeze is required for "observation.images.top" because then we'd have to unsqueeze to get
         # the image index dimension.
 
-    def update(self, batch, *_, **__) -> dict:
+    def update(self, batch, **_) -> dict:
+        """Run the model in train mode, compute the loss, and do an optimization step."""
         start_time = time.time()
         self._preprocess_batch(batch)
 
@@ -278,6 +278,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
         return info
 
     def forward(self, batch: dict[str, Tensor], return_loss: bool = False) -> dict | Tensor:
+        """A forward pass through the DNN part of this policy with optional loss computation."""
         images = self.image_normalizer(batch["observation.images.top"])
 
         if return_loss:  # training time
diff --git a/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py b/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py
deleted file mode 100644
index f7432db3..00000000
--- a/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py
+++ /dev/null
@@ -1,315 +0,0 @@
-"""Code from the original diffusion policy project.
-
-Notes on how to load a checkpoint from the original repository:
-
-In the original repository, run the eval and use a breakpoint to extract the policy weights.
-
-```
-torch.save(policy.state_dict(), "weights.pt")
-```
-
-In this repository, add a breakpoint somewhere after creating an equivalent policy and load in the weights:
-
-```
-loaded = torch.load("weights.pt")
-aligned = {}
-their_prefix = "obs_encoder.obs_nets.image.backbone"
-our_prefix = "obs_encoder.key_model_map.image.backbone"
-aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)})
-their_prefix = "obs_encoder.obs_nets.image.pool"
-our_prefix = "obs_encoder.key_model_map.image.pool"
-aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)})
-their_prefix = "obs_encoder.obs_nets.image.nets.3"
-our_prefix = "obs_encoder.key_model_map.image.out"
-aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)})
-aligned.update({k: v for k, v in loaded.items() if k.startswith('model.')})
-# Note: here you are loading into the ema model.
-missing_keys, unexpected_keys = policy.ema_diffusion.load_state_dict(aligned, strict=False)
-assert all('_dummy_variable' in k for k in missing_keys)
-assert len(unexpected_keys) == 0
-```
-
-Then in that same runtime you can also save the weights with the new aligned state_dict:
-
-```
-policy.save("weights.pt")
-```
-
-Now you can remove the breakpoint and extra code and load in the weights just like with any other lerobot checkpoint.
-
-"""
-
-from typing import Dict
-
-import torch
-import torch.nn.functional as F  # noqa: N812
-from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
-from einops import reduce
-
-from lerobot.common.policies.diffusion.model.conditional_unet1d import ConditionalUnet1D
-from lerobot.common.policies.diffusion.model.mask_generator import LowdimMaskGenerator
-from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin
-from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder
-from lerobot.common.policies.diffusion.model.normalizer import LinearNormalizer
-from lerobot.common.policies.diffusion.pytorch_utils import dict_apply
-
-
-class BaseImagePolicy(ModuleAttrMixin):
-    # init accepts keyword argument shape_meta, see config/task/*_image.yaml
-
-    def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
-        """
-        obs_dict:
-            str: B,To,*
-        return: B,Ta,Da
-        """
-        raise NotImplementedError()
-
-    # reset state for stateful policies
-    def reset(self):
-        pass
-
-    # ========== training ===========
-    # no standard training interface except setting normalizer
-    def set_normalizer(self, normalizer: LinearNormalizer):
-        raise NotImplementedError()
-
-
-class DiffusionUnetImagePolicy(BaseImagePolicy):
-    def __init__(
-        self,
-        shape_meta: dict,
-        noise_scheduler: DDPMScheduler,
-        obs_encoder: MultiImageObsEncoder,
-        horizon,
-        n_action_steps,
-        n_obs_steps,
-        num_inference_steps=None,
-        obs_as_global_cond=True,
-        diffusion_step_embed_dim=256,
-        down_dims=(256, 512, 1024),
-        kernel_size=5,
-        n_groups=8,
-        cond_predict_scale=True,
-        # parameters passed to step
-        **kwargs,
-    ):
-        super().__init__()
-
-        # parse shapes
-        action_shape = shape_meta["action"]["shape"]
-        assert len(action_shape) == 1
-        action_dim = action_shape[0]
-        # get feature dim
-        obs_feature_dim = obs_encoder.output_shape()[0]
-
-        # create diffusion model
-        input_dim = action_dim + obs_feature_dim
-        global_cond_dim = None
-        if obs_as_global_cond:
-            input_dim = action_dim
-            global_cond_dim = obs_feature_dim * n_obs_steps
-
-        model = ConditionalUnet1D(
-            input_dim=input_dim,
-            local_cond_dim=None,
-            global_cond_dim=global_cond_dim,
-            diffusion_step_embed_dim=diffusion_step_embed_dim,
-            down_dims=down_dims,
-            kernel_size=kernel_size,
-            n_groups=n_groups,
-            cond_predict_scale=cond_predict_scale,
-        )
-
-        self.obs_encoder = obs_encoder
-        self.model = model
-        self.noise_scheduler = noise_scheduler
-        self.mask_generator = LowdimMaskGenerator(
-            action_dim=action_dim,
-            obs_dim=0 if obs_as_global_cond else obs_feature_dim,
-            max_n_obs_steps=n_obs_steps,
-            fix_obs_steps=True,
-            action_visible=False,
-        )
-        self.horizon = horizon
-        self.obs_feature_dim = obs_feature_dim
-        self.action_dim = action_dim
-        self.n_action_steps = n_action_steps
-        self.n_obs_steps = n_obs_steps
-        self.obs_as_global_cond = obs_as_global_cond
-        self.kwargs = kwargs
-
-        if num_inference_steps is None:
-            num_inference_steps = noise_scheduler.config.num_train_timesteps
-        self.num_inference_steps = num_inference_steps
-
-    # ========= inference  ============
-    def conditional_sample(
-        self,
-        condition_data,
-        condition_mask,
-        local_cond=None,
-        global_cond=None,
-        generator=None,
-        # keyword arguments to scheduler.step
-        **kwargs,
-    ):
-        model = self.model
-        scheduler = self.noise_scheduler
-
-        trajectory = torch.randn(
-            size=condition_data.shape,
-            dtype=condition_data.dtype,
-            device=condition_data.device,
-            generator=generator,
-        )
-
-        # set step values
-        scheduler.set_timesteps(self.num_inference_steps)
-
-        for t in scheduler.timesteps:
-            # 1. apply conditioning
-            trajectory[condition_mask] = condition_data[condition_mask]
-
-            # 2. predict model output
-            model_output = model(trajectory, t, local_cond=local_cond, global_cond=global_cond)
-
-            # 3. compute previous image: x_t -> x_t-1
-            trajectory = scheduler.step(
-                model_output,
-                t,
-                trajectory,
-                generator=generator,
-                # **kwargs  # TODO(rcadene): in diffusion_policy, expected to be {}
-            ).prev_sample
-
-        # finally make sure conditioning is enforced
-        trajectory[condition_mask] = condition_data[condition_mask]
-
-        return trajectory
-
-    def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
-        """
-        obs_dict: must include "obs" key
-        result: must include "action" key
-        """
-        assert "past_action" not in obs_dict  # not implemented yet
-        nobs = obs_dict
-        value = next(iter(nobs.values()))
-        bsize, n_obs_steps = value.shape[:2]
-        horizon = self.horizon
-        action_dim = self.action_dim
-        obs_dim = self.obs_feature_dim
-        assert self.n_obs_steps == n_obs_steps
-
-        # build input
-        device = self.device
-        dtype = self.dtype
-
-        # handle different ways of passing observation
-        local_cond = None
-        global_cond = None
-        if self.obs_as_global_cond:
-            # condition through global feature
-            this_nobs = dict_apply(nobs, lambda x: x[:, :n_obs_steps, ...].reshape(-1, *x.shape[2:]))
-            nobs_features = self.obs_encoder(this_nobs)
-            # reshape back to B, Do
-            global_cond = nobs_features.reshape(bsize, -1)
-            # empty data for action
-            cond_data = torch.zeros(size=(bsize, horizon, action_dim), device=device, dtype=dtype)
-            cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
-        else:
-            # condition through impainting
-            this_nobs = dict_apply(nobs, lambda x: x[:, :n_obs_steps, ...].reshape(-1, *x.shape[2:]))
-            nobs_features = self.obs_encoder(this_nobs)
-            # reshape back to B, T, Do
-            nobs_features = nobs_features.reshape(bsize, n_obs_steps, -1)
-            cond_data = torch.zeros(size=(bsize, horizon, action_dim + obs_dim), device=device, dtype=dtype)
-            cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
-            cond_data[:, :n_obs_steps, action_dim:] = nobs_features
-            cond_mask[:, :n_obs_steps, action_dim:] = True
-
-        # run sampling
-        nsample = self.conditional_sample(
-            cond_data, cond_mask, local_cond=local_cond, global_cond=global_cond
-        )
-
-        action_pred = nsample[..., :action_dim]
-        # get action
-        start = n_obs_steps - 1
-        end = start + self.n_action_steps
-        action = action_pred[:, start:end]
-
-        result = {"action": action, "action_pred": action_pred}
-        return result
-
-    def compute_loss(self, batch):
-        nobs = {
-            "image": batch["observation.image"],
-            "agent_pos": batch["observation.state"],
-        }
-        nactions = batch["action"]
-        batch_size = nactions.shape[0]
-        horizon = nactions.shape[1]
-
-        # handle different ways of passing observation
-        local_cond = None
-        global_cond = None
-        trajectory = nactions
-        cond_data = trajectory
-        if self.obs_as_global_cond:
-            # reshape B, T, ... to B*T
-            this_nobs = dict_apply(nobs, lambda x: x[:, : self.n_obs_steps, ...].reshape(-1, *x.shape[2:]))
-            nobs_features = self.obs_encoder(this_nobs)
-            # reshape back to B, Do
-            global_cond = nobs_features.reshape(batch_size, -1)
-        else:
-            # reshape B, T, ... to B*T
-            this_nobs = dict_apply(nobs, lambda x: x.reshape(-1, *x.shape[2:]))
-            nobs_features = self.obs_encoder(this_nobs)
-            # reshape back to B, T, Do
-            nobs_features = nobs_features.reshape(batch_size, horizon, -1)
-            cond_data = torch.cat([nactions, nobs_features], dim=-1)
-            trajectory = cond_data.detach()
-
-        # generate impainting mask
-        condition_mask = self.mask_generator(trajectory.shape)
-
-        # Sample noise that we'll add to the images
-        noise = torch.randn(trajectory.shape, device=trajectory.device)
-        bsz = trajectory.shape[0]
-        # Sample a random timestep for each image
-        timesteps = torch.randint(
-            0, self.noise_scheduler.config.num_train_timesteps, (bsz,), device=trajectory.device
-        ).long()
-        # Add noise to the clean images according to the noise magnitude at each timestep
-        # (this is the forward diffusion process)
-        noisy_trajectory = self.noise_scheduler.add_noise(trajectory, noise, timesteps)
-
-        # compute loss mask
-        loss_mask = ~condition_mask
-
-        # apply conditioning
-        noisy_trajectory[condition_mask] = cond_data[condition_mask]
-
-        # Predict the noise residual
-        pred = self.model(noisy_trajectory, timesteps, local_cond=local_cond, global_cond=global_cond)
-
-        pred_type = self.noise_scheduler.config.prediction_type
-        if pred_type == "epsilon":
-            target = noise
-        elif pred_type == "sample":
-            target = trajectory
-        else:
-            raise ValueError(f"Unsupported prediction type {pred_type}")
-
-        loss = F.mse_loss(pred, target, reduction="none")
-        loss = loss * loss_mask.type(loss.dtype)
-
-        if "action_is_pad" in batch:
-            in_episode_bound = ~batch["action_is_pad"]
-            loss = loss * in_episode_bound[:, :, None].type(loss.dtype)
-
-        loss = reduce(loss, "b t c -> b", "mean", b=batch_size)
-        loss = loss.mean()
-        return loss
diff --git a/lerobot/common/policies/diffusion/model/conditional_unet1d.py b/lerobot/common/policies/diffusion/model/conditional_unet1d.py
index d2971d38..5c43d488 100644
--- a/lerobot/common/policies/diffusion/model/conditional_unet1d.py
+++ b/lerobot/common/policies/diffusion/model/conditional_unet1d.py
@@ -1,286 +1,307 @@
 import logging
-from typing import Union
+import math
 
 import einops
 import torch
 import torch.nn as nn
-from einops.layers.torch import Rearrange
-
-from lerobot.common.policies.diffusion.model.conv1d_components import Conv1dBlock, Downsample1d, Upsample1d
-from lerobot.common.policies.diffusion.model.positional_embedding import SinusoidalPosEmb
+from torch import Tensor
 
 logger = logging.getLogger(__name__)
 
 
-class ConditionalResidualBlock1D(nn.Module):
+class _SinusoidalPosEmb(nn.Module):
+    # TODO(now): consolidate?
+    def __init__(self, dim):
+        super().__init__()
+        self.dim = dim
+
+    def forward(self, x):
+        device = x.device
+        half_dim = self.dim // 2
+        emb = math.log(10000) / (half_dim - 1)
+        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
+        emb = x[:, None] * emb[None, :]
+        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
+        return emb
+
+
+class _Conv1dBlock(nn.Module):
+    """Conv1d --> GroupNorm --> Mish"""
+
+    def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
+        super().__init__()
+
+        self.block = nn.Sequential(
+            nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
+            nn.GroupNorm(n_groups, out_channels),
+            nn.Mish(),
+        )
+
+    def forward(self, x):
+        return self.block(x)
+
+
+class _ConditionalResidualBlock1D(nn.Module):
+    """ResNet style 1D convolutional block with FiLM modulation for conditioning."""
+
     def __init__(
-        self, in_channels, out_channels, cond_dim, kernel_size=3, n_groups=8, cond_predict_scale=False
+        self,
+        in_channels: int,
+        out_channels: int,
+        cond_dim: int,
+        kernel_size: int = 3,
+        n_groups: int = 8,
+        # Set to True to do scale modulation with FiLM as well as bias modulation (defaults to False meaning
+        # FiLM just modulates bias).
+        film_scale_modulation: bool = False,
     ):
         super().__init__()
 
-        self.blocks = nn.ModuleList(
-            [
-                Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups),
-                Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups),
-            ]
-        )
-
-        # FiLM modulation https://arxiv.org/abs/1709.07871
-        # predicts per-channel scale and bias
-        cond_channels = out_channels
-        if cond_predict_scale:
-            cond_channels = out_channels * 2
-        self.cond_predict_scale = cond_predict_scale
+        self.film_scale_modulation = film_scale_modulation
         self.out_channels = out_channels
-        self.cond_encoder = nn.Sequential(
-            nn.Mish(),
-            nn.Linear(cond_dim, cond_channels),
-            Rearrange("batch t -> batch t 1"),
-        )
 
-        # make sure dimensions compatible
+        self.conv1 = _Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups)
+
+        # FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale.
+        cond_channels = out_channels * 2 if film_scale_modulation else out_channels
+        self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels))
+
+        self.conv2 = _Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups)
+
+        # A final convolution for dimension matching the residual (if needed).
         self.residual_conv = (
             nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
         )
 
-    def forward(self, x, cond):
+    def forward(self, x: Tensor, cond: Tensor) -> Tensor:
         """
-        x : [ batch_size x in_channels x horizon ]
-        cond : [ batch_size x cond_dim]
+        Args:
+            x: (B, in_channels, T)
+            cond: (B, cond_dim)
+        Returns:
+            (B, out_channels, T)
+        """
+        out = self.conv1(x)
 
-        returns:
-        out : [ batch_size x out_channels x horizon ]
-        """
-        out = self.blocks[0](x)
-        embed = self.cond_encoder(cond)
-        if self.cond_predict_scale:
-            embed = embed.reshape(embed.shape[0], 2, self.out_channels, 1)
-            scale = embed[:, 0, ...]
-            bias = embed[:, 1, ...]
+        # Get condition embedding. Unsqueeze for broadcasting to `out`, resulting in (B, out_channels, 1).
+        cond_embed = self.cond_encoder(cond).unsqueeze(-1)
+        if self.film_scale_modulation:
+            # Treat the embedding as a list of scales and biases.
+            scale = cond_embed[:, : self.out_channels]
+            bias = cond_embed[:, self.out_channels :]
             out = scale * out + bias
         else:
-            out = out + embed
-        out = self.blocks[1](out)
+            # Treat the embedding as biases.
+            out = out + cond_embed
+
+        out = self.conv2(out)
         out = out + self.residual_conv(x)
         return out
 
 
 class ConditionalUnet1D(nn.Module):
+    """A 1D convolutional UNet with FiLM modulation for conditioning.
+
+    Two types of conditioning can be applied:
+    - Global: Conditioning information that is aggregated over the whole observation window. This is
+        incorporated via the FiLM technique in the residual convolution blocks of the Unet's encoder/decoder.
+    - Local: Conditioning information for each timestep in the observation window. This is incorporated
+        by encoding the information via 1D convolutions and adding the resulting embeddings to the inputs and
+        outputs of the Unet's encoder/decoder.
+    """
+
     def __init__(
         self,
-        input_dim,
-        local_cond_dim=None,
-        global_cond_dim=None,
-        diffusion_step_embed_dim=256,
-        down_dims=None,
-        kernel_size=3,
-        n_groups=8,
-        cond_predict_scale=False,
+        input_dim: int,
+        local_cond_dim: int | None = None,
+        global_cond_dim: int | None = None,
+        diffusion_step_embed_dim: int = 256,
+        down_dims: int | None = None,
+        kernel_size: int = 3,
+        n_groups: int = 8,
+        film_scale_modulation: bool = False,
     ):
         super().__init__()
+
         if down_dims is None:
             down_dims = [256, 512, 1024]
 
-        all_dims = [input_dim] + list(down_dims)
-        start_dim = down_dims[0]
-
-        dsed = diffusion_step_embed_dim
-        diffusion_step_encoder = nn.Sequential(
-            SinusoidalPosEmb(dsed),
-            nn.Linear(dsed, dsed * 4),
+        # Encoder for the diffusion timestep.
+        self.diffusion_step_encoder = nn.Sequential(
+            _SinusoidalPosEmb(diffusion_step_embed_dim),
+            nn.Linear(diffusion_step_embed_dim, diffusion_step_embed_dim * 4),
             nn.Mish(),
-            nn.Linear(dsed * 4, dsed),
+            nn.Linear(diffusion_step_embed_dim * 4, diffusion_step_embed_dim),
         )
-        cond_dim = dsed
+
+        # The FiLM conditioning dimension.
+        cond_dim = diffusion_step_embed_dim
         if global_cond_dim is not None:
             cond_dim += global_cond_dim
 
-        in_out = list(zip(all_dims[:-1], all_dims[1:], strict=False))
-
-        local_cond_encoder = None
+        self.local_cond_down_encoder = None
+        self.local_cond_up_encoder = None
         if local_cond_dim is not None:
-            _, dim_out = in_out[0]
-            dim_in = local_cond_dim
-            local_cond_encoder = nn.ModuleList(
-                [
-                    # down encoder
-                    ConditionalResidualBlock1D(
-                        dim_in,
-                        dim_out,
-                        cond_dim=cond_dim,
-                        kernel_size=kernel_size,
-                        n_groups=n_groups,
-                        cond_predict_scale=cond_predict_scale,
-                    ),
-                    # up encoder
-                    ConditionalResidualBlock1D(
-                        dim_in,
-                        dim_out,
-                        cond_dim=cond_dim,
-                        kernel_size=kernel_size,
-                        n_groups=n_groups,
-                        cond_predict_scale=cond_predict_scale,
-                    ),
-                ]
+            # Encoder for the local conditioning. The output gets added to the Unet encoder input.
+            self.local_cond_down_encoder = _ConditionalResidualBlock1D(
+                local_cond_dim,
+                down_dims[0],
+                cond_dim=cond_dim,
+                kernel_size=kernel_size,
+                n_groups=n_groups,
+                film_scale_modulation=film_scale_modulation,
+            )
+            # Encoder for the local conditioning. The output gets added to the Unet encoder output.
+            self.local_cond_up_encoder = _ConditionalResidualBlock1D(
+                local_cond_dim,
+                down_dims[0],
+                cond_dim=cond_dim,
+                kernel_size=kernel_size,
+                n_groups=n_groups,
+                film_scale_modulation=film_scale_modulation,
             )
 
-        mid_dim = all_dims[-1]
+        # In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we
+        # just reverse these.
+        in_out = [(input_dim, down_dims[0])] + list(zip(down_dims[:-1], down_dims[1:], strict=True))
+
+        # Unet encoder.
+        self.down_modules = nn.ModuleList([])
+        for ind, (dim_in, dim_out) in enumerate(in_out):
+            is_last = ind >= (len(in_out) - 1)
+            self.down_modules.append(
+                nn.ModuleList(
+                    [
+                        _ConditionalResidualBlock1D(
+                            dim_in,
+                            dim_out,
+                            cond_dim=cond_dim,
+                            kernel_size=kernel_size,
+                            n_groups=n_groups,
+                            film_scale_modulation=film_scale_modulation,
+                        ),
+                        _ConditionalResidualBlock1D(
+                            dim_out,
+                            dim_out,
+                            cond_dim=cond_dim,
+                            kernel_size=kernel_size,
+                            n_groups=n_groups,
+                            film_scale_modulation=film_scale_modulation,
+                        ),
+                        # Downsample as long as it is not the last block.
+                        nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(),
+                    ]
+                )
+            )
+
+        # Processing in the middle of the auto-encoder.
         self.mid_modules = nn.ModuleList(
             [
-                ConditionalResidualBlock1D(
-                    mid_dim,
-                    mid_dim,
+                _ConditionalResidualBlock1D(
+                    down_dims[-1],
+                    down_dims[-1],
                     cond_dim=cond_dim,
                     kernel_size=kernel_size,
                     n_groups=n_groups,
-                    cond_predict_scale=cond_predict_scale,
+                    film_scale_modulation=film_scale_modulation,
                 ),
-                ConditionalResidualBlock1D(
-                    mid_dim,
-                    mid_dim,
+                _ConditionalResidualBlock1D(
+                    down_dims[-1],
+                    down_dims[-1],
                     cond_dim=cond_dim,
                     kernel_size=kernel_size,
                     n_groups=n_groups,
-                    cond_predict_scale=cond_predict_scale,
+                    film_scale_modulation=film_scale_modulation,
                 ),
             ]
         )
 
-        down_modules = nn.ModuleList([])
-        for ind, (dim_in, dim_out) in enumerate(in_out):
+        # Unet decoder.
+        self.up_modules = nn.ModuleList([])
+        for ind, (dim_out, dim_in) in enumerate(reversed(in_out[1:])):
             is_last = ind >= (len(in_out) - 1)
-            down_modules.append(
+            self.up_modules.append(
                 nn.ModuleList(
                     [
-                        ConditionalResidualBlock1D(
-                            dim_in,
+                        _ConditionalResidualBlock1D(
+                            dim_in * 2,  # x2 as it takes the encoder's skip connection as well
                             dim_out,
                             cond_dim=cond_dim,
                             kernel_size=kernel_size,
                             n_groups=n_groups,
-                            cond_predict_scale=cond_predict_scale,
+                            film_scale_modulation=film_scale_modulation,
                         ),
-                        ConditionalResidualBlock1D(
+                        _ConditionalResidualBlock1D(
                             dim_out,
                             dim_out,
                             cond_dim=cond_dim,
                             kernel_size=kernel_size,
                             n_groups=n_groups,
-                            cond_predict_scale=cond_predict_scale,
+                            film_scale_modulation=film_scale_modulation,
                         ),
-                        Downsample1d(dim_out) if not is_last else nn.Identity(),
+                        # Upsample as long as it is not the last block.
+                        nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(),
                     ]
                 )
             )
 
-        up_modules = nn.ModuleList([])
-        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
-            is_last = ind >= (len(in_out) - 1)
-            up_modules.append(
-                nn.ModuleList(
-                    [
-                        ConditionalResidualBlock1D(
-                            dim_out * 2,
-                            dim_in,
-                            cond_dim=cond_dim,
-                            kernel_size=kernel_size,
-                            n_groups=n_groups,
-                            cond_predict_scale=cond_predict_scale,
-                        ),
-                        ConditionalResidualBlock1D(
-                            dim_in,
-                            dim_in,
-                            cond_dim=cond_dim,
-                            kernel_size=kernel_size,
-                            n_groups=n_groups,
-                            cond_predict_scale=cond_predict_scale,
-                        ),
-                        Upsample1d(dim_in) if not is_last else nn.Identity(),
-                    ]
-                )
-            )
-
-        final_conv = nn.Sequential(
-            Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),
-            nn.Conv1d(start_dim, input_dim, 1),
+        self.final_conv = nn.Sequential(
+            _Conv1dBlock(down_dims[0], down_dims[0], kernel_size=kernel_size),
+            nn.Conv1d(down_dims[0], input_dim, 1),
         )
 
-        self.diffusion_step_encoder = diffusion_step_encoder
-        self.local_cond_encoder = local_cond_encoder
-        self.up_modules = up_modules
-        self.down_modules = down_modules
-        self.final_conv = final_conv
-
-        logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
-
-    def forward(
-        self,
-        sample: torch.Tensor,
-        timestep: Union[torch.Tensor, float, int],
-        local_cond=None,
-        global_cond=None,
-        **kwargs,
-    ):
+    def forward(self, x: Tensor, timestep: Tensor | int, local_cond=None, global_cond=None) -> Tensor:
         """
-        x: (B,T,input_dim)
-        timestep: (B,) or int, diffusion step
-        local_cond: (B,T,local_cond_dim)
-        global_cond: (B,global_cond_dim)
-        output: (B,T,input_dim)
+        Args:
+            x: (B, T, input_dim) tensor for input to the Unet.
+            timestep: (B,) tensor of (timestep_we_are_denoising_from - 1).
+            local_cond: (B, T, local_cond_dim)
+            global_cond: (B, global_cond_dim)
+            output: (B, T, input_dim)
+        Returns:
+            (B, T, input_dim)
         """
-        sample = einops.rearrange(sample, "b h t -> b t h")
-
-        # 1. time
-        timesteps = timestep
-        if not torch.is_tensor(timesteps):
-            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
-            timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
-        elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
-            timesteps = timesteps[None].to(sample.device)
-        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
-        timesteps = timesteps.expand(sample.shape[0])
-
-        global_feature = self.diffusion_step_encoder(timesteps)
-
-        if global_cond is not None:
-            global_feature = torch.cat([global_feature, global_cond], axis=-1)
-
-        # encode local features
-        h_local = []
+        # For 1D convolutions we'll need feature dimension first.
+        x = einops.rearrange(x, "b t d -> b d t")
         if local_cond is not None:
-            local_cond = einops.rearrange(local_cond, "b h t -> b t h")
-            resnet, resnet2 = self.local_cond_encoder
-            x = resnet(local_cond, global_feature)
-            h_local.append(x)
-            x = resnet2(local_cond, global_feature)
-            h_local.append(x)
+            if self.local_cond_down_encoder is None or self.local_cond_up_encoder is None:
+                raise ValueError(
+                    "`local_cond` was provided but the relevant encoders weren't built at initialization."
+                )
+            local_cond = einops.rearrange(local_cond, "b t d -> b d t")
 
-        x = sample
-        h = []
+        timesteps_embed = self.diffusion_step_encoder(timestep)
+
+        # If there is a global conditioning feature, concatenate it to the timestep embedding.
+        if global_cond is not None:
+            global_feature = torch.cat([timesteps_embed, global_cond], axis=-1)
+        else:
+            global_feature = timesteps_embed
+
+        encoder_skip_features: list[Tensor] = []
         for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
             x = resnet(x, global_feature)
-            if idx == 0 and len(h_local) > 0:
-                x = x + h_local[0]
+            if idx == 0 and local_cond is not None:
+                x = x + self.local_cond_down_encoder(local_cond, global_feature)
             x = resnet2(x, global_feature)
-            h.append(x)
+            encoder_skip_features.append(x)
             x = downsample(x)
 
         for mid_module in self.mid_modules:
             x = mid_module(x, global_feature)
 
         for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
-            x = torch.cat((x, h.pop()), dim=1)
+            x = torch.cat((x, encoder_skip_features.pop()), dim=1)
             x = resnet(x, global_feature)
-            # The correct condition should be:
-            # if idx == (len(self.up_modules)-1) and len(h_local) > 0:
-            # However this change will break compatibility with published checkpoints.
-            # Therefore it is left as a comment.
-            if idx == len(self.up_modules) and len(h_local) > 0:
-                x = x + h_local[1]
+            # Note: The condition in the original implementation is:
+            # if idx == len(self.up_modules) and local_cond is not None:
+            # But as they mention in their comments, this is incorrect. We use the correct condition here.
+            if idx == (len(self.up_modules) - 1) and local_cond is not None:
+                x = x + self.local_cond_up_encoder(local_cond, global_feature)
             x = resnet2(x, global_feature)
             x = upsample(x)
 
         x = self.final_conv(x)
 
-        x = einops.rearrange(x, "b t h -> b h t")
+        x = einops.rearrange(x, "b d t -> b t d")
         return x
diff --git a/lerobot/common/policies/diffusion/model/conv1d_components.py b/lerobot/common/policies/diffusion/model/conv1d_components.py
deleted file mode 100644
index 3c21eaf6..00000000
--- a/lerobot/common/policies/diffusion/model/conv1d_components.py
+++ /dev/null
@@ -1,47 +0,0 @@
-import torch.nn as nn
-
-# from einops.layers.torch import Rearrange
-
-
-class Downsample1d(nn.Module):
-    def __init__(self, dim):
-        super().__init__()
-        self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
-
-    def forward(self, x):
-        return self.conv(x)
-
-
-class Upsample1d(nn.Module):
-    def __init__(self, dim):
-        super().__init__()
-        self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
-
-    def forward(self, x):
-        return self.conv(x)
-
-
-class Conv1dBlock(nn.Module):
-    """
-    Conv1d --> GroupNorm --> Mish
-    """
-
-    def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
-        super().__init__()
-
-        self.block = nn.Sequential(
-            nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
-            # Rearrange('batch channels horizon -> batch channels 1 horizon'),
-            nn.GroupNorm(n_groups, out_channels),
-            # Rearrange('batch channels 1 horizon -> batch channels horizon'),
-            nn.Mish(),
-        )
-
-    def forward(self, x):
-        return self.block(x)
-
-
-# def test():
-#     cb = Conv1dBlock(256, 128, kernel_size=3)
-#     x = torch.zeros((1,256,16))
-#     o = cb(x)
diff --git a/lerobot/common/policies/diffusion/model/crop_randomizer.py b/lerobot/common/policies/diffusion/model/crop_randomizer.py
deleted file mode 100644
index 2e60f353..00000000
--- a/lerobot/common/policies/diffusion/model/crop_randomizer.py
+++ /dev/null
@@ -1,294 +0,0 @@
-import torch
-import torch.nn as nn
-import torchvision.transforms.functional as ttf
-
-import lerobot.common.policies.diffusion.model.tensor_utils as tu
-
-
-class CropRandomizer(nn.Module):
-    """
-    Randomly sample crops at input, and then average across crop features at output.
-    """
-
-    def __init__(
-        self,
-        input_shape,
-        crop_height,
-        crop_width,
-        num_crops=1,
-        pos_enc=False,
-    ):
-        """
-        Args:
-            input_shape (tuple, list): shape of input (not including batch dimension)
-            crop_height (int): crop height
-            crop_width (int): crop width
-            num_crops (int): number of random crops to take
-            pos_enc (bool): if True, add 2 channels to the output to encode the spatial
-                location of the cropped pixels in the source image
-        """
-        super().__init__()
-
-        assert len(input_shape) == 3  # (C, H, W)
-        assert crop_height < input_shape[1]
-        assert crop_width < input_shape[2]
-
-        self.input_shape = input_shape
-        self.crop_height = crop_height
-        self.crop_width = crop_width
-        self.num_crops = num_crops
-        self.pos_enc = pos_enc
-
-    def output_shape_in(self, input_shape=None):
-        """
-        Function to compute output shape from inputs to this module. Corresponds to
-        the @forward_in operation, where raw inputs (usually observation modalities)
-        are passed in.
-
-        Args:
-            input_shape (iterable of int): shape of input. Does not include batch dimension.
-                Some modules may not need this argument, if their output does not depend
-                on the size of the input, or if they assume fixed size input.
-
-        Returns:
-            out_shape ([int]): list of integers corresponding to output shape
-        """
-
-        # outputs are shape (C, CH, CW), or maybe C + 2 if using position encoding, because
-        # the number of crops are reshaped into the batch dimension, increasing the batch
-        # size from B to B * N
-        out_c = self.input_shape[0] + 2 if self.pos_enc else self.input_shape[0]
-        return [out_c, self.crop_height, self.crop_width]
-
-    def output_shape_out(self, input_shape=None):
-        """
-        Function to compute output shape from inputs to this module. Corresponds to
-        the @forward_out operation, where processed inputs (usually encoded observation
-        modalities) are passed in.
-
-        Args:
-            input_shape (iterable of int): shape of input. Does not include batch dimension.
-                Some modules may not need this argument, if their output does not depend
-                on the size of the input, or if they assume fixed size input.
-
-        Returns:
-            out_shape ([int]): list of integers corresponding to output shape
-        """
-
-        # since the forward_out operation splits [B * N, ...] -> [B, N, ...]
-        # and then pools to result in [B, ...], only the batch dimension changes,
-        # and so the other dimensions retain their shape.
-        return list(input_shape)
-
-    def forward_in(self, inputs):
-        """
-        Samples N random crops for each input in the batch, and then reshapes
-        inputs to [B * N, ...].
-        """
-        assert len(inputs.shape) >= 3  # must have at least (C, H, W) dimensions
-        if self.training:
-            # generate random crops
-            out, _ = sample_random_image_crops(
-                images=inputs,
-                crop_height=self.crop_height,
-                crop_width=self.crop_width,
-                num_crops=self.num_crops,
-                pos_enc=self.pos_enc,
-            )
-            # [B, N, ...] -> [B * N, ...]
-            return tu.join_dimensions(out, 0, 1)
-        else:
-            # take center crop during eval
-            out = ttf.center_crop(img=inputs, output_size=(self.crop_height, self.crop_width))
-            if self.num_crops > 1:
-                B, C, H, W = out.shape  # noqa: N806
-                out = out.unsqueeze(1).expand(B, self.num_crops, C, H, W).reshape(-1, C, H, W)
-                # [B * N, ...]
-            return out
-
-    def forward_out(self, inputs):
-        """
-        Splits the outputs from shape [B * N, ...] -> [B, N, ...] and then average across N
-        to result in shape [B, ...] to make sure the network output is consistent with
-        what would have happened if there were no randomization.
-        """
-        if self.num_crops <= 1:
-            return inputs
-        else:
-            batch_size = inputs.shape[0] // self.num_crops
-            out = tu.reshape_dimensions(
-                inputs, begin_axis=0, end_axis=0, target_dims=(batch_size, self.num_crops)
-            )
-            return out.mean(dim=1)
-
-    def forward(self, inputs):
-        return self.forward_in(inputs)
-
-    def __repr__(self):
-        """Pretty print network."""
-        header = "{}".format(str(self.__class__.__name__))
-        msg = header + "(input_shape={}, crop_size=[{}, {}], num_crops={})".format(
-            self.input_shape, self.crop_height, self.crop_width, self.num_crops
-        )
-        return msg
-
-
-def crop_image_from_indices(images, crop_indices, crop_height, crop_width):
-    """
-    Crops images at the locations specified by @crop_indices. Crops will be
-    taken across all channels.
-
-    Args:
-        images (torch.Tensor): batch of images of shape [..., C, H, W]
-
-        crop_indices (torch.Tensor): batch of indices of shape [..., N, 2] where
-            N is the number of crops to take per image and each entry corresponds
-            to the pixel height and width of where to take the crop. Note that
-            the indices can also be of shape [..., 2] if only 1 crop should
-            be taken per image. Leading dimensions must be consistent with
-            @images argument. Each index specifies the top left of the crop.
-            Values must be in range [0, H - CH - 1] x [0, W - CW - 1] where
-            H and W are the height and width of @images and CH and CW are
-            @crop_height and @crop_width.
-
-        crop_height (int): height of crop to take
-
-        crop_width (int): width of crop to take
-
-    Returns:
-        crops (torch.Tesnor): cropped images of shape [..., C, @crop_height, @crop_width]
-    """
-
-    # make sure length of input shapes is consistent
-    assert crop_indices.shape[-1] == 2
-    ndim_im_shape = len(images.shape)
-    ndim_indices_shape = len(crop_indices.shape)
-    assert (ndim_im_shape == ndim_indices_shape + 1) or (ndim_im_shape == ndim_indices_shape + 2)
-
-    # maybe pad so that @crop_indices is shape [..., N, 2]
-    is_padded = False
-    if ndim_im_shape == ndim_indices_shape + 2:
-        crop_indices = crop_indices.unsqueeze(-2)
-        is_padded = True
-
-    # make sure leading dimensions between images and indices are consistent
-    assert images.shape[:-3] == crop_indices.shape[:-2]
-
-    device = images.device
-    image_c, image_h, image_w = images.shape[-3:]
-    num_crops = crop_indices.shape[-2]
-
-    # make sure @crop_indices are in valid range
-    assert (crop_indices[..., 0] >= 0).all().item()
-    assert (crop_indices[..., 0] < (image_h - crop_height)).all().item()
-    assert (crop_indices[..., 1] >= 0).all().item()
-    assert (crop_indices[..., 1] < (image_w - crop_width)).all().item()
-
-    # convert each crop index (ch, cw) into a list of pixel indices that correspond to the entire window.
-
-    # 2D index array with columns [0, 1, ..., CH - 1] and shape [CH, CW]
-    crop_ind_grid_h = torch.arange(crop_height).to(device)
-    crop_ind_grid_h = tu.unsqueeze_expand_at(crop_ind_grid_h, size=crop_width, dim=-1)
-    # 2D index array with rows [0, 1, ..., CW - 1] and shape [CH, CW]
-    crop_ind_grid_w = torch.arange(crop_width).to(device)
-    crop_ind_grid_w = tu.unsqueeze_expand_at(crop_ind_grid_w, size=crop_height, dim=0)
-    # combine into shape [CH, CW, 2]
-    crop_in_grid = torch.cat((crop_ind_grid_h.unsqueeze(-1), crop_ind_grid_w.unsqueeze(-1)), dim=-1)
-
-    # Add above grid with the offset index of each sampled crop to get 2d indices for each crop.
-    # After broadcasting, this will be shape [..., N, CH, CW, 2] and each crop has a [CH, CW, 2]
-    # shape array that tells us which pixels from the corresponding source image to grab.
-    grid_reshape = [1] * len(crop_indices.shape[:-1]) + [crop_height, crop_width, 2]
-    all_crop_inds = crop_indices.unsqueeze(-2).unsqueeze(-2) + crop_in_grid.reshape(grid_reshape)
-
-    # For using @torch.gather, convert to flat indices from 2D indices, and also
-    # repeat across the channel dimension. To get flat index of each pixel to grab for
-    # each sampled crop, we just use the mapping: ind = h_ind * @image_w + w_ind
-    all_crop_inds = all_crop_inds[..., 0] * image_w + all_crop_inds[..., 1]  # shape [..., N, CH, CW]
-    all_crop_inds = tu.unsqueeze_expand_at(all_crop_inds, size=image_c, dim=-3)  # shape [..., N, C, CH, CW]
-    all_crop_inds = tu.flatten(all_crop_inds, begin_axis=-2)  # shape [..., N, C, CH * CW]
-
-    # Repeat and flatten the source images -> [..., N, C, H * W] and then use gather to index with crop pixel inds
-    images_to_crop = tu.unsqueeze_expand_at(images, size=num_crops, dim=-4)
-    images_to_crop = tu.flatten(images_to_crop, begin_axis=-2)
-    crops = torch.gather(images_to_crop, dim=-1, index=all_crop_inds)
-    # [..., N, C, CH * CW] -> [..., N, C, CH, CW]
-    reshape_axis = len(crops.shape) - 1
-    crops = tu.reshape_dimensions(
-        crops, begin_axis=reshape_axis, end_axis=reshape_axis, target_dims=(crop_height, crop_width)
-    )
-
-    if is_padded:
-        # undo padding -> [..., C, CH, CW]
-        crops = crops.squeeze(-4)
-    return crops
-
-
-def sample_random_image_crops(images, crop_height, crop_width, num_crops, pos_enc=False):
-    """
-    For each image, randomly sample @num_crops crops of size (@crop_height, @crop_width), from
-    @images.
-
-    Args:
-        images (torch.Tensor): batch of images of shape [..., C, H, W]
-
-        crop_height (int): height of crop to take
-
-        crop_width (int): width of crop to take
-
-        num_crops (n): number of crops to sample
-
-        pos_enc (bool): if True, also add 2 channels to the outputs that gives a spatial
-            encoding of the original source pixel locations. This means that the
-            output crops will contain information about where in the source image
-            it was sampled from.
-
-    Returns:
-        crops (torch.Tensor): crops of shape (..., @num_crops, C, @crop_height, @crop_width)
-            if @pos_enc is False, otherwise (..., @num_crops, C + 2, @crop_height, @crop_width)
-
-        crop_inds (torch.Tensor): sampled crop indices of shape (..., N, 2)
-    """
-    device = images.device
-
-    # maybe add 2 channels of spatial encoding to the source image
-    source_im = images
-    if pos_enc:
-        # spatial encoding [y, x] in [0, 1]
-        h, w = source_im.shape[-2:]
-        pos_y, pos_x = torch.meshgrid(torch.arange(h), torch.arange(w))
-        pos_y = pos_y.float().to(device) / float(h)
-        pos_x = pos_x.float().to(device) / float(w)
-        position_enc = torch.stack((pos_y, pos_x))  # shape [C, H, W]
-
-        # unsqueeze and expand to match leading dimensions -> shape [..., C, H, W]
-        leading_shape = source_im.shape[:-3]
-        position_enc = position_enc[(None,) * len(leading_shape)]
-        position_enc = position_enc.expand(*leading_shape, -1, -1, -1)
-
-        # concat across channel dimension with input
-        source_im = torch.cat((source_im, position_enc), dim=-3)
-
-    # make sure sample boundaries ensure crops are fully within the images
-    image_c, image_h, image_w = source_im.shape[-3:]
-    max_sample_h = image_h - crop_height
-    max_sample_w = image_w - crop_width
-
-    # Sample crop locations for all tensor dimensions up to the last 3, which are [C, H, W].
-    # Each gets @num_crops samples - typically this will just be the batch dimension (B), so
-    # we will sample [B, N] indices, but this supports having more than one leading dimension,
-    # or possibly no leading dimension.
-    #
-    # Trick: sample in [0, 1) with rand, then re-scale to [0, M) and convert to long to get sampled ints
-    crop_inds_h = (max_sample_h * torch.rand(*source_im.shape[:-3], num_crops).to(device)).long()
-    crop_inds_w = (max_sample_w * torch.rand(*source_im.shape[:-3], num_crops).to(device)).long()
-    crop_inds = torch.cat((crop_inds_h.unsqueeze(-1), crop_inds_w.unsqueeze(-1)), dim=-1)  # shape [..., N, 2]
-
-    crops = crop_image_from_indices(
-        images=source_im,
-        crop_indices=crop_inds,
-        crop_height=crop_height,
-        crop_width=crop_width,
-    )
-
-    return crops, crop_inds
diff --git a/lerobot/common/policies/diffusion/model/dict_of_tensor_mixin.py b/lerobot/common/policies/diffusion/model/dict_of_tensor_mixin.py
deleted file mode 100644
index d1356006..00000000
--- a/lerobot/common/policies/diffusion/model/dict_of_tensor_mixin.py
+++ /dev/null
@@ -1,41 +0,0 @@
-import torch
-import torch.nn as nn
-
-
-class DictOfTensorMixin(nn.Module):
-    def __init__(self, params_dict=None):
-        super().__init__()
-        if params_dict is None:
-            params_dict = nn.ParameterDict()
-        self.params_dict = params_dict
-
-    @property
-    def device(self):
-        return next(iter(self.parameters())).device
-
-    def _load_from_state_dict(
-        self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
-    ):
-        def dfs_add(dest, keys, value: torch.Tensor):
-            if len(keys) == 1:
-                dest[keys[0]] = value
-                return
-
-            if keys[0] not in dest:
-                dest[keys[0]] = nn.ParameterDict()
-            dfs_add(dest[keys[0]], keys[1:], value)
-
-        def load_dict(state_dict, prefix):
-            out_dict = nn.ParameterDict()
-            for key, value in state_dict.items():
-                value: torch.Tensor
-                if key.startswith(prefix):
-                    param_keys = key[len(prefix) :].split(".")[1:]
-                    # if len(param_keys) == 0:
-                    #     import pdb; pdb.set_trace()
-                    dfs_add(out_dict, param_keys, value.clone())
-            return out_dict
-
-        self.params_dict = load_dict(state_dict, prefix + "params_dict")
-        self.params_dict.requires_grad_(False)
-        return
diff --git a/lerobot/common/policies/diffusion/model/diffusion_unet_image_policy.py b/lerobot/common/policies/diffusion/model/diffusion_unet_image_policy.py
new file mode 100644
index 00000000..b6b78925
--- /dev/null
+++ b/lerobot/common/policies/diffusion/model/diffusion_unet_image_policy.py
@@ -0,0 +1,220 @@
+import einops
+import torch
+import torch.nn.functional as F  # noqa: N812
+from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
+from torch import Tensor, nn
+
+from lerobot.common.policies.diffusion.model.conditional_unet1d import ConditionalUnet1D
+from lerobot.common.policies.diffusion.model.rgb_encoder import RgbEncoder
+from lerobot.common.policies.utils import get_device_from_parameters, get_dtype_from_parameters
+
+
+class DiffusionUnetImagePolicy(nn.Module):
+    """
+    TODO(now): Add DDIM scheduler.
+
+    Changes:  TODO(now)
+      - Use single image encoder for now instead of generic obs_encoder. We may generalize again when/if
+        needed. Code for a general observation encoder can be found at:
+        https://github.com/huggingface/lerobot/blob/920e0d118b493e4cc3058a9b1b764f38ae145d8e/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py
+      - Uses the observation as global conditioning for the Unet by default.
+      - Does not do any inpainting (which would be applicable if the observation were not used to condition
+        the Unet).
+    """
+
+    def __init__(
+        self,
+        cfg,
+        shape_meta: dict,
+        noise_scheduler: DDPMScheduler,
+        horizon,
+        n_action_steps,
+        n_obs_steps,
+        num_inference_steps=None,
+        diffusion_step_embed_dim=256,
+        down_dims=(256, 512, 1024),
+        kernel_size=5,
+        n_groups=8,
+        film_scale_modulation=True,
+    ):
+        super().__init__()
+        action_shape = shape_meta["action"]["shape"]
+        assert len(action_shape) == 1
+        action_dim = action_shape[0]
+
+        self.rgb_encoder = RgbEncoder(input_shape=shape_meta.obs.image.shape, **cfg.rgb_encoder)
+
+        self.unet = ConditionalUnet1D(
+            input_dim=action_dim,
+            global_cond_dim=(action_dim + self.rgb_encoder.feature_dim) * n_obs_steps,
+            diffusion_step_embed_dim=diffusion_step_embed_dim,
+            down_dims=down_dims,
+            kernel_size=kernel_size,
+            n_groups=n_groups,
+            film_scale_modulation=film_scale_modulation,
+        )
+
+        self.noise_scheduler = noise_scheduler
+        self.horizon = horizon
+        self.action_dim = action_dim
+        self.n_action_steps = n_action_steps
+        self.n_obs_steps = n_obs_steps
+
+        if num_inference_steps is None:
+            num_inference_steps = noise_scheduler.config.num_train_timesteps
+
+        self.num_inference_steps = num_inference_steps
+
+    # ========= inference  ============
+    def conditional_sample(
+        self,
+        condition_data,
+        inpainting_mask,
+        local_cond=None,
+        global_cond=None,
+        generator=None,
+    ):
+        model = self.unet
+        scheduler = self.noise_scheduler
+
+        trajectory = torch.randn(
+            size=condition_data.shape,
+            dtype=condition_data.dtype,
+            device=condition_data.device,
+            generator=generator,
+        )
+
+        # set step values
+        scheduler.set_timesteps(self.num_inference_steps)
+
+        for t in scheduler.timesteps:
+            # 1. apply conditioning
+            trajectory[inpainting_mask] = condition_data[inpainting_mask]
+
+            # 2. predict model output
+            model_output = model(
+                trajectory,
+                torch.full(trajectory.shape[:1], t, dtype=torch.long, device=trajectory.device),
+                local_cond=local_cond,
+                global_cond=global_cond,
+            )
+
+            # 3. compute previous image: x_t -> x_t-1
+            trajectory = scheduler.step(
+                model_output,
+                t,
+                trajectory,
+                generator=generator,
+            ).prev_sample
+
+        # finally make sure conditioning is enforced
+        trajectory[inpainting_mask] = condition_data[inpainting_mask]
+
+        return trajectory
+
+    def predict_action(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
+        """
+        This function expects `batch` to have (at least):
+        {
+            "observation.state": (B, n_obs_steps, state_dim)
+            "observation.image": (B, n_obs_steps, C, H, W)
+        }
+        """
+        assert set(batch).issuperset({"observation.state", "observation.image"})
+        batch_size, n_obs_steps = batch["observation.state"].shape[:2]
+        assert n_obs_steps == self.n_obs_steps
+        assert self.n_obs_steps == n_obs_steps
+
+        # build input
+        device = get_device_from_parameters(self)
+        dtype = get_dtype_from_parameters(self)
+
+        # Extract image feature (first combine batch and sequence dims).
+        img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
+        # Separate batch and sequence dims.
+        img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
+        # Concatenate state and image features then flatten to (B, global_cond_dim).
+        global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
+        # reshape back to B, Do
+        # empty data for action
+        cond_data = torch.zeros(size=(batch_size, self.horizon, self.action_dim), device=device, dtype=dtype)
+        cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
+
+        # run sampling
+        nsample = self.conditional_sample(cond_data, cond_mask, global_cond=global_cond)
+
+        # `horizon` steps worth of actions (from the first observation).
+        action = nsample[..., : self.action_dim]
+        # Extract `n_action_steps` steps worth of action (from the current observation).
+        start = n_obs_steps - 1
+        end = start + self.n_action_steps
+        action = action[:, start:end]
+
+        return action
+
+    def compute_loss(self, batch: dict[str, Tensor]) -> Tensor:
+        """
+        This function expects `batch` to have (at least):
+        {
+            "observation.state": (B, n_obs_steps, state_dim)
+            "observation.image": (B, n_obs_steps, C, H, W)
+            "action": (B, horizon, action_dim)
+            "action_is_pad": (B, horizon)  # TODO(now) maybe this is (B, horizon, 1)
+        }
+        """
+        assert set(batch).issuperset({"observation.state", "observation.image", "action", "action_is_pad"})
+        batch_size, n_obs_steps = batch["observation.state"].shape[:2]
+        horizon = batch["action"].shape[1]
+        assert horizon == self.horizon
+        assert n_obs_steps == self.n_obs_steps
+        assert self.n_obs_steps == n_obs_steps
+
+        # handle different ways of passing observation
+        local_cond = None
+        global_cond = None
+        trajectory = batch["action"]
+        cond_data = trajectory
+
+        # Extract image feature (first combine batch and sequence dims).
+        img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
+        # Separate batch and sequence dims.
+        img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
+        # Concatenate state and image features then flatten to (B, global_cond_dim).
+        global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
+
+        # Sample noise that we'll add to the images
+        noise = torch.randn(trajectory.shape, device=trajectory.device)
+        # Sample a random timestep for each image
+        timesteps = torch.randint(
+            0,
+            self.noise_scheduler.config.num_train_timesteps,
+            (trajectory.shape[0],),
+            device=trajectory.device,
+        ).long()
+        # Add noise to the clean images according to the noise magnitude at each timestep
+        # (this is the forward diffusion process)
+        noisy_trajectory = self.noise_scheduler.add_noise(trajectory, noise, timesteps)
+
+        # Apply inpainting. TODO(now): implement?
+        inpainting_mask = torch.zeros_like(trajectory, dtype=bool)
+        noisy_trajectory[inpainting_mask] = cond_data[inpainting_mask]
+
+        # Predict the noise residual
+        pred = self.unet(noisy_trajectory, timesteps, local_cond=local_cond, global_cond=global_cond)
+
+        pred_type = self.noise_scheduler.config.prediction_type
+        if pred_type == "epsilon":
+            target = noise
+        elif pred_type == "sample":
+            target = trajectory
+        else:
+            raise ValueError(f"Unsupported prediction type {pred_type}")
+
+        loss = F.mse_loss(pred, target, reduction="none")
+        loss = loss * (~inpainting_mask)
+
+        if "action_is_pad" in batch:
+            in_episode_bound = ~batch["action_is_pad"]
+            loss = loss * in_episode_bound[:, :, None].type(loss.dtype)
+
+        return loss.mean()
diff --git a/lerobot/common/policies/diffusion/model/ema_model.py b/lerobot/common/policies/diffusion/model/ema_model.py
index 6dc128de..3cb1dfbd 100644
--- a/lerobot/common/policies/diffusion/model/ema_model.py
+++ b/lerobot/common/policies/diffusion/model/ema_model.py
@@ -51,13 +51,6 @@ class EMAModel:
     def step(self, new_model):
         self.decay = self.get_decay(self.optimization_step)
 
-        # old_all_dataptrs = set()
-        # for param in new_model.parameters():
-        #     data_ptr = param.data_ptr()
-        #     if data_ptr != 0:
-        #         old_all_dataptrs.add(data_ptr)
-
-        # all_dataptrs = set()
         for module, ema_module in zip(new_model.modules(), self.averaged_model.modules(), strict=False):
             for param, ema_param in zip(
                 module.parameters(recurse=False), ema_module.parameters(recurse=False), strict=False
@@ -66,10 +59,6 @@ class EMAModel:
                 if isinstance(param, dict):
                     raise RuntimeError("Dict parameter not supported")
 
-                # data_ptr = param.data_ptr()
-                # if data_ptr != 0:
-                #     all_dataptrs.add(data_ptr)
-
                 if isinstance(module, _BatchNorm):
                     # skip batchnorms
                     ema_param.copy_(param.to(dtype=ema_param.dtype).data)
@@ -79,6 +68,4 @@ class EMAModel:
                     ema_param.mul_(self.decay)
                     ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay)
 
-        # verify that iterating over module and then parameters is identical to parameters recursively.
-        # assert old_all_dataptrs == all_dataptrs
         self.optimization_step += 1
diff --git a/lerobot/common/policies/diffusion/model/lr_scheduler.py b/lerobot/common/policies/diffusion/model/lr_scheduler.py
deleted file mode 100644
index 084b3a36..00000000
--- a/lerobot/common/policies/diffusion/model/lr_scheduler.py
+++ /dev/null
@@ -1,46 +0,0 @@
-from diffusers.optimization import TYPE_TO_SCHEDULER_FUNCTION, Optimizer, Optional, SchedulerType, Union
-
-
-def get_scheduler(
-    name: Union[str, SchedulerType],
-    optimizer: Optimizer,
-    num_warmup_steps: Optional[int] = None,
-    num_training_steps: Optional[int] = None,
-    **kwargs,
-):
-    """
-    Added kwargs vs diffuser's original implementation
-
-    Unified API to get any scheduler from its name.
-
-    Args:
-        name (`str` or `SchedulerType`):
-            The name of the scheduler to use.
-        optimizer (`torch.optim.Optimizer`):
-            The optimizer that will be used during training.
-        num_warmup_steps (`int`, *optional*):
-            The number of warmup steps to do. This is not required by all schedulers (hence the argument being
-            optional), the function will raise an error if it's unset and the scheduler type requires it.
-        num_training_steps (`int``, *optional*):
-            The number of training steps to do. This is not required by all schedulers (hence the argument being
-            optional), the function will raise an error if it's unset and the scheduler type requires it.
-    """
-    name = SchedulerType(name)
-    schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
-    if name == SchedulerType.CONSTANT:
-        return schedule_func(optimizer, **kwargs)
-
-    # All other schedulers require `num_warmup_steps`
-    if num_warmup_steps is None:
-        raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
-
-    if name == SchedulerType.CONSTANT_WITH_WARMUP:
-        return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **kwargs)
-
-    # All other schedulers require `num_training_steps`
-    if num_training_steps is None:
-        raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
-
-    return schedule_func(
-        optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, **kwargs
-    )
diff --git a/lerobot/common/policies/diffusion/model/mask_generator.py b/lerobot/common/policies/diffusion/model/mask_generator.py
deleted file mode 100644
index 63306dea..00000000
--- a/lerobot/common/policies/diffusion/model/mask_generator.py
+++ /dev/null
@@ -1,65 +0,0 @@
-import torch
-
-from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin
-
-
-class LowdimMaskGenerator(ModuleAttrMixin):
-    def __init__(
-        self,
-        action_dim,
-        obs_dim,
-        # obs mask setup
-        max_n_obs_steps=2,
-        fix_obs_steps=True,
-        # action mask
-        action_visible=False,
-    ):
-        super().__init__()
-        self.action_dim = action_dim
-        self.obs_dim = obs_dim
-        self.max_n_obs_steps = max_n_obs_steps
-        self.fix_obs_steps = fix_obs_steps
-        self.action_visible = action_visible
-
-    @torch.no_grad()
-    def forward(self, shape, seed=None):
-        device = self.device
-        B, T, D = shape  # noqa: N806
-        assert (self.action_dim + self.obs_dim) == D
-
-        # create all tensors on this device
-        rng = torch.Generator(device=device)
-        if seed is not None:
-            rng = rng.manual_seed(seed)
-
-        # generate dim mask
-        dim_mask = torch.zeros(size=shape, dtype=torch.bool, device=device)
-        is_action_dim = dim_mask.clone()
-        is_action_dim[..., : self.action_dim] = True
-        is_obs_dim = ~is_action_dim
-
-        # generate obs mask
-        if self.fix_obs_steps:
-            obs_steps = torch.full((B,), fill_value=self.max_n_obs_steps, device=device)
-        else:
-            obs_steps = torch.randint(
-                low=1, high=self.max_n_obs_steps + 1, size=(B,), generator=rng, device=device
-            )
-
-        steps = torch.arange(0, T, device=device).reshape(1, T).expand(B, T)
-        obs_mask = (obs_steps > steps.T).T.reshape(B, T, 1).expand(B, T, D)
-        obs_mask = obs_mask & is_obs_dim
-
-        # generate action mask
-        if self.action_visible:
-            action_steps = torch.maximum(
-                obs_steps - 1, torch.tensor(0, dtype=obs_steps.dtype, device=obs_steps.device)
-            )
-            action_mask = (action_steps > steps.T).T.reshape(B, T, 1).expand(B, T, D)
-            action_mask = action_mask & is_action_dim
-
-        mask = obs_mask
-        if self.action_visible:
-            mask = mask | action_mask
-
-        return mask
diff --git a/lerobot/common/policies/diffusion/model/module_attr_mixin.py b/lerobot/common/policies/diffusion/model/module_attr_mixin.py
deleted file mode 100644
index 5d2cf4ea..00000000
--- a/lerobot/common/policies/diffusion/model/module_attr_mixin.py
+++ /dev/null
@@ -1,15 +0,0 @@
-import torch.nn as nn
-
-
-class ModuleAttrMixin(nn.Module):
-    def __init__(self):
-        super().__init__()
-        self._dummy_variable = nn.Parameter()
-
-    @property
-    def device(self):
-        return next(iter(self.parameters())).device
-
-    @property
-    def dtype(self):
-        return next(iter(self.parameters())).dtype
diff --git a/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py b/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py
deleted file mode 100644
index d724cd49..00000000
--- a/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py
+++ /dev/null
@@ -1,214 +0,0 @@
-import copy
-from typing import Dict, Optional, Tuple, Union
-
-import torch
-import torch.nn as nn
-import torchvision
-from robomimic.models.base_nets import ResNet18Conv, SpatialSoftmax
-
-from lerobot.common.policies.diffusion.model.crop_randomizer import CropRandomizer
-from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin
-from lerobot.common.policies.diffusion.pytorch_utils import replace_submodules
-
-
-class RgbEncoder(nn.Module):
-    """Following `VisualCore` from Robomimic 0.2.0."""
-
-    def __init__(self, input_shape, relu=True, pretrained=False, num_keypoints=32):
-        """
-        input_shape: channel-first input shape (C, H, W)
-        resnet_name: a timm model name.
-        pretrained: whether to use timm pretrained weights.
-        relu: whether to use relu as a final step.
-        num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image).
-        """
-        super().__init__()
-        self.backbone = ResNet18Conv(input_channel=input_shape[0], pretrained=pretrained)
-        # Figure out the feature map shape.
-        with torch.inference_mode():
-            feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:])
-        self.pool = SpatialSoftmax(feat_map_shape, num_kp=num_keypoints)
-        self.out = nn.Linear(num_keypoints * 2, num_keypoints * 2)
-        self.relu = nn.ReLU() if relu else nn.Identity()
-
-    def forward(self, x):
-        return self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1)))
-
-
-class MultiImageObsEncoder(ModuleAttrMixin):
-    def __init__(
-        self,
-        shape_meta: dict,
-        rgb_model: Union[nn.Module, Dict[str, nn.Module]],
-        resize_shape: Union[Tuple[int, int], Dict[str, tuple], None] = None,
-        crop_shape: Union[Tuple[int, int], Dict[str, tuple], None] = None,
-        random_crop: bool = True,
-        # replace BatchNorm with GroupNorm
-        use_group_norm: bool = False,
-        # use single rgb model for all rgb inputs
-        share_rgb_model: bool = False,
-        # renormalize rgb input with imagenet normalization
-        # assuming input in [0,1]
-        norm_mean_std: Optional[tuple[float, float]] = None,
-    ):
-        """
-        Assumes rgb input: B,C,H,W
-        Assumes low_dim input: B,D
-        """
-        super().__init__()
-
-        rgb_keys = []
-        low_dim_keys = []
-        key_model_map = nn.ModuleDict()
-        key_transform_map = nn.ModuleDict()
-        key_shape_map = {}
-
-        # handle sharing vision backbone
-        if share_rgb_model:
-            assert isinstance(rgb_model, nn.Module)
-            key_model_map["rgb"] = rgb_model
-
-        obs_shape_meta = shape_meta["obs"]
-        for key, attr in obs_shape_meta.items():
-            shape = tuple(attr["shape"])
-            type = attr.get("type", "low_dim")
-            key_shape_map[key] = shape
-            if type == "rgb":
-                rgb_keys.append(key)
-                # configure model for this key
-                this_model = None
-                if not share_rgb_model:
-                    if isinstance(rgb_model, dict):
-                        # have provided model for each key
-                        this_model = rgb_model[key]
-                    else:
-                        assert isinstance(rgb_model, nn.Module)
-                        # have a copy of the rgb model
-                        this_model = copy.deepcopy(rgb_model)
-
-                if this_model is not None:
-                    if use_group_norm:
-                        this_model = replace_submodules(
-                            root_module=this_model,
-                            predicate=lambda x: isinstance(x, nn.BatchNorm2d),
-                            func=lambda x: nn.GroupNorm(
-                                num_groups=x.num_features // 16, num_channels=x.num_features
-                            ),
-                        )
-                    key_model_map[key] = this_model
-
-                # configure resize
-                input_shape = shape
-                this_resizer = nn.Identity()
-                if resize_shape is not None:
-                    if isinstance(resize_shape, dict):
-                        h, w = resize_shape[key]
-                    else:
-                        h, w = resize_shape
-                    this_resizer = torchvision.transforms.Resize(size=(h, w))
-                    input_shape = (shape[0], h, w)
-
-                # configure randomizer
-                this_randomizer = nn.Identity()
-                if crop_shape is not None:
-                    if isinstance(crop_shape, dict):
-                        h, w = crop_shape[key]
-                    else:
-                        h, w = crop_shape
-                    if random_crop:
-                        this_randomizer = CropRandomizer(
-                            input_shape=input_shape, crop_height=h, crop_width=w, num_crops=1, pos_enc=False
-                        )
-                    else:
-                        this_normalizer = torchvision.transforms.CenterCrop(size=(h, w))
-                # configure normalizer
-                this_normalizer = nn.Identity()
-                if norm_mean_std is not None:
-                    this_normalizer = torchvision.transforms.Normalize(
-                        mean=norm_mean_std[0], std=norm_mean_std[1]
-                    )
-
-                this_transform = nn.Sequential(this_resizer, this_randomizer, this_normalizer)
-                key_transform_map[key] = this_transform
-            elif type == "low_dim":
-                low_dim_keys.append(key)
-            else:
-                raise RuntimeError(f"Unsupported obs type: {type}")
-        rgb_keys = sorted(rgb_keys)
-        low_dim_keys = sorted(low_dim_keys)
-
-        self.shape_meta = shape_meta
-        self.key_model_map = key_model_map
-        self.key_transform_map = key_transform_map
-        self.share_rgb_model = share_rgb_model
-        self.rgb_keys = rgb_keys
-        self.low_dim_keys = low_dim_keys
-        self.key_shape_map = key_shape_map
-
-    def forward(self, obs_dict):
-        batch_size = None
-        features = []
-
-        # process lowdim input
-        for key in self.low_dim_keys:
-            data = obs_dict[key]
-            if batch_size is None:
-                batch_size = data.shape[0]
-            else:
-                assert batch_size == data.shape[0]
-            assert data.shape[1:] == self.key_shape_map[key]
-            features.append(data)
-
-        # process rgb input
-        if self.share_rgb_model:
-            # pass all rgb obs to rgb model
-            imgs = []
-            for key in self.rgb_keys:
-                img = obs_dict[key]
-                if batch_size is None:
-                    batch_size = img.shape[0]
-                else:
-                    assert batch_size == img.shape[0]
-                assert img.shape[1:] == self.key_shape_map[key]
-                img = self.key_transform_map[key](img)
-                imgs.append(img)
-            # (N*B,C,H,W)
-            imgs = torch.cat(imgs, dim=0)
-            # (N*B,D)
-            feature = self.key_model_map["rgb"](imgs)
-            # (N,B,D)
-            feature = feature.reshape(-1, batch_size, *feature.shape[1:])
-            # (B,N,D)
-            feature = torch.moveaxis(feature, 0, 1)
-            # (B,N*D)
-            feature = feature.reshape(batch_size, -1)
-            features.append(feature)
-        else:
-            # run each rgb obs to independent models
-            for key in self.rgb_keys:
-                img = obs_dict[key]
-                if batch_size is None:
-                    batch_size = img.shape[0]
-                else:
-                    assert batch_size == img.shape[0]
-                assert img.shape[1:] == self.key_shape_map[key]
-                img = self.key_transform_map[key](img)
-                feature = self.key_model_map[key](img)
-                features.append(feature)
-
-        # concatenate all features
-        result = torch.cat(features, dim=-1)
-        return result
-
-    @torch.no_grad()
-    def output_shape(self):
-        example_obs_dict = {}
-        obs_shape_meta = self.shape_meta["obs"]
-        batch_size = 1
-        for key, attr in obs_shape_meta.items():
-            shape = tuple(attr["shape"])
-            this_obs = torch.zeros((batch_size,) + shape, dtype=self.dtype, device=self.device)
-            example_obs_dict[key] = this_obs
-        example_output = self.forward(example_obs_dict)
-        output_shape = example_output.shape[1:]
-        return output_shape
diff --git a/lerobot/common/policies/diffusion/model/normalizer.py b/lerobot/common/policies/diffusion/model/normalizer.py
deleted file mode 100644
index 0e4d79ab..00000000
--- a/lerobot/common/policies/diffusion/model/normalizer.py
+++ /dev/null
@@ -1,358 +0,0 @@
-from typing import Dict, Union
-
-import numpy as np
-import torch
-import torch.nn as nn
-import zarr
-
-from lerobot.common.policies.diffusion.model.dict_of_tensor_mixin import DictOfTensorMixin
-from lerobot.common.policies.diffusion.pytorch_utils import dict_apply
-
-
-class LinearNormalizer(DictOfTensorMixin):
-    avaliable_modes = ["limits", "gaussian"]
-
-    @torch.no_grad()
-    def fit(
-        self,
-        data: Union[Dict, torch.Tensor, np.ndarray, zarr.Array],
-        last_n_dims=1,
-        dtype=torch.float32,
-        mode="limits",
-        output_max=1.0,
-        output_min=-1.0,
-        range_eps=1e-4,
-        fit_offset=True,
-    ):
-        if isinstance(data, dict):
-            for key, value in data.items():
-                self.params_dict[key] = _fit(
-                    value,
-                    last_n_dims=last_n_dims,
-                    dtype=dtype,
-                    mode=mode,
-                    output_max=output_max,
-                    output_min=output_min,
-                    range_eps=range_eps,
-                    fit_offset=fit_offset,
-                )
-        else:
-            self.params_dict["_default"] = _fit(
-                data,
-                last_n_dims=last_n_dims,
-                dtype=dtype,
-                mode=mode,
-                output_max=output_max,
-                output_min=output_min,
-                range_eps=range_eps,
-                fit_offset=fit_offset,
-            )
-
-    def __call__(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor:
-        return self.normalize(x)
-
-    def __getitem__(self, key: str):
-        return SingleFieldLinearNormalizer(self.params_dict[key])
-
-    def __setitem__(self, key: str, value: "SingleFieldLinearNormalizer"):
-        self.params_dict[key] = value.params_dict
-
-    def _normalize_impl(self, x, forward=True):
-        if isinstance(x, dict):
-            result = {}
-            for key, value in x.items():
-                params = self.params_dict[key]
-                result[key] = _normalize(value, params, forward=forward)
-            return result
-        else:
-            if "_default" not in self.params_dict:
-                raise RuntimeError("Not initialized")
-            params = self.params_dict["_default"]
-            return _normalize(x, params, forward=forward)
-
-    def normalize(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor:
-        return self._normalize_impl(x, forward=True)
-
-    def unnormalize(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor:
-        return self._normalize_impl(x, forward=False)
-
-    def get_input_stats(self) -> Dict:
-        if len(self.params_dict) == 0:
-            raise RuntimeError("Not initialized")
-        if len(self.params_dict) == 1 and "_default" in self.params_dict:
-            return self.params_dict["_default"]["input_stats"]
-
-        result = {}
-        for key, value in self.params_dict.items():
-            if key != "_default":
-                result[key] = value["input_stats"]
-        return result
-
-    def get_output_stats(self, key="_default"):
-        input_stats = self.get_input_stats()
-        if "min" in input_stats:
-            # no dict
-            return dict_apply(input_stats, self.normalize)
-
-        result = {}
-        for key, group in input_stats.items():
-            this_dict = {}
-            for name, value in group.items():
-                this_dict[name] = self.normalize({key: value})[key]
-            result[key] = this_dict
-        return result
-
-
-class SingleFieldLinearNormalizer(DictOfTensorMixin):
-    avaliable_modes = ["limits", "gaussian"]
-
-    @torch.no_grad()
-    def fit(
-        self,
-        data: Union[torch.Tensor, np.ndarray, zarr.Array],
-        last_n_dims=1,
-        dtype=torch.float32,
-        mode="limits",
-        output_max=1.0,
-        output_min=-1.0,
-        range_eps=1e-4,
-        fit_offset=True,
-    ):
-        self.params_dict = _fit(
-            data,
-            last_n_dims=last_n_dims,
-            dtype=dtype,
-            mode=mode,
-            output_max=output_max,
-            output_min=output_min,
-            range_eps=range_eps,
-            fit_offset=fit_offset,
-        )
-
-    @classmethod
-    def create_fit(cls, data: Union[torch.Tensor, np.ndarray, zarr.Array], **kwargs):
-        obj = cls()
-        obj.fit(data, **kwargs)
-        return obj
-
-    @classmethod
-    def create_manual(
-        cls,
-        scale: Union[torch.Tensor, np.ndarray],
-        offset: Union[torch.Tensor, np.ndarray],
-        input_stats_dict: Dict[str, Union[torch.Tensor, np.ndarray]],
-    ):
-        def to_tensor(x):
-            if not isinstance(x, torch.Tensor):
-                x = torch.from_numpy(x)
-            x = x.flatten()
-            return x
-
-        # check
-        for x in [offset] + list(input_stats_dict.values()):
-            assert x.shape == scale.shape
-            assert x.dtype == scale.dtype
-
-        params_dict = nn.ParameterDict(
-            {
-                "scale": to_tensor(scale),
-                "offset": to_tensor(offset),
-                "input_stats": nn.ParameterDict(dict_apply(input_stats_dict, to_tensor)),
-            }
-        )
-        return cls(params_dict)
-
-    @classmethod
-    def create_identity(cls, dtype=torch.float32):
-        scale = torch.tensor([1], dtype=dtype)
-        offset = torch.tensor([0], dtype=dtype)
-        input_stats_dict = {
-            "min": torch.tensor([-1], dtype=dtype),
-            "max": torch.tensor([1], dtype=dtype),
-            "mean": torch.tensor([0], dtype=dtype),
-            "std": torch.tensor([1], dtype=dtype),
-        }
-        return cls.create_manual(scale, offset, input_stats_dict)
-
-    def normalize(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
-        return _normalize(x, self.params_dict, forward=True)
-
-    def unnormalize(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
-        return _normalize(x, self.params_dict, forward=False)
-
-    def get_input_stats(self):
-        return self.params_dict["input_stats"]
-
-    def get_output_stats(self):
-        return dict_apply(self.params_dict["input_stats"], self.normalize)
-
-    def __call__(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
-        return self.normalize(x)
-
-
-def _fit(
-    data: Union[torch.Tensor, np.ndarray, zarr.Array],
-    last_n_dims=1,
-    dtype=torch.float32,
-    mode="limits",
-    output_max=1.0,
-    output_min=-1.0,
-    range_eps=1e-4,
-    fit_offset=True,
-):
-    assert mode in ["limits", "gaussian"]
-    assert last_n_dims >= 0
-    assert output_max > output_min
-
-    # convert data to torch and type
-    if isinstance(data, zarr.Array):
-        data = data[:]
-    if isinstance(data, np.ndarray):
-        data = torch.from_numpy(data)
-    if dtype is not None:
-        data = data.type(dtype)
-
-    # convert shape
-    dim = 1
-    if last_n_dims > 0:
-        dim = np.prod(data.shape[-last_n_dims:])
-    data = data.reshape(-1, dim)
-
-    # compute input stats min max mean std
-    input_min, _ = data.min(axis=0)
-    input_max, _ = data.max(axis=0)
-    input_mean = data.mean(axis=0)
-    input_std = data.std(axis=0)
-
-    # compute scale and offset
-    if mode == "limits":
-        if fit_offset:
-            # unit scale
-            input_range = input_max - input_min
-            ignore_dim = input_range < range_eps
-            input_range[ignore_dim] = output_max - output_min
-            scale = (output_max - output_min) / input_range
-            offset = output_min - scale * input_min
-            offset[ignore_dim] = (output_max + output_min) / 2 - input_min[ignore_dim]
-            # ignore dims scaled to mean of output max and min
-        else:
-            # use this when data is pre-zero-centered.
-            assert output_max > 0
-            assert output_min < 0
-            # unit abs
-            output_abs = min(abs(output_min), abs(output_max))
-            input_abs = torch.maximum(torch.abs(input_min), torch.abs(input_max))
-            ignore_dim = input_abs < range_eps
-            input_abs[ignore_dim] = output_abs
-            # don't scale constant channels
-            scale = output_abs / input_abs
-            offset = torch.zeros_like(input_mean)
-    elif mode == "gaussian":
-        ignore_dim = input_std < range_eps
-        scale = input_std.clone()
-        scale[ignore_dim] = 1
-        scale = 1 / scale
-
-        offset = -input_mean * scale if fit_offset else torch.zeros_like(input_mean)
-
-    # save
-    this_params = nn.ParameterDict(
-        {
-            "scale": scale,
-            "offset": offset,
-            "input_stats": nn.ParameterDict(
-                {"min": input_min, "max": input_max, "mean": input_mean, "std": input_std}
-            ),
-        }
-    )
-    for p in this_params.parameters():
-        p.requires_grad_(False)
-    return this_params
-
-
-def _normalize(x, params, forward=True):
-    assert "scale" in params
-    if isinstance(x, np.ndarray):
-        x = torch.from_numpy(x)
-    scale = params["scale"]
-    offset = params["offset"]
-    x = x.to(device=scale.device, dtype=scale.dtype)
-    src_shape = x.shape
-    x = x.reshape(-1, scale.shape[0])
-    x = x * scale + offset if forward else (x - offset) / scale
-    x = x.reshape(src_shape)
-    return x
-
-
-def test():
-    data = torch.zeros((100, 10, 9, 2)).uniform_()
-    data[..., 0, 0] = 0
-
-    normalizer = SingleFieldLinearNormalizer()
-    normalizer.fit(data, mode="limits", last_n_dims=2)
-    datan = normalizer.normalize(data)
-    assert datan.shape == data.shape
-    assert np.allclose(datan.max(), 1.0)
-    assert np.allclose(datan.min(), -1.0)
-    dataun = normalizer.unnormalize(datan)
-    assert torch.allclose(data, dataun, atol=1e-7)
-
-    _ = normalizer.get_input_stats()
-    _ = normalizer.get_output_stats()
-
-    normalizer = SingleFieldLinearNormalizer()
-    normalizer.fit(data, mode="limits", last_n_dims=1, fit_offset=False)
-    datan = normalizer.normalize(data)
-    assert datan.shape == data.shape
-    assert np.allclose(datan.max(), 1.0, atol=1e-3)
-    assert np.allclose(datan.min(), 0.0, atol=1e-3)
-    dataun = normalizer.unnormalize(datan)
-    assert torch.allclose(data, dataun, atol=1e-7)
-
-    data = torch.zeros((100, 10, 9, 2)).uniform_()
-    normalizer = SingleFieldLinearNormalizer()
-    normalizer.fit(data, mode="gaussian", last_n_dims=0)
-    datan = normalizer.normalize(data)
-    assert datan.shape == data.shape
-    assert np.allclose(datan.mean(), 0.0, atol=1e-3)
-    assert np.allclose(datan.std(), 1.0, atol=1e-3)
-    dataun = normalizer.unnormalize(datan)
-    assert torch.allclose(data, dataun, atol=1e-7)
-
-    # dict
-    data = torch.zeros((100, 10, 9, 2)).uniform_()
-    data[..., 0, 0] = 0
-
-    normalizer = LinearNormalizer()
-    normalizer.fit(data, mode="limits", last_n_dims=2)
-    datan = normalizer.normalize(data)
-    assert datan.shape == data.shape
-    assert np.allclose(datan.max(), 1.0)
-    assert np.allclose(datan.min(), -1.0)
-    dataun = normalizer.unnormalize(datan)
-    assert torch.allclose(data, dataun, atol=1e-7)
-
-    _ = normalizer.get_input_stats()
-    _ = normalizer.get_output_stats()
-
-    data = {
-        "obs": torch.zeros((1000, 128, 9, 2)).uniform_() * 512,
-        "action": torch.zeros((1000, 128, 2)).uniform_() * 512,
-    }
-    normalizer = LinearNormalizer()
-    normalizer.fit(data)
-    datan = normalizer.normalize(data)
-    dataun = normalizer.unnormalize(datan)
-    for key in data:
-        assert torch.allclose(data[key], dataun[key], atol=1e-4)
-
-    _ = normalizer.get_input_stats()
-    _ = normalizer.get_output_stats()
-
-    state_dict = normalizer.state_dict()
-    n = LinearNormalizer()
-    n.load_state_dict(state_dict)
-    datan = n.normalize(data)
-    dataun = n.unnormalize(datan)
-    for key in data:
-        assert torch.allclose(data[key], dataun[key], atol=1e-4)
diff --git a/lerobot/common/policies/diffusion/model/positional_embedding.py b/lerobot/common/policies/diffusion/model/positional_embedding.py
deleted file mode 100644
index 65fc97bd..00000000
--- a/lerobot/common/policies/diffusion/model/positional_embedding.py
+++ /dev/null
@@ -1,19 +0,0 @@
-import math
-
-import torch
-import torch.nn as nn
-
-
-class SinusoidalPosEmb(nn.Module):
-    def __init__(self, dim):
-        super().__init__()
-        self.dim = dim
-
-    def forward(self, x):
-        device = x.device
-        half_dim = self.dim // 2
-        emb = math.log(10000) / (half_dim - 1)
-        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
-        emb = x[:, None] * emb[None, :]
-        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
-        return emb
diff --git a/lerobot/common/policies/diffusion/model/rgb_encoder.py b/lerobot/common/policies/diffusion/model/rgb_encoder.py
new file mode 100644
index 00000000..2a5edf45
--- /dev/null
+++ b/lerobot/common/policies/diffusion/model/rgb_encoder.py
@@ -0,0 +1,147 @@
+from typing import Callable
+
+import torch
+import torchvision
+from robomimic.models.base_nets import SpatialSoftmax
+from torch import Tensor, nn
+from torchvision.transforms import CenterCrop, RandomCrop
+
+
+class RgbEncoder(nn.Module):
+    """Encoder an RGB image into a 1D feature vector.
+
+    Includes the ability to normalize and crop the image first.
+    """
+
+    def __init__(
+        self,
+        input_shape: tuple[int, int, int],
+        norm_mean_std: tuple[float, float] = [1.0, 1.0],
+        crop_shape: tuple[int, int] | None = None,
+        random_crop: bool = False,
+        backbone_name: str = "resnet18",
+        pretrained_backbone: bool = False,
+        use_group_norm: bool = False,
+        relu: bool = True,
+        num_keypoints: int = 32,
+    ):
+        """
+        Args:
+            input_shape: channel-first input shape (C, H, W)
+            norm_mean_std: mean and standard deviation used for image normalization. Images are normalized as
+                (image - mean) / std.
+            crop_shape: (H, W) shape to crop to (must fit within the input shape). If not provided, no
+                cropping is done.
+            random_crop: Whether the crop should be random at training time (it's always a center crop in
+                eval mode).
+            backbone_name: The name of one of the available resnet models from torchvision (eg resnet18).
+            pretrained_backbone: whether to use timm pretrained weights.
+            use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
+                The group sizes are set to be about 16 (to be precise, feature_dim // 16).
+            relu: whether to use relu as a final step.
+            num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image).
+        """
+        super().__init__()
+        if input_shape[0] != 3:
+            raise ValueError("Only RGB images are handled")
+        if not backbone_name.startswith("resnet"):
+            raise ValueError(
+                "Only resnet is supported for now (because of the assumption that 'layer4' is the output layer)"
+            )
+
+        # Set up optional preprocessing.
+        if norm_mean_std == [1.0, 1.0]:
+            self.normalizer = nn.Identity()
+        else:
+            self.normalizer = torchvision.transforms.Normalize(mean=norm_mean_std[0], std=norm_mean_std[1])
+
+        if crop_shape is not None:
+            self.do_crop = True
+            self.center_crop = CenterCrop(crop_shape)  # always use center crop for eval
+            if random_crop:
+                self.maybe_random_crop = RandomCrop(crop_shape)
+            else:
+                self.maybe_random_crop = self.center_crop
+        else:
+            self.do_crop = False
+
+        # Set up backbone.
+        backbone_model = getattr(torchvision.models, backbone_name)(pretrained=pretrained_backbone)
+        # Note: This assumes that the layer4 feature map is children()[-3]
+        # TODO(alexander-soare): Use a safer alternative.
+        self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
+        if use_group_norm:
+            if pretrained_backbone:
+                raise ValueError(
+                    "You can't replace BatchNorm in a pretrained model without ruining the weights!"
+                )
+            self.backbone = _replace_submodules(
+                root_module=self.backbone,
+                predicate=lambda x: isinstance(x, nn.BatchNorm2d),
+                func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
+            )
+
+        # Set up pooling and final layers.
+        # Use a dry run to get the feature map shape.
+        with torch.inference_mode():
+            feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:])
+        self.pool = SpatialSoftmax(feat_map_shape, num_kp=num_keypoints)
+        self.feature_dim = num_keypoints * 2
+        self.out = nn.Linear(num_keypoints * 2, self.feature_dim)
+        self.maybe_relu = nn.ReLU() if relu else nn.Identity()
+
+    def forward(self, x: Tensor) -> Tensor:
+        """
+        Args:
+            x: (B, C, H, W) image tensor with pixel values in [0, 1].
+        Returns:
+            (B, D) image feature.
+        """
+        # Preprocess: normalize and maybe crop (if it was set up in the __init__).
+        x = self.normalizer(x)
+        if self.do_crop:
+            if self.training:  # noqa: SIM108
+                x = self.maybe_random_crop(x)
+            else:
+                # Always use center crop for eval.
+                x = self.center_crop(x)
+        # Extract backbone feature.
+        x = torch.flatten(self.pool(self.backbone(x)), start_dim=1)
+        # Final linear layer.
+        x = self.out(x)
+        # Maybe a final non-linearity.
+        x = self.maybe_relu(x)
+        return x
+
+
+def _replace_submodules(
+    root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
+) -> nn.Module:
+    """
+    Args:
+        root_module: The module for which the submodules need to be replaced
+        predicate: Takes a module as an argument and must return True if the that module is to be replaced.
+        func: Takes a module as an argument and returns a new module to replace it with.
+    Returns:
+        The root module with its submodules replaced.
+    """
+    if predicate(root_module):
+        return func(root_module)
+
+    replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
+    for *parents, k in replace_list:
+        parent_module = root_module
+        if len(parents) > 0:
+            parent_module = root_module.get_submodule(".".join(parents))
+        if isinstance(parent_module, nn.Sequential):
+            src_module = parent_module[int(k)]
+        else:
+            src_module = getattr(parent_module, k)
+        tgt_module = func(src_module)
+        if isinstance(parent_module, nn.Sequential):
+            parent_module[int(k)] = tgt_module
+        else:
+            setattr(parent_module, k, tgt_module)
+    # verify that all BN are replaced
+    assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
+    return root_module
diff --git a/lerobot/common/policies/diffusion/model/tensor_utils.py b/lerobot/common/policies/diffusion/model/tensor_utils.py
deleted file mode 100644
index df9a568a..00000000
--- a/lerobot/common/policies/diffusion/model/tensor_utils.py
+++ /dev/null
@@ -1,972 +0,0 @@
-"""
-A collection of utilities for working with nested tensor structures consisting
-of numpy arrays and torch tensors.
-"""
-
-import collections
-
-import numpy as np
-import torch
-
-
-def recursive_dict_list_tuple_apply(x, type_func_dict):
-    """
-    Recursively apply functions to a nested dictionary or list or tuple, given a dictionary of
-    {data_type: function_to_apply}.
-
-    Args:
-        x (dict or list or tuple): a possibly nested dictionary or list or tuple
-        type_func_dict (dict): a mapping from data types to the functions to be
-            applied for each data type.
-
-    Returns:
-        y (dict or list or tuple): new nested dict-list-tuple
-    """
-    assert list not in type_func_dict
-    assert tuple not in type_func_dict
-    assert dict not in type_func_dict
-
-    if isinstance(x, (dict, collections.OrderedDict)):
-        new_x = collections.OrderedDict() if isinstance(x, collections.OrderedDict) else {}
-        for k, v in x.items():
-            new_x[k] = recursive_dict_list_tuple_apply(v, type_func_dict)
-        return new_x
-    elif isinstance(x, (list, tuple)):
-        ret = [recursive_dict_list_tuple_apply(v, type_func_dict) for v in x]
-        if isinstance(x, tuple):
-            ret = tuple(ret)
-        return ret
-    else:
-        for t, f in type_func_dict.items():
-            if isinstance(x, t):
-                return f(x)
-        else:
-            raise NotImplementedError("Cannot handle data type %s" % str(type(x)))
-
-
-def map_tensor(x, func):
-    """
-    Apply function @func to torch.Tensor objects in a nested dictionary or
-    list or tuple.
-
-    Args:
-        x (dict or list or tuple): a possibly nested dictionary or list or tuple
-        func (function): function to apply to each tensor
-
-    Returns:
-        y (dict or list or tuple): new nested dict-list-tuple
-    """
-    return recursive_dict_list_tuple_apply(
-        x,
-        {
-            torch.Tensor: func,
-            type(None): lambda x: x,
-        },
-    )
-
-
-def map_ndarray(x, func):
-    """
-    Apply function @func to np.ndarray objects in a nested dictionary or
-    list or tuple.
-
-    Args:
-        x (dict or list or tuple): a possibly nested dictionary or list or tuple
-        func (function): function to apply to each array
-
-    Returns:
-        y (dict or list or tuple): new nested dict-list-tuple
-    """
-    return recursive_dict_list_tuple_apply(
-        x,
-        {
-            np.ndarray: func,
-            type(None): lambda x: x,
-        },
-    )
-
-
-def map_tensor_ndarray(x, tensor_func, ndarray_func):
-    """
-    Apply function @tensor_func to torch.Tensor objects and @ndarray_func to
-    np.ndarray objects in a nested dictionary or list or tuple.
-
-    Args:
-        x (dict or list or tuple): a possibly nested dictionary or list or tuple
-        tensor_func (function): function to apply to each tensor
-        ndarray_Func (function): function to apply to each array
-
-    Returns:
-        y (dict or list or tuple): new nested dict-list-tuple
-    """
-    return recursive_dict_list_tuple_apply(
-        x,
-        {
-            torch.Tensor: tensor_func,
-            np.ndarray: ndarray_func,
-            type(None): lambda x: x,
-        },
-    )
-
-
-def clone(x):
-    """
-    Clones all torch tensors and numpy arrays in nested dictionary or list
-    or tuple and returns a new nested structure.
-
-    Args:
-        x (dict or list or tuple): a possibly nested dictionary or list or tuple
-
-    Returns:
-        y (dict or list or tuple): new nested dict-list-tuple
-    """
-    return recursive_dict_list_tuple_apply(
-        x,
-        {
-            torch.Tensor: lambda x: x.clone(),
-            np.ndarray: lambda x: x.copy(),
-            type(None): lambda x: x,
-        },
-    )
-
-
-def detach(x):
-    """
-    Detaches all torch tensors in nested dictionary or list
-    or tuple and returns a new nested structure.
-
-    Args:
-        x (dict or list or tuple): a possibly nested dictionary or list or tuple
-
-    Returns:
-        y (dict or list or tuple): new nested dict-list-tuple
-    """
-    return recursive_dict_list_tuple_apply(
-        x,
-        {
-            torch.Tensor: lambda x: x.detach(),
-        },
-    )
-
-
-def to_batch(x):
-    """
-    Introduces a leading batch dimension of 1 for all torch tensors and numpy
-    arrays in nested dictionary or list or tuple and returns a new nested structure.
-
-    Args:
-        x (dict or list or tuple): a possibly nested dictionary or list or tuple
-
-    Returns:
-        y (dict or list or tuple): new nested dict-list-tuple
-    """
-    return recursive_dict_list_tuple_apply(
-        x,
-        {
-            torch.Tensor: lambda x: x[None, ...],
-            np.ndarray: lambda x: x[None, ...],
-            type(None): lambda x: x,
-        },
-    )
-
-
-def to_sequence(x):
-    """
-    Introduces a time dimension of 1 at dimension 1 for all torch tensors and numpy
-    arrays in nested dictionary or list or tuple and returns a new nested structure.
-
-    Args:
-        x (dict or list or tuple): a possibly nested dictionary or list or tuple
-
-    Returns:
-        y (dict or list or tuple): new nested dict-list-tuple
-    """
-    return recursive_dict_list_tuple_apply(
-        x,
-        {
-            torch.Tensor: lambda x: x[:, None, ...],
-            np.ndarray: lambda x: x[:, None, ...],
-            type(None): lambda x: x,
-        },
-    )
-
-
-def index_at_time(x, ind):
-    """
-    Indexes all torch tensors and numpy arrays in dimension 1 with index @ind in
-    nested dictionary or list or tuple and returns a new nested structure.
-
-    Args:
-        x (dict or list or tuple): a possibly nested dictionary or list or tuple
-        ind (int): index
-
-    Returns:
-        y (dict or list or tuple): new nested dict-list-tuple
-    """
-    return recursive_dict_list_tuple_apply(
-        x,
-        {
-            torch.Tensor: lambda x: x[:, ind, ...],
-            np.ndarray: lambda x: x[:, ind, ...],
-            type(None): lambda x: x,
-        },
-    )
-
-
-def unsqueeze(x, dim):
-    """
-    Adds dimension of size 1 at dimension @dim in all torch tensors and numpy arrays
-    in nested dictionary or list or tuple and returns a new nested structure.
-
-    Args:
-        x (dict or list or tuple): a possibly nested dictionary or list or tuple
-        dim (int): dimension
-
-    Returns:
-        y (dict or list or tuple): new nested dict-list-tuple
-    """
-    return recursive_dict_list_tuple_apply(
-        x,
-        {
-            torch.Tensor: lambda x: x.unsqueeze(dim=dim),
-            np.ndarray: lambda x: np.expand_dims(x, axis=dim),
-            type(None): lambda x: x,
-        },
-    )
-
-
-def contiguous(x):
-    """
-    Makes all torch tensors and numpy arrays contiguous in nested dictionary or
-    list or tuple and returns a new nested structure.
-
-    Args:
-        x (dict or list or tuple): a possibly nested dictionary or list or tuple
-
-    Returns:
-        y (dict or list or tuple): new nested dict-list-tuple
-    """
-    return recursive_dict_list_tuple_apply(
-        x,
-        {
-            torch.Tensor: lambda x: x.contiguous(),
-            np.ndarray: lambda x: np.ascontiguousarray(x),
-            type(None): lambda x: x,
-        },
-    )
-
-
-def to_device(x, device):
-    """
-    Sends all torch tensors in nested dictionary or list or tuple to device
-    @device, and returns a new nested structure.
-
-    Args:
-        x (dict or list or tuple): a possibly nested dictionary or list or tuple
-        device (torch.Device): device to send tensors to
-
-    Returns:
-        y (dict or list or tuple): new nested dict-list-tuple
-    """
-    return recursive_dict_list_tuple_apply(
-        x,
-        {
-            torch.Tensor: lambda x, d=device: x.to(d),
-            type(None): lambda x: x,
-        },
-    )
-
-
-def to_tensor(x):
-    """
-    Converts all numpy arrays in nested dictionary or list or tuple to
-    torch tensors (and leaves existing torch Tensors as-is), and returns
-    a new nested structure.
-
-    Args:
-        x (dict or list or tuple): a possibly nested dictionary or list or tuple
-
-    Returns:
-        y (dict or list or tuple): new nested dict-list-tuple
-    """
-    return recursive_dict_list_tuple_apply(
-        x,
-        {
-            torch.Tensor: lambda x: x,
-            np.ndarray: lambda x: torch.from_numpy(x),
-            type(None): lambda x: x,
-        },
-    )
-
-
-def to_numpy(x):
-    """
-    Converts all torch tensors in nested dictionary or list or tuple to
-    numpy (and leaves existing numpy arrays as-is), and returns
-    a new nested structure.
-
-    Args:
-        x (dict or list or tuple): a possibly nested dictionary or list or tuple
-
-    Returns:
-        y (dict or list or tuple): new nested dict-list-tuple
-    """
-
-    def f(tensor):
-        if tensor.is_cuda:
-            return tensor.detach().cpu().numpy()
-        else:
-            return tensor.detach().numpy()
-
-    return recursive_dict_list_tuple_apply(
-        x,
-        {
-            torch.Tensor: f,
-            np.ndarray: lambda x: x,
-            type(None): lambda x: x,
-        },
-    )
-
-
-def to_list(x):
-    """
-    Converts all torch tensors and numpy arrays in nested dictionary or list
-    or tuple to a list, and returns a new nested structure. Useful for
-    json encoding.
-
-    Args:
-        x (dict or list or tuple): a possibly nested dictionary or list or tuple
-
-    Returns:
-        y (dict or list or tuple): new nested dict-list-tuple
-    """
-
-    def f(tensor):
-        if tensor.is_cuda:
-            return tensor.detach().cpu().numpy().tolist()
-        else:
-            return tensor.detach().numpy().tolist()
-
-    return recursive_dict_list_tuple_apply(
-        x,
-        {
-            torch.Tensor: f,
-            np.ndarray: lambda x: x.tolist(),
-            type(None): lambda x: x,
-        },
-    )
-
-
-def to_float(x):
-    """
-    Converts all torch tensors and numpy arrays in nested dictionary or list
-    or tuple to float type entries, and returns a new nested structure.
-
-    Args:
-        x (dict or list or tuple): a possibly nested dictionary or list or tuple
-
-    Returns:
-        y (dict or list or tuple): new nested dict-list-tuple
-    """
-    return recursive_dict_list_tuple_apply(
-        x,
-        {
-            torch.Tensor: lambda x: x.float(),
-            np.ndarray: lambda x: x.astype(np.float32),
-            type(None): lambda x: x,
-        },
-    )
-
-
-def to_uint8(x):
-    """
-    Converts all torch tensors and numpy arrays in nested dictionary or list
-    or tuple to uint8 type entries, and returns a new nested structure.
-
-    Args:
-        x (dict or list or tuple): a possibly nested dictionary or list or tuple
-
-    Returns:
-        y (dict or list or tuple): new nested dict-list-tuple
-    """
-    return recursive_dict_list_tuple_apply(
-        x,
-        {
-            torch.Tensor: lambda x: x.byte(),
-            np.ndarray: lambda x: x.astype(np.uint8),
-            type(None): lambda x: x,
-        },
-    )
-
-
-def to_torch(x, device):
-    """
-    Converts all numpy arrays and torch tensors in nested dictionary or list or tuple to
-    torch tensors on device @device and returns a new nested structure.
-
-    Args:
-        x (dict or list or tuple): a possibly nested dictionary or list or tuple
-        device (torch.Device): device to send tensors to
-
-    Returns:
-        y (dict or list or tuple): new nested dict-list-tuple
-    """
-    return to_device(to_float(to_tensor(x)), device)
-
-
-def to_one_hot_single(tensor, num_class):
-    """
-    Convert tensor to one-hot representation, assuming a certain number of total class labels.
-
-    Args:
-        tensor (torch.Tensor): tensor containing integer labels
-        num_class (int): number of classes
-
-    Returns:
-        x (torch.Tensor): tensor containing one-hot representation of labels
-    """
-    x = torch.zeros(tensor.size() + (num_class,)).to(tensor.device)
-    x.scatter_(-1, tensor.unsqueeze(-1), 1)
-    return x
-
-
-def to_one_hot(tensor, num_class):
-    """
-    Convert all tensors in nested dictionary or list or tuple to one-hot representation,
-    assuming a certain number of total class labels.
-
-    Args:
-        tensor (dict or list or tuple): a possibly nested dictionary or list or tuple
-        num_class (int): number of classes
-
-    Returns:
-        y (dict or list or tuple): new nested dict-list-tuple
-    """
-    return map_tensor(tensor, func=lambda x, nc=num_class: to_one_hot_single(x, nc))
-
-
-def flatten_single(x, begin_axis=1):
-    """
-    Flatten a tensor in all dimensions from @begin_axis onwards.
-
-    Args:
-        x (torch.Tensor): tensor to flatten
-        begin_axis (int): which axis to flatten from
-
-    Returns:
-        y (torch.Tensor): flattened tensor
-    """
-    fixed_size = x.size()[:begin_axis]
-    _s = list(fixed_size) + [-1]
-    return x.reshape(*_s)
-
-
-def flatten(x, begin_axis=1):
-    """
-    Flatten all tensors in nested dictionary or list or tuple, from @begin_axis onwards.
-
-    Args:
-        x (dict or list or tuple): a possibly nested dictionary or list or tuple
-        begin_axis (int): which axis to flatten from
-
-    Returns:
-        y (dict or list or tuple): new nested dict-list-tuple
-    """
-    return recursive_dict_list_tuple_apply(
-        x,
-        {
-            torch.Tensor: lambda x, b=begin_axis: flatten_single(x, begin_axis=b),
-        },
-    )
-
-
-def reshape_dimensions_single(x, begin_axis, end_axis, target_dims):
-    """
-    Reshape selected dimensions in a tensor to a target dimension.
-
-    Args:
-        x (torch.Tensor): tensor to reshape
-        begin_axis (int): begin dimension
-        end_axis (int): end dimension
-        target_dims (tuple or list): target shape for the range of dimensions
-            (@begin_axis, @end_axis)
-
-    Returns:
-        y (torch.Tensor): reshaped tensor
-    """
-    assert begin_axis <= end_axis
-    assert begin_axis >= 0
-    assert end_axis < len(x.shape)
-    assert isinstance(target_dims, (tuple, list))
-    s = x.shape
-    final_s = []
-    for i in range(len(s)):
-        if i == begin_axis:
-            final_s.extend(target_dims)
-        elif i < begin_axis or i > end_axis:
-            final_s.append(s[i])
-    return x.reshape(*final_s)
-
-
-def reshape_dimensions(x, begin_axis, end_axis, target_dims):
-    """
-    Reshape selected dimensions for all tensors in nested dictionary or list or tuple
-    to a target dimension.
-
-    Args:
-        x (dict or list or tuple): a possibly nested dictionary or list or tuple
-        begin_axis (int): begin dimension
-        end_axis (int): end dimension
-        target_dims (tuple or list): target shape for the range of dimensions
-            (@begin_axis, @end_axis)
-
-    Returns:
-        y (dict or list or tuple): new nested dict-list-tuple
-    """
-    return recursive_dict_list_tuple_apply(
-        x,
-        {
-            torch.Tensor: lambda x, b=begin_axis, e=end_axis, t=target_dims: reshape_dimensions_single(
-                x, begin_axis=b, end_axis=e, target_dims=t
-            ),
-            np.ndarray: lambda x, b=begin_axis, e=end_axis, t=target_dims: reshape_dimensions_single(
-                x, begin_axis=b, end_axis=e, target_dims=t
-            ),
-            type(None): lambda x: x,
-        },
-    )
-
-
-def join_dimensions(x, begin_axis, end_axis):
-    """
-    Joins all dimensions between dimensions (@begin_axis, @end_axis) into a flat dimension, for
-    all tensors in nested dictionary or list or tuple.
-
-    Args:
-        x (dict or list or tuple): a possibly nested dictionary or list or tuple
-        begin_axis (int): begin dimension
-        end_axis (int): end dimension
-
-    Returns:
-        y (dict or list or tuple): new nested dict-list-tuple
-    """
-    return recursive_dict_list_tuple_apply(
-        x,
-        {
-            torch.Tensor: lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single(
-                x, begin_axis=b, end_axis=e, target_dims=[-1]
-            ),
-            np.ndarray: lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single(
-                x, begin_axis=b, end_axis=e, target_dims=[-1]
-            ),
-            type(None): lambda x: x,
-        },
-    )
-
-
-def expand_at_single(x, size, dim):
-    """
-    Expand a tensor at a single dimension @dim by @size
-
-    Args:
-        x (torch.Tensor): input tensor
-        size (int): size to expand
-        dim (int): dimension to expand
-
-    Returns:
-        y (torch.Tensor): expanded tensor
-    """
-    assert dim < x.ndimension()
-    assert x.shape[dim] == 1
-    expand_dims = [-1] * x.ndimension()
-    expand_dims[dim] = size
-    return x.expand(*expand_dims)
-
-
-def expand_at(x, size, dim):
-    """
-    Expand all tensors in nested dictionary or list or tuple at a single
-    dimension @dim by @size.
-
-    Args:
-        x (dict or list or tuple): a possibly nested dictionary or list or tuple
-        size (int): size to expand
-        dim (int): dimension to expand
-
-    Returns:
-        y (dict or list or tuple): new nested dict-list-tuple
-    """
-    return map_tensor(x, lambda t, s=size, d=dim: expand_at_single(t, s, d))
-
-
-def unsqueeze_expand_at(x, size, dim):
-    """
-    Unsqueeze and expand a tensor at a dimension @dim by @size.
-
-    Args:
-        x (dict or list or tuple): a possibly nested dictionary or list or tuple
-        size (int): size to expand
-        dim (int): dimension to unsqueeze and expand
-
-    Returns:
-        y (dict or list or tuple): new nested dict-list-tuple
-    """
-    x = unsqueeze(x, dim)
-    return expand_at(x, size, dim)
-
-
-def repeat_by_expand_at(x, repeats, dim):
-    """
-    Repeat a dimension by combining expand and reshape operations.
-
-    Args:
-        x (dict or list or tuple): a possibly nested dictionary or list or tuple
-        repeats (int): number of times to repeat the target dimension
-        dim (int): dimension to repeat on
-
-    Returns:
-        y (dict or list or tuple): new nested dict-list-tuple
-    """
-    x = unsqueeze_expand_at(x, repeats, dim + 1)
-    return join_dimensions(x, dim, dim + 1)
-
-
-def named_reduce_single(x, reduction, dim):
-    """
-    Reduce tensor at a dimension by named reduction functions.
-
-    Args:
-        x (torch.Tensor): tensor to be reduced
-        reduction (str): one of ["sum", "max", "mean", "flatten"]
-        dim (int): dimension to be reduced (or begin axis for flatten)
-
-    Returns:
-        y (torch.Tensor): reduced tensor
-    """
-    assert x.ndimension() > dim
-    assert reduction in ["sum", "max", "mean", "flatten"]
-    if reduction == "flatten":
-        x = flatten(x, begin_axis=dim)
-    elif reduction == "max":
-        x = torch.max(x, dim=dim)[0]  # [B, D]
-    elif reduction == "sum":
-        x = torch.sum(x, dim=dim)
-    else:
-        x = torch.mean(x, dim=dim)
-    return x
-
-
-def named_reduce(x, reduction, dim):
-    """
-    Reduces all tensors in nested dictionary or list or tuple at a dimension
-    using a named reduction function.
-
-    Args:
-        x (dict or list or tuple): a possibly nested dictionary or list or tuple
-        reduction (str): one of ["sum", "max", "mean", "flatten"]
-        dim (int): dimension to be reduced (or begin axis for flatten)
-
-    Returns:
-        y (dict or list or tuple): new nested dict-list-tuple
-    """
-    return map_tensor(x, func=lambda t, r=reduction, d=dim: named_reduce_single(t, r, d))
-
-
-def gather_along_dim_with_dim_single(x, target_dim, source_dim, indices):
-    """
-    This function indexes out a target dimension of a tensor in a structured way,
-    by allowing a different value to be selected for each member of a flat index
-    tensor (@indices) corresponding to a source dimension. This can be interpreted
-    as moving along the source dimension, using the corresponding index value
-    in @indices to select values for all other dimensions outside of the
-    source and target dimensions. A common use case is to gather values
-    in target dimension 1 for each batch member (target dimension 0).
-
-    Args:
-        x (torch.Tensor): tensor to gather values for
-        target_dim (int): dimension to gather values along
-        source_dim (int): dimension to hold constant and use for gathering values
-            from the other dimensions
-        indices (torch.Tensor): flat index tensor with same shape as tensor @x along
-            @source_dim
-
-    Returns:
-        y (torch.Tensor): gathered tensor, with dimension @target_dim indexed out
-    """
-    assert len(indices.shape) == 1
-    assert x.shape[source_dim] == indices.shape[0]
-
-    # unsqueeze in all dimensions except the source dimension
-    new_shape = [1] * x.ndimension()
-    new_shape[source_dim] = -1
-    indices = indices.reshape(*new_shape)
-
-    # repeat in all dimensions - but preserve shape of source dimension,
-    # and make sure target_dimension has singleton dimension
-    expand_shape = list(x.shape)
-    expand_shape[source_dim] = -1
-    expand_shape[target_dim] = 1
-    indices = indices.expand(*expand_shape)
-
-    out = x.gather(dim=target_dim, index=indices)
-    return out.squeeze(target_dim)
-
-
-def gather_along_dim_with_dim(x, target_dim, source_dim, indices):
-    """
-    Apply @gather_along_dim_with_dim_single to all tensors in a nested
-    dictionary or list or tuple.
-
-    Args:
-        x (dict or list or tuple): a possibly nested dictionary or list or tuple
-        target_dim (int): dimension to gather values along
-        source_dim (int): dimension to hold constant and use for gathering values
-            from the other dimensions
-        indices (torch.Tensor): flat index tensor with same shape as tensor @x along
-            @source_dim
-
-    Returns:
-        y (dict or list or tuple): new nested dict-list-tuple
-    """
-    return map_tensor(
-        x, lambda y, t=target_dim, s=source_dim, i=indices: gather_along_dim_with_dim_single(y, t, s, i)
-    )
-
-
-def gather_sequence_single(seq, indices):
-    """
-    Given a tensor with leading dimensions [B, T, ...], gather an element from each sequence in
-    the batch given an index for each sequence.
-
-    Args:
-        seq (torch.Tensor): tensor with leading dimensions [B, T, ...]
-        indices (torch.Tensor): tensor indices of shape [B]
-
-    Return:
-        y (torch.Tensor): indexed tensor of shape [B, ....]
-    """
-    return gather_along_dim_with_dim_single(seq, target_dim=1, source_dim=0, indices=indices)
-
-
-def gather_sequence(seq, indices):
-    """
-    Given a nested dictionary or list or tuple, gathers an element from each sequence of the batch
-    for tensors with leading dimensions [B, T, ...].
-
-    Args:
-        seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors
-            of leading dimensions [B, T, ...]
-        indices (torch.Tensor): tensor indices of shape [B]
-
-    Returns:
-        y (dict or list or tuple): new nested dict-list-tuple with tensors of shape [B, ...]
-    """
-    return gather_along_dim_with_dim(seq, target_dim=1, source_dim=0, indices=indices)
-
-
-def pad_sequence_single(seq, padding, batched=False, pad_same=True, pad_values=None):
-    """
-    Pad input tensor or array @seq in the time dimension (dimension 1).
-
-    Args:
-        seq (np.ndarray or torch.Tensor): sequence to be padded
-        padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1
-        batched (bool): if sequence has the batch dimension
-        pad_same (bool): if pad by duplicating
-        pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same
-
-    Returns:
-        padded sequence (np.ndarray or torch.Tensor)
-    """
-    assert isinstance(seq, (np.ndarray, torch.Tensor))
-    assert pad_same or pad_values is not None
-    if pad_values is not None:
-        assert isinstance(pad_values, float)
-    repeat_func = np.repeat if isinstance(seq, np.ndarray) else torch.repeat_interleave
-    concat_func = np.concatenate if isinstance(seq, np.ndarray) else torch.cat
-    ones_like_func = np.ones_like if isinstance(seq, np.ndarray) else torch.ones_like
-    seq_dim = 1 if batched else 0
-
-    begin_pad = []
-    end_pad = []
-
-    if padding[0] > 0:
-        pad = seq[[0]] if pad_same else ones_like_func(seq[[0]]) * pad_values
-        begin_pad.append(repeat_func(pad, padding[0], seq_dim))
-    if padding[1] > 0:
-        pad = seq[[-1]] if pad_same else ones_like_func(seq[[-1]]) * pad_values
-        end_pad.append(repeat_func(pad, padding[1], seq_dim))
-
-    return concat_func(begin_pad + [seq] + end_pad, seq_dim)
-
-
-def pad_sequence(seq, padding, batched=False, pad_same=True, pad_values=None):
-    """
-    Pad a nested dictionary or list or tuple of sequence tensors in the time dimension (dimension 1).
-
-    Args:
-        seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors
-            of leading dimensions [B, T, ...]
-        padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1
-        batched (bool): if sequence has the batch dimension
-        pad_same (bool): if pad by duplicating
-        pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same
-
-    Returns:
-        padded sequence (dict or list or tuple)
-    """
-    return recursive_dict_list_tuple_apply(
-        seq,
-        {
-            torch.Tensor: lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values: pad_sequence_single(
-                x, p, b, ps, pv
-            ),
-            np.ndarray: lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values: pad_sequence_single(
-                x, p, b, ps, pv
-            ),
-            type(None): lambda x: x,
-        },
-    )
-
-
-def assert_size_at_dim_single(x, size, dim, msg):
-    """
-    Ensure that array or tensor @x has size @size in dim @dim.
-
-    Args:
-        x (np.ndarray or torch.Tensor): input array or tensor
-        size (int): size that tensors should have at @dim
-        dim (int): dimension to check
-        msg (str): text to display if assertion fails
-    """
-    assert x.shape[dim] == size, msg
-
-
-def assert_size_at_dim(x, size, dim, msg):
-    """
-    Ensure that arrays and tensors in nested dictionary or list or tuple have
-    size @size in dim @dim.
-
-    Args:
-        x (dict or list or tuple): a possibly nested dictionary or list or tuple
-        size (int): size that tensors should have at @dim
-        dim (int): dimension to check
-    """
-    map_tensor(x, lambda t, s=size, d=dim, m=msg: assert_size_at_dim_single(t, s, d, m))
-
-
-def get_shape(x):
-    """
-    Get all shapes of arrays and tensors in nested dictionary or list or tuple.
-
-    Args:
-        x (dict or list or tuple): a possibly nested dictionary or list or tuple
-
-    Returns:
-        y (dict or list or tuple): new nested dict-list-tuple that contains each array or
-            tensor's shape
-    """
-    return recursive_dict_list_tuple_apply(
-        x,
-        {
-            torch.Tensor: lambda x: x.shape,
-            np.ndarray: lambda x: x.shape,
-            type(None): lambda x: x,
-        },
-    )
-
-
-def list_of_flat_dict_to_dict_of_list(list_of_dict):
-    """
-    Helper function to go from a list of flat dictionaries to a dictionary of lists.
-    By "flat" we mean that none of the values are dictionaries, but are numpy arrays,
-    floats, etc.
-
-    Args:
-        list_of_dict (list): list of flat dictionaries
-
-    Returns:
-        dict_of_list (dict): dictionary of lists
-    """
-    assert isinstance(list_of_dict, list)
-    dic = collections.OrderedDict()
-    for i in range(len(list_of_dict)):
-        for k in list_of_dict[i]:
-            if k not in dic:
-                dic[k] = []
-            dic[k].append(list_of_dict[i][k])
-    return dic
-
-
-def flatten_nested_dict_list(d, parent_key="", sep="_", item_key=""):
-    """
-    Flatten a nested dict or list to a list.
-
-    For example, given a dict
-    {
-        a: 1
-        b: {
-            c: 2
-        }
-        c: 3
-    }
-
-    the function would return [(a, 1), (b_c, 2), (c, 3)]
-
-    Args:
-        d (dict, list): a nested dict or list to be flattened
-        parent_key (str): recursion helper
-        sep (str): separator for nesting keys
-        item_key (str): recursion helper
-    Returns:
-        list: a list of (key, value) tuples
-    """
-    items = []
-    if isinstance(d, (tuple, list)):
-        new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key
-        for i, v in enumerate(d):
-            items.extend(flatten_nested_dict_list(v, new_key, sep=sep, item_key=str(i)))
-        return items
-    elif isinstance(d, dict):
-        new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key
-        for k, v in d.items():
-            assert isinstance(k, str)
-            items.extend(flatten_nested_dict_list(v, new_key, sep=sep, item_key=k))
-        return items
-    else:
-        new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key
-        return [(new_key, d)]
-
-
-def time_distributed(inputs, op, activation=None, inputs_as_kwargs=False, inputs_as_args=False, **kwargs):
-    """
-    Apply function @op to all tensors in nested dictionary or list or tuple @inputs in both the
-    batch (B) and time (T) dimension, where the tensors are expected to have shape [B, T, ...].
-    Will do this by reshaping tensors to [B * T, ...], passing through the op, and then reshaping
-    outputs to [B, T, ...].
-
-    Args:
-        inputs (list or tuple or dict): a possibly nested dictionary or list or tuple with tensors
-            of leading dimensions [B, T, ...]
-        op: a layer op that accepts inputs
-        activation: activation to apply at the output
-        inputs_as_kwargs (bool): whether to feed input as a kwargs dict to the op
-        inputs_as_args (bool) whether to feed input as a args list to the op
-        kwargs (dict): other kwargs to supply to the op
-
-    Returns:
-        outputs (dict or list or tuple): new nested dict-list-tuple with tensors of leading dimension [B, T].
-    """
-    batch_size, seq_len = flatten_nested_dict_list(inputs)[0][1].shape[:2]
-    inputs = join_dimensions(inputs, 0, 1)
-    if inputs_as_kwargs:
-        outputs = op(**inputs, **kwargs)
-    elif inputs_as_args:
-        outputs = op(*inputs, **kwargs)
-    else:
-        outputs = op(inputs, **kwargs)
-
-    if activation is not None:
-        outputs = map_tensor(outputs, activation)
-    outputs = reshape_dimensions(outputs, begin_axis=0, end_axis=0, target_dims=(batch_size, seq_len))
-    return outputs
diff --git a/lerobot/common/policies/diffusion/policy.py b/lerobot/common/policies/diffusion/policy.py
index 9785358b..fca89d46 100644
--- a/lerobot/common/policies/diffusion/policy.py
+++ b/lerobot/common/policies/diffusion/policy.py
@@ -5,11 +5,10 @@ from collections import deque
 
 import hydra
 import torch
-from torch import nn
+from diffusers.optimization import get_scheduler
+from torch import Tensor, nn
 
-from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
-from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
-from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder, RgbEncoder
+from lerobot.common.policies.diffusion.model.diffusion_unet_image_policy import DiffusionUnetImagePolicy
 from lerobot.common.policies.utils import populate_queues
 from lerobot.common.utils import get_safe_torch_device
 
@@ -22,8 +21,6 @@ class DiffusionPolicy(nn.Module):
         cfg,
         cfg_device,
         cfg_noise_scheduler,
-        cfg_rgb_model,
-        cfg_obs_encoder,
         cfg_optimizer,
         cfg_ema,
         shape_meta: dict,
@@ -31,53 +28,43 @@ class DiffusionPolicy(nn.Module):
         n_action_steps,
         n_obs_steps,
         num_inference_steps=None,
-        obs_as_global_cond=True,
         diffusion_step_embed_dim=256,
         down_dims=(256, 512, 1024),
         kernel_size=5,
         n_groups=8,
-        cond_predict_scale=True,
-        # parameters passed to step
-        **kwargs,
+        film_scale_modulation=True,
+        **_,
     ):
         super().__init__()
         self.cfg = cfg
         self.n_obs_steps = n_obs_steps
         self.n_action_steps = n_action_steps
+
         # queues are populated during rollout of the policy, they contain the n latest observations and actions
         self._queues = None
 
+        # TODO(now): In-house this.
         noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
-        rgb_model_input_shape = copy.deepcopy(shape_meta.obs.image.shape)
-        if cfg_obs_encoder.crop_shape is not None:
-            rgb_model_input_shape[1:] = cfg_obs_encoder.crop_shape
-        rgb_model = RgbEncoder(input_shape=rgb_model_input_shape, **cfg_rgb_model)
-        obs_encoder = MultiImageObsEncoder(
-            rgb_model=rgb_model,
-            **cfg_obs_encoder,
-        )
 
         self.diffusion = DiffusionUnetImagePolicy(
+            cfg,
             shape_meta=shape_meta,
             noise_scheduler=noise_scheduler,
-            obs_encoder=obs_encoder,
             horizon=horizon,
             n_action_steps=n_action_steps,
             n_obs_steps=n_obs_steps,
             num_inference_steps=num_inference_steps,
-            obs_as_global_cond=obs_as_global_cond,
             diffusion_step_embed_dim=diffusion_step_embed_dim,
             down_dims=down_dims,
             kernel_size=kernel_size,
             n_groups=n_groups,
-            cond_predict_scale=cond_predict_scale,
-            # parameters passed to step
-            **kwargs,
+            film_scale_modulation=film_scale_modulation,
         )
 
         self.device = get_safe_torch_device(cfg_device)
         self.diffusion.to(self.device)
 
+        # TODO(alexander-soare): This should probably be managed outside of the policy class.
         self.ema_diffusion = None
         self.ema = None
         if self.cfg.use_ema:
@@ -116,42 +103,45 @@ class DiffusionPolicy(nn.Module):
             "action": deque(maxlen=self.n_action_steps),
         }
 
-    @torch.no_grad()
-    def select_action(self, batch, step):
+    def forward(self, batch: dict[str, Tensor], **_) -> Tensor:
+        """A forward pass through the DNN part of this policy with optional loss computation."""
+        return self.select_action(batch)
+
+    @torch.no_grad
+    def select_action(self, batch, **_):
         """
         Note: this uses the ema model weights if self.training == False, otherwise the non-ema model weights.
+        # TODO(now): Handle a batch
         """
-        # TODO(rcadene): remove unused step
-        del step
         assert "observation.image" in batch
         assert "observation.state" in batch
-        assert len(batch) == 2
+        assert len(batch) == 2  # TODO(now): Does this not have a batch dim?
 
         self._queues = populate_queues(self._queues, batch)
 
         if len(self._queues["action"]) == 0:
             # stack n latest observations from the queue
             batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
-
-            obs_dict = {
-                "image": batch["observation.image"],
-                "agent_pos": batch["observation.state"],
-            }
-            if self.training:
-                out = self.diffusion.predict_action(obs_dict)
-            else:
-                out = self.ema_diffusion.predict_action(obs_dict)
-            self._queues["action"].extend(out["action"].transpose(0, 1))
+            actions = self._generate_actions(batch)
+            self._queues["action"].extend(actions.transpose(0, 1))
 
         action = self._queues["action"].popleft()
         return action
 
-    def forward(self, batch, step):
+    def _generate_actions(self, batch):
+        if not self.training and self.ema_diffusion is not None:
+            return self.ema_diffusion.predict_action(batch)
+        else:
+            return self.diffusion.predict_action(batch)
+
+    def update(self, batch, **_):
+        """Run the model in train mode, compute the loss, and do an optimization step."""
         start_time = time.time()
 
         self.diffusion.train()
 
-        loss = self.diffusion.compute_loss(batch)
+        loss = self.compute_loss(batch)
+
         loss.backward()
 
         grad_norm = torch.nn.utils.clip_grad_norm_(
@@ -174,13 +164,11 @@ class DiffusionPolicy(nn.Module):
             "update_s": time.time() - start_time,
         }
 
-        # TODO(rcadene): remove hardcoding
-        # in diffusion_policy, len(dataloader) is 168 for a batch_size of 64
-        if step % 168 == 0:
-            self.global_step += 1
-
         return info
 
+    def compute_loss(self, batch: dict[str, Tensor]) -> Tensor:
+        return self.diffusion.compute_loss(batch)
+
     def save(self, fp):
         torch.save(self.state_dict(), fp)
 
diff --git a/lerobot/common/policies/diffusion/pytorch_utils.py b/lerobot/common/policies/diffusion/pytorch_utils.py
deleted file mode 100644
index ed5dc23a..00000000
--- a/lerobot/common/policies/diffusion/pytorch_utils.py
+++ /dev/null
@@ -1,76 +0,0 @@
-from typing import Callable, Dict
-
-import torch
-import torch.nn as nn
-import torchvision
-
-
-def get_resnet(name, weights=None, **kwargs):
-    """
-    name: resnet18, resnet34, resnet50
-    weights: "IMAGENET1K_V1", "r3m"
-    """
-    # load r3m weights
-    if (weights == "r3m") or (weights == "R3M"):
-        return get_r3m(name=name, **kwargs)
-
-    func = getattr(torchvision.models, name)
-    resnet = func(weights=weights, **kwargs)
-    resnet.fc = torch.nn.Identity()
-    return resnet
-
-
-def get_r3m(name, **kwargs):
-    """
-    name: resnet18, resnet34, resnet50
-    """
-    import r3m
-
-    r3m.device = "cpu"
-    model = r3m.load_r3m(name)
-    r3m_model = model.module
-    resnet_model = r3m_model.convnet
-    resnet_model = resnet_model.to("cpu")
-    return resnet_model
-
-
-def dict_apply(
-    x: Dict[str, torch.Tensor], func: Callable[[torch.Tensor], torch.Tensor]
-) -> Dict[str, torch.Tensor]:
-    result = {}
-    for key, value in x.items():
-        if isinstance(value, dict):
-            result[key] = dict_apply(value, func)
-        else:
-            result[key] = func(value)
-    return result
-
-
-def replace_submodules(
-    root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
-) -> nn.Module:
-    """
-    predicate: Return true if the module is to be replaced.
-    func: Return new module to use.
-    """
-    if predicate(root_module):
-        return func(root_module)
-
-    bn_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
-    for *parent, k in bn_list:
-        parent_module = root_module
-        if len(parent) > 0:
-            parent_module = root_module.get_submodule(".".join(parent))
-        if isinstance(parent_module, nn.Sequential):
-            src_module = parent_module[int(k)]
-        else:
-            src_module = getattr(parent_module, k)
-        tgt_module = func(src_module)
-        if isinstance(parent_module, nn.Sequential):
-            parent_module[int(k)] = tgt_module
-        else:
-            setattr(parent_module, k, tgt_module)
-    # verify that all BN are replaced
-    bn_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
-    assert len(bn_list) == 0
-    return root_module
diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py
index a1cbea9a..325b5608 100644
--- a/lerobot/common/policies/factory.py
+++ b/lerobot/common/policies/factory.py
@@ -12,8 +12,6 @@ def make_policy(cfg):
             cfg=cfg.policy,
             cfg_device=cfg.device,
             cfg_noise_scheduler=cfg.noise_scheduler,
-            cfg_rgb_model=cfg.rgb_model,
-            cfg_obs_encoder=cfg.obs_encoder,
             cfg_optimizer=cfg.optimizer,
             cfg_ema=cfg.ema,
             # n_obs_steps=cfg.n_obs_steps,
diff --git a/lerobot/common/policies/utils.py b/lerobot/common/policies/utils.py
index b0503fe4..9d4b42f0 100644
--- a/lerobot/common/policies/utils.py
+++ b/lerobot/common/policies/utils.py
@@ -1,3 +1,7 @@
+import torch
+from torch import nn
+
+
 def populate_queues(queues, batch):
     for key in batch:
         if len(queues[key]) != queues[key].maxlen:
@@ -8,3 +12,21 @@ def populate_queues(queues, batch):
             # add latest observation to the queue
             queues[key].append(batch[key])
     return queues
+
+
+def get_device_from_parameters(module: nn.Module) -> torch.device:
+    """Get a module's device by checking one of its parameters.
+
+    Note: assumes that all parameters have the same device
+    TODO(now): Add test.
+    """
+    return next(iter(module.parameters())).device
+
+
+def get_dtype_from_parameters(module: nn.Module) -> torch.dtype:
+    """Get a module's parameter dtype by checking one of its parameters.
+
+    Note: assumes that all parameters have the same dtype.
+    TODO(now): Add test.
+    """
+    return next(iter(module.parameters())).dtype
diff --git a/lerobot/common/utils.py b/lerobot/common/utils.py
index e3e22832..28270bac 100644
--- a/lerobot/common/utils.py
+++ b/lerobot/common/utils.py
@@ -11,6 +11,7 @@ from omegaconf import DictConfig
 
 
 def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device:
+    """Given a string, return a torch.device with checks on whether the device is available."""
     match cfg_device:
         case "cuda":
             assert torch.cuda.is_available()
diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml
index 811ee824..005b0517 100644
--- a/lerobot/configs/policy/diffusion.yaml
+++ b/lerobot/configs/policy/diffusion.yaml
@@ -19,7 +19,6 @@ n_action_steps: 8
 dataset_obs_steps: ${n_obs_steps}
 past_action_visible: False
 keypoint_visible_rate: 1.0
-obs_as_global_cond: True
 
 eval_episodes: 50
 eval_freq: 5000
@@ -40,13 +39,12 @@ policy:
   n_obs_steps: ${n_obs_steps}
   n_action_steps: ${n_action_steps}
   num_inference_steps: 100
-  obs_as_global_cond: ${obs_as_global_cond}
   # crop_shape: null
   diffusion_step_embed_dim: 128
   down_dims: [512, 1024, 2048]
   kernel_size: 5
   n_groups: 8
-  cond_predict_scale: True
+  film_scale_modulation: True
 
   pretrained_model_path:
 
@@ -68,6 +66,16 @@ policy:
     observation.state: [-0.1, 0]
     action: [-0.1, 0, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1.0, 1.1, 1.2, 1.3, 1.4]
 
+  rgb_encoder:
+    backbone_name: resnet18
+    pretrained_backbone: false
+    use_group_norm: True
+    num_keypoints: 32
+    relu: true
+    norm_mean_std: [0.5, 0.5]  # for PushT the original impl normalizes to [-1, 1] (maybe not the case for robomimic envs)
+    crop_shape: [84, 84]
+    random_crop: True
+
 noise_scheduler:
   _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
   num_train_timesteps: 100
@@ -78,16 +86,6 @@ noise_scheduler:
   clip_sample: True # required when predict_epsilon=False
   prediction_type: epsilon # or sample
 
-obs_encoder:
-  shape_meta: ${shape_meta}
-  # resize_shape: null
-  crop_shape: [84, 84]
-  # constant center crop
-  random_crop: True
-  use_group_norm: True
-  share_rgb_model: False
-  norm_mean_std: [0.5, 0.5]  # for PushT the original impl normalizes to [-1, 1] (maybe not the case for robomimic envs)
-
 rgb_model:
   pretrained: false
   num_keypoints: 32
diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py
index 535ac935..06459a85 100644
--- a/lerobot/scripts/eval.py
+++ b/lerobot/scripts/eval.py
@@ -121,7 +121,7 @@ def eval_policy(
 
         # get the next action for the environment
         with torch.inference_mode():
-            action = policy.select_action(observation, step)
+            action = policy.select_action(observation, step=step)
 
         # apply inverse transform to unnormalize the action
         action = postprocess_action(action, transform)
diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py
index caaf5182..dd3da978 100644
--- a/lerobot/scripts/train.py
+++ b/lerobot/scripts/train.py
@@ -213,7 +213,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
         for key in batch:
             batch[key] = batch[key].to(cfg.device, non_blocking=True)
 
-        train_info = policy(batch, step)
+        train_info = policy.update(batch, step=step)
 
         # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
         if step % cfg.log_freq == 0:

From 5666ec3ec7c0294f71f647486aa19b4699e32718 Mon Sep 17 00:00:00 2001
From: Alexander Soare <alexander.soare159@gmail.com>
Date: Thu, 11 Apr 2024 18:33:54 +0100
Subject: [PATCH 2/8] backup wip

---
 .../model/diffusion_unet_image_policy.py      | 117 +++++++-----------
 .../policies/diffusion/model/ema_model.py     |  21 ++--
 lerobot/common/policies/diffusion/policy.py   |   4 +-
 3 files changed, 53 insertions(+), 89 deletions(-)

diff --git a/lerobot/common/policies/diffusion/model/diffusion_unet_image_policy.py b/lerobot/common/policies/diffusion/model/diffusion_unet_image_policy.py
index b6b78925..92928c70 100644
--- a/lerobot/common/policies/diffusion/model/diffusion_unet_image_policy.py
+++ b/lerobot/common/policies/diffusion/model/diffusion_unet_image_policy.py
@@ -66,53 +66,33 @@ class DiffusionUnetImagePolicy(nn.Module):
         self.num_inference_steps = num_inference_steps
 
     # ========= inference  ============
-    def conditional_sample(
-        self,
-        condition_data,
-        inpainting_mask,
-        local_cond=None,
-        global_cond=None,
-        generator=None,
-    ):
-        model = self.unet
-        scheduler = self.noise_scheduler
+    def conditional_sample(self, batch_size, global_cond=None, generator=None):
+        device = get_device_from_parameters(self)
+        dtype = get_dtype_from_parameters(self)
 
-        trajectory = torch.randn(
-            size=condition_data.shape,
-            dtype=condition_data.dtype,
-            device=condition_data.device,
+        # Sample prior.
+        sample = torch.randn(
+            size=(batch_size, self.horizon, self.action_dim),
+            dtype=dtype,
+            device=device,
             generator=generator,
         )
 
-        # set step values
-        scheduler.set_timesteps(self.num_inference_steps)
+        self.noise_scheduler.set_timesteps(self.num_inference_steps)
 
-        for t in scheduler.timesteps:
-            # 1. apply conditioning
-            trajectory[inpainting_mask] = condition_data[inpainting_mask]
-
-            # 2. predict model output
-            model_output = model(
-                trajectory,
-                torch.full(trajectory.shape[:1], t, dtype=torch.long, device=trajectory.device),
-                local_cond=local_cond,
+        for t in self.noise_scheduler.timesteps:
+            # Predict model output.
+            model_output = self.unet(
+                sample,
+                torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device),
                 global_cond=global_cond,
             )
+            # Compute previous image: x_t -> x_t-1  # TODO(now): Is this right?
+            sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample
 
-            # 3. compute previous image: x_t -> x_t-1
-            trajectory = scheduler.step(
-                model_output,
-                t,
-                trajectory,
-                generator=generator,
-            ).prev_sample
+        return sample
 
-        # finally make sure conditioning is enforced
-        trajectory[inpainting_mask] = condition_data[inpainting_mask]
-
-        return trajectory
-
-    def predict_action(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
+    def generate_actions(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
         """
         This function expects `batch` to have (at least):
         {
@@ -125,27 +105,19 @@ class DiffusionUnetImagePolicy(nn.Module):
         assert n_obs_steps == self.n_obs_steps
         assert self.n_obs_steps == n_obs_steps
 
-        # build input
-        device = get_device_from_parameters(self)
-        dtype = get_dtype_from_parameters(self)
-
         # Extract image feature (first combine batch and sequence dims).
         img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
         # Separate batch and sequence dims.
         img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
         # Concatenate state and image features then flatten to (B, global_cond_dim).
         global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
-        # reshape back to B, Do
-        # empty data for action
-        cond_data = torch.zeros(size=(batch_size, self.horizon, self.action_dim), device=device, dtype=dtype)
-        cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
 
         # run sampling
-        nsample = self.conditional_sample(cond_data, cond_mask, global_cond=global_cond)
+        sample = self.conditional_sample(batch_size, global_cond=global_cond)
 
         # `horizon` steps worth of actions (from the first observation).
-        action = nsample[..., : self.action_dim]
-        # Extract `n_action_steps` steps worth of action (from the current observation).
+        action = sample[..., : self.action_dim]
+        # Extract `n_action_steps` steps worth of actions (from the current observation).
         start = n_obs_steps - 1
         end = start + self.n_action_steps
         action = action[:, start:end]
@@ -159,9 +131,10 @@ class DiffusionUnetImagePolicy(nn.Module):
             "observation.state": (B, n_obs_steps, state_dim)
             "observation.image": (B, n_obs_steps, C, H, W)
             "action": (B, horizon, action_dim)
-            "action_is_pad": (B, horizon)  # TODO(now) maybe this is (B, horizon, 1)
+            "action_is_pad": (B, horizon)
         }
         """
+        # Input validation.
         assert set(batch).issuperset({"observation.state", "observation.image", "action", "action_is_pad"})
         batch_size, n_obs_steps = batch["observation.state"].shape[:2]
         horizon = batch["action"].shape[1]
@@ -169,12 +142,6 @@ class DiffusionUnetImagePolicy(nn.Module):
         assert n_obs_steps == self.n_obs_steps
         assert self.n_obs_steps == n_obs_steps
 
-        # handle different ways of passing observation
-        local_cond = None
-        global_cond = None
-        trajectory = batch["action"]
-        cond_data = trajectory
-
         # Extract image feature (first combine batch and sequence dims).
         img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
         # Separate batch and sequence dims.
@@ -182,39 +149,39 @@ class DiffusionUnetImagePolicy(nn.Module):
         # Concatenate state and image features then flatten to (B, global_cond_dim).
         global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
 
-        # Sample noise that we'll add to the images
-        noise = torch.randn(trajectory.shape, device=trajectory.device)
-        # Sample a random timestep for each image
+        trajectory = batch["action"]
+
+        # Forward diffusion.
+        # Sample noise to add to the trajectory.
+        eps = torch.randn(trajectory.shape, device=trajectory.device)
+        # Sample a random noising timestep for each item in the batch.
         timesteps = torch.randint(
-            0,
-            self.noise_scheduler.config.num_train_timesteps,
-            (trajectory.shape[0],),
+            low=0,
+            high=self.noise_scheduler.config.num_train_timesteps,
+            size=(trajectory.shape[0],),
             device=trajectory.device,
         ).long()
-        # Add noise to the clean images according to the noise magnitude at each timestep
-        # (this is the forward diffusion process)
-        noisy_trajectory = self.noise_scheduler.add_noise(trajectory, noise, timesteps)
+        # Add noise to the clean trajectories according to the noise magnitude at each timestep.
+        noisy_trajectory = self.noise_scheduler.add_noise(trajectory, eps, timesteps)
 
-        # Apply inpainting. TODO(now): implement?
-        inpainting_mask = torch.zeros_like(trajectory, dtype=bool)
-        noisy_trajectory[inpainting_mask] = cond_data[inpainting_mask]
-
-        # Predict the noise residual
-        pred = self.unet(noisy_trajectory, timesteps, local_cond=local_cond, global_cond=global_cond)
+        # Run the denoising network (that might denoise the trajectory, or attempt to predict the noise).
+        pred = self.unet(noisy_trajectory, timesteps, global_cond=global_cond)
 
+        # Compute the loss.
+        # The targe is either the original trajectory, or the noise.
         pred_type = self.noise_scheduler.config.prediction_type
         if pred_type == "epsilon":
-            target = noise
+            target = eps
         elif pred_type == "sample":
-            target = trajectory
+            target = batch["action"]
         else:
             raise ValueError(f"Unsupported prediction type {pred_type}")
 
         loss = F.mse_loss(pred, target, reduction="none")
-        loss = loss * (~inpainting_mask)
 
+        # Mask loss wherever the action is padded with copies (edges of the dataset trajectory).
         if "action_is_pad" in batch:
             in_episode_bound = ~batch["action_is_pad"]
-            loss = loss * in_episode_bound[:, :, None].type(loss.dtype)
+            loss = loss * in_episode_bound.unsqueeze(-1)
 
         return loss.mean()
diff --git a/lerobot/common/policies/diffusion/model/ema_model.py b/lerobot/common/policies/diffusion/model/ema_model.py
index 3cb1dfbd..1e3447f3 100644
--- a/lerobot/common/policies/diffusion/model/ema_model.py
+++ b/lerobot/common/policies/diffusion/model/ema_model.py
@@ -32,7 +32,7 @@ class EMAModel:
         self.min_value = min_value
         self.max_value = max_value
 
-        self.decay = 0.0
+        self.alpha = 0.0
         self.optimization_step = 0
 
     def get_decay(self, optimization_step):
@@ -49,23 +49,20 @@ class EMAModel:
 
     @torch.no_grad()
     def step(self, new_model):
-        self.decay = self.get_decay(self.optimization_step)
+        self.alpha = self.get_decay(self.optimization_step)
 
-        for module, ema_module in zip(new_model.modules(), self.averaged_model.modules(), strict=False):
+        for module, ema_module in zip(new_model.modules(), self.averaged_model.modules(), strict=True):
+            # Iterate over immediate parameters only.
             for param, ema_param in zip(
-                module.parameters(recurse=False), ema_module.parameters(recurse=False), strict=False
+                module.parameters(recurse=False), ema_module.parameters(recurse=False), strict=True
             ):
-                # iterative over immediate parameters only.
                 if isinstance(param, dict):
                     raise RuntimeError("Dict parameter not supported")
-
-                if isinstance(module, _BatchNorm):
-                    # skip batchnorms
-                    ema_param.copy_(param.to(dtype=ema_param.dtype).data)
-                elif not param.requires_grad:
+                if isinstance(module, _BatchNorm) or not param.requires_grad:
+                    # Copy BatchNorm parameters, and non-trainable parameters directly.
                     ema_param.copy_(param.to(dtype=ema_param.dtype).data)
                 else:
-                    ema_param.mul_(self.decay)
-                    ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay)
+                    ema_param.mul_(self.alpha)
+                    ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.alpha)
 
         self.optimization_step += 1
diff --git a/lerobot/common/policies/diffusion/policy.py b/lerobot/common/policies/diffusion/policy.py
index fca89d46..b1713869 100644
--- a/lerobot/common/policies/diffusion/policy.py
+++ b/lerobot/common/policies/diffusion/policy.py
@@ -130,9 +130,9 @@ class DiffusionPolicy(nn.Module):
 
     def _generate_actions(self, batch):
         if not self.training and self.ema_diffusion is not None:
-            return self.ema_diffusion.predict_action(batch)
+            return self.ema_diffusion.generate_actions(batch)
         else:
-            return self.diffusion.predict_action(batch)
+            return self.diffusion.generate_actions(batch)
 
     def update(self, batch, **_):
         """Run the model in train mode, compute the loss, and do an optimization step."""

From 6d0a45a97d0d04d324023fe5a1b50815085b14c4 Mon Sep 17 00:00:00 2001
From: Alexander Soare <alexander.soare159@gmail.com>
Date: Fri, 12 Apr 2024 11:36:52 +0100
Subject: [PATCH 3/8] ready for review

---
 examples/3_train_policy.py                    |  2 --
 lerobot/common/policies/act/policy.py         |  2 +-
 .../diffusion/model/conditional_unet1d.py     |  1 -
 .../model/diffusion_unet_image_policy.py      | 14 +--------
 lerobot/common/policies/diffusion/policy.py   | 30 +++++--------------
 lerobot/common/policies/utils.py              |  2 --
 lerobot/scripts/train.py                      |  2 +-
 7 files changed, 11 insertions(+), 42 deletions(-)

diff --git a/examples/3_train_policy.py b/examples/3_train_policy.py
index 238f953d..d2fff13b 100644
--- a/examples/3_train_policy.py
+++ b/examples/3_train_policy.py
@@ -32,8 +32,6 @@ policy = DiffusionPolicy(
     cfg=cfg.policy,
     cfg_device=cfg.device,
     cfg_noise_scheduler=cfg.noise_scheduler,
-    cfg_rgb_model=cfg.rgb_model,
-    cfg_obs_encoder=cfg.obs_encoder,
     cfg_optimizer=cfg.optimizer,
     cfg_ema=cfg.ema,
     **cfg.policy,
diff --git a/lerobot/common/policies/act/policy.py b/lerobot/common/policies/act/policy.py
index 821b0196..24667795 100644
--- a/lerobot/common/policies/act/policy.py
+++ b/lerobot/common/policies/act/policy.py
@@ -213,7 +213,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
         return action[: self.n_action_steps]
 
     def __call__(self, *args, **kwargs) -> dict:
-        # TODO(now): Temporary bridge until we know what to do about the `update` method.
+        # TODO(alexander-soare): Temporary bridge until we know what to do about the `update` method.
         return self.update(*args, **kwargs)
 
     def _preprocess_batch(
diff --git a/lerobot/common/policies/diffusion/model/conditional_unet1d.py b/lerobot/common/policies/diffusion/model/conditional_unet1d.py
index 5c43d488..c3dcc198 100644
--- a/lerobot/common/policies/diffusion/model/conditional_unet1d.py
+++ b/lerobot/common/policies/diffusion/model/conditional_unet1d.py
@@ -10,7 +10,6 @@ logger = logging.getLogger(__name__)
 
 
 class _SinusoidalPosEmb(nn.Module):
-    # TODO(now): consolidate?
     def __init__(self, dim):
         super().__init__()
         self.dim = dim
diff --git a/lerobot/common/policies/diffusion/model/diffusion_unet_image_policy.py b/lerobot/common/policies/diffusion/model/diffusion_unet_image_policy.py
index 92928c70..3e7727f3 100644
--- a/lerobot/common/policies/diffusion/model/diffusion_unet_image_policy.py
+++ b/lerobot/common/policies/diffusion/model/diffusion_unet_image_policy.py
@@ -10,18 +10,6 @@ from lerobot.common.policies.utils import get_device_from_parameters, get_dtype_
 
 
 class DiffusionUnetImagePolicy(nn.Module):
-    """
-    TODO(now): Add DDIM scheduler.
-
-    Changes:  TODO(now)
-      - Use single image encoder for now instead of generic obs_encoder. We may generalize again when/if
-        needed. Code for a general observation encoder can be found at:
-        https://github.com/huggingface/lerobot/blob/920e0d118b493e4cc3058a9b1b764f38ae145d8e/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py
-      - Uses the observation as global conditioning for the Unet by default.
-      - Does not do any inpainting (which would be applicable if the observation were not used to condition
-        the Unet).
-    """
-
     def __init__(
         self,
         cfg,
@@ -87,7 +75,7 @@ class DiffusionUnetImagePolicy(nn.Module):
                 torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device),
                 global_cond=global_cond,
             )
-            # Compute previous image: x_t -> x_t-1  # TODO(now): Is this right?
+            # Compute previous image: x_t -> x_t-1
             sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample
 
         return sample
diff --git a/lerobot/common/policies/diffusion/policy.py b/lerobot/common/policies/diffusion/policy.py
index b1713869..f88e2f25 100644
--- a/lerobot/common/policies/diffusion/policy.py
+++ b/lerobot/common/policies/diffusion/policy.py
@@ -6,7 +6,7 @@ from collections import deque
 import hydra
 import torch
 from diffusers.optimization import get_scheduler
-from torch import Tensor, nn
+from torch import nn
 
 from lerobot.common.policies.diffusion.model.diffusion_unet_image_policy import DiffusionUnetImagePolicy
 from lerobot.common.policies.utils import populate_queues
@@ -43,7 +43,6 @@ class DiffusionPolicy(nn.Module):
         # queues are populated during rollout of the policy, they contain the n latest observations and actions
         self._queues = None
 
-        # TODO(now): In-house this.
         noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
 
         self.diffusion = DiffusionUnetImagePolicy(
@@ -103,45 +102,35 @@ class DiffusionPolicy(nn.Module):
             "action": deque(maxlen=self.n_action_steps),
         }
 
-    def forward(self, batch: dict[str, Tensor], **_) -> Tensor:
-        """A forward pass through the DNN part of this policy with optional loss computation."""
-        return self.select_action(batch)
-
     @torch.no_grad
     def select_action(self, batch, **_):
         """
         Note: this uses the ema model weights if self.training == False, otherwise the non-ema model weights.
-        # TODO(now): Handle a batch
         """
         assert "observation.image" in batch
         assert "observation.state" in batch
-        assert len(batch) == 2  # TODO(now): Does this not have a batch dim?
+        assert len(batch) == 2
 
         self._queues = populate_queues(self._queues, batch)
 
         if len(self._queues["action"]) == 0:
             # stack n latest observations from the queue
             batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
-            actions = self._generate_actions(batch)
+            if not self.training and self.ema_diffusion is not None:
+                actions = self.ema_diffusion.generate_actions(batch)
+            else:
+                actions = self.diffusion.generate_actions(batch)
             self._queues["action"].extend(actions.transpose(0, 1))
 
         action = self._queues["action"].popleft()
         return action
 
-    def _generate_actions(self, batch):
-        if not self.training and self.ema_diffusion is not None:
-            return self.ema_diffusion.generate_actions(batch)
-        else:
-            return self.diffusion.generate_actions(batch)
-
-    def update(self, batch, **_):
-        """Run the model in train mode, compute the loss, and do an optimization step."""
+    def forward(self, batch, **_):
         start_time = time.time()
 
         self.diffusion.train()
 
-        loss = self.compute_loss(batch)
-
+        loss = self.diffusion.compute_loss(batch)
         loss.backward()
 
         grad_norm = torch.nn.utils.clip_grad_norm_(
@@ -166,9 +155,6 @@ class DiffusionPolicy(nn.Module):
 
         return info
 
-    def compute_loss(self, batch: dict[str, Tensor]) -> Tensor:
-        return self.diffusion.compute_loss(batch)
-
     def save(self, fp):
         torch.save(self.state_dict(), fp)
 
diff --git a/lerobot/common/policies/utils.py b/lerobot/common/policies/utils.py
index 9d4b42f0..b23c1336 100644
--- a/lerobot/common/policies/utils.py
+++ b/lerobot/common/policies/utils.py
@@ -18,7 +18,6 @@ def get_device_from_parameters(module: nn.Module) -> torch.device:
     """Get a module's device by checking one of its parameters.
 
     Note: assumes that all parameters have the same device
-    TODO(now): Add test.
     """
     return next(iter(module.parameters())).device
 
@@ -27,6 +26,5 @@ def get_dtype_from_parameters(module: nn.Module) -> torch.dtype:
     """Get a module's parameter dtype by checking one of its parameters.
 
     Note: assumes that all parameters have the same dtype.
-    TODO(now): Add test.
     """
     return next(iter(module.parameters())).dtype
diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py
index 300a8617..5ff6538d 100644
--- a/lerobot/scripts/train.py
+++ b/lerobot/scripts/train.py
@@ -251,7 +251,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
         for key in batch:
             batch[key] = batch[key].to(cfg.device, non_blocking=True)
 
-        train_info = policy.update(batch, step=step)
+        train_info = policy(batch, step=step)
 
         # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
         if step % cfg.log_freq == 0:

From 5608e659e68f0b89f768a12b741d2aa3df3b3666 Mon Sep 17 00:00:00 2001
From: Alexander Soare <alexander.soare159@gmail.com>
Date: Mon, 15 Apr 2024 19:06:44 +0100
Subject: [PATCH 4/8] backup wip

---
 examples/3_train_policy.py                    |   2 +-
 .../common/policies/act/configuration_act.py  |   4 +-
 lerobot/common/policies/act/modeling_act.py   |   3 +-
 .../diffusion/configuration_diffusion.py      |  83 ++
 .../diffusion/model/conditional_unet1d.py     | 306 ------
 .../model/diffusion_unet_image_policy.py      | 175 ----
 .../policies/diffusion/model/ema_model.py     |  68 --
 .../policies/diffusion/model/rgb_encoder.py   | 147 ---
 .../policies/diffusion/modeling_diffusion.py  | 878 ++++++++++++++++++
 lerobot/common/policies/diffusion/policy.py   | 169 ----
 lerobot/common/policies/factory.py            |  72 +-
 lerobot/configs/policy/act.yaml               |   2 +-
 lerobot/configs/policy/diffusion.yaml         | 125 ++-
 tests/test_available.py                       |   2 +-
 14 files changed, 1059 insertions(+), 977 deletions(-)
 create mode 100644 lerobot/common/policies/diffusion/configuration_diffusion.py
 delete mode 100644 lerobot/common/policies/diffusion/model/conditional_unet1d.py
 delete mode 100644 lerobot/common/policies/diffusion/model/diffusion_unet_image_policy.py
 delete mode 100644 lerobot/common/policies/diffusion/model/ema_model.py
 delete mode 100644 lerobot/common/policies/diffusion/model/rgb_encoder.py
 create mode 100644 lerobot/common/policies/diffusion/modeling_diffusion.py
 delete mode 100644 lerobot/common/policies/diffusion/policy.py

diff --git a/examples/3_train_policy.py b/examples/3_train_policy.py
index d2fff13b..64804d8f 100644
--- a/examples/3_train_policy.py
+++ b/examples/3_train_policy.py
@@ -11,7 +11,7 @@ import torch
 from omegaconf import OmegaConf
 
 from lerobot.common.datasets.factory import make_dataset
-from lerobot.common.policies.diffusion.policy import DiffusionPolicy
+from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
 from lerobot.common.utils import init_hydra_config
 
 output_directory = Path("outputs/train/example_pusht_diffusion")
diff --git a/lerobot/common/policies/act/configuration_act.py b/lerobot/common/policies/act/configuration_act.py
index 74ed270e..72d35eb3 100644
--- a/lerobot/common/policies/act/configuration_act.py
+++ b/lerobot/common/policies/act/configuration_act.py
@@ -56,7 +56,7 @@ class ActionChunkingTransformerConfig:
 
     # Inputs / output structure.
     n_obs_steps: int = 1
-    camera_names: list[str] = field(default_factory=lambda: ["top"])
+    camera_names: tuple[str] = ("top",)
     chunk_size: int = 100
     n_action_steps: int = 100
 
@@ -101,7 +101,7 @@ class ActionChunkingTransformerConfig:
     utd: int = 1
 
     def __post_init__(self):
-        """Input validation."""
+        """Input validation (not exhaustive)."""
         if not self.vision_backbone.startswith("resnet"):
             raise ValueError("`vision_backbone` must be one of the ResNet variants.")
         if self.use_temporal_aggregation:
diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py
index 1361e071..18ea3377 100644
--- a/lerobot/common/policies/act/modeling_act.py
+++ b/lerobot/common/policies/act/modeling_act.py
@@ -163,7 +163,8 @@ class ActionChunkingTransformerPolicy(nn.Module):
 
     @torch.no_grad
     def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
-        """
+        """Select a single action given environment observations.
+
         This method wraps `select_actions` in order to return one action at a time for execution in the
         environment. It works by managing the actions in a queue and only calling `select_actions` when the
         queue is empty.
diff --git a/lerobot/common/policies/diffusion/configuration_diffusion.py b/lerobot/common/policies/diffusion/configuration_diffusion.py
new file mode 100644
index 00000000..272c80ea
--- /dev/null
+++ b/lerobot/common/policies/diffusion/configuration_diffusion.py
@@ -0,0 +1,83 @@
+from dataclasses import dataclass
+
+
+@dataclass
+class DiffusionConfig:
+    """Configuration class for Diffusion Policy.
+
+    Defaults are configured for training with PushT providing proprioceptive and single camera observations.
+
+    The parameters you will most likely need to change are the ones which depend on the environment / sensors.
+    Those are: `state_dim`, `action_dim` and `image_size`.
+
+    Args:
+        state_dim: Dimensionality of the observation state space (excluding images).
+        action_dim: Dimensionality of the action space.
+        n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
+            current step and additional steps going back).
+        horizon: Diffusion model action prediction horizon as detailed in the main policy documentation.
+    """
+
+    # Environment.
+    # Inherit these from the environment config.
+    state_dim: int = 2
+    action_dim: int = 2
+    image_size: tuple[int, int] = (96, 96)
+
+    # Inputs / output structure.
+    n_obs_steps: int = 2
+    horizon: int = 16
+    n_action_steps: int = 8
+
+    # Vision preprocessing.
+    image_normalization_mean: tuple[float, float, float] = (0.5, 0.5, 0.5)
+    image_normalization_std: tuple[float, float, float] = (0.5, 0.5, 0.5)
+
+    # Architecture / modeling.
+    # Vision backbone.
+    vision_backbone: str = "resnet18"
+    crop_shape: tuple[int, int] = (84, 84)
+    crop_is_random: bool = True
+    use_pretrained_backbone: bool = False
+    use_group_norm: bool = True
+    spatial_softmax_num_keypoints: int = 32
+    # Unet.
+    down_dims: tuple[int, ...] = (512, 1024, 2048)
+    kernel_size: int = 5
+    n_groups: int = 8
+    diffusion_step_embed_dim: int = 128
+    film_scale_modulation: bool = True
+    # Noise scheduler.
+    num_train_timesteps: int = 100
+    beta_schedule: str = "squaredcos_cap_v2"
+    beta_start: float = 0.0001
+    beta_end: float = 0.02
+    variance_type: str = "fixed_small"
+    prediction_type: str = "epsilon"
+    clip_sample: True
+
+    # Inference
+    num_inference_steps: int = 100
+
+    # ---
+    # TODO(alexander-soare): Remove these from the policy config.
+    batch_size: int = 64
+    grad_clip_norm: int = 10
+    lr: float = 1.0e-4
+    lr_scheduler: str = "cosine"
+    lr_warmup_steps: int = 500
+    adam_betas: tuple[float, float] = (0.95, 0.999)
+    adam_eps: float = 1.0e-8
+    adam_weight_decay: float = 1.0e-6
+    utd: int = 1
+    use_ema: bool = True
+    ema_update_after_step: int = 0
+    ema_min_rate: float = 0.0
+    ema_max_rate: float = 0.9999
+    ema_inv_gamma: float = 1.0
+    ema_power: float = 0.75
+
+    def __post_init__(self):
+        """Input validation (not exhaustive)."""
+        if not self.vision_backbone.startswith("resnet"):
+            raise ValueError("`vision_backbone` must be one of the ResNet variants.")
diff --git a/lerobot/common/policies/diffusion/model/conditional_unet1d.py b/lerobot/common/policies/diffusion/model/conditional_unet1d.py
deleted file mode 100644
index c3dcc198..00000000
--- a/lerobot/common/policies/diffusion/model/conditional_unet1d.py
+++ /dev/null
@@ -1,306 +0,0 @@
-import logging
-import math
-
-import einops
-import torch
-import torch.nn as nn
-from torch import Tensor
-
-logger = logging.getLogger(__name__)
-
-
-class _SinusoidalPosEmb(nn.Module):
-    def __init__(self, dim):
-        super().__init__()
-        self.dim = dim
-
-    def forward(self, x):
-        device = x.device
-        half_dim = self.dim // 2
-        emb = math.log(10000) / (half_dim - 1)
-        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
-        emb = x[:, None] * emb[None, :]
-        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
-        return emb
-
-
-class _Conv1dBlock(nn.Module):
-    """Conv1d --> GroupNorm --> Mish"""
-
-    def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
-        super().__init__()
-
-        self.block = nn.Sequential(
-            nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
-            nn.GroupNorm(n_groups, out_channels),
-            nn.Mish(),
-        )
-
-    def forward(self, x):
-        return self.block(x)
-
-
-class _ConditionalResidualBlock1D(nn.Module):
-    """ResNet style 1D convolutional block with FiLM modulation for conditioning."""
-
-    def __init__(
-        self,
-        in_channels: int,
-        out_channels: int,
-        cond_dim: int,
-        kernel_size: int = 3,
-        n_groups: int = 8,
-        # Set to True to do scale modulation with FiLM as well as bias modulation (defaults to False meaning
-        # FiLM just modulates bias).
-        film_scale_modulation: bool = False,
-    ):
-        super().__init__()
-
-        self.film_scale_modulation = film_scale_modulation
-        self.out_channels = out_channels
-
-        self.conv1 = _Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups)
-
-        # FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale.
-        cond_channels = out_channels * 2 if film_scale_modulation else out_channels
-        self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels))
-
-        self.conv2 = _Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups)
-
-        # A final convolution for dimension matching the residual (if needed).
-        self.residual_conv = (
-            nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
-        )
-
-    def forward(self, x: Tensor, cond: Tensor) -> Tensor:
-        """
-        Args:
-            x: (B, in_channels, T)
-            cond: (B, cond_dim)
-        Returns:
-            (B, out_channels, T)
-        """
-        out = self.conv1(x)
-
-        # Get condition embedding. Unsqueeze for broadcasting to `out`, resulting in (B, out_channels, 1).
-        cond_embed = self.cond_encoder(cond).unsqueeze(-1)
-        if self.film_scale_modulation:
-            # Treat the embedding as a list of scales and biases.
-            scale = cond_embed[:, : self.out_channels]
-            bias = cond_embed[:, self.out_channels :]
-            out = scale * out + bias
-        else:
-            # Treat the embedding as biases.
-            out = out + cond_embed
-
-        out = self.conv2(out)
-        out = out + self.residual_conv(x)
-        return out
-
-
-class ConditionalUnet1D(nn.Module):
-    """A 1D convolutional UNet with FiLM modulation for conditioning.
-
-    Two types of conditioning can be applied:
-    - Global: Conditioning information that is aggregated over the whole observation window. This is
-        incorporated via the FiLM technique in the residual convolution blocks of the Unet's encoder/decoder.
-    - Local: Conditioning information for each timestep in the observation window. This is incorporated
-        by encoding the information via 1D convolutions and adding the resulting embeddings to the inputs and
-        outputs of the Unet's encoder/decoder.
-    """
-
-    def __init__(
-        self,
-        input_dim: int,
-        local_cond_dim: int | None = None,
-        global_cond_dim: int | None = None,
-        diffusion_step_embed_dim: int = 256,
-        down_dims: int | None = None,
-        kernel_size: int = 3,
-        n_groups: int = 8,
-        film_scale_modulation: bool = False,
-    ):
-        super().__init__()
-
-        if down_dims is None:
-            down_dims = [256, 512, 1024]
-
-        # Encoder for the diffusion timestep.
-        self.diffusion_step_encoder = nn.Sequential(
-            _SinusoidalPosEmb(diffusion_step_embed_dim),
-            nn.Linear(diffusion_step_embed_dim, diffusion_step_embed_dim * 4),
-            nn.Mish(),
-            nn.Linear(diffusion_step_embed_dim * 4, diffusion_step_embed_dim),
-        )
-
-        # The FiLM conditioning dimension.
-        cond_dim = diffusion_step_embed_dim
-        if global_cond_dim is not None:
-            cond_dim += global_cond_dim
-
-        self.local_cond_down_encoder = None
-        self.local_cond_up_encoder = None
-        if local_cond_dim is not None:
-            # Encoder for the local conditioning. The output gets added to the Unet encoder input.
-            self.local_cond_down_encoder = _ConditionalResidualBlock1D(
-                local_cond_dim,
-                down_dims[0],
-                cond_dim=cond_dim,
-                kernel_size=kernel_size,
-                n_groups=n_groups,
-                film_scale_modulation=film_scale_modulation,
-            )
-            # Encoder for the local conditioning. The output gets added to the Unet encoder output.
-            self.local_cond_up_encoder = _ConditionalResidualBlock1D(
-                local_cond_dim,
-                down_dims[0],
-                cond_dim=cond_dim,
-                kernel_size=kernel_size,
-                n_groups=n_groups,
-                film_scale_modulation=film_scale_modulation,
-            )
-
-        # In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we
-        # just reverse these.
-        in_out = [(input_dim, down_dims[0])] + list(zip(down_dims[:-1], down_dims[1:], strict=True))
-
-        # Unet encoder.
-        self.down_modules = nn.ModuleList([])
-        for ind, (dim_in, dim_out) in enumerate(in_out):
-            is_last = ind >= (len(in_out) - 1)
-            self.down_modules.append(
-                nn.ModuleList(
-                    [
-                        _ConditionalResidualBlock1D(
-                            dim_in,
-                            dim_out,
-                            cond_dim=cond_dim,
-                            kernel_size=kernel_size,
-                            n_groups=n_groups,
-                            film_scale_modulation=film_scale_modulation,
-                        ),
-                        _ConditionalResidualBlock1D(
-                            dim_out,
-                            dim_out,
-                            cond_dim=cond_dim,
-                            kernel_size=kernel_size,
-                            n_groups=n_groups,
-                            film_scale_modulation=film_scale_modulation,
-                        ),
-                        # Downsample as long as it is not the last block.
-                        nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(),
-                    ]
-                )
-            )
-
-        # Processing in the middle of the auto-encoder.
-        self.mid_modules = nn.ModuleList(
-            [
-                _ConditionalResidualBlock1D(
-                    down_dims[-1],
-                    down_dims[-1],
-                    cond_dim=cond_dim,
-                    kernel_size=kernel_size,
-                    n_groups=n_groups,
-                    film_scale_modulation=film_scale_modulation,
-                ),
-                _ConditionalResidualBlock1D(
-                    down_dims[-1],
-                    down_dims[-1],
-                    cond_dim=cond_dim,
-                    kernel_size=kernel_size,
-                    n_groups=n_groups,
-                    film_scale_modulation=film_scale_modulation,
-                ),
-            ]
-        )
-
-        # Unet decoder.
-        self.up_modules = nn.ModuleList([])
-        for ind, (dim_out, dim_in) in enumerate(reversed(in_out[1:])):
-            is_last = ind >= (len(in_out) - 1)
-            self.up_modules.append(
-                nn.ModuleList(
-                    [
-                        _ConditionalResidualBlock1D(
-                            dim_in * 2,  # x2 as it takes the encoder's skip connection as well
-                            dim_out,
-                            cond_dim=cond_dim,
-                            kernel_size=kernel_size,
-                            n_groups=n_groups,
-                            film_scale_modulation=film_scale_modulation,
-                        ),
-                        _ConditionalResidualBlock1D(
-                            dim_out,
-                            dim_out,
-                            cond_dim=cond_dim,
-                            kernel_size=kernel_size,
-                            n_groups=n_groups,
-                            film_scale_modulation=film_scale_modulation,
-                        ),
-                        # Upsample as long as it is not the last block.
-                        nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(),
-                    ]
-                )
-            )
-
-        self.final_conv = nn.Sequential(
-            _Conv1dBlock(down_dims[0], down_dims[0], kernel_size=kernel_size),
-            nn.Conv1d(down_dims[0], input_dim, 1),
-        )
-
-    def forward(self, x: Tensor, timestep: Tensor | int, local_cond=None, global_cond=None) -> Tensor:
-        """
-        Args:
-            x: (B, T, input_dim) tensor for input to the Unet.
-            timestep: (B,) tensor of (timestep_we_are_denoising_from - 1).
-            local_cond: (B, T, local_cond_dim)
-            global_cond: (B, global_cond_dim)
-            output: (B, T, input_dim)
-        Returns:
-            (B, T, input_dim)
-        """
-        # For 1D convolutions we'll need feature dimension first.
-        x = einops.rearrange(x, "b t d -> b d t")
-        if local_cond is not None:
-            if self.local_cond_down_encoder is None or self.local_cond_up_encoder is None:
-                raise ValueError(
-                    "`local_cond` was provided but the relevant encoders weren't built at initialization."
-                )
-            local_cond = einops.rearrange(local_cond, "b t d -> b d t")
-
-        timesteps_embed = self.diffusion_step_encoder(timestep)
-
-        # If there is a global conditioning feature, concatenate it to the timestep embedding.
-        if global_cond is not None:
-            global_feature = torch.cat([timesteps_embed, global_cond], axis=-1)
-        else:
-            global_feature = timesteps_embed
-
-        encoder_skip_features: list[Tensor] = []
-        for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
-            x = resnet(x, global_feature)
-            if idx == 0 and local_cond is not None:
-                x = x + self.local_cond_down_encoder(local_cond, global_feature)
-            x = resnet2(x, global_feature)
-            encoder_skip_features.append(x)
-            x = downsample(x)
-
-        for mid_module in self.mid_modules:
-            x = mid_module(x, global_feature)
-
-        for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
-            x = torch.cat((x, encoder_skip_features.pop()), dim=1)
-            x = resnet(x, global_feature)
-            # Note: The condition in the original implementation is:
-            # if idx == len(self.up_modules) and local_cond is not None:
-            # But as they mention in their comments, this is incorrect. We use the correct condition here.
-            if idx == (len(self.up_modules) - 1) and local_cond is not None:
-                x = x + self.local_cond_up_encoder(local_cond, global_feature)
-            x = resnet2(x, global_feature)
-            x = upsample(x)
-
-        x = self.final_conv(x)
-
-        x = einops.rearrange(x, "b d t -> b t d")
-        return x
diff --git a/lerobot/common/policies/diffusion/model/diffusion_unet_image_policy.py b/lerobot/common/policies/diffusion/model/diffusion_unet_image_policy.py
deleted file mode 100644
index 3e7727f3..00000000
--- a/lerobot/common/policies/diffusion/model/diffusion_unet_image_policy.py
+++ /dev/null
@@ -1,175 +0,0 @@
-import einops
-import torch
-import torch.nn.functional as F  # noqa: N812
-from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
-from torch import Tensor, nn
-
-from lerobot.common.policies.diffusion.model.conditional_unet1d import ConditionalUnet1D
-from lerobot.common.policies.diffusion.model.rgb_encoder import RgbEncoder
-from lerobot.common.policies.utils import get_device_from_parameters, get_dtype_from_parameters
-
-
-class DiffusionUnetImagePolicy(nn.Module):
-    def __init__(
-        self,
-        cfg,
-        shape_meta: dict,
-        noise_scheduler: DDPMScheduler,
-        horizon,
-        n_action_steps,
-        n_obs_steps,
-        num_inference_steps=None,
-        diffusion_step_embed_dim=256,
-        down_dims=(256, 512, 1024),
-        kernel_size=5,
-        n_groups=8,
-        film_scale_modulation=True,
-    ):
-        super().__init__()
-        action_shape = shape_meta["action"]["shape"]
-        assert len(action_shape) == 1
-        action_dim = action_shape[0]
-
-        self.rgb_encoder = RgbEncoder(input_shape=shape_meta.obs.image.shape, **cfg.rgb_encoder)
-
-        self.unet = ConditionalUnet1D(
-            input_dim=action_dim,
-            global_cond_dim=(action_dim + self.rgb_encoder.feature_dim) * n_obs_steps,
-            diffusion_step_embed_dim=diffusion_step_embed_dim,
-            down_dims=down_dims,
-            kernel_size=kernel_size,
-            n_groups=n_groups,
-            film_scale_modulation=film_scale_modulation,
-        )
-
-        self.noise_scheduler = noise_scheduler
-        self.horizon = horizon
-        self.action_dim = action_dim
-        self.n_action_steps = n_action_steps
-        self.n_obs_steps = n_obs_steps
-
-        if num_inference_steps is None:
-            num_inference_steps = noise_scheduler.config.num_train_timesteps
-
-        self.num_inference_steps = num_inference_steps
-
-    # ========= inference  ============
-    def conditional_sample(self, batch_size, global_cond=None, generator=None):
-        device = get_device_from_parameters(self)
-        dtype = get_dtype_from_parameters(self)
-
-        # Sample prior.
-        sample = torch.randn(
-            size=(batch_size, self.horizon, self.action_dim),
-            dtype=dtype,
-            device=device,
-            generator=generator,
-        )
-
-        self.noise_scheduler.set_timesteps(self.num_inference_steps)
-
-        for t in self.noise_scheduler.timesteps:
-            # Predict model output.
-            model_output = self.unet(
-                sample,
-                torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device),
-                global_cond=global_cond,
-            )
-            # Compute previous image: x_t -> x_t-1
-            sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample
-
-        return sample
-
-    def generate_actions(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
-        """
-        This function expects `batch` to have (at least):
-        {
-            "observation.state": (B, n_obs_steps, state_dim)
-            "observation.image": (B, n_obs_steps, C, H, W)
-        }
-        """
-        assert set(batch).issuperset({"observation.state", "observation.image"})
-        batch_size, n_obs_steps = batch["observation.state"].shape[:2]
-        assert n_obs_steps == self.n_obs_steps
-        assert self.n_obs_steps == n_obs_steps
-
-        # Extract image feature (first combine batch and sequence dims).
-        img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
-        # Separate batch and sequence dims.
-        img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
-        # Concatenate state and image features then flatten to (B, global_cond_dim).
-        global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
-
-        # run sampling
-        sample = self.conditional_sample(batch_size, global_cond=global_cond)
-
-        # `horizon` steps worth of actions (from the first observation).
-        action = sample[..., : self.action_dim]
-        # Extract `n_action_steps` steps worth of actions (from the current observation).
-        start = n_obs_steps - 1
-        end = start + self.n_action_steps
-        action = action[:, start:end]
-
-        return action
-
-    def compute_loss(self, batch: dict[str, Tensor]) -> Tensor:
-        """
-        This function expects `batch` to have (at least):
-        {
-            "observation.state": (B, n_obs_steps, state_dim)
-            "observation.image": (B, n_obs_steps, C, H, W)
-            "action": (B, horizon, action_dim)
-            "action_is_pad": (B, horizon)
-        }
-        """
-        # Input validation.
-        assert set(batch).issuperset({"observation.state", "observation.image", "action", "action_is_pad"})
-        batch_size, n_obs_steps = batch["observation.state"].shape[:2]
-        horizon = batch["action"].shape[1]
-        assert horizon == self.horizon
-        assert n_obs_steps == self.n_obs_steps
-        assert self.n_obs_steps == n_obs_steps
-
-        # Extract image feature (first combine batch and sequence dims).
-        img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
-        # Separate batch and sequence dims.
-        img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
-        # Concatenate state and image features then flatten to (B, global_cond_dim).
-        global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
-
-        trajectory = batch["action"]
-
-        # Forward diffusion.
-        # Sample noise to add to the trajectory.
-        eps = torch.randn(trajectory.shape, device=trajectory.device)
-        # Sample a random noising timestep for each item in the batch.
-        timesteps = torch.randint(
-            low=0,
-            high=self.noise_scheduler.config.num_train_timesteps,
-            size=(trajectory.shape[0],),
-            device=trajectory.device,
-        ).long()
-        # Add noise to the clean trajectories according to the noise magnitude at each timestep.
-        noisy_trajectory = self.noise_scheduler.add_noise(trajectory, eps, timesteps)
-
-        # Run the denoising network (that might denoise the trajectory, or attempt to predict the noise).
-        pred = self.unet(noisy_trajectory, timesteps, global_cond=global_cond)
-
-        # Compute the loss.
-        # The targe is either the original trajectory, or the noise.
-        pred_type = self.noise_scheduler.config.prediction_type
-        if pred_type == "epsilon":
-            target = eps
-        elif pred_type == "sample":
-            target = batch["action"]
-        else:
-            raise ValueError(f"Unsupported prediction type {pred_type}")
-
-        loss = F.mse_loss(pred, target, reduction="none")
-
-        # Mask loss wherever the action is padded with copies (edges of the dataset trajectory).
-        if "action_is_pad" in batch:
-            in_episode_bound = ~batch["action_is_pad"]
-            loss = loss * in_episode_bound.unsqueeze(-1)
-
-        return loss.mean()
diff --git a/lerobot/common/policies/diffusion/model/ema_model.py b/lerobot/common/policies/diffusion/model/ema_model.py
deleted file mode 100644
index 1e3447f3..00000000
--- a/lerobot/common/policies/diffusion/model/ema_model.py
+++ /dev/null
@@ -1,68 +0,0 @@
-import torch
-from torch.nn.modules.batchnorm import _BatchNorm
-
-
-class EMAModel:
-    """
-    Exponential Moving Average of models weights
-    """
-
-    def __init__(
-        self, model, update_after_step=0, inv_gamma=1.0, power=2 / 3, min_value=0.0, max_value=0.9999
-    ):
-        """
-        @crowsonkb's notes on EMA Warmup:
-            If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
-            to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
-            gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
-            at 215.4k steps).
-        Args:
-            inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
-            power (float): Exponential factor of EMA warmup. Default: 2/3.
-            min_value (float): The minimum EMA decay rate. Default: 0.
-        """
-
-        self.averaged_model = model
-        self.averaged_model.eval()
-        self.averaged_model.requires_grad_(False)
-
-        self.update_after_step = update_after_step
-        self.inv_gamma = inv_gamma
-        self.power = power
-        self.min_value = min_value
-        self.max_value = max_value
-
-        self.alpha = 0.0
-        self.optimization_step = 0
-
-    def get_decay(self, optimization_step):
-        """
-        Compute the decay factor for the exponential moving average.
-        """
-        step = max(0, optimization_step - self.update_after_step - 1)
-        value = 1 - (1 + step / self.inv_gamma) ** -self.power
-
-        if step <= 0:
-            return 0.0
-
-        return max(self.min_value, min(value, self.max_value))
-
-    @torch.no_grad()
-    def step(self, new_model):
-        self.alpha = self.get_decay(self.optimization_step)
-
-        for module, ema_module in zip(new_model.modules(), self.averaged_model.modules(), strict=True):
-            # Iterate over immediate parameters only.
-            for param, ema_param in zip(
-                module.parameters(recurse=False), ema_module.parameters(recurse=False), strict=True
-            ):
-                if isinstance(param, dict):
-                    raise RuntimeError("Dict parameter not supported")
-                if isinstance(module, _BatchNorm) or not param.requires_grad:
-                    # Copy BatchNorm parameters, and non-trainable parameters directly.
-                    ema_param.copy_(param.to(dtype=ema_param.dtype).data)
-                else:
-                    ema_param.mul_(self.alpha)
-                    ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.alpha)
-
-        self.optimization_step += 1
diff --git a/lerobot/common/policies/diffusion/model/rgb_encoder.py b/lerobot/common/policies/diffusion/model/rgb_encoder.py
deleted file mode 100644
index 2a5edf45..00000000
--- a/lerobot/common/policies/diffusion/model/rgb_encoder.py
+++ /dev/null
@@ -1,147 +0,0 @@
-from typing import Callable
-
-import torch
-import torchvision
-from robomimic.models.base_nets import SpatialSoftmax
-from torch import Tensor, nn
-from torchvision.transforms import CenterCrop, RandomCrop
-
-
-class RgbEncoder(nn.Module):
-    """Encoder an RGB image into a 1D feature vector.
-
-    Includes the ability to normalize and crop the image first.
-    """
-
-    def __init__(
-        self,
-        input_shape: tuple[int, int, int],
-        norm_mean_std: tuple[float, float] = [1.0, 1.0],
-        crop_shape: tuple[int, int] | None = None,
-        random_crop: bool = False,
-        backbone_name: str = "resnet18",
-        pretrained_backbone: bool = False,
-        use_group_norm: bool = False,
-        relu: bool = True,
-        num_keypoints: int = 32,
-    ):
-        """
-        Args:
-            input_shape: channel-first input shape (C, H, W)
-            norm_mean_std: mean and standard deviation used for image normalization. Images are normalized as
-                (image - mean) / std.
-            crop_shape: (H, W) shape to crop to (must fit within the input shape). If not provided, no
-                cropping is done.
-            random_crop: Whether the crop should be random at training time (it's always a center crop in
-                eval mode).
-            backbone_name: The name of one of the available resnet models from torchvision (eg resnet18).
-            pretrained_backbone: whether to use timm pretrained weights.
-            use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
-                The group sizes are set to be about 16 (to be precise, feature_dim // 16).
-            relu: whether to use relu as a final step.
-            num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image).
-        """
-        super().__init__()
-        if input_shape[0] != 3:
-            raise ValueError("Only RGB images are handled")
-        if not backbone_name.startswith("resnet"):
-            raise ValueError(
-                "Only resnet is supported for now (because of the assumption that 'layer4' is the output layer)"
-            )
-
-        # Set up optional preprocessing.
-        if norm_mean_std == [1.0, 1.0]:
-            self.normalizer = nn.Identity()
-        else:
-            self.normalizer = torchvision.transforms.Normalize(mean=norm_mean_std[0], std=norm_mean_std[1])
-
-        if crop_shape is not None:
-            self.do_crop = True
-            self.center_crop = CenterCrop(crop_shape)  # always use center crop for eval
-            if random_crop:
-                self.maybe_random_crop = RandomCrop(crop_shape)
-            else:
-                self.maybe_random_crop = self.center_crop
-        else:
-            self.do_crop = False
-
-        # Set up backbone.
-        backbone_model = getattr(torchvision.models, backbone_name)(pretrained=pretrained_backbone)
-        # Note: This assumes that the layer4 feature map is children()[-3]
-        # TODO(alexander-soare): Use a safer alternative.
-        self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
-        if use_group_norm:
-            if pretrained_backbone:
-                raise ValueError(
-                    "You can't replace BatchNorm in a pretrained model without ruining the weights!"
-                )
-            self.backbone = _replace_submodules(
-                root_module=self.backbone,
-                predicate=lambda x: isinstance(x, nn.BatchNorm2d),
-                func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
-            )
-
-        # Set up pooling and final layers.
-        # Use a dry run to get the feature map shape.
-        with torch.inference_mode():
-            feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:])
-        self.pool = SpatialSoftmax(feat_map_shape, num_kp=num_keypoints)
-        self.feature_dim = num_keypoints * 2
-        self.out = nn.Linear(num_keypoints * 2, self.feature_dim)
-        self.maybe_relu = nn.ReLU() if relu else nn.Identity()
-
-    def forward(self, x: Tensor) -> Tensor:
-        """
-        Args:
-            x: (B, C, H, W) image tensor with pixel values in [0, 1].
-        Returns:
-            (B, D) image feature.
-        """
-        # Preprocess: normalize and maybe crop (if it was set up in the __init__).
-        x = self.normalizer(x)
-        if self.do_crop:
-            if self.training:  # noqa: SIM108
-                x = self.maybe_random_crop(x)
-            else:
-                # Always use center crop for eval.
-                x = self.center_crop(x)
-        # Extract backbone feature.
-        x = torch.flatten(self.pool(self.backbone(x)), start_dim=1)
-        # Final linear layer.
-        x = self.out(x)
-        # Maybe a final non-linearity.
-        x = self.maybe_relu(x)
-        return x
-
-
-def _replace_submodules(
-    root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
-) -> nn.Module:
-    """
-    Args:
-        root_module: The module for which the submodules need to be replaced
-        predicate: Takes a module as an argument and must return True if the that module is to be replaced.
-        func: Takes a module as an argument and returns a new module to replace it with.
-    Returns:
-        The root module with its submodules replaced.
-    """
-    if predicate(root_module):
-        return func(root_module)
-
-    replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
-    for *parents, k in replace_list:
-        parent_module = root_module
-        if len(parents) > 0:
-            parent_module = root_module.get_submodule(".".join(parents))
-        if isinstance(parent_module, nn.Sequential):
-            src_module = parent_module[int(k)]
-        else:
-            src_module = getattr(parent_module, k)
-        tgt_module = func(src_module)
-        if isinstance(parent_module, nn.Sequential):
-            parent_module[int(k)] = tgt_module
-        else:
-            setattr(parent_module, k, tgt_module)
-    # verify that all BN are replaced
-    assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
-    return root_module
diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py
new file mode 100644
index 00000000..4853dbcf
--- /dev/null
+++ b/lerobot/common/policies/diffusion/modeling_diffusion.py
@@ -0,0 +1,878 @@
+"""
+TODO(alexander-soare):
+  - Remove reliance on Robomimic for SpatialSoftmax.
+  - Remove reliance on diffusers for DDPMScheduler.
+  - Move EMA out of policy.
+"""
+
+import copy
+import logging
+import math
+import time
+from collections import deque
+from typing import Callable
+
+import einops
+import hydra
+import torch
+import torch.nn.functional as F  # noqa: N812
+import torchvision
+from diffusers.optimization import get_scheduler
+from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
+from robomimic.models.base_nets import SpatialSoftmax
+from torch import Tensor, nn
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from lerobot.common.policies.utils import (
+    get_device_from_parameters,
+    get_dtype_from_parameters,
+    populate_queues,
+)
+from lerobot.common.utils import get_safe_torch_device
+
+logger = logging.getLogger(__name__)
+
+
+class DiffusionPolicy(nn.Module):
+    """
+    Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
+    (paper: https://arxiv.org/abs/2303.04137, code: https://github.com/real-stanford/diffusion_policy).
+    """
+
+    name = "diffusion"
+
+    def __init__(
+        self,
+        cfg,
+        cfg_device,
+        cfg_noise_scheduler,
+        cfg_optimizer,
+        cfg_ema,
+        shape_meta: dict,
+        horizon,
+        n_action_steps,
+        n_obs_steps,
+        num_inference_steps=None,
+        diffusion_step_embed_dim=256,
+        down_dims=(256, 512, 1024),
+        kernel_size=5,
+        n_groups=8,
+        film_scale_modulation=True,
+        **_,
+    ):
+        super().__init__()
+        self.cfg = cfg
+        self.n_obs_steps = n_obs_steps
+        self.n_action_steps = n_action_steps
+
+        # queues are populated during rollout of the policy, they contain the n latest observations and actions
+        self._queues = None
+
+        noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
+
+        self.diffusion = _DiffusionUnetImagePolicy(
+            cfg,
+            shape_meta=shape_meta,
+            noise_scheduler=noise_scheduler,
+            horizon=horizon,
+            n_action_steps=n_action_steps,
+            n_obs_steps=n_obs_steps,
+            num_inference_steps=num_inference_steps,
+            diffusion_step_embed_dim=diffusion_step_embed_dim,
+            down_dims=down_dims,
+            kernel_size=kernel_size,
+            n_groups=n_groups,
+            film_scale_modulation=film_scale_modulation,
+        )
+
+        self.device = get_safe_torch_device(cfg_device)
+        self.diffusion.to(self.device)
+
+        # TODO(alexander-soare): This should probably be managed outside of the policy class.
+        self.ema_diffusion = None
+        self.ema = None
+        if self.cfg.use_ema:
+            self.ema_diffusion = copy.deepcopy(self.diffusion)
+            self.ema = hydra.utils.instantiate(
+                cfg_ema,
+                model=self.ema_diffusion,
+            )
+
+        self.optimizer = hydra.utils.instantiate(
+            cfg_optimizer,
+            params=self.diffusion.parameters(),
+        )
+
+        # TODO(rcadene): modify lr scheduler so that it doesnt depend on epochs but steps
+        self.global_step = 0
+
+        # configure lr scheduler
+        self.lr_scheduler = get_scheduler(
+            cfg.lr_scheduler,
+            optimizer=self.optimizer,
+            num_warmup_steps=cfg.lr_warmup_steps,
+            num_training_steps=cfg.offline_steps,
+            # pytorch assumes stepping LRScheduler every epoch
+            # however huggingface diffusers steps it every batch
+            last_epoch=self.global_step - 1,
+        )
+
+    def reset(self):
+        """
+        Clear observation and action queues. Should be called on `env.reset()`
+        """
+        self._queues = {
+            "observation.image": deque(maxlen=self.n_obs_steps),
+            "observation.state": deque(maxlen=self.n_obs_steps),
+            "action": deque(maxlen=self.n_action_steps),
+        }
+
+    @torch.no_grad
+    def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
+        """Select a single action given environment observations.
+
+        This method handles caching a history of observations and an action trajectory generated by the
+        underlying diffusion model. Here's how it works:
+          - `n_obs_steps` steps worth of observations are cached (for the first steps, the observation is
+            copied `n_obs_steps` times to fill the cache).
+          - The diffusion model generates `horizon` steps worth of actions.
+          - `n_action_steps` worth of actions are actually kept for execution, starting from the current step.
+        Schematically this looks like:
+            (legend: o = n_obs_steps, h = horizon, a = n_action_steps)
+            |timestep            | n-o+1 | n-o+2 | ..... | n     | ..... | n+a-1 | n+a   | ..... |n-o+1+h|
+            |observation is used | YES   | YES   | ..... | NO    | NO    | NO    | NO    | NO    | NO    |
+            |action is generated | YES   | YES   | YES   | YES   | YES   | YES   | YES   | YES   | YES   |
+            |action is used      | NO    | NO    | NO    | YES   | YES   | YES   | NO    | NO    | NO    |
+        Note that this means we require: `n_action_steps < horizon - n_obs_steps + 1`. Also, note that
+        "horizon" may not the best name to describe what the variable actually means, because this period is
+        actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
+
+        Note: this method uses the ema model weights if self.training == False, otherwise the non-ema model
+        weights.
+        """
+        assert "observation.image" in batch
+        assert "observation.state" in batch
+        assert len(batch) == 2
+
+        self._queues = populate_queues(self._queues, batch)
+
+        if len(self._queues["action"]) == 0:
+            # stack n latest observations from the queue
+            batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
+            if not self.training and self.ema_diffusion is not None:
+                actions = self.ema_diffusion.generate_actions(batch)
+            else:
+                actions = self.diffusion.generate_actions(batch)
+            self._queues["action"].extend(actions.transpose(0, 1))
+
+        action = self._queues["action"].popleft()
+        return action
+
+    def forward(self, batch, **_):
+        start_time = time.time()
+
+        self.diffusion.train()
+
+        loss = self.diffusion.compute_loss(batch)
+        loss.backward()
+
+        grad_norm = torch.nn.utils.clip_grad_norm_(
+            self.diffusion.parameters(),
+            self.cfg.grad_clip_norm,
+            error_if_nonfinite=False,
+        )
+
+        self.optimizer.step()
+        self.optimizer.zero_grad()
+        self.lr_scheduler.step()
+
+        if self.ema is not None:
+            self.ema.step(self.diffusion)
+
+        info = {
+            "loss": loss.item(),
+            "grad_norm": float(grad_norm),
+            "lr": self.lr_scheduler.get_last_lr()[0],
+            "update_s": time.time() - start_time,
+        }
+
+        return info
+
+    def save(self, fp):
+        torch.save(self.state_dict(), fp)
+
+    def load(self, fp):
+        d = torch.load(fp)
+        missing_keys, unexpected_keys = self.load_state_dict(d, strict=False)
+        if len(missing_keys) > 0:
+            assert all(k.startswith("ema_diffusion.") for k in missing_keys)
+            logging.warning(
+                "DiffusionPolicy.load expected ema parameters in loaded state dict but none were found."
+            )
+        assert len(unexpected_keys) == 0
+
+
+class _DiffusionUnetImagePolicy(nn.Module):
+    def __init__(
+        self,
+        cfg,
+        shape_meta: dict,
+        noise_scheduler: DDPMScheduler,
+        horizon,
+        n_action_steps,
+        n_obs_steps,
+        num_inference_steps=None,
+        diffusion_step_embed_dim=256,
+        down_dims=(256, 512, 1024),
+        kernel_size=5,
+        n_groups=8,
+        film_scale_modulation=True,
+    ):
+        super().__init__()
+        action_shape = shape_meta["action"]["shape"]
+        assert len(action_shape) == 1
+        action_dim = action_shape[0]
+
+        self.rgb_encoder = _RgbEncoder(input_shape=shape_meta.obs.image.shape, **cfg.rgb_encoder)
+
+        self.unet = _ConditionalUnet1D(
+            input_dim=action_dim,
+            global_cond_dim=(action_dim + self.rgb_encoder.feature_dim) * n_obs_steps,
+            diffusion_step_embed_dim=diffusion_step_embed_dim,
+            down_dims=down_dims,
+            kernel_size=kernel_size,
+            n_groups=n_groups,
+            film_scale_modulation=film_scale_modulation,
+        )
+
+        self.noise_scheduler = noise_scheduler
+        self.horizon = horizon
+        self.action_dim = action_dim
+        self.n_action_steps = n_action_steps
+        self.n_obs_steps = n_obs_steps
+
+        if num_inference_steps is None:
+            num_inference_steps = noise_scheduler.config.num_train_timesteps
+
+        self.num_inference_steps = num_inference_steps
+
+    # ========= inference  ============
+    def conditional_sample(self, batch_size, global_cond=None, generator=None):
+        device = get_device_from_parameters(self)
+        dtype = get_dtype_from_parameters(self)
+
+        # Sample prior.
+        sample = torch.randn(
+            size=(batch_size, self.horizon, self.action_dim),
+            dtype=dtype,
+            device=device,
+            generator=generator,
+        )
+
+        self.noise_scheduler.set_timesteps(self.num_inference_steps)
+
+        for t in self.noise_scheduler.timesteps:
+            # Predict model output.
+            model_output = self.unet(
+                sample,
+                torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device),
+                global_cond=global_cond,
+            )
+            # Compute previous image: x_t -> x_t-1
+            sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample
+
+        return sample
+
+    def generate_actions(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
+        """
+        This function expects `batch` to have (at least):
+        {
+            "observation.state": (B, n_obs_steps, state_dim)
+            "observation.image": (B, n_obs_steps, C, H, W)
+        }
+        """
+        assert set(batch).issuperset({"observation.state", "observation.image"})
+        batch_size, n_obs_steps = batch["observation.state"].shape[:2]
+        assert n_obs_steps == self.n_obs_steps
+        assert self.n_obs_steps == n_obs_steps
+
+        # Extract image feature (first combine batch and sequence dims).
+        img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
+        # Separate batch and sequence dims.
+        img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
+        # Concatenate state and image features then flatten to (B, global_cond_dim).
+        global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
+
+        # run sampling
+        sample = self.conditional_sample(batch_size, global_cond=global_cond)
+
+        # `horizon` steps worth of actions (from the first observation).
+        action = sample[..., : self.action_dim]
+        # Extract `n_action_steps` steps worth of actions (from the current observation).
+        start = n_obs_steps - 1
+        end = start + self.n_action_steps
+        action = action[:, start:end]
+
+        return action
+
+    def compute_loss(self, batch: dict[str, Tensor]) -> Tensor:
+        """
+        This function expects `batch` to have (at least):
+        {
+            "observation.state": (B, n_obs_steps, state_dim)
+            "observation.image": (B, n_obs_steps, C, H, W)
+            "action": (B, horizon, action_dim)
+            "action_is_pad": (B, horizon)
+        }
+        """
+        # Input validation.
+        assert set(batch).issuperset({"observation.state", "observation.image", "action", "action_is_pad"})
+        batch_size, n_obs_steps = batch["observation.state"].shape[:2]
+        horizon = batch["action"].shape[1]
+        assert horizon == self.horizon
+        assert n_obs_steps == self.n_obs_steps
+        assert self.n_obs_steps == n_obs_steps
+
+        # Extract image feature (first combine batch and sequence dims).
+        img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
+        # Separate batch and sequence dims.
+        img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
+        # Concatenate state and image features then flatten to (B, global_cond_dim).
+        global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
+
+        trajectory = batch["action"]
+
+        # Forward diffusion.
+        # Sample noise to add to the trajectory.
+        eps = torch.randn(trajectory.shape, device=trajectory.device)
+        # Sample a random noising timestep for each item in the batch.
+        timesteps = torch.randint(
+            low=0,
+            high=self.noise_scheduler.config.num_train_timesteps,
+            size=(trajectory.shape[0],),
+            device=trajectory.device,
+        ).long()
+        # Add noise to the clean trajectories according to the noise magnitude at each timestep.
+        noisy_trajectory = self.noise_scheduler.add_noise(trajectory, eps, timesteps)
+
+        # Run the denoising network (that might denoise the trajectory, or attempt to predict the noise).
+        pred = self.unet(noisy_trajectory, timesteps, global_cond=global_cond)
+
+        # Compute the loss.
+        # The targe is either the original trajectory, or the noise.
+        pred_type = self.noise_scheduler.config.prediction_type
+        if pred_type == "epsilon":
+            target = eps
+        elif pred_type == "sample":
+            target = batch["action"]
+        else:
+            raise ValueError(f"Unsupported prediction type {pred_type}")
+
+        loss = F.mse_loss(pred, target, reduction="none")
+
+        # Mask loss wherever the action is padded with copies (edges of the dataset trajectory).
+        if "action_is_pad" in batch:
+            in_episode_bound = ~batch["action_is_pad"]
+            loss = loss * in_episode_bound.unsqueeze(-1)
+
+        return loss.mean()
+
+
+class _RgbEncoder(nn.Module):
+    """Encoder an RGB image into a 1D feature vector.
+
+    Includes the ability to normalize and crop the image first.
+    """
+
+    def __init__(
+        self,
+        input_shape: tuple[int, int, int],
+        norm_mean_std: tuple[float, float] = [1.0, 1.0],
+        crop_shape: tuple[int, int] | None = None,
+        random_crop: bool = False,
+        backbone_name: str = "resnet18",
+        pretrained_backbone: bool = False,
+        use_group_norm: bool = False,
+        num_keypoints: int = 32,
+    ):
+        """
+        Args:
+            input_shape: channel-first input shape (C, H, W)
+            norm_mean_std: mean and standard deviation used for image normalization. Images are normalized as
+                (image - mean) / std.
+            crop_shape: (H, W) shape to crop to (must fit within the input shape). If not provided, no
+                cropping is done.
+            random_crop: Whether the crop should be random at training time (it's always a center crop in
+                eval mode).
+            backbone_name: The name of one of the available resnet models from torchvision (eg resnet18).
+            pretrained_backbone: whether to use timm pretrained weights.
+            use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
+                The group sizes are set to be about 16 (to be precise, feature_dim // 16).
+            num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image).
+        """
+        super().__init__()
+        if input_shape[0] != 3:
+            raise ValueError("Only RGB images are handled")
+        if not backbone_name.startswith("resnet"):
+            raise ValueError(
+                "Only resnet is supported for now (because of the assumption that 'layer4' is the output layer)"
+            )
+
+        # Set up optional preprocessing.
+        if norm_mean_std == [1.0, 1.0]:
+            self.normalizer = nn.Identity()
+        else:
+            self.normalizer = torchvision.transforms.Normalize(mean=norm_mean_std[0], std=norm_mean_std[1])
+
+        if crop_shape is not None:
+            self.do_crop = True
+            # Always use center crop for eval
+            self.center_crop = torchvision.transforms.CenterCrop(crop_shape)
+            if random_crop:
+                self.maybe_random_crop = torchvision.transforms.RandomCrop(crop_shape)
+            else:
+                self.maybe_random_crop = self.center_crop
+        else:
+            self.do_crop = False
+
+        # Set up backbone.
+        backbone_model = getattr(torchvision.models, backbone_name)(pretrained=pretrained_backbone)
+        # Note: This assumes that the layer4 feature map is children()[-3]
+        # TODO(alexander-soare): Use a safer alternative.
+        self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
+        if use_group_norm:
+            if pretrained_backbone:
+                raise ValueError(
+                    "You can't replace BatchNorm in a pretrained model without ruining the weights!"
+                )
+            self.backbone = _replace_submodules(
+                root_module=self.backbone,
+                predicate=lambda x: isinstance(x, nn.BatchNorm2d),
+                func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
+            )
+
+        # Set up pooling and final layers.
+        # Use a dry run to get the feature map shape.
+        with torch.inference_mode():
+            feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:])
+        self.pool = SpatialSoftmax(feat_map_shape, num_kp=num_keypoints)
+        self.feature_dim = num_keypoints * 2
+        self.out = nn.Linear(num_keypoints * 2, self.feature_dim)
+        self.relu = nn.ReLU()
+
+    def forward(self, x: Tensor) -> Tensor:
+        """
+        Args:
+            x: (B, C, H, W) image tensor with pixel values in [0, 1].
+        Returns:
+            (B, D) image feature.
+        """
+        # Preprocess: normalize and maybe crop (if it was set up in the __init__).
+        x = self.normalizer(x)
+        if self.do_crop:
+            if self.training:  # noqa: SIM108
+                x = self.maybe_random_crop(x)
+            else:
+                # Always use center crop for eval.
+                x = self.center_crop(x)
+        # Extract backbone feature.
+        x = torch.flatten(self.pool(self.backbone(x)), start_dim=1)
+        # Final linear layer with non-linearity.
+        x = self.relu(self.out(x))
+        return x
+
+
+def _replace_submodules(
+    root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
+) -> nn.Module:
+    """
+    Args:
+        root_module: The module for which the submodules need to be replaced
+        predicate: Takes a module as an argument and must return True if the that module is to be replaced.
+        func: Takes a module as an argument and returns a new module to replace it with.
+    Returns:
+        The root module with its submodules replaced.
+    """
+    if predicate(root_module):
+        return func(root_module)
+
+    replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
+    for *parents, k in replace_list:
+        parent_module = root_module
+        if len(parents) > 0:
+            parent_module = root_module.get_submodule(".".join(parents))
+        if isinstance(parent_module, nn.Sequential):
+            src_module = parent_module[int(k)]
+        else:
+            src_module = getattr(parent_module, k)
+        tgt_module = func(src_module)
+        if isinstance(parent_module, nn.Sequential):
+            parent_module[int(k)] = tgt_module
+        else:
+            setattr(parent_module, k, tgt_module)
+    # verify that all BN are replaced
+    assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
+    return root_module
+
+
+class _SinusoidalPosEmb(nn.Module):
+    def __init__(self, dim):
+        super().__init__()
+        self.dim = dim
+
+    def forward(self, x):
+        device = x.device
+        half_dim = self.dim // 2
+        emb = math.log(10000) / (half_dim - 1)
+        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
+        emb = x[:, None] * emb[None, :]
+        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
+        return emb
+
+
+class _Conv1dBlock(nn.Module):
+    """Conv1d --> GroupNorm --> Mish"""
+
+    def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
+        super().__init__()
+
+        self.block = nn.Sequential(
+            nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
+            nn.GroupNorm(n_groups, out_channels),
+            nn.Mish(),
+        )
+
+    def forward(self, x):
+        return self.block(x)
+
+
+class _ConditionalUnet1D(nn.Module):
+    """A 1D convolutional UNet with FiLM modulation for conditioning.
+
+    Two types of conditioning can be applied:
+    - Global: Conditioning information that is aggregated over the whole observation window. This is
+        incorporated via the FiLM technique in the residual convolution blocks of the Unet's encoder/decoder.
+    - Local: Conditioning information for each timestep in the observation window. This is incorporated
+        by encoding the information via 1D convolutions and adding the resulting embeddings to the inputs and
+        outputs of the Unet's encoder/decoder.
+    """
+
+    def __init__(
+        self,
+        input_dim: int,
+        local_cond_dim: int | None = None,
+        global_cond_dim: int | None = None,
+        diffusion_step_embed_dim: int = 256,
+        down_dims: int | None = None,
+        kernel_size: int = 3,
+        n_groups: int = 8,
+        film_scale_modulation: bool = False,
+    ):
+        super().__init__()
+
+        if down_dims is None:
+            down_dims = [256, 512, 1024]
+
+        # Encoder for the diffusion timestep.
+        self.diffusion_step_encoder = nn.Sequential(
+            _SinusoidalPosEmb(diffusion_step_embed_dim),
+            nn.Linear(diffusion_step_embed_dim, diffusion_step_embed_dim * 4),
+            nn.Mish(),
+            nn.Linear(diffusion_step_embed_dim * 4, diffusion_step_embed_dim),
+        )
+
+        # The FiLM conditioning dimension.
+        cond_dim = diffusion_step_embed_dim
+        if global_cond_dim is not None:
+            cond_dim += global_cond_dim
+
+        self.local_cond_down_encoder = None
+        self.local_cond_up_encoder = None
+        if local_cond_dim is not None:
+            # Encoder for the local conditioning. The output gets added to the Unet encoder input.
+            self.local_cond_down_encoder = _ConditionalResidualBlock1D(
+                local_cond_dim,
+                down_dims[0],
+                cond_dim=cond_dim,
+                kernel_size=kernel_size,
+                n_groups=n_groups,
+                film_scale_modulation=film_scale_modulation,
+            )
+            # Encoder for the local conditioning. The output gets added to the Unet encoder output.
+            self.local_cond_up_encoder = _ConditionalResidualBlock1D(
+                local_cond_dim,
+                down_dims[0],
+                cond_dim=cond_dim,
+                kernel_size=kernel_size,
+                n_groups=n_groups,
+                film_scale_modulation=film_scale_modulation,
+            )
+
+        # In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we
+        # just reverse these.
+        in_out = [(input_dim, down_dims[0])] + list(zip(down_dims[:-1], down_dims[1:], strict=True))
+
+        # Unet encoder.
+        self.down_modules = nn.ModuleList([])
+        for ind, (dim_in, dim_out) in enumerate(in_out):
+            is_last = ind >= (len(in_out) - 1)
+            self.down_modules.append(
+                nn.ModuleList(
+                    [
+                        _ConditionalResidualBlock1D(
+                            dim_in,
+                            dim_out,
+                            cond_dim=cond_dim,
+                            kernel_size=kernel_size,
+                            n_groups=n_groups,
+                            film_scale_modulation=film_scale_modulation,
+                        ),
+                        _ConditionalResidualBlock1D(
+                            dim_out,
+                            dim_out,
+                            cond_dim=cond_dim,
+                            kernel_size=kernel_size,
+                            n_groups=n_groups,
+                            film_scale_modulation=film_scale_modulation,
+                        ),
+                        # Downsample as long as it is not the last block.
+                        nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(),
+                    ]
+                )
+            )
+
+        # Processing in the middle of the auto-encoder.
+        self.mid_modules = nn.ModuleList(
+            [
+                _ConditionalResidualBlock1D(
+                    down_dims[-1],
+                    down_dims[-1],
+                    cond_dim=cond_dim,
+                    kernel_size=kernel_size,
+                    n_groups=n_groups,
+                    film_scale_modulation=film_scale_modulation,
+                ),
+                _ConditionalResidualBlock1D(
+                    down_dims[-1],
+                    down_dims[-1],
+                    cond_dim=cond_dim,
+                    kernel_size=kernel_size,
+                    n_groups=n_groups,
+                    film_scale_modulation=film_scale_modulation,
+                ),
+            ]
+        )
+
+        # Unet decoder.
+        self.up_modules = nn.ModuleList([])
+        for ind, (dim_out, dim_in) in enumerate(reversed(in_out[1:])):
+            is_last = ind >= (len(in_out) - 1)
+            self.up_modules.append(
+                nn.ModuleList(
+                    [
+                        _ConditionalResidualBlock1D(
+                            dim_in * 2,  # x2 as it takes the encoder's skip connection as well
+                            dim_out,
+                            cond_dim=cond_dim,
+                            kernel_size=kernel_size,
+                            n_groups=n_groups,
+                            film_scale_modulation=film_scale_modulation,
+                        ),
+                        _ConditionalResidualBlock1D(
+                            dim_out,
+                            dim_out,
+                            cond_dim=cond_dim,
+                            kernel_size=kernel_size,
+                            n_groups=n_groups,
+                            film_scale_modulation=film_scale_modulation,
+                        ),
+                        # Upsample as long as it is not the last block.
+                        nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(),
+                    ]
+                )
+            )
+
+        self.final_conv = nn.Sequential(
+            _Conv1dBlock(down_dims[0], down_dims[0], kernel_size=kernel_size),
+            nn.Conv1d(down_dims[0], input_dim, 1),
+        )
+
+    def forward(self, x: Tensor, timestep: Tensor | int, local_cond=None, global_cond=None) -> Tensor:
+        """
+        Args:
+            x: (B, T, input_dim) tensor for input to the Unet.
+            timestep: (B,) tensor of (timestep_we_are_denoising_from - 1).
+            local_cond: (B, T, local_cond_dim)
+            global_cond: (B, global_cond_dim)
+            output: (B, T, input_dim)
+        Returns:
+            (B, T, input_dim)
+        """
+        # For 1D convolutions we'll need feature dimension first.
+        x = einops.rearrange(x, "b t d -> b d t")
+        if local_cond is not None:
+            if self.local_cond_down_encoder is None or self.local_cond_up_encoder is None:
+                raise ValueError(
+                    "`local_cond` was provided but the relevant encoders weren't built at initialization."
+                )
+            local_cond = einops.rearrange(local_cond, "b t d -> b d t")
+
+        timesteps_embed = self.diffusion_step_encoder(timestep)
+
+        # If there is a global conditioning feature, concatenate it to the timestep embedding.
+        if global_cond is not None:
+            global_feature = torch.cat([timesteps_embed, global_cond], axis=-1)
+        else:
+            global_feature = timesteps_embed
+
+        encoder_skip_features: list[Tensor] = []
+        for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
+            x = resnet(x, global_feature)
+            if idx == 0 and local_cond is not None:
+                x = x + self.local_cond_down_encoder(local_cond, global_feature)
+            x = resnet2(x, global_feature)
+            encoder_skip_features.append(x)
+            x = downsample(x)
+
+        for mid_module in self.mid_modules:
+            x = mid_module(x, global_feature)
+
+        for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
+            x = torch.cat((x, encoder_skip_features.pop()), dim=1)
+            x = resnet(x, global_feature)
+            # Note: The condition in the original implementation is:
+            # if idx == len(self.up_modules) and local_cond is not None:
+            # But as they mention in their comments, this is incorrect. We use the correct condition here.
+            if idx == (len(self.up_modules) - 1) and local_cond is not None:
+                x = x + self.local_cond_up_encoder(local_cond, global_feature)
+            x = resnet2(x, global_feature)
+            x = upsample(x)
+
+        x = self.final_conv(x)
+
+        x = einops.rearrange(x, "b d t -> b t d")
+        return x
+
+
+class _ConditionalResidualBlock1D(nn.Module):
+    """ResNet style 1D convolutional block with FiLM modulation for conditioning."""
+
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        cond_dim: int,
+        kernel_size: int = 3,
+        n_groups: int = 8,
+        # Set to True to do scale modulation with FiLM as well as bias modulation (defaults to False meaning
+        # FiLM just modulates bias).
+        film_scale_modulation: bool = False,
+    ):
+        super().__init__()
+
+        self.film_scale_modulation = film_scale_modulation
+        self.out_channels = out_channels
+
+        self.conv1 = _Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups)
+
+        # FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale.
+        cond_channels = out_channels * 2 if film_scale_modulation else out_channels
+        self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels))
+
+        self.conv2 = _Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups)
+
+        # A final convolution for dimension matching the residual (if needed).
+        self.residual_conv = (
+            nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
+        )
+
+    def forward(self, x: Tensor, cond: Tensor) -> Tensor:
+        """
+        Args:
+            x: (B, in_channels, T)
+            cond: (B, cond_dim)
+        Returns:
+            (B, out_channels, T)
+        """
+        out = self.conv1(x)
+
+        # Get condition embedding. Unsqueeze for broadcasting to `out`, resulting in (B, out_channels, 1).
+        cond_embed = self.cond_encoder(cond).unsqueeze(-1)
+        if self.film_scale_modulation:
+            # Treat the embedding as a list of scales and biases.
+            scale = cond_embed[:, : self.out_channels]
+            bias = cond_embed[:, self.out_channels :]
+            out = scale * out + bias
+        else:
+            # Treat the embedding as biases.
+            out = out + cond_embed
+
+        out = self.conv2(out)
+        out = out + self.residual_conv(x)
+        return out
+
+
+class _EMA:
+    """
+    Exponential Moving Average of models weights
+    """
+
+    def __init__(
+        self, model, update_after_step=0, inv_gamma=1.0, power=2 / 3, min_value=0.0, max_value=0.9999
+    ):
+        """
+        @crowsonkb's notes on EMA Warmup:
+            If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
+            to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
+            gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
+            at 215.4k steps).
+        Args:
+            inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
+            power (float): Exponential factor of EMA warmup. Default: 2/3.
+            min_value (float): The minimum EMA decay rate. Default: 0.
+        """
+
+        self.averaged_model = model
+        self.averaged_model.eval()
+        self.averaged_model.requires_grad_(False)
+
+        self.update_after_step = update_after_step
+        self.inv_gamma = inv_gamma
+        self.power = power
+        self.min_value = min_value
+        self.max_value = max_value
+
+        self.alpha = 0.0
+        self.optimization_step = 0
+
+    def get_decay(self, optimization_step):
+        """
+        Compute the decay factor for the exponential moving average.
+        """
+        step = max(0, optimization_step - self.update_after_step - 1)
+        value = 1 - (1 + step / self.inv_gamma) ** -self.power
+
+        if step <= 0:
+            return 0.0
+
+        return max(self.min_value, min(value, self.max_value))
+
+    @torch.no_grad()
+    def step(self, new_model):
+        self.alpha = self.get_decay(self.optimization_step)
+
+        for module, ema_module in zip(new_model.modules(), self.averaged_model.modules(), strict=True):
+            # Iterate over immediate parameters only.
+            for param, ema_param in zip(
+                module.parameters(recurse=False), ema_module.parameters(recurse=False), strict=True
+            ):
+                if isinstance(param, dict):
+                    raise RuntimeError("Dict parameter not supported")
+                if isinstance(module, _BatchNorm) or not param.requires_grad:
+                    # Copy BatchNorm parameters, and non-trainable parameters directly.
+                    ema_param.copy_(param.to(dtype=ema_param.dtype).data)
+                else:
+                    ema_param.mul_(self.alpha)
+                    ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.alpha)
+
+        self.optimization_step += 1
diff --git a/lerobot/common/policies/diffusion/policy.py b/lerobot/common/policies/diffusion/policy.py
deleted file mode 100644
index f88e2f25..00000000
--- a/lerobot/common/policies/diffusion/policy.py
+++ /dev/null
@@ -1,169 +0,0 @@
-import copy
-import logging
-import time
-from collections import deque
-
-import hydra
-import torch
-from diffusers.optimization import get_scheduler
-from torch import nn
-
-from lerobot.common.policies.diffusion.model.diffusion_unet_image_policy import DiffusionUnetImagePolicy
-from lerobot.common.policies.utils import populate_queues
-from lerobot.common.utils import get_safe_torch_device
-
-
-class DiffusionPolicy(nn.Module):
-    name = "diffusion"
-
-    def __init__(
-        self,
-        cfg,
-        cfg_device,
-        cfg_noise_scheduler,
-        cfg_optimizer,
-        cfg_ema,
-        shape_meta: dict,
-        horizon,
-        n_action_steps,
-        n_obs_steps,
-        num_inference_steps=None,
-        diffusion_step_embed_dim=256,
-        down_dims=(256, 512, 1024),
-        kernel_size=5,
-        n_groups=8,
-        film_scale_modulation=True,
-        **_,
-    ):
-        super().__init__()
-        self.cfg = cfg
-        self.n_obs_steps = n_obs_steps
-        self.n_action_steps = n_action_steps
-
-        # queues are populated during rollout of the policy, they contain the n latest observations and actions
-        self._queues = None
-
-        noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
-
-        self.diffusion = DiffusionUnetImagePolicy(
-            cfg,
-            shape_meta=shape_meta,
-            noise_scheduler=noise_scheduler,
-            horizon=horizon,
-            n_action_steps=n_action_steps,
-            n_obs_steps=n_obs_steps,
-            num_inference_steps=num_inference_steps,
-            diffusion_step_embed_dim=diffusion_step_embed_dim,
-            down_dims=down_dims,
-            kernel_size=kernel_size,
-            n_groups=n_groups,
-            film_scale_modulation=film_scale_modulation,
-        )
-
-        self.device = get_safe_torch_device(cfg_device)
-        self.diffusion.to(self.device)
-
-        # TODO(alexander-soare): This should probably be managed outside of the policy class.
-        self.ema_diffusion = None
-        self.ema = None
-        if self.cfg.use_ema:
-            self.ema_diffusion = copy.deepcopy(self.diffusion)
-            self.ema = hydra.utils.instantiate(
-                cfg_ema,
-                model=self.ema_diffusion,
-            )
-
-        self.optimizer = hydra.utils.instantiate(
-            cfg_optimizer,
-            params=self.diffusion.parameters(),
-        )
-
-        # TODO(rcadene): modify lr scheduler so that it doesnt depend on epochs but steps
-        self.global_step = 0
-
-        # configure lr scheduler
-        self.lr_scheduler = get_scheduler(
-            cfg.lr_scheduler,
-            optimizer=self.optimizer,
-            num_warmup_steps=cfg.lr_warmup_steps,
-            num_training_steps=cfg.offline_steps,
-            # pytorch assumes stepping LRScheduler every epoch
-            # however huggingface diffusers steps it every batch
-            last_epoch=self.global_step - 1,
-        )
-
-    def reset(self):
-        """
-        Clear observation and action queues. Should be called on `env.reset()`
-        """
-        self._queues = {
-            "observation.image": deque(maxlen=self.n_obs_steps),
-            "observation.state": deque(maxlen=self.n_obs_steps),
-            "action": deque(maxlen=self.n_action_steps),
-        }
-
-    @torch.no_grad
-    def select_action(self, batch, **_):
-        """
-        Note: this uses the ema model weights if self.training == False, otherwise the non-ema model weights.
-        """
-        assert "observation.image" in batch
-        assert "observation.state" in batch
-        assert len(batch) == 2
-
-        self._queues = populate_queues(self._queues, batch)
-
-        if len(self._queues["action"]) == 0:
-            # stack n latest observations from the queue
-            batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
-            if not self.training and self.ema_diffusion is not None:
-                actions = self.ema_diffusion.generate_actions(batch)
-            else:
-                actions = self.diffusion.generate_actions(batch)
-            self._queues["action"].extend(actions.transpose(0, 1))
-
-        action = self._queues["action"].popleft()
-        return action
-
-    def forward(self, batch, **_):
-        start_time = time.time()
-
-        self.diffusion.train()
-
-        loss = self.diffusion.compute_loss(batch)
-        loss.backward()
-
-        grad_norm = torch.nn.utils.clip_grad_norm_(
-            self.diffusion.parameters(),
-            self.cfg.grad_clip_norm,
-            error_if_nonfinite=False,
-        )
-
-        self.optimizer.step()
-        self.optimizer.zero_grad()
-        self.lr_scheduler.step()
-
-        if self.ema is not None:
-            self.ema.step(self.diffusion)
-
-        info = {
-            "loss": loss.item(),
-            "grad_norm": float(grad_norm),
-            "lr": self.lr_scheduler.get_last_lr()[0],
-            "update_s": time.time() - start_time,
-        }
-
-        return info
-
-    def save(self, fp):
-        torch.save(self.state_dict(), fp)
-
-    def load(self, fp):
-        d = torch.load(fp)
-        missing_keys, unexpected_keys = self.load_state_dict(d, strict=False)
-        if len(missing_keys) > 0:
-            assert all(k.startswith("ema_diffusion.") for k in missing_keys)
-            logging.warning(
-                "DiffusionPolicy.load expected ema parameters in loaded state dict but none were found."
-            )
-        assert len(unexpected_keys) == 0
diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py
index f0454b8e..0b185e86 100644
--- a/lerobot/common/policies/factory.py
+++ b/lerobot/common/policies/factory.py
@@ -1,59 +1,61 @@
 import inspect
 
-from omegaconf import OmegaConf
+from omegaconf import DictConfig, OmegaConf
 
 from lerobot.common.utils import get_safe_torch_device
 
 
-def make_policy(cfg):
-    if cfg.policy.name == "tdmpc":
+def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
+    expected_kwargs = set(inspect.signature(policy_cfg_class).parameters)
+    assert set(hydra_cfg.policy).issuperset(
+        expected_kwargs
+    ), f"Hydra config is missing arguments: {set(hydra_cfg.policy).difference(expected_kwargs)}"
+    policy_cfg = policy_cfg_class(
+        **{
+            k: v
+            for k, v in OmegaConf.to_container(hydra_cfg.policy, resolve=True).items()
+            if k in expected_kwargs
+        }
+    )
+    return policy_cfg
+
+
+def make_policy(hydra_cfg: DictConfig):
+    if hydra_cfg.policy.name == "tdmpc":
         from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
 
         policy = TDMPCPolicy(
-            cfg.policy, n_obs_steps=cfg.n_obs_steps, n_action_steps=cfg.n_action_steps, device=cfg.device
+            hydra_cfg.policy,
+            n_obs_steps=hydra_cfg.n_obs_steps,
+            n_action_steps=hydra_cfg.n_action_steps,
+            device=hydra_cfg.device,
         )
-    elif cfg.policy.name == "diffusion":
-        from lerobot.common.policies.diffusion.policy import DiffusionPolicy
+    elif hydra_cfg.policy.name == "diffusion":
+        from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
+        from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
 
-        policy = DiffusionPolicy(
-            cfg=cfg.policy,
-            cfg_device=cfg.device,
-            cfg_noise_scheduler=cfg.noise_scheduler,
-            cfg_optimizer=cfg.optimizer,
-            cfg_ema=cfg.ema,
-            # n_obs_steps=cfg.n_obs_steps,
-            # n_action_steps=cfg.n_action_steps,
-            **cfg.policy,
-        )
-    elif cfg.policy.name == "act":
+        policy_cfg = _policy_cfg_from_hydra_cfg(DiffusionConfig, hydra_cfg)
+        policy = DiffusionPolicy(policy_cfg)
+        policy.to(get_safe_torch_device(hydra_cfg.device))
+    elif hydra_cfg.policy.name == "act":
         from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
         from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
 
-        expected_kwargs = set(inspect.signature(ActionChunkingTransformerConfig).parameters)
-        assert set(cfg.policy).issuperset(
-            expected_kwargs
-        ), f"Hydra config is missing arguments: {set(cfg.policy).difference(expected_kwargs)}"
-        policy_cfg = ActionChunkingTransformerConfig(
-            **{
-                k: v
-                for k, v in OmegaConf.to_container(cfg.policy, resolve=True).items()
-                if k in expected_kwargs
-            }
-        )
+        policy_cfg = _policy_cfg_from_hydra_cfg(ActionChunkingTransformerConfig, hydra_cfg)
         policy = ActionChunkingTransformerPolicy(policy_cfg)
-        policy.to(get_safe_torch_device(cfg.device))
+        policy.to(get_safe_torch_device(hydra_cfg.device))
     else:
-        raise ValueError(cfg.policy.name)
+        raise ValueError(hydra_cfg.policy.name)
 
-    if cfg.policy.pretrained_model_path:
+    if hydra_cfg.policy.pretrained_model_path:
         # TODO(rcadene): hack for old pretrained models from fowm
-        if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path:
-            if "offline" in cfg.policy.pretrained_model_path:
+        if hydra_cfg.policy.name == "tdmpc" and "fowm" in hydra_cfg.policy.pretrained_model_path:
+            if "offline" in hydra_cfg.policy.pretrained_model_path:
                 policy.step[0] = 25000
-            elif "final" in cfg.policy.pretrained_model_path:
+            elif "final" in hydra_cfg.policy.pretrained_model_path:
                 policy.step[0] = 100000
             else:
                 raise NotImplementedError()
-        policy.load(cfg.policy.pretrained_model_path)
+        policy.load(hydra_cfg.policy.pretrained_model_path)
 
     return policy
diff --git a/lerobot/configs/policy/act.yaml b/lerobot/configs/policy/act.yaml
index bd883613..5dd70d71 100644
--- a/lerobot/configs/policy/act.yaml
+++ b/lerobot/configs/policy/act.yaml
@@ -18,7 +18,7 @@ policy:
   pretrained_model_path:
 
   # Environment.
-  # Inherit these from the environment.
+  # Inherit these from the environment config.
   state_dim: ???
   action_dim: ???
 
diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml
index 005b0517..c8bdb0c4 100644
--- a/lerobot/configs/policy/diffusion.yaml
+++ b/lerobot/configs/policy/diffusion.yaml
@@ -1,17 +1,5 @@
 # @package _global_
 
-shape_meta:
-  # acceptable types: rgb, low_dim
-  obs:
-    image:
-      shape: [3, 96, 96]
-      type: rgb
-    agent_pos:
-      shape: [2]
-      type: low_dim
-  action:
-    shape: [2]
-
 seed: 100000
 horizon: 16
 n_obs_steps: 2
@@ -33,75 +21,70 @@ offline_prioritized_sampler: true
 policy:
   name: diffusion
 
-  shape_meta: ${shape_meta}
+  pretrained_model_path:
 
-  horizon: ${horizon}
+  # Environment.
+  # Inherit these from the environment config.
+  state_dim: ???
+  action_dim: ???
+  image_size:
+    - ${env.image_size}  # height
+    - ${env.image_size}  # width
+
+  # Inputs / output structure.
   n_obs_steps: ${n_obs_steps}
+  horizon: ${horizon}
   n_action_steps: ${n_action_steps}
-  num_inference_steps: 100
-  # crop_shape: null
-  diffusion_step_embed_dim: 128
+
+  # Vision preprocessing.
+  image_normalization_mean: [0.5, 0.5, 0.5]
+  image_normalization_std: [0.5, 0.5, 0.5]
+
+  # Architecture / modeling.
+  # Vision backbone.
+  vision_backbone: resnet18
+  crop_shape: [84, 84]
+  random_crop: True
+  use_pretrained_backbone: false
+  use_group_norm: True
+  spatial_softmax_num_keypoints: 32
+  # Unet.
   down_dims: [512, 1024, 2048]
   kernel_size: 5
   n_groups: 8
+  diffusion_step_embed_dim: 128
   film_scale_modulation: True
-
-  pretrained_model_path:
-
-  batch_size: 64
-
-  per_alpha: 0.6
-  per_beta: 0.4
-
-  balanced_sampling: false
-  utd: 1
-  offline_steps: ${offline_steps}
-  use_ema: true
-  lr_scheduler: cosine
-  lr_warmup_steps: 500
-  grad_clip_norm: 10
-
-  delta_timestamps:
-    observation.image: [-0.1, 0]
-    observation.state: [-0.1, 0]
-    action: [-0.1, 0, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1.0, 1.1, 1.2, 1.3, 1.4]
-
-  rgb_encoder:
-    backbone_name: resnet18
-    pretrained_backbone: false
-    use_group_norm: True
-    num_keypoints: 32
-    relu: true
-    norm_mean_std: [0.5, 0.5]  # for PushT the original impl normalizes to [-1, 1] (maybe not the case for robomimic envs)
-    crop_shape: [84, 84]
-    random_crop: True
-
-noise_scheduler:
-  _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
+  # Noise scheduler.
   num_train_timesteps: 100
+  beta_schedule: squaredcos_cap_v2
   beta_start: 0.0001
   beta_end: 0.02
-  beta_schedule: squaredcos_cap_v2
-  variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan
-  clip_sample: True # required when predict_epsilon=False
-  prediction_type: epsilon # or sample
+  variance_type: fixed_small
+  prediction_type: epsilon # epsilon / sample
+  clip_sample: True
 
-rgb_model:
-  pretrained: false
-  num_keypoints: 32
-  relu: true
+  # Inference
+  num_inference_steps: 100
 
-ema:
-  _target_: lerobot.common.policies.diffusion.model.ema_model.EMAModel
-  update_after_step: 0
-  inv_gamma: 1.0
-  power: 0.75
-  min_value: 0.0
-  max_value: 0.9999
-
-optimizer:
-  _target_: torch.optim.AdamW
+  # ---
+  # TODO(alexander-soare): Remove these from the policy config.
+  batch_size: 64
+  grad_clip_norm: 10
   lr: 1.0e-4
-  betas: [0.95, 0.999]
-  eps: 1.0e-8
-  weight_decay: 1.0e-6
+  lr_scheduler: cosine
+  lr_warmup_steps: 500
+  adam_betas: [0.95, 0.999]
+  adam_eps: 1.0e-8
+  adam_weight_decay: 1.0e-6
+  utd: 1
+  use_ema: true
+  ema_update_after_step: 0
+  ema_min_rate: 0.0
+  ema_max_rate: 0.9999
+  ema_inv_gamma: 1.0
+  ema_power: 0.75
+
+  delta_timestamps:
+    observation.images: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]"
+    observation.state: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]"
+    action: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1 - ${n_obs_steps} + ${policy.horizon})]"
diff --git a/tests/test_available.py b/tests/test_available.py
index b25a921f..373cc1a7 100644
--- a/tests/test_available.py
+++ b/tests/test_available.py
@@ -19,7 +19,7 @@ from lerobot.common.datasets.aloha import AlohaDataset
 from lerobot.common.datasets.pusht import PushtDataset
 
 from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
-from lerobot.common.policies.diffusion.policy import DiffusionPolicy
+from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
 from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
 
 

From 03b08eb74eb328206bfa214c20db894f4d4df9dc Mon Sep 17 00:00:00 2001
From: Alexander Soare <alexander.soare159@gmail.com>
Date: Tue, 16 Apr 2024 12:51:32 +0100
Subject: [PATCH 5/8] backup wip

---
 examples/3_train_policy.py                    |  60 +--
 lerobot/common/datasets/utils.py              |   4 +
 .../common/policies/act/configuration_act.py  |   4 +-
 .../diffusion/configuration_diffusion.py      |  70 ++-
 .../policies/diffusion/modeling_diffusion.py  | 409 +++++-------------
 lerobot/common/policies/factory.py            |   4 +-
 lerobot/configs/policy/diffusion.yaml         |  12 +-
 tests/test_examples.py                        |  33 +-
 8 files changed, 240 insertions(+), 356 deletions(-)

diff --git a/examples/3_train_policy.py b/examples/3_train_policy.py
index 64804d8f..012efddd 100644
--- a/examples/3_train_policy.py
+++ b/examples/3_train_policy.py
@@ -11,54 +11,54 @@ import torch
 from omegaconf import OmegaConf
 
 from lerobot.common.datasets.factory import make_dataset
+from lerobot.common.datasets.utils import cycle
+from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
 from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
 from lerobot.common.utils import init_hydra_config
 
 output_directory = Path("outputs/train/example_pusht_diffusion")
 os.makedirs(output_directory, exist_ok=True)
 
-overrides = [
-    "env=pusht",
-    "policy=diffusion",
-    # Adjust as you prefer. 5000 steps are needed to get something worth evaluating.
-    "offline_steps=5000",
-    "log_freq=250",
-    "device=cuda",
-]
-
-cfg = init_hydra_config("lerobot/configs/default.yaml", overrides)
-
-policy = DiffusionPolicy(
-    cfg=cfg.policy,
-    cfg_device=cfg.device,
-    cfg_noise_scheduler=cfg.noise_scheduler,
-    cfg_optimizer=cfg.optimizer,
-    cfg_ema=cfg.ema,
-    **cfg.policy,
-)
-policy.train()
+# Number of offline training steps (we'll only do offline training for this example.
+# Adjust as you prefer. 5000 steps are needed to get something worth evaluating.
+training_steps = 5000
+device = torch.device("cuda")
+log_freq = 250
 
+# Set up the dataset.
+cfg = init_hydra_config("lerobot/configs/default.yaml", overrides=["env=pusht"])
 dataset = make_dataset(cfg)
 
-# create dataloader for offline training
+# Set up the the policy.
+# Policies are initialized with a configuration class, in this case `DiffusionConfig`.
+# For this example, no arguments need to be passed because the defaults are set up for PushT.
+# If you're doing something different, you will likely need to change at least some of the defaults.
+cfg = DiffusionConfig()
+# TODO(alexander-soare): Remove LR scheduler from the policy.
+policy = DiffusionPolicy(cfg, lr_scheduler_num_training_steps=training_steps)
+policy.train()
+policy.to(device)
+
+# Create dataloader for offline training.
 dataloader = torch.utils.data.DataLoader(
     dataset,
     num_workers=4,
-    batch_size=cfg.policy.batch_size,
+    batch_size=cfg.batch_size,
     shuffle=True,
-    pin_memory=cfg.device != "cpu",
+    pin_memory=device != torch.device("cpu"),
     drop_last=True,
 )
 
-for step, batch in enumerate(dataloader):
-    info = policy(batch, step)
-
-    if step % cfg.log_freq == 0:
-        num_samples = (step + 1) * cfg.policy.batch_size
+# Run training loop.
+dataloader = cycle(dataloader)
+for step in range(training_steps):
+    batch = {k: v.to(device, non_blocking=True) for k, v in next(dataloader).items()}
+    info = policy(batch)
+    if step % log_freq == 0:
+        num_samples = (step + 1) * cfg.batch_size
         loss = info["loss"]
         update_s = info["update_s"]
-        print(f"step:{step} samples:{num_samples} loss:{loss:.3f} update_time:{update_s:.3f}(seconds)")
-
+        print(f"step: {step} samples: {num_samples} loss: {loss:.3f} update_time: {update_s:.3f} (seconds)")
 
 # Save the policy, configuration, and normalization stats for later use.
 policy.save(output_directory / "model.pt")
diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py
index e67d8a04..34430ff1 100644
--- a/lerobot/common/datasets/utils.py
+++ b/lerobot/common/datasets/utils.py
@@ -208,6 +208,10 @@ def compute_stats(dataset, batch_size=32, max_num_samples=None):
 
 
 def cycle(iterable):
+    """The equivalent of itertools.cycle, but safe for Pytorch dataloaders.
+
+    See https://github.com/pytorch/pytorch/issues/23900 for information on why itertools.cycle is not safe.
+    """
     iterator = iter(iterable)
     while True:
         try:
diff --git a/lerobot/common/policies/act/configuration_act.py b/lerobot/common/policies/act/configuration_act.py
index 72d35eb3..1b438f2d 100644
--- a/lerobot/common/policies/act/configuration_act.py
+++ b/lerobot/common/policies/act/configuration_act.py
@@ -26,8 +26,8 @@ class ActionChunkingTransformerConfig:
         image_normalization_std: Value by which to divide the input image pixels (after the mean has been
             subtracted).
         vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
-        use_pretrained_backbone: Whether the backbone should be initialized with ImageNet, pretrained weights
-            from torchvision.
+        use_pretrained_backbone: Whether the backbone should be initialized with pretrained weights from
+            torchvision.
         replace_final_stride_with_dilation: Whether to replace the ResNet's final 2x2 stride with a dilated
             convolution.
         pre_norm: Whether to use "pre-norm" in the transformer blocks.
diff --git a/lerobot/common/policies/diffusion/configuration_diffusion.py b/lerobot/common/policies/diffusion/configuration_diffusion.py
index 272c80ea..d8820a0b 100644
--- a/lerobot/common/policies/diffusion/configuration_diffusion.py
+++ b/lerobot/common/policies/diffusion/configuration_diffusion.py
@@ -13,9 +13,49 @@ class DiffusionConfig:
     Args:
         state_dim: Dimensionality of the observation state space (excluding images).
         action_dim: Dimensionality of the action space.
+        image_size: (H, W) size of the input images.
         n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
             current step and additional steps going back).
-        horizon: Diffusion model action prediction horizon as detailed in the main policy documentation.
+        horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
+        n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
+            See `DiffusionPolicy.select_action` for more details.
+        image_normalization_mean: Value to subtract from the input image pixels (inputs are assumed to be in
+            [0, 1]) for normalization.
+        image_normalization_std: Value by which to divide the input image pixels (after the mean has been
+            subtracted).
+        vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
+        crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
+            within the image size. If None, no cropping is done.
+        crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
+            mode).
+        use_pretrained_backbone: Whether the backbone should be initialized with pretrained weights from
+            torchvision.
+        use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
+            The group sizes are set to be about 16 (to be precise, feature_dim // 16).
+        spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax.
+        down_dims: Feature dimension for each stage of temporal downsampling in the diffusion modeling Unet.
+            You may provide a variable number of dimensions, therefore also controlling the degree of
+            downsampling.
+        kernel_size: The convolutional kernel size of the diffusion modeling Unet.
+        n_groups: Number of groups used in the group norm of the Unet's convolutional blocks.
+        diffusion_step_embed_dim: The Unet is conditioned on the diffusion timestep via a small non-linear
+            network. This is the output dimension of that network, i.e., the embedding dimension.
+        use_film_scale_modulation: FiLM (https://arxiv.org/abs/1709.07871) is used for the Unet conditioning.
+            Bias modulation is used be default, while this parameter indicates whether to also use scale
+            modulation.
+        num_train_timesteps: Number of diffusion steps for the forward diffusion schedule.
+        beta_schedule: Name of the diffusion beta schedule as per DDPMScheduler from Hugging Face diffusers.
+        beta_start: Beta value for the first forward-diffusion step.
+        beta_end: Beta value for the last forward-diffusion step.
+        prediction_type: The type of prediction that the diffusion modeling Unet makes. Choose from "epsilon"
+            or "sample". These have equivalent outcomes from a latent variable modeling perspective, but
+            "epsilon" has been shown to work better in many deep neural network settings.
+        clip_sample: Whether to clip the sample to [-`clip_sample_range`, +`clip_sample_range`] for each
+            denoising step at inference time. WARNING: you will need to make sure your action-space is
+            normalized to fit within this range.
+        clip_sample_range: The magnitude of the clipping range as described above.
+        num_inference_steps: Number of reverse diffusion steps to use at inference time (steps are evenly
+            spaced). If not provided, this defaults to be the same as `num_train_timesteps`.
     """
 
     # Environment.
@@ -36,7 +76,7 @@ class DiffusionConfig:
     # Architecture / modeling.
     # Vision backbone.
     vision_backbone: str = "resnet18"
-    crop_shape: tuple[int, int] = (84, 84)
+    crop_shape: tuple[int, int] | None = (84, 84)
     crop_is_random: bool = True
     use_pretrained_backbone: bool = False
     use_group_norm: bool = True
@@ -46,18 +86,18 @@ class DiffusionConfig:
     kernel_size: int = 5
     n_groups: int = 8
     diffusion_step_embed_dim: int = 128
-    film_scale_modulation: bool = True
+    use_film_scale_modulation: bool = True
     # Noise scheduler.
     num_train_timesteps: int = 100
     beta_schedule: str = "squaredcos_cap_v2"
     beta_start: float = 0.0001
     beta_end: float = 0.02
-    variance_type: str = "fixed_small"
     prediction_type: str = "epsilon"
-    clip_sample: True
+    clip_sample: bool = True
+    clip_sample_range: float = 1.0
 
     # Inference
-    num_inference_steps: int = 100
+    num_inference_steps: int | None = None
 
     # ---
     # TODO(alexander-soare): Remove these from the policy config.
@@ -72,12 +112,24 @@ class DiffusionConfig:
     utd: int = 1
     use_ema: bool = True
     ema_update_after_step: int = 0
-    ema_min_rate: float = 0.0
-    ema_max_rate: float = 0.9999
+    ema_min_alpha: float = 0.0
+    ema_max_alpha: float = 0.9999
     ema_inv_gamma: float = 1.0
     ema_power: float = 0.75
 
     def __post_init__(self):
         """Input validation (not exhaustive)."""
         if not self.vision_backbone.startswith("resnet"):
-            raise ValueError("`vision_backbone` must be one of the ResNet variants.")
+            raise ValueError(
+                f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
+            )
+        if self.crop_shape[0] > self.image_size[0] or self.crop_shape[1] > self.image_size[1]:
+            raise ValueError(
+                f"`crop_shape` should fit within `image_size`. Got {self.crop_shape} for `crop_shape` and "
+                f"{self.image_size} for `image_size`."
+            )
+        supported_prediction_types = ["epsilon", "sample"]
+        if self.prediction_type not in supported_prediction_types:
+            raise ValueError(
+                f"`prediction_type` must be one of {supported_prediction_types}. Got {self.prediction_type}."
+            )
diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py
index 4853dbcf..a95effb2 100644
--- a/lerobot/common/policies/diffusion/modeling_diffusion.py
+++ b/lerobot/common/policies/diffusion/modeling_diffusion.py
@@ -1,8 +1,10 @@
 """
 TODO(alexander-soare):
   - Remove reliance on Robomimic for SpatialSoftmax.
-  - Remove reliance on diffusers for DDPMScheduler.
+  - Remove reliance on diffusers for DDPMScheduler and LR scheduler.
   - Move EMA out of policy.
+  - Consolidate _DiffusionUnetImagePolicy into DiffusionPolicy.
+  - One more pass on comments and documentation.
 """
 
 import copy
@@ -10,10 +12,10 @@ import logging
 import math
 import time
 from collections import deque
+from itertools import chain
 from typing import Callable
 
 import einops
-import hydra
 import torch
 import torch.nn.functional as F  # noqa: N812
 import torchvision
@@ -23,12 +25,12 @@ from robomimic.models.base_nets import SpatialSoftmax
 from torch import Tensor, nn
 from torch.nn.modules.batchnorm import _BatchNorm
 
+from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
 from lerobot.common.policies.utils import (
     get_device_from_parameters,
     get_dtype_from_parameters,
     populate_queues,
 )
-from lerobot.common.utils import get_safe_torch_device
 
 logger = logging.getLogger(__name__)
 
@@ -41,69 +43,29 @@ class DiffusionPolicy(nn.Module):
 
     name = "diffusion"
 
-    def __init__(
-        self,
-        cfg,
-        cfg_device,
-        cfg_noise_scheduler,
-        cfg_optimizer,
-        cfg_ema,
-        shape_meta: dict,
-        horizon,
-        n_action_steps,
-        n_obs_steps,
-        num_inference_steps=None,
-        diffusion_step_embed_dim=256,
-        down_dims=(256, 512, 1024),
-        kernel_size=5,
-        n_groups=8,
-        film_scale_modulation=True,
-        **_,
-    ):
+    def __init__(self, cfg: DiffusionConfig, lr_scheduler_num_training_steps: int):
         super().__init__()
         self.cfg = cfg
-        self.n_obs_steps = n_obs_steps
-        self.n_action_steps = n_action_steps
 
         # queues are populated during rollout of the policy, they contain the n latest observations and actions
         self._queues = None
 
-        noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
-
-        self.diffusion = _DiffusionUnetImagePolicy(
-            cfg,
-            shape_meta=shape_meta,
-            noise_scheduler=noise_scheduler,
-            horizon=horizon,
-            n_action_steps=n_action_steps,
-            n_obs_steps=n_obs_steps,
-            num_inference_steps=num_inference_steps,
-            diffusion_step_embed_dim=diffusion_step_embed_dim,
-            down_dims=down_dims,
-            kernel_size=kernel_size,
-            n_groups=n_groups,
-            film_scale_modulation=film_scale_modulation,
-        )
-
-        self.device = get_safe_torch_device(cfg_device)
-        self.diffusion.to(self.device)
+        self.diffusion = _DiffusionUnetImagePolicy(cfg)
 
         # TODO(alexander-soare): This should probably be managed outside of the policy class.
         self.ema_diffusion = None
         self.ema = None
         if self.cfg.use_ema:
             self.ema_diffusion = copy.deepcopy(self.diffusion)
-            self.ema = hydra.utils.instantiate(
-                cfg_ema,
-                model=self.ema_diffusion,
-            )
+            self.ema = _EMA(cfg, model=self.ema_diffusion)
 
-        self.optimizer = hydra.utils.instantiate(
-            cfg_optimizer,
-            params=self.diffusion.parameters(),
+        # TODO(alexander-soare): Move optimizer out of policy.
+        self.optimizer = torch.optim.Adam(
+            self.diffusion.parameters(), cfg.lr, cfg.adam_betas, cfg.adam_eps, cfg.adam_weight_decay
         )
 
-        # TODO(rcadene): modify lr scheduler so that it doesnt depend on epochs but steps
+        # TODO(alexander-soare): Move LR scheduler out of policy.
+        # TODO(rcadene): modify lr scheduler so that it doesn't depend on epochs but steps
         self.global_step = 0
 
         # configure lr scheduler
@@ -111,7 +73,7 @@ class DiffusionPolicy(nn.Module):
             cfg.lr_scheduler,
             optimizer=self.optimizer,
             num_warmup_steps=cfg.lr_warmup_steps,
-            num_training_steps=cfg.offline_steps,
+            num_training_steps=lr_scheduler_num_training_steps,
             # pytorch assumes stepping LRScheduler every epoch
             # however huggingface diffusers steps it every batch
             last_epoch=self.global_step - 1,
@@ -122,9 +84,9 @@ class DiffusionPolicy(nn.Module):
         Clear observation and action queues. Should be called on `env.reset()`
         """
         self._queues = {
-            "observation.image": deque(maxlen=self.n_obs_steps),
-            "observation.state": deque(maxlen=self.n_obs_steps),
-            "action": deque(maxlen=self.n_action_steps),
+            "observation.image": deque(maxlen=self.cfg.n_obs_steps),
+            "observation.state": deque(maxlen=self.cfg.n_obs_steps),
+            "action": deque(maxlen=self.cfg.n_action_steps),
         }
 
     @torch.no_grad
@@ -138,11 +100,13 @@ class DiffusionPolicy(nn.Module):
           - The diffusion model generates `horizon` steps worth of actions.
           - `n_action_steps` worth of actions are actually kept for execution, starting from the current step.
         Schematically this looks like:
+            ----------------------------------------------------------------------------------------------
             (legend: o = n_obs_steps, h = horizon, a = n_action_steps)
             |timestep            | n-o+1 | n-o+2 | ..... | n     | ..... | n+a-1 | n+a   | ..... |n-o+1+h|
-            |observation is used | YES   | YES   | ..... | NO    | NO    | NO    | NO    | NO    | NO    |
+            |observation is used | YES   | YES   | YES   | NO    | NO    | NO    | NO    | NO    | NO    |
             |action is generated | YES   | YES   | YES   | YES   | YES   | YES   | YES   | YES   | YES   |
             |action is used      | NO    | NO    | NO    | YES   | YES   | YES   | NO    | NO    | NO    |
+            ----------------------------------------------------------------------------------------------
         Note that this means we require: `n_action_steps < horizon - n_obs_steps + 1`. Also, note that
         "horizon" may not the best name to describe what the variable actually means, because this period is
         actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
@@ -213,57 +177,41 @@ class DiffusionPolicy(nn.Module):
 
 
 class _DiffusionUnetImagePolicy(nn.Module):
-    def __init__(
-        self,
-        cfg,
-        shape_meta: dict,
-        noise_scheduler: DDPMScheduler,
-        horizon,
-        n_action_steps,
-        n_obs_steps,
-        num_inference_steps=None,
-        diffusion_step_embed_dim=256,
-        down_dims=(256, 512, 1024),
-        kernel_size=5,
-        n_groups=8,
-        film_scale_modulation=True,
-    ):
+    def __init__(self, cfg: DiffusionConfig):
         super().__init__()
-        action_shape = shape_meta["action"]["shape"]
-        assert len(action_shape) == 1
-        action_dim = action_shape[0]
-
-        self.rgb_encoder = _RgbEncoder(input_shape=shape_meta.obs.image.shape, **cfg.rgb_encoder)
+        self.cfg = cfg
 
+        self.rgb_encoder = _RgbEncoder(cfg)
         self.unet = _ConditionalUnet1D(
-            input_dim=action_dim,
-            global_cond_dim=(action_dim + self.rgb_encoder.feature_dim) * n_obs_steps,
-            diffusion_step_embed_dim=diffusion_step_embed_dim,
-            down_dims=down_dims,
-            kernel_size=kernel_size,
-            n_groups=n_groups,
-            film_scale_modulation=film_scale_modulation,
+            cfg, global_cond_dim=(cfg.action_dim + self.rgb_encoder.feature_dim) * cfg.n_obs_steps
         )
 
-        self.noise_scheduler = noise_scheduler
-        self.horizon = horizon
-        self.action_dim = action_dim
-        self.n_action_steps = n_action_steps
-        self.n_obs_steps = n_obs_steps
+        self.noise_scheduler = DDPMScheduler(
+            num_train_timesteps=cfg.num_train_timesteps,
+            beta_start=cfg.beta_start,
+            beta_end=cfg.beta_end,
+            beta_schedule=cfg.beta_schedule,
+            variance_type="fixed_small",
+            clip_sample=cfg.clip_sample,
+            clip_sample_range=cfg.clip_sample_range,
+            prediction_type=cfg.prediction_type,
+        )
 
-        if num_inference_steps is None:
-            num_inference_steps = noise_scheduler.config.num_train_timesteps
-
-        self.num_inference_steps = num_inference_steps
+        if cfg.num_inference_steps is None:
+            self.num_inference_steps = self.noise_scheduler.config.num_train_timesteps
+        else:
+            self.num_inference_steps = cfg.num_inference_steps
 
     # ========= inference  ============
-    def conditional_sample(self, batch_size, global_cond=None, generator=None):
+    def conditional_sample(
+        self, batch_size: int, global_cond: Tensor | None = None, generator: torch.Generator | None = None
+    ) -> Tensor:
         device = get_device_from_parameters(self)
         dtype = get_dtype_from_parameters(self)
 
         # Sample prior.
         sample = torch.randn(
-            size=(batch_size, self.horizon, self.action_dim),
+            size=(batch_size, self.cfg.horizon, self.cfg.action_dim),
             dtype=dtype,
             device=device,
             generator=generator,
@@ -283,7 +231,7 @@ class _DiffusionUnetImagePolicy(nn.Module):
 
         return sample
 
-    def generate_actions(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
+    def generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
         """
         This function expects `batch` to have (at least):
         {
@@ -293,8 +241,7 @@ class _DiffusionUnetImagePolicy(nn.Module):
         """
         assert set(batch).issuperset({"observation.state", "observation.image"})
         batch_size, n_obs_steps = batch["observation.state"].shape[:2]
-        assert n_obs_steps == self.n_obs_steps
-        assert self.n_obs_steps == n_obs_steps
+        assert n_obs_steps == self.cfg.n_obs_steps
 
         # Extract image feature (first combine batch and sequence dims).
         img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
@@ -307,13 +254,13 @@ class _DiffusionUnetImagePolicy(nn.Module):
         sample = self.conditional_sample(batch_size, global_cond=global_cond)
 
         # `horizon` steps worth of actions (from the first observation).
-        action = sample[..., : self.action_dim]
+        actions = sample[..., : self.cfg.action_dim]
         # Extract `n_action_steps` steps worth of actions (from the current observation).
         start = n_obs_steps - 1
-        end = start + self.n_action_steps
-        action = action[:, start:end]
+        end = start + self.cfg.n_action_steps
+        actions = actions[:, start:end]
 
-        return action
+        return actions
 
     def compute_loss(self, batch: dict[str, Tensor]) -> Tensor:
         """
@@ -329,9 +276,8 @@ class _DiffusionUnetImagePolicy(nn.Module):
         assert set(batch).issuperset({"observation.state", "observation.image", "action", "action_is_pad"})
         batch_size, n_obs_steps = batch["observation.state"].shape[:2]
         horizon = batch["action"].shape[1]
-        assert horizon == self.horizon
-        assert n_obs_steps == self.n_obs_steps
-        assert self.n_obs_steps == n_obs_steps
+        assert horizon == self.cfg.horizon
+        assert n_obs_steps == self.cfg.n_obs_steps
 
         # Extract image feature (first combine batch and sequence dims).
         img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
@@ -359,14 +305,13 @@ class _DiffusionUnetImagePolicy(nn.Module):
         pred = self.unet(noisy_trajectory, timesteps, global_cond=global_cond)
 
         # Compute the loss.
-        # The targe is either the original trajectory, or the noise.
-        pred_type = self.noise_scheduler.config.prediction_type
-        if pred_type == "epsilon":
+        # The target is either the original trajectory, or the noise.
+        if self.cfg.prediction_type == "epsilon":
             target = eps
-        elif pred_type == "sample":
+        elif self.cfg.prediction_type == "sample":
             target = batch["action"]
         else:
-            raise ValueError(f"Unsupported prediction type {pred_type}")
+            raise ValueError(f"Unsupported prediction type {self.cfg.prediction_type}")
 
         loss = F.mse_loss(pred, target, reduction="none")
 
@@ -384,64 +329,35 @@ class _RgbEncoder(nn.Module):
     Includes the ability to normalize and crop the image first.
     """
 
-    def __init__(
-        self,
-        input_shape: tuple[int, int, int],
-        norm_mean_std: tuple[float, float] = [1.0, 1.0],
-        crop_shape: tuple[int, int] | None = None,
-        random_crop: bool = False,
-        backbone_name: str = "resnet18",
-        pretrained_backbone: bool = False,
-        use_group_norm: bool = False,
-        num_keypoints: int = 32,
-    ):
-        """
-        Args:
-            input_shape: channel-first input shape (C, H, W)
-            norm_mean_std: mean and standard deviation used for image normalization. Images are normalized as
-                (image - mean) / std.
-            crop_shape: (H, W) shape to crop to (must fit within the input shape). If not provided, no
-                cropping is done.
-            random_crop: Whether the crop should be random at training time (it's always a center crop in
-                eval mode).
-            backbone_name: The name of one of the available resnet models from torchvision (eg resnet18).
-            pretrained_backbone: whether to use timm pretrained weights.
-            use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
-                The group sizes are set to be about 16 (to be precise, feature_dim // 16).
-            num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image).
-        """
+    def __init__(self, cfg: DiffusionConfig):
         super().__init__()
-        if input_shape[0] != 3:
-            raise ValueError("Only RGB images are handled")
-        if not backbone_name.startswith("resnet"):
-            raise ValueError(
-                "Only resnet is supported for now (because of the assumption that 'layer4' is the output layer)"
-            )
-
         # Set up optional preprocessing.
-        if norm_mean_std == [1.0, 1.0]:
+        if all(v == 1.0 for v in chain(cfg.image_normalization_mean, cfg.image_normalization_std)):
             self.normalizer = nn.Identity()
         else:
-            self.normalizer = torchvision.transforms.Normalize(mean=norm_mean_std[0], std=norm_mean_std[1])
-
-        if crop_shape is not None:
+            self.normalizer = torchvision.transforms.Normalize(
+                mean=cfg.image_normalization_mean, std=cfg.image_normalization_std
+            )
+        if cfg.crop_shape is not None:
             self.do_crop = True
             # Always use center crop for eval
-            self.center_crop = torchvision.transforms.CenterCrop(crop_shape)
-            if random_crop:
-                self.maybe_random_crop = torchvision.transforms.RandomCrop(crop_shape)
+            self.center_crop = torchvision.transforms.CenterCrop(cfg.crop_shape)
+            if cfg.crop_is_random:
+                self.maybe_random_crop = torchvision.transforms.RandomCrop(cfg.crop_shape)
             else:
                 self.maybe_random_crop = self.center_crop
         else:
             self.do_crop = False
 
         # Set up backbone.
-        backbone_model = getattr(torchvision.models, backbone_name)(pretrained=pretrained_backbone)
+        backbone_model = getattr(torchvision.models, cfg.vision_backbone)(
+            pretrained=cfg.use_pretrained_backbone
+        )
         # Note: This assumes that the layer4 feature map is children()[-3]
         # TODO(alexander-soare): Use a safer alternative.
         self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
-        if use_group_norm:
-            if pretrained_backbone:
+        if cfg.use_group_norm:
+            if cfg.use_pretrained_backbone:
                 raise ValueError(
                     "You can't replace BatchNorm in a pretrained model without ruining the weights!"
                 )
@@ -454,10 +370,10 @@ class _RgbEncoder(nn.Module):
         # Set up pooling and final layers.
         # Use a dry run to get the feature map shape.
         with torch.inference_mode():
-            feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:])
-        self.pool = SpatialSoftmax(feat_map_shape, num_kp=num_keypoints)
-        self.feature_dim = num_keypoints * 2
-        self.out = nn.Linear(num_keypoints * 2, self.feature_dim)
+            feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, 3, *cfg.image_size))).shape[1:])
+        self.pool = SpatialSoftmax(feat_map_shape, num_kp=cfg.spatial_softmax_num_keypoints)
+        self.feature_dim = cfg.spatial_softmax_num_keypoints * 2
+        self.out = nn.Linear(cfg.spatial_softmax_num_keypoints * 2, self.feature_dim)
         self.relu = nn.ReLU()
 
     def forward(self, x: Tensor) -> Tensor:
@@ -516,16 +432,18 @@ def _replace_submodules(
 
 
 class _SinusoidalPosEmb(nn.Module):
-    def __init__(self, dim):
+    """1D sinusoidal positional embeddings as in Attention is All You Need."""
+
+    def __init__(self, dim: int):
         super().__init__()
         self.dim = dim
 
-    def forward(self, x):
+    def forward(self, x: Tensor) -> Tensor:
         device = x.device
         half_dim = self.dim // 2
         emb = math.log(10000) / (half_dim - 1)
         emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
-        emb = x[:, None] * emb[None, :]
+        emb = x.unsqueeze(-1) * emb.unsqueeze(0)
         emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
         return emb
 
@@ -549,92 +467,46 @@ class _Conv1dBlock(nn.Module):
 class _ConditionalUnet1D(nn.Module):
     """A 1D convolutional UNet with FiLM modulation for conditioning.
 
-    Two types of conditioning can be applied:
-    - Global: Conditioning information that is aggregated over the whole observation window. This is
-        incorporated via the FiLM technique in the residual convolution blocks of the Unet's encoder/decoder.
-    - Local: Conditioning information for each timestep in the observation window. This is incorporated
-        by encoding the information via 1D convolutions and adding the resulting embeddings to the inputs and
-        outputs of the Unet's encoder/decoder.
+    Note: this removes local conditioning as compared to the original diffusion policy code.
     """
 
-    def __init__(
-        self,
-        input_dim: int,
-        local_cond_dim: int | None = None,
-        global_cond_dim: int | None = None,
-        diffusion_step_embed_dim: int = 256,
-        down_dims: int | None = None,
-        kernel_size: int = 3,
-        n_groups: int = 8,
-        film_scale_modulation: bool = False,
-    ):
+    def __init__(self, cfg: DiffusionConfig, global_cond_dim: int):
         super().__init__()
 
-        if down_dims is None:
-            down_dims = [256, 512, 1024]
+        self.cfg = cfg
 
         # Encoder for the diffusion timestep.
         self.diffusion_step_encoder = nn.Sequential(
-            _SinusoidalPosEmb(diffusion_step_embed_dim),
-            nn.Linear(diffusion_step_embed_dim, diffusion_step_embed_dim * 4),
+            _SinusoidalPosEmb(cfg.diffusion_step_embed_dim),
+            nn.Linear(cfg.diffusion_step_embed_dim, cfg.diffusion_step_embed_dim * 4),
             nn.Mish(),
-            nn.Linear(diffusion_step_embed_dim * 4, diffusion_step_embed_dim),
+            nn.Linear(cfg.diffusion_step_embed_dim * 4, cfg.diffusion_step_embed_dim),
         )
 
         # The FiLM conditioning dimension.
-        cond_dim = diffusion_step_embed_dim
-        if global_cond_dim is not None:
-            cond_dim += global_cond_dim
-
-        self.local_cond_down_encoder = None
-        self.local_cond_up_encoder = None
-        if local_cond_dim is not None:
-            # Encoder for the local conditioning. The output gets added to the Unet encoder input.
-            self.local_cond_down_encoder = _ConditionalResidualBlock1D(
-                local_cond_dim,
-                down_dims[0],
-                cond_dim=cond_dim,
-                kernel_size=kernel_size,
-                n_groups=n_groups,
-                film_scale_modulation=film_scale_modulation,
-            )
-            # Encoder for the local conditioning. The output gets added to the Unet encoder output.
-            self.local_cond_up_encoder = _ConditionalResidualBlock1D(
-                local_cond_dim,
-                down_dims[0],
-                cond_dim=cond_dim,
-                kernel_size=kernel_size,
-                n_groups=n_groups,
-                film_scale_modulation=film_scale_modulation,
-            )
+        cond_dim = cfg.diffusion_step_embed_dim + global_cond_dim
 
         # In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we
         # just reverse these.
-        in_out = [(input_dim, down_dims[0])] + list(zip(down_dims[:-1], down_dims[1:], strict=True))
+        in_out = [(cfg.action_dim, cfg.down_dims[0])] + list(
+            zip(cfg.down_dims[:-1], cfg.down_dims[1:], strict=True)
+        )
 
         # Unet encoder.
+        common_res_block_kwargs = {
+            "cond_dim": cond_dim,
+            "kernel_size": cfg.kernel_size,
+            "n_groups": cfg.n_groups,
+            "use_film_scale_modulation": cfg.use_film_scale_modulation,
+        }
         self.down_modules = nn.ModuleList([])
         for ind, (dim_in, dim_out) in enumerate(in_out):
             is_last = ind >= (len(in_out) - 1)
             self.down_modules.append(
                 nn.ModuleList(
                     [
-                        _ConditionalResidualBlock1D(
-                            dim_in,
-                            dim_out,
-                            cond_dim=cond_dim,
-                            kernel_size=kernel_size,
-                            n_groups=n_groups,
-                            film_scale_modulation=film_scale_modulation,
-                        ),
-                        _ConditionalResidualBlock1D(
-                            dim_out,
-                            dim_out,
-                            cond_dim=cond_dim,
-                            kernel_size=kernel_size,
-                            n_groups=n_groups,
-                            film_scale_modulation=film_scale_modulation,
-                        ),
+                        _ConditionalResidualBlock1D(dim_in, dim_out, **common_res_block_kwargs),
+                        _ConditionalResidualBlock1D(dim_out, dim_out, **common_res_block_kwargs),
                         # Downsample as long as it is not the last block.
                         nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(),
                     ]
@@ -644,22 +516,8 @@ class _ConditionalUnet1D(nn.Module):
         # Processing in the middle of the auto-encoder.
         self.mid_modules = nn.ModuleList(
             [
-                _ConditionalResidualBlock1D(
-                    down_dims[-1],
-                    down_dims[-1],
-                    cond_dim=cond_dim,
-                    kernel_size=kernel_size,
-                    n_groups=n_groups,
-                    film_scale_modulation=film_scale_modulation,
-                ),
-                _ConditionalResidualBlock1D(
-                    down_dims[-1],
-                    down_dims[-1],
-                    cond_dim=cond_dim,
-                    kernel_size=kernel_size,
-                    n_groups=n_groups,
-                    film_scale_modulation=film_scale_modulation,
-                ),
+                _ConditionalResidualBlock1D(cfg.down_dims[-1], cfg.down_dims[-1], **common_res_block_kwargs),
+                _ConditionalResidualBlock1D(cfg.down_dims[-1], cfg.down_dims[-1], **common_res_block_kwargs),
             ]
         )
 
@@ -670,22 +528,9 @@ class _ConditionalUnet1D(nn.Module):
             self.up_modules.append(
                 nn.ModuleList(
                     [
-                        _ConditionalResidualBlock1D(
-                            dim_in * 2,  # x2 as it takes the encoder's skip connection as well
-                            dim_out,
-                            cond_dim=cond_dim,
-                            kernel_size=kernel_size,
-                            n_groups=n_groups,
-                            film_scale_modulation=film_scale_modulation,
-                        ),
-                        _ConditionalResidualBlock1D(
-                            dim_out,
-                            dim_out,
-                            cond_dim=cond_dim,
-                            kernel_size=kernel_size,
-                            n_groups=n_groups,
-                            film_scale_modulation=film_scale_modulation,
-                        ),
+                        # dim_in * 2, because it takes the encoder's skip connection as well
+                        _ConditionalResidualBlock1D(dim_in * 2, dim_out, **common_res_block_kwargs),
+                        _ConditionalResidualBlock1D(dim_out, dim_out, **common_res_block_kwargs),
                         # Upsample as long as it is not the last block.
                         nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(),
                     ]
@@ -693,29 +538,22 @@ class _ConditionalUnet1D(nn.Module):
             )
 
         self.final_conv = nn.Sequential(
-            _Conv1dBlock(down_dims[0], down_dims[0], kernel_size=kernel_size),
-            nn.Conv1d(down_dims[0], input_dim, 1),
+            _Conv1dBlock(cfg.down_dims[0], cfg.down_dims[0], kernel_size=cfg.kernel_size),
+            nn.Conv1d(cfg.down_dims[0], cfg.action_dim, 1),
         )
 
-    def forward(self, x: Tensor, timestep: Tensor | int, local_cond=None, global_cond=None) -> Tensor:
+    def forward(self, x: Tensor, timestep: Tensor | int, global_cond=None) -> Tensor:
         """
         Args:
             x: (B, T, input_dim) tensor for input to the Unet.
             timestep: (B,) tensor of (timestep_we_are_denoising_from - 1).
-            local_cond: (B, T, local_cond_dim)
             global_cond: (B, global_cond_dim)
             output: (B, T, input_dim)
         Returns:
-            (B, T, input_dim)
+            (B, T, input_dim) diffusion model prediction.
         """
         # For 1D convolutions we'll need feature dimension first.
         x = einops.rearrange(x, "b t d -> b d t")
-        if local_cond is not None:
-            if self.local_cond_down_encoder is None or self.local_cond_up_encoder is None:
-                raise ValueError(
-                    "`local_cond` was provided but the relevant encoders weren't built at initialization."
-                )
-            local_cond = einops.rearrange(local_cond, "b t d -> b d t")
 
         timesteps_embed = self.diffusion_step_encoder(timestep)
 
@@ -725,11 +563,10 @@ class _ConditionalUnet1D(nn.Module):
         else:
             global_feature = timesteps_embed
 
+        # Run encoder, keeping track of skip features to pass to the decoder.
         encoder_skip_features: list[Tensor] = []
-        for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
+        for resnet, resnet2, downsample in self.down_modules:
             x = resnet(x, global_feature)
-            if idx == 0 and local_cond is not None:
-                x = x + self.local_cond_down_encoder(local_cond, global_feature)
             x = resnet2(x, global_feature)
             encoder_skip_features.append(x)
             x = downsample(x)
@@ -737,14 +574,10 @@ class _ConditionalUnet1D(nn.Module):
         for mid_module in self.mid_modules:
             x = mid_module(x, global_feature)
 
-        for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
+        # Run decoder, using the skip features from the encoder.
+        for resnet, resnet2, upsample in self.up_modules:
             x = torch.cat((x, encoder_skip_features.pop()), dim=1)
             x = resnet(x, global_feature)
-            # Note: The condition in the original implementation is:
-            # if idx == len(self.up_modules) and local_cond is not None:
-            # But as they mention in their comments, this is incorrect. We use the correct condition here.
-            if idx == (len(self.up_modules) - 1) and local_cond is not None:
-                x = x + self.local_cond_up_encoder(local_cond, global_feature)
             x = resnet2(x, global_feature)
             x = upsample(x)
 
@@ -766,17 +599,17 @@ class _ConditionalResidualBlock1D(nn.Module):
         n_groups: int = 8,
         # Set to True to do scale modulation with FiLM as well as bias modulation (defaults to False meaning
         # FiLM just modulates bias).
-        film_scale_modulation: bool = False,
+        use_film_scale_modulation: bool = False,
     ):
         super().__init__()
 
-        self.film_scale_modulation = film_scale_modulation
+        self.use_film_scale_modulation = use_film_scale_modulation
         self.out_channels = out_channels
 
         self.conv1 = _Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups)
 
         # FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale.
-        cond_channels = out_channels * 2 if film_scale_modulation else out_channels
+        cond_channels = out_channels * 2 if use_film_scale_modulation else out_channels
         self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels))
 
         self.conv2 = _Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups)
@@ -798,7 +631,7 @@ class _ConditionalResidualBlock1D(nn.Module):
 
         # Get condition embedding. Unsqueeze for broadcasting to `out`, resulting in (B, out_channels, 1).
         cond_embed = self.cond_encoder(cond).unsqueeze(-1)
-        if self.film_scale_modulation:
+        if self.use_film_scale_modulation:
             # Treat the embedding as a list of scales and biases.
             scale = cond_embed[:, : self.out_channels]
             bias = cond_embed[:, self.out_channels :]
@@ -817,9 +650,7 @@ class _EMA:
     Exponential Moving Average of models weights
     """
 
-    def __init__(
-        self, model, update_after_step=0, inv_gamma=1.0, power=2 / 3, min_value=0.0, max_value=0.9999
-    ):
+    def __init__(self, cfg: DiffusionConfig, model: nn.Module):
         """
         @crowsonkb's notes on EMA Warmup:
             If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
@@ -829,18 +660,18 @@ class _EMA:
         Args:
             inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
             power (float): Exponential factor of EMA warmup. Default: 2/3.
-            min_value (float): The minimum EMA decay rate. Default: 0.
+            min_alpha (float): The minimum EMA decay rate. Default: 0.
         """
 
         self.averaged_model = model
         self.averaged_model.eval()
         self.averaged_model.requires_grad_(False)
 
-        self.update_after_step = update_after_step
-        self.inv_gamma = inv_gamma
-        self.power = power
-        self.min_value = min_value
-        self.max_value = max_value
+        self.update_after_step = cfg.ema_update_after_step
+        self.inv_gamma = cfg.ema_inv_gamma
+        self.power = cfg.ema_power
+        self.min_alpha = cfg.ema_min_alpha
+        self.max_alpha = cfg.ema_max_alpha
 
         self.alpha = 0.0
         self.optimization_step = 0
@@ -855,7 +686,7 @@ class _EMA:
         if step <= 0:
             return 0.0
 
-        return max(self.min_value, min(value, self.max_value))
+        return max(self.min_alpha, min(value, self.max_alpha))
 
     @torch.no_grad()
     def step(self, new_model):
diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py
index 0b185e86..b5b5f861 100644
--- a/lerobot/common/policies/factory.py
+++ b/lerobot/common/policies/factory.py
@@ -9,7 +9,7 @@ def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
     expected_kwargs = set(inspect.signature(policy_cfg_class).parameters)
     assert set(hydra_cfg.policy).issuperset(
         expected_kwargs
-    ), f"Hydra config is missing arguments: {set(hydra_cfg.policy).difference(expected_kwargs)}"
+    ), f"Hydra config is missing arguments: {set(expected_kwargs).difference(hydra_cfg.policy)}"
     policy_cfg = policy_cfg_class(
         **{
             k: v
@@ -35,7 +35,7 @@ def make_policy(hydra_cfg: DictConfig):
         from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
 
         policy_cfg = _policy_cfg_from_hydra_cfg(DiffusionConfig, hydra_cfg)
-        policy = DiffusionPolicy(policy_cfg)
+        policy = DiffusionPolicy(policy_cfg, hydra_cfg.offline_steps)
         policy.to(get_safe_torch_device(hydra_cfg.device))
     elif hydra_cfg.policy.name == "act":
         from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml
index c8bdb0c4..44746dfc 100644
--- a/lerobot/configs/policy/diffusion.yaml
+++ b/lerobot/configs/policy/diffusion.yaml
@@ -44,7 +44,7 @@ policy:
   # Vision backbone.
   vision_backbone: resnet18
   crop_shape: [84, 84]
-  random_crop: True
+  crop_is_random: True
   use_pretrained_backbone: false
   use_group_norm: True
   spatial_softmax_num_keypoints: 32
@@ -53,15 +53,15 @@ policy:
   kernel_size: 5
   n_groups: 8
   diffusion_step_embed_dim: 128
-  film_scale_modulation: True
+  use_film_scale_modulation: True
   # Noise scheduler.
   num_train_timesteps: 100
   beta_schedule: squaredcos_cap_v2
   beta_start: 0.0001
   beta_end: 0.02
-  variance_type: fixed_small
   prediction_type: epsilon # epsilon / sample
   clip_sample: True
+  clip_sample_range: 1.0
 
   # Inference
   num_inference_steps: 100
@@ -79,12 +79,12 @@ policy:
   utd: 1
   use_ema: true
   ema_update_after_step: 0
-  ema_min_rate: 0.0
-  ema_max_rate: 0.9999
+  ema_min_alpha: 0.0
+  ema_max_alpha: 0.9999
   ema_inv_gamma: 1.0
   ema_power: 0.75
 
   delta_timestamps:
-    observation.images: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]"
+    observation.image: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]"
     observation.state: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]"
     action: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1 - ${n_obs_steps} + ${policy.horizon})]"
diff --git a/tests/test_examples.py b/tests/test_examples.py
index 4263e452..83fdad5e 100644
--- a/tests/test_examples.py
+++ b/tests/test_examples.py
@@ -1,8 +1,8 @@
 from pathlib import Path
 
 
-def _find_and_replace(text: str, finds: list[str], replaces: list[str]) -> str:
-    for f, r in zip(finds, replaces):
+def _find_and_replace(text: str, finds_and_replaces: list[tuple[str, str]]) -> str:
+    for f, r in finds_and_replaces:
         assert f in text
         text = text.replace(f, r)
     return text
@@ -32,8 +32,10 @@ def test_examples_3_and_2():
     # Do less steps and use CPU.
     file_contents = _find_and_replace(
         file_contents,
-        ['"offline_steps=5000"', '"device=cuda"'],
-        ['"offline_steps=1"', '"device=cpu"'],
+        [
+            ("offline_steps = 5000", "offline_steps = 1"),
+            ('device = torch.device("cuda")', 'device = torch.device("cpu")'),
+        ],
     )
 
     exec(file_contents)
@@ -50,20 +52,15 @@ def test_examples_3_and_2():
     file_contents = _find_and_replace(
         file_contents,
         [
-            '"eval_episodes=10"',
-            '"rollout_batch_size=10"',
-            '"device=cuda"',
-            '# folder = Path("outputs/train/example_pusht_diffusion")',
-            'hub_id = "lerobot/diffusion_policy_pusht_image"',
-            "folder = Path(snapshot_download(hub_id)",
-        ],
-        [
-            '"eval_episodes=1"',
-            '"rollout_batch_size=1"',
-            '"device=cpu"',
-            'folder = Path("outputs/train/example_pusht_diffusion")',
-            "",
-            "",
+            ('"eval_episodes=10"', '"eval_episodes=1"'),
+            ('"rollout_batch_size=10"', '"rollout_batch_size=1"'),
+            ('"device=cuda"', '"device=cpu"'),
+            (
+                '# folder = Path("outputs/train/example_pusht_diffusion")',
+                'folder = Path("outputs/train/example_pusht_diffusion")',
+            ),
+            ('hub_id = "lerobot/diffusion_policy_pusht_image"', ""),
+            ("folder = Path(snapshot_download(hub_id)", ""),
         ],
     )
 

From 9c2f10bd0460d4b623e1059dc1fa6c45cdb631c6 Mon Sep 17 00:00:00 2001
From: Alexander Soare <alexander.soare159@gmail.com>
Date: Tue, 16 Apr 2024 13:43:58 +0100
Subject: [PATCH 6/8] ready for review

---
 examples/2_evaluate_pretrained_policy.py      |  1 +
 examples/3_train_policy.py                    | 33 +++++++++++--------
 .../policies/diffusion/modeling_diffusion.py  |  3 +-
 tests/test_examples.py                        |  5 +--
 4 files changed, 26 insertions(+), 16 deletions(-)

diff --git a/examples/2_evaluate_pretrained_policy.py b/examples/2_evaluate_pretrained_policy.py
index be6abd1b..b3d13f74 100644
--- a/examples/2_evaluate_pretrained_policy.py
+++ b/examples/2_evaluate_pretrained_policy.py
@@ -11,6 +11,7 @@ from lerobot.common.utils import init_hydra_config
 from lerobot.scripts.eval import eval
 
 # Get a pretrained policy from the hub.
+# TODO(alexander-soare): This no longer works until we upload a new model that uses the current configs.
 hub_id = "lerobot/diffusion_policy_pusht_image"
 folder = Path(snapshot_download(hub_id))
 # OR uncomment the following to evaluate a policy from the local outputs/train folder.
diff --git a/examples/3_train_policy.py b/examples/3_train_policy.py
index 012efddd..83563ffd 100644
--- a/examples/3_train_policy.py
+++ b/examples/3_train_policy.py
@@ -11,7 +11,6 @@ import torch
 from omegaconf import OmegaConf
 
 from lerobot.common.datasets.factory import make_dataset
-from lerobot.common.datasets.utils import cycle
 from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
 from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
 from lerobot.common.utils import init_hydra_config
@@ -26,8 +25,8 @@ device = torch.device("cuda")
 log_freq = 250
 
 # Set up the dataset.
-cfg = init_hydra_config("lerobot/configs/default.yaml", overrides=["env=pusht"])
-dataset = make_dataset(cfg)
+hydra_cfg = init_hydra_config("lerobot/configs/default.yaml", overrides=["env=pusht"])
+dataset = make_dataset(hydra_cfg)
 
 # Set up the the policy.
 # Policies are initialized with a configuration class, in this case `DiffusionConfig`.
@@ -50,17 +49,25 @@ dataloader = torch.utils.data.DataLoader(
 )
 
 # Run training loop.
-dataloader = cycle(dataloader)
-for step in range(training_steps):
-    batch = {k: v.to(device, non_blocking=True) for k, v in next(dataloader).items()}
-    info = policy(batch)
-    if step % log_freq == 0:
-        num_samples = (step + 1) * cfg.batch_size
-        loss = info["loss"]
-        update_s = info["update_s"]
-        print(f"step: {step} samples: {num_samples} loss: {loss:.3f} update_time: {update_s:.3f} (seconds)")
+step = 0
+done = False
+while not done:
+    for batch in dataloader:
+        batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
+        info = policy(batch)
+        if step % log_freq == 0:
+            num_samples = (step + 1) * cfg.batch_size
+            loss = info["loss"]
+            update_s = info["update_s"]
+            print(
+                f"step: {step} samples: {num_samples} loss: {loss:.3f} update_time: {update_s:.3f} (seconds)"
+            )
+        step += 1
+        if step >= training_steps:
+            done = True
+            break
 
 # Save the policy, configuration, and normalization stats for later use.
 policy.save(output_directory / "model.pt")
-OmegaConf.save(cfg, output_directory / "config.yaml")
+OmegaConf.save(hydra_cfg, output_directory / "config.yaml")
 torch.save(dataset.transform.transforms[-1].stats, output_directory / "stats.pth")
diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py
index a95effb2..9a02c6a2 100644
--- a/lerobot/common/policies/diffusion/modeling_diffusion.py
+++ b/lerobot/common/policies/diffusion/modeling_diffusion.py
@@ -1,4 +1,5 @@
-"""
+"""Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
+
 TODO(alexander-soare):
   - Remove reliance on Robomimic for SpatialSoftmax.
   - Remove reliance on diffusers for DDPMScheduler and LR scheduler.
diff --git a/tests/test_examples.py b/tests/test_examples.py
index 83fdad5e..c264610b 100644
--- a/tests/test_examples.py
+++ b/tests/test_examples.py
@@ -29,11 +29,12 @@ def test_examples_3_and_2():
     with open(path, "r") as file:
         file_contents = file.read()
 
-    # Do less steps and use CPU.
+    # Do less steps, use CPU, and don't complicate things with dataloader workers.
     file_contents = _find_and_replace(
         file_contents,
         [
-            ("offline_steps = 5000", "offline_steps = 1"),
+            ("training_steps = 5000", "training_steps = 1"),
+            ("num_workers=4", "num_workers=0"),
             ('device = torch.device("cuda")', 'device = torch.device("cpu")'),
         ],
     )

From 43a614c17371bbc39f06111c45858c0964c63358 Mon Sep 17 00:00:00 2001
From: Alexander Soare <alexander.soare159@gmail.com>
Date: Tue, 16 Apr 2024 14:07:16 +0100
Subject: [PATCH 7/8] Fix test_examples

---
 examples/3_train_policy.py | 3 ++-
 tests/test_examples.py     | 3 ++-
 2 files changed, 4 insertions(+), 2 deletions(-)

diff --git a/examples/3_train_policy.py b/examples/3_train_policy.py
index 83563ffd..d2e8b8c9 100644
--- a/examples/3_train_policy.py
+++ b/examples/3_train_policy.py
@@ -53,7 +53,8 @@ step = 0
 done = False
 while not done:
     for batch in dataloader:
-        batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
+        for k in batch:
+            batch[k] = batch[k].to(device, non_blocking=True)
         info = policy(batch)
         if step % log_freq == 0:
             num_samples = (step + 1) * cfg.batch_size
diff --git a/tests/test_examples.py b/tests/test_examples.py
index c264610b..6cab7a1a 100644
--- a/tests/test_examples.py
+++ b/tests/test_examples.py
@@ -29,13 +29,14 @@ def test_examples_3_and_2():
     with open(path, "r") as file:
         file_contents = file.read()
 
-    # Do less steps, use CPU, and don't complicate things with dataloader workers.
+    # Do less steps, use smaller batch, use CPU, and don't complicate things with dataloader workers.
     file_contents = _find_and_replace(
         file_contents,
         [
             ("training_steps = 5000", "training_steps = 1"),
             ("num_workers=4", "num_workers=0"),
             ('device = torch.device("cuda")', 'device = torch.device("cpu")'),
+            ("batch_size=cfg.batch_size", "batch_size=1"),
         ],
     )
 

From a9496fde3976a29ac41674a5c1e826245dd42f22 Mon Sep 17 00:00:00 2001
From: Alexander Soare <alexander.soare159@gmail.com>
Date: Tue, 16 Apr 2024 17:15:51 +0100
Subject: [PATCH 8/8] revision 1

---
 examples/3_train_policy.py                          | 10 ++--------
 lerobot/common/policies/act/modeling_act.py         | 10 +++++++---
 .../common/policies/diffusion/modeling_diffusion.py | 13 ++++++++++---
 tests/test_examples.py                              |  3 ++-
 4 files changed, 21 insertions(+), 15 deletions(-)

diff --git a/examples/3_train_policy.py b/examples/3_train_policy.py
index d2e8b8c9..0c8decc4 100644
--- a/examples/3_train_policy.py
+++ b/examples/3_train_policy.py
@@ -53,16 +53,10 @@ step = 0
 done = False
 while not done:
     for batch in dataloader:
-        for k in batch:
-            batch[k] = batch[k].to(device, non_blocking=True)
+        batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
         info = policy(batch)
         if step % log_freq == 0:
-            num_samples = (step + 1) * cfg.batch_size
-            loss = info["loss"]
-            update_s = info["update_s"]
-            print(
-                f"step: {step} samples: {num_samples} loss: {loss:.3f} update_time: {update_s:.3f} (seconds)"
-            )
+            print(f"step: {step} loss: {info['loss']:.3f} update_time: {info['update_s']:.3f} (seconds)")
         step += 1
         if step >= training_steps:
             done = True
diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py
index 18ea3377..5f2429a6 100644
--- a/lerobot/common/policies/act/modeling_act.py
+++ b/lerobot/common/policies/act/modeling_act.py
@@ -65,12 +65,16 @@ class ActionChunkingTransformerPolicy(nn.Module):
         "ActionChunkingTransformerPolicy does not handle multiple observation steps."
     )
 
-    def __init__(self, cfg: ActionChunkingTransformerConfig):
+    def __init__(self, cfg: ActionChunkingTransformerConfig | None = None):
         """
-        TODO(alexander-soare): Add documentation for all parameters once we have model configs established.
+        Args:
+            cfg: Policy configuration class instance or None, in which case the default instantiation of the
+                 configuration class is used.
         """
         super().__init__()
-        if getattr(cfg, "n_obs_steps", 1) != 1:
+        if cfg is None:
+            cfg = ActionChunkingTransformerConfig()
+        if cfg.n_obs_steps != 1:
             raise ValueError(self._multiple_obs_steps_not_handled_msg)
         self.cfg = cfg
 
diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py
index 9a02c6a2..dfab9bb7 100644
--- a/lerobot/common/policies/diffusion/modeling_diffusion.py
+++ b/lerobot/common/policies/diffusion/modeling_diffusion.py
@@ -33,8 +33,6 @@ from lerobot.common.policies.utils import (
     populate_queues,
 )
 
-logger = logging.getLogger(__name__)
-
 
 class DiffusionPolicy(nn.Module):
     """
@@ -44,8 +42,17 @@ class DiffusionPolicy(nn.Module):
 
     name = "diffusion"
 
-    def __init__(self, cfg: DiffusionConfig, lr_scheduler_num_training_steps: int):
+    def __init__(self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0):
         super().__init__()
+        """
+        Args:
+            cfg: Policy configuration class instance or None, in which case the default instantiation of the
+                 configuration class is used.
+        """
+        # TODO(alexander-soare): LR scheduler will be removed.
+        assert lr_scheduler_num_training_steps > 0
+        if cfg is None:
+            cfg = DiffusionConfig()
         self.cfg = cfg
 
         # queues are populated during rollout of the policy, they contain the n latest observations and actions
diff --git a/tests/test_examples.py b/tests/test_examples.py
index 6cab7a1a..c510eb1e 100644
--- a/tests/test_examples.py
+++ b/tests/test_examples.py
@@ -40,7 +40,8 @@ def test_examples_3_and_2():
         ],
     )
 
-    exec(file_contents)
+    # Pass empty globals to allow dictionary comprehension https://stackoverflow.com/a/32897127/4391249.
+    exec(file_contents, {})
 
     for file_name in ["model.pt", "stats.pth", "config.yaml"]:
         assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists()