Fix
This commit is contained in:
parent
d6556e6519
commit
a6d353c419
|
@ -4,7 +4,6 @@ import time
|
||||||
import hydra
|
import hydra
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from diffusion_policy.model.common.lr_scheduler import get_scheduler
|
from diffusion_policy.model.common.lr_scheduler import get_scheduler
|
||||||
|
|
||||||
from .diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
from .diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
||||||
|
@ -15,6 +14,7 @@ class DiffusionPolicy(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
cfg,
|
cfg,
|
||||||
|
cfg_device,
|
||||||
cfg_noise_scheduler,
|
cfg_noise_scheduler,
|
||||||
cfg_rgb_model,
|
cfg_rgb_model,
|
||||||
cfg_obs_encoder,
|
cfg_obs_encoder,
|
||||||
|
@ -62,8 +62,9 @@ class DiffusionPolicy(nn.Module):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.device = torch.device("cuda")
|
self.device = torch.device(cfg_device)
|
||||||
self.diffusion.cuda()
|
if torch.cuda.is_available() and cfg_device == "cuda":
|
||||||
|
self.diffusion.cuda()
|
||||||
|
|
||||||
self.ema = None
|
self.ema = None
|
||||||
if self.cfg.use_ema:
|
if self.cfg.use_ema:
|
||||||
|
|
|
@ -8,6 +8,7 @@ def make_policy(cfg):
|
||||||
|
|
||||||
policy = DiffusionPolicy(
|
policy = DiffusionPolicy(
|
||||||
cfg=cfg.policy,
|
cfg=cfg.policy,
|
||||||
|
cfg_device=cfg.device,
|
||||||
cfg_noise_scheduler=cfg.noise_scheduler,
|
cfg_noise_scheduler=cfg.noise_scheduler,
|
||||||
cfg_rgb_model=cfg.rgb_model,
|
cfg_rgb_model=cfg.rgb_model,
|
||||||
cfg_obs_encoder=cfg.obs_encoder,
|
cfg_obs_encoder=cfg.obs_encoder,
|
||||||
|
|
Loading…
Reference in New Issue