Adding dtype="audio" by default in microphone features

This commit is contained in:
CarolinePascal 2025-04-07 19:08:53 +02:00
parent 96ed10f90d
commit e743f846a7
No known key found for this signature in database
2 changed files with 4 additions and 9 deletions

View File

@ -403,13 +403,7 @@ def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict:
key: {"dtype": "video" if use_videos else "image", **ft} key: {"dtype": "video" if use_videos else "image", **ft}
for key, ft in robot.camera_features.items() for key, ft in robot.camera_features.items()
} }
microphones_ft = {} return {**robot.motor_features, **camera_ft, **robot.microphone_features, **DEFAULT_FEATURES}
if robot.microphones:
microphones_ft = {
key: {"dtype": "audio", **ft}
for key, ft in robot.microphones_features.items()
}
return {**robot.motor_features, **camera_ft, **microphones_ft, **DEFAULT_FEATURES}
def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]: def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]:

View File

@ -203,11 +203,12 @@ class ManipulatorRobot:
} }
@property @property
def microphones_features(self) -> dict: def microphone_features(self) -> dict:
mic_ft = {} mic_ft = {}
for mic_key, mic in self.microphones.items(): for mic_key, mic in self.microphones.items():
key = f"observation.audio.{mic_key}" key = f"observation.audio.{mic_key}"
mic_ft[key] = { mic_ft[key] = {
"dtype": "audio",
"shape": (len(mic.channels),), "shape": (len(mic.channels),),
"names": "channels", "names": "channels",
"info" : None, "info" : None,
@ -216,7 +217,7 @@ class ManipulatorRobot:
@property @property
def features(self): def features(self):
return {**self.motor_features, **self.camera_features, **self.microphones_features} return {**self.motor_features, **self.camera_features, **self.microphone_features}
@property @property
def has_camera(self): def has_camera(self):