diff --git a/lerobot/common/datasets/compute_stats.py b/lerobot/common/datasets/compute_stats.py index 1149ec83..b24dbaf8 100644 --- a/lerobot/common/datasets/compute_stats.py +++ b/lerobot/common/datasets/compute_stats.py @@ -15,8 +15,7 @@ # limitations under the License. import numpy as np -from lerobot.common.datasets.utils import load_image_as_numpy - +from lerobot.common.datasets.utils import load_image_as_numpy, load_audio def estimate_num_samples( dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75 @@ -71,6 +70,12 @@ def sample_images(image_paths: list[str]) -> np.ndarray: return images +def sample_audio(audio_path: str) -> np.ndarray: + + data = load_audio(audio_path) + sampled_indices = sample_indices(len(data)) + + return(data[sampled_indices]) def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]: return { @@ -91,6 +96,10 @@ def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], featu ep_ft_array = sample_images(data) # data is a list of image paths axes_to_reduce = (0, 2, 3) # keep channel dim keepdims = True + elif features[key]["dtype"] == "audio": + ep_ft_array = sample_audio(data[0]) + axes_to_reduce = 0 + keepdims = True else: ep_ft_array = data # data is already a np.ndarray axes_to_reduce = 0 # compute stats over the first axis diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index a207845f..411a4676 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -906,6 +906,12 @@ class LeRobotDataset(torch.utils.data.Dataset): img_path.parent.mkdir(parents=True, exist_ok=True) self._save_image(frame[key], img_path) self.episode_buffer[key].append(str(img_path)) + elif self.features[key]["dtype"] == "audio": + if frame_index == 0: + audio_path = self._get_raw_audio_file_path( + episode_index=self.episode_buffer["episode_index"], audio_key=key + ) + self.episode_buffer[key].append(str(audio_path)) else: self.episode_buffer[key].append(frame[key]) diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 827e711b..a7bda7f2 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -35,6 +35,8 @@ from huggingface_hub.errors import RevisionNotFoundError from PIL import Image as PILImage from torchvision import transforms +from soundfile import read + from lerobot.common.datasets.backward_compatibility import ( V21_MESSAGE, BackwardCompatibilityError, @@ -258,6 +260,9 @@ def load_image_as_numpy( img_array /= 255.0 return img_array +def load_audio(fpath: str | Path) -> np.ndarray: + audio_data, _ = read(fpath, dtype="float32") + return audio_data def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]): """Get a transform function that convert items from Hugging Face dataset (pyarrow) @@ -752,6 +757,8 @@ def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray return validate_feature_numpy_array(name, expected_dtype, expected_shape, value) elif expected_dtype in ["image", "video"]: return validate_feature_image_or_video(name, expected_shape, value) + elif expected_dtype == "audio": + return validate_feature_audio(name, expected_shape, value) elif expected_dtype == "string": return validate_feature_string(name, value) else: @@ -792,6 +799,17 @@ def validate_feature_image_or_video(name: str, expected_shape: list[str], value: return error_message +def validate_feature_audio(name: str, expected_shape: list[str], value: np.ndarray): + error_message = "" + if isinstance(value, np.ndarray): + actual_shape = value.shape + c = expected_shape + if len(actual_shape) != 2 or (actual_shape[-1] != c[-1] and actual_shape[0] != c[0]): #The number of frames might be different + error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c,)}'.\n" + else: + error_message += f"The feature '{name}' is expected to be of type 'np.ndarray', but type '{type(value)}' provided instead.\n" + + return error_message def validate_feature_string(name: str, value: str): if not isinstance(value, str): diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py index 9c2f0f47..7f71b1ee 100644 --- a/lerobot/common/robot_devices/control_utils.py +++ b/lerobot/common/robot_devices/control_utils.py @@ -279,9 +279,7 @@ def control_loop( action = {"action": action} if dataset is not None: - #Remove audio frames which are directly written in a dedicated file - audioless_observation = {key: observation[key] for key in observation if key not in robot.microphones} - frame = {**audioless_observation, **action, "task": single_task} + frame = {**observation, **action, "task": single_task} dataset.add_frame(frame) # TODO(Steven): This should be more general (for RemoteRobot instead of checking the name, but anyways it will change soon) diff --git a/lerobot/common/robot_devices/microphones/microphone.py b/lerobot/common/robot_devices/microphones/microphone.py index 923d9b5e..8ab8b362 100644 --- a/lerobot/common/robot_devices/microphones/microphone.py +++ b/lerobot/common/robot_devices/microphones/microphone.py @@ -242,7 +242,7 @@ class Microphone: with self.read_queue.mutex: self.read_queue.queue.clear() #self.read_queue.all_tasks_done.notify_all() - audio_readings = np.array(audio_readings).reshape(-1, len(self.channels)) + audio_readings = np.array(audio_readings, dtype=np.float32).reshape(-1, len(self.channels)) return audio_readings diff --git a/lerobot/common/robot_devices/robots/manipulator.py b/lerobot/common/robot_devices/robots/manipulator.py index b5dc60c8..7e849914 100644 --- a/lerobot/common/robot_devices/robots/manipulator.py +++ b/lerobot/common/robot_devices/robots/manipulator.py @@ -556,6 +556,8 @@ class ManipulatorRobot: action_dict["action"] = action for name in self.cameras: obs_dict[f"observation.images.{name}"] = images[name] + for name in self.microphones: + obs_dict[f"observation.audio.{name}"] = audio[name] return obs_dict, action_dict @@ -604,6 +606,8 @@ class ManipulatorRobot: obs_dict["observation.state"] = state for name in self.cameras: obs_dict[f"observation.images.{name}"] = images[name] + for name in self.microphones: + obs_dict[f"observation.audio.{name}"] = audio[name] return obs_dict def send_action(self, action: torch.Tensor) -> torch.Tensor: