Support npu
This commit is contained in:
parent
ae70f12378
commit
485affb658
|
@ -24,6 +24,9 @@ from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
if hasattr(torch, 'npu'):
|
||||||
|
import torch_npu
|
||||||
|
logging.info("exists npu, import torch_npu")
|
||||||
|
|
||||||
|
|
||||||
def none_or_int(value):
|
def none_or_int(value):
|
||||||
|
@ -46,6 +49,9 @@ def auto_select_torch_device() -> torch.device:
|
||||||
elif torch.backends.mps.is_available():
|
elif torch.backends.mps.is_available():
|
||||||
logging.info("Metal backend detected, using cuda.")
|
logging.info("Metal backend detected, using cuda.")
|
||||||
return torch.device("mps")
|
return torch.device("mps")
|
||||||
|
elif torch_npu.npu.is_available():
|
||||||
|
logging.info("Npu backend detected, using npu.")
|
||||||
|
return torch.device("npu")
|
||||||
else:
|
else:
|
||||||
logging.warning("No accelerated backend detected. Using default cpu, this will be slow.")
|
logging.warning("No accelerated backend detected. Using default cpu, this will be slow.")
|
||||||
return torch.device("cpu")
|
return torch.device("cpu")
|
||||||
|
@ -94,8 +100,10 @@ def is_torch_device_available(try_device: str) -> bool:
|
||||||
return torch.backends.mps.is_available()
|
return torch.backends.mps.is_available()
|
||||||
elif try_device == "cpu":
|
elif try_device == "cpu":
|
||||||
return True
|
return True
|
||||||
|
elif try_device == "npu":
|
||||||
|
return torch_npu.npu.is_available()
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu.")
|
raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps, cpu or npu.")
|
||||||
|
|
||||||
|
|
||||||
def is_amp_available(device: str):
|
def is_amp_available(device: str):
|
||||||
|
|
|
@ -69,6 +69,7 @@ from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.envs.utils import add_envs_task, check_env_attributes_and_types, preprocess_observation
|
from lerobot.common.envs.utils import add_envs_task, check_env_attributes_and_types, preprocess_observation
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||||
|
from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
|
||||||
from lerobot.common.policies.utils import get_device_from_parameters
|
from lerobot.common.policies.utils import get_device_from_parameters
|
||||||
from lerobot.common.utils.io_utils import write_video
|
from lerobot.common.utils.io_utils import write_video
|
||||||
from lerobot.common.utils.random_utils import set_seed
|
from lerobot.common.utils.random_utils import set_seed
|
||||||
|
@ -152,7 +153,7 @@ def rollout(
|
||||||
all_observations.append(deepcopy(observation))
|
all_observations.append(deepcopy(observation))
|
||||||
|
|
||||||
observation = {
|
observation = {
|
||||||
key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation
|
key: observation[key].to(device, non_blocking=device.type != "cuda") for key in observation
|
||||||
}
|
}
|
||||||
|
|
||||||
# Infer "task" from attributes of environments.
|
# Infer "task" from attributes of environments.
|
||||||
|
|
Loading…
Reference in New Issue