diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index 684ac17f..0e886f90 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -86,6 +86,7 @@ class SACConfig(PreTrainedConfig): image_encoder_hidden_dim: Hidden dimension size for the image encoder. shared_encoder: Whether to use a shared encoder for actor and critic. num_discrete_actions: Number of discrete actions, eg for gripper actions. + image_embedding_pooling_dim: Dimension of the image embedding pooling. concurrency: Configuration for concurrency settings. actor_learner: Configuration for actor-learner architecture. online_steps: Number of steps for online training. @@ -147,6 +148,7 @@ class SACConfig(PreTrainedConfig): image_encoder_hidden_dim: int = 32 shared_encoder: bool = True num_discrete_actions: int | None = None + image_embedding_pooling_dim: int = 8 # Training parameter online_steps: int = 1000000 diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 05937240..35828467 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -196,7 +196,7 @@ class SACPolicy( # We cached the encoder output to avoid recomputing it observations_features = None if self.shared_encoder: - observations_features = self.actor.encoder.get_image_features(batch, normalize=True) + observations_features = self.actor.encoder.get_base_image_features(batch, normalize=True) actions, _, _ = self.actor(batch, observations_features) actions = self.unnormalize_outputs({"action": actions})["action"] @@ -512,12 +512,19 @@ class SACObservationEncoder(nn.Module): """ super().__init__() self.config = config - self.input_normalization = input_normalizer - self.has_pretrained_vision_encoder = False - self.aggregation_size: int = 0 + self.freeze_image_encoder = config.freeze_vision_encoder + + self.input_normalization = input_normalizer + self._out_dim = 0 + if any("observation.image" in key for key in config.input_features): self.camera_number = config.camera_number + self.all_image_keys = sorted( + [k for k in config.input_features if k.startswith("observation.image")] + ) + + self._out_dim += len(self.all_image_keys) * config.latent_dim if self.config.vision_encoder_name is not None: self.image_enc_layers = PretrainedImageEncoder(config) @@ -525,12 +532,42 @@ class SACObservationEncoder(nn.Module): else: self.image_enc_layers = DefaultImageEncoder(config) - self.aggregation_size += config.latent_dim * self.camera_number + if self.freeze_image_encoder: + freeze_image_encoder(self.image_enc_layers) - if config.freeze_vision_encoder: - freeze_image_encoder(self.image_enc_layers.image_enc_layers) + # Separate components for each image stream + self.spatial_embeddings = nn.ModuleDict() + self.post_encoders = nn.ModuleDict() - self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")] + # determine the nb_channels, height and width of the image + + # Get first image key from input features + image_key = next(key for key in config.input_features if key.startswith("observation.image")) # noqa: SIM118 + dummy_batch = torch.zeros(1, *config.input_features[image_key].shape) + with torch.inference_mode(): + dummy_output = self.image_enc_layers(dummy_batch) + _, channels, height, width = dummy_output.shape + + for key in self.all_image_keys: + # HACK: This a hack because the state_dict use . to separate the keys + safe_key = key.replace(".", "_") + # Separate spatial embedding per image + self.spatial_embeddings[safe_key] = SpatialLearnedEmbeddings( + height=height, + width=width, + channel=channels, + num_features=config.image_embedding_pooling_dim, + ) + # Separate post-encoder per image + self.post_encoders[safe_key] = nn.Sequential( + nn.Dropout(0.1), + nn.Linear( + in_features=channels * config.image_embedding_pooling_dim, + out_features=config.latent_dim, + ), + nn.LayerNorm(normalized_shape=config.latent_dim), + nn.Tanh(), + ) if "observation.state" in config.input_features: self.state_enc_layers = nn.Sequential( @@ -541,8 +578,7 @@ class SACObservationEncoder(nn.Module): nn.LayerNorm(normalized_shape=config.latent_dim), nn.Tanh(), ) - self.aggregation_size += config.latent_dim - + self._out_dim += config.latent_dim if "observation.environment_state" in config.input_features: self.env_state_enc_layers = nn.Sequential( nn.Linear( @@ -552,13 +588,13 @@ class SACObservationEncoder(nn.Module): nn.LayerNorm(normalized_shape=config.latent_dim), nn.Tanh(), ) - self.aggregation_size += config.latent_dim - - self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim) + self._out_dim += config.latent_dim def forward( - self, obs_dict: dict[str, Tensor], vision_encoder_cache: torch.Tensor | None = None - ) -> Tensor: + self, + obs_dict: dict[str, torch.Tensor], + vision_encoder_cache: torch.Tensor | None = None, + ) -> torch.Tensor: """Encode the image and/or state vector. Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken @@ -566,10 +602,9 @@ class SACObservationEncoder(nn.Module): """ feat = [] obs_dict = self.input_normalization(obs_dict) - if len(self.all_image_keys) > 0 and vision_encoder_cache is None: - vision_encoder_cache = self.get_image_features(obs_dict, normalize=False) - - if vision_encoder_cache is not None: + if len(self.all_image_keys) > 0: + if vision_encoder_cache is None: + vision_encoder_cache = self.get_base_image_features(obs_dict, normalize=False) feat.append(vision_encoder_cache) if "observation.environment_state" in self.config.input_features: @@ -578,27 +613,50 @@ class SACObservationEncoder(nn.Module): feat.append(self.state_enc_layers(obs_dict["observation.state"])) features = torch.cat(tensors=feat, dim=-1) - features = self.aggregation_layer(features) return features - def get_image_features(self, batch: dict[str, Tensor], normalize: bool = True) -> torch.Tensor: - # [N*B, C, H, W] + def get_base_image_features( + self, batch: dict[str, torch.Tensor], normalize: bool = True + ) -> dict[str, torch.Tensor]: + """Process all images through the base encoder in a batched manner""" if normalize: batch = self.input_normalization(batch) - if len(self.all_image_keys) > 0: - # Batch all images along the batch dimension, then encode them. - images_batched = torch.cat([batch[key] for key in self.all_image_keys], dim=0) - images_batched = self.image_enc_layers(images_batched) - embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys)) - embeddings_image = torch.cat(embeddings_chunks, dim=-1) - return embeddings_image - return None + + # Sort keys for consistent ordering + sorted_keys = sorted(self.all_image_keys) + + # Stack all images into a single batch + batched_input = torch.cat([batch[key] for key in sorted_keys], dim=0) + + # Process through the image encoder in one pass + batched_output = self.image_enc_layers(batched_input) + + if self.freeze_image_encoder: + batched_output = batched_output.detach() + + # Split the output back into individual tensors + image_features = torch.chunk(batched_output, chunks=len(sorted_keys), dim=0) + + # Create a dictionary mapping the original keys to their features + result = {} + for key, features in zip(sorted_keys, image_features, strict=True): + result[key] = features + + image_features = [] + for key in result: + safe_key = key.replace(".", "_") + x = self.spatial_embeddings[safe_key](result[key]) + x = self.post_encoders[safe_key](x) + image_features.append(x) + + image_features = torch.cat(image_features, dim=-1) + return image_features @property def output_dim(self) -> int: """Returns the dimension of the encoder output""" - return self.config.latent_dim + return self._out_dim class MLP(nn.Module): @@ -938,44 +996,29 @@ class DefaultImageEncoder(nn.Module): nn.ReLU(), ) # Get first image key from input features - image_key = next(key for key in config.input_features.keys() if key.startswith("observation.image")) # noqa: SIM118 - dummy_batch = torch.zeros(1, *config.input_features[image_key].shape) - with torch.inference_mode(): - self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:] - self.image_enc_proj = nn.Sequential( - nn.Flatten(), - nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim), - nn.LayerNorm(config.latent_dim), - nn.Tanh(), - ) - - self.freeze_image_encoder = config.freeze_vision_encoder def forward(self, x): x = self.image_enc_layers(x) - if self.freeze_image_encoder: - x = x.detach() - return self.image_enc_proj(x) + return x + + +def freeze_image_encoder(image_encoder: nn.Module): + """Freeze all parameters in the encoder""" + for param in image_encoder.parameters(): + param.requires_grad = False class PretrainedImageEncoder(nn.Module): def __init__(self, config: SACConfig): super().__init__() - self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config) - self.image_enc_proj = nn.Sequential( - nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim), - nn.LayerNorm(config.latent_dim), - nn.Tanh(), - ) - self.freeze_image_encoder = config.freeze_vision_encoder + self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config) def _load_pretrained_vision_encoder(self, config: SACConfig): """Set up CNN encoder""" from transformers import AutoModel self.image_enc_layers = AutoModel.from_pretrained(config.vision_encoder_name, trust_remote_code=True) - # self.image_enc_layers.pooler = Identity() if hasattr(self.image_enc_layers.config, "hidden_sizes"): self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[-1] # Last channel dimension @@ -986,21 +1029,10 @@ class PretrainedImageEncoder(nn.Module): return self.image_enc_layers, self.image_enc_out_shape def forward(self, x): - # TODO: (maractingi, azouitine) check the forward pass of the pretrained model - # doesn't reach the classifier layer because we don't need it - enc_feat = self.image_enc_layers(x).pooler_output - if self.freeze_image_encoder: - enc_feat = enc_feat.detach() - enc_feat = self.image_enc_proj(enc_feat.view(enc_feat.shape[0], -1)) + enc_feat = self.image_enc_layers(x).last_hidden_state return enc_feat -def freeze_image_encoder(image_encoder: nn.Module): - """Freeze all parameters in the encoder""" - for param in image_encoder.parameters(): - param.requires_grad = False - - def orthogonal_init(): return lambda x: torch.nn.init.orthogonal_(x, gain=1.0) @@ -1013,6 +1045,50 @@ class Identity(nn.Module): return x +class SpatialLearnedEmbeddings(nn.Module): + def __init__(self, height, width, channel, num_features=8): + """ + PyTorch implementation of learned spatial embeddings + + Args: + height: Spatial height of input features + width: Spatial width of input features + channel: Number of input channels + num_features: Number of output embedding dimensions + """ + super().__init__() + self.height = height + self.width = width + self.channel = channel + self.num_features = num_features + + self.kernel = nn.Parameter(torch.empty(channel, height, width, num_features)) + + nn.init.kaiming_normal_(self.kernel, mode="fan_in", nonlinearity="linear") + + def forward(self, features): + """ + Forward pass for spatial embedding + + Args: + features: Input tensor of shape [B, C, H, W] where B is batch size, + C is number of channels, H is height, and W is width + Returns: + Output tensor of shape [B, C*F] where F is the number of features + """ + + features_expanded = features.unsqueeze(-1) # [B, C, H, W, 1] + kernel_expanded = self.kernel.unsqueeze(0) # [1, C, H, W, F] + + # Element-wise multiplication and spatial reduction + output = (features_expanded * kernel_expanded).sum(dim=(2, 3)) # Sum over H,W dimensions + + # Reshape to combine channel and feature dimensions + output = output.view(output.size(0), -1) # [B, C*F] + + return output + + def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict: converted_params = {} for outer_key, inner_dict in normalization_params.items():