numerically sound mean computation
This commit is contained in:
parent
7242953197
commit
e9eb262293
|
@ -169,28 +169,36 @@ class AbstractDataset(TensorDictReplayBuffer):
|
|||
max[key] = torch.tensor(-float("inf")).float()
|
||||
min[key] = torch.tensor(float("inf")).float()
|
||||
|
||||
# compute mean, min, max
|
||||
# Compute mean, min, max.
|
||||
# Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get
|
||||
# surprises when rerunning the sampler.
|
||||
first_batch = None
|
||||
running_item_count = 0 # for online mean computation
|
||||
for _ in tqdm.tqdm(range(ceil(len(rb) / batch_size))):
|
||||
batch = rb.sample()
|
||||
this_batch_size = batch.batch_size[0]
|
||||
running_item_count += this_batch_size
|
||||
if first_batch is None:
|
||||
first_batch = deepcopy(batch)
|
||||
for key, pattern in self.stats_patterns.items():
|
||||
batch[key] = batch[key].float()
|
||||
# Sum over batch then divide by total number of samples.
|
||||
mean[key] = mean[key] + einops.reduce(batch[key], pattern, "mean") * batch.batch_size[0]
|
||||
# Numerically stable update step for mean computation.
|
||||
batch_mean = einops.reduce(batch[key], pattern, "mean")
|
||||
# Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents
|
||||
# the update step, N is the running item count, B is this batch size, x̄ is the running mean,
|
||||
# and x is the current batch mean. Some rearrangement is then required to avoid risking
|
||||
# numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ.
|
||||
mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count
|
||||
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
|
||||
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
|
||||
|
||||
for key in self.stats_patterns:
|
||||
mean[key] = mean[key] / len(rb)
|
||||
|
||||
# Compute std
|
||||
# Compute std.
|
||||
first_batch_ = None
|
||||
running_item_count = 0 # for online std computation
|
||||
for _ in tqdm.tqdm(range(ceil(len(rb) / batch_size))):
|
||||
batch = rb.sample()
|
||||
this_batch_size = batch.batch_size[0]
|
||||
running_item_count += this_batch_size
|
||||
# Sanity check to make sure the batches are still in the same order as before.
|
||||
if first_batch_ is None:
|
||||
first_batch_ = deepcopy(batch)
|
||||
|
@ -198,14 +206,13 @@ class AbstractDataset(TensorDictReplayBuffer):
|
|||
assert torch.equal(first_batch_[key], first_batch[key])
|
||||
for key, pattern in self.stats_patterns.items():
|
||||
batch[key] = batch[key].float()
|
||||
# Sum over batch then divide by total number of samples.
|
||||
std[key] = (
|
||||
std[key]
|
||||
+ einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean") * batch.batch_size[0]
|
||||
)
|
||||
# Numerically stable update step for mean computation (where the mean is over squared
|
||||
# residuals).See notes in the mean computation loop above.
|
||||
batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean")
|
||||
std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count
|
||||
|
||||
for key in self.stats_patterns:
|
||||
std[key] = torch.sqrt(std[key] / len(rb))
|
||||
std[key] = torch.sqrt(std[key])
|
||||
|
||||
stats = TensorDict({}, batch_size=[])
|
||||
for key in self.stats_patterns:
|
||||
|
|
|
@ -58,9 +58,12 @@ def test_compute_stats():
|
|||
for k, pattern in buffer.stats_patterns.items():
|
||||
expected_mean = einops.reduce(all_data[k], pattern, "mean")
|
||||
assert torch.allclose(computed_stats[k]["mean"], expected_mean)
|
||||
try:
|
||||
assert torch.allclose(
|
||||
computed_stats[k]["std"],
|
||||
torch.sqrt(einops.reduce((all_data[k] - expected_mean) ** 2, pattern, "mean"))
|
||||
)
|
||||
except:
|
||||
breakpoint()
|
||||
assert torch.allclose(computed_stats[k]["min"], einops.reduce(all_data[k], pattern, "min"))
|
||||
assert torch.allclose(computed_stats[k]["max"], einops.reduce(all_data[k], pattern, "max"))
|
||||
|
|
Loading…
Reference in New Issue