This commit is contained in:
Caroline Pascal 2025-04-15 15:41:07 +00:00 committed by GitHub
commit 35d58527a4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 1610 additions and 94 deletions

View File

@ -190,6 +190,11 @@ available_cameras = [
"intelrealsense", "intelrealsense",
] ]
# lists all available microphones from `lerobot/common/robot_devices/microphones`
available_microphones = [
"microphone",
]
# lists all available motors from `lerobot/common/robot_devices/motors` # lists all available motors from `lerobot/common/robot_devices/motors`
available_motors = [ available_motors = [
"dynamixel", "dynamixel",

View File

@ -0,0 +1,165 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import subprocess
from collections import OrderedDict
from pathlib import Path
import torch
import torchaudio
from numpy import ceil
def decode_audio(
audio_path: Path | str,
timestamps: list[float],
duration: float,
backend: str | None = "ffmpeg",
) -> torch.Tensor:
"""
Decodes audio using the specified backend.
Args:
audio_path (Path): Path to the audio file.
timestamps (list[float]): List of (starting) timestamps to extract audio chunks.
duration (float): Duration of the audio chunks in seconds.
backend (str, optional): Backend to use for decoding. Defaults to "ffmpeg".
Returns:
torch.Tensor: Decoded audio chunks.
Currently supports ffmpeg.
"""
if backend == "torchcodec":
raise NotImplementedError("torchcodec is not yet supported for audio decoding")
elif backend == "ffmpeg":
return decode_audio_torchaudio(audio_path, timestamps, duration)
else:
raise ValueError(f"Unsupported video backend: {backend}")
def decode_audio_torchaudio(
audio_path: Path | str,
timestamps: list[float],
duration: float,
log_loaded_timestamps: bool = False,
) -> torch.Tensor:
# TODO(CarolinePascal) : add channels selection
audio_path = str(audio_path)
reader = torchaudio.io.StreamReader(src=audio_path)
audio_sample_rate = reader.get_src_stream_info(reader.default_audio_stream).sample_rate
# TODO(CarolinePascal) : sort timestamps ?
reader.add_basic_audio_stream(
frames_per_chunk=int(ceil(duration * audio_sample_rate)), # Too much is better than not enough
buffer_chunk_size=-1, # No dropping frames
format="fltp", # Format as float32
)
audio_chunks = []
for ts in timestamps:
reader.seek(ts) # Default to closest audio sample
status = reader.fill_buffer()
if status != 0:
logging.warning("Audio stream reached end of recording before decoding desired timestamps.")
current_audio_chunk = reader.pop_chunks()[0]
if log_loaded_timestamps:
logging.info(
f"audio chunk loaded at starting timestamp={current_audio_chunk['pts']:.4f} with duration={len(current_audio_chunk) / audio_sample_rate:.4f}"
)
audio_chunks.append(current_audio_chunk)
audio_chunks = torch.stack(audio_chunks)
assert len(timestamps) == len(audio_chunks)
return audio_chunks
def encode_audio(
input_path: Path | str,
output_path: Path | str,
codec: str = "aac", # TODO(CarolinePascal) : investigate Fraunhofer FDK AAC (libfdk_aac) codec and and constant (file size control) /variable (quality control) bitrate options
log_level: str | None = "error",
overwrite: bool = False,
) -> None:
"""Encodes an audio file using ffmpeg."""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
ffmpeg_args = OrderedDict(
[
("-i", str(input_path)),
("-acodec", codec),
]
)
if log_level is not None:
ffmpeg_args["-loglevel"] = str(log_level)
ffmpeg_args = [item for pair in ffmpeg_args.items() for item in pair]
if overwrite:
ffmpeg_args.append("-y")
ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(output_path)]
# redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal
subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL)
if not output_path.exists():
raise OSError(
f"Audio encoding did not work. File not found: {output_path}. "
f"Try running the command manually to debug: `{''.join(ffmpeg_cmd)}`"
)
def get_audio_info(video_path: Path | str) -> dict:
ffprobe_audio_cmd = [
"ffprobe",
"-v",
"error",
"-select_streams",
"a:0",
"-show_entries",
"stream=channels,codec_name,bit_rate,sample_rate,bit_depth,channel_layout,duration",
"-of",
"json",
str(video_path),
]
result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
if result.returncode != 0:
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
info = json.loads(result.stdout)
audio_stream_info = info["streams"][0] if info.get("streams") else None
if audio_stream_info is None:
return {"has_audio": False}
# Return the information, defaulting to None if no audio stream is present
return {
"has_audio": True,
"audio.channels": audio_stream_info.get("channels", None),
"audio.codec": audio_stream_info.get("codec_name", None),
"audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None,
"audio.sample_rate": int(audio_stream_info["sample_rate"])
if audio_stream_info.get("sample_rate")
else None,
"audio.bit_depth": audio_stream_info.get("bit_depth", None),
"audio.channel_layout": audio_stream_info.get("channel_layout", None),
}

View File

