2024-05-10 00:58:39 +08:00
|
|
|
import random
|
|
|
|
from typing import Callable
|
2024-08-23 19:27:08 +08:00
|
|
|
from uuid import uuid4
|
2024-05-10 00:58:39 +08:00
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import pytest
|
|
|
|
import torch
|
2024-05-20 20:04:04 +08:00
|
|
|
from datasets import Dataset
|
2024-05-10 00:58:39 +08:00
|
|
|
|
2024-05-20 20:04:04 +08:00
|
|
|
from lerobot.common.datasets.utils import (
|
|
|
|
calculate_episode_data_index,
|
|
|
|
hf_transform_to_torch,
|
|
|
|
reset_episode_index,
|
|
|
|
)
|
2024-05-28 19:04:23 +08:00
|
|
|
from lerobot.common.utils.utils import (
|
|
|
|
get_global_random_state,
|
2024-08-23 19:27:08 +08:00
|
|
|
init_hydra_config,
|
2024-05-28 19:04:23 +08:00
|
|
|
seeded_context,
|
|
|
|
set_global_random_state,
|
|
|
|
set_global_seed,
|
|
|
|
)
|
2024-05-10 00:58:39 +08:00
|
|
|
|
2024-05-28 19:04:23 +08:00
|
|
|
# Random generation functions for testing the seeding and random state get/set.
|
|
|
|
rand_fns = [
|
|
|
|
random.random,
|
|
|
|
np.random.random,
|
|
|
|
lambda: torch.rand(1).item(),
|
|
|
|
]
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
rand_fns.append(lambda: torch.rand(1, device="cuda"))
|
2024-05-10 00:58:39 +08:00
|
|
|
|
2024-05-28 19:04:23 +08:00
|
|
|
|
|
|
|
@pytest.mark.parametrize("rand_fn", rand_fns)
|
2024-05-10 00:58:39 +08:00
|
|
|
def test_seeding(rand_fn: Callable[[], int]):
|
|
|
|
set_global_seed(0)
|
|
|
|
a = rand_fn()
|
|
|
|
with seeded_context(1337):
|
|
|
|
c = rand_fn()
|
|
|
|
b = rand_fn()
|
|
|
|
set_global_seed(0)
|
|
|
|
a_ = rand_fn()
|
|
|
|
b_ = rand_fn()
|
|
|
|
# Check that `set_global_seed` lets us reproduce a and b.
|
|
|
|
assert a_ == a
|
|
|
|
# Additionally, check that the `seeded_context` didn't interrupt the global RNG.
|
|
|
|
assert b_ == b
|
|
|
|
set_global_seed(1337)
|
|
|
|
c_ = rand_fn()
|
|
|
|
# Check that `seeded_context` and `global_seed` give the same reproducibility.
|
|
|
|
assert c_ == c
|
2024-05-20 20:04:04 +08:00
|
|
|
|
|
|
|
|
2024-05-28 19:04:23 +08:00
|
|
|
def test_get_set_random_state():
|
|
|
|
"""Check that getting the random state, then setting it results in the same random number generation."""
|
|
|
|
random_state_dict = get_global_random_state()
|
|
|
|
rand_numbers = [rand_fn() for rand_fn in rand_fns]
|
|
|
|
set_global_random_state(random_state_dict)
|
|
|
|
rand_numbers_ = [rand_fn() for rand_fn in rand_fns]
|
|
|
|
assert rand_numbers_ == rand_numbers
|
|
|
|
|
|
|
|
|
2024-05-20 20:04:04 +08:00
|
|
|
def test_calculate_episode_data_index():
|
|
|
|
dataset = Dataset.from_dict(
|
|
|
|
{
|
|
|
|
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
|
|
|
"index": [0, 1, 2, 3, 4, 5],
|
|
|
|
"episode_index": [0, 0, 1, 2, 2, 2],
|
|
|
|
},
|
|
|
|
)
|
|
|
|
dataset.set_transform(hf_transform_to_torch)
|
|
|
|
episode_data_index = calculate_episode_data_index(dataset)
|
|
|
|
assert torch.equal(episode_data_index["from"], torch.tensor([0, 2, 3]))
|
|
|
|
assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6]))
|
|
|
|
|
|
|
|
|
|
|
|
def test_reset_episode_index():
|
|
|
|
dataset = Dataset.from_dict(
|
|
|
|
{
|
|
|
|
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
|
|
|
"index": [0, 1, 2, 3, 4, 5],
|
|
|
|
"episode_index": [10, 10, 11, 12, 12, 12],
|
|
|
|
},
|
|
|
|
)
|
|
|
|
dataset.set_transform(hf_transform_to_torch)
|
|
|
|
correct_episode_index = [0, 0, 1, 2, 2, 2]
|
|
|
|
dataset = reset_episode_index(dataset)
|
|
|
|
assert dataset["episode_index"] == correct_episode_index
|
2024-08-23 19:27:08 +08:00
|
|
|
|
|
|
|
|
|
|
|
def test_init_hydra_config_empty():
|
|
|
|
test_file = f"/tmp/test_init_hydra_config_empty_{uuid4().hex}.yaml"
|
|
|
|
with open(test_file, "w") as f:
|
|
|
|
f.write("\n")
|
|
|
|
init_hydra_config(test_file)
|