Fix
This commit is contained in:
parent
d6556e6519
commit
a6d353c419
|
@ -4,7 +4,6 @@ import time
|
|||
import hydra
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from diffusion_policy.model.common.lr_scheduler import get_scheduler
|
||||
|
||||
from .diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
||||
|
@ -15,6 +14,7 @@ class DiffusionPolicy(nn.Module):
|
|||
def __init__(
|
||||
self,
|
||||
cfg,
|
||||
cfg_device,
|
||||
cfg_noise_scheduler,
|
||||
cfg_rgb_model,
|
||||
cfg_obs_encoder,
|
||||
|
@ -62,8 +62,9 @@ class DiffusionPolicy(nn.Module):
|
|||
**kwargs,
|
||||
)
|
||||
|
||||
self.device = torch.device("cuda")
|
||||
self.diffusion.cuda()
|
||||
self.device = torch.device(cfg_device)
|
||||
if torch.cuda.is_available() and cfg_device == "cuda":
|
||||
self.diffusion.cuda()
|
||||
|
||||
self.ema = None
|
||||
if self.cfg.use_ema:
|
||||
|
|
|
@ -8,6 +8,7 @@ def make_policy(cfg):
|
|||
|
||||
policy = DiffusionPolicy(
|
||||
cfg=cfg.policy,
|
||||
cfg_device=cfg.device,
|
||||
cfg_noise_scheduler=cfg.noise_scheduler,
|
||||
cfg_rgb_model=cfg.rgb_model,
|
||||
cfg_obs_encoder=cfg.obs_encoder,
|
||||
|
|
Loading…
Reference in New Issue