Merge ca716ed196
into 768e36660d
This commit is contained in:
commit
35d58527a4
|
@ -190,6 +190,11 @@ available_cameras = [
|
|||
"intelrealsense",
|
||||
]
|
||||
|
||||
# lists all available microphones from `lerobot/common/robot_devices/microphones`
|
||||
available_microphones = [
|
||||
"microphone",
|
||||
]
|
||||
|
||||
# lists all available motors from `lerobot/common/robot_devices/motors`
|
||||
available_motors = [
|
||||
"dynamixel",
|
||||
|
|
|
@ -0,0 +1,165 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import logging
|
||||
import subprocess
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
from numpy import ceil
|
||||
|
||||
|
||||
def decode_audio(
|
||||
audio_path: Path | str,
|
||||
timestamps: list[float],
|
||||
duration: float,
|
||||
backend: str | None = "ffmpeg",
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Decodes audio using the specified backend.
|
||||
Args:
|
||||
audio_path (Path): Path to the audio file.
|
||||
timestamps (list[float]): List of (starting) timestamps to extract audio chunks.
|
||||
duration (float): Duration of the audio chunks in seconds.
|
||||
backend (str, optional): Backend to use for decoding. Defaults to "ffmpeg".
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Decoded audio chunks.
|
||||
|
||||
Currently supports ffmpeg.
|
||||
"""
|
||||
if backend == "torchcodec":
|
||||
raise NotImplementedError("torchcodec is not yet supported for audio decoding")
|
||||
elif backend == "ffmpeg":
|
||||
return decode_audio_torchaudio(audio_path, timestamps, duration)
|
||||
else:
|
||||
raise ValueError(f"Unsupported video backend: {backend}")
|
||||
|
||||
|
||||
def decode_audio_torchaudio(
|
||||
audio_path: Path | str,
|
||||
timestamps: list[float],
|
||||
duration: float,
|
||||
log_loaded_timestamps: bool = False,
|
||||
) -> torch.Tensor:
|
||||
# TODO(CarolinePascal) : add channels selection
|
||||
audio_path = str(audio_path)
|
||||
|
||||
reader = torchaudio.io.StreamReader(src=audio_path)
|
||||
audio_sample_rate = reader.get_src_stream_info(reader.default_audio_stream).sample_rate
|
||||
|
||||
# TODO(CarolinePascal) : sort timestamps ?
|
||||
reader.add_basic_audio_stream(
|
||||
frames_per_chunk=int(ceil(duration * audio_sample_rate)), # Too much is better than not enough
|
||||
buffer_chunk_size=-1, # No dropping frames
|
||||
format="fltp", # Format as float32
|
||||
)
|
||||
|
||||
audio_chunks = []
|
||||
for ts in timestamps:
|
||||
reader.seek(ts) # Default to closest audio sample
|
||||
status = reader.fill_buffer()
|
||||
if status != 0:
|
||||
logging.warning("Audio stream reached end of recording before decoding desired timestamps.")
|
||||
|
||||
current_audio_chunk = reader.pop_chunks()[0]
|
||||
|
||||
if log_loaded_timestamps:
|
||||
logging.info(
|
||||
f"audio chunk loaded at starting timestamp={current_audio_chunk['pts']:.4f} with duration={len(current_audio_chunk) / audio_sample_rate:.4f}"
|
||||
)
|
||||
|
||||
audio_chunks.append(current_audio_chunk)
|
||||
|
||||
audio_chunks = torch.stack(audio_chunks)
|
||||
|
||||
assert len(timestamps) == len(audio_chunks)
|
||||
return audio_chunks
|
||||
|
||||
|
||||
def encode_audio(
|
||||
input_path: Path | str,
|
||||
output_path: Path | str,
|
||||
codec: str = "aac", # TODO(CarolinePascal) : investigate Fraunhofer FDK AAC (libfdk_aac) codec and and constant (file size control) /variable (quality control) bitrate options
|
||||
log_level: str | None = "error",
|
||||
overwrite: bool = False,
|
||||
) -> None:
|
||||
"""Encodes an audio file using ffmpeg."""
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ffmpeg_args = OrderedDict(
|
||||
[
|
||||
("-i", str(input_path)),
|
||||
("-acodec", codec),
|
||||
]
|
||||
)
|
||||
|
||||
if log_level is not None:
|
||||
ffmpeg_args["-loglevel"] = str(log_level)
|
||||
|
||||
ffmpeg_args = [item for pair in ffmpeg_args.items() for item in pair]
|
||||
if overwrite:
|
||||
ffmpeg_args.append("-y")
|
||||
|
||||
ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(output_path)]
|
||||
|
||||
# redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal
|
||||
subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL)
|
||||
|
||||
if not output_path.exists():
|
||||
raise OSError(
|
||||
f"Audio encoding did not work. File not found: {output_path}. "
|
||||
f"Try running the command manually to debug: `{''.join(ffmpeg_cmd)}`"
|
||||
)
|
||||
|
||||
|
||||
def get_audio_info(video_path: Path | str) -> dict:
|
||||
ffprobe_audio_cmd = [
|
||||
"ffprobe",
|
||||
"-v",
|
||||
"error",
|
||||
"-select_streams",
|
||||
"a:0",
|
||||
"-show_entries",
|
||||
"stream=channels,codec_name,bit_rate,sample_rate,bit_depth,channel_layout,duration",
|
||||
"-of",
|
||||
"json",
|
||||
str(video_path),
|
||||
]
|
||||
result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
||||
|
||||
info = json.loads(result.stdout)
|
||||
audio_stream_info = info["streams"][0] if info.get("streams") else None
|
||||
if audio_stream_info is None:
|
||||
return {"has_audio": False}
|
||||
|
||||
# Return the information, defaulting to None if no audio stream is present
|
||||
return {
|
||||
"has_audio": True,
|
||||
"audio.channels": audio_stream_info.get("channels", None),
|
||||
"audio.codec": audio_stream_info.get("codec_name", None),
|
||||
"audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None,
|
||||
"audio.sample_rate": int(audio_stream_info["sample_rate"])
|
||||
if audio_stream_info.get("sample_rate")
|
||||
else None,
|
||||
"audio.bit_depth": audio_stream_info.get("bit_depth", None),
|
||||
"audio.channel_layout": audio_stream_info.get("channel_layout", None),
|
||||
}
|
|
@ -15,7 +15,7 @@
|
|||
# limitations under the License.
|
||||
import numpy as np
|
||||
|
||||
from lerobot.common.datasets.utils import load_image_as_numpy
|
||||
from lerobot.common.datasets.utils import load_audio_from_path, load_image_as_numpy
|
||||
|
||||
|
||||
def estimate_num_samples(
|
||||
|
@ -72,6 +72,20 @@ def sample_images(image_paths: list[str]) -> np.ndarray:
|
|||
return images
|
||||
|
||||
|
||||
def sample_audio_from_path(audio_path: str) -> np.ndarray:
|
||||
"""Samples audio data from an audio recording stored in a WAV file."""
|
||||
data = load_audio_from_path(audio_path)
|
||||
sampled_indices = sample_indices(len(data))
|
||||
|
||||
return data[sampled_indices]
|
||||
|
||||
|
||||
def sample_audio_from_data(data: np.ndarray) -> np.ndarray:
|
||||
"""Samples audio data from an audio recording stored in a numpy array."""
|
||||
sampled_indices = sample_indices(len(data))
|
||||
return data[sampled_indices]
|
||||
|
||||
|
||||
def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]:
|
||||
return {
|
||||
"min": np.min(array, axis=axis, keepdims=keepdims),
|
||||
|
@ -91,6 +105,13 @@ def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], featu
|
|||
ep_ft_array = sample_images(data) # data is a list of image paths
|
||||
axes_to_reduce = (0, 2, 3) # keep channel dim
|
||||
keepdims = True
|
||||
elif features[key]["dtype"] == "audio":
|
||||
try:
|
||||
ep_ft_array = sample_audio_from_path(data[0])
|
||||
except TypeError: # Should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner
|
||||
ep_ft_array = sample_audio_from_data(data)
|
||||
axes_to_reduce = 0
|
||||
keepdims = True
|
||||
else:
|
||||
ep_ft_array = data # data is already a np.ndarray
|
||||
axes_to_reduce = 0 # compute stats over the first axis
|
||||
|
|
|
@ -23,6 +23,7 @@ import datasets
|
|||
import numpy as np
|
||||
import packaging.version
|
||||
import PIL.Image
|
||||
import soundfile as sf
|
||||
import torch
|
||||
import torch.utils
|
||||
from datasets import concatenate_datasets, load_dataset
|
||||
|
@ -31,11 +32,18 @@ from huggingface_hub.constants import REPOCARD_NAME
|
|||
from huggingface_hub.errors import RevisionNotFoundError
|
||||
|
||||
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||
from lerobot.common.datasets.audio_utils import (
|
||||
decode_audio,
|
||||
encode_audio,
|
||||
get_audio_info,
|
||||
)
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
|
||||
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
|
||||
from lerobot.common.datasets.utils import (
|
||||
DEFAULT_AUDIO_CHUNK_DURATION,
|
||||
DEFAULT_FEATURES,
|
||||
DEFAULT_IMAGE_PATH,
|
||||
DEFAULT_RAW_AUDIO_PATH,
|
||||
INFO_PATH,
|
||||
TASKS_PATH,
|
||||
append_jsonlines,
|
||||
|
@ -72,6 +80,7 @@ from lerobot.common.datasets.video_utils import (
|
|||
get_safe_default_codec,
|
||||
get_video_info,
|
||||
)
|
||||
from lerobot.common.robot_devices.microphones.utils import Microphone
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
|
||||
CODEBASE_VERSION = "v2.1"
|
||||
|
@ -142,6 +151,14 @@ class LeRobotDatasetMetadata:
|
|||
fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
|
||||
return Path(fpath)
|
||||
|
||||
def get_compressed_audio_file_path(self, episode_index: int, audio_key: str) -> Path:
|
||||
"""Returns the path of the compressed (i.e. encoded) audio file."""
|
||||
episode_chunk = self.get_episode_chunk(episode_index)
|
||||
fpath = self.audio_path.format(
|
||||
episode_chunk=episode_chunk, audio_key=audio_key, episode_index=episode_index
|
||||
)
|
||||
return self.root / fpath
|
||||
|
||||
def get_episode_chunk(self, ep_index: int) -> int:
|
||||
return ep_index // self.chunks_size
|
||||
|
||||
|
@ -155,6 +172,11 @@ class LeRobotDatasetMetadata:
|
|||
"""Formattable string for the video files."""
|
||||
return self.info["video_path"]
|
||||
|
||||
@property
|
||||
def audio_path(self) -> str | None:
|
||||
"""Formattable string for the audio files."""
|
||||
return self.info["audio_path"]
|
||||
|
||||
@property
|
||||
def robot_type(self) -> str | None:
|
||||
"""Robot type used in recording this dataset."""
|
||||
|
@ -185,6 +207,11 @@ class LeRobotDatasetMetadata:
|
|||
"""Keys to access visual modalities (regardless of their storage method)."""
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
|
||||
|
||||
@property
|
||||
def audio_keys(self) -> list[str]:
|
||||
"""Keys to access audio modalities."""
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] == "audio"]
|
||||
|
||||
@property
|
||||
def names(self) -> dict[str, list | dict]:
|
||||
"""Names of the various dimensions of vector modalities."""
|
||||
|
@ -264,6 +291,10 @@ class LeRobotDatasetMetadata:
|
|||
if len(self.video_keys) > 0:
|
||||
self.update_video_info()
|
||||
|
||||
self.info["total_audio"] += len(self.audio_keys)
|
||||
if len(self.audio_keys) > 0:
|
||||
self.update_audio_info()
|
||||
|
||||
write_info(self.info, self.root)
|
||||
|
||||
episode_dict = {
|
||||
|
@ -288,6 +319,19 @@ class LeRobotDatasetMetadata:
|
|||
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
|
||||
self.info["features"][key]["info"] = get_video_info(video_path)
|
||||
|
||||
def update_audio_info(self) -> None:
|
||||
"""
|
||||
Warning: this function writes info from first episode audio, implicitly assuming that all audio have
|
||||
been encoded the same way. Also, this means it assumes the first episode exists.
|
||||
"""
|
||||
for key in self.audio_keys:
|
||||
if (
|
||||
not self.features[key].get("info", None)
|
||||
or (len(self.features[key]["info"]) == 1 and "sample_rate" in self.features[key]["info"])
|
||||
): # Overwrite if info is empty or only contains sample rate (necessary to correctly save audio files recorded by LeKiwi)
|
||||
audio_path = self.root / self.get_compressed_audio_file_path(0, key)
|
||||
self.info["features"][key]["info"] = get_audio_info(audio_path)
|
||||
|
||||
def __repr__(self):
|
||||
feature_keys = list(self.features)
|
||||
return (
|
||||
|
@ -363,7 +407,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
revision: str | None = None,
|
||||
force_cache_sync: bool = False,
|
||||
download_videos: bool = True,
|
||||
download_audio: bool = True,
|
||||
video_backend: str | None = None,
|
||||
audio_backend: str | None = None,
|
||||
):
|
||||
"""
|
||||
2 modes are available for instantiating this class, depending on 2 different use cases:
|
||||
|
@ -394,7 +440,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
- tasks contains the prompts for each task of the dataset, which can be used for
|
||||
task-conditioned training.
|
||||
- hf_dataset (from datasets.Dataset), which will read any values from parquet files.
|
||||
- videos (optional) from which frames are loaded to be synchronous with data from parquet files.
|
||||
- videos (optional) from which frames and audio (if any) are loaded to be synchronous with data from parquet files and audio.
|
||||
- audio (optional) from which audio is loaded to be synchronous with data from parquet files and videos.
|
||||
|
||||
A typical LeRobotDataset looks like this from its root path:
|
||||
.
|
||||
|
@ -415,17 +462,31 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
│ ├── info.json
|
||||
│ ├── stats.json
|
||||
│ └── tasks.jsonl
|
||||
└── videos
|
||||
├── videos
|
||||
│ ├── chunk-000
|
||||
│ │ ├── observation.images.laptop
|
||||
│ │ │ ├── episode_000000.mp4
|
||||
│ │ │ ├── episode_000001.mp4
|
||||
│ │ │ ├── episode_000002.mp4
|
||||
│ │ │ └── ...
|
||||
│ │ ├── observation.images.phone
|
||||
│ │ │ ├── episode_000000.mp4
|
||||
│ │ │ ├── episode_000001.mp4
|
||||
│ │ │ ├── episode_000002.mp4
|
||||
│ │ │ └── ...
|
||||
│ ├── chunk-001
|
||||
│ └── ...
|
||||
└── audio
|
||||
├── chunk-000
|
||||
│ ├── observation.images.laptop
|
||||
│ │ ├── episode_000000.mp4
|
||||
│ │ ├── episode_000001.mp4
|
||||
│ │ ├── episode_000002.mp4
|
||||
│ ├── observation.audio.laptop
|
||||
│ │ ├── episode_000000.m4a
|
||||
│ │ ├── episode_000001.m4a
|
||||
│ │ ├── episode_000002.m4a
|
||||
│ │ └── ...
|
||||
│ ├── observation.images.phone
|
||||
│ │ ├── episode_000000.mp4
|
||||
│ │ ├── episode_000001.mp4
|
||||
│ │ ├── episode_000002.mp4
|
||||
│ ├── observation.audio.phone
|
||||
│ │ ├── episode_000000.m4a
|
||||
│ │ ├── episode_000001.m4a
|
||||
│ │ ├── episode_000002.m4a
|
||||
│ │ └── ...
|
||||
├── chunk-001
|
||||
└── ...
|
||||
|
@ -463,8 +524,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
|
||||
video files are already present on local disk, they won't be downloaded again. Defaults to
|
||||
True.
|
||||
download_audio (bool, optional): Flag to download the audio (see download_videos). Defaults to True.
|
||||
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'.
|
||||
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
|
||||
audio_backend (str | None, optional): Audio backend to use for decoding audio. Defaults to 'ffmpeg' decoder used by 'torchaudio'.
|
||||
"""
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
|
@ -475,6 +538,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
self.tolerance_s = tolerance_s
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self.video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||
self.audio_backend = (
|
||||
audio_backend if audio_backend else "ffmpeg"
|
||||
) # Waiting for torchcodec release #TODO(CarolinePascal)
|
||||
self.delta_indices = None
|
||||
|
||||
# Unused attributes
|
||||
|
@ -499,7 +565,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
self.hf_dataset = self.load_hf_dataset()
|
||||
except (AssertionError, FileNotFoundError, NotADirectoryError):
|
||||
self.revision = get_safe_version(self.repo_id, self.revision)
|
||||
self.download_episodes(download_videos)
|
||||
self.download_episodes(download_videos, download_audio)
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
|
||||
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
||||
|
@ -510,6 +576,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
|
||||
check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
|
||||
|
||||
# TODO(CarolinePascal) : add check for audio duration with respect to episode duration BUT this will be CPU expensive if there are many episodes !
|
||||
|
||||
# Setup delta_indices
|
||||
if self.delta_timestamps is not None:
|
||||
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
|
||||
|
@ -522,6 +590,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
license: str | None = "apache-2.0",
|
||||
tag_version: bool = True,
|
||||
push_videos: bool = True,
|
||||
push_audio: bool = True,
|
||||
private: bool = False,
|
||||
allow_patterns: list[str] | str | None = None,
|
||||
upload_large_folder: bool = False,
|
||||
|
@ -530,6 +599,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
ignore_patterns = ["images/"]
|
||||
if not push_videos:
|
||||
ignore_patterns.append("videos/")
|
||||
if not push_audio:
|
||||
ignore_patterns.append("audio/")
|
||||
|
||||
hub_api = HfApi()
|
||||
hub_api.create_repo(
|
||||
|
@ -585,7 +656,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
ignore_patterns=ignore_patterns,
|
||||
)
|
||||
|
||||
def download_episodes(self, download_videos: bool = True) -> None:
|
||||
def download_episodes(self, download_videos: bool = True, download_audio: bool = True) -> None:
|
||||
"""Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this
|
||||
will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
|
||||
dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present
|
||||
|
@ -594,7 +665,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
# TODO(rcadene, aliberts): implement faster transfer
|
||||
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
|
||||
files = None
|
||||
ignore_patterns = None if download_videos else "videos/"
|
||||
ignore_patterns = []
|
||||
if not download_videos:
|
||||
ignore_patterns.append("videos/")
|
||||
if not download_audio:
|
||||
ignore_patterns.append("audio/")
|
||||
if self.episodes is not None:
|
||||
files = self.get_episodes_file_paths()
|
||||
|
||||
|
@ -611,6 +686,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
]
|
||||
fpaths += video_files
|
||||
|
||||
if len(self.meta.audio_keys) > 0:
|
||||
audio_files = [
|
||||
str(self.meta.get_compressed_audio_file_path(ep_idx, audio_key))
|
||||
for audio_key in self.meta.audio_keys
|
||||
for ep_idx in episodes
|
||||
]
|
||||
fpaths += audio_files
|
||||
|
||||
return fpaths
|
||||
|
||||
def load_hf_dataset(self) -> datasets.Dataset:
|
||||
|
@ -677,7 +760,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
}
|
||||
return query_indices, padding
|
||||
|
||||
def _get_query_timestamps(
|
||||
def _get_query_timestamps_video(
|
||||
self,
|
||||
current_ts: float,
|
||||
query_indices: dict[str, list[int]] | None = None,
|
||||
|
@ -692,11 +775,27 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
|
||||
return query_timestamps
|
||||
|
||||
# TODO(CarolinePascal): add variable query durations
|
||||
def _get_query_timestamps_audio(
|
||||
self,
|
||||
current_ts: float,
|
||||
query_indices: dict[str, list[int]] | None = None,
|
||||
) -> dict[str, list[float]]:
|
||||
query_timestamps = {}
|
||||
for key in self.meta.audio_keys:
|
||||
if query_indices is not None and key in query_indices:
|
||||
timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
|
||||
query_timestamps[key] = torch.stack(timestamps).tolist()
|
||||
else:
|
||||
query_timestamps[key] = [current_ts]
|
||||
|
||||
return query_timestamps
|
||||
|
||||
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
||||
return {
|
||||
key: torch.stack(self.hf_dataset.select(q_idx)[key])
|
||||
for key, q_idx in query_indices.items()
|
||||
if key not in self.meta.video_keys
|
||||
if key not in self.meta.video_keys and key not in self.meta.audio_keys
|
||||
}
|
||||
|
||||
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
|
||||
|
@ -713,6 +812,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
|
||||
return item
|
||||
|
||||
# TODO(CarolinePascal): add variable query durations
|
||||
def _query_audio(
|
||||
self, query_timestamps: dict[str, list[float]], query_duration: float, ep_idx: int
|
||||
) -> dict[str, torch.Tensor]:
|
||||
item = {}
|
||||
for audio_key, query_ts in query_timestamps.items():
|
||||
audio_path = self.root / self.meta.get_compressed_audio_file_path(ep_idx, audio_key)
|
||||
audio_chunk = decode_audio(audio_path, query_ts, query_duration, self.audio_backend)
|
||||
item[audio_key] = audio_chunk.squeeze(0)
|
||||
return item
|
||||
|
||||
def _add_padding_keys(self, item: dict, padding: dict[str, list[bool]]) -> dict:
|
||||
for key, val in padding.items():
|
||||
item[key] = torch.BoolTensor(val)
|
||||
|
@ -733,11 +843,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
for key, val in query_result.items():
|
||||
item[key] = val
|
||||
|
||||
if len(self.meta.video_keys) > 0:
|
||||
if len(self.meta.video_keys) > 0 or len(self.meta.audio_keys) > 0:
|
||||
current_ts = item["timestamp"].item()
|
||||
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
|
||||
|
||||
query_timestamps = self._get_query_timestamps_video(current_ts, query_indices)
|
||||
video_frames = self._query_videos(query_timestamps, ep_idx)
|
||||
item = {**video_frames, **item}
|
||||
item = {**item, **video_frames}
|
||||
|
||||
query_timestamps = self._get_query_timestamps_audio(current_ts, query_indices)
|
||||
audio_chunks = self._query_audio(query_timestamps, DEFAULT_AUDIO_CHUNK_DURATION, ep_idx)
|
||||
item = {**item, **audio_chunks}
|
||||
|
||||
if self.image_transforms is not None:
|
||||
image_keys = self.meta.camera_keys
|
||||
|
@ -777,6 +892,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
)
|
||||
return self.root / fpath
|
||||
|
||||
def _get_raw_audio_file_path(self, episode_index: int, audio_key: str) -> Path:
|
||||
fpath = DEFAULT_RAW_AUDIO_PATH.format(audio_key=audio_key, episode_index=episode_index)
|
||||
return self.root / fpath
|
||||
|
||||
def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None:
|
||||
if self.image_writer is None:
|
||||
if isinstance(image, torch.Tensor):
|
||||
|
@ -827,11 +946,39 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
img_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._save_image(frame[key], img_path)
|
||||
self.episode_buffer[key].append(str(img_path))
|
||||
elif self.features[key]["dtype"] == "audio":
|
||||
if (
|
||||
self.meta.robot_type is not None and self.meta.robot_type.startswith("lekiwi")
|
||||
): # Raw data storage should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner
|
||||
self.episode_buffer[key].append(frame[key])
|
||||
else: # Otherwise, only the audio file path is stored in the episode buffer
|
||||
if frame_index == 0:
|
||||
audio_path = self._get_raw_audio_file_path(
|
||||
episode_index=self.episode_buffer["episode_index"], audio_key=key
|
||||
)
|
||||
self.episode_buffer[key].append(str(audio_path))
|
||||
else:
|
||||
self.episode_buffer[key].append(frame[key])
|
||||
|
||||
self.episode_buffer["size"] += 1
|
||||
|
||||
def add_microphone_recording(self, microphone: Microphone, microphone_key: str) -> None:
|
||||
"""
|
||||
Starts recording audio data provided by the microphone and directly writes it in a .wav file.
|
||||
"""
|
||||
|
||||
audio_dir = self._get_raw_audio_file_path(
|
||||
self.num_episodes, "observation.audio." + microphone_key
|
||||
).parent
|
||||
if not audio_dir.is_dir():
|
||||
audio_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
microphone.start_recording(
|
||||
output_file=self._get_raw_audio_file_path(
|
||||
self.num_episodes, "observation.audio." + microphone_key
|
||||
)
|
||||
)
|
||||
|
||||
def save_episode(self, episode_data: dict | None = None) -> None:
|
||||
"""
|
||||
This will save to disk the current episode in self.episode_buffer.
|
||||
|
@ -869,16 +1016,39 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
# are processed separately by storing image path and frame info as meta data
|
||||
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
|
||||
continue
|
||||
elif ft["dtype"] == "audio":
|
||||
if (
|
||||
self.meta.robot_type is not None and self.meta.robot_type.startswith("lekiwi")
|
||||
): # Raw data storage should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner
|
||||
episode_buffer[key] = np.concatenate(episode_buffer[key], axis=0)
|
||||
continue
|
||||
episode_buffer[key] = np.stack(episode_buffer[key])
|
||||
|
||||
self._wait_image_writer()
|
||||
self._save_episode_table(episode_buffer, episode_index)
|
||||
|
||||
if (
|
||||
self.meta.robot_type is not None and self.meta.robot_type.startswith("lekiwi")
|
||||
): # Raw data storage should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner
|
||||
for key in self.meta.audio_keys:
|
||||
audio_path = self._get_raw_audio_file_path(
|
||||
episode_index=self.episode_buffer["episode_index"][0], audio_key=key
|
||||
)
|
||||
with sf.SoundFile(
|
||||
audio_path,
|
||||
mode="w",
|
||||
samplerate=self.meta.features[key]["info"]["sample_rate"],
|
||||
channels=self.meta.features[key]["shape"][0],
|
||||
) as file:
|
||||
file.write(episode_buffer[key])
|
||||
|
||||
ep_stats = compute_episode_stats(episode_buffer, self.features)
|
||||
|
||||
if len(self.meta.video_keys) > 0:
|
||||
video_paths = self.encode_episode_videos(episode_index)
|
||||
for key in self.meta.video_keys:
|
||||
episode_buffer[key] = video_paths[key]
|
||||
self.encode_episode_videos(episode_index)
|
||||
|
||||
if len(self.meta.audio_keys) > 0:
|
||||
self.encode_episode_audio(episode_index)
|
||||
|
||||
# `meta.save_episode` be executed after encoding the videos
|
||||
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
|
||||
|
@ -904,6 +1074,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
if img_dir.is_dir():
|
||||
shutil.rmtree(self.root / "images")
|
||||
|
||||
# delete raw audio files
|
||||
raw_audio_files = list(self.root.rglob("*.wav"))
|
||||
for raw_audio_file in raw_audio_files:
|
||||
raw_audio_file.unlink()
|
||||
if len(list(raw_audio_file.parent.iterdir())) == 0:
|
||||
raw_audio_file.parent.rmdir()
|
||||
|
||||
if not episode_data: # Reset the buffer
|
||||
self.episode_buffer = self.create_episode_buffer()
|
||||
|
||||
|
@ -971,19 +1148,40 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
since video encoding with ffmpeg is already using multithreading.
|
||||
"""
|
||||
video_paths = {}
|
||||
for key in self.meta.video_keys:
|
||||
video_path = self.root / self.meta.get_video_file_path(episode_index, key)
|
||||
video_paths[key] = str(video_path)
|
||||
for video_key in self.meta.video_keys:
|
||||
video_path = self.root / self.meta.get_video_file_path(episode_index, video_key)
|
||||
video_paths[video_key] = str(video_path)
|
||||
if video_path.is_file():
|
||||
# Skip if video is already encoded. Could be the case when resuming data recording.
|
||||
continue
|
||||
img_dir = self._get_image_file_path(
|
||||
episode_index=episode_index, image_key=key, frame_index=0
|
||||
episode_index=episode_index, image_key=video_key, frame_index=0
|
||||
).parent
|
||||
|
||||
encode_video_frames(img_dir, video_path, self.fps, overwrite=True)
|
||||
|
||||
return video_paths
|
||||
|
||||
def encode_episode_audio(self, episode_index: int) -> dict:
|
||||
"""
|
||||
Use ffmpeg to convert .wav raw audio files into .m4a audio files.
|
||||
Note: `encode_episode_audio` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||
since video encoding with ffmpeg is already using multithreading.
|
||||
"""
|
||||
audio_paths = {}
|
||||
for audio_key in self.meta.audio_keys:
|
||||
input_audio_path = self.root / self._get_raw_audio_file_path(episode_index, audio_key)
|
||||
output_audio_path = self.root / self.meta.get_compressed_audio_file_path(episode_index, audio_key)
|
||||
|
||||
audio_paths[audio_key] = str(output_audio_path)
|
||||
if output_audio_path.is_file():
|
||||
# Skip if video is already encoded. Could be the case when resuming data recording.
|
||||
continue
|
||||
|
||||
encode_audio(input_audio_path, output_audio_path, overwrite=True)
|
||||
|
||||
return audio_paths
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
|
@ -998,6 +1196,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
image_writer_processes: int = 0,
|
||||
image_writer_threads: int = 0,
|
||||
video_backend: str | None = None,
|
||||
audio_backend: str | None = None,
|
||||
) -> "LeRobotDataset":
|
||||
"""Create a LeRobot Dataset from scratch in order to record data."""
|
||||
obj = cls.__new__(cls)
|
||||
|
@ -1029,6 +1228,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
obj.delta_indices = None
|
||||
obj.episode_data_index = None
|
||||
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
||||
obj.audio_backend = (
|
||||
audio_backend if audio_backend is not None else "ffmpeg"
|
||||
) # Waiting for torchcodec release #TODO(CarolinePascal)
|
||||
return obj
|
||||
|
||||
|
||||
|
@ -1049,6 +1251,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|||
tolerances_s: dict | None = None,
|
||||
download_videos: bool = True,
|
||||
video_backend: str | None = None,
|
||||
audio_backend: str | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.repo_ids = repo_ids
|
||||
|
@ -1066,6 +1269,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|||
tolerance_s=self.tolerances_s[repo_id],
|
||||
download_videos=download_videos,
|
||||
video_backend=video_backend,
|
||||
audio_backend=audio_backend,
|
||||
)
|
||||
for repo_id in repo_ids
|
||||
]
|
||||
|
|
|
@ -33,6 +33,7 @@ from datasets.table import embed_table_storage
|
|||
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
|
||||
from huggingface_hub.errors import RevisionNotFoundError
|
||||
from PIL import Image as PILImage
|
||||
from soundfile import read
|
||||
from torchvision import transforms
|
||||
|
||||
from lerobot.common.datasets.backward_compatibility import (
|
||||
|
@ -55,6 +56,10 @@ TASKS_PATH = "meta/tasks.jsonl"
|
|||
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
||||
DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
|
||||
DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
|
||||
DEFAULT_RAW_AUDIO_PATH = "audio/{audio_key}/episode_{episode_index:06d}.wav"
|
||||
DEFAULT_COMPRESSED_AUDIO_PATH = "audio/chunk-{episode_chunk:03d}/{audio_key}/episode_{episode_index:06d}.m4a"
|
||||
|
||||
DEFAULT_AUDIO_CHUNK_DURATION = 0.5 # seconds
|
||||
|
||||
DATASET_CARD_TEMPLATE = """
|
||||
---
|
||||
|
@ -255,6 +260,11 @@ def load_image_as_numpy(
|
|||
return img_array
|
||||
|
||||
|
||||
def load_audio_from_path(fpath: str | Path) -> np.ndarray:
|
||||
audio_data, _ = read(fpath, dtype="float32")
|
||||
return audio_data
|
||||
|
||||
|
||||
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
|
||||
to torch tensors. Importantly, images are converted from PIL, which corresponds to
|
||||
|
@ -363,7 +373,7 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
|
|||
def get_hf_features_from_features(features: dict) -> datasets.Features:
|
||||
hf_features = {}
|
||||
for key, ft in features.items():
|
||||
if ft["dtype"] == "video":
|
||||
if ft["dtype"] == "video" or ft["dtype"] == "audio":
|
||||
continue
|
||||
elif ft["dtype"] == "image":
|
||||
hf_features[key] = datasets.Image()
|
||||
|
@ -394,7 +404,7 @@ def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict:
|
|||
key: {"dtype": "video" if use_videos else "image", **ft}
|
||||
for key, ft in robot.camera_features.items()
|
||||
}
|
||||
return {**robot.motor_features, **camera_ft, **DEFAULT_FEATURES}
|
||||
return {**robot.motor_features, **camera_ft, **robot.microphone_features, **DEFAULT_FEATURES}
|
||||
|
||||
|
||||
def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]:
|
||||
|
@ -442,12 +452,14 @@ def create_empty_dataset_info(
|
|||
"total_frames": 0,
|
||||
"total_tasks": 0,
|
||||
"total_videos": 0,
|
||||
"total_audio": 0,
|
||||
"total_chunks": 0,
|
||||
"chunks_size": DEFAULT_CHUNK_SIZE,
|
||||
"fps": fps,
|
||||
"splits": {},
|
||||
"data_path": DEFAULT_PARQUET_PATH,
|
||||
"video_path": DEFAULT_VIDEO_PATH if use_videos else None,
|
||||
"audio_path": DEFAULT_COMPRESSED_AUDIO_PATH,
|
||||
"features": features,
|
||||
}
|
||||
|
||||
|
@ -721,6 +733,7 @@ def validate_features_presence(
|
|||
):
|
||||
error_message = ""
|
||||
missing_features = expected_features - actual_features
|
||||
missing_features = {feature for feature in missing_features if "observation.audio" not in feature}
|
||||
extra_features = actual_features - (expected_features | optional_features)
|
||||
|
||||
if missing_features or extra_features:
|
||||
|
@ -740,6 +753,8 @@ def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray
|
|||
return validate_feature_numpy_array(name, expected_dtype, expected_shape, value)
|
||||
elif expected_dtype in ["image", "video"]:
|
||||
return validate_feature_image_or_video(name, expected_shape, value)
|
||||
elif expected_dtype == "audio":
|
||||
return validate_feature_audio(name, expected_shape, value)
|
||||
elif expected_dtype == "string":
|
||||
return validate_feature_string(name, value)
|
||||
else:
|
||||
|
@ -781,6 +796,23 @@ def validate_feature_image_or_video(name: str, expected_shape: list[str], value:
|
|||
return error_message
|
||||
|
||||
|
||||
def validate_feature_audio(name: str, expected_shape: list[str], value: np.ndarray):
|
||||
error_message = ""
|
||||
if isinstance(value, np.ndarray):
|
||||
actual_shape = value.shape
|
||||
c = expected_shape
|
||||
if len(actual_shape) != 2 or (
|
||||
actual_shape[-1] != c[-1] and actual_shape[0] != c[0]
|
||||
): # The number of frames might be different
|
||||
error_message += (
|
||||
f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c,)}'.\n"
|
||||
)
|
||||
else:
|
||||
error_message += f"The feature '{name}' is expected to be of type 'np.ndarray', but type '{type(value)}' provided instead.\n"
|
||||
|
||||
return error_message
|
||||
|
||||
|
||||
def validate_feature_string(name: str, value: str):
|
||||
if not isinstance(value, str):
|
||||
return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n"
|
||||
|
|
|
@ -260,35 +260,39 @@ def encode_video_frames(
|
|||
imgs_dir = Path(imgs_dir)
|
||||
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ffmpeg_args = OrderedDict(
|
||||
ffmpeg_video_args = OrderedDict(
|
||||
[
|
||||
("-f", "image2"),
|
||||
("-r", str(fps)),
|
||||
("-i", str(imgs_dir / "frame_%06d.png")),
|
||||
("-vcodec", vcodec),
|
||||
("-pix_fmt", pix_fmt),
|
||||
("-i", str(Path(imgs_dir) / "frame_%06d.png")),
|
||||
]
|
||||
)
|
||||
|
||||
ffmpeg_encoding_args = OrderedDict(
|
||||
[
|
||||
("-pix_fmt", pix_fmt),
|
||||
("-vcodec", vcodec),
|
||||
]
|
||||
)
|
||||
if g is not None:
|
||||
ffmpeg_args["-g"] = str(g)
|
||||
|
||||
ffmpeg_encoding_args["-g"] = str(g)
|
||||
if crf is not None:
|
||||
ffmpeg_args["-crf"] = str(crf)
|
||||
|
||||
ffmpeg_encoding_args["-crf"] = str(crf)
|
||||
if fast_decode:
|
||||
key = "-svtav1-params" if vcodec == "libsvtav1" else "-tune"
|
||||
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
|
||||
ffmpeg_args[key] = value
|
||||
ffmpeg_encoding_args[key] = value
|
||||
|
||||
if log_level is not None:
|
||||
ffmpeg_args["-loglevel"] = str(log_level)
|
||||
ffmpeg_encoding_args["-loglevel"] = str(log_level)
|
||||
|
||||
ffmpeg_args = [item for pair in ffmpeg_args.items() for item in pair]
|
||||
ffmpeg_args = [item for pair in ffmpeg_video_args.items() for item in pair]
|
||||
ffmpeg_args += [item for pair in ffmpeg_encoding_args.items() for item in pair]
|
||||
if overwrite:
|
||||
ffmpeg_args.append("-y")
|
||||
|
||||
ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(video_path)]
|
||||
|
||||
# redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal
|
||||
subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL)
|
||||
|
||||
|
@ -331,42 +335,6 @@ with warnings.catch_warnings():
|
|||
register_feature(VideoFrame, "VideoFrame")
|
||||
|
||||
|
||||
def get_audio_info(video_path: Path | str) -> dict:
|
||||
ffprobe_audio_cmd = [
|
||||
"ffprobe",
|
||||
"-v",
|
||||
"error",
|
||||
"-select_streams",
|
||||
"a:0",
|
||||
"-show_entries",
|
||||
"stream=channels,codec_name,bit_rate,sample_rate,bit_depth,channel_layout,duration",
|
||||
"-of",
|
||||
"json",
|
||||
str(video_path),
|
||||
]
|
||||
result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
||||
|
||||
info = json.loads(result.stdout)
|
||||
audio_stream_info = info["streams"][0] if info.get("streams") else None
|
||||
if audio_stream_info is None:
|
||||
return {"has_audio": False}
|
||||
|
||||
# Return the information, defaulting to None if no audio stream is present
|
||||
return {
|
||||
"has_audio": True,
|
||||
"audio.channels": audio_stream_info.get("channels", None),
|
||||
"audio.codec": audio_stream_info.get("codec_name", None),
|
||||
"audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None,
|
||||
"audio.sample_rate": int(audio_stream_info["sample_rate"])
|
||||
if audio_stream_info.get("sample_rate")
|
||||
else None,
|
||||
"audio.bit_depth": audio_stream_info.get("bit_depth", None),
|
||||
"audio.channel_layout": audio_stream_info.get("channel_layout", None),
|
||||
}
|
||||
|
||||
|
||||
def get_video_info(video_path: Path | str) -> dict:
|
||||
ffprobe_video_cmd = [
|
||||
"ffprobe",
|
||||
|
@ -402,7 +370,6 @@ def get_video_info(video_path: Path | str) -> dict:
|
|||
"video.codec": video_stream_info["codec_name"],
|
||||
"video.pix_fmt": video_stream_info["pix_fmt"],
|
||||
"video.is_depth_map": False,
|
||||
**get_audio_info(video_path),
|
||||
}
|
||||
|
||||
return video_info
|
||||
|
|
|
@ -78,6 +78,11 @@ def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, f
|
|||
if key in robot.logs:
|
||||
log_dt(f"dtR{name}", robot.logs[key])
|
||||
|
||||
for name in robot.microphones:
|
||||
key = f"read_microphone_{name}_dt_s"
|
||||
if key in robot.logs:
|
||||
log_dt(f"dtR{name}", robot.logs[key])
|
||||
|
||||
info_str = " ".join(log_items)
|
||||
logging.info(info_str)
|
||||
|
||||
|
@ -107,11 +112,15 @@ def predict_action(observation, policy, device, use_amp):
|
|||
torch.inference_mode(),
|
||||
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
|
||||
):
|
||||
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
|
||||
for name in observation:
|
||||
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
|
||||
if "image" in name:
|
||||
observation[name] = observation[name].type(torch.float32) / 255
|
||||
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
||||
# Convert to pytorch format: channel first and float32 in [-1,1] with batch dimension
|
||||
if "audio" in name:
|
||||
observation[name] = observation[name].type(torch.float32)
|
||||
observation[name] = observation[name].permute(1, 0).contiguous()
|
||||
observation[name] = observation[name].unsqueeze(0)
|
||||
observation[name] = observation[name].to(device)
|
||||
|
||||
|
@ -243,6 +252,18 @@ def control_loop(
|
|||
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
|
||||
if (
|
||||
dataset is not None and not robot.robot_type.startswith("lekiwi")
|
||||
): # For now, LeKiwi only supports frame audio recording (which may lead to audio chunks loss, extended post-processing, increased memory usage)
|
||||
for microphone_key, microphone in robot.microphones.items():
|
||||
# Start recording both in file writing and data reading mode
|
||||
dataset.add_microphone_recording(microphone, microphone_key)
|
||||
else:
|
||||
for _, microphone in robot.microphones.items():
|
||||
# Start recording only in data reading mode
|
||||
microphone.start_recording()
|
||||
|
||||
while timestamp < control_time_s:
|
||||
start_loop_t = time.perf_counter()
|
||||
|
||||
|
@ -286,6 +307,9 @@ def control_loop(
|
|||
events["exit_early"] = False
|
||||
break
|
||||
|
||||
for _, microphone in robot.microphones.items():
|
||||
microphone.stop_recording()
|
||||
|
||||
|
||||
def reset_environment(robot, events, reset_time_s, fps):
|
||||
# TODO(rcadene): refactor warmup_record and reset_environment
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
from dataclasses import dataclass
|
||||
|
||||
import draccus
|
||||
|
||||
|
||||
@dataclass
|
||||
class MicrophoneConfigBase(draccus.ChoiceRegistry, abc.ABC):
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
|
||||
@MicrophoneConfigBase.register_subclass("microphone")
|
||||
@dataclass
|
||||
class MicrophoneConfig(MicrophoneConfigBase):
|
||||
"""
|
||||
Dataclass for microphone configuration.
|
||||
"""
|
||||
|
||||
microphone_index: int
|
||||
sample_rate: int | None = None
|
||||
channels: list[int] | None = None
|
||||
mock: bool = False
|
|
@ -0,0 +1,425 @@
|
|||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This file contains utilities for recording audio from a microhone.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import shutil
|
||||
import time
|
||||
from multiprocessing import Event as process_Event
|
||||
from multiprocessing import JoinableQueue as process_Queue
|
||||
from multiprocessing import Process
|
||||
from os import getcwd
|
||||
from pathlib import Path
|
||||
from queue import Empty
|
||||
from queue import Queue as thread_Queue
|
||||
from threading import Event, Thread
|
||||
from threading import Event as thread_Event
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
|
||||
from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig
|
||||
from lerobot.common.robot_devices.utils import (
|
||||
RobotDeviceAlreadyConnectedError,
|
||||
RobotDeviceAlreadyRecordingError,
|
||||
RobotDeviceNotConnectedError,
|
||||
RobotDeviceNotRecordingError,
|
||||
)
|
||||
from lerobot.common.utils.utils import capture_timestamp_utc
|
||||
|
||||
|
||||
def find_microphones(raise_when_empty=False, mock=False) -> list[dict]:
|
||||
"""
|
||||
Finds and lists all microphones compatible with sounddevice (and the underlying PortAudio library).
|
||||
Most microphones and sound cards are compatible, across all OS (Linux, Mac, Windows).
|
||||
"""
|
||||
microphones = []
|
||||
|
||||
if mock:
|
||||
import tests.microphones.mock_sounddevice as sd
|
||||
else:
|
||||
import sounddevice as sd
|
||||
|
||||
devices = sd.query_devices()
|
||||
for device in devices:
|
||||
if device["max_input_channels"] > 0:
|
||||
microphones.append(
|
||||
{
|
||||
"index": device["index"],
|
||||
"name": device["name"],
|
||||
}
|
||||
)
|
||||
|
||||
if raise_when_empty and len(microphones) == 0:
|
||||
raise OSError(
|
||||
"Not a single microphone was detected. Try re-plugging the microphone or check the microphone settings."
|
||||
)
|
||||
|
||||
return microphones
|
||||
|
||||
|
||||
def record_audio_from_microphones(
|
||||
output_dir: Path, microphone_ids: list[int] | None = None, record_time_s: float = 2.0
|
||||
):
|
||||
"""
|
||||
Records audio from all the channels of the specified microphones for the specified duration.
|
||||
If no microphone ids are provided, all available microphones will be used.
|
||||
"""
|
||||
|
||||
if microphone_ids is None or len(microphone_ids) == 0:
|
||||
microphones = find_microphones()
|
||||
microphone_ids = [m["index"] for m in microphones]
|
||||
|
||||
microphones = []
|
||||
for microphone_id in microphone_ids:
|
||||
config = MicrophoneConfig(microphone_index=microphone_id)
|
||||
microphone = Microphone(config)
|
||||
microphone.connect()
|
||||
print(
|
||||
f"Recording audio from microphone {microphone_id} for {record_time_s} seconds at {microphone.sample_rate} Hz."
|
||||
)
|
||||
microphones.append(microphone)
|
||||
|
||||
output_dir = Path(output_dir)
|
||||
if output_dir.exists():
|
||||
shutil.rmtree(
|
||||
output_dir,
|
||||
)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
print(f"Saving audio to {output_dir}")
|
||||
|
||||
for microphone in microphones:
|
||||
microphone.start_recording(getcwd() / output_dir / f"microphone_{microphone.microphone_index}.wav")
|
||||
|
||||
time.sleep(record_time_s)
|
||||
|
||||
for microphone in microphones:
|
||||
microphone.stop_recording()
|
||||
|
||||
# Remark : recording may be resumed here if needed
|
||||
|
||||
for microphone in microphones:
|
||||
microphone.disconnect()
|
||||
|
||||
print(f"Images have been saved to {output_dir}")
|
||||
|
||||
|
||||
class Microphone:
|
||||
"""
|
||||
The Microphone class handles all microphones compatible with sounddevice (and the underlying PortAudio library). Most microphones and sound cards are compatible, across all OS (Linux, Mac, Windows).
|
||||
|
||||
A Microphone instance requires the sounddevice index of the microphone, which may be obtained using `python -m sounddevice`. It also requires the recording sample rate as well as the list of recorded channels.
|
||||
|
||||
Example of usage:
|
||||
```python
|
||||
from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig
|
||||
|
||||
config = MicrophoneConfig(microphone_index=0, sample_rate=16000, channels=[1])
|
||||
microphone = Microphone(config)
|
||||
|
||||
microphone.connect()
|
||||
microphone.start_recording("some/output/file.wav")
|
||||
...
|
||||
audio_readings = microphone.read() #Gets all recorded audio data since the last read or since the beginning of the recording
|
||||
...
|
||||
microphone.stop_recording()
|
||||
microphone.disconnect()
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, config: MicrophoneConfig):
|
||||
self.config = config
|
||||
self.microphone_index = config.microphone_index
|
||||
|
||||
# Store the recording sample rate and channels
|
||||
self.sample_rate = config.sample_rate
|
||||
self.channels = config.channels
|
||||
|
||||
self.mock = config.mock
|
||||
|
||||
# Input audio stream
|
||||
self.stream = None
|
||||
|
||||
# Thread/Process-safe concurrent queue to store the recorded/read audio
|
||||
self.record_queue = None
|
||||
self.read_queue = None
|
||||
|
||||
# Thread/Process to handle data reading and file writing in a separate thread/process (safely)
|
||||
self.record_thread = None
|
||||
self.record_stop_event = None
|
||||
|
||||
self.logs = {}
|
||||
self.is_connected = False
|
||||
self.is_recording = False
|
||||
self.is_writing = False
|
||||
|
||||
def connect(self) -> None:
|
||||
"""
|
||||
Connects the microphone and checks if the requested acquisition parameters are compatible with the microphone.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise RobotDeviceAlreadyConnectedError(
|
||||
f"Microphone {self.microphone_index} is already connected."
|
||||
)
|
||||
|
||||
if self.mock:
|
||||
import tests.microphones.mock_sounddevice as sd
|
||||
else:
|
||||
import sounddevice as sd
|
||||
|
||||
# Check if the provided microphone index does match an input device
|
||||
is_index_input = sd.query_devices(self.microphone_index)["max_input_channels"] > 0
|
||||
|
||||
if not is_index_input:
|
||||
microphones_info = find_microphones()
|
||||
available_microphones = [m["index"] for m in microphones_info]
|
||||
raise OSError(
|
||||
f"Microphone index {self.microphone_index} does not match an input device (microphone). Available input devices : {available_microphones}"
|
||||
)
|
||||
|
||||
# Check if provided recording parameters are compatible with the microphone
|
||||
actual_microphone = sd.query_devices(self.microphone_index)
|
||||
|
||||
if self.sample_rate is not None:
|
||||
if self.sample_rate > actual_microphone["default_samplerate"]:
|
||||
raise OSError(
|
||||
f"Provided sample rate {self.sample_rate} is higher than the sample rate of the microphone {actual_microphone['default_samplerate']}."
|
||||
)
|
||||
elif self.sample_rate < actual_microphone["default_samplerate"]:
|
||||
logging.warning(
|
||||
"Provided sample rate is lower than the sample rate of the microphone. Performance may be impacted."
|
||||
)
|
||||
else:
|
||||
self.sample_rate = int(actual_microphone["default_samplerate"])
|
||||
|
||||
if self.channels is not None and len(self.channels) > 0:
|
||||
if any(c > actual_microphone["max_input_channels"] for c in self.channels):
|
||||
raise OSError(
|
||||
f"Some of the provided channels {self.channels} are outside the maximum channel range of the microphone {actual_microphone['max_input_channels']}."
|
||||
)
|
||||
else:
|
||||
self.channels = np.arange(1, actual_microphone["max_input_channels"] + 1)
|
||||
|
||||
# Get channels index instead of number for slicing
|
||||
self.channels_index = np.array(self.channels) - 1
|
||||
|
||||
# Create the audio stream
|
||||
self.stream = sd.InputStream(
|
||||
device=self.microphone_index,
|
||||
samplerate=self.sample_rate,
|
||||
channels=max(self.channels),
|
||||
dtype="float32",
|
||||
callback=self._audio_callback,
|
||||
)
|
||||
# Remark : the blocksize parameter could be passed to the stream to ensure that audio_callback always receive same length buffers.
|
||||
# However, this may lead to additional latency. We thus stick to blocksize=0 which means that audio_callback will receive varying length buffers, but with no additional latency.
|
||||
|
||||
self.is_connected = True
|
||||
|
||||
def _audio_callback(self, indata, frames, time, status) -> None:
|
||||
"""
|
||||
Low-level sounddevice callback.
|
||||
"""
|
||||
if status:
|
||||
logging.warning(status)
|
||||
# Slicing makes copy unnecessary
|
||||
# Two separate queues are necessary because .get() also pops the data from the queue
|
||||
if self.is_writing:
|
||||
self.record_queue.put(indata[:, self.channels_index])
|
||||
self.read_queue.put(indata[:, self.channels_index])
|
||||
|
||||
@staticmethod
|
||||
def _record_loop(queue, event: Event, sample_rate: int, channels: list[int], output_file: Path) -> None:
|
||||
"""
|
||||
Thread/Process-safe loop to write audio data into a file.
|
||||
"""
|
||||
# Can only be run on a single process/thread for file writing safety
|
||||
with sf.SoundFile(
|
||||
output_file,
|
||||
mode="x",
|
||||
samplerate=sample_rate,
|
||||
channels=max(channels),
|
||||
subtype=sf.default_subtype(output_file.suffix[1:]),
|
||||
) as file:
|
||||
while not event.is_set():
|
||||
try:
|
||||
file.write(
|
||||
queue.get(timeout=0.02)
|
||||
) # Timeout set as twice the usual sounddevice buffer size
|
||||
queue.task_done()
|
||||
except Empty:
|
||||
continue
|
||||
|
||||
def _read(self) -> np.ndarray:
|
||||
"""
|
||||
Thread/Process-safe callback to read available audio data
|
||||
"""
|
||||
audio_readings = np.empty((0, len(self.channels)))
|
||||
|
||||
while True:
|
||||
try:
|
||||
audio_readings = np.concatenate((audio_readings, self.read_queue.get_nowait()), axis=0)
|
||||
except Empty:
|
||||
break
|
||||
|
||||
self.read_queue = thread_Queue()
|
||||
|
||||
return audio_readings
|
||||
|
||||
def read(self) -> np.ndarray:
|
||||
"""
|
||||
Reads the last audio chunk recorded by the microphone, e.g. all samples recorded since the last read or since the beginning of the recording.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
|
||||
if not self.is_recording:
|
||||
raise RobotDeviceNotRecordingError(f"Microphone {self.microphone_index} is not recording.")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
audio_readings = self._read()
|
||||
|
||||
# log the number of seconds it took to read the audio chunk
|
||||
self.logs["delta_timestamp_s"] = time.perf_counter() - start_time
|
||||
|
||||
# log the utc time at which the audio chunk was received
|
||||
self.logs["timestamp_utc"] = capture_timestamp_utc()
|
||||
|
||||
return audio_readings
|
||||
|
||||
def start_recording(self, output_file: str | None = None, multiprocessing: bool | None = False) -> None:
|
||||
"""
|
||||
Starts the recording of the microphone. If output_file is provided, the audio will be written to this file.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
|
||||
if self.is_recording:
|
||||
raise RobotDeviceAlreadyRecordingError(
|
||||
f"Microphone {self.microphone_index} is already recording."
|
||||
)
|
||||
|
||||
# Reset queues
|
||||
self.read_queue = thread_Queue()
|
||||
if multiprocessing:
|
||||
self.record_queue = process_Queue()
|
||||
else:
|
||||
self.record_queue = thread_Queue()
|
||||
|
||||
# Write recordings into a file if output_file is provided
|
||||
if output_file is not None:
|
||||
output_file = Path(output_file)
|
||||
if output_file.exists():
|
||||
output_file.unlink()
|
||||
|
||||
if multiprocessing:
|
||||
self.record_stop_event = process_Event()
|
||||
self.record_thread = Process(
|
||||
target=Microphone._record_loop,
|
||||
args=(
|
||||
self.record_queue,
|
||||
self.record_stop_event,
|
||||
self.sample_rate,
|
||||
self.channels,
|
||||
output_file,
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.record_stop_event = thread_Event()
|
||||
self.record_thread = Thread(
|
||||
target=Microphone._record_loop,
|
||||
args=(
|
||||
self.record_queue,
|
||||
self.record_stop_event,
|
||||
self.sample_rate,
|
||||
self.channels,
|
||||
output_file,
|
||||
),
|
||||
)
|
||||
self.record_thread.daemon = True
|
||||
self.record_thread.start()
|
||||
|
||||
self.is_writing = True
|
||||
|
||||
self.is_recording = True
|
||||
self.stream.start()
|
||||
|
||||
def stop_recording(self) -> None:
|
||||
"""
|
||||
Stops the recording of the microphones.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
|
||||
if not self.is_recording:
|
||||
raise RobotDeviceNotRecordingError(f"Microphone {self.microphone_index} is not recording.")
|
||||
|
||||
if self.stream.active:
|
||||
self.stream.stop() # Wait for all buffers to be processed
|
||||
# Remark : stream.abort() flushes the buffers !
|
||||
self.is_recording = False
|
||||
|
||||
if self.record_thread is not None:
|
||||
self.record_queue.join()
|
||||
self.record_stop_event.set()
|
||||
self.record_thread.join()
|
||||
self.record_thread = None
|
||||
self.record_stop_event = None
|
||||
self.is_writing = False
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""
|
||||
Disconnects the microphone and stops the recording.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
|
||||
|
||||
if self.is_recording:
|
||||
self.stop_recording()
|
||||
|
||||
self.stream.close()
|
||||
self.is_connected = False
|
||||
|
||||
def __del__(self):
|
||||
if getattr(self, "is_connected", False):
|
||||
self.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Records audio using `Microphone` for all microphones connected to the computer, or a selected subset."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--microphone-ids",
|
||||
type=int,
|
||||
nargs="*",
|
||||
default=None,
|
||||
help="List of microphones indices used to instantiate the `Microphone`. If not provided, find and use all available microphones indices.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
default="outputs/audio_from_microphones",
|
||||
help="Set directory to save an audio snippet for each microphone.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--record-time-s",
|
||||
type=float,
|
||||
default=4.0,
|
||||
help="Set the number of seconds used to record the audio. By default, 4 seconds.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
record_audio_from_microphones(**vars(args))
|
|
@ -0,0 +1,48 @@
|
|||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig, MicrophoneConfigBase
|
||||
|
||||
|
||||
# Defines a microphone type
|
||||
class Microphone(Protocol):
|
||||
def connect(self): ...
|
||||
def disconnect(self): ...
|
||||
def start_recording(self, output_file: str | None = None, multiprocessing: bool | None = False): ...
|
||||
def stop_recording(self): ...
|
||||
|
||||
|
||||
def make_microphones_from_configs(microphone_configs: dict[str, MicrophoneConfigBase]) -> list[Microphone]:
|
||||
microphones = {}
|
||||
|
||||
for key, cfg in microphone_configs.items():
|
||||
if cfg.type == "microphone":
|
||||
from lerobot.common.robot_devices.microphones.microphone import Microphone
|
||||
|
||||
microphones[key] = Microphone(cfg)
|
||||
else:
|
||||
raise ValueError(f"The microphone type '{cfg.type}' is not valid.")
|
||||
|
||||
return microphones
|
||||
|
||||
|
||||
def make_microphone(microphone_type, **kwargs) -> Microphone:
|
||||
if microphone_type == "microphone":
|
||||
from lerobot.common.robot_devices.microphones.microphone import Microphone
|
||||
|
||||
return Microphone(MicrophoneConfig(**kwargs))
|
||||
else:
|
||||
raise ValueError(f"The microphone type '{microphone_type}' is not valid.")
|
|
@ -23,6 +23,7 @@ from lerobot.common.robot_devices.cameras.configs import (
|
|||
IntelRealSenseCameraConfig,
|
||||
OpenCVCameraConfig,
|
||||
)
|
||||
from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig
|
||||
from lerobot.common.robot_devices.motors.configs import (
|
||||
DynamixelMotorsBusConfig,
|
||||
FeetechMotorsBusConfig,
|
||||
|
@ -43,6 +44,7 @@ class ManipulatorRobotConfig(RobotConfig):
|
|||
leader_arms: dict[str, MotorsBusConfig] = field(default_factory=lambda: {})
|
||||
follower_arms: dict[str, MotorsBusConfig] = field(default_factory=lambda: {})
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=lambda: {})
|
||||
microphones: dict[str, MicrophoneConfig] = field(default_factory=lambda: {})
|
||||
|
||||
# Optionally limit the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length
|
||||
|
@ -68,6 +70,9 @@ class ManipulatorRobotConfig(RobotConfig):
|
|||
for cam in self.cameras.values():
|
||||
if not cam.mock:
|
||||
cam.mock = True
|
||||
for mic in self.microphones.values():
|
||||
if not mic.mock:
|
||||
mic.mock = True
|
||||
|
||||
if self.max_relative_target is not None and isinstance(self.max_relative_target, Sequence):
|
||||
for name in self.follower_arms:
|
||||
|
@ -491,6 +496,21 @@ class So100RobotConfig(ManipulatorRobotConfig):
|
|||
}
|
||||
)
|
||||
|
||||
microphones: dict[str, MicrophoneConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"laptop": MicrophoneConfig(
|
||||
microphone_index=0,
|
||||
sample_rate=48000,
|
||||
channels=[1],
|
||||
),
|
||||
"headset": MicrophoneConfig(
|
||||
microphone_index=1,
|
||||
sample_rate=44100,
|
||||
channels=[1],
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
mock: bool = False
|
||||
|
||||
|
||||
|
|
|
@ -52,6 +52,16 @@ def run_camera_capture(cameras, images_lock, latest_images_dict, stop_event):
|
|||
time.sleep(0.01)
|
||||
|
||||
|
||||
def run_microphone_capture(microphones, audio_lock, latest_audio_dict, stop_event):
|
||||
while not stop_event.is_set():
|
||||
local_dict = {}
|
||||
for name, microphone in microphones.items():
|
||||
audio_readings = microphone.read()
|
||||
local_dict[name] = audio_readings
|
||||
with audio_lock:
|
||||
latest_audio_dict.update(local_dict)
|
||||
|
||||
|
||||
def calibrate_follower_arm(motors_bus, calib_dir_str):
|
||||
"""
|
||||
Calibrates the follower arm. Attempts to load an existing calibration file;
|
||||
|
@ -94,6 +104,7 @@ def run_lekiwi(robot_config):
|
|||
"""
|
||||
# Import helper functions and classes
|
||||
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.common.robot_devices.microphones.utils import make_microphones_from_configs
|
||||
from lerobot.common.robot_devices.motors.feetech import FeetechMotorsBus, TorqueMode
|
||||
|
||||
# Initialize cameras from the robot configuration.
|
||||
|
@ -101,6 +112,11 @@ def run_lekiwi(robot_config):
|
|||
for cam in cameras.values():
|
||||
cam.connect()
|
||||
|
||||
# Initialize microphones from the robot configuration.
|
||||
microphones = make_microphones_from_configs(robot_config.microphones)
|
||||
for microphone in microphones.values():
|
||||
microphone.connect()
|
||||
|
||||
# Initialize the motors bus using the follower arm configuration.
|
||||
motor_config = robot_config.follower_arms.get("main")
|
||||
if motor_config is None:
|
||||
|
@ -134,6 +150,20 @@ def run_lekiwi(robot_config):
|
|||
)
|
||||
cam_thread.start()
|
||||
|
||||
# Start the microphone recording and capture thread.
|
||||
# TODO(CarolinePascal) : Leverage multi-core processing with a multiprocessing.Process instead !
|
||||
latest_audio_dict = {}
|
||||
audio_lock = threading.Lock()
|
||||
audio_stop_event = threading.Event()
|
||||
microphone_thread = threading.Thread(
|
||||
target=run_microphone_capture,
|
||||
args=(microphones, audio_lock, latest_audio_dict, audio_stop_event),
|
||||
daemon=True,
|
||||
)
|
||||
for microphone in microphones.values():
|
||||
microphone.start_recording()
|
||||
microphone_thread.start()
|
||||
|
||||
last_cmd_time = time.time()
|
||||
print("LeKiwi robot server started. Waiting for commands...")
|
||||
|
||||
|
@ -198,9 +228,14 @@ def run_lekiwi(robot_config):
|
|||
with images_lock:
|
||||
images_dict_copy = dict(latest_images_dict)
|
||||
|
||||
# Get the latest audio data.
|
||||
with audio_lock:
|
||||
audio_dict_copy = dict(latest_audio_dict)
|
||||
|
||||
# Build the observation dictionary.
|
||||
observation = {
|
||||
"images": images_dict_copy,
|
||||
"audio": audio_dict_copy, # TODO(CarolinePascal) : This is a nasty way to do it, sorry.
|
||||
"present_speed": current_velocity,
|
||||
"follower_arm_state": follower_arm_state,
|
||||
}
|
||||
|
@ -217,6 +252,9 @@ def run_lekiwi(robot_config):
|
|||
finally:
|
||||
stop_event.set()
|
||||
cam_thread.join()
|
||||
microphone_thread.join()
|
||||
for microphone in microphones.values():
|
||||
microphone.stop_recording()
|
||||
robot.stop()
|
||||
motors_bus.disconnect()
|
||||
cmd_socket.close()
|
||||
|
|
|
@ -28,6 +28,7 @@ import numpy as np
|
|||
import torch
|
||||
|
||||
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.common.robot_devices.microphones.utils import make_microphones_from_configs
|
||||
from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs
|
||||
from lerobot.common.robot_devices.robots.configs import ManipulatorRobotConfig
|
||||
from lerobot.common.robot_devices.robots.utils import get_arm_id
|
||||
|
@ -164,6 +165,7 @@ class ManipulatorRobot:
|
|||
self.leader_arms = make_motors_buses_from_configs(self.config.leader_arms)
|
||||
self.follower_arms = make_motors_buses_from_configs(self.config.follower_arms)
|
||||
self.cameras = make_cameras_from_configs(self.config.cameras)
|
||||
self.microphones = make_microphones_from_configs(self.config.microphones)
|
||||
self.is_connected = False
|
||||
self.logs = {}
|
||||
|
||||
|
@ -199,9 +201,24 @@ class ManipulatorRobot:
|
|||
},
|
||||
}
|
||||
|
||||
@property
|
||||
def microphone_features(self) -> dict:
|
||||
mic_ft = {}
|
||||
for mic_key, mic in self.microphones.items():
|
||||
key = f"observation.audio.{mic_key}"
|
||||
mic_ft[key] = {
|
||||
"dtype": "audio",
|
||||
"shape": (len(mic.channels),),
|
||||
"names": "channels",
|
||||
"info": {
|
||||
"sample_rate": mic.sample_rate
|
||||
}, # we need to store the sample rate here in the case of audio chunks recording (for LeKiwi), as it will not be available anymore when writing the audio file
|
||||
}
|
||||
return mic_ft
|
||||
|
||||
@property
|
||||
def features(self):
|
||||
return {**self.motor_features, **self.camera_features}
|
||||
return {**self.motor_features, **self.camera_features, **self.microphone_features}
|
||||
|
||||
@property
|
||||
def has_camera(self):
|
||||
|
@ -211,6 +228,14 @@ class ManipulatorRobot:
|
|||
def num_cameras(self):
|
||||
return len(self.cameras)
|
||||
|
||||
@property
|
||||
def has_microphone(self):
|
||||
return len(self.microphones) > 0
|
||||
|
||||
@property
|
||||
def num_microphones(self):
|
||||
return len(self.microphones)
|
||||
|
||||
@property
|
||||
def available_arms(self):
|
||||
available_arms = []
|
||||
|
@ -228,7 +253,7 @@ class ManipulatorRobot:
|
|||
"ManipulatorRobot is already connected. Do not run `robot.connect()` twice."
|
||||
)
|
||||
|
||||
if not self.leader_arms and not self.follower_arms and not self.cameras:
|
||||
if not self.leader_arms and not self.follower_arms and not self.cameras and not self.microphones:
|
||||
raise ValueError(
|
||||
"ManipulatorRobot doesn't have any device to connect. See example of usage in docstring of the class."
|
||||
)
|
||||
|
@ -289,6 +314,10 @@ class ManipulatorRobot:
|
|||
for name in self.cameras:
|
||||
self.cameras[name].connect()
|
||||
|
||||
# Connect the microphones
|
||||
for name in self.microphones:
|
||||
self.microphones[name].connect()
|
||||
|
||||
self.is_connected = True
|
||||
|
||||
def activate_calibration(self):
|
||||
|
@ -514,12 +543,23 @@ class ManipulatorRobot:
|
|||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||
|
||||
# Capture audio from microphones
|
||||
audio = {}
|
||||
for name in self.microphones:
|
||||
before_audioread_t = time.perf_counter()
|
||||
audio[name] = self.microphones[name].read()
|
||||
audio[name] = torch.from_numpy(audio[name])
|
||||
self.logs[f"read_microphone_{name}_dt_s"] = self.microphones[name].logs["delta_timestamp_s"]
|
||||
self.logs[f"async_read_microphone_{name}_dt_s"] = time.perf_counter() - before_audioread_t
|
||||
|
||||
# Populate output dictionaries
|
||||
obs_dict, action_dict = {}, {}
|
||||
obs_dict["observation.state"] = state
|
||||
action_dict["action"] = action
|
||||
for name in self.cameras:
|
||||
obs_dict[f"observation.images.{name}"] = images[name]
|
||||
for name in self.microphones:
|
||||
obs_dict[f"observation.audio.{name}"] = audio[name]
|
||||
|
||||
return obs_dict, action_dict
|
||||
|
||||
|
@ -554,11 +594,22 @@ class ManipulatorRobot:
|
|||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||
|
||||
# Capture audio from microphones
|
||||
audio = {}
|
||||
for name in self.microphones:
|
||||
before_audioread_t = time.perf_counter()
|
||||
audio[name] = self.microphones[name].read()
|
||||
audio[name] = torch.from_numpy(audio[name])
|
||||
self.logs[f"read_microphone_{name}_dt_s"] = self.microphones[name].logs["delta_timestamp_s"]
|
||||
self.logs[f"async_read_microphone_{name}_dt_s"] = time.perf_counter() - before_audioread_t
|
||||
|
||||
# Populate output dictionaries and format to pytorch
|
||||
obs_dict = {}
|
||||
obs_dict["observation.state"] = state
|
||||
for name in self.cameras:
|
||||
obs_dict[f"observation.images.{name}"] = images[name]
|
||||
for name in self.microphones:
|
||||
obs_dict[f"observation.audio.{name}"] = audio[name]
|
||||
return obs_dict
|
||||
|
||||
def send_action(self, action: torch.Tensor) -> torch.Tensor:
|
||||
|
@ -620,6 +671,9 @@ class ManipulatorRobot:
|
|||
for name in self.cameras:
|
||||
self.cameras[name].disconnect()
|
||||
|
||||
for name in self.microphones:
|
||||
self.microphones[name].disconnect()
|
||||
|
||||
self.is_connected = False
|
||||
|
||||
def __del__(self):
|
||||
|
|
|
@ -24,6 +24,7 @@ import torch
|
|||
import zmq
|
||||
|
||||
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.common.robot_devices.microphones.utils import make_microphones_from_configs
|
||||
from lerobot.common.robot_devices.motors.feetech import TorqueMode
|
||||
from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs
|
||||
from lerobot.common.robot_devices.robots.configs import LeKiwiRobotConfig
|
||||
|
@ -79,6 +80,7 @@ class MobileManipulator:
|
|||
self.follower_arms = make_motors_buses_from_configs(self.config.follower_arms)
|
||||
|
||||
self.cameras = make_cameras_from_configs(self.config.cameras)
|
||||
self.microphones = make_microphones_from_configs(self.config.microphones)
|
||||
|
||||
self.is_connected = False
|
||||
|
||||
|
@ -133,6 +135,7 @@ class MobileManipulator:
|
|||
"shape": (cam.height, cam.width, cam.channels),
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": None,
|
||||
"audio": "observation.audio." + cam.microphone if cam.microphone is not None else None,
|
||||
}
|
||||
return cam_ft
|
||||
|
||||
|
@ -161,9 +164,22 @@ class MobileManipulator:
|
|||
},
|
||||
}
|
||||
|
||||
@property
|
||||
def microphone_features(self) -> dict:
|
||||
mic_ft = {}
|
||||
for mic_key, mic in self.microphones.items():
|
||||
key = f"observation.audio.{mic_key}"
|
||||
mic_ft[key] = {
|
||||
"dtype": "audio",
|
||||
"shape": (len(mic.channels),),
|
||||
"names": "channels",
|
||||
"info": {"sample_rate": mic.sample_rate},
|
||||
}
|
||||
return mic_ft
|
||||
|
||||
@property
|
||||
def features(self):
|
||||
return {**self.motor_features, **self.camera_features}
|
||||
return {**self.motor_features, **self.camera_features, **self.microphone_features}
|
||||
|
||||
@property
|
||||
def has_camera(self):
|
||||
|
@ -173,6 +189,14 @@ class MobileManipulator:
|
|||
def num_cameras(self):
|
||||
return len(self.cameras)
|
||||
|
||||
@property
|
||||
def has_microphone(self):
|
||||
return len(self.microphones) > 0
|
||||
|
||||
@property
|
||||
def num_microphones(self):
|
||||
return len(self.microphones)
|
||||
|
||||
@property
|
||||
def available_arms(self):
|
||||
available = []
|
||||
|
@ -344,6 +368,7 @@ class MobileManipulator:
|
|||
observation = json.loads(last_msg)
|
||||
|
||||
images_dict = observation.get("images", {})
|
||||
audio_dict = observation.get("audio", {})
|
||||
new_speed = observation.get("present_speed", {})
|
||||
new_arm_state = observation.get("follower_arm_state", None)
|
||||
|
||||
|
@ -356,6 +381,11 @@ class MobileManipulator:
|
|||
if frame_candidate is not None:
|
||||
frames[cam_name] = frame_candidate
|
||||
|
||||
# Receive audio
|
||||
for microphone_name, audio_data in audio_dict.items():
|
||||
if audio_data:
|
||||
frames[microphone_name] = audio_data
|
||||
|
||||
# If remote_arm_state is None and frames is None there is no message then use the previous message
|
||||
if new_arm_state is not None and frames is not None:
|
||||
self.last_frames = frames
|
||||
|
@ -475,6 +505,14 @@ class MobileManipulator:
|
|||
frame = np.zeros((cam.height, cam.width, cam.channels), dtype=np.uint8)
|
||||
obs_dict[f"observation.images.{cam_name}"] = torch.from_numpy(frame)
|
||||
|
||||
# Loop over each configured microphone
|
||||
for microphone_name, microphone in self.microphones.items():
|
||||
frame = frames.get(microphone_name, None)
|
||||
if frame is None:
|
||||
# Create silence using the microphone's configured channels
|
||||
frame = np.zeros((1, len(microphone.channels)), dtype=np.float32)
|
||||
obs_dict[f"observation.audio.{microphone_name}"] = torch.from_numpy(frame)
|
||||
|
||||
return obs_dict
|
||||
|
||||
def send_action(self, action: torch.Tensor) -> torch.Tensor:
|
||||
|
|
|
@ -63,3 +63,25 @@ class RobotDeviceAlreadyConnectedError(Exception):
|
|||
):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class RobotDeviceNotRecordingError(Exception):
|
||||
"""Exception raised when the robot device is not recording."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message="This robot device is not recording. Try calling `robot_device.start_recording()` first.",
|
||||
):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class RobotDeviceAlreadyRecordingError(Exception):
|
||||
"""Exception raised when the robot device is already recording."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message="This robot device is already recording. Try not calling `robot_device.start_recording()` twice.",
|
||||
):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
|
|
@ -67,8 +67,11 @@ dependencies = [
|
|||
"pynput>=1.7.7",
|
||||
"pyzmq>=26.2.1",
|
||||
"rerun-sdk>=0.21.0",
|
||||
"sounddevice>=0.5.1",
|
||||
"soundfile>=0.13.1",
|
||||
"termcolor>=2.4.0",
|
||||
"torch>=2.2.1",
|
||||
"torchaudio>=2.6.0",
|
||||
"torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')",
|
||||
"torchvision>=0.21.0",
|
||||
"wandb>=0.16.3",
|
||||
|
@ -96,6 +99,7 @@ test = ["pytest>=8.1.0", "pytest-cov>=5.0.0", "pyserial>=3.5"]
|
|||
umi = ["imagecodecs>=2024.1.1"]
|
||||
video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"]
|
||||
xarm = ["gym-xarm>=0.1.1 ; python_version < '4.0'"]
|
||||
audio = ["librosa>=0.11.0"]
|
||||
|
||||
[tool.poetry]
|
||||
requires-poetry = ">=2.1"
|
||||
|
|
|
@ -19,9 +19,9 @@ import traceback
|
|||
import pytest
|
||||
from serial import SerialException
|
||||
|
||||
from lerobot import available_cameras, available_motors, available_robots
|
||||
from lerobot import available_cameras, available_microphones, available_motors, available_robots
|
||||
from lerobot.common.robot_devices.robots.utils import make_robot
|
||||
from tests.utils import DEVICE, make_camera, make_motors_bus
|
||||
from tests.utils import DEVICE, make_camera, make_microphone, make_motors_bus
|
||||
|
||||
# Import fixture modules as plugins
|
||||
pytest_plugins = [
|
||||
|
@ -74,6 +74,11 @@ def is_camera_available(camera_type):
|
|||
return _check_component_availability(camera_type, available_cameras, make_camera)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def is_microphone_available(microphone_type):
|
||||
return _check_component_availability(microphone_type, available_microphones, make_microphone)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def is_motor_available(motor_type):
|
||||
return _check_component_availability(motor_type, available_motors, make_motors_bus)
|
||||
|
|
|
@ -25,6 +25,8 @@ from lerobot.common.datasets.compute_stats import (
|
|||
compute_episode_stats,
|
||||
estimate_num_samples,
|
||||
get_feature_stats,
|
||||
sample_audio_from_data,
|
||||
sample_audio_from_path,
|
||||
sample_images,
|
||||
sample_indices,
|
||||
)
|
||||
|
@ -34,6 +36,10 @@ def mock_load_image_as_numpy(path, dtype, channel_first):
|
|||
return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype)
|
||||
|
||||
|
||||
def mock_load_audio(path):
|
||||
return np.ones((16000, 2), dtype=np.float32)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_array():
|
||||
return np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
||||
|
@ -71,6 +77,25 @@ def test_sample_images(mock_load):
|
|||
assert len(images) == estimate_num_samples(100)
|
||||
|
||||
|
||||
@patch("lerobot.common.datasets.compute_stats.load_audio_from_path", side_effect=mock_load_audio)
|
||||
def test_sample_audio_from_path(mock_load):
|
||||
audio_path = "audio.wav"
|
||||
audio_samples = sample_audio_from_path(audio_path)
|
||||
assert isinstance(audio_samples, np.ndarray)
|
||||
assert audio_samples.shape[1] == 2
|
||||
assert audio_samples.dtype == np.float32
|
||||
assert len(audio_samples) == estimate_num_samples(16000)
|
||||
|
||||
|
||||
def test_sample_audio_from_data():
|
||||
audio_data = np.ones((16000, 2), dtype=np.float32)
|
||||
audio_samples = sample_audio_from_data(audio_data)
|
||||
assert isinstance(audio_samples, np.ndarray)
|
||||
assert audio_samples.shape[1] == 2
|
||||
assert audio_samples.dtype == np.float32
|
||||
assert len(audio_samples) == estimate_num_samples(16000)
|
||||
|
||||
|
||||
def test_get_feature_stats_images():
|
||||
data = np.random.rand(100, 3, 32, 32)
|
||||
stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True)
|
||||
|
@ -79,6 +104,14 @@ def test_get_feature_stats_images():
|
|||
assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape
|
||||
|
||||
|
||||
def test_get_feature_stats_audio():
|
||||
data = np.random.uniform(-1, 1, (16000, 2))
|
||||
stats = get_feature_stats(data, axis=0, keepdims=True)
|
||||
assert "min" in stats and "max" in stats and "mean" in stats and "std" in stats and "count" in stats
|
||||
np.testing.assert_equal(stats["count"], np.array([16000]))
|
||||
assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape
|
||||
|
||||
|
||||
def test_get_feature_stats_axis_0_keepdims(sample_array):
|
||||
expected = {
|
||||
"min": np.array([[1, 2, 3]]),
|
||||
|
@ -137,22 +170,29 @@ def test_get_feature_stats_single_value():
|
|||
def test_compute_episode_stats():
|
||||
episode_data = {
|
||||
"observation.image": [f"image_{i}.jpg" for i in range(100)],
|
||||
"observation.audio": "audio.wav",
|
||||
"observation.state": np.random.rand(100, 10),
|
||||
}
|
||||
features = {
|
||||
"observation.image": {"dtype": "image"},
|
||||
"observation.audio": {"dtype": "audio"},
|
||||
"observation.state": {"dtype": "numeric"},
|
||||
}
|
||||
|
||||
with patch(
|
||||
with (
|
||||
patch(
|
||||
"lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy
|
||||
),
|
||||
patch("lerobot.common.datasets.compute_stats.load_audio_from_path", side_effect=mock_load_audio),
|
||||
):
|
||||
stats = compute_episode_stats(episode_data, features)
|
||||
|
||||
assert "observation.image" in stats and "observation.state" in stats
|
||||
assert stats["observation.image"]["count"].item() == 100
|
||||
assert stats["observation.state"]["count"].item() == 100
|
||||
assert "observation.image" in stats and "observation.state" in stats and "observation.audio" in stats
|
||||
assert stats["observation.image"]["count"].item() == estimate_num_samples(100)
|
||||
assert stats["observation.audio"]["count"].item() == estimate_num_samples(16000)
|
||||
assert stats["observation.state"]["count"].item() == estimate_num_samples(100)
|
||||
assert stats["observation.image"]["mean"].shape == (3, 1, 1)
|
||||
assert stats["observation.audio"]["mean"].shape == (1, 2)
|
||||
|
||||
|
||||
def test_assert_type_and_shape_valid():
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
|
@ -35,6 +36,7 @@ from lerobot.common.datasets.lerobot_dataset import (
|
|||
MultiLeRobotDataset,
|
||||
)
|
||||
from lerobot.common.datasets.utils import (
|
||||
DEFAULT_AUDIO_CHUNK_DURATION,
|
||||
create_branch,
|
||||
flatten_dict,
|
||||
unflatten_dict,
|
||||
|
@ -44,8 +46,8 @@ from lerobot.common.policies.factory import make_policy_config
|
|||
from lerobot.common.robot_devices.robots.utils import make_robot
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
||||
from tests.utils import require_x86_64_kernel
|
||||
from tests.fixtures.constants import DUMMY_AUDIO_CHANNELS, DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
||||
from tests.utils import make_microphone, require_x86_64_kernel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -64,6 +66,20 @@ def image_dataset(tmp_path, empty_lerobot_dataset_factory):
|
|||
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def audio_dataset(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {
|
||||
"observation.audio.microphone": {
|
||||
"dtype": "audio",
|
||||
"shape": (DUMMY_AUDIO_CHANNELS,),
|
||||
"names": [
|
||||
"channels",
|
||||
],
|
||||
}
|
||||
}
|
||||
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
|
||||
|
||||
def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
|
||||
"""
|
||||
Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated
|
||||
|
@ -322,6 +338,24 @@ def test_image_array_to_pil_image_wrong_range_float_0_255():
|
|||
image_array_to_pil_image(image)
|
||||
|
||||
|
||||
def test_add_frame_audio(audio_dataset):
|
||||
dataset = audio_dataset
|
||||
|
||||
microphone = make_microphone(microphone_type="microphone", mock=True)
|
||||
microphone.connect()
|
||||
|
||||
dataset.add_microphone_recording(microphone, "microphone")
|
||||
time.sleep(1.0)
|
||||
dataset.add_frame({"observation.audio.microphone": microphone.read(), "task": "Dummy task"})
|
||||
microphone.stop_recording()
|
||||
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["observation.audio.microphone"].shape == torch.Size(
|
||||
(int(DEFAULT_AUDIO_CHUNK_DURATION * microphone.sample_rate), DUMMY_AUDIO_CHANNELS)
|
||||
)
|
||||
|
||||
|
||||
# TODO(aliberts):
|
||||
# - [ ] test various attributes & state from init and create
|
||||
# - [ ] test init with episodes and check num_frames
|
||||
|
@ -354,6 +388,7 @@ def test_factory(env_name, repo_id, policy_name):
|
|||
dataset = make_dataset(cfg)
|
||||
delta_timestamps = dataset.delta_timestamps
|
||||
camera_keys = dataset.meta.camera_keys
|
||||
audio_keys = dataset.meta.audio_keys
|
||||
|
||||
item = dataset[0]
|
||||
|
||||
|
@ -396,6 +431,11 @@ def test_factory(env_name, repo_id, policy_name):
|
|||
# test c,h,w
|
||||
assert item[key].shape[0] == 3, f"{key}"
|
||||
|
||||
for key in audio_keys:
|
||||
assert item[key].dtype == torch.float32, f"{key}"
|
||||
assert item[key].max() <= 1.0, f"{key}"
|
||||
assert item[key].min() >= -1.0, f"{key}"
|
||||
|
||||
if delta_timestamps is not None:
|
||||
# test missing keys in delta_timestamps
|
||||
for key in delta_timestamps:
|
||||
|
|
|
@ -29,7 +29,12 @@ DUMMY_MOTOR_FEATURES = {
|
|||
},
|
||||
}
|
||||
DUMMY_CAMERA_FEATURES = {
|
||||
"laptop": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
|
||||
"laptop": {
|
||||
"shape": (480, 640, 3),
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": None,
|
||||
"audio": "laptop",
|
||||
},
|
||||
"phone": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
|
||||
}
|
||||
DEFAULT_FPS = 30
|
||||
|
@ -40,5 +45,18 @@ DUMMY_VIDEO_INFO = {
|
|||
"video.is_depth_map": False,
|
||||
"has_audio": False,
|
||||
}
|
||||
DUMMY_MICROPHONE_FEATURES = {
|
||||
"laptop": {"dtype": "audio", "shape": (1,), "names": ["channels"], "info": None},
|
||||
"phone": {"dtype": "audio", "shape": (1,), "names": ["channels"], "info": None},
|
||||
}
|
||||
DEFAULT_SAMPLE_RATE = 48000
|
||||
DUMMY_AUDIO_CHANNELS = 2
|
||||
DUMMY_AUDIO_INFO = {
|
||||
"has_audio": True,
|
||||
"audio.sample_rate": DEFAULT_SAMPLE_RATE,
|
||||
"audio.codec": "aac",
|
||||
"audio.channels": DUMMY_AUDIO_CHANNELS,
|
||||
"audio.channel_layout": "stereo",
|
||||
}
|
||||
DUMMY_CHW = (3, 96, 128)
|
||||
DUMMY_HWC = (96, 128, 3)
|
||||
|
|
|
@ -26,6 +26,7 @@ import torch
|
|||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.common.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_COMPRESSED_AUDIO_PATH,
|
||||
DEFAULT_FEATURES,
|
||||
DEFAULT_PARQUET_PATH,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
|
@ -35,6 +36,7 @@ from lerobot.common.datasets.utils import (
|
|||
from tests.fixtures.constants import (
|
||||
DEFAULT_FPS,
|
||||
DUMMY_CAMERA_FEATURES,
|
||||
DUMMY_MICROPHONE_FEATURES,
|
||||
DUMMY_MOTOR_FEATURES,
|
||||
DUMMY_REPO_ID,
|
||||
DUMMY_ROBOT_TYPE,
|
||||
|
@ -90,6 +92,7 @@ def features_factory():
|
|||
def _create_features(
|
||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
||||
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
||||
audio_features: dict = DUMMY_MICROPHONE_FEATURES,
|
||||
use_videos: bool = True,
|
||||
) -> dict:
|
||||
if use_videos:
|
||||
|
@ -101,6 +104,7 @@ def features_factory():
|
|||
return {
|
||||
**motor_features,
|
||||
**camera_ft,
|
||||
**audio_features,
|
||||
**DEFAULT_FEATURES,
|
||||
}
|
||||
|
||||
|
@ -117,15 +121,18 @@ def info_factory(features_factory):
|
|||
total_frames: int = 0,
|
||||
total_tasks: int = 0,
|
||||
total_videos: int = 0,
|
||||
total_audio: int = 0,
|
||||
total_chunks: int = 0,
|
||||
chunks_size: int = DEFAULT_CHUNK_SIZE,
|
||||
data_path: str = DEFAULT_PARQUET_PATH,
|
||||
video_path: str = DEFAULT_VIDEO_PATH,
|
||||
audio_path: str = DEFAULT_COMPRESSED_AUDIO_PATH,
|
||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
||||
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
||||
audio_features: dict = DUMMY_MICROPHONE_FEATURES,
|
||||
use_videos: bool = True,
|
||||
) -> dict:
|
||||
features = features_factory(motor_features, camera_features, use_videos)
|
||||
features = features_factory(motor_features, camera_features, audio_features, use_videos)
|
||||
return {
|
||||
"codebase_version": codebase_version,
|
||||
"robot_type": robot_type,
|
||||
|
@ -133,12 +140,14 @@ def info_factory(features_factory):
|
|||
"total_frames": total_frames,
|
||||
"total_tasks": total_tasks,
|
||||
"total_videos": total_videos,
|
||||
"total_audio": total_audio,
|
||||
"total_chunks": total_chunks,
|
||||
"chunks_size": chunks_size,
|
||||
"fps": fps,
|
||||
"splits": {},
|
||||
"data_path": data_path,
|
||||
"video_path": video_path if use_videos else None,
|
||||
"audio_path": audio_path,
|
||||
"features": features,
|
||||
}
|
||||
|
||||
|
@ -162,6 +171,14 @@ def stats_factory():
|
|||
"std": np.full((3, 1, 1), 0.25, dtype=np.float32).tolist(),
|
||||
"count": [10],
|
||||
}
|
||||
elif dtype == "audio":
|
||||
stats[key] = {
|
||||
"mean": np.full((shape[0],), 0.0, dtype=np.float32).tolist(),
|
||||
"max": np.full((shape[0],), 1, dtype=np.float32).tolist(),
|
||||
"min": np.full((shape[0],), -1, dtype=np.float32).tolist(),
|
||||
"std": np.full((shape[0],), 0.5, dtype=np.float32).tolist(),
|
||||
"count": [10],
|
||||
}
|
||||
else:
|
||||
stats[key] = {
|
||||
"max": np.full(shape, 1, dtype=dtype).tolist(),
|
||||
|
|
|
@ -0,0 +1,88 @@
|
|||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import time
|
||||
from functools import cache
|
||||
from threading import Event, Thread
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.common.utils.utils import capture_timestamp_utc
|
||||
from tests.fixtures.constants import DEFAULT_SAMPLE_RATE, DUMMY_AUDIO_CHANNELS
|
||||
|
||||
|
||||
@cache
|
||||
def _generate_sound(duration: float, sample_rate: int, channels: int):
|
||||
return np.random.uniform(-1, 1, size=(int(duration * sample_rate), channels)).astype(np.float32)
|
||||
|
||||
|
||||
def query_devices(query_index: int):
|
||||
return {
|
||||
"name": "Mock Sound Device",
|
||||
"index": query_index,
|
||||
"max_input_channels": DUMMY_AUDIO_CHANNELS,
|
||||
"default_samplerate": DEFAULT_SAMPLE_RATE,
|
||||
}
|
||||
|
||||
|
||||
class InputStream:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._mock_dict = {
|
||||
"channels": DUMMY_AUDIO_CHANNELS,
|
||||
"samplerate": DEFAULT_SAMPLE_RATE,
|
||||
}
|
||||
self._is_active = False
|
||||
self._audio_callback = kwargs.get("callback")
|
||||
|
||||
self.callback_thread = None
|
||||
self.callback_thread_stop_event = None
|
||||
|
||||
def _acquisition_loop(self):
|
||||
if self._audio_callback is not None:
|
||||
while not self.callback_thread_stop_event.is_set():
|
||||
# Simulate audio data acquisition
|
||||
time.sleep(0.01)
|
||||
self._audio_callback(
|
||||
_generate_sound(0.01, DEFAULT_SAMPLE_RATE, DUMMY_AUDIO_CHANNELS),
|
||||
0.01 * DEFAULT_SAMPLE_RATE,
|
||||
capture_timestamp_utc(),
|
||||
None,
|
||||
)
|
||||
|
||||
def start(self):
|
||||
self.callback_thread_stop_event = Event()
|
||||
self.callback_thread = Thread(target=self._acquisition_loop, args=())
|
||||
self.callback_thread.daemon = True
|
||||
self.callback_thread.start()
|
||||
|
||||
self._is_active = True
|
||||
|
||||
@property
|
||||
def active(self):
|
||||
return self._is_active
|
||||
|
||||
def stop(self):
|
||||
if self.callback_thread_stop_event is not None:
|
||||
self.callback_thread_stop_event.set()
|
||||
self.callback_thread.join()
|
||||
self.callback_thread = None
|
||||
self.callback_thread_stop_event = None
|
||||
self._is_active = False
|
||||
|
||||
def close(self):
|
||||
if self._is_active:
|
||||
self.stop()
|
||||
|
||||
def __del__(self):
|
||||
if self._is_active:
|
||||
self.stop()
|
|
@ -0,0 +1,152 @@
|
|||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Tests for physical microphones and their mocked versions.
|
||||
If the physical microphone is not connected to the computer, or not working,
|
||||
the test will be skipped.
|
||||
|
||||
Example of running a specific test:
|
||||
```bash
|
||||
pytest -sx tests/microphones/test_microphones.py::test_microphone
|
||||
```
|
||||
|
||||
Example of running test on a real microphone connected to the computer:
|
||||
```bash
|
||||
pytest -sx 'tests/microphones/test_microphones.py::test_microphone[microphone-False]'
|
||||
```
|
||||
|
||||
Example of running test on a mocked version of the microphone:
|
||||
```bash
|
||||
pytest -sx 'tests/microphones/test_microphones.py::test_microphone[microphone-True]'
|
||||
```
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from soundfile import read
|
||||
|
||||
from lerobot.common.robot_devices.utils import (
|
||||
RobotDeviceAlreadyConnectedError,
|
||||
RobotDeviceAlreadyRecordingError,
|
||||
RobotDeviceNotConnectedError,
|
||||
RobotDeviceNotRecordingError,
|
||||
)
|
||||
from tests.utils import TEST_MICROPHONE_TYPES, make_microphone, require_microphone
|
||||
|
||||
# Maximum recording tie difference between two consecutive audio recordings of the same duration.
|
||||
# Set to 0.02 seconds as twice the default size of sounddvice callback buffer (i.e. we tolerate the loss of one buffer).
|
||||
MAX_RECORDING_TIME_DIFFERENCE = 0.02
|
||||
|
||||
DUMMY_RECORDING = "test_recording.wav"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("microphone_type, mock", TEST_MICROPHONE_TYPES)
|
||||
@require_microphone
|
||||
def test_microphone(tmp_path, request, microphone_type, mock):
|
||||
"""Test assumes that a recroding handled with microphone.start_recording(output_file) and stop_recording() or microphone.read()
|
||||
leqds to a sample that does not differ from the requested duration by more than 0.1 seconds.
|
||||
"""
|
||||
|
||||
microphone_kwargs = {"microphone_type": microphone_type, "mock": mock}
|
||||
|
||||
# Test instantiating
|
||||
microphone = make_microphone(**microphone_kwargs)
|
||||
|
||||
# Test start_recording, stop_recording, read and disconnecting before connecting raises an error
|
||||
with pytest.raises(RobotDeviceNotConnectedError):
|
||||
microphone.start_recording()
|
||||
with pytest.raises(RobotDeviceNotConnectedError):
|
||||
microphone.stop_recording()
|
||||
with pytest.raises(RobotDeviceNotConnectedError):
|
||||
microphone.read()
|
||||
with pytest.raises(RobotDeviceNotConnectedError):
|
||||
microphone.disconnect()
|
||||
|
||||
# Test deleting the object without connecting first
|
||||
del microphone
|
||||
|
||||
# Test connecting
|
||||
microphone = make_microphone(**microphone_kwargs)
|
||||
microphone.connect()
|
||||
assert microphone.is_connected
|
||||
assert microphone.sample_rate is not None
|
||||
assert microphone.channels is not None
|
||||
|
||||
# Test connecting twice raises an error
|
||||
with pytest.raises(RobotDeviceAlreadyConnectedError):
|
||||
microphone.connect()
|
||||
|
||||
# Test reading or stop recording before starting recording raises an error
|
||||
with pytest.raises(RobotDeviceNotRecordingError):
|
||||
microphone.read()
|
||||
with pytest.raises(RobotDeviceNotRecordingError):
|
||||
microphone.stop_recording()
|
||||
|
||||
# Test start_recording
|
||||
fpath = tmp_path / DUMMY_RECORDING
|
||||
microphone.start_recording(fpath)
|
||||
assert microphone.is_recording
|
||||
|
||||
# Test start_recording twice raises an error
|
||||
with pytest.raises(RobotDeviceAlreadyRecordingError):
|
||||
microphone.start_recording()
|
||||
|
||||
# Test reading from the microphone
|
||||
time.sleep(1.0)
|
||||
audio_chunk = microphone.read()
|
||||
assert isinstance(audio_chunk, np.ndarray)
|
||||
assert audio_chunk.ndim == 2
|
||||
_, c = audio_chunk.shape
|
||||
assert c == len(microphone.channels)
|
||||
|
||||
# Test stop_recording
|
||||
microphone.stop_recording()
|
||||
assert fpath.exists()
|
||||
assert not microphone.stream.active
|
||||
assert microphone.record_thread is None
|
||||
|
||||
# Test stop_recording twice raises an error
|
||||
with pytest.raises(RobotDeviceNotRecordingError):
|
||||
microphone.stop_recording()
|
||||
|
||||
# Test reading and recording output similar length audio chunks
|
||||
microphone.start_recording(tmp_path / DUMMY_RECORDING)
|
||||
time.sleep(1.0)
|
||||
audio_chunk = microphone.read()
|
||||
microphone.stop_recording()
|
||||
|
||||
recorded_audio, recorded_sample_rate = read(fpath)
|
||||
assert recorded_sample_rate == microphone.sample_rate
|
||||
|
||||
error_msg = (
|
||||
"Recording time difference between read() and stop_recording()",
|
||||
(len(audio_chunk) - len(recorded_audio)) / MAX_RECORDING_TIME_DIFFERENCE,
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
len(audio_chunk),
|
||||
len(recorded_audio),
|
||||
atol=recorded_sample_rate * MAX_RECORDING_TIME_DIFFERENCE,
|
||||
err_msg=error_msg,
|
||||
)
|
||||
|
||||
# Test disconnecting
|
||||
microphone.disconnect()
|
||||
assert not microphone.is_connected
|
||||
|
||||
# Test disconnecting with `__del__`
|
||||
microphone = make_microphone(**microphone_kwargs)
|
||||
microphone.connect()
|
||||
del microphone
|
|
@ -100,6 +100,9 @@ def test_robot(tmp_path, request, robot_type, mock):
|
|||
robot.teleop_step()
|
||||
|
||||
# Test data recorded during teleop are well formatted
|
||||
for _, microphone in robot.microphones.items():
|
||||
microphone.start_recording()
|
||||
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
# State
|
||||
assert "observation.state" in observation
|
||||
|
@ -112,6 +115,11 @@ def test_robot(tmp_path, request, robot_type, mock):
|
|||
assert f"observation.images.{name}" in observation
|
||||
assert isinstance(observation[f"observation.images.{name}"], torch.Tensor)
|
||||
assert observation[f"observation.images.{name}"].ndim == 3
|
||||
# Microphones
|
||||
for name in robot.microphones:
|
||||
assert f"observation.audio.{name}" in observation
|
||||
assert isinstance(observation[f"observation.audio.{name}"], torch.Tensor)
|
||||
assert observation[f"observation.audio.{name}"].ndim == 2
|
||||
# Action
|
||||
assert "action" in action
|
||||
assert isinstance(action["action"], torch.Tensor)
|
||||
|
@ -124,8 +132,9 @@ def test_robot(tmp_path, request, robot_type, mock):
|
|||
captured_observation = robot.capture_observation()
|
||||
assert set(captured_observation.keys()) == set(observation.keys())
|
||||
for name in captured_observation:
|
||||
if "image" in name:
|
||||
if "image" in name or "audio" in name:
|
||||
# TODO(rcadene): skipping image for now as it's challenging to assess equality between two consecutive frames
|
||||
# Also skipping for audio as audio chunks may be of different length
|
||||
continue
|
||||
torch.testing.assert_close(captured_observation[name], observation[name], rtol=1e-4, atol=1)
|
||||
assert captured_observation[name].shape == observation[name].shape
|
||||
|
@ -134,7 +143,7 @@ def test_robot(tmp_path, request, robot_type, mock):
|
|||
robot.send_action(action["action"])
|
||||
|
||||
# Test disconnecting
|
||||
robot.disconnect()
|
||||
robot.disconnect() # Also handles microphone recording stop, life is beautiful
|
||||
assert not robot.is_connected
|
||||
for name in robot.follower_arms:
|
||||
assert not robot.follower_arms[name].is_connected
|
||||
|
@ -142,3 +151,5 @@ def test_robot(tmp_path, request, robot_type, mock):
|
|||
assert not robot.leader_arms[name].is_connected
|
||||
for name in robot.cameras:
|
||||
assert not robot.cameras[name].is_connected
|
||||
for name in robot.microphones:
|
||||
assert not robot.microphones[name].is_connected
|
||||
|
|
|
@ -22,9 +22,11 @@ from pathlib import Path
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot import available_cameras, available_motors, available_robots
|
||||
from lerobot import available_cameras, available_microphones, available_motors, available_robots
|
||||
from lerobot.common.robot_devices.cameras.utils import Camera
|
||||
from lerobot.common.robot_devices.cameras.utils import make_camera as make_camera_device
|
||||
from lerobot.common.robot_devices.microphones.utils import Microphone
|
||||
from lerobot.common.robot_devices.microphones.utils import make_microphone as make_microphone_device
|
||||
from lerobot.common.robot_devices.motors.utils import MotorsBus
|
||||
from lerobot.common.robot_devices.motors.utils import make_motors_bus as make_motors_bus_device
|
||||
from lerobot.common.utils.import_utils import is_package_available
|
||||
|
@ -39,6 +41,10 @@ TEST_CAMERA_TYPES = []
|
|||
for camera_type in available_cameras:
|
||||
TEST_CAMERA_TYPES += [(camera_type, True), (camera_type, False)]
|
||||
|
||||
TEST_MICROPHONE_TYPES = []
|
||||
for microphone_type in available_microphones:
|
||||
TEST_MICROPHONE_TYPES += [(microphone_type, True), (microphone_type, False)]
|
||||
|
||||
TEST_MOTOR_TYPES = []
|
||||
for motor_type in available_motors:
|
||||
TEST_MOTOR_TYPES += [(motor_type, True), (motor_type, False)]
|
||||
|
@ -47,6 +53,9 @@ for motor_type in available_motors:
|
|||
OPENCV_CAMERA_INDEX = int(os.environ.get("LEROBOT_TEST_OPENCV_CAMERA_INDEX", 0))
|
||||
INTELREALSENSE_SERIAL_NUMBER = int(os.environ.get("LEROBOT_TEST_INTELREALSENSE_SERIAL_NUMBER", 128422271614))
|
||||
|
||||
# Microphone indices used for connecting physical microphones
|
||||
MICROPHONE_INDEX = int(os.environ.get("LEROBOT_TEST_MICROPHONE_INDEX", 0))
|
||||
|
||||
DYNAMIXEL_PORT = os.environ.get("LEROBOT_TEST_DYNAMIXEL_PORT", "/dev/tty.usbmodem575E0032081")
|
||||
DYNAMIXEL_MOTORS = {
|
||||
"shoulder_pan": [1, "xl430-w250"],
|
||||
|
@ -253,6 +262,29 @@ def require_camera(func):
|
|||
return wrapper
|
||||
|
||||
|
||||
def require_microphone(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
# Access the pytest request context to get the is_microphone_available fixture
|
||||
request = kwargs.get("request")
|
||||
microphone_type = kwargs.get("microphone_type")
|
||||
mock = kwargs.get("mock")
|
||||
|
||||
if request is None:
|
||||
raise ValueError("The 'request' fixture must be an argument of the test function.")
|
||||
if microphone_type is None:
|
||||
raise ValueError("The 'microphone_type' must be an argument of the test function.")
|
||||
if mock is None:
|
||||
raise ValueError("The 'mock' variable must be an argument of the test function.")
|
||||
|
||||
if not mock and not request.getfixturevalue("is_microphone_available"):
|
||||
pytest.skip(f"A {microphone_type} microphone is not available.")
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def require_motor(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
|
@ -315,6 +347,14 @@ def make_camera(camera_type: str, **kwargs) -> Camera:
|
|||
raise ValueError(f"The camera type '{camera_type}' is not valid.")
|
||||
|
||||
|
||||
def make_microphone(microphone_type: str, **kwargs) -> Microphone:
|
||||
if microphone_type == "microphone":
|
||||
microphone_index = kwargs.pop("microphone_index", MICROPHONE_INDEX)
|
||||
return make_microphone_device(microphone_type, microphone_index=microphone_index, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"The microphone type '{microphone_type}' is not valid.")
|
||||
|
||||
|
||||
# TODO(rcadene, aliberts): remove this dark pattern that overrides
|
||||
def make_motors_bus(motor_type: str, **kwargs) -> MotorsBus:
|
||||
if motor_type == "dynamixel":
|
||||
|
|
Loading…
Reference in New Issue