196 lines
7.0 KiB
Python
196 lines
7.0 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import random
|
|
from contextlib import contextmanager
|
|
from pathlib import Path
|
|
from typing import Any, Generator
|
|
|
|
import numpy as np
|
|
import torch
|
|
from safetensors.torch import load_file, save_file
|
|
|
|
from lerobot.common.constants import RNG_STATE
|
|
from lerobot.common.datasets.utils import flatten_dict, unflatten_dict
|
|
|
|
|
|
def serialize_python_rng_state() -> dict[str, torch.Tensor]:
|
|
"""
|
|
Returns the rng state for `random` in the form of a flat dict[str, torch.Tensor] to be saved using
|
|
`safetensors.save_file()` or `torch.save()`.
|
|
"""
|
|
py_state = random.getstate()
|
|
return {
|
|
"py_rng_version": torch.tensor([py_state[0]], dtype=torch.int64),
|
|
"py_rng_state": torch.tensor(py_state[1], dtype=torch.int64),
|
|
}
|
|
|
|
|
|
def deserialize_python_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None:
|
|
"""
|
|
Restores the rng state for `random` from a dictionary produced by `serialize_python_rng_state()`.
|
|
"""
|
|
py_state = (
|
|
rng_state_dict["py_rng_version"].item(),
|
|
tuple(rng_state_dict["py_rng_state"].tolist()),
|
|
None,
|
|
)
|
|
random.setstate(py_state)
|
|
|
|
|
|
def serialize_numpy_rng_state() -> dict[str, torch.Tensor]:
|
|
"""
|
|
Returns the rng state for `numpy` in the form of a flat dict[str, torch.Tensor] to be saved using
|
|
`safetensors.save_file()` or `torch.save()`.
|
|
"""
|
|
np_state = np.random.get_state()
|
|
# Ensure no breaking changes from numpy
|
|
assert np_state[0] == "MT19937"
|
|
return {
|
|
"np_rng_state_values": torch.tensor(np_state[1], dtype=torch.int64),
|
|
"np_rng_state_index": torch.tensor([np_state[2]], dtype=torch.int64),
|
|
"np_rng_has_gauss": torch.tensor([np_state[3]], dtype=torch.int64),
|
|
"np_rng_cached_gaussian": torch.tensor([np_state[4]], dtype=torch.float32),
|
|
}
|
|
|
|
|
|
def deserialize_numpy_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None:
|
|
"""
|
|
Restores the rng state for `numpy` from a dictionary produced by `serialize_numpy_rng_state()`.
|
|
"""
|
|
np_state = (
|
|
"MT19937",
|
|
rng_state_dict["np_rng_state_values"].numpy(),
|
|
rng_state_dict["np_rng_state_index"].item(),
|
|
rng_state_dict["np_rng_has_gauss"].item(),
|
|
rng_state_dict["np_rng_cached_gaussian"].item(),
|
|
)
|
|
np.random.set_state(np_state)
|
|
|
|
|
|
def serialize_torch_rng_state() -> dict[str, torch.Tensor]:
|
|
"""
|
|
Returns the rng state for `torch` in the form of a flat dict[str, torch.Tensor] to be saved using
|
|
`safetensors.save_file()` or `torch.save()`.
|
|
"""
|
|
torch_rng_state_dict = {"torch_rng_state": torch.get_rng_state()}
|
|
if torch.cuda.is_available():
|
|
torch_rng_state_dict["torch_cuda_rng_state"] = torch.cuda.get_rng_state()
|
|
return torch_rng_state_dict
|
|
|
|
|
|
def deserialize_torch_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None:
|
|
"""
|
|
Restores the rng state for `torch` from a dictionary produced by `serialize_torch_rng_state()`.
|
|
"""
|
|
torch.set_rng_state(rng_state_dict["torch_rng_state"])
|
|
if torch.cuda.is_available() and "torch_cuda_rng_state" in rng_state_dict:
|
|
torch.cuda.set_rng_state(rng_state_dict["torch_cuda_rng_state"])
|
|
|
|
|
|
def serialize_rng_state() -> dict[str, torch.Tensor]:
|
|
"""
|
|
Returns the rng state for `random`, `numpy`, and `torch`, in the form of a flat
|
|
dict[str, torch.Tensor] to be saved using `safetensors.save_file()` `torch.save()`.
|
|
"""
|
|
py_rng_state_dict = serialize_python_rng_state()
|
|
np_rng_state_dict = serialize_numpy_rng_state()
|
|
torch_rng_state_dict = serialize_torch_rng_state()
|
|
|
|
return {
|
|
**py_rng_state_dict,
|
|
**np_rng_state_dict,
|
|
**torch_rng_state_dict,
|
|
}
|
|
|
|
|
|
def deserialize_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None:
|
|
"""
|
|
Restores the rng state for `random`, `numpy`, and `torch` from a dictionary produced by
|
|
`serialize_rng_state()`.
|
|
"""
|
|
py_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("py")}
|
|
np_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("np")}
|
|
torch_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("torch")}
|
|
|
|
deserialize_python_rng_state(py_rng_state_dict)
|
|
deserialize_numpy_rng_state(np_rng_state_dict)
|
|
deserialize_torch_rng_state(torch_rng_state_dict)
|
|
|
|
|
|
def save_rng_state(save_dir: Path) -> None:
|
|
rng_state_dict = serialize_rng_state()
|
|
flat_rng_state_dict = flatten_dict(rng_state_dict)
|
|
save_file(flat_rng_state_dict, save_dir / RNG_STATE)
|
|
|
|
|
|
def load_rng_state(save_dir: Path) -> None:
|
|
flat_rng_state_dict = load_file(save_dir / RNG_STATE)
|
|
rng_state_dict = unflatten_dict(flat_rng_state_dict)
|
|
deserialize_rng_state(rng_state_dict)
|
|
|
|
|
|
def get_rng_state() -> dict[str, Any]:
|
|
"""Get the random state for `random`, `numpy`, and `torch`."""
|
|
random_state_dict = {
|
|
"random_state": random.getstate(),
|
|
"numpy_random_state": np.random.get_state(),
|
|
"torch_random_state": torch.random.get_rng_state(),
|
|
}
|
|
if torch.cuda.is_available():
|
|
random_state_dict["torch_cuda_random_state"] = torch.cuda.random.get_rng_state()
|
|
return random_state_dict
|
|
|
|
|
|
def set_rng_state(random_state_dict: dict[str, Any]):
|
|
"""Set the random state for `random`, `numpy`, and `torch`.
|
|
|
|
Args:
|
|
random_state_dict: A dictionary of the form returned by `get_rng_state`.
|
|
"""
|
|
random.setstate(random_state_dict["random_state"])
|
|
np.random.set_state(random_state_dict["numpy_random_state"])
|
|
torch.random.set_rng_state(random_state_dict["torch_random_state"])
|
|
if torch.cuda.is_available():
|
|
torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"])
|
|
|
|
|
|
def set_seed(seed) -> None:
|
|
"""Set seed for reproducibility."""
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
|
|
@contextmanager
|
|
def seeded_context(seed: int) -> Generator[None, None, None]:
|
|
"""Set the seed when entering a context, and restore the prior random state at exit.
|
|
|
|
Example usage:
|
|
|
|
```
|
|
a = random.random() # produces some random number
|
|
with seeded_context(1337):
|
|
b = random.random() # produces some other random number
|
|
c = random.random() # produces yet another random number, but the same it would have if we never made `b`
|
|
```
|
|
"""
|
|
random_state_dict = get_rng_state()
|
|
set_seed(seed)
|
|
yield None
|
|
set_rng_state(random_state_dict)
|