Checkpoint on final step of training even when it doesn't coincide with `save_freq`. ()

This commit is contained in:
Alexander Soare 2024-06-20 08:27:01 +01:00 committed by GitHub
parent 2abef3bef9
commit 9aa4cdb976
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 6 additions and 2 deletions
lerobot
configs
scripts

View File

@ -39,9 +39,10 @@ training:
# `online_env_seed` is used for environments for online training data rollouts.
online_env_seed: ???
eval_freq: ???
save_freq: ???
log_freq: 250
save_checkpoint: true
# Checkpoint is saved every `save_freq` training iterations and after the last training step.
save_freq: ???
num_workers: 4
batch_size: ???
image_transforms:

View File

@ -351,7 +351,10 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logger.log_video(eval_info["video_paths"][0], step, mode="eval")
logging.info("Resume training")
if cfg.training.save_checkpoint and step % cfg.training.save_freq == 0:
if cfg.training.save_checkpoint and (
step % cfg.training.save_freq == 0
or step == cfg.training.offline_steps + cfg.training.online_steps
):
logging.info(f"Checkpoint policy after step {step}")
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if
# needed (choose 6 as a minimum for consistency without being overkill).