fixed softmax temp
This commit is contained in:
parent
10adadbc71
commit
e36bee7560
lerobot/common/policies/sac
|
@ -129,7 +129,7 @@ class SACPolicy(
|
||||||
encoder=encoder_critic,
|
encoder=encoder_critic,
|
||||||
input_dim=encoder_critic.output_dim,
|
input_dim=encoder_critic.output_dim,
|
||||||
output_dim=config.num_discrete_actions,
|
output_dim=config.num_discrete_actions,
|
||||||
softmax_temperature=1.0,
|
softmax_temperature=.15,
|
||||||
**asdict(config.grasp_critic_network_kwargs),
|
**asdict(config.grasp_critic_network_kwargs),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -138,7 +138,7 @@ class SACPolicy(
|
||||||
encoder=encoder_critic,
|
encoder=encoder_critic,
|
||||||
input_dim=encoder_critic.output_dim,
|
input_dim=encoder_critic.output_dim,
|
||||||
output_dim=config.num_discrete_actions,
|
output_dim=config.num_discrete_actions,
|
||||||
softmax_temperature=1.0,
|
softmax_temperature=0.15,
|
||||||
**asdict(config.grasp_critic_network_kwargs),
|
**asdict(config.grasp_critic_network_kwargs),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -786,6 +786,7 @@ class GraspCritic(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
self.output_dim = output_dim
|
self.output_dim = output_dim
|
||||||
|
self.softmax_temperature = softmax_temperature
|
||||||
|
|
||||||
self.net = MLP(
|
self.net = MLP(
|
||||||
input_dim=input_dim,
|
input_dim=input_dim,
|
||||||
|
|
Loading…
Reference in New Issue