From 95293d459d2005151328e08272069c66bceaa0ca Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Tue, 2 Apr 2024 16:40:33 +0100 Subject: [PATCH 1/9] fix stats computation --- lerobot/common/datasets/abstract.py | 79 ++++++++++++++++++----------- tests/test_datasets.py | 30 +++++++++++ 2 files changed, 78 insertions(+), 31 deletions(-) diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index a81de49b..a098c1ae 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -1,4 +1,6 @@ import logging +from copy import deepcopy +from math import ceil from pathlib import Path from typing import Callable @@ -9,7 +11,7 @@ import tqdm from huggingface_hub import snapshot_download from tensordict import TensorDict from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer -from torchrl.data.replay_buffers.samplers import Sampler +from torchrl.data.replay_buffers.samplers import Sampler, SamplerWithoutReplacement from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer from torchrl.envs.transforms.transforms import Compose @@ -128,7 +130,7 @@ class AbstractDataset(TensorDictReplayBuffer): else: self._transform = transform - def compute_or_load_stats(self, num_batch=100, batch_size=32) -> TensorDict: + def compute_or_load_stats(self, num_batch: int | None = None, batch_size: int = 32) -> TensorDict: stats_path = self.data_dir / "stats.pth" if stats_path.exists(): stats = torch.load(stats_path) @@ -149,50 +151,65 @@ class AbstractDataset(TensorDictReplayBuffer): self.data_dir = self.root / self.dataset_id return TensorStorage(TensorDict.load_memmap(self.data_dir / "replay_buffer")) - def _compute_stats(self, num_batch=100, batch_size=32): + def _compute_stats(self, num_batch: int | None = None, batch_size: int = 32): + """Compute dataset statistics including minimum, maximum, mean, and standard deviation. + + If `num_batch` is specified, we draw `num_batch` batches of size `batch_size` to compute the + statistics. If `num_batch` is not specified, we just consume the whole dataset (default). + """ rb = TensorDictReplayBuffer( storage=self._storage, - batch_size=batch_size, + batch_size=32, prefetch=True, + # Note: Due to be refactored soon. The point is that we should go through the whole dataset. + sampler=SamplerWithoutReplacement(drop_last=False, shuffle=num_batch is not None), ) + # mean and std will be computed incrementally while max and min will track the running value. mean, std, max, min = {}, {}, {}, {} + for key in self.stats_patterns: + mean[key] = torch.tensor(0.0).float() + std[key] = torch.tensor(0.0).float() + max[key] = torch.tensor(-float("inf")).float() + min[key] = torch.tensor(float("inf")).float() # compute mean, min, max - for _ in tqdm.tqdm(range(num_batch)): + # Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get + # surprises when rerunning the sampler. + first_batch = None + for _ in tqdm.tqdm(num_batch or range(ceil(len(rb) / batch_size))): batch = rb.sample() + if first_batch is None: + first_batch = deepcopy(batch) for key, pattern in self.stats_patterns.items(): batch[key] = batch[key].float() - if key not in mean: - # first batch initialize mean, min, max - mean[key] = einops.reduce(batch[key], pattern, "mean") - max[key] = einops.reduce(batch[key], pattern, "max") - min[key] = einops.reduce(batch[key], pattern, "min") - else: - mean[key] += einops.reduce(batch[key], pattern, "mean") - max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max")) - min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min")) - batch = rb.sample() - - for key in self.stats_patterns: - mean[key] /= num_batch - - # compute std, min, max - for _ in tqdm.tqdm(range(num_batch)): - batch = rb.sample() - for key, pattern in self.stats_patterns.items(): - batch[key] = batch[key].float() - batch_mean = einops.reduce(batch[key], pattern, "mean") - if key not in std: - # first batch initialize std - std[key] = (batch_mean - mean[key]) ** 2 - else: - std[key] += (batch_mean - mean[key]) ** 2 + # Sum over batch then divide by total number of samples. + mean[key] = mean[key] + einops.reduce(batch[key], pattern, "mean") * batch.batch_size[0] max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max")) min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min")) for key in self.stats_patterns: - std[key] = torch.sqrt(std[key] / num_batch) + mean[key] = mean[key] / len(rb) + + # Compute std + first_batch_ = None + for _ in tqdm.tqdm(num_batch or range(ceil(len(rb) / batch_size))): + batch = rb.sample() + # Sanity check to make sure the batches are still in the same order as before. + if first_batch_ is None: + first_batch_ = deepcopy(batch) + for key in self.stats_patterns: + assert torch.equal(first_batch_[key], first_batch[key]) + for key, pattern in self.stats_patterns.items(): + batch[key] = batch[key].float() + # Sum over batch then divide by total number of samples. + std[key] = ( + std[key] + + einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean") * batch.batch_size[0] + ) + + for key in self.stats_patterns: + std[key] = torch.sqrt(std[key] / len(rb)) stats = TensorDict({}, batch_size=[]) for key in self.stats_patterns: diff --git a/tests/test_datasets.py b/tests/test_datasets.py index adaefcf5..6d6a020d 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -1,5 +1,8 @@ +import einops import pytest import torch +from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer +from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.utils import init_hydra_config @@ -30,3 +33,30 @@ def test_factory(env_name, dataset_id): # TODO(rcadene): we assume for now that image normalization takes place in the model assert img.max() <= 1.0 assert img.min() >= 0.0 + + +def test_compute_stats(): + """Check that the correct statistics are computed. + + We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do + because we are working with a small dataset). + + This test does not check that the stats_patterns are correct (instead, it relies on them). + """ + cfg = init_hydra_config( + DEFAULT_CONFIG_PATH, overrides=["env=aloha", "env.task=sim_transfer_cube_human"] + ) + buffer = make_offline_buffer(cfg) + # Get all of the data. + all_data = TensorDictReplayBuffer( + storage=buffer._storage, + batch_size=len(buffer), + sampler=SamplerWithoutReplacement(), + ).sample().float() + computed_stats = buffer._compute_stats() + for k, pattern in buffer.stats_patterns.items(): + expected_mean = einops.reduce(all_data[k], pattern, "mean") + assert torch.allclose(computed_stats[k]["mean"], expected_mean) + assert torch.allclose(computed_stats[k]["std"], torch.sqrt(einops.reduce((all_data[k] - expected_mean) ** 2, pattern, "mean"))) + assert torch.allclose(computed_stats[k]["min"], einops.reduce(all_data[k], pattern, "min")) + assert torch.allclose(computed_stats[k]["max"], einops.reduce(all_data[k], pattern, "max")) From a6edb85da42bdfad189e0f86f7a55cce84d4bfdf Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Tue, 2 Apr 2024 16:52:38 +0100 Subject: [PATCH 2/9] Remove random sampling --- lerobot/common/datasets/abstract.py | 18 +++++++----------- tests/test_datasets.py | 4 +++- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index a098c1ae..0aa89d65 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -130,13 +130,13 @@ class AbstractDataset(TensorDictReplayBuffer): else: self._transform = transform - def compute_or_load_stats(self, num_batch: int | None = None, batch_size: int = 32) -> TensorDict: + def compute_or_load_stats(self, batch_size: int = 32) -> TensorDict: stats_path = self.data_dir / "stats.pth" if stats_path.exists(): stats = torch.load(stats_path) else: logging.info(f"compute_stats and save to {stats_path}") - stats = self._compute_stats(num_batch, batch_size) + stats = self._compute_stats(batch_size) torch.save(stats, stats_path) return stats @@ -151,18 +151,14 @@ class AbstractDataset(TensorDictReplayBuffer): self.data_dir = self.root / self.dataset_id return TensorStorage(TensorDict.load_memmap(self.data_dir / "replay_buffer")) - def _compute_stats(self, num_batch: int | None = None, batch_size: int = 32): - """Compute dataset statistics including minimum, maximum, mean, and standard deviation. - - If `num_batch` is specified, we draw `num_batch` batches of size `batch_size` to compute the - statistics. If `num_batch` is not specified, we just consume the whole dataset (default). - """ + def _compute_stats(self, batch_size: int = 32): + """Compute dataset statistics including minimum, maximum, mean, and standard deviation.""" rb = TensorDictReplayBuffer( storage=self._storage, batch_size=32, prefetch=True, # Note: Due to be refactored soon. The point is that we should go through the whole dataset. - sampler=SamplerWithoutReplacement(drop_last=False, shuffle=num_batch is not None), + sampler=SamplerWithoutReplacement(drop_last=False, shuffle=False), ) # mean and std will be computed incrementally while max and min will track the running value. @@ -177,7 +173,7 @@ class AbstractDataset(TensorDictReplayBuffer): # Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get # surprises when rerunning the sampler. first_batch = None - for _ in tqdm.tqdm(num_batch or range(ceil(len(rb) / batch_size))): + for _ in tqdm.tqdm(range(ceil(len(rb) / batch_size))): batch = rb.sample() if first_batch is None: first_batch = deepcopy(batch) @@ -193,7 +189,7 @@ class AbstractDataset(TensorDictReplayBuffer): # Compute std first_batch_ = None - for _ in tqdm.tqdm(num_batch or range(ceil(len(rb) / batch_size))): + for _ in tqdm.tqdm(range(ceil(len(rb) / batch_size))): batch = rb.sample() # Sanity check to make sure the batches are still in the same order as before. if first_batch_ is None: diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 6d6a020d..fa45c646 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -53,7 +53,9 @@ def test_compute_stats(): batch_size=len(buffer), sampler=SamplerWithoutReplacement(), ).sample().float() - computed_stats = buffer._compute_stats() + # Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched + # computation of the statistics. + computed_stats = buffer._compute_stats(batch_size=int(len(all_data) * 0.75)) for k, pattern in buffer.stats_patterns.items(): expected_mean = einops.reduce(all_data[k], pattern, "mean") assert torch.allclose(computed_stats[k]["mean"], expected_mean) From 148df1c1d5a99c1f1a03654568d16a4a5ade3a4f Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Tue, 2 Apr 2024 16:57:25 +0100 Subject: [PATCH 3/9] add comment on test --- tests/test_datasets.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index fa45c646..5146b86c 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -54,7 +54,8 @@ def test_compute_stats(): sampler=SamplerWithoutReplacement(), ).sample().float() # Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched - # computation of the statistics. + # computation of the statistics. While doing this, we also make sure it works when we don't divide the + # dataset into even batches. computed_stats = buffer._compute_stats(batch_size=int(len(all_data) * 0.75)) for k, pattern in buffer.stats_patterns.items(): expected_mean = einops.reduce(all_data[k], pattern, "mean") From c3234adc7da8ded797c145680b6e71177f4217f5 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Tue, 2 Apr 2024 16:59:19 +0100 Subject: [PATCH 4/9] fix indentation --- tests/test_datasets.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 5146b86c..837e3188 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -60,6 +60,9 @@ def test_compute_stats(): for k, pattern in buffer.stats_patterns.items(): expected_mean = einops.reduce(all_data[k], pattern, "mean") assert torch.allclose(computed_stats[k]["mean"], expected_mean) - assert torch.allclose(computed_stats[k]["std"], torch.sqrt(einops.reduce((all_data[k] - expected_mean) ** 2, pattern, "mean"))) + assert torch.allclose( + computed_stats[k]["std"], + torch.sqrt(einops.reduce((all_data[k] - expected_mean) ** 2, pattern, "mean")) + ) assert torch.allclose(computed_stats[k]["min"], einops.reduce(all_data[k], pattern, "min")) assert torch.allclose(computed_stats[k]["max"], einops.reduce(all_data[k], pattern, "max")) From 72429531979df77d813817b6c14b5d4d0c7db84d Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Tue, 2 Apr 2024 19:19:13 +0100 Subject: [PATCH 5/9] revision --- tests/test_datasets.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 837e3188..df41b03f 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -36,12 +36,10 @@ def test_factory(env_name, dataset_id): def test_compute_stats(): - """Check that the correct statistics are computed. + """Check that the statistics are computed correctly according to the stats_patterns property. We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do because we are working with a small dataset). - - This test does not check that the stats_patterns are correct (instead, it relies on them). """ cfg = init_hydra_config( DEFAULT_CONFIG_PATH, overrides=["env=aloha", "env.task=sim_transfer_cube_human"] From e9eb26229317092dcdd81776526a7b535722c4b3 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Wed, 3 Apr 2024 09:44:20 +0100 Subject: [PATCH 6/9] numerically sound mean computation --- lerobot/common/datasets/abstract.py | 33 +++++++++++++++++------------ tests/test_datasets.py | 11 ++++++---- 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index 0aa89d65..18840a9e 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -169,28 +169,36 @@ class AbstractDataset(TensorDictReplayBuffer): max[key] = torch.tensor(-float("inf")).float() min[key] = torch.tensor(float("inf")).float() - # compute mean, min, max + # Compute mean, min, max. # Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get # surprises when rerunning the sampler. first_batch = None + running_item_count = 0 # for online mean computation for _ in tqdm.tqdm(range(ceil(len(rb) / batch_size))): batch = rb.sample() + this_batch_size = batch.batch_size[0] + running_item_count += this_batch_size if first_batch is None: first_batch = deepcopy(batch) for key, pattern in self.stats_patterns.items(): batch[key] = batch[key].float() - # Sum over batch then divide by total number of samples. - mean[key] = mean[key] + einops.reduce(batch[key], pattern, "mean") * batch.batch_size[0] + # Numerically stable update step for mean computation. + batch_mean = einops.reduce(batch[key], pattern, "mean") + # Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents + # the update step, N is the running item count, B is this batch size, x̄ is the running mean, + # and x is the current batch mean. Some rearrangement is then required to avoid risking + # numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. + mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max")) min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min")) - for key in self.stats_patterns: - mean[key] = mean[key] / len(rb) - - # Compute std + # Compute std. first_batch_ = None + running_item_count = 0 # for online std computation for _ in tqdm.tqdm(range(ceil(len(rb) / batch_size))): batch = rb.sample() + this_batch_size = batch.batch_size[0] + running_item_count += this_batch_size # Sanity check to make sure the batches are still in the same order as before. if first_batch_ is None: first_batch_ = deepcopy(batch) @@ -198,14 +206,13 @@ class AbstractDataset(TensorDictReplayBuffer): assert torch.equal(first_batch_[key], first_batch[key]) for key, pattern in self.stats_patterns.items(): batch[key] = batch[key].float() - # Sum over batch then divide by total number of samples. - std[key] = ( - std[key] - + einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean") * batch.batch_size[0] - ) + # Numerically stable update step for mean computation (where the mean is over squared + # residuals).See notes in the mean computation loop above. + batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean") + std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count for key in self.stats_patterns: - std[key] = torch.sqrt(std[key] / len(rb)) + std[key] = torch.sqrt(std[key]) stats = TensorDict({}, batch_size=[]) for key in self.stats_patterns: diff --git a/tests/test_datasets.py b/tests/test_datasets.py index df41b03f..05c9e80d 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -58,9 +58,12 @@ def test_compute_stats(): for k, pattern in buffer.stats_patterns.items(): expected_mean = einops.reduce(all_data[k], pattern, "mean") assert torch.allclose(computed_stats[k]["mean"], expected_mean) - assert torch.allclose( - computed_stats[k]["std"], - torch.sqrt(einops.reduce((all_data[k] - expected_mean) ** 2, pattern, "mean")) - ) + try: + assert torch.allclose( + computed_stats[k]["std"], + torch.sqrt(einops.reduce((all_data[k] - expected_mean) ** 2, pattern, "mean")) + ) + except: + breakpoint() assert torch.allclose(computed_stats[k]["min"], einops.reduce(all_data[k], pattern, "min")) assert torch.allclose(computed_stats[k]["max"], einops.reduce(all_data[k], pattern, "max")) From c50a62dd6dd4d21c41c26f9afa4f30f18fc90fd7 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Wed, 3 Apr 2024 09:47:38 +0100 Subject: [PATCH 7/9] clarifying math --- lerobot/common/datasets/abstract.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index 18840a9e..13be4cab 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -187,7 +187,8 @@ class AbstractDataset(TensorDictReplayBuffer): # Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents # the update step, N is the running item count, B is this batch size, x̄ is the running mean, # and x is the current batch mean. Some rearrangement is then required to avoid risking - # numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. + # numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields + # x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max")) min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min")) From a6ec4fbf58c221047172b358c650d169b5b5a544 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Wed, 3 Apr 2024 09:53:15 +0100 Subject: [PATCH 8/9] remove try-catch --- tests/test_datasets.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 05c9e80d..df41b03f 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -58,12 +58,9 @@ def test_compute_stats(): for k, pattern in buffer.stats_patterns.items(): expected_mean = einops.reduce(all_data[k], pattern, "mean") assert torch.allclose(computed_stats[k]["mean"], expected_mean) - try: - assert torch.allclose( - computed_stats[k]["std"], - torch.sqrt(einops.reduce((all_data[k] - expected_mean) ** 2, pattern, "mean")) - ) - except: - breakpoint() + assert torch.allclose( + computed_stats[k]["std"], + torch.sqrt(einops.reduce((all_data[k] - expected_mean) ** 2, pattern, "mean")) + ) assert torch.allclose(computed_stats[k]["min"], einops.reduce(all_data[k], pattern, "min")) assert torch.allclose(computed_stats[k]["max"], einops.reduce(all_data[k], pattern, "max")) From caf4ffcf654168f5a3bba892059d781a9b159628 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Wed, 3 Apr 2024 09:56:46 +0100 Subject: [PATCH 9/9] add TODO --- lerobot/common/datasets/abstract.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index 13be4cab..e9e9c610 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -152,7 +152,13 @@ class AbstractDataset(TensorDictReplayBuffer): return TensorStorage(TensorDict.load_memmap(self.data_dir / "replay_buffer")) def _compute_stats(self, batch_size: int = 32): - """Compute dataset statistics including minimum, maximum, mean, and standard deviation.""" + """Compute dataset statistics including minimum, maximum, mean, and standard deviation. + + TODO(alexander-soare): Add a num_batches argument which essentially allows one to use a subset of the + full dataset (for handling very large datasets). The sampling would then have to be random + (preferably without replacement). Both stats computation loops would ideally sample the same + items. + """ rb = TensorDictReplayBuffer( storage=self._storage, batch_size=32,