Fix test_online_buffer.py

This commit is contained in:
Simon Alibert 2024-11-03 13:15:01 +01:00
parent df2cb51364
commit ac79e8cb36
1 changed files with 16 additions and 50 deletions

View File

@ -19,11 +19,8 @@ from uuid import uuid4
import numpy as np
import pytest
import torch
from datasets import Dataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.online_buffer import OnlineBuffer, compute_sampler_weights
from lerobot.common.datasets.utils import hf_transform_to_torch
# Some constants for OnlineBuffer tests.
data_key = "data"
@ -212,29 +209,19 @@ def test_delta_timestamps_outside_tolerance_outside_episode_range():
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
@pytest.mark.parametrize("offline_dataset_size", [0, 6])
@pytest.mark.parametrize("offline_dataset_size", [1, 6])
@pytest.mark.parametrize("online_dataset_size", [0, 4])
@pytest.mark.parametrize("online_sampling_ratio", [0.0, 1.0])
def test_compute_sampler_weights_trivial(
offline_dataset_size: int, online_dataset_size: int, online_sampling_ratio: float
lerobot_dataset_from_episodes_factory,
tmp_path,
offline_dataset_size: int,
online_dataset_size: int,
online_sampling_ratio: float,
):
# Pass/skip the test if both datasets sizes are zero.
if offline_dataset_size + online_dataset_size == 0:
return
# Create spoof offline dataset.
offline_dataset = LeRobotDataset.from_preloaded(
hf_dataset=Dataset.from_dict({"data": list(range(offline_dataset_size))})
offline_dataset = lerobot_dataset_from_episodes_factory(
tmp_path, total_episodes=1, total_frames=offline_dataset_size
)
offline_dataset.hf_dataset.set_transform(hf_transform_to_torch)
if offline_dataset_size == 0:
offline_dataset.episode_data_index = {}
else:
# Set up an episode_data_index with at least two episodes.
offline_dataset.episode_data_index = {
"from": torch.tensor([0, offline_dataset_size // 2]),
"to": torch.tensor([offline_dataset_size // 2, offline_dataset_size]),
}
# Create spoof online datset.
online_dataset, _ = make_new_buffer()
if online_dataset_size > 0:
online_dataset.add_data(
@ -254,16 +241,9 @@ def test_compute_sampler_weights_trivial(
assert torch.allclose(weights, expected_weights)
def test_compute_sampler_weights_nontrivial_ratio():
def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_from_episodes_factory, tmp_path):
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
# Create spoof offline dataset.
offline_dataset = LeRobotDataset.from_preloaded(hf_dataset=Dataset.from_dict({"data": list(range(4))}))
offline_dataset.hf_dataset.set_transform(hf_transform_to_torch)
offline_dataset.episode_data_index = {
"from": torch.tensor([0, 2]),
"to": torch.tensor([2, 4]),
}
# Create spoof online datset.
offline_dataset = lerobot_dataset_from_episodes_factory(tmp_path, total_episodes=1, total_frames=4)
online_dataset, _ = make_new_buffer()
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
online_sampling_ratio = 0.8
@ -275,16 +255,11 @@ def test_compute_sampler_weights_nontrivial_ratio():
)
def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n():
def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(
lerobot_dataset_from_episodes_factory, tmp_path
):
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
# Create spoof offline dataset.
offline_dataset = LeRobotDataset.from_preloaded(hf_dataset=Dataset.from_dict({"data": list(range(4))}))
offline_dataset.hf_dataset.set_transform(hf_transform_to_torch)
offline_dataset.episode_data_index = {
"from": torch.tensor([0]),
"to": torch.tensor([4]),
}
# Create spoof online datset.
offline_dataset = lerobot_dataset_from_episodes_factory(tmp_path, total_episodes=1, total_frames=4)
online_dataset, _ = make_new_buffer()
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
weights = compute_sampler_weights(
@ -295,18 +270,9 @@ def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n():
)
def test_compute_sampler_weights_drop_n_last_frames():
def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_from_episodes_factory, tmp_path):
"""Note: test copied from test_sampler."""
data_dict = {
"timestamp": [0, 0.1],
"index": [0, 1],
"episode_index": [0, 0],
"frame_index": [0, 1],
}
offline_dataset = LeRobotDataset.from_preloaded(hf_dataset=Dataset.from_dict(data_dict))
offline_dataset.hf_dataset.set_transform(hf_transform_to_torch)
offline_dataset.episode_data_index = {"from": torch.tensor([0]), "to": torch.tensor([2])}
offline_dataset = lerobot_dataset_from_episodes_factory(tmp_path, total_episodes=1, total_frames=2)
online_dataset, _ = make_new_buffer()
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))