Simplify, add test content, add todo
This commit is contained in:
parent
79d114cc1f
commit
293bdc7f67
|
@ -297,6 +297,7 @@ def lerobot_dataset_factory(
|
|||
):
|
||||
def _create_lerobot_dataset(
|
||||
root: Path,
|
||||
repo_id: str = DUMMY_REPO_ID,
|
||||
info_dict: dict = info,
|
||||
stats_dict: dict = stats,
|
||||
task_dicts: list[dict] = tasks,
|
||||
|
@ -322,7 +323,7 @@ def lerobot_dataset_factory(
|
|||
mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version, enforce_v2=True: version
|
||||
mock_snapshot_download_patch.side_effect = mock_snapshot_download
|
||||
|
||||
return LeRobotDataset(repo_id=DUMMY_REPO_ID, root=root, **kwargs)
|
||||
return LeRobotDataset(repo_id=repo_id, root=root, **kwargs)
|
||||
|
||||
return _create_lerobot_dataset
|
||||
|
||||
|
@ -341,6 +342,7 @@ def lerobot_dataset_from_episodes_factory(
|
|||
total_frames: int = 150,
|
||||
total_tasks: int = 1,
|
||||
multi_task: bool = False,
|
||||
repo_id: str = DUMMY_REPO_ID,
|
||||
**kwargs,
|
||||
):
|
||||
info_dict = info_factory(
|
||||
|
@ -356,6 +358,7 @@ def lerobot_dataset_from_episodes_factory(
|
|||
hf_dataset = hf_dataset_factory(episode_dicts=episode_dicts, task_dicts=task_dicts)
|
||||
return lerobot_dataset_factory(
|
||||
root=root,
|
||||
repo_id=repo_id,
|
||||
info_dict=info_dict,
|
||||
task_dicts=task_dicts,
|
||||
episode_dicts=episode_dicts,
|
||||
|
|
|
@ -41,26 +41,23 @@ from lerobot.common.datasets.utils import (
|
|||
unflatten_dict,
|
||||
)
|
||||
from lerobot.common.utils.utils import init_hydra_config, seeded_context
|
||||
from tests.fixtures.defaults import DEFAULT_FPS, DUMMY_REPO_ID, DUMMY_ROBOT_TYPE
|
||||
from tests.fixtures.defaults import DUMMY_REPO_ID
|
||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, make_robot
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def dataset_create(tmp_path):
|
||||
robot = make_robot("koch", mock=True)
|
||||
return LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=tmp_path)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def dataset_init(lerobot_dataset_factory, tmp_path):
|
||||
return lerobot_dataset_factory(root=tmp_path)
|
||||
|
||||
|
||||
def test_same_attributes_defined(dataset_create, dataset_init):
|
||||
def test_same_attributes_defined(lerobot_dataset_factory, tmp_path):
|
||||
"""
|
||||
Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated
|
||||
objects have the same sets of attributes defined.
|
||||
"""
|
||||
# Instantiate both ways
|
||||
robot = make_robot("koch", mock=True)
|
||||
root_create = tmp_path / "create"
|
||||
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create)
|
||||
|
||||
root_init = tmp_path / "init"
|
||||
dataset_init = lerobot_dataset_factory(root=root_init)
|
||||
|
||||
# Access the '_hub_version' cached_property in both instances to force its creation
|
||||
_ = dataset_init._hub_version
|
||||
_ = dataset_create._hub_version
|
||||
|
@ -68,35 +65,34 @@ def test_same_attributes_defined(dataset_create, dataset_init):
|
|||
init_attr = set(vars(dataset_init).keys())
|
||||
create_attr = set(vars(dataset_create).keys())
|
||||
|
||||
assert init_attr == create_attr, "Attribute sets do not match between __init__ and .create()"
|
||||
assert init_attr == create_attr
|
||||
|
||||
|
||||
def test_dataset_initialization(lerobot_dataset_from_episodes_factory, tmp_path):
|
||||
total_episodes = 10
|
||||
total_frames = 400
|
||||
dataset = lerobot_dataset_from_episodes_factory(
|
||||
root=tmp_path, total_episodes=total_episodes, total_frames=total_frames
|
||||
)
|
||||
kwargs = {
|
||||
"repo_id": DUMMY_REPO_ID,
|
||||
"total_episodes": 10,
|
||||
"total_frames": 400,
|
||||
"episodes": [2, 5, 6],
|
||||
}
|
||||
dataset = lerobot_dataset_from_episodes_factory(root=tmp_path, **kwargs)
|
||||
|
||||
assert dataset.repo_id == DUMMY_REPO_ID
|
||||
assert dataset.num_episodes == total_episodes
|
||||
assert dataset.num_samples == total_frames
|
||||
assert dataset.info["fps"] == DEFAULT_FPS
|
||||
assert dataset.info["robot_type"] == DUMMY_ROBOT_TYPE
|
||||
assert dataset.repo_id == kwargs["repo_id"]
|
||||
assert dataset.total_episodes == kwargs["total_episodes"]
|
||||
assert dataset.total_frames == kwargs["total_frames"]
|
||||
assert dataset.episodes == kwargs["episodes"]
|
||||
assert dataset.num_episodes == len(kwargs["episodes"])
|
||||
assert dataset.num_frames == len(dataset)
|
||||
|
||||
|
||||
def test_dataset_length(dataset_init):
|
||||
dataset = dataset_init
|
||||
assert len(dataset) == 3 # Number of frames in the episode
|
||||
|
||||
|
||||
def test_dataset_item(dataset_init):
|
||||
dataset = dataset_init
|
||||
item = dataset[0]
|
||||
assert item["episode_index"] == 0
|
||||
assert item["frame_index"] == 0
|
||||
assert item["state"].tolist() == [1, 2, 3]
|
||||
assert item["action"].tolist() == [0.1, 0.2]
|
||||
# TODO(aliberts):
|
||||
# - [ ] test various attributes & state from init and create
|
||||
# - [ ] test init with episodes and check num_frames
|
||||
# - [ ] test add_frame
|
||||
# - [ ] test add_episode
|
||||
# - [ ] test consolidate
|
||||
# - [ ] test push_to_hub
|
||||
# - [ ] test smaller methods
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
|
@ -186,7 +182,7 @@ def test_multilerobotdataset_frames():
|
|||
sub_datasets = [LeRobotDataset(repo_id) for repo_id in repo_ids]
|
||||
dataset = MultiLeRobotDataset(repo_ids)
|
||||
assert len(dataset) == sum(len(d) for d in sub_datasets)
|
||||
assert dataset.num_samples == sum(d.num_samples for d in sub_datasets)
|
||||
assert dataset.num_frames == sum(d.num_frames for d in sub_datasets)
|
||||
assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets)
|
||||
|
||||
# Run through all items of the LeRobotDatasets in parallel with the items of the MultiLerobotDataset and
|
||||
|
@ -266,6 +262,7 @@ def test_compute_stats_on_xarm():
|
|||
# assert torch.allclose(loaded_stats[k]["max"], expected_stats[k]["max"])
|
||||
|
||||
|
||||
# TODO(aliberts): Move to more appropriate location
|
||||
def test_flatten_unflatten_dict():
|
||||
d = {
|
||||
"obs": {
|
||||
|
@ -301,7 +298,6 @@ def test_flatten_unflatten_dict():
|
|||
# "lerobot/cmu_stretch",
|
||||
],
|
||||
)
|
||||
|
||||
# TODO(rcadene, aliberts): all these tests fail locally on Mac M1, but not on Linux
|
||||
def test_backward_compatibility(repo_id):
|
||||
"""The artifacts for this test have been generated by `tests/scripts/save_dataset_to_safetensors.py`."""
|
||||
|
|
Loading…
Reference in New Issue