Fixed eval.py on MPS (#702)

This commit is contained in:
Ilia Larchenko 2025-02-14 06:03:55 +07:00 committed by GitHub
parent 1e49cc4d60
commit c574eb4984
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 3 additions and 1 deletions

View File

@ -151,7 +151,9 @@ def rollout(
if return_observations: if return_observations:
all_observations.append(deepcopy(observation)) all_observations.append(deepcopy(observation))
observation = {key: observation[key].to(device, non_blocking=True) for key in observation} observation = {
key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation
}
with torch.inference_mode(): with torch.inference_mode():
action = policy.select_action(observation) action = policy.select_action(observation)