numerically sound mean computation

This commit is contained in:
Alexander Soare 2024-04-03 09:44:20 +01:00
parent 7242953197
commit e9eb262293
2 changed files with 27 additions and 17 deletions

View File

@ -169,28 +169,36 @@ class AbstractDataset(TensorDictReplayBuffer):
max[key] = torch.tensor(-float("inf")).float()
min[key] = torch.tensor(float("inf")).float()
# compute mean, min, max
# Compute mean, min, max.
# Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get
# surprises when rerunning the sampler.
first_batch = None
running_item_count = 0 # for online mean computation
for _ in tqdm.tqdm(range(ceil(len(rb) / batch_size))):
batch = rb.sample()
this_batch_size = batch.batch_size[0]
running_item_count += this_batch_size
if first_batch is None:
first_batch = deepcopy(batch)
for key, pattern in self.stats_patterns.items():
batch[key] = batch[key].float()
# Sum over batch then divide by total number of samples.
mean[key] = mean[key] + einops.reduce(batch[key], pattern, "mean") * batch.batch_size[0]
# Numerically stable update step for mean computation.
batch_mean = einops.reduce(batch[key], pattern, "mean")
# Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents
# the update step, N is the running item count, B is this batch size, x̄ is the running mean,
# and x is the current batch mean. Some rearrangement is then required to avoid risking
# numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ.
mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
for key in self.stats_patterns:
mean[key] = mean[key] / len(rb)
# Compute std
# Compute std.
first_batch_ = None
running_item_count = 0 # for online std computation
for _ in tqdm.tqdm(range(ceil(len(rb) / batch_size))):
batch = rb.sample()
this_batch_size = batch.batch_size[0]
running_item_count += this_batch_size
# Sanity check to make sure the batches are still in the same order as before.
if first_batch_ is None:
first_batch_ = deepcopy(batch)
@ -198,14 +206,13 @@ class AbstractDataset(TensorDictReplayBuffer):
assert torch.equal(first_batch_[key], first_batch[key])
for key, pattern in self.stats_patterns.items():
batch[key] = batch[key].float()
# Sum over batch then divide by total number of samples.
std[key] = (
std[key]
+ einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean") * batch.batch_size[0]
)
# Numerically stable update step for mean computation (where the mean is over squared
# residuals).See notes in the mean computation loop above.
batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean")
std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count
for key in self.stats_patterns:
std[key] = torch.sqrt(std[key] / len(rb))
std[key] = torch.sqrt(std[key])
stats = TensorDict({}, batch_size=[])
for key in self.stats_patterns:

View File

@ -58,9 +58,12 @@ def test_compute_stats():
for k, pattern in buffer.stats_patterns.items():
expected_mean = einops.reduce(all_data[k], pattern, "mean")
assert torch.allclose(computed_stats[k]["mean"], expected_mean)
try:
assert torch.allclose(
computed_stats[k]["std"],
torch.sqrt(einops.reduce((all_data[k] - expected_mean) ** 2, pattern, "mean"))
)
except:
breakpoint()
assert torch.allclose(computed_stats[k]["min"], einops.reduce(all_data[k], pattern, "min"))
assert torch.allclose(computed_stats[k]["max"], einops.reduce(all_data[k], pattern, "max"))