Add context manager for seeding (#164)

This commit is contained in:
Alexander Soare 2024-05-09 17:58:39 +01:00 committed by GitHub
parent 473345fdf6
commit b187942db4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 65 additions and 0 deletions

View File

@ -1,8 +1,10 @@
import logging import logging
import os.path as osp import os.path as osp
import random import random
from contextlib import contextmanager
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Generator
import hydra import hydra
import numpy as np import numpy as np
@ -39,6 +41,31 @@ def set_global_seed(seed):
torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed_all(seed)
@contextmanager
def seeded_context(seed: int) -> Generator[None, None, None]:
"""Set the seed when entering a context, and restore the prior random state at exit.
Example usage:
```
a = random.random() # produces some random number
with seeded_context(1337):
b = random.random() # produces some other random number
c = random.random() # produces yet another random number, but the same it would have if we never made `b`
```
"""
random_state = random.getstate()
np_random_state = np.random.get_state()
torch_random_state = torch.random.get_rng_state()
torch_cuda_random_state = torch.cuda.random.get_rng_state()
set_global_seed(seed)
yield None
random.setstate(random_state)
np.random.set_state(np_random_state)
torch.random.set_rng_state(torch_random_state)
torch.cuda.random.set_rng_state(torch_cuda_random_state)
def init_logging(): def init_logging():
def custom_format(record): def custom_format(record):
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S") dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

38
tests/test_utils.py Normal file
View File

@ -0,0 +1,38 @@
import random
from typing import Callable
import numpy as np
import pytest
import torch
from lerobot.common.utils.utils import seeded_context, set_global_seed
@pytest.mark.parametrize(
"rand_fn",
[
random.random,
np.random.random,
lambda: torch.rand(1).item(),
]
+ [lambda: torch.rand(1, device="cuda")]
if torch.cuda.is_available()
else [],
)
def test_seeding(rand_fn: Callable[[], int]):
set_global_seed(0)
a = rand_fn()
with seeded_context(1337):
c = rand_fn()
b = rand_fn()
set_global_seed(0)
a_ = rand_fn()
b_ = rand_fn()
# Check that `set_global_seed` lets us reproduce a and b.
assert a_ == a
# Additionally, check that the `seeded_context` didn't interrupt the global RNG.
assert b_ == b
set_global_seed(1337)
c_ = rand_fn()
# Check that `seeded_context` and `global_seed` give the same reproducibility.
assert c_ == c