2025-02-24 02:18:46 +08:00
|
|
|
from lerobot.common.datasets.aggregate import aggregate_datasets
|
2025-04-19 21:41:53 +08:00
|
|
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
2025-02-24 02:18:46 +08:00
|
|
|
from tests.fixtures.constants import DUMMY_REPO_ID
|
|
|
|
|
|
|
|
|
|
|
|
def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
|
2025-04-19 21:41:53 +08:00
|
|
|
ds_0 = lerobot_dataset_factory(
|
2025-02-24 02:18:46 +08:00
|
|
|
root=tmp_path / "test_0",
|
2025-04-19 21:41:53 +08:00
|
|
|
repo_id=f"{DUMMY_REPO_ID}_0",
|
2025-02-24 02:18:46 +08:00
|
|
|
total_episodes=10,
|
|
|
|
total_frames=400,
|
|
|
|
)
|
2025-04-19 21:41:53 +08:00
|
|
|
ds_1 = lerobot_dataset_factory(
|
2025-02-24 02:18:46 +08:00
|
|
|
root=tmp_path / "test_1",
|
2025-04-19 21:41:53 +08:00
|
|
|
repo_id=f"{DUMMY_REPO_ID}_1",
|
2025-02-24 02:18:46 +08:00
|
|
|
total_episodes=10,
|
|
|
|
total_frames=400,
|
|
|
|
)
|
|
|
|
|
2025-04-19 21:41:53 +08:00
|
|
|
aggregate_datasets(
|
|
|
|
repo_ids=[ds_0.repo_id, ds_1.repo_id],
|
|
|
|
roots=[ds_0.root, ds_1.root],
|
|
|
|
aggr_repo_id=f"{DUMMY_REPO_ID}_aggr",
|
2025-04-21 15:34:19 +08:00
|
|
|
aggr_root=tmp_path / "test_aggr",
|
2025-04-19 21:41:53 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_aggr", root=tmp_path / "test_aggr")
|
2025-04-21 15:34:19 +08:00
|
|
|
for _ in aggr_ds:
|
2025-04-19 21:41:53 +08:00
|
|
|
pass
|