Add wandb run id in config

This commit is contained in:
AdilZouitine 2025-03-27 08:11:56 +00:00
parent dd37bd412e
commit 626e5dd35c
1 changed files with 7 additions and 2 deletions

View File

@ -30,9 +30,10 @@ def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[st
"""Return a group name for logging. Optionally returns group name as list."""
lst = [
f"policy:{cfg.policy.type}",
f"dataset:{cfg.dataset.repo_id}",
f"seed:{cfg.seed}",
]
if cfg.dataset is not None:
lst.append(f"dataset:{cfg.dataset.repo_id}")
if cfg.env is not None:
lst.append(f"env:{cfg.env.type}")
return lst if return_list else "-".join(lst)
@ -92,6 +93,10 @@ class WandBLogger:
resume="must" if cfg.resume else None,
mode=self.cfg.mode if self.cfg.mode in ["online", "offline", "disabled"] else "online",
)
run_id = wandb.run.id
# NOTE: We will override the cfg.wandb.run_id with the wandb run id.
# This is because we want to be able to resume the run from the wandb run id.
cfg.wandb.run_id = run_id
# Handle custom step key for rl asynchronous training.
self._wandb_custom_step_key: set[str] | None = None
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
@ -110,7 +115,7 @@ class WandBLogger:
artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE)
self._wandb.log_artifact(artifact)
def log_dict(self, d: dict, step: int, mode: str = "train", custom_step_key: str | None = None):
def log_dict(self, d: dict, step: int | None = None, mode: str = "train", custom_step_key: str | None = None):
if mode not in {"train", "eval"}:
raise ValueError(mode)
if step is None and custom_step_key is None: