Refactor SACPolicy for improved action sampling and standard deviation handling

- Updated action selection to use distribution sampling and log probabilities for better stochastic behavior.
- Enhanced standard deviation clamping to prevent extreme values, ensuring stability in policy outputs.
- Cleaned up code by removing unnecessary comments and improving readability.

These changes aim to refine the SAC implementation, enhancing its robustness and performance during training and inference.
This commit is contained in:
KeWang1017 2024-12-28 18:07:15 +00:00 committed by Michel Aractingi
parent 18a4598986
commit ca74a13d61
1 changed files with 52 additions and 23 deletions

View File

@ -19,6 +19,7 @@
from collections import deque
from copy import deepcopy
import math
from typing import Callable, Optional, Sequence, Tuple
import einops
@ -100,7 +101,12 @@ class SACPolicy(
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
actions, _ = self.actor(batch['observations'])
"""Select action for inference/evaluation"""
distribution = self.actor(batch)
# Sample from the distribution and return just the actions
actions = distribution.mode() # or distribution.sample() for stochastic actions
actions = self.unnormalize_outputs({"action": actions})["action"]
return actions
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
"""Run the batch through the model and compute the loss.
@ -126,7 +132,10 @@ class SACPolicy(
# calculate critics loss
# 1- compute actions from policy
action_preds, log_probs = self.actor(observations)
distribution = self.actor(observations)
action_preds = distribution.sample()
log_probs = distribution.log_prob(action_preds)
action_preds = torch.clamp(action_preds, -1, +1)
# 2- compute q targets
q_targets = self.target_qs(next_observations, action_preds)
# subsample critics to prevent overfitting if use high UTD (update to date)
@ -146,31 +155,46 @@ class SACPolicy(
# 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
critics_loss = (
(
F.mse_loss(
q_preds,
einops.repeat(td_target, "t b -> e t b", e=q_preds.shape[0]),
reduction="none",
).sum(0) # sum over ensemble
# `q_preds_ensemble` depends on the first observation and the actions.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
# q_targets depends on the reward and the next observations.
* ~batch["next.reward_is_pad"]
* ~batch["observation.state_is_pad"][1:]
)
.sum(0)
.mean()
)
#critics_loss = (
# (
# F.mse_loss(
# q_preds,
# einops.repeat(td_target, "t b -> e t b", e=q_preds.shape[0]),
# reduction="none",
# ).sum(0) # sum over ensemble
# # `q_preds_ensemble` depends on the first observation and the actions.
# * ~batch["observation.state_is_pad"][0]
# * ~batch["action_is_pad"]
# # q_targets depends on the reward and the next observations.
# * ~batch["next.reward_is_pad"]
# * ~batch["observation.state_is_pad"][1:]
# )
# .sum(0)
# .mean()
#)
# 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
critics_loss = F.mse_loss(
q_preds, # shape: [num_critics, batch_size]
einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]), # expand td_target to match q_preds shape
reduction="none"
).sum(0).mean()
# breakpoint()
# calculate actors loss
# 1- temperature
temperature = self.temperature()
# 2- get actions (batch_size, action_dim) and log probs (batch_size,)
<<<<<<< HEAD
actions, log_probs = self.actor(observations) \
=======
distribution = self.actor(observations)
actions = distribution.sample()
log_probs = distribution.log_prob(actions)
actions = torch.clamp(actions, -1, +1)
>>>>>>> d3c62b92 (Refactor SACPolicy for improved action sampling and standard deviation handling)
# 3- get q-value predictions
with torch.no_grad():
q_preds = self.critic_ensemble(observations, actions, return_type="mean")
@ -309,8 +333,8 @@ class Policy(nn.Module):
network: nn.Module,
action_dim: int,
std_parameterization: str = "exp",
std_min: float = 1e-5,
std_max: float = 10.0,
std_min: float = 0.05,
std_max: float = 2.0,
tanh_squash_distribution: bool = False,
fixed_std: Optional[torch.Tensor] = None,
init_final: Optional[float] = None,
@ -372,6 +396,7 @@ class Policy(nn.Module):
obs_enc = self.encoder(observations, train=train)
else:
obs_enc = observations
# Get network outputs
outputs = self.network(obs_enc)
means = self.mean_layer(outputs)
@ -380,18 +405,22 @@ class Policy(nn.Module):
if self.fixed_std is None:
if self.std_parameterization == "exp":
log_stds = self.std_layer(outputs)
# Clamp log_stds to prevent too large or small values
log_stds = torch.clamp(log_stds, math.log(self.std_min), math.log(self.std_max))
stds = torch.exp(log_stds)
elif self.std_parameterization == "softplus":
stds = torch.nn.functional.softplus(self.std_layer(outputs))
stds = torch.clamp(stds, self.std_min, self.std_max)
elif self.std_parameterization == "uniform":
stds = torch.exp(self.log_stds).expand_as(means)
log_stds = torch.clamp(self.log_stds, math.log(self.std_min), math.log(self.std_max))
stds = torch.exp(log_stds).expand_as(means)
else:
raise ValueError(f"Invalid std_parameterization: {self.std_parameterization}")
else:
assert self.std_parameterization == "fixed"
stds = self.fixed_std.expand_as(means)
# Clip standard deviations and scale with temperature
# Scale with temperature
temperature = torch.tensor(temperature, device=self.device)
stds = torch.clamp(stds, self.std_min, self.std_max) * torch.sqrt(temperature)