Adding support for audio data recording and broadcasting for LeKiwi
This commit is contained in:
parent
1e5e631743
commit
ec8943db37
|
@ -15,7 +15,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from lerobot.common.datasets.utils import load_image_as_numpy, load_audio
|
from lerobot.common.datasets.utils import load_image_as_numpy, load_audio_from_path
|
||||||
|
|
||||||
def estimate_num_samples(
|
def estimate_num_samples(
|
||||||
dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75
|
dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75
|
||||||
|
@ -70,13 +70,17 @@ def sample_images(image_paths: list[str]) -> np.ndarray:
|
||||||
|
|
||||||
return images
|
return images
|
||||||
|
|
||||||
def sample_audio(audio_path: str) -> np.ndarray:
|
def sample_audio_from_path(audio_path: str) -> np.ndarray:
|
||||||
|
|
||||||
data = load_audio(audio_path)
|
data = load_audio_from_path(audio_path)
|
||||||
sampled_indices = sample_indices(len(data))
|
sampled_indices = sample_indices(len(data))
|
||||||
|
|
||||||
return(data[sampled_indices])
|
return(data[sampled_indices])
|
||||||
|
|
||||||
|
def sample_audio_from_data(data: np.ndarray) -> np.ndarray:
|
||||||
|
sampled_indices = sample_indices(len(data))
|
||||||
|
return data[sampled_indices]
|
||||||
|
|
||||||
def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]:
|
def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]:
|
||||||
return {
|
return {
|
||||||
"min": np.min(array, axis=axis, keepdims=keepdims),
|
"min": np.min(array, axis=axis, keepdims=keepdims),
|
||||||
|
@ -97,7 +101,10 @@ def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], featu
|
||||||
axes_to_reduce = (0, 2, 3) # keep channel dim
|
axes_to_reduce = (0, 2, 3) # keep channel dim
|
||||||
keepdims = True
|
keepdims = True
|
||||||
elif features[key]["dtype"] == "audio":
|
elif features[key]["dtype"] == "audio":
|
||||||
ep_ft_array = sample_audio(data[0])
|
try:
|
||||||
|
ep_ft_array = sample_audio_from_path(data[0])
|
||||||
|
except TypeError: #Should only be triggered for LeKiwi robot
|
||||||
|
ep_ft_array = sample_audio_from_data(data)
|
||||||
axes_to_reduce = 0
|
axes_to_reduce = 0
|
||||||
keepdims = True
|
keepdims = True
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -80,6 +80,7 @@ from lerobot.common.datasets.video_utils import (
|
||||||
)
|
)
|
||||||
from lerobot.common.robot_devices.robots.utils import Robot
|
from lerobot.common.robot_devices.robots.utils import Robot
|
||||||
from lerobot.common.robot_devices.microphones.utils import Microphone
|
from lerobot.common.robot_devices.microphones.utils import Microphone
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
CODEBASE_VERSION = "v2.1"
|
CODEBASE_VERSION = "v2.1"
|
||||||
|
|
||||||
|
@ -324,7 +325,7 @@ class LeRobotDatasetMetadata:
|
||||||
been encoded the same way. Also, this means it assumes the first episode exists.
|
been encoded the same way. Also, this means it assumes the first episode exists.
|
||||||
"""
|
"""
|
||||||
for key in set(self.audio_keys) - set(self.audio_camera_keys_mapping.keys()):
|
for key in set(self.audio_keys) - set(self.audio_camera_keys_mapping.keys()):
|
||||||
if not self.features[key].get("info", None):
|
if not self.features[key].get("info", None) or (len(self.features[key]["info"]) == 1 and "sample_rate" in self.features[key]["info"]):
|
||||||
audio_path = self.root / self.get_compressed_audio_file_path(0, key)
|
audio_path = self.root / self.get_compressed_audio_file_path(0, key)
|
||||||
self.info["features"][key]["info"] = get_audio_info(audio_path)
|
self.info["features"][key]["info"] = get_audio_info(audio_path)
|
||||||
|
|
||||||
|
@ -910,11 +911,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
self._save_image(frame[key], img_path)
|
self._save_image(frame[key], img_path)
|
||||||
self.episode_buffer[key].append(str(img_path))
|
self.episode_buffer[key].append(str(img_path))
|
||||||
elif self.features[key]["dtype"] == "audio":
|
elif self.features[key]["dtype"] == "audio":
|
||||||
if frame_index == 0:
|
if self.meta.robot_type.startswith("lekiwi"):
|
||||||
audio_path = self._get_raw_audio_file_path(
|
self.episode_buffer[key].append(frame[key])
|
||||||
episode_index=self.episode_buffer["episode_index"], audio_key=key
|
else:
|
||||||
)
|
if frame_index == 0:
|
||||||
self.episode_buffer[key].append(str(audio_path))
|
audio_path = self._get_raw_audio_file_path(
|
||||||
|
episode_index=self.episode_buffer["episode_index"], audio_key=key
|
||||||
|
)
|
||||||
|
self.episode_buffer[key].append(str(audio_path))
|
||||||
else:
|
else:
|
||||||
self.episode_buffer[key].append(frame[key])
|
self.episode_buffer[key].append(frame[key])
|
||||||
|
|
||||||
|
@ -966,12 +970,23 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
for key, ft in self.features.items():
|
for key, ft in self.features.items():
|
||||||
# index, episode_index, task_index are already processed above, and image and video
|
# index, episode_index, task_index are already processed above, and image and video
|
||||||
# are processed separately by storing image path and frame info as meta data
|
# are processed separately by storing image path and frame info as meta data
|
||||||
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video", "audio"]:
|
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
|
||||||
|
continue
|
||||||
|
elif ft["dtype"] == "audio":
|
||||||
|
if self.meta.robot_type.startswith("lekiwi"):
|
||||||
|
episode_buffer[key] = np.concatenate(episode_buffer[key], axis=0)
|
||||||
continue
|
continue
|
||||||
episode_buffer[key] = np.stack(episode_buffer[key])
|
episode_buffer[key] = np.stack(episode_buffer[key])
|
||||||
|
|
||||||
self._wait_image_writer()
|
self._wait_image_writer()
|
||||||
self._save_episode_table(episode_buffer, episode_index)
|
self._save_episode_table(episode_buffer, episode_index)
|
||||||
|
|
||||||
|
if self.meta.robot_type.startswith("lekiwi"):
|
||||||
|
for key in self.meta.audio_keys:
|
||||||
|
audio_path = self._get_raw_audio_file_path(episode_index=self.episode_buffer["episode_index"][0], audio_key=key)
|
||||||
|
with sf.SoundFile(audio_path, mode='w', samplerate=self.meta.features[key]["info"]["sample_rate"], channels=self.meta.features[key]["shape"][0]) as file:
|
||||||
|
file.write(episode_buffer[key])
|
||||||
|
|
||||||
ep_stats = compute_episode_stats(episode_buffer, self.features)
|
ep_stats = compute_episode_stats(episode_buffer, self.features)
|
||||||
|
|
||||||
if len(self.meta.video_keys) > 0:
|
if len(self.meta.video_keys) > 0:
|
||||||
|
|
|
@ -260,7 +260,7 @@ def load_image_as_numpy(
|
||||||
img_array /= 255.0
|
img_array /= 255.0
|
||||||
return img_array
|
return img_array
|
||||||
|
|
||||||
def load_audio(fpath: str | Path) -> np.ndarray:
|
def load_audio_from_path(fpath: str | Path) -> np.ndarray:
|
||||||
audio_data, _ = read(fpath, dtype="float32")
|
audio_data, _ = read(fpath, dtype="float32")
|
||||||
return audio_data
|
return audio_data
|
||||||
|
|
||||||
|
|
|
@ -252,7 +252,7 @@ def control_loop(
|
||||||
timestamp = 0
|
timestamp = 0
|
||||||
start_episode_t = time.perf_counter()
|
start_episode_t = time.perf_counter()
|
||||||
|
|
||||||
if dataset is not None:
|
if dataset is not None and not robot.robot_type.startswith("lekiwi"): #For now, LeKiwi only supports frame audio recording (which may lead to audio chunks loss, extended post-processing, increased memory usage)
|
||||||
for microphone_key, microphone in robot.microphones.items():
|
for microphone_key, microphone in robot.microphones.items():
|
||||||
#Start recording both in file writing and data reading mode
|
#Start recording both in file writing and data reading mode
|
||||||
dataset.add_microphone_recording(microphone, microphone_key)
|
dataset.add_microphone_recording(microphone, microphone_key)
|
||||||
|
|
|
@ -51,6 +51,14 @@ def run_camera_capture(cameras, images_lock, latest_images_dict, stop_event):
|
||||||
latest_images_dict.update(local_dict)
|
latest_images_dict.update(local_dict)
|
||||||
time.sleep(0.01)
|
time.sleep(0.01)
|
||||||
|
|
||||||
|
def run_microphone_capture(microphones, audio_lock, latest_audio_dict, stop_event):
|
||||||
|
while not stop_event.is_set():
|
||||||
|
local_dict = {}
|
||||||
|
for name, microphone in microphones.items():
|
||||||
|
audio_readings = microphone.read()
|
||||||
|
local_dict[name] = audio_readings
|
||||||
|
with audio_lock:
|
||||||
|
latest_audio_dict.update(local_dict)
|
||||||
|
|
||||||
def calibrate_follower_arm(motors_bus, calib_dir_str):
|
def calibrate_follower_arm(motors_bus, calib_dir_str):
|
||||||
"""
|
"""
|
||||||
|
@ -94,6 +102,7 @@ def run_lekiwi(robot_config):
|
||||||
"""
|
"""
|
||||||
# Import helper functions and classes
|
# Import helper functions and classes
|
||||||
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
|
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
|
||||||
|
from lerobot.common.robot_devices.microphones.utils import make_microphones_from_configs
|
||||||
from lerobot.common.robot_devices.motors.feetech import FeetechMotorsBus, TorqueMode
|
from lerobot.common.robot_devices.motors.feetech import FeetechMotorsBus, TorqueMode
|
||||||
|
|
||||||
# Initialize cameras from the robot configuration.
|
# Initialize cameras from the robot configuration.
|
||||||
|
@ -101,6 +110,11 @@ def run_lekiwi(robot_config):
|
||||||
for cam in cameras.values():
|
for cam in cameras.values():
|
||||||
cam.connect()
|
cam.connect()
|
||||||
|
|
||||||
|
# Initialize microphones from the robot configuration.
|
||||||
|
microphones = make_microphones_from_configs(robot_config.microphones)
|
||||||
|
for microphone in microphones.values():
|
||||||
|
microphone.connect()
|
||||||
|
|
||||||
# Initialize the motors bus using the follower arm configuration.
|
# Initialize the motors bus using the follower arm configuration.
|
||||||
motor_config = robot_config.follower_arms.get("main")
|
motor_config = robot_config.follower_arms.get("main")
|
||||||
if motor_config is None:
|
if motor_config is None:
|
||||||
|
@ -134,6 +148,18 @@ def run_lekiwi(robot_config):
|
||||||
)
|
)
|
||||||
cam_thread.start()
|
cam_thread.start()
|
||||||
|
|
||||||
|
# Start the microphone recording and capture thread.
|
||||||
|
#TODO(CarolinePascal) : Leverage multi-core processing with a multiprocessing.Process instead !
|
||||||
|
latest_audio_dict = {}
|
||||||
|
audio_lock = threading.Lock()
|
||||||
|
audio_stop_event = threading.Event()
|
||||||
|
microphone_thread = threading.Thread(
|
||||||
|
target=run_microphone_capture, args=(microphones, audio_lock, latest_audio_dict, audio_stop_event), daemon=True
|
||||||
|
)
|
||||||
|
for microphone in microphones.values():
|
||||||
|
microphone.start_recording()
|
||||||
|
microphone_thread.start()
|
||||||
|
|
||||||
last_cmd_time = time.time()
|
last_cmd_time = time.time()
|
||||||
print("LeKiwi robot server started. Waiting for commands...")
|
print("LeKiwi robot server started. Waiting for commands...")
|
||||||
|
|
||||||
|
@ -198,9 +224,14 @@ def run_lekiwi(robot_config):
|
||||||
with images_lock:
|
with images_lock:
|
||||||
images_dict_copy = dict(latest_images_dict)
|
images_dict_copy = dict(latest_images_dict)
|
||||||
|
|
||||||
|
# Get the latest audio data.
|
||||||
|
with audio_lock:
|
||||||
|
audio_dict_copy = dict(latest_audio_dict)
|
||||||
|
|
||||||
# Build the observation dictionary.
|
# Build the observation dictionary.
|
||||||
observation = {
|
observation = {
|
||||||
"images": images_dict_copy,
|
"images": images_dict_copy,
|
||||||
|
"audio": audio_dict_copy, #TODO(CarolinePascal) : This is a nasty way to do it, sorry.
|
||||||
"present_speed": current_velocity,
|
"present_speed": current_velocity,
|
||||||
"follower_arm_state": follower_arm_state,
|
"follower_arm_state": follower_arm_state,
|
||||||
}
|
}
|
||||||
|
@ -217,6 +248,9 @@ def run_lekiwi(robot_config):
|
||||||
finally:
|
finally:
|
||||||
stop_event.set()
|
stop_event.set()
|
||||||
cam_thread.join()
|
cam_thread.join()
|
||||||
|
microphone_thread.join()
|
||||||
|
for microphone in microphones.values():
|
||||||
|
microphone.stop_recording()
|
||||||
robot.stop()
|
robot.stop()
|
||||||
motors_bus.disconnect()
|
motors_bus.disconnect()
|
||||||
cmd_socket.close()
|
cmd_socket.close()
|
||||||
|
|
|
@ -211,7 +211,7 @@ class ManipulatorRobot:
|
||||||
"dtype": "audio",
|
"dtype": "audio",
|
||||||
"shape": (len(mic.channels),),
|
"shape": (len(mic.channels),),
|
||||||
"names": "channels",
|
"names": "channels",
|
||||||
"info" : None,
|
"info" : {"sample_rate": mic.sample_rate},
|
||||||
}
|
}
|
||||||
return mic_ft
|
return mic_ft
|
||||||
|
|
||||||
|
|
|
@ -24,6 +24,7 @@ import torch
|
||||||
import zmq
|
import zmq
|
||||||
|
|
||||||
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
|
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
|
||||||
|
from lerobot.common.robot_devices.microphones.utils import make_microphones_from_configs
|
||||||
from lerobot.common.robot_devices.motors.feetech import TorqueMode
|
from lerobot.common.robot_devices.motors.feetech import TorqueMode
|
||||||
from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs
|
from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs
|
||||||
from lerobot.common.robot_devices.robots.configs import LeKiwiRobotConfig
|
from lerobot.common.robot_devices.robots.configs import LeKiwiRobotConfig
|
||||||
|
@ -79,6 +80,7 @@ class MobileManipulator:
|
||||||
self.follower_arms = make_motors_buses_from_configs(self.config.follower_arms)
|
self.follower_arms = make_motors_buses_from_configs(self.config.follower_arms)
|
||||||
|
|
||||||
self.cameras = make_cameras_from_configs(self.config.cameras)
|
self.cameras = make_cameras_from_configs(self.config.cameras)
|
||||||
|
self.microphones = make_microphones_from_configs(self.config.microphones)
|
||||||
|
|
||||||
self.is_connected = False
|
self.is_connected = False
|
||||||
|
|
||||||
|
@ -133,6 +135,7 @@ class MobileManipulator:
|
||||||
"shape": (cam.height, cam.width, cam.channels),
|
"shape": (cam.height, cam.width, cam.channels),
|
||||||
"names": ["height", "width", "channels"],
|
"names": ["height", "width", "channels"],
|
||||||
"info": None,
|
"info": None,
|
||||||
|
"audio": "observation.audio." + cam.microphone if cam.microphone is not None else None,
|
||||||
}
|
}
|
||||||
return cam_ft
|
return cam_ft
|
||||||
|
|
||||||
|
@ -160,10 +163,23 @@ class MobileManipulator:
|
||||||
"names": combined_names,
|
"names": combined_names,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def microphone_features(self) -> dict:
|
||||||
|
mic_ft = {}
|
||||||
|
for mic_key, mic in self.microphones.items():
|
||||||
|
key = f"observation.audio.{mic_key}"
|
||||||
|
mic_ft[key] = {
|
||||||
|
"dtype": "audio",
|
||||||
|
"shape": (len(mic.channels),),
|
||||||
|
"names": "channels",
|
||||||
|
"info" : {"sample_rate": mic.sample_rate},
|
||||||
|
}
|
||||||
|
return mic_ft
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def features(self):
|
def features(self):
|
||||||
return {**self.motor_features, **self.camera_features}
|
return {**self.motor_features, **self.camera_features, **self.microphone_features}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def has_camera(self):
|
def has_camera(self):
|
||||||
|
@ -172,6 +188,14 @@ class MobileManipulator:
|
||||||
@property
|
@property
|
||||||
def num_cameras(self):
|
def num_cameras(self):
|
||||||
return len(self.cameras)
|
return len(self.cameras)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_microphone(self):
|
||||||
|
return len(self.microphones) > 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_microphones(self):
|
||||||
|
return len(self.microphones)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def available_arms(self):
|
def available_arms(self):
|
||||||
|
@ -344,6 +368,7 @@ class MobileManipulator:
|
||||||
observation = json.loads(last_msg)
|
observation = json.loads(last_msg)
|
||||||
|
|
||||||
images_dict = observation.get("images", {})
|
images_dict = observation.get("images", {})
|
||||||
|
audio_dict = observation.get("audio", {})
|
||||||
new_speed = observation.get("present_speed", {})
|
new_speed = observation.get("present_speed", {})
|
||||||
new_arm_state = observation.get("follower_arm_state", None)
|
new_arm_state = observation.get("follower_arm_state", None)
|
||||||
|
|
||||||
|
@ -356,6 +381,11 @@ class MobileManipulator:
|
||||||
if frame_candidate is not None:
|
if frame_candidate is not None:
|
||||||
frames[cam_name] = frame_candidate
|
frames[cam_name] = frame_candidate
|
||||||
|
|
||||||
|
# Recieve audio
|
||||||
|
for microphone_name, audio_data in audio_dict.items():
|
||||||
|
if audio_data:
|
||||||
|
frames[microphone_name] = audio_data
|
||||||
|
|
||||||
# If remote_arm_state is None and frames is None there is no message then use the previous message
|
# If remote_arm_state is None and frames is None there is no message then use the previous message
|
||||||
if new_arm_state is not None and frames is not None:
|
if new_arm_state is not None and frames is not None:
|
||||||
self.last_frames = frames
|
self.last_frames = frames
|
||||||
|
@ -475,6 +505,14 @@ class MobileManipulator:
|
||||||
frame = np.zeros((cam.height, cam.width, cam.channels), dtype=np.uint8)
|
frame = np.zeros((cam.height, cam.width, cam.channels), dtype=np.uint8)
|
||||||
obs_dict[f"observation.images.{cam_name}"] = torch.from_numpy(frame)
|
obs_dict[f"observation.images.{cam_name}"] = torch.from_numpy(frame)
|
||||||
|
|
||||||
|
# Loop over each configured microphone
|
||||||
|
for microphone_name, microphone in self.microphones.items():
|
||||||
|
frame = frames.get(microphone_name, None)
|
||||||
|
if frame is None:
|
||||||
|
# Create silence using the microphone's configured channels
|
||||||
|
frame = np.zeros((1, len(microphone.channels)), dtype=np.float32)
|
||||||
|
obs_dict[f"observation.audio.{microphone_name}"] = torch.from_numpy(frame)
|
||||||
|
|
||||||
return obs_dict
|
return obs_dict
|
||||||
|
|
||||||
def send_action(self, action: torch.Tensor) -> torch.Tensor:
|
def send_action(self, action: torch.Tensor) -> torch.Tensor:
|
||||||
|
|
|
@ -26,7 +26,8 @@ from lerobot.common.datasets.compute_stats import (
|
||||||
estimate_num_samples,
|
estimate_num_samples,
|
||||||
get_feature_stats,
|
get_feature_stats,
|
||||||
sample_images,
|
sample_images,
|
||||||
sample_audio,
|
sample_audio_from_path,
|
||||||
|
sample_audio_from_data,
|
||||||
sample_indices,
|
sample_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -73,10 +74,18 @@ def test_sample_images(mock_load):
|
||||||
assert images.dtype == np.uint8
|
assert images.dtype == np.uint8
|
||||||
assert len(images) == estimate_num_samples(100)
|
assert len(images) == estimate_num_samples(100)
|
||||||
|
|
||||||
@patch("lerobot.common.datasets.compute_stats.load_audio", side_effect=mock_load_audio)
|
@patch("lerobot.common.datasets.compute_stats.load_audio_from_path", side_effect=mock_load_audio)
|
||||||
def test_sample_audio(mock_load):
|
def test_sample_audio_from_path(mock_load):
|
||||||
audio_path = "audio.wav"
|
audio_path = "audio.wav"
|
||||||
audio_samples = sample_audio(audio_path)
|
audio_samples = sample_audio_from_path(audio_path)
|
||||||
|
assert isinstance(audio_samples, np.ndarray)
|
||||||
|
assert audio_samples.shape[1] == 2
|
||||||
|
assert audio_samples.dtype == np.float32
|
||||||
|
assert len(audio_samples) == estimate_num_samples(16000)
|
||||||
|
|
||||||
|
def test_sample_audio_from_data(mock_load):
|
||||||
|
audio_data = np.ones((16000, 2), dtype=np.float32)
|
||||||
|
audio_samples = sample_audio_from_data(audio_data)
|
||||||
assert isinstance(audio_samples, np.ndarray)
|
assert isinstance(audio_samples, np.ndarray)
|
||||||
assert audio_samples.shape[1] == 2
|
assert audio_samples.shape[1] == 2
|
||||||
assert audio_samples.dtype == np.float32
|
assert audio_samples.dtype == np.float32
|
||||||
|
@ -166,7 +175,7 @@ def test_compute_episode_stats():
|
||||||
with patch(
|
with patch(
|
||||||
"lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy
|
"lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy
|
||||||
), patch(
|
), patch(
|
||||||
"lerobot.common.datasets.compute_stats.load_audio", side_effect=mock_load_audio
|
"lerobot.common.datasets.compute_stats.load_audio_from_path", side_effect=mock_load_audio
|
||||||
):
|
):
|
||||||
stats = compute_episode_stats(episode_data, features)
|
stats = compute_episode_stats(episode_data, features)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue