From 140d4f04310282a3f7bcb21936d090b771612aff Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Mon, 24 Mar 2025 15:24:01 +0100 Subject: [PATCH 01/32] Adding audio dependencies --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 4b858634..0830b4ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,6 +96,7 @@ test = ["pytest>=8.1.0", "pytest-cov>=5.0.0", "pyserial>=3.5"] umi = ["imagecodecs>=2024.1.1"] video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"] xarm = ["gym-xarm>=0.1.1 ; python_version < '4.0'"] +audio = ["sounddevice>=0.5.1", "soundfile>=0.13.1", "librosa>=0.11.0", "torchaudio>=2.6.0"] [tool.poetry] requires-poetry = ">=2.1" From 58cd0bdf862d2d3d5f96cb4ac16422cd8f192dee Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Thu, 27 Mar 2025 18:08:27 +0100 Subject: [PATCH 02/32] Implementing basic integration of microphones using soundfile and sounddevice --- .../robot_devices/microphones/microphone.py | 303 ++++++++++++++++++ 1 file changed, 303 insertions(+) create mode 100644 lerobot/common/robot_devices/microphones/microphone.py diff --git a/lerobot/common/robot_devices/microphones/microphone.py b/lerobot/common/robot_devices/microphones/microphone.py new file mode 100644 index 00000000..5b136aab --- /dev/null +++ b/lerobot/common/robot_devices/microphones/microphone.py @@ -0,0 +1,303 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file contains utilities for recording audio from a microhone. +""" + +import argparse +import soundfile as sf +import numpy as np +import logging +from threading import Thread, Event +from queue import Queue +from os.path import splitext +from os import remove, getcwd +from pathlib import Path +import shutil +import time +from concurrent.futures import ThreadPoolExecutor + +from lerobot.common.utils.utils import capture_timestamp_utc + +from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig +from lerobot.common.robot_devices.utils import ( + RobotDeviceAlreadyConnectedError, + RobotDeviceNotConnectedError, + busy_wait, +) + +def find_microphones(raise_when_empty=False, mock=False) -> list[dict]: + microphones = [] + + if mock: + #TODO(CarolinePascal): Implement mock microphones + pass + else: + import sounddevice as sd + + devices = sd.query_devices() + for device in devices: + if device["max_input_channels"] > 0: + microphones.append( + { + "index": device["index"], + "name": device["name"], + } + ) + + if raise_when_empty and len(microphones) == 0: + raise OSError( + "Not a single microphone was detected. Try re-plugging the microphone or check the microphone settings." + ) + + return microphones + +def record_audio_from_microphones( + output_dir: Path, + microphone_ids: list[int] | None = None, + record_time_s: float = 2.0): + + if microphone_ids is None or len(microphone_ids) == 0: + microphones = find_microphones() + microphone_ids = [m["index"] for m in microphones] + + microphones = [] + for microphone_id in microphone_ids: + config = MicrophoneConfig(microphone_index=microphone_id) + microphone = Microphone(config) + microphone.connect() + print( + f"Recording audio from microphone {microphone_id} for {record_time_s} seconds at {microphone.sampling_rate} Hz." + ) + microphones.append(microphone) + + output_dir = Path(output_dir) + if output_dir.exists(): + shutil.rmtree( + output_dir, + ) + output_dir.mkdir(parents=True, exist_ok=True) + print(f"Saving audio to {output_dir}") + + for microphone in microphones: + microphone.start_recording(getcwd() / output_dir / f"microphone_{microphone.microphone_index}.wav") + + time.sleep(record_time_s) + + for microphone in microphones: + microphone.stop_recording() + + #Remark : recording may be resumed here if needed + + for microphone in microphones: + microphone.disconnect() + + print(f"Images have been saved to {output_dir}") + +class Microphone: + """ + The Microphone class handles all microphones compatible with sounddevice (and the underlying PortAudio library). Most microphones and sound cards are compatible, accross all OS (Linux, Mac, Windows). + + A Microphone instance requires the sounddevice index of the microphone, which may be obtained using `python -m sounddevice`. It also requires the recording sampling rate as well as the list of recorded channels. + + Example of usage: + ```python + from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig + + config = MicrophoneConfig(microphone_index=0, sampling_rate=16000, channels=[1], data_type="int16") + microphone = Microphone(config) + + microphone.start_recording("some/output/file.wav") + ... + microphone.stop_recording() + + #OR + + microphone.start_recording() + ... + microphone.stop_recording() + last_recorded_audio_chunk = microphone.queue.get() + ``` + """ + + def __init__(self, config: MicrophoneConfig): + self.config = config + self.microphone_index = config.microphone_index + + #Store the recording sampling rate and channels + self.sampling_rate = config.sampling_rate + self.channels = config.channels + self.data_type = config.data_type + + self.mock = config.mock + + #Input audio stream + self.stream = None + #Thread-safe concurrent queue to store the recorded audio + self.queue = Queue() + self.thread = None + self.stop_event = None + self.logs = {} + + self.is_connected = False + + def connect(self) -> None: + if self.is_connected: + raise RobotDeviceAlreadyConnectedError(f"Microphone {self.microphone_index} is already connected.") + + if self.mock: + #TODO(CarolinePascal): Implement mock microphones + pass + else: + import sounddevice as sd + + #Check if the provided microphone index does match an input device + is_index_input = sd.query_devices(self.microphone_index)["max_input_channels"] > 0 + + if not is_index_input: + microphones_info = find_microphones() + available_microphones = [m["index"] for m in microphones_info] + raise OSError( + f"Microphone index {self.microphone_index} does not match an input device (microphone). Available input devices : {available_microphones}" + ) + + #Check if provided recording parameters are compatible with the microphone + actual_microphone = sd.query_devices(self.microphone_index) + + if self.sampling_rate is not None : + if self.sampling_rate > actual_microphone["default_samplerate"]: + raise OSError( + f"Provided sampling rate {self.sampling_rate} is higher than the sampling rate of the microphone {actual_microphone['default_samplerate']}." + ) + elif self.sampling_rate < actual_microphone["default_samplerate"]: + logging.warning("Provided sampling rate is lower than the sampling rate of the microphone. Performance may be impacted.") + else: + self.sampling_rate = int(actual_microphone["default_samplerate"]) + + if self.channels is not None: + if any(c > actual_microphone["max_input_channels"] for c in self.channels): + raise OSError( + f"Some of the provided channels {self.channels} are outside the maximum channel range of the microphone {actual_microphone['max_input_channels']}." + ) + else: + self.channels = np.arange(1, actual_microphone["max_input_channels"]+1) + + # Get channels index instead of number for slicing + self.channels = np.array(self.channels) - 1 + + #Create the audio stream + self.stream = sd.InputStream( + device=self.microphone_index, + samplerate=self.sampling_rate, + channels=max(self.channels)+1, + dtype=self.data_type, + callback=self._audio_callback, + ) + #Remark : the blocksize parameter could be passed to the stream to ensure that audio_callback always recieve same length buffers. + #However, this may lead to additionnal latency. We thus stick to blocksize=0 which means that audio_callback will recieve varying length buffers, but with no addtional latency. + + self.is_connected = True + + def _audio_callback(self, indata, frames, time, status) -> None : + if status: + logging.warning(status) + #slicing makes copy unecessary + self.queue.put(indata[:,self.channels]) + + def _read_write_loop(self, output_file : Path) -> None: + output_file = Path(output_file) + if output_file.exists(): + shutil.rmtree( + output_file, + ) + with sf.SoundFile(output_file, mode='x', samplerate=self.sampling_rate, + channels=max(self.channels)+1, subtype=sf.default_subtype(output_file.suffix[1:])) as file: + while not self.stop_event.is_set(): + file.write(self.queue.get()) + + def start_recording(self, output_file : str | None = None) -> None: + + if not self.is_connected: + raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.") + + if output_file is not None: + self.stop_event = Event() + self.thread = Thread(target=self._read_write_loop, args=(output_file,)) + self.thread.daemon = True + self.thread.start() + + self.stream.start() + + self.logs["start_timestamp"] = capture_timestamp_utc() + + def stop_recording(self) -> None: + + if not self.is_connected: + raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.") + + self.logs["stop_timestamp"] = capture_timestamp_utc() + + if self.thread is not None: + self.stop_event.set() + self.thread.join() + self.thread = None + self.stop_event = None + + if self.stream.active: + self.stream.stop() #Wait for all buffers to be processed + #Remark : stream.abort() flushes the buffers ! + + self.logs["duration"] = self.logs["stop_timestamp"] - self.logs["start_timestamp"] + + def disconnect(self) -> None: + + if not self.is_connected: + raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.") + + if self.stream.active: + self.stop_recording() + + self.stream.close() + self.is_connected = False + + def __del__(self): + if getattr(self, "is_connected", False): + self.disconnect() + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Records audio using `Microphone` for all microphones connected to the computer, or a selected subset." + ) + parser.add_argument( + "--microphone-ids", + type=int, + nargs="*", + default=None, + help="List of microphones indices used to instantiate the `Microphone`. If not provided, find and use all available microphones indices.", + ) + parser.add_argument( + "--output-dir", + type=Path, + default="outputs/audio_from_microphones", + help="Set directory to save an audio snipet for each microphone.", + ) + parser.add_argument( + "--record-time-s", + type=float, + default=4.0, + help="Set the number of seconds used to record the audio. By default, 4 seconds.", + ) + args = parser.parse_args() + record_audio_from_microphones(**vars(args)) From e6ea8e75c3170960546a3eee0e6571c6a75a35f1 Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Fri, 28 Mar 2025 17:15:19 +0100 Subject: [PATCH 03/32] Integrate microphones in Robot class --- .../robot_devices/microphones/configs.py | 37 ++++++++++++++++ .../common/robot_devices/microphones/utils.py | 43 +++++++++++++++++++ .../common/robot_devices/robots/configs.py | 4 ++ .../robot_devices/robots/manipulator.py | 31 ++++++++++++- 4 files changed, 114 insertions(+), 1 deletion(-) create mode 100644 lerobot/common/robot_devices/microphones/configs.py create mode 100644 lerobot/common/robot_devices/microphones/utils.py diff --git a/lerobot/common/robot_devices/microphones/configs.py b/lerobot/common/robot_devices/microphones/configs.py new file mode 100644 index 00000000..eb59be49 --- /dev/null +++ b/lerobot/common/robot_devices/microphones/configs.py @@ -0,0 +1,37 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from dataclasses import dataclass + +import draccus + +@dataclass +class MicrophoneConfigBase(draccus.ChoiceRegistry, abc.ABC): + @property + def type(self) -> str: + return self.get_choice_name(self.__class__) + +@MicrophoneConfigBase.register_subclass("microphone") +@dataclass +class MicrophoneConfig(MicrophoneConfigBase): + """ + Dataclass for microphone configuration. + """ + + microphone_index: int + sampling_rate: int | None = None + channels: list[int] | None = None + data_type: str | None = None + mock: bool = False \ No newline at end of file diff --git a/lerobot/common/robot_devices/microphones/utils.py b/lerobot/common/robot_devices/microphones/utils.py new file mode 100644 index 00000000..ea5790dc --- /dev/null +++ b/lerobot/common/robot_devices/microphones/utils.py @@ -0,0 +1,43 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Protocol + +from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig, MicrophoneConfigBase + +# Defines a microphone type +class Microphone(Protocol): + def connect(self): ... + def disconnect(self): ... + def start_recording(self, output_file: str | None = None): ... + def stop_recording(self): ... + +def make_microphones_from_configs(microphone_configs: dict[str, MicrophoneConfigBase]) -> list[Microphone]: + microphones = {} + + for key, cfg in microphone_configs.items(): + if cfg.type == "microphone": + from lerobot.common.robot_devices.microphones.microphone import Microphone + microphones[key] = Microphone(cfg) + else: + raise ValueError(f"The microphone type '{cfg.type}' is not valid.") + + return microphones + +def make_microphone(microphone_type, **kwargs) -> Microphone: + if microphone_type == "microphone": + from lerobot.common.robot_devices.microphones.microphone import Microphone + return Microphone(MicrophoneConfig(**kwargs)) + else: + raise ValueError(f"The microphone type '{microphone_type}' is not valid.") \ No newline at end of file diff --git a/lerobot/common/robot_devices/robots/configs.py b/lerobot/common/robot_devices/robots/configs.py index e940b442..9d72a46b 100644 --- a/lerobot/common/robot_devices/robots/configs.py +++ b/lerobot/common/robot_devices/robots/configs.py @@ -28,6 +28,7 @@ from lerobot.common.robot_devices.motors.configs import ( FeetechMotorsBusConfig, MotorsBusConfig, ) +from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig @dataclass @@ -68,6 +69,9 @@ class ManipulatorRobotConfig(RobotConfig): for cam in self.cameras.values(): if not cam.mock: cam.mock = True + for mic in self.microphones.values(): + if not mic.mock: + mic.mock = True if self.max_relative_target is not None and isinstance(self.max_relative_target, Sequence): for name in self.follower_arms: diff --git a/lerobot/common/robot_devices/robots/manipulator.py b/lerobot/common/robot_devices/robots/manipulator.py index 9173abc6..9212bb9c 100644 --- a/lerobot/common/robot_devices/robots/manipulator.py +++ b/lerobot/common/robot_devices/robots/manipulator.py @@ -28,6 +28,7 @@ import numpy as np import torch from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs +from lerobot.common.robot_devices.microphones.utils import make_microphones_from_configs from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs from lerobot.common.robot_devices.robots.configs import ManipulatorRobotConfig from lerobot.common.robot_devices.robots.utils import get_arm_id @@ -164,6 +165,7 @@ class ManipulatorRobot: self.leader_arms = make_motors_buses_from_configs(self.config.leader_arms) self.follower_arms = make_motors_buses_from_configs(self.config.follower_arms) self.cameras = make_cameras_from_configs(self.config.cameras) + self.microphones = make_microphones_from_configs(self.config.microphones) self.is_connected = False self.logs = {} @@ -198,6 +200,18 @@ class ManipulatorRobot: "names": state_names, }, } + + @property + def microphones_features(self) -> dict: + mic_ft = {} + for mic_key, mic in self.microphones.items(): + key = f"observation.audio.{mic_key}" + mic_ft[key] = { + "dtype": mic.data_type, + "shape": (mic.channels,), + "info": None, + } + return mic_ft @property def features(self): @@ -210,6 +224,14 @@ class ManipulatorRobot: @property def num_cameras(self): return len(self.cameras) + + @property + def has_microphone(self): + return len(self.microphones) > 0 + + @property + def num_microphones(self): + return len(self.microphones) @property def available_arms(self): @@ -228,7 +250,7 @@ class ManipulatorRobot: "ManipulatorRobot is already connected. Do not run `robot.connect()` twice." ) - if not self.leader_arms and not self.follower_arms and not self.cameras: + if not self.leader_arms and not self.follower_arms and not self.cameras and not self.microphones: raise ValueError( "ManipulatorRobot doesn't have any device to connect. See example of usage in docstring of the class." ) @@ -289,6 +311,10 @@ class ManipulatorRobot: for name in self.cameras: self.cameras[name].connect() + # Connect the microphones + for name in self.microphones: + self.microphones[name].connect() + self.is_connected = True def activate_calibration(self): @@ -620,6 +646,9 @@ class ManipulatorRobot: for name in self.cameras: self.cameras[name].disconnect() + for name in self.microphones: + self.microphones[name].disconnect() + self.is_connected = False def __del__(self): From 8ddfb299fdd0a8255dd9cb838e1e8533e264be33 Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Tue, 1 Apr 2025 20:36:29 +0200 Subject: [PATCH 04/32] Link cameras with their corresponding microphones for joint data handling --- lerobot/common/robot_devices/cameras/configs.py | 2 ++ lerobot/common/robot_devices/cameras/opencv.py | 2 ++ lerobot/common/robot_devices/robots/manipulator.py | 7 ++++--- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/lerobot/common/robot_devices/cameras/configs.py b/lerobot/common/robot_devices/cameras/configs.py index 013419a9..b1bb588c 100644 --- a/lerobot/common/robot_devices/cameras/configs.py +++ b/lerobot/common/robot_devices/cameras/configs.py @@ -48,6 +48,8 @@ class OpenCVCameraConfig(CameraConfig): rotation: int | None = None mock: bool = False + microphone: str | None = None + def __post_init__(self): if self.color_mode not in ["rgb", "bgr"]: raise ValueError( diff --git a/lerobot/common/robot_devices/cameras/opencv.py b/lerobot/common/robot_devices/cameras/opencv.py index f279f315..757b3d9f 100644 --- a/lerobot/common/robot_devices/cameras/opencv.py +++ b/lerobot/common/robot_devices/cameras/opencv.py @@ -281,6 +281,8 @@ class OpenCVCamera: elif config.rotation == 180: self.rotation = cv2.ROTATE_180 + self.microphone = config.microphone + def connect(self): if self.is_connected: raise RobotDeviceAlreadyConnectedError(f"OpenCVCamera({self.camera_index}) is already connected.") diff --git a/lerobot/common/robot_devices/robots/manipulator.py b/lerobot/common/robot_devices/robots/manipulator.py index 9212bb9c..443466ed 100644 --- a/lerobot/common/robot_devices/robots/manipulator.py +++ b/lerobot/common/robot_devices/robots/manipulator.py @@ -181,6 +181,7 @@ class ManipulatorRobot: "shape": (cam.height, cam.width, cam.channels), "names": ["height", "width", "channels"], "info": None, + "audio": "observation.audio." + cam.microphone if cam.microphone is not None else None, } return cam_ft @@ -207,9 +208,9 @@ class ManipulatorRobot: for mic_key, mic in self.microphones.items(): key = f"observation.audio.{mic_key}" mic_ft[key] = { - "dtype": mic.data_type, - "shape": (mic.channels,), - "info": None, + "shape": (len(mic.channels),), + "names": "channels", + "info" : None, } return mic_ft From 8ee61bb81fd27bbdae2c5cee672663346e3fd2fa Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Fri, 28 Mar 2025 17:16:51 +0100 Subject: [PATCH 05/32] Adding audio modality in LeRobotDatasets --- lerobot/common/datasets/lerobot_dataset.py | 141 +++++++++++++++-- lerobot/common/datasets/utils.py | 16 +- lerobot/common/datasets/video_utils.py | 148 ++++++++++++++++-- .../robot_devices/robots/manipulator.py | 2 +- tests/fixtures/dataset_factories.py | 3 + 5 files changed, 282 insertions(+), 28 deletions(-) diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index d8da85d6..488b7696 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -36,8 +36,11 @@ from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image from lerobot.common.datasets.utils import ( DEFAULT_FEATURES, DEFAULT_IMAGE_PATH, + DEFAULT_RAW_AUDIO_PATH, + DEFAULT_COMPRESSED_AUDIO_PATH, INFO_PATH, TASKS_PATH, + DEFAULT_AUDIO_CHUNK_DURATION, append_jsonlines, backward_compatible_episodes_stats, check_delta_timestamps, @@ -69,8 +72,11 @@ from lerobot.common.datasets.video_utils import ( VideoFrame, decode_video_frames, encode_video_frames, + encode_audio, + decode_audio, get_safe_default_codec, get_video_info, + get_audio_info, ) from lerobot.common.robot_devices.robots.utils import Robot @@ -141,6 +147,11 @@ class LeRobotDatasetMetadata: ep_chunk = self.get_episode_chunk(ep_index) fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index) return Path(fpath) + + def get_compressed_audio_file_path(self, episode_index: int, audio_key: str) -> Path: + episode_chunk = self.get_episode_chunk(episode_index) + fpath = self.audio_path.format(episode_chunk=episode_chunk, audio_key=audio_key, episode_index=episode_index) + return self.root / fpath def get_episode_chunk(self, ep_index: int) -> int: return ep_index // self.chunks_size @@ -154,6 +165,11 @@ class LeRobotDatasetMetadata: def video_path(self) -> str | None: """Formattable string for the video files.""" return self.info["video_path"] + + @property + def audio_path(self) -> str | None: + """Formattable string for the audio files.""" + return self.info["audio_path"] @property def robot_type(self) -> str | None: @@ -184,6 +200,11 @@ class LeRobotDatasetMetadata: def camera_keys(self) -> list[str]: """Keys to access visual modalities (regardless of their storage method).""" return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]] + + @property + def audio_keys(self) -> list[str]: + """Keys to access audio modalities.""" + return [key for key, ft in self.features.items() if ft["dtype"] == "audio"] @property def names(self) -> dict[str, list | dict]: @@ -264,6 +285,9 @@ class LeRobotDatasetMetadata: if len(self.video_keys) > 0: self.update_video_info() + if len(self.audio_keys) > 0: + self.update_audio_info() + write_info(self.info, self.root) episode_dict = { @@ -288,6 +312,17 @@ class LeRobotDatasetMetadata: video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key) self.info["features"][key]["info"] = get_video_info(video_path) + def update_audio_info(self) -> None: + """ + Warning: this function writes info from first episode audio, implicitly assuming that all audio have + been encoded the same way. Also, this means it assumes the first episode exists. + """ + bound_audio_keys = {self.features[video_key]["audio"] for video_key in self.video_keys if self.features[video_key]["audio"] is not None} + for key in set(self.audio_keys) - bound_audio_keys: + if not self.features[key].get("info", None): + audio_path = self.root / self.get_compressed_audio_file_path(0, key) + self.info["features"][key]["info"] = get_audio_info(audio_path) + def __repr__(self): feature_keys = list(self.features) return ( @@ -364,6 +399,7 @@ class LeRobotDataset(torch.utils.data.Dataset): force_cache_sync: bool = False, download_videos: bool = True, video_backend: str | None = None, + audio_backend: str | None = None, ): """ 2 modes are available for instantiating this class, depending on 2 different use cases: @@ -465,6 +501,7 @@ class LeRobotDataset(torch.utils.data.Dataset): True. video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'. You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision. + audio_backend (str | None, optional): Audio backend to use for decoding audio. Defaults to 'ffmpeg'. """ super().__init__() self.repo_id = repo_id @@ -475,6 +512,7 @@ class LeRobotDataset(torch.utils.data.Dataset): self.tolerance_s = tolerance_s self.revision = revision if revision else CODEBASE_VERSION self.video_backend = video_backend if video_backend else get_safe_default_codec() + self.audio_backend = audio_backend if audio_backend else "ffmpeg" #Waiting for torchcodec release #TODO(CarolinePascal) self.delta_indices = None # Unused attributes @@ -499,7 +537,7 @@ class LeRobotDataset(torch.utils.data.Dataset): self.hf_dataset = self.load_hf_dataset() except (AssertionError, FileNotFoundError, NotADirectoryError): self.revision = get_safe_version(self.repo_id, self.revision) - self.download_episodes(download_videos) + self.download_episodes(download_videos) #Sould load audio as well #TODO(CarolinePascal): separate audio from video self.hf_dataset = self.load_hf_dataset() self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes) @@ -677,7 +715,7 @@ class LeRobotDataset(torch.utils.data.Dataset): } return query_indices, padding - def _get_query_timestamps( + def _get_query_timestamps_video( self, current_ts: float, query_indices: dict[str, list[int]] | None = None, @@ -691,6 +729,22 @@ class LeRobotDataset(torch.utils.data.Dataset): query_timestamps[key] = [current_ts] return query_timestamps + + #TODO(CarolinePascal): add variable query durations + def _get_query_timestamps_audio( + self, + current_ts: float, + query_indices: dict[str, list[int]] | None = None, + ) -> dict[str, list[float]]: + query_timestamps = {} + for key in self.meta.audio_keys: + if query_indices is not None and key in query_indices: + timestamps = self.hf_dataset.select(query_indices[key])["timestamp"] + query_timestamps[key] = torch.stack(timestamps).tolist() + else: + query_timestamps[key] = [current_ts] + + return query_timestamps def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict: return { @@ -713,6 +767,21 @@ class LeRobotDataset(torch.utils.data.Dataset): return item + #TODO(CarolinePascal): add variable query durations + def _query_audio(self, query_timestamps: dict[str, list[float]], query_duration: float, ep_idx: int) -> dict[str, torch.Tensor]: + item = {} + bound_audio_keys_mapping = {self.meta.features[video_key]["audio"]:video_key for video_key in self.meta.video_keys if self.meta.features[video_key]["audio"] is not None} + for audio_key, query_ts in query_timestamps.items(): + #Audio stored with video in a single .mp4 file + if audio_key in bound_audio_keys_mapping.keys(): + audio_path = self.root / self.meta.get_video_file_path(ep_idx, bound_audio_keys_mapping[audio_key]) + #Audio stored alone in a separate .m4a file + else: + audio_path = self.root / self.meta.get_compressed_audio_file_path(ep_idx, audio_key) + audio_chunk = decode_audio(audio_path, query_ts, query_duration, self.audio_backend) + item[audio_key] = audio_chunk.squeeze(0) + return item + def _add_padding_keys(self, item: dict, padding: dict[str, list[bool]]) -> dict: for key, val in padding.items(): item[key] = torch.BoolTensor(val) @@ -733,11 +802,16 @@ class LeRobotDataset(torch.utils.data.Dataset): for key, val in query_result.items(): item[key] = val - if len(self.meta.video_keys) > 0: + if len(self.meta.video_keys) > 0 or len(self.meta.audio_keys) > 0: current_ts = item["timestamp"].item() - query_timestamps = self._get_query_timestamps(current_ts, query_indices) + + query_timestamps = self._get_query_timestamps_video(current_ts, query_indices) video_frames = self._query_videos(query_timestamps, ep_idx) - item = {**video_frames, **item} + item = {**item, **video_frames} + + query_timestamps = self._get_query_timestamps_audio(current_ts, query_indices) + audio_chunks = self._query_audio(query_timestamps, DEFAULT_AUDIO_CHUNK_DURATION, ep_idx) + item = {**item, **audio_chunks} if self.image_transforms is not None: image_keys = self.meta.camera_keys @@ -776,6 +850,10 @@ class LeRobotDataset(torch.utils.data.Dataset): image_key=image_key, episode_index=episode_index, frame_index=frame_index ) return self.root / fpath + + def _get_raw_audio_file_path(self, episode_index: int, audio_key: str) -> Path: + fpath = DEFAULT_RAW_AUDIO_PATH.format(audio_key=audio_key, episode_index=episode_index) + return self.root / fpath def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None: if self.image_writer is None: @@ -867,7 +945,7 @@ class LeRobotDataset(torch.utils.data.Dataset): for key, ft in self.features.items(): # index, episode_index, task_index are already processed above, and image and video # are processed separately by storing image path and frame info as meta data - if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]: + if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video", "audio"]: continue episode_buffer[key] = np.stack(episode_buffer[key]) @@ -880,6 +958,9 @@ class LeRobotDataset(torch.utils.data.Dataset): for key in self.meta.video_keys: episode_buffer[key] = video_paths[key] + if len(self.meta.audio_keys) > 0: + _ = self.encode_episode_audio(episode_index) + # `meta.save_episode` be executed after encoding the videos self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats) @@ -904,6 +985,13 @@ class LeRobotDataset(torch.utils.data.Dataset): if img_dir.is_dir(): shutil.rmtree(self.root / "images") + # delete raw audio + raw_audio_files = list(self.root.rglob("*.wav")) + for raw_audio_file in raw_audio_files: + raw_audio_file.unlink() + if len(list(raw_audio_file.parent.iterdir())) == 0: + raw_audio_file.parent.rmdir() + if not episode_data: # Reset the buffer self.episode_buffer = self.create_episode_buffer() @@ -971,18 +1059,45 @@ class LeRobotDataset(torch.utils.data.Dataset): since video encoding with ffmpeg is already using multithreading. """ video_paths = {} - for key in self.meta.video_keys: - video_path = self.root / self.meta.get_video_file_path(episode_index, key) - video_paths[key] = str(video_path) + for video_key in self.meta.video_keys: + video_path = self.root / self.meta.get_video_file_path(episode_index, video_key) + video_paths[video_key] = str(video_path) if video_path.is_file(): # Skip if video is already encoded. Could be the case when resuming data recording. continue img_dir = self._get_image_file_path( - episode_index=episode_index, image_key=key, frame_index=0 + episode_index=episode_index, image_key=video_key, frame_index=0 ).parent - encode_video_frames(img_dir, video_path, self.fps, overwrite=True) + + audio_path = None + if self.meta.features[video_key]["audio"] is not None: + audio_key = self.meta.features[video_key]["audio"] + audio_path = self._get_raw_audio_file_path(episode_index, audio_key) + + encode_video_frames(img_dir, video_path, self.fps, audio_path=audio_path, overwrite=True) return video_paths + + def encode_episode_audio(self, episode_index: int) -> dict: + """ + Use ffmpeg to convert .wav raw audio files into .m4a audio files. + Note: `encode_episode_audio` is a blocking call. Making it asynchronous shouldn't speedup encoding, + since video encoding with ffmpeg is already using multithreading. + """ + audio_paths = {} + bound_audio_keys = {self.meta.features[video_key]["audio"] for video_key in self.meta.video_keys if self.meta.features[video_key]["audio"] is not None} + for audio_key in set(self.meta.audio_keys) - bound_audio_keys: + input_audio_path = self.root / self._get_raw_audio_file_path(episode_index, audio_key) + output_audio_path = self.root / self.meta.get_compressed_audio_file_path(episode_index, audio_key) + + audio_paths[audio_key] = str(output_audio_path) + if output_audio_path.is_file(): + # Skip if video is already encoded. Could be the case when resuming data recording. + continue + + encode_audio(input_audio_path, output_audio_path, overwrite=True) + + return audio_paths @classmethod def create( @@ -998,6 +1113,7 @@ class LeRobotDataset(torch.utils.data.Dataset): image_writer_processes: int = 0, image_writer_threads: int = 0, video_backend: str | None = None, + audio_backend: str | None = None, ) -> "LeRobotDataset": """Create a LeRobot Dataset from scratch in order to record data.""" obj = cls.__new__(cls) @@ -1029,6 +1145,7 @@ class LeRobotDataset(torch.utils.data.Dataset): obj.delta_indices = None obj.episode_data_index = None obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec() + obj.audio_backend = audio_backend if audio_backend is not None else "ffmpeg" #Waiting for torchcodec release #TODO(CarolinePascal) return obj @@ -1049,6 +1166,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): tolerances_s: dict | None = None, download_videos: bool = True, video_backend: str | None = None, + audio_backend: str | None = None, ): super().__init__() self.repo_ids = repo_ids @@ -1066,6 +1184,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): tolerance_s=self.tolerances_s[repo_id], download_videos=download_videos, video_backend=video_backend, + audio_backend=audio_backend, ) for repo_id in repo_ids ] diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 9d8a54db..827e711b 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -55,6 +55,10 @@ TASKS_PATH = "meta/tasks.jsonl" DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4" DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet" DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png" +DEFAULT_RAW_AUDIO_PATH = "audio/{audio_key}/episode_{episode_index:06d}.wav" +DEFAULT_COMPRESSED_AUDIO_PATH = "audio/chunk-{episode_chunk:03d}/{audio_key}/episode_{episode_index:06d}.m4a" + +DEFAULT_AUDIO_CHUNK_DURATION = 0.5 # seconds DATASET_CARD_TEMPLATE = """ --- @@ -363,7 +367,7 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) -> def get_hf_features_from_features(features: dict) -> datasets.Features: hf_features = {} for key, ft in features.items(): - if ft["dtype"] == "video": + if ft["dtype"] == "video" or ft["dtype"] == "audio": continue elif ft["dtype"] == "image": hf_features[key] = datasets.Image() @@ -394,7 +398,13 @@ def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict: key: {"dtype": "video" if use_videos else "image", **ft} for key, ft in robot.camera_features.items() } - return {**robot.motor_features, **camera_ft, **DEFAULT_FEATURES} + microphones_ft = {} + if robot.microphones: + microphones_ft = { + key: {"dtype": "audio", **ft} + for key, ft in robot.microphones_features.items() + } + return {**robot.motor_features, **camera_ft, **microphones_ft, **DEFAULT_FEATURES} def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]: @@ -448,6 +458,7 @@ def create_empty_dataset_info( "splits": {}, "data_path": DEFAULT_PARQUET_PATH, "video_path": DEFAULT_VIDEO_PATH if use_videos else None, + "audio_path": DEFAULT_COMPRESSED_AUDIO_PATH, "features": features, } @@ -721,6 +732,7 @@ def validate_features_presence( ): error_message = "" missing_features = expected_features - actual_features + missing_features = {feature for feature in missing_features if "observation.audio" not in feature} extra_features = actual_features - (expected_features | optional_features) if missing_features or extra_features: diff --git a/lerobot/common/datasets/video_utils.py b/lerobot/common/datasets/video_utils.py index c38d570d..44d5a1a5 100644 --- a/lerobot/common/datasets/video_utils.py +++ b/lerobot/common/datasets/video_utils.py @@ -26,9 +26,12 @@ from typing import Any, ClassVar import pyarrow as pa import torch import torchvision +import torchaudio from datasets.features.features import register_feature from PIL import Image +from numpy import ceil + def get_safe_default_codec(): if importlib.util.find_spec("torchcodec"): @@ -39,7 +42,72 @@ def get_safe_default_codec(): ) return "pyav" +def decode_audio( + audio_path: Path | str, + timestamps: list[float], + duration: float, + backend: str | None = "ffmpeg", +) -> torch.Tensor: + """ + Decodes audio using the specified backend. + Args: + audio_path (Path): Path to the audio file. + timestamps (list[float]): List of timestamps to extract frames. + tolerance_s (float): Allowed deviation in seconds for frame retrieval. + backend (str, optional): Backend to use for decoding. Defaults to "pyav". + Returns: + torch.Tensor: Decoded frames. + + Currently supports pyav. + """ + if backend == "torchcodec": + raise NotImplementedError("torchcodec is not yet supported for audio decoding") + elif backend == "ffmpeg": + return decode_audio_torchvision(audio_path, timestamps, duration) + else: + raise ValueError(f"Unsupported video backend: {backend}") + +def decode_audio_torchvision( + audio_path: Path | str, + timestamps: list[float], + duration: float, + log_loaded_timestamps: bool = False, +) -> torch.Tensor: + + #TODO(CarolinePascal) : add channels selection + audio_path = str(audio_path) + + reader = torchaudio.io.StreamReader(src=audio_path) + audio_sampling_rate = reader.get_src_stream_info(reader.default_audio_stream).sample_rate + + #TODO(CarolinePascal) : sort timestamps ? + + reader.add_basic_audio_stream( + frames_per_chunk = int(ceil(duration * audio_sampling_rate)), #Too much is better than not enough + buffer_chunk_size = -1, #No dropping frames + ) + + audio_chunks = [] + for ts in timestamps: + reader.seek(ts) #Default to closest audio sample + status = reader.fill_buffer() + if status != 0: + logging.warning("Audio stream reached end of recording before decoding desired timestamps.") + + current_audio_chunk = reader.pop_chunks()[0] + + if log_loaded_timestamps: + logging.info(f"audio chunk loaded at starting timestamp={current_audio_chunk["pts"]:.4f} with duration={len(current_audio_chunk) / audio_sampling_rate:.4f}") + + audio_chunks.append(current_audio_chunk) + + audio_chunks = torch.stack(audio_chunks) + #TODO(CarolinePascal) : pytorch format conversion ? + + assert len(timestamps) == len(audio_chunks) + return audio_chunks + def decode_video_frames( video_path: Path | str, timestamps: list[float], @@ -69,7 +137,6 @@ def decode_video_frames( else: raise ValueError(f"Unsupported video backend: {backend}") - def decode_video_frames_torchvision( video_path: Path | str, timestamps: list[float], @@ -167,7 +234,6 @@ def decode_video_frames_torchvision( assert len(timestamps) == len(closest_frames) return closest_frames - def decode_video_frames_torchcodec( video_path: Path | str, timestamps: list[float], @@ -242,15 +308,52 @@ def decode_video_frames_torchcodec( assert len(timestamps) == len(closest_frames) return closest_frames +def encode_audio( + input_path: Path | str, + output_path: Path | str, + codec: str = "aac", + log_level: str | None = "error", + overwrite: bool = False, +) -> None: + """Encodes an audio file using ffmpeg.""" + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + ffmpeg_args = OrderedDict( + [ + ("-i", str(input_path)), + ("-acodec", codec), + ] + ) + + if log_level is not None: + ffmpeg_args["-loglevel"] = str(log_level) + + ffmpeg_args = [item for pair in ffmpeg_args.items() for item in pair] + if overwrite: + ffmpeg_args.append("-y") + + ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(output_path)] + + # redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal + subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL) + + if not output_path.exists(): + raise OSError( + f"Video encoding did not work. File not found: {output_path}. " + f"Try running the command manually to debug: `{''.join(ffmpeg_cmd)}`" + ) def encode_video_frames( imgs_dir: Path | str, video_path: Path | str, fps: int, + audio_path: Path | str | None = None, vcodec: str = "libsvtav1", pix_fmt: str = "yuv420p", g: int | None = 2, crf: int | None = 30, + acodec: str = "aac", #TODO(CarolinePascal) : investigate Fraunhofer FDK AAC (libfdk_aac) codec and and constant (file size control) /variable (quality control) bitrate options fast_decode: int = 0, log_level: str | None = "error", overwrite: bool = False, @@ -260,35 +363,53 @@ def encode_video_frames( imgs_dir = Path(imgs_dir) video_path.parent.mkdir(parents=True, exist_ok=True) - ffmpeg_args = OrderedDict( + ffmpeg_video_args = OrderedDict( [ ("-f", "image2"), ("-r", str(fps)), - ("-i", str(imgs_dir / "frame_%06d.png")), - ("-vcodec", vcodec), - ("-pix_fmt", pix_fmt), + ("-i", str(Path(imgs_dir) / "frame_%06d.png")), ] ) + ffmpeg_audio_args = OrderedDict() + if audio_path is not None: + audio_path = Path(audio_path) + audio_path.parent.mkdir(parents=True, exist_ok=True) + ffmpeg_audio_args.update(OrderedDict( + [ + ("-i", str(audio_path)), + ] + )) + + ffmpeg_encoding_args = OrderedDict( + [ + ("-pix_fmt", pix_fmt), + ("-vcodec", vcodec), + ] + ) if g is not None: - ffmpeg_args["-g"] = str(g) - + ffmpeg_encoding_args["-g"] = str(g) if crf is not None: - ffmpeg_args["-crf"] = str(crf) - + ffmpeg_encoding_args["-crf"] = str(crf) if fast_decode: key = "-svtav1-params" if vcodec == "libsvtav1" else "-tune" value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode" - ffmpeg_args[key] = value + ffmpeg_encoding_args[key] = value + if audio_path is not None: + ffmpeg_encoding_args["-acodec"] = acodec + if log_level is not None: - ffmpeg_args["-loglevel"] = str(log_level) + ffmpeg_encoding_args["-loglevel"] = str(log_level) - ffmpeg_args = [item for pair in ffmpeg_args.items() for item in pair] + ffmpeg_args = [item for pair in ffmpeg_video_args.items() for item in pair] + ffmpeg_args += [item for pair in ffmpeg_audio_args.items() for item in pair] + ffmpeg_args += [item for pair in ffmpeg_encoding_args.items() for item in pair] if overwrite: ffmpeg_args.append("-y") ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(video_path)] + # redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL) @@ -366,7 +487,6 @@ def get_audio_info(video_path: Path | str) -> dict: "audio.channel_layout": audio_stream_info.get("channel_layout", None), } - def get_video_info(video_path: Path | str) -> dict: ffprobe_video_cmd = [ "ffprobe", diff --git a/lerobot/common/robot_devices/robots/manipulator.py b/lerobot/common/robot_devices/robots/manipulator.py index 443466ed..afa4006a 100644 --- a/lerobot/common/robot_devices/robots/manipulator.py +++ b/lerobot/common/robot_devices/robots/manipulator.py @@ -216,7 +216,7 @@ class ManipulatorRobot: @property def features(self): - return {**self.motor_features, **self.camera_features} + return {**self.motor_features, **self.camera_features, **self.microphones_features} @property def has_camera(self): diff --git a/tests/fixtures/dataset_factories.py b/tests/fixtures/dataset_factories.py index 531977da..fbd7480f 100644 --- a/tests/fixtures/dataset_factories.py +++ b/tests/fixtures/dataset_factories.py @@ -29,6 +29,7 @@ from lerobot.common.datasets.utils import ( DEFAULT_FEATURES, DEFAULT_PARQUET_PATH, DEFAULT_VIDEO_PATH, + DEFAULT_COMPRESSED_AUDIO_PATH, get_hf_features_from_features, hf_transform_to_torch, ) @@ -121,6 +122,7 @@ def info_factory(features_factory): chunks_size: int = DEFAULT_CHUNK_SIZE, data_path: str = DEFAULT_PARQUET_PATH, video_path: str = DEFAULT_VIDEO_PATH, + audio_path: str = DEFAULT_COMPRESSED_AUDIO_PATH, motor_features: dict = DUMMY_MOTOR_FEATURES, camera_features: dict = DUMMY_CAMERA_FEATURES, use_videos: bool = True, @@ -139,6 +141,7 @@ def info_factory(features_factory): "splits": {}, "data_path": data_path, "video_path": video_path if use_videos else None, + "audio_path": audio_path, "features": features, } From e4eebd0680cf6ddb9f1d88208df6e9a3f443321f Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Fri, 28 Mar 2025 17:16:17 +0100 Subject: [PATCH 06/32] Adding microphone recording in control loop --- lerobot/common/datasets/lerobot_dataset.py | 12 ++++++++++++ lerobot/common/robot_devices/control_utils.py | 13 +++++++++++++ .../common/robot_devices/microphones/microphone.py | 3 +-- 3 files changed, 26 insertions(+), 2 deletions(-) diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 488b7696..a207845f 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -79,6 +79,7 @@ from lerobot.common.datasets.video_utils import ( get_audio_info, ) from lerobot.common.robot_devices.robots.utils import Robot +from lerobot.common.robot_devices.microphones.utils import Microphone CODEBASE_VERSION = "v2.1" @@ -910,6 +911,17 @@ class LeRobotDataset(torch.utils.data.Dataset): self.episode_buffer["size"] += 1 + def add_microphone_recording(self, microphone: Microphone, microphone_key: str) -> None: + """ + This function will start recording audio from the microphone and save it to disk. + """ + + audio_dir = self._get_raw_audio_file_path(self.num_episodes, "observation.audio." + microphone_key).parent + if not audio_dir.is_dir(): + audio_dir.mkdir(parents=True, exist_ok=True) + + microphone.start_recording(output_file = self._get_raw_audio_file_path(self.num_episodes, "observation.audio." + microphone_key)) + def save_episode(self, episode_data: dict | None = None) -> None: """ This will save to disk the current episode in self.episode_buffer. diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py index 4e42a989..335b6430 100644 --- a/lerobot/common/robot_devices/control_utils.py +++ b/lerobot/common/robot_devices/control_utils.py @@ -77,6 +77,11 @@ def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, f key = f"read_camera_{name}_dt_s" if key in robot.logs: log_dt(f"dtR{name}", robot.logs[key]) + + for name in robot.microphones: + key = f"read_microphone_{name}_dt_s" + if key in robot.logs: + log_dt(f"dtR{name}", robot.logs[key]) info_str = " ".join(log_items) logging.info(info_str) @@ -243,6 +248,11 @@ def control_loop( timestamp = 0 start_episode_t = time.perf_counter() + + if teleoperate and dataset is not None: + for microphone_key, microphone in robot.microphones.items(): + dataset.add_microphone_recording(microphone, microphone_key) + while timestamp < control_time_s: start_loop_t = time.perf_counter() @@ -286,6 +296,9 @@ def control_loop( events["exit_early"] = False break + for _, microphone in robot.microphones.items(): + microphone.stop_recording() + def reset_environment(robot, events, reset_time_s, fps): # TODO(rcadene): refactor warmup_record and reset_environment diff --git a/lerobot/common/robot_devices/microphones/microphone.py b/lerobot/common/robot_devices/microphones/microphone.py index 5b136aab..c26d03c0 100644 --- a/lerobot/common/robot_devices/microphones/microphone.py +++ b/lerobot/common/robot_devices/microphones/microphone.py @@ -238,9 +238,8 @@ class Microphone: self.thread.daemon = True self.thread.start() - self.stream.start() - self.logs["start_timestamp"] = capture_timestamp_utc() + self.stream.start() def stop_recording(self) -> None: From 058478a74d06c105cd8f41f83b27bc7df66f1bb7 Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Thu, 3 Apr 2025 11:41:05 +0200 Subject: [PATCH 07/32] Adding audio frames reading capability --- lerobot/common/robot_devices/control_utils.py | 11 +- .../robot_devices/microphones/microphone.py | 119 ++++++++++++------ .../robot_devices/robots/manipulator.py | 18 +++ 3 files changed, 111 insertions(+), 37 deletions(-) diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py index 335b6430..d54a9e13 100644 --- a/lerobot/common/robot_devices/control_utils.py +++ b/lerobot/common/robot_devices/control_utils.py @@ -249,9 +249,14 @@ def control_loop( timestamp = 0 start_episode_t = time.perf_counter() - if teleoperate and dataset is not None: + if dataset is not None: for microphone_key, microphone in robot.microphones.items(): + #Start recording both in file writing and data reading mode dataset.add_microphone_recording(microphone, microphone_key) + else: + for _, microphone in robot.microphones.items(): + # Start recording only in data reading mode + microphone.start_recording() while timestamp < control_time_s: start_loop_t = time.perf_counter() @@ -271,7 +276,9 @@ def control_loop( action = {"action": action} if dataset is not None: - frame = {**observation, **action, "task": single_task} + #Remove audio frames which are directly written in a dedicated file + audioless_observation = {key: observation[key] for key in observation if key not in robot.microphones} + frame = {**audioless_observation, **action, "task": single_task} dataset.add_frame(frame) # TODO(Steven): This should be more general (for RemoteRobot instead of checking the name, but anyways it will change soon) diff --git a/lerobot/common/robot_devices/microphones/microphone.py b/lerobot/common/robot_devices/microphones/microphone.py index c26d03c0..b751f2b6 100644 --- a/lerobot/common/robot_devices/microphones/microphone.py +++ b/lerobot/common/robot_devices/microphones/microphone.py @@ -119,16 +119,13 @@ class Microphone: config = MicrophoneConfig(microphone_index=0, sampling_rate=16000, channels=[1], data_type="int16") microphone = Microphone(config) + microphone.connect() microphone.start_recording("some/output/file.wav") ... - microphone.stop_recording() - - #OR - - microphone.start_recording() + audio_readings = microphone.read() #Gets all recorded audio data since the last read or since the beginning of the recording ... microphone.stop_recording() - last_recorded_audio_chunk = microphone.queue.get() + microphone.disconnect() ``` """ @@ -145,12 +142,16 @@ class Microphone: #Input audio stream self.stream = None - #Thread-safe concurrent queue to store the recorded audio - self.queue = Queue() - self.thread = None - self.stop_event = None - self.logs = {} + #Thread-safe concurrent queue to store the recorded/read audio + self.record_queue = Queue() + self.read_queue = Queue() + + #Thread to handle data reading and file writing in a separate thread (safely) + self.record_thread = None + self.record_stop_event = None + + self.logs = {} self.is_connected = False def connect(self) -> None: @@ -213,53 +214,101 @@ class Microphone: def _audio_callback(self, indata, frames, time, status) -> None : if status: logging.warning(status) - #slicing makes copy unecessary - self.queue.put(indata[:,self.channels]) + # Slicing makes copy unecessary + # Two separate queues are necessary because .get() also pops the data from the queue + self.record_queue.put(indata[:,self.channels]) + self.read_queue.put(indata[:,self.channels]) - def _read_write_loop(self, output_file : Path) -> None: - output_file = Path(output_file) - if output_file.exists(): - shutil.rmtree( - output_file, - ) + def _record_loop(self, output_file: Path) -> None: with sf.SoundFile(output_file, mode='x', samplerate=self.sampling_rate, channels=max(self.channels)+1, subtype=sf.default_subtype(output_file.suffix[1:])) as file: - while not self.stop_event.is_set(): - file.write(self.queue.get()) + while not self.record_stop_event.is_set(): + file.write(self.record_queue.get()) + #self.record_queue.task_done() + + def _read(self) -> np.ndarray: + """ + Gets audio data from the queue and coverts it to a numpy array. + -> PROS : Inherently thread safe, no need to lock the queue, lightweight CPU usage + -> CONS : Reading duration does not scale well with the number of channels and reading duration + """ + try: + audio_readings = self.read_queue.queue + except Queue.Empty: + audio_readings = np.empty((0, len(self.channels))) + else: + #TODO(CarolinePascal): Check if this is the fastest way to do it + self.read_queue = Queue() + with self.read_queue.mutex: + self.read_queue.queue.clear() + #self.read_queue.all_tasks_done.notify_all() + audio_readings = np.array(audio_readings).reshape(-1, len(self.channels)) + + return audio_readings + + def read(self) -> np.ndarray: + + if not self.is_connected: + raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.") + if not self.stream.active: + raise RuntimeError(f"Microphone {self.microphone_index} is not recording.") + + start_time = time.perf_counter() + + audio_readings = self._read() + + # log the number of seconds it took to read the audio chunk + self.logs["delta_timestamp_s"] = time.perf_counter() - start_time + + # log the utc time at which the audio chunk was received + self.logs["timestamp_utc"] = capture_timestamp_utc() + + return audio_readings def start_recording(self, output_file : str | None = None) -> None: if not self.is_connected: raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.") + self.read_queue = Queue() + with self.read_queue.mutex: + self.read_queue.queue.clear() + #self.read_queue.all_tasks_done.notify_all() + + self.record_queue = Queue() + with self.record_queue.mutex: + self.record_queue.queue.clear() + #self.record_queue.all_tasks_done.notify_all() + + #Recording case if output_file is not None: - self.stop_event = Event() - self.thread = Thread(target=self._read_write_loop, args=(output_file,)) - self.thread.daemon = True - self.thread.start() + output_file = Path(output_file) + if output_file.exists(): + output_file.unlink() + + self.record_stop_event = Event() + self.record_thread = Thread(target=self._record_loop, args=(output_file,)) + self.record_thread.daemon = True + self.record_thread.start() - self.logs["start_timestamp"] = capture_timestamp_utc() self.stream.start() def stop_recording(self) -> None: if not self.is_connected: raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.") - - self.logs["stop_timestamp"] = capture_timestamp_utc() - if self.thread is not None: - self.stop_event.set() - self.thread.join() - self.thread = None - self.stop_event = None + if self.record_thread is not None: + #self.record_queue.join() + self.record_stop_event.set() + self.record_thread.join() + self.record_thread = None + self.record_stop_event = None if self.stream.active: self.stream.stop() #Wait for all buffers to be processed #Remark : stream.abort() flushes the buffers ! - self.logs["duration"] = self.logs["stop_timestamp"] - self.logs["start_timestamp"] - def disconnect(self) -> None: if not self.is_connected: diff --git a/lerobot/common/robot_devices/robots/manipulator.py b/lerobot/common/robot_devices/robots/manipulator.py index afa4006a..b5dc60c8 100644 --- a/lerobot/common/robot_devices/robots/manipulator.py +++ b/lerobot/common/robot_devices/robots/manipulator.py @@ -541,6 +541,15 @@ class ManipulatorRobot: self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"] self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t + # Capture audio from microphones + audio = {} + for name in self.microphones: + before_audioread_t = time.perf_counter() + audio[name] = self.microphones[name].read() + audio[name] = torch.from_numpy(audio[name]) + self.logs[f"read_microphone_{name}_dt_s"] = self.microphones[name].logs["delta_timestamp_s"] + self.logs[f"async_read_microphone_{name}_dt_s"] = time.perf_counter() - before_audioread_t + # Populate output dictionaries obs_dict, action_dict = {}, {} obs_dict["observation.state"] = state @@ -581,6 +590,15 @@ class ManipulatorRobot: self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"] self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t + # Capture audio from microphones + audio = {} + for name in self.microphones: + before_audioread_t = time.perf_counter() + audio[name] = self.microphones[name].read() + audio[name] = torch.from_numpy(audio[name]) + self.logs[f"read_microphone_{name}_dt_s"] = self.microphones[name].logs["delta_timestamp_s"] + self.logs[f"async_read_microphone_{name}_dt_s"] = time.perf_counter() - before_audioread_t + # Populate output dictionaries and format to pytorch obs_dict = {} obs_dict["observation.state"] = state From 44af02a334f8b95d101e60d9cf803efb39c8ffa2 Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Fri, 4 Apr 2025 18:24:04 +0200 Subject: [PATCH 08/32] Remove variable audio recordings data types (will be converted to float32 anyway) --- lerobot/common/robot_devices/microphones/configs.py | 1 - lerobot/common/robot_devices/microphones/microphone.py | 6 +++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/lerobot/common/robot_devices/microphones/configs.py b/lerobot/common/robot_devices/microphones/configs.py index eb59be49..e90da519 100644 --- a/lerobot/common/robot_devices/microphones/configs.py +++ b/lerobot/common/robot_devices/microphones/configs.py @@ -33,5 +33,4 @@ class MicrophoneConfig(MicrophoneConfigBase): microphone_index: int sampling_rate: int | None = None channels: list[int] | None = None - data_type: str | None = None mock: bool = False \ No newline at end of file diff --git a/lerobot/common/robot_devices/microphones/microphone.py b/lerobot/common/robot_devices/microphones/microphone.py index b751f2b6..923d9b5e 100644 --- a/lerobot/common/robot_devices/microphones/microphone.py +++ b/lerobot/common/robot_devices/microphones/microphone.py @@ -116,7 +116,7 @@ class Microphone: ```python from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig - config = MicrophoneConfig(microphone_index=0, sampling_rate=16000, channels=[1], data_type="int16") + config = MicrophoneConfig(microphone_index=0, sampling_rate=16000, channels=[1]) microphone = Microphone(config) microphone.connect() @@ -136,7 +136,6 @@ class Microphone: #Store the recording sampling rate and channels self.sampling_rate = config.sampling_rate self.channels = config.channels - self.data_type = config.data_type self.mock = config.mock @@ -203,7 +202,7 @@ class Microphone: device=self.microphone_index, samplerate=self.sampling_rate, channels=max(self.channels)+1, - dtype=self.data_type, + dtype="float32", callback=self._audio_callback, ) #Remark : the blocksize parameter could be passed to the stream to ensure that audio_callback always recieve same length buffers. @@ -220,6 +219,7 @@ class Microphone: self.read_queue.put(indata[:,self.channels]) def _record_loop(self, output_file: Path) -> None: + #Can only be run on a single process/thread for file writing safety with sf.SoundFile(output_file, mode='x', samplerate=self.sampling_rate, channels=max(self.channels)+1, subtype=sf.default_subtype(output_file.suffix[1:])) as file: while not self.record_stop_event.is_set(): From a18d0e46782418ce5e9681f7897c4e1c582d9478 Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Fri, 4 Apr 2025 18:31:00 +0200 Subject: [PATCH 09/32] Adding pytorch compatible conversion for audio --- lerobot/common/datasets/video_utils.py | 2 +- lerobot/common/robot_devices/control_utils.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/lerobot/common/datasets/video_utils.py b/lerobot/common/datasets/video_utils.py index 44d5a1a5..4c96a400 100644 --- a/lerobot/common/datasets/video_utils.py +++ b/lerobot/common/datasets/video_utils.py @@ -86,6 +86,7 @@ def decode_audio_torchvision( reader.add_basic_audio_stream( frames_per_chunk = int(ceil(duration * audio_sampling_rate)), #Too much is better than not enough buffer_chunk_size = -1, #No dropping frames + format = "fltp", #Format as float32 ) audio_chunks = [] @@ -103,7 +104,6 @@ def decode_audio_torchvision( audio_chunks.append(current_audio_chunk) audio_chunks = torch.stack(audio_chunks) - #TODO(CarolinePascal) : pytorch format conversion ? assert len(timestamps) == len(audio_chunks) return audio_chunks diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py index d54a9e13..9c2f0f47 100644 --- a/lerobot/common/robot_devices/control_utils.py +++ b/lerobot/common/robot_devices/control_utils.py @@ -112,11 +112,14 @@ def predict_action(observation, policy, device, use_amp): torch.inference_mode(), torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(), ): - # Convert to pytorch format: channel first and float32 in [0,1] with batch dimension for name in observation: + # Convert to pytorch format: channel first and float32 in [0,1] with batch dimension if "image" in name: observation[name] = observation[name].type(torch.float32) / 255 observation[name] = observation[name].permute(2, 0, 1).contiguous() + # Convert to pytorch format: channel first and float32 in [-1,1] (always the case here) with batch dimension + if "audio" in name: + observation[name] = observation[name].permute(1, 0).contiguous() observation[name] = observation[name].unsqueeze(0) observation[name] = observation[name].to(device) From b00e866c603809775a0dc3f3435cbd116188301c Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Fri, 4 Apr 2025 19:48:57 +0200 Subject: [PATCH 10/32] Adding missing features for audio frames verification and stats --- lerobot/common/datasets/compute_stats.py | 13 +++++++++++-- lerobot/common/datasets/lerobot_dataset.py | 6 ++++++ lerobot/common/datasets/utils.py | 18 ++++++++++++++++++ lerobot/common/robot_devices/control_utils.py | 4 +--- .../robot_devices/microphones/microphone.py | 2 +- .../common/robot_devices/robots/manipulator.py | 4 ++++ 6 files changed, 41 insertions(+), 6 deletions(-) diff --git a/lerobot/common/datasets/compute_stats.py b/lerobot/common/datasets/compute_stats.py index 1149ec83..b24dbaf8 100644 --- a/lerobot/common/datasets/compute_stats.py +++ b/lerobot/common/datasets/compute_stats.py @@ -15,8 +15,7 @@ # limitations under the License. import numpy as np -from lerobot.common.datasets.utils import load_image_as_numpy - +from lerobot.common.datasets.utils import load_image_as_numpy, load_audio def estimate_num_samples( dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75 @@ -71,6 +70,12 @@ def sample_images(image_paths: list[str]) -> np.ndarray: return images +def sample_audio(audio_path: str) -> np.ndarray: + + data = load_audio(audio_path) + sampled_indices = sample_indices(len(data)) + + return(data[sampled_indices]) def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]: return { @@ -91,6 +96,10 @@ def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], featu ep_ft_array = sample_images(data) # data is a list of image paths axes_to_reduce = (0, 2, 3) # keep channel dim keepdims = True + elif features[key]["dtype"] == "audio": + ep_ft_array = sample_audio(data[0]) + axes_to_reduce = 0 + keepdims = True else: ep_ft_array = data # data is already a np.ndarray axes_to_reduce = 0 # compute stats over the first axis diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index a207845f..411a4676 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -906,6 +906,12 @@ class LeRobotDataset(torch.utils.data.Dataset): img_path.parent.mkdir(parents=True, exist_ok=True) self._save_image(frame[key], img_path) self.episode_buffer[key].append(str(img_path)) + elif self.features[key]["dtype"] == "audio": + if frame_index == 0: + audio_path = self._get_raw_audio_file_path( + episode_index=self.episode_buffer["episode_index"], audio_key=key + ) + self.episode_buffer[key].append(str(audio_path)) else: self.episode_buffer[key].append(frame[key]) diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 827e711b..a7bda7f2 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -35,6 +35,8 @@ from huggingface_hub.errors import RevisionNotFoundError from PIL import Image as PILImage from torchvision import transforms +from soundfile import read + from lerobot.common.datasets.backward_compatibility import ( V21_MESSAGE, BackwardCompatibilityError, @@ -258,6 +260,9 @@ def load_image_as_numpy( img_array /= 255.0 return img_array +def load_audio(fpath: str | Path) -> np.ndarray: + audio_data, _ = read(fpath, dtype="float32") + return audio_data def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]): """Get a transform function that convert items from Hugging Face dataset (pyarrow) @@ -752,6 +757,8 @@ def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray return validate_feature_numpy_array(name, expected_dtype, expected_shape, value) elif expected_dtype in ["image", "video"]: return validate_feature_image_or_video(name, expected_shape, value) + elif expected_dtype == "audio": + return validate_feature_audio(name, expected_shape, value) elif expected_dtype == "string": return validate_feature_string(name, value) else: @@ -792,6 +799,17 @@ def validate_feature_image_or_video(name: str, expected_shape: list[str], value: return error_message +def validate_feature_audio(name: str, expected_shape: list[str], value: np.ndarray): + error_message = "" + if isinstance(value, np.ndarray): + actual_shape = value.shape + c = expected_shape + if len(actual_shape) != 2 or (actual_shape[-1] != c[-1] and actual_shape[0] != c[0]): #The number of frames might be different + error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c,)}'.\n" + else: + error_message += f"The feature '{name}' is expected to be of type 'np.ndarray', but type '{type(value)}' provided instead.\n" + + return error_message def validate_feature_string(name: str, value: str): if not isinstance(value, str): diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py index 9c2f0f47..7f71b1ee 100644 --- a/lerobot/common/robot_devices/control_utils.py +++ b/lerobot/common/robot_devices/control_utils.py @@ -279,9 +279,7 @@ def control_loop( action = {"action": action} if dataset is not None: - #Remove audio frames which are directly written in a dedicated file - audioless_observation = {key: observation[key] for key in observation if key not in robot.microphones} - frame = {**audioless_observation, **action, "task": single_task} + frame = {**observation, **action, "task": single_task} dataset.add_frame(frame) # TODO(Steven): This should be more general (for RemoteRobot instead of checking the name, but anyways it will change soon) diff --git a/lerobot/common/robot_devices/microphones/microphone.py b/lerobot/common/robot_devices/microphones/microphone.py index 923d9b5e..8ab8b362 100644 --- a/lerobot/common/robot_devices/microphones/microphone.py +++ b/lerobot/common/robot_devices/microphones/microphone.py @@ -242,7 +242,7 @@ class Microphone: with self.read_queue.mutex: self.read_queue.queue.clear() #self.read_queue.all_tasks_done.notify_all() - audio_readings = np.array(audio_readings).reshape(-1, len(self.channels)) + audio_readings = np.array(audio_readings, dtype=np.float32).reshape(-1, len(self.channels)) return audio_readings diff --git a/lerobot/common/robot_devices/robots/manipulator.py b/lerobot/common/robot_devices/robots/manipulator.py index b5dc60c8..7e849914 100644 --- a/lerobot/common/robot_devices/robots/manipulator.py +++ b/lerobot/common/robot_devices/robots/manipulator.py @@ -556,6 +556,8 @@ class ManipulatorRobot: action_dict["action"] = action for name in self.cameras: obs_dict[f"observation.images.{name}"] = images[name] + for name in self.microphones: + obs_dict[f"observation.audio.{name}"] = audio[name] return obs_dict, action_dict @@ -604,6 +606,8 @@ class ManipulatorRobot: obs_dict["observation.state"] = state for name in self.cameras: obs_dict[f"observation.images.{name}"] = images[name] + for name in self.microphones: + obs_dict[f"observation.audio.{name}"] = audio[name] return obs_dict def send_action(self, action: torch.Tensor) -> torch.Tensor: From 17ad249335df3bf3e2e2b4994b5f5d72408ae298 Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Mon, 7 Apr 2025 13:21:55 +0200 Subject: [PATCH 11/32] Cleaning up bound/linked audio keys mapping recovery --- lerobot/common/datasets/lerobot_dataset.py | 16 +++++++++------- lerobot/common/robot_devices/robots/configs.py | 1 + 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 411a4676..1ca61520 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -206,6 +206,11 @@ class LeRobotDatasetMetadata: def audio_keys(self) -> list[str]: """Keys to access audio modalities.""" return [key for key, ft in self.features.items() if ft["dtype"] == "audio"] + + @property + def audio_camera_keys_mapping(self) -> dict[str, str]: + """Mapping between camera keys and audio keys when both are linked.""" + return {self.features[camera_key]["audio"]:camera_key for camera_key in self.camera_keys if self.features[camera_key]["audio"] is not None} @property def names(self) -> dict[str, list | dict]: @@ -318,8 +323,7 @@ class LeRobotDatasetMetadata: Warning: this function writes info from first episode audio, implicitly assuming that all audio have been encoded the same way. Also, this means it assumes the first episode exists. """ - bound_audio_keys = {self.features[video_key]["audio"] for video_key in self.video_keys if self.features[video_key]["audio"] is not None} - for key in set(self.audio_keys) - bound_audio_keys: + for key in set(self.audio_keys) - set(self.audio_camera_keys_mapping.keys()): if not self.features[key].get("info", None): audio_path = self.root / self.get_compressed_audio_file_path(0, key) self.info["features"][key]["info"] = get_audio_info(audio_path) @@ -771,11 +775,10 @@ class LeRobotDataset(torch.utils.data.Dataset): #TODO(CarolinePascal): add variable query durations def _query_audio(self, query_timestamps: dict[str, list[float]], query_duration: float, ep_idx: int) -> dict[str, torch.Tensor]: item = {} - bound_audio_keys_mapping = {self.meta.features[video_key]["audio"]:video_key for video_key in self.meta.video_keys if self.meta.features[video_key]["audio"] is not None} for audio_key, query_ts in query_timestamps.items(): #Audio stored with video in a single .mp4 file - if audio_key in bound_audio_keys_mapping.keys(): - audio_path = self.root / self.meta.get_video_file_path(ep_idx, bound_audio_keys_mapping[audio_key]) + if audio_key in self.meta.audio_camera_keys_mapping: + audio_path = self.root / self.meta.get_video_file_path(ep_idx, self.meta.audio_camera_keys_mapping[audio_key]) #Audio stored alone in a separate .m4a file else: audio_path = self.root / self.meta.get_compressed_audio_file_path(ep_idx, audio_key) @@ -1103,8 +1106,7 @@ class LeRobotDataset(torch.utils.data.Dataset): since video encoding with ffmpeg is already using multithreading. """ audio_paths = {} - bound_audio_keys = {self.meta.features[video_key]["audio"] for video_key in self.meta.video_keys if self.meta.features[video_key]["audio"] is not None} - for audio_key in set(self.meta.audio_keys) - bound_audio_keys: + for audio_key in set(self.meta.audio_keys) - set(self.meta.audio_camera_keys_mapping.keys()): input_audio_path = self.root / self._get_raw_audio_file_path(episode_index, audio_key) output_audio_path = self.root / self.meta.get_compressed_audio_file_path(episode_index, audio_key) diff --git a/lerobot/common/robot_devices/robots/configs.py b/lerobot/common/robot_devices/robots/configs.py index 9d72a46b..6bb3b93b 100644 --- a/lerobot/common/robot_devices/robots/configs.py +++ b/lerobot/common/robot_devices/robots/configs.py @@ -44,6 +44,7 @@ class ManipulatorRobotConfig(RobotConfig): leader_arms: dict[str, MotorsBusConfig] = field(default_factory=lambda: {}) follower_arms: dict[str, MotorsBusConfig] = field(default_factory=lambda: {}) cameras: dict[str, CameraConfig] = field(default_factory=lambda: {}) + microphones: dict[str, MicrophoneConfig] = field(default_factory=lambda: {}) # Optionally limit the magnitude of the relative positional target vector for safety purposes. # Set this to a positive scalar to have the same value for all motors, or a list that is the same length From 96ed10f90dae4397927c3da367614592d23c3018 Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Mon, 7 Apr 2025 16:33:16 +0200 Subject: [PATCH 12/32] Fixing sounddevice stream active state recovery and adding corresponding exceptions --- .../robot_devices/microphones/microphone.py | 18 +++++++++++++---- lerobot/common/robot_devices/utils.py | 20 +++++++++++++++++++ 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/lerobot/common/robot_devices/microphones/microphone.py b/lerobot/common/robot_devices/microphones/microphone.py index 8ab8b362..38a6e5f9 100644 --- a/lerobot/common/robot_devices/microphones/microphone.py +++ b/lerobot/common/robot_devices/microphones/microphone.py @@ -35,6 +35,8 @@ from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig from lerobot.common.robot_devices.utils import ( RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError, + RobotDeviceNotRecordingError, + RobotDeviceAlreadyRecordingError, busy_wait, ) @@ -152,6 +154,7 @@ class Microphone: self.logs = {} self.is_connected = False + self.is_recording = False def connect(self) -> None: if self.is_connected: @@ -250,8 +253,8 @@ class Microphone: if not self.is_connected: raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.") - if not self.stream.active: - raise RuntimeError(f"Microphone {self.microphone_index} is not recording.") + if not self.is_recording: + raise RobotDeviceNotRecordingError(f"Microphone {self.microphone_index} is not recording.") start_time = time.perf_counter() @@ -269,6 +272,8 @@ class Microphone: if not self.is_connected: raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.") + if self.is_recording: + raise RobotDeviceAlreadyRecordingError(f"Microphone {self.microphone_index} is already recording.") self.read_queue = Queue() with self.read_queue.mutex: @@ -291,13 +296,16 @@ class Microphone: self.record_thread.daemon = True self.record_thread.start() + self.is_recording = True self.stream.start() def stop_recording(self) -> None: if not self.is_connected: raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.") - + if not self.is_recording: + raise RobotDeviceNotRecordingError(f"Microphone {self.microphone_index} is not recording.") + if self.record_thread is not None: #self.record_queue.join() self.record_stop_event.set() @@ -309,12 +317,14 @@ class Microphone: self.stream.stop() #Wait for all buffers to be processed #Remark : stream.abort() flushes the buffers ! + self.is_recording = False + def disconnect(self) -> None: if not self.is_connected: raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.") - if self.stream.active: + if self.is_recording: self.stop_recording() self.stream.close() diff --git a/lerobot/common/robot_devices/utils.py b/lerobot/common/robot_devices/utils.py index 837c9d2e..01f9195e 100644 --- a/lerobot/common/robot_devices/utils.py +++ b/lerobot/common/robot_devices/utils.py @@ -63,3 +63,23 @@ class RobotDeviceAlreadyConnectedError(Exception): ): self.message = message super().__init__(self.message) + + +class RobotDeviceNotRecordingError(Exception): + """Exception raised when the robot device is not recording.""" + + def __init__( + self, message="This robot device is not recording. Try calling `robot_device.start_recording()` first." + ): + self.message = message + super().__init__(self.message) + +class RobotDeviceAlreadyRecordingError(Exception): + """Exception raised when the robot device is already recording.""" + + def __init__( + self, + message="This robot device is already recording. Try not calling `robot_device.start_recording()` twice.", + ): + self.message = message + super().__init__(self.message) From e743f846a79c0789c3acfa33eb745dbbd542154c Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Mon, 7 Apr 2025 19:08:53 +0200 Subject: [PATCH 13/32] Adding dtype="audio" by default in microphone features --- lerobot/common/datasets/utils.py | 8 +------- lerobot/common/robot_devices/robots/manipulator.py | 5 +++-- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index a7bda7f2..80143939 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -403,13 +403,7 @@ def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict: key: {"dtype": "video" if use_videos else "image", **ft} for key, ft in robot.camera_features.items() } - microphones_ft = {} - if robot.microphones: - microphones_ft = { - key: {"dtype": "audio", **ft} - for key, ft in robot.microphones_features.items() - } - return {**robot.motor_features, **camera_ft, **microphones_ft, **DEFAULT_FEATURES} + return {**robot.motor_features, **camera_ft, **robot.microphone_features, **DEFAULT_FEATURES} def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]: diff --git a/lerobot/common/robot_devices/robots/manipulator.py b/lerobot/common/robot_devices/robots/manipulator.py index 7e849914..16edc4fb 100644 --- a/lerobot/common/robot_devices/robots/manipulator.py +++ b/lerobot/common/robot_devices/robots/manipulator.py @@ -203,11 +203,12 @@ class ManipulatorRobot: } @property - def microphones_features(self) -> dict: + def microphone_features(self) -> dict: mic_ft = {} for mic_key, mic in self.microphones.items(): key = f"observation.audio.{mic_key}" mic_ft[key] = { + "dtype": "audio", "shape": (len(mic.channels),), "names": "channels", "info" : None, @@ -216,7 +217,7 @@ class ManipulatorRobot: @property def features(self): - return {**self.motor_features, **self.camera_features, **self.microphones_features} + return {**self.motor_features, **self.camera_features, **self.microphone_features} @property def has_camera(self): From 7c832fa2a7f6d448943b45f3b9567a7da8918b29 Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Mon, 7 Apr 2025 19:14:43 +0200 Subject: [PATCH 14/32] Taking into account situation where visual data is stored as images --- lerobot/common/datasets/lerobot_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 1ca61520..1966b8e7 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -210,7 +210,7 @@ class LeRobotDatasetMetadata: @property def audio_camera_keys_mapping(self) -> dict[str, str]: """Mapping between camera keys and audio keys when both are linked.""" - return {self.features[camera_key]["audio"]:camera_key for camera_key in self.camera_keys if self.features[camera_key]["audio"] is not None} + return {self.features[camera_key]["audio"]:camera_key for camera_key in self.camera_keys if self.features[camera_key]["audio"] is not None and self.features[camera_key]["dtype"] == "video"} @property def names(self) -> dict[str, list | dict]: From 8c69b0b9cdbf5daa30bc3ef2ce1ebd28b1cbb4da Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Mon, 7 Apr 2025 16:36:04 +0200 Subject: [PATCH 15/32] Adding audio tests --- lerobot/__init__.py | 5 + .../robot_devices/microphones/microphone.py | 6 +- tests/conftest.py | 7 +- tests/datasets/test_compute_stats.py | 29 +++- tests/datasets/test_datasets.py | 37 ++++- tests/fixtures/constants.py | 15 +- tests/fixtures/dataset_factories.py | 14 +- tests/microphones/mock_sounddevice.py | 82 ++++++++++ tests/microphones/test_microphones.py | 142 ++++++++++++++++++ tests/robots/test_robots.py | 15 +- tests/utils.py | 38 ++++- 11 files changed, 375 insertions(+), 15 deletions(-) create mode 100644 tests/microphones/mock_sounddevice.py create mode 100644 tests/microphones/test_microphones.py diff --git a/lerobot/__init__.py b/lerobot/__init__.py index d61e4853..386c2fbb 100644 --- a/lerobot/__init__.py +++ b/lerobot/__init__.py @@ -190,6 +190,11 @@ available_cameras = [ "intelrealsense", ] +# lists all available microphones from `lerobot/common/robot_devices/microphones` +available_microphones = [ + "microphone", +] + # lists all available motors from `lerobot/common/robot_devices/motors` available_motors = [ "dynamixel", diff --git a/lerobot/common/robot_devices/microphones/microphone.py b/lerobot/common/robot_devices/microphones/microphone.py index 38a6e5f9..c4cd7bac 100644 --- a/lerobot/common/robot_devices/microphones/microphone.py +++ b/lerobot/common/robot_devices/microphones/microphone.py @@ -44,8 +44,7 @@ def find_microphones(raise_when_empty=False, mock=False) -> list[dict]: microphones = [] if mock: - #TODO(CarolinePascal): Implement mock microphones - pass + import tests.microphones.mock_sounddevice as sd else: import sounddevice as sd @@ -161,8 +160,7 @@ class Microphone: raise RobotDeviceAlreadyConnectedError(f"Microphone {self.microphone_index} is already connected.") if self.mock: - #TODO(CarolinePascal): Implement mock microphones - pass + import tests.microphones.mock_sounddevice as sd else: import sounddevice as sd diff --git a/tests/conftest.py b/tests/conftest.py index 7eec94bf..adf80931 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,9 +19,9 @@ import traceback import pytest from serial import SerialException -from lerobot import available_cameras, available_motors, available_robots +from lerobot import available_cameras, available_motors, available_robots, available_microphones from lerobot.common.robot_devices.robots.utils import make_robot -from tests.utils import DEVICE, make_camera, make_motors_bus +from tests.utils import DEVICE, make_camera, make_motors_bus, make_microphone # Import fixture modules as plugins pytest_plugins = [ @@ -73,6 +73,9 @@ def is_robot_available(robot_type): def is_camera_available(camera_type): return _check_component_availability(camera_type, available_cameras, make_camera) +@pytest.fixture +def is_microphone_available(microphone_type): + return _check_component_availability(microphone_type, available_microphones, make_microphone) @pytest.fixture def is_motor_available(motor_type): diff --git a/tests/datasets/test_compute_stats.py b/tests/datasets/test_compute_stats.py index d9032c8a..113944b3 100644 --- a/tests/datasets/test_compute_stats.py +++ b/tests/datasets/test_compute_stats.py @@ -26,6 +26,7 @@ from lerobot.common.datasets.compute_stats import ( estimate_num_samples, get_feature_stats, sample_images, + sample_audio, sample_indices, ) @@ -33,6 +34,8 @@ from lerobot.common.datasets.compute_stats import ( def mock_load_image_as_numpy(path, dtype, channel_first): return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype) +def mock_load_audio(path): + return np.ones((16000,2), dtype=np.float32) @pytest.fixture def sample_array(): @@ -70,6 +73,14 @@ def test_sample_images(mock_load): assert images.dtype == np.uint8 assert len(images) == estimate_num_samples(100) +@patch("lerobot.common.datasets.compute_stats.load_audio", side_effect=mock_load_audio) +def test_sample_audio(mock_load): + audio_path = "audio.wav" + audio_samples = sample_audio(audio_path) + assert isinstance(audio_samples, np.ndarray) + assert audio_samples.shape[1] == 2 + assert audio_samples.dtype == np.float32 + assert len(audio_samples) == estimate_num_samples(16000) def test_get_feature_stats_images(): data = np.random.rand(100, 3, 32, 32) @@ -78,6 +89,12 @@ def test_get_feature_stats_images(): np.testing.assert_equal(stats["count"], np.array([100])) assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape +def test_get_feature_stats_audio(): + data = np.random.uniform(-1, 1, (16000,2)) + stats = get_feature_stats(data, axis=0, keepdims=True) + assert "min" in stats and "max" in stats and "mean" in stats and "std" in stats and "count" in stats + np.testing.assert_equal(stats["count"], np.array([16000])) + assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape def test_get_feature_stats_axis_0_keepdims(sample_array): expected = { @@ -137,22 +154,28 @@ def test_get_feature_stats_single_value(): def test_compute_episode_stats(): episode_data = { "observation.image": [f"image_{i}.jpg" for i in range(100)], + "observation.audio": "audio.wav", "observation.state": np.random.rand(100, 10), } features = { "observation.image": {"dtype": "image"}, + "observation.audio": {"dtype": "audio"}, "observation.state": {"dtype": "numeric"}, } with patch( "lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy + ), patch( + "lerobot.common.datasets.compute_stats.load_audio", side_effect=mock_load_audio ): stats = compute_episode_stats(episode_data, features) - assert "observation.image" in stats and "observation.state" in stats - assert stats["observation.image"]["count"].item() == 100 - assert stats["observation.state"]["count"].item() == 100 + assert "observation.image" in stats and "observation.state" in stats and "observation.audio" in stats + assert stats["observation.image"]["count"].item() == estimate_num_samples(100) + assert stats["observation.audio"]["count"].item() == estimate_num_samples(16000) + assert stats["observation.state"]["count"].item() == estimate_num_samples(100) assert stats["observation.image"]["mean"].shape == (3, 1, 1) + assert stats["observation.audio"]["mean"].shape == (1, 2) def test_assert_type_and_shape_valid(): diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 81447089..1ce0e7d8 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -44,9 +44,12 @@ from lerobot.common.policies.factory import make_policy_config from lerobot.common.robot_devices.robots.utils import make_robot from lerobot.configs.default import DatasetConfig from lerobot.configs.train import TrainPipelineConfig -from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID +from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID, DUMMY_AUDIO_CHANNELS from tests.utils import require_x86_64_kernel +from tests.utils import make_microphone +import time +from lerobot.common.datasets.utils import DEFAULT_AUDIO_CHUNK_DURATION @pytest.fixture def image_dataset(tmp_path, empty_lerobot_dataset_factory): @@ -63,6 +66,18 @@ def image_dataset(tmp_path, empty_lerobot_dataset_factory): } return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) +@pytest.fixture +def audio_dataset(tmp_path, empty_lerobot_dataset_factory): + features = { + "observation.audio.microphone": { + "dtype": "audio", + "shape": (DUMMY_AUDIO_CHANNELS,), + "names": [ + "channels", + ], + } + } + return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) def test_same_attributes_defined(tmp_path, lerobot_dataset_factory): """ @@ -321,6 +336,20 @@ def test_image_array_to_pil_image_wrong_range_float_0_255(): with pytest.raises(ValueError): image_array_to_pil_image(image) +def test_add_frame_audio(audio_dataset): + dataset = audio_dataset + + microphone = make_microphone(microphone_type="microphone", mock=True) + microphone.connect() + + dataset.add_microphone_recording(microphone, "microphone") + time.sleep(1.0) + dataset.add_frame({"observation.audio.microphone": microphone.read(), "task": "Dummy task"}) + microphone.stop_recording() + + dataset.save_episode() + + assert dataset[0]["observation.audio.microphone"].shape == torch.Size((int(DEFAULT_AUDIO_CHUNK_DURATION*microphone.sampling_rate),DUMMY_AUDIO_CHANNELS)) # TODO(aliberts): # - [ ] test various attributes & state from init and create @@ -354,6 +383,7 @@ def test_factory(env_name, repo_id, policy_name): dataset = make_dataset(cfg) delta_timestamps = dataset.delta_timestamps camera_keys = dataset.meta.camera_keys + audio_keys = dataset.meta.audio_keys item = dataset[0] @@ -396,6 +426,11 @@ def test_factory(env_name, repo_id, policy_name): # test c,h,w assert item[key].shape[0] == 3, f"{key}" + for key in audio_keys: + assert item[key].dtype == torch.float32, f"{key}" + assert item[key].max() <= 1.0, f"{key}" + assert item[key].min() >= -1.0, f"{key}" + if delta_timestamps is not None: # test missing keys in delta_timestamps for key in delta_timestamps: diff --git a/tests/fixtures/constants.py b/tests/fixtures/constants.py index 5e5c762c..91942190 100644 --- a/tests/fixtures/constants.py +++ b/tests/fixtures/constants.py @@ -29,7 +29,7 @@ DUMMY_MOTOR_FEATURES = { }, } DUMMY_CAMERA_FEATURES = { - "laptop": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None}, + "laptop": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None, "audio": "laptop"}, "phone": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None}, } DEFAULT_FPS = 30 @@ -40,5 +40,18 @@ DUMMY_VIDEO_INFO = { "video.is_depth_map": False, "has_audio": False, } +DUMMY_MICROPHONE_FEATURES = { + "laptop": {"dtype": "audio", "shape": (1,), "names": ["channels"], "info": None}, + "phone": {"dtype": "audio", "shape": (1,), "names": ["channels"], "info": None}, +} +DEFAULT_SAMPLE_RATE = 48000 +DUMMY_AUDIO_CHANNELS = 2 +DUMMY_AUDIO_INFO = { + "has_audio": True, + "audio.sample_rate": DEFAULT_SAMPLE_RATE, + "audio.codec": "aac", + "audio.channels": DUMMY_AUDIO_CHANNELS, + "audio.channel_layout": "stereo", +} DUMMY_CHW = (3, 96, 128) DUMMY_HWC = (96, 128, 3) diff --git a/tests/fixtures/dataset_factories.py b/tests/fixtures/dataset_factories.py index fbd7480f..80387d65 100644 --- a/tests/fixtures/dataset_factories.py +++ b/tests/fixtures/dataset_factories.py @@ -36,6 +36,7 @@ from lerobot.common.datasets.utils import ( from tests.fixtures.constants import ( DEFAULT_FPS, DUMMY_CAMERA_FEATURES, + DUMMY_MICROPHONE_FEATURES, DUMMY_MOTOR_FEATURES, DUMMY_REPO_ID, DUMMY_ROBOT_TYPE, @@ -91,6 +92,7 @@ def features_factory(): def _create_features( motor_features: dict = DUMMY_MOTOR_FEATURES, camera_features: dict = DUMMY_CAMERA_FEATURES, + audio_features: dict = DUMMY_MICROPHONE_FEATURES, use_videos: bool = True, ) -> dict: if use_videos: @@ -102,6 +104,7 @@ def features_factory(): return { **motor_features, **camera_ft, + **audio_features, **DEFAULT_FEATURES, } @@ -125,9 +128,10 @@ def info_factory(features_factory): audio_path: str = DEFAULT_COMPRESSED_AUDIO_PATH, motor_features: dict = DUMMY_MOTOR_FEATURES, camera_features: dict = DUMMY_CAMERA_FEATURES, + audio_features: dict = DUMMY_MICROPHONE_FEATURES, use_videos: bool = True, ) -> dict: - features = features_factory(motor_features, camera_features, use_videos) + features = features_factory(motor_features, camera_features, audio_features, use_videos) return { "codebase_version": codebase_version, "robot_type": robot_type, @@ -165,6 +169,14 @@ def stats_factory(): "std": np.full((3, 1, 1), 0.25, dtype=np.float32).tolist(), "count": [10], } + elif dtype == "audio": + stats[key] = { + "mean": np.full((shape[0],), 0.0, dtype=np.float32).tolist(), + "max": np.full((shape[0],), 1, dtype=np.float32).tolist(), + "min": np.full((shape[0],), -1, dtype=np.float32).tolist(), + "std": np.full((shape[0],), 0.5, dtype=np.float32).tolist(), + "count": [10], + } else: stats[key] = { "max": np.full(shape, 1, dtype=dtype).tolist(), diff --git a/tests/microphones/mock_sounddevice.py b/tests/microphones/mock_sounddevice.py new file mode 100644 index 00000000..f6007085 --- /dev/null +++ b/tests/microphones/mock_sounddevice.py @@ -0,0 +1,82 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import cache + +from tests.fixtures.constants import DUMMY_AUDIO_CHANNELS, DEFAULT_SAMPLE_RATE + +import numpy as np +from lerobot.common.utils.utils import capture_timestamp_utc +from threading import Thread, Event +import time + +@cache +def _generate_sound(duration: float, sample_rate: int, channels: int): + return np.random.uniform(-1, 1, size=(int(duration * sample_rate), channels)).astype(np.float32) + +def query_devices(query_index: int): + return { + "name": "Mock Sound Device", + "index": query_index, + "max_input_channels": DUMMY_AUDIO_CHANNELS, + "default_samplerate": DEFAULT_SAMPLE_RATE, + } + +class InputStream: + def __init__(self, *args, **kwargs): + self._mock_dict = { + "channels": DUMMY_AUDIO_CHANNELS, + "samplerate": DEFAULT_SAMPLE_RATE, + } + self._is_active = False + self._audio_callback = kwargs.get("callback") + + self.callback_thread = None + self.callback_thread_stop_event = None + + def _acquisition_loop(self): + if self._audio_callback is not None: + while not self.callback_thread_stop_event.is_set(): + # Simulate audio data acquisition + time.sleep(0.01) + self._audio_callback(_generate_sound(0.01, DEFAULT_SAMPLE_RATE, DUMMY_AUDIO_CHANNELS), 0.01*DEFAULT_SAMPLE_RATE, capture_timestamp_utc(), None) + + def start(self): + self.callback_thread_stop_event = Event() + self.callback_thread = Thread(target=self._acquisition_loop, args=()) + self.callback_thread.daemon = True + self.callback_thread.start() + + self._is_active = True + + @property + def active(self): + return self._is_active + + def stop(self): + if self.callback_thread_stop_event is not None: + self.callback_thread_stop_event.set() + self.callback_thread.join() + self.callback_thread = None + self.callback_thread_stop_event = None + self._is_active = False + + def close(self): + if self._is_active: + self.stop() + + def __del__(self): + if self._is_active: + self.stop() + + diff --git a/tests/microphones/test_microphones.py b/tests/microphones/test_microphones.py new file mode 100644 index 00000000..50c65119 --- /dev/null +++ b/tests/microphones/test_microphones.py @@ -0,0 +1,142 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests for physical microphones and their mocked versions. +If the physical microphone is not connected to the computer, or not working, +the test will be skipped. + +Example of running a specific test: +```bash +pytest -sx tests/microphones/test_microphones.py::test_microphone +``` + +Example of running test on a real microphone connected to the computer: +```bash +pytest -sx 'tests/microphones/test_microphones.py::test_microphone[microphone-False]' +``` + +Example of running test on a mocked version of the microphone: +```bash +pytest -sx 'tests/microphones/test_microphones.py::test_microphone[microphone-True]' +``` +""" + +import numpy as np +import time +import pytest +from soundfile import read + +from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError, RobotDeviceNotRecordingError, RobotDeviceAlreadyRecordingError +from tests.utils import TEST_MICROPHONE_TYPES, make_microphone, require_microphone + +#Maximum recording tie difference between two consecutive audio recordings of the same duration. +#Set to 0.02 seconds as twice the default size of sounddvice callback buffer (i.e. we tolerate the loss of one buffer). +MAX_RECORDING_TIME_DIFFERENCE = 0.02 + +DUMMY_RECORDING = "test_recording.wav" + +@pytest.mark.parametrize("microphone_type, mock", TEST_MICROPHONE_TYPES) +@require_microphone +def test_microphone(tmp_path, request, microphone_type, mock): + """Test assumes that a recroding handled with microphone.start_recording(output_file) and stop_recording() or microphone.read() + leqds to a sample that does not differ from the requested duration by more than 0.1 seconds. + """ + + microphone_kwargs = {"microphone_type": microphone_type, "mock": mock} + + # Test instantiating + microphone = make_microphone(**microphone_kwargs) + + # Test start_recording, stop_recording, read and disconnecting before connecting raises an error + with pytest.raises(RobotDeviceNotConnectedError): + microphone.start_recording() + with pytest.raises(RobotDeviceNotConnectedError): + microphone.stop_recording() + with pytest.raises(RobotDeviceNotConnectedError): + microphone.read() + with pytest.raises(RobotDeviceNotConnectedError): + microphone.disconnect() + + # Test deleting the object without connecting first + del microphone + + # Test connecting + microphone = make_microphone(**microphone_kwargs) + microphone.connect() + assert microphone.is_connected + assert microphone.sampling_rate is not None + assert microphone.channels is not None + + # Test connecting twice raises an error + with pytest.raises(RobotDeviceAlreadyConnectedError): + microphone.connect() + + # Test reading or stop recording before starting recording raises an error + with pytest.raises(RobotDeviceNotRecordingError): + microphone.read() + with pytest.raises(RobotDeviceNotRecordingError): + microphone.stop_recording() + + # Test start_recording + fpath = tmp_path / DUMMY_RECORDING + microphone.start_recording(fpath) + assert microphone.is_recording + + # Test start_recording twice raises an error + with pytest.raises(RobotDeviceAlreadyRecordingError): + microphone.start_recording() + + # Test reading from the microphone + time.sleep(1.0) + audio_chunk = microphone.read() + assert isinstance(audio_chunk, np.ndarray) + assert audio_chunk.ndim == 2 + _, c = audio_chunk.shape + assert c == len(microphone.channels) + + # Test stop_recording + microphone.stop_recording() + assert fpath.exists() + assert not microphone.stream.active + assert microphone.record_thread is None + + # Test stop_recording twice raises an error + with pytest.raises(RobotDeviceNotRecordingError): + microphone.stop_recording() + + # Test reading and recording output similar length audio chunks + microphone.start_recording(tmp_path / DUMMY_RECORDING) + time.sleep(1.0) + audio_chunk = microphone.read() + microphone.stop_recording() + + recorded_audio, recorded_sample_rate = read(fpath) + assert recorded_sample_rate == microphone.sampling_rate + + error_msg = ( + "Recording time difference between read() and stop_recording()", + (len(audio_chunk) - len(recorded_audio))/MAX_RECORDING_TIME_DIFFERENCE, + ) + np.testing.assert_allclose( + len(audio_chunk), len(recorded_audio), atol=recorded_sample_rate*MAX_RECORDING_TIME_DIFFERENCE, err_msg=error_msg + ) + + # Test disconnecting + microphone.disconnect() + assert not microphone.is_connected + + # Test disconnecting with `__del__` + microphone = make_microphone(**microphone_kwargs) + microphone.connect() + del microphone \ No newline at end of file diff --git a/tests/robots/test_robots.py b/tests/robots/test_robots.py index 71343eba..204aabca 100644 --- a/tests/robots/test_robots.py +++ b/tests/robots/test_robots.py @@ -100,6 +100,9 @@ def test_robot(tmp_path, request, robot_type, mock): robot.teleop_step() # Test data recorded during teleop are well formatted + for _, microphone in robot.microphones.items(): + microphone.start_recording() + observation, action = robot.teleop_step(record_data=True) # State assert "observation.state" in observation @@ -112,6 +115,11 @@ def test_robot(tmp_path, request, robot_type, mock): assert f"observation.images.{name}" in observation assert isinstance(observation[f"observation.images.{name}"], torch.Tensor) assert observation[f"observation.images.{name}"].ndim == 3 + # Microphones + for name in robot.microphones: + assert f"observation.audio.{name}" in observation + assert isinstance(observation[f"observation.audio.{name}"], torch.Tensor) + assert observation[f"observation.audio.{name}"].ndim == 2 # Action assert "action" in action assert isinstance(action["action"], torch.Tensor) @@ -124,8 +132,9 @@ def test_robot(tmp_path, request, robot_type, mock): captured_observation = robot.capture_observation() assert set(captured_observation.keys()) == set(observation.keys()) for name in captured_observation: - if "image" in name: + if "image" in name or "audio" in name: # TODO(rcadene): skipping image for now as it's challenging to assess equality between two consecutive frames + # Also skipping for audio as audio chunks may be of different length continue torch.testing.assert_close(captured_observation[name], observation[name], rtol=1e-4, atol=1) assert captured_observation[name].shape == observation[name].shape @@ -134,7 +143,7 @@ def test_robot(tmp_path, request, robot_type, mock): robot.send_action(action["action"]) # Test disconnecting - robot.disconnect() + robot.disconnect() #Also handles microphone recording stop, life is beautiful assert not robot.is_connected for name in robot.follower_arms: assert not robot.follower_arms[name].is_connected @@ -142,3 +151,5 @@ def test_robot(tmp_path, request, robot_type, mock): assert not robot.leader_arms[name].is_connected for name in robot.cameras: assert not robot.cameras[name].is_connected + for name in robot.microphones: + assert not robot.microphones[name].is_connected diff --git a/tests/utils.py b/tests/utils.py index c49b5b9f..8559ca0c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -22,11 +22,13 @@ from pathlib import Path import pytest import torch -from lerobot import available_cameras, available_motors, available_robots +from lerobot import available_cameras, available_motors, available_robots, available_microphones from lerobot.common.robot_devices.cameras.utils import Camera from lerobot.common.robot_devices.cameras.utils import make_camera as make_camera_device from lerobot.common.robot_devices.motors.utils import MotorsBus from lerobot.common.robot_devices.motors.utils import make_motors_bus as make_motors_bus_device +from lerobot.common.robot_devices.microphones.utils import Microphone +from lerobot.common.robot_devices.microphones.utils import make_microphone as make_microphone_device from lerobot.common.utils.import_utils import is_package_available DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", "cuda") if torch.cuda.is_available() else "cpu" @@ -39,6 +41,10 @@ TEST_CAMERA_TYPES = [] for camera_type in available_cameras: TEST_CAMERA_TYPES += [(camera_type, True), (camera_type, False)] +TEST_MICROPHONE_TYPES = [] +for microphone_type in available_microphones: + TEST_MICROPHONE_TYPES += [(microphone_type, True), (microphone_type, False)] + TEST_MOTOR_TYPES = [] for motor_type in available_motors: TEST_MOTOR_TYPES += [(motor_type, True), (motor_type, False)] @@ -47,6 +53,9 @@ for motor_type in available_motors: OPENCV_CAMERA_INDEX = int(os.environ.get("LEROBOT_TEST_OPENCV_CAMERA_INDEX", 0)) INTELREALSENSE_SERIAL_NUMBER = int(os.environ.get("LEROBOT_TEST_INTELREALSENSE_SERIAL_NUMBER", 128422271614)) +# Microphone indices used for connecting physical microphones +MICROPHONE_INDEX = int(os.environ.get("LEROBOT_TEST_MICROPHONE_INDEX", 0)) + DYNAMIXEL_PORT = os.environ.get("LEROBOT_TEST_DYNAMIXEL_PORT", "/dev/tty.usbmodem575E0032081") DYNAMIXEL_MOTORS = { "shoulder_pan": [1, "xl430-w250"], @@ -252,6 +261,27 @@ def require_camera(func): return wrapper +def require_microphone(func): + @wraps(func) + def wrapper(*args, **kwargs): + # Access the pytest request context to get the is_microphone_available fixture + request = kwargs.get("request") + microphone_type = kwargs.get("microphone_type") + mock = kwargs.get("mock") + + if request is None: + raise ValueError("The 'request' fixture must be an argument of the test function.") + if microphone_type is None: + raise ValueError("The 'microphone_type' must be an argument of the test function.") + if mock is None: + raise ValueError("The 'mock' variable must be an argument of the test function.") + + if not mock and not request.getfixturevalue("is_microphone_available"): + pytest.skip(f"A {microphone_type} microphone is not available.") + + return func(*args, **kwargs) + + return wrapper def require_motor(func): @wraps(func) @@ -314,6 +344,12 @@ def make_camera(camera_type: str, **kwargs) -> Camera: else: raise ValueError(f"The camera type '{camera_type}' is not valid.") +def make_microphone(microphone_type: str, **kwargs) -> Microphone: + if microphone_type == "microphone": + microphone_index = kwargs.pop("microphone_index", MICROPHONE_INDEX) + return make_microphone_device(microphone_type, microphone_index=microphone_index, **kwargs) + else: + raise ValueError(f"The microphone type '{microphone_type}' is not valid.") # TODO(rcadene, aliberts): remove this dark pattern that overrides def make_motors_bus(motor_type: str, **kwargs) -> MotorsBus: From 43a82e2aef0507201a84b1f22c1cce3265605def Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Tue, 8 Apr 2025 16:24:44 +0200 Subject: [PATCH 16/32] Renamming sampling rate to sample rate for consistency --- lerobot/common/datasets/video_utils.py | 6 ++--- .../robot_devices/microphones/configs.py | 2 +- .../robot_devices/microphones/microphone.py | 26 +++++++++---------- tests/datasets/test_datasets.py | 2 +- tests/microphones/test_microphones.py | 4 +-- 5 files changed, 20 insertions(+), 20 deletions(-) diff --git a/lerobot/common/datasets/video_utils.py b/lerobot/common/datasets/video_utils.py index 4c96a400..e8e85411 100644 --- a/lerobot/common/datasets/video_utils.py +++ b/lerobot/common/datasets/video_utils.py @@ -79,12 +79,12 @@ def decode_audio_torchvision( audio_path = str(audio_path) reader = torchaudio.io.StreamReader(src=audio_path) - audio_sampling_rate = reader.get_src_stream_info(reader.default_audio_stream).sample_rate + audio_sample_rate = reader.get_src_stream_info(reader.default_audio_stream).sample_rate #TODO(CarolinePascal) : sort timestamps ? reader.add_basic_audio_stream( - frames_per_chunk = int(ceil(duration * audio_sampling_rate)), #Too much is better than not enough + frames_per_chunk = int(ceil(duration * audio_sample_rate)), #Too much is better than not enough buffer_chunk_size = -1, #No dropping frames format = "fltp", #Format as float32 ) @@ -99,7 +99,7 @@ def decode_audio_torchvision( current_audio_chunk = reader.pop_chunks()[0] if log_loaded_timestamps: - logging.info(f"audio chunk loaded at starting timestamp={current_audio_chunk["pts"]:.4f} with duration={len(current_audio_chunk) / audio_sampling_rate:.4f}") + logging.info(f"audio chunk loaded at starting timestamp={current_audio_chunk["pts"]:.4f} with duration={len(current_audio_chunk) / audio_sample_rate:.4f}") audio_chunks.append(current_audio_chunk) diff --git a/lerobot/common/robot_devices/microphones/configs.py b/lerobot/common/robot_devices/microphones/configs.py index e90da519..c2700723 100644 --- a/lerobot/common/robot_devices/microphones/configs.py +++ b/lerobot/common/robot_devices/microphones/configs.py @@ -31,6 +31,6 @@ class MicrophoneConfig(MicrophoneConfigBase): """ microphone_index: int - sampling_rate: int | None = None + sample_rate: int | None = None channels: list[int] | None = None mock: bool = False \ No newline at end of file diff --git a/lerobot/common/robot_devices/microphones/microphone.py b/lerobot/common/robot_devices/microphones/microphone.py index c4cd7bac..f2a10ef3 100644 --- a/lerobot/common/robot_devices/microphones/microphone.py +++ b/lerobot/common/robot_devices/microphones/microphone.py @@ -80,7 +80,7 @@ def record_audio_from_microphones( microphone = Microphone(config) microphone.connect() print( - f"Recording audio from microphone {microphone_id} for {record_time_s} seconds at {microphone.sampling_rate} Hz." + f"Recording audio from microphone {microphone_id} for {record_time_s} seconds at {microphone.sample_rate} Hz." ) microphones.append(microphone) @@ -111,13 +111,13 @@ class Microphone: """ The Microphone class handles all microphones compatible with sounddevice (and the underlying PortAudio library). Most microphones and sound cards are compatible, accross all OS (Linux, Mac, Windows). - A Microphone instance requires the sounddevice index of the microphone, which may be obtained using `python -m sounddevice`. It also requires the recording sampling rate as well as the list of recorded channels. + A Microphone instance requires the sounddevice index of the microphone, which may be obtained using `python -m sounddevice`. It also requires the recording sample rate as well as the list of recorded channels. Example of usage: ```python from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig - config = MicrophoneConfig(microphone_index=0, sampling_rate=16000, channels=[1]) + config = MicrophoneConfig(microphone_index=0, sample_rate=16000, channels=[1]) microphone = Microphone(config) microphone.connect() @@ -134,8 +134,8 @@ class Microphone: self.config = config self.microphone_index = config.microphone_index - #Store the recording sampling rate and channels - self.sampling_rate = config.sampling_rate + #Store the recording sample rate and channels + self.sample_rate = config.sample_rate self.channels = config.channels self.mock = config.mock @@ -177,15 +177,15 @@ class Microphone: #Check if provided recording parameters are compatible with the microphone actual_microphone = sd.query_devices(self.microphone_index) - if self.sampling_rate is not None : - if self.sampling_rate > actual_microphone["default_samplerate"]: + if self.sample_rate is not None : + if self.sample_rate > actual_microphone["default_samplerate"]: raise OSError( - f"Provided sampling rate {self.sampling_rate} is higher than the sampling rate of the microphone {actual_microphone['default_samplerate']}." + f"Provided sample rate {self.sample_rate} is higher than the sample rate of the microphone {actual_microphone['default_samplerate']}." ) - elif self.sampling_rate < actual_microphone["default_samplerate"]: - logging.warning("Provided sampling rate is lower than the sampling rate of the microphone. Performance may be impacted.") + elif self.sample_rate < actual_microphone["default_samplerate"]: + logging.warning("Provided sample rate is lower than the sample rate of the microphone. Performance may be impacted.") else: - self.sampling_rate = int(actual_microphone["default_samplerate"]) + self.sample_rate = int(actual_microphone["default_samplerate"]) if self.channels is not None: if any(c > actual_microphone["max_input_channels"] for c in self.channels): @@ -201,7 +201,7 @@ class Microphone: #Create the audio stream self.stream = sd.InputStream( device=self.microphone_index, - samplerate=self.sampling_rate, + samplerate=self.sample_rate, channels=max(self.channels)+1, dtype="float32", callback=self._audio_callback, @@ -221,7 +221,7 @@ class Microphone: def _record_loop(self, output_file: Path) -> None: #Can only be run on a single process/thread for file writing safety - with sf.SoundFile(output_file, mode='x', samplerate=self.sampling_rate, + with sf.SoundFile(output_file, mode='x', samplerate=self.sample_rate, channels=max(self.channels)+1, subtype=sf.default_subtype(output_file.suffix[1:])) as file: while not self.record_stop_event.is_set(): file.write(self.record_queue.get()) diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 1ce0e7d8..dde2ed06 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -349,7 +349,7 @@ def test_add_frame_audio(audio_dataset): dataset.save_episode() - assert dataset[0]["observation.audio.microphone"].shape == torch.Size((int(DEFAULT_AUDIO_CHUNK_DURATION*microphone.sampling_rate),DUMMY_AUDIO_CHANNELS)) + assert dataset[0]["observation.audio.microphone"].shape == torch.Size((int(DEFAULT_AUDIO_CHUNK_DURATION*microphone.sample_rate),DUMMY_AUDIO_CHANNELS)) # TODO(aliberts): # - [ ] test various attributes & state from init and create diff --git a/tests/microphones/test_microphones.py b/tests/microphones/test_microphones.py index 50c65119..3ce29fa2 100644 --- a/tests/microphones/test_microphones.py +++ b/tests/microphones/test_microphones.py @@ -75,7 +75,7 @@ def test_microphone(tmp_path, request, microphone_type, mock): microphone = make_microphone(**microphone_kwargs) microphone.connect() assert microphone.is_connected - assert microphone.sampling_rate is not None + assert microphone.sample_rate is not None assert microphone.channels is not None # Test connecting twice raises an error @@ -122,7 +122,7 @@ def test_microphone(tmp_path, request, microphone_type, mock): microphone.stop_recording() recorded_audio, recorded_sample_rate = read(fpath) - assert recorded_sample_rate == microphone.sampling_rate + assert recorded_sample_rate == microphone.sample_rate error_msg = ( "Recording time difference between read() and stop_recording()", From d53035d0473eb368460f608ae41451240005ad32 Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Wed, 9 Apr 2025 14:59:29 +0200 Subject: [PATCH 17/32] Adding multiprocessing support for audio recording --- .../robot_devices/microphones/microphone.py | 95 ++++++++++--------- .../common/robot_devices/microphones/utils.py | 2 +- 2 files changed, 52 insertions(+), 45 deletions(-) diff --git a/lerobot/common/robot_devices/microphones/microphone.py b/lerobot/common/robot_devices/microphones/microphone.py index f2a10ef3..25a46049 100644 --- a/lerobot/common/robot_devices/microphones/microphone.py +++ b/lerobot/common/robot_devices/microphones/microphone.py @@ -21,13 +21,18 @@ import soundfile as sf import numpy as np import logging from threading import Thread, Event -from queue import Queue -from os.path import splitext -from os import remove, getcwd +from multiprocessing import Process +from queue import Empty + +from queue import Queue as thread_Queue +from threading import Event as thread_Event +from multiprocessing import JoinableQueue as process_Queue +from multiprocessing import Event as process_Event + +from os import getcwd from pathlib import Path import shutil import time -from concurrent.futures import ThreadPoolExecutor from lerobot.common.utils.utils import capture_timestamp_utc @@ -37,7 +42,6 @@ from lerobot.common.robot_devices.utils import ( RobotDeviceNotConnectedError, RobotDeviceNotRecordingError, RobotDeviceAlreadyRecordingError, - busy_wait, ) def find_microphones(raise_when_empty=False, mock=False) -> list[dict]: @@ -144,8 +148,8 @@ class Microphone: self.stream = None #Thread-safe concurrent queue to store the recorded/read audio - self.record_queue = Queue() - self.read_queue = Queue() + self.record_queue = None + self.read_queue = None #Thread to handle data reading and file writing in a separate thread (safely) self.record_thread = None @@ -219,13 +223,17 @@ class Microphone: self.record_queue.put(indata[:,self.channels]) self.read_queue.put(indata[:,self.channels]) - def _record_loop(self, output_file: Path) -> None: + @staticmethod + def _record_loop(queue, event: Event, sample_rate: int, channels: list[int], output_file: Path) -> None: #Can only be run on a single process/thread for file writing safety - with sf.SoundFile(output_file, mode='x', samplerate=self.sample_rate, - channels=max(self.channels)+1, subtype=sf.default_subtype(output_file.suffix[1:])) as file: - while not self.record_stop_event.is_set(): - file.write(self.record_queue.get()) - #self.record_queue.task_done() + with sf.SoundFile(output_file, mode='x', samplerate=sample_rate, + channels=max(channels)+1, subtype=sf.default_subtype(output_file.suffix[1:])) as file: + while not event.is_set(): + try: + file.write(queue.get(timeout=0.02)) #Timeout set as twice the usual sounddevice buffer size + queue.task_done() + except Empty: + continue def _read(self) -> np.ndarray: """ @@ -233,17 +241,15 @@ class Microphone: -> PROS : Inherently thread safe, no need to lock the queue, lightweight CPU usage -> CONS : Reading duration does not scale well with the number of channels and reading duration """ - try: - audio_readings = self.read_queue.queue - except Queue.Empty: - audio_readings = np.empty((0, len(self.channels))) - else: - #TODO(CarolinePascal): Check if this is the fastest way to do it - self.read_queue = Queue() - with self.read_queue.mutex: - self.read_queue.queue.clear() - #self.read_queue.all_tasks_done.notify_all() - audio_readings = np.array(audio_readings, dtype=np.float32).reshape(-1, len(self.channels)) + audio_readings = np.empty((0, len(self.channels))) + + while True: + try: + audio_readings = np.concatenate((audio_readings, self.read_queue.get_nowait()), axis=0) + except Empty: + break + + self.read_queue = thread_Queue() return audio_readings @@ -266,31 +272,32 @@ class Microphone: return audio_readings - def start_recording(self, output_file : str | None = None) -> None: + def start_recording(self, output_file : str | None = None, multiprocessing : bool | None = False) -> None: if not self.is_connected: raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.") if self.is_recording: raise RobotDeviceAlreadyRecordingError(f"Microphone {self.microphone_index} is already recording.") - self.read_queue = Queue() - with self.read_queue.mutex: - self.read_queue.queue.clear() - #self.read_queue.all_tasks_done.notify_all() + #Reset queues + self.read_queue = thread_Queue() + if multiprocessing: + self.record_queue = process_Queue() + else: + self.record_queue = thread_Queue() - self.record_queue = Queue() - with self.record_queue.mutex: - self.record_queue.queue.clear() - #self.record_queue.all_tasks_done.notify_all() - - #Recording case + #Write recordings into a file if output_file is provided if output_file is not None: output_file = Path(output_file) if output_file.exists(): output_file.unlink() - self.record_stop_event = Event() - self.record_thread = Thread(target=self._record_loop, args=(output_file,)) + if multiprocessing: + self.record_stop_event = process_Event() + self.record_thread = Process(target=Microphone._record_loop, args=(self.record_queue, self.record_stop_event, self.sample_rate, self.channels, output_file, )) + else: + self.record_stop_event = thread_Event() + self.record_thread = Thread(target=Microphone._record_loop, args=(self.record_queue, self.record_stop_event, self.sample_rate, self.channels, output_file, )) self.record_thread.daemon = True self.record_thread.start() @@ -304,18 +311,18 @@ class Microphone: if not self.is_recording: raise RobotDeviceNotRecordingError(f"Microphone {self.microphone_index} is not recording.") + if self.stream.active: + self.stream.stop() #Wait for all buffers to be processed + #Remark : stream.abort() flushes the buffers ! + self.is_recording = False + if self.record_thread is not None: - #self.record_queue.join() + self.record_queue.join() self.record_stop_event.set() self.record_thread.join() self.record_thread = None self.record_stop_event = None - - if self.stream.active: - self.stream.stop() #Wait for all buffers to be processed - #Remark : stream.abort() flushes the buffers ! - - self.is_recording = False + self.is_writing = False def disconnect(self) -> None: diff --git a/lerobot/common/robot_devices/microphones/utils.py b/lerobot/common/robot_devices/microphones/utils.py index ea5790dc..1b1ad099 100644 --- a/lerobot/common/robot_devices/microphones/utils.py +++ b/lerobot/common/robot_devices/microphones/utils.py @@ -20,7 +20,7 @@ from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig, M class Microphone(Protocol): def connect(self): ... def disconnect(self): ... - def start_recording(self, output_file: str | None = None): ... + def start_recording(self, output_file : str | None = None, multiprocessing : bool | None = False): ... def stop_recording(self): ... def make_microphones_from_configs(microphone_configs: dict[str, MicrophoneConfigBase]) -> list[Microphone]: From 1e5e6317438b8c133da4de65ed93d57cb1a89cc6 Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Wed, 9 Apr 2025 15:00:18 +0200 Subject: [PATCH 18/32] Adding flag for file writting recording case --- lerobot/common/robot_devices/microphones/microphone.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/lerobot/common/robot_devices/microphones/microphone.py b/lerobot/common/robot_devices/microphones/microphone.py index 25a46049..2d75293a 100644 --- a/lerobot/common/robot_devices/microphones/microphone.py +++ b/lerobot/common/robot_devices/microphones/microphone.py @@ -158,6 +158,7 @@ class Microphone: self.logs = {} self.is_connected = False self.is_recording = False + self.is_writing = False def connect(self) -> None: if self.is_connected: @@ -220,7 +221,8 @@ class Microphone: logging.warning(status) # Slicing makes copy unecessary # Two separate queues are necessary because .get() also pops the data from the queue - self.record_queue.put(indata[:,self.channels]) + if self.is_writing: + self.record_queue.put(indata[:,self.channels]) self.read_queue.put(indata[:,self.channels]) @staticmethod @@ -300,6 +302,8 @@ class Microphone: self.record_thread = Thread(target=Microphone._record_loop, args=(self.record_queue, self.record_stop_event, self.sample_rate, self.channels, output_file, )) self.record_thread.daemon = True self.record_thread.start() + + self.is_writing = True self.is_recording = True self.stream.start() From ec8943db370e3cdef1f242b8b5e8cd83960cf136 Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Wed, 9 Apr 2025 17:53:08 +0200 Subject: [PATCH 19/32] Adding support for audio data recording and broadcasting for LeKiwi --- lerobot/common/datasets/compute_stats.py | 15 +++++-- lerobot/common/datasets/lerobot_dataset.py | 29 ++++++++++---- lerobot/common/datasets/utils.py | 2 +- lerobot/common/robot_devices/control_utils.py | 2 +- .../robot_devices/robots/lekiwi_remote.py | 34 ++++++++++++++++ .../robot_devices/robots/manipulator.py | 2 +- .../robots/mobile_manipulator.py | 40 ++++++++++++++++++- tests/datasets/test_compute_stats.py | 19 ++++++--- 8 files changed, 123 insertions(+), 20 deletions(-) diff --git a/lerobot/common/datasets/compute_stats.py b/lerobot/common/datasets/compute_stats.py index b24dbaf8..08ac4ae6 100644 --- a/lerobot/common/datasets/compute_stats.py +++ b/lerobot/common/datasets/compute_stats.py @@ -15,7 +15,7 @@ # limitations under the License. import numpy as np -from lerobot.common.datasets.utils import load_image_as_numpy, load_audio +from lerobot.common.datasets.utils import load_image_as_numpy, load_audio_from_path def estimate_num_samples( dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75 @@ -70,13 +70,17 @@ def sample_images(image_paths: list[str]) -> np.ndarray: return images -def sample_audio(audio_path: str) -> np.ndarray: +def sample_audio_from_path(audio_path: str) -> np.ndarray: - data = load_audio(audio_path) + data = load_audio_from_path(audio_path) sampled_indices = sample_indices(len(data)) return(data[sampled_indices]) +def sample_audio_from_data(data: np.ndarray) -> np.ndarray: + sampled_indices = sample_indices(len(data)) + return data[sampled_indices] + def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]: return { "min": np.min(array, axis=axis, keepdims=keepdims), @@ -97,7 +101,10 @@ def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], featu axes_to_reduce = (0, 2, 3) # keep channel dim keepdims = True elif features[key]["dtype"] == "audio": - ep_ft_array = sample_audio(data[0]) + try: + ep_ft_array = sample_audio_from_path(data[0]) + except TypeError: #Should only be triggered for LeKiwi robot + ep_ft_array = sample_audio_from_data(data) axes_to_reduce = 0 keepdims = True else: diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 1966b8e7..bc6689c7 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -80,6 +80,7 @@ from lerobot.common.datasets.video_utils import ( ) from lerobot.common.robot_devices.robots.utils import Robot from lerobot.common.robot_devices.microphones.utils import Microphone +import soundfile as sf CODEBASE_VERSION = "v2.1" @@ -324,7 +325,7 @@ class LeRobotDatasetMetadata: been encoded the same way. Also, this means it assumes the first episode exists. """ for key in set(self.audio_keys) - set(self.audio_camera_keys_mapping.keys()): - if not self.features[key].get("info", None): + if not self.features[key].get("info", None) or (len(self.features[key]["info"]) == 1 and "sample_rate" in self.features[key]["info"]): audio_path = self.root / self.get_compressed_audio_file_path(0, key) self.info["features"][key]["info"] = get_audio_info(audio_path) @@ -910,11 +911,14 @@ class LeRobotDataset(torch.utils.data.Dataset): self._save_image(frame[key], img_path) self.episode_buffer[key].append(str(img_path)) elif self.features[key]["dtype"] == "audio": - if frame_index == 0: - audio_path = self._get_raw_audio_file_path( - episode_index=self.episode_buffer["episode_index"], audio_key=key - ) - self.episode_buffer[key].append(str(audio_path)) + if self.meta.robot_type.startswith("lekiwi"): + self.episode_buffer[key].append(frame[key]) + else: + if frame_index == 0: + audio_path = self._get_raw_audio_file_path( + episode_index=self.episode_buffer["episode_index"], audio_key=key + ) + self.episode_buffer[key].append(str(audio_path)) else: self.episode_buffer[key].append(frame[key]) @@ -966,12 +970,23 @@ class LeRobotDataset(torch.utils.data.Dataset): for key, ft in self.features.items(): # index, episode_index, task_index are already processed above, and image and video # are processed separately by storing image path and frame info as meta data - if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video", "audio"]: + if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]: + continue + elif ft["dtype"] == "audio": + if self.meta.robot_type.startswith("lekiwi"): + episode_buffer[key] = np.concatenate(episode_buffer[key], axis=0) continue episode_buffer[key] = np.stack(episode_buffer[key]) self._wait_image_writer() self._save_episode_table(episode_buffer, episode_index) + + if self.meta.robot_type.startswith("lekiwi"): + for key in self.meta.audio_keys: + audio_path = self._get_raw_audio_file_path(episode_index=self.episode_buffer["episode_index"][0], audio_key=key) + with sf.SoundFile(audio_path, mode='w', samplerate=self.meta.features[key]["info"]["sample_rate"], channels=self.meta.features[key]["shape"][0]) as file: + file.write(episode_buffer[key]) + ep_stats = compute_episode_stats(episode_buffer, self.features) if len(self.meta.video_keys) > 0: diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 80143939..970c447d 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -260,7 +260,7 @@ def load_image_as_numpy( img_array /= 255.0 return img_array -def load_audio(fpath: str | Path) -> np.ndarray: +def load_audio_from_path(fpath: str | Path) -> np.ndarray: audio_data, _ = read(fpath, dtype="float32") return audio_data diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py index 7f71b1ee..7c8706a4 100644 --- a/lerobot/common/robot_devices/control_utils.py +++ b/lerobot/common/robot_devices/control_utils.py @@ -252,7 +252,7 @@ def control_loop( timestamp = 0 start_episode_t = time.perf_counter() - if dataset is not None: + if dataset is not None and not robot.robot_type.startswith("lekiwi"): #For now, LeKiwi only supports frame audio recording (which may lead to audio chunks loss, extended post-processing, increased memory usage) for microphone_key, microphone in robot.microphones.items(): #Start recording both in file writing and data reading mode dataset.add_microphone_recording(microphone, microphone_key) diff --git a/lerobot/common/robot_devices/robots/lekiwi_remote.py b/lerobot/common/robot_devices/robots/lekiwi_remote.py index 7bf52d21..03576c0c 100644 --- a/lerobot/common/robot_devices/robots/lekiwi_remote.py +++ b/lerobot/common/robot_devices/robots/lekiwi_remote.py @@ -51,6 +51,14 @@ def run_camera_capture(cameras, images_lock, latest_images_dict, stop_event): latest_images_dict.update(local_dict) time.sleep(0.01) +def run_microphone_capture(microphones, audio_lock, latest_audio_dict, stop_event): + while not stop_event.is_set(): + local_dict = {} + for name, microphone in microphones.items(): + audio_readings = microphone.read() + local_dict[name] = audio_readings + with audio_lock: + latest_audio_dict.update(local_dict) def calibrate_follower_arm(motors_bus, calib_dir_str): """ @@ -94,6 +102,7 @@ def run_lekiwi(robot_config): """ # Import helper functions and classes from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs + from lerobot.common.robot_devices.microphones.utils import make_microphones_from_configs from lerobot.common.robot_devices.motors.feetech import FeetechMotorsBus, TorqueMode # Initialize cameras from the robot configuration. @@ -101,6 +110,11 @@ def run_lekiwi(robot_config): for cam in cameras.values(): cam.connect() + # Initialize microphones from the robot configuration. + microphones = make_microphones_from_configs(robot_config.microphones) + for microphone in microphones.values(): + microphone.connect() + # Initialize the motors bus using the follower arm configuration. motor_config = robot_config.follower_arms.get("main") if motor_config is None: @@ -134,6 +148,18 @@ def run_lekiwi(robot_config): ) cam_thread.start() + # Start the microphone recording and capture thread. + #TODO(CarolinePascal) : Leverage multi-core processing with a multiprocessing.Process instead ! + latest_audio_dict = {} + audio_lock = threading.Lock() + audio_stop_event = threading.Event() + microphone_thread = threading.Thread( + target=run_microphone_capture, args=(microphones, audio_lock, latest_audio_dict, audio_stop_event), daemon=True + ) + for microphone in microphones.values(): + microphone.start_recording() + microphone_thread.start() + last_cmd_time = time.time() print("LeKiwi robot server started. Waiting for commands...") @@ -198,9 +224,14 @@ def run_lekiwi(robot_config): with images_lock: images_dict_copy = dict(latest_images_dict) + # Get the latest audio data. + with audio_lock: + audio_dict_copy = dict(latest_audio_dict) + # Build the observation dictionary. observation = { "images": images_dict_copy, + "audio": audio_dict_copy, #TODO(CarolinePascal) : This is a nasty way to do it, sorry. "present_speed": current_velocity, "follower_arm_state": follower_arm_state, } @@ -217,6 +248,9 @@ def run_lekiwi(robot_config): finally: stop_event.set() cam_thread.join() + microphone_thread.join() + for microphone in microphones.values(): + microphone.stop_recording() robot.stop() motors_bus.disconnect() cmd_socket.close() diff --git a/lerobot/common/robot_devices/robots/manipulator.py b/lerobot/common/robot_devices/robots/manipulator.py index 16edc4fb..dc68d609 100644 --- a/lerobot/common/robot_devices/robots/manipulator.py +++ b/lerobot/common/robot_devices/robots/manipulator.py @@ -211,7 +211,7 @@ class ManipulatorRobot: "dtype": "audio", "shape": (len(mic.channels),), "names": "channels", - "info" : None, + "info" : {"sample_rate": mic.sample_rate}, } return mic_ft diff --git a/lerobot/common/robot_devices/robots/mobile_manipulator.py b/lerobot/common/robot_devices/robots/mobile_manipulator.py index 385e218b..98e4cdb1 100644 --- a/lerobot/common/robot_devices/robots/mobile_manipulator.py +++ b/lerobot/common/robot_devices/robots/mobile_manipulator.py @@ -24,6 +24,7 @@ import torch import zmq from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs +from lerobot.common.robot_devices.microphones.utils import make_microphones_from_configs from lerobot.common.robot_devices.motors.feetech import TorqueMode from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs from lerobot.common.robot_devices.robots.configs import LeKiwiRobotConfig @@ -79,6 +80,7 @@ class MobileManipulator: self.follower_arms = make_motors_buses_from_configs(self.config.follower_arms) self.cameras = make_cameras_from_configs(self.config.cameras) + self.microphones = make_microphones_from_configs(self.config.microphones) self.is_connected = False @@ -133,6 +135,7 @@ class MobileManipulator: "shape": (cam.height, cam.width, cam.channels), "names": ["height", "width", "channels"], "info": None, + "audio": "observation.audio." + cam.microphone if cam.microphone is not None else None, } return cam_ft @@ -160,10 +163,23 @@ class MobileManipulator: "names": combined_names, }, } + + @property + def microphone_features(self) -> dict: + mic_ft = {} + for mic_key, mic in self.microphones.items(): + key = f"observation.audio.{mic_key}" + mic_ft[key] = { + "dtype": "audio", + "shape": (len(mic.channels),), + "names": "channels", + "info" : {"sample_rate": mic.sample_rate}, + } + return mic_ft @property def features(self): - return {**self.motor_features, **self.camera_features} + return {**self.motor_features, **self.camera_features, **self.microphone_features} @property def has_camera(self): @@ -172,6 +188,14 @@ class MobileManipulator: @property def num_cameras(self): return len(self.cameras) + + @property + def has_microphone(self): + return len(self.microphones) > 0 + + @property + def num_microphones(self): + return len(self.microphones) @property def available_arms(self): @@ -344,6 +368,7 @@ class MobileManipulator: observation = json.loads(last_msg) images_dict = observation.get("images", {}) + audio_dict = observation.get("audio", {}) new_speed = observation.get("present_speed", {}) new_arm_state = observation.get("follower_arm_state", None) @@ -356,6 +381,11 @@ class MobileManipulator: if frame_candidate is not None: frames[cam_name] = frame_candidate + # Recieve audio + for microphone_name, audio_data in audio_dict.items(): + if audio_data: + frames[microphone_name] = audio_data + # If remote_arm_state is None and frames is None there is no message then use the previous message if new_arm_state is not None and frames is not None: self.last_frames = frames @@ -475,6 +505,14 @@ class MobileManipulator: frame = np.zeros((cam.height, cam.width, cam.channels), dtype=np.uint8) obs_dict[f"observation.images.{cam_name}"] = torch.from_numpy(frame) + # Loop over each configured microphone + for microphone_name, microphone in self.microphones.items(): + frame = frames.get(microphone_name, None) + if frame is None: + # Create silence using the microphone's configured channels + frame = np.zeros((1, len(microphone.channels)), dtype=np.float32) + obs_dict[f"observation.audio.{microphone_name}"] = torch.from_numpy(frame) + return obs_dict def send_action(self, action: torch.Tensor) -> torch.Tensor: diff --git a/tests/datasets/test_compute_stats.py b/tests/datasets/test_compute_stats.py index 113944b3..56cdc176 100644 --- a/tests/datasets/test_compute_stats.py +++ b/tests/datasets/test_compute_stats.py @@ -26,7 +26,8 @@ from lerobot.common.datasets.compute_stats import ( estimate_num_samples, get_feature_stats, sample_images, - sample_audio, + sample_audio_from_path, + sample_audio_from_data, sample_indices, ) @@ -73,10 +74,18 @@ def test_sample_images(mock_load): assert images.dtype == np.uint8 assert len(images) == estimate_num_samples(100) -@patch("lerobot.common.datasets.compute_stats.load_audio", side_effect=mock_load_audio) -def test_sample_audio(mock_load): +@patch("lerobot.common.datasets.compute_stats.load_audio_from_path", side_effect=mock_load_audio) +def test_sample_audio_from_path(mock_load): audio_path = "audio.wav" - audio_samples = sample_audio(audio_path) + audio_samples = sample_audio_from_path(audio_path) + assert isinstance(audio_samples, np.ndarray) + assert audio_samples.shape[1] == 2 + assert audio_samples.dtype == np.float32 + assert len(audio_samples) == estimate_num_samples(16000) + +def test_sample_audio_from_data(mock_load): + audio_data = np.ones((16000, 2), dtype=np.float32) + audio_samples = sample_audio_from_data(audio_data) assert isinstance(audio_samples, np.ndarray) assert audio_samples.shape[1] == 2 assert audio_samples.dtype == np.float32 @@ -166,7 +175,7 @@ def test_compute_episode_stats(): with patch( "lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy ), patch( - "lerobot.common.datasets.compute_stats.load_audio", side_effect=mock_load_audio + "lerobot.common.datasets.compute_stats.load_audio_from_path", side_effect=mock_load_audio ): stats = compute_episode_stats(episode_data, features) From 9c667d347ce960ed383b4b60963feb5fc01acabf Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Thu, 10 Apr 2025 11:04:03 +0200 Subject: [PATCH 20/32] Adding sample SO100 configuration for testing --- lerobot/common/robot_devices/robots/configs.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/lerobot/common/robot_devices/robots/configs.py b/lerobot/common/robot_devices/robots/configs.py index 6bb3b93b..66edc4f6 100644 --- a/lerobot/common/robot_devices/robots/configs.py +++ b/lerobot/common/robot_devices/robots/configs.py @@ -486,6 +486,7 @@ class So100RobotConfig(ManipulatorRobotConfig): fps=30, width=640, height=480, + microphone="laptop", ), "phone": OpenCVCameraConfig( camera_index=1, @@ -496,6 +497,21 @@ class So100RobotConfig(ManipulatorRobotConfig): } ) + microphones: dict[str, MicrophoneConfig] = field( + default_factory=lambda: { + "laptop": MicrophoneConfig( + microphone_index=0, + sample_rate=48000, + channels=[1], + ), + "headset": MicrophoneConfig( + microphone_index=1, + sample_rate=44100, + channels=[1], + ), + } + ) + mock: bool = False From 0cb9345f06644a6da79e0d49c406422a624185d2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 10 Apr 2025 16:56:39 +0000 Subject: [PATCH 21/32] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- lerobot/common/datasets/compute_stats.py | 13 +- lerobot/common/datasets/lerobot_dataset.py | 101 +++++++---- lerobot/common/datasets/utils.py | 17 +- lerobot/common/datasets/video_utils.py | 56 ++++--- lerobot/common/robot_devices/control_utils.py | 10 +- .../robot_devices/microphones/configs.py | 4 +- .../robot_devices/microphones/microphone.py | 157 ++++++++++-------- .../common/robot_devices/microphones/utils.py | 9 +- .../common/robot_devices/robots/configs.py | 2 +- .../robot_devices/robots/lekiwi_remote.py | 10 +- .../robot_devices/robots/manipulator.py | 8 +- .../robots/mobile_manipulator.py | 10 +- lerobot/common/robot_devices/utils.py | 4 +- tests/conftest.py | 6 +- tests/datasets/test_compute_stats.py | 24 ++- tests/datasets/test_datasets.py | 17 +- tests/fixtures/constants.py | 7 +- tests/fixtures/dataset_factories.py | 2 +- tests/microphones/mock_sounddevice.py | 30 ++-- tests/microphones/test_microphones.py | 26 ++- tests/robots/test_robots.py | 2 +- tests/utils.py | 10 +- 22 files changed, 329 insertions(+), 196 deletions(-) diff --git a/lerobot/common/datasets/compute_stats.py b/lerobot/common/datasets/compute_stats.py index 08ac4ae6..36606719 100644 --- a/lerobot/common/datasets/compute_stats.py +++ b/lerobot/common/datasets/compute_stats.py @@ -15,7 +15,8 @@ # limitations under the License. import numpy as np -from lerobot.common.datasets.utils import load_image_as_numpy, load_audio_from_path +from lerobot.common.datasets.utils import load_audio_from_path, load_image_as_numpy + def estimate_num_samples( dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75 @@ -70,17 +71,19 @@ def sample_images(image_paths: list[str]) -> np.ndarray: return images -def sample_audio_from_path(audio_path: str) -> np.ndarray: +def sample_audio_from_path(audio_path: str) -> np.ndarray: data = load_audio_from_path(audio_path) sampled_indices = sample_indices(len(data)) - return(data[sampled_indices]) + return data[sampled_indices] + def sample_audio_from_data(data: np.ndarray) -> np.ndarray: sampled_indices = sample_indices(len(data)) return data[sampled_indices] + def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]: return { "min": np.min(array, axis=axis, keepdims=keepdims), @@ -103,9 +106,9 @@ def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], featu elif features[key]["dtype"] == "audio": try: ep_ft_array = sample_audio_from_path(data[0]) - except TypeError: #Should only be triggered for LeKiwi robot + except TypeError: # Should only be triggered for LeKiwi robot ep_ft_array = sample_audio_from_data(data) - axes_to_reduce = 0 + axes_to_reduce = 0 keepdims = True else: ep_ft_array = data # data is already a np.ndarray diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index bc6689c7..f844eb72 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -23,6 +23,7 @@ import datasets import numpy as np import packaging.version import PIL.Image +import soundfile as sf import torch import torch.utils from datasets import concatenate_datasets, load_dataset @@ -34,13 +35,12 @@ from lerobot.common.constants import HF_LEROBOT_HOME from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image from lerobot.common.datasets.utils import ( + DEFAULT_AUDIO_CHUNK_DURATION, DEFAULT_FEATURES, DEFAULT_IMAGE_PATH, DEFAULT_RAW_AUDIO_PATH, - DEFAULT_COMPRESSED_AUDIO_PATH, INFO_PATH, TASKS_PATH, - DEFAULT_AUDIO_CHUNK_DURATION, append_jsonlines, backward_compatible_episodes_stats, check_delta_timestamps, @@ -70,17 +70,16 @@ from lerobot.common.datasets.utils import ( ) from lerobot.common.datasets.video_utils import ( VideoFrame, - decode_video_frames, - encode_video_frames, - encode_audio, decode_audio, + decode_video_frames, + encode_audio, + encode_video_frames, + get_audio_info, get_safe_default_codec, get_video_info, - get_audio_info, ) -from lerobot.common.robot_devices.robots.utils import Robot from lerobot.common.robot_devices.microphones.utils import Microphone -import soundfile as sf +from lerobot.common.robot_devices.robots.utils import Robot CODEBASE_VERSION = "v2.1" @@ -149,10 +148,12 @@ class LeRobotDatasetMetadata: ep_chunk = self.get_episode_chunk(ep_index) fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index) return Path(fpath) - + def get_compressed_audio_file_path(self, episode_index: int, audio_key: str) -> Path: episode_chunk = self.get_episode_chunk(episode_index) - fpath = self.audio_path.format(episode_chunk=episode_chunk, audio_key=audio_key, episode_index=episode_index) + fpath = self.audio_path.format( + episode_chunk=episode_chunk, audio_key=audio_key, episode_index=episode_index + ) return self.root / fpath def get_episode_chunk(self, ep_index: int) -> int: @@ -167,7 +168,7 @@ class LeRobotDatasetMetadata: def video_path(self) -> str | None: """Formattable string for the video files.""" return self.info["video_path"] - + @property def audio_path(self) -> str | None: """Formattable string for the audio files.""" @@ -202,16 +203,21 @@ class LeRobotDatasetMetadata: def camera_keys(self) -> list[str]: """Keys to access visual modalities (regardless of their storage method).""" return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]] - + @property def audio_keys(self) -> list[str]: """Keys to access audio modalities.""" return [key for key, ft in self.features.items() if ft["dtype"] == "audio"] - + @property def audio_camera_keys_mapping(self) -> dict[str, str]: """Mapping between camera keys and audio keys when both are linked.""" - return {self.features[camera_key]["audio"]:camera_key for camera_key in self.camera_keys if self.features[camera_key]["audio"] is not None and self.features[camera_key]["dtype"] == "video"} + return { + self.features[camera_key]["audio"]: camera_key + for camera_key in self.camera_keys + if self.features[camera_key]["audio"] is not None + and self.features[camera_key]["dtype"] == "video" + } @property def names(self) -> dict[str, list | dict]: @@ -325,7 +331,9 @@ class LeRobotDatasetMetadata: been encoded the same way. Also, this means it assumes the first episode exists. """ for key in set(self.audio_keys) - set(self.audio_camera_keys_mapping.keys()): - if not self.features[key].get("info", None) or (len(self.features[key]["info"]) == 1 and "sample_rate" in self.features[key]["info"]): + if not self.features[key].get("info", None) or ( + len(self.features[key]["info"]) == 1 and "sample_rate" in self.features[key]["info"] + ): audio_path = self.root / self.get_compressed_audio_file_path(0, key) self.info["features"][key]["info"] = get_audio_info(audio_path) @@ -518,7 +526,9 @@ class LeRobotDataset(torch.utils.data.Dataset): self.tolerance_s = tolerance_s self.revision = revision if revision else CODEBASE_VERSION self.video_backend = video_backend if video_backend else get_safe_default_codec() - self.audio_backend = audio_backend if audio_backend else "ffmpeg" #Waiting for torchcodec release #TODO(CarolinePascal) + self.audio_backend = ( + audio_backend if audio_backend else "ffmpeg" + ) # Waiting for torchcodec release #TODO(CarolinePascal) self.delta_indices = None # Unused attributes @@ -543,7 +553,9 @@ class LeRobotDataset(torch.utils.data.Dataset): self.hf_dataset = self.load_hf_dataset() except (AssertionError, FileNotFoundError, NotADirectoryError): self.revision = get_safe_version(self.repo_id, self.revision) - self.download_episodes(download_videos) #Sould load audio as well #TODO(CarolinePascal): separate audio from video + self.download_episodes( + download_videos + ) # Sould load audio as well #TODO(CarolinePascal): separate audio from video self.hf_dataset = self.load_hf_dataset() self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes) @@ -735,13 +747,13 @@ class LeRobotDataset(torch.utils.data.Dataset): query_timestamps[key] = [current_ts] return query_timestamps - - #TODO(CarolinePascal): add variable query durations + + # TODO(CarolinePascal): add variable query durations def _get_query_timestamps_audio( self, current_ts: float, query_indices: dict[str, list[int]] | None = None, - ) -> dict[str, list[float]]: + ) -> dict[str, list[float]]: query_timestamps = {} for key in self.meta.audio_keys: if query_indices is not None and key in query_indices: @@ -773,14 +785,18 @@ class LeRobotDataset(torch.utils.data.Dataset): return item - #TODO(CarolinePascal): add variable query durations - def _query_audio(self, query_timestamps: dict[str, list[float]], query_duration: float, ep_idx: int) -> dict[str, torch.Tensor]: + # TODO(CarolinePascal): add variable query durations + def _query_audio( + self, query_timestamps: dict[str, list[float]], query_duration: float, ep_idx: int + ) -> dict[str, torch.Tensor]: item = {} for audio_key, query_ts in query_timestamps.items(): - #Audio stored with video in a single .mp4 file + # Audio stored with video in a single .mp4 file if audio_key in self.meta.audio_camera_keys_mapping: - audio_path = self.root / self.meta.get_video_file_path(ep_idx, self.meta.audio_camera_keys_mapping[audio_key]) - #Audio stored alone in a separate .m4a file + audio_path = self.root / self.meta.get_video_file_path( + ep_idx, self.meta.audio_camera_keys_mapping[audio_key] + ) + # Audio stored alone in a separate .m4a file else: audio_path = self.root / self.meta.get_compressed_audio_file_path(ep_idx, audio_key) audio_chunk = decode_audio(audio_path, query_ts, query_duration, self.audio_backend) @@ -855,7 +871,7 @@ class LeRobotDataset(torch.utils.data.Dataset): image_key=image_key, episode_index=episode_index, frame_index=frame_index ) return self.root / fpath - + def _get_raw_audio_file_path(self, episode_index: int, audio_key: str) -> Path: fpath = DEFAULT_RAW_AUDIO_PATH.format(audio_key=audio_key, episode_index=episode_index) return self.root / fpath @@ -929,11 +945,17 @@ class LeRobotDataset(torch.utils.data.Dataset): This function will start recording audio from the microphone and save it to disk. """ - audio_dir = self._get_raw_audio_file_path(self.num_episodes, "observation.audio." + microphone_key).parent + audio_dir = self._get_raw_audio_file_path( + self.num_episodes, "observation.audio." + microphone_key + ).parent if not audio_dir.is_dir(): audio_dir.mkdir(parents=True, exist_ok=True) - - microphone.start_recording(output_file = self._get_raw_audio_file_path(self.num_episodes, "observation.audio." + microphone_key)) + + microphone.start_recording( + output_file=self._get_raw_audio_file_path( + self.num_episodes, "observation.audio." + microphone_key + ) + ) def save_episode(self, episode_data: dict | None = None) -> None: """ @@ -983,8 +1005,15 @@ class LeRobotDataset(torch.utils.data.Dataset): if self.meta.robot_type.startswith("lekiwi"): for key in self.meta.audio_keys: - audio_path = self._get_raw_audio_file_path(episode_index=self.episode_buffer["episode_index"][0], audio_key=key) - with sf.SoundFile(audio_path, mode='w', samplerate=self.meta.features[key]["info"]["sample_rate"], channels=self.meta.features[key]["shape"][0]) as file: + audio_path = self._get_raw_audio_file_path( + episode_index=self.episode_buffer["episode_index"][0], audio_key=key + ) + with sf.SoundFile( + audio_path, + mode="w", + samplerate=self.meta.features[key]["info"]["sample_rate"], + channels=self.meta.features[key]["shape"][0], + ) as file: file.write(episode_buffer[key]) ep_stats = compute_episode_stats(episode_buffer, self.features) @@ -996,7 +1025,7 @@ class LeRobotDataset(torch.utils.data.Dataset): if len(self.meta.audio_keys) > 0: _ = self.encode_episode_audio(episode_index) - + # `meta.save_episode` be executed after encoding the videos self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats) @@ -1113,7 +1142,7 @@ class LeRobotDataset(torch.utils.data.Dataset): encode_video_frames(img_dir, video_path, self.fps, audio_path=audio_path, overwrite=True) return video_paths - + def encode_episode_audio(self, episode_index: int) -> dict: """ Use ffmpeg to convert .wav raw audio files into .m4a audio files. @@ -1124,7 +1153,7 @@ class LeRobotDataset(torch.utils.data.Dataset): for audio_key in set(self.meta.audio_keys) - set(self.meta.audio_camera_keys_mapping.keys()): input_audio_path = self.root / self._get_raw_audio_file_path(episode_index, audio_key) output_audio_path = self.root / self.meta.get_compressed_audio_file_path(episode_index, audio_key) - + audio_paths[audio_key] = str(output_audio_path) if output_audio_path.is_file(): # Skip if video is already encoded. Could be the case when resuming data recording. @@ -1180,7 +1209,9 @@ class LeRobotDataset(torch.utils.data.Dataset): obj.delta_indices = None obj.episode_data_index = None obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec() - obj.audio_backend = audio_backend if audio_backend is not None else "ffmpeg" #Waiting for torchcodec release #TODO(CarolinePascal) + obj.audio_backend = ( + audio_backend if audio_backend is not None else "ffmpeg" + ) # Waiting for torchcodec release #TODO(CarolinePascal) return obj diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 970c447d..416d5837 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -33,9 +33,8 @@ from datasets.table import embed_table_storage from huggingface_hub import DatasetCard, DatasetCardData, HfApi from huggingface_hub.errors import RevisionNotFoundError from PIL import Image as PILImage -from torchvision import transforms - from soundfile import read +from torchvision import transforms from lerobot.common.datasets.backward_compatibility import ( V21_MESSAGE, @@ -260,10 +259,12 @@ def load_image_as_numpy( img_array /= 255.0 return img_array + def load_audio_from_path(fpath: str | Path) -> np.ndarray: audio_data, _ = read(fpath, dtype="float32") return audio_data + def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]): """Get a transform function that convert items from Hugging Face dataset (pyarrow) to torch tensors. Importantly, images are converted from PIL, which corresponds to @@ -731,7 +732,7 @@ def validate_features_presence( ): error_message = "" missing_features = expected_features - actual_features - missing_features = {feature for feature in missing_features if "observation.audio" not in feature} + missing_features = {feature for feature in missing_features if "observation.audio" not in feature} extra_features = actual_features - (expected_features | optional_features) if missing_features or extra_features: @@ -793,18 +794,24 @@ def validate_feature_image_or_video(name: str, expected_shape: list[str], value: return error_message + def validate_feature_audio(name: str, expected_shape: list[str], value: np.ndarray): error_message = "" if isinstance(value, np.ndarray): actual_shape = value.shape c = expected_shape - if len(actual_shape) != 2 or (actual_shape[-1] != c[-1] and actual_shape[0] != c[0]): #The number of frames might be different - error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c,)}'.\n" + if len(actual_shape) != 2 or ( + actual_shape[-1] != c[-1] and actual_shape[0] != c[0] + ): # The number of frames might be different + error_message += ( + f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c,)}'.\n" + ) else: error_message += f"The feature '{name}' is expected to be of type 'np.ndarray', but type '{type(value)}' provided instead.\n" return error_message + def validate_feature_string(name: str, value: str): if not isinstance(value, str): return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n" diff --git a/lerobot/common/datasets/video_utils.py b/lerobot/common/datasets/video_utils.py index e8e85411..0511610e 100644 --- a/lerobot/common/datasets/video_utils.py +++ b/lerobot/common/datasets/video_utils.py @@ -25,12 +25,11 @@ from typing import Any, ClassVar import pyarrow as pa import torch -import torchvision import torchaudio +import torchvision from datasets.features.features import register_feature -from PIL import Image - from numpy import ceil +from PIL import Image def get_safe_default_codec(): @@ -42,6 +41,7 @@ def get_safe_default_codec(): ) return "pyav" + def decode_audio( audio_path: Path | str, timestamps: list[float], @@ -68,30 +68,30 @@ def decode_audio( else: raise ValueError(f"Unsupported video backend: {backend}") + def decode_audio_torchvision( audio_path: Path | str, - timestamps: list[float], - duration: float, + timestamps: list[float], + duration: float, log_loaded_timestamps: bool = False, ) -> torch.Tensor: - - #TODO(CarolinePascal) : add channels selection + # TODO(CarolinePascal) : add channels selection audio_path = str(audio_path) reader = torchaudio.io.StreamReader(src=audio_path) audio_sample_rate = reader.get_src_stream_info(reader.default_audio_stream).sample_rate - #TODO(CarolinePascal) : sort timestamps ? + # TODO(CarolinePascal) : sort timestamps ? reader.add_basic_audio_stream( - frames_per_chunk = int(ceil(duration * audio_sample_rate)), #Too much is better than not enough - buffer_chunk_size = -1, #No dropping frames - format = "fltp", #Format as float32 + frames_per_chunk=int(ceil(duration * audio_sample_rate)), # Too much is better than not enough + buffer_chunk_size=-1, # No dropping frames + format="fltp", # Format as float32 ) audio_chunks = [] for ts in timestamps: - reader.seek(ts) #Default to closest audio sample + reader.seek(ts) # Default to closest audio sample status = reader.fill_buffer() if status != 0: logging.warning("Audio stream reached end of recording before decoding desired timestamps.") @@ -99,15 +99,18 @@ def decode_audio_torchvision( current_audio_chunk = reader.pop_chunks()[0] if log_loaded_timestamps: - logging.info(f"audio chunk loaded at starting timestamp={current_audio_chunk["pts"]:.4f} with duration={len(current_audio_chunk) / audio_sample_rate:.4f}") - + logging.info( + f"audio chunk loaded at starting timestamp={current_audio_chunk['pts']:.4f} with duration={len(current_audio_chunk) / audio_sample_rate:.4f}" + ) + audio_chunks.append(current_audio_chunk) audio_chunks = torch.stack(audio_chunks) assert len(timestamps) == len(audio_chunks) return audio_chunks - + + def decode_video_frames( video_path: Path | str, timestamps: list[float], @@ -137,6 +140,7 @@ def decode_video_frames( else: raise ValueError(f"Unsupported video backend: {backend}") + def decode_video_frames_torchvision( video_path: Path | str, timestamps: list[float], @@ -234,6 +238,7 @@ def decode_video_frames_torchvision( assert len(timestamps) == len(closest_frames) return closest_frames + def decode_video_frames_torchcodec( video_path: Path | str, timestamps: list[float], @@ -308,6 +313,7 @@ def decode_video_frames_torchcodec( assert len(timestamps) == len(closest_frames) return closest_frames + def encode_audio( input_path: Path | str, output_path: Path | str, @@ -344,6 +350,7 @@ def encode_audio( f"Try running the command manually to debug: `{''.join(ffmpeg_cmd)}`" ) + def encode_video_frames( imgs_dir: Path | str, video_path: Path | str, @@ -353,7 +360,7 @@ def encode_video_frames( pix_fmt: str = "yuv420p", g: int | None = 2, crf: int | None = 30, - acodec: str = "aac", #TODO(CarolinePascal) : investigate Fraunhofer FDK AAC (libfdk_aac) codec and and constant (file size control) /variable (quality control) bitrate options + acodec: str = "aac", # TODO(CarolinePascal) : investigate Fraunhofer FDK AAC (libfdk_aac) codec and and constant (file size control) /variable (quality control) bitrate options fast_decode: int = 0, log_level: str | None = "error", overwrite: bool = False, @@ -375,16 +382,18 @@ def encode_video_frames( if audio_path is not None: audio_path = Path(audio_path) audio_path.parent.mkdir(parents=True, exist_ok=True) - ffmpeg_audio_args.update(OrderedDict( - [ - ("-i", str(audio_path)), - ] - )) + ffmpeg_audio_args.update( + OrderedDict( + [ + ("-i", str(audio_path)), + ] + ) + ) ffmpeg_encoding_args = OrderedDict( [ ("-pix_fmt", pix_fmt), - ("-vcodec", vcodec), + ("-vcodec", vcodec), ] ) if g is not None: @@ -398,7 +407,7 @@ def encode_video_frames( if audio_path is not None: ffmpeg_encoding_args["-acodec"] = acodec - + if log_level is not None: ffmpeg_encoding_args["-loglevel"] = str(log_level) @@ -487,6 +496,7 @@ def get_audio_info(video_path: Path | str) -> dict: "audio.channel_layout": audio_stream_info.get("channel_layout", None), } + def get_video_info(video_path: Path | str) -> dict: ffprobe_video_cmd = [ "ffprobe", diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py index 7c8706a4..e49a4e71 100644 --- a/lerobot/common/robot_devices/control_utils.py +++ b/lerobot/common/robot_devices/control_utils.py @@ -77,8 +77,8 @@ def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, f key = f"read_camera_{name}_dt_s" if key in robot.logs: log_dt(f"dtR{name}", robot.logs[key]) - - for name in robot.microphones: + + for name in robot.microphones: key = f"read_microphone_{name}_dt_s" if key in robot.logs: log_dt(f"dtR{name}", robot.logs[key]) @@ -252,9 +252,11 @@ def control_loop( timestamp = 0 start_episode_t = time.perf_counter() - if dataset is not None and not robot.robot_type.startswith("lekiwi"): #For now, LeKiwi only supports frame audio recording (which may lead to audio chunks loss, extended post-processing, increased memory usage) + if ( + dataset is not None and not robot.robot_type.startswith("lekiwi") + ): # For now, LeKiwi only supports frame audio recording (which may lead to audio chunks loss, extended post-processing, increased memory usage) for microphone_key, microphone in robot.microphones.items(): - #Start recording both in file writing and data reading mode + # Start recording both in file writing and data reading mode dataset.add_microphone_recording(microphone, microphone_key) else: for _, microphone in robot.microphones.items(): diff --git a/lerobot/common/robot_devices/microphones/configs.py b/lerobot/common/robot_devices/microphones/configs.py index c2700723..1b663b7a 100644 --- a/lerobot/common/robot_devices/microphones/configs.py +++ b/lerobot/common/robot_devices/microphones/configs.py @@ -17,12 +17,14 @@ from dataclasses import dataclass import draccus + @dataclass class MicrophoneConfigBase(draccus.ChoiceRegistry, abc.ABC): @property def type(self) -> str: return self.get_choice_name(self.__class__) + @MicrophoneConfigBase.register_subclass("microphone") @dataclass class MicrophoneConfig(MicrophoneConfigBase): @@ -33,4 +35,4 @@ class MicrophoneConfig(MicrophoneConfigBase): microphone_index: int sample_rate: int | None = None channels: list[int] | None = None - mock: bool = False \ No newline at end of file + mock: bool = False diff --git a/lerobot/common/robot_devices/microphones/microphone.py b/lerobot/common/robot_devices/microphones/microphone.py index 2d75293a..947fdfea 100644 --- a/lerobot/common/robot_devices/microphones/microphone.py +++ b/lerobot/common/robot_devices/microphones/microphone.py @@ -13,36 +13,35 @@ # limitations under the License. """ -This file contains utilities for recording audio from a microhone. +This file contains utilities for recording audio from a microhone. """ import argparse -import soundfile as sf -import numpy as np import logging -from threading import Thread, Event -from multiprocessing import Process -from queue import Empty - -from queue import Queue as thread_Queue -from threading import Event as thread_Event -from multiprocessing import JoinableQueue as process_Queue -from multiprocessing import Event as process_Event - -from os import getcwd -from pathlib import Path import shutil import time +from multiprocessing import Event as process_Event +from multiprocessing import JoinableQueue as process_Queue +from multiprocessing import Process +from os import getcwd +from pathlib import Path +from queue import Empty +from queue import Queue as thread_Queue +from threading import Event, Thread +from threading import Event as thread_Event -from lerobot.common.utils.utils import capture_timestamp_utc +import numpy as np +import soundfile as sf from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig from lerobot.common.robot_devices.utils import ( RobotDeviceAlreadyConnectedError, + RobotDeviceAlreadyRecordingError, RobotDeviceNotConnectedError, RobotDeviceNotRecordingError, - RobotDeviceAlreadyRecordingError, ) +from lerobot.common.utils.utils import capture_timestamp_utc + def find_microphones(raise_when_empty=False, mock=False) -> list[dict]: microphones = [] @@ -69,11 +68,10 @@ def find_microphones(raise_when_empty=False, mock=False) -> list[dict]: return microphones -def record_audio_from_microphones( - output_dir: Path, - microphone_ids: list[int] | None = None, - record_time_s: float = 2.0): +def record_audio_from_microphones( + output_dir: Path, microphone_ids: list[int] | None = None, record_time_s: float = 2.0 +): if microphone_ids is None or len(microphone_ids) == 0: microphones = find_microphones() microphone_ids = [m["index"] for m in microphones] @@ -104,13 +102,14 @@ def record_audio_from_microphones( for microphone in microphones: microphone.stop_recording() - #Remark : recording may be resumed here if needed + # Remark : recording may be resumed here if needed for microphone in microphones: microphone.disconnect() print(f"Images have been saved to {output_dir}") + class Microphone: """ The Microphone class handles all microphones compatible with sounddevice (and the underlying PortAudio library). Most microphones and sound cards are compatible, accross all OS (Linux, Mac, Windows). @@ -138,20 +137,20 @@ class Microphone: self.config = config self.microphone_index = config.microphone_index - #Store the recording sample rate and channels + # Store the recording sample rate and channels self.sample_rate = config.sample_rate self.channels = config.channels self.mock = config.mock - #Input audio stream + # Input audio stream self.stream = None - #Thread-safe concurrent queue to store the recorded/read audio + # Thread-safe concurrent queue to store the recorded/read audio self.record_queue = None self.read_queue = None - #Thread to handle data reading and file writing in a separate thread (safely) + # Thread to handle data reading and file writing in a separate thread (safely) self.record_thread = None self.record_stop_event = None @@ -162,14 +161,16 @@ class Microphone: def connect(self) -> None: if self.is_connected: - raise RobotDeviceAlreadyConnectedError(f"Microphone {self.microphone_index} is already connected.") - + raise RobotDeviceAlreadyConnectedError( + f"Microphone {self.microphone_index} is already connected." + ) + if self.mock: import tests.microphones.mock_sounddevice as sd else: import sounddevice as sd - #Check if the provided microphone index does match an input device + # Check if the provided microphone index does match an input device is_index_input = sd.query_devices(self.microphone_index)["max_input_channels"] > 0 if not is_index_input: @@ -178,17 +179,19 @@ class Microphone: raise OSError( f"Microphone index {self.microphone_index} does not match an input device (microphone). Available input devices : {available_microphones}" ) - - #Check if provided recording parameters are compatible with the microphone + + # Check if provided recording parameters are compatible with the microphone actual_microphone = sd.query_devices(self.microphone_index) - if self.sample_rate is not None : + if self.sample_rate is not None: if self.sample_rate > actual_microphone["default_samplerate"]: raise OSError( f"Provided sample rate {self.sample_rate} is higher than the sample rate of the microphone {actual_microphone['default_samplerate']}." ) elif self.sample_rate < actual_microphone["default_samplerate"]: - logging.warning("Provided sample rate is lower than the sample rate of the microphone. Performance may be impacted.") + logging.warning( + "Provided sample rate is lower than the sample rate of the microphone. Performance may be impacted." + ) else: self.sample_rate = int(actual_microphone["default_samplerate"]) @@ -198,45 +201,52 @@ class Microphone: f"Some of the provided channels {self.channels} are outside the maximum channel range of the microphone {actual_microphone['max_input_channels']}." ) else: - self.channels = np.arange(1, actual_microphone["max_input_channels"]+1) + self.channels = np.arange(1, actual_microphone["max_input_channels"] + 1) # Get channels index instead of number for slicing self.channels = np.array(self.channels) - 1 - #Create the audio stream + # Create the audio stream self.stream = sd.InputStream( device=self.microphone_index, samplerate=self.sample_rate, - channels=max(self.channels)+1, + channels=max(self.channels) + 1, dtype="float32", callback=self._audio_callback, ) - #Remark : the blocksize parameter could be passed to the stream to ensure that audio_callback always recieve same length buffers. - #However, this may lead to additionnal latency. We thus stick to blocksize=0 which means that audio_callback will recieve varying length buffers, but with no addtional latency. - + # Remark : the blocksize parameter could be passed to the stream to ensure that audio_callback always recieve same length buffers. + # However, this may lead to additionnal latency. We thus stick to blocksize=0 which means that audio_callback will recieve varying length buffers, but with no addtional latency. + self.is_connected = True - def _audio_callback(self, indata, frames, time, status) -> None : + def _audio_callback(self, indata, frames, time, status) -> None: if status: logging.warning(status) - # Slicing makes copy unecessary + # Slicing makes copy unecessary # Two separate queues are necessary because .get() also pops the data from the queue if self.is_writing: - self.record_queue.put(indata[:,self.channels]) - self.read_queue.put(indata[:,self.channels]) + self.record_queue.put(indata[:, self.channels]) + self.read_queue.put(indata[:, self.channels]) @staticmethod def _record_loop(queue, event: Event, sample_rate: int, channels: list[int], output_file: Path) -> None: - #Can only be run on a single process/thread for file writing safety - with sf.SoundFile(output_file, mode='x', samplerate=sample_rate, - channels=max(channels)+1, subtype=sf.default_subtype(output_file.suffix[1:])) as file: + # Can only be run on a single process/thread for file writing safety + with sf.SoundFile( + output_file, + mode="x", + samplerate=sample_rate, + channels=max(channels) + 1, + subtype=sf.default_subtype(output_file.suffix[1:]), + ) as file: while not event.is_set(): try: - file.write(queue.get(timeout=0.02)) #Timeout set as twice the usual sounddevice buffer size + file.write( + queue.get(timeout=0.02) + ) # Timeout set as twice the usual sounddevice buffer size queue.task_done() except Empty: continue - + def _read(self) -> np.ndarray: """ Gets audio data from the queue and coverts it to a numpy array. @@ -244,7 +254,7 @@ class Microphone: -> CONS : Reading duration does not scale well with the number of channels and reading duration """ audio_readings = np.empty((0, len(self.channels))) - + while True: try: audio_readings = np.concatenate((audio_readings, self.read_queue.get_nowait()), axis=0) @@ -256,12 +266,11 @@ class Microphone: return audio_readings def read(self) -> np.ndarray: - if not self.is_connected: raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.") if not self.is_recording: raise RobotDeviceNotRecordingError(f"Microphone {self.microphone_index} is not recording.") - + start_time = time.perf_counter() audio_readings = self._read() @@ -274,21 +283,22 @@ class Microphone: return audio_readings - def start_recording(self, output_file : str | None = None, multiprocessing : bool | None = False) -> None: - + def start_recording(self, output_file: str | None = None, multiprocessing: bool | None = False) -> None: if not self.is_connected: raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.") if self.is_recording: - raise RobotDeviceAlreadyRecordingError(f"Microphone {self.microphone_index} is already recording.") - - #Reset queues + raise RobotDeviceAlreadyRecordingError( + f"Microphone {self.microphone_index} is already recording." + ) + + # Reset queues self.read_queue = thread_Queue() if multiprocessing: self.record_queue = process_Queue() else: self.record_queue = thread_Queue() - #Write recordings into a file if output_file is provided + # Write recordings into a file if output_file is provided if output_file is not None: output_file = Path(output_file) if output_file.exists(): @@ -296,28 +306,45 @@ class Microphone: if multiprocessing: self.record_stop_event = process_Event() - self.record_thread = Process(target=Microphone._record_loop, args=(self.record_queue, self.record_stop_event, self.sample_rate, self.channels, output_file, )) + self.record_thread = Process( + target=Microphone._record_loop, + args=( + self.record_queue, + self.record_stop_event, + self.sample_rate, + self.channels, + output_file, + ), + ) else: self.record_stop_event = thread_Event() - self.record_thread = Thread(target=Microphone._record_loop, args=(self.record_queue, self.record_stop_event, self.sample_rate, self.channels, output_file, )) + self.record_thread = Thread( + target=Microphone._record_loop, + args=( + self.record_queue, + self.record_stop_event, + self.sample_rate, + self.channels, + output_file, + ), + ) self.record_thread.daemon = True self.record_thread.start() self.is_writing = True - + self.is_recording = True self.stream.start() def stop_recording(self) -> None: - if not self.is_connected: raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.") if not self.is_recording: raise RobotDeviceNotRecordingError(f"Microphone {self.microphone_index} is not recording.") - + if self.stream.active: - self.stream.stop() #Wait for all buffers to be processed - #Remark : stream.abort() flushes the buffers ! + self.stream.stop() # Wait for all buffers to be processed + # Remark : stream.abort() flushes the buffers ! self.is_recording = False if self.record_thread is not None: @@ -329,7 +356,6 @@ class Microphone: self.is_writing = False def disconnect(self) -> None: - if not self.is_connected: raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.") @@ -342,7 +368,8 @@ class Microphone: def __del__(self): if getattr(self, "is_connected", False): self.disconnect() - + + if __name__ == "__main__": parser = argparse.ArgumentParser( description="Records audio using `Microphone` for all microphones connected to the computer, or a selected subset." diff --git a/lerobot/common/robot_devices/microphones/utils.py b/lerobot/common/robot_devices/microphones/utils.py index 1b1ad099..fb1bac85 100644 --- a/lerobot/common/robot_devices/microphones/utils.py +++ b/lerobot/common/robot_devices/microphones/utils.py @@ -16,28 +16,33 @@ from typing import Protocol from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig, MicrophoneConfigBase + # Defines a microphone type class Microphone(Protocol): def connect(self): ... def disconnect(self): ... - def start_recording(self, output_file : str | None = None, multiprocessing : bool | None = False): ... + def start_recording(self, output_file: str | None = None, multiprocessing: bool | None = False): ... def stop_recording(self): ... + def make_microphones_from_configs(microphone_configs: dict[str, MicrophoneConfigBase]) -> list[Microphone]: microphones = {} for key, cfg in microphone_configs.items(): if cfg.type == "microphone": from lerobot.common.robot_devices.microphones.microphone import Microphone + microphones[key] = Microphone(cfg) else: raise ValueError(f"The microphone type '{cfg.type}' is not valid.") return microphones + def make_microphone(microphone_type, **kwargs) -> Microphone: if microphone_type == "microphone": from lerobot.common.robot_devices.microphones.microphone import Microphone + return Microphone(MicrophoneConfig(**kwargs)) else: - raise ValueError(f"The microphone type '{microphone_type}' is not valid.") \ No newline at end of file + raise ValueError(f"The microphone type '{microphone_type}' is not valid.") diff --git a/lerobot/common/robot_devices/robots/configs.py b/lerobot/common/robot_devices/robots/configs.py index 66edc4f6..942586a0 100644 --- a/lerobot/common/robot_devices/robots/configs.py +++ b/lerobot/common/robot_devices/robots/configs.py @@ -23,12 +23,12 @@ from lerobot.common.robot_devices.cameras.configs import ( IntelRealSenseCameraConfig, OpenCVCameraConfig, ) +from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig from lerobot.common.robot_devices.motors.configs import ( DynamixelMotorsBusConfig, FeetechMotorsBusConfig, MotorsBusConfig, ) -from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig @dataclass diff --git a/lerobot/common/robot_devices/robots/lekiwi_remote.py b/lerobot/common/robot_devices/robots/lekiwi_remote.py index 03576c0c..15023d8a 100644 --- a/lerobot/common/robot_devices/robots/lekiwi_remote.py +++ b/lerobot/common/robot_devices/robots/lekiwi_remote.py @@ -51,6 +51,7 @@ def run_camera_capture(cameras, images_lock, latest_images_dict, stop_event): latest_images_dict.update(local_dict) time.sleep(0.01) + def run_microphone_capture(microphones, audio_lock, latest_audio_dict, stop_event): while not stop_event.is_set(): local_dict = {} @@ -60,6 +61,7 @@ def run_microphone_capture(microphones, audio_lock, latest_audio_dict, stop_even with audio_lock: latest_audio_dict.update(local_dict) + def calibrate_follower_arm(motors_bus, calib_dir_str): """ Calibrates the follower arm. Attempts to load an existing calibration file; @@ -149,12 +151,14 @@ def run_lekiwi(robot_config): cam_thread.start() # Start the microphone recording and capture thread. - #TODO(CarolinePascal) : Leverage multi-core processing with a multiprocessing.Process instead ! + # TODO(CarolinePascal) : Leverage multi-core processing with a multiprocessing.Process instead ! latest_audio_dict = {} audio_lock = threading.Lock() audio_stop_event = threading.Event() microphone_thread = threading.Thread( - target=run_microphone_capture, args=(microphones, audio_lock, latest_audio_dict, audio_stop_event), daemon=True + target=run_microphone_capture, + args=(microphones, audio_lock, latest_audio_dict, audio_stop_event), + daemon=True, ) for microphone in microphones.values(): microphone.start_recording() @@ -231,7 +235,7 @@ def run_lekiwi(robot_config): # Build the observation dictionary. observation = { "images": images_dict_copy, - "audio": audio_dict_copy, #TODO(CarolinePascal) : This is a nasty way to do it, sorry. + "audio": audio_dict_copy, # TODO(CarolinePascal) : This is a nasty way to do it, sorry. "present_speed": current_velocity, "follower_arm_state": follower_arm_state, } diff --git a/lerobot/common/robot_devices/robots/manipulator.py b/lerobot/common/robot_devices/robots/manipulator.py index dc68d609..b452be9d 100644 --- a/lerobot/common/robot_devices/robots/manipulator.py +++ b/lerobot/common/robot_devices/robots/manipulator.py @@ -201,7 +201,7 @@ class ManipulatorRobot: "names": state_names, }, } - + @property def microphone_features(self) -> dict: mic_ft = {} @@ -211,7 +211,7 @@ class ManipulatorRobot: "dtype": "audio", "shape": (len(mic.channels),), "names": "channels", - "info" : {"sample_rate": mic.sample_rate}, + "info": {"sample_rate": mic.sample_rate}, } return mic_ft @@ -226,11 +226,11 @@ class ManipulatorRobot: @property def num_cameras(self): return len(self.cameras) - + @property def has_microphone(self): return len(self.microphones) > 0 - + @property def num_microphones(self): return len(self.microphones) diff --git a/lerobot/common/robot_devices/robots/mobile_manipulator.py b/lerobot/common/robot_devices/robots/mobile_manipulator.py index 98e4cdb1..7727abb9 100644 --- a/lerobot/common/robot_devices/robots/mobile_manipulator.py +++ b/lerobot/common/robot_devices/robots/mobile_manipulator.py @@ -163,7 +163,7 @@ class MobileManipulator: "names": combined_names, }, } - + @property def microphone_features(self) -> dict: mic_ft = {} @@ -173,7 +173,7 @@ class MobileManipulator: "dtype": "audio", "shape": (len(mic.channels),), "names": "channels", - "info" : {"sample_rate": mic.sample_rate}, + "info": {"sample_rate": mic.sample_rate}, } return mic_ft @@ -188,11 +188,11 @@ class MobileManipulator: @property def num_cameras(self): return len(self.cameras) - + @property def has_microphone(self): return len(self.microphones) > 0 - + @property def num_microphones(self): return len(self.microphones) @@ -512,7 +512,7 @@ class MobileManipulator: # Create silence using the microphone's configured channels frame = np.zeros((1, len(microphone.channels)), dtype=np.float32) obs_dict[f"observation.audio.{microphone_name}"] = torch.from_numpy(frame) - + return obs_dict def send_action(self, action: torch.Tensor) -> torch.Tensor: diff --git a/lerobot/common/robot_devices/utils.py b/lerobot/common/robot_devices/utils.py index 01f9195e..5b2270e7 100644 --- a/lerobot/common/robot_devices/utils.py +++ b/lerobot/common/robot_devices/utils.py @@ -69,11 +69,13 @@ class RobotDeviceNotRecordingError(Exception): """Exception raised when the robot device is not recording.""" def __init__( - self, message="This robot device is not recording. Try calling `robot_device.start_recording()` first." + self, + message="This robot device is not recording. Try calling `robot_device.start_recording()` first.", ): self.message = message super().__init__(self.message) + class RobotDeviceAlreadyRecordingError(Exception): """Exception raised when the robot device is already recording.""" diff --git a/tests/conftest.py b/tests/conftest.py index adf80931..a0e9ae41 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,9 +19,9 @@ import traceback import pytest from serial import SerialException -from lerobot import available_cameras, available_motors, available_robots, available_microphones +from lerobot import available_cameras, available_microphones, available_motors, available_robots from lerobot.common.robot_devices.robots.utils import make_robot -from tests.utils import DEVICE, make_camera, make_motors_bus, make_microphone +from tests.utils import DEVICE, make_camera, make_microphone, make_motors_bus # Import fixture modules as plugins pytest_plugins = [ @@ -73,10 +73,12 @@ def is_robot_available(robot_type): def is_camera_available(camera_type): return _check_component_availability(camera_type, available_cameras, make_camera) + @pytest.fixture def is_microphone_available(microphone_type): return _check_component_availability(microphone_type, available_microphones, make_microphone) + @pytest.fixture def is_motor_available(motor_type): return _check_component_availability(motor_type, available_motors, make_motors_bus) diff --git a/tests/datasets/test_compute_stats.py b/tests/datasets/test_compute_stats.py index 56cdc176..9cf9f760 100644 --- a/tests/datasets/test_compute_stats.py +++ b/tests/datasets/test_compute_stats.py @@ -25,9 +25,9 @@ from lerobot.common.datasets.compute_stats import ( compute_episode_stats, estimate_num_samples, get_feature_stats, - sample_images, - sample_audio_from_path, sample_audio_from_data, + sample_audio_from_path, + sample_images, sample_indices, ) @@ -35,8 +35,10 @@ from lerobot.common.datasets.compute_stats import ( def mock_load_image_as_numpy(path, dtype, channel_first): return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype) + def mock_load_audio(path): - return np.ones((16000,2), dtype=np.float32) + return np.ones((16000, 2), dtype=np.float32) + @pytest.fixture def sample_array(): @@ -74,6 +76,7 @@ def test_sample_images(mock_load): assert images.dtype == np.uint8 assert len(images) == estimate_num_samples(100) + @patch("lerobot.common.datasets.compute_stats.load_audio_from_path", side_effect=mock_load_audio) def test_sample_audio_from_path(mock_load): audio_path = "audio.wav" @@ -83,6 +86,7 @@ def test_sample_audio_from_path(mock_load): assert audio_samples.dtype == np.float32 assert len(audio_samples) == estimate_num_samples(16000) + def test_sample_audio_from_data(mock_load): audio_data = np.ones((16000, 2), dtype=np.float32) audio_samples = sample_audio_from_data(audio_data) @@ -91,6 +95,7 @@ def test_sample_audio_from_data(mock_load): assert audio_samples.dtype == np.float32 assert len(audio_samples) == estimate_num_samples(16000) + def test_get_feature_stats_images(): data = np.random.rand(100, 3, 32, 32) stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True) @@ -98,13 +103,15 @@ def test_get_feature_stats_images(): np.testing.assert_equal(stats["count"], np.array([100])) assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape + def test_get_feature_stats_audio(): - data = np.random.uniform(-1, 1, (16000,2)) + data = np.random.uniform(-1, 1, (16000, 2)) stats = get_feature_stats(data, axis=0, keepdims=True) assert "min" in stats and "max" in stats and "mean" in stats and "std" in stats and "count" in stats np.testing.assert_equal(stats["count"], np.array([16000])) assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape + def test_get_feature_stats_axis_0_keepdims(sample_array): expected = { "min": np.array([[1, 2, 3]]), @@ -172,10 +179,11 @@ def test_compute_episode_stats(): "observation.state": {"dtype": "numeric"}, } - with patch( - "lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy - ), patch( - "lerobot.common.datasets.compute_stats.load_audio_from_path", side_effect=mock_load_audio + with ( + patch( + "lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy + ), + patch("lerobot.common.datasets.compute_stats.load_audio_from_path", side_effect=mock_load_audio), ): stats = compute_episode_stats(episode_data, features) diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index dde2ed06..a10a7d61 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -16,6 +16,7 @@ import json import logging import re +import time from copy import deepcopy from itertools import chain from pathlib import Path @@ -35,6 +36,7 @@ from lerobot.common.datasets.lerobot_dataset import ( MultiLeRobotDataset, ) from lerobot.common.datasets.utils import ( + DEFAULT_AUDIO_CHUNK_DURATION, create_branch, flatten_dict, unflatten_dict, @@ -44,12 +46,9 @@ from lerobot.common.policies.factory import make_policy_config from lerobot.common.robot_devices.robots.utils import make_robot from lerobot.configs.default import DatasetConfig from lerobot.configs.train import TrainPipelineConfig -from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID, DUMMY_AUDIO_CHANNELS -from tests.utils import require_x86_64_kernel +from tests.fixtures.constants import DUMMY_AUDIO_CHANNELS, DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID +from tests.utils import make_microphone, require_x86_64_kernel -from tests.utils import make_microphone -import time -from lerobot.common.datasets.utils import DEFAULT_AUDIO_CHUNK_DURATION @pytest.fixture def image_dataset(tmp_path, empty_lerobot_dataset_factory): @@ -66,6 +65,7 @@ def image_dataset(tmp_path, empty_lerobot_dataset_factory): } return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) + @pytest.fixture def audio_dataset(tmp_path, empty_lerobot_dataset_factory): features = { @@ -79,6 +79,7 @@ def audio_dataset(tmp_path, empty_lerobot_dataset_factory): } return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) + def test_same_attributes_defined(tmp_path, lerobot_dataset_factory): """ Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated @@ -336,6 +337,7 @@ def test_image_array_to_pil_image_wrong_range_float_0_255(): with pytest.raises(ValueError): image_array_to_pil_image(image) + def test_add_frame_audio(audio_dataset): dataset = audio_dataset @@ -349,7 +351,10 @@ def test_add_frame_audio(audio_dataset): dataset.save_episode() - assert dataset[0]["observation.audio.microphone"].shape == torch.Size((int(DEFAULT_AUDIO_CHUNK_DURATION*microphone.sample_rate),DUMMY_AUDIO_CHANNELS)) + assert dataset[0]["observation.audio.microphone"].shape == torch.Size( + (int(DEFAULT_AUDIO_CHUNK_DURATION * microphone.sample_rate), DUMMY_AUDIO_CHANNELS) + ) + # TODO(aliberts): # - [ ] test various attributes & state from init and create diff --git a/tests/fixtures/constants.py b/tests/fixtures/constants.py index 91942190..b95044df 100644 --- a/tests/fixtures/constants.py +++ b/tests/fixtures/constants.py @@ -29,7 +29,12 @@ DUMMY_MOTOR_FEATURES = { }, } DUMMY_CAMERA_FEATURES = { - "laptop": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None, "audio": "laptop"}, + "laptop": { + "shape": (480, 640, 3), + "names": ["height", "width", "channels"], + "info": None, + "audio": "laptop", + }, "phone": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None}, } DEFAULT_FPS = 30 diff --git a/tests/fixtures/dataset_factories.py b/tests/fixtures/dataset_factories.py index 80387d65..321dec46 100644 --- a/tests/fixtures/dataset_factories.py +++ b/tests/fixtures/dataset_factories.py @@ -26,10 +26,10 @@ import torch from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata from lerobot.common.datasets.utils import ( DEFAULT_CHUNK_SIZE, + DEFAULT_COMPRESSED_AUDIO_PATH, DEFAULT_FEATURES, DEFAULT_PARQUET_PATH, DEFAULT_VIDEO_PATH, - DEFAULT_COMPRESSED_AUDIO_PATH, get_hf_features_from_features, hf_transform_to_torch, ) diff --git a/tests/microphones/mock_sounddevice.py b/tests/microphones/mock_sounddevice.py index f6007085..0220c88c 100644 --- a/tests/microphones/mock_sounddevice.py +++ b/tests/microphones/mock_sounddevice.py @@ -11,27 +11,30 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import time from functools import cache - -from tests.fixtures.constants import DUMMY_AUDIO_CHANNELS, DEFAULT_SAMPLE_RATE +from threading import Event, Thread import numpy as np + from lerobot.common.utils.utils import capture_timestamp_utc -from threading import Thread, Event -import time +from tests.fixtures.constants import DEFAULT_SAMPLE_RATE, DUMMY_AUDIO_CHANNELS + @cache def _generate_sound(duration: float, sample_rate: int, channels: int): return np.random.uniform(-1, 1, size=(int(duration * sample_rate), channels)).astype(np.float32) + def query_devices(query_index: int): return { - "name": "Mock Sound Device", - "index": query_index, - "max_input_channels": DUMMY_AUDIO_CHANNELS, - "default_samplerate": DEFAULT_SAMPLE_RATE, + "name": "Mock Sound Device", + "index": query_index, + "max_input_channels": DUMMY_AUDIO_CHANNELS, + "default_samplerate": DEFAULT_SAMPLE_RATE, } + class InputStream: def __init__(self, *args, **kwargs): self._mock_dict = { @@ -49,7 +52,12 @@ class InputStream: while not self.callback_thread_stop_event.is_set(): # Simulate audio data acquisition time.sleep(0.01) - self._audio_callback(_generate_sound(0.01, DEFAULT_SAMPLE_RATE, DUMMY_AUDIO_CHANNELS), 0.01*DEFAULT_SAMPLE_RATE, capture_timestamp_utc(), None) + self._audio_callback( + _generate_sound(0.01, DEFAULT_SAMPLE_RATE, DUMMY_AUDIO_CHANNELS), + 0.01 * DEFAULT_SAMPLE_RATE, + capture_timestamp_utc(), + None, + ) def start(self): self.callback_thread_stop_event = Event() @@ -62,7 +70,7 @@ class InputStream: @property def active(self): return self._is_active - + def stop(self): if self.callback_thread_stop_event is not None: self.callback_thread_stop_event.set() @@ -78,5 +86,3 @@ class InputStream: def __del__(self): if self._is_active: self.stop() - - diff --git a/tests/microphones/test_microphones.py b/tests/microphones/test_microphones.py index 3ce29fa2..c7bdbe71 100644 --- a/tests/microphones/test_microphones.py +++ b/tests/microphones/test_microphones.py @@ -32,20 +32,27 @@ pytest -sx 'tests/microphones/test_microphones.py::test_microphone[microphone-Tr ``` """ -import numpy as np import time + +import numpy as np import pytest from soundfile import read -from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError, RobotDeviceNotRecordingError, RobotDeviceAlreadyRecordingError +from lerobot.common.robot_devices.utils import ( + RobotDeviceAlreadyConnectedError, + RobotDeviceAlreadyRecordingError, + RobotDeviceNotConnectedError, + RobotDeviceNotRecordingError, +) from tests.utils import TEST_MICROPHONE_TYPES, make_microphone, require_microphone -#Maximum recording tie difference between two consecutive audio recordings of the same duration. -#Set to 0.02 seconds as twice the default size of sounddvice callback buffer (i.e. we tolerate the loss of one buffer). +# Maximum recording tie difference between two consecutive audio recordings of the same duration. +# Set to 0.02 seconds as twice the default size of sounddvice callback buffer (i.e. we tolerate the loss of one buffer). MAX_RECORDING_TIME_DIFFERENCE = 0.02 DUMMY_RECORDING = "test_recording.wav" + @pytest.mark.parametrize("microphone_type, mock", TEST_MICROPHONE_TYPES) @require_microphone def test_microphone(tmp_path, request, microphone_type, mock): @@ -92,7 +99,7 @@ def test_microphone(tmp_path, request, microphone_type, mock): fpath = tmp_path / DUMMY_RECORDING microphone.start_recording(fpath) assert microphone.is_recording - + # Test start_recording twice raises an error with pytest.raises(RobotDeviceAlreadyRecordingError): microphone.start_recording() @@ -126,10 +133,13 @@ def test_microphone(tmp_path, request, microphone_type, mock): error_msg = ( "Recording time difference between read() and stop_recording()", - (len(audio_chunk) - len(recorded_audio))/MAX_RECORDING_TIME_DIFFERENCE, + (len(audio_chunk) - len(recorded_audio)) / MAX_RECORDING_TIME_DIFFERENCE, ) np.testing.assert_allclose( - len(audio_chunk), len(recorded_audio), atol=recorded_sample_rate*MAX_RECORDING_TIME_DIFFERENCE, err_msg=error_msg + len(audio_chunk), + len(recorded_audio), + atol=recorded_sample_rate * MAX_RECORDING_TIME_DIFFERENCE, + err_msg=error_msg, ) # Test disconnecting @@ -139,4 +149,4 @@ def test_microphone(tmp_path, request, microphone_type, mock): # Test disconnecting with `__del__` microphone = make_microphone(**microphone_kwargs) microphone.connect() - del microphone \ No newline at end of file + del microphone diff --git a/tests/robots/test_robots.py b/tests/robots/test_robots.py index 204aabca..8353fe29 100644 --- a/tests/robots/test_robots.py +++ b/tests/robots/test_robots.py @@ -143,7 +143,7 @@ def test_robot(tmp_path, request, robot_type, mock): robot.send_action(action["action"]) # Test disconnecting - robot.disconnect() #Also handles microphone recording stop, life is beautiful + robot.disconnect() # Also handles microphone recording stop, life is beautiful assert not robot.is_connected for name in robot.follower_arms: assert not robot.follower_arms[name].is_connected diff --git a/tests/utils.py b/tests/utils.py index 8559ca0c..d93eb97c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -22,13 +22,13 @@ from pathlib import Path import pytest import torch -from lerobot import available_cameras, available_motors, available_robots, available_microphones +from lerobot import available_cameras, available_microphones, available_motors, available_robots from lerobot.common.robot_devices.cameras.utils import Camera from lerobot.common.robot_devices.cameras.utils import make_camera as make_camera_device -from lerobot.common.robot_devices.motors.utils import MotorsBus -from lerobot.common.robot_devices.motors.utils import make_motors_bus as make_motors_bus_device from lerobot.common.robot_devices.microphones.utils import Microphone from lerobot.common.robot_devices.microphones.utils import make_microphone as make_microphone_device +from lerobot.common.robot_devices.motors.utils import MotorsBus +from lerobot.common.robot_devices.motors.utils import make_motors_bus as make_motors_bus_device from lerobot.common.utils.import_utils import is_package_available DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", "cuda") if torch.cuda.is_available() else "cpu" @@ -261,6 +261,7 @@ def require_camera(func): return wrapper + def require_microphone(func): @wraps(func) def wrapper(*args, **kwargs): @@ -283,6 +284,7 @@ def require_microphone(func): return wrapper + def require_motor(func): @wraps(func) def wrapper(*args, **kwargs): @@ -344,6 +346,7 @@ def make_camera(camera_type: str, **kwargs) -> Camera: else: raise ValueError(f"The camera type '{camera_type}' is not valid.") + def make_microphone(microphone_type: str, **kwargs) -> Microphone: if microphone_type == "microphone": microphone_index = kwargs.pop("microphone_index", MICROPHONE_INDEX) @@ -351,6 +354,7 @@ def make_microphone(microphone_type: str, **kwargs) -> Microphone: else: raise ValueError(f"The microphone type '{microphone_type}' is not valid.") + # TODO(rcadene, aliberts): remove this dark pattern that overrides def make_motors_bus(motor_type: str, **kwargs) -> MotorsBus: if motor_type == "dynamixel": From a08b5c4105a7f58416cdd90a6ab4786851eab5b5 Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Fri, 11 Apr 2025 13:43:17 +0200 Subject: [PATCH 22/32] Adding last missing audio features in LeRobotDataset --- lerobot/common/datasets/lerobot_dataset.py | 66 ++++++++++++++++------ 1 file changed, 48 insertions(+), 18 deletions(-) diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index f844eb72..da5874fb 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -206,7 +206,7 @@ class LeRobotDatasetMetadata: @property def audio_keys(self) -> list[str]: - """Keys to access audio modalities.""" + """Keys to access audio modalities (wether they are linked to a camera or not).""" return [key for key, ft in self.features.items() if ft["dtype"] == "audio"] @property @@ -219,6 +219,16 @@ class LeRobotDatasetMetadata: and self.features[camera_key]["dtype"] == "video" } + @property + def linked_audio_keys(self) -> list[str]: + """Keys to access audio modalities linked to a camera.""" + return [key for key in self.audio_keys if key in self.audio_camera_keys_mapping] + + @property + def unlinked_audio_keys(self) -> list[str]: + """Keys to access audio modalities not linked to a camera.""" + return [key for key in self.audio_keys if key not in self.audio_camera_keys_mapping] + @property def names(self) -> dict[str, list | dict]: """Names of the various dimensions of vector modalities.""" @@ -298,7 +308,8 @@ class LeRobotDatasetMetadata: if len(self.video_keys) > 0: self.update_video_info() - if len(self.audio_keys) > 0: + self.info["total_audio"] += len(self.audio_keys) + if len(self.unlinked_audio_keys) > 0: self.update_audio_info() write_info(self.info, self.root) @@ -330,10 +341,10 @@ class LeRobotDatasetMetadata: Warning: this function writes info from first episode audio, implicitly assuming that all audio have been encoded the same way. Also, this means it assumes the first episode exists. """ - for key in set(self.audio_keys) - set(self.audio_camera_keys_mapping.keys()): + for key in self.unlinked_audio_keys: if not self.features[key].get("info", None) or ( len(self.features[key]["info"]) == 1 and "sample_rate" in self.features[key]["info"] - ): + ): #Overwrite if info is empty or only contains sample rate (necessary to correctly save audio files recorded by LeKiwi) audio_path = self.root / self.get_compressed_audio_file_path(0, key) self.info["features"][key]["info"] = get_audio_info(audio_path) @@ -412,6 +423,7 @@ class LeRobotDataset(torch.utils.data.Dataset): revision: str | None = None, force_cache_sync: bool = False, download_videos: bool = True, + download_audio: bool = True, video_backend: str | None = None, audio_backend: str | None = None, ): @@ -444,7 +456,8 @@ class LeRobotDataset(torch.utils.data.Dataset): - tasks contains the prompts for each task of the dataset, which can be used for task-conditioned training. - hf_dataset (from datasets.Dataset), which will read any values from parquet files. - - videos (optional) from which frames are loaded to be synchronous with data from parquet files. + - videos (optional) from which frames and audio (if any) are loaded to be synchronous with data from parquet files and audio. + - audio (optional) from which audio is loaded to be synchronous with data from parquet files and videos. A typical LeRobotDataset looks like this from its root path: . @@ -513,9 +526,10 @@ class LeRobotDataset(torch.utils.data.Dataset): download_videos (bool, optional): Flag to download the videos. Note that when set to True but the video files are already present on local disk, they won't be downloaded again. Defaults to True. + download_audio (bool, optional): Flag to download the audio (see download_videos). Defaults to True. video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'. You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision. - audio_backend (str | None, optional): Audio backend to use for decoding audio. Defaults to 'ffmpeg'. + audio_backend (str | None, optional): Audio backend to use for decoding audio. Defaults to 'ffmpeg' decoder used by 'torchaudio'. """ super().__init__() self.repo_id = repo_id @@ -554,8 +568,9 @@ class LeRobotDataset(torch.utils.data.Dataset): except (AssertionError, FileNotFoundError, NotADirectoryError): self.revision = get_safe_version(self.repo_id, self.revision) self.download_episodes( - download_videos - ) # Sould load audio as well #TODO(CarolinePascal): separate audio from video + download_videos, + download_audio + ) #Audio embedded in video files (.mp4) will be downloaded if download_videos is set to True, regardless of the value of download_audio self.hf_dataset = self.load_hf_dataset() self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes) @@ -578,6 +593,7 @@ class LeRobotDataset(torch.utils.data.Dataset): license: str | None = "apache-2.0", tag_version: bool = True, push_videos: bool = True, + push_audio: bool = True, private: bool = False, allow_patterns: list[str] | str | None = None, upload_large_folder: bool = False, @@ -585,7 +601,9 @@ class LeRobotDataset(torch.utils.data.Dataset): ) -> None: ignore_patterns = ["images/"] if not push_videos: - ignore_patterns.append("videos/") + ignore_patterns.append("videos/") #Audio embedded in video files (.mp4) will be automatically pushed if videos are pushed + if not push_audio: + ignore_patterns.append("audio/") hub_api = HfApi() hub_api.create_repo( @@ -641,7 +659,7 @@ class LeRobotDataset(torch.utils.data.Dataset): ignore_patterns=ignore_patterns, ) - def download_episodes(self, download_videos: bool = True) -> None: + def download_episodes(self, download_videos: bool = True, download_audio: bool = True) -> None: """Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present @@ -650,7 +668,11 @@ class LeRobotDataset(torch.utils.data.Dataset): # TODO(rcadene, aliberts): implement faster transfer # https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads files = None - ignore_patterns = None if download_videos else "videos/" + ignore_patterns = [] + if not download_videos: + ignore_patterns.append("videos/") #Audio embedded in video files (.mp4) will be automatically downloaded if videos are downloaded + if not download_audio: + ignore_patterns.append("audio/") if self.episodes is not None: files = self.get_episodes_file_paths() @@ -667,6 +689,14 @@ class LeRobotDataset(torch.utils.data.Dataset): ] fpaths += video_files + if len(self.meta.unlinked_audio_keys) > 0: + audio_files = [ + str(self.meta.get_compressed_audio_file_path(ep_idx, audio_key)) + for audio_key in self.meta.unlinked_audio_keys + for ep_idx in episodes + ] + fpaths += audio_files + return fpaths def load_hf_dataset(self) -> datasets.Dataset: @@ -755,7 +785,7 @@ class LeRobotDataset(torch.utils.data.Dataset): query_indices: dict[str, list[int]] | None = None, ) -> dict[str, list[float]]: query_timestamps = {} - for key in self.meta.audio_keys: + for key in self.meta.audio_keys: #Standalone audio and audio embedded in video as well ! if query_indices is not None and key in query_indices: timestamps = self.hf_dataset.select(query_indices[key])["timestamp"] query_timestamps[key] = torch.stack(timestamps).tolist() @@ -768,7 +798,7 @@ class LeRobotDataset(torch.utils.data.Dataset): return { key: torch.stack(self.hf_dataset.select(q_idx)[key]) for key, q_idx in query_indices.items() - if key not in self.meta.video_keys + if key not in self.meta.video_keys and key not in self.meta.audio_keys } def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]: @@ -791,12 +821,12 @@ class LeRobotDataset(torch.utils.data.Dataset): ) -> dict[str, torch.Tensor]: item = {} for audio_key, query_ts in query_timestamps.items(): - # Audio stored with video in a single .mp4 file - if audio_key in self.meta.audio_camera_keys_mapping: + #Audio stored with video in a single .mp4 file + if audio_key in self.meta.linked_audio_keys: audio_path = self.root / self.meta.get_video_file_path( ep_idx, self.meta.audio_camera_keys_mapping[audio_key] ) - # Audio stored alone in a separate .m4a file + #Audio stored alone in a separate .m4a file else: audio_path = self.root / self.meta.get_compressed_audio_file_path(ep_idx, audio_key) audio_chunk = decode_audio(audio_path, query_ts, query_duration, self.audio_backend) @@ -1023,7 +1053,7 @@ class LeRobotDataset(torch.utils.data.Dataset): for key in self.meta.video_keys: episode_buffer[key] = video_paths[key] - if len(self.meta.audio_keys) > 0: + if len(self.meta.unlinked_audio_keys) > 0: #Linked audio is already encoded in the video files _ = self.encode_episode_audio(episode_index) # `meta.save_episode` be executed after encoding the videos @@ -1150,7 +1180,7 @@ class LeRobotDataset(torch.utils.data.Dataset): since video encoding with ffmpeg is already using multithreading. """ audio_paths = {} - for audio_key in set(self.meta.audio_keys) - set(self.meta.audio_camera_keys_mapping.keys()): + for audio_key in self.meta.unlinked_audio_keys: input_audio_path = self.root / self._get_raw_audio_file_path(episode_index, audio_key) output_audio_path = self.root / self.meta.get_compressed_audio_file_path(episode_index, audio_key) From 5384309e6ffda1fcef3601155cb9d7eaae8f73e0 Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Fri, 11 Apr 2025 13:46:34 +0200 Subject: [PATCH 23/32] docs: add methods descriptions and comments on tricky parts --- lerobot/common/datasets/compute_stats.py | 4 +- lerobot/common/datasets/lerobot_dataset.py | 53 ++++++++++++------- lerobot/common/datasets/video_utils.py | 15 +++--- .../robot_devices/microphones/microphone.py | 48 +++++++++++++---- .../robots/mobile_manipulator.py | 2 +- 5 files changed, 82 insertions(+), 40 deletions(-) diff --git a/lerobot/common/datasets/compute_stats.py b/lerobot/common/datasets/compute_stats.py index 36606719..2fab5a80 100644 --- a/lerobot/common/datasets/compute_stats.py +++ b/lerobot/common/datasets/compute_stats.py @@ -73,6 +73,7 @@ def sample_images(image_paths: list[str]) -> np.ndarray: def sample_audio_from_path(audio_path: str) -> np.ndarray: + """Samples audio data from an audio recording stored in a WAV file.""" data = load_audio_from_path(audio_path) sampled_indices = sample_indices(len(data)) @@ -80,6 +81,7 @@ def sample_audio_from_path(audio_path: str) -> np.ndarray: def sample_audio_from_data(data: np.ndarray) -> np.ndarray: + """Samples audio data from an audio recording stored in a numpy array.""" sampled_indices = sample_indices(len(data)) return data[sampled_indices] @@ -106,7 +108,7 @@ def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], featu elif features[key]["dtype"] == "audio": try: ep_ft_array = sample_audio_from_path(data[0]) - except TypeError: # Should only be triggered for LeKiwi robot + except TypeError: # Should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner ep_ft_array = sample_audio_from_data(data) axes_to_reduce = 0 keepdims = True diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index da5874fb..e51a163d 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -150,6 +150,7 @@ class LeRobotDatasetMetadata: return Path(fpath) def get_compressed_audio_file_path(self, episode_index: int, audio_key: str) -> Path: + """Returns the path of the compressed (i.e. encoded) audio file.""" episode_chunk = self.get_episode_chunk(episode_index) fpath = self.audio_path.format( episode_chunk=episode_chunk, audio_key=audio_key, episode_index=episode_index @@ -206,7 +207,7 @@ class LeRobotDatasetMetadata: @property def audio_keys(self) -> list[str]: - """Keys to access audio modalities (wether they are linked to a camera or not).""" + """Keys to access audio modalities (whether they are linked to a camera or not).""" return [key for key, ft in self.features.items() if ft["dtype"] == "audio"] @property @@ -223,7 +224,7 @@ class LeRobotDatasetMetadata: def linked_audio_keys(self) -> list[str]: """Keys to access audio modalities linked to a camera.""" return [key for key in self.audio_keys if key in self.audio_camera_keys_mapping] - + @property def unlinked_audio_keys(self) -> list[str]: """Keys to access audio modalities not linked to a camera.""" @@ -342,9 +343,10 @@ class LeRobotDatasetMetadata: been encoded the same way. Also, this means it assumes the first episode exists. """ for key in self.unlinked_audio_keys: - if not self.features[key].get("info", None) or ( - len(self.features[key]["info"]) == 1 and "sample_rate" in self.features[key]["info"] - ): #Overwrite if info is empty or only contains sample rate (necessary to correctly save audio files recorded by LeKiwi) + if ( + not self.features[key].get("info", None) + or (len(self.features[key]["info"]) == 1 and "sample_rate" in self.features[key]["info"]) + ): # Overwrite if info is empty or only contains sample rate (necessary to correctly save audio files recorded by LeKiwi) audio_path = self.root / self.get_compressed_audio_file_path(0, key) self.info["features"][key]["info"] = get_audio_info(audio_path) @@ -568,9 +570,8 @@ class LeRobotDataset(torch.utils.data.Dataset): except (AssertionError, FileNotFoundError, NotADirectoryError): self.revision = get_safe_version(self.repo_id, self.revision) self.download_episodes( - download_videos, - download_audio - ) #Audio embedded in video files (.mp4) will be downloaded if download_videos is set to True, regardless of the value of download_audio + download_videos, download_audio + ) # Audio embedded in video files (.mp4) will be downloaded if download_videos is set to True, regardless of the value of download_audio self.hf_dataset = self.load_hf_dataset() self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes) @@ -581,6 +582,8 @@ class LeRobotDataset(torch.utils.data.Dataset): ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()} check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s) + # TODO(CarolinePascal) : add check for audio duration with respect to video duration and episode duration. + # Setup delta_indices if self.delta_timestamps is not None: check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s) @@ -601,7 +604,9 @@ class LeRobotDataset(torch.utils.data.Dataset): ) -> None: ignore_patterns = ["images/"] if not push_videos: - ignore_patterns.append("videos/") #Audio embedded in video files (.mp4) will be automatically pushed if videos are pushed + ignore_patterns.append( + "videos/" + ) # Audio embedded in video files (.mp4) will be automatically pushed if videos are pushed if not push_audio: ignore_patterns.append("audio/") @@ -670,7 +675,9 @@ class LeRobotDataset(torch.utils.data.Dataset): files = None ignore_patterns = [] if not download_videos: - ignore_patterns.append("videos/") #Audio embedded in video files (.mp4) will be automatically downloaded if videos are downloaded + ignore_patterns.append( + "videos/" + ) # Audio embedded in video files (.mp4) will be automatically downloaded if videos are downloaded if not download_audio: ignore_patterns.append("audio/") if self.episodes is not None: @@ -785,7 +792,7 @@ class LeRobotDataset(torch.utils.data.Dataset): query_indices: dict[str, list[int]] | None = None, ) -> dict[str, list[float]]: query_timestamps = {} - for key in self.meta.audio_keys: #Standalone audio and audio embedded in video as well ! + for key in self.meta.audio_keys: # Standalone audio and audio embedded in video as well ! if query_indices is not None and key in query_indices: timestamps = self.hf_dataset.select(query_indices[key])["timestamp"] query_timestamps[key] = torch.stack(timestamps).tolist() @@ -821,12 +828,12 @@ class LeRobotDataset(torch.utils.data.Dataset): ) -> dict[str, torch.Tensor]: item = {} for audio_key, query_ts in query_timestamps.items(): - #Audio stored with video in a single .mp4 file + # Audio stored with video in a single .mp4 file if audio_key in self.meta.linked_audio_keys: audio_path = self.root / self.meta.get_video_file_path( ep_idx, self.meta.audio_camera_keys_mapping[audio_key] ) - #Audio stored alone in a separate .m4a file + # Audio stored alone in a separate .m4a file else: audio_path = self.root / self.meta.get_compressed_audio_file_path(ep_idx, audio_key) audio_chunk = decode_audio(audio_path, query_ts, query_duration, self.audio_backend) @@ -957,9 +964,11 @@ class LeRobotDataset(torch.utils.data.Dataset): self._save_image(frame[key], img_path) self.episode_buffer[key].append(str(img_path)) elif self.features[key]["dtype"] == "audio": - if self.meta.robot_type.startswith("lekiwi"): + if self.meta.robot_type.startswith( + "lekiwi" + ): # Rw data storage should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner self.episode_buffer[key].append(frame[key]) - else: + else: # Otherwise, only the audio file path is stored in the episode buffer if frame_index == 0: audio_path = self._get_raw_audio_file_path( episode_index=self.episode_buffer["episode_index"], audio_key=key @@ -972,7 +981,7 @@ class LeRobotDataset(torch.utils.data.Dataset): def add_microphone_recording(self, microphone: Microphone, microphone_key: str) -> None: """ - This function will start recording audio from the microphone and save it to disk. + Starts recording audio data provided by the microphone and directly writes it in a .wav file. """ audio_dir = self._get_raw_audio_file_path( @@ -1025,7 +1034,9 @@ class LeRobotDataset(torch.utils.data.Dataset): if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]: continue elif ft["dtype"] == "audio": - if self.meta.robot_type.startswith("lekiwi"): + if self.meta.robot_type.startswith( + "lekiwi" + ): # Raw data storage should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner episode_buffer[key] = np.concatenate(episode_buffer[key], axis=0) continue episode_buffer[key] = np.stack(episode_buffer[key]) @@ -1033,7 +1044,9 @@ class LeRobotDataset(torch.utils.data.Dataset): self._wait_image_writer() self._save_episode_table(episode_buffer, episode_index) - if self.meta.robot_type.startswith("lekiwi"): + if self.meta.robot_type.startswith( + "lekiwi" + ): # Raw data storage should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner for key in self.meta.audio_keys: audio_path = self._get_raw_audio_file_path( episode_index=self.episode_buffer["episode_index"][0], audio_key=key @@ -1053,7 +1066,7 @@ class LeRobotDataset(torch.utils.data.Dataset): for key in self.meta.video_keys: episode_buffer[key] = video_paths[key] - if len(self.meta.unlinked_audio_keys) > 0: #Linked audio is already encoded in the video files + if len(self.meta.unlinked_audio_keys) > 0: # Linked audio is already encoded in the video files _ = self.encode_episode_audio(episode_index) # `meta.save_episode` be executed after encoding the videos @@ -1080,7 +1093,7 @@ class LeRobotDataset(torch.utils.data.Dataset): if img_dir.is_dir(): shutil.rmtree(self.root / "images") - # delete raw audio + # delete raw audio files raw_audio_files = list(self.root.rglob("*.wav")) for raw_audio_file in raw_audio_files: raw_audio_file.unlink() diff --git a/lerobot/common/datasets/video_utils.py b/lerobot/common/datasets/video_utils.py index 0511610e..d1b25023 100644 --- a/lerobot/common/datasets/video_utils.py +++ b/lerobot/common/datasets/video_utils.py @@ -52,14 +52,14 @@ def decode_audio( Decodes audio using the specified backend. Args: audio_path (Path): Path to the audio file. - timestamps (list[float]): List of timestamps to extract frames. - tolerance_s (float): Allowed deviation in seconds for frame retrieval. - backend (str, optional): Backend to use for decoding. Defaults to "pyav". + timestamps (list[float]): List of (starting) timestamps to extract audio chunks. + duration (float): Duration of the audio chunks in seconds. + backend (str, optional): Backend to use for decoding. Defaults to "ffmpeg". Returns: - torch.Tensor: Decoded frames. + torch.Tensor: Decoded audio chunks. - Currently supports pyav. + Currently supports ffmpeg. """ if backend == "torchcodec": raise NotImplementedError("torchcodec is not yet supported for audio decoding") @@ -82,7 +82,6 @@ def decode_audio_torchvision( audio_sample_rate = reader.get_src_stream_info(reader.default_audio_stream).sample_rate # TODO(CarolinePascal) : sort timestamps ? - reader.add_basic_audio_stream( frames_per_chunk=int(ceil(duration * audio_sample_rate)), # Too much is better than not enough buffer_chunk_size=-1, # No dropping frames @@ -317,7 +316,7 @@ def decode_video_frames_torchcodec( def encode_audio( input_path: Path | str, output_path: Path | str, - codec: str = "aac", + codec: str = "aac", # TODO(CarolinePascal) : investigate Fraunhofer FDK AAC (libfdk_aac) codec and and constant (file size control) /variable (quality control) bitrate options log_level: str | None = "error", overwrite: bool = False, ) -> None: @@ -346,7 +345,7 @@ def encode_audio( if not output_path.exists(): raise OSError( - f"Video encoding did not work. File not found: {output_path}. " + f"Audio encoding did not work. File not found: {output_path}. " f"Try running the command manually to debug: `{''.join(ffmpeg_cmd)}`" ) diff --git a/lerobot/common/robot_devices/microphones/microphone.py b/lerobot/common/robot_devices/microphones/microphone.py index 947fdfea..b08842c2 100644 --- a/lerobot/common/robot_devices/microphones/microphone.py +++ b/lerobot/common/robot_devices/microphones/microphone.py @@ -44,6 +44,10 @@ from lerobot.common.utils.utils import capture_timestamp_utc def find_microphones(raise_when_empty=False, mock=False) -> list[dict]: + """ + Finds and lists all microphones compatible with sounddevice (and the underlying PortAudio library). + Most microphones and sound cards are compatible, across all OS (Linux, Mac, Windows). + """ microphones = [] if mock: @@ -72,6 +76,11 @@ def find_microphones(raise_when_empty=False, mock=False) -> list[dict]: def record_audio_from_microphones( output_dir: Path, microphone_ids: list[int] | None = None, record_time_s: float = 2.0 ): + """ + Records audio from all the channels of the specified microphones for the specified duration. + If no microphone ids are provided, all available microphones will be used. + """ + if microphone_ids is None or len(microphone_ids) == 0: microphones = find_microphones() microphone_ids = [m["index"] for m in microphones] @@ -112,7 +121,7 @@ def record_audio_from_microphones( class Microphone: """ - The Microphone class handles all microphones compatible with sounddevice (and the underlying PortAudio library). Most microphones and sound cards are compatible, accross all OS (Linux, Mac, Windows). + The Microphone class handles all microphones compatible with sounddevice (and the underlying PortAudio library). Most microphones and sound cards are compatible, across all OS (Linux, Mac, Windows). A Microphone instance requires the sounddevice index of the microphone, which may be obtained using `python -m sounddevice`. It also requires the recording sample rate as well as the list of recorded channels. @@ -146,11 +155,11 @@ class Microphone: # Input audio stream self.stream = None - # Thread-safe concurrent queue to store the recorded/read audio + # Thread/Process-safe concurrent queue to store the recorded/read audio self.record_queue = None self.read_queue = None - # Thread to handle data reading and file writing in a separate thread (safely) + # Thread/Process to handle data reading and file writing in a separate thread/process (safely) self.record_thread = None self.record_stop_event = None @@ -160,6 +169,9 @@ class Microphone: self.is_writing = False def connect(self) -> None: + """ + Connects the microphone and checks if the requested acquisition parameters are compatible with the microphone. + """ if self.is_connected: raise RobotDeviceAlreadyConnectedError( f"Microphone {self.microphone_index} is already connected." @@ -214,15 +226,18 @@ class Microphone: dtype="float32", callback=self._audio_callback, ) - # Remark : the blocksize parameter could be passed to the stream to ensure that audio_callback always recieve same length buffers. - # However, this may lead to additionnal latency. We thus stick to blocksize=0 which means that audio_callback will recieve varying length buffers, but with no addtional latency. + # Remark : the blocksize parameter could be passed to the stream to ensure that audio_callback always receive same length buffers. + # However, this may lead to additional latency. We thus stick to blocksize=0 which means that audio_callback will receive varying length buffers, but with no additional latency. self.is_connected = True def _audio_callback(self, indata, frames, time, status) -> None: + """ + Low-level sounddevice callback. + """ if status: logging.warning(status) - # Slicing makes copy unecessary + # Slicing makes copy unnecessary # Two separate queues are necessary because .get() also pops the data from the queue if self.is_writing: self.record_queue.put(indata[:, self.channels]) @@ -230,6 +245,9 @@ class Microphone: @staticmethod def _record_loop(queue, event: Event, sample_rate: int, channels: list[int], output_file: Path) -> None: + """ + Thread/Process-safe loop to write audio data into a file. + """ # Can only be run on a single process/thread for file writing safety with sf.SoundFile( output_file, @@ -249,9 +267,7 @@ class Microphone: def _read(self) -> np.ndarray: """ - Gets audio data from the queue and coverts it to a numpy array. - -> PROS : Inherently thread safe, no need to lock the queue, lightweight CPU usage - -> CONS : Reading duration does not scale well with the number of channels and reading duration + Thread/Process-safe callback to read available audio data """ audio_readings = np.empty((0, len(self.channels))) @@ -266,6 +282,9 @@ class Microphone: return audio_readings def read(self) -> np.ndarray: + """ + Reads the last audio chunk recorded by the microphone, e.g. all samples recorded since the last read or since the beginning of the recording. + """ if not self.is_connected: raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.") if not self.is_recording: @@ -284,6 +303,9 @@ class Microphone: return audio_readings def start_recording(self, output_file: str | None = None, multiprocessing: bool | None = False) -> None: + """ + Starts the recording of the microphone. If output_file is provided, the audio will be written to this file. + """ if not self.is_connected: raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.") if self.is_recording: @@ -337,6 +359,9 @@ class Microphone: self.stream.start() def stop_recording(self) -> None: + """ + Stops the recording of the microphones. + """ if not self.is_connected: raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.") if not self.is_recording: @@ -356,6 +381,9 @@ class Microphone: self.is_writing = False def disconnect(self) -> None: + """ + Disconnects the microphone and stops the recording. + """ if not self.is_connected: raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.") @@ -385,7 +413,7 @@ if __name__ == "__main__": "--output-dir", type=Path, default="outputs/audio_from_microphones", - help="Set directory to save an audio snipet for each microphone.", + help="Set directory to save an audio snippet for each microphone.", ) parser.add_argument( "--record-time-s", diff --git a/lerobot/common/robot_devices/robots/mobile_manipulator.py b/lerobot/common/robot_devices/robots/mobile_manipulator.py index 7727abb9..4af008ed 100644 --- a/lerobot/common/robot_devices/robots/mobile_manipulator.py +++ b/lerobot/common/robot_devices/robots/mobile_manipulator.py @@ -381,7 +381,7 @@ class MobileManipulator: if frame_candidate is not None: frames[cam_name] = frame_candidate - # Recieve audio + # Receive audio for microphone_name, audio_data in audio_dict.items(): if audio_data: frames[microphone_name] = audio_data From 083f72c1d0d3bd4a6b46425d4925e8b897e14fea Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Fri, 11 Apr 2025 15:46:01 +0200 Subject: [PATCH 24/32] fix: Check if robot_type is not None before getting its value --- lerobot/common/datasets/lerobot_dataset.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index e51a163d..5e13f9e5 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -964,8 +964,8 @@ class LeRobotDataset(torch.utils.data.Dataset): self._save_image(frame[key], img_path) self.episode_buffer[key].append(str(img_path)) elif self.features[key]["dtype"] == "audio": - if self.meta.robot_type.startswith( - "lekiwi" + if ( + self.meta.robot_type is not None and self.meta.robot_type.startswith("lekiwi") ): # Rw data storage should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner self.episode_buffer[key].append(frame[key]) else: # Otherwise, only the audio file path is stored in the episode buffer @@ -1034,8 +1034,8 @@ class LeRobotDataset(torch.utils.data.Dataset): if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]: continue elif ft["dtype"] == "audio": - if self.meta.robot_type.startswith( - "lekiwi" + if ( + self.meta.robot_type is not None and self.meta.robot_type.startswith("lekiwi") ): # Raw data storage should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner episode_buffer[key] = np.concatenate(episode_buffer[key], axis=0) continue @@ -1044,8 +1044,8 @@ class LeRobotDataset(torch.utils.data.Dataset): self._wait_image_writer() self._save_episode_table(episode_buffer, episode_index) - if self.meta.robot_type.startswith( - "lekiwi" + if ( + self.meta.robot_type is not None and self.meta.robot_type.startswith("lekiwi") ): # Raw data storage should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner for key in self.meta.audio_keys: audio_path = self._get_raw_audio_file_path( From 8bea50ecdd7637b4e3b4afeb3c14380b45db5613 Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Fri, 11 Apr 2025 15:46:41 +0200 Subject: [PATCH 25/32] fix: adding microphone argument as None for realsense cameras --- lerobot/common/robot_devices/cameras/intelrealsense.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lerobot/common/robot_devices/cameras/intelrealsense.py b/lerobot/common/robot_devices/cameras/intelrealsense.py index 7a21661a..ac0e8ac7 100644 --- a/lerobot/common/robot_devices/cameras/intelrealsense.py +++ b/lerobot/common/robot_devices/cameras/intelrealsense.py @@ -265,6 +265,8 @@ class IntelRealSenseCamera: elif config.rotation == 180: self.rotation = cv2.ROTATE_180 + self.microphone = None # No microphones on realsense cameras, sorry + def find_serial_number_from_name(self, name): camera_infos = find_cameras() camera_names = [cam["name"] for cam in camera_infos] From c7930d617852572cf0e4786952a43a12e338a9c6 Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Fri, 11 Apr 2025 15:48:23 +0200 Subject: [PATCH 26/32] fix: adding proper definition for "total_audio" in LeRobotDataset --- lerobot/common/datasets/utils.py | 1 + tests/fixtures/dataset_factories.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 416d5837..7ac46013 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -452,6 +452,7 @@ def create_empty_dataset_info( "total_frames": 0, "total_tasks": 0, "total_videos": 0, + "total_audio": 0, "total_chunks": 0, "chunks_size": DEFAULT_CHUNK_SIZE, "fps": fps, diff --git a/tests/fixtures/dataset_factories.py b/tests/fixtures/dataset_factories.py index 321dec46..6de43ef1 100644 --- a/tests/fixtures/dataset_factories.py +++ b/tests/fixtures/dataset_factories.py @@ -121,6 +121,7 @@ def info_factory(features_factory): total_frames: int = 0, total_tasks: int = 0, total_videos: int = 0, + total_audio: int = 0, total_chunks: int = 0, chunks_size: int = DEFAULT_CHUNK_SIZE, data_path: str = DEFAULT_PARQUET_PATH, @@ -139,6 +140,7 @@ def info_factory(features_factory): "total_frames": total_frames, "total_tasks": total_tasks, "total_videos": total_videos, + "total_audio": total_audio, "total_chunks": total_chunks, "chunks_size": chunks_size, "fps": fps, From cb1a625617e3ed5e06a672d990ffe1e2cfe37698 Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Fri, 11 Apr 2025 15:55:55 +0200 Subject: [PATCH 27/32] fix: switching audio dependencies from audio to dependencies for pytest minimal install success --- pyproject.toml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 0830b4ee..73f68a72 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,8 +67,11 @@ dependencies = [ "pynput>=1.7.7", "pyzmq>=26.2.1", "rerun-sdk>=0.21.0", + "sounddevice>=0.5.1", + "soundfile>=0.13.1", "termcolor>=2.4.0", "torch>=2.2.1", + "torchaudio>=2.6.0", "torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", "torchvision>=0.21.0", "wandb>=0.16.3", @@ -96,7 +99,7 @@ test = ["pytest>=8.1.0", "pytest-cov>=5.0.0", "pyserial>=3.5"] umi = ["imagecodecs>=2024.1.1"] video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"] xarm = ["gym-xarm>=0.1.1 ; python_version < '4.0'"] -audio = ["sounddevice>=0.5.1", "soundfile>=0.13.1", "librosa>=0.11.0", "torchaudio>=2.6.0"] +audio = ["librosa>=0.11.0"] [tool.poetry] requires-poetry = ">=2.1" From 5267829b5ab122cce71ff3085d8346cb94796824 Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Fri, 11 Apr 2025 18:51:34 +0200 Subject: [PATCH 28/32] fix: fixing issue with microphone channels numbering and status recovery on stop_recording --- .../robot_devices/microphones/microphone.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/lerobot/common/robot_devices/microphones/microphone.py b/lerobot/common/robot_devices/microphones/microphone.py index b08842c2..33d772a7 100644 --- a/lerobot/common/robot_devices/microphones/microphone.py +++ b/lerobot/common/robot_devices/microphones/microphone.py @@ -207,22 +207,22 @@ class Microphone: else: self.sample_rate = int(actual_microphone["default_samplerate"]) - if self.channels is not None: + if self.channels is not None and len(self.channels) > 0: if any(c > actual_microphone["max_input_channels"] for c in self.channels): raise OSError( f"Some of the provided channels {self.channels} are outside the maximum channel range of the microphone {actual_microphone['max_input_channels']}." ) else: self.channels = np.arange(1, actual_microphone["max_input_channels"] + 1) - + # Get channels index instead of number for slicing - self.channels = np.array(self.channels) - 1 + self.channels_index = np.array(self.channels) - 1 # Create the audio stream self.stream = sd.InputStream( device=self.microphone_index, samplerate=self.sample_rate, - channels=max(self.channels) + 1, + channels=max(self.channels), dtype="float32", callback=self._audio_callback, ) @@ -240,8 +240,8 @@ class Microphone: # Slicing makes copy unnecessary # Two separate queues are necessary because .get() also pops the data from the queue if self.is_writing: - self.record_queue.put(indata[:, self.channels]) - self.read_queue.put(indata[:, self.channels]) + self.record_queue.put(indata[:, self.channels_index]) + self.read_queue.put(indata[:, self.channels_index]) @staticmethod def _record_loop(queue, event: Event, sample_rate: int, channels: list[int], output_file: Path) -> None: @@ -253,7 +253,7 @@ class Microphone: output_file, mode="x", samplerate=sample_rate, - channels=max(channels) + 1, + channels=max(channels), subtype=sf.default_subtype(output_file.suffix[1:]), ) as file: while not event.is_set(): @@ -370,7 +370,7 @@ class Microphone: if self.stream.active: self.stream.stop() # Wait for all buffers to be processed # Remark : stream.abort() flushes the buffers ! - self.is_recording = False + self.is_recording = False if self.record_thread is not None: self.record_queue.join() @@ -378,7 +378,7 @@ class Microphone: self.record_thread.join() self.record_thread = None self.record_stop_event = None - self.is_writing = False + self.is_writing = False def disconnect(self) -> None: """ From 89697d86e748dbeb48225feed181af84f4c5cf78 Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Fri, 11 Apr 2025 18:51:58 +0200 Subject: [PATCH 29/32] fix: fixing typos --- lerobot/common/datasets/video_utils.py | 4 ++-- tests/datasets/test_compute_stats.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lerobot/common/datasets/video_utils.py b/lerobot/common/datasets/video_utils.py index d1b25023..c7a5da61 100644 --- a/lerobot/common/datasets/video_utils.py +++ b/lerobot/common/datasets/video_utils.py @@ -64,12 +64,12 @@ def decode_audio( if backend == "torchcodec": raise NotImplementedError("torchcodec is not yet supported for audio decoding") elif backend == "ffmpeg": - return decode_audio_torchvision(audio_path, timestamps, duration) + return decode_audio_torchaudio(audio_path, timestamps, duration) else: raise ValueError(f"Unsupported video backend: {backend}") -def decode_audio_torchvision( +def decode_audio_torchaudio( audio_path: Path | str, timestamps: list[float], duration: float, diff --git a/tests/datasets/test_compute_stats.py b/tests/datasets/test_compute_stats.py index 9cf9f760..2ebf95f2 100644 --- a/tests/datasets/test_compute_stats.py +++ b/tests/datasets/test_compute_stats.py @@ -87,7 +87,7 @@ def test_sample_audio_from_path(mock_load): assert len(audio_samples) == estimate_num_samples(16000) -def test_sample_audio_from_data(mock_load): +def test_sample_audio_from_data(): audio_data = np.ones((16000, 2), dtype=np.float32) audio_samples = sample_audio_from_data(audio_data) assert isinstance(audio_samples, np.ndarray) From 4ddba296f7e33ba3d68c6acfc391cf5d12b0900e Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Fri, 11 Apr 2025 18:55:37 +0200 Subject: [PATCH 30/32] fix: default float64 type must be cast into float32 for audio --- lerobot/common/robot_devices/control_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py index e49a4e71..111e21f0 100644 --- a/lerobot/common/robot_devices/control_utils.py +++ b/lerobot/common/robot_devices/control_utils.py @@ -117,8 +117,9 @@ def predict_action(observation, policy, device, use_amp): if "image" in name: observation[name] = observation[name].type(torch.float32) / 255 observation[name] = observation[name].permute(2, 0, 1).contiguous() - # Convert to pytorch format: channel first and float32 in [-1,1] (always the case here) with batch dimension + # Convert to pytorch format: channel first and float32 in [-1,1] with batch dimension if "audio" in name: + observation[name] = observation[name].type(torch.float32) observation[name] = observation[name].permute(1, 0).contiguous() observation[name] = observation[name].unsqueeze(0) observation[name] = observation[name].to(device) From 6cf9cb35ba8f05d649c5e1148d0533dc0a9a3a0a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 11 Apr 2025 17:13:25 +0000 Subject: [PATCH 31/32] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- lerobot/common/robot_devices/microphones/microphone.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lerobot/common/robot_devices/microphones/microphone.py b/lerobot/common/robot_devices/microphones/microphone.py index 33d772a7..a92be011 100644 --- a/lerobot/common/robot_devices/microphones/microphone.py +++ b/lerobot/common/robot_devices/microphones/microphone.py @@ -214,7 +214,7 @@ class Microphone: ) else: self.channels = np.arange(1, actual_microphone["max_input_channels"] + 1) - + # Get channels index instead of number for slicing self.channels_index = np.array(self.channels) - 1 From ca716ed1960aecc5da835264e0313db477e68ca1 Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Tue, 15 Apr 2025 17:14:55 +0200 Subject: [PATCH 32/32] fix(audio): separate audio from video --- lerobot/common/datasets/audio_utils.py | 165 ++++++++++++++++++ lerobot/common/datasets/lerobot_dataset.py | 114 +++++------- lerobot/common/datasets/video_utils.py | 162 ----------------- .../common/robot_devices/cameras/configs.py | 2 - .../robot_devices/cameras/intelrealsense.py | 2 - .../common/robot_devices/cameras/opencv.py | 2 - .../common/robot_devices/robots/configs.py | 1 - .../robot_devices/robots/manipulator.py | 5 +- 8 files changed, 213 insertions(+), 240 deletions(-) create mode 100644 lerobot/common/datasets/audio_utils.py diff --git a/lerobot/common/datasets/audio_utils.py b/lerobot/common/datasets/audio_utils.py new file mode 100644 index 00000000..901fad52 --- /dev/null +++ b/lerobot/common/datasets/audio_utils.py @@ -0,0 +1,165 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import logging +import subprocess +from collections import OrderedDict +from pathlib import Path + +import torch +import torchaudio +from numpy import ceil + + +def decode_audio( + audio_path: Path | str, + timestamps: list[float], + duration: float, + backend: str | None = "ffmpeg", +) -> torch.Tensor: + """ + Decodes audio using the specified backend. + Args: + audio_path (Path): Path to the audio file. + timestamps (list[float]): List of (starting) timestamps to extract audio chunks. + duration (float): Duration of the audio chunks in seconds. + backend (str, optional): Backend to use for decoding. Defaults to "ffmpeg". + + Returns: + torch.Tensor: Decoded audio chunks. + + Currently supports ffmpeg. + """ + if backend == "torchcodec": + raise NotImplementedError("torchcodec is not yet supported for audio decoding") + elif backend == "ffmpeg": + return decode_audio_torchaudio(audio_path, timestamps, duration) + else: + raise ValueError(f"Unsupported video backend: {backend}") + + +def decode_audio_torchaudio( + audio_path: Path | str, + timestamps: list[float], + duration: float, + log_loaded_timestamps: bool = False, +) -> torch.Tensor: + # TODO(CarolinePascal) : add channels selection + audio_path = str(audio_path) + + reader = torchaudio.io.StreamReader(src=audio_path) + audio_sample_rate = reader.get_src_stream_info(reader.default_audio_stream).sample_rate + + # TODO(CarolinePascal) : sort timestamps ? + reader.add_basic_audio_stream( + frames_per_chunk=int(ceil(duration * audio_sample_rate)), # Too much is better than not enough + buffer_chunk_size=-1, # No dropping frames + format="fltp", # Format as float32 + ) + + audio_chunks = [] + for ts in timestamps: + reader.seek(ts) # Default to closest audio sample + status = reader.fill_buffer() + if status != 0: + logging.warning("Audio stream reached end of recording before decoding desired timestamps.") + + current_audio_chunk = reader.pop_chunks()[0] + + if log_loaded_timestamps: + logging.info( + f"audio chunk loaded at starting timestamp={current_audio_chunk['pts']:.4f} with duration={len(current_audio_chunk) / audio_sample_rate:.4f}" + ) + + audio_chunks.append(current_audio_chunk) + + audio_chunks = torch.stack(audio_chunks) + + assert len(timestamps) == len(audio_chunks) + return audio_chunks + + +def encode_audio( + input_path: Path | str, + output_path: Path | str, + codec: str = "aac", # TODO(CarolinePascal) : investigate Fraunhofer FDK AAC (libfdk_aac) codec and and constant (file size control) /variable (quality control) bitrate options + log_level: str | None = "error", + overwrite: bool = False, +) -> None: + """Encodes an audio file using ffmpeg.""" + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + ffmpeg_args = OrderedDict( + [ + ("-i", str(input_path)), + ("-acodec", codec), + ] + ) + + if log_level is not None: + ffmpeg_args["-loglevel"] = str(log_level) + + ffmpeg_args = [item for pair in ffmpeg_args.items() for item in pair] + if overwrite: + ffmpeg_args.append("-y") + + ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(output_path)] + + # redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal + subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL) + + if not output_path.exists(): + raise OSError( + f"Audio encoding did not work. File not found: {output_path}. " + f"Try running the command manually to debug: `{''.join(ffmpeg_cmd)}`" + ) + + +def get_audio_info(video_path: Path | str) -> dict: + ffprobe_audio_cmd = [ + "ffprobe", + "-v", + "error", + "-select_streams", + "a:0", + "-show_entries", + "stream=channels,codec_name,bit_rate,sample_rate,bit_depth,channel_layout,duration", + "-of", + "json", + str(video_path), + ] + result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + if result.returncode != 0: + raise RuntimeError(f"Error running ffprobe: {result.stderr}") + + info = json.loads(result.stdout) + audio_stream_info = info["streams"][0] if info.get("streams") else None + if audio_stream_info is None: + return {"has_audio": False} + + # Return the information, defaulting to None if no audio stream is present + return { + "has_audio": True, + "audio.channels": audio_stream_info.get("channels", None), + "audio.codec": audio_stream_info.get("codec_name", None), + "audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None, + "audio.sample_rate": int(audio_stream_info["sample_rate"]) + if audio_stream_info.get("sample_rate") + else None, + "audio.bit_depth": audio_stream_info.get("bit_depth", None), + "audio.channel_layout": audio_stream_info.get("channel_layout", None), + } diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 5e13f9e5..6c7af6a3 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -32,6 +32,11 @@ from huggingface_hub.constants import REPOCARD_NAME from huggingface_hub.errors import RevisionNotFoundError from lerobot.common.constants import HF_LEROBOT_HOME +from lerobot.common.datasets.audio_utils import ( + decode_audio, + encode_audio, + get_audio_info, +) from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image from lerobot.common.datasets.utils import ( @@ -70,11 +75,8 @@ from lerobot.common.datasets.utils import ( ) from lerobot.common.datasets.video_utils import ( VideoFrame, - decode_audio, decode_video_frames, - encode_audio, encode_video_frames, - get_audio_info, get_safe_default_codec, get_video_info, ) @@ -207,29 +209,9 @@ class LeRobotDatasetMetadata: @property def audio_keys(self) -> list[str]: - """Keys to access audio modalities (whether they are linked to a camera or not).""" + """Keys to access audio modalities.""" return [key for key, ft in self.features.items() if ft["dtype"] == "audio"] - @property - def audio_camera_keys_mapping(self) -> dict[str, str]: - """Mapping between camera keys and audio keys when both are linked.""" - return { - self.features[camera_key]["audio"]: camera_key - for camera_key in self.camera_keys - if self.features[camera_key]["audio"] is not None - and self.features[camera_key]["dtype"] == "video" - } - - @property - def linked_audio_keys(self) -> list[str]: - """Keys to access audio modalities linked to a camera.""" - return [key for key in self.audio_keys if key in self.audio_camera_keys_mapping] - - @property - def unlinked_audio_keys(self) -> list[str]: - """Keys to access audio modalities not linked to a camera.""" - return [key for key in self.audio_keys if key not in self.audio_camera_keys_mapping] - @property def names(self) -> dict[str, list | dict]: """Names of the various dimensions of vector modalities.""" @@ -310,7 +292,7 @@ class LeRobotDatasetMetadata: self.update_video_info() self.info["total_audio"] += len(self.audio_keys) - if len(self.unlinked_audio_keys) > 0: + if len(self.audio_keys) > 0: self.update_audio_info() write_info(self.info, self.root) @@ -342,7 +324,7 @@ class LeRobotDatasetMetadata: Warning: this function writes info from first episode audio, implicitly assuming that all audio have been encoded the same way. Also, this means it assumes the first episode exists. """ - for key in self.unlinked_audio_keys: + for key in self.audio_keys: if ( not self.features[key].get("info", None) or (len(self.features[key]["info"]) == 1 and "sample_rate" in self.features[key]["info"]) @@ -480,17 +462,31 @@ class LeRobotDataset(torch.utils.data.Dataset): │ ├── info.json │ ├── stats.json │ └── tasks.jsonl - └── videos + ├── videos + │ ├── chunk-000 + │ │ ├── observation.images.laptop + │ │ │ ├── episode_000000.mp4 + │ │ │ ├── episode_000001.mp4 + │ │ │ ├── episode_000002.mp4 + │ │ │ └── ... + │ │ ├── observation.images.phone + │ │ │ ├── episode_000000.mp4 + │ │ │ ├── episode_000001.mp4 + │ │ │ ├── episode_000002.mp4 + │ │ │ └── ... + │ ├── chunk-001 + │ └── ... + └── audio ├── chunk-000 - │ ├── observation.images.laptop - │ │ ├── episode_000000.mp4 - │ │ ├── episode_000001.mp4 - │ │ ├── episode_000002.mp4 + │ ├── observation.audio.laptop + │ │ ├── episode_000000.m4a + │ │ ├── episode_000001.m4a + │ │ ├── episode_000002.m4a │ │ └── ... - │ ├── observation.images.phone - │ │ ├── episode_000000.mp4 - │ │ ├── episode_000001.mp4 - │ │ ├── episode_000002.mp4 + │ ├── observation.audio.phone + │ │ ├── episode_000000.m4a + │ │ ├── episode_000001.m4a + │ │ ├── episode_000002.m4a │ │ └── ... ├── chunk-001 └── ... @@ -569,9 +565,7 @@ class LeRobotDataset(torch.utils.data.Dataset): self.hf_dataset = self.load_hf_dataset() except (AssertionError, FileNotFoundError, NotADirectoryError): self.revision = get_safe_version(self.repo_id, self.revision) - self.download_episodes( - download_videos, download_audio - ) # Audio embedded in video files (.mp4) will be downloaded if download_videos is set to True, regardless of the value of download_audio + self.download_episodes(download_videos, download_audio) self.hf_dataset = self.load_hf_dataset() self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes) @@ -582,7 +576,7 @@ class LeRobotDataset(torch.utils.data.Dataset): ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()} check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s) - # TODO(CarolinePascal) : add check for audio duration with respect to video duration and episode duration. + # TODO(CarolinePascal) : add check for audio duration with respect to episode duration BUT this will be CPU expensive if there are many episodes ! # Setup delta_indices if self.delta_timestamps is not None: @@ -604,9 +598,7 @@ class LeRobotDataset(torch.utils.data.Dataset): ) -> None: ignore_patterns = ["images/"] if not push_videos: - ignore_patterns.append( - "videos/" - ) # Audio embedded in video files (.mp4) will be automatically pushed if videos are pushed + ignore_patterns.append("videos/") if not push_audio: ignore_patterns.append("audio/") @@ -675,9 +667,7 @@ class LeRobotDataset(torch.utils.data.Dataset): files = None ignore_patterns = [] if not download_videos: - ignore_patterns.append( - "videos/" - ) # Audio embedded in video files (.mp4) will be automatically downloaded if videos are downloaded + ignore_patterns.append("videos/") if not download_audio: ignore_patterns.append("audio/") if self.episodes is not None: @@ -696,10 +686,10 @@ class LeRobotDataset(torch.utils.data.Dataset): ] fpaths += video_files - if len(self.meta.unlinked_audio_keys) > 0: + if len(self.meta.audio_keys) > 0: audio_files = [ str(self.meta.get_compressed_audio_file_path(ep_idx, audio_key)) - for audio_key in self.meta.unlinked_audio_keys + for audio_key in self.meta.audio_keys for ep_idx in episodes ] fpaths += audio_files @@ -792,7 +782,7 @@ class LeRobotDataset(torch.utils.data.Dataset): query_indices: dict[str, list[int]] | None = None, ) -> dict[str, list[float]]: query_timestamps = {} - for key in self.meta.audio_keys: # Standalone audio and audio embedded in video as well ! + for key in self.meta.audio_keys: if query_indices is not None and key in query_indices: timestamps = self.hf_dataset.select(query_indices[key])["timestamp"] query_timestamps[key] = torch.stack(timestamps).tolist() @@ -828,14 +818,7 @@ class LeRobotDataset(torch.utils.data.Dataset): ) -> dict[str, torch.Tensor]: item = {} for audio_key, query_ts in query_timestamps.items(): - # Audio stored with video in a single .mp4 file - if audio_key in self.meta.linked_audio_keys: - audio_path = self.root / self.meta.get_video_file_path( - ep_idx, self.meta.audio_camera_keys_mapping[audio_key] - ) - # Audio stored alone in a separate .m4a file - else: - audio_path = self.root / self.meta.get_compressed_audio_file_path(ep_idx, audio_key) + audio_path = self.root / self.meta.get_compressed_audio_file_path(ep_idx, audio_key) audio_chunk = decode_audio(audio_path, query_ts, query_duration, self.audio_backend) item[audio_key] = audio_chunk.squeeze(0) return item @@ -966,7 +949,7 @@ class LeRobotDataset(torch.utils.data.Dataset): elif self.features[key]["dtype"] == "audio": if ( self.meta.robot_type is not None and self.meta.robot_type.startswith("lekiwi") - ): # Rw data storage should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner + ): # Raw data storage should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner self.episode_buffer[key].append(frame[key]) else: # Otherwise, only the audio file path is stored in the episode buffer if frame_index == 0: @@ -1062,12 +1045,10 @@ class LeRobotDataset(torch.utils.data.Dataset): ep_stats = compute_episode_stats(episode_buffer, self.features) if len(self.meta.video_keys) > 0: - video_paths = self.encode_episode_videos(episode_index) - for key in self.meta.video_keys: - episode_buffer[key] = video_paths[key] + self.encode_episode_videos(episode_index) - if len(self.meta.unlinked_audio_keys) > 0: # Linked audio is already encoded in the video files - _ = self.encode_episode_audio(episode_index) + if len(self.meta.audio_keys) > 0: + self.encode_episode_audio(episode_index) # `meta.save_episode` be executed after encoding the videos self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats) @@ -1177,12 +1158,7 @@ class LeRobotDataset(torch.utils.data.Dataset): episode_index=episode_index, image_key=video_key, frame_index=0 ).parent - audio_path = None - if self.meta.features[video_key]["audio"] is not None: - audio_key = self.meta.features[video_key]["audio"] - audio_path = self._get_raw_audio_file_path(episode_index, audio_key) - - encode_video_frames(img_dir, video_path, self.fps, audio_path=audio_path, overwrite=True) + encode_video_frames(img_dir, video_path, self.fps, overwrite=True) return video_paths @@ -1193,7 +1169,7 @@ class LeRobotDataset(torch.utils.data.Dataset): since video encoding with ffmpeg is already using multithreading. """ audio_paths = {} - for audio_key in self.meta.unlinked_audio_keys: + for audio_key in self.meta.audio_keys: input_audio_path = self.root / self._get_raw_audio_file_path(episode_index, audio_key) output_audio_path = self.root / self.meta.get_compressed_audio_file_path(episode_index, audio_key) diff --git a/lerobot/common/datasets/video_utils.py b/lerobot/common/datasets/video_utils.py index c7a5da61..fbf0b48c 100644 --- a/lerobot/common/datasets/video_utils.py +++ b/lerobot/common/datasets/video_utils.py @@ -25,10 +25,8 @@ from typing import Any, ClassVar import pyarrow as pa import torch -import torchaudio import torchvision from datasets.features.features import register_feature -from numpy import ceil from PIL import Image @@ -42,74 +40,6 @@ def get_safe_default_codec(): return "pyav" -def decode_audio( - audio_path: Path | str, - timestamps: list[float], - duration: float, - backend: str | None = "ffmpeg", -) -> torch.Tensor: - """ - Decodes audio using the specified backend. - Args: - audio_path (Path): Path to the audio file. - timestamps (list[float]): List of (starting) timestamps to extract audio chunks. - duration (float): Duration of the audio chunks in seconds. - backend (str, optional): Backend to use for decoding. Defaults to "ffmpeg". - - Returns: - torch.Tensor: Decoded audio chunks. - - Currently supports ffmpeg. - """ - if backend == "torchcodec": - raise NotImplementedError("torchcodec is not yet supported for audio decoding") - elif backend == "ffmpeg": - return decode_audio_torchaudio(audio_path, timestamps, duration) - else: - raise ValueError(f"Unsupported video backend: {backend}") - - -def decode_audio_torchaudio( - audio_path: Path | str, - timestamps: list[float], - duration: float, - log_loaded_timestamps: bool = False, -) -> torch.Tensor: - # TODO(CarolinePascal) : add channels selection - audio_path = str(audio_path) - - reader = torchaudio.io.StreamReader(src=audio_path) - audio_sample_rate = reader.get_src_stream_info(reader.default_audio_stream).sample_rate - - # TODO(CarolinePascal) : sort timestamps ? - reader.add_basic_audio_stream( - frames_per_chunk=int(ceil(duration * audio_sample_rate)), # Too much is better than not enough - buffer_chunk_size=-1, # No dropping frames - format="fltp", # Format as float32 - ) - - audio_chunks = [] - for ts in timestamps: - reader.seek(ts) # Default to closest audio sample - status = reader.fill_buffer() - if status != 0: - logging.warning("Audio stream reached end of recording before decoding desired timestamps.") - - current_audio_chunk = reader.pop_chunks()[0] - - if log_loaded_timestamps: - logging.info( - f"audio chunk loaded at starting timestamp={current_audio_chunk['pts']:.4f} with duration={len(current_audio_chunk) / audio_sample_rate:.4f}" - ) - - audio_chunks.append(current_audio_chunk) - - audio_chunks = torch.stack(audio_chunks) - - assert len(timestamps) == len(audio_chunks) - return audio_chunks - - def decode_video_frames( video_path: Path | str, timestamps: list[float], @@ -313,53 +243,14 @@ def decode_video_frames_torchcodec( return closest_frames -def encode_audio( - input_path: Path | str, - output_path: Path | str, - codec: str = "aac", # TODO(CarolinePascal) : investigate Fraunhofer FDK AAC (libfdk_aac) codec and and constant (file size control) /variable (quality control) bitrate options - log_level: str | None = "error", - overwrite: bool = False, -) -> None: - """Encodes an audio file using ffmpeg.""" - output_path = Path(output_path) - output_path.parent.mkdir(parents=True, exist_ok=True) - - ffmpeg_args = OrderedDict( - [ - ("-i", str(input_path)), - ("-acodec", codec), - ] - ) - - if log_level is not None: - ffmpeg_args["-loglevel"] = str(log_level) - - ffmpeg_args = [item for pair in ffmpeg_args.items() for item in pair] - if overwrite: - ffmpeg_args.append("-y") - - ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(output_path)] - - # redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal - subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL) - - if not output_path.exists(): - raise OSError( - f"Audio encoding did not work. File not found: {output_path}. " - f"Try running the command manually to debug: `{''.join(ffmpeg_cmd)}`" - ) - - def encode_video_frames( imgs_dir: Path | str, video_path: Path | str, fps: int, - audio_path: Path | str | None = None, vcodec: str = "libsvtav1", pix_fmt: str = "yuv420p", g: int | None = 2, crf: int | None = 30, - acodec: str = "aac", # TODO(CarolinePascal) : investigate Fraunhofer FDK AAC (libfdk_aac) codec and and constant (file size control) /variable (quality control) bitrate options fast_decode: int = 0, log_level: str | None = "error", overwrite: bool = False, @@ -377,18 +268,6 @@ def encode_video_frames( ] ) - ffmpeg_audio_args = OrderedDict() - if audio_path is not None: - audio_path = Path(audio_path) - audio_path.parent.mkdir(parents=True, exist_ok=True) - ffmpeg_audio_args.update( - OrderedDict( - [ - ("-i", str(audio_path)), - ] - ) - ) - ffmpeg_encoding_args = OrderedDict( [ ("-pix_fmt", pix_fmt), @@ -404,14 +283,10 @@ def encode_video_frames( value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode" ffmpeg_encoding_args[key] = value - if audio_path is not None: - ffmpeg_encoding_args["-acodec"] = acodec - if log_level is not None: ffmpeg_encoding_args["-loglevel"] = str(log_level) ffmpeg_args = [item for pair in ffmpeg_video_args.items() for item in pair] - ffmpeg_args += [item for pair in ffmpeg_audio_args.items() for item in pair] ffmpeg_args += [item for pair in ffmpeg_encoding_args.items() for item in pair] if overwrite: ffmpeg_args.append("-y") @@ -460,42 +335,6 @@ with warnings.catch_warnings(): register_feature(VideoFrame, "VideoFrame") -def get_audio_info(video_path: Path | str) -> dict: - ffprobe_audio_cmd = [ - "ffprobe", - "-v", - "error", - "-select_streams", - "a:0", - "-show_entries", - "stream=channels,codec_name,bit_rate,sample_rate,bit_depth,channel_layout,duration", - "-of", - "json", - str(video_path), - ] - result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) - if result.returncode != 0: - raise RuntimeError(f"Error running ffprobe: {result.stderr}") - - info = json.loads(result.stdout) - audio_stream_info = info["streams"][0] if info.get("streams") else None - if audio_stream_info is None: - return {"has_audio": False} - - # Return the information, defaulting to None if no audio stream is present - return { - "has_audio": True, - "audio.channels": audio_stream_info.get("channels", None), - "audio.codec": audio_stream_info.get("codec_name", None), - "audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None, - "audio.sample_rate": int(audio_stream_info["sample_rate"]) - if audio_stream_info.get("sample_rate") - else None, - "audio.bit_depth": audio_stream_info.get("bit_depth", None), - "audio.channel_layout": audio_stream_info.get("channel_layout", None), - } - - def get_video_info(video_path: Path | str) -> dict: ffprobe_video_cmd = [ "ffprobe", @@ -531,7 +370,6 @@ def get_video_info(video_path: Path | str) -> dict: "video.codec": video_stream_info["codec_name"], "video.pix_fmt": video_stream_info["pix_fmt"], "video.is_depth_map": False, - **get_audio_info(video_path), } return video_info diff --git a/lerobot/common/robot_devices/cameras/configs.py b/lerobot/common/robot_devices/cameras/configs.py index b1bb588c..013419a9 100644 --- a/lerobot/common/robot_devices/cameras/configs.py +++ b/lerobot/common/robot_devices/cameras/configs.py @@ -48,8 +48,6 @@ class OpenCVCameraConfig(CameraConfig): rotation: int | None = None mock: bool = False - microphone: str | None = None - def __post_init__(self): if self.color_mode not in ["rgb", "bgr"]: raise ValueError( diff --git a/lerobot/common/robot_devices/cameras/intelrealsense.py b/lerobot/common/robot_devices/cameras/intelrealsense.py index ac0e8ac7..7a21661a 100644 --- a/lerobot/common/robot_devices/cameras/intelrealsense.py +++ b/lerobot/common/robot_devices/cameras/intelrealsense.py @@ -265,8 +265,6 @@ class IntelRealSenseCamera: elif config.rotation == 180: self.rotation = cv2.ROTATE_180 - self.microphone = None # No microphones on realsense cameras, sorry - def find_serial_number_from_name(self, name): camera_infos = find_cameras() camera_names = [cam["name"] for cam in camera_infos] diff --git a/lerobot/common/robot_devices/cameras/opencv.py b/lerobot/common/robot_devices/cameras/opencv.py index 757b3d9f..f279f315 100644 --- a/lerobot/common/robot_devices/cameras/opencv.py +++ b/lerobot/common/robot_devices/cameras/opencv.py @@ -281,8 +281,6 @@ class OpenCVCamera: elif config.rotation == 180: self.rotation = cv2.ROTATE_180 - self.microphone = config.microphone - def connect(self): if self.is_connected: raise RobotDeviceAlreadyConnectedError(f"OpenCVCamera({self.camera_index}) is already connected.") diff --git a/lerobot/common/robot_devices/robots/configs.py b/lerobot/common/robot_devices/robots/configs.py index 942586a0..ab362ad1 100644 --- a/lerobot/common/robot_devices/robots/configs.py +++ b/lerobot/common/robot_devices/robots/configs.py @@ -486,7 +486,6 @@ class So100RobotConfig(ManipulatorRobotConfig): fps=30, width=640, height=480, - microphone="laptop", ), "phone": OpenCVCameraConfig( camera_index=1, diff --git a/lerobot/common/robot_devices/robots/manipulator.py b/lerobot/common/robot_devices/robots/manipulator.py index b452be9d..00bcd3db 100644 --- a/lerobot/common/robot_devices/robots/manipulator.py +++ b/lerobot/common/robot_devices/robots/manipulator.py @@ -181,7 +181,6 @@ class ManipulatorRobot: "shape": (cam.height, cam.width, cam.channels), "names": ["height", "width", "channels"], "info": None, - "audio": "observation.audio." + cam.microphone if cam.microphone is not None else None, } return cam_ft @@ -211,7 +210,9 @@ class ManipulatorRobot: "dtype": "audio", "shape": (len(mic.channels),), "names": "channels", - "info": {"sample_rate": mic.sample_rate}, + "info": { + "sample_rate": mic.sample_rate + }, # we need to store the sample rate here in the case of audio chunks recording (for LeKiwi), as it will not be available anymore when writing the audio file } return mic_ft