diff --git a/lerobot/common/envs/configs.py b/lerobot/common/envs/configs.py index 42be632d..825fa162 100644 --- a/lerobot/common/envs/configs.py +++ b/lerobot/common/envs/configs.py @@ -14,10 +14,12 @@ import abc from dataclasses import dataclass, field +from typing import Any, Dict, Optional, Tuple import draccus from lerobot.common.constants import ACTION, OBS_ENV, OBS_IMAGE, OBS_IMAGES, OBS_ROBOT +from lerobot.common.robot_devices.robots.configs import RobotConfig from lerobot.configs.types import FeatureType, PolicyFeature @@ -159,20 +161,84 @@ class XarmEnv(EnvConfig): @dataclass class VideoRecordConfig: """Configuration for video recording in ManiSkill environments.""" + enabled: bool = False record_dir: str = "videos" trajectory_name: str = "trajectory" + @dataclass class WrapperConfig: """Configuration for environment wrappers.""" + delta_action: float | None = None joint_masking_action_space: list[bool] | None = None + +@dataclass +class EEActionSpaceConfig: + """Configuration parameters for end-effector action space.""" + + x_step_size: float + y_step_size: float + z_step_size: float + bounds: Dict[str, Any] # Contains 'min' and 'max' keys with position bounds + use_gamepad: bool = False + + +@dataclass +class EnvWrapperConfig: + """Configuration for environment wrappers.""" + + display_cameras: bool = False + delta_action: float = 0.1 + use_relative_joint_positions: bool = True + add_joint_velocity_to_observation: bool = False + add_ee_pose_to_observation: bool = False + crop_params_dict: Optional[Dict[str, Tuple[int, int, int, int]]] = None + resize_size: Optional[Tuple[int, int]] = None + control_time_s: float = 20.0 + fixed_reset_joint_positions: Optional[Any] = None + reset_time_s: float = 5.0 + joint_masking_action_space: Optional[Any] = None + ee_action_space_params: Optional[EEActionSpaceConfig] = None + use_gripper: bool = False + + +@EnvConfig.register_subclass(name="gym_manipulator") +@dataclass +class HILSerlRobotEnvConfig(EnvConfig): + """Configuration for the HILSerlRobotEnv environment.""" + + robot: Optional[RobotConfig] = None + wrapper: Optional[EnvWrapperConfig] = None + fps: int = 10 + name: str = "real_robot" + mode: str = None # Either "record", "replay", None + repo_id: Optional[str] = None + dataset_root: Optional[str] = None + task: str = "" + num_episodes: int = 10 # only for record mode + episode: int = 0 + device: str = "cuda" + push_to_hub: bool = True + pretrained_policy_name_or_path: Optional[str] = None + reward_classifier: dict[str, str | None] = field( + default_factory=lambda: { + "pretrained_path": None, + "config_path": None, + } + ) + + def gym_kwargs(self) -> dict: + return {} + + @EnvConfig.register_subclass("maniskill_push") @dataclass class ManiskillEnvConfig(EnvConfig): """Configuration for the ManiSkill environment.""" + name: str = "maniskill/pushcube" task: str = "PushCube-v1" image_size: int = 64 @@ -185,7 +251,7 @@ class ManiskillEnvConfig(EnvConfig): render_mode: str = "rgb_array" render_size: int = 64 device: str = "cuda" - robot: str = "so100" # This is a hack to make the robot config work + robot: str = "so100" # This is a hack to make the robot config work video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig) wrapper: WrapperConfig = field(default_factory=WrapperConfig) features: dict[str, PolicyFeature] = field( @@ -218,4 +284,4 @@ class ManiskillEnvConfig(EnvConfig): "control_mode": self.control_mode, "sensor_configs": {"width": self.image_size, "height": self.image_size}, "num_envs": 1, - } \ No newline at end of file + } diff --git a/lerobot/common/envs/utils.py b/lerobot/common/envs/utils.py index 7e14628d..d5bfbb4b 100644 --- a/lerobot/common/envs/utils.py +++ b/lerobot/common/envs/utils.py @@ -49,9 +49,7 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten # sanity check that images are channel last _, h, w, c = img.shape - assert c < h and c < w, ( - f"expect channel last images, but instead got {img.shape=}" - ) + assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}" # sanity check that images are uint8 assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}" @@ -91,7 +89,7 @@ def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]: else: feature = ft - policy_key = env_cfg.features_map[key] + policy_key = env_cfg.features_map.get(key, key) policy_features[policy_key] = feature return policy_features diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index fa3b0187..4b20bce4 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -20,7 +20,7 @@ from typing import Any, Optional from lerobot.common.optim.optimizers import MultiAdamConfig from lerobot.configs.policies import PreTrainedConfig -from lerobot.configs.types import NormalizationMode, PolicyFeature, FeatureType +from lerobot.configs.types import NormalizationMode, PolicyFeature, FeatureType @dataclass @@ -29,7 +29,6 @@ class ConcurrencyConfig: learner: str = "threads" - @dataclass class ActorLearnerConfig: learner_host: str = "127.0.0.1" @@ -110,6 +109,7 @@ class SACConfig(PreTrainedConfig): use_backup_entropy: Whether to use backup entropy for the SAC algorithm. grad_clip_norm: Gradient clipping norm for the SAC algorithm. """ + normalization_mapping: dict[str, NormalizationMode] = field( default_factory=lambda: { "VISUAL": NormalizationMode.MEAN_STD, @@ -152,8 +152,8 @@ class SACConfig(PreTrainedConfig): camera_number: int = 1 device: str = "cuda" storage_device: str = "cpu" - # Set to "helper2424/resnet10" for hil serl - vision_encoder_name: str | None = None + # Set to "helper2424/resnet10" for hil serl + vision_encoder_name: str | None = None freeze_vision_encoder: bool = True image_encoder_hidden_dim: int = 32 shared_encoder: bool = True @@ -163,7 +163,7 @@ class SACConfig(PreTrainedConfig): online_env_seed: int = 10000 online_buffer_capacity: int = 100000 offline_buffer_capacity: int = 100000 - online_step_before_learning: int = 100 + online_step_before_learning: int = 100 policy_update_freq: int = 1 # SAC algorithm parameters @@ -181,24 +181,14 @@ class SACConfig(PreTrainedConfig): target_entropy: float | None = None use_backup_entropy: bool = True grad_clip_norm: float = 40.0 - + # Network configuration - critic_network_kwargs: CriticNetworkConfig = field( - default_factory=CriticNetworkConfig - ) - actor_network_kwargs: ActorNetworkConfig = field( - default_factory=ActorNetworkConfig - ) - policy_kwargs: PolicyConfig = field( - default_factory=PolicyConfig - ) - - actor_learner_config: ActorLearnerConfig = field( - default_factory=ActorLearnerConfig - ) - concurrency: ConcurrencyConfig = field( - default_factory=ConcurrencyConfig - ) + critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig) + actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig) + policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig) + + actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig) + concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig) def __post_init__(self): super().__post_init__() @@ -218,18 +208,20 @@ class SACConfig(PreTrainedConfig): return None def validate_features(self) -> None: - if "observation.image" not in self.input_features: - raise ValueError("You must provide 'observation.image' in the input features") - - if "observation.state" not in self.input_features: - raise ValueError("You must provide 'observation.state' in the input features") - + has_image = any(key.startswith("observation.image") for key in self.input_features) + has_state = "observation.state" in self.input_features + + if not (has_state or has_image): + raise ValueError( + "You must provide either 'observation.state' or an image observation (key starting with 'observation.image') in the input features" + ) + if "action" not in self.output_features: raise ValueError("You must provide 'action' in the output features") @property def image_features(self) -> list[str]: - return [key for key in self.input_features.keys() if 'image' in key] + return [key for key in self.input_features.keys() if "image" in key] @property def observation_delta_indices(self) -> list: @@ -243,9 +235,13 @@ class SACConfig(PreTrainedConfig): def reward_delta_indices(self) -> None: return None + if __name__ == "__main__": import draccus + config = SACConfig() draccus.set_config_type("json") - draccus.dump(config=config, stream=open(file='run_config.json', mode='w'), ) - + draccus.dump( + config=config, + stream=open(file="run_config.json", mode="w"), + ) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index b54def54..3c9b26ec 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -39,7 +39,6 @@ from lerobot.common.policies.utils import get_device_from_parameters class SACPolicy( PreTrainedPolicy, ): - config_class = SACConfig name = "sac" @@ -53,9 +52,7 @@ class SACPolicy( self.config = config if config.dataset_stats is not None: - input_normalization_params = _convert_normalization_params_to_tensor( - config.dataset_stats - ) + input_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats) self.normalize_inputs = Normalize( config.input_features, config.normalization_mapping, @@ -64,12 +61,10 @@ class SACPolicy( else: self.normalize_inputs = nn.Identity() - output_normalization_params = _convert_normalization_params_to_tensor( - config.dataset_stats - ) + output_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats) # HACK: This is hacky and should be removed - dataset_stats = dataset_stats or output_normalization_params + dataset_stats = dataset_stats or output_normalization_params self.normalize_targets = Normalize( config.output_features, config.normalization_mapping, dataset_stats ) @@ -138,7 +133,6 @@ class SACPolicy( self.log_alpha = nn.Parameter(torch.tensor([math.log(temperature_init)])) self.temperature = self.log_alpha.exp().item() - def get_optim_params(self) -> dict: return { "actor": self.actor.parameters_to_optimize, @@ -655,9 +649,10 @@ class SACObservationEncoder(nn.Module): class DefaultImageEncoder(nn.Module): def __init__(self, config: SACConfig): super().__init__() + image_key = next(key for key in config.input_features.keys() if key.startswith("observation.image")) # noqa: SIM118 self.image_enc_layers = nn.Sequential( nn.Conv2d( - in_channels=config.input_features["observation.image"].shape[0], + in_channels=config.input_features[image_key].shape[0], out_channels=config.image_encoder_hidden_dim, kernel_size=7, stride=2, @@ -685,7 +680,9 @@ class DefaultImageEncoder(nn.Module): ), nn.ReLU(), ) - dummy_batch = torch.zeros(1, *config.input_features["observation.image"].shape) + # Get first image key from input features + image_key = next(key for key in config.input_features.keys() if key.startswith("observation.image")) # noqa: SIM118 + dummy_batch = torch.zeros(1, *config.input_features[image_key].shape) with torch.inference_mode(): self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:] self.image_enc_layers.extend( @@ -844,8 +841,10 @@ if __name__ == "__main__": import draccus from lerobot.configs import parser + @parser.wrap() def main(config: SACConfig): policy = SACPolicy(config=config) print("yolo") - main() \ No newline at end of file + + main() diff --git a/lerobot/scripts/server/actor_server.py b/lerobot/scripts/server/actor_server.py index dcb9c3d3..e28cfea1 100644 --- a/lerobot/scripts/server/actor_server.py +++ b/lerobot/scripts/server/actor_server.py @@ -28,7 +28,6 @@ from torch.multiprocessing import Event, Queue # TODO: Remove the import of maniskill from lerobot.common.policies.factory import make_policy from lerobot.common.policies.sac.modeling_sac import SACPolicy -from lerobot.common.robot_devices.robots.utils import Robot, make_robot from lerobot.common.robot_devices.utils import busy_wait from lerobot.common.utils.random_utils import set_seed from lerobot.common.utils.utils import ( @@ -268,7 +267,7 @@ def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device) def act_with_policy( cfg: TrainPipelineConfig, - robot: Robot, + # robot: Robot, reward_classifier: nn.Module, shutdown_event: any, # Event, parameters_queue: Queue, @@ -287,7 +286,7 @@ def act_with_policy( logging.info("make_env online") - online_env = make_robot_env( cfg=cfg.env) + online_env = make_robot_env(cfg=cfg.env) set_seed(cfg.seed) device = get_safe_torch_device(cfg.policy.device, log=True) @@ -503,7 +502,6 @@ def actor_cli(cfg: TrainPipelineConfig): mp.set_start_method("spawn") init_logging(log_file="actor.log") - robot = make_robot(robot_type=cfg.env.robot) shutdown_event = setup_process_handlers(use_threads(cfg)) @@ -563,18 +561,17 @@ def actor_cli(cfg: TrainPipelineConfig): # HACK: FOR MANISKILL we do not have a reward classifier # TODO: Remove this once we merge into main reward_classifier = None - if ( - cfg.env.reward_classifier["pretrained_path"] is not None - and cfg.env.reward_classifier["config_path"] is not None - ): - reward_classifier = get_classifier( - pretrained_path=cfg.env.reward_classifier["pretrained_path"], - config_path=cfg.env.reward_classifier["config_path"], - ) + # if ( + # cfg.env.reward_classifier["pretrained_path"] is not None + # and cfg.env.reward_classifier["config_path"] is not None + # ): + # reward_classifier = get_classifier( + # pretrained_path=cfg.env.reward_classifier["pretrained_path"], + # config_path=cfg.env.reward_classifier["config_path"], + # ) act_with_policy( cfg=cfg, - robot=robot, reward_classifier=reward_classifier, shutdown_event=shutdown_event, parameters_queue=parameters_queue, diff --git a/lerobot/scripts/server/end_effector_control_utils.py b/lerobot/scripts/server/end_effector_control_utils.py index d576f2ef..a50dc4df 100644 --- a/lerobot/scripts/server/end_effector_control_utils.py +++ b/lerobot/scripts/server/end_effector_control_utils.py @@ -29,6 +29,9 @@ class InputController: self.z_step_size = z_step_size self.running = True self.episode_end_status = None # None, "success", or "failure" + self.intervention_flag = False + self.open_gripper_command = False + self.close_gripper_command = False def start(self): """Start the controller and initialize resources.""" @@ -70,6 +73,19 @@ class InputController: self.episode_end_status = None # Reset after reading return status + def should_intervene(self): + """Return True if intervention flag was set.""" + return self.intervention_flag + + def gripper_command(self): + """Return the current gripper command.""" + if self.open_gripper_command == self.close_gripper_command: + return "no-op" + elif self.open_gripper_command: + return "open" + elif self.close_gripper_command: + return "close" + class KeyboardController(InputController): """Generate motion deltas from keyboard input.""" @@ -326,7 +342,6 @@ class GamepadControllerHID(InputController): self.buttons = {} self.quit_requested = False self.save_requested = False - self.intervention_flag = False def find_device(self): """Look for the gamepad device by vendor and product ID.""" @@ -416,7 +431,13 @@ class GamepadControllerHID(InputController): buttons = data[5] # Check if RB is pressed then the intervention flag should be set - self.intervention_flag = data[6] == 2 + self.intervention_flag = data[6] in [2, 6, 10, 14] + + # Check if RT is pressed + self.open_gripper_command = data[6] in [8, 10, 12] + + # Check if LT is pressed + self.close_gripper_command = data[6] in [4, 6, 12] # Check if Y/Triangle button (bit 7) is pressed for saving # Check if X/Square button (bit 5) is pressed for failure @@ -676,12 +697,8 @@ def teleoperate_gym_env(env, controller, fps: int = 30): if __name__ == "__main__": from lerobot.common.robot_devices.robots.configs import RobotConfig from lerobot.common.robot_devices.robots.utils import make_robot_from_config - from lerobot.scripts.server.gym_manipulator import ( - EEActionSpaceConfig, - EnvWrapperConfig, - HILSerlRobotEnvConfig, - make_robot_env, - ) + from lerobot.scripts.server.gym_manipulator import make_robot_env + from lerobot.common.envs.configs import HILSerlRobotEnvConfig, EEActionSpaceConfig, EnvWrapperConfig parser = argparse.ArgumentParser(description="Test end-effector control") parser.add_argument( diff --git a/lerobot/scripts/server/gym_manipulator.py b/lerobot/scripts/server/gym_manipulator.py index 1f05d9a7..dbef1d6d 100644 --- a/lerobot/scripts/server/gym_manipulator.py +++ b/lerobot/scripts/server/gym_manipulator.py @@ -1,9 +1,8 @@ import logging import sys import time -from dataclasses import dataclass from threading import Lock -from typing import Annotated, Any, Dict, Optional, Tuple +from typing import Annotated, Any, Dict, Tuple import gymnasium as gym import numpy as np @@ -17,66 +16,13 @@ from lerobot.common.robot_devices.control_utils import ( is_headless, reset_follower_position, ) -from lerobot.common.robot_devices.robots.configs import RobotConfig from lerobot.common.robot_devices.robots.utils import make_robot_from_config from lerobot.common.utils.utils import log_say from lerobot.configs import parser from lerobot.scripts.server.kinematics import RobotKinematics logging.basicConfig(level=logging.INFO) - - -@dataclass -class EEActionSpaceConfig: - """Configuration parameters for end-effector action space.""" - - x_step_size: float - y_step_size: float - z_step_size: float - bounds: Dict[str, Any] # Contains 'min' and 'max' keys with position bounds - use_gamepad: bool = False - - -@dataclass -class EnvWrapperConfig: - """Configuration for environment wrappers.""" - - display_cameras: bool = False - delta_action: float = 0.1 - use_relative_joint_positions: bool = True - add_joint_velocity_to_observation: bool = False - add_ee_pose_to_observation: bool = False - crop_params_dict: Optional[Dict[str, Tuple[int, int, int, int]]] = None - resize_size: Optional[Tuple[int, int]] = None - control_time_s: float = 20.0 - fixed_reset_joint_positions: Optional[Any] = None - reset_time_s: float = 5.0 - joint_masking_action_space: Optional[Any] = None - ee_action_space_params: Optional[EEActionSpaceConfig] = None - reward_classifier_pretrained_path: Optional[str] = None - reward_classifier_config_file: Optional[str] = None - - -@EnvConfig.register_subclass(name="gym_manipulator") -@dataclass -class HILSerlRobotEnvConfig(EnvConfig): - """Configuration for the HILSerlRobotEnv environment.""" - - robot: Optional[RobotConfig] = None - wrapper: Optional[EnvWrapperConfig] = None - fps: int = 10 - mode: str = None # Either "record", "replay", None - repo_id: Optional[str] = None - dataset_root: Optional[str] = None - task: str = "" - num_episodes: int = 10 # only for record mode - episode: int = 0 - device: str = "cuda" - push_to_hub: bool = True - pretrained_policy_name_or_path: Optional[str] = None - - def gym_kwargs(self) -> dict: - return {} +MAX_GRIPPER_COMMAND = 25 class HILSerlRobotEnv(gym.Env): @@ -813,9 +759,10 @@ class BatchCompitableWrapper(gym.ObservationWrapper): class EEActionWrapper(gym.ActionWrapper): - def __init__(self, env, ee_action_space_params=None): + def __init__(self, env, ee_action_space_params=None, use_gripper=False): super().__init__(env) self.ee_action_space_params = ee_action_space_params + self.use_gripper = use_gripper # Initialize kinematics instance for the appropriate robot type robot_type = getattr(env.unwrapped.robot.config, "robot_type", "so100") @@ -829,10 +776,12 @@ class EEActionWrapper(gym.ActionWrapper): ee_action_space_params.z_step_size, ] ) + if self.use_gripper: + action_space_bounds = np.concatenate([action_space_bounds, [1.0]]) ee_action_space = gym.spaces.Box( low=-action_space_bounds, high=action_space_bounds, - shape=(3,), + shape=(3 + int(self.use_gripper),), dtype=np.float32, ) if isinstance(self.action_space, gym.spaces.Tuple): @@ -848,6 +797,10 @@ class EEActionWrapper(gym.ActionWrapper): if isinstance(action, tuple): action, _ = action + if self.use_gripper: + gripper_command = action[-1] + action = action[:-1] + current_joint_pos = self.unwrapped.robot.follower_arms["main"].read("Present_Position") current_ee_pos = self.fk_function(current_joint_pos) if isinstance(action, torch.Tensor): @@ -863,6 +816,12 @@ class EEActionWrapper(gym.ActionWrapper): position_only=True, fk_func=self.fk_function, ) + if self.use_gripper: + gripper_command = gripper_command * MAX_GRIPPER_COMMAND + gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1] + gripper_action = np.clip(gripper_state + gripper_command, 0, MAX_GRIPPER_COMMAND) + target_joint_pos[-1] = gripper_action + return target_joint_pos, is_intervention @@ -912,6 +871,7 @@ class GamepadControlWrapper(gym.Wrapper): x_step_size=1.0, y_step_size=1.0, z_step_size=1.0, + use_gripper=False, auto_reset=False, input_threshold=0.001, ): @@ -948,6 +908,7 @@ class GamepadControlWrapper(gym.Wrapper): z_step_size=z_step_size, ) self.auto_reset = auto_reset + self.use_gripper = use_gripper self.input_threshold = input_threshold self.controller.start() @@ -977,6 +938,15 @@ class GamepadControlWrapper(gym.Wrapper): # Create action from gamepad input gamepad_action = np.array([delta_x, delta_y, delta_z], dtype=np.float32) + if self.use_gripper: + gripper_command = self.controller.gripper_command() + if gripper_command == "open": + gamepad_action = np.concatenate([gamepad_action, [1.0]]) + elif gripper_command == "close": + gamepad_action = np.concatenate([gamepad_action, [-1.0]]) + else: + gamepad_action = np.concatenate([gamepad_action, [0.0]]) + # Check episode ending buttons # We'll rely on controller.get_episode_end_status() which returns "success", "failure", or None episode_end_status = self.controller.get_episode_end_status() @@ -1023,6 +993,7 @@ class GamepadControlWrapper(gym.Wrapper): final_action = (torch.from_numpy(gamepad_action), False) else: final_action = torch.from_numpy(gamepad_action) + else: # Use the original action final_action = action @@ -1138,7 +1109,11 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv: # env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device) env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps) if cfg.wrapper.ee_action_space_params is not None: - env = EEActionWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params) + env = EEActionWrapper( + env=env, + ee_action_space_params=cfg.wrapper.ee_action_space_params, + use_gripper=cfg.wrapper.use_gripper, + ) if cfg.wrapper.ee_action_space_params is not None and cfg.wrapper.ee_action_space_params.use_gamepad: # env = ActionScaleWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params) env = GamepadControlWrapper( @@ -1146,6 +1121,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv: x_step_size=cfg.wrapper.ee_action_space_params.x_step_size, y_step_size=cfg.wrapper.ee_action_space_params.y_step_size, z_step_size=cfg.wrapper.ee_action_space_params.z_step_size, + use_gripper=cfg.wrapper.use_gripper, ) else: env = KeyboardInterfaceWrapper(env=env) @@ -1184,7 +1160,7 @@ def get_classifier(cfg): return model -def record_dataset(env, policy, cfg: HILSerlRobotEnvConfig): +def record_dataset(env, policy, cfg): """ Record a dataset of robot interactions using either a policy or teleop.