[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2025-04-09 15:05:17 +00:00
parent 5428ab96f5
commit ba09f44eb7
3 changed files with 13 additions and 14 deletions

View File

@ -202,12 +202,11 @@ class EnvWrapperConfig:
ee_action_space_params: Optional[EEActionSpaceConfig] = None
use_gripper: bool = False
gripper_quantization_threshold: float | None = 0.8
gripper_penalty: float = 0.0
gripper_penalty: float = 0.0
gripper_penalty_in_reward: bool = False
open_gripper_on_reset: bool = False
@EnvConfig.register_subclass(name="gym_manipulator")
@dataclass
class HILSerlRobotEnvConfig(EnvConfig):

View File

@ -718,7 +718,7 @@ class ResetWrapper(gym.Wrapper):
env: HILSerlRobotEnv,
reset_pose: np.ndarray | None = None,
reset_time_s: float = 5,
open_gripper_on_reset: bool = False
open_gripper_on_reset: bool = False,
):
super().__init__(env)
self.reset_time_s = reset_time_s
@ -727,8 +727,6 @@ class ResetWrapper(gym.Wrapper):
self.open_gripper_on_reset = open_gripper_on_reset
def reset(self, *, seed=None, options=None):
if self.reset_pose is not None:
start_time = time.perf_counter()
log_say("Reset the environment.", play_sounds=True)
@ -777,12 +775,11 @@ class GripperPenaltyWrapper(gym.RewardWrapper):
self.penalty = penalty
self.gripper_penalty_in_reward = gripper_penalty_in_reward
self.last_gripper_state = None
def reward(self, reward, action):
gripper_state_normalized = self.last_gripper_state / MAX_GRIPPER_COMMAND
action_normalized = action - 1.0 #action / MAX_GRIPPER_COMMAND
action_normalized = action - 1.0 # action / MAX_GRIPPER_COMMAND
gripper_penalty_bool = (gripper_state_normalized < 0.5 and action_normalized > 0.5) or (
gripper_state_normalized > 0.75 and action_normalized < -0.5
@ -803,7 +800,7 @@ class GripperPenaltyWrapper(gym.RewardWrapper):
reward += gripper_penalty
else:
info["gripper_penalty"] = gripper_penalty
return obs, reward, terminated, truncated, info
def reset(self, **kwargs):
@ -813,6 +810,7 @@ class GripperPenaltyWrapper(gym.RewardWrapper):
info["gripper_penalty"] = 0.0
return obs, info
class GripperActionWrapper(gym.ActionWrapper):
def __init__(self, env, quantization_threshold: float = 0.2):
super().__init__(env)
@ -1189,11 +1187,13 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
# env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
if cfg.wrapper.use_gripper:
env = GripperActionWrapper(
env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold
)
env = GripperActionWrapper(env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold)
if cfg.wrapper.gripper_penalty is not None:
env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty, gripper_penalty_in_reward=cfg.wrapper.gripper_penalty_in_reward)
env = GripperPenaltyWrapper(
env=env,
penalty=cfg.wrapper.gripper_penalty,
gripper_penalty_in_reward=cfg.wrapper.gripper_penalty_in_reward,
)
if cfg.wrapper.ee_action_space_params is not None:
env = EEActionWrapper(
@ -1218,7 +1218,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
env=env,
reset_pose=cfg.wrapper.fixed_reset_joint_positions,
reset_time_s=cfg.wrapper.reset_time_s,
open_gripper_on_reset=cfg.wrapper.open_gripper_on_reset
open_gripper_on_reset=cfg.wrapper.open_gripper_on_reset,
)
if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None:
env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space)

View File

@ -406,7 +406,7 @@ def add_actor_information_and_train(
"next_state": next_observations,
"done": done,
"observation_feature": observation_features,
"next_observation_feature": next_observation_features,
"next_observation_feature": next_observation_features,
"complementary_info": batch["complementary_info"],
}