diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py index 109f6951..8a374932 100644 --- a/lerobot/common/logger.py +++ b/lerobot/common/logger.py @@ -20,11 +20,15 @@ import logging import os from pathlib import Path +import torch from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE from omegaconf import OmegaConf from termcolor import colored +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LRScheduler from lerobot.common.policies.policy_protocol import Policy +from lerobot.common.utils.utils import get_global_random_state, set_global_random_state def log_output_dir(out_dir): @@ -49,11 +53,11 @@ class Logger: self._log_dir = Path(log_dir) self._log_dir.mkdir(parents=True, exist_ok=True) self._job_name = job_name - self._model_dir = self._log_dir / "checkpoints" + self._checkpoint_dir = self._log_dir / "checkpoints" + self._last_checkpoint_path = self._checkpoint_dir / "last" self._buffer_dir = self._log_dir / "buffers" self._save_model = cfg.training.save_model self._disable_wandb_artifact = cfg.wandb.disable_artifact - self._save_buffer = cfg.training.get("save_buffer", False) self._group = cfg_to_group(cfg) self._seed = cfg.seed self._cfg = cfg @@ -83,16 +87,20 @@ class Logger: # TODO(rcadene): split train and eval, and run async eval with job_type="eval" job_type="train_eval", # TODO(rcadene): add resume option - resume=None, + resume="must", ) print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"])) logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}") self._wandb = wandb - def save_model(self, policy: Policy, identifier): + @property + def last_checkpoint_path(self): + return self._last_checkpoint_path + + def save_model(self, policy: Policy, identifier: str): if self._save_model: - self._model_dir.mkdir(parents=True, exist_ok=True) - save_dir = self._model_dir / str(identifier) + self._checkpoint_dir.mkdir(parents=True, exist_ok=True) + save_dir = self._checkpoint_dir / str(identifier) policy.save_pretrained(save_dir) # Also save the full Hydra config for the env configuration. OmegaConf.save(self._cfg, save_dir / "config.yaml") @@ -104,27 +112,47 @@ class Logger: ) artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE) self._wandb.log_artifact(artifact) + os.symlink(save_dir.absolute(), self._last_checkpoint_path) # TODO(now): Check this works - def save_buffer(self, buffer, identifier): - self._buffer_dir.mkdir(parents=True, exist_ok=True) - fp = self._buffer_dir / f"{str(identifier)}.pkl" - buffer.save(fp) - if self._wandb and not self._disable_wandb_artifact: - # note wandb artifact does not accept ":" or "/" in its name - artifact = self._wandb.Artifact( - f"{self._group.replace(':', '_').replace('/', '_')}-{self._seed}-{identifier}", - type="buffer", + def save_training_state( + self, train_step: int, optimizer: Optimizer, scheduler: LRScheduler | None, identifier: str + ): + training_state = { + "step": train_step, + "optimizer": optimizer.state_dict(), + **get_global_random_state(), + } + if scheduler is not None: + training_state["scheduler"] = scheduler.state_dict() + torch.save(training_state, self._checkpoint_dir / str(identifier) / "training_state.pth") + + def load_last_training_state(self, optimizer: Optimizer, scheduler: LRScheduler | None) -> int: + """ + Load the optimizer and scheduler state_dict from the last checkpoint, set the random state, and return + the global training step. + """ + training_state = torch.load(self._checkpoint_dir / "last" / "training_state.pth") + optimizer.load_state_dict(training_state["optimizer"]) + if scheduler is not None: + scheduler.load_state_dict(training_state["scheduler"]) + elif "scheduler" in training_state: + raise ValueError( + "The checkpoint contains a scheduler state_dict, but no LRScheduler was provided." ) - artifact.add_file(fp) - self._wandb.log_artifact(artifact) + # Small hack to get the expected keys: use `get_global_random_state`. + set_global_random_state({k: training_state[k] for k in get_global_random_state()}) + return training_state["step"] - def finish(self, agent, buffer): - if self._save_model: - self.save_model(agent, identifier="final") - if self._save_buffer: - self.save_buffer(buffer, identifier="buffer") - if self._wandb: - self._wandb.finish() + def save_checkpont( + self, + train_step: int, + policy: Policy, + optimizer: Optimizer, + scheduler: LRScheduler | None, + identifier: str, + ): + self.save_model(policy, identifier) + self.save_training_state(train_step, optimizer, scheduler, identifier) def log_dict(self, d, step, mode="train"): assert mode in {"train", "eval"} diff --git a/lerobot/common/utils/utils.py b/lerobot/common/utils/utils.py index d62507b5..0eab089f 100644 --- a/lerobot/common/utils/utils.py +++ b/lerobot/common/utils/utils.py @@ -19,7 +19,7 @@ import random from contextlib import contextmanager from datetime import datetime from pathlib import Path -from typing import Generator +from typing import Any, Generator import hydra import numpy as np @@ -48,6 +48,28 @@ def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device: return device +def get_global_random_state() -> dict[str, Any]: + """Get the random state for `random`, `numpy`, and `torch`.""" + return { + "random_state": random.getstate(), + "numpy_random_state": np.random.get_state(), + "torch_random_state": torch.random.get_rng_state(), + "torch_cuda_random_state": torch.cuda.random.get_rng_state(), + } + + +def set_global_random_state(random_state_dict: dict[str, Any]): + """Set the random state for `random`, `numpy`, and `torch`. + + Args: + random_state_dict: A dictionary of the form returned by `get_global_random_state`. + """ + random.setstate(random_state_dict["random_state"]) + np.random.set_state(random_state_dict["numpy_random_state"]) + torch.random.set_rng_state(random_state_dict["torch_random_state"]) + torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"]) + + def set_global_seed(seed): """Set seed for reproducibility.""" random.seed(seed) @@ -69,16 +91,10 @@ def seeded_context(seed: int) -> Generator[None, None, None]: c = random.random() # produces yet another random number, but the same it would have if we never made `b` ``` """ - random_state = random.getstate() - np_random_state = np.random.get_state() - torch_random_state = torch.random.get_rng_state() - torch_cuda_random_state = torch.cuda.random.get_rng_state() + random_state_dict = get_global_random_state() set_global_seed(seed) yield None - random.setstate(random_state) - np.random.set_state(np_random_state) - torch.random.set_rng_state(torch_random_state) - torch.cuda.random.set_rng_state(torch_cuda_random_state) + set_global_random_state(random_state_dict) def init_logging(): diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index ae36b3e2..757580c9 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -5,10 +5,19 @@ defaults: hydra: run: + # Set `dir` to where you would like to save all of the run outputs. If you run another training session + # with the same value for `dir` its contents will be overwritten unless you set `resume` to true. dir: outputs/train/${now:%Y-%m-%d}/${now:%H-%M-%S}_${env.name}_${policy.name}_${hydra.job.name} job: name: default +# Set `resume` to true to resume a previous run. In order for this to work, you will need to make sure +# `hydra.run.dir` is the directory of an existing run with at least one checkpoint in it. +# Note that run resumption works by grabbing the configuration file from +# {hydra.run.dir}/checkpoints/{specific_checkpoint_dir}/config.yaml. Any differences between the provided +# configuration and the prior configuration (apart from the resume parameter itself) are ignored. If you wish +# to change something, you can consider modifying the configuration in the file directly. +resume: false device: cuda # cpu # `seed` is used for training (eg: model initialization, dataset shuffling) # AND for the evaluation environments. diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 7ca7a0b3..dc17f604 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -34,6 +34,7 @@ from lerobot.common.policies.policy_protocol import PolicyWithUpdate from lerobot.common.utils.utils import ( format_big_number, get_safe_torch_device, + init_hydra_config, init_logging, set_global_seed, ) @@ -122,24 +123,6 @@ def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None): return info -@hydra.main(version_base="1.2", config_name="default", config_path="../configs") -def train_cli(cfg: dict): - train( - cfg, - out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir, - job_name=hydra.core.hydra_config.HydraConfig.get().job.name, - ) - - -def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"): - from hydra import compose, initialize - - hydra.core.global_hydra.GlobalHydra.instance().clear() - initialize(config_path=config_path) - cfg = compose(config_name=config_name) - train(cfg, out_dir=out_dir, job_name=job_name) - - def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline): loss = info["loss"] grad_norm = info["grad_norm"] @@ -316,15 +299,19 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No init_logging() + # log metrics to terminal and wandb + logger = Logger(out_dir, job_name, cfg) + if cfg.training.online_steps > 0 and cfg.eval.batch_size > 1: logging.warning("eval.batch_size > 1 not supported for online training steps") + set_global_seed(cfg.seed) + # Check device is available get_safe_torch_device(cfg.device, log=True) torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True - set_global_seed(cfg.seed) logging.info("make_dataset") offline_dataset = make_dataset(cfg) @@ -333,18 +320,32 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No eval_env = make_env(cfg) logging.info("make_policy") - policy = make_policy(hydra_cfg=cfg, dataset_stats=offline_dataset.stats) + policy = make_policy( + hydra_cfg=cfg, + dataset_stats=offline_dataset.stats, + pretrained_policy_name_or_path=logger.last_checkpoint_path if cfg.resume else None, + ) # Create optimizer and scheduler # Temporary hack to move optimizer out of policy optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) + step = 0 # number of policy updates (forward + backward + optim) + + if cfg.resume: + print("You have set resume=True, indicating that you wish to resume a run.") + # Make sure there is a checkpoint. + if not Path(logger.last_checkpoint_path).exists(): + raise RuntimeError(f"You have set resume=True, but {logger.last_checkpoint_path} does not exist.") + # Get the configuration file from the last checkpoint. + checkpoint_cfg = init_hydra_config(logger.last_checkpoint_path) + # TODO(now): Do a diff check. + cfg = checkpoint_cfg + step = logger.load_last_training_state(optimizer, lr_scheduler) + num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_total_params = sum(p.numel() for p in policy.parameters()) - # log metrics to terminal and wandb - logger = Logger(out_dir, job_name, cfg) - log_output_dir(out_dir) logging.info(f"{cfg.env.task=}") logging.info(f"{cfg.training.offline_steps=} ({format_big_number(cfg.training.offline_steps)})") @@ -395,9 +396,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No dl_iter = cycle(dataloader) policy.train() - step = 0 # number of policy update (forward + backward + optim) is_offline = True - for offline_step in range(cfg.training.offline_steps): + for offline_step in range(step, cfg.training.offline_steps): if offline_step == 0: logging.info("Start offline training on a fixed dataset") batch = next(dl_iter) @@ -491,5 +491,23 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No logging.info("End of training") +@hydra.main(version_base="1.2", config_name="default", config_path="../configs") +def train_cli(cfg: dict): + train( + cfg, + out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir, + job_name=hydra.core.hydra_config.HydraConfig.get().job.name, + ) + + +def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"): + from hydra import compose, initialize + + hydra.core.global_hydra.GlobalHydra.instance().clear() + initialize(config_path=config_path) + cfg = compose(config_name=config_name) + train(cfg, out_dir=out_dir, job_name=job_name) + + if __name__ == "__main__": train_cli() diff --git a/tests/test_utils.py b/tests/test_utils.py index a7f770fb..7f9f967c 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -11,22 +11,26 @@ from lerobot.common.datasets.utils import ( hf_transform_to_torch, reset_episode_index, ) -from lerobot.common.utils.utils import seeded_context, set_global_seed - - -@pytest.mark.parametrize( - "rand_fn", - ( - [ - random.random, - np.random.random, - lambda: torch.rand(1).item(), - ] - + [lambda: torch.rand(1, device="cuda")] - if torch.cuda.is_available() - else [] - ), +from lerobot.common.utils.utils import ( + get_global_random_state, + seeded_context, + set_global_random_state, + set_global_seed, ) + +rand_fns = ( + [ + random.random, + np.random.random, + lambda: torch.rand(1).item(), + ] + + [lambda: torch.rand(1, device="cuda")] + if torch.cuda.is_available() + else [] +) + + +@pytest.mark.parametrize("rand_fn", rand_fns) def test_seeding(rand_fn: Callable[[], int]): set_global_seed(0) a = rand_fn() @@ -46,6 +50,15 @@ def test_seeding(rand_fn: Callable[[], int]): assert c_ == c +def test_get_set_random_state(): + """Check that getting the random state, then setting it results in the same random number generation.""" + random_state_dict = get_global_random_state() + rand_numbers = [rand_fn() for rand_fn in rand_fns] + set_global_random_state(random_state_dict) + rand_numbers_ = [rand_fn() for rand_fn in rand_fns] + assert rand_numbers_ == rand_numbers + + def test_calculate_episode_data_index(): dataset = Dataset.from_dict( {