Adding pytorch compatible conversion for audio

This commit is contained in:
CarolinePascal 2025-04-04 18:31:00 +02:00
parent 44af02a334
commit a18d0e4678
No known key found for this signature in database
2 changed files with 5 additions and 2 deletions

View File

@ -86,6 +86,7 @@ def decode_audio_torchvision(
reader.add_basic_audio_stream( reader.add_basic_audio_stream(
frames_per_chunk = int(ceil(duration * audio_sampling_rate)), #Too much is better than not enough frames_per_chunk = int(ceil(duration * audio_sampling_rate)), #Too much is better than not enough
buffer_chunk_size = -1, #No dropping frames buffer_chunk_size = -1, #No dropping frames
format = "fltp", #Format as float32
) )
audio_chunks = [] audio_chunks = []
@ -103,7 +104,6 @@ def decode_audio_torchvision(
audio_chunks.append(current_audio_chunk) audio_chunks.append(current_audio_chunk)
audio_chunks = torch.stack(audio_chunks) audio_chunks = torch.stack(audio_chunks)
#TODO(CarolinePascal) : pytorch format conversion ?
assert len(timestamps) == len(audio_chunks) assert len(timestamps) == len(audio_chunks)
return audio_chunks return audio_chunks

View File

@ -112,11 +112,14 @@ def predict_action(observation, policy, device, use_amp):
torch.inference_mode(), torch.inference_mode(),
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(), torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
): ):
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
for name in observation: for name in observation:
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
if "image" in name: if "image" in name:
observation[name] = observation[name].type(torch.float32) / 255 observation[name] = observation[name].type(torch.float32) / 255
observation[name] = observation[name].permute(2, 0, 1).contiguous() observation[name] = observation[name].permute(2, 0, 1).contiguous()
# Convert to pytorch format: channel first and float32 in [-1,1] (always the case here) with batch dimension
if "audio" in name:
observation[name] = observation[name].permute(1, 0).contiguous()
observation[name] = observation[name].unsqueeze(0) observation[name] = observation[name].unsqueeze(0)
observation[name] = observation[name].to(device) observation[name] = observation[name].to(device)