This commit is contained in:
Alexander Soare 2024-04-03 09:56:46 +01:00
parent a6ec4fbf58
commit caf4ffcf65
1 changed files with 7 additions and 1 deletions

View File

@ -152,7 +152,13 @@ class AbstractDataset(TensorDictReplayBuffer):
return TensorStorage(TensorDict.load_memmap(self.data_dir / "replay_buffer"))
def _compute_stats(self, batch_size: int = 32):
"""Compute dataset statistics including minimum, maximum, mean, and standard deviation."""
"""Compute dataset statistics including minimum, maximum, mean, and standard deviation.
TODO(alexander-soare): Add a num_batches argument which essentially allows one to use a subset of the
full dataset (for handling very large datasets). The sampling would then have to be random
(preferably without replacement). Both stats computation loops would ideally sample the same
items.
"""
rb = TensorDictReplayBuffer(
storage=self._storage,
batch_size=32,