diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index d0f83325..79d513b8 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -198,7 +198,7 @@ class SACPolicy( # We cached the encoder output to avoid recomputing it observations_features = None if self.shared_encoder: - observations_features = self.actor.encoder.get_base_image_features(batch, normalize=True) + observations_features = self.actor.encoder.get_cached_image_features(batch=batch, normalize=True) actions, _, _ = self.actor(batch, observations_features) actions = self.unnormalize_outputs({"action": actions})["action"] @@ -596,6 +596,7 @@ class SACObservationEncoder(nn.Module): self, obs_dict: dict[str, torch.Tensor], vision_encoder_cache: torch.Tensor | None = None, + detach: bool = False, ) -> torch.Tensor: """Encode the image and/or state vector. @@ -606,7 +607,11 @@ class SACObservationEncoder(nn.Module): obs_dict = self.input_normalization(obs_dict) if len(self.all_image_keys) > 0: if vision_encoder_cache is None: - vision_encoder_cache = self.get_base_image_features(obs_dict, normalize=False) + vision_encoder_cache = self.get_cached_image_features(obs_dict, normalize=False) + + vision_encoder_cache = self.get_full_image_representation_with_cached_features( + batch_image_cached_features=vision_encoder_cache, detach=detach + ) feat.append(vision_encoder_cache) if "observation.environment_state" in self.config.input_features: @@ -618,10 +623,10 @@ class SACObservationEncoder(nn.Module): return features - def get_base_image_features( + def get_cached_image_features( self, batch: dict[str, torch.Tensor], normalize: bool = True ) -> dict[str, torch.Tensor]: - """Process all images through the base encoder in a batched manner""" + """Get the cached image features for a batch of observations, when the image encoder is frozen""" if normalize: batch = self.input_normalization(batch) @@ -634,9 +639,6 @@ class SACObservationEncoder(nn.Module): # Process through the image encoder in one pass batched_output = self.image_enc_layers(batched_input) - if self.freeze_image_encoder: - batched_output = batched_output.detach() - # Split the output back into individual tensors image_features = torch.chunk(batched_output, chunks=len(sorted_keys), dim=0) @@ -645,11 +647,24 @@ class SACObservationEncoder(nn.Module): for key, features in zip(sorted_keys, image_features, strict=True): result[key] = features + return result + + def get_full_image_representation_with_cached_features( + self, + batch_image_cached_features: dict[str, torch.Tensor], + detach: bool = False, + ) -> dict[str, torch.Tensor]: + """Get the full image representation with the cached features, applying the post-encoder and the spatial embedding""" + image_features = [] - for key in result: + for key in batch_image_cached_features: safe_key = key.replace(".", "_") - x = self.spatial_embeddings[safe_key](result[key]) + x = self.spatial_embeddings[safe_key](batch_image_cached_features[key]) x = self.post_encoders[safe_key](x) + + # The gradient of the image encoder is not needed to update the policy + if detach: + x = x.detach() image_features.append(x) image_features = torch.cat(image_features, dim=-1) @@ -916,10 +931,11 @@ class Policy(nn.Module): observations: torch.Tensor, observation_features: torch.Tensor | None = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - # Encode observations if encoder exists - obs_enc = self.encoder(observations, vision_encoder_cache=observation_features) - if self.encoder_is_shared: - obs_enc = obs_enc.detach() + # We detach the encoder if it is shared to avoid backprop through it + # This is important to avoid the encoder to be updated through the policy + obs_enc = self.encoder( + observations, vision_encoder_cache=observation_features, detach=self.encoder_is_shared + ) # Get network outputs outputs = self.network(obs_enc)