Cleaning up bound/linked audio keys mapping recovery

This commit is contained in:
CarolinePascal 2025-04-07 13:21:55 +02:00
parent b00e866c60
commit 17ad249335
No known key found for this signature in database
2 changed files with 10 additions and 7 deletions

View File

@ -207,6 +207,11 @@ class LeRobotDatasetMetadata:
"""Keys to access audio modalities."""
return [key for key, ft in self.features.items() if ft["dtype"] == "audio"]
@property
def audio_camera_keys_mapping(self) -> dict[str, str]:
"""Mapping between camera keys and audio keys when both are linked."""
return {self.features[camera_key]["audio"]:camera_key for camera_key in self.camera_keys if self.features[camera_key]["audio"] is not None}
@property
def names(self) -> dict[str, list | dict]:
"""Names of the various dimensions of vector modalities."""
@ -318,8 +323,7 @@ class LeRobotDatasetMetadata:
Warning: this function writes info from first episode audio, implicitly assuming that all audio have
been encoded the same way. Also, this means it assumes the first episode exists.
"""
bound_audio_keys = {self.features[video_key]["audio"] for video_key in self.video_keys if self.features[video_key]["audio"] is not None}
for key in set(self.audio_keys) - bound_audio_keys:
for key in set(self.audio_keys) - set(self.audio_camera_keys_mapping.keys()):
if not self.features[key].get("info", None):
audio_path = self.root / self.get_compressed_audio_file_path(0, key)
self.info["features"][key]["info"] = get_audio_info(audio_path)
@ -771,11 +775,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
#TODO(CarolinePascal): add variable query durations
def _query_audio(self, query_timestamps: dict[str, list[float]], query_duration: float, ep_idx: int) -> dict[str, torch.Tensor]:
item = {}
bound_audio_keys_mapping = {self.meta.features[video_key]["audio"]:video_key for video_key in self.meta.video_keys if self.meta.features[video_key]["audio"] is not None}
for audio_key, query_ts in query_timestamps.items():
#Audio stored with video in a single .mp4 file
if audio_key in bound_audio_keys_mapping.keys():
audio_path = self.root / self.meta.get_video_file_path(ep_idx, bound_audio_keys_mapping[audio_key])
if audio_key in self.meta.audio_camera_keys_mapping:
audio_path = self.root / self.meta.get_video_file_path(ep_idx, self.meta.audio_camera_keys_mapping[audio_key])
#Audio stored alone in a separate .m4a file
else:
audio_path = self.root / self.meta.get_compressed_audio_file_path(ep_idx, audio_key)
@ -1103,8 +1106,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
since video encoding with ffmpeg is already using multithreading.
"""
audio_paths = {}
bound_audio_keys = {self.meta.features[video_key]["audio"] for video_key in self.meta.video_keys if self.meta.features[video_key]["audio"] is not None}
for audio_key in set(self.meta.audio_keys) - bound_audio_keys:
for audio_key in set(self.meta.audio_keys) - set(self.meta.audio_camera_keys_mapping.keys()):
input_audio_path = self.root / self._get_raw_audio_file_path(episode_index, audio_key)
output_audio_path = self.root / self.meta.get_compressed_audio_file_path(episode_index, audio_key)

View File

@ -44,6 +44,7 @@ class ManipulatorRobotConfig(RobotConfig):
leader_arms: dict[str, MotorsBusConfig] = field(default_factory=lambda: {})
follower_arms: dict[str, MotorsBusConfig] = field(default_factory=lambda: {})
cameras: dict[str, CameraConfig] = field(default_factory=lambda: {})
microphones: dict[str, MicrophoneConfig] = field(default_factory=lambda: {})
# Optionally limit the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length