Refactor SACObservationEncoder to improve modularity and readability. Split initialization into dedicated methods for image and state layers, and enhance caching logic for image features. Update forward method to streamline feature encoding and ensure proper normalization handling.
This commit is contained in:
parent
1ce368503d
commit
dcd850feab
|
@ -81,10 +81,10 @@ class SACPolicy(
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
"""Select action for inference/evaluation"""
|
"""Select action for inference/evaluation"""
|
||||||
# We cached the encoder output to avoid recomputing it if the encoder is shared
|
|
||||||
observations_features = None
|
observations_features = None
|
||||||
if self.shared_encoder:
|
if self.shared_encoder:
|
||||||
observations_features = self.actor.encoder.get_cached_image_features(batch=batch, normalize=True)
|
# Cache and normalize image features
|
||||||
|
observations_features = self.actor.encoder.get_cached_image_features(batch, normalize=True)
|
||||||
|
|
||||||
actions, _, _ = self.actor(batch, observations_features)
|
actions, _, _ = self.actor(batch, observations_features)
|
||||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||||
|
@ -489,171 +489,165 @@ class SACPolicy(
|
||||||
class SACObservationEncoder(nn.Module):
|
class SACObservationEncoder(nn.Module):
|
||||||
"""Encode image and/or state vector observations."""
|
"""Encode image and/or state vector observations."""
|
||||||
|
|
||||||
def __init__(self, config: SACConfig, input_normalizer: nn.Module):
|
def __init__(self, config: SACConfig, input_normalizer: nn.Module) -> None:
|
||||||
"""
|
super(SACObservationEncoder, self).__init__()
|
||||||
Creates encoders for pixel and/or state modalities.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
self.freeze_image_encoder = config.freeze_vision_encoder
|
|
||||||
|
|
||||||
self.input_normalization = input_normalizer
|
self.input_normalization = input_normalizer
|
||||||
self._out_dim = 0
|
self._init_image_layers()
|
||||||
|
self._init_state_layers()
|
||||||
|
self._compute_output_dim()
|
||||||
|
|
||||||
if any("observation.image" in key for key in config.input_features):
|
def _init_image_layers(self) -> None:
|
||||||
self.camera_number = config.camera_number
|
self.image_keys = [k for k in self.config.input_features if k.startswith("observation.image")]
|
||||||
self.all_image_keys = sorted(
|
self.has_images = bool(self.image_keys)
|
||||||
[k for k in config.input_features if k.startswith("observation.image")]
|
if not self.has_images:
|
||||||
)
|
return
|
||||||
|
|
||||||
self._out_dim += len(self.all_image_keys) * config.latent_dim
|
if self.config.vision_encoder_name:
|
||||||
|
self.image_encoder = PretrainedImageEncoder(self.config)
|
||||||
if self.config.vision_encoder_name is not None:
|
|
||||||
self.image_enc_layers = PretrainedImageEncoder(config)
|
|
||||||
self.has_pretrained_vision_encoder = True
|
|
||||||
else:
|
else:
|
||||||
self.image_enc_layers = DefaultImageEncoder(config)
|
self.image_encoder = DefaultImageEncoder(self.config)
|
||||||
|
|
||||||
if self.freeze_image_encoder:
|
if self.config.freeze_vision_encoder:
|
||||||
freeze_image_encoder(self.image_enc_layers)
|
freeze_image_encoder(self.image_encoder)
|
||||||
|
|
||||||
|
dummy = torch.zeros(1, *self.config.input_features[self.image_keys[0]].shape)
|
||||||
|
with torch.no_grad():
|
||||||
|
_, channels, height, width = self.image_encoder(dummy).shape
|
||||||
|
|
||||||
# Separate components for each image stream
|
|
||||||
self.spatial_embeddings = nn.ModuleDict()
|
self.spatial_embeddings = nn.ModuleDict()
|
||||||
self.post_encoders = nn.ModuleDict()
|
self.post_encoders = nn.ModuleDict()
|
||||||
|
|
||||||
# determine the nb_channels, height and width of the image
|
for key in self.image_keys:
|
||||||
|
name = key.replace(".", "_")
|
||||||
# Get first image key from input features
|
self.spatial_embeddings[name] = SpatialLearnedEmbeddings(
|
||||||
image_key = next(key for key in config.input_features if key.startswith("observation.image")) # noqa: SIM118
|
|
||||||
dummy_batch = torch.zeros(1, *config.input_features[image_key].shape)
|
|
||||||
with torch.inference_mode():
|
|
||||||
dummy_output = self.image_enc_layers(dummy_batch)
|
|
||||||
_, channels, height, width = dummy_output.shape
|
|
||||||
|
|
||||||
for key in self.all_image_keys:
|
|
||||||
# HACK: This a hack because the state_dict use . to separate the keys
|
|
||||||
safe_key = key.replace(".", "_")
|
|
||||||
# Separate spatial embedding per image
|
|
||||||
self.spatial_embeddings[safe_key] = SpatialLearnedEmbeddings(
|
|
||||||
height=height,
|
height=height,
|
||||||
width=width,
|
width=width,
|
||||||
channel=channels,
|
channel=channels,
|
||||||
num_features=config.image_embedding_pooling_dim,
|
num_features=self.config.image_embedding_pooling_dim,
|
||||||
)
|
)
|
||||||
# Separate post-encoder per image
|
self.post_encoders[name] = nn.Sequential(
|
||||||
self.post_encoders[safe_key] = nn.Sequential(
|
|
||||||
nn.Dropout(0.1),
|
nn.Dropout(0.1),
|
||||||
nn.Linear(
|
nn.Linear(
|
||||||
in_features=channels * config.image_embedding_pooling_dim,
|
in_features=channels * self.config.image_embedding_pooling_dim,
|
||||||
out_features=config.latent_dim,
|
out_features=self.config.latent_dim,
|
||||||
),
|
),
|
||||||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
nn.LayerNorm(normalized_shape=self.config.latent_dim),
|
||||||
nn.Tanh(),
|
nn.Tanh(),
|
||||||
)
|
)
|
||||||
|
|
||||||
if "observation.state" in config.input_features:
|
def _init_state_layers(self) -> None:
|
||||||
self.state_enc_layers = nn.Sequential(
|
self.has_env = "observation.environment_state" in self.config.input_features
|
||||||
nn.Linear(
|
self.has_state = "observation.state" in self.config.input_features
|
||||||
in_features=config.input_features["observation.state"].shape[0],
|
if self.has_env:
|
||||||
out_features=config.latent_dim,
|
dim = self.config.input_features["observation.environment_state"].shape[0]
|
||||||
),
|
self.env_encoder = nn.Sequential(
|
||||||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
nn.Linear(dim, self.config.latent_dim),
|
||||||
|
nn.LayerNorm(self.config.latent_dim),
|
||||||
nn.Tanh(),
|
nn.Tanh(),
|
||||||
)
|
)
|
||||||
self._out_dim += config.latent_dim
|
if self.has_state:
|
||||||
if "observation.environment_state" in config.input_features:
|
dim = self.config.input_features["observation.state"].shape[0]
|
||||||
self.env_state_enc_layers = nn.Sequential(
|
self.state_encoder = nn.Sequential(
|
||||||
nn.Linear(
|
nn.Linear(dim, self.config.latent_dim),
|
||||||
in_features=config.input_features["observation.environment_state"].shape[0],
|
nn.LayerNorm(self.config.latent_dim),
|
||||||
out_features=config.latent_dim,
|
|
||||||
),
|
|
||||||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
|
||||||
nn.Tanh(),
|
nn.Tanh(),
|
||||||
)
|
)
|
||||||
self._out_dim += config.latent_dim
|
|
||||||
|
def _compute_output_dim(self) -> None:
|
||||||
|
out = 0
|
||||||
|
if self.has_images:
|
||||||
|
out += len(self.image_keys) * self.config.latent_dim
|
||||||
|
if self.has_env:
|
||||||
|
out += self.config.latent_dim
|
||||||
|
if self.has_state:
|
||||||
|
out += self.config.latent_dim
|
||||||
|
self._out_dim = out
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self, obs: dict[str, Tensor], cache: Optional[dict[str, Tensor]] = None, detach: bool = False
|
||||||
obs_dict: dict[str, torch.Tensor],
|
) -> Tensor:
|
||||||
vision_encoder_cache: torch.Tensor | None = None,
|
obs = self.input_normalization(obs)
|
||||||
detach: bool = False,
|
parts = []
|
||||||
) -> torch.Tensor:
|
if self.has_images:
|
||||||
"""Encode the image and/or state vector.
|
if cache is None:
|
||||||
|
cache = self.get_cached_image_features(obs, normalize=False)
|
||||||
|
parts.append(self._encode_images(cache, detach))
|
||||||
|
if self.has_env:
|
||||||
|
parts.append(self.env_encoder(obs["observation.environment_state"]))
|
||||||
|
if self.has_state:
|
||||||
|
parts.append(self.state_encoder(obs["observation.state"]))
|
||||||
|
if parts:
|
||||||
|
return torch.cat(parts, dim=-1)
|
||||||
|
|
||||||
Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken
|
raise ValueError(
|
||||||
over all features.
|
"No parts to concatenate, you should have at least one image or environment state or state"
|
||||||
"""
|
|
||||||
feat = []
|
|
||||||
obs_dict = self.input_normalization(obs_dict)
|
|
||||||
if len(self.all_image_keys) > 0:
|
|
||||||
if vision_encoder_cache is None:
|
|
||||||
vision_encoder_cache = self.get_cached_image_features(obs_dict, normalize=False)
|
|
||||||
|
|
||||||
vision_encoder_cache = self.get_full_image_representation_with_cached_features(
|
|
||||||
batch_image_cached_features=vision_encoder_cache, detach=detach
|
|
||||||
)
|
)
|
||||||
feat.append(vision_encoder_cache)
|
|
||||||
|
|
||||||
if "observation.environment_state" in self.config.input_features:
|
def get_cached_image_features(self, obs: dict[str, Tensor], normalize: bool = False) -> dict[str, Tensor]:
|
||||||
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
"""Extract and optionally cache image features from observations.
|
||||||
if "observation.state" in self.config.input_features:
|
|
||||||
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
|
||||||
|
|
||||||
features = torch.cat(tensors=feat, dim=-1)
|
This function processes image observations through the vision encoder once and returns
|
||||||
|
the resulting features.
|
||||||
|
When the image encoder is shared between actor and critics AND frozen, these features can be safely cached and
|
||||||
|
reused across policy components (actor, critic, grasp_critic), avoiding redundant forward passes.
|
||||||
|
|
||||||
return features
|
Performance impact:
|
||||||
|
- The vision encoder forward pass is typically the main computational bottleneck during training and inference
|
||||||
|
- Caching these features can provide 2-4x speedup in training and inference
|
||||||
|
|
||||||
def get_cached_image_features(
|
Normalization behavior:
|
||||||
self, batch: dict[str, torch.Tensor], normalize: bool = True
|
- When called from inside forward(): set normalize=False since inputs are already normalized
|
||||||
) -> dict[str, torch.Tensor]:
|
- When called from outside forward(): set normalize=True to ensure proper input normalization
|
||||||
"""Get the cached image features for a batch of observations, when the image encoder is frozen"""
|
|
||||||
|
Usage patterns:
|
||||||
|
- Called in select_action() with normalize=True
|
||||||
|
- Called in learner_server.py's get_observation_features() to pre-compute features for all policy components
|
||||||
|
- Called internally by forward() with normalize=False
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obs: Dictionary of observation tensors containing image keys
|
||||||
|
normalize: Whether to normalize observations before encoding
|
||||||
|
Set to True when calling directly from outside the encoder's forward method
|
||||||
|
Set to False when calling from within forward() where inputs are already normalized
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping image keys to their corresponding encoded features
|
||||||
|
"""
|
||||||
if normalize:
|
if normalize:
|
||||||
batch = self.input_normalization(batch)
|
obs = self.input_normalization(obs)
|
||||||
|
batched = torch.cat([obs[k] for k in self.image_keys], dim=0)
|
||||||
|
out = self.image_encoder(batched)
|
||||||
|
chunks = torch.chunk(out, len(self.image_keys), dim=0)
|
||||||
|
return dict(zip(self.image_keys, chunks, strict=False))
|
||||||
|
|
||||||
# Sort keys for consistent ordering
|
def _encode_images(self, cache: dict[str, Tensor], detach: bool) -> Tensor:
|
||||||
sorted_keys = sorted(self.all_image_keys)
|
"""Encode image features from cached observations.
|
||||||
|
|
||||||
# Stack all images into a single batch
|
This function takes pre-encoded image features from the cache and applies spatial embeddings and post-encoders.
|
||||||
batched_input = torch.cat([batch[key] for key in sorted_keys], dim=0)
|
It also supports detaching the encoded features if specified.
|
||||||
|
|
||||||
# Process through the image encoder in one pass
|
Args:
|
||||||
batched_output = self.image_enc_layers(batched_input)
|
cache (dict[str, Tensor]): The cached image features.
|
||||||
|
detach (bool): Usually when the encoder is shared between actor and critics,
|
||||||
|
we want to detach the encoded features on the policy side to avoid backprop through the encoder.
|
||||||
|
More detail here `https://cdn.aaai.org/ojs/17276/17276-13-20770-1-2-20210518.pdf`
|
||||||
|
|
||||||
# Split the output back into individual tensors
|
Returns:
|
||||||
image_features = torch.chunk(batched_output, chunks=len(sorted_keys), dim=0)
|
Tensor: The encoded image features.
|
||||||
|
"""
|
||||||
# Create a dictionary mapping the original keys to their features
|
feats = []
|
||||||
result = {}
|
for k, feat in cache.items():
|
||||||
for key, features in zip(sorted_keys, image_features, strict=True):
|
safe_key = k.replace(".", "_")
|
||||||
result[key] = features
|
x = self.spatial_embeddings[safe_key](feat)
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def get_full_image_representation_with_cached_features(
|
|
||||||
self,
|
|
||||||
batch_image_cached_features: dict[str, torch.Tensor],
|
|
||||||
detach: bool = False,
|
|
||||||
) -> dict[str, torch.Tensor]:
|
|
||||||
"""Get the full image representation with the cached features, applying the post-encoder and the spatial embedding"""
|
|
||||||
|
|
||||||
image_features = []
|
|
||||||
for key in batch_image_cached_features:
|
|
||||||
safe_key = key.replace(".", "_")
|
|
||||||
x = self.spatial_embeddings[safe_key](batch_image_cached_features[key])
|
|
||||||
x = self.post_encoders[safe_key](x)
|
x = self.post_encoders[safe_key](x)
|
||||||
|
|
||||||
# The gradient of the image encoder is not needed to update the policy
|
|
||||||
if detach:
|
if detach:
|
||||||
x = x.detach()
|
x = x.detach()
|
||||||
image_features.append(x)
|
feats.append(x)
|
||||||
|
return torch.cat(feats, dim=-1)
|
||||||
image_features = torch.cat(image_features, dim=-1)
|
|
||||||
return image_features
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_dim(self) -> int:
|
def output_dim(self) -> int:
|
||||||
"""Returns the dimension of the encoder output"""
|
|
||||||
return self._out_dim
|
return self._out_dim
|
||||||
|
|
||||||
|
|
||||||
|
@ -805,7 +799,7 @@ class CriticEnsemble(nn.Module):
|
||||||
actions = self.output_normalization(actions)["action"]
|
actions = self.output_normalization(actions)["action"]
|
||||||
actions = actions.to(device)
|
actions = actions.to(device)
|
||||||
|
|
||||||
obs_enc = self.encoder(observations, observation_features)
|
obs_enc = self.encoder(observations, cache=observation_features)
|
||||||
|
|
||||||
inputs = torch.cat([obs_enc, actions], dim=-1)
|
inputs = torch.cat([obs_enc, actions], dim=-1)
|
||||||
|
|
||||||
|
@ -858,7 +852,7 @@ class GraspCritic(nn.Module):
|
||||||
device = get_device_from_parameters(self)
|
device = get_device_from_parameters(self)
|
||||||
# Move each tensor in observations to device by cloning first to avoid inplace operations
|
# Move each tensor in observations to device by cloning first to avoid inplace operations
|
||||||
observations = {k: v.to(device) for k, v in observations.items()}
|
observations = {k: v.to(device) for k, v in observations.items()}
|
||||||
obs_enc = self.encoder(observations, vision_encoder_cache=observation_features)
|
obs_enc = self.encoder(observations, cache=observation_features)
|
||||||
return self.output_layer(self.net(obs_enc))
|
return self.output_layer(self.net(obs_enc))
|
||||||
|
|
||||||
|
|
||||||
|
@ -911,12 +905,10 @@ class Policy(nn.Module):
|
||||||
self,
|
self,
|
||||||
observations: torch.Tensor,
|
observations: torch.Tensor,
|
||||||
observation_features: torch.Tensor | None = None,
|
observation_features: torch.Tensor | None = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
# We detach the encoder if it is shared to avoid backprop through it
|
# We detach the encoder if it is shared to avoid backprop through it
|
||||||
# This is important to avoid the encoder to be updated through the policy
|
# This is important to avoid the encoder to be updated through the policy
|
||||||
obs_enc = self.encoder(
|
obs_enc = self.encoder(observations, cache=observation_features, detach=self.encoder_is_shared)
|
||||||
observations, vision_encoder_cache=observation_features, detach=self.encoder_is_shared
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get network outputs
|
# Get network outputs
|
||||||
outputs = self.network(obs_enc)
|
outputs = self.network(obs_enc)
|
||||||
|
|
Loading…
Reference in New Issue