From a6d353c419c78603f0592ac39e9b77a6a1657984 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Tue, 5 Mar 2024 17:00:17 +0100 Subject: [PATCH] Fix --- lerobot/common/policies/diffusion/policy.py | 7 ++++--- lerobot/common/policies/factory.py | 1 + 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/lerobot/common/policies/diffusion/policy.py b/lerobot/common/policies/diffusion/policy.py index 7ae0a529..37bc79a0 100644 --- a/lerobot/common/policies/diffusion/policy.py +++ b/lerobot/common/policies/diffusion/policy.py @@ -4,7 +4,6 @@ import time import hydra import torch import torch.nn as nn - from diffusion_policy.model.common.lr_scheduler import get_scheduler from .diffusion_unet_image_policy import DiffusionUnetImagePolicy @@ -15,6 +14,7 @@ class DiffusionPolicy(nn.Module): def __init__( self, cfg, + cfg_device, cfg_noise_scheduler, cfg_rgb_model, cfg_obs_encoder, @@ -62,8 +62,9 @@ class DiffusionPolicy(nn.Module): **kwargs, ) - self.device = torch.device("cuda") - self.diffusion.cuda() + self.device = torch.device(cfg_device) + if torch.cuda.is_available() and cfg_device == "cuda": + self.diffusion.cuda() self.ema = None if self.cfg.use_ema: diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 15a2c21d..9507586c 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -8,6 +8,7 @@ def make_policy(cfg): policy = DiffusionPolicy( cfg=cfg.policy, + cfg_device=cfg.device, cfg_noise_scheduler=cfg.noise_scheduler, cfg_rgb_model=cfg.rgb_model, cfg_obs_encoder=cfg.obs_encoder,