fix caching and dataset stats is optional
This commit is contained in:
parent
ab2c2d39fb
commit
7ad93bdbf1
|
@ -120,7 +120,7 @@ class SACConfig(PreTrainedConfig):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset_stats: dict[str, dict[str, list[float]]] = field(
|
dataset_stats: dict[str, dict[str, list[float]]] | None = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"observation.image": {
|
"observation.image": {
|
||||||
"mean": [0.485, 0.456, 0.406],
|
"mean": [0.485, 0.456, 0.406],
|
||||||
|
|
|
@ -65,16 +65,21 @@ class SACPolicy(
|
||||||
else:
|
else:
|
||||||
self.normalize_inputs = nn.Identity()
|
self.normalize_inputs = nn.Identity()
|
||||||
|
|
||||||
output_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
|
|
||||||
|
|
||||||
# HACK: This is hacky and should be removed
|
if config.dataset_stats is not None:
|
||||||
dataset_stats = dataset_stats or output_normalization_params
|
output_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
|
||||||
self.normalize_targets = Normalize(
|
|
||||||
config.output_features, config.normalization_mapping, dataset_stats
|
# HACK: This is hacky and should be removed
|
||||||
)
|
dataset_stats = dataset_stats or output_normalization_params
|
||||||
self.unnormalize_outputs = Unnormalize(
|
self.normalize_targets = Normalize(
|
||||||
config.output_features, config.normalization_mapping, dataset_stats
|
config.output_features, config.normalization_mapping, dataset_stats
|
||||||
)
|
)
|
||||||
|
self.unnormalize_outputs = Unnormalize(
|
||||||
|
config.output_features, config.normalization_mapping, dataset_stats
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.normalize_targets = nn.Identity()
|
||||||
|
self.unnormalize_outputs = nn.Identity()
|
||||||
|
|
||||||
# NOTE: For images the encoder should be shared between the actor and critic
|
# NOTE: For images the encoder should be shared between the actor and critic
|
||||||
if config.shared_encoder:
|
if config.shared_encoder:
|
||||||
|
@ -192,7 +197,7 @@ class SACPolicy(
|
||||||
# We cached the encoder output to avoid recomputing it
|
# We cached the encoder output to avoid recomputing it
|
||||||
observations_features = None
|
observations_features = None
|
||||||
if self.shared_encoder:
|
if self.shared_encoder:
|
||||||
observations_features = self.actor.encoder.get_image_features(batch)
|
observations_features = self.actor.encoder.get_image_features(batch, normalize=True)
|
||||||
|
|
||||||
actions, _, _ = self.actor(batch, observations_features)
|
actions, _, _ = self.actor(batch, observations_features)
|
||||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||||
|
@ -568,8 +573,7 @@ class SACObservationEncoder(nn.Module):
|
||||||
feat = []
|
feat = []
|
||||||
obs_dict = self.input_normalization(obs_dict)
|
obs_dict = self.input_normalization(obs_dict)
|
||||||
if len(self.all_image_keys) > 0 and vision_encoder_cache is None:
|
if len(self.all_image_keys) > 0 and vision_encoder_cache is None:
|
||||||
vision_encoder_cache = self.get_image_features(obs_dict)
|
vision_encoder_cache = self.get_image_features(obs_dict, normalize=False)
|
||||||
feat.append(vision_encoder_cache)
|
|
||||||
|
|
||||||
if vision_encoder_cache is not None:
|
if vision_encoder_cache is not None:
|
||||||
feat.append(vision_encoder_cache)
|
feat.append(vision_encoder_cache)
|
||||||
|
@ -584,8 +588,10 @@ class SACObservationEncoder(nn.Module):
|
||||||
|
|
||||||
return features
|
return features
|
||||||
|
|
||||||
def get_image_features(self, batch: dict[str, Tensor]) -> torch.Tensor:
|
def get_image_features(self, batch: dict[str, Tensor], normalize: bool = True) -> torch.Tensor:
|
||||||
# [N*B, C, H, W]
|
# [N*B, C, H, W]
|
||||||
|
if normalize:
|
||||||
|
batch = self.input_normalization(batch)
|
||||||
if len(self.all_image_keys) > 0:
|
if len(self.all_image_keys) > 0:
|
||||||
# Batch all images along the batch dimension, then encode them.
|
# Batch all images along the batch dimension, then encode them.
|
||||||
images_batched = torch.cat([batch[key] for key in self.all_image_keys], dim=0)
|
images_batched = torch.cat([batch[key] for key in self.all_image_keys], dim=0)
|
||||||
|
|
|
@ -1026,8 +1026,8 @@ def get_observation_features(
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
observation_features = policy.actor.encoder.get_image_features(observations)
|
observation_features = policy.actor.encoder.get_image_features(observations, normalize=True)
|
||||||
next_observation_features = policy.actor.encoder.get_image_features(next_observations)
|
next_observation_features = policy.actor.encoder.get_image_features(next_observations, normalize=True)
|
||||||
|
|
||||||
return observation_features, next_observation_features
|
return observation_features, next_observation_features
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue