fix caching and dataset stats is optional
This commit is contained in:
parent
a8135629b4
commit
d948b95d22
|
@ -120,7 +120,7 @@ class SACConfig(PreTrainedConfig):
|
|||
}
|
||||
)
|
||||
|
||||
dataset_stats: dict[str, dict[str, list[float]]] = field(
|
||||
dataset_stats: dict[str, dict[str, list[float]]] | None = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": {
|
||||
"mean": [0.485, 0.456, 0.406],
|
||||
|
|
|
@ -63,16 +63,21 @@ class SACPolicy(
|
|||
else:
|
||||
self.normalize_inputs = nn.Identity()
|
||||
|
||||
output_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
|
||||
|
||||
# HACK: This is hacky and should be removed
|
||||
dataset_stats = dataset_stats or output_normalization_params
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
if config.dataset_stats is not None:
|
||||
output_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
|
||||
|
||||
# HACK: This is hacky and should be removed
|
||||
dataset_stats = dataset_stats or output_normalization_params
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
else:
|
||||
self.normalize_targets = nn.Identity()
|
||||
self.unnormalize_outputs = nn.Identity()
|
||||
|
||||
# NOTE: For images the encoder should be shared between the actor and critic
|
||||
if config.shared_encoder:
|
||||
|
@ -188,7 +193,7 @@ class SACPolicy(
|
|||
# We cached the encoder output to avoid recomputing it
|
||||
observations_features = None
|
||||
if self.shared_encoder:
|
||||
observations_features = self.actor.encoder.get_image_features(batch)
|
||||
observations_features = self.actor.encoder.get_image_features(batch, normalize=True)
|
||||
|
||||
actions, _, _ = self.actor(batch, observations_features)
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
@ -564,8 +569,7 @@ class SACObservationEncoder(nn.Module):
|
|||
feat = []
|
||||
obs_dict = self.input_normalization(obs_dict)
|
||||
if len(self.all_image_keys) > 0 and vision_encoder_cache is None:
|
||||
vision_encoder_cache = self.get_image_features(obs_dict)
|
||||
feat.append(vision_encoder_cache)
|
||||
vision_encoder_cache = self.get_image_features(obs_dict, normalize=False)
|
||||
|
||||
if vision_encoder_cache is not None:
|
||||
feat.append(vision_encoder_cache)
|
||||
|
@ -580,8 +584,10 @@ class SACObservationEncoder(nn.Module):
|
|||
|
||||
return features
|
||||
|
||||
def get_image_features(self, batch: dict[str, Tensor]) -> torch.Tensor:
|
||||
def get_image_features(self, batch: dict[str, Tensor], normalize: bool = True) -> torch.Tensor:
|
||||
# [N*B, C, H, W]
|
||||
if normalize:
|
||||
batch = self.input_normalization(batch)
|
||||
if len(self.all_image_keys) > 0:
|
||||
# Batch all images along the batch dimension, then encode them.
|
||||
images_batched = torch.cat([batch[key] for key in self.all_image_keys], dim=0)
|
||||
|
|
|
@ -1026,8 +1026,8 @@ def get_observation_features(
|
|||
return None, None
|
||||
|
||||
with torch.no_grad():
|
||||
observation_features = policy.actor.encoder.get_image_features(observations)
|
||||
next_observation_features = policy.actor.encoder.get_image_features(next_observations)
|
||||
observation_features = policy.actor.encoder.get_image_features(observations, normalize=True)
|
||||
next_observation_features = policy.actor.encoder.get_image_features(next_observations, normalize=True)
|
||||
|
||||
return observation_features, next_observation_features
|
||||
|
||||
|
|
Loading…
Reference in New Issue