Adding last missing audio features in LeRobotDataset

This commit is contained in:
CarolinePascal 2025-04-11 13:43:17 +02:00
parent 0cb9345f06
commit a08b5c4105
No known key found for this signature in database
1 changed files with 48 additions and 18 deletions

View File

@ -206,7 +206,7 @@ class LeRobotDatasetMetadata:
@property
def audio_keys(self) -> list[str]:
"""Keys to access audio modalities."""
"""Keys to access audio modalities (wether they are linked to a camera or not)."""
return [key for key, ft in self.features.items() if ft["dtype"] == "audio"]
@property
@ -219,6 +219,16 @@ class LeRobotDatasetMetadata:
and self.features[camera_key]["dtype"] == "video"
}
@property
def linked_audio_keys(self) -> list[str]:
"""Keys to access audio modalities linked to a camera."""
return [key for key in self.audio_keys if key in self.audio_camera_keys_mapping]
@property
def unlinked_audio_keys(self) -> list[str]:
"""Keys to access audio modalities not linked to a camera."""
return [key for key in self.audio_keys if key not in self.audio_camera_keys_mapping]
@property
def names(self) -> dict[str, list | dict]:
"""Names of the various dimensions of vector modalities."""
@ -298,7 +308,8 @@ class LeRobotDatasetMetadata:
if len(self.video_keys) > 0:
self.update_video_info()
if len(self.audio_keys) > 0:
self.info["total_audio"] += len(self.audio_keys)
if len(self.unlinked_audio_keys) > 0:
self.update_audio_info()
write_info(self.info, self.root)
@ -330,10 +341,10 @@ class LeRobotDatasetMetadata:
Warning: this function writes info from first episode audio, implicitly assuming that all audio have
been encoded the same way. Also, this means it assumes the first episode exists.
"""
for key in set(self.audio_keys) - set(self.audio_camera_keys_mapping.keys()):
for key in self.unlinked_audio_keys:
if not self.features[key].get("info", None) or (
len(self.features[key]["info"]) == 1 and "sample_rate" in self.features[key]["info"]
):
): #Overwrite if info is empty or only contains sample rate (necessary to correctly save audio files recorded by LeKiwi)
audio_path = self.root / self.get_compressed_audio_file_path(0, key)
self.info["features"][key]["info"] = get_audio_info(audio_path)
@ -412,6 +423,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
revision: str | None = None,
force_cache_sync: bool = False,
download_videos: bool = True,
download_audio: bool = True,
video_backend: str | None = None,
audio_backend: str | None = None,
):
@ -444,7 +456,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
- tasks contains the prompts for each task of the dataset, which can be used for
task-conditioned training.
- hf_dataset (from datasets.Dataset), which will read any values from parquet files.
- videos (optional) from which frames are loaded to be synchronous with data from parquet files.
- videos (optional) from which frames and audio (if any) are loaded to be synchronous with data from parquet files and audio.
- audio (optional) from which audio is loaded to be synchronous with data from parquet files and videos.
A typical LeRobotDataset looks like this from its root path:
.
@ -513,9 +526,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
video files are already present on local disk, they won't be downloaded again. Defaults to
True.
download_audio (bool, optional): Flag to download the audio (see download_videos). Defaults to True.
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'.
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
audio_backend (str | None, optional): Audio backend to use for decoding audio. Defaults to 'ffmpeg'.
audio_backend (str | None, optional): Audio backend to use for decoding audio. Defaults to 'ffmpeg' decoder used by 'torchaudio'.
"""
super().__init__()
self.repo_id = repo_id
@ -554,8 +568,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
except (AssertionError, FileNotFoundError, NotADirectoryError):
self.revision = get_safe_version(self.repo_id, self.revision)
self.download_episodes(
download_videos
) # Sould load audio as well #TODO(CarolinePascal): separate audio from video
download_videos,
download_audio
) #Audio embedded in video files (.mp4) will be downloaded if download_videos is set to True, regardless of the value of download_audio
self.hf_dataset = self.load_hf_dataset()
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
@ -578,6 +593,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
license: str | None = "apache-2.0",
tag_version: bool = True,
push_videos: bool = True,
push_audio: bool = True,
private: bool = False,
allow_patterns: list[str] | str | None = None,
upload_large_folder: bool = False,
@ -585,7 +601,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
) -> None:
ignore_patterns = ["images/"]
if not push_videos:
ignore_patterns.append("videos/")
ignore_patterns.append("videos/") #Audio embedded in video files (.mp4) will be automatically pushed if videos are pushed
if not push_audio:
ignore_patterns.append("audio/")
hub_api = HfApi()
hub_api.create_repo(
@ -641,7 +659,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
ignore_patterns=ignore_patterns,
)
def download_episodes(self, download_videos: bool = True) -> None:
def download_episodes(self, download_videos: bool = True, download_audio: bool = True) -> None:
"""Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this
will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present
@ -650,7 +668,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
# TODO(rcadene, aliberts): implement faster transfer
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
files = None
ignore_patterns = None if download_videos else "videos/"
ignore_patterns = []
if not download_videos:
ignore_patterns.append("videos/") #Audio embedded in video files (.mp4) will be automatically downloaded if videos are downloaded
if not download_audio:
ignore_patterns.append("audio/")
if self.episodes is not None:
files = self.get_episodes_file_paths()
@ -667,6 +689,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
]
fpaths += video_files
if len(self.meta.unlinked_audio_keys) > 0:
audio_files = [
str(self.meta.get_compressed_audio_file_path(ep_idx, audio_key))
for audio_key in self.meta.unlinked_audio_keys
for ep_idx in episodes
]
fpaths += audio_files
return fpaths
def load_hf_dataset(self) -> datasets.Dataset:
@ -755,7 +785,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
query_indices: dict[str, list[int]] | None = None,
) -> dict[str, list[float]]:
query_timestamps = {}
for key in self.meta.audio_keys:
for key in self.meta.audio_keys: #Standalone audio and audio embedded in video as well !
if query_indices is not None and key in query_indices:
timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
query_timestamps[key] = torch.stack(timestamps).tolist()
@ -768,7 +798,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
return {
key: torch.stack(self.hf_dataset.select(q_idx)[key])
for key, q_idx in query_indices.items()
if key not in self.meta.video_keys
if key not in self.meta.video_keys and key not in self.meta.audio_keys
}
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
@ -791,12 +821,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
) -> dict[str, torch.Tensor]:
item = {}
for audio_key, query_ts in query_timestamps.items():
# Audio stored with video in a single .mp4 file
if audio_key in self.meta.audio_camera_keys_mapping:
#Audio stored with video in a single .mp4 file
if audio_key in self.meta.linked_audio_keys:
audio_path = self.root / self.meta.get_video_file_path(
ep_idx, self.meta.audio_camera_keys_mapping[audio_key]
)
# Audio stored alone in a separate .m4a file
#Audio stored alone in a separate .m4a file
else:
audio_path = self.root / self.meta.get_compressed_audio_file_path(ep_idx, audio_key)
audio_chunk = decode_audio(audio_path, query_ts, query_duration, self.audio_backend)
@ -1023,7 +1053,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
for key in self.meta.video_keys:
episode_buffer[key] = video_paths[key]
if len(self.meta.audio_keys) > 0:
if len(self.meta.unlinked_audio_keys) > 0: #Linked audio is already encoded in the video files
_ = self.encode_episode_audio(episode_index)
# `meta.save_episode` be executed after encoding the videos
@ -1150,7 +1180,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
since video encoding with ffmpeg is already using multithreading.
"""
audio_paths = {}
for audio_key in set(self.meta.audio_keys) - set(self.meta.audio_camera_keys_mapping.keys()):
for audio_key in self.meta.unlinked_audio_keys:
input_audio_path = self.root / self._get_raw_audio_file_path(episode_index, audio_key)
output_audio_path = self.root / self.meta.get_compressed_audio_file_path(episode_index, audio_key)