SAC works

This commit is contained in:
Adil Zouitine 2025-01-14 11:34:52 +01:00
parent e5b83aab5e
commit 20d31ab8e0
2 changed files with 6 additions and 0 deletions

View File

@ -271,6 +271,9 @@ class SACPolicy(
q_targets = self.critic_forward(
observations=next_observations, actions=next_action_preds, use_target=True
)
q_targets = self.critic_forward(
observations=next_observations, actions=next_action_preds, use_target=True
)
# subsample critics to prevent overfitting if use high UTD (update to date)
if self.config.num_subsample_critics is not None:

View File

@ -265,6 +265,9 @@ class ReplayBuffer:
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
self.device
)
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
self.device
)
# Return a BatchTransition typed dict
return BatchTransition(