From 8d1936ffe0410374d3bdfe5d5a8a4a3a7ceca09c Mon Sep 17 00:00:00 2001 From: s1lent4gnt Date: Mon, 31 Mar 2025 17:38:16 +0200 Subject: [PATCH] Add gripper penalty wrapper --- lerobot/scripts/server/gym_manipulator.py | 24 +++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/lerobot/scripts/server/gym_manipulator.py b/lerobot/scripts/server/gym_manipulator.py index 92e8dcbc..26ed1991 100644 --- a/lerobot/scripts/server/gym_manipulator.py +++ b/lerobot/scripts/server/gym_manipulator.py @@ -1069,6 +1069,29 @@ class ActionScaleWrapper(gym.ActionWrapper): return action * self.scale_vector, is_intervention +class GripperPenaltyWrapper(gym.Wrapper): + def __init__(self, env, penalty=-0.05): + super().__init__(env) + self.penalty = penalty + self.last_gripper_pos = None + + def reset(self, **kwargs): + obs, info = self.env.reset(**kwargs) + self.last_gripper_pos = obs["observation.state"][0, 0] # first idx for the gripper + return obs, info + + def step(self, action): + observation, reward, terminated, truncated, info = self.env.step(action) + + if (action[-1] < -0.5 and self.last_gripper_pos > 0.9) or (action[-1] > 0.5 and self.last_gripper_pos < 0.9): + info["grasp_penalty"] = self.penalty + else: + info["grasp_penalty"] = 0.0 + + self.last_gripper_pos = observation["observation.state"][0, 0] # first idx for the gripper + return observation, reward, terminated, truncated, info + + def make_robot_env(cfg) -> gym.vector.VectorEnv: """ Factory function to create a vectorized robot environment. @@ -1144,6 +1167,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv: if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None: env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space) env = BatchCompitableWrapper(env=env) + env= GripperPenaltyWrapper(env=env) return env