Enhance SAC configuration and policy with gradient clipping and temperature management
- Introduced `grad_clip_norm` parameter in SAC configuration for gradient clipping - Updated SACPolicy to store temperature as an instance variable for consistent usage - Modified loss calculations in SACPolicy to utilize the instance temperature - Enhanced MLP and CriticHead to support a customizable final activation function - Implemented gradient clipping in the learner server during training steps for both actor and critic - Added tracking for gradient norms in training information
This commit is contained in:
parent
41219fe81e
commit
1f23ef7889
|
@ -84,10 +84,12 @@ class SACConfig:
|
|||
latent_dim: int = 256
|
||||
target_entropy: float | None = None
|
||||
use_backup_entropy: bool = True
|
||||
grad_clip_norm: float = 40.0
|
||||
critic_network_kwargs: dict[str, Any] = field(
|
||||
default_factory=lambda: {
|
||||
"hidden_dims": [256, 256],
|
||||
"activate_final": True,
|
||||
"final_activation": None,
|
||||
}
|
||||
)
|
||||
actor_network_kwargs: dict[str, Any] = field(
|
||||
|
|
|
@ -330,7 +330,7 @@ class SACPolicy(
|
|||
observation_features: Tensor | None = None,
|
||||
next_observation_features: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
temperature = self.log_alpha.exp().item()
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
with torch.no_grad():
|
||||
next_action_preds, next_log_probs, _ = self.actor(
|
||||
next_observations, next_observation_features
|
||||
|
@ -358,7 +358,7 @@ class SACPolicy(
|
|||
# critics subsample size
|
||||
min_q, _ = q_targets.min(dim=0) # Get values from min operation
|
||||
if self.config.use_backup_entropy:
|
||||
min_q = min_q - (temperature * next_log_probs)
|
||||
min_q = min_q - (self.temperature * next_log_probs)
|
||||
|
||||
td_target = rewards + (1 - done) * self.config.discount * min_q
|
||||
|
||||
|
@ -398,7 +398,7 @@ class SACPolicy(
|
|||
def compute_loss_actor(
|
||||
self, observations, observation_features: Tensor | None = None
|
||||
) -> Tensor:
|
||||
temperature = self.log_alpha.exp().item()
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
|
||||
actions_pi, log_probs, _ = self.actor(observations, observation_features)
|
||||
|
||||
|
@ -413,7 +413,7 @@ class SACPolicy(
|
|||
)
|
||||
min_q_preds = q_preds.min(dim=0)[0]
|
||||
|
||||
actor_loss = ((temperature * log_probs) - min_q_preds).mean()
|
||||
actor_loss = ((self.temperature * log_probs) - min_q_preds).mean()
|
||||
return actor_loss
|
||||
|
||||
|
||||
|
@ -425,6 +425,7 @@ class MLP(nn.Module):
|
|||
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
|
||||
activate_final: bool = False,
|
||||
dropout_rate: Optional[float] = None,
|
||||
final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.activate_final = activate_final
|
||||
|
@ -451,11 +452,24 @@ class MLP(nn.Module):
|
|||
if dropout_rate is not None and dropout_rate > 0:
|
||||
layers.append(nn.Dropout(p=dropout_rate))
|
||||
layers.append(nn.LayerNorm(hidden_dims[i]))
|
||||
layers.append(
|
||||
activations
|
||||
if isinstance(activations, nn.Module)
|
||||
else getattr(nn, activations)()
|
||||
)
|
||||
|
||||
# If we're at the final layer and a final activation is specified, use it
|
||||
if (
|
||||
i + 1 == len(hidden_dims)
|
||||
and activate_final
|
||||
and final_activation is not None
|
||||
):
|
||||
layers.append(
|
||||
final_activation
|
||||
if isinstance(final_activation, nn.Module)
|
||||
else getattr(nn, final_activation)()
|
||||
)
|
||||
else:
|
||||
layers.append(
|
||||
activations
|
||||
if isinstance(activations, nn.Module)
|
||||
else getattr(nn, activations)()
|
||||
)
|
||||
|
||||
self.net = nn.Sequential(*layers)
|
||||
|
||||
|
@ -516,6 +530,7 @@ class CriticHead(nn.Module):
|
|||
activate_final: bool = False,
|
||||
dropout_rate: Optional[float] = None,
|
||||
init_final: Optional[float] = None,
|
||||
final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.net = MLP(
|
||||
|
@ -524,6 +539,7 @@ class CriticHead(nn.Module):
|
|||
activations=activations,
|
||||
activate_final=activate_final,
|
||||
dropout_rate=dropout_rate,
|
||||
final_activation=final_activation,
|
||||
)
|
||||
self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=1)
|
||||
if init_final is not None:
|
||||
|
|
|
@ -390,6 +390,10 @@ def add_actor_information_and_train(
|
|||
if cfg.resume
|
||||
else None,
|
||||
)
|
||||
|
||||
# Update the policy config with the grad_clip_norm value from training config if it exists
|
||||
clip_grad_norm_value = cfg.training.grad_clip_norm
|
||||
|
||||
# compile policy
|
||||
policy = torch.compile(policy)
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
@ -507,6 +511,12 @@ def add_actor_information_and_train(
|
|||
)
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
|
||||
# clip gradients
|
||||
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
policy.critic_ensemble.parameters(), clip_grad_norm_value
|
||||
)
|
||||
|
||||
optimizers["critic"].step()
|
||||
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
|
@ -541,10 +551,17 @@ def add_actor_information_and_train(
|
|||
)
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
|
||||
# clip gradients
|
||||
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
policy.critic_ensemble.parameters(), clip_grad_norm_value
|
||||
).item()
|
||||
|
||||
optimizers["critic"].step()
|
||||
|
||||
training_infos = {}
|
||||
training_infos["loss_critic"] = loss_critic.item()
|
||||
training_infos["critic_grad_norm"] = critic_grad_norm
|
||||
|
||||
if optimization_step % cfg.training.policy_update_freq == 0:
|
||||
for _ in range(cfg.training.policy_update_freq):
|
||||
|
@ -555,19 +572,35 @@ def add_actor_information_and_train(
|
|||
|
||||
optimizers["actor"].zero_grad()
|
||||
loss_actor.backward()
|
||||
|
||||
# clip gradients
|
||||
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
policy.actor.parameters_to_optimize, clip_grad_norm_value
|
||||
).item()
|
||||
|
||||
optimizers["actor"].step()
|
||||
|
||||
training_infos["loss_actor"] = loss_actor.item()
|
||||
training_infos["actor_grad_norm"] = actor_grad_norm
|
||||
|
||||
# Temperature optimization
|
||||
loss_temperature = policy.compute_loss_temperature(
|
||||
observations=observations,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
optimizers["temperature"].zero_grad()
|
||||
loss_temperature.backward()
|
||||
|
||||
# clip gradients
|
||||
temp_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
[policy.log_alpha], clip_grad_norm_value
|
||||
).item()
|
||||
|
||||
optimizers["temperature"].step()
|
||||
|
||||
training_infos["loss_temperature"] = loss_temperature.item()
|
||||
training_infos["temperature_grad_norm"] = temp_grad_norm
|
||||
training_infos["temperature"] = policy.temperature
|
||||
|
||||
if (
|
||||
time.time() - last_time_policy_pushed
|
||||
|
|
Loading…
Reference in New Issue