fix caching and dataset stats is optional

This commit is contained in:
AdilZouitine 2025-04-09 13:20:51 +00:00
parent a8135629b4
commit d948b95d22
3 changed files with 22 additions and 16 deletions

View File

@ -120,7 +120,7 @@ class SACConfig(PreTrainedConfig):
} }
) )
dataset_stats: dict[str, dict[str, list[float]]] = field( dataset_stats: dict[str, dict[str, list[float]]] | None = field(
default_factory=lambda: { default_factory=lambda: {
"observation.image": { "observation.image": {
"mean": [0.485, 0.456, 0.406], "mean": [0.485, 0.456, 0.406],

View File

@ -63,6 +63,8 @@ class SACPolicy(
else: else:
self.normalize_inputs = nn.Identity() self.normalize_inputs = nn.Identity()
if config.dataset_stats is not None:
output_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats) output_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
# HACK: This is hacky and should be removed # HACK: This is hacky and should be removed
@ -73,6 +75,9 @@ class SACPolicy(
self.unnormalize_outputs = Unnormalize( self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats config.output_features, config.normalization_mapping, dataset_stats
) )
else:
self.normalize_targets = nn.Identity()
self.unnormalize_outputs = nn.Identity()
# NOTE: For images the encoder should be shared between the actor and critic # NOTE: For images the encoder should be shared between the actor and critic
if config.shared_encoder: if config.shared_encoder:
@ -188,7 +193,7 @@ class SACPolicy(
# We cached the encoder output to avoid recomputing it # We cached the encoder output to avoid recomputing it
observations_features = None observations_features = None
if self.shared_encoder: if self.shared_encoder:
observations_features = self.actor.encoder.get_image_features(batch) observations_features = self.actor.encoder.get_image_features(batch, normalize=True)
actions, _, _ = self.actor(batch, observations_features) actions, _, _ = self.actor(batch, observations_features)
actions = self.unnormalize_outputs({"action": actions})["action"] actions = self.unnormalize_outputs({"action": actions})["action"]
@ -564,8 +569,7 @@ class SACObservationEncoder(nn.Module):
feat = [] feat = []
obs_dict = self.input_normalization(obs_dict) obs_dict = self.input_normalization(obs_dict)
if len(self.all_image_keys) > 0 and vision_encoder_cache is None: if len(self.all_image_keys) > 0 and vision_encoder_cache is None:
vision_encoder_cache = self.get_image_features(obs_dict) vision_encoder_cache = self.get_image_features(obs_dict, normalize=False)
feat.append(vision_encoder_cache)
if vision_encoder_cache is not None: if vision_encoder_cache is not None:
feat.append(vision_encoder_cache) feat.append(vision_encoder_cache)
@ -580,8 +584,10 @@ class SACObservationEncoder(nn.Module):
return features return features
def get_image_features(self, batch: dict[str, Tensor]) -> torch.Tensor: def get_image_features(self, batch: dict[str, Tensor], normalize: bool = True) -> torch.Tensor:
# [N*B, C, H, W] # [N*B, C, H, W]
if normalize:
batch = self.input_normalization(batch)
if len(self.all_image_keys) > 0: if len(self.all_image_keys) > 0:
# Batch all images along the batch dimension, then encode them. # Batch all images along the batch dimension, then encode them.
images_batched = torch.cat([batch[key] for key in self.all_image_keys], dim=0) images_batched = torch.cat([batch[key] for key in self.all_image_keys], dim=0)

View File

@ -1026,8 +1026,8 @@ def get_observation_features(
return None, None return None, None
with torch.no_grad(): with torch.no_grad():
observation_features = policy.actor.encoder.get_image_features(observations) observation_features = policy.actor.encoder.get_image_features(observations, normalize=True)
next_observation_features = policy.actor.encoder.get_image_features(next_observations) next_observation_features = policy.actor.encoder.get_image_features(next_observations, normalize=True)
return observation_features, next_observation_features return observation_features, next_observation_features