Adding parameter dataloading_s to console logs and wandb for tracking… (#243)
Co-authored-by: Remi <re.cadene@gmail.com>
This commit is contained in:
parent
b0d954c6e1
commit
33362dbd17
|
@ -150,6 +150,7 @@ def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline):
|
||||||
grad_norm = info["grad_norm"]
|
grad_norm = info["grad_norm"]
|
||||||
lr = info["lr"]
|
lr = info["lr"]
|
||||||
update_s = info["update_s"]
|
update_s = info["update_s"]
|
||||||
|
dataloading_s = info["dataloading_s"]
|
||||||
|
|
||||||
# A sample is an (observation,action) pair, where observation and action
|
# A sample is an (observation,action) pair, where observation and action
|
||||||
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
|
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
|
||||||
|
@ -170,6 +171,7 @@ def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline):
|
||||||
f"lr:{lr:0.1e}",
|
f"lr:{lr:0.1e}",
|
||||||
# in seconds
|
# in seconds
|
||||||
f"updt_s:{update_s:.3f}",
|
f"updt_s:{update_s:.3f}",
|
||||||
|
f"data_s:{dataloading_s:.3f}", # if not ~0, you are bottlenecked by cpu or io
|
||||||
]
|
]
|
||||||
logging.info(" ".join(log_items))
|
logging.info(" ".join(log_items))
|
||||||
|
|
||||||
|
@ -382,7 +384,10 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
for _ in range(step, cfg.training.offline_steps):
|
for _ in range(step, cfg.training.offline_steps):
|
||||||
if step == 0:
|
if step == 0:
|
||||||
logging.info("Start offline training on a fixed dataset")
|
logging.info("Start offline training on a fixed dataset")
|
||||||
|
|
||||||
|
start_time = time.perf_counter()
|
||||||
batch = next(dl_iter)
|
batch = next(dl_iter)
|
||||||
|
dataloading_s = time.perf_counter() - start_time
|
||||||
|
|
||||||
for key in batch:
|
for key in batch:
|
||||||
batch[key] = batch[key].to(device, non_blocking=True)
|
batch[key] = batch[key].to(device, non_blocking=True)
|
||||||
|
@ -397,6 +402,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
use_amp=cfg.use_amp,
|
use_amp=cfg.use_amp,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
train_info["dataloading_s"] = dataloading_s
|
||||||
|
|
||||||
if step % cfg.training.log_freq == 0:
|
if step % cfg.training.log_freq == 0:
|
||||||
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline=True)
|
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline=True)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue