diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index a81de49b..e9e9c610 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -1,4 +1,6 @@ import logging +from copy import deepcopy +from math import ceil from pathlib import Path from typing import Callable @@ -9,7 +11,7 @@ import tqdm from huggingface_hub import snapshot_download from tensordict import TensorDict from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer -from torchrl.data.replay_buffers.samplers import Sampler +from torchrl.data.replay_buffers.samplers import Sampler, SamplerWithoutReplacement from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer from torchrl.envs.transforms.transforms import Compose @@ -128,13 +130,13 @@ class AbstractDataset(TensorDictReplayBuffer): else: self._transform = transform - def compute_or_load_stats(self, num_batch=100, batch_size=32) -> TensorDict: + def compute_or_load_stats(self, batch_size: int = 32) -> TensorDict: stats_path = self.data_dir / "stats.pth" if stats_path.exists(): stats = torch.load(stats_path) else: logging.info(f"compute_stats and save to {stats_path}") - stats = self._compute_stats(num_batch, batch_size) + stats = self._compute_stats(batch_size) torch.save(stats, stats_path) return stats @@ -149,50 +151,75 @@ class AbstractDataset(TensorDictReplayBuffer): self.data_dir = self.root / self.dataset_id return TensorStorage(TensorDict.load_memmap(self.data_dir / "replay_buffer")) - def _compute_stats(self, num_batch=100, batch_size=32): + def _compute_stats(self, batch_size: int = 32): + """Compute dataset statistics including minimum, maximum, mean, and standard deviation. + + TODO(alexander-soare): Add a num_batches argument which essentially allows one to use a subset of the + full dataset (for handling very large datasets). The sampling would then have to be random + (preferably without replacement). Both stats computation loops would ideally sample the same + items. + """ rb = TensorDictReplayBuffer( storage=self._storage, - batch_size=batch_size, + batch_size=32, prefetch=True, + # Note: Due to be refactored soon. The point is that we should go through the whole dataset. + sampler=SamplerWithoutReplacement(drop_last=False, shuffle=False), ) + # mean and std will be computed incrementally while max and min will track the running value. mean, std, max, min = {}, {}, {}, {} - - # compute mean, min, max - for _ in tqdm.tqdm(range(num_batch)): - batch = rb.sample() - for key, pattern in self.stats_patterns.items(): - batch[key] = batch[key].float() - if key not in mean: - # first batch initialize mean, min, max - mean[key] = einops.reduce(batch[key], pattern, "mean") - max[key] = einops.reduce(batch[key], pattern, "max") - min[key] = einops.reduce(batch[key], pattern, "min") - else: - mean[key] += einops.reduce(batch[key], pattern, "mean") - max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max")) - min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min")) - batch = rb.sample() - for key in self.stats_patterns: - mean[key] /= num_batch + mean[key] = torch.tensor(0.0).float() + std[key] = torch.tensor(0.0).float() + max[key] = torch.tensor(-float("inf")).float() + min[key] = torch.tensor(float("inf")).float() - # compute std, min, max - for _ in tqdm.tqdm(range(num_batch)): + # Compute mean, min, max. + # Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get + # surprises when rerunning the sampler. + first_batch = None + running_item_count = 0 # for online mean computation + for _ in tqdm.tqdm(range(ceil(len(rb) / batch_size))): batch = rb.sample() + this_batch_size = batch.batch_size[0] + running_item_count += this_batch_size + if first_batch is None: + first_batch = deepcopy(batch) for key, pattern in self.stats_patterns.items(): batch[key] = batch[key].float() + # Numerically stable update step for mean computation. batch_mean = einops.reduce(batch[key], pattern, "mean") - if key not in std: - # first batch initialize std - std[key] = (batch_mean - mean[key]) ** 2 - else: - std[key] += (batch_mean - mean[key]) ** 2 + # Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents + # the update step, N is the running item count, B is this batch size, x̄ is the running mean, + # and x is the current batch mean. Some rearrangement is then required to avoid risking + # numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields + # x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ + mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max")) min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min")) + # Compute std. + first_batch_ = None + running_item_count = 0 # for online std computation + for _ in tqdm.tqdm(range(ceil(len(rb) / batch_size))): + batch = rb.sample() + this_batch_size = batch.batch_size[0] + running_item_count += this_batch_size + # Sanity check to make sure the batches are still in the same order as before. + if first_batch_ is None: + first_batch_ = deepcopy(batch) + for key in self.stats_patterns: + assert torch.equal(first_batch_[key], first_batch[key]) + for key, pattern in self.stats_patterns.items(): + batch[key] = batch[key].float() + # Numerically stable update step for mean computation (where the mean is over squared + # residuals).See notes in the mean computation loop above. + batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean") + std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count + for key in self.stats_patterns: - std[key] = torch.sqrt(std[key] / num_batch) + std[key] = torch.sqrt(std[key]) stats = TensorDict({}, batch_size=[]) for key in self.stats_patterns: diff --git a/tests/test_datasets.py b/tests/test_datasets.py index adaefcf5..df41b03f 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -1,5 +1,8 @@ +import einops import pytest import torch +from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer +from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.utils import init_hydra_config @@ -30,3 +33,34 @@ def test_factory(env_name, dataset_id): # TODO(rcadene): we assume for now that image normalization takes place in the model assert img.max() <= 1.0 assert img.min() >= 0.0 + + +def test_compute_stats(): + """Check that the statistics are computed correctly according to the stats_patterns property. + + We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do + because we are working with a small dataset). + """ + cfg = init_hydra_config( + DEFAULT_CONFIG_PATH, overrides=["env=aloha", "env.task=sim_transfer_cube_human"] + ) + buffer = make_offline_buffer(cfg) + # Get all of the data. + all_data = TensorDictReplayBuffer( + storage=buffer._storage, + batch_size=len(buffer), + sampler=SamplerWithoutReplacement(), + ).sample().float() + # Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched + # computation of the statistics. While doing this, we also make sure it works when we don't divide the + # dataset into even batches. + computed_stats = buffer._compute_stats(batch_size=int(len(all_data) * 0.75)) + for k, pattern in buffer.stats_patterns.items(): + expected_mean = einops.reduce(all_data[k], pattern, "mean") + assert torch.allclose(computed_stats[k]["mean"], expected_mean) + assert torch.allclose( + computed_stats[k]["std"], + torch.sqrt(einops.reduce((all_data[k] - expected_mean) ** 2, pattern, "mean")) + ) + assert torch.allclose(computed_stats[k]["min"], einops.reduce(all_data[k], pattern, "min")) + assert torch.allclose(computed_stats[k]["max"], einops.reduce(all_data[k], pattern, "max"))