From a08b5c4105a7f58416cdd90a6ab4786851eab5b5 Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Fri, 11 Apr 2025 13:43:17 +0200 Subject: [PATCH] Adding last missing audio features in LeRobotDataset --- lerobot/common/datasets/lerobot_dataset.py | 66 ++++++++++++++++------ 1 file changed, 48 insertions(+), 18 deletions(-) diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index f844eb72..da5874fb 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -206,7 +206,7 @@ class LeRobotDatasetMetadata: @property def audio_keys(self) -> list[str]: - """Keys to access audio modalities.""" + """Keys to access audio modalities (wether they are linked to a camera or not).""" return [key for key, ft in self.features.items() if ft["dtype"] == "audio"] @property @@ -219,6 +219,16 @@ class LeRobotDatasetMetadata: and self.features[camera_key]["dtype"] == "video" } + @property + def linked_audio_keys(self) -> list[str]: + """Keys to access audio modalities linked to a camera.""" + return [key for key in self.audio_keys if key in self.audio_camera_keys_mapping] + + @property + def unlinked_audio_keys(self) -> list[str]: + """Keys to access audio modalities not linked to a camera.""" + return [key for key in self.audio_keys if key not in self.audio_camera_keys_mapping] + @property def names(self) -> dict[str, list | dict]: """Names of the various dimensions of vector modalities.""" @@ -298,7 +308,8 @@ class LeRobotDatasetMetadata: if len(self.video_keys) > 0: self.update_video_info() - if len(self.audio_keys) > 0: + self.info["total_audio"] += len(self.audio_keys) + if len(self.unlinked_audio_keys) > 0: self.update_audio_info() write_info(self.info, self.root) @@ -330,10 +341,10 @@ class LeRobotDatasetMetadata: Warning: this function writes info from first episode audio, implicitly assuming that all audio have been encoded the same way. Also, this means it assumes the first episode exists. """ - for key in set(self.audio_keys) - set(self.audio_camera_keys_mapping.keys()): + for key in self.unlinked_audio_keys: if not self.features[key].get("info", None) or ( len(self.features[key]["info"]) == 1 and "sample_rate" in self.features[key]["info"] - ): + ): #Overwrite if info is empty or only contains sample rate (necessary to correctly save audio files recorded by LeKiwi) audio_path = self.root / self.get_compressed_audio_file_path(0, key) self.info["features"][key]["info"] = get_audio_info(audio_path) @@ -412,6 +423,7 @@ class LeRobotDataset(torch.utils.data.Dataset): revision: str | None = None, force_cache_sync: bool = False, download_videos: bool = True, + download_audio: bool = True, video_backend: str | None = None, audio_backend: str | None = None, ): @@ -444,7 +456,8 @@ class LeRobotDataset(torch.utils.data.Dataset): - tasks contains the prompts for each task of the dataset, which can be used for task-conditioned training. - hf_dataset (from datasets.Dataset), which will read any values from parquet files. - - videos (optional) from which frames are loaded to be synchronous with data from parquet files. + - videos (optional) from which frames and audio (if any) are loaded to be synchronous with data from parquet files and audio. + - audio (optional) from which audio is loaded to be synchronous with data from parquet files and videos. A typical LeRobotDataset looks like this from its root path: . @@ -513,9 +526,10 @@ class LeRobotDataset(torch.utils.data.Dataset): download_videos (bool, optional): Flag to download the videos. Note that when set to True but the video files are already present on local disk, they won't be downloaded again. Defaults to True. + download_audio (bool, optional): Flag to download the audio (see download_videos). Defaults to True. video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'. You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision. - audio_backend (str | None, optional): Audio backend to use for decoding audio. Defaults to 'ffmpeg'. + audio_backend (str | None, optional): Audio backend to use for decoding audio. Defaults to 'ffmpeg' decoder used by 'torchaudio'. """ super().__init__() self.repo_id = repo_id @@ -554,8 +568,9 @@ class LeRobotDataset(torch.utils.data.Dataset): except (AssertionError, FileNotFoundError, NotADirectoryError): self.revision = get_safe_version(self.repo_id, self.revision) self.download_episodes( - download_videos - ) # Sould load audio as well #TODO(CarolinePascal): separate audio from video + download_videos, + download_audio + ) #Audio embedded in video files (.mp4) will be downloaded if download_videos is set to True, regardless of the value of download_audio self.hf_dataset = self.load_hf_dataset() self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes) @@ -578,6 +593,7 @@ class LeRobotDataset(torch.utils.data.Dataset): license: str | None = "apache-2.0", tag_version: bool = True, push_videos: bool = True, + push_audio: bool = True, private: bool = False, allow_patterns: list[str] | str | None = None, upload_large_folder: bool = False, @@ -585,7 +601,9 @@ class LeRobotDataset(torch.utils.data.Dataset): ) -> None: ignore_patterns = ["images/"] if not push_videos: - ignore_patterns.append("videos/") + ignore_patterns.append("videos/") #Audio embedded in video files (.mp4) will be automatically pushed if videos are pushed + if not push_audio: + ignore_patterns.append("audio/") hub_api = HfApi() hub_api.create_repo( @@ -641,7 +659,7 @@ class LeRobotDataset(torch.utils.data.Dataset): ignore_patterns=ignore_patterns, ) - def download_episodes(self, download_videos: bool = True) -> None: + def download_episodes(self, download_videos: bool = True, download_audio: bool = True) -> None: """Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present @@ -650,7 +668,11 @@ class LeRobotDataset(torch.utils.data.Dataset): # TODO(rcadene, aliberts): implement faster transfer # https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads files = None - ignore_patterns = None if download_videos else "videos/" + ignore_patterns = [] + if not download_videos: + ignore_patterns.append("videos/") #Audio embedded in video files (.mp4) will be automatically downloaded if videos are downloaded + if not download_audio: + ignore_patterns.append("audio/") if self.episodes is not None: files = self.get_episodes_file_paths() @@ -667,6 +689,14 @@ class LeRobotDataset(torch.utils.data.Dataset): ] fpaths += video_files + if len(self.meta.unlinked_audio_keys) > 0: + audio_files = [ + str(self.meta.get_compressed_audio_file_path(ep_idx, audio_key)) + for audio_key in self.meta.unlinked_audio_keys + for ep_idx in episodes + ] + fpaths += audio_files + return fpaths def load_hf_dataset(self) -> datasets.Dataset: @@ -755,7 +785,7 @@ class LeRobotDataset(torch.utils.data.Dataset): query_indices: dict[str, list[int]] | None = None, ) -> dict[str, list[float]]: query_timestamps = {} - for key in self.meta.audio_keys: + for key in self.meta.audio_keys: #Standalone audio and audio embedded in video as well ! if query_indices is not None and key in query_indices: timestamps = self.hf_dataset.select(query_indices[key])["timestamp"] query_timestamps[key] = torch.stack(timestamps).tolist() @@ -768,7 +798,7 @@ class LeRobotDataset(torch.utils.data.Dataset): return { key: torch.stack(self.hf_dataset.select(q_idx)[key]) for key, q_idx in query_indices.items() - if key not in self.meta.video_keys + if key not in self.meta.video_keys and key not in self.meta.audio_keys } def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]: @@ -791,12 +821,12 @@ class LeRobotDataset(torch.utils.data.Dataset): ) -> dict[str, torch.Tensor]: item = {} for audio_key, query_ts in query_timestamps.items(): - # Audio stored with video in a single .mp4 file - if audio_key in self.meta.audio_camera_keys_mapping: + #Audio stored with video in a single .mp4 file + if audio_key in self.meta.linked_audio_keys: audio_path = self.root / self.meta.get_video_file_path( ep_idx, self.meta.audio_camera_keys_mapping[audio_key] ) - # Audio stored alone in a separate .m4a file + #Audio stored alone in a separate .m4a file else: audio_path = self.root / self.meta.get_compressed_audio_file_path(ep_idx, audio_key) audio_chunk = decode_audio(audio_path, query_ts, query_duration, self.audio_backend) @@ -1023,7 +1053,7 @@ class LeRobotDataset(torch.utils.data.Dataset): for key in self.meta.video_keys: episode_buffer[key] = video_paths[key] - if len(self.meta.audio_keys) > 0: + if len(self.meta.unlinked_audio_keys) > 0: #Linked audio is already encoded in the video files _ = self.encode_episode_audio(episode_index) # `meta.save_episode` be executed after encoding the videos @@ -1150,7 +1180,7 @@ class LeRobotDataset(torch.utils.data.Dataset): since video encoding with ffmpeg is already using multithreading. """ audio_paths = {} - for audio_key in set(self.meta.audio_keys) - set(self.meta.audio_camera_keys_mapping.keys()): + for audio_key in self.meta.unlinked_audio_keys: input_audio_path = self.root / self._get_raw_audio_file_path(episode_index, audio_key) output_audio_path = self.root / self.meta.get_compressed_audio_file_path(episode_index, audio_key)