Fix init temp
Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
This commit is contained in:
parent
23c9441d5f
commit
dc1548fe1a
|
@ -27,7 +27,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.distributions import MultivariateNormal, TransformedDistribution, TanhTransform, Transform
|
from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution
|
||||||
|
|
||||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||||
|
@ -156,7 +156,9 @@ class SACPolicy(
|
||||||
**asdict(config.policy_kwargs),
|
**asdict(config.policy_kwargs),
|
||||||
)
|
)
|
||||||
if config.target_entropy is None:
|
if config.target_entropy is None:
|
||||||
discrete_actions_dim: Literal[1] | Literal[0] = 1 if config.num_discrete_actions is None else 0
|
discrete_actions_dim: Literal[1] | Literal[0] = (
|
||||||
|
1 if config.num_discrete_actions is not None else 0
|
||||||
|
)
|
||||||
config.target_entropy = -np.prod(continuous_action_dim + discrete_actions_dim) / 2 # (-dim(A)/2)
|
config.target_entropy = -np.prod(continuous_action_dim + discrete_actions_dim) / 2 # (-dim(A)/2)
|
||||||
|
|
||||||
# TODO (azouitine): Handle the case where the temparameter is a fixed
|
# TODO (azouitine): Handle the case where the temparameter is a fixed
|
||||||
|
|
Loading…
Reference in New Issue