Enhance SAC configuration and policy with gradient clipping and temperature management
- Introduced `grad_clip_norm` parameter in SAC configuration for gradient clipping - Updated SACPolicy to store temperature as an instance variable for consistent usage - Modified loss calculations in SACPolicy to utilize the instance temperature - Enhanced MLP and CriticHead to support a customizable final activation function - Implemented gradient clipping in the learner server during training steps for both actor and critic - Added tracking for gradient norms in training information
This commit is contained in:
parent
41219fe81e
commit
1f23ef7889
|
@ -84,10 +84,12 @@ class SACConfig:
|
||||||
latent_dim: int = 256
|
latent_dim: int = 256
|
||||||
target_entropy: float | None = None
|
target_entropy: float | None = None
|
||||||
use_backup_entropy: bool = True
|
use_backup_entropy: bool = True
|
||||||
|
grad_clip_norm: float = 40.0
|
||||||
critic_network_kwargs: dict[str, Any] = field(
|
critic_network_kwargs: dict[str, Any] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"hidden_dims": [256, 256],
|
"hidden_dims": [256, 256],
|
||||||
"activate_final": True,
|
"activate_final": True,
|
||||||
|
"final_activation": None,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
actor_network_kwargs: dict[str, Any] = field(
|
actor_network_kwargs: dict[str, Any] = field(
|
||||||
|
|
|
@ -330,7 +330,7 @@ class SACPolicy(
|
||||||
observation_features: Tensor | None = None,
|
observation_features: Tensor | None = None,
|
||||||
next_observation_features: Tensor | None = None,
|
next_observation_features: Tensor | None = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
temperature = self.log_alpha.exp().item()
|
self.temperature = self.log_alpha.exp().item()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
next_action_preds, next_log_probs, _ = self.actor(
|
next_action_preds, next_log_probs, _ = self.actor(
|
||||||
next_observations, next_observation_features
|
next_observations, next_observation_features
|
||||||
|
@ -358,7 +358,7 @@ class SACPolicy(
|
||||||
# critics subsample size
|
# critics subsample size
|
||||||
min_q, _ = q_targets.min(dim=0) # Get values from min operation
|
min_q, _ = q_targets.min(dim=0) # Get values from min operation
|
||||||
if self.config.use_backup_entropy:
|
if self.config.use_backup_entropy:
|
||||||
min_q = min_q - (temperature * next_log_probs)
|
min_q = min_q - (self.temperature * next_log_probs)
|
||||||
|
|
||||||
td_target = rewards + (1 - done) * self.config.discount * min_q
|
td_target = rewards + (1 - done) * self.config.discount * min_q
|
||||||
|
|
||||||
|
@ -398,7 +398,7 @@ class SACPolicy(
|
||||||
def compute_loss_actor(
|
def compute_loss_actor(
|
||||||
self, observations, observation_features: Tensor | None = None
|
self, observations, observation_features: Tensor | None = None
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
temperature = self.log_alpha.exp().item()
|
self.temperature = self.log_alpha.exp().item()
|
||||||
|
|
||||||
actions_pi, log_probs, _ = self.actor(observations, observation_features)
|
actions_pi, log_probs, _ = self.actor(observations, observation_features)
|
||||||
|
|
||||||
|
@ -413,7 +413,7 @@ class SACPolicy(
|
||||||
)
|
)
|
||||||
min_q_preds = q_preds.min(dim=0)[0]
|
min_q_preds = q_preds.min(dim=0)[0]
|
||||||
|
|
||||||
actor_loss = ((temperature * log_probs) - min_q_preds).mean()
|
actor_loss = ((self.temperature * log_probs) - min_q_preds).mean()
|
||||||
return actor_loss
|
return actor_loss
|
||||||
|
|
||||||
|
|
||||||
|
@ -425,6 +425,7 @@ class MLP(nn.Module):
|
||||||
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
|
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
|
||||||
activate_final: bool = False,
|
activate_final: bool = False,
|
||||||
dropout_rate: Optional[float] = None,
|
dropout_rate: Optional[float] = None,
|
||||||
|
final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.activate_final = activate_final
|
self.activate_final = activate_final
|
||||||
|
@ -451,11 +452,24 @@ class MLP(nn.Module):
|
||||||
if dropout_rate is not None and dropout_rate > 0:
|
if dropout_rate is not None and dropout_rate > 0:
|
||||||
layers.append(nn.Dropout(p=dropout_rate))
|
layers.append(nn.Dropout(p=dropout_rate))
|
||||||
layers.append(nn.LayerNorm(hidden_dims[i]))
|
layers.append(nn.LayerNorm(hidden_dims[i]))
|
||||||
layers.append(
|
|
||||||
activations
|
# If we're at the final layer and a final activation is specified, use it
|
||||||
if isinstance(activations, nn.Module)
|
if (
|
||||||
else getattr(nn, activations)()
|
i + 1 == len(hidden_dims)
|
||||||
)
|
and activate_final
|
||||||
|
and final_activation is not None
|
||||||
|
):
|
||||||
|
layers.append(
|
||||||
|
final_activation
|
||||||
|
if isinstance(final_activation, nn.Module)
|
||||||
|
else getattr(nn, final_activation)()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
layers.append(
|
||||||
|
activations
|
||||||
|
if isinstance(activations, nn.Module)
|
||||||
|
else getattr(nn, activations)()
|
||||||
|
)
|
||||||
|
|
||||||
self.net = nn.Sequential(*layers)
|
self.net = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
@ -516,6 +530,7 @@ class CriticHead(nn.Module):
|
||||||
activate_final: bool = False,
|
activate_final: bool = False,
|
||||||
dropout_rate: Optional[float] = None,
|
dropout_rate: Optional[float] = None,
|
||||||
init_final: Optional[float] = None,
|
init_final: Optional[float] = None,
|
||||||
|
final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.net = MLP(
|
self.net = MLP(
|
||||||
|
@ -524,6 +539,7 @@ class CriticHead(nn.Module):
|
||||||
activations=activations,
|
activations=activations,
|
||||||
activate_final=activate_final,
|
activate_final=activate_final,
|
||||||
dropout_rate=dropout_rate,
|
dropout_rate=dropout_rate,
|
||||||
|
final_activation=final_activation,
|
||||||
)
|
)
|
||||||
self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=1)
|
self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=1)
|
||||||
if init_final is not None:
|
if init_final is not None:
|
||||||
|
|
|
@ -390,6 +390,10 @@ def add_actor_information_and_train(
|
||||||
if cfg.resume
|
if cfg.resume
|
||||||
else None,
|
else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Update the policy config with the grad_clip_norm value from training config if it exists
|
||||||
|
clip_grad_norm_value = cfg.training.grad_clip_norm
|
||||||
|
|
||||||
# compile policy
|
# compile policy
|
||||||
policy = torch.compile(policy)
|
policy = torch.compile(policy)
|
||||||
assert isinstance(policy, nn.Module)
|
assert isinstance(policy, nn.Module)
|
||||||
|
@ -507,6 +511,12 @@ def add_actor_information_and_train(
|
||||||
)
|
)
|
||||||
optimizers["critic"].zero_grad()
|
optimizers["critic"].zero_grad()
|
||||||
loss_critic.backward()
|
loss_critic.backward()
|
||||||
|
|
||||||
|
# clip gradients
|
||||||
|
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
|
policy.critic_ensemble.parameters(), clip_grad_norm_value
|
||||||
|
)
|
||||||
|
|
||||||
optimizers["critic"].step()
|
optimizers["critic"].step()
|
||||||
|
|
||||||
batch = replay_buffer.sample(batch_size)
|
batch = replay_buffer.sample(batch_size)
|
||||||
|
@ -541,10 +551,17 @@ def add_actor_information_and_train(
|
||||||
)
|
)
|
||||||
optimizers["critic"].zero_grad()
|
optimizers["critic"].zero_grad()
|
||||||
loss_critic.backward()
|
loss_critic.backward()
|
||||||
|
|
||||||
|
# clip gradients
|
||||||
|
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
|
policy.critic_ensemble.parameters(), clip_grad_norm_value
|
||||||
|
).item()
|
||||||
|
|
||||||
optimizers["critic"].step()
|
optimizers["critic"].step()
|
||||||
|
|
||||||
training_infos = {}
|
training_infos = {}
|
||||||
training_infos["loss_critic"] = loss_critic.item()
|
training_infos["loss_critic"] = loss_critic.item()
|
||||||
|
training_infos["critic_grad_norm"] = critic_grad_norm
|
||||||
|
|
||||||
if optimization_step % cfg.training.policy_update_freq == 0:
|
if optimization_step % cfg.training.policy_update_freq == 0:
|
||||||
for _ in range(cfg.training.policy_update_freq):
|
for _ in range(cfg.training.policy_update_freq):
|
||||||
|
@ -555,19 +572,35 @@ def add_actor_information_and_train(
|
||||||
|
|
||||||
optimizers["actor"].zero_grad()
|
optimizers["actor"].zero_grad()
|
||||||
loss_actor.backward()
|
loss_actor.backward()
|
||||||
|
|
||||||
|
# clip gradients
|
||||||
|
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
|
policy.actor.parameters_to_optimize, clip_grad_norm_value
|
||||||
|
).item()
|
||||||
|
|
||||||
optimizers["actor"].step()
|
optimizers["actor"].step()
|
||||||
|
|
||||||
training_infos["loss_actor"] = loss_actor.item()
|
training_infos["loss_actor"] = loss_actor.item()
|
||||||
|
training_infos["actor_grad_norm"] = actor_grad_norm
|
||||||
|
|
||||||
|
# Temperature optimization
|
||||||
loss_temperature = policy.compute_loss_temperature(
|
loss_temperature = policy.compute_loss_temperature(
|
||||||
observations=observations,
|
observations=observations,
|
||||||
observation_features=observation_features,
|
observation_features=observation_features,
|
||||||
)
|
)
|
||||||
optimizers["temperature"].zero_grad()
|
optimizers["temperature"].zero_grad()
|
||||||
loss_temperature.backward()
|
loss_temperature.backward()
|
||||||
|
|
||||||
|
# clip gradients
|
||||||
|
temp_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
|
[policy.log_alpha], clip_grad_norm_value
|
||||||
|
).item()
|
||||||
|
|
||||||
optimizers["temperature"].step()
|
optimizers["temperature"].step()
|
||||||
|
|
||||||
training_infos["loss_temperature"] = loss_temperature.item()
|
training_infos["loss_temperature"] = loss_temperature.item()
|
||||||
|
training_infos["temperature_grad_norm"] = temp_grad_norm
|
||||||
|
training_infos["temperature"] = policy.temperature
|
||||||
|
|
||||||
if (
|
if (
|
||||||
time.time() - last_time_policy_pushed
|
time.time() - last_time_policy_pushed
|
||||||
|
|
Loading…
Reference in New Issue