diff --git a/lerobot/common/envs/configs.py b/lerobot/common/envs/configs.py index 02911332..77220c3c 100644 --- a/lerobot/common/envs/configs.py +++ b/lerobot/common/envs/configs.py @@ -202,12 +202,11 @@ class EnvWrapperConfig: ee_action_space_params: Optional[EEActionSpaceConfig] = None use_gripper: bool = False gripper_quantization_threshold: float | None = 0.8 - gripper_penalty: float = 0.0 + gripper_penalty: float = 0.0 gripper_penalty_in_reward: bool = False open_gripper_on_reset: bool = False - @EnvConfig.register_subclass(name="gym_manipulator") @dataclass class HILSerlRobotEnvConfig(EnvConfig): diff --git a/lerobot/scripts/server/gym_manipulator.py b/lerobot/scripts/server/gym_manipulator.py index 44bbcf9b..6a8e848a 100644 --- a/lerobot/scripts/server/gym_manipulator.py +++ b/lerobot/scripts/server/gym_manipulator.py @@ -718,7 +718,7 @@ class ResetWrapper(gym.Wrapper): env: HILSerlRobotEnv, reset_pose: np.ndarray | None = None, reset_time_s: float = 5, - open_gripper_on_reset: bool = False + open_gripper_on_reset: bool = False, ): super().__init__(env) self.reset_time_s = reset_time_s @@ -727,8 +727,6 @@ class ResetWrapper(gym.Wrapper): self.open_gripper_on_reset = open_gripper_on_reset def reset(self, *, seed=None, options=None): - - if self.reset_pose is not None: start_time = time.perf_counter() log_say("Reset the environment.", play_sounds=True) @@ -777,12 +775,11 @@ class GripperPenaltyWrapper(gym.RewardWrapper): self.penalty = penalty self.gripper_penalty_in_reward = gripper_penalty_in_reward self.last_gripper_state = None - def reward(self, reward, action): gripper_state_normalized = self.last_gripper_state / MAX_GRIPPER_COMMAND - action_normalized = action - 1.0 #action / MAX_GRIPPER_COMMAND + action_normalized = action - 1.0 # action / MAX_GRIPPER_COMMAND gripper_penalty_bool = (gripper_state_normalized < 0.5 and action_normalized > 0.5) or ( gripper_state_normalized > 0.75 and action_normalized < -0.5 @@ -803,7 +800,7 @@ class GripperPenaltyWrapper(gym.RewardWrapper): reward += gripper_penalty else: info["gripper_penalty"] = gripper_penalty - + return obs, reward, terminated, truncated, info def reset(self, **kwargs): @@ -813,6 +810,7 @@ class GripperPenaltyWrapper(gym.RewardWrapper): info["gripper_penalty"] = 0.0 return obs, info + class GripperActionWrapper(gym.ActionWrapper): def __init__(self, env, quantization_threshold: float = 0.2): super().__init__(env) @@ -1189,11 +1187,13 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv: # env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device) env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps) if cfg.wrapper.use_gripper: - env = GripperActionWrapper( - env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold - ) + env = GripperActionWrapper(env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold) if cfg.wrapper.gripper_penalty is not None: - env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty, gripper_penalty_in_reward=cfg.wrapper.gripper_penalty_in_reward) + env = GripperPenaltyWrapper( + env=env, + penalty=cfg.wrapper.gripper_penalty, + gripper_penalty_in_reward=cfg.wrapper.gripper_penalty_in_reward, + ) if cfg.wrapper.ee_action_space_params is not None: env = EEActionWrapper( @@ -1218,7 +1218,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv: env=env, reset_pose=cfg.wrapper.fixed_reset_joint_positions, reset_time_s=cfg.wrapper.reset_time_s, - open_gripper_on_reset=cfg.wrapper.open_gripper_on_reset + open_gripper_on_reset=cfg.wrapper.open_gripper_on_reset, ) if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None: env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space) diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index e4bcc620..5b39d0d3 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -406,7 +406,7 @@ def add_actor_information_and_train( "next_state": next_observations, "done": done, "observation_feature": observation_features, - "next_observation_feature": next_observation_features, + "next_observation_feature": next_observation_features, "complementary_info": batch["complementary_info"], }