Simplify, add test content, add todo
This commit is contained in:
parent
79d114cc1f
commit
293bdc7f67
|
@ -297,6 +297,7 @@ def lerobot_dataset_factory(
|
||||||
):
|
):
|
||||||
def _create_lerobot_dataset(
|
def _create_lerobot_dataset(
|
||||||
root: Path,
|
root: Path,
|
||||||
|
repo_id: str = DUMMY_REPO_ID,
|
||||||
info_dict: dict = info,
|
info_dict: dict = info,
|
||||||
stats_dict: dict = stats,
|
stats_dict: dict = stats,
|
||||||
task_dicts: list[dict] = tasks,
|
task_dicts: list[dict] = tasks,
|
||||||
|
@ -322,7 +323,7 @@ def lerobot_dataset_factory(
|
||||||
mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version, enforce_v2=True: version
|
mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version, enforce_v2=True: version
|
||||||
mock_snapshot_download_patch.side_effect = mock_snapshot_download
|
mock_snapshot_download_patch.side_effect = mock_snapshot_download
|
||||||
|
|
||||||
return LeRobotDataset(repo_id=DUMMY_REPO_ID, root=root, **kwargs)
|
return LeRobotDataset(repo_id=repo_id, root=root, **kwargs)
|
||||||
|
|
||||||
return _create_lerobot_dataset
|
return _create_lerobot_dataset
|
||||||
|
|
||||||
|
@ -341,6 +342,7 @@ def lerobot_dataset_from_episodes_factory(
|
||||||
total_frames: int = 150,
|
total_frames: int = 150,
|
||||||
total_tasks: int = 1,
|
total_tasks: int = 1,
|
||||||
multi_task: bool = False,
|
multi_task: bool = False,
|
||||||
|
repo_id: str = DUMMY_REPO_ID,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
info_dict = info_factory(
|
info_dict = info_factory(
|
||||||
|
@ -356,6 +358,7 @@ def lerobot_dataset_from_episodes_factory(
|
||||||
hf_dataset = hf_dataset_factory(episode_dicts=episode_dicts, task_dicts=task_dicts)
|
hf_dataset = hf_dataset_factory(episode_dicts=episode_dicts, task_dicts=task_dicts)
|
||||||
return lerobot_dataset_factory(
|
return lerobot_dataset_factory(
|
||||||
root=root,
|
root=root,
|
||||||
|
repo_id=repo_id,
|
||||||
info_dict=info_dict,
|
info_dict=info_dict,
|
||||||
task_dicts=task_dicts,
|
task_dicts=task_dicts,
|
||||||
episode_dicts=episode_dicts,
|
episode_dicts=episode_dicts,
|
||||||
|
|
|
@ -41,26 +41,23 @@ from lerobot.common.datasets.utils import (
|
||||||
unflatten_dict,
|
unflatten_dict,
|
||||||
)
|
)
|
||||||
from lerobot.common.utils.utils import init_hydra_config, seeded_context
|
from lerobot.common.utils.utils import init_hydra_config, seeded_context
|
||||||
from tests.fixtures.defaults import DEFAULT_FPS, DUMMY_REPO_ID, DUMMY_ROBOT_TYPE
|
from tests.fixtures.defaults import DUMMY_REPO_ID
|
||||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, make_robot
|
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, make_robot
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
def test_same_attributes_defined(lerobot_dataset_factory, tmp_path):
|
||||||
def dataset_create(tmp_path):
|
|
||||||
robot = make_robot("koch", mock=True)
|
|
||||||
return LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=tmp_path)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
|
||||||
def dataset_init(lerobot_dataset_factory, tmp_path):
|
|
||||||
return lerobot_dataset_factory(root=tmp_path)
|
|
||||||
|
|
||||||
|
|
||||||
def test_same_attributes_defined(dataset_create, dataset_init):
|
|
||||||
"""
|
"""
|
||||||
Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated
|
Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated
|
||||||
objects have the same sets of attributes defined.
|
objects have the same sets of attributes defined.
|
||||||
"""
|
"""
|
||||||
|
# Instantiate both ways
|
||||||
|
robot = make_robot("koch", mock=True)
|
||||||
|
root_create = tmp_path / "create"
|
||||||
|
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create)
|
||||||
|
|
||||||
|
root_init = tmp_path / "init"
|
||||||
|
dataset_init = lerobot_dataset_factory(root=root_init)
|
||||||
|
|
||||||
# Access the '_hub_version' cached_property in both instances to force its creation
|
# Access the '_hub_version' cached_property in both instances to force its creation
|
||||||
_ = dataset_init._hub_version
|
_ = dataset_init._hub_version
|
||||||
_ = dataset_create._hub_version
|
_ = dataset_create._hub_version
|
||||||
|
@ -68,35 +65,34 @@ def test_same_attributes_defined(dataset_create, dataset_init):
|
||||||
init_attr = set(vars(dataset_init).keys())
|
init_attr = set(vars(dataset_init).keys())
|
||||||
create_attr = set(vars(dataset_create).keys())
|
create_attr = set(vars(dataset_create).keys())
|
||||||
|
|
||||||
assert init_attr == create_attr, "Attribute sets do not match between __init__ and .create()"
|
assert init_attr == create_attr
|
||||||
|
|
||||||
|
|
||||||
def test_dataset_initialization(lerobot_dataset_from_episodes_factory, tmp_path):
|
def test_dataset_initialization(lerobot_dataset_from_episodes_factory, tmp_path):
|
||||||
total_episodes = 10
|
kwargs = {
|
||||||
total_frames = 400
|
"repo_id": DUMMY_REPO_ID,
|
||||||
dataset = lerobot_dataset_from_episodes_factory(
|
"total_episodes": 10,
|
||||||
root=tmp_path, total_episodes=total_episodes, total_frames=total_frames
|
"total_frames": 400,
|
||||||
)
|
"episodes": [2, 5, 6],
|
||||||
|
}
|
||||||
|
dataset = lerobot_dataset_from_episodes_factory(root=tmp_path, **kwargs)
|
||||||
|
|
||||||
assert dataset.repo_id == DUMMY_REPO_ID
|
assert dataset.repo_id == kwargs["repo_id"]
|
||||||
assert dataset.num_episodes == total_episodes
|
assert dataset.total_episodes == kwargs["total_episodes"]
|
||||||
assert dataset.num_samples == total_frames
|
assert dataset.total_frames == kwargs["total_frames"]
|
||||||
assert dataset.info["fps"] == DEFAULT_FPS
|
assert dataset.episodes == kwargs["episodes"]
|
||||||
assert dataset.info["robot_type"] == DUMMY_ROBOT_TYPE
|
assert dataset.num_episodes == len(kwargs["episodes"])
|
||||||
|
assert dataset.num_frames == len(dataset)
|
||||||
|
|
||||||
|
|
||||||
def test_dataset_length(dataset_init):
|
# TODO(aliberts):
|
||||||
dataset = dataset_init
|
# - [ ] test various attributes & state from init and create
|
||||||
assert len(dataset) == 3 # Number of frames in the episode
|
# - [ ] test init with episodes and check num_frames
|
||||||
|
# - [ ] test add_frame
|
||||||
|
# - [ ] test add_episode
|
||||||
def test_dataset_item(dataset_init):
|
# - [ ] test consolidate
|
||||||
dataset = dataset_init
|
# - [ ] test push_to_hub
|
||||||
item = dataset[0]
|
# - [ ] test smaller methods
|
||||||
assert item["episode_index"] == 0
|
|
||||||
assert item["frame_index"] == 0
|
|
||||||
assert item["state"].tolist() == [1, 2, 3]
|
|
||||||
assert item["action"].tolist() == [0.1, 0.2]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||||
|
@ -186,7 +182,7 @@ def test_multilerobotdataset_frames():
|
||||||
sub_datasets = [LeRobotDataset(repo_id) for repo_id in repo_ids]
|
sub_datasets = [LeRobotDataset(repo_id) for repo_id in repo_ids]
|
||||||
dataset = MultiLeRobotDataset(repo_ids)
|
dataset = MultiLeRobotDataset(repo_ids)
|
||||||
assert len(dataset) == sum(len(d) for d in sub_datasets)
|
assert len(dataset) == sum(len(d) for d in sub_datasets)
|
||||||
assert dataset.num_samples == sum(d.num_samples for d in sub_datasets)
|
assert dataset.num_frames == sum(d.num_frames for d in sub_datasets)
|
||||||
assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets)
|
assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets)
|
||||||
|
|
||||||
# Run through all items of the LeRobotDatasets in parallel with the items of the MultiLerobotDataset and
|
# Run through all items of the LeRobotDatasets in parallel with the items of the MultiLerobotDataset and
|
||||||
|
@ -266,6 +262,7 @@ def test_compute_stats_on_xarm():
|
||||||
# assert torch.allclose(loaded_stats[k]["max"], expected_stats[k]["max"])
|
# assert torch.allclose(loaded_stats[k]["max"], expected_stats[k]["max"])
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(aliberts): Move to more appropriate location
|
||||||
def test_flatten_unflatten_dict():
|
def test_flatten_unflatten_dict():
|
||||||
d = {
|
d = {
|
||||||
"obs": {
|
"obs": {
|
||||||
|
@ -301,7 +298,6 @@ def test_flatten_unflatten_dict():
|
||||||
# "lerobot/cmu_stretch",
|
# "lerobot/cmu_stretch",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO(rcadene, aliberts): all these tests fail locally on Mac M1, but not on Linux
|
# TODO(rcadene, aliberts): all these tests fail locally on Mac M1, but not on Linux
|
||||||
def test_backward_compatibility(repo_id):
|
def test_backward_compatibility(repo_id):
|
||||||
"""The artifacts for this test have been generated by `tests/scripts/save_dataset_to_safetensors.py`."""
|
"""The artifacts for this test have been generated by `tests/scripts/save_dataset_to_safetensors.py`."""
|
||||||
|
|
Loading…
Reference in New Issue