[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2025-04-09 15:05:17 +00:00
parent 5428ab96f5
commit ba09f44eb7
3 changed files with 13 additions and 14 deletions

View File

@ -207,7 +207,6 @@ class EnvWrapperConfig:
open_gripper_on_reset: bool = False open_gripper_on_reset: bool = False
@EnvConfig.register_subclass(name="gym_manipulator") @EnvConfig.register_subclass(name="gym_manipulator")
@dataclass @dataclass
class HILSerlRobotEnvConfig(EnvConfig): class HILSerlRobotEnvConfig(EnvConfig):

View File

@ -718,7 +718,7 @@ class ResetWrapper(gym.Wrapper):
env: HILSerlRobotEnv, env: HILSerlRobotEnv,
reset_pose: np.ndarray | None = None, reset_pose: np.ndarray | None = None,
reset_time_s: float = 5, reset_time_s: float = 5,
open_gripper_on_reset: bool = False open_gripper_on_reset: bool = False,
): ):
super().__init__(env) super().__init__(env)
self.reset_time_s = reset_time_s self.reset_time_s = reset_time_s
@ -727,8 +727,6 @@ class ResetWrapper(gym.Wrapper):
self.open_gripper_on_reset = open_gripper_on_reset self.open_gripper_on_reset = open_gripper_on_reset
def reset(self, *, seed=None, options=None): def reset(self, *, seed=None, options=None):
if self.reset_pose is not None: if self.reset_pose is not None:
start_time = time.perf_counter() start_time = time.perf_counter()
log_say("Reset the environment.", play_sounds=True) log_say("Reset the environment.", play_sounds=True)
@ -778,7 +776,6 @@ class GripperPenaltyWrapper(gym.RewardWrapper):
self.gripper_penalty_in_reward = gripper_penalty_in_reward self.gripper_penalty_in_reward = gripper_penalty_in_reward
self.last_gripper_state = None self.last_gripper_state = None
def reward(self, reward, action): def reward(self, reward, action):
gripper_state_normalized = self.last_gripper_state / MAX_GRIPPER_COMMAND gripper_state_normalized = self.last_gripper_state / MAX_GRIPPER_COMMAND
@ -813,6 +810,7 @@ class GripperPenaltyWrapper(gym.RewardWrapper):
info["gripper_penalty"] = 0.0 info["gripper_penalty"] = 0.0
return obs, info return obs, info
class GripperActionWrapper(gym.ActionWrapper): class GripperActionWrapper(gym.ActionWrapper):
def __init__(self, env, quantization_threshold: float = 0.2): def __init__(self, env, quantization_threshold: float = 0.2):
super().__init__(env) super().__init__(env)
@ -1189,11 +1187,13 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
# env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device) # env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps) env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
if cfg.wrapper.use_gripper: if cfg.wrapper.use_gripper:
env = GripperActionWrapper( env = GripperActionWrapper(env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold)
env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold
)
if cfg.wrapper.gripper_penalty is not None: if cfg.wrapper.gripper_penalty is not None:
env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty, gripper_penalty_in_reward=cfg.wrapper.gripper_penalty_in_reward) env = GripperPenaltyWrapper(
env=env,
penalty=cfg.wrapper.gripper_penalty,
gripper_penalty_in_reward=cfg.wrapper.gripper_penalty_in_reward,
)
if cfg.wrapper.ee_action_space_params is not None: if cfg.wrapper.ee_action_space_params is not None:
env = EEActionWrapper( env = EEActionWrapper(
@ -1218,7 +1218,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
env=env, env=env,
reset_pose=cfg.wrapper.fixed_reset_joint_positions, reset_pose=cfg.wrapper.fixed_reset_joint_positions,
reset_time_s=cfg.wrapper.reset_time_s, reset_time_s=cfg.wrapper.reset_time_s,
open_gripper_on_reset=cfg.wrapper.open_gripper_on_reset open_gripper_on_reset=cfg.wrapper.open_gripper_on_reset,
) )
if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None: if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None:
env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space) env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space)