Checkpoint on final step of training even when it doesn't coincide with `save_freq`. (#284)
This commit is contained in:
parent
2abef3bef9
commit
9aa4cdb976
lerobot
|
@ -39,9 +39,10 @@ training:
|
|||
# `online_env_seed` is used for environments for online training data rollouts.
|
||||
online_env_seed: ???
|
||||
eval_freq: ???
|
||||
save_freq: ???
|
||||
log_freq: 250
|
||||
save_checkpoint: true
|
||||
# Checkpoint is saved every `save_freq` training iterations and after the last training step.
|
||||
save_freq: ???
|
||||
num_workers: 4
|
||||
batch_size: ???
|
||||
image_transforms:
|
||||
|
|
|
@ -351,7 +351,10 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
logger.log_video(eval_info["video_paths"][0], step, mode="eval")
|
||||
logging.info("Resume training")
|
||||
|
||||
if cfg.training.save_checkpoint and step % cfg.training.save_freq == 0:
|
||||
if cfg.training.save_checkpoint and (
|
||||
step % cfg.training.save_freq == 0
|
||||
or step == cfg.training.offline_steps + cfg.training.online_steps
|
||||
):
|
||||
logging.info(f"Checkpoint policy after step {step}")
|
||||
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if
|
||||
# needed (choose 6 as a minimum for consistency without being overkill).
|
||||
|
|
Loading…
Reference in New Issue