lerobot/tests/datasets/test_aggregate.py

30 lines
906 B
Python
Raw Normal View History

2025-02-24 02:18:46 +08:00
from lerobot.common.datasets.aggregate import aggregate_datasets
2025-04-19 21:41:53 +08:00
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
2025-02-24 02:18:46 +08:00
from tests.fixtures.constants import DUMMY_REPO_ID
def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
2025-04-19 21:41:53 +08:00
ds_0 = lerobot_dataset_factory(
2025-02-24 02:18:46 +08:00
root=tmp_path / "test_0",
2025-04-19 21:41:53 +08:00
repo_id=f"{DUMMY_REPO_ID}_0",
2025-02-24 02:18:46 +08:00
total_episodes=10,
total_frames=400,
)
2025-04-19 21:41:53 +08:00
ds_1 = lerobot_dataset_factory(
2025-02-24 02:18:46 +08:00
root=tmp_path / "test_1",
2025-04-19 21:41:53 +08:00
repo_id=f"{DUMMY_REPO_ID}_1",
2025-02-24 02:18:46 +08:00
total_episodes=10,
total_frames=400,
)
2025-04-19 21:41:53 +08:00
aggregate_datasets(
repo_ids=[ds_0.repo_id, ds_1.repo_id],
roots=[ds_0.root, ds_1.root],
aggr_repo_id=f"{DUMMY_REPO_ID}_aggr",
2025-04-21 15:34:19 +08:00
aggr_root=tmp_path / "test_aggr",
2025-04-19 21:41:53 +08:00
)
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_aggr", root=tmp_path / "test_aggr")
2025-04-21 15:34:19 +08:00
for _ in aggr_ds:
2025-04-19 21:41:53 +08:00
pass