diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 9b909813..e3d83d36 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -421,6 +421,7 @@ class SACPolicy( # In the buffer we have the full action space (continuous + discrete) # We need to split them before concatenating them in the critic forward actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone() + actions_discrete = torch.round(actions_discrete) actions_discrete = actions_discrete.long() if complementary_info is not None: