Refine SAC configuration and policy for enhanced performance

- Updated standard deviation parameterization in SACConfig to 'softplus' with defined min and max values for improved stability.
- Modified action sampling in SACPolicy to use reparameterized sampling, ensuring better gradient flow and log probability calculations.
- Cleaned up log probability calculations in TanhMultivariateNormalDiag for clarity and efficiency.
- Increased evaluation frequency in YAML configuration to 50000 for more efficient training cycles.

These changes aim to enhance the robustness and performance of the SAC implementation during training and inference.
This commit is contained in:
KeWang1017 2024-12-28 22:11:34 +00:00 committed by Michel Aractingi
parent ca74a13d61
commit 22fbc9ea4a
3 changed files with 31 additions and 39 deletions

View File

@ -59,14 +59,10 @@ class SACConfig:
"activate_final": True,
}
policy_kwargs = {
"tanh_squash_distribution": True,
"std_parameterization": "uniform",
}
input_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"observation.image": [3, 84, 84],
"observation.state": [4],
"tanh_squash_distribution": True,
"std_parameterization": "softplus",
"std_min": 0.005,
"std_max": 5.0,
}
)
output_shapes: dict[str, list[int]] = field(

View File

@ -134,7 +134,6 @@ class SACPolicy(
# 1- compute actions from policy
distribution = self.actor(observations)
action_preds = distribution.sample()
log_probs = distribution.log_prob(action_preds)
action_preds = torch.clamp(action_preds, -1, +1)
# 2- compute q targets
q_targets = self.target_qs(next_observations, action_preds)
@ -186,15 +185,11 @@ class SACPolicy(
temperature = self.temperature()
# 2- get actions (batch_size, action_dim) and log probs (batch_size,)
<<<<<<< HEAD
actions, log_probs = self.actor(observations) \
=======
distribution = self.actor(observations)
actions = distribution.sample()
log_probs = distribution.log_prob(actions)
actions = distribution.rsample()
log_probs = distribution.log_prob(actions).sum(-1)
# breakpoint()
actions = torch.clamp(actions, -1, +1)
>>>>>>> d3c62b92 (Refactor SACPolicy for improved action sampling and standard deviation handling)
# 3- get q-value predictions
with torch.no_grad():
q_preds = self.critic_ensemble(observations, actions, return_type="mean")
@ -610,7 +605,7 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution):
"""
Reparameterized sample from the distribution
"""
# Sample from base distribution
# Sample from base distributionrsample
x = self.base_dist.rsample(sample_shape)
# Apply transforms
@ -625,17 +620,18 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution):
Includes the log det jacobian for the transforms
"""
# Initialize log prob
log_prob = torch.zeros_like(value[..., 0])
log_prob = torch.zeros_like(value)
# Inverse transforms to get back to normal distribution
q = value
for transform in reversed(self.transforms):
q = transform.inv(q)
log_prob = log_prob - transform.log_abs_det_jacobian(q, transform(q))
q_prev = transform.inv(q) # Get the pre-transform value
log_prob = log_prob - transform.log_abs_det_jacobian(q_prev, q) # Sum over action dimensions
q = q_prev
# Add base distribution log prob
log_prob = log_prob + self.base_dist.log_prob(q).sum(-1)
log_prob = log_prob + self.base_dist.log_prob(q) # Sum over action dimensions
return log_prob
def sample_and_log_prob(self, sample_shape=DEFAULT_SAMPLE_SHAPE) -> Tuple[torch.Tensor, torch.Tensor]:
@ -646,20 +642,20 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution):
log_prob = self.log_prob(x)
return x, log_prob
def entropy(self) -> torch.Tensor:
"""
Compute entropy of the distribution
"""
# Start with base distribution entropy
entropy = self.base_dist.entropy().sum(-1)
# Add log det jacobian for each transform
x = self.rsample()
for transform in self.transforms:
entropy = entropy + transform.log_abs_det_jacobian(x, transform(x))
x = transform(x)
return entropy
# def entropy(self) -> torch.Tensor:
# """
# Compute entropy of the distribution
# """
# # Start with base distribution entropy
# entropy = self.base_dist.entropy().sum(-1)
# # Add log det jacobian for each transform
# x = self.rsample()
# for transform in self.transforms:
# entropy = entropy + transform.log_abs_det_jacobian(x, transform(x))
# x = transform(x)
# return entropy
def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: str = "cuda") -> nn.ModuleList:

View File

@ -19,7 +19,7 @@ training:
grad_clip_norm: 10.0
lr: 3e-4
eval_freq: 10000
eval_freq: 50000
log_freq: 500
save_freq: 50000