diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index b3411e11..a86193b8 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -146,7 +146,8 @@ jobs:
             device=cpu \
             save_model=true \
             save_freq=2 \
-            horizon=20 \
+            policy.n_action_steps=20 \
+            policy.chunk_size=20 \
             policy.batch_size=2 \
             hydra.run.dir=tests/outputs/act/
 
diff --git a/lerobot/common/policies/act/configuration_act.py b/lerobot/common/policies/act/configuration_act.py
new file mode 100644
index 00000000..74ed270e
--- /dev/null
+++ b/lerobot/common/policies/act/configuration_act.py
@@ -0,0 +1,114 @@
+from dataclasses import dataclass, field
+
+
+@dataclass
+class ActionChunkingTransformerConfig:
+    """Configuration class for the Action Chunking Transformers policy.
+
+    Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".
+
+    The parameters you will most likely need to change are the ones which depend on the environment / sensors.
+    Those are: `state_dim`, `action_dim` and `camera_names`.
+
+    Args:
+        state_dim: Dimensionality of the observation state space (excluding images).
+        action_dim: Dimensionality of the action space.
+        n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
+            current step and additional steps going back).
+        camera_names: The (unique) set of names for the cameras.
+        chunk_size: The size of the action prediction "chunks" in units of environment steps.
+        n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
+            This should be no greater than the chunk size. For example, if the chunk size size 100, you may
+            set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
+            environment, and throws the other 50 out.
+        image_normalization_mean: Value to subtract from the input image pixels (inputs are assumed to be in
+            [0, 1]) for normalization.
+        image_normalization_std: Value by which to divide the input image pixels (after the mean has been
+            subtracted).
+        vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
+        use_pretrained_backbone: Whether the backbone should be initialized with ImageNet, pretrained weights
+            from torchvision.
+        replace_final_stride_with_dilation: Whether to replace the ResNet's final 2x2 stride with a dilated
+            convolution.
+        pre_norm: Whether to use "pre-norm" in the transformer blocks.
+        d_model: The transformer blocks' main hidden dimension.
+        n_heads: The number of heads to use in the transformer blocks' multi-head attention.
+        dim_feedforward: The dimension to expand the transformer's hidden dimension to in the feed-forward
+            layers.
+        feedforward_activation: The activation to use in the transformer block's feed-forward layers.
+        n_encoder_layers: The number of transformer layers to use for the transformer encoder.
+        n_decoder_layers: The number of transformer layers to use for the transformer decoder.
+        use_vae: Whether to use a variational objective during training. This introduces another transformer
+            which is used as the VAE's encoder (not to be confused with the transformer encoder - see
+            documentation in the policy class).
+        latent_dim: The VAE's latent dimension.
+        n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder.
+        use_temporal_aggregation: Whether to blend the actions of multiple policy invocations for any given
+            environment step.
+        dropout: Dropout to use in the transformer layers (see code for details).
+        kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective
+            is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`.
+    """
+
+    # Environment.
+    state_dim: int = 14
+    action_dim: int = 14
+
+    # Inputs / output structure.
+    n_obs_steps: int = 1
+    camera_names: list[str] = field(default_factory=lambda: ["top"])
+    chunk_size: int = 100
+    n_action_steps: int = 100
+
+    # Vision preprocessing.
+    image_normalization_mean: tuple[float, float, float] = field(
+        default_factory=lambda: [0.485, 0.456, 0.406]
+    )
+    image_normalization_std: tuple[float, float, float] = field(default_factory=lambda: [0.229, 0.224, 0.225])
+
+    # Architecture.
+    # Vision backbone.
+    vision_backbone: str = "resnet18"
+    use_pretrained_backbone: bool = True
+    replace_final_stride_with_dilation: int = False
+    # Transformer layers.
+    pre_norm: bool = False
+    d_model: int = 512
+    n_heads: int = 8
+    dim_feedforward: int = 3200
+    feedforward_activation: str = "relu"
+    n_encoder_layers: int = 4
+    n_decoder_layers: int = 1
+    # VAE.
+    use_vae: bool = True
+    latent_dim: int = 32
+    n_vae_encoder_layers: int = 4
+
+    # Inference.
+    use_temporal_aggregation: bool = False
+
+    # Training and loss computation.
+    dropout: float = 0.1
+    kl_weight: float = 10.0
+
+    # ---
+    # TODO(alexander-soare): Remove these from the policy config.
+    batch_size: int = 8
+    lr: float = 1e-5
+    lr_backbone: float = 1e-5
+    weight_decay: float = 1e-4
+    grad_clip_norm: float = 10
+    utd: int = 1
+
+    def __post_init__(self):
+        """Input validation."""
+        if not self.vision_backbone.startswith("resnet"):
+            raise ValueError("`vision_backbone` must be one of the ResNet variants.")
+        if self.use_temporal_aggregation:
+            raise NotImplementedError("Temporal aggregation is not yet implemented.")
+        if self.n_action_steps > self.chunk_size:
+            raise ValueError(
+                "The chunk size is the upper bound for the number of action steps per model invocation."
+            )
+        if self.camera_names != ["top"]:
+            raise ValueError("For now, `camera_names` can only be ['top']")
diff --git a/lerobot/common/policies/act/policy.py b/lerobot/common/policies/act/modeling_act.py
similarity index 80%
rename from lerobot/common/policies/act/policy.py
rename to lerobot/common/policies/act/modeling_act.py
index 24667795..1361e071 100644
--- a/lerobot/common/policies/act/policy.py
+++ b/lerobot/common/policies/act/modeling_act.py
@@ -20,7 +20,7 @@ from torch import Tensor, nn
 from torchvision.models._utils import IntermediateLayerGetter
 from torchvision.ops.misc import FrozenBatchNorm2d
 
-from lerobot.common.utils import get_safe_torch_device
+from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
 
 
 class ActionChunkingTransformerPolicy(nn.Module):
@@ -65,7 +65,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
         "ActionChunkingTransformerPolicy does not handle multiple observation steps."
     )
 
-    def __init__(self, cfg, device):
+    def __init__(self, cfg: ActionChunkingTransformerConfig):
         """
         TODO(alexander-soare): Add documentation for all parameters once we have model configs established.
         """
@@ -73,79 +73,64 @@ class ActionChunkingTransformerPolicy(nn.Module):
         if getattr(cfg, "n_obs_steps", 1) != 1:
             raise ValueError(self._multiple_obs_steps_not_handled_msg)
         self.cfg = cfg
-        self.n_action_steps = cfg.n_action_steps
-        self.device = get_safe_torch_device(device)
-        self.camera_names = cfg.camera_names
-        self.use_vae = cfg.use_vae
-        self.horizon = cfg.horizon
-        self.d_model = cfg.d_model
-
-        transformer_common_kwargs = dict(  # noqa: C408
-            d_model=self.d_model,
-            num_heads=cfg.num_heads,
-            dim_feedforward=cfg.dim_feedforward,
-            dropout=cfg.dropout,
-            activation=cfg.activation,
-            normalize_before=cfg.pre_norm,
-        )
 
         # BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
         # The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
-        if self.use_vae:
-            self.vae_encoder = _TransformerEncoder(num_layers=cfg.vae_enc_layers, **transformer_common_kwargs)
-            self.vae_encoder_cls_embed = nn.Embedding(1, self.d_model)
+        if self.cfg.use_vae:
+            self.vae_encoder = _TransformerEncoder(cfg)
+            self.vae_encoder_cls_embed = nn.Embedding(1, cfg.d_model)
             # Projection layer for joint-space configuration to hidden dimension.
-            self.vae_encoder_robot_state_input_proj = nn.Linear(cfg.state_dim, self.d_model)
+            self.vae_encoder_robot_state_input_proj = nn.Linear(cfg.state_dim, cfg.d_model)
             # Projection layer for action (joint-space target) to hidden dimension.
-            self.vae_encoder_action_input_proj = nn.Linear(cfg.state_dim, self.d_model)
+            self.vae_encoder_action_input_proj = nn.Linear(cfg.state_dim, cfg.d_model)
             self.latent_dim = cfg.latent_dim
             # Projection layer from the VAE encoder's output to the latent distribution's parameter space.
-            self.vae_encoder_latent_output_proj = nn.Linear(self.d_model, self.latent_dim * 2)
+            self.vae_encoder_latent_output_proj = nn.Linear(cfg.d_model, self.latent_dim * 2)
             # Fixed sinusoidal positional embedding the whole input to the VAE encoder. Unsqueeze for batch
             # dimension.
             self.register_buffer(
                 "vae_encoder_pos_enc",
-                _create_sinusoidal_position_embedding(1 + 1 + self.horizon, self.d_model).unsqueeze(0),
+                _create_sinusoidal_position_embedding(1 + 1 + cfg.chunk_size, cfg.d_model).unsqueeze(0),
             )
 
         # Backbone for image feature extraction.
         self.image_normalizer = transforms.Normalize(
-            mean=cfg.image_normalization.mean, std=cfg.image_normalization.std
+            mean=cfg.image_normalization_mean, std=cfg.image_normalization_std
         )
-        backbone_model = getattr(torchvision.models, cfg.backbone)(
-            replace_stride_with_dilation=[False, False, cfg.dilation],
-            pretrained=cfg.pretrained_backbone,
+        backbone_model = getattr(torchvision.models, cfg.vision_backbone)(
+            replace_stride_with_dilation=[False, False, cfg.replace_final_stride_with_dilation],
+            pretrained=cfg.use_pretrained_backbone,
             norm_layer=FrozenBatchNorm2d,
         )
+        # Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final feature
+        # map).
         # Note: The forward method of this returns a dict: {"feature_map": output}.
         self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
 
         # Transformer (acts as VAE decoder when training with the variational objective).
-        self.encoder = _TransformerEncoder(num_layers=cfg.enc_layers, **transformer_common_kwargs)
-        self.decoder = _TransformerDecoder(num_layers=cfg.dec_layers, **transformer_common_kwargs)
+        self.encoder = _TransformerEncoder(cfg)
+        self.decoder = _TransformerDecoder(cfg)
 
         # Transformer encoder input projections. The tokens will be structured like
         # [latent, robot_state, image_feature_map_pixels].
-        self.encoder_robot_state_input_proj = nn.Linear(cfg.state_dim, self.d_model)
-        self.encoder_latent_input_proj = nn.Linear(self.latent_dim, self.d_model)
+        self.encoder_robot_state_input_proj = nn.Linear(cfg.state_dim, cfg.d_model)
+        self.encoder_latent_input_proj = nn.Linear(self.latent_dim, cfg.d_model)
         self.encoder_img_feat_input_proj = nn.Conv2d(
-            backbone_model.fc.in_features, self.d_model, kernel_size=1
+            backbone_model.fc.in_features, cfg.d_model, kernel_size=1
         )
         # Transformer encoder positional embeddings.
-        self.encoder_robot_and_latent_pos_embed = nn.Embedding(2, self.d_model)
-        self.encoder_cam_feat_pos_embed = _SinusoidalPositionEmbedding2D(self.d_model // 2)
+        self.encoder_robot_and_latent_pos_embed = nn.Embedding(2, cfg.d_model)
+        self.encoder_cam_feat_pos_embed = _SinusoidalPositionEmbedding2D(cfg.d_model // 2)
 
         # Transformer decoder.
         # Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
-        self.decoder_pos_embed = nn.Embedding(self.horizon, self.d_model)
+        self.decoder_pos_embed = nn.Embedding(cfg.chunk_size, cfg.d_model)
 
         # Final action regression head on the output of the transformer's decoder.
-        self.action_head = nn.Linear(self.d_model, cfg.action_dim)
+        self.action_head = nn.Linear(cfg.d_model, cfg.action_dim)
 
         self._reset_parameters()
-
         self._create_optimizer()
-        self.to(self.device)
 
     def _create_optimizer(self):
         optimizer_params_dicts = [
@@ -173,8 +158,8 @@ class ActionChunkingTransformerPolicy(nn.Module):
 
     def reset(self):
         """This should be called whenever the environment is reset."""
-        if self.n_action_steps is not None:
-            self._action_queue = deque([], maxlen=self.n_action_steps)
+        if self.cfg.n_action_steps is not None:
+            self._action_queue = deque([], maxlen=self.cfg.n_action_steps)
 
     @torch.no_grad
     def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
@@ -184,8 +169,8 @@ class ActionChunkingTransformerPolicy(nn.Module):
         queue is empty.
         """
         if len(self._action_queue) == 0:
-            # `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has shape
-            # (n_action_steps, batch_size, *), hence the transpose.
+            # `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has
+            # shape (n_action_steps, batch_size, *), hence the transpose.
             self._action_queue.extend(self.select_actions(batch).transpose(0, 1))
         return self._action_queue.popleft()
 
@@ -197,20 +182,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
 
         action = self.forward(batch, return_loss=False)
 
-        if self.cfg.temporal_agg:
-            # TODO(rcadene): implement temporal aggregation
-            raise NotImplementedError()
-            # all_time_actions[[t], t:t+num_queries] = action
-            # actions_for_curr_step = all_time_actions[:, t]
-            # actions_populated = torch.all(actions_for_curr_step != 0, axis=1)
-            # actions_for_curr_step = actions_for_curr_step[actions_populated]
-            # k = 0.01
-            # exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step)))
-            # exp_weights = exp_weights / exp_weights.sum()
-            # exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1)
-            # raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True)
-
-        return action[: self.n_action_steps]
+        return action[: self.cfg.n_action_steps]
 
     def __call__(self, *args, **kwargs) -> dict:
         # TODO(alexander-soare): Temporary bridge until we know what to do about the `update` method.
@@ -251,9 +223,9 @@ class ActionChunkingTransformerPolicy(nn.Module):
         self.train()
 
         num_slices = self.cfg.batch_size
-        batch_size = self.cfg.horizon * num_slices
+        batch_size = self.cfg.chunk_size * num_slices
 
-        assert batch_size % self.cfg.horizon == 0
+        assert batch_size % self.cfg.chunk_size == 0
         assert batch_size % num_slices == 0
 
         loss = self.forward(batch, return_loss=True)["loss"]
@@ -324,7 +296,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
             Tuple containing the latent PDF's parameters (mean, log(σ²)) both as (B, L) tensors where L is the
             latent dimension.
         """
-        if self.use_vae and self.training:
+        if self.cfg.use_vae and self.training:
             assert (
                 actions is not None
             ), "actions must be provided when using the variational objective in training mode."
@@ -332,7 +304,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
         batch_size = robot_state.shape[0]
 
         # Prepare the latent for input to the transformer encoder.
-        if self.use_vae and actions is not None:
+        if self.cfg.use_vae and actions is not None:
             # Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence].
             cls_embed = einops.repeat(
                 self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
@@ -367,7 +339,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
         # Camera observation features and positional embeddings.
         all_cam_features = []
         all_cam_pos_embeds = []
-        for cam_id, _ in enumerate(self.camera_names):
+        for cam_id, _ in enumerate(self.cfg.camera_names):
             cam_features = self.backbone(image[:, cam_id])["feature_map"]
             cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
             cam_features = self.encoder_img_feat_input_proj(cam_features)  # (B, C, h, w)
@@ -399,7 +371,9 @@ class ActionChunkingTransformerPolicy(nn.Module):
         # Forward pass through the transformer modules.
         encoder_out = self.encoder(encoder_in, pos_embed=pos_embed)
         decoder_in = torch.zeros(
-            (self.horizon, batch_size, self.d_model), dtype=pos_embed.dtype, device=pos_embed.device
+            (self.cfg.chunk_size, batch_size, self.cfg.d_model),
+            dtype=pos_embed.dtype,
+            device=pos_embed.device,
         )
         decoder_out = self.decoder(
             decoder_in,
@@ -426,16 +400,10 @@ class ActionChunkingTransformerPolicy(nn.Module):
 class _TransformerEncoder(nn.Module):
     """Convenience module for running multiple encoder layers, maybe followed by normalization."""
 
-    def __init__(self, num_layers: int, **encoder_layer_kwargs: dict):
+    def __init__(self, cfg: ActionChunkingTransformerConfig):
         super().__init__()
-        self.layers = nn.ModuleList(
-            [_TransformerEncoderLayer(**encoder_layer_kwargs) for _ in range(num_layers)]
-        )
-        self.norm = (
-            nn.LayerNorm(encoder_layer_kwargs["d_model"])
-            if encoder_layer_kwargs["normalize_before"]
-            else nn.Identity()
-        )
+        self.layers = nn.ModuleList([_TransformerEncoderLayer(cfg) for _ in range(cfg.n_encoder_layers)])
+        self.norm = nn.LayerNorm(cfg.d_model) if cfg.pre_norm else nn.Identity()
 
     def forward(self, x: Tensor, pos_embed: Tensor | None = None) -> Tensor:
         for layer in self.layers:
@@ -445,39 +413,31 @@ class _TransformerEncoder(nn.Module):
 
 
 class _TransformerEncoderLayer(nn.Module):
-    def __init__(
-        self,
-        d_model: int,
-        num_heads: int,
-        dim_feedforward: int,
-        dropout: float,
-        activation: str,
-        normalize_before: bool,
-    ):
+    def __init__(self, cfg: ActionChunkingTransformerConfig):
         super().__init__()
-        self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
+        self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout)
 
         # Feed forward layers.
-        self.linear1 = nn.Linear(d_model, dim_feedforward)
-        self.dropout = nn.Dropout(dropout)
-        self.linear2 = nn.Linear(dim_feedforward, d_model)
+        self.linear1 = nn.Linear(cfg.d_model, cfg.dim_feedforward)
+        self.dropout = nn.Dropout(cfg.dropout)
+        self.linear2 = nn.Linear(cfg.dim_feedforward, cfg.d_model)
 
-        self.norm1 = nn.LayerNorm(d_model)
-        self.norm2 = nn.LayerNorm(d_model)
-        self.dropout1 = nn.Dropout(dropout)
-        self.dropout2 = nn.Dropout(dropout)
+        self.norm1 = nn.LayerNorm(cfg.d_model)
+        self.norm2 = nn.LayerNorm(cfg.d_model)
+        self.dropout1 = nn.Dropout(cfg.dropout)
+        self.dropout2 = nn.Dropout(cfg.dropout)
 
-        self.activation = _get_activation_fn(activation)
-        self.normalize_before = normalize_before
+        self.activation = _get_activation_fn(cfg.feedforward_activation)
+        self.pre_norm = cfg.pre_norm
 
     def forward(self, x, pos_embed: Tensor | None = None) -> Tensor:
         skip = x
-        if self.normalize_before:
+        if self.pre_norm:
             x = self.norm1(x)
         q = k = x if pos_embed is None else x + pos_embed
         x = self.self_attn(q, k, value=x)[0]  # select just the output, not the attention weights
         x = skip + self.dropout1(x)
-        if self.normalize_before:
+        if self.pre_norm:
             skip = x
             x = self.norm2(x)
         else:
@@ -485,20 +445,17 @@ class _TransformerEncoderLayer(nn.Module):
             skip = x
         x = self.linear2(self.dropout(self.activation(self.linear1(x))))
         x = skip + self.dropout2(x)
-        if not self.normalize_before:
+        if not self.pre_norm:
             x = self.norm2(x)
         return x
 
 
 class _TransformerDecoder(nn.Module):
-    def __init__(self, num_layers: int, **decoder_layer_kwargs):
+    def __init__(self, cfg: ActionChunkingTransformerConfig):
         """Convenience module for running multiple decoder layers followed by normalization."""
         super().__init__()
-        self.layers = nn.ModuleList(
-            [_TransformerDecoderLayer(**decoder_layer_kwargs) for _ in range(num_layers)]
-        )
-        self.num_layers = num_layers
-        self.norm = nn.LayerNorm(decoder_layer_kwargs["d_model"])
+        self.layers = nn.ModuleList([_TransformerDecoderLayer(cfg) for _ in range(cfg.n_decoder_layers)])
+        self.norm = nn.LayerNorm(cfg.d_model)
 
     def forward(
         self,
@@ -517,33 +474,25 @@ class _TransformerDecoder(nn.Module):
 
 
 class _TransformerDecoderLayer(nn.Module):
-    def __init__(
-        self,
-        d_model: int,
-        num_heads: int,
-        dim_feedforward: int,
-        dropout: float,
-        activation: str,
-        normalize_before: bool,
-    ):
+    def __init__(self, cfg: ActionChunkingTransformerConfig):
         super().__init__()
-        self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
-        self.multihead_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
+        self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout)
+        self.multihead_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout)
 
         # Feed forward layers.
-        self.linear1 = nn.Linear(d_model, dim_feedforward)
-        self.dropout = nn.Dropout(dropout)
-        self.linear2 = nn.Linear(dim_feedforward, d_model)
+        self.linear1 = nn.Linear(cfg.d_model, cfg.dim_feedforward)
+        self.dropout = nn.Dropout(cfg.dropout)
+        self.linear2 = nn.Linear(cfg.dim_feedforward, cfg.d_model)
 
-        self.norm1 = nn.LayerNorm(d_model)
-        self.norm2 = nn.LayerNorm(d_model)
-        self.norm3 = nn.LayerNorm(d_model)
-        self.dropout1 = nn.Dropout(dropout)
-        self.dropout2 = nn.Dropout(dropout)
-        self.dropout3 = nn.Dropout(dropout)
+        self.norm1 = nn.LayerNorm(cfg.d_model)
+        self.norm2 = nn.LayerNorm(cfg.d_model)
+        self.norm3 = nn.LayerNorm(cfg.d_model)
+        self.dropout1 = nn.Dropout(cfg.dropout)
+        self.dropout2 = nn.Dropout(cfg.dropout)
+        self.dropout3 = nn.Dropout(cfg.dropout)
 
-        self.activation = _get_activation_fn(activation)
-        self.normalize_before = normalize_before
+        self.activation = _get_activation_fn(cfg.feedforward_activation)
+        self.pre_norm = cfg.pre_norm
 
     def maybe_add_pos_embed(self, tensor: Tensor, pos_embed: Tensor | None) -> Tensor:
         return tensor if pos_embed is None else tensor + pos_embed
@@ -566,12 +515,12 @@ class _TransformerDecoderLayer(nn.Module):
             (DS, B, C) tensor of decoder output features.
         """
         skip = x
-        if self.normalize_before:
+        if self.pre_norm:
             x = self.norm1(x)
         q = k = self.maybe_add_pos_embed(x, decoder_pos_embed)
         x = self.self_attn(q, k, value=x)[0]  # select just the output, not the attention weights
         x = skip + self.dropout1(x)
-        if self.normalize_before:
+        if self.pre_norm:
             skip = x
             x = self.norm2(x)
         else:
@@ -583,7 +532,7 @@ class _TransformerDecoderLayer(nn.Module):
             value=encoder_out,
         )[0]  # select just the output, not the attention weights
         x = skip + self.dropout2(x)
-        if self.normalize_before:
+        if self.pre_norm:
             skip = x
             x = self.norm3(x)
         else:
@@ -591,7 +540,7 @@ class _TransformerDecoderLayer(nn.Module):
             skip = x
         x = self.linear2(self.dropout(self.activation(self.linear1(x))))
         x = skip + self.dropout3(x)
-        if not self.normalize_before:
+        if not self.pre_norm:
             x = self.norm3(x)
         return x
 
diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py
index a287614d..f0454b8e 100644
--- a/lerobot/common/policies/factory.py
+++ b/lerobot/common/policies/factory.py
@@ -1,3 +1,10 @@
+import inspect
+
+from omegaconf import OmegaConf
+
+from lerobot.common.utils import get_safe_torch_device
+
+
 def make_policy(cfg):
     if cfg.policy.name == "tdmpc":
         from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
@@ -19,10 +26,22 @@ def make_policy(cfg):
             **cfg.policy,
         )
     elif cfg.policy.name == "act":
-        from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
+        from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
+        from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
 
-        policy = ActionChunkingTransformerPolicy(cfg.policy, cfg.device)
-        policy.to(cfg.device)
+        expected_kwargs = set(inspect.signature(ActionChunkingTransformerConfig).parameters)
+        assert set(cfg.policy).issuperset(
+            expected_kwargs
+        ), f"Hydra config is missing arguments: {set(cfg.policy).difference(expected_kwargs)}"
+        policy_cfg = ActionChunkingTransformerConfig(
+            **{
+                k: v
+                for k, v in OmegaConf.to_container(cfg.policy, resolve=True).items()
+                if k in expected_kwargs
+            }
+        )
+        policy = ActionChunkingTransformerPolicy(policy_cfg)
+        policy.to(get_safe_torch_device(cfg.device))
     else:
         raise ValueError(cfg.policy.name)
 
diff --git a/lerobot/configs/policy/act.yaml b/lerobot/configs/policy/act.yaml
index e2074b46..bd883613 100644
--- a/lerobot/configs/policy/act.yaml
+++ b/lerobot/configs/policy/act.yaml
@@ -8,61 +8,65 @@ eval_freq: 10000
 save_freq: 100000
 log_freq: 250
 
-horizon: 100
 n_obs_steps: 1
 # when temporal_agg=False, n_action_steps=horizon
-n_action_steps: ${horizon}
 
+# See `configuration_act.py` for more details.
 policy:
   name: act
 
   pretrained_model_path:
 
+  # Environment.
+  # Inherit these from the environment.
+  state_dim: ???
+  action_dim: ???
+
+  # Inputs / output structure.
+  n_obs_steps: ${n_obs_steps}
+  camera_names: [top]  # [top, front_close, left_pillar, right_pillar]
+  chunk_size: 100 # chunk_size
+  n_action_steps: 100
+
+  # Vision preprocessing.
+  image_normalization_mean: [0.485, 0.456, 0.406]
+  image_normalization_std: [0.229, 0.224, 0.225]
+
+  # Architecture.
+  # Vision backbone.
+  vision_backbone: resnet18
+  use_pretrained_backbone: true
+  replace_final_stride_with_dilation: false
+  # Transformer layers.
+  pre_norm: false
+  d_model: 512
+  n_heads: 8
+  dim_feedforward: 3200
+  feedforward_activation: relu
+  n_encoder_layers: 4
+  n_decoder_layers: 1
+  # VAE.
+  use_vae: true
+  latent_dim: 32
+  n_vae_encoder_layers: 4
+
+  # Inference.
+  use_temporal_aggregation: false
+
+  # Training and loss computation.
+  dropout: 0.1
+  kl_weight: 10.0
+
+  # ---
+  # TODO(alexander-soare): Remove these from the policy config.
+  batch_size: 8
   lr: 1e-5
   lr_backbone: 1e-5
-  pretrained_backbone: true
   weight_decay: 1e-4
   grad_clip_norm: 10
-  backbone: resnet18
-  horizon: ${horizon} # chunk_size
-  kl_weight: 10
-  d_model: 512
-  dim_feedforward: 3200
-  vae_enc_layers: 4
-  enc_layers: 4
-  dec_layers: 1
-  num_heads: 8
-  #camera_names: [top, front_close, left_pillar, right_pillar]
-  camera_names: [top]
-  dilation: false
-  dropout: 0.1
-  pre_norm: false
-  activation: relu
-  latent_dim: 32
-
-  use_vae: true
-
-  batch_size: 8
-
-  per_alpha: 0.6
-  per_beta: 0.4
-
-  balanced_sampling: false
   utd: 1
 
-  n_obs_steps: ${n_obs_steps}
-  n_action_steps: ${n_action_steps}
-
-  temporal_agg: false
-
-  state_dim: 14
-  action_dim: 14
-
-  image_normalization:
-    mean: [0.485, 0.456, 0.406]
-    std: [0.229, 0.224, 0.225]
-
   delta_timestamps:
-    observation.images.top: [0.0]
-    observation.state: [0.0]
-    action: "[i / ${fps} for i in range(${horizon})]"
+    observation.images.top: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]"
+    observation.state: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]"
+    action: "[i / ${fps} for i in range(${policy.chunk_size})]"
diff --git a/tests/test_available.py b/tests/test_available.py
index be74a42a..b25a921f 100644
--- a/tests/test_available.py
+++ b/tests/test_available.py
@@ -18,7 +18,7 @@ from lerobot.common.datasets.xarm import XarmDataset
 from lerobot.common.datasets.aloha import AlohaDataset
 from lerobot.common.datasets.pusht import PushtDataset
 
-from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
+from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
 from lerobot.common.policies.diffusion.policy import DiffusionPolicy
 from lerobot.common.policies.tdmpc.policy import TDMPCPolicy