Tests cleaning & simplification (#81)

This commit is contained in:
Simon Alibert 2024-04-18 14:47:42 +02:00 committed by GitHub
parent 0928afd37d
commit 7ad1909641
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 277 additions and 157 deletions

View File

@ -11,7 +11,7 @@ body:
id: system-info
attributes:
label: System Info
description: If needed, you can share your lerobot configuration with us by running `python -m lerobot.commands.env` and copy-pasting its outputs below
description: If needed, you can share your lerobot configuration with us by running `python -m lerobot.scripts.display_sys_info` and copy-pasting its outputs below
render: Shell
placeholder: lerobot version, OS, python version, numpy version, torch version, and lerobot's configuration
validations:

View File

@ -117,11 +117,9 @@ jobs:
# run tests & coverage
#----------------------------------------------
- name: Run tests
env:
LEROBOT_TESTS_DEVICE: cpu
run: |
source .venv/bin/activate
pytest --cov=./lerobot --cov-report=xml tests
pytest -v --cov=./lerobot --cov-report=xml tests
# TODO(aliberts): Link with HF Codecov account
# - name: Upload coverage reports to Codecov with GitHub Action

View File

@ -1,4 +1,4 @@
exclude: ^(data/|tests/)
exclude: ^(data/|tests/data)
default_language_version:
python: python3.10
repos:

View File

@ -65,6 +65,26 @@ A good feature request addresses the following points:
If your issue is well written we're already 80% of the way there by the time you
post it.
## Adding new policies, datasets or environments
Look at our implementations for [datasets](./lerobot/common/datasets/), [policies](./lerobot/common/policies/),
environments ([aloha](https://github.com/huggingface/gym-aloha),
[xarm](https://github.com/huggingface/gym-xarm),
[pusht](https://github.com/huggingface/gym-pusht))
and follow the same api design.
When implementing a new dataset class (e.g. `AlohaDataset`) follow these steps:
- Update `available_datasets` in `lerobot/__init__.py`
- Copy it in the required `available_datasets` class attribute
When implementing a new environment (e.g. `gym_aloha`), follow these steps:
- Update `available_envs`, `available_tasks_per_env` and `available_datasets` in `lerobot/__init__.py`
When implementing a new policy class (e.g. `DiffusionPolicy`) follow these steps:
- Update `available_policies` in `lerobot/__init__.py`
- Set the required `name` class attribute.
- Update variables in `tests/test_available.py` by importing your new Policy class
## Submitting a pull request (PR)
Before writing code, we strongly advise you to search through the existing PRs or

View File

@ -7,7 +7,7 @@ from pathlib import Path
from huggingface_hub import snapshot_download
from lerobot.common.utils import init_hydra_config
from lerobot.common.utils.utils import init_hydra_config
from lerobot.scripts.eval import eval
# Get a pretrained policy from the hub.

View File

@ -13,7 +13,7 @@ from omegaconf import OmegaConf
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.common.utils import init_hydra_config
from lerobot.common.utils.utils import init_hydra_config
output_directory = Path("outputs/train/example_pusht_diffusion")
os.makedirs(output_directory, exist_ok=True)

View File

@ -7,16 +7,22 @@ Example:
import lerobot
print(lerobot.available_envs)
print(lerobot.available_tasks_per_env)
print(lerobot.available_datasets_per_env)
print(lerobot.available_datasets)
print(lerobot.available_policies)
print(lerobot.available_policies_per_env)
```
When implementing a new dataset (e.g. `AlohaDataset`), policy (e.g. `DiffusionPolicy`), or environment, follow these steps:
- Set the required class attributes: `available_datasets`.
- Set the required class attributes: `name`.
- Update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
- Update variables in `tests/test_available.py` by importing your new class
When implementing a new dataset class (e.g. `AlohaDataset`) follow these steps:
- Update `available_datasets` in `lerobot/__init__.py`
- Set the required `available_datasets` class attribute using the previously updated `lerobot.available_datasets`
When implementing a new environment (e.g. `gym_aloha`), follow these steps:
- Update `available_envs`, `available_tasks_per_env` and `available_datasets` in `lerobot/__init__.py`
When implementing a new policy class (e.g. `DiffusionPolicy`) follow these steps:
- Update `available_policies` in `lerobot/__init__.py`
- Set the required `name` class attribute.
- Update variables in `tests/test_available.py` by importing your new Policy class
"""
from lerobot.__version__ import __version__ # noqa: F401
@ -36,7 +42,7 @@ available_tasks_per_env = {
"xarm": ["XarmLift-v0"],
}
available_datasets_per_env = {
available_datasets = {
"aloha": [
"aloha_sim_insertion_human",
"aloha_sim_insertion_scripted",
@ -47,10 +53,23 @@ available_datasets_per_env = {
"xarm": ["xarm_lift_medium"],
}
available_datasets = [dataset for env in available_envs for dataset in available_datasets_per_env[env]]
available_policies = [
"act",
"diffusion",
"tdmpc",
]
available_policies_per_env = {
"aloha": ["act"],
"pusht": ["diffusion"],
"xarm": ["tdmpc"],
}
env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks]
env_dataset_pairs = [(env, dataset) for env, datasets in available_datasets.items() for dataset in datasets]
env_dataset_policy_triplets = [
(env, dataset, policy)
for env, datasets in available_datasets.items()
for dataset in datasets
for policy in available_policies_per_env[env]
]

View File

@ -14,6 +14,7 @@ class AlohaDataset(torch.utils.data.Dataset):
https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted
"""
# Copied from lerobot/__init__.py
available_datasets = [
"aloha_sim_insertion_human",
"aloha_sim_insertion_scripted",

View File

@ -17,6 +17,7 @@ class PushtDataset(torch.utils.data.Dataset):
If `None`, no shift is applied to current timestamp and the data from the current frame is loaded.
"""
# Copied from lerobot/__init__.py
available_datasets = ["pusht"]
fps = 10
image_keys = ["observation.image"]

View File

@ -11,9 +11,8 @@ class XarmDataset(torch.utils.data.Dataset):
https://huggingface.co/datasets/lerobot/xarm_lift_medium
"""
available_datasets = [
"xarm_lift_medium",
]
# Copied from lerobot/__init__.py
available_datasets = ["xarm_lift_medium"]
fps = 15
image_keys = ["observation.image"]

View File

@ -2,7 +2,7 @@ import inspect
from omegaconf import DictConfig, OmegaConf
from lerobot.common.utils import get_safe_torch_device
from lerobot.common.utils.utils import get_safe_torch_device
def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):

View File

@ -11,7 +11,7 @@ import torch.nn as nn
import lerobot.common.policies.tdmpc.helper as h
from lerobot.common.policies.utils import populate_queues
from lerobot.common.utils import get_safe_torch_device
from lerobot.common.utils.utils import get_safe_torch_device
FIRST_FRAME = 0

View File

@ -0,0 +1,44 @@
import importlib
import logging
def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[bool, str] | bool:
"""Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/utils/import_utils.py
Check if the package spec exists and grab its version to avoid importing a local directory.
**Note:** this doesn't work for all packages.
"""
package_exists = importlib.util.find_spec(pkg_name) is not None
package_version = "N/A"
if package_exists:
try:
# Primary method to get the package version
package_version = importlib.metadata.version(pkg_name)
except importlib.metadata.PackageNotFoundError:
# Fallback method: Only for "torch" and versions containing "dev"
if pkg_name == "torch":
try:
package = importlib.import_module(pkg_name)
temp_version = getattr(package, "__version__", "N/A")
# Check if the version contains "dev"
if "dev" in temp_version:
package_version = temp_version
package_exists = True
else:
package_exists = False
except ImportError:
# If the package can't be imported, it's not available
package_exists = False
else:
# For packages other than "torch", don't attempt the fallback and set as not available
package_exists = False
logging.debug(f"Detected {pkg_name} version: {package_version}")
if return_version:
return package_exists, package_version
else:
return package_exists
_torch_available, _torch_version = is_package_available("torch", return_version=True)
_gym_xarm_available = is_package_available("gym_xarm")
_gym_aloha_available = is_package_available("gym_aloha")
_gym_pusht_available = is_package_available("gym_pusht")

View File

@ -15,7 +15,7 @@ cuda_version = torch._C._cuda_getCompiledVersion() if torch.version.cuda is not
# TODO(aliberts): refactor into an actual command `lerobot env`
def get_env_info() -> dict:
def display_sys_info() -> dict:
"""Run this to get basic system info to help for tracking issues & bugs."""
info = {
"`lerobot` version": version,
@ -40,4 +40,4 @@ def format_dict(d: dict) -> str:
if __name__ == "__main__":
get_env_info()
display_sys_info()

View File

@ -50,7 +50,7 @@ from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
from lerobot.common.logger import log_output_dir
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
def write_video(video_path, stacked_frames, fps):

View File

@ -13,7 +13,7 @@ from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env
from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils import (
from lerobot.common.utils.utils import (
format_big_number,
get_safe_torch_device,
init_logging,

View File

@ -9,7 +9,7 @@ import torch
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.logger import log_output_dir
from lerobot.common.utils import init_logging
from lerobot.common.utils.utils import init_logging
NUM_EPISODES_TO_RENDER = 50
MAX_NUM_STEPS = 1000

View File

@ -1,53 +1,60 @@
"""
This test verifies that all environments, datasets, policies listed in `lerobot/__init__.py` can be sucessfully
imported and that their class attributes (eg. `available_datasets`, `name`, `available_tasks`) are valid.
When implementing a new dataset (e.g. `AlohaDataset`), policy (e.g. `DiffusionPolicy`), or environment, follow these steps:
- Set the required class attributes: `available_datasets`.
- Set the required class attributes: `name`.
- Update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
- Update variables in `tests/test_available.py` by importing your new class
"""
import importlib
import pytest
import lerobot
import gymnasium as gym
from lerobot.common.datasets.xarm import XarmDataset
import gymnasium as gym
import pytest
import lerobot
from lerobot.common.datasets.aloha import AlohaDataset
from lerobot.common.datasets.pusht import PushtDataset
from lerobot.common.datasets.xarm import XarmDataset
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
from tests.utils import require_env
def test_available():
@pytest.mark.parametrize("env_name, task_name", lerobot.env_task_pairs)
@require_env
def test_available_env_task(env_name: str, task_name: list):
"""
This test verifies that all environments listed in `lerobot/__init__.py` can
be sucessfully imported if they're installed — and that their
`available_tasks_per_env` are valid.
"""
package_name = f"gym_{env_name}"
importlib.import_module(package_name)
gym_handle = f"{package_name}/{task_name}"
assert gym_handle in gym.envs.registry, gym_handle
@pytest.mark.parametrize(
"env_name, dataset_class",
[
("aloha", AlohaDataset),
("pusht", PushtDataset),
("xarm", XarmDataset),
],
)
def test_available_datasets(env_name, dataset_class):
"""
This test verifies that the class attribute `available_datasets` for all
dataset classes is consistent with those listed in `lerobot/__init__.py`.
"""
available_env_datasets = lerobot.available_datasets[env_name]
assert set(available_env_datasets) == set(
dataset_class.available_datasets
), f"{env_name=} {available_env_datasets=}"
def test_available_policies():
"""
This test verifies that the class attribute `name` for all policies is
consistent with those listed in `lerobot/__init__.py`.
"""
policy_classes = [
ActionChunkingTransformerPolicy,
DiffusionPolicy,
TDMPCPolicy,
]
dataset_class_per_env = {
"aloha": AlohaDataset,
"pusht": PushtDataset,
"xarm": XarmDataset,
}
policies = [pol_cls.name for pol_cls in policy_classes]
assert set(policies) == set(lerobot.available_policies), policies
for env_name in lerobot.available_envs:
for task_name in lerobot.available_tasks_per_env[env_name]:
package_name = f"gym_{env_name}"
importlib.import_module(package_name)
gym_handle = f"{package_name}/{task_name}"
assert gym_handle in gym.envs.registry.keys(), gym_handle
dataset_class = dataset_class_per_env[env_name]
available_datasets = lerobot.available_datasets_per_env[env_name]
assert set(available_datasets) == set(dataset_class.available_datasets), f"{env_name=} {available_datasets=}"

View File

@ -1,33 +1,35 @@
import logging
import os
from pathlib import Path
import einops
import pytest
import torch
from lerobot.common.datasets.utils import compute_stats, get_stats_einops_patterns, load_previous_and_future_frames
from lerobot.common.transforms import Prod
from lerobot.common.utils import init_hydra_config
import logging
from lerobot.common.datasets.factory import make_dataset
from datasets import Dataset
from .utils import DEVICE, DEFAULT_CONFIG_PATH
@pytest.mark.parametrize(
"env_name,dataset_id,policy_name",
[
("xarm", "xarm_lift_medium", "tdmpc"),
("pusht", "pusht", "diffusion"),
("aloha", "aloha_sim_insertion_human", "act"),
("aloha", "aloha_sim_insertion_scripted", "act"),
("aloha", "aloha_sim_transfer_cube_human", "act"),
("aloha", "aloha_sim_transfer_cube_scripted", "act"),
],
import lerobot
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import (
compute_stats,
get_stats_einops_patterns,
load_previous_and_future_frames,
)
from lerobot.common.transforms import Prod
from lerobot.common.utils.utils import init_hydra_config
from .utils import DEFAULT_CONFIG_PATH, DEVICE
@pytest.mark.parametrize("env_name, dataset_id, policy_name", lerobot.env_dataset_policy_triplets)
def test_factory(env_name, dataset_id, policy_name):
cfg = init_hydra_config(
DEFAULT_CONFIG_PATH,
overrides=[f"env={env_name}", f"dataset_id={dataset_id}", f"policy={policy_name}", f"device={DEVICE}"]
overrides=[
f"env={env_name}",
f"dataset_id={dataset_id}",
f"policy={policy_name}",
f"device={DEVICE}",
],
)
dataset = make_dataset(cfg)
delta_timestamps = dataset.delta_timestamps
@ -51,7 +53,7 @@ def test_factory(env_name, dataset_id, policy_name):
(key, 3, True),
)
assert dataset.hf_dataset[key].dtype == torch.uint8, f"{key}"
# test number of dimensions
for key, ndim, required in keys_ndim_required:
if key not in item:
@ -60,13 +62,13 @@ def test_factory(env_name, dataset_id, policy_name):
else:
logging.warning(f'Missing key in dataset: "{key}" not in {dataset}.')
continue
if delta_timestamps is not None and key in delta_timestamps:
assert item[key].ndim == ndim + 1, f"{key}"
assert item[key].shape[0] == len(delta_timestamps[key]), f"{key}"
else:
assert item[key].ndim == ndim, f"{key}"
if key in image_keys:
assert item[key].dtype == torch.float32, f"{key}"
# TODO(rcadene): we assume for now that image normalization takes place in the model
@ -77,17 +79,16 @@ def test_factory(env_name, dataset_id, policy_name):
# test t,c,h,w
assert item[key].shape[1] == 3, f"{key}"
else:
# test c,h,w
# test c,h,w
assert item[key].shape[0] == 3, f"{key}"
if delta_timestamps is not None:
# test missing keys in delta_timestamps
for key in delta_timestamps:
assert key in item, f"{key}"
def test_compute_stats():
def test_compute_stats_on_xarm():
"""Check that the statistics are computed correctly according to the stats_patterns property.
We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do
@ -95,20 +96,20 @@ def test_compute_stats():
"""
from lerobot.common.datasets.xarm import XarmDataset
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
data_dir = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
# get transform to convert images from uint8 [0,255] to float32 [0,1]
transform = Prod(in_keys=XarmDataset.image_keys, prod=1 / 255.0)
dataset = XarmDataset(
dataset_id="xarm_lift_medium",
root=DATA_DIR,
root=data_dir,
transform=transform,
)
# Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
# computation of the statistics. While doing this, we also make sure it works when we don't divide the
# dataset into even batches.
# dataset into even batches.
computed_stats = compute_stats(dataset, batch_size=int(len(dataset) * 0.25))
# get einops patterns to aggregate batches and compute statistics
@ -128,7 +129,9 @@ def test_compute_stats():
for k, pattern in stats_patterns.items():
expected_stats[k] = {}
expected_stats[k]["mean"] = einops.reduce(hf_dataset[k], pattern, "mean")
expected_stats[k]["std"] = torch.sqrt(einops.reduce((hf_dataset[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean"))
expected_stats[k]["std"] = torch.sqrt(
einops.reduce((hf_dataset[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean")
)
expected_stats[k]["min"] = einops.reduce(hf_dataset[k], pattern, "min")
expected_stats[k]["max"] = einops.reduce(hf_dataset[k], pattern, "max")
@ -153,12 +156,14 @@ def test_compute_stats():
def test_load_previous_and_future_frames_within_tolerance():
hf_dataset = Dataset.from_dict({
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4],
"episode_data_index_from": [0, 0, 0, 0, 0],
"episode_data_index_to": [5, 5, 5, 5, 5],
})
hf_dataset = Dataset.from_dict(
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4],
"episode_data_index_from": [0, 0, 0, 0, 0],
"episode_data_index_to": [5, 5, 5, 5, 5],
}
)
hf_dataset = hf_dataset.with_format("torch")
item = hf_dataset[2]
delta_timestamps = {"index": [-0.2, 0, 0.139]}
@ -168,13 +173,16 @@ def test_load_previous_and_future_frames_within_tolerance():
assert torch.equal(data, torch.tensor([0, 2, 3])), "Data does not match expected values"
assert not is_pad.any(), "Unexpected padding detected"
def test_load_previous_and_future_frames_outside_tolerance_inside_episode_range():
hf_dataset = Dataset.from_dict({
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4],
"episode_data_index_from": [0, 0, 0, 0, 0],
"episode_data_index_to": [5, 5, 5, 5, 5],
})
hf_dataset = Dataset.from_dict(
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4],
"episode_data_index_from": [0, 0, 0, 0, 0],
"episode_data_index_to": [5, 5, 5, 5, 5],
}
)
hf_dataset = hf_dataset.with_format("torch")
item = hf_dataset[2]
delta_timestamps = {"index": [-0.2, 0, 0.141]}
@ -182,13 +190,16 @@ def test_load_previous_and_future_frames_outside_tolerance_inside_episode_range(
with pytest.raises(AssertionError):
load_previous_and_future_frames(item, hf_dataset, delta_timestamps, tol)
def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range():
hf_dataset = Dataset.from_dict({
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4],
"episode_data_index_from": [0, 0, 0, 0, 0],
"episode_data_index_to": [5, 5, 5, 5, 5],
})
hf_dataset = Dataset.from_dict(
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4],
"episode_data_index_from": [0, 0, 0, 0, 0],
"episode_data_index_to": [5, 5, 5, 5, 5],
}
)
hf_dataset = hf_dataset.with_format("torch")
item = hf_dataset[2]
delta_timestamps = {"index": [-0.3, -0.24, 0, 0.26, 0.3]}
@ -196,6 +207,6 @@ def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range
item = load_previous_and_future_frames(item, hf_dataset, delta_timestamps, tol)
data, is_pad = item["index"], item["index_is_pad"]
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values"
assert torch.equal(is_pad, torch.tensor([True, False, False, True, True])), "Padding does not match expected values"
assert torch.equal(
is_pad, torch.tensor([True, False, False, True, True])
), "Padding does not match expected values"

View File

@ -1,49 +1,37 @@
import importlib
import gymnasium as gym
import pytest
import torch
from lerobot.common.datasets.factory import make_dataset
import gymnasium as gym
from gymnasium.utils.env_checker import check_env
import lerobot
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.envs.factory import make_env
from lerobot.common.utils import init_hydra_config
from lerobot.common.envs.utils import preprocess_observation
from lerobot.common.utils.utils import init_hydra_config
from .utils import DEVICE, DEFAULT_CONFIG_PATH
from .utils import DEFAULT_CONFIG_PATH, DEVICE, require_env
OBS_TYPES = ["state", "pixels", "pixels_agent_pos"]
@pytest.mark.parametrize(
"env_name, task, obs_type",
[
# ("AlohaInsertion-v0", "state"),
("aloha", "AlohaInsertion-v0", "pixels"),
("aloha", "AlohaInsertion-v0", "pixels_agent_pos"),
("aloha", "AlohaTransferCube-v0", "pixels"),
("aloha", "AlohaTransferCube-v0", "pixels_agent_pos"),
("xarm", "XarmLift-v0", "state"),
("xarm", "XarmLift-v0", "pixels"),
("xarm", "XarmLift-v0", "pixels_agent_pos"),
("pusht", "PushT-v0", "state"),
("pusht", "PushT-v0", "pixels"),
("pusht", "PushT-v0", "pixels_agent_pos"),
],
)
def test_env(env_name, task, obs_type):
@pytest.mark.parametrize("obs_type", OBS_TYPES)
@pytest.mark.parametrize("env_name, env_task", lerobot.env_task_pairs)
@require_env
def test_env(env_name, env_task, obs_type):
if env_name == "aloha" and obs_type == "state":
pytest.skip("`state` observations not available for aloha")
package_name = f"gym_{env_name}"
importlib.import_module(package_name)
env = gym.make(f"{package_name}/{task}", obs_type=obs_type)
env = gym.make(f"{package_name}/{env_task}", obs_type=obs_type)
check_env(env.unwrapped, skip_render_check=True)
env.close()
@pytest.mark.parametrize(
"env_name",
[
"pusht",
"xarm",
"aloha",
],
)
@pytest.mark.parametrize("env_name", lerobot.available_envs)
@require_env
def test_factory(env_name):
cfg = init_hydra_config(
DEFAULT_CONFIG_PATH,

View File

@ -1,5 +1,5 @@
from pathlib import Path
import subprocess
from pathlib import Path
def _find_and_replace(text: str, finds_and_replaces: list[tuple[str, str]]) -> str:
@ -10,7 +10,7 @@ def _find_and_replace(text: str, finds_and_replaces: list[tuple[str, str]]) -> s
def _run_script(path):
subprocess.run(['python', path], check=True)
subprocess.run(["python", path], check=True)
def test_example_1():
@ -33,7 +33,7 @@ def test_examples_4_and_3():
path = "examples/4_train_policy.py"
with open(path, "r") as file:
with open(path) as file:
file_contents = file.read()
# Do less steps, use smaller batch, use CPU, and don't complicate things with dataloader workers.
@ -55,7 +55,7 @@ def test_examples_4_and_3():
path = "examples/3_evaluate_pretrained_policy.py"
with open(path, "r") as file:
with open(path) as file:
file_contents = file.read()
# Do less evals, use CPU, and use the local model.
@ -74,4 +74,4 @@ def test_examples_4_and_3():
],
)
assert Path(f"outputs/train/example_pusht_diffusion").exists()
assert Path("outputs/train/example_pusht_diffusion").exists()

View File

@ -1,16 +1,18 @@
import pytest
import torch
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.envs.factory import make_env
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.utils import init_hydra_config
from .utils import DEVICE, DEFAULT_CONFIG_PATH
from lerobot.common.utils.utils import init_hydra_config
from .utils import DEFAULT_CONFIG_PATH, DEVICE, require_env
# TODO(aliberts): refactor using lerobot/__init__.py variables
@pytest.mark.parametrize(
"env_name,policy_name,extra_overrides",
[
@ -21,10 +23,9 @@ from .utils import DEVICE, DEFAULT_CONFIG_PATH
("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_id=aloha_sim_insertion_scripted"]),
("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_human"]),
("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_scripted"]),
# TODO(aliberts): xarm not working with diffusion
# ("xarm", "diffusion", []),
],
)
@require_env
def test_policy(env_name, policy_name, extra_overrides):
"""
Tests:

View File

@ -1,6 +1,37 @@
import os
import pytest
import torch
from lerobot.common.utils.import_utils import is_package_available
# Pass this as the first argument to init_hydra_config.
DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml"
DEVICE = os.environ.get('LEROBOT_TESTS_DEVICE', "cuda")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def require_env(func):
"""
Decorator that skips the test if the required environment package is not installed.
As it need 'env_name' in args, it also checks whether it is provided as an argument.
"""
from functools import wraps
@wraps(func)
def wrapper(*args, **kwargs):
# Determine if 'env_name' is provided and extract its value
arg_names = func.__code__.co_varnames[: func.__code__.co_argcount]
if "env_name" in arg_names:
# Get the index of 'env_name' and retrieve the value from args
index = arg_names.index("env_name")
env_name = args[index] if len(args) > index else kwargs.get("env_name")
else:
raise ValueError("Function does not have 'env_name' as an argument.")
# Perform the package check
package_name = f"gym_{env_name}"
if not is_package_available(package_name):
pytest.skip(f"gym-{env_name} not installed")
return func(*args, **kwargs)
return wrapper