Adding missing features for audio frames verification and stats
This commit is contained in:
parent
a18d0e4678
commit
b00e866c60
|
@ -15,8 +15,7 @@
|
|||
# limitations under the License.
|
||||
import numpy as np
|
||||
|
||||
from lerobot.common.datasets.utils import load_image_as_numpy
|
||||
|
||||
from lerobot.common.datasets.utils import load_image_as_numpy, load_audio
|
||||
|
||||
def estimate_num_samples(
|
||||
dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75
|
||||
|
@ -71,6 +70,12 @@ def sample_images(image_paths: list[str]) -> np.ndarray:
|
|||
|
||||
return images
|
||||
|
||||
def sample_audio(audio_path: str) -> np.ndarray:
|
||||
|
||||
data = load_audio(audio_path)
|
||||
sampled_indices = sample_indices(len(data))
|
||||
|
||||
return(data[sampled_indices])
|
||||
|
||||
def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]:
|
||||
return {
|
||||
|
@ -91,6 +96,10 @@ def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], featu
|
|||
ep_ft_array = sample_images(data) # data is a list of image paths
|
||||
axes_to_reduce = (0, 2, 3) # keep channel dim
|
||||
keepdims = True
|
||||
elif features[key]["dtype"] == "audio":
|
||||
ep_ft_array = sample_audio(data[0])
|
||||
axes_to_reduce = 0
|
||||
keepdims = True
|
||||
else:
|
||||
ep_ft_array = data # data is already a np.ndarray
|
||||
axes_to_reduce = 0 # compute stats over the first axis
|
||||
|
|
|
@ -906,6 +906,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
img_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._save_image(frame[key], img_path)
|
||||
self.episode_buffer[key].append(str(img_path))
|
||||
elif self.features[key]["dtype"] == "audio":
|
||||
if frame_index == 0:
|
||||
audio_path = self._get_raw_audio_file_path(
|
||||
episode_index=self.episode_buffer["episode_index"], audio_key=key
|
||||
)
|
||||
self.episode_buffer[key].append(str(audio_path))
|
||||
else:
|
||||
self.episode_buffer[key].append(frame[key])
|
||||
|
||||
|
|
|
@ -35,6 +35,8 @@ from huggingface_hub.errors import RevisionNotFoundError
|
|||
from PIL import Image as PILImage
|
||||
from torchvision import transforms
|
||||
|
||||
from soundfile import read
|
||||
|
||||
from lerobot.common.datasets.backward_compatibility import (
|
||||
V21_MESSAGE,
|
||||
BackwardCompatibilityError,
|
||||
|
@ -258,6 +260,9 @@ def load_image_as_numpy(
|
|||
img_array /= 255.0
|
||||
return img_array
|
||||
|
||||
def load_audio(fpath: str | Path) -> np.ndarray:
|
||||
audio_data, _ = read(fpath, dtype="float32")
|
||||
return audio_data
|
||||
|
||||
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
|
||||
|
@ -752,6 +757,8 @@ def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray
|
|||
return validate_feature_numpy_array(name, expected_dtype, expected_shape, value)
|
||||
elif expected_dtype in ["image", "video"]:
|
||||
return validate_feature_image_or_video(name, expected_shape, value)
|
||||
elif expected_dtype == "audio":
|
||||
return validate_feature_audio(name, expected_shape, value)
|
||||
elif expected_dtype == "string":
|
||||
return validate_feature_string(name, value)
|
||||
else:
|
||||
|
@ -792,6 +799,17 @@ def validate_feature_image_or_video(name: str, expected_shape: list[str], value:
|
|||
|
||||
return error_message
|
||||
|
||||
def validate_feature_audio(name: str, expected_shape: list[str], value: np.ndarray):
|
||||
error_message = ""
|
||||
if isinstance(value, np.ndarray):
|
||||
actual_shape = value.shape
|
||||
c = expected_shape
|
||||
if len(actual_shape) != 2 or (actual_shape[-1] != c[-1] and actual_shape[0] != c[0]): #The number of frames might be different
|
||||
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c,)}'.\n"
|
||||
else:
|
||||
error_message += f"The feature '{name}' is expected to be of type 'np.ndarray', but type '{type(value)}' provided instead.\n"
|
||||
|
||||
return error_message
|
||||
|
||||
def validate_feature_string(name: str, value: str):
|
||||
if not isinstance(value, str):
|
||||
|
|
|
@ -279,9 +279,7 @@ def control_loop(
|
|||
action = {"action": action}
|
||||
|
||||
if dataset is not None:
|
||||
#Remove audio frames which are directly written in a dedicated file
|
||||
audioless_observation = {key: observation[key] for key in observation if key not in robot.microphones}
|
||||
frame = {**audioless_observation, **action, "task": single_task}
|
||||
frame = {**observation, **action, "task": single_task}
|
||||
dataset.add_frame(frame)
|
||||
|
||||
# TODO(Steven): This should be more general (for RemoteRobot instead of checking the name, but anyways it will change soon)
|
||||
|
|
|
@ -242,7 +242,7 @@ class Microphone:
|
|||
with self.read_queue.mutex:
|
||||
self.read_queue.queue.clear()
|
||||
#self.read_queue.all_tasks_done.notify_all()
|
||||
audio_readings = np.array(audio_readings).reshape(-1, len(self.channels))
|
||||
audio_readings = np.array(audio_readings, dtype=np.float32).reshape(-1, len(self.channels))
|
||||
|
||||
return audio_readings
|
||||
|
||||
|
|
|
@ -556,6 +556,8 @@ class ManipulatorRobot:
|
|||
action_dict["action"] = action
|
||||
for name in self.cameras:
|
||||
obs_dict[f"observation.images.{name}"] = images[name]
|
||||
for name in self.microphones:
|
||||
obs_dict[f"observation.audio.{name}"] = audio[name]
|
||||
|
||||
return obs_dict, action_dict
|
||||
|
||||
|
@ -604,6 +606,8 @@ class ManipulatorRobot:
|
|||
obs_dict["observation.state"] = state
|
||||
for name in self.cameras:
|
||||
obs_dict[f"observation.images.{name}"] = images[name]
|
||||
for name in self.microphones:
|
||||
obs_dict[f"observation.audio.{name}"] = audio[name]
|
||||
return obs_dict
|
||||
|
||||
def send_action(self, action: torch.Tensor) -> torch.Tensor:
|
||||
|
|
Loading…
Reference in New Issue