WIP stats (TODO: run tests on stats + cmpute them)
This commit is contained in:
parent
1cdfbc8b52
commit
c93ce35d8c
|
@ -1,234 +0,0 @@
|
||||||
import logging
|
|
||||||
from copy import deepcopy
|
|
||||||
from math import ceil
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Callable
|
|
||||||
|
|
||||||
import einops
|
|
||||||
import torch
|
|
||||||
import torchrl
|
|
||||||
import tqdm
|
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
from tensordict import TensorDict
|
|
||||||
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
|
|
||||||
from torchrl.data.replay_buffers.samplers import Sampler, SamplerWithoutReplacement
|
|
||||||
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
|
||||||
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
|
||||||
from torchrl.envs.transforms.transforms import Compose
|
|
||||||
|
|
||||||
HF_USER = "lerobot"
|
|
||||||
|
|
||||||
|
|
||||||
class AbstractDataset(TensorDictReplayBuffer):
|
|
||||||
"""
|
|
||||||
AbstractDataset represents a dataset in the context of imitation learning or reinforcement learning.
|
|
||||||
This class is designed to be subclassed by concrete implementations that specify particular types of datasets.
|
|
||||||
These implementations can vary based on the source of the data, the environment the data pertains to,
|
|
||||||
or the specific kind of data manipulation applied.
|
|
||||||
|
|
||||||
Note:
|
|
||||||
- `TensorDictReplayBuffer` is the base class from which `AbstractDataset` inherits. It provides the foundational
|
|
||||||
functionality for storing and retrieving `TensorDict`-like data.
|
|
||||||
- `available_datasets` should be overridden by concrete subclasses to list the specific dataset variants supported.
|
|
||||||
It is expected that these variants correspond to a HuggingFace dataset on the hub.
|
|
||||||
For instance, the `AlohaDataset` which inherites from `AbstractDataset` has 4 available dataset variants:
|
|
||||||
- [aloha_sim_transfer_cube_scripted](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted)
|
|
||||||
- [aloha_sim_insertion_scripted](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted)
|
|
||||||
- [aloha_sim_transfer_cube_human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human)
|
|
||||||
- [aloha_sim_insertion_human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human)
|
|
||||||
- When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
|
|
||||||
1. set the required class attributes:
|
|
||||||
- for classes inheriting from `AbstractDataset`: `available_datasets`
|
|
||||||
- for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
|
|
||||||
- for classes inheriting from `AbstractPolicy`: `name`
|
|
||||||
2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
|
|
||||||
3. update variables in `tests/test_available.py` by importing your new class
|
|
||||||
"""
|
|
||||||
|
|
||||||
available_datasets: list[str] | None = None
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dataset_id: str,
|
|
||||||
version: str | None = None,
|
|
||||||
batch_size: int | None = None,
|
|
||||||
*,
|
|
||||||
shuffle: bool = True,
|
|
||||||
root: Path | None = None,
|
|
||||||
pin_memory: bool = False,
|
|
||||||
prefetch: int = None,
|
|
||||||
sampler: Sampler | None = None,
|
|
||||||
collate_fn: Callable | None = None,
|
|
||||||
writer: Writer | None = None,
|
|
||||||
transform: "torchrl.envs.Transform" = None,
|
|
||||||
):
|
|
||||||
assert (
|
|
||||||
self.available_datasets is not None
|
|
||||||
), "Subclasses of `AbstractDataset` should set the `available_datasets` class attribute."
|
|
||||||
assert (
|
|
||||||
dataset_id in self.available_datasets
|
|
||||||
), f"The provided dataset ({dataset_id}) is not on the list of available datasets {self.available_datasets}."
|
|
||||||
|
|
||||||
self.dataset_id = dataset_id
|
|
||||||
self.version = version
|
|
||||||
self.shuffle = shuffle
|
|
||||||
self.root = root if root is None else Path(root)
|
|
||||||
|
|
||||||
if self.root is not None and self.version is not None:
|
|
||||||
logging.warning(
|
|
||||||
f"The version of the dataset ({self.version}) is not enforced when root is provided ({self.root})."
|
|
||||||
)
|
|
||||||
|
|
||||||
storage = self._download_or_load_dataset()
|
|
||||||
|
|
||||||
super().__init__(
|
|
||||||
storage=storage,
|
|
||||||
sampler=sampler,
|
|
||||||
writer=ImmutableDatasetWriter() if writer is None else writer,
|
|
||||||
collate_fn=_collate_id if collate_fn is None else collate_fn,
|
|
||||||
pin_memory=pin_memory,
|
|
||||||
prefetch=prefetch,
|
|
||||||
batch_size=batch_size,
|
|
||||||
transform=transform,
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def stats_patterns(self) -> dict:
|
|
||||||
return {
|
|
||||||
("observation", "state"): "b c -> c",
|
|
||||||
("observation", "image"): "b c h w -> c 1 1",
|
|
||||||
("action",): "b c -> c",
|
|
||||||
}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def image_keys(self) -> list:
|
|
||||||
return [("observation", "image")]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def num_cameras(self) -> int:
|
|
||||||
return len(self.image_keys)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def num_samples(self) -> int:
|
|
||||||
return len(self)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def num_episodes(self) -> int:
|
|
||||||
return len(self._storage._storage["episode"].unique())
|
|
||||||
|
|
||||||
@property
|
|
||||||
def transform(self):
|
|
||||||
return self._transform
|
|
||||||
|
|
||||||
def set_transform(self, transform):
|
|
||||||
if not isinstance(transform, Compose):
|
|
||||||
# required since torchrl calls `len(self._transform)` downstream
|
|
||||||
if isinstance(transform, list):
|
|
||||||
self._transform = Compose(*transform)
|
|
||||||
else:
|
|
||||||
self._transform = Compose(transform)
|
|
||||||
else:
|
|
||||||
self._transform = transform
|
|
||||||
|
|
||||||
def compute_or_load_stats(self, batch_size: int = 32) -> TensorDict:
|
|
||||||
stats_path = self.data_dir / "stats.pth"
|
|
||||||
if stats_path.exists():
|
|
||||||
stats = torch.load(stats_path)
|
|
||||||
else:
|
|
||||||
logging.info(f"compute_stats and save to {stats_path}")
|
|
||||||
stats = self._compute_stats(batch_size)
|
|
||||||
torch.save(stats, stats_path)
|
|
||||||
return stats
|
|
||||||
|
|
||||||
def _download_or_load_dataset(self) -> torch.StorageBase:
|
|
||||||
if self.root is None:
|
|
||||||
self.data_dir = Path(
|
|
||||||
snapshot_download(
|
|
||||||
repo_id=f"{HF_USER}/{self.dataset_id}", repo_type="dataset", revision=self.version
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.data_dir = self.root / self.dataset_id
|
|
||||||
return TensorStorage(TensorDict.load_memmap(self.data_dir / "replay_buffer"))
|
|
||||||
|
|
||||||
def _compute_stats(self, batch_size: int = 32):
|
|
||||||
"""Compute dataset statistics including minimum, maximum, mean, and standard deviation.
|
|
||||||
|
|
||||||
TODO(alexander-soare): Add a num_batches argument which essentially allows one to use a subset of the
|
|
||||||
full dataset (for handling very large datasets). The sampling would then have to be random
|
|
||||||
(preferably without replacement). Both stats computation loops would ideally sample the same
|
|
||||||
items.
|
|
||||||
"""
|
|
||||||
rb = TensorDictReplayBuffer(
|
|
||||||
storage=self._storage,
|
|
||||||
batch_size=32,
|
|
||||||
prefetch=True,
|
|
||||||
# Note: Due to be refactored soon. The point is that we should go through the whole dataset.
|
|
||||||
sampler=SamplerWithoutReplacement(drop_last=False, shuffle=False),
|
|
||||||
)
|
|
||||||
|
|
||||||
# mean and std will be computed incrementally while max and min will track the running value.
|
|
||||||
mean, std, max, min = {}, {}, {}, {}
|
|
||||||
for key in self.stats_patterns:
|
|
||||||
mean[key] = torch.tensor(0.0).float()
|
|
||||||
std[key] = torch.tensor(0.0).float()
|
|
||||||
max[key] = torch.tensor(-float("inf")).float()
|
|
||||||
min[key] = torch.tensor(float("inf")).float()
|
|
||||||
|
|
||||||
# Compute mean, min, max.
|
|
||||||
# Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get
|
|
||||||
# surprises when rerunning the sampler.
|
|
||||||
first_batch = None
|
|
||||||
running_item_count = 0 # for online mean computation
|
|
||||||
for _ in tqdm.tqdm(range(ceil(len(rb) / batch_size))):
|
|
||||||
batch = rb.sample()
|
|
||||||
this_batch_size = batch.batch_size[0]
|
|
||||||
running_item_count += this_batch_size
|
|
||||||
if first_batch is None:
|
|
||||||
first_batch = deepcopy(batch)
|
|
||||||
for key, pattern in self.stats_patterns.items():
|
|
||||||
batch[key] = batch[key].float()
|
|
||||||
# Numerically stable update step for mean computation.
|
|
||||||
batch_mean = einops.reduce(batch[key], pattern, "mean")
|
|
||||||
# Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents
|
|
||||||
# the update step, N is the running item count, B is this batch size, x̄ is the running mean,
|
|
||||||
# and x is the current batch mean. Some rearrangement is then required to avoid risking
|
|
||||||
# numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields
|
|
||||||
# x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ
|
|
||||||
mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count
|
|
||||||
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
|
|
||||||
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
|
|
||||||
|
|
||||||
# Compute std.
|
|
||||||
first_batch_ = None
|
|
||||||
running_item_count = 0 # for online std computation
|
|
||||||
for _ in tqdm.tqdm(range(ceil(len(rb) / batch_size))):
|
|
||||||
batch = rb.sample()
|
|
||||||
this_batch_size = batch.batch_size[0]
|
|
||||||
running_item_count += this_batch_size
|
|
||||||
# Sanity check to make sure the batches are still in the same order as before.
|
|
||||||
if first_batch_ is None:
|
|
||||||
first_batch_ = deepcopy(batch)
|
|
||||||
for key in self.stats_patterns:
|
|
||||||
assert torch.equal(first_batch_[key], first_batch[key])
|
|
||||||
for key, pattern in self.stats_patterns.items():
|
|
||||||
batch[key] = batch[key].float()
|
|
||||||
# Numerically stable update step for mean computation (where the mean is over squared
|
|
||||||
# residuals).See notes in the mean computation loop above.
|
|
||||||
batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean")
|
|
||||||
std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count
|
|
||||||
|
|
||||||
for key in self.stats_patterns:
|
|
||||||
std[key] = torch.sqrt(std[key])
|
|
||||||
|
|
||||||
stats = TensorDict({}, batch_size=[])
|
|
||||||
for key in self.stats_patterns:
|
|
||||||
stats[(*key, "mean")] = mean[key]
|
|
||||||
stats[(*key, "std")] = std[key]
|
|
||||||
stats[(*key, "max")] = max[key]
|
|
||||||
stats[(*key, "min")] = min[key]
|
|
||||||
|
|
||||||
if key[0] == "observation":
|
|
||||||
# use same stats for the next observations
|
|
||||||
stats[("next", *key)] = stats[key]
|
|
||||||
return stats
|
|
|
@ -4,7 +4,8 @@ from pathlib import Path
|
||||||
import torch
|
import torch
|
||||||
from torchvision.transforms import v2
|
from torchvision.transforms import v2
|
||||||
|
|
||||||
from lerobot.common.transforms import Prod
|
from lerobot.common.datasets.utils import compute_or_load_stats
|
||||||
|
from lerobot.common.transforms import NormalizeTransform, Prod
|
||||||
|
|
||||||
# DATA_DIR specifies to location where datasets are loaded. By default, DATA_DIR is None and
|
# DATA_DIR specifies to location where datasets are loaded. By default, DATA_DIR is None and
|
||||||
# we load from `$HOME/.cache/huggingface/hub/datasets`. For our unit tests, we set `DATA_DIR=tests/data`
|
# we load from `$HOME/.cache/huggingface/hub/datasets`. For our unit tests, we set `DATA_DIR=tests/data`
|
||||||
|
@ -41,9 +42,8 @@ def make_dataset(
|
||||||
# min_max_from_spec
|
# min_max_from_spec
|
||||||
# stats = dataset.compute_or_load_stats() if stats_path is None else torch.load(stats_path)
|
# stats = dataset.compute_or_load_stats() if stats_path is None else torch.load(stats_path)
|
||||||
|
|
||||||
stats = {}
|
|
||||||
|
|
||||||
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
|
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
|
||||||
|
stats = {}
|
||||||
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
|
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
|
||||||
stats["observation.state"] = {}
|
stats["observation.state"] = {}
|
||||||
stats["observation.state"]["min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
|
stats["observation.state"]["min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
|
||||||
|
@ -51,22 +51,30 @@ def make_dataset(
|
||||||
stats["action"] = {}
|
stats["action"] = {}
|
||||||
stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
||||||
stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
# instantiate a one frame dataset with light transform
|
||||||
|
stats_dataset = clsfunc(
|
||||||
|
dataset_id=cfg.dataset_id,
|
||||||
|
root=DATA_DIR,
|
||||||
|
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
|
||||||
|
)
|
||||||
|
stats = compute_or_load_stats(stats_dataset)
|
||||||
|
|
||||||
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
|
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
|
||||||
# normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
|
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
|
||||||
|
|
||||||
transforms = v2.Compose(
|
transforms = v2.Compose(
|
||||||
[
|
[
|
||||||
# TODO(rcadene): we need to do something about image_keys
|
# TODO(rcadene): we need to do something about image_keys
|
||||||
Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
|
Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
|
||||||
# NormalizeTransform(
|
NormalizeTransform(
|
||||||
# stats,
|
stats,
|
||||||
# in_keys=[
|
in_keys=[
|
||||||
# "observation.state",
|
"observation.state",
|
||||||
# "action",
|
"action",
|
||||||
# ],
|
],
|
||||||
# mode=normalization_mode,
|
mode=normalization_mode,
|
||||||
# ),
|
),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,11 @@
|
||||||
import io
|
import io
|
||||||
|
import logging
|
||||||
import zipfile
|
import zipfile
|
||||||
|
from copy import deepcopy
|
||||||
|
from math import ceil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import einops
|
||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
|
@ -97,3 +101,100 @@ def load_data_with_delta_timestamps(
|
||||||
)
|
)
|
||||||
|
|
||||||
return data, is_pad
|
return data, is_pad
|
||||||
|
|
||||||
|
|
||||||
|
def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None):
|
||||||
|
stats_path = dataset.data_dir / "stats.pth"
|
||||||
|
if stats_path.exists():
|
||||||
|
return torch.load(stats_path)
|
||||||
|
|
||||||
|
logging.info(f"compute_stats and save to {stats_path}")
|
||||||
|
|
||||||
|
if max_num_samples is None:
|
||||||
|
max_num_samples = len(dataset)
|
||||||
|
|
||||||
|
dataloader = torch.utils.data.DataLoader(
|
||||||
|
dataset,
|
||||||
|
num_workers=4,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
# pin_memory=cfg.device != "cpu",
|
||||||
|
drop_last=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
stats_patterns = {
|
||||||
|
"action": "b c -> c",
|
||||||
|
"observation.state": "b c -> c",
|
||||||
|
}
|
||||||
|
for key in dataset.image_keys:
|
||||||
|
stats_patterns[key] = "b c h w -> c 1 1"
|
||||||
|
|
||||||
|
# mean and std will be computed incrementally while max and min will track the running value.
|
||||||
|
mean, std, max, min = {}, {}, {}, {}
|
||||||
|
for key in stats_patterns:
|
||||||
|
mean[key] = torch.tensor(0.0).float()
|
||||||
|
std[key] = torch.tensor(0.0).float()
|
||||||
|
max[key] = torch.tensor(-float("inf")).float()
|
||||||
|
min[key] = torch.tensor(float("inf")).float()
|
||||||
|
|
||||||
|
# Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get
|
||||||
|
# surprises when rerunning the sampler.
|
||||||
|
first_batch = None
|
||||||
|
running_item_count = 0 # for online mean computation
|
||||||
|
for i, batch in enumerate(
|
||||||
|
tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
|
||||||
|
):
|
||||||
|
this_batch_size = batch.batch_size[0]
|
||||||
|
running_item_count += this_batch_size
|
||||||
|
if first_batch is None:
|
||||||
|
first_batch = deepcopy(batch)
|
||||||
|
for key, pattern in stats_patterns.items():
|
||||||
|
batch[key] = batch[key].float()
|
||||||
|
# Numerically stable update step for mean computation.
|
||||||
|
batch_mean = einops.reduce(batch[key], pattern, "mean")
|
||||||
|
# Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents
|
||||||
|
# the update step, N is the running item count, B is this batch size, x̄ is the running mean,
|
||||||
|
# and x is the current batch mean. Some rearrangement is then required to avoid risking
|
||||||
|
# numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields
|
||||||
|
# x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ
|
||||||
|
mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count
|
||||||
|
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
|
||||||
|
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
|
||||||
|
|
||||||
|
if i == ceil(max_num_samples / batch_size) - 1:
|
||||||
|
break
|
||||||
|
|
||||||
|
first_batch_ = None
|
||||||
|
running_item_count = 0 # for online std computation
|
||||||
|
for i, batch in enumerate(tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")):
|
||||||
|
this_batch_size = batch.batch_size[0]
|
||||||
|
running_item_count += this_batch_size
|
||||||
|
# Sanity check to make sure the batches are still in the same order as before.
|
||||||
|
if first_batch_ is None:
|
||||||
|
first_batch_ = deepcopy(batch)
|
||||||
|
for key in stats_patterns:
|
||||||
|
assert torch.equal(first_batch_[key], first_batch[key])
|
||||||
|
for key, pattern in stats_patterns.items():
|
||||||
|
batch[key] = batch[key].float()
|
||||||
|
# Numerically stable update step for mean computation (where the mean is over squared
|
||||||
|
# residuals).See notes in the mean computation loop above.
|
||||||
|
batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean")
|
||||||
|
std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count
|
||||||
|
|
||||||
|
if i == ceil(max_num_samples / batch_size) - 1:
|
||||||
|
break
|
||||||
|
|
||||||
|
for key in stats_patterns:
|
||||||
|
std[key] = torch.sqrt(std[key])
|
||||||
|
|
||||||
|
stats = {}
|
||||||
|
for key in stats_patterns:
|
||||||
|
stats[key] = {
|
||||||
|
"mean": mean[key],
|
||||||
|
"std": std[key],
|
||||||
|
"max": max[key],
|
||||||
|
"min": min[key],
|
||||||
|
}
|
||||||
|
|
||||||
|
torch.save(stats, stats_path)
|
||||||
|
return stats
|
||||||
|
|
|
@ -72,12 +72,12 @@ class NormalizeTransform(Transform):
|
||||||
if inkey not in item:
|
if inkey not in item:
|
||||||
continue
|
continue
|
||||||
if self.mode == "mean_std":
|
if self.mode == "mean_std":
|
||||||
mean = self.stats[inkey]["mean"]
|
mean = self.stats[f"{inkey}.mean"]
|
||||||
std = self.stats[inkey]["std"]
|
std = self.stats[f"{inkey}.std"]
|
||||||
item[outkey] = (item[inkey] - mean) / (std + 1e-8)
|
item[outkey] = (item[inkey] - mean) / (std + 1e-8)
|
||||||
else:
|
else:
|
||||||
min = self.stats[inkey]["min"]
|
min = self.stats[f"{inkey}.min"]
|
||||||
max = self.stats[inkey]["max"]
|
max = self.stats[f"{inkey}.max"]
|
||||||
# normalize to [0,1]
|
# normalize to [0,1]
|
||||||
item[outkey] = (item[inkey] - min) / (max - min)
|
item[outkey] = (item[inkey] - min) / (max - min)
|
||||||
# normalize to [-1, 1]
|
# normalize to [-1, 1]
|
||||||
|
@ -89,12 +89,12 @@ class NormalizeTransform(Transform):
|
||||||
if inkey not in item:
|
if inkey not in item:
|
||||||
continue
|
continue
|
||||||
if self.mode == "mean_std":
|
if self.mode == "mean_std":
|
||||||
mean = self.stats[inkey]["mean"]
|
mean = self.stats[f"{inkey}.mean"]
|
||||||
std = self.stats[inkey]["std"]
|
std = self.stats[f"{inkey}.std"]
|
||||||
item[outkey] = item[inkey] * std + mean
|
item[outkey] = item[inkey] * std + mean
|
||||||
else:
|
else:
|
||||||
min = self.stats[inkey]["min"]
|
min = self.stats[f"{inkey}.min"]
|
||||||
max = self.stats[inkey]["max"]
|
max = self.stats[f"{inkey}.max"]
|
||||||
item[outkey] = (item[inkey] + 1) / 2
|
item[outkey] = (item[inkey] + 1) / 2
|
||||||
item[outkey] = item[outkey] * (max - min) + min
|
item[outkey] = item[outkey] * (max - min) + min
|
||||||
return item
|
return item
|
||||||
|
|
|
@ -1,10 +1,6 @@
|
||||||
import einops
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
|
|
||||||
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
|
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_offline_buffer
|
|
||||||
from lerobot.common.utils import init_hydra_config
|
from lerobot.common.utils import init_hydra_config
|
||||||
import logging
|
import logging
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
|
@ -52,32 +48,32 @@ def test_factory(env_name, dataset_id):
|
||||||
logging.warning(f'Missing "next.done" key in dataset {dataset}.')
|
logging.warning(f'Missing "next.done" key in dataset {dataset}.')
|
||||||
|
|
||||||
|
|
||||||
def test_compute_stats():
|
# def test_compute_stats():
|
||||||
"""Check that the statistics are computed correctly according to the stats_patterns property.
|
# """Check that the statistics are computed correctly according to the stats_patterns property.
|
||||||
|
|
||||||
We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do
|
# We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do
|
||||||
because we are working with a small dataset).
|
# because we are working with a small dataset).
|
||||||
"""
|
# """
|
||||||
cfg = init_hydra_config(
|
# cfg = init_hydra_config(
|
||||||
DEFAULT_CONFIG_PATH, overrides=["env=aloha", "env.task=sim_transfer_cube_human"]
|
# DEFAULT_CONFIG_PATH, overrides=["env=aloha", "env.task=sim_transfer_cube_human"]
|
||||||
)
|
# )
|
||||||
buffer = make_offline_buffer(cfg)
|
# dataset = make_dataset(cfg)
|
||||||
# Get all of the data.
|
# # Get all of the data.
|
||||||
all_data = TensorDictReplayBuffer(
|
# all_data = TensorDictReplayBuffer(
|
||||||
storage=buffer._storage,
|
# storage=buffer._storage,
|
||||||
batch_size=len(buffer),
|
# batch_size=len(buffer),
|
||||||
sampler=SamplerWithoutReplacement(),
|
# sampler=SamplerWithoutReplacement(),
|
||||||
).sample().float()
|
# ).sample().float()
|
||||||
# Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
|
# # Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
|
||||||
# computation of the statistics. While doing this, we also make sure it works when we don't divide the
|
# # computation of the statistics. While doing this, we also make sure it works when we don't divide the
|
||||||
# dataset into even batches.
|
# # dataset into even batches.
|
||||||
computed_stats = buffer._compute_stats(batch_size=int(len(all_data) * 0.75))
|
# computed_stats = buffer._compute_stats(batch_size=int(len(all_data) * 0.75))
|
||||||
for k, pattern in buffer.stats_patterns.items():
|
# for k, pattern in buffer.stats_patterns.items():
|
||||||
expected_mean = einops.reduce(all_data[k], pattern, "mean")
|
# expected_mean = einops.reduce(all_data[k], pattern, "mean")
|
||||||
assert torch.allclose(computed_stats[k]["mean"], expected_mean)
|
# assert torch.allclose(computed_stats[k]["mean"], expected_mean)
|
||||||
assert torch.allclose(
|
# assert torch.allclose(
|
||||||
computed_stats[k]["std"],
|
# computed_stats[k]["std"],
|
||||||
torch.sqrt(einops.reduce((all_data[k] - expected_mean) ** 2, pattern, "mean"))
|
# torch.sqrt(einops.reduce((all_data[k] - expected_mean) ** 2, pattern, "mean"))
|
||||||
)
|
# )
|
||||||
assert torch.allclose(computed_stats[k]["min"], einops.reduce(all_data[k], pattern, "min"))
|
# assert torch.allclose(computed_stats[k]["min"], einops.reduce(all_data[k], pattern, "min"))
|
||||||
assert torch.allclose(computed_stats[k]["max"], einops.reduce(all_data[k], pattern, "max"))
|
# assert torch.allclose(computed_stats[k]["max"], einops.reduce(all_data[k], pattern, "max"))
|
||||||
|
|
Loading…
Reference in New Issue