Add context manager for seeding (#164)
This commit is contained in:
parent
473345fdf6
commit
b187942db4
|
@ -1,8 +1,10 @@
|
||||||
import logging
|
import logging
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import random
|
import random
|
||||||
|
from contextlib import contextmanager
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Generator
|
||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -39,6 +41,31 @@ def set_global_seed(seed):
|
||||||
torch.cuda.manual_seed_all(seed)
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def seeded_context(seed: int) -> Generator[None, None, None]:
|
||||||
|
"""Set the seed when entering a context, and restore the prior random state at exit.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
|
||||||
|
```
|
||||||
|
a = random.random() # produces some random number
|
||||||
|
with seeded_context(1337):
|
||||||
|
b = random.random() # produces some other random number
|
||||||
|
c = random.random() # produces yet another random number, but the same it would have if we never made `b`
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
random_state = random.getstate()
|
||||||
|
np_random_state = np.random.get_state()
|
||||||
|
torch_random_state = torch.random.get_rng_state()
|
||||||
|
torch_cuda_random_state = torch.cuda.random.get_rng_state()
|
||||||
|
set_global_seed(seed)
|
||||||
|
yield None
|
||||||
|
random.setstate(random_state)
|
||||||
|
np.random.set_state(np_random_state)
|
||||||
|
torch.random.set_rng_state(torch_random_state)
|
||||||
|
torch.cuda.random.set_rng_state(torch_cuda_random_state)
|
||||||
|
|
||||||
|
|
||||||
def init_logging():
|
def init_logging():
|
||||||
def custom_format(record):
|
def custom_format(record):
|
||||||
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
|
|
@ -0,0 +1,38 @@
|
||||||
|
import random
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.common.utils.utils import seeded_context, set_global_seed
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"rand_fn",
|
||||||
|
[
|
||||||
|
random.random,
|
||||||
|
np.random.random,
|
||||||
|
lambda: torch.rand(1).item(),
|
||||||
|
]
|
||||||
|
+ [lambda: torch.rand(1, device="cuda")]
|
||||||
|
if torch.cuda.is_available()
|
||||||
|
else [],
|
||||||
|
)
|
||||||
|
def test_seeding(rand_fn: Callable[[], int]):
|
||||||
|
set_global_seed(0)
|
||||||
|
a = rand_fn()
|
||||||
|
with seeded_context(1337):
|
||||||
|
c = rand_fn()
|
||||||
|
b = rand_fn()
|
||||||
|
set_global_seed(0)
|
||||||
|
a_ = rand_fn()
|
||||||
|
b_ = rand_fn()
|
||||||
|
# Check that `set_global_seed` lets us reproduce a and b.
|
||||||
|
assert a_ == a
|
||||||
|
# Additionally, check that the `seeded_context` didn't interrupt the global RNG.
|
||||||
|
assert b_ == b
|
||||||
|
set_global_seed(1337)
|
||||||
|
c_ = rand_fn()
|
||||||
|
# Check that `seeded_context` and `global_seed` give the same reproducibility.
|
||||||
|
assert c_ == c
|
Loading…
Reference in New Issue