fix(audio): separate audio from video

This commit is contained in:
CarolinePascal 2025-04-15 17:14:55 +02:00
parent 6cf9cb35ba
commit ca716ed196
No known key found for this signature in database
8 changed files with 213 additions and 240 deletions

View File

@ -0,0 +1,165 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import subprocess
from collections import OrderedDict
from pathlib import Path
import torch
import torchaudio
from numpy import ceil
def decode_audio(
audio_path: Path | str,
timestamps: list[float],
duration: float,
backend: str | None = "ffmpeg",
) -> torch.Tensor:
"""
Decodes audio using the specified backend.
Args:
audio_path (Path): Path to the audio file.
timestamps (list[float]): List of (starting) timestamps to extract audio chunks.
duration (float): Duration of the audio chunks in seconds.
backend (str, optional): Backend to use for decoding. Defaults to "ffmpeg".
Returns:
torch.Tensor: Decoded audio chunks.
Currently supports ffmpeg.
"""
if backend == "torchcodec":
raise NotImplementedError("torchcodec is not yet supported for audio decoding")
elif backend == "ffmpeg":
return decode_audio_torchaudio(audio_path, timestamps, duration)
else:
raise ValueError(f"Unsupported video backend: {backend}")
def decode_audio_torchaudio(
audio_path: Path | str,
timestamps: list[float],
duration: float,
log_loaded_timestamps: bool = False,
) -> torch.Tensor:
# TODO(CarolinePascal) : add channels selection
audio_path = str(audio_path)
reader = torchaudio.io.StreamReader(src=audio_path)
audio_sample_rate = reader.get_src_stream_info(reader.default_audio_stream).sample_rate
# TODO(CarolinePascal) : sort timestamps ?
reader.add_basic_audio_stream(
frames_per_chunk=int(ceil(duration * audio_sample_rate)), # Too much is better than not enough
buffer_chunk_size=-1, # No dropping frames
format="fltp", # Format as float32
)
audio_chunks = []
for ts in timestamps:
reader.seek(ts) # Default to closest audio sample
status = reader.fill_buffer()
if status != 0:
logging.warning("Audio stream reached end of recording before decoding desired timestamps.")
current_audio_chunk = reader.pop_chunks()[0]
if log_loaded_timestamps:
logging.info(
f"audio chunk loaded at starting timestamp={current_audio_chunk['pts']:.4f} with duration={len(current_audio_chunk) / audio_sample_rate:.4f}"
)
audio_chunks.append(current_audio_chunk)
audio_chunks = torch.stack(audio_chunks)
assert len(timestamps) == len(audio_chunks)
return audio_chunks
def encode_audio(
input_path: Path | str,
output_path: Path | str,
codec: str = "aac", # TODO(CarolinePascal) : investigate Fraunhofer FDK AAC (libfdk_aac) codec and and constant (file size control) /variable (quality control) bitrate options
log_level: str | None = "error",
overwrite: bool = False,
) -> None:
"""Encodes an audio file using ffmpeg."""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
ffmpeg_args = OrderedDict(
[
("-i", str(input_path)),
("-acodec", codec),
]
)
if log_level is not None:
ffmpeg_args["-loglevel"] = str(log_level)
ffmpeg_args = [item for pair in ffmpeg_args.items() for item in pair]
if overwrite:
ffmpeg_args.append("-y")
ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(output_path)]
# redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal
subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL)
if not output_path.exists():
raise OSError(
f"Audio encoding did not work. File not found: {output_path}. "
f"Try running the command manually to debug: `{''.join(ffmpeg_cmd)}`"
)
def get_audio_info(video_path: Path | str) -> dict:
ffprobe_audio_cmd = [
"ffprobe",
"-v",
"error",
"-select_streams",
"a:0",
"-show_entries",
"stream=channels,codec_name,bit_rate,sample_rate,bit_depth,channel_layout,duration",
"-of",
"json",
str(video_path),
]
result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
if result.returncode != 0:
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
info = json.loads(result.stdout)
audio_stream_info = info["streams"][0] if info.get("streams") else None
if audio_stream_info is None:
return {"has_audio": False}
# Return the information, defaulting to None if no audio stream is present
return {
"has_audio": True,
"audio.channels": audio_stream_info.get("channels", None),
"audio.codec": audio_stream_info.get("codec_name", None),
"audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None,
"audio.sample_rate": int(audio_stream_info["sample_rate"])
if audio_stream_info.get("sample_rate")
else None,
"audio.bit_depth": audio_stream_info.get("bit_depth", None),
"audio.channel_layout": audio_stream_info.get("channel_layout", None),
}

