[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
9c667d347c
commit
0cb9345f06
|
@ -15,7 +15,8 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from lerobot.common.datasets.utils import load_image_as_numpy, load_audio_from_path
|
from lerobot.common.datasets.utils import load_audio_from_path, load_image_as_numpy
|
||||||
|
|
||||||
|
|
||||||
def estimate_num_samples(
|
def estimate_num_samples(
|
||||||
dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75
|
dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75
|
||||||
|
@ -70,17 +71,19 @@ def sample_images(image_paths: list[str]) -> np.ndarray:
|
||||||
|
|
||||||
return images
|
return images
|
||||||
|
|
||||||
def sample_audio_from_path(audio_path: str) -> np.ndarray:
|
|
||||||
|
|
||||||
|
def sample_audio_from_path(audio_path: str) -> np.ndarray:
|
||||||
data = load_audio_from_path(audio_path)
|
data = load_audio_from_path(audio_path)
|
||||||
sampled_indices = sample_indices(len(data))
|
sampled_indices = sample_indices(len(data))
|
||||||
|
|
||||||
return(data[sampled_indices])
|
return data[sampled_indices]
|
||||||
|
|
||||||
|
|
||||||
def sample_audio_from_data(data: np.ndarray) -> np.ndarray:
|
def sample_audio_from_data(data: np.ndarray) -> np.ndarray:
|
||||||
sampled_indices = sample_indices(len(data))
|
sampled_indices = sample_indices(len(data))
|
||||||
return data[sampled_indices]
|
return data[sampled_indices]
|
||||||
|
|
||||||
|
|
||||||
def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]:
|
def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]:
|
||||||
return {
|
return {
|
||||||
"min": np.min(array, axis=axis, keepdims=keepdims),
|
"min": np.min(array, axis=axis, keepdims=keepdims),
|
||||||
|
@ -103,9 +106,9 @@ def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], featu
|
||||||
elif features[key]["dtype"] == "audio":
|
elif features[key]["dtype"] == "audio":
|
||||||
try:
|
try:
|
||||||
ep_ft_array = sample_audio_from_path(data[0])
|
ep_ft_array = sample_audio_from_path(data[0])
|
||||||
except TypeError: #Should only be triggered for LeKiwi robot
|
except TypeError: # Should only be triggered for LeKiwi robot
|
||||||
ep_ft_array = sample_audio_from_data(data)
|
ep_ft_array = sample_audio_from_data(data)
|
||||||
axes_to_reduce = 0
|
axes_to_reduce = 0
|
||||||
keepdims = True
|
keepdims = True
|
||||||
else:
|
else:
|
||||||
ep_ft_array = data # data is already a np.ndarray
|
ep_ft_array = data # data is already a np.ndarray
|
||||||
|
|
|
@ -23,6 +23,7 @@ import datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import packaging.version
|
import packaging.version
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
|
import soundfile as sf
|
||||||
import torch
|
import torch
|
||||||
import torch.utils
|
import torch.utils
|
||||||
from datasets import concatenate_datasets, load_dataset
|
from datasets import concatenate_datasets, load_dataset
|
||||||
|
@ -34,13 +35,12 @@ from lerobot.common.constants import HF_LEROBOT_HOME
|
||||||
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
|
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
|
||||||
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
|
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
|
DEFAULT_AUDIO_CHUNK_DURATION,
|
||||||
DEFAULT_FEATURES,
|
DEFAULT_FEATURES,
|
||||||
DEFAULT_IMAGE_PATH,
|
DEFAULT_IMAGE_PATH,
|
||||||
DEFAULT_RAW_AUDIO_PATH,
|
DEFAULT_RAW_AUDIO_PATH,
|
||||||
DEFAULT_COMPRESSED_AUDIO_PATH,
|
|
||||||
INFO_PATH,
|
INFO_PATH,
|
||||||
TASKS_PATH,
|
TASKS_PATH,
|
||||||
DEFAULT_AUDIO_CHUNK_DURATION,
|
|
||||||
append_jsonlines,
|
append_jsonlines,
|
||||||
backward_compatible_episodes_stats,
|
backward_compatible_episodes_stats,
|
||||||
check_delta_timestamps,
|
check_delta_timestamps,
|
||||||
|
@ -70,17 +70,16 @@ from lerobot.common.datasets.utils import (
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.video_utils import (
|
from lerobot.common.datasets.video_utils import (
|
||||||
VideoFrame,
|
VideoFrame,
|
||||||
decode_video_frames,
|
|
||||||
encode_video_frames,
|
|
||||||
encode_audio,
|
|
||||||
decode_audio,
|
decode_audio,
|
||||||
|
decode_video_frames,
|
||||||
|
encode_audio,
|
||||||
|
encode_video_frames,
|
||||||
|
get_audio_info,
|
||||||
get_safe_default_codec,
|
get_safe_default_codec,
|
||||||
get_video_info,
|
get_video_info,
|
||||||
get_audio_info,
|
|
||||||
)
|
)
|
||||||
from lerobot.common.robot_devices.robots.utils import Robot
|
|
||||||
from lerobot.common.robot_devices.microphones.utils import Microphone
|
from lerobot.common.robot_devices.microphones.utils import Microphone
|
||||||
import soundfile as sf
|
from lerobot.common.robot_devices.robots.utils import Robot
|
||||||
|
|
||||||
CODEBASE_VERSION = "v2.1"
|
CODEBASE_VERSION = "v2.1"
|
||||||
|
|
||||||
|
@ -149,10 +148,12 @@ class LeRobotDatasetMetadata:
|
||||||
ep_chunk = self.get_episode_chunk(ep_index)
|
ep_chunk = self.get_episode_chunk(ep_index)
|
||||||
fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
|
fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
|
||||||
return Path(fpath)
|
return Path(fpath)
|
||||||
|
|
||||||
def get_compressed_audio_file_path(self, episode_index: int, audio_key: str) -> Path:
|
def get_compressed_audio_file_path(self, episode_index: int, audio_key: str) -> Path:
|
||||||
episode_chunk = self.get_episode_chunk(episode_index)
|
episode_chunk = self.get_episode_chunk(episode_index)
|
||||||
fpath = self.audio_path.format(episode_chunk=episode_chunk, audio_key=audio_key, episode_index=episode_index)
|
fpath = self.audio_path.format(
|
||||||
|
episode_chunk=episode_chunk, audio_key=audio_key, episode_index=episode_index
|
||||||
|
)
|
||||||
return self.root / fpath
|
return self.root / fpath
|
||||||
|
|
||||||
def get_episode_chunk(self, ep_index: int) -> int:
|
def get_episode_chunk(self, ep_index: int) -> int:
|
||||||
|
@ -167,7 +168,7 @@ class LeRobotDatasetMetadata:
|
||||||
def video_path(self) -> str | None:
|
def video_path(self) -> str | None:
|
||||||
"""Formattable string for the video files."""
|
"""Formattable string for the video files."""
|
||||||
return self.info["video_path"]
|
return self.info["video_path"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def audio_path(self) -> str | None:
|
def audio_path(self) -> str | None:
|
||||||
"""Formattable string for the audio files."""
|
"""Formattable string for the audio files."""
|
||||||
|
@ -202,16 +203,21 @@ class LeRobotDatasetMetadata:
|
||||||
def camera_keys(self) -> list[str]:
|
def camera_keys(self) -> list[str]:
|
||||||
"""Keys to access visual modalities (regardless of their storage method)."""
|
"""Keys to access visual modalities (regardless of their storage method)."""
|
||||||
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
|
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def audio_keys(self) -> list[str]:
|
def audio_keys(self) -> list[str]:
|
||||||
"""Keys to access audio modalities."""
|
"""Keys to access audio modalities."""
|
||||||
return [key for key, ft in self.features.items() if ft["dtype"] == "audio"]
|
return [key for key, ft in self.features.items() if ft["dtype"] == "audio"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def audio_camera_keys_mapping(self) -> dict[str, str]:
|
def audio_camera_keys_mapping(self) -> dict[str, str]:
|
||||||
"""Mapping between camera keys and audio keys when both are linked."""
|
"""Mapping between camera keys and audio keys when both are linked."""
|
||||||
return {self.features[camera_key]["audio"]:camera_key for camera_key in self.camera_keys if self.features[camera_key]["audio"] is not None and self.features[camera_key]["dtype"] == "video"}
|
return {
|
||||||
|
self.features[camera_key]["audio"]: camera_key
|
||||||
|
for camera_key in self.camera_keys
|
||||||
|
if self.features[camera_key]["audio"] is not None
|
||||||
|
and self.features[camera_key]["dtype"] == "video"
|
||||||
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def names(self) -> dict[str, list | dict]:
|
def names(self) -> dict[str, list | dict]:
|
||||||
|
@ -325,7 +331,9 @@ class LeRobotDatasetMetadata:
|
||||||
been encoded the same way. Also, this means it assumes the first episode exists.
|
been encoded the same way. Also, this means it assumes the first episode exists.
|
||||||
"""
|
"""
|
||||||
for key in set(self.audio_keys) - set(self.audio_camera_keys_mapping.keys()):
|
for key in set(self.audio_keys) - set(self.audio_camera_keys_mapping.keys()):
|
||||||
if not self.features[key].get("info", None) or (len(self.features[key]["info"]) == 1 and "sample_rate" in self.features[key]["info"]):
|
if not self.features[key].get("info", None) or (
|
||||||
|
len(self.features[key]["info"]) == 1 and "sample_rate" in self.features[key]["info"]
|
||||||
|
):
|
||||||
audio_path = self.root / self.get_compressed_audio_file_path(0, key)
|
audio_path = self.root / self.get_compressed_audio_file_path(0, key)
|
||||||
self.info["features"][key]["info"] = get_audio_info(audio_path)
|
self.info["features"][key]["info"] = get_audio_info(audio_path)
|
||||||
|
|
||||||
|
@ -518,7 +526,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
self.tolerance_s = tolerance_s
|
self.tolerance_s = tolerance_s
|
||||||
self.revision = revision if revision else CODEBASE_VERSION
|
self.revision = revision if revision else CODEBASE_VERSION
|
||||||
self.video_backend = video_backend if video_backend else get_safe_default_codec()
|
self.video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||||
self.audio_backend = audio_backend if audio_backend else "ffmpeg" #Waiting for torchcodec release #TODO(CarolinePascal)
|
self.audio_backend = (
|
||||||
|
audio_backend if audio_backend else "ffmpeg"
|
||||||
|
) # Waiting for torchcodec release #TODO(CarolinePascal)
|
||||||
self.delta_indices = None
|
self.delta_indices = None
|
||||||
|
|
||||||
# Unused attributes
|
# Unused attributes
|
||||||
|
@ -543,7 +553,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
self.hf_dataset = self.load_hf_dataset()
|
self.hf_dataset = self.load_hf_dataset()
|
||||||
except (AssertionError, FileNotFoundError, NotADirectoryError):
|
except (AssertionError, FileNotFoundError, NotADirectoryError):
|
||||||
self.revision = get_safe_version(self.repo_id, self.revision)
|
self.revision = get_safe_version(self.repo_id, self.revision)
|
||||||
self.download_episodes(download_videos) #Sould load audio as well #TODO(CarolinePascal): separate audio from video
|
self.download_episodes(
|
||||||
|
download_videos
|
||||||
|
) # Sould load audio as well #TODO(CarolinePascal): separate audio from video
|
||||||
self.hf_dataset = self.load_hf_dataset()
|
self.hf_dataset = self.load_hf_dataset()
|
||||||
|
|
||||||
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
||||||
|
@ -735,13 +747,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
query_timestamps[key] = [current_ts]
|
query_timestamps[key] = [current_ts]
|
||||||
|
|
||||||
return query_timestamps
|
return query_timestamps
|
||||||
|
|
||||||
#TODO(CarolinePascal): add variable query durations
|
# TODO(CarolinePascal): add variable query durations
|
||||||
def _get_query_timestamps_audio(
|
def _get_query_timestamps_audio(
|
||||||
self,
|
self,
|
||||||
current_ts: float,
|
current_ts: float,
|
||||||
query_indices: dict[str, list[int]] | None = None,
|
query_indices: dict[str, list[int]] | None = None,
|
||||||
) -> dict[str, list[float]]:
|
) -> dict[str, list[float]]:
|
||||||
query_timestamps = {}
|
query_timestamps = {}
|
||||||
for key in self.meta.audio_keys:
|
for key in self.meta.audio_keys:
|
||||||
if query_indices is not None and key in query_indices:
|
if query_indices is not None and key in query_indices:
|
||||||
|
@ -773,14 +785,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
return item
|
return item
|
||||||
|
|
||||||
#TODO(CarolinePascal): add variable query durations
|
# TODO(CarolinePascal): add variable query durations
|
||||||
def _query_audio(self, query_timestamps: dict[str, list[float]], query_duration: float, ep_idx: int) -> dict[str, torch.Tensor]:
|
def _query_audio(
|
||||||
|
self, query_timestamps: dict[str, list[float]], query_duration: float, ep_idx: int
|
||||||
|
) -> dict[str, torch.Tensor]:
|
||||||
item = {}
|
item = {}
|
||||||
for audio_key, query_ts in query_timestamps.items():
|
for audio_key, query_ts in query_timestamps.items():
|
||||||
#Audio stored with video in a single .mp4 file
|
# Audio stored with video in a single .mp4 file
|
||||||
if audio_key in self.meta.audio_camera_keys_mapping:
|
if audio_key in self.meta.audio_camera_keys_mapping:
|
||||||
audio_path = self.root / self.meta.get_video_file_path(ep_idx, self.meta.audio_camera_keys_mapping[audio_key])
|
audio_path = self.root / self.meta.get_video_file_path(
|
||||||
#Audio stored alone in a separate .m4a file
|
ep_idx, self.meta.audio_camera_keys_mapping[audio_key]
|
||||||
|
)
|
||||||
|
# Audio stored alone in a separate .m4a file
|
||||||
else:
|
else:
|
||||||
audio_path = self.root / self.meta.get_compressed_audio_file_path(ep_idx, audio_key)
|
audio_path = self.root / self.meta.get_compressed_audio_file_path(ep_idx, audio_key)
|
||||||
audio_chunk = decode_audio(audio_path, query_ts, query_duration, self.audio_backend)
|
audio_chunk = decode_audio(audio_path, query_ts, query_duration, self.audio_backend)
|
||||||
|
@ -855,7 +871,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
image_key=image_key, episode_index=episode_index, frame_index=frame_index
|
image_key=image_key, episode_index=episode_index, frame_index=frame_index
|
||||||
)
|
)
|
||||||
return self.root / fpath
|
return self.root / fpath
|
||||||
|
|
||||||
def _get_raw_audio_file_path(self, episode_index: int, audio_key: str) -> Path:
|
def _get_raw_audio_file_path(self, episode_index: int, audio_key: str) -> Path:
|
||||||
fpath = DEFAULT_RAW_AUDIO_PATH.format(audio_key=audio_key, episode_index=episode_index)
|
fpath = DEFAULT_RAW_AUDIO_PATH.format(audio_key=audio_key, episode_index=episode_index)
|
||||||
return self.root / fpath
|
return self.root / fpath
|
||||||
|
@ -929,11 +945,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
This function will start recording audio from the microphone and save it to disk.
|
This function will start recording audio from the microphone and save it to disk.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
audio_dir = self._get_raw_audio_file_path(self.num_episodes, "observation.audio." + microphone_key).parent
|
audio_dir = self._get_raw_audio_file_path(
|
||||||
|
self.num_episodes, "observation.audio." + microphone_key
|
||||||
|
).parent
|
||||||
if not audio_dir.is_dir():
|
if not audio_dir.is_dir():
|
||||||
audio_dir.mkdir(parents=True, exist_ok=True)
|
audio_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
microphone.start_recording(output_file = self._get_raw_audio_file_path(self.num_episodes, "observation.audio." + microphone_key))
|
microphone.start_recording(
|
||||||
|
output_file=self._get_raw_audio_file_path(
|
||||||
|
self.num_episodes, "observation.audio." + microphone_key
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def save_episode(self, episode_data: dict | None = None) -> None:
|
def save_episode(self, episode_data: dict | None = None) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -983,8 +1005,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
if self.meta.robot_type.startswith("lekiwi"):
|
if self.meta.robot_type.startswith("lekiwi"):
|
||||||
for key in self.meta.audio_keys:
|
for key in self.meta.audio_keys:
|
||||||
audio_path = self._get_raw_audio_file_path(episode_index=self.episode_buffer["episode_index"][0], audio_key=key)
|
audio_path = self._get_raw_audio_file_path(
|
||||||
with sf.SoundFile(audio_path, mode='w', samplerate=self.meta.features[key]["info"]["sample_rate"], channels=self.meta.features[key]["shape"][0]) as file:
|
episode_index=self.episode_buffer["episode_index"][0], audio_key=key
|
||||||
|
)
|
||||||
|
with sf.SoundFile(
|
||||||
|
audio_path,
|
||||||
|
mode="w",
|
||||||
|
samplerate=self.meta.features[key]["info"]["sample_rate"],
|
||||||
|
channels=self.meta.features[key]["shape"][0],
|
||||||
|
) as file:
|
||||||
file.write(episode_buffer[key])
|
file.write(episode_buffer[key])
|
||||||
|
|
||||||
ep_stats = compute_episode_stats(episode_buffer, self.features)
|
ep_stats = compute_episode_stats(episode_buffer, self.features)
|
||||||
|
@ -996,7 +1025,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
if len(self.meta.audio_keys) > 0:
|
if len(self.meta.audio_keys) > 0:
|
||||||
_ = self.encode_episode_audio(episode_index)
|
_ = self.encode_episode_audio(episode_index)
|
||||||
|
|
||||||
# `meta.save_episode` be executed after encoding the videos
|
# `meta.save_episode` be executed after encoding the videos
|
||||||
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
|
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
|
||||||
|
|
||||||
|
@ -1113,7 +1142,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
encode_video_frames(img_dir, video_path, self.fps, audio_path=audio_path, overwrite=True)
|
encode_video_frames(img_dir, video_path, self.fps, audio_path=audio_path, overwrite=True)
|
||||||
|
|
||||||
return video_paths
|
return video_paths
|
||||||
|
|
||||||
def encode_episode_audio(self, episode_index: int) -> dict:
|
def encode_episode_audio(self, episode_index: int) -> dict:
|
||||||
"""
|
"""
|
||||||
Use ffmpeg to convert .wav raw audio files into .m4a audio files.
|
Use ffmpeg to convert .wav raw audio files into .m4a audio files.
|
||||||
|
@ -1124,7 +1153,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
for audio_key in set(self.meta.audio_keys) - set(self.meta.audio_camera_keys_mapping.keys()):
|
for audio_key in set(self.meta.audio_keys) - set(self.meta.audio_camera_keys_mapping.keys()):
|
||||||
input_audio_path = self.root / self._get_raw_audio_file_path(episode_index, audio_key)
|
input_audio_path = self.root / self._get_raw_audio_file_path(episode_index, audio_key)
|
||||||
output_audio_path = self.root / self.meta.get_compressed_audio_file_path(episode_index, audio_key)
|
output_audio_path = self.root / self.meta.get_compressed_audio_file_path(episode_index, audio_key)
|
||||||
|
|
||||||
audio_paths[audio_key] = str(output_audio_path)
|
audio_paths[audio_key] = str(output_audio_path)
|
||||||
if output_audio_path.is_file():
|
if output_audio_path.is_file():
|
||||||
# Skip if video is already encoded. Could be the case when resuming data recording.
|
# Skip if video is already encoded. Could be the case when resuming data recording.
|
||||||
|
@ -1180,7 +1209,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
obj.delta_indices = None
|
obj.delta_indices = None
|
||||||
obj.episode_data_index = None
|
obj.episode_data_index = None
|
||||||
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
||||||
obj.audio_backend = audio_backend if audio_backend is not None else "ffmpeg" #Waiting for torchcodec release #TODO(CarolinePascal)
|
obj.audio_backend = (
|
||||||
|
audio_backend if audio_backend is not None else "ffmpeg"
|
||||||
|
) # Waiting for torchcodec release #TODO(CarolinePascal)
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -33,9 +33,8 @@ from datasets.table import embed_table_storage
|
||||||
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
|
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
|
||||||
from huggingface_hub.errors import RevisionNotFoundError
|
from huggingface_hub.errors import RevisionNotFoundError
|
||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
from torchvision import transforms
|
|
||||||
|
|
||||||
from soundfile import read
|
from soundfile import read
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
from lerobot.common.datasets.backward_compatibility import (
|
from lerobot.common.datasets.backward_compatibility import (
|
||||||
V21_MESSAGE,
|
V21_MESSAGE,
|
||||||
|
@ -260,10 +259,12 @@ def load_image_as_numpy(
|
||||||
img_array /= 255.0
|
img_array /= 255.0
|
||||||
return img_array
|
return img_array
|
||||||
|
|
||||||
|
|
||||||
def load_audio_from_path(fpath: str | Path) -> np.ndarray:
|
def load_audio_from_path(fpath: str | Path) -> np.ndarray:
|
||||||
audio_data, _ = read(fpath, dtype="float32")
|
audio_data, _ = read(fpath, dtype="float32")
|
||||||
return audio_data
|
return audio_data
|
||||||
|
|
||||||
|
|
||||||
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||||
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
|
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
|
||||||
to torch tensors. Importantly, images are converted from PIL, which corresponds to
|
to torch tensors. Importantly, images are converted from PIL, which corresponds to
|
||||||
|
@ -731,7 +732,7 @@ def validate_features_presence(
|
||||||
):
|
):
|
||||||
error_message = ""
|
error_message = ""
|
||||||
missing_features = expected_features - actual_features
|
missing_features = expected_features - actual_features
|
||||||
missing_features = {feature for feature in missing_features if "observation.audio" not in feature}
|
missing_features = {feature for feature in missing_features if "observation.audio" not in feature}
|
||||||
extra_features = actual_features - (expected_features | optional_features)
|
extra_features = actual_features - (expected_features | optional_features)
|
||||||
|
|
||||||
if missing_features or extra_features:
|
if missing_features or extra_features:
|
||||||
|
@ -793,18 +794,24 @@ def validate_feature_image_or_video(name: str, expected_shape: list[str], value:
|
||||||
|
|
||||||
return error_message
|
return error_message
|
||||||
|
|
||||||
|
|
||||||
def validate_feature_audio(name: str, expected_shape: list[str], value: np.ndarray):
|
def validate_feature_audio(name: str, expected_shape: list[str], value: np.ndarray):
|
||||||
error_message = ""
|
error_message = ""
|
||||||
if isinstance(value, np.ndarray):
|
if isinstance(value, np.ndarray):
|
||||||
actual_shape = value.shape
|
actual_shape = value.shape
|
||||||
c = expected_shape
|
c = expected_shape
|
||||||
if len(actual_shape) != 2 or (actual_shape[-1] != c[-1] and actual_shape[0] != c[0]): #The number of frames might be different
|
if len(actual_shape) != 2 or (
|
||||||
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c,)}'.\n"
|
actual_shape[-1] != c[-1] and actual_shape[0] != c[0]
|
||||||
|
): # The number of frames might be different
|
||||||
|
error_message += (
|
||||||
|
f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c,)}'.\n"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
error_message += f"The feature '{name}' is expected to be of type 'np.ndarray', but type '{type(value)}' provided instead.\n"
|
error_message += f"The feature '{name}' is expected to be of type 'np.ndarray', but type '{type(value)}' provided instead.\n"
|
||||||
|
|
||||||
return error_message
|
return error_message
|
||||||
|
|
||||||
|
|
||||||
def validate_feature_string(name: str, value: str):
|
def validate_feature_string(name: str, value: str):
|
||||||
if not isinstance(value, str):
|
if not isinstance(value, str):
|
||||||
return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n"
|
return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n"
|
||||||
|
|
|
@ -25,12 +25,11 @@ from typing import Any, ClassVar
|
||||||
|
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
import torchvision
|
||||||
from datasets.features.features import register_feature
|
from datasets.features.features import register_feature
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from numpy import ceil
|
from numpy import ceil
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
def get_safe_default_codec():
|
def get_safe_default_codec():
|
||||||
|
@ -42,6 +41,7 @@ def get_safe_default_codec():
|
||||||
)
|
)
|
||||||
return "pyav"
|
return "pyav"
|
||||||
|
|
||||||
|
|
||||||
def decode_audio(
|
def decode_audio(
|
||||||
audio_path: Path | str,
|
audio_path: Path | str,
|
||||||
timestamps: list[float],
|
timestamps: list[float],
|
||||||
|
@ -68,30 +68,30 @@ def decode_audio(
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported video backend: {backend}")
|
raise ValueError(f"Unsupported video backend: {backend}")
|
||||||
|
|
||||||
|
|
||||||
def decode_audio_torchvision(
|
def decode_audio_torchvision(
|
||||||
audio_path: Path | str,
|
audio_path: Path | str,
|
||||||
timestamps: list[float],
|
timestamps: list[float],
|
||||||
duration: float,
|
duration: float,
|
||||||
log_loaded_timestamps: bool = False,
|
log_loaded_timestamps: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
# TODO(CarolinePascal) : add channels selection
|
||||||
#TODO(CarolinePascal) : add channels selection
|
|
||||||
audio_path = str(audio_path)
|
audio_path = str(audio_path)
|
||||||
|
|
||||||
reader = torchaudio.io.StreamReader(src=audio_path)
|
reader = torchaudio.io.StreamReader(src=audio_path)
|
||||||
audio_sample_rate = reader.get_src_stream_info(reader.default_audio_stream).sample_rate
|
audio_sample_rate = reader.get_src_stream_info(reader.default_audio_stream).sample_rate
|
||||||
|
|
||||||
#TODO(CarolinePascal) : sort timestamps ?
|
# TODO(CarolinePascal) : sort timestamps ?
|
||||||
|
|
||||||
reader.add_basic_audio_stream(
|
reader.add_basic_audio_stream(
|
||||||
frames_per_chunk = int(ceil(duration * audio_sample_rate)), #Too much is better than not enough
|
frames_per_chunk=int(ceil(duration * audio_sample_rate)), # Too much is better than not enough
|
||||||
buffer_chunk_size = -1, #No dropping frames
|
buffer_chunk_size=-1, # No dropping frames
|
||||||
format = "fltp", #Format as float32
|
format="fltp", # Format as float32
|
||||||
)
|
)
|
||||||
|
|
||||||
audio_chunks = []
|
audio_chunks = []
|
||||||
for ts in timestamps:
|
for ts in timestamps:
|
||||||
reader.seek(ts) #Default to closest audio sample
|
reader.seek(ts) # Default to closest audio sample
|
||||||
status = reader.fill_buffer()
|
status = reader.fill_buffer()
|
||||||
if status != 0:
|
if status != 0:
|
||||||
logging.warning("Audio stream reached end of recording before decoding desired timestamps.")
|
logging.warning("Audio stream reached end of recording before decoding desired timestamps.")
|
||||||
|
@ -99,15 +99,18 @@ def decode_audio_torchvision(
|
||||||
current_audio_chunk = reader.pop_chunks()[0]
|
current_audio_chunk = reader.pop_chunks()[0]
|
||||||
|
|
||||||
if log_loaded_timestamps:
|
if log_loaded_timestamps:
|
||||||
logging.info(f"audio chunk loaded at starting timestamp={current_audio_chunk["pts"]:.4f} with duration={len(current_audio_chunk) / audio_sample_rate:.4f}")
|
logging.info(
|
||||||
|
f"audio chunk loaded at starting timestamp={current_audio_chunk['pts']:.4f} with duration={len(current_audio_chunk) / audio_sample_rate:.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
audio_chunks.append(current_audio_chunk)
|
audio_chunks.append(current_audio_chunk)
|
||||||
|
|
||||||
audio_chunks = torch.stack(audio_chunks)
|
audio_chunks = torch.stack(audio_chunks)
|
||||||
|
|
||||||
assert len(timestamps) == len(audio_chunks)
|
assert len(timestamps) == len(audio_chunks)
|
||||||
return audio_chunks
|
return audio_chunks
|
||||||
|
|
||||||
|
|
||||||
def decode_video_frames(
|
def decode_video_frames(
|
||||||
video_path: Path | str,
|
video_path: Path | str,
|
||||||
timestamps: list[float],
|
timestamps: list[float],
|
||||||
|
@ -137,6 +140,7 @@ def decode_video_frames(
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported video backend: {backend}")
|
raise ValueError(f"Unsupported video backend: {backend}")
|
||||||
|
|
||||||
|
|
||||||
def decode_video_frames_torchvision(
|
def decode_video_frames_torchvision(
|
||||||
video_path: Path | str,
|
video_path: Path | str,
|
||||||
timestamps: list[float],
|
timestamps: list[float],
|
||||||
|
@ -234,6 +238,7 @@ def decode_video_frames_torchvision(
|
||||||
assert len(timestamps) == len(closest_frames)
|
assert len(timestamps) == len(closest_frames)
|
||||||
return closest_frames
|
return closest_frames
|
||||||
|
|
||||||
|
|
||||||
def decode_video_frames_torchcodec(
|
def decode_video_frames_torchcodec(
|
||||||
video_path: Path | str,
|
video_path: Path | str,
|
||||||
timestamps: list[float],
|
timestamps: list[float],
|
||||||
|
@ -308,6 +313,7 @@ def decode_video_frames_torchcodec(
|
||||||
assert len(timestamps) == len(closest_frames)
|
assert len(timestamps) == len(closest_frames)
|
||||||
return closest_frames
|
return closest_frames
|
||||||
|
|
||||||
|
|
||||||
def encode_audio(
|
def encode_audio(
|
||||||
input_path: Path | str,
|
input_path: Path | str,
|
||||||
output_path: Path | str,
|
output_path: Path | str,
|
||||||
|
@ -344,6 +350,7 @@ def encode_audio(
|
||||||
f"Try running the command manually to debug: `{''.join(ffmpeg_cmd)}`"
|
f"Try running the command manually to debug: `{''.join(ffmpeg_cmd)}`"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def encode_video_frames(
|
def encode_video_frames(
|
||||||
imgs_dir: Path | str,
|
imgs_dir: Path | str,
|
||||||
video_path: Path | str,
|
video_path: Path | str,
|
||||||
|
@ -353,7 +360,7 @@ def encode_video_frames(
|
||||||
pix_fmt: str = "yuv420p",
|
pix_fmt: str = "yuv420p",
|
||||||
g: int | None = 2,
|
g: int | None = 2,
|
||||||
crf: int | None = 30,
|
crf: int | None = 30,
|
||||||
acodec: str = "aac", #TODO(CarolinePascal) : investigate Fraunhofer FDK AAC (libfdk_aac) codec and and constant (file size control) /variable (quality control) bitrate options
|
acodec: str = "aac", # TODO(CarolinePascal) : investigate Fraunhofer FDK AAC (libfdk_aac) codec and and constant (file size control) /variable (quality control) bitrate options
|
||||||
fast_decode: int = 0,
|
fast_decode: int = 0,
|
||||||
log_level: str | None = "error",
|
log_level: str | None = "error",
|
||||||
overwrite: bool = False,
|
overwrite: bool = False,
|
||||||
|
@ -375,16 +382,18 @@ def encode_video_frames(
|
||||||
if audio_path is not None:
|
if audio_path is not None:
|
||||||
audio_path = Path(audio_path)
|
audio_path = Path(audio_path)
|
||||||
audio_path.parent.mkdir(parents=True, exist_ok=True)
|
audio_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
ffmpeg_audio_args.update(OrderedDict(
|
ffmpeg_audio_args.update(
|
||||||
[
|
OrderedDict(
|
||||||
("-i", str(audio_path)),
|
[
|
||||||
]
|
("-i", str(audio_path)),
|
||||||
))
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
ffmpeg_encoding_args = OrderedDict(
|
ffmpeg_encoding_args = OrderedDict(
|
||||||
[
|
[
|
||||||
("-pix_fmt", pix_fmt),
|
("-pix_fmt", pix_fmt),
|
||||||
("-vcodec", vcodec),
|
("-vcodec", vcodec),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
if g is not None:
|
if g is not None:
|
||||||
|
@ -398,7 +407,7 @@ def encode_video_frames(
|
||||||
|
|
||||||
if audio_path is not None:
|
if audio_path is not None:
|
||||||
ffmpeg_encoding_args["-acodec"] = acodec
|
ffmpeg_encoding_args["-acodec"] = acodec
|
||||||
|
|
||||||
if log_level is not None:
|
if log_level is not None:
|
||||||
ffmpeg_encoding_args["-loglevel"] = str(log_level)
|
ffmpeg_encoding_args["-loglevel"] = str(log_level)
|
||||||
|
|
||||||
|
@ -487,6 +496,7 @@ def get_audio_info(video_path: Path | str) -> dict:
|
||||||
"audio.channel_layout": audio_stream_info.get("channel_layout", None),
|
"audio.channel_layout": audio_stream_info.get("channel_layout", None),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_video_info(video_path: Path | str) -> dict:
|
def get_video_info(video_path: Path | str) -> dict:
|
||||||
ffprobe_video_cmd = [
|
ffprobe_video_cmd = [
|
||||||
"ffprobe",
|
"ffprobe",
|
||||||
|
|
|
@ -77,8 +77,8 @@ def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, f
|
||||||
key = f"read_camera_{name}_dt_s"
|
key = f"read_camera_{name}_dt_s"
|
||||||
if key in robot.logs:
|
if key in robot.logs:
|
||||||
log_dt(f"dtR{name}", robot.logs[key])
|
log_dt(f"dtR{name}", robot.logs[key])
|
||||||
|
|
||||||
for name in robot.microphones:
|
for name in robot.microphones:
|
||||||
key = f"read_microphone_{name}_dt_s"
|
key = f"read_microphone_{name}_dt_s"
|
||||||
if key in robot.logs:
|
if key in robot.logs:
|
||||||
log_dt(f"dtR{name}", robot.logs[key])
|
log_dt(f"dtR{name}", robot.logs[key])
|
||||||
|
@ -252,9 +252,11 @@ def control_loop(
|
||||||
timestamp = 0
|
timestamp = 0
|
||||||
start_episode_t = time.perf_counter()
|
start_episode_t = time.perf_counter()
|
||||||
|
|
||||||
if dataset is not None and not robot.robot_type.startswith("lekiwi"): #For now, LeKiwi only supports frame audio recording (which may lead to audio chunks loss, extended post-processing, increased memory usage)
|
if (
|
||||||
|
dataset is not None and not robot.robot_type.startswith("lekiwi")
|
||||||
|
): # For now, LeKiwi only supports frame audio recording (which may lead to audio chunks loss, extended post-processing, increased memory usage)
|
||||||
for microphone_key, microphone in robot.microphones.items():
|
for microphone_key, microphone in robot.microphones.items():
|
||||||
#Start recording both in file writing and data reading mode
|
# Start recording both in file writing and data reading mode
|
||||||
dataset.add_microphone_recording(microphone, microphone_key)
|
dataset.add_microphone_recording(microphone, microphone_key)
|
||||||
else:
|
else:
|
||||||
for _, microphone in robot.microphones.items():
|
for _, microphone in robot.microphones.items():
|
||||||
|
|
|
@ -17,12 +17,14 @@ from dataclasses import dataclass
|
||||||
|
|
||||||
import draccus
|
import draccus
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MicrophoneConfigBase(draccus.ChoiceRegistry, abc.ABC):
|
class MicrophoneConfigBase(draccus.ChoiceRegistry, abc.ABC):
|
||||||
@property
|
@property
|
||||||
def type(self) -> str:
|
def type(self) -> str:
|
||||||
return self.get_choice_name(self.__class__)
|
return self.get_choice_name(self.__class__)
|
||||||
|
|
||||||
|
|
||||||
@MicrophoneConfigBase.register_subclass("microphone")
|
@MicrophoneConfigBase.register_subclass("microphone")
|
||||||
@dataclass
|
@dataclass
|
||||||
class MicrophoneConfig(MicrophoneConfigBase):
|
class MicrophoneConfig(MicrophoneConfigBase):
|
||||||
|
@ -33,4 +35,4 @@ class MicrophoneConfig(MicrophoneConfigBase):
|
||||||
microphone_index: int
|
microphone_index: int
|
||||||
sample_rate: int | None = None
|
sample_rate: int | None = None
|
||||||
channels: list[int] | None = None
|
channels: list[int] | None = None
|
||||||
mock: bool = False
|
mock: bool = False
|
||||||
|
|
|
@ -13,36 +13,35 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This file contains utilities for recording audio from a microhone.
|
This file contains utilities for recording audio from a microhone.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import soundfile as sf
|
|
||||||
import numpy as np
|
|
||||||
import logging
|
import logging
|
||||||
from threading import Thread, Event
|
|
||||||
from multiprocessing import Process
|
|
||||||
from queue import Empty
|
|
||||||
|
|
||||||
from queue import Queue as thread_Queue
|
|
||||||
from threading import Event as thread_Event
|
|
||||||
from multiprocessing import JoinableQueue as process_Queue
|
|
||||||
from multiprocessing import Event as process_Event
|
|
||||||
|
|
||||||
from os import getcwd
|
|
||||||
from pathlib import Path
|
|
||||||
import shutil
|
import shutil
|
||||||
import time
|
import time
|
||||||
|
from multiprocessing import Event as process_Event
|
||||||
|
from multiprocessing import JoinableQueue as process_Queue
|
||||||
|
from multiprocessing import Process
|
||||||
|
from os import getcwd
|
||||||
|
from pathlib import Path
|
||||||
|
from queue import Empty
|
||||||
|
from queue import Queue as thread_Queue
|
||||||
|
from threading import Event, Thread
|
||||||
|
from threading import Event as thread_Event
|
||||||
|
|
||||||
from lerobot.common.utils.utils import capture_timestamp_utc
|
import numpy as np
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig
|
from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig
|
||||||
from lerobot.common.robot_devices.utils import (
|
from lerobot.common.robot_devices.utils import (
|
||||||
RobotDeviceAlreadyConnectedError,
|
RobotDeviceAlreadyConnectedError,
|
||||||
|
RobotDeviceAlreadyRecordingError,
|
||||||
RobotDeviceNotConnectedError,
|
RobotDeviceNotConnectedError,
|
||||||
RobotDeviceNotRecordingError,
|
RobotDeviceNotRecordingError,
|
||||||
RobotDeviceAlreadyRecordingError,
|
|
||||||
)
|
)
|
||||||
|
from lerobot.common.utils.utils import capture_timestamp_utc
|
||||||
|
|
||||||
|
|
||||||
def find_microphones(raise_when_empty=False, mock=False) -> list[dict]:
|
def find_microphones(raise_when_empty=False, mock=False) -> list[dict]:
|
||||||
microphones = []
|
microphones = []
|
||||||
|
@ -69,11 +68,10 @@ def find_microphones(raise_when_empty=False, mock=False) -> list[dict]:
|
||||||
|
|
||||||
return microphones
|
return microphones
|
||||||
|
|
||||||
def record_audio_from_microphones(
|
|
||||||
output_dir: Path,
|
|
||||||
microphone_ids: list[int] | None = None,
|
|
||||||
record_time_s: float = 2.0):
|
|
||||||
|
|
||||||
|
def record_audio_from_microphones(
|
||||||
|
output_dir: Path, microphone_ids: list[int] | None = None, record_time_s: float = 2.0
|
||||||
|
):
|
||||||
if microphone_ids is None or len(microphone_ids) == 0:
|
if microphone_ids is None or len(microphone_ids) == 0:
|
||||||
microphones = find_microphones()
|
microphones = find_microphones()
|
||||||
microphone_ids = [m["index"] for m in microphones]
|
microphone_ids = [m["index"] for m in microphones]
|
||||||
|
@ -104,13 +102,14 @@ def record_audio_from_microphones(
|
||||||
for microphone in microphones:
|
for microphone in microphones:
|
||||||
microphone.stop_recording()
|
microphone.stop_recording()
|
||||||
|
|
||||||
#Remark : recording may be resumed here if needed
|
# Remark : recording may be resumed here if needed
|
||||||
|
|
||||||
for microphone in microphones:
|
for microphone in microphones:
|
||||||
microphone.disconnect()
|
microphone.disconnect()
|
||||||
|
|
||||||
print(f"Images have been saved to {output_dir}")
|
print(f"Images have been saved to {output_dir}")
|
||||||
|
|
||||||
|
|
||||||
class Microphone:
|
class Microphone:
|
||||||
"""
|
"""
|
||||||
The Microphone class handles all microphones compatible with sounddevice (and the underlying PortAudio library). Most microphones and sound cards are compatible, accross all OS (Linux, Mac, Windows).
|
The Microphone class handles all microphones compatible with sounddevice (and the underlying PortAudio library). Most microphones and sound cards are compatible, accross all OS (Linux, Mac, Windows).
|
||||||
|
@ -138,20 +137,20 @@ class Microphone:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.microphone_index = config.microphone_index
|
self.microphone_index = config.microphone_index
|
||||||
|
|
||||||
#Store the recording sample rate and channels
|
# Store the recording sample rate and channels
|
||||||
self.sample_rate = config.sample_rate
|
self.sample_rate = config.sample_rate
|
||||||
self.channels = config.channels
|
self.channels = config.channels
|
||||||
|
|
||||||
self.mock = config.mock
|
self.mock = config.mock
|
||||||
|
|
||||||
#Input audio stream
|
# Input audio stream
|
||||||
self.stream = None
|
self.stream = None
|
||||||
|
|
||||||
#Thread-safe concurrent queue to store the recorded/read audio
|
# Thread-safe concurrent queue to store the recorded/read audio
|
||||||
self.record_queue = None
|
self.record_queue = None
|
||||||
self.read_queue = None
|
self.read_queue = None
|
||||||
|
|
||||||
#Thread to handle data reading and file writing in a separate thread (safely)
|
# Thread to handle data reading and file writing in a separate thread (safely)
|
||||||
self.record_thread = None
|
self.record_thread = None
|
||||||
self.record_stop_event = None
|
self.record_stop_event = None
|
||||||
|
|
||||||
|
@ -162,14 +161,16 @@ class Microphone:
|
||||||
|
|
||||||
def connect(self) -> None:
|
def connect(self) -> None:
|
||||||
if self.is_connected:
|
if self.is_connected:
|
||||||
raise RobotDeviceAlreadyConnectedError(f"Microphone {self.microphone_index} is already connected.")
|
raise RobotDeviceAlreadyConnectedError(
|
||||||
|
f"Microphone {self.microphone_index} is already connected."
|
||||||
|
)
|
||||||
|
|
||||||
if self.mock:
|
if self.mock:
|
||||||
import tests.microphones.mock_sounddevice as sd
|
import tests.microphones.mock_sounddevice as sd
|
||||||
else:
|
else:
|
||||||
import sounddevice as sd
|
import sounddevice as sd
|
||||||
|
|
||||||
#Check if the provided microphone index does match an input device
|
# Check if the provided microphone index does match an input device
|
||||||
is_index_input = sd.query_devices(self.microphone_index)["max_input_channels"] > 0
|
is_index_input = sd.query_devices(self.microphone_index)["max_input_channels"] > 0
|
||||||
|
|
||||||
if not is_index_input:
|
if not is_index_input:
|
||||||
|
@ -178,17 +179,19 @@ class Microphone:
|
||||||
raise OSError(
|
raise OSError(
|
||||||
f"Microphone index {self.microphone_index} does not match an input device (microphone). Available input devices : {available_microphones}"
|
f"Microphone index {self.microphone_index} does not match an input device (microphone). Available input devices : {available_microphones}"
|
||||||
)
|
)
|
||||||
|
|
||||||
#Check if provided recording parameters are compatible with the microphone
|
# Check if provided recording parameters are compatible with the microphone
|
||||||
actual_microphone = sd.query_devices(self.microphone_index)
|
actual_microphone = sd.query_devices(self.microphone_index)
|
||||||
|
|
||||||
if self.sample_rate is not None :
|
if self.sample_rate is not None:
|
||||||
if self.sample_rate > actual_microphone["default_samplerate"]:
|
if self.sample_rate > actual_microphone["default_samplerate"]:
|
||||||
raise OSError(
|
raise OSError(
|
||||||
f"Provided sample rate {self.sample_rate} is higher than the sample rate of the microphone {actual_microphone['default_samplerate']}."
|
f"Provided sample rate {self.sample_rate} is higher than the sample rate of the microphone {actual_microphone['default_samplerate']}."
|
||||||
)
|
)
|
||||||
elif self.sample_rate < actual_microphone["default_samplerate"]:
|
elif self.sample_rate < actual_microphone["default_samplerate"]:
|
||||||
logging.warning("Provided sample rate is lower than the sample rate of the microphone. Performance may be impacted.")
|
logging.warning(
|
||||||
|
"Provided sample rate is lower than the sample rate of the microphone. Performance may be impacted."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.sample_rate = int(actual_microphone["default_samplerate"])
|
self.sample_rate = int(actual_microphone["default_samplerate"])
|
||||||
|
|
||||||
|
@ -198,45 +201,52 @@ class Microphone:
|
||||||
f"Some of the provided channels {self.channels} are outside the maximum channel range of the microphone {actual_microphone['max_input_channels']}."
|
f"Some of the provided channels {self.channels} are outside the maximum channel range of the microphone {actual_microphone['max_input_channels']}."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.channels = np.arange(1, actual_microphone["max_input_channels"]+1)
|
self.channels = np.arange(1, actual_microphone["max_input_channels"] + 1)
|
||||||
|
|
||||||
# Get channels index instead of number for slicing
|
# Get channels index instead of number for slicing
|
||||||
self.channels = np.array(self.channels) - 1
|
self.channels = np.array(self.channels) - 1
|
||||||
|
|
||||||
#Create the audio stream
|
# Create the audio stream
|
||||||
self.stream = sd.InputStream(
|
self.stream = sd.InputStream(
|
||||||
device=self.microphone_index,
|
device=self.microphone_index,
|
||||||
samplerate=self.sample_rate,
|
samplerate=self.sample_rate,
|
||||||
channels=max(self.channels)+1,
|
channels=max(self.channels) + 1,
|
||||||
dtype="float32",
|
dtype="float32",
|
||||||
callback=self._audio_callback,
|
callback=self._audio_callback,
|
||||||
)
|
)
|
||||||
#Remark : the blocksize parameter could be passed to the stream to ensure that audio_callback always recieve same length buffers.
|
# Remark : the blocksize parameter could be passed to the stream to ensure that audio_callback always recieve same length buffers.
|
||||||
#However, this may lead to additionnal latency. We thus stick to blocksize=0 which means that audio_callback will recieve varying length buffers, but with no addtional latency.
|
# However, this may lead to additionnal latency. We thus stick to blocksize=0 which means that audio_callback will recieve varying length buffers, but with no addtional latency.
|
||||||
|
|
||||||
self.is_connected = True
|
self.is_connected = True
|
||||||
|
|
||||||
def _audio_callback(self, indata, frames, time, status) -> None :
|
def _audio_callback(self, indata, frames, time, status) -> None:
|
||||||
if status:
|
if status:
|
||||||
logging.warning(status)
|
logging.warning(status)
|
||||||
# Slicing makes copy unecessary
|
# Slicing makes copy unecessary
|
||||||
# Two separate queues are necessary because .get() also pops the data from the queue
|
# Two separate queues are necessary because .get() also pops the data from the queue
|
||||||
if self.is_writing:
|
if self.is_writing:
|
||||||
self.record_queue.put(indata[:,self.channels])
|
self.record_queue.put(indata[:, self.channels])
|
||||||
self.read_queue.put(indata[:,self.channels])
|
self.read_queue.put(indata[:, self.channels])
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _record_loop(queue, event: Event, sample_rate: int, channels: list[int], output_file: Path) -> None:
|
def _record_loop(queue, event: Event, sample_rate: int, channels: list[int], output_file: Path) -> None:
|
||||||
#Can only be run on a single process/thread for file writing safety
|
# Can only be run on a single process/thread for file writing safety
|
||||||
with sf.SoundFile(output_file, mode='x', samplerate=sample_rate,
|
with sf.SoundFile(
|
||||||
channels=max(channels)+1, subtype=sf.default_subtype(output_file.suffix[1:])) as file:
|
output_file,
|
||||||
|
mode="x",
|
||||||
|
samplerate=sample_rate,
|
||||||
|
channels=max(channels) + 1,
|
||||||
|
subtype=sf.default_subtype(output_file.suffix[1:]),
|
||||||
|
) as file:
|
||||||
while not event.is_set():
|
while not event.is_set():
|
||||||
try:
|
try:
|
||||||
file.write(queue.get(timeout=0.02)) #Timeout set as twice the usual sounddevice buffer size
|
file.write(
|
||||||
|
queue.get(timeout=0.02)
|
||||||
|
) # Timeout set as twice the usual sounddevice buffer size
|
||||||
queue.task_done()
|
queue.task_done()
|
||||||
except Empty:
|
except Empty:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
def _read(self) -> np.ndarray:
|
def _read(self) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Gets audio data from the queue and coverts it to a numpy array.
|
Gets audio data from the queue and coverts it to a numpy array.
|
||||||
|
@ -244,7 +254,7 @@ class Microphone:
|
||||||
-> CONS : Reading duration does not scale well with the number of channels and reading duration
|
-> CONS : Reading duration does not scale well with the number of channels and reading duration
|
||||||
"""
|
"""
|
||||||
audio_readings = np.empty((0, len(self.channels)))
|
audio_readings = np.empty((0, len(self.channels)))
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
audio_readings = np.concatenate((audio_readings, self.read_queue.get_nowait()), axis=0)
|
audio_readings = np.concatenate((audio_readings, self.read_queue.get_nowait()), axis=0)
|
||||||
|
@ -256,12 +266,11 @@ class Microphone:
|
||||||
return audio_readings
|
return audio_readings
|
||||||
|
|
||||||
def read(self) -> np.ndarray:
|
def read(self) -> np.ndarray:
|
||||||
|
|
||||||
if not self.is_connected:
|
if not self.is_connected:
|
||||||
raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
|
raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
|
||||||
if not self.is_recording:
|
if not self.is_recording:
|
||||||
raise RobotDeviceNotRecordingError(f"Microphone {self.microphone_index} is not recording.")
|
raise RobotDeviceNotRecordingError(f"Microphone {self.microphone_index} is not recording.")
|
||||||
|
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
audio_readings = self._read()
|
audio_readings = self._read()
|
||||||
|
@ -274,21 +283,22 @@ class Microphone:
|
||||||
|
|
||||||
return audio_readings
|
return audio_readings
|
||||||
|
|
||||||
def start_recording(self, output_file : str | None = None, multiprocessing : bool | None = False) -> None:
|
def start_recording(self, output_file: str | None = None, multiprocessing: bool | None = False) -> None:
|
||||||
|
|
||||||
if not self.is_connected:
|
if not self.is_connected:
|
||||||
raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
|
raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
|
||||||
if self.is_recording:
|
if self.is_recording:
|
||||||
raise RobotDeviceAlreadyRecordingError(f"Microphone {self.microphone_index} is already recording.")
|
raise RobotDeviceAlreadyRecordingError(
|
||||||
|
f"Microphone {self.microphone_index} is already recording."
|
||||||
#Reset queues
|
)
|
||||||
|
|
||||||
|
# Reset queues
|
||||||
self.read_queue = thread_Queue()
|
self.read_queue = thread_Queue()
|
||||||
if multiprocessing:
|
if multiprocessing:
|
||||||
self.record_queue = process_Queue()
|
self.record_queue = process_Queue()
|
||||||
else:
|
else:
|
||||||
self.record_queue = thread_Queue()
|
self.record_queue = thread_Queue()
|
||||||
|
|
||||||
#Write recordings into a file if output_file is provided
|
# Write recordings into a file if output_file is provided
|
||||||
if output_file is not None:
|
if output_file is not None:
|
||||||
output_file = Path(output_file)
|
output_file = Path(output_file)
|
||||||
if output_file.exists():
|
if output_file.exists():
|
||||||
|
@ -296,28 +306,45 @@ class Microphone:
|
||||||
|
|
||||||
if multiprocessing:
|
if multiprocessing:
|
||||||
self.record_stop_event = process_Event()
|
self.record_stop_event = process_Event()
|
||||||
self.record_thread = Process(target=Microphone._record_loop, args=(self.record_queue, self.record_stop_event, self.sample_rate, self.channels, output_file, ))
|
self.record_thread = Process(
|
||||||
|
target=Microphone._record_loop,
|
||||||
|
args=(
|
||||||
|
self.record_queue,
|
||||||
|
self.record_stop_event,
|
||||||
|
self.sample_rate,
|
||||||
|
self.channels,
|
||||||
|
output_file,
|
||||||
|
),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.record_stop_event = thread_Event()
|
self.record_stop_event = thread_Event()
|
||||||
self.record_thread = Thread(target=Microphone._record_loop, args=(self.record_queue, self.record_stop_event, self.sample_rate, self.channels, output_file, ))
|
self.record_thread = Thread(
|
||||||
|
target=Microphone._record_loop,
|
||||||
|
args=(
|
||||||
|
self.record_queue,
|
||||||
|
self.record_stop_event,
|
||||||
|
self.sample_rate,
|
||||||
|
self.channels,
|
||||||
|
output_file,
|
||||||
|
),
|
||||||
|
)
|
||||||
self.record_thread.daemon = True
|
self.record_thread.daemon = True
|
||||||
self.record_thread.start()
|
self.record_thread.start()
|
||||||
|
|
||||||
self.is_writing = True
|
self.is_writing = True
|
||||||
|
|
||||||
self.is_recording = True
|
self.is_recording = True
|
||||||
self.stream.start()
|
self.stream.start()
|
||||||
|
|
||||||
def stop_recording(self) -> None:
|
def stop_recording(self) -> None:
|
||||||
|
|
||||||
if not self.is_connected:
|
if not self.is_connected:
|
||||||
raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
|
raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
|
||||||
if not self.is_recording:
|
if not self.is_recording:
|
||||||
raise RobotDeviceNotRecordingError(f"Microphone {self.microphone_index} is not recording.")
|
raise RobotDeviceNotRecordingError(f"Microphone {self.microphone_index} is not recording.")
|
||||||
|
|
||||||
if self.stream.active:
|
if self.stream.active:
|
||||||
self.stream.stop() #Wait for all buffers to be processed
|
self.stream.stop() # Wait for all buffers to be processed
|
||||||
#Remark : stream.abort() flushes the buffers !
|
# Remark : stream.abort() flushes the buffers !
|
||||||
self.is_recording = False
|
self.is_recording = False
|
||||||
|
|
||||||
if self.record_thread is not None:
|
if self.record_thread is not None:
|
||||||
|
@ -329,7 +356,6 @@ class Microphone:
|
||||||
self.is_writing = False
|
self.is_writing = False
|
||||||
|
|
||||||
def disconnect(self) -> None:
|
def disconnect(self) -> None:
|
||||||
|
|
||||||
if not self.is_connected:
|
if not self.is_connected:
|
||||||
raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
|
raise RobotDeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
|
||||||
|
|
||||||
|
@ -342,7 +368,8 @@ class Microphone:
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
if getattr(self, "is_connected", False):
|
if getattr(self, "is_connected", False):
|
||||||
self.disconnect()
|
self.disconnect()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Records audio using `Microphone` for all microphones connected to the computer, or a selected subset."
|
description="Records audio using `Microphone` for all microphones connected to the computer, or a selected subset."
|
||||||
|
|
|
@ -16,28 +16,33 @@ from typing import Protocol
|
||||||
|
|
||||||
from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig, MicrophoneConfigBase
|
from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig, MicrophoneConfigBase
|
||||||
|
|
||||||
|
|
||||||
# Defines a microphone type
|
# Defines a microphone type
|
||||||
class Microphone(Protocol):
|
class Microphone(Protocol):
|
||||||
def connect(self): ...
|
def connect(self): ...
|
||||||
def disconnect(self): ...
|
def disconnect(self): ...
|
||||||
def start_recording(self, output_file : str | None = None, multiprocessing : bool | None = False): ...
|
def start_recording(self, output_file: str | None = None, multiprocessing: bool | None = False): ...
|
||||||
def stop_recording(self): ...
|
def stop_recording(self): ...
|
||||||
|
|
||||||
|
|
||||||
def make_microphones_from_configs(microphone_configs: dict[str, MicrophoneConfigBase]) -> list[Microphone]:
|
def make_microphones_from_configs(microphone_configs: dict[str, MicrophoneConfigBase]) -> list[Microphone]:
|
||||||
microphones = {}
|
microphones = {}
|
||||||
|
|
||||||
for key, cfg in microphone_configs.items():
|
for key, cfg in microphone_configs.items():
|
||||||
if cfg.type == "microphone":
|
if cfg.type == "microphone":
|
||||||
from lerobot.common.robot_devices.microphones.microphone import Microphone
|
from lerobot.common.robot_devices.microphones.microphone import Microphone
|
||||||
|
|
||||||
microphones[key] = Microphone(cfg)
|
microphones[key] = Microphone(cfg)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"The microphone type '{cfg.type}' is not valid.")
|
raise ValueError(f"The microphone type '{cfg.type}' is not valid.")
|
||||||
|
|
||||||
return microphones
|
return microphones
|
||||||
|
|
||||||
|
|
||||||
def make_microphone(microphone_type, **kwargs) -> Microphone:
|
def make_microphone(microphone_type, **kwargs) -> Microphone:
|
||||||
if microphone_type == "microphone":
|
if microphone_type == "microphone":
|
||||||
from lerobot.common.robot_devices.microphones.microphone import Microphone
|
from lerobot.common.robot_devices.microphones.microphone import Microphone
|
||||||
|
|
||||||
return Microphone(MicrophoneConfig(**kwargs))
|
return Microphone(MicrophoneConfig(**kwargs))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"The microphone type '{microphone_type}' is not valid.")
|
raise ValueError(f"The microphone type '{microphone_type}' is not valid.")
|
||||||
|
|
|
@ -23,12 +23,12 @@ from lerobot.common.robot_devices.cameras.configs import (
|
||||||
IntelRealSenseCameraConfig,
|
IntelRealSenseCameraConfig,
|
||||||
OpenCVCameraConfig,
|
OpenCVCameraConfig,
|
||||||
)
|
)
|
||||||
|
from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig
|
||||||
from lerobot.common.robot_devices.motors.configs import (
|
from lerobot.common.robot_devices.motors.configs import (
|
||||||
DynamixelMotorsBusConfig,
|
DynamixelMotorsBusConfig,
|
||||||
FeetechMotorsBusConfig,
|
FeetechMotorsBusConfig,
|
||||||
MotorsBusConfig,
|
MotorsBusConfig,
|
||||||
)
|
)
|
||||||
from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
@ -51,6 +51,7 @@ def run_camera_capture(cameras, images_lock, latest_images_dict, stop_event):
|
||||||
latest_images_dict.update(local_dict)
|
latest_images_dict.update(local_dict)
|
||||||
time.sleep(0.01)
|
time.sleep(0.01)
|
||||||
|
|
||||||
|
|
||||||
def run_microphone_capture(microphones, audio_lock, latest_audio_dict, stop_event):
|
def run_microphone_capture(microphones, audio_lock, latest_audio_dict, stop_event):
|
||||||
while not stop_event.is_set():
|
while not stop_event.is_set():
|
||||||
local_dict = {}
|
local_dict = {}
|
||||||
|
@ -60,6 +61,7 @@ def run_microphone_capture(microphones, audio_lock, latest_audio_dict, stop_even
|
||||||
with audio_lock:
|
with audio_lock:
|
||||||
latest_audio_dict.update(local_dict)
|
latest_audio_dict.update(local_dict)
|
||||||
|
|
||||||
|
|
||||||
def calibrate_follower_arm(motors_bus, calib_dir_str):
|
def calibrate_follower_arm(motors_bus, calib_dir_str):
|
||||||
"""
|
"""
|
||||||
Calibrates the follower arm. Attempts to load an existing calibration file;
|
Calibrates the follower arm. Attempts to load an existing calibration file;
|
||||||
|
@ -149,12 +151,14 @@ def run_lekiwi(robot_config):
|
||||||
cam_thread.start()
|
cam_thread.start()
|
||||||
|
|
||||||
# Start the microphone recording and capture thread.
|
# Start the microphone recording and capture thread.
|
||||||
#TODO(CarolinePascal) : Leverage multi-core processing with a multiprocessing.Process instead !
|
# TODO(CarolinePascal) : Leverage multi-core processing with a multiprocessing.Process instead !
|
||||||
latest_audio_dict = {}
|
latest_audio_dict = {}
|
||||||
audio_lock = threading.Lock()
|
audio_lock = threading.Lock()
|
||||||
audio_stop_event = threading.Event()
|
audio_stop_event = threading.Event()
|
||||||
microphone_thread = threading.Thread(
|
microphone_thread = threading.Thread(
|
||||||
target=run_microphone_capture, args=(microphones, audio_lock, latest_audio_dict, audio_stop_event), daemon=True
|
target=run_microphone_capture,
|
||||||
|
args=(microphones, audio_lock, latest_audio_dict, audio_stop_event),
|
||||||
|
daemon=True,
|
||||||
)
|
)
|
||||||
for microphone in microphones.values():
|
for microphone in microphones.values():
|
||||||
microphone.start_recording()
|
microphone.start_recording()
|
||||||
|
@ -231,7 +235,7 @@ def run_lekiwi(robot_config):
|
||||||
# Build the observation dictionary.
|
# Build the observation dictionary.
|
||||||
observation = {
|
observation = {
|
||||||
"images": images_dict_copy,
|
"images": images_dict_copy,
|
||||||
"audio": audio_dict_copy, #TODO(CarolinePascal) : This is a nasty way to do it, sorry.
|
"audio": audio_dict_copy, # TODO(CarolinePascal) : This is a nasty way to do it, sorry.
|
||||||
"present_speed": current_velocity,
|
"present_speed": current_velocity,
|
||||||
"follower_arm_state": follower_arm_state,
|
"follower_arm_state": follower_arm_state,
|
||||||
}
|
}
|
||||||
|
|
|
@ -201,7 +201,7 @@ class ManipulatorRobot:
|
||||||
"names": state_names,
|
"names": state_names,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def microphone_features(self) -> dict:
|
def microphone_features(self) -> dict:
|
||||||
mic_ft = {}
|
mic_ft = {}
|
||||||
|
@ -211,7 +211,7 @@ class ManipulatorRobot:
|
||||||
"dtype": "audio",
|
"dtype": "audio",
|
||||||
"shape": (len(mic.channels),),
|
"shape": (len(mic.channels),),
|
||||||
"names": "channels",
|
"names": "channels",
|
||||||
"info" : {"sample_rate": mic.sample_rate},
|
"info": {"sample_rate": mic.sample_rate},
|
||||||
}
|
}
|
||||||
return mic_ft
|
return mic_ft
|
||||||
|
|
||||||
|
@ -226,11 +226,11 @@ class ManipulatorRobot:
|
||||||
@property
|
@property
|
||||||
def num_cameras(self):
|
def num_cameras(self):
|
||||||
return len(self.cameras)
|
return len(self.cameras)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def has_microphone(self):
|
def has_microphone(self):
|
||||||
return len(self.microphones) > 0
|
return len(self.microphones) > 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_microphones(self):
|
def num_microphones(self):
|
||||||
return len(self.microphones)
|
return len(self.microphones)
|
||||||
|
|
|
@ -163,7 +163,7 @@ class MobileManipulator:
|
||||||
"names": combined_names,
|
"names": combined_names,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def microphone_features(self) -> dict:
|
def microphone_features(self) -> dict:
|
||||||
mic_ft = {}
|
mic_ft = {}
|
||||||
|
@ -173,7 +173,7 @@ class MobileManipulator:
|
||||||
"dtype": "audio",
|
"dtype": "audio",
|
||||||
"shape": (len(mic.channels),),
|
"shape": (len(mic.channels),),
|
||||||
"names": "channels",
|
"names": "channels",
|
||||||
"info" : {"sample_rate": mic.sample_rate},
|
"info": {"sample_rate": mic.sample_rate},
|
||||||
}
|
}
|
||||||
return mic_ft
|
return mic_ft
|
||||||
|
|
||||||
|
@ -188,11 +188,11 @@ class MobileManipulator:
|
||||||
@property
|
@property
|
||||||
def num_cameras(self):
|
def num_cameras(self):
|
||||||
return len(self.cameras)
|
return len(self.cameras)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def has_microphone(self):
|
def has_microphone(self):
|
||||||
return len(self.microphones) > 0
|
return len(self.microphones) > 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_microphones(self):
|
def num_microphones(self):
|
||||||
return len(self.microphones)
|
return len(self.microphones)
|
||||||
|
@ -512,7 +512,7 @@ class MobileManipulator:
|
||||||
# Create silence using the microphone's configured channels
|
# Create silence using the microphone's configured channels
|
||||||
frame = np.zeros((1, len(microphone.channels)), dtype=np.float32)
|
frame = np.zeros((1, len(microphone.channels)), dtype=np.float32)
|
||||||
obs_dict[f"observation.audio.{microphone_name}"] = torch.from_numpy(frame)
|
obs_dict[f"observation.audio.{microphone_name}"] = torch.from_numpy(frame)
|
||||||
|
|
||||||
return obs_dict
|
return obs_dict
|
||||||
|
|
||||||
def send_action(self, action: torch.Tensor) -> torch.Tensor:
|
def send_action(self, action: torch.Tensor) -> torch.Tensor:
|
||||||
|
|
|
@ -69,11 +69,13 @@ class RobotDeviceNotRecordingError(Exception):
|
||||||
"""Exception raised when the robot device is not recording."""
|
"""Exception raised when the robot device is not recording."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, message="This robot device is not recording. Try calling `robot_device.start_recording()` first."
|
self,
|
||||||
|
message="This robot device is not recording. Try calling `robot_device.start_recording()` first.",
|
||||||
):
|
):
|
||||||
self.message = message
|
self.message = message
|
||||||
super().__init__(self.message)
|
super().__init__(self.message)
|
||||||
|
|
||||||
|
|
||||||
class RobotDeviceAlreadyRecordingError(Exception):
|
class RobotDeviceAlreadyRecordingError(Exception):
|
||||||
"""Exception raised when the robot device is already recording."""
|
"""Exception raised when the robot device is already recording."""
|
||||||
|
|
||||||
|
|
|
@ -19,9 +19,9 @@ import traceback
|
||||||
import pytest
|
import pytest
|
||||||
from serial import SerialException
|
from serial import SerialException
|
||||||
|
|
||||||
from lerobot import available_cameras, available_motors, available_robots, available_microphones
|
from lerobot import available_cameras, available_microphones, available_motors, available_robots
|
||||||
from lerobot.common.robot_devices.robots.utils import make_robot
|
from lerobot.common.robot_devices.robots.utils import make_robot
|
||||||
from tests.utils import DEVICE, make_camera, make_motors_bus, make_microphone
|
from tests.utils import DEVICE, make_camera, make_microphone, make_motors_bus
|
||||||
|
|
||||||
# Import fixture modules as plugins
|
# Import fixture modules as plugins
|
||||||
pytest_plugins = [
|
pytest_plugins = [
|
||||||
|
@ -73,10 +73,12 @@ def is_robot_available(robot_type):
|
||||||
def is_camera_available(camera_type):
|
def is_camera_available(camera_type):
|
||||||
return _check_component_availability(camera_type, available_cameras, make_camera)
|
return _check_component_availability(camera_type, available_cameras, make_camera)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def is_microphone_available(microphone_type):
|
def is_microphone_available(microphone_type):
|
||||||
return _check_component_availability(microphone_type, available_microphones, make_microphone)
|
return _check_component_availability(microphone_type, available_microphones, make_microphone)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def is_motor_available(motor_type):
|
def is_motor_available(motor_type):
|
||||||
return _check_component_availability(motor_type, available_motors, make_motors_bus)
|
return _check_component_availability(motor_type, available_motors, make_motors_bus)
|
||||||
|
|
|
@ -25,9 +25,9 @@ from lerobot.common.datasets.compute_stats import (
|
||||||
compute_episode_stats,
|
compute_episode_stats,
|
||||||
estimate_num_samples,
|
estimate_num_samples,
|
||||||
get_feature_stats,
|
get_feature_stats,
|
||||||
sample_images,
|
|
||||||
sample_audio_from_path,
|
|
||||||
sample_audio_from_data,
|
sample_audio_from_data,
|
||||||
|
sample_audio_from_path,
|
||||||
|
sample_images,
|
||||||
sample_indices,
|
sample_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -35,8 +35,10 @@ from lerobot.common.datasets.compute_stats import (
|
||||||
def mock_load_image_as_numpy(path, dtype, channel_first):
|
def mock_load_image_as_numpy(path, dtype, channel_first):
|
||||||
return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype)
|
return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
def mock_load_audio(path):
|
def mock_load_audio(path):
|
||||||
return np.ones((16000,2), dtype=np.float32)
|
return np.ones((16000, 2), dtype=np.float32)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_array():
|
def sample_array():
|
||||||
|
@ -74,6 +76,7 @@ def test_sample_images(mock_load):
|
||||||
assert images.dtype == np.uint8
|
assert images.dtype == np.uint8
|
||||||
assert len(images) == estimate_num_samples(100)
|
assert len(images) == estimate_num_samples(100)
|
||||||
|
|
||||||
|
|
||||||
@patch("lerobot.common.datasets.compute_stats.load_audio_from_path", side_effect=mock_load_audio)
|
@patch("lerobot.common.datasets.compute_stats.load_audio_from_path", side_effect=mock_load_audio)
|
||||||
def test_sample_audio_from_path(mock_load):
|
def test_sample_audio_from_path(mock_load):
|
||||||
audio_path = "audio.wav"
|
audio_path = "audio.wav"
|
||||||
|
@ -83,6 +86,7 @@ def test_sample_audio_from_path(mock_load):
|
||||||
assert audio_samples.dtype == np.float32
|
assert audio_samples.dtype == np.float32
|
||||||
assert len(audio_samples) == estimate_num_samples(16000)
|
assert len(audio_samples) == estimate_num_samples(16000)
|
||||||
|
|
||||||
|
|
||||||
def test_sample_audio_from_data(mock_load):
|
def test_sample_audio_from_data(mock_load):
|
||||||
audio_data = np.ones((16000, 2), dtype=np.float32)
|
audio_data = np.ones((16000, 2), dtype=np.float32)
|
||||||
audio_samples = sample_audio_from_data(audio_data)
|
audio_samples = sample_audio_from_data(audio_data)
|
||||||
|
@ -91,6 +95,7 @@ def test_sample_audio_from_data(mock_load):
|
||||||
assert audio_samples.dtype == np.float32
|
assert audio_samples.dtype == np.float32
|
||||||
assert len(audio_samples) == estimate_num_samples(16000)
|
assert len(audio_samples) == estimate_num_samples(16000)
|
||||||
|
|
||||||
|
|
||||||
def test_get_feature_stats_images():
|
def test_get_feature_stats_images():
|
||||||
data = np.random.rand(100, 3, 32, 32)
|
data = np.random.rand(100, 3, 32, 32)
|
||||||
stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True)
|
stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True)
|
||||||
|
@ -98,13 +103,15 @@ def test_get_feature_stats_images():
|
||||||
np.testing.assert_equal(stats["count"], np.array([100]))
|
np.testing.assert_equal(stats["count"], np.array([100]))
|
||||||
assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape
|
assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape
|
||||||
|
|
||||||
|
|
||||||
def test_get_feature_stats_audio():
|
def test_get_feature_stats_audio():
|
||||||
data = np.random.uniform(-1, 1, (16000,2))
|
data = np.random.uniform(-1, 1, (16000, 2))
|
||||||
stats = get_feature_stats(data, axis=0, keepdims=True)
|
stats = get_feature_stats(data, axis=0, keepdims=True)
|
||||||
assert "min" in stats and "max" in stats and "mean" in stats and "std" in stats and "count" in stats
|
assert "min" in stats and "max" in stats and "mean" in stats and "std" in stats and "count" in stats
|
||||||
np.testing.assert_equal(stats["count"], np.array([16000]))
|
np.testing.assert_equal(stats["count"], np.array([16000]))
|
||||||
assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape
|
assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape
|
||||||
|
|
||||||
|
|
||||||
def test_get_feature_stats_axis_0_keepdims(sample_array):
|
def test_get_feature_stats_axis_0_keepdims(sample_array):
|
||||||
expected = {
|
expected = {
|
||||||
"min": np.array([[1, 2, 3]]),
|
"min": np.array([[1, 2, 3]]),
|
||||||
|
@ -172,10 +179,11 @@ def test_compute_episode_stats():
|
||||||
"observation.state": {"dtype": "numeric"},
|
"observation.state": {"dtype": "numeric"},
|
||||||
}
|
}
|
||||||
|
|
||||||
with patch(
|
with (
|
||||||
"lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy
|
patch(
|
||||||
), patch(
|
"lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy
|
||||||
"lerobot.common.datasets.compute_stats.load_audio_from_path", side_effect=mock_load_audio
|
),
|
||||||
|
patch("lerobot.common.datasets.compute_stats.load_audio_from_path", side_effect=mock_load_audio),
|
||||||
):
|
):
|
||||||
stats = compute_episode_stats(episode_data, features)
|
stats = compute_episode_stats(episode_data, features)
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
import time
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -35,6 +36,7 @@ from lerobot.common.datasets.lerobot_dataset import (
|
||||||
MultiLeRobotDataset,
|
MultiLeRobotDataset,
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
|
DEFAULT_AUDIO_CHUNK_DURATION,
|
||||||
create_branch,
|
create_branch,
|
||||||
flatten_dict,
|
flatten_dict,
|
||||||
unflatten_dict,
|
unflatten_dict,
|
||||||
|
@ -44,12 +46,9 @@ from lerobot.common.policies.factory import make_policy_config
|
||||||
from lerobot.common.robot_devices.robots.utils import make_robot
|
from lerobot.common.robot_devices.robots.utils import make_robot
|
||||||
from lerobot.configs.default import DatasetConfig
|
from lerobot.configs.default import DatasetConfig
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID, DUMMY_AUDIO_CHANNELS
|
from tests.fixtures.constants import DUMMY_AUDIO_CHANNELS, DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
||||||
from tests.utils import require_x86_64_kernel
|
from tests.utils import make_microphone, require_x86_64_kernel
|
||||||
|
|
||||||
from tests.utils import make_microphone
|
|
||||||
import time
|
|
||||||
from lerobot.common.datasets.utils import DEFAULT_AUDIO_CHUNK_DURATION
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def image_dataset(tmp_path, empty_lerobot_dataset_factory):
|
def image_dataset(tmp_path, empty_lerobot_dataset_factory):
|
||||||
|
@ -66,6 +65,7 @@ def image_dataset(tmp_path, empty_lerobot_dataset_factory):
|
||||||
}
|
}
|
||||||
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def audio_dataset(tmp_path, empty_lerobot_dataset_factory):
|
def audio_dataset(tmp_path, empty_lerobot_dataset_factory):
|
||||||
features = {
|
features = {
|
||||||
|
@ -79,6 +79,7 @@ def audio_dataset(tmp_path, empty_lerobot_dataset_factory):
|
||||||
}
|
}
|
||||||
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
|
|
||||||
|
|
||||||
def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
|
def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
|
||||||
"""
|
"""
|
||||||
Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated
|
Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated
|
||||||
|
@ -336,6 +337,7 @@ def test_image_array_to_pil_image_wrong_range_float_0_255():
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
image_array_to_pil_image(image)
|
image_array_to_pil_image(image)
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_audio(audio_dataset):
|
def test_add_frame_audio(audio_dataset):
|
||||||
dataset = audio_dataset
|
dataset = audio_dataset
|
||||||
|
|
||||||
|
@ -349,7 +351,10 @@ def test_add_frame_audio(audio_dataset):
|
||||||
|
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
|
|
||||||
assert dataset[0]["observation.audio.microphone"].shape == torch.Size((int(DEFAULT_AUDIO_CHUNK_DURATION*microphone.sample_rate),DUMMY_AUDIO_CHANNELS))
|
assert dataset[0]["observation.audio.microphone"].shape == torch.Size(
|
||||||
|
(int(DEFAULT_AUDIO_CHUNK_DURATION * microphone.sample_rate), DUMMY_AUDIO_CHANNELS)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO(aliberts):
|
# TODO(aliberts):
|
||||||
# - [ ] test various attributes & state from init and create
|
# - [ ] test various attributes & state from init and create
|
||||||
|
|
|
@ -29,7 +29,12 @@ DUMMY_MOTOR_FEATURES = {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
DUMMY_CAMERA_FEATURES = {
|
DUMMY_CAMERA_FEATURES = {
|
||||||
"laptop": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None, "audio": "laptop"},
|
"laptop": {
|
||||||
|
"shape": (480, 640, 3),
|
||||||
|
"names": ["height", "width", "channels"],
|
||||||
|
"info": None,
|
||||||
|
"audio": "laptop",
|
||||||
|
},
|
||||||
"phone": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
|
"phone": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
|
||||||
}
|
}
|
||||||
DEFAULT_FPS = 30
|
DEFAULT_FPS = 30
|
||||||
|
|
|
@ -26,10 +26,10 @@ import torch
|
||||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
|
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
DEFAULT_CHUNK_SIZE,
|
DEFAULT_CHUNK_SIZE,
|
||||||
|
DEFAULT_COMPRESSED_AUDIO_PATH,
|
||||||
DEFAULT_FEATURES,
|
DEFAULT_FEATURES,
|
||||||
DEFAULT_PARQUET_PATH,
|
DEFAULT_PARQUET_PATH,
|
||||||
DEFAULT_VIDEO_PATH,
|
DEFAULT_VIDEO_PATH,
|
||||||
DEFAULT_COMPRESSED_AUDIO_PATH,
|
|
||||||
get_hf_features_from_features,
|
get_hf_features_from_features,
|
||||||
hf_transform_to_torch,
|
hf_transform_to_torch,
|
||||||
)
|
)
|
||||||
|
|
|
@ -11,27 +11,30 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import time
|
||||||
from functools import cache
|
from functools import cache
|
||||||
|
from threading import Event, Thread
|
||||||
from tests.fixtures.constants import DUMMY_AUDIO_CHANNELS, DEFAULT_SAMPLE_RATE
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from lerobot.common.utils.utils import capture_timestamp_utc
|
from lerobot.common.utils.utils import capture_timestamp_utc
|
||||||
from threading import Thread, Event
|
from tests.fixtures.constants import DEFAULT_SAMPLE_RATE, DUMMY_AUDIO_CHANNELS
|
||||||
import time
|
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
def _generate_sound(duration: float, sample_rate: int, channels: int):
|
def _generate_sound(duration: float, sample_rate: int, channels: int):
|
||||||
return np.random.uniform(-1, 1, size=(int(duration * sample_rate), channels)).astype(np.float32)
|
return np.random.uniform(-1, 1, size=(int(duration * sample_rate), channels)).astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
def query_devices(query_index: int):
|
def query_devices(query_index: int):
|
||||||
return {
|
return {
|
||||||
"name": "Mock Sound Device",
|
"name": "Mock Sound Device",
|
||||||
"index": query_index,
|
"index": query_index,
|
||||||
"max_input_channels": DUMMY_AUDIO_CHANNELS,
|
"max_input_channels": DUMMY_AUDIO_CHANNELS,
|
||||||
"default_samplerate": DEFAULT_SAMPLE_RATE,
|
"default_samplerate": DEFAULT_SAMPLE_RATE,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class InputStream:
|
class InputStream:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
self._mock_dict = {
|
self._mock_dict = {
|
||||||
|
@ -49,7 +52,12 @@ class InputStream:
|
||||||
while not self.callback_thread_stop_event.is_set():
|
while not self.callback_thread_stop_event.is_set():
|
||||||
# Simulate audio data acquisition
|
# Simulate audio data acquisition
|
||||||
time.sleep(0.01)
|
time.sleep(0.01)
|
||||||
self._audio_callback(_generate_sound(0.01, DEFAULT_SAMPLE_RATE, DUMMY_AUDIO_CHANNELS), 0.01*DEFAULT_SAMPLE_RATE, capture_timestamp_utc(), None)
|
self._audio_callback(
|
||||||
|
_generate_sound(0.01, DEFAULT_SAMPLE_RATE, DUMMY_AUDIO_CHANNELS),
|
||||||
|
0.01 * DEFAULT_SAMPLE_RATE,
|
||||||
|
capture_timestamp_utc(),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
self.callback_thread_stop_event = Event()
|
self.callback_thread_stop_event = Event()
|
||||||
|
@ -62,7 +70,7 @@ class InputStream:
|
||||||
@property
|
@property
|
||||||
def active(self):
|
def active(self):
|
||||||
return self._is_active
|
return self._is_active
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
if self.callback_thread_stop_event is not None:
|
if self.callback_thread_stop_event is not None:
|
||||||
self.callback_thread_stop_event.set()
|
self.callback_thread_stop_event.set()
|
||||||
|
@ -78,5 +86,3 @@ class InputStream:
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
if self._is_active:
|
if self._is_active:
|
||||||
self.stop()
|
self.stop()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -32,20 +32,27 @@ pytest -sx 'tests/microphones/test_microphones.py::test_microphone[microphone-Tr
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
from soundfile import read
|
from soundfile import read
|
||||||
|
|
||||||
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError, RobotDeviceNotRecordingError, RobotDeviceAlreadyRecordingError
|
from lerobot.common.robot_devices.utils import (
|
||||||
|
RobotDeviceAlreadyConnectedError,
|
||||||
|
RobotDeviceAlreadyRecordingError,
|
||||||
|
RobotDeviceNotConnectedError,
|
||||||
|
RobotDeviceNotRecordingError,
|
||||||
|
)
|
||||||
from tests.utils import TEST_MICROPHONE_TYPES, make_microphone, require_microphone
|
from tests.utils import TEST_MICROPHONE_TYPES, make_microphone, require_microphone
|
||||||
|
|
||||||
#Maximum recording tie difference between two consecutive audio recordings of the same duration.
|
# Maximum recording tie difference between two consecutive audio recordings of the same duration.
|
||||||
#Set to 0.02 seconds as twice the default size of sounddvice callback buffer (i.e. we tolerate the loss of one buffer).
|
# Set to 0.02 seconds as twice the default size of sounddvice callback buffer (i.e. we tolerate the loss of one buffer).
|
||||||
MAX_RECORDING_TIME_DIFFERENCE = 0.02
|
MAX_RECORDING_TIME_DIFFERENCE = 0.02
|
||||||
|
|
||||||
DUMMY_RECORDING = "test_recording.wav"
|
DUMMY_RECORDING = "test_recording.wav"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("microphone_type, mock", TEST_MICROPHONE_TYPES)
|
@pytest.mark.parametrize("microphone_type, mock", TEST_MICROPHONE_TYPES)
|
||||||
@require_microphone
|
@require_microphone
|
||||||
def test_microphone(tmp_path, request, microphone_type, mock):
|
def test_microphone(tmp_path, request, microphone_type, mock):
|
||||||
|
@ -92,7 +99,7 @@ def test_microphone(tmp_path, request, microphone_type, mock):
|
||||||
fpath = tmp_path / DUMMY_RECORDING
|
fpath = tmp_path / DUMMY_RECORDING
|
||||||
microphone.start_recording(fpath)
|
microphone.start_recording(fpath)
|
||||||
assert microphone.is_recording
|
assert microphone.is_recording
|
||||||
|
|
||||||
# Test start_recording twice raises an error
|
# Test start_recording twice raises an error
|
||||||
with pytest.raises(RobotDeviceAlreadyRecordingError):
|
with pytest.raises(RobotDeviceAlreadyRecordingError):
|
||||||
microphone.start_recording()
|
microphone.start_recording()
|
||||||
|
@ -126,10 +133,13 @@ def test_microphone(tmp_path, request, microphone_type, mock):
|
||||||
|
|
||||||
error_msg = (
|
error_msg = (
|
||||||
"Recording time difference between read() and stop_recording()",
|
"Recording time difference between read() and stop_recording()",
|
||||||
(len(audio_chunk) - len(recorded_audio))/MAX_RECORDING_TIME_DIFFERENCE,
|
(len(audio_chunk) - len(recorded_audio)) / MAX_RECORDING_TIME_DIFFERENCE,
|
||||||
)
|
)
|
||||||
np.testing.assert_allclose(
|
np.testing.assert_allclose(
|
||||||
len(audio_chunk), len(recorded_audio), atol=recorded_sample_rate*MAX_RECORDING_TIME_DIFFERENCE, err_msg=error_msg
|
len(audio_chunk),
|
||||||
|
len(recorded_audio),
|
||||||
|
atol=recorded_sample_rate * MAX_RECORDING_TIME_DIFFERENCE,
|
||||||
|
err_msg=error_msg,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Test disconnecting
|
# Test disconnecting
|
||||||
|
@ -139,4 +149,4 @@ def test_microphone(tmp_path, request, microphone_type, mock):
|
||||||
# Test disconnecting with `__del__`
|
# Test disconnecting with `__del__`
|
||||||
microphone = make_microphone(**microphone_kwargs)
|
microphone = make_microphone(**microphone_kwargs)
|
||||||
microphone.connect()
|
microphone.connect()
|
||||||
del microphone
|
del microphone
|
||||||
|
|
|
@ -143,7 +143,7 @@ def test_robot(tmp_path, request, robot_type, mock):
|
||||||
robot.send_action(action["action"])
|
robot.send_action(action["action"])
|
||||||
|
|
||||||
# Test disconnecting
|
# Test disconnecting
|
||||||
robot.disconnect() #Also handles microphone recording stop, life is beautiful
|
robot.disconnect() # Also handles microphone recording stop, life is beautiful
|
||||||
assert not robot.is_connected
|
assert not robot.is_connected
|
||||||
for name in robot.follower_arms:
|
for name in robot.follower_arms:
|
||||||
assert not robot.follower_arms[name].is_connected
|
assert not robot.follower_arms[name].is_connected
|
||||||
|
|
|
@ -22,13 +22,13 @@ from pathlib import Path
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot import available_cameras, available_motors, available_robots, available_microphones
|
from lerobot import available_cameras, available_microphones, available_motors, available_robots
|
||||||
from lerobot.common.robot_devices.cameras.utils import Camera
|
from lerobot.common.robot_devices.cameras.utils import Camera
|
||||||
from lerobot.common.robot_devices.cameras.utils import make_camera as make_camera_device
|
from lerobot.common.robot_devices.cameras.utils import make_camera as make_camera_device
|
||||||
from lerobot.common.robot_devices.motors.utils import MotorsBus
|
|
||||||
from lerobot.common.robot_devices.motors.utils import make_motors_bus as make_motors_bus_device
|
|
||||||
from lerobot.common.robot_devices.microphones.utils import Microphone
|
from lerobot.common.robot_devices.microphones.utils import Microphone
|
||||||
from lerobot.common.robot_devices.microphones.utils import make_microphone as make_microphone_device
|
from lerobot.common.robot_devices.microphones.utils import make_microphone as make_microphone_device
|
||||||
|
from lerobot.common.robot_devices.motors.utils import MotorsBus
|
||||||
|
from lerobot.common.robot_devices.motors.utils import make_motors_bus as make_motors_bus_device
|
||||||
from lerobot.common.utils.import_utils import is_package_available
|
from lerobot.common.utils.import_utils import is_package_available
|
||||||
|
|
||||||
DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", "cuda") if torch.cuda.is_available() else "cpu"
|
DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", "cuda") if torch.cuda.is_available() else "cpu"
|
||||||
|
@ -261,6 +261,7 @@ def require_camera(func):
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def require_microphone(func):
|
def require_microphone(func):
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
|
@ -283,6 +284,7 @@ def require_microphone(func):
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def require_motor(func):
|
def require_motor(func):
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
|
@ -344,6 +346,7 @@ def make_camera(camera_type: str, **kwargs) -> Camera:
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"The camera type '{camera_type}' is not valid.")
|
raise ValueError(f"The camera type '{camera_type}' is not valid.")
|
||||||
|
|
||||||
|
|
||||||
def make_microphone(microphone_type: str, **kwargs) -> Microphone:
|
def make_microphone(microphone_type: str, **kwargs) -> Microphone:
|
||||||
if microphone_type == "microphone":
|
if microphone_type == "microphone":
|
||||||
microphone_index = kwargs.pop("microphone_index", MICROPHONE_INDEX)
|
microphone_index = kwargs.pop("microphone_index", MICROPHONE_INDEX)
|
||||||
|
@ -351,6 +354,7 @@ def make_microphone(microphone_type: str, **kwargs) -> Microphone:
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"The microphone type '{microphone_type}' is not valid.")
|
raise ValueError(f"The microphone type '{microphone_type}' is not valid.")
|
||||||
|
|
||||||
|
|
||||||
# TODO(rcadene, aliberts): remove this dark pattern that overrides
|
# TODO(rcadene, aliberts): remove this dark pattern that overrides
|
||||||
def make_motors_bus(motor_type: str, **kwargs) -> MotorsBus:
|
def make_motors_bus(motor_type: str, **kwargs) -> MotorsBus:
|
||||||
if motor_type == "dynamixel":
|
if motor_type == "dynamixel":
|
||||||
|
|
Loading…
Reference in New Issue