Organize test folders (#856)

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
Simon Alibert 2025-03-13 14:05:55 +01:00 committed by GitHub
parent a36ed39487
commit 974028bd28
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
79 changed files with 63 additions and 106 deletions

View File

@ -73,7 +73,7 @@ pip-log.txt
pip-delete-this-directory.txt pip-delete-this-directory.txt
# Unit test / coverage reports # Unit test / coverage reports
!tests/data !tests/artifacts
htmlcov/ htmlcov/
.tox/ .tox/
.nox/ .nox/

2
.gitignore vendored
View File

@ -78,7 +78,7 @@ pip-log.txt
pip-delete-this-directory.txt pip-delete-this-directory.txt
# Unit test / coverage reports # Unit test / coverage reports
!tests/data !tests/artifacts
htmlcov/ htmlcov/
.tox/ .tox/
.nox/ .nox/

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
exclude: ^(tests/data) exclude: "tests/artifacts/.*\\.safetensors$"
default_language_version: default_language_version:
python: python3.10 python: python3.10
repos: repos:

View File

@ -291,7 +291,7 @@ sudo apt-get install git-lfs
git lfs install git lfs install
``` ```
Pull artifacts if they're not in [tests/data](tests/data) Pull artifacts if they're not in [tests/artifacts](tests/artifacts)
```bash ```bash
git lfs pull git lfs pull
``` ```

View File

@ -48,7 +48,7 @@ def find_cameras(raise_when_empty=True, mock=False) -> list[dict]:
connected to the computer. connected to the computer.
""" """
if mock: if mock:
import tests.mock_pyrealsense2 as rs import tests.cameras.mock_pyrealsense2 as rs
else: else:
import pyrealsense2 as rs import pyrealsense2 as rs
@ -100,7 +100,7 @@ def save_images_from_cameras(
serial_numbers = [cam["serial_number"] for cam in camera_infos] serial_numbers = [cam["serial_number"] for cam in camera_infos]
if mock: if mock:
import tests.mock_cv2 as cv2 import tests.cameras.mock_cv2 as cv2
else: else:
import cv2 import cv2
@ -253,7 +253,7 @@ class IntelRealSenseCamera:
self.logs = {} self.logs = {}
if self.mock: if self.mock:
import tests.mock_cv2 as cv2 import tests.cameras.mock_cv2 as cv2
else: else:
import cv2 import cv2
@ -287,7 +287,7 @@ class IntelRealSenseCamera:
) )
if self.mock: if self.mock:
import tests.mock_pyrealsense2 as rs import tests.cameras.mock_pyrealsense2 as rs
else: else:
import pyrealsense2 as rs import pyrealsense2 as rs
@ -375,7 +375,7 @@ class IntelRealSenseCamera:
) )
if self.mock: if self.mock:
import tests.mock_cv2 as cv2 import tests.cameras.mock_cv2 as cv2
else: else:
import cv2 import cv2

View File

@ -80,7 +80,7 @@ def _find_cameras(
possible_camera_ids: list[int | str], raise_when_empty=False, mock=False possible_camera_ids: list[int | str], raise_when_empty=False, mock=False
) -> list[int | str]: ) -> list[int | str]:
if mock: if mock:
import tests.mock_cv2 as cv2 import tests.cameras.mock_cv2 as cv2
else: else:
import cv2 import cv2
@ -269,7 +269,7 @@ class OpenCVCamera:
self.logs = {} self.logs = {}
if self.mock: if self.mock:
import tests.mock_cv2 as cv2 import tests.cameras.mock_cv2 as cv2
else: else:
import cv2 import cv2
@ -286,7 +286,7 @@ class OpenCVCamera:
raise RobotDeviceAlreadyConnectedError(f"OpenCVCamera({self.camera_index}) is already connected.") raise RobotDeviceAlreadyConnectedError(f"OpenCVCamera({self.camera_index}) is already connected.")
if self.mock: if self.mock:
import tests.mock_cv2 as cv2 import tests.cameras.mock_cv2 as cv2
else: else:
import cv2 import cv2
@ -398,7 +398,7 @@ class OpenCVCamera:
# so we convert the image color from BGR to RGB. # so we convert the image color from BGR to RGB.
if requested_color_mode == "rgb": if requested_color_mode == "rgb":
if self.mock: if self.mock:
import tests.mock_cv2 as cv2 import tests.cameras.mock_cv2 as cv2
else: else:
import cv2 import cv2

