backup wip

This commit is contained in:
Alexander Soare 2024-05-20 17:42:35 +01:00
parent 9b62c25f6c
commit c99b845b8f
5 changed files with 157 additions and 73 deletions

View File

@ -20,11 +20,15 @@ import logging
import os import os
from pathlib import Path from pathlib import Path
import torch
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from omegaconf import OmegaConf from omegaconf import OmegaConf
from termcolor import colored from termcolor import colored
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from lerobot.common.policies.policy_protocol import Policy from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import get_global_random_state, set_global_random_state
def log_output_dir(out_dir): def log_output_dir(out_dir):
@ -49,11 +53,11 @@ class Logger:
self._log_dir = Path(log_dir) self._log_dir = Path(log_dir)
self._log_dir.mkdir(parents=True, exist_ok=True) self._log_dir.mkdir(parents=True, exist_ok=True)
self._job_name = job_name self._job_name = job_name
self._model_dir = self._log_dir / "checkpoints" self._checkpoint_dir = self._log_dir / "checkpoints"
self._last_checkpoint_path = self._checkpoint_dir / "last"
self._buffer_dir = self._log_dir / "buffers" self._buffer_dir = self._log_dir / "buffers"
self._save_model = cfg.training.save_model self._save_model = cfg.training.save_model
self._disable_wandb_artifact = cfg.wandb.disable_artifact self._disable_wandb_artifact = cfg.wandb.disable_artifact
self._save_buffer = cfg.training.get("save_buffer", False)
self._group = cfg_to_group(cfg) self._group = cfg_to_group(cfg)
self._seed = cfg.seed self._seed = cfg.seed
self._cfg = cfg self._cfg = cfg
@ -83,16 +87,20 @@ class Logger:
# TODO(rcadene): split train and eval, and run async eval with job_type="eval" # TODO(rcadene): split train and eval, and run async eval with job_type="eval"
job_type="train_eval", job_type="train_eval",
# TODO(rcadene): add resume option # TODO(rcadene): add resume option
resume=None, resume="must",
) )
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"])) print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}") logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
self._wandb = wandb self._wandb = wandb
def save_model(self, policy: Policy, identifier): @property
def last_checkpoint_path(self):
return self._last_checkpoint_path
def save_model(self, policy: Policy, identifier: str):
if self._save_model: if self._save_model:
self._model_dir.mkdir(parents=True, exist_ok=True) self._checkpoint_dir.mkdir(parents=True, exist_ok=True)
save_dir = self._model_dir / str(identifier) save_dir = self._checkpoint_dir / str(identifier)
policy.save_pretrained(save_dir) policy.save_pretrained(save_dir)
# Also save the full Hydra config for the env configuration. # Also save the full Hydra config for the env configuration.
OmegaConf.save(self._cfg, save_dir / "config.yaml") OmegaConf.save(self._cfg, save_dir / "config.yaml")
@ -104,27 +112,47 @@ class Logger:
) )
artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE) artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE)
self._wandb.log_artifact(artifact) self._wandb.log_artifact(artifact)
os.symlink(save_dir.absolute(), self._last_checkpoint_path) # TODO(now): Check this works
def save_buffer(self, buffer, identifier): def save_training_state(
self._buffer_dir.mkdir(parents=True, exist_ok=True) self, train_step: int, optimizer: Optimizer, scheduler: LRScheduler | None, identifier: str
fp = self._buffer_dir / f"{str(identifier)}.pkl" ):
buffer.save(fp) training_state = {
if self._wandb and not self._disable_wandb_artifact: "step": train_step,
# note wandb artifact does not accept ":" or "/" in its name "optimizer": optimizer.state_dict(),
artifact = self._wandb.Artifact( **get_global_random_state(),
f"{self._group.replace(':', '_').replace('/', '_')}-{self._seed}-{identifier}", }
type="buffer", if scheduler is not None:
training_state["scheduler"] = scheduler.state_dict()
torch.save(training_state, self._checkpoint_dir / str(identifier) / "training_state.pth")
def load_last_training_state(self, optimizer: Optimizer, scheduler: LRScheduler | None) -> int:
"""
Load the optimizer and scheduler state_dict from the last checkpoint, set the random state, and return
the global training step.
"""
training_state = torch.load(self._checkpoint_dir / "last" / "training_state.pth")
optimizer.load_state_dict(training_state["optimizer"])
if scheduler is not None:
scheduler.load_state_dict(training_state["scheduler"])
elif "scheduler" in training_state:
raise ValueError(
"The checkpoint contains a scheduler state_dict, but no LRScheduler was provided."
) )
artifact.add_file(fp) # Small hack to get the expected keys: use `get_global_random_state`.
self._wandb.log_artifact(artifact) set_global_random_state({k: training_state[k] for k in get_global_random_state()})
return training_state["step"]
def finish(self, agent, buffer): def save_checkpont(
if self._save_model: self,
self.save_model(agent, identifier="final") train_step: int,
if self._save_buffer: policy: Policy,
self.save_buffer(buffer, identifier="buffer") optimizer: Optimizer,
if self._wandb: scheduler: LRScheduler | None,
self._wandb.finish() identifier: str,
):
self.save_model(policy, identifier)
self.save_training_state(train_step, optimizer, scheduler, identifier)
def log_dict(self, d, step, mode="train"): def log_dict(self, d, step, mode="train"):
assert mode in {"train", "eval"} assert mode in {"train", "eval"}

View File

@ -19,7 +19,7 @@ import random
from contextlib import contextmanager from contextlib import contextmanager
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Generator from typing import Any, Generator
import hydra import hydra
import numpy as np import numpy as np
@ -48,6 +48,28 @@ def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device:
return device return device
def get_global_random_state() -> dict[str, Any]:
"""Get the random state for `random`, `numpy`, and `torch`."""
return {
"random_state": random.getstate(),
"numpy_random_state": np.random.get_state(),
"torch_random_state": torch.random.get_rng_state(),
"torch_cuda_random_state": torch.cuda.random.get_rng_state(),
}
def set_global_random_state(random_state_dict: dict[str, Any]):
"""Set the random state for `random`, `numpy`, and `torch`.
Args:
random_state_dict: A dictionary of the form returned by `get_global_random_state`.
"""
random.setstate(random_state_dict["random_state"])
np.random.set_state(random_state_dict["numpy_random_state"])
torch.random.set_rng_state(random_state_dict["torch_random_state"])
torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"])
def set_global_seed(seed): def set_global_seed(seed):
"""Set seed for reproducibility.""" """Set seed for reproducibility."""
random.seed(seed) random.seed(seed)
@ -69,16 +91,10 @@ def seeded_context(seed: int) -> Generator[None, None, None]:
c = random.random() # produces yet another random number, but the same it would have if we never made `b` c = random.random() # produces yet another random number, but the same it would have if we never made `b`
``` ```
""" """
random_state = random.getstate() random_state_dict = get_global_random_state()
np_random_state = np.random.get_state()
torch_random_state = torch.random.get_rng_state()
torch_cuda_random_state = torch.cuda.random.get_rng_state()
set_global_seed(seed) set_global_seed(seed)
yield None yield None
random.setstate(random_state) set_global_random_state(random_state_dict)
np.random.set_state(np_random_state)
torch.random.set_rng_state(torch_random_state)
torch.cuda.random.set_rng_state(torch_cuda_random_state)
def init_logging(): def init_logging():

View File

@ -5,10 +5,19 @@ defaults:
hydra: hydra:
run: run:
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
dir: outputs/train/${now:%Y-%m-%d}/${now:%H-%M-%S}_${env.name}_${policy.name}_${hydra.job.name} dir: outputs/train/${now:%Y-%m-%d}/${now:%H-%M-%S}_${env.name}_${policy.name}_${hydra.job.name}
job: job:
name: default name: default
# Set `resume` to true to resume a previous run. In order for this to work, you will need to make sure
# `hydra.run.dir` is the directory of an existing run with at least one checkpoint in it.
# Note that run resumption works by grabbing the configuration file from
# {hydra.run.dir}/checkpoints/{specific_checkpoint_dir}/config.yaml. Any differences between the provided
# configuration and the prior configuration (apart from the resume parameter itself) are ignored. If you wish
# to change something, you can consider modifying the configuration in the file directly.
resume: false
device: cuda # cpu device: cuda # cpu
# `seed` is used for training (eg: model initialization, dataset shuffling) # `seed` is used for training (eg: model initialization, dataset shuffling)
# AND for the evaluation environments. # AND for the evaluation environments.

View File

