Make using sampler in train.py more explicit

This commit is contained in:
Radek Osmulski 2024-05-30 07:43:29 +10:00
parent aca5fd2f37
commit ee6d4c31d9
1 changed files with 12 additions and 5 deletions

View File

@ -348,15 +348,22 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logging.info("Resume training")
# create dataloader for offline training
if cfg.get("drop_n_last_frames"):
shuffle = False
sampler = EpisodeAwareSampler(
offline_dataset.episode_data_index,
drop_n_last_frames=cfg.drop_n_last_frames,
shuffle=True,
)
else:
shuffle = True
sampler = None
dataloader = torch.utils.data.DataLoader(
offline_dataset,
num_workers=4,
batch_size=cfg.training.batch_size,
sampler=EpisodeAwareSampler(
offline_dataset.episode_data_index,
drop_n_last_frames=cfg.drop_n_last_frames if hasattr(cfg, "drop_n_last_frames") else 0,
shuffle=True,
),
shuffle=shuffle,
sampler=sampler,
pin_memory=device.type != "cpu",
drop_last=False,
)