docs: add methods descriptions and comments on tricky parts

This commit is contained in:
CarolinePascal 2025-04-11 13:46:34 +02:00
parent a08b5c4105
commit 5384309e6f
No known key found for this signature in database
5 changed files with 82 additions and 40 deletions

View File

@ -73,6 +73,7 @@ def sample_images(image_paths: list[str]) -> np.ndarray:
def sample_audio_from_path(audio_path: str) -> np.ndarray: def sample_audio_from_path(audio_path: str) -> np.ndarray:
"""Samples audio data from an audio recording stored in a WAV file."""
data = load_audio_from_path(audio_path) data = load_audio_from_path(audio_path)
sampled_indices = sample_indices(len(data)) sampled_indices = sample_indices(len(data))
@ -80,6 +81,7 @@ def sample_audio_from_path(audio_path: str) -> np.ndarray:
def sample_audio_from_data(data: np.ndarray) -> np.ndarray: def sample_audio_from_data(data: np.ndarray) -> np.ndarray:
"""Samples audio data from an audio recording stored in a numpy array."""
sampled_indices = sample_indices(len(data)) sampled_indices = sample_indices(len(data))
return data[sampled_indices] return data[sampled_indices]
@ -106,7 +108,7 @@ def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], featu
elif features[key]["dtype"] == "audio": elif features[key]["dtype"] == "audio":
try: try:
ep_ft_array = sample_audio_from_path(data[0]) ep_ft_array = sample_audio_from_path(data[0])
except TypeError: # Should only be triggered for LeKiwi robot except TypeError: # Should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner
ep_ft_array = sample_audio_from_data(data) ep_ft_array = sample_audio_from_data(data)
axes_to_reduce = 0 axes_to_reduce = 0
keepdims = True keepdims = True

View File

