Enable logging the whole forward dictionary

This commit is contained in:
Alexander Soare 2024-05-08 14:45:18 +01:00
parent 47de07658c
commit e1addd40f4
3 changed files with 7 additions and 2 deletions

View File

@ -114,6 +114,8 @@ class Logger:
assert mode in {"train", "eval"}
if self._wandb is not None:
for k, v in d.items():
if not isinstance(v, (int, float, str)):
continue
self._wandb.log({f"{mode}/{k}": v}, step=step)
def log_video(self, video_path: str, step: int, mode: str = "train"):

View File

@ -38,7 +38,8 @@ class Policy(Protocol):
def forward(self, batch: dict[str, Tensor]) -> dict:
"""Run the batch through the model and compute the loss for training or validation.
Returns a dictionary with "loss" and maybe other information.
Returns a dictionary with "loss" and potentially other information. Apart from "loss" which is a Tensor, all
other items should be logging-friendly, native Python types.
"""
def select_action(self, batch: dict[str, Tensor]):

View File

@ -72,6 +72,7 @@ def make_optimizer_and_scheduler(cfg, policy):
def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
"""Returns a dictionary of items for logging."""
start_time = time.time()
policy.train()
output_dict = policy.forward(batch)
@ -99,6 +100,7 @@ def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
"grad_norm": float(grad_norm),
"lr": optimizer.param_groups[0]["lr"],
"update_s": time.time() - start_time,
**{k: v for k, v in output_dict.items() if k != "loss"},
}
return info
@ -122,7 +124,7 @@ def train_notebook(out_dir=None, job_name=None, config_name="default", config_pa
train(cfg, out_dir=out_dir, job_name=job_name)
def log_train_info(logger, info, step, cfg, dataset, is_offline):
def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline):
loss = info["loss"]
grad_norm = info["grad_norm"]
lr = info["lr"]