Add rounding for safety

This commit is contained in:
AdilZouitine 2025-04-08 08:50:02 +00:00 committed by Michel Aractingi
parent a3ada81816
commit 68c271ad25
1 changed files with 1 additions and 0 deletions

View File

@ -421,6 +421,7 @@ class SACPolicy(
# In the buffer we have the full action space (continuous + discrete) # In the buffer we have the full action space (continuous + discrete)
# We need to split them before concatenating them in the critic forward # We need to split them before concatenating them in the critic forward
actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone() actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone()
actions_discrete = torch.round(actions_discrete)
actions_discrete = actions_discrete.long() actions_discrete = actions_discrete.long()
if complementary_info is not None: if complementary_info is not None: