Adding audio modality in LeRobotDatasets

This commit is contained in:
CarolinePascal 2025-03-28 17:16:51 +01:00
parent 8ddfb299fd
commit 8ee61bb81f
No known key found for this signature in database
5 changed files with 282 additions and 28 deletions

View File

@ -36,8 +36,11 @@ from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
DEFAULT_FEATURES, DEFAULT_FEATURES,
DEFAULT_IMAGE_PATH, DEFAULT_IMAGE_PATH,
DEFAULT_RAW_AUDIO_PATH,
DEFAULT_COMPRESSED_AUDIO_PATH,
INFO_PATH, INFO_PATH,
TASKS_PATH, TASKS_PATH,
DEFAULT_AUDIO_CHUNK_DURATION,
append_jsonlines, append_jsonlines,
backward_compatible_episodes_stats, backward_compatible_episodes_stats,
check_delta_timestamps, check_delta_timestamps,
@ -69,8 +72,11 @@ from lerobot.common.datasets.video_utils import (
VideoFrame, VideoFrame,
decode_video_frames, decode_video_frames,
encode_video_frames, encode_video_frames,
encode_audio,
decode_audio,
get_safe_default_codec, get_safe_default_codec,
get_video_info, get_video_info,
get_audio_info,
) )
from lerobot.common.robot_devices.robots.utils import Robot from lerobot.common.robot_devices.robots.utils import Robot
@ -141,6 +147,11 @@ class LeRobotDatasetMetadata:
ep_chunk = self.get_episode_chunk(ep_index) ep_chunk = self.get_episode_chunk(ep_index)
fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index) fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
return Path(fpath) return Path(fpath)
def get_compressed_audio_file_path(self, episode_index: int, audio_key: str) -> Path:
episode_chunk = self.get_episode_chunk(episode_index)
fpath = self.audio_path.format(episode_chunk=episode_chunk, audio_key=audio_key, episode_index=episode_index)
return self.root / fpath
def get_episode_chunk(self, ep_index: int) -> int: def get_episode_chunk(self, ep_index: int) -> int:
return ep_index // self.chunks_size return ep_index // self.chunks_size
@ -154,6 +165,11 @@ class LeRobotDatasetMetadata:
def video_path(self) -> str | None: def video_path(self) -> str | None:
"""Formattable string for the video files.""" """Formattable string for the video files."""
return self.info["video_path"] return self.info["video_path"]
@property
def audio_path(self) -> str | None:
"""Formattable string for the audio files."""
return self.info["audio_path"]
@property @property
def robot_type(self) -> str | None: def robot_type(self) -> str | None:
@ -184,6 +200,11 @@ class LeRobotDatasetMetadata:
def camera_keys(self) -> list[str]: def camera_keys(self) -> list[str]:
"""Keys to access visual modalities (regardless of their storage method).""" """Keys to access visual modalities (regardless of their storage method)."""
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]] return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
@property
def audio_keys(self) -> list[str]:
"""Keys to access audio modalities."""
return [key for key, ft in self.features.items() if ft["dtype"] == "audio"]
@property @property
def names(self) -> dict[str, list | dict]: def names(self) -> dict[str, list | dict]:
@ -264,6 +285,9 @@ class LeRobotDatasetMetadata:
if len(self.video_keys) > 0: if len(self.video_keys) > 0:
self.update_video_info() self.update_video_info()
if len(self.audio_keys) > 0:
self.update_audio_info()
write_info(self.info, self.root) write_info(self.info, self.root)
episode_dict = { episode_dict = {
@ -288,6 +312,17 @@ class LeRobotDatasetMetadata:
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key) video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
self.info["features"][key]["info"] = get_video_info(video_path) self.info["features"][key]["info"] = get_video_info(video_path)
def update_audio_info(self) -> None:
"""
Warning: this function writes info from first episode audio, implicitly assuming that all audio have
been encoded the same way. Also, this means it assumes the first episode exists.
"""
bound_audio_keys = {self.features[video_key]["audio"] for video_key in self.video_keys if self.features[video_key]["audio"] is not None}
for key in set(self.audio_keys) - bound_audio_keys:
if not self.features[key].get("info", None):
audio_path = self.root / self.get_compressed_audio_file_path(0, key)
self.info["features"][key]["info"] = get_audio_info(audio_path)
def __repr__(self): def __repr__(self):
feature_keys = list(self.features) feature_keys = list(self.features)
return ( return (
@ -364,6 +399,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
force_cache_sync: bool = False, force_cache_sync: bool = False,
download_videos: bool = True, download_videos: bool = True,
video_backend: str | None = None, video_backend: str | None = None,
audio_backend: str | None = None,
): ):
""" """
2 modes are available for instantiating this class, depending on 2 different use cases: 2 modes are available for instantiating this class, depending on 2 different use cases:
@ -465,6 +501,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
True. True.
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'. video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'.
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision. You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
audio_backend (str | None, optional): Audio backend to use for decoding audio. Defaults to 'ffmpeg'.
""" """
super().__init__() super().__init__()
self.repo_id = repo_id self.repo_id = repo_id
@ -475,6 +512,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.tolerance_s = tolerance_s self.tolerance_s = tolerance_s
self.revision = revision if revision else CODEBASE_VERSION self.revision = revision if revision else CODEBASE_VERSION
self.video_backend = video_backend if video_backend else get_safe_default_codec() self.video_backend = video_backend if video_backend else get_safe_default_codec()
self.audio_backend = audio_backend if audio_backend else "ffmpeg" #Waiting for torchcodec release #TODO(CarolinePascal)
self.delta_indices = None self.delta_indices = None
# Unused attributes # Unused attributes
@ -499,7 +537,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.hf_dataset = self.load_hf_dataset() self.hf_dataset = self.load_hf_dataset()
except (AssertionError, FileNotFoundError, NotADirectoryError): except (AssertionError, FileNotFoundError, NotADirectoryError):
self.revision = get_safe_version(self.repo_id, self.revision) self.revision = get_safe_version(self.repo_id, self.revision)
self.download_episodes(download_videos) self.download_episodes(download_videos) #Sould load audio as well #TODO(CarolinePascal): separate audio from video
self.hf_dataset = self.load_hf_dataset() self.hf_dataset = self.load_hf_dataset()
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes) self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
@ -677,7 +715,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
} }
return query_indices, padding return query_indices, padding
def _get_query_timestamps( def _get_query_timestamps_video(
self, self,
current_ts: float, current_ts: float,
query_indices: dict[str, list[int]] | None = None, query_indices: dict[str, list[int]] | None = None,
@ -691,6 +729,22 @@ class LeRobotDataset(torch.utils.data.Dataset):
query_timestamps[key] = [current_ts] query_timestamps[key] = [current_ts]
return query_timestamps return query_timestamps
#TODO(CarolinePascal): add variable query durations
def _get_query_timestamps_audio(
self,
current_ts: float,
query_indices: dict[str, list[int]] | None = None,
) -> dict[str, list[float]]:
query_timestamps = {}
for key in self.meta.audio_keys:
if query_indices is not None and key in query_indices:
timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
query_timestamps[key] = torch.stack(timestamps).tolist()
else:
query_timestamps[key] = [current_ts]
return query_timestamps
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict: def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
return { return {
@ -713,6 +767,21 @@ class LeRobotDataset(torch.utils.data.Dataset):
return item return item
#TODO(CarolinePascal): add variable query durations
def _query_audio(self, query_timestamps: dict[str, list[float]], query_duration: float, ep_idx: int) -> dict[str, torch.Tensor]:
item = {}
bound_audio_keys_mapping = {self.meta.features[video_key]["audio"]:video_key for video_key in self.meta.video_keys if self.meta.features[video_key]["audio"] is not None}
for audio_key, query_ts in query_timestamps.items():
#Audio stored with video in a single .mp4 file
if audio_key in bound_audio_keys_mapping.keys():
audio_path = self.root / self.meta.get_video_file_path(ep_idx, bound_audio_keys_mapping[audio_key])
#Audio stored alone in a separate .m4a file
else:
audio_path = self.root / self.meta.get_compressed_audio_file_path(ep_idx, audio_key)
audio_chunk = decode_audio(audio_path, query_ts, query_duration, self.audio_backend)
item[audio_key] = audio_chunk.squeeze(0)
return item
def _add_padding_keys(self, item: dict, padding: dict[str, list[bool]]) -> dict: def _add_padding_keys(self, item: dict, padding: dict[str, list[bool]]) -> dict:
for key, val in padding.items(): for key, val in padding.items():
item[key] = torch.BoolTensor(val) item[key] = torch.BoolTensor(val)
@ -733,11 +802,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
for key, val in query_result.items(): for key, val in query_result.items():
item[key] = val item[key] = val
if len(self.meta.video_keys) > 0: if len(self.meta.video_keys) > 0 or len(self.meta.audio_keys) > 0:
current_ts = item["timestamp"].item() current_ts = item["timestamp"].item()
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
query_timestamps = self._get_query_timestamps_video(current_ts, query_indices)
video_frames = self._query_videos(query_timestamps, ep_idx) video_frames = self._query_videos(query_timestamps, ep_idx)
item = {**video_frames, **item} item = {**item, **video_frames}
query_timestamps = self._get_query_timestamps_audio(current_ts, query_indices)
audio_chunks = self._query_audio(query_timestamps, DEFAULT_AUDIO_CHUNK_DURATION, ep_idx)
item = {**item, **audio_chunks}
if self.image_transforms is not None: if self.image_transforms is not None:
image_keys = self.meta.camera_keys image_keys = self.meta.camera_keys
@ -776,6 +850,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
image_key=image_key, episode_index=episode_index, frame_index=frame_index image_key=image_key, episode_index=episode_index, frame_index=frame_index
) )
return self.root / fpath return self.root / fpath
def _get_raw_audio_file_path(self, episode_index: int, audio_key: str) -> Path:
fpath = DEFAULT_RAW_AUDIO_PATH.format(audio_key=audio_key, episode_index=episode_index)
return self.root / fpath
def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None: def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None:
if self.image_writer is None: if self.image_writer is None:
@ -867,7 +945,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
for key, ft in self.features.items(): for key, ft in self.features.items():
# index, episode_index, task_index are already processed above, and image and video # index, episode_index, task_index are already processed above, and image and video
# are processed separately by storing image path and frame info as meta data # are processed separately by storing image path and frame info as meta data
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]: if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video", "audio"]:
continue continue
episode_buffer[key] = np.stack(episode_buffer[key]) episode_buffer[key] = np.stack(episode_buffer[key])
@ -880,6 +958,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
for key in self.meta.video_keys: for key in self.meta.video_keys:
episode_buffer[key] = video_paths[key] episode_buffer[key] = video_paths[key]
if len(self.meta.audio_keys) > 0:
_ = self.encode_episode_audio(episode_index)
# `meta.save_episode` be executed after encoding the videos # `meta.save_episode` be executed after encoding the videos
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats) self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
@ -904,6 +985,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
if img_dir.is_dir(): if img_dir.is_dir():
shutil.rmtree(self.root / "images") shutil.rmtree(self.root / "images")
# delete raw audio
raw_audio_files = list(self.root.rglob("*.wav"))
for raw_audio_file in raw_audio_files:
raw_audio_file.unlink()
if len(list(raw_audio_file.parent.iterdir())) == 0:
raw_audio_file.parent.rmdir()
if not episode_data: # Reset the buffer if not episode_data: # Reset the buffer
self.episode_buffer = self.create_episode_buffer() self.episode_buffer = self.create_episode_buffer()
@ -971,18 +1059,45 @@ class LeRobotDataset(torch.utils.data.Dataset):
since video encoding with ffmpeg is already using multithreading. since video encoding with ffmpeg is already using multithreading.
""" """
video_paths = {} video_paths = {}
for key in self.meta.video_keys: for video_key in self.meta.video_keys:
video_path = self.root / self.meta.get_video_file_path(episode_index, key) video_path = self.root / self.meta.get_video_file_path(episode_index, video_key)
video_paths[key] = str(video_path) video_paths[video_key] = str(video_path)
if video_path.is_file(): if video_path.is_file():
# Skip if video is already encoded. Could be the case when resuming data recording. # Skip if video is already encoded. Could be the case when resuming data recording.
continue continue
img_dir = self._get_image_file_path( img_dir = self._get_image_file_path(
episode_index=episode_index, image_key=key, frame_index=0 episode_index=episode_index, image_key=video_key, frame_index=0
).parent ).parent
encode_video_frames(img_dir, video_path, self.fps, overwrite=True)
audio_path = None
if self.meta.features[video_key]["audio"] is not None:
audio_key = self.meta.features[video_key]["audio"]
audio_path = self._get_raw_audio_file_path(episode_index, audio_key)
encode_video_frames(img_dir, video_path, self.fps, audio_path=audio_path, overwrite=True)
return video_paths return video_paths
def encode_episode_audio(self, episode_index: int) -> dict:
"""
Use ffmpeg to convert .wav raw audio files into .m4a audio files.
Note: `encode_episode_audio` is a blocking call. Making it asynchronous shouldn't speedup encoding,
since video encoding with ffmpeg is already using multithreading.
"""
audio_paths = {}
bound_audio_keys = {self.meta.features[video_key]["audio"] for video_key in self.meta.video_keys if self.meta.features[video_key]["audio"] is not None}
for audio_key in set(self.meta.audio_keys) - bound_audio_keys:
input_audio_path = self.root / self._get_raw_audio_file_path(episode_index, audio_key)
output_audio_path = self.root / self.meta.get_compressed_audio_file_path(episode_index, audio_key)
audio_paths[audio_key] = str(output_audio_path)
if output_audio_path.is_file():
# Skip if video is already encoded. Could be the case when resuming data recording.
continue
encode_audio(input_audio_path, output_audio_path, overwrite=True)
return audio_paths
@classmethod @classmethod
def create( def create(
@ -998,6 +1113,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
image_writer_processes: int = 0, image_writer_processes: int = 0,
image_writer_threads: int = 0, image_writer_threads: int = 0,
video_backend: str | None = None, video_backend: str | None = None,
audio_backend: str | None = None,
) -> "LeRobotDataset": ) -> "LeRobotDataset":
"""Create a LeRobot Dataset from scratch in order to record data.""" """Create a LeRobot Dataset from scratch in order to record data."""
obj = cls.__new__(cls) obj = cls.__new__(cls)
@ -1029,6 +1145,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.delta_indices = None obj.delta_indices = None
obj.episode_data_index = None obj.episode_data_index = None
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec() obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
obj.audio_backend = audio_backend if audio_backend is not None else "ffmpeg" #Waiting for torchcodec release #TODO(CarolinePascal)
return obj return obj
@ -1049,6 +1166,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
tolerances_s: dict | None = None, tolerances_s: dict | None = None,
download_videos: bool = True, download_videos: bool = True,
video_backend: str | None = None, video_backend: str | None = None,
audio_backend: str | None = None,
): ):
super().__init__() super().__init__()
self.repo_ids = repo_ids self.repo_ids = repo_ids
@ -1066,6 +1184,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
tolerance_s=self.tolerances_s[repo_id], tolerance_s=self.tolerances_s[repo_id],
download_videos=download_videos, download_videos=download_videos,
video_backend=video_backend, video_backend=video_backend,
audio_backend=audio_backend,
) )
for repo_id in repo_ids for repo_id in repo_ids
] ]

View File

@ -55,6 +55,10 @@ TASKS_PATH = "meta/tasks.jsonl"
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4" DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet" DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png" DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
DEFAULT_RAW_AUDIO_PATH = "audio/{audio_key}/episode_{episode_index:06d}.wav"
DEFAULT_COMPRESSED_AUDIO_PATH = "audio/chunk-{episode_chunk:03d}/{audio_key}/episode_{episode_index:06d}.m4a"
DEFAULT_AUDIO_CHUNK_DURATION = 0.5 # seconds
DATASET_CARD_TEMPLATE = """ DATASET_CARD_TEMPLATE = """
--- ---
@ -363,7 +367,7 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
def get_hf_features_from_features(features: dict) -> datasets.Features: def get_hf_features_from_features(features: dict) -> datasets.Features:
hf_features = {} hf_features = {}
for key, ft in features.items(): for key, ft in features.items():
if ft["dtype"] == "video": if ft["dtype"] == "video" or ft["dtype"] == "audio":
continue continue
elif ft["dtype"] == "image": elif ft["dtype"] == "image":
hf_features[key] = datasets.Image() hf_features[key] = datasets.Image()
@ -394,7 +398,13 @@ def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict:
key: {"dtype": "video" if use_videos else "image", **ft} key: {"dtype": "video" if use_videos else "image", **ft}
for key, ft in robot.camera_features.items() for key, ft in robot.camera_features.items()
} }
return {**robot.motor_features, **camera_ft, **DEFAULT_FEATURES} microphones_ft = {}
if robot.microphones:
microphones_ft = {
key: {"dtype": "audio", **ft}
for key, ft in robot.microphones_features.items()
}
return {**robot.motor_features, **camera_ft, **microphones_ft, **DEFAULT_FEATURES}
def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]: def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]:
@ -448,6 +458,7 @@ def create_empty_dataset_info(
"splits": {}, "splits": {},
"data_path": DEFAULT_PARQUET_PATH, "data_path": DEFAULT_PARQUET_PATH,
"video_path": DEFAULT_VIDEO_PATH if use_videos else None, "video_path": DEFAULT_VIDEO_PATH if use_videos else None,
"audio_path": DEFAULT_COMPRESSED_AUDIO_PATH,
"features": features, "features": features,
} }
@ -721,6 +732,7 @@ def validate_features_presence(
): ):
error_message = "" error_message = ""
missing_features = expected_features - actual_features missing_features = expected_features - actual_features
missing_features = {feature for feature in missing_features if "observation.audio" not in feature}
extra_features = actual_features - (expected_features | optional_features) extra_features = actual_features - (expected_features | optional_features)
if missing_features or extra_features: if missing_features or extra_features:

