fix: default float64 type must be cast into float32 for audio

This commit is contained in:
CarolinePascal 2025-04-11 18:55:37 +02:00
parent 89697d86e7
commit 4ddba296f7
No known key found for this signature in database
1 changed files with 2 additions and 1 deletions

View File

@ -117,8 +117,9 @@ def predict_action(observation, policy, device, use_amp):
if "image" in name: if "image" in name:
observation[name] = observation[name].type(torch.float32) / 255 observation[name] = observation[name].type(torch.float32) / 255
observation[name] = observation[name].permute(2, 0, 1).contiguous() observation[name] = observation[name].permute(2, 0, 1).contiguous()
# Convert to pytorch format: channel first and float32 in [-1,1] (always the case here) with batch dimension # Convert to pytorch format: channel first and float32 in [-1,1] with batch dimension
if "audio" in name: if "audio" in name:
observation[name] = observation[name].type(torch.float32)
observation[name] = observation[name].permute(1, 0).contiguous() observation[name] = observation[name].permute(1, 0).contiguous()
observation[name] = observation[name].unsqueeze(0) observation[name] = observation[name].unsqueeze(0)
observation[name] = observation[name].to(device) observation[name] = observation[name].to(device)