diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py
index d0f83325..79d513b8 100644
--- a/lerobot/common/policies/sac/modeling_sac.py
+++ b/lerobot/common/policies/sac/modeling_sac.py
@@ -198,7 +198,7 @@ class SACPolicy(
         # We cached the encoder output to avoid recomputing it
         observations_features = None
         if self.shared_encoder:
-            observations_features = self.actor.encoder.get_base_image_features(batch, normalize=True)
+            observations_features = self.actor.encoder.get_cached_image_features(batch=batch, normalize=True)
 
         actions, _, _ = self.actor(batch, observations_features)
         actions = self.unnormalize_outputs({"action": actions})["action"]
@@ -596,6 +596,7 @@ class SACObservationEncoder(nn.Module):
         self,
         obs_dict: dict[str, torch.Tensor],
         vision_encoder_cache: torch.Tensor | None = None,
+        detach: bool = False,
     ) -> torch.Tensor:
         """Encode the image and/or state vector.
 
@@ -606,7 +607,11 @@ class SACObservationEncoder(nn.Module):
         obs_dict = self.input_normalization(obs_dict)
         if len(self.all_image_keys) > 0:
             if vision_encoder_cache is None:
-                vision_encoder_cache = self.get_base_image_features(obs_dict, normalize=False)
+                vision_encoder_cache = self.get_cached_image_features(obs_dict, normalize=False)
+
+            vision_encoder_cache = self.get_full_image_representation_with_cached_features(
+                batch_image_cached_features=vision_encoder_cache, detach=detach
+            )
             feat.append(vision_encoder_cache)
 
         if "observation.environment_state" in self.config.input_features:
@@ -618,10 +623,10 @@ class SACObservationEncoder(nn.Module):
 
         return features
 
-    def get_base_image_features(
+    def get_cached_image_features(
         self, batch: dict[str, torch.Tensor], normalize: bool = True
     ) -> dict[str, torch.Tensor]:
-        """Process all images through the base encoder in a batched manner"""
+        """Get the cached image features for a batch of observations, when the image encoder is frozen"""
         if normalize:
             batch = self.input_normalization(batch)
 
@@ -634,9 +639,6 @@ class SACObservationEncoder(nn.Module):
         # Process through the image encoder in one pass
         batched_output = self.image_enc_layers(batched_input)
 
-        if self.freeze_image_encoder:
-            batched_output = batched_output.detach()
-
         # Split the output back into individual tensors
         image_features = torch.chunk(batched_output, chunks=len(sorted_keys), dim=0)
 
@@ -645,11 +647,24 @@ class SACObservationEncoder(nn.Module):
         for key, features in zip(sorted_keys, image_features, strict=True):
             result[key] = features
 
+        return result
+
+    def get_full_image_representation_with_cached_features(
+        self,
+        batch_image_cached_features: dict[str, torch.Tensor],
+        detach: bool = False,
+    ) -> dict[str, torch.Tensor]:
+        """Get the full image representation with the cached features, applying the post-encoder and the spatial embedding"""
+
         image_features = []
-        for key in result:
+        for key in batch_image_cached_features:
             safe_key = key.replace(".", "_")
-            x = self.spatial_embeddings[safe_key](result[key])
+            x = self.spatial_embeddings[safe_key](batch_image_cached_features[key])
             x = self.post_encoders[safe_key](x)
+
+            # The gradient of the image encoder is not needed to update the policy
+            if detach:
+                x = x.detach()
             image_features.append(x)
 
         image_features = torch.cat(image_features, dim=-1)
@@ -916,10 +931,11 @@ class Policy(nn.Module):
         observations: torch.Tensor,
         observation_features: torch.Tensor | None = None,
     ) -> Tuple[torch.Tensor, torch.Tensor]:
-        # Encode observations if encoder exists
-        obs_enc = self.encoder(observations, vision_encoder_cache=observation_features)
-        if self.encoder_is_shared:
-            obs_enc = obs_enc.detach()
+        # We detach the encoder if it is shared to avoid backprop through it
+        # This is important to avoid the encoder to be updated through the policy
+        obs_enc = self.encoder(
+            observations, vision_encoder_cache=observation_features, detach=self.encoder_is_shared
+        )
 
         # Get network outputs
         outputs = self.network(obs_enc)