Make using sampler in train.py more explicit
This commit is contained in:
parent
aca5fd2f37
commit
ee6d4c31d9
|
@ -348,15 +348,22 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
logging.info("Resume training")
|
logging.info("Resume training")
|
||||||
|
|
||||||
# create dataloader for offline training
|
# create dataloader for offline training
|
||||||
|
if cfg.get("drop_n_last_frames"):
|
||||||
|
shuffle = False
|
||||||
|
sampler = EpisodeAwareSampler(
|
||||||
|
offline_dataset.episode_data_index,
|
||||||
|
drop_n_last_frames=cfg.drop_n_last_frames,
|
||||||
|
shuffle=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
shuffle = True
|
||||||
|
sampler = None
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
offline_dataset,
|
offline_dataset,
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
batch_size=cfg.training.batch_size,
|
batch_size=cfg.training.batch_size,
|
||||||
sampler=EpisodeAwareSampler(
|
shuffle=shuffle,
|
||||||
offline_dataset.episode_data_index,
|
sampler=sampler,
|
||||||
drop_n_last_frames=cfg.drop_n_last_frames if hasattr(cfg, "drop_n_last_frames") else 0,
|
|
||||||
shuffle=True,
|
|
||||||
),
|
|
||||||
pin_memory=device.type != "cpu",
|
pin_memory=device.type != "cpu",
|
||||||
drop_last=False,
|
drop_last=False,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue