From 854bfb4ff8548c2276e70de0852acd3d620b948e Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Fri, 11 Apr 2025 11:50:46 +0000 Subject: [PATCH] fix encoder training --- lerobot/common/policies/sac/modeling_sac.py | 32 +++++++++++++-------- 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index b8827a1b..9ffdf154 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -525,9 +525,10 @@ class SACObservationEncoder(nn.Module): self.aggregation_size += config.latent_dim * self.camera_number if config.freeze_vision_encoder: - freeze_image_encoder(self.image_enc_layers) - else: - self.parameters_to_optimize += list(self.image_enc_layers.parameters()) + freeze_image_encoder(self.image_enc_layers.image_enc_layers) + + self.parameters_to_optimize += self.image_enc_layers.parameters_to_optimize + self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")] if "observation.state" in config.input_features: @@ -958,23 +959,25 @@ class DefaultImageEncoder(nn.Module): dummy_batch = torch.zeros(1, *config.input_features[image_key].shape) with torch.inference_mode(): self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:] - self.image_enc_layers.extend( - nn.Sequential( - nn.Flatten(), - nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim), - nn.LayerNorm(config.latent_dim), - nn.Tanh(), - ) + self.image_enc_proj = nn.Sequential( + nn.Flatten(), + nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim), + nn.LayerNorm(config.latent_dim), + nn.Tanh(), ) + self.parameters_to_optimize = [] + if not config.freeze_vision_encoder: + self.parameters_to_optimize += list(self.image_enc_layers.parameters()) + self.parameters_to_optimize += list(self.image_enc_proj.parameters()) + def forward(self, x): - return self.image_enc_layers(x) + return self.image_enc_proj(self.image_enc_layers(x)) class PretrainedImageEncoder(nn.Module): def __init__(self, config: SACConfig): super().__init__() - self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config) self.image_enc_proj = nn.Sequential( nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim), @@ -982,6 +985,11 @@ class PretrainedImageEncoder(nn.Module): nn.Tanh(), ) + self.parameters_to_optimize = [] + if not config.freeze_vision_encoder: + self.parameters_to_optimize += list(self.image_enc_layers.parameters()) + self.parameters_to_optimize += list(self.image_enc_proj.parameters()) + def _load_pretrained_vision_encoder(self, config: SACConfig): """Set up CNN encoder""" from transformers import AutoModel