stick to hil serl nn architecture
Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
This commit is contained in:
parent
044ca3b039
commit
fe7b47f459
|
@ -86,6 +86,7 @@ class SACConfig(PreTrainedConfig):
|
||||||
image_encoder_hidden_dim: Hidden dimension size for the image encoder.
|
image_encoder_hidden_dim: Hidden dimension size for the image encoder.
|
||||||
shared_encoder: Whether to use a shared encoder for actor and critic.
|
shared_encoder: Whether to use a shared encoder for actor and critic.
|
||||||
num_discrete_actions: Number of discrete actions, eg for gripper actions.
|
num_discrete_actions: Number of discrete actions, eg for gripper actions.
|
||||||
|
image_embedding_pooling_dim: Dimension of the image embedding pooling.
|
||||||
concurrency: Configuration for concurrency settings.
|
concurrency: Configuration for concurrency settings.
|
||||||
actor_learner: Configuration for actor-learner architecture.
|
actor_learner: Configuration for actor-learner architecture.
|
||||||
online_steps: Number of steps for online training.
|
online_steps: Number of steps for online training.
|
||||||
|
@ -147,6 +148,7 @@ class SACConfig(PreTrainedConfig):
|
||||||
image_encoder_hidden_dim: int = 32
|
image_encoder_hidden_dim: int = 32
|
||||||
shared_encoder: bool = True
|
shared_encoder: bool = True
|
||||||
num_discrete_actions: int | None = None
|
num_discrete_actions: int | None = None
|
||||||
|
image_embedding_pooling_dim: int = 8
|
||||||
|
|
||||||
# Training parameter
|
# Training parameter
|
||||||
online_steps: int = 1000000
|
online_steps: int = 1000000
|
||||||
|
|
|
@ -196,7 +196,7 @@ class SACPolicy(
|
||||||
# We cached the encoder output to avoid recomputing it
|
# We cached the encoder output to avoid recomputing it
|
||||||
observations_features = None
|
observations_features = None
|
||||||
if self.shared_encoder:
|
if self.shared_encoder:
|
||||||
observations_features = self.actor.encoder.get_image_features(batch, normalize=True)
|
observations_features = self.actor.encoder.get_base_image_features(batch, normalize=True)
|
||||||
|
|
||||||
actions, _, _ = self.actor(batch, observations_features)
|
actions, _, _ = self.actor(batch, observations_features)
|
||||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||||
|
@ -512,12 +512,19 @@ class SACObservationEncoder(nn.Module):
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.input_normalization = input_normalizer
|
|
||||||
self.has_pretrained_vision_encoder = False
|
|
||||||
|
|
||||||
self.aggregation_size: int = 0
|
self.freeze_image_encoder = config.freeze_vision_encoder
|
||||||
|
|
||||||
|
self.input_normalization = input_normalizer
|
||||||
|
self._out_dim = 0
|
||||||
|
|
||||||
if any("observation.image" in key for key in config.input_features):
|
if any("observation.image" in key for key in config.input_features):
|
||||||
self.camera_number = config.camera_number
|
self.camera_number = config.camera_number
|
||||||
|
self.all_image_keys = sorted(
|
||||||
|
[k for k in config.input_features if k.startswith("observation.image")]
|
||||||
|
)
|
||||||
|
|
||||||
|
self._out_dim += len(self.all_image_keys) * config.latent_dim
|
||||||
|
|
||||||
if self.config.vision_encoder_name is not None:
|
if self.config.vision_encoder_name is not None:
|
||||||
self.image_enc_layers = PretrainedImageEncoder(config)
|
self.image_enc_layers = PretrainedImageEncoder(config)
|
||||||
|
@ -525,12 +532,42 @@ class SACObservationEncoder(nn.Module):
|
||||||
else:
|
else:
|
||||||
self.image_enc_layers = DefaultImageEncoder(config)
|
self.image_enc_layers = DefaultImageEncoder(config)
|
||||||
|
|
||||||
self.aggregation_size += config.latent_dim * self.camera_number
|
if self.freeze_image_encoder:
|
||||||
|
freeze_image_encoder(self.image_enc_layers)
|
||||||
|
|
||||||
if config.freeze_vision_encoder:
|
# Separate components for each image stream
|
||||||
freeze_image_encoder(self.image_enc_layers.image_enc_layers)
|
self.spatial_embeddings = nn.ModuleDict()
|
||||||
|
self.post_encoders = nn.ModuleDict()
|
||||||
|
|
||||||
self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")]
|
# determine the nb_channels, height and width of the image
|
||||||
|
|
||||||
|
# Get first image key from input features
|
||||||
|
image_key = next(key for key in config.input_features if key.startswith("observation.image")) # noqa: SIM118
|
||||||
|
dummy_batch = torch.zeros(1, *config.input_features[image_key].shape)
|
||||||
|
with torch.inference_mode():
|
||||||
|
dummy_output = self.image_enc_layers(dummy_batch)
|
||||||
|
_, channels, height, width = dummy_output.shape
|
||||||
|
|
||||||
|
for key in self.all_image_keys:
|
||||||
|
# HACK: This a hack because the state_dict use . to separate the keys
|
||||||
|
safe_key = key.replace(".", "_")
|
||||||
|
# Separate spatial embedding per image
|
||||||
|
self.spatial_embeddings[safe_key] = SpatialLearnedEmbeddings(
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
channel=channels,
|
||||||
|
num_features=config.image_embedding_pooling_dim,
|
||||||
|
)
|
||||||
|
# Separate post-encoder per image
|
||||||
|
self.post_encoders[safe_key] = nn.Sequential(
|
||||||
|
nn.Dropout(0.1),
|
||||||
|
nn.Linear(
|
||||||
|
in_features=channels * config.image_embedding_pooling_dim,
|
||||||
|
out_features=config.latent_dim,
|
||||||
|
),
|
||||||
|
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||||
|
nn.Tanh(),
|
||||||
|
)
|
||||||
|
|
||||||
if "observation.state" in config.input_features:
|
if "observation.state" in config.input_features:
|
||||||
self.state_enc_layers = nn.Sequential(
|
self.state_enc_layers = nn.Sequential(
|
||||||
|
@ -541,8 +578,7 @@ class SACObservationEncoder(nn.Module):
|
||||||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||||
nn.Tanh(),
|
nn.Tanh(),
|
||||||
)
|
)
|
||||||
self.aggregation_size += config.latent_dim
|
self._out_dim += config.latent_dim
|
||||||
|
|
||||||
if "observation.environment_state" in config.input_features:
|
if "observation.environment_state" in config.input_features:
|
||||||
self.env_state_enc_layers = nn.Sequential(
|
self.env_state_enc_layers = nn.Sequential(
|
||||||
nn.Linear(
|
nn.Linear(
|
||||||
|
@ -552,13 +588,13 @@ class SACObservationEncoder(nn.Module):
|
||||||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||||
nn.Tanh(),
|
nn.Tanh(),
|
||||||
)
|
)
|
||||||
self.aggregation_size += config.latent_dim
|
self._out_dim += config.latent_dim
|
||||||
|
|
||||||
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, obs_dict: dict[str, Tensor], vision_encoder_cache: torch.Tensor | None = None
|
self,
|
||||||
) -> Tensor:
|
obs_dict: dict[str, torch.Tensor],
|
||||||
|
vision_encoder_cache: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
"""Encode the image and/or state vector.
|
"""Encode the image and/or state vector.
|
||||||
|
|
||||||
Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken
|
Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken
|
||||||
|
@ -566,10 +602,9 @@ class SACObservationEncoder(nn.Module):
|
||||||
"""
|
"""
|
||||||
feat = []
|
feat = []
|
||||||
obs_dict = self.input_normalization(obs_dict)
|
obs_dict = self.input_normalization(obs_dict)
|
||||||
if len(self.all_image_keys) > 0 and vision_encoder_cache is None:
|
if len(self.all_image_keys) > 0:
|
||||||
vision_encoder_cache = self.get_image_features(obs_dict, normalize=False)
|
if vision_encoder_cache is None:
|
||||||
|
vision_encoder_cache = self.get_base_image_features(obs_dict, normalize=False)
|
||||||
if vision_encoder_cache is not None:
|
|
||||||
feat.append(vision_encoder_cache)
|
feat.append(vision_encoder_cache)
|
||||||
|
|
||||||
if "observation.environment_state" in self.config.input_features:
|
if "observation.environment_state" in self.config.input_features:
|
||||||
|
@ -578,27 +613,50 @@ class SACObservationEncoder(nn.Module):
|
||||||
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
||||||
|
|
||||||
features = torch.cat(tensors=feat, dim=-1)
|
features = torch.cat(tensors=feat, dim=-1)
|
||||||
features = self.aggregation_layer(features)
|
|
||||||
|
|
||||||
return features
|
return features
|
||||||
|
|
||||||
def get_image_features(self, batch: dict[str, Tensor], normalize: bool = True) -> torch.Tensor:
|
def get_base_image_features(
|
||||||
# [N*B, C, H, W]
|
self, batch: dict[str, torch.Tensor], normalize: bool = True
|
||||||
|
) -> dict[str, torch.Tensor]:
|
||||||
|
"""Process all images through the base encoder in a batched manner"""
|
||||||
if normalize:
|
if normalize:
|
||||||
batch = self.input_normalization(batch)
|
batch = self.input_normalization(batch)
|
||||||
if len(self.all_image_keys) > 0:
|
|
||||||
# Batch all images along the batch dimension, then encode them.
|
# Sort keys for consistent ordering
|
||||||
images_batched = torch.cat([batch[key] for key in self.all_image_keys], dim=0)
|
sorted_keys = sorted(self.all_image_keys)
|
||||||
images_batched = self.image_enc_layers(images_batched)
|
|
||||||
embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys))
|
# Stack all images into a single batch
|
||||||
embeddings_image = torch.cat(embeddings_chunks, dim=-1)
|
batched_input = torch.cat([batch[key] for key in sorted_keys], dim=0)
|
||||||
return embeddings_image
|
|
||||||
return None
|
# Process through the image encoder in one pass
|
||||||
|
batched_output = self.image_enc_layers(batched_input)
|
||||||
|
|
||||||
|
if self.freeze_image_encoder:
|
||||||
|
batched_output = batched_output.detach()
|
||||||
|
|
||||||
|
# Split the output back into individual tensors
|
||||||
|
image_features = torch.chunk(batched_output, chunks=len(sorted_keys), dim=0)
|
||||||
|
|
||||||
|
# Create a dictionary mapping the original keys to their features
|
||||||
|
result = {}
|
||||||
|
for key, features in zip(sorted_keys, image_features, strict=True):
|
||||||
|
result[key] = features
|
||||||
|
|
||||||
|
image_features = []
|
||||||
|
for key in result:
|
||||||
|
safe_key = key.replace(".", "_")
|
||||||
|
x = self.spatial_embeddings[safe_key](result[key])
|
||||||
|
x = self.post_encoders[safe_key](x)
|
||||||
|
image_features.append(x)
|
||||||
|
|
||||||
|
image_features = torch.cat(image_features, dim=-1)
|
||||||
|
return image_features
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_dim(self) -> int:
|
def output_dim(self) -> int:
|
||||||
"""Returns the dimension of the encoder output"""
|
"""Returns the dimension of the encoder output"""
|
||||||
return self.config.latent_dim
|
return self._out_dim
|
||||||
|
|
||||||
|
|
||||||
class MLP(nn.Module):
|
class MLP(nn.Module):
|
||||||
|
@ -938,44 +996,29 @@ class DefaultImageEncoder(nn.Module):
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
)
|
)
|
||||||
# Get first image key from input features
|
# Get first image key from input features
|
||||||
image_key = next(key for key in config.input_features.keys() if key.startswith("observation.image")) # noqa: SIM118
|
|
||||||
dummy_batch = torch.zeros(1, *config.input_features[image_key].shape)
|
|
||||||
with torch.inference_mode():
|
|
||||||
self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:]
|
|
||||||
self.image_enc_proj = nn.Sequential(
|
|
||||||
nn.Flatten(),
|
|
||||||
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
|
|
||||||
nn.LayerNorm(config.latent_dim),
|
|
||||||
nn.Tanh(),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.freeze_image_encoder = config.freeze_vision_encoder
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.image_enc_layers(x)
|
x = self.image_enc_layers(x)
|
||||||
if self.freeze_image_encoder:
|
return x
|
||||||
x = x.detach()
|
|
||||||
return self.image_enc_proj(x)
|
|
||||||
|
def freeze_image_encoder(image_encoder: nn.Module):
|
||||||
|
"""Freeze all parameters in the encoder"""
|
||||||
|
for param in image_encoder.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
|
||||||
class PretrainedImageEncoder(nn.Module):
|
class PretrainedImageEncoder(nn.Module):
|
||||||
def __init__(self, config: SACConfig):
|
def __init__(self, config: SACConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
|
|
||||||
self.image_enc_proj = nn.Sequential(
|
|
||||||
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
|
|
||||||
nn.LayerNorm(config.latent_dim),
|
|
||||||
nn.Tanh(),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.freeze_image_encoder = config.freeze_vision_encoder
|
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
|
||||||
|
|
||||||
def _load_pretrained_vision_encoder(self, config: SACConfig):
|
def _load_pretrained_vision_encoder(self, config: SACConfig):
|
||||||
"""Set up CNN encoder"""
|
"""Set up CNN encoder"""
|
||||||
from transformers import AutoModel
|
from transformers import AutoModel
|
||||||
|
|
||||||
self.image_enc_layers = AutoModel.from_pretrained(config.vision_encoder_name, trust_remote_code=True)
|
self.image_enc_layers = AutoModel.from_pretrained(config.vision_encoder_name, trust_remote_code=True)
|
||||||
# self.image_enc_layers.pooler = Identity()
|
|
||||||
|
|
||||||
if hasattr(self.image_enc_layers.config, "hidden_sizes"):
|
if hasattr(self.image_enc_layers.config, "hidden_sizes"):
|
||||||
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[-1] # Last channel dimension
|
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[-1] # Last channel dimension
|
||||||
|
@ -986,21 +1029,10 @@ class PretrainedImageEncoder(nn.Module):
|
||||||
return self.image_enc_layers, self.image_enc_out_shape
|
return self.image_enc_layers, self.image_enc_out_shape
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# TODO: (maractingi, azouitine) check the forward pass of the pretrained model
|
enc_feat = self.image_enc_layers(x).last_hidden_state
|
||||||
# doesn't reach the classifier layer because we don't need it
|
|
||||||
enc_feat = self.image_enc_layers(x).pooler_output
|
|
||||||
if self.freeze_image_encoder:
|
|
||||||
enc_feat = enc_feat.detach()
|
|
||||||
enc_feat = self.image_enc_proj(enc_feat.view(enc_feat.shape[0], -1))
|
|
||||||
return enc_feat
|
return enc_feat
|
||||||
|
|
||||||
|
|
||||||
def freeze_image_encoder(image_encoder: nn.Module):
|
|
||||||
"""Freeze all parameters in the encoder"""
|
|
||||||
for param in image_encoder.parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
|
|
||||||
|
|
||||||
def orthogonal_init():
|
def orthogonal_init():
|
||||||
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
|
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
|
||||||
|
|
||||||
|
@ -1013,6 +1045,50 @@ class Identity(nn.Module):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SpatialLearnedEmbeddings(nn.Module):
|
||||||
|
def __init__(self, height, width, channel, num_features=8):
|
||||||
|
"""
|
||||||
|
PyTorch implementation of learned spatial embeddings
|
||||||
|
|
||||||
|
Args:
|
||||||
|
height: Spatial height of input features
|
||||||
|
width: Spatial width of input features
|
||||||
|
channel: Number of input channels
|
||||||
|
num_features: Number of output embedding dimensions
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.height = height
|
||||||
|
self.width = width
|
||||||
|
self.channel = channel
|
||||||
|
self.num_features = num_features
|
||||||
|
|
||||||
|
self.kernel = nn.Parameter(torch.empty(channel, height, width, num_features))
|
||||||
|
|
||||||
|
nn.init.kaiming_normal_(self.kernel, mode="fan_in", nonlinearity="linear")
|
||||||
|
|
||||||
|
def forward(self, features):
|
||||||
|
"""
|
||||||
|
Forward pass for spatial embedding
|
||||||
|
|
||||||
|
Args:
|
||||||
|
features: Input tensor of shape [B, C, H, W] where B is batch size,
|
||||||
|
C is number of channels, H is height, and W is width
|
||||||
|
Returns:
|
||||||
|
Output tensor of shape [B, C*F] where F is the number of features
|
||||||
|
"""
|
||||||
|
|
||||||
|
features_expanded = features.unsqueeze(-1) # [B, C, H, W, 1]
|
||||||
|
kernel_expanded = self.kernel.unsqueeze(0) # [1, C, H, W, F]
|
||||||
|
|
||||||
|
# Element-wise multiplication and spatial reduction
|
||||||
|
output = (features_expanded * kernel_expanded).sum(dim=(2, 3)) # Sum over H,W dimensions
|
||||||
|
|
||||||
|
# Reshape to combine channel and feature dimensions
|
||||||
|
output = output.view(output.size(0), -1) # [B, C*F]
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
|
def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
|
||||||
converted_params = {}
|
converted_params = {}
|
||||||
for outer_key, inner_dict in normalization_params.items():
|
for outer_key, inner_dict in normalization_params.items():
|
||||||
|
|
Loading…
Reference in New Issue