Support npu

This commit is contained in:
ruanafan 2025-04-07 11:35:06 +08:00
parent ae70f12378
commit 485affb658
2 changed files with 11 additions and 2 deletions

View File

@ -24,6 +24,9 @@ from pathlib import Path
import numpy as np
import torch
if hasattr(torch, 'npu'):
import torch_npu
logging.info("exists npu, import torch_npu")
def none_or_int(value):
@ -46,6 +49,9 @@ def auto_select_torch_device() -> torch.device:
elif torch.backends.mps.is_available():
logging.info("Metal backend detected, using cuda.")
return torch.device("mps")
elif torch_npu.npu.is_available():
logging.info("Npu backend detected, using npu.")
return torch.device("npu")
else:
logging.warning("No accelerated backend detected. Using default cpu, this will be slow.")
return torch.device("cpu")
@ -94,8 +100,10 @@ def is_torch_device_available(try_device: str) -> bool:
return torch.backends.mps.is_available()
elif try_device == "cpu":
return True
elif try_device == "npu":
return torch_npu.npu.is_available()
else:
raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu.")
raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps, cpu or npu.")
def is_amp_available(device: str):

View File

@ -69,6 +69,7 @@ from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import add_envs_task, check_env_attributes_and_types, preprocess_observation
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.io_utils import write_video
from lerobot.common.utils.random_utils import set_seed
@ -152,7 +153,7 @@ def rollout(
all_observations.append(deepcopy(observation))
observation = {
key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation
key: observation[key].to(device, non_blocking=device.type != "cuda") for key in observation
}
# Infer "task" from attributes of environments.