diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 8793f162..d9bdcf6c 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -22,7 +22,6 @@ from concurrent.futures import ThreadPoolExecutor from pathlib import Path from pprint import pformat -import draccus import grpc # Import generated stubs @@ -38,21 +37,17 @@ from lerobot.common.constants import ( LAST_CHECKPOINT_LINK, PRETRAINED_MODEL_DIR, TRAINING_STATE_DIR, - TRAINING_STEP, ) from lerobot.common.datasets.factory import make_dataset # TODO: Remove the import of maniskill from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.policies.factory import make_policy -from lerobot.common.policies.sac.modeling_sac import SACConfig, SACPolicy -from lerobot.common.policies.utils import get_device_from_parameters +from lerobot.common.policies.sac.modeling_sac import SACPolicy from lerobot.common.utils.random_utils import set_seed from lerobot.common.utils.train_utils import ( get_step_checkpoint_dir, - get_step_identifier, save_checkpoint, - save_training_state, update_last_checkpoint, ) from lerobot.common.utils.train_utils import ( @@ -78,7 +73,672 @@ from lerobot.scripts.server.buffer import ( ) from lerobot.scripts.server.utils import setup_process_handlers +LOG_PREFIX = "[LEARNER]" + logging.basicConfig(level=logging.INFO) + +################################################# +# MAIN ENTRY POINTS AND CORE ALGORITHM FUNCTIONS # +################################################# + + +@parser.wrap() +def train_cli(cfg: TrainPipelineConfig): + if not use_threads(cfg): + import torch.multiprocessing as mp + + mp.set_start_method("spawn") + + # Use the job_name from the config + train( + cfg, + job_name=cfg.job_name, + ) + + logging.info("[LEARNER] train_cli finished") + + +def train(cfg: TrainPipelineConfig, job_name: str | None = None): + """ + Main training function that initializes and runs the training process. + + Args: + cfg (TrainPipelineConfig): The training configuration + job_name (str | None, optional): Job name for logging. Defaults to None. + """ + + cfg.validate() + + if job_name is None: + job_name = cfg.job_name + + if job_name is None: + raise ValueError("Job name must be specified either in config or as a parameter") + + # Create logs directory to ensure it exists + log_dir = os.path.join(cfg.output_dir, "logs") + os.makedirs(log_dir, exist_ok=True) + log_file = os.path.join(log_dir, f"learner_{job_name}.log") + + # Initialize logging with explicit log file + init_logging(log_file=log_file) + logging.info(f"Learner logging initialized, writing to {log_file}") + logging.info(pformat(cfg.to_dict())) + + # Setup WandB logging if enabled + if cfg.wandb.enable and cfg.wandb.project: + from lerobot.common.utils.wandb_utils import WandBLogger + + wandb_logger = WandBLogger(cfg) + else: + wandb_logger = None + logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"])) + + # Handle resume logic + cfg = handle_resume_logic(cfg) + + set_seed(seed=cfg.seed) + + torch.backends.cudnn.benchmark = True + torch.backends.cuda.matmul.allow_tf32 = True + + shutdown_event = setup_process_handlers(use_threads(cfg)) + + start_learner_threads( + cfg=cfg, + wandb_logger=wandb_logger, + shutdown_event=shutdown_event, + ) + + +def start_learner_threads( + cfg: TrainPipelineConfig, + wandb_logger: WandBLogger | None, + shutdown_event: any, # Event, +) -> None: + """ + Start the learner threads for training. + + Args: + cfg (TrainPipelineConfig): Training configuration + wandb_logger (WandBLogger | None): Logger for metrics + shutdown_event: Event to signal shutdown + """ + # Create multiprocessing queues + transition_queue = Queue() + interaction_message_queue = Queue() + parameters_queue = Queue() + + concurrency_entity = None + + if use_threads(cfg): + from threading import Thread + + concurrency_entity = Thread + else: + from torch.multiprocessing import Process + + concurrency_entity = Process + + communication_process = concurrency_entity( + target=start_learner_server, + args=( + parameters_queue, + transition_queue, + interaction_message_queue, + shutdown_event, + cfg, + ), + daemon=True, + ) + communication_process.start() + + add_actor_information_and_train( + cfg=cfg, + wandb_logger=wandb_logger, + shutdown_event=shutdown_event, + transition_queue=transition_queue, + interaction_message_queue=interaction_message_queue, + parameters_queue=parameters_queue, + ) + logging.info("[LEARNER] Training process stopped") + + logging.info("[LEARNER] Closing queues") + transition_queue.close() + interaction_message_queue.close() + parameters_queue.close() + + communication_process.join() + logging.info("[LEARNER] Communication process joined") + + logging.info("[LEARNER] join queues") + transition_queue.cancel_join_thread() + interaction_message_queue.cancel_join_thread() + parameters_queue.cancel_join_thread() + + logging.info("[LEARNER] queues closed") + + +################################################# +# Core algorithm functions # +################################################# + + +def add_actor_information_and_train( + cfg: TrainPipelineConfig, + wandb_logger: WandBLogger | None, + shutdown_event: any, # Event, + transition_queue: Queue, + interaction_message_queue: Queue, + parameters_queue: Queue, +): + """ + Handles data transfer from the actor to the learner, manages training updates, + and logs training progress in an online reinforcement learning setup. + + This function continuously: + - Transfers transitions from the actor to the replay buffer. + - Logs received interaction messages. + - Ensures training begins only when the replay buffer has a sufficient number of transitions. + - Samples batches from the replay buffer and performs multiple critic updates. + - Periodically updates the actor, critic, and temperature optimizers. + - Logs training statistics, including loss values and optimization frequency. + + NOTE: This function doesn't have a single responsibility, it should be split into multiple functions + in the future. The reason why we did that is the GIL in Python. It's super slow the performance + are divided by 200. So we need to have a single thread that does all the work. + + Args: + cfg (TrainPipelineConfig): Configuration object containing hyperparameters. + wandb_logger (WandBLogger | None): Logger for tracking training progress. + shutdown_event (Event): Event to signal shutdown. + transition_queue (Queue): Queue for receiving transitions from the actor. + interaction_message_queue (Queue): Queue for receiving interaction messages from the actor. + parameters_queue (Queue): Queue for sending policy parameters to the actor. + """ + # Extract all configuration variables at the beginning + device = get_safe_torch_device(try_device=cfg.policy.device, log=True) + storage_device = get_safe_torch_device(try_device=cfg.policy.storage_device) + clip_grad_norm_value = cfg.policy.grad_clip_norm + online_step_before_learning = cfg.policy.online_step_before_learning + utd_ratio = cfg.policy.utd_ratio + fps = cfg.env.fps + log_freq = cfg.log_freq + save_freq = cfg.save_freq + policy_update_freq = cfg.policy.policy_update_freq + policy_parameters_push_frequency = cfg.policy.actor_learner_config.policy_parameters_push_frequency + saving_checkpoint = cfg.save_checkpoint + online_steps = cfg.policy.online_steps + + # Initialize logging for multiprocessing + if not use_threads(cfg): + log_dir = os.path.join(cfg.output_dir, "logs") + os.makedirs(log_dir, exist_ok=True) + log_file = os.path.join(log_dir, f"learner_train_process_{os.getpid()}.log") + init_logging(log_file=log_file) + logging.info("Initialized logging for actor information and training process") + + logging.info("Initializing policy") + + policy: SACPolicy = make_policy( + cfg=cfg.policy, + # ds_meta=cfg.dataset, + env_cfg=cfg.env, + ) + + # compile policy + policy = torch.compile(policy) + assert isinstance(policy, nn.Module) + policy.train() + + push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy) + + last_time_policy_pushed = time.time() + + optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg=cfg, policy=policy) + resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers) + + log_training_info(cfg=cfg, policy=policy) + + replay_buffer = initialize_replay_buffer(cfg, device, storage_device) + batch_size = cfg.batch_size + offline_replay_buffer = None + + if cfg.dataset is not None: + active_action_dims = None + # TODO: FIX THIS + if cfg.env.wrapper.joint_masking_action_space is not None: + active_action_dims = [ + i for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space) if mask + ] + offline_replay_buffer = initialize_offline_replay_buffer( + cfg=cfg, + device=device, + storage_device=storage_device, + active_action_dims=active_action_dims, + ) + batch_size: int = batch_size // 2 # We will sample from both replay buffer + + logging.info("Starting learner thread") + interaction_message, transition = None, None + optimization_step = resume_optimization_step if resume_optimization_step is not None else 0 + interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0 + + dataset_repo_id = None + if cfg.dataset is not None: + dataset_repo_id = cfg.dataset.repo_id + + # NOTE: THIS IS THE MAIN LOOP OF THE LEARNER + while True: + # Exit the training loop if shutdown is requested + if shutdown_event is not None and shutdown_event.is_set(): + logging.info("[LEARNER] Shutdown signal received. Exiting...") + break + + # Process all available transitions + logging.debug("[LEARNER] Waiting for transitions") + process_transitions( + transition_queue=transition_queue, + replay_buffer=replay_buffer, + offline_replay_buffer=offline_replay_buffer, + device=device, + dataset_repo_id=dataset_repo_id, + shutdown_event=shutdown_event, + ) + logging.debug("[LEARNER] Received transitions") + + # Process all available interaction messages + logging.debug("[LEARNER] Waiting for interactions") + interaction_message = process_interaction_messages( + interaction_message_queue=interaction_message_queue, + interaction_step_shift=interaction_step_shift, + wandb_logger=wandb_logger, + shutdown_event=shutdown_event, + ) + logging.debug("[LEARNER] Received interactions") + + # Wait until the replay buffer has enough samples + if len(replay_buffer) < online_step_before_learning: + continue + + logging.debug("[LEARNER] Starting optimization loop") + time_for_one_optimization_step = time.time() + for _ in range(utd_ratio - 1): + batch = replay_buffer.sample(batch_size=batch_size) + + if dataset_repo_id is not None: + batch_offline = offline_replay_buffer.sample(batch_size=batch_size) + batch = concatenate_batch_transitions( + left_batch_transitions=batch, right_batch_transition=batch_offline + ) + + actions = batch["action"] + rewards = batch["reward"] + observations = batch["state"] + next_observations = batch["next_state"] + done = batch["done"] + check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations) + + observation_features, next_observation_features = get_observation_features( + policy=policy, observations=observations, next_observations=next_observations + ) + loss_critic = policy.compute_loss_critic( + observations=observations, + actions=actions, + rewards=rewards, + next_observations=next_observations, + done=done, + observation_features=observation_features, + next_observation_features=next_observation_features, + ) + optimizers["critic"].zero_grad() + loss_critic.backward() + + # clip gradients + critic_grad_norm = torch.nn.utils.clip_grad_norm_( + parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value + ) + + optimizers["critic"].step() + + policy.update_target_networks() + + batch = replay_buffer.sample(batch_size=batch_size) + + if dataset_repo_id is not None: + batch_offline = offline_replay_buffer.sample(batch_size=batch_size) + batch = concatenate_batch_transitions( + left_batch_transitions=batch, right_batch_transition=batch_offline + ) + + actions = batch["action"] + rewards = batch["reward"] + observations = batch["state"] + next_observations = batch["next_state"] + done = batch["done"] + + check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations) + + observation_features, next_observation_features = get_observation_features( + policy=policy, observations=observations, next_observations=next_observations + ) + loss_critic = policy.compute_loss_critic( + observations=observations, + actions=actions, + rewards=rewards, + next_observations=next_observations, + done=done, + observation_features=observation_features, + next_observation_features=next_observation_features, + ) + optimizers["critic"].zero_grad() + loss_critic.backward() + + # clip gradients + critic_grad_norm = torch.nn.utils.clip_grad_norm_( + parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value + ).item() + + optimizers["critic"].step() + + training_infos = {} + training_infos["loss_critic"] = loss_critic.item() + training_infos["critic_grad_norm"] = critic_grad_norm + + if optimization_step % policy_update_freq == 0: + for _ in range(policy_update_freq): + loss_actor = policy.compute_loss_actor( + observations=observations, + observation_features=observation_features, + ) + + optimizers["actor"].zero_grad() + loss_actor.backward() + + # clip gradients + actor_grad_norm = torch.nn.utils.clip_grad_norm_( + parameters=policy.actor.parameters_to_optimize, max_norm=clip_grad_norm_value + ).item() + + optimizers["actor"].step() + + training_infos["loss_actor"] = loss_actor.item() + training_infos["actor_grad_norm"] = actor_grad_norm + + # Temperature optimization + loss_temperature = policy.compute_loss_temperature( + observations=observations, + observation_features=observation_features, + ) + optimizers["temperature"].zero_grad() + loss_temperature.backward() + + # clip gradients + temp_grad_norm = torch.nn.utils.clip_grad_norm_( + parameters=[policy.log_alpha], max_norm=clip_grad_norm_value + ).item() + + optimizers["temperature"].step() + + training_infos["loss_temperature"] = loss_temperature.item() + training_infos["temperature_grad_norm"] = temp_grad_norm + training_infos["temperature"] = policy.temperature + + # Check if it's time to push updated policy to actors + if time.time() - last_time_policy_pushed > policy_parameters_push_frequency: + push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy) + last_time_policy_pushed = time.time() + + policy.update_target_networks() + + # Log training metrics at specified intervals + if optimization_step % log_freq == 0: + training_infos["replay_buffer_size"] = len(replay_buffer) + if offline_replay_buffer is not None: + training_infos["offline_replay_buffer_size"] = len(offline_replay_buffer) + training_infos["Optimization step"] = optimization_step + + # Log training metrics + if wandb_logger: + wandb_logger.log_dict(d=training_infos, mode="train", custom_step_key="Optimization step") + + # Calculate and log optimization frequency + time_for_one_optimization_step = time.time() - time_for_one_optimization_step + frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9) + + logging.info(f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}") + + # Log optimization frequency + if wandb_logger: + wandb_logger.log_dict( + { + "Optimization frequency loop [Hz]": frequency_for_one_optimization_step, + "Optimization step": optimization_step, + }, + mode="train", + custom_step_key="Optimization step", + ) + + optimization_step += 1 + if optimization_step % log_freq == 0: + logging.info(f"[LEARNER] Number of optimization step: {optimization_step}") + + # Save checkpoint at specified intervals + if saving_checkpoint and (optimization_step % save_freq == 0 or optimization_step == online_steps): + save_training_checkpoint( + cfg=cfg, + optimization_step=optimization_step, + online_steps=online_steps, + interaction_message=interaction_message, + policy=policy, + optimizers=optimizers, + replay_buffer=replay_buffer, + offline_replay_buffer=offline_replay_buffer, + dataset_repo_id=dataset_repo_id, + fps=fps, + ) + + +def start_learner_server( + parameters_queue: Queue, + transition_queue: Queue, + interaction_message_queue: Queue, + shutdown_event: any, # Event, + cfg: TrainPipelineConfig, +): + """ + Start the learner server for training. + + Args: + parameters_queue: Queue for sending policy parameters to the actor + transition_queue: Queue for receiving transitions from the actor + interaction_message_queue: Queue for receiving interaction messages from the actor + shutdown_event: Event to signal shutdown + cfg: Training configuration + """ + if not use_threads(cfg): + # Create a process-specific log file + log_dir = os.path.join(cfg.output_dir, "logs") + os.makedirs(log_dir, exist_ok=True) + log_file = os.path.join(log_dir, f"learner_server_process_{os.getpid()}.log") + + # Initialize logging with explicit log file + init_logging(log_file=log_file) + logging.info("Learner server process logging initialized") + + # Setup process handlers to handle shutdown signal + # But use shutdown event from the main process + # Return back for MP + setup_process_handlers(False) + + service = learner_service.LearnerService( + shutdown_event=shutdown_event, + parameters_queue=parameters_queue, + seconds_between_pushes=cfg.policy.actor_learner_config.policy_parameters_push_frequency, + transition_queue=transition_queue, + interaction_message_queue=interaction_message_queue, + ) + + server = grpc.server( + ThreadPoolExecutor(max_workers=learner_service.MAX_WORKERS), + options=[ + ("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE), + ("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE), + ], + ) + + hilserl_pb2_grpc.add_LearnerServiceServicer_to_server( + service, + server, + ) + + host = cfg.policy.actor_learner_config.learner_host + port = cfg.policy.actor_learner_config.learner_port + + server.add_insecure_port(f"{host}:{port}") + server.start() + logging.info("[LEARNER] gRPC server started") + + shutdown_event.wait() + logging.info("[LEARNER] Stopping gRPC server...") + server.stop(learner_service.STUTDOWN_TIMEOUT) + logging.info("[LEARNER] gRPC server stopped") + + +def save_training_checkpoint( + cfg: TrainPipelineConfig, + optimization_step: int, + online_steps: int, + interaction_message: dict | None, + policy: nn.Module, + optimizers: dict[str, Optimizer], + replay_buffer: ReplayBuffer, + offline_replay_buffer: ReplayBuffer | None = None, + dataset_repo_id: str | None = None, + fps: int = 30, +) -> None: + """ + Save training checkpoint and associated data. + + This function performs the following steps: + 1. Creates a checkpoint directory with the current optimization step + 2. Saves the policy model, configuration, and optimizer states + 3. Saves the current interaction step for resuming training + 4. Updates the "last" checkpoint symlink to point to this checkpoint + 5. Saves the replay buffer as a dataset for later use + 6. If an offline replay buffer exists, saves it as a separate dataset + + Args: + cfg: Training configuration + optimization_step: Current optimization step + online_steps: Total number of online steps + interaction_message: Dictionary containing interaction information + policy: Policy model to save + optimizers: Dictionary of optimizers + replay_buffer: Replay buffer to save as dataset + offline_replay_buffer: Optional offline replay buffer to save + dataset_repo_id: Repository ID for dataset + fps: Frames per second for dataset + """ + logging.info(f"Checkpoint policy after step {optimization_step}") + _num_digits = max(6, len(str(online_steps))) + interaction_step = interaction_message["Interaction step"] if interaction_message is not None else 0 + + # Create checkpoint directory + checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, online_steps, optimization_step) + + # Save checkpoint + save_checkpoint( + checkpoint_dir=checkpoint_dir, + step=optimization_step, + cfg=cfg, + policy=policy, + optimizer=optimizers, + scheduler=None, + ) + + # Save interaction step manually + training_state_dir = os.path.join(checkpoint_dir, TRAINING_STATE_DIR) + os.makedirs(training_state_dir, exist_ok=True) + training_state = {"step": optimization_step, "interaction_step": interaction_step} + torch.save(training_state, os.path.join(training_state_dir, "training_state.pt")) + + # Update the "last" symlink + update_last_checkpoint(checkpoint_dir) + + # TODO : temporarly save replay buffer here, remove later when on the robot + # We want to control this with the keyboard inputs + dataset_dir = os.path.join(cfg.output_dir, "dataset") + if os.path.exists(dataset_dir) and os.path.isdir(dataset_dir): + shutil.rmtree(dataset_dir) + + # Save dataset + # NOTE: Handle the case where the dataset repo id is not specified in the config + # eg. RL training without demonstrations data + repo_id_buffer_save = cfg.env.task if dataset_repo_id is None else dataset_repo_id + replay_buffer.to_lerobot_dataset(repo_id=repo_id_buffer_save, fps=fps, root=dataset_dir) + + if offline_replay_buffer is not None: + dataset_offline_dir = os.path.join(cfg.output_dir, "dataset_offline") + if os.path.exists(dataset_offline_dir) and os.path.isdir(dataset_offline_dir): + shutil.rmtree(dataset_offline_dir) + + offline_replay_buffer.to_lerobot_dataset( + cfg.dataset.repo_id, + fps=fps, + root=dataset_offline_dir, + ) + + logging.info("Resume training") + + +def make_optimizers_and_scheduler(cfg, policy: nn.Module): + """ + Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy. + + This function sets up Adam optimizers for: + - The **actor network**, ensuring that only relevant parameters are optimized. + - The **critic ensemble**, which evaluates the value function. + - The **temperature parameter**, which controls the entropy in soft actor-critic (SAC)-like methods. + + It also initializes a learning rate scheduler, though currently, it is set to `None`. + + **NOTE:** + - If the encoder is shared, its parameters are excluded from the actor's optimization process. + - The policy's log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor. + + Args: + cfg: Configuration object containing hyperparameters. + policy (nn.Module): The policy model containing the actor, critic, and temperature components. + + Returns: + Tuple[Dict[str, torch.optim.Optimizer], Optional[torch.optim.lr_scheduler._LRScheduler]]: + A tuple containing: + - `optimizers`: A dictionary mapping component names ("actor", "critic", "temperature") to their respective Adam optimizers. + - `lr_scheduler`: Currently set to `None` but can be extended to support learning rate scheduling. + + """ + optimizer_actor = torch.optim.Adam( + # NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor + params=policy.actor.parameters_to_optimize, + lr=cfg.policy.actor_lr, + ) + optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr) + optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr) + lr_scheduler = None + optimizers = { + "actor": optimizer_actor, + "critic": optimizer_critic, + "temperature": optimizer_temperature, + } + return optimizers, lr_scheduler + + +################################################# +# Training setup functions # +################################################# + + def handle_resume_logic(cfg: TrainPipelineConfig) -> TrainPipelineConfig: """ Handle the resume logic for training. @@ -279,9 +939,28 @@ def initialize_offline_replay_buffer( return offline_replay_buffer +################################################# +# Utilities/Helpers functions # +################################################# + + def get_observation_features( policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor ) -> tuple[torch.Tensor | None, torch.Tensor | None]: + """ + Get observation features from the policy encoder. It act as cache for the observation features. + when the encoder is frozen, the observation features are not updated. + We can save compute by caching the observation features. + + Args: + policy: The policy model + observations: The current observations + next_observations: The next observations + + Returns: + tuple: observation_features, next_observation_features + """ + if policy.config.vision_encoder_name is None or not policy.config.freeze_vision_encoder: return None, None @@ -300,130 +979,6 @@ def use_threads(cfg: TrainPipelineConfig) -> bool: return cfg.policy.concurrency.learner == "threads" -def start_learner_threads( - cfg: TrainPipelineConfig, - wandb_logger: WandBLogger | None, - shutdown_event: any, # Event, -) -> None: - """ - Start the learner threads for training. - - Args: - cfg (TrainPipelineConfig): Training configuration - wandb_logger (WandBLogger | None): Logger for metrics - shutdown_event: Event to signal shutdown - """ - # Create multiprocessing queues - transition_queue = Queue() - interaction_message_queue = Queue() - parameters_queue = Queue() - - concurrency_entity = None - - if use_threads(cfg): - from threading import Thread - - concurrency_entity = Thread - else: - from torch.multiprocessing import Process - - concurrency_entity = Process - - communication_process = concurrency_entity( - target=start_learner_server, - args=( - parameters_queue, - transition_queue, - interaction_message_queue, - shutdown_event, - cfg, - ), - daemon=True, - ) - communication_process.start() - - add_actor_information_and_train( - cfg=cfg, - wandb_logger=wandb_logger, - shutdown_event=shutdown_event, - transition_queue=transition_queue, - interaction_message_queue=interaction_message_queue, - parameters_queue=parameters_queue, - ) - logging.info("[LEARNER] Training process stopped") - - logging.info("[LEARNER] Closing queues") - transition_queue.close() - interaction_message_queue.close() - parameters_queue.close() - - communication_process.join() - logging.info("[LEARNER] Communication process joined") - - logging.info("[LEARNER] join queues") - transition_queue.cancel_join_thread() - interaction_message_queue.cancel_join_thread() - parameters_queue.cancel_join_thread() - - logging.info("[LEARNER] queues closed") - - -def start_learner_server( - parameters_queue: Queue, - transition_queue: Queue, - interaction_message_queue: Queue, - shutdown_event: any, # Event, - cfg: TrainPipelineConfig, -): - if not use_threads(cfg): - # Create a process-specific log file - log_dir = os.path.join(cfg.output_dir, "logs") - os.makedirs(log_dir, exist_ok=True) - log_file = os.path.join(log_dir, f"learner_server_process_{os.getpid()}.log") - - # Initialize logging with explicit log file - init_logging(log_file=log_file) - logging.info(f"Learner server process logging initialized") - - # Setup process handlers to handle shutdown signal - # But use shutdown event from the main process - # Return back for MP - setup_process_handlers(False) - - service = learner_service.LearnerService( - shutdown_event=shutdown_event, - parameters_queue=parameters_queue, - seconds_between_pushes=cfg.policy.actor_learner_config.policy_parameters_push_frequency, - transition_queue=transition_queue, - interaction_message_queue=interaction_message_queue, - ) - - server = grpc.server( - ThreadPoolExecutor(max_workers=learner_service.MAX_WORKERS), - options=[ - ("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE), - ("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE), - ], - ) - - hilserl_pb2_grpc.add_LearnerServiceServicer_to_server( - service, - server, - ) - - host = cfg.policy.actor_learner_config.learner_host - port = cfg.policy.actor_learner_config.learner_port - - server.add_insecure_port(f"{host}:{port}") - server.start() - logging.info("[LEARNER] gRPC server started") - - shutdown_event.wait() - logging.info("[LEARNER] Stopping gRPC server...") - server.stop(learner_service.STUTDOWN_TIMEOUT) - logging.info("[LEARNER] gRPC server stopped") - - def check_nan_in_transition( observations: torch.Tensor, actions: torch.Tensor, @@ -477,532 +1032,91 @@ def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module): parameters_queue.put(state_bytes) -def add_actor_information_and_train( - cfg: TrainPipelineConfig, - wandb_logger: WandBLogger | None, - shutdown_event: any, # Event, - transition_queue: Queue, - interaction_message_queue: Queue, - parameters_queue: Queue, +def process_interaction_message( + message, interaction_step_shift: int, wandb_logger: WandBLogger | None = None ): - """ - Handles data transfer from the actor to the learner, manages training updates, - and logs training progress in an online reinforcement learning setup. + """Process a single interaction message with consistent handling.""" + message = bytes_to_python_object(message) + # Shift interaction step for consistency with checkpointed state + message["Interaction step"] += interaction_step_shift - This function continuously: - - Transfers transitions from the actor to the replay buffer. - - Logs received interaction messages. - - Ensures training begins only when the replay buffer has a sufficient number of transitions. - - Samples batches from the replay buffer and performs multiple critic updates. - - Periodically updates the actor, critic, and temperature optimizers. - - Logs training statistics, including loss values and optimization frequency. + # Log if logger available + if wandb_logger: + wandb_logger.log_dict(d=message, mode="train", custom_step_key="Interaction step") - Args: - cfg (TrainPipelineConfig): Configuration object containing hyperparameters. - wandb_logger (WandBLogger | None): Logger for tracking training progress. - shutdown_event (Event): Event to signal shutdown. - transition_queue (Queue): Queue for receiving transitions from the actor. - interaction_message_queue (Queue): Queue for receiving interaction messages from the actor. - parameters_queue (Queue): Queue for sending policy parameters to the actor. - """ - # Initialize logging for multiprocessing - if not use_threads(cfg): - log_dir = os.path.join(cfg.output_dir, "logs") - os.makedirs(log_dir, exist_ok=True) - log_file = os.path.join(log_dir, f"learner_train_process_{os.getpid()}.log") - init_logging(log_file=log_file) - logging.info(f"Initialized logging for actor information and training process") + return message - device = get_safe_torch_device(try_device=cfg.policy.device, log=True) - storage_device = get_safe_torch_device(try_device=cfg.policy.storage_device) - logging.info("Initializing policy") - # Get checkpoint dir for resuming - checkpoint_dir = ( - os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) if cfg.resume else None - ) - pretrained_path = os.path.join(checkpoint_dir, PRETRAINED_MODEL_DIR) if checkpoint_dir else None - - policy: SACPolicy = make_policy( - cfg=cfg.policy, - # ds_meta=cfg.dataset, - env_cfg=cfg.env, - ) - - # Update the policy config with the grad_clip_norm value from training config if it exists - clip_grad_norm_value: float = cfg.policy.grad_clip_norm - - # compile policy - policy = torch.compile(policy) - assert isinstance(policy, nn.Module) - policy.train() - - push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy) - - last_time_policy_pushed = time.time() - - optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg=cfg, policy=policy) - resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers) - - log_training_info(cfg=cfg, policy=policy) - - replay_buffer = initialize_replay_buffer(cfg, device, storage_device) - batch_size = cfg.batch_size - offline_replay_buffer = None - - if cfg.dataset is not None: - active_action_dims = None - # TODO: FIX THIS - if cfg.env.wrapper.joint_masking_action_space is not None: - active_action_dims = [ - i for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space) if mask - ] - offline_replay_buffer = initialize_offline_replay_buffer( - cfg=cfg, - device=device, - storage_device=storage_device, - active_action_dims=active_action_dims, - ) - batch_size: int = batch_size // 2 # We will sample from both replay buffer - - # NOTE: This function doesn't have a single responsibility, it should be split into multiple functions - # in the future. The reason why we did that is the GIL in Python. It's super slow the performance - # are divided by 200. So we need to have a single thread that does all the work. - time.time() - logging.info("Starting learner thread") - interaction_message, transition = None, None - optimization_step = resume_optimization_step if resume_optimization_step is not None else 0 - interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0 - - # Extract variables from cfg - online_step_before_learning = cfg.policy.online_step_before_learning - utd_ratio = cfg.policy.utd_ratio - - dataset_repo_id = None - if cfg.dataset is not None: - dataset_repo_id = cfg.dataset.repo_id - - fps = cfg.env.fps - log_freq = cfg.log_freq - save_freq = cfg.save_freq - device = cfg.policy.device - storage_device = cfg.policy.storage_device - policy_update_freq = cfg.policy.policy_update_freq - policy_parameters_push_frequency = cfg.policy.actor_learner_config.policy_parameters_push_frequency - saving_checkpoint = cfg.save_checkpoint - online_steps = cfg.policy.online_steps - - while True: - if shutdown_event is not None and shutdown_event.is_set(): - logging.info("[LEARNER] Shutdown signal received. Exiting...") - break - - logging.debug("[LEARNER] Waiting for transitions") - while not transition_queue.empty() and not shutdown_event.is_set(): - transition_list = transition_queue.get() - transition_list = bytes_to_transitions(transition_list) - - for transition in transition_list: - transition = move_transition_to_device(transition, device=device) - if check_nan_in_transition( - transition["state"], transition["action"], transition["next_state"] - ): - logging.warning("NaN detected in transition, skipping") - continue - replay_buffer.add(**transition) - - if dataset_repo_id is not None and transition.get("complementary_info", {}).get( - "is_intervention" - ): - offline_replay_buffer.add(**transition) - - logging.debug("[LEARNER] Received transitions") - logging.debug("[LEARNER] Waiting for interactions") - while not interaction_message_queue.empty() and not shutdown_event.is_set(): - interaction_message = interaction_message_queue.get() - interaction_message = bytes_to_python_object(interaction_message) - # If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging - interaction_message["Interaction step"] += interaction_step_shift - - # Log interaction messages with WandB if available - if wandb_logger: - wandb_logger.log_dict(d=interaction_message, mode="train", custom_step_key="Interaction step") - - logging.debug("[LEARNER] Received interactions") - - if len(replay_buffer) < online_step_before_learning: - continue - - logging.debug("[LEARNER] Starting optimization loop") - time_for_one_optimization_step = time.time() - for _ in range(utd_ratio - 1): - batch = replay_buffer.sample(batch_size=batch_size) - - if dataset_repo_id is not None: - batch_offline = offline_replay_buffer.sample(batch_size=batch_size) - batch = concatenate_batch_transitions( - left_batch_transitions=batch, right_batch_transition=batch_offline - ) - - actions = batch["action"] - rewards = batch["reward"] - observations = batch["state"] - next_observations = batch["next_state"] - done = batch["done"] - check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations) - - observation_features, next_observation_features = get_observation_features( - policy=policy, observations=observations, next_observations=next_observations - ) - loss_critic = policy.compute_loss_critic( - observations=observations, - actions=actions, - rewards=rewards, - next_observations=next_observations, - done=done, - observation_features=observation_features, - next_observation_features=next_observation_features, - ) - optimizers["critic"].zero_grad() - loss_critic.backward() - - # clip gradients - critic_grad_norm = torch.nn.utils.clip_grad_norm_( - parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value - ) - - optimizers["critic"].step() - - batch = replay_buffer.sample(batch_size=batch_size) - - if dataset_repo_id is not None: - batch_offline = offline_replay_buffer.sample(batch_size=batch_size) - batch = concatenate_batch_transitions( - left_batch_transitions=batch, right_batch_transition=batch_offline - ) - - actions = batch["action"] - rewards = batch["reward"] - observations = batch["state"] - next_observations = batch["next_state"] - done = batch["done"] - - check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations) - - observation_features, next_observation_features = get_observation_features( - policy=policy, observations=observations, next_observations=next_observations - ) - loss_critic = policy.compute_loss_critic( - observations=observations, - actions=actions, - rewards=rewards, - next_observations=next_observations, - done=done, - observation_features=observation_features, - next_observation_features=next_observation_features, - ) - optimizers["critic"].zero_grad() - loss_critic.backward() - - # clip gradients - critic_grad_norm = torch.nn.utils.clip_grad_norm_( - parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value - ).item() - - optimizers["critic"].step() - - training_infos = {} - training_infos["loss_critic"] = loss_critic.item() - training_infos["critic_grad_norm"] = critic_grad_norm - - if optimization_step % policy_update_freq == 0: - for _ in range(policy_update_freq): - loss_actor = policy.compute_loss_actor( - observations=observations, - observation_features=observation_features, - ) - - optimizers["actor"].zero_grad() - loss_actor.backward() - - # clip gradients - actor_grad_norm = torch.nn.utils.clip_grad_norm_( - parameters=policy.actor.parameters_to_optimize, max_norm=clip_grad_norm_value - ).item() - - optimizers["actor"].step() - - training_infos["loss_actor"] = loss_actor.item() - training_infos["actor_grad_norm"] = actor_grad_norm - - # Temperature optimization - loss_temperature = policy.compute_loss_temperature( - observations=observations, - observation_features=observation_features, - ) - optimizers["temperature"].zero_grad() - loss_temperature.backward() - - # clip gradients - temp_grad_norm = torch.nn.utils.clip_grad_norm_( - parameters=[policy.log_alpha], max_norm=clip_grad_norm_value - ).item() - - optimizers["temperature"].step() - - training_infos["loss_temperature"] = loss_temperature.item() - training_infos["temperature_grad_norm"] = temp_grad_norm - training_infos["temperature"] = policy.temperature - - if time.time() - last_time_policy_pushed > policy_parameters_push_frequency: - push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy) - last_time_policy_pushed = time.time() - - policy.update_target_networks() - - if optimization_step % log_freq == 0: - training_infos["replay_buffer_size"] = len(replay_buffer) - if offline_replay_buffer is not None: - training_infos["offline_replay_buffer_size"] = len(offline_replay_buffer) - training_infos["Optimization step"] = optimization_step - - # Log training metrics - if wandb_logger: - wandb_logger.log_dict(d=training_infos, mode="train", custom_step_key="Optimization step") - - time_for_one_optimization_step = time.time() - time_for_one_optimization_step - frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9) - - logging.info(f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}") - - # Log optimization frequency - if wandb_logger: - wandb_logger.log_dict( - { - "Optimization frequency loop [Hz]": frequency_for_one_optimization_step, - "Optimization step": optimization_step, - }, - mode="train", - custom_step_key="Optimization step", - ) - - optimization_step += 1 - if optimization_step % log_freq == 0: - logging.info(f"[LEARNER] Number of optimization step: {optimization_step}") - - if saving_checkpoint and (optimization_step % save_freq == 0 or optimization_step == online_steps): - save_training_checkpoint( - cfg=cfg, - optimization_step=optimization_step, - online_steps=online_steps, - interaction_message=interaction_message, - policy=policy, - optimizers=optimizers, - replay_buffer=replay_buffer, - offline_replay_buffer=offline_replay_buffer, - dataset_repo_id=dataset_repo_id, - fps=fps, - ) - -def save_training_checkpoint( - cfg: TrainPipelineConfig, - optimization_step: int, - online_steps: int, - interaction_message: dict | None, - policy: nn.Module, - optimizers: dict[str, Optimizer], +def process_transitions( + transition_queue: Queue, replay_buffer: ReplayBuffer, - offline_replay_buffer: ReplayBuffer | None = None, - dataset_repo_id: str | None = None, - fps: int = 30, -) -> None: - """ - Save training checkpoint and associated data. - + offline_replay_buffer: ReplayBuffer, + device: str, + dataset_repo_id: str | None, + shutdown_event: any, +): + """Process all available transitions from the queue. + Args: - cfg: Training configuration - optimization_step: Current optimization step - online_steps: Total number of online steps - interaction_message: Dictionary containing interaction information - policy: Policy model to save - optimizers: Dictionary of optimizers - replay_buffer: Replay buffer to save as dataset - offline_replay_buffer: Optional offline replay buffer to save + transition_queue: Queue for receiving transitions from the actor + replay_buffer: Replay buffer to add transitions to + offline_replay_buffer: Offline replay buffer to add transitions to + device: Device to move transitions to dataset_repo_id: Repository ID for dataset - fps: Frames per second for dataset + shutdown_event: Event to signal shutdown """ - logging.info(f"Checkpoint policy after step {optimization_step}") - _num_digits = max(6, len(str(online_steps))) - step_identifier = f"{optimization_step:0{_num_digits}d}" - interaction_step = ( - interaction_message["Interaction step"] if interaction_message is not None else 0 - ) - - # Create checkpoint directory - checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, online_steps, optimization_step) - - # Save checkpoint - save_checkpoint( - checkpoint_dir=checkpoint_dir, - step=optimization_step, - cfg=cfg, - policy=policy, - optimizer=optimizers, - scheduler=None - ) - - # Save interaction step manually - training_state_dir = os.path.join(checkpoint_dir, TRAINING_STATE_DIR) - os.makedirs(training_state_dir, exist_ok=True) - training_state = { - "step": optimization_step, - "interaction_step": interaction_step - } - torch.save(training_state, os.path.join(training_state_dir, "training_state.pt")) - - # Update the "last" symlink - update_last_checkpoint(checkpoint_dir) + while not transition_queue.empty() and not shutdown_event.is_set(): + transition_list = transition_queue.get() + transition_list = bytes_to_transitions(buffer=transition_list) - # TODO : temporarly save replay buffer here, remove later when on the robot - # We want to control this with the keyboard inputs - dataset_dir = os.path.join(cfg.output_dir, "dataset") - if os.path.exists(dataset_dir) and os.path.isdir(dataset_dir): - shutil.rmtree(dataset_dir) - - # Save dataset - # NOTE: Handle the case where the dataset repo id is not specified in the config - # eg. RL training without demonstrations data - repo_id_buffer_save = cfg.env.task if dataset_repo_id is None else dataset_repo_id - replay_buffer.to_lerobot_dataset( - repo_id=repo_id_buffer_save, - fps=fps, - root=dataset_dir - ) - - if offline_replay_buffer is not None: - dataset_offline_dir = os.path.join(cfg.output_dir, "dataset_offline") - if os.path.exists(dataset_offline_dir) and os.path.isdir(dataset_offline_dir): - shutil.rmtree(dataset_offline_dir) + for transition in transition_list: + transition = move_transition_to_device(transition=transition, device=device) - offline_replay_buffer.to_lerobot_dataset( - cfg.dataset.repo_id, - fps=fps, - root=dataset_offline_dir, - ) + # Skip transitions with NaN values + if check_nan_in_transition( + observations=transition["state"], + actions=transition["action"], + next_state=transition["next_state"], + ): + logging.warning("[LEARNER] NaN detected in transition, skipping") + continue - logging.info("Resume training") + replay_buffer.add(**transition) -def make_optimizers_and_scheduler(cfg, policy: nn.Module): - """ - Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy. + # Add to offline buffer if it's an intervention + if dataset_repo_id is not None and transition.get("complementary_info", {}).get( + "is_intervention" + ): + offline_replay_buffer.add(**transition) - This function sets up Adam optimizers for: - - The **actor network**, ensuring that only relevant parameters are optimized. - - The **critic ensemble**, which evaluates the value function. - - The **temperature parameter**, which controls the entropy in soft actor-critic (SAC)-like methods. - It also initializes a learning rate scheduler, though currently, it is set to `None`. - - **NOTE:** - - If the encoder is shared, its parameters are excluded from the actor's optimization process. - - The policy's log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor. +def process_interaction_messages( + interaction_message_queue: Queue, + interaction_step_shift: int, + wandb_logger: WandBLogger | None, + shutdown_event: any, +) -> dict | None: + """Process all available interaction messages from the queue. Args: - cfg: Configuration object containing hyperparameters. - policy (nn.Module): The policy model containing the actor, critic, and temperature components. + interaction_message_queue: Queue for receiving interaction messages + interaction_step_shift: Amount to shift interaction step by + wandb_logger: Logger for tracking progress + shutdown_event: Event to signal shutdown Returns: - Tuple[Dict[str, torch.optim.Optimizer], Optional[torch.optim.lr_scheduler._LRScheduler]]: - A tuple containing: - - `optimizers`: A dictionary mapping component names ("actor", "critic", "temperature") to their respective Adam optimizers. - - `lr_scheduler`: Currently set to `None` but can be extended to support learning rate scheduling. - + dict | None: The last interaction message processed, or None if none were processed """ - optimizer_actor = torch.optim.Adam( - # NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor - params=policy.actor.parameters_to_optimize, - lr=cfg.policy.actor_lr, - ) - optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr) - optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr) - lr_scheduler = None - optimizers = { - "actor": optimizer_actor, - "critic": optimizer_critic, - "temperature": optimizer_temperature, - } - return optimizers, lr_scheduler + last_message = None + while not interaction_message_queue.empty() and not shutdown_event.is_set(): + message = interaction_message_queue.get() + last_message = process_interaction_message( + message=message, + interaction_step_shift=interaction_step_shift, + wandb_logger=wandb_logger, + ) - -def train(cfg: TrainPipelineConfig, job_name: str | None = None): - """ - Main training function that initializes and runs the training process. - - Args: - cfg (TrainPipelineConfig): The training configuration - job_name (str | None, optional): Job name for logging. Defaults to None. - """ - - cfg.validate() - # if cfg.output_dir is None: - # raise ValueError("Output directory must be specified in config") - - if job_name is None: - job_name = cfg.job_name - - if job_name is None: - raise ValueError("Job name must be specified either in config or as a parameter") - - # Create logs directory to ensure it exists - log_dir = os.path.join(cfg.output_dir, "logs") - os.makedirs(log_dir, exist_ok=True) - log_file = os.path.join(log_dir, f"learner_{job_name}.log") - - # Initialize logging with explicit log file - init_logging(log_file=log_file) - logging.info(f"Learner logging initialized, writing to {log_file}") - logging.info(pformat(cfg.to_dict())) - - # Setup WandB logging if enabled - if cfg.wandb.enable and cfg.wandb.project: - from lerobot.common.utils.wandb_utils import WandBLogger - - wandb_logger = WandBLogger(cfg) - else: - wandb_logger = None - logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"])) - - # Handle resume logic - cfg = handle_resume_logic(cfg) - - set_seed(seed=cfg.seed) - - torch.backends.cudnn.benchmark = True - torch.backends.cuda.matmul.allow_tf32 = True - - shutdown_event = setup_process_handlers(use_threads(cfg)) - - start_learner_threads( - cfg=cfg, - wandb_logger=wandb_logger, - shutdown_event=shutdown_event, - ) - - -@parser.wrap() -def train_cli(cfg: TrainPipelineConfig): - if not use_threads(cfg): - import torch.multiprocessing as mp - - mp.set_start_method("spawn") - - # Use the job_name from the config - train( - cfg, - job_name=cfg.job_name, - ) - - logging.info("[LEARNER] train_cli finished") + return last_message if __name__ == "__main__":