Refactor modeling_sac and parameter handling for clarity and reusability.
Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
This commit is contained in:
parent
854bfb4ff8
commit
320a1a92a3
|
@ -167,8 +167,12 @@ class SACPolicy(
|
|||
|
||||
def get_optim_params(self) -> dict:
|
||||
optim_params = {
|
||||
"actor": self.actor.parameters_to_optimize,
|
||||
"critic": self.critic_ensemble.parameters_to_optimize,
|
||||
"actor": [
|
||||
p
|
||||
for n, p in self.actor.named_parameters()
|
||||
if not n.startswith("encoder") or not self.shared_encoder
|
||||
],
|
||||
"critic": self.critic_ensemble.parameters(),
|
||||
"temperature": self.log_alpha,
|
||||
}
|
||||
if self.config.num_discrete_actions is not None:
|
||||
|
@ -451,11 +455,11 @@ class SACPolicy(
|
|||
target_next_grasp_qs, dim=1, index=best_next_grasp_action
|
||||
).squeeze(-1)
|
||||
|
||||
# Compute target Q-value with Bellman equation
|
||||
rewards_gripper = rewards
|
||||
if gripper_penalties is not None:
|
||||
rewards_gripper = rewards + gripper_penalties
|
||||
target_grasp_q = rewards_gripper + (1 - done) * self.config.discount * target_next_grasp_q
|
||||
# Compute target Q-value with Bellman equation
|
||||
rewards_gripper = rewards
|
||||
if gripper_penalties is not None:
|
||||
rewards_gripper = rewards + gripper_penalties
|
||||
target_grasp_q = rewards_gripper + (1 - done) * self.config.discount * target_next_grasp_q
|
||||
|
||||
# Get predicted Q-values for current observations
|
||||
predicted_grasp_qs = self.grasp_critic_forward(
|
||||
|
@ -510,7 +514,6 @@ class SACObservationEncoder(nn.Module):
|
|||
self.config = config
|
||||
self.input_normalization = input_normalizer
|
||||
self.has_pretrained_vision_encoder = False
|
||||
self.parameters_to_optimize = []
|
||||
|
||||
self.aggregation_size: int = 0
|
||||
if any("observation.image" in key for key in config.input_features):
|
||||
|
@ -527,8 +530,6 @@ class SACObservationEncoder(nn.Module):
|
|||
if config.freeze_vision_encoder:
|
||||
freeze_image_encoder(self.image_enc_layers.image_enc_layers)
|
||||
|
||||
self.parameters_to_optimize += self.image_enc_layers.parameters_to_optimize
|
||||
|
||||
self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")]
|
||||
|
||||
if "observation.state" in config.input_features:
|
||||
|
@ -542,8 +543,6 @@ class SACObservationEncoder(nn.Module):
|
|||
)
|
||||
self.aggregation_size += config.latent_dim
|
||||
|
||||
self.parameters_to_optimize += list(self.state_enc_layers.parameters())
|
||||
|
||||
if "observation.environment_state" in config.input_features:
|
||||
self.env_state_enc_layers = nn.Sequential(
|
||||
nn.Linear(
|
||||
|
@ -554,10 +553,8 @@ class SACObservationEncoder(nn.Module):
|
|||
nn.Tanh(),
|
||||
)
|
||||
self.aggregation_size += config.latent_dim
|
||||
self.parameters_to_optimize += list(self.env_state_enc_layers.parameters())
|
||||
|
||||
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
|
||||
self.parameters_to_optimize += list(self.aggregation_layer.parameters())
|
||||
|
||||
def forward(
|
||||
self, obs_dict: dict[str, Tensor], vision_encoder_cache: torch.Tensor | None = None
|
||||
|
@ -737,12 +734,6 @@ class CriticEnsemble(nn.Module):
|
|||
self.output_normalization = output_normalization
|
||||
self.critics = nn.ModuleList(ensemble)
|
||||
|
||||
self.parameters_to_optimize = []
|
||||
# Handle the case where a part of the encoder if frozen
|
||||
if self.encoder is not None:
|
||||
self.parameters_to_optimize += list(self.encoder.parameters_to_optimize)
|
||||
self.parameters_to_optimize += list(self.critics.parameters())
|
||||
|
||||
def forward(
|
||||
self,
|
||||
observations: dict[str, torch.Tensor],
|
||||
|
@ -805,10 +796,6 @@ class GraspCritic(nn.Module):
|
|||
else:
|
||||
orthogonal_init()(self.output_layer.weight)
|
||||
|
||||
self.parameters_to_optimize = []
|
||||
self.parameters_to_optimize += list(self.net.parameters())
|
||||
self.parameters_to_optimize += list(self.output_layer.parameters())
|
||||
|
||||
def forward(
|
||||
self, observations: torch.Tensor, observation_features: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
|
@ -840,12 +827,8 @@ class Policy(nn.Module):
|
|||
self.log_std_max = log_std_max
|
||||
self.fixed_std = fixed_std
|
||||
self.use_tanh_squash = use_tanh_squash
|
||||
self.parameters_to_optimize = []
|
||||
self.encoder_is_shared = encoder_is_shared
|
||||
|
||||
self.parameters_to_optimize += list(self.network.parameters())
|
||||
|
||||
if self.encoder is not None and not encoder_is_shared:
|
||||
self.parameters_to_optimize += list(self.encoder.parameters())
|
||||
# Find the last Linear layer's output dimension
|
||||
for layer in reversed(network.net):
|
||||
if isinstance(layer, nn.Linear):
|
||||
|
@ -859,7 +842,6 @@ class Policy(nn.Module):
|
|||
else:
|
||||
orthogonal_init()(self.mean_layer.weight)
|
||||
|
||||
self.parameters_to_optimize += list(self.mean_layer.parameters())
|
||||
# Standard deviation layer or parameter
|
||||
if fixed_std is None:
|
||||
self.std_layer = nn.Linear(out_features, action_dim)
|
||||
|
@ -868,7 +850,6 @@ class Policy(nn.Module):
|
|||
nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
|
||||
else:
|
||||
orthogonal_init()(self.std_layer.weight)
|
||||
self.parameters_to_optimize += list(self.std_layer.parameters())
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -877,6 +858,8 @@ class Policy(nn.Module):
|
|||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Encode observations if encoder exists
|
||||
obs_enc = self.encoder(observations, vision_encoder_cache=observation_features)
|
||||
if self.encoder_is_shared:
|
||||
obs_enc = obs_enc.detach()
|
||||
|
||||
# Get network outputs
|
||||
outputs = self.network(obs_enc)
|
||||
|
@ -966,13 +949,13 @@ class DefaultImageEncoder(nn.Module):
|
|||
nn.Tanh(),
|
||||
)
|
||||
|
||||
self.parameters_to_optimize = []
|
||||
if not config.freeze_vision_encoder:
|
||||
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
|
||||
self.parameters_to_optimize += list(self.image_enc_proj.parameters())
|
||||
self.freeze_image_encoder = config.freeze_vision_encoder
|
||||
|
||||
def forward(self, x):
|
||||
return self.image_enc_proj(self.image_enc_layers(x))
|
||||
x = self.image_enc_layers(x)
|
||||
if self.freeze_image_encoder:
|
||||
x = x.detach()
|
||||
return self.image_enc_proj(x)
|
||||
|
||||
|
||||
class PretrainedImageEncoder(nn.Module):
|
||||
|
@ -985,10 +968,7 @@ class PretrainedImageEncoder(nn.Module):
|
|||
nn.Tanh(),
|
||||
)
|
||||
|
||||
self.parameters_to_optimize = []
|
||||
if not config.freeze_vision_encoder:
|
||||
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
|
||||
self.parameters_to_optimize += list(self.image_enc_proj.parameters())
|
||||
self.freeze_image_encoder = config.freeze_vision_encoder
|
||||
|
||||
def _load_pretrained_vision_encoder(self, config: SACConfig):
|
||||
"""Set up CNN encoder"""
|
||||
|
@ -1009,6 +989,8 @@ class PretrainedImageEncoder(nn.Module):
|
|||
# TODO: (maractingi, azouitine) check the forward pass of the pretrained model
|
||||
# doesn't reach the classifier layer because we don't need it
|
||||
enc_feat = self.image_enc_layers(x).pooler_output
|
||||
if self.freeze_image_encoder:
|
||||
enc_feat = enc_feat.detach()
|
||||
enc_feat = self.image_enc_proj(enc_feat.view(enc_feat.shape[0], -1))
|
||||
return enc_feat
|
||||
|
||||
|
|
|
@ -510,7 +510,7 @@ def add_actor_information_and_train(
|
|||
optimizers["actor"].zero_grad()
|
||||
loss_actor.backward()
|
||||
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.actor.parameters_to_optimize, max_norm=clip_grad_norm_value
|
||||
parameters=policy.actor.parameters(), max_norm=clip_grad_norm_value
|
||||
).item()
|
||||
optimizers["actor"].step()
|
||||
|
||||
|
@ -773,12 +773,14 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
|
|||
"""
|
||||
optimizer_actor = torch.optim.Adam(
|
||||
# NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor
|
||||
params=policy.actor.parameters_to_optimize,
|
||||
params=[
|
||||
p
|
||||
for n, p in policy.actor.named_parameters()
|
||||
if not n.startswith("encoder") or not policy.config.shared_encoder
|
||||
],
|
||||
lr=cfg.policy.actor_lr,
|
||||
)
|
||||
optimizer_critic = torch.optim.Adam(
|
||||
params=policy.critic_ensemble.parameters_to_optimize, lr=cfg.policy.critic_lr
|
||||
)
|
||||
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
|
||||
|
||||
if cfg.policy.num_discrete_actions is not None:
|
||||
optimizer_grasp_critic = torch.optim.Adam(
|
||||
|
@ -1089,6 +1091,44 @@ def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
|
|||
parameters_queue.put(state_bytes)
|
||||
|
||||
|
||||
def check_weight_gradients(module: nn.Module) -> dict[str, bool]:
|
||||
"""
|
||||
Checks whether each parameter in the module has a gradient.
|
||||
|
||||
Args:
|
||||
module (nn.Module): A PyTorch module whose parameters will be inspected.
|
||||
|
||||
Returns:
|
||||
dict[str, bool]: A dictionary where each key is the parameter name and the value is
|
||||
True if the parameter has an associated gradient (i.e. .grad is not None),
|
||||
otherwise False.
|
||||
"""
|
||||
grad_status = {}
|
||||
for name, param in module.named_parameters():
|
||||
grad_status[name] = param.grad is not None
|
||||
return grad_status
|
||||
|
||||
|
||||
def get_overlapping_parameters(model: nn.Module, grad_status: dict[str, bool]) -> dict[str, bool]:
|
||||
"""
|
||||
Returns a dictionary of parameters (from actor) that also exist in the grad_status dictionary.
|
||||
|
||||
Args:
|
||||
actor (nn.Module): The actor model.
|
||||
grad_status (dict[str, bool]): A dictionary where keys are parameter names and values indicate
|
||||
whether each parameter has a gradient.
|
||||
|
||||
Returns:
|
||||
dict[str, bool]: A dictionary containing only the overlapping parameter names and their gradient status.
|
||||
"""
|
||||
# Get actor parameter names as a set.
|
||||
model_param_names = {name for name, _ in model.named_parameters()}
|
||||
|
||||
# Intersect parameter names between actor and grad_status.
|
||||
overlapping = {name: grad_status[name] for name in grad_status if name in model_param_names}
|
||||
return overlapping
|
||||
|
||||
|
||||
def process_interaction_message(
|
||||
message, interaction_step_shift: int, wandb_logger: WandBLogger | None = None
|
||||
):
|
||||
|
|
Loading…
Reference in New Issue