Centralize availables

This commit is contained in:
Simon Alibert 2024-04-17 16:40:40 +02:00
parent 0928afd37d
commit 6dbbe87c2c
6 changed files with 100 additions and 55 deletions

View File

@ -7,16 +7,22 @@ Example:
import lerobot import lerobot
print(lerobot.available_envs) print(lerobot.available_envs)
print(lerobot.available_tasks_per_env) print(lerobot.available_tasks_per_env)
print(lerobot.available_datasets_per_env)
print(lerobot.available_datasets) print(lerobot.available_datasets)
print(lerobot.available_policies) print(lerobot.available_policies)
print(lerobot.available_policies_per_env)
``` ```
When implementing a new dataset (e.g. `AlohaDataset`), policy (e.g. `DiffusionPolicy`), or environment, follow these steps: When implementing a new dataset class (e.g. `AlohaDataset`) follow these steps:
- Set the required class attributes: `available_datasets`. - Update `available_datasets` in `lerobot/__init__.py`
- Set the required class attributes: `name`. - Set the required `available_datasets` class attribute using the previously updated `lerobot.available_datasets`
- Update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
- Update variables in `tests/test_available.py` by importing your new class When implementing a new environment (e.g. `gym_aloha`), follow these steps:
- Update `available_envs`, `available_tasks_per_env` and `available_datasets` in `lerobot/__init__.py`
When implementing a new policy class (e.g. `DiffusionPolicy`) follow these steps:
- Update `available_policies` in `lerobot/__init__.py`
- Set the required `name` class attribute.
- Update variables in `tests/test_available.py` by importing your new Policy class
""" """
from lerobot.__version__ import __version__ # noqa: F401 from lerobot.__version__ import __version__ # noqa: F401
@ -36,7 +42,7 @@ available_tasks_per_env = {
"xarm": ["XarmLift-v0"], "xarm": ["XarmLift-v0"],
} }
available_datasets_per_env = { available_datasets = {
"aloha": [ "aloha": [
"aloha_sim_insertion_human", "aloha_sim_insertion_human",
"aloha_sim_insertion_scripted", "aloha_sim_insertion_scripted",
@ -47,10 +53,23 @@ available_datasets_per_env = {
"xarm": ["xarm_lift_medium"], "xarm": ["xarm_lift_medium"],
} }
available_datasets = [dataset for env in available_envs for dataset in available_datasets_per_env[env]]
available_policies = [ available_policies = [
"act", "act",
"diffusion", "diffusion",
"tdmpc", "tdmpc",
] ]
available_policies_per_env = {
"aloha": ["act"],
"pusht": ["diffusion"],
"xarm": ["tdmpc"],
}
env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks]
env_dataset_pairs = [(env, dataset) for env, datasets in available_datasets.items() for dataset in datasets]
env_dataset_policy_triplets = [
(env, dataset, policy)
for env, datasets in available_datasets.items()
for dataset in datasets
for policy in available_policies_per_env[env]
]

View File

@ -3,6 +3,7 @@ from pathlib import Path
import torch import torch
from datasets import load_dataset, load_from_disk from datasets import load_dataset, load_from_disk
import lerobot
from lerobot.common.datasets.utils import load_previous_and_future_frames from lerobot.common.datasets.utils import load_previous_and_future_frames
@ -14,12 +15,7 @@ class AlohaDataset(torch.utils.data.Dataset):
https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted
""" """
available_datasets = [ available_datasets = lerobot.available_datasets["aloha"]
"aloha_sim_insertion_human",
"aloha_sim_insertion_scripted",
"aloha_sim_transfer_cube_human",
"aloha_sim_transfer_cube_scripted",
]
fps = 50 fps = 50
image_keys = ["observation.images.top"] image_keys = ["observation.images.top"]

View File

@ -3,6 +3,7 @@ from pathlib import Path
import torch import torch
from datasets import load_dataset, load_from_disk from datasets import load_dataset, load_from_disk
import lerobot
from lerobot.common.datasets.utils import load_previous_and_future_frames from lerobot.common.datasets.utils import load_previous_and_future_frames
@ -17,7 +18,7 @@ class PushtDataset(torch.utils.data.Dataset):
If `None`, no shift is applied to current timestamp and the data from the current frame is loaded. If `None`, no shift is applied to current timestamp and the data from the current frame is loaded.
""" """
available_datasets = ["pusht"] available_datasets = lerobot.available_datasets["pusht"]
fps = 10 fps = 10
image_keys = ["observation.image"] image_keys = ["observation.image"]

View File

@ -3,6 +3,7 @@ from pathlib import Path
import torch import torch
from datasets import load_dataset, load_from_disk from datasets import load_dataset, load_from_disk
import lerobot
from lerobot.common.datasets.utils import load_previous_and_future_frames from lerobot.common.datasets.utils import load_previous_and_future_frames
@ -11,9 +12,7 @@ class XarmDataset(torch.utils.data.Dataset):
https://huggingface.co/datasets/lerobot/xarm_lift_medium https://huggingface.co/datasets/lerobot/xarm_lift_medium
""" """
available_datasets = [ available_datasets = lerobot.available_datasets["xarm"]
"xarm_lift_medium",
]
fps = 15 fps = 15
image_keys = ["observation.image"] image_keys = ["observation.image"]

View File

@ -0,0 +1,44 @@
import importlib
import logging
def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[bool, str] | bool:
"""Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/utils/import_utils.py
Check if the package spec exists and grab its version to avoid importing a local directory.
**Note:** this doesn't work for all packages.
"""
package_exists = importlib.util.find_spec(pkg_name) is not None
package_version = "N/A"
if package_exists:
try:
# Primary method to get the package version
package_version = importlib.metadata.version(pkg_name)
except importlib.metadata.PackageNotFoundError:
# Fallback method: Only for "torch" and versions containing "dev"
if pkg_name == "torch":
try:
package = importlib.import_module(pkg_name)
temp_version = getattr(package, "__version__", "N/A")
# Check if the version contains "dev"
if "dev" in temp_version:
package_version = temp_version
package_exists = True
else:
package_exists = False
except ImportError:
# If the package can't be imported, it's not available
package_exists = False
else:
# For packages other than "torch", don't attempt the fallback and set as not available
package_exists = False
logging.debug(f"Detected {pkg_name} version: {package_version}")
if return_version:
return package_exists, package_version
else:
return package_exists
_torch_available, _torch_version = is_package_available("torch", return_version=True)
_gym_xarm_available = is_package_available("gym_xarm")
_gym_aloha_available = is_package_available("gym_aloha")
_gym_pusht_available = is_package_available("gym_pusht")

View File

@ -1,53 +1,39 @@
"""
This test verifies that all environments, datasets, policies listed in `lerobot/__init__.py` can be sucessfully
imported and that their class attributes (eg. `available_datasets`, `name`, `available_tasks`) are valid.
When implementing a new dataset (e.g. `AlohaDataset`), policy (e.g. `DiffusionPolicy`), or environment, follow these steps:
- Set the required class attributes: `available_datasets`.
- Set the required class attributes: `name`.
- Update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
- Update variables in `tests/test_available.py` by importing your new class
"""
import importlib import importlib
import pytest import pytest
import lerobot import lerobot
import gymnasium as gym import gymnasium as gym
from lerobot.common.datasets.xarm import XarmDataset from lerobot.common.import_utils import is_package_available
from lerobot.common.datasets.aloha import AlohaDataset
from lerobot.common.datasets.pusht import PushtDataset
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
def test_available(): @pytest.mark.parametrize("env_name, task_name", lerobot.env_task_pairs)
def test_available_env_task(env_name: str, task_name: list):
"""
This test verifies that all environments listed in `lerobot/__init__.py` can
be sucessfully imported if they're installed — and that their
`available_tasks_per_env` are valid.
"""
package_name = f"gym_{env_name}"
if not is_package_available(package_name):
pytest.skip(f"gym-{env_name} not installed")
importlib.import_module(package_name)
gym_handle = f"{package_name}/{task_name}"
assert gym_handle in gym.envs.registry.keys(), gym_handle
def test_available_policies():
"""
This test verifies that the class attribute `name` for all policies is
consistent with those listed in `lerobot/__init__.py`.
"""
policy_classes = [ policy_classes = [
ActionChunkingTransformerPolicy, ActionChunkingTransformerPolicy,
DiffusionPolicy, DiffusionPolicy,
TDMPCPolicy, TDMPCPolicy,
] ]
dataset_class_per_env = {
"aloha": AlohaDataset,
"pusht": PushtDataset,
"xarm": XarmDataset,
}
policies = [pol_cls.name for pol_cls in policy_classes] policies = [pol_cls.name for pol_cls in policy_classes]
assert set(policies) == set(lerobot.available_policies), policies assert set(policies) == set(lerobot.available_policies), policies
for env_name in lerobot.available_envs:
for task_name in lerobot.available_tasks_per_env[env_name]:
package_name = f"gym_{env_name}"
importlib.import_module(package_name)
gym_handle = f"{package_name}/{task_name}"
assert gym_handle in gym.envs.registry.keys(), gym_handle
dataset_class = dataset_class_per_env[env_name]
available_datasets = lerobot.available_datasets_per_env[env_name]
assert set(available_datasets) == set(dataset_class.available_datasets), f"{env_name=} {available_datasets=}"