Improve log msg in train.py
This commit is contained in:
parent
0f2fa4d9ef
commit
4c400b41a5
|
@ -170,6 +170,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline)
|
||||
|
||||
if step > 0 and step % cfg.eval_freq == 0:
|
||||
logging.info(f"Eval policy at step {step}")
|
||||
eval_info, first_video = eval_policy(
|
||||
env,
|
||||
td_policy,
|
||||
|
@ -179,10 +180,12 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
log_eval_info(logger, eval_info, step, cfg, offline_buffer, is_offline)
|
||||
if cfg.wandb.enable:
|
||||
logger.log_video(first_video, step, mode="eval")
|
||||
logging.info("Resume training")
|
||||
|
||||
if step > 0 and cfg.save_model and step % cfg.save_freq == 0:
|
||||
logging.info(f"Checkpoint model at step {step}")
|
||||
logging.info(f"Checkpoint policy at step {step}")
|
||||
logger.save_model(policy, identifier=step)
|
||||
logging.info("Resume training")
|
||||
|
||||
step += 1
|
||||
|
||||
|
@ -227,6 +230,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline)
|
||||
|
||||
if step > 0 and step % cfg.eval_freq == 0:
|
||||
logging.info(f"Eval policy at step {step}")
|
||||
eval_info, first_video = eval_policy(
|
||||
env,
|
||||
td_policy,
|
||||
|
@ -236,10 +240,12 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
log_eval_info(logger, eval_info, step, cfg, offline_buffer, is_offline)
|
||||
if cfg.wandb.enable:
|
||||
logger.log_video(first_video, step, mode="eval")
|
||||
logging.info("Resume training")
|
||||
|
||||
if step > 0 and cfg.save_model and step % cfg.save_freq == 0:
|
||||
logging.info(f"Checkpoint model at step {step}")
|
||||
logging.info(f"Checkpoint policy at step {step}")
|
||||
logger.save_model(policy, identifier=step)
|
||||
logging.info("Resume training")
|
||||
|
||||
step += 1
|
||||
online_step += 1
|
||||
|
|
Loading…
Reference in New Issue