Add wandb run id in config
This commit is contained in:
parent
056f79d358
commit
0b5b62c8fb
|
@ -30,9 +30,10 @@ def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[st
|
||||||
"""Return a group name for logging. Optionally returns group name as list."""
|
"""Return a group name for logging. Optionally returns group name as list."""
|
||||||
lst = [
|
lst = [
|
||||||
f"policy:{cfg.policy.type}",
|
f"policy:{cfg.policy.type}",
|
||||||
f"dataset:{cfg.dataset.repo_id}",
|
|
||||||
f"seed:{cfg.seed}",
|
f"seed:{cfg.seed}",
|
||||||
]
|
]
|
||||||
|
if cfg.dataset is not None:
|
||||||
|
lst.append(f"dataset:{cfg.dataset.repo_id}")
|
||||||
if cfg.env is not None:
|
if cfg.env is not None:
|
||||||
lst.append(f"env:{cfg.env.type}")
|
lst.append(f"env:{cfg.env.type}")
|
||||||
return lst if return_list else "-".join(lst)
|
return lst if return_list else "-".join(lst)
|
||||||
|
@ -92,6 +93,10 @@ class WandBLogger:
|
||||||
resume="must" if cfg.resume else None,
|
resume="must" if cfg.resume else None,
|
||||||
mode=self.cfg.mode if self.cfg.mode in ["online", "offline", "disabled"] else "online",
|
mode=self.cfg.mode if self.cfg.mode in ["online", "offline", "disabled"] else "online",
|
||||||
)
|
)
|
||||||
|
run_id = wandb.run.id
|
||||||
|
# NOTE: We will override the cfg.wandb.run_id with the wandb run id.
|
||||||
|
# This is because we want to be able to resume the run from the wandb run id.
|
||||||
|
cfg.wandb.run_id = run_id
|
||||||
# Handle custom step key for rl asynchronous training.
|
# Handle custom step key for rl asynchronous training.
|
||||||
self._wandb_custom_step_key: set[str] | None = None
|
self._wandb_custom_step_key: set[str] | None = None
|
||||||
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
||||||
|
@ -110,7 +115,7 @@ class WandBLogger:
|
||||||
artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE)
|
artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE)
|
||||||
self._wandb.log_artifact(artifact)
|
self._wandb.log_artifact(artifact)
|
||||||
|
|
||||||
def log_dict(self, d: dict, step: int, mode: str = "train", custom_step_key: str | None = None):
|
def log_dict(self, d: dict, step: int | None = None, mode: str = "train", custom_step_key: str | None = None):
|
||||||
if mode not in {"train", "eval"}:
|
if mode not in {"train", "eval"}:
|
||||||
raise ValueError(mode)
|
raise ValueError(mode)
|
||||||
if step is None and custom_step_key is None:
|
if step is None and custom_step_key is None:
|
||||||
|
|
Loading…
Reference in New Issue