@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
import numpy as np import numpy as np
from lerobot.common.datasets.utils import load_image_as_numpy from lerobot.common.datasets.utils import load_audio_from_path, load_image_as_numpy
def estimate_num_samples( def estimate_num_samples(
@ -72,6 +72,20 @@ def sample_images(image_paths: list[str]) -> np.ndarray:
return images return images
def sample_audio_from_path(audio_path: str) -> np.ndarray:
"""Samples audio data from an audio recording stored in a WAV file."""
data = load_audio_from_path(audio_path)
sampled_indices = sample_indices(len(data))
return data[sampled_indices]
def sample_audio_from_data(data: np.ndarray) -> np.ndarray:
"""Samples audio data from an audio recording stored in a numpy array."""
sampled_indices = sample_indices(len(data))
return data[sampled_indices]
def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]: def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]:
return { return {
"min": np.min(array, axis=axis, keepdims=keepdims), "min": np.min(array, axis=axis, keepdims=keepdims),
@ -91,6 +105,13 @@ def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], featu
ep_ft_array = sample_images(data) # data is a list of image paths ep_ft_array = sample_images(data) # data is a list of image paths
axes_to_reduce = (0, 2, 3) # keep channel dim axes_to_reduce = (0, 2, 3) # keep channel dim
keepdims = True keepdims = True
elif features[key]["dtype"] == "audio":
try:
ep_ft_array = sample_audio_from_path(data[0])
except TypeError: # Should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner
ep_ft_array = sample_audio_from_data(data)
axes_to_reduce = 0
keepdims = True
else: else:
ep_ft_array = data # data is already a np.ndarray ep_ft_array = data # data is already a np.ndarray
axes_to_reduce = 0 # compute stats over the first axis axes_to_reduce = 0 # compute stats over the first axis

View File

@ -23,6 +23,7 @@ import datasets
import numpy as np import numpy as np
import packaging.version import packaging.version
import PIL.Image import PIL.Image
import soundfile as sf
import torch import torch
import torch.utils import torch.utils
from datasets import concatenate_datasets, load_dataset from datasets import concatenate_datasets, load_dataset
@ -31,11 +32,18 @@ from huggingface_hub.constants import REPOCARD_NAME
from huggingface_hub.errors import RevisionNotFoundError from huggingface_hub.errors import RevisionNotFoundError
from lerobot.common.constants import HF_LEROBOT_HOME from lerobot.common.constants import HF_LEROBOT_HOME
from lerobot.common.datasets.audio_utils import (
decode_audio,
encode_audio,
get_audio_info,
)
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
DEFAULT_AUDIO_CHUNK_DURATION,
DEFAULT_FEATURES, DEFAULT_FEATURES,
DEFAULT_IMAGE_PATH, DEFAULT_IMAGE_PATH,
DEFAULT_RAW_AUDIO_PATH,
INFO_PATH, INFO_PATH,
TASKS_PATH, TASKS_PATH,
append_jsonlines, append_jsonlines,
@ -72,6 +80,7 @@ from lerobot.common.datasets.video_utils import (
get_safe_default_codec, get_safe_default_codec,
get_video_info, get_video_info,
) )
from lerobot.common.robot_devices.microphones.utils import Microphone
from lerobot.common.robot_devices.robots.utils import Robot from lerobot.common.robot_devices.robots.utils import Robot
CODEBASE_VERSION = "v2.1" CODEBASE_VERSION = "v2.1"
@ -142,6 +151,14 @@ class LeRobotDatasetMetadata:
fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index) fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
return Path(fpath) return Path(fpath)
def get_compressed_audio_file_path(self, episode_index: int, audio_key: str) -> Path:
"""Returns the path of the compressed (i.e. encoded) audio file."""
episode_chunk = self.get_episode_chunk(episode_index)
fpath = self.audio_path.format(
episode_chunk=episode_chunk, audio_key=audio_key, episode_index=episode_index
)
return self.root / fpath
def get_episode_chunk(self, ep_index: int) -> int: def get_episode_chunk(self, ep_index: int) -> int:
return ep_index // self.chunks_size return ep_index // self.chunks_size
@ -155,6 +172,11 @@ class LeRobotDatasetMetadata:
"""Formattable string for the video files.""" """Formattable string for the video files."""
return self.info["video_path"] return self.info["video_path"]
@property
def audio_path(self) -> str | None:
"""Formattable string for the audio files."""
return self.info["audio_path"]
@property @property
def robot_type(self) -> str | None: def robot_type(self) -> str | None:
"""Robot type used in recording this dataset.""" """Robot type used in recording this dataset."""
@ -185,6 +207,11 @@ class LeRobotDatasetMetadata:
"""Keys to access visual modalities (regardless of their storage method).""" """Keys to access visual modalities (regardless of their storage method)."""
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]] return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
@property
def audio_keys(self) -> list[str]:
"""Keys to access audio modalities."""
return [key for key, ft in self.features.items() if ft["dtype"] == "audio"]
@property @property
def names(self) -> dict[str, list | dict]: def names(self) -> dict[str, list | dict]:
"""Names of the various dimensions of vector modalities.""" """Names of the various dimensions of vector modalities."""
@ -264,6 +291,10 @@ class LeRobotDatasetMetadata:
if len(self.video_keys) > 0: if len(self.video_keys) > 0:
self.update_video_info() self.update_video_info()
self.info["total_audio"] += len(self.audio_keys)
if len(self.audio_keys) > 0:
self.update_audio_info()
write_info(self.info, self.root) write_info(self.info, self.root)
episode_dict = { episode_dict = {
@ -288,6 +319,19 @@ class LeRobotDatasetMetadata:
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key) video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
self.info["features"][key]["info"] = get_video_info(video_path) self.info["features"][key]["info"] = get_video_info(video_path)
def update_audio_info(self) -> None:
"""
Warning: this function writes info from first episode audio, implicitly assuming that all audio have
been encoded the same way. Also, this means it assumes the first episode exists.
"""
for key in self.audio_keys:
if (
not self.features[key].get("info", None)
or (len(self.features[key]["info"]) == 1 and "sample_rate" in self.features[key]["info"])
): # Overwrite if info is empty or only contains sample rate (necessary to correctly save audio files recorded by LeKiwi)
audio_path = self.root / self.get_compressed_audio_file_path(0, key)
self.info["features"][key]["info"] = get_audio_info(audio_path)
def __repr__(self): def __repr__(self):
feature_keys = list(self.features) feature_keys = list(self.features)
return ( return (
@ -363,7 +407,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
revision: str | None = None, revision: str | None = None,
force_cache_sync: bool = False, force_cache_sync: bool = False,
download_videos: bool = True, download_videos: bool = True,
download_audio: bool = True,
video_backend: str | None = None, video_backend: str | None = None,
audio_backend: str | None = None,
): ):
""" """
2 modes are available for instantiating this class, depending on 2 different use cases: 2 modes are available for instantiating this class, depending on 2 different use cases:
@ -394,7 +440,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
- tasks contains the prompts for each task of the dataset, which can be used for - tasks contains the prompts for each task of the dataset, which can be used for
task-conditioned training. task-conditioned training.
- hf_dataset (from datasets.Dataset), which will read any values from parquet files. - hf_dataset (from datasets.Dataset), which will read any values from parquet files.
- videos (optional) from which frames are loaded to be synchronous with data from parquet files. - videos (optional) from which frames and audio (if any) are loaded to be synchronous with data from parquet files and audio.
- audio (optional) from which audio is loaded to be synchronous with data from parquet files and videos.
A typical LeRobotDataset looks like this from its root path: A typical LeRobotDataset looks like this from its root path:
. .
@ -415,17 +462,31 @@ class LeRobotDataset(torch.utils.data.Dataset):
info.json info.json
stats.json stats.json
tasks.jsonl tasks.jsonl
videos videos
chunk-000
observation.images.laptop
episode_000000.mp4
episode_000001.mp4
episode_000002.mp4
...
observation.images.phone
episode_000000.mp4
episode_000001.mp4
episode_000002.mp4
...
chunk-001
...
audio
chunk-000 chunk-000
observation.images.laptop observation.audio.laptop
episode_000000.mp4 episode_000000.m4a
episode_000001.mp4 episode_000001.m4a
episode_000002.mp4 episode_000002.m4a
... ...
observation.images.phone observation.audio.phone
episode_000000.mp4 episode_000000.m4a
episode_000001.mp4 episode_000001.m4a
episode_000002.mp4 episode_000002.m4a
... ...
chunk-001 chunk-001
... ...
@ -463,8 +524,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
video files are already present on local disk, they won't be downloaded again. Defaults to video files are already present on local disk, they won't be downloaded again. Defaults to
True. True.
download_audio (bool, optional): Flag to download the audio (see download_videos). Defaults to True.
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'. video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'.
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision. You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
audio_backend (str | None, optional): Audio backend to use for decoding audio. Defaults to 'ffmpeg' decoder used by 'torchaudio'.
""" """
super().__init__() super().__init__()
self.repo_id = repo_id self.repo_id = repo_id
@ -475,6 +538,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.tolerance_s = tolerance_s self.tolerance_s = tolerance_s
self.revision = revision if revision else CODEBASE_VERSION self.revision = revision if revision else CODEBASE_VERSION
self.video_backend = video_backend if video_backend else get_safe_default_codec() self.video_backend = video_backend if video_backend else get_safe_default_codec()
self.audio_backend = (
audio_backend if audio_backend else "ffmpeg"
) # Waiting for torchcodec release #TODO(CarolinePascal)
self.delta_indices = None self.delta_indices = None
# Unused attributes # Unused attributes
@ -499,7 +565,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.hf_dataset = self.load_hf_dataset() self.hf_dataset = self.load_hf_dataset()
except (AssertionError, FileNotFoundError, NotADirectoryError): except (AssertionError, FileNotFoundError, NotADirectoryError):
self.revision = get_safe_version(self.repo_id, self.revision) self.revision = get_safe_version(self.repo_id, self.revision)
self.download_episodes(download_videos) self.download_episodes(download_videos, download_audio)
self.hf_dataset = self.load_hf_dataset() self.hf_dataset = self.load_hf_dataset()
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes) self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
@ -510,6 +576,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()} ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s) check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
# TODO(CarolinePascal) : add check for audio duration with respect to episode duration BUT this will be CPU expensive if there are many episodes !
# Setup delta_indices # Setup delta_indices
if self.delta_timestamps is not None: if self.delta_timestamps is not None:
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s) check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
@ -522,6 +590,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
license: str | None = "apache-2.0", license: str | None = "apache-2.0",
tag_version: bool = True, tag_version: bool = True,
push_videos: bool = True, push_videos: bool = True,
push_audio: bool = True,
private: bool = False, private: bool = False,
allow_patterns: list[str] | str | None = None, allow_patterns: list[str] | str | None = None,
upload_large_folder: bool = False, upload_large_folder: bool = False,
@ -530,6 +599,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
ignore_patterns = ["images/"] ignore_patterns = ["images/"]
if not push_videos: if not push_videos:
ignore_patterns.append("videos/") ignore_patterns.append("videos/")
if not push_audio:
ignore_patterns.append("audio/")
hub_api = HfApi() hub_api = HfApi()
hub_api.create_repo( hub_api.create_repo(
@ -585,7 +656,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
ignore_patterns=ignore_patterns, ignore_patterns=ignore_patterns,
) )
def download_episodes(self, download_videos: bool = True) -> None: def download_episodes(self, download_videos: bool = True, download_audio: bool = True) -> None:
"""Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this """Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this
will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present
@ -594,7 +665,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
# TODO(rcadene, aliberts): implement faster transfer # TODO(rcadene, aliberts): implement faster transfer
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads # https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
files = None files = None
ignore_patterns = None if download_videos else "videos/" ignore_patterns = []
if not download_videos:
ignore_patterns.append("videos/")
if not download_audio:
ignore_patterns.append("audio/")
if self.episodes is not None: if self.episodes is not None:
files = self.get_episodes_file_paths() files = self.get_episodes_file_paths()
@ -611,6 +686,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
] ]
fpaths += video_files fpaths += video_files
if len(self.meta.audio_keys) > 0:
audio_files = [
str(self.meta.get_compressed_audio_file_path(ep_idx, audio_key))
for audio_key in self.meta.audio_keys
for ep_idx in episodes
]
fpaths += audio_files
return fpaths return fpaths
def load_hf_dataset(self) -> datasets.Dataset: def load_hf_dataset(self) -> datasets.Dataset:
@ -677,7 +760,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
} }
return query_indices, padding return query_indices, padding
def _get_query_timestamps( def _get_query_timestamps_video(
self, self,
current_ts: float, current_ts: float,
query_indices: dict[str, list[int]] | None = None, query_indices: dict[str, list[int]] | None = None,
@ -692,11 +775,27 @@ class LeRobotDataset(torch.utils.data.Dataset):
return query_timestamps return query_timestamps
# TODO(CarolinePascal): add variable query durations
def _get_query_timestamps_audio(
self,
current_ts: float,
query_indices: dict[str, list[int]] | None = None,
) -> dict[str, list[float]]:
query_timestamps = {}
for key in self.meta.audio_keys:
if query_indices is not None and key in query_indices:
timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
query_timestamps[key] = torch.stack(timestamps).tolist()
else:
query_timestamps[key] = [current_ts]
return query_timestamps
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict: def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
return { return {
key: torch.stack(self.hf_dataset.select(q_idx)[key]) key: torch.stack(self.hf_dataset.select(q_idx)[key])
for key, q_idx in query_indices.items() for key, q_idx in query_indices.items()
if key not in self.meta.video_keys if key not in self.meta.video_keys and key not in self.meta.audio_keys
} }
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]: def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
@ -713,6 +812,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
return item return item
# TODO(CarolinePascal): add variable query durations
def _query_audio(
self, query_timestamps: dict[str, list[float]], query_duration: float, ep_idx: int
) -> dict[str, torch.Tensor]:
item = {}
for audio_key, query_ts in query_timestamps.items():
audio_path = self.root / self.meta.get_compressed_audio_file_path(ep_idx, audio_key)
audio_chunk = decode_audio(audio_path, query_ts, query_duration, self.audio_backend)
item[audio_key] = audio_chunk.squeeze(0)
return item
def _add_padding_keys(self, item: dict, padding: dict[str, list[bool]]) -> dict: def _add_padding_keys(self, item: dict, padding: dict[str, list[bool]]) -> dict:
for key, val in padding.items(): for key, val in padding.items():
item[key] = torch.BoolTensor(val) item[key] = torch.BoolTensor(val)
@ -733,11 +843,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
for key, val in query_result.items(): for key, val in query_result.items():
item[key] = val item[key] = val
if len(self.meta.video_keys) > 0: if len(self.meta.video_keys) > 0 or len(self.meta.audio_keys) > 0:
current_ts = item["timestamp"].item() current_ts = item["timestamp"].item()
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
query_timestamps = self._get_query_timestamps_video(current_ts, query_indices)
video_frames = self._query_videos(query_timestamps, ep_idx) video_frames = self._query_videos(query_timestamps, ep_idx)
item = {**video_frames, **item} item = {**item, **video_frames}
query_timestamps = self._get_query_timestamps_audio(current_ts, query_indices)
audio_chunks = self._query_audio(query_timestamps, DEFAULT_AUDIO_CHUNK_DURATION, ep_idx)
item = {**item, **audio_chunks}
if self.image_transforms is not None: if self.image_transforms is not None:
image_keys = self.meta.camera_keys image_keys = self.meta.camera_keys
@ -777,6 +892,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
) )
return self.root / fpath return self.root / fpath
def _get_raw_audio_file_path(self, episode_index: int, audio_key: str) -> Path:
fpath = DEFAULT_RAW_AUDIO_PATH.format(audio_key=audio_key, episode_index=episode_index)
return self.root / fpath
def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None: def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None:
if self.image_writer is None: if self.image_writer is None:
if isinstance(image, torch.Tensor): if isinstance(image, torch.Tensor):
@ -827,11 +946,39 @@ class LeRobotDataset(torch.utils.data.Dataset):
img_path.parent.mkdir(parents=True, exist_ok=True) img_path.parent.mkdir(parents=True, exist_ok=True)
self._save_image(frame[key], img_path) self._save_image(frame[key], img_path)
self.episode_buffer[key].append(str(img_path)) self.episode_buffer[key].append(str(img_path))
elif self.features[key]["dtype"] == "audio":
if (
self.meta.robot_type is not None and self.meta.robot_type.startswith("lekiwi")
): # Raw data storage should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner
self.episode_buffer[key].append(frame[key])
else: # Otherwise, only the audio file path is stored in the episode buffer
if frame_index == 0:
audio_path = self._get_raw_audio_file_path(
episode_index=self.episode_buffer["episode_index"], audio_key=key
)
self.episode_buffer[key].append(str(audio_path))
else: else:
self.episode_buffer[key].append(frame[key]) self.episode_buffer[key].append(frame[key])
self.episode_buffer["size"] += 1 self.episode_buffer["size"] += 1
def add_microphone_recording(self, microphone: Microphone, microphone_key: str) -> None:
"""
Starts recording audio data provided by the microphone and directly writes it in a .wav file.
"""
audio_dir = self._get_raw_audio_file_path(
self.num_episodes, "observation.audio." + microphone_key
).parent
if not audio_dir.is_dir():
audio_dir.mkdir(parents=True, exist_ok=True)
microphone.start_recording(
output_file=self._get_raw_audio_file_path(
self.num_episodes, "observation.audio." + microphone_key
)
)
def save_episode(self, episode_data: dict | None = None) -> None: def save_episode(self, episode_data: dict | None = None) -> None:
""" """
This will save to disk the current episode in self.episode_buffer. This will save to disk the current episode in self.episode_buffer.
@ -869,16 +1016,39 @@ class LeRobotDataset(torch.utils.data.Dataset):
# are processed separately by storing image path and frame info as meta data # are processed separately by storing image path and frame info as meta data
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]: if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
continue continue
elif ft["dtype"] == "audio":
if (
self.meta.robot_type is not None and self.meta.robot_type.startswith("lekiwi")
): # Raw data storage should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner
episode_buffer[key] = np.concatenate(episode_buffer[key], axis=0)
continue
episode_buffer[key] = np.stack(episode_buffer[key]) episode_buffer[key] = np.stack(episode_buffer[key])
self._wait_image_writer() self._wait_image_writer()
self._save_episode_table(episode_buffer, episode_index) self._save_episode_table(episode_buffer, episode_index)
if (
self.meta.robot_type is not None and self.meta.robot_type.startswith("lekiwi")
): # Raw data storage should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner
for key in self.meta.audio_keys:
audio_path = self._get_raw_audio_file_path(
episode_index=self.episode_buffer["episode_index"][0], audio_key=key
)
with sf.SoundFile(
audio_path,
mode="w",
samplerate=self.meta.features[key]["info"]["sample_rate"],
channels=self.meta.features[key]["shape"][0],
) as file:
file.write(episode_buffer[key])
ep_stats = compute_episode_stats(episode_buffer, self.features) ep_stats = compute_episode_stats(episode_buffer, self.features)
if len(self.meta.video_keys) > 0: if len(self.meta.video_keys) > 0:
video_paths = self.encode_episode_videos(episode_index) self.encode_episode_videos(episode_index)
for key in self.meta.video_keys:
episode_buffer[key] = video_paths[key] if len(self.meta.audio_keys) > 0:
self.encode_episode_audio(episode_index)
# `meta.save_episode` be executed after encoding the videos # `meta.save_episode` be executed after encoding the videos
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats) self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
@ -904,6 +1074,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
if img_dir.is_dir(): if img_dir.is_dir():
shutil.rmtree(self.root / "images") shutil.rmtree(self.root / "images")
# delete raw audio files
raw_audio_files = list(self.root.rglob("*.wav"))
for raw_audio_file in raw_audio_files:
raw_audio_file.unlink()
if len(list(raw_audio_file.parent.iterdir())) == 0:
raw_audio_file.parent.rmdir()
if not episode_data: # Reset the buffer if not episode_data: # Reset the buffer
self.episode_buffer = self.create_episode_buffer() self.episode_buffer = self.create_episode_buffer()
@ -971,19 +1148,40 @@ class LeRobotDataset(torch.utils.data.Dataset):
since video encoding with ffmpeg is already using multithreading. since video encoding with ffmpeg is already using multithreading.
""" """
video_paths = {} video_paths = {}
for key in self.meta.video_keys: for video_key in self.meta.video_keys:
video_path = self.root / self.meta.get_video_file_path(episode_index, key) video_path = self.root / self.meta.get_video_file_path(episode_index, video_key)
video_paths[key] = str(video_path) video_paths[video_key] = str(video_path)
if video_path.is_file(): if video_path.is_file():
# Skip if video is already encoded. Could be the case when resuming data recording. # Skip if video is already encoded. Could be the case when resuming data recording.
continue continue
img_dir = self._get_image_file_path( img_dir = self._get_image_file_path(
episode_index=episode_index, image_key=key, frame_index=0 episode_index=episode_index, image_key=video_key, frame_index=0
).parent ).parent
encode_video_frames(img_dir, video_path, self.fps, overwrite=True) encode_video_frames(img_dir, video_path, self.fps, overwrite=True)
return video_paths return video_paths
def encode_episode_audio(self, episode_index: int) -> dict:
"""
Use ffmpeg to convert .wav raw audio files into .m4a audio files.
Note: `encode_episode_audio` is a blocking call. Making it asynchronous shouldn't speedup encoding,
since video encoding with ffmpeg is already using multithreading.
"""
audio_paths = {}
for audio_key in self.meta.audio_keys:
input_audio_path = self.root / self._get_raw_audio_file_path(episode_index, audio_key)
output_audio_path = self.root / self.meta.get_compressed_audio_file_path(episode_index, audio_key)
audio_paths[audio_key] = str(output_audio_path)
if output_audio_path.is_file():
# Skip if video is already encoded. Could be the case when resuming data recording.
continue
encode_audio(input_audio_path, output_audio_path, overwrite=True)
return audio_paths
@classmethod @classmethod
def create( def create(
cls, cls,
@ -998,6 +1196,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
image_writer_processes: int = 0, image_writer_processes: int = 0,
image_writer_threads: int = 0, image_writer_threads: int = 0,
video_backend: str | None = None, video_backend: str | None = None,
audio_backend: str | None = None,
) -> "LeRobotDataset": ) -> "LeRobotDataset":
"""Create a LeRobot Dataset from scratch in order to record data.""" """Create a LeRobot Dataset from scratch in order to record data."""
obj = cls.__new__(cls) obj = cls.__new__(cls)
@ -1029,6 +1228,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.delta_indices = None obj.delta_indices = None
obj.episode_data_index = None obj.episode_data_index = None
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec() obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
obj.audio_backend = (
audio_backend if audio_backend is not None else "ffmpeg"
) # Waiting for torchcodec release #TODO(CarolinePascal)
return obj return obj
@ -1049,6 +1251,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
tolerances_s: dict | None = None, tolerances_s: dict | None = None,
download_videos: bool = True, download_videos: bool = True,
video_backend: str | None = None, video_backend: str | None = None,
audio_backend: str | None = None,
): ):
super().__init__() super().__init__()
self.repo_ids = repo_ids self.repo_ids = repo_ids
@ -1066,6 +1269,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
tolerance_s=self.tolerances_s[repo_id], tolerance_s=self.tolerances_s[repo_id],
download_videos=download_videos, download_videos=download_videos,
video_backend=video_backend, video_backend=video_backend,
audio_backend=audio_backend,
) )
for repo_id in repo_ids for repo_id in repo_ids
] ]