View File

@ -332,7 +332,7 @@ class DynamixelMotorsBus:
) )
if self.mock: if self.mock:
import tests.mock_dynamixel_sdk as dxl import tests.motors.mock_dynamixel_sdk as dxl
else: else:
import dynamixel_sdk as dxl import dynamixel_sdk as dxl
@ -356,7 +356,7 @@ class DynamixelMotorsBus:
def reconnect(self): def reconnect(self):
if self.mock: if self.mock:
import tests.mock_dynamixel_sdk as dxl import tests.motors.mock_dynamixel_sdk as dxl
else: else:
import dynamixel_sdk as dxl import dynamixel_sdk as dxl
@ -646,7 +646,7 @@ class DynamixelMotorsBus:
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY): def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
if self.mock: if self.mock:
import tests.mock_dynamixel_sdk as dxl import tests.motors.mock_dynamixel_sdk as dxl
else: else:
import dynamixel_sdk as dxl import dynamixel_sdk as dxl
@ -691,7 +691,7 @@ class DynamixelMotorsBus:
start_time = time.perf_counter() start_time = time.perf_counter()
if self.mock: if self.mock:
import tests.mock_dynamixel_sdk as dxl import tests.motors.mock_dynamixel_sdk as dxl
else: else:
import dynamixel_sdk as dxl import dynamixel_sdk as dxl
@ -757,7 +757,7 @@ class DynamixelMotorsBus:
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY): def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
if self.mock: if self.mock:
import tests.mock_dynamixel_sdk as dxl import tests.motors.mock_dynamixel_sdk as dxl
else: else:
import dynamixel_sdk as dxl import dynamixel_sdk as dxl
@ -793,7 +793,7 @@ class DynamixelMotorsBus:
start_time = time.perf_counter() start_time = time.perf_counter()
if self.mock: if self.mock:
import tests.mock_dynamixel_sdk as dxl import tests.motors.mock_dynamixel_sdk as dxl
else: else:
import dynamixel_sdk as dxl import dynamixel_sdk as dxl

View File

@ -313,7 +313,7 @@ class FeetechMotorsBus:
) )
if self.mock: if self.mock:
import tests.mock_scservo_sdk as scs import tests.motors.mock_scservo_sdk as scs
else: else:
import scservo_sdk as scs import scservo_sdk as scs
@ -337,7 +337,7 @@ class FeetechMotorsBus:
def reconnect(self): def reconnect(self):
if self.mock: if self.mock:
import tests.mock_scservo_sdk as scs import tests.motors.mock_scservo_sdk as scs
else: else:
import scservo_sdk as scs import scservo_sdk as scs
@ -664,7 +664,7 @@ class FeetechMotorsBus:
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY): def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
if self.mock: if self.mock:
import tests.mock_scservo_sdk as scs import tests.motors.mock_scservo_sdk as scs
else: else:
import scservo_sdk as scs import scservo_sdk as scs
@ -702,7 +702,7 @@ class FeetechMotorsBus:
def read(self, data_name, motor_names: str | list[str] | None = None): def read(self, data_name, motor_names: str | list[str] | None = None):
if self.mock: if self.mock:
import tests.mock_scservo_sdk as scs import tests.motors.mock_scservo_sdk as scs
else: else:
import scservo_sdk as scs import scservo_sdk as scs
@ -782,7 +782,7 @@ class FeetechMotorsBus:
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY): def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
if self.mock: if self.mock:
import tests.mock_scservo_sdk as scs import tests.motors.mock_scservo_sdk as scs
else: else:
import scservo_sdk as scs import scservo_sdk as scs
@ -818,7 +818,7 @@ class FeetechMotorsBus:
start_time = time.perf_counter() start_time = time.perf_counter()
if self.mock: if self.mock:
import tests.mock_scservo_sdk as scs import tests.motors.mock_scservo_sdk as scs
else: else:
import scservo_sdk as scs import scservo_sdk as scs

View File

@ -102,30 +102,7 @@ requires-poetry = ">=2.1"
[tool.ruff] [tool.ruff]
line-length = 110 line-length = 110
target-version = "py310" target-version = "py310"
exclude = [ exclude = ["tests/artifacts/**/*.safetensors"]
"tests/data",
".bzr",
".direnv",
".eggs",
".git",
".git-rewrite",
".hg",
".mypy_cache",
".nox",
".pants.d",
".pytype",
".ruff_cache",
".svn",
".tox",
".venv",
"__pypackages__",
"_build",
"buck-out",
"build",
"dist",
"node_modules",
"venv",
]
[tool.ruff.lint] [tool.ruff.lint]
select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"] select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"]

View File

@ -23,7 +23,7 @@ If you know that your change will break backward compatibility, you should write
doesnt need to be merged into the `main` branch. Then you need to run this script and update the tests artifacts. doesnt need to be merged into the `main` branch. Then you need to run this script and update the tests artifacts.
Example usage: Example usage:
`python tests/scripts/save_dataset_to_safetensors.py` `python tests/artifacts/datasets/save_dataset_to_safetensors.py`
""" """
import shutil import shutil
@ -88,4 +88,4 @@ if __name__ == "__main__":
"lerobot/nyu_franka_play_dataset", "lerobot/nyu_franka_play_dataset",
"lerobot/cmu_stretch", "lerobot/cmu_stretch",
]: ]:
save_dataset_to_safetensors("tests/data/save_dataset_to_safetensors", repo_id=dataset) save_dataset_to_safetensors("tests/artifacts/datasets", repo_id=dataset)

View File

@ -27,7 +27,7 @@ from lerobot.common.datasets.transforms import (
) )
from lerobot.common.utils.random_utils import seeded_context from lerobot.common.utils.random_utils import seeded_context
ARTIFACT_DIR = Path("tests/data/save_image_transforms_to_safetensors") ARTIFACT_DIR = Path("tests/artifacts/image_transforms")
DATASET_REPO_ID = "lerobot/aloha_mobile_shrimp" DATASET_REPO_ID = "lerobot/aloha_mobile_shrimp"

View File

@ -141,5 +141,5 @@ if __name__ == "__main__":
raise RuntimeError("No policies were provided!") raise RuntimeError("No policies were provided!")
for ds_repo_id, policy, policy_kwargs, file_name_extra in artifacts_cfg: for ds_repo_id, policy, policy_kwargs, file_name_extra in artifacts_cfg:
ds_name = ds_repo_id.split("/")[-1] ds_name = ds_repo_id.split("/")[-1]
output_dir = Path("tests/data/save_policy_to_safetensors") / f"{ds_name}_{policy}_{file_name_extra}" output_dir = Path("tests/artifacts/policies") / f"{ds_name}_{policy}_{file_name_extra}"
save_policy_to_safetensors(output_dir, ds_repo_id, policy, policy_kwargs) save_policy_to_safetensors(output_dir, ds_repo_id, policy, policy_kwargs)

View File

@ -146,7 +146,7 @@ def test_camera(request, camera_type, mock):
camera.connect() camera.connect()
if mock: if mock:
import tests.mock_cv2 as cv2 import tests.cameras.mock_cv2 as cv2
else: else:
import cv2 import cv2

View File

@ -473,12 +473,12 @@ def test_flatten_unflatten_dict():
) )
@require_x86_64_kernel @require_x86_64_kernel
def test_backward_compatibility(repo_id): def test_backward_compatibility(repo_id):
"""The artifacts for this test have been generated by `tests/scripts/save_dataset_to_safetensors.py`.""" """The artifacts for this test have been generated by `tests/artifacts/datasets/save_dataset_to_safetensors.py`."""
# TODO(rcadene, aliberts): remove dataset download # TODO(rcadene, aliberts): remove dataset download
dataset = LeRobotDataset(repo_id, episodes=[0]) dataset = LeRobotDataset(repo_id, episodes=[0])
test_dir = Path("tests/data/save_dataset_to_safetensors") / repo_id test_dir = Path("tests/artifacts/datasets") / repo_id
def load_and_compare(i): def load_and_compare(i):
new_frame = dataset[i] # noqa: B023 new_frame = dataset[i] # noqa: B023

View File

