Adding pytorch compatible conversion for audio
This commit is contained in:
parent
44af02a334
commit
a18d0e4678
|
@ -86,6 +86,7 @@ def decode_audio_torchvision(
|
||||||
reader.add_basic_audio_stream(
|
reader.add_basic_audio_stream(
|
||||||
frames_per_chunk = int(ceil(duration * audio_sampling_rate)), #Too much is better than not enough
|
frames_per_chunk = int(ceil(duration * audio_sampling_rate)), #Too much is better than not enough
|
||||||
buffer_chunk_size = -1, #No dropping frames
|
buffer_chunk_size = -1, #No dropping frames
|
||||||
|
format = "fltp", #Format as float32
|
||||||
)
|
)
|
||||||
|
|
||||||
audio_chunks = []
|
audio_chunks = []
|
||||||
|
@ -103,7 +104,6 @@ def decode_audio_torchvision(
|
||||||
audio_chunks.append(current_audio_chunk)
|
audio_chunks.append(current_audio_chunk)
|
||||||
|
|
||||||
audio_chunks = torch.stack(audio_chunks)
|
audio_chunks = torch.stack(audio_chunks)
|
||||||
#TODO(CarolinePascal) : pytorch format conversion ?
|
|
||||||
|
|
||||||
assert len(timestamps) == len(audio_chunks)
|
assert len(timestamps) == len(audio_chunks)
|
||||||
return audio_chunks
|
return audio_chunks
|
||||||
|
|
|
@ -112,11 +112,14 @@ def predict_action(observation, policy, device, use_amp):
|
||||||
torch.inference_mode(),
|
torch.inference_mode(),
|
||||||
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
|
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
|
||||||
):
|
):
|
||||||
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
|
|
||||||
for name in observation:
|
for name in observation:
|
||||||
|
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
|
||||||
if "image" in name:
|
if "image" in name:
|
||||||
observation[name] = observation[name].type(torch.float32) / 255
|
observation[name] = observation[name].type(torch.float32) / 255
|
||||||
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
||||||
|
# Convert to pytorch format: channel first and float32 in [-1,1] (always the case here) with batch dimension
|
||||||
|
if "audio" in name:
|
||||||
|
observation[name] = observation[name].permute(1, 0).contiguous()
|
||||||
observation[name] = observation[name].unsqueeze(0)
|
observation[name] = observation[name].unsqueeze(0)
|
||||||
observation[name] = observation[name].to(device)
|
observation[name] = observation[name].to(device)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue