Merge branch 'cache_dataset' into train_tdmpc
This commit is contained in:
commit
1e2cabd4e0
|
@ -28,6 +28,7 @@ def make_dataset(
|
|||
cfg.dataset_repo_id,
|
||||
split=split,
|
||||
delta_timestamps=delta_timestamps,
|
||||
use_cache=cfg.training.dataset_use_cache,
|
||||
)
|
||||
|
||||
if cfg.get("override_dataset_stats"):
|
||||
|
|
|
@ -27,7 +27,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
split: str = "train",
|
||||
transform: callable = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
use_cache: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
use_cache: Enable this to cache all items as tensors for faster data loading after the first
|
||||
epoch. Useful if you have a small enough dataset to fit into memory. You may set multiple
|
||||
workers for the PyTorch Dataloader but remember to set persistent_workers=True.
|
||||
"""
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
self.version = version
|
||||
|
@ -44,6 +51,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
self.info = load_info(repo_id, version, root)
|
||||
if self.video:
|
||||
self.videos_dir = load_videos(repo_id, version, root)
|
||||
self.cache = {} if use_cache else None
|
||||
|
||||
@property
|
||||
def fps(self) -> int:
|
||||
|
@ -104,9 +112,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
return 1 / self.fps - 1e-4
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
return self.num_samples // 8
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if self.cache is not None and idx in self.cache:
|
||||
item = self.cache[idx]
|
||||
else:
|
||||
item = self.hf_dataset[idx]
|
||||
|
||||
if self.delta_timestamps is not None:
|
||||
|
|
|
@ -26,6 +26,12 @@ training:
|
|||
save_freq: ???
|
||||
log_freq: 250
|
||||
save_model: true
|
||||
# `dataset_use_cache` indicates whether to cache all dataset items as Tensors in RAM. Potentially useful for
|
||||
# faster data loading with datasets small enough to fit in memory. If you wish to use dataloader workers,
|
||||
# remember to set `dataloader_persistent_workers to True.
|
||||
dataset_use_cache: false
|
||||
dataloader_num_workers: 4
|
||||
dataloader_persistent_workers: false
|
||||
|
||||
eval:
|
||||
n_episodes: 1
|
||||
|
|
|
@ -10,6 +10,9 @@ training:
|
|||
online_steps_between_rollouts: 1
|
||||
online_sampling_ratio: 0.5
|
||||
online_env_seed: 10000
|
||||
dataset_use_cache: true
|
||||
dataloader_num_workers: 4
|
||||
dataloader_persistent_workers: true
|
||||
|
||||
batch_size: 256
|
||||
grad_clip_norm: 10.0
|
||||
|
|
|
@ -368,7 +368,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
# create dataloader for offline training
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
offline_dataset,
|
||||
num_workers=4,
|
||||
num_workers=cfg.training.dataloader_num_workers,
|
||||
persistent_workers=cfg.training.dataloader_persistent_workers,
|
||||
batch_size=cfg.training.batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=cfg.device != "cpu",
|
||||
|
@ -415,7 +416,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
concat_dataset,
|
||||
num_workers=4,
|
||||
num_workers=cfg.training.dataloader_num_workers,
|
||||
persistent_workers=cfg.training.dataloader_persistent_workers,
|
||||
batch_size=cfg.training.batch_size,
|
||||
sampler=sampler,
|
||||
pin_memory=cfg.device != "cpu",
|
||||
|
|
Loading…
Reference in New Issue