diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index fffe2085..6bb48e70 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -412,7 +412,7 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea names = ft["names"] # Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets. - if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w) + if names is not None and names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w) shape = (shape[2], shape[0], shape[1]) elif key == "observation.environment_state": type = FeatureType.ENV diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index b6b3e547..e53ad945 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -68,4 +68,3 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g ) return env - diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index a843ac1a..4c9102de 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -28,6 +28,7 @@ from lerobot.common.policies.pi0.configuration_pi0 import PI0Config from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig +from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import FeatureType @@ -58,6 +59,10 @@ def get_policy_class(name: str) -> PreTrainedPolicy: from lerobot.common.policies.sac.modeling_sac import SACPolicy return SACPolicy + elif name == "hilserl_classifier": + from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier + + return Classifier else: raise NotImplementedError(f"Policy with name {name} is not implemented.") @@ -73,6 +78,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: return VQBeTConfig(**kwargs) elif policy_type == "pi0": return PI0Config(**kwargs) + elif policy_type == "hilserl_classifier": + return ClassifierConfig(**kwargs) else: raise ValueError(f"Policy type '{policy_type}' is not available.") diff --git a/lerobot/common/policies/hilserl/classifier/configuration_classifier.py b/lerobot/common/policies/hilserl/classifier/configuration_classifier.py index fe7eb142..00688931 100644 --- a/lerobot/common/policies/hilserl/classifier/configuration_classifier.py +++ b/lerobot/common/policies/hilserl/classifier/configuration_classifier.py @@ -1,12 +1,18 @@ -import json -import os -from dataclasses import asdict, dataclass +from dataclasses import dataclass, field +from typing import Dict, List + +from lerobot.common.optim.optimizers import AdamWConfig, OptimizerConfig +from lerobot.common.optim.schedulers import LRSchedulerConfig +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import FeatureType, PolicyFeature +@PreTrainedConfig.register_subclass(name="hilserl_classifier") @dataclass -class ClassifierConfig: +class ClassifierConfig(PreTrainedConfig): """Configuration for the Classifier model.""" + name: str = "hilserl_classifier" num_classes: int = 2 hidden_dim: int = 256 dropout_rate: float = 0.1 @@ -14,22 +20,35 @@ class ClassifierConfig: device: str = "cpu" model_type: str = "cnn" # "transformer" or "cnn" num_cameras: int = 2 + learning_rate: float = 1e-4 + normalization_mode = None + # output_features: Dict[str, PolicyFeature] = field( + # default_factory=lambda: {"next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(1,))} + # ) - def save_pretrained(self, save_dir): - """Save config to json file.""" - os.makedirs(save_dir, exist_ok=True) + @property + def observation_delta_indices(self) -> List | None: + return None - # Convert to dict and save as JSON - config_dict = asdict(self) - with open(os.path.join(save_dir, "config.json"), "w") as f: - json.dump(config_dict, f, indent=2) + @property + def action_delta_indices(self) -> List | None: + return None - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path): - """Load config from json file.""" - config_file = os.path.join(pretrained_model_name_or_path, "config.json") + @property + def reward_delta_indices(self) -> List | None: + return None - with open(config_file) as f: - config_dict = json.load(f) + def get_optimizer_preset(self) -> OptimizerConfig: + return AdamWConfig( + lr=self.learning_rate, + weight_decay=0.01, + grad_clip_norm=1.0, + ) - return cls(**config_dict) + def get_scheduler_preset(self) -> LRSchedulerConfig | None: + return None + + def validate_features(self) -> None: + """Validate feature configurations.""" + # Classifier doesn't need specific feature validation + pass diff --git a/lerobot/common/policies/hilserl/classifier/modeling_classifier.py b/lerobot/common/policies/hilserl/classifier/modeling_classifier.py index 18c30493..3db6394e 100644 --- a/lerobot/common/policies/hilserl/classifier/modeling_classifier.py +++ b/lerobot/common/policies/hilserl/classifier/modeling_classifier.py @@ -1,11 +1,15 @@ import logging -from typing import Optional +from typing import Dict, Optional, Tuple import torch -from huggingface_hub import PyTorchModelHubMixin from torch import Tensor, nn -from .configuration_classifier import ClassifierConfig +from lerobot.common.constants import OBS_IMAGE +from lerobot.common.policies.hilserl.classifier.configuration_classifier import ( + ClassifierConfig, +) +from lerobot.common.policies.normalize import Normalize, Unnormalize +from lerobot.common.policies.pretrained import PreTrainedPolicy logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) @@ -32,25 +36,32 @@ class ClassifierOutput: ) -class Classifier( - nn.Module, - PyTorchModelHubMixin, - # Add Hub metadata - library_name="lerobot", - repo_url="https://github.com/huggingface/lerobot", - tags=["robotics", "vision-classifier"], -): +class Classifier(PreTrainedPolicy): """Image classifier built on top of a pre-trained encoder.""" - # Add name attribute for factory - name = "classifier" + name = "hilserl_classifier" + config_class = ClassifierConfig - def __init__(self, config: ClassifierConfig): + def __init__( + self, + config: ClassifierConfig, + dataset_stats: Dict[str, Dict[str, Tensor]] | None = None, + ): from transformers import AutoModel - super().__init__() + super().__init__(config) self.config = config - # self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True) + + # Initialize normalization (standardized with the policy framework) + self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) + self.normalize_targets = Normalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + self.unnormalize_outputs = Unnormalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + + # Set up encoder encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True) # Extract vision model if we're given a multimodal model if hasattr(encoder, "vision_model"): @@ -81,8 +92,6 @@ class Classifier( else: raise ValueError("Unsupported CNN architecture") - self.encoder = self.encoder.to(self.config.device) - def _freeze_encoder(self) -> None: """Freeze the encoder parameters.""" for param in self.encoder.parameters(): @@ -109,22 +118,13 @@ class Classifier( 1 if self.config.num_classes == 2 else self.config.num_classes, ), ) - self.classifier_head = self.classifier_head.to(self.config.device) def _get_encoder_output(self, x: torch.Tensor) -> torch.Tensor: """Extract the appropriate output from the encoder.""" - # Process images with the processor (handles resizing and normalization) - # processed = self.processor( - # images=x, # LeRobotDataset already provides proper tensor format - # return_tensors="pt", - # ) - # processed = processed["pixel_values"].to(x.device) - processed = x - with torch.no_grad(): if self.is_cnn: # The HF ResNet applies pooling internally - outputs = self.encoder(processed) + outputs = self.encoder(x) # Get pooled output directly features = outputs.pooler_output @@ -132,14 +132,24 @@ class Classifier( features = features.squeeze(-1).squeeze(-1) return features else: # Transformer models - outputs = self.encoder(processed) + outputs = self.encoder(x) if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None: return outputs.pooler_output return outputs.last_hidden_state[:, 0, :] - def forward(self, xs: torch.Tensor) -> ClassifierOutput: - """Forward pass of the classifier.""" - # For training, we expect input to be a tensor directly from LeRobotDataset + def extract_images_and_labels(self, batch: Dict[str, Tensor]) -> Tuple[list, Tensor]: + """Extract image tensors and label tensors from batch.""" + # Find image keys in input features + image_keys = [key for key in self.config.input_features if key.startswith(OBS_IMAGE)] + + # Extract the images and labels + images = [batch[key] for key in image_keys] + labels = batch["next.reward"] + + return images, labels + + def predict(self, xs: list) -> ClassifierOutput: + """Forward pass of the classifier for inference.""" encoder_outputs = torch.hstack([self._get_encoder_output(x) for x in xs]) logits = self.classifier_head(encoder_outputs) @@ -151,10 +161,77 @@ class Classifier( return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs) - def predict_reward(self, x, threshold=0.6): + def forward(self, batch: Dict[str, Tensor]) -> Tuple[Tensor, Dict[str, Tensor]]: + """Standard forward pass for training compatible with train.py.""" + # Normalize inputs if needed + batch = self.normalize_inputs(batch) + batch = self.normalize_targets(batch) + + # Extract images and labels + images, labels = self.extract_images_and_labels(batch) + + # Get predictions + outputs = self.predict(images) + + # Calculate loss if self.config.num_classes == 2: - probs = self.forward(x).probabilities + # Binary classification + loss = nn.functional.binary_cross_entropy_with_logits(outputs.logits, labels) + predictions = (torch.sigmoid(outputs.logits) > 0.5).float() + else: + # Multi-class classification + loss = nn.functional.cross_entropy(outputs.logits, labels.long()) + predictions = torch.argmax(outputs.logits, dim=1) + + # Calculate accuracy for logging + correct = (predictions == labels).sum().item() + total = labels.size(0) + accuracy = 100 * correct / total + + # Return loss and metrics for logging + output_dict = { + "accuracy": accuracy, + "correct": correct, + "total": total, + } + + return loss, output_dict + + def predict_reward(self, batch, threshold=0.6): + """Legacy method for compatibility.""" + images, _ = self.extract_images_and_labels(batch) + if self.config.num_classes == 2: + probs = self.predict(images).probabilities logging.debug(f"Predicted reward images: {probs}") return (probs > threshold).float() else: - return torch.argmax(self.forward(x).probabilities, dim=1) + return torch.argmax(self.predict(images).probabilities, dim=1) + + # Methods required by PreTrainedPolicy abstract class + + def get_optim_params(self) -> dict: + """Return optimizer parameters for the policy.""" + return { + "params": self.parameters(), + "lr": getattr(self.config, "learning_rate", 1e-4), + "weight_decay": getattr(self.config, "weight_decay", 0.01), + } + + def reset(self): + """Reset any stateful components (required by PreTrainedPolicy).""" + # Classifier doesn't have stateful components that need resetting + pass + + def select_action(self, batch: Dict[str, Tensor]) -> Tensor: + """Return action (class prediction) based on input observation.""" + images, _ = self.extract_images_and_labels(batch) + + with torch.no_grad(): + outputs = self.predict(images) + + if self.config.num_classes == 2: + # For binary classification return 0 or 1 + return (outputs.probabilities > 0.5).float() + else: + # For multi-class return the predicted class + return torch.argmax(outputs.probabilities, dim=1) diff --git a/lerobot/configs/types.py b/lerobot/configs/types.py index 6b3d92e8..6040ff70 100644 --- a/lerobot/configs/types.py +++ b/lerobot/configs/types.py @@ -23,6 +23,7 @@ class FeatureType(str, Enum): VISUAL = "VISUAL" ENV = "ENV" ACTION = "ACTION" + REWARD = "REWARD" class NormalizationMode(str, Enum): diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index 479cb21f..1013001a 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -274,10 +274,7 @@ def record( if not robot.is_connected: robot.connect() - listener, events = init_keyboard_listener(assign_rewards=assign_rewards) - - if reset_follower: - initial_position = robot.follower_arms["main"].read("Present_Position") + listener, events = init_keyboard_listener(assign_rewards=cfg.assign_rewards) # Execute a few seconds without recording to: # 1. teleoperate the robot to move it in starting position if no policy provided, @@ -394,6 +391,7 @@ def control_robot(cfg: ControlPipelineConfig): replay(robot, cfg.control) elif isinstance(cfg.control, RemoteRobotConfig): from lerobot.common.robot_devices.robots.lekiwi_remote import run_lekiwi + run_lekiwi(cfg.robot) if robot.is_connected: diff --git a/lerobot/scripts/server/end_effector_control_utils.py b/lerobot/scripts/server/end_effector_control_utils.py index 1c056aa5..d576f2ef 100644 --- a/lerobot/scripts/server/end_effector_control_utils.py +++ b/lerobot/scripts/server/end_effector_control_utils.py @@ -1,13 +1,13 @@ +import argparse +import logging +import sys +import time + +import numpy as np +import torch + from lerobot.common.robot_devices.utils import busy_wait from lerobot.scripts.server.kinematics import RobotKinematics -import logging -import time -import torch -import numpy as np -import argparse -from lerobot.common.robot_devices.robots.utils import make_robot_from_config -from lerobot.scripts.server.gym_manipulator import make_robot_env, HILSerlRobotEnvConfig -from lerobot.common.robot_devices.robots.configs import RobotConfig logging.basicConfig(level=logging.INFO) @@ -458,12 +458,13 @@ class GamepadControllerHID(InputController): def test_forward_kinematics(robot, fps=10): logging.info("Testing Forward Kinematics") timestep = time.perf_counter() + kinematics = RobotKinematics(robot.robot_type) while time.perf_counter() - timestep < 60.0: loop_start_time = time.perf_counter() robot.teleop_step() obs = robot.capture_observation() joint_positions = obs["observation.state"].cpu().numpy() - ee_pos = RobotKinematics.fk_gripper_tip(joint_positions) + ee_pos = kinematics.fk_gripper_tip(joint_positions) logging.info(f"EE Position: {ee_pos[:3, 3]}") busy_wait(1 / fps - (time.perf_counter() - loop_start_time)) @@ -485,21 +486,19 @@ def test_inverse_kinematics(robot, fps=10): def teleoperate_inverse_kinematics_with_leader(robot, fps=10): logging.info("Testing Inverse Kinematics") - fk_func = RobotKinematics.fk_gripper_tip + kinematics = RobotKinematics(robot.robot_type) timestep = time.perf_counter() while time.perf_counter() - timestep < 60.0: loop_start_time = time.perf_counter() obs = robot.capture_observation() joint_positions = obs["observation.state"].cpu().numpy() - ee_pos = fk_func(joint_positions) + ee_pos = kinematics.fk_gripper_tip(joint_positions) leader_joint_positions = robot.leader_arms["main"].read("Present_Position") - leader_ee = fk_func(leader_joint_positions) + leader_ee = kinematics.fk_gripper_tip(leader_joint_positions) desired_ee_pos = leader_ee - target_joint_state = RobotKinematics.ik( - joint_positions, desired_ee_pos, position_only=True, fk_func=fk_func - ) + target_joint_state = kinematics.ik(joint_positions, desired_ee_pos, position_only=True) robot.send_action(torch.from_numpy(target_joint_state)) logging.info(f"Leader EE: {leader_ee[:3, 3]}, Follower EE: {ee_pos[:3, 3]}") busy_wait(1 / fps - (time.perf_counter() - loop_start_time)) @@ -513,10 +512,10 @@ def teleoperate_delta_inverse_kinematics_with_leader(robot, fps=10): obs = robot.capture_observation() joint_positions = obs["observation.state"].cpu().numpy() - fk_func = RobotKinematics.fk_gripper_tip + kinematics = RobotKinematics(robot.robot_type) leader_joint_positions = robot.leader_arms["main"].read("Present_Position") - initial_leader_ee = fk_func(leader_joint_positions) + initial_leader_ee = kinematics.fk_gripper_tip(leader_joint_positions) desired_ee_pos = np.diag(np.ones(4)) @@ -525,13 +524,13 @@ def teleoperate_delta_inverse_kinematics_with_leader(robot, fps=10): # Get leader state for teleoperation leader_joint_positions = robot.leader_arms["main"].read("Present_Position") - leader_ee = fk_func(leader_joint_positions) + leader_ee = kinematics.fk_gripper_tip(leader_joint_positions) # Get current state # obs = robot.capture_observation() # joint_positions = obs["observation.state"].cpu().numpy() joint_positions = robot.follower_arms["main"].read("Present_Position") - current_ee_pos = fk_func(joint_positions) + current_ee_pos = kinematics.fk_gripper_tip(joint_positions) # Calculate delta between leader and follower end-effectors # Scaling factor can be adjusted for sensitivity @@ -545,9 +544,7 @@ def teleoperate_delta_inverse_kinematics_with_leader(robot, fps=10): if np.any(np.abs(ee_delta[:3, 3]) > 0.01): # Compute joint targets via inverse kinematics - target_joint_state = RobotKinematics.ik( - joint_positions, desired_ee_pos, position_only=True, fk_func=fk_func - ) + target_joint_state = kinematics.ik(joint_positions, desired_ee_pos, position_only=True) initial_leader_ee = leader_ee.copy() @@ -580,7 +577,8 @@ def teleoperate_delta_inverse_kinematics(robot, controller, fps=10, bounds=None, # Initial position capture obs = robot.capture_observation() joint_positions = obs["observation.state"].cpu().numpy() - current_ee_pos = fk_func(joint_positions) + kinematics = RobotKinematics(robot.robot_type) + current_ee_pos = kinematics.fk_gripper_tip(joint_positions) # Initialize desired position with current position desired_ee_pos = np.eye(4) # Identity matrix @@ -595,7 +593,7 @@ def teleoperate_delta_inverse_kinematics(robot, controller, fps=10, bounds=None, # Get currrent robot state joint_positions = robot.follower_arms["main"].read("Present_Position") - current_ee_pos = fk_func(joint_positions) + current_ee_pos = kinematics.fk_gripper_tip(joint_positions) # Get movement deltas from the controller delta_x, delta_y, delta_z = controller.get_deltas() @@ -612,9 +610,7 @@ def teleoperate_delta_inverse_kinematics(robot, controller, fps=10, bounds=None, # Only send commands if there's actual movement if any([abs(v) > 0.001 for v in [delta_x, delta_y, delta_z]]): # Compute joint targets via inverse kinematics - target_joint_state = RobotKinematics.ik( - joint_positions, desired_ee_pos, position_only=True, fk_func=fk_func - ) + target_joint_state = kinematics.ik(joint_positions, desired_ee_pos, position_only=True) # Send command to robot robot.send_action(torch.from_numpy(target_joint_state)) @@ -676,7 +672,17 @@ def teleoperate_gym_env(env, controller, fps: int = 30): # Close the environment env.close() + if __name__ == "__main__": + from lerobot.common.robot_devices.robots.configs import RobotConfig + from lerobot.common.robot_devices.robots.utils import make_robot_from_config + from lerobot.scripts.server.gym_manipulator import ( + EEActionSpaceConfig, + EnvWrapperConfig, + HILSerlRobotEnvConfig, + make_robot_env, + ) + parser = argparse.ArgumentParser(description="Test end-effector control") parser.add_argument( "--mode", @@ -698,12 +704,6 @@ if __name__ == "__main__": default="so100", help="Robot type (so100, koch, aloha, etc.)", ) - parser.add_argument( - "--config-path", - type=str, - default=None, - help="Path to the config file in json format", - ) args = parser.parse_args() @@ -725,7 +725,10 @@ if __name__ == "__main__": if args.mode.startswith("keyboard"): controller = KeyboardController(x_step_size=0.01, y_step_size=0.01, z_step_size=0.05) elif args.mode.startswith("gamepad"): - controller = GamepadController(x_step_size=0.02, y_step_size=0.02, z_step_size=0.05) + if sys.platform == "darwin": + controller = GamepadControllerHID(x_step_size=0.01, y_step_size=0.01, z_step_size=0.05) + else: + controller = GamepadController(x_step_size=0.01, y_step_size=0.01, z_step_size=0.05) # Handle mode categories if args.mode in ["keyboard", "gamepad"]: @@ -734,12 +737,14 @@ if __name__ == "__main__": elif args.mode in ["keyboard_gym", "gamepad_gym"]: # Gym environment control modes - cfg = HILSerlRobotEnvConfig() - if args.config_path is not None: - cfg = HILSerlRobotEnvConfig.from_json(args.config_path) - + cfg = HILSerlRobotEnvConfig(robot=robot_config, wrapper=EnvWrapperConfig()) + cfg.wrapper.ee_action_space_params = EEActionSpaceConfig( + x_step_size=0.03, y_step_size=0.03, z_step_size=0.03, bounds=bounds + ) + cfg.wrapper.ee_action_space_params.use_gamepad = False + cfg.device = "cpu" env = make_robot_env(cfg, robot) - teleoperate_gym_env(env, controller, fps=args.fps) + teleoperate_gym_env(env, controller, fps=cfg.fps) elif args.mode == "leader": # Leader-follower modes don't use controllers diff --git a/lerobot/scripts/server/find_joint_limits.py b/lerobot/scripts/server/find_joint_limits.py index deec4d75..f34c8f9f 100644 --- a/lerobot/scripts/server/find_joint_limits.py +++ b/lerobot/scripts/server/find_joint_limits.py @@ -63,9 +63,10 @@ def find_ee_bounds( if time.perf_counter() - start_episode_t < 5: continue + kinematics = RobotKinematics(robot.robot_type) joint_positions = robot.follower_arms["main"].read("Present_Position") print(f"Joint positions: {joint_positions}") - ee_list.append(RobotKinematics.fk_gripper_tip(joint_positions)[:3, 3]) + ee_list.append(kinematics.fk_gripper_tip(joint_positions)[:3, 3]) if display_cameras and not is_headless(): image_keys = [key for key in observation if "image" in key] @@ -81,20 +82,19 @@ def find_ee_bounds( break -def make_robot(robot_type="so100", mock=True): +def make_robot(robot_type="so100"): """ Create a robot instance using the appropriate robot config class. Args: robot_type: Robot type string (e.g., "so100", "koch", "aloha") - mock: Whether to use mock mode for hardware (default: True) Returns: Robot instance """ # Get the appropriate robot config class based on robot_type - robot_config = RobotConfig.get_choice_class(robot_type)(mock=mock) + robot_config = RobotConfig.get_choice_class(robot_type)(mock=False) robot_config.leader_arms["main"].port = leader_port robot_config.follower_arms["main"].port = follower_port @@ -122,18 +122,12 @@ if __name__ == "__main__": default="so100", help="Robot type (so100, koch, aloha, etc.)", ) - parser.add_argument( - "--mock", - type=int, - default=1, - help="Use mock mode for hardware simulation", - ) - + # Only parse known args, leaving robot config args for Hydra if used - args, _ = parser.parse_known_args() + args = parser.parse_args() # Create robot with the appropriate config - robot = make_robot(args.robot_type, args.mock) + robot = make_robot(args.robot_type) if args.mode == "joint": find_joint_bounds(robot, args.control_time_s) diff --git a/lerobot/scripts/server/gym_manipulator.py b/lerobot/scripts/server/gym_manipulator.py index 4abd385f..856ea843 100644 --- a/lerobot/scripts/server/gym_manipulator.py +++ b/lerobot/scripts/server/gym_manipulator.py @@ -1,45 +1,37 @@ import logging import sys import time -import sys - +from dataclasses import dataclass from threading import Lock -from typing import Annotated, Any, Dict, Tuple +from typing import Annotated, Any, Dict, Optional, Tuple import gymnasium as gym import numpy as np import torch import torchvision.transforms.functional as F # noqa: N812 -import json -from dataclasses import dataclass - -from lerobot.common.envs.utils import preprocess_observation -from lerobot.configs.train import TrainPipelineConfig from lerobot.common.envs.configs import EnvConfig +from lerobot.common.envs.utils import preprocess_observation from lerobot.common.robot_devices.control_utils import ( busy_wait, is_headless, - # reset_follower_position, + reset_follower_position, ) - -from typing import Optional -from lerobot.common.utils.utils import log_say -from lerobot.common.robot_devices.robots.utils import make_robot_from_config - from lerobot.common.robot_devices.robots.configs import RobotConfig - -from lerobot.scripts.server.kinematics import RobotKinematics -from lerobot.scripts.server.maniskill_manipulator import ManiskillEnvConfig, make_maniskill +from lerobot.common.robot_devices.robots.utils import make_robot_from_config +from lerobot.common.utils.utils import log_say from lerobot.configs import parser +from lerobot.scripts.server.kinematics import RobotKinematics logging.basicConfig(level=logging.INFO) + @dataclass class EEActionSpaceConfig: """Configuration parameters for end-effector action space.""" + x_step_size: float - y_step_size: float + y_step_size: float z_step_size: float bounds: Dict[str, Any] # Contains 'min' and 'max' keys with position bounds use_gamepad: bool = False @@ -48,6 +40,7 @@ class EEActionSpaceConfig: @dataclass class EnvWrapperConfig: """Configuration for environment wrappers.""" + display_cameras: bool = False delta_action: float = 0.1 use_relative_joint_positions: bool = True @@ -64,28 +57,27 @@ class EnvWrapperConfig: reward_classifier_config_file: Optional[str] = None +@EnvConfig.register_subclass(name="gym_manipulator") @dataclass -class HILSerlRobotEnvConfig: +class HILSerlRobotEnvConfig(EnvConfig): """Configuration for the HILSerlRobotEnv environment.""" - robot: RobotConfig - wrapper: EnvWrapperConfig - env_name: str = "real_robot" + + robot: Optional[RobotConfig] = None + wrapper: Optional[EnvWrapperConfig] = None fps: int = 10 mode: str = None # Either "record", "replay", None repo_id: Optional[str] = None dataset_root: Optional[str] = None task: str = "" - num_episodes: int = 10 # only for record mode + num_episodes: int = 10 # only for record mode episode: int = 0 device: str = "cuda" push_to_hub: bool = True pretrained_policy_name_or_path: Optional[str] = None - @classmethod - def from_json(cls, json_path: str): - with open(json_path, "r") as f: - config = json.load(f) - return cls(**config) + def gym_kwargs(self) -> dict: + return {} + class HILSerlRobotEnv(gym.Env): """ @@ -580,8 +572,7 @@ class ImageCropResizeWrapper(gym.Wrapper): if key_crop not in self.env.observation_space.keys(): # noqa: SIM118 raise ValueError(f"Key {key_crop} not in observation space") for key in crop_params_dict: - top, left, height, width = crop_params_dict[key] - new_shape = (top + height, left + width) + new_shape = (3, resize_size[0], resize_size[1]) self.observation_space[key] = gym.spaces.Box(low=0, high=255, shape=new_shape) self.resize_size = resize_size @@ -1097,9 +1088,7 @@ class ActionScaleWrapper(gym.ActionWrapper): return action * self.scale_vector, is_intervention -def make_robot_env(cfg: EnvConfig) -> gym.vector.VectorEnv: -# def make_robot_env(cfg: TrainPipelineConfig) -> gym.vector.VectorEnv: -# def make_robot_env(cfg: ManiskillEnvConfig) -> gym.vector.VectorEnv: +def make_robot_env(cfg) -> gym.vector.VectorEnv: """ Factory function to create a vectorized robot environment. @@ -1111,16 +1100,16 @@ def make_robot_env(cfg: EnvConfig) -> gym.vector.VectorEnv: Returns: A vectorized gym environment with all the necessary wrappers applied. """ - if "maniskill" in cfg.name: - from lerobot.scripts.server.maniskill_manipulator import make_maniskill + # if "maniskill" in cfg.name: + # from lerobot.scripts.server.maniskill_manipulator import make_maniskill - logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN") - env = make_maniskill( - cfg=cfg, - n_envs=1, - ) - return env - robot = cfg.robot + # logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN") + # env = make_maniskill( + # cfg=cfg, + # n_envs=1, + # ) + # return env + robot = make_robot_from_config(cfg.robot) # Create base environment env = HILSerlRobotEnv( robot=robot, @@ -1150,10 +1139,7 @@ def make_robot_env(cfg: EnvConfig) -> gym.vector.VectorEnv: env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps) if cfg.wrapper.ee_action_space_params is not None: env = EEActionWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params) - if ( - cfg.wrapper.ee_action_space_params is not None - and cfg.wrapper.ee_action_space_params.use_gamepad - ): + if cfg.wrapper.ee_action_space_params is not None and cfg.wrapper.ee_action_space_params.use_gamepad: # env = ActionScaleWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params) env = GamepadControlWrapper( env=env, @@ -1169,10 +1155,7 @@ def make_robot_env(cfg: EnvConfig) -> gym.vector.VectorEnv: reset_pose=cfg.wrapper.fixed_reset_joint_positions, reset_time_s=cfg.wrapper.reset_time_s, ) - if ( - cfg.wrapper.ee_action_space_params is None - and cfg.wrapper.joint_masking_action_space is not None - ): + if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None: env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space) env = BatchCompitableWrapper(env=env) @@ -1180,7 +1163,10 @@ def make_robot_env(cfg: EnvConfig) -> gym.vector.VectorEnv: def get_classifier(cfg): - if cfg.wrapper.reward_classifier_pretrained_path is None or cfg.wrapper.reward_classifier_config_file is None: + if ( + cfg.wrapper.reward_classifier_pretrained_path is None + or cfg.wrapper.reward_classifier_config_file is None + ): return None from lerobot.common.policies.hilserl.classifier.configuration_classifier import ( @@ -1258,7 +1244,8 @@ def record_dataset(env, policy, cfg: HILSerlRobotEnvConfig): # Record episodes episode_index = 0 - while episode_index < cfg.record_num_episodes: + recorded_action = None + while episode_index < cfg.num_episodes: obs, _ = env.reset() start_episode_t = time.perf_counter() log_say(f"Recording episode {episode_index}", play_sounds=True) @@ -1279,16 +1266,19 @@ def record_dataset(env, policy, cfg: HILSerlRobotEnvConfig): break # For teleop, get action from intervention - if policy is None: - action = {"action": info["action_intervention"].cpu().squeeze(0).float()} + recorded_action = { + "action": info["action_intervention"].cpu().squeeze(0).float() if policy is None else action + } # Process observation for dataset obs = {k: v.cpu().squeeze(0).float() for k, v in obs.items()} + obs["observation.images.side"] = torch.clamp(obs["observation.images.side"], 0, 1) # Add frame to dataset - frame = {**obs, **action} - frame["next.reward"] = reward - frame["next.done"] = terminated or truncated + frame = {**obs, **recorded_action} + frame["next.reward"] = np.array([reward], dtype=np.float32) + frame["next.done"] = np.array([terminated or truncated], dtype=bool) + frame["task"] = cfg.task dataset.add_frame(frame) # Maintain consistent timing @@ -1309,9 +1299,9 @@ def record_dataset(env, policy, cfg: HILSerlRobotEnvConfig): episode_index += 1 # Finalize dataset - dataset.consolidate(run_compute_stats=True) + # dataset.consolidate(run_compute_stats=True) if cfg.push_to_hub: - dataset.push_to_hub(cfg.repo_id) + dataset.push_to_hub() def replay_episode(env, repo_id, root=None, episode=0): @@ -1334,82 +1324,69 @@ def replay_episode(env, repo_id, root=None, episode=0): busy_wait(1 / 10 - dt_s) -# @parser.wrap() -# def main(cfg): +@parser.wrap() +def main(cfg: EnvConfig): + env = make_robot_env(cfg) -# robot = make_robot_from_config(cfg.robot) + if cfg.mode == "record": + policy = None + if cfg.pretrained_policy_name_or_path is not None: + from lerobot.common.policies.sac.modeling_sac import SACPolicy -# reward_classifier = None #get_classifier( -# # cfg.wrapper.reward_classifier_pretrained_path, cfg.wrapper.reward_classifier_config_file -# # ) -# user_relative_joint_positions = True + policy = SACPolicy.from_pretrained(cfg.pretrained_policy_name_or_path) + policy.to(cfg.device) + policy.eval() -# env = make_robot_env(cfg, robot) + record_dataset( + env, + policy=None, + cfg=cfg, + ) + exit() -# if cfg.mode == "record": -# policy = None -# if cfg.pretrained_policy_name_or_path is not None: -# from lerobot.common.policies.sac.modeling_sac import SACPolicy + if cfg.mode == "replay": + replay_episode( + env, + cfg.replay_repo_id, + root=cfg.dataset_root, + episode=cfg.replay_episode, + ) + exit() -# policy = SACPolicy.from_pretrained(cfg.pretrained_policy_name_or_path) -# policy.to(cfg.device) -# policy.eval() + env.reset() -# record_dataset( -# env, -# cfg.repo_id, -# root=cfg.dataset_root, -# num_episodes=cfg.num_episodes, -# fps=cfg.fps, -# task_description=cfg.task, -# policy=policy, -# ) -# exit() + # Retrieve the robot's action space for joint commands. + action_space_robot = env.action_space.spaces[0] -# if cfg.mode == "replay": -# replay_episode( -# env, -# cfg.replay_repo_id, -# root=cfg.dataset_root, -# episode=cfg.replay_episode, -# ) -# exit() + # Initialize the smoothed action as a random sample. + smoothed_action = action_space_robot.sample() -# env.reset() + # Smoothing coefficient (alpha) defines how much of the new random sample to mix in. + # A value close to 0 makes the trajectory very smooth (slow to change), while a value close to 1 is less smooth. + alpha = 1.0 -# # Retrieve the robot's action space for joint commands. -# action_space_robot = env.action_space.spaces[0] + num_episode = 0 + sucesses = [] + while num_episode < 20: + start_loop_s = time.perf_counter() + # Sample a new random action from the robot's action space. + new_random_action = action_space_robot.sample() + # Update the smoothed action using an exponential moving average. + smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action -# # Initialize the smoothed action as a random sample. -# smoothed_action = action_space_robot.sample() + # Execute the step: wrap the NumPy action in a torch tensor. + obs, reward, terminated, truncated, info = env.step((torch.from_numpy(smoothed_action), False)) + if terminated or truncated: + sucesses.append(reward) + env.reset() + num_episode += 1 -# # Smoothing coefficient (alpha) defines how much of the new random sample to mix in. -# # A value close to 0 makes the trajectory very smooth (slow to change), while a value close to 1 is less smooth. -# alpha = 1.0 + dt_s = time.perf_counter() - start_loop_s + busy_wait(1 / cfg.fps - dt_s) -# num_episode = 0 -# sucesses = [] -# while num_episode < 20: -# start_loop_s = time.perf_counter() -# # Sample a new random action from the robot's action space. -# new_random_action = action_space_robot.sample() -# # Update the smoothed action using an exponential moving average. -# smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action + logging.info(f"Success after 20 steps {sucesses}") + logging.info(f"success rate {sum(sucesses) / len(sucesses)}") -# # Execute the step: wrap the NumPy action in a torch tensor. -# obs, reward, terminated, truncated, info = env.step((torch.from_numpy(smoothed_action), False)) -# if terminated or truncated: -# sucesses.append(reward) -# env.reset() -# num_episode += 1 -# dt_s = time.perf_counter() - start_loop_s -# busy_wait(1 / cfg.fps - dt_s) - -# logging.info(f"Success after 20 steps {sucesses}") -# logging.info(f"success rate {sum(sucesses) / len(sucesses)}") - -# if __name__ == "__main__": -# main() if __name__ == "__main__": - make_robot_env() \ No newline at end of file + main() diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 04fab60f..4deb1972 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -15,12 +15,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import os import shutil import time from concurrent.futures import ThreadPoolExecutor -from pprint import pformat -import os from pathlib import Path +from pprint import pformat import draccus import grpc @@ -30,35 +30,42 @@ import hilserl_pb2_grpc # type: ignore import torch from termcolor import colored from torch import nn - from torch.multiprocessing import Queue from torch.optim.optimizer import Optimizer +from lerobot.common.constants import ( + CHECKPOINTS_DIR, + LAST_CHECKPOINT_LINK, + PRETRAINED_MODEL_DIR, + TRAINING_STATE_DIR, + TRAINING_STEP, +) from lerobot.common.datasets.factory import make_dataset -from lerobot.configs.train import TrainPipelineConfig -from lerobot.configs import parser # TODO: Remove the import of maniskill from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.policies.factory import make_policy -from lerobot.common.policies.sac.modeling_sac import SACPolicy, SACConfig +from lerobot.common.policies.sac.modeling_sac import SACConfig, SACPolicy +from lerobot.common.policies.utils import get_device_from_parameters +from lerobot.common.utils.random_utils import set_seed from lerobot.common.utils.train_utils import ( get_step_checkpoint_dir, get_step_identifier, - load_training_state as utils_load_training_state, save_checkpoint, - update_last_checkpoint, save_training_state, + update_last_checkpoint, +) +from lerobot.common.utils.train_utils import ( + load_training_state as utils_load_training_state, ) -from lerobot.common.utils.random_utils import set_seed from lerobot.common.utils.utils import ( format_big_number, get_safe_torch_device, init_logging, ) - -from lerobot.common.policies.utils import get_device_from_parameters from lerobot.common.utils.wandb_utils import WandBLogger +from lerobot.configs import parser +from lerobot.configs.train import TrainPipelineConfig from lerobot.scripts.server import learner_service from lerobot.scripts.server.buffer import ( ReplayBuffer, @@ -70,47 +77,39 @@ from lerobot.scripts.server.buffer import ( state_to_bytes, ) from lerobot.scripts.server.utils import setup_process_handlers -from lerobot.common.constants import ( - CHECKPOINTS_DIR, - LAST_CHECKPOINT_LINK, - PRETRAINED_MODEL_DIR, - TRAINING_STATE_DIR, - TRAINING_STEP, -) def handle_resume_logic(cfg: TrainPipelineConfig) -> TrainPipelineConfig: """ Handle the resume logic for training. - + If resume is True: - Verifies that a checkpoint exists - Loads the checkpoint configuration - Logs resumption details - Returns the checkpoint configuration - + If resume is False: - Checks if an output directory exists (to prevent accidental overwriting) - Returns the original configuration - + Args: cfg (TrainPipelineConfig): The training configuration - + Returns: TrainPipelineConfig: The updated configuration - + Raises: RuntimeError: If resume is True but no checkpoint found, or if resume is False but directory exists """ out_dir = cfg.output_dir - + # Case 1: Not resuming, but need to check if directory exists to prevent overwrites if not cfg.resume: checkpoint_dir = os.path.join(out_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) if os.path.exists(checkpoint_dir): raise RuntimeError( - f"Output directory {checkpoint_dir} already exists. " - "Use `resume=true` to resume training." + f"Output directory {checkpoint_dir} already exists. Use `resume=true` to resume training." ) return cfg @@ -131,7 +130,7 @@ def handle_resume_logic(cfg: TrainPipelineConfig) -> TrainPipelineConfig: # Load config using Draccus checkpoint_cfg_path = os.path.join(checkpoint_dir, PRETRAINED_MODEL_DIR, "train_config.json") checkpoint_cfg = TrainPipelineConfig.from_pretrained(checkpoint_cfg_path) - + # Ensure resume flag is set in returned config checkpoint_cfg.resume = True return checkpoint_cfg @@ -143,11 +142,11 @@ def load_training_state( ): """ Loads the training state (optimizers, step count, etc.) from a checkpoint. - + Args: cfg (TrainPipelineConfig): Training configuration optimizers (Optimizer | dict): Optimizers to load state into - + Returns: tuple: (optimization_step, interaction_step) or (None, None) if not resuming """ @@ -156,23 +155,23 @@ def load_training_state( # Construct path to the last checkpoint directory checkpoint_dir = os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) - + logging.info(f"Loading training state from {checkpoint_dir}") - + try: # Use the utility function from train_utils which loads the optimizer state step, optimizers, _ = utils_load_training_state(Path(checkpoint_dir), optimizers, None) - + # Load interaction step separately from training_state.pt training_state_path = os.path.join(checkpoint_dir, TRAINING_STATE_DIR, "training_state.pt") interaction_step = 0 if os.path.exists(training_state_path): training_state = torch.load(training_state_path, weights_only=False) interaction_step = training_state.get("interaction_step", 0) - + logging.info(f"Resuming from step {step}, interaction step {interaction_step}") return step, interaction_step - + except Exception as e: logging.error(f"Failed to load training state: {e}") return None, None @@ -181,7 +180,7 @@ def load_training_state( def log_training_info(cfg: TrainPipelineConfig, policy: nn.Module) -> None: """ Log information about the training process. - + Args: cfg (TrainPipelineConfig): Training configuration policy (nn.Module): Policy model @@ -189,7 +188,6 @@ def log_training_info(cfg: TrainPipelineConfig, policy: nn.Module) -> None: num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_total_params = sum(p.numel() for p in policy.parameters()) - logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}") logging.info(f"{cfg.env.task=}") logging.info(f"{cfg.policy.online_steps=}") @@ -197,19 +195,15 @@ def log_training_info(cfg: TrainPipelineConfig, policy: nn.Module) -> None: logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") -def initialize_replay_buffer( - cfg: TrainPipelineConfig, - device: str, - storage_device: str -) -> ReplayBuffer: +def initialize_replay_buffer(cfg: TrainPipelineConfig, device: str, storage_device: str) -> ReplayBuffer: """ Initialize a replay buffer, either empty or from a dataset if resuming. - + Args: cfg (TrainPipelineConfig): Training configuration device (str): Device to store tensors on storage_device (str): Device for storage optimization - + Returns: ReplayBuffer: Initialized replay buffer """ @@ -224,7 +218,7 @@ def initialize_replay_buffer( logging.info("Resume training load the online dataset") dataset_path = os.path.join(cfg.output_dir, "dataset") - + # NOTE: In RL is possible to not have a dataset. repo_id = None if cfg.dataset is not None: @@ -250,13 +244,13 @@ def initialize_offline_replay_buffer( ) -> ReplayBuffer: """ Initialize an offline replay buffer from a dataset. - + Args: cfg (TrainPipelineConfig): Training configuration device (str): Device to store tensors on storage_device (str): Device for storage optimization active_action_dims (list[int] | None): Active action dimensions for masking - + Returns: ReplayBuffer: Initialized offline replay buffer """ @@ -314,7 +308,7 @@ def start_learner_threads( ) -> None: """ Start the learner threads for training. - + Args: cfg (TrainPipelineConfig): Training configuration wandb_logger (WandBLogger | None): Logger for metrics @@ -512,17 +506,19 @@ def add_actor_information_and_train( logging.info("Initializing policy") # Get checkpoint dir for resuming - checkpoint_dir = os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) if cfg.resume else None + checkpoint_dir = ( + os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) if cfg.resume else None + ) pretrained_path = os.path.join(checkpoint_dir, PRETRAINED_MODEL_DIR) if checkpoint_dir else None - + policy: SACPolicy = make_policy( cfg=cfg.policy, # ds_meta=cfg.dataset, - env_cfg=cfg.env + env_cfg=cfg.env, ) # Update the policy config with the grad_clip_norm value from training config if it exists - clip_grad_norm_value:float = cfg.policy.grad_clip_norm + clip_grad_norm_value: float = cfg.policy.grad_clip_norm # compile policy policy = torch.compile(policy) @@ -536,7 +532,7 @@ def add_actor_information_and_train( optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg=cfg, policy=policy) resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers) - log_training_info(cfg=cfg, policy= policy) + log_training_info(cfg=cfg, policy=policy) replay_buffer = initialize_replay_buffer(cfg, device, storage_device) batch_size = cfg.batch_size @@ -615,14 +611,10 @@ def add_actor_information_and_train( interaction_message = bytes_to_python_object(interaction_message) # If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging interaction_message["Interaction step"] += interaction_step_shift - + # Log interaction messages with WandB if available if wandb_logger: - wandb_logger.log_dict( - d=interaction_message, - mode="train", - custom_step_key="Interaction step" - ) + wandb_logger.log_dict(d=interaction_message, mode="train", custom_step_key="Interaction step") logging.debug("[LEARNER] Received interactions") @@ -636,7 +628,9 @@ def add_actor_information_and_train( if dataset_repo_id is not None: batch_offline = offline_replay_buffer.sample(batch_size=batch_size) - batch = concatenate_batch_transitions(left_batch_transitions=batch, right_batch_transition=batch_offline) + batch = concatenate_batch_transitions( + left_batch_transitions=batch, right_batch_transition=batch_offline + ) actions = batch["action"] rewards = batch["reward"] @@ -759,14 +753,10 @@ def add_actor_information_and_train( if offline_replay_buffer is not None: training_infos["offline_replay_buffer_size"] = len(offline_replay_buffer) training_infos["Optimization step"] = optimization_step - + # Log training metrics if wandb_logger: - wandb_logger.log_dict( - d=training_infos, - mode="train", - custom_step_key="Optimization step" - ) + wandb_logger.log_dict(d=training_infos, mode="train", custom_step_key="Optimization step") time_for_one_optimization_step = time.time() - time_for_one_optimization_step frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9) @@ -795,29 +785,19 @@ def add_actor_information_and_train( interaction_step = ( interaction_message["Interaction step"] if interaction_message is not None else 0 ) - + # Create checkpoint directory checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, online_steps, optimization_step) - + # Save checkpoint - save_checkpoint( - checkpoint_dir, - optimization_step, - cfg, - policy, - optimizers, - scheduler=None - ) - + save_checkpoint(checkpoint_dir, optimization_step, cfg, policy, optimizers, scheduler=None) + # Save interaction step manually training_state_dir = os.path.join(checkpoint_dir, TRAINING_STATE_DIR) os.makedirs(training_state_dir, exist_ok=True) - training_state = { - "step": optimization_step, - "interaction_step": interaction_step - } + training_state = {"step": optimization_step, "interaction_step": interaction_step} torch.save(training_state, os.path.join(training_state_dir, "training_state.pt")) - + # Update the "last" symlink update_last_checkpoint(checkpoint_dir) @@ -826,17 +806,13 @@ def add_actor_information_and_train( dataset_dir = os.path.join(cfg.output_dir, "dataset") if os.path.exists(dataset_dir) and os.path.isdir(dataset_dir): shutil.rmtree(dataset_dir) - + # Save dataset # NOTE: Handle the case where the dataset repo id is not specified in the config - # eg. RL training without demonstrations data + # eg. RL training without demonstrations data repo_id_buffer_save = cfg.env.task if dataset_repo_id is None else dataset_repo_id - replay_buffer.to_lerobot_dataset( - repo_id=repo_id_buffer_save, - fps=fps, - root=dataset_dir - ) - + replay_buffer.to_lerobot_dataset(repo_id=repo_id_buffer_save, fps=fps, root=dataset_dir) + if offline_replay_buffer is not None: dataset_offline_dir = os.path.join(cfg.output_dir, "dataset_offline") if os.path.exists(dataset_offline_dir) and os.path.isdir(dataset_offline_dir): @@ -882,9 +858,7 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module): params=policy.actor.parameters_to_optimize, lr=cfg.policy.actor_lr, ) - optimizer_critic = torch.optim.Adam( - params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr - ) + optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr) optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr) lr_scheduler = None optimizers = { @@ -898,19 +872,19 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module): def train(cfg: TrainPipelineConfig, job_name: str | None = None): """ Main training function that initializes and runs the training process. - + Args: cfg (TrainPipelineConfig): The training configuration job_name (str | None, optional): Job name for logging. Defaults to None. """ - + cfg.validate() # if cfg.output_dir is None: # raise ValueError("Output directory must be specified in config") - + if job_name is None: job_name = cfg.job_name - + if job_name is None: raise ValueError("Job name must be specified either in config or as a parameter") @@ -920,11 +894,12 @@ def train(cfg: TrainPipelineConfig, job_name: str | None = None): # Setup WandB logging if enabled if cfg.wandb.enable and cfg.wandb.project: from lerobot.common.utils.wandb_utils import WandBLogger + wandb_logger = WandBLogger(cfg) else: wandb_logger = None logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"])) - + # Handle resume logic cfg = handle_resume_logic(cfg) @@ -944,9 +919,9 @@ def train(cfg: TrainPipelineConfig, job_name: str | None = None): @parser.wrap() def train_cli(cfg: TrainPipelineConfig): - if not use_threads(cfg): import torch.multiprocessing as mp + mp.set_start_method("spawn") # Use the job_name from the config diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 69a7c0d3..6576651c 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -122,6 +122,9 @@ def make_optimizer_and_scheduler(cfg, policy): optimizer = VQBeTOptimizer(policy, cfg) lr_scheduler = VQBeTScheduler(optimizer, cfg) + elif cfg.policy.name == "hilserl_classifier": + optimizer = torch.optim.AdamW(policy.parameters(), cfg.policy.learning_rate) + lr_scheduler = None else: raise NotImplementedError() diff --git a/lerobot/scripts/train_hilserl_classifier.py b/lerobot/scripts/train_hilserl_classifier.py index fc9a1d07..a69b2b3c 100644 --- a/lerobot/scripts/train_hilserl_classifier.py +++ b/lerobot/scripts/train_hilserl_classifier.py @@ -16,7 +16,6 @@ import time from contextlib import nullcontext from pprint import pformat -import hydra import numpy as np import torch import torch.nn as nn @@ -32,11 +31,8 @@ from tqdm import tqdm from lerobot.common.datasets.factory import resolve_delta_timestamps from lerobot.common.datasets.lerobot_dataset import LeRobotDataset -from lerobot.common.logger import Logger -from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg -from lerobot.common.policies.hilserl.classifier.configuration_classifier import ( - ClassifierConfig, -) +from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig + from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier from lerobot.common.utils.utils import ( format_big_number, @@ -296,8 +292,6 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No init_logging() logging.info(OmegaConf.to_yaml(cfg)) - logger = Logger(cfg, out_dir, wandb_job_name=job_name) - # Initialize training environment device = get_safe_torch_device(cfg.device, log=True) set_global_seed(cfg.seed)