From bc7b6d3dafa0a235a6be025f15af180a5990192e Mon Sep 17 00:00:00 2001 From: Yoel Date: Fri, 31 Jan 2025 09:42:13 +0100 Subject: [PATCH] [Port HIL-SERL] Add HF vision encoder option in SAC (#651) Added support with custom pretrained vision encoder to the modeling sac implementation. Great job @ChorntonYoel ! --- .../common/policies/sac/configuration_sac.py | 1 + lerobot/common/policies/sac/modeling_sac.py | 122 +++++++++++------- 2 files changed, 77 insertions(+), 46 deletions(-) diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index 904679e8..3c6344de 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -55,6 +55,7 @@ class SACConfig: ) camera_number: int = 1 # Add type annotations for these fields: + vision_encoder_name: str = field(default="microsoft/resnet-18") image_encoder_hidden_dim: int = 32 shared_encoder: bool = False discount: float = 0.99 diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 64688b1b..bd6e9ef2 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -473,54 +473,61 @@ class SACObservationEncoder(nn.Module): """ super().__init__() self.config = config + self.has_pretrained_vision_encoder = False if "observation.image" in config.input_shapes: - self.image_enc_layers = nn.Sequential( - nn.Conv2d( - in_channels=config.input_shapes["observation.image"][0], - out_channels=config.image_encoder_hidden_dim, - kernel_size=7, - stride=2, - ), - nn.ReLU(), - nn.Conv2d( - in_channels=config.image_encoder_hidden_dim, - out_channels=config.image_encoder_hidden_dim, - kernel_size=5, - stride=2, - ), - nn.ReLU(), - nn.Conv2d( - in_channels=config.image_encoder_hidden_dim, - out_channels=config.image_encoder_hidden_dim, - kernel_size=3, - stride=2, - ), - nn.ReLU(), - nn.Conv2d( - in_channels=config.image_encoder_hidden_dim, - out_channels=config.image_encoder_hidden_dim, - kernel_size=3, - stride=2, - ), - nn.ReLU(), - ) self.camera_number = config.camera_number self.aggregation_size: int = 0 - - dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"]) - with torch.inference_mode(): - out_shape = self.image_enc_layers(dummy_batch).shape[1:] - self.image_enc_layers.extend( - sequential=nn.Sequential( - nn.Flatten(), - nn.Linear( - in_features=np.prod(out_shape) * self.camera_number, out_features=config.latent_dim - ), - nn.LayerNorm(normalized_shape=config.latent_dim), + if self.config.vision_encoder_name is not None: + self.has_pretrained_vision_encoder = True + self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder() + self.freeze_encoder() + self.image_enc_proj = nn.Sequential( + nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim), + nn.LayerNorm(config.latent_dim), nn.Tanh(), ) - ) - + else: + self.image_enc_layers = nn.Sequential( + nn.Conv2d( + in_channels=config.input_shapes["observation.image"][0], + out_channels=config.image_encoder_hidden_dim, + kernel_size=7, + stride=2, + ), + nn.ReLU(), + nn.Conv2d( + in_channels=config.image_encoder_hidden_dim, + out_channels=config.image_encoder_hidden_dim, + kernel_size=5, + stride=2, + ), + nn.ReLU(), + nn.Conv2d( + in_channels=config.image_encoder_hidden_dim, + out_channels=config.image_encoder_hidden_dim, + kernel_size=3, + stride=2, + ), + nn.ReLU(), + nn.Conv2d( + in_channels=config.image_encoder_hidden_dim, + out_channels=config.image_encoder_hidden_dim, + kernel_size=3, + stride=2, + ), + nn.ReLU(), + ) + dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"]) + with torch.inference_mode(): + self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:] + self.image_enc_layers.extend( + nn.Sequential( + nn.Flatten(), + nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim), + nn.LayerNorm(config.latent_dim), + nn.Tanh(), + ) + ) self.aggregation_size += config.latent_dim * self.camera_number if "observation.state" in config.input_shapes: self.state_enc_layers = nn.Sequential( @@ -541,10 +548,27 @@ class SACObservationEncoder(nn.Module): nn.LayerNorm(normalized_shape=config.latent_dim), nn.Tanh(), ) - - self.aggregation_size += config.latent_dim + self.aggregation_size += config.latent_dim self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim) + def _load_pretrained_vision_encoder(self): + """Set up CNN encoder""" + from transformers import AutoModel + + self.image_enc_layers = AutoModel.from_pretrained(self.config.vision_encoder_name) + if hasattr(self.image_enc_layers.config, "hidden_sizes"): + self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[-1] # Last channel dimension + elif hasattr(self.image_enc_layers, "fc"): + self.image_enc_out_shape = self.image_enc_layers.fc.in_features + else: + raise ValueError("Unsupported vision encoder architecture, make sure you are using a CNN") + return self.image_enc_layers, self.image_enc_out_shape + + def freeze_encoder(self): + """Freeze all parameters in the encoder""" + for param in self.image_enc_layers.parameters(): + param.requires_grad = False + def forward(self, obs_dict: dict[str, Tensor]) -> Tensor: """Encode the image and/or state vector. @@ -555,7 +579,13 @@ class SACObservationEncoder(nn.Module): # Concatenate all images along the channel dimension. image_keys = [k for k in self.config.input_shapes if k.startswith("observation.image")] for image_key in image_keys: - feat.append(flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key])) + if self.has_pretrained_vision_encoder: + enc_feat = self.image_enc_layers(obs_dict[image_key]).pooler_output + enc_feat = self.image_enc_proj(enc_feat.view(enc_feat.shape[0], -1)) + else: + enc_feat = flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key]) + + feat.append(enc_feat) if "observation.environment_state" in self.config.input_shapes: feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"])) if "observation.state" in self.config.input_shapes: