[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2025-04-10 16:56:39 +00:00 committed by CarolinePascal
parent 9c667d347c
commit 0cb9345f06
No known key found for this signature in database
22 changed files with 329 additions and 196 deletions

View File

@ -15,7 +15,8 @@
# limitations under the License. # limitations under the License.
import numpy as np import numpy as np
from lerobot.common.datasets.utils import load_image_as_numpy, load_audio_from_path from lerobot.common.datasets.utils import load_audio_from_path, load_image_as_numpy
def estimate_num_samples( def estimate_num_samples(
dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75 dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75
@ -70,17 +71,19 @@ def sample_images(image_paths: list[str]) -> np.ndarray:
return images return images
def sample_audio_from_path(audio_path: str) -> np.ndarray:
def sample_audio_from_path(audio_path: str) -> np.ndarray:
data = load_audio_from_path(audio_path) data = load_audio_from_path(audio_path)
sampled_indices = sample_indices(len(data)) sampled_indices = sample_indices(len(data))
return(data[sampled_indices]) return data[sampled_indices]
def sample_audio_from_data(data: np.ndarray) -> np.ndarray: def sample_audio_from_data(data: np.ndarray) -> np.ndarray:
sampled_indices = sample_indices(len(data)) sampled_indices = sample_indices(len(data))
return data[sampled_indices] return data[sampled_indices]
def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]: def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]:
return { return {
"min": np.min(array, axis=axis, keepdims=keepdims), "min": np.min(array, axis=axis, keepdims=keepdims),
@ -103,9 +106,9 @@ def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], featu
elif features[key]["dtype"] == "audio": elif features[key]["dtype"] == "audio":
try: try:
ep_ft_array = sample_audio_from_path(data[0]) ep_ft_array = sample_audio_from_path(data[0])
except TypeError: #Should only be triggered for LeKiwi robot except TypeError: # Should only be triggered for LeKiwi robot
ep_ft_array = sample_audio_from_data(data) ep_ft_array = sample_audio_from_data(data)
axes_to_reduce = 0 axes_to_reduce = 0
keepdims = True keepdims = True
else: else:
ep_ft_array = data # data is already a np.ndarray ep_ft_array = data # data is already a np.ndarray

View File

@ -23,6 +23,7 @@ import datasets
import numpy as np import numpy as np
import packaging.version import packaging.version
import PIL.Image import PIL.Image
import soundfile as sf
import torch import torch
import torch.utils import torch.utils
from datasets import concatenate_datasets, load_dataset from datasets import concatenate_datasets, load_dataset
@ -34,13 +35,12 @@ from lerobot.common.constants import HF_LEROBOT_HOME
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
DEFAULT_AUDIO_CHUNK_DURATION,
DEFAULT_FEATURES, DEFAULT_FEATURES,
DEFAULT_IMAGE_PATH, DEFAULT_IMAGE_PATH,
DEFAULT_RAW_AUDIO_PATH, DEFAULT_RAW_AUDIO_PATH,
DEFAULT_COMPRESSED_AUDIO_PATH,
INFO_PATH, INFO_PATH,
TASKS_PATH, TASKS_PATH,
DEFAULT_AUDIO_CHUNK_DURATION,
append_jsonlines, append_jsonlines,
backward_compatible_episodes_stats, backward_compatible_episodes_stats,
check_delta_timestamps, check_delta_timestamps,
@ -70,17 +70,16 @@ from lerobot.common.datasets.utils import (
) )
from lerobot.common.datasets.video_utils import ( from lerobot.common.datasets.video_utils import (
VideoFrame, VideoFrame,
decode_video_frames,
encode_video_frames,
encode_audio,
decode_audio, decode_audio,
decode_video_frames,
encode_audio,
encode_video_frames,
get_audio_info,
get_safe_default_codec, get_safe_default_codec,
get_video_info, get_video_info,
get_audio_info,
) )
from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.robot_devices.microphones.utils import Microphone from lerobot.common.robot_devices.microphones.utils import Microphone
import soundfile as sf from lerobot.common.robot_devices.robots.utils import Robot
CODEBASE_VERSION = "v2.1" CODEBASE_VERSION = "v2.1"
@ -149,10 +148,12 @@ class LeRobotDatasetMetadata:
ep_chunk = self.get_episode_chunk(ep_index) ep_chunk = self.get_episode_chunk(ep_index)
fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index) fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
return Path(fpath) return Path(fpath)
def get_compressed_audio_file_path(self, episode_index: int, audio_key: str) -> Path: def get_compressed_audio_file_path(self, episode_index: int, audio_key: str) -> Path:
episode_chunk = self.get_episode_chunk(episode_index) episode_chunk = self.get_episode_chunk(episode_index)
fpath = self.audio_path.format(episode_chunk=episode_chunk, audio_key=audio_key, episode_index=episode_index) fpath = self.audio_path.format(
episode_chunk=episode_chunk, audio_key=audio_key, episode_index=episode_index
)
return self.root / fpath return self.root / fpath
def get_episode_chunk(self, ep_index: int) -> int: def get_episode_chunk(self, ep_index: int) -> int:
@ -167,7 +168,7 @@ class LeRobotDatasetMetadata:
def video_path(self) -> str | None: def video_path(self) -> str | None:
"""Formattable string for the video files.""" """Formattable string for the video files."""
return self.info["video_path"] return self.info["video_path"]
@property @property
def audio_path(self) -> str | None: def audio_path(self) -> str | None:
"""Formattable string for the audio files.""" """Formattable string for the audio files."""
@ -202,16 +203,21 @@ class LeRobotDatasetMetadata:
def camera_keys(self) -> list[str]: def camera_keys(self) -> list[str]:
"""Keys to access visual modalities (regardless of their storage method).""" """Keys to access visual modalities (regardless of their storage method)."""
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]] return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
@property @property
def audio_keys(self) -> list[str]: def audio_keys(self) -> list[str]:
"""Keys to access audio modalities.""" """Keys to access audio modalities."""
return [key for key, ft in self.features.items() if ft["dtype"] == "audio"] return [key for key, ft in self.features.items() if ft["dtype"] == "audio"]
@property @property
def audio_camera_keys_mapping(self) -> dict[str, str]: def audio_camera_keys_mapping(self) -> dict[str, str]:
"""Mapping between camera keys and audio keys when both are linked.""" """Mapping between camera keys and audio keys when both are linked."""
return {self.features[camera_key]["audio"]:camera_key for camera_key in self.camera_keys if self.features[camera_key]["audio"] is not None and self.features[camera_key]["dtype"] == "video"} return {
self.features[camera_key]["audio"]: camera_key
for camera_key in self.camera_keys
if self.features[camera_key]["audio"] is not None
and self.features[camera_key]["dtype"] == "video"
}
@property @property
def names(self) -> dict[str, list | dict]: def names(self) -> dict[str, list | dict]:
@ -325,7 +331,9 @@ class LeRobotDatasetMetadata:
been encoded the same way. Also, this means it assumes the first episode exists. been encoded the same way. Also, this means it assumes the first episode exists.
""" """
for key in set(self.audio_keys) - set(self.audio_camera_keys_mapping.keys()): for key in set(self.audio_keys) - set(self.audio_camera_keys_mapping.keys()):
if not self.features[key].get("info", None) or (len(self.features[key]["info"]) == 1 and "sample_rate" in self.features[key]["info"]): if not self.features[key].get("info", None) or (
len(self.features[key]["info"]) == 1 and "sample_rate" in self.features[key]["info"]
):
audio_path = self.root / self.get_compressed_audio_file_path(0, key) audio_path = self.root / self.get_compressed_audio_file_path(0, key)
self.info["features"][key]["info"] = get_audio_info(audio_path) self.info["features"][key]["info"] = get_audio_info(audio_path)
@ -518,7 +526,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.tolerance_s = tolerance_s self.tolerance_s = tolerance_s
self.revision = revision if revision else CODEBASE_VERSION self.revision = revision if revision else CODEBASE_VERSION
self.video_backend = video_backend if video_backend else get_safe_default_codec() self.video_backend = video_backend if video_backend else get_safe_default_codec()
self.audio_backend = audio_backend if audio_backend else "ffmpeg" #Waiting for torchcodec release #TODO(CarolinePascal) self.audio_backend = (
audio_backend if audio_backend else "ffmpeg"
) # Waiting for torchcodec release #TODO(CarolinePascal)
self.delta_indices = None self.delta_indices = None
# Unused attributes # Unused attributes
@ -543,7 +553,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.hf_dataset = self.load_hf_dataset() self.hf_dataset = self.load_hf_dataset()
except (AssertionError, FileNotFoundError, NotADirectoryError): except (AssertionError, FileNotFoundError, NotADirectoryError):
self.revision = get_safe_version(self.repo_id, self.revision) self.revision = get_safe_version(self.repo_id, self.revision)
self.download_episodes(download_videos) #Sould load audio as well #TODO(CarolinePascal): separate audio from video self.download_episodes(
download_videos
) # Sould load audio as well #TODO(CarolinePascal): separate audio from video
self.hf_dataset = self.load_hf_dataset() self.hf_dataset = self.load_hf_dataset()
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes) self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
@ -735,13 +747,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
query_timestamps[key] = [current_ts] query_timestamps[key] = [current_ts]
return query_timestamps return query_timestamps
#TODO(CarolinePascal): add variable query durations # TODO(CarolinePascal): add variable query durations
def _get_query_timestamps_audio( def _get_query_timestamps_audio(
self, self,
current_ts: float, current_ts: float,
query_indices: dict[str, list[int]] | None = None, query_indices: dict[str, list[int]] | None = None,
) -> dict[str, list[float]]: ) -> dict[str, list[float]]:
query_timestamps = {} query_timestamps = {}
for key in self.meta.audio_keys: for key in self.meta.audio_keys:
if query_indices is not None and key in query_indices: if query_indices is not None and key in query_indices:
@ -773,14 +785,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
return item return item
#TODO(CarolinePascal): add variable query durations # TODO(CarolinePascal): add variable query durations
def _query_audio(self, query_timestamps: dict[str, list[float]], query_duration: float, ep_idx: int) -> dict[str, torch.Tensor]: def _query_audio(
self, query_timestamps: dict[str, list[float]], query_duration: float, ep_idx: int
) -> dict[str, torch.Tensor]:
item = {} item = {}
for audio_key, query_ts in query_timestamps.items(): for audio_key, query_ts in query_timestamps.items():
#Audio stored with video in a single .mp4 file # Audio stored with video in a single .mp4 file
if audio_key in self.meta.audio_camera_keys_mapping: if audio_key in self.meta.audio_camera_keys_mapping:
audio_path = self.root / self.meta.get_video_file_path(ep_idx, self.meta.audio_camera_keys_mapping[audio_key]) audio_path = self.root / self.meta.get_video_file_path(
#Audio stored alone in a separate .m4a file ep_idx, self.meta.audio_camera_keys_mapping[audio_key]
)
# Audio stored alone in a separate .m4a file
else: else:
audio_path = self.root / self.meta.get_compressed_audio_file_path(ep_idx, audio_key) audio_path = self.root / self.meta.get_compressed_audio_file_path(ep_idx, audio_key)
audio_chunk = decode_audio(audio_path, query_ts, query_duration, self.audio_backend) audio_chunk = decode_audio(audio_path, query_ts, query_duration, self.audio_backend)
@ -855,7 +871,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
image_key=image_key, episode_index=episode_index, frame_index=frame_index image_key=image_key, episode_index=episode_index, frame_index=frame_index
) )
return self.root / fpath return self.root / fpath
def _get_raw_audio_file_path(self, episode_index: int, audio_key: str) -> Path: def _get_raw_audio_file_path(self, episode_index: int, audio_key: str) -> Path:
fpath = DEFAULT_RAW_AUDIO_PATH.format(audio_key=audio_key, episode_index=episode_index) fpath = DEFAULT_RAW_AUDIO_PATH.format(audio_key=audio_key, episode_index=episode_index)
return self.root / fpath return self.root / fpath
@ -929,11 +945,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
This function will start recording audio from the microphone and save it to disk. This function will start recording audio from the microphone and save it to disk.
""" """
audio_dir = self._get_raw_audio_file_path(self.num_episodes, "observation.audio." + microphone_key).parent audio_dir = self._get_raw_audio_file_path(
self.num_episodes, "observation.audio." + microphone_key
).parent
if not audio_dir.is_dir(): if not audio_dir.is_dir():
audio_dir.mkdir(parents=True, exist_ok=True) audio_dir.mkdir(parents=True, exist_ok=True)
microphone.start_recording(output_file = self._get_raw_audio_file_path(self.num_episodes, "observation.audio." + microphone_key)) microphone.start_recording(
output_file=self._get_raw_audio_file_path(
self.num_episodes, "observation.audio." + microphone_key
)
)
def save_episode(self, episode_data: dict | None = None) -> None: def save_episode(self, episode_data: dict | None = None) -> None:
""" """
@ -983,8 +1005,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
if self.meta.robot_type.startswith("lekiwi"): if self.meta.robot_type.startswith("lekiwi"):
for key in self.meta.audio_keys: for key in self.meta.audio_keys:
audio_path = self._get_raw_audio_file_path(episode_index=self.episode_buffer["episode_index"][0], audio_key=key) audio_path = self._get_raw_audio_file_path(
with sf.SoundFile(audio_path, mode='w', samplerate=self.meta.features[key]["info"]["sample_rate"], channels=self.meta.features[key]["shape"][0]) as file: episode_index=self.episode_buffer["episode_index"][0], audio_key=key
)
with sf.SoundFile(
audio_path,
mode="w",
samplerate=self.meta.features[key]["info"]["sample_rate"],
channels=self.meta.features[key]["shape"][0],
) as file:
file.write(episode_buffer[key]) file.write(episode_buffer[key])
ep_stats = compute_episode_stats(episode_buffer, self.features) ep_stats = compute_episode_stats(episode_buffer, self.features)
@ -996,7 +1025,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
if len(self.meta.audio_keys) > 0: if len(self.meta.audio_keys) > 0:
_ = self.encode_episode_audio(episode_index) _ = self.encode_episode_audio(episode_index)
# `meta.save_episode` be executed after encoding the videos # `meta.save_episode` be executed after encoding the videos
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats) self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
@ -1113,7 +1142,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
encode_video_frames(img_dir, video_path, self.fps, audio_path=audio_path, overwrite=True) encode_video_frames(img_dir, video_path, self.fps, audio_path=audio_path, overwrite=True)
return video_paths return video_paths
def encode_episode_audio(self, episode_index: int) -> dict: def encode_episode_audio(self, episode_index: int) -> dict:
""" """
Use ffmpeg to convert .wav raw audio files into .m4a audio files. Use ffmpeg to convert .wav raw audio files into .m4a audio files.
@ -1124,7 +1153,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
for audio_key in set(self.meta.audio_keys) - set(self.meta.audio_camera_keys_mapping.keys()): for audio_key in set(self.meta.audio_keys) - set(self.meta.audio_camera_keys_mapping.keys()):
input_audio_path = self.root / self._get_raw_audio_file_path(episode_index, audio_key) input_audio_path = self.root / self._get_raw_audio_file_path(episode_index, audio_key)
output_audio_path = self.root / self.meta.get_compressed_audio_file_path(episode_index, audio_key) output_audio_path = self.root / self.meta.get_compressed_audio_file_path(episode_index, audio_key)
audio_paths[audio_key] = str(output_audio_path) audio_paths[audio_key] = str(output_audio_path)
if output_audio_path.is_file(): if output_audio_path.is_file():
# Skip if video is already encoded. Could be the case when resuming data recording. # Skip if video is already encoded. Could be the case when resuming data recording.
@ -1180,7 +1209,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.delta_indices = None obj.delta_indices = None
obj.episode_data_index = None obj.episode_data_index = None
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec() obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
obj.audio_backend = audio_backend if audio_backend is not None else "ffmpeg" #Waiting for torchcodec release #TODO(CarolinePascal) obj.audio_backend = (
audio_backend if audio_backend is not None else "ffmpeg"
) # Waiting for torchcodec release #TODO(CarolinePascal)
return obj return obj

