From 4631d36c0518519ae2114ece414a7cb9d83bbacb Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Wed, 20 Mar 2024 18:38:55 +0100 Subject: [PATCH] Add get_safe_torch_device in policies --- lerobot/common/policies/act/policy.py | 3 ++- lerobot/common/policies/diffusion/policy.py | 6 +++--- lerobot/common/policies/tdmpc/policy.py | 6 ++++-- lerobot/common/utils.py | 20 ++++++++++++++++++++ lerobot/scripts/eval.py | 14 +++++++------- lerobot/scripts/train.py | 8 +++----- 6 files changed, 39 insertions(+), 18 deletions(-) diff --git a/lerobot/common/policies/act/policy.py b/lerobot/common/policies/act/policy.py index 539cdcf5..0a0ee405 100644 --- a/lerobot/common/policies/act/policy.py +++ b/lerobot/common/policies/act/policy.py @@ -7,6 +7,7 @@ import torchvision.transforms as transforms from lerobot.common.policies.abstract import AbstractPolicy from lerobot.common.policies.act.detr_vae import build +from lerobot.common.utils import get_safe_torch_device def build_act_model_and_optimizer(cfg): @@ -45,7 +46,7 @@ class ActionChunkingTransformerPolicy(AbstractPolicy): super().__init__(n_action_steps) self.cfg = cfg self.n_action_steps = n_action_steps - self.device = device + self.device = get_safe_torch_device(device) self.model, self.optimizer = build_act_model_and_optimizer(cfg) self.kl_weight = self.cfg.kl_weight logging.info(f"KL Weight {self.kl_weight}") diff --git a/lerobot/common/policies/diffusion/policy.py b/lerobot/common/policies/diffusion/policy.py index 2c47f172..dee5aa64 100644 --- a/lerobot/common/policies/diffusion/policy.py +++ b/lerobot/common/policies/diffusion/policy.py @@ -8,6 +8,7 @@ from lerobot.common.policies.abstract import AbstractPolicy from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder +from lerobot.common.utils import get_safe_torch_device class DiffusionPolicy(AbstractPolicy): @@ -62,9 +63,8 @@ class DiffusionPolicy(AbstractPolicy): **kwargs, ) - self.device = torch.device(cfg_device) - if torch.cuda.is_available() and cfg_device == "cuda": - self.diffusion.cuda() + self.device = get_safe_torch_device(cfg_device) + self.diffusion.to(self.device) self.ema = None if self.cfg.use_ema: diff --git a/lerobot/common/policies/tdmpc/policy.py b/lerobot/common/policies/tdmpc/policy.py index 320f6f2b..5bb0da43 100644 --- a/lerobot/common/policies/tdmpc/policy.py +++ b/lerobot/common/policies/tdmpc/policy.py @@ -10,6 +10,7 @@ import torch.nn as nn import lerobot.common.policies.tdmpc.helper as h from lerobot.common.policies.abstract import AbstractPolicy +from lerobot.common.utils import get_safe_torch_device FIRST_FRAME = 0 @@ -94,9 +95,10 @@ class TDMPC(AbstractPolicy): self.action_dim = cfg.action_dim self.cfg = cfg - self.device = torch.device(device) + self.device = get_safe_torch_device(device) self.std = h.linear_schedule(cfg.std_schedule, 0) - self.model = TOLD(cfg).cuda() if torch.cuda.is_available() and device == "cuda" else TOLD(cfg) + self.model = TOLD(cfg) + self.model.to(self.device) self.model_target = deepcopy(self.model) self.optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr) self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr) diff --git a/lerobot/common/utils.py b/lerobot/common/utils.py index d174d4b5..a56543b7 100644 --- a/lerobot/common/utils.py +++ b/lerobot/common/utils.py @@ -6,6 +6,26 @@ import numpy as np import torch +def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device: + match cfg_device: + case "cuda": + assert torch.cuda.is_available() + device = torch.device("cuda") + case "mps": + assert torch.backends.mps.is_available() + device = torch.device("mps") + case "cpu": + device = torch.device("cpu") + if log: + logging.warning("Using CPU, this will be slow.") + case _: + device = torch.device(cfg_device) + if log: + logging.warning(f"Using custom {cfg_device} device.") + + return device + + def set_seed(seed): """Set seed for reproducibility.""" random.seed(seed) diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 41d58b91..76deb2fe 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -18,7 +18,7 @@ from lerobot.common.envs.factory import make_env from lerobot.common.logger import log_output_dir from lerobot.common.policies.abstract import AbstractPolicy from lerobot.common.policies.factory import make_policy -from lerobot.common.utils import init_logging, set_seed +from lerobot.common.utils import get_safe_torch_device, init_logging, set_seed def write_video(video_path, stacked_frames, fps): @@ -35,7 +35,8 @@ def eval_policy( fps: int = 15, return_first_video: bool = False, ): - policy.eval() + if policy is not None: + policy.eval() start = time.time() sum_rewards = [] max_rewards = [] @@ -55,7 +56,8 @@ def eval_policy( with torch.inference_mode(): # TODO(alexander-soare): When `break_when_any_done == False` this rolls out for max_steps even when all # envs are done the first time. But we only use the first rollout. This is a waste of compute. - policy.clear_action_queue() + if policy is not None: + policy.clear_action_queue() rollout = env.rollout( max_steps=max_steps, policy=policy, @@ -128,10 +130,8 @@ def eval(cfg: dict, out_dir=None): init_logging() - if cfg.device == "cuda": - assert torch.cuda.is_available() - else: - logging.warning("Using CPU, this will be slow.") + # Check device is available + get_safe_torch_device(cfg.device, log=True) torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 242c77bc..872b80df 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -12,7 +12,7 @@ from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.envs.factory import make_env from lerobot.common.logger import Logger, log_output_dir from lerobot.common.policies.factory import make_policy -from lerobot.common.utils import format_big_number, init_logging, set_seed +from lerobot.common.utils import format_big_number, get_safe_torch_device, init_logging, set_seed from lerobot.scripts.eval import eval_policy @@ -117,10 +117,8 @@ def train(cfg: dict, out_dir=None, job_name=None): init_logging() - if cfg.device == "cuda": - assert torch.cuda.is_available() - else: - logging.warning("Using CPU, this will be slow.") + # Check device is available + get_safe_torch_device(cfg.device, log=True) torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True