backup wip
This commit is contained in:
parent
c99b845b8f
commit
fe8347246c
|
@ -18,6 +18,8 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
|
from glob import glob
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -73,7 +75,19 @@ class Logger:
|
||||||
os.environ["WANDB_SILENT"] = "true"
|
os.environ["WANDB_SILENT"] = "true"
|
||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
|
wandb_run_id = None
|
||||||
|
if cfg.resume:
|
||||||
|
# Get the WandB run ID.
|
||||||
|
paths = glob(str(self._checkpoint_dir / "../wandb/latest-run/run-*"))
|
||||||
|
if len(paths) != 1:
|
||||||
|
raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.")
|
||||||
|
match = re.search(r"run-([^\.]+).wandb", paths[0].split("/")[-1])
|
||||||
|
if match is None:
|
||||||
|
raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.")
|
||||||
|
wandb_run_id = match.groups(0)[0]
|
||||||
|
|
||||||
wandb.init(
|
wandb.init(
|
||||||
|
id=wandb_run_id,
|
||||||
project=project,
|
project=project,
|
||||||
entity=entity,
|
entity=entity,
|
||||||
name=job_name,
|
name=job_name,
|
||||||
|
@ -87,14 +101,14 @@ class Logger:
|
||||||
# TODO(rcadene): split train and eval, and run async eval with job_type="eval"
|
# TODO(rcadene): split train and eval, and run async eval with job_type="eval"
|
||||||
job_type="train_eval",
|
job_type="train_eval",
|
||||||
# TODO(rcadene): add resume option
|
# TODO(rcadene): add resume option
|
||||||
resume="must",
|
resume="must" if cfg.resume else None,
|
||||||
)
|
)
|
||||||
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
||||||
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
||||||
self._wandb = wandb
|
self._wandb = wandb
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def last_checkpoint_path(self):
|
def last_checkpoint_path(self) -> Path:
|
||||||
return self._last_checkpoint_path
|
return self._last_checkpoint_path
|
||||||
|
|
||||||
def save_model(self, policy: Policy, identifier: str):
|
def save_model(self, policy: Policy, identifier: str):
|
||||||
|
@ -112,6 +126,8 @@ class Logger:
|
||||||
)
|
)
|
||||||
artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE)
|
artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE)
|
||||||
self._wandb.log_artifact(artifact)
|
self._wandb.log_artifact(artifact)
|
||||||
|
if self._last_checkpoint_path.exists():
|
||||||
|
os.remove(self._last_checkpoint_path)
|
||||||
os.symlink(save_dir.absolute(), self._last_checkpoint_path) # TODO(now): Check this works
|
os.symlink(save_dir.absolute(), self._last_checkpoint_path) # TODO(now): Check this works
|
||||||
|
|
||||||
def save_training_state(
|
def save_training_state(
|
||||||
|
|
|
@ -322,8 +322,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
logging.info("make_policy")
|
logging.info("make_policy")
|
||||||
policy = make_policy(
|
policy = make_policy(
|
||||||
hydra_cfg=cfg,
|
hydra_cfg=cfg,
|
||||||
dataset_stats=offline_dataset.stats,
|
dataset_stats=offline_dataset.stats if not cfg.resume else None,
|
||||||
pretrained_policy_name_or_path=logger.last_checkpoint_path if cfg.resume else None,
|
pretrained_policy_name_or_path=str(logger.last_checkpoint_path) if cfg.resume else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create optimizer and scheduler
|
# Create optimizer and scheduler
|
||||||
|
@ -335,10 +335,12 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
if cfg.resume:
|
if cfg.resume:
|
||||||
print("You have set resume=True, indicating that you wish to resume a run.")
|
print("You have set resume=True, indicating that you wish to resume a run.")
|
||||||
# Make sure there is a checkpoint.
|
# Make sure there is a checkpoint.
|
||||||
if not Path(logger.last_checkpoint_path).exists():
|
if not logger.last_checkpoint_path.exists():
|
||||||
raise RuntimeError(f"You have set resume=True, but {logger.last_checkpoint_path} does not exist.")
|
raise RuntimeError(
|
||||||
|
f"You have set resume=True, but {str(logger.last_checkpoint_path)} does not exist."
|
||||||
|
)
|
||||||
# Get the configuration file from the last checkpoint.
|
# Get the configuration file from the last checkpoint.
|
||||||
checkpoint_cfg = init_hydra_config(logger.last_checkpoint_path)
|
checkpoint_cfg = init_hydra_config(str(logger.last_checkpoint_path))
|
||||||
# TODO(now): Do a diff check.
|
# TODO(now): Do a diff check.
|
||||||
cfg = checkpoint_cfg
|
cfg = checkpoint_cfg
|
||||||
step = logger.load_last_training_state(optimizer, lr_scheduler)
|
step = logger.load_last_training_state(optimizer, lr_scheduler)
|
||||||
|
@ -376,8 +378,11 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
logging.info(f"Checkpoint policy after step {step}")
|
logging.info(f"Checkpoint policy after step {step}")
|
||||||
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if
|
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if
|
||||||
# needed (choose 6 as a minimum for consistency without being overkill).
|
# needed (choose 6 as a minimum for consistency without being overkill).
|
||||||
logger.save_model(
|
logger.save_checkpont(
|
||||||
|
step,
|
||||||
policy,
|
policy,
|
||||||
|
optimizer,
|
||||||
|
lr_scheduler,
|
||||||
identifier=str(step).zfill(
|
identifier=str(step).zfill(
|
||||||
max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps)))
|
max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps)))
|
||||||
),
|
),
|
||||||
|
|
Loading…
Reference in New Issue