diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py
index 22384ca0..911e7121 100644
--- a/lerobot/common/policies/act/modeling_act.py
+++ b/lerobot/common/policies/act/modeling_act.py
@@ -26,6 +26,108 @@ class ActionChunkingTransformerPolicy(nn.Module):
     """
     Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
     Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
+    """
+
+    name = "act"
+
+    def __init__(self, cfg: ActionChunkingTransformerConfig | None = None, dataset_stats=None):
+        """
+        Args:
+            cfg: Policy configuration class instance or None, in which case the default instantiation of the
+                 configuration class is used.
+        """
+        super().__init__()
+        if cfg is None:
+            cfg = ActionChunkingTransformerConfig()
+        self.cfg = cfg
+        self.normalize_inputs = Normalize(cfg.input_shapes, cfg.input_normalization_modes, dataset_stats)
+        self.normalize_targets = Normalize(cfg.output_shapes, cfg.output_normalization_modes, dataset_stats)
+        self.unnormalize_outputs = Unnormalize(
+            cfg.output_shapes, cfg.output_normalization_modes, dataset_stats
+        )
+        self.model = _ActionChunkingTransformer(cfg)
+
+    def reset(self):
+        """This should be called whenever the environment is reset."""
+        if self.cfg.n_action_steps is not None:
+            self._action_queue = deque([], maxlen=self.cfg.n_action_steps)
+
+    @torch.no_grad
+    def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
+        """Select a single action given environment observations.
+
+        This method wraps `select_actions` in order to return one action at a time for execution in the
+        environment. It works by managing the actions in a queue and only calling `select_actions` when the
+        queue is empty.
+        """
+        self.eval()
+
+        batch = self.normalize_inputs(batch)
+        self._stack_images(batch)
+
+        if len(self._action_queue) == 0:
+            # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
+            # effectively has shape (n_action_steps, batch_size, *), hence the transpose.
+            actions = self.model(batch)[0][: self.cfg.n_action_steps]
+
+            # TODO(rcadene): make _forward return output dictionary?
+            actions = self.unnormalize_outputs({"action": actions})["action"]
+
+            self._action_queue.extend(actions.transpose(0, 1))
+        return self._action_queue.popleft()
+
+    def forward(self, batch, **_) -> dict[str, Tensor]:
+        """Run the batch through the model and compute the loss for training or validation."""
+        batch = self.normalize_inputs(batch)
+        batch = self.normalize_targets(batch)
+        self._stack_images(batch)
+        actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
+
+        l1_loss = (
+            F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
+        ).mean()
+
+        loss_dict = {"l1_loss": l1_loss}
+        if self.cfg.use_vae:
+            # Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
+            # each dimension independently, we sum over the latent dimension to get the total
+            # KL-divergence per batch element, then take the mean over the batch.
+            # (See App. B of https://arxiv.org/abs/1312.6114 for more details).
+            mean_kld = (
+                (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
+            )
+            loss_dict["kld_loss"] = mean_kld
+            loss_dict["loss"] = l1_loss + mean_kld * self.cfg.kl_weight
+        else:
+            loss_dict["loss"] = l1_loss
+
+        return loss_dict
+
+    def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
+        """Stacks all the images in a batch and puts them in a new key: "observation.images".
+
+        This function expects `batch` to have (at least):
+        {
+            "observation.state": (B, state_dim) batch of robot states.
+            "observation.images.{name}": (B, C, H, W) tensor of images.
+        }
+        """
+        # Stack images in the order dictated by input_shapes.
+        batch["observation.images"] = torch.stack(
+            [batch[k] for k in self.cfg.input_shapes if k.startswith("observation.images.")],
+            dim=-4,
+        )
+
+    def save(self, fp):
+        torch.save(self.state_dict(), fp)
+
+    def load(self, fp):
+        d = torch.load(fp)
+        self.load_state_dict(d)
+
+
+class _ActionChunkingTransformer(nn.Module):
+    """Action Chunking Transformer: The underlying neural network for ActionChunkingTransformerPolicy.
 
     Note: In this code we use the terms `vae_encoder`, 'encoder', `decoder`. The meanings are as follows.
         - The `vae_encoder` is, as per the literature around variational auto-encoders (VAE), the part of the
@@ -59,24 +161,9 @@ class ActionChunkingTransformerPolicy(nn.Module):
                                 └───────────────────────┘
     """
 
-    name = "act"
-
-    def __init__(self, cfg: ActionChunkingTransformerConfig | None = None, dataset_stats=None):
-        """
-        Args:
-            cfg: Policy configuration class instance or None, in which case the default instantiation of the
-                 configuration class is used.
-        """
+    def __init__(self, cfg: ActionChunkingTransformerConfig):
         super().__init__()
-        if cfg is None:
-            cfg = ActionChunkingTransformerConfig()
         self.cfg = cfg
-        self.normalize_inputs = Normalize(cfg.input_shapes, cfg.input_normalization_modes, dataset_stats)
-        self.normalize_targets = Normalize(cfg.output_shapes, cfg.output_normalization_modes, dataset_stats)
-        self.unnormalize_outputs = Unnormalize(
-            cfg.output_shapes, cfg.output_normalization_modes, dataset_stats
-        )
-
         # BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
         # The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
         if self.cfg.use_vae:
@@ -141,76 +228,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
             if p.dim() > 1:
                 nn.init.xavier_uniform_(p)
 
-    def reset(self):
-        """This should be called whenever the environment is reset."""
-        if self.cfg.n_action_steps is not None:
-            self._action_queue = deque([], maxlen=self.cfg.n_action_steps)
-
-    @torch.no_grad
-    def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
-        """Select a single action given environment observations.
-
-        This method wraps `select_actions` in order to return one action at a time for execution in the
-        environment. It works by managing the actions in a queue and only calling `select_actions` when the
-        queue is empty.
-        """
-        self.eval()
-
-        batch = self.normalize_inputs(batch)
-
-        if len(self._action_queue) == 0:
-            # `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively
-            # has shape (n_action_steps, batch_size, *), hence the transpose.
-            actions = self._forward(batch)[0][: self.cfg.n_action_steps]
-
-            # TODO(rcadene): make _forward return output dictionary?
-            actions = self.unnormalize_outputs({"action": actions})["action"]
-
-            self._action_queue.extend(actions.transpose(0, 1))
-        return self._action_queue.popleft()
-
-    def forward(self, batch, **_) -> dict[str, Tensor]:
-        """Run the batch through the model and compute the loss for training or validation."""
-        batch = self.normalize_inputs(batch)
-        batch = self.normalize_targets(batch)
-        actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(batch)
-
-        l1_loss = (
-            F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
-        ).mean()
-
-        loss_dict = {"l1_loss": l1_loss}
-        if self.cfg.use_vae:
-            # Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
-            # each dimension independently, we sum over the latent dimension to get the total
-            # KL-divergence per batch element, then take the mean over the batch.
-            # (See App. B of https://arxiv.org/abs/1312.6114 for more details).
-            mean_kld = (
-                (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
-            )
-            loss_dict["kld_loss"] = mean_kld
-            loss_dict["loss"] = l1_loss + mean_kld * self.cfg.kl_weight
-        else:
-            loss_dict["loss"] = l1_loss
-
-        return loss_dict
-
-    def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
-        """Stacks all the images in a batch and puts them in a new key: "observation.images".
-
-        This function expects `batch` to have (at least):
-        {
-            "observation.state": (B, state_dim) batch of robot states.
-            "observation.images.{name}": (B, C, H, W) tensor of images.
-        }
-        """
-        # Stack images in the order dictated by input_shapes.
-        batch["observation.images"] = torch.stack(
-            [batch[k] for k in self.cfg.input_shapes if k.startswith("observation.images.")],
-            dim=-4,
-        )
-
-    def _forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
+    def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
         """A forward pass through the Action Chunking Transformer (with optional VAE encoder).
 
         `batch` should have the following structure:
@@ -231,8 +249,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
                 "action" in batch
             ), "actions must be provided when using the variational objective in training mode."
 
-        self._stack_images(batch)
-
         batch_size = batch["observation.state"].shape[0]
 
         # Prepare the latent for input to the transformer encoder.
@@ -324,13 +340,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
 
         return actions, (mu, log_sigma_x2)
 
-    def save(self, fp):
-        torch.save(self.state_dict(), fp)
-
-    def load(self, fp):
-        d = torch.load(fp)
-        self.load_state_dict(d)
-
 
 class _TransformerEncoder(nn.Module):
     """Convenience module for running multiple encoder layers, maybe followed by normalization."""