Refactor SAC configuration and policy for improved action sampling and stability

- Updated SACConfig to replace standard deviation parameterization with log_std_min and log_std_max for better control over action distributions.
- Modified SACPolicy to streamline action selection and log probability calculations, enhancing stochastic behavior.
- Removed deprecated TanhMultivariateNormalDiag class to simplify the codebase and improve maintainability.

These changes aim to enhance the robustness and performance of the SAC implementation during training and inference.
This commit is contained in:
KeWang1017 2024-12-29 12:30:39 +00:00 committed by AdilZouitine
parent 70e3b9248c
commit 91fefdecfa
2 changed files with 43 additions and 217 deletions

View File

@ -53,30 +53,13 @@ class SACConfig:
critic_network_kwargs = {
"hidden_dims": [256, 256],
"activate_final": True,
}
}
actor_network_kwargs = {
"hidden_dims": [256, 256],
"activate_final": True,
}
}
policy_kwargs = {
"tanh_squash_distribution": True,
"std_parameterization": "softplus",
"std_min": 0.005,
"std_max": 5.0,
"use_tanh_squash": True,
"log_std_min": -5,
"log_std_max": 2,
}
)
output_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"action": [4],
}
)
state_encoder_hidden_dim: int = 256
latent_dim: int = 256
network_hidden_dims: int = 256
# Normalization / Unnormalization
input_normalization_modes: dict[str, str] | None = None
output_normalization_modes: dict[str, str] = field(
default_factory=lambda: {"action": "min_max"},
)

View File

@ -102,9 +102,7 @@ class SACPolicy(
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select action for inference/evaluation"""
distribution = self.actor(batch)
# Sample from the distribution and return just the actions
actions = distribution.mode() # or distribution.sample() for stochastic actions
actions, _ = self.actor(batch)
actions = self.unnormalize_outputs({"action": actions})["action"]
return actions
@ -129,12 +127,11 @@ class SACPolicy(
# reward bias from HIL-SERL code base
# add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch
# calculate critics loss
# 1- compute actions from policy
distribution = self.actor(observations)
action_preds = distribution.sample()
action_preds = torch.clamp(action_preds, -1, +1)
action_preds, log_probs = self.actor(next_observations)
# 2- compute q targets
q_targets = self.target_qs(next_observations, action_preds)
# subsample critics to prevent overfitting if use high UTD (update to date)
@ -147,7 +144,7 @@ class SACPolicy(
min_q = q_targets.min(dim=0)
# compute td target
td_target = rewards + self.discount * min_q
td_target = rewards + self.config.discount * min_q #+ self.config.discount * self.temperature() * log_probs # add entropy term
# 3- compute predicted qs
q_preds = self.critic_ensemble(observations, actions)
@ -178,18 +175,12 @@ class SACPolicy(
einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]), # expand td_target to match q_preds shape
reduction="none"
).sum(0).mean()
# breakpoint()
# calculate actors loss
# 1- temperature
temperature = self.temperature()
# 2- get actions (batch_size, action_dim) and log probs (batch_size,)
distribution = self.actor(observations)
actions = distribution.rsample()
log_probs = distribution.log_prob(actions).sum(-1)
# breakpoint()
actions = torch.clamp(actions, -1, +1)
actions, log_probs = self.actor(observations)
# 3- get q-value predictions
with torch.no_grad():
q_preds = self.critic_ensemble(observations, actions, return_type="mean")
@ -264,15 +255,13 @@ class Critic(nn.Module):
encoder: Optional[nn.Module],
network: nn.Module,
init_final: Optional[float] = None,
activate_final: bool = False,
device: str = "cuda",
device: str = "cuda"
):
super().__init__()
self.device = torch.device(device)
self.encoder = encoder
self.network = network
self.init_final = init_final
self.activate_final = activate_final
# Find the last Linear layer's output dimension
for layer in reversed(network.net):
@ -304,22 +293,6 @@ class Critic(nn.Module):
value = self.output_layer(x)
return value.squeeze(-1)
def q_value_ensemble(
self, observations: torch.Tensor, actions: torch.Tensor, train: bool = False
) -> torch.Tensor:
observations = observations.to(self.device)
actions = actions.to(self.device)
if len(actions.shape) == 3: # [batch_size, num_actions, action_dim]
batch_size, num_actions = actions.shape[:2]
obs_expanded = observations.unsqueeze(1).expand(-1, num_actions, -1)
obs_flat = obs_expanded.reshape(-1, observations.shape[-1])
actions_flat = actions.reshape(-1, actions.shape[-1])
q_values = self(obs_flat, actions_flat, train)
return q_values.reshape(batch_size, num_actions)
else:
return self(observations, actions, train)
class Policy(nn.Module):
def __init__(
@ -327,26 +300,22 @@ class Policy(nn.Module):
encoder: Optional[nn.Module],
network: nn.Module,
action_dim: int,
std_parameterization: str = "exp",
std_min: float = 0.05,
std_max: float = 2.0,
tanh_squash_distribution: bool = False,
log_std_min: float = -5,
log_std_max: float = 2,
fixed_std: Optional[torch.Tensor] = None,
init_final: Optional[float] = None,
activate_final: bool = False,
device: str = "cuda",
use_tanh_squash: bool = False,
device: str = "cuda"
):
super().__init__()
self.device = torch.device(device)
self.encoder = encoder
self.network = network
self.action_dim = action_dim
self.std_parameterization = std_parameterization
self.std_min = std_min
self.std_max = std_max
self.tanh_squash_distribution = tanh_squash_distribution
self.log_std_min = log_std_min
self.log_std_max = log_std_max
self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None
self.activate_final = activate_final
self.use_tanh_squash = use_tanh_squash
# Find the last Linear layer's output dimension
for layer in reversed(network.net):
@ -364,27 +333,20 @@ class Policy(nn.Module):
# Standard deviation layer or parameter
if fixed_std is None:
if std_parameterization == "uniform":
self.log_stds = nn.Parameter(torch.zeros(action_dim, device=self.device))
self.std_layer = nn.Linear(out_features, action_dim)
if init_final is not None:
nn.init.uniform_(self.std_layer.weight, -init_final, init_final)
nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
else:
self.std_layer = nn.Linear(out_features, action_dim)
if init_final is not None:
nn.init.uniform_(self.std_layer.weight, -init_final, init_final)
nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
else:
orthogonal_init()(self.std_layer.weight)
orthogonal_init()(self.std_layer.weight)
self.to(self.device)
def forward(
self,
observations: torch.Tensor,
temperature: float = 1.0,
train: bool = False,
non_squash_distribution: bool = False,
) -> torch.distributions.Distribution:
self.train(train)
) -> Tuple[torch.Tensor, torch.Tensor]:
# Encode observations if encoder exists
if self.encoder is not None:
with torch.set_grad_enabled(train):
@ -398,41 +360,24 @@ class Policy(nn.Module):
# Compute standard deviations
if self.fixed_std is None:
if self.std_parameterization == "exp":
log_stds = self.std_layer(outputs)
# Clamp log_stds to prevent too large or small values
log_stds = torch.clamp(log_stds, math.log(self.std_min), math.log(self.std_max))
stds = torch.exp(log_stds)
elif self.std_parameterization == "softplus":
stds = torch.nn.functional.softplus(self.std_layer(outputs))
stds = torch.clamp(stds, self.std_min, self.std_max)
elif self.std_parameterization == "uniform":
log_stds = torch.clamp(self.log_stds, math.log(self.std_min), math.log(self.std_max))
stds = torch.exp(log_stds).expand_as(means)
else:
raise ValueError(f"Invalid std_parameterization: {self.std_parameterization}")
log_std = self.std_layer(outputs)
if self.use_tanh_squash:
log_std = torch.tanh(log_std)
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
else:
assert self.std_parameterization == "fixed"
stds = self.fixed_std.expand_as(means)
# Scale with temperature
temperature = torch.tensor(temperature, device=self.device)
stds = torch.clamp(stds, self.std_min, self.std_max) * torch.sqrt(temperature)
# Create distribution
if self.tanh_squash_distribution and not non_squash_distribution:
distribution = TanhMultivariateNormalDiag(
loc=means,
scale_diag=stds,
)
else:
distribution = torch.distributions.Normal(
loc=means,
scale=stds,
)
return distribution
# uses tahn activation function to squash the action to be in the range of [-1, 1]
normal = torch.distributions.Normal(means, stds)
x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1))
log_probs = normal.log_prob(x_t)
if self.use_tanh_squash:
actions = torch.tanh(x_t)
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6)
log_probs = log_probs.sum(-1) # sum over action dim
return actions, log_probs
def get_features(self, observations: torch.Tensor) -> torch.Tensor:
"""Get encoded features from observations"""
observations = observations.to(self.device)
@ -552,110 +497,8 @@ class LagrangeMultiplier(nn.Module):
return multiplier * diff
# The TanhMultivariateNormalDiag is a probability distribution that represents a transformed normal (Gaussian) distribution where:
# 1. The base distribution is a diagonal multivariate normal distribution
# 2. The samples from this normal distribution are transformed through a tanh function, which squashes the values to be between -1 and 1
# 3. Optionally, the values can be further transformed to fit within arbitrary bounds [low, high] using an affine transformation
# This type of distribution is commonly used in reinforcement learning, particularly for continuous action spaces
class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution):
DEFAULT_SAMPLE_SHAPE = torch.Size()
def __init__(
self,
loc: torch.Tensor,
scale_diag: torch.Tensor,
low: Optional[torch.Tensor] = None,
high: Optional[torch.Tensor] = None,
):
# Create base normal distribution
base_distribution = torch.distributions.Normal(loc=loc, scale=scale_diag)
# Create list of transforms
transforms = []
# Add tanh transform
transforms.append(torch.distributions.transforms.TanhTransform())
# Add rescaling transform if bounds are provided
if low is not None and high is not None:
transforms.append(
torch.distributions.transforms.AffineTransform(loc=(high + low) / 2, scale=(high - low) / 2)
)
# Initialize parent class
super().__init__(base_distribution=base_distribution, transforms=transforms)
# Store parameters
self.loc = loc
self.scale_diag = scale_diag
self.low = low
self.high = high
def mode(self) -> torch.Tensor:
"""Get the mode of the transformed distribution"""
# The mode of a normal distribution is its mean
mode = self.loc
# Apply transforms
for transform in self.transforms:
mode = transform(mode)
return mode
def rsample(self, sample_shape=DEFAULT_SAMPLE_SHAPE) -> torch.Tensor:
"""
Reparameterized sample from the distribution
"""
# Sample from base distributionrsample
x = self.base_dist.rsample(sample_shape)
# Apply transforms
for transform in self.transforms:
x = transform(x)
return x
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
"""
Compute log probability of a value
Includes the log det jacobian for the transforms
"""
# Initialize log prob
log_prob = torch.zeros_like(value)
# Inverse transforms to get back to normal distribution
q = value
for transform in reversed(self.transforms):
q_prev = transform.inv(q) # Get the pre-transform value
log_prob = log_prob - transform.log_abs_det_jacobian(q_prev, q) # Sum over action dimensions
q = q_prev
# Add base distribution log prob
log_prob = log_prob + self.base_dist.log_prob(q) # Sum over action dimensions
return log_prob
def sample_and_log_prob(self, sample_shape=DEFAULT_SAMPLE_SHAPE) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Sample from the distribution and compute log probability
"""
x = self.rsample(sample_shape)
log_prob = self.log_prob(x)
return x, log_prob
# def entropy(self) -> torch.Tensor:
# """
# Compute entropy of the distribution
# """
# # Start with base distribution entropy
# entropy = self.base_dist.entropy().sum(-1)
# # Add log det jacobian for each transform
# x = self.rsample()
# for transform in self.transforms:
# entropy = entropy + transform.log_abs_det_jacobian(x, transform(x))
# x = transform(x)
# return entropy
def orthogonal_init():
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: str = "cuda") -> nn.ModuleList: