Refactor SAC configuration and policy for improved action sampling and stability

- Updated SACConfig to replace standard deviation parameterization with log_std_min and log_std_max for better control over action distributions.
- Modified SACPolicy to streamline action selection and log probability calculations, enhancing stochastic behavior.
- Removed deprecated TanhMultivariateNormalDiag class to simplify the codebase and improve maintainability.

These changes aim to enhance the robustness and performance of the SAC implementation during training and inference.
This commit is contained in:
KeWang1017 2024-12-29 12:30:39 +00:00 committed by AdilZouitine
parent 70e3b9248c
commit 91fefdecfa
2 changed files with 43 additions and 217 deletions

View File

@ -59,24 +59,7 @@ class SACConfig:
"activate_final": True, "activate_final": True,
} }
policy_kwargs = { policy_kwargs = {
"tanh_squash_distribution": True, "use_tanh_squash": True,
"std_parameterization": "softplus", "log_std_min": -5,
"std_min": 0.005, "log_std_max": 2,
"std_max": 5.0,
} }
)
output_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"action": [4],
}
)
state_encoder_hidden_dim: int = 256
latent_dim: int = 256
network_hidden_dims: int = 256
# Normalization / Unnormalization
input_normalization_modes: dict[str, str] | None = None
output_normalization_modes: dict[str, str] = field(
default_factory=lambda: {"action": "min_max"},
)

View File

