Renamming sampling rate to sample rate for consistency

This commit is contained in:
CarolinePascal 2025-04-08 16:24:44 +02:00
parent 8c69b0b9cd
commit 43a82e2aef
No known key found for this signature in database
5 changed files with 20 additions and 20 deletions

View File

@ -79,12 +79,12 @@ def decode_audio_torchvision(
audio_path = str(audio_path) audio_path = str(audio_path)
reader = torchaudio.io.StreamReader(src=audio_path) reader = torchaudio.io.StreamReader(src=audio_path)
audio_sampling_rate = reader.get_src_stream_info(reader.default_audio_stream).sample_rate audio_sample_rate = reader.get_src_stream_info(reader.default_audio_stream).sample_rate
#TODO(CarolinePascal) : sort timestamps ? #TODO(CarolinePascal) : sort timestamps ?
reader.add_basic_audio_stream( reader.add_basic_audio_stream(
frames_per_chunk = int(ceil(duration * audio_sampling_rate)), #Too much is better than not enough frames_per_chunk = int(ceil(duration * audio_sample_rate)), #Too much is better than not enough
buffer_chunk_size = -1, #No dropping frames buffer_chunk_size = -1, #No dropping frames
format = "fltp", #Format as float32 format = "fltp", #Format as float32
) )
@ -99,7 +99,7 @@ def decode_audio_torchvision(
current_audio_chunk = reader.pop_chunks()[0] current_audio_chunk = reader.pop_chunks()[0]
if log_loaded_timestamps: if log_loaded_timestamps:
logging.info(f"audio chunk loaded at starting timestamp={current_audio_chunk["pts"]:.4f} with duration={len(current_audio_chunk) / audio_sampling_rate:.4f}") logging.info(f"audio chunk loaded at starting timestamp={current_audio_chunk["pts"]:.4f} with duration={len(current_audio_chunk) / audio_sample_rate:.4f}")
audio_chunks.append(current_audio_chunk) audio_chunks.append(current_audio_chunk)

View File

@ -31,6 +31,6 @@ class MicrophoneConfig(MicrophoneConfigBase):
""" """
microphone_index: int microphone_index: int
sampling_rate: int | None = None sample_rate: int | None = None
channels: list[int] | None = None channels: list[int] | None = None
mock: bool = False mock: bool = False

View File

@ -80,7 +80,7 @@ def record_audio_from_microphones(
microphone = Microphone(config) microphone = Microphone(config)
microphone.connect() microphone.connect()
print( print(
f"Recording audio from microphone {microphone_id} for {record_time_s} seconds at {microphone.sampling_rate} Hz." f"Recording audio from microphone {microphone_id} for {record_time_s} seconds at {microphone.sample_rate} Hz."
) )
microphones.append(microphone) microphones.append(microphone)
@ -111,13 +111,13 @@ class Microphone:
""" """
The Microphone class handles all microphones compatible with sounddevice (and the underlying PortAudio library). Most microphones and sound cards are compatible, accross all OS (Linux, Mac, Windows). The Microphone class handles all microphones compatible with sounddevice (and the underlying PortAudio library). Most microphones and sound cards are compatible, accross all OS (Linux, Mac, Windows).
A Microphone instance requires the sounddevice index of the microphone, which may be obtained using `python -m sounddevice`. It also requires the recording sampling rate as well as the list of recorded channels. A Microphone instance requires the sounddevice index of the microphone, which may be obtained using `python -m sounddevice`. It also requires the recording sample rate as well as the list of recorded channels.
Example of usage: Example of usage:
```python ```python
from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig
config = MicrophoneConfig(microphone_index=0, sampling_rate=16000, channels=[1]) config = MicrophoneConfig(microphone_index=0, sample_rate=16000, channels=[1])
microphone = Microphone(config) microphone = Microphone(config)
microphone.connect() microphone.connect()
@ -134,8 +134,8 @@ class Microphone:
self.config = config self.config = config
self.microphone_index = config.microphone_index self.microphone_index = config.microphone_index
#Store the recording sampling rate and channels #Store the recording sample rate and channels
self.sampling_rate = config.sampling_rate self.sample_rate = config.sample_rate
self.channels = config.channels self.channels = config.channels
self.mock = config.mock self.mock = config.mock
@ -177,15 +177,15 @@ class Microphone:
#Check if provided recording parameters are compatible with the microphone #Check if provided recording parameters are compatible with the microphone
actual_microphone = sd.query_devices(self.microphone_index) actual_microphone = sd.query_devices(self.microphone_index)
if self.sampling_rate is not None : if self.sample_rate is not None :
if self.sampling_rate > actual_microphone["default_samplerate"]: if self.sample_rate > actual_microphone["default_samplerate"]:
raise OSError( raise OSError(
f"Provided sampling rate {self.sampling_rate} is higher than the sampling rate of the microphone {actual_microphone['default_samplerate']}." f"Provided sample rate {self.sample_rate} is higher than the sample rate of the microphone {actual_microphone['default_samplerate']}."
) )
elif self.sampling_rate < actual_microphone["default_samplerate"]: elif self.sample_rate < actual_microphone["default_samplerate"]:
logging.warning("Provided sampling rate is lower than the sampling rate of the microphone. Performance may be impacted.") logging.warning("Provided sample rate is lower than the sample rate of the microphone. Performance may be impacted.")
else: else:
self.sampling_rate = int(actual_microphone["default_samplerate"]) self.sample_rate = int(actual_microphone["default_samplerate"])
if self.channels is not None: if self.channels is not None:
if any(c > actual_microphone["max_input_channels"] for c in self.channels): if any(c > actual_microphone["max_input_channels"] for c in self.channels):
@ -201,7 +201,7 @@ class Microphone:
#Create the audio stream #Create the audio stream
self.stream = sd.InputStream( self.stream = sd.InputStream(
device=self.microphone_index, device=self.microphone_index,
samplerate=self.sampling_rate, samplerate=self.sample_rate,
channels=max(self.channels)+1, channels=max(self.channels)+1,
dtype="float32", dtype="float32",
callback=self._audio_callback, callback=self._audio_callback,
@ -221,7 +221,7 @@ class Microphone:
def _record_loop(self, output_file: Path) -> None: def _record_loop(self, output_file: Path) -> None:
#Can only be run on a single process/thread for file writing safety #Can only be run on a single process/thread for file writing safety
with sf.SoundFile(output_file, mode='x', samplerate=self.sampling_rate, with sf.SoundFile(output_file, mode='x', samplerate=self.sample_rate,
channels=max(self.channels)+1, subtype=sf.default_subtype(output_file.suffix[1:])) as file: channels=max(self.channels)+1, subtype=sf.default_subtype(output_file.suffix[1:])) as file:
while not self.record_stop_event.is_set(): while not self.record_stop_event.is_set():
file.write(self.record_queue.get()) file.write(self.record_queue.get())

View File

@ -349,7 +349,7 @@ def test_add_frame_audio(audio_dataset):
dataset.save_episode() dataset.save_episode()
assert dataset[0]["observation.audio.microphone"].shape == torch.Size((int(DEFAULT_AUDIO_CHUNK_DURATION*microphone.sampling_rate),DUMMY_AUDIO_CHANNELS)) assert dataset[0]["observation.audio.microphone"].shape == torch.Size((int(DEFAULT_AUDIO_CHUNK_DURATION*microphone.sample_rate),DUMMY_AUDIO_CHANNELS))
# TODO(aliberts): # TODO(aliberts):
# - [ ] test various attributes & state from init and create # - [ ] test various attributes & state from init and create

View File

@ -75,7 +75,7 @@ def test_microphone(tmp_path, request, microphone_type, mock):
microphone = make_microphone(**microphone_kwargs) microphone = make_microphone(**microphone_kwargs)
microphone.connect() microphone.connect()
assert microphone.is_connected assert microphone.is_connected
assert microphone.sampling_rate is not None assert microphone.sample_rate is not None
assert microphone.channels is not None assert microphone.channels is not None
# Test connecting twice raises an error # Test connecting twice raises an error
@ -122,7 +122,7 @@ def test_microphone(tmp_path, request, microphone_type, mock):
microphone.stop_recording() microphone.stop_recording()
recorded_audio, recorded_sample_rate = read(fpath) recorded_audio, recorded_sample_rate = read(fpath)
assert recorded_sample_rate == microphone.sampling_rate assert recorded_sample_rate == microphone.sample_rate
error_msg = ( error_msg = (
"Recording time difference between read() and stop_recording()", "Recording time difference between read() and stop_recording()",