fix encoder training
This commit is contained in:
parent
ba09f44eb7
commit
854bfb4ff8
|
@ -525,9 +525,10 @@ class SACObservationEncoder(nn.Module):
|
|||
self.aggregation_size += config.latent_dim * self.camera_number
|
||||
|
||||
if config.freeze_vision_encoder:
|
||||
freeze_image_encoder(self.image_enc_layers)
|
||||
else:
|
||||
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
|
||||
freeze_image_encoder(self.image_enc_layers.image_enc_layers)
|
||||
|
||||
self.parameters_to_optimize += self.image_enc_layers.parameters_to_optimize
|
||||
|
||||
self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")]
|
||||
|
||||
if "observation.state" in config.input_features:
|
||||
|
@ -958,23 +959,25 @@ class DefaultImageEncoder(nn.Module):
|
|||
dummy_batch = torch.zeros(1, *config.input_features[image_key].shape)
|
||||
with torch.inference_mode():
|
||||
self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:]
|
||||
self.image_enc_layers.extend(
|
||||
nn.Sequential(
|
||||
nn.Flatten(),
|
||||
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
self.image_enc_proj = nn.Sequential(
|
||||
nn.Flatten(),
|
||||
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
self.parameters_to_optimize = []
|
||||
if not config.freeze_vision_encoder:
|
||||
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
|
||||
self.parameters_to_optimize += list(self.image_enc_proj.parameters())
|
||||
|
||||
def forward(self, x):
|
||||
return self.image_enc_layers(x)
|
||||
return self.image_enc_proj(self.image_enc_layers(x))
|
||||
|
||||
|
||||
class PretrainedImageEncoder(nn.Module):
|
||||
def __init__(self, config: SACConfig):
|
||||
super().__init__()
|
||||
|
||||
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
|
||||
self.image_enc_proj = nn.Sequential(
|
||||
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
|
||||
|
@ -982,6 +985,11 @@ class PretrainedImageEncoder(nn.Module):
|
|||
nn.Tanh(),
|
||||
)
|
||||
|
||||
self.parameters_to_optimize = []
|
||||
if not config.freeze_vision_encoder:
|
||||
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
|
||||
self.parameters_to_optimize += list(self.image_enc_proj.parameters())
|
||||
|
||||
def _load_pretrained_vision_encoder(self, config: SACConfig):
|
||||
"""Set up CNN encoder"""
|
||||
from transformers import AutoModel
|
||||
|
|
Loading…
Reference in New Issue