View File

@ -33,6 +33,7 @@ from datasets.table import embed_table_storage
from huggingface_hub import DatasetCard, DatasetCardData, HfApi from huggingface_hub import DatasetCard, DatasetCardData, HfApi
from huggingface_hub.errors import RevisionNotFoundError from huggingface_hub.errors import RevisionNotFoundError
from PIL import Image as PILImage from PIL import Image as PILImage
from soundfile import read
from torchvision import transforms from torchvision import transforms
from lerobot.common.datasets.backward_compatibility import ( from lerobot.common.datasets.backward_compatibility import (
@ -55,6 +56,10 @@ TASKS_PATH = "meta/tasks.jsonl"
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4" DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet" DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png" DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
DEFAULT_RAW_AUDIO_PATH = "audio/{audio_key}/episode_{episode_index:06d}.wav"
DEFAULT_COMPRESSED_AUDIO_PATH = "audio/chunk-{episode_chunk:03d}/{audio_key}/episode_{episode_index:06d}.m4a"
DEFAULT_AUDIO_CHUNK_DURATION = 0.5 # seconds
DATASET_CARD_TEMPLATE = """ DATASET_CARD_TEMPLATE = """
--- ---
@ -255,6 +260,11 @@ def load_image_as_numpy(
return img_array return img_array
def load_audio_from_path(fpath: str | Path) -> np.ndarray:
audio_data, _ = read(fpath, dtype="float32")
return audio_data
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]): def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
"""Get a transform function that convert items from Hugging Face dataset (pyarrow) """Get a transform function that convert items from Hugging Face dataset (pyarrow)
to torch tensors. Importantly, images are converted from PIL, which corresponds to to torch tensors. Importantly, images are converted from PIL, which corresponds to
@ -363,7 +373,7 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
def get_hf_features_from_features(features: dict) -> datasets.Features: def get_hf_features_from_features(features: dict) -> datasets.Features:
hf_features = {} hf_features = {}
for key, ft in features.items(): for key, ft in features.items():
if ft["dtype"] == "video": if ft["dtype"] == "video" or ft["dtype"] == "audio":
continue continue
elif ft["dtype"] == "image": elif ft["dtype"] == "image":
hf_features[key] = datasets.Image() hf_features[key] = datasets.Image()
@ -394,7 +404,7 @@ def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict:
key: {"dtype": "video" if use_videos else "image", **ft} key: {"dtype": "video" if use_videos else "image", **ft}
for key, ft in robot.camera_features.items() for key, ft in robot.camera_features.items()
} }
return {**robot.motor_features, **camera_ft, **DEFAULT_FEATURES} return {**robot.motor_features, **camera_ft, **robot.microphone_features, **DEFAULT_FEATURES}
def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]: def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]:
@ -442,12 +452,14 @@ def create_empty_dataset_info(
"total_frames": 0, "total_frames": 0,
"total_tasks": 0, "total_tasks": 0,
"total_videos": 0, "total_videos": 0,
"total_audio": 0,
"total_chunks": 0, "total_chunks": 0,
"chunks_size": DEFAULT_CHUNK_SIZE, "chunks_size": DEFAULT_CHUNK_SIZE,
"fps": fps, "fps": fps,
"splits": {}, "splits": {},
"data_path": DEFAULT_PARQUET_PATH, "data_path": DEFAULT_PARQUET_PATH,
"video_path": DEFAULT_VIDEO_PATH if use_videos else None, "video_path": DEFAULT_VIDEO_PATH if use_videos else None,
"audio_path": DEFAULT_COMPRESSED_AUDIO_PATH,
"features": features, "features": features,
} }
@ -721,6 +733,7 @@ def validate_features_presence(
): ):
error_message = "" error_message = ""
missing_features = expected_features - actual_features missing_features = expected_features - actual_features
missing_features = {feature for feature in missing_features if "observation.audio" not in feature}
extra_features = actual_features - (expected_features | optional_features) extra_features = actual_features - (expected_features | optional_features)
if missing_features or extra_features: if missing_features or extra_features:
@ -740,6 +753,8 @@ def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray
return validate_feature_numpy_array(name, expected_dtype, expected_shape, value) return validate_feature_numpy_array(name, expected_dtype, expected_shape, value)
elif expected_dtype in ["image", "video"]: elif expected_dtype in ["image", "video"]:
return validate_feature_image_or_video(name, expected_shape, value) return validate_feature_image_or_video(name, expected_shape, value)
elif expected_dtype == "audio":
return validate_feature_audio(name, expected_shape, value)
elif expected_dtype == "string": elif expected_dtype == "string":
return validate_feature_string(name, value) return validate_feature_string(name, value)
else: else:
@ -781,6 +796,23 @@ def validate_feature_image_or_video(name: str, expected_shape: list[str], value:
return error_message return error_message
def validate_feature_audio(name: str, expected_shape: list[str], value: np.ndarray):
error_message = ""
if isinstance(value, np.ndarray):
actual_shape = value.shape
c = expected_shape
if len(actual_shape) != 2 or (
actual_shape[-1] != c[-1] and actual_shape[0] != c[0]
): # The number of frames might be different
error_message += (
f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c,)}'.\n"
)
else:
error_message += f"The feature '{name}' is expected to be of type 'np.ndarray', but type '{type(value)}' provided instead.\n"
return error_message
def validate_feature_string(name: str, value: str): def validate_feature_string(name: str, value: str):
if not isinstance(value, str): if not isinstance(value, str):
return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n" return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n"

View File

@ -260,35 +260,39 @@ def encode_video_frames(
imgs_dir = Path(imgs_dir) imgs_dir = Path(imgs_dir)
video_path.parent.mkdir(parents=True, exist_ok=True) video_path.parent.mkdir(parents=True, exist_ok=True)
ffmpeg_args = OrderedDict( ffmpeg_video_args = OrderedDict(
[ [
("-f", "image2"), ("-f", "image2"),
("-r", str(fps)), ("-r", str(fps)),
("-i", str(imgs_dir / "frame_%06d.png")), ("-i", str(Path(imgs_dir) / "frame_%06d.png")),
("-vcodec", vcodec),
("-pix_fmt", pix_fmt),
] ]
) )
ffmpeg_encoding_args = OrderedDict(
[
("-pix_fmt", pix_fmt),
("-vcodec", vcodec),
]
)
if g is not None: if g is not None:
ffmpeg_args["-g"] = str(g) ffmpeg_encoding_args["-g"] = str(g)
if crf is not None: if crf is not None:
ffmpeg_args["-crf"] = str(crf) ffmpeg_encoding_args["-crf"] = str(crf)
if fast_decode: if fast_decode:
key = "-svtav1-params" if vcodec == "libsvtav1" else "-tune" key = "-svtav1-params" if vcodec == "libsvtav1" else "-tune"
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode" value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
ffmpeg_args[key] = value ffmpeg_encoding_args[key] = value
if log_level is not None: if log_level is not None:
ffmpeg_args["-loglevel"] = str(log_level) ffmpeg_encoding_args["-loglevel"] = str(log_level)
ffmpeg_args = [item for pair in ffmpeg_args.items() for item in pair] ffmpeg_args = [item for pair in ffmpeg_video_args.items() for item in pair]
ffmpeg_args += [item for pair in ffmpeg_encoding_args.items() for item in pair]
if overwrite: if overwrite:
ffmpeg_args.append("-y") ffmpeg_args.append("-y")
ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(video_path)] ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(video_path)]
# redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal # redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal
subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL) subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL)
@ -331,42 +335,6 @@ with warnings.catch_warnings():
register_feature(VideoFrame, "VideoFrame") register_feature(VideoFrame, "VideoFrame")
def get_audio_info(video_path: Path | str) -> dict:
ffprobe_audio_cmd = [
"ffprobe",
"-v",
"error",
"-select_streams",
"a:0",
"-show_entries",
"stream=channels,codec_name,bit_rate,sample_rate,bit_depth,channel_layout,duration",
"-of",
"json",
str(video_path),
]
result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
if result.returncode != 0:
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
info = json.loads(result.stdout)
audio_stream_info = info["streams"][0] if info.get("streams") else None
if audio_stream_info is None:
return {"has_audio": False}
# Return the information, defaulting to None if no audio stream is present
return {
"has_audio": True,
"audio.channels": audio_stream_info.get("channels", None),
"audio.codec": audio_stream_info.get("codec_name", None),
"audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None,
"audio.sample_rate": int(audio_stream_info["sample_rate"])
if audio_stream_info.get("sample_rate")
else None,
"audio.bit_depth": audio_stream_info.get("bit_depth", None),
"audio.channel_layout": audio_stream_info.get("channel_layout", None),
}
def get_video_info(video_path: Path | str) -> dict: def get_video_info(video_path: Path | str) -> dict:
ffprobe_video_cmd = [ ffprobe_video_cmd = [
"ffprobe", "ffprobe",
@ -402,7 +370,6 @@ def get_video_info(video_path: Path | str) -> dict:
"video.codec": video_stream_info["codec_name"], "video.codec": video_stream_info["codec_name"],
"video.pix_fmt": video_stream_info["pix_fmt"], "video.pix_fmt": video_stream_info["pix_fmt"],
"video.is_depth_map": False, "video.is_depth_map": False,
**get_audio_info(video_path),
} }
return video_info return video_info

View File

@ -78,6 +78,11 @@ def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, f
if key in robot.logs: if key in robot.logs:
log_dt(f"dtR{name}", robot.logs[key]) log_dt(f"dtR{name}", robot.logs[key])
for name in robot.microphones:
key = f"read_microphone_{name}_dt_s"
if key in robot.logs:
log_dt(f"dtR{name}", robot.logs[key])
info_str = " ".join(log_items) info_str = " ".join(log_items)
logging.info(info_str) logging.info(info_str)
@ -107,11 +112,15 @@ def predict_action(observation, policy, device, use_amp):
torch.inference_mode(), torch.inference_mode(),
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(), torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
): ):
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
for name in observation: for name in observation:
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
if "image" in name: if "image" in name:
observation[name] = observation[name].type(torch.float32) / 255 observation[name] = observation[name].type(torch.float32) / 255
observation[name] = observation[name].permute(2, 0, 1).contiguous() observation[name] = observation[name].permute(2, 0, 1).contiguous()
# Convert to pytorch format: channel first and float32 in [-1,1] with batch dimension
if "audio" in name:
observation[name] = observation[name].type(torch.float32)
observation[name] = observation[name].permute(1, 0).contiguous()
observation[name] = observation[name].unsqueeze(0) observation[name] = observation[name].unsqueeze(0)
observation[name] = observation[name].to(device) observation[name] = observation[name].to(device)
@ -243,6 +252,18 @@ def control_loop(
timestamp = 0 timestamp = 0
start_episode_t = time.perf_counter() start_episode_t = time.perf_counter()
if (
dataset is not None and not robot.robot_type.startswith("lekiwi")
): # For now, LeKiwi only supports frame audio recording (which may lead to audio chunks loss, extended post-processing, increased memory usage)
for microphone_key, microphone in robot.microphones.items():
# Start recording both in file writing and data reading mode
dataset.add_microphone_recording(microphone, microphone_key)
else:
for _, microphone in robot.microphones.items():
# Start recording only in data reading mode
microphone.start_recording()
while timestamp < control_time_s: while timestamp < control_time_s:
start_loop_t = time.perf_counter() start_loop_t = time.perf_counter()
@ -286,6 +307,9 @@ def control_loop(
events["exit_early"] = False events["exit_early"] = False
break break
for _, microphone in robot.microphones.items():
microphone.stop_recording()
def reset_environment(robot, events, reset_time_s, fps): def reset_environment(robot, events, reset_time_s, fps):
# TODO(rcadene): refactor warmup_record and reset_environment # TODO(rcadene): refactor warmup_record and reset_environment

View File

@ -0,0 +1,38 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from dataclasses import dataclass
import draccus
@dataclass
class MicrophoneConfigBase(draccus.ChoiceRegistry, abc.ABC):
@property
def type(self) -> str:
return self.get_choice_name(self.__class__)
@MicrophoneConfigBase.register_subclass("microphone")
@dataclass
class MicrophoneConfig(MicrophoneConfigBase):
"""
Dataclass for microphone configuration.
"""
microphone_index: int
sample_rate: int | None = None
channels: list[int] | None = None
mock: bool = False

View File

@ -0,0 +1,425 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file contains utilities for recording audio from a microhone.
"""
import argparse
import logging
import shutil
import time
from multiprocessing import Event as process_Event
from multiprocessing import JoinableQueue as process_Queue
from multiprocessing import Process
from os import getcwd
from pathlib import Path
from queue import Empty
from queue import Queue as thread_Queue
from threading import Event, Thread
from threading import Event as thread_Event
import numpy as np
import soundfile as sf
from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig
from lerobot.common.robot_devices.utils import (
RobotDeviceAlreadyConnectedError,
RobotDeviceAlreadyRecordingError,
RobotDeviceNotConnectedError,
RobotDeviceNotRecordingError,
)
from lerobot.common.utils.utils import capture_timestamp_utc
def find_microphones(raise_when_empty=False, mock=False) -> list[dict]:
"""
Finds and lists all microphones compatible with sounddevice (and the underlying PortAudio library).
Most microphones and sound cards are compatible, across all OS (Linux, Mac, Windows).
"""
microphones = []
if mock:
import tests.microphones.mock_sounddevice as sd
else:
import sounddevice as sd
devices = sd.query_devices()
for device in devices:
if device["max_input_channels"] > 0:
microphones.append(
{
"index": device["index"],
"name": device["name"],
}
)
if raise_when_empty and len(microphones) == 0:
raise OSError(
"Not a single microphone was detected. Try re-plugging the microphone or check the microphone settings."
)
return microphones
def record_audio_from_microphones(
output_dir: Path, microphone_ids: list[int] | None = None, record_time_s: float = 2.0
):
"""
Records audio from all the channels of the specified microphones for the specified duration.
If no microphone ids are provided, all available microphones will be used.
"""
if microphone_ids is None or len(microphone_ids) == 0:
microphones = find_microphones()
microphone_ids = [m["index"] for m in microphones]
microphones = []
for microphone_id in microphone_ids:
config = MicrophoneConfig(microphone_index=microphone_id)
microphone = Microphone(config)
microphone.connect()
print(
f"Recording audio from microphone {microphone_id} for {record_time_s} seconds at {microphone.sample_rate} Hz."
)
microphones.append(microphone)
output_dir = Path(output_dir)
if output_dir.exists():
shutil.rmtree(
output_dir,
)
output_dir.mkdir(parents=True, exist_ok=True)
print(f"Saving audio to {output_dir}")
for microphone in microphones:
microphone.start_recording(getcwd() / output_dir / f"microphone_{microphone.microphone_index}.wav")
time.sleep(record_time_s)
for microphone in microphones:
microphone.stop_recording()
# Remark : recording may be resumed here if needed
for microphone in microphones:
microphone.disconnect()
print(f"Images have been saved to {output_dir}")
class Microphone:
"""
The Microphone class handles all microphones compatible with sounddevice (and the underlying PortAudio library). Most microphones and sound cards are compatible, across all OS (Linux, Mac, Windows).
A Microphone instance requires the sounddevice index of the microphone, which may be obtained using `python -m sounddevice`. It also requires the recording sample rate as well as the list of recorded channels.
Example of usage:
```python
from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig
config = MicrophoneConfig(microphone_index=0, sample_rate=16000, channels=[1])
microphone = Microphone(config)
microphone.connect()
microphone.start_recording("some/output/file.wav")
...
audio_readings = microphone.read() #Gets all recorded audio data since the last read or since the beginning of the recording
...
microphone.stop_recording()
microphone.disconnect()
```
"""
def __init__(self, config: MicrophoneConfig):
self.config = config
self.microphone_index = config.microphone_index
# Store the recording sample rate and channels
self.sample_rate = config.sample_rate
self.channels = config.channels
self.mock = config.mock
# Input audio stream
self.stream = None
# Thread/Process-safe concurrent queue to store the recorded/read audio
self.record_queue = None
self.read_queue = None
# Thread/Process to handle data reading and file writing in a separate thread/process (safely)
self.record_thread = None
self.record_stop_event = None
self.logs = {}
self.is_connected = False
self.is_recording = False
self.is_writing = False
def connect(self) -> None:
"""
Connects the microphone and checks if the requested acquisition parameters are compatible with the microphone.
"""
if self.is_connected:
raise RobotDeviceAlreadyConnectedError(
f"Microphone {self.microphone_index} is already connected."
)
if self.mock:
import tests.microphones.mock_sounddevice as sd
else:
import sounddevice as sd
# Check if the provided microphone index does match an input device
is_index_input = sd.query_devices(self.microphone_index)["max_input_channels"] > 0
if not is_index_input:
microphones_info = find_microphones()
available_microphones = [m["index"] for m in microphones_info]
raise OSError(
f"Microphone index {self.microphone_index} does not match an input device (microphone). Available input devices : {available_microphones}"
)
# Check if provided recording parameters are compatible with the microphone
actual_microphone = sd.query_devices(self.microphone_index)
if self.sample_rate is not None:
if self.sample_rate > actual_microphone["default_samplerate"]:
raise OSError(
f"Provided sample rate {self.sample_rate} is higher than the sample rate of the microphone {actual_microphone['default_samplerate']}."
)
elif self.sample_rate < actual_microphone["default_samplerate"]:
logging.warning(
"Provided sample rate is lower than the sample rate of the microphone. Performance may be impacted."
)
else:
self.sample_rate = int(actual_microphone["default_samplerate"])
if self.channels is not None and len(self.channels) > 0:
if any(c > actual_microphone["max_input_channels"] for c in self.channels):
raise OSError(
f"Some of the provided channels {self.channels} are outside the maximum channel range of the microphone {actual_microphone['max_input_channels']}."
)
else:
self.channels = np.arange(1, actual_microphone["max_input_channels"] + 1)
# Get channels index instead of number for slicing
self.channels_index = np.array(self.channels) - 1
# Create the audio stream
self.stream = sd.InputStream(
device=self.microphone_index,
samplerate=self.sample_rate,
channels=max(self.channels),
dtype="float32",
callback=self._audio_callback,
)
# Remark : the blocksize parameter could be passed to the stream to ensure that audio_callback always receive same length buffers.
# However, this may lead to additional latency. We thus stick to blocksize=0 which means that audio_callback will receive varying length buffers, but with no additional latency.
self.is_connected = True
def _audio_callback(self, indata, frames, time, status) -> None:
"""
Low-level sounddevice callback.
"""
if status:
logging.warning(status)
# Slicing makes copy unnecessary
# Two separate queues are necessary because .get() also pops the data from the queue
if self.is_writing:
self.record_queue.put(indata[:, self.channels_index])
self.read_queue.put(indata[:, self.channels_index])
@staticmethod
def _record_loop(queue, event: Event, sample_rate: int, channels: list[int], output_file: Path) -> None:
"""
Thread/Process-safe loop to write audio data into a file.
"""
# Can only be run on a single process/thread for file writing safety
with sf.SoundFile(
output_file,
mode="x",
samplerate=sample_rate,
channels=max(channels),
subtype=sf.default_subtype(output_file.suffix[1:]),
) as file:
while not event.is_set():
try:
file.write(
queue.get(timeout=0.02)
) # Timeout set as twice the usual sounddevice buffer size
queue.task_done()
except Empty:
continue
def _read(self) -> np.ndarray:
"""
Thread/Process-safe callback to read available audio data
"""
audio_readings = np.empty((0, len(self.channels)))
while True:
try:
audio_readings = np.concatenate((audio_readings, self.read_queue.get_nowait()), axis=0)
except Empty:
break
self.read_queue = thread_Queue()
return audio_readings
def read(self) -> np.ndarray:
"""
Reads the last audio chunk recorded by the microphone, e.g. all samples recorded since the last read or since the beginning of the recording.
"""
if not self.is_connected:
raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
if not self.is_recording:
raise RobotDeviceNotRecordingError(f"Microphone {self.microphone_index} is not recording.")
start_time = time.perf_counter()
audio_readings = self._read()
# log the number of seconds it took to read the audio chunk
self.logs["delta_timestamp_s"] = time.perf_counter() - start_time
# log the utc time at which the audio chunk was received
self.logs["timestamp_utc"] = capture_timestamp_utc()
return audio_readings
def start_recording(self, output_file: str | None = None, multiprocessing: bool | None = False) -> None:
"""
Starts the recording of the microphone. If output_file is provided, the audio will be written to this file.
"""
if not self.is_connected:
raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
if self.is_recording:
raise RobotDeviceAlreadyRecordingError(
f"Microphone {self.microphone_index} is already recording."
)
# Reset queues
self.read_queue = thread_Queue()
if multiprocessing:
self.record_queue = process_Queue()
else:
self.record_queue = thread_Queue()
# Write recordings into a file if output_file is provided
if output_file is not None:
output_file = Path(output_file)
if output_file.exists():
output_file.unlink()
if multiprocessing:
self.record_stop_event = process_Event()
self.record_thread = Process(
target=Microphone._record_loop,
args=(
self.record_queue,
self.record_stop_event,
self.sample_rate,
self.channels,
output_file,
),
)
else:
self.record_stop_event = thread_Event()
self.record_thread = Thread(
target=Microphone._record_loop,
args=(
self.record_queue,
self.record_stop_event,
self.sample_rate,
self.channels,
output_file,
),
)
self.record_thread.daemon = True
self.record_thread.start()
self.is_writing = True
self.is_recording = True
self.stream.start()
def stop_recording(self) -> None:
"""
Stops the recording of the microphones.
"""
if not self.is_connected:
raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
if not self.is_recording:
raise RobotDeviceNotRecordingError(f"Microphone {self.microphone_index} is not recording.")
if self.stream.active:
self.stream.stop() # Wait for all buffers to be processed
# Remark : stream.abort() flushes the buffers !
self.is_recording = False
if self.record_thread is not None:
self.record_queue.join()
self.record_stop_event.set()
self.record_thread.join()
self.record_thread = None
self.record_stop_event = None
self.is_writing = False
def disconnect(self) -> None:
"""
Disconnects the microphone and stops the recording.
"""
if not self.is_connected:
raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
if self.is_recording:
self.stop_recording()
self.stream.close()
self.is_connected = False
def __del__(self):
if getattr(self, "is_connected", False):
self.disconnect()
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Records audio using `Microphone` for all microphones connected to the computer, or a selected subset."
)
parser.add_argument(
"--microphone-ids",
type=int,
nargs="*",
default=None,
help="List of microphones indices used to instantiate the `Microphone`. If not provided, find and use all available microphones indices.",
)
parser.add_argument(
"--output-dir",
type=Path,
default="outputs/audio_from_microphones",
help="Set directory to save an audio snippet for each microphone.",
)
parser.add_argument(
"--record-time-s",
type=float,
default=4.0,
help="Set the number of seconds used to record the audio. By default, 4 seconds.",
)
args = parser.parse_args()
record_audio_from_microphones(**vars(args))

View File

@ -0,0 +1,48 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Protocol
from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig, MicrophoneConfigBase
# Defines a microphone type
class Microphone(Protocol):
def connect(self): ...
def disconnect(self): ...
def start_recording(self, output_file: str | None = None, multiprocessing: bool | None = False): ...
def stop_recording(self): ...
def make_microphones_from_configs(microphone_configs: dict[str, MicrophoneConfigBase]) -> list[Microphone]:
microphones = {}
for key, cfg in microphone_configs.items():
if cfg.type == "microphone":
from lerobot.common.robot_devices.microphones.microphone import Microphone
microphones[key] = Microphone(cfg)
else:
raise ValueError(f"The microphone type '{cfg.type}' is not valid.")
return microphones
def make_microphone(microphone_type, **kwargs) -> Microphone:
if microphone_type == "microphone":
from lerobot.common.robot_devices.microphones.microphone import Microphone
return Microphone(MicrophoneConfig(**kwargs))
else:
raise ValueError(f"The microphone type '{microphone_type}' is not valid.")

View File

@ -23,6 +23,7 @@ from lerobot.common.robot_devices.cameras.configs import (
IntelRealSenseCameraConfig, IntelRealSenseCameraConfig,
OpenCVCameraConfig, OpenCVCameraConfig,
) )
from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig
from lerobot.common.robot_devices.motors.configs import ( from lerobot.common.robot_devices.motors.configs import (
DynamixelMotorsBusConfig, DynamixelMotorsBusConfig,
FeetechMotorsBusConfig, FeetechMotorsBusConfig,
@ -43,6 +44,7 @@ class ManipulatorRobotConfig(RobotConfig):
leader_arms: dict[str, MotorsBusConfig] = field(default_factory=lambda: {}) leader_arms: dict[str, MotorsBusConfig] = field(default_factory=lambda: {})
follower_arms: dict[str, MotorsBusConfig] = field(default_factory=lambda: {}) follower_arms: dict[str, MotorsBusConfig] = field(default_factory=lambda: {})
cameras: dict[str, CameraConfig] = field(default_factory=lambda: {}) cameras: dict[str, CameraConfig] = field(default_factory=lambda: {})
microphones: dict[str, MicrophoneConfig] = field(default_factory=lambda: {})
# Optionally limit the magnitude of the relative positional target vector for safety purposes. # Optionally limit the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length # Set this to a positive scalar to have the same value for all motors, or a list that is the same length
@ -68,6 +70,9 @@ class ManipulatorRobotConfig(RobotConfig):
for cam in self.cameras.values(): for cam in self.cameras.values():
if not cam.mock: if not cam.mock:
cam.mock = True cam.mock = True
for mic in self.microphones.values():
if not mic.mock:
mic.mock = True
if self.max_relative_target is not None and isinstance(self.max_relative_target, Sequence): if self.max_relative_target is not None and isinstance(self.max_relative_target, Sequence):
for name in self.follower_arms: for name in self.follower_arms:
@ -491,6 +496,21 @@ class So100RobotConfig(ManipulatorRobotConfig):
} }
) )
microphones: dict[str, MicrophoneConfig] = field(
default_factory=lambda: {
"laptop": MicrophoneConfig(
microphone_index=0,
sample_rate=48000,
channels=[1],
),
"headset": MicrophoneConfig(
microphone_index=1,
sample_rate=44100,
channels=[1],
),
}
)
mock: bool = False mock: bool = False

View File

@ -52,6 +52,16 @@ def run_camera_capture(cameras, images_lock, latest_images_dict, stop_event):
time.sleep(0.01) time.sleep(0.01)
def run_microphone_capture(microphones, audio_lock, latest_audio_dict, stop_event):
while not stop_event.is_set():
local_dict = {}
for name, microphone in microphones.items():
audio_readings = microphone.read()
local_dict[name] = audio_readings
with audio_lock:
latest_audio_dict.update(local_dict)
def calibrate_follower_arm(motors_bus, calib_dir_str): def calibrate_follower_arm(motors_bus, calib_dir_str):
""" """
Calibrates the follower arm. Attempts to load an existing calibration file; Calibrates the follower arm. Attempts to load an existing calibration file;
@ -94,6 +104,7 @@ def run_lekiwi(robot_config):
""" """
# Import helper functions and classes # Import helper functions and classes
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
from lerobot.common.robot_devices.microphones.utils import make_microphones_from_configs
from lerobot.common.robot_devices.motors.feetech import FeetechMotorsBus, TorqueMode from lerobot.common.robot_devices.motors.feetech import FeetechMotorsBus, TorqueMode
# Initialize cameras from the robot configuration. # Initialize cameras from the robot configuration.
@ -101,6 +112,11 @@ def run_lekiwi(robot_config):
for cam in cameras.values(): for cam in cameras.values():
cam.connect() cam.connect()
# Initialize microphones from the robot configuration.
microphones = make_microphones_from_configs(robot_config.microphones)
for microphone in microphones.values():
microphone.connect()
# Initialize the motors bus using the follower arm configuration. # Initialize the motors bus using the follower arm configuration.
motor_config = robot_config.follower_arms.get("main") motor_config = robot_config.follower_arms.get("main")
if motor_config is None: if motor_config is None:
@ -134,6 +150,20 @@ def run_lekiwi(robot_config):
) )
cam_thread.start() cam_thread.start()
# Start the microphone recording and capture thread.
# TODO(CarolinePascal) : Leverage multi-core processing with a multiprocessing.Process instead !
latest_audio_dict = {}
audio_lock = threading.Lock()
audio_stop_event = threading.Event()
microphone_thread = threading.Thread(
target=run_microphone_capture,
args=(microphones, audio_lock, latest_audio_dict, audio_stop_event),
daemon=True,
)
for microphone in microphones.values():
microphone.start_recording()
microphone_thread.start()
last_cmd_time = time.time() last_cmd_time = time.time()
print("LeKiwi robot server started. Waiting for commands...") print("LeKiwi robot server started. Waiting for commands...")
@ -198,9 +228,14 @@ def run_lekiwi(robot_config):
with images_lock: with images_lock:
images_dict_copy = dict(latest_images_dict) images_dict_copy = dict(latest_images_dict)
# Get the latest audio data.
with audio_lock:
audio_dict_copy = dict(latest_audio_dict)
# Build the observation dictionary. # Build the observation dictionary.
observation = { observation = {
"images": images_dict_copy, "images": images_dict_copy,
"audio": audio_dict_copy, # TODO(CarolinePascal) : This is a nasty way to do it, sorry.
"present_speed": current_velocity, "present_speed": current_velocity,
"follower_arm_state": follower_arm_state, "follower_arm_state": follower_arm_state,
} }
@ -217,6 +252,9 @@ def run_lekiwi(robot_config):
finally: finally:
stop_event.set() stop_event.set()
cam_thread.join() cam_thread.join()
microphone_thread.join()
for microphone in microphones.values():
microphone.stop_recording()
robot.stop() robot.stop()
motors_bus.disconnect() motors_bus.disconnect()
cmd_socket.close() cmd_socket.close()

View File

@ -28,6 +28,7 @@ import numpy as np
import torch import torch
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
from lerobot.common.robot_devices.microphones.utils import make_microphones_from_configs
from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs
from lerobot.common.robot_devices.robots.configs import ManipulatorRobotConfig from lerobot.common.robot_devices.robots.configs import ManipulatorRobotConfig
from lerobot.common.robot_devices.robots.utils import get_arm_id from lerobot.common.robot_devices.robots.utils import get_arm_id
@ -164,6 +165,7 @@ class ManipulatorRobot:
self.leader_arms = make_motors_buses_from_configs(self.config.leader_arms) self.leader_arms = make_motors_buses_from_configs(self.config.leader_arms)
self.follower_arms = make_motors_buses_from_configs(self.config.follower_arms) self.follower_arms = make_motors_buses_from_configs(self.config.follower_arms)
self.cameras = make_cameras_from_configs(self.config.cameras) self.cameras = make_cameras_from_configs(self.config.cameras)
self.microphones = make_microphones_from_configs(self.config.microphones)
self.is_connected = False self.is_connected = False
self.logs = {} self.logs = {}
@ -199,9 +201,24 @@ class ManipulatorRobot:
}, },
} }
@property
def microphone_features(self) -> dict:
mic_ft = {}
for mic_key, mic in self.microphones.items():
key = f"observation.audio.{mic_key}"
mic_ft[key] = {
"dtype": "audio",
"shape": (len(mic.channels),),
"names": "channels",
"info": {
"sample_rate": mic.sample_rate
}, # we need to store the sample rate here in the case of audio chunks recording (for LeKiwi), as it will not be available anymore when writing the audio file
}
return mic_ft
@property @property
def features(self): def features(self):
return {**self.motor_features, **self.camera_features} return {**self.motor_features, **self.camera_features, **self.microphone_features}
@property @property
def has_camera(self): def has_camera(self):
@ -211,6 +228,14 @@ class ManipulatorRobot:
def num_cameras(self): def num_cameras(self):
return len(self.cameras) return len(self.cameras)
@property
def has_microphone(self):
return len(self.microphones) > 0
@property
def num_microphones(self):
return len(self.microphones)
@property @property
def available_arms(self): def available_arms(self):
available_arms = [] available_arms = []
@ -228,7 +253,7 @@ class ManipulatorRobot:
"ManipulatorRobot is already connected. Do not run `robot.connect()` twice." "ManipulatorRobot is already connected. Do not run `robot.connect()` twice."
) )
if not self.leader_arms and not self.follower_arms and not self.cameras: if not self.leader_arms and not self.follower_arms and not self.cameras and not self.microphones:
raise ValueError( raise ValueError(
"ManipulatorRobot doesn't have any device to connect. See example of usage in docstring of the class." "ManipulatorRobot doesn't have any device to connect. See example of usage in docstring of the class."
) )
@ -289,6 +314,10 @@ class ManipulatorRobot:
for name in self.cameras: for name in self.cameras:
self.cameras[name].connect() self.cameras[name].connect()
# Connect the microphones
for name in self.microphones:
self.microphones[name].connect()
self.is_connected = True self.is_connected = True
def activate_calibration(self): def activate_calibration(self):
@ -514,12 +543,23 @@ class ManipulatorRobot:
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"] self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
# Capture audio from microphones
audio = {}
for name in self.microphones:
before_audioread_t = time.perf_counter()
audio[name] = self.microphones[name].read()
audio[name] = torch.from_numpy(audio[name])
self.logs[f"read_microphone_{name}_dt_s"] = self.microphones[name].logs["delta_timestamp_s"]
self.logs[f"async_read_microphone_{name}_dt_s"] = time.perf_counter() - before_audioread_t
# Populate output dictionaries # Populate output dictionaries
obs_dict, action_dict = {}, {} obs_dict, action_dict = {}, {}
obs_dict["observation.state"] = state obs_dict["observation.state"] = state
action_dict["action"] = action action_dict["action"] = action
for name in self.cameras: for name in self.cameras:
obs_dict[f"observation.images.{name}"] = images[name] obs_dict[f"observation.images.{name}"] = images[name]
for name in self.microphones:
obs_dict[f"observation.audio.{name}"] = audio[name]
return obs_dict, action_dict return obs_dict, action_dict
@ -554,11 +594,22 @@ class ManipulatorRobot:
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"] self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
# Capture audio from microphones
audio = {}
for name in self.microphones:
before_audioread_t = time.perf_counter()
audio[name] = self.microphones[name].read()
audio[name] = torch.from_numpy(audio[name])
self.logs[f"read_microphone_{name}_dt_s"] = self.microphones[name].logs["delta_timestamp_s"]
self.logs[f"async_read_microphone_{name}_dt_s"] = time.perf_counter() - before_audioread_t
# Populate output dictionaries and format to pytorch # Populate output dictionaries and format to pytorch
obs_dict = {} obs_dict = {}
obs_dict["observation.state"] = state obs_dict["observation.state"] = state
for name in self.cameras: for name in self.cameras:
obs_dict[f"observation.images.{name}"] = images[name] obs_dict[f"observation.images.{name}"] = images[name]
for name in self.microphones:
obs_dict[f"observation.audio.{name}"] = audio[name]
return obs_dict return obs_dict
def send_action(self, action: torch.Tensor) -> torch.Tensor: def send_action(self, action: torch.Tensor) -> torch.Tensor:
@ -620,6 +671,9 @@ class ManipulatorRobot:
for name in self.cameras: for name in self.cameras:
self.cameras[name].disconnect() self.cameras[name].disconnect()
for name in self.microphones:
self.microphones[name].disconnect()
self.is_connected = False self.is_connected = False
def __del__(self): def __del__(self):

View File

@ -24,6 +24,7 @@ import torch
import zmq import zmq
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
from lerobot.common.robot_devices.microphones.utils import make_microphones_from_configs
from lerobot.common.robot_devices.motors.feetech import TorqueMode from lerobot.common.robot_devices.motors.feetech import TorqueMode
from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs
from lerobot.common.robot_devices.robots.configs import LeKiwiRobotConfig from lerobot.common.robot_devices.robots.configs import LeKiwiRobotConfig
@ -79,6 +80,7 @@ class MobileManipulator:
self.follower_arms = make_motors_buses_from_configs(self.config.follower_arms) self.follower_arms = make_motors_buses_from_configs(self.config.follower_arms)
self.cameras = make_cameras_from_configs(self.config.cameras) self.cameras = make_cameras_from_configs(self.config.cameras)
self.microphones = make_microphones_from_configs(self.config.microphones)
self.is_connected = False self.is_connected = False
@ -133,6 +135,7 @@ class MobileManipulator:
"shape": (cam.height, cam.width, cam.channels), "shape": (cam.height, cam.width, cam.channels),
"names": ["height", "width", "channels"], "names": ["height", "width", "channels"],
"info": None, "info": None,
"audio": "observation.audio." + cam.microphone if cam.microphone is not None else None,
} }
return cam_ft return cam_ft
@ -161,9 +164,22 @@ class MobileManipulator:
}, },
} }
@property
def microphone_features(self) -> dict:
mic_ft = {}
for mic_key, mic in self.microphones.items():
key = f"observation.audio.{mic_key}"
mic_ft[key] = {
"dtype": "audio",
"shape": (len(mic.channels),),
"names": "channels",
"info": {"sample_rate": mic.sample_rate},
}
return mic_ft
@property @property
def features(self): def features(self):
return {**self.motor_features, **self.camera_features} return {**self.motor_features, **self.camera_features, **self.microphone_features}
@property @property
def has_camera(self): def has_camera(self):
@ -173,6 +189,14 @@ class MobileManipulator:
def num_cameras(self): def num_cameras(self):
return len(self.cameras) return len(self.cameras)
@property
def has_microphone(self):
return len(self.microphones) > 0
@property
def num_microphones(self):
return len(self.microphones)
@property @property
def available_arms(self): def available_arms(self):
available = [] available = []
@ -344,6 +368,7 @@ class MobileManipulator:
observation = json.loads(last_msg) observation = json.loads(last_msg)
images_dict = observation.get("images", {}) images_dict = observation.get("images", {})
audio_dict = observation.get("audio", {})
new_speed = observation.get("present_speed", {}) new_speed = observation.get("present_speed", {})
new_arm_state = observation.get("follower_arm_state", None) new_arm_state = observation.get("follower_arm_state", None)
@ -356,6 +381,11 @@ class MobileManipulator:
if frame_candidate is not None: if frame_candidate is not None:
frames[cam_name] = frame_candidate frames[cam_name] = frame_candidate
# Receive audio
for microphone_name, audio_data in audio_dict.items():
if audio_data:
frames[microphone_name] = audio_data
# If remote_arm_state is None and frames is None there is no message then use the previous message # If remote_arm_state is None and frames is None there is no message then use the previous message
if new_arm_state is not None and frames is not None: if new_arm_state is not None and frames is not None:
self.last_frames = frames self.last_frames = frames
@ -475,6 +505,14 @@ class MobileManipulator:
frame = np.zeros((cam.height, cam.width, cam.channels), dtype=np.uint8) frame = np.zeros((cam.height, cam.width, cam.channels), dtype=np.uint8)
obs_dict[f"observation.images.{cam_name}"] = torch.from_numpy(frame) obs_dict[f"observation.images.{cam_name}"] = torch.from_numpy(frame)
# Loop over each configured microphone
for microphone_name, microphone in self.microphones.items():
frame = frames.get(microphone_name, None)
if frame is None:
# Create silence using the microphone's configured channels
frame = np.zeros((1, len(microphone.channels)), dtype=np.float32)
obs_dict[f"observation.audio.{microphone_name}"] = torch.from_numpy(frame)
return obs_dict return obs_dict
def send_action(self, action: torch.Tensor) -> torch.Tensor: def send_action(self, action: torch.Tensor) -> torch.Tensor:

View File

@ -63,3 +63,25 @@ class RobotDeviceAlreadyConnectedError(Exception):
): ):
self.message = message self.message = message
super().__init__(self.message) super().__init__(self.message)
class RobotDeviceNotRecordingError(Exception):
"""Exception raised when the robot device is not recording."""
def __init__(
self,
message="This robot device is not recording. Try calling `robot_device.start_recording()` first.",
):
self.message = message
super().__init__(self.message)
class RobotDeviceAlreadyRecordingError(Exception):
"""Exception raised when the robot device is already recording."""
def __init__(
self,
message="This robot device is already recording. Try not calling `robot_device.start_recording()` twice.",
):
self.message = message
super().__init__(self.message)

View File

@ -67,8 +67,11 @@ dependencies = [
"pynput>=1.7.7", "pynput>=1.7.7",
"pyzmq>=26.2.1", "pyzmq>=26.2.1",
"rerun-sdk>=0.21.0", "rerun-sdk>=0.21.0",
"sounddevice>=0.5.1",
"soundfile>=0.13.1",
"termcolor>=2.4.0", "termcolor>=2.4.0",
"torch>=2.2.1", "torch>=2.2.1",
"torchaudio>=2.6.0",
"torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", "torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')",
"torchvision>=0.21.0", "torchvision>=0.21.0",
"wandb>=0.16.3", "wandb>=0.16.3",
@ -96,6 +99,7 @@ test = ["pytest>=8.1.0", "pytest-cov>=5.0.0", "pyserial>=3.5"]
umi = ["imagecodecs>=2024.1.1"] umi = ["imagecodecs>=2024.1.1"]
video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"] video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"]
xarm = ["gym-xarm>=0.1.1 ; python_version < '4.0'"] xarm = ["gym-xarm>=0.1.1 ; python_version < '4.0'"]
audio = ["librosa>=0.11.0"]
[tool.poetry] [tool.poetry]
requires-poetry = ">=2.1" requires-poetry = ">=2.1"

View File

@ -19,9 +19,9 @@ import traceback
import pytest import pytest
from serial import SerialException from serial import SerialException
from lerobot import available_cameras, available_motors, available_robots from lerobot import available_cameras, available_microphones, available_motors, available_robots
from lerobot.common.robot_devices.robots.utils import make_robot from lerobot.common.robot_devices.robots.utils import make_robot
from tests.utils import DEVICE, make_camera, make_motors_bus from tests.utils import DEVICE, make_camera, make_microphone, make_motors_bus
# Import fixture modules as plugins # Import fixture modules as plugins
pytest_plugins = [ pytest_plugins = [
@ -74,6 +74,11 @@ def is_camera_available(camera_type):
return _check_component_availability(camera_type, available_cameras, make_camera) return _check_component_availability(camera_type, available_cameras, make_camera)
@pytest.fixture
def is_microphone_available(microphone_type):
return _check_component_availability(microphone_type, available_microphones, make_microphone)
@pytest.fixture @pytest.fixture
def is_motor_available(motor_type): def is_motor_available(motor_type):
return _check_component_availability(motor_type, available_motors, make_motors_bus) return _check_component_availability(motor_type, available_motors, make_motors_bus)

View File

@ -25,6 +25,8 @@ from lerobot.common.datasets.compute_stats import (
compute_episode_stats, compute_episode_stats,
estimate_num_samples, estimate_num_samples,
get_feature_stats, get_feature_stats,
sample_audio_from_data,
sample_audio_from_path,
sample_images, sample_images,
sample_indices, sample_indices,
) )
@ -34,6 +36,10 @@ def mock_load_image_as_numpy(path, dtype, channel_first):
return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype) return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype)
def mock_load_audio(path):
return np.ones((16000, 2), dtype=np.float32)
@pytest.fixture @pytest.fixture
def sample_array(): def sample_array():
return np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) return np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
@ -71,6 +77,25 @@ def test_sample_images(mock_load):
assert len(images) == estimate_num_samples(100) assert len(images) == estimate_num_samples(100)
@patch("lerobot.common.datasets.compute_stats.load_audio_from_path", side_effect=mock_load_audio)
def test_sample_audio_from_path(mock_load):
audio_path = "audio.wav"
audio_samples = sample_audio_from_path(audio_path)
assert isinstance(audio_samples, np.ndarray)
assert audio_samples.shape[1] == 2
assert audio_samples.dtype == np.float32
assert len(audio_samples) == estimate_num_samples(16000)
def test_sample_audio_from_data():
audio_data = np.ones((16000, 2), dtype=np.float32)
audio_samples = sample_audio_from_data(audio_data)
assert isinstance(audio_samples, np.ndarray)
assert audio_samples.shape[1] == 2
assert audio_samples.dtype == np.float32
assert len(audio_samples) == estimate_num_samples(16000)
def test_get_feature_stats_images(): def test_get_feature_stats_images():
data = np.random.rand(100, 3, 32, 32) data = np.random.rand(100, 3, 32, 32)
stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True) stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True)
@ -79,6 +104,14 @@ def test_get_feature_stats_images():
assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape
def test_get_feature_stats_audio():
data = np.random.uniform(-1, 1, (16000, 2))
stats = get_feature_stats(data, axis=0, keepdims=True)
assert "min" in stats and "max" in stats and "mean" in stats and "std" in stats and "count" in stats
np.testing.assert_equal(stats["count"], np.array([16000]))
assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape
def test_get_feature_stats_axis_0_keepdims(sample_array): def test_get_feature_stats_axis_0_keepdims(sample_array):
expected = { expected = {
"min": np.array([[1, 2, 3]]), "min": np.array([[1, 2, 3]]),
@ -137,22 +170,29 @@ def test_get_feature_stats_single_value():
def test_compute_episode_stats(): def test_compute_episode_stats():
episode_data = { episode_data = {
"observation.image": [f"image_{i}.jpg" for i in range(100)], "observation.image": [f"image_{i}.jpg" for i in range(100)],
"observation.audio": "audio.wav",
"observation.state": np.random.rand(100, 10), "observation.state": np.random.rand(100, 10),
} }
features = { features = {
"observation.image": {"dtype": "image"}, "observation.image": {"dtype": "image"},
"observation.audio": {"dtype": "audio"},
"observation.state": {"dtype": "numeric"}, "observation.state": {"dtype": "numeric"},
} }
with patch( with (
"lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy patch(
"lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy
),
patch("lerobot.common.datasets.compute_stats.load_audio_from_path", side_effect=mock_load_audio),
): ):
stats = compute_episode_stats(episode_data, features) stats = compute_episode_stats(episode_data, features)
assert "observation.image" in stats and "observation.state" in stats assert "observation.image" in stats and "observation.state" in stats and "observation.audio" in stats
assert stats["observation.image"]["count"].item() == 100 assert stats["observation.image"]["count"].item() == estimate_num_samples(100)
assert stats["observation.state"]["count"].item() == 100 assert stats["observation.audio"]["count"].item() == estimate_num_samples(16000)
assert stats["observation.state"]["count"].item() == estimate_num_samples(100)
assert stats["observation.image"]["mean"].shape == (3, 1, 1) assert stats["observation.image"]["mean"].shape == (3, 1, 1)
assert stats["observation.audio"]["mean"].shape == (1, 2)
def test_assert_type_and_shape_valid(): def test_assert_type_and_shape_valid():

View File

@ -16,6 +16,7 @@
import json import json
import logging import logging
import re import re
import time
from copy import deepcopy from copy import deepcopy
from itertools import chain from itertools import chain
from pathlib import Path from pathlib import Path
@ -35,6 +36,7 @@ from lerobot.common.datasets.lerobot_dataset import (
MultiLeRobotDataset, MultiLeRobotDataset,
) )
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
DEFAULT_AUDIO_CHUNK_DURATION,
create_branch, create_branch,
flatten_dict, flatten_dict,
unflatten_dict, unflatten_dict,
@ -44,8 +46,8 @@ from lerobot.common.policies.factory import make_policy_config
from lerobot.common.robot_devices.robots.utils import make_robot from lerobot.common.robot_devices.robots.utils import make_robot
from lerobot.configs.default import DatasetConfig from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig from lerobot.configs.train import TrainPipelineConfig
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID from tests.fixtures.constants import DUMMY_AUDIO_CHANNELS, DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
from tests.utils import require_x86_64_kernel from tests.utils import make_microphone, require_x86_64_kernel
@pytest.fixture @pytest.fixture
@ -64,6 +66,20 @@ def image_dataset(tmp_path, empty_lerobot_dataset_factory):
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
@pytest.fixture
def audio_dataset(tmp_path, empty_lerobot_dataset_factory):
features = {
"observation.audio.microphone": {
"dtype": "audio",
"shape": (DUMMY_AUDIO_CHANNELS,),
"names": [
"channels",
],
}
}
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
def test_same_attributes_defined(tmp_path, lerobot_dataset_factory): def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
""" """
Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated
@ -322,6 +338,24 @@ def test_image_array_to_pil_image_wrong_range_float_0_255():
image_array_to_pil_image(image) image_array_to_pil_image(image)
def test_add_frame_audio(audio_dataset):
dataset = audio_dataset
microphone = make_microphone(microphone_type="microphone", mock=True)
microphone.connect()
dataset.add_microphone_recording(microphone, "microphone")
time.sleep(1.0)
dataset.add_frame({"observation.audio.microphone": microphone.read(), "task": "Dummy task"})
microphone.stop_recording()
dataset.save_episode()
assert dataset[0]["observation.audio.microphone"].shape == torch.Size(
(int(DEFAULT_AUDIO_CHUNK_DURATION * microphone.sample_rate), DUMMY_AUDIO_CHANNELS)
)
# TODO(aliberts): # TODO(aliberts):
# - [ ] test various attributes & state from init and create # - [ ] test various attributes & state from init and create
# - [ ] test init with episodes and check num_frames # - [ ] test init with episodes and check num_frames
@ -354,6 +388,7 @@ def test_factory(env_name, repo_id, policy_name):
dataset = make_dataset(cfg) dataset = make_dataset(cfg)
delta_timestamps = dataset.delta_timestamps delta_timestamps = dataset.delta_timestamps
camera_keys = dataset.meta.camera_keys camera_keys = dataset.meta.camera_keys
audio_keys = dataset.meta.audio_keys
item = dataset[0] item = dataset[0]
@ -396,6 +431,11 @@ def test_factory(env_name, repo_id, policy_name):
# test c,h,w # test c,h,w
assert item[key].shape[0] == 3, f"{key}" assert item[key].shape[0] == 3, f"{key}"
for key in audio_keys:
assert item[key].dtype == torch.float32, f"{key}"
assert item[key].max() <= 1.0, f"{key}"
assert item[key].min() >= -1.0, f"{key}"
if delta_timestamps is not None: if delta_timestamps is not None:
# test missing keys in delta_timestamps # test missing keys in delta_timestamps
for key in delta_timestamps: for key in delta_timestamps:

View File

@ -29,7 +29,12 @@ DUMMY_MOTOR_FEATURES = {
}, },
} }
DUMMY_CAMERA_FEATURES = { DUMMY_CAMERA_FEATURES = {
"laptop": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None}, "laptop": {
"shape": (480, 640, 3),
"names": ["height", "width", "channels"],
"info": None,
"audio": "laptop",
},
"phone": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None}, "phone": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
} }
DEFAULT_FPS = 30 DEFAULT_FPS = 30
@ -40,5 +45,18 @@ DUMMY_VIDEO_INFO = {
"video.is_depth_map": False, "video.is_depth_map": False,
"has_audio": False, "has_audio": False,
} }
DUMMY_MICROPHONE_FEATURES = {
"laptop": {"dtype": "audio", "shape": (1,), "names": ["channels"], "info": None},
"phone": {"dtype": "audio", "shape": (1,), "names": ["channels"], "info": None},
}
DEFAULT_SAMPLE_RATE = 48000
DUMMY_AUDIO_CHANNELS = 2
DUMMY_AUDIO_INFO = {
"has_audio": True,
"audio.sample_rate": DEFAULT_SAMPLE_RATE,
"audio.codec": "aac",
"audio.channels": DUMMY_AUDIO_CHANNELS,
"audio.channel_layout": "stereo",
}
DUMMY_CHW = (3, 96, 128) DUMMY_CHW = (3, 96, 128)
DUMMY_HWC = (96, 128, 3) DUMMY_HWC = (96, 128, 3)

View File

@ -26,6 +26,7 @@ import torch
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_SIZE,
DEFAULT_COMPRESSED_AUDIO_PATH,
DEFAULT_FEATURES, DEFAULT_FEATURES,
DEFAULT_PARQUET_PATH, DEFAULT_PARQUET_PATH,
DEFAULT_VIDEO_PATH, DEFAULT_VIDEO_PATH,
@ -35,6 +36,7 @@ from lerobot.common.datasets.utils import (
from tests.fixtures.constants import ( from tests.fixtures.constants import (
DEFAULT_FPS, DEFAULT_FPS,
DUMMY_CAMERA_FEATURES, DUMMY_CAMERA_FEATURES,
DUMMY_MICROPHONE_FEATURES,
DUMMY_MOTOR_FEATURES, DUMMY_MOTOR_FEATURES,
DUMMY_REPO_ID, DUMMY_REPO_ID,
DUMMY_ROBOT_TYPE, DUMMY_ROBOT_TYPE,
@ -90,6 +92,7 @@ def features_factory():
def _create_features( def _create_features(
motor_features: dict = DUMMY_MOTOR_FEATURES, motor_features: dict = DUMMY_MOTOR_FEATURES,
camera_features: dict = DUMMY_CAMERA_FEATURES, camera_features: dict = DUMMY_CAMERA_FEATURES,
audio_features: dict = DUMMY_MICROPHONE_FEATURES,
use_videos: bool = True, use_videos: bool = True,
) -> dict: ) -> dict:
if use_videos: if use_videos:
@ -101,6 +104,7 @@ def features_factory():
return { return {
**motor_features, **motor_features,
**camera_ft, **camera_ft,
**audio_features,
**DEFAULT_FEATURES, **DEFAULT_FEATURES,
} }
@ -117,15 +121,18 @@ def info_factory(features_factory):
total_frames: int = 0, total_frames: int = 0,
total_tasks: int = 0, total_tasks: int = 0,
total_videos: int = 0, total_videos: int = 0,
total_audio: int = 0,
total_chunks: int = 0, total_chunks: int = 0,
chunks_size: int = DEFAULT_CHUNK_SIZE, chunks_size: int = DEFAULT_CHUNK_SIZE,
data_path: str = DEFAULT_PARQUET_PATH, data_path: str = DEFAULT_PARQUET_PATH,
video_path: str = DEFAULT_VIDEO_PATH, video_path: str = DEFAULT_VIDEO_PATH,
audio_path: str = DEFAULT_COMPRESSED_AUDIO_PATH,
motor_features: dict = DUMMY_MOTOR_FEATURES, motor_features: dict = DUMMY_MOTOR_FEATURES,
camera_features: dict = DUMMY_CAMERA_FEATURES, camera_features: dict = DUMMY_CAMERA_FEATURES,
audio_features: dict = DUMMY_MICROPHONE_FEATURES,
use_videos: bool = True, use_videos: bool = True,
) -> dict: ) -> dict:
features = features_factory(motor_features, camera_features, use_videos) features = features_factory(motor_features, camera_features, audio_features, use_videos)
return { return {
"codebase_version": codebase_version, "codebase_version": codebase_version,
"robot_type": robot_type, "robot_type": robot_type,
@ -133,12 +140,14 @@ def info_factory(features_factory):
"total_frames": total_frames, "total_frames": total_frames,
"total_tasks": total_tasks, "total_tasks": total_tasks,
"total_videos": total_videos, "total_videos": total_videos,
"total_audio": total_audio,
"total_chunks": total_chunks, "total_chunks": total_chunks,
"chunks_size": chunks_size, "chunks_size": chunks_size,
"fps": fps, "fps": fps,
"splits": {}, "splits": {},
"data_path": data_path, "data_path": data_path,
"video_path": video_path if use_videos else None, "video_path": video_path if use_videos else None,
"audio_path": audio_path,
"features": features, "features": features,
} }
@ -162,6 +171,14 @@ def stats_factory():
"std": np.full((3, 1, 1), 0.25, dtype=np.float32).tolist(), "std": np.full((3, 1, 1), 0.25, dtype=np.float32).tolist(),
"count": [10], "count": [10],
} }
elif dtype == "audio":
stats[key] = {
"mean": np.full((shape[0],), 0.0, dtype=np.float32).tolist(),
"max": np.full((shape[0],), 1, dtype=np.float32).tolist(),
"min": np.full((shape[0],), -1, dtype=np.float32).tolist(),
"std": np.full((shape[0],), 0.5, dtype=np.float32).tolist(),
"count": [10],
}
else: else:
stats[key] = { stats[key] = {
"max": np.full(shape, 1, dtype=dtype).tolist(), "max": np.full(shape, 1, dtype=dtype).tolist(),

View File

@ -0,0 +1,88 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from functools import cache
from threading import Event, Thread
import numpy as np
from lerobot.common.utils.utils import capture_timestamp_utc
from tests.fixtures.constants import DEFAULT_SAMPLE_RATE, DUMMY_AUDIO_CHANNELS
@cache
def _generate_sound(duration: float, sample_rate: int, channels: int):
return np.random.uniform(-1, 1, size=(int(duration * sample_rate), channels)).astype(np.float32)
def query_devices(query_index: int):
return {
"name": "Mock Sound Device",
"index": query_index,
"max_input_channels": DUMMY_AUDIO_CHANNELS,
"default_samplerate": DEFAULT_SAMPLE_RATE,
}
class InputStream:
def __init__(self, *args, **kwargs):
self._mock_dict = {
"channels": DUMMY_AUDIO_CHANNELS,
"samplerate": DEFAULT_SAMPLE_RATE,
}
self._is_active = False
self._audio_callback = kwargs.get("callback")
self.callback_thread = None
self.callback_thread_stop_event = None
def _acquisition_loop(self):
if self._audio_callback is not None:
while not self.callback_thread_stop_event.is_set():
# Simulate audio data acquisition
time.sleep(0.01)
self._audio_callback(
_generate_sound(0.01, DEFAULT_SAMPLE_RATE, DUMMY_AUDIO_CHANNELS),
0.01 * DEFAULT_SAMPLE_RATE,
capture_timestamp_utc(),
None,
)
def start(self):
self.callback_thread_stop_event = Event()
self.callback_thread = Thread(target=self._acquisition_loop, args=())
self.callback_thread.daemon = True
self.callback_thread.start()
self._is_active = True
@property
def active(self):
return self._is_active
def stop(self):
if self.callback_thread_stop_event is not None:
self.callback_thread_stop_event.set()
self.callback_thread.join()
self.callback_thread = None
self.callback_thread_stop_event = None
self._is_active = False
def close(self):
if self._is_active:
self.stop()
def __del__(self):
if self._is_active:
self.stop()

View File

@ -0,0 +1,152 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Tests for physical microphones and their mocked versions.
If the physical microphone is not connected to the computer, or not working,
the test will be skipped.
Example of running a specific test:
```bash
pytest -sx tests/microphones/test_microphones.py::test_microphone
```
Example of running test on a real microphone connected to the computer:
```bash
pytest -sx 'tests/microphones/test_microphones.py::test_microphone[microphone-False]'
```
Example of running test on a mocked version of the microphone:
```bash
pytest -sx 'tests/microphones/test_microphones.py::test_microphone[microphone-True]'
```
"""
import time
import numpy as np
import pytest
from soundfile import read
from lerobot.common.robot_devices.utils import (
RobotDeviceAlreadyConnectedError,
RobotDeviceAlreadyRecordingError,
RobotDeviceNotConnectedError,
RobotDeviceNotRecordingError,
)
from tests.utils import TEST_MICROPHONE_TYPES, make_microphone, require_microphone
# Maximum recording tie difference between two consecutive audio recordings of the same duration.
# Set to 0.02 seconds as twice the default size of sounddvice callback buffer (i.e. we tolerate the loss of one buffer).
MAX_RECORDING_TIME_DIFFERENCE = 0.02
DUMMY_RECORDING = "test_recording.wav"
@pytest.mark.parametrize("microphone_type, mock", TEST_MICROPHONE_TYPES)
@require_microphone
def test_microphone(tmp_path, request, microphone_type, mock):
"""Test assumes that a recroding handled with microphone.start_recording(output_file) and stop_recording() or microphone.read()
leqds to a sample that does not differ from the requested duration by more than 0.1 seconds.
"""
microphone_kwargs = {"microphone_type": microphone_type, "mock": mock}
# Test instantiating
microphone = make_microphone(**microphone_kwargs)
# Test start_recording, stop_recording, read and disconnecting before connecting raises an error
with pytest.raises(RobotDeviceNotConnectedError):
microphone.start_recording()
with pytest.raises(RobotDeviceNotConnectedError):
microphone.stop_recording()
with pytest.raises(RobotDeviceNotConnectedError):
microphone.read()
with pytest.raises(RobotDeviceNotConnectedError):
microphone.disconnect()
# Test deleting the object without connecting first
del microphone
# Test connecting
microphone = make_microphone(**microphone_kwargs)
microphone.connect()
assert microphone.is_connected
assert microphone.sample_rate is not None
assert microphone.channels is not None
# Test connecting twice raises an error
with pytest.raises(RobotDeviceAlreadyConnectedError):
microphone.connect()
# Test reading or stop recording before starting recording raises an error
with pytest.raises(RobotDeviceNotRecordingError):
microphone.read()
with pytest.raises(RobotDeviceNotRecordingError):
microphone.stop_recording()
# Test start_recording
fpath = tmp_path / DUMMY_RECORDING
microphone.start_recording(fpath)
assert microphone.is_recording
# Test start_recording twice raises an error
with pytest.raises(RobotDeviceAlreadyRecordingError):
microphone.start_recording()
# Test reading from the microphone
time.sleep(1.0)
audio_chunk = microphone.read()
assert isinstance(audio_chunk, np.ndarray)
assert audio_chunk.ndim == 2
_, c = audio_chunk.shape
assert c == len(microphone.channels)
# Test stop_recording
microphone.stop_recording()
assert fpath.exists()
assert not microphone.stream.active
assert microphone.record_thread is None
# Test stop_recording twice raises an error
with pytest.raises(RobotDeviceNotRecordingError):
microphone.stop_recording()
# Test reading and recording output similar length audio chunks
microphone.start_recording(tmp_path / DUMMY_RECORDING)
time.sleep(1.0)
audio_chunk = microphone.read()
microphone.stop_recording()
recorded_audio, recorded_sample_rate = read(fpath)
assert recorded_sample_rate == microphone.sample_rate
error_msg = (
"Recording time difference between read() and stop_recording()",
(len(audio_chunk) - len(recorded_audio)) / MAX_RECORDING_TIME_DIFFERENCE,
)
np.testing.assert_allclose(
len(audio_chunk),
len(recorded_audio),
atol=recorded_sample_rate * MAX_RECORDING_TIME_DIFFERENCE,
err_msg=error_msg,
)
# Test disconnecting
microphone.disconnect()
assert not microphone.is_connected
# Test disconnecting with `__del__`
microphone = make_microphone(**microphone_kwargs)
microphone.connect()
del microphone

View File

@ -100,6 +100,9 @@ def test_robot(tmp_path, request, robot_type, mock):
robot.teleop_step() robot.teleop_step()
# Test data recorded during teleop are well formatted # Test data recorded during teleop are well formatted
for _, microphone in robot.microphones.items():
microphone.start_recording()
observation, action = robot.teleop_step(record_data=True) observation, action = robot.teleop_step(record_data=True)
# State # State
assert "observation.state" in observation assert "observation.state" in observation
@ -112,6 +115,11 @@ def test_robot(tmp_path, request, robot_type, mock):
assert f"observation.images.{name}" in observation assert f"observation.images.{name}" in observation
assert isinstance(observation[f"observation.images.{name}"], torch.Tensor) assert isinstance(observation[f"observation.images.{name}"], torch.Tensor)
assert observation[f"observation.images.{name}"].ndim == 3 assert observation[f"observation.images.{name}"].ndim == 3
# Microphones
for name in robot.microphones:
assert f"observation.audio.{name}" in observation
assert isinstance(observation[f"observation.audio.{name}"], torch.Tensor)
assert observation[f"observation.audio.{name}"].ndim == 2
# Action # Action
assert "action" in action assert "action" in action
assert isinstance(action["action"], torch.Tensor) assert isinstance(action["action"], torch.Tensor)
@ -124,8 +132,9 @@ def test_robot(tmp_path, request, robot_type, mock):
captured_observation = robot.capture_observation() captured_observation = robot.capture_observation()
assert set(captured_observation.keys()) == set(observation.keys()) assert set(captured_observation.keys()) == set(observation.keys())
for name in captured_observation: for name in captured_observation:
if "image" in name: if "image" in name or "audio" in name:
# TODO(rcadene): skipping image for now as it's challenging to assess equality between two consecutive frames # TODO(rcadene): skipping image for now as it's challenging to assess equality between two consecutive frames
# Also skipping for audio as audio chunks may be of different length
continue continue
torch.testing.assert_close(captured_observation[name], observation[name], rtol=1e-4, atol=1) torch.testing.assert_close(captured_observation[name], observation[name], rtol=1e-4, atol=1)
assert captured_observation[name].shape == observation[name].shape assert captured_observation[name].shape == observation[name].shape
@ -134,7 +143,7 @@ def test_robot(tmp_path, request, robot_type, mock):
robot.send_action(action["action"]) robot.send_action(action["action"])
# Test disconnecting # Test disconnecting
robot.disconnect() robot.disconnect() # Also handles microphone recording stop, life is beautiful
assert not robot.is_connected assert not robot.is_connected
for name in robot.follower_arms: for name in robot.follower_arms:
assert not robot.follower_arms[name].is_connected assert not robot.follower_arms[name].is_connected
@ -142,3 +151,5 @@ def test_robot(tmp_path, request, robot_type, mock):
assert not robot.leader_arms[name].is_connected assert not robot.leader_arms[name].is_connected
for name in robot.cameras: for name in robot.cameras:
assert not robot.cameras[name].is_connected assert not robot.cameras[name].is_connected
for name in robot.microphones:
assert not robot.microphones[name].is_connected

View File

@ -22,9 +22,11 @@ from pathlib import Path
import pytest import pytest
import torch import torch
from lerobot import available_cameras, available_motors, available_robots from lerobot import available_cameras, available_microphones, available_motors, available_robots
from lerobot.common.robot_devices.cameras.utils import Camera from lerobot.common.robot_devices.cameras.utils import Camera
from lerobot.common.robot_devices.cameras.utils import make_camera as make_camera_device from lerobot.common.robot_devices.cameras.utils import make_camera as make_camera_device
from lerobot.common.robot_devices.microphones.utils import Microphone
from lerobot.common.robot_devices.microphones.utils import make_microphone as make_microphone_device
from lerobot.common.robot_devices.motors.utils import MotorsBus from lerobot.common.robot_devices.motors.utils import MotorsBus
from lerobot.common.robot_devices.motors.utils import make_motors_bus as make_motors_bus_device from lerobot.common.robot_devices.motors.utils import make_motors_bus as make_motors_bus_device
from lerobot.common.utils.import_utils import is_package_available from lerobot.common.utils.import_utils import is_package_available
@ -39,6 +41,10 @@ TEST_CAMERA_TYPES = []
for camera_type in available_cameras: for camera_type in available_cameras:
TEST_CAMERA_TYPES += [(camera_type, True), (camera_type, False)] TEST_CAMERA_TYPES += [(camera_type, True), (camera_type, False)]
TEST_MICROPHONE_TYPES = []
for microphone_type in available_microphones:
TEST_MICROPHONE_TYPES += [(microphone_type, True), (microphone_type, False)]
TEST_MOTOR_TYPES = [] TEST_MOTOR_TYPES = []
for motor_type in available_motors: for motor_type in available_motors:
TEST_MOTOR_TYPES += [(motor_type, True), (motor_type, False)] TEST_MOTOR_TYPES += [(motor_type, True), (motor_type, False)]
@ -47,6 +53,9 @@ for motor_type in available_motors:
OPENCV_CAMERA_INDEX = int(os.environ.get("LEROBOT_TEST_OPENCV_CAMERA_INDEX", 0)) OPENCV_CAMERA_INDEX = int(os.environ.get("LEROBOT_TEST_OPENCV_CAMERA_INDEX", 0))
INTELREALSENSE_SERIAL_NUMBER = int(os.environ.get("LEROBOT_TEST_INTELREALSENSE_SERIAL_NUMBER", 128422271614)) INTELREALSENSE_SERIAL_NUMBER = int(os.environ.get("LEROBOT_TEST_INTELREALSENSE_SERIAL_NUMBER", 128422271614))
# Microphone indices used for connecting physical microphones
MICROPHONE_INDEX = int(os.environ.get("LEROBOT_TEST_MICROPHONE_INDEX", 0))
DYNAMIXEL_PORT = os.environ.get("LEROBOT_TEST_DYNAMIXEL_PORT", "/dev/tty.usbmodem575E0032081") DYNAMIXEL_PORT = os.environ.get("LEROBOT_TEST_DYNAMIXEL_PORT", "/dev/tty.usbmodem575E0032081")
DYNAMIXEL_MOTORS = { DYNAMIXEL_MOTORS = {
"shoulder_pan": [1, "xl430-w250"], "shoulder_pan": [1, "xl430-w250"],
@ -253,6 +262,29 @@ def require_camera(func):
return wrapper return wrapper
def require_microphone(func):
@wraps(func)
def wrapper(*args, **kwargs):
# Access the pytest request context to get the is_microphone_available fixture
request = kwargs.get("request")
microphone_type = kwargs.get("microphone_type")
mock = kwargs.get("mock")
if request is None:
raise ValueError("The 'request' fixture must be an argument of the test function.")
if microphone_type is None:
raise ValueError("The 'microphone_type' must be an argument of the test function.")
if mock is None:
raise ValueError("The 'mock' variable must be an argument of the test function.")
if not mock and not request.getfixturevalue("is_microphone_available"):
pytest.skip(f"A {microphone_type} microphone is not available.")
return func(*args, **kwargs)
return wrapper
def require_motor(func): def require_motor(func):
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
@ -315,6 +347,14 @@ def make_camera(camera_type: str, **kwargs) -> Camera:
raise ValueError(f"The camera type '{camera_type}' is not valid.") raise ValueError(f"The camera type '{camera_type}' is not valid.")
def make_microphone(microphone_type: str, **kwargs) -> Microphone:
if microphone_type == "microphone":
microphone_index = kwargs.pop("microphone_index", MICROPHONE_INDEX)
return make_microphone_device(microphone_type, microphone_index=microphone_index, **kwargs)
else:
raise ValueError(f"The microphone type '{microphone_type}' is not valid.")
# TODO(rcadene, aliberts): remove this dark pattern that overrides # TODO(rcadene, aliberts): remove this dark pattern that overrides
def make_motors_bus(motor_type: str, **kwargs) -> MotorsBus: def make_motors_bus(motor_type: str, **kwargs) -> MotorsBus:
if motor_type == "dynamixel": if motor_type == "dynamixel":