diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 06c821d7..79e11744 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -348,15 +348,22 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No logging.info("Resume training") # create dataloader for offline training + if cfg.get("drop_n_last_frames"): + shuffle = False + sampler = EpisodeAwareSampler( + offline_dataset.episode_data_index, + drop_n_last_frames=cfg.drop_n_last_frames, + shuffle=True, + ) + else: + shuffle = True + sampler = None dataloader = torch.utils.data.DataLoader( offline_dataset, num_workers=4, batch_size=cfg.training.batch_size, - sampler=EpisodeAwareSampler( - offline_dataset.episode_data_index, - drop_n_last_frames=cfg.drop_n_last_frames if hasattr(cfg, "drop_n_last_frames") else 0, - shuffle=True, - ), + shuffle=shuffle, + sampler=sampler, pin_memory=device.type != "cpu", drop_last=False, )