SAC works
This commit is contained in:
parent
1e9bafc852
commit
bd8c768f62
|
@ -271,6 +271,9 @@ class SACPolicy(
|
||||||
q_targets = self.critic_forward(
|
q_targets = self.critic_forward(
|
||||||
observations=next_observations, actions=next_action_preds, use_target=True
|
observations=next_observations, actions=next_action_preds, use_target=True
|
||||||
)
|
)
|
||||||
|
q_targets = self.critic_forward(
|
||||||
|
observations=next_observations, actions=next_action_preds, use_target=True
|
||||||
|
)
|
||||||
|
|
||||||
# subsample critics to prevent overfitting if use high UTD (update to date)
|
# subsample critics to prevent overfitting if use high UTD (update to date)
|
||||||
if self.config.num_subsample_critics is not None:
|
if self.config.num_subsample_critics is not None:
|
||||||
|
|
|
@ -265,6 +265,9 @@ class ReplayBuffer:
|
||||||
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
|
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
|
||||||
self.device
|
self.device
|
||||||
)
|
)
|
||||||
|
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
|
||||||
|
self.device
|
||||||
|
)
|
||||||
|
|
||||||
# Return a BatchTransition typed dict
|
# Return a BatchTransition typed dict
|
||||||
return BatchTransition(
|
return BatchTransition(
|
||||||
|
|
Loading…
Reference in New Issue