@ -150,6 +150,7 @@ class LeRobotDatasetMetadata:
return Path(fpath) return Path(fpath)
def get_compressed_audio_file_path(self, episode_index: int, audio_key: str) -> Path: def get_compressed_audio_file_path(self, episode_index: int, audio_key: str) -> Path:
"""Returns the path of the compressed (i.e. encoded) audio file."""
episode_chunk = self.get_episode_chunk(episode_index) episode_chunk = self.get_episode_chunk(episode_index)
fpath = self.audio_path.format( fpath = self.audio_path.format(
episode_chunk=episode_chunk, audio_key=audio_key, episode_index=episode_index episode_chunk=episode_chunk, audio_key=audio_key, episode_index=episode_index
@ -206,7 +207,7 @@ class LeRobotDatasetMetadata:
@property @property
def audio_keys(self) -> list[str]: def audio_keys(self) -> list[str]:
"""Keys to access audio modalities (wether they are linked to a camera or not).""" """Keys to access audio modalities (whether they are linked to a camera or not)."""
return [key for key, ft in self.features.items() if ft["dtype"] == "audio"] return [key for key, ft in self.features.items() if ft["dtype"] == "audio"]
@property @property
@ -223,7 +224,7 @@ class LeRobotDatasetMetadata:
def linked_audio_keys(self) -> list[str]: def linked_audio_keys(self) -> list[str]:
"""Keys to access audio modalities linked to a camera.""" """Keys to access audio modalities linked to a camera."""
return [key for key in self.audio_keys if key in self.audio_camera_keys_mapping] return [key for key in self.audio_keys if key in self.audio_camera_keys_mapping]
@property @property
def unlinked_audio_keys(self) -> list[str]: def unlinked_audio_keys(self) -> list[str]:
"""Keys to access audio modalities not linked to a camera.""" """Keys to access audio modalities not linked to a camera."""
@ -342,9 +343,10 @@ class LeRobotDatasetMetadata:
been encoded the same way. Also, this means it assumes the first episode exists. been encoded the same way. Also, this means it assumes the first episode exists.
""" """
for key in self.unlinked_audio_keys: for key in self.unlinked_audio_keys:
if not self.features[key].get("info", None) or ( if (
len(self.features[key]["info"]) == 1 and "sample_rate" in self.features[key]["info"] not self.features[key].get("info", None)
): #Overwrite if info is empty or only contains sample rate (necessary to correctly save audio files recorded by LeKiwi) or (len(self.features[key]["info"]) == 1 and "sample_rate" in self.features[key]["info"])
): # Overwrite if info is empty or only contains sample rate (necessary to correctly save audio files recorded by LeKiwi)
audio_path = self.root / self.get_compressed_audio_file_path(0, key) audio_path = self.root / self.get_compressed_audio_file_path(0, key)
self.info["features"][key]["info"] = get_audio_info(audio_path) self.info["features"][key]["info"] = get_audio_info(audio_path)
@ -568,9 +570,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
except (AssertionError, FileNotFoundError, NotADirectoryError): except (AssertionError, FileNotFoundError, NotADirectoryError):
self.revision = get_safe_version(self.repo_id, self.revision) self.revision = get_safe_version(self.repo_id, self.revision)
self.download_episodes( self.download_episodes(
download_videos, download_videos, download_audio
download_audio ) # Audio embedded in video files (.mp4) will be downloaded if download_videos is set to True, regardless of the value of download_audio
) #Audio embedded in video files (.mp4) will be downloaded if download_videos is set to True, regardless of the value of download_audio
self.hf_dataset = self.load_hf_dataset() self.hf_dataset = self.load_hf_dataset()
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes) self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
@ -581,6 +582,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()} ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s) check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
# TODO(CarolinePascal) : add check for audio duration with respect to video duration and episode duration.
# Setup delta_indices # Setup delta_indices
if self.delta_timestamps is not None: if self.delta_timestamps is not None:
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s) check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
@ -601,7 +604,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
) -> None: ) -> None:
ignore_patterns = ["images/"] ignore_patterns = ["images/"]
if not push_videos: if not push_videos:
ignore_patterns.append("videos/") #Audio embedded in video files (.mp4) will be automatically pushed if videos are pushed ignore_patterns.append(
"videos/"
) # Audio embedded in video files (.mp4) will be automatically pushed if videos are pushed
if not push_audio: if not push_audio:
ignore_patterns.append("audio/") ignore_patterns.append("audio/")
@ -670,7 +675,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
files = None files = None
ignore_patterns = [] ignore_patterns = []
if not download_videos: if not download_videos:
ignore_patterns.append("videos/") #Audio embedded in video files (.mp4) will be automatically downloaded if videos are downloaded ignore_patterns.append(
"videos/"
) # Audio embedded in video files (.mp4) will be automatically downloaded if videos are downloaded
if not download_audio: if not download_audio:
ignore_patterns.append("audio/") ignore_patterns.append("audio/")
if self.episodes is not None: if self.episodes is not None:
@ -785,7 +792,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
query_indices: dict[str, list[int]] | None = None, query_indices: dict[str, list[int]] | None = None,
) -> dict[str, list[float]]: ) -> dict[str, list[float]]:
query_timestamps = {} query_timestamps = {}
for key in self.meta.audio_keys: #Standalone audio and audio embedded in video as well ! for key in self.meta.audio_keys: # Standalone audio and audio embedded in video as well !
if query_indices is not None and key in query_indices: if query_indices is not None and key in query_indices:
timestamps = self.hf_dataset.select(query_indices[key])["timestamp"] timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
query_timestamps[key] = torch.stack(timestamps).tolist() query_timestamps[key] = torch.stack(timestamps).tolist()
@ -821,12 +828,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
) -> dict[str, torch.Tensor]: ) -> dict[str, torch.Tensor]:
item = {} item = {}
for audio_key, query_ts in query_timestamps.items(): for audio_key, query_ts in query_timestamps.items():
#Audio stored with video in a single .mp4 file # Audio stored with video in a single .mp4 file
if audio_key in self.meta.linked_audio_keys: if audio_key in self.meta.linked_audio_keys:
audio_path = self.root / self.meta.get_video_file_path( audio_path = self.root / self.meta.get_video_file_path(
ep_idx, self.meta.audio_camera_keys_mapping[audio_key] ep_idx, self.meta.audio_camera_keys_mapping[audio_key]
) )
#Audio stored alone in a separate .m4a file # Audio stored alone in a separate .m4a file
else: else:
audio_path = self.root / self.meta.get_compressed_audio_file_path(ep_idx, audio_key) audio_path = self.root / self.meta.get_compressed_audio_file_path(ep_idx, audio_key)
audio_chunk = decode_audio(audio_path, query_ts, query_duration, self.audio_backend) audio_chunk = decode_audio(audio_path, query_ts, query_duration, self.audio_backend)
@ -957,9 +964,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
self._save_image(frame[key], img_path) self._save_image(frame[key], img_path)
self.episode_buffer[key].append(str(img_path)) self.episode_buffer[key].append(str(img_path))
elif self.features[key]["dtype"] == "audio": elif self.features[key]["dtype"] == "audio":
if self.meta.robot_type.startswith("lekiwi"): if self.meta.robot_type.startswith(
"lekiwi"
): # Rw data storage should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner
self.episode_buffer[key].append(frame[key]) self.episode_buffer[key].append(frame[key])
else: else: # Otherwise, only the audio file path is stored in the episode buffer
if frame_index == 0: if frame_index == 0:
audio_path = self._get_raw_audio_file_path( audio_path = self._get_raw_audio_file_path(
episode_index=self.episode_buffer["episode_index"], audio_key=key episode_index=self.episode_buffer["episode_index"], audio_key=key
@ -972,7 +981,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
def add_microphone_recording(self, microphone: Microphone, microphone_key: str) -> None: def add_microphone_recording(self, microphone: Microphone, microphone_key: str) -> None:
""" """
This function will start recording audio from the microphone and save it to disk. Starts recording audio data provided by the microphone and directly writes it in a .wav file.
""" """
audio_dir = self._get_raw_audio_file_path( audio_dir = self._get_raw_audio_file_path(
@ -1025,7 +1034,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]: if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
continue continue
elif ft["dtype"] == "audio": elif ft["dtype"] == "audio":
if self.meta.robot_type.startswith("lekiwi"): if self.meta.robot_type.startswith(
"lekiwi"
): # Raw data storage should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner
episode_buffer[key] = np.concatenate(episode_buffer[key], axis=0) episode_buffer[key] = np.concatenate(episode_buffer[key], axis=0)
continue continue
episode_buffer[key] = np.stack(episode_buffer[key]) episode_buffer[key] = np.stack(episode_buffer[key])
@ -1033,7 +1044,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
self._wait_image_writer() self._wait_image_writer()
self._save_episode_table(episode_buffer, episode_index) self._save_episode_table(episode_buffer, episode_index)
if self.meta.robot_type.startswith("lekiwi"): if self.meta.robot_type.startswith(
"lekiwi"
): # Raw data storage should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner
for key in self.meta.audio_keys: for key in self.meta.audio_keys:
audio_path = self._get_raw_audio_file_path( audio_path = self._get_raw_audio_file_path(
episode_index=self.episode_buffer["episode_index"][0], audio_key=key episode_index=self.episode_buffer["episode_index"][0], audio_key=key
@ -1053,7 +1066,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
for key in self.meta.video_keys: for key in self.meta.video_keys:
episode_buffer[key] = video_paths[key] episode_buffer[key] = video_paths[key]
if len(self.meta.unlinked_audio_keys) > 0: #Linked audio is already encoded in the video files if len(self.meta.unlinked_audio_keys) > 0: # Linked audio is already encoded in the video files
_ = self.encode_episode_audio(episode_index) _ = self.encode_episode_audio(episode_index)
# `meta.save_episode` be executed after encoding the videos # `meta.save_episode` be executed after encoding the videos
@ -1080,7 +1093,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
if img_dir.is_dir(): if img_dir.is_dir():
shutil.rmtree(self.root / "images") shutil.rmtree(self.root / "images")
# delete raw audio # delete raw audio files
raw_audio_files = list(self.root.rglob("*.wav")) raw_audio_files = list(self.root.rglob("*.wav"))
for raw_audio_file in raw_audio_files: for raw_audio_file in raw_audio_files:
raw_audio_file.unlink() raw_audio_file.unlink()

View File

@ -52,14 +52,14 @@ def decode_audio(
Decodes audio using the specified backend. Decodes audio using the specified backend.
Args: Args:
audio_path (Path): Path to the audio file. audio_path (Path): Path to the audio file.
timestamps (list[float]): List of timestamps to extract frames. timestamps (list[float]): List of (starting) timestamps to extract audio chunks.
tolerance_s (float): Allowed deviation in seconds for frame retrieval. duration (float): Duration of the audio chunks in seconds.
backend (str, optional): Backend to use for decoding. Defaults to "pyav". backend (str, optional): Backend to use for decoding. Defaults to "ffmpeg".
Returns: Returns:
torch.Tensor: Decoded frames. torch.Tensor: Decoded audio chunks.
Currently supports pyav. Currently supports ffmpeg.
""" """
if backend == "torchcodec": if backend == "torchcodec":
raise NotImplementedError("torchcodec is not yet supported for audio decoding") raise NotImplementedError("torchcodec is not yet supported for audio decoding")
@ -82,7 +82,6 @@ def decode_audio_torchvision(
audio_sample_rate = reader.get_src_stream_info(reader.default_audio_stream).sample_rate audio_sample_rate = reader.get_src_stream_info(reader.default_audio_stream).sample_rate
# TODO(CarolinePascal) : sort timestamps ? # TODO(CarolinePascal) : sort timestamps ?
reader.add_basic_audio_stream( reader.add_basic_audio_stream(
frames_per_chunk=int(ceil(duration * audio_sample_rate)), # Too much is better than not enough frames_per_chunk=int(ceil(duration * audio_sample_rate)), # Too much is better than not enough
buffer_chunk_size=-1, # No dropping frames buffer_chunk_size=-1, # No dropping frames
@ -317,7 +316,7 @@ def decode_video_frames_torchcodec(
def encode_audio( def encode_audio(
input_path: Path | str, input_path: Path | str,
output_path: Path | str, output_path: Path | str,
codec: str = "aac", codec: str = "aac", # TODO(CarolinePascal) : investigate Fraunhofer FDK AAC (libfdk_aac) codec and and constant (file size control) /variable (quality control) bitrate options
log_level: str | None = "error", log_level: str | None = "error",
overwrite: bool = False, overwrite: bool = False,
) -> None: ) -> None:
@ -346,7 +345,7 @@ def encode_audio(
if not output_path.exists(): if not output_path.exists():
raise OSError( raise OSError(
f"Video encoding did not work. File not found: {output_path}. " f"Audio encoding did not work. File not found: {output_path}. "
f"Try running the command manually to debug: `{''.join(ffmpeg_cmd)}`" f"Try running the command manually to debug: `{''.join(ffmpeg_cmd)}`"
) )

View File

@ -44,6 +44,10 @@ from lerobot.common.utils.utils import capture_timestamp_utc
def find_microphones(raise_when_empty=False, mock=False) -> list[dict]: def find_microphones(raise_when_empty=False, mock=False) -> list[dict]:
"""
Finds and lists all microphones compatible with sounddevice (and the underlying PortAudio library).
Most microphones and sound cards are compatible, across all OS (Linux, Mac, Windows).
"""
microphones = [] microphones = []
if mock: if mock:
@ -72,6 +76,11 @@ def find_microphones(raise_when_empty=False, mock=False) -> list[dict]:
def record_audio_from_microphones( def record_audio_from_microphones(
output_dir: Path, microphone_ids: list[int] | None = None, record_time_s: float = 2.0 output_dir: Path, microphone_ids: list[int] | None = None, record_time_s: float = 2.0
): ):
"""
Records audio from all the channels of the specified microphones for the specified duration.
If no microphone ids are provided, all available microphones will be used.
"""
if microphone_ids is None or len(microphone_ids) == 0: if microphone_ids is None or len(microphone_ids) == 0:
microphones = find_microphones() microphones = find_microphones()
microphone_ids = [m["index"] for m in microphones] microphone_ids = [m["index"] for m in microphones]
@ -112,7 +121,7 @@ def record_audio_from_microphones(
class Microphone: class Microphone:
""" """
The Microphone class handles all microphones compatible with sounddevice (and the underlying PortAudio library). Most microphones and sound cards are compatible, accross all OS (Linux, Mac, Windows). The Microphone class handles all microphones compatible with sounddevice (and the underlying PortAudio library). Most microphones and sound cards are compatible, across all OS (Linux, Mac, Windows).
A Microphone instance requires the sounddevice index of the microphone, which may be obtained using `python -m sounddevice`. It also requires the recording sample rate as well as the list of recorded channels. A Microphone instance requires the sounddevice index of the microphone, which may be obtained using `python -m sounddevice`. It also requires the recording sample rate as well as the list of recorded channels.
@ -146,11 +155,11 @@ class Microphone:
# Input audio stream # Input audio stream
self.stream = None self.stream = None
# Thread-safe concurrent queue to store the recorded/read audio # Thread/Process-safe concurrent queue to store the recorded/read audio
self.record_queue = None self.record_queue = None
self.read_queue = None self.read_queue = None
# Thread to handle data reading and file writing in a separate thread (safely) # Thread/Process to handle data reading and file writing in a separate thread/process (safely)
self.record_thread = None self.record_thread = None
self.record_stop_event = None self.record_stop_event = None
@ -160,6 +169,9 @@ class Microphone:
self.is_writing = False self.is_writing = False
def connect(self) -> None: def connect(self) -> None:
"""
Connects the microphone and checks if the requested acquisition parameters are compatible with the microphone.
"""
if self.is_connected: if self.is_connected:
raise RobotDeviceAlreadyConnectedError( raise RobotDeviceAlreadyConnectedError(
f"Microphone {self.microphone_index} is already connected." f"Microphone {self.microphone_index} is already connected."
@ -214,15 +226,18 @@ class Microphone:
dtype="float32", dtype="float32",
callback=self._audio_callback, callback=self._audio_callback,
) )
# Remark : the blocksize parameter could be passed to the stream to ensure that audio_callback always recieve same length buffers. # Remark : the blocksize parameter could be passed to the stream to ensure that audio_callback always receive same length buffers.
# However, this may lead to additionnal latency. We thus stick to blocksize=0 which means that audio_callback will recieve varying length buffers, but with no addtional latency. # However, this may lead to additional latency. We thus stick to blocksize=0 which means that audio_callback will receive varying length buffers, but with no additional latency.
self.is_connected = True self.is_connected = True
def _audio_callback(self, indata, frames, time, status) -> None: def _audio_callback(self, indata, frames, time, status) -> None:
"""
Low-level sounddevice callback.
"""
if status: if status:
logging.warning(status) logging.warning(status)
# Slicing makes copy unecessary # Slicing makes copy unnecessary
# Two separate queues are necessary because .get() also pops the data from the queue # Two separate queues are necessary because .get() also pops the data from the queue
if self.is_writing: if self.is_writing:
self.record_queue.put(indata[:, self.channels]) self.record_queue.put(indata[:, self.channels])
@ -230,6 +245,9 @@ class Microphone:
@staticmethod @staticmethod
def _record_loop(queue, event: Event, sample_rate: int, channels: list[int], output_file: Path) -> None: def _record_loop(queue, event: Event, sample_rate: int, channels: list[int], output_file: Path) -> None:
"""
Thread/Process-safe loop to write audio data into a file.
"""
# Can only be run on a single process/thread for file writing safety # Can only be run on a single process/thread for file writing safety
with sf.SoundFile( with sf.SoundFile(
output_file, output_file,
@ -249,9 +267,7 @@ class Microphone:
def _read(self) -> np.ndarray: def _read(self) -> np.ndarray:
""" """
Gets audio data from the queue and coverts it to a numpy array. Thread/Process-safe callback to read available audio data
-> PROS : Inherently thread safe, no need to lock the queue, lightweight CPU usage
-> CONS : Reading duration does not scale well with the number of channels and reading duration
""" """
audio_readings = np.empty((0, len(self.channels))) audio_readings = np.empty((0, len(self.channels)))
@ -266,6 +282,9 @@ class Microphone:
return audio_readings return audio_readings
def read(self) -> np.ndarray: def read(self) -> np.ndarray:
"""
Reads the last audio chunk recorded by the microphone, e.g. all samples recorded since the last read or since the beginning of the recording.
"""
if not self.is_connected: if not self.is_connected:
raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.") raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
if not self.is_recording: if not self.is_recording:
@ -284,6 +303,9 @@ class Microphone:
return audio_readings return audio_readings
def start_recording(self, output_file: str | None = None, multiprocessing: bool | None = False) -> None: def start_recording(self, output_file: str | None = None, multiprocessing: bool | None = False) -> None:
"""
Starts the recording of the microphone. If output_file is provided, the audio will be written to this file.
"""
if not self.is_connected: if not self.is_connected:
raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.") raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
if self.is_recording: if self.is_recording:
@ -337,6 +359,9 @@ class Microphone:
self.stream.start() self.stream.start()
def stop_recording(self) -> None: def stop_recording(self) -> None:
"""
Stops the recording of the microphones.
"""
if not self.is_connected: if not self.is_connected:
raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.") raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
if not self.is_recording: if not self.is_recording:
@ -356,6 +381,9 @@ class Microphone:
self.is_writing = False self.is_writing = False
def disconnect(self) -> None: def disconnect(self) -> None:
"""
Disconnects the microphone and stops the recording.
"""
if not self.is_connected: if not self.is_connected:
raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.") raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
@ -385,7 +413,7 @@ if __name__ == "__main__":
"--output-dir", "--output-dir",
type=Path, type=Path,
default="outputs/audio_from_microphones", default="outputs/audio_from_microphones",
help="Set directory to save an audio snipet for each microphone.", help="Set directory to save an audio snippet for each microphone.",
) )
parser.add_argument( parser.add_argument(
"--record-time-s", "--record-time-s",

View File

@ -381,7 +381,7 @@ class MobileManipulator:
if frame_candidate is not None: if frame_candidate is not None:
frames[cam_name] = frame_candidate frames[cam_name] = frame_candidate
# Recieve audio # Receive audio
for microphone_name, audio_data in audio_dict.items(): for microphone_name, audio_data in audio_dict.items():
if audio_data: if audio_data:
frames[microphone_name] = audio_data frames[microphone_name] = audio_data