diff --git a/lerobot/__init__.py b/lerobot/__init__.py index d61e4853..386c2fbb 100644 --- a/lerobot/__init__.py +++ b/lerobot/__init__.py @@ -190,6 +190,11 @@ available_cameras = [ "intelrealsense", ] +# lists all available microphones from `lerobot/common/robot_devices/microphones` +available_microphones = [ + "microphone", +] + # lists all available motors from `lerobot/common/robot_devices/motors` available_motors = [ "dynamixel", diff --git a/lerobot/common/datasets/audio_utils.py b/lerobot/common/datasets/audio_utils.py new file mode 100644 index 00000000..2f3537d2 --- /dev/null +++ b/lerobot/common/datasets/audio_utils.py @@ -0,0 +1,195 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import logging +import subprocess +from collections import OrderedDict +from pathlib import Path + +import torch +import torchaudio +import torchcodec +from numpy import ceil + + +def decode_audio( + audio_path: Path | str, + timestamps: list[float], + duration: float, + backend: str | None = "torchaudio", +) -> torch.Tensor: + """ + Decodes audio using the specified backend. + Args: + audio_path (Path): Path to the audio file. + timestamps (list[float]): List of (starting) timestamps to extract audio chunks. + duration (float): Duration of the audio chunks in seconds. + backend (str, optional): Backend to use for decoding. Defaults to "torchaudio". + + Returns: + torch.Tensor: Decoded audio chunks. + + Currently supports ffmpeg. + """ + if backend == "torchcodec": + # return decode_audio_torchcodec(audio_path, timestamps, duration) #TODO(CarolinePascal): uncomment this line at next torchcodec release + raise ValueError("torchcodec backend is not available yet.") + elif backend == "torchaudio": + return decode_audio_torchaudio(audio_path, timestamps, duration) + else: + raise ValueError(f"Unsupported video backend: {backend}") + + +def decode_audio_torchcodec( + audio_path: Path | str, + timestamps: list[float], + duration: float, + log_loaded_timestamps: bool = False, +) -> torch.Tensor: + # TODO(CarolinePascal) : add channels selection + audio_decoder = torchcodec.decoders.AudioDecoder(audio_path) + + audio_chunks = [] + for ts in timestamps: + current_audio_chunk = audio_decoder.get_samples_played_in_range( + start_seconds=ts, stop_seconds=ts + duration + ) + + if log_loaded_timestamps: + logging.info( + f"audio chunk loaded at starting timestamp={current_audio_chunk.pts_seconds:.4f} with duration={current_audio_chunk.duration_seconds:.4f}" + ) + + audio_chunks.append(current_audio_chunk.data) + + audio_chunks = torch.stack(audio_chunks) + + assert len(timestamps) == len(audio_chunks) + return audio_chunks + + +def decode_audio_torchaudio( + audio_path: Path | str, + timestamps: list[float], + duration: float, + log_loaded_timestamps: bool = False, +) -> torch.Tensor: + # TODO(CarolinePascal) : add channels selection + audio_path = str(audio_path) + + reader = torchaudio.io.StreamReader(src=audio_path) + audio_sample_rate = reader.get_src_stream_info(reader.default_audio_stream).sample_rate + + # TODO(CarolinePascal) : sort timestamps ? + reader.add_basic_audio_stream( + frames_per_chunk=int(ceil(duration * audio_sample_rate)), # Too much is better than not enough + buffer_chunk_size=-1, # No dropping frames + format="fltp", # Format as float32 + ) + + audio_chunks = [] + for ts in timestamps: + reader.seek(ts) # Default to closest audio sample + status = reader.fill_buffer() + if status != 0: + logging.warning("Audio stream reached end of recording before decoding desired timestamps.") + + current_audio_chunk = reader.pop_chunks()[0] + + if log_loaded_timestamps: + logging.info( + f"audio chunk loaded at starting timestamp={current_audio_chunk['pts']:.4f} with duration={len(current_audio_chunk) / audio_sample_rate:.4f}" + ) + + audio_chunks.append(current_audio_chunk) + + audio_chunks = torch.stack(audio_chunks) + + assert len(timestamps) == len(audio_chunks) + return audio_chunks + + +def encode_audio( + input_path: Path | str, + output_path: Path | str, + codec: str = "aac", # TODO(CarolinePascal) : investigate Fraunhofer FDK AAC (libfdk_aac) codec and and constant (file size control) /variable (quality control) bitrate options + log_level: str | None = "error", + overwrite: bool = False, +) -> None: + """Encodes an audio file using ffmpeg.""" + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + ffmpeg_args = OrderedDict( + [ + ("-i", str(input_path)), + ("-acodec", codec), + ] + ) + + if log_level is not None: + ffmpeg_args["-loglevel"] = str(log_level) + + ffmpeg_args = [item for pair in ffmpeg_args.items() for item in pair] + if overwrite: + ffmpeg_args.append("-y") + + ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(output_path)] + + # redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal + subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL) + + if not output_path.exists(): + raise OSError( + f"Audio encoding did not work. File not found: {output_path}. " + f"Try running the command manually to debug: `{''.join(ffmpeg_cmd)}`" + ) + + +def get_audio_info(video_path: Path | str) -> dict: + ffprobe_audio_cmd = [ + "ffprobe", + "-v", + "error", + "-select_streams", + "a:0", + "-show_entries", + "stream=channels,codec_name,bit_rate,sample_rate,bit_depth,channel_layout,duration", + "-of", + "json", + str(video_path), + ] + result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + if result.returncode != 0: + raise RuntimeError(f"Error running ffprobe: {result.stderr}") + + info = json.loads(result.stdout) + audio_stream_info = info["streams"][0] if info.get("streams") else None + if audio_stream_info is None: + return {"has_audio": False} + + # Return the information, defaulting to None if no audio stream is present + return { + "has_audio": True, + "audio.channels": audio_stream_info.get("channels", None), + "audio.codec": audio_stream_info.get("codec_name", None), + "audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None, + "audio.sample_rate": int(audio_stream_info["sample_rate"]) + if audio_stream_info.get("sample_rate") + else None, + "audio.bit_depth": audio_stream_info.get("bit_depth", None), + "audio.channel_layout": audio_stream_info.get("channel_layout", None), + } diff --git a/lerobot/common/datasets/compute_stats.py b/lerobot/common/datasets/compute_stats.py index 1149ec83..2fab5a80 100644 --- a/lerobot/common/datasets/compute_stats.py +++ b/lerobot/common/datasets/compute_stats.py @@ -15,7 +15,7 @@ # limitations under the License. import numpy as np -from lerobot.common.datasets.utils import load_image_as_numpy +from lerobot.common.datasets.utils import load_audio_from_path, load_image_as_numpy def estimate_num_samples( @@ -72,6 +72,20 @@ def sample_images(image_paths: list[str]) -> np.ndarray: return images +def sample_audio_from_path(audio_path: str) -> np.ndarray: + """Samples audio data from an audio recording stored in a WAV file.""" + data = load_audio_from_path(audio_path) + sampled_indices = sample_indices(len(data)) + + return data[sampled_indices] + + +def sample_audio_from_data(data: np.ndarray) -> np.ndarray: + """Samples audio data from an audio recording stored in a numpy array.""" + sampled_indices = sample_indices(len(data)) + return data[sampled_indices] + + def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]: return { "min": np.min(array, axis=axis, keepdims=keepdims), @@ -91,6 +105,13 @@ def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], featu ep_ft_array = sample_images(data) # data is a list of image paths axes_to_reduce = (0, 2, 3) # keep channel dim keepdims = True + elif features[key]["dtype"] == "audio": + try: + ep_ft_array = sample_audio_from_path(data[0]) + except TypeError: # Should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner + ep_ft_array = sample_audio_from_data(data) + axes_to_reduce = 0 + keepdims = True else: ep_ft_array = data # data is already a np.ndarray axes_to_reduce = 0 # compute stats over the first axis diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index d8da85d6..27899870 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -23,6 +23,7 @@ import datasets import numpy as np import packaging.version import PIL.Image +import soundfile as sf import torch import torch.utils from datasets import concatenate_datasets, load_dataset @@ -31,11 +32,18 @@ from huggingface_hub.constants import REPOCARD_NAME from huggingface_hub.errors import RevisionNotFoundError from lerobot.common.constants import HF_LEROBOT_HOME +from lerobot.common.datasets.audio_utils import ( + decode_audio, + encode_audio, + get_audio_info, +) from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image from lerobot.common.datasets.utils import ( + DEFAULT_AUDIO_CHUNK_DURATION, DEFAULT_FEATURES, DEFAULT_IMAGE_PATH, + DEFAULT_RAW_AUDIO_PATH, INFO_PATH, TASKS_PATH, append_jsonlines, @@ -72,6 +80,7 @@ from lerobot.common.datasets.video_utils import ( get_safe_default_codec, get_video_info, ) +from lerobot.common.robot_devices.microphones.utils import Microphone from lerobot.common.robot_devices.robots.utils import Robot CODEBASE_VERSION = "v2.1" @@ -142,6 +151,14 @@ class LeRobotDatasetMetadata: fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index) return Path(fpath) + def get_compressed_audio_file_path(self, episode_index: int, audio_key: str) -> Path: + """Returns the path of the compressed (i.e. encoded) audio file.""" + episode_chunk = self.get_episode_chunk(episode_index) + fpath = self.audio_path.format( + episode_chunk=episode_chunk, audio_key=audio_key, episode_index=episode_index + ) + return self.root / fpath + def get_episode_chunk(self, ep_index: int) -> int: return ep_index // self.chunks_size @@ -155,6 +172,11 @@ class LeRobotDatasetMetadata: """Formattable string for the video files.""" return self.info["video_path"] + @property + def audio_path(self) -> str | None: + """Formattable string for the audio files.""" + return self.info["audio_path"] + @property def robot_type(self) -> str | None: """Robot type used in recording this dataset.""" @@ -185,6 +207,11 @@ class LeRobotDatasetMetadata: """Keys to access visual modalities (regardless of their storage method).""" return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]] + @property + def audio_keys(self) -> list[str]: + """Keys to access audio modalities.""" + return [key for key, ft in self.features.items() if ft["dtype"] == "audio"] + @property def names(self) -> dict[str, list | dict]: """Names of the various dimensions of vector modalities.""" @@ -264,6 +291,10 @@ class LeRobotDatasetMetadata: if len(self.video_keys) > 0: self.update_video_info() + self.info["total_audio"] += len(self.audio_keys) + if len(self.audio_keys) > 0: + self.update_audio_info() + write_info(self.info, self.root) episode_dict = { @@ -288,6 +319,19 @@ class LeRobotDatasetMetadata: video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key) self.info["features"][key]["info"] = get_video_info(video_path) + def update_audio_info(self) -> None: + """ + Warning: this function writes info from first episode audio, implicitly assuming that all audio have + been encoded the same way. Also, this means it assumes the first episode exists. + """ + for key in self.audio_keys: + if ( + not self.features[key].get("info", None) + or (len(self.features[key]["info"]) == 1 and "sample_rate" in self.features[key]["info"]) + ): # Overwrite if info is empty or only contains sample rate (necessary to correctly save audio files recorded by LeKiwi) + audio_path = self.root / self.get_compressed_audio_file_path(0, key) + self.info["features"][key]["info"] = get_audio_info(audio_path) + def __repr__(self): feature_keys = list(self.features) return ( @@ -363,7 +407,9 @@ class LeRobotDataset(torch.utils.data.Dataset): revision: str | None = None, force_cache_sync: bool = False, download_videos: bool = True, + download_audio: bool = True, video_backend: str | None = None, + audio_backend: str | None = None, ): """ 2 modes are available for instantiating this class, depending on 2 different use cases: @@ -394,7 +440,8 @@ class LeRobotDataset(torch.utils.data.Dataset): - tasks contains the prompts for each task of the dataset, which can be used for task-conditioned training. - hf_dataset (from datasets.Dataset), which will read any values from parquet files. - - videos (optional) from which frames are loaded to be synchronous with data from parquet files. + - videos (optional) from which frames and audio (if any) are loaded to be synchronous with data from parquet files and audio. + - audio (optional) from which audio is loaded to be synchronous with data from parquet files and videos. A typical LeRobotDataset looks like this from its root path: . @@ -415,17 +462,31 @@ class LeRobotDataset(torch.utils.data.Dataset): │ ├── info.json │ ├── stats.json │ └── tasks.jsonl - └── videos + ├── videos + │ ├── chunk-000 + │ │ ├── observation.images.laptop + │ │ │ ├── episode_000000.mp4 + │ │ │ ├── episode_000001.mp4 + │ │ │ ├── episode_000002.mp4 + │ │ │ └── ... + │ │ ├── observation.images.phone + │ │ │ ├── episode_000000.mp4 + │ │ │ ├── episode_000001.mp4 + │ │ │ ├── episode_000002.mp4 + │ │ │ └── ... + │ ├── chunk-001 + │ └── ... + └── audio ├── chunk-000 - │ ├── observation.images.laptop - │ │ ├── episode_000000.mp4 - │ │ ├── episode_000001.mp4 - │ │ ├── episode_000002.mp4 + │ ├── observation.audio.laptop + │ │ ├── episode_000000.m4a + │ │ ├── episode_000001.m4a + │ │ ├── episode_000002.m4a │ │ └── ... - │ ├── observation.images.phone - │ │ ├── episode_000000.mp4 - │ │ ├── episode_000001.mp4 - │ │ ├── episode_000002.mp4 + │ ├── observation.audio.phone + │ │ ├── episode_000000.m4a + │ │ ├── episode_000001.m4a + │ │ ├── episode_000002.m4a │ │ └── ... ├── chunk-001 └── ... @@ -463,8 +524,10 @@ class LeRobotDataset(torch.utils.data.Dataset): download_videos (bool, optional): Flag to download the videos. Note that when set to True but the video files are already present on local disk, they won't be downloaded again. Defaults to True. + download_audio (bool, optional): Flag to download the audio (see download_videos). Defaults to True. video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'. You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision. + audio_backend (str | None, optional): Audio backend to use for decoding audio. Defaults to 'torchaudio'. """ super().__init__() self.repo_id = repo_id @@ -475,6 +538,9 @@ class LeRobotDataset(torch.utils.data.Dataset): self.tolerance_s = tolerance_s self.revision = revision if revision else CODEBASE_VERSION self.video_backend = video_backend if video_backend else get_safe_default_codec() + self.audio_backend = ( + audio_backend if audio_backend else "trochaudio" + ) # Waiting for torchcodec release #TODO(CarolinePascal) self.delta_indices = None # Unused attributes @@ -499,7 +565,7 @@ class LeRobotDataset(torch.utils.data.Dataset): self.hf_dataset = self.load_hf_dataset() except (AssertionError, FileNotFoundError, NotADirectoryError): self.revision = get_safe_version(self.repo_id, self.revision) - self.download_episodes(download_videos) + self.download_episodes(download_videos, download_audio) self.hf_dataset = self.load_hf_dataset() self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes) @@ -510,6 +576,8 @@ class LeRobotDataset(torch.utils.data.Dataset): ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()} check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s) + # TODO(CarolinePascal) : add check for audio duration with respect to episode duration BUT this will be CPU expensive if there are many episodes ! + # Setup delta_indices if self.delta_timestamps is not None: check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s) @@ -522,6 +590,7 @@ class LeRobotDataset(torch.utils.data.Dataset): license: str | None = "apache-2.0", tag_version: bool = True, push_videos: bool = True, + push_audio: bool = True, private: bool = False, allow_patterns: list[str] | str | None = None, upload_large_folder: bool = False, @@ -530,6 +599,8 @@ class LeRobotDataset(torch.utils.data.Dataset): ignore_patterns = ["images/"] if not push_videos: ignore_patterns.append("videos/") + if not push_audio: + ignore_patterns.append("audio/") hub_api = HfApi() hub_api.create_repo( @@ -585,7 +656,7 @@ class LeRobotDataset(torch.utils.data.Dataset): ignore_patterns=ignore_patterns, ) - def download_episodes(self, download_videos: bool = True) -> None: + def download_episodes(self, download_videos: bool = True, download_audio: bool = True) -> None: """Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present @@ -594,7 +665,11 @@ class LeRobotDataset(torch.utils.data.Dataset): # TODO(rcadene, aliberts): implement faster transfer # https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads files = None - ignore_patterns = None if download_videos else "videos/" + ignore_patterns = [] + if not download_videos: + ignore_patterns.append("videos/") + if not download_audio: + ignore_patterns.append("audio/") if self.episodes is not None: files = self.get_episodes_file_paths() @@ -611,6 +686,14 @@ class LeRobotDataset(torch.utils.data.Dataset): ] fpaths += video_files + if len(self.meta.audio_keys) > 0: + audio_files = [ + str(self.meta.get_compressed_audio_file_path(ep_idx, audio_key)) + for audio_key in self.meta.audio_keys + for ep_idx in episodes + ] + fpaths += audio_files + return fpaths def load_hf_dataset(self) -> datasets.Dataset: @@ -677,7 +760,7 @@ class LeRobotDataset(torch.utils.data.Dataset): } return query_indices, padding - def _get_query_timestamps( + def _get_query_timestamps_video( self, current_ts: float, query_indices: dict[str, list[int]] | None = None, @@ -692,11 +775,27 @@ class LeRobotDataset(torch.utils.data.Dataset): return query_timestamps + # TODO(CarolinePascal): add variable query durations + def _get_query_timestamps_audio( + self, + current_ts: float, + query_indices: dict[str, list[int]] | None = None, + ) -> dict[str, list[float]]: + query_timestamps = {} + for key in self.meta.audio_keys: + if query_indices is not None and key in query_indices: + timestamps = self.hf_dataset.select(query_indices[key])["timestamp"] + query_timestamps[key] = torch.stack(timestamps).tolist() + else: + query_timestamps[key] = [current_ts] + + return query_timestamps + def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict: return { key: torch.stack(self.hf_dataset.select(q_idx)[key]) for key, q_idx in query_indices.items() - if key not in self.meta.video_keys + if key not in self.meta.video_keys and key not in self.meta.audio_keys } def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]: @@ -713,6 +812,17 @@ class LeRobotDataset(torch.utils.data.Dataset): return item + # TODO(CarolinePascal): add variable query durations + def _query_audio( + self, query_timestamps: dict[str, list[float]], query_duration: float, ep_idx: int + ) -> dict[str, torch.Tensor]: + item = {} + for audio_key, query_ts in query_timestamps.items(): + audio_path = self.root / self.meta.get_compressed_audio_file_path(ep_idx, audio_key) + audio_chunk = decode_audio(audio_path, query_ts, query_duration, self.audio_backend) + item[audio_key] = audio_chunk.squeeze(0) + return item + def _add_padding_keys(self, item: dict, padding: dict[str, list[bool]]) -> dict: for key, val in padding.items(): item[key] = torch.BoolTensor(val) @@ -733,11 +843,16 @@ class LeRobotDataset(torch.utils.data.Dataset): for key, val in query_result.items(): item[key] = val - if len(self.meta.video_keys) > 0: + if len(self.meta.video_keys) > 0 or len(self.meta.audio_keys) > 0: current_ts = item["timestamp"].item() - query_timestamps = self._get_query_timestamps(current_ts, query_indices) + + query_timestamps = self._get_query_timestamps_video(current_ts, query_indices) video_frames = self._query_videos(query_timestamps, ep_idx) - item = {**video_frames, **item} + item = {**item, **video_frames} + + query_timestamps = self._get_query_timestamps_audio(current_ts, query_indices) + audio_chunks = self._query_audio(query_timestamps, DEFAULT_AUDIO_CHUNK_DURATION, ep_idx) + item = {**item, **audio_chunks} if self.image_transforms is not None: image_keys = self.meta.camera_keys @@ -777,6 +892,10 @@ class LeRobotDataset(torch.utils.data.Dataset): ) return self.root / fpath + def _get_raw_audio_file_path(self, episode_index: int, audio_key: str) -> Path: + fpath = DEFAULT_RAW_AUDIO_PATH.format(audio_key=audio_key, episode_index=episode_index) + return self.root / fpath + def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None: if self.image_writer is None: if isinstance(image, torch.Tensor): @@ -827,11 +946,39 @@ class LeRobotDataset(torch.utils.data.Dataset): img_path.parent.mkdir(parents=True, exist_ok=True) self._save_image(frame[key], img_path) self.episode_buffer[key].append(str(img_path)) + elif self.features[key]["dtype"] == "audio": + if ( + self.meta.robot_type is not None and self.meta.robot_type.startswith("lekiwi") + ): # Raw data storage should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner + self.episode_buffer[key].append(frame[key]) + else: # Otherwise, only the audio file path is stored in the episode buffer + if frame_index == 0: + audio_path = self._get_raw_audio_file_path( + episode_index=self.episode_buffer["episode_index"], audio_key=key + ) + self.episode_buffer[key].append(str(audio_path)) else: self.episode_buffer[key].append(frame[key]) self.episode_buffer["size"] += 1 + def add_microphone_recording(self, microphone: Microphone, microphone_key: str) -> None: + """ + Starts recording audio data provided by the microphone and directly writes it in a .wav file. + """ + + audio_dir = self._get_raw_audio_file_path( + self.num_episodes, "observation.audio." + microphone_key + ).parent + if not audio_dir.is_dir(): + audio_dir.mkdir(parents=True, exist_ok=True) + + microphone.start_recording( + output_file=self._get_raw_audio_file_path( + self.num_episodes, "observation.audio." + microphone_key + ) + ) + def save_episode(self, episode_data: dict | None = None) -> None: """ This will save to disk the current episode in self.episode_buffer. @@ -869,16 +1016,39 @@ class LeRobotDataset(torch.utils.data.Dataset): # are processed separately by storing image path and frame info as meta data if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]: continue + elif ft["dtype"] == "audio": + if ( + self.meta.robot_type is not None and self.meta.robot_type.startswith("lekiwi") + ): # Raw data storage should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner + episode_buffer[key] = np.concatenate(episode_buffer[key], axis=0) + continue episode_buffer[key] = np.stack(episode_buffer[key]) self._wait_image_writer() self._save_episode_table(episode_buffer, episode_index) + + if ( + self.meta.robot_type is not None and self.meta.robot_type.startswith("lekiwi") + ): # Raw data storage should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner + for key in self.meta.audio_keys: + audio_path = self._get_raw_audio_file_path( + episode_index=self.episode_buffer["episode_index"][0], audio_key=key + ) + with sf.SoundFile( + audio_path, + mode="w", + samplerate=self.meta.features[key]["info"]["sample_rate"], + channels=self.meta.features[key]["shape"][0], + ) as file: + file.write(episode_buffer[key]) + ep_stats = compute_episode_stats(episode_buffer, self.features) if len(self.meta.video_keys) > 0: - video_paths = self.encode_episode_videos(episode_index) - for key in self.meta.video_keys: - episode_buffer[key] = video_paths[key] + self.encode_episode_videos(episode_index) + + if len(self.meta.audio_keys) > 0: + self.encode_episode_audio(episode_index) # `meta.save_episode` be executed after encoding the videos self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats) @@ -904,6 +1074,13 @@ class LeRobotDataset(torch.utils.data.Dataset): if img_dir.is_dir(): shutil.rmtree(self.root / "images") + # delete raw audio files + raw_audio_files = list(self.root.rglob("*.wav")) + for raw_audio_file in raw_audio_files: + raw_audio_file.unlink() + if len(list(raw_audio_file.parent.iterdir())) == 0: + raw_audio_file.parent.rmdir() + if not episode_data: # Reset the buffer self.episode_buffer = self.create_episode_buffer() @@ -971,19 +1148,40 @@ class LeRobotDataset(torch.utils.data.Dataset): since video encoding with ffmpeg is already using multithreading. """ video_paths = {} - for key in self.meta.video_keys: - video_path = self.root / self.meta.get_video_file_path(episode_index, key) - video_paths[key] = str(video_path) + for video_key in self.meta.video_keys: + video_path = self.root / self.meta.get_video_file_path(episode_index, video_key) + video_paths[video_key] = str(video_path) if video_path.is_file(): # Skip if video is already encoded. Could be the case when resuming data recording. continue img_dir = self._get_image_file_path( - episode_index=episode_index, image_key=key, frame_index=0 + episode_index=episode_index, image_key=video_key, frame_index=0 ).parent + encode_video_frames(img_dir, video_path, self.fps, overwrite=True) return video_paths + def encode_episode_audio(self, episode_index: int) -> dict: + """ + Use ffmpeg to convert .wav raw audio files into .m4a audio files. + Note: `encode_episode_audio` is a blocking call. Making it asynchronous shouldn't speedup encoding, + since video encoding with ffmpeg is already using multithreading. + """ + audio_paths = {} + for audio_key in self.meta.audio_keys: + input_audio_path = self.root / self._get_raw_audio_file_path(episode_index, audio_key) + output_audio_path = self.root / self.meta.get_compressed_audio_file_path(episode_index, audio_key) + + audio_paths[audio_key] = str(output_audio_path) + if output_audio_path.is_file(): + # Skip if video is already encoded. Could be the case when resuming data recording. + continue + + encode_audio(input_audio_path, output_audio_path, overwrite=True) + + return audio_paths + @classmethod def create( cls, @@ -998,6 +1196,7 @@ class LeRobotDataset(torch.utils.data.Dataset): image_writer_processes: int = 0, image_writer_threads: int = 0, video_backend: str | None = None, + audio_backend: str | None = None, ) -> "LeRobotDataset": """Create a LeRobot Dataset from scratch in order to record data.""" obj = cls.__new__(cls) @@ -1029,6 +1228,9 @@ class LeRobotDataset(torch.utils.data.Dataset): obj.delta_indices = None obj.episode_data_index = None obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec() + obj.audio_backend = ( + audio_backend if audio_backend is not None else "trochaudio" + ) # Waiting for torchcodec release #TODO(CarolinePascal) return obj @@ -1049,6 +1251,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): tolerances_s: dict | None = None, download_videos: bool = True, video_backend: str | None = None, + audio_backend: str | None = None, ): super().__init__() self.repo_ids = repo_ids @@ -1066,6 +1269,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): tolerance_s=self.tolerances_s[repo_id], download_videos=download_videos, video_backend=video_backend, + audio_backend=audio_backend, ) for repo_id in repo_ids ] diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 9d8a54db..7ac46013 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -33,6 +33,7 @@ from datasets.table import embed_table_storage from huggingface_hub import DatasetCard, DatasetCardData, HfApi from huggingface_hub.errors import RevisionNotFoundError from PIL import Image as PILImage +from soundfile import read from torchvision import transforms from lerobot.common.datasets.backward_compatibility import ( @@ -55,6 +56,10 @@ TASKS_PATH = "meta/tasks.jsonl" DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4" DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet" DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png" +DEFAULT_RAW_AUDIO_PATH = "audio/{audio_key}/episode_{episode_index:06d}.wav" +DEFAULT_COMPRESSED_AUDIO_PATH = "audio/chunk-{episode_chunk:03d}/{audio_key}/episode_{episode_index:06d}.m4a" + +DEFAULT_AUDIO_CHUNK_DURATION = 0.5 # seconds DATASET_CARD_TEMPLATE = """ --- @@ -255,6 +260,11 @@ def load_image_as_numpy( return img_array +def load_audio_from_path(fpath: str | Path) -> np.ndarray: + audio_data, _ = read(fpath, dtype="float32") + return audio_data + + def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]): """Get a transform function that convert items from Hugging Face dataset (pyarrow) to torch tensors. Importantly, images are converted from PIL, which corresponds to @@ -363,7 +373,7 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) -> def get_hf_features_from_features(features: dict) -> datasets.Features: hf_features = {} for key, ft in features.items(): - if ft["dtype"] == "video": + if ft["dtype"] == "video" or ft["dtype"] == "audio": continue elif ft["dtype"] == "image": hf_features[key] = datasets.Image() @@ -394,7 +404,7 @@ def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict: key: {"dtype": "video" if use_videos else "image", **ft} for key, ft in robot.camera_features.items() } - return {**robot.motor_features, **camera_ft, **DEFAULT_FEATURES} + return {**robot.motor_features, **camera_ft, **robot.microphone_features, **DEFAULT_FEATURES} def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]: @@ -442,12 +452,14 @@ def create_empty_dataset_info( "total_frames": 0, "total_tasks": 0, "total_videos": 0, + "total_audio": 0, "total_chunks": 0, "chunks_size": DEFAULT_CHUNK_SIZE, "fps": fps, "splits": {}, "data_path": DEFAULT_PARQUET_PATH, "video_path": DEFAULT_VIDEO_PATH if use_videos else None, + "audio_path": DEFAULT_COMPRESSED_AUDIO_PATH, "features": features, } @@ -721,6 +733,7 @@ def validate_features_presence( ): error_message = "" missing_features = expected_features - actual_features + missing_features = {feature for feature in missing_features if "observation.audio" not in feature} extra_features = actual_features - (expected_features | optional_features) if missing_features or extra_features: @@ -740,6 +753,8 @@ def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray return validate_feature_numpy_array(name, expected_dtype, expected_shape, value) elif expected_dtype in ["image", "video"]: return validate_feature_image_or_video(name, expected_shape, value) + elif expected_dtype == "audio": + return validate_feature_audio(name, expected_shape, value) elif expected_dtype == "string": return validate_feature_string(name, value) else: @@ -781,6 +796,23 @@ def validate_feature_image_or_video(name: str, expected_shape: list[str], value: return error_message +def validate_feature_audio(name: str, expected_shape: list[str], value: np.ndarray): + error_message = "" + if isinstance(value, np.ndarray): + actual_shape = value.shape + c = expected_shape + if len(actual_shape) != 2 or ( + actual_shape[-1] != c[-1] and actual_shape[0] != c[0] + ): # The number of frames might be different + error_message += ( + f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c,)}'.\n" + ) + else: + error_message += f"The feature '{name}' is expected to be of type 'np.ndarray', but type '{type(value)}' provided instead.\n" + + return error_message + + def validate_feature_string(name: str, value: str): if not isinstance(value, str): return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n" diff --git a/lerobot/common/datasets/video_utils.py b/lerobot/common/datasets/video_utils.py index c38d570d..fbf0b48c 100644 --- a/lerobot/common/datasets/video_utils.py +++ b/lerobot/common/datasets/video_utils.py @@ -260,35 +260,39 @@ def encode_video_frames( imgs_dir = Path(imgs_dir) video_path.parent.mkdir(parents=True, exist_ok=True) - ffmpeg_args = OrderedDict( + ffmpeg_video_args = OrderedDict( [ ("-f", "image2"), ("-r", str(fps)), - ("-i", str(imgs_dir / "frame_%06d.png")), - ("-vcodec", vcodec), - ("-pix_fmt", pix_fmt), + ("-i", str(Path(imgs_dir) / "frame_%06d.png")), ] ) + ffmpeg_encoding_args = OrderedDict( + [ + ("-pix_fmt", pix_fmt), + ("-vcodec", vcodec), + ] + ) if g is not None: - ffmpeg_args["-g"] = str(g) - + ffmpeg_encoding_args["-g"] = str(g) if crf is not None: - ffmpeg_args["-crf"] = str(crf) - + ffmpeg_encoding_args["-crf"] = str(crf) if fast_decode: key = "-svtav1-params" if vcodec == "libsvtav1" else "-tune" value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode" - ffmpeg_args[key] = value + ffmpeg_encoding_args[key] = value if log_level is not None: - ffmpeg_args["-loglevel"] = str(log_level) + ffmpeg_encoding_args["-loglevel"] = str(log_level) - ffmpeg_args = [item for pair in ffmpeg_args.items() for item in pair] + ffmpeg_args = [item for pair in ffmpeg_video_args.items() for item in pair] + ffmpeg_args += [item for pair in ffmpeg_encoding_args.items() for item in pair] if overwrite: ffmpeg_args.append("-y") ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(video_path)] + # redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL) @@ -331,42 +335,6 @@ with warnings.catch_warnings(): register_feature(VideoFrame, "VideoFrame") -def get_audio_info(video_path: Path | str) -> dict: - ffprobe_audio_cmd = [ - "ffprobe", - "-v", - "error", - "-select_streams", - "a:0", - "-show_entries", - "stream=channels,codec_name,bit_rate,sample_rate,bit_depth,channel_layout,duration", - "-of", - "json", - str(video_path), - ] - result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) - if result.returncode != 0: - raise RuntimeError(f"Error running ffprobe: {result.stderr}") - - info = json.loads(result.stdout) - audio_stream_info = info["streams"][0] if info.get("streams") else None - if audio_stream_info is None: - return {"has_audio": False} - - # Return the information, defaulting to None if no audio stream is present - return { - "has_audio": True, - "audio.channels": audio_stream_info.get("channels", None), - "audio.codec": audio_stream_info.get("codec_name", None), - "audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None, - "audio.sample_rate": int(audio_stream_info["sample_rate"]) - if audio_stream_info.get("sample_rate") - else None, - "audio.bit_depth": audio_stream_info.get("bit_depth", None), - "audio.channel_layout": audio_stream_info.get("channel_layout", None), - } - - def get_video_info(video_path: Path | str) -> dict: ffprobe_video_cmd = [ "ffprobe", @@ -402,7 +370,6 @@ def get_video_info(video_path: Path | str) -> dict: "video.codec": video_stream_info["codec_name"], "video.pix_fmt": video_stream_info["pix_fmt"], "video.is_depth_map": False, - **get_audio_info(video_path), } return video_info diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py index 4e42a989..111e21f0 100644 --- a/lerobot/common/robot_devices/control_utils.py +++ b/lerobot/common/robot_devices/control_utils.py @@ -78,6 +78,11 @@ def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, f if key in robot.logs: log_dt(f"dtR{name}", robot.logs[key]) + for name in robot.microphones: + key = f"read_microphone_{name}_dt_s" + if key in robot.logs: + log_dt(f"dtR{name}", robot.logs[key]) + info_str = " ".join(log_items) logging.info(info_str) @@ -107,11 +112,15 @@ def predict_action(observation, policy, device, use_amp): torch.inference_mode(), torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(), ): - # Convert to pytorch format: channel first and float32 in [0,1] with batch dimension for name in observation: + # Convert to pytorch format: channel first and float32 in [0,1] with batch dimension if "image" in name: observation[name] = observation[name].type(torch.float32) / 255 observation[name] = observation[name].permute(2, 0, 1).contiguous() + # Convert to pytorch format: channel first and float32 in [-1,1] with batch dimension + if "audio" in name: + observation[name] = observation[name].type(torch.float32) + observation[name] = observation[name].permute(1, 0).contiguous() observation[name] = observation[name].unsqueeze(0) observation[name] = observation[name].to(device) @@ -243,6 +252,18 @@ def control_loop( timestamp = 0 start_episode_t = time.perf_counter() + + if ( + dataset is not None and not robot.robot_type.startswith("lekiwi") + ): # For now, LeKiwi only supports frame audio recording (which may lead to audio chunks loss, extended post-processing, increased memory usage) + for microphone_key, microphone in robot.microphones.items(): + # Start recording both in file writing and data reading mode + dataset.add_microphone_recording(microphone, microphone_key) + else: + for _, microphone in robot.microphones.items(): + # Start recording only in data reading mode + microphone.start_recording() + while timestamp < control_time_s: start_loop_t = time.perf_counter() @@ -286,6 +307,9 @@ def control_loop( events["exit_early"] = False break + for _, microphone in robot.microphones.items(): + microphone.stop_recording() + def reset_environment(robot, events, reset_time_s, fps): # TODO(rcadene): refactor warmup_record and reset_environment diff --git a/lerobot/common/robot_devices/microphones/configs.py b/lerobot/common/robot_devices/microphones/configs.py new file mode 100644 index 00000000..1b663b7a --- /dev/null +++ b/lerobot/common/robot_devices/microphones/configs.py @@ -0,0 +1,38 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from dataclasses import dataclass + +import draccus + + +@dataclass +class MicrophoneConfigBase(draccus.ChoiceRegistry, abc.ABC): + @property + def type(self) -> str: + return self.get_choice_name(self.__class__) + + +@MicrophoneConfigBase.register_subclass("microphone") +@dataclass +class MicrophoneConfig(MicrophoneConfigBase): + """ + Dataclass for microphone configuration. + """ + + microphone_index: int + sample_rate: int | None = None + channels: list[int] | None = None + mock: bool = False diff --git a/lerobot/common/robot_devices/microphones/microphone.py b/lerobot/common/robot_devices/microphones/microphone.py new file mode 100644 index 00000000..a92be011 --- /dev/null +++ b/lerobot/common/robot_devices/microphones/microphone.py @@ -0,0 +1,425 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file contains utilities for recording audio from a microhone. +""" + +import argparse +import logging +import shutil +import time +from multiprocessing import Event as process_Event +from multiprocessing import JoinableQueue as process_Queue +from multiprocessing import Process +from os import getcwd +from pathlib import Path +from queue import Empty +from queue import Queue as thread_Queue +from threading import Event, Thread +from threading import Event as thread_Event + +import numpy as np +import soundfile as sf + +from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig +from lerobot.common.robot_devices.utils import ( + RobotDeviceAlreadyConnectedError, + RobotDeviceAlreadyRecordingError, + RobotDeviceNotConnectedError, + RobotDeviceNotRecordingError, +) +from lerobot.common.utils.utils import capture_timestamp_utc + + +def find_microphones(raise_when_empty=False, mock=False) -> list[dict]: + """ + Finds and lists all microphones compatible with sounddevice (and the underlying PortAudio library). + Most microphones and sound cards are compatible, across all OS (Linux, Mac, Windows). + """ + microphones = [] + + if mock: + import tests.microphones.mock_sounddevice as sd + else: + import sounddevice as sd + + devices = sd.query_devices() + for device in devices: + if device["max_input_channels"] > 0: + microphones.append( + { + "index": device["index"], + "name": device["name"], + } + ) + + if raise_when_empty and len(microphones) == 0: + raise OSError( + "Not a single microphone was detected. Try re-plugging the microphone or check the microphone settings." + ) + + return microphones + + +def record_audio_from_microphones( + output_dir: Path, microphone_ids: list[int] | None = None, record_time_s: float = 2.0 +): + """ + Records audio from all the channels of the specified microphones for the specified duration. + If no microphone ids are provided, all available microphones will be used. + """ + + if microphone_ids is None or len(microphone_ids) == 0: + microphones = find_microphones() + microphone_ids = [m["index"] for m in microphones] + + microphones = [] + for microphone_id in microphone_ids: + config = MicrophoneConfig(microphone_index=microphone_id) + microphone = Microphone(config) + microphone.connect() + print( + f"Recording audio from microphone {microphone_id} for {record_time_s} seconds at {microphone.sample_rate} Hz." + ) + microphones.append(microphone) + + output_dir = Path(output_dir) + if output_dir.exists(): + shutil.rmtree( + output_dir, + ) + output_dir.mkdir(parents=True, exist_ok=True) + print(f"Saving audio to {output_dir}") + + for microphone in microphones: + microphone.start_recording(getcwd() / output_dir / f"microphone_{microphone.microphone_index}.wav") + + time.sleep(record_time_s) + + for microphone in microphones: + microphone.stop_recording() + + # Remark : recording may be resumed here if needed + + for microphone in microphones: + microphone.disconnect() + + print(f"Images have been saved to {output_dir}") + + +class Microphone: + """ + The Microphone class handles all microphones compatible with sounddevice (and the underlying PortAudio library). Most microphones and sound cards are compatible, across all OS (Linux, Mac, Windows). + + A Microphone instance requires the sounddevice index of the microphone, which may be obtained using `python -m sounddevice`. It also requires the recording sample rate as well as the list of recorded channels. + + Example of usage: + ```python + from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig + + config = MicrophoneConfig(microphone_index=0, sample_rate=16000, channels=[1]) + microphone = Microphone(config) + + microphone.connect() + microphone.start_recording("some/output/file.wav") + ... + audio_readings = microphone.read() #Gets all recorded audio data since the last read or since the beginning of the recording + ... + microphone.stop_recording() + microphone.disconnect() + ``` + """ + + def __init__(self, config: MicrophoneConfig): + self.config = config + self.microphone_index = config.microphone_index + + # Store the recording sample rate and channels + self.sample_rate = config.sample_rate + self.channels = config.channels + + self.mock = config.mock + + # Input audio stream + self.stream = None + + # Thread/Process-safe concurrent queue to store the recorded/read audio + self.record_queue = None + self.read_queue = None + + # Thread/Process to handle data reading and file writing in a separate thread/process (safely) + self.record_thread = None + self.record_stop_event = None + + self.logs = {} + self.is_connected = False + self.is_recording = False + self.is_writing = False + + def connect(self) -> None: + """ + Connects the microphone and checks if the requested acquisition parameters are compatible with the microphone. + """ + if self.is_connected: + raise RobotDeviceAlreadyConnectedError( + f"Microphone {self.microphone_index} is already connected." + ) + + if self.mock: + import tests.microphones.mock_sounddevice as sd + else: + import sounddevice as sd + + # Check if the provided microphone index does match an input device + is_index_input = sd.query_devices(self.microphone_index)["max_input_channels"] > 0 + + if not is_index_input: + microphones_info = find_microphones() + available_microphones = [m["index"] for m in microphones_info] + raise OSError( + f"Microphone index {self.microphone_index} does not match an input device (microphone). Available input devices : {available_microphones}" + ) + + # Check if provided recording parameters are compatible with the microphone + actual_microphone = sd.query_devices(self.microphone_index) + + if self.sample_rate is not None: + if self.sample_rate > actual_microphone["default_samplerate"]: + raise OSError( + f"Provided sample rate {self.sample_rate} is higher than the sample rate of the microphone {actual_microphone['default_samplerate']}." + ) + elif self.sample_rate < actual_microphone["default_samplerate"]: + logging.warning( + "Provided sample rate is lower than the sample rate of the microphone. Performance may be impacted." + ) + else: + self.sample_rate = int(actual_microphone["default_samplerate"]) + + if self.channels is not None and len(self.channels) > 0: + if any(c > actual_microphone["max_input_channels"] for c in self.channels): + raise OSError( + f"Some of the provided channels {self.channels} are outside the maximum channel range of the microphone {actual_microphone['max_input_channels']}." + ) + else: + self.channels = np.arange(1, actual_microphone["max_input_channels"] + 1) + + # Get channels index instead of number for slicing + self.channels_index = np.array(self.channels) - 1 + + # Create the audio stream + self.stream = sd.InputStream( + device=self.microphone_index, + samplerate=self.sample_rate, + channels=max(self.channels), + dtype="float32", + callback=self._audio_callback, + ) + # Remark : the blocksize parameter could be passed to the stream to ensure that audio_callback always receive same length buffers. + # However, this may lead to additional latency. We thus stick to blocksize=0 which means that audio_callback will receive varying length buffers, but with no additional latency. + + self.is_connected = True + + def _audio_callback(self, indata, frames, time, status) -> None: + """ + Low-level sounddevice callback. + """ + if status: + logging.warning(status) + # Slicing makes copy unnecessary + # Two separate queues are necessary because .get() also pops the data from the queue + if self.is_writing: + self.record_queue.put(indata[:, self.channels_index]) + self.read_queue.put(indata[:, self.channels_index]) + + @staticmethod + def _record_loop(queue, event: Event, sample_rate: int, channels: list[int], output_file: Path) -> None: + """ + Thread/Process-safe loop to write audio data into a file. + """ + # Can only be run on a single process/thread for file writing safety + with sf.SoundFile( + output_file, + mode="x", + samplerate=sample_rate, + channels=max(channels), + subtype=sf.default_subtype(output_file.suffix[1:]), + ) as file: + while not event.is_set(): + try: + file.write( + queue.get(timeout=0.02) + ) # Timeout set as twice the usual sounddevice buffer size + queue.task_done() + except Empty: + continue + + def _read(self) -> np.ndarray: + """ + Thread/Process-safe callback to read available audio data + """ + audio_readings = np.empty((0, len(self.channels))) + + while True: + try: + audio_readings = np.concatenate((audio_readings, self.read_queue.get_nowait()), axis=0) + except Empty: + break + + self.read_queue = thread_Queue() + + return audio_readings + + def read(self) -> np.ndarray: + """ + Reads the last audio chunk recorded by the microphone, e.g. all samples recorded since the last read or since the beginning of the recording. + """ + if not self.is_connected: + raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.") + if not self.is_recording: + raise RobotDeviceNotRecordingError(f"Microphone {self.microphone_index} is not recording.") + + start_time = time.perf_counter() + + audio_readings = self._read() + + # log the number of seconds it took to read the audio chunk + self.logs["delta_timestamp_s"] = time.perf_counter() - start_time + + # log the utc time at which the audio chunk was received + self.logs["timestamp_utc"] = capture_timestamp_utc() + + return audio_readings + + def start_recording(self, output_file: str | None = None, multiprocessing: bool | None = False) -> None: + """ + Starts the recording of the microphone. If output_file is provided, the audio will be written to this file. + """ + if not self.is_connected: + raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.") + if self.is_recording: + raise RobotDeviceAlreadyRecordingError( + f"Microphone {self.microphone_index} is already recording." + ) + + # Reset queues + self.read_queue = thread_Queue() + if multiprocessing: + self.record_queue = process_Queue() + else: + self.record_queue = thread_Queue() + + # Write recordings into a file if output_file is provided + if output_file is not None: + output_file = Path(output_file) + if output_file.exists(): + output_file.unlink() + + if multiprocessing: + self.record_stop_event = process_Event() + self.record_thread = Process( + target=Microphone._record_loop, + args=( + self.record_queue, + self.record_stop_event, + self.sample_rate, + self.channels, + output_file, + ), + ) + else: + self.record_stop_event = thread_Event() + self.record_thread = Thread( + target=Microphone._record_loop, + args=( + self.record_queue, + self.record_stop_event, + self.sample_rate, + self.channels, + output_file, + ), + ) + self.record_thread.daemon = True + self.record_thread.start() + + self.is_writing = True + + self.is_recording = True + self.stream.start() + + def stop_recording(self) -> None: + """ + Stops the recording of the microphones. + """ + if not self.is_connected: + raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.") + if not self.is_recording: + raise RobotDeviceNotRecordingError(f"Microphone {self.microphone_index} is not recording.") + + if self.stream.active: + self.stream.stop() # Wait for all buffers to be processed + # Remark : stream.abort() flushes the buffers ! + self.is_recording = False + + if self.record_thread is not None: + self.record_queue.join() + self.record_stop_event.set() + self.record_thread.join() + self.record_thread = None + self.record_stop_event = None + self.is_writing = False + + def disconnect(self) -> None: + """ + Disconnects the microphone and stops the recording. + """ + if not self.is_connected: + raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.") + + if self.is_recording: + self.stop_recording() + + self.stream.close() + self.is_connected = False + + def __del__(self): + if getattr(self, "is_connected", False): + self.disconnect() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Records audio using `Microphone` for all microphones connected to the computer, or a selected subset." + ) + parser.add_argument( + "--microphone-ids", + type=int, + nargs="*", + default=None, + help="List of microphones indices used to instantiate the `Microphone`. If not provided, find and use all available microphones indices.", + ) + parser.add_argument( + "--output-dir", + type=Path, + default="outputs/audio_from_microphones", + help="Set directory to save an audio snippet for each microphone.", + ) + parser.add_argument( + "--record-time-s", + type=float, + default=4.0, + help="Set the number of seconds used to record the audio. By default, 4 seconds.", + ) + args = parser.parse_args() + record_audio_from_microphones(**vars(args)) diff --git a/lerobot/common/robot_devices/microphones/utils.py b/lerobot/common/robot_devices/microphones/utils.py new file mode 100644 index 00000000..fb1bac85 --- /dev/null +++ b/lerobot/common/robot_devices/microphones/utils.py @@ -0,0 +1,48 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Protocol + +from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig, MicrophoneConfigBase + + +# Defines a microphone type +class Microphone(Protocol): + def connect(self): ... + def disconnect(self): ... + def start_recording(self, output_file: str | None = None, multiprocessing: bool | None = False): ... + def stop_recording(self): ... + + +def make_microphones_from_configs(microphone_configs: dict[str, MicrophoneConfigBase]) -> list[Microphone]: + microphones = {} + + for key, cfg in microphone_configs.items(): + if cfg.type == "microphone": + from lerobot.common.robot_devices.microphones.microphone import Microphone + + microphones[key] = Microphone(cfg) + else: + raise ValueError(f"The microphone type '{cfg.type}' is not valid.") + + return microphones + + +def make_microphone(microphone_type, **kwargs) -> Microphone: + if microphone_type == "microphone": + from lerobot.common.robot_devices.microphones.microphone import Microphone + + return Microphone(MicrophoneConfig(**kwargs)) + else: + raise ValueError(f"The microphone type '{microphone_type}' is not valid.") diff --git a/lerobot/common/robot_devices/robots/configs.py b/lerobot/common/robot_devices/robots/configs.py index e940b442..ab362ad1 100644 --- a/lerobot/common/robot_devices/robots/configs.py +++ b/lerobot/common/robot_devices/robots/configs.py @@ -23,6 +23,7 @@ from lerobot.common.robot_devices.cameras.configs import ( IntelRealSenseCameraConfig, OpenCVCameraConfig, ) +from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig from lerobot.common.robot_devices.motors.configs import ( DynamixelMotorsBusConfig, FeetechMotorsBusConfig, @@ -43,6 +44,7 @@ class ManipulatorRobotConfig(RobotConfig): leader_arms: dict[str, MotorsBusConfig] = field(default_factory=lambda: {}) follower_arms: dict[str, MotorsBusConfig] = field(default_factory=lambda: {}) cameras: dict[str, CameraConfig] = field(default_factory=lambda: {}) + microphones: dict[str, MicrophoneConfig] = field(default_factory=lambda: {}) # Optionally limit the magnitude of the relative positional target vector for safety purposes. # Set this to a positive scalar to have the same value for all motors, or a list that is the same length @@ -68,6 +70,9 @@ class ManipulatorRobotConfig(RobotConfig): for cam in self.cameras.values(): if not cam.mock: cam.mock = True + for mic in self.microphones.values(): + if not mic.mock: + mic.mock = True if self.max_relative_target is not None and isinstance(self.max_relative_target, Sequence): for name in self.follower_arms: @@ -491,6 +496,21 @@ class So100RobotConfig(ManipulatorRobotConfig): } ) + microphones: dict[str, MicrophoneConfig] = field( + default_factory=lambda: { + "laptop": MicrophoneConfig( + microphone_index=0, + sample_rate=48000, + channels=[1], + ), + "headset": MicrophoneConfig( + microphone_index=1, + sample_rate=44100, + channels=[1], + ), + } + ) + mock: bool = False diff --git a/lerobot/common/robot_devices/robots/lekiwi_remote.py b/lerobot/common/robot_devices/robots/lekiwi_remote.py index 7bf52d21..15023d8a 100644 --- a/lerobot/common/robot_devices/robots/lekiwi_remote.py +++ b/lerobot/common/robot_devices/robots/lekiwi_remote.py @@ -52,6 +52,16 @@ def run_camera_capture(cameras, images_lock, latest_images_dict, stop_event): time.sleep(0.01) +def run_microphone_capture(microphones, audio_lock, latest_audio_dict, stop_event): + while not stop_event.is_set(): + local_dict = {} + for name, microphone in microphones.items(): + audio_readings = microphone.read() + local_dict[name] = audio_readings + with audio_lock: + latest_audio_dict.update(local_dict) + + def calibrate_follower_arm(motors_bus, calib_dir_str): """ Calibrates the follower arm. Attempts to load an existing calibration file; @@ -94,6 +104,7 @@ def run_lekiwi(robot_config): """ # Import helper functions and classes from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs + from lerobot.common.robot_devices.microphones.utils import make_microphones_from_configs from lerobot.common.robot_devices.motors.feetech import FeetechMotorsBus, TorqueMode # Initialize cameras from the robot configuration. @@ -101,6 +112,11 @@ def run_lekiwi(robot_config): for cam in cameras.values(): cam.connect() + # Initialize microphones from the robot configuration. + microphones = make_microphones_from_configs(robot_config.microphones) + for microphone in microphones.values(): + microphone.connect() + # Initialize the motors bus using the follower arm configuration. motor_config = robot_config.follower_arms.get("main") if motor_config is None: @@ -134,6 +150,20 @@ def run_lekiwi(robot_config): ) cam_thread.start() + # Start the microphone recording and capture thread. + # TODO(CarolinePascal) : Leverage multi-core processing with a multiprocessing.Process instead ! + latest_audio_dict = {} + audio_lock = threading.Lock() + audio_stop_event = threading.Event() + microphone_thread = threading.Thread( + target=run_microphone_capture, + args=(microphones, audio_lock, latest_audio_dict, audio_stop_event), + daemon=True, + ) + for microphone in microphones.values(): + microphone.start_recording() + microphone_thread.start() + last_cmd_time = time.time() print("LeKiwi robot server started. Waiting for commands...") @@ -198,9 +228,14 @@ def run_lekiwi(robot_config): with images_lock: images_dict_copy = dict(latest_images_dict) + # Get the latest audio data. + with audio_lock: + audio_dict_copy = dict(latest_audio_dict) + # Build the observation dictionary. observation = { "images": images_dict_copy, + "audio": audio_dict_copy, # TODO(CarolinePascal) : This is a nasty way to do it, sorry. "present_speed": current_velocity, "follower_arm_state": follower_arm_state, } @@ -217,6 +252,9 @@ def run_lekiwi(robot_config): finally: stop_event.set() cam_thread.join() + microphone_thread.join() + for microphone in microphones.values(): + microphone.stop_recording() robot.stop() motors_bus.disconnect() cmd_socket.close() diff --git a/lerobot/common/robot_devices/robots/manipulator.py b/lerobot/common/robot_devices/robots/manipulator.py index 9173abc6..00bcd3db 100644 --- a/lerobot/common/robot_devices/robots/manipulator.py +++ b/lerobot/common/robot_devices/robots/manipulator.py @@ -28,6 +28,7 @@ import numpy as np import torch from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs +from lerobot.common.robot_devices.microphones.utils import make_microphones_from_configs from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs from lerobot.common.robot_devices.robots.configs import ManipulatorRobotConfig from lerobot.common.robot_devices.robots.utils import get_arm_id @@ -164,6 +165,7 @@ class ManipulatorRobot: self.leader_arms = make_motors_buses_from_configs(self.config.leader_arms) self.follower_arms = make_motors_buses_from_configs(self.config.follower_arms) self.cameras = make_cameras_from_configs(self.config.cameras) + self.microphones = make_microphones_from_configs(self.config.microphones) self.is_connected = False self.logs = {} @@ -199,9 +201,24 @@ class ManipulatorRobot: }, } + @property + def microphone_features(self) -> dict: + mic_ft = {} + for mic_key, mic in self.microphones.items(): + key = f"observation.audio.{mic_key}" + mic_ft[key] = { + "dtype": "audio", + "shape": (len(mic.channels),), + "names": "channels", + "info": { + "sample_rate": mic.sample_rate + }, # we need to store the sample rate here in the case of audio chunks recording (for LeKiwi), as it will not be available anymore when writing the audio file + } + return mic_ft + @property def features(self): - return {**self.motor_features, **self.camera_features} + return {**self.motor_features, **self.camera_features, **self.microphone_features} @property def has_camera(self): @@ -211,6 +228,14 @@ class ManipulatorRobot: def num_cameras(self): return len(self.cameras) + @property + def has_microphone(self): + return len(self.microphones) > 0 + + @property + def num_microphones(self): + return len(self.microphones) + @property def available_arms(self): available_arms = [] @@ -228,7 +253,7 @@ class ManipulatorRobot: "ManipulatorRobot is already connected. Do not run `robot.connect()` twice." ) - if not self.leader_arms and not self.follower_arms and not self.cameras: + if not self.leader_arms and not self.follower_arms and not self.cameras and not self.microphones: raise ValueError( "ManipulatorRobot doesn't have any device to connect. See example of usage in docstring of the class." ) @@ -289,6 +314,10 @@ class ManipulatorRobot: for name in self.cameras: self.cameras[name].connect() + # Connect the microphones + for name in self.microphones: + self.microphones[name].connect() + self.is_connected = True def activate_calibration(self): @@ -514,12 +543,23 @@ class ManipulatorRobot: self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"] self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t + # Capture audio from microphones + audio = {} + for name in self.microphones: + before_audioread_t = time.perf_counter() + audio[name] = self.microphones[name].read() + audio[name] = torch.from_numpy(audio[name]) + self.logs[f"read_microphone_{name}_dt_s"] = self.microphones[name].logs["delta_timestamp_s"] + self.logs[f"async_read_microphone_{name}_dt_s"] = time.perf_counter() - before_audioread_t + # Populate output dictionaries obs_dict, action_dict = {}, {} obs_dict["observation.state"] = state action_dict["action"] = action for name in self.cameras: obs_dict[f"observation.images.{name}"] = images[name] + for name in self.microphones: + obs_dict[f"observation.audio.{name}"] = audio[name] return obs_dict, action_dict @@ -554,11 +594,22 @@ class ManipulatorRobot: self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"] self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t + # Capture audio from microphones + audio = {} + for name in self.microphones: + before_audioread_t = time.perf_counter() + audio[name] = self.microphones[name].read() + audio[name] = torch.from_numpy(audio[name]) + self.logs[f"read_microphone_{name}_dt_s"] = self.microphones[name].logs["delta_timestamp_s"] + self.logs[f"async_read_microphone_{name}_dt_s"] = time.perf_counter() - before_audioread_t + # Populate output dictionaries and format to pytorch obs_dict = {} obs_dict["observation.state"] = state for name in self.cameras: obs_dict[f"observation.images.{name}"] = images[name] + for name in self.microphones: + obs_dict[f"observation.audio.{name}"] = audio[name] return obs_dict def send_action(self, action: torch.Tensor) -> torch.Tensor: @@ -620,6 +671,9 @@ class ManipulatorRobot: for name in self.cameras: self.cameras[name].disconnect() + for name in self.microphones: + self.microphones[name].disconnect() + self.is_connected = False def __del__(self): diff --git a/lerobot/common/robot_devices/robots/mobile_manipulator.py b/lerobot/common/robot_devices/robots/mobile_manipulator.py index 385e218b..4af008ed 100644 --- a/lerobot/common/robot_devices/robots/mobile_manipulator.py +++ b/lerobot/common/robot_devices/robots/mobile_manipulator.py @@ -24,6 +24,7 @@ import torch import zmq from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs +from lerobot.common.robot_devices.microphones.utils import make_microphones_from_configs from lerobot.common.robot_devices.motors.feetech import TorqueMode from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs from lerobot.common.robot_devices.robots.configs import LeKiwiRobotConfig @@ -79,6 +80,7 @@ class MobileManipulator: self.follower_arms = make_motors_buses_from_configs(self.config.follower_arms) self.cameras = make_cameras_from_configs(self.config.cameras) + self.microphones = make_microphones_from_configs(self.config.microphones) self.is_connected = False @@ -133,6 +135,7 @@ class MobileManipulator: "shape": (cam.height, cam.width, cam.channels), "names": ["height", "width", "channels"], "info": None, + "audio": "observation.audio." + cam.microphone if cam.microphone is not None else None, } return cam_ft @@ -161,9 +164,22 @@ class MobileManipulator: }, } + @property + def microphone_features(self) -> dict: + mic_ft = {} + for mic_key, mic in self.microphones.items(): + key = f"observation.audio.{mic_key}" + mic_ft[key] = { + "dtype": "audio", + "shape": (len(mic.channels),), + "names": "channels", + "info": {"sample_rate": mic.sample_rate}, + } + return mic_ft + @property def features(self): - return {**self.motor_features, **self.camera_features} + return {**self.motor_features, **self.camera_features, **self.microphone_features} @property def has_camera(self): @@ -173,6 +189,14 @@ class MobileManipulator: def num_cameras(self): return len(self.cameras) + @property + def has_microphone(self): + return len(self.microphones) > 0 + + @property + def num_microphones(self): + return len(self.microphones) + @property def available_arms(self): available = [] @@ -344,6 +368,7 @@ class MobileManipulator: observation = json.loads(last_msg) images_dict = observation.get("images", {}) + audio_dict = observation.get("audio", {}) new_speed = observation.get("present_speed", {}) new_arm_state = observation.get("follower_arm_state", None) @@ -356,6 +381,11 @@ class MobileManipulator: if frame_candidate is not None: frames[cam_name] = frame_candidate + # Receive audio + for microphone_name, audio_data in audio_dict.items(): + if audio_data: + frames[microphone_name] = audio_data + # If remote_arm_state is None and frames is None there is no message then use the previous message if new_arm_state is not None and frames is not None: self.last_frames = frames @@ -475,6 +505,14 @@ class MobileManipulator: frame = np.zeros((cam.height, cam.width, cam.channels), dtype=np.uint8) obs_dict[f"observation.images.{cam_name}"] = torch.from_numpy(frame) + # Loop over each configured microphone + for microphone_name, microphone in self.microphones.items(): + frame = frames.get(microphone_name, None) + if frame is None: + # Create silence using the microphone's configured channels + frame = np.zeros((1, len(microphone.channels)), dtype=np.float32) + obs_dict[f"observation.audio.{microphone_name}"] = torch.from_numpy(frame) + return obs_dict def send_action(self, action: torch.Tensor) -> torch.Tensor: diff --git a/lerobot/common/robot_devices/utils.py b/lerobot/common/robot_devices/utils.py index 837c9d2e..5b2270e7 100644 --- a/lerobot/common/robot_devices/utils.py +++ b/lerobot/common/robot_devices/utils.py @@ -63,3 +63,25 @@ class RobotDeviceAlreadyConnectedError(Exception): ): self.message = message super().__init__(self.message) + + +class RobotDeviceNotRecordingError(Exception): + """Exception raised when the robot device is not recording.""" + + def __init__( + self, + message="This robot device is not recording. Try calling `robot_device.start_recording()` first.", + ): + self.message = message + super().__init__(self.message) + + +class RobotDeviceAlreadyRecordingError(Exception): + """Exception raised when the robot device is already recording.""" + + def __init__( + self, + message="This robot device is already recording. Try not calling `robot_device.start_recording()` twice.", + ): + self.message = message + super().__init__(self.message) diff --git a/pyproject.toml b/pyproject.toml index 4b858634..73f68a72 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,8 +67,11 @@ dependencies = [ "pynput>=1.7.7", "pyzmq>=26.2.1", "rerun-sdk>=0.21.0", + "sounddevice>=0.5.1", + "soundfile>=0.13.1", "termcolor>=2.4.0", "torch>=2.2.1", + "torchaudio>=2.6.0", "torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", "torchvision>=0.21.0", "wandb>=0.16.3", @@ -96,6 +99,7 @@ test = ["pytest>=8.1.0", "pytest-cov>=5.0.0", "pyserial>=3.5"] umi = ["imagecodecs>=2024.1.1"] video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"] xarm = ["gym-xarm>=0.1.1 ; python_version < '4.0'"] +audio = ["librosa>=0.11.0"] [tool.poetry] requires-poetry = ">=2.1" diff --git a/tests/conftest.py b/tests/conftest.py index 7eec94bf..a0e9ae41 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,9 +19,9 @@ import traceback import pytest from serial import SerialException -from lerobot import available_cameras, available_motors, available_robots +from lerobot import available_cameras, available_microphones, available_motors, available_robots from lerobot.common.robot_devices.robots.utils import make_robot -from tests.utils import DEVICE, make_camera, make_motors_bus +from tests.utils import DEVICE, make_camera, make_microphone, make_motors_bus # Import fixture modules as plugins pytest_plugins = [ @@ -74,6 +74,11 @@ def is_camera_available(camera_type): return _check_component_availability(camera_type, available_cameras, make_camera) +@pytest.fixture +def is_microphone_available(microphone_type): + return _check_component_availability(microphone_type, available_microphones, make_microphone) + + @pytest.fixture def is_motor_available(motor_type): return _check_component_availability(motor_type, available_motors, make_motors_bus) diff --git a/tests/datasets/test_compute_stats.py b/tests/datasets/test_compute_stats.py index d9032c8a..2ebf95f2 100644 --- a/tests/datasets/test_compute_stats.py +++ b/tests/datasets/test_compute_stats.py @@ -25,6 +25,8 @@ from lerobot.common.datasets.compute_stats import ( compute_episode_stats, estimate_num_samples, get_feature_stats, + sample_audio_from_data, + sample_audio_from_path, sample_images, sample_indices, ) @@ -34,6 +36,10 @@ def mock_load_image_as_numpy(path, dtype, channel_first): return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype) +def mock_load_audio(path): + return np.ones((16000, 2), dtype=np.float32) + + @pytest.fixture def sample_array(): return np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) @@ -71,6 +77,25 @@ def test_sample_images(mock_load): assert len(images) == estimate_num_samples(100) +@patch("lerobot.common.datasets.compute_stats.load_audio_from_path", side_effect=mock_load_audio) +def test_sample_audio_from_path(mock_load): + audio_path = "audio.wav" + audio_samples = sample_audio_from_path(audio_path) + assert isinstance(audio_samples, np.ndarray) + assert audio_samples.shape[1] == 2 + assert audio_samples.dtype == np.float32 + assert len(audio_samples) == estimate_num_samples(16000) + + +def test_sample_audio_from_data(): + audio_data = np.ones((16000, 2), dtype=np.float32) + audio_samples = sample_audio_from_data(audio_data) + assert isinstance(audio_samples, np.ndarray) + assert audio_samples.shape[1] == 2 + assert audio_samples.dtype == np.float32 + assert len(audio_samples) == estimate_num_samples(16000) + + def test_get_feature_stats_images(): data = np.random.rand(100, 3, 32, 32) stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True) @@ -79,6 +104,14 @@ def test_get_feature_stats_images(): assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape +def test_get_feature_stats_audio(): + data = np.random.uniform(-1, 1, (16000, 2)) + stats = get_feature_stats(data, axis=0, keepdims=True) + assert "min" in stats and "max" in stats and "mean" in stats and "std" in stats and "count" in stats + np.testing.assert_equal(stats["count"], np.array([16000])) + assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape + + def test_get_feature_stats_axis_0_keepdims(sample_array): expected = { "min": np.array([[1, 2, 3]]), @@ -137,22 +170,29 @@ def test_get_feature_stats_single_value(): def test_compute_episode_stats(): episode_data = { "observation.image": [f"image_{i}.jpg" for i in range(100)], + "observation.audio": "audio.wav", "observation.state": np.random.rand(100, 10), } features = { "observation.image": {"dtype": "image"}, + "observation.audio": {"dtype": "audio"}, "observation.state": {"dtype": "numeric"}, } - with patch( - "lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy + with ( + patch( + "lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy + ), + patch("lerobot.common.datasets.compute_stats.load_audio_from_path", side_effect=mock_load_audio), ): stats = compute_episode_stats(episode_data, features) - assert "observation.image" in stats and "observation.state" in stats - assert stats["observation.image"]["count"].item() == 100 - assert stats["observation.state"]["count"].item() == 100 + assert "observation.image" in stats and "observation.state" in stats and "observation.audio" in stats + assert stats["observation.image"]["count"].item() == estimate_num_samples(100) + assert stats["observation.audio"]["count"].item() == estimate_num_samples(16000) + assert stats["observation.state"]["count"].item() == estimate_num_samples(100) assert stats["observation.image"]["mean"].shape == (3, 1, 1) + assert stats["observation.audio"]["mean"].shape == (1, 2) def test_assert_type_and_shape_valid(): diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 81447089..a10a7d61 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -16,6 +16,7 @@ import json import logging import re +import time from copy import deepcopy from itertools import chain from pathlib import Path @@ -35,6 +36,7 @@ from lerobot.common.datasets.lerobot_dataset import ( MultiLeRobotDataset, ) from lerobot.common.datasets.utils import ( + DEFAULT_AUDIO_CHUNK_DURATION, create_branch, flatten_dict, unflatten_dict, @@ -44,8 +46,8 @@ from lerobot.common.policies.factory import make_policy_config from lerobot.common.robot_devices.robots.utils import make_robot from lerobot.configs.default import DatasetConfig from lerobot.configs.train import TrainPipelineConfig -from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID -from tests.utils import require_x86_64_kernel +from tests.fixtures.constants import DUMMY_AUDIO_CHANNELS, DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID +from tests.utils import make_microphone, require_x86_64_kernel @pytest.fixture @@ -64,6 +66,20 @@ def image_dataset(tmp_path, empty_lerobot_dataset_factory): return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) +@pytest.fixture +def audio_dataset(tmp_path, empty_lerobot_dataset_factory): + features = { + "observation.audio.microphone": { + "dtype": "audio", + "shape": (DUMMY_AUDIO_CHANNELS,), + "names": [ + "channels", + ], + } + } + return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) + + def test_same_attributes_defined(tmp_path, lerobot_dataset_factory): """ Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated @@ -322,6 +338,24 @@ def test_image_array_to_pil_image_wrong_range_float_0_255(): image_array_to_pil_image(image) +def test_add_frame_audio(audio_dataset): + dataset = audio_dataset + + microphone = make_microphone(microphone_type="microphone", mock=True) + microphone.connect() + + dataset.add_microphone_recording(microphone, "microphone") + time.sleep(1.0) + dataset.add_frame({"observation.audio.microphone": microphone.read(), "task": "Dummy task"}) + microphone.stop_recording() + + dataset.save_episode() + + assert dataset[0]["observation.audio.microphone"].shape == torch.Size( + (int(DEFAULT_AUDIO_CHUNK_DURATION * microphone.sample_rate), DUMMY_AUDIO_CHANNELS) + ) + + # TODO(aliberts): # - [ ] test various attributes & state from init and create # - [ ] test init with episodes and check num_frames @@ -354,6 +388,7 @@ def test_factory(env_name, repo_id, policy_name): dataset = make_dataset(cfg) delta_timestamps = dataset.delta_timestamps camera_keys = dataset.meta.camera_keys + audio_keys = dataset.meta.audio_keys item = dataset[0] @@ -396,6 +431,11 @@ def test_factory(env_name, repo_id, policy_name): # test c,h,w assert item[key].shape[0] == 3, f"{key}" + for key in audio_keys: + assert item[key].dtype == torch.float32, f"{key}" + assert item[key].max() <= 1.0, f"{key}" + assert item[key].min() >= -1.0, f"{key}" + if delta_timestamps is not None: # test missing keys in delta_timestamps for key in delta_timestamps: diff --git a/tests/fixtures/constants.py b/tests/fixtures/constants.py index 5e5c762c..b95044df 100644 --- a/tests/fixtures/constants.py +++ b/tests/fixtures/constants.py @@ -29,7 +29,12 @@ DUMMY_MOTOR_FEATURES = { }, } DUMMY_CAMERA_FEATURES = { - "laptop": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None}, + "laptop": { + "shape": (480, 640, 3), + "names": ["height", "width", "channels"], + "info": None, + "audio": "laptop", + }, "phone": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None}, } DEFAULT_FPS = 30 @@ -40,5 +45,18 @@ DUMMY_VIDEO_INFO = { "video.is_depth_map": False, "has_audio": False, } +DUMMY_MICROPHONE_FEATURES = { + "laptop": {"dtype": "audio", "shape": (1,), "names": ["channels"], "info": None}, + "phone": {"dtype": "audio", "shape": (1,), "names": ["channels"], "info": None}, +} +DEFAULT_SAMPLE_RATE = 48000 +DUMMY_AUDIO_CHANNELS = 2 +DUMMY_AUDIO_INFO = { + "has_audio": True, + "audio.sample_rate": DEFAULT_SAMPLE_RATE, + "audio.codec": "aac", + "audio.channels": DUMMY_AUDIO_CHANNELS, + "audio.channel_layout": "stereo", +} DUMMY_CHW = (3, 96, 128) DUMMY_HWC = (96, 128, 3) diff --git a/tests/fixtures/dataset_factories.py b/tests/fixtures/dataset_factories.py index 531977da..6de43ef1 100644 --- a/tests/fixtures/dataset_factories.py +++ b/tests/fixtures/dataset_factories.py @@ -26,6 +26,7 @@ import torch from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata from lerobot.common.datasets.utils import ( DEFAULT_CHUNK_SIZE, + DEFAULT_COMPRESSED_AUDIO_PATH, DEFAULT_FEATURES, DEFAULT_PARQUET_PATH, DEFAULT_VIDEO_PATH, @@ -35,6 +36,7 @@ from lerobot.common.datasets.utils import ( from tests.fixtures.constants import ( DEFAULT_FPS, DUMMY_CAMERA_FEATURES, + DUMMY_MICROPHONE_FEATURES, DUMMY_MOTOR_FEATURES, DUMMY_REPO_ID, DUMMY_ROBOT_TYPE, @@ -90,6 +92,7 @@ def features_factory(): def _create_features( motor_features: dict = DUMMY_MOTOR_FEATURES, camera_features: dict = DUMMY_CAMERA_FEATURES, + audio_features: dict = DUMMY_MICROPHONE_FEATURES, use_videos: bool = True, ) -> dict: if use_videos: @@ -101,6 +104,7 @@ def features_factory(): return { **motor_features, **camera_ft, + **audio_features, **DEFAULT_FEATURES, } @@ -117,15 +121,18 @@ def info_factory(features_factory): total_frames: int = 0, total_tasks: int = 0, total_videos: int = 0, + total_audio: int = 0, total_chunks: int = 0, chunks_size: int = DEFAULT_CHUNK_SIZE, data_path: str = DEFAULT_PARQUET_PATH, video_path: str = DEFAULT_VIDEO_PATH, + audio_path: str = DEFAULT_COMPRESSED_AUDIO_PATH, motor_features: dict = DUMMY_MOTOR_FEATURES, camera_features: dict = DUMMY_CAMERA_FEATURES, + audio_features: dict = DUMMY_MICROPHONE_FEATURES, use_videos: bool = True, ) -> dict: - features = features_factory(motor_features, camera_features, use_videos) + features = features_factory(motor_features, camera_features, audio_features, use_videos) return { "codebase_version": codebase_version, "robot_type": robot_type, @@ -133,12 +140,14 @@ def info_factory(features_factory): "total_frames": total_frames, "total_tasks": total_tasks, "total_videos": total_videos, + "total_audio": total_audio, "total_chunks": total_chunks, "chunks_size": chunks_size, "fps": fps, "splits": {}, "data_path": data_path, "video_path": video_path if use_videos else None, + "audio_path": audio_path, "features": features, } @@ -162,6 +171,14 @@ def stats_factory(): "std": np.full((3, 1, 1), 0.25, dtype=np.float32).tolist(), "count": [10], } + elif dtype == "audio": + stats[key] = { + "mean": np.full((shape[0],), 0.0, dtype=np.float32).tolist(), + "max": np.full((shape[0],), 1, dtype=np.float32).tolist(), + "min": np.full((shape[0],), -1, dtype=np.float32).tolist(), + "std": np.full((shape[0],), 0.5, dtype=np.float32).tolist(), + "count": [10], + } else: stats[key] = { "max": np.full(shape, 1, dtype=dtype).tolist(), diff --git a/tests/microphones/mock_sounddevice.py b/tests/microphones/mock_sounddevice.py new file mode 100644 index 00000000..0220c88c --- /dev/null +++ b/tests/microphones/mock_sounddevice.py @@ -0,0 +1,88 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time +from functools import cache +from threading import Event, Thread + +import numpy as np + +from lerobot.common.utils.utils import capture_timestamp_utc +from tests.fixtures.constants import DEFAULT_SAMPLE_RATE, DUMMY_AUDIO_CHANNELS + + +@cache +def _generate_sound(duration: float, sample_rate: int, channels: int): + return np.random.uniform(-1, 1, size=(int(duration * sample_rate), channels)).astype(np.float32) + + +def query_devices(query_index: int): + return { + "name": "Mock Sound Device", + "index": query_index, + "max_input_channels": DUMMY_AUDIO_CHANNELS, + "default_samplerate": DEFAULT_SAMPLE_RATE, + } + + +class InputStream: + def __init__(self, *args, **kwargs): + self._mock_dict = { + "channels": DUMMY_AUDIO_CHANNELS, + "samplerate": DEFAULT_SAMPLE_RATE, + } + self._is_active = False + self._audio_callback = kwargs.get("callback") + + self.callback_thread = None + self.callback_thread_stop_event = None + + def _acquisition_loop(self): + if self._audio_callback is not None: + while not self.callback_thread_stop_event.is_set(): + # Simulate audio data acquisition + time.sleep(0.01) + self._audio_callback( + _generate_sound(0.01, DEFAULT_SAMPLE_RATE, DUMMY_AUDIO_CHANNELS), + 0.01 * DEFAULT_SAMPLE_RATE, + capture_timestamp_utc(), + None, + ) + + def start(self): + self.callback_thread_stop_event = Event() + self.callback_thread = Thread(target=self._acquisition_loop, args=()) + self.callback_thread.daemon = True + self.callback_thread.start() + + self._is_active = True + + @property + def active(self): + return self._is_active + + def stop(self): + if self.callback_thread_stop_event is not None: + self.callback_thread_stop_event.set() + self.callback_thread.join() + self.callback_thread = None + self.callback_thread_stop_event = None + self._is_active = False + + def close(self): + if self._is_active: + self.stop() + + def __del__(self): + if self._is_active: + self.stop() diff --git a/tests/microphones/test_microphones.py b/tests/microphones/test_microphones.py new file mode 100644 index 00000000..c7bdbe71 --- /dev/null +++ b/tests/microphones/test_microphones.py @@ -0,0 +1,152 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests for physical microphones and their mocked versions. +If the physical microphone is not connected to the computer, or not working, +the test will be skipped. + +Example of running a specific test: +```bash +pytest -sx tests/microphones/test_microphones.py::test_microphone +``` + +Example of running test on a real microphone connected to the computer: +```bash +pytest -sx 'tests/microphones/test_microphones.py::test_microphone[microphone-False]' +``` + +Example of running test on a mocked version of the microphone: +```bash +pytest -sx 'tests/microphones/test_microphones.py::test_microphone[microphone-True]' +``` +""" + +import time + +import numpy as np +import pytest +from soundfile import read + +from lerobot.common.robot_devices.utils import ( + RobotDeviceAlreadyConnectedError, + RobotDeviceAlreadyRecordingError, + RobotDeviceNotConnectedError, + RobotDeviceNotRecordingError, +) +from tests.utils import TEST_MICROPHONE_TYPES, make_microphone, require_microphone + +# Maximum recording tie difference between two consecutive audio recordings of the same duration. +# Set to 0.02 seconds as twice the default size of sounddvice callback buffer (i.e. we tolerate the loss of one buffer). +MAX_RECORDING_TIME_DIFFERENCE = 0.02 + +DUMMY_RECORDING = "test_recording.wav" + + +@pytest.mark.parametrize("microphone_type, mock", TEST_MICROPHONE_TYPES) +@require_microphone +def test_microphone(tmp_path, request, microphone_type, mock): + """Test assumes that a recroding handled with microphone.start_recording(output_file) and stop_recording() or microphone.read() + leqds to a sample that does not differ from the requested duration by more than 0.1 seconds. + """ + + microphone_kwargs = {"microphone_type": microphone_type, "mock": mock} + + # Test instantiating + microphone = make_microphone(**microphone_kwargs) + + # Test start_recording, stop_recording, read and disconnecting before connecting raises an error + with pytest.raises(RobotDeviceNotConnectedError): + microphone.start_recording() + with pytest.raises(RobotDeviceNotConnectedError): + microphone.stop_recording() + with pytest.raises(RobotDeviceNotConnectedError): + microphone.read() + with pytest.raises(RobotDeviceNotConnectedError): + microphone.disconnect() + + # Test deleting the object without connecting first + del microphone + + # Test connecting + microphone = make_microphone(**microphone_kwargs) + microphone.connect() + assert microphone.is_connected + assert microphone.sample_rate is not None + assert microphone.channels is not None + + # Test connecting twice raises an error + with pytest.raises(RobotDeviceAlreadyConnectedError): + microphone.connect() + + # Test reading or stop recording before starting recording raises an error + with pytest.raises(RobotDeviceNotRecordingError): + microphone.read() + with pytest.raises(RobotDeviceNotRecordingError): + microphone.stop_recording() + + # Test start_recording + fpath = tmp_path / DUMMY_RECORDING + microphone.start_recording(fpath) + assert microphone.is_recording + + # Test start_recording twice raises an error + with pytest.raises(RobotDeviceAlreadyRecordingError): + microphone.start_recording() + + # Test reading from the microphone + time.sleep(1.0) + audio_chunk = microphone.read() + assert isinstance(audio_chunk, np.ndarray) + assert audio_chunk.ndim == 2 + _, c = audio_chunk.shape + assert c == len(microphone.channels) + + # Test stop_recording + microphone.stop_recording() + assert fpath.exists() + assert not microphone.stream.active + assert microphone.record_thread is None + + # Test stop_recording twice raises an error + with pytest.raises(RobotDeviceNotRecordingError): + microphone.stop_recording() + + # Test reading and recording output similar length audio chunks + microphone.start_recording(tmp_path / DUMMY_RECORDING) + time.sleep(1.0) + audio_chunk = microphone.read() + microphone.stop_recording() + + recorded_audio, recorded_sample_rate = read(fpath) + assert recorded_sample_rate == microphone.sample_rate + + error_msg = ( + "Recording time difference between read() and stop_recording()", + (len(audio_chunk) - len(recorded_audio)) / MAX_RECORDING_TIME_DIFFERENCE, + ) + np.testing.assert_allclose( + len(audio_chunk), + len(recorded_audio), + atol=recorded_sample_rate * MAX_RECORDING_TIME_DIFFERENCE, + err_msg=error_msg, + ) + + # Test disconnecting + microphone.disconnect() + assert not microphone.is_connected + + # Test disconnecting with `__del__` + microphone = make_microphone(**microphone_kwargs) + microphone.connect() + del microphone diff --git a/tests/robots/test_robots.py b/tests/robots/test_robots.py index 71343eba..8353fe29 100644 --- a/tests/robots/test_robots.py +++ b/tests/robots/test_robots.py @@ -100,6 +100,9 @@ def test_robot(tmp_path, request, robot_type, mock): robot.teleop_step() # Test data recorded during teleop are well formatted + for _, microphone in robot.microphones.items(): + microphone.start_recording() + observation, action = robot.teleop_step(record_data=True) # State assert "observation.state" in observation @@ -112,6 +115,11 @@ def test_robot(tmp_path, request, robot_type, mock): assert f"observation.images.{name}" in observation assert isinstance(observation[f"observation.images.{name}"], torch.Tensor) assert observation[f"observation.images.{name}"].ndim == 3 + # Microphones + for name in robot.microphones: + assert f"observation.audio.{name}" in observation + assert isinstance(observation[f"observation.audio.{name}"], torch.Tensor) + assert observation[f"observation.audio.{name}"].ndim == 2 # Action assert "action" in action assert isinstance(action["action"], torch.Tensor) @@ -124,8 +132,9 @@ def test_robot(tmp_path, request, robot_type, mock): captured_observation = robot.capture_observation() assert set(captured_observation.keys()) == set(observation.keys()) for name in captured_observation: - if "image" in name: + if "image" in name or "audio" in name: # TODO(rcadene): skipping image for now as it's challenging to assess equality between two consecutive frames + # Also skipping for audio as audio chunks may be of different length continue torch.testing.assert_close(captured_observation[name], observation[name], rtol=1e-4, atol=1) assert captured_observation[name].shape == observation[name].shape @@ -134,7 +143,7 @@ def test_robot(tmp_path, request, robot_type, mock): robot.send_action(action["action"]) # Test disconnecting - robot.disconnect() + robot.disconnect() # Also handles microphone recording stop, life is beautiful assert not robot.is_connected for name in robot.follower_arms: assert not robot.follower_arms[name].is_connected @@ -142,3 +151,5 @@ def test_robot(tmp_path, request, robot_type, mock): assert not robot.leader_arms[name].is_connected for name in robot.cameras: assert not robot.cameras[name].is_connected + for name in robot.microphones: + assert not robot.microphones[name].is_connected diff --git a/tests/utils.py b/tests/utils.py index c49b5b9f..d93eb97c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -22,9 +22,11 @@ from pathlib import Path import pytest import torch -from lerobot import available_cameras, available_motors, available_robots +from lerobot import available_cameras, available_microphones, available_motors, available_robots from lerobot.common.robot_devices.cameras.utils import Camera from lerobot.common.robot_devices.cameras.utils import make_camera as make_camera_device +from lerobot.common.robot_devices.microphones.utils import Microphone +from lerobot.common.robot_devices.microphones.utils import make_microphone as make_microphone_device from lerobot.common.robot_devices.motors.utils import MotorsBus from lerobot.common.robot_devices.motors.utils import make_motors_bus as make_motors_bus_device from lerobot.common.utils.import_utils import is_package_available @@ -39,6 +41,10 @@ TEST_CAMERA_TYPES = [] for camera_type in available_cameras: TEST_CAMERA_TYPES += [(camera_type, True), (camera_type, False)] +TEST_MICROPHONE_TYPES = [] +for microphone_type in available_microphones: + TEST_MICROPHONE_TYPES += [(microphone_type, True), (microphone_type, False)] + TEST_MOTOR_TYPES = [] for motor_type in available_motors: TEST_MOTOR_TYPES += [(motor_type, True), (motor_type, False)] @@ -47,6 +53,9 @@ for motor_type in available_motors: OPENCV_CAMERA_INDEX = int(os.environ.get("LEROBOT_TEST_OPENCV_CAMERA_INDEX", 0)) INTELREALSENSE_SERIAL_NUMBER = int(os.environ.get("LEROBOT_TEST_INTELREALSENSE_SERIAL_NUMBER", 128422271614)) +# Microphone indices used for connecting physical microphones +MICROPHONE_INDEX = int(os.environ.get("LEROBOT_TEST_MICROPHONE_INDEX", 0)) + DYNAMIXEL_PORT = os.environ.get("LEROBOT_TEST_DYNAMIXEL_PORT", "/dev/tty.usbmodem575E0032081") DYNAMIXEL_MOTORS = { "shoulder_pan": [1, "xl430-w250"], @@ -253,6 +262,29 @@ def require_camera(func): return wrapper +def require_microphone(func): + @wraps(func) + def wrapper(*args, **kwargs): + # Access the pytest request context to get the is_microphone_available fixture + request = kwargs.get("request") + microphone_type = kwargs.get("microphone_type") + mock = kwargs.get("mock") + + if request is None: + raise ValueError("The 'request' fixture must be an argument of the test function.") + if microphone_type is None: + raise ValueError("The 'microphone_type' must be an argument of the test function.") + if mock is None: + raise ValueError("The 'mock' variable must be an argument of the test function.") + + if not mock and not request.getfixturevalue("is_microphone_available"): + pytest.skip(f"A {microphone_type} microphone is not available.") + + return func(*args, **kwargs) + + return wrapper + + def require_motor(func): @wraps(func) def wrapper(*args, **kwargs): @@ -315,6 +347,14 @@ def make_camera(camera_type: str, **kwargs) -> Camera: raise ValueError(f"The camera type '{camera_type}' is not valid.") +def make_microphone(microphone_type: str, **kwargs) -> Microphone: + if microphone_type == "microphone": + microphone_index = kwargs.pop("microphone_index", MICROPHONE_INDEX) + return make_microphone_device(microphone_type, microphone_index=microphone_index, **kwargs) + else: + raise ValueError(f"The microphone type '{microphone_type}' is not valid.") + + # TODO(rcadene, aliberts): remove this dark pattern that overrides def make_motors_bus(motor_type: str, **kwargs) -> MotorsBus: if motor_type == "dynamixel":