Merge pull request #40 from huggingface/user/aliberts/2024_03_20_enable_mps_device
Enable mps backend for Apple silicon devices
This commit is contained in:
commit
1bd50122be
|
@ -7,6 +7,7 @@ import torchvision.transforms as transforms
|
||||||
|
|
||||||
from lerobot.common.policies.abstract import AbstractPolicy
|
from lerobot.common.policies.abstract import AbstractPolicy
|
||||||
from lerobot.common.policies.act.detr_vae import build
|
from lerobot.common.policies.act.detr_vae import build
|
||||||
|
from lerobot.common.utils import get_safe_torch_device
|
||||||
|
|
||||||
|
|
||||||
def build_act_model_and_optimizer(cfg):
|
def build_act_model_and_optimizer(cfg):
|
||||||
|
@ -45,7 +46,7 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
|
||||||
super().__init__(n_action_steps)
|
super().__init__(n_action_steps)
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.n_action_steps = n_action_steps
|
self.n_action_steps = n_action_steps
|
||||||
self.device = device
|
self.device = get_safe_torch_device(device)
|
||||||
self.model, self.optimizer = build_act_model_and_optimizer(cfg)
|
self.model, self.optimizer = build_act_model_and_optimizer(cfg)
|
||||||
self.kl_weight = self.cfg.kl_weight
|
self.kl_weight = self.cfg.kl_weight
|
||||||
logging.info(f"KL Weight {self.kl_weight}")
|
logging.info(f"KL Weight {self.kl_weight}")
|
||||||
|
|
|
@ -8,6 +8,7 @@ from lerobot.common.policies.abstract import AbstractPolicy
|
||||||
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
||||||
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
|
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
|
||||||
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder
|
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder
|
||||||
|
from lerobot.common.utils import get_safe_torch_device
|
||||||
|
|
||||||
|
|
||||||
class DiffusionPolicy(AbstractPolicy):
|
class DiffusionPolicy(AbstractPolicy):
|
||||||
|
@ -62,9 +63,8 @@ class DiffusionPolicy(AbstractPolicy):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.device = torch.device(cfg_device)
|
self.device = get_safe_torch_device(cfg_device)
|
||||||
if torch.cuda.is_available() and cfg_device == "cuda":
|
self.diffusion.to(self.device)
|
||||||
self.diffusion.cuda()
|
|
||||||
|
|
||||||
self.ema = None
|
self.ema = None
|
||||||
if self.cfg.use_ema:
|
if self.cfg.use_ema:
|
||||||
|
|
|
@ -10,6 +10,7 @@ import torch.nn as nn
|
||||||
|
|
||||||
import lerobot.common.policies.tdmpc.helper as h
|
import lerobot.common.policies.tdmpc.helper as h
|
||||||
from lerobot.common.policies.abstract import AbstractPolicy
|
from lerobot.common.policies.abstract import AbstractPolicy
|
||||||
|
from lerobot.common.utils import get_safe_torch_device
|
||||||
|
|
||||||
FIRST_FRAME = 0
|
FIRST_FRAME = 0
|
||||||
|
|
||||||
|
@ -94,9 +95,10 @@ class TDMPC(AbstractPolicy):
|
||||||
self.action_dim = cfg.action_dim
|
self.action_dim = cfg.action_dim
|
||||||
|
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.device = torch.device(device)
|
self.device = get_safe_torch_device(device)
|
||||||
self.std = h.linear_schedule(cfg.std_schedule, 0)
|
self.std = h.linear_schedule(cfg.std_schedule, 0)
|
||||||
self.model = TOLD(cfg).cuda() if torch.cuda.is_available() and device == "cuda" else TOLD(cfg)
|
self.model = TOLD(cfg)
|
||||||
|
self.model.to(self.device)
|
||||||
self.model_target = deepcopy(self.model)
|
self.model_target = deepcopy(self.model)
|
||||||
self.optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
|
self.optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
|
||||||
self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr)
|
self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr)
|
||||||
|
|
|
@ -6,6 +6,26 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device:
|
||||||
|
match cfg_device:
|
||||||
|
case "cuda":
|
||||||
|
assert torch.cuda.is_available()
|
||||||
|
device = torch.device("cuda")
|
||||||
|
case "mps":
|
||||||
|
assert torch.backends.mps.is_available()
|
||||||
|
device = torch.device("mps")
|
||||||
|
case "cpu":
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if log:
|
||||||
|
logging.warning("Using CPU, this will be slow.")
|
||||||
|
case _:
|
||||||
|
device = torch.device(cfg_device)
|
||||||
|
if log:
|
||||||
|
logging.warning(f"Using custom {cfg_device} device.")
|
||||||
|
|
||||||
|
return device
|
||||||
|
|
||||||
|
|
||||||
def set_seed(seed):
|
def set_seed(seed):
|
||||||
"""Set seed for reproducibility."""
|
"""Set seed for reproducibility."""
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
|
|
|
@ -18,7 +18,7 @@ from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.logger import log_output_dir
|
from lerobot.common.logger import log_output_dir
|
||||||
from lerobot.common.policies.abstract import AbstractPolicy
|
from lerobot.common.policies.abstract import AbstractPolicy
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.common.utils import init_logging, set_seed
|
from lerobot.common.utils import get_safe_torch_device, init_logging, set_seed
|
||||||
|
|
||||||
|
|
||||||
def write_video(video_path, stacked_frames, fps):
|
def write_video(video_path, stacked_frames, fps):
|
||||||
|
@ -35,6 +35,7 @@ def eval_policy(
|
||||||
fps: int = 15,
|
fps: int = 15,
|
||||||
return_first_video: bool = False,
|
return_first_video: bool = False,
|
||||||
):
|
):
|
||||||
|
if policy is not None:
|
||||||
policy.eval()
|
policy.eval()
|
||||||
start = time.time()
|
start = time.time()
|
||||||
sum_rewards = []
|
sum_rewards = []
|
||||||
|
@ -55,6 +56,7 @@ def eval_policy(
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
# TODO(alexander-soare): When `break_when_any_done == False` this rolls out for max_steps even when all
|
# TODO(alexander-soare): When `break_when_any_done == False` this rolls out for max_steps even when all
|
||||||
# envs are done the first time. But we only use the first rollout. This is a waste of compute.
|
# envs are done the first time. But we only use the first rollout. This is a waste of compute.
|
||||||
|
if policy is not None:
|
||||||
policy.clear_action_queue()
|
policy.clear_action_queue()
|
||||||
rollout = env.rollout(
|
rollout = env.rollout(
|
||||||
max_steps=max_steps,
|
max_steps=max_steps,
|
||||||
|
@ -128,10 +130,8 @@ def eval(cfg: dict, out_dir=None):
|
||||||
|
|
||||||
init_logging()
|
init_logging()
|
||||||
|
|
||||||
if cfg.device == "cuda":
|
# Check device is available
|
||||||
assert torch.cuda.is_available()
|
get_safe_torch_device(cfg.device, log=True)
|
||||||
else:
|
|
||||||
logging.warning("Using CPU, this will be slow.")
|
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
|
|
@ -12,7 +12,7 @@ from lerobot.common.datasets.factory import make_offline_buffer
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.logger import Logger, log_output_dir
|
from lerobot.common.logger import Logger, log_output_dir
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.common.utils import format_big_number, init_logging, set_seed
|
from lerobot.common.utils import format_big_number, get_safe_torch_device, init_logging, set_seed
|
||||||
from lerobot.scripts.eval import eval_policy
|
from lerobot.scripts.eval import eval_policy
|
||||||
|
|
||||||
|
|
||||||
|
@ -117,10 +117,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
|
|
||||||
init_logging()
|
init_logging()
|
||||||
|
|
||||||
if cfg.device == "cuda":
|
# Check device is available
|
||||||
assert torch.cuda.is_available()
|
get_safe_torch_device(cfg.device, log=True)
|
||||||
else:
|
|
||||||
logging.warning("Using CPU, this will be slow.")
|
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
|
Loading…
Reference in New Issue