From 659c69a1c07105351bd39ad89aafc2ea44b228d3 Mon Sep 17 00:00:00 2001 From: Remi Date: Thu, 25 Apr 2024 12:23:12 +0200 Subject: [PATCH] Refactor datasets into LeRobotDataset (#91) Co-authored-by: Alexander Soare --- CONTRIBUTING.md | 9 +- README.md | 41 ++++----- download_and_upload_dataset.py | 9 +- examples/2_load_lerobot_dataset.py | 44 ++++------ lerobot/__init__.py | 44 +++++----- lerobot/common/datasets/aloha.py | 78 ------------------ lerobot/common/datasets/factory.py | 27 ++---- .../datasets/{xarm.py => lerobot_dataset.py} | 40 ++++----- lerobot/common/datasets/pusht.py | 76 ----------------- lerobot/common/datasets/utils.py | 37 ++++++--- lerobot/configs/default.yaml | 3 +- lerobot/configs/env/aloha.yaml | 3 +- lerobot/configs/env/pusht.yaml | 3 +- lerobot/configs/env/xarm.yaml | 3 +- .../meta_data/episode_data_index.safetensors | Bin .../meta_data/info.json | 0 .../meta_data/stats.safetensors | Bin .../aloha_sim_insertion_human/stats.pth | Bin .../train/data-00000-of-00001.arrow | Bin .../train/dataset_info.json | 0 .../train/state.json | 0 .../meta_data/episode_data_index.safetensors | Bin .../meta_data/info.json | 0 .../meta_data/stats.safetensors | Bin .../aloha_sim_insertion_scripted/stats.pth | Bin .../train/data-00000-of-00001.arrow | Bin .../train/dataset_info.json | 0 .../train/state.json | 0 .../meta_data/episode_data_index.safetensors | Bin .../meta_data/info.json | 0 .../meta_data/stats.safetensors | Bin .../aloha_sim_transfer_cube_human/stats.pth | Bin .../train/data-00000-of-00001.arrow | Bin .../train/dataset_info.json | 0 .../train/state.json | 0 .../meta_data/episode_data_index.safetensors | Bin .../meta_data/info.json | 0 .../meta_data/stats.safetensors | Bin .../stats.pth | Bin .../train/data-00000-of-00001.arrow | Bin .../train/dataset_info.json | 0 .../train/state.json | 0 .../meta_data/episode_data_index.safetensors | Bin .../{ => lerobot}/pusht/meta_data/info.json | 0 .../pusht/meta_data/stats.safetensors | Bin tests/data/{ => lerobot}/pusht/stats.pth | Bin .../pusht/train/data-00000-of-00001.arrow | Bin .../pusht/train/dataset_info.json | 0 .../meta_data/episode_data_index.safetensors | Bin .../pusht/train/meta_data/info.json | 0 .../train/meta_data/stats_action.safetensors | Bin .../stats_observation.image.safetensors | Bin .../stats_observation.state.safetensors | Bin .../data/{ => lerobot}/pusht/train/state.json | 0 .../meta_data/episode_data_index.safetensors | Bin .../xarm_lift_medium/meta_data/info.json | 0 .../meta_data/stats.safetensors | Bin .../{ => lerobot}/xarm_lift_medium/stats.pth | Bin .../train/data-00000-of-00001.arrow | Bin .../xarm_lift_medium/train/dataset_info.json | 0 .../xarm_lift_medium/train/state.json | 0 .../meta_data/episode_data_index.safetensors | Bin .../meta_data/info.json | 0 .../meta_data/stats.safetensors | Bin .../train/data-00000-of-00001.arrow | Bin .../train/dataset_info.json | 0 .../xarm_lift_medium_replay/train/state.json | 0 .../meta_data/episode_data_index.safetensors | Bin .../xarm_push_medium/meta_data/info.json | 0 .../meta_data/stats.safetensors | Bin .../train/data-00000-of-00001.arrow | Bin .../xarm_push_medium/train/dataset_info.json | 0 .../xarm_push_medium/train/state.json | 0 .../meta_data/episode_data_index.safetensors | Bin .../meta_data/info.json | 0 .../meta_data/stats.safetensors | Bin .../train/data-00000-of-00001.arrow | Bin .../train/dataset_info.json | 0 .../xarm_push_medium_replay/train/state.json | 0 .../{ => lerobot}/pusht/frame_0.safetensors | Bin .../{ => lerobot}/pusht/frame_1.safetensors | Bin .../{ => lerobot}/pusht/frame_159.safetensors | Bin .../{ => lerobot}/pusht/frame_160.safetensors | Bin .../{ => lerobot}/pusht/frame_80.safetensors | Bin .../{ => lerobot}/pusht/frame_81.safetensors | Bin tests/scripts/save_dataset_to_safetensors.py | 12 +-- tests/test_available.py | 31 ++----- tests/test_datasets.py | 31 ++++--- tests/test_policies.py | 20 ++++- tests/test_visualize_dataset.py | 8 +- 90 files changed, 167 insertions(+), 352 deletions(-) delete mode 100644 lerobot/common/datasets/aloha.py rename lerobot/common/datasets/{xarm.py => lerobot_dataset.py} (62%) delete mode 100644 lerobot/common/datasets/pusht.py rename tests/data/{ => lerobot}/aloha_sim_insertion_human/meta_data/episode_data_index.safetensors (100%) rename tests/data/{ => lerobot}/aloha_sim_insertion_human/meta_data/info.json (100%) rename tests/data/{ => lerobot}/aloha_sim_insertion_human/meta_data/stats.safetensors (100%) rename tests/data/{ => lerobot}/aloha_sim_insertion_human/stats.pth (100%) rename tests/data/{ => lerobot}/aloha_sim_insertion_human/train/data-00000-of-00001.arrow (100%) rename tests/data/{ => lerobot}/aloha_sim_insertion_human/train/dataset_info.json (100%) rename tests/data/{ => lerobot}/aloha_sim_insertion_human/train/state.json (100%) rename tests/data/{ => lerobot}/aloha_sim_insertion_scripted/meta_data/episode_data_index.safetensors (100%) rename tests/data/{ => lerobot}/aloha_sim_insertion_scripted/meta_data/info.json (100%) rename tests/data/{ => lerobot}/aloha_sim_insertion_scripted/meta_data/stats.safetensors (100%) rename tests/data/{ => lerobot}/aloha_sim_insertion_scripted/stats.pth (100%) rename tests/data/{ => lerobot}/aloha_sim_insertion_scripted/train/data-00000-of-00001.arrow (100%) rename tests/data/{ => lerobot}/aloha_sim_insertion_scripted/train/dataset_info.json (100%) rename tests/data/{ => lerobot}/aloha_sim_insertion_scripted/train/state.json (100%) rename tests/data/{ => lerobot}/aloha_sim_transfer_cube_human/meta_data/episode_data_index.safetensors (100%) rename tests/data/{ => lerobot}/aloha_sim_transfer_cube_human/meta_data/info.json (100%) rename tests/data/{ => lerobot}/aloha_sim_transfer_cube_human/meta_data/stats.safetensors (100%) rename tests/data/{ => lerobot}/aloha_sim_transfer_cube_human/stats.pth (100%) rename tests/data/{ => lerobot}/aloha_sim_transfer_cube_human/train/data-00000-of-00001.arrow (100%) rename tests/data/{ => lerobot}/aloha_sim_transfer_cube_human/train/dataset_info.json (100%) rename tests/data/{ => lerobot}/aloha_sim_transfer_cube_human/train/state.json (100%) rename tests/data/{ => lerobot}/aloha_sim_transfer_cube_scripted/meta_data/episode_data_index.safetensors (100%) rename tests/data/{ => lerobot}/aloha_sim_transfer_cube_scripted/meta_data/info.json (100%) rename tests/data/{ => lerobot}/aloha_sim_transfer_cube_scripted/meta_data/stats.safetensors (100%) rename tests/data/{ => lerobot}/aloha_sim_transfer_cube_scripted/stats.pth (100%) rename tests/data/{ => lerobot}/aloha_sim_transfer_cube_scripted/train/data-00000-of-00001.arrow (100%) rename tests/data/{ => lerobot}/aloha_sim_transfer_cube_scripted/train/dataset_info.json (100%) rename tests/data/{ => lerobot}/aloha_sim_transfer_cube_scripted/train/state.json (100%) rename tests/data/{ => lerobot}/pusht/meta_data/episode_data_index.safetensors (100%) rename tests/data/{ => lerobot}/pusht/meta_data/info.json (100%) rename tests/data/{ => lerobot}/pusht/meta_data/stats.safetensors (100%) rename tests/data/{ => lerobot}/pusht/stats.pth (100%) rename tests/data/{ => lerobot}/pusht/train/data-00000-of-00001.arrow (100%) rename tests/data/{ => lerobot}/pusht/train/dataset_info.json (100%) rename tests/data/{ => lerobot}/pusht/train/meta_data/episode_data_index.safetensors (100%) rename tests/data/{ => lerobot}/pusht/train/meta_data/info.json (100%) rename tests/data/{ => lerobot}/pusht/train/meta_data/stats_action.safetensors (100%) rename tests/data/{ => lerobot}/pusht/train/meta_data/stats_observation.image.safetensors (100%) rename tests/data/{ => lerobot}/pusht/train/meta_data/stats_observation.state.safetensors (100%) rename tests/data/{ => lerobot}/pusht/train/state.json (100%) rename tests/data/{ => lerobot}/xarm_lift_medium/meta_data/episode_data_index.safetensors (100%) rename tests/data/{ => lerobot}/xarm_lift_medium/meta_data/info.json (100%) rename tests/data/{ => lerobot}/xarm_lift_medium/meta_data/stats.safetensors (100%) rename tests/data/{ => lerobot}/xarm_lift_medium/stats.pth (100%) rename tests/data/{ => lerobot}/xarm_lift_medium/train/data-00000-of-00001.arrow (100%) rename tests/data/{ => lerobot}/xarm_lift_medium/train/dataset_info.json (100%) rename tests/data/{ => lerobot}/xarm_lift_medium/train/state.json (100%) rename tests/data/{ => lerobot}/xarm_lift_medium_replay/meta_data/episode_data_index.safetensors (100%) rename tests/data/{ => lerobot}/xarm_lift_medium_replay/meta_data/info.json (100%) rename tests/data/{ => lerobot}/xarm_lift_medium_replay/meta_data/stats.safetensors (100%) rename tests/data/{ => lerobot}/xarm_lift_medium_replay/train/data-00000-of-00001.arrow (100%) rename tests/data/{ => lerobot}/xarm_lift_medium_replay/train/dataset_info.json (100%) rename tests/data/{ => lerobot}/xarm_lift_medium_replay/train/state.json (100%) rename tests/data/{ => lerobot}/xarm_push_medium/meta_data/episode_data_index.safetensors (100%) rename tests/data/{ => lerobot}/xarm_push_medium/meta_data/info.json (100%) rename tests/data/{ => lerobot}/xarm_push_medium/meta_data/stats.safetensors (100%) rename tests/data/{ => lerobot}/xarm_push_medium/train/data-00000-of-00001.arrow (100%) rename tests/data/{ => lerobot}/xarm_push_medium/train/dataset_info.json (100%) rename tests/data/{ => lerobot}/xarm_push_medium/train/state.json (100%) rename tests/data/{ => lerobot}/xarm_push_medium_replay/meta_data/episode_data_index.safetensors (100%) rename tests/data/{ => lerobot}/xarm_push_medium_replay/meta_data/info.json (100%) rename tests/data/{ => lerobot}/xarm_push_medium_replay/meta_data/stats.safetensors (100%) rename tests/data/{ => lerobot}/xarm_push_medium_replay/train/data-00000-of-00001.arrow (100%) rename tests/data/{ => lerobot}/xarm_push_medium_replay/train/dataset_info.json (100%) rename tests/data/{ => lerobot}/xarm_push_medium_replay/train/state.json (100%) rename tests/data/save_dataset_to_safetensors/{ => lerobot}/pusht/frame_0.safetensors (100%) rename tests/data/save_dataset_to_safetensors/{ => lerobot}/pusht/frame_1.safetensors (100%) rename tests/data/save_dataset_to_safetensors/{ => lerobot}/pusht/frame_159.safetensors (100%) rename tests/data/save_dataset_to_safetensors/{ => lerobot}/pusht/frame_160.safetensors (100%) rename tests/data/save_dataset_to_safetensors/{ => lerobot}/pusht/frame_80.safetensors (100%) rename tests/data/save_dataset_to_safetensors/{ => lerobot}/pusht/frame_81.safetensors (100%) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 5b69a13a..3cee0e46 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -73,15 +73,14 @@ environments ([aloha](https://github.com/huggingface/gym-aloha), [pusht](https://github.com/huggingface/gym-pusht)) and follow the same api design. -When implementing a new dataset class (e.g. `AlohaDataset`) follow these steps: -- Update `available_datasets` in `lerobot/__init__.py` -- Copy it in the required `available_datasets` class attribute +When implementing a new dataset loadable with LeRobotDataset follow these steps: +- Update `available_datasets_per_env` in `lerobot/__init__.py` When implementing a new environment (e.g. `gym_aloha`), follow these steps: -- Update `available_envs`, `available_tasks_per_env` and `available_datasets` in `lerobot/__init__.py` +- Update `available_tasks_per_env` and `available_datasets_per_env` in `lerobot/__init__.py` When implementing a new policy class (e.g. `DiffusionPolicy`) follow these steps: -- Update `available_policies` in `lerobot/__init__.py` +- Update `available_policies` and `available_policies_per_env`, in `lerobot/__init__.py` - Set the required `name` class attribute. - Update variables in `tests/test_available.py` by importing your new Policy class diff --git a/README.md b/README.md index 8b78ca3e..a44d0af0 100644 --- a/README.md +++ b/README.md @@ -118,30 +118,7 @@ wandb login ### Visualize datasets -You can import our dataset class, download the data from the HuggingFace hub and use our rendering utilities: -```python -""" Copy pasted from `examples/1_visualize_dataset.py` """ -import os -from pathlib import Path - -import lerobot -from lerobot.common.datasets.aloha import AlohaDataset -from lerobot.scripts.visualize_dataset import render_dataset - -print(lerobot.available_datasets) -# >>> ['aloha_sim_insertion_human', 'aloha_sim_insertion_scripted', 'aloha_sim_transfer_cube_human', 'aloha_sim_transfer_cube_scripted', 'pusht', 'xarm_lift_medium'] - -# TODO(rcadene): remove DATA_DIR -dataset = AlohaDataset("pusht", root=Path(os.environ.get("DATA_DIR"))) - -video_paths = render_dataset( - dataset, - out_dir="outputs/visualize_dataset/example", - max_num_episodes=1, -) -print(video_paths) -# ['outputs/visualize_dataset/example/episode_0.mp4'] -``` +Check out [examples](./examples) to see how you can import our dataset class, download the data from the HuggingFace hub and use our rendering utilities. Or you can achieve the same result by executing our script from the command line: ```bash @@ -153,7 +130,7 @@ hydra.run.dir=outputs/visualize_dataset/example ### Evaluate a pretrained policy -Check out [example 2](./examples/2_evaluate_pretrained_policy.py) to see how you can load a pretrained policy from HuggingFace hub, load up the corresponding environment and model, and run an evaluation. +Check out [examples](./examples) to see how you can load a pretrained policy from HuggingFace hub, load up the corresponding environment and model, and run an evaluation. Or you can achieve the same result by executing our script from the command line: ```bash @@ -176,24 +153,30 @@ See `python lerobot/scripts/eval.py --help` for more instructions. ### Train your own policy -You can import our dataset, environment, policy classes, and use our training utilities (if some data is missing, it will be automatically downloaded from HuggingFace hub): check out [example 3](./examples/3_train_policy.py). After you run this, you may want to revisit [example 2](./examples/2_evaluate_pretrained_policy.py) to evaluate your training output! +Check out [examples](./examples) to see how you can start training a model on a dataset, which will be automatically downloaded if needed. In general, you can use our training script to easily train any policy on any environment: ```bash python lerobot/scripts/train.py \ env=aloha \ task=sim_insertion \ -dataset_id=aloha_sim_insertion_scripted \ +repo_id=lerobot/aloha_sim_insertion_scripted \ policy=act \ hydra.run.dir=outputs/train/aloha_act ``` +After training, you may want to revisit model evaluation to change the evaluation settings. In fact, during training every checkpoint is already evaluated but on a low number of episodes for efficiency. Check out [example](./examples) to evaluate any model checkpoint on more episodes to increase statistical significance. + ## Contribute If you would like to contribute to 🤗 LeRobot, please check out our [contribution guide](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md). ### Add a new dataset +```python +# TODO(rcadene, AdilZouitine): rewrite this section +``` + To add a dataset to the hub, first login and use a token generated from [huggingface settings](https://huggingface.co/settings/tokens) with write access: ```bash huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential @@ -255,6 +238,10 @@ python tests/scripts/mock_dataset.py --in-data-dir data/$DATASET --out-data-dir ### Add a pretrained policy +```python +# TODO(rcadene, alexander-soare): rewrite this section +``` + Once you have trained a policy you may upload it to the HuggingFace hub. Firstly, make sure you have a model repository set up on the hub. The hub ID looks like HF_USER/REPO_NAME. diff --git a/download_and_upload_dataset.py b/download_and_upload_dataset.py index 8e1e27ce..f37d5fa2 100644 --- a/download_and_upload_dataset.py +++ b/download_and_upload_dataset.py @@ -23,6 +23,7 @@ from lerobot.common.datasets.utils import compute_stats, flatten_dict, hf_transf def download_and_upload(root, revision, dataset_id): + # TODO(rcadene, adilzouitine): add community_id/user_id (e.g. "lerobot", "cadene") or repo_id (e.g. "lerobot/pusht") if "pusht" in dataset_id: download_and_upload_pusht(root, revision, dataset_id) elif "xarm" in dataset_id: @@ -149,11 +150,11 @@ def push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dat # copy in tests folder, the first episode and the meta_data directory num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0] hf_dataset.select(range(num_items_first_ep)).with_format("torch").save_to_disk( - f"tests/data/{dataset_id}/train" + f"tests/data/lerobot/{dataset_id}/train" ) - if Path(f"tests/data/{dataset_id}/meta_data").exists(): - shutil.rmtree(f"tests/data/{dataset_id}/meta_data") - shutil.copytree(meta_data_dir, f"tests/data/{dataset_id}/meta_data") + if Path(f"tests/data/lerobot/{dataset_id}/meta_data").exists(): + shutil.rmtree(f"tests/data/lerobot/{dataset_id}/meta_data") + shutil.copytree(meta_data_dir, f"tests/data/lerobot/{dataset_id}/meta_data") def download_and_upload_pusht(root, revision, dataset_id="pusht", fps=10): diff --git a/examples/2_load_lerobot_dataset.py b/examples/2_load_lerobot_dataset.py index 4eaed238..26c78de1 100644 --- a/examples/2_load_lerobot_dataset.py +++ b/examples/2_load_lerobot_dataset.py @@ -1,5 +1,5 @@ """ -This script demonstrates the use of the PushtDataset class for handling and processing robotic datasets from Hugging Face. +This script demonstrates the use of `LeRobotDataset` class for handling and processing robotic datasets from Hugging Face. It illustrates how to load datasets, manipulate them, and apply transformations suitable for machine learning tasks in PyTorch. Features included in this script: @@ -11,22 +11,6 @@ Features included in this script: - Demonstrating compatibility with PyTorch DataLoader for batch processing. The script ends with examples of how to batch process data using PyTorch's DataLoader. - -To try a different Hugging Face dataset, you can replace: -```python -dataset = PushtDataset() -``` -by one of these: -```python -dataset = XarmDataset("xarm_lift_medium") -dataset = XarmDataset("xarm_lift_medium_replay") -dataset = XarmDataset("xarm_push_medium") -dataset = XarmDataset("xarm_push_medium_replay") -dataset = AlohaDataset("aloha_sim_insertion_human") -dataset = AlohaDataset("aloha_sim_insertion_scripted") -dataset = AlohaDataset("aloha_sim_transfer_cube_human") -dataset = AlohaDataset("aloha_sim_transfer_cube_scripted") -``` """ from pathlib import Path @@ -34,31 +18,33 @@ from pathlib import Path import imageio import torch -from lerobot.common.datasets.pusht import PushtDataset +import lerobot +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset -# TODO(rcadene): List available datasets and their dataset ids (e.g. PushtDataset, AlohaDataset(dataset_id="aloha_sim_insertion_human")) -# print("List of available datasets", lerobot.available_datasets) -# # >>> ['aloha_sim_insertion_human', 'aloha_sim_insertion_scripted', -# # 'aloha_sim_transfer_cube_human', 'aloha_sim_transfer_cube_scripted', -# # 'pusht', 'xarm_lift_medium'] +print("List of available datasets", lerobot.available_datasets) +# # >>> ['lerobot/aloha_sim_insertion_human', 'lerobot/aloha_sim_insertion_scripted', +# # 'lerobot/aloha_sim_transfer_cube_human', 'lerobot/aloha_sim_transfer_cube_scripted', +# # 'lerobot/pusht', 'lerobot/xarm_lift_medium'] +repo_id = "lerobot/pusht" -# You can easily load datasets from LeRobot -dataset = PushtDataset() +# You can easily load a dataset from a Hugging Face repositery +dataset = LeRobotDataset(repo_id) -# All LeRobot datasets are actually a thin wrapper around an underlying Hugging Face dataset (see https://huggingface.co/docs/datasets/index for more information). +# LeRobotDataset is actually a thin wrapper around an underlying Hugging Face dataset (see https://huggingface.co/docs/datasets/index for more information). # TODO(rcadene): update to make the print pretty print(f"{dataset=}") print(f"{dataset.hf_dataset=}") -# and provide additional utilities for robotics and compatibility with pytorch +# and provides additional utilities for robotics and compatibility with pytorch print(f"number of samples/frames: {dataset.num_samples=}") print(f"number of episodes: {dataset.num_episodes=}") print(f"average number of frames per episode: {dataset.num_samples / dataset.num_episodes:.3f}") print(f"frames per second used during data collection: {dataset.fps=}") print(f"keys to access images from cameras: {dataset.image_keys=}") -# While the LeRobot dataset adds helpers for working within our library, we still expose the underling Hugging Face dataset. It may be freely replaced or modified in place. Here we use the filtering to keep only frames from episode 5. +# While the LeRobotDataset adds helpers for working within our library, we still expose the underling Hugging Face dataset. +# It may be freely replaced or modified in place. Here we use the filtering to keep only frames from episode 5. # TODO(rcadene): remove this example of accessing hf_dataset dataset.hf_dataset = dataset.hf_dataset.filter(lambda frame: frame["episode_index"] == 5) @@ -85,7 +71,7 @@ delta_timestamps = { # loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future "action": [t / dataset.fps for t in range(64)], } -dataset = PushtDataset(delta_timestamps=delta_timestamps) +dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps) print(f"{dataset[0]['observation.image'].shape=}") # (4,c,h,w) print(f"{dataset[0]['observation.state'].shape=}") # (8,c) print(f"{dataset[0]['action'].shape=}") # (64,c) diff --git a/lerobot/__init__.py b/lerobot/__init__.py index 70d7d7b0..5c01de19 100644 --- a/lerobot/__init__.py +++ b/lerobot/__init__.py @@ -8,31 +8,25 @@ Example: print(lerobot.available_envs) print(lerobot.available_tasks_per_env) print(lerobot.available_datasets) + print(lerobot.available_datasets_per_env) print(lerobot.available_policies) print(lerobot.available_policies_per_env) ``` -When implementing a new dataset class (e.g. `AlohaDataset`) follow these steps: -- Update `available_datasets` in `lerobot/__init__.py` -- Set the required `available_datasets` class attribute using the previously updated `lerobot.available_datasets` +When implementing a new dataset loadable with LeRobotDataset follow these steps: +- Update `available_datasets_per_env` in `lerobot/__init__.py` When implementing a new environment (e.g. `gym_aloha`), follow these steps: -- Update `available_envs`, `available_tasks_per_env` and `available_datasets` in `lerobot/__init__.py` +- Update `available_tasks_per_env` and `available_datasets_per_env` in `lerobot/__init__.py` When implementing a new policy class (e.g. `DiffusionPolicy`) follow these steps: -- Update `available_policies` in `lerobot/__init__.py` +- Update `available_policies` and `available_policies_per_env`, in `lerobot/__init__.py` - Set the required `name` class attribute. - Update variables in `tests/test_available.py` by importing your new Policy class """ from lerobot.__version__ import __version__ # noqa: F401 -available_envs = [ - "aloha", - "pusht", - "xarm", -] - available_tasks_per_env = { "aloha": [ "AlohaInsertion-v0", @@ -41,22 +35,24 @@ available_tasks_per_env = { "pusht": ["PushT-v0"], "xarm": ["XarmLift-v0"], } +available_envs = list(available_tasks_per_env.keys()) -available_datasets = { +available_datasets_per_env = { "aloha": [ - "aloha_sim_insertion_human", - "aloha_sim_insertion_scripted", - "aloha_sim_transfer_cube_human", - "aloha_sim_transfer_cube_scripted", + "lerobot/aloha_sim_insertion_human", + "lerobot/aloha_sim_insertion_scripted", + "lerobot/aloha_sim_transfer_cube_human", + "lerobot/aloha_sim_transfer_cube_scripted", ], - "pusht": ["pusht"], + "pusht": ["lerobot/pusht"], "xarm": [ - "xarm_lift_medium", - "xarm_lift_medium_replay", - "xarm_push_medium", - "xarm_push_medium_replay", + "lerobot/xarm_lift_medium", + "lerobot/xarm_lift_medium_replay", + "lerobot/xarm_push_medium", + "lerobot/xarm_push_medium_replay", ], } +available_datasets = [dataset for datasets in available_datasets_per_env.values() for dataset in datasets] available_policies = [ "act", @@ -71,10 +67,12 @@ available_policies_per_env = { } env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks] -env_dataset_pairs = [(env, dataset) for env, datasets in available_datasets.items() for dataset in datasets] +env_dataset_pairs = [ + (env, dataset) for env, datasets in available_datasets_per_env.items() for dataset in datasets +] env_dataset_policy_triplets = [ (env, dataset, policy) - for env, datasets in available_datasets.items() + for env, datasets in available_datasets_per_env.items() for dataset in datasets for policy in available_policies_per_env[env] ] diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py deleted file mode 100644 index f96d32b4..00000000 --- a/lerobot/common/datasets/aloha.py +++ /dev/null @@ -1,78 +0,0 @@ -from pathlib import Path - -import torch - -from lerobot.common.datasets.utils import ( - load_episode_data_index, - load_hf_dataset, - load_previous_and_future_frames, - load_stats, -) - - -class AlohaDataset(torch.utils.data.Dataset): - """ - https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human - https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted - https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human - https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted - """ - - # Copied from lerobot/__init__.py - available_datasets = [ - "aloha_sim_insertion_human", - "aloha_sim_insertion_scripted", - "aloha_sim_transfer_cube_human", - "aloha_sim_transfer_cube_scripted", - ] - fps = 50 - image_keys = ["observation.images.top"] - - def __init__( - self, - dataset_id: str, - version: str | None = "v1.1", - root: Path | None = None, - split: str = "train", - transform: callable = None, - delta_timestamps: dict[list[float]] | None = None, - ): - super().__init__() - self.dataset_id = dataset_id - self.version = version - self.root = root - self.split = split - self.transform = transform - self.delta_timestamps = delta_timestamps - # load data from hub or locally when root is provided - self.hf_dataset = load_hf_dataset(dataset_id, version, root, split) - self.episode_data_index = load_episode_data_index(dataset_id, version, root) - self.stats = load_stats(dataset_id, version, root) - - @property - def num_samples(self) -> int: - return len(self.hf_dataset) - - @property - def num_episodes(self) -> int: - return len(self.hf_dataset.unique("episode_index")) - - def __len__(self): - return self.num_samples - - def __getitem__(self, idx): - item = self.hf_dataset[idx] - - if self.delta_timestamps is not None: - item = load_previous_and_future_frames( - item, - self.hf_dataset, - self.episode_data_index, - self.delta_timestamps, - tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error - ) - - if self.transform is not None: - item = self.transform(item) - - return item diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 9753cde7..0da17b8e 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -1,9 +1,12 @@ +import logging import os from pathlib import Path import torch from omegaconf import OmegaConf +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset + DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None @@ -11,22 +14,10 @@ def make_dataset( cfg, split="train", ): - if cfg.env.name == "xarm": - from lerobot.common.datasets.xarm import XarmDataset - - clsfunc = XarmDataset - - elif cfg.env.name == "pusht": - from lerobot.common.datasets.pusht import PushtDataset - - clsfunc = PushtDataset - - elif cfg.env.name == "aloha": - from lerobot.common.datasets.aloha import AlohaDataset - - clsfunc = AlohaDataset - else: - raise ValueError(cfg.env.name) + if cfg.env.name not in cfg.dataset.repo_id: + logging.warning( + f"There might be a mismatch between your training dataset ({cfg.dataset.repo_id=}) and your environment ({cfg.env.name=})." + ) delta_timestamps = cfg.policy.get("delta_timestamps") if delta_timestamps is not None: @@ -36,8 +27,8 @@ def make_dataset( # TODO(rcadene): add data augmentations - dataset = clsfunc( - dataset_id=cfg.dataset_id, + dataset = LeRobotDataset( + cfg.dataset.repo_id, split=split, root=DATA_DIR, delta_timestamps=delta_timestamps, diff --git a/lerobot/common/datasets/xarm.py b/lerobot/common/datasets/lerobot_dataset.py similarity index 62% rename from lerobot/common/datasets/xarm.py rename to lerobot/common/datasets/lerobot_dataset.py index 7e69e7d7..6eedb233 100644 --- a/lerobot/common/datasets/xarm.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -1,36 +1,21 @@ from pathlib import Path +import datasets import torch from lerobot.common.datasets.utils import ( load_episode_data_index, load_hf_dataset, + load_info, load_previous_and_future_frames, load_stats, ) -class XarmDataset(torch.utils.data.Dataset): - """ - https://huggingface.co/datasets/lerobot/xarm_lift_medium - https://huggingface.co/datasets/lerobot/xarm_lift_medium_replay - https://huggingface.co/datasets/lerobot/xarm_push_medium - https://huggingface.co/datasets/lerobot/xarm_push_medium_replay - """ - - # Copied from lerobot/__init__.py - available_datasets = [ - "xarm_lift_medium", - "xarm_lift_medium_replay", - "xarm_push_medium", - "xarm_push_medium_replay", - ] - fps = 15 - image_keys = ["observation.image"] - +class LeRobotDataset(torch.utils.data.Dataset): def __init__( self, - dataset_id: str, + repo_id: str, version: str | None = "v1.1", root: Path | None = None, split: str = "train", @@ -38,16 +23,25 @@ class XarmDataset(torch.utils.data.Dataset): delta_timestamps: dict[list[float]] | None = None, ): super().__init__() - self.dataset_id = dataset_id + self.repo_id = repo_id self.version = version self.root = root self.split = split self.transform = transform self.delta_timestamps = delta_timestamps # load data from hub or locally when root is provided - self.hf_dataset = load_hf_dataset(dataset_id, version, root, split) - self.episode_data_index = load_episode_data_index(dataset_id, version, root) - self.stats = load_stats(dataset_id, version, root) + self.hf_dataset = load_hf_dataset(repo_id, version, root, split) + self.episode_data_index = load_episode_data_index(repo_id, version, root) + self.stats = load_stats(repo_id, version, root) + self.info = load_info(repo_id, version, root) + + @property + def fps(self) -> int: + return self.info["fps"] + + @property + def image_keys(self) -> list[str]: + return [key for key, feats in self.hf_dataset.features.items() if isinstance(feats, datasets.Image)] @property def num_samples(self) -> int: diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py deleted file mode 100644 index bc978b7a..00000000 --- a/lerobot/common/datasets/pusht.py +++ /dev/null @@ -1,76 +0,0 @@ -from pathlib import Path - -import torch - -from lerobot.common.datasets.utils import ( - load_episode_data_index, - load_hf_dataset, - load_previous_and_future_frames, - load_stats, -) - - -class PushtDataset(torch.utils.data.Dataset): - """ - https://huggingface.co/datasets/lerobot/pusht - - Arguments - ---------- - delta_timestamps : dict[list[float]] | None, optional - Loads data from frames with a shift in timestamps with a different strategy for each data key (e.g. state, action or image) - If `None`, no shift is applied to current timestamp and the data from the current frame is loaded. - """ - - # Copied from lerobot/__init__.py - available_datasets = ["pusht"] - fps = 10 - image_keys = ["observation.image"] - - def __init__( - self, - dataset_id: str = "pusht", - version: str | None = "v1.1", - root: Path | None = None, - split: str = "train", - transform: callable = None, - delta_timestamps: dict[list[float]] | None = None, - ): - super().__init__() - self.dataset_id = dataset_id - self.version = version - self.root = root - self.split = split - self.transform = transform - self.delta_timestamps = delta_timestamps - # load data from hub or locally when root is provided - self.hf_dataset = load_hf_dataset(dataset_id, version, root, split) - self.episode_data_index = load_episode_data_index(dataset_id, version, root) - self.stats = load_stats(dataset_id, version, root) - - @property - def num_samples(self) -> int: - return len(self.hf_dataset) - - @property - def num_episodes(self) -> int: - return len(self.episode_data_index["from"]) - - def __len__(self): - return self.num_samples - - def __getitem__(self, idx): - item = self.hf_dataset[idx] - - if self.delta_timestamps is not None: - item = load_previous_and_future_frames( - item, - self.hf_dataset, - self.episode_data_index, - self.delta_timestamps, - tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error - ) - - if self.transform is not None: - item = self.transform(item) - - return item diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index f5246c74..f7186b6a 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -1,3 +1,4 @@ +import json from copy import deepcopy from math import ceil from pathlib import Path @@ -15,7 +16,7 @@ from torchvision import transforms def flatten_dict(d, parent_key="", sep="/"): """Flatten a nested dictionary structure by collapsing nested keys into one key with a separator. - + For example: ``` >>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}` @@ -61,19 +62,17 @@ def hf_transform_to_torch(items_dict): return items_dict -def load_hf_dataset(dataset_id, version, root, split) -> datasets.Dataset: +def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset: """hf_dataset contains all the observations, states, actions, rewards, etc.""" if root is not None: - hf_dataset = load_from_disk(str(Path(root) / dataset_id / split)) + hf_dataset = load_from_disk(str(Path(root) / repo_id / split)) else: - # TODO(rcadene): remove dataset_id everywhere and use repo_id instead - repo_id = f"lerobot/{dataset_id}" hf_dataset = load_dataset(repo_id, revision=version, split=split) hf_dataset.set_transform(hf_transform_to_torch) return hf_dataset -def load_episode_data_index(dataset_id, version, root) -> dict[str, torch.Tensor]: +def load_episode_data_index(repo_id, version, root) -> dict[str, torch.Tensor]: """episode_data_index contains the range of indices for each episode Example: @@ -84,9 +83,8 @@ def load_episode_data_index(dataset_id, version, root) -> dict[str, torch.Tensor ``` """ if root is not None: - path = Path(root) / dataset_id / "meta_data" / "episode_data_index.safetensors" + path = Path(root) / repo_id / "meta_data" / "episode_data_index.safetensors" else: - repo_id = f"lerobot/{dataset_id}" path = hf_hub_download( repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=version ) @@ -94,7 +92,7 @@ def load_episode_data_index(dataset_id, version, root) -> dict[str, torch.Tensor return load_file(path) -def load_stats(dataset_id, version, root) -> dict[str, dict[str, torch.Tensor]]: +def load_stats(repo_id, version, root) -> dict[str, dict[str, torch.Tensor]]: """stats contains the statistics per modality computed over the full dataset, such as max, min, mean, std Example: @@ -103,15 +101,32 @@ def load_stats(dataset_id, version, root) -> dict[str, dict[str, torch.Tensor]]: ``` """ if root is not None: - path = Path(root) / dataset_id / "meta_data" / "stats.safetensors" + path = Path(root) / repo_id / "meta_data" / "stats.safetensors" else: - repo_id = f"lerobot/{dataset_id}" path = hf_hub_download(repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=version) stats = load_file(path) return unflatten_dict(stats) +def load_info(repo_id, version, root) -> dict: + """info contains useful information regarding the dataset that are not stored elsewhere + + Example: + ```python + print("frame per second used to collect the video", info["fps"]) + ``` + """ + if root is not None: + path = Path(root) / repo_id / "meta_data" / "info.json" + else: + path = hf_hub_download(repo_id, "meta_data/info.json", repo_type="dataset", revision=version) + + with open(path) as f: + info = json.load(f) + return info + + def load_previous_and_future_frames( item: dict[str, torch.Tensor], hf_dataset: datasets.Dataset, diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index 7b6c129d..21370e4b 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -26,7 +26,8 @@ fps: ??? offline_prioritized_sampler: true -dataset_id: ??? +dataset: + repo_id: ??? n_action_steps: ??? n_obs_steps: ??? diff --git a/lerobot/configs/env/aloha.yaml b/lerobot/configs/env/aloha.yaml index 26493711..41d44db8 100644 --- a/lerobot/configs/env/aloha.yaml +++ b/lerobot/configs/env/aloha.yaml @@ -10,7 +10,8 @@ online_steps: 25000 fps: 50 -dataset_id: aloha_sim_insertion_human +dataset: + repo_id: lerobot/aloha_sim_insertion_human env: name: aloha diff --git a/lerobot/configs/env/pusht.yaml b/lerobot/configs/env/pusht.yaml index 92b6a33b..29c2a258 100644 --- a/lerobot/configs/env/pusht.yaml +++ b/lerobot/configs/env/pusht.yaml @@ -10,7 +10,8 @@ online_steps: 25000 fps: 10 -dataset_id: pusht +dataset: + repo_id: lerobot/pusht env: name: pusht diff --git a/lerobot/configs/env/xarm.yaml b/lerobot/configs/env/xarm.yaml index 72ca12a0..00b8e2d5 100644 --- a/lerobot/configs/env/xarm.yaml +++ b/lerobot/configs/env/xarm.yaml @@ -9,7 +9,8 @@ online_steps: 25000 fps: 15 -dataset_id: xarm_lift_medium +dataset: + repo_id: lerobot/xarm_lift_medium env: name: xarm diff --git a/tests/data/aloha_sim_insertion_human/meta_data/episode_data_index.safetensors b/tests/data/lerobot/aloha_sim_insertion_human/meta_data/episode_data_index.safetensors similarity index 100% rename from tests/data/aloha_sim_insertion_human/meta_data/episode_data_index.safetensors rename to tests/data/lerobot/aloha_sim_insertion_human/meta_data/episode_data_index.safetensors diff --git a/tests/data/aloha_sim_insertion_human/meta_data/info.json b/tests/data/lerobot/aloha_sim_insertion_human/meta_data/info.json similarity index 100% rename from tests/data/aloha_sim_insertion_human/meta_data/info.json rename to tests/data/lerobot/aloha_sim_insertion_human/meta_data/info.json diff --git a/tests/data/aloha_sim_insertion_human/meta_data/stats.safetensors b/tests/data/lerobot/aloha_sim_insertion_human/meta_data/stats.safetensors similarity index 100% rename from tests/data/aloha_sim_insertion_human/meta_data/stats.safetensors rename to tests/data/lerobot/aloha_sim_insertion_human/meta_data/stats.safetensors diff --git a/tests/data/aloha_sim_insertion_human/stats.pth b/tests/data/lerobot/aloha_sim_insertion_human/stats.pth similarity index 100% rename from tests/data/aloha_sim_insertion_human/stats.pth rename to tests/data/lerobot/aloha_sim_insertion_human/stats.pth diff --git a/tests/data/aloha_sim_insertion_human/train/data-00000-of-00001.arrow b/tests/data/lerobot/aloha_sim_insertion_human/train/data-00000-of-00001.arrow similarity index 100% rename from tests/data/aloha_sim_insertion_human/train/data-00000-of-00001.arrow rename to tests/data/lerobot/aloha_sim_insertion_human/train/data-00000-of-00001.arrow diff --git a/tests/data/aloha_sim_insertion_human/train/dataset_info.json b/tests/data/lerobot/aloha_sim_insertion_human/train/dataset_info.json similarity index 100% rename from tests/data/aloha_sim_insertion_human/train/dataset_info.json rename to tests/data/lerobot/aloha_sim_insertion_human/train/dataset_info.json diff --git a/tests/data/aloha_sim_insertion_human/train/state.json b/tests/data/lerobot/aloha_sim_insertion_human/train/state.json similarity index 100% rename from tests/data/aloha_sim_insertion_human/train/state.json rename to tests/data/lerobot/aloha_sim_insertion_human/train/state.json diff --git a/tests/data/aloha_sim_insertion_scripted/meta_data/episode_data_index.safetensors b/tests/data/lerobot/aloha_sim_insertion_scripted/meta_data/episode_data_index.safetensors similarity index 100% rename from tests/data/aloha_sim_insertion_scripted/meta_data/episode_data_index.safetensors rename to tests/data/lerobot/aloha_sim_insertion_scripted/meta_data/episode_data_index.safetensors diff --git a/tests/data/aloha_sim_insertion_scripted/meta_data/info.json b/tests/data/lerobot/aloha_sim_insertion_scripted/meta_data/info.json similarity index 100% rename from tests/data/aloha_sim_insertion_scripted/meta_data/info.json rename to tests/data/lerobot/aloha_sim_insertion_scripted/meta_data/info.json diff --git a/tests/data/aloha_sim_insertion_scripted/meta_data/stats.safetensors b/tests/data/lerobot/aloha_sim_insertion_scripted/meta_data/stats.safetensors similarity index 100% rename from tests/data/aloha_sim_insertion_scripted/meta_data/stats.safetensors rename to tests/data/lerobot/aloha_sim_insertion_scripted/meta_data/stats.safetensors diff --git a/tests/data/aloha_sim_insertion_scripted/stats.pth b/tests/data/lerobot/aloha_sim_insertion_scripted/stats.pth similarity index 100% rename from tests/data/aloha_sim_insertion_scripted/stats.pth rename to tests/data/lerobot/aloha_sim_insertion_scripted/stats.pth diff --git a/tests/data/aloha_sim_insertion_scripted/train/data-00000-of-00001.arrow b/tests/data/lerobot/aloha_sim_insertion_scripted/train/data-00000-of-00001.arrow similarity index 100% rename from tests/data/aloha_sim_insertion_scripted/train/data-00000-of-00001.arrow rename to tests/data/lerobot/aloha_sim_insertion_scripted/train/data-00000-of-00001.arrow diff --git a/tests/data/aloha_sim_insertion_scripted/train/dataset_info.json b/tests/data/lerobot/aloha_sim_insertion_scripted/train/dataset_info.json similarity index 100% rename from tests/data/aloha_sim_insertion_scripted/train/dataset_info.json rename to tests/data/lerobot/aloha_sim_insertion_scripted/train/dataset_info.json diff --git a/tests/data/aloha_sim_insertion_scripted/train/state.json b/tests/data/lerobot/aloha_sim_insertion_scripted/train/state.json similarity index 100% rename from tests/data/aloha_sim_insertion_scripted/train/state.json rename to tests/data/lerobot/aloha_sim_insertion_scripted/train/state.json diff --git a/tests/data/aloha_sim_transfer_cube_human/meta_data/episode_data_index.safetensors b/tests/data/lerobot/aloha_sim_transfer_cube_human/meta_data/episode_data_index.safetensors similarity index 100% rename from tests/data/aloha_sim_transfer_cube_human/meta_data/episode_data_index.safetensors rename to tests/data/lerobot/aloha_sim_transfer_cube_human/meta_data/episode_data_index.safetensors diff --git a/tests/data/aloha_sim_transfer_cube_human/meta_data/info.json b/tests/data/lerobot/aloha_sim_transfer_cube_human/meta_data/info.json similarity index 100% rename from tests/data/aloha_sim_transfer_cube_human/meta_data/info.json rename to tests/data/lerobot/aloha_sim_transfer_cube_human/meta_data/info.json diff --git a/tests/data/aloha_sim_transfer_cube_human/meta_data/stats.safetensors b/tests/data/lerobot/aloha_sim_transfer_cube_human/meta_data/stats.safetensors similarity index 100% rename from tests/data/aloha_sim_transfer_cube_human/meta_data/stats.safetensors rename to tests/data/lerobot/aloha_sim_transfer_cube_human/meta_data/stats.safetensors diff --git a/tests/data/aloha_sim_transfer_cube_human/stats.pth b/tests/data/lerobot/aloha_sim_transfer_cube_human/stats.pth similarity index 100% rename from tests/data/aloha_sim_transfer_cube_human/stats.pth rename to tests/data/lerobot/aloha_sim_transfer_cube_human/stats.pth diff --git a/tests/data/aloha_sim_transfer_cube_human/train/data-00000-of-00001.arrow b/tests/data/lerobot/aloha_sim_transfer_cube_human/train/data-00000-of-00001.arrow similarity index 100% rename from tests/data/aloha_sim_transfer_cube_human/train/data-00000-of-00001.arrow rename to tests/data/lerobot/aloha_sim_transfer_cube_human/train/data-00000-of-00001.arrow diff --git a/tests/data/aloha_sim_transfer_cube_human/train/dataset_info.json b/tests/data/lerobot/aloha_sim_transfer_cube_human/train/dataset_info.json similarity index 100% rename from tests/data/aloha_sim_transfer_cube_human/train/dataset_info.json rename to tests/data/lerobot/aloha_sim_transfer_cube_human/train/dataset_info.json diff --git a/tests/data/aloha_sim_transfer_cube_human/train/state.json b/tests/data/lerobot/aloha_sim_transfer_cube_human/train/state.json similarity index 100% rename from tests/data/aloha_sim_transfer_cube_human/train/state.json rename to tests/data/lerobot/aloha_sim_transfer_cube_human/train/state.json diff --git a/tests/data/aloha_sim_transfer_cube_scripted/meta_data/episode_data_index.safetensors b/tests/data/lerobot/aloha_sim_transfer_cube_scripted/meta_data/episode_data_index.safetensors similarity index 100% rename from tests/data/aloha_sim_transfer_cube_scripted/meta_data/episode_data_index.safetensors rename to tests/data/lerobot/aloha_sim_transfer_cube_scripted/meta_data/episode_data_index.safetensors diff --git a/tests/data/aloha_sim_transfer_cube_scripted/meta_data/info.json b/tests/data/lerobot/aloha_sim_transfer_cube_scripted/meta_data/info.json similarity index 100% rename from tests/data/aloha_sim_transfer_cube_scripted/meta_data/info.json rename to tests/data/lerobot/aloha_sim_transfer_cube_scripted/meta_data/info.json diff --git a/tests/data/aloha_sim_transfer_cube_scripted/meta_data/stats.safetensors b/tests/data/lerobot/aloha_sim_transfer_cube_scripted/meta_data/stats.safetensors similarity index 100% rename from tests/data/aloha_sim_transfer_cube_scripted/meta_data/stats.safetensors rename to tests/data/lerobot/aloha_sim_transfer_cube_scripted/meta_data/stats.safetensors diff --git a/tests/data/aloha_sim_transfer_cube_scripted/stats.pth b/tests/data/lerobot/aloha_sim_transfer_cube_scripted/stats.pth similarity index 100% rename from tests/data/aloha_sim_transfer_cube_scripted/stats.pth rename to tests/data/lerobot/aloha_sim_transfer_cube_scripted/stats.pth diff --git a/tests/data/aloha_sim_transfer_cube_scripted/train/data-00000-of-00001.arrow b/tests/data/lerobot/aloha_sim_transfer_cube_scripted/train/data-00000-of-00001.arrow similarity index 100% rename from tests/data/aloha_sim_transfer_cube_scripted/train/data-00000-of-00001.arrow rename to tests/data/lerobot/aloha_sim_transfer_cube_scripted/train/data-00000-of-00001.arrow diff --git a/tests/data/aloha_sim_transfer_cube_scripted/train/dataset_info.json b/tests/data/lerobot/aloha_sim_transfer_cube_scripted/train/dataset_info.json similarity index 100% rename from tests/data/aloha_sim_transfer_cube_scripted/train/dataset_info.json rename to tests/data/lerobot/aloha_sim_transfer_cube_scripted/train/dataset_info.json diff --git a/tests/data/aloha_sim_transfer_cube_scripted/train/state.json b/tests/data/lerobot/aloha_sim_transfer_cube_scripted/train/state.json similarity index 100% rename from tests/data/aloha_sim_transfer_cube_scripted/train/state.json rename to tests/data/lerobot/aloha_sim_transfer_cube_scripted/train/state.json diff --git a/tests/data/pusht/meta_data/episode_data_index.safetensors b/tests/data/lerobot/pusht/meta_data/episode_data_index.safetensors similarity index 100% rename from tests/data/pusht/meta_data/episode_data_index.safetensors rename to tests/data/lerobot/pusht/meta_data/episode_data_index.safetensors diff --git a/tests/data/pusht/meta_data/info.json b/tests/data/lerobot/pusht/meta_data/info.json similarity index 100% rename from tests/data/pusht/meta_data/info.json rename to tests/data/lerobot/pusht/meta_data/info.json diff --git a/tests/data/pusht/meta_data/stats.safetensors b/tests/data/lerobot/pusht/meta_data/stats.safetensors similarity index 100% rename from tests/data/pusht/meta_data/stats.safetensors rename to tests/data/lerobot/pusht/meta_data/stats.safetensors diff --git a/tests/data/pusht/stats.pth b/tests/data/lerobot/pusht/stats.pth similarity index 100% rename from tests/data/pusht/stats.pth rename to tests/data/lerobot/pusht/stats.pth diff --git a/tests/data/pusht/train/data-00000-of-00001.arrow b/tests/data/lerobot/pusht/train/data-00000-of-00001.arrow similarity index 100% rename from tests/data/pusht/train/data-00000-of-00001.arrow rename to tests/data/lerobot/pusht/train/data-00000-of-00001.arrow diff --git a/tests/data/pusht/train/dataset_info.json b/tests/data/lerobot/pusht/train/dataset_info.json similarity index 100% rename from tests/data/pusht/train/dataset_info.json rename to tests/data/lerobot/pusht/train/dataset_info.json diff --git a/tests/data/pusht/train/meta_data/episode_data_index.safetensors b/tests/data/lerobot/pusht/train/meta_data/episode_data_index.safetensors similarity index 100% rename from tests/data/pusht/train/meta_data/episode_data_index.safetensors rename to tests/data/lerobot/pusht/train/meta_data/episode_data_index.safetensors diff --git a/tests/data/pusht/train/meta_data/info.json b/tests/data/lerobot/pusht/train/meta_data/info.json similarity index 100% rename from tests/data/pusht/train/meta_data/info.json rename to tests/data/lerobot/pusht/train/meta_data/info.json diff --git a/tests/data/pusht/train/meta_data/stats_action.safetensors b/tests/data/lerobot/pusht/train/meta_data/stats_action.safetensors similarity index 100% rename from tests/data/pusht/train/meta_data/stats_action.safetensors rename to tests/data/lerobot/pusht/train/meta_data/stats_action.safetensors diff --git a/tests/data/pusht/train/meta_data/stats_observation.image.safetensors b/tests/data/lerobot/pusht/train/meta_data/stats_observation.image.safetensors similarity index 100% rename from tests/data/pusht/train/meta_data/stats_observation.image.safetensors rename to tests/data/lerobot/pusht/train/meta_data/stats_observation.image.safetensors diff --git a/tests/data/pusht/train/meta_data/stats_observation.state.safetensors b/tests/data/lerobot/pusht/train/meta_data/stats_observation.state.safetensors similarity index 100% rename from tests/data/pusht/train/meta_data/stats_observation.state.safetensors rename to tests/data/lerobot/pusht/train/meta_data/stats_observation.state.safetensors diff --git a/tests/data/pusht/train/state.json b/tests/data/lerobot/pusht/train/state.json similarity index 100% rename from tests/data/pusht/train/state.json rename to tests/data/lerobot/pusht/train/state.json diff --git a/tests/data/xarm_lift_medium/meta_data/episode_data_index.safetensors b/tests/data/lerobot/xarm_lift_medium/meta_data/episode_data_index.safetensors similarity index 100% rename from tests/data/xarm_lift_medium/meta_data/episode_data_index.safetensors rename to tests/data/lerobot/xarm_lift_medium/meta_data/episode_data_index.safetensors diff --git a/tests/data/xarm_lift_medium/meta_data/info.json b/tests/data/lerobot/xarm_lift_medium/meta_data/info.json similarity index 100% rename from tests/data/xarm_lift_medium/meta_data/info.json rename to tests/data/lerobot/xarm_lift_medium/meta_data/info.json diff --git a/tests/data/xarm_lift_medium/meta_data/stats.safetensors b/tests/data/lerobot/xarm_lift_medium/meta_data/stats.safetensors similarity index 100% rename from tests/data/xarm_lift_medium/meta_data/stats.safetensors rename to tests/data/lerobot/xarm_lift_medium/meta_data/stats.safetensors diff --git a/tests/data/xarm_lift_medium/stats.pth b/tests/data/lerobot/xarm_lift_medium/stats.pth similarity index 100% rename from tests/data/xarm_lift_medium/stats.pth rename to tests/data/lerobot/xarm_lift_medium/stats.pth diff --git a/tests/data/xarm_lift_medium/train/data-00000-of-00001.arrow b/tests/data/lerobot/xarm_lift_medium/train/data-00000-of-00001.arrow similarity index 100% rename from tests/data/xarm_lift_medium/train/data-00000-of-00001.arrow rename to tests/data/lerobot/xarm_lift_medium/train/data-00000-of-00001.arrow diff --git a/tests/data/xarm_lift_medium/train/dataset_info.json b/tests/data/lerobot/xarm_lift_medium/train/dataset_info.json similarity index 100% rename from tests/data/xarm_lift_medium/train/dataset_info.json rename to tests/data/lerobot/xarm_lift_medium/train/dataset_info.json diff --git a/tests/data/xarm_lift_medium/train/state.json b/tests/data/lerobot/xarm_lift_medium/train/state.json similarity index 100% rename from tests/data/xarm_lift_medium/train/state.json rename to tests/data/lerobot/xarm_lift_medium/train/state.json diff --git a/tests/data/xarm_lift_medium_replay/meta_data/episode_data_index.safetensors b/tests/data/lerobot/xarm_lift_medium_replay/meta_data/episode_data_index.safetensors similarity index 100% rename from tests/data/xarm_lift_medium_replay/meta_data/episode_data_index.safetensors rename to tests/data/lerobot/xarm_lift_medium_replay/meta_data/episode_data_index.safetensors diff --git a/tests/data/xarm_lift_medium_replay/meta_data/info.json b/tests/data/lerobot/xarm_lift_medium_replay/meta_data/info.json similarity index 100% rename from tests/data/xarm_lift_medium_replay/meta_data/info.json rename to tests/data/lerobot/xarm_lift_medium_replay/meta_data/info.json diff --git a/tests/data/xarm_lift_medium_replay/meta_data/stats.safetensors b/tests/data/lerobot/xarm_lift_medium_replay/meta_data/stats.safetensors similarity index 100% rename from tests/data/xarm_lift_medium_replay/meta_data/stats.safetensors rename to tests/data/lerobot/xarm_lift_medium_replay/meta_data/stats.safetensors diff --git a/tests/data/xarm_lift_medium_replay/train/data-00000-of-00001.arrow b/tests/data/lerobot/xarm_lift_medium_replay/train/data-00000-of-00001.arrow similarity index 100% rename from tests/data/xarm_lift_medium_replay/train/data-00000-of-00001.arrow rename to tests/data/lerobot/xarm_lift_medium_replay/train/data-00000-of-00001.arrow diff --git a/tests/data/xarm_lift_medium_replay/train/dataset_info.json b/tests/data/lerobot/xarm_lift_medium_replay/train/dataset_info.json similarity index 100% rename from tests/data/xarm_lift_medium_replay/train/dataset_info.json rename to tests/data/lerobot/xarm_lift_medium_replay/train/dataset_info.json diff --git a/tests/data/xarm_lift_medium_replay/train/state.json b/tests/data/lerobot/xarm_lift_medium_replay/train/state.json similarity index 100% rename from tests/data/xarm_lift_medium_replay/train/state.json rename to tests/data/lerobot/xarm_lift_medium_replay/train/state.json diff --git a/tests/data/xarm_push_medium/meta_data/episode_data_index.safetensors b/tests/data/lerobot/xarm_push_medium/meta_data/episode_data_index.safetensors similarity index 100% rename from tests/data/xarm_push_medium/meta_data/episode_data_index.safetensors rename to tests/data/lerobot/xarm_push_medium/meta_data/episode_data_index.safetensors diff --git a/tests/data/xarm_push_medium/meta_data/info.json b/tests/data/lerobot/xarm_push_medium/meta_data/info.json similarity index 100% rename from tests/data/xarm_push_medium/meta_data/info.json rename to tests/data/lerobot/xarm_push_medium/meta_data/info.json diff --git a/tests/data/xarm_push_medium/meta_data/stats.safetensors b/tests/data/lerobot/xarm_push_medium/meta_data/stats.safetensors similarity index 100% rename from tests/data/xarm_push_medium/meta_data/stats.safetensors rename to tests/data/lerobot/xarm_push_medium/meta_data/stats.safetensors diff --git a/tests/data/xarm_push_medium/train/data-00000-of-00001.arrow b/tests/data/lerobot/xarm_push_medium/train/data-00000-of-00001.arrow similarity index 100% rename from tests/data/xarm_push_medium/train/data-00000-of-00001.arrow rename to tests/data/lerobot/xarm_push_medium/train/data-00000-of-00001.arrow diff --git a/tests/data/xarm_push_medium/train/dataset_info.json b/tests/data/lerobot/xarm_push_medium/train/dataset_info.json similarity index 100% rename from tests/data/xarm_push_medium/train/dataset_info.json rename to tests/data/lerobot/xarm_push_medium/train/dataset_info.json diff --git a/tests/data/xarm_push_medium/train/state.json b/tests/data/lerobot/xarm_push_medium/train/state.json similarity index 100% rename from tests/data/xarm_push_medium/train/state.json rename to tests/data/lerobot/xarm_push_medium/train/state.json diff --git a/tests/data/xarm_push_medium_replay/meta_data/episode_data_index.safetensors b/tests/data/lerobot/xarm_push_medium_replay/meta_data/episode_data_index.safetensors similarity index 100% rename from tests/data/xarm_push_medium_replay/meta_data/episode_data_index.safetensors rename to tests/data/lerobot/xarm_push_medium_replay/meta_data/episode_data_index.safetensors diff --git a/tests/data/xarm_push_medium_replay/meta_data/info.json b/tests/data/lerobot/xarm_push_medium_replay/meta_data/info.json similarity index 100% rename from tests/data/xarm_push_medium_replay/meta_data/info.json rename to tests/data/lerobot/xarm_push_medium_replay/meta_data/info.json diff --git a/tests/data/xarm_push_medium_replay/meta_data/stats.safetensors b/tests/data/lerobot/xarm_push_medium_replay/meta_data/stats.safetensors similarity index 100% rename from tests/data/xarm_push_medium_replay/meta_data/stats.safetensors rename to tests/data/lerobot/xarm_push_medium_replay/meta_data/stats.safetensors diff --git a/tests/data/xarm_push_medium_replay/train/data-00000-of-00001.arrow b/tests/data/lerobot/xarm_push_medium_replay/train/data-00000-of-00001.arrow similarity index 100% rename from tests/data/xarm_push_medium_replay/train/data-00000-of-00001.arrow rename to tests/data/lerobot/xarm_push_medium_replay/train/data-00000-of-00001.arrow diff --git a/tests/data/xarm_push_medium_replay/train/dataset_info.json b/tests/data/lerobot/xarm_push_medium_replay/train/dataset_info.json similarity index 100% rename from tests/data/xarm_push_medium_replay/train/dataset_info.json rename to tests/data/lerobot/xarm_push_medium_replay/train/dataset_info.json diff --git a/tests/data/xarm_push_medium_replay/train/state.json b/tests/data/lerobot/xarm_push_medium_replay/train/state.json similarity index 100% rename from tests/data/xarm_push_medium_replay/train/state.json rename to tests/data/lerobot/xarm_push_medium_replay/train/state.json diff --git a/tests/data/save_dataset_to_safetensors/pusht/frame_0.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_0.safetensors similarity index 100% rename from tests/data/save_dataset_to_safetensors/pusht/frame_0.safetensors rename to tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_0.safetensors diff --git a/tests/data/save_dataset_to_safetensors/pusht/frame_1.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_1.safetensors similarity index 100% rename from tests/data/save_dataset_to_safetensors/pusht/frame_1.safetensors rename to tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_1.safetensors diff --git a/tests/data/save_dataset_to_safetensors/pusht/frame_159.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_159.safetensors similarity index 100% rename from tests/data/save_dataset_to_safetensors/pusht/frame_159.safetensors rename to tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_159.safetensors diff --git a/tests/data/save_dataset_to_safetensors/pusht/frame_160.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_160.safetensors similarity index 100% rename from tests/data/save_dataset_to_safetensors/pusht/frame_160.safetensors rename to tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_160.safetensors diff --git a/tests/data/save_dataset_to_safetensors/pusht/frame_80.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_80.safetensors similarity index 100% rename from tests/data/save_dataset_to_safetensors/pusht/frame_80.safetensors rename to tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_80.safetensors diff --git a/tests/data/save_dataset_to_safetensors/pusht/frame_81.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_81.safetensors similarity index 100% rename from tests/data/save_dataset_to_safetensors/pusht/frame_81.safetensors rename to tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_81.safetensors diff --git a/tests/scripts/save_dataset_to_safetensors.py b/tests/scripts/save_dataset_to_safetensors.py index 4f0875e2..a8ea1065 100644 --- a/tests/scripts/save_dataset_to_safetensors.py +++ b/tests/scripts/save_dataset_to_safetensors.py @@ -16,22 +16,18 @@ from pathlib import Path from safetensors.torch import save_file -from lerobot.common.datasets.pusht import PushtDataset +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset -def save_dataset_to_safetensors(output_dir, dataset_id="pusht"): - data_dir = Path(output_dir) / dataset_id +def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"): + data_dir = Path(output_dir) / repo_id if data_dir.exists(): shutil.rmtree(data_dir) data_dir.mkdir(parents=True, exist_ok=True) - # TODO(rcadene): make it work for all datasets with LeRobotDataset(repo_id) - dataset = PushtDataset( - dataset_id=dataset_id, - split="train", - ) + dataset = LeRobotDataset(repo_id) # save 2 first frames of first episode i = dataset.episode_data_index["from"][0].item() diff --git a/tests/test_available.py b/tests/test_available.py index 4328ec69..29f4f31e 100644 --- a/tests/test_available.py +++ b/tests/test_available.py @@ -4,9 +4,6 @@ import gymnasium as gym import pytest import lerobot -from lerobot.common.datasets.aloha import AlohaDataset -from lerobot.common.datasets.pusht import PushtDataset -from lerobot.common.datasets.xarm import XarmDataset from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.common.policies.tdmpc.policy import TDMPCPolicy @@ -27,25 +24,6 @@ def test_available_env_task(env_name: str, task_name: list): assert gym_handle in gym.envs.registry, gym_handle -@pytest.mark.parametrize( - "env_name, dataset_class", - [ - ("aloha", AlohaDataset), - ("pusht", PushtDataset), - ("xarm", XarmDataset), - ], -) -def test_available_datasets(env_name, dataset_class): - """ - This test verifies that the class attribute `available_datasets` for all - dataset classes is consistent with those listed in `lerobot/__init__.py`. - """ - available_env_datasets = lerobot.available_datasets[env_name] - assert set(available_env_datasets) == set( - dataset_class.available_datasets - ), f"{env_name=} {available_env_datasets=}" - - def test_available_policies(): """ This test verifies that the class attribute `name` for all policies is @@ -58,3 +36,12 @@ def test_available_policies(): ] policies = [pol_cls.name for pol_cls in policy_classes] assert set(policies) == set(lerobot.available_policies), policies + + +def test_print(): + print(lerobot.available_envs) + print(lerobot.available_tasks_per_env) + print(lerobot.available_datasets) + print(lerobot.available_datasets_per_env) + print(lerobot.available_policies) + print(lerobot.available_policies_per_env) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index ec459c58..bebc3479 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -12,7 +12,7 @@ from safetensors.torch import load_file import lerobot from lerobot.common.datasets.factory import make_dataset -from lerobot.common.datasets.pusht import PushtDataset +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.utils import ( compute_stats, flatten_dict, @@ -26,13 +26,13 @@ from lerobot.common.utils.utils import init_hydra_config from .utils import DEFAULT_CONFIG_PATH, DEVICE -@pytest.mark.parametrize("env_name, dataset_id, policy_name", lerobot.env_dataset_policy_triplets) -def test_factory(env_name, dataset_id, policy_name): +@pytest.mark.parametrize("env_name, repo_id, policy_name", lerobot.env_dataset_policy_triplets) +def test_factory(env_name, repo_id, policy_name): cfg = init_hydra_config( DEFAULT_CONFIG_PATH, overrides=[ f"env={env_name}", - f"dataset_id={dataset_id}", + f"dataset.repo_id={repo_id}", f"policy={policy_name}", f"device={DEVICE}", ], @@ -94,14 +94,13 @@ def test_compute_stats_on_xarm(): We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do because we are working with a small dataset). """ - # TODO(rcadene): Reduce size of dataset sample on which stats compute is tested - from lerobot.common.datasets.xarm import XarmDataset - - dataset = XarmDataset( - dataset_id="xarm_lift_medium", - root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None, + dataset = LeRobotDataset( + "lerobot/xarm_lift_medium", root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None ) + # reduce size of dataset sample on which stats compute is tested to 10 frames + dataset.hf_dataset = dataset.hf_dataset.select(range(10)) + # Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched # computation of the statistics. While doing this, we also make sure it works when we don't divide the # dataset into even batches. @@ -241,16 +240,16 @@ def test_flatten_unflatten_dict(): def test_backward_compatibility(): """This tests artifacts have been generated by `tests/scripts/save_dataset_to_safetensors.py`.""" - # TODO(rcadene): make it work for all datasets with LeRobotDataset(repo_id) - dataset_id = "pusht" - data_dir = Path("tests/data/save_dataset_to_safetensors") / dataset_id - dataset = PushtDataset( - dataset_id=dataset_id, - split="train", + repo_id = "lerobot/pusht" + + dataset = LeRobotDataset( + repo_id, root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None, ) + data_dir = Path("tests/data/save_dataset_to_safetensors") / repo_id + def load_and_compare(i): new_frame = dataset[i] old_frame = load_file(data_dir / f"frame_{i}.safetensors") diff --git a/tests/test_policies.py b/tests/test_policies.py index 0e4ce654..3b1959d5 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -19,10 +19,22 @@ from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env ("xarm", "tdmpc", ["policy.mpc=true"]), ("pusht", "tdmpc", ["policy.mpc=false"]), ("pusht", "diffusion", []), - ("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_id=aloha_sim_insertion_human"]), - ("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_id=aloha_sim_insertion_scripted"]), - ("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_human"]), - ("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_scripted"]), + ("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset.repo_id=lerobot/aloha_sim_insertion_human"]), + ( + "aloha", + "act", + ["env.task=AlohaInsertion-v0", "dataset.repo_id=lerobot/aloha_sim_insertion_scripted"], + ), + ( + "aloha", + "act", + ["env.task=AlohaTransferCube-v0", "dataset.repo_id=lerobot/aloha_sim_transfer_cube_human"], + ), + ( + "aloha", + "act", + ["env.task=AlohaTransferCube-v0", "dataset.repo_id=lerobot/aloha_sim_transfer_cube_scripted"], + ), ], ) @require_env diff --git a/tests/test_visualize_dataset.py b/tests/test_visualize_dataset.py index 6787c463..3ed22970 100644 --- a/tests/test_visualize_dataset.py +++ b/tests/test_visualize_dataset.py @@ -7,12 +7,12 @@ from .utils import DEFAULT_CONFIG_PATH @pytest.mark.parametrize( - "dataset_id", + "repo_id", [ - "aloha_sim_insertion_human", + "lerobot/aloha_sim_insertion_human", ], ) -def test_visualize_dataset(tmpdir, dataset_id): +def test_visualize_dataset(tmpdir, repo_id): # TODO(rcadene): this test might fail with other datasets/policies/envs, since visualization_dataset # doesnt support multiple timesteps which requires delta_timestamps to None for images. cfg = init_hydra_config( @@ -20,7 +20,7 @@ def test_visualize_dataset(tmpdir, dataset_id): overrides=[ "policy=act", "env=aloha", - f"dataset_id={dataset_id}", + f"dataset.repo_id={repo_id}", ], ) video_paths = visualize_dataset(cfg, out_dir=tmpdir)