Merge pull request #66 from alexander-soare/fix_stats_computation
fix stats computation
This commit is contained in:
commit
920e0d118b
|
@ -1,4 +1,6 @@
|
||||||
import logging
|
import logging
|
||||||
|
from copy import deepcopy
|
||||||
|
from math import ceil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
|
@ -9,7 +11,7 @@ import tqdm
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from tensordict import TensorDict
|
from tensordict import TensorDict
|
||||||
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
|
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
|
||||||
from torchrl.data.replay_buffers.samplers import Sampler
|
from torchrl.data.replay_buffers.samplers import Sampler, SamplerWithoutReplacement
|
||||||
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
||||||
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
||||||
from torchrl.envs.transforms.transforms import Compose
|
from torchrl.envs.transforms.transforms import Compose
|
||||||
|
@ -128,13 +130,13 @@ class AbstractDataset(TensorDictReplayBuffer):
|
||||||
else:
|
else:
|
||||||
self._transform = transform
|
self._transform = transform
|
||||||
|
|
||||||
def compute_or_load_stats(self, num_batch=100, batch_size=32) -> TensorDict:
|
def compute_or_load_stats(self, batch_size: int = 32) -> TensorDict:
|
||||||
stats_path = self.data_dir / "stats.pth"
|
stats_path = self.data_dir / "stats.pth"
|
||||||
if stats_path.exists():
|
if stats_path.exists():
|
||||||
stats = torch.load(stats_path)
|
stats = torch.load(stats_path)
|
||||||
else:
|
else:
|
||||||
logging.info(f"compute_stats and save to {stats_path}")
|
logging.info(f"compute_stats and save to {stats_path}")
|
||||||
stats = self._compute_stats(num_batch, batch_size)
|
stats = self._compute_stats(batch_size)
|
||||||
torch.save(stats, stats_path)
|
torch.save(stats, stats_path)
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
|
@ -149,50 +151,75 @@ class AbstractDataset(TensorDictReplayBuffer):
|
||||||
self.data_dir = self.root / self.dataset_id
|
self.data_dir = self.root / self.dataset_id
|
||||||
return TensorStorage(TensorDict.load_memmap(self.data_dir / "replay_buffer"))
|
return TensorStorage(TensorDict.load_memmap(self.data_dir / "replay_buffer"))
|
||||||
|
|
||||||
def _compute_stats(self, num_batch=100, batch_size=32):
|
def _compute_stats(self, batch_size: int = 32):
|
||||||
|
"""Compute dataset statistics including minimum, maximum, mean, and standard deviation.
|
||||||
|
|
||||||
|
TODO(alexander-soare): Add a num_batches argument which essentially allows one to use a subset of the
|
||||||
|
full dataset (for handling very large datasets). The sampling would then have to be random
|
||||||
|
(preferably without replacement). Both stats computation loops would ideally sample the same
|
||||||
|
items.
|
||||||
|
"""
|
||||||
rb = TensorDictReplayBuffer(
|
rb = TensorDictReplayBuffer(
|
||||||
storage=self._storage,
|
storage=self._storage,
|
||||||
batch_size=batch_size,
|
batch_size=32,
|
||||||
prefetch=True,
|
prefetch=True,
|
||||||
|
# Note: Due to be refactored soon. The point is that we should go through the whole dataset.
|
||||||
|
sampler=SamplerWithoutReplacement(drop_last=False, shuffle=False),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# mean and std will be computed incrementally while max and min will track the running value.
|
||||||
mean, std, max, min = {}, {}, {}, {}
|
mean, std, max, min = {}, {}, {}, {}
|
||||||
|
|
||||||
# compute mean, min, max
|
|
||||||
for _ in tqdm.tqdm(range(num_batch)):
|
|
||||||
batch = rb.sample()
|
|
||||||
for key, pattern in self.stats_patterns.items():
|
|
||||||
batch[key] = batch[key].float()
|
|
||||||
if key not in mean:
|
|
||||||
# first batch initialize mean, min, max
|
|
||||||
mean[key] = einops.reduce(batch[key], pattern, "mean")
|
|
||||||
max[key] = einops.reduce(batch[key], pattern, "max")
|
|
||||||
min[key] = einops.reduce(batch[key], pattern, "min")
|
|
||||||
else:
|
|
||||||
mean[key] += einops.reduce(batch[key], pattern, "mean")
|
|
||||||
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
|
|
||||||
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
|
|
||||||
batch = rb.sample()
|
|
||||||
|
|
||||||
for key in self.stats_patterns:
|
for key in self.stats_patterns:
|
||||||
mean[key] /= num_batch
|
mean[key] = torch.tensor(0.0).float()
|
||||||
|
std[key] = torch.tensor(0.0).float()
|
||||||
|
max[key] = torch.tensor(-float("inf")).float()
|
||||||
|
min[key] = torch.tensor(float("inf")).float()
|
||||||
|
|
||||||
# compute std, min, max
|
# Compute mean, min, max.
|
||||||
for _ in tqdm.tqdm(range(num_batch)):
|
# Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get
|
||||||
|
# surprises when rerunning the sampler.
|
||||||
|
first_batch = None
|
||||||
|
running_item_count = 0 # for online mean computation
|
||||||
|
for _ in tqdm.tqdm(range(ceil(len(rb) / batch_size))):
|
||||||
batch = rb.sample()
|
batch = rb.sample()
|
||||||
|
this_batch_size = batch.batch_size[0]
|
||||||
|
running_item_count += this_batch_size
|
||||||
|
if first_batch is None:
|
||||||
|
first_batch = deepcopy(batch)
|
||||||
for key, pattern in self.stats_patterns.items():
|
for key, pattern in self.stats_patterns.items():
|
||||||
batch[key] = batch[key].float()
|
batch[key] = batch[key].float()
|
||||||
|
# Numerically stable update step for mean computation.
|
||||||
batch_mean = einops.reduce(batch[key], pattern, "mean")
|
batch_mean = einops.reduce(batch[key], pattern, "mean")
|
||||||
if key not in std:
|
# Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents
|
||||||
# first batch initialize std
|
# the update step, N is the running item count, B is this batch size, x̄ is the running mean,
|
||||||
std[key] = (batch_mean - mean[key]) ** 2
|
# and x is the current batch mean. Some rearrangement is then required to avoid risking
|
||||||
else:
|
# numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields
|
||||||
std[key] += (batch_mean - mean[key]) ** 2
|
# x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ
|
||||||
|
mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count
|
||||||
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
|
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
|
||||||
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
|
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
|
||||||
|
|
||||||
|
# Compute std.
|
||||||
|
first_batch_ = None
|
||||||
|
running_item_count = 0 # for online std computation
|
||||||
|
for _ in tqdm.tqdm(range(ceil(len(rb) / batch_size))):
|
||||||
|
batch = rb.sample()
|
||||||
|
this_batch_size = batch.batch_size[0]
|
||||||
|
running_item_count += this_batch_size
|
||||||
|
# Sanity check to make sure the batches are still in the same order as before.
|
||||||
|
if first_batch_ is None:
|
||||||
|
first_batch_ = deepcopy(batch)
|
||||||
|
for key in self.stats_patterns:
|
||||||
|
assert torch.equal(first_batch_[key], first_batch[key])
|
||||||
|
for key, pattern in self.stats_patterns.items():
|
||||||
|
batch[key] = batch[key].float()
|
||||||
|
# Numerically stable update step for mean computation (where the mean is over squared
|
||||||
|
# residuals).See notes in the mean computation loop above.
|
||||||
|
batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean")
|
||||||
|
std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count
|
||||||
|
|
||||||
for key in self.stats_patterns:
|
for key in self.stats_patterns:
|
||||||
std[key] = torch.sqrt(std[key] / num_batch)
|
std[key] = torch.sqrt(std[key])
|
||||||
|
|
||||||
stats = TensorDict({}, batch_size=[])
|
stats = TensorDict({}, batch_size=[])
|
||||||
for key in self.stats_patterns:
|
for key in self.stats_patterns:
|
||||||
|
|
|
@ -1,5 +1,8 @@
|
||||||
|
import einops
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
|
||||||
|
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_offline_buffer
|
from lerobot.common.datasets.factory import make_offline_buffer
|
||||||
from lerobot.common.utils import init_hydra_config
|
from lerobot.common.utils import init_hydra_config
|
||||||
|
@ -30,3 +33,34 @@ def test_factory(env_name, dataset_id):
|
||||||
# TODO(rcadene): we assume for now that image normalization takes place in the model
|
# TODO(rcadene): we assume for now that image normalization takes place in the model
|
||||||
assert img.max() <= 1.0
|
assert img.max() <= 1.0
|
||||||
assert img.min() >= 0.0
|
assert img.min() >= 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_compute_stats():
|
||||||
|
"""Check that the statistics are computed correctly according to the stats_patterns property.
|
||||||
|
|
||||||
|
We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do
|
||||||
|
because we are working with a small dataset).
|
||||||
|
"""
|
||||||
|
cfg = init_hydra_config(
|
||||||
|
DEFAULT_CONFIG_PATH, overrides=["env=aloha", "env.task=sim_transfer_cube_human"]
|
||||||
|
)
|
||||||
|
buffer = make_offline_buffer(cfg)
|
||||||
|
# Get all of the data.
|
||||||
|
all_data = TensorDictReplayBuffer(
|
||||||
|
storage=buffer._storage,
|
||||||
|
batch_size=len(buffer),
|
||||||
|
sampler=SamplerWithoutReplacement(),
|
||||||
|
).sample().float()
|
||||||
|
# Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
|
||||||
|
# computation of the statistics. While doing this, we also make sure it works when we don't divide the
|
||||||
|
# dataset into even batches.
|
||||||
|
computed_stats = buffer._compute_stats(batch_size=int(len(all_data) * 0.75))
|
||||||
|
for k, pattern in buffer.stats_patterns.items():
|
||||||
|
expected_mean = einops.reduce(all_data[k], pattern, "mean")
|
||||||
|
assert torch.allclose(computed_stats[k]["mean"], expected_mean)
|
||||||
|
assert torch.allclose(
|
||||||
|
computed_stats[k]["std"],
|
||||||
|
torch.sqrt(einops.reduce((all_data[k] - expected_mean) ** 2, pattern, "mean"))
|
||||||
|
)
|
||||||
|
assert torch.allclose(computed_stats[k]["min"], einops.reduce(all_data[k], pattern, "min"))
|
||||||
|
assert torch.allclose(computed_stats[k]["max"], einops.reduce(all_data[k], pattern, "max"))
|
||||||
|
|
Loading…
Reference in New Issue