Refactor modeling_sac and parameter handling for clarity and reusability.

Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
This commit is contained in:
AdilZouitine 2025-04-14 14:00:57 +00:00
parent 854bfb4ff8
commit 320a1a92a3
2 changed files with 67 additions and 45 deletions

View File

@ -167,8 +167,12 @@ class SACPolicy(
def get_optim_params(self) -> dict: def get_optim_params(self) -> dict:
optim_params = { optim_params = {
"actor": self.actor.parameters_to_optimize, "actor": [
"critic": self.critic_ensemble.parameters_to_optimize, p
for n, p in self.actor.named_parameters()
if not n.startswith("encoder") or not self.shared_encoder
],
"critic": self.critic_ensemble.parameters(),
"temperature": self.log_alpha, "temperature": self.log_alpha,
} }
if self.config.num_discrete_actions is not None: if self.config.num_discrete_actions is not None:
@ -451,11 +455,11 @@ class SACPolicy(
target_next_grasp_qs, dim=1, index=best_next_grasp_action target_next_grasp_qs, dim=1, index=best_next_grasp_action
).squeeze(-1) ).squeeze(-1)
# Compute target Q-value with Bellman equation # Compute target Q-value with Bellman equation
rewards_gripper = rewards rewards_gripper = rewards
if gripper_penalties is not None: if gripper_penalties is not None:
rewards_gripper = rewards + gripper_penalties rewards_gripper = rewards + gripper_penalties
target_grasp_q = rewards_gripper + (1 - done) * self.config.discount * target_next_grasp_q target_grasp_q = rewards_gripper + (1 - done) * self.config.discount * target_next_grasp_q
# Get predicted Q-values for current observations # Get predicted Q-values for current observations
predicted_grasp_qs = self.grasp_critic_forward( predicted_grasp_qs = self.grasp_critic_forward(
@ -510,7 +514,6 @@ class SACObservationEncoder(nn.Module):
self.config = config self.config = config
self.input_normalization = input_normalizer self.input_normalization = input_normalizer
self.has_pretrained_vision_encoder = False self.has_pretrained_vision_encoder = False
self.parameters_to_optimize = []
self.aggregation_size: int = 0 self.aggregation_size: int = 0
if any("observation.image" in key for key in config.input_features): if any("observation.image" in key for key in config.input_features):
@ -527,8 +530,6 @@ class SACObservationEncoder(nn.Module):
if config.freeze_vision_encoder: if config.freeze_vision_encoder:
freeze_image_encoder(self.image_enc_layers.image_enc_layers) freeze_image_encoder(self.image_enc_layers.image_enc_layers)
self.parameters_to_optimize += self.image_enc_layers.parameters_to_optimize
self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")] self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")]
if "observation.state" in config.input_features: if "observation.state" in config.input_features:
@ -542,8 +543,6 @@ class SACObservationEncoder(nn.Module):
) )
self.aggregation_size += config.latent_dim self.aggregation_size += config.latent_dim
self.parameters_to_optimize += list(self.state_enc_layers.parameters())
if "observation.environment_state" in config.input_features: if "observation.environment_state" in config.input_features:
self.env_state_enc_layers = nn.Sequential( self.env_state_enc_layers = nn.Sequential(
nn.Linear( nn.Linear(
@ -554,10 +553,8 @@ class SACObservationEncoder(nn.Module):
nn.Tanh(), nn.Tanh(),
) )
self.aggregation_size += config.latent_dim self.aggregation_size += config.latent_dim
self.parameters_to_optimize += list(self.env_state_enc_layers.parameters())
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim) self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
self.parameters_to_optimize += list(self.aggregation_layer.parameters())
def forward( def forward(
self, obs_dict: dict[str, Tensor], vision_encoder_cache: torch.Tensor | None = None self, obs_dict: dict[str, Tensor], vision_encoder_cache: torch.Tensor | None = None
@ -737,12 +734,6 @@ class CriticEnsemble(nn.Module):
self.output_normalization = output_normalization self.output_normalization = output_normalization
self.critics = nn.ModuleList(ensemble) self.critics = nn.ModuleList(ensemble)
self.parameters_to_optimize = []
# Handle the case where a part of the encoder if frozen
if self.encoder is not None:
self.parameters_to_optimize += list(self.encoder.parameters_to_optimize)
self.parameters_to_optimize += list(self.critics.parameters())
def forward( def forward(
self, self,
observations: dict[str, torch.Tensor], observations: dict[str, torch.Tensor],
@ -805,10 +796,6 @@ class GraspCritic(nn.Module):
else: else:
orthogonal_init()(self.output_layer.weight) orthogonal_init()(self.output_layer.weight)
self.parameters_to_optimize = []
self.parameters_to_optimize += list(self.net.parameters())
self.parameters_to_optimize += list(self.output_layer.parameters())
def forward( def forward(
self, observations: torch.Tensor, observation_features: torch.Tensor | None = None self, observations: torch.Tensor, observation_features: torch.Tensor | None = None
) -> torch.Tensor: ) -> torch.Tensor:
@ -840,12 +827,8 @@ class Policy(nn.Module):
self.log_std_max = log_std_max self.log_std_max = log_std_max
self.fixed_std = fixed_std self.fixed_std = fixed_std
self.use_tanh_squash = use_tanh_squash self.use_tanh_squash = use_tanh_squash
self.parameters_to_optimize = [] self.encoder_is_shared = encoder_is_shared
self.parameters_to_optimize += list(self.network.parameters())
if self.encoder is not None and not encoder_is_shared:
self.parameters_to_optimize += list(self.encoder.parameters())
# Find the last Linear layer's output dimension # Find the last Linear layer's output dimension
for layer in reversed(network.net): for layer in reversed(network.net):
if isinstance(layer, nn.Linear): if isinstance(layer, nn.Linear):
@ -859,7 +842,6 @@ class Policy(nn.Module):
else: else:
orthogonal_init()(self.mean_layer.weight) orthogonal_init()(self.mean_layer.weight)
self.parameters_to_optimize += list(self.mean_layer.parameters())
# Standard deviation layer or parameter # Standard deviation layer or parameter
if fixed_std is None: if fixed_std is None:
self.std_layer = nn.Linear(out_features, action_dim) self.std_layer = nn.Linear(out_features, action_dim)
@ -868,7 +850,6 @@ class Policy(nn.Module):
nn.init.uniform_(self.std_layer.bias, -init_final, init_final) nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
else: else:
orthogonal_init()(self.std_layer.weight) orthogonal_init()(self.std_layer.weight)
self.parameters_to_optimize += list(self.std_layer.parameters())
def forward( def forward(
self, self,
@ -877,6 +858,8 @@ class Policy(nn.Module):
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Encode observations if encoder exists # Encode observations if encoder exists
obs_enc = self.encoder(observations, vision_encoder_cache=observation_features) obs_enc = self.encoder(observations, vision_encoder_cache=observation_features)
if self.encoder_is_shared:
obs_enc = obs_enc.detach()
# Get network outputs # Get network outputs
outputs = self.network(obs_enc) outputs = self.network(obs_enc)
@ -966,13 +949,13 @@ class DefaultImageEncoder(nn.Module):
nn.Tanh(), nn.Tanh(),
) )
self.parameters_to_optimize = [] self.freeze_image_encoder = config.freeze_vision_encoder
if not config.freeze_vision_encoder:
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
self.parameters_to_optimize += list(self.image_enc_proj.parameters())
def forward(self, x): def forward(self, x):
return self.image_enc_proj(self.image_enc_layers(x)) x = self.image_enc_layers(x)
if self.freeze_image_encoder:
x = x.detach()
return self.image_enc_proj(x)
class PretrainedImageEncoder(nn.Module): class PretrainedImageEncoder(nn.Module):
@ -985,10 +968,7 @@ class PretrainedImageEncoder(nn.Module):
nn.Tanh(), nn.Tanh(),
) )
self.parameters_to_optimize = [] self.freeze_image_encoder = config.freeze_vision_encoder
if not config.freeze_vision_encoder:
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
self.parameters_to_optimize += list(self.image_enc_proj.parameters())
def _load_pretrained_vision_encoder(self, config: SACConfig): def _load_pretrained_vision_encoder(self, config: SACConfig):
"""Set up CNN encoder""" """Set up CNN encoder"""
@ -1009,6 +989,8 @@ class PretrainedImageEncoder(nn.Module):
# TODO: (maractingi, azouitine) check the forward pass of the pretrained model # TODO: (maractingi, azouitine) check the forward pass of the pretrained model
# doesn't reach the classifier layer because we don't need it # doesn't reach the classifier layer because we don't need it
enc_feat = self.image_enc_layers(x).pooler_output enc_feat = self.image_enc_layers(x).pooler_output
if self.freeze_image_encoder:
enc_feat = enc_feat.detach()
enc_feat = self.image_enc_proj(enc_feat.view(enc_feat.shape[0], -1)) enc_feat = self.image_enc_proj(enc_feat.view(enc_feat.shape[0], -1))
return enc_feat return enc_feat

View File

@ -510,7 +510,7 @@ def add_actor_information_and_train(
optimizers["actor"].zero_grad() optimizers["actor"].zero_grad()
loss_actor.backward() loss_actor.backward()
actor_grad_norm = torch.nn.utils.clip_grad_norm_( actor_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.actor.parameters_to_optimize, max_norm=clip_grad_norm_value parameters=policy.actor.parameters(), max_norm=clip_grad_norm_value
).item() ).item()
optimizers["actor"].step() optimizers["actor"].step()
@ -773,12 +773,14 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
""" """
optimizer_actor = torch.optim.Adam( optimizer_actor = torch.optim.Adam(
# NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor # NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor
params=policy.actor.parameters_to_optimize, params=[
p
for n, p in policy.actor.named_parameters()
if not n.startswith("encoder") or not policy.config.shared_encoder
],
lr=cfg.policy.actor_lr, lr=cfg.policy.actor_lr,
) )
optimizer_critic = torch.optim.Adam( optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
params=policy.critic_ensemble.parameters_to_optimize, lr=cfg.policy.critic_lr
)
if cfg.policy.num_discrete_actions is not None: if cfg.policy.num_discrete_actions is not None:
optimizer_grasp_critic = torch.optim.Adam( optimizer_grasp_critic = torch.optim.Adam(
@ -1089,6 +1091,44 @@ def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
parameters_queue.put(state_bytes) parameters_queue.put(state_bytes)
def check_weight_gradients(module: nn.Module) -> dict[str, bool]:
"""
Checks whether each parameter in the module has a gradient.
Args:
module (nn.Module): A PyTorch module whose parameters will be inspected.
Returns:
dict[str, bool]: A dictionary where each key is the parameter name and the value is
True if the parameter has an associated gradient (i.e. .grad is not None),
otherwise False.
"""
grad_status = {}
for name, param in module.named_parameters():
grad_status[name] = param.grad is not None
return grad_status
def get_overlapping_parameters(model: nn.Module, grad_status: dict[str, bool]) -> dict[str, bool]:
"""
Returns a dictionary of parameters (from actor) that also exist in the grad_status dictionary.
Args:
actor (nn.Module): The actor model.
grad_status (dict[str, bool]): A dictionary where keys are parameter names and values indicate
whether each parameter has a gradient.
Returns:
dict[str, bool]: A dictionary containing only the overlapping parameter names and their gradient status.
"""
# Get actor parameter names as a set.
model_param_names = {name for name, _ in model.named_parameters()}
# Intersect parameter names between actor and grad_status.
overlapping = {name: grad_status[name] for name in grad_status if name in model_param_names}
return overlapping
def process_interaction_message( def process_interaction_message(
message, interaction_step_shift: int, wandb_logger: WandBLogger | None = None message, interaction_step_shift: int, wandb_logger: WandBLogger | None = None
): ):