Integrate microphones in Robot class

This commit is contained in:
CarolinePascal 2025-03-28 17:15:19 +01:00
parent 58cd0bdf86
commit e6ea8e75c3
No known key found for this signature in database
4 changed files with 114 additions and 1 deletions

View File

@ -0,0 +1,37 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from dataclasses import dataclass
import draccus
@dataclass
class MicrophoneConfigBase(draccus.ChoiceRegistry, abc.ABC):
@property
def type(self) -> str:
return self.get_choice_name(self.__class__)
@MicrophoneConfigBase.register_subclass("microphone")
@dataclass
class MicrophoneConfig(MicrophoneConfigBase):
"""
Dataclass for microphone configuration.
"""
microphone_index: int
sampling_rate: int | None = None
channels: list[int] | None = None
data_type: str | None = None
mock: bool = False

View File

@ -0,0 +1,43 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Protocol
from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig, MicrophoneConfigBase
# Defines a microphone type
class Microphone(Protocol):
def connect(self): ...
def disconnect(self): ...
def start_recording(self, output_file: str | None = None): ...
def stop_recording(self): ...
def make_microphones_from_configs(microphone_configs: dict[str, MicrophoneConfigBase]) -> list[Microphone]:
microphones = {}
for key, cfg in microphone_configs.items():
if cfg.type == "microphone":
from lerobot.common.robot_devices.microphones.microphone import Microphone
microphones[key] = Microphone(cfg)
else:
raise ValueError(f"The microphone type '{cfg.type}' is not valid.")
return microphones
def make_microphone(microphone_type, **kwargs) -> Microphone:
if microphone_type == "microphone":
from lerobot.common.robot_devices.microphones.microphone import Microphone
return Microphone(MicrophoneConfig(**kwargs))
else:
raise ValueError(f"The microphone type '{microphone_type}' is not valid.")

View File

@ -28,6 +28,7 @@ from lerobot.common.robot_devices.motors.configs import (
FeetechMotorsBusConfig, FeetechMotorsBusConfig,
MotorsBusConfig, MotorsBusConfig,
) )
from lerobot.common.robot_devices.microphones.configs import MicrophoneConfig
@dataclass @dataclass
@ -68,6 +69,9 @@ class ManipulatorRobotConfig(RobotConfig):
for cam in self.cameras.values(): for cam in self.cameras.values():
if not cam.mock: if not cam.mock:
cam.mock = True cam.mock = True
for mic in self.microphones.values():
if not mic.mock:
mic.mock = True
if self.max_relative_target is not None and isinstance(self.max_relative_target, Sequence): if self.max_relative_target is not None and isinstance(self.max_relative_target, Sequence):
for name in self.follower_arms: for name in self.follower_arms:

View File

@ -28,6 +28,7 @@ import numpy as np
import torch import torch
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
from lerobot.common.robot_devices.microphones.utils import make_microphones_from_configs
from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs
from lerobot.common.robot_devices.robots.configs import ManipulatorRobotConfig from lerobot.common.robot_devices.robots.configs import ManipulatorRobotConfig
from lerobot.common.robot_devices.robots.utils import get_arm_id from lerobot.common.robot_devices.robots.utils import get_arm_id
@ -164,6 +165,7 @@ class ManipulatorRobot:
self.leader_arms = make_motors_buses_from_configs(self.config.leader_arms) self.leader_arms = make_motors_buses_from_configs(self.config.leader_arms)
self.follower_arms = make_motors_buses_from_configs(self.config.follower_arms) self.follower_arms = make_motors_buses_from_configs(self.config.follower_arms)
self.cameras = make_cameras_from_configs(self.config.cameras) self.cameras = make_cameras_from_configs(self.config.cameras)
self.microphones = make_microphones_from_configs(self.config.microphones)
self.is_connected = False self.is_connected = False
self.logs = {} self.logs = {}
@ -199,6 +201,18 @@ class ManipulatorRobot:
}, },
} }
@property
def microphones_features(self) -> dict:
mic_ft = {}
for mic_key, mic in self.microphones.items():
key = f"observation.audio.{mic_key}"
mic_ft[key] = {
"dtype": mic.data_type,
"shape": (mic.channels,),
"info": None,
}
return mic_ft
@property @property
def features(self): def features(self):
return {**self.motor_features, **self.camera_features} return {**self.motor_features, **self.camera_features}
@ -211,6 +225,14 @@ class ManipulatorRobot:
def num_cameras(self): def num_cameras(self):
return len(self.cameras) return len(self.cameras)
@property
def has_microphone(self):
return len(self.microphones) > 0
@property
def num_microphones(self):
return len(self.microphones)
@property @property
def available_arms(self): def available_arms(self):
available_arms = [] available_arms = []
@ -228,7 +250,7 @@ class ManipulatorRobot:
"ManipulatorRobot is already connected. Do not run `robot.connect()` twice." "ManipulatorRobot is already connected. Do not run `robot.connect()` twice."
) )
if not self.leader_arms and not self.follower_arms and not self.cameras: if not self.leader_arms and not self.follower_arms and not self.cameras and not self.microphones:
raise ValueError( raise ValueError(
"ManipulatorRobot doesn't have any device to connect. See example of usage in docstring of the class." "ManipulatorRobot doesn't have any device to connect. See example of usage in docstring of the class."
) )
@ -289,6 +311,10 @@ class ManipulatorRobot:
for name in self.cameras: for name in self.cameras:
self.cameras[name].connect() self.cameras[name].connect()
# Connect the microphones
for name in self.microphones:
self.microphones[name].connect()
self.is_connected = True self.is_connected = True
def activate_calibration(self): def activate_calibration(self):
@ -620,6 +646,9 @@ class ManipulatorRobot:
for name in self.cameras: for name in self.cameras:
self.cameras[name].disconnect() self.cameras[name].disconnect()
for name in self.microphones:
self.microphones[name].disconnect()
self.is_connected = False self.is_connected = False
def __del__(self): def __del__(self):