Merge pull request #66 from alexander-soare/fix_stats_computation

fix stats computation
This commit is contained in:
Alexander Soare 2024-04-03 10:03:47 +01:00 committed by GitHub
commit 920e0d118b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 92 additions and 31 deletions

View File

@ -1,4 +1,6 @@
import logging
from copy import deepcopy
from math import ceil
from pathlib import Path
from typing import Callable
@ -9,7 +11,7 @@ import tqdm
from huggingface_hub import snapshot_download
from tensordict import TensorDict
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import Sampler
from torchrl.data.replay_buffers.samplers import Sampler, SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
from torchrl.envs.transforms.transforms import Compose
@ -128,13 +130,13 @@ class AbstractDataset(TensorDictReplayBuffer):
else:
self._transform = transform
def compute_or_load_stats(self, num_batch=100, batch_size=32) -> TensorDict:
def compute_or_load_stats(self, batch_size: int = 32) -> TensorDict:
stats_path = self.data_dir / "stats.pth"
if stats_path.exists():
stats = torch.load(stats_path)
else:
logging.info(f"compute_stats and save to {stats_path}")
stats = self._compute_stats(num_batch, batch_size)
stats = self._compute_stats(batch_size)
torch.save(stats, stats_path)
return stats
@ -149,50 +151,75 @@ class AbstractDataset(TensorDictReplayBuffer):
self.data_dir = self.root / self.dataset_id
return TensorStorage(TensorDict.load_memmap(self.data_dir / "replay_buffer"))
def _compute_stats(self, num_batch=100, batch_size=32):
def _compute_stats(self, batch_size: int = 32):
"""Compute dataset statistics including minimum, maximum, mean, and standard deviation.
TODO(alexander-soare): Add a num_batches argument which essentially allows one to use a subset of the
full dataset (for handling very large datasets). The sampling would then have to be random
(preferably without replacement). Both stats computation loops would ideally sample the same
items.
"""
rb = TensorDictReplayBuffer(
storage=self._storage,
batch_size=batch_size,
batch_size=32,
prefetch=True,
# Note: Due to be refactored soon. The point is that we should go through the whole dataset.
sampler=SamplerWithoutReplacement(drop_last=False, shuffle=False),
)
# mean and std will be computed incrementally while max and min will track the running value.
mean, std, max, min = {}, {}, {}, {}
# compute mean, min, max
for _ in tqdm.tqdm(range(num_batch)):
batch = rb.sample()
for key, pattern in self.stats_patterns.items():
batch[key] = batch[key].float()
if key not in mean:
# first batch initialize mean, min, max
mean[key] = einops.reduce(batch[key], pattern, "mean")
max[key] = einops.reduce(batch[key], pattern, "max")
min[key] = einops.reduce(batch[key], pattern, "min")
else:
mean[key] += einops.reduce(batch[key], pattern, "mean")
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
batch = rb.sample()
for key in self.stats_patterns:
mean[key] /= num_batch
mean[key] = torch.tensor(0.0).float()
std[key] = torch.tensor(0.0).float()
max[key] = torch.tensor(-float("inf")).float()
min[key] = torch.tensor(float("inf")).float()
# compute std, min, max
for _ in tqdm.tqdm(range(num_batch)):
# Compute mean, min, max.
# Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get
# surprises when rerunning the sampler.
first_batch = None
running_item_count = 0 # for online mean computation
for _ in tqdm.tqdm(range(ceil(len(rb) / batch_size))):
batch = rb.sample()
this_batch_size = batch.batch_size[0]
running_item_count += this_batch_size
if first_batch is None:
first_batch = deepcopy(batch)
for key, pattern in self.stats_patterns.items():
batch[key] = batch[key].float()
# Numerically stable update step for mean computation.
batch_mean = einops.reduce(batch[key], pattern, "mean")
if key not in std:
# first batch initialize std
std[key] = (batch_mean - mean[key]) ** 2
else:
std[key] += (batch_mean - mean[key]) ** 2
# Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents
# the update step, N is the running item count, B is this batch size, x̄ is the running mean,
# and x is the current batch mean. Some rearrangement is then required to avoid risking
# numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields
# x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ
mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
# Compute std.
first_batch_ = None
running_item_count = 0 # for online std computation
for _ in tqdm.tqdm(range(ceil(len(rb) / batch_size))):
batch = rb.sample()
this_batch_size = batch.batch_size[0]
running_item_count += this_batch_size
# Sanity check to make sure the batches are still in the same order as before.
if first_batch_ is None:
first_batch_ = deepcopy(batch)
for key in self.stats_patterns:
assert torch.equal(first_batch_[key], first_batch[key])
for key, pattern in self.stats_patterns.items():
batch[key] = batch[key].float()
# Numerically stable update step for mean computation (where the mean is over squared
# residuals).See notes in the mean computation loop above.
batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean")
std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count
for key in self.stats_patterns:
std[key] = torch.sqrt(std[key] / num_batch)
std[key] = torch.sqrt(std[key])
stats = TensorDict({}, batch_size=[])
for key in self.stats_patterns:

View File

@ -1,5 +1,8 @@
import einops
import pytest
import torch
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.utils import init_hydra_config
@ -30,3 +33,34 @@ def test_factory(env_name, dataset_id):
# TODO(rcadene): we assume for now that image normalization takes place in the model
assert img.max() <= 1.0
assert img.min() >= 0.0
def test_compute_stats():
"""Check that the statistics are computed correctly according to the stats_patterns property.
We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do
because we are working with a small dataset).
"""
cfg = init_hydra_config(
DEFAULT_CONFIG_PATH, overrides=["env=aloha", "env.task=sim_transfer_cube_human"]
)
buffer = make_offline_buffer(cfg)
# Get all of the data.
all_data = TensorDictReplayBuffer(
storage=buffer._storage,
batch_size=len(buffer),
sampler=SamplerWithoutReplacement(),
).sample().float()
# Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
# computation of the statistics. While doing this, we also make sure it works when we don't divide the
# dataset into even batches.
computed_stats = buffer._compute_stats(batch_size=int(len(all_data) * 0.75))
for k, pattern in buffer.stats_patterns.items():
expected_mean = einops.reduce(all_data[k], pattern, "mean")
assert torch.allclose(computed_stats[k]["mean"], expected_mean)
assert torch.allclose(
computed_stats[k]["std"],
torch.sqrt(einops.reduce((all_data[k] - expected_mean) ** 2, pattern, "mean"))
)
assert torch.allclose(computed_stats[k]["min"], einops.reduce(all_data[k], pattern, "min"))
assert torch.allclose(computed_stats[k]["max"], einops.reduce(all_data[k], pattern, "max"))