fix caching
Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
This commit is contained in:
parent
cf8d995c3a
commit
201215cd6c
|
@ -198,7 +198,7 @@ class SACPolicy(
|
||||||
# We cached the encoder output to avoid recomputing it
|
# We cached the encoder output to avoid recomputing it
|
||||||
observations_features = None
|
observations_features = None
|
||||||
if self.shared_encoder:
|
if self.shared_encoder:
|
||||||
observations_features = self.actor.encoder.get_base_image_features(batch, normalize=True)
|
observations_features = self.actor.encoder.get_cached_image_features(batch=batch, normalize=True)
|
||||||
|
|
||||||
actions, _, _ = self.actor(batch, observations_features)
|
actions, _, _ = self.actor(batch, observations_features)
|
||||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||||
|
@ -596,6 +596,7 @@ class SACObservationEncoder(nn.Module):
|
||||||
self,
|
self,
|
||||||
obs_dict: dict[str, torch.Tensor],
|
obs_dict: dict[str, torch.Tensor],
|
||||||
vision_encoder_cache: torch.Tensor | None = None,
|
vision_encoder_cache: torch.Tensor | None = None,
|
||||||
|
detach: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Encode the image and/or state vector.
|
"""Encode the image and/or state vector.
|
||||||
|
|
||||||
|
@ -606,7 +607,11 @@ class SACObservationEncoder(nn.Module):
|
||||||
obs_dict = self.input_normalization(obs_dict)
|
obs_dict = self.input_normalization(obs_dict)
|
||||||
if len(self.all_image_keys) > 0:
|
if len(self.all_image_keys) > 0:
|
||||||
if vision_encoder_cache is None:
|
if vision_encoder_cache is None:
|
||||||
vision_encoder_cache = self.get_base_image_features(obs_dict, normalize=False)
|
vision_encoder_cache = self.get_cached_image_features(obs_dict, normalize=False)
|
||||||
|
|
||||||
|
vision_encoder_cache = self.get_full_image_representation_with_cached_features(
|
||||||
|
batch_image_cached_features=vision_encoder_cache, detach=detach
|
||||||
|
)
|
||||||
feat.append(vision_encoder_cache)
|
feat.append(vision_encoder_cache)
|
||||||
|
|
||||||
if "observation.environment_state" in self.config.input_features:
|
if "observation.environment_state" in self.config.input_features:
|
||||||
|
@ -618,10 +623,10 @@ class SACObservationEncoder(nn.Module):
|
||||||
|
|
||||||
return features
|
return features
|
||||||
|
|
||||||
def get_base_image_features(
|
def get_cached_image_features(
|
||||||
self, batch: dict[str, torch.Tensor], normalize: bool = True
|
self, batch: dict[str, torch.Tensor], normalize: bool = True
|
||||||
) -> dict[str, torch.Tensor]:
|
) -> dict[str, torch.Tensor]:
|
||||||
"""Process all images through the base encoder in a batched manner"""
|
"""Get the cached image features for a batch of observations, when the image encoder is frozen"""
|
||||||
if normalize:
|
if normalize:
|
||||||
batch = self.input_normalization(batch)
|
batch = self.input_normalization(batch)
|
||||||
|
|
||||||
|
@ -634,9 +639,6 @@ class SACObservationEncoder(nn.Module):
|
||||||
# Process through the image encoder in one pass
|
# Process through the image encoder in one pass
|
||||||
batched_output = self.image_enc_layers(batched_input)
|
batched_output = self.image_enc_layers(batched_input)
|
||||||
|
|
||||||
if self.freeze_image_encoder:
|
|
||||||
batched_output = batched_output.detach()
|
|
||||||
|
|
||||||
# Split the output back into individual tensors
|
# Split the output back into individual tensors
|
||||||
image_features = torch.chunk(batched_output, chunks=len(sorted_keys), dim=0)
|
image_features = torch.chunk(batched_output, chunks=len(sorted_keys), dim=0)
|
||||||
|
|
||||||
|
@ -645,11 +647,24 @@ class SACObservationEncoder(nn.Module):
|
||||||
for key, features in zip(sorted_keys, image_features, strict=True):
|
for key, features in zip(sorted_keys, image_features, strict=True):
|
||||||
result[key] = features
|
result[key] = features
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_full_image_representation_with_cached_features(
|
||||||
|
self,
|
||||||
|
batch_image_cached_features: dict[str, torch.Tensor],
|
||||||
|
detach: bool = False,
|
||||||
|
) -> dict[str, torch.Tensor]:
|
||||||
|
"""Get the full image representation with the cached features, applying the post-encoder and the spatial embedding"""
|
||||||
|
|
||||||
image_features = []
|
image_features = []
|
||||||
for key in result:
|
for key in batch_image_cached_features:
|
||||||
safe_key = key.replace(".", "_")
|
safe_key = key.replace(".", "_")
|
||||||
x = self.spatial_embeddings[safe_key](result[key])
|
x = self.spatial_embeddings[safe_key](batch_image_cached_features[key])
|
||||||
x = self.post_encoders[safe_key](x)
|
x = self.post_encoders[safe_key](x)
|
||||||
|
|
||||||
|
# The gradient of the image encoder is not needed to update the policy
|
||||||
|
if detach:
|
||||||
|
x = x.detach()
|
||||||
image_features.append(x)
|
image_features.append(x)
|
||||||
|
|
||||||
image_features = torch.cat(image_features, dim=-1)
|
image_features = torch.cat(image_features, dim=-1)
|
||||||
|
@ -916,10 +931,11 @@ class Policy(nn.Module):
|
||||||
observations: torch.Tensor,
|
observations: torch.Tensor,
|
||||||
observation_features: torch.Tensor | None = None,
|
observation_features: torch.Tensor | None = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# Encode observations if encoder exists
|
# We detach the encoder if it is shared to avoid backprop through it
|
||||||
obs_enc = self.encoder(observations, vision_encoder_cache=observation_features)
|
# This is important to avoid the encoder to be updated through the policy
|
||||||
if self.encoder_is_shared:
|
obs_enc = self.encoder(
|
||||||
obs_enc = obs_enc.detach()
|
observations, vision_encoder_cache=observation_features, detach=self.encoder_is_shared
|
||||||
|
)
|
||||||
|
|
||||||
# Get network outputs
|
# Get network outputs
|
||||||
outputs = self.network(obs_enc)
|
outputs = self.network(obs_enc)
|
||||||
|
|
Loading…
Reference in New Issue