From dc086dc21f0e99aa1d1f0b9a3d2887755c24ac9c Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Thu, 13 Feb 2025 11:04:49 +0100 Subject: [PATCH] Added logging for interventions to monitor the rate of interventions through time Added an s keyboard command to force success in the case the reward classifier fails Co-authored-by: Adil Zouitine --- lerobot/configs/env/so100_real.yaml | 8 +++--- .../configs/policy/hilserl_classifier.yaml | 12 ++++----- lerobot/scripts/server/actor_server.py | 4 +++ lerobot/scripts/server/crop_dataset_roi.py | 14 ++++++++--- lerobot/scripts/server/gym_manipulator.py | 25 +++++++++++++++---- 5 files changed, 45 insertions(+), 18 deletions(-) diff --git a/lerobot/configs/env/so100_real.yaml b/lerobot/configs/env/so100_real.yaml index e6b07c69..b5afea52 100644 --- a/lerobot/configs/env/so100_real.yaml +++ b/lerobot/configs/env/so100_real.yaml @@ -12,10 +12,10 @@ env: wrapper: crop_params_dict: - observation.images.front: [126, 43, 329, 518] - observation.images.side: [93, 69, 381, 434] - # observation.images.front: [135, 59, 331, 527] - # observation.images.side: [79, 47, 397, 450] + observation.images.front: [102, 43, 358, 523] + observation.images.side: [92, 123, 379, 349] + # observation.images.front: [109, 37, 361, 557] + # observation.images.side: [94, 161, 372, 315] resize_size: [128, 128] control_time_s: 20 reset_follower_pos: true diff --git a/lerobot/configs/policy/hilserl_classifier.yaml b/lerobot/configs/policy/hilserl_classifier.yaml index 9b00d7ef..149eeab2 100644 --- a/lerobot/configs/policy/hilserl_classifier.yaml +++ b/lerobot/configs/policy/hilserl_classifier.yaml @@ -4,8 +4,9 @@ defaults: - _self_ seed: 13 -dataset_repo_id: aractingi/push_cube_square_reward_cropped_resized -dataset_root: data/aractingi/push_cube_square_reward_cropped_resized +dataset_repo_id: aractingi/push_cube_square_light_reward_cropped_resized +# aractingi/push_cube_square_reward_1_cropped_resized +dataset_root: data/aractingi/push_cube_square_light_reward_cropped_resized local_files_only: true train_split_proportion: 0.8 @@ -26,7 +27,6 @@ training: eval_freq: 1 # How often to run validation (in epochs) save_freq: 1 # How often to save checkpoints (in epochs) save_checkpoint: true - # image_keys: ["observation.images.top", "observation.images.wrist"] image_keys: ["observation.images.front", "observation.images.side"] label_key: "next.reward" profile_inference_time: false @@ -37,8 +37,8 @@ eval: num_samples_to_log: 30 # Number of validation samples to log in the table policy: - name: "hilserl/classifier/push_cube_square_reward_cropped_resized" #"hilserl/classifier/pick_place_lego_cube_120 - model_name: "helper2424/resnet10" # "facebook/convnext-base-224" #"helper2424/resnet10" + name: "hilserl/classifier" + model_name: "helper2424/resnet10" # "facebook/convnext-base-224 model_type: "cnn" num_cameras: 2 # Has to be len(training.image_keys) @@ -50,4 +50,4 @@ wandb: device: "mps" resume: false -output_dir: "outputs/classifier/resnet10_frozen" +output_dir: "outputs/classifier/old_trainer_resnet10_frozen" diff --git a/lerobot/scripts/server/actor_server.py b/lerobot/scripts/server/actor_server.py index 7ee91b2c..7b1866f9 100644 --- a/lerobot/scripts/server/actor_server.py +++ b/lerobot/scripts/server/actor_server.py @@ -223,6 +223,7 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module) sum_reward_episode = 0 list_transition_to_send_to_learner = [] list_policy_time = [] + episode_intervention = False for interaction_step in range(cfg.training.online_steps): if interaction_step >= cfg.training.online_step_before_learning: @@ -252,6 +253,7 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module) if info["is_intervention"]: # TODO: Check the shape action = info["action_intervention"] + episode_intervention = True # Check for NaN values in observations for key, tensor in obs.items(): @@ -295,11 +297,13 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module) interaction_message={ "Episodic reward": sum_reward_episode, "Interaction step": interaction_step, + "Episode intervention": int(episode_intervention), **stats, } ) ) sum_reward_episode = 0.0 + episode_intervention = False obs, info = online_env.reset() diff --git a/lerobot/scripts/server/crop_dataset_roi.py b/lerobot/scripts/server/crop_dataset_roi.py index 53fda473..da1bf96a 100644 --- a/lerobot/scripts/server/crop_dataset_roi.py +++ b/lerobot/scripts/server/crop_dataset_roi.py @@ -245,10 +245,18 @@ if __name__ == "__main__": images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()} images = {k: (v * 255).astype("uint8") for k, v in images.items()} - rois = select_square_roi_for_images(images) + # rois = select_square_roi_for_images(images) + rois = { + "observation.images.front": [102, 43, 358, 523], + "observation.images.side": [92, 123, 379, 349], + } # rois = { - # "observation.images.front": [126, 43, 329, 518], - # "observation.images.side": [93, 69, 381, 434], + # "observation.images.side": (92, 123, 379, 349), + # "observation.images.front": (109, 37, 361, 557), + # } + # rois = { + # "observation.images.front": [109, 37, 361, 557], + # "observation.images.side": [94, 161, 372, 315], # } # Print the selected rectangular ROIs diff --git a/lerobot/scripts/server/gym_manipulator.py b/lerobot/scripts/server/gym_manipulator.py index c29450bc..baaa3da9 100644 --- a/lerobot/scripts/server/gym_manipulator.py +++ b/lerobot/scripts/server/gym_manipulator.py @@ -312,7 +312,7 @@ class RewardWrapper(gym.Wrapper): start_time = time.perf_counter() with torch.inference_mode(): reward = ( - self.reward_classifier.predict_reward(images, threshold=0.5) + self.reward_classifier.predict_reward(images, threshold=0.6) if self.reward_classifier is not None else 0.0 ) @@ -507,6 +507,7 @@ class KeyboardInterfaceWrapper(gym.Wrapper): "pause_policy": False, "reset_env": False, "human_intervention_step": False, + "episode_success": False, } self.event_lock = Lock() # Thread-safe access to events self._init_keyboard_listener() @@ -528,7 +529,12 @@ class KeyboardInterfaceWrapper(gym.Wrapper): if key == keyboard.Key.right or key == keyboard.Key.esc: print("Right arrow key pressed. Exiting loop...") self.events["exit_early"] = True - elif key == keyboard.Key.space and not self.events["exit_early"]: + return + if hasattr(key, "char") and key.char == "s": + print("Key 's' pressed. Episode success triggered.") + self.events["episode_success"] = True + return + if key == keyboard.Key.space and not self.events["exit_early"]: if not self.events["pause_policy"]: print( "Space key pressed. Human intervention required.\n" @@ -536,15 +542,18 @@ class KeyboardInterfaceWrapper(gym.Wrapper): ) self.events["pause_policy"] = True log_say("Human intervention stage. Get ready to take over.", play_sounds=True) - elif self.events["pause_policy"] and not self.events["human_intervention_step"]: + return + if self.events["pause_policy"] and not self.events["human_intervention_step"]: self.events["human_intervention_step"] = True print("Space key pressed. Human intervention starting.") log_say("Starting human intervention.", play_sounds=True) - else: + return + if self.events["pause_policy"] and self.events["human_intervention_step"]: self.events["pause_policy"] = False self.events["human_intervention_step"] = False print("Space key pressed for a third time.") log_say("Continuing with policy actions.", play_sounds=True) + return except Exception as e: print(f"Error handling key press: {e}") @@ -566,7 +575,6 @@ class KeyboardInterfaceWrapper(gym.Wrapper): with self.event_lock: if self.events["exit_early"]: terminated_by_keyboard = True - # If we need to wait for human intervention, we note that outside the lock. pause_policy = self.events["pause_policy"] if pause_policy: @@ -580,6 +588,13 @@ class KeyboardInterfaceWrapper(gym.Wrapper): # Execute the step in the underlying environment obs, reward, terminated, truncated, info = self.env.step((policy_action, is_intervention)) + + # Override reward and termination if episode success event triggered + with self.event_lock: + if self.events["episode_success"]: + reward = 1 + terminated_by_keyboard = True + return obs, reward, terminated or terminated_by_keyboard, truncated, info def reset(self, **kwargs) -> Tuple[Any, Dict]: