Refactor SACObservationEncoder to improve modularity and readability. Split initialization into dedicated methods for image and state layers, and enhance caching logic for image features. Update forward method to streamline feature encoding and ensure proper normalization handling.

This commit is contained in:
AdilZouitine 2025-04-18 12:22:14 +00:00 committed by Michel Aractingi
parent 1ce368503d
commit dcd850feab
1 changed files with 140 additions and 148 deletions

View File

@ -81,10 +81,10 @@ class SACPolicy(
@torch.no_grad() @torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor: def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select action for inference/evaluation""" """Select action for inference/evaluation"""
# We cached the encoder output to avoid recomputing it if the encoder is shared
observations_features = None observations_features = None
if self.shared_encoder: if self.shared_encoder:
observations_features = self.actor.encoder.get_cached_image_features(batch=batch, normalize=True) # Cache and normalize image features
observations_features = self.actor.encoder.get_cached_image_features(batch, normalize=True)
actions, _, _ = self.actor(batch, observations_features) actions, _, _ = self.actor(batch, observations_features)
actions = self.unnormalize_outputs({"action": actions})["action"] actions = self.unnormalize_outputs({"action": actions})["action"]
@ -489,171 +489,165 @@ class SACPolicy(
class SACObservationEncoder(nn.Module): class SACObservationEncoder(nn.Module):
"""Encode image and/or state vector observations.""" """Encode image and/or state vector observations."""
def __init__(self, config: SACConfig, input_normalizer: nn.Module): def __init__(self, config: SACConfig, input_normalizer: nn.Module) -> None:
""" super(SACObservationEncoder, self).__init__()
Creates encoders for pixel and/or state modalities.
"""
super().__init__()
self.config = config self.config = config
self.freeze_image_encoder = config.freeze_vision_encoder
self.input_normalization = input_normalizer self.input_normalization = input_normalizer
self._out_dim = 0 self._init_image_layers()
self._init_state_layers()
self._compute_output_dim()
if any("observation.image" in key for key in config.input_features): def _init_image_layers(self) -> None:
self.camera_number = config.camera_number self.image_keys = [k for k in self.config.input_features if k.startswith("observation.image")]
self.all_image_keys = sorted( self.has_images = bool(self.image_keys)
[k for k in config.input_features if k.startswith("observation.image")] if not self.has_images:
) return
self._out_dim += len(self.all_image_keys) * config.latent_dim if self.config.vision_encoder_name:
self.image_encoder = PretrainedImageEncoder(self.config)
if self.config.vision_encoder_name is not None:
self.image_enc_layers = PretrainedImageEncoder(config)
self.has_pretrained_vision_encoder = True
else: else:
self.image_enc_layers = DefaultImageEncoder(config) self.image_encoder = DefaultImageEncoder(self.config)
if self.freeze_image_encoder: if self.config.freeze_vision_encoder:
freeze_image_encoder(self.image_enc_layers) freeze_image_encoder(self.image_encoder)
dummy = torch.zeros(1, *self.config.input_features[self.image_keys[0]].shape)
with torch.no_grad():
_, channels, height, width = self.image_encoder(dummy).shape
# Separate components for each image stream
self.spatial_embeddings = nn.ModuleDict() self.spatial_embeddings = nn.ModuleDict()
self.post_encoders = nn.ModuleDict() self.post_encoders = nn.ModuleDict()
# determine the nb_channels, height and width of the image for key in self.image_keys:
name = key.replace(".", "_")
# Get first image key from input features self.spatial_embeddings[name] = SpatialLearnedEmbeddings(
image_key = next(key for key in config.input_features if key.startswith("observation.image")) # noqa: SIM118
dummy_batch = torch.zeros(1, *config.input_features[image_key].shape)
with torch.inference_mode():
dummy_output = self.image_enc_layers(dummy_batch)
_, channels, height, width = dummy_output.shape
for key in self.all_image_keys:
# HACK: This a hack because the state_dict use . to separate the keys
safe_key = key.replace(".", "_")
# Separate spatial embedding per image
self.spatial_embeddings[safe_key] = SpatialLearnedEmbeddings(
height=height, height=height,
width=width, width=width,
channel=channels, channel=channels,
num_features=config.image_embedding_pooling_dim, num_features=self.config.image_embedding_pooling_dim,
) )
# Separate post-encoder per image self.post_encoders[name] = nn.Sequential(
self.post_encoders[safe_key] = nn.Sequential(
nn.Dropout(0.1), nn.Dropout(0.1),
nn.Linear( nn.Linear(
in_features=channels * config.image_embedding_pooling_dim, in_features=channels * self.config.image_embedding_pooling_dim,
out_features=config.latent_dim, out_features=self.config.latent_dim,
), ),
nn.LayerNorm(normalized_shape=config.latent_dim), nn.LayerNorm(normalized_shape=self.config.latent_dim),
nn.Tanh(), nn.Tanh(),
) )
if "observation.state" in config.input_features: def _init_state_layers(self) -> None:
self.state_enc_layers = nn.Sequential( self.has_env = "observation.environment_state" in self.config.input_features
nn.Linear( self.has_state = "observation.state" in self.config.input_features
in_features=config.input_features["observation.state"].shape[0], if self.has_env:
out_features=config.latent_dim, dim = self.config.input_features["observation.environment_state"].shape[0]
), self.env_encoder = nn.Sequential(
nn.LayerNorm(normalized_shape=config.latent_dim), nn.Linear(dim, self.config.latent_dim),
nn.LayerNorm(self.config.latent_dim),
nn.Tanh(), nn.Tanh(),
) )
self._out_dim += config.latent_dim if self.has_state:
if "observation.environment_state" in config.input_features: dim = self.config.input_features["observation.state"].shape[0]
self.env_state_enc_layers = nn.Sequential( self.state_encoder = nn.Sequential(
nn.Linear( nn.Linear(dim, self.config.latent_dim),
in_features=config.input_features["observation.environment_state"].shape[0], nn.LayerNorm(self.config.latent_dim),
out_features=config.latent_dim,
),
nn.LayerNorm(normalized_shape=config.latent_dim),
nn.Tanh(), nn.Tanh(),
) )
self._out_dim += config.latent_dim
def _compute_output_dim(self) -> None:
out = 0
if self.has_images:
out += len(self.image_keys) * self.config.latent_dim
if self.has_env:
out += self.config.latent_dim
if self.has_state:
out += self.config.latent_dim
self._out_dim = out
def forward( def forward(
self, self, obs: dict[str, Tensor], cache: Optional[dict[str, Tensor]] = None, detach: bool = False
obs_dict: dict[str, torch.Tensor], ) -> Tensor:
vision_encoder_cache: torch.Tensor | None = None, obs = self.input_normalization(obs)
detach: bool = False, parts = []
) -> torch.Tensor: if self.has_images:
"""Encode the image and/or state vector. if cache is None:
cache = self.get_cached_image_features(obs, normalize=False)
parts.append(self._encode_images(cache, detach))
if self.has_env:
parts.append(self.env_encoder(obs["observation.environment_state"]))
if self.has_state:
parts.append(self.state_encoder(obs["observation.state"]))
if parts:
return torch.cat(parts, dim=-1)
Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken raise ValueError(
over all features. "No parts to concatenate, you should have at least one image or environment state or state"
"""
feat = []
obs_dict = self.input_normalization(obs_dict)
if len(self.all_image_keys) > 0:
if vision_encoder_cache is None:
vision_encoder_cache = self.get_cached_image_features(obs_dict, normalize=False)
vision_encoder_cache = self.get_full_image_representation_with_cached_features(
batch_image_cached_features=vision_encoder_cache, detach=detach
) )
feat.append(vision_encoder_cache)
if "observation.environment_state" in self.config.input_features: def get_cached_image_features(self, obs: dict[str, Tensor], normalize: bool = False) -> dict[str, Tensor]:
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"])) """Extract and optionally cache image features from observations.
if "observation.state" in self.config.input_features:
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
features = torch.cat(tensors=feat, dim=-1) This function processes image observations through the vision encoder once and returns
the resulting features.
When the image encoder is shared between actor and critics AND frozen, these features can be safely cached and
reused across policy components (actor, critic, grasp_critic), avoiding redundant forward passes.
return features Performance impact:
- The vision encoder forward pass is typically the main computational bottleneck during training and inference
- Caching these features can provide 2-4x speedup in training and inference
def get_cached_image_features( Normalization behavior:
self, batch: dict[str, torch.Tensor], normalize: bool = True - When called from inside forward(): set normalize=False since inputs are already normalized
) -> dict[str, torch.Tensor]: - When called from outside forward(): set normalize=True to ensure proper input normalization
"""Get the cached image features for a batch of observations, when the image encoder is frozen"""
Usage patterns:
- Called in select_action() with normalize=True
- Called in learner_server.py's get_observation_features() to pre-compute features for all policy components
- Called internally by forward() with normalize=False
Args:
obs: Dictionary of observation tensors containing image keys
normalize: Whether to normalize observations before encoding
Set to True when calling directly from outside the encoder's forward method
Set to False when calling from within forward() where inputs are already normalized
Returns:
Dictionary mapping image keys to their corresponding encoded features
"""
if normalize: if normalize:
batch = self.input_normalization(batch) obs = self.input_normalization(obs)
batched = torch.cat([obs[k] for k in self.image_keys], dim=0)
out = self.image_encoder(batched)
chunks = torch.chunk(out, len(self.image_keys), dim=0)
return dict(zip(self.image_keys, chunks, strict=False))
# Sort keys for consistent ordering def _encode_images(self, cache: dict[str, Tensor], detach: bool) -> Tensor:
sorted_keys = sorted(self.all_image_keys) """Encode image features from cached observations.
# Stack all images into a single batch This function takes pre-encoded image features from the cache and applies spatial embeddings and post-encoders.
batched_input = torch.cat([batch[key] for key in sorted_keys], dim=0) It also supports detaching the encoded features if specified.
# Process through the image encoder in one pass Args:
batched_output = self.image_enc_layers(batched_input) cache (dict[str, Tensor]): The cached image features.
detach (bool): Usually when the encoder is shared between actor and critics,
we want to detach the encoded features on the policy side to avoid backprop through the encoder.
More detail here `https://cdn.aaai.org/ojs/17276/17276-13-20770-1-2-20210518.pdf`
# Split the output back into individual tensors Returns:
image_features = torch.chunk(batched_output, chunks=len(sorted_keys), dim=0) Tensor: The encoded image features.
"""
# Create a dictionary mapping the original keys to their features feats = []
result = {} for k, feat in cache.items():
for key, features in zip(sorted_keys, image_features, strict=True): safe_key = k.replace(".", "_")
result[key] = features x = self.spatial_embeddings[safe_key](feat)
return result
def get_full_image_representation_with_cached_features(
self,
batch_image_cached_features: dict[str, torch.Tensor],
detach: bool = False,
) -> dict[str, torch.Tensor]:
"""Get the full image representation with the cached features, applying the post-encoder and the spatial embedding"""
image_features = []
for key in batch_image_cached_features:
safe_key = key.replace(".", "_")
x = self.spatial_embeddings[safe_key](batch_image_cached_features[key])
x = self.post_encoders[safe_key](x) x = self.post_encoders[safe_key](x)
# The gradient of the image encoder is not needed to update the policy
if detach: if detach:
x = x.detach() x = x.detach()
image_features.append(x) feats.append(x)
return torch.cat(feats, dim=-1)
image_features = torch.cat(image_features, dim=-1)
return image_features
@property @property
def output_dim(self) -> int: def output_dim(self) -> int:
"""Returns the dimension of the encoder output"""
return self._out_dim return self._out_dim
@ -805,7 +799,7 @@ class CriticEnsemble(nn.Module):
actions = self.output_normalization(actions)["action"] actions = self.output_normalization(actions)["action"]
actions = actions.to(device) actions = actions.to(device)
obs_enc = self.encoder(observations, observation_features) obs_enc = self.encoder(observations, cache=observation_features)
inputs = torch.cat([obs_enc, actions], dim=-1) inputs = torch.cat([obs_enc, actions], dim=-1)
@ -858,7 +852,7 @@ class GraspCritic(nn.Module):
device = get_device_from_parameters(self) device = get_device_from_parameters(self)
# Move each tensor in observations to device by cloning first to avoid inplace operations # Move each tensor in observations to device by cloning first to avoid inplace operations
observations = {k: v.to(device) for k, v in observations.items()} observations = {k: v.to(device) for k, v in observations.items()}
obs_enc = self.encoder(observations, vision_encoder_cache=observation_features) obs_enc = self.encoder(observations, cache=observation_features)
return self.output_layer(self.net(obs_enc)) return self.output_layer(self.net(obs_enc))
@ -911,12 +905,10 @@ class Policy(nn.Module):
self, self,
observations: torch.Tensor, observations: torch.Tensor,
observation_features: torch.Tensor | None = None, observation_features: torch.Tensor | None = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# We detach the encoder if it is shared to avoid backprop through it # We detach the encoder if it is shared to avoid backprop through it
# This is important to avoid the encoder to be updated through the policy # This is important to avoid the encoder to be updated through the policy
obs_enc = self.encoder( obs_enc = self.encoder(observations, cache=observation_features, detach=self.encoder_is_shared)
observations, vision_encoder_cache=observation_features, detach=self.encoder_is_shared
)
# Get network outputs # Get network outputs
outputs = self.network(obs_enc) outputs = self.network(obs_enc)