Enhance SAC configuration and policy with gradient clipping and temperature management

- Introduced `grad_clip_norm` parameter in SAC configuration for gradient clipping
- Updated SACPolicy to store temperature as an instance variable for consistent usage
- Modified loss calculations in SACPolicy to utilize the instance temperature
- Enhanced MLP and CriticHead to support a customizable final activation function
- Implemented gradient clipping in the learner server during training steps for both actor and critic
- Added tracking for gradient norms in training information
This commit is contained in:
AdilZouitine 2025-03-17 10:50:28 +00:00
parent 41219fe81e
commit 1f23ef7889
3 changed files with 60 additions and 9 deletions

View File

@ -84,10 +84,12 @@ class SACConfig:
latent_dim: int = 256
target_entropy: float | None = None
use_backup_entropy: bool = True
grad_clip_norm: float = 40.0
critic_network_kwargs: dict[str, Any] = field(
default_factory=lambda: {
"hidden_dims": [256, 256],
"activate_final": True,
"final_activation": None,
}
)
actor_network_kwargs: dict[str, Any] = field(

View File

@ -330,7 +330,7 @@ class SACPolicy(
observation_features: Tensor | None = None,
next_observation_features: Tensor | None = None,
) -> Tensor:
temperature = self.log_alpha.exp().item()
self.temperature = self.log_alpha.exp().item()
with torch.no_grad():
next_action_preds, next_log_probs, _ = self.actor(
next_observations, next_observation_features
@ -358,7 +358,7 @@ class SACPolicy(
# critics subsample size
min_q, _ = q_targets.min(dim=0) # Get values from min operation
if self.config.use_backup_entropy:
min_q = min_q - (temperature * next_log_probs)
min_q = min_q - (self.temperature * next_log_probs)
td_target = rewards + (1 - done) * self.config.discount * min_q
@ -398,7 +398,7 @@ class SACPolicy(
def compute_loss_actor(
self, observations, observation_features: Tensor | None = None
) -> Tensor:
temperature = self.log_alpha.exp().item()
self.temperature = self.log_alpha.exp().item()
actions_pi, log_probs, _ = self.actor(observations, observation_features)
@ -413,7 +413,7 @@ class SACPolicy(
)
min_q_preds = q_preds.min(dim=0)[0]
actor_loss = ((temperature * log_probs) - min_q_preds).mean()
actor_loss = ((self.temperature * log_probs) - min_q_preds).mean()
return actor_loss
@ -425,6 +425,7 @@ class MLP(nn.Module):
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
activate_final: bool = False,
dropout_rate: Optional[float] = None,
final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
):
super().__init__()
self.activate_final = activate_final
@ -451,11 +452,24 @@ class MLP(nn.Module):
if dropout_rate is not None and dropout_rate > 0:
layers.append(nn.Dropout(p=dropout_rate))
layers.append(nn.LayerNorm(hidden_dims[i]))
layers.append(
activations
if isinstance(activations, nn.Module)
else getattr(nn, activations)()
)
# If we're at the final layer and a final activation is specified, use it
if (
i + 1 == len(hidden_dims)
and activate_final
and final_activation is not None
):
layers.append(
final_activation
if isinstance(final_activation, nn.Module)
else getattr(nn, final_activation)()
)
else:
layers.append(
activations
if isinstance(activations, nn.Module)
else getattr(nn, activations)()
)
self.net = nn.Sequential(*layers)
@ -516,6 +530,7 @@ class CriticHead(nn.Module):
activate_final: bool = False,
dropout_rate: Optional[float] = None,
init_final: Optional[float] = None,
final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
):
super().__init__()
self.net = MLP(
@ -524,6 +539,7 @@ class CriticHead(nn.Module):
activations=activations,
activate_final=activate_final,
dropout_rate=dropout_rate,
final_activation=final_activation,
)
self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=1)
if init_final is not None:

View File

@ -390,6 +390,10 @@ def add_actor_information_and_train(
if cfg.resume
else None,
)
# Update the policy config with the grad_clip_norm value from training config if it exists
clip_grad_norm_value = cfg.training.grad_clip_norm
# compile policy
policy = torch.compile(policy)
assert isinstance(policy, nn.Module)
@ -507,6 +511,12 @@ def add_actor_information_and_train(
)
optimizers["critic"].zero_grad()
loss_critic.backward()
# clip gradients
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
policy.critic_ensemble.parameters(), clip_grad_norm_value
)
optimizers["critic"].step()
batch = replay_buffer.sample(batch_size)
@ -541,10 +551,17 @@ def add_actor_information_and_train(
)
optimizers["critic"].zero_grad()
loss_critic.backward()
# clip gradients
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
policy.critic_ensemble.parameters(), clip_grad_norm_value
).item()
optimizers["critic"].step()
training_infos = {}
training_infos["loss_critic"] = loss_critic.item()
training_infos["critic_grad_norm"] = critic_grad_norm
if optimization_step % cfg.training.policy_update_freq == 0:
for _ in range(cfg.training.policy_update_freq):
@ -555,19 +572,35 @@ def add_actor_information_and_train(
optimizers["actor"].zero_grad()
loss_actor.backward()
# clip gradients
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
policy.actor.parameters_to_optimize, clip_grad_norm_value
).item()
optimizers["actor"].step()
training_infos["loss_actor"] = loss_actor.item()
training_infos["actor_grad_norm"] = actor_grad_norm
# Temperature optimization
loss_temperature = policy.compute_loss_temperature(
observations=observations,
observation_features=observation_features,
)
optimizers["temperature"].zero_grad()
loss_temperature.backward()
# clip gradients
temp_grad_norm = torch.nn.utils.clip_grad_norm_(
[policy.log_alpha], clip_grad_norm_value
).item()
optimizers["temperature"].step()
training_infos["loss_temperature"] = loss_temperature.item()
training_infos["temperature_grad_norm"] = temp_grad_norm
training_infos["temperature"] = policy.temperature
if (
time.time() - last_time_policy_pushed