Merge branch 'cache_dataset' into train_tdmpc

This commit is contained in:
Alexander Soare 2024-05-07 09:07:15 +01:00
commit 1e2cabd4e0
5 changed files with 35 additions and 12 deletions

View File

@ -28,6 +28,7 @@ def make_dataset(
cfg.dataset_repo_id, cfg.dataset_repo_id,
split=split, split=split,
delta_timestamps=delta_timestamps, delta_timestamps=delta_timestamps,
use_cache=cfg.training.dataset_use_cache,
) )
if cfg.get("override_dataset_stats"): if cfg.get("override_dataset_stats"):

View File

@ -27,7 +27,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
split: str = "train", split: str = "train",
transform: callable = None, transform: callable = None,
delta_timestamps: dict[list[float]] | None = None, delta_timestamps: dict[list[float]] | None = None,
use_cache: bool = False,
): ):
"""
Args:
use_cache: Enable this to cache all items as tensors for faster data loading after the first
epoch. Useful if you have a small enough dataset to fit into memory. You may set multiple
workers for the PyTorch Dataloader but remember to set persistent_workers=True.
"""
super().__init__() super().__init__()
self.repo_id = repo_id self.repo_id = repo_id
self.version = version self.version = version
@ -44,6 +51,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.info = load_info(repo_id, version, root) self.info = load_info(repo_id, version, root)
if self.video: if self.video:
self.videos_dir = load_videos(repo_id, version, root) self.videos_dir = load_videos(repo_id, version, root)
self.cache = {} if use_cache else None
@property @property
def fps(self) -> int: def fps(self) -> int:
@ -104,19 +112,22 @@ class LeRobotDataset(torch.utils.data.Dataset):
return 1 / self.fps - 1e-4 return 1 / self.fps - 1e-4
def __len__(self): def __len__(self):
return self.num_samples return self.num_samples // 8
def __getitem__(self, idx): def __getitem__(self, idx):
item = self.hf_dataset[idx] if self.cache is not None and idx in self.cache:
item = self.cache[idx]
else:
item = self.hf_dataset[idx]
if self.delta_timestamps is not None: if self.delta_timestamps is not None:
item = load_previous_and_future_frames( item = load_previous_and_future_frames(
item, item,
self.hf_dataset, self.hf_dataset,
self.episode_data_index, self.episode_data_index,
self.delta_timestamps, self.delta_timestamps,
self.tolerance_s, self.tolerance_s,
) )
if self.video: if self.video:
item = load_from_videos( item = load_from_videos(

View File

@ -26,6 +26,12 @@ training:
save_freq: ??? save_freq: ???
log_freq: 250 log_freq: 250
save_model: true save_model: true
# `dataset_use_cache` indicates whether to cache all dataset items as Tensors in RAM. Potentially useful for
# faster data loading with datasets small enough to fit in memory. If you wish to use dataloader workers,
# remember to set `dataloader_persistent_workers to True.
dataset_use_cache: false
dataloader_num_workers: 4
dataloader_persistent_workers: false
eval: eval:
n_episodes: 1 n_episodes: 1

View File

@ -10,6 +10,9 @@ training:
online_steps_between_rollouts: 1 online_steps_between_rollouts: 1
online_sampling_ratio: 0.5 online_sampling_ratio: 0.5
online_env_seed: 10000 online_env_seed: 10000
dataset_use_cache: true
dataloader_num_workers: 4
dataloader_persistent_workers: true
batch_size: 256 batch_size: 256
grad_clip_norm: 10.0 grad_clip_norm: 10.0

View File

@ -368,7 +368,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
# create dataloader for offline training # create dataloader for offline training
dataloader = torch.utils.data.DataLoader( dataloader = torch.utils.data.DataLoader(
offline_dataset, offline_dataset,
num_workers=4, num_workers=cfg.training.dataloader_num_workers,
persistent_workers=cfg.training.dataloader_persistent_workers,
batch_size=cfg.training.batch_size, batch_size=cfg.training.batch_size,
shuffle=True, shuffle=True,
pin_memory=cfg.device != "cpu", pin_memory=cfg.device != "cpu",
@ -415,7 +416,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
) )
dataloader = torch.utils.data.DataLoader( dataloader = torch.utils.data.DataLoader(
concat_dataset, concat_dataset,
num_workers=4, num_workers=cfg.training.dataloader_num_workers,
persistent_workers=cfg.training.dataloader_persistent_workers,
batch_size=cfg.training.batch_size, batch_size=cfg.training.batch_size,
sampler=sampler, sampler=sampler,
pin_memory=cfg.device != "cpu", pin_memory=cfg.device != "cpu",