diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 6cbc8265..fcfd63ae 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -8,7 +8,6 @@ import hydra import torch from datasets import concatenate_datasets from datasets.utils import disable_progress_bars, enable_progress_bars -from diffusers.optimization import get_scheduler from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.utils import cycle @@ -55,6 +54,8 @@ def make_optimizer_and_scheduler(cfg, policy): cfg.training.adam_weight_decay, ) assert cfg.training.online_steps == 0, "Diffusion Policy does not handle online training." + from diffusers.optimization import get_scheduler + lr_scheduler = get_scheduler( cfg.training.lr_scheduler, optimizer=optimizer,