From a18d0e46782418ce5e9681f7897c4e1c582d9478 Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Fri, 4 Apr 2025 18:31:00 +0200 Subject: [PATCH] Adding pytorch compatible conversion for audio --- lerobot/common/datasets/video_utils.py | 2 +- lerobot/common/robot_devices/control_utils.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/lerobot/common/datasets/video_utils.py b/lerobot/common/datasets/video_utils.py index 44d5a1a5..4c96a400 100644 --- a/lerobot/common/datasets/video_utils.py +++ b/lerobot/common/datasets/video_utils.py @@ -86,6 +86,7 @@ def decode_audio_torchvision( reader.add_basic_audio_stream( frames_per_chunk = int(ceil(duration * audio_sampling_rate)), #Too much is better than not enough buffer_chunk_size = -1, #No dropping frames + format = "fltp", #Format as float32 ) audio_chunks = [] @@ -103,7 +104,6 @@ def decode_audio_torchvision( audio_chunks.append(current_audio_chunk) audio_chunks = torch.stack(audio_chunks) - #TODO(CarolinePascal) : pytorch format conversion ? assert len(timestamps) == len(audio_chunks) return audio_chunks diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py index d54a9e13..9c2f0f47 100644 --- a/lerobot/common/robot_devices/control_utils.py +++ b/lerobot/common/robot_devices/control_utils.py @@ -112,11 +112,14 @@ def predict_action(observation, policy, device, use_amp): torch.inference_mode(), torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(), ): - # Convert to pytorch format: channel first and float32 in [0,1] with batch dimension for name in observation: + # Convert to pytorch format: channel first and float32 in [0,1] with batch dimension if "image" in name: observation[name] = observation[name].type(torch.float32) / 255 observation[name] = observation[name].permute(2, 0, 1).contiguous() + # Convert to pytorch format: channel first and float32 in [-1,1] (always the case here) with batch dimension + if "audio" in name: + observation[name] = observation[name].permute(1, 0).contiguous() observation[name] = observation[name].unsqueeze(0) observation[name] = observation[name].to(device)