diff --git a/lerobot/configs/policy/sac_maniskill.yaml b/lerobot/configs/policy/sac_maniskill.yaml index fa3dca37..2776b39d 100644 --- a/lerobot/configs/policy/sac_maniskill.yaml +++ b/lerobot/configs/policy/sac_maniskill.yaml @@ -21,7 +21,7 @@ training: eval_freq: 2500 log_freq: 500 - save_freq: 1000000 + save_freq: 2000000 online_steps: 1000000 online_rollout_n_episodes: 10 diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index cf6d8c76..dbafeb42 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -34,8 +34,7 @@ from deepdiff import DeepDiff from omegaconf import DictConfig, OmegaConf from termcolor import colored from torch import nn - -from lerobot.common.datasets.factory import make_dataset +from torch.optim.optimizer import Optimizer # TODO: Remove the import of maniskill from lerobot.common.datasets.lerobot_dataset import LeRobotDataset @@ -53,18 +52,164 @@ from lerobot.common.utils.utils import ( ) from lerobot.scripts.server.buffer import ( ReplayBuffer, - concatenate_batch_transitions, move_state_dict_to_device, move_transition_to_device, ) logging.basicConfig(level=logging.INFO) -# TODO: Implement it in cleaner way maybe transition_queue = queue.Queue() interaction_message_queue = queue.Queue() +def handle_resume_logic(cfg: DictConfig, out_dir: str) -> DictConfig: + if not cfg.resume: + if Logger.get_last_checkpoint_dir(out_dir).exists(): + raise RuntimeError( + f"Output directory {Logger.get_last_checkpoint_dir(out_dir)} already exists. " + "Use `resume=true` to resume training." + ) + return cfg + + # if resume == True + checkpoint_dir = Logger.get_last_checkpoint_dir(out_dir) + if not checkpoint_dir.exists(): + raise RuntimeError(f"No model checkpoint found in {checkpoint_dir} for resume=True") + + checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml") + logging.info( + colored( + "Resume=True detected, resuming previous run", + color="yellow", + attrs=["bold"], + ) + ) + + checkpoint_cfg = init_hydra_config(checkpoint_cfg_path) + diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg)) + + if "values_changed" in diff and "root['resume']" in diff["values_changed"]: + del diff["values_changed"]["root['resume']"] + + if len(diff) > 0: + logging.warning( + f"Differences between the checkpoint config and the provided config detected: \n{pformat(diff)}\n" + "Checkpoint configuration takes precedence." + ) + + checkpoint_cfg.resume = True + return checkpoint_cfg + + +def load_training_state( + cfg: DictConfig, + logger: Logger, + optimizers: Optimizer | dict, +): + if not cfg.resume: + return None, None + + training_state = torch.load(logger.last_checkpoint_dir / logger.training_state_file_name) + + if isinstance(training_state["optimizer"], dict): + assert set(training_state["optimizer"].keys()) == set(optimizers.keys()) + for k, v in training_state["optimizer"].items(): + optimizers[k].load_state_dict(v) + else: + optimizers.load_state_dict(training_state["optimizer"]) + + set_global_random_state({k: training_state[k] for k in get_global_random_state()}) + return training_state["step"], training_state["interaction_step"] + + +def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None: + num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) + num_total_params = sum(p.numel() for p in policy.parameters()) + + log_output_dir(out_dir) + logging.info(f"{cfg.env.task=}") + logging.info(f"{cfg.training.online_steps=}") + logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") + logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") + + +def initialize_replay_buffer(cfg: DictConfig, logger: Logger, device: str) -> ReplayBuffer: + if not cfg.resume: + return ReplayBuffer( + capacity=cfg.training.online_buffer_capacity, + device=device, + state_keys=cfg.policy.input_shapes.keys(), + ) + + dataset = LeRobotDataset( + repo_id=cfg.dataset_repo_id, local_files_only=True, root=logger.log_dir / "dataset" + ) + return ReplayBuffer.from_lerobot_dataset( + lerobot_dataset=dataset, + capacity=cfg.training.online_buffer_capacity, + device=device, + state_keys=cfg.policy.input_shapes.keys(), + ) + + +def start_learner_threads( + cfg: DictConfig, + device: str, + replay_buffer: ReplayBuffer, + offline_replay_buffer: ReplayBuffer, + batch_size: int, + optimizers: dict, + policy: SACPolicy, + policy_lock: Lock, + logger: Logger, + resume_optimization_step: int | None = None, + resume_interaction_step: int | None = None, +) -> None: + actor_ip = cfg.actor_learner_config.actor_ip + port = cfg.actor_learner_config.port + + server_thread = Thread( + target=stream_transitions_from_actor, + args=( + actor_ip, + port, + ), + daemon=True, + ) + + transition_thread = Thread( + target=add_actor_information_and_train, + daemon=True, + args=( + cfg, + device, + replay_buffer, + offline_replay_buffer, + batch_size, + optimizers, + policy, + policy_lock, + logger, + resume_optimization_step, + resume_interaction_step, + ), + ) + + param_push_thread = Thread( + target=learner_push_parameters, + args=(policy, policy_lock, actor_ip, port, 15), + daemon=True, + ) + + server_thread.start() + transition_thread.start() + param_push_thread.start() + + param_push_thread.join() + transition_thread.join() + server_thread.join() + + def stream_transitions_from_actor(host="127.0.0.1", port=50051): """ Runs a gRPC client that listens for transition and interaction messages from an Actor service. @@ -373,49 +518,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No logging.info(pformat(OmegaConf.to_container(cfg))) logger = Logger(cfg, out_dir, wandb_job_name=job_name) - - ## Handle resume by reloading the state of the policy and optimization - # If we are resuming a run, we need to check that a checkpoint exists in the log directory, and we need - # to check for any differences between the provided config and the checkpoint's config. - if cfg.resume: - if not Logger.get_last_checkpoint_dir(out_dir).exists(): - raise RuntimeError( - "You have set resume=True, but there is no model checkpoint in " - f"{Logger.get_last_checkpoint_dir(out_dir)}" - ) - checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml") - logging.info( - colored( - "You have set resume=True, indicating that you wish to resume a run", - color="yellow", - attrs=["bold"], - ) - ) - # Get the configuration file from the last checkpoint. - checkpoint_cfg = init_hydra_config(checkpoint_cfg_path) - # Check for differences between the checkpoint configuration and provided configuration. - diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg)) - # Ignore the `resume` and parameters. - if "values_changed" in diff and "root['resume']" in diff["values_changed"]: - del diff["values_changed"]["root['resume']"] - - # Log a warning about differences between the checkpoint configuration and the provided - # configuration. - if len(diff) > 0: - logging.warning( - "At least one difference was detected between the checkpoint configuration and " - f"the provided configuration: \n{pformat(diff)}\nNote that the checkpoint configuration " - "takes precedence.", - ) - # Use the checkpoint config instead of the provided config (but keep `resume` parameter). - cfg = checkpoint_cfg - cfg.resume = True - elif Logger.get_last_checkpoint_dir(out_dir).exists(): - raise RuntimeError( - f"The configured output directory {Logger.get_last_checkpoint_dir(out_dir)} already exists. If " - "you meant to resume training, please use `resume=true` in your command or yaml configuration." - ) - # =========================== + cfg = handle_resume_logic(cfg, out_dir) set_global_seed(cfg.seed) @@ -438,57 +541,14 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No dataset_stats=None, pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None, ) - # device=device, - # ) assert isinstance(policy, nn.Module) optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy) - # load last training state - # We can't use the logger function in `lerobot/common/logger.py` - # because it only loads the optimization step and not the interaction one - # to avoid altering that code, we will just load the optimization state manually - resume_interaction_step, resume_optimization_step = None, None - if cfg.resume: - training_state = torch.load(logger.last_checkpoint_dir / logger.training_state_file_name) - if type(training_state["optimizer"]) is dict: - assert set(training_state["optimizer"].keys()) == set(optimizers.keys()), ( - "Optimizer dictionaries do not have the same keys during resume!" - ) - for k, v in training_state["optimizer"].items(): - optimizers[k].load_state_dict(v) - else: - optimizers.load_state_dict(training_state["optimizer"]) - # Small hack to get the expected keys: use `get_global_random_state`. - set_global_random_state({k: training_state[k] for k in get_global_random_state()}) - resume_optimization_step = training_state["step"] - resume_interaction_step = training_state["interaction_step"] + resume_optimization_step, resume_interaction_step = load_training_state(cfg, logger, optimizers) - num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) - num_total_params = sum(p.numel() for p in policy.parameters()) + log_training_info(cfg, out_dir, policy) - log_output_dir(out_dir) - logging.info(f"{cfg.env.task=}") - logging.info(f"{cfg.training.online_steps=}") - logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") - logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") - - if not cfg.resume: - replay_buffer = ReplayBuffer( - capacity=cfg.training.online_buffer_capacity, - device=device, - state_keys=cfg.policy.input_shapes.keys(), - ) - else: - # Reload replay buffer - dataset = LeRobotDataset( - repo_id=cfg.dataset_repo_id, local_files_only=True, root=logger.log_dir / "dataset" - ) - replay_buffer = ReplayBuffer.from_lerobot_dataset( - lerobot_dataset=dataset, - capacity=cfg.training.online_buffer_capacity, - device=device, - state_keys=cfg.policy.input_shapes.keys(), - ) + replay_buffer = initialize_replay_buffer(cfg, logger, device) batch_size = cfg.training.batch_size offline_replay_buffer = None @@ -501,47 +561,19 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No # ) # batch_size: int = batch_size // 2 # We will sample from both replay buffer - actor_ip = cfg.actor_learner_config.actor_ip - port = cfg.actor_learner_config.port - - server_thread = Thread( - target=stream_transitions_from_actor, - args=( - actor_ip, - port, - ), - daemon=True, + start_learner_threads( + cfg, + device, + replay_buffer, + offline_replay_buffer, + batch_size, + optimizers, + policy, + policy_lock, + logger, + resume_optimization_step, + resume_interaction_step, ) - server_thread.start() - - transition_thread = Thread( - target=add_actor_information_and_train, - daemon=True, - args=( - cfg, - device, - replay_buffer, - offline_replay_buffer, - batch_size, - optimizers, - policy, - policy_lock, - logger, - resume_optimization_step, - resume_interaction_step, - ), - ) - transition_thread.start() - - param_push_thread = Thread( - target=learner_push_parameters, - args=(policy, policy_lock, actor_ip, port, 15), - daemon=True, - ) - param_push_thread.start() - - transition_thread.join() - server_thread.join() @hydra.main(version_base="1.2", config_name="default", config_path="../../configs")