revision
This commit is contained in:
parent
57e464b5bb
commit
ff87379b3d
|
@ -162,6 +162,12 @@ def aggregate_stats(ls_datasets) -> dict[str, torch.Tensor]:
|
|||
"""Aggregate stats of multiple LeRobot datasets into one set of stats without recomputing from scratch.
|
||||
|
||||
The final stats will have the union of all data keys from each of the datasets.
|
||||
|
||||
The final stats will have the union of all data keys from each of the datasets. For instance:
|
||||
- new_max = max(max_dataset_0, max_dataset_1, ...)
|
||||
- new_min = min(min_dataset_0, min_dataset_1, ...)
|
||||
- new_mean = (mean of all data)
|
||||
- new_std = (std of all data)
|
||||
"""
|
||||
data_keys = set()
|
||||
for dataset in ls_datasets:
|
||||
|
|
|
@ -39,12 +39,21 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
|
|||
"""
|
||||
Args:
|
||||
cfg: A Hydra config as per the LeRobot config scheme.
|
||||
split: TODO(now)
|
||||
split: Select the data subset used to create an instance of LeRobotDataset.
|
||||
All datasets hosted on [lerobot](https://huggingface.co/lerobot) contain only one subset: "train".
|
||||
Thus, by default, `split="train"` selects all the available data. `split` aims to work like the
|
||||
slicer in the hugging face datasets:
|
||||
https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits
|
||||
As of now, it only supports `split="train[:n]"` to load the first n frames of the dataset or
|
||||
`split="train[n:]"` to load the last n frames. For instance `split="train[:1000]"`.
|
||||
Returns:
|
||||
The LeRobotDataset.
|
||||
"""
|
||||
if not isinstance(cfg.dataset_repo_id, (str, ListConfig)):
|
||||
raise ValueError("Expected cfg.dataset_repo_id to be either a single string or a list of strings.")
|
||||
raise ValueError(
|
||||
"Expected cfg.dataset_repo_id to be either a single string to load one dataset or a list of "
|
||||
"strings to load multiple datasets."
|
||||
)
|
||||
|
||||
if isinstance(cfg.dataset_repo_id, str) and cfg.env.name not in cfg.dataset_repo_id:
|
||||
logging.warning(
|
||||
|
|
|
@ -251,11 +251,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|||
f"Detected a mismatch in dataset info between {self.repo_ids[0]} and {repo_id}. This is "
|
||||
"not yet supported."
|
||||
)
|
||||
if set(dataset.features) != set(self._datasets[0].features):
|
||||
# Use a warning here as we don't want to explicitly block this sort of inconsistency.
|
||||
logging.warning(
|
||||
f"Detected a mismatch in dataset features between {self.repo_ids[0]} and {repo_id}."
|
||||
)
|
||||
# Disable any data keys that are not common across all of the datasets. Note: we may relax this
|
||||
# restriction in future iterations of this class. For now, this is necessary at least for being able
|
||||
# to use PyTorch's default DataLoader collate function.
|
||||
|
|
|
@ -23,7 +23,7 @@ If you know that your change will break backward compatibility, you should write
|
|||
doesnt need to be merged into the `main` branch. Then you need to run this script and update the tests artifacts.
|
||||
|
||||
Example usage:
|
||||
`DATA_DIR=tests/data python tests/scripts/save_dataset_to_safetensors.py`
|
||||
`python tests/scripts/save_dataset_to_safetensors.py`
|
||||
"""
|
||||
|
||||
import shutil
|
||||
|
|
|
@ -350,3 +350,4 @@ def test_aggregate_stats():
|
|||
for data_key, data in zip(["a", "b", "c"], [data_a, data_b, data_c], strict=True):
|
||||
for agg_fn in ["mean", "min", "max"]:
|
||||
assert torch.allclose(stats[data_key][agg_fn], einops.reduce(data, "n -> 1", agg_fn))
|
||||
assert torch.allclose(stats[data_key]["std"], torch.std(data, correction=0))
|
||||
|
|
Loading…
Reference in New Issue