This commit is contained in:
Alexander Soare 2024-05-21 18:37:39 +01:00
parent ae96c16cba
commit bcdcceb2f5
5 changed files with 87 additions and 44 deletions

View File

@ -48,13 +48,16 @@ test-act-ete-train:
training.batch_size=2 \ training.batch_size=2 \
hydra.run.dir=tests/outputs/act/ hydra.run.dir=tests/outputs/act/
# TODO(alexander-soare): This does not test override_config_on_resume=false. To do so, we need a way of
# interrupting the prior training before it is done so that we don't need to increase
# `training.offline_steps`.
test-act-ete-train-resume: test-act-ete-train-resume:
python lerobot/scripts/train.py \ python lerobot/scripts/train.py \
policy=act \ policy=act \
policy.dim_model=64 \ policy.dim_model=64 \
env=aloha \ env=aloha \
wandb.enable=False \ wandb.enable=False \
training.offline_steps=2 \ training.offline_steps=4 \
training.online_steps=0 \ training.online_steps=0 \
eval.n_episodes=1 \ eval.n_episodes=1 \
eval.batch_size=1 \ eval.batch_size=1 \
@ -64,9 +67,9 @@ test-act-ete-train-resume:
policy.n_action_steps=20 \ policy.n_action_steps=20 \
policy.chunk_size=20 \ policy.chunk_size=20 \
training.batch_size=2 \ training.batch_size=2 \
hydra.run.dir=tests/outputs/act/ hydra.run.dir=tests/outputs/act/ \
resume=true resume=true \
override_config_on_resume=true
test-act-ete-eval: test-act-ete-eval:
python lerobot/scripts/eval.py \ python lerobot/scripts/eval.py \

View File

@ -154,7 +154,9 @@ There's one new thing here: `hydra.run.dir=outputs/train/act_aloha_sim_transfer_
## Resuming a training run ## Resuming a training run
If your training run is interrupted partway through (or you finish a training run and want to pick up where you left off), you may resume the run. All that's required is that you run the same training command again, but add `resume=true`. Note that the configuration you provide in your training command will be the one that's used. If you change something (for example, you increase the number of training steps), it will override the prior configuration. If your training run is interrupted partway through (or you finish a training run and want to pick up where you left off), you may resume the run. All that's required is that you run the same training command again, but add `resume=true`.
Note that with `resume=true` the default behavior is to ignore the configuration you provide with your training command, and use the one from the checkpoint. But, it is possible to use `override_config_on_resume=true` to override the prior configuration with the configuration in your training command.
--- ---

View File

@ -68,13 +68,18 @@ class Logger:
pretrained_model_dir_name = "pretrained_model" pretrained_model_dir_name = "pretrained_model"
training_state_file_name = "training_state.pth" training_state_file_name = "training_state.pth"
def __init__(self, log_dir: str, job_name: str, cfg: DictConfig): def __init__(self, cfg: DictConfig, log_dir: str, wandb_job_name: str | None = None):
"""
Args:
log_dir: The directory to save all logs and training outputs to.
job_name: The WandB job name.
"""
self._cfg = cfg self._cfg = cfg
self.log_dir = Path(log_dir) self.log_dir = Path(log_dir)
self.log_dir.mkdir(parents=True, exist_ok=True) self.log_dir.mkdir(parents=True, exist_ok=True)
self.checkpoints_dir = self.log_dir / "checkpoints" self.checkpoints_dir = self.get_checkpoints_dir(log_dir)
self.last_checkpoint_dir = self.checkpoints_dir / "last" self.last_checkpoint_dir = self.get_last_checkpoint_dir(log_dir)
self.last_pretrained_model_dir = self.last_checkpoint_dir / self.pretrained_model_dir_name self.last_pretrained_model_dir = self.get_last_pretrained_model_dir(log_dir)
# Set up WandB. # Set up WandB.
self._group = cfg_to_group(cfg) self._group = cfg_to_group(cfg)
@ -97,7 +102,7 @@ class Logger:
id=wandb_run_id, id=wandb_run_id,
project=project, project=project,
entity=entity, entity=entity,
name=job_name, name=wandb_job_name,
notes=cfg.get("wandb", {}).get("notes"), notes=cfg.get("wandb", {}).get("notes"),
tags=cfg_to_group(cfg, return_list=True), tags=cfg_to_group(cfg, return_list=True),
dir=log_dir, dir=log_dir,
@ -112,6 +117,24 @@ class Logger:
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}") logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
self._wandb = wandb self._wandb = wandb
@classmethod
def get_checkpoints_dir(cls, log_dir: str | Path) -> Path:
"""Given the log directory, get the sub-directory in which checkpoints will be saved."""
return Path(log_dir) / "checkpoints"
@classmethod
def get_last_checkpoint_dir(cls, log_dir: str | Path) -> Path:
"""Given the log directory, get the sub-directory in which the last checkpoint will be saved."""
return cls.get_checkpoints_dir(log_dir) / "last"
@classmethod
def get_last_pretrained_model_dir(cls, log_dir: str | Path) -> Path:
"""
Given the log directory, get the sub-directory in which the last checkpoint's pretrained weights will
be saved.
"""
return cls.get_last_checkpoint_dir(log_dir) / cls.pretrained_model_dir_name
def save_model(self, save_dir: Path, policy: Policy, wandb_artifact_name: str | None = None): def save_model(self, save_dir: Path, policy: Policy, wandb_artifact_name: str | None = None):
"""Save the weights of the Policy model using PyTorchModelHubMixin. """Save the weights of the Policy model using PyTorchModelHubMixin.

View File

@ -13,9 +13,12 @@ hydra:
# Set `resume` to true to resume a previous run. In order for this to work, you will need to make sure # Set `resume` to true to resume a previous run. In order for this to work, you will need to make sure
# `hydra.run.dir` is the directory of an existing run with at least one checkpoint in it. # `hydra.run.dir` is the directory of an existing run with at least one checkpoint in it.
# Note that when resuming a run, the provided configuration takes precedence over the checkpoint # Note that when resuming a run, the default behavior is to use the configuration from the checkpoint,
# configuration. # regardless of what's provided with the training command at the time of resumption.
resume: false resume: false
# Set `override_config_on_resume` to true to use the provided configuration instead of the one from the
# checkpoint.
override_config_on_resume: false
device: cuda # cpu device: cuda # cpu
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP, # `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
# automatic gradient scaling is used. # automatic gradient scaling is used.

View File

@ -223,44 +223,56 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
init_logging() init_logging()
# log metrics to terminal and wandb
logger = Logger(out_dir, job_name, cfg)
# If we are resuming a run, we need to check that a checkpoint exists in the log directory, and we need # If we are resuming a run, we need to check that a checkpoint exists in the log directory, and we need
# to check for any differences between the provided config and the checkpoint's config. # to check for any differences between the provided config and the checkpoint's config.
if cfg.resume: if cfg.resume:
if not logger.last_checkpoint_dir.exists(): if not Logger.get_last_checkpoint_dir(out_dir).exists():
raise RuntimeError( raise RuntimeError(
f"You have set resume=True, but there is no model checpoint in {logger.last_checkpoint_dir}." "You have set resume=True, but there is no model checkpoint in "
f"{Logger.get_last_checkpoint_dir(out_dir)}."
) )
else: checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
checkpoint_cfg_path = str(logger.last_pretrained_model_dir / "config.yaml") logging.info(
logging.info( colored(
colored( "You have set resume=True, indicating that you wish to resume a run",
"You have set resume=True, indicating that you wish to resume a run. The provided config " color="yellow",
f"is being overriden by {checkpoint_cfg_path}", attrs=["bold"],
color="yellow",
attrs=["bold"],
)
) )
# Get the configuration file from the last checkpoint. )
checkpoint_cfg = init_hydra_config(checkpoint_cfg_path) # Get the configuration file from the last checkpoint.
# Hack to resolve the delta_timestamps ahead of time in order to properly diff. checkpoint_cfg = init_hydra_config(checkpoint_cfg_path)
resolve_delta_timestamps(cfg) # Check for differences between the checkpoint configuration and provided configuration.
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg)) # Hack to resolve the delta_timestamps ahead of time in order to properly diff.
if len(diff) > 0: resolve_delta_timestamps(cfg)
# Log a warning about differences between the checkpoint configuration and the provided diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
# configuration (but ignore the `resume` parameter). # Ignore the `resume` and `override_config_on_resume` parameters.
if "values_changed" in diff and "root['resume']" in diff["values_changed"]: if "values_changed" in diff:
del diff["values_changed"]["root['resume']"] for k in ["root['resume']", "root['override_config_on_resume']"]:
logging.warning( if k in diff["values_changed"]:
colored( del diff["values_changed"][k]
"At least one difference was detected between the checkpoint configuration and the " # Log a warning about differences between the checkpoint configuration and the provided
f"provided configuration: \n{pformat(diff)}\nNote that the provided configuration " # configuration.
"takes precedence.", logging.warning(
color="yellow", colored(
) "At least one difference was detected between the checkpoint configuration and "
) f"the provided configuration: \n{pformat(diff)}\nNote that since "
f"{cfg.override_config_on_resume=}, the "
f"{'provided' if cfg.override_config_on_resume else 'checkpoint'} configuration takes "
"precedence.",
color="yellow",
)
)
if not cfg.override_config_on_resume:
# Use the checkpoint config instead of the provided config (but keep the provided `resume` and
# `override_config_on_resume` parameters).
resume = cfg.resume
override_config_on_resume = cfg.override_config_on_resume
cfg = checkpoint_cfg
cfg.resume = resume
cfg.override_config_on_resume = override_config_on_resume
# log metrics to terminal and wandb
logger = Logger(cfg, out_dir, wandb_job_name=job_name)
if cfg.training.online_steps > 0: if cfg.training.online_steps > 0:
raise NotImplementedError("Online training is not implemented yet.") raise NotImplementedError("Online training is not implemented yet.")