[skip ci] feat(torchcodec): adding support for torchcodec audio decoding
This commit is contained in:
parent
ca716ed196
commit
0acda9facd
|
@ -21,6 +21,7 @@ from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
import torchcodec
|
||||||
from numpy import ceil
|
from numpy import ceil
|
||||||
|
|
||||||
|
|
||||||
|
@ -28,7 +29,7 @@ def decode_audio(
|
||||||
audio_path: Path | str,
|
audio_path: Path | str,
|
||||||
timestamps: list[float],
|
timestamps: list[float],
|
||||||
duration: float,
|
duration: float,
|
||||||
backend: str | None = "ffmpeg",
|
backend: str | None = "torchaudio",
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Decodes audio using the specified backend.
|
Decodes audio using the specified backend.
|
||||||
|
@ -36,7 +37,7 @@ def decode_audio(
|
||||||
audio_path (Path): Path to the audio file.
|
audio_path (Path): Path to the audio file.
|
||||||
timestamps (list[float]): List of (starting) timestamps to extract audio chunks.
|
timestamps (list[float]): List of (starting) timestamps to extract audio chunks.
|
||||||
duration (float): Duration of the audio chunks in seconds.
|
duration (float): Duration of the audio chunks in seconds.
|
||||||
backend (str, optional): Backend to use for decoding. Defaults to "ffmpeg".
|
backend (str, optional): Backend to use for decoding. Defaults to "torchaudio".
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: Decoded audio chunks.
|
torch.Tensor: Decoded audio chunks.
|
||||||
|
@ -44,13 +45,42 @@ def decode_audio(
|
||||||
Currently supports ffmpeg.
|
Currently supports ffmpeg.
|
||||||
"""
|
"""
|
||||||
if backend == "torchcodec":
|
if backend == "torchcodec":
|
||||||
raise NotImplementedError("torchcodec is not yet supported for audio decoding")
|
# return decode_audio_torchcodec(audio_path, timestamps, duration) #TODO(CarolinePascal): uncomment this line at next torchcodec release
|
||||||
elif backend == "ffmpeg":
|
raise ValueError("torchcodec backend is not available yet.")
|
||||||
|
elif backend == "torchaudio":
|
||||||
return decode_audio_torchaudio(audio_path, timestamps, duration)
|
return decode_audio_torchaudio(audio_path, timestamps, duration)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported video backend: {backend}")
|
raise ValueError(f"Unsupported video backend: {backend}")
|
||||||
|
|
||||||
|
|
||||||
|
def decode_audio_torchcodec(
|
||||||
|
audio_path: Path | str,
|
||||||
|
timestamps: list[float],
|
||||||
|
duration: float,
|
||||||
|
log_loaded_timestamps: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# TODO(CarolinePascal) : add channels selection
|
||||||
|
audio_decoder = torchcodec.decoders.AudioDecoder(audio_path)
|
||||||
|
|
||||||
|
audio_chunks = []
|
||||||
|
for ts in timestamps:
|
||||||
|
current_audio_chunk = audio_decoder.get_samples_played_in_range(
|
||||||
|
start_seconds=ts, stop_seconds=ts + duration
|
||||||
|
)
|
||||||
|
|
||||||
|
if log_loaded_timestamps:
|
||||||
|
logging.info(
|
||||||
|
f"audio chunk loaded at starting timestamp={current_audio_chunk.pts_seconds:.4f} with duration={current_audio_chunk.duration_seconds:.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_chunks.append(current_audio_chunk.data)
|
||||||
|
|
||||||
|
audio_chunks = torch.stack(audio_chunks)
|
||||||
|
|
||||||
|
assert len(timestamps) == len(audio_chunks)
|
||||||
|
return audio_chunks
|
||||||
|
|
||||||
|
|
||||||
def decode_audio_torchaudio(
|
def decode_audio_torchaudio(
|
||||||
audio_path: Path | str,
|
audio_path: Path | str,
|
||||||
timestamps: list[float],
|
timestamps: list[float],
|
||||||
|
|
|
@ -527,7 +527,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
download_audio (bool, optional): Flag to download the audio (see download_videos). Defaults to True.
|
download_audio (bool, optional): Flag to download the audio (see download_videos). Defaults to True.
|
||||||
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'.
|
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'.
|
||||||
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
|
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
|
||||||
audio_backend (str | None, optional): Audio backend to use for decoding audio. Defaults to 'ffmpeg' decoder used by 'torchaudio'.
|
audio_backend (str | None, optional): Audio backend to use for decoding audio. Defaults to 'torchaudio'.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.repo_id = repo_id
|
self.repo_id = repo_id
|
||||||
|
@ -539,7 +539,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
self.revision = revision if revision else CODEBASE_VERSION
|
self.revision = revision if revision else CODEBASE_VERSION
|
||||||
self.video_backend = video_backend if video_backend else get_safe_default_codec()
|
self.video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||||
self.audio_backend = (
|
self.audio_backend = (
|
||||||
audio_backend if audio_backend else "ffmpeg"
|
audio_backend if audio_backend else "trochaudio"
|
||||||
) # Waiting for torchcodec release #TODO(CarolinePascal)
|
) # Waiting for torchcodec release #TODO(CarolinePascal)
|
||||||
self.delta_indices = None
|
self.delta_indices = None
|
||||||
|
|
||||||
|
@ -1229,7 +1229,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
obj.episode_data_index = None
|
obj.episode_data_index = None
|
||||||
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
||||||
obj.audio_backend = (
|
obj.audio_backend = (
|
||||||
audio_backend if audio_backend is not None else "ffmpeg"
|
audio_backend if audio_backend is not None else "trochaudio"
|
||||||
) # Waiting for torchcodec release #TODO(CarolinePascal)
|
) # Waiting for torchcodec release #TODO(CarolinePascal)
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue