[skip ci] feat(torchcodec): adding support for torchcodec audio decoding

This commit is contained in:
CarolinePascal 2025-04-15 18:41:29 +02:00
parent ca716ed196
commit 0acda9facd
No known key found for this signature in database
2 changed files with 37 additions and 7 deletions

View File

@ -21,6 +21,7 @@ from pathlib import Path
import torch import torch
import torchaudio import torchaudio
import torchcodec
from numpy import ceil from numpy import ceil
@ -28,7 +29,7 @@ def decode_audio(
audio_path: Path | str, audio_path: Path | str,
timestamps: list[float], timestamps: list[float],
duration: float, duration: float,
backend: str | None = "ffmpeg", backend: str | None = "torchaudio",
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Decodes audio using the specified backend. Decodes audio using the specified backend.
@ -36,7 +37,7 @@ def decode_audio(
audio_path (Path): Path to the audio file. audio_path (Path): Path to the audio file.
timestamps (list[float]): List of (starting) timestamps to extract audio chunks. timestamps (list[float]): List of (starting) timestamps to extract audio chunks.
duration (float): Duration of the audio chunks in seconds. duration (float): Duration of the audio chunks in seconds.
backend (str, optional): Backend to use for decoding. Defaults to "ffmpeg". backend (str, optional): Backend to use for decoding. Defaults to "torchaudio".
Returns: Returns:
torch.Tensor: Decoded audio chunks. torch.Tensor: Decoded audio chunks.
@ -44,13 +45,42 @@ def decode_audio(
Currently supports ffmpeg. Currently supports ffmpeg.
""" """
if backend == "torchcodec": if backend == "torchcodec":
raise NotImplementedError("torchcodec is not yet supported for audio decoding") # return decode_audio_torchcodec(audio_path, timestamps, duration) #TODO(CarolinePascal): uncomment this line at next torchcodec release
elif backend == "ffmpeg": raise ValueError("torchcodec backend is not available yet.")
elif backend == "torchaudio":
return decode_audio_torchaudio(audio_path, timestamps, duration) return decode_audio_torchaudio(audio_path, timestamps, duration)
else: else:
raise ValueError(f"Unsupported video backend: {backend}") raise ValueError(f"Unsupported video backend: {backend}")
def decode_audio_torchcodec(
audio_path: Path | str,
timestamps: list[float],
duration: float,
log_loaded_timestamps: bool = False,
) -> torch.Tensor:
# TODO(CarolinePascal) : add channels selection
audio_decoder = torchcodec.decoders.AudioDecoder(audio_path)
audio_chunks = []
for ts in timestamps:
current_audio_chunk = audio_decoder.get_samples_played_in_range(
start_seconds=ts, stop_seconds=ts + duration
)
if log_loaded_timestamps:
logging.info(
f"audio chunk loaded at starting timestamp={current_audio_chunk.pts_seconds:.4f} with duration={current_audio_chunk.duration_seconds:.4f}"
)
audio_chunks.append(current_audio_chunk.data)
audio_chunks = torch.stack(audio_chunks)
assert len(timestamps) == len(audio_chunks)
return audio_chunks
def decode_audio_torchaudio( def decode_audio_torchaudio(
audio_path: Path | str, audio_path: Path | str,
timestamps: list[float], timestamps: list[float],

View File

@ -527,7 +527,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
download_audio (bool, optional): Flag to download the audio (see download_videos). Defaults to True. download_audio (bool, optional): Flag to download the audio (see download_videos). Defaults to True.
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'. video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'.
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision. You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
audio_backend (str | None, optional): Audio backend to use for decoding audio. Defaults to 'ffmpeg' decoder used by 'torchaudio'. audio_backend (str | None, optional): Audio backend to use for decoding audio. Defaults to 'torchaudio'.
""" """
super().__init__() super().__init__()
self.repo_id = repo_id self.repo_id = repo_id
@ -539,7 +539,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.revision = revision if revision else CODEBASE_VERSION self.revision = revision if revision else CODEBASE_VERSION
self.video_backend = video_backend if video_backend else get_safe_default_codec() self.video_backend = video_backend if video_backend else get_safe_default_codec()
self.audio_backend = ( self.audio_backend = (
audio_backend if audio_backend else "ffmpeg" audio_backend if audio_backend else "trochaudio"
) # Waiting for torchcodec release #TODO(CarolinePascal) ) # Waiting for torchcodec release #TODO(CarolinePascal)
self.delta_indices = None self.delta_indices = None
@ -1229,7 +1229,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.episode_data_index = None obj.episode_data_index = None
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec() obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
obj.audio_backend = ( obj.audio_backend = (
audio_backend if audio_backend is not None else "ffmpeg" audio_backend if audio_backend is not None else "trochaudio"
) # Waiting for torchcodec release #TODO(CarolinePascal) ) # Waiting for torchcodec release #TODO(CarolinePascal)
return obj return obj