diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py index 8a374932..d8398ff1 100644 --- a/lerobot/common/logger.py +++ b/lerobot/common/logger.py @@ -18,6 +18,8 @@ import logging import os +import re +from glob import glob from pathlib import Path import torch @@ -73,7 +75,19 @@ class Logger: os.environ["WANDB_SILENT"] = "true" import wandb + wandb_run_id = None + if cfg.resume: + # Get the WandB run ID. + paths = glob(str(self._checkpoint_dir / "../wandb/latest-run/run-*")) + if len(paths) != 1: + raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.") + match = re.search(r"run-([^\.]+).wandb", paths[0].split("/")[-1]) + if match is None: + raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.") + wandb_run_id = match.groups(0)[0] + wandb.init( + id=wandb_run_id, project=project, entity=entity, name=job_name, @@ -87,14 +101,14 @@ class Logger: # TODO(rcadene): split train and eval, and run async eval with job_type="eval" job_type="train_eval", # TODO(rcadene): add resume option - resume="must", + resume="must" if cfg.resume else None, ) print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"])) logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}") self._wandb = wandb @property - def last_checkpoint_path(self): + def last_checkpoint_path(self) -> Path: return self._last_checkpoint_path def save_model(self, policy: Policy, identifier: str): @@ -112,6 +126,8 @@ class Logger: ) artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE) self._wandb.log_artifact(artifact) + if self._last_checkpoint_path.exists(): + os.remove(self._last_checkpoint_path) os.symlink(save_dir.absolute(), self._last_checkpoint_path) # TODO(now): Check this works def save_training_state( diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index dc17f604..3713c9a7 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -322,8 +322,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No logging.info("make_policy") policy = make_policy( hydra_cfg=cfg, - dataset_stats=offline_dataset.stats, - pretrained_policy_name_or_path=logger.last_checkpoint_path if cfg.resume else None, + dataset_stats=offline_dataset.stats if not cfg.resume else None, + pretrained_policy_name_or_path=str(logger.last_checkpoint_path) if cfg.resume else None, ) # Create optimizer and scheduler @@ -335,10 +335,12 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No if cfg.resume: print("You have set resume=True, indicating that you wish to resume a run.") # Make sure there is a checkpoint. - if not Path(logger.last_checkpoint_path).exists(): - raise RuntimeError(f"You have set resume=True, but {logger.last_checkpoint_path} does not exist.") + if not logger.last_checkpoint_path.exists(): + raise RuntimeError( + f"You have set resume=True, but {str(logger.last_checkpoint_path)} does not exist." + ) # Get the configuration file from the last checkpoint. - checkpoint_cfg = init_hydra_config(logger.last_checkpoint_path) + checkpoint_cfg = init_hydra_config(str(logger.last_checkpoint_path)) # TODO(now): Do a diff check. cfg = checkpoint_cfg step = logger.load_last_training_state(optimizer, lr_scheduler) @@ -376,8 +378,11 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No logging.info(f"Checkpoint policy after step {step}") # Note: Save with step as the identifier, and format it to have at least 6 digits but more if # needed (choose 6 as a minimum for consistency without being overkill). - logger.save_model( + logger.save_checkpont( + step, policy, + optimizer, + lr_scheduler, identifier=str(step).zfill( max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps))) ),