From 97513287838b33f1bba90af3de66d6ad2ad37320 Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Tue, 8 Apr 2025 08:50:02 +0000 Subject: [PATCH] Add rounding for safety --- lerobot/common/policies/sac/modeling_sac.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 9b909813..e3d83d36 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -421,6 +421,7 @@ class SACPolicy( # In the buffer we have the full action space (continuous + discrete) # We need to split them before concatenating them in the critic forward actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone() + actions_discrete = torch.round(actions_discrete) actions_discrete = actions_discrete.long() if complementary_info is not None: