diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 63507cce..7e079476 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -1,10 +1,11 @@ +import logging import os from pathlib import Path import torch from torchvision.transforms import v2 -from lerobot.common.datasets.utils import compute_or_load_stats +from lerobot.common.datasets.utils import compute_stats from lerobot.common.transforms import NormalizeTransform, Prod # DATA_DIR specifies to location where datasets are loaded. By default, DATA_DIR is None and @@ -59,7 +60,15 @@ def make_dataset( root=DATA_DIR, transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0), ) - stats = compute_or_load_stats(stats_dataset) + + # load stats if the file exists already or compute stats and save it + precomputed_stats_path = stats_dataset.data_dir / "stats.pth" + if precomputed_stats_path.exists(): + stats = torch.load(precomputed_stats_path) + else: + logging.info(f"compute_stats and save to {precomputed_stats_path}") + stats = compute_stats(stats_dataset) + torch.save(stats, stats_path) else: stats = torch.load(stats_path) diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 3b4aacfc..cf8caa46 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -1,5 +1,4 @@ import io -import logging import zipfile from copy import deepcopy from math import ceil @@ -103,13 +102,18 @@ def load_data_with_delta_timestamps( return data, is_pad -def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None): - stats_path = dataset.data_dir / "stats.pth" - if stats_path.exists(): - return torch.load(stats_path) +def get_stats_einops_patterns(dataset): + """These einops patterns will be used to aggregate batches and compute statistics.""" + stats_patterns = { + "action": "b c -> c", + "observation.state": "b c -> c", + } + for key in dataset.image_keys: + stats_patterns[key] = "b c h w -> c 1 1" + return stats_patterns - logging.info(f"compute_stats and save to {stats_path}") +def compute_stats(dataset, batch_size=32, max_num_samples=None): if max_num_samples is None: max_num_samples = len(dataset) else: @@ -124,13 +128,8 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None): drop_last=False, ) - # these einops patterns will be used to aggregate batches and compute statistics - stats_patterns = { - "action": "b c -> c", - "observation.state": "b c -> c", - } - for key in dataset.image_keys: - stats_patterns[key] = "b c h w -> c 1 1" + # get einops patterns to aggregate batches and compute statistics + stats_patterns = get_stats_einops_patterns(dataset) # mean and std will be computed incrementally while max and min will track the running value. mean, std, max, min = {}, {}, {}, {} @@ -201,7 +200,6 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None): "min": min[key], } - torch.save(stats, stats_path) return stats diff --git a/tests/test_datasets.py b/tests/test_datasets.py index e24d7b4d..9b32ea25 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -1,6 +1,12 @@ +import os +from pathlib import Path +import einops import pytest import torch +from lerobot.common.datasets.utils import compute_stats, get_stats_einops_patterns +from lerobot.common.datasets.xarm import XarmDataset +from lerobot.common.transforms import Prod from lerobot.common.utils import init_hydra_config import logging from lerobot.common.datasets.factory import make_dataset @@ -81,28 +87,58 @@ def test_factory(env_name, dataset_id, policy_name): assert key in item, f"{key}" -# def test_compute_stats(): -# """Check that the statistics are computed correctly according to the stats_patterns property. +def test_compute_stats(): + """Check that the statistics are computed correctly according to the stats_patterns property. -# We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do -# because we are working with a small dataset). -# """ -# cfg = init_hydra_config( -# DEFAULT_CONFIG_PATH, overrides=["env=aloha", "env.task=sim_transfer_cube_human"] -# ) -# dataset = make_dataset(cfg) -# # Get all of the data. -# all_data = dataset.data_dict -# # Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched -# # computation of the statistics. While doing this, we also make sure it works when we don't divide the -# # dataset into even batches. -# computed_stats = buffer._compute_stats(batch_size=int(len(all_data) * 0.75)) -# for k, pattern in buffer.stats_patterns.items(): -# expected_mean = einops.reduce(all_data[k], pattern, "mean") -# assert torch.allclose(computed_stats[k]["mean"], expected_mean) -# assert torch.allclose( -# computed_stats[k]["std"], -# torch.sqrt(einops.reduce((all_data[k] - expected_mean) ** 2, pattern, "mean")) -# ) -# assert torch.allclose(computed_stats[k]["min"], einops.reduce(all_data[k], pattern, "min")) -# assert torch.allclose(computed_stats[k]["max"], einops.reduce(all_data[k], pattern, "max")) + We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do + because we are working with a small dataset). + """ + DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None + + # get transform to convert images from uint8 [0,255] to float32 [0,1] + transform = Prod(in_keys=XarmDataset.image_keys, prod=1 / 255.0) + + dataset = XarmDataset( + dataset_id="xarm_lift_medium", + root=DATA_DIR, + transform=transform, + ) + + # Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched + # computation of the statistics. While doing this, we also make sure it works when we don't divide the + # dataset into even batches. + computed_stats = compute_stats(dataset, batch_size=int(len(dataset) * 0.25)) + + # get einops patterns to aggregate batches and compute statistics + stats_patterns = get_stats_einops_patterns(dataset) + + # get all frames from the dataset in the same dtype and range as during compute_stats + data_dict = transform(dataset.data_dict) + + # compute stats based on all frames from the dataset without any batching + expected_stats = {} + for k, pattern in stats_patterns.items(): + expected_stats[k] = {} + expected_stats[k]["mean"] = einops.reduce(data_dict[k], pattern, "mean") + expected_stats[k]["std"] = torch.sqrt(einops.reduce((data_dict[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean")) + expected_stats[k]["min"] = einops.reduce(data_dict[k], pattern, "min") + expected_stats[k]["max"] = einops.reduce(data_dict[k], pattern, "max") + + # test computed stats match expected stats + for k in stats_patterns: + assert torch.allclose(computed_stats[k]["mean"], expected_stats[k]["mean"]) + assert torch.allclose(computed_stats[k]["std"], expected_stats[k]["std"]) + assert torch.allclose(computed_stats[k]["min"], expected_stats[k]["min"]) + assert torch.allclose(computed_stats[k]["max"], expected_stats[k]["max"]) + + # TODO(rcadene): check that the stats used for training are correct too + # # load stats that are expected to match the ones returned by computed_stats + # assert (dataset.data_dir / "stats.pth").exists() + # loaded_stats = torch.load(dataset.data_dir / "stats.pth") + + # # test loaded stats match expected stats + # for k in stats_patterns: + # assert torch.allclose(loaded_stats[k]["mean"], expected_stats[k]["mean"]) + # assert torch.allclose(loaded_stats[k]["std"], expected_stats[k]["std"]) + # assert torch.allclose(loaded_stats[k]["min"], expected_stats[k]["min"]) + # assert torch.allclose(loaded_stats[k]["max"], expected_stats[k]["max"])