From 25b88f3b86ae9c19fe8db499a0118a95e8bcf8d4 Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Mon, 10 Mar 2025 10:31:38 +0000 Subject: [PATCH] Remove torch.no_grad decorator and optimize next action prediction in SAC policy - Removed `@torch.no_grad` decorator from Unnormalize forward method - Added TODO comment for optimizing next action prediction in SAC policy - Minor formatting adjustment in NaN assertion for log standard deviation Co-authored-by: Yoel Chornton --- lerobot/common/policies/normalize.py | 2 +- lerobot/common/policies/sac/modeling_sac.py | 11 ++++++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/lerobot/common/policies/normalize.py b/lerobot/common/policies/normalize.py index 2e0b266e..8dbe048d 100644 --- a/lerobot/common/policies/normalize.py +++ b/lerobot/common/policies/normalize.py @@ -196,7 +196,7 @@ class Unnormalize(nn.Module): setattr(self, "buffer_" + key.replace(".", "_"), buffer) # TODO(rcadene): should we remove torch.no_grad? - @torch.no_grad + # @torch.no_grad def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: batch = dict(batch) # shallow copy avoids mutating the input batch for key, mode in self.modes.items(): diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 9eb864ec..4baf7d88 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -210,6 +210,11 @@ class SACPolicy( next_observations, next_observation_features ) + # TODO: (maractingi, azouitine) This is to slow, we should find a way to do this in a more efficient way + next_action_preds = self.unnormalize_outputs({"action": next_action_preds})[ + "action" + ] + # 2- compute q targets q_targets = self.critic_forward( observations=next_observations, @@ -512,9 +517,9 @@ class Policy(nn.Module): # Compute standard deviations if self.fixed_std is None: log_std = self.std_layer(outputs) - assert not torch.isnan( - log_std - ).any(), "[ERROR] log_std became NaN after std_layer!" + assert not torch.isnan(log_std).any(), ( + "[ERROR] log_std became NaN after std_layer!" + ) if self.use_tanh_squash: log_std = torch.tanh(log_std)