From 12f570292cfc58ad3be24082ce54d13b55841d19 Mon Sep 17 00:00:00 2001 From: Eugene Mironov Date: Wed, 12 Mar 2025 23:44:17 +0700 Subject: [PATCH] Optimize training loop, extract config usage --- lerobot/scripts/server/learner_server.py | 179 +++++++++++++++++++++++ 1 file changed, 179 insertions(+) diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 7bd4aee0..9056e4cb 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -444,6 +444,22 @@ def add_actor_information_and_train( resume_interaction_step if resume_interaction_step is not None else 0 ) + # Extract variables from cfg + online_step_before_learning = cfg.training.online_step_before_learning + utd_ratio = cfg.policy.utd_ratio + dataset_repo_id = cfg.dataset_repo_id + fps = cfg.fps + log_freq = cfg.training.log_freq + save_freq = cfg.training.save_freq + device = cfg.device + storage_device = cfg.training.storage_device + policy_update_freq = cfg.training.policy_update_freq + policy_parameters_push_frequency = ( + cfg.actor_learner_config.policy_parameters_push_frequency + ) + save_checkpoint = cfg.training.save_checkpoint + online_steps = cfg.training.online_steps + while True: if shutdown_event is not None and shutdown_event.is_set(): logging.info("[LEARNER] Shutdown signal received. Exiting...") @@ -472,6 +488,169 @@ def add_actor_information_and_train( logging.debug("[LEARNER] Received interactions") + if len(replay_buffer) < online_step_before_learning: + continue + + logging.debug("[LEARNER] Starting optimization loop") + time_for_one_optimization_step = time.time() + for _ in range(utd_ratio): + # profiler = cProfile.Profile() + # profiler.enable() + batch = replay_buffer.sample(batch_size) + + # profiler.disable() + # profiler.dump_stats("sample_buffer.prof") + + if cfg.dataset_repo_id is not None: + batch_offline = offline_replay_buffer.sample(batch_size) + batch = concatenate_batch_transitions(batch, batch_offline) + + actions = batch["action"] + rewards = batch["reward"] + observations = batch["state"] + next_observations = batch["next_state"] + done = batch["done"] + check_nan_in_transition( + observations=observations, actions=actions, next_state=next_observations + ) + + observation_features, next_observation_features = get_observation_features( + policy, observations, next_observations + ) + loss_critic = policy.compute_loss_critic( + observations=observations, + actions=actions, + rewards=rewards, + next_observations=next_observations, + done=done, + observation_features=observation_features, + next_observation_features=next_observation_features, + ) + optimizers["critic"].zero_grad() + loss_critic.backward() + optimizers["critic"].step() + + training_infos = {} + training_infos["loss_critic"] = loss_critic.item() + + if optimization_step % policy_update_freq == 0: + for _ in range(policy_update_freq): + loss_actor = policy.compute_loss_actor( + observations=observations, + observation_features=observation_features, + ) + + optimizers["actor"].zero_grad() + loss_actor.backward() + optimizers["actor"].step() + + training_infos["loss_actor"] = loss_actor.item() + + loss_temperature = policy.compute_loss_temperature( + observations=observations, + observation_features=observation_features, + ) + optimizers["temperature"].zero_grad() + loss_temperature.backward() + optimizers["temperature"].step() + + training_infos["loss_temperature"] = loss_temperature.item() + + if time.time() - last_time_policy_pushed > policy_parameters_push_frequency: + push_actor_policy_to_queue(parameters_queue, policy) + last_time_policy_pushed = time.time() + + policy.update_target_networks() + if optimization_step % log_freq == 0: + training_infos["Optimization step"] = optimization_step + logger.log_dict( + d=training_infos, mode="train", custom_step_key="Optimization step" + ) + + time_for_one_optimization_step = time.time() - time_for_one_optimization_step + frequency_for_one_optimization_step = 1 / ( + time_for_one_optimization_step + 1e-9 + ) + + logging.info( + f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}" + ) + + logger.log_dict( + { + "Optimization frequency loop [Hz]": frequency_for_one_optimization_step, + "Optimization step": optimization_step, + }, + mode="train", + custom_step_key="Optimization step", + ) + + optimization_step += 1 + if optimization_step % log_freq == 0: + logging.info(f"[LEARNER] Number of optimization step: {optimization_step}") + + if save_checkpoint and ( + optimization_step % save_freq == 0 or optimization_step == online_steps + ): + logging.info(f"Checkpoint policy after step {optimization_step}") + # Note: Save with step as the identifier, and format it to have at least 6 digits but more if + # needed (choose 6 as a minimum for consistency without being overkill). + _num_digits = max(6, len(str(online_steps))) + step_identifier = f"{optimization_step:0{_num_digits}d}" + interaction_step = ( + interaction_message["Interaction step"] + if interaction_message is not None + else 0 + ) + logger.save_checkpoint( + optimization_step, + policy, + optimizers, + scheduler=None, + identifier=step_identifier, + interaction_step=interaction_step, + ) + + # TODO : temporarly save replay buffer here, remove later when on the robot + # We want to control this with the keyboard inputs + dataset_dir = logger.log_dir / "dataset" + if dataset_dir.exists() and dataset_dir.is_dir(): + shutil.rmtree( + dataset_dir, + ) + replay_buffer.to_lerobot_dataset( + dataset_repo_id, fps=fps, root=logger.log_dir / "dataset" + ) + + logging.info("Resume training") + + if shutdown_event is not None and shutdown_event.is_set(): + logging.info("[LEARNER] Shutdown signal received. Exiting...") + break + + logging.debug("[LEARNER] Waiting for transitions") + while not transition_queue.empty() and not shutdown_event.is_set(): + transition_list = transition_queue.get() + transition_list = bytes_to_transitions(transition_list) + + for transition in transition_list: + transition = move_transition_to_device(transition, device=device) + replay_buffer.add(**transition) + if transition.get("complementary_info", {}).get("is_intervention"): + offline_replay_buffer.add(**transition) + logging.debug("[LEARNER] Received transitions") + logging.debug("[LEARNER] Waiting for interactions") + while not interaction_message_queue.empty() and not shutdown_event.is_set(): + interaction_message = interaction_message_queue.get() + interaction_message = bytes_to_python_object(interaction_message) + # If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging + interaction_message["Interaction step"] += interaction_step_shift + logger.log_dict( + interaction_message, mode="train", custom_step_key="Interaction step" + ) + + logging.debug("[LEARNER] Received interactions") + if len(replay_buffer) < cfg.training.online_step_before_learning: continue