View File

@ -26,9 +26,12 @@ from typing import Any, ClassVar
import pyarrow as pa import pyarrow as pa
import torch import torch
import torchvision import torchvision
import torchaudio
from datasets.features.features import register_feature from datasets.features.features import register_feature
from PIL import Image from PIL import Image
from numpy import ceil
def get_safe_default_codec(): def get_safe_default_codec():
if importlib.util.find_spec("torchcodec"): if importlib.util.find_spec("torchcodec"):
@ -39,7 +42,72 @@ def get_safe_default_codec():
) )
return "pyav" return "pyav"
def decode_audio(
audio_path: Path | str,
timestamps: list[float],
duration: float,
backend: str | None = "ffmpeg",
) -> torch.Tensor:
"""
Decodes audio using the specified backend.
Args:
audio_path (Path): Path to the audio file.
timestamps (list[float]): List of timestamps to extract frames.
tolerance_s (float): Allowed deviation in seconds for frame retrieval.
backend (str, optional): Backend to use for decoding. Defaults to "pyav".
Returns:
torch.Tensor: Decoded frames.
Currently supports pyav.
"""
if backend == "torchcodec":
raise NotImplementedError("torchcodec is not yet supported for audio decoding")
elif backend == "ffmpeg":
return decode_audio_torchvision(audio_path, timestamps, duration)
else:
raise ValueError(f"Unsupported video backend: {backend}")
def decode_audio_torchvision(
audio_path: Path | str,
timestamps: list[float],
duration: float,
log_loaded_timestamps: bool = False,
) -> torch.Tensor:
#TODO(CarolinePascal) : add channels selection
audio_path = str(audio_path)
reader = torchaudio.io.StreamReader(src=audio_path)
audio_sampling_rate = reader.get_src_stream_info(reader.default_audio_stream).sample_rate
#TODO(CarolinePascal) : sort timestamps ?
reader.add_basic_audio_stream(
frames_per_chunk = int(ceil(duration * audio_sampling_rate)), #Too much is better than not enough
buffer_chunk_size = -1, #No dropping frames
)
audio_chunks = []
for ts in timestamps:
reader.seek(ts) #Default to closest audio sample
status = reader.fill_buffer()
if status != 0:
logging.warning("Audio stream reached end of recording before decoding desired timestamps.")
current_audio_chunk = reader.pop_chunks()[0]
if log_loaded_timestamps:
logging.info(f"audio chunk loaded at starting timestamp={current_audio_chunk["pts"]:.4f} with duration={len(current_audio_chunk) / audio_sampling_rate:.4f}")
audio_chunks.append(current_audio_chunk)
audio_chunks = torch.stack(audio_chunks)
#TODO(CarolinePascal) : pytorch format conversion ?
assert len(timestamps) == len(audio_chunks)
return audio_chunks
def decode_video_frames( def decode_video_frames(
video_path: Path | str, video_path: Path | str,
timestamps: list[float], timestamps: list[float],
@ -69,7 +137,6 @@ def decode_video_frames(
else: else:
raise ValueError(f"Unsupported video backend: {backend}") raise ValueError(f"Unsupported video backend: {backend}")
def decode_video_frames_torchvision( def decode_video_frames_torchvision(
video_path: Path | str, video_path: Path | str,
timestamps: list[float], timestamps: list[float],
@ -167,7 +234,6 @@ def decode_video_frames_torchvision(
assert len(timestamps) == len(closest_frames) assert len(timestamps) == len(closest_frames)
return closest_frames return closest_frames
def decode_video_frames_torchcodec( def decode_video_frames_torchcodec(
video_path: Path | str, video_path: Path | str,
timestamps: list[float], timestamps: list[float],
@ -242,15 +308,52 @@ def decode_video_frames_torchcodec(
assert len(timestamps) == len(closest_frames) assert len(timestamps) == len(closest_frames)
return closest_frames return closest_frames
def encode_audio(
input_path: Path | str,
output_path: Path | str,
codec: str = "aac",
log_level: str | None = "error",
overwrite: bool = False,
) -> None:
"""Encodes an audio file using ffmpeg."""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
ffmpeg_args = OrderedDict(
[
("-i", str(input_path)),
("-acodec", codec),
]
)
if log_level is not None:
ffmpeg_args["-loglevel"] = str(log_level)
ffmpeg_args = [item for pair in ffmpeg_args.items() for item in pair]
if overwrite:
ffmpeg_args.append("-y")
ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(output_path)]
# redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal
subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL)
if not output_path.exists():
raise OSError(
f"Video encoding did not work. File not found: {output_path}. "
f"Try running the command manually to debug: `{''.join(ffmpeg_cmd)}`"
)
def encode_video_frames( def encode_video_frames(
imgs_dir: Path | str, imgs_dir: Path | str,
video_path: Path | str, video_path: Path | str,
fps: int, fps: int,
audio_path: Path | str | None = None,
vcodec: str = "libsvtav1", vcodec: str = "libsvtav1",
pix_fmt: str = "yuv420p", pix_fmt: str = "yuv420p",
g: int | None = 2, g: int | None = 2,
crf: int | None = 30, crf: int | None = 30,
acodec: str = "aac", #TODO(CarolinePascal) : investigate Fraunhofer FDK AAC (libfdk_aac) codec and and constant (file size control) /variable (quality control) bitrate options
fast_decode: int = 0, fast_decode: int = 0,
log_level: str | None = "error", log_level: str | None = "error",
overwrite: bool = False, overwrite: bool = False,
@ -260,35 +363,53 @@ def encode_video_frames(
imgs_dir = Path(imgs_dir) imgs_dir = Path(imgs_dir)
video_path.parent.mkdir(parents=True, exist_ok=True) video_path.parent.mkdir(parents=True, exist_ok=True)
ffmpeg_args = OrderedDict( ffmpeg_video_args = OrderedDict(
[ [
("-f", "image2"), ("-f", "image2"),
("-r", str(fps)), ("-r", str(fps)),
("-i", str(imgs_dir / "frame_%06d.png")), ("-i", str(Path(imgs_dir) / "frame_%06d.png")),
("-vcodec", vcodec),
("-pix_fmt", pix_fmt),
] ]
) )
ffmpeg_audio_args = OrderedDict()
if audio_path is not None:
audio_path = Path(audio_path)
audio_path.parent.mkdir(parents=True, exist_ok=True)
ffmpeg_audio_args.update(OrderedDict(
[
("-i", str(audio_path)),
]
))
ffmpeg_encoding_args = OrderedDict(
[
("-pix_fmt", pix_fmt),
("-vcodec", vcodec),
]
)
if g is not None: if g is not None:
ffmpeg_args["-g"] = str(g) ffmpeg_encoding_args["-g"] = str(g)
if crf is not None: if crf is not None:
ffmpeg_args["-crf"] = str(crf) ffmpeg_encoding_args["-crf"] = str(crf)
if fast_decode: if fast_decode:
key = "-svtav1-params" if vcodec == "libsvtav1" else "-tune" key = "-svtav1-params" if vcodec == "libsvtav1" else "-tune"
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode" value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
ffmpeg_args[key] = value ffmpeg_encoding_args[key] = value
if audio_path is not None:
ffmpeg_encoding_args["-acodec"] = acodec
if log_level is not None: if log_level is not None:
ffmpeg_args["-loglevel"] = str(log_level) ffmpeg_encoding_args["-loglevel"] = str(log_level)
ffmpeg_args = [item for pair in ffmpeg_args.items() for item in pair] ffmpeg_args = [item for pair in ffmpeg_video_args.items() for item in pair]
ffmpeg_args += [item for pair in ffmpeg_audio_args.items() for item in pair]
ffmpeg_args += [item for pair in ffmpeg_encoding_args.items() for item in pair]
if overwrite: if overwrite:
ffmpeg_args.append("-y") ffmpeg_args.append("-y")
ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(video_path)] ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(video_path)]
# redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal # redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal
subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL) subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL)
@ -366,7 +487,6 @@ def get_audio_info(video_path: Path | str) -> dict:
"audio.channel_layout": audio_stream_info.get("channel_layout", None), "audio.channel_layout": audio_stream_info.get("channel_layout", None),
} }
def get_video_info(video_path: Path | str) -> dict: def get_video_info(video_path: Path | str) -> dict:
ffprobe_video_cmd = [ ffprobe_video_cmd = [
"ffprobe", "ffprobe",

View File

@ -216,7 +216,7 @@ class ManipulatorRobot:
@property @property
def features(self): def features(self):
return {**self.motor_features, **self.camera_features} return {**self.motor_features, **self.camera_features, **self.microphones_features}
@property @property
def has_camera(self): def has_camera(self):

View File

@ -29,6 +29,7 @@ from lerobot.common.datasets.utils import (
DEFAULT_FEATURES, DEFAULT_FEATURES,
DEFAULT_PARQUET_PATH, DEFAULT_PARQUET_PATH,
DEFAULT_VIDEO_PATH, DEFAULT_VIDEO_PATH,
DEFAULT_COMPRESSED_AUDIO_PATH,
get_hf_features_from_features, get_hf_features_from_features,
hf_transform_to_torch, hf_transform_to_torch,
) )
@ -121,6 +122,7 @@ def info_factory(features_factory):
chunks_size: int = DEFAULT_CHUNK_SIZE, chunks_size: int = DEFAULT_CHUNK_SIZE,
data_path: str = DEFAULT_PARQUET_PATH, data_path: str = DEFAULT_PARQUET_PATH,
video_path: str = DEFAULT_VIDEO_PATH, video_path: str = DEFAULT_VIDEO_PATH,
audio_path: str = DEFAULT_COMPRESSED_AUDIO_PATH,
motor_features: dict = DUMMY_MOTOR_FEATURES, motor_features: dict = DUMMY_MOTOR_FEATURES,
camera_features: dict = DUMMY_CAMERA_FEATURES, camera_features: dict = DUMMY_CAMERA_FEATURES,
use_videos: bool = True, use_videos: bool = True,
@ -139,6 +141,7 @@ def info_factory(features_factory):
"splits": {}, "splits": {},
"data_path": data_path, "data_path": data_path,
"video_path": video_path if use_videos else None, "video_path": video_path if use_videos else None,
"audio_path": audio_path,
"features": features, "features": features,
} }