Refactor SACObservationEncoder to improve modularity and readability. Split initialization into dedicated methods for image and state layers, and enhance caching logic for image features. Update forward method to streamline feature encoding and ensure proper normalization handling.
This commit is contained in:
parent
1ce368503d
commit
dcd850feab
|
@ -81,10 +81,10 @@ class SACPolicy(
|
|||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select action for inference/evaluation"""
|
||||
# We cached the encoder output to avoid recomputing it if the encoder is shared
|
||||
observations_features = None
|
||||
if self.shared_encoder:
|
||||
observations_features = self.actor.encoder.get_cached_image_features(batch=batch, normalize=True)
|
||||
# Cache and normalize image features
|
||||
observations_features = self.actor.encoder.get_cached_image_features(batch, normalize=True)
|
||||
|
||||
actions, _, _ = self.actor(batch, observations_features)
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
@ -489,171 +489,165 @@ class SACPolicy(
|
|||
class SACObservationEncoder(nn.Module):
|
||||
"""Encode image and/or state vector observations."""
|
||||
|
||||
def __init__(self, config: SACConfig, input_normalizer: nn.Module):
|
||||
"""
|
||||
Creates encoders for pixel and/or state modalities.
|
||||
"""
|
||||
super().__init__()
|
||||
def __init__(self, config: SACConfig, input_normalizer: nn.Module) -> None:
|
||||
super(SACObservationEncoder, self).__init__()
|
||||
self.config = config
|
||||
|
||||
self.freeze_image_encoder = config.freeze_vision_encoder
|
||||
|
||||
self.input_normalization = input_normalizer
|
||||
self._out_dim = 0
|
||||
self._init_image_layers()
|
||||
self._init_state_layers()
|
||||
self._compute_output_dim()
|
||||
|
||||
if any("observation.image" in key for key in config.input_features):
|
||||
self.camera_number = config.camera_number
|
||||
self.all_image_keys = sorted(
|
||||
[k for k in config.input_features if k.startswith("observation.image")]
|
||||
def _init_image_layers(self) -> None:
|
||||
self.image_keys = [k for k in self.config.input_features if k.startswith("observation.image")]
|
||||
self.has_images = bool(self.image_keys)
|
||||
if not self.has_images:
|
||||
return
|
||||
|
||||
if self.config.vision_encoder_name:
|
||||
self.image_encoder = PretrainedImageEncoder(self.config)
|
||||
else:
|
||||
self.image_encoder = DefaultImageEncoder(self.config)
|
||||
|
||||
if self.config.freeze_vision_encoder:
|
||||
freeze_image_encoder(self.image_encoder)
|
||||
|
||||
dummy = torch.zeros(1, *self.config.input_features[self.image_keys[0]].shape)
|
||||
with torch.no_grad():
|
||||
_, channels, height, width = self.image_encoder(dummy).shape
|
||||
|
||||
self.spatial_embeddings = nn.ModuleDict()
|
||||
self.post_encoders = nn.ModuleDict()
|
||||
|
||||
for key in self.image_keys:
|
||||
name = key.replace(".", "_")
|
||||
self.spatial_embeddings[name] = SpatialLearnedEmbeddings(
|
||||
height=height,
|
||||
width=width,
|
||||
channel=channels,
|
||||
num_features=self.config.image_embedding_pooling_dim,
|
||||
)
|
||||
|
||||
self._out_dim += len(self.all_image_keys) * config.latent_dim
|
||||
|
||||
if self.config.vision_encoder_name is not None:
|
||||
self.image_enc_layers = PretrainedImageEncoder(config)
|
||||
self.has_pretrained_vision_encoder = True
|
||||
else:
|
||||
self.image_enc_layers = DefaultImageEncoder(config)
|
||||
|
||||
if self.freeze_image_encoder:
|
||||
freeze_image_encoder(self.image_enc_layers)
|
||||
|
||||
# Separate components for each image stream
|
||||
self.spatial_embeddings = nn.ModuleDict()
|
||||
self.post_encoders = nn.ModuleDict()
|
||||
|
||||
# determine the nb_channels, height and width of the image
|
||||
|
||||
# Get first image key from input features
|
||||
image_key = next(key for key in config.input_features if key.startswith("observation.image")) # noqa: SIM118
|
||||
dummy_batch = torch.zeros(1, *config.input_features[image_key].shape)
|
||||
with torch.inference_mode():
|
||||
dummy_output = self.image_enc_layers(dummy_batch)
|
||||
_, channels, height, width = dummy_output.shape
|
||||
|
||||
for key in self.all_image_keys:
|
||||
# HACK: This a hack because the state_dict use . to separate the keys
|
||||
safe_key = key.replace(".", "_")
|
||||
# Separate spatial embedding per image
|
||||
self.spatial_embeddings[safe_key] = SpatialLearnedEmbeddings(
|
||||
height=height,
|
||||
width=width,
|
||||
channel=channels,
|
||||
num_features=config.image_embedding_pooling_dim,
|
||||
)
|
||||
# Separate post-encoder per image
|
||||
self.post_encoders[safe_key] = nn.Sequential(
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(
|
||||
in_features=channels * config.image_embedding_pooling_dim,
|
||||
out_features=config.latent_dim,
|
||||
),
|
||||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
if "observation.state" in config.input_features:
|
||||
self.state_enc_layers = nn.Sequential(
|
||||
self.post_encoders[name] = nn.Sequential(
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(
|
||||
in_features=config.input_features["observation.state"].shape[0],
|
||||
out_features=config.latent_dim,
|
||||
in_features=channels * self.config.image_embedding_pooling_dim,
|
||||
out_features=self.config.latent_dim,
|
||||
),
|
||||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||
nn.LayerNorm(normalized_shape=self.config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
self._out_dim += config.latent_dim
|
||||
if "observation.environment_state" in config.input_features:
|
||||
self.env_state_enc_layers = nn.Sequential(
|
||||
nn.Linear(
|
||||
in_features=config.input_features["observation.environment_state"].shape[0],
|
||||
out_features=config.latent_dim,
|
||||
),
|
||||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||
|
||||
def _init_state_layers(self) -> None:
|
||||
self.has_env = "observation.environment_state" in self.config.input_features
|
||||
self.has_state = "observation.state" in self.config.input_features
|
||||
if self.has_env:
|
||||
dim = self.config.input_features["observation.environment_state"].shape[0]
|
||||
self.env_encoder = nn.Sequential(
|
||||
nn.Linear(dim, self.config.latent_dim),
|
||||
nn.LayerNorm(self.config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
self._out_dim += config.latent_dim
|
||||
if self.has_state:
|
||||
dim = self.config.input_features["observation.state"].shape[0]
|
||||
self.state_encoder = nn.Sequential(
|
||||
nn.Linear(dim, self.config.latent_dim),
|
||||
nn.LayerNorm(self.config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
def _compute_output_dim(self) -> None:
|
||||
out = 0
|
||||
if self.has_images:
|
||||
out += len(self.image_keys) * self.config.latent_dim
|
||||
if self.has_env:
|
||||
out += self.config.latent_dim
|
||||
if self.has_state:
|
||||
out += self.config.latent_dim
|
||||
self._out_dim = out
|
||||
|
||||
def forward(
|
||||
self,
|
||||
obs_dict: dict[str, torch.Tensor],
|
||||
vision_encoder_cache: torch.Tensor | None = None,
|
||||
detach: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Encode the image and/or state vector.
|
||||
self, obs: dict[str, Tensor], cache: Optional[dict[str, Tensor]] = None, detach: bool = False
|
||||
) -> Tensor:
|
||||
obs = self.input_normalization(obs)
|
||||
parts = []
|
||||
if self.has_images:
|
||||
if cache is None:
|
||||
cache = self.get_cached_image_features(obs, normalize=False)
|
||||
parts.append(self._encode_images(cache, detach))
|
||||
if self.has_env:
|
||||
parts.append(self.env_encoder(obs["observation.environment_state"]))
|
||||
if self.has_state:
|
||||
parts.append(self.state_encoder(obs["observation.state"]))
|
||||
if parts:
|
||||
return torch.cat(parts, dim=-1)
|
||||
|
||||
Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken
|
||||
over all features.
|
||||
raise ValueError(
|
||||
"No parts to concatenate, you should have at least one image or environment state or state"
|
||||
)
|
||||
|
||||
def get_cached_image_features(self, obs: dict[str, Tensor], normalize: bool = False) -> dict[str, Tensor]:
|
||||
"""Extract and optionally cache image features from observations.
|
||||
|
||||
This function processes image observations through the vision encoder once and returns
|
||||
the resulting features.
|
||||
When the image encoder is shared between actor and critics AND frozen, these features can be safely cached and
|
||||
reused across policy components (actor, critic, grasp_critic), avoiding redundant forward passes.
|
||||
|
||||
Performance impact:
|
||||
- The vision encoder forward pass is typically the main computational bottleneck during training and inference
|
||||
- Caching these features can provide 2-4x speedup in training and inference
|
||||
|
||||
Normalization behavior:
|
||||
- When called from inside forward(): set normalize=False since inputs are already normalized
|
||||
- When called from outside forward(): set normalize=True to ensure proper input normalization
|
||||
|
||||
Usage patterns:
|
||||
- Called in select_action() with normalize=True
|
||||
- Called in learner_server.py's get_observation_features() to pre-compute features for all policy components
|
||||
- Called internally by forward() with normalize=False
|
||||
|
||||
Args:
|
||||
obs: Dictionary of observation tensors containing image keys
|
||||
normalize: Whether to normalize observations before encoding
|
||||
Set to True when calling directly from outside the encoder's forward method
|
||||
Set to False when calling from within forward() where inputs are already normalized
|
||||
|
||||
Returns:
|
||||
Dictionary mapping image keys to their corresponding encoded features
|
||||
"""
|
||||
feat = []
|
||||
obs_dict = self.input_normalization(obs_dict)
|
||||
if len(self.all_image_keys) > 0:
|
||||
if vision_encoder_cache is None:
|
||||
vision_encoder_cache = self.get_cached_image_features(obs_dict, normalize=False)
|
||||
|
||||
vision_encoder_cache = self.get_full_image_representation_with_cached_features(
|
||||
batch_image_cached_features=vision_encoder_cache, detach=detach
|
||||
)
|
||||
feat.append(vision_encoder_cache)
|
||||
|
||||
if "observation.environment_state" in self.config.input_features:
|
||||
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
||||
if "observation.state" in self.config.input_features:
|
||||
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
||||
|
||||
features = torch.cat(tensors=feat, dim=-1)
|
||||
|
||||
return features
|
||||
|
||||
def get_cached_image_features(
|
||||
self, batch: dict[str, torch.Tensor], normalize: bool = True
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Get the cached image features for a batch of observations, when the image encoder is frozen"""
|
||||
if normalize:
|
||||
batch = self.input_normalization(batch)
|
||||
obs = self.input_normalization(obs)
|
||||
batched = torch.cat([obs[k] for k in self.image_keys], dim=0)
|
||||
out = self.image_encoder(batched)
|
||||
chunks = torch.chunk(out, len(self.image_keys), dim=0)
|
||||
return dict(zip(self.image_keys, chunks, strict=False))
|
||||
|
||||
# Sort keys for consistent ordering
|
||||
sorted_keys = sorted(self.all_image_keys)
|
||||
def _encode_images(self, cache: dict[str, Tensor], detach: bool) -> Tensor:
|
||||
"""Encode image features from cached observations.
|
||||
|
||||
# Stack all images into a single batch
|
||||
batched_input = torch.cat([batch[key] for key in sorted_keys], dim=0)
|
||||
This function takes pre-encoded image features from the cache and applies spatial embeddings and post-encoders.
|
||||
It also supports detaching the encoded features if specified.
|
||||
|
||||
# Process through the image encoder in one pass
|
||||
batched_output = self.image_enc_layers(batched_input)
|
||||
Args:
|
||||
cache (dict[str, Tensor]): The cached image features.
|
||||
detach (bool): Usually when the encoder is shared between actor and critics,
|
||||
we want to detach the encoded features on the policy side to avoid backprop through the encoder.
|
||||
More detail here `https://cdn.aaai.org/ojs/17276/17276-13-20770-1-2-20210518.pdf`
|
||||
|
||||
# Split the output back into individual tensors
|
||||
image_features = torch.chunk(batched_output, chunks=len(sorted_keys), dim=0)
|
||||
|
||||
# Create a dictionary mapping the original keys to their features
|
||||
result = {}
|
||||
for key, features in zip(sorted_keys, image_features, strict=True):
|
||||
result[key] = features
|
||||
|
||||
return result
|
||||
|
||||
def get_full_image_representation_with_cached_features(
|
||||
self,
|
||||
batch_image_cached_features: dict[str, torch.Tensor],
|
||||
detach: bool = False,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Get the full image representation with the cached features, applying the post-encoder and the spatial embedding"""
|
||||
|
||||
image_features = []
|
||||
for key in batch_image_cached_features:
|
||||
safe_key = key.replace(".", "_")
|
||||
x = self.spatial_embeddings[safe_key](batch_image_cached_features[key])
|
||||
Returns:
|
||||
Tensor: The encoded image features.
|
||||
"""
|
||||
feats = []
|
||||
for k, feat in cache.items():
|
||||
safe_key = k.replace(".", "_")
|
||||
x = self.spatial_embeddings[safe_key](feat)
|
||||
x = self.post_encoders[safe_key](x)
|
||||
|
||||
# The gradient of the image encoder is not needed to update the policy
|
||||
if detach:
|
||||
x = x.detach()
|
||||
image_features.append(x)
|
||||
|
||||
image_features = torch.cat(image_features, dim=-1)
|
||||
return image_features
|
||||
feats.append(x)
|
||||
return torch.cat(feats, dim=-1)
|
||||
|
||||
@property
|
||||
def output_dim(self) -> int:
|
||||
"""Returns the dimension of the encoder output"""
|
||||
return self._out_dim
|
||||
|
||||
|
||||
|
@ -805,7 +799,7 @@ class CriticEnsemble(nn.Module):
|
|||
actions = self.output_normalization(actions)["action"]
|
||||
actions = actions.to(device)
|
||||
|
||||
obs_enc = self.encoder(observations, observation_features)
|
||||
obs_enc = self.encoder(observations, cache=observation_features)
|
||||
|
||||
inputs = torch.cat([obs_enc, actions], dim=-1)
|
||||
|
||||
|
@ -858,7 +852,7 @@ class GraspCritic(nn.Module):
|
|||
device = get_device_from_parameters(self)
|
||||
# Move each tensor in observations to device by cloning first to avoid inplace operations
|
||||
observations = {k: v.to(device) for k, v in observations.items()}
|
||||
obs_enc = self.encoder(observations, vision_encoder_cache=observation_features)
|
||||
obs_enc = self.encoder(observations, cache=observation_features)
|
||||
return self.output_layer(self.net(obs_enc))
|
||||
|
||||
|
||||
|
@ -911,12 +905,10 @@ class Policy(nn.Module):
|
|||
self,
|
||||
observations: torch.Tensor,
|
||||
observation_features: torch.Tensor | None = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# We detach the encoder if it is shared to avoid backprop through it
|
||||
# This is important to avoid the encoder to be updated through the policy
|
||||
obs_enc = self.encoder(
|
||||
observations, vision_encoder_cache=observation_features, detach=self.encoder_is_shared
|
||||
)
|
||||
obs_enc = self.encoder(observations, cache=observation_features, detach=self.encoder_is_shared)
|
||||
|
||||
# Get network outputs
|
||||
outputs = self.network(obs_enc)
|
||||
|
|
Loading…
Reference in New Issue