Merge ca716ed196
into 768e36660d
This commit is contained in:
commit
35d58527a4
|
@ -190,6 +190,11 @@ available_cameras = [
|
||||||
"intelrealsense",
|
"intelrealsense",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# lists all available microphones from `lerobot/common/robot_devices/microphones`
|
||||||
|
available_microphones = [
|
||||||
|
"microphone",
|
||||||
|
]
|
||||||
|
|
||||||
# lists all available motors from `lerobot/common/robot_devices/motors`
|
# lists all available motors from `lerobot/common/robot_devices/motors`
|
||||||
available_motors = [
|
available_motors = [
|
||||||
"dynamixel",
|
"dynamixel",
|
||||||
|
|
|
@ -0,0 +1,165 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import subprocess
|
||||||
|
from collections import OrderedDict
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from numpy import ceil
|
||||||
|
|
||||||
|
|
||||||
|
def decode_audio(
|
||||||
|
audio_path: Path | str,
|
||||||
|
timestamps: list[float],
|
||||||
|
duration: float,
|
||||||
|
backend: str | None = "ffmpeg",
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Decodes audio using the specified backend.
|
||||||
|
Args:
|
||||||
|
audio_path (Path): Path to the audio file.
|
||||||
|
timestamps (list[float]): List of (starting) timestamps to extract audio chunks.
|
||||||
|
duration (float): Duration of the audio chunks in seconds.
|
||||||
|
backend (str, optional): Backend to use for decoding. Defaults to "ffmpeg".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Decoded audio chunks.
|
||||||
|
|
||||||
|
Currently supports ffmpeg.
|
||||||
|
"""
|
||||||
|
if backend == "torchcodec":
|
||||||
|
raise NotImplementedError("torchcodec is not yet supported for audio decoding")
|
||||||
|
elif backend == "ffmpeg":
|
||||||
|
return decode_audio_torchaudio(audio_path, timestamps, duration)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported video backend: {backend}")
|
||||||
|
|
||||||
|
|
||||||
|
def decode_audio_torchaudio(
|
||||||
|
audio_path: Path | str,
|
||||||
|
timestamps: list[float],
|
||||||
|
duration: float,
|
||||||
|
log_loaded_timestamps: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# TODO(CarolinePascal) : add channels selection
|
||||||
|
audio_path = str(audio_path)
|
||||||
|
|
||||||
|
reader = torchaudio.io.StreamReader(src=audio_path)
|
||||||
|
audio_sample_rate = reader.get_src_stream_info(reader.default_audio_stream).sample_rate
|
||||||
|
|
||||||
|
# TODO(CarolinePascal) : sort timestamps ?
|
||||||
|
reader.add_basic_audio_stream(
|
||||||
|
frames_per_chunk=int(ceil(duration * audio_sample_rate)), # Too much is better than not enough
|
||||||
|
buffer_chunk_size=-1, # No dropping frames
|
||||||
|
format="fltp", # Format as float32
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_chunks = []
|
||||||
|
for ts in timestamps:
|
||||||
|
reader.seek(ts) # Default to closest audio sample
|
||||||
|
status = reader.fill_buffer()
|
||||||
|
if status != 0:
|
||||||
|
logging.warning("Audio stream reached end of recording before decoding desired timestamps.")
|
||||||
|
|
||||||
|
current_audio_chunk = reader.pop_chunks()[0]
|
||||||
|
|
||||||
|
if log_loaded_timestamps:
|
||||||
|
logging.info(
|
||||||
|
f"audio chunk loaded at starting timestamp={current_audio_chunk['pts']:.4f} with duration={len(current_audio_chunk) / audio_sample_rate:.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_chunks.append(current_audio_chunk)
|
||||||
|
|
||||||
|
audio_chunks = torch.stack(audio_chunks)
|
||||||
|
|
||||||
|
assert len(timestamps) == len(audio_chunks)
|
||||||
|
return audio_chunks
|
||||||
|
|
||||||
|
|
||||||
|
def encode_audio(
|
||||||
|
input_path: Path | str,
|
||||||
|
output_path: Path | str,
|
||||||
|
codec: str = "aac", # TODO(CarolinePascal) : investigate Fraunhofer FDK AAC (libfdk_aac) codec and and constant (file size control) /variable (quality control) bitrate options
|
||||||
|
log_level: str | None = "error",
|
||||||
|
overwrite: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""Encodes an audio file using ffmpeg."""
|
||||||
|
output_path = Path(output_path)
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
ffmpeg_args = OrderedDict(
|
||||||
|
[
|
||||||
|
("-i", str(input_path)),
|
||||||
|
("-acodec", codec),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if log_level is not None:
|
||||||
|
ffmpeg_args["-loglevel"] = str(log_level)
|
||||||
|
|
||||||
|
ffmpeg_args = [item for pair in ffmpeg_args.items() for item in pair]
|
||||||
|
if overwrite:
|
||||||
|
ffmpeg_args.append("-y")
|
||||||
|
|
||||||
|
ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(output_path)]
|
||||||
|
|
||||||
|
# redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal
|
||||||
|
subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL)
|
||||||
|
|
||||||
|
if not output_path.exists():
|
||||||
|
raise OSError(
|
||||||
|
f"Audio encoding did not work. File not found: {output_path}. "
|
||||||
|
f"Try running the command manually to debug: `{''.join(ffmpeg_cmd)}`"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_audio_info(video_path: Path | str) -> dict:
|
||||||
|
ffprobe_audio_cmd = [
|
||||||
|
"ffprobe",
|
||||||
|
"-v",
|
||||||
|
"error",
|
||||||
|
"-select_streams",
|
||||||
|
"a:0",
|
||||||
|
"-show_entries",
|
||||||
|
"stream=channels,codec_name,bit_rate,sample_rate,bit_depth,channel_layout,duration",
|
||||||
|
"-of",
|
||||||
|
"json",
|
||||||
|
str(video_path),
|
||||||
|
]
|
||||||
|
result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||||
|
if result.returncode != 0:
|
||||||
|
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
||||||
|
|
||||||
|
info = json.loads(result.stdout)
|
||||||
|
audio_stream_info = info["streams"][0] if info.get("streams") else None
|
||||||
|
if audio_stream_info is None:
|
||||||
|
return {"has_audio": False}
|
||||||
|
|
||||||
|
# Return the information, defaulting to None if no audio stream is present
|
||||||
|
return {
|
||||||
|
"has_audio": True,
|
||||||
|
"audio.channels": audio_stream_info.get("channels", None),
|
||||||
|
"audio.codec": audio_stream_info.get("codec_name", None),
|
||||||
|
"audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None,
|
||||||
|
"audio.sample_rate": int(audio_stream_info["sample_rate"])
|
||||||
|
if audio_stream_info.get("sample_rate")
|
||||||
|
else None,
|
||||||
|
"audio.bit_depth": audio_stream_info.get("bit_depth", None),
|
||||||
|
"audio.channel_layout": audio_stream_info.get("channel_layout", None),
|
||||||
|
}
|
|
@ -15,7 +15,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from lerobot.common.datasets.utils import load_image_as_numpy
|
from lerobot.common.datasets.utils import load_audio_from_path, load_image_as_numpy
|
||||||
|
|
||||||
|
|
||||||
def estimate_num_samples(
|
def estimate_num_samples(
|
||||||
|
@ -72,6 +72,20 @@ def sample_images(image_paths: list[str]) -> np.ndarray:
|
||||||
return images
|
return images
|
||||||
|
|
||||||
|
|
||||||
|
def sample_audio_from_path(audio_path: str) -> np.ndarray:
|
||||||
|
"""Samples audio data from an audio recording stored in a WAV file."""
|
||||||
|
data = load_audio_from_path(audio_path)
|
||||||
|
sampled_indices = sample_indices(len(data))
|
||||||
|
|
||||||
|
return data[sampled_indices]
|
||||||
|
|
||||||
|
|
||||||
|
def sample_audio_from_data(data: np.ndarray) -> np.ndarray:
|
||||||
|
"""Samples audio data from an audio recording stored in a numpy array."""
|
||||||
|
sampled_indices = sample_indices(len(data))
|
||||||
|
return data[sampled_indices]
|
||||||
|
|
||||||
|
|
||||||
def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]:
|
def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]:
|
||||||
return {
|
return {
|
||||||
"min": np.min(array, axis=axis, keepdims=keepdims),
|
"min": np.min(array, axis=axis, keepdims=keepdims),
|
||||||
|
@ -91,6 +105,13 @@ def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], featu
|
||||||
ep_ft_array = sample_images(data) # data is a list of image paths
|
ep_ft_array = sample_images(data) # data is a list of image paths
|
||||||
axes_to_reduce = (0, 2, 3) # keep channel dim
|
axes_to_reduce = (0, 2, 3) # keep channel dim
|
||||||
keepdims = True
|
keepdims = True
|
||||||
|
elif features[key]["dtype"] == "audio":
|
||||||
|
try:
|
||||||
|
ep_ft_array = sample_audio_from_path(data[0])
|
||||||
|
except TypeError: # Should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner
|
||||||
|
ep_ft_array = sample_audio_from_data(data)
|
||||||
|
axes_to_reduce = 0
|
||||||
|
keepdims = True
|
||||||
else:
|
else:
|
||||||
ep_ft_array = data # data is already a np.ndarray
|
ep_ft_array = data # data is already a np.ndarray
|
||||||
axes_to_reduce = 0 # compute stats over the first axis
|
axes_to_reduce = 0 # compute stats over the first axis
|
||||||
|
|
|
@ -23,6 +23,7 @@ import datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import packaging.version
|
import packaging.version
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
|
import soundfile as sf
|
||||||
import torch
|
import torch
|
||||||
import torch.utils
|
import torch.utils
|
||||||
from datasets import concatenate_datasets, load_dataset
|
from datasets import concatenate_datasets, load_dataset
|
||||||
|
@ -31,11 +32,18 @@ from huggingface_hub.constants import REPOCARD_NAME
|
||||||
from huggingface_hub.errors import RevisionNotFoundError
|
from huggingface_hub.errors import RevisionNotFoundError
|
||||||
|
|
||||||
from lerobot.common.constants import HF_LEROBOT_HOME
|
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||||
|
from lerobot.common.datasets.audio_utils import (
|
||||||
|
decode_audio,
|
||||||
|
encode_audio,
|
||||||
|
get_audio_info,
|
||||||
|
)
|
||||||
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
|
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
|
||||||
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
|
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
|
DEFAULT_AUDIO_CHUNK_DURATION,
|
||||||
DEFAULT_FEATURES,
|
DEFAULT_FEATURES,
|
||||||
DEFAULT_IMAGE_PATH,
|
DEFAULT_IMAGE_PATH,
|
||||||
|
DEFAULT_RAW_AUDIO_PATH,
|
||||||
INFO_PATH,
|
INFO_PATH,
|
||||||
TASKS_PATH,
|
TASKS_PATH,
|
||||||
append_jsonlines,
|
append_jsonlines,
|
||||||
|
@ -72,6 +80,7 @@ from lerobot.common.datasets.video_utils import (
|
||||||
get_safe_default_codec,
|
get_safe_default_codec,
|
||||||
get_video_info,
|
get_video_info,
|
||||||
)
|
)
|
||||||
|
from lerobot.common.robot_devices.microphones.utils import Microphone
|
||||||
from lerobot.common.robot_devices.robots.utils import Robot
|
from lerobot.common.robot_devices.robots.utils import Robot
|
||||||
|
|
||||||
CODEBASE_VERSION = "v2.1"
|
CODEBASE_VERSION = "v2.1"
|
||||||
|
@ -142,6 +151,14 @@ class LeRobotDatasetMetadata:
|
||||||
fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
|
fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
|
||||||
return Path(fpath)
|
return Path(fpath)
|
||||||
|
|
||||||
|
def get_compressed_audio_file_path(self, episode_index: int, audio_key: str) -> Path:
|
||||||
|
"""Returns the path of the compressed (i.e. encoded) audio file."""
|
||||||
|
episode_chunk = self.get_episode_chunk(episode_index)
|
||||||
|
fpath = self.audio_path.format(
|
||||||
|
episode_chunk=episode_chunk, audio_key=audio_key, episode_index=episode_index
|
||||||
|
)
|
||||||
|
return self.root / fpath
|
||||||
|
|
||||||
def get_episode_chunk(self, ep_index: int) -> int:
|
def get_episode_chunk(self, ep_index: int) -> int:
|
||||||
return ep_index // self.chunks_size
|
return ep_index // self.chunks_size
|
||||||
|
|
||||||
|
@ -155,6 +172,11 @@ class LeRobotDatasetMetadata:
|
||||||
"""Formattable string for the video files."""
|
"""Formattable string for the video files."""
|
||||||
return self.info["video_path"]
|
return self.info["video_path"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def audio_path(self) -> str | None:
|
||||||
|
"""Formattable string for the audio files."""
|
||||||
|
return self.info["audio_path"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def robot_type(self) -> str | None:
|
def robot_type(self) -> str | None:
|
||||||
"""Robot type used in recording this dataset."""
|
"""Robot type used in recording this dataset."""
|
||||||
|
@ -185,6 +207,11 @@ class LeRobotDatasetMetadata:
|
||||||
"""Keys to access visual modalities (regardless of their storage method)."""
|
"""Keys to access visual modalities (regardless of their storage method)."""
|
||||||
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
|
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def audio_keys(self) -> list[str]:
|
||||||
|
"""Keys to access audio modalities."""
|
||||||
|
return [key for key, ft in self.features.items() if ft["dtype"] == "audio"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def names(self) -> dict[str, list | dict]:
|
def names(self) -> dict[str, list | dict]:
|
||||||
"""Names of the various dimensions of vector modalities."""
|
"""Names of the various dimensions of vector modalities."""
|
||||||
|
@ -264,6 +291,10 @@ class LeRobotDatasetMetadata:
|
||||||
if len(self.video_keys) > 0:
|
if len(self.video_keys) > 0:
|
||||||
self.update_video_info()
|
self.update_video_info()
|
||||||
|
|
||||||
|
self.info["total_audio"] += len(self.audio_keys)
|
||||||
|
if len(self.audio_keys) > 0:
|
||||||
|
self.update_audio_info()
|
||||||
|
|
||||||
write_info(self.info, self.root)
|
write_info(self.info, self.root)
|
||||||
|
|
||||||
episode_dict = {
|
episode_dict = {
|
||||||
|
@ -288,6 +319,19 @@ class LeRobotDatasetMetadata:
|
||||||
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
|
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
|
||||||
self.info["features"][key]["info"] = get_video_info(video_path)
|
self.info["features"][key]["info"] = get_video_info(video_path)
|
||||||
|
|
||||||
|
def update_audio_info(self) -> None:
|
||||||
|
"""
|
||||||
|
Warning: this function writes info from first episode audio, implicitly assuming that all audio have
|
||||||
|
been encoded the same way. Also, this means it assumes the first episode exists.
|
||||||
|
"""
|
||||||
|
for key in self.audio_keys:
|
||||||
|
if (
|
||||||
|
not self.features[key].get("info", None)
|
||||||
|
or (len(self.features[key]["info"]) == 1 and "sample_rate" in self.features[key]["info"])
|
||||||
|
): # Overwrite if info is empty or only contains sample rate (necessary to correctly save audio files recorded by LeKiwi)
|
||||||
|
audio_path = self.root / self.get_compressed_audio_file_path(0, key)
|
||||||
|
self.info["features"][key]["info"] = get_audio_info(audio_path)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
feature_keys = list(self.features)
|
feature_keys = list(self.features)
|
||||||
return (
|
return (
|
||||||
|
@ -363,7 +407,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
revision: str | None = None,
|
revision: str | None = None,
|
||||||
force_cache_sync: bool = False,
|
force_cache_sync: bool = False,
|
||||||
download_videos: bool = True,
|
download_videos: bool = True,
|
||||||
|
download_audio: bool = True,
|
||||||
video_backend: str | None = None,
|
video_backend: str | None = None,
|
||||||
|
audio_backend: str | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
2 modes are available for instantiating this class, depending on 2 different use cases:
|
2 modes are available for instantiating this class, depending on 2 different use cases:
|
||||||
|
@ -394,7 +440,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
- tasks contains the prompts for each task of the dataset, which can be used for
|
- tasks contains the prompts for each task of the dataset, which can be used for
|
||||||
task-conditioned training.
|
task-conditioned training.
|
||||||
- hf_dataset (from datasets.Dataset), which will read any values from parquet files.
|
- hf_dataset (from datasets.Dataset), which will read any values from parquet files.
|
||||||
- videos (optional) from which frames are loaded to be synchronous with data from parquet files.
|
- videos (optional) from which frames and audio (if any) are loaded to be synchronous with data from parquet files and audio.
|
||||||
|
- audio (optional) from which audio is loaded to be synchronous with data from parquet files and videos.
|
||||||
|
|
||||||
A typical LeRobotDataset looks like this from its root path:
|
A typical LeRobotDataset looks like this from its root path:
|
||||||
.
|
.
|
||||||
|
@ -415,17 +462,31 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
│ ├── info.json
|
│ ├── info.json
|
||||||
│ ├── stats.json
|
│ ├── stats.json
|
||||||
│ └── tasks.jsonl
|
│ └── tasks.jsonl
|
||||||
└── videos
|
├── videos
|
||||||
|
│ ├── chunk-000
|
||||||
|
│ │ ├── observation.images.laptop
|
||||||
|
│ │ │ ├── episode_000000.mp4
|
||||||
|
│ │ │ ├── episode_000001.mp4
|
||||||
|
│ │ │ ├── episode_000002.mp4
|
||||||
|
│ │ │ └── ...
|
||||||
|
│ │ ├── observation.images.phone
|
||||||
|
│ │ │ ├── episode_000000.mp4
|
||||||
|
│ │ │ ├── episode_000001.mp4
|
||||||
|
│ │ │ ├── episode_000002.mp4
|
||||||
|
│ │ │ └── ...
|
||||||
|
│ ├── chunk-001
|
||||||
|
│ └── ...
|
||||||
|
└── audio
|
||||||
├── chunk-000
|
├── chunk-000
|
||||||
│ ├── observation.images.laptop
|
│ ├── observation.audio.laptop
|
||||||
│ │ ├── episode_000000.mp4
|
│ │ ├── episode_000000.m4a
|
||||||
│ │ ├── episode_000001.mp4
|
│ │ ├── episode_000001.m4a
|
||||||
│ │ ├── episode_000002.mp4
|
│ │ ├── episode_000002.m4a
|
||||||
│ │ └── ...
|
│ │ └── ...
|
||||||
│ ├── observation.images.phone
|
│ ├── observation.audio.phone
|
||||||
│ │ ├── episode_000000.mp4
|
│ │ ├── episode_000000.m4a
|
||||||
│ │ ├── episode_000001.mp4
|
│ │ ├── episode_000001.m4a
|
||||||
│ │ ├── episode_000002.mp4
|
│ │ ├── episode_000002.m4a
|
||||||
│ │ └── ...
|
│ │ └── ...
|
||||||
├── chunk-001
|
├── chunk-001
|
||||||
└── ...
|
└── ...
|
||||||
|
@ -463,8 +524,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
|
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
|
||||||
video files are already present on local disk, they won't be downloaded again. Defaults to
|
video files are already present on local disk, they won't be downloaded again. Defaults to
|
||||||
True.
|
True.
|
||||||
|
download_audio (bool, optional): Flag to download the audio (see download_videos). Defaults to True.
|
||||||
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'.
|
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'.
|
||||||
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
|
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
|
||||||
|
audio_backend (str | None, optional): Audio backend to use for decoding audio. Defaults to 'ffmpeg' decoder used by 'torchaudio'.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.repo_id = repo_id
|
self.repo_id = repo_id
|
||||||
|
@ -475,6 +538,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
self.tolerance_s = tolerance_s
|
self.tolerance_s = tolerance_s
|
||||||
self.revision = revision if revision else CODEBASE_VERSION
|
self.revision = revision if revision else CODEBASE_VERSION
|
||||||
self.video_backend = video_backend if video_backend else get_safe_default_codec()
|
self.video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||||
|
self.audio_backend = (
|
||||||
|
audio_backend if audio_backend else "ffmpeg"
|
||||||
|
) # Waiting for torchcodec release #TODO(CarolinePascal)
|
||||||
self.delta_indices = None
|
self.delta_indices = None
|
||||||
|
|
||||||
# Unused attributes
|
# Unused attributes
|
||||||
|
@ -499,7 +565,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
self.hf_dataset = self.load_hf_dataset()
|
self.hf_dataset = self.load_hf_dataset()
|
||||||
except (AssertionError, FileNotFoundError, NotADirectoryError):
|
except (AssertionError, FileNotFoundError, NotADirectoryError):
|
||||||
self.revision = get_safe_version(self.repo_id, self.revision)
|
self.revision = get_safe_version(self.repo_id, self.revision)
|
||||||
self.download_episodes(download_videos)
|
self.download_episodes(download_videos, download_audio)
|
||||||
self.hf_dataset = self.load_hf_dataset()
|
self.hf_dataset = self.load_hf_dataset()
|
||||||
|
|
||||||
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
||||||
|
@ -510,6 +576,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
|
ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
|
||||||
check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
|
check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
|
||||||
|
|
||||||
|
# TODO(CarolinePascal) : add check for audio duration with respect to episode duration BUT this will be CPU expensive if there are many episodes !
|
||||||
|
|
||||||
# Setup delta_indices
|
# Setup delta_indices
|
||||||
if self.delta_timestamps is not None:
|
if self.delta_timestamps is not None:
|
||||||
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
|
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
|
||||||
|
@ -522,6 +590,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
license: str | None = "apache-2.0",
|
license: str | None = "apache-2.0",
|
||||||
tag_version: bool = True,
|
tag_version: bool = True,
|
||||||
push_videos: bool = True,
|
push_videos: bool = True,
|
||||||
|
push_audio: bool = True,
|
||||||
private: bool = False,
|
private: bool = False,
|
||||||
allow_patterns: list[str] | str | None = None,
|
allow_patterns: list[str] | str | None = None,
|
||||||
upload_large_folder: bool = False,
|
upload_large_folder: bool = False,
|
||||||
|
@ -530,6 +599,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
ignore_patterns = ["images/"]
|
ignore_patterns = ["images/"]
|
||||||
if not push_videos:
|
if not push_videos:
|
||||||
ignore_patterns.append("videos/")
|
ignore_patterns.append("videos/")
|
||||||
|
if not push_audio:
|
||||||
|
ignore_patterns.append("audio/")
|
||||||
|
|
||||||
hub_api = HfApi()
|
hub_api = HfApi()
|
||||||
hub_api.create_repo(
|
hub_api.create_repo(
|
||||||
|
@ -585,7 +656,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
ignore_patterns=ignore_patterns,
|
ignore_patterns=ignore_patterns,
|
||||||
)
|
)
|
||||||
|
|
||||||
def download_episodes(self, download_videos: bool = True) -> None:
|
def download_episodes(self, download_videos: bool = True, download_audio: bool = True) -> None:
|
||||||
"""Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this
|
"""Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this
|
||||||
will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
|
will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
|
||||||
dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present
|
dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present
|
||||||
|
@ -594,7 +665,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
# TODO(rcadene, aliberts): implement faster transfer
|
# TODO(rcadene, aliberts): implement faster transfer
|
||||||
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
|
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
|
||||||
files = None
|
files = None
|
||||||
ignore_patterns = None if download_videos else "videos/"
|
ignore_patterns = []
|
||||||
|
if not download_videos:
|
||||||
|
ignore_patterns.append("videos/")
|
||||||
|
if not download_audio:
|
||||||
|
ignore_patterns.append("audio/")
|
||||||
if self.episodes is not None:
|
if self.episodes is not None:
|
||||||
files = self.get_episodes_file_paths()
|
files = self.get_episodes_file_paths()
|
||||||
|
|
||||||
|
@ -611,6 +686,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
]
|
]
|
||||||
fpaths += video_files
|
fpaths += video_files
|
||||||
|
|
||||||
|
if len(self.meta.audio_keys) > 0:
|
||||||
|
audio_files = [
|
||||||
|
str(self.meta.get_compressed_audio_file_path(ep_idx, audio_key))
|
||||||
|
for audio_key in self.meta.audio_keys
|
||||||
|
for ep_idx in episodes
|
||||||
|
]
|
||||||
|
fpaths += audio_files
|
||||||
|
|
||||||
return fpaths
|
return fpaths
|
||||||
|
|
||||||
def load_hf_dataset(self) -> datasets.Dataset:
|
def load_hf_dataset(self) -> datasets.Dataset:
|
||||||
|
@ -677,7 +760,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
}
|
}
|
||||||
return query_indices, padding
|
return query_indices, padding
|
||||||
|
|
||||||
def _get_query_timestamps(
|
def _get_query_timestamps_video(
|
||||||
self,
|
self,
|
||||||
current_ts: float,
|
current_ts: float,
|
||||||
query_indices: dict[str, list[int]] | None = None,
|
query_indices: dict[str, list[int]] | None = None,
|
||||||
|
@ -692,11 +775,27 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
return query_timestamps
|
return query_timestamps
|
||||||
|
|
||||||
|
# TODO(CarolinePascal): add variable query durations
|
||||||
|
def _get_query_timestamps_audio(
|
||||||
|
self,
|
||||||
|
current_ts: float,
|
||||||
|
query_indices: dict[str, list[int]] | None = None,
|
||||||
|
) -> dict[str, list[float]]:
|
||||||
|
query_timestamps = {}
|
||||||
|
for key in self.meta.audio_keys:
|
||||||
|
if query_indices is not None and key in query_indices:
|
||||||
|
timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
|
||||||
|
query_timestamps[key] = torch.stack(timestamps).tolist()
|
||||||
|
else:
|
||||||
|
query_timestamps[key] = [current_ts]
|
||||||
|
|
||||||
|
return query_timestamps
|
||||||
|
|
||||||
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
||||||
return {
|
return {
|
||||||
key: torch.stack(self.hf_dataset.select(q_idx)[key])
|
key: torch.stack(self.hf_dataset.select(q_idx)[key])
|
||||||
for key, q_idx in query_indices.items()
|
for key, q_idx in query_indices.items()
|
||||||
if key not in self.meta.video_keys
|
if key not in self.meta.video_keys and key not in self.meta.audio_keys
|
||||||
}
|
}
|
||||||
|
|
||||||
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
|
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
|
||||||
|
@ -713,6 +812,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
return item
|
return item
|
||||||
|
|
||||||
|
# TODO(CarolinePascal): add variable query durations
|
||||||
|
def _query_audio(
|
||||||
|
self, query_timestamps: dict[str, list[float]], query_duration: float, ep_idx: int
|
||||||
|
) -> dict[str, torch.Tensor]:
|
||||||
|
item = {}
|
||||||
|
for audio_key, query_ts in query_timestamps.items():
|
||||||
|
audio_path = self.root / self.meta.get_compressed_audio_file_path(ep_idx, audio_key)
|
||||||
|
audio_chunk = decode_audio(audio_path, query_ts, query_duration, self.audio_backend)
|
||||||
|
item[audio_key] = audio_chunk.squeeze(0)
|
||||||
|
return item
|
||||||
|
|
||||||
def _add_padding_keys(self, item: dict, padding: dict[str, list[bool]]) -> dict:
|
def _add_padding_keys(self, item: dict, padding: dict[str, list[bool]]) -> dict:
|
||||||
for key, val in padding.items():
|
for key, val in padding.items():
|
||||||
item[key] = torch.BoolTensor(val)
|
item[key] = torch.BoolTensor(val)
|
||||||
|
@ -733,11 +843,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
for key, val in query_result.items():
|
for key, val in query_result.items():
|
||||||
item[key] = val
|
item[key] = val
|
||||||
|
|
||||||
if len(self.meta.video_keys) > 0:
|
if len(self.meta.video_keys) > 0 or len(self.meta.audio_keys) > 0:
|
||||||
current_ts = item["timestamp"].item()
|
current_ts = item["timestamp"].item()
|
||||||
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
|
|
||||||
|
query_timestamps = self._get_query_timestamps_video(current_ts, query_indices)
|
||||||
video_frames = self._query_videos(query_timestamps, ep_idx)
|
video_frames = self._query_videos(query_timestamps, ep_idx)
|
||||||
item = {**video_frames, **item}
|
item = {**item, **video_frames}
|
||||||
|
|
||||||
|
query_timestamps = self._get_query_timestamps_audio(current_ts, query_indices)
|
||||||
|
audio_chunks = self._query_audio(query_timestamps, DEFAULT_AUDIO_CHUNK_DURATION, ep_idx)
|
||||||
|
item = {**item, **audio_chunks}
|
||||||
|
|
||||||
if self.image_transforms is not None:
|
if self.image_transforms is not None:
|
||||||
image_keys = self.meta.camera_keys
|
image_keys = self.meta.camera_keys
|
||||||
|
@ -777,6 +892,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
)
|
)
|
||||||
return self.root / fpath
|
return self.root / fpath
|
||||||
|
|
||||||
|
def _get_raw_audio_file_path(self, episode_index: int, audio_key: str) -> Path:
|
||||||
|
fpath = DEFAULT_RAW_AUDIO_PATH.format(audio_key=audio_key, episode_index=episode_index)
|
||||||
|
return self.root / fpath
|
||||||
|
|
||||||
def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None:
|
def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None:
|
||||||
if self.image_writer is None:
|
if self.image_writer is None:
|
||||||
if isinstance(image, torch.Tensor):
|
if isinstance(image, torch.Tensor):
|
||||||
|
@ -827,11 +946,39 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
img_path.parent.mkdir(parents=True, exist_ok=True)
|
img_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
self._save_image(frame[key], img_path)
|
self._save_image(frame[key], img_path)
|
||||||
self.episode_buffer[key].append(str(img_path))
|
self.episode_buffer[key].append(str(img_path))
|
||||||
|
elif self.features[key]["dtype"] == "audio":
|
||||||
|
if (
|
||||||
|
self.meta.robot_type is not None and self.meta.robot_type.startswith("lekiwi")
|
||||||
|
): # Raw data storage should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner
|
||||||
|
self.episode_buffer[key].append(frame[key])
|
||||||
|
else: # Otherwise, only the audio file path is stored in the episode buffer
|
||||||
|
if frame_index == 0:
|
||||||
|
audio_path = self._get_raw_audio_file_path(
|
||||||
|
episode_index=self.episode_buffer["episode_index"], audio_key=key
|
||||||
|
)
|
||||||
|
self.episode_buffer[key].append(str(audio_path))
|
||||||
else:
|
else:
|
||||||
self.episode_buffer[key].append(frame[key])
|
self.episode_buffer[key].append(frame[key])
|
||||||
|
|
||||||
self.episode_buffer["size"] += 1
|
self.episode_buffer["size"] += 1
|
||||||
|
|
||||||
|
def add_microphone_recording(self, microphone: Microphone, microphone_key: str) -> None:
|
||||||
|
"""
|
||||||
|
Starts recording audio data provided by the microphone and directly writes it in a .wav file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
audio_dir = self._get_raw_audio_file_path(
|
||||||
|
self.num_episodes, "observation.audio." + microphone_key
|
||||||
|
).parent
|
||||||
|
if not audio_dir.is_dir():
|
||||||
|
audio_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
microphone.start_recording(
|
||||||
|
output_file=self._get_raw_audio_file_path(
|
||||||
|
self.num_episodes, "observation.audio." + microphone_key
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def save_episode(self, episode_data: dict | None = None) -> None:
|
def save_episode(self, episode_data: dict | None = None) -> None:
|
||||||
"""
|
"""
|
||||||
This will save to disk the current episode in self.episode_buffer.
|
This will save to disk the current episode in self.episode_buffer.
|
||||||
|
@ -869,16 +1016,39 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
# are processed separately by storing image path and frame info as meta data
|
# are processed separately by storing image path and frame info as meta data
|
||||||
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
|
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
|
||||||
continue
|
continue
|
||||||
|
elif ft["dtype"] == "audio":
|
||||||
|
if (
|
||||||
|
self.meta.robot_type is not None and self.meta.robot_type.startswith("lekiwi")
|
||||||
|
): # Raw data storage should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner
|
||||||
|
episode_buffer[key] = np.concatenate(episode_buffer[key], axis=0)
|
||||||
|
continue
|
||||||
episode_buffer[key] = np.stack(episode_buffer[key])
|
episode_buffer[key] = np.stack(episode_buffer[key])
|
||||||
|
|
||||||
self._wait_image_writer()
|
self._wait_image_writer()
|
||||||
self._save_episode_table(episode_buffer, episode_index)
|
self._save_episode_table(episode_buffer, episode_index)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.meta.robot_type is not None and self.meta.robot_type.startswith("lekiwi")
|
||||||
|
): # Raw data storage should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner
|
||||||
|
for key in self.meta.audio_keys:
|
||||||
|
audio_path = self._get_raw_audio_file_path(
|
||||||
|
episode_index=self.episode_buffer["episode_index"][0], audio_key=key
|
||||||
|
)
|
||||||
|
with sf.SoundFile(
|
||||||
|
audio_path,
|
||||||
|
mode="w",
|
||||||
|
samplerate=self.meta.features[key]["info"]["sample_rate"],
|
||||||
|
channels=self.meta.features[key]["shape"][0],
|
||||||
|
) as file:
|
||||||
|
file.write(episode_buffer[key])
|
||||||
|
|
||||||
ep_stats = compute_episode_stats(episode_buffer, self.features)
|
ep_stats = compute_episode_stats(episode_buffer, self.features)
|
||||||
|
|
||||||
if len(self.meta.video_keys) > 0:
|
if len(self.meta.video_keys) > 0:
|
||||||
video_paths = self.encode_episode_videos(episode_index)
|
self.encode_episode_videos(episode_index)
|
||||||
for key in self.meta.video_keys:
|
|
||||||
episode_buffer[key] = video_paths[key]
|
if len(self.meta.audio_keys) > 0:
|
||||||
|
self.encode_episode_audio(episode_index)
|
||||||
|
|
||||||
# `meta.save_episode` be executed after encoding the videos
|
# `meta.save_episode` be executed after encoding the videos
|
||||||
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
|
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
|
||||||
|
@ -904,6 +1074,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
if img_dir.is_dir():
|
if img_dir.is_dir():
|
||||||
shutil.rmtree(self.root / "images")
|
shutil.rmtree(self.root / "images")
|
||||||
|
|
||||||
|
# delete raw audio files
|
||||||
|
raw_audio_files = list(self.root.rglob("*.wav"))
|
||||||
|
for raw_audio_file in raw_audio_files:
|
||||||
|
raw_audio_file.unlink()
|
||||||
|
if len(list(raw_audio_file.parent.iterdir())) == 0:
|
||||||
|
raw_audio_file.parent.rmdir()
|
||||||
|
|
||||||
if not episode_data: # Reset the buffer
|
if not episode_data: # Reset the buffer
|
||||||
self.episode_buffer = self.create_episode_buffer()
|
self.episode_buffer = self.create_episode_buffer()
|
||||||
|
|
||||||
|
@ -971,19 +1148,40 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
since video encoding with ffmpeg is already using multithreading.
|
since video encoding with ffmpeg is already using multithreading.
|
||||||
"""
|
"""
|
||||||
video_paths = {}
|
video_paths = {}
|
||||||
for key in self.meta.video_keys:
|
for video_key in self.meta.video_keys:
|
||||||
video_path = self.root / self.meta.get_video_file_path(episode_index, key)
|
video_path = self.root / self.meta.get_video_file_path(episode_index, video_key)
|
||||||
video_paths[key] = str(video_path)
|
video_paths[video_key] = str(video_path)
|
||||||
if video_path.is_file():
|
if video_path.is_file():
|
||||||
# Skip if video is already encoded. Could be the case when resuming data recording.
|
# Skip if video is already encoded. Could be the case when resuming data recording.
|
||||||
continue
|
continue
|
||||||
img_dir = self._get_image_file_path(
|
img_dir = self._get_image_file_path(
|
||||||
episode_index=episode_index, image_key=key, frame_index=0
|
episode_index=episode_index, image_key=video_key, frame_index=0
|
||||||
).parent
|
).parent
|
||||||
|
|
||||||
encode_video_frames(img_dir, video_path, self.fps, overwrite=True)
|
encode_video_frames(img_dir, video_path, self.fps, overwrite=True)
|
||||||
|
|
||||||
return video_paths
|
return video_paths
|
||||||
|
|
||||||
|
def encode_episode_audio(self, episode_index: int) -> dict:
|
||||||
|
"""
|
||||||
|
Use ffmpeg to convert .wav raw audio files into .m4a audio files.
|
||||||
|
Note: `encode_episode_audio` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||||
|
since video encoding with ffmpeg is already using multithreading.
|
||||||
|
"""
|
||||||
|
audio_paths = {}
|
||||||
|
for audio_key in self.meta.audio_keys:
|
||||||
|
input_audio_path = self.root / self._get_raw_audio_file_path(episode_index, audio_key)
|
||||||
|
output_audio_path = self.root / self.meta.get_compressed_audio_file_path(episode_index, audio_key)
|
||||||
|
|
||||||
|
audio_paths[audio_key] = str(output_audio_path)
|
||||||
|
if output_audio_path.is_file():
|
||||||
|
# Skip if video is already encoded. Could be the case when resuming data recording.
|
||||||
|
continue
|
||||||
|
|
||||||
|
encode_audio(input_audio_path, output_audio_path, overwrite=True)
|
||||||
|
|
||||||
|
return audio_paths
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(
|
def create(
|
||||||
cls,
|
cls,
|
||||||
|
@ -998,6 +1196,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
image_writer_processes: int = 0,
|
image_writer_processes: int = 0,
|
||||||
image_writer_threads: int = 0,
|
image_writer_threads: int = 0,
|
||||||
video_backend: str | None = None,
|
video_backend: str | None = None,
|
||||||
|
audio_backend: str | None = None,
|
||||||
) -> "LeRobotDataset":
|
) -> "LeRobotDataset":
|
||||||
"""Create a LeRobot Dataset from scratch in order to record data."""
|
"""Create a LeRobot Dataset from scratch in order to record data."""
|
||||||
obj = cls.__new__(cls)
|
obj = cls.__new__(cls)
|
||||||
|
@ -1029,6 +1228,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
obj.delta_indices = None
|
obj.delta_indices = None
|
||||||
obj.episode_data_index = None
|
obj.episode_data_index = None
|
||||||
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
||||||
|
obj.audio_backend = (
|
||||||
|
audio_backend if audio_backend is not None else "ffmpeg"
|
||||||
|
) # Waiting for torchcodec release #TODO(CarolinePascal)
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
@ -1049,6 +1251,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||||
tolerances_s: dict | None = None,
|
tolerances_s: dict | None = None,
|
||||||
download_videos: bool = True,
|
download_videos: bool = True,
|
||||||
video_backend: str | None = None,
|
video_backend: str | None = None,
|
||||||
|
audio_backend: str | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.repo_ids = repo_ids
|
self.repo_ids = repo_ids
|
||||||
|
@ -1066,6 +1269,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||||
tolerance_s=self.tolerances_s[repo_id],
|
tolerance_s=self.tolerances_s[repo_id],
|
||||||
download_videos=download_videos,
|
download_videos=download_videos,
|
||||||
video_backend=video_backend,
|
video_backend=video_backend,
|
||||||
|
audio_backend=audio_backend,
|
||||||
)
|
)
|
||||||
for repo_id in repo_ids
|
for repo_id in repo_ids
|
||||||
]
|
]
|
||||||
|
|
|
@ -33,6 +33,7 @@ from datasets.table import embed_table_storage
|
||||||
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
|
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
|
||||||
from huggingface_hub.errors import RevisionNotFoundError
|
from huggingface_hub.errors import RevisionNotFoundError
|
||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
|
from soundfile import read
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
|
||||||
from lerobot.common.datasets.backward_compatibility import (
|
from lerobot.common.datasets.backward_compatibility import (
|
||||||
|
@ -55,6 +56,10 @@ TASKS_PATH = "meta/tasks.jsonl"
|
||||||
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
||||||
DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
|
DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
|
||||||
DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
|
DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
|
||||||
|
DEFAULT_RAW_AUDIO_PATH = "audio/{audio_key}/episode_{episode_index:06d}.wav"
|
||||||
|
DEFAULT_COMPRESSED_AUDIO_PATH = "audio/chunk-{episode_chunk:03d}/{audio_key}/episode_{episode_index:06d}.m4a"
|
||||||
|
|
||||||
|
DEFAULT_AUDIO_CHUNK_DURATION = 0.5 # seconds
|
||||||
|
|
||||||
DATASET_CARD_TEMPLATE = """
|
DATASET_CARD_TEMPLATE = """
|
||||||
---
|
---
|
||||||
|
@ -255,6 +260,11 @@ def load_image_as_numpy(
|
||||||
return img_array
|
return img_array
|
||||||
|
|
||||||
|
|
||||||
|
def load_audio_from_path(fpath: str | Path) -> np.ndarray:
|
||||||
|
audio_data, _ = read(fpath, dtype="float32")
|
||||||
|
return audio_data
|
||||||
|
|
||||||
|
|
||||||
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||||
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
|
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
|
||||||
to torch tensors. Importantly, images are converted from PIL, which corresponds to
|
to torch tensors. Importantly, images are converted from PIL, which corresponds to
|
||||||
|
@ -363,7 +373,7 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
|
||||||
def get_hf_features_from_features(features: dict) -> datasets.Features:
|
def get_hf_features_from_features(features: dict) -> datasets.Features:
|
||||||
hf_features = {}
|
hf_features = {}
|
||||||
for key, ft in features.items():
|
for key, ft in features.items():
|
||||||
if ft["dtype"] == "video":
|
if ft["dtype"] == "video" or ft["dtype"] == "audio":
|
||||||
continue
|
continue
|
||||||
elif ft["dtype"] == "image":
|
elif ft["dtype"] == "image":
|
||||||
hf_features[key] = datasets.Image()
|
hf_features[key] = datasets.Image()
|
||||||
|
@ -394,7 +404,7 @@ def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict:
|
||||||
key: {"dtype": "video" if use_videos else "image", **ft}
|
key: {"dtype": "video" if use_videos else "image", **ft}
|
||||||
for key, ft in robot.camera_features.items()
|
for key, ft in robot.camera_features.items()
|
||||||
}
|
}
|
||||||
return {**robot.motor_features, **camera_ft, **DEFAULT_FEATURES}
|
return {**robot.motor_features, **camera_ft, **robot.microphone_features, **DEFAULT_FEATURES}
|
||||||
|
|
||||||
|
|
||||||
def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]:
|
def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]:
|
||||||
|
@ -442,12 +452,14 @@ def create_empty_dataset_info(
|
||||||
"total_frames": 0,
|
"total_frames": 0,
|
||||||
"total_tasks": 0,
|
"total_tasks": 0,
|
||||||
"total_videos": 0,
|
"total_videos": 0,
|
||||||
|
"total_audio": 0,
|
||||||
"total_chunks": 0,
|
"total_chunks": 0,
|
||||||
"chunks_size": DEFAULT_CHUNK_SIZE,
|
"chunks_size": DEFAULT_CHUNK_SIZE,
|
||||||
"fps": fps,
|
"fps": fps,
|
||||||
"splits": {},
|
"splits": {},
|
||||||
"data_path": DEFAULT_PARQUET_PATH,
|
"data_path": DEFAULT_PARQUET_PATH,
|
||||||
"video_path": DEFAULT_VIDEO_PATH if use_videos else None,
|
"video_path": DEFAULT_VIDEO_PATH if use_videos else None,
|
||||||
|
"audio_path": DEFAULT_COMPRESSED_AUDIO_PATH,
|
||||||
"features": features,
|
"features": features,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -721,6 +733,7 @@ def validate_features_presence(
|
||||||
):
|
):
|
||||||
error_message = ""
|
error_message = ""
|
||||||
missing_features = expected_features - actual_features
|
missing_features = expected_features - actual_features
|
||||||
|
missing_features = {feature for feature in missing_features if "observation.audio" not in feature}
|
||||||
extra_features = actual_features - (expected_features | optional_features)
|
extra_features = actual_features - (expected_features | optional_features)
|
||||||
|
|
||||||
if missing_features or extra_features:
|
if missing_features or extra_features:
|
||||||
|
@ -740,6 +753,8 @@ def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray
|
||||||
return validate_feature_numpy_array(name, expected_dtype, expected_shape, value)
|
return validate_feature_numpy_array(name, expected_dtype, expected_shape, value)
|
||||||
elif expected_dtype in ["image", "video"]:
|
elif expected_dtype in ["image", "video"]:
|
||||||
return validate_feature_image_or_video(name, expected_shape, value)
|
return validate_feature_image_or_video(name, expected_shape, value)
|
||||||
|
elif expected_dtype == "audio":
|
||||||
|
return validate_feature_audio(name, expected_shape, value)
|
||||||
elif expected_dtype == "string":
|
elif expected_dtype == "string":
|
||||||
return validate_feature_string(name, value)
|
return validate_feature_string(name, value)
|
||||||
else:
|
else:
|
||||||
|
@ -781,6 +796,23 @@ def validate_feature_image_or_video(name: str, expected_shape: list[str], value:
|
||||||
return error_message
|
return error_message
|
||||||
|
|
||||||
|
|
||||||
|
def validate_feature_audio(name: str, expected_shape: list[str], value: np.ndarray):
|
||||||
|
error_message = ""
|
||||||
|
if isinstance(value, np.ndarray):
|
||||||
|
actual_shape = value.shape
|
||||||
|
c = expected_shape
|
||||||
|
if len(actual_shape) != 2 or (
|
||||||
|
actual_shape[-1] != c[-1] and actual_shape[0] != c[0]
|
||||||
|
): # The number of frames might be different
|
||||||
|
error_message += (
|
||||||
|
f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c,)}'.\n"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
error_message += f"The feature '{name}' is expected to be of type 'np.ndarray', but type '{type(value)}' provided instead.\n"
|
||||||
|
|
||||||
|
return error_message
|
||||||
|
|
||||||
|
|
||||||
def validate_feature_string(name: str, value: str):
|
def validate_feature_string(name: str, value: str):
|
||||||
if not isinstance(value, str):
|
if not isinstance(value, str):
|
||||||
return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n"
|
return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n"
|
||||||
|
|
|
@ -260,35 +260,39 @@ def encode_video_frames(
|
||||||
imgs_dir = Path(imgs_dir)
|
imgs_dir = Path(imgs_dir)
|
||||||
video_path.parent.mkdir(parents=True, exist_ok=True)
|
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
ffmpeg_args = OrderedDict(
|
ffmpeg_video_args = OrderedDict(
|
||||||
[
|
[
|
||||||
("-f", "image2"),
|
("-f", "image2"),
|
||||||
("-r", str(fps)),
|
("-r", str(fps)),
|
||||||
("-i", str(imgs_dir / "frame_%06d.png")),
|
("-i", str(Path(imgs_dir) / "frame_%06d.png")),
|
||||||
("-vcodec", vcodec),
|
|
||||||
("-pix_fmt", pix_fmt),
|
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ffmpeg_encoding_args = OrderedDict(
|
||||||
|
[
|
||||||
|
("-pix_fmt", pix_fmt),
|
||||||
|
("-vcodec", vcodec),
|
||||||
|
]
|
||||||
|
)
|
||||||
if g is not None:
|
if g is not None:
|
||||||
ffmpeg_args["-g"] = str(g)
|
ffmpeg_encoding_args["-g"] = str(g)
|
||||||
|
|
||||||
if crf is not None:
|
if crf is not None:
|
||||||
ffmpeg_args["-crf"] = str(crf)
|
ffmpeg_encoding_args["-crf"] = str(crf)
|
||||||
|
|
||||||
if fast_decode:
|
if fast_decode:
|
||||||
key = "-svtav1-params" if vcodec == "libsvtav1" else "-tune"
|
key = "-svtav1-params" if vcodec == "libsvtav1" else "-tune"
|
||||||
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
|
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
|
||||||
ffmpeg_args[key] = value
|
ffmpeg_encoding_args[key] = value
|
||||||
|
|
||||||
if log_level is not None:
|
if log_level is not None:
|
||||||
ffmpeg_args["-loglevel"] = str(log_level)
|
ffmpeg_encoding_args["-loglevel"] = str(log_level)
|
||||||
|
|
||||||
ffmpeg_args = [item for pair in ffmpeg_args.items() for item in pair]
|
ffmpeg_args = [item for pair in ffmpeg_video_args.items() for item in pair]
|
||||||
|
ffmpeg_args += [item for pair in ffmpeg_encoding_args.items() for item in pair]
|
||||||
if overwrite:
|
if overwrite:
|
||||||
ffmpeg_args.append("-y")
|
ffmpeg_args.append("-y")
|
||||||
|
|
||||||
ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(video_path)]
|
ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(video_path)]
|
||||||
|
|
||||||
# redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal
|
# redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal
|
||||||
subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL)
|
subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL)
|
||||||
|
|
||||||
|
@ -331,42 +335,6 @@ with warnings.catch_warnings():
|
||||||
register_feature(VideoFrame, "VideoFrame")
|
register_feature(VideoFrame, "VideoFrame")
|
||||||
|
|
||||||
|
|
||||||
def get_audio_info(video_path: Path | str) -> dict:
|
|
||||||
ffprobe_audio_cmd = [
|
|
||||||
"ffprobe",
|
|
||||||
"-v",
|
|
||||||
"error",
|
|
||||||
"-select_streams",
|
|
||||||
"a:0",
|
|
||||||
"-show_entries",
|
|
||||||
"stream=channels,codec_name,bit_rate,sample_rate,bit_depth,channel_layout,duration",
|
|
||||||
"-of",
|
|
||||||
"json",
|
|
||||||
str(video_path),
|
|
||||||
]
|
|
||||||
result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
|
||||||
if result.returncode != 0:
|
|
||||||
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
|
||||||
|
|
||||||
info = json.loads(result.stdout)
|
|
||||||
audio_stream_info = info["streams"][0] if info.get("streams") else None
|
|
||||||
if audio_stream_info is None:
|
|
||||||
return {"has_audio": False}
|
|
||||||
|
|
||||||
# Return the information, defaulting to None if no audio stream is present
|
|
||||||
return {
|
|
||||||
"has_audio": True,
|
|
||||||
"audio.channels": audio_stream_info.get("channels", None),
|
|
||||||
"audio.codec": audio_stream_info.get("codec_name", None),
|
|
||||||
"audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None,
|
|
||||||
"audio.sample_rate": int(audio_stream_info["sample_rate"])
|
|
||||||
if audio_stream_info.get("sample_rate")
|
|
||||||
else None,
|
|
||||||
"audio.bit_depth": audio_stream_info.get("bit_depth", None),
|
|
||||||
"audio.channel_layout": audio_stream_info.get("channel_layout", None),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_video_info(video_path: Path | str) -> dict:
|
def get_video_info(video_path: Path | str) -> dict:
|
||||||
ffprobe_video_cmd = [
|
ffprobe_video_cmd = [
|
||||||
"ffprobe",
|
"ffprobe",
|
||||||
|
@ -402,7 +370,6 @@ def get_video_info(video_path: Path | str) -> dict:
|
||||||
"video.codec": video_stream_info["codec_name"],
|
"video.codec": video_stream_info["codec_name"],
|
||||||
"video.pix_fmt": video_stream_info["pix_fmt"],
|
"video.pix_fmt": video_stream_info["pix_fmt"],
|
||||||
"video.is_depth_map": False,
|
"video.is_depth_map": False,
|
||||||
**get_audio_info(video_path),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return video_info
|
return video_info
|
||||||
|
|
|
@ -78,6 +78,11 @@ def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, f
|
||||||
if key in robot.logs:
|
if key in robot.logs:
|
||||||
log_dt(f"dtR{name}", robot.logs[key])
|
log_dt(f"dtR{name}", robot.logs[key])
|
||||||
|
|
||||||
|
for name in robot.microphones:
|
||||||
|
key = f"read_microphone_{name}_dt_s"
|
||||||
|
if key in robot.logs:
|
||||||
|
log_dt(f"dtR{name}", robot.logs[key])
|
||||||
|
|
||||||
info_str = " ".join(log_items)
|
info_str = " ".join(log_items)
|
||||||
logging.info(info_str)
|
logging.info(info_str)
|
||||||
|
|
||||||
|
@ -107,11 +112,15 @@ def predict_action(observation, policy, device, use_amp):
|
||||||
torch.inference_mode(),
|
torch.inference_mode(),
|
||||||
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
|
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
|
||||||
):
|
):
|
||||||
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
|
|
||||||
for name in observation:
|
for name in observation:
|
||||||
|
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
|
||||||
if "image" in name:
|
if "image" in name:
|
||||||
observation[name] = observation[name].type(torch.float32) / 255
|
observation[name] = observation[name].type(torch.float32) / 255
|
||||||
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
||||||
|
# Convert to pytorch format: channel first and float32 in [-1,1] with batch dimension
|
||||||
|
if "audio" in name:
|
||||||
|
observation[name] = observation[name].type(torch.float32)
|
||||||
|
observation[name] = observation[name].permute(1, 0).contiguous()
|
||||||
observation[name] = observation[name].unsqueeze(0)
|
observation[name] = observation[name].unsqueeze(0)
|
||||||
observation[name] = observation[name].to(device)
|
observation[name] = observation[name].to(device)
|
||||||
|
|
||||||
|
@ -243,6 +252,18 @@ def control_loop(
|
||||||
|
|
||||||
timestamp = 0
|
timestamp = 0
|
||||||
start_episode_t = time.perf_counter()
|
start_episode_t = time.perf_counter()
|
||||||
|
|
||||||
|
if (
|
||||||
|
dataset is not None and not robot.robot_type.startswith("lekiwi")
|
||||||
|
): # For now, LeKiwi only supports frame audio recording (which may lead to audio chunks loss, extended post-processing, increased memory usage)
|
||||||
|
for microphone_key, microphone in robot.microphones.items():
|
||||||
|
# Start recording both in file writing and data reading mode
|
||||||
|
dataset.add_microphone_recording(microphone, microphone_key)
|
||||||
|
else:
|
||||||
|
for _, microphone in robot.microphones.items():
|
||||||
|
# Start recording only in data reading mode
|
||||||
|
microphone.start_recording()
|
||||||
|
|
||||||
while timestamp < control_time_s:
|
while timestamp < control_time_s:
|
||||||
start_loop_t = time.perf_counter()
|
start_loop_t = time.perf_counter()
|
||||||
|
|
||||||
|
@ -286,6 +307,9 @@ def control_loop(
|
||||||
events["exit_early"] = False
|
events["exit_early"] = False
|
||||||
break
|
break
|
||||||
|
|
||||||
|
for _, microphone in robot.microphones.items():
|
||||||
|
microphone.stop_recording()
|
||||||
|
|
||||||
|
|
||||||
def reset_environment(robot, events, reset_time_s, fps):
|
def reset_environment(robot, events, reset_time_s, fps):
|
||||||
# TODO(rcadene): refactor warmup_record and reset_environment
|
# TODO(rcadene): refactor warmup_record and reset_environment
|
||||||
|
|
|
@ -0,0 +1,38 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import abc
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import draccus
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MicrophoneConfigBase(draccus.ChoiceRegistry, abc.ABC):
|
||||||
|
@property
|
||||||
|
def type(self) -> str:
|
||||||
|
return self.get_choice_name(self.__class__)
|
||||||
|
|
||||||
|
|
||||||
|
@MicrophoneConfigBase.register_subclass("microphone")
|
||||||
|
@dataclass
|
||||||
|
class MicrophoneConfig(MicrophoneConfigBase):
|
||||||
|
"""
|
||||||
|
Dataclass for microphone configuration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
microphone_index: int
|
||||||
|
sample_rate: int | None = None
|
||||||
|
channels: list[int] | None = None
|
||||||
|
mock: bool = False
|
|
@ -0,0 +1,425 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file contains utilities for recording audio from a microhone.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import shutil
|
||||||
|
import time
|
||||||
|
from multiprocessing import Event as process_Event
|
||||||
|
from multiprocessing import JoinableQueue as process_Queue
|
||||||
|
from multiprocessing import Process
|
||||||
|
from os import getcwd
|
||||||
|
from pathlib import Path
|
||||||
|
from queue import Empty
|
||||||
|
from queue import Queue as thread_Queue
|
||||||
|
from threading import Event, Thread
|
||||||
|
from threading import Event as thread_Event
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
|
from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig
|
||||||
|
from lerobot.common.robot_devices.utils import (
|
||||||
|
RobotDeviceAlreadyConnectedError,
|
||||||
|
RobotDeviceAlreadyRecordingError,
|
||||||
|
RobotDeviceNotConnectedError,
|
||||||
|
RobotDeviceNotRecordingError,
|
||||||
|
)
|
||||||
|
from lerobot.common.utils.utils import capture_timestamp_utc
|
||||||
|
|
||||||
|
|
||||||
|
def find_microphones(raise_when_empty=False, mock=False) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Finds and lists all microphones compatible with sounddevice (and the underlying PortAudio library).
|
||||||
|
Most microphones and sound cards are compatible, across all OS (Linux, Mac, Windows).
|
||||||
|
"""
|
||||||
|
microphones = []
|
||||||
|
|
||||||
|
if mock:
|
||||||
|
import tests.microphones.mock_sounddevice as sd
|
||||||
|
else:
|
||||||
|
import sounddevice as sd
|
||||||
|
|
||||||
|
devices = sd.query_devices()
|
||||||
|
for device in devices:
|
||||||
|
if device["max_input_channels"] > 0:
|
||||||
|
microphones.append(
|
||||||
|
{
|
||||||
|
"index": device["index"],
|
||||||
|
"name": device["name"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if raise_when_empty and len(microphones) == 0:
|
||||||
|
raise OSError(
|
||||||
|
"Not a single microphone was detected. Try re-plugging the microphone or check the microphone settings."
|
||||||
|
)
|
||||||
|
|
||||||
|
return microphones
|
||||||
|
|
||||||
|
|
||||||
|
def record_audio_from_microphones(
|
||||||
|
output_dir: Path, microphone_ids: list[int] | None = None, record_time_s: float = 2.0
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Records audio from all the channels of the specified microphones for the specified duration.
|
||||||
|
If no microphone ids are provided, all available microphones will be used.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if microphone_ids is None or len(microphone_ids) == 0:
|
||||||
|
microphones = find_microphones()
|
||||||
|
microphone_ids = [m["index"] for m in microphones]
|
||||||
|
|
||||||
|
microphones = []
|
||||||
|
for microphone_id in microphone_ids:
|
||||||
|
config = MicrophoneConfig(microphone_index=microphone_id)
|
||||||
|
microphone = Microphone(config)
|
||||||
|
microphone.connect()
|
||||||
|
print(
|
||||||
|
f"Recording audio from microphone {microphone_id} for {record_time_s} seconds at {microphone.sample_rate} Hz."
|
||||||
|
)
|
||||||
|
microphones.append(microphone)
|
||||||
|
|
||||||
|
output_dir = Path(output_dir)
|
||||||
|
if output_dir.exists():
|
||||||
|
shutil.rmtree(
|
||||||
|
output_dir,
|
||||||
|
)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
print(f"Saving audio to {output_dir}")
|
||||||
|
|
||||||
|
for microphone in microphones:
|
||||||
|
microphone.start_recording(getcwd() / output_dir / f"microphone_{microphone.microphone_index}.wav")
|
||||||
|
|
||||||
|
time.sleep(record_time_s)
|
||||||
|
|
||||||
|
for microphone in microphones:
|
||||||
|
microphone.stop_recording()
|
||||||
|
|
||||||
|
# Remark : recording may be resumed here if needed
|
||||||
|
|
||||||
|
for microphone in microphones:
|
||||||
|
microphone.disconnect()
|
||||||
|
|
||||||
|
print(f"Images have been saved to {output_dir}")
|
||||||
|
|
||||||
|
|
||||||
|
class Microphone:
|
||||||
|
"""
|
||||||
|
The Microphone class handles all microphones compatible with sounddevice (and the underlying PortAudio library). Most microphones and sound cards are compatible, across all OS (Linux, Mac, Windows).
|
||||||
|
|
||||||
|
A Microphone instance requires the sounddevice index of the microphone, which may be obtained using `python -m sounddevice`. It also requires the recording sample rate as well as the list of recorded channels.
|
||||||
|
|
||||||
|
Example of usage:
|
||||||
|
```python
|
||||||
|
from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig
|
||||||
|
|
||||||
|
config = MicrophoneConfig(microphone_index=0, sample_rate=16000, channels=[1])
|
||||||
|
microphone = Microphone(config)
|
||||||
|
|
||||||
|
microphone.connect()
|
||||||
|
microphone.start_recording("some/output/file.wav")
|
||||||
|
...
|
||||||
|
audio_readings = microphone.read() #Gets all recorded audio data since the last read or since the beginning of the recording
|
||||||
|
...
|
||||||
|
microphone.stop_recording()
|
||||||
|
microphone.disconnect()
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: MicrophoneConfig):
|
||||||
|
self.config = config
|
||||||
|
self.microphone_index = config.microphone_index
|
||||||
|
|
||||||
|
# Store the recording sample rate and channels
|
||||||
|
self.sample_rate = config.sample_rate
|
||||||
|
self.channels = config.channels
|
||||||
|
|
||||||
|
self.mock = config.mock
|
||||||
|
|
||||||
|
# Input audio stream
|
||||||
|
self.stream = None
|
||||||
|
|
||||||
|
# Thread/Process-safe concurrent queue to store the recorded/read audio
|
||||||
|
self.record_queue = None
|
||||||
|
self.read_queue = None
|
||||||
|
|
||||||
|
# Thread/Process to handle data reading and file writing in a separate thread/process (safely)
|
||||||
|
self.record_thread = None
|
||||||
|
self.record_stop_event = None
|
||||||
|
|
||||||
|
self.logs = {}
|
||||||
|
self.is_connected = False
|
||||||
|
self.is_recording = False
|
||||||
|
self.is_writing = False
|
||||||
|
|
||||||
|
def connect(self) -> None:
|
||||||
|
"""
|
||||||
|
Connects the microphone and checks if the requested acquisition parameters are compatible with the microphone.
|
||||||
|
"""
|
||||||
|
if self.is_connected:
|
||||||
|
raise RobotDeviceAlreadyConnectedError(
|
||||||
|
f"Microphone {self.microphone_index} is already connected."
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.mock:
|
||||||
|
import tests.microphones.mock_sounddevice as sd
|
||||||
|
else:
|
||||||
|
import sounddevice as sd
|
||||||
|
|
||||||
|
# Check if the provided microphone index does match an input device
|
||||||
|
is_index_input = sd.query_devices(self.microphone_index)["max_input_channels"] > 0
|
||||||
|
|
||||||
|
if not is_index_input:
|
||||||
|
microphones_info = find_microphones()
|
||||||
|
available_microphones = [m["index"] for m in microphones_info]
|
||||||
|
raise OSError(
|
||||||
|
f"Microphone index {self.microphone_index} does not match an input device (microphone). Available input devices : {available_microphones}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if provided recording parameters are compatible with the microphone
|
||||||
|
actual_microphone = sd.query_devices(self.microphone_index)
|
||||||
|
|
||||||
|
if self.sample_rate is not None:
|
||||||
|
if self.sample_rate > actual_microphone["default_samplerate"]:
|
||||||
|
raise OSError(
|
||||||
|
f"Provided sample rate {self.sample_rate} is higher than the sample rate of the microphone {actual_microphone['default_samplerate']}."
|
||||||
|
)
|
||||||
|
elif self.sample_rate < actual_microphone["default_samplerate"]:
|
||||||
|
logging.warning(
|
||||||
|
"Provided sample rate is lower than the sample rate of the microphone. Performance may be impacted."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.sample_rate = int(actual_microphone["default_samplerate"])
|
||||||
|
|
||||||
|
if self.channels is not None and len(self.channels) > 0:
|
||||||
|
if any(c > actual_microphone["max_input_channels"] for c in self.channels):
|
||||||
|
raise OSError(
|
||||||
|
f"Some of the provided channels {self.channels} are outside the maximum channel range of the microphone {actual_microphone['max_input_channels']}."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.channels = np.arange(1, actual_microphone["max_input_channels"] + 1)
|
||||||
|
|
||||||
|
# Get channels index instead of number for slicing
|
||||||
|
self.channels_index = np.array(self.channels) - 1
|
||||||
|
|
||||||
|
# Create the audio stream
|
||||||
|
self.stream = sd.InputStream(
|
||||||
|
device=self.microphone_index,
|
||||||
|
samplerate=self.sample_rate,
|
||||||
|
channels=max(self.channels),
|
||||||
|
dtype="float32",
|
||||||
|
callback=self._audio_callback,
|
||||||
|
)
|
||||||
|
# Remark : the blocksize parameter could be passed to the stream to ensure that audio_callback always receive same length buffers.
|
||||||
|
# However, this may lead to additional latency. We thus stick to blocksize=0 which means that audio_callback will receive varying length buffers, but with no additional latency.
|
||||||
|
|
||||||
|
self.is_connected = True
|
||||||
|
|
||||||
|
def _audio_callback(self, indata, frames, time, status) -> None:
|
||||||
|
"""
|
||||||
|
Low-level sounddevice callback.
|
||||||
|
"""
|
||||||
|
if status:
|
||||||
|
logging.warning(status)
|
||||||
|
# Slicing makes copy unnecessary
|
||||||
|
# Two separate queues are necessary because .get() also pops the data from the queue
|
||||||
|
if self.is_writing:
|
||||||
|
self.record_queue.put(indata[:, self.channels_index])
|
||||||
|
self.read_queue.put(indata[:, self.channels_index])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _record_loop(queue, event: Event, sample_rate: int, channels: list[int], output_file: Path) -> None:
|
||||||
|
"""
|
||||||
|
Thread/Process-safe loop to write audio data into a file.
|
||||||
|
"""
|
||||||
|
# Can only be run on a single process/thread for file writing safety
|
||||||
|
with sf.SoundFile(
|
||||||
|
output_file,
|
||||||
|
mode="x",
|
||||||
|
samplerate=sample_rate,
|
||||||
|
channels=max(channels),
|
||||||
|
subtype=sf.default_subtype(output_file.suffix[1:]),
|
||||||
|
) as file:
|
||||||
|
while not event.is_set():
|
||||||
|
try:
|
||||||
|
file.write(
|
||||||
|
queue.get(timeout=0.02)
|
||||||
|
) # Timeout set as twice the usual sounddevice buffer size
|
||||||
|
queue.task_done()
|
||||||
|
except Empty:
|
||||||
|
continue
|
||||||
|
|
||||||
|
def _read(self) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Thread/Process-safe callback to read available audio data
|
||||||
|
"""
|
||||||
|
audio_readings = np.empty((0, len(self.channels)))
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
audio_readings = np.concatenate((audio_readings, self.read_queue.get_nowait()), axis=0)
|
||||||
|
except Empty:
|
||||||
|
break
|
||||||
|
|
||||||
|
self.read_queue = thread_Queue()
|
||||||
|
|
||||||
|
return audio_readings
|
||||||
|
|
||||||
|
def read(self) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Reads the last audio chunk recorded by the microphone, e.g. all samples recorded since the last read or since the beginning of the recording.
|
||||||
|
"""
|
||||||
|
if not self.is_connected:
|
||||||
|
raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
|
||||||
|
if not self.is_recording:
|
||||||
|
raise RobotDeviceNotRecordingError(f"Microphone {self.microphone_index} is not recording.")
|
||||||
|
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
audio_readings = self._read()
|
||||||
|
|
||||||
|
# log the number of seconds it took to read the audio chunk
|
||||||
|
self.logs["delta_timestamp_s"] = time.perf_counter() - start_time
|
||||||
|
|
||||||
|
# log the utc time at which the audio chunk was received
|
||||||
|
self.logs["timestamp_utc"] = capture_timestamp_utc()
|
||||||
|
|
||||||
|
return audio_readings
|
||||||
|
|
||||||
|
def start_recording(self, output_file: str | None = None, multiprocessing: bool | None = False) -> None:
|
||||||
|
"""
|
||||||
|
Starts the recording of the microphone. If output_file is provided, the audio will be written to this file.
|
||||||
|
"""
|
||||||
|
if not self.is_connected:
|
||||||
|
raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
|
||||||
|
if self.is_recording:
|
||||||
|
raise RobotDeviceAlreadyRecordingError(
|
||||||
|
f"Microphone {self.microphone_index} is already recording."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reset queues
|
||||||
|
self.read_queue = thread_Queue()
|
||||||
|
if multiprocessing:
|
||||||
|
self.record_queue = process_Queue()
|
||||||
|
else:
|
||||||
|
self.record_queue = thread_Queue()
|
||||||
|
|
||||||
|
# Write recordings into a file if output_file is provided
|
||||||
|
if output_file is not None:
|
||||||
|
output_file = Path(output_file)
|
||||||
|
if output_file.exists():
|
||||||
|
output_file.unlink()
|
||||||
|
|
||||||
|
if multiprocessing:
|
||||||
|
self.record_stop_event = process_Event()
|
||||||
|
self.record_thread = Process(
|
||||||
|
target=Microphone._record_loop,
|
||||||
|
args=(
|
||||||
|
self.record_queue,
|
||||||
|
self.record_stop_event,
|
||||||
|
self.sample_rate,
|
||||||
|
self.channels,
|
||||||
|
output_file,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.record_stop_event = thread_Event()
|
||||||
|
self.record_thread = Thread(
|
||||||
|
target=Microphone._record_loop,
|
||||||
|
args=(
|
||||||
|
self.record_queue,
|
||||||
|
self.record_stop_event,
|
||||||
|
self.sample_rate,
|
||||||
|
self.channels,
|
||||||
|
output_file,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.record_thread.daemon = True
|
||||||
|
self.record_thread.start()
|
||||||
|
|
||||||
|
self.is_writing = True
|
||||||
|
|
||||||
|
self.is_recording = True
|
||||||
|
self.stream.start()
|
||||||
|
|
||||||
|
def stop_recording(self) -> None:
|
||||||
|
"""
|
||||||
|
Stops the recording of the microphones.
|
||||||
|
"""
|
||||||
|
if not self.is_connected:
|
||||||
|
raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
|
||||||
|
if not self.is_recording:
|
||||||
|
raise RobotDeviceNotRecordingError(f"Microphone {self.microphone_index} is not recording.")
|
||||||
|
|
||||||
|
if self.stream.active:
|
||||||
|
self.stream.stop() # Wait for all buffers to be processed
|
||||||
|
# Remark : stream.abort() flushes the buffers !
|
||||||
|
self.is_recording = False
|
||||||
|
|
||||||
|
if self.record_thread is not None:
|
||||||
|
self.record_queue.join()
|
||||||
|
self.record_stop_event.set()
|
||||||
|
self.record_thread.join()
|
||||||
|
self.record_thread = None
|
||||||
|
self.record_stop_event = None
|
||||||
|
self.is_writing = False
|
||||||
|
|
||||||
|
def disconnect(self) -> None:
|
||||||
|
"""
|
||||||
|
Disconnects the microphone and stops the recording.
|
||||||
|
"""
|
||||||
|
if not self.is_connected:
|
||||||
|
raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
|
||||||
|
|
||||||
|
if self.is_recording:
|
||||||
|
self.stop_recording()
|
||||||
|
|
||||||
|
self.stream.close()
|
||||||
|
self.is_connected = False
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
if getattr(self, "is_connected", False):
|
||||||
|
self.disconnect()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Records audio using `Microphone` for all microphones connected to the computer, or a selected subset."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--microphone-ids",
|
||||||
|
type=int,
|
||||||
|
nargs="*",
|
||||||
|
default=None,
|
||||||
|
help="List of microphones indices used to instantiate the `Microphone`. If not provided, find and use all available microphones indices.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-dir",
|
||||||
|
type=Path,
|
||||||
|
default="outputs/audio_from_microphones",
|
||||||
|
help="Set directory to save an audio snippet for each microphone.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--record-time-s",
|
||||||
|
type=float,
|
||||||
|
default=4.0,
|
||||||
|
help="Set the number of seconds used to record the audio. By default, 4 seconds.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
record_audio_from_microphones(**vars(args))
|
|
@ -0,0 +1,48 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig, MicrophoneConfigBase
|
||||||
|
|
||||||
|
|
||||||
|
# Defines a microphone type
|
||||||
|
class Microphone(Protocol):
|
||||||
|
def connect(self): ...
|
||||||
|
def disconnect(self): ...
|
||||||
|
def start_recording(self, output_file: str | None = None, multiprocessing: bool | None = False): ...
|
||||||
|
def stop_recording(self): ...
|
||||||
|
|
||||||
|
|
||||||
|
def make_microphones_from_configs(microphone_configs: dict[str, MicrophoneConfigBase]) -> list[Microphone]:
|
||||||
|
microphones = {}
|
||||||
|
|
||||||
|
for key, cfg in microphone_configs.items():
|
||||||
|
if cfg.type == "microphone":
|
||||||
|
from lerobot.common.robot_devices.microphones.microphone import Microphone
|
||||||
|
|
||||||
|
microphones[key] = Microphone(cfg)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"The microphone type '{cfg.type}' is not valid.")
|
||||||
|
|
||||||
|
return microphones
|
||||||
|
|
||||||
|
|
||||||
|
def make_microphone(microphone_type, **kwargs) -> Microphone:
|
||||||
|
if microphone_type == "microphone":
|
||||||
|
from lerobot.common.robot_devices.microphones.microphone import Microphone
|
||||||
|
|
||||||
|
return Microphone(MicrophoneConfig(**kwargs))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"The microphone type '{microphone_type}' is not valid.")
|
|
@ -23,6 +23,7 @@ from lerobot.common.robot_devices.cameras.configs import (
|
||||||
IntelRealSenseCameraConfig,
|
IntelRealSenseCameraConfig,
|
||||||
OpenCVCameraConfig,
|
OpenCVCameraConfig,
|
||||||
)
|
)
|
||||||
|
from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig
|
||||||
from lerobot.common.robot_devices.motors.configs import (
|
from lerobot.common.robot_devices.motors.configs import (
|
||||||
DynamixelMotorsBusConfig,
|
DynamixelMotorsBusConfig,
|
||||||
FeetechMotorsBusConfig,
|
FeetechMotorsBusConfig,
|
||||||
|
@ -43,6 +44,7 @@ class ManipulatorRobotConfig(RobotConfig):
|
||||||
leader_arms: dict[str, MotorsBusConfig] = field(default_factory=lambda: {})
|
leader_arms: dict[str, MotorsBusConfig] = field(default_factory=lambda: {})
|
||||||
follower_arms: dict[str, MotorsBusConfig] = field(default_factory=lambda: {})
|
follower_arms: dict[str, MotorsBusConfig] = field(default_factory=lambda: {})
|
||||||
cameras: dict[str, CameraConfig] = field(default_factory=lambda: {})
|
cameras: dict[str, CameraConfig] = field(default_factory=lambda: {})
|
||||||
|
microphones: dict[str, MicrophoneConfig] = field(default_factory=lambda: {})
|
||||||
|
|
||||||
# Optionally limit the magnitude of the relative positional target vector for safety purposes.
|
# Optionally limit the magnitude of the relative positional target vector for safety purposes.
|
||||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length
|
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length
|
||||||
|
@ -68,6 +70,9 @@ class ManipulatorRobotConfig(RobotConfig):
|
||||||
for cam in self.cameras.values():
|
for cam in self.cameras.values():
|
||||||
if not cam.mock:
|
if not cam.mock:
|
||||||
cam.mock = True
|
cam.mock = True
|
||||||
|
for mic in self.microphones.values():
|
||||||
|
if not mic.mock:
|
||||||
|
mic.mock = True
|
||||||
|
|
||||||
if self.max_relative_target is not None and isinstance(self.max_relative_target, Sequence):
|
if self.max_relative_target is not None and isinstance(self.max_relative_target, Sequence):
|
||||||
for name in self.follower_arms:
|
for name in self.follower_arms:
|
||||||
|
@ -491,6 +496,21 @@ class So100RobotConfig(ManipulatorRobotConfig):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
microphones: dict[str, MicrophoneConfig] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"laptop": MicrophoneConfig(
|
||||||
|
microphone_index=0,
|
||||||
|
sample_rate=48000,
|
||||||
|
channels=[1],
|
||||||
|
),
|
||||||
|
"headset": MicrophoneConfig(
|
||||||
|
microphone_index=1,
|
||||||
|
sample_rate=44100,
|
||||||
|
channels=[1],
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
mock: bool = False
|
mock: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -52,6 +52,16 @@ def run_camera_capture(cameras, images_lock, latest_images_dict, stop_event):
|
||||||
time.sleep(0.01)
|
time.sleep(0.01)
|
||||||
|
|
||||||
|
|
||||||
|
def run_microphone_capture(microphones, audio_lock, latest_audio_dict, stop_event):
|
||||||
|
while not stop_event.is_set():
|
||||||
|
local_dict = {}
|
||||||
|
for name, microphone in microphones.items():
|
||||||
|
audio_readings = microphone.read()
|
||||||
|
local_dict[name] = audio_readings
|
||||||
|
with audio_lock:
|
||||||
|
latest_audio_dict.update(local_dict)
|
||||||
|
|
||||||
|
|
||||||
def calibrate_follower_arm(motors_bus, calib_dir_str):
|
def calibrate_follower_arm(motors_bus, calib_dir_str):
|
||||||
"""
|
"""
|
||||||
Calibrates the follower arm. Attempts to load an existing calibration file;
|
Calibrates the follower arm. Attempts to load an existing calibration file;
|
||||||
|
@ -94,6 +104,7 @@ def run_lekiwi(robot_config):
|
||||||
"""
|
"""
|
||||||
# Import helper functions and classes
|
# Import helper functions and classes
|
||||||
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
|
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
|
||||||
|
from lerobot.common.robot_devices.microphones.utils import make_microphones_from_configs
|
||||||
from lerobot.common.robot_devices.motors.feetech import FeetechMotorsBus, TorqueMode
|
from lerobot.common.robot_devices.motors.feetech import FeetechMotorsBus, TorqueMode
|
||||||
|
|
||||||
# Initialize cameras from the robot configuration.
|
# Initialize cameras from the robot configuration.
|
||||||
|
@ -101,6 +112,11 @@ def run_lekiwi(robot_config):
|
||||||
for cam in cameras.values():
|
for cam in cameras.values():
|
||||||
cam.connect()
|
cam.connect()
|
||||||
|
|
||||||
|
# Initialize microphones from the robot configuration.
|
||||||
|
microphones = make_microphones_from_configs(robot_config.microphones)
|
||||||
|
for microphone in microphones.values():
|
||||||
|
microphone.connect()
|
||||||
|
|
||||||
# Initialize the motors bus using the follower arm configuration.
|
# Initialize the motors bus using the follower arm configuration.
|
||||||
motor_config = robot_config.follower_arms.get("main")
|
motor_config = robot_config.follower_arms.get("main")
|
||||||
if motor_config is None:
|
if motor_config is None:
|
||||||
|
@ -134,6 +150,20 @@ def run_lekiwi(robot_config):
|
||||||
)
|
)
|
||||||
cam_thread.start()
|
cam_thread.start()
|
||||||
|
|
||||||
|
# Start the microphone recording and capture thread.
|
||||||
|
# TODO(CarolinePascal) : Leverage multi-core processing with a multiprocessing.Process instead !
|
||||||
|
latest_audio_dict = {}
|
||||||
|
audio_lock = threading.Lock()
|
||||||
|
audio_stop_event = threading.Event()
|
||||||
|
microphone_thread = threading.Thread(
|
||||||
|
target=run_microphone_capture,
|
||||||
|
args=(microphones, audio_lock, latest_audio_dict, audio_stop_event),
|
||||||
|
daemon=True,
|
||||||
|
)
|
||||||
|
for microphone in microphones.values():
|
||||||
|
microphone.start_recording()
|
||||||
|
microphone_thread.start()
|
||||||
|
|
||||||
last_cmd_time = time.time()
|
last_cmd_time = time.time()
|
||||||
print("LeKiwi robot server started. Waiting for commands...")
|
print("LeKiwi robot server started. Waiting for commands...")
|
||||||
|
|
||||||
|
@ -198,9 +228,14 @@ def run_lekiwi(robot_config):
|
||||||
with images_lock:
|
with images_lock:
|
||||||
images_dict_copy = dict(latest_images_dict)
|
images_dict_copy = dict(latest_images_dict)
|
||||||
|
|
||||||
|
# Get the latest audio data.
|
||||||
|
with audio_lock:
|
||||||
|
audio_dict_copy = dict(latest_audio_dict)
|
||||||
|
|
||||||
# Build the observation dictionary.
|
# Build the observation dictionary.
|
||||||
observation = {
|
observation = {
|
||||||
"images": images_dict_copy,
|
"images": images_dict_copy,
|
||||||
|
"audio": audio_dict_copy, # TODO(CarolinePascal) : This is a nasty way to do it, sorry.
|
||||||
"present_speed": current_velocity,
|
"present_speed": current_velocity,
|
||||||
"follower_arm_state": follower_arm_state,
|
"follower_arm_state": follower_arm_state,
|
||||||
}
|
}
|
||||||
|
@ -217,6 +252,9 @@ def run_lekiwi(robot_config):
|
||||||
finally:
|
finally:
|
||||||
stop_event.set()
|
stop_event.set()
|
||||||
cam_thread.join()
|
cam_thread.join()
|
||||||
|
microphone_thread.join()
|
||||||
|
for microphone in microphones.values():
|
||||||
|
microphone.stop_recording()
|
||||||
robot.stop()
|
robot.stop()
|
||||||
motors_bus.disconnect()
|
motors_bus.disconnect()
|
||||||
cmd_socket.close()
|
cmd_socket.close()
|
||||||
|
|
|
@ -28,6 +28,7 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
|
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
|
||||||
|
from lerobot.common.robot_devices.microphones.utils import make_microphones_from_configs
|
||||||
from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs
|
from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs
|
||||||
from lerobot.common.robot_devices.robots.configs import ManipulatorRobotConfig
|
from lerobot.common.robot_devices.robots.configs import ManipulatorRobotConfig
|
||||||
from lerobot.common.robot_devices.robots.utils import get_arm_id
|
from lerobot.common.robot_devices.robots.utils import get_arm_id
|
||||||
|
@ -164,6 +165,7 @@ class ManipulatorRobot:
|
||||||
self.leader_arms = make_motors_buses_from_configs(self.config.leader_arms)
|
self.leader_arms = make_motors_buses_from_configs(self.config.leader_arms)
|
||||||
self.follower_arms = make_motors_buses_from_configs(self.config.follower_arms)
|
self.follower_arms = make_motors_buses_from_configs(self.config.follower_arms)
|
||||||
self.cameras = make_cameras_from_configs(self.config.cameras)
|
self.cameras = make_cameras_from_configs(self.config.cameras)
|
||||||
|
self.microphones = make_microphones_from_configs(self.config.microphones)
|
||||||
self.is_connected = False
|
self.is_connected = False
|
||||||
self.logs = {}
|
self.logs = {}
|
||||||
|
|
||||||
|
@ -199,9 +201,24 @@ class ManipulatorRobot:
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def microphone_features(self) -> dict:
|
||||||
|
mic_ft = {}
|
||||||
|
for mic_key, mic in self.microphones.items():
|
||||||
|
key = f"observation.audio.{mic_key}"
|
||||||
|
mic_ft[key] = {
|
||||||
|
"dtype": "audio",
|
||||||
|
"shape": (len(mic.channels),),
|
||||||
|
"names": "channels",
|
||||||
|
"info": {
|
||||||
|
"sample_rate": mic.sample_rate
|
||||||
|
}, # we need to store the sample rate here in the case of audio chunks recording (for LeKiwi), as it will not be available anymore when writing the audio file
|
||||||
|
}
|
||||||
|
return mic_ft
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def features(self):
|
def features(self):
|
||||||
return {**self.motor_features, **self.camera_features}
|
return {**self.motor_features, **self.camera_features, **self.microphone_features}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def has_camera(self):
|
def has_camera(self):
|
||||||
|
@ -211,6 +228,14 @@ class ManipulatorRobot:
|
||||||
def num_cameras(self):
|
def num_cameras(self):
|
||||||
return len(self.cameras)
|
return len(self.cameras)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_microphone(self):
|
||||||
|
return len(self.microphones) > 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_microphones(self):
|
||||||
|
return len(self.microphones)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def available_arms(self):
|
def available_arms(self):
|
||||||
available_arms = []
|
available_arms = []
|
||||||
|
@ -228,7 +253,7 @@ class ManipulatorRobot:
|
||||||
"ManipulatorRobot is already connected. Do not run `robot.connect()` twice."
|
"ManipulatorRobot is already connected. Do not run `robot.connect()` twice."
|
||||||
)
|
)
|
||||||
|
|
||||||
if not self.leader_arms and not self.follower_arms and not self.cameras:
|
if not self.leader_arms and not self.follower_arms and not self.cameras and not self.microphones:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"ManipulatorRobot doesn't have any device to connect. See example of usage in docstring of the class."
|
"ManipulatorRobot doesn't have any device to connect. See example of usage in docstring of the class."
|
||||||
)
|
)
|
||||||
|
@ -289,6 +314,10 @@ class ManipulatorRobot:
|
||||||
for name in self.cameras:
|
for name in self.cameras:
|
||||||
self.cameras[name].connect()
|
self.cameras[name].connect()
|
||||||
|
|
||||||
|
# Connect the microphones
|
||||||
|
for name in self.microphones:
|
||||||
|
self.microphones[name].connect()
|
||||||
|
|
||||||
self.is_connected = True
|
self.is_connected = True
|
||||||
|
|
||||||
def activate_calibration(self):
|
def activate_calibration(self):
|
||||||
|
@ -514,12 +543,23 @@ class ManipulatorRobot:
|
||||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||||
|
|
||||||
|
# Capture audio from microphones
|
||||||
|
audio = {}
|
||||||
|
for name in self.microphones:
|
||||||
|
before_audioread_t = time.perf_counter()
|
||||||
|
audio[name] = self.microphones[name].read()
|
||||||
|
audio[name] = torch.from_numpy(audio[name])
|
||||||
|
self.logs[f"read_microphone_{name}_dt_s"] = self.microphones[name].logs["delta_timestamp_s"]
|
||||||
|
self.logs[f"async_read_microphone_{name}_dt_s"] = time.perf_counter() - before_audioread_t
|
||||||
|
|
||||||
# Populate output dictionaries
|
# Populate output dictionaries
|
||||||
obs_dict, action_dict = {}, {}
|
obs_dict, action_dict = {}, {}
|
||||||
obs_dict["observation.state"] = state
|
obs_dict["observation.state"] = state
|
||||||
action_dict["action"] = action
|
action_dict["action"] = action
|
||||||
for name in self.cameras:
|
for name in self.cameras:
|
||||||
obs_dict[f"observation.images.{name}"] = images[name]
|
obs_dict[f"observation.images.{name}"] = images[name]
|
||||||
|
for name in self.microphones:
|
||||||
|
obs_dict[f"observation.audio.{name}"] = audio[name]
|
||||||
|
|
||||||
return obs_dict, action_dict
|
return obs_dict, action_dict
|
||||||
|
|
||||||
|
@ -554,11 +594,22 @@ class ManipulatorRobot:
|
||||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||||
|
|
||||||
|
# Capture audio from microphones
|
||||||
|
audio = {}
|
||||||
|
for name in self.microphones:
|
||||||
|
before_audioread_t = time.perf_counter()
|
||||||
|
audio[name] = self.microphones[name].read()
|
||||||
|
audio[name] = torch.from_numpy(audio[name])
|
||||||
|
self.logs[f"read_microphone_{name}_dt_s"] = self.microphones[name].logs["delta_timestamp_s"]
|
||||||
|
self.logs[f"async_read_microphone_{name}_dt_s"] = time.perf_counter() - before_audioread_t
|
||||||
|
|
||||||
# Populate output dictionaries and format to pytorch
|
# Populate output dictionaries and format to pytorch
|
||||||
obs_dict = {}
|
obs_dict = {}
|
||||||
obs_dict["observation.state"] = state
|
obs_dict["observation.state"] = state
|
||||||
for name in self.cameras:
|
for name in self.cameras:
|
||||||
obs_dict[f"observation.images.{name}"] = images[name]
|
obs_dict[f"observation.images.{name}"] = images[name]
|
||||||
|
for name in self.microphones:
|
||||||
|
obs_dict[f"observation.audio.{name}"] = audio[name]
|
||||||
return obs_dict
|
return obs_dict
|
||||||
|
|
||||||
def send_action(self, action: torch.Tensor) -> torch.Tensor:
|
def send_action(self, action: torch.Tensor) -> torch.Tensor:
|
||||||
|
@ -620,6 +671,9 @@ class ManipulatorRobot:
|
||||||
for name in self.cameras:
|
for name in self.cameras:
|
||||||
self.cameras[name].disconnect()
|
self.cameras[name].disconnect()
|
||||||
|
|
||||||
|
for name in self.microphones:
|
||||||
|
self.microphones[name].disconnect()
|
||||||
|
|
||||||
self.is_connected = False
|
self.is_connected = False
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
|
|
|
@ -24,6 +24,7 @@ import torch
|
||||||
import zmq
|
import zmq
|
||||||
|
|
||||||
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
|
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
|
||||||
|
from lerobot.common.robot_devices.microphones.utils import make_microphones_from_configs
|
||||||
from lerobot.common.robot_devices.motors.feetech import TorqueMode
|
from lerobot.common.robot_devices.motors.feetech import TorqueMode
|
||||||
from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs
|
from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs
|
||||||
from lerobot.common.robot_devices.robots.configs import LeKiwiRobotConfig
|
from lerobot.common.robot_devices.robots.configs import LeKiwiRobotConfig
|
||||||
|
@ -79,6 +80,7 @@ class MobileManipulator:
|
||||||
self.follower_arms = make_motors_buses_from_configs(self.config.follower_arms)
|
self.follower_arms = make_motors_buses_from_configs(self.config.follower_arms)
|
||||||
|
|
||||||
self.cameras = make_cameras_from_configs(self.config.cameras)
|
self.cameras = make_cameras_from_configs(self.config.cameras)
|
||||||
|
self.microphones = make_microphones_from_configs(self.config.microphones)
|
||||||
|
|
||||||
self.is_connected = False
|
self.is_connected = False
|
||||||
|
|
||||||
|
@ -133,6 +135,7 @@ class MobileManipulator:
|
||||||
"shape": (cam.height, cam.width, cam.channels),
|
"shape": (cam.height, cam.width, cam.channels),
|
||||||
"names": ["height", "width", "channels"],
|
"names": ["height", "width", "channels"],
|
||||||
"info": None,
|
"info": None,
|
||||||
|
"audio": "observation.audio." + cam.microphone if cam.microphone is not None else None,
|
||||||
}
|
}
|
||||||
return cam_ft
|
return cam_ft
|
||||||
|
|
||||||
|
@ -161,9 +164,22 @@ class MobileManipulator:
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def microphone_features(self) -> dict:
|
||||||
|
mic_ft = {}
|
||||||
|
for mic_key, mic in self.microphones.items():
|
||||||
|
key = f"observation.audio.{mic_key}"
|
||||||
|
mic_ft[key] = {
|
||||||
|
"dtype": "audio",
|
||||||
|
"shape": (len(mic.channels),),
|
||||||
|
"names": "channels",
|
||||||
|
"info": {"sample_rate": mic.sample_rate},
|
||||||
|
}
|
||||||
|
return mic_ft
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def features(self):
|
def features(self):
|
||||||
return {**self.motor_features, **self.camera_features}
|
return {**self.motor_features, **self.camera_features, **self.microphone_features}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def has_camera(self):
|
def has_camera(self):
|
||||||
|
@ -173,6 +189,14 @@ class MobileManipulator:
|
||||||
def num_cameras(self):
|
def num_cameras(self):
|
||||||
return len(self.cameras)
|
return len(self.cameras)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_microphone(self):
|
||||||
|
return len(self.microphones) > 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_microphones(self):
|
||||||
|
return len(self.microphones)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def available_arms(self):
|
def available_arms(self):
|
||||||
available = []
|
available = []
|
||||||
|
@ -344,6 +368,7 @@ class MobileManipulator:
|
||||||
observation = json.loads(last_msg)
|
observation = json.loads(last_msg)
|
||||||
|
|
||||||
images_dict = observation.get("images", {})
|
images_dict = observation.get("images", {})
|
||||||
|
audio_dict = observation.get("audio", {})
|
||||||
new_speed = observation.get("present_speed", {})
|
new_speed = observation.get("present_speed", {})
|
||||||
new_arm_state = observation.get("follower_arm_state", None)
|
new_arm_state = observation.get("follower_arm_state", None)
|
||||||
|
|
||||||
|
@ -356,6 +381,11 @@ class MobileManipulator:
|
||||||
if frame_candidate is not None:
|
if frame_candidate is not None:
|
||||||
frames[cam_name] = frame_candidate
|
frames[cam_name] = frame_candidate
|
||||||
|
|
||||||
|
# Receive audio
|
||||||
|
for microphone_name, audio_data in audio_dict.items():
|
||||||
|
if audio_data:
|
||||||
|
frames[microphone_name] = audio_data
|
||||||
|
|
||||||
# If remote_arm_state is None and frames is None there is no message then use the previous message
|
# If remote_arm_state is None and frames is None there is no message then use the previous message
|
||||||
if new_arm_state is not None and frames is not None:
|
if new_arm_state is not None and frames is not None:
|
||||||
self.last_frames = frames
|
self.last_frames = frames
|
||||||
|
@ -475,6 +505,14 @@ class MobileManipulator:
|
||||||
frame = np.zeros((cam.height, cam.width, cam.channels), dtype=np.uint8)
|
frame = np.zeros((cam.height, cam.width, cam.channels), dtype=np.uint8)
|
||||||
obs_dict[f"observation.images.{cam_name}"] = torch.from_numpy(frame)
|
obs_dict[f"observation.images.{cam_name}"] = torch.from_numpy(frame)
|
||||||
|
|
||||||
|
# Loop over each configured microphone
|
||||||
|
for microphone_name, microphone in self.microphones.items():
|
||||||
|
frame = frames.get(microphone_name, None)
|
||||||
|
if frame is None:
|
||||||
|
# Create silence using the microphone's configured channels
|
||||||
|
frame = np.zeros((1, len(microphone.channels)), dtype=np.float32)
|
||||||
|
obs_dict[f"observation.audio.{microphone_name}"] = torch.from_numpy(frame)
|
||||||
|
|
||||||
return obs_dict
|
return obs_dict
|
||||||
|
|
||||||
def send_action(self, action: torch.Tensor) -> torch.Tensor:
|
def send_action(self, action: torch.Tensor) -> torch.Tensor:
|
||||||
|
|
|
@ -63,3 +63,25 @@ class RobotDeviceAlreadyConnectedError(Exception):
|
||||||
):
|
):
|
||||||
self.message = message
|
self.message = message
|
||||||
super().__init__(self.message)
|
super().__init__(self.message)
|
||||||
|
|
||||||
|
|
||||||
|
class RobotDeviceNotRecordingError(Exception):
|
||||||
|
"""Exception raised when the robot device is not recording."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message="This robot device is not recording. Try calling `robot_device.start_recording()` first.",
|
||||||
|
):
|
||||||
|
self.message = message
|
||||||
|
super().__init__(self.message)
|
||||||
|
|
||||||
|
|
||||||
|
class RobotDeviceAlreadyRecordingError(Exception):
|
||||||
|
"""Exception raised when the robot device is already recording."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message="This robot device is already recording. Try not calling `robot_device.start_recording()` twice.",
|
||||||
|
):
|
||||||
|
self.message = message
|
||||||
|
super().__init__(self.message)
|
||||||
|
|
|
@ -67,8 +67,11 @@ dependencies = [
|
||||||
"pynput>=1.7.7",
|
"pynput>=1.7.7",
|
||||||
"pyzmq>=26.2.1",
|
"pyzmq>=26.2.1",
|
||||||
"rerun-sdk>=0.21.0",
|
"rerun-sdk>=0.21.0",
|
||||||
|
"sounddevice>=0.5.1",
|
||||||
|
"soundfile>=0.13.1",
|
||||||
"termcolor>=2.4.0",
|
"termcolor>=2.4.0",
|
||||||
"torch>=2.2.1",
|
"torch>=2.2.1",
|
||||||
|
"torchaudio>=2.6.0",
|
||||||
"torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')",
|
"torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')",
|
||||||
"torchvision>=0.21.0",
|
"torchvision>=0.21.0",
|
||||||
"wandb>=0.16.3",
|
"wandb>=0.16.3",
|
||||||
|
@ -96,6 +99,7 @@ test = ["pytest>=8.1.0", "pytest-cov>=5.0.0", "pyserial>=3.5"]
|
||||||
umi = ["imagecodecs>=2024.1.1"]
|
umi = ["imagecodecs>=2024.1.1"]
|
||||||
video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"]
|
video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"]
|
||||||
xarm = ["gym-xarm>=0.1.1 ; python_version < '4.0'"]
|
xarm = ["gym-xarm>=0.1.1 ; python_version < '4.0'"]
|
||||||
|
audio = ["librosa>=0.11.0"]
|
||||||
|
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
requires-poetry = ">=2.1"
|
requires-poetry = ">=2.1"
|
||||||
|
|
|
@ -19,9 +19,9 @@ import traceback
|
||||||
import pytest
|
import pytest
|
||||||
from serial import SerialException
|
from serial import SerialException
|
||||||
|
|
||||||
from lerobot import available_cameras, available_motors, available_robots
|
from lerobot import available_cameras, available_microphones, available_motors, available_robots
|
||||||
from lerobot.common.robot_devices.robots.utils import make_robot
|
from lerobot.common.robot_devices.robots.utils import make_robot
|
||||||
from tests.utils import DEVICE, make_camera, make_motors_bus
|
from tests.utils import DEVICE, make_camera, make_microphone, make_motors_bus
|
||||||
|
|
||||||
# Import fixture modules as plugins
|
# Import fixture modules as plugins
|
||||||
pytest_plugins = [
|
pytest_plugins = [
|
||||||
|
@ -74,6 +74,11 @@ def is_camera_available(camera_type):
|
||||||
return _check_component_availability(camera_type, available_cameras, make_camera)
|
return _check_component_availability(camera_type, available_cameras, make_camera)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def is_microphone_available(microphone_type):
|
||||||
|
return _check_component_availability(microphone_type, available_microphones, make_microphone)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def is_motor_available(motor_type):
|
def is_motor_available(motor_type):
|
||||||
return _check_component_availability(motor_type, available_motors, make_motors_bus)
|
return _check_component_availability(motor_type, available_motors, make_motors_bus)
|
||||||
|
|
|
@ -25,6 +25,8 @@ from lerobot.common.datasets.compute_stats import (
|
||||||
compute_episode_stats,
|
compute_episode_stats,
|
||||||
estimate_num_samples,
|
estimate_num_samples,
|
||||||
get_feature_stats,
|
get_feature_stats,
|
||||||
|
sample_audio_from_data,
|
||||||
|
sample_audio_from_path,
|
||||||
sample_images,
|
sample_images,
|
||||||
sample_indices,
|
sample_indices,
|
||||||
)
|
)
|
||||||
|
@ -34,6 +36,10 @@ def mock_load_image_as_numpy(path, dtype, channel_first):
|
||||||
return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype)
|
return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def mock_load_audio(path):
|
||||||
|
return np.ones((16000, 2), dtype=np.float32)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_array():
|
def sample_array():
|
||||||
return np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
return np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
||||||
|
@ -71,6 +77,25 @@ def test_sample_images(mock_load):
|
||||||
assert len(images) == estimate_num_samples(100)
|
assert len(images) == estimate_num_samples(100)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("lerobot.common.datasets.compute_stats.load_audio_from_path", side_effect=mock_load_audio)
|
||||||
|
def test_sample_audio_from_path(mock_load):
|
||||||
|
audio_path = "audio.wav"
|
||||||
|
audio_samples = sample_audio_from_path(audio_path)
|
||||||
|
assert isinstance(audio_samples, np.ndarray)
|
||||||
|
assert audio_samples.shape[1] == 2
|
||||||
|
assert audio_samples.dtype == np.float32
|
||||||
|
assert len(audio_samples) == estimate_num_samples(16000)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sample_audio_from_data():
|
||||||
|
audio_data = np.ones((16000, 2), dtype=np.float32)
|
||||||
|
audio_samples = sample_audio_from_data(audio_data)
|
||||||
|
assert isinstance(audio_samples, np.ndarray)
|
||||||
|
assert audio_samples.shape[1] == 2
|
||||||
|
assert audio_samples.dtype == np.float32
|
||||||
|
assert len(audio_samples) == estimate_num_samples(16000)
|
||||||
|
|
||||||
|
|
||||||
def test_get_feature_stats_images():
|
def test_get_feature_stats_images():
|
||||||
data = np.random.rand(100, 3, 32, 32)
|
data = np.random.rand(100, 3, 32, 32)
|
||||||
stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True)
|
stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True)
|
||||||
|
@ -79,6 +104,14 @@ def test_get_feature_stats_images():
|
||||||
assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape
|
assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_feature_stats_audio():
|
||||||
|
data = np.random.uniform(-1, 1, (16000, 2))
|
||||||
|
stats = get_feature_stats(data, axis=0, keepdims=True)
|
||||||
|
assert "min" in stats and "max" in stats and "mean" in stats and "std" in stats and "count" in stats
|
||||||
|
np.testing.assert_equal(stats["count"], np.array([16000]))
|
||||||
|
assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape
|
||||||
|
|
||||||
|
|
||||||
def test_get_feature_stats_axis_0_keepdims(sample_array):
|
def test_get_feature_stats_axis_0_keepdims(sample_array):
|
||||||
expected = {
|
expected = {
|
||||||
"min": np.array([[1, 2, 3]]),
|
"min": np.array([[1, 2, 3]]),
|
||||||
|
@ -137,22 +170,29 @@ def test_get_feature_stats_single_value():
|
||||||
def test_compute_episode_stats():
|
def test_compute_episode_stats():
|
||||||
episode_data = {
|
episode_data = {
|
||||||
"observation.image": [f"image_{i}.jpg" for i in range(100)],
|
"observation.image": [f"image_{i}.jpg" for i in range(100)],
|
||||||
|
"observation.audio": "audio.wav",
|
||||||
"observation.state": np.random.rand(100, 10),
|
"observation.state": np.random.rand(100, 10),
|
||||||
}
|
}
|
||||||
features = {
|
features = {
|
||||||
"observation.image": {"dtype": "image"},
|
"observation.image": {"dtype": "image"},
|
||||||
|
"observation.audio": {"dtype": "audio"},
|
||||||
"observation.state": {"dtype": "numeric"},
|
"observation.state": {"dtype": "numeric"},
|
||||||
}
|
}
|
||||||
|
|
||||||
with patch(
|
with (
|
||||||
"lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy
|
patch(
|
||||||
|
"lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy
|
||||||
|
),
|
||||||
|
patch("lerobot.common.datasets.compute_stats.load_audio_from_path", side_effect=mock_load_audio),
|
||||||
):
|
):
|
||||||
stats = compute_episode_stats(episode_data, features)
|
stats = compute_episode_stats(episode_data, features)
|
||||||
|
|
||||||
assert "observation.image" in stats and "observation.state" in stats
|
assert "observation.image" in stats and "observation.state" in stats and "observation.audio" in stats
|
||||||
assert stats["observation.image"]["count"].item() == 100
|
assert stats["observation.image"]["count"].item() == estimate_num_samples(100)
|
||||||
assert stats["observation.state"]["count"].item() == 100
|
assert stats["observation.audio"]["count"].item() == estimate_num_samples(16000)
|
||||||
|
assert stats["observation.state"]["count"].item() == estimate_num_samples(100)
|
||||||
assert stats["observation.image"]["mean"].shape == (3, 1, 1)
|
assert stats["observation.image"]["mean"].shape == (3, 1, 1)
|
||||||
|
assert stats["observation.audio"]["mean"].shape == (1, 2)
|
||||||
|
|
||||||
|
|
||||||
def test_assert_type_and_shape_valid():
|
def test_assert_type_and_shape_valid():
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
import time
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -35,6 +36,7 @@ from lerobot.common.datasets.lerobot_dataset import (
|
||||||
MultiLeRobotDataset,
|
MultiLeRobotDataset,
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
|
DEFAULT_AUDIO_CHUNK_DURATION,
|
||||||
create_branch,
|
create_branch,
|
||||||
flatten_dict,
|
flatten_dict,
|
||||||
unflatten_dict,
|
unflatten_dict,
|
||||||
|
@ -44,8 +46,8 @@ from lerobot.common.policies.factory import make_policy_config
|
||||||
from lerobot.common.robot_devices.robots.utils import make_robot
|
from lerobot.common.robot_devices.robots.utils import make_robot
|
||||||
from lerobot.configs.default import DatasetConfig
|
from lerobot.configs.default import DatasetConfig
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
from tests.fixtures.constants import DUMMY_AUDIO_CHANNELS, DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
||||||
from tests.utils import require_x86_64_kernel
|
from tests.utils import make_microphone, require_x86_64_kernel
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -64,6 +66,20 @@ def image_dataset(tmp_path, empty_lerobot_dataset_factory):
|
||||||
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def audio_dataset(tmp_path, empty_lerobot_dataset_factory):
|
||||||
|
features = {
|
||||||
|
"observation.audio.microphone": {
|
||||||
|
"dtype": "audio",
|
||||||
|
"shape": (DUMMY_AUDIO_CHANNELS,),
|
||||||
|
"names": [
|
||||||
|
"channels",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
|
|
||||||
|
|
||||||
def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
|
def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
|
||||||
"""
|
"""
|
||||||
Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated
|
Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated
|
||||||
|
@ -322,6 +338,24 @@ def test_image_array_to_pil_image_wrong_range_float_0_255():
|
||||||
image_array_to_pil_image(image)
|
image_array_to_pil_image(image)
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_frame_audio(audio_dataset):
|
||||||
|
dataset = audio_dataset
|
||||||
|
|
||||||
|
microphone = make_microphone(microphone_type="microphone", mock=True)
|
||||||
|
microphone.connect()
|
||||||
|
|
||||||
|
dataset.add_microphone_recording(microphone, "microphone")
|
||||||
|
time.sleep(1.0)
|
||||||
|
dataset.add_frame({"observation.audio.microphone": microphone.read(), "task": "Dummy task"})
|
||||||
|
microphone.stop_recording()
|
||||||
|
|
||||||
|
dataset.save_episode()
|
||||||
|
|
||||||
|
assert dataset[0]["observation.audio.microphone"].shape == torch.Size(
|
||||||
|
(int(DEFAULT_AUDIO_CHUNK_DURATION * microphone.sample_rate), DUMMY_AUDIO_CHANNELS)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO(aliberts):
|
# TODO(aliberts):
|
||||||
# - [ ] test various attributes & state from init and create
|
# - [ ] test various attributes & state from init and create
|
||||||
# - [ ] test init with episodes and check num_frames
|
# - [ ] test init with episodes and check num_frames
|
||||||
|
@ -354,6 +388,7 @@ def test_factory(env_name, repo_id, policy_name):
|
||||||
dataset = make_dataset(cfg)
|
dataset = make_dataset(cfg)
|
||||||
delta_timestamps = dataset.delta_timestamps
|
delta_timestamps = dataset.delta_timestamps
|
||||||
camera_keys = dataset.meta.camera_keys
|
camera_keys = dataset.meta.camera_keys
|
||||||
|
audio_keys = dataset.meta.audio_keys
|
||||||
|
|
||||||
item = dataset[0]
|
item = dataset[0]
|
||||||
|
|
||||||
|
@ -396,6 +431,11 @@ def test_factory(env_name, repo_id, policy_name):
|
||||||
# test c,h,w
|
# test c,h,w
|
||||||
assert item[key].shape[0] == 3, f"{key}"
|
assert item[key].shape[0] == 3, f"{key}"
|
||||||
|
|
||||||
|
for key in audio_keys:
|
||||||
|
assert item[key].dtype == torch.float32, f"{key}"
|
||||||
|
assert item[key].max() <= 1.0, f"{key}"
|
||||||
|
assert item[key].min() >= -1.0, f"{key}"
|
||||||
|
|
||||||
if delta_timestamps is not None:
|
if delta_timestamps is not None:
|
||||||
# test missing keys in delta_timestamps
|
# test missing keys in delta_timestamps
|
||||||
for key in delta_timestamps:
|
for key in delta_timestamps:
|
||||||
|
|
|
@ -29,7 +29,12 @@ DUMMY_MOTOR_FEATURES = {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
DUMMY_CAMERA_FEATURES = {
|
DUMMY_CAMERA_FEATURES = {
|
||||||
"laptop": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
|
"laptop": {
|
||||||
|
"shape": (480, 640, 3),
|
||||||
|
"names": ["height", "width", "channels"],
|
||||||
|
"info": None,
|
||||||
|
"audio": "laptop",
|
||||||
|
},
|
||||||
"phone": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
|
"phone": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
|
||||||
}
|
}
|
||||||
DEFAULT_FPS = 30
|
DEFAULT_FPS = 30
|
||||||
|
@ -40,5 +45,18 @@ DUMMY_VIDEO_INFO = {
|
||||||
"video.is_depth_map": False,
|
"video.is_depth_map": False,
|
||||||
"has_audio": False,
|
"has_audio": False,
|
||||||
}
|
}
|
||||||
|
DUMMY_MICROPHONE_FEATURES = {
|
||||||
|
"laptop": {"dtype": "audio", "shape": (1,), "names": ["channels"], "info": None},
|
||||||
|
"phone": {"dtype": "audio", "shape": (1,), "names": ["channels"], "info": None},
|
||||||
|
}
|
||||||
|
DEFAULT_SAMPLE_RATE = 48000
|
||||||
|
DUMMY_AUDIO_CHANNELS = 2
|
||||||
|
DUMMY_AUDIO_INFO = {
|
||||||
|
"has_audio": True,
|
||||||
|
"audio.sample_rate": DEFAULT_SAMPLE_RATE,
|
||||||
|
"audio.codec": "aac",
|
||||||
|
"audio.channels": DUMMY_AUDIO_CHANNELS,
|
||||||
|
"audio.channel_layout": "stereo",
|
||||||
|
}
|
||||||
DUMMY_CHW = (3, 96, 128)
|
DUMMY_CHW = (3, 96, 128)
|
||||||
DUMMY_HWC = (96, 128, 3)
|
DUMMY_HWC = (96, 128, 3)
|
||||||
|
|
|
@ -26,6 +26,7 @@ import torch
|
||||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
|
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
DEFAULT_CHUNK_SIZE,
|
DEFAULT_CHUNK_SIZE,
|
||||||
|
DEFAULT_COMPRESSED_AUDIO_PATH,
|
||||||
DEFAULT_FEATURES,
|
DEFAULT_FEATURES,
|
||||||
DEFAULT_PARQUET_PATH,
|
DEFAULT_PARQUET_PATH,
|
||||||
DEFAULT_VIDEO_PATH,
|
DEFAULT_VIDEO_PATH,
|
||||||
|
@ -35,6 +36,7 @@ from lerobot.common.datasets.utils import (
|
||||||
from tests.fixtures.constants import (
|
from tests.fixtures.constants import (
|
||||||
DEFAULT_FPS,
|
DEFAULT_FPS,
|
||||||
DUMMY_CAMERA_FEATURES,
|
DUMMY_CAMERA_FEATURES,
|
||||||
|
DUMMY_MICROPHONE_FEATURES,
|
||||||
DUMMY_MOTOR_FEATURES,
|
DUMMY_MOTOR_FEATURES,
|
||||||
DUMMY_REPO_ID,
|
DUMMY_REPO_ID,
|
||||||
DUMMY_ROBOT_TYPE,
|
DUMMY_ROBOT_TYPE,
|
||||||
|
@ -90,6 +92,7 @@ def features_factory():
|
||||||
def _create_features(
|
def _create_features(
|
||||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
||||||
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
||||||
|
audio_features: dict = DUMMY_MICROPHONE_FEATURES,
|
||||||
use_videos: bool = True,
|
use_videos: bool = True,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
if use_videos:
|
if use_videos:
|
||||||
|
@ -101,6 +104,7 @@ def features_factory():
|
||||||
return {
|
return {
|
||||||
**motor_features,
|
**motor_features,
|
||||||
**camera_ft,
|
**camera_ft,
|
||||||
|
**audio_features,
|
||||||
**DEFAULT_FEATURES,
|
**DEFAULT_FEATURES,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -117,15 +121,18 @@ def info_factory(features_factory):
|
||||||
total_frames: int = 0,
|
total_frames: int = 0,
|
||||||
total_tasks: int = 0,
|
total_tasks: int = 0,
|
||||||
total_videos: int = 0,
|
total_videos: int = 0,
|
||||||
|
total_audio: int = 0,
|
||||||
total_chunks: int = 0,
|
total_chunks: int = 0,
|
||||||
chunks_size: int = DEFAULT_CHUNK_SIZE,
|
chunks_size: int = DEFAULT_CHUNK_SIZE,
|
||||||
data_path: str = DEFAULT_PARQUET_PATH,
|
data_path: str = DEFAULT_PARQUET_PATH,
|
||||||
video_path: str = DEFAULT_VIDEO_PATH,
|
video_path: str = DEFAULT_VIDEO_PATH,
|
||||||
|
audio_path: str = DEFAULT_COMPRESSED_AUDIO_PATH,
|
||||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
||||||
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
||||||
|
audio_features: dict = DUMMY_MICROPHONE_FEATURES,
|
||||||
use_videos: bool = True,
|
use_videos: bool = True,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
features = features_factory(motor_features, camera_features, use_videos)
|
features = features_factory(motor_features, camera_features, audio_features, use_videos)
|
||||||
return {
|
return {
|
||||||
"codebase_version": codebase_version,
|
"codebase_version": codebase_version,
|
||||||
"robot_type": robot_type,
|
"robot_type": robot_type,
|
||||||
|
@ -133,12 +140,14 @@ def info_factory(features_factory):
|
||||||
"total_frames": total_frames,
|
"total_frames": total_frames,
|
||||||
"total_tasks": total_tasks,
|
"total_tasks": total_tasks,
|
||||||
"total_videos": total_videos,
|
"total_videos": total_videos,
|
||||||
|
"total_audio": total_audio,
|
||||||
"total_chunks": total_chunks,
|
"total_chunks": total_chunks,
|
||||||
"chunks_size": chunks_size,
|
"chunks_size": chunks_size,
|
||||||
"fps": fps,
|
"fps": fps,
|
||||||
"splits": {},
|
"splits": {},
|
||||||
"data_path": data_path,
|
"data_path": data_path,
|
||||||
"video_path": video_path if use_videos else None,
|
"video_path": video_path if use_videos else None,
|
||||||
|
"audio_path": audio_path,
|
||||||
"features": features,
|
"features": features,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -162,6 +171,14 @@ def stats_factory():
|
||||||
"std": np.full((3, 1, 1), 0.25, dtype=np.float32).tolist(),
|
"std": np.full((3, 1, 1), 0.25, dtype=np.float32).tolist(),
|
||||||
"count": [10],
|
"count": [10],
|
||||||
}
|
}
|
||||||
|
elif dtype == "audio":
|
||||||
|
stats[key] = {
|
||||||
|
"mean": np.full((shape[0],), 0.0, dtype=np.float32).tolist(),
|
||||||
|
"max": np.full((shape[0],), 1, dtype=np.float32).tolist(),
|
||||||
|
"min": np.full((shape[0],), -1, dtype=np.float32).tolist(),
|
||||||
|
"std": np.full((shape[0],), 0.5, dtype=np.float32).tolist(),
|
||||||
|
"count": [10],
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
stats[key] = {
|
stats[key] = {
|
||||||
"max": np.full(shape, 1, dtype=dtype).tolist(),
|
"max": np.full(shape, 1, dtype=dtype).tolist(),
|
||||||
|
|
|
@ -0,0 +1,88 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import time
|
||||||
|
from functools import cache
|
||||||
|
from threading import Event, Thread
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from lerobot.common.utils.utils import capture_timestamp_utc
|
||||||
|
from tests.fixtures.constants import DEFAULT_SAMPLE_RATE, DUMMY_AUDIO_CHANNELS
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def _generate_sound(duration: float, sample_rate: int, channels: int):
|
||||||
|
return np.random.uniform(-1, 1, size=(int(duration * sample_rate), channels)).astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def query_devices(query_index: int):
|
||||||
|
return {
|
||||||
|
"name": "Mock Sound Device",
|
||||||
|
"index": query_index,
|
||||||
|
"max_input_channels": DUMMY_AUDIO_CHANNELS,
|
||||||
|
"default_samplerate": DEFAULT_SAMPLE_RATE,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class InputStream:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self._mock_dict = {
|
||||||
|
"channels": DUMMY_AUDIO_CHANNELS,
|
||||||
|
"samplerate": DEFAULT_SAMPLE_RATE,
|
||||||
|
}
|
||||||
|
self._is_active = False
|
||||||
|
self._audio_callback = kwargs.get("callback")
|
||||||
|
|
||||||
|
self.callback_thread = None
|
||||||
|
self.callback_thread_stop_event = None
|
||||||
|
|
||||||
|
def _acquisition_loop(self):
|
||||||
|
if self._audio_callback is not None:
|
||||||
|
while not self.callback_thread_stop_event.is_set():
|
||||||
|
# Simulate audio data acquisition
|
||||||
|
time.sleep(0.01)
|
||||||
|
self._audio_callback(
|
||||||
|
_generate_sound(0.01, DEFAULT_SAMPLE_RATE, DUMMY_AUDIO_CHANNELS),
|
||||||
|
0.01 * DEFAULT_SAMPLE_RATE,
|
||||||
|
capture_timestamp_utc(),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
self.callback_thread_stop_event = Event()
|
||||||
|
self.callback_thread = Thread(target=self._acquisition_loop, args=())
|
||||||
|
self.callback_thread.daemon = True
|
||||||
|
self.callback_thread.start()
|
||||||
|
|
||||||
|
self._is_active = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def active(self):
|
||||||
|
return self._is_active
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
if self.callback_thread_stop_event is not None:
|
||||||
|
self.callback_thread_stop_event.set()
|
||||||
|
self.callback_thread.join()
|
||||||
|
self.callback_thread = None
|
||||||
|
self.callback_thread_stop_event = None
|
||||||
|
self._is_active = False
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
if self._is_active:
|
||||||
|
self.stop()
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
if self._is_active:
|
||||||
|
self.stop()
|
|
@ -0,0 +1,152 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Tests for physical microphones and their mocked versions.
|
||||||
|
If the physical microphone is not connected to the computer, or not working,
|
||||||
|
the test will be skipped.
|
||||||
|
|
||||||
|
Example of running a specific test:
|
||||||
|
```bash
|
||||||
|
pytest -sx tests/microphones/test_microphones.py::test_microphone
|
||||||
|
```
|
||||||
|
|
||||||
|
Example of running test on a real microphone connected to the computer:
|
||||||
|
```bash
|
||||||
|
pytest -sx 'tests/microphones/test_microphones.py::test_microphone[microphone-False]'
|
||||||
|
```
|
||||||
|
|
||||||
|
Example of running test on a mocked version of the microphone:
|
||||||
|
```bash
|
||||||
|
pytest -sx 'tests/microphones/test_microphones.py::test_microphone[microphone-True]'
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from soundfile import read
|
||||||
|
|
||||||
|
from lerobot.common.robot_devices.utils import (
|
||||||
|
RobotDeviceAlreadyConnectedError,
|
||||||
|
RobotDeviceAlreadyRecordingError,
|
||||||
|
RobotDeviceNotConnectedError,
|
||||||
|
RobotDeviceNotRecordingError,
|
||||||
|
)
|
||||||
|
from tests.utils import TEST_MICROPHONE_TYPES, make_microphone, require_microphone
|
||||||
|
|
||||||
|
# Maximum recording tie difference between two consecutive audio recordings of the same duration.
|
||||||
|
# Set to 0.02 seconds as twice the default size of sounddvice callback buffer (i.e. we tolerate the loss of one buffer).
|
||||||
|
MAX_RECORDING_TIME_DIFFERENCE = 0.02
|
||||||
|
|
||||||
|
DUMMY_RECORDING = "test_recording.wav"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("microphone_type, mock", TEST_MICROPHONE_TYPES)
|
||||||
|
@require_microphone
|
||||||
|
def test_microphone(tmp_path, request, microphone_type, mock):
|
||||||
|
"""Test assumes that a recroding handled with microphone.start_recording(output_file) and stop_recording() or microphone.read()
|
||||||
|
leqds to a sample that does not differ from the requested duration by more than 0.1 seconds.
|
||||||
|
"""
|
||||||
|
|
||||||
|
microphone_kwargs = {"microphone_type": microphone_type, "mock": mock}
|
||||||
|
|
||||||
|
# Test instantiating
|
||||||
|
microphone = make_microphone(**microphone_kwargs)
|
||||||
|
|
||||||
|
# Test start_recording, stop_recording, read and disconnecting before connecting raises an error
|
||||||
|
with pytest.raises(RobotDeviceNotConnectedError):
|
||||||
|
microphone.start_recording()
|
||||||
|
with pytest.raises(RobotDeviceNotConnectedError):
|
||||||
|
microphone.stop_recording()
|
||||||
|
with pytest.raises(RobotDeviceNotConnectedError):
|
||||||
|
microphone.read()
|
||||||
|
with pytest.raises(RobotDeviceNotConnectedError):
|
||||||
|
microphone.disconnect()
|
||||||
|
|
||||||
|
# Test deleting the object without connecting first
|
||||||
|
del microphone
|
||||||
|
|
||||||
|
# Test connecting
|
||||||
|
microphone = make_microphone(**microphone_kwargs)
|
||||||
|
microphone.connect()
|
||||||
|
assert microphone.is_connected
|
||||||
|
assert microphone.sample_rate is not None
|
||||||
|
assert microphone.channels is not None
|
||||||
|
|
||||||
|
# Test connecting twice raises an error
|
||||||
|
with pytest.raises(RobotDeviceAlreadyConnectedError):
|
||||||
|
microphone.connect()
|
||||||
|
|
||||||
|
# Test reading or stop recording before starting recording raises an error
|
||||||
|
with pytest.raises(RobotDeviceNotRecordingError):
|
||||||
|
microphone.read()
|
||||||
|
with pytest.raises(RobotDeviceNotRecordingError):
|
||||||
|
microphone.stop_recording()
|
||||||
|
|
||||||
|
# Test start_recording
|
||||||
|
fpath = tmp_path / DUMMY_RECORDING
|
||||||
|
microphone.start_recording(fpath)
|
||||||
|
assert microphone.is_recording
|
||||||
|
|
||||||
|
# Test start_recording twice raises an error
|
||||||
|
with pytest.raises(RobotDeviceAlreadyRecordingError):
|
||||||
|
microphone.start_recording()
|
||||||
|
|
||||||
|
# Test reading from the microphone
|
||||||
|
time.sleep(1.0)
|
||||||
|
audio_chunk = microphone.read()
|
||||||
|
assert isinstance(audio_chunk, np.ndarray)
|
||||||
|
assert audio_chunk.ndim == 2
|
||||||
|
_, c = audio_chunk.shape
|
||||||
|
assert c == len(microphone.channels)
|
||||||
|
|
||||||
|
# Test stop_recording
|
||||||
|
microphone.stop_recording()
|
||||||
|
assert fpath.exists()
|
||||||
|
assert not microphone.stream.active
|
||||||
|
assert microphone.record_thread is None
|
||||||
|
|
||||||
|
# Test stop_recording twice raises an error
|
||||||
|
with pytest.raises(RobotDeviceNotRecordingError):
|
||||||
|
microphone.stop_recording()
|
||||||
|
|
||||||
|
# Test reading and recording output similar length audio chunks
|
||||||
|
microphone.start_recording(tmp_path / DUMMY_RECORDING)
|
||||||
|
time.sleep(1.0)
|
||||||
|
audio_chunk = microphone.read()
|
||||||
|
microphone.stop_recording()
|
||||||
|
|
||||||
|
recorded_audio, recorded_sample_rate = read(fpath)
|
||||||
|
assert recorded_sample_rate == microphone.sample_rate
|
||||||
|
|
||||||
|
error_msg = (
|
||||||
|
"Recording time difference between read() and stop_recording()",
|
||||||
|
(len(audio_chunk) - len(recorded_audio)) / MAX_RECORDING_TIME_DIFFERENCE,
|
||||||
|
)
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
len(audio_chunk),
|
||||||
|
len(recorded_audio),
|
||||||
|
atol=recorded_sample_rate * MAX_RECORDING_TIME_DIFFERENCE,
|
||||||
|
err_msg=error_msg,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test disconnecting
|
||||||
|
microphone.disconnect()
|
||||||
|
assert not microphone.is_connected
|
||||||
|
|
||||||
|
# Test disconnecting with `__del__`
|
||||||
|
microphone = make_microphone(**microphone_kwargs)
|
||||||
|
microphone.connect()
|
||||||
|
del microphone
|
|
@ -100,6 +100,9 @@ def test_robot(tmp_path, request, robot_type, mock):
|
||||||
robot.teleop_step()
|
robot.teleop_step()
|
||||||
|
|
||||||
# Test data recorded during teleop are well formatted
|
# Test data recorded during teleop are well formatted
|
||||||
|
for _, microphone in robot.microphones.items():
|
||||||
|
microphone.start_recording()
|
||||||
|
|
||||||
observation, action = robot.teleop_step(record_data=True)
|
observation, action = robot.teleop_step(record_data=True)
|
||||||
# State
|
# State
|
||||||
assert "observation.state" in observation
|
assert "observation.state" in observation
|
||||||
|
@ -112,6 +115,11 @@ def test_robot(tmp_path, request, robot_type, mock):
|
||||||
assert f"observation.images.{name}" in observation
|
assert f"observation.images.{name}" in observation
|
||||||
assert isinstance(observation[f"observation.images.{name}"], torch.Tensor)
|
assert isinstance(observation[f"observation.images.{name}"], torch.Tensor)
|
||||||
assert observation[f"observation.images.{name}"].ndim == 3
|
assert observation[f"observation.images.{name}"].ndim == 3
|
||||||
|
# Microphones
|
||||||
|
for name in robot.microphones:
|
||||||
|
assert f"observation.audio.{name}" in observation
|
||||||
|
assert isinstance(observation[f"observation.audio.{name}"], torch.Tensor)
|
||||||
|
assert observation[f"observation.audio.{name}"].ndim == 2
|
||||||
# Action
|
# Action
|
||||||
assert "action" in action
|
assert "action" in action
|
||||||
assert isinstance(action["action"], torch.Tensor)
|
assert isinstance(action["action"], torch.Tensor)
|
||||||
|
@ -124,8 +132,9 @@ def test_robot(tmp_path, request, robot_type, mock):
|
||||||
captured_observation = robot.capture_observation()
|
captured_observation = robot.capture_observation()
|
||||||
assert set(captured_observation.keys()) == set(observation.keys())
|
assert set(captured_observation.keys()) == set(observation.keys())
|
||||||
for name in captured_observation:
|
for name in captured_observation:
|
||||||
if "image" in name:
|
if "image" in name or "audio" in name:
|
||||||
# TODO(rcadene): skipping image for now as it's challenging to assess equality between two consecutive frames
|
# TODO(rcadene): skipping image for now as it's challenging to assess equality between two consecutive frames
|
||||||
|
# Also skipping for audio as audio chunks may be of different length
|
||||||
continue
|
continue
|
||||||
torch.testing.assert_close(captured_observation[name], observation[name], rtol=1e-4, atol=1)
|
torch.testing.assert_close(captured_observation[name], observation[name], rtol=1e-4, atol=1)
|
||||||
assert captured_observation[name].shape == observation[name].shape
|
assert captured_observation[name].shape == observation[name].shape
|
||||||
|
@ -134,7 +143,7 @@ def test_robot(tmp_path, request, robot_type, mock):
|
||||||
robot.send_action(action["action"])
|
robot.send_action(action["action"])
|
||||||
|
|
||||||
# Test disconnecting
|
# Test disconnecting
|
||||||
robot.disconnect()
|
robot.disconnect() # Also handles microphone recording stop, life is beautiful
|
||||||
assert not robot.is_connected
|
assert not robot.is_connected
|
||||||
for name in robot.follower_arms:
|
for name in robot.follower_arms:
|
||||||
assert not robot.follower_arms[name].is_connected
|
assert not robot.follower_arms[name].is_connected
|
||||||
|
@ -142,3 +151,5 @@ def test_robot(tmp_path, request, robot_type, mock):
|
||||||
assert not robot.leader_arms[name].is_connected
|
assert not robot.leader_arms[name].is_connected
|
||||||
for name in robot.cameras:
|
for name in robot.cameras:
|
||||||
assert not robot.cameras[name].is_connected
|
assert not robot.cameras[name].is_connected
|
||||||
|
for name in robot.microphones:
|
||||||
|
assert not robot.microphones[name].is_connected
|
||||||
|
|
|
@ -22,9 +22,11 @@ from pathlib import Path
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot import available_cameras, available_motors, available_robots
|
from lerobot import available_cameras, available_microphones, available_motors, available_robots
|
||||||
from lerobot.common.robot_devices.cameras.utils import Camera
|
from lerobot.common.robot_devices.cameras.utils import Camera
|
||||||
from lerobot.common.robot_devices.cameras.utils import make_camera as make_camera_device
|
from lerobot.common.robot_devices.cameras.utils import make_camera as make_camera_device
|
||||||
|
from lerobot.common.robot_devices.microphones.utils import Microphone
|
||||||
|
from lerobot.common.robot_devices.microphones.utils import make_microphone as make_microphone_device
|
||||||
from lerobot.common.robot_devices.motors.utils import MotorsBus
|
from lerobot.common.robot_devices.motors.utils import MotorsBus
|
||||||
from lerobot.common.robot_devices.motors.utils import make_motors_bus as make_motors_bus_device
|
from lerobot.common.robot_devices.motors.utils import make_motors_bus as make_motors_bus_device
|
||||||
from lerobot.common.utils.import_utils import is_package_available
|
from lerobot.common.utils.import_utils import is_package_available
|
||||||
|
@ -39,6 +41,10 @@ TEST_CAMERA_TYPES = []
|
||||||
for camera_type in available_cameras:
|
for camera_type in available_cameras:
|
||||||
TEST_CAMERA_TYPES += [(camera_type, True), (camera_type, False)]
|
TEST_CAMERA_TYPES += [(camera_type, True), (camera_type, False)]
|
||||||
|
|
||||||
|
TEST_MICROPHONE_TYPES = []
|
||||||
|
for microphone_type in available_microphones:
|
||||||
|
TEST_MICROPHONE_TYPES += [(microphone_type, True), (microphone_type, False)]
|
||||||
|
|
||||||
TEST_MOTOR_TYPES = []
|
TEST_MOTOR_TYPES = []
|
||||||
for motor_type in available_motors:
|
for motor_type in available_motors:
|
||||||
TEST_MOTOR_TYPES += [(motor_type, True), (motor_type, False)]
|
TEST_MOTOR_TYPES += [(motor_type, True), (motor_type, False)]
|
||||||
|
@ -47,6 +53,9 @@ for motor_type in available_motors:
|
||||||
OPENCV_CAMERA_INDEX = int(os.environ.get("LEROBOT_TEST_OPENCV_CAMERA_INDEX", 0))
|
OPENCV_CAMERA_INDEX = int(os.environ.get("LEROBOT_TEST_OPENCV_CAMERA_INDEX", 0))
|
||||||
INTELREALSENSE_SERIAL_NUMBER = int(os.environ.get("LEROBOT_TEST_INTELREALSENSE_SERIAL_NUMBER", 128422271614))
|
INTELREALSENSE_SERIAL_NUMBER = int(os.environ.get("LEROBOT_TEST_INTELREALSENSE_SERIAL_NUMBER", 128422271614))
|
||||||
|
|
||||||
|
# Microphone indices used for connecting physical microphones
|
||||||
|
MICROPHONE_INDEX = int(os.environ.get("LEROBOT_TEST_MICROPHONE_INDEX", 0))
|
||||||
|
|
||||||
DYNAMIXEL_PORT = os.environ.get("LEROBOT_TEST_DYNAMIXEL_PORT", "/dev/tty.usbmodem575E0032081")
|
DYNAMIXEL_PORT = os.environ.get("LEROBOT_TEST_DYNAMIXEL_PORT", "/dev/tty.usbmodem575E0032081")
|
||||||
DYNAMIXEL_MOTORS = {
|
DYNAMIXEL_MOTORS = {
|
||||||
"shoulder_pan": [1, "xl430-w250"],
|
"shoulder_pan": [1, "xl430-w250"],
|
||||||
|
@ -253,6 +262,29 @@ def require_camera(func):
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def require_microphone(func):
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
# Access the pytest request context to get the is_microphone_available fixture
|
||||||
|
request = kwargs.get("request")
|
||||||
|
microphone_type = kwargs.get("microphone_type")
|
||||||
|
mock = kwargs.get("mock")
|
||||||
|
|
||||||
|
if request is None:
|
||||||
|
raise ValueError("The 'request' fixture must be an argument of the test function.")
|
||||||
|
if microphone_type is None:
|
||||||
|
raise ValueError("The 'microphone_type' must be an argument of the test function.")
|
||||||
|
if mock is None:
|
||||||
|
raise ValueError("The 'mock' variable must be an argument of the test function.")
|
||||||
|
|
||||||
|
if not mock and not request.getfixturevalue("is_microphone_available"):
|
||||||
|
pytest.skip(f"A {microphone_type} microphone is not available.")
|
||||||
|
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def require_motor(func):
|
def require_motor(func):
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
|
@ -315,6 +347,14 @@ def make_camera(camera_type: str, **kwargs) -> Camera:
|
||||||
raise ValueError(f"The camera type '{camera_type}' is not valid.")
|
raise ValueError(f"The camera type '{camera_type}' is not valid.")
|
||||||
|
|
||||||
|
|
||||||
|
def make_microphone(microphone_type: str, **kwargs) -> Microphone:
|
||||||
|
if microphone_type == "microphone":
|
||||||
|
microphone_index = kwargs.pop("microphone_index", MICROPHONE_INDEX)
|
||||||
|
return make_microphone_device(microphone_type, microphone_index=microphone_index, **kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"The microphone type '{microphone_type}' is not valid.")
|
||||||
|
|
||||||
|
|
||||||
# TODO(rcadene, aliberts): remove this dark pattern that overrides
|
# TODO(rcadene, aliberts): remove this dark pattern that overrides
|
||||||
def make_motors_bus(motor_type: str, **kwargs) -> MotorsBus:
|
def make_motors_bus(motor_type: str, **kwargs) -> MotorsBus:
|
||||||
if motor_type == "dynamixel":
|
if motor_type == "dynamixel":
|
||||||
|
|
Loading…
Reference in New Issue