Add gripper penalty wrapper
This commit is contained in:
parent
cef944e1b1
commit
8d1936ffe0
|
@ -1069,6 +1069,29 @@ class ActionScaleWrapper(gym.ActionWrapper):
|
||||||
return action * self.scale_vector, is_intervention
|
return action * self.scale_vector, is_intervention
|
||||||
|
|
||||||
|
|
||||||
|
class GripperPenaltyWrapper(gym.Wrapper):
|
||||||
|
def __init__(self, env, penalty=-0.05):
|
||||||
|
super().__init__(env)
|
||||||
|
self.penalty = penalty
|
||||||
|
self.last_gripper_pos = None
|
||||||
|
|
||||||
|
def reset(self, **kwargs):
|
||||||
|
obs, info = self.env.reset(**kwargs)
|
||||||
|
self.last_gripper_pos = obs["observation.state"][0, 0] # first idx for the gripper
|
||||||
|
return obs, info
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
observation, reward, terminated, truncated, info = self.env.step(action)
|
||||||
|
|
||||||
|
if (action[-1] < -0.5 and self.last_gripper_pos > 0.9) or (action[-1] > 0.5 and self.last_gripper_pos < 0.9):
|
||||||
|
info["grasp_penalty"] = self.penalty
|
||||||
|
else:
|
||||||
|
info["grasp_penalty"] = 0.0
|
||||||
|
|
||||||
|
self.last_gripper_pos = observation["observation.state"][0, 0] # first idx for the gripper
|
||||||
|
return observation, reward, terminated, truncated, info
|
||||||
|
|
||||||
|
|
||||||
def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
||||||
"""
|
"""
|
||||||
Factory function to create a vectorized robot environment.
|
Factory function to create a vectorized robot environment.
|
||||||
|
@ -1144,6 +1167,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
||||||
if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None:
|
if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None:
|
||||||
env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space)
|
env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space)
|
||||||
env = BatchCompitableWrapper(env=env)
|
env = BatchCompitableWrapper(env=env)
|
||||||
|
env= GripperPenaltyWrapper(env=env)
|
||||||
|
|
||||||
return env
|
return env
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue