fix encoder training
This commit is contained in:
parent
ba09f44eb7
commit
854bfb4ff8
|
@ -525,9 +525,10 @@ class SACObservationEncoder(nn.Module):
|
||||||
self.aggregation_size += config.latent_dim * self.camera_number
|
self.aggregation_size += config.latent_dim * self.camera_number
|
||||||
|
|
||||||
if config.freeze_vision_encoder:
|
if config.freeze_vision_encoder:
|
||||||
freeze_image_encoder(self.image_enc_layers)
|
freeze_image_encoder(self.image_enc_layers.image_enc_layers)
|
||||||
else:
|
|
||||||
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
|
self.parameters_to_optimize += self.image_enc_layers.parameters_to_optimize
|
||||||
|
|
||||||
self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")]
|
self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")]
|
||||||
|
|
||||||
if "observation.state" in config.input_features:
|
if "observation.state" in config.input_features:
|
||||||
|
@ -958,23 +959,25 @@ class DefaultImageEncoder(nn.Module):
|
||||||
dummy_batch = torch.zeros(1, *config.input_features[image_key].shape)
|
dummy_batch = torch.zeros(1, *config.input_features[image_key].shape)
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:]
|
self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:]
|
||||||
self.image_enc_layers.extend(
|
self.image_enc_proj = nn.Sequential(
|
||||||
nn.Sequential(
|
|
||||||
nn.Flatten(),
|
nn.Flatten(),
|
||||||
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
|
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
|
||||||
nn.LayerNorm(config.latent_dim),
|
nn.LayerNorm(config.latent_dim),
|
||||||
nn.Tanh(),
|
nn.Tanh(),
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
self.parameters_to_optimize = []
|
||||||
|
if not config.freeze_vision_encoder:
|
||||||
|
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
|
||||||
|
self.parameters_to_optimize += list(self.image_enc_proj.parameters())
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.image_enc_layers(x)
|
return self.image_enc_proj(self.image_enc_layers(x))
|
||||||
|
|
||||||
|
|
||||||
class PretrainedImageEncoder(nn.Module):
|
class PretrainedImageEncoder(nn.Module):
|
||||||
def __init__(self, config: SACConfig):
|
def __init__(self, config: SACConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
|
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
|
||||||
self.image_enc_proj = nn.Sequential(
|
self.image_enc_proj = nn.Sequential(
|
||||||
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
|
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
|
||||||
|
@ -982,6 +985,11 @@ class PretrainedImageEncoder(nn.Module):
|
||||||
nn.Tanh(),
|
nn.Tanh(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.parameters_to_optimize = []
|
||||||
|
if not config.freeze_vision_encoder:
|
||||||
|
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
|
||||||
|
self.parameters_to_optimize += list(self.image_enc_proj.parameters())
|
||||||
|
|
||||||
def _load_pretrained_vision_encoder(self, config: SACConfig):
|
def _load_pretrained_vision_encoder(self, config: SACConfig):
|
||||||
"""Set up CNN encoder"""
|
"""Set up CNN encoder"""
|
||||||
from transformers import AutoModel
|
from transformers import AutoModel
|
||||||
|
|
Loading…
Reference in New Issue