From cfa956bd3be5dd5e4709f89ea60349783c35c8a4 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Thu, 30 May 2024 15:12:55 +0100 Subject: [PATCH] squash --- .../compute_stats.py | 54 ++++- lerobot/common/datasets/factory.py | 36 ++- lerobot/common/datasets/lerobot_dataset.py | 216 +++++++++++++++++- lerobot/common/datasets/utils.py | 17 +- lerobot/configs/default.yaml | 4 + lerobot/scripts/push_dataset_to_hub.py | 2 +- lerobot/scripts/train.py | 49 +--- tests/scripts/save_dataset_to_safetensors.py | 2 +- tests/test_datasets.py | 38 ++- 9 files changed, 344 insertions(+), 74 deletions(-) rename lerobot/common/datasets/{push_dataset_to_hub => }/compute_stats.py (72%) diff --git a/lerobot/common/datasets/push_dataset_to_hub/compute_stats.py b/lerobot/common/datasets/compute_stats.py similarity index 72% rename from lerobot/common/datasets/push_dataset_to_hub/compute_stats.py rename to lerobot/common/datasets/compute_stats.py index ec296658..bf1efdd3 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/compute_stats.py +++ b/lerobot/common/datasets/compute_stats.py @@ -16,17 +16,15 @@ from copy import deepcopy from math import ceil -import datasets import einops import torch import tqdm from datasets import Image -from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.video_utils import VideoFrame -def get_stats_einops_patterns(dataset: LeRobotDataset | datasets.Dataset, num_workers=0): +def get_stats_einops_patterns(dataset, num_workers=0): """These einops patterns will be used to aggregate batches and compute statistics. Note: We assume the images are in channel first format @@ -66,9 +64,8 @@ def get_stats_einops_patterns(dataset: LeRobotDataset | datasets.Dataset, num_wo return stats_patterns -def compute_stats( - dataset: LeRobotDataset | datasets.Dataset, batch_size=32, num_workers=16, max_num_samples=None -): +def compute_stats(dataset, batch_size=32, num_workers=16, max_num_samples=None): + """Compute mean/std and min/max statistics of all data keys in a LeRobotDataset.""" if max_num_samples is None: max_num_samples = len(dataset) @@ -159,3 +156,48 @@ def compute_stats( "min": min[key], } return stats + + +def consolidate_stats(ls_datasets) -> dict[str, torch.Tensor]: + """Consolidate stats of multiple LeRobot datasets into one set of stats without recomputing from scratch. + + The final stats will have the union of all data keys from each of the datasets. + """ + data_keys = set() + for dataset in ls_datasets: + data_keys.update(dataset.stats.keys()) + stats = {k: {} for k in data_keys} + for data_key in data_keys: + for stat_key in ["min", "max"]: + # compute `max(dataset_0["max"], dataset_1["max"], ...)` + stats[data_key][stat_key] = einops.reduce( + torch.stack([d.stats[data_key][stat_key] for d in ls_datasets if data_key in d.stats], dim=0), + "n ... -> ...", + stat_key, + ) + total_samples = sum(d.num_samples for d in ls_datasets if data_key in d.stats) + # Compute the "sum" statistic by multiplying each mean by the number of samples in the respective + # dataset, then divide by total_samples to get the overall "mean". + # NOTE: the brackets around (d.num_samples / total_samples) are needed tor minimize the risk of + # numerical overflow! + stats[data_key]["mean"] = sum( + d.stats[data_key]["mean"] * (d.num_samples / total_samples) + for d in ls_datasets + if data_key in d.stats + ) + # The derivation for standard deviation is a little more involved but is much in the same spirit as + # the computation of the mean. + # Given two sets of data where the statistics are known: + # σ_combined = sqrt[ (n1 * (σ1^2 + d1^2) + n2 * (σ2^2 + d2^2)) / (n1 + n2) ] + # where d1 = μ1 - μ_combined, d2 = μ2 - μ_combined + # NOTE: the brackets around (d.num_samples / total_samples) are needed tor minimize the risk of + # numerical overflow! + stats[data_key]["std"] = torch.sqrt( + sum( + (d.stats[data_key]["std"] ** 2 + (d.stats[data_key]["mean"] - stats[data_key]["mean"]) ** 2) + * (d.num_samples / total_samples) + for d in ls_datasets + if data_key in d.stats + ) + ) + return stats diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 7bdc2ca9..c04b11f2 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -16,9 +16,9 @@ import logging import torch -from omegaconf import OmegaConf +from omegaconf import ListConfig, OmegaConf -from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset def resolve_delta_timestamps(cfg): @@ -35,11 +35,18 @@ def resolve_delta_timestamps(cfg): cfg.training.delta_timestamps[key] = eval(delta_timestamps[key]) -def make_dataset( - cfg, - split="train", -): - if cfg.env.name not in cfg.dataset_repo_id: +def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotDataset: + """ + Args: + cfg: A Hydra config as per the LeRobot config scheme. + split: TODO(now) + Returns: + The LeRobotDataset. + """ + if not isinstance(cfg.dataset_repo_id, (str, ListConfig)): + raise ValueError("Expected cfg.dataset_repo_id to be either a single string or a list of strings.") + + if isinstance(cfg.dataset_repo_id, str) and cfg.env.name not in cfg.dataset_repo_id: logging.warning( f"There might be a mismatch between your training dataset ({cfg.dataset_repo_id=}) and your " f"environment ({cfg.env.name=})." @@ -49,11 +56,16 @@ def make_dataset( # TODO(rcadene): add data augmentations - dataset = LeRobotDataset( - cfg.dataset_repo_id, - split=split, - delta_timestamps=cfg.training.get("delta_timestamps"), - ) + if isinstance(cfg.dataset_repo_id, str): + dataset = LeRobotDataset( + cfg.dataset_repo_id, + split=split, + delta_timestamps=cfg.training.get("delta_timestamps"), + ) + else: + dataset = MultiLeRobotDataset( + cfg.dataset_repo_id, split=split, delta_timestamps=cfg.training.get("delta_timestamps") + ) if cfg.get("override_dataset_stats"): for key, stats_dict in cfg.override_dataset_stats.items(): diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 057e4770..3294ef6d 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -13,12 +13,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import os from pathlib import Path +from typing import Callable import datasets import torch +import torch.utils +from lerobot.common.datasets.compute_stats import consolidate_stats from lerobot.common.datasets.utils import ( calculate_episode_data_index, load_episode_data_index, @@ -42,7 +46,7 @@ class LeRobotDataset(torch.utils.data.Dataset): version: str | None = CODEBASE_VERSION, root: Path | None = DATA_DIR, split: str = "train", - transform: callable = None, + transform: Callable | None = None, delta_timestamps: dict[list[float]] | None = None, ): super().__init__() @@ -171,7 +175,7 @@ class LeRobotDataset(torch.utils.data.Dataset): @classmethod def from_preloaded( cls, - repo_id: str, + repo_id: str = "from_preloaded", version: str | None = CODEBASE_VERSION, root: Path | None = None, split: str = "train", @@ -183,7 +187,15 @@ class LeRobotDataset(torch.utils.data.Dataset): stats=None, info=None, videos_dir=None, - ): + ) -> "LeRobotDataset": + """Create a LeRobot Dataset from existing data and attributes instead of loading from the filesystem. + + It is especially useful when converting raw data into LeRobotDataset before saving the dataset + on the filesystem or uploading to the hub. + + Note: Meta-data attributes like `repo_id`, `version`, `root`, etc are optional and potentially + meaningless depending on the downstream usage of the return dataset. + """ # create an empty object of type LeRobotDataset obj = cls.__new__(cls) obj.repo_id = repo_id @@ -195,6 +207,202 @@ class LeRobotDataset(torch.utils.data.Dataset): obj.hf_dataset = hf_dataset obj.episode_data_index = episode_data_index obj.stats = stats - obj.info = info + obj.info = info if info is not None else {} obj.videos_dir = videos_dir return obj + + +class MultiLeRobotDataset(torch.utils.data.Dataset): + """A dataset consisting of multiple underlying `LeRobotDataset`s. + + The underlying `LeRobotDataset`s are effectively concatenated, and this class adopts much of the API + structure of `LeRobotDataset`. + """ + + def __init__( + self, + repo_ids: list[str], + version: str | None = CODEBASE_VERSION, + root: Path | None = DATA_DIR, + split: str = "train", + transform: Callable | None = None, + delta_timestamps: dict[list[float]] | None = None, + ): + super().__init__() + self.repo_ids = repo_ids + # Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which + # are handled by this class. + self._datasets = [ + LeRobotDataset( + repo_id, + version=version, + root=root, + split=split, + delta_timestamps=delta_timestamps, + transform=transform, + ) + for repo_id in repo_ids + ] + # Check that some properties are consistent across datasets. Note: We may relax some of these + # consistency requirements in future iterations of this class. + for repo_id, dataset in zip(self.repo_ids, self._datasets, strict=True): + if dataset.info != self._datasets[0].info: + raise ValueError( + f"Detected a mismatch in dataset info between {self.repo_ids[0]} and {repo_id}. This is " + "not yet supported." + ) + if set(dataset.features) != set(self._datasets[0].features): + # Use a warning here as we don't want to explicitly block this sort of inconsistency. + logging.warning( + f"Detected a mismatch in dataset features between {self.repo_ids[0]} and {repo_id}." + ) + # Disable any data keys that are not common across all of the datasets. Note: we may relax this + # restriction in future iterations of this class. For now, this is necessary at least for being able + # to use PyTorch's default DataLoader collate function. + self.disabled_data_keys = set() + intersection_data_keys = set(self._datasets[0].hf_dataset.features) + for dataset in self._datasets: + intersection_data_keys.intersection_update(dataset.hf_dataset.features) + if len(intersection_data_keys) == 0: + raise RuntimeError( + "Multiple datasets were provided but they had no keys common to all of them. The " + "multi-dataset functionality currently only keeps common keys." + ) + for repo_id, dataset in zip(self.repo_ids, self._datasets, strict=True): + extra_keys = set(dataset.hf_dataset.features).difference(intersection_data_keys) + logging.warning( + f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the " + "other datasets." + ) + self.disabled_data_keys.update(extra_keys) + + self.version = version + self.root = root + self.split = split + self.transform = transform + self.delta_timestamps = delta_timestamps + self.stats = consolidate_stats(self._datasets) + + @property + def repo_id_to_index(self): + """Return a mapping from dataset repo_id to a dataset index automatically created by this class. + + This index is incorporated as a data key in the dictionary returned by `__getitem__`. + """ + return {repo_id: i for i, repo_id in enumerate(self.repo_ids)} + + @property + def repo_index_to_id(self): + """Return the inverse mapping if repo_id_to_index.""" + return {v: k for k, v in self.repo_id_to_index} + + @property + def fps(self) -> int: + """Frames per second used during data collection. + + NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info. + """ + return self._datasets[0].info["fps"] + + @property + def video(self) -> bool: + """Returns True if this dataset loads video frames from mp4 files. + + Returns False if it only loads images from png files. + + NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info. + """ + return self._datasets[0].get("video", False) + + @property + def features(self) -> datasets.Features: + features = self.hf_datasets.features + for data_key in self.disabled_data_keys: + if data_key in features: + del features[data_key] + return features + + @property + def camera_keys(self) -> list[str]: + """Keys to access image and video stream from cameras.""" + keys = [] + for key, feats in self.hf_dataset.features.items(): + if key in self.disabled_data_keys: + pass + if isinstance(feats, (datasets.Image, VideoFrame)): + keys.append(key) + return keys + + @property + def video_frame_keys(self) -> list[str]: + """Keys to access video frames that requires to be decoded into images. + + Note: It is empty if the dataset contains images only, + or equal to `self.cameras` if the dataset contains videos only, + or can even be a subset of `self.cameras` in a case of a mixed image/video dataset. + """ + video_frame_keys = [] + for key, feats in self.hf_dataset.features.items(): + if key in self.disabled_data_keys: + pass + if isinstance(feats, VideoFrame): + video_frame_keys.append(key) + return video_frame_keys + + @property + def num_samples(self) -> int: + """Number of samples/frames.""" + return sum(d.num_samples for d in self._datasets) + + @property + def num_episodes(self) -> int: + """Number of episodes.""" + return sum(d.num_episodes for d in self._datasets) + + @property + def tolerance_s(self) -> float: + """Tolerance in seconds used to discard loaded frames when their timestamps + are not close enough from the requested frames. It is only used when `delta_timestamps` + is provided or when loading video frames from mp4 files. + """ + # 1e-4 to account for possible numerical error + return 1 / self.fps - 1e-4 + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: + if idx >= len(self): + raise IndexError(f"Index {idx} out of bounds.") + # Determine which dataset to get an item from based on the index. + start_idx = 0 + dataset_idx = 0 + for dataset in self._datasets: + if idx >= start_idx + dataset.num_samples: + start_idx += dataset.num_samples + dataset_idx += 1 + break + else: + raise AssertionError("We expect the loop to break out as long as the index is within bounds.") + item = self._datasets[dataset_idx][idx - start_idx] + item["dataset_index"] = torch.tensor(dataset_idx) + for data_key in self.disabled_data_keys: + if data_key in item: + del item[data_key] + return item + + def __repr__(self): + return ( + f"{self.__class__.__name__}(\n" + f" Repository IDs: '{self.repo_ids}',\n" + f" Version: '{self.version}',\n" + f" Split: '{self.split}',\n" + f" Number of Samples: {self.num_samples},\n" + f" Number of Episodes: {self.num_episodes},\n" + f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n" + f" Recorded Frames per Second: {self.fps},\n" + f" Camera Keys: {self.camera_keys},\n" + f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n" + f" Transformations: {self.transform},\n" + f")" + ) diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 86fef8d4..7f0d808e 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -59,7 +59,7 @@ def unflatten_dict(d, sep="/"): return outdict -def hf_transform_to_torch(items_dict): +def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]): """Get a transform function that convert items from Hugging Face dataset (pyarrow) to torch tensors. Importantly, images are converted from PIL, which corresponds to a channel last representation (h w c) of uint8 type, to a torch image representation @@ -73,6 +73,8 @@ def hf_transform_to_torch(items_dict): elif isinstance(first_item, dict) and "path" in first_item and "timestamp" in first_item: # video frame will be processed downstream pass + elif first_item is None: + pass else: items_dict[key] = [torch.tensor(x) for x in items_dict[key]] return items_dict @@ -317,14 +319,18 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torc return episode_data_index -def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset: - """ - Reset the `episode_index` of the provided HuggingFace Dataset. +def reset_episode_index(hf_dataset: datasets.Dataset, start_index: int = 0) -> datasets.Dataset: + """Reset the `episode_index` of the provided HuggingFace Dataset. `episode_data_index` (and related functionality such as `load_previous_and_future_frames`) requires the `episode_index` to be sorted, continuous (1,1,1 and not 1,2,1) and start at 0. This brings the `episode_index` to the required format. + + Args: + hf_dataset: Dataset for which we are resetting the indexing. + start_index: The episode index to start with for the new indexing. For most use cases in LeRobot this + should be left as 0. """ if len(hf_dataset) == 0: return hf_dataset @@ -337,7 +343,8 @@ def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset: example["episode_index"] = episode_idx_to_reset_idx_mapping[example["episode_index"].item()] return example - hf_dataset = hf_dataset.map(modify_ep_idx_func) + hf_dataset = hf_dataset.map(modify_ep_idx_func, input_columns=["episode_index"]) + return hf_dataset diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index f2238769..85b9ceea 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -23,6 +23,10 @@ use_amp: false # `seed` is used for training (eg: model initialization, dataset shuffling) # AND for the evaluation environments. seed: ??? +# You may provide a list of datasets here. `train.py` creates them all and concatenates them. Note: only data +# keys common between the datasets are kept. Each dataset gets and additional transform that inserts the +# "dataset_index" into the returned item. The index mapping is made according to the order in which the +# datsets are provided. dataset_repo_id: lerobot/pusht training: diff --git a/lerobot/scripts/push_dataset_to_hub.py b/lerobot/scripts/push_dataset_to_hub.py index c6eac5e9..52252b57 100644 --- a/lerobot/scripts/push_dataset_to_hub.py +++ b/lerobot/scripts/push_dataset_to_hub.py @@ -71,9 +71,9 @@ import torch from huggingface_hub import HfApi from safetensors.torch import save_file +from lerobot.common.datasets.compute_stats import compute_stats from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw -from lerobot.common.datasets.push_dataset_to_hub.compute_stats import compute_stats from lerobot.common.datasets.utils import flatten_dict diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index eb33b268..08ad6e66 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -16,7 +16,6 @@ import logging import time from contextlib import nullcontext -from copy import deepcopy from pathlib import Path from pprint import pformat @@ -28,6 +27,7 @@ from termcolor import colored from torch.cuda.amp import GradScaler from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps +from lerobot.common.datasets.lerobot_dataset import MultiLeRobotDataset from lerobot.common.datasets.utils import cycle from lerobot.common.envs.factory import make_env from lerobot.common.logger import Logger, log_output_dir @@ -280,6 +280,11 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No logging.info("make_dataset") offline_dataset = make_dataset(cfg) + if isinstance(offline_dataset, MultiLeRobotDataset): + logging.info( + "Multiple datasets were provided. Applied the following index mapping to the provided datasets: " + f"{pformat(offline_dataset.repo_id_to_index , indent=2)}" + ) # Create environment used for evaluating checkpoints during training on simulation data. # On real-world data, no need to create an environment as evaluations are done outside train.py, @@ -330,7 +335,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No max_episodes_rendered=4, start_seed=cfg.seed, ) - log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline) + log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline=True) if cfg.wandb.enable: logger.log_video(eval_info["video_paths"][0], step, mode="eval") logging.info("Resume training") @@ -362,7 +367,6 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No dl_iter = cycle(dataloader) policy.train() - is_offline = True for _ in range(step, cfg.training.offline_steps): if step == 0: logging.info("Start offline training on a fixed dataset") @@ -382,7 +386,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No ) if step % cfg.training.log_freq == 0: - log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline) + log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline=True) # Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed, # so we pass in step + 1. @@ -390,41 +394,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No step += 1 - logging.info("End of offline training") - - if cfg.training.online_steps == 0: - if cfg.training.eval_freq > 0: - eval_env.close() - return - - # create an env dedicated to online episodes collection from policy rollout - online_training_env = make_env(cfg, n_envs=1) - - # create an empty online dataset similar to offline dataset - online_dataset = deepcopy(offline_dataset) - online_dataset.hf_dataset = {} - online_dataset.episode_data_index = {} - - # create dataloader for online training - concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset]) - weights = [1.0] * len(concat_dataset) - sampler = torch.utils.data.WeightedRandomSampler( - weights, num_samples=len(concat_dataset), replacement=True - ) - dataloader = torch.utils.data.DataLoader( - concat_dataset, - num_workers=4, - batch_size=cfg.training.batch_size, - sampler=sampler, - pin_memory=device.type != "cpu", - drop_last=False, - ) - - logging.info("End of online training") - - if cfg.training.eval_freq > 0: - eval_env.close() - online_training_env.close() + eval_env.close() + logging.info("End of training") @hydra.main(version_base="1.2", config_name="default", config_path="../configs") diff --git a/tests/scripts/save_dataset_to_safetensors.py b/tests/scripts/save_dataset_to_safetensors.py index 4aa8131f..6c421ae3 100644 --- a/tests/scripts/save_dataset_to_safetensors.py +++ b/tests/scripts/save_dataset_to_safetensors.py @@ -23,7 +23,7 @@ If you know that your change will break backward compatibility, you should write doesnt need to be merged into the `main` branch. Then you need to run this script and update the tests artifacts. Example usage: - `python tests/scripts/save_dataset_to_safetensors.py` + `DATA_DIR=tests/data python tests/scripts/save_dataset_to_safetensors.py` """ import shutil diff --git a/tests/test_datasets.py b/tests/test_datasets.py index afea16a5..c5adb748 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -25,21 +25,20 @@ from datasets import Dataset from safetensors.torch import load_file import lerobot -from lerobot.common.datasets.factory import make_dataset -from lerobot.common.datasets.lerobot_dataset import ( - LeRobotDataset, -) -from lerobot.common.datasets.push_dataset_to_hub.compute_stats import ( +from lerobot.common.datasets.compute_stats import ( compute_stats, + consolidate_stats, get_stats_einops_patterns, ) +from lerobot.common.datasets.factory import make_dataset +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.utils import ( flatten_dict, hf_transform_to_torch, load_previous_and_future_frames, unflatten_dict, ) -from lerobot.common.utils.utils import init_hydra_config +from lerobot.common.utils.utils import init_hydra_config, seeded_context from tests.utils import DEFAULT_CONFIG_PATH, DEVICE @@ -315,3 +314,30 @@ def test_backward_compatibility(repo_id): # i = dataset.episode_data_index["to"][-1].item() # load_and_compare(i - 2) # load_and_compare(i - 1) + + +def test_consolidate_stats(): + """Makes 3 basic datasets and checks that aggregate stats are computed correctly.""" + with seeded_context(0): + data_a = torch.rand(30, dtype=torch.float32) + data_b = torch.rand(20, dtype=torch.float32) + data_c = torch.rand(20, dtype=torch.float32) + + hf_dataset_1 = Dataset.from_dict( + {"a": data_a[:10], "b": data_b[:10], "c": data_c[:10], "index": torch.arange(10)} + ) + hf_dataset_1.set_transform(hf_transform_to_torch) + hf_dataset_2 = Dataset.from_dict({"a": data_a[10:20], "b": data_b[10:], "index": torch.arange(10)}) + hf_dataset_2.set_transform(hf_transform_to_torch) + hf_dataset_3 = Dataset.from_dict({"a": data_a[20:], "c": data_c[10:], "index": torch.arange(10)}) + hf_dataset_3.set_transform(hf_transform_to_torch) + dataset_1 = LeRobotDataset.from_preloaded("d1", hf_dataset=hf_dataset_1) + dataset_1.stats = compute_stats(dataset_1, batch_size=len(hf_dataset_1), num_workers=0) + dataset_2 = LeRobotDataset.from_preloaded("d2", hf_dataset=hf_dataset_2) + dataset_2.stats = compute_stats(dataset_2, batch_size=len(hf_dataset_2), num_workers=0) + dataset_3 = LeRobotDataset.from_preloaded("d3", hf_dataset=hf_dataset_3) + dataset_3.stats = compute_stats(dataset_3, batch_size=len(hf_dataset_3), num_workers=0) + stats = consolidate_stats([dataset_1, dataset_2, dataset_3]) + for data_key, data in zip(["a", "b", "c"], [data_a, data_b, data_c], strict=True): + for agg_fn in ["mean", "min", "max"]: + assert torch.allclose(stats[data_key][agg_fn], einops.reduce(data, "n -> 1", agg_fn))