Add gripper penalty wrapper
This commit is contained in:
parent
cef944e1b1
commit
8d1936ffe0
|
@ -1069,6 +1069,29 @@ class ActionScaleWrapper(gym.ActionWrapper):
|
|||
return action * self.scale_vector, is_intervention
|
||||
|
||||
|
||||
class GripperPenaltyWrapper(gym.Wrapper):
|
||||
def __init__(self, env, penalty=-0.05):
|
||||
super().__init__(env)
|
||||
self.penalty = penalty
|
||||
self.last_gripper_pos = None
|
||||
|
||||
def reset(self, **kwargs):
|
||||
obs, info = self.env.reset(**kwargs)
|
||||
self.last_gripper_pos = obs["observation.state"][0, 0] # first idx for the gripper
|
||||
return obs, info
|
||||
|
||||
def step(self, action):
|
||||
observation, reward, terminated, truncated, info = self.env.step(action)
|
||||
|
||||
if (action[-1] < -0.5 and self.last_gripper_pos > 0.9) or (action[-1] > 0.5 and self.last_gripper_pos < 0.9):
|
||||
info["grasp_penalty"] = self.penalty
|
||||
else:
|
||||
info["grasp_penalty"] = 0.0
|
||||
|
||||
self.last_gripper_pos = observation["observation.state"][0, 0] # first idx for the gripper
|
||||
return observation, reward, terminated, truncated, info
|
||||
|
||||
|
||||
def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
||||
"""
|
||||
Factory function to create a vectorized robot environment.
|
||||
|
@ -1144,6 +1167,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
|||
if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None:
|
||||
env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space)
|
||||
env = BatchCompitableWrapper(env=env)
|
||||
env= GripperPenaltyWrapper(env=env)
|
||||
|
||||
return env
|
||||
|
||||
|
|
Loading…
Reference in New Issue