From 7ad93bdbf11a5fb182ac6086dbd4623fa4c423bd Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Wed, 9 Apr 2025 13:20:51 +0000 Subject: [PATCH] fix caching and dataset stats is optional --- .../common/policies/sac/configuration_sac.py | 2 +- lerobot/common/policies/sac/modeling_sac.py | 32 +++++++++++-------- lerobot/scripts/server/learner_server.py | 4 +-- 3 files changed, 22 insertions(+), 16 deletions(-) diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index 3d01f47c..684ac17f 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -120,7 +120,7 @@ class SACConfig(PreTrainedConfig): } ) - dataset_stats: dict[str, dict[str, list[float]]] = field( + dataset_stats: dict[str, dict[str, list[float]]] | None = field( default_factory=lambda: { "observation.image": { "mean": [0.485, 0.456, 0.406], diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 0e6f8fda..3d2eca86 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -65,16 +65,21 @@ class SACPolicy( else: self.normalize_inputs = nn.Identity() - output_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats) - # HACK: This is hacky and should be removed - dataset_stats = dataset_stats or output_normalization_params - self.normalize_targets = Normalize( - config.output_features, config.normalization_mapping, dataset_stats - ) - self.unnormalize_outputs = Unnormalize( - config.output_features, config.normalization_mapping, dataset_stats - ) + if config.dataset_stats is not None: + output_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats) + + # HACK: This is hacky and should be removed + dataset_stats = dataset_stats or output_normalization_params + self.normalize_targets = Normalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + self.unnormalize_outputs = Unnormalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + else: + self.normalize_targets = nn.Identity() + self.unnormalize_outputs = nn.Identity() # NOTE: For images the encoder should be shared between the actor and critic if config.shared_encoder: @@ -192,7 +197,7 @@ class SACPolicy( # We cached the encoder output to avoid recomputing it observations_features = None if self.shared_encoder: - observations_features = self.actor.encoder.get_image_features(batch) + observations_features = self.actor.encoder.get_image_features(batch, normalize=True) actions, _, _ = self.actor(batch, observations_features) actions = self.unnormalize_outputs({"action": actions})["action"] @@ -568,8 +573,7 @@ class SACObservationEncoder(nn.Module): feat = [] obs_dict = self.input_normalization(obs_dict) if len(self.all_image_keys) > 0 and vision_encoder_cache is None: - vision_encoder_cache = self.get_image_features(obs_dict) - feat.append(vision_encoder_cache) + vision_encoder_cache = self.get_image_features(obs_dict, normalize=False) if vision_encoder_cache is not None: feat.append(vision_encoder_cache) @@ -584,8 +588,10 @@ class SACObservationEncoder(nn.Module): return features - def get_image_features(self, batch: dict[str, Tensor]) -> torch.Tensor: + def get_image_features(self, batch: dict[str, Tensor], normalize: bool = True) -> torch.Tensor: # [N*B, C, H, W] + if normalize: + batch = self.input_normalization(batch) if len(self.all_image_keys) > 0: # Batch all images along the batch dimension, then encode them. images_batched = torch.cat([batch[key] for key in self.all_image_keys], dim=0) diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 5489d6dc..707547a1 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -1026,8 +1026,8 @@ def get_observation_features( return None, None with torch.no_grad(): - observation_features = policy.actor.encoder.get_image_features(observations) - next_observation_features = policy.actor.encoder.get_image_features(next_observations) + observation_features = policy.actor.encoder.get_image_features(observations, normalize=True) + next_observation_features = policy.actor.encoder.get_image_features(next_observations, normalize=True) return observation_features, next_observation_features