parent
4c3d8b061e
commit
9874652c2f
|
@ -1,10 +1,11 @@
|
|||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torchvision.transforms import v2
|
||||
|
||||
from lerobot.common.datasets.utils import compute_or_load_stats
|
||||
from lerobot.common.datasets.utils import compute_stats
|
||||
from lerobot.common.transforms import NormalizeTransform, Prod
|
||||
|
||||
# DATA_DIR specifies to location where datasets are loaded. By default, DATA_DIR is None and
|
||||
|
@ -59,7 +60,15 @@ def make_dataset(
|
|||
root=DATA_DIR,
|
||||
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
|
||||
)
|
||||
stats = compute_or_load_stats(stats_dataset)
|
||||
|
||||
# load stats if the file exists already or compute stats and save it
|
||||
precomputed_stats_path = stats_dataset.data_dir / "stats.pth"
|
||||
if precomputed_stats_path.exists():
|
||||
stats = torch.load(precomputed_stats_path)
|
||||
else:
|
||||
logging.info(f"compute_stats and save to {precomputed_stats_path}")
|
||||
stats = compute_stats(stats_dataset)
|
||||
torch.save(stats, stats_path)
|
||||
else:
|
||||
stats = torch.load(stats_path)
|
||||
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import io
|
||||
import logging
|
||||
import zipfile
|
||||
from copy import deepcopy
|
||||
from math import ceil
|
||||
|
@ -103,13 +102,18 @@ def load_data_with_delta_timestamps(
|
|||
return data, is_pad
|
||||
|
||||
|
||||
def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None):
|
||||
stats_path = dataset.data_dir / "stats.pth"
|
||||
if stats_path.exists():
|
||||
return torch.load(stats_path)
|
||||
def get_stats_einops_patterns(dataset):
|
||||
"""These einops patterns will be used to aggregate batches and compute statistics."""
|
||||
stats_patterns = {
|
||||
"action": "b c -> c",
|
||||
"observation.state": "b c -> c",
|
||||
}
|
||||
for key in dataset.image_keys:
|
||||
stats_patterns[key] = "b c h w -> c 1 1"
|
||||
return stats_patterns
|
||||
|
||||
logging.info(f"compute_stats and save to {stats_path}")
|
||||
|
||||
def compute_stats(dataset, batch_size=32, max_num_samples=None):
|
||||
if max_num_samples is None:
|
||||
max_num_samples = len(dataset)
|
||||
else:
|
||||
|
@ -124,13 +128,8 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None):
|
|||
drop_last=False,
|
||||
)
|
||||
|
||||
# these einops patterns will be used to aggregate batches and compute statistics
|
||||
stats_patterns = {
|
||||
"action": "b c -> c",
|
||||
"observation.state": "b c -> c",
|
||||
}
|
||||
for key in dataset.image_keys:
|
||||
stats_patterns[key] = "b c h w -> c 1 1"
|
||||
# get einops patterns to aggregate batches and compute statistics
|
||||
stats_patterns = get_stats_einops_patterns(dataset)
|
||||
|
||||
# mean and std will be computed incrementally while max and min will track the running value.
|
||||
mean, std, max, min = {}, {}, {}, {}
|
||||
|
@ -201,7 +200,6 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None):
|
|||
"min": min[key],
|
||||
}
|
||||
|
||||
torch.save(stats, stats_path)
|
||||
return stats
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,12 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
import einops
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.utils import compute_stats, get_stats_einops_patterns
|
||||
from lerobot.common.datasets.xarm import XarmDataset
|
||||
from lerobot.common.transforms import Prod
|
||||
from lerobot.common.utils import init_hydra_config
|
||||
import logging
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
|
@ -81,28 +87,58 @@ def test_factory(env_name, dataset_id, policy_name):
|
|||
assert key in item, f"{key}"
|
||||
|
||||
|
||||
# def test_compute_stats():
|
||||
# """Check that the statistics are computed correctly according to the stats_patterns property.
|
||||
def test_compute_stats():
|
||||
"""Check that the statistics are computed correctly according to the stats_patterns property.
|
||||
|
||||
# We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do
|
||||
# because we are working with a small dataset).
|
||||
# """
|
||||
# cfg = init_hydra_config(
|
||||
# DEFAULT_CONFIG_PATH, overrides=["env=aloha", "env.task=sim_transfer_cube_human"]
|
||||
# )
|
||||
# dataset = make_dataset(cfg)
|
||||
# # Get all of the data.
|
||||
# all_data = dataset.data_dict
|
||||
# # Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
|
||||
# # computation of the statistics. While doing this, we also make sure it works when we don't divide the
|
||||
# # dataset into even batches.
|
||||
# computed_stats = buffer._compute_stats(batch_size=int(len(all_data) * 0.75))
|
||||
# for k, pattern in buffer.stats_patterns.items():
|
||||
# expected_mean = einops.reduce(all_data[k], pattern, "mean")
|
||||
# assert torch.allclose(computed_stats[k]["mean"], expected_mean)
|
||||
# assert torch.allclose(
|
||||
# computed_stats[k]["std"],
|
||||
# torch.sqrt(einops.reduce((all_data[k] - expected_mean) ** 2, pattern, "mean"))
|
||||
# )
|
||||
# assert torch.allclose(computed_stats[k]["min"], einops.reduce(all_data[k], pattern, "min"))
|
||||
# assert torch.allclose(computed_stats[k]["max"], einops.reduce(all_data[k], pattern, "max"))
|
||||
We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do
|
||||
because we are working with a small dataset).
|
||||
"""
|
||||
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||
|
||||
# get transform to convert images from uint8 [0,255] to float32 [0,1]
|
||||
transform = Prod(in_keys=XarmDataset.image_keys, prod=1 / 255.0)
|
||||
|
||||
dataset = XarmDataset(
|
||||
dataset_id="xarm_lift_medium",
|
||||
root=DATA_DIR,
|
||||
transform=transform,
|
||||
)
|
||||
|
||||
# Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
|
||||
# computation of the statistics. While doing this, we also make sure it works when we don't divide the
|
||||
# dataset into even batches.
|
||||
computed_stats = compute_stats(dataset, batch_size=int(len(dataset) * 0.25))
|
||||
|
||||
# get einops patterns to aggregate batches and compute statistics
|
||||
stats_patterns = get_stats_einops_patterns(dataset)
|
||||
|
||||
# get all frames from the dataset in the same dtype and range as during compute_stats
|
||||
data_dict = transform(dataset.data_dict)
|
||||
|
||||
# compute stats based on all frames from the dataset without any batching
|
||||
expected_stats = {}
|
||||
for k, pattern in stats_patterns.items():
|
||||
expected_stats[k] = {}
|
||||
expected_stats[k]["mean"] = einops.reduce(data_dict[k], pattern, "mean")
|
||||
expected_stats[k]["std"] = torch.sqrt(einops.reduce((data_dict[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean"))
|
||||
expected_stats[k]["min"] = einops.reduce(data_dict[k], pattern, "min")
|
||||
expected_stats[k]["max"] = einops.reduce(data_dict[k], pattern, "max")
|
||||
|
||||
# test computed stats match expected stats
|
||||
for k in stats_patterns:
|
||||
assert torch.allclose(computed_stats[k]["mean"], expected_stats[k]["mean"])
|
||||
assert torch.allclose(computed_stats[k]["std"], expected_stats[k]["std"])
|
||||
assert torch.allclose(computed_stats[k]["min"], expected_stats[k]["min"])
|
||||
assert torch.allclose(computed_stats[k]["max"], expected_stats[k]["max"])
|
||||
|
||||
# TODO(rcadene): check that the stats used for training are correct too
|
||||
# # load stats that are expected to match the ones returned by computed_stats
|
||||
# assert (dataset.data_dir / "stats.pth").exists()
|
||||
# loaded_stats = torch.load(dataset.data_dir / "stats.pth")
|
||||
|
||||
# # test loaded stats match expected stats
|
||||
# for k in stats_patterns:
|
||||
# assert torch.allclose(loaded_stats[k]["mean"], expected_stats[k]["mean"])
|
||||
# assert torch.allclose(loaded_stats[k]["std"], expected_stats[k]["std"])
|
||||
# assert torch.allclose(loaded_stats[k]["min"], expected_stats[k]["min"])
|
||||
# assert torch.allclose(loaded_stats[k]["max"], expected_stats[k]["max"])
|
||||
|
|
Loading…
Reference in New Issue