diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 43221d5c..3e99dbd4 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -17,7 +17,6 @@ # TODO: (1) better device management -from copy import deepcopy from typing import Callable, Optional, Tuple, Union, Dict, List from pathlib import Path @@ -637,9 +636,9 @@ class Policy(nn.Module): # Compute standard deviations if self.fixed_std is None: log_std = self.std_layer(outputs) - assert not torch.isnan(log_std).any(), ( - "[ERROR] log_std became NaN after std_layer!" - ) + assert not torch.isnan( + log_std + ).any(), "[ERROR] log_std became NaN after std_layer!" if self.use_tanh_squash: log_std = torch.tanh(log_std)