Add rounding for safety

This commit is contained in:
AdilZouitine 2025-04-08 08:50:02 +00:00
parent a7be613ee8
commit a8135629b4
1 changed files with 1 additions and 0 deletions

View File

@ -421,6 +421,7 @@ class SACPolicy(
# In the buffer we have the full action space (continuous + discrete)
# We need to split them before concatenating them in the critic forward
actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone()
actions_discrete = torch.round(actions_discrete)
actions_discrete = actions_discrete.long()
if complementary_info is not None: