From 626e5dd35cbab6b38e57d941c1487655575e71b2 Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Thu, 27 Mar 2025 08:11:56 +0000 Subject: [PATCH] Add wandb run id in config --- lerobot/common/utils/wandb_utils.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/lerobot/common/utils/wandb_utils.py b/lerobot/common/utils/wandb_utils.py index 9a192406..b62ff140 100644 --- a/lerobot/common/utils/wandb_utils.py +++ b/lerobot/common/utils/wandb_utils.py @@ -30,9 +30,10 @@ def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[st """Return a group name for logging. Optionally returns group name as list.""" lst = [ f"policy:{cfg.policy.type}", - f"dataset:{cfg.dataset.repo_id}", f"seed:{cfg.seed}", ] + if cfg.dataset is not None: + lst.append(f"dataset:{cfg.dataset.repo_id}") if cfg.env is not None: lst.append(f"env:{cfg.env.type}") return lst if return_list else "-".join(lst) @@ -92,6 +93,10 @@ class WandBLogger: resume="must" if cfg.resume else None, mode=self.cfg.mode if self.cfg.mode in ["online", "offline", "disabled"] else "online", ) + run_id = wandb.run.id + # NOTE: We will override the cfg.wandb.run_id with the wandb run id. + # This is because we want to be able to resume the run from the wandb run id. + cfg.wandb.run_id = run_id # Handle custom step key for rl asynchronous training. self._wandb_custom_step_key: set[str] | None = None print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"])) @@ -110,7 +115,7 @@ class WandBLogger: artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE) self._wandb.log_artifact(artifact) - def log_dict(self, d: dict, step: int, mode: str = "train", custom_step_key: str | None = None): + def log_dict(self, d: dict, step: int | None = None, mode: str = "train", custom_step_key: str | None = None): if mode not in {"train", "eval"}: raise ValueError(mode) if step is None and custom_step_key is None: