Simplify, add test content, add todo

This commit is contained in:
Simon Alibert 2024-11-01 19:55:28 +01:00
parent 79d114cc1f
commit 293bdc7f67
2 changed files with 38 additions and 39 deletions

View File

@ -297,6 +297,7 @@ def lerobot_dataset_factory(
): ):
def _create_lerobot_dataset( def _create_lerobot_dataset(
root: Path, root: Path,
repo_id: str = DUMMY_REPO_ID,
info_dict: dict = info, info_dict: dict = info,
stats_dict: dict = stats, stats_dict: dict = stats,
task_dicts: list[dict] = tasks, task_dicts: list[dict] = tasks,
@ -322,7 +323,7 @@ def lerobot_dataset_factory(
mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version, enforce_v2=True: version mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version, enforce_v2=True: version
mock_snapshot_download_patch.side_effect = mock_snapshot_download mock_snapshot_download_patch.side_effect = mock_snapshot_download
return LeRobotDataset(repo_id=DUMMY_REPO_ID, root=root, **kwargs) return LeRobotDataset(repo_id=repo_id, root=root, **kwargs)
return _create_lerobot_dataset return _create_lerobot_dataset
@ -341,6 +342,7 @@ def lerobot_dataset_from_episodes_factory(
total_frames: int = 150, total_frames: int = 150,
total_tasks: int = 1, total_tasks: int = 1,
multi_task: bool = False, multi_task: bool = False,
repo_id: str = DUMMY_REPO_ID,
**kwargs, **kwargs,
): ):
info_dict = info_factory( info_dict = info_factory(
@ -356,6 +358,7 @@ def lerobot_dataset_from_episodes_factory(
hf_dataset = hf_dataset_factory(episode_dicts=episode_dicts, task_dicts=task_dicts) hf_dataset = hf_dataset_factory(episode_dicts=episode_dicts, task_dicts=task_dicts)
return lerobot_dataset_factory( return lerobot_dataset_factory(
root=root, root=root,
repo_id=repo_id,
info_dict=info_dict, info_dict=info_dict,
task_dicts=task_dicts, task_dicts=task_dicts,
episode_dicts=episode_dicts, episode_dicts=episode_dicts,

View File

@ -41,26 +41,23 @@ from lerobot.common.datasets.utils import (
unflatten_dict, unflatten_dict,
) )
from lerobot.common.utils.utils import init_hydra_config, seeded_context from lerobot.common.utils.utils import init_hydra_config, seeded_context
from tests.fixtures.defaults import DEFAULT_FPS, DUMMY_REPO_ID, DUMMY_ROBOT_TYPE from tests.fixtures.defaults import DUMMY_REPO_ID
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, make_robot from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, make_robot
@pytest.fixture(scope="function") def test_same_attributes_defined(lerobot_dataset_factory, tmp_path):
def dataset_create(tmp_path):
robot = make_robot("koch", mock=True)
return LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=tmp_path)
@pytest.fixture(scope="function")
def dataset_init(lerobot_dataset_factory, tmp_path):
return lerobot_dataset_factory(root=tmp_path)
def test_same_attributes_defined(dataset_create, dataset_init):
""" """
Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated
objects have the same sets of attributes defined. objects have the same sets of attributes defined.
""" """
# Instantiate both ways
robot = make_robot("koch", mock=True)
root_create = tmp_path / "create"
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create)
root_init = tmp_path / "init"
dataset_init = lerobot_dataset_factory(root=root_init)
# Access the '_hub_version' cached_property in both instances to force its creation # Access the '_hub_version' cached_property in both instances to force its creation
_ = dataset_init._hub_version _ = dataset_init._hub_version
_ = dataset_create._hub_version _ = dataset_create._hub_version
@ -68,35 +65,34 @@ def test_same_attributes_defined(dataset_create, dataset_init):
init_attr = set(vars(dataset_init).keys()) init_attr = set(vars(dataset_init).keys())
create_attr = set(vars(dataset_create).keys()) create_attr = set(vars(dataset_create).keys())
assert init_attr == create_attr, "Attribute sets do not match between __init__ and .create()" assert init_attr == create_attr
def test_dataset_initialization(lerobot_dataset_from_episodes_factory, tmp_path): def test_dataset_initialization(lerobot_dataset_from_episodes_factory, tmp_path):
total_episodes = 10 kwargs = {
total_frames = 400 "repo_id": DUMMY_REPO_ID,
dataset = lerobot_dataset_from_episodes_factory( "total_episodes": 10,
root=tmp_path, total_episodes=total_episodes, total_frames=total_frames "total_frames": 400,
) "episodes": [2, 5, 6],
}
dataset = lerobot_dataset_from_episodes_factory(root=tmp_path, **kwargs)
assert dataset.repo_id == DUMMY_REPO_ID assert dataset.repo_id == kwargs["repo_id"]
assert dataset.num_episodes == total_episodes assert dataset.total_episodes == kwargs["total_episodes"]
assert dataset.num_samples == total_frames assert dataset.total_frames == kwargs["total_frames"]
assert dataset.info["fps"] == DEFAULT_FPS assert dataset.episodes == kwargs["episodes"]
assert dataset.info["robot_type"] == DUMMY_ROBOT_TYPE assert dataset.num_episodes == len(kwargs["episodes"])
assert dataset.num_frames == len(dataset)
def test_dataset_length(dataset_init): # TODO(aliberts):
dataset = dataset_init # - [ ] test various attributes & state from init and create
assert len(dataset) == 3 # Number of frames in the episode # - [ ] test init with episodes and check num_frames
# - [ ] test add_frame
# - [ ] test add_episode
def test_dataset_item(dataset_init): # - [ ] test consolidate
dataset = dataset_init # - [ ] test push_to_hub
item = dataset[0] # - [ ] test smaller methods
assert item["episode_index"] == 0
assert item["frame_index"] == 0
assert item["state"].tolist() == [1, 2, 3]
assert item["action"].tolist() == [0.1, 0.2]
@pytest.mark.skip("TODO after v2 migration / removing hydra") @pytest.mark.skip("TODO after v2 migration / removing hydra")
@ -186,7 +182,7 @@ def test_multilerobotdataset_frames():
sub_datasets = [LeRobotDataset(repo_id) for repo_id in repo_ids] sub_datasets = [LeRobotDataset(repo_id) for repo_id in repo_ids]
dataset = MultiLeRobotDataset(repo_ids) dataset = MultiLeRobotDataset(repo_ids)
assert len(dataset) == sum(len(d) for d in sub_datasets) assert len(dataset) == sum(len(d) for d in sub_datasets)
assert dataset.num_samples == sum(d.num_samples for d in sub_datasets) assert dataset.num_frames == sum(d.num_frames for d in sub_datasets)
assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets) assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets)
# Run through all items of the LeRobotDatasets in parallel with the items of the MultiLerobotDataset and # Run through all items of the LeRobotDatasets in parallel with the items of the MultiLerobotDataset and
@ -266,6 +262,7 @@ def test_compute_stats_on_xarm():
# assert torch.allclose(loaded_stats[k]["max"], expected_stats[k]["max"]) # assert torch.allclose(loaded_stats[k]["max"], expected_stats[k]["max"])
# TODO(aliberts): Move to more appropriate location
def test_flatten_unflatten_dict(): def test_flatten_unflatten_dict():
d = { d = {
"obs": { "obs": {
@ -301,7 +298,6 @@ def test_flatten_unflatten_dict():
# "lerobot/cmu_stretch", # "lerobot/cmu_stretch",
], ],
) )
# TODO(rcadene, aliberts): all these tests fail locally on Mac M1, but not on Linux # TODO(rcadene, aliberts): all these tests fail locally on Mac M1, but not on Linux
def test_backward_compatibility(repo_id): def test_backward_compatibility(repo_id):
"""The artifacts for this test have been generated by `tests/scripts/save_dataset_to_safetensors.py`.""" """The artifacts for this test have been generated by `tests/scripts/save_dataset_to_safetensors.py`."""