From e1addd40f40ea8391f82dcb7643a1265450fb143 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Wed, 8 May 2024 14:45:18 +0100 Subject: [PATCH] Enable logging the whole forward dictionary --- lerobot/common/logger.py | 2 ++ lerobot/common/policies/policy_protocol.py | 3 ++- lerobot/scripts/train.py | 4 +++- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py index 8e7fe7f2..3ff39d3d 100644 --- a/lerobot/common/logger.py +++ b/lerobot/common/logger.py @@ -114,6 +114,8 @@ class Logger: assert mode in {"train", "eval"} if self._wandb is not None: for k, v in d.items(): + if not isinstance(v, (int, float, str)): + continue self._wandb.log({f"{mode}/{k}": v}, step=step) def log_video(self, video_path: str, step: int, mode: str = "train"): diff --git a/lerobot/common/policies/policy_protocol.py b/lerobot/common/policies/policy_protocol.py index 5749c6a8..b00cff5c 100644 --- a/lerobot/common/policies/policy_protocol.py +++ b/lerobot/common/policies/policy_protocol.py @@ -38,7 +38,8 @@ class Policy(Protocol): def forward(self, batch: dict[str, Tensor]) -> dict: """Run the batch through the model and compute the loss for training or validation. - Returns a dictionary with "loss" and maybe other information. + Returns a dictionary with "loss" and potentially other information. Apart from "loss" which is a Tensor, all + other items should be logging-friendly, native Python types. """ def select_action(self, batch: dict[str, Tensor]): diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index d5fedc84..7319e03f 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -72,6 +72,7 @@ def make_optimizer_and_scheduler(cfg, policy): def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None): + """Returns a dictionary of items for logging.""" start_time = time.time() policy.train() output_dict = policy.forward(batch) @@ -99,6 +100,7 @@ def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None): "grad_norm": float(grad_norm), "lr": optimizer.param_groups[0]["lr"], "update_s": time.time() - start_time, + **{k: v for k, v in output_dict.items() if k != "loss"}, } return info @@ -122,7 +124,7 @@ def train_notebook(out_dir=None, job_name=None, config_name="default", config_pa train(cfg, out_dir=out_dir, job_name=job_name) -def log_train_info(logger, info, step, cfg, dataset, is_offline): +def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline): loss = info["loss"] grad_norm = info["grad_norm"] lr = info["lr"]