diff --git a/lerobot/scripts/server/gym_manipulator.py b/lerobot/scripts/server/gym_manipulator.py index 92e8dcbc..26ed1991 100644 --- a/lerobot/scripts/server/gym_manipulator.py +++ b/lerobot/scripts/server/gym_manipulator.py @@ -1069,6 +1069,29 @@ class ActionScaleWrapper(gym.ActionWrapper): return action * self.scale_vector, is_intervention +class GripperPenaltyWrapper(gym.Wrapper): + def __init__(self, env, penalty=-0.05): + super().__init__(env) + self.penalty = penalty + self.last_gripper_pos = None + + def reset(self, **kwargs): + obs, info = self.env.reset(**kwargs) + self.last_gripper_pos = obs["observation.state"][0, 0] # first idx for the gripper + return obs, info + + def step(self, action): + observation, reward, terminated, truncated, info = self.env.step(action) + + if (action[-1] < -0.5 and self.last_gripper_pos > 0.9) or (action[-1] > 0.5 and self.last_gripper_pos < 0.9): + info["grasp_penalty"] = self.penalty + else: + info["grasp_penalty"] = 0.0 + + self.last_gripper_pos = observation["observation.state"][0, 0] # first idx for the gripper + return observation, reward, terminated, truncated, info + + def make_robot_env(cfg) -> gym.vector.VectorEnv: """ Factory function to create a vectorized robot environment. @@ -1144,6 +1167,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv: if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None: env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space) env = BatchCompitableWrapper(env=env) + env= GripperPenaltyWrapper(env=env) return env