diff --git a/lerobot/__init__.py b/lerobot/__init__.py index 8ab95df8..83e51c7a 100644 --- a/lerobot/__init__.py +++ b/lerobot/__init__.py @@ -7,16 +7,22 @@ Example: import lerobot print(lerobot.available_envs) print(lerobot.available_tasks_per_env) - print(lerobot.available_datasets_per_env) print(lerobot.available_datasets) print(lerobot.available_policies) + print(lerobot.available_policies_per_env) ``` -When implementing a new dataset (e.g. `AlohaDataset`), policy (e.g. `DiffusionPolicy`), or environment, follow these steps: -- Set the required class attributes: `available_datasets`. -- Set the required class attributes: `name`. -- Update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`) -- Update variables in `tests/test_available.py` by importing your new class +When implementing a new dataset class (e.g. `AlohaDataset`) follow these steps: +- Update `available_datasets` in `lerobot/__init__.py` +- Set the required `available_datasets` class attribute using the previously updated `lerobot.available_datasets` + +When implementing a new environment (e.g. `gym_aloha`), follow these steps: +- Update `available_envs`, `available_tasks_per_env` and `available_datasets` in `lerobot/__init__.py` + +When implementing a new policy class (e.g. `DiffusionPolicy`) follow these steps: +- Update `available_policies` in `lerobot/__init__.py` +- Set the required `name` class attribute. +- Update variables in `tests/test_available.py` by importing your new Policy class """ from lerobot.__version__ import __version__ # noqa: F401 @@ -36,7 +42,7 @@ available_tasks_per_env = { "xarm": ["XarmLift-v0"], } -available_datasets_per_env = { +available_datasets = { "aloha": [ "aloha_sim_insertion_human", "aloha_sim_insertion_scripted", @@ -47,10 +53,23 @@ available_datasets_per_env = { "xarm": ["xarm_lift_medium"], } -available_datasets = [dataset for env in available_envs for dataset in available_datasets_per_env[env]] - available_policies = [ "act", "diffusion", "tdmpc", ] + +available_policies_per_env = { + "aloha": ["act"], + "pusht": ["diffusion"], + "xarm": ["tdmpc"], +} + +env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks] +env_dataset_pairs = [(env, dataset) for env, datasets in available_datasets.items() for dataset in datasets] +env_dataset_policy_triplets = [ + (env, dataset, policy) + for env, datasets in available_datasets.items() + for dataset in datasets + for policy in available_policies_per_env[env] +] diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index 87ee57a8..4ba4d925 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -3,6 +3,7 @@ from pathlib import Path import torch from datasets import load_dataset, load_from_disk +import lerobot from lerobot.common.datasets.utils import load_previous_and_future_frames @@ -14,12 +15,7 @@ class AlohaDataset(torch.utils.data.Dataset): https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted """ - available_datasets = [ - "aloha_sim_insertion_human", - "aloha_sim_insertion_scripted", - "aloha_sim_transfer_cube_human", - "aloha_sim_transfer_cube_scripted", - ] + available_datasets = lerobot.available_datasets["aloha"] fps = 50 image_keys = ["observation.images.top"] diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index b9d06ba4..2689ebfa 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -3,6 +3,7 @@ from pathlib import Path import torch from datasets import load_dataset, load_from_disk +import lerobot from lerobot.common.datasets.utils import load_previous_and_future_frames @@ -17,7 +18,7 @@ class PushtDataset(torch.utils.data.Dataset): If `None`, no shift is applied to current timestamp and the data from the current frame is loaded. """ - available_datasets = ["pusht"] + available_datasets = lerobot.available_datasets["pusht"] fps = 10 image_keys = ["observation.image"] diff --git a/lerobot/common/datasets/xarm.py b/lerobot/common/datasets/xarm.py index 28ef4fa8..aa5646f9 100644 --- a/lerobot/common/datasets/xarm.py +++ b/lerobot/common/datasets/xarm.py @@ -3,6 +3,7 @@ from pathlib import Path import torch from datasets import load_dataset, load_from_disk +import lerobot from lerobot.common.datasets.utils import load_previous_and_future_frames @@ -11,9 +12,7 @@ class XarmDataset(torch.utils.data.Dataset): https://huggingface.co/datasets/lerobot/xarm_lift_medium """ - available_datasets = [ - "xarm_lift_medium", - ] + available_datasets = lerobot.available_datasets["xarm"] fps = 15 image_keys = ["observation.image"] diff --git a/lerobot/common/import_utils.py b/lerobot/common/import_utils.py new file mode 100644 index 00000000..642e0ff1 --- /dev/null +++ b/lerobot/common/import_utils.py @@ -0,0 +1,44 @@ +import importlib +import logging + + +def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[bool, str] | bool: + """Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/utils/import_utils.py + Check if the package spec exists and grab its version to avoid importing a local directory. + **Note:** this doesn't work for all packages. + """ + package_exists = importlib.util.find_spec(pkg_name) is not None + package_version = "N/A" + if package_exists: + try: + # Primary method to get the package version + package_version = importlib.metadata.version(pkg_name) + except importlib.metadata.PackageNotFoundError: + # Fallback method: Only for "torch" and versions containing "dev" + if pkg_name == "torch": + try: + package = importlib.import_module(pkg_name) + temp_version = getattr(package, "__version__", "N/A") + # Check if the version contains "dev" + if "dev" in temp_version: + package_version = temp_version + package_exists = True + else: + package_exists = False + except ImportError: + # If the package can't be imported, it's not available + package_exists = False + else: + # For packages other than "torch", don't attempt the fallback and set as not available + package_exists = False + logging.debug(f"Detected {pkg_name} version: {package_version}") + if return_version: + return package_exists, package_version + else: + return package_exists + + +_torch_available, _torch_version = is_package_available("torch", return_version=True) +_gym_xarm_available = is_package_available("gym_xarm") +_gym_aloha_available = is_package_available("gym_aloha") +_gym_pusht_available = is_package_available("gym_pusht") diff --git a/tests/test_available.py b/tests/test_available.py index 373cc1a7..560364c0 100644 --- a/tests/test_available.py +++ b/tests/test_available.py @@ -1,53 +1,39 @@ -""" -This test verifies that all environments, datasets, policies listed in `lerobot/__init__.py` can be sucessfully -imported and that their class attributes (eg. `available_datasets`, `name`, `available_tasks`) are valid. - -When implementing a new dataset (e.g. `AlohaDataset`), policy (e.g. `DiffusionPolicy`), or environment, follow these steps: -- Set the required class attributes: `available_datasets`. -- Set the required class attributes: `name`. -- Update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`) -- Update variables in `tests/test_available.py` by importing your new class -""" - import importlib import pytest import lerobot import gymnasium as gym -from lerobot.common.datasets.xarm import XarmDataset -from lerobot.common.datasets.aloha import AlohaDataset -from lerobot.common.datasets.pusht import PushtDataset - +from lerobot.common.import_utils import is_package_available from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.common.policies.tdmpc.policy import TDMPCPolicy -def test_available(): +@pytest.mark.parametrize("env_name, task_name", lerobot.env_task_pairs) +def test_available_env_task(env_name: str, task_name: list): + """ + This test verifies that all environments listed in `lerobot/__init__.py` can + be sucessfully imported if — they're installed — and that their + `available_tasks_per_env` are valid. + """ + package_name = f"gym_{env_name}" + if not is_package_available(package_name): + pytest.skip(f"gym-{env_name} not installed") + + importlib.import_module(package_name) + gym_handle = f"{package_name}/{task_name}" + assert gym_handle in gym.envs.registry.keys(), gym_handle + + +def test_available_policies(): + """ + This test verifies that the class attribute `name` for all policies is + consistent with those listed in `lerobot/__init__.py`. + """ policy_classes = [ ActionChunkingTransformerPolicy, DiffusionPolicy, TDMPCPolicy, ] - - dataset_class_per_env = { - "aloha": AlohaDataset, - "pusht": PushtDataset, - "xarm": XarmDataset, - } - policies = [pol_cls.name for pol_cls in policy_classes] assert set(policies) == set(lerobot.available_policies), policies - - for env_name in lerobot.available_envs: - for task_name in lerobot.available_tasks_per_env[env_name]: - package_name = f"gym_{env_name}" - importlib.import_module(package_name) - gym_handle = f"{package_name}/{task_name}" - assert gym_handle in gym.envs.registry.keys(), gym_handle - - dataset_class = dataset_class_per_env[env_name] - available_datasets = lerobot.available_datasets_per_env[env_name] - assert set(available_datasets) == set(dataset_class.available_datasets), f"{env_name=} {available_datasets=}" - -