fix bug
This commit is contained in:
parent
9f6f508edb
commit
ab2c2d39fb
|
@ -138,6 +138,7 @@ class SACPolicy(
|
||||||
encoder=encoder_critic,
|
encoder=encoder_critic,
|
||||||
input_dim=encoder_critic.output_dim,
|
input_dim=encoder_critic.output_dim,
|
||||||
output_dim=config.num_discrete_actions,
|
output_dim=config.num_discrete_actions,
|
||||||
|
softmax_temperature=1.0,
|
||||||
**asdict(config.grasp_critic_network_kwargs),
|
**asdict(config.grasp_critic_network_kwargs),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -198,7 +199,7 @@ class SACPolicy(
|
||||||
|
|
||||||
if self.config.num_discrete_actions is not None:
|
if self.config.num_discrete_actions is not None:
|
||||||
_, discrete_action_distribution = self.grasp_critic(batch, observations_features)
|
_, discrete_action_distribution = self.grasp_critic(batch, observations_features)
|
||||||
discrete_action = discrete_action_distribution.sample()
|
discrete_action = discrete_action_distribution.sample().unsqueeze(-1).float()
|
||||||
actions = torch.cat([actions, discrete_action], dim=-1)
|
actions = torch.cat([actions, discrete_action], dim=-1)
|
||||||
|
|
||||||
return actions
|
return actions
|
||||||
|
@ -435,7 +436,7 @@ class SACPolicy(
|
||||||
next_grasp_qs, next_grasp_distribution = self.grasp_critic_forward(
|
next_grasp_qs, next_grasp_distribution = self.grasp_critic_forward(
|
||||||
next_observations, use_target=False, observation_features=next_observation_features
|
next_observations, use_target=False, observation_features=next_observation_features
|
||||||
)
|
)
|
||||||
best_next_grasp_action = next_grasp_distribution.sample()
|
best_next_grasp_action = next_grasp_distribution.sample().unsqueeze(-1)
|
||||||
|
|
||||||
# Get target Q-values from target network
|
# Get target Q-values from target network
|
||||||
target_next_grasp_qs, _ = self.grasp_critic_forward(
|
target_next_grasp_qs, _ = self.grasp_critic_forward(
|
||||||
|
|
Loading…
Reference in New Issue