Remove torch.no_grad decorator and optimize next action prediction in SAC policy
- Removed `@torch.no_grad` decorator from Unnormalize forward method - Added TODO comment for optimizing next action prediction in SAC policy - Minor formatting adjustment in NaN assertion for log standard deviation Co-authored-by: Yoel Chornton <yoel.chornton@gmail.com>
This commit is contained in:
parent
d711e20b5f
commit
25b88f3b86
|
@ -196,7 +196,7 @@ class Unnormalize(nn.Module):
|
||||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||||
|
|
||||||
# TODO(rcadene): should we remove torch.no_grad?
|
# TODO(rcadene): should we remove torch.no_grad?
|
||||||
@torch.no_grad
|
# @torch.no_grad
|
||||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||||
for key, mode in self.modes.items():
|
for key, mode in self.modes.items():
|
||||||
|
|
|
@ -210,6 +210,11 @@ class SACPolicy(
|
||||||
next_observations, next_observation_features
|
next_observations, next_observation_features
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO: (maractingi, azouitine) This is to slow, we should find a way to do this in a more efficient way
|
||||||
|
next_action_preds = self.unnormalize_outputs({"action": next_action_preds})[
|
||||||
|
"action"
|
||||||
|
]
|
||||||
|
|
||||||
# 2- compute q targets
|
# 2- compute q targets
|
||||||
q_targets = self.critic_forward(
|
q_targets = self.critic_forward(
|
||||||
observations=next_observations,
|
observations=next_observations,
|
||||||
|
@ -512,9 +517,9 @@ class Policy(nn.Module):
|
||||||
# Compute standard deviations
|
# Compute standard deviations
|
||||||
if self.fixed_std is None:
|
if self.fixed_std is None:
|
||||||
log_std = self.std_layer(outputs)
|
log_std = self.std_layer(outputs)
|
||||||
assert not torch.isnan(
|
assert not torch.isnan(log_std).any(), (
|
||||||
log_std
|
"[ERROR] log_std became NaN after std_layer!"
|
||||||
).any(), "[ERROR] log_std became NaN after std_layer!"
|
)
|
||||||
|
|
||||||
if self.use_tanh_squash:
|
if self.use_tanh_squash:
|
||||||
log_std = torch.tanh(log_std)
|
log_std = torch.tanh(log_std)
|
||||||
|
|
Loading…
Reference in New Issue