Refactor SAC configuration and policy for improved action sampling and stability
- Updated SACConfig to replace standard deviation parameterization with log_std_min and log_std_max for better control over action distributions. - Modified SACPolicy to streamline action selection and log probability calculations, enhancing stochastic behavior. - Removed deprecated TanhMultivariateNormalDiag class to simplify the codebase and improve maintainability. These changes aim to enhance the robustness and performance of the SAC implementation during training and inference.
This commit is contained in:
parent
70e3b9248c
commit
91fefdecfa
|
@ -53,30 +53,13 @@ class SACConfig:
|
|||
critic_network_kwargs = {
|
||||
"hidden_dims": [256, 256],
|
||||
"activate_final": True,
|
||||
}
|
||||
}
|
||||
actor_network_kwargs = {
|
||||
"hidden_dims": [256, 256],
|
||||
"activate_final": True,
|
||||
}
|
||||
}
|
||||
policy_kwargs = {
|
||||
"tanh_squash_distribution": True,
|
||||
"std_parameterization": "softplus",
|
||||
"std_min": 0.005,
|
||||
"std_max": 5.0,
|
||||
"use_tanh_squash": True,
|
||||
"log_std_min": -5,
|
||||
"log_std_max": 2,
|
||||
}
|
||||
)
|
||||
output_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": [4],
|
||||
}
|
||||
)
|
||||
|
||||
state_encoder_hidden_dim: int = 256
|
||||
latent_dim: int = 256
|
||||
network_hidden_dims: int = 256
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes: dict[str, str] | None = None
|
||||
output_normalization_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {"action": "min_max"},
|
||||
)
|
||||
|
|
|
@ -102,9 +102,7 @@ class SACPolicy(
|
|||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select action for inference/evaluation"""
|
||||
distribution = self.actor(batch)
|
||||
# Sample from the distribution and return just the actions
|
||||
actions = distribution.mode() # or distribution.sample() for stochastic actions
|
||||
actions, _ = self.actor(batch)
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
return actions
|
||||
|
||||
|
@ -129,12 +127,11 @@ class SACPolicy(
|
|||
|
||||
# reward bias from HIL-SERL code base
|
||||
# add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch
|
||||
|
||||
|
||||
# calculate critics loss
|
||||
# 1- compute actions from policy
|
||||
distribution = self.actor(observations)
|
||||
action_preds = distribution.sample()
|
||||
action_preds = torch.clamp(action_preds, -1, +1)
|
||||
action_preds, log_probs = self.actor(next_observations)
|
||||
|
||||
# 2- compute q targets
|
||||
q_targets = self.target_qs(next_observations, action_preds)
|
||||
# subsample critics to prevent overfitting if use high UTD (update to date)
|
||||
|
@ -147,7 +144,7 @@ class SACPolicy(
|
|||
min_q = q_targets.min(dim=0)
|
||||
|
||||
# compute td target
|
||||
td_target = rewards + self.discount * min_q
|
||||
td_target = rewards + self.config.discount * min_q #+ self.config.discount * self.temperature() * log_probs # add entropy term
|
||||
|
||||
# 3- compute predicted qs
|
||||
q_preds = self.critic_ensemble(observations, actions)
|
||||
|
@ -178,18 +175,12 @@ class SACPolicy(
|
|||
einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]), # expand td_target to match q_preds shape
|
||||
reduction="none"
|
||||
).sum(0).mean()
|
||||
# breakpoint()
|
||||
|
||||
# calculate actors loss
|
||||
# 1- temperature
|
||||
temperature = self.temperature()
|
||||
|
||||
# 2- get actions (batch_size, action_dim) and log probs (batch_size,)
|
||||
distribution = self.actor(observations)
|
||||
actions = distribution.rsample()
|
||||
log_probs = distribution.log_prob(actions).sum(-1)
|
||||
# breakpoint()
|
||||
actions = torch.clamp(actions, -1, +1)
|
||||
actions, log_probs = self.actor(observations)
|
||||
# 3- get q-value predictions
|
||||
with torch.no_grad():
|
||||
q_preds = self.critic_ensemble(observations, actions, return_type="mean")
|
||||
|
@ -264,15 +255,13 @@ class Critic(nn.Module):
|
|||
encoder: Optional[nn.Module],
|
||||
network: nn.Module,
|
||||
init_final: Optional[float] = None,
|
||||
activate_final: bool = False,
|
||||
device: str = "cuda",
|
||||
device: str = "cuda"
|
||||
):
|
||||
super().__init__()
|
||||
self.device = torch.device(device)
|
||||
self.encoder = encoder
|
||||
self.network = network
|
||||
self.init_final = init_final
|
||||
self.activate_final = activate_final
|
||||
|
||||
# Find the last Linear layer's output dimension
|
||||
for layer in reversed(network.net):
|
||||
|
@ -304,22 +293,6 @@ class Critic(nn.Module):
|
|||
value = self.output_layer(x)
|
||||
return value.squeeze(-1)
|
||||
|
||||
def q_value_ensemble(
|
||||
self, observations: torch.Tensor, actions: torch.Tensor, train: bool = False
|
||||
) -> torch.Tensor:
|
||||
observations = observations.to(self.device)
|
||||
actions = actions.to(self.device)
|
||||
|
||||
if len(actions.shape) == 3: # [batch_size, num_actions, action_dim]
|
||||
batch_size, num_actions = actions.shape[:2]
|
||||
obs_expanded = observations.unsqueeze(1).expand(-1, num_actions, -1)
|
||||
obs_flat = obs_expanded.reshape(-1, observations.shape[-1])
|
||||
actions_flat = actions.reshape(-1, actions.shape[-1])
|
||||
q_values = self(obs_flat, actions_flat, train)
|
||||
return q_values.reshape(batch_size, num_actions)
|
||||
else:
|
||||
return self(observations, actions, train)
|
||||
|
||||
|
||||
class Policy(nn.Module):
|
||||
def __init__(
|
||||
|
@ -327,26 +300,22 @@ class Policy(nn.Module):
|
|||
encoder: Optional[nn.Module],
|
||||
network: nn.Module,
|
||||
action_dim: int,
|
||||
std_parameterization: str = "exp",
|
||||
std_min: float = 0.05,
|
||||
std_max: float = 2.0,
|
||||
tanh_squash_distribution: bool = False,
|
||||
log_std_min: float = -5,
|
||||
log_std_max: float = 2,
|
||||
fixed_std: Optional[torch.Tensor] = None,
|
||||
init_final: Optional[float] = None,
|
||||
activate_final: bool = False,
|
||||
device: str = "cuda",
|
||||
use_tanh_squash: bool = False,
|
||||
device: str = "cuda"
|
||||
):
|
||||
super().__init__()
|
||||
self.device = torch.device(device)
|
||||
self.encoder = encoder
|
||||
self.network = network
|
||||
self.action_dim = action_dim
|
||||
self.std_parameterization = std_parameterization
|
||||
self.std_min = std_min
|
||||
self.std_max = std_max
|
||||
self.tanh_squash_distribution = tanh_squash_distribution
|
||||
self.log_std_min = log_std_min
|
||||
self.log_std_max = log_std_max
|
||||
self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None
|
||||
self.activate_final = activate_final
|
||||
self.use_tanh_squash = use_tanh_squash
|
||||
|
||||
# Find the last Linear layer's output dimension
|
||||
for layer in reversed(network.net):
|
||||
|
@ -364,27 +333,20 @@ class Policy(nn.Module):
|
|||
|
||||
# Standard deviation layer or parameter
|
||||
if fixed_std is None:
|
||||
if std_parameterization == "uniform":
|
||||
self.log_stds = nn.Parameter(torch.zeros(action_dim, device=self.device))
|
||||
self.std_layer = nn.Linear(out_features, action_dim)
|
||||
if init_final is not None:
|
||||
nn.init.uniform_(self.std_layer.weight, -init_final, init_final)
|
||||
nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
|
||||
else:
|
||||
self.std_layer = nn.Linear(out_features, action_dim)
|
||||
if init_final is not None:
|
||||
nn.init.uniform_(self.std_layer.weight, -init_final, init_final)
|
||||
nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
|
||||
else:
|
||||
orthogonal_init()(self.std_layer.weight)
|
||||
|
||||
orthogonal_init()(self.std_layer.weight)
|
||||
|
||||
self.to(self.device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
observations: torch.Tensor,
|
||||
temperature: float = 1.0,
|
||||
train: bool = False,
|
||||
non_squash_distribution: bool = False,
|
||||
) -> torch.distributions.Distribution:
|
||||
self.train(train)
|
||||
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
# Encode observations if encoder exists
|
||||
if self.encoder is not None:
|
||||
with torch.set_grad_enabled(train):
|
||||
|
@ -398,41 +360,24 @@ class Policy(nn.Module):
|
|||
|
||||
# Compute standard deviations
|
||||
if self.fixed_std is None:
|
||||
if self.std_parameterization == "exp":
|
||||
log_stds = self.std_layer(outputs)
|
||||
# Clamp log_stds to prevent too large or small values
|
||||
log_stds = torch.clamp(log_stds, math.log(self.std_min), math.log(self.std_max))
|
||||
stds = torch.exp(log_stds)
|
||||
elif self.std_parameterization == "softplus":
|
||||
stds = torch.nn.functional.softplus(self.std_layer(outputs))
|
||||
stds = torch.clamp(stds, self.std_min, self.std_max)
|
||||
elif self.std_parameterization == "uniform":
|
||||
log_stds = torch.clamp(self.log_stds, math.log(self.std_min), math.log(self.std_max))
|
||||
stds = torch.exp(log_stds).expand_as(means)
|
||||
else:
|
||||
raise ValueError(f"Invalid std_parameterization: {self.std_parameterization}")
|
||||
log_std = self.std_layer(outputs)
|
||||
if self.use_tanh_squash:
|
||||
log_std = torch.tanh(log_std)
|
||||
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
|
||||
else:
|
||||
assert self.std_parameterization == "fixed"
|
||||
stds = self.fixed_std.expand_as(means)
|
||||
|
||||
# Scale with temperature
|
||||
temperature = torch.tensor(temperature, device=self.device)
|
||||
stds = torch.clamp(stds, self.std_min, self.std_max) * torch.sqrt(temperature)
|
||||
|
||||
# Create distribution
|
||||
if self.tanh_squash_distribution and not non_squash_distribution:
|
||||
distribution = TanhMultivariateNormalDiag(
|
||||
loc=means,
|
||||
scale_diag=stds,
|
||||
)
|
||||
else:
|
||||
distribution = torch.distributions.Normal(
|
||||
loc=means,
|
||||
scale=stds,
|
||||
)
|
||||
|
||||
return distribution
|
||||
# uses tahn activation function to squash the action to be in the range of [-1, 1]
|
||||
normal = torch.distributions.Normal(means, stds)
|
||||
x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1))
|
||||
log_probs = normal.log_prob(x_t)
|
||||
if self.use_tanh_squash:
|
||||
actions = torch.tanh(x_t)
|
||||
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6)
|
||||
log_probs = log_probs.sum(-1) # sum over action dim
|
||||
|
||||
return actions, log_probs
|
||||
|
||||
def get_features(self, observations: torch.Tensor) -> torch.Tensor:
|
||||
"""Get encoded features from observations"""
|
||||
observations = observations.to(self.device)
|
||||
|
@ -552,110 +497,8 @@ class LagrangeMultiplier(nn.Module):
|
|||
return multiplier * diff
|
||||
|
||||
|
||||
# The TanhMultivariateNormalDiag is a probability distribution that represents a transformed normal (Gaussian) distribution where:
|
||||
# 1. The base distribution is a diagonal multivariate normal distribution
|
||||
# 2. The samples from this normal distribution are transformed through a tanh function, which squashes the values to be between -1 and 1
|
||||
# 3. Optionally, the values can be further transformed to fit within arbitrary bounds [low, high] using an affine transformation
|
||||
# This type of distribution is commonly used in reinforcement learning, particularly for continuous action spaces
|
||||
class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution):
|
||||
DEFAULT_SAMPLE_SHAPE = torch.Size()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
loc: torch.Tensor,
|
||||
scale_diag: torch.Tensor,
|
||||
low: Optional[torch.Tensor] = None,
|
||||
high: Optional[torch.Tensor] = None,
|
||||
):
|
||||
# Create base normal distribution
|
||||
base_distribution = torch.distributions.Normal(loc=loc, scale=scale_diag)
|
||||
|
||||
# Create list of transforms
|
||||
transforms = []
|
||||
|
||||
# Add tanh transform
|
||||
transforms.append(torch.distributions.transforms.TanhTransform())
|
||||
|
||||
# Add rescaling transform if bounds are provided
|
||||
if low is not None and high is not None:
|
||||
transforms.append(
|
||||
torch.distributions.transforms.AffineTransform(loc=(high + low) / 2, scale=(high - low) / 2)
|
||||
)
|
||||
|
||||
# Initialize parent class
|
||||
super().__init__(base_distribution=base_distribution, transforms=transforms)
|
||||
|
||||
# Store parameters
|
||||
self.loc = loc
|
||||
self.scale_diag = scale_diag
|
||||
self.low = low
|
||||
self.high = high
|
||||
|
||||
def mode(self) -> torch.Tensor:
|
||||
"""Get the mode of the transformed distribution"""
|
||||
# The mode of a normal distribution is its mean
|
||||
mode = self.loc
|
||||
# Apply transforms
|
||||
for transform in self.transforms:
|
||||
mode = transform(mode)
|
||||
|
||||
return mode
|
||||
|
||||
def rsample(self, sample_shape=DEFAULT_SAMPLE_SHAPE) -> torch.Tensor:
|
||||
"""
|
||||
Reparameterized sample from the distribution
|
||||
"""
|
||||
# Sample from base distributionrsample
|
||||
x = self.base_dist.rsample(sample_shape)
|
||||
|
||||
# Apply transforms
|
||||
for transform in self.transforms:
|
||||
x = transform(x)
|
||||
|
||||
return x
|
||||
|
||||
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Compute log probability of a value
|
||||
Includes the log det jacobian for the transforms
|
||||
"""
|
||||
# Initialize log prob
|
||||
log_prob = torch.zeros_like(value)
|
||||
|
||||
# Inverse transforms to get back to normal distribution
|
||||
q = value
|
||||
for transform in reversed(self.transforms):
|
||||
q_prev = transform.inv(q) # Get the pre-transform value
|
||||
log_prob = log_prob - transform.log_abs_det_jacobian(q_prev, q) # Sum over action dimensions
|
||||
q = q_prev
|
||||
|
||||
# Add base distribution log prob
|
||||
log_prob = log_prob + self.base_dist.log_prob(q) # Sum over action dimensions
|
||||
|
||||
return log_prob
|
||||
|
||||
def sample_and_log_prob(self, sample_shape=DEFAULT_SAMPLE_SHAPE) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Sample from the distribution and compute log probability
|
||||
"""
|
||||
x = self.rsample(sample_shape)
|
||||
log_prob = self.log_prob(x)
|
||||
return x, log_prob
|
||||
|
||||
# def entropy(self) -> torch.Tensor:
|
||||
# """
|
||||
# Compute entropy of the distribution
|
||||
# """
|
||||
# # Start with base distribution entropy
|
||||
# entropy = self.base_dist.entropy().sum(-1)
|
||||
|
||||
# # Add log det jacobian for each transform
|
||||
# x = self.rsample()
|
||||
# for transform in self.transforms:
|
||||
# entropy = entropy + transform.log_abs_det_jacobian(x, transform(x))
|
||||
# x = transform(x)
|
||||
|
||||
# return entropy
|
||||
def orthogonal_init():
|
||||
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
|
||||
|
||||
|
||||
def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: str = "cuda") -> nn.ModuleList:
|
||||
|
|
Loading…
Reference in New Issue