@ -102,9 +102,7 @@ class SACPolicy(
@torch.no_grad() @torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor: def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select action for inference/evaluation""" """Select action for inference/evaluation"""
distribution = self.actor(batch) actions, _ = self.actor(batch)
# Sample from the distribution and return just the actions
actions = distribution.mode() # or distribution.sample() for stochastic actions
actions = self.unnormalize_outputs({"action": actions})["action"] actions = self.unnormalize_outputs({"action": actions})["action"]
return actions return actions
@ -132,9 +130,8 @@ class SACPolicy(
# calculate critics loss # calculate critics loss
# 1- compute actions from policy # 1- compute actions from policy
distribution = self.actor(observations) action_preds, log_probs = self.actor(next_observations)
action_preds = distribution.sample()
action_preds = torch.clamp(action_preds, -1, +1)
# 2- compute q targets # 2- compute q targets
q_targets = self.target_qs(next_observations, action_preds) q_targets = self.target_qs(next_observations, action_preds)
# subsample critics to prevent overfitting if use high UTD (update to date) # subsample critics to prevent overfitting if use high UTD (update to date)
@ -147,7 +144,7 @@ class SACPolicy(
min_q = q_targets.min(dim=0) min_q = q_targets.min(dim=0)
# compute td target # compute td target
td_target = rewards + self.discount * min_q td_target = rewards + self.config.discount * min_q #+ self.config.discount * self.temperature() * log_probs # add entropy term
# 3- compute predicted qs # 3- compute predicted qs
q_preds = self.critic_ensemble(observations, actions) q_preds = self.critic_ensemble(observations, actions)
@ -178,18 +175,12 @@ class SACPolicy(
einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]), # expand td_target to match q_preds shape einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]), # expand td_target to match q_preds shape
reduction="none" reduction="none"
).sum(0).mean() ).sum(0).mean()
# breakpoint()
# calculate actors loss # calculate actors loss
# 1- temperature # 1- temperature
temperature = self.temperature() temperature = self.temperature()
# 2- get actions (batch_size, action_dim) and log probs (batch_size,) # 2- get actions (batch_size, action_dim) and log probs (batch_size,)
distribution = self.actor(observations) actions, log_probs = self.actor(observations)
actions = distribution.rsample()
log_probs = distribution.log_prob(actions).sum(-1)
# breakpoint()
actions = torch.clamp(actions, -1, +1)
# 3- get q-value predictions # 3- get q-value predictions
with torch.no_grad(): with torch.no_grad():
q_preds = self.critic_ensemble(observations, actions, return_type="mean") q_preds = self.critic_ensemble(observations, actions, return_type="mean")
@ -264,15 +255,13 @@ class Critic(nn.Module):
encoder: Optional[nn.Module], encoder: Optional[nn.Module],
network: nn.Module, network: nn.Module,
init_final: Optional[float] = None, init_final: Optional[float] = None,
activate_final: bool = False, device: str = "cuda"
device: str = "cuda",
): ):
super().__init__() super().__init__()
self.device = torch.device(device) self.device = torch.device(device)
self.encoder = encoder self.encoder = encoder
self.network = network self.network = network
self.init_final = init_final self.init_final = init_final
self.activate_final = activate_final
# Find the last Linear layer's output dimension # Find the last Linear layer's output dimension
for layer in reversed(network.net): for layer in reversed(network.net):
@ -304,22 +293,6 @@ class Critic(nn.Module):
value = self.output_layer(x) value = self.output_layer(x)
return value.squeeze(-1) return value.squeeze(-1)
def q_value_ensemble(
self, observations: torch.Tensor, actions: torch.Tensor, train: bool = False
) -> torch.Tensor:
observations = observations.to(self.device)
actions = actions.to(self.device)
if len(actions.shape) == 3: # [batch_size, num_actions, action_dim]
batch_size, num_actions = actions.shape[:2]
obs_expanded = observations.unsqueeze(1).expand(-1, num_actions, -1)
obs_flat = obs_expanded.reshape(-1, observations.shape[-1])
actions_flat = actions.reshape(-1, actions.shape[-1])
q_values = self(obs_flat, actions_flat, train)
return q_values.reshape(batch_size, num_actions)
else:
return self(observations, actions, train)
class Policy(nn.Module): class Policy(nn.Module):
def __init__( def __init__(
@ -327,26 +300,22 @@ class Policy(nn.Module):
encoder: Optional[nn.Module], encoder: Optional[nn.Module],
network: nn.Module, network: nn.Module,
action_dim: int, action_dim: int,
std_parameterization: str = "exp", log_std_min: float = -5,
std_min: float = 0.05, log_std_max: float = 2,
std_max: float = 2.0,
tanh_squash_distribution: bool = False,
fixed_std: Optional[torch.Tensor] = None, fixed_std: Optional[torch.Tensor] = None,
init_final: Optional[float] = None, init_final: Optional[float] = None,
activate_final: bool = False, use_tanh_squash: bool = False,
device: str = "cuda", device: str = "cuda"
): ):
super().__init__() super().__init__()
self.device = torch.device(device) self.device = torch.device(device)
self.encoder = encoder self.encoder = encoder
self.network = network self.network = network
self.action_dim = action_dim self.action_dim = action_dim
self.std_parameterization = std_parameterization self.log_std_min = log_std_min
self.std_min = std_min self.log_std_max = log_std_max
self.std_max = std_max
self.tanh_squash_distribution = tanh_squash_distribution
self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None
self.activate_final = activate_final self.use_tanh_squash = use_tanh_squash
# Find the last Linear layer's output dimension # Find the last Linear layer's output dimension
for layer in reversed(network.net): for layer in reversed(network.net):
@ -364,9 +333,6 @@ class Policy(nn.Module):
# Standard deviation layer or parameter # Standard deviation layer or parameter
if fixed_std is None: if fixed_std is None:
if std_parameterization == "uniform":
self.log_stds = nn.Parameter(torch.zeros(action_dim, device=self.device))
else:
self.std_layer = nn.Linear(out_features, action_dim) self.std_layer = nn.Linear(out_features, action_dim)
if init_final is not None: if init_final is not None:
nn.init.uniform_(self.std_layer.weight, -init_final, init_final) nn.init.uniform_(self.std_layer.weight, -init_final, init_final)
@ -379,11 +345,7 @@ class Policy(nn.Module):
def forward( def forward(
self, self,
observations: torch.Tensor, observations: torch.Tensor,
temperature: float = 1.0, ) -> Tuple[torch.Tensor, torch.Tensor]:
train: bool = False,
non_squash_distribution: bool = False,
) -> torch.distributions.Distribution:
self.train(train)
# Encode observations if encoder exists # Encode observations if encoder exists
if self.encoder is not None: if self.encoder is not None:
@ -398,40 +360,23 @@ class Policy(nn.Module):
# Compute standard deviations # Compute standard deviations
if self.fixed_std is None: if self.fixed_std is None:
if self.std_parameterization == "exp": log_std = self.std_layer(outputs)
log_stds = self.std_layer(outputs) if self.use_tanh_squash:
# Clamp log_stds to prevent too large or small values log_std = torch.tanh(log_std)
log_stds = torch.clamp(log_stds, math.log(self.std_min), math.log(self.std_max)) log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
stds = torch.exp(log_stds)
elif self.std_parameterization == "softplus":
stds = torch.nn.functional.softplus(self.std_layer(outputs))
stds = torch.clamp(stds, self.std_min, self.std_max)
elif self.std_parameterization == "uniform":
log_stds = torch.clamp(self.log_stds, math.log(self.std_min), math.log(self.std_max))
stds = torch.exp(log_stds).expand_as(means)
else: else:
raise ValueError(f"Invalid std_parameterization: {self.std_parameterization}")
else:
assert self.std_parameterization == "fixed"
stds = self.fixed_std.expand_as(means) stds = self.fixed_std.expand_as(means)
# Scale with temperature # uses tahn activation function to squash the action to be in the range of [-1, 1]
temperature = torch.tensor(temperature, device=self.device) normal = torch.distributions.Normal(means, stds)
stds = torch.clamp(stds, self.std_min, self.std_max) * torch.sqrt(temperature) x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1))
log_probs = normal.log_prob(x_t)
if self.use_tanh_squash:
actions = torch.tanh(x_t)
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6)
log_probs = log_probs.sum(-1) # sum over action dim
# Create distribution return actions, log_probs
if self.tanh_squash_distribution and not non_squash_distribution:
distribution = TanhMultivariateNormalDiag(
loc=means,
scale_diag=stds,
)
else:
distribution = torch.distributions.Normal(
loc=means,
scale=stds,
)
return distribution
def get_features(self, observations: torch.Tensor) -> torch.Tensor: def get_features(self, observations: torch.Tensor) -> torch.Tensor:
"""Get encoded features from observations""" """Get encoded features from observations"""
@ -552,110 +497,8 @@ class LagrangeMultiplier(nn.Module):
return multiplier * diff return multiplier * diff
# The TanhMultivariateNormalDiag is a probability distribution that represents a transformed normal (Gaussian) distribution where: def orthogonal_init():
# 1. The base distribution is a diagonal multivariate normal distribution return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
# 2. The samples from this normal distribution are transformed through a tanh function, which squashes the values to be between -1 and 1
# 3. Optionally, the values can be further transformed to fit within arbitrary bounds [low, high] using an affine transformation
# This type of distribution is commonly used in reinforcement learning, particularly for continuous action spaces
class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution):
DEFAULT_SAMPLE_SHAPE = torch.Size()
def __init__(
self,
loc: torch.Tensor,
scale_diag: torch.Tensor,
low: Optional[torch.Tensor] = None,
high: Optional[torch.Tensor] = None,
):
# Create base normal distribution
base_distribution = torch.distributions.Normal(loc=loc, scale=scale_diag)
# Create list of transforms
transforms = []
# Add tanh transform
transforms.append(torch.distributions.transforms.TanhTransform())
# Add rescaling transform if bounds are provided
if low is not None and high is not None:
transforms.append(
torch.distributions.transforms.AffineTransform(loc=(high + low) / 2, scale=(high - low) / 2)
)
# Initialize parent class
super().__init__(base_distribution=base_distribution, transforms=transforms)
# Store parameters
self.loc = loc
self.scale_diag = scale_diag
self.low = low
self.high = high
def mode(self) -> torch.Tensor:
"""Get the mode of the transformed distribution"""
# The mode of a normal distribution is its mean
mode = self.loc
# Apply transforms
for transform in self.transforms:
mode = transform(mode)
return mode
def rsample(self, sample_shape=DEFAULT_SAMPLE_SHAPE) -> torch.Tensor:
"""
Reparameterized sample from the distribution
"""
# Sample from base distributionrsample
x = self.base_dist.rsample(sample_shape)
# Apply transforms
for transform in self.transforms:
x = transform(x)
return x
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
"""
Compute log probability of a value
Includes the log det jacobian for the transforms
"""
# Initialize log prob
log_prob = torch.zeros_like(value)
# Inverse transforms to get back to normal distribution
q = value
for transform in reversed(self.transforms):
q_prev = transform.inv(q) # Get the pre-transform value
log_prob = log_prob - transform.log_abs_det_jacobian(q_prev, q) # Sum over action dimensions
q = q_prev
# Add base distribution log prob
log_prob = log_prob + self.base_dist.log_prob(q) # Sum over action dimensions
return log_prob
def sample_and_log_prob(self, sample_shape=DEFAULT_SAMPLE_SHAPE) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Sample from the distribution and compute log probability
"""
x = self.rsample(sample_shape)
log_prob = self.log_prob(x)
return x, log_prob
# def entropy(self) -> torch.Tensor:
# """
# Compute entropy of the distribution
# """
# # Start with base distribution entropy
# entropy = self.base_dist.entropy().sum(-1)
# # Add log det jacobian for each transform
# x = self.rsample()
# for transform in self.transforms:
# entropy = entropy + transform.log_abs_det_jacobian(x, transform(x))
# x = transform(x)
# return entropy
def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: str = "cuda") -> nn.ModuleList: def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: str = "cuda") -> nn.ModuleList: