From 1f23ef78891197e9fe22c3993b722b31f3a746da Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Mon, 17 Mar 2025 10:50:28 +0000 Subject: [PATCH 1/6] Enhance SAC configuration and policy with gradient clipping and temperature management - Introduced `grad_clip_norm` parameter in SAC configuration for gradient clipping - Updated SACPolicy to store temperature as an instance variable for consistent usage - Modified loss calculations in SACPolicy to utilize the instance temperature - Enhanced MLP and CriticHead to support a customizable final activation function - Implemented gradient clipping in the learner server during training steps for both actor and critic - Added tracking for gradient norms in training information --- .../common/policies/sac/configuration_sac.py | 2 ++ lerobot/common/policies/sac/modeling_sac.py | 34 ++++++++++++++----- lerobot/scripts/server/learner_server.py | 33 ++++++++++++++++++ 3 files changed, 60 insertions(+), 9 deletions(-) diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index b834896e..61e08df4 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -84,10 +84,12 @@ class SACConfig: latent_dim: int = 256 target_entropy: float | None = None use_backup_entropy: bool = True + grad_clip_norm: float = 40.0 critic_network_kwargs: dict[str, Any] = field( default_factory=lambda: { "hidden_dims": [256, 256], "activate_final": True, + "final_activation": None, } ) actor_network_kwargs: dict[str, Any] = field( diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index afbbc945..2c4bad5f 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -330,7 +330,7 @@ class SACPolicy( observation_features: Tensor | None = None, next_observation_features: Tensor | None = None, ) -> Tensor: - temperature = self.log_alpha.exp().item() + self.temperature = self.log_alpha.exp().item() with torch.no_grad(): next_action_preds, next_log_probs, _ = self.actor( next_observations, next_observation_features @@ -358,7 +358,7 @@ class SACPolicy( # critics subsample size min_q, _ = q_targets.min(dim=0) # Get values from min operation if self.config.use_backup_entropy: - min_q = min_q - (temperature * next_log_probs) + min_q = min_q - (self.temperature * next_log_probs) td_target = rewards + (1 - done) * self.config.discount * min_q @@ -398,7 +398,7 @@ class SACPolicy( def compute_loss_actor( self, observations, observation_features: Tensor | None = None ) -> Tensor: - temperature = self.log_alpha.exp().item() + self.temperature = self.log_alpha.exp().item() actions_pi, log_probs, _ = self.actor(observations, observation_features) @@ -413,7 +413,7 @@ class SACPolicy( ) min_q_preds = q_preds.min(dim=0)[0] - actor_loss = ((temperature * log_probs) - min_q_preds).mean() + actor_loss = ((self.temperature * log_probs) - min_q_preds).mean() return actor_loss @@ -425,6 +425,7 @@ class MLP(nn.Module): activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(), activate_final: bool = False, dropout_rate: Optional[float] = None, + final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None, ): super().__init__() self.activate_final = activate_final @@ -451,11 +452,24 @@ class MLP(nn.Module): if dropout_rate is not None and dropout_rate > 0: layers.append(nn.Dropout(p=dropout_rate)) layers.append(nn.LayerNorm(hidden_dims[i])) - layers.append( - activations - if isinstance(activations, nn.Module) - else getattr(nn, activations)() - ) + + # If we're at the final layer and a final activation is specified, use it + if ( + i + 1 == len(hidden_dims) + and activate_final + and final_activation is not None + ): + layers.append( + final_activation + if isinstance(final_activation, nn.Module) + else getattr(nn, final_activation)() + ) + else: + layers.append( + activations + if isinstance(activations, nn.Module) + else getattr(nn, activations)() + ) self.net = nn.Sequential(*layers) @@ -516,6 +530,7 @@ class CriticHead(nn.Module): activate_final: bool = False, dropout_rate: Optional[float] = None, init_final: Optional[float] = None, + final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None, ): super().__init__() self.net = MLP( @@ -524,6 +539,7 @@ class CriticHead(nn.Module): activations=activations, activate_final=activate_final, dropout_rate=dropout_rate, + final_activation=final_activation, ) self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=1) if init_final is not None: diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 7bd4aee0..580eed1a 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -390,6 +390,10 @@ def add_actor_information_and_train( if cfg.resume else None, ) + + # Update the policy config with the grad_clip_norm value from training config if it exists + clip_grad_norm_value = cfg.training.grad_clip_norm + # compile policy policy = torch.compile(policy) assert isinstance(policy, nn.Module) @@ -507,6 +511,12 @@ def add_actor_information_and_train( ) optimizers["critic"].zero_grad() loss_critic.backward() + + # clip gradients + critic_grad_norm = torch.nn.utils.clip_grad_norm_( + policy.critic_ensemble.parameters(), clip_grad_norm_value + ) + optimizers["critic"].step() batch = replay_buffer.sample(batch_size) @@ -541,10 +551,17 @@ def add_actor_information_and_train( ) optimizers["critic"].zero_grad() loss_critic.backward() + + # clip gradients + critic_grad_norm = torch.nn.utils.clip_grad_norm_( + policy.critic_ensemble.parameters(), clip_grad_norm_value + ).item() + optimizers["critic"].step() training_infos = {} training_infos["loss_critic"] = loss_critic.item() + training_infos["critic_grad_norm"] = critic_grad_norm if optimization_step % cfg.training.policy_update_freq == 0: for _ in range(cfg.training.policy_update_freq): @@ -555,19 +572,35 @@ def add_actor_information_and_train( optimizers["actor"].zero_grad() loss_actor.backward() + + # clip gradients + actor_grad_norm = torch.nn.utils.clip_grad_norm_( + policy.actor.parameters_to_optimize, clip_grad_norm_value + ).item() + optimizers["actor"].step() training_infos["loss_actor"] = loss_actor.item() + training_infos["actor_grad_norm"] = actor_grad_norm + # Temperature optimization loss_temperature = policy.compute_loss_temperature( observations=observations, observation_features=observation_features, ) optimizers["temperature"].zero_grad() loss_temperature.backward() + + # clip gradients + temp_grad_norm = torch.nn.utils.clip_grad_norm_( + [policy.log_alpha], clip_grad_norm_value + ).item() + optimizers["temperature"].step() training_infos["loss_temperature"] = loss_temperature.item() + training_infos["temperature_grad_norm"] = temp_grad_norm + training_infos["temperature"] = policy.temperature if ( time.time() - last_time_policy_pushed From 9e3c8461cac22e1c1463f7c70937741a05ac9053 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Mon, 17 Mar 2025 14:22:33 +0100 Subject: [PATCH 2/6] Add end effector action space to hil-serl (#861) Co-authored-by: Adil Zouitine Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- lerobot/common/envs/factory.py | 1 + .../common/policies/sac/configuration_sac.py | 2 +- lerobot/common/robot_devices/control_utils.py | 2 +- lerobot/configs/env/so100_real.yaml | 44 +- lerobot/configs/policy/sac_real.yaml | 24 +- lerobot/configs/robot/so100.yaml | 14 +- lerobot/scripts/server/actor_server.py | 18 +- lerobot/scripts/server/crop_dataset_roi.py | 5 - .../server/end_effector_control_utils.py | 797 ++++++++++++++++++ lerobot/scripts/server/find_joint_limits.py | 73 +- lerobot/scripts/server/gym_manipulator.py | 694 +++++++++++++-- lerobot/scripts/server/kinematics.py | 543 ++++++++++++ lerobot/scripts/server/learner_server.py | 60 +- 13 files changed, 2138 insertions(+), 139 deletions(-) create mode 100644 lerobot/scripts/server/end_effector_control_utils.py create mode 100644 lerobot/scripts/server/kinematics.py diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index 457b7af6..c23dcd1d 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -137,6 +137,7 @@ class PixelWrapper(gym.Wrapper): return self._get_obs(obs), reward, terminated, truncated, info +# TODO: Remove this class ConvertToLeRobotEnv(gym.Wrapper): def __init__(self, env, num_envs): super().__init__(env) diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index 61e08df4..3f1a7fbb 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -103,6 +103,6 @@ class SACConfig: "use_tanh_squash": True, "log_std_min": -5, "log_std_max": 2, - "init_final": 0.005, + "init_final": 0.05, } ) diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py index ae25f7ae..08429fc1 100644 --- a/lerobot/common/robot_devices/control_utils.py +++ b/lerobot/common/robot_devices/control_utils.py @@ -367,7 +367,7 @@ def reset_environment(robot, events, reset_time_s): def reset_follower_position(robot: Robot, target_position): current_position = robot.follower_arms["main"].read("Present_Position") trajectory = torch.from_numpy( - np.linspace(current_position, target_position, 30) + np.linspace(current_position, target_position, 50) ) # NOTE: 30 is just an aribtrary number for pose in trajectory: robot.send_action(pose) diff --git a/lerobot/configs/env/so100_real.yaml b/lerobot/configs/env/so100_real.yaml index dc30224c..5265a252 100644 --- a/lerobot/configs/env/so100_real.yaml +++ b/lerobot/configs/env/so100_real.yaml @@ -5,26 +5,46 @@ fps: 10 env: name: real_world task: null - state_dim: 6 - action_dim: 6 + state_dim: 15 + action_dim: 3 fps: ${fps} device: mps wrapper: crop_params_dict: - observation.images.front: [102, 43, 358, 523] - observation.images.side: [92, 123, 379, 349] - # observation.images.front: [109, 37, 361, 557] - # observation.images.side: [94, 161, 372, 315] + observation.images.front: [171, 207, 116, 251] + observation.images.side: [232, 200, 142, 204] resize_size: [128, 128] - control_time_s: 20 - reset_follower_pos: true + control_time_s: 10 + reset_follower_pos: false use_relative_joint_positions: true reset_time_s: 5 display_cameras: false - delta_action: 0.1 - joint_masking_action_space: [1, 1, 1, 1, 0, 0] # disable wrist and gripper + delta_action: null #0.3 + joint_masking_action_space: null #[1, 1, 1, 1, 0, 0] # disable wrist and gripper + add_joint_velocity_to_observation: true + add_ee_pose_to_observation: true + + # If null then the teleoperation will be used to reset the robot + # Bounds for pushcube_gamepad_lerobot15 dataset and experiments + # fixed_reset_joint_positions: [-19.86, 103.19, 117.33, 42.7, 13.89, 0.297] + # ee_action_space_params: # If null then ee_action_space is not used + # bounds: + # max: [0.291, 0.147, 0.074] + # min: [0.139, -0.143, 0.03] + + # Bounds for insertcube_gamepad dataset and experiments + fixed_reset_joint_positions: [20.0, 90., 90., 75., -0.7910156, -0.5673759] + ee_action_space_params: + bounds: + max: [0.25295413, 0.07498981, 0.06862044] + min: [0.2010096, -0.12, 0.0433196] + + use_gamepad: true + x_step_size: 0.03 + y_step_size: 0.03 + z_step_size: 0.03 reward_classifier: - pretrained_path: outputs/classifier/13-02-random-sample-resnet10-frozen/checkpoints/best/pretrained_model - config_path: lerobot/configs/policy/hilserl_classifier.yaml + pretrained_path: null # outputs/classifier/13-02-random-sample-resnet10-frozen/checkpoints/best/pretrained_model + config_path: null # lerobot/configs/policy/hilserl_classifier.yaml diff --git a/lerobot/configs/policy/sac_real.yaml b/lerobot/configs/policy/sac_real.yaml index 139463f9..039fe0f0 100644 --- a/lerobot/configs/policy/sac_real.yaml +++ b/lerobot/configs/policy/sac_real.yaml @@ -8,8 +8,7 @@ # env.gym.obs_type=environment_state_agent_pos \ seed: 1 -dataset_repo_id: aractingi/push_cube_overfit_cropped_resized -#aractingi/push_cube_square_offline_demo_cropped_resized +dataset_repo_id: aractingi/insertcube_simple training: # Offline training dataloader @@ -30,7 +29,7 @@ training: online_steps_between_rollouts: 1000 online_sampling_ratio: 1.0 online_env_seed: 10000 - online_buffer_capacity: 1000000 + online_buffer_capacity: 10000 online_buffer_seed_size: 0 online_step_before_learning: 100 #5000 do_online_rollout_async: false @@ -62,7 +61,7 @@ policy: observation.images.side: [3, 128, 128] # observation.image: [3, 128, 128] output_shapes: - action: [4] # ["${env.action_dim}"] + action: ["${env.action_dim}"] # Normalization / Unnormalization input_normalization_modes: @@ -77,23 +76,16 @@ policy: mean: [0.485, 0.456, 0.406] std: [0.229, 0.224, 0.225] observation.state: - min: [-77.08008, 56.25, 60.55664, 19.511719, 0., -0.63829786] - max: [ 7.215820e+01, 1.5398438e+02, 1.6075195e+02, 9.3251953e+01, 0., -1.4184397e-01] - - # min: [-87.09961, 62.402344, 67.23633, 36.035156, 77.34375,0.53691274] - # max: [58.183594, 131.83594, 145.98633, 82.08984, 78.22266, 0.60402685] - # min: [-88.50586, 23.81836, 0.87890625, -32.16797, 78.66211, 0.53691274] - # max: [84.55078, 187.11914, 145.98633, 101.60156, 146.60156, 88.18792] + # 6- joint positions, 6- joint velocities, 3- ee position + max: [ 52.822266, 136.14258, 142.03125, 72.1582, 22.675781, -0.5673759, 100., 100., 100., 100., 100., 100., 0.25295413, 0.07498981, 0.06862044] + min: [-2.6367188, 86.572266, 89.82422, 12.392578, -26.015625, -0.5673759, -100., -100., -100., -100., -100., -100., 0.2010096, -0.12, 0.0433196] output_normalization_modes: action: min_max output_normalization_params: - # action: - # min: [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0] - # max: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0] action: - min: [-149.23828125, -97.734375, -100.1953125, -73.740234375] - max: [149.23828125, 97.734375, 100.1953125, 73.740234375] + min: [-0.03, -0.03, -0.01] + max: [0.03, 0.03, 0.03] # Architecture / modeling. # Neural networks. diff --git a/lerobot/configs/robot/so100.yaml b/lerobot/configs/robot/so100.yaml index 459308ae..59ecfa0b 100644 --- a/lerobot/configs/robot/so100.yaml +++ b/lerobot/configs/robot/so100.yaml @@ -14,9 +14,13 @@ calibration_dir: .cache/calibration/so100 # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as # the number of motors in your follower arms. max_relative_target: null -joint_position_relative_bounds: - max: [ 7.2158203e+01, 1.5398438e+02, 1.6075195e+02, 9.3251953e+01, 0., -1.4184397e-01] - min: [-77.08008, 56.25, 60.55664, 19.511719, 0., -0.63829786] +joint_position_relative_bounds: null + # max: [100, 100, 100, 100, 100, 100] + # min: [-100, -100, -100, -100, -100, -100] + # max: [ 7.2158203e+01, 1.5398438e+02, 1.6075195e+02, 9.3251953e+01, 0., -1.4184397e-01] + # min: [-77.08008, 56.25, 60.55664, 19.511719, 0., -0.63829786] + # max: [ 35.06836 , 103.18359 , 127.61719 , 75.58594 , 0., 0.] + # min: [ -8.876953 , 63.808594 , 90.49805 , 49.48242 , 0., 0.] leader_arms: main: @@ -47,13 +51,13 @@ follower_arms: cameras: front: _target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera - camera_index: 0 + camera_index: 1 fps: 30 width: 640 height: 480 side: _target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera - camera_index: 1 + camera_index: 0 fps: 30 width: 640 height: 480 diff --git a/lerobot/scripts/server/actor_server.py b/lerobot/scripts/server/actor_server.py index 24d8356d..45fd34a3 100644 --- a/lerobot/scripts/server/actor_server.py +++ b/lerobot/scripts/server/actor_server.py @@ -54,6 +54,7 @@ from lerobot.scripts.server.network_utils import ( ) from lerobot.scripts.server.gym_manipulator import get_classifier, make_robot_env from lerobot.scripts.server import learner_service +from lerobot.common.robot_devices.utils import busy_wait from torch.multiprocessing import Queue, Event from queue import Empty @@ -312,17 +313,6 @@ def act_with_policy( logging.info("make_policy") - # HACK: This is an ugly hack to pass the normalization parameters to the policy - # Because the action space is dynamic so we override the output normalization parameters - # it's ugly, we know ... and we will fix it - min_action_space: list = online_env.action_space.spaces[0].low.tolist() - max_action_space: list = online_env.action_space.spaces[0].high.tolist() - output_normalization_params: dict[dict[str, list]] = { - "action": {"min": min_action_space, "max": max_action_space} - } - cfg.policy.output_normalization_params = output_normalization_params - cfg.policy.output_shapes["action"] = online_env.action_space.spaces[0].shape - ### Instantiate the policy in both the actor and learner processes ### To avoid sending a SACPolicy object through the port, we create a policy intance ### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters @@ -347,6 +337,7 @@ def act_with_policy( episode_intervention = False for interaction_step in range(cfg.training.online_steps): + start_time = time.perf_counter() if shutdown_event.is_set(): logging.info("[ACTOR] Shutting down act_with_policy") return @@ -408,7 +399,6 @@ def act_with_policy( complementary_info=info, # TODO Handle information for the transition, is_demonstraction: bool ) ) - # assign obs to the next obs and continue the rollout obs = next_obs @@ -449,6 +439,10 @@ def act_with_policy( episode_intervention = False obs, info = online_env.reset() + if cfg.fps is not None: + dt_time = time.perf_counter() - start_time + busy_wait(1 / cfg.fps - dt_time) + def push_transitions_to_transport_queue(transitions: list, transitions_queue): """Send transitions to learner in smaller chunks to avoid network issues. diff --git a/lerobot/scripts/server/crop_dataset_roi.py b/lerobot/scripts/server/crop_dataset_roi.py index 8bb414fe..d6c3dd51 100644 --- a/lerobot/scripts/server/crop_dataset_roi.py +++ b/lerobot/scripts/server/crop_dataset_roi.py @@ -263,11 +263,6 @@ if __name__ == "__main__": with open(args.crop_params_path) as f: rois = json.load(f) - # rois = { - # "observation.images.front": [102, 43, 358, 523], - # "observation.images.side": [92, 123, 379, 349], - # } - # Print the selected rectangular ROIs print("\nSelected Rectangular Regions of Interest (top, left, height, width):") for key, roi in rois.items(): diff --git a/lerobot/scripts/server/end_effector_control_utils.py b/lerobot/scripts/server/end_effector_control_utils.py new file mode 100644 index 00000000..253a8ebd --- /dev/null +++ b/lerobot/scripts/server/end_effector_control_utils.py @@ -0,0 +1,797 @@ +from lerobot.common.robot_devices.robots.factory import make_robot +from lerobot.common.utils.utils import init_hydra_config +from lerobot.common.robot_devices.utils import busy_wait +from lerobot.scripts.server.kinematics import RobotKinematics +import logging +import time +import torch +import numpy as np +import argparse + + +logging.basicConfig(level=logging.INFO) + + +class InputController: + """Base class for input controllers that generate motion deltas.""" + + def __init__(self, x_step_size=0.01, y_step_size=0.01, z_step_size=0.01): + """ + Initialize the controller. + + Args: + x_step_size: Base movement step size in meters + y_step_size: Base movement step size in meters + z_step_size: Base movement step size in meters + """ + self.x_step_size = x_step_size + self.y_step_size = y_step_size + self.z_step_size = z_step_size + self.running = True + self.episode_end_status = None # None, "success", or "failure" + + def start(self): + """Start the controller and initialize resources.""" + pass + + def stop(self): + """Stop the controller and release resources.""" + pass + + def get_deltas(self): + """Get the current movement deltas (dx, dy, dz) in meters.""" + return 0.0, 0.0, 0.0 + + def should_quit(self): + """Return True if the user has requested to quit.""" + return not self.running + + def update(self): + """Update controller state - call this once per frame.""" + pass + + def __enter__(self): + """Support for use in 'with' statements.""" + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Ensure resources are released when exiting 'with' block.""" + self.stop() + + def get_episode_end_status(self): + """ + Get the current episode end status. + + Returns: + None if episode should continue, "success" or "failure" otherwise + """ + status = self.episode_end_status + self.episode_end_status = None # Reset after reading + return status + + +class KeyboardController(InputController): + """Generate motion deltas from keyboard input.""" + + def __init__(self, x_step_size=0.01, y_step_size=0.01, z_step_size=0.01): + super().__init__(x_step_size, y_step_size, z_step_size) + self.key_states = { + "forward_x": False, + "backward_x": False, + "forward_y": False, + "backward_y": False, + "forward_z": False, + "backward_z": False, + "quit": False, + "success": False, + "failure": False, + } + self.listener = None + + def start(self): + """Start the keyboard listener.""" + from pynput import keyboard + + def on_press(key): + try: + if key == keyboard.Key.up: + self.key_states["forward_x"] = True + elif key == keyboard.Key.down: + self.key_states["backward_x"] = True + elif key == keyboard.Key.left: + self.key_states["forward_y"] = True + elif key == keyboard.Key.right: + self.key_states["backward_y"] = True + elif key == keyboard.Key.shift: + self.key_states["backward_z"] = True + elif key == keyboard.Key.shift_r: + self.key_states["forward_z"] = True + elif key == keyboard.Key.esc: + self.key_states["quit"] = True + self.running = False + return False + elif key == keyboard.Key.enter: + self.key_states["success"] = True + self.episode_end_status = "success" + elif key == keyboard.Key.backspace: + self.key_states["failure"] = True + self.episode_end_status = "failure" + except AttributeError: + pass + + def on_release(key): + try: + if key == keyboard.Key.up: + self.key_states["forward_x"] = False + elif key == keyboard.Key.down: + self.key_states["backward_x"] = False + elif key == keyboard.Key.left: + self.key_states["forward_y"] = False + elif key == keyboard.Key.right: + self.key_states["backward_y"] = False + elif key == keyboard.Key.shift: + self.key_states["backward_z"] = False + elif key == keyboard.Key.shift_r: + self.key_states["forward_z"] = False + elif key == keyboard.Key.enter: + self.key_states["success"] = False + elif key == keyboard.Key.backspace: + self.key_states["failure"] = False + except AttributeError: + pass + + self.listener = keyboard.Listener(on_press=on_press, on_release=on_release) + self.listener.start() + + print("Keyboard controls:") + print(" Arrow keys: Move in X-Y plane") + print(" Shift and Shift_R: Move in Z axis") + print(" Enter: End episode with SUCCESS") + print(" Backspace: End episode with FAILURE") + print(" ESC: Exit") + + def stop(self): + """Stop the keyboard listener.""" + if self.listener and self.listener.is_alive(): + self.listener.stop() + + def get_deltas(self): + """Get the current movement deltas from keyboard state.""" + delta_x = delta_y = delta_z = 0.0 + + if self.key_states["forward_x"]: + delta_x += self.x_step_size + if self.key_states["backward_x"]: + delta_x -= self.x_step_size + if self.key_states["forward_y"]: + delta_y += self.y_step_size + if self.key_states["backward_y"]: + delta_y -= self.y_step_size + if self.key_states["forward_z"]: + delta_z += self.z_step_size + if self.key_states["backward_z"]: + delta_z -= self.z_step_size + + return delta_x, delta_y, delta_z + + def should_quit(self): + """Return True if ESC was pressed.""" + return self.key_states["quit"] + + def should_save(self): + """Return True if Enter was pressed (save episode).""" + return self.key_states["success"] or self.key_states["failure"] + + +class GamepadController(InputController): + """Generate motion deltas from gamepad input.""" + + def __init__( + self, x_step_size=0.01, y_step_size=0.01, z_step_size=0.01, deadzone=0.1 + ): + super().__init__(x_step_size, y_step_size, z_step_size) + self.deadzone = deadzone + self.joystick = None + self.intervention_flag = False + + def start(self): + """Initialize pygame and the gamepad.""" + import pygame + + pygame.init() + pygame.joystick.init() + + if pygame.joystick.get_count() == 0: + logging.error( + "No gamepad detected. Please connect a gamepad and try again." + ) + self.running = False + return + + self.joystick = pygame.joystick.Joystick(0) + self.joystick.init() + logging.info(f"Initialized gamepad: {self.joystick.get_name()}") + + print("Gamepad controls:") + print(" Left analog stick: Move in X-Y plane") + print(" Right analog stick (vertical): Move in Z axis") + print(" B/Circle button: Exit") + print(" Y/Triangle button: End episode with SUCCESS") + print(" A/Cross button: End episode with FAILURE") + print(" X/Square button: Rerecord episode") + + def stop(self): + """Clean up pygame resources.""" + import pygame + + if pygame.joystick.get_init(): + if self.joystick: + self.joystick.quit() + pygame.joystick.quit() + pygame.quit() + + def update(self): + """Process pygame events to get fresh gamepad readings.""" + import pygame + + for event in pygame.event.get(): + if event.type == pygame.JOYBUTTONDOWN: + if event.button == 3: + self.episode_end_status = "success" + # A button (1) for failure + elif event.button == 1: + self.episode_end_status = "failure" + # X button (0) for rerecord + elif event.button == 0: + self.episode_end_status = "rerecord_episode" + + # Reset episode status on button release + elif event.type == pygame.JOYBUTTONUP: + if event.button in [0, 2, 3]: + self.episode_end_status = None + + # Check for RB button (typically button 5) for intervention flag + if self.joystick.get_button(5): + self.intervention_flag = True + else: + self.intervention_flag = False + + def get_deltas(self): + """Get the current movement deltas from gamepad state.""" + import pygame + + try: + # Read joystick axes + # Left stick X and Y (typically axes 0 and 1) + x_input = self.joystick.get_axis(0) # Left/Right + y_input = self.joystick.get_axis(1) # Up/Down (often inverted) + + # Right stick Y (typically axis 3 or 4) + z_input = self.joystick.get_axis(3) # Up/Down for Z + + # Apply deadzone to avoid drift + x_input = 0 if abs(x_input) < self.deadzone else x_input + y_input = 0 if abs(y_input) < self.deadzone else y_input + z_input = 0 if abs(z_input) < self.deadzone else z_input + + # Calculate deltas (note: may need to invert axes depending on controller) + delta_x = -y_input * self.y_step_size # Forward/backward + delta_y = -x_input * self.x_step_size # Left/right + delta_z = -z_input * self.z_step_size # Up/down + + return delta_x, delta_y, delta_z + + except pygame.error: + logging.error("Error reading gamepad. Is it still connected?") + return 0.0, 0.0, 0.0 + + def should_intervene(self): + """Return True if intervention flag was set.""" + return self.intervention_flag + + +class GamepadControllerHID(InputController): + """Generate motion deltas from gamepad input using HIDAPI.""" + + def __init__( + self, + x_step_size=0.01, + y_step_size=0.01, + z_step_size=0.01, + deadzone=0.1, + vendor_id=0x046D, + product_id=0xC219, + ): + """ + Initialize the HID gamepad controller. + + Args: + step_size: Base movement step size in meters + z_scale: Scaling factor for Z-axis movement + deadzone: Joystick deadzone to prevent drift + vendor_id: USB vendor ID of the gamepad (default: Logitech) + product_id: USB product ID of the gamepad (default: RumblePad 2) + """ + super().__init__(x_step_size, y_step_size, z_step_size) + self.deadzone = deadzone + self.vendor_id = vendor_id + self.product_id = product_id + self.device = None + self.device_info = None + + # Movement values (normalized from -1.0 to 1.0) + self.left_x = 0.0 + self.left_y = 0.0 + self.right_x = 0.0 + self.right_y = 0.0 + + # Button states + self.buttons = {} + self.quit_requested = False + self.save_requested = False + self.intervention_flag = False + + def find_device(self): + """Look for the gamepad device by vendor and product ID.""" + import hid + + devices = hid.enumerate() + for device in devices: + if ( + device["vendor_id"] == self.vendor_id + and device["product_id"] == self.product_id + ): + logging.info( + f"Found gamepad: {device.get('product_string', 'Unknown')}" + ) + return device + + logging.error( + f"No gamepad with vendor ID 0x{self.vendor_id:04X} and " + f"product ID 0x{self.product_id:04X} found" + ) + return None + + def start(self): + """Connect to the gamepad using HIDAPI.""" + import hid + + self.device_info = self.find_device() + if not self.device_info: + self.running = False + return + + try: + logging.info(f"Connecting to gamepad at path: {self.device_info['path']}") + self.device = hid.device() + self.device.open_path(self.device_info["path"]) + self.device.set_nonblocking(1) + + manufacturer = self.device.get_manufacturer_string() + product = self.device.get_product_string() + logging.info(f"Connected to {manufacturer} {product}") + + logging.info("Gamepad controls (HID mode):") + logging.info(" Left analog stick: Move in X-Y plane") + logging.info(" Right analog stick: Move in Z axis (vertical)") + logging.info(" Button 1/B/Circle: Exit") + logging.info(" Button 2/A/Cross: End episode with SUCCESS") + logging.info(" Button 3/X/Square: End episode with FAILURE") + + except OSError as e: + logging.error(f"Error opening gamepad: {e}") + logging.error( + "You might need to run this with sudo/admin privileges on some systems" + ) + self.running = False + + def stop(self): + """Close the HID device connection.""" + if self.device: + self.device.close() + self.device = None + + def update(self): + """ + Read and process the latest gamepad data. + Due to an issue with the HIDAPI, we need to read the read the device several times in order to get a stable reading + """ + for _ in range(10): + self._update() + + def _update(self): + """Read and process the latest gamepad data.""" + if not self.device or not self.running: + return + + try: + # Read data from the gamepad + data = self.device.read(64) + if data: + # Interpret gamepad data - this will vary by controller model + # These offsets are for the Logitech RumblePad 2 + if len(data) >= 8: + # Normalize joystick values from 0-255 to -1.0-1.0 + self.left_x = (data[1] - 128) / 128.0 + self.left_y = (data[2] - 128) / 128.0 + self.right_x = (data[3] - 128) / 128.0 + self.right_y = (data[4] - 128) / 128.0 + + # Apply deadzone + self.left_x = 0 if abs(self.left_x) < self.deadzone else self.left_x + self.left_y = 0 if abs(self.left_y) < self.deadzone else self.left_y + self.right_x = ( + 0 if abs(self.right_x) < self.deadzone else self.right_x + ) + self.right_y = ( + 0 if abs(self.right_y) < self.deadzone else self.right_y + ) + + # Parse button states (byte 5 in the Logitech RumblePad 2) + buttons = data[5] + + # Check if RB is pressed then the intervention flag should be set + self.intervention_flag = data[6] == 2 + + # Check if Y/Triangle button (bit 7) is pressed for saving + # Check if X/Square button (bit 5) is pressed for failure + # Check if A/Cross button (bit 4) is pressed for rerecording + if buttons & 1 << 7: + self.episode_end_status = "success" + elif buttons & 1 << 5: + self.episode_end_status = "failure" + elif buttons & 1 << 4: + self.episode_end_status = "rerecord_episode" + else: + self.episode_end_status = None + + except OSError as e: + logging.error(f"Error reading from gamepad: {e}") + + def get_deltas(self): + """Get the current movement deltas from gamepad state.""" + # Calculate deltas - invert as needed based on controller orientation + delta_x = -self.left_y * self.x_step_size # Forward/backward + delta_y = -self.left_x * self.y_step_size # Left/right + delta_z = -self.right_y * self.z_step_size # Up/down + + return delta_x, delta_y, delta_z + + def should_quit(self): + """Return True if quit button was pressed.""" + return self.quit_requested + + def should_save(self): + """Return True if save button was pressed.""" + return self.save_requested + + def should_intervene(self): + """Return True if intervention flag was set.""" + return self.intervention_flag + + +def test_forward_kinematics(robot, fps=10): + logging.info("Testing Forward Kinematics") + timestep = time.perf_counter() + while time.perf_counter() - timestep < 60.0: + loop_start_time = time.perf_counter() + robot.teleop_step() + obs = robot.capture_observation() + joint_positions = obs["observation.state"].cpu().numpy() + ee_pos = RobotKinematics.fk_gripper_tip(joint_positions) + logging.info(f"EE Position: {ee_pos[:3,3]}") + busy_wait(1 / fps - (time.perf_counter() - loop_start_time)) + + +def test_inverse_kinematics(robot, fps=10): + logging.info("Testing Inverse Kinematics") + timestep = time.perf_counter() + while time.perf_counter() - timestep < 60.0: + loop_start_time = time.perf_counter() + obs = robot.capture_observation() + joint_positions = obs["observation.state"].cpu().numpy() + ee_pos = RobotKinematics.fk_gripper_tip(joint_positions) + desired_ee_pos = ee_pos + target_joint_state = RobotKinematics.ik( + joint_positions, desired_ee_pos, position_only=True + ) + robot.send_action(torch.from_numpy(target_joint_state)) + logging.info(f"Target Joint State: {target_joint_state}") + busy_wait(1 / fps - (time.perf_counter() - loop_start_time)) + + +def teleoperate_inverse_kinematics_with_leader(robot, fps=10): + logging.info("Testing Inverse Kinematics") + fk_func = RobotKinematics.fk_gripper_tip + timestep = time.perf_counter() + while time.perf_counter() - timestep < 60.0: + loop_start_time = time.perf_counter() + obs = robot.capture_observation() + joint_positions = obs["observation.state"].cpu().numpy() + ee_pos = fk_func(joint_positions) + + leader_joint_positions = robot.leader_arms["main"].read("Present_Position") + leader_ee = fk_func(leader_joint_positions) + + desired_ee_pos = leader_ee + target_joint_state = RobotKinematics.ik( + joint_positions, desired_ee_pos, position_only=True, fk_func=fk_func + ) + robot.send_action(torch.from_numpy(target_joint_state)) + logging.info(f"Leader EE: {leader_ee[:3,3]}, Follower EE: {ee_pos[:3,3]}") + busy_wait(1 / fps - (time.perf_counter() - loop_start_time)) + + +def teleoperate_delta_inverse_kinematics_with_leader(robot, fps=10): + logging.info("Testing Delta End-Effector Control") + timestep = time.perf_counter() + + # Initial position capture + obs = robot.capture_observation() + joint_positions = obs["observation.state"].cpu().numpy() + + fk_func = RobotKinematics.fk_gripper_tip + + leader_joint_positions = robot.leader_arms["main"].read("Present_Position") + initial_leader_ee = fk_func(leader_joint_positions) + + desired_ee_pos = np.diag(np.ones(4)) + + while time.perf_counter() - timestep < 60.0: + loop_start_time = time.perf_counter() + + # Get leader state for teleoperation + leader_joint_positions = robot.leader_arms["main"].read("Present_Position") + leader_ee = fk_func(leader_joint_positions) + + # Get current state + # obs = robot.capture_observation() + # joint_positions = obs["observation.state"].cpu().numpy() + joint_positions = robot.follower_arms["main"].read("Present_Position") + current_ee_pos = fk_func(joint_positions) + + # Calculate delta between leader and follower end-effectors + # Scaling factor can be adjusted for sensitivity + scaling_factor = 1.0 + ee_delta = (leader_ee - initial_leader_ee) * scaling_factor + + # Apply delta to current position + desired_ee_pos[0, 3] = current_ee_pos[0, 3] + ee_delta[0, 3] + desired_ee_pos[1, 3] = current_ee_pos[1, 3] + ee_delta[1, 3] + desired_ee_pos[2, 3] = current_ee_pos[2, 3] + ee_delta[2, 3] + + if np.any(np.abs(ee_delta[:3, 3]) > 0.01): + # Compute joint targets via inverse kinematics + target_joint_state = RobotKinematics.ik( + joint_positions, desired_ee_pos, position_only=True, fk_func=fk_func + ) + + initial_leader_ee = leader_ee.copy() + + # Send command to robot + robot.send_action(torch.from_numpy(target_joint_state)) + + # Logging + logging.info( + f"Current EE: {current_ee_pos[:3,3]}, Desired EE: {desired_ee_pos[:3,3]}" + ) + logging.info(f"Delta EE: {ee_delta[:3,3]}") + + busy_wait(1 / fps - (time.perf_counter() - loop_start_time)) + + +def teleoperate_delta_inverse_kinematics( + robot, controller, fps=10, bounds=None, fk_func=None +): + """ + Control a robot using delta end-effector movements from any input controller. + + Args: + robot: Robot instance to control + controller: InputController instance (keyboard, gamepad, etc.) + fps: Control frequency in Hz + bounds: Optional position limits + fk_func: Forward kinematics function to use + """ + if fk_func is None: + fk_func = RobotKinematics.fk_gripper_tip + + logging.info( + f"Testing Delta End-Effector Control with {controller.__class__.__name__}" + ) + + # Initial position capture + obs = robot.capture_observation() + joint_positions = obs["observation.state"].cpu().numpy() + current_ee_pos = fk_func(joint_positions) + + # Initialize desired position with current position + desired_ee_pos = np.eye(4) # Identity matrix + + timestep = time.perf_counter() + with controller: + while not controller.should_quit() and time.perf_counter() - timestep < 60.0: + loop_start_time = time.perf_counter() + + # Process input events + controller.update() + + # Get currrent robot state + joint_positions = robot.follower_arms["main"].read("Present_Position") + current_ee_pos = fk_func(joint_positions) + + # Get movement deltas from the controller + delta_x, delta_y, delta_z = controller.get_deltas() + + # Update desired position + desired_ee_pos[0, 3] = current_ee_pos[0, 3] + delta_x + desired_ee_pos[1, 3] = current_ee_pos[1, 3] + delta_y + desired_ee_pos[2, 3] = current_ee_pos[2, 3] + delta_z + + # Apply bounds if provided + if bounds is not None: + desired_ee_pos[:3, 3] = np.clip( + desired_ee_pos[:3, 3], bounds["min"], bounds["max"] + ) + + # Only send commands if there's actual movement + if any([abs(v) > 0.001 for v in [delta_x, delta_y, delta_z]]): + # Compute joint targets via inverse kinematics + target_joint_state = RobotKinematics.ik( + joint_positions, desired_ee_pos, position_only=True, fk_func=fk_func + ) + + # Send command to robot + robot.send_action(torch.from_numpy(target_joint_state)) + + busy_wait(1 / fps - (time.perf_counter() - loop_start_time)) + + +def teleoperate_gym_env(env, controller, fps: int = 30): + """ + Control a robot through a gym environment using keyboard inputs. + + Args: + env: A gym environment created with make_robot_env + fps: Target control frequency + """ + + logging.info("Testing Keyboard Control of Gym Environment") + print("Keyboard controls:") + print(" Arrow keys: Move in X-Y plane") + print(" Shift and Shift_R: Move in Z axis") + print(" ESC: Exit") + + # Reset the environment to get initial observation + obs, info = env.reset() + + try: + with controller: + while not controller.should_quit(): + loop_start_time = time.perf_counter() + + # Process input events + controller.update() + + # Get movement deltas from the controller + delta_x, delta_y, delta_z = controller.get_deltas() + + # Create the action vector + action = np.array([delta_x, delta_y, delta_z]) + + # Skip if no movement + if any([abs(v) > 0.001 for v in [delta_x, delta_y, delta_z]]): + # Step the environment - pass action as a tensor with intervention flag + action_tensor = torch.from_numpy(action.astype(np.float32)) + obs, reward, terminated, truncated, info = env.step( + (action_tensor, False) + ) + + # Log information + logging.info( + f"Action: [{delta_x:.4f}, {delta_y:.4f}, {delta_z:.4f}]" + ) + logging.info(f"Reward: {reward}") + + # Reset if episode ended + if terminated or truncated: + logging.info("Episode ended, resetting environment") + obs, info = env.reset() + + # Maintain target frame rate + busy_wait(1 / fps - (time.perf_counter() - loop_start_time)) + + finally: + # Close the environment + env.close() + + +def make_robot_from_config(config_path, overrides=None): + """Helper function to create a robot from a config file.""" + if overrides is None: + overrides = [] + robot_cfg = init_hydra_config(config_path, overrides) + return make_robot(robot_cfg) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Test end-effector control") + parser.add_argument( + "--mode", + type=str, + default="keyboard", + choices=[ + "keyboard", + "gamepad", + "keyboard_gym", + "gamepad_gym", + "leader", + "leader_abs", + ], + help="Control mode to use", + ) + parser.add_argument( + "--task", + type=str, + default="Robot manipulation task", + help="Description of the task being performed", + ) + parser.add_argument( + "--push-to-hub", + default=True, + type=bool, + help="Push the dataset to Hugging Face Hub", + ) + # Add the rest of your existing arguments + args = parser.parse_args() + + robot = make_robot_from_config("lerobot/configs/robot/so100.yaml", []) + + if not robot.is_connected: + robot.connect() + + # Example bounds + bounds = { + "max": np.array([0.32170487, 0.201285, 0.10273342]), + "min": np.array([0.16631757, -0.08237468, 0.03364977]), + } + + try: + # Determine controller type based on mode prefix + controller = None + if args.mode.startswith("keyboard"): + controller = KeyboardController( + x_step_size=0.01, y_step_size=0.01, z_step_size=0.05 + ) + elif args.mode.startswith("gamepad"): + controller = GamepadController( + x_step_size=0.02, y_step_size=0.02, z_step_size=0.05 + ) + + # Handle mode categories + if args.mode in ["keyboard", "gamepad"]: + # Direct robot control modes + teleoperate_delta_inverse_kinematics( + robot, controller, bounds=bounds, fps=10 + ) + + elif args.mode in ["keyboard_gym", "gamepad_gym"]: + # Gym environment control modes + from lerobot.scripts.server.gym_manipulator import make_robot_env + + cfg = init_hydra_config("lerobot/configs/env/so100_real.yaml", []) + cfg.env.wrapper.ee_action_space_params.use_gamepad = False + env = make_robot_env(robot, None, cfg) + teleoperate_gym_env(env, controller) + + elif args.mode == "leader": + # Leader-follower modes don't use controllers + teleoperate_delta_inverse_kinematics_with_leader(robot) + + elif args.mode == "leader_abs": + teleoperate_inverse_kinematics_with_leader(robot) + + finally: + if robot.is_connected: + robot.disconnect() diff --git a/lerobot/scripts/server/find_joint_limits.py b/lerobot/scripts/server/find_joint_limits.py index d5870027..7834f821 100644 --- a/lerobot/scripts/server/find_joint_limits.py +++ b/lerobot/scripts/server/find_joint_limits.py @@ -7,25 +7,26 @@ import numpy as np from lerobot.common.robot_devices.control_utils import is_headless from lerobot.common.robot_devices.robots.factory import make_robot from lerobot.common.utils.utils import init_hydra_config +from lerobot.scripts.server.kinematics import RobotKinematics def find_joint_bounds( robot, - control_time_s=20, + control_time_s=30, display_cameras=False, ): - # TODO(rcadene): Add option to record logs if not robot.is_connected: robot.connect() - control_time_s = float("inf") - - timestamp = 0 start_episode_t = time.perf_counter() pos_list = [] - while timestamp < control_time_s: + while True: observation, action = robot.teleop_step(record_data=True) + # Wait for 5 seconds to stabilize the robot initial position + if time.perf_counter() - start_episode_t < 5: + continue + pos_list.append(robot.follower_arms["main"].read("Present_Position")) if display_cameras and not is_headless(): @@ -36,8 +37,7 @@ def find_joint_bounds( ) cv2.waitKey(1) - timestamp = time.perf_counter() - start_episode_t - if timestamp > 60: + if time.perf_counter() - start_episode_t > control_time_s: max = np.max(np.stack(pos_list), 0) min = np.min(np.stack(pos_list), 0) print(f"Max angle position per joint {max}") @@ -45,6 +45,43 @@ def find_joint_bounds( break +def find_ee_bounds( + robot, + control_time_s=30, + display_cameras=False, +): + if not robot.is_connected: + robot.connect() + + start_episode_t = time.perf_counter() + ee_list = [] + while True: + observation, action = robot.teleop_step(record_data=True) + + # Wait for 5 seconds to stabilize the robot initial position + if time.perf_counter() - start_episode_t < 5: + continue + + joint_positions = robot.follower_arms["main"].read("Present_Position") + print(f"Joint positions: {joint_positions}") + ee_list.append(RobotKinematics.fk_gripper_tip(joint_positions)[:3, 3]) + + if display_cameras and not is_headless(): + image_keys = [key for key in observation if "image" in key] + for key in image_keys: + cv2.imshow( + key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR) + ) + cv2.waitKey(1) + + if time.perf_counter() - start_episode_t > control_time_s: + max = np.max(np.stack(ee_list), 0) + min = np.min(np.stack(ee_list), 0) + print(f"Max ee position {max}") + print(f"Min ee position {min}") + break + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( @@ -59,14 +96,26 @@ if __name__ == "__main__": nargs="*", help="Any key=value arguments to override config values (use dots for.nested=overrides)", ) + parser.add_argument( + "--mode", + type=str, + default="joint", + choices=["joint", "ee"], + help="Mode to run the script in. Can be 'joint' or 'ee'.", + ) parser.add_argument( "--control-time-s", - type=float, - default=20, - help="Maximum episode length in seconds", + type=int, + default=30, + help="Time step to use for control.", ) args = parser.parse_args() robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides) robot = make_robot(robot_cfg) - find_joint_bounds(robot, control_time_s=args.control_time_s) + if args.mode == "joint": + find_joint_bounds(robot, args.control_time_s) + elif args.mode == "ee": + find_ee_bounds(robot, args.control_time_s) + if robot.is_connected: + robot.disconnect() diff --git a/lerobot/scripts/server/gym_manipulator.py b/lerobot/scripts/server/gym_manipulator.py index c1a7c88c..728afdfa 100644 --- a/lerobot/scripts/server/gym_manipulator.py +++ b/lerobot/scripts/server/gym_manipulator.py @@ -1,19 +1,26 @@ import argparse +import sys + import logging import time from threading import Lock -from typing import Annotated, Any, Callable, Dict, Optional, Tuple - +from typing import Annotated, Any, Dict, Tuple import gymnasium as gym import numpy as np import torch import torchvision.transforms.functional as F # noqa: N812 from lerobot.common.envs.utils import preprocess_observation -from lerobot.common.robot_devices.control_utils import busy_wait, is_headless +from lerobot.common.robot_devices.control_utils import ( + busy_wait, + is_headless, + reset_follower_position, +) from lerobot.common.robot_devices.robots.factory import make_robot from lerobot.common.utils.utils import init_hydra_config, log_say +from lerobot.scripts.server.kinematics import RobotKinematics + logging.basicConfig(level=logging.INFO) @@ -76,13 +83,19 @@ class HILSerlRobotEnv(gym.Env): # Retrieve the size of the joint position interval bound. self.relative_bounds_size = ( - self.robot.config.joint_position_relative_bounds["max"] - - self.robot.config.joint_position_relative_bounds["min"] + ( + self.robot.config.joint_position_relative_bounds["max"] + - self.robot.config.joint_position_relative_bounds["min"] + ) + if self.robot.config.joint_position_relative_bounds is not None + else None ) - self.delta_relative_bounds_size = self.relative_bounds_size * self.delta - - self.robot.config.max_relative_target = self.delta_relative_bounds_size.float() + self.robot.config.max_relative_target = ( + self.relative_bounds_size.float() + if self.relative_bounds_size is not None + else None + ) # Dynamically configure the observation and action spaces. self._setup_spaces() @@ -99,26 +112,23 @@ class HILSerlRobotEnv(gym.Env): - The action space is defined as a Tuple where: • The first element is a Box space representing joint position commands. It is defined as relative (delta) or absolute, based on the configuration. - • The second element is a Discrete space (with 2 values) serving as a flag for intervention (teleoperation). + • ThE SECONd element is a Discrete space (with 2 values) serving as a flag for intervention (teleoperation). """ example_obs = self.robot.capture_observation() # Define observation spaces for images and other states. image_keys = [key for key in example_obs if "image" in key] - state_keys = [key for key in example_obs if "image" not in key] observation_spaces = { key: gym.spaces.Box( low=0, high=255, shape=example_obs[key].shape, dtype=np.uint8 ) for key in image_keys } - observation_spaces["observation.state"] = gym.spaces.Dict( - { - key: gym.spaces.Box( - low=0, high=10, shape=example_obs[key].shape, dtype=np.float32 - ) - for key in state_keys - } + observation_spaces["observation.state"] = gym.spaces.Box( + low=0, + high=10, + shape=example_obs["observation.state"].shape, + dtype=np.float32, ) self.observation_space = gym.spaces.Dict(observation_spaces) @@ -126,20 +136,31 @@ class HILSerlRobotEnv(gym.Env): # Define the action space for joint positions along with setting an intervention flag. action_dim = len(self.robot.follower_arms["main"].read("Present_Position")) if self.use_delta_action_space: + bounds = ( + self.relative_bounds_size + if self.relative_bounds_size is not None + else np.ones(action_dim) * 1000 + ) action_space_robot = gym.spaces.Box( - low=-self.relative_bounds_size.cpu().numpy(), - high=self.relative_bounds_size.cpu().numpy(), + low=-bounds, + high=bounds, shape=(action_dim,), dtype=np.float32, ) else: + bounds_min = ( + self.robot.config.joint_position_relative_bounds["min"].cpu().numpy() + if self.robot.config.joint_position_relative_bounds is not None + else np.ones(action_dim) * -1000 + ) + bounds_max = ( + self.robot.config.joint_position_relative_bounds["max"].cpu().numpy() + if self.robot.config.joint_position_relative_bounds is not None + else np.ones(action_dim) * 1000 + ) action_space_robot = gym.spaces.Box( - low=self.robot.config.joint_position_relative_bounds["min"] - .cpu() - .numpy(), - high=self.robot.config.joint_position_relative_bounds["max"] - .cpu() - .numpy(), + low=bounds_min, + high=bounds_max, shape=(action_dim,), dtype=np.float32, ) @@ -176,7 +197,7 @@ class HILSerlRobotEnv(gym.Env): self.current_step = 0 self.episode_data = None - return observation, {"initial_position": self.initial_follower_position} + return observation, {} def step( self, action: Tuple[np.ndarray, bool] @@ -218,6 +239,7 @@ class HILSerlRobotEnv(gym.Env): policy_action = np.clip( policy_action, self.action_space[0].low, self.action_space[0].high ) + if not intervention_bool: if self.use_delta_action_space: target_joint_positions = ( @@ -238,8 +260,9 @@ class HILSerlRobotEnv(gym.Env): teleop_action = ( teleop_action - self.current_joint_positions ) / self.delta - if torch.any(teleop_action < -self.relative_bounds_size) and torch.any( - teleop_action > self.relative_bounds_size + if self.relative_bounds_size is not None and ( + torch.any(teleop_action < -self.relative_bounds_size) + and torch.any(teleop_action > self.relative_bounds_size) ): logging.debug( f"Relative teleop delta exceeded bounds {self.relative_bounds_size}, teleop_action {teleop_action}\n" @@ -299,6 +322,46 @@ class HILSerlRobotEnv(gym.Env): self.robot.disconnect() +class AddJointVelocityToObservation(gym.ObservationWrapper): + def __init__(self, env, joint_velocity_limits=100.0, fps=30): + super().__init__(env) + + # Extend observation space to include joint velocities + old_low = self.observation_space["observation.state"].low + old_high = self.observation_space["observation.state"].high + old_shape = self.observation_space["observation.state"].shape + + self.last_joint_positions = np.zeros(old_shape) + + new_low = np.concatenate( + [old_low, np.ones_like(old_low) * -joint_velocity_limits] + ) + new_high = np.concatenate( + [old_high, np.ones_like(old_high) * joint_velocity_limits] + ) + + new_shape = (old_shape[0] * 2,) + + self.observation_space["observation.state"] = gym.spaces.Box( + low=new_low, + high=new_high, + shape=new_shape, + dtype=np.float32, + ) + + self.dt = 1.0 / fps + + def observation(self, observation): + joint_velocities = ( + observation["observation.state"] - self.last_joint_positions + ) / self.dt + self.last_joint_positions = observation["observation.state"].clone() + observation["observation.state"] = torch.cat( + [observation["observation.state"], joint_velocities], dim=-1 + ) + return observation + + class ActionRepeatWrapper(gym.Wrapper): def __init__(self, env, nb_repeat: int = 1): super().__init__(env) @@ -347,8 +410,6 @@ class RewardWrapper(gym.Wrapper): ) info["Reward classifer frequency"] = 1 / (time.perf_counter() - start_time) - # logging.info(f"Reward: {reward}") - if reward == 1.0: terminated = True return observation, reward, terminated, truncated, info @@ -465,9 +526,7 @@ class TimeLimitWrapper(gym.Wrapper): if 1.0 / time_since_last_step < self.fps: logging.debug(f"Current timestep exceeded expected fps {self.fps}") - if self.episode_time_in_s > self.control_time_s: - # if self.current_step >= self.max_episode_steps: - # Terminated = True + if self.current_step >= self.max_episode_steps: terminated = True return obs, reward, terminated, truncated, info @@ -508,7 +567,20 @@ class ImageCropResizeWrapper(gym.Wrapper): obs, reward, terminated, truncated, info = self.env.step(action) for k in self.crop_params_dict: device = obs[k].device + if obs[k].dim() >= 3: + # Reshape to combine height and width dimensions for easier calculation + batch_size = obs[k].size(0) + channels = obs[k].size(1) + flattened_spatial_dims = obs[k].view(batch_size, channels, -1) + # Calculate standard deviation across spatial dimensions (H, W) + std_per_channel = torch.std(flattened_spatial_dims, dim=2) + + # If any channel has std=0, all pixels in that channel have the same value + if (std_per_channel <= 0.02).any(): + logging.warning( + f"Potential hardware issue detected: All pixels have the same value in observation {k}" + ) # Check for NaNs before processing if torch.isnan(obs[k]).any(): logging.error( @@ -703,19 +775,21 @@ class ResetWrapper(gym.Wrapper): def __init__( self, env: HILSerlRobotEnv, - reset_fn: Optional[Callable[[], None]] = None, + reset_pose: np.ndarray | None = None, reset_time_s: float = 5, ): super().__init__(env) - self.reset_fn = reset_fn self.reset_time_s = reset_time_s - + self.reset_pose = reset_pose self.robot = self.unwrapped.robot - self.init_pos = self.unwrapped.initial_follower_position def reset(self, *, seed=None, options=None): - if self.reset_fn is not None: - self.reset_fn(self.env) + if self.reset_pose is not None: + start_time = time.perf_counter() + log_say("Reset the environment.", play_sounds=True) + reset_follower_position(self.robot, self.reset_pose) + busy_wait(self.reset_time_s - (time.perf_counter() - start_time)) + log_say("Reset the environment done.", play_sounds=True) else: log_say( f"Manually reset the environment for {self.reset_time_s} seconds.", @@ -741,10 +815,297 @@ class BatchCompitableWrapper(gym.ObservationWrapper): observation[key] = observation[key].unsqueeze(0) if "state" in key and observation[key].dim() == 1: observation[key] = observation[key].unsqueeze(0) + if "velocity" in key and observation[key].dim() == 1: + observation[key] = observation[key].unsqueeze(0) return observation -# TODO: REMOVE TH +class EEActionWrapper(gym.ActionWrapper): + def __init__(self, env, ee_action_space_params=None): + super().__init__(env) + self.ee_action_space_params = ee_action_space_params + + # Initialize kinematics instance for the appropriate robot type + robot_type = getattr(env.unwrapped.robot.config, "robot_type", "so100") + self.kinematics = RobotKinematics(robot_type) + self.fk_function = self.kinematics.fk_gripper_tip + + action_space_bounds = np.array( + [ + ee_action_space_params.x_step_size, + ee_action_space_params.y_step_size, + ee_action_space_params.z_step_size, + ] + ) + ee_action_space = gym.spaces.Box( + low=-action_space_bounds, + high=action_space_bounds, + shape=(3,), + dtype=np.float32, + ) + if isinstance(self.action_space, gym.spaces.Tuple): + self.action_space = gym.spaces.Tuple( + (ee_action_space, self.action_space[1]) + ) + else: + self.action_space = ee_action_space + + self.bounds = ee_action_space_params.bounds + + def action(self, action): + is_intervention = False + desired_ee_pos = np.eye(4) + if isinstance(action, tuple): + action, _ = action + + current_joint_pos = self.unwrapped.robot.follower_arms["main"].read( + "Present_Position" + ) + current_ee_pos = self.fk_function(current_joint_pos) + if isinstance(action, torch.Tensor): + action = action.cpu().numpy() + desired_ee_pos[:3, 3] = np.clip( + current_ee_pos[:3, 3] + action, + self.bounds["min"], + self.bounds["max"], + ) + target_joint_pos = self.kinematics.ik( + current_joint_pos, + desired_ee_pos, + position_only=True, + fk_func=self.fk_function, + ) + return target_joint_pos, is_intervention + + +class EEObservationWrapper(gym.ObservationWrapper): + def __init__(self, env, ee_pose_limits): + super().__init__(env) + + # Extend observation space to include end effector pose + prev_space = self.observation_space["observation.state"] + + self.observation_space["observation.state"] = gym.spaces.Box( + low=np.concatenate([prev_space.low, ee_pose_limits["min"]]), + high=np.concatenate([prev_space.high, ee_pose_limits["max"]]), + shape=(prev_space.shape[0] + 3,), + dtype=np.float32, + ) + + # Initialize kinematics instance for the appropriate robot type + robot_type = getattr(env.unwrapped.robot.config, "robot_type", "so100") + self.kinematics = RobotKinematics(robot_type) + self.fk_function = self.kinematics.fk_gripper_tip + + def observation(self, observation): + current_joint_pos = self.unwrapped.robot.follower_arms["main"].read( + "Present_Position" + ) + current_ee_pos = self.fk_function(current_joint_pos) + observation["observation.state"] = torch.cat( + [ + observation["observation.state"], + torch.from_numpy(current_ee_pos[:3, 3]), + ], + dim=-1, + ) + return observation + + +class GamepadControlWrapper(gym.Wrapper): + """ + Wrapper that allows controlling a gym environment with a gamepad. + + This wrapper intercepts the step method and allows human input via gamepad + to override the agent's actions when desired. + """ + + def __init__( + self, + env, + x_step_size=1.0, + y_step_size=1.0, + z_step_size=1.0, + auto_reset=False, + input_threshold=0.001, + ): + """ + Initialize the gamepad controller wrapper. + + Args: + env: The environment to wrap + x_step_size: Base movement step size for X axis in meters + y_step_size: Base movement step size for Y axis in meters + z_step_size: Base movement step size for Z axis in meters + vendor_id: USB vendor ID of the gamepad (default: Logitech) + product_id: USB product ID of the gamepad (default: RumblePad 2) + auto_reset: Whether to auto reset the environment when episode ends + input_threshold: Minimum movement delta to consider as active input + """ + super().__init__(env) + from lerobot.scripts.server.end_effector_control_utils import ( + GamepadControllerHID, + GamepadController, + ) + + # use HidApi for macos + if sys.platform == "darwin": + self.controller = GamepadControllerHID( + x_step_size=x_step_size, + y_step_size=y_step_size, + z_step_size=z_step_size, + ) + else: + self.controller = GamepadController( + x_step_size=x_step_size, + y_step_size=y_step_size, + z_step_size=z_step_size, + ) + self.auto_reset = auto_reset + self.input_threshold = input_threshold + self.controller.start() + + logging.info("Gamepad control wrapper initialized") + print("Gamepad controls:") + print(" Left analog stick: Move in X-Y plane") + print(" Right analog stick: Move in Z axis (up/down)") + print(" X/Square button: End episode (FAILURE)") + print(" Y/Triangle button: End episode (SUCCESS)") + print(" B/Circle button: Exit program") + + def get_gamepad_action(self): + """ + Get the current action from the gamepad if any input is active. + + Returns: + Tuple of (is_active, action, terminate_episode, success) + """ + # Update the controller to get fresh inputs + self.controller.update() + + # Get movement deltas from the controller + delta_x, delta_y, delta_z = self.controller.get_deltas() + + intervention_is_active = self.controller.should_intervene() + + # Create action from gamepad input + gamepad_action = np.array([delta_x, delta_y, delta_z], dtype=np.float32) + + # Check episode ending buttons + # We'll rely on controller.get_episode_end_status() which returns "success", "failure", or None + episode_end_status = self.controller.get_episode_end_status() + terminate_episode = episode_end_status is not None + success = episode_end_status == "success" + rerecord_episode = episode_end_status == "rerecord_episode" + + return ( + intervention_is_active, + gamepad_action, + terminate_episode, + success, + rerecord_episode, + ) + + def step(self, action): + """ + Step the environment, using gamepad input to override actions when active. + + Args: + action: Original action from agent + + Returns: + observation, reward, terminated, truncated, info + """ + # Get gamepad state and action + ( + is_intervention, + gamepad_action, + terminate_episode, + success, + rerecord_episode, + ) = self.get_gamepad_action() + + # Update episode ending state if requested + if terminate_episode: + logging.info( + f"Episode manually ended: {'SUCCESS' if success else 'FAILURE'}" + ) + + # Only override the action if gamepad is active + if is_intervention: + # Format according to the expected action type + if isinstance(self.action_space, gym.spaces.Tuple): + # For environments that use (action, is_intervention) tuples + final_action = (torch.from_numpy(gamepad_action), False) + else: + final_action = torch.from_numpy(gamepad_action) + else: + # Use the original action + final_action = action + + # Step the environment + obs, reward, terminated, truncated, info = self.env.step(final_action) + + # Add episode ending if requested via gamepad + terminated = terminated or truncated or terminate_episode + + if success: + reward = 1.0 + logging.info("Episode ended successfully with reward 1.0") + + info["is_intervention"] = is_intervention + action_intervention = ( + final_action[0] if isinstance(final_action, Tuple) else final_action + ) + if isinstance(action_intervention, np.ndarray): + action_intervention = torch.from_numpy(action_intervention) + info["action_intervention"] = action_intervention + info["rerecord_episode"] = rerecord_episode + + # If episode ended, reset the state + if terminated or truncated: + # Add success/failure information to info dict + info["next.success"] = success + + # Auto reset if configured + if self.auto_reset: + obs, reset_info = self.reset() + info.update(reset_info) + + return obs, reward, terminated, truncated, info + + def close(self): + """Clean up resources when environment closes.""" + # Stop the controller + if hasattr(self, "controller"): + self.controller.stop() + + # Call the parent close method + return self.env.close() + + +class ActionScaleWrapper(gym.ActionWrapper): + def __init__(self, env, ee_action_space_params=None): + super().__init__(env) + assert ( + ee_action_space_params is not None + ), "TODO: method implemented for ee action space only so far" + self.scale_vector = np.array( + [ + [ + ee_action_space_params.x_step_size, + ee_action_space_params.y_step_size, + ee_action_space_params.z_step_size, + ] + ] + ) + + def action(self, action): + is_intervention = False + if isinstance(action, tuple): + action, is_intervention = action + + return action * self.scale_vector, is_intervention def make_robot_env( @@ -779,11 +1140,20 @@ def make_robot_env( robot=robot, display_cameras=cfg.env.wrapper.display_cameras, delta=cfg.env.wrapper.delta_action, - use_delta_action_space=cfg.env.wrapper.use_relative_joint_positions, + use_delta_action_space=cfg.env.wrapper.use_relative_joint_positions + and cfg.env.wrapper.ee_action_space_params is None, ) # Add observation and image processing - env = ConvertToLeRobotObservation(env=env, device=cfg.device) + if cfg.env.wrapper.add_joint_velocity_to_observation: + env = AddJointVelocityToObservation(env=env, fps=cfg.fps) + if cfg.env.wrapper.add_ee_pose_to_observation: + env = EEObservationWrapper( + env=env, ee_pose_limits=cfg.env.wrapper.ee_action_space_params.bounds + ) + + env = ConvertToLeRobotObservation(env=env, device=cfg.env.device) + if cfg.env.wrapper.crop_params_dict is not None: env = ImageCropResizeWrapper( env=env, @@ -792,23 +1162,44 @@ def make_robot_env( ) # Add reward computation and control wrappers - env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device) + # env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device) env = TimeLimitWrapper( env=env, control_time_s=cfg.env.wrapper.control_time_s, fps=cfg.fps ) - env = KeyboardInterfaceWrapper(env=env) + if cfg.env.wrapper.ee_action_space_params is not None: + env = EEActionWrapper( + env=env, ee_action_space_params=cfg.env.wrapper.ee_action_space_params + ) + if ( + cfg.env.wrapper.ee_action_space_params is not None + and cfg.env.wrapper.ee_action_space_params.use_gamepad + ): + # env = ActionScaleWrapper(env=env, ee_action_space_params=cfg.env.wrapper.ee_action_space_params) + env = GamepadControlWrapper( + env=env, + x_step_size=cfg.env.wrapper.ee_action_space_params.x_step_size, + y_step_size=cfg.env.wrapper.ee_action_space_params.y_step_size, + z_step_size=cfg.env.wrapper.ee_action_space_params.z_step_size, + ) + else: + env = KeyboardInterfaceWrapper(env=env) + env = ResetWrapper( - env=env, reset_fn=None, reset_time_s=cfg.env.wrapper.reset_time_s - ) - env = JointMaskingActionSpace( - env=env, mask=cfg.env.wrapper.joint_masking_action_space + env=env, + reset_pose=cfg.env.wrapper.fixed_reset_joint_positions, + reset_time_s=cfg.env.wrapper.reset_time_s, ) + if ( + cfg.env.wrapper.ee_action_space_params is None + and cfg.env.wrapper.joint_masking_action_space is not None + ): + env = JointMaskingActionSpace( + env=env, mask=cfg.env.wrapper.joint_masking_action_space + ) env = BatchCompitableWrapper(env=env) return env - # batched version of the env that returns an observation of shape (b, c) - def get_classifier(pretrained_path, config_path, device="mps"): if pretrained_path is None or config_path is None: @@ -834,6 +1225,134 @@ def get_classifier(pretrained_path, config_path, device="mps"): return model +def record_dataset( + env, + repo_id, + root=None, + num_episodes=1, + control_time_s=20, + fps=30, + push_to_hub=True, + task_description="", + policy=None, +): + """ + Record a dataset of robot interactions using either a policy or teleop. + + Args: + env: The environment to record from + repo_id: Repository ID for dataset storage + root: Local root directory for dataset (optional) + num_episodes: Number of episodes to record + control_time_s: Maximum episode length in seconds + fps: Frames per second for recording + push_to_hub: Whether to push dataset to Hugging Face Hub + task_description: Description of the task being recorded + policy: Optional policy to generate actions (if None, uses teleop) + """ + from lerobot.common.datasets.lerobot_dataset import LeRobotDataset + + # Setup initial action (zero action if using teleop) + dummy_action = env.action_space.sample() + dummy_action = (torch.from_numpy(dummy_action[0] * 0.0), False) + action = dummy_action + + # Configure dataset features based on environment spaces + features = { + "observation.state": { + "dtype": "float32", + "shape": env.observation_space["observation.state"].shape, + "names": None, + }, + "action": { + "dtype": "float32", + "shape": env.action_space[0].shape, + "names": None, + }, + "next.reward": {"dtype": "float32", "shape": (1,), "names": None}, + "next.done": {"dtype": "bool", "shape": (1,), "names": None}, + } + + # Add image features + for key in env.observation_space: + if "image" in key: + features[key] = { + "dtype": "video", + "shape": env.observation_space[key].shape, + "names": None, + } + + # Create dataset + dataset = LeRobotDataset.create( + repo_id, + fps, + root=root, + use_videos=True, + image_writer_threads=4, + image_writer_processes=0, + features=features, + ) + + # Record episodes + episode_index = 0 + while episode_index < num_episodes: + obs, _ = env.reset() + start_episode_t = time.perf_counter() + log_say(f"Recording episode {episode_index}", play_sounds=True) + + # Run episode steps + while time.perf_counter() - start_episode_t < control_time_s: + start_loop_t = time.perf_counter() + + # Get action from policy if available + if policy is not None: + action = policy.select_action(obs) + + # Step environment + obs, reward, terminated, truncated, info = env.step(action) + + # Check if episode needs to be rerecorded + if info.get("rerecord_episode", False): + break + + # For teleop, get action from intervention + if policy is None: + action = { + "action": info["action_intervention"].cpu().squeeze(0).float() + } + + # Process observation for dataset + obs = {k: v.cpu().squeeze(0).float() for k, v in obs.items()} + + # Add frame to dataset + frame = {**obs, **action} + frame["next.reward"] = reward + frame["next.done"] = terminated or truncated + dataset.add_frame(frame) + + # Maintain consistent timing + if fps: + dt_s = time.perf_counter() - start_loop_t + busy_wait(1 / fps - dt_s) + + if terminated or truncated: + break + + # Handle episode recording + if info.get("rerecord_episode", False): + dataset.clear_episode_buffer() + logging.info(f"Re-recording episode {episode_index}") + continue + + dataset.save_episode(task_description) + episode_index += 1 + + # Finalize dataset + dataset.consolidate(run_compute_stats=True) + if push_to_hub: + dataset.push_to_hub(repo_id) + + def replay_episode(env, repo_id, root=None, episode=0): from lerobot.common.datasets.lerobot_dataset import LeRobotDataset @@ -841,14 +1360,16 @@ def replay_episode(env, repo_id, root=None, episode=0): dataset = LeRobotDataset( repo_id, root=root, episodes=[episode], local_files_only=local_files_only ) + env.reset() + actions = dataset.hf_dataset.select_columns("action") for idx in range(dataset.num_frames): start_episode_t = time.perf_counter() action = actions[idx]["action"][:4] - print(action) - env.step((action / env.unwrapped.delta, False)) + env.step((action, False)) + # env.step((action / env.unwrapped.delta, False)) dt_s = time.perf_counter() - start_episode_t busy_wait(1 / 10 - dt_s) @@ -875,14 +1396,6 @@ if __name__ == "__main__": help=( "Either the repo ID of a model hosted on the Hub or a path to a directory containing weights " "saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch " - "(useful for debugging). This argument is mutually exclusive with `--config`." - ), - ) - parser.add_argument( - "--config", - help=( - "Path to a yaml config you want to use for initializing a policy from scratch (useful for " - "debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)." ), ) parser.add_argument( @@ -929,11 +1442,30 @@ if __name__ == "__main__": help="Repo ID of the episode to replay", ) parser.add_argument( - "--replay-root", type=str, default=None, help="Root of the dataset to replay" + "--dataset-root", type=str, default=None, help="Root of the dataset to replay" ) parser.add_argument( "--replay-episode", type=int, default=0, help="Episode to replay" ) + parser.add_argument( + "--record-repo-id", + type=str, + default=None, + help="Repo ID of the dataset to record", + ) + parser.add_argument( + "--record-num-episodes", + type=int, + default=1, + help="Number of episodes to record", + ) + parser.add_argument( + "--record-episode-task", + type=str, + default="", + help="Single line description of the task to record", + ) + args = parser.parse_args() robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides) @@ -948,17 +1480,40 @@ if __name__ == "__main__": env = make_robot_env( robot, reward_classifier, - cfg.env, # .wrapper, + cfg, # .wrapper, ) - env.reset() + if args.record_repo_id is not None: + policy = None + if args.pretrained_policy_name_or_path is not None: + from lerobot.common.policies.sac.modeling_sac import SACPolicy + + policy = SACPolicy.from_pretrained(args.pretrained_policy_name_or_path) + policy.to(cfg.device) + policy.eval() + + record_dataset( + env, + args.record_repo_id, + root=args.dataset_root, + num_episodes=args.record_num_episodes, + fps=args.fps, + task_description=args.record_episode_task, + policy=policy, + ) + exit() if args.replay_repo_id is not None: replay_episode( - env, args.replay_repo_id, root=args.replay_root, episode=args.replay_episode + env, + args.replay_repo_id, + root=args.dataset_root, + episode=args.replay_episode, ) exit() + env.reset() + # Retrieve the robot's action space for joint commands. action_space_robot = env.action_space.spaces[0] @@ -967,9 +1522,11 @@ if __name__ == "__main__": # Smoothing coefficient (alpha) defines how much of the new random sample to mix in. # A value close to 0 makes the trajectory very smooth (slow to change), while a value close to 1 is less smooth. - alpha = 0.4 + alpha = 1.0 - while True: + num_episode = 0 + sucesses = [] + while num_episode < 20: start_loop_s = time.perf_counter() # Sample a new random action from the robot's action space. new_random_action = action_space_robot.sample() @@ -981,7 +1538,12 @@ if __name__ == "__main__": (torch.from_numpy(smoothed_action), False) ) if terminated or truncated: + sucesses.append(reward) env.reset() + num_episode += 1 dt_s = time.perf_counter() - start_loop_s busy_wait(1 / args.fps - dt_s) + + logging.info(f"Success after 20 steps {sucesses}") + logging.info(f"success rate {sum(sucesses)/ len(sucesses)}") diff --git a/lerobot/scripts/server/kinematics.py b/lerobot/scripts/server/kinematics.py new file mode 100644 index 00000000..6622fe76 --- /dev/null +++ b/lerobot/scripts/server/kinematics.py @@ -0,0 +1,543 @@ +import numpy as np +from scipy.spatial.transform import Rotation + + +def skew_symmetric(w): + """Creates the skew-symmetric matrix from a 3D vector.""" + return np.array([[0, -w[2], w[1]], [w[2], 0, -w[0]], [-w[1], w[0], 0]]) + + +def rodrigues_rotation(w, theta): + """Computes the rotation matrix using Rodrigues' formula.""" + w_hat = skew_symmetric(w) + return np.eye(3) + np.sin(theta) * w_hat + (1 - np.cos(theta)) * w_hat @ w_hat + + +def screw_axis_to_transform(S, theta): + """Converts a screw axis to a 4x4 transformation matrix.""" + S_w = S[:3] + S_v = S[3:] + if np.allclose(S_w, 0) and np.linalg.norm(S_v) == 1: # Pure translation + T = np.eye(4) + T[:3, 3] = S_v * theta + elif np.linalg.norm(S_w) == 1: # Rotation and translation + w_hat = skew_symmetric(S_w) + R = np.eye(3) + np.sin(theta) * w_hat + (1 - np.cos(theta)) * w_hat @ w_hat + t = ( + np.eye(3) * theta + + (1 - np.cos(theta)) * w_hat + + (theta - np.sin(theta)) * w_hat @ w_hat + ) @ S_v + T = np.eye(4) + T[:3, :3] = R + T[:3, 3] = t + else: + raise ValueError("Invalid screw axis parameters") + return T + + +def pose_difference_se3(pose1, pose2): + """ + Calculates the SE(3) difference between two 4x4 homogeneous transformation matrices. + + pose1 - pose2 + + Args: + pose1: A 4x4 numpy array representing the first pose. + pose2: A 4x4 numpy array representing the second pose. + + Returns: + A tuple (translation_diff, rotation_diff) where: + - translation_diff is a 3x1 numpy array representing the translational difference. + - rotation_diff is a 3x1 numpy array representing the rotational difference in axis-angle representation. + """ + + # Extract rotation matrices from poses + R1 = pose1[:3, :3] + R2 = pose2[:3, :3] + + # Calculate translational difference + translation_diff = pose1[:3, 3] - pose2[:3, 3] + + # Calculate rotational difference using scipy's Rotation library + R_diff = Rotation.from_matrix(R1 @ R2.T) + rotation_diff = R_diff.as_rotvec() # Convert to axis-angle representation + + return np.concatenate([translation_diff, rotation_diff]) + + +def se3_error(target_pose, current_pose): + pos_error = target_pose[:3, 3] - current_pose[:3, 3] + R_target = target_pose[:3, :3] + R_current = current_pose[:3, :3] + R_error = R_target @ R_current.T + rot_error = Rotation.from_matrix(R_error).as_rotvec() + return np.concatenate([pos_error, rot_error]) + + +class RobotKinematics: + """Robot kinematics class supporting multiple robot models.""" + + # Robot measurements dictionary + ROBOT_MEASUREMENTS = { + "koch": { + "gripper": [0.239, -0.001, 0.024], + "wrist": [0.209, 0, 0.024], + "forearm": [0.108, 0, 0.02], + "humerus": [0, 0, 0.036], + "shoulder": [0, 0, 0], + "base": [0, 0, 0.02], + }, + "so100": { + "gripper": [0.320, 0, 0.050], + "wrist": [0.278, 0, 0.050], + "forearm": [0.143, 0, 0.044], + "humerus": [0.031, 0, 0.072], + "shoulder": [0, 0, 0], + "base": [0, 0, 0.02], + }, + "moss": { + "gripper": [0.246, 0.013, 0.111], + "wrist": [0.245, 0.002, 0.064], + "forearm": [0.122, 0, 0.064], + "humerus": [0.001, 0.001, 0.063], + "shoulder": [0, 0, 0], + "base": [0, 0, 0.02], + }, + } + + def __init__(self, robot_type="so100"): + """Initialize kinematics for the specified robot type. + + Args: + robot_type: String specifying the robot model ("koch", "so100", or "moss") + """ + if robot_type not in self.ROBOT_MEASUREMENTS: + raise ValueError( + f"Unknown robot type: {robot_type}. Available types: {list(self.ROBOT_MEASUREMENTS.keys())}" + ) + + self.robot_type = robot_type + self.measurements = self.ROBOT_MEASUREMENTS[robot_type] + + # Initialize all transformation matrices and screw axes + self._setup_transforms() + + def _create_translation_matrix(self, x=0, y=0, z=0): + """Create a 4x4 translation matrix.""" + return np.array([[1, 0, 0, x], [0, 1, 0, y], [0, 0, 1, z], [0, 0, 0, 1]]) + + def _setup_transforms(self): + """Setup all transformation matrices and screw axes for the robot.""" + # Set up rotation matrices (constant across robot types) + + # Gripper orientation + self.gripper_X0 = np.array( + [ + [1, 0, 0, 0], + [0, 0, 1, 0], + [0, -1, 0, 0], + [0, 0, 0, 1], + ] + ) + + # Wrist orientation + self.wrist_X0 = np.array( + [ + [0, -1, 0, 0], + [1, 0, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1], + ] + ) + + # Base orientation + self.base_X0 = np.array( + [ + [0, 0, 1, 0], + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 0, 1], + ] + ) + + # Gripper + # Screw axis of gripper frame wrt base frame + self.S_BG = np.array( + [ + 1, + 0, + 0, + 0, + self.measurements["gripper"][2], + -self.measurements["gripper"][1], + ] + ) + + # Gripper origin to centroid transform + self.X_GoGc = self._create_translation_matrix(x=0.07) + + # Gripper origin to tip transform + self.X_GoGt = self._create_translation_matrix(x=0.12) + + # 0-position gripper frame pose wrt base + self.X_BoGo = self._create_translation_matrix( + x=self.measurements["gripper"][0], + y=self.measurements["gripper"][1], + z=self.measurements["gripper"][2], + ) + + # Wrist + # Screw axis of wrist frame wrt base frame + self.S_BR = np.array( + [0, 1, 0, -self.measurements["wrist"][2], 0, self.measurements["wrist"][0]] + ) + + # 0-position origin to centroid transform + self.X_RoRc = self._create_translation_matrix(x=0.0035, y=-0.002) + + # 0-position wrist frame pose wrt base + self.X_BR = self._create_translation_matrix( + x=self.measurements["wrist"][0], + y=self.measurements["wrist"][1], + z=self.measurements["wrist"][2], + ) + + # Forearm + # Screw axis of forearm frame wrt base frame + self.S_BF = np.array( + [ + 0, + 1, + 0, + -self.measurements["forearm"][2], + 0, + self.measurements["forearm"][0], + ] + ) + + # Forearm origin + centroid transform + self.X_FoFc = self._create_translation_matrix(x=0.036) + + # 0-position forearm frame pose wrt base + self.X_BF = self._create_translation_matrix( + x=self.measurements["forearm"][0], + y=self.measurements["forearm"][1], + z=self.measurements["forearm"][2], + ) + + # Humerus + # Screw axis of humerus frame wrt base frame + self.S_BH = np.array( + [ + 0, + -1, + 0, + self.measurements["humerus"][2], + 0, + -self.measurements["humerus"][0], + ] + ) + + # Humerus origin to centroid transform + self.X_HoHc = self._create_translation_matrix(x=0.0475) + + # 0-position humerus frame pose wrt base + self.X_BH = self._create_translation_matrix( + x=self.measurements["humerus"][0], + y=self.measurements["humerus"][1], + z=self.measurements["humerus"][2], + ) + + # Shoulder + # Screw axis of shoulder frame wrt Base frame + self.S_BS = np.array([0, 0, -1, 0, 0, 0]) + + # Shoulder origin to centroid transform + self.X_SoSc = self._create_translation_matrix(x=-0.017, z=0.0235) + + # 0-position shoulder frame pose wrt base + self.X_BS = self._create_translation_matrix( + x=self.measurements["shoulder"][0], + y=self.measurements["shoulder"][1], + z=self.measurements["shoulder"][2], + ) + + # Base + # Base origin to centroid transform + self.X_BoBc = self._create_translation_matrix(y=0.015) + + # World to base transform + self.X_WoBo = self._create_translation_matrix( + x=self.measurements["base"][0], + y=self.measurements["base"][1], + z=self.measurements["base"][2], + ) + + # Pre-compute gripper post-multiplication matrix + self._fk_gripper_post = self.X_GoGc @ self.X_BoGo @ self.gripper_X0 + + def fk_base(self): + """Forward kinematics for the base frame.""" + return self.X_WoBo @ self.X_BoBc @ self.base_X0 + + def fk_shoulder(self, robot_pos_deg): + """Forward kinematics for the shoulder frame.""" + robot_pos_rad = robot_pos_deg / 180 * np.pi + return ( + self.X_WoBo + @ screw_axis_to_transform(self.S_BS, robot_pos_rad[0]) + @ self.X_SoSc + @ self.X_BS + ) + + def fk_humerus(self, robot_pos_deg): + """Forward kinematics for the humerus frame.""" + robot_pos_rad = robot_pos_deg / 180 * np.pi + return ( + self.X_WoBo + @ screw_axis_to_transform(self.S_BS, robot_pos_rad[0]) + @ screw_axis_to_transform(self.S_BH, robot_pos_rad[1]) + @ self.X_HoHc + @ self.X_BH + ) + + def fk_forearm(self, robot_pos_deg): + """Forward kinematics for the forearm frame.""" + robot_pos_rad = robot_pos_deg / 180 * np.pi + return ( + self.X_WoBo + @ screw_axis_to_transform(self.S_BS, robot_pos_rad[0]) + @ screw_axis_to_transform(self.S_BH, robot_pos_rad[1]) + @ screw_axis_to_transform(self.S_BF, robot_pos_rad[2]) + @ self.X_FoFc + @ self.X_BF + ) + + def fk_wrist(self, robot_pos_deg): + """Forward kinematics for the wrist frame.""" + robot_pos_rad = robot_pos_deg / 180 * np.pi + return ( + self.X_WoBo + @ screw_axis_to_transform(self.S_BS, robot_pos_rad[0]) + @ screw_axis_to_transform(self.S_BH, robot_pos_rad[1]) + @ screw_axis_to_transform(self.S_BF, robot_pos_rad[2]) + @ screw_axis_to_transform(self.S_BR, robot_pos_rad[3]) + @ self.X_RoRc + @ self.X_BR + @ self.wrist_X0 + ) + + def fk_gripper(self, robot_pos_deg): + """Forward kinematics for the gripper frame.""" + robot_pos_rad = robot_pos_deg / 180 * np.pi + return ( + self.X_WoBo + @ screw_axis_to_transform(self.S_BS, robot_pos_rad[0]) + @ screw_axis_to_transform(self.S_BH, robot_pos_rad[1]) + @ screw_axis_to_transform(self.S_BF, robot_pos_rad[2]) + @ screw_axis_to_transform(self.S_BR, robot_pos_rad[3]) + @ screw_axis_to_transform(self.S_BG, robot_pos_rad[4]) + @ self._fk_gripper_post + ) + + def fk_gripper_tip(self, robot_pos_deg): + """Forward kinematics for the gripper tip frame.""" + robot_pos_rad = robot_pos_deg / 180 * np.pi + return ( + self.X_WoBo + @ screw_axis_to_transform(self.S_BS, robot_pos_rad[0]) + @ screw_axis_to_transform(self.S_BH, robot_pos_rad[1]) + @ screw_axis_to_transform(self.S_BF, robot_pos_rad[2]) + @ screw_axis_to_transform(self.S_BR, robot_pos_rad[3]) + @ screw_axis_to_transform(self.S_BG, robot_pos_rad[4]) + @ self.X_GoGt + @ self.X_BoGo + @ self.gripper_X0 + ) + + def compute_jacobian(self, robot_pos_deg, fk_func=None): + """Finite differences to compute the Jacobian. + J(i, j) represents how the ith component of the end-effector's velocity changes wrt a small change + in the jth joint's velocity. + + Args: + robot_pos_deg: Current joint positions in degrees + fk_func: Forward kinematics function to use (defaults to fk_gripper) + """ + if fk_func is None: + fk_func = self.fk_gripper + + eps = 1e-8 + jac = np.zeros(shape=(6, 5)) + delta = np.zeros(len(robot_pos_deg[:-1]), dtype=np.float64) + for el_ix in range(len(robot_pos_deg[:-1])): + delta *= 0 + delta[el_ix] = eps / 2 + Sdot = ( + pose_difference_se3( + fk_func(robot_pos_deg[:-1] + delta), + fk_func(robot_pos_deg[:-1] - delta), + ) + / eps + ) + jac[:, el_ix] = Sdot + return jac + + def compute_positional_jacobian(self, robot_pos_deg, fk_func=None): + """Finite differences to compute the positional Jacobian. + J(i, j) represents how the ith component of the end-effector's position changes wrt a small change + in the jth joint's velocity. + + Args: + robot_pos_deg: Current joint positions in degrees + fk_func: Forward kinematics function to use (defaults to fk_gripper) + """ + if fk_func is None: + fk_func = self.fk_gripper + + eps = 1e-8 + jac = np.zeros(shape=(3, 5)) + delta = np.zeros(len(robot_pos_deg[:-1]), dtype=np.float64) + for el_ix in range(len(robot_pos_deg[:-1])): + delta *= 0 + delta[el_ix] = eps / 2 + Sdot = ( + fk_func(robot_pos_deg[:-1] + delta)[:3, 3] + - fk_func(robot_pos_deg[:-1] - delta)[:3, 3] + ) / eps + jac[:, el_ix] = Sdot + return jac + + def ik( + self, current_joint_state, desired_ee_pose, position_only=True, fk_func=None + ): + """Inverse kinematics using gradient descent. + + Args: + current_joint_state: Initial joint positions in degrees + desired_ee_pose: Target end-effector pose as a 4x4 transformation matrix + position_only: If True, only match end-effector position, not orientation + fk_func: Forward kinematics function to use (defaults to fk_gripper) + + Returns: + Joint positions in degrees that achieve the desired end-effector pose + """ + if fk_func is None: + fk_func = self.fk_gripper + + # Do gradient descent. + max_iterations = 5 + learning_rate = 1 + for _ in range(max_iterations): + current_ee_pose = fk_func(current_joint_state) + if not position_only: + error = se3_error(desired_ee_pose, current_ee_pose) + jac = self.compute_jacobian(current_joint_state, fk_func) + else: + error = desired_ee_pose[:3, 3] - current_ee_pose[:3, 3] + jac = self.compute_positional_jacobian(current_joint_state, fk_func) + delta_angles = np.linalg.pinv(jac) @ error + current_joint_state[:-1] += learning_rate * delta_angles + + if np.linalg.norm(error) < 5e-3: + return current_joint_state + return current_joint_state + + +if __name__ == "__main__": + import time + + def run_test(robot_type): + """Run test suite for a specific robot type.""" + print(f"\n--- Testing {robot_type.upper()} Robot ---") + + # Initialize kinematics for this robot + robot = RobotKinematics(robot_type) + + # Test 1: Forward kinematics consistency + print("Test 1: Forward kinematics consistency") + test_angles = np.array( + [30, 45, -30, 20, 10, 0] + ) # Example joint angles in degrees + + # Calculate FK for different joints + shoulder_pose = robot.fk_shoulder(test_angles) + humerus_pose = robot.fk_humerus(test_angles) + forearm_pose = robot.fk_forearm(test_angles) + wrist_pose = robot.fk_wrist(test_angles) + gripper_pose = robot.fk_gripper(test_angles) + gripper_tip_pose = robot.fk_gripper_tip(test_angles) + + # Check that poses form a consistent kinematic chain (positions should be progressively further from origin) + distances = [ + np.linalg.norm(shoulder_pose[:3, 3]), + np.linalg.norm(humerus_pose[:3, 3]), + np.linalg.norm(forearm_pose[:3, 3]), + np.linalg.norm(wrist_pose[:3, 3]), + np.linalg.norm(gripper_pose[:3, 3]), + np.linalg.norm(gripper_tip_pose[:3, 3]), + ] + + # Check if distances generally increase along the chain + is_consistent = all( + distances[i] <= distances[i + 1] for i in range(len(distances) - 1) + ) + print(f" Pose distances from origin: {[round(d, 3) for d in distances]}") + print( + f" Kinematic chain consistency: {'PASSED' if is_consistent else 'FAILED'}" + ) + + # Test 2: Jacobian computation + print("Test 2: Jacobian computation") + jacobian = robot.compute_jacobian(test_angles) + positional_jacobian = robot.compute_positional_jacobian(test_angles) + + # Check shapes + jacobian_shape_ok = jacobian.shape == (6, 5) + pos_jacobian_shape_ok = positional_jacobian.shape == (3, 5) + + print(f" Jacobian shape: {'PASSED' if jacobian_shape_ok else 'FAILED'}") + print( + f" Positional Jacobian shape: {'PASSED' if pos_jacobian_shape_ok else 'FAILED'}" + ) + + # Test 3: Inverse kinematics + print("Test 3: Inverse kinematics (position only)") + + # Generate target pose from known joint angles + original_angles = np.array([10, 20, 30, -10, 5, 0]) + target_pose = robot.fk_gripper(original_angles) + + # Start IK from a different position + initial_guess = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) + + # Measure IK performance + start_time = time.time() + computed_angles = robot.ik(initial_guess.copy(), target_pose) + ik_time = time.time() - start_time + + # Compute resulting pose from IK solution + result_pose = robot.fk_gripper(computed_angles) + + # Calculate position error + pos_error = np.linalg.norm(target_pose[:3, 3] - result_pose[:3, 3]) + passed = pos_error < 0.01 # Accept errors less than 1cm + + print(f" IK computation time: {ik_time:.4f} seconds") + print(f" Position error: {pos_error:.4f}") + print(f" IK position accuracy: {'PASSED' if passed else 'FAILED'}") + + return is_consistent and jacobian_shape_ok and pos_jacobian_shape_ok and passed + + # Run tests for all robot types + results = {} + for robot_type in ["koch", "so100", "moss"]: + results[robot_type] = run_test(robot_type) + + # Print overall summary + print("\n=== Test Summary ===") + all_passed = all(results.values()) + for robot_type, passed in results.items(): + print(f"{robot_type.upper()}: {'PASSED' if passed else 'FAILED'}") + print(f"\nOverall: {'ALL TESTS PASSED' if all_passed else 'SOME TESTS FAILED'}") diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 580eed1a..eb04effd 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -315,16 +315,49 @@ def start_learner_server( def check_nan_in_transition( - observations: torch.Tensor, actions: torch.Tensor, next_state: torch.Tensor -): - for k in observations: - if torch.isnan(observations[k]).any(): - logging.error(f"observations[{k}] contains NaN values") - for k in next_state: - if torch.isnan(next_state[k]).any(): - logging.error(f"next_state[{k}] contains NaN values") + observations: torch.Tensor, + actions: torch.Tensor, + next_state: torch.Tensor, + raise_error: bool = False, +) -> bool: + """ + Check for NaN values in transition data. + + Args: + observations: Dictionary of observation tensors + actions: Action tensor + next_state: Dictionary of next state tensors + raise_error: If True, raises ValueError when NaN is detected + + Returns: + bool: True if NaN values were detected, False otherwise + """ + nan_detected = False + + # Check observations + for key, tensor in observations.items(): + if torch.isnan(tensor).any(): + logging.error(f"observations[{key}] contains NaN values") + nan_detected = True + if raise_error: + raise ValueError(f"NaN detected in observations[{key}]") + + # Check next state + for key, tensor in next_state.items(): + if torch.isnan(tensor).any(): + logging.error(f"next_state[{key}] contains NaN values") + nan_detected = True + if raise_error: + raise ValueError(f"NaN detected in next_state[{key}]") + + # Check actions if torch.isnan(actions).any(): logging.error("actions contains NaN values") + nan_detected = True + if raise_error: + raise ValueError("NaN detected in actions") + + return nan_detected def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module): @@ -460,9 +493,18 @@ def add_actor_information_and_train( for transition in transition_list: transition = move_transition_to_device(transition, device=device) + if check_nan_in_transition( + transition["state"], transition["action"], transition["next_state"] + ): + logging.warning("NaN detected in transition, skipping") + continue replay_buffer.add(**transition) - if transition.get("complementary_info", {}).get("is_intervention"): + + if cfg.dataset_repo_id is not None and transition.get( + "complementary_info", {} + ).get("is_intervention"): offline_replay_buffer.add(**transition) + logging.debug("[LEARNER] Received transitions") logging.debug("[LEARNER] Waiting for interactions") while not interaction_message_queue.empty() and not shutdown_event.is_set(): From 17ec837a7a3bada51d4648bdb190030913d8d5a3 Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Tue, 18 Mar 2025 14:57:15 +0000 Subject: [PATCH 3/6] Refactor SACPolicy and learner server for improved replay buffer management - Updated SACPolicy to create critic heads using a list comprehension for better readability. - Simplified the saving and loading of models using `save_model` and `load_model` functions from the safetensors library. - Introduced `initialize_offline_replay_buffer` function in the learner server to streamline offline dataset handling and replay buffer initialization. - Enhanced logging for dataset loading processes to improve traceability during training. --- lerobot/common/policies/sac/modeling_sac.py | 237 +++++++++----------- lerobot/scripts/server/learner_server.py | 60 ++++- 2 files changed, 159 insertions(+), 138 deletions(-) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 2c4bad5f..646da874 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -18,7 +18,7 @@ # TODO: (1) better device management from copy import deepcopy -from typing import Callable, Optional, Tuple, Union, Dict +from typing import Callable, Optional, Tuple, Union, Dict, List from pathlib import Path import einops @@ -88,33 +88,33 @@ class SACPolicy( encoder_critic = SACObservationEncoder(config, self.normalize_inputs) encoder_actor = SACObservationEncoder(config, self.normalize_inputs) + # Create a list of critic heads + critic_heads = [ + CriticHead( + input_dim=encoder_critic.output_dim + config.output_shapes["action"][0], + **config.critic_network_kwargs, + ) + for _ in range(config.num_critics) + ] + self.critic_ensemble = CriticEnsemble( encoder=encoder_critic, - ensemble=Ensemble( - [ - CriticHead( - input_dim=encoder_critic.output_dim - + config.output_shapes["action"][0], - **config.critic_network_kwargs, - ) - for _ in range(config.num_critics) - ] - ), + ensemble=critic_heads, output_normalization=self.normalize_targets, ) + # Create target critic heads as deepcopies of the original critic heads + target_critic_heads = [ + CriticHead( + input_dim=encoder_critic.output_dim + config.output_shapes["action"][0], + **config.critic_network_kwargs, + ) + for _ in range(config.num_critics) + ] + self.critic_target = CriticEnsemble( encoder=encoder_critic, - ensemble=Ensemble( - [ - CriticHead( - input_dim=encoder_critic.output_dim - + config.output_shapes["action"][0], - **config.critic_network_kwargs, - ) - for _ in range(config.num_critics) - ] - ), + ensemble=target_critic_heads, output_normalization=self.normalize_targets, ) @@ -149,19 +149,9 @@ class SACPolicy( import json from dataclasses import asdict from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME - from safetensors.torch import save_file + from safetensors.torch import save_model - # NOTE: Using tensordict.from_modules in the model to batch the inference using torch.vmap - # implies one side effect: the __batch_size parameters are saved in the state_dict - # __batch_size is torch.Size or safetensor save only torch.Tensor - # so we need to filter them out before saving - simplified_state_dict = {} - - for name, param in self.named_parameters(): - simplified_state_dict[name] = param - save_file( - simplified_state_dict, os.path.join(save_directory, SAFETENSORS_SINGLE_FILE) - ) + save_model(self, os.path.join(save_directory, SAFETENSORS_SINGLE_FILE)) # Save config config_dict = asdict(self.config) @@ -191,7 +181,7 @@ class SACPolicy( from pathlib import Path from huggingface_hub import hf_hub_download from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME - from safetensors.torch import load_file + from safetensors.torch import load_model from lerobot.common.policies.sac.configuration_sac import SACConfig # Check if model_id is a local path or a hub model ID @@ -243,28 +233,7 @@ class SACPolicy( # Load state dict from safetensors file if os.path.exists(safetensors_file): - # Note: The load_file function returns a dict with the parameters, but __batch_size - # is not loaded so we need to copy it from the model state_dict - # Load the parameters only - loaded_state_dict = load_file(safetensors_file, device=map_location) - - # Copy batch size parameters - find_and_copy_params( - original_state_dict=model.state_dict(), - loaded_state_dict=loaded_state_dict, - pattern="__batch_size", - match_type="endswith", - ) - - # Copy normalization buffer parameters - find_and_copy_params( - original_state_dict=model.state_dict(), - loaded_state_dict=loaded_state_dict, - pattern="_orig_mod.output_normalization.buffer_action", - match_type="contains", - ) - - model.load_state_dict(loaded_state_dict, strict=False) + load_model(model, filename=safetensors_file, device=map_location) return model @@ -594,21 +563,21 @@ class CriticEnsemble(nn.Module): def __init__( self, encoder: Optional[nn.Module], - ensemble: "Ensemble[CriticHead]", + ensemble: List[CriticHead], output_normalization: nn.Module, init_final: Optional[float] = None, ): super().__init__() self.encoder = encoder - self.ensemble = ensemble self.init_final = init_final self.output_normalization = output_normalization + self.critics = nn.ModuleList(ensemble) self.parameters_to_optimize = [] # Handle the case where a part of the encoder if frozen if self.encoder is not None: self.parameters_to_optimize += list(self.encoder.parameters_to_optimize) - self.parameters_to_optimize += list(self.ensemble.parameters()) + self.parameters_to_optimize += list(self.critics.parameters()) def forward( self, @@ -632,8 +601,15 @@ class CriticEnsemble(nn.Module): ) inputs = torch.cat([obs_enc, actions], dim=-1) - q_values = self.ensemble(inputs) # [num_critics, B, 1] - return q_values.squeeze(-1) # [num_critics, B] + + # Loop through critics and collect outputs + q_values = [] + for critic in self.critics: + q_values.append(critic(inputs)) + + # Stack outputs to match expected shape [num_critics, batch_size] + q_values = torch.stack([q.squeeze(-1) for q in q_values], dim=0) + return q_values class Policy(nn.Module): @@ -706,9 +682,9 @@ class Policy(nn.Module): # Compute standard deviations if self.fixed_std is None: log_std = self.std_layer(outputs) - assert not torch.isnan( - log_std - ).any(), "[ERROR] log_std became NaN after std_layer!" + assert not torch.isnan(log_std).any(), ( + "[ERROR] log_std became NaN after std_layer!" + ) if self.use_tanh_squash: log_std = torch.tanh(log_std) @@ -1025,73 +1001,78 @@ def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict: if __name__ == "__main__": - # Test the SACObservationEncoder + # Benchmark the CriticEnsemble performance import time - config = SACConfig() - config.num_critics = 10 - config.vision_encoder_name = None - encoder = SACObservationEncoder(config, nn.Identity()) - # actor_encoder = SACObservationEncoder(config) - # encoder = torch.compile(encoder) + # Configuration + num_critics = 10 + batch_size = 32 + action_dim = 7 + obs_dim = 64 + hidden_dims = [256, 256] + num_iterations = 100 + + print("Creating test environment...") + + # Create a simple dummy encoder + class DummyEncoder(nn.Module): + def __init__(self): + super().__init__() + self.output_dim = obs_dim + self.parameters_to_optimize = [] + + def forward(self, obs): + # Just return a random tensor of the right shape + # In practice, this would encode the observations + return torch.randn(batch_size, obs_dim, device=device) + + # Create critic heads + print(f"Creating {num_critics} critic heads...") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + critic_heads = [ + CriticHead( + input_dim=obs_dim + action_dim, + hidden_dims=hidden_dims, + ).to(device) + for _ in range(num_critics) + ] + + # Create the critic ensemble + print("Creating CriticEnsemble...") critic_ensemble = CriticEnsemble( - encoder=encoder, - ensemble=Ensemble( - [ - CriticHead( - input_dim=encoder.output_dim + config.output_shapes["action"][0], - **config.critic_network_kwargs, - ) - for _ in range(config.num_critics) - ] - ), + encoder=DummyEncoder().to(device), + ensemble=critic_heads, output_normalization=nn.Identity(), - ) - # actor = Policy( - # encoder=actor_encoder, - # network=MLP(input_dim=actor_encoder.output_dim, **config.actor_network_kwargs), - # action_dim=config.output_shapes["action"][0], - # encoder_is_shared=config.shared_encoder, - # **config.policy_kwargs, - # ) - # encoder = encoder.to("cuda:0") - # critic_ensemble = torch.compile(critic_ensemble) - critic_ensemble = critic_ensemble.to("cuda:0") - # actor = torch.compile(actor) - # actor = actor.to("cuda:0") + ).to(device) + + # Create random input data + print("Creating input data...") obs_dict = { - "observation.image": torch.randn(8, 3, 84, 84), - "observation.state": torch.randn(8, 4), + "observation.state": torch.randn(batch_size, obs_dim, device=device), } - actions = torch.randn(8, 2).to("cuda:0") - # obs_dict = {k: v.to("cuda:0") for k, v in obs_dict.items()} - # print("compiling...") - q_value = critic_ensemble(obs_dict, actions) - print(q_value.size()) - # action = actor(obs_dict) - # print("compiled") - # start = time.perf_counter() - # for _ in range(1000): - # # features = encoder(obs_dict) - # action = actor(obs_dict) - # # q_value = critic_ensemble(obs_dict, actions) - # print("Time taken:", time.perf_counter() - start) - # Compare the performance of the ensemble vs a for loop of 16 MLPs - ensemble = Ensemble([CriticHead(256, [256, 256]) for _ in range(2)]) - ensemble = ensemble.to("cuda:0") - critic = CriticHead(256, [256, 256]) - critic = critic.to("cuda:0") - data_ensemble = torch.randn(8, 256).to("cuda:0") - ensemble = torch.compile(ensemble) - # critic = torch.compile(critic) - print(ensemble(data_ensemble).size()) - print(critic(data_ensemble).size()) - start = time.perf_counter() - for _ in range(1000): - ensemble(data_ensemble) - print("Time taken:", time.perf_counter() - start) - start = time.perf_counter() - for _ in range(1000): - for i in range(2): - critic(data_ensemble) - print("Time taken:", time.perf_counter() - start) + actions = torch.randn(batch_size, action_dim, device=device) + + # Warmup run + print("Warming up...") + _ = critic_ensemble(obs_dict, actions) + + # Time the forward pass + print(f"Running benchmark with {num_iterations} iterations...") + start_time = time.perf_counter() + for _ in range(num_iterations): + q_values = critic_ensemble(obs_dict, actions) + end_time = time.perf_counter() + + # Print results + elapsed_time = end_time - start_time + print(f"Total time: {elapsed_time:.4f} seconds") + print(f"Average time per iteration: {elapsed_time / num_iterations * 1000:.4f} ms") + print(f"Output shape: {q_values.shape}") # Should be [num_critics, batch_size] + + # Verify that all critic heads produce different outputs + # This confirms each critic head is unique + # print("\nVerifying critic outputs are different:") + # for i in range(num_critics): + # for j in range(i + 1, num_critics): + # diff = torch.abs(q_values[i] - q_values[j]).mean().item() + # print(f"Mean difference between critic {i} and {j}: {diff:.6f}") diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index eb04effd..d1235980 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -121,7 +121,7 @@ def load_training_state( return None, None training_state = torch.load( - logger.last_checkpoint_dir / logger.training_state_file_name + logger.last_checkpoint_dir / logger.training_state_file_name, weights_only=False ) if isinstance(training_state["optimizer"], dict): @@ -160,6 +160,7 @@ def initialize_replay_buffer( optimize_memory=True, ) + logging.info("Resume training load the online dataset") dataset = LeRobotDataset( repo_id=cfg.dataset_repo_id, local_files_only=True, @@ -174,6 +175,37 @@ def initialize_replay_buffer( ) +def initialize_offline_replay_buffer( + cfg: DictConfig, + logger: Logger, + device: str, + storage_device: str, + active_action_dims: list[int] | None = None, +) -> ReplayBuffer: + if not cfg.resume: + logging.info("make_dataset offline buffer") + offline_dataset = make_dataset(cfg) + if cfg.resume: + logging.info("load offline dataset") + offline_dataset = LeRobotDataset( + repo_id=cfg.dataset_repo_id, + local_files_only=True, + root=logger.log_dir / "dataset_offline", + ) + + logging.info("Convert to a offline replay buffer") + offline_replay_buffer = ReplayBuffer.from_lerobot_dataset( + offline_dataset, + device=device, + state_keys=cfg.policy.input_shapes.keys(), + action_mask=active_action_dims, + action_delta=cfg.env.wrapper.delta_action, + storage_device=storage_device, + optimize_memory=True, + ) + return offline_replay_buffer + + def get_observation_features( policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor ) -> tuple[torch.Tensor | None, torch.Tensor | None]: @@ -447,9 +479,6 @@ def add_actor_information_and_train( offline_replay_buffer = None if cfg.dataset_repo_id is not None: - logging.info("make_dataset offline buffer") - offline_dataset = make_dataset(cfg) - logging.info("Convertion to a offline replay buffer") active_action_dims = None if cfg.env.wrapper.joint_masking_action_space is not None: active_action_dims = [ @@ -457,14 +486,12 @@ def add_actor_information_and_train( for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space) if mask ] - offline_replay_buffer = ReplayBuffer.from_lerobot_dataset( - offline_dataset, + offline_replay_buffer = initialize_offline_replay_buffer( + cfg=cfg, + logger=logger, device=device, - state_keys=cfg.policy.input_shapes.keys(), - action_mask=active_action_dims, - action_delta=cfg.env.wrapper.delta_action, storage_device=storage_device, - optimize_memory=True, + active_action_dims=active_action_dims, ) batch_size: int = batch_size // 2 # We will sample from both replay buffer @@ -714,6 +741,19 @@ def add_actor_information_and_train( replay_buffer.to_lerobot_dataset( cfg.dataset_repo_id, fps=cfg.fps, root=logger.log_dir / "dataset" ) + if offline_replay_buffer is not None: + dataset_dir = logger.log_dir / "dataset_offline" + + if dataset_dir.exists() and dataset_dir.is_dir(): + shutil.rmtree( + dataset_dir, + ) + + offline_replay_buffer.to_lerobot_dataset( + cfg.dataset_repo_id, + fps=cfg.fps, + root=logger.log_dir / "dataset_offline", + ) logging.info("Resume training") From f899edb57fe3c4bc1caec9a4053aa90f7899e2bc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 18 Mar 2025 14:57:57 +0000 Subject: [PATCH 4/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- lerobot/common/policies/sac/modeling_sac.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 646da874..e634af3f 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -682,9 +682,9 @@ class Policy(nn.Module): # Compute standard deviations if self.fixed_std is None: log_std = self.std_layer(outputs) - assert not torch.isnan(log_std).any(), ( - "[ERROR] log_std became NaN after std_layer!" - ) + assert not torch.isnan( + log_std + ).any(), "[ERROR] log_std became NaN after std_layer!" if self.use_tanh_squash: log_std = torch.tanh(log_std) From b7bd13570f0953b4766271fef9a4c4149c2e59f7 Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Wed, 19 Mar 2025 09:54:46 +0000 Subject: [PATCH 5/6] Update configuration files for improved performance and flexibility - Increased frame rate in `maniskill_example.yaml` from 20 to 400 for enhanced simulation speed. - Updated `sac_maniskill.yaml` to set `dataset_repo_id` to null and adjusted `grad_clip_norm` from 10.0 to 40.0. - Changed `storage_device` from "cpu" to "cuda" for better resource utilization. - Modified `save_freq` from 2000000 to 1000000 to optimize saving intervals. - Enhanced input normalization parameters for `observation.state` and `observation.image` in SAC policy. - Adjusted `num_critics` from 10 to 2 and `policy_parameters_push_frequency` from 1 to 4 for improved training dynamics. - Updated `learner_server.py` to utilize `offline_buffer_capacity` for replay buffer initialization. - Changed action multiplier in `maniskill_manipulator.py` from 1 to 0.03 for finer control over actions. --- lerobot/configs/env/maniskill_example.yaml | 2 +- lerobot/configs/policy/sac_maniskill.yaml | 63 ++++++++++--------- lerobot/scripts/server/learner_server.py | 1 + .../scripts/server/maniskill_manipulator.py | 2 +- 4 files changed, 35 insertions(+), 33 deletions(-) diff --git a/lerobot/configs/env/maniskill_example.yaml b/lerobot/configs/env/maniskill_example.yaml index 3df23b2e..2beaa8a6 100644 --- a/lerobot/configs/env/maniskill_example.yaml +++ b/lerobot/configs/env/maniskill_example.yaml @@ -1,6 +1,6 @@ # @package _global_ -fps: 20 +fps: 400 env: name: maniskill/pushcube diff --git a/lerobot/configs/policy/sac_maniskill.yaml b/lerobot/configs/policy/sac_maniskill.yaml index c9bbca44..cf20d059 100644 --- a/lerobot/configs/policy/sac_maniskill.yaml +++ b/lerobot/configs/policy/sac_maniskill.yaml @@ -8,22 +8,23 @@ # env.gym.obs_type=environment_state_agent_pos \ seed: 1 -dataset_repo_id: "AdilZtn/Maniskill-Pushcube-demonstration-medium" +# dataset_repo_id: "AdilZtn/Maniskill-Pushcube-demonstration-medium" +dataset_repo_id: null training: # Offline training dataloader num_workers: 4 batch_size: 512 - grad_clip_norm: 10.0 + grad_clip_norm: 40.0 lr: 3e-4 - storage_device: "cpu" + storage_device: "cuda" eval_freq: 2500 log_freq: 10 - save_freq: 2000000 + save_freq: 1000000 online_steps: 1000000 online_rollout_n_episodes: 10 @@ -32,17 +33,12 @@ training: online_sampling_ratio: 1.0 online_env_seed: 10000 online_buffer_capacity: 200000 + offline_buffer_capacity: 100000 online_buffer_seed_size: 0 online_step_before_learning: 500 do_online_rollout_async: false policy_update_freq: 1 - # delta_timestamps: - # observation.environment_state: "[i / ${fps} for i in range(${policy.horizon} + 1)]" - # observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]" - # action: "[i / ${fps} for i in range(${policy.horizon})]" - # next.reward: "[i / ${fps} for i in range(${policy.horizon})]" - policy: name: sac @@ -68,28 +64,33 @@ policy: camera_number: 1 # Normalization / Unnormalization - input_normalization_modes: null - # input_normalization_modes: - # observation.state: min_max - input_normalization_params: null - # observation.state: - # min: [-1.9361e+00, -7.7640e-01, -7.7094e-01, -2.9709e+00, -8.5656e-01, - # 1.0764e+00, -1.2680e+00, 0.0000e+00, 0.0000e+00, -9.3448e+00, - # -3.3828e+00, -3.8420e+00, -5.2553e+00, -3.4154e+00, -6.5082e+00, - # -6.0500e+00, -8.7193e+00, -8.2337e+00, -3.4650e-01, -4.9441e-01, - # 8.3516e-03, -3.1114e-01, -9.9700e-01, -2.3471e-01, -2.7137e-01] + # input_normalization_modes: null + input_normalization_modes: + observation.state: min_max + observation.image: mean_std + # input_normalization_params: null + input_normalization_params: + observation.state: + min: [-1.9361e+00, -7.7640e-01, -7.7094e-01, -2.9709e+00, -8.5656e-01, + 1.0764e+00, -1.2680e+00, 0.0000e+00, 0.0000e+00, -9.3448e+00, + -3.3828e+00, -3.8420e+00, -5.2553e+00, -3.4154e+00, -6.5082e+00, + -6.0500e+00, -8.7193e+00, -8.2337e+00, -3.4650e-01, -4.9441e-01, + 8.3516e-03, -3.1114e-01, -9.9700e-01, -2.3471e-01, -2.7137e-01] + max: [ 0.8644, 1.4306, 1.8520, -0.7578, 0.9508, 3.4901, 1.9381, 0.0400, + 0.0400, 5.0885, 4.7156, 7.9393, 7.9100, 2.9796, 5.7720, 4.7163, + 7.8145, 9.7415, 0.2422, 0.4505, 0.6306, 0.2622, 1.0000, 0.5135, + 0.4001] - # max: [ 0.8644, 1.4306, 1.8520, -0.7578, 0.9508, 3.4901, 1.9381, 0.0400, - # 0.0400, 5.0885, 4.7156, 7.9393, 7.9100, 2.9796, 5.7720, 4.7163, - # 7.8145, 9.7415, 0.2422, 0.4505, 0.6306, 0.2622, 1.0000, 0.5135, - # 0.4001] + observation.image: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] output_normalization_modes: action: min_max output_normalization_params: action: - min: [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0] - max: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + min: [-0.03, -0.03, -0.03, -0.03, -0.03, -0.03, -0.03] + max: [0.03, 0.03, 0.03, 0.03, 0.03, 0.03, 0.03] output_normalization_shapes: action: [7] @@ -99,8 +100,8 @@ policy: # discount: 0.99 discount: 0.80 temperature_init: 1.0 - num_critics: 10 #10 - num_subsample_critics: 2 + num_critics: 2 #10 + num_subsample_critics: null critic_lr: 3e-4 actor_lr: 3e-4 temperature_lr: 3e-4 @@ -111,7 +112,7 @@ policy: actor_learner_config: learner_host: "127.0.0.1" learner_port: 50051 - policy_parameters_push_frequency: 1 + policy_parameters_push_frequency: 4 concurrency: - actor: 'processes' - learner: 'processes' + actor: 'threads' + learner: 'threads' diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index d1235980..713fc2a8 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -202,6 +202,7 @@ def initialize_offline_replay_buffer( action_delta=cfg.env.wrapper.delta_action, storage_device=storage_device, optimize_memory=True, + capacity=cfg.training.offline_buffer_capacity, ) return offline_replay_buffer diff --git a/lerobot/scripts/server/maniskill_manipulator.py b/lerobot/scripts/server/maniskill_manipulator.py index e4d55955..495042de 100644 --- a/lerobot/scripts/server/maniskill_manipulator.py +++ b/lerobot/scripts/server/maniskill_manipulator.py @@ -159,7 +159,7 @@ def make_maniskill( env.unwrapped.metadata["render_fps"] = 20 env = ManiSkillCompat(env) env = ManiSkillActionWrapper(env) - env = ManiSkillMultiplyActionWrapper(env, multiply_factor=1) + env = ManiSkillMultiplyActionWrapper(env, multiply_factor=0.03) return env From 6fa3e5f9ad2cbee6e5ebe10f7b42c13d92d99b12 Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Wed, 19 Mar 2025 13:16:31 +0000 Subject: [PATCH 6/6] Enhance training information logging in learner server - Added tracking for replay buffer size and offline replay buffer size during training steps. --- lerobot/scripts/server/learner_server.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 713fc2a8..2b19fea2 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -681,6 +681,11 @@ def add_actor_information_and_train( policy.update_target_networks() if optimization_step % cfg.training.log_freq == 0: + training_infos["replay_buffer_size"] = len(replay_buffer) + if offline_replay_buffer is not None: + training_infos["offline_replay_buffer_size"] = len( + offline_replay_buffer + ) training_infos["Optimization step"] = optimization_step logger.log_dict( d=training_infos, mode="train", custom_step_key="Optimization step"