stick to hil serl nn architecture

Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
This commit is contained in:
AdilZouitine 2025-04-15 07:44:32 +00:00
parent 320a1a92a3
commit 5d7820527d
2 changed files with 144 additions and 66 deletions

View File

@ -86,6 +86,7 @@ class SACConfig(PreTrainedConfig):
image_encoder_hidden_dim: Hidden dimension size for the image encoder.
shared_encoder: Whether to use a shared encoder for actor and critic.
num_discrete_actions: Number of discrete actions, eg for gripper actions.
image_embedding_pooling_dim: Dimension of the image embedding pooling.
concurrency: Configuration for concurrency settings.
actor_learner: Configuration for actor-learner architecture.
online_steps: Number of steps for online training.
@ -147,6 +148,7 @@ class SACConfig(PreTrainedConfig):
image_encoder_hidden_dim: int = 32
shared_encoder: bool = True
num_discrete_actions: int | None = None
image_embedding_pooling_dim: int = 8
# Training parameter
online_steps: int = 1000000

View File

@ -196,7 +196,7 @@ class SACPolicy(
# We cached the encoder output to avoid recomputing it
observations_features = None
if self.shared_encoder:
observations_features = self.actor.encoder.get_image_features(batch, normalize=True)
observations_features = self.actor.encoder.get_base_image_features(batch, normalize=True)
actions, _, _ = self.actor(batch, observations_features)
actions = self.unnormalize_outputs({"action": actions})["action"]
@ -512,12 +512,19 @@ class SACObservationEncoder(nn.Module):
"""
super().__init__()
self.config = config
self.input_normalization = input_normalizer
self.has_pretrained_vision_encoder = False
self.aggregation_size: int = 0
self.freeze_image_encoder = config.freeze_vision_encoder
self.input_normalization = input_normalizer
self._out_dim = 0
if any("observation.image" in key for key in config.input_features):
self.camera_number = config.camera_number
self.all_image_keys = sorted(
[k for k in config.input_features if k.startswith("observation.image")]
)
self._out_dim += len(self.all_image_keys) * config.latent_dim
if self.config.vision_encoder_name is not None:
self.image_enc_layers = PretrainedImageEncoder(config)
@ -525,12 +532,42 @@ class SACObservationEncoder(nn.Module):
else:
self.image_enc_layers = DefaultImageEncoder(config)
self.aggregation_size += config.latent_dim * self.camera_number
if self.freeze_image_encoder:
freeze_image_encoder(self.image_enc_layers)
if config.freeze_vision_encoder:
freeze_image_encoder(self.image_enc_layers.image_enc_layers)
# Separate components for each image stream
self.spatial_embeddings = nn.ModuleDict()
self.post_encoders = nn.ModuleDict()
self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")]
# determine the nb_channels, height and width of the image
# Get first image key from input features
image_key = next(key for key in config.input_features if key.startswith("observation.image")) # noqa: SIM118
dummy_batch = torch.zeros(1, *config.input_features[image_key].shape)
with torch.inference_mode():
dummy_output = self.image_enc_layers(dummy_batch)
_, channels, height, width = dummy_output.shape
for key in self.all_image_keys:
# HACK: This a hack because the state_dict use . to separate the keys
safe_key = key.replace(".", "_")
# Separate spatial embedding per image
self.spatial_embeddings[safe_key] = SpatialLearnedEmbeddings(
height=height,
width=width,
channel=channels,
num_features=config.image_embedding_pooling_dim,
)
# Separate post-encoder per image
self.post_encoders[safe_key] = nn.Sequential(
nn.Dropout(0.1),
nn.Linear(
in_features=channels * config.image_embedding_pooling_dim,
out_features=config.latent_dim,
),
nn.LayerNorm(normalized_shape=config.latent_dim),
nn.Tanh(),
)
if "observation.state" in config.input_features:
self.state_enc_layers = nn.Sequential(
@ -541,8 +578,7 @@ class SACObservationEncoder(nn.Module):
nn.LayerNorm(normalized_shape=config.latent_dim),
nn.Tanh(),
)
self.aggregation_size += config.latent_dim
self._out_dim += config.latent_dim
if "observation.environment_state" in config.input_features:
self.env_state_enc_layers = nn.Sequential(
nn.Linear(
@ -552,13 +588,13 @@ class SACObservationEncoder(nn.Module):
nn.LayerNorm(normalized_shape=config.latent_dim),
nn.Tanh(),
)
self.aggregation_size += config.latent_dim
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
self._out_dim += config.latent_dim
def forward(
self, obs_dict: dict[str, Tensor], vision_encoder_cache: torch.Tensor | None = None
) -> Tensor:
self,
obs_dict: dict[str, torch.Tensor],
vision_encoder_cache: torch.Tensor | None = None,
) -> torch.Tensor:
"""Encode the image and/or state vector.
Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken
@ -566,10 +602,9 @@ class SACObservationEncoder(nn.Module):
"""
feat = []
obs_dict = self.input_normalization(obs_dict)
if len(self.all_image_keys) > 0 and vision_encoder_cache is None:
vision_encoder_cache = self.get_image_features(obs_dict, normalize=False)
if vision_encoder_cache is not None:
if len(self.all_image_keys) > 0:
if vision_encoder_cache is None:
vision_encoder_cache = self.get_base_image_features(obs_dict, normalize=False)
feat.append(vision_encoder_cache)
if "observation.environment_state" in self.config.input_features:
@ -578,27 +613,50 @@ class SACObservationEncoder(nn.Module):
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
features = torch.cat(tensors=feat, dim=-1)
features = self.aggregation_layer(features)
return features
def get_image_features(self, batch: dict[str, Tensor], normalize: bool = True) -> torch.Tensor:
# [N*B, C, H, W]
def get_base_image_features(
self, batch: dict[str, torch.Tensor], normalize: bool = True
) -> dict[str, torch.Tensor]:
"""Process all images through the base encoder in a batched manner"""
if normalize:
batch = self.input_normalization(batch)
if len(self.all_image_keys) > 0:
# Batch all images along the batch dimension, then encode them.
images_batched = torch.cat([batch[key] for key in self.all_image_keys], dim=0)
images_batched = self.image_enc_layers(images_batched)
embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys))
embeddings_image = torch.cat(embeddings_chunks, dim=-1)
return embeddings_image
return None
# Sort keys for consistent ordering
sorted_keys = sorted(self.all_image_keys)
# Stack all images into a single batch
batched_input = torch.cat([batch[key] for key in sorted_keys], dim=0)
# Process through the image encoder in one pass
batched_output = self.image_enc_layers(batched_input)
if self.freeze_image_encoder:
batched_output = batched_output.detach()
# Split the output back into individual tensors
image_features = torch.chunk(batched_output, chunks=len(sorted_keys), dim=0)
# Create a dictionary mapping the original keys to their features
result = {}
for key, features in zip(sorted_keys, image_features, strict=True):
result[key] = features
image_features = []
for key in result:
safe_key = key.replace(".", "_")
x = self.spatial_embeddings[safe_key](result[key])
x = self.post_encoders[safe_key](x)
image_features.append(x)
image_features = torch.cat(image_features, dim=-1)
return image_features
@property
def output_dim(self) -> int:
"""Returns the dimension of the encoder output"""
return self.config.latent_dim
return self._out_dim
class MLP(nn.Module):
@ -938,44 +996,29 @@ class DefaultImageEncoder(nn.Module):
nn.ReLU(),
)
# Get first image key from input features
image_key = next(key for key in config.input_features.keys() if key.startswith("observation.image")) # noqa: SIM118
dummy_batch = torch.zeros(1, *config.input_features[image_key].shape)
with torch.inference_mode():
self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:]
self.image_enc_proj = nn.Sequential(
nn.Flatten(),
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Tanh(),
)
self.freeze_image_encoder = config.freeze_vision_encoder
def forward(self, x):
x = self.image_enc_layers(x)
if self.freeze_image_encoder:
x = x.detach()
return self.image_enc_proj(x)
return x
def freeze_image_encoder(image_encoder: nn.Module):
"""Freeze all parameters in the encoder"""
for param in image_encoder.parameters():
param.requires_grad = False
class PretrainedImageEncoder(nn.Module):
def __init__(self, config: SACConfig):
super().__init__()
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
self.image_enc_proj = nn.Sequential(
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Tanh(),
)
self.freeze_image_encoder = config.freeze_vision_encoder
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
def _load_pretrained_vision_encoder(self, config: SACConfig):
"""Set up CNN encoder"""
from transformers import AutoModel
self.image_enc_layers = AutoModel.from_pretrained(config.vision_encoder_name, trust_remote_code=True)
# self.image_enc_layers.pooler = Identity()
if hasattr(self.image_enc_layers.config, "hidden_sizes"):
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[-1] # Last channel dimension
@ -986,21 +1029,10 @@ class PretrainedImageEncoder(nn.Module):
return self.image_enc_layers, self.image_enc_out_shape
def forward(self, x):
# TODO: (maractingi, azouitine) check the forward pass of the pretrained model
# doesn't reach the classifier layer because we don't need it
enc_feat = self.image_enc_layers(x).pooler_output
if self.freeze_image_encoder:
enc_feat = enc_feat.detach()
enc_feat = self.image_enc_proj(enc_feat.view(enc_feat.shape[0], -1))
enc_feat = self.image_enc_layers(x).last_hidden_state
return enc_feat
def freeze_image_encoder(image_encoder: nn.Module):
"""Freeze all parameters in the encoder"""
for param in image_encoder.parameters():
param.requires_grad = False
def orthogonal_init():
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
@ -1013,6 +1045,50 @@ class Identity(nn.Module):
return x
class SpatialLearnedEmbeddings(nn.Module):
def __init__(self, height, width, channel, num_features=8):
"""
PyTorch implementation of learned spatial embeddings
Args:
height: Spatial height of input features
width: Spatial width of input features
channel: Number of input channels
num_features: Number of output embedding dimensions
"""
super().__init__()
self.height = height
self.width = width
self.channel = channel
self.num_features = num_features
self.kernel = nn.Parameter(torch.empty(channel, height, width, num_features))
nn.init.kaiming_normal_(self.kernel, mode="fan_in", nonlinearity="linear")
def forward(self, features):
"""
Forward pass for spatial embedding
Args:
features: Input tensor of shape [B, C, H, W] where B is batch size,
C is number of channels, H is height, and W is width
Returns:
Output tensor of shape [B, C*F] where F is the number of features
"""
features_expanded = features.unsqueeze(-1) # [B, C, H, W, 1]
kernel_expanded = self.kernel.unsqueeze(0) # [1, C, H, W, F]
# Element-wise multiplication and spatial reduction
output = (features_expanded * kernel_expanded).sum(dim=(2, 3)) # Sum over H,W dimensions
# Reshape to combine channel and feature dimensions
output = output.view(output.size(0), -1) # [B, C*F]
return output
def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
converted_params = {}
for outer_key, inner_dict in normalization_params.items():