WIP stats (TODO: run tests on stats + cmpute them)

This commit is contained in:
Cadene 2024-04-04 16:36:03 +00:00
parent 1cdfbc8b52
commit c93ce35d8c
5 changed files with 157 additions and 286 deletions

View File

@ -1,234 +0,0 @@
import logging
from copy import deepcopy
from math import ceil
from pathlib import Path
from typing import Callable
import einops
import torch
import torchrl
import tqdm
from huggingface_hub import snapshot_download
from tensordict import TensorDict
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import Sampler, SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
from torchrl.envs.transforms.transforms import Compose
HF_USER = "lerobot"
class AbstractDataset(TensorDictReplayBuffer):
"""
AbstractDataset represents a dataset in the context of imitation learning or reinforcement learning.
This class is designed to be subclassed by concrete implementations that specify particular types of datasets.
These implementations can vary based on the source of the data, the environment the data pertains to,
or the specific kind of data manipulation applied.
Note:
- `TensorDictReplayBuffer` is the base class from which `AbstractDataset` inherits. It provides the foundational
functionality for storing and retrieving `TensorDict`-like data.
- `available_datasets` should be overridden by concrete subclasses to list the specific dataset variants supported.
It is expected that these variants correspond to a HuggingFace dataset on the hub.
For instance, the `AlohaDataset` which inherites from `AbstractDataset` has 4 available dataset variants:
- [aloha_sim_transfer_cube_scripted](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted)
- [aloha_sim_insertion_scripted](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted)
- [aloha_sim_transfer_cube_human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human)
- [aloha_sim_insertion_human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human)
- When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
1. set the required class attributes:
- for classes inheriting from `AbstractDataset`: `available_datasets`
- for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
- for classes inheriting from `AbstractPolicy`: `name`
2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
3. update variables in `tests/test_available.py` by importing your new class
"""
available_datasets: list[str] | None = None
def __init__(
self,
dataset_id: str,
version: str | None = None,
batch_size: int | None = None,
*,
shuffle: bool = True,
root: Path | None = None,
pin_memory: bool = False,
prefetch: int = None,
sampler: Sampler | None = None,
collate_fn: Callable | None = None,
writer: Writer | None = None,
transform: "torchrl.envs.Transform" = None,
):
assert (
self.available_datasets is not None
), "Subclasses of `AbstractDataset` should set the `available_datasets` class attribute."
assert (
dataset_id in self.available_datasets
), f"The provided dataset ({dataset_id}) is not on the list of available datasets {self.available_datasets}."
self.dataset_id = dataset_id
self.version = version
self.shuffle = shuffle
self.root = root if root is None else Path(root)
if self.root is not None and self.version is not None:
logging.warning(
f"The version of the dataset ({self.version}) is not enforced when root is provided ({self.root})."
)
storage = self._download_or_load_dataset()
super().__init__(
storage=storage,
sampler=sampler,
writer=ImmutableDatasetWriter() if writer is None else writer,
collate_fn=_collate_id if collate_fn is None else collate_fn,
pin_memory=pin_memory,
prefetch=prefetch,
batch_size=batch_size,
transform=transform,
)
@property
def stats_patterns(self) -> dict:
return {
("observation", "state"): "b c -> c",
("observation", "image"): "b c h w -> c 1 1",
("action",): "b c -> c",
}
@property
def image_keys(self) -> list:
return [("observation", "image")]
@property
def num_cameras(self) -> int:
return len(self.image_keys)
@property
def num_samples(self) -> int:
return len(self)
@property
def num_episodes(self) -> int:
return len(self._storage._storage["episode"].unique())
@property
def transform(self):
return self._transform
def set_transform(self, transform):
if not isinstance(transform, Compose):
# required since torchrl calls `len(self._transform)` downstream
if isinstance(transform, list):
self._transform = Compose(*transform)
else:
self._transform = Compose(transform)
else:
self._transform = transform
def compute_or_load_stats(self, batch_size: int = 32) -> TensorDict:
stats_path = self.data_dir / "stats.pth"
if stats_path.exists():
stats = torch.load(stats_path)
else:
logging.info(f"compute_stats and save to {stats_path}")
stats = self._compute_stats(batch_size)
torch.save(stats, stats_path)
return stats
def _download_or_load_dataset(self) -> torch.StorageBase:
if self.root is None:
self.data_dir = Path(
snapshot_download(
repo_id=f"{HF_USER}/{self.dataset_id}", repo_type="dataset", revision=self.version
)
)
else:
self.data_dir = self.root / self.dataset_id
return TensorStorage(TensorDict.load_memmap(self.data_dir / "replay_buffer"))
def _compute_stats(self, batch_size: int = 32):
"""Compute dataset statistics including minimum, maximum, mean, and standard deviation.
TODO(alexander-soare): Add a num_batches argument which essentially allows one to use a subset of the
full dataset (for handling very large datasets). The sampling would then have to be random
(preferably without replacement). Both stats computation loops would ideally sample the same
items.
"""
rb = TensorDictReplayBuffer(
storage=self._storage,
batch_size=32,
prefetch=True,
# Note: Due to be refactored soon. The point is that we should go through the whole dataset.
sampler=SamplerWithoutReplacement(drop_last=False, shuffle=False),
)
# mean and std will be computed incrementally while max and min will track the running value.
mean, std, max, min = {}, {}, {}, {}
for key in self.stats_patterns:
mean[key] = torch.tensor(0.0).float()
std[key] = torch.tensor(0.0).float()
max[key] = torch.tensor(-float("inf")).float()
min[key] = torch.tensor(float("inf")).float()
# Compute mean, min, max.
# Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get
# surprises when rerunning the sampler.
first_batch = None
running_item_count = 0 # for online mean computation
for _ in tqdm.tqdm(range(ceil(len(rb) / batch_size))):
batch = rb.sample()
this_batch_size = batch.batch_size[0]
running_item_count += this_batch_size
if first_batch is None:
first_batch = deepcopy(batch)
for key, pattern in self.stats_patterns.items():
batch[key] = batch[key].float()
# Numerically stable update step for mean computation.
batch_mean = einops.reduce(batch[key], pattern, "mean")
# Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents
# the update step, N is the running item count, B is this batch size, x̄ is the running mean,
# and x is the current batch mean. Some rearrangement is then required to avoid risking
# numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields
# x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ
mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
# Compute std.
first_batch_ = None
running_item_count = 0 # for online std computation
for _ in tqdm.tqdm(range(ceil(len(rb) / batch_size))):
batch = rb.sample()
this_batch_size = batch.batch_size[0]
running_item_count += this_batch_size
# Sanity check to make sure the batches are still in the same order as before.
if first_batch_ is None:
first_batch_ = deepcopy(batch)
for key in self.stats_patterns:
assert torch.equal(first_batch_[key], first_batch[key])
for key, pattern in self.stats_patterns.items():
batch[key] = batch[key].float()
# Numerically stable update step for mean computation (where the mean is over squared
# residuals).See notes in the mean computation loop above.
batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean")
std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count
for key in self.stats_patterns:
std[key] = torch.sqrt(std[key])
stats = TensorDict({}, batch_size=[])
for key in self.stats_patterns:
stats[(*key, "mean")] = mean[key]
stats[(*key, "std")] = std[key]
stats[(*key, "max")] = max[key]
stats[(*key, "min")] = min[key]
if key[0] == "observation":
# use same stats for the next observations
stats[("next", *key)] = stats[key]
return stats

