diff --git a/lerobot/common/envs/configs.py b/lerobot/common/envs/configs.py index a6eda93b..02911332 100644 --- a/lerobot/common/envs/configs.py +++ b/lerobot/common/envs/configs.py @@ -171,7 +171,6 @@ class VideoRecordConfig: class WrapperConfig: """Configuration for environment wrappers.""" - delta_action: float | None = None joint_masking_action_space: list[bool] | None = None @@ -191,7 +190,6 @@ class EnvWrapperConfig: """Configuration for environment wrappers.""" display_cameras: bool = False - delta_action: float = 0.1 use_relative_joint_positions: bool = True add_joint_velocity_to_observation: bool = False add_ee_pose_to_observation: bool = False @@ -203,11 +201,13 @@ class EnvWrapperConfig: joint_masking_action_space: Optional[Any] = None ee_action_space_params: Optional[EEActionSpaceConfig] = None use_gripper: bool = False - gripper_quantization_threshold: float = 0.8 - gripper_penalty: float = 0.0 + gripper_quantization_threshold: float | None = 0.8 + gripper_penalty: float = 0.0 + gripper_penalty_in_reward: bool = False open_gripper_on_reset: bool = False + @EnvConfig.register_subclass(name="gym_manipulator") @dataclass class HILSerlRobotEnvConfig(EnvConfig): diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index b51f9b8f..b8827a1b 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -428,6 +428,7 @@ class SACPolicy( actions_discrete = torch.round(actions_discrete) actions_discrete = actions_discrete.long() + gripper_penalties: Tensor | None = None if complementary_info is not None: gripper_penalties: Tensor | None = complementary_info.get("gripper_penalty") diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py index c834e9e9..170a35de 100644 --- a/lerobot/common/robot_devices/control_utils.py +++ b/lerobot/common/robot_devices/control_utils.py @@ -221,7 +221,6 @@ def record_episode( events=events, policy=policy, fps=fps, - # record_delta_actions=record_delta_actions, teleoperate=policy is None, single_task=single_task, ) @@ -267,8 +266,6 @@ def control_loop( if teleoperate: observation, action = robot.teleop_step(record_data=True) - # if record_delta_actions: - # action["action"] = action["action"] - current_joint_positions else: observation = robot.capture_observation() diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index 1013001a..658371a1 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -363,8 +363,6 @@ def replay( start_episode_t = time.perf_counter() action = actions[idx]["action"] - # if replay_delta_actions: - # action = action + current_joint_positions robot.send_action(action) dt_s = time.perf_counter() - start_episode_t diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py index 8db1a82c..92e03d33 100644 --- a/lerobot/scripts/server/buffer.py +++ b/lerobot/scripts/server/buffer.py @@ -78,9 +78,7 @@ def move_transition_to_device(transition: Transition, device: str = "cpu") -> Tr if isinstance(val, torch.Tensor): transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking) elif isinstance(val, (int, float, bool)): - transition["complementary_info"][key] = torch.tensor( - val, device=device, non_blocking=non_blocking - ) + transition["complementary_info"][key] = torch.tensor(val, device=device) else: raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]") return transition @@ -505,7 +503,6 @@ class ReplayBuffer: state_keys: Optional[Sequence[str]] = None, capacity: Optional[int] = None, action_mask: Optional[Sequence[int]] = None, - action_delta: Optional[float] = None, image_augmentation_function: Optional[Callable] = None, use_drq: bool = True, storage_device: str = "cpu", @@ -520,7 +517,6 @@ class ReplayBuffer: state_keys (Optional[Sequence[str]]): The list of keys that appear in `state` and `next_state`. capacity (Optional[int]): Buffer capacity. If None, uses dataset length. action_mask (Optional[Sequence[int]]): Indices of action dimensions to keep. - action_delta (Optional[float]): Factor to divide actions by. image_augmentation_function (Optional[Callable]): Function for image augmentation. If None, uses default random shift with pad=4. use_drq (bool): Whether to use DrQ image augmentation when sampling. @@ -565,9 +561,6 @@ class ReplayBuffer: else: first_action = first_action[:, action_mask] - if action_delta is not None: - first_action = first_action / action_delta - # Get complementary info if available first_complementary_info = None if ( @@ -598,9 +591,6 @@ class ReplayBuffer: else: action = action[:, action_mask] - if action_delta is not None: - action = action / action_delta - replay_buffer.add( state=data["state"], action=action, diff --git a/lerobot/scripts/server/gym_manipulator.py b/lerobot/scripts/server/gym_manipulator.py index 3aa75466..44bbcf9b 100644 --- a/lerobot/scripts/server/gym_manipulator.py +++ b/lerobot/scripts/server/gym_manipulator.py @@ -42,7 +42,6 @@ class HILSerlRobotEnv(gym.Env): self, robot, use_delta_action_space: bool = True, - delta: float | None = None, display_cameras: bool = False, ): """ @@ -55,8 +54,6 @@ class HILSerlRobotEnv(gym.Env): robot: The robot interface object used to connect and interact with the physical robot. use_delta_action_space (bool): If True, uses a delta (relative) action space for joint control. Otherwise, absolute joint positions are used. - delta (float or None): A scaling factor for the relative adjustments applied to joint positions. Should be a value between - 0 and 1 when using a delta action space. display_cameras (bool): If True, the robot's camera feeds will be displayed during execution. """ super().__init__() @@ -74,7 +71,6 @@ class HILSerlRobotEnv(gym.Env): self.current_step = 0 self.episode_data = None - self.delta = delta self.use_delta_action_space = use_delta_action_space self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position") @@ -374,7 +370,7 @@ class RewardWrapper(gym.Wrapper): self.device = device def step(self, action): - observation, _, terminated, truncated, info = self.env.step(action) + observation, reward, terminated, truncated, info = self.env.step(action) images = [ observation[key].to(self.device, non_blocking=self.device.type == "cuda") for key in observation @@ -382,15 +378,17 @@ class RewardWrapper(gym.Wrapper): ] start_time = time.perf_counter() with torch.inference_mode(): - reward = ( + success = ( self.reward_classifier.predict_reward(images, threshold=0.8) if self.reward_classifier is not None else 0.0 ) info["Reward classifer frequency"] = 1 / (time.perf_counter() - start_time) - if reward == 1.0: + if success == 1.0: terminated = True + reward = 1.0 + return observation, reward, terminated, truncated, info def reset(self, seed=None, options=None): @@ -720,19 +718,31 @@ class ResetWrapper(gym.Wrapper): env: HILSerlRobotEnv, reset_pose: np.ndarray | None = None, reset_time_s: float = 5, + open_gripper_on_reset: bool = False ): super().__init__(env) self.reset_time_s = reset_time_s self.reset_pose = reset_pose self.robot = self.unwrapped.robot + self.open_gripper_on_reset = open_gripper_on_reset def reset(self, *, seed=None, options=None): + + if self.reset_pose is not None: start_time = time.perf_counter() log_say("Reset the environment.", play_sounds=True) reset_follower_position(self.robot, self.reset_pose) busy_wait(self.reset_time_s - (time.perf_counter() - start_time)) log_say("Reset the environment done.", play_sounds=True) + if self.open_gripper_on_reset: + current_joint_pos = self.robot.follower_arms["main"].read("Present_Position") + current_joint_pos[-1] = MAX_GRIPPER_COMMAND + self.robot.send_action(torch.from_numpy(current_joint_pos)) + busy_wait(0.1) + current_joint_pos[-1] = 0.0 + self.robot.send_action(torch.from_numpy(current_joint_pos)) + busy_wait(0.2) else: log_say( f"Manually reset the environment for {self.reset_time_s} seconds.", @@ -762,37 +772,48 @@ class BatchCompitableWrapper(gym.ObservationWrapper): class GripperPenaltyWrapper(gym.RewardWrapper): - def __init__(self, env, penalty: float = -0.1): + def __init__(self, env, penalty: float = -0.1, gripper_penalty_in_reward: bool = True): super().__init__(env) self.penalty = penalty + self.gripper_penalty_in_reward = gripper_penalty_in_reward self.last_gripper_state = None + def reward(self, reward, action): gripper_state_normalized = self.last_gripper_state / MAX_GRIPPER_COMMAND - if isinstance(action, tuple): - action = action[0] - action_normalized = action[-1] / MAX_GRIPPER_COMMAND + action_normalized = action - 1.0 #action / MAX_GRIPPER_COMMAND - gripper_penalty_bool = (gripper_state_normalized < 0.1 and action_normalized > 0.9) or ( - gripper_state_normalized > 0.9 and action_normalized < 0.1 + gripper_penalty_bool = (gripper_state_normalized < 0.5 and action_normalized > 0.5) or ( + gripper_state_normalized > 0.75 and action_normalized < -0.5 ) - breakpoint() - return reward + self.penalty * gripper_penalty_bool + return reward + self.penalty * int(gripper_penalty_bool) def step(self, action): self.last_gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1] + if isinstance(action, tuple): + gripper_action = action[0][-1] + else: + gripper_action = action[-1] obs, reward, terminated, truncated, info = self.env.step(action) - reward = self.reward(reward, action) + gripper_penalty = self.reward(reward, gripper_action) + + if self.gripper_penalty_in_reward: + reward += gripper_penalty + else: + info["gripper_penalty"] = gripper_penalty + return obs, reward, terminated, truncated, info def reset(self, **kwargs): self.last_gripper_state = None - return super().reset(**kwargs) + obs, info = super().reset(**kwargs) + if self.gripper_penalty_in_reward: + info["gripper_penalty"] = 0.0 + return obs, info - -class GripperQuantizationWrapper(gym.ActionWrapper): +class GripperActionWrapper(gym.ActionWrapper): def __init__(self, env, quantization_threshold: float = 0.2): super().__init__(env) self.quantization_threshold = quantization_threshold @@ -801,16 +822,18 @@ class GripperQuantizationWrapper(gym.ActionWrapper): is_intervention = False if isinstance(action, tuple): action, is_intervention = action - gripper_command = action[-1] - # Quantize gripper command to -1, 0 or 1 - if gripper_command < -self.quantization_threshold: - gripper_command = -MAX_GRIPPER_COMMAND - elif gripper_command > self.quantization_threshold: - gripper_command = MAX_GRIPPER_COMMAND - else: - gripper_command = 0.0 + # Gripper actions are between 0, 2 + # we want to quantize them to -1, 0 or 1 + gripper_command = gripper_command - 1.0 + + if self.quantization_threshold is not None: + # Quantize gripper command to -1, 0 or 1 + gripper_command = ( + np.sign(gripper_command) if abs(gripper_command) > self.quantization_threshold else 0.0 + ) + gripper_command = gripper_command * MAX_GRIPPER_COMMAND gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1] gripper_action = np.clip(gripper_state + gripper_command, 0, MAX_GRIPPER_COMMAND) action[-1] = gripper_action.item() @@ -836,10 +859,12 @@ class EEActionWrapper(gym.ActionWrapper): ] ) if self.use_gripper: - action_space_bounds = np.concatenate([action_space_bounds, [1.0]]) + # gripper actions open at 2.0, and closed at 0.0 + min_action_space_bounds = np.concatenate([-action_space_bounds, [0.0]]) + max_action_space_bounds = np.concatenate([action_space_bounds, [2.0]]) ee_action_space = gym.spaces.Box( - low=-action_space_bounds, - high=action_space_bounds, + low=min_action_space_bounds, + high=max_action_space_bounds, shape=(3 + int(self.use_gripper),), dtype=np.float32, ) @@ -997,11 +1022,11 @@ class GamepadControlWrapper(gym.Wrapper): if self.use_gripper: gripper_command = self.controller.gripper_command() if gripper_command == "open": - gamepad_action = np.concatenate([gamepad_action, [1.0]]) + gamepad_action = np.concatenate([gamepad_action, [2.0]]) elif gripper_command == "close": - gamepad_action = np.concatenate([gamepad_action, [-1.0]]) - else: gamepad_action = np.concatenate([gamepad_action, [0.0]]) + else: + gamepad_action = np.concatenate([gamepad_action, [1.0]]) # Check episode ending buttons # We'll rely on controller.get_episode_end_status() which returns "success", "failure", or None @@ -1141,7 +1166,6 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv: env = HILSerlRobotEnv( robot=robot, display_cameras=cfg.wrapper.display_cameras, - delta=cfg.wrapper.delta_action, use_delta_action_space=cfg.wrapper.use_relative_joint_positions and cfg.wrapper.ee_action_space_params is None, ) @@ -1165,10 +1189,11 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv: # env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device) env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps) if cfg.wrapper.use_gripper: - env = GripperQuantizationWrapper( + env = GripperActionWrapper( env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold ) - # env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty) + if cfg.wrapper.gripper_penalty is not None: + env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty, gripper_penalty_in_reward=cfg.wrapper.gripper_penalty_in_reward) if cfg.wrapper.ee_action_space_params is not None: env = EEActionWrapper( @@ -1176,6 +1201,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv: ee_action_space_params=cfg.wrapper.ee_action_space_params, use_gripper=cfg.wrapper.use_gripper, ) + if cfg.wrapper.ee_action_space_params is not None and cfg.wrapper.ee_action_space_params.use_gamepad: # env = ActionScaleWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params) env = GamepadControlWrapper( @@ -1192,6 +1218,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv: env=env, reset_pose=cfg.wrapper.fixed_reset_joint_positions, reset_time_s=cfg.wrapper.reset_time_s, + open_gripper_on_reset=cfg.wrapper.open_gripper_on_reset ) if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None: env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space) @@ -1341,11 +1368,10 @@ def record_dataset(env, policy, cfg): dataset.push_to_hub() -def replay_episode(env, repo_id, root=None, episode=0): +def replay_episode(env, cfg): from lerobot.common.datasets.lerobot_dataset import LeRobotDataset - local_files_only = root is not None - dataset = LeRobotDataset(repo_id, root=root, episodes=[episode], local_files_only=local_files_only) + dataset = LeRobotDataset(cfg.repo_id, root=cfg.dataset_root, episodes=[cfg.episode]) env.reset() actions = dataset.hf_dataset.select_columns("action") @@ -1353,7 +1379,7 @@ def replay_episode(env, repo_id, root=None, episode=0): for idx in range(dataset.num_frames): start_episode_t = time.perf_counter() - action = actions[idx]["action"][:4] + action = actions[idx]["action"] env.step((action, False)) # env.step((action / env.unwrapped.delta, False)) @@ -1384,9 +1410,7 @@ def main(cfg: EnvConfig): if cfg.mode == "replay": replay_episode( env, - cfg.replay_repo_id, - root=cfg.dataset_root, - episode=cfg.replay_episode, + cfg=cfg, ) exit() diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 707547a1..e4bcc620 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -406,7 +406,8 @@ def add_actor_information_and_train( "next_state": next_observations, "done": done, "observation_feature": observation_features, - "next_observation_feature": next_observation_features, + "next_observation_feature": next_observation_features, + "complementary_info": batch["complementary_info"], } # Use the forward method for critic loss (includes both main critic and grasp critic) @@ -992,7 +993,6 @@ def initialize_offline_replay_buffer( device=device, state_keys=cfg.policy.input_features.keys(), action_mask=active_action_dims, - action_delta=cfg.env.wrapper.delta_action, storage_device=storage_device, optimize_memory=True, capacity=cfg.policy.offline_buffer_capacity,