Move diffusers 'get_scheduler' import

This commit is contained in:
Simon Alibert 2024-05-05 14:47:14 +02:00
parent 53d8f6b785
commit dc56e9b930
1 changed files with 2 additions and 1 deletions

View File

@ -8,7 +8,6 @@ import hydra
import torch import torch
from datasets import concatenate_datasets from datasets import concatenate_datasets
from datasets.utils import disable_progress_bars, enable_progress_bars from datasets.utils import disable_progress_bars, enable_progress_bars
from diffusers.optimization import get_scheduler
from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import cycle from lerobot.common.datasets.utils import cycle
@ -55,6 +54,8 @@ def make_optimizer_and_scheduler(cfg, policy):
cfg.training.adam_weight_decay, cfg.training.adam_weight_decay,
) )
assert cfg.training.online_steps == 0, "Diffusion Policy does not handle online training." assert cfg.training.online_steps == 0, "Diffusion Policy does not handle online training."
from diffusers.optimization import get_scheduler
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
cfg.training.lr_scheduler, cfg.training.lr_scheduler,
optimizer=optimizer, optimizer=optimizer,