Add gripper penalty wrapper

This commit is contained in:
s1lent4gnt 2025-03-31 17:38:16 +02:00 committed by Adil Zouitine
parent cef944e1b1
commit 8d1936ffe0
1 changed files with 24 additions and 0 deletions

View File

@ -1069,6 +1069,29 @@ class ActionScaleWrapper(gym.ActionWrapper):
return action * self.scale_vector, is_intervention
class GripperPenaltyWrapper(gym.Wrapper):
def __init__(self, env, penalty=-0.05):
super().__init__(env)
self.penalty = penalty
self.last_gripper_pos = None
def reset(self, **kwargs):
obs, info = self.env.reset(**kwargs)
self.last_gripper_pos = obs["observation.state"][0, 0] # first idx for the gripper
return obs, info
def step(self, action):
observation, reward, terminated, truncated, info = self.env.step(action)
if (action[-1] < -0.5 and self.last_gripper_pos > 0.9) or (action[-1] > 0.5 and self.last_gripper_pos < 0.9):
info["grasp_penalty"] = self.penalty
else:
info["grasp_penalty"] = 0.0
self.last_gripper_pos = observation["observation.state"][0, 0] # first idx for the gripper
return observation, reward, terminated, truncated, info
def make_robot_env(cfg) -> gym.vector.VectorEnv:
"""
Factory function to create a vectorized robot environment.
@ -1144,6 +1167,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None:
env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space)
env = BatchCompitableWrapper(env=env)
env= GripperPenaltyWrapper(env=env)
return env