diff --git a/Makefile b/Makefile index 964de343..9877ee86 100644 --- a/Makefile +++ b/Makefile @@ -48,13 +48,16 @@ test-act-ete-train: training.batch_size=2 \ hydra.run.dir=tests/outputs/act/ +# TODO(alexander-soare): This does not test override_config_on_resume=false. To do so, we need a way of +# interrupting the prior training before it is done so that we don't need to increase +# `training.offline_steps`. test-act-ete-train-resume: python lerobot/scripts/train.py \ policy=act \ policy.dim_model=64 \ env=aloha \ wandb.enable=False \ - training.offline_steps=2 \ + training.offline_steps=4 \ training.online_steps=0 \ eval.n_episodes=1 \ eval.batch_size=1 \ @@ -64,9 +67,9 @@ test-act-ete-train-resume: policy.n_action_steps=20 \ policy.chunk_size=20 \ training.batch_size=2 \ - hydra.run.dir=tests/outputs/act/ - resume=true - + hydra.run.dir=tests/outputs/act/ \ + resume=true \ + override_config_on_resume=true test-act-ete-eval: python lerobot/scripts/eval.py \ diff --git a/examples/4_train_policy_with_script.md b/examples/4_train_policy_with_script.md index 9af23c90..29e0a73c 100644 --- a/examples/4_train_policy_with_script.md +++ b/examples/4_train_policy_with_script.md @@ -154,7 +154,9 @@ There's one new thing here: `hydra.run.dir=outputs/train/act_aloha_sim_transfer_ ## Resuming a training run -If your training run is interrupted partway through (or you finish a training run and want to pick up where you left off), you may resume the run. All that's required is that you run the same training command again, but add `resume=true`. Note that the configuration you provide in your training command will be the one that's used. If you change something (for example, you increase the number of training steps), it will override the prior configuration. +If your training run is interrupted partway through (or you finish a training run and want to pick up where you left off), you may resume the run. All that's required is that you run the same training command again, but add `resume=true`. + +Note that with `resume=true` the default behavior is to ignore the configuration you provide with your training command, and use the one from the checkpoint. But, it is possible to use `override_config_on_resume=true` to override the prior configuration with the configuration in your training command. --- diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py index 867b804b..c2b6f05b 100644 --- a/lerobot/common/logger.py +++ b/lerobot/common/logger.py @@ -68,13 +68,18 @@ class Logger: pretrained_model_dir_name = "pretrained_model" training_state_file_name = "training_state.pth" - def __init__(self, log_dir: str, job_name: str, cfg: DictConfig): + def __init__(self, cfg: DictConfig, log_dir: str, wandb_job_name: str | None = None): + """ + Args: + log_dir: The directory to save all logs and training outputs to. + job_name: The WandB job name. + """ self._cfg = cfg self.log_dir = Path(log_dir) self.log_dir.mkdir(parents=True, exist_ok=True) - self.checkpoints_dir = self.log_dir / "checkpoints" - self.last_checkpoint_dir = self.checkpoints_dir / "last" - self.last_pretrained_model_dir = self.last_checkpoint_dir / self.pretrained_model_dir_name + self.checkpoints_dir = self.get_checkpoints_dir(log_dir) + self.last_checkpoint_dir = self.get_last_checkpoint_dir(log_dir) + self.last_pretrained_model_dir = self.get_last_pretrained_model_dir(log_dir) # Set up WandB. self._group = cfg_to_group(cfg) @@ -97,7 +102,7 @@ class Logger: id=wandb_run_id, project=project, entity=entity, - name=job_name, + name=wandb_job_name, notes=cfg.get("wandb", {}).get("notes"), tags=cfg_to_group(cfg, return_list=True), dir=log_dir, @@ -112,6 +117,24 @@ class Logger: logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}") self._wandb = wandb + @classmethod + def get_checkpoints_dir(cls, log_dir: str | Path) -> Path: + """Given the log directory, get the sub-directory in which checkpoints will be saved.""" + return Path(log_dir) / "checkpoints" + + @classmethod + def get_last_checkpoint_dir(cls, log_dir: str | Path) -> Path: + """Given the log directory, get the sub-directory in which the last checkpoint will be saved.""" + return cls.get_checkpoints_dir(log_dir) / "last" + + @classmethod + def get_last_pretrained_model_dir(cls, log_dir: str | Path) -> Path: + """ + Given the log directory, get the sub-directory in which the last checkpoint's pretrained weights will + be saved. + """ + return cls.get_last_checkpoint_dir(log_dir) / cls.pretrained_model_dir_name + def save_model(self, save_dir: Path, policy: Policy, wandb_artifact_name: str | None = None): """Save the weights of the Policy model using PyTorchModelHubMixin. diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index 352a9192..5593af83 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -13,9 +13,12 @@ hydra: # Set `resume` to true to resume a previous run. In order for this to work, you will need to make sure # `hydra.run.dir` is the directory of an existing run with at least one checkpoint in it. -# Note that when resuming a run, the provided configuration takes precedence over the checkpoint -# configuration. +# Note that when resuming a run, the default behavior is to use the configuration from the checkpoint, +# regardless of what's provided with the training command at the time of resumption. resume: false +# Set `override_config_on_resume` to true to use the provided configuration instead of the one from the +# checkpoint. +override_config_on_resume: false device: cuda # cpu # `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP, # automatic gradient scaling is used. diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index b7308273..879605c1 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -223,44 +223,56 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No init_logging() - # log metrics to terminal and wandb - logger = Logger(out_dir, job_name, cfg) - # If we are resuming a run, we need to check that a checkpoint exists in the log directory, and we need # to check for any differences between the provided config and the checkpoint's config. if cfg.resume: - if not logger.last_checkpoint_dir.exists(): + if not Logger.get_last_checkpoint_dir(out_dir).exists(): raise RuntimeError( - f"You have set resume=True, but there is no model checpoint in {logger.last_checkpoint_dir}." + "You have set resume=True, but there is no model checkpoint in " + f"{Logger.get_last_checkpoint_dir(out_dir)}." ) - else: - checkpoint_cfg_path = str(logger.last_pretrained_model_dir / "config.yaml") - logging.info( - colored( - "You have set resume=True, indicating that you wish to resume a run. The provided config " - f"is being overriden by {checkpoint_cfg_path}", - color="yellow", - attrs=["bold"], - ) + checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml") + logging.info( + colored( + "You have set resume=True, indicating that you wish to resume a run", + color="yellow", + attrs=["bold"], ) - # Get the configuration file from the last checkpoint. - checkpoint_cfg = init_hydra_config(checkpoint_cfg_path) - # Hack to resolve the delta_timestamps ahead of time in order to properly diff. - resolve_delta_timestamps(cfg) - diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg)) - if len(diff) > 0: - # Log a warning about differences between the checkpoint configuration and the provided - # configuration (but ignore the `resume` parameter). - if "values_changed" in diff and "root['resume']" in diff["values_changed"]: - del diff["values_changed"]["root['resume']"] - logging.warning( - colored( - "At least one difference was detected between the checkpoint configuration and the " - f"provided configuration: \n{pformat(diff)}\nNote that the provided configuration " - "takes precedence.", - color="yellow", - ) - ) + ) + # Get the configuration file from the last checkpoint. + checkpoint_cfg = init_hydra_config(checkpoint_cfg_path) + # Check for differences between the checkpoint configuration and provided configuration. + # Hack to resolve the delta_timestamps ahead of time in order to properly diff. + resolve_delta_timestamps(cfg) + diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg)) + # Ignore the `resume` and `override_config_on_resume` parameters. + if "values_changed" in diff: + for k in ["root['resume']", "root['override_config_on_resume']"]: + if k in diff["values_changed"]: + del diff["values_changed"][k] + # Log a warning about differences between the checkpoint configuration and the provided + # configuration. + logging.warning( + colored( + "At least one difference was detected between the checkpoint configuration and " + f"the provided configuration: \n{pformat(diff)}\nNote that since " + f"{cfg.override_config_on_resume=}, the " + f"{'provided' if cfg.override_config_on_resume else 'checkpoint'} configuration takes " + "precedence.", + color="yellow", + ) + ) + if not cfg.override_config_on_resume: + # Use the checkpoint config instead of the provided config (but keep the provided `resume` and + # `override_config_on_resume` parameters). + resume = cfg.resume + override_config_on_resume = cfg.override_config_on_resume + cfg = checkpoint_cfg + cfg.resume = resume + cfg.override_config_on_resume = override_config_on_resume + + # log metrics to terminal and wandb + logger = Logger(cfg, out_dir, wandb_job_name=job_name) if cfg.training.online_steps > 0: raise NotImplementedError("Online training is not implemented yet.")