From b187942db428ef7a7ae01a2dc1a27fa6e62c03b2 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Thu, 9 May 2024 17:58:39 +0100 Subject: [PATCH] Add context manager for seeding (#164) --- lerobot/common/utils/utils.py | 27 +++++++++++++++++++++++++ tests/test_utils.py | 38 +++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+) create mode 100644 tests/test_utils.py diff --git a/lerobot/common/utils/utils.py b/lerobot/common/utils/utils.py index 9d0ddd98..8fe621f4 100644 --- a/lerobot/common/utils/utils.py +++ b/lerobot/common/utils/utils.py @@ -1,8 +1,10 @@ import logging import os.path as osp import random +from contextlib import contextmanager from datetime import datetime from pathlib import Path +from typing import Generator import hydra import numpy as np @@ -39,6 +41,31 @@ def set_global_seed(seed): torch.cuda.manual_seed_all(seed) +@contextmanager +def seeded_context(seed: int) -> Generator[None, None, None]: + """Set the seed when entering a context, and restore the prior random state at exit. + + Example usage: + + ``` + a = random.random() # produces some random number + with seeded_context(1337): + b = random.random() # produces some other random number + c = random.random() # produces yet another random number, but the same it would have if we never made `b` + ``` + """ + random_state = random.getstate() + np_random_state = np.random.get_state() + torch_random_state = torch.random.get_rng_state() + torch_cuda_random_state = torch.cuda.random.get_rng_state() + set_global_seed(seed) + yield None + random.setstate(random_state) + np.random.set_state(np_random_state) + torch.random.set_rng_state(torch_random_state) + torch.cuda.random.set_rng_state(torch_cuda_random_state) + + def init_logging(): def custom_format(record): dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S") diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..bcdd95b4 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,38 @@ +import random +from typing import Callable + +import numpy as np +import pytest +import torch + +from lerobot.common.utils.utils import seeded_context, set_global_seed + + +@pytest.mark.parametrize( + "rand_fn", + [ + random.random, + np.random.random, + lambda: torch.rand(1).item(), + ] + + [lambda: torch.rand(1, device="cuda")] + if torch.cuda.is_available() + else [], +) +def test_seeding(rand_fn: Callable[[], int]): + set_global_seed(0) + a = rand_fn() + with seeded_context(1337): + c = rand_fn() + b = rand_fn() + set_global_seed(0) + a_ = rand_fn() + b_ = rand_fn() + # Check that `set_global_seed` lets us reproduce a and b. + assert a_ == a + # Additionally, check that the `seeded_context` didn't interrupt the global RNG. + assert b_ == b + set_global_seed(1337) + c_ = rand_fn() + # Check that `seeded_context` and `global_seed` give the same reproducibility. + assert c_ == c