From a6edb85da42bdfad189e0f86f7a55cce84d4bfdf Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Tue, 2 Apr 2024 16:52:38 +0100 Subject: [PATCH] Remove random sampling --- lerobot/common/datasets/abstract.py | 18 +++++++----------- tests/test_datasets.py | 4 +++- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index a098c1ae..0aa89d65 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -130,13 +130,13 @@ class AbstractDataset(TensorDictReplayBuffer): else: self._transform = transform - def compute_or_load_stats(self, num_batch: int | None = None, batch_size: int = 32) -> TensorDict: + def compute_or_load_stats(self, batch_size: int = 32) -> TensorDict: stats_path = self.data_dir / "stats.pth" if stats_path.exists(): stats = torch.load(stats_path) else: logging.info(f"compute_stats and save to {stats_path}") - stats = self._compute_stats(num_batch, batch_size) + stats = self._compute_stats(batch_size) torch.save(stats, stats_path) return stats @@ -151,18 +151,14 @@ class AbstractDataset(TensorDictReplayBuffer): self.data_dir = self.root / self.dataset_id return TensorStorage(TensorDict.load_memmap(self.data_dir / "replay_buffer")) - def _compute_stats(self, num_batch: int | None = None, batch_size: int = 32): - """Compute dataset statistics including minimum, maximum, mean, and standard deviation. - - If `num_batch` is specified, we draw `num_batch` batches of size `batch_size` to compute the - statistics. If `num_batch` is not specified, we just consume the whole dataset (default). - """ + def _compute_stats(self, batch_size: int = 32): + """Compute dataset statistics including minimum, maximum, mean, and standard deviation.""" rb = TensorDictReplayBuffer( storage=self._storage, batch_size=32, prefetch=True, # Note: Due to be refactored soon. The point is that we should go through the whole dataset. - sampler=SamplerWithoutReplacement(drop_last=False, shuffle=num_batch is not None), + sampler=SamplerWithoutReplacement(drop_last=False, shuffle=False), ) # mean and std will be computed incrementally while max and min will track the running value. @@ -177,7 +173,7 @@ class AbstractDataset(TensorDictReplayBuffer): # Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get # surprises when rerunning the sampler. first_batch = None - for _ in tqdm.tqdm(num_batch or range(ceil(len(rb) / batch_size))): + for _ in tqdm.tqdm(range(ceil(len(rb) / batch_size))): batch = rb.sample() if first_batch is None: first_batch = deepcopy(batch) @@ -193,7 +189,7 @@ class AbstractDataset(TensorDictReplayBuffer): # Compute std first_batch_ = None - for _ in tqdm.tqdm(num_batch or range(ceil(len(rb) / batch_size))): + for _ in tqdm.tqdm(range(ceil(len(rb) / batch_size))): batch = rb.sample() # Sanity check to make sure the batches are still in the same order as before. if first_batch_ is None: diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 6d6a020d..fa45c646 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -53,7 +53,9 @@ def test_compute_stats(): batch_size=len(buffer), sampler=SamplerWithoutReplacement(), ).sample().float() - computed_stats = buffer._compute_stats() + # Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched + # computation of the statistics. + computed_stats = buffer._compute_stats(batch_size=int(len(all_data) * 0.75)) for k, pattern in buffer.stats_patterns.items(): expected_mean = einops.reduce(all_data[k], pattern, "mean") assert torch.allclose(computed_stats[k]["mean"], expected_mean)