From 114ec644d072d50b88d8c75cea7f1ea32fa653dc Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Tue, 25 Mar 2025 14:24:46 +0100 Subject: [PATCH] Change config logic in: - gym_manipulator - find_joint_limits - end_effector_utils --- lerobot/common/envs/utils.py | 15 +- .../common/robot_devices/control_configs.py | 2 + lerobot/common/robot_devices/control_utils.py | 6 +- .../common/robot_devices/robots/configs.py | 4 +- .../robot_devices/robots/manipulator.py | 24 +- lerobot/scripts/control_robot.py | 20 +- .../server/end_effector_control_utils.py | 53 ++- lerobot/scripts/server/find_joint_limits.py | 67 +++- lerobot/scripts/server/gym_manipulator.py | 344 ++++++++---------- pyproject.toml | 1 - 10 files changed, 256 insertions(+), 280 deletions(-) diff --git a/lerobot/common/envs/utils.py b/lerobot/common/envs/utils.py index 9feb3c39..3f694481 100644 --- a/lerobot/common/envs/utils.py +++ b/lerobot/common/envs/utils.py @@ -44,13 +44,18 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten if "images" not in key: continue - for imgkey, img in imgs.items(): - # TODO(aliberts, rcadene): use transforms.ToTensor()? + # TODO(aliberts, rcadene): use transforms.ToTensor()? + if not torch.is_tensor(img): img = torch.from_numpy(img) - # sanity check that images are channel last - _, h, w, c = img.shape - assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}" + if img.ndim == 3: + img = img.unsqueeze(0) + + # sanity check that images are channel last + _, h, w, c = img.shape + assert c < h and c < w, ( + f"expect channel last images, but instead got {img.shape=}" + ) # sanity check that images are uint8 assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}" diff --git a/lerobot/common/robot_devices/control_configs.py b/lerobot/common/robot_devices/control_configs.py index cb558c71..00577ddb 100644 --- a/lerobot/common/robot_devices/control_configs.py +++ b/lerobot/common/robot_devices/control_configs.py @@ -87,6 +87,8 @@ class RecordControlConfig(ControlConfig): play_sounds: bool = True # Resume recording on an existing dataset. resume: bool = False + # Reset follower arms to an initial configuration. + reset_follower_arms: bool = True def __post_init__(self): # HACK: We parse again the cli args here to get the pretrained path if there was one. diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py index 53cd508f..f4ca4f7d 100644 --- a/lerobot/common/robot_devices/control_utils.py +++ b/lerobot/common/robot_devices/control_utils.py @@ -221,7 +221,7 @@ def record_episode( events=events, policy=policy, fps=fps, - record_delta_actions=record_delta_actions, + # record_delta_actions=record_delta_actions, teleoperate=policy is None, single_task=single_task, ) @@ -267,8 +267,8 @@ def control_loop( if teleoperate: observation, action = robot.teleop_step(record_data=True) - if record_delta_actions: - action["action"] = action["action"] - current_joint_positions + # if record_delta_actions: + # action["action"] = action["action"] - current_joint_positions else: observation = robot.capture_observation() diff --git a/lerobot/common/robot_devices/robots/configs.py b/lerobot/common/robot_devices/robots/configs.py index e940b442..8d66dae2 100644 --- a/lerobot/common/robot_devices/robots/configs.py +++ b/lerobot/common/robot_devices/robots/configs.py @@ -443,7 +443,7 @@ class So100RobotConfig(ManipulatorRobotConfig): leader_arms: dict[str, MotorsBusConfig] = field( default_factory=lambda: { "main": FeetechMotorsBusConfig( - port="/dev/tty.usbmodem58760431091", + port="/dev/tty.usbmodem58760433331", motors={ # name: (index, model) "shoulder_pan": [1, "sts3215"], @@ -460,7 +460,7 @@ class So100RobotConfig(ManipulatorRobotConfig): follower_arms: dict[str, MotorsBusConfig] = field( default_factory=lambda: { "main": FeetechMotorsBusConfig( - port="/dev/tty.usbmodem585A0076891", + port="/dev/tty.usbmodem58760431631", motors={ # name: (index, model) "shoulder_pan": [1, "sts3215"], diff --git a/lerobot/common/robot_devices/robots/manipulator.py b/lerobot/common/robot_devices/robots/manipulator.py index e7993621..e14a5264 100644 --- a/lerobot/common/robot_devices/robots/manipulator.py +++ b/lerobot/common/robot_devices/robots/manipulator.py @@ -475,12 +475,12 @@ class ManipulatorRobot: goal_pos = leader_pos[name] # If specified, clip the goal positions within predefined bounds specified in the config of the robot - if self.config.joint_position_relative_bounds is not None: - goal_pos = torch.clamp( - goal_pos, - self.config.joint_position_relative_bounds["min"], - self.config.joint_position_relative_bounds["max"], - ) + # if self.config.joint_position_relative_bounds is not None: + # goal_pos = torch.clamp( + # goal_pos, + # self.config.joint_position_relative_bounds["min"], + # self.config.joint_position_relative_bounds["max"], + # ) # Cap goal position when too far away from present position. # Slower fps expected due to reading from the follower. @@ -604,12 +604,12 @@ class ManipulatorRobot: from_idx = to_idx # If specified, clip the goal positions within predefined bounds specified in the config of the robot - if self.config.joint_position_relative_bounds is not None: - goal_pos = torch.clamp( - goal_pos, - self.config.joint_position_relative_bounds["min"], - self.config.joint_position_relative_bounds["max"], - ) + # if self.config.joint_position_relative_bounds is not None: + # goal_pos = torch.clamp( + # goal_pos, + # self.config.joint_position_relative_bounds["min"], + # self.config.joint_position_relative_bounds["max"], + # ) # Cap goal position when too far away from present position. # Slower fps expected due to reading from the follower. diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index 0399c0e1..633fada7 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -371,8 +371,8 @@ def replay( start_episode_t = time.perf_counter() action = actions[idx]["action"] - if replay_delta_actions: - action = action + current_joint_positions + # if replay_delta_actions: + # action = action + current_joint_positions robot.send_action(action) dt_s = time.perf_counter() - start_episode_t @@ -382,9 +382,7 @@ def replay( log_control_info(robot, dt_s, fps=cfg.fps) -def _init_rerun( - control_config: ControlConfig, session_name: str = "lerobot_control_loop" -) -> None: +def _init_rerun(control_config: ControlConfig, session_name: str = "lerobot_control_loop") -> None: """Initializes the Rerun SDK for visualizing the control loop. Args: @@ -430,23 +428,17 @@ def control_robot(cfg: ControlPipelineConfig): if isinstance(cfg.control, CalibrateControlConfig): calibrate(robot, cfg.control) elif isinstance(cfg.control, TeleoperateControlConfig): - _init_rerun( - control_config=cfg.control, session_name="lerobot_control_loop_teleop" - ) + _init_rerun(control_config=cfg.control, session_name="lerobot_control_loop_teleop") teleoperate(robot, cfg.control) elif isinstance(cfg.control, RecordControlConfig): - _init_rerun( - control_config=cfg.control, session_name="lerobot_control_loop_record" - ) + _init_rerun(control_config=cfg.control, session_name="lerobot_control_loop_record") record(robot, cfg.control) elif isinstance(cfg.control, ReplayControlConfig): replay(robot, cfg.control) elif isinstance(cfg.control, RemoteRobotConfig): from lerobot.common.robot_devices.robots.lekiwi_remote import run_lekiwi - _init_rerun( - control_config=cfg.control, session_name="lerobot_control_loop_remote" - ) + _init_rerun(control_config=cfg.control, session_name="lerobot_control_loop_remote") run_lekiwi(cfg.robot) if robot.is_connected: diff --git a/lerobot/scripts/server/end_effector_control_utils.py b/lerobot/scripts/server/end_effector_control_utils.py index d5b217e4..1c056aa5 100644 --- a/lerobot/scripts/server/end_effector_control_utils.py +++ b/lerobot/scripts/server/end_effector_control_utils.py @@ -1,14 +1,13 @@ -import argparse +from lerobot.common.robot_devices.utils import busy_wait +from lerobot.scripts.server.kinematics import RobotKinematics import logging import time - -import numpy as np import torch - -from lerobot.common.robot_devices.robots.factory import make_robot -from lerobot.common.robot_devices.utils import busy_wait -from lerobot.common.utils.utils import init_hydra_config -from lerobot.scripts.server.kinematics import RobotKinematics +import numpy as np +import argparse +from lerobot.common.robot_devices.robots.utils import make_robot_from_config +from lerobot.scripts.server.gym_manipulator import make_robot_env, HILSerlRobotEnvConfig +from lerobot.common.robot_devices.robots.configs import RobotConfig logging.basicConfig(level=logging.INFO) @@ -677,15 +676,6 @@ def teleoperate_gym_env(env, controller, fps: int = 30): # Close the environment env.close() - -def make_robot_from_config(config_path, overrides=None): - """Helper function to create a robot from a config file.""" - if overrides is None: - overrides = [] - robot_cfg = init_hydra_config(config_path, overrides) - return make_robot(robot_cfg) - - if __name__ == "__main__": parser = argparse.ArgumentParser(description="Test end-effector control") parser.add_argument( @@ -703,21 +693,22 @@ if __name__ == "__main__": help="Control mode to use", ) parser.add_argument( - "--task", + "--robot-type", type=str, - default="Robot manipulation task", - help="Description of the task being performed", + default="so100", + help="Robot type (so100, koch, aloha, etc.)", ) parser.add_argument( - "--push-to-hub", - default=True, - type=bool, - help="Push the dataset to Hugging Face Hub", + "--config-path", + type=str, + default=None, + help="Path to the config file in json format", ) - # Add the rest of your existing arguments + args = parser.parse_args() - robot = make_robot_from_config("lerobot/configs/robot/so100.yaml", []) + robot_config = RobotConfig.get_choice_class(args.robot_type)(mock=False) + robot = make_robot_from_config(robot_config) if not robot.is_connected: robot.connect() @@ -743,12 +734,12 @@ if __name__ == "__main__": elif args.mode in ["keyboard_gym", "gamepad_gym"]: # Gym environment control modes - from lerobot.scripts.server.gym_manipulator import make_robot_env + cfg = HILSerlRobotEnvConfig() + if args.config_path is not None: + cfg = HILSerlRobotEnvConfig.from_json(args.config_path) - cfg = init_hydra_config("lerobot/configs/env/so100_real.yaml", []) - cfg.env.wrapper.ee_action_space_params.use_gamepad = False - env = make_robot_env(robot, None, cfg) - teleoperate_gym_env(env, controller) + env = make_robot_env(cfg, robot) + teleoperate_gym_env(env, controller, fps=args.fps) elif args.mode == "leader": # Leader-follower modes don't use controllers diff --git a/lerobot/scripts/server/find_joint_limits.py b/lerobot/scripts/server/find_joint_limits.py index f8891ba7..deec4d75 100644 --- a/lerobot/scripts/server/find_joint_limits.py +++ b/lerobot/scripts/server/find_joint_limits.py @@ -5,10 +5,13 @@ import cv2 import numpy as np from lerobot.common.robot_devices.control_utils import is_headless -from lerobot.common.robot_devices.robots.factory import make_robot -from lerobot.common.utils.utils import init_hydra_config +from lerobot.common.robot_devices.robots.utils import make_robot_from_config from lerobot.scripts.server.kinematics import RobotKinematics +from lerobot.configs import parser +from lerobot.common.robot_devices.robots.configs import RobotConfig +follower_port = "/dev/tty.usbmodem58760431631" +leader_port = "/dev/tty.usbmodem58760433331" def find_joint_bounds( robot, @@ -78,20 +81,28 @@ def find_ee_bounds( break +def make_robot(robot_type="so100", mock=True): + """ + Create a robot instance using the appropriate robot config class. + + Args: + robot_type: Robot type string (e.g., "so100", "koch", "aloha") + mock: Whether to use mock mode for hardware (default: True) + + Returns: + Robot instance + """ + + # Get the appropriate robot config class based on robot_type + robot_config = RobotConfig.get_choice_class(robot_type)(mock=mock) + robot_config.leader_arms["main"].port = leader_port + robot_config.follower_arms["main"].port = follower_port + + return make_robot_from_config(robot_config) + if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--robot-path", - type=str, - default="lerobot/configs/robot/koch.yaml", - help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.", - ) - parser.add_argument( - "--robot-overrides", - type=str, - nargs="*", - help="Any key=value arguments to override config values (use dots for.nested=overrides)", - ) + # Create argparse for script-specific arguments + parser = argparse.ArgumentParser(add_help=False) # Set add_help=False to avoid conflict parser.add_argument( "--mode", type=str, @@ -105,13 +116,29 @@ if __name__ == "__main__": default=30, help="Time step to use for control.", ) - args = parser.parse_args() - robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides) - - robot = make_robot(robot_cfg) + parser.add_argument( + "--robot-type", + type=str, + default="so100", + help="Robot type (so100, koch, aloha, etc.)", + ) + parser.add_argument( + "--mock", + type=int, + default=1, + help="Use mock mode for hardware simulation", + ) + + # Only parse known args, leaving robot config args for Hydra if used + args, _ = parser.parse_known_args() + + # Create robot with the appropriate config + robot = make_robot(args.robot_type, args.mock) + if args.mode == "joint": find_joint_bounds(robot, args.control_time_s) elif args.mode == "ee": find_ee_bounds(robot, args.control_time_s) + if robot.is_connected: - robot.disconnect() + robot.disconnect() \ No newline at end of file diff --git a/lerobot/scripts/server/gym_manipulator.py b/lerobot/scripts/server/gym_manipulator.py index 63a0fbc9..2f39bfdb 100644 --- a/lerobot/scripts/server/gym_manipulator.py +++ b/lerobot/scripts/server/gym_manipulator.py @@ -1,7 +1,8 @@ -import argparse import logging import sys import time +import sys + from threading import Lock from typing import Annotated, Any, Dict, Tuple @@ -9,6 +10,9 @@ import gymnasium as gym import numpy as np import torch import torchvision.transforms.functional as F # noqa: N812 +import json + +from dataclasses import dataclass from lerobot.common.envs.utils import preprocess_observation from lerobot.common.robot_devices.control_utils import ( @@ -16,12 +20,69 @@ from lerobot.common.robot_devices.control_utils import ( is_headless, reset_follower_position, ) -from lerobot.common.robot_devices.robots.factory import make_robot -from lerobot.common.utils.utils import init_hydra_config, log_say + +from typing import Optional +from lerobot.common.utils.utils import log_say +from lerobot.common.robot_devices.robots.utils import make_robot_from_config + +from lerobot.common.robot_devices.robots.configs import RobotConfig + from lerobot.scripts.server.kinematics import RobotKinematics +from lerobot.configs import parser logging.basicConfig(level=logging.INFO) +@dataclass +class EEActionSpaceConfig: + """Configuration parameters for end-effector action space.""" + x_step_size: float + y_step_size: float + z_step_size: float + bounds: Dict[str, Any] # Contains 'min' and 'max' keys with position bounds + use_gamepad: bool = False + + +@dataclass +class EnvWrapperConfig: + """Configuration for environment wrappers.""" + display_cameras: bool = False + delta_action: float = 0.1 + use_relative_joint_positions: bool = True + add_joint_velocity_to_observation: bool = False + add_ee_pose_to_observation: bool = False + crop_params_dict: Optional[Dict[str, Tuple[int, int, int, int]]] = None + resize_size: Optional[Tuple[int, int]] = None + control_time_s: float = 20.0 + fixed_reset_joint_positions: Optional[Any] = None + reset_time_s: float = 5.0 + joint_masking_action_space: Optional[Any] = None + ee_action_space_params: Optional[EEActionSpaceConfig] = None + reward_classifier_pretrained_path: Optional[str] = None + reward_classifier_config_file: Optional[str] = None + + +@dataclass +class HILSerlRobotEnvConfig: + """Configuration for the HILSerlRobotEnv environment.""" + robot: RobotConfig + wrapper: EnvWrapperConfig + env_name: str = "real_robot" + fps: int = 10 + mode: str = None # Either "record", "replay", None + repo_id: Optional[str] = None + dataset_root: Optional[str] = None + task: str = "" + num_episodes: int = 10 # only for record mode + episode: int = 0 + device: str = "cuda" + push_to_hub: bool = True + pretrained_policy_name_or_path: Optional[str] = None + + @classmethod + def from_json(cls, json_path: str): + with open(json_path, "r") as f: + config = json.load(f) + return cls(**config) class HILSerlRobotEnv(gym.Env): """ @@ -49,7 +110,7 @@ class HILSerlRobotEnv(gym.Env): The environment is set up with a robot interface, which is used to capture observations and send joint commands. The setup supports both relative (delta) adjustments and absolute joint positions for controlling the robot. - Args: + cfg. robot: The robot interface object used to connect and interact with the physical robot. use_delta_action_space (bool): If True, uses a delta (relative) action space for joint control. Otherwise, absolute joint positions are used. @@ -77,14 +138,17 @@ class HILSerlRobotEnv(gym.Env): self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position") # Retrieve the size of the joint position interval bound. - self.relative_bounds_size = ( - ( - self.robot.config.joint_position_relative_bounds["max"] - - self.robot.config.joint_position_relative_bounds["min"] - ) - if self.robot.config.joint_position_relative_bounds is not None - else None - ) + + self.relative_bounds_size = None + # ( + # ( + # self.robot.config.joint_position_relative_bounds["max"] + # - self.robot.config.joint_position_relative_bounds["min"] + # ) + # if self.robot.config.joint_position_relative_bounds is not None + # else None + # ) + self.robot.config.joint_position_relative_bounds = None self.robot.config.max_relative_target = ( self.relative_bounds_size.float() if self.relative_bounds_size is not None else None @@ -168,7 +232,7 @@ class HILSerlRobotEnv(gym.Env): Reset the environment to its initial state. This method resets the step counter and clears any episodic data. - Args: + cfg. seed (Optional[int]): A seed for random number generation to ensure reproducibility. options (Optional[dict]): Additional options to influence the reset behavior. @@ -203,7 +267,7 @@ class HILSerlRobotEnv(gym.Env): - When True, a teleoperation step is executed. If using a delta action space, an absolute teleop action is converted to relative change based on the current joint positions. - Args: + cfg. action (tuple): A tuple with two elements: - policy_action (np.ndarray or torch.Tensor): The commanded joint positions. - intervention_bool (bool): True if the human operator intervenes by providing a teleoperation input. @@ -258,7 +322,8 @@ class HILSerlRobotEnv(gym.Env): if teleop_action.dim() == 1: teleop_action = teleop_action.unsqueeze(0) - # self.render() + if self.display_cameras: + self.render() self.current_step += 1 @@ -353,7 +418,7 @@ class RewardWrapper(gym.Wrapper): """ Wrapper to add reward prediction to the environment, it use a trained classifer. - Args: + cfg. env: The environment to wrap reward_classifier: The reward classifier model device: The device to run the model on @@ -396,7 +461,7 @@ class JointMaskingActionSpace(gym.Wrapper): """ Wrapper to mask out dimensions of the action space. - Args: + cfg. env: The environment to wrap mask: Binary mask array where 0 indicates dimensions to remove """ @@ -428,7 +493,7 @@ class JointMaskingActionSpace(gym.Wrapper): """ Convert masked action back to full action space. - Args: + cfg. action: Action in masked space. For Tuple spaces, the first element is masked. Returns: @@ -859,7 +924,7 @@ class GamepadControlWrapper(gym.Wrapper): """ Initialize the gamepad controller wrapper. - Args: + cfg. env: The environment to wrap x_step_size: Base movement step size for X axis in meters y_step_size: Base movement step size for Y axis in meters @@ -937,7 +1002,7 @@ class GamepadControlWrapper(gym.Wrapper): """ Step the environment, using gamepad input to override actions when active. - Args: + cfg. action: Original action from agent Returns: @@ -1029,25 +1094,19 @@ class ActionScaleWrapper(gym.ActionWrapper): return action * self.scale_vector, is_intervention -def make_robot_env( - robot, - reward_classifier, - cfg, - n_envs: int = 1, -) -> gym.vector.VectorEnv: +def make_robot_env(cfg, robot) -> gym.vector.VectorEnv: """ Factory function to create a vectorized robot environment. - Args: + cfg. robot: Robot instance to control reward_classifier: Classifier model for computing rewards cfg: Configuration object containing environment parameters - n_envs: Number of environments to create in parallel. Defaults to 1. Returns: A vectorized gym environment with all the necessary wrappers applied. """ - if "maniskill" in cfg.env.name: + if "maniskill" in cfg.env_name: from lerobot.scripts.server.maniskill_manipulator import make_maniskill logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN") @@ -1056,69 +1115,69 @@ def make_robot_env( n_envs=1, ) return env + # Create base environment env = HILSerlRobotEnv( robot=robot, - display_cameras=cfg.env.wrapper.display_cameras, - delta=cfg.env.wrapper.delta_action, - use_delta_action_space=cfg.env.wrapper.use_relative_joint_positions - and cfg.env.wrapper.ee_action_space_params is None, + display_cameras=cfg.wrapper.display_cameras, + delta=cfg.wrapper.delta_action, + use_delta_action_space=cfg.wrapper.use_relative_joint_positions + and cfg.wrapper.ee_action_space_params is None, ) # Add observation and image processing - if cfg.env.wrapper.add_joint_velocity_to_observation: + if cfg.wrapper.add_joint_velocity_to_observation: env = AddJointVelocityToObservation(env=env, fps=cfg.fps) - if cfg.env.wrapper.add_ee_pose_to_observation: - env = EEObservationWrapper(env=env, ee_pose_limits=cfg.env.wrapper.ee_action_space_params.bounds) + if cfg.wrapper.add_ee_pose_to_observation: + env = EEObservationWrapper(env=env, ee_pose_limits=cfg.wrapper.ee_action_space_params.bounds) - env = ConvertToLeRobotObservation(env=env, device=cfg.env.device) + env = ConvertToLeRobotObservation(env=env, device=cfg.device) - if cfg.env.wrapper.crop_params_dict is not None: + if cfg.wrapper.crop_params_dict is not None: env = ImageCropResizeWrapper( env=env, - crop_params_dict=cfg.env.wrapper.crop_params_dict, - resize_size=cfg.env.wrapper.resize_size, + crop_params_dict=cfg.wrapper.crop_params_dict, + resize_size=cfg.wrapper.resize_size, ) # Add reward computation and control wrappers # env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device) - env = TimeLimitWrapper(env=env, control_time_s=cfg.env.wrapper.control_time_s, fps=cfg.fps) - if cfg.env.wrapper.ee_action_space_params is not None: - env = EEActionWrapper(env=env, ee_action_space_params=cfg.env.wrapper.ee_action_space_params) + env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps) + if cfg.wrapper.ee_action_space_params is not None: + env = EEActionWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params) if ( - cfg.env.wrapper.ee_action_space_params is not None - and cfg.env.wrapper.ee_action_space_params.use_gamepad + cfg.wrapper.ee_action_space_params is not None + and cfg.wrapper.ee_action_space_params.use_gamepad ): - # env = ActionScaleWrapper(env=env, ee_action_space_params=cfg.env.wrapper.ee_action_space_params) + # env = ActionScaleWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params) env = GamepadControlWrapper( env=env, - x_step_size=cfg.env.wrapper.ee_action_space_params.x_step_size, - y_step_size=cfg.env.wrapper.ee_action_space_params.y_step_size, - z_step_size=cfg.env.wrapper.ee_action_space_params.z_step_size, + x_step_size=cfg.wrapper.ee_action_space_params.x_step_size, + y_step_size=cfg.wrapper.ee_action_space_params.y_step_size, + z_step_size=cfg.wrapper.ee_action_space_params.z_step_size, ) else: env = KeyboardInterfaceWrapper(env=env) env = ResetWrapper( env=env, - reset_pose=cfg.env.wrapper.fixed_reset_joint_positions, - reset_time_s=cfg.env.wrapper.reset_time_s, + reset_pose=cfg.wrapper.fixed_reset_joint_positions, + reset_time_s=cfg.wrapper.reset_time_s, ) if ( - cfg.env.wrapper.ee_action_space_params is None - and cfg.env.wrapper.joint_masking_action_space is not None + cfg.wrapper.ee_action_space_params is None + and cfg.wrapper.joint_masking_action_space is not None ): - env = JointMaskingActionSpace(env=env, mask=cfg.env.wrapper.joint_masking_action_space) + env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space) env = BatchCompitableWrapper(env=env) return env -def get_classifier(pretrained_path, config_path, device="mps"): - if pretrained_path is None or config_path is None: +def get_classifier(cfg): + if cfg.wrapper.reward_classifier_pretrained_path is None or cfg.wrapper.reward_classifier_config_file is None: return None - from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg from lerobot.common.policies.hilserl.classifier.configuration_classifier import ( ClassifierConfig, ) @@ -1126,8 +1185,6 @@ def get_classifier(pretrained_path, config_path, device="mps"): Classifier, ) - cfg = init_hydra_config(config_path) - classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg) classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths model = Classifier(classifier_config) @@ -1136,21 +1193,11 @@ def get_classifier(pretrained_path, config_path, device="mps"): return model -def record_dataset( - env, - repo_id, - root=None, - num_episodes=1, - control_time_s=20, - fps=30, - push_to_hub=True, - task_description="", - policy=None, -): +def record_dataset(env, policy, cfg: HILSerlRobotEnvConfig): """ Record a dataset of robot interactions using either a policy or teleop. - Args: + cfg. env: The environment to record from repo_id: Repository ID for dataset storage root: Local root directory for dataset (optional) @@ -1195,9 +1242,9 @@ def record_dataset( # Create dataset dataset = LeRobotDataset.create( - repo_id, - fps, - root=root, + cfg.repo_id, + cfg.fps, + root=cfg.dataset_root, use_videos=True, image_writer_threads=4, image_writer_processes=0, @@ -1206,17 +1253,17 @@ def record_dataset( # Record episodes episode_index = 0 - while episode_index < num_episodes: + while episode_index < cfg.record_num_episodes: obs, _ = env.reset() start_episode_t = time.perf_counter() log_say(f"Recording episode {episode_index}", play_sounds=True) # Run episode steps - while time.perf_counter() - start_episode_t < control_time_s: + while time.perf_counter() - start_episode_t < cfg.wrapper.control_time_s: start_loop_t = time.perf_counter() # Get action from policy if available - if policy is not None: + if cfg.pretrained_policy_name_or_path is not None: action = policy.select_action(obs) # Step environment @@ -1240,9 +1287,9 @@ def record_dataset( dataset.add_frame(frame) # Maintain consistent timing - if fps: + if cfg.fps: dt_s = time.perf_counter() - start_loop_t - busy_wait(1 / fps - dt_s) + busy_wait(1 / cfg.fps - dt_s) if terminated or truncated: break @@ -1253,13 +1300,13 @@ def record_dataset( logging.info(f"Re-recording episode {episode_index}") continue - dataset.save_episode(task_description) + dataset.save_episode(cfg.task) episode_index += 1 # Finalize dataset dataset.consolidate(run_compute_stats=True) - if push_to_hub: - dataset.push_to_hub(repo_id) + if cfg.push_to_hub: + dataset.push_to_hub(cfg.repo_id) def replay_episode(env, repo_id, root=None, episode=0): @@ -1282,134 +1329,44 @@ def replay_episode(env, repo_id, root=None, episode=0): busy_wait(1 / 10 - dt_s) -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--fps", type=int, default=30, help="control frequency") - parser.add_argument( - "--robot-path", - type=str, - default="lerobot/configs/robot/koch.yaml", - help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.", - ) - parser.add_argument( - "--robot-overrides", - type=str, - nargs="*", - help="Any key=value arguments to override config values (use dots for.nested=overrides)", - ) - parser.add_argument( - "-p", - "--pretrained-policy-name-or-path", - help=( - "Either the repo ID of a model hosted on the Hub or a path to a directory containing weights " - "saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch " - ), - ) - parser.add_argument( - "--display-cameras", - help=("Whether to display the camera feed while the rollout is happening"), - ) - parser.add_argument( - "--reward-classifier-pretrained-path", - type=str, - default=None, - help="Path to the pretrained classifier weights.", - ) - parser.add_argument( - "--reward-classifier-config-file", - type=str, - default=None, - help="Path to a yaml config file that is necessary to build the reward classifier model.", - ) - parser.add_argument("--env-path", type=str, default=None, help="Path to the env yaml file") - parser.add_argument( - "--env-overrides", - type=str, - default=None, - help="Overrides for the env yaml file", - ) - parser.add_argument( - "--control-time-s", - type=float, - default=20, - help="Maximum episode length in seconds", - ) - parser.add_argument( - "--reset-follower-pos", - type=int, - default=1, - help="Reset follower between episodes", - ) - parser.add_argument( - "--replay-repo-id", - type=str, - default=None, - help="Repo ID of the episode to replay", - ) - parser.add_argument("--dataset-root", type=str, default=None, help="Root of the dataset to replay") - parser.add_argument("--replay-episode", type=int, default=0, help="Episode to replay") - parser.add_argument( - "--record-repo-id", - type=str, - default=None, - help="Repo ID of the dataset to record", - ) - parser.add_argument( - "--record-num-episodes", - type=int, - default=1, - help="Number of episodes to record", - ) - parser.add_argument( - "--record-episode-task", - type=str, - default="", - help="Single line description of the task to record", - ) +@parser.wrap() +def main(cfg: HILSerlRobotEnvConfig): - args = parser.parse_args() + robot = make_robot_from_config(cfg.robot) - robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides) - robot = make_robot(robot_cfg) - - reward_classifier = get_classifier( - args.reward_classifier_pretrained_path, args.reward_classifier_config_file - ) + reward_classifier = None #get_classifier( + # cfg.wrapper.reward_classifier_pretrained_path, cfg.wrapper.reward_classifier_config_file + # ) user_relative_joint_positions = True - cfg = init_hydra_config(args.env_path, args.env_overrides) - env = make_robot_env( - robot, - reward_classifier, - cfg, # .wrapper, - ) + env = make_robot_env(cfg, robot) - if args.record_repo_id is not None: + if cfg.mode == "record": policy = None - if args.pretrained_policy_name_or_path is not None: + if cfg.pretrained_policy_name_or_path is not None: from lerobot.common.policies.sac.modeling_sac import SACPolicy - policy = SACPolicy.from_pretrained(args.pretrained_policy_name_or_path) + policy = SACPolicy.from_pretrained(cfg.pretrained_policy_name_or_path) policy.to(cfg.device) policy.eval() record_dataset( env, - args.record_repo_id, - root=args.dataset_root, - num_episodes=args.record_num_episodes, - fps=args.fps, - task_description=args.record_episode_task, + cfg.repo_id, + root=cfg.dataset_root, + num_episodes=cfg.num_episodes, + fps=cfg.fps, + task_description=cfg.task, policy=policy, ) exit() - if args.replay_repo_id is not None: + if cfg.mode == "replay": replay_episode( env, - args.replay_repo_id, - root=args.dataset_root, - episode=args.replay_episode, + cfg.replay_repo_id, + root=cfg.dataset_root, + episode=cfg.replay_episode, ) exit() @@ -1442,7 +1399,10 @@ if __name__ == "__main__": num_episode += 1 dt_s = time.perf_counter() - start_loop_s - busy_wait(1 / args.fps - dt_s) + busy_wait(1 / cfg.fps - dt_s) logging.info(f"Success after 20 steps {sucesses}") logging.info(f"success rate {sum(sucesses) / len(sucesses)}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index a8814a4c..a34ad5e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,6 @@ dependencies = [ "huggingface-hub[hf-transfer,cli]>=0.27.1 ; python_version < '4.0'", "imageio[ffmpeg]>=2.34.0", "jsonlines>=4.0.0", - "mani-skill>=3.0.0b18", "numba>=0.59.0", "omegaconf>=2.3.0", "opencv-python-headless>=4.9.0",