diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py
index 3f5dae1c..97ba04b1 100644
--- a/lerobot/common/policies/sac/configuration_sac.py
+++ b/lerobot/common/policies/sac/configuration_sac.py
@@ -16,6 +16,7 @@
 # limitations under the License.
 
 from dataclasses import dataclass, field
+from typing import Any
 
 
 @dataclass
@@ -26,6 +27,7 @@ class SACConfig:
             "observation.state": [4],
         }
     )
+
     output_shapes: dict[str, list[int]] = field(
         default_factory=lambda: {
             "action": [2],
@@ -43,36 +45,64 @@ class SACConfig:
     output_normalization_modes: dict[str, str] = field(
         default_factory=lambda: {"action": "min_max"},
     )
+from dataclasses import dataclass, field
 
-    shared_encoder = False
-    discount = 0.99
-    temperature_init = 1.0
-    num_critics = 2
-    # num_critics = 8
-    num_subsample_critics = None
-    # num_subsample_critics = 2
-    # critic_lr = 1e-3
-    critic_lr = 3e-4
-    actor_lr = 3e-4
-    temperature_lr = 3e-4
-    critic_target_update_weight = 0.005
-    # utd_ratio = 8
-    utd_ratio = 1  # If you want enable utd_ratio, you need to set it to >1
-    state_encoder_hidden_dim = 256
-    latent_dim = 256
-    target_entropy = None
-    # backup_entropy = False
-    use_backup_entropy = True
-    critic_network_kwargs = {
-        "hidden_dims": [256, 256],
-        "activate_final": True,
-    }
-    actor_network_kwargs = {
-        "hidden_dims": [256, 256],
-        "activate_final": True,
-    }
-    policy_kwargs = {
-        "use_tanh_squash": True,
-        "log_std_min": -5,
-        "log_std_max": 2,
-    }
+@dataclass
+class SACConfig:
+    input_shapes: dict[str, list[int]] = field(
+        default_factory=lambda: {
+            "observation.image": [3, 84, 84],
+            "observation.state": [4],
+        }
+    )
+    output_shapes: dict[str, list[int]] = field(
+        default_factory=lambda: {
+            "action": [2],
+        }
+    )
+    input_normalization_modes: dict[str, str] = field(
+        default_factory=lambda: {
+            "observation.image": "mean_std",
+            "observation.state": "min_max",
+            "observation.environment_state": "min_max",
+        }
+    )
+    output_normalization_modes: dict[str, str] = field(
+        default_factory=lambda: {"action": "min_max"}
+    )
+
+    # Add type annotations for these fields:
+    image_encoder_hidden_dim: int = 32
+    shared_encoder: bool = False
+    discount: float = 0.99
+    temperature_init: float = 1.0
+    num_critics: int = 2
+    num_subsample_critics: int | None = None
+    critic_lr: float = 3e-4
+    actor_lr: float = 3e-4
+    temperature_lr: float = 3e-4
+    critic_target_update_weight: float = 0.005
+    utd_ratio: int = 1  # If you want enable utd_ratio, you need to set it to >1
+    state_encoder_hidden_dim: int = 256
+    latent_dim: int = 256
+    target_entropy: float | None = None
+    use_backup_entropy: bool = True
+    critic_network_kwargs: dict[str, Any] = field(
+        default_factory=lambda: {
+            "hidden_dims": [256, 256],
+            "activate_final": True,
+        }
+    )
+    actor_network_kwargs: dict[str, Any] = field(
+        default_factory=lambda: {
+            "hidden_dims": [256, 256],
+            "activate_final": True,
+        }
+    )
+    policy_kwargs: dict[str, Any] = field(
+        default_factory=lambda: {
+            "use_tanh_squash": True,
+            "log_std_min": -5,
+            "log_std_max": 2,
+        }
+    )