diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 9056e4cb..f573c710 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -493,15 +493,10 @@ def add_actor_information_and_train( logging.debug("[LEARNER] Starting optimization loop") time_for_one_optimization_step = time.time() - for _ in range(utd_ratio): - # profiler = cProfile.Profile() - # profiler.enable() + for _ in range(utd_ratio - 1): batch = replay_buffer.sample(batch_size) - # profiler.disable() - # profiler.dump_stats("sample_buffer.prof") - - if cfg.dataset_repo_id is not None: + if dataset_repo_id is not None: batch_offline = offline_replay_buffer.sample(batch_size) batch = concatenate_batch_transitions(batch, batch_offline) @@ -530,6 +525,40 @@ def add_actor_information_and_train( loss_critic.backward() optimizers["critic"].step() + batch = replay_buffer.sample(batch_size) + + if dataset_repo_id is not None: + batch_offline = offline_replay_buffer.sample(batch_size) + batch = concatenate_batch_transitions( + left_batch_transitions=batch, right_batch_transition=batch_offline + ) + + actions = batch["action"] + rewards = batch["reward"] + observations = batch["state"] + next_observations = batch["next_state"] + done = batch["done"] + + check_nan_in_transition( + observations=observations, actions=actions, next_state=next_observations + ) + + observation_features, next_observation_features = get_observation_features( + policy, observations, next_observations + ) + loss_critic = policy.compute_loss_critic( + observations=observations, + actions=actions, + rewards=rewards, + next_observations=next_observations, + done=done, + observation_features=observation_features, + next_observation_features=next_observation_features, + ) + optimizers["critic"].zero_grad() + loss_critic.backward() + optimizers["critic"].step() + training_infos = {} training_infos["loss_critic"] = loss_critic.item() @@ -566,6 +595,7 @@ def add_actor_information_and_train( logger.log_dict( d=training_infos, mode="train", custom_step_key="Optimization step" ) + # logging.info(f"Training infos: {training_infos}") time_for_one_optimization_step = time.time() - time_for_one_optimization_step frequency_for_one_optimization_step = 1 / ( @@ -593,8 +623,6 @@ def add_actor_information_and_train( optimization_step % save_freq == 0 or optimization_step == online_steps ): logging.info(f"Checkpoint policy after step {optimization_step}") - # Note: Save with step as the identifier, and format it to have at least 6 digits but more if - # needed (choose 6 as a minimum for consistency without being overkill). _num_digits = max(6, len(str(online_steps))) step_identifier = f"{optimization_step:0{_num_digits}d}" interaction_step = ( @@ -624,203 +652,6 @@ def add_actor_information_and_train( logging.info("Resume training") - if shutdown_event is not None and shutdown_event.is_set(): - logging.info("[LEARNER] Shutdown signal received. Exiting...") - break - - logging.debug("[LEARNER] Waiting for transitions") - while not transition_queue.empty() and not shutdown_event.is_set(): - transition_list = transition_queue.get() - transition_list = bytes_to_transitions(transition_list) - - for transition in transition_list: - transition = move_transition_to_device(transition, device=device) - replay_buffer.add(**transition) - if transition.get("complementary_info", {}).get("is_intervention"): - offline_replay_buffer.add(**transition) - logging.debug("[LEARNER] Received transitions") - logging.debug("[LEARNER] Waiting for interactions") - while not interaction_message_queue.empty() and not shutdown_event.is_set(): - interaction_message = interaction_message_queue.get() - interaction_message = bytes_to_python_object(interaction_message) - # If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging - interaction_message["Interaction step"] += interaction_step_shift - logger.log_dict( - interaction_message, mode="train", custom_step_key="Interaction step" - ) - - logging.debug("[LEARNER] Received interactions") - - if len(replay_buffer) < cfg.training.online_step_before_learning: - continue - - logging.debug("[LEARNER] Starting optimization loop") - time_for_one_optimization_step = time.time() - for _ in range(cfg.policy.utd_ratio - 1): - batch = replay_buffer.sample(batch_size) - - if cfg.dataset_repo_id is not None: - batch_offline = offline_replay_buffer.sample(batch_size) - batch = concatenate_batch_transitions(batch, batch_offline) - - actions = batch["action"] - rewards = batch["reward"] - observations = batch["state"] - next_observations = batch["next_state"] - done = batch["done"] - check_nan_in_transition( - observations=observations, actions=actions, next_state=next_observations - ) - - observation_features, next_observation_features = get_observation_features( - policy, observations, next_observations - ) - loss_critic = policy.compute_loss_critic( - observations=observations, - actions=actions, - rewards=rewards, - next_observations=next_observations, - done=done, - observation_features=observation_features, - next_observation_features=next_observation_features, - ) - optimizers["critic"].zero_grad() - loss_critic.backward() - optimizers["critic"].step() - - batch = replay_buffer.sample(batch_size) - - if cfg.dataset_repo_id is not None: - batch_offline = offline_replay_buffer.sample(batch_size) - batch = concatenate_batch_transitions( - left_batch_transitions=batch, right_batch_transition=batch_offline - ) - - actions = batch["action"] - rewards = batch["reward"] - observations = batch["state"] - next_observations = batch["next_state"] - done = batch["done"] - - check_nan_in_transition( - observations=observations, actions=actions, next_state=next_observations - ) - - observation_features, next_observation_features = get_observation_features( - policy, observations, next_observations - ) - loss_critic = policy.compute_loss_critic( - observations=observations, - actions=actions, - rewards=rewards, - next_observations=next_observations, - done=done, - observation_features=observation_features, - next_observation_features=next_observation_features, - ) - optimizers["critic"].zero_grad() - loss_critic.backward() - optimizers["critic"].step() - - training_infos = {} - training_infos["loss_critic"] = loss_critic.item() - - if optimization_step % cfg.training.policy_update_freq == 0: - for _ in range(cfg.training.policy_update_freq): - loss_actor = policy.compute_loss_actor( - observations=observations, - observation_features=observation_features, - ) - - optimizers["actor"].zero_grad() - loss_actor.backward() - optimizers["actor"].step() - - training_infos["loss_actor"] = loss_actor.item() - - loss_temperature = policy.compute_loss_temperature( - observations=observations, - observation_features=observation_features, - ) - optimizers["temperature"].zero_grad() - loss_temperature.backward() - optimizers["temperature"].step() - - training_infos["loss_temperature"] = loss_temperature.item() - - if ( - time.time() - last_time_policy_pushed - > cfg.actor_learner_config.policy_parameters_push_frequency - ): - push_actor_policy_to_queue(parameters_queue, policy) - last_time_policy_pushed = time.time() - - policy.update_target_networks() - if optimization_step % cfg.training.log_freq == 0: - training_infos["Optimization step"] = optimization_step - logger.log_dict( - d=training_infos, mode="train", custom_step_key="Optimization step" - ) - # logging.info(f"Training infos: {training_infos}") - - time_for_one_optimization_step = time.time() - time_for_one_optimization_step - frequency_for_one_optimization_step = 1 / ( - time_for_one_optimization_step + 1e-9 - ) - - logging.info( - f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}" - ) - - logger.log_dict( - { - "Optimization frequency loop [Hz]": frequency_for_one_optimization_step, - "Optimization step": optimization_step, - }, - mode="train", - custom_step_key="Optimization step", - ) - - optimization_step += 1 - if optimization_step % cfg.training.log_freq == 0: - logging.info(f"[LEARNER] Number of optimization step: {optimization_step}") - - if cfg.training.save_checkpoint and ( - optimization_step % cfg.training.save_freq == 0 - or optimization_step == cfg.training.online_steps - ): - logging.info(f"Checkpoint policy after step {optimization_step}") - # Note: Save with step as the identifier, and format it to have at least 6 digits but more if - # needed (choose 6 as a minimum for consistency without being overkill). - _num_digits = max(6, len(str(cfg.training.online_steps))) - step_identifier = f"{optimization_step:0{_num_digits}d}" - interaction_step = ( - interaction_message["Interaction step"] - if interaction_message is not None - else 0 - ) - logger.save_checkpoint( - optimization_step, - policy, - optimizers, - scheduler=None, - identifier=step_identifier, - interaction_step=interaction_step, - ) - - # TODO : temporarly save replay buffer here, remove later when on the robot - # We want to control this with the keyboard inputs - dataset_dir = logger.log_dir / "dataset" - if dataset_dir.exists() and dataset_dir.is_dir(): - shutil.rmtree( - dataset_dir, - ) - replay_buffer.to_lerobot_dataset( - cfg.dataset_repo_id, fps=cfg.fps, root=logger.log_dir / "dataset" - ) - - logging.info("Resume training") - def make_optimizers_and_scheduler(cfg, policy: nn.Module): """