View File

@ -32,6 +32,11 @@ from huggingface_hub.constants import REPOCARD_NAME
from huggingface_hub.errors import RevisionNotFoundError
from lerobot.common.constants import HF_LEROBOT_HOME
from lerobot.common.datasets.audio_utils import (
decode_audio,
encode_audio,
get_audio_info,
)
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
from lerobot.common.datasets.utils import (
@ -70,11 +75,8 @@ from lerobot.common.datasets.utils import (
)
from lerobot.common.datasets.video_utils import (
VideoFrame,
decode_audio,
decode_video_frames,
encode_audio,
encode_video_frames,
get_audio_info,
get_safe_default_codec,
get_video_info,
)
@ -207,29 +209,9 @@ class LeRobotDatasetMetadata:
@property
def audio_keys(self) -> list[str]:
"""Keys to access audio modalities (whether they are linked to a camera or not)."""
"""Keys to access audio modalities."""
return [key for key, ft in self.features.items() if ft["dtype"] == "audio"]
@property
def audio_camera_keys_mapping(self) -> dict[str, str]:
"""Mapping between camera keys and audio keys when both are linked."""
return {
self.features[camera_key]["audio"]: camera_key
for camera_key in self.camera_keys
if self.features[camera_key]["audio"] is not None
and self.features[camera_key]["dtype"] == "video"
}
@property
def linked_audio_keys(self) -> list[str]:
"""Keys to access audio modalities linked to a camera."""
return [key for key in self.audio_keys if key in self.audio_camera_keys_mapping]
@property
def unlinked_audio_keys(self) -> list[str]:
"""Keys to access audio modalities not linked to a camera."""
return [key for key in self.audio_keys if key not in self.audio_camera_keys_mapping]
@property
def names(self) -> dict[str, list | dict]:
"""Names of the various dimensions of vector modalities."""
@ -310,7 +292,7 @@ class LeRobotDatasetMetadata:
self.update_video_info()
self.info["total_audio"] += len(self.audio_keys)
if len(self.unlinked_audio_keys) > 0:
if len(self.audio_keys) > 0:
self.update_audio_info()
write_info(self.info, self.root)
@ -342,7 +324,7 @@ class LeRobotDatasetMetadata:
Warning: this function writes info from first episode audio, implicitly assuming that all audio have
been encoded the same way. Also, this means it assumes the first episode exists.
"""
for key in self.unlinked_audio_keys:
for key in self.audio_keys:
if (
not self.features[key].get("info", None)
or (len(self.features[key]["info"]) == 1 and "sample_rate" in self.features[key]["info"])
@ -480,17 +462,31 @@ class LeRobotDataset(torch.utils.data.Dataset):
info.json
stats.json
tasks.jsonl
videos
videos
chunk-000
observation.images.laptop
episode_000000.mp4
episode_000001.mp4
episode_000002.mp4
...
observation.images.phone
episode_000000.mp4
episode_000001.mp4
episode_000002.mp4
...
chunk-001
...
audio
chunk-000
observation.images.laptop
episode_000000.mp4
episode_000001.mp4
episode_000002.mp4
observation.audio.laptop
episode_000000.m4a
episode_000001.m4a
episode_000002.m4a
...
observation.images.phone
episode_000000.mp4
episode_000001.mp4
episode_000002.mp4
observation.audio.phone
episode_000000.m4a
episode_000001.m4a
episode_000002.m4a
...
chunk-001
...
@ -569,9 +565,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.hf_dataset = self.load_hf_dataset()
except (AssertionError, FileNotFoundError, NotADirectoryError):
self.revision = get_safe_version(self.repo_id, self.revision)
self.download_episodes(
download_videos, download_audio
) # Audio embedded in video files (.mp4) will be downloaded if download_videos is set to True, regardless of the value of download_audio
self.download_episodes(download_videos, download_audio)
self.hf_dataset = self.load_hf_dataset()
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
@ -582,7 +576,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
# TODO(CarolinePascal) : add check for audio duration with respect to video duration and episode duration.
# TODO(CarolinePascal) : add check for audio duration with respect to episode duration BUT this will be CPU expensive if there are many episodes !
# Setup delta_indices
if self.delta_timestamps is not None:
@ -604,9 +598,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
) -> None:
ignore_patterns = ["images/"]
if not push_videos:
ignore_patterns.append(
"videos/"
) # Audio embedded in video files (.mp4) will be automatically pushed if videos are pushed
ignore_patterns.append("videos/")
if not push_audio:
ignore_patterns.append("audio/")
@ -675,9 +667,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
files = None
ignore_patterns = []
if not download_videos:
ignore_patterns.append(
"videos/"
) # Audio embedded in video files (.mp4) will be automatically downloaded if videos are downloaded
ignore_patterns.append("videos/")
if not download_audio:
ignore_patterns.append("audio/")
if self.episodes is not None:
@ -696,10 +686,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
]
fpaths += video_files
if len(self.meta.unlinked_audio_keys) > 0:
if len(self.meta.audio_keys) > 0:
audio_files = [
str(self.meta.get_compressed_audio_file_path(ep_idx, audio_key))
for audio_key in self.meta.unlinked_audio_keys
for audio_key in self.meta.audio_keys
for ep_idx in episodes
]
fpaths += audio_files
@ -792,7 +782,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
query_indices: dict[str, list[int]] | None = None,
) -> dict[str, list[float]]:
query_timestamps = {}
for key in self.meta.audio_keys: # Standalone audio and audio embedded in video as well !
for key in self.meta.audio_keys:
if query_indices is not None and key in query_indices:
timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
query_timestamps[key] = torch.stack(timestamps).tolist()
@ -828,14 +818,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
) -> dict[str, torch.Tensor]:
item = {}
for audio_key, query_ts in query_timestamps.items():
# Audio stored with video in a single .mp4 file
if audio_key in self.meta.linked_audio_keys:
audio_path = self.root / self.meta.get_video_file_path(
ep_idx, self.meta.audio_camera_keys_mapping[audio_key]
)
# Audio stored alone in a separate .m4a file
else:
audio_path = self.root / self.meta.get_compressed_audio_file_path(ep_idx, audio_key)
audio_path = self.root / self.meta.get_compressed_audio_file_path(ep_idx, audio_key)
audio_chunk = decode_audio(audio_path, query_ts, query_duration, self.audio_backend)
item[audio_key] = audio_chunk.squeeze(0)
return item
@ -966,7 +949,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
elif self.features[key]["dtype"] == "audio":
if (
self.meta.robot_type is not None and self.meta.robot_type.startswith("lekiwi")
): # Rw data storage should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner
): # Raw data storage should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner
self.episode_buffer[key].append(frame[key])
else: # Otherwise, only the audio file path is stored in the episode buffer
if frame_index == 0:
@ -1062,12 +1045,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
ep_stats = compute_episode_stats(episode_buffer, self.features)
if len(self.meta.video_keys) > 0:
video_paths = self.encode_episode_videos(episode_index)
for key in self.meta.video_keys:
episode_buffer[key] = video_paths[key]
self.encode_episode_videos(episode_index)
if len(self.meta.unlinked_audio_keys) > 0: # Linked audio is already encoded in the video files
_ = self.encode_episode_audio(episode_index)
if len(self.meta.audio_keys) > 0:
self.encode_episode_audio(episode_index)
# `meta.save_episode` be executed after encoding the videos
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
@ -1177,12 +1158,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
episode_index=episode_index, image_key=video_key, frame_index=0
).parent
audio_path = None
if self.meta.features[video_key]["audio"] is not None:
audio_key = self.meta.features[video_key]["audio"]
audio_path = self._get_raw_audio_file_path(episode_index, audio_key)
encode_video_frames(img_dir, video_path, self.fps, audio_path=audio_path, overwrite=True)
encode_video_frames(img_dir, video_path, self.fps, overwrite=True)
return video_paths
@ -1193,7 +1169,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
since video encoding with ffmpeg is already using multithreading.
"""
audio_paths = {}
for audio_key in self.meta.unlinked_audio_keys:
for audio_key in self.meta.audio_keys:
input_audio_path = self.root / self._get_raw_audio_file_path(episode_index, audio_key)
output_audio_path = self.root / self.meta.get_compressed_audio_file_path(episode_index, audio_key)

View File

@ -25,10 +25,8 @@ from typing import Any, ClassVar
import pyarrow as pa
import torch
import torchaudio
import torchvision
from datasets.features.features import register_feature
from numpy import ceil
from PIL import Image
@ -42,74 +40,6 @@ def get_safe_default_codec():
return "pyav"
def decode_audio(
audio_path: Path | str,
timestamps: list[float],
duration: float,
backend: str | None = "ffmpeg",
) -> torch.Tensor:
"""
Decodes audio using the specified backend.
Args:
audio_path (Path): Path to the audio file.
timestamps (list[float]): List of (starting) timestamps to extract audio chunks.
duration (float): Duration of the audio chunks in seconds.
backend (str, optional): Backend to use for decoding. Defaults to "ffmpeg".
Returns:
torch.Tensor: Decoded audio chunks.
Currently supports ffmpeg.
"""
if backend == "torchcodec":
raise NotImplementedError("torchcodec is not yet supported for audio decoding")
elif backend == "ffmpeg":
return decode_audio_torchaudio(audio_path, timestamps, duration)
else:
raise ValueError(f"Unsupported video backend: {backend}")
def decode_audio_torchaudio(
audio_path: Path | str,
timestamps: list[float],
duration: float,
log_loaded_timestamps: bool = False,
) -> torch.Tensor:
# TODO(CarolinePascal) : add channels selection
audio_path = str(audio_path)
reader = torchaudio.io.StreamReader(src=audio_path)
audio_sample_rate = reader.get_src_stream_info(reader.default_audio_stream).sample_rate
# TODO(CarolinePascal) : sort timestamps ?
reader.add_basic_audio_stream(
frames_per_chunk=int(ceil(duration * audio_sample_rate)), # Too much is better than not enough
buffer_chunk_size=-1, # No dropping frames
format="fltp", # Format as float32
)
audio_chunks = []
for ts in timestamps:
reader.seek(ts) # Default to closest audio sample
status = reader.fill_buffer()
if status != 0:
logging.warning("Audio stream reached end of recording before decoding desired timestamps.")
current_audio_chunk = reader.pop_chunks()[0]
if log_loaded_timestamps:
logging.info(
f"audio chunk loaded at starting timestamp={current_audio_chunk['pts']:.4f} with duration={len(current_audio_chunk) / audio_sample_rate:.4f}"
)
audio_chunks.append(current_audio_chunk)
audio_chunks = torch.stack(audio_chunks)
assert len(timestamps) == len(audio_chunks)
return audio_chunks
def decode_video_frames(
video_path: Path | str,
timestamps: list[float],
@ -313,53 +243,14 @@ def decode_video_frames_torchcodec(
return closest_frames
def encode_audio(
input_path: Path | str,
output_path: Path | str,
codec: str = "aac", # TODO(CarolinePascal) : investigate Fraunhofer FDK AAC (libfdk_aac) codec and and constant (file size control) /variable (quality control) bitrate options
log_level: str | None = "error",
overwrite: bool = False,
) -> None:
"""Encodes an audio file using ffmpeg."""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
ffmpeg_args = OrderedDict(
[
("-i", str(input_path)),
("-acodec", codec),
]
)
if log_level is not None:
ffmpeg_args["-loglevel"] = str(log_level)
ffmpeg_args = [item for pair in ffmpeg_args.items() for item in pair]
if overwrite:
ffmpeg_args.append("-y")
ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(output_path)]
# redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal
subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL)
if not output_path.exists():
raise OSError(
f"Audio encoding did not work. File not found: {output_path}. "
f"Try running the command manually to debug: `{''.join(ffmpeg_cmd)}`"
)
def encode_video_frames(
imgs_dir: Path | str,
video_path: Path | str,
fps: int,
audio_path: Path | str | None = None,
vcodec: str = "libsvtav1",
pix_fmt: str = "yuv420p",
g: int | None = 2,
crf: int | None = 30,
acodec: str = "aac", # TODO(CarolinePascal) : investigate Fraunhofer FDK AAC (libfdk_aac) codec and and constant (file size control) /variable (quality control) bitrate options
fast_decode: int = 0,
log_level: str | None = "error",
overwrite: bool = False,
@ -377,18 +268,6 @@ def encode_video_frames(
]
)
ffmpeg_audio_args = OrderedDict()
if audio_path is not None:
audio_path = Path(audio_path)
audio_path.parent.mkdir(parents=True, exist_ok=True)
ffmpeg_audio_args.update(
OrderedDict(
[
("-i", str(audio_path)),
]
)
)
ffmpeg_encoding_args = OrderedDict(
[
("-pix_fmt", pix_fmt),
@ -404,14 +283,10 @@ def encode_video_frames(
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
ffmpeg_encoding_args[key] = value
if audio_path is not None:
ffmpeg_encoding_args["-acodec"] = acodec
if log_level is not None:
ffmpeg_encoding_args["-loglevel"] = str(log_level)
ffmpeg_args = [item for pair in ffmpeg_video_args.items() for item in pair]
ffmpeg_args += [item for pair in ffmpeg_audio_args.items() for item in pair]
ffmpeg_args += [item for pair in ffmpeg_encoding_args.items() for item in pair]
if overwrite:
ffmpeg_args.append("-y")
@ -460,42 +335,6 @@ with warnings.catch_warnings():
register_feature(VideoFrame, "VideoFrame")
def get_audio_info(video_path: Path | str) -> dict:
ffprobe_audio_cmd = [
"ffprobe",
"-v",
"error",
"-select_streams",
"a:0",
"-show_entries",
"stream=channels,codec_name,bit_rate,sample_rate,bit_depth,channel_layout,duration",
"-of",
"json",
str(video_path),
]
result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
if result.returncode != 0:
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
info = json.loads(result.stdout)
audio_stream_info = info["streams"][0] if info.get("streams") else None
if audio_stream_info is None:
return {"has_audio": False}
# Return the information, defaulting to None if no audio stream is present
return {
"has_audio": True,
"audio.channels": audio_stream_info.get("channels", None),
"audio.codec": audio_stream_info.get("codec_name", None),
"audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None,
"audio.sample_rate": int(audio_stream_info["sample_rate"])
if audio_stream_info.get("sample_rate")
else None,
"audio.bit_depth": audio_stream_info.get("bit_depth", None),
"audio.channel_layout": audio_stream_info.get("channel_layout", None),
}
def get_video_info(video_path: Path | str) -> dict:
ffprobe_video_cmd = [
"ffprobe",
@ -531,7 +370,6 @@ def get_video_info(video_path: Path | str) -> dict:
"video.codec": video_stream_info["codec_name"],
"video.pix_fmt": video_stream_info["pix_fmt"],
"video.is_depth_map": False,
**get_audio_info(video_path),
}
return video_info

View File

@ -48,8 +48,6 @@ class OpenCVCameraConfig(CameraConfig):
rotation: int | None = None
mock: bool = False
microphone: str | None = None
def __post_init__(self):
if self.color_mode not in ["rgb", "bgr"]:
raise ValueError(

View File

@ -265,8 +265,6 @@ class IntelRealSenseCamera:
elif config.rotation == 180:
self.rotation = cv2.ROTATE_180
self.microphone = None # No microphones on realsense cameras, sorry
def find_serial_number_from_name(self, name):
camera_infos = find_cameras()
camera_names = [cam["name"] for cam in camera_infos]

View File

@ -281,8 +281,6 @@ class OpenCVCamera:
elif config.rotation == 180:
self.rotation = cv2.ROTATE_180
self.microphone = config.microphone
def connect(self):
if self.is_connected:
raise RobotDeviceAlreadyConnectedError(f"OpenCVCamera({self.camera_index}) is already connected.")

View File

@ -486,7 +486,6 @@ class So100RobotConfig(ManipulatorRobotConfig):
fps=30,
width=640,
height=480,
microphone="laptop",
),
"phone": OpenCVCameraConfig(
camera_index=1,

View File

@ -181,7 +181,6 @@ class ManipulatorRobot:
"shape": (cam.height, cam.width, cam.channels),
"names": ["height", "width", "channels"],
"info": None,
"audio": "observation.audio." + cam.microphone if cam.microphone is not None else None,
}
return cam_ft
@ -211,7 +210,9 @@ class ManipulatorRobot:
"dtype": "audio",
"shape": (len(mic.channels),),
"names": "channels",
"info": {"sample_rate": mic.sample_rate},
"info": {
"sample_rate": mic.sample_rate
}, # we need to store the sample rate here in the case of audio chunks recording (for LeKiwi), as it will not be available anymore when writing the audio file
}
return mic_ft