[Port HIL-SERL] Add HF vision encoder option in SAC (#651)

Added support with custom pretrained vision encoder to the modeling sac implementation. Great job @ChorntonYoel !
This commit is contained in:
Yoel 2025-01-31 09:42:13 +01:00 committed by AdilZouitine
parent c620b0878f
commit faab32fe14
2 changed files with 77 additions and 46 deletions

View File

@ -55,6 +55,7 @@ class SACConfig:
) )
camera_number: int = 1 camera_number: int = 1
# Add type annotations for these fields: # Add type annotations for these fields:
vision_encoder_name: str = field(default="microsoft/resnet-18")
image_encoder_hidden_dim: int = 32 image_encoder_hidden_dim: int = 32
shared_encoder: bool = False shared_encoder: bool = False
discount: float = 0.99 discount: float = 0.99

View File

@ -473,7 +473,20 @@ class SACObservationEncoder(nn.Module):
""" """
super().__init__() super().__init__()
self.config = config self.config = config
self.has_pretrained_vision_encoder = False
if "observation.image" in config.input_shapes: if "observation.image" in config.input_shapes:
self.camera_number = config.camera_number
self.aggregation_size: int = 0
if self.config.vision_encoder_name is not None:
self.has_pretrained_vision_encoder = True
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder()
self.freeze_encoder()
self.image_enc_proj = nn.Sequential(
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Tanh(),
)
else:
self.image_enc_layers = nn.Sequential( self.image_enc_layers = nn.Sequential(
nn.Conv2d( nn.Conv2d(
in_channels=config.input_shapes["observation.image"][0], in_channels=config.input_shapes["observation.image"][0],
@ -504,23 +517,17 @@ class SACObservationEncoder(nn.Module):
), ),
nn.ReLU(), nn.ReLU(),
) )
self.camera_number = config.camera_number
self.aggregation_size: int = 0
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"]) dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
with torch.inference_mode(): with torch.inference_mode():
out_shape = self.image_enc_layers(dummy_batch).shape[1:] self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:]
self.image_enc_layers.extend( self.image_enc_layers.extend(
sequential=nn.Sequential( nn.Sequential(
nn.Flatten(), nn.Flatten(),
nn.Linear( nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
in_features=np.prod(out_shape) * self.camera_number, out_features=config.latent_dim nn.LayerNorm(config.latent_dim),
),
nn.LayerNorm(normalized_shape=config.latent_dim),
nn.Tanh(), nn.Tanh(),
) )
) )
self.aggregation_size += config.latent_dim * self.camera_number self.aggregation_size += config.latent_dim * self.camera_number
if "observation.state" in config.input_shapes: if "observation.state" in config.input_shapes:
self.state_enc_layers = nn.Sequential( self.state_enc_layers = nn.Sequential(
@ -541,10 +548,27 @@ class SACObservationEncoder(nn.Module):
nn.LayerNorm(normalized_shape=config.latent_dim), nn.LayerNorm(normalized_shape=config.latent_dim),
nn.Tanh(), nn.Tanh(),
) )
self.aggregation_size += config.latent_dim self.aggregation_size += config.latent_dim
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim) self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
def _load_pretrained_vision_encoder(self):
"""Set up CNN encoder"""
from transformers import AutoModel
self.image_enc_layers = AutoModel.from_pretrained(self.config.vision_encoder_name)
if hasattr(self.image_enc_layers.config, "hidden_sizes"):
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[-1] # Last channel dimension
elif hasattr(self.image_enc_layers, "fc"):
self.image_enc_out_shape = self.image_enc_layers.fc.in_features
else:
raise ValueError("Unsupported vision encoder architecture, make sure you are using a CNN")
return self.image_enc_layers, self.image_enc_out_shape
def freeze_encoder(self):
"""Freeze all parameters in the encoder"""
for param in self.image_enc_layers.parameters():
param.requires_grad = False
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor: def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
"""Encode the image and/or state vector. """Encode the image and/or state vector.
@ -555,7 +579,13 @@ class SACObservationEncoder(nn.Module):
# Concatenate all images along the channel dimension. # Concatenate all images along the channel dimension.
image_keys = [k for k in self.config.input_shapes if k.startswith("observation.image")] image_keys = [k for k in self.config.input_shapes if k.startswith("observation.image")]
for image_key in image_keys: for image_key in image_keys:
feat.append(flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key])) if self.has_pretrained_vision_encoder:
enc_feat = self.image_enc_layers(obs_dict[image_key]).pooler_output
enc_feat = self.image_enc_proj(enc_feat.view(enc_feat.shape[0], -1))
else:
enc_feat = flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key])
feat.append(enc_feat)
if "observation.environment_state" in self.config.input_shapes: if "observation.environment_state" in self.config.input_shapes:
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"])) feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
if "observation.state" in self.config.input_shapes: if "observation.state" in self.config.input_shapes: