Include tests in pre-commit formatting

This commit is contained in:
Simon Alibert 2024-04-18 12:53:23 +02:00
parent 37efcea3eb
commit d407ce21aa
7 changed files with 82 additions and 63 deletions

View File

@ -1,4 +1,4 @@
exclude: ^(data/|tests/) exclude: ^(data/|tests/data)
default_language_version: default_language_version:
python: python3.10 python: python3.10
repos: repos:

View File

@ -1,9 +1,9 @@
import importlib import importlib
import pytest
import lerobot
import gymnasium as gym
from lerobot.common.utils.import_utils import is_package_available import gymnasium as gym
import pytest
import lerobot
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
@ -21,7 +21,7 @@ def test_available_env_task(env_name: str, task_name: list):
package_name = f"gym_{env_name}" package_name = f"gym_{env_name}"
importlib.import_module(package_name) importlib.import_module(package_name)
gym_handle = f"{package_name}/{task_name}" gym_handle = f"{package_name}/{task_name}"
assert gym_handle in gym.envs.registry.keys(), gym_handle assert gym_handle in gym.envs.registry, gym_handle
def test_available_policies(): def test_available_policies():

View File

@ -1,24 +1,35 @@
import logging
import os import os
from pathlib import Path from pathlib import Path
import einops import einops
import pytest import pytest
import torch import torch
from datasets import Dataset
import lerobot import lerobot
from lerobot.common.datasets.utils import compute_stats, get_stats_einops_patterns, load_previous_and_future_frames from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import (
compute_stats,
get_stats_einops_patterns,
load_previous_and_future_frames,
)
from lerobot.common.transforms import Prod from lerobot.common.transforms import Prod
from lerobot.common.utils.utils import init_hydra_config from lerobot.common.utils.utils import init_hydra_config
import logging
from lerobot.common.datasets.factory import make_dataset from .utils import DEFAULT_CONFIG_PATH, DEVICE
from datasets import Dataset
from .utils import DEVICE, DEFAULT_CONFIG_PATH
@pytest.mark.parametrize("env_name, dataset_id, policy_name", lerobot.env_dataset_policy_triplets) @pytest.mark.parametrize("env_name, dataset_id, policy_name", lerobot.env_dataset_policy_triplets)
def test_factory(env_name, dataset_id, policy_name): def test_factory(env_name, dataset_id, policy_name):
cfg = init_hydra_config( cfg = init_hydra_config(
DEFAULT_CONFIG_PATH, DEFAULT_CONFIG_PATH,
overrides=[f"env={env_name}", f"dataset_id={dataset_id}", f"policy={policy_name}", f"device={DEVICE}"] overrides=[
f"env={env_name}",
f"dataset_id={dataset_id}",
f"policy={policy_name}",
f"device={DEVICE}",
],
) )
dataset = make_dataset(cfg) dataset = make_dataset(cfg)
delta_timestamps = dataset.delta_timestamps delta_timestamps = dataset.delta_timestamps
@ -71,7 +82,6 @@ def test_factory(env_name, dataset_id, policy_name):
# test c,h,w # test c,h,w
assert item[key].shape[0] == 3, f"{key}" assert item[key].shape[0] == 3, f"{key}"
if delta_timestamps is not None: if delta_timestamps is not None:
# test missing keys in delta_timestamps # test missing keys in delta_timestamps
for key in delta_timestamps: for key in delta_timestamps:
@ -86,14 +96,14 @@ def test_compute_stats_on_xarm():
""" """
from lerobot.common.datasets.xarm import XarmDataset from lerobot.common.datasets.xarm import XarmDataset
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None data_dir = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
# get transform to convert images from uint8 [0,255] to float32 [0,1] # get transform to convert images from uint8 [0,255] to float32 [0,1]
transform = Prod(in_keys=XarmDataset.image_keys, prod=1 / 255.0) transform = Prod(in_keys=XarmDataset.image_keys, prod=1 / 255.0)
dataset = XarmDataset( dataset = XarmDataset(
dataset_id="xarm_lift_medium", dataset_id="xarm_lift_medium",
root=DATA_DIR, root=data_dir,
transform=transform, transform=transform,
) )
@ -119,7 +129,9 @@ def test_compute_stats_on_xarm():
for k, pattern in stats_patterns.items(): for k, pattern in stats_patterns.items():
expected_stats[k] = {} expected_stats[k] = {}
expected_stats[k]["mean"] = einops.reduce(hf_dataset[k], pattern, "mean") expected_stats[k]["mean"] = einops.reduce(hf_dataset[k], pattern, "mean")
expected_stats[k]["std"] = torch.sqrt(einops.reduce((hf_dataset[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean")) expected_stats[k]["std"] = torch.sqrt(
einops.reduce((hf_dataset[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean")
)
expected_stats[k]["min"] = einops.reduce(hf_dataset[k], pattern, "min") expected_stats[k]["min"] = einops.reduce(hf_dataset[k], pattern, "min")
expected_stats[k]["max"] = einops.reduce(hf_dataset[k], pattern, "max") expected_stats[k]["max"] = einops.reduce(hf_dataset[k], pattern, "max")
@ -144,12 +156,14 @@ def test_compute_stats_on_xarm():
def test_load_previous_and_future_frames_within_tolerance(): def test_load_previous_and_future_frames_within_tolerance():
hf_dataset = Dataset.from_dict({ hf_dataset = Dataset.from_dict(
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5], {
"index": [0, 1, 2, 3, 4], "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"episode_data_index_from": [0, 0, 0, 0, 0], "index": [0, 1, 2, 3, 4],
"episode_data_index_to": [5, 5, 5, 5, 5], "episode_data_index_from": [0, 0, 0, 0, 0],
}) "episode_data_index_to": [5, 5, 5, 5, 5],
}
)
hf_dataset = hf_dataset.with_format("torch") hf_dataset = hf_dataset.with_format("torch")
item = hf_dataset[2] item = hf_dataset[2]
delta_timestamps = {"index": [-0.2, 0, 0.139]} delta_timestamps = {"index": [-0.2, 0, 0.139]}
@ -161,12 +175,14 @@ def test_load_previous_and_future_frames_within_tolerance():
def test_load_previous_and_future_frames_outside_tolerance_inside_episode_range(): def test_load_previous_and_future_frames_outside_tolerance_inside_episode_range():
hf_dataset = Dataset.from_dict({ hf_dataset = Dataset.from_dict(
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5], {
"index": [0, 1, 2, 3, 4], "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"episode_data_index_from": [0, 0, 0, 0, 0], "index": [0, 1, 2, 3, 4],
"episode_data_index_to": [5, 5, 5, 5, 5], "episode_data_index_from": [0, 0, 0, 0, 0],
}) "episode_data_index_to": [5, 5, 5, 5, 5],
}
)
hf_dataset = hf_dataset.with_format("torch") hf_dataset = hf_dataset.with_format("torch")
item = hf_dataset[2] item = hf_dataset[2]
delta_timestamps = {"index": [-0.2, 0, 0.141]} delta_timestamps = {"index": [-0.2, 0, 0.141]}
@ -176,12 +192,14 @@ def test_load_previous_and_future_frames_outside_tolerance_inside_episode_range(
def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range(): def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range():
hf_dataset = Dataset.from_dict({ hf_dataset = Dataset.from_dict(
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5], {
"index": [0, 1, 2, 3, 4], "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"episode_data_index_from": [0, 0, 0, 0, 0], "index": [0, 1, 2, 3, 4],
"episode_data_index_to": [5, 5, 5, 5, 5], "episode_data_index_from": [0, 0, 0, 0, 0],
}) "episode_data_index_to": [5, 5, 5, 5, 5],
}
)
hf_dataset = hf_dataset.with_format("torch") hf_dataset = hf_dataset.with_format("torch")
item = hf_dataset[2] item = hf_dataset[2]
delta_timestamps = {"index": [-0.3, -0.24, 0, 0.26, 0.3]} delta_timestamps = {"index": [-0.3, -0.24, 0, 0.26, 0.3]}
@ -189,6 +207,6 @@ def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range
item = load_previous_and_future_frames(item, hf_dataset, delta_timestamps, tol) item = load_previous_and_future_frames(item, hf_dataset, delta_timestamps, tol)
data, is_pad = item["index"], item["index_is_pad"] data, is_pad = item["index"], item["index_is_pad"]
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values" assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values"
assert torch.equal(is_pad, torch.tensor([True, False, False, True, True])), "Padding does not match expected values" assert torch.equal(
is_pad, torch.tensor([True, False, False, True, True])
), "Padding does not match expected values"

View File

@ -1,18 +1,17 @@
import importlib import importlib
import gymnasium as gym
import pytest import pytest
import torch import torch
from lerobot.common.datasets.factory import make_dataset
import gymnasium as gym
from gymnasium.utils.env_checker import check_env from gymnasium.utils.env_checker import check_env
import lerobot
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.envs.factory import make_env from lerobot.common.envs.factory import make_env
from lerobot.common.utils.import_utils import is_package_available from lerobot.common.envs.utils import preprocess_observation
from lerobot.common.utils.utils import init_hydra_config from lerobot.common.utils.utils import init_hydra_config
import lerobot from .utils import DEFAULT_CONFIG_PATH, DEVICE, require_env
from lerobot.common.envs.utils import preprocess_observation
from .utils import DEVICE, DEFAULT_CONFIG_PATH, require_env
OBS_TYPES = ["state", "pixels", "pixels_agent_pos"] OBS_TYPES = ["state", "pixels", "pixels_agent_pos"]

View File

@ -1,5 +1,5 @@
from pathlib import Path
import subprocess import subprocess
from pathlib import Path
def _find_and_replace(text: str, finds_and_replaces: list[tuple[str, str]]) -> str: def _find_and_replace(text: str, finds_and_replaces: list[tuple[str, str]]) -> str:
@ -10,7 +10,7 @@ def _find_and_replace(text: str, finds_and_replaces: list[tuple[str, str]]) -> s
def _run_script(path): def _run_script(path):
subprocess.run(['python', path], check=True) subprocess.run(["python", path], check=True)
def test_example_1(): def test_example_1():
@ -33,7 +33,7 @@ def test_examples_4_and_3():
path = "examples/4_train_policy.py" path = "examples/4_train_policy.py"
with open(path, "r") as file: with open(path) as file:
file_contents = file.read() file_contents = file.read()
# Do less steps, use smaller batch, use CPU, and don't complicate things with dataloader workers. # Do less steps, use smaller batch, use CPU, and don't complicate things with dataloader workers.
@ -55,7 +55,7 @@ def test_examples_4_and_3():
path = "examples/3_evaluate_pretrained_policy.py" path = "examples/3_evaluate_pretrained_policy.py"
with open(path, "r") as file: with open(path) as file:
file_contents = file.read() file_contents = file.read()
# Do less evals, use CPU, and use the local model. # Do less evals, use CPU, and use the local model.
@ -74,4 +74,4 @@ def test_examples_4_and_3():
], ],
) )
assert Path(f"outputs/train/example_pusht_diffusion").exists() assert Path("outputs/train/example_pusht_diffusion").exists()

View File

@ -1,14 +1,15 @@
import pytest import pytest
import torch import torch
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import cycle from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import postprocess_action, preprocess_observation from lerobot.common.envs.utils import postprocess_action, preprocess_observation
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.policy_protocol import Policy from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.envs.factory import make_env
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.utils.utils import init_hydra_config from lerobot.common.utils.utils import init_hydra_config
from .utils import DEVICE, DEFAULT_CONFIG_PATH, require_env
from .utils import DEFAULT_CONFIG_PATH, DEVICE, require_env
# TODO(aliberts): refactor using lerobot/__init__.py variables # TODO(aliberts): refactor using lerobot/__init__.py variables

View File

@ -8,6 +8,7 @@ DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def require_env(func): def require_env(func):
""" """
Decorator that skips the test if the required environment package is not installed. Decorator that skips the test if the required environment package is not installed.
@ -18,11 +19,11 @@ def require_env(func):
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
# Determine if 'env_name' is provided and extract its value # Determine if 'env_name' is provided and extract its value
arg_names = func.__code__.co_varnames[:func.__code__.co_argcount] arg_names = func.__code__.co_varnames[: func.__code__.co_argcount]
if 'env_name' in arg_names: if "env_name" in arg_names:
# Get the index of 'env_name' and retrieve the value from args # Get the index of 'env_name' and retrieve the value from args
index = arg_names.index('env_name') index = arg_names.index("env_name")
env_name = args[index] if len(args) > index else kwargs.get('env_name') env_name = args[index] if len(args) > index else kwargs.get("env_name")
else: else:
raise ValueError("Function does not have 'env_name' as an argument.") raise ValueError("Function does not have 'env_name' as an argument.")