View File

@ -33,9 +33,8 @@ from datasets.table import embed_table_storage
from huggingface_hub import DatasetCard, DatasetCardData, HfApi from huggingface_hub import DatasetCard, DatasetCardData, HfApi
from huggingface_hub.errors import RevisionNotFoundError from huggingface_hub.errors import RevisionNotFoundError
from PIL import Image as PILImage from PIL import Image as PILImage
from torchvision import transforms
from soundfile import read from soundfile import read
from torchvision import transforms
from lerobot.common.datasets.backward_compatibility import ( from lerobot.common.datasets.backward_compatibility import (
V21_MESSAGE, V21_MESSAGE,
@ -260,10 +259,12 @@ def load_image_as_numpy(
img_array /= 255.0 img_array /= 255.0
return img_array return img_array
def load_audio_from_path(fpath: str | Path) -> np.ndarray: def load_audio_from_path(fpath: str | Path) -> np.ndarray:
audio_data, _ = read(fpath, dtype="float32") audio_data, _ = read(fpath, dtype="float32")
return audio_data return audio_data
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]): def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
"""Get a transform function that convert items from Hugging Face dataset (pyarrow) """Get a transform function that convert items from Hugging Face dataset (pyarrow)
to torch tensors. Importantly, images are converted from PIL, which corresponds to to torch tensors. Importantly, images are converted from PIL, which corresponds to
@ -731,7 +732,7 @@ def validate_features_presence(
): ):
error_message = "" error_message = ""
missing_features = expected_features - actual_features missing_features = expected_features - actual_features
missing_features = {feature for feature in missing_features if "observation.audio" not in feature} missing_features = {feature for feature in missing_features if "observation.audio" not in feature}
extra_features = actual_features - (expected_features | optional_features) extra_features = actual_features - (expected_features | optional_features)
if missing_features or extra_features: if missing_features or extra_features:
@ -793,18 +794,24 @@ def validate_feature_image_or_video(name: str, expected_shape: list[str], value:
return error_message return error_message
def validate_feature_audio(name: str, expected_shape: list[str], value: np.ndarray): def validate_feature_audio(name: str, expected_shape: list[str], value: np.ndarray):
error_message = "" error_message = ""
if isinstance(value, np.ndarray): if isinstance(value, np.ndarray):
actual_shape = value.shape actual_shape = value.shape
c = expected_shape c = expected_shape
if len(actual_shape) != 2 or (actual_shape[-1] != c[-1] and actual_shape[0] != c[0]): #The number of frames might be different if len(actual_shape) != 2 or (
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c,)}'.\n" actual_shape[-1] != c[-1] and actual_shape[0] != c[0]
): # The number of frames might be different
error_message += (
f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c,)}'.\n"
)
else: else:
error_message += f"The feature '{name}' is expected to be of type 'np.ndarray', but type '{type(value)}' provided instead.\n" error_message += f"The feature '{name}' is expected to be of type 'np.ndarray', but type '{type(value)}' provided instead.\n"
return error_message return error_message
def validate_feature_string(name: str, value: str): def validate_feature_string(name: str, value: str):
if not isinstance(value, str): if not isinstance(value, str):
return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n" return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n"

