Add rounding for safety
This commit is contained in:
parent
a7be613ee8
commit
a8135629b4
|
@ -421,6 +421,7 @@ class SACPolicy(
|
|||
# In the buffer we have the full action space (continuous + discrete)
|
||||
# We need to split them before concatenating them in the critic forward
|
||||
actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone()
|
||||
actions_discrete = torch.round(actions_discrete)
|
||||
actions_discrete = actions_discrete.long()
|
||||
|
||||
if complementary_info is not None:
|
||||
|
|
Loading…
Reference in New Issue