fix stats computation

This commit is contained in:
Alexander Soare 2024-04-02 16:40:33 +01:00
parent 11cbf1bea1
commit 95293d459d
2 changed files with 78 additions and 31 deletions
lerobot/common/datasets
tests

View File

@ -1,4 +1,6 @@
import logging
from copy import deepcopy
from math import ceil
from pathlib import Path
from typing import Callable
@ -9,7 +11,7 @@ import tqdm
from huggingface_hub import snapshot_download
from tensordict import TensorDict
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import Sampler
from torchrl.data.replay_buffers.samplers import Sampler, SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
from torchrl.envs.transforms.transforms import Compose
@ -128,7 +130,7 @@ class AbstractDataset(TensorDictReplayBuffer):
else:
self._transform = transform
def compute_or_load_stats(self, num_batch=100, batch_size=32) -> TensorDict:
def compute_or_load_stats(self, num_batch: int | None = None, batch_size: int = 32) -> TensorDict:
stats_path = self.data_dir / "stats.pth"
if stats_path.exists():
stats = torch.load(stats_path)
@ -149,50 +151,65 @@ class AbstractDataset(TensorDictReplayBuffer):
self.data_dir = self.root / self.dataset_id
return TensorStorage(TensorDict.load_memmap(self.data_dir / "replay_buffer"))
def _compute_stats(self, num_batch=100, batch_size=32):
def _compute_stats(self, num_batch: int | None = None, batch_size: int = 32):
"""Compute dataset statistics including minimum, maximum, mean, and standard deviation.
If `num_batch` is specified, we draw `num_batch` batches of size `batch_size` to compute the
statistics. If `num_batch` is not specified, we just consume the whole dataset (default).
"""
rb = TensorDictReplayBuffer(
storage=self._storage,
batch_size=batch_size,
batch_size=32,
prefetch=True,
# Note: Due to be refactored soon. The point is that we should go through the whole dataset.
sampler=SamplerWithoutReplacement(drop_last=False, shuffle=num_batch is not None),
)
# mean and std will be computed incrementally while max and min will track the running value.
mean, std, max, min = {}, {}, {}, {}
for key in self.stats_patterns:
mean[key] = torch.tensor(0.0).float()
std[key] = torch.tensor(0.0).float()
max[key] = torch.tensor(-float("inf")).float()
min[key] = torch.tensor(float("inf")).float()
# compute mean, min, max
for _ in tqdm.tqdm(range(num_batch)):
# Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get
# surprises when rerunning the sampler.
first_batch = None
for _ in tqdm.tqdm(num_batch or range(ceil(len(rb) / batch_size))):
batch = rb.sample()
if first_batch is None:
first_batch = deepcopy(batch)
for key, pattern in self.stats_patterns.items():
batch[key] = batch[key].float()
if key not in mean:
# first batch initialize mean, min, max
mean[key] = einops.reduce(batch[key], pattern, "mean")
max[key] = einops.reduce(batch[key], pattern, "max")
min[key] = einops.reduce(batch[key], pattern, "min")
else:
mean[key] += einops.reduce(batch[key], pattern, "mean")
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
batch = rb.sample()
for key in self.stats_patterns:
mean[key] /= num_batch
# compute std, min, max
for _ in tqdm.tqdm(range(num_batch)):
batch = rb.sample()
for key, pattern in self.stats_patterns.items():
batch[key] = batch[key].float()
batch_mean = einops.reduce(batch[key], pattern, "mean")
if key not in std:
# first batch initialize std
std[key] = (batch_mean - mean[key]) ** 2
else:
std[key] += (batch_mean - mean[key]) ** 2
# Sum over batch then divide by total number of samples.
mean[key] = mean[key] + einops.reduce(batch[key], pattern, "mean") * batch.batch_size[0]
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
for key in self.stats_patterns:
std[key] = torch.sqrt(std[key] / num_batch)
mean[key] = mean[key] / len(rb)
# Compute std
first_batch_ = None
for _ in tqdm.tqdm(num_batch or range(ceil(len(rb) / batch_size))):
batch = rb.sample()
# Sanity check to make sure the batches are still in the same order as before.
if first_batch_ is None:
first_batch_ = deepcopy(batch)
for key in self.stats_patterns:
assert torch.equal(first_batch_[key], first_batch[key])
for key, pattern in self.stats_patterns.items():
batch[key] = batch[key].float()
# Sum over batch then divide by total number of samples.
std[key] = (
std[key]
+ einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean") * batch.batch_size[0]
)
for key in self.stats_patterns:
std[key] = torch.sqrt(std[key] / len(rb))
stats = TensorDict({}, batch_size=[])
for key in self.stats_patterns:

View File

@ -1,5 +1,8 @@
import einops
import pytest
import torch
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.utils import init_hydra_config
@ -30,3 +33,30 @@ def test_factory(env_name, dataset_id):
# TODO(rcadene): we assume for now that image normalization takes place in the model
assert img.max() <= 1.0
assert img.min() >= 0.0
def test_compute_stats():
"""Check that the correct statistics are computed.
We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do
because we are working with a small dataset).
This test does not check that the stats_patterns are correct (instead, it relies on them).
"""
cfg = init_hydra_config(
DEFAULT_CONFIG_PATH, overrides=["env=aloha", "env.task=sim_transfer_cube_human"]
)
buffer = make_offline_buffer(cfg)
# Get all of the data.
all_data = TensorDictReplayBuffer(
storage=buffer._storage,
batch_size=len(buffer),
sampler=SamplerWithoutReplacement(),
).sample().float()
computed_stats = buffer._compute_stats()
for k, pattern in buffer.stats_patterns.items():
expected_mean = einops.reduce(all_data[k], pattern, "mean")
assert torch.allclose(computed_stats[k]["mean"], expected_mean)
assert torch.allclose(computed_stats[k]["std"], torch.sqrt(einops.reduce((all_data[k] - expected_mean) ** 2, pattern, "mean")))
assert torch.allclose(computed_stats[k]["min"], einops.reduce(all_data[k], pattern, "min"))
assert torch.allclose(computed_stats[k]["max"], einops.reduce(all_data[k], pattern, "max"))