diff --git a/lerobot/common/envs/configs.py b/lerobot/common/envs/configs.py index a6eda93b..d91df43f 100644 --- a/lerobot/common/envs/configs.py +++ b/lerobot/common/envs/configs.py @@ -171,7 +171,6 @@ class VideoRecordConfig: class WrapperConfig: """Configuration for environment wrappers.""" - delta_action: float | None = None joint_masking_action_space: list[bool] | None = None @@ -191,7 +190,6 @@ class EnvWrapperConfig: """Configuration for environment wrappers.""" display_cameras: bool = False - delta_action: float = 0.1 use_relative_joint_positions: bool = True add_joint_velocity_to_observation: bool = False add_ee_pose_to_observation: bool = False @@ -203,7 +201,7 @@ class EnvWrapperConfig: joint_masking_action_space: Optional[Any] = None ee_action_space_params: Optional[EEActionSpaceConfig] = None use_gripper: bool = False - gripper_quantization_threshold: float = 0.8 + gripper_quantization_threshold: float | None = None gripper_penalty: float = 0.0 open_gripper_on_reset: bool = False diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 0e6f8fda..29f4b5ff 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -144,8 +144,8 @@ class SACPolicy( self.grasp_critic_target.load_state_dict(self.grasp_critic.state_dict()) - self.grasp_critic = torch.compile(self.grasp_critic) - self.grasp_critic_target = torch.compile(self.grasp_critic_target) + # self.grasp_critic = torch.compile(self.grasp_critic) + # self.grasp_critic_target = torch.compile(self.grasp_critic_target) self.actor = Policy( encoder=encoder_actor, @@ -224,6 +224,10 @@ class SACPolicy( critics = self.critic_target if use_target else self.critic_ensemble q_values = critics(observations, actions, observation_features) + if not use_target: + for name, param in critics.named_parameters(): + if param.requires_grad: + print(f"Critic Ensemble layer {name}, norm {param.data.norm().item()}") return q_values def grasp_critic_forward(self, observations, use_target=False, observation_features=None) -> torch.Tensor: @@ -239,6 +243,10 @@ class SACPolicy( """ grasp_critic = self.grasp_critic_target if use_target else self.grasp_critic q_values = grasp_critic(observations, observation_features) + if not use_target: + for name, param in grasp_critic.named_parameters(): + if param.requires_grad: + print(f"Grasp critic layer {name}, norm {param.data.norm().item()}") return q_values def forward( diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py index c834e9e9..170a35de 100644 --- a/lerobot/common/robot_devices/control_utils.py +++ b/lerobot/common/robot_devices/control_utils.py @@ -221,7 +221,6 @@ def record_episode( events=events, policy=policy, fps=fps, - # record_delta_actions=record_delta_actions, teleoperate=policy is None, single_task=single_task, ) @@ -267,8 +266,6 @@ def control_loop( if teleoperate: observation, action = robot.teleop_step(record_data=True) - # if record_delta_actions: - # action["action"] = action["action"] - current_joint_positions else: observation = robot.capture_observation() diff --git a/lerobot/scripts/server/actor_server.py b/lerobot/scripts/server/actor_server.py index c76dc003..7ac8343e 100644 --- a/lerobot/scripts/server/actor_server.py +++ b/lerobot/scripts/server/actor_server.py @@ -250,28 +250,18 @@ def act_with_policy( logging.info("[ACTOR] Shutting down act_with_policy") return - if interaction_step >= cfg.policy.online_step_before_learning: - # Time policy inference and check if it meets FPS requirement - with TimerManager( - elapsed_time_list=list_policy_time, - label="Policy inference time", - log=False, - ) as timer: # noqa: F841 - action = policy.select_action(batch=obs) - policy_fps = 1.0 / (list_policy_time[-1] + 1e-9) + # Time policy inference and check if it meets FPS requirement + with TimerManager( + elapsed_time_list=list_policy_time, + label="Policy inference time", + log=False, + ) as timer: # noqa: F841 + action = policy.select_action(batch=obs) + policy_fps = 1.0 / (list_policy_time[-1] + 1e-9) - log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step) + log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step) - next_obs, reward, done, truncated, info = online_env.step(action.squeeze(dim=0).cpu().numpy()) - else: - # TODO (azouitine): Make a custom space for torch tensor - action = online_env.action_space.sample() - next_obs, reward, done, truncated, info = online_env.step(action) - - # HACK: We have only one env but we want to batch it, it will be resolved with the torch box - action = ( - torch.from_numpy(action[0]).to(device, non_blocking=device.type == "cuda").unsqueeze(dim=0) - ) + next_obs, reward, done, truncated, info = online_env.step(action.squeeze(dim=0).cpu().numpy()) sum_reward_episode += float(reward) # Increment total steps counter for intervention rate diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py index 8db1a82c..e95f8f55 100644 --- a/lerobot/scripts/server/buffer.py +++ b/lerobot/scripts/server/buffer.py @@ -505,7 +505,6 @@ class ReplayBuffer: state_keys: Optional[Sequence[str]] = None, capacity: Optional[int] = None, action_mask: Optional[Sequence[int]] = None, - action_delta: Optional[float] = None, image_augmentation_function: Optional[Callable] = None, use_drq: bool = True, storage_device: str = "cpu", @@ -520,7 +519,6 @@ class ReplayBuffer: state_keys (Optional[Sequence[str]]): The list of keys that appear in `state` and `next_state`. capacity (Optional[int]): Buffer capacity. If None, uses dataset length. action_mask (Optional[Sequence[int]]): Indices of action dimensions to keep. - action_delta (Optional[float]): Factor to divide actions by. image_augmentation_function (Optional[Callable]): Function for image augmentation. If None, uses default random shift with pad=4. use_drq (bool): Whether to use DrQ image augmentation when sampling. @@ -565,8 +563,6 @@ class ReplayBuffer: else: first_action = first_action[:, action_mask] - if action_delta is not None: - first_action = first_action / action_delta # Get complementary info if available first_complementary_info = None @@ -598,8 +594,6 @@ class ReplayBuffer: else: action = action[:, action_mask] - if action_delta is not None: - action = action / action_delta replay_buffer.add( state=data["state"], diff --git a/lerobot/scripts/server/end_effector_control_utils.py b/lerobot/scripts/server/end_effector_control_utils.py index 3bd927b4..e4a6156f 100644 --- a/lerobot/scripts/server/end_effector_control_utils.py +++ b/lerobot/scripts/server/end_effector_control_utils.py @@ -258,24 +258,24 @@ class GamepadController(InputController): elif event.button == 0: self.episode_end_status = "rerecord_episode" - # RB button (6) for opening gripper + # LT button for closing gripper elif event.button == 6: - self.open_gripper_command = True - - # LT button (7) for closing gripper - elif event.button == 7: self.close_gripper_command = True + # RB button for opening gripper + elif event.button == 7: + self.open_gripper_command = True + # Reset episode status on button release elif event.type == pygame.JOYBUTTONUP: if event.button in [0, 2, 3]: self.episode_end_status = None - - elif event.button == 6: - self.open_gripper_command = False - - elif event.button == 7: + + if event.button == 6: self.close_gripper_command = False + + if event.button == 7: + self.open_gripper_command = False # Check for RB button (typically button 5) for intervention flag if self.joystick.get_button(5): diff --git a/lerobot/scripts/server/gym_manipulator.py b/lerobot/scripts/server/gym_manipulator.py index 3aa75466..8eae6fb9 100644 --- a/lerobot/scripts/server/gym_manipulator.py +++ b/lerobot/scripts/server/gym_manipulator.py @@ -42,7 +42,6 @@ class HILSerlRobotEnv(gym.Env): self, robot, use_delta_action_space: bool = True, - delta: float | None = None, display_cameras: bool = False, ): """ @@ -55,8 +54,6 @@ class HILSerlRobotEnv(gym.Env): robot: The robot interface object used to connect and interact with the physical robot. use_delta_action_space (bool): If True, uses a delta (relative) action space for joint control. Otherwise, absolute joint positions are used. - delta (float or None): A scaling factor for the relative adjustments applied to joint positions. Should be a value between - 0 and 1 when using a delta action space. display_cameras (bool): If True, the robot's camera feeds will be displayed during execution. """ super().__init__() @@ -74,7 +71,6 @@ class HILSerlRobotEnv(gym.Env): self.current_step = 0 self.episode_data = None - self.delta = delta self.use_delta_action_space = use_delta_action_space self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position") @@ -555,6 +551,9 @@ class ImageCropResizeWrapper(gym.Wrapper): # TODO(michel-aractingi): Bug in resize, it returns values outside [0, 1] obs[k] = obs[k].clamp(0.0, 1.0) + # import cv2 + # cv2.imwrite(f"tmp_img/{k}.jpg", cv2.cvtColor(obs[k].squeeze(0).permute(1,2,0).cpu().numpy()*255, cv2.COLOR_RGB2BGR)) + # Check for NaNs after processing if torch.isnan(obs[k]).any(): logging.error(f"NaN values detected in observation {k} after crop and resize") @@ -720,19 +719,31 @@ class ResetWrapper(gym.Wrapper): env: HILSerlRobotEnv, reset_pose: np.ndarray | None = None, reset_time_s: float = 5, + open_gripper_on_reset: bool = False ): super().__init__(env) self.reset_time_s = reset_time_s self.reset_pose = reset_pose self.robot = self.unwrapped.robot + self.open_gripper_on_reset = open_gripper_on_reset def reset(self, *, seed=None, options=None): + + if self.reset_pose is not None: start_time = time.perf_counter() log_say("Reset the environment.", play_sounds=True) reset_follower_position(self.robot, self.reset_pose) busy_wait(self.reset_time_s - (time.perf_counter() - start_time)) log_say("Reset the environment done.", play_sounds=True) + if self.open_gripper_on_reset: + current_joint_pos = self.robot.follower_arms["main"].read("Present_Position") + current_joint_pos[-1] = MAX_GRIPPER_COMMAND + self.robot.send_action(torch.from_numpy(current_joint_pos)) + busy_wait(0.1) + current_joint_pos[-1] = 0.0 + self.robot.send_action(torch.from_numpy(current_joint_pos)) + busy_wait(0.2) else: log_say( f"Manually reset the environment for {self.reset_time_s} seconds.", @@ -777,7 +788,6 @@ class GripperPenaltyWrapper(gym.RewardWrapper): gripper_penalty_bool = (gripper_state_normalized < 0.1 and action_normalized > 0.9) or ( gripper_state_normalized > 0.9 and action_normalized < 0.1 ) - breakpoint() return reward + self.penalty * gripper_penalty_bool @@ -791,8 +801,7 @@ class GripperPenaltyWrapper(gym.RewardWrapper): self.last_gripper_state = None return super().reset(**kwargs) - -class GripperQuantizationWrapper(gym.ActionWrapper): +class GripperActionWrapper(gym.ActionWrapper): def __init__(self, env, quantization_threshold: float = 0.2): super().__init__(env) self.quantization_threshold = quantization_threshold @@ -801,16 +810,18 @@ class GripperQuantizationWrapper(gym.ActionWrapper): is_intervention = False if isinstance(action, tuple): action, is_intervention = action - gripper_command = action[-1] - # Quantize gripper command to -1, 0 or 1 - if gripper_command < -self.quantization_threshold: - gripper_command = -MAX_GRIPPER_COMMAND - elif gripper_command > self.quantization_threshold: - gripper_command = MAX_GRIPPER_COMMAND - else: - gripper_command = 0.0 + # Gripper actions are between 0, 2 + # we want to quantize them to -1, 0 or 1 + gripper_command = gripper_command - 1.0 + + if self.quantization_threshold is not None: + # Quantize gripper command to -1, 0 or 1 + gripper_command = ( + np.sign(gripper_command) if abs(gripper_command) > self.quantization_threshold else 0.0 + ) + gripper_command = gripper_command * MAX_GRIPPER_COMMAND gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1] gripper_action = np.clip(gripper_state + gripper_command, 0, MAX_GRIPPER_COMMAND) action[-1] = gripper_action.item() @@ -836,10 +847,12 @@ class EEActionWrapper(gym.ActionWrapper): ] ) if self.use_gripper: - action_space_bounds = np.concatenate([action_space_bounds, [1.0]]) + # gripper actions open at 2.0, and closed at 0.0 + min_action_space_bounds = np.concatenate([-action_space_bounds, [0.0]]) + max_action_space_bounds = np.concatenate([action_space_bounds, [2.0]]) ee_action_space = gym.spaces.Box( - low=-action_space_bounds, - high=action_space_bounds, + low=min_action_space_bounds, + high=max_action_space_bounds, shape=(3 + int(self.use_gripper),), dtype=np.float32, ) @@ -997,11 +1010,11 @@ class GamepadControlWrapper(gym.Wrapper): if self.use_gripper: gripper_command = self.controller.gripper_command() if gripper_command == "open": - gamepad_action = np.concatenate([gamepad_action, [1.0]]) + gamepad_action = np.concatenate([gamepad_action, [2.0]]) elif gripper_command == "close": - gamepad_action = np.concatenate([gamepad_action, [-1.0]]) - else: gamepad_action = np.concatenate([gamepad_action, [0.0]]) + else: + gamepad_action = np.concatenate([gamepad_action, [1.0]]) # Check episode ending buttons # We'll rely on controller.get_episode_end_status() which returns "success", "failure", or None @@ -1141,7 +1154,6 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv: env = HILSerlRobotEnv( robot=robot, display_cameras=cfg.wrapper.display_cameras, - delta=cfg.wrapper.delta_action, use_delta_action_space=cfg.wrapper.use_relative_joint_positions and cfg.wrapper.ee_action_space_params is None, ) @@ -1165,7 +1177,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv: # env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device) env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps) if cfg.wrapper.use_gripper: - env = GripperQuantizationWrapper( + env = GripperActionWrapper( env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold ) # env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty) @@ -1176,6 +1188,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv: ee_action_space_params=cfg.wrapper.ee_action_space_params, use_gripper=cfg.wrapper.use_gripper, ) + if cfg.wrapper.ee_action_space_params is not None and cfg.wrapper.ee_action_space_params.use_gamepad: # env = ActionScaleWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params) env = GamepadControlWrapper( @@ -1192,6 +1205,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv: env=env, reset_pose=cfg.wrapper.fixed_reset_joint_positions, reset_time_s=cfg.wrapper.reset_time_s, + open_gripper_on_reset=cfg.wrapper.open_gripper_on_reset ) if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None: env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space) @@ -1341,11 +1355,10 @@ def record_dataset(env, policy, cfg): dataset.push_to_hub() -def replay_episode(env, repo_id, root=None, episode=0): +def replay_episode(env, cfg): from lerobot.common.datasets.lerobot_dataset import LeRobotDataset - local_files_only = root is not None - dataset = LeRobotDataset(repo_id, root=root, episodes=[episode], local_files_only=local_files_only) + dataset = LeRobotDataset(cfg.repo_id, root=cfg.dataset_root, episodes=[cfg.episode]) env.reset() actions = dataset.hf_dataset.select_columns("action") @@ -1353,7 +1366,7 @@ def replay_episode(env, repo_id, root=None, episode=0): for idx in range(dataset.num_frames): start_episode_t = time.perf_counter() - action = actions[idx]["action"][:4] + action = actions[idx]["action"] env.step((action, False)) # env.step((action / env.unwrapped.delta, False)) @@ -1384,9 +1397,7 @@ def main(cfg: EnvConfig): if cfg.mode == "replay": replay_episode( env, - cfg.replay_repo_id, - root=cfg.dataset_root, - episode=cfg.replay_episode, + cfg=cfg, ) exit() diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 5489d6dc..bcf47787 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -380,6 +380,7 @@ def add_actor_information_and_train( for _ in range(utd_ratio - 1): # Sample from the iterators batch = next(online_iterator) + # batch = replay_buffer.sample(batch_size) if dataset_repo_id is not None: batch_offline = next(offline_iterator) @@ -437,9 +438,11 @@ def add_actor_information_and_train( # Sample for the last update in the UTD ratio batch = next(online_iterator) + # batch = replay_buffer.sample(batch_size) if dataset_repo_id is not None: batch_offline = next(offline_iterator) + # batch_offline = offline_replay_buffer.sample(batch_size) batch = concatenate_batch_transitions( left_batch_transitions=batch, right_batch_transition=batch_offline ) @@ -775,9 +778,7 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module): params=policy.actor.parameters_to_optimize, lr=cfg.policy.actor_lr, ) - optimizer_critic = torch.optim.Adam( - params=policy.critic_ensemble.parameters_to_optimize, lr=cfg.policy.critic_lr - ) + optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters_to_optimize, lr=cfg.policy.critic_lr) if cfg.policy.num_discrete_actions is not None: optimizer_grasp_critic = torch.optim.Adam( @@ -992,7 +993,6 @@ def initialize_offline_replay_buffer( device=device, state_keys=cfg.policy.input_features.keys(), action_mask=active_action_dims, - action_delta=cfg.env.wrapper.delta_action, storage_device=storage_device, optimize_memory=True, capacity=cfg.policy.offline_buffer_capacity,