Refactor SACObservationEncoder to improve modularity and readability. Split initialization into dedicated methods for image and state layers, and enhance caching logic for image features. Update forward method to streamline feature encoding and ensure proper normalization handling.

This commit is contained in:
AdilZouitine 2025-04-18 12:22:14 +00:00 committed by Michel Aractingi
parent 1ce368503d
commit dcd850feab
1 changed files with 140 additions and 148 deletions

View File

@ -81,10 +81,10 @@ class SACPolicy(
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select action for inference/evaluation"""
# We cached the encoder output to avoid recomputing it if the encoder is shared
observations_features = None
if self.shared_encoder:
observations_features = self.actor.encoder.get_cached_image_features(batch=batch, normalize=True)
# Cache and normalize image features
observations_features = self.actor.encoder.get_cached_image_features(batch, normalize=True)
actions, _, _ = self.actor(batch, observations_features)
actions = self.unnormalize_outputs({"action": actions})["action"]
@ -489,171 +489,165 @@ class SACPolicy(
class SACObservationEncoder(nn.Module):
"""Encode image and/or state vector observations."""
def __init__(self, config: SACConfig, input_normalizer: nn.Module):
"""
Creates encoders for pixel and/or state modalities.
"""
super().__init__()
def __init__(self, config: SACConfig, input_normalizer: nn.Module) -> None:
super(SACObservationEncoder, self).__init__()
self.config = config
self.freeze_image_encoder = config.freeze_vision_encoder
self.input_normalization = input_normalizer
self._out_dim = 0
self._init_image_layers()
self._init_state_layers()
self._compute_output_dim()
if any("observation.image" in key for key in config.input_features):
self.camera_number = config.camera_number
self.all_image_keys = sorted(
[k for k in config.input_features if k.startswith("observation.image")]
def _init_image_layers(self) -> None:
self.image_keys = [k for k in self.config.input_features if k.startswith("observation.image")]
self.has_images = bool(self.image_keys)
if not self.has_images:
return
if self.config.vision_encoder_name:
self.image_encoder = PretrainedImageEncoder(self.config)
else:
self.image_encoder = DefaultImageEncoder(self.config)
if self.config.freeze_vision_encoder:
freeze_image_encoder(self.image_encoder)
dummy = torch.zeros(1, *self.config.input_features[self.image_keys[0]].shape)
with torch.no_grad():
_, channels, height, width = self.image_encoder(dummy).shape
self.spatial_embeddings = nn.ModuleDict()
self.post_encoders = nn.ModuleDict()
for key in self.image_keys:
name = key.replace(".", "_")
self.spatial_embeddings[name] = SpatialLearnedEmbeddings(
height=height,
width=width,
channel=channels,
num_features=self.config.image_embedding_pooling_dim,
)
self._out_dim += len(self.all_image_keys) * config.latent_dim
if self.config.vision_encoder_name is not None:
self.image_enc_layers = PretrainedImageEncoder(config)
self.has_pretrained_vision_encoder = True
else:
self.image_enc_layers = DefaultImageEncoder(config)
if self.freeze_image_encoder:
freeze_image_encoder(self.image_enc_layers)
# Separate components for each image stream
self.spatial_embeddings = nn.ModuleDict()
self.post_encoders = nn.ModuleDict()
# determine the nb_channels, height and width of the image
# Get first image key from input features
image_key = next(key for key in config.input_features if key.startswith("observation.image")) # noqa: SIM118
dummy_batch = torch.zeros(1, *config.input_features[image_key].shape)
with torch.inference_mode():
dummy_output = self.image_enc_layers(dummy_batch)
_, channels, height, width = dummy_output.shape
for key in self.all_image_keys:
# HACK: This a hack because the state_dict use . to separate the keys
safe_key = key.replace(".", "_")
# Separate spatial embedding per image
self.spatial_embeddings[safe_key] = SpatialLearnedEmbeddings(
height=height,
width=width,
channel=channels,
num_features=config.image_embedding_pooling_dim,
)
# Separate post-encoder per image
self.post_encoders[safe_key] = nn.Sequential(
nn.Dropout(0.1),
nn.Linear(
in_features=channels * config.image_embedding_pooling_dim,
out_features=config.latent_dim,
),
nn.LayerNorm(normalized_shape=config.latent_dim),
nn.Tanh(),
)
if "observation.state" in config.input_features:
self.state_enc_layers = nn.Sequential(
self.post_encoders[name] = nn.Sequential(
nn.Dropout(0.1),
nn.Linear(
in_features=config.input_features["observation.state"].shape[0],
out_features=config.latent_dim,
in_features=channels * self.config.image_embedding_pooling_dim,
out_features=self.config.latent_dim,
),
nn.LayerNorm(normalized_shape=config.latent_dim),
nn.LayerNorm(normalized_shape=self.config.latent_dim),
nn.Tanh(),
)
self._out_dim += config.latent_dim
if "observation.environment_state" in config.input_features:
self.env_state_enc_layers = nn.Sequential(
nn.Linear(
in_features=config.input_features["observation.environment_state"].shape[0],
out_features=config.latent_dim,
),
nn.LayerNorm(normalized_shape=config.latent_dim),
def _init_state_layers(self) -> None:
self.has_env = "observation.environment_state" in self.config.input_features
self.has_state = "observation.state" in self.config.input_features
if self.has_env:
dim = self.config.input_features["observation.environment_state"].shape[0]
self.env_encoder = nn.Sequential(
nn.Linear(dim, self.config.latent_dim),
nn.LayerNorm(self.config.latent_dim),
nn.Tanh(),
)
self._out_dim += config.latent_dim
if self.has_state:
dim = self.config.input_features["observation.state"].shape[0]
self.state_encoder = nn.Sequential(
nn.Linear(dim, self.config.latent_dim),
nn.LayerNorm(self.config.latent_dim),
nn.Tanh(),
)
def _compute_output_dim(self) -> None:
out = 0
if self.has_images:
out += len(self.image_keys) * self.config.latent_dim
if self.has_env:
out += self.config.latent_dim
if self.has_state:
out += self.config.latent_dim
self._out_dim = out
def forward(
self,
obs_dict: dict[str, torch.Tensor],
vision_encoder_cache: torch.Tensor | None = None,
detach: bool = False,
) -> torch.Tensor:
"""Encode the image and/or state vector.
self, obs: dict[str, Tensor], cache: Optional[dict[str, Tensor]] = None, detach: bool = False
) -> Tensor:
obs = self.input_normalization(obs)
parts = []
if self.has_images:
if cache is None:
cache = self.get_cached_image_features(obs, normalize=False)
parts.append(self._encode_images(cache, detach))
if self.has_env:
parts.append(self.env_encoder(obs["observation.environment_state"]))
if self.has_state:
parts.append(self.state_encoder(obs["observation.state"]))
if parts:
return torch.cat(parts, dim=-1)
Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken
over all features.
raise ValueError(
"No parts to concatenate, you should have at least one image or environment state or state"
)
def get_cached_image_features(self, obs: dict[str, Tensor], normalize: bool = False) -> dict[str, Tensor]:
"""Extract and optionally cache image features from observations.
This function processes image observations through the vision encoder once and returns
the resulting features.
When the image encoder is shared between actor and critics AND frozen, these features can be safely cached and
reused across policy components (actor, critic, grasp_critic), avoiding redundant forward passes.
Performance impact:
- The vision encoder forward pass is typically the main computational bottleneck during training and inference
- Caching these features can provide 2-4x speedup in training and inference
Normalization behavior:
- When called from inside forward(): set normalize=False since inputs are already normalized
- When called from outside forward(): set normalize=True to ensure proper input normalization
Usage patterns:
- Called in select_action() with normalize=True
- Called in learner_server.py's get_observation_features() to pre-compute features for all policy components
- Called internally by forward() with normalize=False
Args:
obs: Dictionary of observation tensors containing image keys
normalize: Whether to normalize observations before encoding
Set to True when calling directly from outside the encoder's forward method
Set to False when calling from within forward() where inputs are already normalized
Returns:
Dictionary mapping image keys to their corresponding encoded features
"""
feat = []
obs_dict = self.input_normalization(obs_dict)
if len(self.all_image_keys) > 0:
if vision_encoder_cache is None:
vision_encoder_cache = self.get_cached_image_features(obs_dict, normalize=False)
vision_encoder_cache = self.get_full_image_representation_with_cached_features(
batch_image_cached_features=vision_encoder_cache, detach=detach
)
feat.append(vision_encoder_cache)
if "observation.environment_state" in self.config.input_features:
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
if "observation.state" in self.config.input_features:
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
features = torch.cat(tensors=feat, dim=-1)
return features
def get_cached_image_features(
self, batch: dict[str, torch.Tensor], normalize: bool = True
) -> dict[str, torch.Tensor]:
"""Get the cached image features for a batch of observations, when the image encoder is frozen"""
if normalize:
batch = self.input_normalization(batch)
obs = self.input_normalization(obs)
batched = torch.cat([obs[k] for k in self.image_keys], dim=0)
out = self.image_encoder(batched)
chunks = torch.chunk(out, len(self.image_keys), dim=0)
return dict(zip(self.image_keys, chunks, strict=False))
# Sort keys for consistent ordering
sorted_keys = sorted(self.all_image_keys)
def _encode_images(self, cache: dict[str, Tensor], detach: bool) -> Tensor:
"""Encode image features from cached observations.
# Stack all images into a single batch
batched_input = torch.cat([batch[key] for key in sorted_keys], dim=0)
This function takes pre-encoded image features from the cache and applies spatial embeddings and post-encoders.
It also supports detaching the encoded features if specified.
# Process through the image encoder in one pass
batched_output = self.image_enc_layers(batched_input)
Args:
cache (dict[str, Tensor]): The cached image features.
detach (bool): Usually when the encoder is shared between actor and critics,
we want to detach the encoded features on the policy side to avoid backprop through the encoder.
More detail here `https://cdn.aaai.org/ojs/17276/17276-13-20770-1-2-20210518.pdf`
# Split the output back into individual tensors
image_features = torch.chunk(batched_output, chunks=len(sorted_keys), dim=0)
# Create a dictionary mapping the original keys to their features
result = {}
for key, features in zip(sorted_keys, image_features, strict=True):
result[key] = features
return result
def get_full_image_representation_with_cached_features(
self,
batch_image_cached_features: dict[str, torch.Tensor],
detach: bool = False,
) -> dict[str, torch.Tensor]:
"""Get the full image representation with the cached features, applying the post-encoder and the spatial embedding"""
image_features = []
for key in batch_image_cached_features:
safe_key = key.replace(".", "_")
x = self.spatial_embeddings[safe_key](batch_image_cached_features[key])
Returns:
Tensor: The encoded image features.
"""
feats = []
for k, feat in cache.items():
safe_key = k.replace(".", "_")
x = self.spatial_embeddings[safe_key](feat)
x = self.post_encoders[safe_key](x)
# The gradient of the image encoder is not needed to update the policy
if detach:
x = x.detach()
image_features.append(x)
image_features = torch.cat(image_features, dim=-1)
return image_features
feats.append(x)
return torch.cat(feats, dim=-1)
@property
def output_dim(self) -> int:
"""Returns the dimension of the encoder output"""
return self._out_dim
@ -805,7 +799,7 @@ class CriticEnsemble(nn.Module):
actions = self.output_normalization(actions)["action"]
actions = actions.to(device)
obs_enc = self.encoder(observations, observation_features)
obs_enc = self.encoder(observations, cache=observation_features)
inputs = torch.cat([obs_enc, actions], dim=-1)
@ -858,7 +852,7 @@ class GraspCritic(nn.Module):
device = get_device_from_parameters(self)
# Move each tensor in observations to device by cloning first to avoid inplace operations
observations = {k: v.to(device) for k, v in observations.items()}
obs_enc = self.encoder(observations, vision_encoder_cache=observation_features)
obs_enc = self.encoder(observations, cache=observation_features)
return self.output_layer(self.net(obs_enc))
@ -911,12 +905,10 @@ class Policy(nn.Module):
self,
observations: torch.Tensor,
observation_features: torch.Tensor | None = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# We detach the encoder if it is shared to avoid backprop through it
# This is important to avoid the encoder to be updated through the policy
obs_enc = self.encoder(
observations, vision_encoder_cache=observation_features, detach=self.encoder_is_shared
)
obs_enc = self.encoder(observations, cache=observation_features, detach=self.encoder_is_shared)
# Get network outputs
outputs = self.network(obs_enc)