fix(audio): separate audio from video
This commit is contained in:
parent
6cf9cb35ba
commit
ca716ed196
|
@ -0,0 +1,165 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import subprocess
|
||||||
|
from collections import OrderedDict
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from numpy import ceil
|
||||||
|
|
||||||
|
|
||||||
|
def decode_audio(
|
||||||
|
audio_path: Path | str,
|
||||||
|
timestamps: list[float],
|
||||||
|
duration: float,
|
||||||
|
backend: str | None = "ffmpeg",
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Decodes audio using the specified backend.
|
||||||
|
Args:
|
||||||
|
audio_path (Path): Path to the audio file.
|
||||||
|
timestamps (list[float]): List of (starting) timestamps to extract audio chunks.
|
||||||
|
duration (float): Duration of the audio chunks in seconds.
|
||||||
|
backend (str, optional): Backend to use for decoding. Defaults to "ffmpeg".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Decoded audio chunks.
|
||||||
|
|
||||||
|
Currently supports ffmpeg.
|
||||||
|
"""
|
||||||
|
if backend == "torchcodec":
|
||||||
|
raise NotImplementedError("torchcodec is not yet supported for audio decoding")
|
||||||
|
elif backend == "ffmpeg":
|
||||||
|
return decode_audio_torchaudio(audio_path, timestamps, duration)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported video backend: {backend}")
|
||||||
|
|
||||||
|
|
||||||
|
def decode_audio_torchaudio(
|
||||||
|
audio_path: Path | str,
|
||||||
|
timestamps: list[float],
|
||||||
|
duration: float,
|
||||||
|
log_loaded_timestamps: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# TODO(CarolinePascal) : add channels selection
|
||||||
|
audio_path = str(audio_path)
|
||||||
|
|
||||||
|
reader = torchaudio.io.StreamReader(src=audio_path)
|
||||||
|
audio_sample_rate = reader.get_src_stream_info(reader.default_audio_stream).sample_rate
|
||||||
|
|
||||||
|
# TODO(CarolinePascal) : sort timestamps ?
|
||||||
|
reader.add_basic_audio_stream(
|
||||||
|
frames_per_chunk=int(ceil(duration * audio_sample_rate)), # Too much is better than not enough
|
||||||
|
buffer_chunk_size=-1, # No dropping frames
|
||||||
|
format="fltp", # Format as float32
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_chunks = []
|
||||||
|
for ts in timestamps:
|
||||||
|
reader.seek(ts) # Default to closest audio sample
|
||||||
|
status = reader.fill_buffer()
|
||||||
|
if status != 0:
|
||||||
|
logging.warning("Audio stream reached end of recording before decoding desired timestamps.")
|
||||||
|
|
||||||
|
current_audio_chunk = reader.pop_chunks()[0]
|
||||||
|
|
||||||
|
if log_loaded_timestamps:
|
||||||
|
logging.info(
|
||||||
|
f"audio chunk loaded at starting timestamp={current_audio_chunk['pts']:.4f} with duration={len(current_audio_chunk) / audio_sample_rate:.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_chunks.append(current_audio_chunk)
|
||||||
|
|
||||||
|
audio_chunks = torch.stack(audio_chunks)
|
||||||
|
|
||||||
|
assert len(timestamps) == len(audio_chunks)
|
||||||
|
return audio_chunks
|
||||||
|
|
||||||
|
|
||||||
|
def encode_audio(
|
||||||
|
input_path: Path | str,
|
||||||
|
output_path: Path | str,
|
||||||
|
codec: str = "aac", # TODO(CarolinePascal) : investigate Fraunhofer FDK AAC (libfdk_aac) codec and and constant (file size control) /variable (quality control) bitrate options
|
||||||
|
log_level: str | None = "error",
|
||||||
|
overwrite: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""Encodes an audio file using ffmpeg."""
|
||||||
|
output_path = Path(output_path)
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
ffmpeg_args = OrderedDict(
|
||||||
|
[
|
||||||
|
("-i", str(input_path)),
|
||||||
|
("-acodec", codec),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if log_level is not None:
|
||||||
|
ffmpeg_args["-loglevel"] = str(log_level)
|
||||||
|
|
||||||
|
ffmpeg_args = [item for pair in ffmpeg_args.items() for item in pair]
|
||||||
|
if overwrite:
|
||||||
|
ffmpeg_args.append("-y")
|
||||||
|
|
||||||
|
ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(output_path)]
|
||||||
|
|
||||||
|
# redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal
|
||||||
|
subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL)
|
||||||
|
|
||||||
|
if not output_path.exists():
|
||||||
|
raise OSError(
|
||||||
|
f"Audio encoding did not work. File not found: {output_path}. "
|
||||||
|
f"Try running the command manually to debug: `{''.join(ffmpeg_cmd)}`"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_audio_info(video_path: Path | str) -> dict:
|
||||||
|
ffprobe_audio_cmd = [
|
||||||
|
"ffprobe",
|
||||||
|
"-v",
|
||||||
|
"error",
|
||||||
|
"-select_streams",
|
||||||
|
"a:0",
|
||||||
|
"-show_entries",
|
||||||
|
"stream=channels,codec_name,bit_rate,sample_rate,bit_depth,channel_layout,duration",
|
||||||
|
"-of",
|
||||||
|
"json",
|
||||||
|
str(video_path),
|
||||||
|
]
|
||||||
|
result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||||
|
if result.returncode != 0:
|
||||||
|
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
||||||
|
|
||||||
|
info = json.loads(result.stdout)
|
||||||
|
audio_stream_info = info["streams"][0] if info.get("streams") else None
|
||||||
|
if audio_stream_info is None:
|
||||||
|
return {"has_audio": False}
|
||||||
|
|
||||||
|
# Return the information, defaulting to None if no audio stream is present
|
||||||
|
return {
|
||||||
|
"has_audio": True,
|
||||||
|
"audio.channels": audio_stream_info.get("channels", None),
|
||||||
|
"audio.codec": audio_stream_info.get("codec_name", None),
|
||||||
|
"audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None,
|
||||||
|
"audio.sample_rate": int(audio_stream_info["sample_rate"])
|
||||||
|
if audio_stream_info.get("sample_rate")
|
||||||
|
else None,
|
||||||
|
"audio.bit_depth": audio_stream_info.get("bit_depth", None),
|
||||||
|
"audio.channel_layout": audio_stream_info.get("channel_layout", None),
|
||||||
|
}
|
|
@ -32,6 +32,11 @@ from huggingface_hub.constants import REPOCARD_NAME
|
||||||
from huggingface_hub.errors import RevisionNotFoundError
|
from huggingface_hub.errors import RevisionNotFoundError
|
||||||
|
|
||||||
from lerobot.common.constants import HF_LEROBOT_HOME
|
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||||
|
from lerobot.common.datasets.audio_utils import (
|
||||||
|
decode_audio,
|
||||||
|
encode_audio,
|
||||||
|
get_audio_info,
|
||||||
|
)
|
||||||
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
|
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
|
||||||
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
|
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
|
@ -70,11 +75,8 @@ from lerobot.common.datasets.utils import (
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.video_utils import (
|
from lerobot.common.datasets.video_utils import (
|
||||||
VideoFrame,
|
VideoFrame,
|
||||||
decode_audio,
|
|
||||||
decode_video_frames,
|
decode_video_frames,
|
||||||
encode_audio,
|
|
||||||
encode_video_frames,
|
encode_video_frames,
|
||||||
get_audio_info,
|
|
||||||
get_safe_default_codec,
|
get_safe_default_codec,
|
||||||
get_video_info,
|
get_video_info,
|
||||||
)
|
)
|
||||||
|
@ -207,29 +209,9 @@ class LeRobotDatasetMetadata:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def audio_keys(self) -> list[str]:
|
def audio_keys(self) -> list[str]:
|
||||||
"""Keys to access audio modalities (whether they are linked to a camera or not)."""
|
"""Keys to access audio modalities."""
|
||||||
return [key for key, ft in self.features.items() if ft["dtype"] == "audio"]
|
return [key for key, ft in self.features.items() if ft["dtype"] == "audio"]
|
||||||
|
|
||||||
@property
|
|
||||||
def audio_camera_keys_mapping(self) -> dict[str, str]:
|
|
||||||
"""Mapping between camera keys and audio keys when both are linked."""
|
|
||||||
return {
|
|
||||||
self.features[camera_key]["audio"]: camera_key
|
|
||||||
for camera_key in self.camera_keys
|
|
||||||
if self.features[camera_key]["audio"] is not None
|
|
||||||
and self.features[camera_key]["dtype"] == "video"
|
|
||||||
}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def linked_audio_keys(self) -> list[str]:
|
|
||||||
"""Keys to access audio modalities linked to a camera."""
|
|
||||||
return [key for key in self.audio_keys if key in self.audio_camera_keys_mapping]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def unlinked_audio_keys(self) -> list[str]:
|
|
||||||
"""Keys to access audio modalities not linked to a camera."""
|
|
||||||
return [key for key in self.audio_keys if key not in self.audio_camera_keys_mapping]
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def names(self) -> dict[str, list | dict]:
|
def names(self) -> dict[str, list | dict]:
|
||||||
"""Names of the various dimensions of vector modalities."""
|
"""Names of the various dimensions of vector modalities."""
|
||||||
|
@ -310,7 +292,7 @@ class LeRobotDatasetMetadata:
|
||||||
self.update_video_info()
|
self.update_video_info()
|
||||||
|
|
||||||
self.info["total_audio"] += len(self.audio_keys)
|
self.info["total_audio"] += len(self.audio_keys)
|
||||||
if len(self.unlinked_audio_keys) > 0:
|
if len(self.audio_keys) > 0:
|
||||||
self.update_audio_info()
|
self.update_audio_info()
|
||||||
|
|
||||||
write_info(self.info, self.root)
|
write_info(self.info, self.root)
|
||||||
|
@ -342,7 +324,7 @@ class LeRobotDatasetMetadata:
|
||||||
Warning: this function writes info from first episode audio, implicitly assuming that all audio have
|
Warning: this function writes info from first episode audio, implicitly assuming that all audio have
|
||||||
been encoded the same way. Also, this means it assumes the first episode exists.
|
been encoded the same way. Also, this means it assumes the first episode exists.
|
||||||
"""
|
"""
|
||||||
for key in self.unlinked_audio_keys:
|
for key in self.audio_keys:
|
||||||
if (
|
if (
|
||||||
not self.features[key].get("info", None)
|
not self.features[key].get("info", None)
|
||||||
or (len(self.features[key]["info"]) == 1 and "sample_rate" in self.features[key]["info"])
|
or (len(self.features[key]["info"]) == 1 and "sample_rate" in self.features[key]["info"])
|
||||||
|
@ -480,17 +462,31 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
│ ├── info.json
|
│ ├── info.json
|
||||||
│ ├── stats.json
|
│ ├── stats.json
|
||||||
│ └── tasks.jsonl
|
│ └── tasks.jsonl
|
||||||
└── videos
|
├── videos
|
||||||
|
│ ├── chunk-000
|
||||||
|
│ │ ├── observation.images.laptop
|
||||||
|
│ │ │ ├── episode_000000.mp4
|
||||||
|
│ │ │ ├── episode_000001.mp4
|
||||||
|
│ │ │ ├── episode_000002.mp4
|
||||||
|
│ │ │ └── ...
|
||||||
|
│ │ ├── observation.images.phone
|
||||||
|
│ │ │ ├── episode_000000.mp4
|
||||||
|
│ │ │ ├── episode_000001.mp4
|
||||||
|
│ │ │ ├── episode_000002.mp4
|
||||||
|
│ │ │ └── ...
|
||||||
|
│ ├── chunk-001
|
||||||
|
│ └── ...
|
||||||
|
└── audio
|
||||||
├── chunk-000
|
├── chunk-000
|
||||||
│ ├── observation.images.laptop
|
│ ├── observation.audio.laptop
|
||||||
│ │ ├── episode_000000.mp4
|
│ │ ├── episode_000000.m4a
|
||||||
│ │ ├── episode_000001.mp4
|
│ │ ├── episode_000001.m4a
|
||||||
│ │ ├── episode_000002.mp4
|
│ │ ├── episode_000002.m4a
|
||||||
│ │ └── ...
|
│ │ └── ...
|
||||||
│ ├── observation.images.phone
|
│ ├── observation.audio.phone
|
||||||
│ │ ├── episode_000000.mp4
|
│ │ ├── episode_000000.m4a
|
||||||
│ │ ├── episode_000001.mp4
|
│ │ ├── episode_000001.m4a
|
||||||
│ │ ├── episode_000002.mp4
|
│ │ ├── episode_000002.m4a
|
||||||
│ │ └── ...
|
│ │ └── ...
|
||||||
├── chunk-001
|
├── chunk-001
|
||||||
└── ...
|
└── ...
|
||||||
|
@ -569,9 +565,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
self.hf_dataset = self.load_hf_dataset()
|
self.hf_dataset = self.load_hf_dataset()
|
||||||
except (AssertionError, FileNotFoundError, NotADirectoryError):
|
except (AssertionError, FileNotFoundError, NotADirectoryError):
|
||||||
self.revision = get_safe_version(self.repo_id, self.revision)
|
self.revision = get_safe_version(self.repo_id, self.revision)
|
||||||
self.download_episodes(
|
self.download_episodes(download_videos, download_audio)
|
||||||
download_videos, download_audio
|
|
||||||
) # Audio embedded in video files (.mp4) will be downloaded if download_videos is set to True, regardless of the value of download_audio
|
|
||||||
self.hf_dataset = self.load_hf_dataset()
|
self.hf_dataset = self.load_hf_dataset()
|
||||||
|
|
||||||
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
||||||
|
@ -582,7 +576,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
|
ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
|
||||||
check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
|
check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
|
||||||
|
|
||||||
# TODO(CarolinePascal) : add check for audio duration with respect to video duration and episode duration.
|
# TODO(CarolinePascal) : add check for audio duration with respect to episode duration BUT this will be CPU expensive if there are many episodes !
|
||||||
|
|
||||||
# Setup delta_indices
|
# Setup delta_indices
|
||||||
if self.delta_timestamps is not None:
|
if self.delta_timestamps is not None:
|
||||||
|
@ -604,9 +598,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
) -> None:
|
) -> None:
|
||||||
ignore_patterns = ["images/"]
|
ignore_patterns = ["images/"]
|
||||||
if not push_videos:
|
if not push_videos:
|
||||||
ignore_patterns.append(
|
ignore_patterns.append("videos/")
|
||||||
"videos/"
|
|
||||||
) # Audio embedded in video files (.mp4) will be automatically pushed if videos are pushed
|
|
||||||
if not push_audio:
|
if not push_audio:
|
||||||
ignore_patterns.append("audio/")
|
ignore_patterns.append("audio/")
|
||||||
|
|
||||||
|
@ -675,9 +667,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
files = None
|
files = None
|
||||||
ignore_patterns = []
|
ignore_patterns = []
|
||||||
if not download_videos:
|
if not download_videos:
|
||||||
ignore_patterns.append(
|
ignore_patterns.append("videos/")
|
||||||
"videos/"
|
|
||||||
) # Audio embedded in video files (.mp4) will be automatically downloaded if videos are downloaded
|
|
||||||
if not download_audio:
|
if not download_audio:
|
||||||
ignore_patterns.append("audio/")
|
ignore_patterns.append("audio/")
|
||||||
if self.episodes is not None:
|
if self.episodes is not None:
|
||||||
|
@ -696,10 +686,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
]
|
]
|
||||||
fpaths += video_files
|
fpaths += video_files
|
||||||
|
|
||||||
if len(self.meta.unlinked_audio_keys) > 0:
|
if len(self.meta.audio_keys) > 0:
|
||||||
audio_files = [
|
audio_files = [
|
||||||
str(self.meta.get_compressed_audio_file_path(ep_idx, audio_key))
|
str(self.meta.get_compressed_audio_file_path(ep_idx, audio_key))
|
||||||
for audio_key in self.meta.unlinked_audio_keys
|
for audio_key in self.meta.audio_keys
|
||||||
for ep_idx in episodes
|
for ep_idx in episodes
|
||||||
]
|
]
|
||||||
fpaths += audio_files
|
fpaths += audio_files
|
||||||
|
@ -792,7 +782,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
query_indices: dict[str, list[int]] | None = None,
|
query_indices: dict[str, list[int]] | None = None,
|
||||||
) -> dict[str, list[float]]:
|
) -> dict[str, list[float]]:
|
||||||
query_timestamps = {}
|
query_timestamps = {}
|
||||||
for key in self.meta.audio_keys: # Standalone audio and audio embedded in video as well !
|
for key in self.meta.audio_keys:
|
||||||
if query_indices is not None and key in query_indices:
|
if query_indices is not None and key in query_indices:
|
||||||
timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
|
timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
|
||||||
query_timestamps[key] = torch.stack(timestamps).tolist()
|
query_timestamps[key] = torch.stack(timestamps).tolist()
|
||||||
|
@ -828,13 +818,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
) -> dict[str, torch.Tensor]:
|
) -> dict[str, torch.Tensor]:
|
||||||
item = {}
|
item = {}
|
||||||
for audio_key, query_ts in query_timestamps.items():
|
for audio_key, query_ts in query_timestamps.items():
|
||||||
# Audio stored with video in a single .mp4 file
|
|
||||||
if audio_key in self.meta.linked_audio_keys:
|
|
||||||
audio_path = self.root / self.meta.get_video_file_path(
|
|
||||||
ep_idx, self.meta.audio_camera_keys_mapping[audio_key]
|
|
||||||
)
|
|
||||||
# Audio stored alone in a separate .m4a file
|
|
||||||
else:
|
|
||||||
audio_path = self.root / self.meta.get_compressed_audio_file_path(ep_idx, audio_key)
|
audio_path = self.root / self.meta.get_compressed_audio_file_path(ep_idx, audio_key)
|
||||||
audio_chunk = decode_audio(audio_path, query_ts, query_duration, self.audio_backend)
|
audio_chunk = decode_audio(audio_path, query_ts, query_duration, self.audio_backend)
|
||||||
item[audio_key] = audio_chunk.squeeze(0)
|
item[audio_key] = audio_chunk.squeeze(0)
|
||||||
|
@ -966,7 +949,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
elif self.features[key]["dtype"] == "audio":
|
elif self.features[key]["dtype"] == "audio":
|
||||||
if (
|
if (
|
||||||
self.meta.robot_type is not None and self.meta.robot_type.startswith("lekiwi")
|
self.meta.robot_type is not None and self.meta.robot_type.startswith("lekiwi")
|
||||||
): # Rw data storage should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner
|
): # Raw data storage should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner
|
||||||
self.episode_buffer[key].append(frame[key])
|
self.episode_buffer[key].append(frame[key])
|
||||||
else: # Otherwise, only the audio file path is stored in the episode buffer
|
else: # Otherwise, only the audio file path is stored in the episode buffer
|
||||||
if frame_index == 0:
|
if frame_index == 0:
|
||||||
|
@ -1062,12 +1045,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
ep_stats = compute_episode_stats(episode_buffer, self.features)
|
ep_stats = compute_episode_stats(episode_buffer, self.features)
|
||||||
|
|
||||||
if len(self.meta.video_keys) > 0:
|
if len(self.meta.video_keys) > 0:
|
||||||
video_paths = self.encode_episode_videos(episode_index)
|
self.encode_episode_videos(episode_index)
|
||||||
for key in self.meta.video_keys:
|
|
||||||
episode_buffer[key] = video_paths[key]
|
|
||||||
|
|
||||||
if len(self.meta.unlinked_audio_keys) > 0: # Linked audio is already encoded in the video files
|
if len(self.meta.audio_keys) > 0:
|
||||||
_ = self.encode_episode_audio(episode_index)
|
self.encode_episode_audio(episode_index)
|
||||||
|
|
||||||
# `meta.save_episode` be executed after encoding the videos
|
# `meta.save_episode` be executed after encoding the videos
|
||||||
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
|
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
|
||||||
|
@ -1177,12 +1158,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
episode_index=episode_index, image_key=video_key, frame_index=0
|
episode_index=episode_index, image_key=video_key, frame_index=0
|
||||||
).parent
|
).parent
|
||||||
|
|
||||||
audio_path = None
|
encode_video_frames(img_dir, video_path, self.fps, overwrite=True)
|
||||||
if self.meta.features[video_key]["audio"] is not None:
|
|
||||||
audio_key = self.meta.features[video_key]["audio"]
|
|
||||||
audio_path = self._get_raw_audio_file_path(episode_index, audio_key)
|
|
||||||
|
|
||||||
encode_video_frames(img_dir, video_path, self.fps, audio_path=audio_path, overwrite=True)
|
|
||||||
|
|
||||||
return video_paths
|
return video_paths
|
||||||
|
|
||||||
|
@ -1193,7 +1169,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
since video encoding with ffmpeg is already using multithreading.
|
since video encoding with ffmpeg is already using multithreading.
|
||||||
"""
|
"""
|
||||||
audio_paths = {}
|
audio_paths = {}
|
||||||
for audio_key in self.meta.unlinked_audio_keys:
|
for audio_key in self.meta.audio_keys:
|
||||||
input_audio_path = self.root / self._get_raw_audio_file_path(episode_index, audio_key)
|
input_audio_path = self.root / self._get_raw_audio_file_path(episode_index, audio_key)
|
||||||
output_audio_path = self.root / self.meta.get_compressed_audio_file_path(episode_index, audio_key)
|
output_audio_path = self.root / self.meta.get_compressed_audio_file_path(episode_index, audio_key)
|
||||||
|
|
||||||
|
|
|
@ -25,10 +25,8 @@ from typing import Any, ClassVar
|
||||||
|
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
|
||||||
import torchvision
|
import torchvision
|
||||||
from datasets.features.features import register_feature
|
from datasets.features.features import register_feature
|
||||||
from numpy import ceil
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
@ -42,74 +40,6 @@ def get_safe_default_codec():
|
||||||
return "pyav"
|
return "pyav"
|
||||||
|
|
||||||
|
|
||||||
def decode_audio(
|
|
||||||
audio_path: Path | str,
|
|
||||||
timestamps: list[float],
|
|
||||||
duration: float,
|
|
||||||
backend: str | None = "ffmpeg",
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Decodes audio using the specified backend.
|
|
||||||
Args:
|
|
||||||
audio_path (Path): Path to the audio file.
|
|
||||||
timestamps (list[float]): List of (starting) timestamps to extract audio chunks.
|
|
||||||
duration (float): Duration of the audio chunks in seconds.
|
|
||||||
backend (str, optional): Backend to use for decoding. Defaults to "ffmpeg".
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: Decoded audio chunks.
|
|
||||||
|
|
||||||
Currently supports ffmpeg.
|
|
||||||
"""
|
|
||||||
if backend == "torchcodec":
|
|
||||||
raise NotImplementedError("torchcodec is not yet supported for audio decoding")
|
|
||||||
elif backend == "ffmpeg":
|
|
||||||
return decode_audio_torchaudio(audio_path, timestamps, duration)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported video backend: {backend}")
|
|
||||||
|
|
||||||
|
|
||||||
def decode_audio_torchaudio(
|
|
||||||
audio_path: Path | str,
|
|
||||||
timestamps: list[float],
|
|
||||||
duration: float,
|
|
||||||
log_loaded_timestamps: bool = False,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
# TODO(CarolinePascal) : add channels selection
|
|
||||||
audio_path = str(audio_path)
|
|
||||||
|
|
||||||
reader = torchaudio.io.StreamReader(src=audio_path)
|
|
||||||
audio_sample_rate = reader.get_src_stream_info(reader.default_audio_stream).sample_rate
|
|
||||||
|
|
||||||
# TODO(CarolinePascal) : sort timestamps ?
|
|
||||||
reader.add_basic_audio_stream(
|
|
||||||
frames_per_chunk=int(ceil(duration * audio_sample_rate)), # Too much is better than not enough
|
|
||||||
buffer_chunk_size=-1, # No dropping frames
|
|
||||||
format="fltp", # Format as float32
|
|
||||||
)
|
|
||||||
|
|
||||||
audio_chunks = []
|
|
||||||
for ts in timestamps:
|
|
||||||
reader.seek(ts) # Default to closest audio sample
|
|
||||||
status = reader.fill_buffer()
|
|
||||||
if status != 0:
|
|
||||||
logging.warning("Audio stream reached end of recording before decoding desired timestamps.")
|
|
||||||
|
|
||||||
current_audio_chunk = reader.pop_chunks()[0]
|
|
||||||
|
|
||||||
if log_loaded_timestamps:
|
|
||||||
logging.info(
|
|
||||||
f"audio chunk loaded at starting timestamp={current_audio_chunk['pts']:.4f} with duration={len(current_audio_chunk) / audio_sample_rate:.4f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
audio_chunks.append(current_audio_chunk)
|
|
||||||
|
|
||||||
audio_chunks = torch.stack(audio_chunks)
|
|
||||||
|
|
||||||
assert len(timestamps) == len(audio_chunks)
|
|
||||||
return audio_chunks
|
|
||||||
|
|
||||||
|
|
||||||
def decode_video_frames(
|
def decode_video_frames(
|
||||||
video_path: Path | str,
|
video_path: Path | str,
|
||||||
timestamps: list[float],
|
timestamps: list[float],
|
||||||
|
@ -313,53 +243,14 @@ def decode_video_frames_torchcodec(
|
||||||
return closest_frames
|
return closest_frames
|
||||||
|
|
||||||
|
|
||||||
def encode_audio(
|
|
||||||
input_path: Path | str,
|
|
||||||
output_path: Path | str,
|
|
||||||
codec: str = "aac", # TODO(CarolinePascal) : investigate Fraunhofer FDK AAC (libfdk_aac) codec and and constant (file size control) /variable (quality control) bitrate options
|
|
||||||
log_level: str | None = "error",
|
|
||||||
overwrite: bool = False,
|
|
||||||
) -> None:
|
|
||||||
"""Encodes an audio file using ffmpeg."""
|
|
||||||
output_path = Path(output_path)
|
|
||||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
ffmpeg_args = OrderedDict(
|
|
||||||
[
|
|
||||||
("-i", str(input_path)),
|
|
||||||
("-acodec", codec),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
if log_level is not None:
|
|
||||||
ffmpeg_args["-loglevel"] = str(log_level)
|
|
||||||
|
|
||||||
ffmpeg_args = [item for pair in ffmpeg_args.items() for item in pair]
|
|
||||||
if overwrite:
|
|
||||||
ffmpeg_args.append("-y")
|
|
||||||
|
|
||||||
ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(output_path)]
|
|
||||||
|
|
||||||
# redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal
|
|
||||||
subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL)
|
|
||||||
|
|
||||||
if not output_path.exists():
|
|
||||||
raise OSError(
|
|
||||||
f"Audio encoding did not work. File not found: {output_path}. "
|
|
||||||
f"Try running the command manually to debug: `{''.join(ffmpeg_cmd)}`"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def encode_video_frames(
|
def encode_video_frames(
|
||||||
imgs_dir: Path | str,
|
imgs_dir: Path | str,
|
||||||
video_path: Path | str,
|
video_path: Path | str,
|
||||||
fps: int,
|
fps: int,
|
||||||
audio_path: Path | str | None = None,
|
|
||||||
vcodec: str = "libsvtav1",
|
vcodec: str = "libsvtav1",
|
||||||
pix_fmt: str = "yuv420p",
|
pix_fmt: str = "yuv420p",
|
||||||
g: int | None = 2,
|
g: int | None = 2,
|
||||||
crf: int | None = 30,
|
crf: int | None = 30,
|
||||||
acodec: str = "aac", # TODO(CarolinePascal) : investigate Fraunhofer FDK AAC (libfdk_aac) codec and and constant (file size control) /variable (quality control) bitrate options
|
|
||||||
fast_decode: int = 0,
|
fast_decode: int = 0,
|
||||||
log_level: str | None = "error",
|
log_level: str | None = "error",
|
||||||
overwrite: bool = False,
|
overwrite: bool = False,
|
||||||
|
@ -377,18 +268,6 @@ def encode_video_frames(
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
ffmpeg_audio_args = OrderedDict()
|
|
||||||
if audio_path is not None:
|
|
||||||
audio_path = Path(audio_path)
|
|
||||||
audio_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
ffmpeg_audio_args.update(
|
|
||||||
OrderedDict(
|
|
||||||
[
|
|
||||||
("-i", str(audio_path)),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
ffmpeg_encoding_args = OrderedDict(
|
ffmpeg_encoding_args = OrderedDict(
|
||||||
[
|
[
|
||||||
("-pix_fmt", pix_fmt),
|
("-pix_fmt", pix_fmt),
|
||||||
|
@ -404,14 +283,10 @@ def encode_video_frames(
|
||||||
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
|
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
|
||||||
ffmpeg_encoding_args[key] = value
|
ffmpeg_encoding_args[key] = value
|
||||||
|
|
||||||
if audio_path is not None:
|
|
||||||
ffmpeg_encoding_args["-acodec"] = acodec
|
|
||||||
|
|
||||||
if log_level is not None:
|
if log_level is not None:
|
||||||
ffmpeg_encoding_args["-loglevel"] = str(log_level)
|
ffmpeg_encoding_args["-loglevel"] = str(log_level)
|
||||||
|
|
||||||
ffmpeg_args = [item for pair in ffmpeg_video_args.items() for item in pair]
|
ffmpeg_args = [item for pair in ffmpeg_video_args.items() for item in pair]
|
||||||
ffmpeg_args += [item for pair in ffmpeg_audio_args.items() for item in pair]
|
|
||||||
ffmpeg_args += [item for pair in ffmpeg_encoding_args.items() for item in pair]
|
ffmpeg_args += [item for pair in ffmpeg_encoding_args.items() for item in pair]
|
||||||
if overwrite:
|
if overwrite:
|
||||||
ffmpeg_args.append("-y")
|
ffmpeg_args.append("-y")
|
||||||
|
@ -460,42 +335,6 @@ with warnings.catch_warnings():
|
||||||
register_feature(VideoFrame, "VideoFrame")
|
register_feature(VideoFrame, "VideoFrame")
|
||||||
|
|
||||||
|
|
||||||
def get_audio_info(video_path: Path | str) -> dict:
|
|
||||||
ffprobe_audio_cmd = [
|
|
||||||
"ffprobe",
|
|
||||||
"-v",
|
|
||||||
"error",
|
|
||||||
"-select_streams",
|
|
||||||
"a:0",
|
|
||||||
"-show_entries",
|
|
||||||
"stream=channels,codec_name,bit_rate,sample_rate,bit_depth,channel_layout,duration",
|
|
||||||
"-of",
|
|
||||||
"json",
|
|
||||||
str(video_path),
|
|
||||||
]
|
|
||||||
result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
|
||||||
if result.returncode != 0:
|
|
||||||
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
|
||||||
|
|
||||||
info = json.loads(result.stdout)
|
|
||||||
audio_stream_info = info["streams"][0] if info.get("streams") else None
|
|
||||||
if audio_stream_info is None:
|
|
||||||
return {"has_audio": False}
|
|
||||||
|
|
||||||
# Return the information, defaulting to None if no audio stream is present
|
|
||||||
return {
|
|
||||||
"has_audio": True,
|
|
||||||
"audio.channels": audio_stream_info.get("channels", None),
|
|
||||||
"audio.codec": audio_stream_info.get("codec_name", None),
|
|
||||||
"audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None,
|
|
||||||
"audio.sample_rate": int(audio_stream_info["sample_rate"])
|
|
||||||
if audio_stream_info.get("sample_rate")
|
|
||||||
else None,
|
|
||||||
"audio.bit_depth": audio_stream_info.get("bit_depth", None),
|
|
||||||
"audio.channel_layout": audio_stream_info.get("channel_layout", None),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_video_info(video_path: Path | str) -> dict:
|
def get_video_info(video_path: Path | str) -> dict:
|
||||||
ffprobe_video_cmd = [
|
ffprobe_video_cmd = [
|
||||||
"ffprobe",
|
"ffprobe",
|
||||||
|
@ -531,7 +370,6 @@ def get_video_info(video_path: Path | str) -> dict:
|
||||||
"video.codec": video_stream_info["codec_name"],
|
"video.codec": video_stream_info["codec_name"],
|
||||||
"video.pix_fmt": video_stream_info["pix_fmt"],
|
"video.pix_fmt": video_stream_info["pix_fmt"],
|
||||||
"video.is_depth_map": False,
|
"video.is_depth_map": False,
|
||||||
**get_audio_info(video_path),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return video_info
|
return video_info
|
||||||
|
|
|
@ -48,8 +48,6 @@ class OpenCVCameraConfig(CameraConfig):
|
||||||
rotation: int | None = None
|
rotation: int | None = None
|
||||||
mock: bool = False
|
mock: bool = False
|
||||||
|
|
||||||
microphone: str | None = None
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.color_mode not in ["rgb", "bgr"]:
|
if self.color_mode not in ["rgb", "bgr"]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
|
@ -265,8 +265,6 @@ class IntelRealSenseCamera:
|
||||||
elif config.rotation == 180:
|
elif config.rotation == 180:
|
||||||
self.rotation = cv2.ROTATE_180
|
self.rotation = cv2.ROTATE_180
|
||||||
|
|
||||||
self.microphone = None # No microphones on realsense cameras, sorry
|
|
||||||
|
|
||||||
def find_serial_number_from_name(self, name):
|
def find_serial_number_from_name(self, name):
|
||||||
camera_infos = find_cameras()
|
camera_infos = find_cameras()
|
||||||
camera_names = [cam["name"] for cam in camera_infos]
|
camera_names = [cam["name"] for cam in camera_infos]
|
||||||
|
|
|
@ -281,8 +281,6 @@ class OpenCVCamera:
|
||||||
elif config.rotation == 180:
|
elif config.rotation == 180:
|
||||||
self.rotation = cv2.ROTATE_180
|
self.rotation = cv2.ROTATE_180
|
||||||
|
|
||||||
self.microphone = config.microphone
|
|
||||||
|
|
||||||
def connect(self):
|
def connect(self):
|
||||||
if self.is_connected:
|
if self.is_connected:
|
||||||
raise RobotDeviceAlreadyConnectedError(f"OpenCVCamera({self.camera_index}) is already connected.")
|
raise RobotDeviceAlreadyConnectedError(f"OpenCVCamera({self.camera_index}) is already connected.")
|
||||||
|
|
|
@ -486,7 +486,6 @@ class So100RobotConfig(ManipulatorRobotConfig):
|
||||||
fps=30,
|
fps=30,
|
||||||
width=640,
|
width=640,
|
||||||
height=480,
|
height=480,
|
||||||
microphone="laptop",
|
|
||||||
),
|
),
|
||||||
"phone": OpenCVCameraConfig(
|
"phone": OpenCVCameraConfig(
|
||||||
camera_index=1,
|
camera_index=1,
|
||||||
|
|
|
@ -181,7 +181,6 @@ class ManipulatorRobot:
|
||||||
"shape": (cam.height, cam.width, cam.channels),
|
"shape": (cam.height, cam.width, cam.channels),
|
||||||
"names": ["height", "width", "channels"],
|
"names": ["height", "width", "channels"],
|
||||||
"info": None,
|
"info": None,
|
||||||
"audio": "observation.audio." + cam.microphone if cam.microphone is not None else None,
|
|
||||||
}
|
}
|
||||||
return cam_ft
|
return cam_ft
|
||||||
|
|
||||||
|
@ -211,7 +210,9 @@ class ManipulatorRobot:
|
||||||
"dtype": "audio",
|
"dtype": "audio",
|
||||||
"shape": (len(mic.channels),),
|
"shape": (len(mic.channels),),
|
||||||
"names": "channels",
|
"names": "channels",
|
||||||
"info": {"sample_rate": mic.sample_rate},
|
"info": {
|
||||||
|
"sample_rate": mic.sample_rate
|
||||||
|
}, # we need to store the sample rate here in the case of audio chunks recording (for LeKiwi), as it will not be available anymore when writing the audio file
|
||||||
}
|
}
|
||||||
return mic_ft
|
return mic_ft
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue