backup wip
This commit is contained in:
parent
c99b845b8f
commit
fe8347246c
|
@ -18,6 +18,8 @@
|
|||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from glob import glob
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
@ -73,7 +75,19 @@ class Logger:
|
|||
os.environ["WANDB_SILENT"] = "true"
|
||||
import wandb
|
||||
|
||||
wandb_run_id = None
|
||||
if cfg.resume:
|
||||
# Get the WandB run ID.
|
||||
paths = glob(str(self._checkpoint_dir / "../wandb/latest-run/run-*"))
|
||||
if len(paths) != 1:
|
||||
raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.")
|
||||
match = re.search(r"run-([^\.]+).wandb", paths[0].split("/")[-1])
|
||||
if match is None:
|
||||
raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.")
|
||||
wandb_run_id = match.groups(0)[0]
|
||||
|
||||
wandb.init(
|
||||
id=wandb_run_id,
|
||||
project=project,
|
||||
entity=entity,
|
||||
name=job_name,
|
||||
|
@ -87,14 +101,14 @@ class Logger:
|
|||
# TODO(rcadene): split train and eval, and run async eval with job_type="eval"
|
||||
job_type="train_eval",
|
||||
# TODO(rcadene): add resume option
|
||||
resume="must",
|
||||
resume="must" if cfg.resume else None,
|
||||
)
|
||||
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
||||
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
||||
self._wandb = wandb
|
||||
|
||||
@property
|
||||
def last_checkpoint_path(self):
|
||||
def last_checkpoint_path(self) -> Path:
|
||||
return self._last_checkpoint_path
|
||||
|
||||
def save_model(self, policy: Policy, identifier: str):
|
||||
|
@ -112,6 +126,8 @@ class Logger:
|
|||
)
|
||||
artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE)
|
||||
self._wandb.log_artifact(artifact)
|
||||
if self._last_checkpoint_path.exists():
|
||||
os.remove(self._last_checkpoint_path)
|
||||
os.symlink(save_dir.absolute(), self._last_checkpoint_path) # TODO(now): Check this works
|
||||
|
||||
def save_training_state(
|
||||
|
|
|
@ -322,8 +322,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
logging.info("make_policy")
|
||||
policy = make_policy(
|
||||
hydra_cfg=cfg,
|
||||
dataset_stats=offline_dataset.stats,
|
||||
pretrained_policy_name_or_path=logger.last_checkpoint_path if cfg.resume else None,
|
||||
dataset_stats=offline_dataset.stats if not cfg.resume else None,
|
||||
pretrained_policy_name_or_path=str(logger.last_checkpoint_path) if cfg.resume else None,
|
||||
)
|
||||
|
||||
# Create optimizer and scheduler
|
||||
|
@ -335,10 +335,12 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
if cfg.resume:
|
||||
print("You have set resume=True, indicating that you wish to resume a run.")
|
||||
# Make sure there is a checkpoint.
|
||||
if not Path(logger.last_checkpoint_path).exists():
|
||||
raise RuntimeError(f"You have set resume=True, but {logger.last_checkpoint_path} does not exist.")
|
||||
if not logger.last_checkpoint_path.exists():
|
||||
raise RuntimeError(
|
||||
f"You have set resume=True, but {str(logger.last_checkpoint_path)} does not exist."
|
||||
)
|
||||
# Get the configuration file from the last checkpoint.
|
||||
checkpoint_cfg = init_hydra_config(logger.last_checkpoint_path)
|
||||
checkpoint_cfg = init_hydra_config(str(logger.last_checkpoint_path))
|
||||
# TODO(now): Do a diff check.
|
||||
cfg = checkpoint_cfg
|
||||
step = logger.load_last_training_state(optimizer, lr_scheduler)
|
||||
|
@ -376,8 +378,11 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
logging.info(f"Checkpoint policy after step {step}")
|
||||
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if
|
||||
# needed (choose 6 as a minimum for consistency without being overkill).
|
||||
logger.save_model(
|
||||
logger.save_checkpont(
|
||||
step,
|
||||
policy,
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
identifier=str(step).zfill(
|
||||
max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps)))
|
||||
),
|
||||
|
|
Loading…
Reference in New Issue