From 6a1a29386a45307e7eddee4ab24585b5808809cb Mon Sep 17 00:00:00 2001 From: Cadene Date: Tue, 19 Mar 2024 15:49:45 +0000 Subject: [PATCH 1/5] Add replay_buffer directory in pusht datasets + aloha (WIP) --- README.md | 34 +++++++++++++++++-- lerobot/common/datasets/abstract.py | 10 ++++-- lerobot/common/datasets/aloha.py | 2 ++ lerobot/common/datasets/pusht.py | 2 ++ lerobot/common/datasets/simxarm.py | 2 ++ .../pusht/{ => replay_buffer}/action.memmap | 0 .../pusht/{ => replay_buffer}/episode.memmap | 0 .../pusht/{ => replay_buffer}/frame_id.memmap | 0 .../data/pusht/{ => replay_buffer}/meta.json | 0 .../{ => replay_buffer}/next/done.memmap | 0 .../pusht/{ => replay_buffer}/next/meta.json | 0 .../next/observation/image.memmap | 0 .../next/observation/meta.json | 0 .../next/observation/state.memmap | 0 .../{ => replay_buffer}/next/reward.memmap | 0 .../{ => replay_buffer}/next/success.memmap | 0 .../observation/image.memmap | 0 .../{ => replay_buffer}/observation/meta.json | 0 .../observation/state.memmap | 0 tests/scripts/mock_dataset.py | 11 +++--- 20 files changed, 53 insertions(+), 8 deletions(-) rename tests/data/pusht/{ => replay_buffer}/action.memmap (100%) rename tests/data/pusht/{ => replay_buffer}/episode.memmap (100%) rename tests/data/pusht/{ => replay_buffer}/frame_id.memmap (100%) rename tests/data/pusht/{ => replay_buffer}/meta.json (100%) rename tests/data/pusht/{ => replay_buffer}/next/done.memmap (100%) rename tests/data/pusht/{ => replay_buffer}/next/meta.json (100%) rename tests/data/pusht/{ => replay_buffer}/next/observation/image.memmap (100%) rename tests/data/pusht/{ => replay_buffer}/next/observation/meta.json (100%) rename tests/data/pusht/{ => replay_buffer}/next/observation/state.memmap (100%) rename tests/data/pusht/{ => replay_buffer}/next/reward.memmap (100%) rename tests/data/pusht/{ => replay_buffer}/next/success.memmap (100%) rename tests/data/pusht/{ => replay_buffer}/observation/image.memmap (100%) rename tests/data/pusht/{ => replay_buffer}/observation/meta.json (100%) rename tests/data/pusht/{ => replay_buffer}/observation/state.memmap (100%) diff --git a/README.md b/README.md index 1051c8a6..71990e28 100644 --- a/README.md +++ b/README.md @@ -138,7 +138,7 @@ git lfs pull When adding a new dataset, mock it with ``` -python tests/scripts/mock_dataset.py --in-data-dir data/ --out-data-dir tests/data/ +python tests/scripts/mock_dataset.py --in-data-dir data/$DATASET --out-data-dir tests/data/$DATASET ``` Run tests @@ -155,7 +155,9 @@ huggingface-cli login --token $HUGGINGFACE_TOKEN --add-to-git-credential Then you can upload it to the hub with: ``` -HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload --repo-type dataset $HF_USER/$DATASET data/$DATASET +HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload $HF_USER/$DATASET data/$DATASET \ +--repo-type dataset \ +--revision v1.0 ``` For instance, for [cadene/pusht](https://huggingface.co/datasets/cadene/pusht), we used: @@ -164,6 +166,34 @@ HF_USER=cadene DATASET=pusht ``` +If you want to improve an existing dataset, you can download it locally with: +``` +mkdir -p data/$DATASET +HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli download $HF_USER/$DATASET \ +--repo-type dataset \ +--local-dir data/$DATASET \ +--local-dir-use-symlinks=False \ +--revision v1.0 +``` + +Iterate on your code and dataset with: +``` +DATA_DIR=data python train.py +``` + +Then upload a new version: +``` +HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload $HF_USER/$DATASET data/$DATASET \ +--repo-type dataset \ +--revision v1.1 \ +--delete "*" +``` + +And you might want to mock the dataset if you need to update the unit tests as well: +``` +python tests/scripts/mock_dataset.py --in-data-dir data/$DATASET --out-data-dir tests/data/$DATASET +``` + ## Acknowledgment - Our Diffusion policy and Pusht environment are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu/) diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index 34b33c2e..9127d887 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -19,6 +19,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer): def __init__( self, dataset_id: str, + version: str | None = None, batch_size: int = None, *, shuffle: bool = True, @@ -31,6 +32,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer): transform: "torchrl.envs.Transform" = None, ): self.dataset_id = dataset_id + self.version = version self.shuffle = shuffle self.root = root storage = self._download_or_load_dataset() @@ -96,10 +98,14 @@ class AbstractExperienceReplay(TensorDictReplayBuffer): def _download_or_load_dataset(self) -> torch.StorageBase: if self.root is None: - self.data_dir = Path(snapshot_download(repo_id=f"cadene/{self.dataset_id}", repo_type="dataset")) + self.data_dir = Path( + snapshot_download( + repo_id=f"cadene/{self.dataset_id}", repo_type="dataset", revision=self.version + ) + ) else: self.data_dir = self.root / self.dataset_id - return TensorStorage(TensorDict.load_memmap(self.data_dir)) + return TensorStorage(TensorDict.load_memmap(self.data_dir / "replay_buffer")) def _compute_stats(self, num_batch=100, batch_size=32): rb = TensorDictReplayBuffer( diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index 52a5676e..2af98cd8 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -84,6 +84,7 @@ class AlohaExperienceReplay(AbstractExperienceReplay): def __init__( self, dataset_id: str, + version: str | None = None, batch_size: int = None, *, shuffle: bool = True, @@ -99,6 +100,7 @@ class AlohaExperienceReplay(AbstractExperienceReplay): super().__init__( dataset_id, + version, batch_size, shuffle=shuffle, root=root, diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index f4f6d9ac..3ad6371f 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -87,6 +87,7 @@ class PushtExperienceReplay(AbstractExperienceReplay): def __init__( self, dataset_id: str, + version: str | None = "v1.0", batch_size: int = None, *, shuffle: bool = True, @@ -100,6 +101,7 @@ class PushtExperienceReplay(AbstractExperienceReplay): ): super().__init__( dataset_id, + version, batch_size, shuffle=shuffle, root=root, diff --git a/lerobot/common/datasets/simxarm.py b/lerobot/common/datasets/simxarm.py index 7bcb03fb..d7e2e18f 100644 --- a/lerobot/common/datasets/simxarm.py +++ b/lerobot/common/datasets/simxarm.py @@ -40,6 +40,7 @@ class SimxarmExperienceReplay(AbstractExperienceReplay): def __init__( self, dataset_id: str, + version: str | None = None, batch_size: int = None, *, shuffle: bool = True, @@ -53,6 +54,7 @@ class SimxarmExperienceReplay(AbstractExperienceReplay): ): super().__init__( dataset_id, + version, batch_size, shuffle=shuffle, root=root, diff --git a/tests/data/pusht/action.memmap b/tests/data/pusht/replay_buffer/action.memmap similarity index 100% rename from tests/data/pusht/action.memmap rename to tests/data/pusht/replay_buffer/action.memmap diff --git a/tests/data/pusht/episode.memmap b/tests/data/pusht/replay_buffer/episode.memmap similarity index 100% rename from tests/data/pusht/episode.memmap rename to tests/data/pusht/replay_buffer/episode.memmap diff --git a/tests/data/pusht/frame_id.memmap b/tests/data/pusht/replay_buffer/frame_id.memmap similarity index 100% rename from tests/data/pusht/frame_id.memmap rename to tests/data/pusht/replay_buffer/frame_id.memmap diff --git a/tests/data/pusht/meta.json b/tests/data/pusht/replay_buffer/meta.json similarity index 100% rename from tests/data/pusht/meta.json rename to tests/data/pusht/replay_buffer/meta.json diff --git a/tests/data/pusht/next/done.memmap b/tests/data/pusht/replay_buffer/next/done.memmap similarity index 100% rename from tests/data/pusht/next/done.memmap rename to tests/data/pusht/replay_buffer/next/done.memmap diff --git a/tests/data/pusht/next/meta.json b/tests/data/pusht/replay_buffer/next/meta.json similarity index 100% rename from tests/data/pusht/next/meta.json rename to tests/data/pusht/replay_buffer/next/meta.json diff --git a/tests/data/pusht/next/observation/image.memmap b/tests/data/pusht/replay_buffer/next/observation/image.memmap similarity index 100% rename from tests/data/pusht/next/observation/image.memmap rename to tests/data/pusht/replay_buffer/next/observation/image.memmap diff --git a/tests/data/pusht/next/observation/meta.json b/tests/data/pusht/replay_buffer/next/observation/meta.json similarity index 100% rename from tests/data/pusht/next/observation/meta.json rename to tests/data/pusht/replay_buffer/next/observation/meta.json diff --git a/tests/data/pusht/next/observation/state.memmap b/tests/data/pusht/replay_buffer/next/observation/state.memmap similarity index 100% rename from tests/data/pusht/next/observation/state.memmap rename to tests/data/pusht/replay_buffer/next/observation/state.memmap diff --git a/tests/data/pusht/next/reward.memmap b/tests/data/pusht/replay_buffer/next/reward.memmap similarity index 100% rename from tests/data/pusht/next/reward.memmap rename to tests/data/pusht/replay_buffer/next/reward.memmap diff --git a/tests/data/pusht/next/success.memmap b/tests/data/pusht/replay_buffer/next/success.memmap similarity index 100% rename from tests/data/pusht/next/success.memmap rename to tests/data/pusht/replay_buffer/next/success.memmap diff --git a/tests/data/pusht/observation/image.memmap b/tests/data/pusht/replay_buffer/observation/image.memmap similarity index 100% rename from tests/data/pusht/observation/image.memmap rename to tests/data/pusht/replay_buffer/observation/image.memmap diff --git a/tests/data/pusht/observation/meta.json b/tests/data/pusht/replay_buffer/observation/meta.json similarity index 100% rename from tests/data/pusht/observation/meta.json rename to tests/data/pusht/replay_buffer/observation/meta.json diff --git a/tests/data/pusht/observation/state.memmap b/tests/data/pusht/replay_buffer/observation/state.memmap similarity index 100% rename from tests/data/pusht/observation/state.memmap rename to tests/data/pusht/replay_buffer/observation/state.memmap diff --git a/tests/scripts/mock_dataset.py b/tests/scripts/mock_dataset.py index c58280d7..2200b644 100644 --- a/tests/scripts/mock_dataset.py +++ b/tests/scripts/mock_dataset.py @@ -10,12 +10,15 @@ from pathlib import Path def mock_dataset(in_data_dir, out_data_dir, num_frames=50): + in_data_dir = Path(in_data_dir) + out_data_dir = Path(out_data_dir) + # load full dataset as a tensor dict - in_td_data = TensorDict.load_memmap(in_data_dir) + in_td_data = TensorDict.load_memmap(in_data_dir / "replay_buffer") # use 1 frame to know the specification of the dataset # and copy it over `n` frames in the test artifact directory - out_td_data = in_td_data[0].expand(num_frames).memmap_like(out_data_dir) + out_td_data = in_td_data[0].expand(num_frames).memmap_like(out_data_dir / "replay_buffer") # copy the first `n` frames so that we have real data out_td_data[:num_frames] = in_td_data[:num_frames].clone() @@ -24,8 +27,8 @@ def mock_dataset(in_data_dir, out_data_dir, num_frames=50): out_td_data.lock_() # copy the full statistics of dataset since it's pretty small - in_stats_path = Path(in_data_dir) / "stats.pth" - out_stats_path = Path(out_data_dir) / "stats.pth" + in_stats_path = in_data_dir / "stats.pth" + out_stats_path = out_data_dir / "stats.pth" shutil.copy(in_stats_path, out_stats_path) From 10034e85c4df5411c60940373b75cb4ae565282b Mon Sep 17 00:00:00 2001 From: Cadene Date: Tue, 19 Mar 2024 16:03:42 +0000 Subject: [PATCH 2/5] Aloha done --- lerobot/common/datasets/aloha.py | 2 +- .../aloha_sim_insertion_human/{ => replay_buffer}/action.memmap | 0 .../{ => replay_buffer}/episode.memmap | 0 .../{ => replay_buffer}/frame_id.memmap | 0 .../aloha_sim_insertion_human/{ => replay_buffer}/meta.json | 0 .../{ => replay_buffer}/next/done.memmap | 0 .../{ => replay_buffer}/next/meta.json | 0 .../{ => replay_buffer}/next/observation/image/meta.json | 0 .../{ => replay_buffer}/next/observation/image/top.memmap | 0 .../{ => replay_buffer}/next/observation/meta.json | 0 .../{ => replay_buffer}/next/observation/state.memmap | 0 .../{ => replay_buffer}/observation/image/meta.json | 0 .../{ => replay_buffer}/observation/image/top.memmap | 0 .../{ => replay_buffer}/observation/meta.json | 0 .../{ => replay_buffer}/observation/state.memmap | 0 .../{ => replay_buffer}/action.memmap | 0 .../{ => replay_buffer}/episode.memmap | 0 .../{ => replay_buffer}/frame_id.memmap | 0 .../aloha_sim_insertion_scripted/{ => replay_buffer}/meta.json | 0 .../{ => replay_buffer}/next/done.memmap | 0 .../{ => replay_buffer}/next/meta.json | 0 .../{ => replay_buffer}/next/observation/image/meta.json | 0 .../{ => replay_buffer}/next/observation/image/top.memmap | 0 .../{ => replay_buffer}/next/observation/meta.json | 0 .../{ => replay_buffer}/next/observation/state.memmap | 0 .../{ => replay_buffer}/observation/image/meta.json | 0 .../{ => replay_buffer}/observation/image/top.memmap | 0 .../{ => replay_buffer}/observation/meta.json | 0 .../{ => replay_buffer}/observation/state.memmap | 0 .../{ => replay_buffer}/action.memmap | 0 .../{ => replay_buffer}/episode.memmap | 0 .../{ => replay_buffer}/frame_id.memmap | 0 .../aloha_sim_transfer_cube_human/{ => replay_buffer}/meta.json | 0 .../{ => replay_buffer}/next/done.memmap | 0 .../{ => replay_buffer}/next/meta.json | 0 .../{ => replay_buffer}/next/observation/image/meta.json | 0 .../{ => replay_buffer}/next/observation/image/top.memmap | 0 .../{ => replay_buffer}/next/observation/meta.json | 0 .../{ => replay_buffer}/next/observation/state.memmap | 0 .../{ => replay_buffer}/observation/image/meta.json | 0 .../{ => replay_buffer}/observation/image/top.memmap | 0 .../{ => replay_buffer}/observation/meta.json | 0 .../{ => replay_buffer}/observation/state.memmap | 0 .../{ => replay_buffer}/action.memmap | 0 .../{ => replay_buffer}/episode.memmap | 0 .../{ => replay_buffer}/frame_id.memmap | 0 .../{ => replay_buffer}/meta.json | 0 .../{ => replay_buffer}/next/done.memmap | 0 .../{ => replay_buffer}/next/meta.json | 0 .../{ => replay_buffer}/next/observation/image/meta.json | 0 .../{ => replay_buffer}/next/observation/image/top.memmap | 0 .../{ => replay_buffer}/next/observation/meta.json | 0 .../{ => replay_buffer}/next/observation/state.memmap | 0 .../{ => replay_buffer}/observation/image/meta.json | 0 .../{ => replay_buffer}/observation/image/top.memmap | 0 .../{ => replay_buffer}/observation/meta.json | 0 .../{ => replay_buffer}/observation/state.memmap | 0 57 files changed, 1 insertion(+), 1 deletion(-) rename tests/data/aloha_sim_insertion_human/{ => replay_buffer}/action.memmap (100%) rename tests/data/aloha_sim_insertion_human/{ => replay_buffer}/episode.memmap (100%) rename tests/data/aloha_sim_insertion_human/{ => replay_buffer}/frame_id.memmap (100%) rename tests/data/aloha_sim_insertion_human/{ => replay_buffer}/meta.json (100%) rename tests/data/aloha_sim_insertion_human/{ => replay_buffer}/next/done.memmap (100%) rename tests/data/aloha_sim_insertion_human/{ => replay_buffer}/next/meta.json (100%) rename tests/data/aloha_sim_insertion_human/{ => replay_buffer}/next/observation/image/meta.json (100%) rename tests/data/aloha_sim_insertion_human/{ => replay_buffer}/next/observation/image/top.memmap (100%) rename tests/data/aloha_sim_insertion_human/{ => replay_buffer}/next/observation/meta.json (100%) rename tests/data/aloha_sim_insertion_human/{ => replay_buffer}/next/observation/state.memmap (100%) rename tests/data/aloha_sim_insertion_human/{ => replay_buffer}/observation/image/meta.json (100%) rename tests/data/aloha_sim_insertion_human/{ => replay_buffer}/observation/image/top.memmap (100%) rename tests/data/aloha_sim_insertion_human/{ => replay_buffer}/observation/meta.json (100%) rename tests/data/aloha_sim_insertion_human/{ => replay_buffer}/observation/state.memmap (100%) rename tests/data/aloha_sim_insertion_scripted/{ => replay_buffer}/action.memmap (100%) rename tests/data/aloha_sim_insertion_scripted/{ => replay_buffer}/episode.memmap (100%) rename tests/data/aloha_sim_insertion_scripted/{ => replay_buffer}/frame_id.memmap (100%) rename tests/data/aloha_sim_insertion_scripted/{ => replay_buffer}/meta.json (100%) rename tests/data/aloha_sim_insertion_scripted/{ => replay_buffer}/next/done.memmap (100%) rename tests/data/aloha_sim_insertion_scripted/{ => replay_buffer}/next/meta.json (100%) rename tests/data/aloha_sim_insertion_scripted/{ => replay_buffer}/next/observation/image/meta.json (100%) rename tests/data/aloha_sim_insertion_scripted/{ => replay_buffer}/next/observation/image/top.memmap (100%) rename tests/data/aloha_sim_insertion_scripted/{ => replay_buffer}/next/observation/meta.json (100%) rename tests/data/aloha_sim_insertion_scripted/{ => replay_buffer}/next/observation/state.memmap (100%) rename tests/data/aloha_sim_insertion_scripted/{ => replay_buffer}/observation/image/meta.json (100%) rename tests/data/aloha_sim_insertion_scripted/{ => replay_buffer}/observation/image/top.memmap (100%) rename tests/data/aloha_sim_insertion_scripted/{ => replay_buffer}/observation/meta.json (100%) rename tests/data/aloha_sim_insertion_scripted/{ => replay_buffer}/observation/state.memmap (100%) rename tests/data/aloha_sim_transfer_cube_human/{ => replay_buffer}/action.memmap (100%) rename tests/data/aloha_sim_transfer_cube_human/{ => replay_buffer}/episode.memmap (100%) rename tests/data/aloha_sim_transfer_cube_human/{ => replay_buffer}/frame_id.memmap (100%) rename tests/data/aloha_sim_transfer_cube_human/{ => replay_buffer}/meta.json (100%) rename tests/data/aloha_sim_transfer_cube_human/{ => replay_buffer}/next/done.memmap (100%) rename tests/data/aloha_sim_transfer_cube_human/{ => replay_buffer}/next/meta.json (100%) rename tests/data/aloha_sim_transfer_cube_human/{ => replay_buffer}/next/observation/image/meta.json (100%) rename tests/data/aloha_sim_transfer_cube_human/{ => replay_buffer}/next/observation/image/top.memmap (100%) rename tests/data/aloha_sim_transfer_cube_human/{ => replay_buffer}/next/observation/meta.json (100%) rename tests/data/aloha_sim_transfer_cube_human/{ => replay_buffer}/next/observation/state.memmap (100%) rename tests/data/aloha_sim_transfer_cube_human/{ => replay_buffer}/observation/image/meta.json (100%) rename tests/data/aloha_sim_transfer_cube_human/{ => replay_buffer}/observation/image/top.memmap (100%) rename tests/data/aloha_sim_transfer_cube_human/{ => replay_buffer}/observation/meta.json (100%) rename tests/data/aloha_sim_transfer_cube_human/{ => replay_buffer}/observation/state.memmap (100%) rename tests/data/aloha_sim_transfer_cube_scripted/{ => replay_buffer}/action.memmap (100%) rename tests/data/aloha_sim_transfer_cube_scripted/{ => replay_buffer}/episode.memmap (100%) rename tests/data/aloha_sim_transfer_cube_scripted/{ => replay_buffer}/frame_id.memmap (100%) rename tests/data/aloha_sim_transfer_cube_scripted/{ => replay_buffer}/meta.json (100%) rename tests/data/aloha_sim_transfer_cube_scripted/{ => replay_buffer}/next/done.memmap (100%) rename tests/data/aloha_sim_transfer_cube_scripted/{ => replay_buffer}/next/meta.json (100%) rename tests/data/aloha_sim_transfer_cube_scripted/{ => replay_buffer}/next/observation/image/meta.json (100%) rename tests/data/aloha_sim_transfer_cube_scripted/{ => replay_buffer}/next/observation/image/top.memmap (100%) rename tests/data/aloha_sim_transfer_cube_scripted/{ => replay_buffer}/next/observation/meta.json (100%) rename tests/data/aloha_sim_transfer_cube_scripted/{ => replay_buffer}/next/observation/state.memmap (100%) rename tests/data/aloha_sim_transfer_cube_scripted/{ => replay_buffer}/observation/image/meta.json (100%) rename tests/data/aloha_sim_transfer_cube_scripted/{ => replay_buffer}/observation/image/top.memmap (100%) rename tests/data/aloha_sim_transfer_cube_scripted/{ => replay_buffer}/observation/meta.json (100%) rename tests/data/aloha_sim_transfer_cube_scripted/{ => replay_buffer}/observation/state.memmap (100%) diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index 2af98cd8..82989659 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -84,7 +84,7 @@ class AlohaExperienceReplay(AbstractExperienceReplay): def __init__( self, dataset_id: str, - version: str | None = None, + version: str | None = "v1.0", batch_size: int = None, *, shuffle: bool = True, diff --git a/tests/data/aloha_sim_insertion_human/action.memmap b/tests/data/aloha_sim_insertion_human/replay_buffer/action.memmap similarity index 100% rename from tests/data/aloha_sim_insertion_human/action.memmap rename to tests/data/aloha_sim_insertion_human/replay_buffer/action.memmap diff --git a/tests/data/aloha_sim_insertion_human/episode.memmap b/tests/data/aloha_sim_insertion_human/replay_buffer/episode.memmap similarity index 100% rename from tests/data/aloha_sim_insertion_human/episode.memmap rename to tests/data/aloha_sim_insertion_human/replay_buffer/episode.memmap diff --git a/tests/data/aloha_sim_insertion_human/frame_id.memmap b/tests/data/aloha_sim_insertion_human/replay_buffer/frame_id.memmap similarity index 100% rename from tests/data/aloha_sim_insertion_human/frame_id.memmap rename to tests/data/aloha_sim_insertion_human/replay_buffer/frame_id.memmap diff --git a/tests/data/aloha_sim_insertion_human/meta.json b/tests/data/aloha_sim_insertion_human/replay_buffer/meta.json similarity index 100% rename from tests/data/aloha_sim_insertion_human/meta.json rename to tests/data/aloha_sim_insertion_human/replay_buffer/meta.json diff --git a/tests/data/aloha_sim_insertion_human/next/done.memmap b/tests/data/aloha_sim_insertion_human/replay_buffer/next/done.memmap similarity index 100% rename from tests/data/aloha_sim_insertion_human/next/done.memmap rename to tests/data/aloha_sim_insertion_human/replay_buffer/next/done.memmap diff --git a/tests/data/aloha_sim_insertion_human/next/meta.json b/tests/data/aloha_sim_insertion_human/replay_buffer/next/meta.json similarity index 100% rename from tests/data/aloha_sim_insertion_human/next/meta.json rename to tests/data/aloha_sim_insertion_human/replay_buffer/next/meta.json diff --git a/tests/data/aloha_sim_insertion_human/next/observation/image/meta.json b/tests/data/aloha_sim_insertion_human/replay_buffer/next/observation/image/meta.json similarity index 100% rename from tests/data/aloha_sim_insertion_human/next/observation/image/meta.json rename to tests/data/aloha_sim_insertion_human/replay_buffer/next/observation/image/meta.json diff --git a/tests/data/aloha_sim_insertion_human/next/observation/image/top.memmap b/tests/data/aloha_sim_insertion_human/replay_buffer/next/observation/image/top.memmap similarity index 100% rename from tests/data/aloha_sim_insertion_human/next/observation/image/top.memmap rename to tests/data/aloha_sim_insertion_human/replay_buffer/next/observation/image/top.memmap diff --git a/tests/data/aloha_sim_insertion_human/next/observation/meta.json b/tests/data/aloha_sim_insertion_human/replay_buffer/next/observation/meta.json similarity index 100% rename from tests/data/aloha_sim_insertion_human/next/observation/meta.json rename to tests/data/aloha_sim_insertion_human/replay_buffer/next/observation/meta.json diff --git a/tests/data/aloha_sim_insertion_human/next/observation/state.memmap b/tests/data/aloha_sim_insertion_human/replay_buffer/next/observation/state.memmap similarity index 100% rename from tests/data/aloha_sim_insertion_human/next/observation/state.memmap rename to tests/data/aloha_sim_insertion_human/replay_buffer/next/observation/state.memmap diff --git a/tests/data/aloha_sim_insertion_human/observation/image/meta.json b/tests/data/aloha_sim_insertion_human/replay_buffer/observation/image/meta.json similarity index 100% rename from tests/data/aloha_sim_insertion_human/observation/image/meta.json rename to tests/data/aloha_sim_insertion_human/replay_buffer/observation/image/meta.json diff --git a/tests/data/aloha_sim_insertion_human/observation/image/top.memmap b/tests/data/aloha_sim_insertion_human/replay_buffer/observation/image/top.memmap similarity index 100% rename from tests/data/aloha_sim_insertion_human/observation/image/top.memmap rename to tests/data/aloha_sim_insertion_human/replay_buffer/observation/image/top.memmap diff --git a/tests/data/aloha_sim_insertion_human/observation/meta.json b/tests/data/aloha_sim_insertion_human/replay_buffer/observation/meta.json similarity index 100% rename from tests/data/aloha_sim_insertion_human/observation/meta.json rename to tests/data/aloha_sim_insertion_human/replay_buffer/observation/meta.json diff --git a/tests/data/aloha_sim_insertion_human/observation/state.memmap b/tests/data/aloha_sim_insertion_human/replay_buffer/observation/state.memmap similarity index 100% rename from tests/data/aloha_sim_insertion_human/observation/state.memmap rename to tests/data/aloha_sim_insertion_human/replay_buffer/observation/state.memmap diff --git a/tests/data/aloha_sim_insertion_scripted/action.memmap b/tests/data/aloha_sim_insertion_scripted/replay_buffer/action.memmap similarity index 100% rename from tests/data/aloha_sim_insertion_scripted/action.memmap rename to tests/data/aloha_sim_insertion_scripted/replay_buffer/action.memmap diff --git a/tests/data/aloha_sim_insertion_scripted/episode.memmap b/tests/data/aloha_sim_insertion_scripted/replay_buffer/episode.memmap similarity index 100% rename from tests/data/aloha_sim_insertion_scripted/episode.memmap rename to tests/data/aloha_sim_insertion_scripted/replay_buffer/episode.memmap diff --git a/tests/data/aloha_sim_insertion_scripted/frame_id.memmap b/tests/data/aloha_sim_insertion_scripted/replay_buffer/frame_id.memmap similarity index 100% rename from tests/data/aloha_sim_insertion_scripted/frame_id.memmap rename to tests/data/aloha_sim_insertion_scripted/replay_buffer/frame_id.memmap diff --git a/tests/data/aloha_sim_insertion_scripted/meta.json b/tests/data/aloha_sim_insertion_scripted/replay_buffer/meta.json similarity index 100% rename from tests/data/aloha_sim_insertion_scripted/meta.json rename to tests/data/aloha_sim_insertion_scripted/replay_buffer/meta.json diff --git a/tests/data/aloha_sim_insertion_scripted/next/done.memmap b/tests/data/aloha_sim_insertion_scripted/replay_buffer/next/done.memmap similarity index 100% rename from tests/data/aloha_sim_insertion_scripted/next/done.memmap rename to tests/data/aloha_sim_insertion_scripted/replay_buffer/next/done.memmap diff --git a/tests/data/aloha_sim_insertion_scripted/next/meta.json b/tests/data/aloha_sim_insertion_scripted/replay_buffer/next/meta.json similarity index 100% rename from tests/data/aloha_sim_insertion_scripted/next/meta.json rename to tests/data/aloha_sim_insertion_scripted/replay_buffer/next/meta.json diff --git a/tests/data/aloha_sim_insertion_scripted/next/observation/image/meta.json b/tests/data/aloha_sim_insertion_scripted/replay_buffer/next/observation/image/meta.json similarity index 100% rename from tests/data/aloha_sim_insertion_scripted/next/observation/image/meta.json rename to tests/data/aloha_sim_insertion_scripted/replay_buffer/next/observation/image/meta.json diff --git a/tests/data/aloha_sim_insertion_scripted/next/observation/image/top.memmap b/tests/data/aloha_sim_insertion_scripted/replay_buffer/next/observation/image/top.memmap similarity index 100% rename from tests/data/aloha_sim_insertion_scripted/next/observation/image/top.memmap rename to tests/data/aloha_sim_insertion_scripted/replay_buffer/next/observation/image/top.memmap diff --git a/tests/data/aloha_sim_insertion_scripted/next/observation/meta.json b/tests/data/aloha_sim_insertion_scripted/replay_buffer/next/observation/meta.json similarity index 100% rename from tests/data/aloha_sim_insertion_scripted/next/observation/meta.json rename to tests/data/aloha_sim_insertion_scripted/replay_buffer/next/observation/meta.json diff --git a/tests/data/aloha_sim_insertion_scripted/next/observation/state.memmap b/tests/data/aloha_sim_insertion_scripted/replay_buffer/next/observation/state.memmap similarity index 100% rename from tests/data/aloha_sim_insertion_scripted/next/observation/state.memmap rename to tests/data/aloha_sim_insertion_scripted/replay_buffer/next/observation/state.memmap diff --git a/tests/data/aloha_sim_insertion_scripted/observation/image/meta.json b/tests/data/aloha_sim_insertion_scripted/replay_buffer/observation/image/meta.json similarity index 100% rename from tests/data/aloha_sim_insertion_scripted/observation/image/meta.json rename to tests/data/aloha_sim_insertion_scripted/replay_buffer/observation/image/meta.json diff --git a/tests/data/aloha_sim_insertion_scripted/observation/image/top.memmap b/tests/data/aloha_sim_insertion_scripted/replay_buffer/observation/image/top.memmap similarity index 100% rename from tests/data/aloha_sim_insertion_scripted/observation/image/top.memmap rename to tests/data/aloha_sim_insertion_scripted/replay_buffer/observation/image/top.memmap diff --git a/tests/data/aloha_sim_insertion_scripted/observation/meta.json b/tests/data/aloha_sim_insertion_scripted/replay_buffer/observation/meta.json similarity index 100% rename from tests/data/aloha_sim_insertion_scripted/observation/meta.json rename to tests/data/aloha_sim_insertion_scripted/replay_buffer/observation/meta.json diff --git a/tests/data/aloha_sim_insertion_scripted/observation/state.memmap b/tests/data/aloha_sim_insertion_scripted/replay_buffer/observation/state.memmap similarity index 100% rename from tests/data/aloha_sim_insertion_scripted/observation/state.memmap rename to tests/data/aloha_sim_insertion_scripted/replay_buffer/observation/state.memmap diff --git a/tests/data/aloha_sim_transfer_cube_human/action.memmap b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/action.memmap similarity index 100% rename from tests/data/aloha_sim_transfer_cube_human/action.memmap rename to tests/data/aloha_sim_transfer_cube_human/replay_buffer/action.memmap diff --git a/tests/data/aloha_sim_transfer_cube_human/episode.memmap b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/episode.memmap similarity index 100% rename from tests/data/aloha_sim_transfer_cube_human/episode.memmap rename to tests/data/aloha_sim_transfer_cube_human/replay_buffer/episode.memmap diff --git a/tests/data/aloha_sim_transfer_cube_human/frame_id.memmap b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/frame_id.memmap similarity index 100% rename from tests/data/aloha_sim_transfer_cube_human/frame_id.memmap rename to tests/data/aloha_sim_transfer_cube_human/replay_buffer/frame_id.memmap diff --git a/tests/data/aloha_sim_transfer_cube_human/meta.json b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/meta.json similarity index 100% rename from tests/data/aloha_sim_transfer_cube_human/meta.json rename to tests/data/aloha_sim_transfer_cube_human/replay_buffer/meta.json diff --git a/tests/data/aloha_sim_transfer_cube_human/next/done.memmap b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/done.memmap similarity index 100% rename from tests/data/aloha_sim_transfer_cube_human/next/done.memmap rename to tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/done.memmap diff --git a/tests/data/aloha_sim_transfer_cube_human/next/meta.json b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/meta.json similarity index 100% rename from tests/data/aloha_sim_transfer_cube_human/next/meta.json rename to tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/meta.json diff --git a/tests/data/aloha_sim_transfer_cube_human/next/observation/image/meta.json b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/observation/image/meta.json similarity index 100% rename from tests/data/aloha_sim_transfer_cube_human/next/observation/image/meta.json rename to tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/observation/image/meta.json diff --git a/tests/data/aloha_sim_transfer_cube_human/next/observation/image/top.memmap b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/observation/image/top.memmap similarity index 100% rename from tests/data/aloha_sim_transfer_cube_human/next/observation/image/top.memmap rename to tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/observation/image/top.memmap diff --git a/tests/data/aloha_sim_transfer_cube_human/next/observation/meta.json b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/observation/meta.json similarity index 100% rename from tests/data/aloha_sim_transfer_cube_human/next/observation/meta.json rename to tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/observation/meta.json diff --git a/tests/data/aloha_sim_transfer_cube_human/next/observation/state.memmap b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/observation/state.memmap similarity index 100% rename from tests/data/aloha_sim_transfer_cube_human/next/observation/state.memmap rename to tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/observation/state.memmap diff --git a/tests/data/aloha_sim_transfer_cube_human/observation/image/meta.json b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/observation/image/meta.json similarity index 100% rename from tests/data/aloha_sim_transfer_cube_human/observation/image/meta.json rename to tests/data/aloha_sim_transfer_cube_human/replay_buffer/observation/image/meta.json diff --git a/tests/data/aloha_sim_transfer_cube_human/observation/image/top.memmap b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/observation/image/top.memmap similarity index 100% rename from tests/data/aloha_sim_transfer_cube_human/observation/image/top.memmap rename to tests/data/aloha_sim_transfer_cube_human/replay_buffer/observation/image/top.memmap diff --git a/tests/data/aloha_sim_transfer_cube_human/observation/meta.json b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/observation/meta.json similarity index 100% rename from tests/data/aloha_sim_transfer_cube_human/observation/meta.json rename to tests/data/aloha_sim_transfer_cube_human/replay_buffer/observation/meta.json diff --git a/tests/data/aloha_sim_transfer_cube_human/observation/state.memmap b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/observation/state.memmap similarity index 100% rename from tests/data/aloha_sim_transfer_cube_human/observation/state.memmap rename to tests/data/aloha_sim_transfer_cube_human/replay_buffer/observation/state.memmap diff --git a/tests/data/aloha_sim_transfer_cube_scripted/action.memmap b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/action.memmap similarity index 100% rename from tests/data/aloha_sim_transfer_cube_scripted/action.memmap rename to tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/action.memmap diff --git a/tests/data/aloha_sim_transfer_cube_scripted/episode.memmap b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/episode.memmap similarity index 100% rename from tests/data/aloha_sim_transfer_cube_scripted/episode.memmap rename to tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/episode.memmap diff --git a/tests/data/aloha_sim_transfer_cube_scripted/frame_id.memmap b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/frame_id.memmap similarity index 100% rename from tests/data/aloha_sim_transfer_cube_scripted/frame_id.memmap rename to tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/frame_id.memmap diff --git a/tests/data/aloha_sim_transfer_cube_scripted/meta.json b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/meta.json similarity index 100% rename from tests/data/aloha_sim_transfer_cube_scripted/meta.json rename to tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/meta.json diff --git a/tests/data/aloha_sim_transfer_cube_scripted/next/done.memmap b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/done.memmap similarity index 100% rename from tests/data/aloha_sim_transfer_cube_scripted/next/done.memmap rename to tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/done.memmap diff --git a/tests/data/aloha_sim_transfer_cube_scripted/next/meta.json b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/meta.json similarity index 100% rename from tests/data/aloha_sim_transfer_cube_scripted/next/meta.json rename to tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/meta.json diff --git a/tests/data/aloha_sim_transfer_cube_scripted/next/observation/image/meta.json b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/observation/image/meta.json similarity index 100% rename from tests/data/aloha_sim_transfer_cube_scripted/next/observation/image/meta.json rename to tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/observation/image/meta.json diff --git a/tests/data/aloha_sim_transfer_cube_scripted/next/observation/image/top.memmap b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/observation/image/top.memmap similarity index 100% rename from tests/data/aloha_sim_transfer_cube_scripted/next/observation/image/top.memmap rename to tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/observation/image/top.memmap diff --git a/tests/data/aloha_sim_transfer_cube_scripted/next/observation/meta.json b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/observation/meta.json similarity index 100% rename from tests/data/aloha_sim_transfer_cube_scripted/next/observation/meta.json rename to tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/observation/meta.json diff --git a/tests/data/aloha_sim_transfer_cube_scripted/next/observation/state.memmap b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/observation/state.memmap similarity index 100% rename from tests/data/aloha_sim_transfer_cube_scripted/next/observation/state.memmap rename to tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/observation/state.memmap diff --git a/tests/data/aloha_sim_transfer_cube_scripted/observation/image/meta.json b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/observation/image/meta.json similarity index 100% rename from tests/data/aloha_sim_transfer_cube_scripted/observation/image/meta.json rename to tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/observation/image/meta.json diff --git a/tests/data/aloha_sim_transfer_cube_scripted/observation/image/top.memmap b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/observation/image/top.memmap similarity index 100% rename from tests/data/aloha_sim_transfer_cube_scripted/observation/image/top.memmap rename to tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/observation/image/top.memmap diff --git a/tests/data/aloha_sim_transfer_cube_scripted/observation/meta.json b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/observation/meta.json similarity index 100% rename from tests/data/aloha_sim_transfer_cube_scripted/observation/meta.json rename to tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/observation/meta.json diff --git a/tests/data/aloha_sim_transfer_cube_scripted/observation/state.memmap b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/observation/state.memmap similarity index 100% rename from tests/data/aloha_sim_transfer_cube_scripted/observation/state.memmap rename to tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/observation/state.memmap From e799dc5e3fee333f731ef5f50e79c2bb6b20ea85 Mon Sep 17 00:00:00 2001 From: Cadene Date: Tue, 19 Mar 2024 16:38:07 +0000 Subject: [PATCH 3/5] Improve mock_dataset --- tests/scripts/mock_dataset.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/tests/scripts/mock_dataset.py b/tests/scripts/mock_dataset.py index 2200b644..d9c86464 100644 --- a/tests/scripts/mock_dataset.py +++ b/tests/scripts/mock_dataset.py @@ -1,5 +1,18 @@ """ - usage: `python tests/scripts/mock_dataset.py --in-data-dir data/pusht --out-data-dir tests/data/pusht` +This script is designed to facilitate the creation of a subset of an existing dataset by selecting a specific number of frames from the original dataset. +This subset can then be used for running quick unit tests. +The script takes an input directory containing the original dataset and an output directory where the subset of the dataset will be saved. +Additionally, the number of frames to include in the subset can be specified. +The script ensures that the subset is a representative sample of the original dataset by copying the specified number of frames and retaining the structure and format of the data. + +Usage: + Run the script with the following command, specifying the path to the input data directory, + the path to the output data directory, and optionally the number of frames to include in the subset dataset: + + `python tests/scripts/mock_dataset.py --in-data-dir path/to/input_data --out-data-dir path/to/output_data` + +Example: + `python tests/scripts/mock_dataset.py --in-data-dir data/pusht --out-data-dir tests/data/pusht` """ import argparse @@ -9,7 +22,7 @@ from tensordict import TensorDict from pathlib import Path -def mock_dataset(in_data_dir, out_data_dir, num_frames=50): +def mock_dataset(in_data_dir, out_data_dir, num_frames): in_data_dir = Path(in_data_dir) out_data_dir = Path(out_data_dir) @@ -34,11 +47,12 @@ def mock_dataset(in_data_dir, out_data_dir, num_frames=50): if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Create dataset") + parser = argparse.ArgumentParser(description="Create a dataset with a subset of frames for quick testing.") parser.add_argument("--in-data-dir", type=str, help="Path to input data") parser.add_argument("--out-data-dir", type=str, help="Path to save the output data") + parser.add_argument("--num-frames", type=int, default=50, help="Number of frames to copy over") args = parser.parse_args() - mock_dataset(args.in_data_dir, args.out_data_dir) \ No newline at end of file + mock_dataset(args.in_data_dir, args.out_data_dir, args.num_frames) \ No newline at end of file From b420ab88f46a94a22f40191c903a2a4f038d8cdc Mon Sep 17 00:00:00 2001 From: Cadene Date: Tue, 19 Mar 2024 16:44:19 +0000 Subject: [PATCH 4/5] version naming conventions --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 71990e28..7e4baad8 100644 --- a/README.md +++ b/README.md @@ -181,7 +181,7 @@ Iterate on your code and dataset with: DATA_DIR=data python train.py ``` -Then upload a new version: +Then upload a new version (v2.0 or v1.1 if the changes are respectively more or less significant): ``` HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload $HF_USER/$DATASET data/$DATASET \ --repo-type dataset \ From 7d5d99e036c7eccd973a0ced0084fdcc18cf3a85 Mon Sep 17 00:00:00 2001 From: Cadene Date: Tue, 19 Mar 2024 16:53:07 +0000 Subject: [PATCH 5/5] Address more comments --- README.md | 23 ++++++++++++++++++----- lerobot/common/datasets/abstract.py | 6 ++++++ 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 7e4baad8..dc51fec2 100644 --- a/README.md +++ b/README.md @@ -148,9 +148,9 @@ DATA_DIR="tests/data" pytest -sx tests **Datasets** -To add a pytorch rl dataset to the hub, first login and use a token generated from [huggingface settings](https://huggingface.co/settings/tokens) with write access: +To add a dataset to the hub, first login and use a token generated from [huggingface settings](https://huggingface.co/settings/tokens) with write access: ``` -huggingface-cli login --token $HUGGINGFACE_TOKEN --add-to-git-credential +huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential ``` Then you can upload it to the hub with: @@ -160,6 +160,12 @@ HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload $HF_USER/$DATASET data/$DATAS --revision v1.0 ``` +You will need to set the corresponding version as a default argument in your dataset class: +```python + version: str | None = "v1.0", +``` +See: [`lerobot/common/datasets/pusht.py`](https://github.com/Cadene/lerobot/blob/main/lerobot/common/datasets/pusht.py) + For instance, for [cadene/pusht](https://huggingface.co/datasets/cadene/pusht), we used: ``` HF_USER=cadene @@ -169,7 +175,7 @@ DATASET=pusht If you want to improve an existing dataset, you can download it locally with: ``` mkdir -p data/$DATASET -HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli download $HF_USER/$DATASET \ +HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli download ${HF_USER}/$DATASET \ --repo-type dataset \ --local-dir data/$DATASET \ --local-dir-use-symlinks=False \ @@ -181,7 +187,7 @@ Iterate on your code and dataset with: DATA_DIR=data python train.py ``` -Then upload a new version (v2.0 or v1.1 if the changes are respectively more or less significant): +Upload a new version (v2.0 or v1.1 if the changes are respectively more or less significant): ``` HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload $HF_USER/$DATASET data/$DATASET \ --repo-type dataset \ @@ -189,7 +195,14 @@ HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload $HF_USER/$DATASET data/$DATAS --delete "*" ``` -And you might want to mock the dataset if you need to update the unit tests as well: +Then you will need to set the corresponding version as a default argument in your dataset class: +```python + version: str | None = "v1.1", +``` +See: [`lerobot/common/datasets/pusht.py`](https://github.com/Cadene/lerobot/blob/main/lerobot/common/datasets/pusht.py) + + +Finally, you might want to mock the dataset if you need to update the unit tests as well: ``` python tests/scripts/mock_dataset.py --in-data-dir data/$DATASET --out-data-dir tests/data/$DATASET ``` diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index 9127d887..aec53877 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -35,6 +35,12 @@ class AbstractExperienceReplay(TensorDictReplayBuffer): self.version = version self.shuffle = shuffle self.root = root + + if self.root is not None and self.version is not None: + logging.warning( + f"The version of the dataset ({self.version}) is not enforced when root is provided ({self.root})." + ) + storage = self._download_or_load_dataset() super().__init__(