backup wip
This commit is contained in:
parent
9b62c25f6c
commit
c99b845b8f
|
@ -20,11 +20,15 @@ import logging
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
from torch.optim.lr_scheduler import LRScheduler
|
||||||
|
|
||||||
from lerobot.common.policies.policy_protocol import Policy
|
from lerobot.common.policies.policy_protocol import Policy
|
||||||
|
from lerobot.common.utils.utils import get_global_random_state, set_global_random_state
|
||||||
|
|
||||||
|
|
||||||
def log_output_dir(out_dir):
|
def log_output_dir(out_dir):
|
||||||
|
@ -49,11 +53,11 @@ class Logger:
|
||||||
self._log_dir = Path(log_dir)
|
self._log_dir = Path(log_dir)
|
||||||
self._log_dir.mkdir(parents=True, exist_ok=True)
|
self._log_dir.mkdir(parents=True, exist_ok=True)
|
||||||
self._job_name = job_name
|
self._job_name = job_name
|
||||||
self._model_dir = self._log_dir / "checkpoints"
|
self._checkpoint_dir = self._log_dir / "checkpoints"
|
||||||
|
self._last_checkpoint_path = self._checkpoint_dir / "last"
|
||||||
self._buffer_dir = self._log_dir / "buffers"
|
self._buffer_dir = self._log_dir / "buffers"
|
||||||
self._save_model = cfg.training.save_model
|
self._save_model = cfg.training.save_model
|
||||||
self._disable_wandb_artifact = cfg.wandb.disable_artifact
|
self._disable_wandb_artifact = cfg.wandb.disable_artifact
|
||||||
self._save_buffer = cfg.training.get("save_buffer", False)
|
|
||||||
self._group = cfg_to_group(cfg)
|
self._group = cfg_to_group(cfg)
|
||||||
self._seed = cfg.seed
|
self._seed = cfg.seed
|
||||||
self._cfg = cfg
|
self._cfg = cfg
|
||||||
|
@ -83,16 +87,20 @@ class Logger:
|
||||||
# TODO(rcadene): split train and eval, and run async eval with job_type="eval"
|
# TODO(rcadene): split train and eval, and run async eval with job_type="eval"
|
||||||
job_type="train_eval",
|
job_type="train_eval",
|
||||||
# TODO(rcadene): add resume option
|
# TODO(rcadene): add resume option
|
||||||
resume=None,
|
resume="must",
|
||||||
)
|
)
|
||||||
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
||||||
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
||||||
self._wandb = wandb
|
self._wandb = wandb
|
||||||
|
|
||||||
def save_model(self, policy: Policy, identifier):
|
@property
|
||||||
|
def last_checkpoint_path(self):
|
||||||
|
return self._last_checkpoint_path
|
||||||
|
|
||||||
|
def save_model(self, policy: Policy, identifier: str):
|
||||||
if self._save_model:
|
if self._save_model:
|
||||||
self._model_dir.mkdir(parents=True, exist_ok=True)
|
self._checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||||
save_dir = self._model_dir / str(identifier)
|
save_dir = self._checkpoint_dir / str(identifier)
|
||||||
policy.save_pretrained(save_dir)
|
policy.save_pretrained(save_dir)
|
||||||
# Also save the full Hydra config for the env configuration.
|
# Also save the full Hydra config for the env configuration.
|
||||||
OmegaConf.save(self._cfg, save_dir / "config.yaml")
|
OmegaConf.save(self._cfg, save_dir / "config.yaml")
|
||||||
|
@ -104,27 +112,47 @@ class Logger:
|
||||||
)
|
)
|
||||||
artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE)
|
artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE)
|
||||||
self._wandb.log_artifact(artifact)
|
self._wandb.log_artifact(artifact)
|
||||||
|
os.symlink(save_dir.absolute(), self._last_checkpoint_path) # TODO(now): Check this works
|
||||||
|
|
||||||
def save_buffer(self, buffer, identifier):
|
def save_training_state(
|
||||||
self._buffer_dir.mkdir(parents=True, exist_ok=True)
|
self, train_step: int, optimizer: Optimizer, scheduler: LRScheduler | None, identifier: str
|
||||||
fp = self._buffer_dir / f"{str(identifier)}.pkl"
|
):
|
||||||
buffer.save(fp)
|
training_state = {
|
||||||
if self._wandb and not self._disable_wandb_artifact:
|
"step": train_step,
|
||||||
# note wandb artifact does not accept ":" or "/" in its name
|
"optimizer": optimizer.state_dict(),
|
||||||
artifact = self._wandb.Artifact(
|
**get_global_random_state(),
|
||||||
f"{self._group.replace(':', '_').replace('/', '_')}-{self._seed}-{identifier}",
|
}
|
||||||
type="buffer",
|
if scheduler is not None:
|
||||||
|
training_state["scheduler"] = scheduler.state_dict()
|
||||||
|
torch.save(training_state, self._checkpoint_dir / str(identifier) / "training_state.pth")
|
||||||
|
|
||||||
|
def load_last_training_state(self, optimizer: Optimizer, scheduler: LRScheduler | None) -> int:
|
||||||
|
"""
|
||||||
|
Load the optimizer and scheduler state_dict from the last checkpoint, set the random state, and return
|
||||||
|
the global training step.
|
||||||
|
"""
|
||||||
|
training_state = torch.load(self._checkpoint_dir / "last" / "training_state.pth")
|
||||||
|
optimizer.load_state_dict(training_state["optimizer"])
|
||||||
|
if scheduler is not None:
|
||||||
|
scheduler.load_state_dict(training_state["scheduler"])
|
||||||
|
elif "scheduler" in training_state:
|
||||||
|
raise ValueError(
|
||||||
|
"The checkpoint contains a scheduler state_dict, but no LRScheduler was provided."
|
||||||
)
|
)
|
||||||
artifact.add_file(fp)
|
# Small hack to get the expected keys: use `get_global_random_state`.
|
||||||
self._wandb.log_artifact(artifact)
|
set_global_random_state({k: training_state[k] for k in get_global_random_state()})
|
||||||
|
return training_state["step"]
|
||||||
|
|
||||||
def finish(self, agent, buffer):
|
def save_checkpont(
|
||||||
if self._save_model:
|
self,
|
||||||
self.save_model(agent, identifier="final")
|
train_step: int,
|
||||||
if self._save_buffer:
|
policy: Policy,
|
||||||
self.save_buffer(buffer, identifier="buffer")
|
optimizer: Optimizer,
|
||||||
if self._wandb:
|
scheduler: LRScheduler | None,
|
||||||
self._wandb.finish()
|
identifier: str,
|
||||||
|
):
|
||||||
|
self.save_model(policy, identifier)
|
||||||
|
self.save_training_state(train_step, optimizer, scheduler, identifier)
|
||||||
|
|
||||||
def log_dict(self, d, step, mode="train"):
|
def log_dict(self, d, step, mode="train"):
|
||||||
assert mode in {"train", "eval"}
|
assert mode in {"train", "eval"}
|
||||||
|
|
|
@ -19,7 +19,7 @@ import random
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Generator
|
from typing import Any, Generator
|
||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -48,6 +48,28 @@ def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device:
|
||||||
return device
|
return device
|
||||||
|
|
||||||
|
|
||||||
|
def get_global_random_state() -> dict[str, Any]:
|
||||||
|
"""Get the random state for `random`, `numpy`, and `torch`."""
|
||||||
|
return {
|
||||||
|
"random_state": random.getstate(),
|
||||||
|
"numpy_random_state": np.random.get_state(),
|
||||||
|
"torch_random_state": torch.random.get_rng_state(),
|
||||||
|
"torch_cuda_random_state": torch.cuda.random.get_rng_state(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def set_global_random_state(random_state_dict: dict[str, Any]):
|
||||||
|
"""Set the random state for `random`, `numpy`, and `torch`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
random_state_dict: A dictionary of the form returned by `get_global_random_state`.
|
||||||
|
"""
|
||||||
|
random.setstate(random_state_dict["random_state"])
|
||||||
|
np.random.set_state(random_state_dict["numpy_random_state"])
|
||||||
|
torch.random.set_rng_state(random_state_dict["torch_random_state"])
|
||||||
|
torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"])
|
||||||
|
|
||||||
|
|
||||||
def set_global_seed(seed):
|
def set_global_seed(seed):
|
||||||
"""Set seed for reproducibility."""
|
"""Set seed for reproducibility."""
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
|
@ -69,16 +91,10 @@ def seeded_context(seed: int) -> Generator[None, None, None]:
|
||||||
c = random.random() # produces yet another random number, but the same it would have if we never made `b`
|
c = random.random() # produces yet another random number, but the same it would have if we never made `b`
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
random_state = random.getstate()
|
random_state_dict = get_global_random_state()
|
||||||
np_random_state = np.random.get_state()
|
|
||||||
torch_random_state = torch.random.get_rng_state()
|
|
||||||
torch_cuda_random_state = torch.cuda.random.get_rng_state()
|
|
||||||
set_global_seed(seed)
|
set_global_seed(seed)
|
||||||
yield None
|
yield None
|
||||||
random.setstate(random_state)
|
set_global_random_state(random_state_dict)
|
||||||
np.random.set_state(np_random_state)
|
|
||||||
torch.random.set_rng_state(torch_random_state)
|
|
||||||
torch.cuda.random.set_rng_state(torch_cuda_random_state)
|
|
||||||
|
|
||||||
|
|
||||||
def init_logging():
|
def init_logging():
|
||||||
|
|
|
@ -5,10 +5,19 @@ defaults:
|
||||||
|
|
||||||
hydra:
|
hydra:
|
||||||
run:
|
run:
|
||||||
|
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
|
||||||
|
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
|
||||||
dir: outputs/train/${now:%Y-%m-%d}/${now:%H-%M-%S}_${env.name}_${policy.name}_${hydra.job.name}
|
dir: outputs/train/${now:%Y-%m-%d}/${now:%H-%M-%S}_${env.name}_${policy.name}_${hydra.job.name}
|
||||||
job:
|
job:
|
||||||
name: default
|
name: default
|
||||||
|
|
||||||
|
# Set `resume` to true to resume a previous run. In order for this to work, you will need to make sure
|
||||||
|
# `hydra.run.dir` is the directory of an existing run with at least one checkpoint in it.
|
||||||
|
# Note that run resumption works by grabbing the configuration file from
|
||||||
|
# {hydra.run.dir}/checkpoints/{specific_checkpoint_dir}/config.yaml. Any differences between the provided
|
||||||
|
# configuration and the prior configuration (apart from the resume parameter itself) are ignored. If you wish
|
||||||
|
# to change something, you can consider modifying the configuration in the file directly.
|
||||||
|
resume: false
|
||||||
device: cuda # cpu
|
device: cuda # cpu
|
||||||
# `seed` is used for training (eg: model initialization, dataset shuffling)
|
# `seed` is used for training (eg: model initialization, dataset shuffling)
|
||||||
# AND for the evaluation environments.
|
# AND for the evaluation environments.
|
||||||
|
|
|
@ -34,6 +34,7 @@ from lerobot.common.policies.policy_protocol import PolicyWithUpdate
|
||||||
from lerobot.common.utils.utils import (
|
from lerobot.common.utils.utils import (
|
||||||
format_big_number,
|
format_big_number,
|
||||||
get_safe_torch_device,
|
get_safe_torch_device,
|
||||||
|
init_hydra_config,
|
||||||
init_logging,
|
init_logging,
|
||||||
set_global_seed,
|
set_global_seed,
|
||||||
)
|
)
|
||||||
|
@ -122,24 +123,6 @@ def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
|
||||||
return info
|
return info
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
|
|
||||||
def train_cli(cfg: dict):
|
|
||||||
train(
|
|
||||||
cfg,
|
|
||||||
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
|
|
||||||
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"):
|
|
||||||
from hydra import compose, initialize
|
|
||||||
|
|
||||||
hydra.core.global_hydra.GlobalHydra.instance().clear()
|
|
||||||
initialize(config_path=config_path)
|
|
||||||
cfg = compose(config_name=config_name)
|
|
||||||
train(cfg, out_dir=out_dir, job_name=job_name)
|
|
||||||
|
|
||||||
|
|
||||||
def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline):
|
def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline):
|
||||||
loss = info["loss"]
|
loss = info["loss"]
|
||||||
grad_norm = info["grad_norm"]
|
grad_norm = info["grad_norm"]
|
||||||
|
@ -316,15 +299,19 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
|
|
||||||
init_logging()
|
init_logging()
|
||||||
|
|
||||||
|
# log metrics to terminal and wandb
|
||||||
|
logger = Logger(out_dir, job_name, cfg)
|
||||||
|
|
||||||
if cfg.training.online_steps > 0 and cfg.eval.batch_size > 1:
|
if cfg.training.online_steps > 0 and cfg.eval.batch_size > 1:
|
||||||
logging.warning("eval.batch_size > 1 not supported for online training steps")
|
logging.warning("eval.batch_size > 1 not supported for online training steps")
|
||||||
|
|
||||||
|
set_global_seed(cfg.seed)
|
||||||
|
|
||||||
# Check device is available
|
# Check device is available
|
||||||
get_safe_torch_device(cfg.device, log=True)
|
get_safe_torch_device(cfg.device, log=True)
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
set_global_seed(cfg.seed)
|
|
||||||
|
|
||||||
logging.info("make_dataset")
|
logging.info("make_dataset")
|
||||||
offline_dataset = make_dataset(cfg)
|
offline_dataset = make_dataset(cfg)
|
||||||
|
@ -333,18 +320,32 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
eval_env = make_env(cfg)
|
eval_env = make_env(cfg)
|
||||||
|
|
||||||
logging.info("make_policy")
|
logging.info("make_policy")
|
||||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=offline_dataset.stats)
|
policy = make_policy(
|
||||||
|
hydra_cfg=cfg,
|
||||||
|
dataset_stats=offline_dataset.stats,
|
||||||
|
pretrained_policy_name_or_path=logger.last_checkpoint_path if cfg.resume else None,
|
||||||
|
)
|
||||||
|
|
||||||
# Create optimizer and scheduler
|
# Create optimizer and scheduler
|
||||||
# Temporary hack to move optimizer out of policy
|
# Temporary hack to move optimizer out of policy
|
||||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||||
|
|
||||||
|
step = 0 # number of policy updates (forward + backward + optim)
|
||||||
|
|
||||||
|
if cfg.resume:
|
||||||
|
print("You have set resume=True, indicating that you wish to resume a run.")
|
||||||
|
# Make sure there is a checkpoint.
|
||||||
|
if not Path(logger.last_checkpoint_path).exists():
|
||||||
|
raise RuntimeError(f"You have set resume=True, but {logger.last_checkpoint_path} does not exist.")
|
||||||
|
# Get the configuration file from the last checkpoint.
|
||||||
|
checkpoint_cfg = init_hydra_config(logger.last_checkpoint_path)
|
||||||
|
# TODO(now): Do a diff check.
|
||||||
|
cfg = checkpoint_cfg
|
||||||
|
step = logger.load_last_training_state(optimizer, lr_scheduler)
|
||||||
|
|
||||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||||
|
|
||||||
# log metrics to terminal and wandb
|
|
||||||
logger = Logger(out_dir, job_name, cfg)
|
|
||||||
|
|
||||||
log_output_dir(out_dir)
|
log_output_dir(out_dir)
|
||||||
logging.info(f"{cfg.env.task=}")
|
logging.info(f"{cfg.env.task=}")
|
||||||
logging.info(f"{cfg.training.offline_steps=} ({format_big_number(cfg.training.offline_steps)})")
|
logging.info(f"{cfg.training.offline_steps=} ({format_big_number(cfg.training.offline_steps)})")
|
||||||
|
@ -395,9 +396,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
dl_iter = cycle(dataloader)
|
dl_iter = cycle(dataloader)
|
||||||
|
|
||||||
policy.train()
|
policy.train()
|
||||||
step = 0 # number of policy update (forward + backward + optim)
|
|
||||||
is_offline = True
|
is_offline = True
|
||||||
for offline_step in range(cfg.training.offline_steps):
|
for offline_step in range(step, cfg.training.offline_steps):
|
||||||
if offline_step == 0:
|
if offline_step == 0:
|
||||||
logging.info("Start offline training on a fixed dataset")
|
logging.info("Start offline training on a fixed dataset")
|
||||||
batch = next(dl_iter)
|
batch = next(dl_iter)
|
||||||
|
@ -491,5 +491,23 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
logging.info("End of training")
|
logging.info("End of training")
|
||||||
|
|
||||||
|
|
||||||
|
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
|
||||||
|
def train_cli(cfg: dict):
|
||||||
|
train(
|
||||||
|
cfg,
|
||||||
|
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
|
||||||
|
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"):
|
||||||
|
from hydra import compose, initialize
|
||||||
|
|
||||||
|
hydra.core.global_hydra.GlobalHydra.instance().clear()
|
||||||
|
initialize(config_path=config_path)
|
||||||
|
cfg = compose(config_name=config_name)
|
||||||
|
train(cfg, out_dir=out_dir, job_name=job_name)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
train_cli()
|
train_cli()
|
||||||
|
|
|
@ -11,12 +11,14 @@ from lerobot.common.datasets.utils import (
|
||||||
hf_transform_to_torch,
|
hf_transform_to_torch,
|
||||||
reset_episode_index,
|
reset_episode_index,
|
||||||
)
|
)
|
||||||
from lerobot.common.utils.utils import seeded_context, set_global_seed
|
from lerobot.common.utils.utils import (
|
||||||
|
get_global_random_state,
|
||||||
|
seeded_context,
|
||||||
|
set_global_random_state,
|
||||||
|
set_global_seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
rand_fns = (
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"rand_fn",
|
|
||||||
(
|
|
||||||
[
|
[
|
||||||
random.random,
|
random.random,
|
||||||
np.random.random,
|
np.random.random,
|
||||||
|
@ -25,8 +27,10 @@ from lerobot.common.utils.utils import seeded_context, set_global_seed
|
||||||
+ [lambda: torch.rand(1, device="cuda")]
|
+ [lambda: torch.rand(1, device="cuda")]
|
||||||
if torch.cuda.is_available()
|
if torch.cuda.is_available()
|
||||||
else []
|
else []
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("rand_fn", rand_fns)
|
||||||
def test_seeding(rand_fn: Callable[[], int]):
|
def test_seeding(rand_fn: Callable[[], int]):
|
||||||
set_global_seed(0)
|
set_global_seed(0)
|
||||||
a = rand_fn()
|
a = rand_fn()
|
||||||
|
@ -46,6 +50,15 @@ def test_seeding(rand_fn: Callable[[], int]):
|
||||||
assert c_ == c
|
assert c_ == c
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_set_random_state():
|
||||||
|
"""Check that getting the random state, then setting it results in the same random number generation."""
|
||||||
|
random_state_dict = get_global_random_state()
|
||||||
|
rand_numbers = [rand_fn() for rand_fn in rand_fns]
|
||||||
|
set_global_random_state(random_state_dict)
|
||||||
|
rand_numbers_ = [rand_fn() for rand_fn in rand_fns]
|
||||||
|
assert rand_numbers_ == rand_numbers
|
||||||
|
|
||||||
|
|
||||||
def test_calculate_episode_data_index():
|
def test_calculate_episode_data_index():
|
||||||
dataset = Dataset.from_dict(
|
dataset = Dataset.from_dict(
|
||||||
{
|
{
|
||||||
|
|
Loading…
Reference in New Issue