diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 52fed33d..1ad16a15 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -25,7 +25,7 @@ from torch.amp import GradScaler from torch.optim import Optimizer from lerobot.common.datasets.factory import make_dataset -from lerobot.common.datasets.sampler import PrioritizedSampler +from lerobot.common.datasets.sampler import EpisodeAwareSampler, PrioritizedSampler from lerobot.common.datasets.utils import cycle from lerobot.common.envs.factory import make_env from lerobot.common.optim.factory import make_optimizer_and_scheduler @@ -166,18 +166,26 @@ def train(cfg: TrainPipelineConfig): # create dataloader for offline training if hasattr(cfg.policy, "drop_n_last_frames"): shuffle = False - sampler = PrioritizedSampler( - data_len=data_len, - alpha=0.6, - beta=0.1, - eps=1e-6, - replacement=True, - num_samples_per_epoch=data_len, + sampler = EpisodeAwareSampler( + dataset.episode_data_index, + drop_n_last_frames=cfg.policy.drop_n_last_frames, + shuffle=True, ) else: shuffle = True sampler = None + # TODO(pepijn): If experiment works integrate this + shuffle = False + sampler = PrioritizedSampler( + data_len=data_len, + alpha=0.6, + beta=0.1, + eps=1e-6, + replacement=True, + num_samples_per_epoch=data_len, + ) + dataloader = torch.utils.data.DataLoader( dataset, num_workers=cfg.num_workers,