39 lines
947 B
Python
39 lines
947 B
Python
import random
|
|
from typing import Callable
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
|
|
from lerobot.common.utils.utils import seeded_context, set_global_seed
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"rand_fn",
|
|
[
|
|
random.random,
|
|
np.random.random,
|
|
lambda: torch.rand(1).item(),
|
|
]
|
|
+ [lambda: torch.rand(1, device="cuda")]
|
|
if torch.cuda.is_available()
|
|
else [],
|
|
)
|
|
def test_seeding(rand_fn: Callable[[], int]):
|
|
set_global_seed(0)
|
|
a = rand_fn()
|
|
with seeded_context(1337):
|
|
c = rand_fn()
|
|
b = rand_fn()
|
|
set_global_seed(0)
|
|
a_ = rand_fn()
|
|
b_ = rand_fn()
|
|
# Check that `set_global_seed` lets us reproduce a and b.
|
|
assert a_ == a
|
|
# Additionally, check that the `seeded_context` didn't interrupt the global RNG.
|
|
assert b_ == b
|
|
set_global_seed(1337)
|
|
c_ = rand_fn()
|
|
# Check that `seeded_context` and `global_seed` give the same reproducibility.
|
|
assert c_ == c
|