From ff87379b3d2babf9b681b0840622c03bd02185b6 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Thu, 30 May 2024 16:07:33 +0100 Subject: [PATCH] revision --- lerobot/common/datasets/compute_stats.py | 6 ++++++ lerobot/common/datasets/factory.py | 13 +++++++++++-- lerobot/common/datasets/lerobot_dataset.py | 5 ----- tests/scripts/save_dataset_to_safetensors.py | 2 +- tests/test_datasets.py | 1 + 5 files changed, 19 insertions(+), 8 deletions(-) diff --git a/lerobot/common/datasets/compute_stats.py b/lerobot/common/datasets/compute_stats.py index 68210ca2..a69bc573 100644 --- a/lerobot/common/datasets/compute_stats.py +++ b/lerobot/common/datasets/compute_stats.py @@ -162,6 +162,12 @@ def aggregate_stats(ls_datasets) -> dict[str, torch.Tensor]: """Aggregate stats of multiple LeRobot datasets into one set of stats without recomputing from scratch. The final stats will have the union of all data keys from each of the datasets. + + The final stats will have the union of all data keys from each of the datasets. For instance: + - new_max = max(max_dataset_0, max_dataset_1, ...) + - new_min = min(min_dataset_0, min_dataset_1, ...) + - new_mean = (mean of all data) + - new_std = (std of all data) """ data_keys = set() for dataset in ls_datasets: diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index c04b11f2..b48a9211 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -39,12 +39,21 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData """ Args: cfg: A Hydra config as per the LeRobot config scheme. - split: TODO(now) + split: Select the data subset used to create an instance of LeRobotDataset. + All datasets hosted on [lerobot](https://huggingface.co/lerobot) contain only one subset: "train". + Thus, by default, `split="train"` selects all the available data. `split` aims to work like the + slicer in the hugging face datasets: + https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits + As of now, it only supports `split="train[:n]"` to load the first n frames of the dataset or + `split="train[n:]"` to load the last n frames. For instance `split="train[:1000]"`. Returns: The LeRobotDataset. """ if not isinstance(cfg.dataset_repo_id, (str, ListConfig)): - raise ValueError("Expected cfg.dataset_repo_id to be either a single string or a list of strings.") + raise ValueError( + "Expected cfg.dataset_repo_id to be either a single string to load one dataset or a list of " + "strings to load multiple datasets." + ) if isinstance(cfg.dataset_repo_id, str) and cfg.env.name not in cfg.dataset_repo_id: logging.warning( diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 7298dd08..a87c3ee8 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -251,11 +251,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): f"Detected a mismatch in dataset info between {self.repo_ids[0]} and {repo_id}. This is " "not yet supported." ) - if set(dataset.features) != set(self._datasets[0].features): - # Use a warning here as we don't want to explicitly block this sort of inconsistency. - logging.warning( - f"Detected a mismatch in dataset features between {self.repo_ids[0]} and {repo_id}." - ) # Disable any data keys that are not common across all of the datasets. Note: we may relax this # restriction in future iterations of this class. For now, this is necessary at least for being able # to use PyTorch's default DataLoader collate function. diff --git a/tests/scripts/save_dataset_to_safetensors.py b/tests/scripts/save_dataset_to_safetensors.py index 6c421ae3..4aa8131f 100644 --- a/tests/scripts/save_dataset_to_safetensors.py +++ b/tests/scripts/save_dataset_to_safetensors.py @@ -23,7 +23,7 @@ If you know that your change will break backward compatibility, you should write doesnt need to be merged into the `main` branch. Then you need to run this script and update the tests artifacts. Example usage: - `DATA_DIR=tests/data python tests/scripts/save_dataset_to_safetensors.py` + `python tests/scripts/save_dataset_to_safetensors.py` """ import shutil diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 1c9af2c2..e01fc52c 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -350,3 +350,4 @@ def test_aggregate_stats(): for data_key, data in zip(["a", "b", "c"], [data_a, data_b, data_c], strict=True): for agg_fn in ["mean", "min", "max"]: assert torch.allclose(stats[data_key][agg_fn], einops.reduce(data, "n -> 1", agg_fn)) + assert torch.allclose(stats[data_key]["std"], torch.std(data, correction=0))