diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 79d513b8..24d53600 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -27,7 +27,7 @@ import torch import torch.nn as nn import torch.nn.functional as F # noqa: N812 from torch import Tensor -from torch.distributions import MultivariateNormal, TransformedDistribution, TanhTransform, Transform +from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.pretrained import PreTrainedPolicy @@ -156,7 +156,9 @@ class SACPolicy( **asdict(config.policy_kwargs), ) if config.target_entropy is None: - discrete_actions_dim: Literal[1] | Literal[0] = 1 if config.num_discrete_actions is None else 0 + discrete_actions_dim: Literal[1] | Literal[0] = ( + 1 if config.num_discrete_actions is not None else 0 + ) config.target_entropy = -np.prod(continuous_action_dim + discrete_actions_dim) / 2 # (-dim(A)/2) # TODO (azouitine): Handle the case where the temparameter is a fixed