Adding audio tests
This commit is contained in:
parent
7c832fa2a7
commit
8c69b0b9cd
|
@ -190,6 +190,11 @@ available_cameras = [
|
|||
"intelrealsense",
|
||||
]
|
||||
|
||||
# lists all available microphones from `lerobot/common/robot_devices/microphones`
|
||||
available_microphones = [
|
||||
"microphone",
|
||||
]
|
||||
|
||||
# lists all available motors from `lerobot/common/robot_devices/motors`
|
||||
available_motors = [
|
||||
"dynamixel",
|
||||
|
|
|
@ -44,8 +44,7 @@ def find_microphones(raise_when_empty=False, mock=False) -> list[dict]:
|
|||
microphones = []
|
||||
|
||||
if mock:
|
||||
#TODO(CarolinePascal): Implement mock microphones
|
||||
pass
|
||||
import tests.microphones.mock_sounddevice as sd
|
||||
else:
|
||||
import sounddevice as sd
|
||||
|
||||
|
@ -161,8 +160,7 @@ class Microphone:
|
|||
raise RobotDeviceAlreadyConnectedError(f"Microphone {self.microphone_index} is already connected.")
|
||||
|
||||
if self.mock:
|
||||
#TODO(CarolinePascal): Implement mock microphones
|
||||
pass
|
||||
import tests.microphones.mock_sounddevice as sd
|
||||
else:
|
||||
import sounddevice as sd
|
||||
|
||||
|
|
|
@ -19,9 +19,9 @@ import traceback
|
|||
import pytest
|
||||
from serial import SerialException
|
||||
|
||||
from lerobot import available_cameras, available_motors, available_robots
|
||||
from lerobot import available_cameras, available_motors, available_robots, available_microphones
|
||||
from lerobot.common.robot_devices.robots.utils import make_robot
|
||||
from tests.utils import DEVICE, make_camera, make_motors_bus
|
||||
from tests.utils import DEVICE, make_camera, make_motors_bus, make_microphone
|
||||
|
||||
# Import fixture modules as plugins
|
||||
pytest_plugins = [
|
||||
|
@ -73,6 +73,9 @@ def is_robot_available(robot_type):
|
|||
def is_camera_available(camera_type):
|
||||
return _check_component_availability(camera_type, available_cameras, make_camera)
|
||||
|
||||
@pytest.fixture
|
||||
def is_microphone_available(microphone_type):
|
||||
return _check_component_availability(microphone_type, available_microphones, make_microphone)
|
||||
|
||||
@pytest.fixture
|
||||
def is_motor_available(motor_type):
|
||||
|
|
|
@ -26,6 +26,7 @@ from lerobot.common.datasets.compute_stats import (
|
|||
estimate_num_samples,
|
||||
get_feature_stats,
|
||||
sample_images,
|
||||
sample_audio,
|
||||
sample_indices,
|
||||
)
|
||||
|
||||
|
@ -33,6 +34,8 @@ from lerobot.common.datasets.compute_stats import (
|
|||
def mock_load_image_as_numpy(path, dtype, channel_first):
|
||||
return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype)
|
||||
|
||||
def mock_load_audio(path):
|
||||
return np.ones((16000,2), dtype=np.float32)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_array():
|
||||
|
@ -70,6 +73,14 @@ def test_sample_images(mock_load):
|
|||
assert images.dtype == np.uint8
|
||||
assert len(images) == estimate_num_samples(100)
|
||||
|
||||
@patch("lerobot.common.datasets.compute_stats.load_audio", side_effect=mock_load_audio)
|
||||
def test_sample_audio(mock_load):
|
||||
audio_path = "audio.wav"
|
||||
audio_samples = sample_audio(audio_path)
|
||||
assert isinstance(audio_samples, np.ndarray)
|
||||
assert audio_samples.shape[1] == 2
|
||||
assert audio_samples.dtype == np.float32
|
||||
assert len(audio_samples) == estimate_num_samples(16000)
|
||||
|
||||
def test_get_feature_stats_images():
|
||||
data = np.random.rand(100, 3, 32, 32)
|
||||
|
@ -78,6 +89,12 @@ def test_get_feature_stats_images():
|
|||
np.testing.assert_equal(stats["count"], np.array([100]))
|
||||
assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape
|
||||
|
||||
def test_get_feature_stats_audio():
|
||||
data = np.random.uniform(-1, 1, (16000,2))
|
||||
stats = get_feature_stats(data, axis=0, keepdims=True)
|
||||
assert "min" in stats and "max" in stats and "mean" in stats and "std" in stats and "count" in stats
|
||||
np.testing.assert_equal(stats["count"], np.array([16000]))
|
||||
assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape
|
||||
|
||||
def test_get_feature_stats_axis_0_keepdims(sample_array):
|
||||
expected = {
|
||||
|
@ -137,22 +154,28 @@ def test_get_feature_stats_single_value():
|
|||
def test_compute_episode_stats():
|
||||
episode_data = {
|
||||
"observation.image": [f"image_{i}.jpg" for i in range(100)],
|
||||
"observation.audio": "audio.wav",
|
||||
"observation.state": np.random.rand(100, 10),
|
||||
}
|
||||
features = {
|
||||
"observation.image": {"dtype": "image"},
|
||||
"observation.audio": {"dtype": "audio"},
|
||||
"observation.state": {"dtype": "numeric"},
|
||||
}
|
||||
|
||||
with patch(
|
||||
"lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy
|
||||
), patch(
|
||||
"lerobot.common.datasets.compute_stats.load_audio", side_effect=mock_load_audio
|
||||
):
|
||||
stats = compute_episode_stats(episode_data, features)
|
||||
|
||||
assert "observation.image" in stats and "observation.state" in stats
|
||||
assert stats["observation.image"]["count"].item() == 100
|
||||
assert stats["observation.state"]["count"].item() == 100
|
||||
assert "observation.image" in stats and "observation.state" in stats and "observation.audio" in stats
|
||||
assert stats["observation.image"]["count"].item() == estimate_num_samples(100)
|
||||
assert stats["observation.audio"]["count"].item() == estimate_num_samples(16000)
|
||||
assert stats["observation.state"]["count"].item() == estimate_num_samples(100)
|
||||
assert stats["observation.image"]["mean"].shape == (3, 1, 1)
|
||||
assert stats["observation.audio"]["mean"].shape == (1, 2)
|
||||
|
||||
|
||||
def test_assert_type_and_shape_valid():
|
||||
|
|
|
@ -44,9 +44,12 @@ from lerobot.common.policies.factory import make_policy_config
|
|||
from lerobot.common.robot_devices.robots.utils import make_robot
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
||||
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID, DUMMY_AUDIO_CHANNELS
|
||||
from tests.utils import require_x86_64_kernel
|
||||
|
||||
from tests.utils import make_microphone
|
||||
import time
|
||||
from lerobot.common.datasets.utils import DEFAULT_AUDIO_CHUNK_DURATION
|
||||
|
||||
@pytest.fixture
|
||||
def image_dataset(tmp_path, empty_lerobot_dataset_factory):
|
||||
|
@ -63,6 +66,18 @@ def image_dataset(tmp_path, empty_lerobot_dataset_factory):
|
|||
}
|
||||
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
|
||||
@pytest.fixture
|
||||
def audio_dataset(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {
|
||||
"observation.audio.microphone": {
|
||||
"dtype": "audio",
|
||||
"shape": (DUMMY_AUDIO_CHANNELS,),
|
||||
"names": [
|
||||
"channels",
|
||||
],
|
||||
}
|
||||
}
|
||||
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
|
||||
def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
|
||||
"""
|
||||
|
@ -321,6 +336,20 @@ def test_image_array_to_pil_image_wrong_range_float_0_255():
|
|||
with pytest.raises(ValueError):
|
||||
image_array_to_pil_image(image)
|
||||
|
||||
def test_add_frame_audio(audio_dataset):
|
||||
dataset = audio_dataset
|
||||
|
||||
microphone = make_microphone(microphone_type="microphone", mock=True)
|
||||
microphone.connect()
|
||||
|
||||
dataset.add_microphone_recording(microphone, "microphone")
|
||||
time.sleep(1.0)
|
||||
dataset.add_frame({"observation.audio.microphone": microphone.read(), "task": "Dummy task"})
|
||||
microphone.stop_recording()
|
||||
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["observation.audio.microphone"].shape == torch.Size((int(DEFAULT_AUDIO_CHUNK_DURATION*microphone.sampling_rate),DUMMY_AUDIO_CHANNELS))
|
||||
|
||||
# TODO(aliberts):
|
||||
# - [ ] test various attributes & state from init and create
|
||||
|
@ -354,6 +383,7 @@ def test_factory(env_name, repo_id, policy_name):
|
|||
dataset = make_dataset(cfg)
|
||||
delta_timestamps = dataset.delta_timestamps
|
||||
camera_keys = dataset.meta.camera_keys
|
||||
audio_keys = dataset.meta.audio_keys
|
||||
|
||||
item = dataset[0]
|
||||
|
||||
|
@ -396,6 +426,11 @@ def test_factory(env_name, repo_id, policy_name):
|
|||
# test c,h,w
|
||||
assert item[key].shape[0] == 3, f"{key}"
|
||||
|
||||
for key in audio_keys:
|
||||
assert item[key].dtype == torch.float32, f"{key}"
|
||||
assert item[key].max() <= 1.0, f"{key}"
|
||||
assert item[key].min() >= -1.0, f"{key}"
|
||||
|
||||
if delta_timestamps is not None:
|
||||
# test missing keys in delta_timestamps
|
||||
for key in delta_timestamps:
|
||||
|
|
|
@ -29,7 +29,7 @@ DUMMY_MOTOR_FEATURES = {
|
|||
},
|
||||
}
|
||||
DUMMY_CAMERA_FEATURES = {
|
||||
"laptop": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
|
||||
"laptop": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None, "audio": "laptop"},
|
||||
"phone": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
|
||||
}
|
||||
DEFAULT_FPS = 30
|
||||
|
@ -40,5 +40,18 @@ DUMMY_VIDEO_INFO = {
|
|||
"video.is_depth_map": False,
|
||||
"has_audio": False,
|
||||
}
|
||||
DUMMY_MICROPHONE_FEATURES = {
|
||||
"laptop": {"dtype": "audio", "shape": (1,), "names": ["channels"], "info": None},
|
||||
"phone": {"dtype": "audio", "shape": (1,), "names": ["channels"], "info": None},
|
||||
}
|
||||
DEFAULT_SAMPLE_RATE = 48000
|
||||
DUMMY_AUDIO_CHANNELS = 2
|
||||
DUMMY_AUDIO_INFO = {
|
||||
"has_audio": True,
|
||||
"audio.sample_rate": DEFAULT_SAMPLE_RATE,
|
||||
"audio.codec": "aac",
|
||||
"audio.channels": DUMMY_AUDIO_CHANNELS,
|
||||
"audio.channel_layout": "stereo",
|
||||
}
|
||||
DUMMY_CHW = (3, 96, 128)
|
||||
DUMMY_HWC = (96, 128, 3)
|
||||
|
|
|
@ -36,6 +36,7 @@ from lerobot.common.datasets.utils import (
|
|||
from tests.fixtures.constants import (
|
||||
DEFAULT_FPS,
|
||||
DUMMY_CAMERA_FEATURES,
|
||||
DUMMY_MICROPHONE_FEATURES,
|
||||
DUMMY_MOTOR_FEATURES,
|
||||
DUMMY_REPO_ID,
|
||||
DUMMY_ROBOT_TYPE,
|
||||
|
@ -91,6 +92,7 @@ def features_factory():
|
|||
def _create_features(
|
||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
||||
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
||||
audio_features: dict = DUMMY_MICROPHONE_FEATURES,
|
||||
use_videos: bool = True,
|
||||
) -> dict:
|
||||
if use_videos:
|
||||
|
@ -102,6 +104,7 @@ def features_factory():
|
|||
return {
|
||||
**motor_features,
|
||||
**camera_ft,
|
||||
**audio_features,
|
||||
**DEFAULT_FEATURES,
|
||||
}
|
||||
|
||||
|
@ -125,9 +128,10 @@ def info_factory(features_factory):
|
|||
audio_path: str = DEFAULT_COMPRESSED_AUDIO_PATH,
|
||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
||||
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
||||
audio_features: dict = DUMMY_MICROPHONE_FEATURES,
|
||||
use_videos: bool = True,
|
||||
) -> dict:
|
||||
features = features_factory(motor_features, camera_features, use_videos)
|
||||
features = features_factory(motor_features, camera_features, audio_features, use_videos)
|
||||
return {
|
||||
"codebase_version": codebase_version,
|
||||
"robot_type": robot_type,
|
||||
|
@ -165,6 +169,14 @@ def stats_factory():
|
|||
"std": np.full((3, 1, 1), 0.25, dtype=np.float32).tolist(),
|
||||
"count": [10],
|
||||
}
|
||||
elif dtype == "audio":
|
||||
stats[key] = {
|
||||
"mean": np.full((shape[0],), 0.0, dtype=np.float32).tolist(),
|
||||
"max": np.full((shape[0],), 1, dtype=np.float32).tolist(),
|
||||
"min": np.full((shape[0],), -1, dtype=np.float32).tolist(),
|
||||
"std": np.full((shape[0],), 0.5, dtype=np.float32).tolist(),
|
||||
"count": [10],
|
||||
}
|
||||
else:
|
||||
stats[key] = {
|
||||
"max": np.full(shape, 1, dtype=dtype).tolist(),
|
||||
|
|
|
@ -0,0 +1,82 @@
|
|||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from functools import cache
|
||||
|
||||
from tests.fixtures.constants import DUMMY_AUDIO_CHANNELS, DEFAULT_SAMPLE_RATE
|
||||
|
||||
import numpy as np
|
||||
from lerobot.common.utils.utils import capture_timestamp_utc
|
||||
from threading import Thread, Event
|
||||
import time
|
||||
|
||||
@cache
|
||||
def _generate_sound(duration: float, sample_rate: int, channels: int):
|
||||
return np.random.uniform(-1, 1, size=(int(duration * sample_rate), channels)).astype(np.float32)
|
||||
|
||||
def query_devices(query_index: int):
|
||||
return {
|
||||
"name": "Mock Sound Device",
|
||||
"index": query_index,
|
||||
"max_input_channels": DUMMY_AUDIO_CHANNELS,
|
||||
"default_samplerate": DEFAULT_SAMPLE_RATE,
|
||||
}
|
||||
|
||||
class InputStream:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._mock_dict = {
|
||||
"channels": DUMMY_AUDIO_CHANNELS,
|
||||
"samplerate": DEFAULT_SAMPLE_RATE,
|
||||
}
|
||||
self._is_active = False
|
||||
self._audio_callback = kwargs.get("callback")
|
||||
|
||||
self.callback_thread = None
|
||||
self.callback_thread_stop_event = None
|
||||
|
||||
def _acquisition_loop(self):
|
||||
if self._audio_callback is not None:
|
||||
while not self.callback_thread_stop_event.is_set():
|
||||
# Simulate audio data acquisition
|
||||
time.sleep(0.01)
|
||||
self._audio_callback(_generate_sound(0.01, DEFAULT_SAMPLE_RATE, DUMMY_AUDIO_CHANNELS), 0.01*DEFAULT_SAMPLE_RATE, capture_timestamp_utc(), None)
|
||||
|
||||
def start(self):
|
||||
self.callback_thread_stop_event = Event()
|
||||
self.callback_thread = Thread(target=self._acquisition_loop, args=())
|
||||
self.callback_thread.daemon = True
|
||||
self.callback_thread.start()
|
||||
|
||||
self._is_active = True
|
||||
|
||||
@property
|
||||
def active(self):
|
||||
return self._is_active
|
||||
|
||||
def stop(self):
|
||||
if self.callback_thread_stop_event is not None:
|
||||
self.callback_thread_stop_event.set()
|
||||
self.callback_thread.join()
|
||||
self.callback_thread = None
|
||||
self.callback_thread_stop_event = None
|
||||
self._is_active = False
|
||||
|
||||
def close(self):
|
||||
if self._is_active:
|
||||
self.stop()
|
||||
|
||||
def __del__(self):
|
||||
if self._is_active:
|
||||
self.stop()
|
||||
|
||||
|
|
@ -0,0 +1,142 @@
|
|||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Tests for physical microphones and their mocked versions.
|
||||
If the physical microphone is not connected to the computer, or not working,
|
||||
the test will be skipped.
|
||||
|
||||
Example of running a specific test:
|
||||
```bash
|
||||
pytest -sx tests/microphones/test_microphones.py::test_microphone
|
||||
```
|
||||
|
||||
Example of running test on a real microphone connected to the computer:
|
||||
```bash
|
||||
pytest -sx 'tests/microphones/test_microphones.py::test_microphone[microphone-False]'
|
||||
```
|
||||
|
||||
Example of running test on a mocked version of the microphone:
|
||||
```bash
|
||||
pytest -sx 'tests/microphones/test_microphones.py::test_microphone[microphone-True]'
|
||||
```
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import time
|
||||
import pytest
|
||||
from soundfile import read
|
||||
|
||||
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError, RobotDeviceNotRecordingError, RobotDeviceAlreadyRecordingError
|
||||
from tests.utils import TEST_MICROPHONE_TYPES, make_microphone, require_microphone
|
||||
|
||||
#Maximum recording tie difference between two consecutive audio recordings of the same duration.
|
||||
#Set to 0.02 seconds as twice the default size of sounddvice callback buffer (i.e. we tolerate the loss of one buffer).
|
||||
MAX_RECORDING_TIME_DIFFERENCE = 0.02
|
||||
|
||||
DUMMY_RECORDING = "test_recording.wav"
|
||||
|
||||
@pytest.mark.parametrize("microphone_type, mock", TEST_MICROPHONE_TYPES)
|
||||
@require_microphone
|
||||
def test_microphone(tmp_path, request, microphone_type, mock):
|
||||
"""Test assumes that a recroding handled with microphone.start_recording(output_file) and stop_recording() or microphone.read()
|
||||
leqds to a sample that does not differ from the requested duration by more than 0.1 seconds.
|
||||
"""
|
||||
|
||||
microphone_kwargs = {"microphone_type": microphone_type, "mock": mock}
|
||||
|
||||
# Test instantiating
|
||||
microphone = make_microphone(**microphone_kwargs)
|
||||
|
||||
# Test start_recording, stop_recording, read and disconnecting before connecting raises an error
|
||||
with pytest.raises(RobotDeviceNotConnectedError):
|
||||
microphone.start_recording()
|
||||
with pytest.raises(RobotDeviceNotConnectedError):
|
||||
microphone.stop_recording()
|
||||
with pytest.raises(RobotDeviceNotConnectedError):
|
||||
microphone.read()
|
||||
with pytest.raises(RobotDeviceNotConnectedError):
|
||||
microphone.disconnect()
|
||||
|
||||
# Test deleting the object without connecting first
|
||||
del microphone
|
||||
|
||||
# Test connecting
|
||||
microphone = make_microphone(**microphone_kwargs)
|
||||
microphone.connect()
|
||||
assert microphone.is_connected
|
||||
assert microphone.sampling_rate is not None
|
||||
assert microphone.channels is not None
|
||||
|
||||
# Test connecting twice raises an error
|
||||
with pytest.raises(RobotDeviceAlreadyConnectedError):
|
||||
microphone.connect()
|
||||
|
||||
# Test reading or stop recording before starting recording raises an error
|
||||
with pytest.raises(RobotDeviceNotRecordingError):
|
||||
microphone.read()
|
||||
with pytest.raises(RobotDeviceNotRecordingError):
|
||||
microphone.stop_recording()
|
||||
|
||||
# Test start_recording
|
||||
fpath = tmp_path / DUMMY_RECORDING
|
||||
microphone.start_recording(fpath)
|
||||
assert microphone.is_recording
|
||||
|
||||
# Test start_recording twice raises an error
|
||||
with pytest.raises(RobotDeviceAlreadyRecordingError):
|
||||
microphone.start_recording()
|
||||
|
||||
# Test reading from the microphone
|
||||
time.sleep(1.0)
|
||||
audio_chunk = microphone.read()
|
||||
assert isinstance(audio_chunk, np.ndarray)
|
||||
assert audio_chunk.ndim == 2
|
||||
_, c = audio_chunk.shape
|
||||
assert c == len(microphone.channels)
|
||||
|
||||
# Test stop_recording
|
||||
microphone.stop_recording()
|
||||
assert fpath.exists()
|
||||
assert not microphone.stream.active
|
||||
assert microphone.record_thread is None
|
||||
|
||||
# Test stop_recording twice raises an error
|
||||
with pytest.raises(RobotDeviceNotRecordingError):
|
||||
microphone.stop_recording()
|
||||
|
||||
# Test reading and recording output similar length audio chunks
|
||||
microphone.start_recording(tmp_path / DUMMY_RECORDING)
|
||||
time.sleep(1.0)
|
||||
audio_chunk = microphone.read()
|
||||
microphone.stop_recording()
|
||||
|
||||
recorded_audio, recorded_sample_rate = read(fpath)
|
||||
assert recorded_sample_rate == microphone.sampling_rate
|
||||
|
||||
error_msg = (
|
||||
"Recording time difference between read() and stop_recording()",
|
||||
(len(audio_chunk) - len(recorded_audio))/MAX_RECORDING_TIME_DIFFERENCE,
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
len(audio_chunk), len(recorded_audio), atol=recorded_sample_rate*MAX_RECORDING_TIME_DIFFERENCE, err_msg=error_msg
|
||||
)
|
||||
|
||||
# Test disconnecting
|
||||
microphone.disconnect()
|
||||
assert not microphone.is_connected
|
||||
|
||||
# Test disconnecting with `__del__`
|
||||
microphone = make_microphone(**microphone_kwargs)
|
||||
microphone.connect()
|
||||
del microphone
|
|
@ -100,6 +100,9 @@ def test_robot(tmp_path, request, robot_type, mock):
|
|||
robot.teleop_step()
|
||||
|
||||
# Test data recorded during teleop are well formatted
|
||||
for _, microphone in robot.microphones.items():
|
||||
microphone.start_recording()
|
||||
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
# State
|
||||
assert "observation.state" in observation
|
||||
|
@ -112,6 +115,11 @@ def test_robot(tmp_path, request, robot_type, mock):
|
|||
assert f"observation.images.{name}" in observation
|
||||
assert isinstance(observation[f"observation.images.{name}"], torch.Tensor)
|
||||
assert observation[f"observation.images.{name}"].ndim == 3
|
||||
# Microphones
|
||||
for name in robot.microphones:
|
||||
assert f"observation.audio.{name}" in observation
|
||||
assert isinstance(observation[f"observation.audio.{name}"], torch.Tensor)
|
||||
assert observation[f"observation.audio.{name}"].ndim == 2
|
||||
# Action
|
||||
assert "action" in action
|
||||
assert isinstance(action["action"], torch.Tensor)
|
||||
|
@ -124,8 +132,9 @@ def test_robot(tmp_path, request, robot_type, mock):
|
|||
captured_observation = robot.capture_observation()
|
||||
assert set(captured_observation.keys()) == set(observation.keys())
|
||||
for name in captured_observation:
|
||||
if "image" in name:
|
||||
if "image" in name or "audio" in name:
|
||||
# TODO(rcadene): skipping image for now as it's challenging to assess equality between two consecutive frames
|
||||
# Also skipping for audio as audio chunks may be of different length
|
||||
continue
|
||||
torch.testing.assert_close(captured_observation[name], observation[name], rtol=1e-4, atol=1)
|
||||
assert captured_observation[name].shape == observation[name].shape
|
||||
|
@ -134,7 +143,7 @@ def test_robot(tmp_path, request, robot_type, mock):
|
|||
robot.send_action(action["action"])
|
||||
|
||||
# Test disconnecting
|
||||
robot.disconnect()
|
||||
robot.disconnect() #Also handles microphone recording stop, life is beautiful
|
||||
assert not robot.is_connected
|
||||
for name in robot.follower_arms:
|
||||
assert not robot.follower_arms[name].is_connected
|
||||
|
@ -142,3 +151,5 @@ def test_robot(tmp_path, request, robot_type, mock):
|
|||
assert not robot.leader_arms[name].is_connected
|
||||
for name in robot.cameras:
|
||||
assert not robot.cameras[name].is_connected
|
||||
for name in robot.microphones:
|
||||
assert not robot.microphones[name].is_connected
|
||||
|
|
|
@ -22,11 +22,13 @@ from pathlib import Path
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot import available_cameras, available_motors, available_robots
|
||||
from lerobot import available_cameras, available_motors, available_robots, available_microphones
|
||||
from lerobot.common.robot_devices.cameras.utils import Camera
|
||||
from lerobot.common.robot_devices.cameras.utils import make_camera as make_camera_device
|
||||
from lerobot.common.robot_devices.motors.utils import MotorsBus
|
||||
from lerobot.common.robot_devices.motors.utils import make_motors_bus as make_motors_bus_device
|
||||
from lerobot.common.robot_devices.microphones.utils import Microphone
|
||||
from lerobot.common.robot_devices.microphones.utils import make_microphone as make_microphone_device
|
||||
from lerobot.common.utils.import_utils import is_package_available
|
||||
|
||||
DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", "cuda") if torch.cuda.is_available() else "cpu"
|
||||
|
@ -39,6 +41,10 @@ TEST_CAMERA_TYPES = []
|
|||
for camera_type in available_cameras:
|
||||
TEST_CAMERA_TYPES += [(camera_type, True), (camera_type, False)]
|
||||
|
||||
TEST_MICROPHONE_TYPES = []
|
||||
for microphone_type in available_microphones:
|
||||
TEST_MICROPHONE_TYPES += [(microphone_type, True), (microphone_type, False)]
|
||||
|
||||
TEST_MOTOR_TYPES = []
|
||||
for motor_type in available_motors:
|
||||
TEST_MOTOR_TYPES += [(motor_type, True), (motor_type, False)]
|
||||
|
@ -47,6 +53,9 @@ for motor_type in available_motors:
|
|||
OPENCV_CAMERA_INDEX = int(os.environ.get("LEROBOT_TEST_OPENCV_CAMERA_INDEX", 0))
|
||||
INTELREALSENSE_SERIAL_NUMBER = int(os.environ.get("LEROBOT_TEST_INTELREALSENSE_SERIAL_NUMBER", 128422271614))
|
||||
|
||||
# Microphone indices used for connecting physical microphones
|
||||
MICROPHONE_INDEX = int(os.environ.get("LEROBOT_TEST_MICROPHONE_INDEX", 0))
|
||||
|
||||
DYNAMIXEL_PORT = os.environ.get("LEROBOT_TEST_DYNAMIXEL_PORT", "/dev/tty.usbmodem575E0032081")
|
||||
DYNAMIXEL_MOTORS = {
|
||||
"shoulder_pan": [1, "xl430-w250"],
|
||||
|
@ -252,6 +261,27 @@ def require_camera(func):
|
|||
|
||||
return wrapper
|
||||
|
||||
def require_microphone(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
# Access the pytest request context to get the is_microphone_available fixture
|
||||
request = kwargs.get("request")
|
||||
microphone_type = kwargs.get("microphone_type")
|
||||
mock = kwargs.get("mock")
|
||||
|
||||
if request is None:
|
||||
raise ValueError("The 'request' fixture must be an argument of the test function.")
|
||||
if microphone_type is None:
|
||||
raise ValueError("The 'microphone_type' must be an argument of the test function.")
|
||||
if mock is None:
|
||||
raise ValueError("The 'mock' variable must be an argument of the test function.")
|
||||
|
||||
if not mock and not request.getfixturevalue("is_microphone_available"):
|
||||
pytest.skip(f"A {microphone_type} microphone is not available.")
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
def require_motor(func):
|
||||
@wraps(func)
|
||||
|
@ -314,6 +344,12 @@ def make_camera(camera_type: str, **kwargs) -> Camera:
|
|||
else:
|
||||
raise ValueError(f"The camera type '{camera_type}' is not valid.")
|
||||
|
||||
def make_microphone(microphone_type: str, **kwargs) -> Microphone:
|
||||
if microphone_type == "microphone":
|
||||
microphone_index = kwargs.pop("microphone_index", MICROPHONE_INDEX)
|
||||
return make_microphone_device(microphone_type, microphone_index=microphone_index, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"The microphone type '{microphone_type}' is not valid.")
|
||||
|
||||
# TODO(rcadene, aliberts): remove this dark pattern that overrides
|
||||
def make_motors_bus(motor_type: str, **kwargs) -> MotorsBus:
|
||||
|
|
Loading…
Reference in New Issue