backup wip

This commit is contained in:
Alexander Soare 2024-05-20 17:42:35 +01:00
parent 9b62c25f6c
commit c99b845b8f
5 changed files with 157 additions and 73 deletions

View File

@ -20,11 +20,15 @@ import logging
import os
from pathlib import Path
import torch
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from omegaconf import OmegaConf
from termcolor import colored
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import get_global_random_state, set_global_random_state
def log_output_dir(out_dir):
@ -49,11 +53,11 @@ class Logger:
self._log_dir = Path(log_dir)
self._log_dir.mkdir(parents=True, exist_ok=True)
self._job_name = job_name
self._model_dir = self._log_dir / "checkpoints"
self._checkpoint_dir = self._log_dir / "checkpoints"
self._last_checkpoint_path = self._checkpoint_dir / "last"
self._buffer_dir = self._log_dir / "buffers"
self._save_model = cfg.training.save_model
self._disable_wandb_artifact = cfg.wandb.disable_artifact
self._save_buffer = cfg.training.get("save_buffer", False)
self._group = cfg_to_group(cfg)
self._seed = cfg.seed
self._cfg = cfg
@ -83,16 +87,20 @@ class Logger:
# TODO(rcadene): split train and eval, and run async eval with job_type="eval"
job_type="train_eval",
# TODO(rcadene): add resume option
resume=None,
resume="must",
)
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
self._wandb = wandb
def save_model(self, policy: Policy, identifier):
@property
def last_checkpoint_path(self):
return self._last_checkpoint_path
def save_model(self, policy: Policy, identifier: str):
if self._save_model:
self._model_dir.mkdir(parents=True, exist_ok=True)
save_dir = self._model_dir / str(identifier)
self._checkpoint_dir.mkdir(parents=True, exist_ok=True)
save_dir = self._checkpoint_dir / str(identifier)
policy.save_pretrained(save_dir)
# Also save the full Hydra config for the env configuration.
OmegaConf.save(self._cfg, save_dir / "config.yaml")
@ -104,27 +112,47 @@ class Logger:
)
artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE)
self._wandb.log_artifact(artifact)
os.symlink(save_dir.absolute(), self._last_checkpoint_path) # TODO(now): Check this works
def save_buffer(self, buffer, identifier):
self._buffer_dir.mkdir(parents=True, exist_ok=True)
fp = self._buffer_dir / f"{str(identifier)}.pkl"
buffer.save(fp)
if self._wandb and not self._disable_wandb_artifact:
# note wandb artifact does not accept ":" or "/" in its name
artifact = self._wandb.Artifact(
f"{self._group.replace(':', '_').replace('/', '_')}-{self._seed}-{identifier}",
type="buffer",
def save_training_state(
self, train_step: int, optimizer: Optimizer, scheduler: LRScheduler | None, identifier: str
):
training_state = {
"step": train_step,
"optimizer": optimizer.state_dict(),
**get_global_random_state(),
}
if scheduler is not None:
training_state["scheduler"] = scheduler.state_dict()
torch.save(training_state, self._checkpoint_dir / str(identifier) / "training_state.pth")
def load_last_training_state(self, optimizer: Optimizer, scheduler: LRScheduler | None) -> int:
"""
Load the optimizer and scheduler state_dict from the last checkpoint, set the random state, and return
the global training step.
"""
training_state = torch.load(self._checkpoint_dir / "last" / "training_state.pth")
optimizer.load_state_dict(training_state["optimizer"])
if scheduler is not None:
scheduler.load_state_dict(training_state["scheduler"])
elif "scheduler" in training_state:
raise ValueError(
"The checkpoint contains a scheduler state_dict, but no LRScheduler was provided."
)
artifact.add_file(fp)
self._wandb.log_artifact(artifact)
# Small hack to get the expected keys: use `get_global_random_state`.
set_global_random_state({k: training_state[k] for k in get_global_random_state()})
return training_state["step"]
def finish(self, agent, buffer):
if self._save_model:
self.save_model(agent, identifier="final")
if self._save_buffer:
self.save_buffer(buffer, identifier="buffer")
if self._wandb:
self._wandb.finish()
def save_checkpont(
self,
train_step: int,
policy: Policy,
optimizer: Optimizer,
scheduler: LRScheduler | None,
identifier: str,
):
self.save_model(policy, identifier)
self.save_training_state(train_step, optimizer, scheduler, identifier)
def log_dict(self, d, step, mode="train"):
assert mode in {"train", "eval"}

View File

@ -19,7 +19,7 @@ import random
from contextlib import contextmanager
from datetime import datetime
from pathlib import Path
from typing import Generator
from typing import Any, Generator
import hydra
import numpy as np
@ -48,6 +48,28 @@ def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device:
return device
def get_global_random_state() -> dict[str, Any]:
"""Get the random state for `random`, `numpy`, and `torch`."""
return {
"random_state": random.getstate(),
"numpy_random_state": np.random.get_state(),
"torch_random_state": torch.random.get_rng_state(),
"torch_cuda_random_state": torch.cuda.random.get_rng_state(),
}
def set_global_random_state(random_state_dict: dict[str, Any]):
"""Set the random state for `random`, `numpy`, and `torch`.
Args:
random_state_dict: A dictionary of the form returned by `get_global_random_state`.
"""
random.setstate(random_state_dict["random_state"])
np.random.set_state(random_state_dict["numpy_random_state"])
torch.random.set_rng_state(random_state_dict["torch_random_state"])
torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"])
def set_global_seed(seed):
"""Set seed for reproducibility."""
random.seed(seed)
@ -69,16 +91,10 @@ def seeded_context(seed: int) -> Generator[None, None, None]:
c = random.random() # produces yet another random number, but the same it would have if we never made `b`
```
"""
random_state = random.getstate()
np_random_state = np.random.get_state()
torch_random_state = torch.random.get_rng_state()
torch_cuda_random_state = torch.cuda.random.get_rng_state()
random_state_dict = get_global_random_state()
set_global_seed(seed)
yield None
random.setstate(random_state)
np.random.set_state(np_random_state)
torch.random.set_rng_state(torch_random_state)
torch.cuda.random.set_rng_state(torch_cuda_random_state)
set_global_random_state(random_state_dict)
def init_logging():

View File

@ -5,10 +5,19 @@ defaults:
hydra:
run:
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
dir: outputs/train/${now:%Y-%m-%d}/${now:%H-%M-%S}_${env.name}_${policy.name}_${hydra.job.name}
job:
name: default
# Set `resume` to true to resume a previous run. In order for this to work, you will need to make sure
# `hydra.run.dir` is the directory of an existing run with at least one checkpoint in it.
# Note that run resumption works by grabbing the configuration file from
# {hydra.run.dir}/checkpoints/{specific_checkpoint_dir}/config.yaml. Any differences between the provided
# configuration and the prior configuration (apart from the resume parameter itself) are ignored. If you wish
# to change something, you can consider modifying the configuration in the file directly.
resume: false
device: cuda # cpu
# `seed` is used for training (eg: model initialization, dataset shuffling)
# AND for the evaluation environments.

View File

@ -34,6 +34,7 @@ from lerobot.common.policies.policy_protocol import PolicyWithUpdate
from lerobot.common.utils.utils import (
format_big_number,
get_safe_torch_device,
init_hydra_config,
init_logging,
set_global_seed,
)
@ -122,24 +123,6 @@ def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
return info
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
def train_cli(cfg: dict):
train(
cfg,
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
)
def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"):
from hydra import compose, initialize
hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(config_path=config_path)
cfg = compose(config_name=config_name)
train(cfg, out_dir=out_dir, job_name=job_name)
def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline):
loss = info["loss"]
grad_norm = info["grad_norm"]
@ -316,15 +299,19 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
init_logging()
# log metrics to terminal and wandb
logger = Logger(out_dir, job_name, cfg)
if cfg.training.online_steps > 0 and cfg.eval.batch_size > 1:
logging.warning("eval.batch_size > 1 not supported for online training steps")
set_global_seed(cfg.seed)
# Check device is available
get_safe_torch_device(cfg.device, log=True)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
set_global_seed(cfg.seed)
logging.info("make_dataset")
offline_dataset = make_dataset(cfg)
@ -333,18 +320,32 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
eval_env = make_env(cfg)
logging.info("make_policy")
policy = make_policy(hydra_cfg=cfg, dataset_stats=offline_dataset.stats)
policy = make_policy(
hydra_cfg=cfg,
dataset_stats=offline_dataset.stats,
pretrained_policy_name_or_path=logger.last_checkpoint_path if cfg.resume else None,
)
# Create optimizer and scheduler
# Temporary hack to move optimizer out of policy
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
step = 0 # number of policy updates (forward + backward + optim)
if cfg.resume:
print("You have set resume=True, indicating that you wish to resume a run.")
# Make sure there is a checkpoint.
if not Path(logger.last_checkpoint_path).exists():
raise RuntimeError(f"You have set resume=True, but {logger.last_checkpoint_path} does not exist.")
# Get the configuration file from the last checkpoint.
checkpoint_cfg = init_hydra_config(logger.last_checkpoint_path)
# TODO(now): Do a diff check.
cfg = checkpoint_cfg
step = logger.load_last_training_state(optimizer, lr_scheduler)
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())
# log metrics to terminal and wandb
logger = Logger(out_dir, job_name, cfg)
log_output_dir(out_dir)
logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.training.offline_steps=} ({format_big_number(cfg.training.offline_steps)})")
@ -395,9 +396,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
dl_iter = cycle(dataloader)
policy.train()
step = 0 # number of policy update (forward + backward + optim)
is_offline = True
for offline_step in range(cfg.training.offline_steps):
for offline_step in range(step, cfg.training.offline_steps):
if offline_step == 0:
logging.info("Start offline training on a fixed dataset")
batch = next(dl_iter)
@ -491,5 +491,23 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logging.info("End of training")
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
def train_cli(cfg: dict):
train(
cfg,
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
)
def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"):
from hydra import compose, initialize
hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(config_path=config_path)
cfg = compose(config_name=config_name)
train(cfg, out_dir=out_dir, job_name=job_name)
if __name__ == "__main__":
train_cli()

View File

@ -11,22 +11,26 @@ from lerobot.common.datasets.utils import (
hf_transform_to_torch,
reset_episode_index,
)
from lerobot.common.utils.utils import seeded_context, set_global_seed
@pytest.mark.parametrize(
"rand_fn",
(
[
random.random,
np.random.random,
lambda: torch.rand(1).item(),
]
+ [lambda: torch.rand(1, device="cuda")]
if torch.cuda.is_available()
else []
),
from lerobot.common.utils.utils import (
get_global_random_state,
seeded_context,
set_global_random_state,
set_global_seed,
)
rand_fns = (
[
random.random,
np.random.random,
lambda: torch.rand(1).item(),
]
+ [lambda: torch.rand(1, device="cuda")]
if torch.cuda.is_available()
else []
)
@pytest.mark.parametrize("rand_fn", rand_fns)
def test_seeding(rand_fn: Callable[[], int]):
set_global_seed(0)
a = rand_fn()
@ -46,6 +50,15 @@ def test_seeding(rand_fn: Callable[[], int]):
assert c_ == c
def test_get_set_random_state():
"""Check that getting the random state, then setting it results in the same random number generation."""
random_state_dict = get_global_random_state()
rand_numbers = [rand_fn() for rand_fn in rand_fns]
set_global_random_state(random_state_dict)
rand_numbers_ = [rand_fn() for rand_fn in rand_fns]
assert rand_numbers_ == rand_numbers
def test_calculate_episode_data_index():
dataset = Dataset.from_dict(
{