Support npu
This commit is contained in:
parent
ae70f12378
commit
485affb658
|
@ -24,6 +24,9 @@ from pathlib import Path
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
if hasattr(torch, 'npu'):
|
||||
import torch_npu
|
||||
logging.info("exists npu, import torch_npu")
|
||||
|
||||
|
||||
def none_or_int(value):
|
||||
|
@ -46,6 +49,9 @@ def auto_select_torch_device() -> torch.device:
|
|||
elif torch.backends.mps.is_available():
|
||||
logging.info("Metal backend detected, using cuda.")
|
||||
return torch.device("mps")
|
||||
elif torch_npu.npu.is_available():
|
||||
logging.info("Npu backend detected, using npu.")
|
||||
return torch.device("npu")
|
||||
else:
|
||||
logging.warning("No accelerated backend detected. Using default cpu, this will be slow.")
|
||||
return torch.device("cpu")
|
||||
|
@ -94,8 +100,10 @@ def is_torch_device_available(try_device: str) -> bool:
|
|||
return torch.backends.mps.is_available()
|
||||
elif try_device == "cpu":
|
||||
return True
|
||||
elif try_device == "npu":
|
||||
return torch_npu.npu.is_available()
|
||||
else:
|
||||
raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu.")
|
||||
raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps, cpu or npu.")
|
||||
|
||||
|
||||
def is_amp_available(device: str):
|
||||
|
|
|
@ -69,6 +69,7 @@ from lerobot.common.envs.factory import make_env
|
|||
from lerobot.common.envs.utils import add_envs_task, check_env_attributes_and_types, preprocess_observation
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
from lerobot.common.utils.io_utils import write_video
|
||||
from lerobot.common.utils.random_utils import set_seed
|
||||
|
@ -152,7 +153,7 @@ def rollout(
|
|||
all_observations.append(deepcopy(observation))
|
||||
|
||||
observation = {
|
||||
key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation
|
||||
key: observation[key].to(device, non_blocking=device.type != "cuda") for key in observation
|
||||
}
|
||||
|
||||
# Infer "task" from attributes of environments.
|
||||
|
|
Loading…
Reference in New Issue