View File

@ -25,12 +25,11 @@ from typing import Any, ClassVar
import pyarrow as pa import pyarrow as pa
import torch import torch
import torchvision
import torchaudio import torchaudio
import torchvision
from datasets.features.features import register_feature from datasets.features.features import register_feature
from PIL import Image
from numpy import ceil from numpy import ceil
from PIL import Image
def get_safe_default_codec(): def get_safe_default_codec():
@ -42,6 +41,7 @@ def get_safe_default_codec():
) )
return "pyav" return "pyav"
def decode_audio( def decode_audio(
audio_path: Path | str, audio_path: Path | str,
timestamps: list[float], timestamps: list[float],
@ -68,30 +68,30 @@ def decode_audio(
else: else:
raise ValueError(f"Unsupported video backend: {backend}") raise ValueError(f"Unsupported video backend: {backend}")
def decode_audio_torchvision( def decode_audio_torchvision(
audio_path: Path | str, audio_path: Path | str,
timestamps: list[float], timestamps: list[float],
duration: float, duration: float,
log_loaded_timestamps: bool = False, log_loaded_timestamps: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
# TODO(CarolinePascal) : add channels selection
#TODO(CarolinePascal) : add channels selection
audio_path = str(audio_path) audio_path = str(audio_path)
reader = torchaudio.io.StreamReader(src=audio_path) reader = torchaudio.io.StreamReader(src=audio_path)
audio_sample_rate = reader.get_src_stream_info(reader.default_audio_stream).sample_rate audio_sample_rate = reader.get_src_stream_info(reader.default_audio_stream).sample_rate
#TODO(CarolinePascal) : sort timestamps ? # TODO(CarolinePascal) : sort timestamps ?
reader.add_basic_audio_stream( reader.add_basic_audio_stream(
frames_per_chunk = int(ceil(duration * audio_sample_rate)), #Too much is better than not enough frames_per_chunk=int(ceil(duration * audio_sample_rate)), # Too much is better than not enough
buffer_chunk_size = -1, #No dropping frames buffer_chunk_size=-1, # No dropping frames
format = "fltp", #Format as float32 format="fltp", # Format as float32
) )
audio_chunks = [] audio_chunks = []
for ts in timestamps: for ts in timestamps:
reader.seek(ts) #Default to closest audio sample reader.seek(ts) # Default to closest audio sample
status = reader.fill_buffer() status = reader.fill_buffer()
if status != 0: if status != 0:
logging.warning("Audio stream reached end of recording before decoding desired timestamps.") logging.warning("Audio stream reached end of recording before decoding desired timestamps.")
@ -99,15 +99,18 @@ def decode_audio_torchvision(
current_audio_chunk = reader.pop_chunks()[0] current_audio_chunk = reader.pop_chunks()[0]
if log_loaded_timestamps: if log_loaded_timestamps:
logging.info(f"audio chunk loaded at starting timestamp={current_audio_chunk["pts"]:.4f} with duration={len(current_audio_chunk) / audio_sample_rate:.4f}") logging.info(
f"audio chunk loaded at starting timestamp={current_audio_chunk['pts']:.4f} with duration={len(current_audio_chunk) / audio_sample_rate:.4f}"
)
audio_chunks.append(current_audio_chunk) audio_chunks.append(current_audio_chunk)
audio_chunks = torch.stack(audio_chunks) audio_chunks = torch.stack(audio_chunks)
assert len(timestamps) == len(audio_chunks) assert len(timestamps) == len(audio_chunks)
return audio_chunks return audio_chunks
def decode_video_frames( def decode_video_frames(
video_path: Path | str, video_path: Path | str,
timestamps: list[float], timestamps: list[float],
@ -137,6 +140,7 @@ def decode_video_frames(
else: else:
raise ValueError(f"Unsupported video backend: {backend}") raise ValueError(f"Unsupported video backend: {backend}")
def decode_video_frames_torchvision( def decode_video_frames_torchvision(
video_path: Path | str, video_path: Path | str,
timestamps: list[float], timestamps: list[float],
@ -234,6 +238,7 @@ def decode_video_frames_torchvision(
assert len(timestamps) == len(closest_frames) assert len(timestamps) == len(closest_frames)
return closest_frames return closest_frames
def decode_video_frames_torchcodec( def decode_video_frames_torchcodec(
video_path: Path | str, video_path: Path | str,
timestamps: list[float], timestamps: list[float],
@ -308,6 +313,7 @@ def decode_video_frames_torchcodec(
assert len(timestamps) == len(closest_frames) assert len(timestamps) == len(closest_frames)
return closest_frames return closest_frames
def encode_audio( def encode_audio(
input_path: Path | str, input_path: Path | str,
output_path: Path | str, output_path: Path | str,
@ -344,6 +350,7 @@ def encode_audio(
f"Try running the command manually to debug: `{''.join(ffmpeg_cmd)}`" f"Try running the command manually to debug: `{''.join(ffmpeg_cmd)}`"
) )
def encode_video_frames( def encode_video_frames(
imgs_dir: Path | str, imgs_dir: Path | str,
video_path: Path | str, video_path: Path | str,
@ -353,7 +360,7 @@ def encode_video_frames(
pix_fmt: str = "yuv420p", pix_fmt: str = "yuv420p",
g: int | None = 2, g: int | None = 2,
crf: int | None = 30, crf: int | None = 30,
acodec: str = "aac", #TODO(CarolinePascal) : investigate Fraunhofer FDK AAC (libfdk_aac) codec and and constant (file size control) /variable (quality control) bitrate options acodec: str = "aac", # TODO(CarolinePascal) : investigate Fraunhofer FDK AAC (libfdk_aac) codec and and constant (file size control) /variable (quality control) bitrate options
fast_decode: int = 0, fast_decode: int = 0,
log_level: str | None = "error", log_level: str | None = "error",
overwrite: bool = False, overwrite: bool = False,
@ -375,16 +382,18 @@ def encode_video_frames(
if audio_path is not None: if audio_path is not None:
audio_path = Path(audio_path) audio_path = Path(audio_path)
audio_path.parent.mkdir(parents=True, exist_ok=True) audio_path.parent.mkdir(parents=True, exist_ok=True)
ffmpeg_audio_args.update(OrderedDict( ffmpeg_audio_args.update(
[ OrderedDict(
("-i", str(audio_path)), [
] ("-i", str(audio_path)),
)) ]
)
)
ffmpeg_encoding_args = OrderedDict( ffmpeg_encoding_args = OrderedDict(
[ [
("-pix_fmt", pix_fmt), ("-pix_fmt", pix_fmt),
("-vcodec", vcodec), ("-vcodec", vcodec),
] ]
) )
if g is not None: if g is not None:
@ -398,7 +407,7 @@ def encode_video_frames(
if audio_path is not None: if audio_path is not None:
ffmpeg_encoding_args["-acodec"] = acodec ffmpeg_encoding_args["-acodec"] = acodec
if log_level is not None: if log_level is not None:
ffmpeg_encoding_args["-loglevel"] = str(log_level) ffmpeg_encoding_args["-loglevel"] = str(log_level)
@ -487,6 +496,7 @@ def get_audio_info(video_path: Path | str) -> dict:
"audio.channel_layout": audio_stream_info.get("channel_layout", None), "audio.channel_layout": audio_stream_info.get("channel_layout", None),
} }
def get_video_info(video_path: Path | str) -> dict: def get_video_info(video_path: Path | str) -> dict:
ffprobe_video_cmd = [ ffprobe_video_cmd = [
"ffprobe", "ffprobe",

View File

@ -77,8 +77,8 @@ def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, f
key = f"read_camera_{name}_dt_s" key = f"read_camera_{name}_dt_s"
if key in robot.logs: if key in robot.logs:
log_dt(f"dtR{name}", robot.logs[key]) log_dt(f"dtR{name}", robot.logs[key])
for name in robot.microphones: for name in robot.microphones:
key = f"read_microphone_{name}_dt_s" key = f"read_microphone_{name}_dt_s"
if key in robot.logs: if key in robot.logs:
log_dt(f"dtR{name}", robot.logs[key]) log_dt(f"dtR{name}", robot.logs[key])
@ -252,9 +252,11 @@ def control_loop(
timestamp = 0 timestamp = 0
start_episode_t = time.perf_counter() start_episode_t = time.perf_counter()
if dataset is not None and not robot.robot_type.startswith("lekiwi"): #For now, LeKiwi only supports frame audio recording (which may lead to audio chunks loss, extended post-processing, increased memory usage) if (
dataset is not None and not robot.robot_type.startswith("lekiwi")
): # For now, LeKiwi only supports frame audio recording (which may lead to audio chunks loss, extended post-processing, increased memory usage)
for microphone_key, microphone in robot.microphones.items(): for microphone_key, microphone in robot.microphones.items():
#Start recording both in file writing and data reading mode # Start recording both in file writing and data reading mode
dataset.add_microphone_recording(microphone, microphone_key) dataset.add_microphone_recording(microphone, microphone_key)
else: else:
for _, microphone in robot.microphones.items(): for _, microphone in robot.microphones.items():

View File

@ -17,12 +17,14 @@ from dataclasses import dataclass
import draccus import draccus
@dataclass @dataclass
class MicrophoneConfigBase(draccus.ChoiceRegistry, abc.ABC): class MicrophoneConfigBase(draccus.ChoiceRegistry, abc.ABC):
@property @property
def type(self) -> str: def type(self) -> str:
return self.get_choice_name(self.__class__) return self.get_choice_name(self.__class__)
@MicrophoneConfigBase.register_subclass("microphone") @MicrophoneConfigBase.register_subclass("microphone")
@dataclass @dataclass
class MicrophoneConfig(MicrophoneConfigBase): class MicrophoneConfig(MicrophoneConfigBase):
@ -33,4 +35,4 @@ class MicrophoneConfig(MicrophoneConfigBase):
microphone_index: int microphone_index: int
sample_rate: int | None = None sample_rate: int | None = None
channels: list[int] | None = None channels: list[int] | None = None
mock: bool = False mock: bool = False

View File

@ -13,36 +13,35 @@
# limitations under the License. # limitations under the License.
""" """
This file contains utilities for recording audio from a microhone. This file contains utilities for recording audio from a microhone.
""" """
import argparse import argparse
import soundfile as sf
import numpy as np
import logging import logging
from threading import Thread, Event
from multiprocessing import Process
from queue import Empty
from queue import Queue as thread_Queue
from threading import Event as thread_Event
from multiprocessing import JoinableQueue as process_Queue
from multiprocessing import Event as process_Event
from os import getcwd
from pathlib import Path
import shutil import shutil
import time import time
from multiprocessing import Event as process_Event
from multiprocessing import JoinableQueue as process_Queue
from multiprocessing import Process
from os import getcwd
from pathlib import Path
from queue import Empty
from queue import Queue as thread_Queue
from threading import Event, Thread
from threading import Event as thread_Event
from lerobot.common.utils.utils import capture_timestamp_utc import numpy as np
import soundfile as sf
from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig
from lerobot.common.robot_devices.utils import ( from lerobot.common.robot_devices.utils import (
RobotDeviceAlreadyConnectedError, RobotDeviceAlreadyConnectedError,
RobotDeviceAlreadyRecordingError,
RobotDeviceNotConnectedError, RobotDeviceNotConnectedError,
RobotDeviceNotRecordingError, RobotDeviceNotRecordingError,
RobotDeviceAlreadyRecordingError,
) )
from lerobot.common.utils.utils import capture_timestamp_utc
def find_microphones(raise_when_empty=False, mock=False) -> list[dict]: def find_microphones(raise_when_empty=False, mock=False) -> list[dict]:
microphones = [] microphones = []
@ -69,11 +68,10 @@ def find_microphones(raise_when_empty=False, mock=False) -> list[dict]:
return microphones return microphones
def record_audio_from_microphones(
output_dir: Path,
microphone_ids: list[int] | None = None,
record_time_s: float = 2.0):
def record_audio_from_microphones(
output_dir: Path, microphone_ids: list[int] | None = None, record_time_s: float = 2.0
):
if microphone_ids is None or len(microphone_ids) == 0: if microphone_ids is None or len(microphone_ids) == 0:
microphones = find_microphones() microphones = find_microphones()
microphone_ids = [m["index"] for m in microphones] microphone_ids = [m["index"] for m in microphones]
@ -104,13 +102,14 @@ def record_audio_from_microphones(
for microphone in microphones: for microphone in microphones:
microphone.stop_recording() microphone.stop_recording()
#Remark : recording may be resumed here if needed # Remark : recording may be resumed here if needed
for microphone in microphones: for microphone in microphones:
microphone.disconnect() microphone.disconnect()
print(f"Images have been saved to {output_dir}") print(f"Images have been saved to {output_dir}")
class Microphone: class Microphone:
""" """
The Microphone class handles all microphones compatible with sounddevice (and the underlying PortAudio library). Most microphones and sound cards are compatible, accross all OS (Linux, Mac, Windows). The Microphone class handles all microphones compatible with sounddevice (and the underlying PortAudio library). Most microphones and sound cards are compatible, accross all OS (Linux, Mac, Windows).
@ -138,20 +137,20 @@ class Microphone:
self.config = config self.config = config
self.microphone_index = config.microphone_index self.microphone_index = config.microphone_index
#Store the recording sample rate and channels # Store the recording sample rate and channels
self.sample_rate = config.sample_rate self.sample_rate = config.sample_rate
self.channels = config.channels self.channels = config.channels
self.mock = config.mock self.mock = config.mock
#Input audio stream # Input audio stream
self.stream = None self.stream = None
#Thread-safe concurrent queue to store the recorded/read audio # Thread-safe concurrent queue to store the recorded/read audio
self.record_queue = None self.record_queue = None
self.read_queue = None self.read_queue = None
#Thread to handle data reading and file writing in a separate thread (safely) # Thread to handle data reading and file writing in a separate thread (safely)
self.record_thread = None self.record_thread = None
self.record_stop_event = None self.record_stop_event = None
@ -162,14 +161,16 @@ class Microphone:
def connect(self) -> None: def connect(self) -> None:
if self.is_connected: if self.is_connected:
raise RobotDeviceAlreadyConnectedError(f"Microphone {self.microphone_index} is already connected.") raise RobotDeviceAlreadyConnectedError(
f"Microphone {self.microphone_index} is already connected."
)
if self.mock: if self.mock:
import tests.microphones.mock_sounddevice as sd import tests.microphones.mock_sounddevice as sd
else: else:
import sounddevice as sd import sounddevice as sd
#Check if the provided microphone index does match an input device # Check if the provided microphone index does match an input device
is_index_input = sd.query_devices(self.microphone_index)["max_input_channels"] > 0 is_index_input = sd.query_devices(self.microphone_index)["max_input_channels"] > 0
if not is_index_input: if not is_index_input:
@ -178,17 +179,19 @@ class Microphone:
raise OSError( raise OSError(
f"Microphone index {self.microphone_index} does not match an input device (microphone). Available input devices : {available_microphones}" f"Microphone index {self.microphone_index} does not match an input device (microphone). Available input devices : {available_microphones}"
) )
#Check if provided recording parameters are compatible with the microphone # Check if provided recording parameters are compatible with the microphone
actual_microphone = sd.query_devices(self.microphone_index) actual_microphone = sd.query_devices(self.microphone_index)
if self.sample_rate is not None : if self.sample_rate is not None:
if self.sample_rate > actual_microphone["default_samplerate"]: if self.sample_rate > actual_microphone["default_samplerate"]:
raise OSError( raise OSError(
f"Provided sample rate {self.sample_rate} is higher than the sample rate of the microphone {actual_microphone['default_samplerate']}." f"Provided sample rate {self.sample_rate} is higher than the sample rate of the microphone {actual_microphone['default_samplerate']}."
) )
elif self.sample_rate < actual_microphone["default_samplerate"]: elif self.sample_rate < actual_microphone["default_samplerate"]:
logging.warning("Provided sample rate is lower than the sample rate of the microphone. Performance may be impacted.") logging.warning(
"Provided sample rate is lower than the sample rate of the microphone. Performance may be impacted."
)
else: else:
self.sample_rate = int(actual_microphone["default_samplerate"]) self.sample_rate = int(actual_microphone["default_samplerate"])
@ -198,45 +201,52 @@ class Microphone:
f"Some of the provided channels {self.channels} are outside the maximum channel range of the microphone {actual_microphone['max_input_channels']}." f"Some of the provided channels {self.channels} are outside the maximum channel range of the microphone {actual_microphone['max_input_channels']}."
) )
else: else:
self.channels = np.arange(1, actual_microphone["max_input_channels"]+1) self.channels = np.arange(1, actual_microphone["max_input_channels"] + 1)
# Get channels index instead of number for slicing # Get channels index instead of number for slicing
self.channels = np.array(self.channels) - 1 self.channels = np.array(self.channels) - 1
#Create the audio stream # Create the audio stream
self.stream = sd.InputStream( self.stream = sd.InputStream(
device=self.microphone_index, device=self.microphone_index,
samplerate=self.sample_rate, samplerate=self.sample_rate,
channels=max(self.channels)+1, channels=max(self.channels) + 1,
dtype="float32", dtype="float32",
callback=self._audio_callback, callback=self._audio_callback,
) )
#Remark : the blocksize parameter could be passed to the stream to ensure that audio_callback always recieve same length buffers. # Remark : the blocksize parameter could be passed to the stream to ensure that audio_callback always recieve same length buffers.
#However, this may lead to additionnal latency. We thus stick to blocksize=0 which means that audio_callback will recieve varying length buffers, but with no addtional latency. # However, this may lead to additionnal latency. We thus stick to blocksize=0 which means that audio_callback will recieve varying length buffers, but with no addtional latency.
self.is_connected = True self.is_connected = True
def _audio_callback(self, indata, frames, time, status) -> None : def _audio_callback(self, indata, frames, time, status) -> None:
if status: if status:
logging.warning(status) logging.warning(status)
# Slicing makes copy unecessary # Slicing makes copy unecessary
# Two separate queues are necessary because .get() also pops the data from the queue # Two separate queues are necessary because .get() also pops the data from the queue
if self.is_writing: if self.is_writing:
self.record_queue.put(indata[:,self.channels]) self.record_queue.put(indata[:, self.channels])
self.read_queue.put(indata[:,self.channels]) self.read_queue.put(indata[:, self.channels])
@staticmethod @staticmethod
def _record_loop(queue, event: Event, sample_rate: int, channels: list[int], output_file: Path) -> None: def _record_loop(queue, event: Event, sample_rate: int, channels: list[int], output_file: Path) -> None:
#Can only be run on a single process/thread for file writing safety # Can only be run on a single process/thread for file writing safety
with sf.SoundFile(output_file, mode='x', samplerate=sample_rate, with sf.SoundFile(
channels=max(channels)+1, subtype=sf.default_subtype(output_file.suffix[1:])) as file: output_file,
mode="x",
samplerate=sample_rate,
channels=max(channels) + 1,
subtype=sf.default_subtype(output_file.suffix[1:]),
) as file:
while not event.is_set(): while not event.is_set():
try: try:
file.write(queue.get(timeout=0.02)) #Timeout set as twice the usual sounddevice buffer size file.write(
queue.get(timeout=0.02)
) # Timeout set as twice the usual sounddevice buffer size
queue.task_done() queue.task_done()
except Empty: except Empty:
continue continue
def _read(self) -> np.ndarray: def _read(self) -> np.ndarray:
""" """
Gets audio data from the queue and coverts it to a numpy array. Gets audio data from the queue and coverts it to a numpy array.
@ -244,7 +254,7 @@ class Microphone:
-> CONS : Reading duration does not scale well with the number of channels and reading duration -> CONS : Reading duration does not scale well with the number of channels and reading duration
""" """
audio_readings = np.empty((0, len(self.channels))) audio_readings = np.empty((0, len(self.channels)))
while True: while True:
try: try:
audio_readings = np.concatenate((audio_readings, self.read_queue.get_nowait()), axis=0) audio_readings = np.concatenate((audio_readings, self.read_queue.get_nowait()), axis=0)
@ -256,12 +266,11 @@ class Microphone:
return audio_readings return audio_readings
def read(self) -> np.ndarray: def read(self) -> np.ndarray:
if not self.is_connected: if not self.is_connected:
raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.") raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
if not self.is_recording: if not self.is_recording:
raise RobotDeviceNotRecordingError(f"Microphone {self.microphone_index} is not recording.") raise RobotDeviceNotRecordingError(f"Microphone {self.microphone_index} is not recording.")
start_time = time.perf_counter() start_time = time.perf_counter()
audio_readings = self._read() audio_readings = self._read()
@ -274,21 +283,22 @@ class Microphone:
return audio_readings return audio_readings
def start_recording(self, output_file : str | None = None, multiprocessing : bool | None = False) -> None: def start_recording(self, output_file: str | None = None, multiprocessing: bool | None = False) -> None:
if not self.is_connected: if not self.is_connected:
raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.") raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
if self.is_recording: if self.is_recording:
raise RobotDeviceAlreadyRecordingError(f"Microphone {self.microphone_index} is already recording.") raise RobotDeviceAlreadyRecordingError(
f"Microphone {self.microphone_index} is already recording."
#Reset queues )
# Reset queues
self.read_queue = thread_Queue() self.read_queue = thread_Queue()
if multiprocessing: if multiprocessing:
self.record_queue = process_Queue() self.record_queue = process_Queue()
else: else:
self.record_queue = thread_Queue() self.record_queue = thread_Queue()
#Write recordings into a file if output_file is provided # Write recordings into a file if output_file is provided
if output_file is not None: if output_file is not None:
output_file = Path(output_file) output_file = Path(output_file)
if output_file.exists(): if output_file.exists():
@ -296,28 +306,45 @@ class Microphone:
if multiprocessing: if multiprocessing:
self.record_stop_event = process_Event() self.record_stop_event = process_Event()
self.record_thread = Process(target=Microphone._record_loop, args=(self.record_queue, self.record_stop_event, self.sample_rate, self.channels, output_file, )) self.record_thread = Process(
target=Microphone._record_loop,
args=(
self.record_queue,
self.record_stop_event,
self.sample_rate,
self.channels,
output_file,
),
)
else: else:
self.record_stop_event = thread_Event() self.record_stop_event = thread_Event()
self.record_thread = Thread(target=Microphone._record_loop, args=(self.record_queue, self.record_stop_event, self.sample_rate, self.channels, output_file, )) self.record_thread = Thread(
target=Microphone._record_loop,
args=(
self.record_queue,
self.record_stop_event,
self.sample_rate,
self.channels,
output_file,
),
)
self.record_thread.daemon = True self.record_thread.daemon = True
self.record_thread.start() self.record_thread.start()
self.is_writing = True self.is_writing = True
self.is_recording = True self.is_recording = True
self.stream.start() self.stream.start()
def stop_recording(self) -> None: def stop_recording(self) -> None:
if not self.is_connected: if not self.is_connected:
raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.") raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
if not self.is_recording: if not self.is_recording:
raise RobotDeviceNotRecordingError(f"Microphone {self.microphone_index} is not recording.") raise RobotDeviceNotRecordingError(f"Microphone {self.microphone_index} is not recording.")
if self.stream.active: if self.stream.active:
self.stream.stop() #Wait for all buffers to be processed self.stream.stop() # Wait for all buffers to be processed
#Remark : stream.abort() flushes the buffers ! # Remark : stream.abort() flushes the buffers !
self.is_recording = False self.is_recording = False
if self.record_thread is not None: if self.record_thread is not None:
@ -329,7 +356,6 @@ class Microphone:
self.is_writing = False self.is_writing = False
def disconnect(self) -> None: def disconnect(self) -> None:
if not self.is_connected: if not self.is_connected:
raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.") raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
@ -342,7 +368,8 @@ class Microphone:
def __del__(self): def __del__(self):
if getattr(self, "is_connected", False): if getattr(self, "is_connected", False):
self.disconnect() self.disconnect()
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Records audio using `Microphone` for all microphones connected to the computer, or a selected subset." description="Records audio using `Microphone` for all microphones connected to the computer, or a selected subset."

View File

@ -16,28 +16,33 @@ from typing import Protocol
from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig, MicrophoneConfigBase from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig, MicrophoneConfigBase
# Defines a microphone type # Defines a microphone type
class Microphone(Protocol): class Microphone(Protocol):
def connect(self): ... def connect(self): ...
def disconnect(self): ... def disconnect(self): ...
def start_recording(self, output_file : str | None = None, multiprocessing : bool | None = False): ... def start_recording(self, output_file: str | None = None, multiprocessing: bool | None = False): ...
def stop_recording(self): ... def stop_recording(self): ...
def make_microphones_from_configs(microphone_configs: dict[str, MicrophoneConfigBase]) -> list[Microphone]: def make_microphones_from_configs(microphone_configs: dict[str, MicrophoneConfigBase]) -> list[Microphone]:
microphones = {} microphones = {}
for key, cfg in microphone_configs.items(): for key, cfg in microphone_configs.items():
if cfg.type == "microphone": if cfg.type == "microphone":
from lerobot.common.robot_devices.microphones.microphone import Microphone from lerobot.common.robot_devices.microphones.microphone import Microphone
microphones[key] = Microphone(cfg) microphones[key] = Microphone(cfg)
else: else:
raise ValueError(f"The microphone type '{cfg.type}' is not valid.") raise ValueError(f"The microphone type '{cfg.type}' is not valid.")
return microphones return microphones
def make_microphone(microphone_type, **kwargs) -> Microphone: def make_microphone(microphone_type, **kwargs) -> Microphone:
if microphone_type == "microphone": if microphone_type == "microphone":
from lerobot.common.robot_devices.microphones.microphone import Microphone from lerobot.common.robot_devices.microphones.microphone import Microphone
return Microphone(MicrophoneConfig(**kwargs)) return Microphone(MicrophoneConfig(**kwargs))
else: else:
raise ValueError(f"The microphone type '{microphone_type}' is not valid.") raise ValueError(f"The microphone type '{microphone_type}' is not valid.")

View File

@ -23,12 +23,12 @@ from lerobot.common.robot_devices.cameras.configs import (
IntelRealSenseCameraConfig, IntelRealSenseCameraConfig,
OpenCVCameraConfig, OpenCVCameraConfig,
) )
from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig
from lerobot.common.robot_devices.motors.configs import ( from lerobot.common.robot_devices.motors.configs import (
DynamixelMotorsBusConfig, DynamixelMotorsBusConfig,
FeetechMotorsBusConfig, FeetechMotorsBusConfig,
MotorsBusConfig, MotorsBusConfig,
) )
from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig
@dataclass @dataclass

View File

@ -51,6 +51,7 @@ def run_camera_capture(cameras, images_lock, latest_images_dict, stop_event):
latest_images_dict.update(local_dict) latest_images_dict.update(local_dict)
time.sleep(0.01) time.sleep(0.01)
def run_microphone_capture(microphones, audio_lock, latest_audio_dict, stop_event): def run_microphone_capture(microphones, audio_lock, latest_audio_dict, stop_event):
while not stop_event.is_set(): while not stop_event.is_set():
local_dict = {} local_dict = {}
@ -60,6 +61,7 @@ def run_microphone_capture(microphones, audio_lock, latest_audio_dict, stop_even
with audio_lock: with audio_lock:
latest_audio_dict.update(local_dict) latest_audio_dict.update(local_dict)
def calibrate_follower_arm(motors_bus, calib_dir_str): def calibrate_follower_arm(motors_bus, calib_dir_str):
""" """
Calibrates the follower arm. Attempts to load an existing calibration file; Calibrates the follower arm. Attempts to load an existing calibration file;
@ -149,12 +151,14 @@ def run_lekiwi(robot_config):
cam_thread.start() cam_thread.start()
# Start the microphone recording and capture thread. # Start the microphone recording and capture thread.
#TODO(CarolinePascal) : Leverage multi-core processing with a multiprocessing.Process instead ! # TODO(CarolinePascal) : Leverage multi-core processing with a multiprocessing.Process instead !
latest_audio_dict = {} latest_audio_dict = {}
audio_lock = threading.Lock() audio_lock = threading.Lock()
audio_stop_event = threading.Event() audio_stop_event = threading.Event()
microphone_thread = threading.Thread( microphone_thread = threading.Thread(
target=run_microphone_capture, args=(microphones, audio_lock, latest_audio_dict, audio_stop_event), daemon=True target=run_microphone_capture,
args=(microphones, audio_lock, latest_audio_dict, audio_stop_event),
daemon=True,
) )
for microphone in microphones.values(): for microphone in microphones.values():
microphone.start_recording() microphone.start_recording()
@ -231,7 +235,7 @@ def run_lekiwi(robot_config):
# Build the observation dictionary. # Build the observation dictionary.
observation = { observation = {
"images": images_dict_copy, "images": images_dict_copy,
"audio": audio_dict_copy, #TODO(CarolinePascal) : This is a nasty way to do it, sorry. "audio": audio_dict_copy, # TODO(CarolinePascal) : This is a nasty way to do it, sorry.
"present_speed": current_velocity, "present_speed": current_velocity,
"follower_arm_state": follower_arm_state, "follower_arm_state": follower_arm_state,
} }

View File

@ -201,7 +201,7 @@ class ManipulatorRobot:
"names": state_names, "names": state_names,
}, },
} }
@property @property
def microphone_features(self) -> dict: def microphone_features(self) -> dict:
mic_ft = {} mic_ft = {}
@ -211,7 +211,7 @@ class ManipulatorRobot:
"dtype": "audio", "dtype": "audio",
"shape": (len(mic.channels),), "shape": (len(mic.channels),),
"names": "channels", "names": "channels",
"info" : {"sample_rate": mic.sample_rate}, "info": {"sample_rate": mic.sample_rate},
} }
return mic_ft return mic_ft
@ -226,11 +226,11 @@ class ManipulatorRobot:
@property @property
def num_cameras(self): def num_cameras(self):
return len(self.cameras) return len(self.cameras)
@property @property
def has_microphone(self): def has_microphone(self):
return len(self.microphones) > 0 return len(self.microphones) > 0
@property @property
def num_microphones(self): def num_microphones(self):
return len(self.microphones) return len(self.microphones)

View File

@ -163,7 +163,7 @@ class MobileManipulator:
"names": combined_names, "names": combined_names,
}, },
} }
@property @property
def microphone_features(self) -> dict: def microphone_features(self) -> dict:
mic_ft = {} mic_ft = {}
@ -173,7 +173,7 @@ class MobileManipulator:
"dtype": "audio", "dtype": "audio",
"shape": (len(mic.channels),), "shape": (len(mic.channels),),
"names": "channels", "names": "channels",
"info" : {"sample_rate": mic.sample_rate}, "info": {"sample_rate": mic.sample_rate},
} }
return mic_ft return mic_ft
@ -188,11 +188,11 @@ class MobileManipulator:
@property @property
def num_cameras(self): def num_cameras(self):
return len(self.cameras) return len(self.cameras)
@property @property
def has_microphone(self): def has_microphone(self):
return len(self.microphones) > 0 return len(self.microphones) > 0
@property @property
def num_microphones(self): def num_microphones(self):
return len(self.microphones) return len(self.microphones)
@ -512,7 +512,7 @@ class MobileManipulator:
# Create silence using the microphone's configured channels # Create silence using the microphone's configured channels
frame = np.zeros((1, len(microphone.channels)), dtype=np.float32) frame = np.zeros((1, len(microphone.channels)), dtype=np.float32)
obs_dict[f"observation.audio.{microphone_name}"] = torch.from_numpy(frame) obs_dict[f"observation.audio.{microphone_name}"] = torch.from_numpy(frame)
return obs_dict return obs_dict
def send_action(self, action: torch.Tensor) -> torch.Tensor: def send_action(self, action: torch.Tensor) -> torch.Tensor:

View File

@ -69,11 +69,13 @@ class RobotDeviceNotRecordingError(Exception):
"""Exception raised when the robot device is not recording.""" """Exception raised when the robot device is not recording."""
def __init__( def __init__(
self, message="This robot device is not recording. Try calling `robot_device.start_recording()` first." self,
message="This robot device is not recording. Try calling `robot_device.start_recording()` first.",
): ):
self.message = message self.message = message
super().__init__(self.message) super().__init__(self.message)
class RobotDeviceAlreadyRecordingError(Exception): class RobotDeviceAlreadyRecordingError(Exception):
"""Exception raised when the robot device is already recording.""" """Exception raised when the robot device is already recording."""

View File

@ -19,9 +19,9 @@ import traceback
import pytest import pytest
from serial import SerialException from serial import SerialException
from lerobot import available_cameras, available_motors, available_robots, available_microphones from lerobot import available_cameras, available_microphones, available_motors, available_robots
from lerobot.common.robot_devices.robots.utils import make_robot from lerobot.common.robot_devices.robots.utils import make_robot
from tests.utils import DEVICE, make_camera, make_motors_bus, make_microphone from tests.utils import DEVICE, make_camera, make_microphone, make_motors_bus
# Import fixture modules as plugins # Import fixture modules as plugins
pytest_plugins = [ pytest_plugins = [
@ -73,10 +73,12 @@ def is_robot_available(robot_type):
def is_camera_available(camera_type): def is_camera_available(camera_type):
return _check_component_availability(camera_type, available_cameras, make_camera) return _check_component_availability(camera_type, available_cameras, make_camera)
@pytest.fixture @pytest.fixture
def is_microphone_available(microphone_type): def is_microphone_available(microphone_type):
return _check_component_availability(microphone_type, available_microphones, make_microphone) return _check_component_availability(microphone_type, available_microphones, make_microphone)
@pytest.fixture @pytest.fixture
def is_motor_available(motor_type): def is_motor_available(motor_type):
return _check_component_availability(motor_type, available_motors, make_motors_bus) return _check_component_availability(motor_type, available_motors, make_motors_bus)

View File

@ -25,9 +25,9 @@ from lerobot.common.datasets.compute_stats import (
compute_episode_stats, compute_episode_stats,
estimate_num_samples, estimate_num_samples,
get_feature_stats, get_feature_stats,
sample_images,
sample_audio_from_path,
sample_audio_from_data, sample_audio_from_data,
sample_audio_from_path,
sample_images,
sample_indices, sample_indices,
) )
@ -35,8 +35,10 @@ from lerobot.common.datasets.compute_stats import (
def mock_load_image_as_numpy(path, dtype, channel_first): def mock_load_image_as_numpy(path, dtype, channel_first):
return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype) return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype)
def mock_load_audio(path): def mock_load_audio(path):
return np.ones((16000,2), dtype=np.float32) return np.ones((16000, 2), dtype=np.float32)
@pytest.fixture @pytest.fixture
def sample_array(): def sample_array():
@ -74,6 +76,7 @@ def test_sample_images(mock_load):
assert images.dtype == np.uint8 assert images.dtype == np.uint8
assert len(images) == estimate_num_samples(100) assert len(images) == estimate_num_samples(100)
@patch("lerobot.common.datasets.compute_stats.load_audio_from_path", side_effect=mock_load_audio) @patch("lerobot.common.datasets.compute_stats.load_audio_from_path", side_effect=mock_load_audio)
def test_sample_audio_from_path(mock_load): def test_sample_audio_from_path(mock_load):
audio_path = "audio.wav" audio_path = "audio.wav"
@ -83,6 +86,7 @@ def test_sample_audio_from_path(mock_load):
assert audio_samples.dtype == np.float32 assert audio_samples.dtype == np.float32
assert len(audio_samples) == estimate_num_samples(16000) assert len(audio_samples) == estimate_num_samples(16000)
def test_sample_audio_from_data(mock_load): def test_sample_audio_from_data(mock_load):
audio_data = np.ones((16000, 2), dtype=np.float32) audio_data = np.ones((16000, 2), dtype=np.float32)
audio_samples = sample_audio_from_data(audio_data) audio_samples = sample_audio_from_data(audio_data)
@ -91,6 +95,7 @@ def test_sample_audio_from_data(mock_load):
assert audio_samples.dtype == np.float32 assert audio_samples.dtype == np.float32
assert len(audio_samples) == estimate_num_samples(16000) assert len(audio_samples) == estimate_num_samples(16000)
def test_get_feature_stats_images(): def test_get_feature_stats_images():
data = np.random.rand(100, 3, 32, 32) data = np.random.rand(100, 3, 32, 32)
stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True) stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True)
@ -98,13 +103,15 @@ def test_get_feature_stats_images():
np.testing.assert_equal(stats["count"], np.array([100])) np.testing.assert_equal(stats["count"], np.array([100]))
assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape
def test_get_feature_stats_audio(): def test_get_feature_stats_audio():
data = np.random.uniform(-1, 1, (16000,2)) data = np.random.uniform(-1, 1, (16000, 2))
stats = get_feature_stats(data, axis=0, keepdims=True) stats = get_feature_stats(data, axis=0, keepdims=True)
assert "min" in stats and "max" in stats and "mean" in stats and "std" in stats and "count" in stats assert "min" in stats and "max" in stats and "mean" in stats and "std" in stats and "count" in stats
np.testing.assert_equal(stats["count"], np.array([16000])) np.testing.assert_equal(stats["count"], np.array([16000]))
assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape
def test_get_feature_stats_axis_0_keepdims(sample_array): def test_get_feature_stats_axis_0_keepdims(sample_array):
expected = { expected = {
"min": np.array([[1, 2, 3]]), "min": np.array([[1, 2, 3]]),
@ -172,10 +179,11 @@ def test_compute_episode_stats():
"observation.state": {"dtype": "numeric"}, "observation.state": {"dtype": "numeric"},
} }
with patch( with (
"lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy patch(
), patch( "lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy
"lerobot.common.datasets.compute_stats.load_audio_from_path", side_effect=mock_load_audio ),
patch("lerobot.common.datasets.compute_stats.load_audio_from_path", side_effect=mock_load_audio),
): ):
stats = compute_episode_stats(episode_data, features) stats = compute_episode_stats(episode_data, features)

View File

@ -16,6 +16,7 @@
import json import json
import logging import logging
import re import re
import time
from copy import deepcopy from copy import deepcopy
from itertools import chain from itertools import chain
from pathlib import Path from pathlib import Path
@ -35,6 +36,7 @@ from lerobot.common.datasets.lerobot_dataset import (
MultiLeRobotDataset, MultiLeRobotDataset,
) )
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
DEFAULT_AUDIO_CHUNK_DURATION,
create_branch, create_branch,
flatten_dict, flatten_dict,
unflatten_dict, unflatten_dict,
@ -44,12 +46,9 @@ from lerobot.common.policies.factory import make_policy_config
from lerobot.common.robot_devices.robots.utils import make_robot from lerobot.common.robot_devices.robots.utils import make_robot
from lerobot.configs.default import DatasetConfig from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig from lerobot.configs.train import TrainPipelineConfig
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID, DUMMY_AUDIO_CHANNELS from tests.fixtures.constants import DUMMY_AUDIO_CHANNELS, DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
from tests.utils import require_x86_64_kernel from tests.utils import make_microphone, require_x86_64_kernel
from tests.utils import make_microphone
import time
from lerobot.common.datasets.utils import DEFAULT_AUDIO_CHUNK_DURATION
@pytest.fixture @pytest.fixture
def image_dataset(tmp_path, empty_lerobot_dataset_factory): def image_dataset(tmp_path, empty_lerobot_dataset_factory):
@ -66,6 +65,7 @@ def image_dataset(tmp_path, empty_lerobot_dataset_factory):
} }
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
@pytest.fixture @pytest.fixture
def audio_dataset(tmp_path, empty_lerobot_dataset_factory): def audio_dataset(tmp_path, empty_lerobot_dataset_factory):
features = { features = {
@ -79,6 +79,7 @@ def audio_dataset(tmp_path, empty_lerobot_dataset_factory):
} }
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
def test_same_attributes_defined(tmp_path, lerobot_dataset_factory): def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
""" """
Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated
@ -336,6 +337,7 @@ def test_image_array_to_pil_image_wrong_range_float_0_255():
with pytest.raises(ValueError): with pytest.raises(ValueError):
image_array_to_pil_image(image) image_array_to_pil_image(image)
def test_add_frame_audio(audio_dataset): def test_add_frame_audio(audio_dataset):
dataset = audio_dataset dataset = audio_dataset
@ -349,7 +351,10 @@ def test_add_frame_audio(audio_dataset):
dataset.save_episode() dataset.save_episode()
assert dataset[0]["observation.audio.microphone"].shape == torch.Size((int(DEFAULT_AUDIO_CHUNK_DURATION*microphone.sample_rate),DUMMY_AUDIO_CHANNELS)) assert dataset[0]["observation.audio.microphone"].shape == torch.Size(
(int(DEFAULT_AUDIO_CHUNK_DURATION * microphone.sample_rate), DUMMY_AUDIO_CHANNELS)
)
# TODO(aliberts): # TODO(aliberts):
# - [ ] test various attributes & state from init and create # - [ ] test various attributes & state from init and create

View File

@ -29,7 +29,12 @@ DUMMY_MOTOR_FEATURES = {
}, },
} }
DUMMY_CAMERA_FEATURES = { DUMMY_CAMERA_FEATURES = {
"laptop": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None, "audio": "laptop"}, "laptop": {
"shape": (480, 640, 3),
"names": ["height", "width", "channels"],
"info": None,
"audio": "laptop",
},
"phone": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None}, "phone": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
} }
DEFAULT_FPS = 30 DEFAULT_FPS = 30

