From 8fe2c5eaa2296184a9aa5d5a08c4d577020003fe Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Wed, 12 Feb 2025 19:25:41 +0100 Subject: [PATCH] Added possiblity to record and replay delta actions during teleoperation rather than absolute actions Co-authored-by: Adil Zouitine --- .../hilserl/classifier/modeling_classifier.py | 4 +- lerobot/common/robot_devices/control_utils.py | 8 + lerobot/configs/env/so100_real.yaml | 6 +- .../configs/policy/hilserl_classifier.yaml | 14 +- lerobot/configs/policy/sac_real.yaml | 12 +- lerobot/configs/robot/so100.yaml | 4 +- lerobot/scripts/control_robot.py | 6 +- lerobot/scripts/server/crop_dataset_roi.py | 6 +- lerobot/scripts/server/gym_manipulator.py | 13 +- .../server/wrappers/gym_manipulator.py | 584 ------------------ lerobot/scripts/train_hilserl_classifier.py | 24 +- 11 files changed, 63 insertions(+), 618 deletions(-) delete mode 100644 lerobot/scripts/server/wrappers/gym_manipulator.py diff --git a/lerobot/common/policies/hilserl/classifier/modeling_classifier.py b/lerobot/common/policies/hilserl/classifier/modeling_classifier.py index c5485227..a9fbb601 100644 --- a/lerobot/common/policies/hilserl/classifier/modeling_classifier.py +++ b/lerobot/common/policies/hilserl/classifier/modeling_classifier.py @@ -147,6 +147,8 @@ class Classifier( def predict_reward(self, x, threshold=0.6): if self.config.num_classes == 2: - return (self.forward(x).probabilities > threshold).float() + probs = self.forward(x).probabilities + logging.info(f"Predicted reward images: {probs}") + return (probs > threshold).float() else: return torch.argmax(self.forward(x).probabilities, dim=1) diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py index ab9a86cd..1703a52a 100644 --- a/lerobot/common/robot_devices/control_utils.py +++ b/lerobot/common/robot_devices/control_utils.py @@ -221,6 +221,7 @@ def record_episode( events=events, policy=policy, fps=fps, + record_delta_actions=record_delta_actions, teleoperate=policy is None, single_task=single_task, ) @@ -262,8 +263,12 @@ def control_loop( while timestamp < control_time_s: start_loop_t = time.perf_counter() + current_joint_positions = robot.follower_arms["main"].read("Present_Position") + if teleoperate: observation, action = robot.teleop_step(record_data=True) + if record_delta_actions: + action["action"] = action["action"] - current_joint_positions else: observation = robot.capture_observation() @@ -280,6 +285,9 @@ def control_loop( frame = {**observation, **action, "task": single_task} dataset.add_frame(frame) + # if frame["next.done"]: + # break + if display_cameras and not is_headless(): image_keys = [key for key in observation if "image" in key] for key in image_keys: diff --git a/lerobot/configs/env/so100_real.yaml b/lerobot/configs/env/so100_real.yaml index 82dcfeea..e6b07c69 100644 --- a/lerobot/configs/env/so100_real.yaml +++ b/lerobot/configs/env/so100_real.yaml @@ -12,8 +12,10 @@ env: wrapper: crop_params_dict: - observation.images.laptop: [58, 89, 357, 455] - observation.images.phone: [3, 4, 471, 633] + observation.images.front: [126, 43, 329, 518] + observation.images.side: [93, 69, 381, 434] + # observation.images.front: [135, 59, 331, 527] + # observation.images.side: [79, 47, 397, 450] resize_size: [128, 128] control_time_s: 20 reset_follower_pos: true diff --git a/lerobot/configs/policy/hilserl_classifier.yaml b/lerobot/configs/policy/hilserl_classifier.yaml index 1a95f000..9b00d7ef 100644 --- a/lerobot/configs/policy/hilserl_classifier.yaml +++ b/lerobot/configs/policy/hilserl_classifier.yaml @@ -4,7 +4,9 @@ defaults: - _self_ seed: 13 -dataset_repo_id: aractingi/push_green_cube_hf_cropped_resized +dataset_repo_id: aractingi/push_cube_square_reward_cropped_resized +dataset_root: data/aractingi/push_cube_square_reward_cropped_resized +local_files_only: true train_split_proportion: 0.8 # Required by logger @@ -14,7 +16,7 @@ env: training: - num_epochs: 5 + num_epochs: 6 batch_size: 16 learning_rate: 1e-4 num_workers: 4 @@ -25,7 +27,7 @@ training: save_freq: 1 # How often to save checkpoints (in epochs) save_checkpoint: true # image_keys: ["observation.images.top", "observation.images.wrist"] - image_keys: ["observation.images.laptop", "observation.images.phone"] + image_keys: ["observation.images.front", "observation.images.side"] label_key: "next.reward" profile_inference_time: false profile_inference_time_iters: 20 @@ -35,8 +37,8 @@ eval: num_samples_to_log: 30 # Number of validation samples to log in the table policy: - name: "hilserl/classifier/push_green_cube_hf_cropped_resized" #"hilserl/classifier/pick_place_lego_cube_1" - model_name: "helper2424/resnet10" + name: "hilserl/classifier/push_cube_square_reward_cropped_resized" #"hilserl/classifier/pick_place_lego_cube_120 + model_name: "helper2424/resnet10" # "facebook/convnext-base-224" #"helper2424/resnet10" model_type: "cnn" num_cameras: 2 # Has to be len(training.image_keys) @@ -48,4 +50,4 @@ wandb: device: "mps" resume: false -output_dir: "outputs/classifier" +output_dir: "outputs/classifier/resnet10_frozen" diff --git a/lerobot/configs/policy/sac_real.yaml b/lerobot/configs/policy/sac_real.yaml index de0ffe9b..afcb408e 100644 --- a/lerobot/configs/policy/sac_real.yaml +++ b/lerobot/configs/policy/sac_real.yaml @@ -57,22 +57,22 @@ policy: input_shapes: # # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env? observation.state: ["${env.state_dim}"] - observation.images.laptop: [3, 128, 128] - observation.images.phone: [3, 128, 128] + observation.images.front: [3, 128, 128] + observation.images.side: [3, 128, 128] # observation.image: [3, 128, 128] output_shapes: action: ["${env.action_dim}"] # Normalization / Unnormalization input_normalization_modes: - observation.images.laptop: mean_std - observation.images.phone: mean_std + observation.images.front: mean_std + observation.images.side: mean_std observation.state: min_max input_normalization_params: - observation.images.laptop: + observation.images.front: mean: [0.485, 0.456, 0.406] std: [0.229, 0.224, 0.225] - observation.images.phone: + observation.images.side: mean: [0.485, 0.456, 0.406] std: [0.229, 0.224, 0.225] observation.state: diff --git a/lerobot/configs/robot/so100.yaml b/lerobot/configs/robot/so100.yaml index 59c52a6d..82689753 100644 --- a/lerobot/configs/robot/so100.yaml +++ b/lerobot/configs/robot/so100.yaml @@ -50,13 +50,13 @@ follower_arms: gripper: [6, "sts3215"] cameras: - laptop: + front: _target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera camera_index: 0 fps: 30 width: 640 height: 480 - phone: + side: _target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera camera_index: 1 fps: 30 diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index bb496425..016ce2e9 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -279,7 +279,7 @@ def record( if reset_follower: initial_position = robot.follower_arms["main"].read("Present_Position") - + # Execute a few seconds without recording to: # 1. teleoperate the robot to move it in starting position if no policy provided, # 2. give times to the robot devices to connect and start synchronizing, @@ -351,15 +351,17 @@ def replay( dataset = LeRobotDataset(cfg.repo_id, root=cfg.root, episodes=[cfg.episode]) actions = dataset.hf_dataset.select_columns("action") - if not robot.is_connected: robot.connect() log_say("Replaying episode", cfg.play_sounds, blocking=True) for idx in range(dataset.num_frames): + current_joint_positions = robot.follower_arms["main"].read("Present_Position") start_episode_t = time.perf_counter() action = actions[idx]["action"] + if replay_delta_actions: + action = action + current_joint_positions robot.send_action(action) dt_s = time.perf_counter() - start_episode_t diff --git a/lerobot/scripts/server/crop_dataset_roi.py b/lerobot/scripts/server/crop_dataset_roi.py index 41be58a8..53fda473 100644 --- a/lerobot/scripts/server/crop_dataset_roi.py +++ b/lerobot/scripts/server/crop_dataset_roi.py @@ -239,13 +239,17 @@ if __name__ == "__main__": ) args = parser.parse_args() - dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root) + dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root, local_files_only=True) images = get_image_from_lerobot_dataset(dataset) images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()} images = {k: (v * 255).astype("uint8") for k, v in images.items()} rois = select_square_roi_for_images(images) + # rois = { + # "observation.images.front": [126, 43, 329, 518], + # "observation.images.side": [93, 69, 381, 434], + # } # Print the selected rectangular ROIs print("\nSelected Rectangular Regions of Interest (top, left, height, width):") diff --git a/lerobot/scripts/server/gym_manipulator.py b/lerobot/scripts/server/gym_manipulator.py index 09b979c5..c29450bc 100644 --- a/lerobot/scripts/server/gym_manipulator.py +++ b/lerobot/scripts/server/gym_manipulator.py @@ -230,6 +230,8 @@ class HILSerlRobotEnv(gym.Env): if teleop_action.dim() == 1: teleop_action = teleop_action.unsqueeze(0) + # self.render() + self.current_step += 1 reward = 0.0 @@ -255,8 +257,7 @@ class HILSerlRobotEnv(gym.Env): for key in image_keys: cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)) - - cv2.waitKey(1) + cv2.waitKey(1) def close(self): """ @@ -311,10 +312,14 @@ class RewardWrapper(gym.Wrapper): start_time = time.perf_counter() with torch.inference_mode(): reward = ( - self.reward_classifier.predict_reward(images) if self.reward_classifier is not None else 0.0 + self.reward_classifier.predict_reward(images, threshold=0.5) + if self.reward_classifier is not None + else 0.0 ) info["Reward classifer frequency"] = 1 / (time.perf_counter() - start_time) + logging.info(f"Reward: {reward}") + if reward == 1.0: terminated = True return observation, reward, terminated, truncated, info @@ -760,7 +765,7 @@ if __name__ == "__main__": env = make_robot_env( robot, reward_classifier, - cfg.wrapper, + cfg.env, # .wrapper, ) env.reset() diff --git a/lerobot/scripts/server/wrappers/gym_manipulator.py b/lerobot/scripts/server/wrappers/gym_manipulator.py deleted file mode 100644 index f95b7731..00000000 --- a/lerobot/scripts/server/wrappers/gym_manipulator.py +++ /dev/null @@ -1,584 +0,0 @@ -import argparse -import logging -import time -from threading import Lock -from typing import Annotated, Any, Callable, Dict, Optional, Tuple - -import gymnasium as gym -import numpy as np -import torch -import torchvision.transforms.functional as F # noqa: N812 - -from lerobot.common.envs.utils import preprocess_observation -from lerobot.common.robot_devices.control_utils import is_headless, reset_follower_position -from lerobot.common.robot_devices.robots.factory import make_robot -from lerobot.common.utils.utils import init_hydra_config, log_say - -logging.basicConfig(level=logging.INFO) - - -class HILSerlRobotEnv(gym.Env): - """ - Gym-like environment wrapper for robot policy evaluation. - - This wrapper provides a consistent interface for interacting with the robot, - following the OpenAI Gym environment conventions. - """ - - def __init__( - self, - robot, - display_cameras=False, - ): - """ - Initialize the robot environment. - - Args: - robot: The robot interface object - reward_classifier: Optional reward classifier - fps: Frames per second for control - control_time_s: Total control time for each episode - display_cameras: Whether to display camera feeds - """ - super().__init__() - - self.robot = robot - self.display_cameras = display_cameras - - # connect robot - if not self.robot.is_connected: - self.robot.connect() - - # Dynamically determine observation and action spaces - self._setup_spaces() - - self.initial_follower_position = robot.follower_arms["main"].read("Present_Position") - - # Episode tracking - self.current_step = 0 - self.episode_data = None - - def _setup_spaces(self): - """ - Dynamically determine observation and action spaces based on robot capabilities. - - This method should be customized based on the specific robot's observation - and action representations. - """ - # Example space setup - you'll need to adapt this to your specific robot - example_obs = self.robot.capture_observation() - - # Observation space (assuming image-based observations) - image_keys = [key for key in example_obs if "image" in key] - state_keys = [key for key in example_obs if "image" not in key] - observation_spaces = { - key: gym.spaces.Box(low=0, high=255, shape=example_obs[key].shape, dtype=np.uint8) - for key in image_keys - } - observation_spaces["observation.state"] = gym.spaces.Dict( - { - key: gym.spaces.Box(low=0, high=10, shape=example_obs[key].shape, dtype=np.float32) - for key in state_keys - } - ) - - self.observation_space = gym.spaces.Dict(observation_spaces) - - # Action space (assuming joint positions) - action_dim = len(self.robot.follower_arms["main"].read("Present_Position")) - self.action_space = gym.spaces.Tuple( - ( - gym.spaces.Box(low=-np.inf, high=np.inf, shape=(action_dim,), dtype=np.float32), - gym.spaces.Discrete(2), - ), - ) - - def reset(self, seed=None, options=None) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]: - """ - Reset the environment to initial state. - - Returns: - observation (dict): Initial observation - info (dict): Additional information - """ - super().reset(seed=seed, options=options) - - # Capture initial observation - observation = self.robot.capture_observation() - - # Reset tracking variables - self.current_step = 0 - self.episode_data = None - - return observation, {"initial_position": self.initial_follower_position} - - def step( - self, action: Tuple[np.ndarray, bool] - ) -> Tuple[Dict[str, np.ndarray], float, bool, bool, Dict[str, Any]]: - """ - Take a step in the environment. - - Args: - action tuple(np.ndarray, bool): - Policy action to be executed on the robot and boolean to determine - whether to choose policy action or expert action. - - Returns: - observation (dict): Next observation - reward (float): Reward for this step - terminated (bool): Whether the episode has terminated - truncated (bool): Whether the episode was truncated - info (dict): Additional information - """ - # The actions recieved are the in form of a tuple containing the policy action and an intervention bool - # The boolean inidicated whether we will use the expert's actions (through teleoperation) or the policy actions - policy_action, intervention_bool = action - teleop_action = None - if not intervention_bool: - self.robot.send_action(policy_action.cpu()) - observation = self.robot.capture_observation() - else: - observation, teleop_action = self.robot.teleop_step(record_data=True) - teleop_action = teleop_action["action"] # teleop step returns torch tensors but in a dict - - self.current_step += 1 - - reward = 0.0 - terminated = False - truncated = False - - return ( - observation, - reward, - terminated, - truncated, - {"action_intervention": teleop_action, "is_intervention": teleop_action is not None}, - ) - - def render(self): - """ - Render the environment (in this case, display camera feeds). - """ - import cv2 - - observation = self.robot.capture_observation() - image_keys = [key for key in observation if "image" in key] - - for key in image_keys: - cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)) - - cv2.waitKey(1) - - def close(self): - """ - Close the environment and disconnect the robot. - """ - if self.robot.is_connected: - self.robot.disconnect() - - -class ActionRepeatWrapper(gym.Wrapper): - def __init__(self, env, nb_repeat: int = 1): - super().__init__(env) - self.nb_repeat = nb_repeat - - def step(self, action): - for _ in range(self.nb_repeat): - obs, reward, done, truncated, info = self.env.step(action) - if done or truncated: - break - return obs, reward, done, truncated, info - - -class RelativeJointPositionActionWrapper(gym.Wrapper): - def __init__(self, env: HILSerlRobotEnv, delta: float = 0.1): - super().__init__(env) - self.joint_positions = self.unwrapped.robot.follower_arms["main"].read("Present_Position") - self.delta = delta - - def step(self, action): - action_joint = action - self.joint_positions = self.unwrapped.robot.follower_arms["main"].read("Present_Position") - if isinstance(self.env.action_space, gym.spaces.Tuple): - action_joint = action[0] - joint_positions = self.joint_positions + (self.delta * action_joint) - # clip the joint positions to the joint limits with the action space - joint_positions = np.clip(joint_positions, self.action_space.low, self.action_space.high) - - if isinstance(self.env.action_space, gym.spaces.Tuple): - return self.env.step((joint_positions, action[1])) - - obs, reward, terminated, truncated, info = self.env.step(joint_positions) - if info["is_intervention"]: - # teleop actions are returned in absolute joint space - # If we are using a relative joint position action space, - # there will be a mismatch between the spaces of the policy and teleop actions - # Solution is to transform the teleop actions into relative space. - teleop_action = info["action_intervention"] # teleop actions are in absolute joint space - relative_teleop_action = (teleop_action - self.joint_positions) / self.delta - info["action_intervention"] = relative_teleop_action - - return self.env.step(joint_positions) - - -class RewardWrapper(gym.Wrapper): - def __init__(self, env, reward_classifier: Optional[None], device: torch.device = "cuda"): - self.env = env - self.reward_classifier = reward_classifier - self.device = device - - def step(self, action): - observation, _, terminated, truncated, info = self.env.step(action) - images = [ - observation[key].to(self.device, non_blocking=True) for key in observation if "image" in key - ] - reward = self.reward_classifier.predict_reward(images) if self.reward_classifier is not None else 0.0 - reward = reward.item() - return observation, reward, terminated, truncated, info - - def reset(self, seed=None, options=None): - return self.env.reset(seed=seed, options=options) - - -class TimeLimitWrapper(gym.Wrapper): - def __init__(self, env, control_time_s, fps): - self.env = env - self.control_time_s = control_time_s - self.fps = fps - - self.last_timestamp = 0.0 - self.episode_time_in_s = 0.0 - - def step(self, action): - obs, reward, terminated, truncated, info = self.env.step(action) - time_since_last_step = time.perf_counter() - self.last_timestamp - self.episode_time_in_s += time_since_last_step - self.last_timestamp = time.perf_counter() - - # check if last timestep took more time than the expected fps - if 1.0 / time_since_last_step < self.fps: - logging.warning(f"Current timestep exceeded expected fps {self.fps}") - - if self.episode_time_in_s > self.control_time_s: - # Terminated = True - terminated = True - return obs, reward, terminated, truncated, info - - def reset(self, seed=None, options=None): - self.episode_time_in_s = 0.0 - self.last_timestamp = time.perf_counter() - return self.env.reset(seed=seed, options=options) - - -class ImageCropResizeWrapper(gym.Wrapper): - def __init__(self, env, crop_params_dict: Dict[str, Annotated[Tuple[int], 4]], resize_size=None): - self.env = env - self.crop_params_dict = crop_params_dict - for key in crop_params_dict: - assert key in self.env.observation_space, f"Key {key} not in observation space" - top, left, height, width = crop_params_dict[key] - new_shape = (top + height, left + width) - self.observation_space[key] = gym.spaces.Box(low=0, high=255, shape=new_shape) - - self.resize_size = resize_size - if self.resize_size is None: - self.resize_size = (128, 128) - - def step(self, action): - obs, reward, terminated, truncated, info = self.env.step(action) - for k in self.crop_params_dict: - obs[k] = F.crop(obs[k], *self.crop_params_dict[k]) - obs[k] = F.resize(obs[k], self.resize_size) - return obs, reward, terminated, truncated, info - - -class ConvertToLeRobotObservation(gym.ObservationWrapper): - def __init__(self, env, device): - super().__init__(env) - self.device = device - - def observation(self, observation): - observation = preprocess_observation(observation) - - observation = {key: observation[key].to(self.device, non_blocking=True) for key in observation} - observation = {k: torch.tensor(v, device=self.device) for k, v in observation.items()} - return observation - - -class KeyboardInterfaceWrapper(gym.Wrapper): - def __init__(self, env): - super().__init__(env) - self.listener = None - self.events = { - "exit_early": False, - "pause_policy": False, - "reset_env": False, - "human_intervention_step": False, - } - self.event_lock = Lock() # Thread-safe access to events - self._init_keyboard_listener() - - def _init_keyboard_listener(self): - """Initialize keyboard listener if not in headless mode""" - - if is_headless(): - logging.warning( - "Headless environment detected. On-screen cameras display and keyboard inputs will not be available." - ) - return - try: - from pynput import keyboard - - def on_press(key): - with self.event_lock: - try: - if key == keyboard.Key.right or key == keyboard.Key.esc: - print("Right arrow key pressed. Exiting loop...") - self.events["exit_early"] = True - elif key == keyboard.Key.space: - if not self.events["pause_policy"]: - print( - "Space key pressed. Human intervention required.\n" - "Place the leader in similar pose to the follower and press space again." - ) - self.events["pause_policy"] = True - log_say("Human intervention stage. Get ready to take over.", play_sounds=True) - elif self.events["pause_policy"] and not self.events["human_intervention_step"]: - self.events["human_intervention_step"] = True - print("Space key pressed. Human intervention starting.") - log_say("Starting human intervention.", play_sounds=True) - else: - self.events["pause_policy"] = False - self.events["human_intervention_step"] = False - print("Space key pressed for a third time.") - log_say("Continuing with policy actions.", play_sounds=True) - except Exception as e: - print(f"Error handling key press: {e}") - - self.listener = keyboard.Listener(on_press=on_press) - self.listener.start() - except ImportError: - logging.warning("Could not import pynput. Keyboard interface will not be available.") - self.listener = None - - def step(self, action: Any) -> Tuple[Any, float, bool, bool, Dict]: - is_intervention = False - terminated_by_keyboard = False - - # Extract policy_action if needed - if isinstance(self.env.action_space, gym.spaces.Tuple): - policy_action = action[0] - - # Check the event flags without holding the lock for too long. - with self.event_lock: - if self.events["exit_early"]: - terminated_by_keyboard = True - # If we need to wait for human intervention, we note that outside the lock. - pause_policy = self.events["pause_policy"] - - if pause_policy: - # Now, wait for human_intervention_step without holding the lock - while True: - with self.event_lock: - if self.events["human_intervention_step"]: - is_intervention = True - break - time.sleep(0.1) # Check more frequently if desired - - # Execute the step in the underlying environment - obs, reward, terminated, truncated, info = self.env.step((policy_action, is_intervention)) - return obs, reward, terminated or terminated_by_keyboard, truncated, info - - def reset(self, **kwargs) -> Tuple[Any, Dict]: - """ - Reset the environment and clear any pending events - """ - with self.event_lock: - self.events = {k: False for k in self.events} - return self.env.reset(**kwargs) - - def close(self): - """ - Properly clean up the keyboard listener when the environment is closed - """ - if self.listener is not None: - self.listener.stop() - super().close() - - -class ResetWrapper(gym.Wrapper): - def __init__( - self, env: HILSerlRobotEnv, reset_fn: Optional[Callable[[], None]] = None, reset_time_s: float = 5 - ): - super().__init__(env) - self.reset_fn = reset_fn - self.reset_time_s = reset_time_s - - self.robot = self.unwrapped.robot - self.init_pos = self.unwrapped.initial_follower_position - - def reset(self, *, seed=None, options=None): - if self.reset_fn is not None: - self.reset_fn(self.env) - else: - log_say(f"Manually reset the environment for {self.reset_time_s} seconds.", play_sounds=True) - start_time = time.perf_counter() - while time.perf_counter() - start_time < self.reset_time_s: - self.robot.teleop_step() - - log_say("Manual reseting of the environment done.", play_sounds=True) - return super().reset(seed=seed, options=options) - - -def make_robot_env( - robot, - reward_classifier, - crop_params_dict=None, - fps=30, - control_time_s=20, - reset_follower_pos=True, - display_cameras=False, - device="cuda:0", - resize_size=None, - reset_time_s=10, - delta_action=0.1, - nb_repeats=1, - use_relative_joint_positions=False, -): - """ - Factory function to create the robot environment. - - Mimics gym.make() for consistent environment creation. - """ - env = HILSerlRobotEnv(robot, display_cameras) - env = ConvertToLeRobotObservation(env, device) - # if crop_params_dict is not None: - # env = ImageCropResizeWrapper(env, crop_params_dict, resize_size=resize_size) - # env = RewardWrapper(env, reward_classifier) - env = TimeLimitWrapper(env, control_time_s, fps) - # if use_relative_joint_positions: - # env = RelativeJointPositionActionWrapper(env, delta=delta_action) - # env = ActionRepeatWrapper(env, nb_repeat=nb_repeats) - env = KeyboardInterfaceWrapper(env) - env = ResetWrapper(env, reset_fn=None, reset_time_s=reset_time_s) - return env - - -def get_classifier(pretrained_path, config_path, device="mps"): - if pretrained_path is None or config_path is None: - return - - from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg - from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig - from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier - - cfg = init_hydra_config(config_path) - - classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg) - classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths - model = Classifier(classifier_config) - model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict()) - model = model.to(device) - return model - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--fps", type=int, default=30, help="control frequency") - parser.add_argument( - "--robot-path", - type=str, - default="lerobot/configs/robot/koch.yaml", - help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.", - ) - parser.add_argument( - "--robot-overrides", - type=str, - nargs="*", - help="Any key=value arguments to override config values (use dots for.nested=overrides)", - ) - parser.add_argument( - "-p", - "--pretrained-policy-name-or-path", - help=( - "Either the repo ID of a model hosted on the Hub or a path to a directory containing weights " - "saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch " - "(useful for debugging). This argument is mutually exclusive with `--config`." - ), - ) - parser.add_argument( - "--config", - help=( - "Path to a yaml config you want to use for initializing a policy from scratch (useful for " - "debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)." - ), - ) - parser.add_argument( - "--display-cameras", help=("Whether to display the camera feed while the rollout is happening") - ) - parser.add_argument( - "--reward-classifier-pretrained-path", - type=str, - default=None, - help="Path to the pretrained classifier weights.", - ) - parser.add_argument( - "--reward-classifier-config-file", - type=str, - default=None, - help="Path to a yaml config file that is necessary to build the reward classifier model.", - ) - parser.add_argument("--control-time-s", type=float, default=20, help="Maximum episode length in seconds") - parser.add_argument("--reset-follower-pos", type=int, default=1, help="Reset follower between episodes") - args = parser.parse_args() - - robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides) - robot = make_robot(robot_cfg) - - reward_classifier = get_classifier( - args.reward_classifier_pretrained_path, args.reward_classifier_config_file - ) - - env = make_robot_env( - robot, - reward_classifier, - None, - args.fps, - args.control_time_s, - args.reset_follower_pos, - args.display_cameras, - device="mps", - resize_size=None, - reset_time_s=10, - delta_action=0.1, - nb_repeats=1, - use_relative_joint_positions=False, - ) - - env.reset() - init_pos = env.unwrapped.initial_follower_position - goal_pos = init_pos - - right_goal = init_pos.copy() - right_goal[0] += 50 - - left_goal = init_pos.copy() - left_goal[0] -= 50 - - # Michel is a beast - pitch_angle = np.linspace(left_goal[0], right_goal[0], 1000) - - while True: - for i in range(len(pitch_angle)): - goal_pos[0] = pitch_angle[i] - obs, reward, terminated, truncated, info = env.step((torch.from_numpy(goal_pos), False)) - if terminated or truncated: - logging.info("Max control time reached, reset environment.") - env.reset() - - for i in reversed(range(len(pitch_angle))): - goal_pos[0] = pitch_angle[i] - obs, reward, terminated, truncated, info = env.step((torch.from_numpy(goal_pos), False)) - if terminated or truncated: - logging.info("Max control time reached, reset environment.") - env.reset() diff --git a/lerobot/scripts/train_hilserl_classifier.py b/lerobot/scripts/train_hilserl_classifier.py index 0db19cd6..e0e01a5d 100644 --- a/lerobot/scripts/train_hilserl_classifier.py +++ b/lerobot/scripts/train_hilserl_classifier.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python - # Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -23,6 +21,7 @@ import hydra import numpy as np import torch import torch.nn as nn +import wandb from deepdiff import DeepDiff from omegaconf import DictConfig, OmegaConf from termcolor import colored @@ -32,7 +31,6 @@ from torch.cuda.amp import GradScaler from torch.utils.data import DataLoader, RandomSampler, WeightedRandomSampler, random_split from tqdm import tqdm -import wandb from lerobot.common.datasets.factory import resolve_delta_timestamps from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.logger import Logger @@ -45,6 +43,7 @@ from lerobot.common.utils.utils import ( init_hydra_config, set_global_seed, ) +from lerobot.scripts.server.buffer import random_shift def get_model(cfg, logger): # noqa I001 @@ -82,6 +81,7 @@ def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, for batch_idx, batch in enumerate(pbar): start_time = time.perf_counter() images = [batch[img_key].to(device) for img_key in cfg.training.image_keys] + images = [random_shift(img, 4) for img in images] labels = batch[cfg.training.label_key].float().to(device) # Forward pass with optional AMP @@ -161,14 +161,17 @@ def validate(model, val_loader, criterion, device, logger, cfg): # Log sample predictions for visualization if len(samples) < cfg.eval.num_samples_to_log: - for i in range(min( cfg.eval.num_samples_to_log - len(samples), len(images))): + for i in range(min(cfg.eval.num_samples_to_log - len(samples), len(images))): if model.config.num_classes == 2: confidence = round(outputs.probabilities[i].item(), 3) else: confidence = [round(prob, 3) for prob in outputs.probabilities[i].tolist()] samples.append( { - **{f"image_{img_key}": wandb.Image(images[img_idx][i].cpu()) for img_idx, img_key in enumerate(cfg.training.image_keys)}, + **{ + f"image_{img_key}": wandb.Image(images[img_idx][i].cpu()) + for img_idx, img_key in enumerate(cfg.training.image_keys) + }, "true_label": labels[i].item(), "predicted": predictions[i].item(), "confidence": confidence, @@ -270,11 +273,13 @@ def train(cfg: DictConfig) -> None: device = get_safe_torch_device(cfg.device, log=True) set_global_seed(cfg.seed) - out_dir = hydra.core.hydra_config.HydraConfig.get().run.dir + "classifier" + out_dir = hydra.core.hydra_config.HydraConfig.get().run.dir + "frozen_resnet10_2" logger = Logger(cfg, out_dir, cfg.wandb.job_name if cfg.wandb.enable else None) # Setup dataset and dataloaders - dataset = LeRobotDataset(cfg.dataset_repo_id) + dataset = LeRobotDataset( + cfg.dataset_repo_id, root=cfg.dataset_root, local_files_only=cfg.local_files_only + ) logging.info(f"Dataset size: {len(dataset)}") n_total = len(dataset) @@ -282,14 +287,13 @@ def train(cfg: DictConfig) -> None: train_dataset = torch.utils.data.Subset(dataset, range(0, n_train)) val_dataset = torch.utils.data.Subset(dataset, range(n_train, n_total)) - sampler = create_balanced_sampler(train_dataset, cfg) train_loader = DataLoader( train_dataset, batch_size=cfg.training.batch_size, num_workers=cfg.training.num_workers, sampler=sampler, - pin_memory=True, + pin_memory=device.type == "cuda", ) val_loader = DataLoader( @@ -297,7 +301,7 @@ def train(cfg: DictConfig) -> None: batch_size=cfg.eval.batch_size, shuffle=False, num_workers=cfg.training.num_workers, - pin_memory=True, + pin_memory=device.type == "cuda", ) # Resume training if requested