From 148df1c1d5a99c1f1a03654568d16a4a5ade3a4f Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Tue, 2 Apr 2024 16:57:25 +0100 Subject: [PATCH] add comment on test --- tests/test_datasets.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index fa45c646..5146b86c 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -54,7 +54,8 @@ def test_compute_stats(): sampler=SamplerWithoutReplacement(), ).sample().float() # Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched - # computation of the statistics. + # computation of the statistics. While doing this, we also make sure it works when we don't divide the + # dataset into even batches. computed_stats = buffer._compute_stats(batch_size=int(len(all_data) * 0.75)) for k, pattern in buffer.stats_patterns.items(): expected_mean = einops.reduce(all_data[k], pattern, "mean")