Merge remote-tracking branch 'upstream/main' into custom_eval_out_dir

This commit is contained in:
Alexander Soare 2024-06-04 18:25:12 +01:00
commit 5db9a96677
2 changed files with 8 additions and 1 deletions

View File

@ -147,7 +147,7 @@ class Normalize(nn.Module):
assert not torch.isinf(min).any(), _no_stats_error_str("min")
assert not torch.isinf(max).any(), _no_stats_error_str("max")
# normalize to [0,1]
batch[key] = (batch[key] - min) / (max - min)
batch[key] = (batch[key] - min) / (max - min + 1e-8)
# normalize to [-1, 1]
batch[key] = batch[key] * 2 - 1
else:

View File

@ -150,6 +150,7 @@ def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline):
grad_norm = info["grad_norm"]
lr = info["lr"]
update_s = info["update_s"]
dataloading_s = info["dataloading_s"]
# A sample is an (observation,action) pair, where observation and action
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
@ -170,6 +171,7 @@ def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline):
f"lr:{lr:0.1e}",
# in seconds
f"updt_s:{update_s:.3f}",
f"data_s:{dataloading_s:.3f}", # if not ~0, you are bottlenecked by cpu or io
]
logging.info(" ".join(log_items))
@ -382,7 +384,10 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
for _ in range(step, cfg.training.offline_steps):
if step == 0:
logging.info("Start offline training on a fixed dataset")
start_time = time.perf_counter()
batch = next(dl_iter)
dataloading_s = time.perf_counter() - start_time
for key in batch:
batch[key] = batch[key].to(device, non_blocking=True)
@ -397,6 +402,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
use_amp=cfg.use_amp,
)
train_info["dataloading_s"] = dataloading_s
if step % cfg.training.log_freq == 0:
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline=True)