This commit is contained in:
Simon Alibert 2024-03-05 17:00:17 +01:00
parent d6556e6519
commit a6d353c419
2 changed files with 5 additions and 3 deletions

View File

@ -4,7 +4,6 @@ import time
import hydra
import torch
import torch.nn as nn
from diffusion_policy.model.common.lr_scheduler import get_scheduler
from .diffusion_unet_image_policy import DiffusionUnetImagePolicy
@ -15,6 +14,7 @@ class DiffusionPolicy(nn.Module):
def __init__(
self,
cfg,
cfg_device,
cfg_noise_scheduler,
cfg_rgb_model,
cfg_obs_encoder,
@ -62,8 +62,9 @@ class DiffusionPolicy(nn.Module):
**kwargs,
)
self.device = torch.device("cuda")
self.diffusion.cuda()
self.device = torch.device(cfg_device)
if torch.cuda.is_available() and cfg_device == "cuda":
self.diffusion.cuda()
self.ema = None
if self.cfg.use_ema:

View File

@ -8,6 +8,7 @@ def make_policy(cfg):
policy = DiffusionPolicy(
cfg=cfg.policy,
cfg_device=cfg.device,
cfg_noise_scheduler=cfg.noise_scheduler,
cfg_rgb_model=cfg.rgb_model,
cfg_obs_encoder=cfg.obs_encoder,