This commit is contained in:
Simon Alibert 2024-03-05 17:00:17 +01:00
parent d6556e6519
commit a6d353c419
2 changed files with 5 additions and 3 deletions

View File

@ -4,7 +4,6 @@ import time
import hydra import hydra
import torch import torch
import torch.nn as nn import torch.nn as nn
from diffusion_policy.model.common.lr_scheduler import get_scheduler from diffusion_policy.model.common.lr_scheduler import get_scheduler
from .diffusion_unet_image_policy import DiffusionUnetImagePolicy from .diffusion_unet_image_policy import DiffusionUnetImagePolicy
@ -15,6 +14,7 @@ class DiffusionPolicy(nn.Module):
def __init__( def __init__(
self, self,
cfg, cfg,
cfg_device,
cfg_noise_scheduler, cfg_noise_scheduler,
cfg_rgb_model, cfg_rgb_model,
cfg_obs_encoder, cfg_obs_encoder,
@ -62,7 +62,8 @@ class DiffusionPolicy(nn.Module):
**kwargs, **kwargs,
) )
self.device = torch.device("cuda") self.device = torch.device(cfg_device)
if torch.cuda.is_available() and cfg_device == "cuda":
self.diffusion.cuda() self.diffusion.cuda()
self.ema = None self.ema = None

View File

@ -8,6 +8,7 @@ def make_policy(cfg):
policy = DiffusionPolicy( policy = DiffusionPolicy(
cfg=cfg.policy, cfg=cfg.policy,
cfg_device=cfg.device,
cfg_noise_scheduler=cfg.noise_scheduler, cfg_noise_scheduler=cfg.noise_scheduler,
cfg_rgb_model=cfg.rgb_model, cfg_rgb_model=cfg.rgb_model,
cfg_obs_encoder=cfg.obs_encoder, cfg_obs_encoder=cfg.obs_encoder,