Merge branch 'cache_dataset' into train_tdmpc
This commit is contained in:
commit
1e2cabd4e0
|
@ -28,6 +28,7 @@ def make_dataset(
|
||||||
cfg.dataset_repo_id,
|
cfg.dataset_repo_id,
|
||||||
split=split,
|
split=split,
|
||||||
delta_timestamps=delta_timestamps,
|
delta_timestamps=delta_timestamps,
|
||||||
|
use_cache=cfg.training.dataset_use_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.get("override_dataset_stats"):
|
if cfg.get("override_dataset_stats"):
|
||||||
|
|
|
@ -27,7 +27,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
split: str = "train",
|
split: str = "train",
|
||||||
transform: callable = None,
|
transform: callable = None,
|
||||||
delta_timestamps: dict[list[float]] | None = None,
|
delta_timestamps: dict[list[float]] | None = None,
|
||||||
|
use_cache: bool = False,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
use_cache: Enable this to cache all items as tensors for faster data loading after the first
|
||||||
|
epoch. Useful if you have a small enough dataset to fit into memory. You may set multiple
|
||||||
|
workers for the PyTorch Dataloader but remember to set persistent_workers=True.
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.repo_id = repo_id
|
self.repo_id = repo_id
|
||||||
self.version = version
|
self.version = version
|
||||||
|
@ -44,6 +51,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
self.info = load_info(repo_id, version, root)
|
self.info = load_info(repo_id, version, root)
|
||||||
if self.video:
|
if self.video:
|
||||||
self.videos_dir = load_videos(repo_id, version, root)
|
self.videos_dir = load_videos(repo_id, version, root)
|
||||||
|
self.cache = {} if use_cache else None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def fps(self) -> int:
|
def fps(self) -> int:
|
||||||
|
@ -104,19 +112,22 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
return 1 / self.fps - 1e-4
|
return 1 / self.fps - 1e-4
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.num_samples
|
return self.num_samples // 8
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
item = self.hf_dataset[idx]
|
if self.cache is not None and idx in self.cache:
|
||||||
|
item = self.cache[idx]
|
||||||
|
else:
|
||||||
|
item = self.hf_dataset[idx]
|
||||||
|
|
||||||
if self.delta_timestamps is not None:
|
if self.delta_timestamps is not None:
|
||||||
item = load_previous_and_future_frames(
|
item = load_previous_and_future_frames(
|
||||||
item,
|
item,
|
||||||
self.hf_dataset,
|
self.hf_dataset,
|
||||||
self.episode_data_index,
|
self.episode_data_index,
|
||||||
self.delta_timestamps,
|
self.delta_timestamps,
|
||||||
self.tolerance_s,
|
self.tolerance_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.video:
|
if self.video:
|
||||||
item = load_from_videos(
|
item = load_from_videos(
|
||||||
|
|
|
@ -26,6 +26,12 @@ training:
|
||||||
save_freq: ???
|
save_freq: ???
|
||||||
log_freq: 250
|
log_freq: 250
|
||||||
save_model: true
|
save_model: true
|
||||||
|
# `dataset_use_cache` indicates whether to cache all dataset items as Tensors in RAM. Potentially useful for
|
||||||
|
# faster data loading with datasets small enough to fit in memory. If you wish to use dataloader workers,
|
||||||
|
# remember to set `dataloader_persistent_workers to True.
|
||||||
|
dataset_use_cache: false
|
||||||
|
dataloader_num_workers: 4
|
||||||
|
dataloader_persistent_workers: false
|
||||||
|
|
||||||
eval:
|
eval:
|
||||||
n_episodes: 1
|
n_episodes: 1
|
||||||
|
|
|
@ -10,6 +10,9 @@ training:
|
||||||
online_steps_between_rollouts: 1
|
online_steps_between_rollouts: 1
|
||||||
online_sampling_ratio: 0.5
|
online_sampling_ratio: 0.5
|
||||||
online_env_seed: 10000
|
online_env_seed: 10000
|
||||||
|
dataset_use_cache: true
|
||||||
|
dataloader_num_workers: 4
|
||||||
|
dataloader_persistent_workers: true
|
||||||
|
|
||||||
batch_size: 256
|
batch_size: 256
|
||||||
grad_clip_norm: 10.0
|
grad_clip_norm: 10.0
|
||||||
|
|
|
@ -368,7 +368,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
# create dataloader for offline training
|
# create dataloader for offline training
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
offline_dataset,
|
offline_dataset,
|
||||||
num_workers=4,
|
num_workers=cfg.training.dataloader_num_workers,
|
||||||
|
persistent_workers=cfg.training.dataloader_persistent_workers,
|
||||||
batch_size=cfg.training.batch_size,
|
batch_size=cfg.training.batch_size,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
pin_memory=cfg.device != "cpu",
|
pin_memory=cfg.device != "cpu",
|
||||||
|
@ -415,7 +416,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
)
|
)
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
concat_dataset,
|
concat_dataset,
|
||||||
num_workers=4,
|
num_workers=cfg.training.dataloader_num_workers,
|
||||||
|
persistent_workers=cfg.training.dataloader_persistent_workers,
|
||||||
batch_size=cfg.training.batch_size,
|
batch_size=cfg.training.batch_size,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
pin_memory=cfg.device != "cpu",
|
pin_memory=cfg.device != "cpu",
|
||||||
|
|
Loading…
Reference in New Issue