fix: default float64 type must be cast into float32 for audio
This commit is contained in:
parent
89697d86e7
commit
4ddba296f7
|
@ -117,8 +117,9 @@ def predict_action(observation, policy, device, use_amp):
|
|||
if "image" in name:
|
||||
observation[name] = observation[name].type(torch.float32) / 255
|
||||
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
||||
# Convert to pytorch format: channel first and float32 in [-1,1] (always the case here) with batch dimension
|
||||
# Convert to pytorch format: channel first and float32 in [-1,1] with batch dimension
|
||||
if "audio" in name:
|
||||
observation[name] = observation[name].type(torch.float32)
|
||||
observation[name] = observation[name].permute(1, 0).contiguous()
|
||||
observation[name] = observation[name].unsqueeze(0)
|
||||
observation[name] = observation[name].to(device)
|
||||
|
|
Loading…
Reference in New Issue