diff --git a/tests/test_online_buffer.py b/tests/test_online_buffer.py index 37000e4f..20e26177 100644 --- a/tests/test_online_buffer.py +++ b/tests/test_online_buffer.py @@ -19,11 +19,8 @@ from uuid import uuid4 import numpy as np import pytest import torch -from datasets import Dataset -from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.online_buffer import OnlineBuffer, compute_sampler_weights -from lerobot.common.datasets.utils import hf_transform_to_torch # Some constants for OnlineBuffer tests. data_key = "data" @@ -212,29 +209,19 @@ def test_delta_timestamps_outside_tolerance_outside_episode_range(): # Arbitrarily set small dataset sizes, making sure to have uneven sizes. -@pytest.mark.parametrize("offline_dataset_size", [0, 6]) +@pytest.mark.parametrize("offline_dataset_size", [1, 6]) @pytest.mark.parametrize("online_dataset_size", [0, 4]) @pytest.mark.parametrize("online_sampling_ratio", [0.0, 1.0]) def test_compute_sampler_weights_trivial( - offline_dataset_size: int, online_dataset_size: int, online_sampling_ratio: float + lerobot_dataset_from_episodes_factory, + tmp_path, + offline_dataset_size: int, + online_dataset_size: int, + online_sampling_ratio: float, ): - # Pass/skip the test if both datasets sizes are zero. - if offline_dataset_size + online_dataset_size == 0: - return - # Create spoof offline dataset. - offline_dataset = LeRobotDataset.from_preloaded( - hf_dataset=Dataset.from_dict({"data": list(range(offline_dataset_size))}) + offline_dataset = lerobot_dataset_from_episodes_factory( + tmp_path, total_episodes=1, total_frames=offline_dataset_size ) - offline_dataset.hf_dataset.set_transform(hf_transform_to_torch) - if offline_dataset_size == 0: - offline_dataset.episode_data_index = {} - else: - # Set up an episode_data_index with at least two episodes. - offline_dataset.episode_data_index = { - "from": torch.tensor([0, offline_dataset_size // 2]), - "to": torch.tensor([offline_dataset_size // 2, offline_dataset_size]), - } - # Create spoof online datset. online_dataset, _ = make_new_buffer() if online_dataset_size > 0: online_dataset.add_data( @@ -254,16 +241,9 @@ def test_compute_sampler_weights_trivial( assert torch.allclose(weights, expected_weights) -def test_compute_sampler_weights_nontrivial_ratio(): +def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_from_episodes_factory, tmp_path): # Arbitrarily set small dataset sizes, making sure to have uneven sizes. - # Create spoof offline dataset. - offline_dataset = LeRobotDataset.from_preloaded(hf_dataset=Dataset.from_dict({"data": list(range(4))})) - offline_dataset.hf_dataset.set_transform(hf_transform_to_torch) - offline_dataset.episode_data_index = { - "from": torch.tensor([0, 2]), - "to": torch.tensor([2, 4]), - } - # Create spoof online datset. + offline_dataset = lerobot_dataset_from_episodes_factory(tmp_path, total_episodes=1, total_frames=4) online_dataset, _ = make_new_buffer() online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)) online_sampling_ratio = 0.8 @@ -275,16 +255,11 @@ def test_compute_sampler_weights_nontrivial_ratio(): ) -def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(): +def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n( + lerobot_dataset_from_episodes_factory, tmp_path +): # Arbitrarily set small dataset sizes, making sure to have uneven sizes. - # Create spoof offline dataset. - offline_dataset = LeRobotDataset.from_preloaded(hf_dataset=Dataset.from_dict({"data": list(range(4))})) - offline_dataset.hf_dataset.set_transform(hf_transform_to_torch) - offline_dataset.episode_data_index = { - "from": torch.tensor([0]), - "to": torch.tensor([4]), - } - # Create spoof online datset. + offline_dataset = lerobot_dataset_from_episodes_factory(tmp_path, total_episodes=1, total_frames=4) online_dataset, _ = make_new_buffer() online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)) weights = compute_sampler_weights( @@ -295,18 +270,9 @@ def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(): ) -def test_compute_sampler_weights_drop_n_last_frames(): +def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_from_episodes_factory, tmp_path): """Note: test copied from test_sampler.""" - data_dict = { - "timestamp": [0, 0.1], - "index": [0, 1], - "episode_index": [0, 0], - "frame_index": [0, 1], - } - offline_dataset = LeRobotDataset.from_preloaded(hf_dataset=Dataset.from_dict(data_dict)) - offline_dataset.hf_dataset.set_transform(hf_transform_to_torch) - offline_dataset.episode_data_index = {"from": torch.tensor([0]), "to": torch.tensor([2])} - + offline_dataset = lerobot_dataset_from_episodes_factory(tmp_path, total_episodes=1, total_frames=2) online_dataset, _ = make_new_buffer() online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))