backup wip

This commit is contained in:
Alexander Soare 2024-05-20 18:47:36 +01:00
parent c99b845b8f
commit fe8347246c
2 changed files with 29 additions and 8 deletions

View File

@ -18,6 +18,8 @@
import logging
import os
import re
from glob import glob
from pathlib import Path
import torch
@ -73,7 +75,19 @@ class Logger:
os.environ["WANDB_SILENT"] = "true"
import wandb
wandb_run_id = None
if cfg.resume:
# Get the WandB run ID.
paths = glob(str(self._checkpoint_dir / "../wandb/latest-run/run-*"))
if len(paths) != 1:
raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.")
match = re.search(r"run-([^\.]+).wandb", paths[0].split("/")[-1])
if match is None:
raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.")
wandb_run_id = match.groups(0)[0]
wandb.init(
id=wandb_run_id,
project=project,
entity=entity,
name=job_name,
@ -87,14 +101,14 @@ class Logger:
# TODO(rcadene): split train and eval, and run async eval with job_type="eval"
job_type="train_eval",
# TODO(rcadene): add resume option
resume="must",
resume="must" if cfg.resume else None,
)
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
self._wandb = wandb
@property
def last_checkpoint_path(self):
def last_checkpoint_path(self) -> Path:
return self._last_checkpoint_path
def save_model(self, policy: Policy, identifier: str):
@ -112,6 +126,8 @@ class Logger:
)
artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE)
self._wandb.log_artifact(artifact)
if self._last_checkpoint_path.exists():
os.remove(self._last_checkpoint_path)
os.symlink(save_dir.absolute(), self._last_checkpoint_path) # TODO(now): Check this works
def save_training_state(

View File

@ -322,8 +322,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logging.info("make_policy")
policy = make_policy(
hydra_cfg=cfg,
dataset_stats=offline_dataset.stats,
pretrained_policy_name_or_path=logger.last_checkpoint_path if cfg.resume else None,
dataset_stats=offline_dataset.stats if not cfg.resume else None,
pretrained_policy_name_or_path=str(logger.last_checkpoint_path) if cfg.resume else None,
)
# Create optimizer and scheduler
@ -335,10 +335,12 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
if cfg.resume:
print("You have set resume=True, indicating that you wish to resume a run.")
# Make sure there is a checkpoint.
if not Path(logger.last_checkpoint_path).exists():
raise RuntimeError(f"You have set resume=True, but {logger.last_checkpoint_path} does not exist.")
if not logger.last_checkpoint_path.exists():
raise RuntimeError(
f"You have set resume=True, but {str(logger.last_checkpoint_path)} does not exist."
)
# Get the configuration file from the last checkpoint.
checkpoint_cfg = init_hydra_config(logger.last_checkpoint_path)
checkpoint_cfg = init_hydra_config(str(logger.last_checkpoint_path))
# TODO(now): Do a diff check.
cfg = checkpoint_cfg
step = logger.load_last_training_state(optimizer, lr_scheduler)
@ -376,8 +378,11 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logging.info(f"Checkpoint policy after step {step}")
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if
# needed (choose 6 as a minimum for consistency without being overkill).
logger.save_model(
logger.save_checkpont(
step,
policy,
optimizer,
lr_scheduler,
identifier=str(step).zfill(
max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps)))
),