diff --git a/lerobot/common/utils/wandb_utils.py b/lerobot/common/utils/wandb_utils.py index 9a192406..b62ff140 100644 --- a/lerobot/common/utils/wandb_utils.py +++ b/lerobot/common/utils/wandb_utils.py @@ -30,9 +30,10 @@ def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[st """Return a group name for logging. Optionally returns group name as list.""" lst = [ f"policy:{cfg.policy.type}", - f"dataset:{cfg.dataset.repo_id}", f"seed:{cfg.seed}", ] + if cfg.dataset is not None: + lst.append(f"dataset:{cfg.dataset.repo_id}") if cfg.env is not None: lst.append(f"env:{cfg.env.type}") return lst if return_list else "-".join(lst) @@ -92,6 +93,10 @@ class WandBLogger: resume="must" if cfg.resume else None, mode=self.cfg.mode if self.cfg.mode in ["online", "offline", "disabled"] else "online", ) + run_id = wandb.run.id + # NOTE: We will override the cfg.wandb.run_id with the wandb run id. + # This is because we want to be able to resume the run from the wandb run id. + cfg.wandb.run_id = run_id # Handle custom step key for rl asynchronous training. self._wandb_custom_step_key: set[str] | None = None print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"])) @@ -110,7 +115,7 @@ class WandBLogger: artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE) self._wandb.log_artifact(artifact) - def log_dict(self, d: dict, step: int, mode: str = "train", custom_step_key: str | None = None): + def log_dict(self, d: dict, step: int | None = None, mode: str = "train", custom_step_key: str | None = None): if mode not in {"train", "eval"}: raise ValueError(mode) if step is None and custom_step_key is None: