From 8c69b0b9cdbf5daa30bc3ef2ce1ebd28b1cbb4da Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Mon, 7 Apr 2025 16:36:04 +0200 Subject: [PATCH] Adding audio tests --- lerobot/__init__.py | 5 + .../robot_devices/microphones/microphone.py | 6 +- tests/conftest.py | 7 +- tests/datasets/test_compute_stats.py | 29 +++- tests/datasets/test_datasets.py | 37 ++++- tests/fixtures/constants.py | 15 +- tests/fixtures/dataset_factories.py | 14 +- tests/microphones/mock_sounddevice.py | 82 ++++++++++ tests/microphones/test_microphones.py | 142 ++++++++++++++++++ tests/robots/test_robots.py | 15 +- tests/utils.py | 38 ++++- 11 files changed, 375 insertions(+), 15 deletions(-) create mode 100644 tests/microphones/mock_sounddevice.py create mode 100644 tests/microphones/test_microphones.py diff --git a/lerobot/__init__.py b/lerobot/__init__.py index d61e4853..386c2fbb 100644 --- a/lerobot/__init__.py +++ b/lerobot/__init__.py @@ -190,6 +190,11 @@ available_cameras = [ "intelrealsense", ] +# lists all available microphones from `lerobot/common/robot_devices/microphones` +available_microphones = [ + "microphone", +] + # lists all available motors from `lerobot/common/robot_devices/motors` available_motors = [ "dynamixel", diff --git a/lerobot/common/robot_devices/microphones/microphone.py b/lerobot/common/robot_devices/microphones/microphone.py index 38a6e5f9..c4cd7bac 100644 --- a/lerobot/common/robot_devices/microphones/microphone.py +++ b/lerobot/common/robot_devices/microphones/microphone.py @@ -44,8 +44,7 @@ def find_microphones(raise_when_empty=False, mock=False) -> list[dict]: microphones = [] if mock: - #TODO(CarolinePascal): Implement mock microphones - pass + import tests.microphones.mock_sounddevice as sd else: import sounddevice as sd @@ -161,8 +160,7 @@ class Microphone: raise RobotDeviceAlreadyConnectedError(f"Microphone {self.microphone_index} is already connected.") if self.mock: - #TODO(CarolinePascal): Implement mock microphones - pass + import tests.microphones.mock_sounddevice as sd else: import sounddevice as sd diff --git a/tests/conftest.py b/tests/conftest.py index 7eec94bf..adf80931 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,9 +19,9 @@ import traceback import pytest from serial import SerialException -from lerobot import available_cameras, available_motors, available_robots +from lerobot import available_cameras, available_motors, available_robots, available_microphones from lerobot.common.robot_devices.robots.utils import make_robot -from tests.utils import DEVICE, make_camera, make_motors_bus +from tests.utils import DEVICE, make_camera, make_motors_bus, make_microphone # Import fixture modules as plugins pytest_plugins = [ @@ -73,6 +73,9 @@ def is_robot_available(robot_type): def is_camera_available(camera_type): return _check_component_availability(camera_type, available_cameras, make_camera) +@pytest.fixture +def is_microphone_available(microphone_type): + return _check_component_availability(microphone_type, available_microphones, make_microphone) @pytest.fixture def is_motor_available(motor_type): diff --git a/tests/datasets/test_compute_stats.py b/tests/datasets/test_compute_stats.py index d9032c8a..113944b3 100644 --- a/tests/datasets/test_compute_stats.py +++ b/tests/datasets/test_compute_stats.py @@ -26,6 +26,7 @@ from lerobot.common.datasets.compute_stats import ( estimate_num_samples, get_feature_stats, sample_images, + sample_audio, sample_indices, ) @@ -33,6 +34,8 @@ from lerobot.common.datasets.compute_stats import ( def mock_load_image_as_numpy(path, dtype, channel_first): return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype) +def mock_load_audio(path): + return np.ones((16000,2), dtype=np.float32) @pytest.fixture def sample_array(): @@ -70,6 +73,14 @@ def test_sample_images(mock_load): assert images.dtype == np.uint8 assert len(images) == estimate_num_samples(100) +@patch("lerobot.common.datasets.compute_stats.load_audio", side_effect=mock_load_audio) +def test_sample_audio(mock_load): + audio_path = "audio.wav" + audio_samples = sample_audio(audio_path) + assert isinstance(audio_samples, np.ndarray) + assert audio_samples.shape[1] == 2 + assert audio_samples.dtype == np.float32 + assert len(audio_samples) == estimate_num_samples(16000) def test_get_feature_stats_images(): data = np.random.rand(100, 3, 32, 32) @@ -78,6 +89,12 @@ def test_get_feature_stats_images(): np.testing.assert_equal(stats["count"], np.array([100])) assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape +def test_get_feature_stats_audio(): + data = np.random.uniform(-1, 1, (16000,2)) + stats = get_feature_stats(data, axis=0, keepdims=True) + assert "min" in stats and "max" in stats and "mean" in stats and "std" in stats and "count" in stats + np.testing.assert_equal(stats["count"], np.array([16000])) + assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape def test_get_feature_stats_axis_0_keepdims(sample_array): expected = { @@ -137,22 +154,28 @@ def test_get_feature_stats_single_value(): def test_compute_episode_stats(): episode_data = { "observation.image": [f"image_{i}.jpg" for i in range(100)], + "observation.audio": "audio.wav", "observation.state": np.random.rand(100, 10), } features = { "observation.image": {"dtype": "image"}, + "observation.audio": {"dtype": "audio"}, "observation.state": {"dtype": "numeric"}, } with patch( "lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy + ), patch( + "lerobot.common.datasets.compute_stats.load_audio", side_effect=mock_load_audio ): stats = compute_episode_stats(episode_data, features) - assert "observation.image" in stats and "observation.state" in stats - assert stats["observation.image"]["count"].item() == 100 - assert stats["observation.state"]["count"].item() == 100 + assert "observation.image" in stats and "observation.state" in stats and "observation.audio" in stats + assert stats["observation.image"]["count"].item() == estimate_num_samples(100) + assert stats["observation.audio"]["count"].item() == estimate_num_samples(16000) + assert stats["observation.state"]["count"].item() == estimate_num_samples(100) assert stats["observation.image"]["mean"].shape == (3, 1, 1) + assert stats["observation.audio"]["mean"].shape == (1, 2) def test_assert_type_and_shape_valid(): diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 81447089..1ce0e7d8 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -44,9 +44,12 @@ from lerobot.common.policies.factory import make_policy_config from lerobot.common.robot_devices.robots.utils import make_robot from lerobot.configs.default import DatasetConfig from lerobot.configs.train import TrainPipelineConfig -from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID +from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID, DUMMY_AUDIO_CHANNELS from tests.utils import require_x86_64_kernel +from tests.utils import make_microphone +import time +from lerobot.common.datasets.utils import DEFAULT_AUDIO_CHUNK_DURATION @pytest.fixture def image_dataset(tmp_path, empty_lerobot_dataset_factory): @@ -63,6 +66,18 @@ def image_dataset(tmp_path, empty_lerobot_dataset_factory): } return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) +@pytest.fixture +def audio_dataset(tmp_path, empty_lerobot_dataset_factory): + features = { + "observation.audio.microphone": { + "dtype": "audio", + "shape": (DUMMY_AUDIO_CHANNELS,), + "names": [ + "channels", + ], + } + } + return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) def test_same_attributes_defined(tmp_path, lerobot_dataset_factory): """ @@ -321,6 +336,20 @@ def test_image_array_to_pil_image_wrong_range_float_0_255(): with pytest.raises(ValueError): image_array_to_pil_image(image) +def test_add_frame_audio(audio_dataset): + dataset = audio_dataset + + microphone = make_microphone(microphone_type="microphone", mock=True) + microphone.connect() + + dataset.add_microphone_recording(microphone, "microphone") + time.sleep(1.0) + dataset.add_frame({"observation.audio.microphone": microphone.read(), "task": "Dummy task"}) + microphone.stop_recording() + + dataset.save_episode() + + assert dataset[0]["observation.audio.microphone"].shape == torch.Size((int(DEFAULT_AUDIO_CHUNK_DURATION*microphone.sampling_rate),DUMMY_AUDIO_CHANNELS)) # TODO(aliberts): # - [ ] test various attributes & state from init and create @@ -354,6 +383,7 @@ def test_factory(env_name, repo_id, policy_name): dataset = make_dataset(cfg) delta_timestamps = dataset.delta_timestamps camera_keys = dataset.meta.camera_keys + audio_keys = dataset.meta.audio_keys item = dataset[0] @@ -396,6 +426,11 @@ def test_factory(env_name, repo_id, policy_name): # test c,h,w assert item[key].shape[0] == 3, f"{key}" + for key in audio_keys: + assert item[key].dtype == torch.float32, f"{key}" + assert item[key].max() <= 1.0, f"{key}" + assert item[key].min() >= -1.0, f"{key}" + if delta_timestamps is not None: # test missing keys in delta_timestamps for key in delta_timestamps: diff --git a/tests/fixtures/constants.py b/tests/fixtures/constants.py index 5e5c762c..91942190 100644 --- a/tests/fixtures/constants.py +++ b/tests/fixtures/constants.py @@ -29,7 +29,7 @@ DUMMY_MOTOR_FEATURES = { }, } DUMMY_CAMERA_FEATURES = { - "laptop": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None}, + "laptop": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None, "audio": "laptop"}, "phone": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None}, } DEFAULT_FPS = 30 @@ -40,5 +40,18 @@ DUMMY_VIDEO_INFO = { "video.is_depth_map": False, "has_audio": False, } +DUMMY_MICROPHONE_FEATURES = { + "laptop": {"dtype": "audio", "shape": (1,), "names": ["channels"], "info": None}, + "phone": {"dtype": "audio", "shape": (1,), "names": ["channels"], "info": None}, +} +DEFAULT_SAMPLE_RATE = 48000 +DUMMY_AUDIO_CHANNELS = 2 +DUMMY_AUDIO_INFO = { + "has_audio": True, + "audio.sample_rate": DEFAULT_SAMPLE_RATE, + "audio.codec": "aac", + "audio.channels": DUMMY_AUDIO_CHANNELS, + "audio.channel_layout": "stereo", +} DUMMY_CHW = (3, 96, 128) DUMMY_HWC = (96, 128, 3) diff --git a/tests/fixtures/dataset_factories.py b/tests/fixtures/dataset_factories.py index fbd7480f..80387d65 100644 --- a/tests/fixtures/dataset_factories.py +++ b/tests/fixtures/dataset_factories.py @@ -36,6 +36,7 @@ from lerobot.common.datasets.utils import ( from tests.fixtures.constants import ( DEFAULT_FPS, DUMMY_CAMERA_FEATURES, + DUMMY_MICROPHONE_FEATURES, DUMMY_MOTOR_FEATURES, DUMMY_REPO_ID, DUMMY_ROBOT_TYPE, @@ -91,6 +92,7 @@ def features_factory(): def _create_features( motor_features: dict = DUMMY_MOTOR_FEATURES, camera_features: dict = DUMMY_CAMERA_FEATURES, + audio_features: dict = DUMMY_MICROPHONE_FEATURES, use_videos: bool = True, ) -> dict: if use_videos: @@ -102,6 +104,7 @@ def features_factory(): return { **motor_features, **camera_ft, + **audio_features, **DEFAULT_FEATURES, } @@ -125,9 +128,10 @@ def info_factory(features_factory): audio_path: str = DEFAULT_COMPRESSED_AUDIO_PATH, motor_features: dict = DUMMY_MOTOR_FEATURES, camera_features: dict = DUMMY_CAMERA_FEATURES, + audio_features: dict = DUMMY_MICROPHONE_FEATURES, use_videos: bool = True, ) -> dict: - features = features_factory(motor_features, camera_features, use_videos) + features = features_factory(motor_features, camera_features, audio_features, use_videos) return { "codebase_version": codebase_version, "robot_type": robot_type, @@ -165,6 +169,14 @@ def stats_factory(): "std": np.full((3, 1, 1), 0.25, dtype=np.float32).tolist(), "count": [10], } + elif dtype == "audio": + stats[key] = { + "mean": np.full((shape[0],), 0.0, dtype=np.float32).tolist(), + "max": np.full((shape[0],), 1, dtype=np.float32).tolist(), + "min": np.full((shape[0],), -1, dtype=np.float32).tolist(), + "std": np.full((shape[0],), 0.5, dtype=np.float32).tolist(), + "count": [10], + } else: stats[key] = { "max": np.full(shape, 1, dtype=dtype).tolist(), diff --git a/tests/microphones/mock_sounddevice.py b/tests/microphones/mock_sounddevice.py new file mode 100644 index 00000000..f6007085 --- /dev/null +++ b/tests/microphones/mock_sounddevice.py @@ -0,0 +1,82 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import cache + +from tests.fixtures.constants import DUMMY_AUDIO_CHANNELS, DEFAULT_SAMPLE_RATE + +import numpy as np +from lerobot.common.utils.utils import capture_timestamp_utc +from threading import Thread, Event +import time + +@cache +def _generate_sound(duration: float, sample_rate: int, channels: int): + return np.random.uniform(-1, 1, size=(int(duration * sample_rate), channels)).astype(np.float32) + +def query_devices(query_index: int): + return { + "name": "Mock Sound Device", + "index": query_index, + "max_input_channels": DUMMY_AUDIO_CHANNELS, + "default_samplerate": DEFAULT_SAMPLE_RATE, + } + +class InputStream: + def __init__(self, *args, **kwargs): + self._mock_dict = { + "channels": DUMMY_AUDIO_CHANNELS, + "samplerate": DEFAULT_SAMPLE_RATE, + } + self._is_active = False + self._audio_callback = kwargs.get("callback") + + self.callback_thread = None + self.callback_thread_stop_event = None + + def _acquisition_loop(self): + if self._audio_callback is not None: + while not self.callback_thread_stop_event.is_set(): + # Simulate audio data acquisition + time.sleep(0.01) + self._audio_callback(_generate_sound(0.01, DEFAULT_SAMPLE_RATE, DUMMY_AUDIO_CHANNELS), 0.01*DEFAULT_SAMPLE_RATE, capture_timestamp_utc(), None) + + def start(self): + self.callback_thread_stop_event = Event() + self.callback_thread = Thread(target=self._acquisition_loop, args=()) + self.callback_thread.daemon = True + self.callback_thread.start() + + self._is_active = True + + @property + def active(self): + return self._is_active + + def stop(self): + if self.callback_thread_stop_event is not None: + self.callback_thread_stop_event.set() + self.callback_thread.join() + self.callback_thread = None + self.callback_thread_stop_event = None + self._is_active = False + + def close(self): + if self._is_active: + self.stop() + + def __del__(self): + if self._is_active: + self.stop() + + diff --git a/tests/microphones/test_microphones.py b/tests/microphones/test_microphones.py new file mode 100644 index 00000000..50c65119 --- /dev/null +++ b/tests/microphones/test_microphones.py @@ -0,0 +1,142 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests for physical microphones and their mocked versions. +If the physical microphone is not connected to the computer, or not working, +the test will be skipped. + +Example of running a specific test: +```bash +pytest -sx tests/microphones/test_microphones.py::test_microphone +``` + +Example of running test on a real microphone connected to the computer: +```bash +pytest -sx 'tests/microphones/test_microphones.py::test_microphone[microphone-False]' +``` + +Example of running test on a mocked version of the microphone: +```bash +pytest -sx 'tests/microphones/test_microphones.py::test_microphone[microphone-True]' +``` +""" + +import numpy as np +import time +import pytest +from soundfile import read + +from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError, RobotDeviceNotRecordingError, RobotDeviceAlreadyRecordingError +from tests.utils import TEST_MICROPHONE_TYPES, make_microphone, require_microphone + +#Maximum recording tie difference between two consecutive audio recordings of the same duration. +#Set to 0.02 seconds as twice the default size of sounddvice callback buffer (i.e. we tolerate the loss of one buffer). +MAX_RECORDING_TIME_DIFFERENCE = 0.02 + +DUMMY_RECORDING = "test_recording.wav" + +@pytest.mark.parametrize("microphone_type, mock", TEST_MICROPHONE_TYPES) +@require_microphone +def test_microphone(tmp_path, request, microphone_type, mock): + """Test assumes that a recroding handled with microphone.start_recording(output_file) and stop_recording() or microphone.read() + leqds to a sample that does not differ from the requested duration by more than 0.1 seconds. + """ + + microphone_kwargs = {"microphone_type": microphone_type, "mock": mock} + + # Test instantiating + microphone = make_microphone(**microphone_kwargs) + + # Test start_recording, stop_recording, read and disconnecting before connecting raises an error + with pytest.raises(RobotDeviceNotConnectedError): + microphone.start_recording() + with pytest.raises(RobotDeviceNotConnectedError): + microphone.stop_recording() + with pytest.raises(RobotDeviceNotConnectedError): + microphone.read() + with pytest.raises(RobotDeviceNotConnectedError): + microphone.disconnect() + + # Test deleting the object without connecting first + del microphone + + # Test connecting + microphone = make_microphone(**microphone_kwargs) + microphone.connect() + assert microphone.is_connected + assert microphone.sampling_rate is not None + assert microphone.channels is not None + + # Test connecting twice raises an error + with pytest.raises(RobotDeviceAlreadyConnectedError): + microphone.connect() + + # Test reading or stop recording before starting recording raises an error + with pytest.raises(RobotDeviceNotRecordingError): + microphone.read() + with pytest.raises(RobotDeviceNotRecordingError): + microphone.stop_recording() + + # Test start_recording + fpath = tmp_path / DUMMY_RECORDING + microphone.start_recording(fpath) + assert microphone.is_recording + + # Test start_recording twice raises an error + with pytest.raises(RobotDeviceAlreadyRecordingError): + microphone.start_recording() + + # Test reading from the microphone + time.sleep(1.0) + audio_chunk = microphone.read() + assert isinstance(audio_chunk, np.ndarray) + assert audio_chunk.ndim == 2 + _, c = audio_chunk.shape + assert c == len(microphone.channels) + + # Test stop_recording + microphone.stop_recording() + assert fpath.exists() + assert not microphone.stream.active + assert microphone.record_thread is None + + # Test stop_recording twice raises an error + with pytest.raises(RobotDeviceNotRecordingError): + microphone.stop_recording() + + # Test reading and recording output similar length audio chunks + microphone.start_recording(tmp_path / DUMMY_RECORDING) + time.sleep(1.0) + audio_chunk = microphone.read() + microphone.stop_recording() + + recorded_audio, recorded_sample_rate = read(fpath) + assert recorded_sample_rate == microphone.sampling_rate + + error_msg = ( + "Recording time difference between read() and stop_recording()", + (len(audio_chunk) - len(recorded_audio))/MAX_RECORDING_TIME_DIFFERENCE, + ) + np.testing.assert_allclose( + len(audio_chunk), len(recorded_audio), atol=recorded_sample_rate*MAX_RECORDING_TIME_DIFFERENCE, err_msg=error_msg + ) + + # Test disconnecting + microphone.disconnect() + assert not microphone.is_connected + + # Test disconnecting with `__del__` + microphone = make_microphone(**microphone_kwargs) + microphone.connect() + del microphone \ No newline at end of file diff --git a/tests/robots/test_robots.py b/tests/robots/test_robots.py index 71343eba..204aabca 100644 --- a/tests/robots/test_robots.py +++ b/tests/robots/test_robots.py @@ -100,6 +100,9 @@ def test_robot(tmp_path, request, robot_type, mock): robot.teleop_step() # Test data recorded during teleop are well formatted + for _, microphone in robot.microphones.items(): + microphone.start_recording() + observation, action = robot.teleop_step(record_data=True) # State assert "observation.state" in observation @@ -112,6 +115,11 @@ def test_robot(tmp_path, request, robot_type, mock): assert f"observation.images.{name}" in observation assert isinstance(observation[f"observation.images.{name}"], torch.Tensor) assert observation[f"observation.images.{name}"].ndim == 3 + # Microphones + for name in robot.microphones: + assert f"observation.audio.{name}" in observation + assert isinstance(observation[f"observation.audio.{name}"], torch.Tensor) + assert observation[f"observation.audio.{name}"].ndim == 2 # Action assert "action" in action assert isinstance(action["action"], torch.Tensor) @@ -124,8 +132,9 @@ def test_robot(tmp_path, request, robot_type, mock): captured_observation = robot.capture_observation() assert set(captured_observation.keys()) == set(observation.keys()) for name in captured_observation: - if "image" in name: + if "image" in name or "audio" in name: # TODO(rcadene): skipping image for now as it's challenging to assess equality between two consecutive frames + # Also skipping for audio as audio chunks may be of different length continue torch.testing.assert_close(captured_observation[name], observation[name], rtol=1e-4, atol=1) assert captured_observation[name].shape == observation[name].shape @@ -134,7 +143,7 @@ def test_robot(tmp_path, request, robot_type, mock): robot.send_action(action["action"]) # Test disconnecting - robot.disconnect() + robot.disconnect() #Also handles microphone recording stop, life is beautiful assert not robot.is_connected for name in robot.follower_arms: assert not robot.follower_arms[name].is_connected @@ -142,3 +151,5 @@ def test_robot(tmp_path, request, robot_type, mock): assert not robot.leader_arms[name].is_connected for name in robot.cameras: assert not robot.cameras[name].is_connected + for name in robot.microphones: + assert not robot.microphones[name].is_connected diff --git a/tests/utils.py b/tests/utils.py index c49b5b9f..8559ca0c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -22,11 +22,13 @@ from pathlib import Path import pytest import torch -from lerobot import available_cameras, available_motors, available_robots +from lerobot import available_cameras, available_motors, available_robots, available_microphones from lerobot.common.robot_devices.cameras.utils import Camera from lerobot.common.robot_devices.cameras.utils import make_camera as make_camera_device from lerobot.common.robot_devices.motors.utils import MotorsBus from lerobot.common.robot_devices.motors.utils import make_motors_bus as make_motors_bus_device +from lerobot.common.robot_devices.microphones.utils import Microphone +from lerobot.common.robot_devices.microphones.utils import make_microphone as make_microphone_device from lerobot.common.utils.import_utils import is_package_available DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", "cuda") if torch.cuda.is_available() else "cpu" @@ -39,6 +41,10 @@ TEST_CAMERA_TYPES = [] for camera_type in available_cameras: TEST_CAMERA_TYPES += [(camera_type, True), (camera_type, False)] +TEST_MICROPHONE_TYPES = [] +for microphone_type in available_microphones: + TEST_MICROPHONE_TYPES += [(microphone_type, True), (microphone_type, False)] + TEST_MOTOR_TYPES = [] for motor_type in available_motors: TEST_MOTOR_TYPES += [(motor_type, True), (motor_type, False)] @@ -47,6 +53,9 @@ for motor_type in available_motors: OPENCV_CAMERA_INDEX = int(os.environ.get("LEROBOT_TEST_OPENCV_CAMERA_INDEX", 0)) INTELREALSENSE_SERIAL_NUMBER = int(os.environ.get("LEROBOT_TEST_INTELREALSENSE_SERIAL_NUMBER", 128422271614)) +# Microphone indices used for connecting physical microphones +MICROPHONE_INDEX = int(os.environ.get("LEROBOT_TEST_MICROPHONE_INDEX", 0)) + DYNAMIXEL_PORT = os.environ.get("LEROBOT_TEST_DYNAMIXEL_PORT", "/dev/tty.usbmodem575E0032081") DYNAMIXEL_MOTORS = { "shoulder_pan": [1, "xl430-w250"], @@ -252,6 +261,27 @@ def require_camera(func): return wrapper +def require_microphone(func): + @wraps(func) + def wrapper(*args, **kwargs): + # Access the pytest request context to get the is_microphone_available fixture + request = kwargs.get("request") + microphone_type = kwargs.get("microphone_type") + mock = kwargs.get("mock") + + if request is None: + raise ValueError("The 'request' fixture must be an argument of the test function.") + if microphone_type is None: + raise ValueError("The 'microphone_type' must be an argument of the test function.") + if mock is None: + raise ValueError("The 'mock' variable must be an argument of the test function.") + + if not mock and not request.getfixturevalue("is_microphone_available"): + pytest.skip(f"A {microphone_type} microphone is not available.") + + return func(*args, **kwargs) + + return wrapper def require_motor(func): @wraps(func) @@ -314,6 +344,12 @@ def make_camera(camera_type: str, **kwargs) -> Camera: else: raise ValueError(f"The camera type '{camera_type}' is not valid.") +def make_microphone(microphone_type: str, **kwargs) -> Microphone: + if microphone_type == "microphone": + microphone_index = kwargs.pop("microphone_index", MICROPHONE_INDEX) + return make_microphone_device(microphone_type, microphone_index=microphone_index, **kwargs) + else: + raise ValueError(f"The microphone type '{microphone_type}' is not valid.") # TODO(rcadene, aliberts): remove this dark pattern that overrides def make_motors_bus(motor_type: str, **kwargs) -> MotorsBus: