From 20d31ab8e023ad240a32ecd2cf931e9d27131993 Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Tue, 14 Jan 2025 11:34:52 +0100 Subject: [PATCH] SAC works --- lerobot/common/policies/sac/modeling_sac.py | 3 +++ lerobot/scripts/train_sac.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 23513916..fece59f0 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -271,6 +271,9 @@ class SACPolicy( q_targets = self.critic_forward( observations=next_observations, actions=next_action_preds, use_target=True ) + q_targets = self.critic_forward( + observations=next_observations, actions=next_action_preds, use_target=True + ) # subsample critics to prevent overfitting if use high UTD (update to date) if self.config.num_subsample_critics is not None: diff --git a/lerobot/scripts/train_sac.py b/lerobot/scripts/train_sac.py index 942a19ab..eba504d3 100644 --- a/lerobot/scripts/train_sac.py +++ b/lerobot/scripts/train_sac.py @@ -265,6 +265,9 @@ class ReplayBuffer: batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to( self.device ) + batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to( + self.device + ) # Return a BatchTransition typed dict return BatchTransition(