Add tasks and episodes factories

This commit is contained in:
Simon Alibert 2024-11-01 13:37:17 +01:00
parent cd1509d805
commit 2650872b76
4 changed files with 231 additions and 99 deletions

View File

@ -1,7 +1,6 @@
import datasets import datasets
import pytest import pytest
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import get_episode_data_index from lerobot.common.datasets.utils import get_episode_data_index
from tests.fixtures.defaults import DUMMY_CAMERA_KEYS from tests.fixtures.defaults import DUMMY_CAMERA_KEYS
@ -66,9 +65,3 @@ def hf_dataset(hf_dataset_factory) -> datasets.Dataset:
def hf_dataset_image(hf_dataset_factory) -> datasets.Dataset: def hf_dataset_image(hf_dataset_factory) -> datasets.Dataset:
image_keys = DUMMY_CAMERA_KEYS image_keys = DUMMY_CAMERA_KEYS
return hf_dataset_factory(image_keys=image_keys) return hf_dataset_factory(image_keys=image_keys)
@pytest.fixture(scope="session")
def lerobot_dataset(lerobot_dataset_factory, tmp_path_factory) -> LeRobotDataset:
root = tmp_path_factory.getbasetemp()
return lerobot_dataset_factory(root=root)

View File

@ -1,3 +1,4 @@
import random
from pathlib import Path from pathlib import Path
from unittest.mock import patch from unittest.mock import patch
@ -12,10 +13,16 @@ from lerobot.common.datasets.utils import (
DEFAULT_VIDEO_PATH, DEFAULT_VIDEO_PATH,
hf_transform_to_torch, hf_transform_to_torch,
) )
from tests.fixtures.defaults import DUMMY_CAMERA_KEYS, DUMMY_KEYS, DUMMY_REPO_ID from tests.fixtures.defaults import (
DEFAULT_FPS,
DUMMY_CAMERA_KEYS,
DUMMY_KEYS,
DUMMY_REPO_ID,
DUMMY_ROBOT_TYPE,
)
def get_dummy_shapes(keys: list[str] | None = None, camera_keys: list[str] | None = None) -> dict: def make_dummy_shapes(keys: list[str] | None = None, camera_keys: list[str] | None = None) -> dict:
shapes = {} shapes = {}
if keys: if keys:
shapes.update({key: 10 for key in keys}) shapes.update({key: 10 for key in keys})
@ -25,10 +32,6 @@ def get_dummy_shapes(keys: list[str] | None = None, camera_keys: list[str] | Non
def get_task_index(tasks_dicts: dict, task: str) -> int: def get_task_index(tasks_dicts: dict, task: str) -> int:
"""
Given a task in natural language, returns its task_index if the task already exists in the dataset,
otherwise creates a new task_index.
"""
tasks = {d["task_index"]: d["task"] for d in tasks_dicts} tasks = {d["task_index"]: d["task"] for d in tasks_dicts}
task_to_task_index = {task: task_idx for task_idx, task in tasks.items()} task_to_task_index = {task: task_idx for task_idx, task in tasks.items()}
return task_to_task_index[task] return task_to_task_index[task]
@ -46,8 +49,8 @@ def img_array_factory():
def info_factory(): def info_factory():
def _create_info( def _create_info(
codebase_version: str = CODEBASE_VERSION, codebase_version: str = CODEBASE_VERSION,
fps: int = 30, fps: int = DEFAULT_FPS,
robot_type: str = "dummy_robot", robot_type: str = DUMMY_ROBOT_TYPE,
keys: list[str] = DUMMY_KEYS, keys: list[str] = DUMMY_KEYS,
image_keys: list[str] | None = None, image_keys: list[str] | None = None,
video_keys: list[str] = DUMMY_CAMERA_KEYS, video_keys: list[str] = DUMMY_CAMERA_KEYS,
@ -65,7 +68,7 @@ def info_factory():
if not image_keys: if not image_keys:
image_keys = [] image_keys = []
if not shapes: if not shapes:
shapes = get_dummy_shapes(keys=keys, camera_keys=[*image_keys, *video_keys]) shapes = make_dummy_shapes(keys=keys, camera_keys=[*image_keys, *video_keys])
if not names: if not names:
names = {key: [f"motor_{i}" for i in range(shapes[key])] for key in keys} names = {key: [f"motor_{i}" for i in range(shapes[key])] for key in keys}
@ -115,7 +118,7 @@ def stats_factory():
if not image_keys: if not image_keys:
image_keys = [] image_keys = []
if not shapes: if not shapes:
shapes = get_dummy_shapes(keys=keys, camera_keys=[*image_keys, *video_keys]) shapes = make_dummy_shapes(keys=keys, camera_keys=[*image_keys, *video_keys])
stats = {} stats = {}
for key in keys: for key in keys:
shape = shapes[key] shape = shapes[key]
@ -138,6 +141,68 @@ def stats_factory():
return _create_stats return _create_stats
@pytest.fixture(scope="session")
def tasks_factory():
def _create_tasks(total_tasks: int = 3) -> int:
tasks_list = []
for i in range(total_tasks):
task_dict = {"task_index": i, "task": f"Perform action {i}."}
tasks_list.append(task_dict)
return tasks_list
return _create_tasks
@pytest.fixture(scope="session")
def episodes_factory(tasks_factory):
def _create_episodes(
total_episodes: int = 3,
total_frames: int = 400,
task_dicts: dict | None = None,
multi_task: bool = False,
):
if total_episodes <= 0 or total_frames <= 0:
raise ValueError("num_episodes and total_length must be positive integers.")
if total_frames < total_episodes:
raise ValueError("total_length must be greater than or equal to num_episodes.")
if not task_dicts:
min_tasks = 2 if multi_task else 1
total_tasks = random.randint(min_tasks, total_episodes)
task_dicts = tasks_factory(total_tasks)
if total_episodes < len(task_dicts) and not multi_task:
raise ValueError("The number of tasks should be less than the number of episodes.")
# Generate random lengths that sum up to total_length
lengths = np.random.multinomial(total_frames, [1 / total_episodes] * total_episodes).tolist()
tasks_list = [task_dict["task"] for task_dict in task_dicts]
num_tasks_available = len(tasks_list)
episodes_list = []
remaining_tasks = tasks_list.copy()
for ep_idx in range(total_episodes):
num_tasks_in_episode = random.randint(1, min(3, num_tasks_available)) if multi_task else 1
tasks_to_sample = remaining_tasks if remaining_tasks else tasks_list
episode_tasks = random.sample(tasks_to_sample, min(num_tasks_in_episode, len(tasks_to_sample)))
if remaining_tasks:
for task in episode_tasks:
remaining_tasks.remove(task)
episodes_list.append(
{
"episode_index": ep_idx,
"tasks": episode_tasks,
"length": lengths[ep_idx],
}
)
return episodes_list
return _create_episodes
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def hf_dataset_factory(img_array_factory, episodes, tasks): def hf_dataset_factory(img_array_factory, episodes, tasks):
def _create_hf_dataset( def _create_hf_dataset(
@ -146,12 +211,12 @@ def hf_dataset_factory(img_array_factory, episodes, tasks):
keys: list[str] = DUMMY_KEYS, keys: list[str] = DUMMY_KEYS,
image_keys: list[str] | None = None, image_keys: list[str] | None = None,
shapes: dict | None = None, shapes: dict | None = None,
fps: int = 30, fps: int = DEFAULT_FPS,
) -> datasets.Dataset: ) -> datasets.Dataset:
if not image_keys: if not image_keys:
image_keys = [] image_keys = []
if not shapes: if not shapes:
shapes = get_dummy_shapes(keys=keys, camera_keys=image_keys) shapes = make_dummy_shapes(keys=keys, camera_keys=image_keys)
key_features = { key_features = {
key: datasets.Sequence(length=shapes[key], feature=datasets.Value(dtype="float32")) key: datasets.Sequence(length=shapes[key], feature=datasets.Value(dtype="float32"))
for key in keys for key in keys
@ -225,8 +290,8 @@ def hf_dataset_factory(img_array_factory, episodes, tasks):
def lerobot_dataset_factory( def lerobot_dataset_factory(
info, info,
stats, stats,
episodes,
tasks, tasks,
episodes,
hf_dataset, hf_dataset,
mock_snapshot_download_factory, mock_snapshot_download_factory,
): ):
@ -260,3 +325,42 @@ def lerobot_dataset_factory(
return LeRobotDataset(repo_id=DUMMY_REPO_ID, root=root, **kwargs) return LeRobotDataset(repo_id=DUMMY_REPO_ID, root=root, **kwargs)
return _create_lerobot_dataset return _create_lerobot_dataset
@pytest.fixture(scope="session")
def lerobot_dataset_from_episodes_factory(
info_factory,
tasks_factory,
episodes_factory,
hf_dataset_factory,
lerobot_dataset_factory,
):
def _create_lerobot_dataset_total_episodes(
root: Path,
total_episodes: int = 3,
total_frames: int = 150,
total_tasks: int = 1,
multi_task: bool = False,
**kwargs,
):
info_dict = info_factory(
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
)
task_dicts = tasks_factory(total_tasks)
episode_dicts = episodes_factory(
total_episodes=total_episodes,
total_frames=total_frames,
task_dicts=task_dicts,
multi_task=multi_task,
)
hf_dataset = hf_dataset_factory(episode_dicts=episode_dicts, task_dicts=task_dicts)
return lerobot_dataset_factory(
root=root,
info_dict=info_dict,
task_dicts=task_dicts,
episode_dicts=episode_dicts,
hf_ds=hf_dataset,
**kwargs,
)
return _create_lerobot_dataset_total_episodes

View File

@ -2,5 +2,7 @@ from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME
LEROBOT_TEST_DIR = LEROBOT_HOME / "_testing" LEROBOT_TEST_DIR = LEROBOT_HOME / "_testing"
DUMMY_REPO_ID = "dummy/repo" DUMMY_REPO_ID = "dummy/repo"
DUMMY_ROBOT_TYPE = "dummy_robot"
DUMMY_KEYS = ["state", "action"] DUMMY_KEYS = ["state", "action"]
DUMMY_CAMERA_KEYS = ["laptop", "phone"] DUMMY_CAMERA_KEYS = ["laptop", "phone"]
DEFAULT_FPS = 30

View File

@ -16,6 +16,7 @@
import json import json
import logging import logging
from copy import deepcopy from copy import deepcopy
from itertools import chain
from pathlib import Path from pathlib import Path
import einops import einops
@ -29,9 +30,10 @@ import lerobot
from lerobot.common.datasets.compute_stats import ( from lerobot.common.datasets.compute_stats import (
aggregate_stats, aggregate_stats,
compute_stats, compute_stats,
get_stats_einops_patterns,
) )
from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
create_branch, create_branch,
flatten_dict, flatten_dict,
@ -39,7 +41,7 @@ from lerobot.common.datasets.utils import (
unflatten_dict, unflatten_dict,
) )
from lerobot.common.utils.utils import init_hydra_config, seeded_context from lerobot.common.utils.utils import init_hydra_config, seeded_context
from tests.fixtures.defaults import DUMMY_REPO_ID from tests.fixtures.defaults import DEFAULT_FPS, DUMMY_REPO_ID, DUMMY_ROBOT_TYPE
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, make_robot from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, make_robot
@ -69,6 +71,34 @@ def test_same_attributes_defined(dataset_create, dataset_init):
assert init_attr == create_attr, "Attribute sets do not match between __init__ and .create()" assert init_attr == create_attr, "Attribute sets do not match between __init__ and .create()"
def test_dataset_initialization(lerobot_dataset_from_episodes_factory, tmp_path):
total_episodes = 10
total_frames = 400
dataset = lerobot_dataset_from_episodes_factory(
root=tmp_path, total_episodes=total_episodes, total_frames=total_frames
)
assert dataset.repo_id == DUMMY_REPO_ID
assert dataset.num_episodes == total_episodes
assert dataset.num_samples == total_frames
assert dataset.info["fps"] == DEFAULT_FPS
assert dataset.info["robot_type"] == DUMMY_ROBOT_TYPE
def test_dataset_length(dataset_init):
dataset = dataset_init
assert len(dataset) == 3 # Number of frames in the episode
def test_dataset_item(dataset_init):
dataset = dataset_init
item = dataset[0]
assert item["episode_index"] == 0
assert item["frame_index"] == 0
assert item["state"].tolist() == [1, 2, 3]
assert item["action"].tolist() == [0.1, 0.2]
@pytest.mark.skip("TODO after v2 migration / removing hydra") @pytest.mark.skip("TODO after v2 migration / removing hydra")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"env_name, repo_id, policy_name", "env_name, repo_id, policy_name",
@ -141,97 +171,99 @@ def test_factory(env_name, repo_id, policy_name):
assert key in item, f"{key}" assert key in item, f"{key}"
# # TODO(alexander-soare): If you're hunting for savings on testing time, this takes about 5 seconds. # TODO(alexander-soare): If you're hunting for savings on testing time, this takes about 5 seconds.
# def test_multilerobotdataset_frames(): @pytest.mark.skip("TODO after v2 migration / removing hydra")
# """Check that all dataset frames are incorporated.""" def test_multilerobotdataset_frames():
# # Note: use the image variants of the dataset to make the test approx 3x faster. """Check that all dataset frames are incorporated."""
# # Note: We really do need three repo_ids here as at some point this caught an issue with the chaining # Note: use the image variants of the dataset to make the test approx 3x faster.
# # logic that wouldn't be caught with two repo IDs. # Note: We really do need three repo_ids here as at some point this caught an issue with the chaining
# repo_ids = [ # logic that wouldn't be caught with two repo IDs.
# "lerobot/aloha_sim_insertion_human_image", repo_ids = [
# "lerobot/aloha_sim_transfer_cube_human_image", "lerobot/aloha_sim_insertion_human_image",
# "lerobot/aloha_sim_insertion_scripted_image", "lerobot/aloha_sim_transfer_cube_human_image",
# ] "lerobot/aloha_sim_insertion_scripted_image",
# sub_datasets = [LeRobotDataset(repo_id) for repo_id in repo_ids] ]
# dataset = MultiLeRobotDataset(repo_ids) sub_datasets = [LeRobotDataset(repo_id) for repo_id in repo_ids]
# assert len(dataset) == sum(len(d) for d in sub_datasets) dataset = MultiLeRobotDataset(repo_ids)
# assert dataset.num_samples == sum(d.num_samples for d in sub_datasets) assert len(dataset) == sum(len(d) for d in sub_datasets)
# assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets) assert dataset.num_samples == sum(d.num_samples for d in sub_datasets)
assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets)
# # Run through all items of the LeRobotDatasets in parallel with the items of the MultiLerobotDataset and # Run through all items of the LeRobotDatasets in parallel with the items of the MultiLerobotDataset and
# # check they match. # check they match.
# expected_dataset_indices = [] expected_dataset_indices = []
# for i, sub_dataset in enumerate(sub_datasets): for i, sub_dataset in enumerate(sub_datasets):
# expected_dataset_indices.extend([i] * len(sub_dataset)) expected_dataset_indices.extend([i] * len(sub_dataset))
# for expected_dataset_index, sub_dataset_item, dataset_item in zip( for expected_dataset_index, sub_dataset_item, dataset_item in zip(
# expected_dataset_indices, chain(*sub_datasets), dataset, strict=True expected_dataset_indices, chain(*sub_datasets), dataset, strict=True
# ): ):
# dataset_index = dataset_item.pop("dataset_index") dataset_index = dataset_item.pop("dataset_index")
# assert dataset_index == expected_dataset_index assert dataset_index == expected_dataset_index
# assert sub_dataset_item.keys() == dataset_item.keys() assert sub_dataset_item.keys() == dataset_item.keys()
# for k in sub_dataset_item: for k in sub_dataset_item:
# assert torch.equal(sub_dataset_item[k], dataset_item[k]) assert torch.equal(sub_dataset_item[k], dataset_item[k])
# TODO(aliberts, rcadene): Refactor and move this to a tests/test_compute_stats.py # TODO(aliberts, rcadene): Refactor and move this to a tests/test_compute_stats.py
# def test_compute_stats_on_xarm(): @pytest.mark.skip("TODO after v2 migration / removing hydra")
# """Check that the statistics are computed correctly according to the stats_patterns property. def test_compute_stats_on_xarm():
"""Check that the statistics are computed correctly according to the stats_patterns property.
# We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do
# because we are working with a small dataset). because we are working with a small dataset).
# """ """
# dataset = LeRobotDataset("lerobot/xarm_lift_medium") dataset = LeRobotDataset("lerobot/xarm_lift_medium")
# # reduce size of dataset sample on which stats compute is tested to 10 frames # reduce size of dataset sample on which stats compute is tested to 10 frames
# dataset.hf_dataset = dataset.hf_dataset.select(range(10)) dataset.hf_dataset = dataset.hf_dataset.select(range(10))
# # Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched # Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
# # computation of the statistics. While doing this, we also make sure it works when we don't divide the # computation of the statistics. While doing this, we also make sure it works when we don't divide the
# # dataset into even batches. # dataset into even batches.
# computed_stats = compute_stats(dataset, batch_size=int(len(dataset) * 0.25), num_workers=0) computed_stats = compute_stats(dataset, batch_size=int(len(dataset) * 0.25), num_workers=0)
# # get einops patterns to aggregate batches and compute statistics # get einops patterns to aggregate batches and compute statistics
# stats_patterns = get_stats_einops_patterns(dataset) stats_patterns = get_stats_einops_patterns(dataset)
# # get all frames from the dataset in the same dtype and range as during compute_stats # get all frames from the dataset in the same dtype and range as during compute_stats
# dataloader = torch.utils.data.DataLoader( dataloader = torch.utils.data.DataLoader(
# dataset, dataset,
# num_workers=0, num_workers=0,
# batch_size=len(dataset), batch_size=len(dataset),
# shuffle=False, shuffle=False,
# ) )
# full_batch = next(iter(dataloader)) full_batch = next(iter(dataloader))
# # compute stats based on all frames from the dataset without any batching # compute stats based on all frames from the dataset without any batching
# expected_stats = {} expected_stats = {}
# for k, pattern in stats_patterns.items(): for k, pattern in stats_patterns.items():
# full_batch[k] = full_batch[k].float() full_batch[k] = full_batch[k].float()
# expected_stats[k] = {} expected_stats[k] = {}
# expected_stats[k]["mean"] = einops.reduce(full_batch[k], pattern, "mean") expected_stats[k]["mean"] = einops.reduce(full_batch[k], pattern, "mean")
# expected_stats[k]["std"] = torch.sqrt( expected_stats[k]["std"] = torch.sqrt(
# einops.reduce((full_batch[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean") einops.reduce((full_batch[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean")
# ) )
# expected_stats[k]["min"] = einops.reduce(full_batch[k], pattern, "min") expected_stats[k]["min"] = einops.reduce(full_batch[k], pattern, "min")
# expected_stats[k]["max"] = einops.reduce(full_batch[k], pattern, "max") expected_stats[k]["max"] = einops.reduce(full_batch[k], pattern, "max")
# # test computed stats match expected stats # test computed stats match expected stats
# for k in stats_patterns: for k in stats_patterns:
# assert torch.allclose(computed_stats[k]["mean"], expected_stats[k]["mean"]) assert torch.allclose(computed_stats[k]["mean"], expected_stats[k]["mean"])
# assert torch.allclose(computed_stats[k]["std"], expected_stats[k]["std"]) assert torch.allclose(computed_stats[k]["std"], expected_stats[k]["std"])
# assert torch.allclose(computed_stats[k]["min"], expected_stats[k]["min"]) assert torch.allclose(computed_stats[k]["min"], expected_stats[k]["min"])
# assert torch.allclose(computed_stats[k]["max"], expected_stats[k]["max"]) assert torch.allclose(computed_stats[k]["max"], expected_stats[k]["max"])
# # load stats used during training which are expected to match the ones returned by computed_stats # load stats used during training which are expected to match the ones returned by computed_stats
# loaded_stats = dataset.stats # noqa: F841 loaded_stats = dataset.stats # noqa: F841
# # TODO(rcadene): we can't test this because expected_stats is computed on a subset # TODO(rcadene): we can't test this because expected_stats is computed on a subset
# # # test loaded stats match expected stats # # test loaded stats match expected stats
# # for k in stats_patterns: # for k in stats_patterns:
# # assert torch.allclose(loaded_stats[k]["mean"], expected_stats[k]["mean"]) # assert torch.allclose(loaded_stats[k]["mean"], expected_stats[k]["mean"])
# # assert torch.allclose(loaded_stats[k]["std"], expected_stats[k]["std"]) # assert torch.allclose(loaded_stats[k]["std"], expected_stats[k]["std"])
# # assert torch.allclose(loaded_stats[k]["min"], expected_stats[k]["min"]) # assert torch.allclose(loaded_stats[k]["min"], expected_stats[k]["min"])
# # assert torch.allclose(loaded_stats[k]["max"], expected_stats[k]["max"]) # assert torch.allclose(loaded_stats[k]["max"], expected_stats[k]["max"])
def test_flatten_unflatten_dict(): def test_flatten_unflatten_dict():
@ -269,6 +301,7 @@ def test_flatten_unflatten_dict():
# "lerobot/cmu_stretch", # "lerobot/cmu_stretch",
], ],
) )
# TODO(rcadene, aliberts): all these tests fail locally on Mac M1, but not on Linux # TODO(rcadene, aliberts): all these tests fail locally on Mac M1, but not on Linux
def test_backward_compatibility(repo_id): def test_backward_compatibility(repo_id):
"""The artifacts for this test have been generated by `tests/scripts/save_dataset_to_safetensors.py`.""" """The artifacts for this test have been generated by `tests/scripts/save_dataset_to_safetensors.py`."""