clarifying math
This commit is contained in:
parent
e9eb262293
commit
c50a62dd6d
|
@ -187,7 +187,8 @@ class AbstractDataset(TensorDictReplayBuffer):
|
||||||
# Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents
|
# Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents
|
||||||
# the update step, N is the running item count, B is this batch size, x̄ is the running mean,
|
# the update step, N is the running item count, B is this batch size, x̄ is the running mean,
|
||||||
# and x is the current batch mean. Some rearrangement is then required to avoid risking
|
# and x is the current batch mean. Some rearrangement is then required to avoid risking
|
||||||
# numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ.
|
# numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields
|
||||||
|
# x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ
|
||||||
mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count
|
mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count
|
||||||
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
|
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
|
||||||
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
|
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
|
||||||
|
|
Loading…
Reference in New Issue