From 9aa4cdb9762ee503d5a3ab7cf3586d47afa09de9 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Thu, 20 Jun 2024 08:27:01 +0100 Subject: [PATCH] Checkpoint on final step of training even when it doesn't coincide with `save_freq`. (#284) --- lerobot/configs/default.yaml | 3 ++- lerobot/scripts/train.py | 5 ++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index c479788b..df0dae7d 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -39,9 +39,10 @@ training: # `online_env_seed` is used for environments for online training data rollouts. online_env_seed: ??? eval_freq: ??? - save_freq: ??? log_freq: 250 save_checkpoint: true + # Checkpoint is saved every `save_freq` training iterations and after the last training step. + save_freq: ??? num_workers: 4 batch_size: ??? image_transforms: diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 693ff40c..796881c4 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -351,7 +351,10 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No logger.log_video(eval_info["video_paths"][0], step, mode="eval") logging.info("Resume training") - if cfg.training.save_checkpoint and step % cfg.training.save_freq == 0: + if cfg.training.save_checkpoint and ( + step % cfg.training.save_freq == 0 + or step == cfg.training.offline_steps + cfg.training.online_steps + ): logging.info(f"Checkpoint policy after step {step}") # Note: Save with step as the identifier, and format it to have at least 6 digits but more if # needed (choose 6 as a minimum for consistency without being overkill).