Adding last missing audio features in LeRobotDataset
This commit is contained in:
parent
0cb9345f06
commit
a08b5c4105
|
@ -206,7 +206,7 @@ class LeRobotDatasetMetadata:
|
|||
|
||||
@property
|
||||
def audio_keys(self) -> list[str]:
|
||||
"""Keys to access audio modalities."""
|
||||
"""Keys to access audio modalities (wether they are linked to a camera or not)."""
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] == "audio"]
|
||||
|
||||
@property
|
||||
|
@ -219,6 +219,16 @@ class LeRobotDatasetMetadata:
|
|||
and self.features[camera_key]["dtype"] == "video"
|
||||
}
|
||||
|
||||
@property
|
||||
def linked_audio_keys(self) -> list[str]:
|
||||
"""Keys to access audio modalities linked to a camera."""
|
||||
return [key for key in self.audio_keys if key in self.audio_camera_keys_mapping]
|
||||
|
||||
@property
|
||||
def unlinked_audio_keys(self) -> list[str]:
|
||||
"""Keys to access audio modalities not linked to a camera."""
|
||||
return [key for key in self.audio_keys if key not in self.audio_camera_keys_mapping]
|
||||
|
||||
@property
|
||||
def names(self) -> dict[str, list | dict]:
|
||||
"""Names of the various dimensions of vector modalities."""
|
||||
|
@ -298,7 +308,8 @@ class LeRobotDatasetMetadata:
|
|||
if len(self.video_keys) > 0:
|
||||
self.update_video_info()
|
||||
|
||||
if len(self.audio_keys) > 0:
|
||||
self.info["total_audio"] += len(self.audio_keys)
|
||||
if len(self.unlinked_audio_keys) > 0:
|
||||
self.update_audio_info()
|
||||
|
||||
write_info(self.info, self.root)
|
||||
|
@ -330,10 +341,10 @@ class LeRobotDatasetMetadata:
|
|||
Warning: this function writes info from first episode audio, implicitly assuming that all audio have
|
||||
been encoded the same way. Also, this means it assumes the first episode exists.
|
||||
"""
|
||||
for key in set(self.audio_keys) - set(self.audio_camera_keys_mapping.keys()):
|
||||
for key in self.unlinked_audio_keys:
|
||||
if not self.features[key].get("info", None) or (
|
||||
len(self.features[key]["info"]) == 1 and "sample_rate" in self.features[key]["info"]
|
||||
):
|
||||
): #Overwrite if info is empty or only contains sample rate (necessary to correctly save audio files recorded by LeKiwi)
|
||||
audio_path = self.root / self.get_compressed_audio_file_path(0, key)
|
||||
self.info["features"][key]["info"] = get_audio_info(audio_path)
|
||||
|
||||
|
@ -412,6 +423,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
revision: str | None = None,
|
||||
force_cache_sync: bool = False,
|
||||
download_videos: bool = True,
|
||||
download_audio: bool = True,
|
||||
video_backend: str | None = None,
|
||||
audio_backend: str | None = None,
|
||||
):
|
||||
|
@ -444,7 +456,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
- tasks contains the prompts for each task of the dataset, which can be used for
|
||||
task-conditioned training.
|
||||
- hf_dataset (from datasets.Dataset), which will read any values from parquet files.
|
||||
- videos (optional) from which frames are loaded to be synchronous with data from parquet files.
|
||||
- videos (optional) from which frames and audio (if any) are loaded to be synchronous with data from parquet files and audio.
|
||||
- audio (optional) from which audio is loaded to be synchronous with data from parquet files and videos.
|
||||
|
||||
A typical LeRobotDataset looks like this from its root path:
|
||||
.
|
||||
|
@ -513,9 +526,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
|
||||
video files are already present on local disk, they won't be downloaded again. Defaults to
|
||||
True.
|
||||
download_audio (bool, optional): Flag to download the audio (see download_videos). Defaults to True.
|
||||
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'.
|
||||
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
|
||||
audio_backend (str | None, optional): Audio backend to use for decoding audio. Defaults to 'ffmpeg'.
|
||||
audio_backend (str | None, optional): Audio backend to use for decoding audio. Defaults to 'ffmpeg' decoder used by 'torchaudio'.
|
||||
"""
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
|
@ -554,8 +568,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
except (AssertionError, FileNotFoundError, NotADirectoryError):
|
||||
self.revision = get_safe_version(self.repo_id, self.revision)
|
||||
self.download_episodes(
|
||||
download_videos
|
||||
) # Sould load audio as well #TODO(CarolinePascal): separate audio from video
|
||||
download_videos,
|
||||
download_audio
|
||||
) #Audio embedded in video files (.mp4) will be downloaded if download_videos is set to True, regardless of the value of download_audio
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
|
||||
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
||||
|
@ -578,6 +593,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
license: str | None = "apache-2.0",
|
||||
tag_version: bool = True,
|
||||
push_videos: bool = True,
|
||||
push_audio: bool = True,
|
||||
private: bool = False,
|
||||
allow_patterns: list[str] | str | None = None,
|
||||
upload_large_folder: bool = False,
|
||||
|
@ -585,7 +601,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
) -> None:
|
||||
ignore_patterns = ["images/"]
|
||||
if not push_videos:
|
||||
ignore_patterns.append("videos/")
|
||||
ignore_patterns.append("videos/") #Audio embedded in video files (.mp4) will be automatically pushed if videos are pushed
|
||||
if not push_audio:
|
||||
ignore_patterns.append("audio/")
|
||||
|
||||
hub_api = HfApi()
|
||||
hub_api.create_repo(
|
||||
|
@ -641,7 +659,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
ignore_patterns=ignore_patterns,
|
||||
)
|
||||
|
||||
def download_episodes(self, download_videos: bool = True) -> None:
|
||||
def download_episodes(self, download_videos: bool = True, download_audio: bool = True) -> None:
|
||||
"""Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this
|
||||
will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
|
||||
dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present
|
||||
|
@ -650,7 +668,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
# TODO(rcadene, aliberts): implement faster transfer
|
||||
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
|
||||
files = None
|
||||
ignore_patterns = None if download_videos else "videos/"
|
||||
ignore_patterns = []
|
||||
if not download_videos:
|
||||
ignore_patterns.append("videos/") #Audio embedded in video files (.mp4) will be automatically downloaded if videos are downloaded
|
||||
if not download_audio:
|
||||
ignore_patterns.append("audio/")
|
||||
if self.episodes is not None:
|
||||
files = self.get_episodes_file_paths()
|
||||
|
||||
|
@ -667,6 +689,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
]
|
||||
fpaths += video_files
|
||||
|
||||
if len(self.meta.unlinked_audio_keys) > 0:
|
||||
audio_files = [
|
||||
str(self.meta.get_compressed_audio_file_path(ep_idx, audio_key))
|
||||
for audio_key in self.meta.unlinked_audio_keys
|
||||
for ep_idx in episodes
|
||||
]
|
||||
fpaths += audio_files
|
||||
|
||||
return fpaths
|
||||
|
||||
def load_hf_dataset(self) -> datasets.Dataset:
|
||||
|
@ -755,7 +785,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
query_indices: dict[str, list[int]] | None = None,
|
||||
) -> dict[str, list[float]]:
|
||||
query_timestamps = {}
|
||||
for key in self.meta.audio_keys:
|
||||
for key in self.meta.audio_keys: #Standalone audio and audio embedded in video as well !
|
||||
if query_indices is not None and key in query_indices:
|
||||
timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
|
||||
query_timestamps[key] = torch.stack(timestamps).tolist()
|
||||
|
@ -768,7 +798,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
return {
|
||||
key: torch.stack(self.hf_dataset.select(q_idx)[key])
|
||||
for key, q_idx in query_indices.items()
|
||||
if key not in self.meta.video_keys
|
||||
if key not in self.meta.video_keys and key not in self.meta.audio_keys
|
||||
}
|
||||
|
||||
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
|
||||
|
@ -792,7 +822,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
item = {}
|
||||
for audio_key, query_ts in query_timestamps.items():
|
||||
#Audio stored with video in a single .mp4 file
|
||||
if audio_key in self.meta.audio_camera_keys_mapping:
|
||||
if audio_key in self.meta.linked_audio_keys:
|
||||
audio_path = self.root / self.meta.get_video_file_path(
|
||||
ep_idx, self.meta.audio_camera_keys_mapping[audio_key]
|
||||
)
|
||||
|
@ -1023,7 +1053,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
for key in self.meta.video_keys:
|
||||
episode_buffer[key] = video_paths[key]
|
||||
|
||||
if len(self.meta.audio_keys) > 0:
|
||||
if len(self.meta.unlinked_audio_keys) > 0: #Linked audio is already encoded in the video files
|
||||
_ = self.encode_episode_audio(episode_index)
|
||||
|
||||
# `meta.save_episode` be executed after encoding the videos
|
||||
|
@ -1150,7 +1180,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
since video encoding with ffmpeg is already using multithreading.
|
||||
"""
|
||||
audio_paths = {}
|
||||
for audio_key in set(self.meta.audio_keys) - set(self.meta.audio_camera_keys_mapping.keys()):
|
||||
for audio_key in self.meta.unlinked_audio_keys:
|
||||
input_audio_path = self.root / self._get_raw_audio_file_path(episode_index, audio_key)
|
||||
output_audio_path = self.root / self.meta.get_compressed_audio_file_path(episode_index, audio_key)
|
||||
|
||||
|
|
Loading…
Reference in New Issue