SAC works

This commit is contained in:
Adil Zouitine 2025-01-14 11:34:52 +01:00 committed by Michel Aractingi
parent 86df8a433d
commit c1d4bf4b63
2 changed files with 6 additions and 0 deletions

View File

@ -271,6 +271,9 @@ class SACPolicy(
q_targets = self.critic_forward( q_targets = self.critic_forward(
observations=next_observations, actions=next_action_preds, use_target=True observations=next_observations, actions=next_action_preds, use_target=True
) )
q_targets = self.critic_forward(
observations=next_observations, actions=next_action_preds, use_target=True
)
# subsample critics to prevent overfitting if use high UTD (update to date) # subsample critics to prevent overfitting if use high UTD (update to date)
if self.config.num_subsample_critics is not None: if self.config.num_subsample_critics is not None:

View File

@ -265,6 +265,9 @@ class ReplayBuffer:
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to( batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
self.device self.device
) )
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
self.device
)
# Return a BatchTransition typed dict # Return a BatchTransition typed dict
return BatchTransition( return BatchTransition(