Add a test for MultiLeRobotDataset making sure it produces all frames. (#230)

Co-authored-by: Remi <re.cadene@gmail.com>
This commit is contained in:
Alexander Soare 2024-05-30 17:46:25 +01:00 committed by GitHub
parent 111cd58f8a
commit 0b51a335bc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 28 additions and 1 deletions

View File

@ -16,6 +16,7 @@
import json import json
import logging import logging
from copy import deepcopy from copy import deepcopy
from itertools import chain
from pathlib import Path from pathlib import Path
import einops import einops
@ -31,7 +32,7 @@ from lerobot.common.datasets.compute_stats import (
get_stats_einops_patterns, get_stats_einops_patterns,
) )
from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
flatten_dict, flatten_dict,
hf_transform_to_torch, hf_transform_to_torch,
@ -113,6 +114,32 @@ def test_factory(env_name, repo_id, policy_name):
assert key in item, f"{key}" assert key in item, f"{key}"
def test_multilerobotdataset_frames():
"""Check that all dataset frames are incorporated."""
# Note: use the image variants of the dataset to make the test approx 3x faster.
repo_ids = ["lerobot/aloha_sim_insertion_human_image", "lerobot/aloha_sim_transfer_cube_human_image"]
sub_datasets = [LeRobotDataset(repo_id) for repo_id in repo_ids]
dataset = MultiLeRobotDataset(repo_ids)
assert len(dataset) == sum(len(d) for d in sub_datasets)
assert dataset.num_samples == sum(d.num_samples for d in sub_datasets)
assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets)
# Run through all items of the LeRobotDatasets in parallel with the items of the MultiLerobotDataset and
# check they match.
expected_dataset_indices = []
for i, sub_dataset in enumerate(sub_datasets):
expected_dataset_indices.extend([i] * len(sub_dataset))
for expected_dataset_index, sub_dataset_item, dataset_item in zip(
expected_dataset_indices, chain(*sub_datasets), dataset, strict=True
):
dataset_index = dataset_item.pop("dataset_index")
assert dataset_index == expected_dataset_index
assert sub_dataset_item.keys() == dataset_item.keys()
for k in sub_dataset_item:
assert torch.equal(sub_dataset_item[k], dataset_item[k])
def test_compute_stats_on_xarm(): def test_compute_stats_on_xarm():
"""Check that the statistics are computed correctly according to the stats_patterns property. """Check that the statistics are computed correctly according to the stats_patterns property.