Adding pytorch compatible conversion for audio
This commit is contained in:
parent
44af02a334
commit
a18d0e4678
|
@ -86,6 +86,7 @@ def decode_audio_torchvision(
|
|||
reader.add_basic_audio_stream(
|
||||
frames_per_chunk = int(ceil(duration * audio_sampling_rate)), #Too much is better than not enough
|
||||
buffer_chunk_size = -1, #No dropping frames
|
||||
format = "fltp", #Format as float32
|
||||
)
|
||||
|
||||
audio_chunks = []
|
||||
|
@ -103,7 +104,6 @@ def decode_audio_torchvision(
|
|||
audio_chunks.append(current_audio_chunk)
|
||||
|
||||
audio_chunks = torch.stack(audio_chunks)
|
||||
#TODO(CarolinePascal) : pytorch format conversion ?
|
||||
|
||||
assert len(timestamps) == len(audio_chunks)
|
||||
return audio_chunks
|
||||
|
|
|
@ -112,11 +112,14 @@ def predict_action(observation, policy, device, use_amp):
|
|||
torch.inference_mode(),
|
||||
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
|
||||
):
|
||||
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
|
||||
for name in observation:
|
||||
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
|
||||
if "image" in name:
|
||||
observation[name] = observation[name].type(torch.float32) / 255
|
||||
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
||||
# Convert to pytorch format: channel first and float32 in [-1,1] (always the case here) with batch dimension
|
||||
if "audio" in name:
|
||||
observation[name] = observation[name].permute(1, 0).contiguous()
|
||||
observation[name] = observation[name].unsqueeze(0)
|
||||
observation[name] = observation[name].to(device)
|
||||
|
||||
|
|
Loading…
Reference in New Issue