This commit is contained in:
Alexander Soare 2024-05-30 16:07:33 +01:00
parent 57e464b5bb
commit ff87379b3d
5 changed files with 19 additions and 8 deletions

View File

@ -162,6 +162,12 @@ def aggregate_stats(ls_datasets) -> dict[str, torch.Tensor]:
"""Aggregate stats of multiple LeRobot datasets into one set of stats without recomputing from scratch. """Aggregate stats of multiple LeRobot datasets into one set of stats without recomputing from scratch.
The final stats will have the union of all data keys from each of the datasets. The final stats will have the union of all data keys from each of the datasets.
The final stats will have the union of all data keys from each of the datasets. For instance:
- new_max = max(max_dataset_0, max_dataset_1, ...)
- new_min = min(min_dataset_0, min_dataset_1, ...)
- new_mean = (mean of all data)
- new_std = (std of all data)
""" """
data_keys = set() data_keys = set()
for dataset in ls_datasets: for dataset in ls_datasets:

View File

@ -39,12 +39,21 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
""" """
Args: Args:
cfg: A Hydra config as per the LeRobot config scheme. cfg: A Hydra config as per the LeRobot config scheme.
split: TODO(now) split: Select the data subset used to create an instance of LeRobotDataset.
All datasets hosted on [lerobot](https://huggingface.co/lerobot) contain only one subset: "train".
Thus, by default, `split="train"` selects all the available data. `split` aims to work like the
slicer in the hugging face datasets:
https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits
As of now, it only supports `split="train[:n]"` to load the first n frames of the dataset or
`split="train[n:]"` to load the last n frames. For instance `split="train[:1000]"`.
Returns: Returns:
The LeRobotDataset. The LeRobotDataset.
""" """
if not isinstance(cfg.dataset_repo_id, (str, ListConfig)): if not isinstance(cfg.dataset_repo_id, (str, ListConfig)):
raise ValueError("Expected cfg.dataset_repo_id to be either a single string or a list of strings.") raise ValueError(
"Expected cfg.dataset_repo_id to be either a single string to load one dataset or a list of "
"strings to load multiple datasets."
)
if isinstance(cfg.dataset_repo_id, str) and cfg.env.name not in cfg.dataset_repo_id: if isinstance(cfg.dataset_repo_id, str) and cfg.env.name not in cfg.dataset_repo_id:
logging.warning( logging.warning(

View File

@ -251,11 +251,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
f"Detected a mismatch in dataset info between {self.repo_ids[0]} and {repo_id}. This is " f"Detected a mismatch in dataset info between {self.repo_ids[0]} and {repo_id}. This is "
"not yet supported." "not yet supported."
) )
if set(dataset.features) != set(self._datasets[0].features):
# Use a warning here as we don't want to explicitly block this sort of inconsistency.
logging.warning(
f"Detected a mismatch in dataset features between {self.repo_ids[0]} and {repo_id}."
)
# Disable any data keys that are not common across all of the datasets. Note: we may relax this # Disable any data keys that are not common across all of the datasets. Note: we may relax this
# restriction in future iterations of this class. For now, this is necessary at least for being able # restriction in future iterations of this class. For now, this is necessary at least for being able
# to use PyTorch's default DataLoader collate function. # to use PyTorch's default DataLoader collate function.

View File

@ -23,7 +23,7 @@ If you know that your change will break backward compatibility, you should write
doesnt need to be merged into the `main` branch. Then you need to run this script and update the tests artifacts. doesnt need to be merged into the `main` branch. Then you need to run this script and update the tests artifacts.
Example usage: Example usage:
`DATA_DIR=tests/data python tests/scripts/save_dataset_to_safetensors.py` `python tests/scripts/save_dataset_to_safetensors.py`
""" """
import shutil import shutil

View File

@ -350,3 +350,4 @@ def test_aggregate_stats():
for data_key, data in zip(["a", "b", "c"], [data_a, data_b, data_c], strict=True): for data_key, data in zip(["a", "b", "c"], [data_a, data_b, data_c], strict=True):
for agg_fn in ["mean", "min", "max"]: for agg_fn in ["mean", "min", "max"]:
assert torch.allclose(stats[data_key][agg_fn], einops.reduce(data, "n -> 1", agg_fn)) assert torch.allclose(stats[data_key][agg_fn], einops.reduce(data, "n -> 1", agg_fn))
assert torch.allclose(stats[data_key]["std"], torch.std(data, correction=0))