147 lines
6.0 KiB
Python
147 lines
6.0 KiB
Python
from copy import deepcopy
|
|
from math import ceil
|
|
|
|
import datasets
|
|
import einops
|
|
import torch
|
|
import tqdm
|
|
from datasets import Image
|
|
|
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
|
from lerobot.common.datasets.video_utils import VideoFrame
|
|
|
|
|
|
def get_stats_einops_patterns(dataset: LeRobotDataset | datasets.Dataset, num_workers=0):
|
|
"""These einops patterns will be used to aggregate batches and compute statistics.
|
|
|
|
Note: We assume the images are in channel first format
|
|
"""
|
|
|
|
dataloader = torch.utils.data.DataLoader(
|
|
dataset,
|
|
num_workers=num_workers,
|
|
batch_size=2,
|
|
shuffle=False,
|
|
)
|
|
batch = next(iter(dataloader))
|
|
|
|
stats_patterns = {}
|
|
for key, feats_type in dataset.features.items():
|
|
# sanity check that tensors are not float64
|
|
assert batch[key].dtype != torch.float64
|
|
|
|
if isinstance(feats_type, (VideoFrame, Image)):
|
|
# sanity check that images are channel first
|
|
_, c, h, w = batch[key].shape
|
|
assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"
|
|
|
|
# sanity check that images are float32 in range [0,1]
|
|
assert batch[key].dtype == torch.float32, f"expect torch.float32, but instead {batch[key].dtype=}"
|
|
assert batch[key].max() <= 1, f"expect pixels lower than 1, but instead {batch[key].max()=}"
|
|
assert batch[key].min() >= 0, f"expect pixels greater than 1, but instead {batch[key].min()=}"
|
|
|
|
stats_patterns[key] = "b c h w -> c 1 1"
|
|
elif batch[key].ndim == 2:
|
|
stats_patterns[key] = "b c -> c "
|
|
elif batch[key].ndim == 1:
|
|
stats_patterns[key] = "b -> 1"
|
|
else:
|
|
raise ValueError(f"{key}, {feats_type}, {batch[key].shape}")
|
|
|
|
return stats_patterns
|
|
|
|
|
|
def compute_stats(
|
|
dataset: LeRobotDataset | datasets.Dataset, batch_size=32, num_workers=16, max_num_samples=None
|
|
):
|
|
if max_num_samples is None:
|
|
max_num_samples = len(dataset)
|
|
|
|
# for more info on why we need to set the same number of workers, see `load_from_videos`
|
|
stats_patterns = get_stats_einops_patterns(dataset, num_workers)
|
|
|
|
# mean and std will be computed incrementally while max and min will track the running value.
|
|
mean, std, max, min = {}, {}, {}, {}
|
|
for key in stats_patterns:
|
|
mean[key] = torch.tensor(0.0).float()
|
|
std[key] = torch.tensor(0.0).float()
|
|
max[key] = torch.tensor(-float("inf")).float()
|
|
min[key] = torch.tensor(float("inf")).float()
|
|
|
|
def create_seeded_dataloader(dataset, batch_size, seed):
|
|
generator = torch.Generator()
|
|
generator.manual_seed(seed)
|
|
dataloader = torch.utils.data.DataLoader(
|
|
dataset,
|
|
num_workers=num_workers,
|
|
batch_size=batch_size,
|
|
shuffle=True,
|
|
drop_last=False,
|
|
generator=generator,
|
|
)
|
|
return dataloader
|
|
|
|
# Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get
|
|
# surprises when rerunning the sampler.
|
|
first_batch = None
|
|
running_item_count = 0 # for online mean computation
|
|
dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337)
|
|
for i, batch in enumerate(
|
|
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
|
|
):
|
|
this_batch_size = len(batch["index"])
|
|
running_item_count += this_batch_size
|
|
if first_batch is None:
|
|
first_batch = deepcopy(batch)
|
|
for key, pattern in stats_patterns.items():
|
|
batch[key] = batch[key].float()
|
|
# Numerically stable update step for mean computation.
|
|
batch_mean = einops.reduce(batch[key], pattern, "mean")
|
|
# Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents
|
|
# the update step, N is the running item count, B is this batch size, x̄ is the running mean,
|
|
# and x is the current batch mean. Some rearrangement is then required to avoid risking
|
|
# numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields
|
|
# x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ
|
|
mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count
|
|
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
|
|
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
|
|
|
|
if i == ceil(max_num_samples / batch_size) - 1:
|
|
break
|
|
|
|
first_batch_ = None
|
|
running_item_count = 0 # for online std computation
|
|
dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337)
|
|
for i, batch in enumerate(
|
|
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")
|
|
):
|
|
this_batch_size = len(batch["index"])
|
|
running_item_count += this_batch_size
|
|
# Sanity check to make sure the batches are still in the same order as before.
|
|
if first_batch_ is None:
|
|
first_batch_ = deepcopy(batch)
|
|
for key in stats_patterns:
|
|
assert torch.equal(first_batch_[key], first_batch[key])
|
|
for key, pattern in stats_patterns.items():
|
|
batch[key] = batch[key].float()
|
|
# Numerically stable update step for mean computation (where the mean is over squared
|
|
# residuals).See notes in the mean computation loop above.
|
|
batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean")
|
|
std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count
|
|
|
|
if i == ceil(max_num_samples / batch_size) - 1:
|
|
break
|
|
|
|
for key in stats_patterns:
|
|
std[key] = torch.sqrt(std[key])
|
|
|
|
stats = {}
|
|
for key in stats_patterns:
|
|
stats[key] = {
|
|
"mean": mean[key],
|
|
"std": std[key],
|
|
"max": max[key],
|
|
"min": min[key],
|
|
}
|
|
return stats
|