[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
9c667d347c
commit
0cb9345f06
|
@ -15,7 +15,8 @@
|
|||
# limitations under the License.
|
||||
import numpy as np
|
||||
|
||||
from lerobot.common.datasets.utils import load_image_as_numpy, load_audio_from_path
|
||||
from lerobot.common.datasets.utils import load_audio_from_path, load_image_as_numpy
|
||||
|
||||
|
||||
def estimate_num_samples(
|
||||
dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75
|
||||
|
@ -70,17 +71,19 @@ def sample_images(image_paths: list[str]) -> np.ndarray:
|
|||
|
||||
return images
|
||||
|
||||
def sample_audio_from_path(audio_path: str) -> np.ndarray:
|
||||
|
||||
def sample_audio_from_path(audio_path: str) -> np.ndarray:
|
||||
data = load_audio_from_path(audio_path)
|
||||
sampled_indices = sample_indices(len(data))
|
||||
|
||||
return(data[sampled_indices])
|
||||
return data[sampled_indices]
|
||||
|
||||
|
||||
def sample_audio_from_data(data: np.ndarray) -> np.ndarray:
|
||||
sampled_indices = sample_indices(len(data))
|
||||
return data[sampled_indices]
|
||||
|
||||
|
||||
def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]:
|
||||
return {
|
||||
"min": np.min(array, axis=axis, keepdims=keepdims),
|
||||
|
@ -103,7 +106,7 @@ def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], featu
|
|||
elif features[key]["dtype"] == "audio":
|
||||
try:
|
||||
ep_ft_array = sample_audio_from_path(data[0])
|
||||
except TypeError: #Should only be triggered for LeKiwi robot
|
||||
except TypeError: # Should only be triggered for LeKiwi robot
|
||||
ep_ft_array = sample_audio_from_data(data)
|
||||
axes_to_reduce = 0
|
||||
keepdims = True
|
||||
|
|
|
@ -23,6 +23,7 @@ import datasets
|
|||
import numpy as np
|
||||
import packaging.version
|
||||
import PIL.Image
|
||||
import soundfile as sf
|
||||
import torch
|
||||
import torch.utils
|
||||
from datasets import concatenate_datasets, load_dataset
|
||||
|
@ -34,13 +35,12 @@ from lerobot.common.constants import HF_LEROBOT_HOME
|
|||
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
|
||||
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
|
||||
from lerobot.common.datasets.utils import (
|
||||
DEFAULT_AUDIO_CHUNK_DURATION,
|
||||
DEFAULT_FEATURES,
|
||||
DEFAULT_IMAGE_PATH,
|
||||
DEFAULT_RAW_AUDIO_PATH,
|
||||
DEFAULT_COMPRESSED_AUDIO_PATH,
|
||||
INFO_PATH,
|
||||
TASKS_PATH,
|
||||
DEFAULT_AUDIO_CHUNK_DURATION,
|
||||
append_jsonlines,
|
||||
backward_compatible_episodes_stats,
|
||||
check_delta_timestamps,
|
||||
|
@ -70,17 +70,16 @@ from lerobot.common.datasets.utils import (
|
|||
)
|
||||
from lerobot.common.datasets.video_utils import (
|
||||
VideoFrame,
|
||||
decode_video_frames,
|
||||
encode_video_frames,
|
||||
encode_audio,
|
||||
decode_audio,
|
||||
decode_video_frames,
|
||||
encode_audio,
|
||||
encode_video_frames,
|
||||
get_audio_info,
|
||||
get_safe_default_codec,
|
||||
get_video_info,
|
||||
get_audio_info,
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from lerobot.common.robot_devices.microphones.utils import Microphone
|
||||
import soundfile as sf
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
|
||||
CODEBASE_VERSION = "v2.1"
|
||||
|
||||
|
@ -152,7 +151,9 @@ class LeRobotDatasetMetadata:
|
|||
|
||||
def get_compressed_audio_file_path(self, episode_index: int, audio_key: str) -> Path:
|
||||
episode_chunk = self.get_episode_chunk(episode_index)
|
||||
fpath = self.audio_path.format(episode_chunk=episode_chunk, audio_key=audio_key, episode_index=episode_index)
|
||||
fpath = self.audio_path.format(
|
||||
episode_chunk=episode_chunk, audio_key=audio_key, episode_index=episode_index
|
||||
)
|
||||
return self.root / fpath
|
||||
|
||||
def get_episode_chunk(self, ep_index: int) -> int:
|
||||
|
@ -211,7 +212,12 @@ class LeRobotDatasetMetadata:
|
|||
@property
|
||||
def audio_camera_keys_mapping(self) -> dict[str, str]:
|
||||
"""Mapping between camera keys and audio keys when both are linked."""
|
||||
return {self.features[camera_key]["audio"]:camera_key for camera_key in self.camera_keys if self.features[camera_key]["audio"] is not None and self.features[camera_key]["dtype"] == "video"}
|
||||
return {
|
||||
self.features[camera_key]["audio"]: camera_key
|
||||
for camera_key in self.camera_keys
|
||||
if self.features[camera_key]["audio"] is not None
|
||||
and self.features[camera_key]["dtype"] == "video"
|
||||
}
|
||||
|
||||
@property
|
||||
def names(self) -> dict[str, list | dict]:
|
||||
|
@ -325,7 +331,9 @@ class LeRobotDatasetMetadata:
|
|||
been encoded the same way. Also, this means it assumes the first episode exists.
|
||||
"""
|
||||
for key in set(self.audio_keys) - set(self.audio_camera_keys_mapping.keys()):
|
||||
if not self.features[key].get("info", None) or (len(self.features[key]["info"]) == 1 and "sample_rate" in self.features[key]["info"]):
|
||||
if not self.features[key].get("info", None) or (
|
||||
len(self.features[key]["info"]) == 1 and "sample_rate" in self.features[key]["info"]
|
||||
):
|
||||
audio_path = self.root / self.get_compressed_audio_file_path(0, key)
|
||||
self.info["features"][key]["info"] = get_audio_info(audio_path)
|
||||
|
||||
|
@ -518,7 +526,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
self.tolerance_s = tolerance_s
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self.video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||
self.audio_backend = audio_backend if audio_backend else "ffmpeg" #Waiting for torchcodec release #TODO(CarolinePascal)
|
||||
self.audio_backend = (
|
||||
audio_backend if audio_backend else "ffmpeg"
|
||||
) # Waiting for torchcodec release #TODO(CarolinePascal)
|
||||
self.delta_indices = None
|
||||
|
||||
# Unused attributes
|
||||
|
@ -543,7 +553,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
self.hf_dataset = self.load_hf_dataset()
|
||||
except (AssertionError, FileNotFoundError, NotADirectoryError):
|
||||
self.revision = get_safe_version(self.repo_id, self.revision)
|
||||
self.download_episodes(download_videos) #Sould load audio as well #TODO(CarolinePascal): separate audio from video
|
||||
self.download_episodes(
|
||||
download_videos
|
||||
) # Sould load audio as well #TODO(CarolinePascal): separate audio from video
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
|
||||
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
||||
|
@ -736,7 +748,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
|
||||
return query_timestamps
|
||||
|
||||
#TODO(CarolinePascal): add variable query durations
|
||||
# TODO(CarolinePascal): add variable query durations
|
||||
def _get_query_timestamps_audio(
|
||||
self,
|
||||
current_ts: float,
|
||||
|
@ -773,14 +785,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
|
||||
return item
|
||||
|
||||
#TODO(CarolinePascal): add variable query durations
|
||||
def _query_audio(self, query_timestamps: dict[str, list[float]], query_duration: float, ep_idx: int) -> dict[str, torch.Tensor]:
|
||||
# TODO(CarolinePascal): add variable query durations
|
||||
def _query_audio(
|
||||
self, query_timestamps: dict[str, list[float]], query_duration: float, ep_idx: int
|
||||
) -> dict[str, torch.Tensor]:
|
||||
item = {}
|
||||
for audio_key, query_ts in query_timestamps.items():
|
||||
#Audio stored with video in a single .mp4 file
|
||||
# Audio stored with video in a single .mp4 file
|
||||
if audio_key in self.meta.audio_camera_keys_mapping:
|
||||
audio_path = self.root / self.meta.get_video_file_path(ep_idx, self.meta.audio_camera_keys_mapping[audio_key])
|
||||
#Audio stored alone in a separate .m4a file
|
||||
audio_path = self.root / self.meta.get_video_file_path(
|
||||
ep_idx, self.meta.audio_camera_keys_mapping[audio_key]
|
||||
)
|
||||
# Audio stored alone in a separate .m4a file
|
||||
else:
|
||||
audio_path = self.root / self.meta.get_compressed_audio_file_path(ep_idx, audio_key)
|
||||
audio_chunk = decode_audio(audio_path, query_ts, query_duration, self.audio_backend)
|
||||
|
@ -929,11 +945,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
This function will start recording audio from the microphone and save it to disk.
|
||||
"""
|
||||
|
||||
audio_dir = self._get_raw_audio_file_path(self.num_episodes, "observation.audio." + microphone_key).parent
|
||||
audio_dir = self._get_raw_audio_file_path(
|
||||
self.num_episodes, "observation.audio." + microphone_key
|
||||
).parent
|
||||
if not audio_dir.is_dir():
|
||||
audio_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
microphone.start_recording(output_file = self._get_raw_audio_file_path(self.num_episodes, "observation.audio." + microphone_key))
|
||||
microphone.start_recording(
|
||||
output_file=self._get_raw_audio_file_path(
|
||||
self.num_episodes, "observation.audio." + microphone_key
|
||||
)
|
||||
)
|
||||
|
||||
def save_episode(self, episode_data: dict | None = None) -> None:
|
||||
"""
|
||||
|
@ -983,8 +1005,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
|
||||
if self.meta.robot_type.startswith("lekiwi"):
|
||||
for key in self.meta.audio_keys:
|
||||
audio_path = self._get_raw_audio_file_path(episode_index=self.episode_buffer["episode_index"][0], audio_key=key)
|
||||
with sf.SoundFile(audio_path, mode='w', samplerate=self.meta.features[key]["info"]["sample_rate"], channels=self.meta.features[key]["shape"][0]) as file:
|
||||
audio_path = self._get_raw_audio_file_path(
|
||||
episode_index=self.episode_buffer["episode_index"][0], audio_key=key
|
||||
)
|
||||
with sf.SoundFile(
|
||||
audio_path,
|
||||
mode="w",
|
||||
samplerate=self.meta.features[key]["info"]["sample_rate"],
|
||||
channels=self.meta.features[key]["shape"][0],
|
||||
) as file:
|
||||
file.write(episode_buffer[key])
|
||||
|
||||
ep_stats = compute_episode_stats(episode_buffer, self.features)
|
||||
|
@ -1180,7 +1209,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
obj.delta_indices = None
|
||||
obj.episode_data_index = None
|
||||
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
||||
obj.audio_backend = audio_backend if audio_backend is not None else "ffmpeg" #Waiting for torchcodec release #TODO(CarolinePascal)
|
||||
obj.audio_backend = (
|
||||
audio_backend if audio_backend is not None else "ffmpeg"
|
||||
) # Waiting for torchcodec release #TODO(CarolinePascal)
|
||||
return obj
|
||||
|
||||
|
||||
|
|
|
@ -33,9 +33,8 @@ from datasets.table import embed_table_storage
|
|||
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
|
||||
from huggingface_hub.errors import RevisionNotFoundError
|
||||
from PIL import Image as PILImage
|
||||
from torchvision import transforms
|
||||
|
||||
from soundfile import read
|
||||
from torchvision import transforms
|
||||
|
||||
from lerobot.common.datasets.backward_compatibility import (
|
||||
V21_MESSAGE,
|
||||
|
@ -260,10 +259,12 @@ def load_image_as_numpy(
|
|||
img_array /= 255.0
|
||||
return img_array
|
||||
|
||||
|
||||
def load_audio_from_path(fpath: str | Path) -> np.ndarray:
|
||||
audio_data, _ = read(fpath, dtype="float32")
|
||||
return audio_data
|
||||
|
||||
|
||||
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
|
||||
to torch tensors. Importantly, images are converted from PIL, which corresponds to
|
||||
|
@ -793,18 +794,24 @@ def validate_feature_image_or_video(name: str, expected_shape: list[str], value:
|
|||
|
||||
return error_message
|
||||
|
||||
|
||||
def validate_feature_audio(name: str, expected_shape: list[str], value: np.ndarray):
|
||||
error_message = ""
|
||||
if isinstance(value, np.ndarray):
|
||||
actual_shape = value.shape
|
||||
c = expected_shape
|
||||
if len(actual_shape) != 2 or (actual_shape[-1] != c[-1] and actual_shape[0] != c[0]): #The number of frames might be different
|
||||
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c,)}'.\n"
|
||||
if len(actual_shape) != 2 or (
|
||||
actual_shape[-1] != c[-1] and actual_shape[0] != c[0]
|
||||
): # The number of frames might be different
|
||||
error_message += (
|
||||
f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c,)}'.\n"
|
||||
)
|
||||
else:
|
||||
error_message += f"The feature '{name}' is expected to be of type 'np.ndarray', but type '{type(value)}' provided instead.\n"
|
||||
|
||||
return error_message
|
||||
|
||||
|
||||
def validate_feature_string(name: str, value: str):
|
||||
if not isinstance(value, str):
|
||||
return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n"
|
||||
|
|
|
@ -25,12 +25,11 @@ from typing import Any, ClassVar
|
|||
|
||||
import pyarrow as pa
|
||||
import torch
|
||||
import torchvision
|
||||
import torchaudio
|
||||
import torchvision
|
||||
from datasets.features.features import register_feature
|
||||
from PIL import Image
|
||||
|
||||
from numpy import ceil
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def get_safe_default_codec():
|
||||
|
@ -42,6 +41,7 @@ def get_safe_default_codec():
|
|||
)
|
||||
return "pyav"
|
||||
|
||||
|
||||
def decode_audio(
|
||||
audio_path: Path | str,
|
||||
timestamps: list[float],
|
||||
|
@ -68,30 +68,30 @@ def decode_audio(
|
|||
else:
|
||||
raise ValueError(f"Unsupported video backend: {backend}")
|
||||
|
||||
|
||||
def decode_audio_torchvision(
|
||||
audio_path: Path | str,
|
||||
timestamps: list[float],
|
||||
duration: float,
|
||||
log_loaded_timestamps: bool = False,
|
||||
) -> torch.Tensor:
|
||||
|
||||
#TODO(CarolinePascal) : add channels selection
|
||||
# TODO(CarolinePascal) : add channels selection
|
||||
audio_path = str(audio_path)
|
||||
|
||||
reader = torchaudio.io.StreamReader(src=audio_path)
|
||||
audio_sample_rate = reader.get_src_stream_info(reader.default_audio_stream).sample_rate
|
||||
|
||||
#TODO(CarolinePascal) : sort timestamps ?
|
||||
# TODO(CarolinePascal) : sort timestamps ?
|
||||
|
||||
reader.add_basic_audio_stream(
|
||||
frames_per_chunk = int(ceil(duration * audio_sample_rate)), #Too much is better than not enough
|
||||
buffer_chunk_size = -1, #No dropping frames
|
||||
format = "fltp", #Format as float32
|
||||
frames_per_chunk=int(ceil(duration * audio_sample_rate)), # Too much is better than not enough
|
||||
buffer_chunk_size=-1, # No dropping frames
|
||||
format="fltp", # Format as float32
|
||||
)
|
||||
|
||||
audio_chunks = []
|
||||
for ts in timestamps:
|
||||
reader.seek(ts) #Default to closest audio sample
|
||||
reader.seek(ts) # Default to closest audio sample
|
||||
status = reader.fill_buffer()
|
||||
if status != 0:
|
||||
logging.warning("Audio stream reached end of recording before decoding desired timestamps.")
|
||||
|
@ -99,7 +99,9 @@ def decode_audio_torchvision(
|
|||
current_audio_chunk = reader.pop_chunks()[0]
|
||||
|
||||
if log_loaded_timestamps:
|
||||
logging.info(f"audio chunk loaded at starting timestamp={current_audio_chunk["pts"]:.4f} with duration={len(current_audio_chunk) / audio_sample_rate:.4f}")
|
||||
logging.info(
|
||||
f"audio chunk loaded at starting timestamp={current_audio_chunk['pts']:.4f} with duration={len(current_audio_chunk) / audio_sample_rate:.4f}"
|
||||
)
|
||||
|
||||
audio_chunks.append(current_audio_chunk)
|
||||
|
||||
|
@ -108,6 +110,7 @@ def decode_audio_torchvision(
|
|||
assert len(timestamps) == len(audio_chunks)
|
||||
return audio_chunks
|
||||
|
||||
|
||||
def decode_video_frames(
|
||||
video_path: Path | str,
|
||||
timestamps: list[float],
|
||||
|
@ -137,6 +140,7 @@ def decode_video_frames(
|
|||
else:
|
||||
raise ValueError(f"Unsupported video backend: {backend}")
|
||||
|
||||
|
||||
def decode_video_frames_torchvision(
|
||||
video_path: Path | str,
|
||||
timestamps: list[float],
|
||||
|
@ -234,6 +238,7 @@ def decode_video_frames_torchvision(
|
|||
assert len(timestamps) == len(closest_frames)
|
||||
return closest_frames
|
||||
|
||||
|
||||
def decode_video_frames_torchcodec(
|
||||
video_path: Path | str,
|
||||
timestamps: list[float],
|
||||
|
@ -308,6 +313,7 @@ def decode_video_frames_torchcodec(
|
|||
assert len(timestamps) == len(closest_frames)
|
||||
return closest_frames
|
||||
|
||||
|
||||
def encode_audio(
|
||||
input_path: Path | str,
|
||||
output_path: Path | str,
|
||||
|
@ -344,6 +350,7 @@ def encode_audio(
|
|||
f"Try running the command manually to debug: `{''.join(ffmpeg_cmd)}`"
|
||||
)
|
||||
|
||||
|
||||
def encode_video_frames(
|
||||
imgs_dir: Path | str,
|
||||
video_path: Path | str,
|
||||
|
@ -353,7 +360,7 @@ def encode_video_frames(
|
|||
pix_fmt: str = "yuv420p",
|
||||
g: int | None = 2,
|
||||
crf: int | None = 30,
|
||||
acodec: str = "aac", #TODO(CarolinePascal) : investigate Fraunhofer FDK AAC (libfdk_aac) codec and and constant (file size control) /variable (quality control) bitrate options
|
||||
acodec: str = "aac", # TODO(CarolinePascal) : investigate Fraunhofer FDK AAC (libfdk_aac) codec and and constant (file size control) /variable (quality control) bitrate options
|
||||
fast_decode: int = 0,
|
||||
log_level: str | None = "error",
|
||||
overwrite: bool = False,
|
||||
|
@ -375,11 +382,13 @@ def encode_video_frames(
|
|||
if audio_path is not None:
|
||||
audio_path = Path(audio_path)
|
||||
audio_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
ffmpeg_audio_args.update(OrderedDict(
|
||||
ffmpeg_audio_args.update(
|
||||
OrderedDict(
|
||||
[
|
||||
("-i", str(audio_path)),
|
||||
]
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
ffmpeg_encoding_args = OrderedDict(
|
||||
[
|
||||
|
@ -487,6 +496,7 @@ def get_audio_info(video_path: Path | str) -> dict:
|
|||
"audio.channel_layout": audio_stream_info.get("channel_layout", None),
|
||||
}
|
||||
|
||||
|
||||
def get_video_info(video_path: Path | str) -> dict:
|
||||
ffprobe_video_cmd = [
|
||||
"ffprobe",
|
||||
|
|
|
@ -252,9 +252,11 @@ def control_loop(
|
|||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
|
||||
if dataset is not None and not robot.robot_type.startswith("lekiwi"): #For now, LeKiwi only supports frame audio recording (which may lead to audio chunks loss, extended post-processing, increased memory usage)
|
||||
if (
|
||||
dataset is not None and not robot.robot_type.startswith("lekiwi")
|
||||
): # For now, LeKiwi only supports frame audio recording (which may lead to audio chunks loss, extended post-processing, increased memory usage)
|
||||
for microphone_key, microphone in robot.microphones.items():
|
||||
#Start recording both in file writing and data reading mode
|
||||
# Start recording both in file writing and data reading mode
|
||||
dataset.add_microphone_recording(microphone, microphone_key)
|
||||
else:
|
||||
for _, microphone in robot.microphones.items():
|
||||
|
|
|
@ -17,12 +17,14 @@ from dataclasses import dataclass
|
|||
|
||||
import draccus
|
||||
|
||||
|
||||
@dataclass
|
||||
class MicrophoneConfigBase(draccus.ChoiceRegistry, abc.ABC):
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
|
||||
@MicrophoneConfigBase.register_subclass("microphone")
|
||||
@dataclass
|
||||
class MicrophoneConfig(MicrophoneConfigBase):
|
||||
|
|
|
@ -17,32 +17,31 @@ This file contains utilities for recording audio from a microhone.
|
|||
"""
|
||||
|
||||
import argparse
|
||||
import soundfile as sf
|
||||
import numpy as np
|
||||
import logging
|
||||
from threading import Thread, Event
|
||||
from multiprocessing import Process
|
||||
from queue import Empty
|
||||
|
||||
from queue import Queue as thread_Queue
|
||||
from threading import Event as thread_Event
|
||||
from multiprocessing import JoinableQueue as process_Queue
|
||||
from multiprocessing import Event as process_Event
|
||||
|
||||
from os import getcwd
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import time
|
||||
from multiprocessing import Event as process_Event
|
||||
from multiprocessing import JoinableQueue as process_Queue
|
||||
from multiprocessing import Process
|
||||
from os import getcwd
|
||||
from pathlib import Path
|
||||
from queue import Empty
|
||||
from queue import Queue as thread_Queue
|
||||
from threading import Event, Thread
|
||||
from threading import Event as thread_Event
|
||||
|
||||
from lerobot.common.utils.utils import capture_timestamp_utc
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
|
||||
from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig
|
||||
from lerobot.common.robot_devices.utils import (
|
||||
RobotDeviceAlreadyConnectedError,
|
||||
RobotDeviceAlreadyRecordingError,
|
||||
RobotDeviceNotConnectedError,
|
||||
RobotDeviceNotRecordingError,
|
||||
RobotDeviceAlreadyRecordingError,
|
||||
)
|
||||
from lerobot.common.utils.utils import capture_timestamp_utc
|
||||
|
||||
|
||||
def find_microphones(raise_when_empty=False, mock=False) -> list[dict]:
|
||||
microphones = []
|
||||
|
@ -69,11 +68,10 @@ def find_microphones(raise_when_empty=False, mock=False) -> list[dict]:
|
|||
|
||||
return microphones
|
||||
|
||||
def record_audio_from_microphones(
|
||||
output_dir: Path,
|
||||
microphone_ids: list[int] | None = None,
|
||||
record_time_s: float = 2.0):
|
||||
|
||||
def record_audio_from_microphones(
|
||||
output_dir: Path, microphone_ids: list[int] | None = None, record_time_s: float = 2.0
|
||||
):
|
||||
if microphone_ids is None or len(microphone_ids) == 0:
|
||||
microphones = find_microphones()
|
||||
microphone_ids = [m["index"] for m in microphones]
|
||||
|
@ -104,13 +102,14 @@ def record_audio_from_microphones(
|
|||
for microphone in microphones:
|
||||
microphone.stop_recording()
|
||||
|
||||
#Remark : recording may be resumed here if needed
|
||||
# Remark : recording may be resumed here if needed
|
||||
|
||||
for microphone in microphones:
|
||||
microphone.disconnect()
|
||||
|
||||
print(f"Images have been saved to {output_dir}")
|
||||
|
||||
|
||||
class Microphone:
|
||||
"""
|
||||
The Microphone class handles all microphones compatible with sounddevice (and the underlying PortAudio library). Most microphones and sound cards are compatible, accross all OS (Linux, Mac, Windows).
|
||||
|
@ -138,20 +137,20 @@ class Microphone:
|
|||
self.config = config
|
||||
self.microphone_index = config.microphone_index
|
||||
|
||||
#Store the recording sample rate and channels
|
||||
# Store the recording sample rate and channels
|
||||
self.sample_rate = config.sample_rate
|
||||
self.channels = config.channels
|
||||
|
||||
self.mock = config.mock
|
||||
|
||||
#Input audio stream
|
||||
# Input audio stream
|
||||
self.stream = None
|
||||
|
||||
#Thread-safe concurrent queue to store the recorded/read audio
|
||||
# Thread-safe concurrent queue to store the recorded/read audio
|
||||
self.record_queue = None
|
||||
self.read_queue = None
|
||||
|
||||
#Thread to handle data reading and file writing in a separate thread (safely)
|
||||
# Thread to handle data reading and file writing in a separate thread (safely)
|
||||
self.record_thread = None
|
||||
self.record_stop_event = None
|
||||
|
||||
|
@ -162,14 +161,16 @@ class Microphone:
|
|||
|
||||
def connect(self) -> None:
|
||||
if self.is_connected:
|
||||
raise RobotDeviceAlreadyConnectedError(f"Microphone {self.microphone_index} is already connected.")
|
||||
raise RobotDeviceAlreadyConnectedError(
|
||||
f"Microphone {self.microphone_index} is already connected."
|
||||
)
|
||||
|
||||
if self.mock:
|
||||
import tests.microphones.mock_sounddevice as sd
|
||||
else:
|
||||
import sounddevice as sd
|
||||
|
||||
#Check if the provided microphone index does match an input device
|
||||
# Check if the provided microphone index does match an input device
|
||||
is_index_input = sd.query_devices(self.microphone_index)["max_input_channels"] > 0
|
||||
|
||||
if not is_index_input:
|
||||
|
@ -179,16 +180,18 @@ class Microphone:
|
|||
f"Microphone index {self.microphone_index} does not match an input device (microphone). Available input devices : {available_microphones}"
|
||||
)
|
||||
|
||||
#Check if provided recording parameters are compatible with the microphone
|
||||
# Check if provided recording parameters are compatible with the microphone
|
||||
actual_microphone = sd.query_devices(self.microphone_index)
|
||||
|
||||
if self.sample_rate is not None :
|
||||
if self.sample_rate is not None:
|
||||
if self.sample_rate > actual_microphone["default_samplerate"]:
|
||||
raise OSError(
|
||||
f"Provided sample rate {self.sample_rate} is higher than the sample rate of the microphone {actual_microphone['default_samplerate']}."
|
||||
)
|
||||
elif self.sample_rate < actual_microphone["default_samplerate"]:
|
||||
logging.warning("Provided sample rate is lower than the sample rate of the microphone. Performance may be impacted.")
|
||||
logging.warning(
|
||||
"Provided sample rate is lower than the sample rate of the microphone. Performance may be impacted."
|
||||
)
|
||||
else:
|
||||
self.sample_rate = int(actual_microphone["default_samplerate"])
|
||||
|
||||
|
@ -198,41 +201,48 @@ class Microphone:
|
|||
f"Some of the provided channels {self.channels} are outside the maximum channel range of the microphone {actual_microphone['max_input_channels']}."
|
||||
)
|
||||
else:
|
||||
self.channels = np.arange(1, actual_microphone["max_input_channels"]+1)
|
||||
self.channels = np.arange(1, actual_microphone["max_input_channels"] + 1)
|
||||
|
||||
# Get channels index instead of number for slicing
|
||||
self.channels = np.array(self.channels) - 1
|
||||
|
||||
#Create the audio stream
|
||||
# Create the audio stream
|
||||
self.stream = sd.InputStream(
|
||||
device=self.microphone_index,
|
||||
samplerate=self.sample_rate,
|
||||
channels=max(self.channels)+1,
|
||||
channels=max(self.channels) + 1,
|
||||
dtype="float32",
|
||||
callback=self._audio_callback,
|
||||
)
|
||||
#Remark : the blocksize parameter could be passed to the stream to ensure that audio_callback always recieve same length buffers.
|
||||
#However, this may lead to additionnal latency. We thus stick to blocksize=0 which means that audio_callback will recieve varying length buffers, but with no addtional latency.
|
||||
# Remark : the blocksize parameter could be passed to the stream to ensure that audio_callback always recieve same length buffers.
|
||||
# However, this may lead to additionnal latency. We thus stick to blocksize=0 which means that audio_callback will recieve varying length buffers, but with no addtional latency.
|
||||
|
||||
self.is_connected = True
|
||||
|
||||
def _audio_callback(self, indata, frames, time, status) -> None :
|
||||
def _audio_callback(self, indata, frames, time, status) -> None:
|
||||
if status:
|
||||
logging.warning(status)
|
||||
# Slicing makes copy unecessary
|
||||
# Two separate queues are necessary because .get() also pops the data from the queue
|
||||
if self.is_writing:
|
||||
self.record_queue.put(indata[:,self.channels])
|
||||
self.read_queue.put(indata[:,self.channels])
|
||||
self.record_queue.put(indata[:, self.channels])
|
||||
self.read_queue.put(indata[:, self.channels])
|
||||
|
||||
@staticmethod
|
||||
def _record_loop(queue, event: Event, sample_rate: int, channels: list[int], output_file: Path) -> None:
|
||||
#Can only be run on a single process/thread for file writing safety
|
||||
with sf.SoundFile(output_file, mode='x', samplerate=sample_rate,
|
||||
channels=max(channels)+1, subtype=sf.default_subtype(output_file.suffix[1:])) as file:
|
||||
# Can only be run on a single process/thread for file writing safety
|
||||
with sf.SoundFile(
|
||||
output_file,
|
||||
mode="x",
|
||||
samplerate=sample_rate,
|
||||
channels=max(channels) + 1,
|
||||
subtype=sf.default_subtype(output_file.suffix[1:]),
|
||||
) as file:
|
||||
while not event.is_set():
|
||||
try:
|
||||
file.write(queue.get(timeout=0.02)) #Timeout set as twice the usual sounddevice buffer size
|
||||
file.write(
|
||||
queue.get(timeout=0.02)
|
||||
) # Timeout set as twice the usual sounddevice buffer size
|
||||
queue.task_done()
|
||||
except Empty:
|
||||
continue
|
||||
|
@ -256,7 +266,6 @@ class Microphone:
|
|||
return audio_readings
|
||||
|
||||
def read(self) -> np.ndarray:
|
||||
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
|
||||
if not self.is_recording:
|
||||
|
@ -274,21 +283,22 @@ class Microphone:
|
|||
|
||||
return audio_readings
|
||||
|
||||
def start_recording(self, output_file : str | None = None, multiprocessing : bool | None = False) -> None:
|
||||
|
||||
def start_recording(self, output_file: str | None = None, multiprocessing: bool | None = False) -> None:
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
|
||||
if self.is_recording:
|
||||
raise RobotDeviceAlreadyRecordingError(f"Microphone {self.microphone_index} is already recording.")
|
||||
raise RobotDeviceAlreadyRecordingError(
|
||||
f"Microphone {self.microphone_index} is already recording."
|
||||
)
|
||||
|
||||
#Reset queues
|
||||
# Reset queues
|
||||
self.read_queue = thread_Queue()
|
||||
if multiprocessing:
|
||||
self.record_queue = process_Queue()
|
||||
else:
|
||||
self.record_queue = thread_Queue()
|
||||
|
||||
#Write recordings into a file if output_file is provided
|
||||
# Write recordings into a file if output_file is provided
|
||||
if output_file is not None:
|
||||
output_file = Path(output_file)
|
||||
if output_file.exists():
|
||||
|
@ -296,10 +306,28 @@ class Microphone:
|
|||
|
||||
if multiprocessing:
|
||||
self.record_stop_event = process_Event()
|
||||
self.record_thread = Process(target=Microphone._record_loop, args=(self.record_queue, self.record_stop_event, self.sample_rate, self.channels, output_file, ))
|
||||
self.record_thread = Process(
|
||||
target=Microphone._record_loop,
|
||||
args=(
|
||||
self.record_queue,
|
||||
self.record_stop_event,
|
||||
self.sample_rate,
|
||||
self.channels,
|
||||
output_file,
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.record_stop_event = thread_Event()
|
||||
self.record_thread = Thread(target=Microphone._record_loop, args=(self.record_queue, self.record_stop_event, self.sample_rate, self.channels, output_file, ))
|
||||
self.record_thread = Thread(
|
||||
target=Microphone._record_loop,
|
||||
args=(
|
||||
self.record_queue,
|
||||
self.record_stop_event,
|
||||
self.sample_rate,
|
||||
self.channels,
|
||||
output_file,
|
||||
),
|
||||
)
|
||||
self.record_thread.daemon = True
|
||||
self.record_thread.start()
|
||||
|
||||
|
@ -309,15 +337,14 @@ class Microphone:
|
|||
self.stream.start()
|
||||
|
||||
def stop_recording(self) -> None:
|
||||
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
|
||||
if not self.is_recording:
|
||||
raise RobotDeviceNotRecordingError(f"Microphone {self.microphone_index} is not recording.")
|
||||
|
||||
if self.stream.active:
|
||||
self.stream.stop() #Wait for all buffers to be processed
|
||||
#Remark : stream.abort() flushes the buffers !
|
||||
self.stream.stop() # Wait for all buffers to be processed
|
||||
# Remark : stream.abort() flushes the buffers !
|
||||
self.is_recording = False
|
||||
|
||||
if self.record_thread is not None:
|
||||
|
@ -329,7 +356,6 @@ class Microphone:
|
|||
self.is_writing = False
|
||||
|
||||
def disconnect(self) -> None:
|
||||
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
|
||||
|
||||
|
@ -343,6 +369,7 @@ class Microphone:
|
|||
if getattr(self, "is_connected", False):
|
||||
self.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Records audio using `Microphone` for all microphones connected to the computer, or a selected subset."
|
||||
|
|
|
@ -16,28 +16,33 @@ from typing import Protocol
|
|||
|
||||
from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig, MicrophoneConfigBase
|
||||
|
||||
|
||||
# Defines a microphone type
|
||||
class Microphone(Protocol):
|
||||
def connect(self): ...
|
||||
def disconnect(self): ...
|
||||
def start_recording(self, output_file : str | None = None, multiprocessing : bool | None = False): ...
|
||||
def start_recording(self, output_file: str | None = None, multiprocessing: bool | None = False): ...
|
||||
def stop_recording(self): ...
|
||||
|
||||
|
||||
def make_microphones_from_configs(microphone_configs: dict[str, MicrophoneConfigBase]) -> list[Microphone]:
|
||||
microphones = {}
|
||||
|
||||
for key, cfg in microphone_configs.items():
|
||||
if cfg.type == "microphone":
|
||||
from lerobot.common.robot_devices.microphones.microphone import Microphone
|
||||
|
||||
microphones[key] = Microphone(cfg)
|
||||
else:
|
||||
raise ValueError(f"The microphone type '{cfg.type}' is not valid.")
|
||||
|
||||
return microphones
|
||||
|
||||
|
||||
def make_microphone(microphone_type, **kwargs) -> Microphone:
|
||||
if microphone_type == "microphone":
|
||||
from lerobot.common.robot_devices.microphones.microphone import Microphone
|
||||
|
||||
return Microphone(MicrophoneConfig(**kwargs))
|
||||
else:
|
||||
raise ValueError(f"The microphone type '{microphone_type}' is not valid.")
|
|
@ -23,12 +23,12 @@ from lerobot.common.robot_devices.cameras.configs import (
|
|||
IntelRealSenseCameraConfig,
|
||||
OpenCVCameraConfig,
|
||||
)
|
||||
from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig
|
||||
from lerobot.common.robot_devices.motors.configs import (
|
||||
DynamixelMotorsBusConfig,
|
||||
FeetechMotorsBusConfig,
|
||||
MotorsBusConfig,
|
||||
)
|
||||
from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
@ -51,6 +51,7 @@ def run_camera_capture(cameras, images_lock, latest_images_dict, stop_event):
|
|||
latest_images_dict.update(local_dict)
|
||||
time.sleep(0.01)
|
||||
|
||||
|
||||
def run_microphone_capture(microphones, audio_lock, latest_audio_dict, stop_event):
|
||||
while not stop_event.is_set():
|
||||
local_dict = {}
|
||||
|
@ -60,6 +61,7 @@ def run_microphone_capture(microphones, audio_lock, latest_audio_dict, stop_even
|
|||
with audio_lock:
|
||||
latest_audio_dict.update(local_dict)
|
||||
|
||||
|
||||
def calibrate_follower_arm(motors_bus, calib_dir_str):
|
||||
"""
|
||||
Calibrates the follower arm. Attempts to load an existing calibration file;
|
||||
|
@ -149,12 +151,14 @@ def run_lekiwi(robot_config):
|
|||
cam_thread.start()
|
||||
|
||||
# Start the microphone recording and capture thread.
|
||||
#TODO(CarolinePascal) : Leverage multi-core processing with a multiprocessing.Process instead !
|
||||
# TODO(CarolinePascal) : Leverage multi-core processing with a multiprocessing.Process instead !
|
||||
latest_audio_dict = {}
|
||||
audio_lock = threading.Lock()
|
||||
audio_stop_event = threading.Event()
|
||||
microphone_thread = threading.Thread(
|
||||
target=run_microphone_capture, args=(microphones, audio_lock, latest_audio_dict, audio_stop_event), daemon=True
|
||||
target=run_microphone_capture,
|
||||
args=(microphones, audio_lock, latest_audio_dict, audio_stop_event),
|
||||
daemon=True,
|
||||
)
|
||||
for microphone in microphones.values():
|
||||
microphone.start_recording()
|
||||
|
@ -231,7 +235,7 @@ def run_lekiwi(robot_config):
|
|||
# Build the observation dictionary.
|
||||
observation = {
|
||||
"images": images_dict_copy,
|
||||
"audio": audio_dict_copy, #TODO(CarolinePascal) : This is a nasty way to do it, sorry.
|
||||
"audio": audio_dict_copy, # TODO(CarolinePascal) : This is a nasty way to do it, sorry.
|
||||
"present_speed": current_velocity,
|
||||
"follower_arm_state": follower_arm_state,
|
||||
}
|
||||
|
|
|
@ -211,7 +211,7 @@ class ManipulatorRobot:
|
|||
"dtype": "audio",
|
||||
"shape": (len(mic.channels),),
|
||||
"names": "channels",
|
||||
"info" : {"sample_rate": mic.sample_rate},
|
||||
"info": {"sample_rate": mic.sample_rate},
|
||||
}
|
||||
return mic_ft
|
||||
|
||||
|
|
|
@ -173,7 +173,7 @@ class MobileManipulator:
|
|||
"dtype": "audio",
|
||||
"shape": (len(mic.channels),),
|
||||
"names": "channels",
|
||||
"info" : {"sample_rate": mic.sample_rate},
|
||||
"info": {"sample_rate": mic.sample_rate},
|
||||
}
|
||||
return mic_ft
|
||||
|
||||
|
|
|
@ -69,11 +69,13 @@ class RobotDeviceNotRecordingError(Exception):
|
|||
"""Exception raised when the robot device is not recording."""
|
||||
|
||||
def __init__(
|
||||
self, message="This robot device is not recording. Try calling `robot_device.start_recording()` first."
|
||||
self,
|
||||
message="This robot device is not recording. Try calling `robot_device.start_recording()` first.",
|
||||
):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class RobotDeviceAlreadyRecordingError(Exception):
|
||||
"""Exception raised when the robot device is already recording."""
|
||||
|
||||
|
|
|
@ -19,9 +19,9 @@ import traceback
|
|||
import pytest
|
||||
from serial import SerialException
|
||||
|
||||
from lerobot import available_cameras, available_motors, available_robots, available_microphones
|
||||
from lerobot import available_cameras, available_microphones, available_motors, available_robots
|
||||
from lerobot.common.robot_devices.robots.utils import make_robot
|
||||
from tests.utils import DEVICE, make_camera, make_motors_bus, make_microphone
|
||||
from tests.utils import DEVICE, make_camera, make_microphone, make_motors_bus
|
||||
|
||||
# Import fixture modules as plugins
|
||||
pytest_plugins = [
|
||||
|
@ -73,10 +73,12 @@ def is_robot_available(robot_type):
|
|||
def is_camera_available(camera_type):
|
||||
return _check_component_availability(camera_type, available_cameras, make_camera)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def is_microphone_available(microphone_type):
|
||||
return _check_component_availability(microphone_type, available_microphones, make_microphone)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def is_motor_available(motor_type):
|
||||
return _check_component_availability(motor_type, available_motors, make_motors_bus)
|
||||
|
|
|
@ -25,9 +25,9 @@ from lerobot.common.datasets.compute_stats import (
|
|||
compute_episode_stats,
|
||||
estimate_num_samples,
|
||||
get_feature_stats,
|
||||
sample_images,
|
||||
sample_audio_from_path,
|
||||
sample_audio_from_data,
|
||||
sample_audio_from_path,
|
||||
sample_images,
|
||||
sample_indices,
|
||||
)
|
||||
|
||||
|
@ -35,8 +35,10 @@ from lerobot.common.datasets.compute_stats import (
|
|||
def mock_load_image_as_numpy(path, dtype, channel_first):
|
||||
return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype)
|
||||
|
||||
|
||||
def mock_load_audio(path):
|
||||
return np.ones((16000,2), dtype=np.float32)
|
||||
return np.ones((16000, 2), dtype=np.float32)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_array():
|
||||
|
@ -74,6 +76,7 @@ def test_sample_images(mock_load):
|
|||
assert images.dtype == np.uint8
|
||||
assert len(images) == estimate_num_samples(100)
|
||||
|
||||
|
||||
@patch("lerobot.common.datasets.compute_stats.load_audio_from_path", side_effect=mock_load_audio)
|
||||
def test_sample_audio_from_path(mock_load):
|
||||
audio_path = "audio.wav"
|
||||
|
@ -83,6 +86,7 @@ def test_sample_audio_from_path(mock_load):
|
|||
assert audio_samples.dtype == np.float32
|
||||
assert len(audio_samples) == estimate_num_samples(16000)
|
||||
|
||||
|
||||
def test_sample_audio_from_data(mock_load):
|
||||
audio_data = np.ones((16000, 2), dtype=np.float32)
|
||||
audio_samples = sample_audio_from_data(audio_data)
|
||||
|
@ -91,6 +95,7 @@ def test_sample_audio_from_data(mock_load):
|
|||
assert audio_samples.dtype == np.float32
|
||||
assert len(audio_samples) == estimate_num_samples(16000)
|
||||
|
||||
|
||||
def test_get_feature_stats_images():
|
||||
data = np.random.rand(100, 3, 32, 32)
|
||||
stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True)
|
||||
|
@ -98,13 +103,15 @@ def test_get_feature_stats_images():
|
|||
np.testing.assert_equal(stats["count"], np.array([100]))
|
||||
assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape
|
||||
|
||||
|
||||
def test_get_feature_stats_audio():
|
||||
data = np.random.uniform(-1, 1, (16000,2))
|
||||
data = np.random.uniform(-1, 1, (16000, 2))
|
||||
stats = get_feature_stats(data, axis=0, keepdims=True)
|
||||
assert "min" in stats and "max" in stats and "mean" in stats and "std" in stats and "count" in stats
|
||||
np.testing.assert_equal(stats["count"], np.array([16000]))
|
||||
assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape
|
||||
|
||||
|
||||
def test_get_feature_stats_axis_0_keepdims(sample_array):
|
||||
expected = {
|
||||
"min": np.array([[1, 2, 3]]),
|
||||
|
@ -172,10 +179,11 @@ def test_compute_episode_stats():
|
|||
"observation.state": {"dtype": "numeric"},
|
||||
}
|
||||
|
||||
with patch(
|
||||
with (
|
||||
patch(
|
||||
"lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy
|
||||
), patch(
|
||||
"lerobot.common.datasets.compute_stats.load_audio_from_path", side_effect=mock_load_audio
|
||||
),
|
||||
patch("lerobot.common.datasets.compute_stats.load_audio_from_path", side_effect=mock_load_audio),
|
||||
):
|
||||
stats = compute_episode_stats(episode_data, features)
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
|
@ -35,6 +36,7 @@ from lerobot.common.datasets.lerobot_dataset import (
|
|||
MultiLeRobotDataset,
|
||||
)
|
||||
from lerobot.common.datasets.utils import (
|
||||
DEFAULT_AUDIO_CHUNK_DURATION,
|
||||
create_branch,
|
||||
flatten_dict,
|
||||
unflatten_dict,
|
||||
|
@ -44,12 +46,9 @@ from lerobot.common.policies.factory import make_policy_config
|
|||
from lerobot.common.robot_devices.robots.utils import make_robot
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID, DUMMY_AUDIO_CHANNELS
|
||||
from tests.utils import require_x86_64_kernel
|
||||
from tests.fixtures.constants import DUMMY_AUDIO_CHANNELS, DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
||||
from tests.utils import make_microphone, require_x86_64_kernel
|
||||
|
||||
from tests.utils import make_microphone
|
||||
import time
|
||||
from lerobot.common.datasets.utils import DEFAULT_AUDIO_CHUNK_DURATION
|
||||
|
||||
@pytest.fixture
|
||||
def image_dataset(tmp_path, empty_lerobot_dataset_factory):
|
||||
|
@ -66,6 +65,7 @@ def image_dataset(tmp_path, empty_lerobot_dataset_factory):
|
|||
}
|
||||
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def audio_dataset(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {
|
||||
|
@ -79,6 +79,7 @@ def audio_dataset(tmp_path, empty_lerobot_dataset_factory):
|
|||
}
|
||||
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
|
||||
|
||||
def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
|
||||
"""
|
||||
Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated
|
||||
|
@ -336,6 +337,7 @@ def test_image_array_to_pil_image_wrong_range_float_0_255():
|
|||
with pytest.raises(ValueError):
|
||||
image_array_to_pil_image(image)
|
||||
|
||||
|
||||
def test_add_frame_audio(audio_dataset):
|
||||
dataset = audio_dataset
|
||||
|
||||
|
@ -349,7 +351,10 @@ def test_add_frame_audio(audio_dataset):
|
|||
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["observation.audio.microphone"].shape == torch.Size((int(DEFAULT_AUDIO_CHUNK_DURATION*microphone.sample_rate),DUMMY_AUDIO_CHANNELS))
|
||||
assert dataset[0]["observation.audio.microphone"].shape == torch.Size(
|
||||
(int(DEFAULT_AUDIO_CHUNK_DURATION * microphone.sample_rate), DUMMY_AUDIO_CHANNELS)
|
||||
)
|
||||
|
||||
|
||||
# TODO(aliberts):
|
||||
# - [ ] test various attributes & state from init and create
|
||||
|
|
|
@ -29,7 +29,12 @@ DUMMY_MOTOR_FEATURES = {
|
|||
},
|
||||
}
|
||||
DUMMY_CAMERA_FEATURES = {
|
||||
"laptop": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None, "audio": "laptop"},
|
||||
"laptop": {
|
||||
"shape": (480, 640, 3),
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": None,
|
||||
"audio": "laptop",
|
||||
},
|
||||
"phone": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
|
||||
}
|
||||
DEFAULT_FPS = 30
|
||||
|
|
|
@ -26,10 +26,10 @@ import torch
|
|||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.common.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_COMPRESSED_AUDIO_PATH,
|
||||
DEFAULT_FEATURES,
|
||||
DEFAULT_PARQUET_PATH,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
DEFAULT_COMPRESSED_AUDIO_PATH,
|
||||
get_hf_features_from_features,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
|
|
|
@ -11,19 +11,21 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import time
|
||||
from functools import cache
|
||||
|
||||
from tests.fixtures.constants import DUMMY_AUDIO_CHANNELS, DEFAULT_SAMPLE_RATE
|
||||
from threading import Event, Thread
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.common.utils.utils import capture_timestamp_utc
|
||||
from threading import Thread, Event
|
||||
import time
|
||||
from tests.fixtures.constants import DEFAULT_SAMPLE_RATE, DUMMY_AUDIO_CHANNELS
|
||||
|
||||
|
||||
@cache
|
||||
def _generate_sound(duration: float, sample_rate: int, channels: int):
|
||||
return np.random.uniform(-1, 1, size=(int(duration * sample_rate), channels)).astype(np.float32)
|
||||
|
||||
|
||||
def query_devices(query_index: int):
|
||||
return {
|
||||
"name": "Mock Sound Device",
|
||||
|
@ -32,6 +34,7 @@ def query_devices(query_index: int):
|
|||
"default_samplerate": DEFAULT_SAMPLE_RATE,
|
||||
}
|
||||
|
||||
|
||||
class InputStream:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._mock_dict = {
|
||||
|
@ -49,7 +52,12 @@ class InputStream:
|
|||
while not self.callback_thread_stop_event.is_set():
|
||||
# Simulate audio data acquisition
|
||||
time.sleep(0.01)
|
||||
self._audio_callback(_generate_sound(0.01, DEFAULT_SAMPLE_RATE, DUMMY_AUDIO_CHANNELS), 0.01*DEFAULT_SAMPLE_RATE, capture_timestamp_utc(), None)
|
||||
self._audio_callback(
|
||||
_generate_sound(0.01, DEFAULT_SAMPLE_RATE, DUMMY_AUDIO_CHANNELS),
|
||||
0.01 * DEFAULT_SAMPLE_RATE,
|
||||
capture_timestamp_utc(),
|
||||
None,
|
||||
)
|
||||
|
||||
def start(self):
|
||||
self.callback_thread_stop_event = Event()
|
||||
|
@ -78,5 +86,3 @@ class InputStream:
|
|||
def __del__(self):
|
||||
if self._is_active:
|
||||
self.stop()
|
||||
|
||||
|
||||
|
|
|
@ -32,20 +32,27 @@ pytest -sx 'tests/microphones/test_microphones.py::test_microphone[microphone-Tr
|
|||
```
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from soundfile import read
|
||||
|
||||
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError, RobotDeviceNotRecordingError, RobotDeviceAlreadyRecordingError
|
||||
from lerobot.common.robot_devices.utils import (
|
||||
RobotDeviceAlreadyConnectedError,
|
||||
RobotDeviceAlreadyRecordingError,
|
||||
RobotDeviceNotConnectedError,
|
||||
RobotDeviceNotRecordingError,
|
||||
)
|
||||
from tests.utils import TEST_MICROPHONE_TYPES, make_microphone, require_microphone
|
||||
|
||||
#Maximum recording tie difference between two consecutive audio recordings of the same duration.
|
||||
#Set to 0.02 seconds as twice the default size of sounddvice callback buffer (i.e. we tolerate the loss of one buffer).
|
||||
# Maximum recording tie difference between two consecutive audio recordings of the same duration.
|
||||
# Set to 0.02 seconds as twice the default size of sounddvice callback buffer (i.e. we tolerate the loss of one buffer).
|
||||
MAX_RECORDING_TIME_DIFFERENCE = 0.02
|
||||
|
||||
DUMMY_RECORDING = "test_recording.wav"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("microphone_type, mock", TEST_MICROPHONE_TYPES)
|
||||
@require_microphone
|
||||
def test_microphone(tmp_path, request, microphone_type, mock):
|
||||
|
@ -126,10 +133,13 @@ def test_microphone(tmp_path, request, microphone_type, mock):
|
|||
|
||||
error_msg = (
|
||||
"Recording time difference between read() and stop_recording()",
|
||||
(len(audio_chunk) - len(recorded_audio))/MAX_RECORDING_TIME_DIFFERENCE,
|
||||
(len(audio_chunk) - len(recorded_audio)) / MAX_RECORDING_TIME_DIFFERENCE,
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
len(audio_chunk), len(recorded_audio), atol=recorded_sample_rate*MAX_RECORDING_TIME_DIFFERENCE, err_msg=error_msg
|
||||
len(audio_chunk),
|
||||
len(recorded_audio),
|
||||
atol=recorded_sample_rate * MAX_RECORDING_TIME_DIFFERENCE,
|
||||
err_msg=error_msg,
|
||||
)
|
||||
|
||||
# Test disconnecting
|
||||
|
|
|
@ -143,7 +143,7 @@ def test_robot(tmp_path, request, robot_type, mock):
|
|||
robot.send_action(action["action"])
|
||||
|
||||
# Test disconnecting
|
||||
robot.disconnect() #Also handles microphone recording stop, life is beautiful
|
||||
robot.disconnect() # Also handles microphone recording stop, life is beautiful
|
||||
assert not robot.is_connected
|
||||
for name in robot.follower_arms:
|
||||
assert not robot.follower_arms[name].is_connected
|
||||
|
|
|
@ -22,13 +22,13 @@ from pathlib import Path
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot import available_cameras, available_motors, available_robots, available_microphones
|
||||
from lerobot import available_cameras, available_microphones, available_motors, available_robots
|
||||
from lerobot.common.robot_devices.cameras.utils import Camera
|
||||
from lerobot.common.robot_devices.cameras.utils import make_camera as make_camera_device
|
||||
from lerobot.common.robot_devices.motors.utils import MotorsBus
|
||||
from lerobot.common.robot_devices.motors.utils import make_motors_bus as make_motors_bus_device
|
||||
from lerobot.common.robot_devices.microphones.utils import Microphone
|
||||
from lerobot.common.robot_devices.microphones.utils import make_microphone as make_microphone_device
|
||||
from lerobot.common.robot_devices.motors.utils import MotorsBus
|
||||
from lerobot.common.robot_devices.motors.utils import make_motors_bus as make_motors_bus_device
|
||||
from lerobot.common.utils.import_utils import is_package_available
|
||||
|
||||
DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", "cuda") if torch.cuda.is_available() else "cpu"
|
||||
|
@ -261,6 +261,7 @@ def require_camera(func):
|
|||
|
||||
return wrapper
|
||||
|
||||
|
||||
def require_microphone(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
|
@ -283,6 +284,7 @@ def require_microphone(func):
|
|||
|
||||
return wrapper
|
||||
|
||||
|
||||
def require_motor(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
|
@ -344,6 +346,7 @@ def make_camera(camera_type: str, **kwargs) -> Camera:
|
|||
else:
|
||||
raise ValueError(f"The camera type '{camera_type}' is not valid.")
|
||||
|
||||
|
||||
def make_microphone(microphone_type: str, **kwargs) -> Microphone:
|
||||
if microphone_type == "microphone":
|
||||
microphone_index = kwargs.pop("microphone_index", MICROPHONE_INDEX)
|
||||
|
@ -351,6 +354,7 @@ def make_microphone(microphone_type: str, **kwargs) -> Microphone:
|
|||
else:
|
||||
raise ValueError(f"The microphone type '{microphone_type}' is not valid.")
|
||||
|
||||
|
||||
# TODO(rcadene, aliberts): remove this dark pattern that overrides
|
||||
def make_motors_bus(motor_type: str, **kwargs) -> MotorsBus:
|
||||
if motor_type == "dynamixel":
|
||||
|
|
Loading…
Reference in New Issue