Remove torch.no_grad decorator and optimize next action prediction in SAC policy

- Removed `@torch.no_grad` decorator from Unnormalize forward method

- Added TODO comment for optimizing next action prediction in SAC policy
- Minor formatting adjustment in NaN assertion for log standard deviation
Co-authored-by: Yoel Chornton <yoel.chornton@gmail.com>
This commit is contained in:
AdilZouitine 2025-03-10 10:31:38 +00:00
parent d711e20b5f
commit 25b88f3b86
2 changed files with 9 additions and 4 deletions

View File

@ -196,7 +196,7 @@ class Unnormalize(nn.Module):
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
# TODO(rcadene): should we remove torch.no_grad?
@torch.no_grad
# @torch.no_grad
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch) # shallow copy avoids mutating the input batch
for key, mode in self.modes.items():

View File

@ -210,6 +210,11 @@ class SACPolicy(
next_observations, next_observation_features
)
# TODO: (maractingi, azouitine) This is to slow, we should find a way to do this in a more efficient way
next_action_preds = self.unnormalize_outputs({"action": next_action_preds})[
"action"
]
# 2- compute q targets
q_targets = self.critic_forward(
observations=next_observations,
@ -512,9 +517,9 @@ class Policy(nn.Module):
# Compute standard deviations
if self.fixed_std is None:
log_std = self.std_layer(outputs)
assert not torch.isnan(
log_std
).any(), "[ERROR] log_std became NaN after std_layer!"
assert not torch.isnan(log_std).any(), (
"[ERROR] log_std became NaN after std_layer!"
)
if self.use_tanh_squash:
log_std = torch.tanh(log_std)