diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index 6df94761..7a4bd364 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -59,14 +59,10 @@ class SACConfig: "activate_final": True, } policy_kwargs = { - "tanh_squash_distribution": True, - "std_parameterization": "uniform", - } - - input_shapes: dict[str, list[int]] = field( - default_factory=lambda: { - "observation.image": [3, 84, 84], - "observation.state": [4], + "tanh_squash_distribution": True, + "std_parameterization": "softplus", + "std_min": 0.005, + "std_max": 5.0, } ) output_shapes: dict[str, list[int]] = field( diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 821cb93f..806cb767 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -134,7 +134,6 @@ class SACPolicy( # 1- compute actions from policy distribution = self.actor(observations) action_preds = distribution.sample() - log_probs = distribution.log_prob(action_preds) action_preds = torch.clamp(action_preds, -1, +1) # 2- compute q targets q_targets = self.target_qs(next_observations, action_preds) @@ -186,15 +185,11 @@ class SACPolicy( temperature = self.temperature() # 2- get actions (batch_size, action_dim) and log probs (batch_size,) -<<<<<<< HEAD - actions, log_probs = self.actor(observations) \ - -======= distribution = self.actor(observations) - actions = distribution.sample() - log_probs = distribution.log_prob(actions) + actions = distribution.rsample() + log_probs = distribution.log_prob(actions).sum(-1) + # breakpoint() actions = torch.clamp(actions, -1, +1) ->>>>>>> d3c62b92 (Refactor SACPolicy for improved action sampling and standard deviation handling) # 3- get q-value predictions with torch.no_grad(): q_preds = self.critic_ensemble(observations, actions, return_type="mean") @@ -610,7 +605,7 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution): """ Reparameterized sample from the distribution """ - # Sample from base distribution + # Sample from base distributionrsample x = self.base_dist.rsample(sample_shape) # Apply transforms @@ -625,17 +620,18 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution): Includes the log det jacobian for the transforms """ # Initialize log prob - log_prob = torch.zeros_like(value[..., 0]) - + log_prob = torch.zeros_like(value) + # Inverse transforms to get back to normal distribution q = value for transform in reversed(self.transforms): - q = transform.inv(q) - log_prob = log_prob - transform.log_abs_det_jacobian(q, transform(q)) - + q_prev = transform.inv(q) # Get the pre-transform value + log_prob = log_prob - transform.log_abs_det_jacobian(q_prev, q) # Sum over action dimensions + q = q_prev + # Add base distribution log prob - log_prob = log_prob + self.base_dist.log_prob(q).sum(-1) - + log_prob = log_prob + self.base_dist.log_prob(q) # Sum over action dimensions + return log_prob def sample_and_log_prob(self, sample_shape=DEFAULT_SAMPLE_SHAPE) -> Tuple[torch.Tensor, torch.Tensor]: @@ -646,20 +642,20 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution): log_prob = self.log_prob(x) return x, log_prob - def entropy(self) -> torch.Tensor: - """ - Compute entropy of the distribution - """ - # Start with base distribution entropy - entropy = self.base_dist.entropy().sum(-1) - - # Add log det jacobian for each transform - x = self.rsample() - for transform in self.transforms: - entropy = entropy + transform.log_abs_det_jacobian(x, transform(x)) - x = transform(x) - - return entropy + # def entropy(self) -> torch.Tensor: + # """ + # Compute entropy of the distribution + # """ + # # Start with base distribution entropy + # entropy = self.base_dist.entropy().sum(-1) + + # # Add log det jacobian for each transform + # x = self.rsample() + # for transform in self.transforms: + # entropy = entropy + transform.log_abs_det_jacobian(x, transform(x)) + # x = transform(x) + + # return entropy def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: str = "cuda") -> nn.ModuleList: diff --git a/lerobot/configs/policy/sac_pusht_keypoints.yaml b/lerobot/configs/policy/sac_pusht_keypoints.yaml index 19af60d4..6d8971a2 100644 --- a/lerobot/configs/policy/sac_pusht_keypoints.yaml +++ b/lerobot/configs/policy/sac_pusht_keypoints.yaml @@ -19,7 +19,7 @@ training: grad_clip_norm: 10.0 lr: 3e-4 - eval_freq: 10000 + eval_freq: 50000 log_freq: 500 save_freq: 50000