From 8ee61bb81fd27bbdae2c5cee672663346e3fd2fa Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Fri, 28 Mar 2025 17:16:51 +0100 Subject: [PATCH] Adding audio modality in LeRobotDatasets --- lerobot/common/datasets/lerobot_dataset.py | 141 +++++++++++++++-- lerobot/common/datasets/utils.py | 16 +- lerobot/common/datasets/video_utils.py | 148 ++++++++++++++++-- .../robot_devices/robots/manipulator.py | 2 +- tests/fixtures/dataset_factories.py | 3 + 5 files changed, 282 insertions(+), 28 deletions(-) diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index d8da85d6..488b7696 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -36,8 +36,11 @@ from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image from lerobot.common.datasets.utils import ( DEFAULT_FEATURES, DEFAULT_IMAGE_PATH, + DEFAULT_RAW_AUDIO_PATH, + DEFAULT_COMPRESSED_AUDIO_PATH, INFO_PATH, TASKS_PATH, + DEFAULT_AUDIO_CHUNK_DURATION, append_jsonlines, backward_compatible_episodes_stats, check_delta_timestamps, @@ -69,8 +72,11 @@ from lerobot.common.datasets.video_utils import ( VideoFrame, decode_video_frames, encode_video_frames, + encode_audio, + decode_audio, get_safe_default_codec, get_video_info, + get_audio_info, ) from lerobot.common.robot_devices.robots.utils import Robot @@ -141,6 +147,11 @@ class LeRobotDatasetMetadata: ep_chunk = self.get_episode_chunk(ep_index) fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index) return Path(fpath) + + def get_compressed_audio_file_path(self, episode_index: int, audio_key: str) -> Path: + episode_chunk = self.get_episode_chunk(episode_index) + fpath = self.audio_path.format(episode_chunk=episode_chunk, audio_key=audio_key, episode_index=episode_index) + return self.root / fpath def get_episode_chunk(self, ep_index: int) -> int: return ep_index // self.chunks_size @@ -154,6 +165,11 @@ class LeRobotDatasetMetadata: def video_path(self) -> str | None: """Formattable string for the video files.""" return self.info["video_path"] + + @property + def audio_path(self) -> str | None: + """Formattable string for the audio files.""" + return self.info["audio_path"] @property def robot_type(self) -> str | None: @@ -184,6 +200,11 @@ class LeRobotDatasetMetadata: def camera_keys(self) -> list[str]: """Keys to access visual modalities (regardless of their storage method).""" return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]] + + @property + def audio_keys(self) -> list[str]: + """Keys to access audio modalities.""" + return [key for key, ft in self.features.items() if ft["dtype"] == "audio"] @property def names(self) -> dict[str, list | dict]: @@ -264,6 +285,9 @@ class LeRobotDatasetMetadata: if len(self.video_keys) > 0: self.update_video_info() + if len(self.audio_keys) > 0: + self.update_audio_info() + write_info(self.info, self.root) episode_dict = { @@ -288,6 +312,17 @@ class LeRobotDatasetMetadata: video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key) self.info["features"][key]["info"] = get_video_info(video_path) + def update_audio_info(self) -> None: + """ + Warning: this function writes info from first episode audio, implicitly assuming that all audio have + been encoded the same way. Also, this means it assumes the first episode exists. + """ + bound_audio_keys = {self.features[video_key]["audio"] for video_key in self.video_keys if self.features[video_key]["audio"] is not None} + for key in set(self.audio_keys) - bound_audio_keys: + if not self.features[key].get("info", None): + audio_path = self.root / self.get_compressed_audio_file_path(0, key) + self.info["features"][key]["info"] = get_audio_info(audio_path) + def __repr__(self): feature_keys = list(self.features) return ( @@ -364,6 +399,7 @@ class LeRobotDataset(torch.utils.data.Dataset): force_cache_sync: bool = False, download_videos: bool = True, video_backend: str | None = None, + audio_backend: str | None = None, ): """ 2 modes are available for instantiating this class, depending on 2 different use cases: @@ -465,6 +501,7 @@ class LeRobotDataset(torch.utils.data.Dataset): True. video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'. You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision. + audio_backend (str | None, optional): Audio backend to use for decoding audio. Defaults to 'ffmpeg'. """ super().__init__() self.repo_id = repo_id @@ -475,6 +512,7 @@ class LeRobotDataset(torch.utils.data.Dataset): self.tolerance_s = tolerance_s self.revision = revision if revision else CODEBASE_VERSION self.video_backend = video_backend if video_backend else get_safe_default_codec() + self.audio_backend = audio_backend if audio_backend else "ffmpeg" #Waiting for torchcodec release #TODO(CarolinePascal) self.delta_indices = None # Unused attributes @@ -499,7 +537,7 @@ class LeRobotDataset(torch.utils.data.Dataset): self.hf_dataset = self.load_hf_dataset() except (AssertionError, FileNotFoundError, NotADirectoryError): self.revision = get_safe_version(self.repo_id, self.revision) - self.download_episodes(download_videos) + self.download_episodes(download_videos) #Sould load audio as well #TODO(CarolinePascal): separate audio from video self.hf_dataset = self.load_hf_dataset() self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes) @@ -677,7 +715,7 @@ class LeRobotDataset(torch.utils.data.Dataset): } return query_indices, padding - def _get_query_timestamps( + def _get_query_timestamps_video( self, current_ts: float, query_indices: dict[str, list[int]] | None = None, @@ -691,6 +729,22 @@ class LeRobotDataset(torch.utils.data.Dataset): query_timestamps[key] = [current_ts] return query_timestamps + + #TODO(CarolinePascal): add variable query durations + def _get_query_timestamps_audio( + self, + current_ts: float, + query_indices: dict[str, list[int]] | None = None, + ) -> dict[str, list[float]]: + query_timestamps = {} + for key in self.meta.audio_keys: + if query_indices is not None and key in query_indices: + timestamps = self.hf_dataset.select(query_indices[key])["timestamp"] + query_timestamps[key] = torch.stack(timestamps).tolist() + else: + query_timestamps[key] = [current_ts] + + return query_timestamps def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict: return { @@ -713,6 +767,21 @@ class LeRobotDataset(torch.utils.data.Dataset): return item + #TODO(CarolinePascal): add variable query durations + def _query_audio(self, query_timestamps: dict[str, list[float]], query_duration: float, ep_idx: int) -> dict[str, torch.Tensor]: + item = {} + bound_audio_keys_mapping = {self.meta.features[video_key]["audio"]:video_key for video_key in self.meta.video_keys if self.meta.features[video_key]["audio"] is not None} + for audio_key, query_ts in query_timestamps.items(): + #Audio stored with video in a single .mp4 file + if audio_key in bound_audio_keys_mapping.keys(): + audio_path = self.root / self.meta.get_video_file_path(ep_idx, bound_audio_keys_mapping[audio_key]) + #Audio stored alone in a separate .m4a file + else: + audio_path = self.root / self.meta.get_compressed_audio_file_path(ep_idx, audio_key) + audio_chunk = decode_audio(audio_path, query_ts, query_duration, self.audio_backend) + item[audio_key] = audio_chunk.squeeze(0) + return item + def _add_padding_keys(self, item: dict, padding: dict[str, list[bool]]) -> dict: for key, val in padding.items(): item[key] = torch.BoolTensor(val) @@ -733,11 +802,16 @@ class LeRobotDataset(torch.utils.data.Dataset): for key, val in query_result.items(): item[key] = val - if len(self.meta.video_keys) > 0: + if len(self.meta.video_keys) > 0 or len(self.meta.audio_keys) > 0: current_ts = item["timestamp"].item() - query_timestamps = self._get_query_timestamps(current_ts, query_indices) + + query_timestamps = self._get_query_timestamps_video(current_ts, query_indices) video_frames = self._query_videos(query_timestamps, ep_idx) - item = {**video_frames, **item} + item = {**item, **video_frames} + + query_timestamps = self._get_query_timestamps_audio(current_ts, query_indices) + audio_chunks = self._query_audio(query_timestamps, DEFAULT_AUDIO_CHUNK_DURATION, ep_idx) + item = {**item, **audio_chunks} if self.image_transforms is not None: image_keys = self.meta.camera_keys @@ -776,6 +850,10 @@ class LeRobotDataset(torch.utils.data.Dataset): image_key=image_key, episode_index=episode_index, frame_index=frame_index ) return self.root / fpath + + def _get_raw_audio_file_path(self, episode_index: int, audio_key: str) -> Path: + fpath = DEFAULT_RAW_AUDIO_PATH.format(audio_key=audio_key, episode_index=episode_index) + return self.root / fpath def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None: if self.image_writer is None: @@ -867,7 +945,7 @@ class LeRobotDataset(torch.utils.data.Dataset): for key, ft in self.features.items(): # index, episode_index, task_index are already processed above, and image and video # are processed separately by storing image path and frame info as meta data - if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]: + if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video", "audio"]: continue episode_buffer[key] = np.stack(episode_buffer[key]) @@ -880,6 +958,9 @@ class LeRobotDataset(torch.utils.data.Dataset): for key in self.meta.video_keys: episode_buffer[key] = video_paths[key] + if len(self.meta.audio_keys) > 0: + _ = self.encode_episode_audio(episode_index) + # `meta.save_episode` be executed after encoding the videos self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats) @@ -904,6 +985,13 @@ class LeRobotDataset(torch.utils.data.Dataset): if img_dir.is_dir(): shutil.rmtree(self.root / "images") + # delete raw audio + raw_audio_files = list(self.root.rglob("*.wav")) + for raw_audio_file in raw_audio_files: + raw_audio_file.unlink() + if len(list(raw_audio_file.parent.iterdir())) == 0: + raw_audio_file.parent.rmdir() + if not episode_data: # Reset the buffer self.episode_buffer = self.create_episode_buffer() @@ -971,18 +1059,45 @@ class LeRobotDataset(torch.utils.data.Dataset): since video encoding with ffmpeg is already using multithreading. """ video_paths = {} - for key in self.meta.video_keys: - video_path = self.root / self.meta.get_video_file_path(episode_index, key) - video_paths[key] = str(video_path) + for video_key in self.meta.video_keys: + video_path = self.root / self.meta.get_video_file_path(episode_index, video_key) + video_paths[video_key] = str(video_path) if video_path.is_file(): # Skip if video is already encoded. Could be the case when resuming data recording. continue img_dir = self._get_image_file_path( - episode_index=episode_index, image_key=key, frame_index=0 + episode_index=episode_index, image_key=video_key, frame_index=0 ).parent - encode_video_frames(img_dir, video_path, self.fps, overwrite=True) + + audio_path = None + if self.meta.features[video_key]["audio"] is not None: + audio_key = self.meta.features[video_key]["audio"] + audio_path = self._get_raw_audio_file_path(episode_index, audio_key) + + encode_video_frames(img_dir, video_path, self.fps, audio_path=audio_path, overwrite=True) return video_paths + + def encode_episode_audio(self, episode_index: int) -> dict: + """ + Use ffmpeg to convert .wav raw audio files into .m4a audio files. + Note: `encode_episode_audio` is a blocking call. Making it asynchronous shouldn't speedup encoding, + since video encoding with ffmpeg is already using multithreading. + """ + audio_paths = {} + bound_audio_keys = {self.meta.features[video_key]["audio"] for video_key in self.meta.video_keys if self.meta.features[video_key]["audio"] is not None} + for audio_key in set(self.meta.audio_keys) - bound_audio_keys: + input_audio_path = self.root / self._get_raw_audio_file_path(episode_index, audio_key) + output_audio_path = self.root / self.meta.get_compressed_audio_file_path(episode_index, audio_key) + + audio_paths[audio_key] = str(output_audio_path) + if output_audio_path.is_file(): + # Skip if video is already encoded. Could be the case when resuming data recording. + continue + + encode_audio(input_audio_path, output_audio_path, overwrite=True) + + return audio_paths @classmethod def create( @@ -998,6 +1113,7 @@ class LeRobotDataset(torch.utils.data.Dataset): image_writer_processes: int = 0, image_writer_threads: int = 0, video_backend: str | None = None, + audio_backend: str | None = None, ) -> "LeRobotDataset": """Create a LeRobot Dataset from scratch in order to record data.""" obj = cls.__new__(cls) @@ -1029,6 +1145,7 @@ class LeRobotDataset(torch.utils.data.Dataset): obj.delta_indices = None obj.episode_data_index = None obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec() + obj.audio_backend = audio_backend if audio_backend is not None else "ffmpeg" #Waiting for torchcodec release #TODO(CarolinePascal) return obj @@ -1049,6 +1166,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): tolerances_s: dict | None = None, download_videos: bool = True, video_backend: str | None = None, + audio_backend: str | None = None, ): super().__init__() self.repo_ids = repo_ids @@ -1066,6 +1184,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): tolerance_s=self.tolerances_s[repo_id], download_videos=download_videos, video_backend=video_backend, + audio_backend=audio_backend, ) for repo_id in repo_ids ] diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 9d8a54db..827e711b 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -55,6 +55,10 @@ TASKS_PATH = "meta/tasks.jsonl" DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4" DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet" DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png" +DEFAULT_RAW_AUDIO_PATH = "audio/{audio_key}/episode_{episode_index:06d}.wav" +DEFAULT_COMPRESSED_AUDIO_PATH = "audio/chunk-{episode_chunk:03d}/{audio_key}/episode_{episode_index:06d}.m4a" + +DEFAULT_AUDIO_CHUNK_DURATION = 0.5 # seconds DATASET_CARD_TEMPLATE = """ --- @@ -363,7 +367,7 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) -> def get_hf_features_from_features(features: dict) -> datasets.Features: hf_features = {} for key, ft in features.items(): - if ft["dtype"] == "video": + if ft["dtype"] == "video" or ft["dtype"] == "audio": continue elif ft["dtype"] == "image": hf_features[key] = datasets.Image() @@ -394,7 +398,13 @@ def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict: key: {"dtype": "video" if use_videos else "image", **ft} for key, ft in robot.camera_features.items() } - return {**robot.motor_features, **camera_ft, **DEFAULT_FEATURES} + microphones_ft = {} + if robot.microphones: + microphones_ft = { + key: {"dtype": "audio", **ft} + for key, ft in robot.microphones_features.items() + } + return {**robot.motor_features, **camera_ft, **microphones_ft, **DEFAULT_FEATURES} def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]: @@ -448,6 +458,7 @@ def create_empty_dataset_info( "splits": {}, "data_path": DEFAULT_PARQUET_PATH, "video_path": DEFAULT_VIDEO_PATH if use_videos else None, + "audio_path": DEFAULT_COMPRESSED_AUDIO_PATH, "features": features, } @@ -721,6 +732,7 @@ def validate_features_presence( ): error_message = "" missing_features = expected_features - actual_features + missing_features = {feature for feature in missing_features if "observation.audio" not in feature} extra_features = actual_features - (expected_features | optional_features) if missing_features or extra_features: diff --git a/lerobot/common/datasets/video_utils.py b/lerobot/common/datasets/video_utils.py index c38d570d..44d5a1a5 100644 --- a/lerobot/common/datasets/video_utils.py +++ b/lerobot/common/datasets/video_utils.py @@ -26,9 +26,12 @@ from typing import Any, ClassVar import pyarrow as pa import torch import torchvision +import torchaudio from datasets.features.features import register_feature from PIL import Image +from numpy import ceil + def get_safe_default_codec(): if importlib.util.find_spec("torchcodec"): @@ -39,7 +42,72 @@ def get_safe_default_codec(): ) return "pyav" +def decode_audio( + audio_path: Path | str, + timestamps: list[float], + duration: float, + backend: str | None = "ffmpeg", +) -> torch.Tensor: + """ + Decodes audio using the specified backend. + Args: + audio_path (Path): Path to the audio file. + timestamps (list[float]): List of timestamps to extract frames. + tolerance_s (float): Allowed deviation in seconds for frame retrieval. + backend (str, optional): Backend to use for decoding. Defaults to "pyav". + Returns: + torch.Tensor: Decoded frames. + + Currently supports pyav. + """ + if backend == "torchcodec": + raise NotImplementedError("torchcodec is not yet supported for audio decoding") + elif backend == "ffmpeg": + return decode_audio_torchvision(audio_path, timestamps, duration) + else: + raise ValueError(f"Unsupported video backend: {backend}") + +def decode_audio_torchvision( + audio_path: Path | str, + timestamps: list[float], + duration: float, + log_loaded_timestamps: bool = False, +) -> torch.Tensor: + + #TODO(CarolinePascal) : add channels selection + audio_path = str(audio_path) + + reader = torchaudio.io.StreamReader(src=audio_path) + audio_sampling_rate = reader.get_src_stream_info(reader.default_audio_stream).sample_rate + + #TODO(CarolinePascal) : sort timestamps ? + + reader.add_basic_audio_stream( + frames_per_chunk = int(ceil(duration * audio_sampling_rate)), #Too much is better than not enough + buffer_chunk_size = -1, #No dropping frames + ) + + audio_chunks = [] + for ts in timestamps: + reader.seek(ts) #Default to closest audio sample + status = reader.fill_buffer() + if status != 0: + logging.warning("Audio stream reached end of recording before decoding desired timestamps.") + + current_audio_chunk = reader.pop_chunks()[0] + + if log_loaded_timestamps: + logging.info(f"audio chunk loaded at starting timestamp={current_audio_chunk["pts"]:.4f} with duration={len(current_audio_chunk) / audio_sampling_rate:.4f}") + + audio_chunks.append(current_audio_chunk) + + audio_chunks = torch.stack(audio_chunks) + #TODO(CarolinePascal) : pytorch format conversion ? + + assert len(timestamps) == len(audio_chunks) + return audio_chunks + def decode_video_frames( video_path: Path | str, timestamps: list[float], @@ -69,7 +137,6 @@ def decode_video_frames( else: raise ValueError(f"Unsupported video backend: {backend}") - def decode_video_frames_torchvision( video_path: Path | str, timestamps: list[float], @@ -167,7 +234,6 @@ def decode_video_frames_torchvision( assert len(timestamps) == len(closest_frames) return closest_frames - def decode_video_frames_torchcodec( video_path: Path | str, timestamps: list[float], @@ -242,15 +308,52 @@ def decode_video_frames_torchcodec( assert len(timestamps) == len(closest_frames) return closest_frames +def encode_audio( + input_path: Path | str, + output_path: Path | str, + codec: str = "aac", + log_level: str | None = "error", + overwrite: bool = False, +) -> None: + """Encodes an audio file using ffmpeg.""" + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + ffmpeg_args = OrderedDict( + [ + ("-i", str(input_path)), + ("-acodec", codec), + ] + ) + + if log_level is not None: + ffmpeg_args["-loglevel"] = str(log_level) + + ffmpeg_args = [item for pair in ffmpeg_args.items() for item in pair] + if overwrite: + ffmpeg_args.append("-y") + + ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(output_path)] + + # redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal + subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL) + + if not output_path.exists(): + raise OSError( + f"Video encoding did not work. File not found: {output_path}. " + f"Try running the command manually to debug: `{''.join(ffmpeg_cmd)}`" + ) def encode_video_frames( imgs_dir: Path | str, video_path: Path | str, fps: int, + audio_path: Path | str | None = None, vcodec: str = "libsvtav1", pix_fmt: str = "yuv420p", g: int | None = 2, crf: int | None = 30, + acodec: str = "aac", #TODO(CarolinePascal) : investigate Fraunhofer FDK AAC (libfdk_aac) codec and and constant (file size control) /variable (quality control) bitrate options fast_decode: int = 0, log_level: str | None = "error", overwrite: bool = False, @@ -260,35 +363,53 @@ def encode_video_frames( imgs_dir = Path(imgs_dir) video_path.parent.mkdir(parents=True, exist_ok=True) - ffmpeg_args = OrderedDict( + ffmpeg_video_args = OrderedDict( [ ("-f", "image2"), ("-r", str(fps)), - ("-i", str(imgs_dir / "frame_%06d.png")), - ("-vcodec", vcodec), - ("-pix_fmt", pix_fmt), + ("-i", str(Path(imgs_dir) / "frame_%06d.png")), ] ) + ffmpeg_audio_args = OrderedDict() + if audio_path is not None: + audio_path = Path(audio_path) + audio_path.parent.mkdir(parents=True, exist_ok=True) + ffmpeg_audio_args.update(OrderedDict( + [ + ("-i", str(audio_path)), + ] + )) + + ffmpeg_encoding_args = OrderedDict( + [ + ("-pix_fmt", pix_fmt), + ("-vcodec", vcodec), + ] + ) if g is not None: - ffmpeg_args["-g"] = str(g) - + ffmpeg_encoding_args["-g"] = str(g) if crf is not None: - ffmpeg_args["-crf"] = str(crf) - + ffmpeg_encoding_args["-crf"] = str(crf) if fast_decode: key = "-svtav1-params" if vcodec == "libsvtav1" else "-tune" value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode" - ffmpeg_args[key] = value + ffmpeg_encoding_args[key] = value + if audio_path is not None: + ffmpeg_encoding_args["-acodec"] = acodec + if log_level is not None: - ffmpeg_args["-loglevel"] = str(log_level) + ffmpeg_encoding_args["-loglevel"] = str(log_level) - ffmpeg_args = [item for pair in ffmpeg_args.items() for item in pair] + ffmpeg_args = [item for pair in ffmpeg_video_args.items() for item in pair] + ffmpeg_args += [item for pair in ffmpeg_audio_args.items() for item in pair] + ffmpeg_args += [item for pair in ffmpeg_encoding_args.items() for item in pair] if overwrite: ffmpeg_args.append("-y") ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(video_path)] + # redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL) @@ -366,7 +487,6 @@ def get_audio_info(video_path: Path | str) -> dict: "audio.channel_layout": audio_stream_info.get("channel_layout", None), } - def get_video_info(video_path: Path | str) -> dict: ffprobe_video_cmd = [ "ffprobe", diff --git a/lerobot/common/robot_devices/robots/manipulator.py b/lerobot/common/robot_devices/robots/manipulator.py index 443466ed..afa4006a 100644 --- a/lerobot/common/robot_devices/robots/manipulator.py +++ b/lerobot/common/robot_devices/robots/manipulator.py @@ -216,7 +216,7 @@ class ManipulatorRobot: @property def features(self): - return {**self.motor_features, **self.camera_features} + return {**self.motor_features, **self.camera_features, **self.microphones_features} @property def has_camera(self): diff --git a/tests/fixtures/dataset_factories.py b/tests/fixtures/dataset_factories.py index 531977da..fbd7480f 100644 --- a/tests/fixtures/dataset_factories.py +++ b/tests/fixtures/dataset_factories.py @@ -29,6 +29,7 @@ from lerobot.common.datasets.utils import ( DEFAULT_FEATURES, DEFAULT_PARQUET_PATH, DEFAULT_VIDEO_PATH, + DEFAULT_COMPRESSED_AUDIO_PATH, get_hf_features_from_features, hf_transform_to_torch, ) @@ -121,6 +122,7 @@ def info_factory(features_factory): chunks_size: int = DEFAULT_CHUNK_SIZE, data_path: str = DEFAULT_PARQUET_PATH, video_path: str = DEFAULT_VIDEO_PATH, + audio_path: str = DEFAULT_COMPRESSED_AUDIO_PATH, motor_features: dict = DUMMY_MOTOR_FEATURES, camera_features: dict = DUMMY_CAMERA_FEATURES, use_videos: bool = True, @@ -139,6 +141,7 @@ def info_factory(features_factory): "splits": {}, "data_path": data_path, "video_path": video_path if use_videos else None, + "audio_path": audio_path, "features": features, }