Refine SAC configuration and policy for enhanced performance
- Updated standard deviation parameterization in SACConfig to 'softplus' with defined min and max values for improved stability. - Modified action sampling in SACPolicy to use reparameterized sampling, ensuring better gradient flow and log probability calculations. - Cleaned up log probability calculations in TanhMultivariateNormalDiag for clarity and efficiency. - Increased evaluation frequency in YAML configuration to 50000 for more efficient training cycles. These changes aim to enhance the robustness and performance of the SAC implementation during training and inference.
This commit is contained in:
parent
ca74a13d61
commit
22fbc9ea4a
|
@ -59,14 +59,10 @@ class SACConfig:
|
|||
"activate_final": True,
|
||||
}
|
||||
policy_kwargs = {
|
||||
"tanh_squash_distribution": True,
|
||||
"std_parameterization": "uniform",
|
||||
}
|
||||
|
||||
input_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": [3, 84, 84],
|
||||
"observation.state": [4],
|
||||
"tanh_squash_distribution": True,
|
||||
"std_parameterization": "softplus",
|
||||
"std_min": 0.005,
|
||||
"std_max": 5.0,
|
||||
}
|
||||
)
|
||||
output_shapes: dict[str, list[int]] = field(
|
||||
|
|
|
@ -134,7 +134,6 @@ class SACPolicy(
|
|||
# 1- compute actions from policy
|
||||
distribution = self.actor(observations)
|
||||
action_preds = distribution.sample()
|
||||
log_probs = distribution.log_prob(action_preds)
|
||||
action_preds = torch.clamp(action_preds, -1, +1)
|
||||
# 2- compute q targets
|
||||
q_targets = self.target_qs(next_observations, action_preds)
|
||||
|
@ -186,15 +185,11 @@ class SACPolicy(
|
|||
temperature = self.temperature()
|
||||
|
||||
# 2- get actions (batch_size, action_dim) and log probs (batch_size,)
|
||||
<<<<<<< HEAD
|
||||
actions, log_probs = self.actor(observations) \
|
||||
|
||||
=======
|
||||
distribution = self.actor(observations)
|
||||
actions = distribution.sample()
|
||||
log_probs = distribution.log_prob(actions)
|
||||
actions = distribution.rsample()
|
||||
log_probs = distribution.log_prob(actions).sum(-1)
|
||||
# breakpoint()
|
||||
actions = torch.clamp(actions, -1, +1)
|
||||
>>>>>>> d3c62b92 (Refactor SACPolicy for improved action sampling and standard deviation handling)
|
||||
# 3- get q-value predictions
|
||||
with torch.no_grad():
|
||||
q_preds = self.critic_ensemble(observations, actions, return_type="mean")
|
||||
|
@ -610,7 +605,7 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution):
|
|||
"""
|
||||
Reparameterized sample from the distribution
|
||||
"""
|
||||
# Sample from base distribution
|
||||
# Sample from base distributionrsample
|
||||
x = self.base_dist.rsample(sample_shape)
|
||||
|
||||
# Apply transforms
|
||||
|
@ -625,17 +620,18 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution):
|
|||
Includes the log det jacobian for the transforms
|
||||
"""
|
||||
# Initialize log prob
|
||||
log_prob = torch.zeros_like(value[..., 0])
|
||||
|
||||
log_prob = torch.zeros_like(value)
|
||||
|
||||
# Inverse transforms to get back to normal distribution
|
||||
q = value
|
||||
for transform in reversed(self.transforms):
|
||||
q = transform.inv(q)
|
||||
log_prob = log_prob - transform.log_abs_det_jacobian(q, transform(q))
|
||||
|
||||
q_prev = transform.inv(q) # Get the pre-transform value
|
||||
log_prob = log_prob - transform.log_abs_det_jacobian(q_prev, q) # Sum over action dimensions
|
||||
q = q_prev
|
||||
|
||||
# Add base distribution log prob
|
||||
log_prob = log_prob + self.base_dist.log_prob(q).sum(-1)
|
||||
|
||||
log_prob = log_prob + self.base_dist.log_prob(q) # Sum over action dimensions
|
||||
|
||||
return log_prob
|
||||
|
||||
def sample_and_log_prob(self, sample_shape=DEFAULT_SAMPLE_SHAPE) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
@ -646,20 +642,20 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution):
|
|||
log_prob = self.log_prob(x)
|
||||
return x, log_prob
|
||||
|
||||
def entropy(self) -> torch.Tensor:
|
||||
"""
|
||||
Compute entropy of the distribution
|
||||
"""
|
||||
# Start with base distribution entropy
|
||||
entropy = self.base_dist.entropy().sum(-1)
|
||||
|
||||
# Add log det jacobian for each transform
|
||||
x = self.rsample()
|
||||
for transform in self.transforms:
|
||||
entropy = entropy + transform.log_abs_det_jacobian(x, transform(x))
|
||||
x = transform(x)
|
||||
|
||||
return entropy
|
||||
# def entropy(self) -> torch.Tensor:
|
||||
# """
|
||||
# Compute entropy of the distribution
|
||||
# """
|
||||
# # Start with base distribution entropy
|
||||
# entropy = self.base_dist.entropy().sum(-1)
|
||||
|
||||
# # Add log det jacobian for each transform
|
||||
# x = self.rsample()
|
||||
# for transform in self.transforms:
|
||||
# entropy = entropy + transform.log_abs_det_jacobian(x, transform(x))
|
||||
# x = transform(x)
|
||||
|
||||
# return entropy
|
||||
|
||||
|
||||
def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: str = "cuda") -> nn.ModuleList:
|
||||
|
|
|
@ -19,7 +19,7 @@ training:
|
|||
grad_clip_norm: 10.0
|
||||
lr: 3e-4
|
||||
|
||||
eval_freq: 10000
|
||||
eval_freq: 50000
|
||||
log_freq: 500
|
||||
save_freq: 50000
|
||||
|
||||
|
|
Loading…
Reference in New Issue