From 0cb9345f06644a6da79e0d49c406422a624185d2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 10 Apr 2025 16:56:39 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- lerobot/common/datasets/compute_stats.py | 13 +- lerobot/common/datasets/lerobot_dataset.py | 101 +++++++---- lerobot/common/datasets/utils.py | 17 +- lerobot/common/datasets/video_utils.py | 56 ++++--- lerobot/common/robot_devices/control_utils.py | 10 +- .../robot_devices/microphones/configs.py | 4 +- .../robot_devices/microphones/microphone.py | 157 ++++++++++-------- .../common/robot_devices/microphones/utils.py | 9 +- .../common/robot_devices/robots/configs.py | 2 +- .../robot_devices/robots/lekiwi_remote.py | 10 +- .../robot_devices/robots/manipulator.py | 8 +- .../robots/mobile_manipulator.py | 10 +- lerobot/common/robot_devices/utils.py | 4 +- tests/conftest.py | 6 +- tests/datasets/test_compute_stats.py | 24 ++- tests/datasets/test_datasets.py | 17 +- tests/fixtures/constants.py | 7 +- tests/fixtures/dataset_factories.py | 2 +- tests/microphones/mock_sounddevice.py | 30 ++-- tests/microphones/test_microphones.py | 26 ++- tests/robots/test_robots.py | 2 +- tests/utils.py | 10 +- 22 files changed, 329 insertions(+), 196 deletions(-) diff --git a/lerobot/common/datasets/compute_stats.py b/lerobot/common/datasets/compute_stats.py index 08ac4ae6..36606719 100644 --- a/lerobot/common/datasets/compute_stats.py +++ b/lerobot/common/datasets/compute_stats.py @@ -15,7 +15,8 @@ # limitations under the License. import numpy as np -from lerobot.common.datasets.utils import load_image_as_numpy, load_audio_from_path +from lerobot.common.datasets.utils import load_audio_from_path, load_image_as_numpy + def estimate_num_samples( dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75 @@ -70,17 +71,19 @@ def sample_images(image_paths: list[str]) -> np.ndarray: return images -def sample_audio_from_path(audio_path: str) -> np.ndarray: +def sample_audio_from_path(audio_path: str) -> np.ndarray: data = load_audio_from_path(audio_path) sampled_indices = sample_indices(len(data)) - return(data[sampled_indices]) + return data[sampled_indices] + def sample_audio_from_data(data: np.ndarray) -> np.ndarray: sampled_indices = sample_indices(len(data)) return data[sampled_indices] + def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]: return { "min": np.min(array, axis=axis, keepdims=keepdims), @@ -103,9 +106,9 @@ def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], featu elif features[key]["dtype"] == "audio": try: ep_ft_array = sample_audio_from_path(data[0]) - except TypeError: #Should only be triggered for LeKiwi robot + except TypeError: # Should only be triggered for LeKiwi robot ep_ft_array = sample_audio_from_data(data) - axes_to_reduce = 0 + axes_to_reduce = 0 keepdims = True else: ep_ft_array = data # data is already a np.ndarray diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index bc6689c7..f844eb72 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -23,6 +23,7 @@ import datasets import numpy as np import packaging.version import PIL.Image +import soundfile as sf import torch import torch.utils from datasets import concatenate_datasets, load_dataset @@ -34,13 +35,12 @@ from lerobot.common.constants import HF_LEROBOT_HOME from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image from lerobot.common.datasets.utils import ( + DEFAULT_AUDIO_CHUNK_DURATION, DEFAULT_FEATURES, DEFAULT_IMAGE_PATH, DEFAULT_RAW_AUDIO_PATH, - DEFAULT_COMPRESSED_AUDIO_PATH, INFO_PATH, TASKS_PATH, - DEFAULT_AUDIO_CHUNK_DURATION, append_jsonlines, backward_compatible_episodes_stats, check_delta_timestamps, @@ -70,17 +70,16 @@ from lerobot.common.datasets.utils import ( ) from lerobot.common.datasets.video_utils import ( VideoFrame, - decode_video_frames, - encode_video_frames, - encode_audio, decode_audio, + decode_video_frames, + encode_audio, + encode_video_frames, + get_audio_info, get_safe_default_codec, get_video_info, - get_audio_info, ) -from lerobot.common.robot_devices.robots.utils import Robot from lerobot.common.robot_devices.microphones.utils import Microphone -import soundfile as sf +from lerobot.common.robot_devices.robots.utils import Robot CODEBASE_VERSION = "v2.1" @@ -149,10 +148,12 @@ class LeRobotDatasetMetadata: ep_chunk = self.get_episode_chunk(ep_index) fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index) return Path(fpath) - + def get_compressed_audio_file_path(self, episode_index: int, audio_key: str) -> Path: episode_chunk = self.get_episode_chunk(episode_index) - fpath = self.audio_path.format(episode_chunk=episode_chunk, audio_key=audio_key, episode_index=episode_index) + fpath = self.audio_path.format( + episode_chunk=episode_chunk, audio_key=audio_key, episode_index=episode_index + ) return self.root / fpath def get_episode_chunk(self, ep_index: int) -> int: @@ -167,7 +168,7 @@ class LeRobotDatasetMetadata: def video_path(self) -> str | None: """Formattable string for the video files.""" return self.info["video_path"] - + @property def audio_path(self) -> str | None: """Formattable string for the audio files.""" @@ -202,16 +203,21 @@ class LeRobotDatasetMetadata: def camera_keys(self) -> list[str]: """Keys to access visual modalities (regardless of their storage method).""" return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]] - + @property def audio_keys(self) -> list[str]: """Keys to access audio modalities.""" return [key for key, ft in self.features.items() if ft["dtype"] == "audio"] - + @property def audio_camera_keys_mapping(self) -> dict[str, str]: """Mapping between camera keys and audio keys when both are linked.""" - return {self.features[camera_key]["audio"]:camera_key for camera_key in self.camera_keys if self.features[camera_key]["audio"] is not None and self.features[camera_key]["dtype"] == "video"} + return { + self.features[camera_key]["audio"]: camera_key + for camera_key in self.camera_keys + if self.features[camera_key]["audio"] is not None + and self.features[camera_key]["dtype"] == "video" + } @property def names(self) -> dict[str, list | dict]: @@ -325,7 +331,9 @@ class LeRobotDatasetMetadata: been encoded the same way. Also, this means it assumes the first episode exists. """ for key in set(self.audio_keys) - set(self.audio_camera_keys_mapping.keys()): - if not self.features[key].get("info", None) or (len(self.features[key]["info"]) == 1 and "sample_rate" in self.features[key]["info"]): + if not self.features[key].get("info", None) or ( + len(self.features[key]["info"]) == 1 and "sample_rate" in self.features[key]["info"] + ): audio_path = self.root / self.get_compressed_audio_file_path(0, key) self.info["features"][key]["info"] = get_audio_info(audio_path) @@ -518,7 +526,9 @@ class LeRobotDataset(torch.utils.data.Dataset): self.tolerance_s = tolerance_s self.revision = revision if revision else CODEBASE_VERSION self.video_backend = video_backend if video_backend else get_safe_default_codec() - self.audio_backend = audio_backend if audio_backend else "ffmpeg" #Waiting for torchcodec release #TODO(CarolinePascal) + self.audio_backend = ( + audio_backend if audio_backend else "ffmpeg" + ) # Waiting for torchcodec release #TODO(CarolinePascal) self.delta_indices = None # Unused attributes @@ -543,7 +553,9 @@ class LeRobotDataset(torch.utils.data.Dataset): self.hf_dataset = self.load_hf_dataset() except (AssertionError, FileNotFoundError, NotADirectoryError): self.revision = get_safe_version(self.repo_id, self.revision) - self.download_episodes(download_videos) #Sould load audio as well #TODO(CarolinePascal): separate audio from video + self.download_episodes( + download_videos + ) # Sould load audio as well #TODO(CarolinePascal): separate audio from video self.hf_dataset = self.load_hf_dataset() self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes) @@ -735,13 +747,13 @@ class LeRobotDataset(torch.utils.data.Dataset): query_timestamps[key] = [current_ts] return query_timestamps - - #TODO(CarolinePascal): add variable query durations + + # TODO(CarolinePascal): add variable query durations def _get_query_timestamps_audio( self, current_ts: float, query_indices: dict[str, list[int]] | None = None, - ) -> dict[str, list[float]]: + ) -> dict[str, list[float]]: query_timestamps = {} for key in self.meta.audio_keys: if query_indices is not None and key in query_indices: @@ -773,14 +785,18 @@ class LeRobotDataset(torch.utils.data.Dataset): return item - #TODO(CarolinePascal): add variable query durations - def _query_audio(self, query_timestamps: dict[str, list[float]], query_duration: float, ep_idx: int) -> dict[str, torch.Tensor]: + # TODO(CarolinePascal): add variable query durations + def _query_audio( + self, query_timestamps: dict[str, list[float]], query_duration: float, ep_idx: int + ) -> dict[str, torch.Tensor]: item = {} for audio_key, query_ts in query_timestamps.items(): - #Audio stored with video in a single .mp4 file + # Audio stored with video in a single .mp4 file if audio_key in self.meta.audio_camera_keys_mapping: - audio_path = self.root / self.meta.get_video_file_path(ep_idx, self.meta.audio_camera_keys_mapping[audio_key]) - #Audio stored alone in a separate .m4a file + audio_path = self.root / self.meta.get_video_file_path( + ep_idx, self.meta.audio_camera_keys_mapping[audio_key] + ) + # Audio stored alone in a separate .m4a file else: audio_path = self.root / self.meta.get_compressed_audio_file_path(ep_idx, audio_key) audio_chunk = decode_audio(audio_path, query_ts, query_duration, self.audio_backend) @@ -855,7 +871,7 @@ class LeRobotDataset(torch.utils.data.Dataset): image_key=image_key, episode_index=episode_index, frame_index=frame_index ) return self.root / fpath - + def _get_raw_audio_file_path(self, episode_index: int, audio_key: str) -> Path: fpath = DEFAULT_RAW_AUDIO_PATH.format(audio_key=audio_key, episode_index=episode_index) return self.root / fpath @@ -929,11 +945,17 @@ class LeRobotDataset(torch.utils.data.Dataset): This function will start recording audio from the microphone and save it to disk. """ - audio_dir = self._get_raw_audio_file_path(self.num_episodes, "observation.audio." + microphone_key).parent + audio_dir = self._get_raw_audio_file_path( + self.num_episodes, "observation.audio." + microphone_key + ).parent if not audio_dir.is_dir(): audio_dir.mkdir(parents=True, exist_ok=True) - - microphone.start_recording(output_file = self._get_raw_audio_file_path(self.num_episodes, "observation.audio." + microphone_key)) + + microphone.start_recording( + output_file=self._get_raw_audio_file_path( + self.num_episodes, "observation.audio." + microphone_key + ) + ) def save_episode(self, episode_data: dict | None = None) -> None: """ @@ -983,8 +1005,15 @@ class LeRobotDataset(torch.utils.data.Dataset): if self.meta.robot_type.startswith("lekiwi"): for key in self.meta.audio_keys: - audio_path = self._get_raw_audio_file_path(episode_index=self.episode_buffer["episode_index"][0], audio_key=key) - with sf.SoundFile(audio_path, mode='w', samplerate=self.meta.features[key]["info"]["sample_rate"], channels=self.meta.features[key]["shape"][0]) as file: + audio_path = self._get_raw_audio_file_path( + episode_index=self.episode_buffer["episode_index"][0], audio_key=key + ) + with sf.SoundFile( + audio_path, + mode="w", + samplerate=self.meta.features[key]["info"]["sample_rate"], + channels=self.meta.features[key]["shape"][0], + ) as file: file.write(episode_buffer[key]) ep_stats = compute_episode_stats(episode_buffer, self.features) @@ -996,7 +1025,7 @@ class LeRobotDataset(torch.utils.data.Dataset): if len(self.meta.audio_keys) > 0: _ = self.encode_episode_audio(episode_index) - + # `meta.save_episode` be executed after encoding the videos self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats) @@ -1113,7 +1142,7 @@ class LeRobotDataset(torch.utils.data.Dataset): encode_video_frames(img_dir, video_path, self.fps, audio_path=audio_path, overwrite=True) return video_paths - + def encode_episode_audio(self, episode_index: int) -> dict: """ Use ffmpeg to convert .wav raw audio files into .m4a audio files. @@ -1124,7 +1153,7 @@ class LeRobotDataset(torch.utils.data.Dataset): for audio_key in set(self.meta.audio_keys) - set(self.meta.audio_camera_keys_mapping.keys()): input_audio_path = self.root / self._get_raw_audio_file_path(episode_index, audio_key) output_audio_path = self.root / self.meta.get_compressed_audio_file_path(episode_index, audio_key) - + audio_paths[audio_key] = str(output_audio_path) if output_audio_path.is_file(): # Skip if video is already encoded. Could be the case when resuming data recording. @@ -1180,7 +1209,9 @@ class LeRobotDataset(torch.utils.data.Dataset): obj.delta_indices = None obj.episode_data_index = None obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec() - obj.audio_backend = audio_backend if audio_backend is not None else "ffmpeg" #Waiting for torchcodec release #TODO(CarolinePascal) + obj.audio_backend = ( + audio_backend if audio_backend is not None else "ffmpeg" + ) # Waiting for torchcodec release #TODO(CarolinePascal) return obj diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 970c447d..416d5837 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -33,9 +33,8 @@ from datasets.table import embed_table_storage from huggingface_hub import DatasetCard, DatasetCardData, HfApi from huggingface_hub.errors import RevisionNotFoundError from PIL import Image as PILImage -from torchvision import transforms - from soundfile import read +from torchvision import transforms from lerobot.common.datasets.backward_compatibility import ( V21_MESSAGE, @@ -260,10 +259,12 @@ def load_image_as_numpy( img_array /= 255.0 return img_array + def load_audio_from_path(fpath: str | Path) -> np.ndarray: audio_data, _ = read(fpath, dtype="float32") return audio_data + def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]): """Get a transform function that convert items from Hugging Face dataset (pyarrow) to torch tensors. Importantly, images are converted from PIL, which corresponds to @@ -731,7 +732,7 @@ def validate_features_presence( ): error_message = "" missing_features = expected_features - actual_features - missing_features = {feature for feature in missing_features if "observation.audio" not in feature} + missing_features = {feature for feature in missing_features if "observation.audio" not in feature} extra_features = actual_features - (expected_features | optional_features) if missing_features or extra_features: @@ -793,18 +794,24 @@ def validate_feature_image_or_video(name: str, expected_shape: list[str], value: return error_message + def validate_feature_audio(name: str, expected_shape: list[str], value: np.ndarray): error_message = "" if isinstance(value, np.ndarray): actual_shape = value.shape c = expected_shape - if len(actual_shape) != 2 or (actual_shape[-1] != c[-1] and actual_shape[0] != c[0]): #The number of frames might be different - error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c,)}'.\n" + if len(actual_shape) != 2 or ( + actual_shape[-1] != c[-1] and actual_shape[0] != c[0] + ): # The number of frames might be different + error_message += ( + f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c,)}'.\n" + ) else: error_message += f"The feature '{name}' is expected to be of type 'np.ndarray', but type '{type(value)}' provided instead.\n" return error_message + def validate_feature_string(name: str, value: str): if not isinstance(value, str): return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n" diff --git a/lerobot/common/datasets/video_utils.py b/lerobot/common/datasets/video_utils.py index e8e85411..0511610e 100644 --- a/lerobot/common/datasets/video_utils.py +++ b/lerobot/common/datasets/video_utils.py @@ -25,12 +25,11 @@ from typing import Any, ClassVar import pyarrow as pa import torch -import torchvision import torchaudio +import torchvision from datasets.features.features import register_feature -from PIL import Image - from numpy import ceil +from PIL import Image def get_safe_default_codec(): @@ -42,6 +41,7 @@ def get_safe_default_codec(): ) return "pyav" + def decode_audio( audio_path: Path | str, timestamps: list[float], @@ -68,30 +68,30 @@ def decode_audio( else: raise ValueError(f"Unsupported video backend: {backend}") + def decode_audio_torchvision( audio_path: Path | str, - timestamps: list[float], - duration: float, + timestamps: list[float], + duration: float, log_loaded_timestamps: bool = False, ) -> torch.Tensor: - - #TODO(CarolinePascal) : add channels selection + # TODO(CarolinePascal) : add channels selection audio_path = str(audio_path) reader = torchaudio.io.StreamReader(src=audio_path) audio_sample_rate = reader.get_src_stream_info(reader.default_audio_stream).sample_rate - #TODO(CarolinePascal) : sort timestamps ? + # TODO(CarolinePascal) : sort timestamps ? reader.add_basic_audio_stream( - frames_per_chunk = int(ceil(duration * audio_sample_rate)), #Too much is better than not enough - buffer_chunk_size = -1, #No dropping frames - format = "fltp", #Format as float32 + frames_per_chunk=int(ceil(duration * audio_sample_rate)), # Too much is better than not enough + buffer_chunk_size=-1, # No dropping frames + format="fltp", # Format as float32 ) audio_chunks = [] for ts in timestamps: - reader.seek(ts) #Default to closest audio sample + reader.seek(ts) # Default to closest audio sample status = reader.fill_buffer() if status != 0: logging.warning("Audio stream reached end of recording before decoding desired timestamps.") @@ -99,15 +99,18 @@ def decode_audio_torchvision( current_audio_chunk = reader.pop_chunks()[0] if log_loaded_timestamps: - logging.info(f"audio chunk loaded at starting timestamp={current_audio_chunk["pts"]:.4f} with duration={len(current_audio_chunk) / audio_sample_rate:.4f}") - + logging.info( + f"audio chunk loaded at starting timestamp={current_audio_chunk['pts']:.4f} with duration={len(current_audio_chunk) / audio_sample_rate:.4f}" + ) + audio_chunks.append(current_audio_chunk) audio_chunks = torch.stack(audio_chunks) assert len(timestamps) == len(audio_chunks) return audio_chunks - + + def decode_video_frames( video_path: Path | str, timestamps: list[float], @@ -137,6 +140,7 @@ def decode_video_frames( else: raise ValueError(f"Unsupported video backend: {backend}") + def decode_video_frames_torchvision( video_path: Path | str, timestamps: list[float], @@ -234,6 +238,7 @@ def decode_video_frames_torchvision( assert len(timestamps) == len(closest_frames) return closest_frames + def decode_video_frames_torchcodec( video_path: Path | str, timestamps: list[float], @@ -308,6 +313,7 @@ def decode_video_frames_torchcodec( assert len(timestamps) == len(closest_frames) return closest_frames + def encode_audio( input_path: Path | str, output_path: Path | str, @@ -344,6 +350,7 @@ def encode_audio( f"Try running the command manually to debug: `{''.join(ffmpeg_cmd)}`" ) + def encode_video_frames( imgs_dir: Path | str, video_path: Path | str, @@ -353,7 +360,7 @@ def encode_video_frames( pix_fmt: str = "yuv420p", g: int | None = 2, crf: int | None = 30, - acodec: str = "aac", #TODO(CarolinePascal) : investigate Fraunhofer FDK AAC (libfdk_aac) codec and and constant (file size control) /variable (quality control) bitrate options + acodec: str = "aac", # TODO(CarolinePascal) : investigate Fraunhofer FDK AAC (libfdk_aac) codec and and constant (file size control) /variable (quality control) bitrate options fast_decode: int = 0, log_level: str | None = "error", overwrite: bool = False, @@ -375,16 +382,18 @@ def encode_video_frames( if audio_path is not None: audio_path = Path(audio_path) audio_path.parent.mkdir(parents=True, exist_ok=True) - ffmpeg_audio_args.update(OrderedDict( - [ - ("-i", str(audio_path)), - ] - )) + ffmpeg_audio_args.update( + OrderedDict( + [ + ("-i", str(audio_path)), + ] + ) + ) ffmpeg_encoding_args = OrderedDict( [ ("-pix_fmt", pix_fmt), - ("-vcodec", vcodec), + ("-vcodec", vcodec), ] ) if g is not None: @@ -398,7 +407,7 @@ def encode_video_frames( if audio_path is not None: ffmpeg_encoding_args["-acodec"] = acodec - + if log_level is not None: ffmpeg_encoding_args["-loglevel"] = str(log_level) @@ -487,6 +496,7 @@ def get_audio_info(video_path: Path | str) -> dict: "audio.channel_layout": audio_stream_info.get("channel_layout", None), } + def get_video_info(video_path: Path | str) -> dict: ffprobe_video_cmd = [ "ffprobe", diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py index 7c8706a4..e49a4e71 100644 --- a/lerobot/common/robot_devices/control_utils.py +++ b/lerobot/common/robot_devices/control_utils.py @@ -77,8 +77,8 @@ def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, f key = f"read_camera_{name}_dt_s" if key in robot.logs: log_dt(f"dtR{name}", robot.logs[key]) - - for name in robot.microphones: + + for name in robot.microphones: key = f"read_microphone_{name}_dt_s" if key in robot.logs: log_dt(f"dtR{name}", robot.logs[key]) @@ -252,9 +252,11 @@ def control_loop( timestamp = 0 start_episode_t = time.perf_counter() - if dataset is not None and not robot.robot_type.startswith("lekiwi"): #For now, LeKiwi only supports frame audio recording (which may lead to audio chunks loss, extended post-processing, increased memory usage) + if ( + dataset is not None and not robot.robot_type.startswith("lekiwi") + ): # For now, LeKiwi only supports frame audio recording (which may lead to audio chunks loss, extended post-processing, increased memory usage) for microphone_key, microphone in robot.microphones.items(): - #Start recording both in file writing and data reading mode + # Start recording both in file writing and data reading mode dataset.add_microphone_recording(microphone, microphone_key) else: for _, microphone in robot.microphones.items(): diff --git a/lerobot/common/robot_devices/microphones/configs.py b/lerobot/common/robot_devices/microphones/configs.py index c2700723..1b663b7a 100644 --- a/lerobot/common/robot_devices/microphones/configs.py +++ b/lerobot/common/robot_devices/microphones/configs.py @@ -17,12 +17,14 @@ from dataclasses import dataclass import draccus + @dataclass class MicrophoneConfigBase(draccus.ChoiceRegistry, abc.ABC): @property def type(self) -> str: return self.get_choice_name(self.__class__) + @MicrophoneConfigBase.register_subclass("microphone") @dataclass class MicrophoneConfig(MicrophoneConfigBase): @@ -33,4 +35,4 @@ class MicrophoneConfig(MicrophoneConfigBase): microphone_index: int sample_rate: int | None = None channels: list[int] | None = None - mock: bool = False \ No newline at end of file + mock: bool = False diff --git a/lerobot/common/robot_devices/microphones/microphone.py b/lerobot/common/robot_devices/microphones/microphone.py index 2d75293a..947fdfea 100644 --- a/lerobot/common/robot_devices/microphones/microphone.py +++ b/lerobot/common/robot_devices/microphones/microphone.py @@ -13,36 +13,35 @@ # limitations under the License. """ -This file contains utilities for recording audio from a microhone. +This file contains utilities for recording audio from a microhone. """ import argparse -import soundfile as sf -import numpy as np import logging -from threading import Thread, Event -from multiprocessing import Process -from queue import Empty - -from queue import Queue as thread_Queue -from threading import Event as thread_Event -from multiprocessing import JoinableQueue as process_Queue -from multiprocessing import Event as process_Event - -from os import getcwd -from pathlib import Path import shutil import time +from multiprocessing import Event as process_Event +from multiprocessing import JoinableQueue as process_Queue +from multiprocessing import Process +from os import getcwd +from pathlib import Path +from queue import Empty +from queue import Queue as thread_Queue +from threading import Event, Thread +from threading import Event as thread_Event -from lerobot.common.utils.utils import capture_timestamp_utc +import numpy as np +import soundfile as sf from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig from lerobot.common.robot_devices.utils import ( RobotDeviceAlreadyConnectedError, + RobotDeviceAlreadyRecordingError, RobotDeviceNotConnectedError, RobotDeviceNotRecordingError, - RobotDeviceAlreadyRecordingError, ) +from lerobot.common.utils.utils import capture_timestamp_utc + def find_microphones(raise_when_empty=False, mock=False) -> list[dict]: microphones = [] @@ -69,11 +68,10 @@ def find_microphones(raise_when_empty=False, mock=False) -> list[dict]: return microphones -def record_audio_from_microphones( - output_dir: Path, - microphone_ids: list[int] | None = None, - record_time_s: float = 2.0): +def record_audio_from_microphones( + output_dir: Path, microphone_ids: list[int] | None = None, record_time_s: float = 2.0 +): if microphone_ids is None or len(microphone_ids) == 0: microphones = find_microphones() microphone_ids = [m["index"] for m in microphones] @@ -104,13 +102,14 @@ def record_audio_from_microphones( for microphone in microphones: microphone.stop_recording() - #Remark : recording may be resumed here if needed + # Remark : recording may be resumed here if needed for microphone in microphones: microphone.disconnect() print(f"Images have been saved to {output_dir}") + class Microphone: """ The Microphone class handles all microphones compatible with sounddevice (and the underlying PortAudio library). Most microphones and sound cards are compatible, accross all OS (Linux, Mac, Windows). @@ -138,20 +137,20 @@ class Microphone: self.config = config self.microphone_index = config.microphone_index - #Store the recording sample rate and channels + # Store the recording sample rate and channels self.sample_rate = config.sample_rate self.channels = config.channels self.mock = config.mock - #Input audio stream + # Input audio stream self.stream = None - #Thread-safe concurrent queue to store the recorded/read audio + # Thread-safe concurrent queue to store the recorded/read audio self.record_queue = None self.read_queue = None - #Thread to handle data reading and file writing in a separate thread (safely) + # Thread to handle data reading and file writing in a separate thread (safely) self.record_thread = None self.record_stop_event = None @@ -162,14 +161,16 @@ class Microphone: def connect(self) -> None: if self.is_connected: - raise RobotDeviceAlreadyConnectedError(f"Microphone {self.microphone_index} is already connected.") - + raise RobotDeviceAlreadyConnectedError( + f"Microphone {self.microphone_index} is already connected." + ) + if self.mock: import tests.microphones.mock_sounddevice as sd else: import sounddevice as sd - #Check if the provided microphone index does match an input device + # Check if the provided microphone index does match an input device is_index_input = sd.query_devices(self.microphone_index)["max_input_channels"] > 0 if not is_index_input: @@ -178,17 +179,19 @@ class Microphone: raise OSError( f"Microphone index {self.microphone_index} does not match an input device (microphone). Available input devices : {available_microphones}" ) - - #Check if provided recording parameters are compatible with the microphone + + # Check if provided recording parameters are compatible with the microphone actual_microphone = sd.query_devices(self.microphone_index) - if self.sample_rate is not None : + if self.sample_rate is not None: if self.sample_rate > actual_microphone["default_samplerate"]: raise OSError( f"Provided sample rate {self.sample_rate} is higher than the sample rate of the microphone {actual_microphone['default_samplerate']}." ) elif self.sample_rate < actual_microphone["default_samplerate"]: - logging.warning("Provided sample rate is lower than the sample rate of the microphone. Performance may be impacted.") + logging.warning( + "Provided sample rate is lower than the sample rate of the microphone. Performance may be impacted." + ) else: self.sample_rate = int(actual_microphone["default_samplerate"]) @@ -198,45 +201,52 @@ class Microphone: f"Some of the provided channels {self.channels} are outside the maximum channel range of the microphone {actual_microphone['max_input_channels']}." ) else: - self.channels = np.arange(1, actual_microphone["max_input_channels"]+1) + self.channels = np.arange(1, actual_microphone["max_input_channels"] + 1) # Get channels index instead of number for slicing self.channels = np.array(self.channels) - 1 - #Create the audio stream + # Create the audio stream self.stream = sd.InputStream( device=self.microphone_index, samplerate=self.sample_rate, - channels=max(self.channels)+1, + channels=max(self.channels) + 1, dtype="float32", callback=self._audio_callback, ) - #Remark : the blocksize parameter could be passed to the stream to ensure that audio_callback always recieve same length buffers. - #However, this may lead to additionnal latency. We thus stick to blocksize=0 which means that audio_callback will recieve varying length buffers, but with no addtional latency. - + # Remark : the blocksize parameter could be passed to the stream to ensure that audio_callback always recieve same length buffers. + # However, this may lead to additionnal latency. We thus stick to blocksize=0 which means that audio_callback will recieve varying length buffers, but with no addtional latency. + self.is_connected = True - def _audio_callback(self, indata, frames, time, status) -> None : + def _audio_callback(self, indata, frames, time, status) -> None: if status: logging.warning(status) - # Slicing makes copy unecessary + # Slicing makes copy unecessary # Two separate queues are necessary because .get() also pops the data from the queue if self.is_writing: - self.record_queue.put(indata[:,self.channels]) - self.read_queue.put(indata[:,self.channels]) + self.record_queue.put(indata[:, self.channels]) + self.read_queue.put(indata[:, self.channels]) @staticmethod def _record_loop(queue, event: Event, sample_rate: int, channels: list[int], output_file: Path) -> None: - #Can only be run on a single process/thread for file writing safety - with sf.SoundFile(output_file, mode='x', samplerate=sample_rate, - channels=max(channels)+1, subtype=sf.default_subtype(output_file.suffix[1:])) as file: + # Can only be run on a single process/thread for file writing safety + with sf.SoundFile( + output_file, + mode="x", + samplerate=sample_rate, + channels=max(channels) + 1, + subtype=sf.default_subtype(output_file.suffix[1:]), + ) as file: while not event.is_set(): try: - file.write(queue.get(timeout=0.02)) #Timeout set as twice the usual sounddevice buffer size + file.write( + queue.get(timeout=0.02) + ) # Timeout set as twice the usual sounddevice buffer size queue.task_done() except Empty: continue - + def _read(self) -> np.ndarray: """ Gets audio data from the queue and coverts it to a numpy array. @@ -244,7 +254,7 @@ class Microphone: -> CONS : Reading duration does not scale well with the number of channels and reading duration """ audio_readings = np.empty((0, len(self.channels))) - + while True: try: audio_readings = np.concatenate((audio_readings, self.read_queue.get_nowait()), axis=0) @@ -256,12 +266,11 @@ class Microphone: return audio_readings def read(self) -> np.ndarray: - if not self.is_connected: raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.") if not self.is_recording: raise RobotDeviceNotRecordingError(f"Microphone {self.microphone_index} is not recording.") - + start_time = time.perf_counter() audio_readings = self._read() @@ -274,21 +283,22 @@ class Microphone: return audio_readings - def start_recording(self, output_file : str | None = None, multiprocessing : bool | None = False) -> None: - + def start_recording(self, output_file: str | None = None, multiprocessing: bool | None = False) -> None: if not self.is_connected: raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.") if self.is_recording: - raise RobotDeviceAlreadyRecordingError(f"Microphone {self.microphone_index} is already recording.") - - #Reset queues + raise RobotDeviceAlreadyRecordingError( + f"Microphone {self.microphone_index} is already recording." + ) + + # Reset queues self.read_queue = thread_Queue() if multiprocessing: self.record_queue = process_Queue() else: self.record_queue = thread_Queue() - #Write recordings into a file if output_file is provided + # Write recordings into a file if output_file is provided if output_file is not None: output_file = Path(output_file) if output_file.exists(): @@ -296,28 +306,45 @@ class Microphone: if multiprocessing: self.record_stop_event = process_Event() - self.record_thread = Process(target=Microphone._record_loop, args=(self.record_queue, self.record_stop_event, self.sample_rate, self.channels, output_file, )) + self.record_thread = Process( + target=Microphone._record_loop, + args=( + self.record_queue, + self.record_stop_event, + self.sample_rate, + self.channels, + output_file, + ), + ) else: self.record_stop_event = thread_Event() - self.record_thread = Thread(target=Microphone._record_loop, args=(self.record_queue, self.record_stop_event, self.sample_rate, self.channels, output_file, )) + self.record_thread = Thread( + target=Microphone._record_loop, + args=( + self.record_queue, + self.record_stop_event, + self.sample_rate, + self.channels, + output_file, + ), + ) self.record_thread.daemon = True self.record_thread.start() self.is_writing = True - + self.is_recording = True self.stream.start() def stop_recording(self) -> None: - if not self.is_connected: raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.") if not self.is_recording: raise RobotDeviceNotRecordingError(f"Microphone {self.microphone_index} is not recording.") - + if self.stream.active: - self.stream.stop() #Wait for all buffers to be processed - #Remark : stream.abort() flushes the buffers ! + self.stream.stop() # Wait for all buffers to be processed + # Remark : stream.abort() flushes the buffers ! self.is_recording = False if self.record_thread is not None: @@ -329,7 +356,6 @@ class Microphone: self.is_writing = False def disconnect(self) -> None: - if not self.is_connected: raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.") @@ -342,7 +368,8 @@ class Microphone: def __del__(self): if getattr(self, "is_connected", False): self.disconnect() - + + if __name__ == "__main__": parser = argparse.ArgumentParser( description="Records audio using `Microphone` for all microphones connected to the computer, or a selected subset." diff --git a/lerobot/common/robot_devices/microphones/utils.py b/lerobot/common/robot_devices/microphones/utils.py index 1b1ad099..fb1bac85 100644 --- a/lerobot/common/robot_devices/microphones/utils.py +++ b/lerobot/common/robot_devices/microphones/utils.py @@ -16,28 +16,33 @@ from typing import Protocol from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig, MicrophoneConfigBase + # Defines a microphone type class Microphone(Protocol): def connect(self): ... def disconnect(self): ... - def start_recording(self, output_file : str | None = None, multiprocessing : bool | None = False): ... + def start_recording(self, output_file: str | None = None, multiprocessing: bool | None = False): ... def stop_recording(self): ... + def make_microphones_from_configs(microphone_configs: dict[str, MicrophoneConfigBase]) -> list[Microphone]: microphones = {} for key, cfg in microphone_configs.items(): if cfg.type == "microphone": from lerobot.common.robot_devices.microphones.microphone import Microphone + microphones[key] = Microphone(cfg) else: raise ValueError(f"The microphone type '{cfg.type}' is not valid.") return microphones + def make_microphone(microphone_type, **kwargs) -> Microphone: if microphone_type == "microphone": from lerobot.common.robot_devices.microphones.microphone import Microphone + return Microphone(MicrophoneConfig(**kwargs)) else: - raise ValueError(f"The microphone type '{microphone_type}' is not valid.") \ No newline at end of file + raise ValueError(f"The microphone type '{microphone_type}' is not valid.") diff --git a/lerobot/common/robot_devices/robots/configs.py b/lerobot/common/robot_devices/robots/configs.py index 66edc4f6..942586a0 100644 --- a/lerobot/common/robot_devices/robots/configs.py +++ b/lerobot/common/robot_devices/robots/configs.py @@ -23,12 +23,12 @@ from lerobot.common.robot_devices.cameras.configs import ( IntelRealSenseCameraConfig, OpenCVCameraConfig, ) +from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig from lerobot.common.robot_devices.motors.configs import ( DynamixelMotorsBusConfig, FeetechMotorsBusConfig, MotorsBusConfig, ) -from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig @dataclass diff --git a/lerobot/common/robot_devices/robots/lekiwi_remote.py b/lerobot/common/robot_devices/robots/lekiwi_remote.py index 03576c0c..15023d8a 100644 --- a/lerobot/common/robot_devices/robots/lekiwi_remote.py +++ b/lerobot/common/robot_devices/robots/lekiwi_remote.py @@ -51,6 +51,7 @@ def run_camera_capture(cameras, images_lock, latest_images_dict, stop_event): latest_images_dict.update(local_dict) time.sleep(0.01) + def run_microphone_capture(microphones, audio_lock, latest_audio_dict, stop_event): while not stop_event.is_set(): local_dict = {} @@ -60,6 +61,7 @@ def run_microphone_capture(microphones, audio_lock, latest_audio_dict, stop_even with audio_lock: latest_audio_dict.update(local_dict) + def calibrate_follower_arm(motors_bus, calib_dir_str): """ Calibrates the follower arm. Attempts to load an existing calibration file; @@ -149,12 +151,14 @@ def run_lekiwi(robot_config): cam_thread.start() # Start the microphone recording and capture thread. - #TODO(CarolinePascal) : Leverage multi-core processing with a multiprocessing.Process instead ! + # TODO(CarolinePascal) : Leverage multi-core processing with a multiprocessing.Process instead ! latest_audio_dict = {} audio_lock = threading.Lock() audio_stop_event = threading.Event() microphone_thread = threading.Thread( - target=run_microphone_capture, args=(microphones, audio_lock, latest_audio_dict, audio_stop_event), daemon=True + target=run_microphone_capture, + args=(microphones, audio_lock, latest_audio_dict, audio_stop_event), + daemon=True, ) for microphone in microphones.values(): microphone.start_recording() @@ -231,7 +235,7 @@ def run_lekiwi(robot_config): # Build the observation dictionary. observation = { "images": images_dict_copy, - "audio": audio_dict_copy, #TODO(CarolinePascal) : This is a nasty way to do it, sorry. + "audio": audio_dict_copy, # TODO(CarolinePascal) : This is a nasty way to do it, sorry. "present_speed": current_velocity, "follower_arm_state": follower_arm_state, } diff --git a/lerobot/common/robot_devices/robots/manipulator.py b/lerobot/common/robot_devices/robots/manipulator.py index dc68d609..b452be9d 100644 --- a/lerobot/common/robot_devices/robots/manipulator.py +++ b/lerobot/common/robot_devices/robots/manipulator.py @@ -201,7 +201,7 @@ class ManipulatorRobot: "names": state_names, }, } - + @property def microphone_features(self) -> dict: mic_ft = {} @@ -211,7 +211,7 @@ class ManipulatorRobot: "dtype": "audio", "shape": (len(mic.channels),), "names": "channels", - "info" : {"sample_rate": mic.sample_rate}, + "info": {"sample_rate": mic.sample_rate}, } return mic_ft @@ -226,11 +226,11 @@ class ManipulatorRobot: @property def num_cameras(self): return len(self.cameras) - + @property def has_microphone(self): return len(self.microphones) > 0 - + @property def num_microphones(self): return len(self.microphones) diff --git a/lerobot/common/robot_devices/robots/mobile_manipulator.py b/lerobot/common/robot_devices/robots/mobile_manipulator.py index 98e4cdb1..7727abb9 100644 --- a/lerobot/common/robot_devices/robots/mobile_manipulator.py +++ b/lerobot/common/robot_devices/robots/mobile_manipulator.py @@ -163,7 +163,7 @@ class MobileManipulator: "names": combined_names, }, } - + @property def microphone_features(self) -> dict: mic_ft = {} @@ -173,7 +173,7 @@ class MobileManipulator: "dtype": "audio", "shape": (len(mic.channels),), "names": "channels", - "info" : {"sample_rate": mic.sample_rate}, + "info": {"sample_rate": mic.sample_rate}, } return mic_ft @@ -188,11 +188,11 @@ class MobileManipulator: @property def num_cameras(self): return len(self.cameras) - + @property def has_microphone(self): return len(self.microphones) > 0 - + @property def num_microphones(self): return len(self.microphones) @@ -512,7 +512,7 @@ class MobileManipulator: # Create silence using the microphone's configured channels frame = np.zeros((1, len(microphone.channels)), dtype=np.float32) obs_dict[f"observation.audio.{microphone_name}"] = torch.from_numpy(frame) - + return obs_dict def send_action(self, action: torch.Tensor) -> torch.Tensor: diff --git a/lerobot/common/robot_devices/utils.py b/lerobot/common/robot_devices/utils.py index 01f9195e..5b2270e7 100644 --- a/lerobot/common/robot_devices/utils.py +++ b/lerobot/common/robot_devices/utils.py @@ -69,11 +69,13 @@ class RobotDeviceNotRecordingError(Exception): """Exception raised when the robot device is not recording.""" def __init__( - self, message="This robot device is not recording. Try calling `robot_device.start_recording()` first." + self, + message="This robot device is not recording. Try calling `robot_device.start_recording()` first.", ): self.message = message super().__init__(self.message) + class RobotDeviceAlreadyRecordingError(Exception): """Exception raised when the robot device is already recording.""" diff --git a/tests/conftest.py b/tests/conftest.py index adf80931..a0e9ae41 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,9 +19,9 @@ import traceback import pytest from serial import SerialException -from lerobot import available_cameras, available_motors, available_robots, available_microphones +from lerobot import available_cameras, available_microphones, available_motors, available_robots from lerobot.common.robot_devices.robots.utils import make_robot -from tests.utils import DEVICE, make_camera, make_motors_bus, make_microphone +from tests.utils import DEVICE, make_camera, make_microphone, make_motors_bus # Import fixture modules as plugins pytest_plugins = [ @@ -73,10 +73,12 @@ def is_robot_available(robot_type): def is_camera_available(camera_type): return _check_component_availability(camera_type, available_cameras, make_camera) + @pytest.fixture def is_microphone_available(microphone_type): return _check_component_availability(microphone_type, available_microphones, make_microphone) + @pytest.fixture def is_motor_available(motor_type): return _check_component_availability(motor_type, available_motors, make_motors_bus) diff --git a/tests/datasets/test_compute_stats.py b/tests/datasets/test_compute_stats.py index 56cdc176..9cf9f760 100644 --- a/tests/datasets/test_compute_stats.py +++ b/tests/datasets/test_compute_stats.py @@ -25,9 +25,9 @@ from lerobot.common.datasets.compute_stats import ( compute_episode_stats, estimate_num_samples, get_feature_stats, - sample_images, - sample_audio_from_path, sample_audio_from_data, + sample_audio_from_path, + sample_images, sample_indices, ) @@ -35,8 +35,10 @@ from lerobot.common.datasets.compute_stats import ( def mock_load_image_as_numpy(path, dtype, channel_first): return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype) + def mock_load_audio(path): - return np.ones((16000,2), dtype=np.float32) + return np.ones((16000, 2), dtype=np.float32) + @pytest.fixture def sample_array(): @@ -74,6 +76,7 @@ def test_sample_images(mock_load): assert images.dtype == np.uint8 assert len(images) == estimate_num_samples(100) + @patch("lerobot.common.datasets.compute_stats.load_audio_from_path", side_effect=mock_load_audio) def test_sample_audio_from_path(mock_load): audio_path = "audio.wav" @@ -83,6 +86,7 @@ def test_sample_audio_from_path(mock_load): assert audio_samples.dtype == np.float32 assert len(audio_samples) == estimate_num_samples(16000) + def test_sample_audio_from_data(mock_load): audio_data = np.ones((16000, 2), dtype=np.float32) audio_samples = sample_audio_from_data(audio_data) @@ -91,6 +95,7 @@ def test_sample_audio_from_data(mock_load): assert audio_samples.dtype == np.float32 assert len(audio_samples) == estimate_num_samples(16000) + def test_get_feature_stats_images(): data = np.random.rand(100, 3, 32, 32) stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True) @@ -98,13 +103,15 @@ def test_get_feature_stats_images(): np.testing.assert_equal(stats["count"], np.array([100])) assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape + def test_get_feature_stats_audio(): - data = np.random.uniform(-1, 1, (16000,2)) + data = np.random.uniform(-1, 1, (16000, 2)) stats = get_feature_stats(data, axis=0, keepdims=True) assert "min" in stats and "max" in stats and "mean" in stats and "std" in stats and "count" in stats np.testing.assert_equal(stats["count"], np.array([16000])) assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape + def test_get_feature_stats_axis_0_keepdims(sample_array): expected = { "min": np.array([[1, 2, 3]]), @@ -172,10 +179,11 @@ def test_compute_episode_stats(): "observation.state": {"dtype": "numeric"}, } - with patch( - "lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy - ), patch( - "lerobot.common.datasets.compute_stats.load_audio_from_path", side_effect=mock_load_audio + with ( + patch( + "lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy + ), + patch("lerobot.common.datasets.compute_stats.load_audio_from_path", side_effect=mock_load_audio), ): stats = compute_episode_stats(episode_data, features) diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index dde2ed06..a10a7d61 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -16,6 +16,7 @@ import json import logging import re +import time from copy import deepcopy from itertools import chain from pathlib import Path @@ -35,6 +36,7 @@ from lerobot.common.datasets.lerobot_dataset import ( MultiLeRobotDataset, ) from lerobot.common.datasets.utils import ( + DEFAULT_AUDIO_CHUNK_DURATION, create_branch, flatten_dict, unflatten_dict, @@ -44,12 +46,9 @@ from lerobot.common.policies.factory import make_policy_config from lerobot.common.robot_devices.robots.utils import make_robot from lerobot.configs.default import DatasetConfig from lerobot.configs.train import TrainPipelineConfig -from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID, DUMMY_AUDIO_CHANNELS -from tests.utils import require_x86_64_kernel +from tests.fixtures.constants import DUMMY_AUDIO_CHANNELS, DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID +from tests.utils import make_microphone, require_x86_64_kernel -from tests.utils import make_microphone -import time -from lerobot.common.datasets.utils import DEFAULT_AUDIO_CHUNK_DURATION @pytest.fixture def image_dataset(tmp_path, empty_lerobot_dataset_factory): @@ -66,6 +65,7 @@ def image_dataset(tmp_path, empty_lerobot_dataset_factory): } return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) + @pytest.fixture def audio_dataset(tmp_path, empty_lerobot_dataset_factory): features = { @@ -79,6 +79,7 @@ def audio_dataset(tmp_path, empty_lerobot_dataset_factory): } return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) + def test_same_attributes_defined(tmp_path, lerobot_dataset_factory): """ Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated @@ -336,6 +337,7 @@ def test_image_array_to_pil_image_wrong_range_float_0_255(): with pytest.raises(ValueError): image_array_to_pil_image(image) + def test_add_frame_audio(audio_dataset): dataset = audio_dataset @@ -349,7 +351,10 @@ def test_add_frame_audio(audio_dataset): dataset.save_episode() - assert dataset[0]["observation.audio.microphone"].shape == torch.Size((int(DEFAULT_AUDIO_CHUNK_DURATION*microphone.sample_rate),DUMMY_AUDIO_CHANNELS)) + assert dataset[0]["observation.audio.microphone"].shape == torch.Size( + (int(DEFAULT_AUDIO_CHUNK_DURATION * microphone.sample_rate), DUMMY_AUDIO_CHANNELS) + ) + # TODO(aliberts): # - [ ] test various attributes & state from init and create diff --git a/tests/fixtures/constants.py b/tests/fixtures/constants.py index 91942190..b95044df 100644 --- a/tests/fixtures/constants.py +++ b/tests/fixtures/constants.py @@ -29,7 +29,12 @@ DUMMY_MOTOR_FEATURES = { }, } DUMMY_CAMERA_FEATURES = { - "laptop": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None, "audio": "laptop"}, + "laptop": { + "shape": (480, 640, 3), + "names": ["height", "width", "channels"], + "info": None, + "audio": "laptop", + }, "phone": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None}, } DEFAULT_FPS = 30 diff --git a/tests/fixtures/dataset_factories.py b/tests/fixtures/dataset_factories.py index 80387d65..321dec46 100644 --- a/tests/fixtures/dataset_factories.py +++ b/tests/fixtures/dataset_factories.py @@ -26,10 +26,10 @@ import torch from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata from lerobot.common.datasets.utils import ( DEFAULT_CHUNK_SIZE, + DEFAULT_COMPRESSED_AUDIO_PATH, DEFAULT_FEATURES, DEFAULT_PARQUET_PATH, DEFAULT_VIDEO_PATH, - DEFAULT_COMPRESSED_AUDIO_PATH, get_hf_features_from_features, hf_transform_to_torch, ) diff --git a/tests/microphones/mock_sounddevice.py b/tests/microphones/mock_sounddevice.py index f6007085..0220c88c 100644 --- a/tests/microphones/mock_sounddevice.py +++ b/tests/microphones/mock_sounddevice.py @@ -11,27 +11,30 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import time from functools import cache - -from tests.fixtures.constants import DUMMY_AUDIO_CHANNELS, DEFAULT_SAMPLE_RATE +from threading import Event, Thread import numpy as np + from lerobot.common.utils.utils import capture_timestamp_utc -from threading import Thread, Event -import time +from tests.fixtures.constants import DEFAULT_SAMPLE_RATE, DUMMY_AUDIO_CHANNELS + @cache def _generate_sound(duration: float, sample_rate: int, channels: int): return np.random.uniform(-1, 1, size=(int(duration * sample_rate), channels)).astype(np.float32) + def query_devices(query_index: int): return { - "name": "Mock Sound Device", - "index": query_index, - "max_input_channels": DUMMY_AUDIO_CHANNELS, - "default_samplerate": DEFAULT_SAMPLE_RATE, + "name": "Mock Sound Device", + "index": query_index, + "max_input_channels": DUMMY_AUDIO_CHANNELS, + "default_samplerate": DEFAULT_SAMPLE_RATE, } + class InputStream: def __init__(self, *args, **kwargs): self._mock_dict = { @@ -49,7 +52,12 @@ class InputStream: while not self.callback_thread_stop_event.is_set(): # Simulate audio data acquisition time.sleep(0.01) - self._audio_callback(_generate_sound(0.01, DEFAULT_SAMPLE_RATE, DUMMY_AUDIO_CHANNELS), 0.01*DEFAULT_SAMPLE_RATE, capture_timestamp_utc(), None) + self._audio_callback( + _generate_sound(0.01, DEFAULT_SAMPLE_RATE, DUMMY_AUDIO_CHANNELS), + 0.01 * DEFAULT_SAMPLE_RATE, + capture_timestamp_utc(), + None, + ) def start(self): self.callback_thread_stop_event = Event() @@ -62,7 +70,7 @@ class InputStream: @property def active(self): return self._is_active - + def stop(self): if self.callback_thread_stop_event is not None: self.callback_thread_stop_event.set() @@ -78,5 +86,3 @@ class InputStream: def __del__(self): if self._is_active: self.stop() - - diff --git a/tests/microphones/test_microphones.py b/tests/microphones/test_microphones.py index 3ce29fa2..c7bdbe71 100644 --- a/tests/microphones/test_microphones.py +++ b/tests/microphones/test_microphones.py @@ -32,20 +32,27 @@ pytest -sx 'tests/microphones/test_microphones.py::test_microphone[microphone-Tr ``` """ -import numpy as np import time + +import numpy as np import pytest from soundfile import read -from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError, RobotDeviceNotRecordingError, RobotDeviceAlreadyRecordingError +from lerobot.common.robot_devices.utils import ( + RobotDeviceAlreadyConnectedError, + RobotDeviceAlreadyRecordingError, + RobotDeviceNotConnectedError, + RobotDeviceNotRecordingError, +) from tests.utils import TEST_MICROPHONE_TYPES, make_microphone, require_microphone -#Maximum recording tie difference between two consecutive audio recordings of the same duration. -#Set to 0.02 seconds as twice the default size of sounddvice callback buffer (i.e. we tolerate the loss of one buffer). +# Maximum recording tie difference between two consecutive audio recordings of the same duration. +# Set to 0.02 seconds as twice the default size of sounddvice callback buffer (i.e. we tolerate the loss of one buffer). MAX_RECORDING_TIME_DIFFERENCE = 0.02 DUMMY_RECORDING = "test_recording.wav" + @pytest.mark.parametrize("microphone_type, mock", TEST_MICROPHONE_TYPES) @require_microphone def test_microphone(tmp_path, request, microphone_type, mock): @@ -92,7 +99,7 @@ def test_microphone(tmp_path, request, microphone_type, mock): fpath = tmp_path / DUMMY_RECORDING microphone.start_recording(fpath) assert microphone.is_recording - + # Test start_recording twice raises an error with pytest.raises(RobotDeviceAlreadyRecordingError): microphone.start_recording() @@ -126,10 +133,13 @@ def test_microphone(tmp_path, request, microphone_type, mock): error_msg = ( "Recording time difference between read() and stop_recording()", - (len(audio_chunk) - len(recorded_audio))/MAX_RECORDING_TIME_DIFFERENCE, + (len(audio_chunk) - len(recorded_audio)) / MAX_RECORDING_TIME_DIFFERENCE, ) np.testing.assert_allclose( - len(audio_chunk), len(recorded_audio), atol=recorded_sample_rate*MAX_RECORDING_TIME_DIFFERENCE, err_msg=error_msg + len(audio_chunk), + len(recorded_audio), + atol=recorded_sample_rate * MAX_RECORDING_TIME_DIFFERENCE, + err_msg=error_msg, ) # Test disconnecting @@ -139,4 +149,4 @@ def test_microphone(tmp_path, request, microphone_type, mock): # Test disconnecting with `__del__` microphone = make_microphone(**microphone_kwargs) microphone.connect() - del microphone \ No newline at end of file + del microphone diff --git a/tests/robots/test_robots.py b/tests/robots/test_robots.py index 204aabca..8353fe29 100644 --- a/tests/robots/test_robots.py +++ b/tests/robots/test_robots.py @@ -143,7 +143,7 @@ def test_robot(tmp_path, request, robot_type, mock): robot.send_action(action["action"]) # Test disconnecting - robot.disconnect() #Also handles microphone recording stop, life is beautiful + robot.disconnect() # Also handles microphone recording stop, life is beautiful assert not robot.is_connected for name in robot.follower_arms: assert not robot.follower_arms[name].is_connected diff --git a/tests/utils.py b/tests/utils.py index 8559ca0c..d93eb97c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -22,13 +22,13 @@ from pathlib import Path import pytest import torch -from lerobot import available_cameras, available_motors, available_robots, available_microphones +from lerobot import available_cameras, available_microphones, available_motors, available_robots from lerobot.common.robot_devices.cameras.utils import Camera from lerobot.common.robot_devices.cameras.utils import make_camera as make_camera_device -from lerobot.common.robot_devices.motors.utils import MotorsBus -from lerobot.common.robot_devices.motors.utils import make_motors_bus as make_motors_bus_device from lerobot.common.robot_devices.microphones.utils import Microphone from lerobot.common.robot_devices.microphones.utils import make_microphone as make_microphone_device +from lerobot.common.robot_devices.motors.utils import MotorsBus +from lerobot.common.robot_devices.motors.utils import make_motors_bus as make_motors_bus_device from lerobot.common.utils.import_utils import is_package_available DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", "cuda") if torch.cuda.is_available() else "cpu" @@ -261,6 +261,7 @@ def require_camera(func): return wrapper + def require_microphone(func): @wraps(func) def wrapper(*args, **kwargs): @@ -283,6 +284,7 @@ def require_microphone(func): return wrapper + def require_motor(func): @wraps(func) def wrapper(*args, **kwargs): @@ -344,6 +346,7 @@ def make_camera(camera_type: str, **kwargs) -> Camera: else: raise ValueError(f"The camera type '{camera_type}' is not valid.") + def make_microphone(microphone_type: str, **kwargs) -> Microphone: if microphone_type == "microphone": microphone_index = kwargs.pop("microphone_index", MICROPHONE_INDEX) @@ -351,6 +354,7 @@ def make_microphone(microphone_type: str, **kwargs) -> Microphone: else: raise ValueError(f"The microphone type '{microphone_type}' is not valid.") + # TODO(rcadene, aliberts): remove this dark pattern that overrides def make_motors_bus(motor_type: str, **kwargs) -> MotorsBus: if motor_type == "dynamixel":