Adding missing features for audio frames verification and stats

This commit is contained in:
CarolinePascal 2025-04-04 19:48:57 +02:00
parent a18d0e4678
commit b00e866c60
No known key found for this signature in database
6 changed files with 41 additions and 6 deletions

View File

@ -15,8 +15,7 @@
# limitations under the License.
import numpy as np
from lerobot.common.datasets.utils import load_image_as_numpy
from lerobot.common.datasets.utils import load_image_as_numpy, load_audio
def estimate_num_samples(
dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75
@ -71,6 +70,12 @@ def sample_images(image_paths: list[str]) -> np.ndarray:
return images
def sample_audio(audio_path: str) -> np.ndarray:
data = load_audio(audio_path)
sampled_indices = sample_indices(len(data))
return(data[sampled_indices])
def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]:
return {
@ -91,6 +96,10 @@ def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], featu
ep_ft_array = sample_images(data) # data is a list of image paths
axes_to_reduce = (0, 2, 3) # keep channel dim
keepdims = True
elif features[key]["dtype"] == "audio":
ep_ft_array = sample_audio(data[0])
axes_to_reduce = 0
keepdims = True
else:
ep_ft_array = data # data is already a np.ndarray
axes_to_reduce = 0 # compute stats over the first axis

View File

@ -906,6 +906,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
img_path.parent.mkdir(parents=True, exist_ok=True)
self._save_image(frame[key], img_path)
self.episode_buffer[key].append(str(img_path))
elif self.features[key]["dtype"] == "audio":
if frame_index == 0:
audio_path = self._get_raw_audio_file_path(
episode_index=self.episode_buffer["episode_index"], audio_key=key
)
self.episode_buffer[key].append(str(audio_path))
else:
self.episode_buffer[key].append(frame[key])

View File

@ -35,6 +35,8 @@ from huggingface_hub.errors import RevisionNotFoundError
from PIL import Image as PILImage
from torchvision import transforms
from soundfile import read
from lerobot.common.datasets.backward_compatibility import (
V21_MESSAGE,
BackwardCompatibilityError,
@ -258,6 +260,9 @@ def load_image_as_numpy(
img_array /= 255.0
return img_array
def load_audio(fpath: str | Path) -> np.ndarray:
audio_data, _ = read(fpath, dtype="float32")
return audio_data
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
@ -752,6 +757,8 @@ def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray
return validate_feature_numpy_array(name, expected_dtype, expected_shape, value)
elif expected_dtype in ["image", "video"]:
return validate_feature_image_or_video(name, expected_shape, value)
elif expected_dtype == "audio":
return validate_feature_audio(name, expected_shape, value)
elif expected_dtype == "string":
return validate_feature_string(name, value)
else:
@ -792,6 +799,17 @@ def validate_feature_image_or_video(name: str, expected_shape: list[str], value:
return error_message
def validate_feature_audio(name: str, expected_shape: list[str], value: np.ndarray):
error_message = ""
if isinstance(value, np.ndarray):
actual_shape = value.shape
c = expected_shape
if len(actual_shape) != 2 or (actual_shape[-1] != c[-1] and actual_shape[0] != c[0]): #The number of frames might be different
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c,)}'.\n"
else:
error_message += f"The feature '{name}' is expected to be of type 'np.ndarray', but type '{type(value)}' provided instead.\n"
return error_message
def validate_feature_string(name: str, value: str):
if not isinstance(value, str):

View File

@ -279,9 +279,7 @@ def control_loop(
action = {"action": action}
if dataset is not None:
#Remove audio frames which are directly written in a dedicated file
audioless_observation = {key: observation[key] for key in observation if key not in robot.microphones}
frame = {**audioless_observation, **action, "task": single_task}
frame = {**observation, **action, "task": single_task}
dataset.add_frame(frame)
# TODO(Steven): This should be more general (for RemoteRobot instead of checking the name, but anyways it will change soon)

View File

@ -242,7 +242,7 @@ class Microphone:
with self.read_queue.mutex:
self.read_queue.queue.clear()
#self.read_queue.all_tasks_done.notify_all()
audio_readings = np.array(audio_readings).reshape(-1, len(self.channels))
audio_readings = np.array(audio_readings, dtype=np.float32).reshape(-1, len(self.channels))
return audio_readings

View File

@ -556,6 +556,8 @@ class ManipulatorRobot:
action_dict["action"] = action
for name in self.cameras:
obs_dict[f"observation.images.{name}"] = images[name]
for name in self.microphones:
obs_dict[f"observation.audio.{name}"] = audio[name]
return obs_dict, action_dict
@ -604,6 +606,8 @@ class ManipulatorRobot:
obs_dict["observation.state"] = state
for name in self.cameras:
obs_dict[f"observation.images.{name}"] = images[name]
for name in self.microphones:
obs_dict[f"observation.audio.{name}"] = audio[name]
return obs_dict
def send_action(self, action: torch.Tensor) -> torch.Tensor: