Fix test_online_buffer.py
This commit is contained in:
parent
df2cb51364
commit
ac79e8cb36
|
@ -19,11 +19,8 @@ from uuid import uuid4
|
|||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.online_buffer import OnlineBuffer, compute_sampler_weights
|
||||
from lerobot.common.datasets.utils import hf_transform_to_torch
|
||||
|
||||
# Some constants for OnlineBuffer tests.
|
||||
data_key = "data"
|
||||
|
@ -212,29 +209,19 @@ def test_delta_timestamps_outside_tolerance_outside_episode_range():
|
|||
|
||||
|
||||
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
|
||||
@pytest.mark.parametrize("offline_dataset_size", [0, 6])
|
||||
@pytest.mark.parametrize("offline_dataset_size", [1, 6])
|
||||
@pytest.mark.parametrize("online_dataset_size", [0, 4])
|
||||
@pytest.mark.parametrize("online_sampling_ratio", [0.0, 1.0])
|
||||
def test_compute_sampler_weights_trivial(
|
||||
offline_dataset_size: int, online_dataset_size: int, online_sampling_ratio: float
|
||||
lerobot_dataset_from_episodes_factory,
|
||||
tmp_path,
|
||||
offline_dataset_size: int,
|
||||
online_dataset_size: int,
|
||||
online_sampling_ratio: float,
|
||||
):
|
||||
# Pass/skip the test if both datasets sizes are zero.
|
||||
if offline_dataset_size + online_dataset_size == 0:
|
||||
return
|
||||
# Create spoof offline dataset.
|
||||
offline_dataset = LeRobotDataset.from_preloaded(
|
||||
hf_dataset=Dataset.from_dict({"data": list(range(offline_dataset_size))})
|
||||
offline_dataset = lerobot_dataset_from_episodes_factory(
|
||||
tmp_path, total_episodes=1, total_frames=offline_dataset_size
|
||||
)
|
||||
offline_dataset.hf_dataset.set_transform(hf_transform_to_torch)
|
||||
if offline_dataset_size == 0:
|
||||
offline_dataset.episode_data_index = {}
|
||||
else:
|
||||
# Set up an episode_data_index with at least two episodes.
|
||||
offline_dataset.episode_data_index = {
|
||||
"from": torch.tensor([0, offline_dataset_size // 2]),
|
||||
"to": torch.tensor([offline_dataset_size // 2, offline_dataset_size]),
|
||||
}
|
||||
# Create spoof online datset.
|
||||
online_dataset, _ = make_new_buffer()
|
||||
if online_dataset_size > 0:
|
||||
online_dataset.add_data(
|
||||
|
@ -254,16 +241,9 @@ def test_compute_sampler_weights_trivial(
|
|||
assert torch.allclose(weights, expected_weights)
|
||||
|
||||
|
||||
def test_compute_sampler_weights_nontrivial_ratio():
|
||||
def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_from_episodes_factory, tmp_path):
|
||||
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
|
||||
# Create spoof offline dataset.
|
||||
offline_dataset = LeRobotDataset.from_preloaded(hf_dataset=Dataset.from_dict({"data": list(range(4))}))
|
||||
offline_dataset.hf_dataset.set_transform(hf_transform_to_torch)
|
||||
offline_dataset.episode_data_index = {
|
||||
"from": torch.tensor([0, 2]),
|
||||
"to": torch.tensor([2, 4]),
|
||||
}
|
||||
# Create spoof online datset.
|
||||
offline_dataset = lerobot_dataset_from_episodes_factory(tmp_path, total_episodes=1, total_frames=4)
|
||||
online_dataset, _ = make_new_buffer()
|
||||
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
|
||||
online_sampling_ratio = 0.8
|
||||
|
@ -275,16 +255,11 @@ def test_compute_sampler_weights_nontrivial_ratio():
|
|||
)
|
||||
|
||||
|
||||
def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n():
|
||||
def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(
|
||||
lerobot_dataset_from_episodes_factory, tmp_path
|
||||
):
|
||||
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
|
||||
# Create spoof offline dataset.
|
||||
offline_dataset = LeRobotDataset.from_preloaded(hf_dataset=Dataset.from_dict({"data": list(range(4))}))
|
||||
offline_dataset.hf_dataset.set_transform(hf_transform_to_torch)
|
||||
offline_dataset.episode_data_index = {
|
||||
"from": torch.tensor([0]),
|
||||
"to": torch.tensor([4]),
|
||||
}
|
||||
# Create spoof online datset.
|
||||
offline_dataset = lerobot_dataset_from_episodes_factory(tmp_path, total_episodes=1, total_frames=4)
|
||||
online_dataset, _ = make_new_buffer()
|
||||
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
|
||||
weights = compute_sampler_weights(
|
||||
|
@ -295,18 +270,9 @@ def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n():
|
|||
)
|
||||
|
||||
|
||||
def test_compute_sampler_weights_drop_n_last_frames():
|
||||
def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_from_episodes_factory, tmp_path):
|
||||
"""Note: test copied from test_sampler."""
|
||||
data_dict = {
|
||||
"timestamp": [0, 0.1],
|
||||
"index": [0, 1],
|
||||
"episode_index": [0, 0],
|
||||
"frame_index": [0, 1],
|
||||
}
|
||||
offline_dataset = LeRobotDataset.from_preloaded(hf_dataset=Dataset.from_dict(data_dict))
|
||||
offline_dataset.hf_dataset.set_transform(hf_transform_to_torch)
|
||||
offline_dataset.episode_data_index = {"from": torch.tensor([0]), "to": torch.tensor([2])}
|
||||
|
||||
offline_dataset = lerobot_dataset_from_episodes_factory(tmp_path, total_episodes=1, total_frames=2)
|
||||
online_dataset, _ = make_new_buffer()
|
||||
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
|
||||
|
||||
|
|
Loading…
Reference in New Issue