Remove torch.no_grad decorator and optimize next action prediction in SAC policy
- Removed `@torch.no_grad` decorator from Unnormalize forward method - Added TODO comment for optimizing next action prediction in SAC policy - Minor formatting adjustment in NaN assertion for log standard deviation Co-authored-by: Yoel Chornton <yoel.chornton@gmail.com>
This commit is contained in:
parent
d711e20b5f
commit
25b88f3b86
|
@ -196,7 +196,7 @@ class Unnormalize(nn.Module):
|
|||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
|
||||
# TODO(rcadene): should we remove torch.no_grad?
|
||||
@torch.no_grad
|
||||
# @torch.no_grad
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||
for key, mode in self.modes.items():
|
||||
|
|
|
@ -210,6 +210,11 @@ class SACPolicy(
|
|||
next_observations, next_observation_features
|
||||
)
|
||||
|
||||
# TODO: (maractingi, azouitine) This is to slow, we should find a way to do this in a more efficient way
|
||||
next_action_preds = self.unnormalize_outputs({"action": next_action_preds})[
|
||||
"action"
|
||||
]
|
||||
|
||||
# 2- compute q targets
|
||||
q_targets = self.critic_forward(
|
||||
observations=next_observations,
|
||||
|
@ -512,9 +517,9 @@ class Policy(nn.Module):
|
|||
# Compute standard deviations
|
||||
if self.fixed_std is None:
|
||||
log_std = self.std_layer(outputs)
|
||||
assert not torch.isnan(
|
||||
log_std
|
||||
).any(), "[ERROR] log_std became NaN after std_layer!"
|
||||
assert not torch.isnan(log_std).any(), (
|
||||
"[ERROR] log_std became NaN after std_layer!"
|
||||
)
|
||||
|
||||
if self.use_tanh_squash:
|
||||
log_std = torch.tanh(log_std)
|
||||
|
|
Loading…
Reference in New Issue