Fix init temp

Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
This commit is contained in:
AdilZouitine 2025-04-16 14:43:47 +00:00 committed by Adil Zouitine
parent a6f612e2e3
commit 7191bbbff9
1 changed files with 4 additions and 2 deletions

View File

@ -27,7 +27,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from torch import Tensor
from torch.distributions import MultivariateNormal, TransformedDistribution, TanhTransform, Transform
from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy
@ -156,7 +156,9 @@ class SACPolicy(
**asdict(config.policy_kwargs),
)
if config.target_entropy is None:
discrete_actions_dim: Literal[1] | Literal[0] = 1 if config.num_discrete_actions is None else 0
discrete_actions_dim: Literal[1] | Literal[0] = (
1 if config.num_discrete_actions is not None else 0
)
config.target_entropy = -np.prod(continuous_action_dim + discrete_actions_dim) / 2 # (-dim(A)/2)
# TODO (azouitine): Handle the case where the temparameter is a fixed