View File

@ -26,10 +26,10 @@ import torch
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_SIZE,
DEFAULT_COMPRESSED_AUDIO_PATH,
DEFAULT_FEATURES, DEFAULT_FEATURES,
DEFAULT_PARQUET_PATH, DEFAULT_PARQUET_PATH,
DEFAULT_VIDEO_PATH, DEFAULT_VIDEO_PATH,
DEFAULT_COMPRESSED_AUDIO_PATH,
get_hf_features_from_features, get_hf_features_from_features,
hf_transform_to_torch, hf_transform_to_torch,
) )

View File

@ -11,27 +11,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import time
from functools import cache from functools import cache
from threading import Event, Thread
from tests.fixtures.constants import DUMMY_AUDIO_CHANNELS, DEFAULT_SAMPLE_RATE
import numpy as np import numpy as np
from lerobot.common.utils.utils import capture_timestamp_utc from lerobot.common.utils.utils import capture_timestamp_utc
from threading import Thread, Event from tests.fixtures.constants import DEFAULT_SAMPLE_RATE, DUMMY_AUDIO_CHANNELS
import time
@cache @cache
def _generate_sound(duration: float, sample_rate: int, channels: int): def _generate_sound(duration: float, sample_rate: int, channels: int):
return np.random.uniform(-1, 1, size=(int(duration * sample_rate), channels)).astype(np.float32) return np.random.uniform(-1, 1, size=(int(duration * sample_rate), channels)).astype(np.float32)
def query_devices(query_index: int): def query_devices(query_index: int):
return { return {
"name": "Mock Sound Device", "name": "Mock Sound Device",
"index": query_index, "index": query_index,
"max_input_channels": DUMMY_AUDIO_CHANNELS, "max_input_channels": DUMMY_AUDIO_CHANNELS,
"default_samplerate": DEFAULT_SAMPLE_RATE, "default_samplerate": DEFAULT_SAMPLE_RATE,
} }
class InputStream: class InputStream:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self._mock_dict = { self._mock_dict = {
@ -49,7 +52,12 @@ class InputStream:
while not self.callback_thread_stop_event.is_set(): while not self.callback_thread_stop_event.is_set():
# Simulate audio data acquisition # Simulate audio data acquisition
time.sleep(0.01) time.sleep(0.01)
self._audio_callback(_generate_sound(0.01, DEFAULT_SAMPLE_RATE, DUMMY_AUDIO_CHANNELS), 0.01*DEFAULT_SAMPLE_RATE, capture_timestamp_utc(), None) self._audio_callback(
_generate_sound(0.01, DEFAULT_SAMPLE_RATE, DUMMY_AUDIO_CHANNELS),
0.01 * DEFAULT_SAMPLE_RATE,
capture_timestamp_utc(),
None,
)
def start(self): def start(self):
self.callback_thread_stop_event = Event() self.callback_thread_stop_event = Event()
@ -62,7 +70,7 @@ class InputStream:
@property @property
def active(self): def active(self):
return self._is_active return self._is_active
def stop(self): def stop(self):
if self.callback_thread_stop_event is not None: if self.callback_thread_stop_event is not None:
self.callback_thread_stop_event.set() self.callback_thread_stop_event.set()
@ -78,5 +86,3 @@ class InputStream:
def __del__(self): def __del__(self):
if self._is_active: if self._is_active:
self.stop() self.stop()

View File

@ -32,20 +32,27 @@ pytest -sx 'tests/microphones/test_microphones.py::test_microphone[microphone-Tr
``` ```
""" """
import numpy as np
import time import time
import numpy as np
import pytest import pytest
from soundfile import read from soundfile import read
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError, RobotDeviceNotRecordingError, RobotDeviceAlreadyRecordingError from lerobot.common.robot_devices.utils import (
RobotDeviceAlreadyConnectedError,
RobotDeviceAlreadyRecordingError,
RobotDeviceNotConnectedError,
RobotDeviceNotRecordingError,
)
from tests.utils import TEST_MICROPHONE_TYPES, make_microphone, require_microphone from tests.utils import TEST_MICROPHONE_TYPES, make_microphone, require_microphone
#Maximum recording tie difference between two consecutive audio recordings of the same duration. # Maximum recording tie difference between two consecutive audio recordings of the same duration.
#Set to 0.02 seconds as twice the default size of sounddvice callback buffer (i.e. we tolerate the loss of one buffer). # Set to 0.02 seconds as twice the default size of sounddvice callback buffer (i.e. we tolerate the loss of one buffer).
MAX_RECORDING_TIME_DIFFERENCE = 0.02 MAX_RECORDING_TIME_DIFFERENCE = 0.02
DUMMY_RECORDING = "test_recording.wav" DUMMY_RECORDING = "test_recording.wav"
@pytest.mark.parametrize("microphone_type, mock", TEST_MICROPHONE_TYPES) @pytest.mark.parametrize("microphone_type, mock", TEST_MICROPHONE_TYPES)
@require_microphone @require_microphone
def test_microphone(tmp_path, request, microphone_type, mock): def test_microphone(tmp_path, request, microphone_type, mock):
@ -92,7 +99,7 @@ def test_microphone(tmp_path, request, microphone_type, mock):
fpath = tmp_path / DUMMY_RECORDING fpath = tmp_path / DUMMY_RECORDING
microphone.start_recording(fpath) microphone.start_recording(fpath)
assert microphone.is_recording assert microphone.is_recording
# Test start_recording twice raises an error # Test start_recording twice raises an error
with pytest.raises(RobotDeviceAlreadyRecordingError): with pytest.raises(RobotDeviceAlreadyRecordingError):
microphone.start_recording() microphone.start_recording()
@ -126,10 +133,13 @@ def test_microphone(tmp_path, request, microphone_type, mock):
error_msg = ( error_msg = (
"Recording time difference between read() and stop_recording()", "Recording time difference between read() and stop_recording()",
(len(audio_chunk) - len(recorded_audio))/MAX_RECORDING_TIME_DIFFERENCE, (len(audio_chunk) - len(recorded_audio)) / MAX_RECORDING_TIME_DIFFERENCE,
) )
np.testing.assert_allclose( np.testing.assert_allclose(
len(audio_chunk), len(recorded_audio), atol=recorded_sample_rate*MAX_RECORDING_TIME_DIFFERENCE, err_msg=error_msg len(audio_chunk),
len(recorded_audio),
atol=recorded_sample_rate * MAX_RECORDING_TIME_DIFFERENCE,
err_msg=error_msg,
) )
# Test disconnecting # Test disconnecting
@ -139,4 +149,4 @@ def test_microphone(tmp_path, request, microphone_type, mock):
# Test disconnecting with `__del__` # Test disconnecting with `__del__`
microphone = make_microphone(**microphone_kwargs) microphone = make_microphone(**microphone_kwargs)
microphone.connect() microphone.connect()
del microphone del microphone

View File

@ -143,7 +143,7 @@ def test_robot(tmp_path, request, robot_type, mock):
robot.send_action(action["action"]) robot.send_action(action["action"])
# Test disconnecting # Test disconnecting
robot.disconnect() #Also handles microphone recording stop, life is beautiful robot.disconnect() # Also handles microphone recording stop, life is beautiful
assert not robot.is_connected assert not robot.is_connected
for name in robot.follower_arms: for name in robot.follower_arms:
assert not robot.follower_arms[name].is_connected assert not robot.follower_arms[name].is_connected

View File

@ -22,13 +22,13 @@ from pathlib import Path
import pytest import pytest
import torch import torch
from lerobot import available_cameras, available_motors, available_robots, available_microphones from lerobot import available_cameras, available_microphones, available_motors, available_robots
from lerobot.common.robot_devices.cameras.utils import Camera from lerobot.common.robot_devices.cameras.utils import Camera
from lerobot.common.robot_devices.cameras.utils import make_camera as make_camera_device from lerobot.common.robot_devices.cameras.utils import make_camera as make_camera_device
from lerobot.common.robot_devices.motors.utils import MotorsBus
from lerobot.common.robot_devices.motors.utils import make_motors_bus as make_motors_bus_device
from lerobot.common.robot_devices.microphones.utils import Microphone from lerobot.common.robot_devices.microphones.utils import Microphone
from lerobot.common.robot_devices.microphones.utils import make_microphone as make_microphone_device from lerobot.common.robot_devices.microphones.utils import make_microphone as make_microphone_device
from lerobot.common.robot_devices.motors.utils import MotorsBus
from lerobot.common.robot_devices.motors.utils import make_motors_bus as make_motors_bus_device
from lerobot.common.utils.import_utils import is_package_available from lerobot.common.utils.import_utils import is_package_available
DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", "cuda") if torch.cuda.is_available() else "cpu" DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", "cuda") if torch.cuda.is_available() else "cpu"
@ -261,6 +261,7 @@ def require_camera(func):
return wrapper return wrapper
def require_microphone(func): def require_microphone(func):
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
@ -283,6 +284,7 @@ def require_microphone(func):
return wrapper return wrapper
def require_motor(func): def require_motor(func):
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
@ -344,6 +346,7 @@ def make_camera(camera_type: str, **kwargs) -> Camera:
else: else:
raise ValueError(f"The camera type '{camera_type}' is not valid.") raise ValueError(f"The camera type '{camera_type}' is not valid.")
def make_microphone(microphone_type: str, **kwargs) -> Microphone: def make_microphone(microphone_type: str, **kwargs) -> Microphone:
if microphone_type == "microphone": if microphone_type == "microphone":
microphone_index = kwargs.pop("microphone_index", MICROPHONE_INDEX) microphone_index = kwargs.pop("microphone_index", MICROPHONE_INDEX)
@ -351,6 +354,7 @@ def make_microphone(microphone_type: str, **kwargs) -> Microphone:
else: else:
raise ValueError(f"The microphone type '{microphone_type}' is not valid.") raise ValueError(f"The microphone type '{microphone_type}' is not valid.")
# TODO(rcadene, aliberts): remove this dark pattern that overrides # TODO(rcadene, aliberts): remove this dark pattern that overrides
def make_motors_bus(motor_type: str, **kwargs) -> MotorsBus: def make_motors_bus(motor_type: str, **kwargs) -> MotorsBus:
if motor_type == "dynamixel": if motor_type == "dynamixel":