Fix init temp

Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
This commit is contained in:
AdilZouitine 2025-04-16 14:43:47 +00:00 committed by Adil Zouitine
parent 23c9441d5f
commit dc1548fe1a
1 changed files with 4 additions and 2 deletions

View File

@ -27,7 +27,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F # noqa: N812 import torch.nn.functional as F # noqa: N812
from torch import Tensor from torch import Tensor
from torch.distributions import MultivariateNormal, TransformedDistribution, TanhTransform, Transform from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution
from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.pretrained import PreTrainedPolicy
@ -156,7 +156,9 @@ class SACPolicy(
**asdict(config.policy_kwargs), **asdict(config.policy_kwargs),
) )
if config.target_entropy is None: if config.target_entropy is None:
discrete_actions_dim: Literal[1] | Literal[0] = 1 if config.num_discrete_actions is None else 0 discrete_actions_dim: Literal[1] | Literal[0] = (
1 if config.num_discrete_actions is not None else 0
)
config.target_entropy = -np.prod(continuous_action_dim + discrete_actions_dim) / 2 # (-dim(A)/2) config.target_entropy = -np.prod(continuous_action_dim + discrete_actions_dim) / 2 # (-dim(A)/2)
# TODO (azouitine): Handle the case where the temparameter is a fixed # TODO (azouitine): Handle the case where the temparameter is a fixed