Adding audio modality in LeRobotDatasets
This commit is contained in:
parent
8ddfb299fd
commit
8ee61bb81f
|
@ -36,8 +36,11 @@ from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
|
|||
from lerobot.common.datasets.utils import (
|
||||
DEFAULT_FEATURES,
|
||||
DEFAULT_IMAGE_PATH,
|
||||
DEFAULT_RAW_AUDIO_PATH,
|
||||
DEFAULT_COMPRESSED_AUDIO_PATH,
|
||||
INFO_PATH,
|
||||
TASKS_PATH,
|
||||
DEFAULT_AUDIO_CHUNK_DURATION,
|
||||
append_jsonlines,
|
||||
backward_compatible_episodes_stats,
|
||||
check_delta_timestamps,
|
||||
|
@ -69,8 +72,11 @@ from lerobot.common.datasets.video_utils import (
|
|||
VideoFrame,
|
||||
decode_video_frames,
|
||||
encode_video_frames,
|
||||
encode_audio,
|
||||
decode_audio,
|
||||
get_safe_default_codec,
|
||||
get_video_info,
|
||||
get_audio_info,
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
|
||||
|
@ -142,6 +148,11 @@ class LeRobotDatasetMetadata:
|
|||
fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
|
||||
return Path(fpath)
|
||||
|
||||
def get_compressed_audio_file_path(self, episode_index: int, audio_key: str) -> Path:
|
||||
episode_chunk = self.get_episode_chunk(episode_index)
|
||||
fpath = self.audio_path.format(episode_chunk=episode_chunk, audio_key=audio_key, episode_index=episode_index)
|
||||
return self.root / fpath
|
||||
|
||||
def get_episode_chunk(self, ep_index: int) -> int:
|
||||
return ep_index // self.chunks_size
|
||||
|
||||
|
@ -155,6 +166,11 @@ class LeRobotDatasetMetadata:
|
|||
"""Formattable string for the video files."""
|
||||
return self.info["video_path"]
|
||||
|
||||
@property
|
||||
def audio_path(self) -> str | None:
|
||||
"""Formattable string for the audio files."""
|
||||
return self.info["audio_path"]
|
||||
|
||||
@property
|
||||
def robot_type(self) -> str | None:
|
||||
"""Robot type used in recording this dataset."""
|
||||
|
@ -185,6 +201,11 @@ class LeRobotDatasetMetadata:
|
|||
"""Keys to access visual modalities (regardless of their storage method)."""
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
|
||||
|
||||
@property
|
||||
def audio_keys(self) -> list[str]:
|
||||
"""Keys to access audio modalities."""
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] == "audio"]
|
||||
|
||||
@property
|
||||
def names(self) -> dict[str, list | dict]:
|
||||
"""Names of the various dimensions of vector modalities."""
|
||||
|
@ -264,6 +285,9 @@ class LeRobotDatasetMetadata:
|
|||
if len(self.video_keys) > 0:
|
||||
self.update_video_info()
|
||||
|
||||
if len(self.audio_keys) > 0:
|
||||
self.update_audio_info()
|
||||
|
||||
write_info(self.info, self.root)
|
||||
|
||||
episode_dict = {
|
||||
|
@ -288,6 +312,17 @@ class LeRobotDatasetMetadata:
|
|||
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
|
||||
self.info["features"][key]["info"] = get_video_info(video_path)
|
||||
|
||||
def update_audio_info(self) -> None:
|
||||
"""
|
||||
Warning: this function writes info from first episode audio, implicitly assuming that all audio have
|
||||
been encoded the same way. Also, this means it assumes the first episode exists.
|
||||
"""
|
||||
bound_audio_keys = {self.features[video_key]["audio"] for video_key in self.video_keys if self.features[video_key]["audio"] is not None}
|
||||
for key in set(self.audio_keys) - bound_audio_keys:
|
||||
if not self.features[key].get("info", None):
|
||||
audio_path = self.root / self.get_compressed_audio_file_path(0, key)
|
||||
self.info["features"][key]["info"] = get_audio_info(audio_path)
|
||||
|
||||
def __repr__(self):
|
||||
feature_keys = list(self.features)
|
||||
return (
|
||||
|
@ -364,6 +399,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
force_cache_sync: bool = False,
|
||||
download_videos: bool = True,
|
||||
video_backend: str | None = None,
|
||||
audio_backend: str | None = None,
|
||||
):
|
||||
"""
|
||||
2 modes are available for instantiating this class, depending on 2 different use cases:
|
||||
|
@ -465,6 +501,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
True.
|
||||
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'.
|
||||
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
|
||||
audio_backend (str | None, optional): Audio backend to use for decoding audio. Defaults to 'ffmpeg'.
|
||||
"""
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
|
@ -475,6 +512,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
self.tolerance_s = tolerance_s
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self.video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||
self.audio_backend = audio_backend if audio_backend else "ffmpeg" #Waiting for torchcodec release #TODO(CarolinePascal)
|
||||
self.delta_indices = None
|
||||
|
||||
# Unused attributes
|
||||
|
@ -499,7 +537,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
self.hf_dataset = self.load_hf_dataset()
|
||||
except (AssertionError, FileNotFoundError, NotADirectoryError):
|
||||
self.revision = get_safe_version(self.repo_id, self.revision)
|
||||
self.download_episodes(download_videos)
|
||||
self.download_episodes(download_videos) #Sould load audio as well #TODO(CarolinePascal): separate audio from video
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
|
||||
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
||||
|
@ -677,7 +715,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
}
|
||||
return query_indices, padding
|
||||
|
||||
def _get_query_timestamps(
|
||||
def _get_query_timestamps_video(
|
||||
self,
|
||||
current_ts: float,
|
||||
query_indices: dict[str, list[int]] | None = None,
|
||||
|
@ -692,6 +730,22 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
|
||||
return query_timestamps
|
||||
|
||||
#TODO(CarolinePascal): add variable query durations
|
||||
def _get_query_timestamps_audio(
|
||||
self,
|
||||
current_ts: float,
|
||||
query_indices: dict[str, list[int]] | None = None,
|
||||
) -> dict[str, list[float]]:
|
||||
query_timestamps = {}
|
||||
for key in self.meta.audio_keys:
|
||||
if query_indices is not None and key in query_indices:
|
||||
timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
|
||||
query_timestamps[key] = torch.stack(timestamps).tolist()
|
||||
else:
|
||||
query_timestamps[key] = [current_ts]
|
||||
|
||||
return query_timestamps
|
||||
|
||||
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
||||
return {
|
||||
key: torch.stack(self.hf_dataset.select(q_idx)[key])
|
||||
|
@ -713,6 +767,21 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
|
||||
return item
|
||||
|
||||
#TODO(CarolinePascal): add variable query durations
|
||||
def _query_audio(self, query_timestamps: dict[str, list[float]], query_duration: float, ep_idx: int) -> dict[str, torch.Tensor]:
|
||||
item = {}
|
||||
bound_audio_keys_mapping = {self.meta.features[video_key]["audio"]:video_key for video_key in self.meta.video_keys if self.meta.features[video_key]["audio"] is not None}
|
||||
for audio_key, query_ts in query_timestamps.items():
|
||||
#Audio stored with video in a single .mp4 file
|
||||
if audio_key in bound_audio_keys_mapping.keys():
|
||||
audio_path = self.root / self.meta.get_video_file_path(ep_idx, bound_audio_keys_mapping[audio_key])
|
||||
#Audio stored alone in a separate .m4a file
|
||||
else:
|
||||
audio_path = self.root / self.meta.get_compressed_audio_file_path(ep_idx, audio_key)
|
||||
audio_chunk = decode_audio(audio_path, query_ts, query_duration, self.audio_backend)
|
||||
item[audio_key] = audio_chunk.squeeze(0)
|
||||
return item
|
||||
|
||||
def _add_padding_keys(self, item: dict, padding: dict[str, list[bool]]) -> dict:
|
||||
for key, val in padding.items():
|
||||
item[key] = torch.BoolTensor(val)
|
||||
|
@ -733,11 +802,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
for key, val in query_result.items():
|
||||
item[key] = val
|
||||
|
||||
if len(self.meta.video_keys) > 0:
|
||||
if len(self.meta.video_keys) > 0 or len(self.meta.audio_keys) > 0:
|
||||
current_ts = item["timestamp"].item()
|
||||
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
|
||||
|
||||
query_timestamps = self._get_query_timestamps_video(current_ts, query_indices)
|
||||
video_frames = self._query_videos(query_timestamps, ep_idx)
|
||||
item = {**video_frames, **item}
|
||||
item = {**item, **video_frames}
|
||||
|
||||
query_timestamps = self._get_query_timestamps_audio(current_ts, query_indices)
|
||||
audio_chunks = self._query_audio(query_timestamps, DEFAULT_AUDIO_CHUNK_DURATION, ep_idx)
|
||||
item = {**item, **audio_chunks}
|
||||
|
||||
if self.image_transforms is not None:
|
||||
image_keys = self.meta.camera_keys
|
||||
|
@ -777,6 +851,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
)
|
||||
return self.root / fpath
|
||||
|
||||
def _get_raw_audio_file_path(self, episode_index: int, audio_key: str) -> Path:
|
||||
fpath = DEFAULT_RAW_AUDIO_PATH.format(audio_key=audio_key, episode_index=episode_index)
|
||||
return self.root / fpath
|
||||
|
||||
def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None:
|
||||
if self.image_writer is None:
|
||||
if isinstance(image, torch.Tensor):
|
||||
|
@ -867,7 +945,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
for key, ft in self.features.items():
|
||||
# index, episode_index, task_index are already processed above, and image and video
|
||||
# are processed separately by storing image path and frame info as meta data
|
||||
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
|
||||
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video", "audio"]:
|
||||
continue
|
||||
episode_buffer[key] = np.stack(episode_buffer[key])
|
||||
|
||||
|
@ -880,6 +958,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
for key in self.meta.video_keys:
|
||||
episode_buffer[key] = video_paths[key]
|
||||
|
||||
if len(self.meta.audio_keys) > 0:
|
||||
_ = self.encode_episode_audio(episode_index)
|
||||
|
||||
# `meta.save_episode` be executed after encoding the videos
|
||||
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
|
||||
|
||||
|
@ -904,6 +985,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
if img_dir.is_dir():
|
||||
shutil.rmtree(self.root / "images")
|
||||
|
||||
# delete raw audio
|
||||
raw_audio_files = list(self.root.rglob("*.wav"))
|
||||
for raw_audio_file in raw_audio_files:
|
||||
raw_audio_file.unlink()
|
||||
if len(list(raw_audio_file.parent.iterdir())) == 0:
|
||||
raw_audio_file.parent.rmdir()
|
||||
|
||||
if not episode_data: # Reset the buffer
|
||||
self.episode_buffer = self.create_episode_buffer()
|
||||
|
||||
|
@ -971,19 +1059,46 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
since video encoding with ffmpeg is already using multithreading.
|
||||
"""
|
||||
video_paths = {}
|
||||
for key in self.meta.video_keys:
|
||||
video_path = self.root / self.meta.get_video_file_path(episode_index, key)
|
||||
video_paths[key] = str(video_path)
|
||||
for video_key in self.meta.video_keys:
|
||||
video_path = self.root / self.meta.get_video_file_path(episode_index, video_key)
|
||||
video_paths[video_key] = str(video_path)
|
||||
if video_path.is_file():
|
||||
# Skip if video is already encoded. Could be the case when resuming data recording.
|
||||
continue
|
||||
img_dir = self._get_image_file_path(
|
||||
episode_index=episode_index, image_key=key, frame_index=0
|
||||
episode_index=episode_index, image_key=video_key, frame_index=0
|
||||
).parent
|
||||
encode_video_frames(img_dir, video_path, self.fps, overwrite=True)
|
||||
|
||||
audio_path = None
|
||||
if self.meta.features[video_key]["audio"] is not None:
|
||||
audio_key = self.meta.features[video_key]["audio"]
|
||||
audio_path = self._get_raw_audio_file_path(episode_index, audio_key)
|
||||
|
||||
encode_video_frames(img_dir, video_path, self.fps, audio_path=audio_path, overwrite=True)
|
||||
|
||||
return video_paths
|
||||
|
||||
def encode_episode_audio(self, episode_index: int) -> dict:
|
||||
"""
|
||||
Use ffmpeg to convert .wav raw audio files into .m4a audio files.
|
||||
Note: `encode_episode_audio` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||
since video encoding with ffmpeg is already using multithreading.
|
||||
"""
|
||||
audio_paths = {}
|
||||
bound_audio_keys = {self.meta.features[video_key]["audio"] for video_key in self.meta.video_keys if self.meta.features[video_key]["audio"] is not None}
|
||||
for audio_key in set(self.meta.audio_keys) - bound_audio_keys:
|
||||
input_audio_path = self.root / self._get_raw_audio_file_path(episode_index, audio_key)
|
||||
output_audio_path = self.root / self.meta.get_compressed_audio_file_path(episode_index, audio_key)
|
||||
|
||||
audio_paths[audio_key] = str(output_audio_path)
|
||||
if output_audio_path.is_file():
|
||||
# Skip if video is already encoded. Could be the case when resuming data recording.
|
||||
continue
|
||||
|
||||
encode_audio(input_audio_path, output_audio_path, overwrite=True)
|
||||
|
||||
return audio_paths
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
|
@ -998,6 +1113,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
image_writer_processes: int = 0,
|
||||
image_writer_threads: int = 0,
|
||||
video_backend: str | None = None,
|
||||
audio_backend: str | None = None,
|
||||
) -> "LeRobotDataset":
|
||||
"""Create a LeRobot Dataset from scratch in order to record data."""
|
||||
obj = cls.__new__(cls)
|
||||
|
@ -1029,6 +1145,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
obj.delta_indices = None
|
||||
obj.episode_data_index = None
|
||||
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
||||
obj.audio_backend = audio_backend if audio_backend is not None else "ffmpeg" #Waiting for torchcodec release #TODO(CarolinePascal)
|
||||
return obj
|
||||
|
||||
|
||||
|
@ -1049,6 +1166,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|||
tolerances_s: dict | None = None,
|
||||
download_videos: bool = True,
|
||||
video_backend: str | None = None,
|
||||
audio_backend: str | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.repo_ids = repo_ids
|
||||
|
@ -1066,6 +1184,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|||
tolerance_s=self.tolerances_s[repo_id],
|
||||
download_videos=download_videos,
|
||||
video_backend=video_backend,
|
||||
audio_backend=audio_backend,
|
||||
)
|
||||
for repo_id in repo_ids
|
||||
]
|
||||
|
|
|
@ -55,6 +55,10 @@ TASKS_PATH = "meta/tasks.jsonl"
|
|||
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
||||
DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
|
||||
DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
|
||||
DEFAULT_RAW_AUDIO_PATH = "audio/{audio_key}/episode_{episode_index:06d}.wav"
|
||||
DEFAULT_COMPRESSED_AUDIO_PATH = "audio/chunk-{episode_chunk:03d}/{audio_key}/episode_{episode_index:06d}.m4a"
|
||||
|
||||
DEFAULT_AUDIO_CHUNK_DURATION = 0.5 # seconds
|
||||
|
||||
DATASET_CARD_TEMPLATE = """
|
||||
---
|
||||
|
@ -363,7 +367,7 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
|
|||
def get_hf_features_from_features(features: dict) -> datasets.Features:
|
||||
hf_features = {}
|
||||
for key, ft in features.items():
|
||||
if ft["dtype"] == "video":
|
||||
if ft["dtype"] == "video" or ft["dtype"] == "audio":
|
||||
continue
|
||||
elif ft["dtype"] == "image":
|
||||
hf_features[key] = datasets.Image()
|
||||
|
@ -394,7 +398,13 @@ def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict:
|
|||
key: {"dtype": "video" if use_videos else "image", **ft}
|
||||
for key, ft in robot.camera_features.items()
|
||||
}
|
||||
return {**robot.motor_features, **camera_ft, **DEFAULT_FEATURES}
|
||||
microphones_ft = {}
|
||||
if robot.microphones:
|
||||
microphones_ft = {
|
||||
key: {"dtype": "audio", **ft}
|
||||
for key, ft in robot.microphones_features.items()
|
||||
}
|
||||
return {**robot.motor_features, **camera_ft, **microphones_ft, **DEFAULT_FEATURES}
|
||||
|
||||
|
||||
def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]:
|
||||
|
@ -448,6 +458,7 @@ def create_empty_dataset_info(
|
|||
"splits": {},
|
||||
"data_path": DEFAULT_PARQUET_PATH,
|
||||
"video_path": DEFAULT_VIDEO_PATH if use_videos else None,
|
||||
"audio_path": DEFAULT_COMPRESSED_AUDIO_PATH,
|
||||
"features": features,
|
||||
}
|
||||
|
||||
|
@ -721,6 +732,7 @@ def validate_features_presence(
|
|||
):
|
||||
error_message = ""
|
||||
missing_features = expected_features - actual_features
|
||||
missing_features = {feature for feature in missing_features if "observation.audio" not in feature}
|
||||
extra_features = actual_features - (expected_features | optional_features)
|
||||
|
||||
if missing_features or extra_features:
|
||||
|
|
|
@ -26,9 +26,12 @@ from typing import Any, ClassVar
|
|||
import pyarrow as pa
|
||||
import torch
|
||||
import torchvision
|
||||
import torchaudio
|
||||
from datasets.features.features import register_feature
|
||||
from PIL import Image
|
||||
|
||||
from numpy import ceil
|
||||
|
||||
|
||||
def get_safe_default_codec():
|
||||
if importlib.util.find_spec("torchcodec"):
|
||||
|
@ -39,6 +42,71 @@ def get_safe_default_codec():
|
|||
)
|
||||
return "pyav"
|
||||
|
||||
def decode_audio(
|
||||
audio_path: Path | str,
|
||||
timestamps: list[float],
|
||||
duration: float,
|
||||
backend: str | None = "ffmpeg",
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Decodes audio using the specified backend.
|
||||
Args:
|
||||
audio_path (Path): Path to the audio file.
|
||||
timestamps (list[float]): List of timestamps to extract frames.
|
||||
tolerance_s (float): Allowed deviation in seconds for frame retrieval.
|
||||
backend (str, optional): Backend to use for decoding. Defaults to "pyav".
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Decoded frames.
|
||||
|
||||
Currently supports pyav.
|
||||
"""
|
||||
if backend == "torchcodec":
|
||||
raise NotImplementedError("torchcodec is not yet supported for audio decoding")
|
||||
elif backend == "ffmpeg":
|
||||
return decode_audio_torchvision(audio_path, timestamps, duration)
|
||||
else:
|
||||
raise ValueError(f"Unsupported video backend: {backend}")
|
||||
|
||||
def decode_audio_torchvision(
|
||||
audio_path: Path | str,
|
||||
timestamps: list[float],
|
||||
duration: float,
|
||||
log_loaded_timestamps: bool = False,
|
||||
) -> torch.Tensor:
|
||||
|
||||
#TODO(CarolinePascal) : add channels selection
|
||||
audio_path = str(audio_path)
|
||||
|
||||
reader = torchaudio.io.StreamReader(src=audio_path)
|
||||
audio_sampling_rate = reader.get_src_stream_info(reader.default_audio_stream).sample_rate
|
||||
|
||||
#TODO(CarolinePascal) : sort timestamps ?
|
||||
|
||||
reader.add_basic_audio_stream(
|
||||
frames_per_chunk = int(ceil(duration * audio_sampling_rate)), #Too much is better than not enough
|
||||
buffer_chunk_size = -1, #No dropping frames
|
||||
)
|
||||
|
||||
audio_chunks = []
|
||||
for ts in timestamps:
|
||||
reader.seek(ts) #Default to closest audio sample
|
||||
status = reader.fill_buffer()
|
||||
if status != 0:
|
||||
logging.warning("Audio stream reached end of recording before decoding desired timestamps.")
|
||||
|
||||
current_audio_chunk = reader.pop_chunks()[0]
|
||||
|
||||
if log_loaded_timestamps:
|
||||
logging.info(f"audio chunk loaded at starting timestamp={current_audio_chunk["pts"]:.4f} with duration={len(current_audio_chunk) / audio_sampling_rate:.4f}")
|
||||
|
||||
audio_chunks.append(current_audio_chunk)
|
||||
|
||||
audio_chunks = torch.stack(audio_chunks)
|
||||
#TODO(CarolinePascal) : pytorch format conversion ?
|
||||
|
||||
assert len(timestamps) == len(audio_chunks)
|
||||
return audio_chunks
|
||||
|
||||
def decode_video_frames(
|
||||
video_path: Path | str,
|
||||
|
@ -69,7 +137,6 @@ def decode_video_frames(
|
|||
else:
|
||||
raise ValueError(f"Unsupported video backend: {backend}")
|
||||
|
||||
|
||||
def decode_video_frames_torchvision(
|
||||
video_path: Path | str,
|
||||
timestamps: list[float],
|
||||
|
@ -167,7 +234,6 @@ def decode_video_frames_torchvision(
|
|||
assert len(timestamps) == len(closest_frames)
|
||||
return closest_frames
|
||||
|
||||
|
||||
def decode_video_frames_torchcodec(
|
||||
video_path: Path | str,
|
||||
timestamps: list[float],
|
||||
|
@ -242,15 +308,52 @@ def decode_video_frames_torchcodec(
|
|||
assert len(timestamps) == len(closest_frames)
|
||||
return closest_frames
|
||||
|
||||
def encode_audio(
|
||||
input_path: Path | str,
|
||||
output_path: Path | str,
|
||||
codec: str = "aac",
|
||||
log_level: str | None = "error",
|
||||
overwrite: bool = False,
|
||||
) -> None:
|
||||
"""Encodes an audio file using ffmpeg."""
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ffmpeg_args = OrderedDict(
|
||||
[
|
||||
("-i", str(input_path)),
|
||||
("-acodec", codec),
|
||||
]
|
||||
)
|
||||
|
||||
if log_level is not None:
|
||||
ffmpeg_args["-loglevel"] = str(log_level)
|
||||
|
||||
ffmpeg_args = [item for pair in ffmpeg_args.items() for item in pair]
|
||||
if overwrite:
|
||||
ffmpeg_args.append("-y")
|
||||
|
||||
ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(output_path)]
|
||||
|
||||
# redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal
|
||||
subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL)
|
||||
|
||||
if not output_path.exists():
|
||||
raise OSError(
|
||||
f"Video encoding did not work. File not found: {output_path}. "
|
||||
f"Try running the command manually to debug: `{''.join(ffmpeg_cmd)}`"
|
||||
)
|
||||
|
||||
def encode_video_frames(
|
||||
imgs_dir: Path | str,
|
||||
video_path: Path | str,
|
||||
fps: int,
|
||||
audio_path: Path | str | None = None,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
g: int | None = 2,
|
||||
crf: int | None = 30,
|
||||
acodec: str = "aac", #TODO(CarolinePascal) : investigate Fraunhofer FDK AAC (libfdk_aac) codec and and constant (file size control) /variable (quality control) bitrate options
|
||||
fast_decode: int = 0,
|
||||
log_level: str | None = "error",
|
||||
overwrite: bool = False,
|
||||
|
@ -260,35 +363,53 @@ def encode_video_frames(
|
|||
imgs_dir = Path(imgs_dir)
|
||||
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ffmpeg_args = OrderedDict(
|
||||
ffmpeg_video_args = OrderedDict(
|
||||
[
|
||||
("-f", "image2"),
|
||||
("-r", str(fps)),
|
||||
("-i", str(imgs_dir / "frame_%06d.png")),
|
||||
("-vcodec", vcodec),
|
||||
("-pix_fmt", pix_fmt),
|
||||
("-i", str(Path(imgs_dir) / "frame_%06d.png")),
|
||||
]
|
||||
)
|
||||
|
||||
ffmpeg_audio_args = OrderedDict()
|
||||
if audio_path is not None:
|
||||
audio_path = Path(audio_path)
|
||||
audio_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
ffmpeg_audio_args.update(OrderedDict(
|
||||
[
|
||||
("-i", str(audio_path)),
|
||||
]
|
||||
))
|
||||
|
||||
ffmpeg_encoding_args = OrderedDict(
|
||||
[
|
||||
("-pix_fmt", pix_fmt),
|
||||
("-vcodec", vcodec),
|
||||
]
|
||||
)
|
||||
if g is not None:
|
||||
ffmpeg_args["-g"] = str(g)
|
||||
|
||||
ffmpeg_encoding_args["-g"] = str(g)
|
||||
if crf is not None:
|
||||
ffmpeg_args["-crf"] = str(crf)
|
||||
|
||||
ffmpeg_encoding_args["-crf"] = str(crf)
|
||||
if fast_decode:
|
||||
key = "-svtav1-params" if vcodec == "libsvtav1" else "-tune"
|
||||
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
|
||||
ffmpeg_args[key] = value
|
||||
ffmpeg_encoding_args[key] = value
|
||||
|
||||
if audio_path is not None:
|
||||
ffmpeg_encoding_args["-acodec"] = acodec
|
||||
|
||||
if log_level is not None:
|
||||
ffmpeg_args["-loglevel"] = str(log_level)
|
||||
ffmpeg_encoding_args["-loglevel"] = str(log_level)
|
||||
|
||||
ffmpeg_args = [item for pair in ffmpeg_args.items() for item in pair]
|
||||
ffmpeg_args = [item for pair in ffmpeg_video_args.items() for item in pair]
|
||||
ffmpeg_args += [item for pair in ffmpeg_audio_args.items() for item in pair]
|
||||
ffmpeg_args += [item for pair in ffmpeg_encoding_args.items() for item in pair]
|
||||
if overwrite:
|
||||
ffmpeg_args.append("-y")
|
||||
|
||||
ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(video_path)]
|
||||
|
||||
# redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal
|
||||
subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL)
|
||||
|
||||
|
@ -366,7 +487,6 @@ def get_audio_info(video_path: Path | str) -> dict:
|
|||
"audio.channel_layout": audio_stream_info.get("channel_layout", None),
|
||||
}
|
||||
|
||||
|
||||
def get_video_info(video_path: Path | str) -> dict:
|
||||
ffprobe_video_cmd = [
|
||||
"ffprobe",
|
||||
|
|
|
@ -216,7 +216,7 @@ class ManipulatorRobot:
|
|||
|
||||
@property
|
||||
def features(self):
|
||||
return {**self.motor_features, **self.camera_features}
|
||||
return {**self.motor_features, **self.camera_features, **self.microphones_features}
|
||||
|
||||
@property
|
||||
def has_camera(self):
|
||||
|
|
|
@ -29,6 +29,7 @@ from lerobot.common.datasets.utils import (
|
|||
DEFAULT_FEATURES,
|
||||
DEFAULT_PARQUET_PATH,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
DEFAULT_COMPRESSED_AUDIO_PATH,
|
||||
get_hf_features_from_features,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
|
@ -121,6 +122,7 @@ def info_factory(features_factory):
|
|||
chunks_size: int = DEFAULT_CHUNK_SIZE,
|
||||
data_path: str = DEFAULT_PARQUET_PATH,
|
||||
video_path: str = DEFAULT_VIDEO_PATH,
|
||||
audio_path: str = DEFAULT_COMPRESSED_AUDIO_PATH,
|
||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
||||
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
||||
use_videos: bool = True,
|
||||
|
@ -139,6 +141,7 @@ def info_factory(features_factory):
|
|||
"splits": {},
|
||||
"data_path": data_path,
|
||||
"video_path": video_path if use_videos else None,
|
||||
"audio_path": audio_path,
|
||||
"features": features,
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue