diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py index 35c12062..569cad69 100644 --- a/lerobot/common/logger.py +++ b/lerobot/common/logger.py @@ -259,7 +259,7 @@ class Logger: if k == custom_step_key: continue - if self._wandb_custom_step_key is not None: + if self._wandb_custom_step_key is not None and custom_step_key is not None: # NOTE: Log the metric with the custom step key. value_custom_step_key = d[custom_step_key] self._wandb.log({f"{mode}/{k}": v, self._wandb_custom_step_key: value_custom_step_key}) diff --git a/lerobot/configs/policy/sac_manyskill.yaml b/lerobot/configs/policy/sac_manyskill.yaml index e4c3f17d..fc824da5 100644 --- a/lerobot/configs/policy/sac_manyskill.yaml +++ b/lerobot/configs/policy/sac_manyskill.yaml @@ -82,7 +82,7 @@ policy: temperature_lr: 3e-4 # critic_target_update_weight: 0.005 critic_target_update_weight: 0.01 - utd_ratio: 1 + utd_ratio: 2 # # Loss coefficients. diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 22777a26..bd15fc01 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -116,6 +116,7 @@ def learner_push_parameters( params_bytes = buf.getvalue() # Push them to the Actor’s "SendParameters" method + logging.info(f"[LEARNER] Pushing parameters to the Actor") response = actor_stub.SendParameters(hilserl_pb2.Parameters(parameter_bytes=params_bytes)) time.sleep(seconds_between_pushes) @@ -144,7 +145,7 @@ def add_actor_information( # are divided by 200. So we need to have a single thread that does all the work. start = time.time() optimization_step = 0 - + timeout_for_adding_transitions = 1 while True: time_for_adding_transitions = time.time() while not transition_queue.empty(): @@ -153,99 +154,103 @@ def add_actor_information( for transition in transition_list: transition = move_transition_to_device(transition, device=device) replay_buffer.add(**transition) + # logging.info(f"[LEARNER] size of replay buffer: {len(replay_buffer)}") + # logging.info(f"[LEARNER] size of transition queues: {transition_queue.qsize()}") + # logging.info(f"[LEARNER] size of replay buffer: {len(replay_buffer)}") + # logging.info(f"[LEARNER] size of transition queues: {transition }") + if len(replay_buffer) > cfg.training.online_step_before_learning: logging.info(f"[LEARNER] size of replay buffer: {len(replay_buffer)}") - logging.info(f"[LEARNER] size of transition queues: {transition_queue.qsize()}") - while not interaction_message_queue.empty(): interaction_message = interaction_message_queue.get() logger.log_dict(interaction_message,mode="train",custom_step_key="interaction_step") - logging.info(f"[LEARNER] size of interaction message queue: {interaction_message_queue.qsize()}") + # logging.info(f"[LEARNER] size of interaction message queue: {interaction_message_queue.qsize()}") - # if len(replay_buffer.memory) < cfg.training.online_step_before_learning: - # continue + if len(replay_buffer) < cfg.training.online_step_before_learning: + continue + time_for_one_optimization_step = time.time() + for _ in range(cfg.policy.utd_ratio - 1): + batch = replay_buffer.sample(batch_size) - # for _ in range(cfg.policy.utd_ratio - 1): + if cfg.dataset_repo_id is not None: + batch_offline = offline_replay_buffer.sample(batch_size) + batch = concatenate_batch_transitions(batch, batch_offline) - # batch = replay_buffer.sample(batch_size) - # if cfg.dataset_repo_id is not None: - # batch_offline = offline_replay_buffer.sample(batch_size) - # batch = concatenate_batch_transitions(batch, batch_offline) + actions = batch["action"] + rewards = batch["reward"] + observations = batch["state"] + next_observations = batch["next_state"] + done = batch["done"] - # actions = batch["action"] - # rewards = batch["reward"] - # observations = batch["state"] - # next_observations = batch["next_state"] - # done = batch["done"] + with policy_lock: + loss_critic = policy.compute_loss_critic( + observations=observations, + actions=actions, + rewards=rewards, + next_observations=next_observations, + done=done, + ) + optimizers["critic"].zero_grad() + loss_critic.backward() + optimizers["critic"].step() - # with policy_lock: - # loss_critic = policy.compute_loss_critic( - # observations=observations, - # actions=actions, - # rewards=rewards, - # next_observations=next_observations, - # done=done, - # ) - # optimizers["critic"].zero_grad() - # loss_critic.backward() - # optimizers["critic"].step() + batch = replay_buffer.sample(batch_size) - # batch = replay_buffer.sample(batch_size) + if cfg.dataset_repo_id is not None: + batch_offline = offline_replay_buffer.sample(batch_size) + batch = concatenate_batch_transitions( + left_batch_transitions=batch, right_batch_transition=batch_offline + ) - # if cfg.dataset_repo_id is not None: - # batch_offline = offline_replay_buffer.sample(batch_size) - # batch = concatenate_batch_transitions( - # left_batch_transitions=batch, right_batch_transition=batch_offline - # ) + actions = batch["action"] + rewards = batch["reward"] + observations = batch["state"] + next_observations = batch["next_state"] + done = batch["done"] - # actions = batch["action"] - # rewards = batch["reward"] - # observations = batch["state"] - # next_observations = batch["next_state"] - # done = batch["done"] + with policy_lock: + loss_critic = policy.compute_loss_critic( + observations=observations, + actions=actions, + rewards=rewards, + next_observations=next_observations, + done=done, + ) + optimizers["critic"].zero_grad() + loss_critic.backward() + optimizers["critic"].step() - # with policy_lock: - # loss_critic = policy.compute_loss_critic( - # observations=observations, - # actions=actions, - # rewards=rewards, - # next_observations=next_observations, - # done=done, - # ) - # optimizers["critic"].zero_grad() - # loss_critic.backward() - # optimizers["critic"].step() + training_infos = {} + training_infos["loss_critic"] = loss_critic.item() - # training_infos = {} - # training_infos["loss_critic"] = loss_critic.item() - # if optimization_step % cfg.training.policy_update_freq == 0: - # for _ in range(cfg.training.policy_update_freq): - # with policy_lock: - # loss_actor = policy.compute_loss_actor(observations=observations) + if optimization_step % cfg.training.policy_update_freq == 0: + for _ in range(cfg.training.policy_update_freq): + with policy_lock: + loss_actor = policy.compute_loss_actor(observations=observations) - # optimizers["actor"].zero_grad() - # loss_actor.backward() - # optimizers["actor"].step() + optimizers["actor"].zero_grad() + loss_actor.backward() + optimizers["actor"].step() - # training_infos["loss_actor"] = loss_actor.item() + training_infos["loss_actor"] = loss_actor.item() - # loss_temperature = policy.compute_loss_temperature(observations=observations) - # optimizers["temperature"].zero_grad() - # loss_temperature.backward() - # optimizers["temperature"].step() + loss_temperature = policy.compute_loss_temperature(observations=observations) + optimizers["temperature"].zero_grad() + loss_temperature.backward() + optimizers["temperature"].step() - # training_infos["loss_temperature"] = loss_temperature.item() + training_infos["loss_temperature"] = loss_temperature.item() - # if optimization_step % cfg.training.log_freq == 0: - # logger.log_dict(training_infos, step=optimization_step, mode="train") + if optimization_step % cfg.training.log_freq == 0: + logger.log_dict(training_infos, step=optimization_step, mode="train") - # policy.update_target_networks() - # optimization_step += 1 - # time_for_one_optimization_step = time.time() - time_for_one_optimization_step + policy.update_target_networks() + optimization_step += 1 + time_for_one_optimization_step = time.time() - time_for_one_optimization_step - # logger.log_dict({"[LEARNER] Time optimization step":time_for_one_optimization_step}, step=optimization_step, mode="train") - # time_for_one_optimization_step = time.time() + logging.info(f"[LEARNER] Time for one optimization step: {time_for_one_optimization_step}") + logger.log_dict({"Time optimization step":time_for_one_optimization_step}, step=optimization_step, mode="train") def make_optimizers_and_scheduler(cfg, policy): @@ -360,13 +365,13 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No ) transition_thread.start() - # param_push_thread = Thread( - # target=learner_push_parameters, - # args=(policy, policy_lock, "127.0.0.1", 50052, 15), - # # args=("127.0.0.1", 50052), - # daemon=True, - # ) - # param_push_thread.start() + param_push_thread = Thread( + target=learner_push_parameters, + args=(policy, policy_lock, "127.0.0.1", 50051, 15), + # args=("127.0.0.1", 50052), + daemon=True, + ) + param_push_thread.start() # interaction_thread = Thread( # target=add_message_interaction_to_wandb,