Refactor SACPolicy for improved action sampling and standard deviation handling
- Updated action selection to use distribution sampling and log probabilities for better stochastic behavior. - Enhanced standard deviation clamping to prevent extreme values, ensuring stability in policy outputs. - Cleaned up code by removing unnecessary comments and improving readability. These changes aim to refine the SAC implementation, enhancing its robustness and performance during training and inference.
This commit is contained in:
parent
18a4598986
commit
ca74a13d61
|
@ -19,6 +19,7 @@
|
|||
|
||||
from collections import deque
|
||||
from copy import deepcopy
|
||||
import math
|
||||
from typing import Callable, Optional, Sequence, Tuple
|
||||
|
||||
import einops
|
||||
|
@ -100,7 +101,12 @@ class SACPolicy(
|
|||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
actions, _ = self.actor(batch['observations'])
|
||||
"""Select action for inference/evaluation"""
|
||||
distribution = self.actor(batch)
|
||||
# Sample from the distribution and return just the actions
|
||||
actions = distribution.mode() # or distribution.sample() for stochastic actions
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
return actions
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
|
||||
"""Run the batch through the model and compute the loss.
|
||||
|
@ -126,7 +132,10 @@ class SACPolicy(
|
|||
|
||||
# calculate critics loss
|
||||
# 1- compute actions from policy
|
||||
action_preds, log_probs = self.actor(observations)
|
||||
distribution = self.actor(observations)
|
||||
action_preds = distribution.sample()
|
||||
log_probs = distribution.log_prob(action_preds)
|
||||
action_preds = torch.clamp(action_preds, -1, +1)
|
||||
# 2- compute q targets
|
||||
q_targets = self.target_qs(next_observations, action_preds)
|
||||
# subsample critics to prevent overfitting if use high UTD (update to date)
|
||||
|
@ -146,31 +155,46 @@ class SACPolicy(
|
|||
|
||||
# 4- Calculate loss
|
||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
critics_loss = (
|
||||
(
|
||||
F.mse_loss(
|
||||
q_preds,
|
||||
einops.repeat(td_target, "t b -> e t b", e=q_preds.shape[0]),
|
||||
reduction="none",
|
||||
).sum(0) # sum over ensemble
|
||||
# `q_preds_ensemble` depends on the first observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
# q_targets depends on the reward and the next observations.
|
||||
* ~batch["next.reward_is_pad"]
|
||||
* ~batch["observation.state_is_pad"][1:]
|
||||
)
|
||||
.sum(0)
|
||||
.mean()
|
||||
)
|
||||
#critics_loss = (
|
||||
# (
|
||||
# F.mse_loss(
|
||||
# q_preds,
|
||||
# einops.repeat(td_target, "t b -> e t b", e=q_preds.shape[0]),
|
||||
# reduction="none",
|
||||
# ).sum(0) # sum over ensemble
|
||||
# # `q_preds_ensemble` depends on the first observation and the actions.
|
||||
# * ~batch["observation.state_is_pad"][0]
|
||||
# * ~batch["action_is_pad"]
|
||||
# # q_targets depends on the reward and the next observations.
|
||||
# * ~batch["next.reward_is_pad"]
|
||||
# * ~batch["observation.state_is_pad"][1:]
|
||||
# )
|
||||
# .sum(0)
|
||||
# .mean()
|
||||
#)
|
||||
# 4- Calculate loss
|
||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
critics_loss = F.mse_loss(
|
||||
q_preds, # shape: [num_critics, batch_size]
|
||||
einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]), # expand td_target to match q_preds shape
|
||||
reduction="none"
|
||||
).sum(0).mean()
|
||||
# breakpoint()
|
||||
|
||||
# calculate actors loss
|
||||
# 1- temperature
|
||||
temperature = self.temperature()
|
||||
|
||||
# 2- get actions (batch_size, action_dim) and log probs (batch_size,)
|
||||
<<<<<<< HEAD
|
||||
actions, log_probs = self.actor(observations) \
|
||||
|
||||
=======
|
||||
distribution = self.actor(observations)
|
||||
actions = distribution.sample()
|
||||
log_probs = distribution.log_prob(actions)
|
||||
actions = torch.clamp(actions, -1, +1)
|
||||
>>>>>>> d3c62b92 (Refactor SACPolicy for improved action sampling and standard deviation handling)
|
||||
# 3- get q-value predictions
|
||||
with torch.no_grad():
|
||||
q_preds = self.critic_ensemble(observations, actions, return_type="mean")
|
||||
|
@ -309,8 +333,8 @@ class Policy(nn.Module):
|
|||
network: nn.Module,
|
||||
action_dim: int,
|
||||
std_parameterization: str = "exp",
|
||||
std_min: float = 1e-5,
|
||||
std_max: float = 10.0,
|
||||
std_min: float = 0.05,
|
||||
std_max: float = 2.0,
|
||||
tanh_squash_distribution: bool = False,
|
||||
fixed_std: Optional[torch.Tensor] = None,
|
||||
init_final: Optional[float] = None,
|
||||
|
@ -372,6 +396,7 @@ class Policy(nn.Module):
|
|||
obs_enc = self.encoder(observations, train=train)
|
||||
else:
|
||||
obs_enc = observations
|
||||
|
||||
# Get network outputs
|
||||
outputs = self.network(obs_enc)
|
||||
means = self.mean_layer(outputs)
|
||||
|
@ -380,18 +405,22 @@ class Policy(nn.Module):
|
|||
if self.fixed_std is None:
|
||||
if self.std_parameterization == "exp":
|
||||
log_stds = self.std_layer(outputs)
|
||||
# Clamp log_stds to prevent too large or small values
|
||||
log_stds = torch.clamp(log_stds, math.log(self.std_min), math.log(self.std_max))
|
||||
stds = torch.exp(log_stds)
|
||||
elif self.std_parameterization == "softplus":
|
||||
stds = torch.nn.functional.softplus(self.std_layer(outputs))
|
||||
stds = torch.clamp(stds, self.std_min, self.std_max)
|
||||
elif self.std_parameterization == "uniform":
|
||||
stds = torch.exp(self.log_stds).expand_as(means)
|
||||
log_stds = torch.clamp(self.log_stds, math.log(self.std_min), math.log(self.std_max))
|
||||
stds = torch.exp(log_stds).expand_as(means)
|
||||
else:
|
||||
raise ValueError(f"Invalid std_parameterization: {self.std_parameterization}")
|
||||
else:
|
||||
assert self.std_parameterization == "fixed"
|
||||
stds = self.fixed_std.expand_as(means)
|
||||
|
||||
# Clip standard deviations and scale with temperature
|
||||
# Scale with temperature
|
||||
temperature = torch.tensor(temperature, device=self.device)
|
||||
stds = torch.clamp(stds, self.std_min, self.std_max) * torch.sqrt(temperature)
|
||||
|
||||
|
|
Loading…
Reference in New Issue