fix encoder training

This commit is contained in:
AdilZouitine 2025-04-11 11:50:46 +00:00
parent ba09f44eb7
commit 854bfb4ff8
1 changed files with 20 additions and 12 deletions

View File

@ -525,9 +525,10 @@ class SACObservationEncoder(nn.Module):
self.aggregation_size += config.latent_dim * self.camera_number
if config.freeze_vision_encoder:
freeze_image_encoder(self.image_enc_layers)
else:
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
freeze_image_encoder(self.image_enc_layers.image_enc_layers)
self.parameters_to_optimize += self.image_enc_layers.parameters_to_optimize
self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")]
if "observation.state" in config.input_features:
@ -958,23 +959,25 @@ class DefaultImageEncoder(nn.Module):
dummy_batch = torch.zeros(1, *config.input_features[image_key].shape)
with torch.inference_mode():
self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:]
self.image_enc_layers.extend(
nn.Sequential(
nn.Flatten(),
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Tanh(),
)
self.image_enc_proj = nn.Sequential(
nn.Flatten(),
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Tanh(),
)
self.parameters_to_optimize = []
if not config.freeze_vision_encoder:
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
self.parameters_to_optimize += list(self.image_enc_proj.parameters())
def forward(self, x):
return self.image_enc_layers(x)
return self.image_enc_proj(self.image_enc_layers(x))
class PretrainedImageEncoder(nn.Module):
def __init__(self, config: SACConfig):
super().__init__()
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
self.image_enc_proj = nn.Sequential(
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
@ -982,6 +985,11 @@ class PretrainedImageEncoder(nn.Module):
nn.Tanh(),
)
self.parameters_to_optimize = []
if not config.freeze_vision_encoder:
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
self.parameters_to_optimize += list(self.image_enc_proj.parameters())
def _load_pretrained_vision_encoder(self, config: SACConfig):
"""Set up CNN encoder"""
from transformers import AutoModel