Adding multiprocessing support for audio recording

This commit is contained in:
CarolinePascal 2025-04-09 14:59:29 +02:00
parent 43a82e2aef
commit d53035d047
No known key found for this signature in database
2 changed files with 52 additions and 45 deletions

View File

@ -21,13 +21,18 @@ import soundfile as sf
import numpy as np
import logging
from threading import Thread, Event
from queue import Queue
from os.path import splitext
from os import remove, getcwd
from multiprocessing import Process
from queue import Empty
from queue import Queue as thread_Queue
from threading import Event as thread_Event
from multiprocessing import JoinableQueue as process_Queue
from multiprocessing import Event as process_Event
from os import getcwd
from pathlib import Path
import shutil
import time
from concurrent.futures import ThreadPoolExecutor
from lerobot.common.utils.utils import capture_timestamp_utc
@ -37,7 +42,6 @@ from lerobot.common.robot_devices.utils import (
RobotDeviceNotConnectedError,
RobotDeviceNotRecordingError,
RobotDeviceAlreadyRecordingError,
busy_wait,
)
def find_microphones(raise_when_empty=False, mock=False) -> list[dict]:
@ -144,8 +148,8 @@ class Microphone:
self.stream = None
#Thread-safe concurrent queue to store the recorded/read audio
self.record_queue = Queue()
self.read_queue = Queue()
self.record_queue = None
self.read_queue = None
#Thread to handle data reading and file writing in a separate thread (safely)
self.record_thread = None
@ -219,13 +223,17 @@ class Microphone:
self.record_queue.put(indata[:,self.channels])
self.read_queue.put(indata[:,self.channels])
def _record_loop(self, output_file: Path) -> None:
@staticmethod
def _record_loop(queue, event: Event, sample_rate: int, channels: list[int], output_file: Path) -> None:
#Can only be run on a single process/thread for file writing safety
with sf.SoundFile(output_file, mode='x', samplerate=self.sample_rate,
channels=max(self.channels)+1, subtype=sf.default_subtype(output_file.suffix[1:])) as file:
while not self.record_stop_event.is_set():
file.write(self.record_queue.get())
#self.record_queue.task_done()
with sf.SoundFile(output_file, mode='x', samplerate=sample_rate,
channels=max(channels)+1, subtype=sf.default_subtype(output_file.suffix[1:])) as file:
while not event.is_set():
try:
file.write(queue.get(timeout=0.02)) #Timeout set as twice the usual sounddevice buffer size
queue.task_done()
except Empty:
continue
def _read(self) -> np.ndarray:
"""
@ -233,17 +241,15 @@ class Microphone:
-> PROS : Inherently thread safe, no need to lock the queue, lightweight CPU usage
-> CONS : Reading duration does not scale well with the number of channels and reading duration
"""
try:
audio_readings = self.read_queue.queue
except Queue.Empty:
audio_readings = np.empty((0, len(self.channels)))
else:
#TODO(CarolinePascal): Check if this is the fastest way to do it
self.read_queue = Queue()
with self.read_queue.mutex:
self.read_queue.queue.clear()
#self.read_queue.all_tasks_done.notify_all()
audio_readings = np.array(audio_readings, dtype=np.float32).reshape(-1, len(self.channels))
audio_readings = np.empty((0, len(self.channels)))
while True:
try:
audio_readings = np.concatenate((audio_readings, self.read_queue.get_nowait()), axis=0)
except Empty:
break
self.read_queue = thread_Queue()
return audio_readings
@ -266,31 +272,32 @@ class Microphone:
return audio_readings
def start_recording(self, output_file : str | None = None) -> None:
def start_recording(self, output_file : str | None = None, multiprocessing : bool | None = False) -> None:
if not self.is_connected:
raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
if self.is_recording:
raise RobotDeviceAlreadyRecordingError(f"Microphone {self.microphone_index} is already recording.")
self.read_queue = Queue()
with self.read_queue.mutex:
self.read_queue.queue.clear()
#self.read_queue.all_tasks_done.notify_all()
#Reset queues
self.read_queue = thread_Queue()
if multiprocessing:
self.record_queue = process_Queue()
else:
self.record_queue = thread_Queue()
self.record_queue = Queue()
with self.record_queue.mutex:
self.record_queue.queue.clear()
#self.record_queue.all_tasks_done.notify_all()
#Recording case
#Write recordings into a file if output_file is provided
if output_file is not None:
output_file = Path(output_file)
if output_file.exists():
output_file.unlink()
self.record_stop_event = Event()
self.record_thread = Thread(target=self._record_loop, args=(output_file,))
if multiprocessing:
self.record_stop_event = process_Event()
self.record_thread = Process(target=Microphone._record_loop, args=(self.record_queue, self.record_stop_event, self.sample_rate, self.channels, output_file, ))
else:
self.record_stop_event = thread_Event()
self.record_thread = Thread(target=Microphone._record_loop, args=(self.record_queue, self.record_stop_event, self.sample_rate, self.channels, output_file, ))
self.record_thread.daemon = True
self.record_thread.start()
@ -304,18 +311,18 @@ class Microphone:
if not self.is_recording:
raise RobotDeviceNotRecordingError(f"Microphone {self.microphone_index} is not recording.")
if self.stream.active:
self.stream.stop() #Wait for all buffers to be processed
#Remark : stream.abort() flushes the buffers !
self.is_recording = False
if self.record_thread is not None:
#self.record_queue.join()
self.record_queue.join()
self.record_stop_event.set()
self.record_thread.join()
self.record_thread = None
self.record_stop_event = None
if self.stream.active:
self.stream.stop() #Wait for all buffers to be processed
#Remark : stream.abort() flushes the buffers !
self.is_recording = False
self.is_writing = False
def disconnect(self) -> None:

View File

@ -20,7 +20,7 @@ from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig, M
class Microphone(Protocol):
def connect(self): ...
def disconnect(self): ...
def start_recording(self, output_file: str | None = None): ...
def start_recording(self, output_file : str | None = None, multiprocessing : bool | None = False): ...
def stop_recording(self): ...
def make_microphones_from_configs(microphone_configs: dict[str, MicrophoneConfigBase]) -> list[Microphone]: