From 0acda9facda0de47bad93ed5f38a0b16eafbffb3 Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Tue, 15 Apr 2025 18:41:29 +0200 Subject: [PATCH] [skip ci] feat(torchcodec): adding support for torchcodec audio decoding --- lerobot/common/datasets/audio_utils.py | 38 +++++++++++++++++++--- lerobot/common/datasets/lerobot_dataset.py | 6 ++-- 2 files changed, 37 insertions(+), 7 deletions(-) diff --git a/lerobot/common/datasets/audio_utils.py b/lerobot/common/datasets/audio_utils.py index 901fad52..2f3537d2 100644 --- a/lerobot/common/datasets/audio_utils.py +++ b/lerobot/common/datasets/audio_utils.py @@ -21,6 +21,7 @@ from pathlib import Path import torch import torchaudio +import torchcodec from numpy import ceil @@ -28,7 +29,7 @@ def decode_audio( audio_path: Path | str, timestamps: list[float], duration: float, - backend: str | None = "ffmpeg", + backend: str | None = "torchaudio", ) -> torch.Tensor: """ Decodes audio using the specified backend. @@ -36,7 +37,7 @@ def decode_audio( audio_path (Path): Path to the audio file. timestamps (list[float]): List of (starting) timestamps to extract audio chunks. duration (float): Duration of the audio chunks in seconds. - backend (str, optional): Backend to use for decoding. Defaults to "ffmpeg". + backend (str, optional): Backend to use for decoding. Defaults to "torchaudio". Returns: torch.Tensor: Decoded audio chunks. @@ -44,13 +45,42 @@ def decode_audio( Currently supports ffmpeg. """ if backend == "torchcodec": - raise NotImplementedError("torchcodec is not yet supported for audio decoding") - elif backend == "ffmpeg": + # return decode_audio_torchcodec(audio_path, timestamps, duration) #TODO(CarolinePascal): uncomment this line at next torchcodec release + raise ValueError("torchcodec backend is not available yet.") + elif backend == "torchaudio": return decode_audio_torchaudio(audio_path, timestamps, duration) else: raise ValueError(f"Unsupported video backend: {backend}") +def decode_audio_torchcodec( + audio_path: Path | str, + timestamps: list[float], + duration: float, + log_loaded_timestamps: bool = False, +) -> torch.Tensor: + # TODO(CarolinePascal) : add channels selection + audio_decoder = torchcodec.decoders.AudioDecoder(audio_path) + + audio_chunks = [] + for ts in timestamps: + current_audio_chunk = audio_decoder.get_samples_played_in_range( + start_seconds=ts, stop_seconds=ts + duration + ) + + if log_loaded_timestamps: + logging.info( + f"audio chunk loaded at starting timestamp={current_audio_chunk.pts_seconds:.4f} with duration={current_audio_chunk.duration_seconds:.4f}" + ) + + audio_chunks.append(current_audio_chunk.data) + + audio_chunks = torch.stack(audio_chunks) + + assert len(timestamps) == len(audio_chunks) + return audio_chunks + + def decode_audio_torchaudio( audio_path: Path | str, timestamps: list[float], diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 6c7af6a3..27899870 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -527,7 +527,7 @@ class LeRobotDataset(torch.utils.data.Dataset): download_audio (bool, optional): Flag to download the audio (see download_videos). Defaults to True. video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'. You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision. - audio_backend (str | None, optional): Audio backend to use for decoding audio. Defaults to 'ffmpeg' decoder used by 'torchaudio'. + audio_backend (str | None, optional): Audio backend to use for decoding audio. Defaults to 'torchaudio'. """ super().__init__() self.repo_id = repo_id @@ -539,7 +539,7 @@ class LeRobotDataset(torch.utils.data.Dataset): self.revision = revision if revision else CODEBASE_VERSION self.video_backend = video_backend if video_backend else get_safe_default_codec() self.audio_backend = ( - audio_backend if audio_backend else "ffmpeg" + audio_backend if audio_backend else "trochaudio" ) # Waiting for torchcodec release #TODO(CarolinePascal) self.delta_indices = None @@ -1229,7 +1229,7 @@ class LeRobotDataset(torch.utils.data.Dataset): obj.episode_data_index = None obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec() obj.audio_backend = ( - audio_backend if audio_backend is not None else "ffmpeg" + audio_backend if audio_backend is not None else "trochaudio" ) # Waiting for torchcodec release #TODO(CarolinePascal) return obj