diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index a81de49b..a098c1ae 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -1,4 +1,6 @@ import logging +from copy import deepcopy +from math import ceil from pathlib import Path from typing import Callable @@ -9,7 +11,7 @@ import tqdm from huggingface_hub import snapshot_download from tensordict import TensorDict from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer -from torchrl.data.replay_buffers.samplers import Sampler +from torchrl.data.replay_buffers.samplers import Sampler, SamplerWithoutReplacement from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer from torchrl.envs.transforms.transforms import Compose @@ -128,7 +130,7 @@ class AbstractDataset(TensorDictReplayBuffer): else: self._transform = transform - def compute_or_load_stats(self, num_batch=100, batch_size=32) -> TensorDict: + def compute_or_load_stats(self, num_batch: int | None = None, batch_size: int = 32) -> TensorDict: stats_path = self.data_dir / "stats.pth" if stats_path.exists(): stats = torch.load(stats_path) @@ -149,50 +151,65 @@ class AbstractDataset(TensorDictReplayBuffer): self.data_dir = self.root / self.dataset_id return TensorStorage(TensorDict.load_memmap(self.data_dir / "replay_buffer")) - def _compute_stats(self, num_batch=100, batch_size=32): + def _compute_stats(self, num_batch: int | None = None, batch_size: int = 32): + """Compute dataset statistics including minimum, maximum, mean, and standard deviation. + + If `num_batch` is specified, we draw `num_batch` batches of size `batch_size` to compute the + statistics. If `num_batch` is not specified, we just consume the whole dataset (default). + """ rb = TensorDictReplayBuffer( storage=self._storage, - batch_size=batch_size, + batch_size=32, prefetch=True, + # Note: Due to be refactored soon. The point is that we should go through the whole dataset. + sampler=SamplerWithoutReplacement(drop_last=False, shuffle=num_batch is not None), ) + # mean and std will be computed incrementally while max and min will track the running value. mean, std, max, min = {}, {}, {}, {} + for key in self.stats_patterns: + mean[key] = torch.tensor(0.0).float() + std[key] = torch.tensor(0.0).float() + max[key] = torch.tensor(-float("inf")).float() + min[key] = torch.tensor(float("inf")).float() # compute mean, min, max - for _ in tqdm.tqdm(range(num_batch)): + # Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get + # surprises when rerunning the sampler. + first_batch = None + for _ in tqdm.tqdm(num_batch or range(ceil(len(rb) / batch_size))): batch = rb.sample() + if first_batch is None: + first_batch = deepcopy(batch) for key, pattern in self.stats_patterns.items(): batch[key] = batch[key].float() - if key not in mean: - # first batch initialize mean, min, max - mean[key] = einops.reduce(batch[key], pattern, "mean") - max[key] = einops.reduce(batch[key], pattern, "max") - min[key] = einops.reduce(batch[key], pattern, "min") - else: - mean[key] += einops.reduce(batch[key], pattern, "mean") - max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max")) - min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min")) - batch = rb.sample() - - for key in self.stats_patterns: - mean[key] /= num_batch - - # compute std, min, max - for _ in tqdm.tqdm(range(num_batch)): - batch = rb.sample() - for key, pattern in self.stats_patterns.items(): - batch[key] = batch[key].float() - batch_mean = einops.reduce(batch[key], pattern, "mean") - if key not in std: - # first batch initialize std - std[key] = (batch_mean - mean[key]) ** 2 - else: - std[key] += (batch_mean - mean[key]) ** 2 + # Sum over batch then divide by total number of samples. + mean[key] = mean[key] + einops.reduce(batch[key], pattern, "mean") * batch.batch_size[0] max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max")) min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min")) for key in self.stats_patterns: - std[key] = torch.sqrt(std[key] / num_batch) + mean[key] = mean[key] / len(rb) + + # Compute std + first_batch_ = None + for _ in tqdm.tqdm(num_batch or range(ceil(len(rb) / batch_size))): + batch = rb.sample() + # Sanity check to make sure the batches are still in the same order as before. + if first_batch_ is None: + first_batch_ = deepcopy(batch) + for key in self.stats_patterns: + assert torch.equal(first_batch_[key], first_batch[key]) + for key, pattern in self.stats_patterns.items(): + batch[key] = batch[key].float() + # Sum over batch then divide by total number of samples. + std[key] = ( + std[key] + + einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean") * batch.batch_size[0] + ) + + for key in self.stats_patterns: + std[key] = torch.sqrt(std[key] / len(rb)) stats = TensorDict({}, batch_size=[]) for key in self.stats_patterns: diff --git a/tests/test_datasets.py b/tests/test_datasets.py index adaefcf5..6d6a020d 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -1,5 +1,8 @@ +import einops import pytest import torch +from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer +from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.utils import init_hydra_config @@ -30,3 +33,30 @@ def test_factory(env_name, dataset_id): # TODO(rcadene): we assume for now that image normalization takes place in the model assert img.max() <= 1.0 assert img.min() >= 0.0 + + +def test_compute_stats(): + """Check that the correct statistics are computed. + + We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do + because we are working with a small dataset). + + This test does not check that the stats_patterns are correct (instead, it relies on them). + """ + cfg = init_hydra_config( + DEFAULT_CONFIG_PATH, overrides=["env=aloha", "env.task=sim_transfer_cube_human"] + ) + buffer = make_offline_buffer(cfg) + # Get all of the data. + all_data = TensorDictReplayBuffer( + storage=buffer._storage, + batch_size=len(buffer), + sampler=SamplerWithoutReplacement(), + ).sample().float() + computed_stats = buffer._compute_stats() + for k, pattern in buffer.stats_patterns.items(): + expected_mean = einops.reduce(all_data[k], pattern, "mean") + assert torch.allclose(computed_stats[k]["mean"], expected_mean) + assert torch.allclose(computed_stats[k]["std"], torch.sqrt(einops.reduce((all_data[k] - expected_mean) ** 2, pattern, "mean"))) + assert torch.allclose(computed_stats[k]["min"], einops.reduce(all_data[k], pattern, "min")) + assert torch.allclose(computed_stats[k]["max"], einops.reduce(all_data[k], pattern, "max"))