From a7db3959f50bfd9e5f5072856446d8442323fe5d Mon Sep 17 00:00:00 2001
From: Michel Aractingi <michel.aractingi@huggingface.co>
Date: Tue, 11 Feb 2025 11:34:46 +0100
Subject: [PATCH] - Added JointMaskingActionSpace wrapper in `gym_manipulator`
 in order to select which joints will be controlled. For example, we can
 disable the gripper actions for some tasks. - Added Nan detection mechanisms
 in the actor, learner and gym_manipulator for the case where we encounter
 nans in the loop. - changed the non-blocking in the `.to(device)` functions
 to only work for the case of cuda because they were causing nans when running
 the policy on mps - Added some joint clipping and limits in the env, robot
 and policy configs. TODO clean this part and make the limits in one config
 file only.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
---
 .../hilserl/classifier/modeling_classifier.py |   4 +-
 lerobot/configs/env/so100_real.yaml           |   4 +-
 lerobot/configs/policy/sac_real.yaml          |  10 +-
 lerobot/configs/robot/so100.yaml              |   9 +-
 lerobot/scripts/server/actor_server.py        |  21 +++-
 lerobot/scripts/server/buffer.py              |  17 ++-
 lerobot/scripts/server/find_joint_limits.py   |   1 +
 lerobot/scripts/server/gym_manipulator.py     | 106 ++++++++++++++++--
 lerobot/scripts/server/learner_server.py      |  20 +++-
 9 files changed, 161 insertions(+), 31 deletions(-)

diff --git a/lerobot/common/policies/hilserl/classifier/modeling_classifier.py b/lerobot/common/policies/hilserl/classifier/modeling_classifier.py
index 58532302..c5485227 100644
--- a/lerobot/common/policies/hilserl/classifier/modeling_classifier.py
+++ b/lerobot/common/policies/hilserl/classifier/modeling_classifier.py
@@ -145,8 +145,8 @@ class Classifier(
 
         return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs)
 
-    def predict_reward(self, x):
+    def predict_reward(self, x, threshold=0.6):
         if self.config.num_classes == 2:
-            return (self.forward(x).probabilities > 0.6).float()
+            return (self.forward(x).probabilities > threshold).float()
         else:
             return torch.argmax(self.forward(x).probabilities, dim=1)
diff --git a/lerobot/configs/env/so100_real.yaml b/lerobot/configs/env/so100_real.yaml
index 862ea951..82dcfeea 100644
--- a/lerobot/configs/env/so100_real.yaml
+++ b/lerobot/configs/env/so100_real.yaml
@@ -18,10 +18,12 @@ env:
     control_time_s: 20
     reset_follower_pos: true
     use_relative_joint_positions: true
-    reset_time_s: 10
+    reset_time_s: 5
     display_cameras: false
     delta_action: 0.1
+    joint_masking_action_space: [1, 1, 1, 1, 0, 0] # disable wrist and gripper
 
   reward_classifier:
     pretrained_path: outputs/classifier/checkpoints/best/pretrained_model
     config_path: lerobot/configs/policy/hilserl_classifier.yaml
+    
\ No newline at end of file
diff --git a/lerobot/configs/policy/sac_real.yaml b/lerobot/configs/policy/sac_real.yaml
index cbce0b00..de0ffe9b 100644
--- a/lerobot/configs/policy/sac_real.yaml
+++ b/lerobot/configs/policy/sac_real.yaml
@@ -20,7 +20,7 @@ training:
   lr: 3e-4
 
   eval_freq: 2500
-  log_freq: 500
+  log_freq: 1
   save_freq: 2000000
 
   online_steps: 1000000
@@ -31,7 +31,7 @@ training:
   online_env_seed: 10000
   online_buffer_capacity: 1000000
   online_buffer_seed_size: 0
-  online_step_before_learning: 100 #5000
+  online_step_before_learning: 1000 #5000
   do_online_rollout_async: false
   policy_update_freq: 1
 
@@ -76,8 +76,10 @@ policy:
       mean: [0.485, 0.456, 0.406]
       std: [0.229, 0.224, 0.225]
     observation.state:
-      min: [-88.50586,  23.81836, 0.87890625, -32.16797, 78.66211,   0.53691274]
-      max: [84.55078, 187.11914, 145.98633, 101.60156, 146.60156,  88.18792]
+      min: [-87.09961,     62.402344,    67.23633,     36.035156,    77.34375,0.53691274] 
+      max: [58.183594,   131.83594,    145.98633,     82.08984,     78.22266, 0.60402685]
+      # min: [-88.50586,  23.81836, 0.87890625, -32.16797, 78.66211,   0.53691274]
+      # max: [84.55078, 187.11914, 145.98633, 101.60156, 146.60156,  88.18792]
 
   output_normalization_modes:
     action: min_max
diff --git a/lerobot/configs/robot/so100.yaml b/lerobot/configs/robot/so100.yaml
index d57ae721..59c52a6d 100644
--- a/lerobot/configs/robot/so100.yaml
+++ b/lerobot/configs/robot/so100.yaml
@@ -15,8 +15,13 @@ calibration_dir: .cache/calibration/so100
 # the number of motors in your follower arms.
 max_relative_target: null
 joint_position_relative_bounds: 
-  min: [-88.50586,  23.81836, 0.87890625, -32.16797, 78.66211,   0.53691274]
-  max: [84.55078, 187.11914, 145.98633, 101.60156, 146.60156,  88.18792]
+   min: [-87.09961,     62.402344,    67.23633,     36.035156,    77.34375,
+   0.53691274] 
+   max: [58.183594,   131.83594,    145.98633,     82.08984,     78.22266,
+   0.60402685]
+   
+  # min: [-88.50586,  23.81836, 0.87890625, -32.16797, 78.66211,   0.53691274]
+  # max: [84.55078, 187.11914, 145.98633, 101.60156, 146.60156,  88.18792]
 
 leader_arms:
   main:
diff --git a/lerobot/scripts/server/actor_server.py b/lerobot/scripts/server/actor_server.py
index 28c582d2..7ee91b2c 100644
--- a/lerobot/scripts/server/actor_server.py
+++ b/lerobot/scripts/server/actor_server.py
@@ -101,10 +101,14 @@ class ActorServiceServicer(hilserl_pb2_grpc.ActorServiceServicer):
             message = message_queue.get(block=True)
 
             if message.transition is not None:
-                transition_to_send_to_learner = [
-                    move_transition_to_device(T, device="cpu") for T in message.transition
+                transition_to_send_to_learner: list[Transition] = [
+                    move_transition_to_device(transition=T, device="cpu") for T in message.transition
                 ]
-
+                # Check for NaNs in transitions before sending to learner
+                for transition in transition_to_send_to_learner:
+                    for key, value in transition["state"].items():
+                        if torch.isnan(value).any():
+                            logging.warning(f"Found NaN values in transition {key}")
                 buf = io.BytesIO()
                 torch.save(transition_to_send_to_learner, buf)
                 transition_bytes = buf.getvalue()
@@ -226,7 +230,7 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
             with TimerManager(
                 elapsed_time_list=list_policy_time, label="Policy inference time", log=False
             ) as timer:  # noqa: F841
-                action = policy.select_action(batch=obs) * 0.0
+                action = policy.select_action(batch=obs)
             policy_fps = 1.0 / (list_policy_time[-1] + 1e-9)
 
             log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
@@ -238,7 +242,9 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
             next_obs, reward, done, truncated, info = online_env.step(action)
 
             # HACK: We have only one env but we want to batch it, it will be resolved with the torch box
-            action = torch.from_numpy(action[0]).to(device, non_blocking=True).unsqueeze(dim=0)
+            action = (
+                torch.from_numpy(action[0]).to(device, non_blocking=device.type == "cuda").unsqueeze(dim=0)
+            )
 
         sum_reward_episode += float(reward)
 
@@ -247,6 +253,11 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
             # TODO: Check the shape
             action = info["action_intervention"]
 
+        # Check for NaN values in observations
+        for key, tensor in obs.items():
+            if torch.isnan(tensor).any():
+                logging.error(f"[ACTOR] NaN values found in obs[{key}] at step {interaction_step}")
+
         list_transition_to_send_to_learner.append(
             Transition(
                 state=obs,
diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py
index 8be21365..6caa9df7 100644
--- a/lerobot/scripts/server/buffer.py
+++ b/lerobot/scripts/server/buffer.py
@@ -43,16 +43,27 @@ class BatchTransition(TypedDict):
 
 def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition:
     # Move state tensors to CPU
-    transition["state"] = {key: val.to(device, non_blocking=True) for key, val in transition["state"].items()}
+    device = torch.device(device)
+    transition["state"] = {
+        key: val.to(device, non_blocking=device.type == "cuda") for key, val in transition["state"].items()
+    }
 
     # Move action to CPU
-    transition["action"] = transition["action"].to(device, non_blocking=True)
+    transition["action"] = transition["action"].to(device, non_blocking=device.type == "cuda")
 
     # No need to move reward or done, as they are float and bool
 
+    # No need to move reward or done, as they are float and bool
+    if isinstance(transition["reward"], torch.Tensor):
+        transition["reward"] = transition["reward"].to(device=device, non_blocking=device.type == "cuda")
+
+    if isinstance(transition["done"], torch.Tensor):
+        transition["done"] = transition["done"].to(device, non_blocking=device.type == "cuda")
+
     # Move next_state tensors to CPU
     transition["next_state"] = {
-        key: val.to(device, non_blocking=True) for key, val in transition["next_state"].items()
+        key: val.to(device, non_blocking=device.type == "cuda")
+        for key, val in transition["next_state"].items()
     }
 
     # If complementary_info is present, move its tensors to CPU
diff --git a/lerobot/scripts/server/find_joint_limits.py b/lerobot/scripts/server/find_joint_limits.py
index 6ec9d89f..1c2443d6 100644
--- a/lerobot/scripts/server/find_joint_limits.py
+++ b/lerobot/scripts/server/find_joint_limits.py
@@ -40,6 +40,7 @@ def find_joint_bounds(
             min = np.min(np.stack(pos_list), 0)
             print(f"Max angle position per joint {max}")
             print(f"Min angle position per joint {min}")
+            break
 
 
 if __name__ == "__main__":
diff --git a/lerobot/scripts/server/gym_manipulator.py b/lerobot/scripts/server/gym_manipulator.py
index 5bf51868..09b979c5 100644
--- a/lerobot/scripts/server/gym_manipulator.py
+++ b/lerobot/scripts/server/gym_manipulator.py
@@ -296,12 +296,17 @@ class RewardWrapper(gym.Wrapper):
 
         # NOTE: We got 15% speedup by compiling the model
         self.reward_classifier = torch.compile(reward_classifier)
+
+        if isinstance(device, str):
+            device = torch.device(device)
         self.device = device
 
     def step(self, action):
         observation, _, terminated, truncated, info = self.env.step(action)
         images = [
-            observation[key].to(self.device, non_blocking=True) for key in observation if "image" in key
+            observation[key].to(self.device, non_blocking=self.device.type == "cuda")
+            for key in observation
+            if "image" in key
         ]
         start_time = time.perf_counter()
         with torch.inference_mode():
@@ -309,12 +314,76 @@ class RewardWrapper(gym.Wrapper):
                 self.reward_classifier.predict_reward(images) if self.reward_classifier is not None else 0.0
             )
         info["Reward classifer frequency"] = 1 / (time.perf_counter() - start_time)
+
+        if reward == 1.0:
+            terminated = True
         return observation, reward, terminated, truncated, info
 
     def reset(self, seed=None, options=None):
         return self.env.reset(seed=seed, options=options)
 
 
+class JointMaskingActionSpace(gym.ActionWrapper):
+    def __init__(self, env, mask):
+        """
+        Wrapper to mask out dimensions of the action space.
+
+        Args:
+            env: The environment to wrap
+            mask: Binary mask array where 0 indicates dimensions to remove
+        """
+        super().__init__(env)
+
+        # Validate mask matches action space
+
+        # Keep only dimensions where mask is 1
+        self.active_dims = np.where(mask)[0]
+
+        if isinstance(env.action_space, gym.spaces.Box):
+            if len(mask) != env.action_space.shape[0]:
+                raise ValueError("Mask length must match action space dimensions")
+            low = env.action_space.low[self.active_dims]
+            high = env.action_space.high[self.active_dims]
+            self.action_space = gym.spaces.Box(low=low, high=high, dtype=env.action_space.dtype)
+
+        if isinstance(env.action_space, gym.spaces.Tuple):
+            if len(mask) != env.action_space[0].shape[0]:
+                raise ValueError("Mask length must match action space 0 dimensions")
+
+            low = env.action_space[0].low[self.active_dims]
+            high = env.action_space[0].high[self.active_dims]
+            action_space_masked = gym.spaces.Box(low=low, high=high, dtype=env.action_space[0].dtype)
+            self.action_space = gym.spaces.Tuple((action_space_masked, env.action_space[1]))
+            # Create new action space with masked dimensions
+
+    def action(self, action):
+        """
+        Convert masked action back to full action space.
+
+        Args:
+            action: Action in masked space. For Tuple spaces, the first element is masked.
+
+        Returns:
+            Action in original space with masked dims set to 0.
+        """
+
+        # Determine whether we are handling a Tuple space or a Box.
+        if isinstance(self.env.action_space, gym.spaces.Tuple):
+            # Extract the masked component from the tuple.
+            masked_action = action[0] if isinstance(action, tuple) else action
+            # Create a full action for the Box element.
+            full_box_action = np.zeros(self.env.action_space[0].shape, dtype=self.env.action_space[0].dtype)
+            full_box_action[self.active_dims] = masked_action
+            # Return a tuple with the reconstructed Box action and the unchanged remainder.
+            return (full_box_action, action[1])
+        else:
+            # For Box action spaces.
+            masked_action = action if not isinstance(action, tuple) else action[0]
+            full_action = np.zeros(self.env.action_space.shape, dtype=self.env.action_space.dtype)
+            full_action[self.active_dims] = masked_action
+            return full_action
+
+
 class TimeLimitWrapper(gym.Wrapper):
     def __init__(self, env, control_time_s, fps):
         self.env = env
@@ -331,13 +400,12 @@ class TimeLimitWrapper(gym.Wrapper):
     def step(self, action):
         obs, reward, terminated, truncated, info = self.env.step(action)
         time_since_last_step = time.perf_counter() - self.last_timestamp
-        # logging.warning(f"Current timestep is lower than the expected fps {self.fps}")
         self.episode_time_in_s += time_since_last_step
         self.last_timestamp = time.perf_counter()
         self.current_step += 1
         # check if last timestep took more time than the expected fps
-        # if 1.0 / time_since_last_step < self.fps:
-        #     logging.warning(f"Current timestep exceeded expected fps {self.fps}")
+        if 1.0 / time_since_last_step < self.fps:
+            logging.debug(f"Current timestep exceeded expected fps {self.fps}")
 
         if self.episode_time_in_s > self.control_time_s:
             # if self.current_step >= self.max_episode_steps:
@@ -360,7 +428,7 @@ class ImageCropResizeWrapper(gym.Wrapper):
         print(f"obs_keys , {self.env.observation_space}")
         print(f"crop params dict {crop_params_dict.keys()}")
         for key_crop in crop_params_dict:
-            if key_crop not in self.env.observation_space.keys():
+            if key_crop not in self.env.observation_space.keys():  # noqa: SIM118
                 raise ValueError(f"Key {key_crop} not in observation space")
         for key in crop_params_dict:
             top, left, height, width = crop_params_dict[key]
@@ -375,14 +443,23 @@ class ImageCropResizeWrapper(gym.Wrapper):
         obs, reward, terminated, truncated, info = self.env.step(action)
         for k in self.crop_params_dict:
             device = obs[k].device
+
+            # Check for NaNs before processing
+            if torch.isnan(obs[k]).any():
+                logging.error(f"NaN values detected in observation {k} before crop and resize")
+
             if device == torch.device("mps:0"):
                 obs[k] = obs[k].cpu()
+
             obs[k] = F.crop(obs[k], *self.crop_params_dict[k])
             obs[k] = F.resize(obs[k], self.resize_size)
+
+            # Check for NaNs after processing
+            if torch.isnan(obs[k]).any():
+                logging.error(f"NaN values detected in observation {k} after crop and resize")
+
             obs[k] = obs[k].to(device)
-            # print(f"observation with key {k} with size {obs[k].size()}")
-            # cv2.imshow(k, cv2.cvtColor(obs[k].cpu().squeeze(0).permute(1, 2, 0).numpy(), cv2.COLOR_RGB2BGR))
-            # cv2.waitKey(1)
+
         return obs, reward, terminated, truncated, info
 
     def reset(self, seed=None, options=None):
@@ -400,12 +477,18 @@ class ImageCropResizeWrapper(gym.Wrapper):
 class ConvertToLeRobotObservation(gym.ObservationWrapper):
     def __init__(self, env, device):
         super().__init__(env)
+
+        if isinstance(device, str):
+            device = torch.device(device)
         self.device = device
 
     def observation(self, observation):
         observation = preprocess_observation(observation)
 
-        observation = {key: observation[key].to(self.device, non_blocking=True) for key in observation}
+        observation = {
+            key: observation[key].to(self.device, non_blocking=self.device.type == "cuda")
+            for key in observation
+        }
         observation = {k: torch.tensor(v, device=self.device) for k, v in observation.items()}
         return observation
 
@@ -440,7 +523,7 @@ class KeyboardInterfaceWrapper(gym.Wrapper):
                         if key == keyboard.Key.right or key == keyboard.Key.esc:
                             print("Right arrow key pressed. Exiting loop...")
                             self.events["exit_early"] = True
-                        elif key == keyboard.Key.space:
+                        elif key == keyboard.Key.space and not self.events["exit_early"]:
                             if not self.events["pause_policy"]:
                                 print(
                                     "Space key pressed. Human intervention required.\n"
@@ -587,6 +670,7 @@ def make_robot_env(
     env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
     env = KeyboardInterfaceWrapper(env=env)
     env = ResetWrapper(env=env, reset_fn=None, reset_time_s=cfg.wrapper.reset_time_s)
+    env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space)
     env = BatchCompitableWrapper(env=env)
 
     return env
@@ -596,7 +680,7 @@ def make_robot_env(
 
 def get_classifier(pretrained_path, config_path, device="mps"):
     if pretrained_path is None or config_path is None:
-        return
+        return None
 
     from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
     from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py
index bbd70598..1b54e3a9 100644
--- a/lerobot/scripts/server/learner_server.py
+++ b/lerobot/scripts/server/learner_server.py
@@ -278,12 +278,23 @@ def learner_push_parameters(
         torch.save(params_dict, buf)
         params_bytes = buf.getvalue()
 
-        # Push them to the Actor’s "SendParameters" method
+        # Push them to the Actor's "SendParameters" method
         logging.info("[LEARNER] Publishing parameters to the Actor")
         response = actor_stub.SendParameters(hilserl_pb2.Parameters(parameter_bytes=params_bytes))  # noqa: F841
         time.sleep(seconds_between_pushes)
 
 
+def check_nan_in_transition(observations: torch.Tensor, actions: torch.Tensor, next_state: torch.Tensor):
+    for k in observations:
+        if torch.isnan(observations[k]).any():
+            logging.error(f"observations[{k}] contains NaN values")
+    for k in next_state:
+        if torch.isnan(next_state[k]).any():
+            logging.error(f"next_state[{k}] contains NaN values")
+    if torch.isnan(actions).any():
+        logging.error("actions contains NaN values")
+
+
 def add_actor_information_and_train(
     cfg,
     device: str,
@@ -372,6 +383,7 @@ def add_actor_information_and_train(
             observations = batch["state"]
             next_observations = batch["next_state"]
             done = batch["done"]
+            check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
 
             with policy_lock:
                 loss_critic = policy.compute_loss_critic(
@@ -399,6 +411,8 @@ def add_actor_information_and_train(
         next_observations = batch["next_state"]
         done = batch["done"]
 
+        assert_and_breakpoint(observations=observations, actions=actions, next_state=next_observations)
+
         with policy_lock:
             loss_critic = policy.compute_loss_critic(
                 observations=observations,
@@ -497,8 +511,8 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
     It also initializes a learning rate scheduler, though currently, it is set to `None`.
 
     **NOTE:**
-    - If the encoder is shared, its parameters are excluded from the actor’s optimization process.
-    - The policy’s log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor.
+    - If the encoder is shared, its parameters are excluded from the actor's optimization process.
+    - The policy's log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor.
 
     Args:
         cfg: Configuration object containing hyperparameters.