From 06de182448120b5e6f7b6d0c12a81c1ba10a364e Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Mon, 7 Apr 2025 16:51:17 +0200 Subject: [PATCH] gixes for gripper penalty --- lerobot/common/policies/sac/modeling_sac.py | 1 + lerobot/scripts/server/actor_server.py | 2 +- lerobot/scripts/server/buffer.py | 2 +- lerobot/scripts/server/gym_manipulator.py | 13 ++++++++----- lerobot/scripts/server/learner_server.py | 1 + 5 files changed, 12 insertions(+), 7 deletions(-) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 7b3c9c41..a5df2533 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -428,6 +428,7 @@ class SACPolicy( actions_discrete = torch.round(actions_discrete) actions_discrete = actions_discrete.long() + gripper_penalties: Tensor | None = None if complementary_info is not None: gripper_penalties: Tensor | None = complementary_info.get("gripper_penalty") diff --git a/lerobot/scripts/server/actor_server.py b/lerobot/scripts/server/actor_server.py index 7ac8343e..0895c29e 100644 --- a/lerobot/scripts/server/actor_server.py +++ b/lerobot/scripts/server/actor_server.py @@ -281,7 +281,7 @@ def act_with_policy( for key, tensor in obs.items(): if torch.isnan(tensor).any(): logging.error(f"[ACTOR] NaN values found in obs[{key}] at step {interaction_step}") - + list_transition_to_send_to_learner.append( Transition( state=obs, diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py index e95f8f55..7f0b2429 100644 --- a/lerobot/scripts/server/buffer.py +++ b/lerobot/scripts/server/buffer.py @@ -269,7 +269,7 @@ class ReplayBuffer: self.complementary_info[key] = torch.empty( (self.capacity, *value_shape), device=self.storage_device ) - elif isinstance(value, (int, float)): + elif isinstance(value, (int, float, bool)): # Handle scalar values similar to reward self.complementary_info[key] = torch.empty((self.capacity,), device=self.storage_device) else: diff --git a/lerobot/scripts/server/gym_manipulator.py b/lerobot/scripts/server/gym_manipulator.py index 501a71ce..1c1aead3 100644 --- a/lerobot/scripts/server/gym_manipulator.py +++ b/lerobot/scripts/server/gym_manipulator.py @@ -800,18 +800,21 @@ class GripperPenaltyWrapper(gym.RewardWrapper): else: gripper_action = action[-1] obs, reward, terminated, truncated, info = self.env.step(action) - grasp_reward = self.reward(reward, gripper_action) + gripper_penalty = self.reward(reward, gripper_action) if self.gripper_penalty_in_reward: - reward += grasp_reward + reward += gripper_penalty else: - info["grasp_reward"] = grasp_reward + info["gripper_penalty"] = gripper_penalty return obs, reward, terminated, truncated, info def reset(self, **kwargs): self.last_gripper_state = None - return super().reset(**kwargs) + obs, info = super().reset(**kwargs) + if self.gripper_penalty_in_reward: + info["gripper_penalty"] = 0.0 + return obs, info class GripperActionWrapper(gym.ActionWrapper): def __init__(self, env, quantization_threshold: float = 0.2): @@ -1192,7 +1195,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv: env = GripperActionWrapper( env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold ) - env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty) + env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty, gripper_penalty_in_reward=cfg.wrapper.gripper_penalty_in_reward) if cfg.wrapper.ee_action_space_params is not None: env = EEActionWrapper( diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index bcf47787..4a67f07d 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -408,6 +408,7 @@ def add_actor_information_and_train( "done": done, "observation_feature": observation_features, "next_observation_feature": next_observation_features, + "complementary_info": batch.get("complementary_info", None), } # Use the forward method for critic loss (includes both main critic and grasp critic)