Simplify, add test content, add todo

This commit is contained in:
Simon Alibert 2024-11-01 19:55:28 +01:00
parent 79d114cc1f
commit 293bdc7f67
2 changed files with 38 additions and 39 deletions

View File

@ -297,6 +297,7 @@ def lerobot_dataset_factory(
):
def _create_lerobot_dataset(
root: Path,
repo_id: str = DUMMY_REPO_ID,
info_dict: dict = info,
stats_dict: dict = stats,
task_dicts: list[dict] = tasks,
@ -322,7 +323,7 @@ def lerobot_dataset_factory(
mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version, enforce_v2=True: version
mock_snapshot_download_patch.side_effect = mock_snapshot_download
return LeRobotDataset(repo_id=DUMMY_REPO_ID, root=root, **kwargs)
return LeRobotDataset(repo_id=repo_id, root=root, **kwargs)
return _create_lerobot_dataset
@ -341,6 +342,7 @@ def lerobot_dataset_from_episodes_factory(
total_frames: int = 150,
total_tasks: int = 1,
multi_task: bool = False,
repo_id: str = DUMMY_REPO_ID,
**kwargs,
):
info_dict = info_factory(
@ -356,6 +358,7 @@ def lerobot_dataset_from_episodes_factory(
hf_dataset = hf_dataset_factory(episode_dicts=episode_dicts, task_dicts=task_dicts)
return lerobot_dataset_factory(
root=root,
repo_id=repo_id,
info_dict=info_dict,
task_dicts=task_dicts,
episode_dicts=episode_dicts,

View File

@ -41,26 +41,23 @@ from lerobot.common.datasets.utils import (
unflatten_dict,
)
from lerobot.common.utils.utils import init_hydra_config, seeded_context
from tests.fixtures.defaults import DEFAULT_FPS, DUMMY_REPO_ID, DUMMY_ROBOT_TYPE
from tests.fixtures.defaults import DUMMY_REPO_ID
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, make_robot
@pytest.fixture(scope="function")
def dataset_create(tmp_path):
robot = make_robot("koch", mock=True)
return LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=tmp_path)
@pytest.fixture(scope="function")
def dataset_init(lerobot_dataset_factory, tmp_path):
return lerobot_dataset_factory(root=tmp_path)
def test_same_attributes_defined(dataset_create, dataset_init):
def test_same_attributes_defined(lerobot_dataset_factory, tmp_path):
"""
Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated
objects have the same sets of attributes defined.
"""
# Instantiate both ways
robot = make_robot("koch", mock=True)
root_create = tmp_path / "create"
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create)
root_init = tmp_path / "init"
dataset_init = lerobot_dataset_factory(root=root_init)
# Access the '_hub_version' cached_property in both instances to force its creation
_ = dataset_init._hub_version
_ = dataset_create._hub_version
@ -68,35 +65,34 @@ def test_same_attributes_defined(dataset_create, dataset_init):
init_attr = set(vars(dataset_init).keys())
create_attr = set(vars(dataset_create).keys())
assert init_attr == create_attr, "Attribute sets do not match between __init__ and .create()"
assert init_attr == create_attr
def test_dataset_initialization(lerobot_dataset_from_episodes_factory, tmp_path):
total_episodes = 10
total_frames = 400
dataset = lerobot_dataset_from_episodes_factory(
root=tmp_path, total_episodes=total_episodes, total_frames=total_frames
)
kwargs = {
"repo_id": DUMMY_REPO_ID,
"total_episodes": 10,
"total_frames": 400,
"episodes": [2, 5, 6],
}
dataset = lerobot_dataset_from_episodes_factory(root=tmp_path, **kwargs)
assert dataset.repo_id == DUMMY_REPO_ID
assert dataset.num_episodes == total_episodes
assert dataset.num_samples == total_frames
assert dataset.info["fps"] == DEFAULT_FPS
assert dataset.info["robot_type"] == DUMMY_ROBOT_TYPE
assert dataset.repo_id == kwargs["repo_id"]
assert dataset.total_episodes == kwargs["total_episodes"]
assert dataset.total_frames == kwargs["total_frames"]
assert dataset.episodes == kwargs["episodes"]
assert dataset.num_episodes == len(kwargs["episodes"])
assert dataset.num_frames == len(dataset)
def test_dataset_length(dataset_init):
dataset = dataset_init
assert len(dataset) == 3 # Number of frames in the episode
def test_dataset_item(dataset_init):
dataset = dataset_init
item = dataset[0]
assert item["episode_index"] == 0
assert item["frame_index"] == 0
assert item["state"].tolist() == [1, 2, 3]
assert item["action"].tolist() == [0.1, 0.2]
# TODO(aliberts):
# - [ ] test various attributes & state from init and create
# - [ ] test init with episodes and check num_frames
# - [ ] test add_frame
# - [ ] test add_episode
# - [ ] test consolidate
# - [ ] test push_to_hub
# - [ ] test smaller methods
@pytest.mark.skip("TODO after v2 migration / removing hydra")
@ -186,7 +182,7 @@ def test_multilerobotdataset_frames():
sub_datasets = [LeRobotDataset(repo_id) for repo_id in repo_ids]
dataset = MultiLeRobotDataset(repo_ids)
assert len(dataset) == sum(len(d) for d in sub_datasets)
assert dataset.num_samples == sum(d.num_samples for d in sub_datasets)
assert dataset.num_frames == sum(d.num_frames for d in sub_datasets)
assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets)
# Run through all items of the LeRobotDatasets in parallel with the items of the MultiLerobotDataset and
@ -266,6 +262,7 @@ def test_compute_stats_on_xarm():
# assert torch.allclose(loaded_stats[k]["max"], expected_stats[k]["max"])
# TODO(aliberts): Move to more appropriate location
def test_flatten_unflatten_dict():
d = {
"obs": {
@ -301,7 +298,6 @@ def test_flatten_unflatten_dict():
# "lerobot/cmu_stretch",
],
)
# TODO(rcadene, aliberts): all these tests fail locally on Mac M1, but not on Linux
def test_backward_compatibility(repo_id):
"""The artifacts for this test have been generated by `tests/scripts/save_dataset_to_safetensors.py`."""