Adding audio modality in LeRobotDatasets
This commit is contained in:
parent
8ddfb299fd
commit
8ee61bb81f
|
@ -36,8 +36,11 @@ from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
DEFAULT_FEATURES,
|
DEFAULT_FEATURES,
|
||||||
DEFAULT_IMAGE_PATH,
|
DEFAULT_IMAGE_PATH,
|
||||||
|
DEFAULT_RAW_AUDIO_PATH,
|
||||||
|
DEFAULT_COMPRESSED_AUDIO_PATH,
|
||||||
INFO_PATH,
|
INFO_PATH,
|
||||||
TASKS_PATH,
|
TASKS_PATH,
|
||||||
|
DEFAULT_AUDIO_CHUNK_DURATION,
|
||||||
append_jsonlines,
|
append_jsonlines,
|
||||||
backward_compatible_episodes_stats,
|
backward_compatible_episodes_stats,
|
||||||
check_delta_timestamps,
|
check_delta_timestamps,
|
||||||
|
@ -69,8 +72,11 @@ from lerobot.common.datasets.video_utils import (
|
||||||
VideoFrame,
|
VideoFrame,
|
||||||
decode_video_frames,
|
decode_video_frames,
|
||||||
encode_video_frames,
|
encode_video_frames,
|
||||||
|
encode_audio,
|
||||||
|
decode_audio,
|
||||||
get_safe_default_codec,
|
get_safe_default_codec,
|
||||||
get_video_info,
|
get_video_info,
|
||||||
|
get_audio_info,
|
||||||
)
|
)
|
||||||
from lerobot.common.robot_devices.robots.utils import Robot
|
from lerobot.common.robot_devices.robots.utils import Robot
|
||||||
|
|
||||||
|
@ -141,6 +147,11 @@ class LeRobotDatasetMetadata:
|
||||||
ep_chunk = self.get_episode_chunk(ep_index)
|
ep_chunk = self.get_episode_chunk(ep_index)
|
||||||
fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
|
fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
|
||||||
return Path(fpath)
|
return Path(fpath)
|
||||||
|
|
||||||
|
def get_compressed_audio_file_path(self, episode_index: int, audio_key: str) -> Path:
|
||||||
|
episode_chunk = self.get_episode_chunk(episode_index)
|
||||||
|
fpath = self.audio_path.format(episode_chunk=episode_chunk, audio_key=audio_key, episode_index=episode_index)
|
||||||
|
return self.root / fpath
|
||||||
|
|
||||||
def get_episode_chunk(self, ep_index: int) -> int:
|
def get_episode_chunk(self, ep_index: int) -> int:
|
||||||
return ep_index // self.chunks_size
|
return ep_index // self.chunks_size
|
||||||
|
@ -154,6 +165,11 @@ class LeRobotDatasetMetadata:
|
||||||
def video_path(self) -> str | None:
|
def video_path(self) -> str | None:
|
||||||
"""Formattable string for the video files."""
|
"""Formattable string for the video files."""
|
||||||
return self.info["video_path"]
|
return self.info["video_path"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def audio_path(self) -> str | None:
|
||||||
|
"""Formattable string for the audio files."""
|
||||||
|
return self.info["audio_path"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def robot_type(self) -> str | None:
|
def robot_type(self) -> str | None:
|
||||||
|
@ -184,6 +200,11 @@ class LeRobotDatasetMetadata:
|
||||||
def camera_keys(self) -> list[str]:
|
def camera_keys(self) -> list[str]:
|
||||||
"""Keys to access visual modalities (regardless of their storage method)."""
|
"""Keys to access visual modalities (regardless of their storage method)."""
|
||||||
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
|
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def audio_keys(self) -> list[str]:
|
||||||
|
"""Keys to access audio modalities."""
|
||||||
|
return [key for key, ft in self.features.items() if ft["dtype"] == "audio"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def names(self) -> dict[str, list | dict]:
|
def names(self) -> dict[str, list | dict]:
|
||||||
|
@ -264,6 +285,9 @@ class LeRobotDatasetMetadata:
|
||||||
if len(self.video_keys) > 0:
|
if len(self.video_keys) > 0:
|
||||||
self.update_video_info()
|
self.update_video_info()
|
||||||
|
|
||||||
|
if len(self.audio_keys) > 0:
|
||||||
|
self.update_audio_info()
|
||||||
|
|
||||||
write_info(self.info, self.root)
|
write_info(self.info, self.root)
|
||||||
|
|
||||||
episode_dict = {
|
episode_dict = {
|
||||||
|
@ -288,6 +312,17 @@ class LeRobotDatasetMetadata:
|
||||||
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
|
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
|
||||||
self.info["features"][key]["info"] = get_video_info(video_path)
|
self.info["features"][key]["info"] = get_video_info(video_path)
|
||||||
|
|
||||||
|
def update_audio_info(self) -> None:
|
||||||
|
"""
|
||||||
|
Warning: this function writes info from first episode audio, implicitly assuming that all audio have
|
||||||
|
been encoded the same way. Also, this means it assumes the first episode exists.
|
||||||
|
"""
|
||||||
|
bound_audio_keys = {self.features[video_key]["audio"] for video_key in self.video_keys if self.features[video_key]["audio"] is not None}
|
||||||
|
for key in set(self.audio_keys) - bound_audio_keys:
|
||||||
|
if not self.features[key].get("info", None):
|
||||||
|
audio_path = self.root / self.get_compressed_audio_file_path(0, key)
|
||||||
|
self.info["features"][key]["info"] = get_audio_info(audio_path)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
feature_keys = list(self.features)
|
feature_keys = list(self.features)
|
||||||
return (
|
return (
|
||||||
|
@ -364,6 +399,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
force_cache_sync: bool = False,
|
force_cache_sync: bool = False,
|
||||||
download_videos: bool = True,
|
download_videos: bool = True,
|
||||||
video_backend: str | None = None,
|
video_backend: str | None = None,
|
||||||
|
audio_backend: str | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
2 modes are available for instantiating this class, depending on 2 different use cases:
|
2 modes are available for instantiating this class, depending on 2 different use cases:
|
||||||
|
@ -465,6 +501,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
True.
|
True.
|
||||||
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'.
|
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'.
|
||||||
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
|
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
|
||||||
|
audio_backend (str | None, optional): Audio backend to use for decoding audio. Defaults to 'ffmpeg'.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.repo_id = repo_id
|
self.repo_id = repo_id
|
||||||
|
@ -475,6 +512,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
self.tolerance_s = tolerance_s
|
self.tolerance_s = tolerance_s
|
||||||
self.revision = revision if revision else CODEBASE_VERSION
|
self.revision = revision if revision else CODEBASE_VERSION
|
||||||
self.video_backend = video_backend if video_backend else get_safe_default_codec()
|
self.video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||||
|
self.audio_backend = audio_backend if audio_backend else "ffmpeg" #Waiting for torchcodec release #TODO(CarolinePascal)
|
||||||
self.delta_indices = None
|
self.delta_indices = None
|
||||||
|
|
||||||
# Unused attributes
|
# Unused attributes
|
||||||
|
@ -499,7 +537,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
self.hf_dataset = self.load_hf_dataset()
|
self.hf_dataset = self.load_hf_dataset()
|
||||||
except (AssertionError, FileNotFoundError, NotADirectoryError):
|
except (AssertionError, FileNotFoundError, NotADirectoryError):
|
||||||
self.revision = get_safe_version(self.repo_id, self.revision)
|
self.revision = get_safe_version(self.repo_id, self.revision)
|
||||||
self.download_episodes(download_videos)
|
self.download_episodes(download_videos) #Sould load audio as well #TODO(CarolinePascal): separate audio from video
|
||||||
self.hf_dataset = self.load_hf_dataset()
|
self.hf_dataset = self.load_hf_dataset()
|
||||||
|
|
||||||
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
||||||
|
@ -677,7 +715,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
}
|
}
|
||||||
return query_indices, padding
|
return query_indices, padding
|
||||||
|
|
||||||
def _get_query_timestamps(
|
def _get_query_timestamps_video(
|
||||||
self,
|
self,
|
||||||
current_ts: float,
|
current_ts: float,
|
||||||
query_indices: dict[str, list[int]] | None = None,
|
query_indices: dict[str, list[int]] | None = None,
|
||||||
|
@ -691,6 +729,22 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
query_timestamps[key] = [current_ts]
|
query_timestamps[key] = [current_ts]
|
||||||
|
|
||||||
return query_timestamps
|
return query_timestamps
|
||||||
|
|
||||||
|
#TODO(CarolinePascal): add variable query durations
|
||||||
|
def _get_query_timestamps_audio(
|
||||||
|
self,
|
||||||
|
current_ts: float,
|
||||||
|
query_indices: dict[str, list[int]] | None = None,
|
||||||
|
) -> dict[str, list[float]]:
|
||||||
|
query_timestamps = {}
|
||||||
|
for key in self.meta.audio_keys:
|
||||||
|
if query_indices is not None and key in query_indices:
|
||||||
|
timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
|
||||||
|
query_timestamps[key] = torch.stack(timestamps).tolist()
|
||||||
|
else:
|
||||||
|
query_timestamps[key] = [current_ts]
|
||||||
|
|
||||||
|
return query_timestamps
|
||||||
|
|
||||||
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
||||||
return {
|
return {
|
||||||
|
@ -713,6 +767,21 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
return item
|
return item
|
||||||
|
|
||||||
|
#TODO(CarolinePascal): add variable query durations
|
||||||
|
def _query_audio(self, query_timestamps: dict[str, list[float]], query_duration: float, ep_idx: int) -> dict[str, torch.Tensor]:
|
||||||
|
item = {}
|
||||||
|
bound_audio_keys_mapping = {self.meta.features[video_key]["audio"]:video_key for video_key in self.meta.video_keys if self.meta.features[video_key]["audio"] is not None}
|
||||||
|
for audio_key, query_ts in query_timestamps.items():
|
||||||
|
#Audio stored with video in a single .mp4 file
|
||||||
|
if audio_key in bound_audio_keys_mapping.keys():
|
||||||
|
audio_path = self.root / self.meta.get_video_file_path(ep_idx, bound_audio_keys_mapping[audio_key])
|
||||||
|
#Audio stored alone in a separate .m4a file
|
||||||
|
else:
|
||||||
|
audio_path = self.root / self.meta.get_compressed_audio_file_path(ep_idx, audio_key)
|
||||||
|
audio_chunk = decode_audio(audio_path, query_ts, query_duration, self.audio_backend)
|
||||||
|
item[audio_key] = audio_chunk.squeeze(0)
|
||||||
|
return item
|
||||||
|
|
||||||
def _add_padding_keys(self, item: dict, padding: dict[str, list[bool]]) -> dict:
|
def _add_padding_keys(self, item: dict, padding: dict[str, list[bool]]) -> dict:
|
||||||
for key, val in padding.items():
|
for key, val in padding.items():
|
||||||
item[key] = torch.BoolTensor(val)
|
item[key] = torch.BoolTensor(val)
|
||||||
|
@ -733,11 +802,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
for key, val in query_result.items():
|
for key, val in query_result.items():
|
||||||
item[key] = val
|
item[key] = val
|
||||||
|
|
||||||
if len(self.meta.video_keys) > 0:
|
if len(self.meta.video_keys) > 0 or len(self.meta.audio_keys) > 0:
|
||||||
current_ts = item["timestamp"].item()
|
current_ts = item["timestamp"].item()
|
||||||
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
|
|
||||||
|
query_timestamps = self._get_query_timestamps_video(current_ts, query_indices)
|
||||||
video_frames = self._query_videos(query_timestamps, ep_idx)
|
video_frames = self._query_videos(query_timestamps, ep_idx)
|
||||||
item = {**video_frames, **item}
|
item = {**item, **video_frames}
|
||||||
|
|
||||||
|
query_timestamps = self._get_query_timestamps_audio(current_ts, query_indices)
|
||||||
|
audio_chunks = self._query_audio(query_timestamps, DEFAULT_AUDIO_CHUNK_DURATION, ep_idx)
|
||||||
|
item = {**item, **audio_chunks}
|
||||||
|
|
||||||
if self.image_transforms is not None:
|
if self.image_transforms is not None:
|
||||||
image_keys = self.meta.camera_keys
|
image_keys = self.meta.camera_keys
|
||||||
|
@ -776,6 +850,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
image_key=image_key, episode_index=episode_index, frame_index=frame_index
|
image_key=image_key, episode_index=episode_index, frame_index=frame_index
|
||||||
)
|
)
|
||||||
return self.root / fpath
|
return self.root / fpath
|
||||||
|
|
||||||
|
def _get_raw_audio_file_path(self, episode_index: int, audio_key: str) -> Path:
|
||||||
|
fpath = DEFAULT_RAW_AUDIO_PATH.format(audio_key=audio_key, episode_index=episode_index)
|
||||||
|
return self.root / fpath
|
||||||
|
|
||||||
def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None:
|
def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None:
|
||||||
if self.image_writer is None:
|
if self.image_writer is None:
|
||||||
|
@ -867,7 +945,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
for key, ft in self.features.items():
|
for key, ft in self.features.items():
|
||||||
# index, episode_index, task_index are already processed above, and image and video
|
# index, episode_index, task_index are already processed above, and image and video
|
||||||
# are processed separately by storing image path and frame info as meta data
|
# are processed separately by storing image path and frame info as meta data
|
||||||
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
|
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video", "audio"]:
|
||||||
continue
|
continue
|
||||||
episode_buffer[key] = np.stack(episode_buffer[key])
|
episode_buffer[key] = np.stack(episode_buffer[key])
|
||||||
|
|
||||||
|
@ -880,6 +958,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
for key in self.meta.video_keys:
|
for key in self.meta.video_keys:
|
||||||
episode_buffer[key] = video_paths[key]
|
episode_buffer[key] = video_paths[key]
|
||||||
|
|
||||||
|
if len(self.meta.audio_keys) > 0:
|
||||||
|
_ = self.encode_episode_audio(episode_index)
|
||||||
|
|
||||||
# `meta.save_episode` be executed after encoding the videos
|
# `meta.save_episode` be executed after encoding the videos
|
||||||
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
|
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
|
||||||
|
|
||||||
|
@ -904,6 +985,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
if img_dir.is_dir():
|
if img_dir.is_dir():
|
||||||
shutil.rmtree(self.root / "images")
|
shutil.rmtree(self.root / "images")
|
||||||
|
|
||||||
|
# delete raw audio
|
||||||
|
raw_audio_files = list(self.root.rglob("*.wav"))
|
||||||
|
for raw_audio_file in raw_audio_files:
|
||||||
|
raw_audio_file.unlink()
|
||||||
|
if len(list(raw_audio_file.parent.iterdir())) == 0:
|
||||||
|
raw_audio_file.parent.rmdir()
|
||||||
|
|
||||||
if not episode_data: # Reset the buffer
|
if not episode_data: # Reset the buffer
|
||||||
self.episode_buffer = self.create_episode_buffer()
|
self.episode_buffer = self.create_episode_buffer()
|
||||||
|
|
||||||
|
@ -971,18 +1059,45 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
since video encoding with ffmpeg is already using multithreading.
|
since video encoding with ffmpeg is already using multithreading.
|
||||||
"""
|
"""
|
||||||
video_paths = {}
|
video_paths = {}
|
||||||
for key in self.meta.video_keys:
|
for video_key in self.meta.video_keys:
|
||||||
video_path = self.root / self.meta.get_video_file_path(episode_index, key)
|
video_path = self.root / self.meta.get_video_file_path(episode_index, video_key)
|
||||||
video_paths[key] = str(video_path)
|
video_paths[video_key] = str(video_path)
|
||||||
if video_path.is_file():
|
if video_path.is_file():
|
||||||
# Skip if video is already encoded. Could be the case when resuming data recording.
|
# Skip if video is already encoded. Could be the case when resuming data recording.
|
||||||
continue
|
continue
|
||||||
img_dir = self._get_image_file_path(
|
img_dir = self._get_image_file_path(
|
||||||
episode_index=episode_index, image_key=key, frame_index=0
|
episode_index=episode_index, image_key=video_key, frame_index=0
|
||||||
).parent
|
).parent
|
||||||
encode_video_frames(img_dir, video_path, self.fps, overwrite=True)
|
|
||||||
|
audio_path = None
|
||||||
|
if self.meta.features[video_key]["audio"] is not None:
|
||||||
|
audio_key = self.meta.features[video_key]["audio"]
|
||||||
|
audio_path = self._get_raw_audio_file_path(episode_index, audio_key)
|
||||||
|
|
||||||
|
encode_video_frames(img_dir, video_path, self.fps, audio_path=audio_path, overwrite=True)
|
||||||
|
|
||||||
return video_paths
|
return video_paths
|
||||||
|
|
||||||
|
def encode_episode_audio(self, episode_index: int) -> dict:
|
||||||
|
"""
|
||||||
|
Use ffmpeg to convert .wav raw audio files into .m4a audio files.
|
||||||
|
Note: `encode_episode_audio` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||||
|
since video encoding with ffmpeg is already using multithreading.
|
||||||
|
"""
|
||||||
|
audio_paths = {}
|
||||||
|
bound_audio_keys = {self.meta.features[video_key]["audio"] for video_key in self.meta.video_keys if self.meta.features[video_key]["audio"] is not None}
|
||||||
|
for audio_key in set(self.meta.audio_keys) - bound_audio_keys:
|
||||||
|
input_audio_path = self.root / self._get_raw_audio_file_path(episode_index, audio_key)
|
||||||
|
output_audio_path = self.root / self.meta.get_compressed_audio_file_path(episode_index, audio_key)
|
||||||
|
|
||||||
|
audio_paths[audio_key] = str(output_audio_path)
|
||||||
|
if output_audio_path.is_file():
|
||||||
|
# Skip if video is already encoded. Could be the case when resuming data recording.
|
||||||
|
continue
|
||||||
|
|
||||||
|
encode_audio(input_audio_path, output_audio_path, overwrite=True)
|
||||||
|
|
||||||
|
return audio_paths
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(
|
def create(
|
||||||
|
@ -998,6 +1113,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
image_writer_processes: int = 0,
|
image_writer_processes: int = 0,
|
||||||
image_writer_threads: int = 0,
|
image_writer_threads: int = 0,
|
||||||
video_backend: str | None = None,
|
video_backend: str | None = None,
|
||||||
|
audio_backend: str | None = None,
|
||||||
) -> "LeRobotDataset":
|
) -> "LeRobotDataset":
|
||||||
"""Create a LeRobot Dataset from scratch in order to record data."""
|
"""Create a LeRobot Dataset from scratch in order to record data."""
|
||||||
obj = cls.__new__(cls)
|
obj = cls.__new__(cls)
|
||||||
|
@ -1029,6 +1145,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
obj.delta_indices = None
|
obj.delta_indices = None
|
||||||
obj.episode_data_index = None
|
obj.episode_data_index = None
|
||||||
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
||||||
|
obj.audio_backend = audio_backend if audio_backend is not None else "ffmpeg" #Waiting for torchcodec release #TODO(CarolinePascal)
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
@ -1049,6 +1166,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||||
tolerances_s: dict | None = None,
|
tolerances_s: dict | None = None,
|
||||||
download_videos: bool = True,
|
download_videos: bool = True,
|
||||||
video_backend: str | None = None,
|
video_backend: str | None = None,
|
||||||
|
audio_backend: str | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.repo_ids = repo_ids
|
self.repo_ids = repo_ids
|
||||||
|
@ -1066,6 +1184,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||||
tolerance_s=self.tolerances_s[repo_id],
|
tolerance_s=self.tolerances_s[repo_id],
|
||||||
download_videos=download_videos,
|
download_videos=download_videos,
|
||||||
video_backend=video_backend,
|
video_backend=video_backend,
|
||||||
|
audio_backend=audio_backend,
|
||||||
)
|
)
|
||||||
for repo_id in repo_ids
|
for repo_id in repo_ids
|
||||||
]
|
]
|
||||||
|
|
|
@ -55,6 +55,10 @@ TASKS_PATH = "meta/tasks.jsonl"
|
||||||
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
||||||
DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
|
DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
|
||||||
DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
|
DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
|
||||||
|
DEFAULT_RAW_AUDIO_PATH = "audio/{audio_key}/episode_{episode_index:06d}.wav"
|
||||||
|
DEFAULT_COMPRESSED_AUDIO_PATH = "audio/chunk-{episode_chunk:03d}/{audio_key}/episode_{episode_index:06d}.m4a"
|
||||||
|
|
||||||
|
DEFAULT_AUDIO_CHUNK_DURATION = 0.5 # seconds
|
||||||
|
|
||||||
DATASET_CARD_TEMPLATE = """
|
DATASET_CARD_TEMPLATE = """
|
||||||
---
|
---
|
||||||
|
@ -363,7 +367,7 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
|
||||||
def get_hf_features_from_features(features: dict) -> datasets.Features:
|
def get_hf_features_from_features(features: dict) -> datasets.Features:
|
||||||
hf_features = {}
|
hf_features = {}
|
||||||
for key, ft in features.items():
|
for key, ft in features.items():
|
||||||
if ft["dtype"] == "video":
|
if ft["dtype"] == "video" or ft["dtype"] == "audio":
|
||||||
continue
|
continue
|
||||||
elif ft["dtype"] == "image":
|
elif ft["dtype"] == "image":
|
||||||
hf_features[key] = datasets.Image()
|
hf_features[key] = datasets.Image()
|
||||||
|
@ -394,7 +398,13 @@ def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict:
|
||||||
key: {"dtype": "video" if use_videos else "image", **ft}
|
key: {"dtype": "video" if use_videos else "image", **ft}
|
||||||
for key, ft in robot.camera_features.items()
|
for key, ft in robot.camera_features.items()
|
||||||
}
|
}
|
||||||
return {**robot.motor_features, **camera_ft, **DEFAULT_FEATURES}
|
microphones_ft = {}
|
||||||
|
if robot.microphones:
|
||||||
|
microphones_ft = {
|
||||||
|
key: {"dtype": "audio", **ft}
|
||||||
|
for key, ft in robot.microphones_features.items()
|
||||||
|
}
|
||||||
|
return {**robot.motor_features, **camera_ft, **microphones_ft, **DEFAULT_FEATURES}
|
||||||
|
|
||||||
|
|
||||||
def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]:
|
def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]:
|
||||||
|
@ -448,6 +458,7 @@ def create_empty_dataset_info(
|
||||||
"splits": {},
|
"splits": {},
|
||||||
"data_path": DEFAULT_PARQUET_PATH,
|
"data_path": DEFAULT_PARQUET_PATH,
|
||||||
"video_path": DEFAULT_VIDEO_PATH if use_videos else None,
|
"video_path": DEFAULT_VIDEO_PATH if use_videos else None,
|
||||||
|
"audio_path": DEFAULT_COMPRESSED_AUDIO_PATH,
|
||||||
"features": features,
|
"features": features,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -721,6 +732,7 @@ def validate_features_presence(
|
||||||
):
|
):
|
||||||
error_message = ""
|
error_message = ""
|
||||||
missing_features = expected_features - actual_features
|
missing_features = expected_features - actual_features
|
||||||
|
missing_features = {feature for feature in missing_features if "observation.audio" not in feature}
|
||||||
extra_features = actual_features - (expected_features | optional_features)
|
extra_features = actual_features - (expected_features | optional_features)
|
||||||
|
|
||||||
if missing_features or extra_features:
|
if missing_features or extra_features:
|
||||||
|
|
|
@ -26,9 +26,12 @@ from typing import Any, ClassVar
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
|
import torchaudio
|
||||||
from datasets.features.features import register_feature
|
from datasets.features.features import register_feature
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
from numpy import ceil
|
||||||
|
|
||||||
|
|
||||||
def get_safe_default_codec():
|
def get_safe_default_codec():
|
||||||
if importlib.util.find_spec("torchcodec"):
|
if importlib.util.find_spec("torchcodec"):
|
||||||
|
@ -39,7 +42,72 @@ def get_safe_default_codec():
|
||||||
)
|
)
|
||||||
return "pyav"
|
return "pyav"
|
||||||
|
|
||||||
|
def decode_audio(
|
||||||
|
audio_path: Path | str,
|
||||||
|
timestamps: list[float],
|
||||||
|
duration: float,
|
||||||
|
backend: str | None = "ffmpeg",
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Decodes audio using the specified backend.
|
||||||
|
Args:
|
||||||
|
audio_path (Path): Path to the audio file.
|
||||||
|
timestamps (list[float]): List of timestamps to extract frames.
|
||||||
|
tolerance_s (float): Allowed deviation in seconds for frame retrieval.
|
||||||
|
backend (str, optional): Backend to use for decoding. Defaults to "pyav".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Decoded frames.
|
||||||
|
|
||||||
|
Currently supports pyav.
|
||||||
|
"""
|
||||||
|
if backend == "torchcodec":
|
||||||
|
raise NotImplementedError("torchcodec is not yet supported for audio decoding")
|
||||||
|
elif backend == "ffmpeg":
|
||||||
|
return decode_audio_torchvision(audio_path, timestamps, duration)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported video backend: {backend}")
|
||||||
|
|
||||||
|
def decode_audio_torchvision(
|
||||||
|
audio_path: Path | str,
|
||||||
|
timestamps: list[float],
|
||||||
|
duration: float,
|
||||||
|
log_loaded_timestamps: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
#TODO(CarolinePascal) : add channels selection
|
||||||
|
audio_path = str(audio_path)
|
||||||
|
|
||||||
|
reader = torchaudio.io.StreamReader(src=audio_path)
|
||||||
|
audio_sampling_rate = reader.get_src_stream_info(reader.default_audio_stream).sample_rate
|
||||||
|
|
||||||
|
#TODO(CarolinePascal) : sort timestamps ?
|
||||||
|
|
||||||
|
reader.add_basic_audio_stream(
|
||||||
|
frames_per_chunk = int(ceil(duration * audio_sampling_rate)), #Too much is better than not enough
|
||||||
|
buffer_chunk_size = -1, #No dropping frames
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_chunks = []
|
||||||
|
for ts in timestamps:
|
||||||
|
reader.seek(ts) #Default to closest audio sample
|
||||||
|
status = reader.fill_buffer()
|
||||||
|
if status != 0:
|
||||||
|
logging.warning("Audio stream reached end of recording before decoding desired timestamps.")
|
||||||
|
|
||||||
|
current_audio_chunk = reader.pop_chunks()[0]
|
||||||
|
|
||||||
|
if log_loaded_timestamps:
|
||||||
|
logging.info(f"audio chunk loaded at starting timestamp={current_audio_chunk["pts"]:.4f} with duration={len(current_audio_chunk) / audio_sampling_rate:.4f}")
|
||||||
|
|
||||||
|
audio_chunks.append(current_audio_chunk)
|
||||||
|
|
||||||
|
audio_chunks = torch.stack(audio_chunks)
|
||||||
|
#TODO(CarolinePascal) : pytorch format conversion ?
|
||||||
|
|
||||||
|
assert len(timestamps) == len(audio_chunks)
|
||||||
|
return audio_chunks
|
||||||
|
|
||||||
def decode_video_frames(
|
def decode_video_frames(
|
||||||
video_path: Path | str,
|
video_path: Path | str,
|
||||||
timestamps: list[float],
|
timestamps: list[float],
|
||||||
|
@ -69,7 +137,6 @@ def decode_video_frames(
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported video backend: {backend}")
|
raise ValueError(f"Unsupported video backend: {backend}")
|
||||||
|
|
||||||
|
|
||||||
def decode_video_frames_torchvision(
|
def decode_video_frames_torchvision(
|
||||||
video_path: Path | str,
|
video_path: Path | str,
|
||||||
timestamps: list[float],
|
timestamps: list[float],
|
||||||
|
@ -167,7 +234,6 @@ def decode_video_frames_torchvision(
|
||||||
assert len(timestamps) == len(closest_frames)
|
assert len(timestamps) == len(closest_frames)
|
||||||
return closest_frames
|
return closest_frames
|
||||||
|
|
||||||
|
|
||||||
def decode_video_frames_torchcodec(
|
def decode_video_frames_torchcodec(
|
||||||
video_path: Path | str,
|
video_path: Path | str,
|
||||||
timestamps: list[float],
|
timestamps: list[float],
|
||||||
|
@ -242,15 +308,52 @@ def decode_video_frames_torchcodec(
|
||||||
assert len(timestamps) == len(closest_frames)
|
assert len(timestamps) == len(closest_frames)
|
||||||
return closest_frames
|
return closest_frames
|
||||||
|
|
||||||
|
def encode_audio(
|
||||||
|
input_path: Path | str,
|
||||||
|
output_path: Path | str,
|
||||||
|
codec: str = "aac",
|
||||||
|
log_level: str | None = "error",
|
||||||
|
overwrite: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""Encodes an audio file using ffmpeg."""
|
||||||
|
output_path = Path(output_path)
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
ffmpeg_args = OrderedDict(
|
||||||
|
[
|
||||||
|
("-i", str(input_path)),
|
||||||
|
("-acodec", codec),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if log_level is not None:
|
||||||
|
ffmpeg_args["-loglevel"] = str(log_level)
|
||||||
|
|
||||||
|
ffmpeg_args = [item for pair in ffmpeg_args.items() for item in pair]
|
||||||
|
if overwrite:
|
||||||
|
ffmpeg_args.append("-y")
|
||||||
|
|
||||||
|
ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(output_path)]
|
||||||
|
|
||||||
|
# redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal
|
||||||
|
subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL)
|
||||||
|
|
||||||
|
if not output_path.exists():
|
||||||
|
raise OSError(
|
||||||
|
f"Video encoding did not work. File not found: {output_path}. "
|
||||||
|
f"Try running the command manually to debug: `{''.join(ffmpeg_cmd)}`"
|
||||||
|
)
|
||||||
|
|
||||||
def encode_video_frames(
|
def encode_video_frames(
|
||||||
imgs_dir: Path | str,
|
imgs_dir: Path | str,
|
||||||
video_path: Path | str,
|
video_path: Path | str,
|
||||||
fps: int,
|
fps: int,
|
||||||
|
audio_path: Path | str | None = None,
|
||||||
vcodec: str = "libsvtav1",
|
vcodec: str = "libsvtav1",
|
||||||
pix_fmt: str = "yuv420p",
|
pix_fmt: str = "yuv420p",
|
||||||
g: int | None = 2,
|
g: int | None = 2,
|
||||||
crf: int | None = 30,
|
crf: int | None = 30,
|
||||||
|
acodec: str = "aac", #TODO(CarolinePascal) : investigate Fraunhofer FDK AAC (libfdk_aac) codec and and constant (file size control) /variable (quality control) bitrate options
|
||||||
fast_decode: int = 0,
|
fast_decode: int = 0,
|
||||||
log_level: str | None = "error",
|
log_level: str | None = "error",
|
||||||
overwrite: bool = False,
|
overwrite: bool = False,
|
||||||
|
@ -260,35 +363,53 @@ def encode_video_frames(
|
||||||
imgs_dir = Path(imgs_dir)
|
imgs_dir = Path(imgs_dir)
|
||||||
video_path.parent.mkdir(parents=True, exist_ok=True)
|
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
ffmpeg_args = OrderedDict(
|
ffmpeg_video_args = OrderedDict(
|
||||||
[
|
[
|
||||||
("-f", "image2"),
|
("-f", "image2"),
|
||||||
("-r", str(fps)),
|
("-r", str(fps)),
|
||||||
("-i", str(imgs_dir / "frame_%06d.png")),
|
("-i", str(Path(imgs_dir) / "frame_%06d.png")),
|
||||||
("-vcodec", vcodec),
|
|
||||||
("-pix_fmt", pix_fmt),
|
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ffmpeg_audio_args = OrderedDict()
|
||||||
|
if audio_path is not None:
|
||||||
|
audio_path = Path(audio_path)
|
||||||
|
audio_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
ffmpeg_audio_args.update(OrderedDict(
|
||||||
|
[
|
||||||
|
("-i", str(audio_path)),
|
||||||
|
]
|
||||||
|
))
|
||||||
|
|
||||||
|
ffmpeg_encoding_args = OrderedDict(
|
||||||
|
[
|
||||||
|
("-pix_fmt", pix_fmt),
|
||||||
|
("-vcodec", vcodec),
|
||||||
|
]
|
||||||
|
)
|
||||||
if g is not None:
|
if g is not None:
|
||||||
ffmpeg_args["-g"] = str(g)
|
ffmpeg_encoding_args["-g"] = str(g)
|
||||||
|
|
||||||
if crf is not None:
|
if crf is not None:
|
||||||
ffmpeg_args["-crf"] = str(crf)
|
ffmpeg_encoding_args["-crf"] = str(crf)
|
||||||
|
|
||||||
if fast_decode:
|
if fast_decode:
|
||||||
key = "-svtav1-params" if vcodec == "libsvtav1" else "-tune"
|
key = "-svtav1-params" if vcodec == "libsvtav1" else "-tune"
|
||||||
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
|
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
|
||||||
ffmpeg_args[key] = value
|
ffmpeg_encoding_args[key] = value
|
||||||
|
|
||||||
|
if audio_path is not None:
|
||||||
|
ffmpeg_encoding_args["-acodec"] = acodec
|
||||||
|
|
||||||
if log_level is not None:
|
if log_level is not None:
|
||||||
ffmpeg_args["-loglevel"] = str(log_level)
|
ffmpeg_encoding_args["-loglevel"] = str(log_level)
|
||||||
|
|
||||||
ffmpeg_args = [item for pair in ffmpeg_args.items() for item in pair]
|
ffmpeg_args = [item for pair in ffmpeg_video_args.items() for item in pair]
|
||||||
|
ffmpeg_args += [item for pair in ffmpeg_audio_args.items() for item in pair]
|
||||||
|
ffmpeg_args += [item for pair in ffmpeg_encoding_args.items() for item in pair]
|
||||||
if overwrite:
|
if overwrite:
|
||||||
ffmpeg_args.append("-y")
|
ffmpeg_args.append("-y")
|
||||||
|
|
||||||
ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(video_path)]
|
ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(video_path)]
|
||||||
|
|
||||||
# redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal
|
# redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal
|
||||||
subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL)
|
subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL)
|
||||||
|
|
||||||
|
@ -366,7 +487,6 @@ def get_audio_info(video_path: Path | str) -> dict:
|
||||||
"audio.channel_layout": audio_stream_info.get("channel_layout", None),
|
"audio.channel_layout": audio_stream_info.get("channel_layout", None),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_video_info(video_path: Path | str) -> dict:
|
def get_video_info(video_path: Path | str) -> dict:
|
||||||
ffprobe_video_cmd = [
|
ffprobe_video_cmd = [
|
||||||
"ffprobe",
|
"ffprobe",
|
||||||
|
|
|
@ -216,7 +216,7 @@ class ManipulatorRobot:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def features(self):
|
def features(self):
|
||||||
return {**self.motor_features, **self.camera_features}
|
return {**self.motor_features, **self.camera_features, **self.microphones_features}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def has_camera(self):
|
def has_camera(self):
|
||||||
|
|
|
@ -29,6 +29,7 @@ from lerobot.common.datasets.utils import (
|
||||||
DEFAULT_FEATURES,
|
DEFAULT_FEATURES,
|
||||||
DEFAULT_PARQUET_PATH,
|
DEFAULT_PARQUET_PATH,
|
||||||
DEFAULT_VIDEO_PATH,
|
DEFAULT_VIDEO_PATH,
|
||||||
|
DEFAULT_COMPRESSED_AUDIO_PATH,
|
||||||
get_hf_features_from_features,
|
get_hf_features_from_features,
|
||||||
hf_transform_to_torch,
|
hf_transform_to_torch,
|
||||||
)
|
)
|
||||||
|
@ -121,6 +122,7 @@ def info_factory(features_factory):
|
||||||
chunks_size: int = DEFAULT_CHUNK_SIZE,
|
chunks_size: int = DEFAULT_CHUNK_SIZE,
|
||||||
data_path: str = DEFAULT_PARQUET_PATH,
|
data_path: str = DEFAULT_PARQUET_PATH,
|
||||||
video_path: str = DEFAULT_VIDEO_PATH,
|
video_path: str = DEFAULT_VIDEO_PATH,
|
||||||
|
audio_path: str = DEFAULT_COMPRESSED_AUDIO_PATH,
|
||||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
||||||
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
||||||
use_videos: bool = True,
|
use_videos: bool = True,
|
||||||
|
@ -139,6 +141,7 @@ def info_factory(features_factory):
|
||||||
"splits": {},
|
"splits": {},
|
||||||
"data_path": data_path,
|
"data_path": data_path,
|
||||||
"video_path": video_path if use_videos else None,
|
"video_path": video_path if use_videos else None,
|
||||||
|
"audio_path": audio_path,
|
||||||
"features": features,
|
"features": features,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue