fix caching

Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
This commit is contained in:
AdilZouitine 2025-04-15 13:16:22 +00:00
parent cf8d995c3a
commit 201215cd6c
1 changed files with 29 additions and 13 deletions

View File

@ -198,7 +198,7 @@ class SACPolicy(
# We cached the encoder output to avoid recomputing it # We cached the encoder output to avoid recomputing it
observations_features = None observations_features = None
if self.shared_encoder: if self.shared_encoder:
observations_features = self.actor.encoder.get_base_image_features(batch, normalize=True) observations_features = self.actor.encoder.get_cached_image_features(batch=batch, normalize=True)
actions, _, _ = self.actor(batch, observations_features) actions, _, _ = self.actor(batch, observations_features)
actions = self.unnormalize_outputs({"action": actions})["action"] actions = self.unnormalize_outputs({"action": actions})["action"]
@ -596,6 +596,7 @@ class SACObservationEncoder(nn.Module):
self, self,
obs_dict: dict[str, torch.Tensor], obs_dict: dict[str, torch.Tensor],
vision_encoder_cache: torch.Tensor | None = None, vision_encoder_cache: torch.Tensor | None = None,
detach: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
"""Encode the image and/or state vector. """Encode the image and/or state vector.
@ -606,7 +607,11 @@ class SACObservationEncoder(nn.Module):
obs_dict = self.input_normalization(obs_dict) obs_dict = self.input_normalization(obs_dict)
if len(self.all_image_keys) > 0: if len(self.all_image_keys) > 0:
if vision_encoder_cache is None: if vision_encoder_cache is None:
vision_encoder_cache = self.get_base_image_features(obs_dict, normalize=False) vision_encoder_cache = self.get_cached_image_features(obs_dict, normalize=False)
vision_encoder_cache = self.get_full_image_representation_with_cached_features(
batch_image_cached_features=vision_encoder_cache, detach=detach
)
feat.append(vision_encoder_cache) feat.append(vision_encoder_cache)
if "observation.environment_state" in self.config.input_features: if "observation.environment_state" in self.config.input_features:
@ -618,10 +623,10 @@ class SACObservationEncoder(nn.Module):
return features return features
def get_base_image_features( def get_cached_image_features(
self, batch: dict[str, torch.Tensor], normalize: bool = True self, batch: dict[str, torch.Tensor], normalize: bool = True
) -> dict[str, torch.Tensor]: ) -> dict[str, torch.Tensor]:
"""Process all images through the base encoder in a batched manner""" """Get the cached image features for a batch of observations, when the image encoder is frozen"""
if normalize: if normalize:
batch = self.input_normalization(batch) batch = self.input_normalization(batch)
@ -634,9 +639,6 @@ class SACObservationEncoder(nn.Module):
# Process through the image encoder in one pass # Process through the image encoder in one pass
batched_output = self.image_enc_layers(batched_input) batched_output = self.image_enc_layers(batched_input)
if self.freeze_image_encoder:
batched_output = batched_output.detach()
# Split the output back into individual tensors # Split the output back into individual tensors
image_features = torch.chunk(batched_output, chunks=len(sorted_keys), dim=0) image_features = torch.chunk(batched_output, chunks=len(sorted_keys), dim=0)
@ -645,11 +647,24 @@ class SACObservationEncoder(nn.Module):
for key, features in zip(sorted_keys, image_features, strict=True): for key, features in zip(sorted_keys, image_features, strict=True):
result[key] = features result[key] = features
return result
def get_full_image_representation_with_cached_features(
self,
batch_image_cached_features: dict[str, torch.Tensor],
detach: bool = False,
) -> dict[str, torch.Tensor]:
"""Get the full image representation with the cached features, applying the post-encoder and the spatial embedding"""
image_features = [] image_features = []
for key in result: for key in batch_image_cached_features:
safe_key = key.replace(".", "_") safe_key = key.replace(".", "_")
x = self.spatial_embeddings[safe_key](result[key]) x = self.spatial_embeddings[safe_key](batch_image_cached_features[key])
x = self.post_encoders[safe_key](x) x = self.post_encoders[safe_key](x)
# The gradient of the image encoder is not needed to update the policy
if detach:
x = x.detach()
image_features.append(x) image_features.append(x)
image_features = torch.cat(image_features, dim=-1) image_features = torch.cat(image_features, dim=-1)
@ -916,10 +931,11 @@ class Policy(nn.Module):
observations: torch.Tensor, observations: torch.Tensor,
observation_features: torch.Tensor | None = None, observation_features: torch.Tensor | None = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Encode observations if encoder exists # We detach the encoder if it is shared to avoid backprop through it
obs_enc = self.encoder(observations, vision_encoder_cache=observation_features) # This is important to avoid the encoder to be updated through the policy
if self.encoder_is_shared: obs_enc = self.encoder(
obs_enc = obs_enc.detach() observations, vision_encoder_cache=observation_features, detach=self.encoder_is_shared
)
# Get network outputs # Get network outputs
outputs = self.network(obs_enc) outputs = self.network(obs_enc)