diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 05c0e02a..81859119 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -81,10 +81,10 @@ class SACPolicy( @torch.no_grad() def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select action for inference/evaluation""" - # We cached the encoder output to avoid recomputing it if the encoder is shared observations_features = None if self.shared_encoder: - observations_features = self.actor.encoder.get_cached_image_features(batch=batch, normalize=True) + # Cache and normalize image features + observations_features = self.actor.encoder.get_cached_image_features(batch, normalize=True) actions, _, _ = self.actor(batch, observations_features) actions = self.unnormalize_outputs({"action": actions})["action"] @@ -489,171 +489,165 @@ class SACPolicy( class SACObservationEncoder(nn.Module): """Encode image and/or state vector observations.""" - def __init__(self, config: SACConfig, input_normalizer: nn.Module): - """ - Creates encoders for pixel and/or state modalities. - """ - super().__init__() + def __init__(self, config: SACConfig, input_normalizer: nn.Module) -> None: + super(SACObservationEncoder, self).__init__() self.config = config - - self.freeze_image_encoder = config.freeze_vision_encoder - self.input_normalization = input_normalizer - self._out_dim = 0 + self._init_image_layers() + self._init_state_layers() + self._compute_output_dim() - if any("observation.image" in key for key in config.input_features): - self.camera_number = config.camera_number - self.all_image_keys = sorted( - [k for k in config.input_features if k.startswith("observation.image")] + def _init_image_layers(self) -> None: + self.image_keys = [k for k in self.config.input_features if k.startswith("observation.image")] + self.has_images = bool(self.image_keys) + if not self.has_images: + return + + if self.config.vision_encoder_name: + self.image_encoder = PretrainedImageEncoder(self.config) + else: + self.image_encoder = DefaultImageEncoder(self.config) + + if self.config.freeze_vision_encoder: + freeze_image_encoder(self.image_encoder) + + dummy = torch.zeros(1, *self.config.input_features[self.image_keys[0]].shape) + with torch.no_grad(): + _, channels, height, width = self.image_encoder(dummy).shape + + self.spatial_embeddings = nn.ModuleDict() + self.post_encoders = nn.ModuleDict() + + for key in self.image_keys: + name = key.replace(".", "_") + self.spatial_embeddings[name] = SpatialLearnedEmbeddings( + height=height, + width=width, + channel=channels, + num_features=self.config.image_embedding_pooling_dim, ) - - self._out_dim += len(self.all_image_keys) * config.latent_dim - - if self.config.vision_encoder_name is not None: - self.image_enc_layers = PretrainedImageEncoder(config) - self.has_pretrained_vision_encoder = True - else: - self.image_enc_layers = DefaultImageEncoder(config) - - if self.freeze_image_encoder: - freeze_image_encoder(self.image_enc_layers) - - # Separate components for each image stream - self.spatial_embeddings = nn.ModuleDict() - self.post_encoders = nn.ModuleDict() - - # determine the nb_channels, height and width of the image - - # Get first image key from input features - image_key = next(key for key in config.input_features if key.startswith("observation.image")) # noqa: SIM118 - dummy_batch = torch.zeros(1, *config.input_features[image_key].shape) - with torch.inference_mode(): - dummy_output = self.image_enc_layers(dummy_batch) - _, channels, height, width = dummy_output.shape - - for key in self.all_image_keys: - # HACK: This a hack because the state_dict use . to separate the keys - safe_key = key.replace(".", "_") - # Separate spatial embedding per image - self.spatial_embeddings[safe_key] = SpatialLearnedEmbeddings( - height=height, - width=width, - channel=channels, - num_features=config.image_embedding_pooling_dim, - ) - # Separate post-encoder per image - self.post_encoders[safe_key] = nn.Sequential( - nn.Dropout(0.1), - nn.Linear( - in_features=channels * config.image_embedding_pooling_dim, - out_features=config.latent_dim, - ), - nn.LayerNorm(normalized_shape=config.latent_dim), - nn.Tanh(), - ) - - if "observation.state" in config.input_features: - self.state_enc_layers = nn.Sequential( + self.post_encoders[name] = nn.Sequential( + nn.Dropout(0.1), nn.Linear( - in_features=config.input_features["observation.state"].shape[0], - out_features=config.latent_dim, + in_features=channels * self.config.image_embedding_pooling_dim, + out_features=self.config.latent_dim, ), - nn.LayerNorm(normalized_shape=config.latent_dim), + nn.LayerNorm(normalized_shape=self.config.latent_dim), nn.Tanh(), ) - self._out_dim += config.latent_dim - if "observation.environment_state" in config.input_features: - self.env_state_enc_layers = nn.Sequential( - nn.Linear( - in_features=config.input_features["observation.environment_state"].shape[0], - out_features=config.latent_dim, - ), - nn.LayerNorm(normalized_shape=config.latent_dim), + + def _init_state_layers(self) -> None: + self.has_env = "observation.environment_state" in self.config.input_features + self.has_state = "observation.state" in self.config.input_features + if self.has_env: + dim = self.config.input_features["observation.environment_state"].shape[0] + self.env_encoder = nn.Sequential( + nn.Linear(dim, self.config.latent_dim), + nn.LayerNorm(self.config.latent_dim), nn.Tanh(), ) - self._out_dim += config.latent_dim + if self.has_state: + dim = self.config.input_features["observation.state"].shape[0] + self.state_encoder = nn.Sequential( + nn.Linear(dim, self.config.latent_dim), + nn.LayerNorm(self.config.latent_dim), + nn.Tanh(), + ) + + def _compute_output_dim(self) -> None: + out = 0 + if self.has_images: + out += len(self.image_keys) * self.config.latent_dim + if self.has_env: + out += self.config.latent_dim + if self.has_state: + out += self.config.latent_dim + self._out_dim = out def forward( - self, - obs_dict: dict[str, torch.Tensor], - vision_encoder_cache: torch.Tensor | None = None, - detach: bool = False, - ) -> torch.Tensor: - """Encode the image and/or state vector. + self, obs: dict[str, Tensor], cache: Optional[dict[str, Tensor]] = None, detach: bool = False + ) -> Tensor: + obs = self.input_normalization(obs) + parts = [] + if self.has_images: + if cache is None: + cache = self.get_cached_image_features(obs, normalize=False) + parts.append(self._encode_images(cache, detach)) + if self.has_env: + parts.append(self.env_encoder(obs["observation.environment_state"])) + if self.has_state: + parts.append(self.state_encoder(obs["observation.state"])) + if parts: + return torch.cat(parts, dim=-1) - Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken - over all features. + raise ValueError( + "No parts to concatenate, you should have at least one image or environment state or state" + ) + + def get_cached_image_features(self, obs: dict[str, Tensor], normalize: bool = False) -> dict[str, Tensor]: + """Extract and optionally cache image features from observations. + + This function processes image observations through the vision encoder once and returns + the resulting features. + When the image encoder is shared between actor and critics AND frozen, these features can be safely cached and + reused across policy components (actor, critic, grasp_critic), avoiding redundant forward passes. + + Performance impact: + - The vision encoder forward pass is typically the main computational bottleneck during training and inference + - Caching these features can provide 2-4x speedup in training and inference + + Normalization behavior: + - When called from inside forward(): set normalize=False since inputs are already normalized + - When called from outside forward(): set normalize=True to ensure proper input normalization + + Usage patterns: + - Called in select_action() with normalize=True + - Called in learner_server.py's get_observation_features() to pre-compute features for all policy components + - Called internally by forward() with normalize=False + + Args: + obs: Dictionary of observation tensors containing image keys + normalize: Whether to normalize observations before encoding + Set to True when calling directly from outside the encoder's forward method + Set to False when calling from within forward() where inputs are already normalized + + Returns: + Dictionary mapping image keys to their corresponding encoded features """ - feat = [] - obs_dict = self.input_normalization(obs_dict) - if len(self.all_image_keys) > 0: - if vision_encoder_cache is None: - vision_encoder_cache = self.get_cached_image_features(obs_dict, normalize=False) - - vision_encoder_cache = self.get_full_image_representation_with_cached_features( - batch_image_cached_features=vision_encoder_cache, detach=detach - ) - feat.append(vision_encoder_cache) - - if "observation.environment_state" in self.config.input_features: - feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"])) - if "observation.state" in self.config.input_features: - feat.append(self.state_enc_layers(obs_dict["observation.state"])) - - features = torch.cat(tensors=feat, dim=-1) - - return features - - def get_cached_image_features( - self, batch: dict[str, torch.Tensor], normalize: bool = True - ) -> dict[str, torch.Tensor]: - """Get the cached image features for a batch of observations, when the image encoder is frozen""" if normalize: - batch = self.input_normalization(batch) + obs = self.input_normalization(obs) + batched = torch.cat([obs[k] for k in self.image_keys], dim=0) + out = self.image_encoder(batched) + chunks = torch.chunk(out, len(self.image_keys), dim=0) + return dict(zip(self.image_keys, chunks, strict=False)) - # Sort keys for consistent ordering - sorted_keys = sorted(self.all_image_keys) + def _encode_images(self, cache: dict[str, Tensor], detach: bool) -> Tensor: + """Encode image features from cached observations. - # Stack all images into a single batch - batched_input = torch.cat([batch[key] for key in sorted_keys], dim=0) + This function takes pre-encoded image features from the cache and applies spatial embeddings and post-encoders. + It also supports detaching the encoded features if specified. - # Process through the image encoder in one pass - batched_output = self.image_enc_layers(batched_input) + Args: + cache (dict[str, Tensor]): The cached image features. + detach (bool): Usually when the encoder is shared between actor and critics, + we want to detach the encoded features on the policy side to avoid backprop through the encoder. + More detail here `https://cdn.aaai.org/ojs/17276/17276-13-20770-1-2-20210518.pdf` - # Split the output back into individual tensors - image_features = torch.chunk(batched_output, chunks=len(sorted_keys), dim=0) - - # Create a dictionary mapping the original keys to their features - result = {} - for key, features in zip(sorted_keys, image_features, strict=True): - result[key] = features - - return result - - def get_full_image_representation_with_cached_features( - self, - batch_image_cached_features: dict[str, torch.Tensor], - detach: bool = False, - ) -> dict[str, torch.Tensor]: - """Get the full image representation with the cached features, applying the post-encoder and the spatial embedding""" - - image_features = [] - for key in batch_image_cached_features: - safe_key = key.replace(".", "_") - x = self.spatial_embeddings[safe_key](batch_image_cached_features[key]) + Returns: + Tensor: The encoded image features. + """ + feats = [] + for k, feat in cache.items(): + safe_key = k.replace(".", "_") + x = self.spatial_embeddings[safe_key](feat) x = self.post_encoders[safe_key](x) - - # The gradient of the image encoder is not needed to update the policy if detach: x = x.detach() - image_features.append(x) - - image_features = torch.cat(image_features, dim=-1) - return image_features + feats.append(x) + return torch.cat(feats, dim=-1) @property def output_dim(self) -> int: - """Returns the dimension of the encoder output""" return self._out_dim @@ -805,7 +799,7 @@ class CriticEnsemble(nn.Module): actions = self.output_normalization(actions)["action"] actions = actions.to(device) - obs_enc = self.encoder(observations, observation_features) + obs_enc = self.encoder(observations, cache=observation_features) inputs = torch.cat([obs_enc, actions], dim=-1) @@ -858,7 +852,7 @@ class GraspCritic(nn.Module): device = get_device_from_parameters(self) # Move each tensor in observations to device by cloning first to avoid inplace operations observations = {k: v.to(device) for k, v in observations.items()} - obs_enc = self.encoder(observations, vision_encoder_cache=observation_features) + obs_enc = self.encoder(observations, cache=observation_features) return self.output_layer(self.net(obs_enc)) @@ -911,12 +905,10 @@ class Policy(nn.Module): self, observations: torch.Tensor, observation_features: torch.Tensor | None = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # We detach the encoder if it is shared to avoid backprop through it # This is important to avoid the encoder to be updated through the policy - obs_enc = self.encoder( - observations, vision_encoder_cache=observation_features, detach=self.encoder_is_shared - ) + obs_enc = self.encoder(observations, cache=observation_features, detach=self.encoder_is_shared) # Get network outputs outputs = self.network(obs_enc)