View File

@ -4,7 +4,8 @@ from pathlib import Path
import torch
from torchvision.transforms import v2
from lerobot.common.transforms import Prod
from lerobot.common.datasets.utils import compute_or_load_stats
from lerobot.common.transforms import NormalizeTransform, Prod
# DATA_DIR specifies to location where datasets are loaded. By default, DATA_DIR is None and
# we load from `$HOME/.cache/huggingface/hub/datasets`. For our unit tests, we set `DATA_DIR=tests/data`
@ -41,9 +42,8 @@ def make_dataset(
# min_max_from_spec
# stats = dataset.compute_or_load_stats() if stats_path is None else torch.load(stats_path)
stats = {}
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
stats = {}
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
stats["observation.state"] = {}
stats["observation.state"]["min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
@ -51,22 +51,30 @@ def make_dataset(
stats["action"] = {}
stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
else:
# instantiate a one frame dataset with light transform
stats_dataset = clsfunc(
dataset_id=cfg.dataset_id,
root=DATA_DIR,
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
)
stats = compute_or_load_stats(stats_dataset)
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
# normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
transforms = v2.Compose(
[
# TODO(rcadene): we need to do something about image_keys
Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
# NormalizeTransform(
# stats,
# in_keys=[
# "observation.state",
# "action",
# ],
# mode=normalization_mode,
# ),
NormalizeTransform(
stats,
in_keys=[
"observation.state",
"action",
],
mode=normalization_mode,
),
]
)

View File

@ -1,7 +1,11 @@
import io
import logging
import zipfile
from copy import deepcopy
from math import ceil
from pathlib import Path
import einops
import requests
import torch
import tqdm
@ -97,3 +101,100 @@ def load_data_with_delta_timestamps(
)
return data, is_pad
def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None):
stats_path = dataset.data_dir / "stats.pth"
if stats_path.exists():
return torch.load(stats_path)
logging.info(f"compute_stats and save to {stats_path}")
if max_num_samples is None:
max_num_samples = len(dataset)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=4,
batch_size=batch_size,
shuffle=True,
# pin_memory=cfg.device != "cpu",
drop_last=False,
)
stats_patterns = {
"action": "b c -> c",
"observation.state": "b c -> c",
}
for key in dataset.image_keys:
stats_patterns[key] = "b c h w -> c 1 1"
# mean and std will be computed incrementally while max and min will track the running value.
mean, std, max, min = {}, {}, {}, {}
for key in stats_patterns:
mean[key] = torch.tensor(0.0).float()
std[key] = torch.tensor(0.0).float()
max[key] = torch.tensor(-float("inf")).float()
min[key] = torch.tensor(float("inf")).float()
# Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get
# surprises when rerunning the sampler.
first_batch = None
running_item_count = 0 # for online mean computation
for i, batch in enumerate(
tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
):
this_batch_size = batch.batch_size[0]
running_item_count += this_batch_size
if first_batch is None:
first_batch = deepcopy(batch)
for key, pattern in stats_patterns.items():
batch[key] = batch[key].float()
# Numerically stable update step for mean computation.
batch_mean = einops.reduce(batch[key], pattern, "mean")
# Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents
# the update step, N is the running item count, B is this batch size, x̄ is the running mean,
# and x is the current batch mean. Some rearrangement is then required to avoid risking
# numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields
# x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ
mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
if i == ceil(max_num_samples / batch_size) - 1:
break
first_batch_ = None
running_item_count = 0 # for online std computation
for i, batch in enumerate(tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")):
this_batch_size = batch.batch_size[0]
running_item_count += this_batch_size
# Sanity check to make sure the batches are still in the same order as before.
if first_batch_ is None:
first_batch_ = deepcopy(batch)
for key in stats_patterns:
assert torch.equal(first_batch_[key], first_batch[key])
for key, pattern in stats_patterns.items():
batch[key] = batch[key].float()
# Numerically stable update step for mean computation (where the mean is over squared
# residuals).See notes in the mean computation loop above.
batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean")
std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count
if i == ceil(max_num_samples / batch_size) - 1:
break
for key in stats_patterns:
std[key] = torch.sqrt(std[key])
stats = {}
for key in stats_patterns:
stats[key] = {
"mean": mean[key],
"std": std[key],
"max": max[key],
"min": min[key],
}
torch.save(stats, stats_path)
return stats

View File

@ -72,12 +72,12 @@ class NormalizeTransform(Transform):
if inkey not in item:
continue
if self.mode == "mean_std":
mean = self.stats[inkey]["mean"]
std = self.stats[inkey]["std"]
mean = self.stats[f"{inkey}.mean"]
std = self.stats[f"{inkey}.std"]
item[outkey] = (item[inkey] - mean) / (std + 1e-8)
else:
min = self.stats[inkey]["min"]
max = self.stats[inkey]["max"]
min = self.stats[f"{inkey}.min"]
max = self.stats[f"{inkey}.max"]
# normalize to [0,1]
item[outkey] = (item[inkey] - min) / (max - min)
# normalize to [-1, 1]
@ -89,12 +89,12 @@ class NormalizeTransform(Transform):
if inkey not in item:
continue
if self.mode == "mean_std":
mean = self.stats[inkey]["mean"]
std = self.stats[inkey]["std"]
mean = self.stats[f"{inkey}.mean"]
std = self.stats[f"{inkey}.std"]
item[outkey] = item[inkey] * std + mean
else:
min = self.stats[inkey]["min"]
max = self.stats[inkey]["max"]
min = self.stats[f"{inkey}.min"]
max = self.stats[f"{inkey}.max"]
item[outkey] = (item[inkey] + 1) / 2
item[outkey] = item[outkey] * (max - min) + min
return item

View File

@ -1,10 +1,6 @@
import einops
import pytest
import torch
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.utils import init_hydra_config
import logging
from lerobot.common.datasets.factory import make_dataset
@ -52,32 +48,32 @@ def test_factory(env_name, dataset_id):
logging.warning(f'Missing "next.done" key in dataset {dataset}.')
def test_compute_stats():
"""Check that the statistics are computed correctly according to the stats_patterns property.
# def test_compute_stats():
# """Check that the statistics are computed correctly according to the stats_patterns property.
We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do
because we are working with a small dataset).
"""
cfg = init_hydra_config(
DEFAULT_CONFIG_PATH, overrides=["env=aloha", "env.task=sim_transfer_cube_human"]
)
buffer = make_offline_buffer(cfg)
# Get all of the data.
all_data = TensorDictReplayBuffer(
storage=buffer._storage,
batch_size=len(buffer),
sampler=SamplerWithoutReplacement(),
).sample().float()
# Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
# computation of the statistics. While doing this, we also make sure it works when we don't divide the
# dataset into even batches.
computed_stats = buffer._compute_stats(batch_size=int(len(all_data) * 0.75))
for k, pattern in buffer.stats_patterns.items():
expected_mean = einops.reduce(all_data[k], pattern, "mean")
assert torch.allclose(computed_stats[k]["mean"], expected_mean)
assert torch.allclose(
computed_stats[k]["std"],
torch.sqrt(einops.reduce((all_data[k] - expected_mean) ** 2, pattern, "mean"))
)
assert torch.allclose(computed_stats[k]["min"], einops.reduce(all_data[k], pattern, "min"))
assert torch.allclose(computed_stats[k]["max"], einops.reduce(all_data[k], pattern, "max"))
# We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do
# because we are working with a small dataset).
# """
# cfg = init_hydra_config(
# DEFAULT_CONFIG_PATH, overrides=["env=aloha", "env.task=sim_transfer_cube_human"]
# )
# dataset = make_dataset(cfg)
# # Get all of the data.
# all_data = TensorDictReplayBuffer(
# storage=buffer._storage,
# batch_size=len(buffer),
# sampler=SamplerWithoutReplacement(),
# ).sample().float()
# # Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
# # computation of the statistics. While doing this, we also make sure it works when we don't divide the
# # dataset into even batches.
# computed_stats = buffer._compute_stats(batch_size=int(len(all_data) * 0.75))
# for k, pattern in buffer.stats_patterns.items():
# expected_mean = einops.reduce(all_data[k], pattern, "mean")
# assert torch.allclose(computed_stats[k]["mean"], expected_mean)
# assert torch.allclose(
# computed_stats[k]["std"],
# torch.sqrt(einops.reduce((all_data[k] - expected_mean) ** 2, pattern, "mean"))
# )
# assert torch.allclose(computed_stats[k]["min"], einops.reduce(all_data[k], pattern, "min"))
# assert torch.allclose(computed_stats[k]["max"], einops.reduce(all_data[k], pattern, "max"))