diff --git a/lerobot/common/robot_devices/microphones/configs.py b/lerobot/common/robot_devices/microphones/configs.py new file mode 100644 index 00000000..eb59be49 --- /dev/null +++ b/lerobot/common/robot_devices/microphones/configs.py @@ -0,0 +1,37 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from dataclasses import dataclass + +import draccus + +@dataclass +class MicrophoneConfigBase(draccus.ChoiceRegistry, abc.ABC): + @property + def type(self) -> str: + return self.get_choice_name(self.__class__) + +@MicrophoneConfigBase.register_subclass("microphone") +@dataclass +class MicrophoneConfig(MicrophoneConfigBase): + """ + Dataclass for microphone configuration. + """ + + microphone_index: int + sampling_rate: int | None = None + channels: list[int] | None = None + data_type: str | None = None + mock: bool = False \ No newline at end of file diff --git a/lerobot/common/robot_devices/microphones/utils.py b/lerobot/common/robot_devices/microphones/utils.py new file mode 100644 index 00000000..ea5790dc --- /dev/null +++ b/lerobot/common/robot_devices/microphones/utils.py @@ -0,0 +1,43 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Protocol + +from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig, MicrophoneConfigBase + +# Defines a microphone type +class Microphone(Protocol): + def connect(self): ... + def disconnect(self): ... + def start_recording(self, output_file: str | None = None): ... + def stop_recording(self): ... + +def make_microphones_from_configs(microphone_configs: dict[str, MicrophoneConfigBase]) -> list[Microphone]: + microphones = {} + + for key, cfg in microphone_configs.items(): + if cfg.type == "microphone": + from lerobot.common.robot_devices.microphones.microphone import Microphone + microphones[key] = Microphone(cfg) + else: + raise ValueError(f"The microphone type '{cfg.type}' is not valid.") + + return microphones + +def make_microphone(microphone_type, **kwargs) -> Microphone: + if microphone_type == "microphone": + from lerobot.common.robot_devices.microphones.microphone import Microphone + return Microphone(MicrophoneConfig(**kwargs)) + else: + raise ValueError(f"The microphone type '{microphone_type}' is not valid.") \ No newline at end of file diff --git a/lerobot/common/robot_devices/robots/configs.py b/lerobot/common/robot_devices/robots/configs.py index e940b442..9d72a46b 100644 --- a/lerobot/common/robot_devices/robots/configs.py +++ b/lerobot/common/robot_devices/robots/configs.py @@ -28,6 +28,7 @@ from lerobot.common.robot_devices.motors.configs import ( FeetechMotorsBusConfig, MotorsBusConfig, ) +from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig @dataclass @@ -68,6 +69,9 @@ class ManipulatorRobotConfig(RobotConfig): for cam in self.cameras.values(): if not cam.mock: cam.mock = True + for mic in self.microphones.values(): + if not mic.mock: + mic.mock = True if self.max_relative_target is not None and isinstance(self.max_relative_target, Sequence): for name in self.follower_arms: diff --git a/lerobot/common/robot_devices/robots/manipulator.py b/lerobot/common/robot_devices/robots/manipulator.py index 9173abc6..9212bb9c 100644 --- a/lerobot/common/robot_devices/robots/manipulator.py +++ b/lerobot/common/robot_devices/robots/manipulator.py @@ -28,6 +28,7 @@ import numpy as np import torch from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs +from lerobot.common.robot_devices.microphones.utils import make_microphones_from_configs from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs from lerobot.common.robot_devices.robots.configs import ManipulatorRobotConfig from lerobot.common.robot_devices.robots.utils import get_arm_id @@ -164,6 +165,7 @@ class ManipulatorRobot: self.leader_arms = make_motors_buses_from_configs(self.config.leader_arms) self.follower_arms = make_motors_buses_from_configs(self.config.follower_arms) self.cameras = make_cameras_from_configs(self.config.cameras) + self.microphones = make_microphones_from_configs(self.config.microphones) self.is_connected = False self.logs = {} @@ -198,6 +200,18 @@ class ManipulatorRobot: "names": state_names, }, } + + @property + def microphones_features(self) -> dict: + mic_ft = {} + for mic_key, mic in self.microphones.items(): + key = f"observation.audio.{mic_key}" + mic_ft[key] = { + "dtype": mic.data_type, + "shape": (mic.channels,), + "info": None, + } + return mic_ft @property def features(self): @@ -210,6 +224,14 @@ class ManipulatorRobot: @property def num_cameras(self): return len(self.cameras) + + @property + def has_microphone(self): + return len(self.microphones) > 0 + + @property + def num_microphones(self): + return len(self.microphones) @property def available_arms(self): @@ -228,7 +250,7 @@ class ManipulatorRobot: "ManipulatorRobot is already connected. Do not run `robot.connect()` twice." ) - if not self.leader_arms and not self.follower_arms and not self.cameras: + if not self.leader_arms and not self.follower_arms and not self.cameras and not self.microphones: raise ValueError( "ManipulatorRobot doesn't have any device to connect. See example of usage in docstring of the class." ) @@ -289,6 +311,10 @@ class ManipulatorRobot: for name in self.cameras: self.cameras[name].connect() + # Connect the microphones + for name in self.microphones: + self.microphones[name].connect() + self.is_connected = True def activate_calibration(self): @@ -620,6 +646,9 @@ class ManipulatorRobot: for name in self.cameras: self.cameras[name].disconnect() + for name in self.microphones: + self.microphones[name].disconnect() + self.is_connected = False def __del__(self):