ready for review

This commit is contained in:
Alexander Soare 2024-05-30 17:17:37 +01:00
parent 111cd58f8a
commit 00cefb44bc
1 changed files with 31 additions and 6 deletions

View File

@ -16,6 +16,7 @@
import json
import logging
from copy import deepcopy
from itertools import chain
from pathlib import Path
import einops
@ -31,7 +32,7 @@ from lerobot.common.datasets.compute_stats import (
get_stats_einops_patterns,
)
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
from lerobot.common.datasets.utils import (
flatten_dict,
hf_transform_to_torch,
@ -42,11 +43,7 @@ from lerobot.common.utils.utils import init_hydra_config, seeded_context
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE
@pytest.mark.parametrize(
"env_name, repo_id, policy_name",
lerobot.env_dataset_policy_triplets
+ [("aloha", ["lerobot/aloha_sim_insertion_human", "lerobot/aloha_sim_transfer_cube_human"], "act")],
)
@pytest.mark.parametrize("env_name, repo_id, policy_name", lerobot.env_dataset_policy_triplets)
def test_factory(env_name, repo_id, policy_name):
"""
Tests that:
@ -113,6 +110,30 @@ def test_factory(env_name, repo_id, policy_name):
assert key in item, f"{key}"
def test_multilerobotdataset_frames():
"""Check that all dataset frames are incorporated."""
# Note: use the image variants of the dataset to make the test approx 3x faster.
repo_ids = ["lerobot/aloha_sim_insertion_human_image", "lerobot/aloha_sim_transfer_cube_human_image"]
sub_datasets = [LeRobotDataset(repo_id, root="tests/data") for repo_id in repo_ids]
dataset = MultiLeRobotDataset(repo_ids, root="tests/data")
assert len(dataset) == sum(len(d) for d in sub_datasets)
assert dataset.num_samples == sum(d.num_samples for d in sub_datasets)
assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets)
# Run through all items of the LeRobotDatasets in parallel with the items of the MultiLerobotDataset and
# check they match.
expected_dataset_indices = []
for i, sub_dataset in enumerate(sub_datasets):
expected_dataset_indices.extend([i] * len(sub_dataset))
for expected_dataset_index, sub_dataset_item, dataset_item in zip(
expected_dataset_indices, chain(*sub_datasets), dataset, strict=True
):
dataset_index = dataset_item.pop("dataset_index")
assert dataset_index == expected_dataset_index
assert sub_dataset_item.keys() == dataset_item.keys()
for k in sub_dataset_item:
assert torch.equal(sub_dataset_item[k], dataset_item[k])
def test_compute_stats_on_xarm():
"""Check that the statistics are computed correctly according to the stats_patterns property.
@ -351,3 +372,7 @@ def test_aggregate_stats():
for agg_fn in ["mean", "min", "max"]:
assert torch.allclose(stats[data_key][agg_fn], einops.reduce(data, "n -> 1", agg_fn))
assert torch.allclose(stats[data_key]["std"], torch.std(data, correction=0))
if __name__ == "__main__":
test_multilerobotdataset_frames()