From 91f8d3793827a1b5bc3d6d325dae908bbe30aa11 Mon Sep 17 00:00:00 2001 From: Eugene Mironov Date: Wed, 19 Mar 2025 20:27:32 +0700 Subject: [PATCH] [PORT HIL-SERL] Optimize training loop, extract config usage (#855) Co-authored-by: Adil Zouitine --- lerobot/scripts/server/learner_server.py | 47 +++++++++++++++--------- 1 file changed, 29 insertions(+), 18 deletions(-) diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 2b19fea2..89e51f62 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -509,6 +509,22 @@ def add_actor_information_and_train( resume_interaction_step if resume_interaction_step is not None else 0 ) + # Extract variables from cfg + online_step_before_learning = cfg.training.online_step_before_learning + utd_ratio = cfg.policy.utd_ratio + dataset_repo_id = cfg.dataset_repo_id + fps = cfg.fps + log_freq = cfg.training.log_freq + save_freq = cfg.training.save_freq + device = cfg.device + storage_device = cfg.training.storage_device + policy_update_freq = cfg.training.policy_update_freq + policy_parameters_push_frequency = ( + cfg.actor_learner_config.policy_parameters_push_frequency + ) + save_checkpoint = cfg.training.save_checkpoint + online_steps = cfg.training.online_steps + while True: if shutdown_event is not None and shutdown_event.is_set(): logging.info("[LEARNER] Shutdown signal received. Exiting...") @@ -546,15 +562,15 @@ def add_actor_information_and_train( logging.debug("[LEARNER] Received interactions") - if len(replay_buffer) < cfg.training.online_step_before_learning: + if len(replay_buffer) < online_step_before_learning: continue logging.debug("[LEARNER] Starting optimization loop") time_for_one_optimization_step = time.time() - for _ in range(cfg.policy.utd_ratio - 1): + for _ in range(utd_ratio - 1): batch = replay_buffer.sample(batch_size) - if cfg.dataset_repo_id is not None: + if dataset_repo_id is not None: batch_offline = offline_replay_buffer.sample(batch_size) batch = concatenate_batch_transitions(batch, batch_offline) @@ -591,7 +607,7 @@ def add_actor_information_and_train( batch = replay_buffer.sample(batch_size) - if cfg.dataset_repo_id is not None: + if dataset_repo_id is not None: batch_offline = offline_replay_buffer.sample(batch_size) batch = concatenate_batch_transitions( left_batch_transitions=batch, right_batch_transition=batch_offline @@ -633,8 +649,8 @@ def add_actor_information_and_train( training_infos["loss_critic"] = loss_critic.item() training_infos["critic_grad_norm"] = critic_grad_norm - if optimization_step % cfg.training.policy_update_freq == 0: - for _ in range(cfg.training.policy_update_freq): + if optimization_step % policy_update_freq == 0: + for _ in range(policy_update_freq): loss_actor = policy.compute_loss_actor( observations=observations, observation_features=observation_features, @@ -672,14 +688,12 @@ def add_actor_information_and_train( training_infos["temperature_grad_norm"] = temp_grad_norm training_infos["temperature"] = policy.temperature - if ( - time.time() - last_time_policy_pushed - > cfg.actor_learner_config.policy_parameters_push_frequency - ): + if time.time() - last_time_policy_pushed > policy_parameters_push_frequency: push_actor_policy_to_queue(parameters_queue, policy) last_time_policy_pushed = time.time() policy.update_target_networks() + if optimization_step % cfg.training.log_freq == 0: training_infos["replay_buffer_size"] = len(replay_buffer) if offline_replay_buffer is not None: @@ -711,17 +725,14 @@ def add_actor_information_and_train( ) optimization_step += 1 - if optimization_step % cfg.training.log_freq == 0: + if optimization_step % log_freq == 0: logging.info(f"[LEARNER] Number of optimization step: {optimization_step}") - if cfg.training.save_checkpoint and ( - optimization_step % cfg.training.save_freq == 0 - or optimization_step == cfg.training.online_steps + if save_checkpoint and ( + optimization_step % save_freq == 0 or optimization_step == online_steps ): logging.info(f"Checkpoint policy after step {optimization_step}") - # Note: Save with step as the identifier, and format it to have at least 6 digits but more if - # needed (choose 6 as a minimum for consistency without being overkill). - _num_digits = max(6, len(str(cfg.training.online_steps))) + _num_digits = max(6, len(str(online_steps))) step_identifier = f"{optimization_step:0{_num_digits}d}" interaction_step = ( interaction_message["Interaction step"] @@ -745,7 +756,7 @@ def add_actor_information_and_train( dataset_dir, ) replay_buffer.to_lerobot_dataset( - cfg.dataset_repo_id, fps=cfg.fps, root=logger.log_dir / "dataset" + dataset_repo_id, fps=fps, root=logger.log_dir / "dataset" ) if offline_replay_buffer is not None: dataset_dir = logger.log_dir / "dataset_offline"