diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index 0aa89d65..18840a9e 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -169,28 +169,36 @@ class AbstractDataset(TensorDictReplayBuffer): max[key] = torch.tensor(-float("inf")).float() min[key] = torch.tensor(float("inf")).float() - # compute mean, min, max + # Compute mean, min, max. # Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get # surprises when rerunning the sampler. first_batch = None + running_item_count = 0 # for online mean computation for _ in tqdm.tqdm(range(ceil(len(rb) / batch_size))): batch = rb.sample() + this_batch_size = batch.batch_size[0] + running_item_count += this_batch_size if first_batch is None: first_batch = deepcopy(batch) for key, pattern in self.stats_patterns.items(): batch[key] = batch[key].float() - # Sum over batch then divide by total number of samples. - mean[key] = mean[key] + einops.reduce(batch[key], pattern, "mean") * batch.batch_size[0] + # Numerically stable update step for mean computation. + batch_mean = einops.reduce(batch[key], pattern, "mean") + # Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents + # the update step, N is the running item count, B is this batch size, x̄ is the running mean, + # and x is the current batch mean. Some rearrangement is then required to avoid risking + # numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. + mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max")) min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min")) - for key in self.stats_patterns: - mean[key] = mean[key] / len(rb) - - # Compute std + # Compute std. first_batch_ = None + running_item_count = 0 # for online std computation for _ in tqdm.tqdm(range(ceil(len(rb) / batch_size))): batch = rb.sample() + this_batch_size = batch.batch_size[0] + running_item_count += this_batch_size # Sanity check to make sure the batches are still in the same order as before. if first_batch_ is None: first_batch_ = deepcopy(batch) @@ -198,14 +206,13 @@ class AbstractDataset(TensorDictReplayBuffer): assert torch.equal(first_batch_[key], first_batch[key]) for key, pattern in self.stats_patterns.items(): batch[key] = batch[key].float() - # Sum over batch then divide by total number of samples. - std[key] = ( - std[key] - + einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean") * batch.batch_size[0] - ) + # Numerically stable update step for mean computation (where the mean is over squared + # residuals).See notes in the mean computation loop above. + batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean") + std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count for key in self.stats_patterns: - std[key] = torch.sqrt(std[key] / len(rb)) + std[key] = torch.sqrt(std[key]) stats = TensorDict({}, batch_size=[]) for key in self.stats_patterns: diff --git a/tests/test_datasets.py b/tests/test_datasets.py index df41b03f..05c9e80d 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -58,9 +58,12 @@ def test_compute_stats(): for k, pattern in buffer.stats_patterns.items(): expected_mean = einops.reduce(all_data[k], pattern, "mean") assert torch.allclose(computed_stats[k]["mean"], expected_mean) - assert torch.allclose( - computed_stats[k]["std"], - torch.sqrt(einops.reduce((all_data[k] - expected_mean) ** 2, pattern, "mean")) - ) + try: + assert torch.allclose( + computed_stats[k]["std"], + torch.sqrt(einops.reduce((all_data[k] - expected_mean) ** 2, pattern, "mean")) + ) + except: + breakpoint() assert torch.allclose(computed_stats[k]["min"], einops.reduce(all_data[k], pattern, "min")) assert torch.allclose(computed_stats[k]["max"], einops.reduce(all_data[k], pattern, "max"))