stick to hil serl nn architecture
Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
This commit is contained in:
parent
320a1a92a3
commit
5d7820527d
|
@ -86,6 +86,7 @@ class SACConfig(PreTrainedConfig):
|
|||
image_encoder_hidden_dim: Hidden dimension size for the image encoder.
|
||||
shared_encoder: Whether to use a shared encoder for actor and critic.
|
||||
num_discrete_actions: Number of discrete actions, eg for gripper actions.
|
||||
image_embedding_pooling_dim: Dimension of the image embedding pooling.
|
||||
concurrency: Configuration for concurrency settings.
|
||||
actor_learner: Configuration for actor-learner architecture.
|
||||
online_steps: Number of steps for online training.
|
||||
|
@ -147,6 +148,7 @@ class SACConfig(PreTrainedConfig):
|
|||
image_encoder_hidden_dim: int = 32
|
||||
shared_encoder: bool = True
|
||||
num_discrete_actions: int | None = None
|
||||
image_embedding_pooling_dim: int = 8
|
||||
|
||||
# Training parameter
|
||||
online_steps: int = 1000000
|
||||
|
|
|
@ -196,7 +196,7 @@ class SACPolicy(
|
|||
# We cached the encoder output to avoid recomputing it
|
||||
observations_features = None
|
||||
if self.shared_encoder:
|
||||
observations_features = self.actor.encoder.get_image_features(batch, normalize=True)
|
||||
observations_features = self.actor.encoder.get_base_image_features(batch, normalize=True)
|
||||
|
||||
actions, _, _ = self.actor(batch, observations_features)
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
@ -512,12 +512,19 @@ class SACObservationEncoder(nn.Module):
|
|||
"""
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.input_normalization = input_normalizer
|
||||
self.has_pretrained_vision_encoder = False
|
||||
|
||||
self.aggregation_size: int = 0
|
||||
self.freeze_image_encoder = config.freeze_vision_encoder
|
||||
|
||||
self.input_normalization = input_normalizer
|
||||
self._out_dim = 0
|
||||
|
||||
if any("observation.image" in key for key in config.input_features):
|
||||
self.camera_number = config.camera_number
|
||||
self.all_image_keys = sorted(
|
||||
[k for k in config.input_features if k.startswith("observation.image")]
|
||||
)
|
||||
|
||||
self._out_dim += len(self.all_image_keys) * config.latent_dim
|
||||
|
||||
if self.config.vision_encoder_name is not None:
|
||||
self.image_enc_layers = PretrainedImageEncoder(config)
|
||||
|
@ -525,12 +532,42 @@ class SACObservationEncoder(nn.Module):
|
|||
else:
|
||||
self.image_enc_layers = DefaultImageEncoder(config)
|
||||
|
||||
self.aggregation_size += config.latent_dim * self.camera_number
|
||||
if self.freeze_image_encoder:
|
||||
freeze_image_encoder(self.image_enc_layers)
|
||||
|
||||
if config.freeze_vision_encoder:
|
||||
freeze_image_encoder(self.image_enc_layers.image_enc_layers)
|
||||
# Separate components for each image stream
|
||||
self.spatial_embeddings = nn.ModuleDict()
|
||||
self.post_encoders = nn.ModuleDict()
|
||||
|
||||
self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")]
|
||||
# determine the nb_channels, height and width of the image
|
||||
|
||||
# Get first image key from input features
|
||||
image_key = next(key for key in config.input_features if key.startswith("observation.image")) # noqa: SIM118
|
||||
dummy_batch = torch.zeros(1, *config.input_features[image_key].shape)
|
||||
with torch.inference_mode():
|
||||
dummy_output = self.image_enc_layers(dummy_batch)
|
||||
_, channels, height, width = dummy_output.shape
|
||||
|
||||
for key in self.all_image_keys:
|
||||
# HACK: This a hack because the state_dict use . to separate the keys
|
||||
safe_key = key.replace(".", "_")
|
||||
# Separate spatial embedding per image
|
||||
self.spatial_embeddings[safe_key] = SpatialLearnedEmbeddings(
|
||||
height=height,
|
||||
width=width,
|
||||
channel=channels,
|
||||
num_features=config.image_embedding_pooling_dim,
|
||||
)
|
||||
# Separate post-encoder per image
|
||||
self.post_encoders[safe_key] = nn.Sequential(
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(
|
||||
in_features=channels * config.image_embedding_pooling_dim,
|
||||
out_features=config.latent_dim,
|
||||
),
|
||||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
if "observation.state" in config.input_features:
|
||||
self.state_enc_layers = nn.Sequential(
|
||||
|
@ -541,8 +578,7 @@ class SACObservationEncoder(nn.Module):
|
|||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
self.aggregation_size += config.latent_dim
|
||||
|
||||
self._out_dim += config.latent_dim
|
||||
if "observation.environment_state" in config.input_features:
|
||||
self.env_state_enc_layers = nn.Sequential(
|
||||
nn.Linear(
|
||||
|
@ -552,13 +588,13 @@ class SACObservationEncoder(nn.Module):
|
|||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
self.aggregation_size += config.latent_dim
|
||||
|
||||
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
|
||||
self._out_dim += config.latent_dim
|
||||
|
||||
def forward(
|
||||
self, obs_dict: dict[str, Tensor], vision_encoder_cache: torch.Tensor | None = None
|
||||
) -> Tensor:
|
||||
self,
|
||||
obs_dict: dict[str, torch.Tensor],
|
||||
vision_encoder_cache: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Encode the image and/or state vector.
|
||||
|
||||
Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken
|
||||
|
@ -566,10 +602,9 @@ class SACObservationEncoder(nn.Module):
|
|||
"""
|
||||
feat = []
|
||||
obs_dict = self.input_normalization(obs_dict)
|
||||
if len(self.all_image_keys) > 0 and vision_encoder_cache is None:
|
||||
vision_encoder_cache = self.get_image_features(obs_dict, normalize=False)
|
||||
|
||||
if vision_encoder_cache is not None:
|
||||
if len(self.all_image_keys) > 0:
|
||||
if vision_encoder_cache is None:
|
||||
vision_encoder_cache = self.get_base_image_features(obs_dict, normalize=False)
|
||||
feat.append(vision_encoder_cache)
|
||||
|
||||
if "observation.environment_state" in self.config.input_features:
|
||||
|
@ -578,27 +613,50 @@ class SACObservationEncoder(nn.Module):
|
|||
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
||||
|
||||
features = torch.cat(tensors=feat, dim=-1)
|
||||
features = self.aggregation_layer(features)
|
||||
|
||||
return features
|
||||
|
||||
def get_image_features(self, batch: dict[str, Tensor], normalize: bool = True) -> torch.Tensor:
|
||||
# [N*B, C, H, W]
|
||||
def get_base_image_features(
|
||||
self, batch: dict[str, torch.Tensor], normalize: bool = True
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Process all images through the base encoder in a batched manner"""
|
||||
if normalize:
|
||||
batch = self.input_normalization(batch)
|
||||
if len(self.all_image_keys) > 0:
|
||||
# Batch all images along the batch dimension, then encode them.
|
||||
images_batched = torch.cat([batch[key] for key in self.all_image_keys], dim=0)
|
||||
images_batched = self.image_enc_layers(images_batched)
|
||||
embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys))
|
||||
embeddings_image = torch.cat(embeddings_chunks, dim=-1)
|
||||
return embeddings_image
|
||||
return None
|
||||
|
||||
# Sort keys for consistent ordering
|
||||
sorted_keys = sorted(self.all_image_keys)
|
||||
|
||||
# Stack all images into a single batch
|
||||
batched_input = torch.cat([batch[key] for key in sorted_keys], dim=0)
|
||||
|
||||
# Process through the image encoder in one pass
|
||||
batched_output = self.image_enc_layers(batched_input)
|
||||
|
||||
if self.freeze_image_encoder:
|
||||
batched_output = batched_output.detach()
|
||||
|
||||
# Split the output back into individual tensors
|
||||
image_features = torch.chunk(batched_output, chunks=len(sorted_keys), dim=0)
|
||||
|
||||
# Create a dictionary mapping the original keys to their features
|
||||
result = {}
|
||||
for key, features in zip(sorted_keys, image_features, strict=True):
|
||||
result[key] = features
|
||||
|
||||
image_features = []
|
||||
for key in result:
|
||||
safe_key = key.replace(".", "_")
|
||||
x = self.spatial_embeddings[safe_key](result[key])
|
||||
x = self.post_encoders[safe_key](x)
|
||||
image_features.append(x)
|
||||
|
||||
image_features = torch.cat(image_features, dim=-1)
|
||||
return image_features
|
||||
|
||||
@property
|
||||
def output_dim(self) -> int:
|
||||
"""Returns the dimension of the encoder output"""
|
||||
return self.config.latent_dim
|
||||
return self._out_dim
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
|
@ -938,44 +996,29 @@ class DefaultImageEncoder(nn.Module):
|
|||
nn.ReLU(),
|
||||
)
|
||||
# Get first image key from input features
|
||||
image_key = next(key for key in config.input_features.keys() if key.startswith("observation.image")) # noqa: SIM118
|
||||
dummy_batch = torch.zeros(1, *config.input_features[image_key].shape)
|
||||
with torch.inference_mode():
|
||||
self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:]
|
||||
self.image_enc_proj = nn.Sequential(
|
||||
nn.Flatten(),
|
||||
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
self.freeze_image_encoder = config.freeze_vision_encoder
|
||||
|
||||
def forward(self, x):
|
||||
x = self.image_enc_layers(x)
|
||||
if self.freeze_image_encoder:
|
||||
x = x.detach()
|
||||
return self.image_enc_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
def freeze_image_encoder(image_encoder: nn.Module):
|
||||
"""Freeze all parameters in the encoder"""
|
||||
for param in image_encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
|
||||
class PretrainedImageEncoder(nn.Module):
|
||||
def __init__(self, config: SACConfig):
|
||||
super().__init__()
|
||||
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
|
||||
self.image_enc_proj = nn.Sequential(
|
||||
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
self.freeze_image_encoder = config.freeze_vision_encoder
|
||||
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
|
||||
|
||||
def _load_pretrained_vision_encoder(self, config: SACConfig):
|
||||
"""Set up CNN encoder"""
|
||||
from transformers import AutoModel
|
||||
|
||||
self.image_enc_layers = AutoModel.from_pretrained(config.vision_encoder_name, trust_remote_code=True)
|
||||
# self.image_enc_layers.pooler = Identity()
|
||||
|
||||
if hasattr(self.image_enc_layers.config, "hidden_sizes"):
|
||||
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[-1] # Last channel dimension
|
||||
|
@ -986,21 +1029,10 @@ class PretrainedImageEncoder(nn.Module):
|
|||
return self.image_enc_layers, self.image_enc_out_shape
|
||||
|
||||
def forward(self, x):
|
||||
# TODO: (maractingi, azouitine) check the forward pass of the pretrained model
|
||||
# doesn't reach the classifier layer because we don't need it
|
||||
enc_feat = self.image_enc_layers(x).pooler_output
|
||||
if self.freeze_image_encoder:
|
||||
enc_feat = enc_feat.detach()
|
||||
enc_feat = self.image_enc_proj(enc_feat.view(enc_feat.shape[0], -1))
|
||||
enc_feat = self.image_enc_layers(x).last_hidden_state
|
||||
return enc_feat
|
||||
|
||||
|
||||
def freeze_image_encoder(image_encoder: nn.Module):
|
||||
"""Freeze all parameters in the encoder"""
|
||||
for param in image_encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
|
||||
def orthogonal_init():
|
||||
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
|
||||
|
||||
|
@ -1013,6 +1045,50 @@ class Identity(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class SpatialLearnedEmbeddings(nn.Module):
|
||||
def __init__(self, height, width, channel, num_features=8):
|
||||
"""
|
||||
PyTorch implementation of learned spatial embeddings
|
||||
|
||||
Args:
|
||||
height: Spatial height of input features
|
||||
width: Spatial width of input features
|
||||
channel: Number of input channels
|
||||
num_features: Number of output embedding dimensions
|
||||
"""
|
||||
super().__init__()
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.channel = channel
|
||||
self.num_features = num_features
|
||||
|
||||
self.kernel = nn.Parameter(torch.empty(channel, height, width, num_features))
|
||||
|
||||
nn.init.kaiming_normal_(self.kernel, mode="fan_in", nonlinearity="linear")
|
||||
|
||||
def forward(self, features):
|
||||
"""
|
||||
Forward pass for spatial embedding
|
||||
|
||||
Args:
|
||||
features: Input tensor of shape [B, C, H, W] where B is batch size,
|
||||
C is number of channels, H is height, and W is width
|
||||
Returns:
|
||||
Output tensor of shape [B, C*F] where F is the number of features
|
||||
"""
|
||||
|
||||
features_expanded = features.unsqueeze(-1) # [B, C, H, W, 1]
|
||||
kernel_expanded = self.kernel.unsqueeze(0) # [1, C, H, W, F]
|
||||
|
||||
# Element-wise multiplication and spatial reduction
|
||||
output = (features_expanded * kernel_expanded).sum(dim=(2, 3)) # Sum over H,W dimensions
|
||||
|
||||
# Reshape to combine channel and feature dimensions
|
||||
output = output.view(output.size(0), -1) # [B, C*F]
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
|
||||
converted_params = {}
|
||||
for outer_key, inner_dict in normalization_params.items():
|
||||
|
|
Loading…
Reference in New Issue