@ -33,7 +33,7 @@ from lerobot.scripts.visualize_image_transforms import (
save_all_transforms, save_all_transforms,
save_each_transform, save_each_transform,
) )
from tests.scripts.save_image_transforms_to_safetensors import ARTIFACT_DIR from tests.artifacts.image_transforms.save_image_transforms_to_safetensors import ARTIFACT_DIR
from tests.utils import require_x86_64_kernel from tests.utils import require_x86_64_kernel

View File

@ -1,3 +1,5 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -11,13 +13,32 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import torch import torch
from datasets import Dataset from datasets import Dataset
from huggingface_hub import DatasetCard
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import create_lerobot_dataset_card, hf_transform_to_torch
hf_transform_to_torch,
)
def test_default_parameters():
card = create_lerobot_dataset_card()
assert isinstance(card, DatasetCard)
assert card.data.tags == ["LeRobot"]
assert card.data.task_categories == ["robotics"]
assert card.data.configs == [
{
"config_name": "default",
"data_files": "data/*/*.parquet",
}
]
def test_with_tags():
tags = ["tag1", "tag2"]
card = create_lerobot_dataset_card(tags=tags)
assert card.data.tags == ["LeRobot", "tag1", "tag2"]
def test_calculate_episode_data_index(): def test_calculate_episode_data_index():

View File

@ -23,8 +23,7 @@ from gymnasium.utils.env_checker import check_env
import lerobot import lerobot
from lerobot.common.envs.factory import make_env, make_env_config from lerobot.common.envs.factory import make_env, make_env_config
from lerobot.common.envs.utils import preprocess_observation from lerobot.common.envs.utils import preprocess_observation
from tests.utils import require_env
from .utils import require_env
OBS_TYPES = ["state", "pixels", "pixels_agent_pos"] OBS_TYPES = ["state", "pixels", "pixels_agent_pos"]

View File

@ -1,38 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from huggingface_hub import DatasetCard
from lerobot.common.datasets.utils import create_lerobot_dataset_card
def test_default_parameters():
card = create_lerobot_dataset_card()
assert isinstance(card, DatasetCard)
assert card.data.tags == ["LeRobot"]
assert card.data.task_categories == ["robotics"]
assert card.data.configs == [
{
"config_name": "default",
"data_files": "data/*/*.parquet",
}
]
def test_with_tags():
tags = ["tag1", "tag2"]
card = create_lerobot_dataset_card(tags=tags)
assert card.data.tags == ["LeRobot", "tag1", "tag2"]

View File

@ -40,7 +40,7 @@ from lerobot.common.utils.random_utils import seeded_context
from lerobot.configs.default import DatasetConfig from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig from lerobot.configs.train import TrainPipelineConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from tests.scripts.save_policy_to_safetensors import get_policy_stats from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel
@ -407,12 +407,10 @@ def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs
should be updated. should be updated.
4. Check that this test now passes. 4. Check that this test now passes.
5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state. 5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state.
6. Remember to stage and commit the resulting changes to `tests/data`. 6. Remember to stage and commit the resulting changes to `tests/artifacts`.
""" """
ds_name = ds_repo_id.split("/")[-1] ds_name = ds_repo_id.split("/")[-1]
artifact_dir = ( artifact_dir = Path("tests/artifacts/policies") / f"{ds_name}_{policy_name}_{file_name_extra}"
Path("tests/data/save_policy_to_safetensors") / f"{ds_name}_{policy_name}_{file_name_extra}"
)
saved_output_dict = load_file(artifact_dir / "output_dict.safetensors") saved_output_dict = load_file(artifact_dir / "output_dict.safetensors")
saved_grad_stats = load_file(artifact_dir / "grad_stats.safetensors") saved_grad_stats = load_file(artifact_dir / "grad_stats.safetensors")
saved_param_stats = load_file(artifact_dir / "param_stats.safetensors") saved_param_stats = load_file(artifact_dir / "param_stats.safetensors")

View File

@ -51,7 +51,7 @@ from lerobot.common.robot_devices.control_configs import (
) )
from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.policies import PreTrainedConfig
from lerobot.scripts.control_robot import calibrate, record, replay, teleoperate from lerobot.scripts.control_robot import calibrate, record, replay, teleoperate
from tests.test_robots import make_robot from tests.robots.test_robots import make_robot
from tests.utils import TEST_ROBOT_TYPES, mock_calibration_dir, require_robot from tests.utils import TEST_ROBOT_TYPES, mock_calibration_dir, require_robot