@ -34,6 +34,7 @@ from lerobot.common.policies.policy_protocol import PolicyWithUpdate
from lerobot.common.utils.utils import ( from lerobot.common.utils.utils import (
format_big_number, format_big_number,
get_safe_torch_device, get_safe_torch_device,
init_hydra_config,
init_logging, init_logging,
set_global_seed, set_global_seed,
) )
@ -122,24 +123,6 @@ def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
return info return info
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
def train_cli(cfg: dict):
train(
cfg,
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
)
def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"):
from hydra import compose, initialize
hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(config_path=config_path)
cfg = compose(config_name=config_name)
train(cfg, out_dir=out_dir, job_name=job_name)
def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline): def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline):
loss = info["loss"] loss = info["loss"]
grad_norm = info["grad_norm"] grad_norm = info["grad_norm"]
@ -316,15 +299,19 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
init_logging() init_logging()
# log metrics to terminal and wandb
logger = Logger(out_dir, job_name, cfg)
if cfg.training.online_steps > 0 and cfg.eval.batch_size > 1: if cfg.training.online_steps > 0 and cfg.eval.batch_size > 1:
logging.warning("eval.batch_size > 1 not supported for online training steps") logging.warning("eval.batch_size > 1 not supported for online training steps")
set_global_seed(cfg.seed)
# Check device is available # Check device is available
get_safe_torch_device(cfg.device, log=True) get_safe_torch_device(cfg.device, log=True)
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
set_global_seed(cfg.seed)
logging.info("make_dataset") logging.info("make_dataset")
offline_dataset = make_dataset(cfg) offline_dataset = make_dataset(cfg)
@ -333,18 +320,32 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
eval_env = make_env(cfg) eval_env = make_env(cfg)
logging.info("make_policy") logging.info("make_policy")
policy = make_policy(hydra_cfg=cfg, dataset_stats=offline_dataset.stats) policy = make_policy(
hydra_cfg=cfg,
dataset_stats=offline_dataset.stats,
pretrained_policy_name_or_path=logger.last_checkpoint_path if cfg.resume else None,
)
# Create optimizer and scheduler # Create optimizer and scheduler
# Temporary hack to move optimizer out of policy # Temporary hack to move optimizer out of policy
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
step = 0 # number of policy updates (forward + backward + optim)
if cfg.resume:
print("You have set resume=True, indicating that you wish to resume a run.")
# Make sure there is a checkpoint.
if not Path(logger.last_checkpoint_path).exists():
raise RuntimeError(f"You have set resume=True, but {logger.last_checkpoint_path} does not exist.")
# Get the configuration file from the last checkpoint.
checkpoint_cfg = init_hydra_config(logger.last_checkpoint_path)
# TODO(now): Do a diff check.
cfg = checkpoint_cfg
step = logger.load_last_training_state(optimizer, lr_scheduler)
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters()) num_total_params = sum(p.numel() for p in policy.parameters())
# log metrics to terminal and wandb
logger = Logger(out_dir, job_name, cfg)
log_output_dir(out_dir) log_output_dir(out_dir)
logging.info(f"{cfg.env.task=}") logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.training.offline_steps=} ({format_big_number(cfg.training.offline_steps)})") logging.info(f"{cfg.training.offline_steps=} ({format_big_number(cfg.training.offline_steps)})")
@ -395,9 +396,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
dl_iter = cycle(dataloader) dl_iter = cycle(dataloader)
policy.train() policy.train()
step = 0 # number of policy update (forward + backward + optim)
is_offline = True is_offline = True
for offline_step in range(cfg.training.offline_steps): for offline_step in range(step, cfg.training.offline_steps):
if offline_step == 0: if offline_step == 0:
logging.info("Start offline training on a fixed dataset") logging.info("Start offline training on a fixed dataset")
batch = next(dl_iter) batch = next(dl_iter)
@ -491,5 +491,23 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logging.info("End of training") logging.info("End of training")
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
def train_cli(cfg: dict):
train(
cfg,
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
)
def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"):
from hydra import compose, initialize
hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(config_path=config_path)
cfg = compose(config_name=config_name)
train(cfg, out_dir=out_dir, job_name=job_name)
if __name__ == "__main__": if __name__ == "__main__":
train_cli() train_cli()

View File

@ -11,12 +11,14 @@ from lerobot.common.datasets.utils import (
hf_transform_to_torch, hf_transform_to_torch,
reset_episode_index, reset_episode_index,
) )
from lerobot.common.utils.utils import seeded_context, set_global_seed from lerobot.common.utils.utils import (
get_global_random_state,
seeded_context,
set_global_random_state,
set_global_seed,
)
rand_fns = (
@pytest.mark.parametrize(
"rand_fn",
(
[ [
random.random, random.random,
np.random.random, np.random.random,
@ -25,8 +27,10 @@ from lerobot.common.utils.utils import seeded_context, set_global_seed
+ [lambda: torch.rand(1, device="cuda")] + [lambda: torch.rand(1, device="cuda")]
if torch.cuda.is_available() if torch.cuda.is_available()
else [] else []
),
) )
@pytest.mark.parametrize("rand_fn", rand_fns)
def test_seeding(rand_fn: Callable[[], int]): def test_seeding(rand_fn: Callable[[], int]):
set_global_seed(0) set_global_seed(0)
a = rand_fn() a = rand_fn()
@ -46,6 +50,15 @@ def test_seeding(rand_fn: Callable[[], int]):
assert c_ == c assert c_ == c
def test_get_set_random_state():
"""Check that getting the random state, then setting it results in the same random number generation."""
random_state_dict = get_global_random_state()
rand_numbers = [rand_fn() for rand_fn in rand_fns]
set_global_random_state(random_state_dict)
rand_numbers_ = [rand_fn() for rand_fn in rand_fns]
assert rand_numbers_ == rand_numbers
def test_calculate_episode_data_index(): def test_calculate_episode_data_index():
dataset = Dataset.from_dict( dataset = Dataset.from_dict(
{ {