Checkpoint on final step of training even when it doesn't coincide with `save_freq`. (#284)
This commit is contained in:
parent
2abef3bef9
commit
9aa4cdb976
|
@ -39,9 +39,10 @@ training:
|
||||||
# `online_env_seed` is used for environments for online training data rollouts.
|
# `online_env_seed` is used for environments for online training data rollouts.
|
||||||
online_env_seed: ???
|
online_env_seed: ???
|
||||||
eval_freq: ???
|
eval_freq: ???
|
||||||
save_freq: ???
|
|
||||||
log_freq: 250
|
log_freq: 250
|
||||||
save_checkpoint: true
|
save_checkpoint: true
|
||||||
|
# Checkpoint is saved every `save_freq` training iterations and after the last training step.
|
||||||
|
save_freq: ???
|
||||||
num_workers: 4
|
num_workers: 4
|
||||||
batch_size: ???
|
batch_size: ???
|
||||||
image_transforms:
|
image_transforms:
|
||||||
|
|
|
@ -351,7 +351,10 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
logger.log_video(eval_info["video_paths"][0], step, mode="eval")
|
logger.log_video(eval_info["video_paths"][0], step, mode="eval")
|
||||||
logging.info("Resume training")
|
logging.info("Resume training")
|
||||||
|
|
||||||
if cfg.training.save_checkpoint and step % cfg.training.save_freq == 0:
|
if cfg.training.save_checkpoint and (
|
||||||
|
step % cfg.training.save_freq == 0
|
||||||
|
or step == cfg.training.offline_steps + cfg.training.online_steps
|
||||||
|
):
|
||||||
logging.info(f"Checkpoint policy after step {step}")
|
logging.info(f"Checkpoint policy after step {step}")
|
||||||
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if
|
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if
|
||||||
# needed (choose 6 as a minimum for consistency without being overkill).
|
# needed (choose 6 as a minimum for consistency without being overkill).
|
||||||
|
|
Loading…
Reference in New Issue