revision
This commit is contained in:
parent
ae96c16cba
commit
bcdcceb2f5
11
Makefile
11
Makefile
|
@ -48,13 +48,16 @@ test-act-ete-train:
|
||||||
training.batch_size=2 \
|
training.batch_size=2 \
|
||||||
hydra.run.dir=tests/outputs/act/
|
hydra.run.dir=tests/outputs/act/
|
||||||
|
|
||||||
|
# TODO(alexander-soare): This does not test override_config_on_resume=false. To do so, we need a way of
|
||||||
|
# interrupting the prior training before it is done so that we don't need to increase
|
||||||
|
# `training.offline_steps`.
|
||||||
test-act-ete-train-resume:
|
test-act-ete-train-resume:
|
||||||
python lerobot/scripts/train.py \
|
python lerobot/scripts/train.py \
|
||||||
policy=act \
|
policy=act \
|
||||||
policy.dim_model=64 \
|
policy.dim_model=64 \
|
||||||
env=aloha \
|
env=aloha \
|
||||||
wandb.enable=False \
|
wandb.enable=False \
|
||||||
training.offline_steps=2 \
|
training.offline_steps=4 \
|
||||||
training.online_steps=0 \
|
training.online_steps=0 \
|
||||||
eval.n_episodes=1 \
|
eval.n_episodes=1 \
|
||||||
eval.batch_size=1 \
|
eval.batch_size=1 \
|
||||||
|
@ -64,9 +67,9 @@ test-act-ete-train-resume:
|
||||||
policy.n_action_steps=20 \
|
policy.n_action_steps=20 \
|
||||||
policy.chunk_size=20 \
|
policy.chunk_size=20 \
|
||||||
training.batch_size=2 \
|
training.batch_size=2 \
|
||||||
hydra.run.dir=tests/outputs/act/
|
hydra.run.dir=tests/outputs/act/ \
|
||||||
resume=true
|
resume=true \
|
||||||
|
override_config_on_resume=true
|
||||||
|
|
||||||
test-act-ete-eval:
|
test-act-ete-eval:
|
||||||
python lerobot/scripts/eval.py \
|
python lerobot/scripts/eval.py \
|
||||||
|
|
|
@ -154,7 +154,9 @@ There's one new thing here: `hydra.run.dir=outputs/train/act_aloha_sim_transfer_
|
||||||
|
|
||||||
## Resuming a training run
|
## Resuming a training run
|
||||||
|
|
||||||
If your training run is interrupted partway through (or you finish a training run and want to pick up where you left off), you may resume the run. All that's required is that you run the same training command again, but add `resume=true`. Note that the configuration you provide in your training command will be the one that's used. If you change something (for example, you increase the number of training steps), it will override the prior configuration.
|
If your training run is interrupted partway through (or you finish a training run and want to pick up where you left off), you may resume the run. All that's required is that you run the same training command again, but add `resume=true`.
|
||||||
|
|
||||||
|
Note that with `resume=true` the default behavior is to ignore the configuration you provide with your training command, and use the one from the checkpoint. But, it is possible to use `override_config_on_resume=true` to override the prior configuration with the configuration in your training command.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|
|
@ -68,13 +68,18 @@ class Logger:
|
||||||
pretrained_model_dir_name = "pretrained_model"
|
pretrained_model_dir_name = "pretrained_model"
|
||||||
training_state_file_name = "training_state.pth"
|
training_state_file_name = "training_state.pth"
|
||||||
|
|
||||||
def __init__(self, log_dir: str, job_name: str, cfg: DictConfig):
|
def __init__(self, cfg: DictConfig, log_dir: str, wandb_job_name: str | None = None):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
log_dir: The directory to save all logs and training outputs to.
|
||||||
|
job_name: The WandB job name.
|
||||||
|
"""
|
||||||
self._cfg = cfg
|
self._cfg = cfg
|
||||||
self.log_dir = Path(log_dir)
|
self.log_dir = Path(log_dir)
|
||||||
self.log_dir.mkdir(parents=True, exist_ok=True)
|
self.log_dir.mkdir(parents=True, exist_ok=True)
|
||||||
self.checkpoints_dir = self.log_dir / "checkpoints"
|
self.checkpoints_dir = self.get_checkpoints_dir(log_dir)
|
||||||
self.last_checkpoint_dir = self.checkpoints_dir / "last"
|
self.last_checkpoint_dir = self.get_last_checkpoint_dir(log_dir)
|
||||||
self.last_pretrained_model_dir = self.last_checkpoint_dir / self.pretrained_model_dir_name
|
self.last_pretrained_model_dir = self.get_last_pretrained_model_dir(log_dir)
|
||||||
|
|
||||||
# Set up WandB.
|
# Set up WandB.
|
||||||
self._group = cfg_to_group(cfg)
|
self._group = cfg_to_group(cfg)
|
||||||
|
@ -97,7 +102,7 @@ class Logger:
|
||||||
id=wandb_run_id,
|
id=wandb_run_id,
|
||||||
project=project,
|
project=project,
|
||||||
entity=entity,
|
entity=entity,
|
||||||
name=job_name,
|
name=wandb_job_name,
|
||||||
notes=cfg.get("wandb", {}).get("notes"),
|
notes=cfg.get("wandb", {}).get("notes"),
|
||||||
tags=cfg_to_group(cfg, return_list=True),
|
tags=cfg_to_group(cfg, return_list=True),
|
||||||
dir=log_dir,
|
dir=log_dir,
|
||||||
|
@ -112,6 +117,24 @@ class Logger:
|
||||||
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
||||||
self._wandb = wandb
|
self._wandb = wandb
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_checkpoints_dir(cls, log_dir: str | Path) -> Path:
|
||||||
|
"""Given the log directory, get the sub-directory in which checkpoints will be saved."""
|
||||||
|
return Path(log_dir) / "checkpoints"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_last_checkpoint_dir(cls, log_dir: str | Path) -> Path:
|
||||||
|
"""Given the log directory, get the sub-directory in which the last checkpoint will be saved."""
|
||||||
|
return cls.get_checkpoints_dir(log_dir) / "last"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_last_pretrained_model_dir(cls, log_dir: str | Path) -> Path:
|
||||||
|
"""
|
||||||
|
Given the log directory, get the sub-directory in which the last checkpoint's pretrained weights will
|
||||||
|
be saved.
|
||||||
|
"""
|
||||||
|
return cls.get_last_checkpoint_dir(log_dir) / cls.pretrained_model_dir_name
|
||||||
|
|
||||||
def save_model(self, save_dir: Path, policy: Policy, wandb_artifact_name: str | None = None):
|
def save_model(self, save_dir: Path, policy: Policy, wandb_artifact_name: str | None = None):
|
||||||
"""Save the weights of the Policy model using PyTorchModelHubMixin.
|
"""Save the weights of the Policy model using PyTorchModelHubMixin.
|
||||||
|
|
||||||
|
|
|
@ -13,9 +13,12 @@ hydra:
|
||||||
|
|
||||||
# Set `resume` to true to resume a previous run. In order for this to work, you will need to make sure
|
# Set `resume` to true to resume a previous run. In order for this to work, you will need to make sure
|
||||||
# `hydra.run.dir` is the directory of an existing run with at least one checkpoint in it.
|
# `hydra.run.dir` is the directory of an existing run with at least one checkpoint in it.
|
||||||
# Note that when resuming a run, the provided configuration takes precedence over the checkpoint
|
# Note that when resuming a run, the default behavior is to use the configuration from the checkpoint,
|
||||||
# configuration.
|
# regardless of what's provided with the training command at the time of resumption.
|
||||||
resume: false
|
resume: false
|
||||||
|
# Set `override_config_on_resume` to true to use the provided configuration instead of the one from the
|
||||||
|
# checkpoint.
|
||||||
|
override_config_on_resume: false
|
||||||
device: cuda # cpu
|
device: cuda # cpu
|
||||||
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
|
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
|
||||||
# automatic gradient scaling is used.
|
# automatic gradient scaling is used.
|
||||||
|
|
|
@ -223,44 +223,56 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
|
|
||||||
init_logging()
|
init_logging()
|
||||||
|
|
||||||
# log metrics to terminal and wandb
|
|
||||||
logger = Logger(out_dir, job_name, cfg)
|
|
||||||
|
|
||||||
# If we are resuming a run, we need to check that a checkpoint exists in the log directory, and we need
|
# If we are resuming a run, we need to check that a checkpoint exists in the log directory, and we need
|
||||||
# to check for any differences between the provided config and the checkpoint's config.
|
# to check for any differences between the provided config and the checkpoint's config.
|
||||||
if cfg.resume:
|
if cfg.resume:
|
||||||
if not logger.last_checkpoint_dir.exists():
|
if not Logger.get_last_checkpoint_dir(out_dir).exists():
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"You have set resume=True, but there is no model checpoint in {logger.last_checkpoint_dir}."
|
"You have set resume=True, but there is no model checkpoint in "
|
||||||
|
f"{Logger.get_last_checkpoint_dir(out_dir)}."
|
||||||
)
|
)
|
||||||
else:
|
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
|
||||||
checkpoint_cfg_path = str(logger.last_pretrained_model_dir / "config.yaml")
|
logging.info(
|
||||||
logging.info(
|
colored(
|
||||||
colored(
|
"You have set resume=True, indicating that you wish to resume a run",
|
||||||
"You have set resume=True, indicating that you wish to resume a run. The provided config "
|
color="yellow",
|
||||||
f"is being overriden by {checkpoint_cfg_path}",
|
attrs=["bold"],
|
||||||
color="yellow",
|
|
||||||
attrs=["bold"],
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
# Get the configuration file from the last checkpoint.
|
)
|
||||||
checkpoint_cfg = init_hydra_config(checkpoint_cfg_path)
|
# Get the configuration file from the last checkpoint.
|
||||||
# Hack to resolve the delta_timestamps ahead of time in order to properly diff.
|
checkpoint_cfg = init_hydra_config(checkpoint_cfg_path)
|
||||||
resolve_delta_timestamps(cfg)
|
# Check for differences between the checkpoint configuration and provided configuration.
|
||||||
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
|
# Hack to resolve the delta_timestamps ahead of time in order to properly diff.
|
||||||
if len(diff) > 0:
|
resolve_delta_timestamps(cfg)
|
||||||
# Log a warning about differences between the checkpoint configuration and the provided
|
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
|
||||||
# configuration (but ignore the `resume` parameter).
|
# Ignore the `resume` and `override_config_on_resume` parameters.
|
||||||
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
|
if "values_changed" in diff:
|
||||||
del diff["values_changed"]["root['resume']"]
|
for k in ["root['resume']", "root['override_config_on_resume']"]:
|
||||||
logging.warning(
|
if k in diff["values_changed"]:
|
||||||
colored(
|
del diff["values_changed"][k]
|
||||||
"At least one difference was detected between the checkpoint configuration and the "
|
# Log a warning about differences between the checkpoint configuration and the provided
|
||||||
f"provided configuration: \n{pformat(diff)}\nNote that the provided configuration "
|
# configuration.
|
||||||
"takes precedence.",
|
logging.warning(
|
||||||
color="yellow",
|
colored(
|
||||||
)
|
"At least one difference was detected between the checkpoint configuration and "
|
||||||
)
|
f"the provided configuration: \n{pformat(diff)}\nNote that since "
|
||||||
|
f"{cfg.override_config_on_resume=}, the "
|
||||||
|
f"{'provided' if cfg.override_config_on_resume else 'checkpoint'} configuration takes "
|
||||||
|
"precedence.",
|
||||||
|
color="yellow",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if not cfg.override_config_on_resume:
|
||||||
|
# Use the checkpoint config instead of the provided config (but keep the provided `resume` and
|
||||||
|
# `override_config_on_resume` parameters).
|
||||||
|
resume = cfg.resume
|
||||||
|
override_config_on_resume = cfg.override_config_on_resume
|
||||||
|
cfg = checkpoint_cfg
|
||||||
|
cfg.resume = resume
|
||||||
|
cfg.override_config_on_resume = override_config_on_resume
|
||||||
|
|
||||||
|
# log metrics to terminal and wandb
|
||||||
|
logger = Logger(cfg, out_dir, wandb_job_name=job_name)
|
||||||
|
|
||||||
if cfg.training.online_steps > 0:
|
if cfg.training.online_steps > 0:
|
||||||
raise NotImplementedError("Online training is not implemented yet.")
|
raise NotImplementedError("Online training is not implemented yet.")
|
||||||
|
|
Loading…
Reference in New Issue