diff --git a/lerobot/common/policies/diffusion/policy.py b/lerobot/common/policies/diffusion/policy.py index 7ae0a529..37bc79a0 100644 --- a/lerobot/common/policies/diffusion/policy.py +++ b/lerobot/common/policies/diffusion/policy.py @@ -4,7 +4,6 @@ import time import hydra import torch import torch.nn as nn - from diffusion_policy.model.common.lr_scheduler import get_scheduler from .diffusion_unet_image_policy import DiffusionUnetImagePolicy @@ -15,6 +14,7 @@ class DiffusionPolicy(nn.Module): def __init__( self, cfg, + cfg_device, cfg_noise_scheduler, cfg_rgb_model, cfg_obs_encoder, @@ -62,8 +62,9 @@ class DiffusionPolicy(nn.Module): **kwargs, ) - self.device = torch.device("cuda") - self.diffusion.cuda() + self.device = torch.device(cfg_device) + if torch.cuda.is_available() and cfg_device == "cuda": + self.diffusion.cuda() self.ema = None if self.cfg.use_ema: diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 15a2c21d..9507586c 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -8,6 +8,7 @@ def make_policy(cfg): policy = DiffusionPolicy( cfg=cfg.policy, + cfg_device=cfg.device, cfg_noise_scheduler=cfg.noise_scheduler, cfg_rgb_model=cfg.rgb_model, cfg_obs_encoder=cfg.obs_encoder,