fix caching and dataset stats is optional

This commit is contained in:
AdilZouitine 2025-04-09 13:20:51 +00:00
parent a8135629b4
commit d948b95d22
3 changed files with 22 additions and 16 deletions

View File

@ -120,7 +120,7 @@ class SACConfig(PreTrainedConfig):
}
)
dataset_stats: dict[str, dict[str, list[float]]] = field(
dataset_stats: dict[str, dict[str, list[float]]] | None = field(
default_factory=lambda: {
"observation.image": {
"mean": [0.485, 0.456, 0.406],

View File

@ -63,16 +63,21 @@ class SACPolicy(
else:
self.normalize_inputs = nn.Identity()
output_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
# HACK: This is hacky and should be removed
dataset_stats = dataset_stats or output_normalization_params
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
if config.dataset_stats is not None:
output_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
# HACK: This is hacky and should be removed
dataset_stats = dataset_stats or output_normalization_params
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
else:
self.normalize_targets = nn.Identity()
self.unnormalize_outputs = nn.Identity()
# NOTE: For images the encoder should be shared between the actor and critic
if config.shared_encoder:
@ -188,7 +193,7 @@ class SACPolicy(
# We cached the encoder output to avoid recomputing it
observations_features = None
if self.shared_encoder:
observations_features = self.actor.encoder.get_image_features(batch)
observations_features = self.actor.encoder.get_image_features(batch, normalize=True)
actions, _, _ = self.actor(batch, observations_features)
actions = self.unnormalize_outputs({"action": actions})["action"]
@ -564,8 +569,7 @@ class SACObservationEncoder(nn.Module):
feat = []
obs_dict = self.input_normalization(obs_dict)
if len(self.all_image_keys) > 0 and vision_encoder_cache is None:
vision_encoder_cache = self.get_image_features(obs_dict)
feat.append(vision_encoder_cache)
vision_encoder_cache = self.get_image_features(obs_dict, normalize=False)
if vision_encoder_cache is not None:
feat.append(vision_encoder_cache)
@ -580,8 +584,10 @@ class SACObservationEncoder(nn.Module):
return features
def get_image_features(self, batch: dict[str, Tensor]) -> torch.Tensor:
def get_image_features(self, batch: dict[str, Tensor], normalize: bool = True) -> torch.Tensor:
# [N*B, C, H, W]
if normalize:
batch = self.input_normalization(batch)
if len(self.all_image_keys) > 0:
# Batch all images along the batch dimension, then encode them.
images_batched = torch.cat([batch[key] for key in self.all_image_keys], dim=0)

View File

@ -1026,8 +1026,8 @@ def get_observation_features(
return None, None
with torch.no_grad():
observation_features = policy.actor.encoder.get_image_features(observations)
next_observation_features = policy.actor.encoder.get_image_features(next_observations)
observation_features = policy.actor.encoder.get_image_features(observations, normalize=True)
next_observation_features = policy.actor.encoder.get_image_features(next_observations, normalize=True)
return observation_features, next_observation_features