diff --git a/lerobot/common/utils/utils.py b/lerobot/common/utils/utils.py index 563a7b81..2a44be6f 100644 --- a/lerobot/common/utils/utils.py +++ b/lerobot/common/utils/utils.py @@ -24,6 +24,9 @@ from pathlib import Path import numpy as np import torch +if hasattr(torch, 'npu'): + import torch_npu + logging.info("exists npu, import torch_npu") def none_or_int(value): @@ -46,6 +49,9 @@ def auto_select_torch_device() -> torch.device: elif torch.backends.mps.is_available(): logging.info("Metal backend detected, using cuda.") return torch.device("mps") + elif torch_npu.npu.is_available(): + logging.info("Npu backend detected, using npu.") + return torch.device("npu") else: logging.warning("No accelerated backend detected. Using default cpu, this will be slow.") return torch.device("cpu") @@ -94,8 +100,10 @@ def is_torch_device_available(try_device: str) -> bool: return torch.backends.mps.is_available() elif try_device == "cpu": return True + elif try_device == "npu": + return torch_npu.npu.is_available() else: - raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu.") + raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps, cpu or npu.") def is_amp_available(device: str): diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 9790f8b3..5d425c11 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -69,6 +69,7 @@ from lerobot.common.envs.factory import make_env from lerobot.common.envs.utils import add_envs_task, check_env_attributes_and_types, preprocess_observation from lerobot.common.policies.factory import make_policy from lerobot.common.policies.pretrained import PreTrainedPolicy +from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy from lerobot.common.policies.utils import get_device_from_parameters from lerobot.common.utils.io_utils import write_video from lerobot.common.utils.random_utils import set_seed @@ -152,7 +153,7 @@ def rollout( all_observations.append(deepcopy(observation)) observation = { - key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation + key: observation[key].to(device, non_blocking=device.type != "cuda") for key in observation } # Infer "task" from attributes of environments.