Improve log msg in train.py

This commit is contained in:
Remi Cadene 2024-03-03 13:22:09 +00:00
parent 0f2fa4d9ef
commit 4c400b41a5
1 changed files with 8 additions and 2 deletions

View File

@ -170,6 +170,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline)
if step > 0 and step % cfg.eval_freq == 0:
logging.info(f"Eval policy at step {step}")
eval_info, first_video = eval_policy(
env,
td_policy,
@ -179,10 +180,12 @@ def train(cfg: dict, out_dir=None, job_name=None):
log_eval_info(logger, eval_info, step, cfg, offline_buffer, is_offline)
if cfg.wandb.enable:
logger.log_video(first_video, step, mode="eval")
logging.info("Resume training")
if step > 0 and cfg.save_model and step % cfg.save_freq == 0:
logging.info(f"Checkpoint model at step {step}")
logging.info(f"Checkpoint policy at step {step}")
logger.save_model(policy, identifier=step)
logging.info("Resume training")
step += 1
@ -227,6 +230,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline)
if step > 0 and step % cfg.eval_freq == 0:
logging.info(f"Eval policy at step {step}")
eval_info, first_video = eval_policy(
env,
td_policy,
@ -236,10 +240,12 @@ def train(cfg: dict, out_dir=None, job_name=None):
log_eval_info(logger, eval_info, step, cfg, offline_buffer, is_offline)
if cfg.wandb.enable:
logger.log_video(first_video, step, mode="eval")
logging.info("Resume training")
if step > 0 and cfg.save_model and step % cfg.save_freq == 0:
logging.info(f"Checkpoint model at step {step}")
logging.info(f"Checkpoint policy at step {step}")
logger.save_model(policy, identifier=step)
logging.info("Resume training")
step += 1
online_step += 1