diff --git a/lerobot/common/envs/configs.py b/lerobot/common/envs/configs.py index 825fa162..440512c3 100644 --- a/lerobot/common/envs/configs.py +++ b/lerobot/common/envs/configs.py @@ -203,6 +203,9 @@ class EnvWrapperConfig: joint_masking_action_space: Optional[Any] = None ee_action_space_params: Optional[EEActionSpaceConfig] = None use_gripper: bool = False + gripper_quantization_threshold: float = 0.8 + gripper_penalty: float = 0.0 + open_gripper_on_reset: bool = False @EnvConfig.register_subclass(name="gym_manipulator") diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py index bf65b1ec..2af3995e 100644 --- a/lerobot/scripts/server/buffer.py +++ b/lerobot/scripts/server/buffer.py @@ -245,10 +245,6 @@ class ReplayBuffer: self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device) self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device) - # Initialize complementary_info storage - self.complementary_info_keys = [] - self.complementary_info_storage = {} - self.initialized = True def __len__(self): @@ -282,28 +278,6 @@ class ReplayBuffer: self.dones[self.position] = done self.truncateds[self.position] = truncated - # Store complementary info if provided - if complementary_info is not None: - # Initialize storage for new keys on first encounter - for key, value in complementary_info.items(): - if key not in self.complementary_info_keys: - self.complementary_info_keys.append(key) - if isinstance(value, torch.Tensor): - shape = value.shape if value.ndim > 0 else (1,) - self.complementary_info_storage[key] = torch.zeros( - (self.capacity, *shape), dtype=value.dtype, device=self.storage_device - ) - - # Store the value - if key in self.complementary_info_storage: - if isinstance(value, torch.Tensor): - self.complementary_info_storage[key][self.position] = value - else: - # For non-tensor values (like grasp_penalty) - self.complementary_info_storage[key][self.position] = torch.tensor( - value, device=self.storage_device - ) - self.position = (self.position + 1) % self.capacity self.size = min(self.size + 1, self.capacity) @@ -362,13 +336,6 @@ class ReplayBuffer: batch_dones = self.dones[idx].to(self.device).float() batch_truncateds = self.truncateds[idx].to(self.device).float() - # Add complementary_info to batch if it exists - batch_complementary_info = {} - if hasattr(self, "complementary_info_keys") and self.complementary_info_keys: - for key in self.complementary_info_keys: - if key in self.complementary_info_storage: - batch_complementary_info[key] = self.complementary_info_storage[key][idx].to(self.device) - return BatchTransition( state=batch_state, action=batch_actions, @@ -376,7 +343,6 @@ class ReplayBuffer: next_state=batch_next_state, done=batch_dones, truncated=batch_truncateds, - complementary_info=batch_complementary_info if batch_complementary_info else None, ) @classmethod diff --git a/lerobot/scripts/server/end_effector_control_utils.py b/lerobot/scripts/server/end_effector_control_utils.py index 3b2cfa90..3bd927b4 100644 --- a/lerobot/scripts/server/end_effector_control_utils.py +++ b/lerobot/scripts/server/end_effector_control_utils.py @@ -312,31 +312,6 @@ class GamepadController(InputController): logging.error("Error reading gamepad. Is it still connected?") return 0.0, 0.0, 0.0 - def get_gripper_action(self): - """ - Get gripper action using L3/R3 buttons. - Press left stick (L3) to open the gripper. - Press right stick (R3) to close the gripper. - """ - import pygame - - try: - # Check if buttons are pressed - l3_pressed = self.joystick.get_button(9) - r3_pressed = self.joystick.get_button(10) - - # Determine action based on button presses - if r3_pressed: - return 1.0 # Close gripper - elif l3_pressed: - return -1.0 # Open gripper - else: - return 0.0 # No change - - except pygame.error: - logging.error("Error reading gamepad. Is it still connected?") - return 0.0 - class GamepadControllerHID(InputController): """Generate motion deltas from gamepad input using HIDAPI.""" diff --git a/lerobot/scripts/server/gym_manipulator.py b/lerobot/scripts/server/gym_manipulator.py index ac3bbb0a..3aa75466 100644 --- a/lerobot/scripts/server/gym_manipulator.py +++ b/lerobot/scripts/server/gym_manipulator.py @@ -761,6 +761,62 @@ class BatchCompitableWrapper(gym.ObservationWrapper): return observation +class GripperPenaltyWrapper(gym.RewardWrapper): + def __init__(self, env, penalty: float = -0.1): + super().__init__(env) + self.penalty = penalty + self.last_gripper_state = None + + def reward(self, reward, action): + gripper_state_normalized = self.last_gripper_state / MAX_GRIPPER_COMMAND + + if isinstance(action, tuple): + action = action[0] + action_normalized = action[-1] / MAX_GRIPPER_COMMAND + + gripper_penalty_bool = (gripper_state_normalized < 0.1 and action_normalized > 0.9) or ( + gripper_state_normalized > 0.9 and action_normalized < 0.1 + ) + breakpoint() + + return reward + self.penalty * gripper_penalty_bool + + def step(self, action): + self.last_gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1] + obs, reward, terminated, truncated, info = self.env.step(action) + reward = self.reward(reward, action) + return obs, reward, terminated, truncated, info + + def reset(self, **kwargs): + self.last_gripper_state = None + return super().reset(**kwargs) + + +class GripperQuantizationWrapper(gym.ActionWrapper): + def __init__(self, env, quantization_threshold: float = 0.2): + super().__init__(env) + self.quantization_threshold = quantization_threshold + + def action(self, action): + is_intervention = False + if isinstance(action, tuple): + action, is_intervention = action + + gripper_command = action[-1] + # Quantize gripper command to -1, 0 or 1 + if gripper_command < -self.quantization_threshold: + gripper_command = -MAX_GRIPPER_COMMAND + elif gripper_command > self.quantization_threshold: + gripper_command = MAX_GRIPPER_COMMAND + else: + gripper_command = 0.0 + + gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1] + gripper_action = np.clip(gripper_state + gripper_command, 0, MAX_GRIPPER_COMMAND) + action[-1] = gripper_action.item() + return action, is_intervention + + class EEActionWrapper(gym.ActionWrapper): def __init__(self, env, ee_action_space_params=None, use_gripper=False): super().__init__(env) @@ -820,17 +876,7 @@ class EEActionWrapper(gym.ActionWrapper): fk_func=self.fk_function, ) if self.use_gripper: - # Quantize gripper command to -1, 0 or 1 - if gripper_command < -0.2: - gripper_command = -1.0 - elif gripper_command > 0.2: - gripper_command = 1.0 - else: - gripper_command = 0.0 - - gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1] - gripper_action = np.clip(gripper_state + gripper_command, 0, MAX_GRIPPER_COMMAND) - target_joint_pos[-1] = gripper_action + target_joint_pos[-1] = gripper_command return target_joint_pos, is_intervention @@ -1069,31 +1115,6 @@ class ActionScaleWrapper(gym.ActionWrapper): return action * self.scale_vector, is_intervention -class GripperPenaltyWrapper(gym.Wrapper): - def __init__(self, env, penalty=-0.05): - super().__init__(env) - self.penalty = penalty - self.last_gripper_pos = None - - def reset(self, **kwargs): - obs, info = self.env.reset(**kwargs) - self.last_gripper_pos = obs["observation.state"][0, 0] # first idx for the gripper - return obs, info - - def step(self, action): - observation, reward, terminated, truncated, info = self.env.step(action) - - if (action[-1] < -0.5 and self.last_gripper_pos > 0.9) or ( - action[-1] > 0.5 and self.last_gripper_pos < 0.9 - ): - info["grasp_penalty"] = self.penalty - else: - info["grasp_penalty"] = 0.0 - - self.last_gripper_pos = observation["observation.state"][0, 0] # first idx for the gripper - return observation, reward, terminated, truncated, info - - def make_robot_env(cfg) -> gym.vector.VectorEnv: """ Factory function to create a vectorized robot environment. @@ -1143,6 +1164,12 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv: # Add reward computation and control wrappers # env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device) env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps) + if cfg.wrapper.use_gripper: + env = GripperQuantizationWrapper( + env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold + ) + # env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty) + if cfg.wrapper.ee_action_space_params is not None: env = EEActionWrapper( env=env, @@ -1169,7 +1196,6 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv: if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None: env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space) env = BatchCompitableWrapper(env=env) - env = GripperPenaltyWrapper(env=env) return env diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 0f760dc5..15de2cb7 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -375,7 +375,6 @@ def add_actor_information_and_train( observations = batch["state"] next_observations = batch["next_state"] done = batch["done"] - complementary_info = batch["complementary_info"] check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations) observation_features, next_observation_features = get_observation_features( @@ -391,7 +390,6 @@ def add_actor_information_and_train( "done": done, "observation_feature": observation_features, "next_observation_feature": next_observation_features, - "complementary_info": complementary_info, } # Use the forward method for critic loss @@ -450,7 +448,6 @@ def add_actor_information_and_train( "done": done, "observation_feature": observation_features, "next_observation_feature": next_observation_features, - "complementary_info": complementary_info, } # Use the forward method for critic loss