Cleaning up bound/linked audio keys mapping recovery

This commit is contained in:
CarolinePascal 2025-04-07 13:21:55 +02:00
parent b00e866c60
commit 17ad249335
No known key found for this signature in database
2 changed files with 10 additions and 7 deletions

View File

@ -207,6 +207,11 @@ class LeRobotDatasetMetadata:
"""Keys to access audio modalities.""" """Keys to access audio modalities."""
return [key for key, ft in self.features.items() if ft["dtype"] == "audio"] return [key for key, ft in self.features.items() if ft["dtype"] == "audio"]
@property
def audio_camera_keys_mapping(self) -> dict[str, str]:
"""Mapping between camera keys and audio keys when both are linked."""
return {self.features[camera_key]["audio"]:camera_key for camera_key in self.camera_keys if self.features[camera_key]["audio"] is not None}
@property @property
def names(self) -> dict[str, list | dict]: def names(self) -> dict[str, list | dict]:
"""Names of the various dimensions of vector modalities.""" """Names of the various dimensions of vector modalities."""
@ -318,8 +323,7 @@ class LeRobotDatasetMetadata:
Warning: this function writes info from first episode audio, implicitly assuming that all audio have Warning: this function writes info from first episode audio, implicitly assuming that all audio have
been encoded the same way. Also, this means it assumes the first episode exists. been encoded the same way. Also, this means it assumes the first episode exists.
""" """
bound_audio_keys = {self.features[video_key]["audio"] for video_key in self.video_keys if self.features[video_key]["audio"] is not None} for key in set(self.audio_keys) - set(self.audio_camera_keys_mapping.keys()):
for key in set(self.audio_keys) - bound_audio_keys:
if not self.features[key].get("info", None): if not self.features[key].get("info", None):
audio_path = self.root / self.get_compressed_audio_file_path(0, key) audio_path = self.root / self.get_compressed_audio_file_path(0, key)
self.info["features"][key]["info"] = get_audio_info(audio_path) self.info["features"][key]["info"] = get_audio_info(audio_path)
@ -771,11 +775,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
#TODO(CarolinePascal): add variable query durations #TODO(CarolinePascal): add variable query durations
def _query_audio(self, query_timestamps: dict[str, list[float]], query_duration: float, ep_idx: int) -> dict[str, torch.Tensor]: def _query_audio(self, query_timestamps: dict[str, list[float]], query_duration: float, ep_idx: int) -> dict[str, torch.Tensor]:
item = {} item = {}
bound_audio_keys_mapping = {self.meta.features[video_key]["audio"]:video_key for video_key in self.meta.video_keys if self.meta.features[video_key]["audio"] is not None}
for audio_key, query_ts in query_timestamps.items(): for audio_key, query_ts in query_timestamps.items():
#Audio stored with video in a single .mp4 file #Audio stored with video in a single .mp4 file
if audio_key in bound_audio_keys_mapping.keys(): if audio_key in self.meta.audio_camera_keys_mapping:
audio_path = self.root / self.meta.get_video_file_path(ep_idx, bound_audio_keys_mapping[audio_key]) audio_path = self.root / self.meta.get_video_file_path(ep_idx, self.meta.audio_camera_keys_mapping[audio_key])
#Audio stored alone in a separate .m4a file #Audio stored alone in a separate .m4a file
else: else:
audio_path = self.root / self.meta.get_compressed_audio_file_path(ep_idx, audio_key) audio_path = self.root / self.meta.get_compressed_audio_file_path(ep_idx, audio_key)
@ -1103,8 +1106,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
since video encoding with ffmpeg is already using multithreading. since video encoding with ffmpeg is already using multithreading.
""" """
audio_paths = {} audio_paths = {}
bound_audio_keys = {self.meta.features[video_key]["audio"] for video_key in self.meta.video_keys if self.meta.features[video_key]["audio"] is not None} for audio_key in set(self.meta.audio_keys) - set(self.meta.audio_camera_keys_mapping.keys()):
for audio_key in set(self.meta.audio_keys) - bound_audio_keys:
input_audio_path = self.root / self._get_raw_audio_file_path(episode_index, audio_key) input_audio_path = self.root / self._get_raw_audio_file_path(episode_index, audio_key)
output_audio_path = self.root / self.meta.get_compressed_audio_file_path(episode_index, audio_key) output_audio_path = self.root / self.meta.get_compressed_audio_file_path(episode_index, audio_key)

View File

@ -44,6 +44,7 @@ class ManipulatorRobotConfig(RobotConfig):
leader_arms: dict[str, MotorsBusConfig] = field(default_factory=lambda: {}) leader_arms: dict[str, MotorsBusConfig] = field(default_factory=lambda: {})
follower_arms: dict[str, MotorsBusConfig] = field(default_factory=lambda: {}) follower_arms: dict[str, MotorsBusConfig] = field(default_factory=lambda: {})
cameras: dict[str, CameraConfig] = field(default_factory=lambda: {}) cameras: dict[str, CameraConfig] = field(default_factory=lambda: {})
microphones: dict[str, MicrophoneConfig] = field(default_factory=lambda: {})
# Optionally limit the magnitude of the relative positional target vector for safety purposes. # Optionally limit the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length # Set this to a positive scalar to have the same value for all motors, or a list that is the same length