diff --git a/README.md b/README.md index 1051c8a6..dc51fec2 100644 --- a/README.md +++ b/README.md @@ -138,7 +138,7 @@ git lfs pull When adding a new dataset, mock it with ``` -python tests/scripts/mock_dataset.py --in-data-dir data/ --out-data-dir tests/data/ +python tests/scripts/mock_dataset.py --in-data-dir data/$DATASET --out-data-dir tests/data/$DATASET ``` Run tests @@ -148,22 +148,65 @@ DATA_DIR="tests/data" pytest -sx tests **Datasets** -To add a pytorch rl dataset to the hub, first login and use a token generated from [huggingface settings](https://huggingface.co/settings/tokens) with write access: +To add a dataset to the hub, first login and use a token generated from [huggingface settings](https://huggingface.co/settings/tokens) with write access: ``` -huggingface-cli login --token $HUGGINGFACE_TOKEN --add-to-git-credential +huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential ``` Then you can upload it to the hub with: ``` -HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload --repo-type dataset $HF_USER/$DATASET data/$DATASET +HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload $HF_USER/$DATASET data/$DATASET \ +--repo-type dataset \ +--revision v1.0 ``` +You will need to set the corresponding version as a default argument in your dataset class: +```python + version: str | None = "v1.0", +``` +See: [`lerobot/common/datasets/pusht.py`](https://github.com/Cadene/lerobot/blob/main/lerobot/common/datasets/pusht.py) + For instance, for [cadene/pusht](https://huggingface.co/datasets/cadene/pusht), we used: ``` HF_USER=cadene DATASET=pusht ``` +If you want to improve an existing dataset, you can download it locally with: +``` +mkdir -p data/$DATASET +HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli download ${HF_USER}/$DATASET \ +--repo-type dataset \ +--local-dir data/$DATASET \ +--local-dir-use-symlinks=False \ +--revision v1.0 +``` + +Iterate on your code and dataset with: +``` +DATA_DIR=data python train.py +``` + +Upload a new version (v2.0 or v1.1 if the changes are respectively more or less significant): +``` +HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload $HF_USER/$DATASET data/$DATASET \ +--repo-type dataset \ +--revision v1.1 \ +--delete "*" +``` + +Then you will need to set the corresponding version as a default argument in your dataset class: +```python + version: str | None = "v1.1", +``` +See: [`lerobot/common/datasets/pusht.py`](https://github.com/Cadene/lerobot/blob/main/lerobot/common/datasets/pusht.py) + + +Finally, you might want to mock the dataset if you need to update the unit tests as well: +``` +python tests/scripts/mock_dataset.py --in-data-dir data/$DATASET --out-data-dir tests/data/$DATASET +``` + ## Acknowledgment - Our Diffusion policy and Pusht environment are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu/) diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index 34b33c2e..aec53877 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -19,6 +19,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer): def __init__( self, dataset_id: str, + version: str | None = None, batch_size: int = None, *, shuffle: bool = True, @@ -31,8 +32,15 @@ class AbstractExperienceReplay(TensorDictReplayBuffer): transform: "torchrl.envs.Transform" = None, ): self.dataset_id = dataset_id + self.version = version self.shuffle = shuffle self.root = root + + if self.root is not None and self.version is not None: + logging.warning( + f"The version of the dataset ({self.version}) is not enforced when root is provided ({self.root})." + ) + storage = self._download_or_load_dataset() super().__init__( @@ -96,10 +104,14 @@ class AbstractExperienceReplay(TensorDictReplayBuffer): def _download_or_load_dataset(self) -> torch.StorageBase: if self.root is None: - self.data_dir = Path(snapshot_download(repo_id=f"cadene/{self.dataset_id}", repo_type="dataset")) + self.data_dir = Path( + snapshot_download( + repo_id=f"cadene/{self.dataset_id}", repo_type="dataset", revision=self.version + ) + ) else: self.data_dir = self.root / self.dataset_id - return TensorStorage(TensorDict.load_memmap(self.data_dir)) + return TensorStorage(TensorDict.load_memmap(self.data_dir / "replay_buffer")) def _compute_stats(self, num_batch=100, batch_size=32): rb = TensorDictReplayBuffer( diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index 52a5676e..82989659 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -84,6 +84,7 @@ class AlohaExperienceReplay(AbstractExperienceReplay): def __init__( self, dataset_id: str, + version: str | None = "v1.0", batch_size: int = None, *, shuffle: bool = True, @@ -99,6 +100,7 @@ class AlohaExperienceReplay(AbstractExperienceReplay): super().__init__( dataset_id, + version, batch_size, shuffle=shuffle, root=root, diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index f4f6d9ac..3ad6371f 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -87,6 +87,7 @@ class PushtExperienceReplay(AbstractExperienceReplay): def __init__( self, dataset_id: str, + version: str | None = "v1.0", batch_size: int = None, *, shuffle: bool = True, @@ -100,6 +101,7 @@ class PushtExperienceReplay(AbstractExperienceReplay): ): super().__init__( dataset_id, + version, batch_size, shuffle=shuffle, root=root, diff --git a/lerobot/common/datasets/simxarm.py b/lerobot/common/datasets/simxarm.py index 7bcb03fb..d7e2e18f 100644 --- a/lerobot/common/datasets/simxarm.py +++ b/lerobot/common/datasets/simxarm.py @@ -40,6 +40,7 @@ class SimxarmExperienceReplay(AbstractExperienceReplay): def __init__( self, dataset_id: str, + version: str | None = None, batch_size: int = None, *, shuffle: bool = True, @@ -53,6 +54,7 @@ class SimxarmExperienceReplay(AbstractExperienceReplay): ): super().__init__( dataset_id, + version, batch_size, shuffle=shuffle, root=root, diff --git a/tests/data/aloha_sim_insertion_human/action.memmap b/tests/data/aloha_sim_insertion_human/replay_buffer/action.memmap similarity index 100% rename from tests/data/aloha_sim_insertion_human/action.memmap rename to tests/data/aloha_sim_insertion_human/replay_buffer/action.memmap diff --git a/tests/data/aloha_sim_insertion_human/episode.memmap b/tests/data/aloha_sim_insertion_human/replay_buffer/episode.memmap similarity index 100% rename from tests/data/aloha_sim_insertion_human/episode.memmap rename to tests/data/aloha_sim_insertion_human/replay_buffer/episode.memmap diff --git a/tests/data/aloha_sim_insertion_human/frame_id.memmap b/tests/data/aloha_sim_insertion_human/replay_buffer/frame_id.memmap similarity index 100% rename from tests/data/aloha_sim_insertion_human/frame_id.memmap rename to tests/data/aloha_sim_insertion_human/replay_buffer/frame_id.memmap diff --git a/tests/data/aloha_sim_insertion_human/meta.json b/tests/data/aloha_sim_insertion_human/replay_buffer/meta.json similarity index 100% rename from tests/data/aloha_sim_insertion_human/meta.json rename to tests/data/aloha_sim_insertion_human/replay_buffer/meta.json diff --git a/tests/data/aloha_sim_insertion_human/next/done.memmap b/tests/data/aloha_sim_insertion_human/replay_buffer/next/done.memmap similarity index 100% rename from tests/data/aloha_sim_insertion_human/next/done.memmap rename to tests/data/aloha_sim_insertion_human/replay_buffer/next/done.memmap diff --git a/tests/data/aloha_sim_insertion_human/next/meta.json b/tests/data/aloha_sim_insertion_human/replay_buffer/next/meta.json similarity index 100% rename from tests/data/aloha_sim_insertion_human/next/meta.json rename to tests/data/aloha_sim_insertion_human/replay_buffer/next/meta.json diff --git a/tests/data/aloha_sim_insertion_human/next/observation/image/meta.json b/tests/data/aloha_sim_insertion_human/replay_buffer/next/observation/image/meta.json similarity index 100% rename from tests/data/aloha_sim_insertion_human/next/observation/image/meta.json rename to tests/data/aloha_sim_insertion_human/replay_buffer/next/observation/image/meta.json diff --git a/tests/data/aloha_sim_insertion_human/next/observation/image/top.memmap b/tests/data/aloha_sim_insertion_human/replay_buffer/next/observation/image/top.memmap similarity index 100% rename from tests/data/aloha_sim_insertion_human/next/observation/image/top.memmap rename to tests/data/aloha_sim_insertion_human/replay_buffer/next/observation/image/top.memmap diff --git a/tests/data/aloha_sim_insertion_human/next/observation/meta.json b/tests/data/aloha_sim_insertion_human/replay_buffer/next/observation/meta.json similarity index 100% rename from tests/data/aloha_sim_insertion_human/next/observation/meta.json rename to tests/data/aloha_sim_insertion_human/replay_buffer/next/observation/meta.json diff --git a/tests/data/aloha_sim_insertion_human/next/observation/state.memmap b/tests/data/aloha_sim_insertion_human/replay_buffer/next/observation/state.memmap similarity index 100% rename from tests/data/aloha_sim_insertion_human/next/observation/state.memmap rename to tests/data/aloha_sim_insertion_human/replay_buffer/next/observation/state.memmap diff --git a/tests/data/aloha_sim_insertion_human/observation/image/meta.json b/tests/data/aloha_sim_insertion_human/replay_buffer/observation/image/meta.json similarity index 100% rename from tests/data/aloha_sim_insertion_human/observation/image/meta.json rename to tests/data/aloha_sim_insertion_human/replay_buffer/observation/image/meta.json diff --git a/tests/data/aloha_sim_insertion_human/observation/image/top.memmap b/tests/data/aloha_sim_insertion_human/replay_buffer/observation/image/top.memmap similarity index 100% rename from tests/data/aloha_sim_insertion_human/observation/image/top.memmap rename to tests/data/aloha_sim_insertion_human/replay_buffer/observation/image/top.memmap diff --git a/tests/data/aloha_sim_insertion_human/observation/meta.json b/tests/data/aloha_sim_insertion_human/replay_buffer/observation/meta.json similarity index 100% rename from tests/data/aloha_sim_insertion_human/observation/meta.json rename to tests/data/aloha_sim_insertion_human/replay_buffer/observation/meta.json diff --git a/tests/data/aloha_sim_insertion_human/observation/state.memmap b/tests/data/aloha_sim_insertion_human/replay_buffer/observation/state.memmap similarity index 100% rename from tests/data/aloha_sim_insertion_human/observation/state.memmap rename to tests/data/aloha_sim_insertion_human/replay_buffer/observation/state.memmap diff --git a/tests/data/aloha_sim_insertion_scripted/action.memmap b/tests/data/aloha_sim_insertion_scripted/replay_buffer/action.memmap similarity index 100% rename from tests/data/aloha_sim_insertion_scripted/action.memmap rename to tests/data/aloha_sim_insertion_scripted/replay_buffer/action.memmap diff --git a/tests/data/aloha_sim_insertion_scripted/episode.memmap b/tests/data/aloha_sim_insertion_scripted/replay_buffer/episode.memmap similarity index 100% rename from tests/data/aloha_sim_insertion_scripted/episode.memmap rename to tests/data/aloha_sim_insertion_scripted/replay_buffer/episode.memmap diff --git a/tests/data/aloha_sim_insertion_scripted/frame_id.memmap b/tests/data/aloha_sim_insertion_scripted/replay_buffer/frame_id.memmap similarity index 100% rename from tests/data/aloha_sim_insertion_scripted/frame_id.memmap rename to tests/data/aloha_sim_insertion_scripted/replay_buffer/frame_id.memmap diff --git a/tests/data/aloha_sim_insertion_scripted/meta.json b/tests/data/aloha_sim_insertion_scripted/replay_buffer/meta.json similarity index 100% rename from tests/data/aloha_sim_insertion_scripted/meta.json rename to tests/data/aloha_sim_insertion_scripted/replay_buffer/meta.json diff --git a/tests/data/aloha_sim_insertion_scripted/next/done.memmap b/tests/data/aloha_sim_insertion_scripted/replay_buffer/next/done.memmap similarity index 100% rename from tests/data/aloha_sim_insertion_scripted/next/done.memmap rename to tests/data/aloha_sim_insertion_scripted/replay_buffer/next/done.memmap diff --git a/tests/data/aloha_sim_insertion_scripted/next/meta.json b/tests/data/aloha_sim_insertion_scripted/replay_buffer/next/meta.json similarity index 100% rename from tests/data/aloha_sim_insertion_scripted/next/meta.json rename to tests/data/aloha_sim_insertion_scripted/replay_buffer/next/meta.json diff --git a/tests/data/aloha_sim_insertion_scripted/next/observation/image/meta.json b/tests/data/aloha_sim_insertion_scripted/replay_buffer/next/observation/image/meta.json similarity index 100% rename from tests/data/aloha_sim_insertion_scripted/next/observation/image/meta.json rename to tests/data/aloha_sim_insertion_scripted/replay_buffer/next/observation/image/meta.json diff --git a/tests/data/aloha_sim_insertion_scripted/next/observation/image/top.memmap b/tests/data/aloha_sim_insertion_scripted/replay_buffer/next/observation/image/top.memmap similarity index 100% rename from tests/data/aloha_sim_insertion_scripted/next/observation/image/top.memmap rename to tests/data/aloha_sim_insertion_scripted/replay_buffer/next/observation/image/top.memmap diff --git a/tests/data/aloha_sim_insertion_scripted/next/observation/meta.json b/tests/data/aloha_sim_insertion_scripted/replay_buffer/next/observation/meta.json similarity index 100% rename from tests/data/aloha_sim_insertion_scripted/next/observation/meta.json rename to tests/data/aloha_sim_insertion_scripted/replay_buffer/next/observation/meta.json diff --git a/tests/data/aloha_sim_insertion_scripted/next/observation/state.memmap b/tests/data/aloha_sim_insertion_scripted/replay_buffer/next/observation/state.memmap similarity index 100% rename from tests/data/aloha_sim_insertion_scripted/next/observation/state.memmap rename to tests/data/aloha_sim_insertion_scripted/replay_buffer/next/observation/state.memmap diff --git a/tests/data/aloha_sim_insertion_scripted/observation/image/meta.json b/tests/data/aloha_sim_insertion_scripted/replay_buffer/observation/image/meta.json similarity index 100% rename from tests/data/aloha_sim_insertion_scripted/observation/image/meta.json rename to tests/data/aloha_sim_insertion_scripted/replay_buffer/observation/image/meta.json diff --git a/tests/data/aloha_sim_insertion_scripted/observation/image/top.memmap b/tests/data/aloha_sim_insertion_scripted/replay_buffer/observation/image/top.memmap similarity index 100% rename from tests/data/aloha_sim_insertion_scripted/observation/image/top.memmap rename to tests/data/aloha_sim_insertion_scripted/replay_buffer/observation/image/top.memmap diff --git a/tests/data/aloha_sim_insertion_scripted/observation/meta.json b/tests/data/aloha_sim_insertion_scripted/replay_buffer/observation/meta.json similarity index 100% rename from tests/data/aloha_sim_insertion_scripted/observation/meta.json rename to tests/data/aloha_sim_insertion_scripted/replay_buffer/observation/meta.json diff --git a/tests/data/aloha_sim_insertion_scripted/observation/state.memmap b/tests/data/aloha_sim_insertion_scripted/replay_buffer/observation/state.memmap similarity index 100% rename from tests/data/aloha_sim_insertion_scripted/observation/state.memmap rename to tests/data/aloha_sim_insertion_scripted/replay_buffer/observation/state.memmap diff --git a/tests/data/aloha_sim_transfer_cube_human/action.memmap b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/action.memmap similarity index 100% rename from tests/data/aloha_sim_transfer_cube_human/action.memmap rename to tests/data/aloha_sim_transfer_cube_human/replay_buffer/action.memmap diff --git a/tests/data/aloha_sim_transfer_cube_human/episode.memmap b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/episode.memmap similarity index 100% rename from tests/data/aloha_sim_transfer_cube_human/episode.memmap rename to tests/data/aloha_sim_transfer_cube_human/replay_buffer/episode.memmap diff --git a/tests/data/aloha_sim_transfer_cube_human/frame_id.memmap b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/frame_id.memmap similarity index 100% rename from tests/data/aloha_sim_transfer_cube_human/frame_id.memmap rename to tests/data/aloha_sim_transfer_cube_human/replay_buffer/frame_id.memmap diff --git a/tests/data/aloha_sim_transfer_cube_human/meta.json b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/meta.json similarity index 100% rename from tests/data/aloha_sim_transfer_cube_human/meta.json rename to tests/data/aloha_sim_transfer_cube_human/replay_buffer/meta.json diff --git a/tests/data/aloha_sim_transfer_cube_human/next/done.memmap b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/done.memmap similarity index 100% rename from tests/data/aloha_sim_transfer_cube_human/next/done.memmap rename to tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/done.memmap diff --git a/tests/data/aloha_sim_transfer_cube_human/next/meta.json b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/meta.json similarity index 100% rename from tests/data/aloha_sim_transfer_cube_human/next/meta.json rename to tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/meta.json diff --git a/tests/data/aloha_sim_transfer_cube_human/next/observation/image/meta.json b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/observation/image/meta.json similarity index 100% rename from tests/data/aloha_sim_transfer_cube_human/next/observation/image/meta.json rename to tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/observation/image/meta.json diff --git a/tests/data/aloha_sim_transfer_cube_human/next/observation/image/top.memmap b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/observation/image/top.memmap similarity index 100% rename from tests/data/aloha_sim_transfer_cube_human/next/observation/image/top.memmap rename to tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/observation/image/top.memmap diff --git a/tests/data/aloha_sim_transfer_cube_human/next/observation/meta.json b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/observation/meta.json similarity index 100% rename from tests/data/aloha_sim_transfer_cube_human/next/observation/meta.json rename to tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/observation/meta.json diff --git a/tests/data/aloha_sim_transfer_cube_human/next/observation/state.memmap b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/observation/state.memmap similarity index 100% rename from tests/data/aloha_sim_transfer_cube_human/next/observation/state.memmap rename to tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/observation/state.memmap diff --git a/tests/data/aloha_sim_transfer_cube_human/observation/image/meta.json b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/observation/image/meta.json similarity index 100% rename from tests/data/aloha_sim_transfer_cube_human/observation/image/meta.json rename to tests/data/aloha_sim_transfer_cube_human/replay_buffer/observation/image/meta.json diff --git a/tests/data/aloha_sim_transfer_cube_human/observation/image/top.memmap b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/observation/image/top.memmap similarity index 100% rename from tests/data/aloha_sim_transfer_cube_human/observation/image/top.memmap rename to tests/data/aloha_sim_transfer_cube_human/replay_buffer/observation/image/top.memmap diff --git a/tests/data/aloha_sim_transfer_cube_human/observation/meta.json b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/observation/meta.json similarity index 100% rename from tests/data/aloha_sim_transfer_cube_human/observation/meta.json rename to tests/data/aloha_sim_transfer_cube_human/replay_buffer/observation/meta.json diff --git a/tests/data/aloha_sim_transfer_cube_human/observation/state.memmap b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/observation/state.memmap similarity index 100% rename from tests/data/aloha_sim_transfer_cube_human/observation/state.memmap rename to tests/data/aloha_sim_transfer_cube_human/replay_buffer/observation/state.memmap diff --git a/tests/data/aloha_sim_transfer_cube_scripted/action.memmap b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/action.memmap similarity index 100% rename from tests/data/aloha_sim_transfer_cube_scripted/action.memmap rename to tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/action.memmap diff --git a/tests/data/aloha_sim_transfer_cube_scripted/episode.memmap b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/episode.memmap similarity index 100% rename from tests/data/aloha_sim_transfer_cube_scripted/episode.memmap rename to tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/episode.memmap diff --git a/tests/data/aloha_sim_transfer_cube_scripted/frame_id.memmap b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/frame_id.memmap similarity index 100% rename from tests/data/aloha_sim_transfer_cube_scripted/frame_id.memmap rename to tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/frame_id.memmap diff --git a/tests/data/aloha_sim_transfer_cube_scripted/meta.json b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/meta.json similarity index 100% rename from tests/data/aloha_sim_transfer_cube_scripted/meta.json rename to tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/meta.json diff --git a/tests/data/aloha_sim_transfer_cube_scripted/next/done.memmap b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/done.memmap similarity index 100% rename from tests/data/aloha_sim_transfer_cube_scripted/next/done.memmap rename to tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/done.memmap diff --git a/tests/data/aloha_sim_transfer_cube_scripted/next/meta.json b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/meta.json similarity index 100% rename from tests/data/aloha_sim_transfer_cube_scripted/next/meta.json rename to tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/meta.json diff --git a/tests/data/aloha_sim_transfer_cube_scripted/next/observation/image/meta.json b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/observation/image/meta.json similarity index 100% rename from tests/data/aloha_sim_transfer_cube_scripted/next/observation/image/meta.json rename to tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/observation/image/meta.json diff --git a/tests/data/aloha_sim_transfer_cube_scripted/next/observation/image/top.memmap b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/observation/image/top.memmap similarity index 100% rename from tests/data/aloha_sim_transfer_cube_scripted/next/observation/image/top.memmap rename to tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/observation/image/top.memmap diff --git a/tests/data/aloha_sim_transfer_cube_scripted/next/observation/meta.json b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/observation/meta.json similarity index 100% rename from tests/data/aloha_sim_transfer_cube_scripted/next/observation/meta.json rename to tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/observation/meta.json diff --git a/tests/data/aloha_sim_transfer_cube_scripted/next/observation/state.memmap b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/observation/state.memmap similarity index 100% rename from tests/data/aloha_sim_transfer_cube_scripted/next/observation/state.memmap rename to tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/observation/state.memmap diff --git a/tests/data/aloha_sim_transfer_cube_scripted/observation/image/meta.json b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/observation/image/meta.json similarity index 100% rename from tests/data/aloha_sim_transfer_cube_scripted/observation/image/meta.json rename to tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/observation/image/meta.json diff --git a/tests/data/aloha_sim_transfer_cube_scripted/observation/image/top.memmap b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/observation/image/top.memmap similarity index 100% rename from tests/data/aloha_sim_transfer_cube_scripted/observation/image/top.memmap rename to tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/observation/image/top.memmap diff --git a/tests/data/aloha_sim_transfer_cube_scripted/observation/meta.json b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/observation/meta.json similarity index 100% rename from tests/data/aloha_sim_transfer_cube_scripted/observation/meta.json rename to tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/observation/meta.json diff --git a/tests/data/aloha_sim_transfer_cube_scripted/observation/state.memmap b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/observation/state.memmap similarity index 100% rename from tests/data/aloha_sim_transfer_cube_scripted/observation/state.memmap rename to tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/observation/state.memmap diff --git a/tests/data/pusht/action.memmap b/tests/data/pusht/replay_buffer/action.memmap similarity index 100% rename from tests/data/pusht/action.memmap rename to tests/data/pusht/replay_buffer/action.memmap diff --git a/tests/data/pusht/episode.memmap b/tests/data/pusht/replay_buffer/episode.memmap similarity index 100% rename from tests/data/pusht/episode.memmap rename to tests/data/pusht/replay_buffer/episode.memmap diff --git a/tests/data/pusht/frame_id.memmap b/tests/data/pusht/replay_buffer/frame_id.memmap similarity index 100% rename from tests/data/pusht/frame_id.memmap rename to tests/data/pusht/replay_buffer/frame_id.memmap diff --git a/tests/data/pusht/meta.json b/tests/data/pusht/replay_buffer/meta.json similarity index 100% rename from tests/data/pusht/meta.json rename to tests/data/pusht/replay_buffer/meta.json diff --git a/tests/data/pusht/next/done.memmap b/tests/data/pusht/replay_buffer/next/done.memmap similarity index 100% rename from tests/data/pusht/next/done.memmap rename to tests/data/pusht/replay_buffer/next/done.memmap diff --git a/tests/data/pusht/next/meta.json b/tests/data/pusht/replay_buffer/next/meta.json similarity index 100% rename from tests/data/pusht/next/meta.json rename to tests/data/pusht/replay_buffer/next/meta.json diff --git a/tests/data/pusht/next/observation/image.memmap b/tests/data/pusht/replay_buffer/next/observation/image.memmap similarity index 100% rename from tests/data/pusht/next/observation/image.memmap rename to tests/data/pusht/replay_buffer/next/observation/image.memmap diff --git a/tests/data/pusht/next/observation/meta.json b/tests/data/pusht/replay_buffer/next/observation/meta.json similarity index 100% rename from tests/data/pusht/next/observation/meta.json rename to tests/data/pusht/replay_buffer/next/observation/meta.json diff --git a/tests/data/pusht/next/observation/state.memmap b/tests/data/pusht/replay_buffer/next/observation/state.memmap similarity index 100% rename from tests/data/pusht/next/observation/state.memmap rename to tests/data/pusht/replay_buffer/next/observation/state.memmap diff --git a/tests/data/pusht/next/reward.memmap b/tests/data/pusht/replay_buffer/next/reward.memmap similarity index 100% rename from tests/data/pusht/next/reward.memmap rename to tests/data/pusht/replay_buffer/next/reward.memmap diff --git a/tests/data/pusht/next/success.memmap b/tests/data/pusht/replay_buffer/next/success.memmap similarity index 100% rename from tests/data/pusht/next/success.memmap rename to tests/data/pusht/replay_buffer/next/success.memmap diff --git a/tests/data/pusht/observation/image.memmap b/tests/data/pusht/replay_buffer/observation/image.memmap similarity index 100% rename from tests/data/pusht/observation/image.memmap rename to tests/data/pusht/replay_buffer/observation/image.memmap diff --git a/tests/data/pusht/observation/meta.json b/tests/data/pusht/replay_buffer/observation/meta.json similarity index 100% rename from tests/data/pusht/observation/meta.json rename to tests/data/pusht/replay_buffer/observation/meta.json diff --git a/tests/data/pusht/observation/state.memmap b/tests/data/pusht/replay_buffer/observation/state.memmap similarity index 100% rename from tests/data/pusht/observation/state.memmap rename to tests/data/pusht/replay_buffer/observation/state.memmap diff --git a/tests/scripts/mock_dataset.py b/tests/scripts/mock_dataset.py index c58280d7..d9c86464 100644 --- a/tests/scripts/mock_dataset.py +++ b/tests/scripts/mock_dataset.py @@ -1,5 +1,18 @@ """ - usage: `python tests/scripts/mock_dataset.py --in-data-dir data/pusht --out-data-dir tests/data/pusht` +This script is designed to facilitate the creation of a subset of an existing dataset by selecting a specific number of frames from the original dataset. +This subset can then be used for running quick unit tests. +The script takes an input directory containing the original dataset and an output directory where the subset of the dataset will be saved. +Additionally, the number of frames to include in the subset can be specified. +The script ensures that the subset is a representative sample of the original dataset by copying the specified number of frames and retaining the structure and format of the data. + +Usage: + Run the script with the following command, specifying the path to the input data directory, + the path to the output data directory, and optionally the number of frames to include in the subset dataset: + + `python tests/scripts/mock_dataset.py --in-data-dir path/to/input_data --out-data-dir path/to/output_data` + +Example: + `python tests/scripts/mock_dataset.py --in-data-dir data/pusht --out-data-dir tests/data/pusht` """ import argparse @@ -9,13 +22,16 @@ from tensordict import TensorDict from pathlib import Path -def mock_dataset(in_data_dir, out_data_dir, num_frames=50): +def mock_dataset(in_data_dir, out_data_dir, num_frames): + in_data_dir = Path(in_data_dir) + out_data_dir = Path(out_data_dir) + # load full dataset as a tensor dict - in_td_data = TensorDict.load_memmap(in_data_dir) + in_td_data = TensorDict.load_memmap(in_data_dir / "replay_buffer") # use 1 frame to know the specification of the dataset # and copy it over `n` frames in the test artifact directory - out_td_data = in_td_data[0].expand(num_frames).memmap_like(out_data_dir) + out_td_data = in_td_data[0].expand(num_frames).memmap_like(out_data_dir / "replay_buffer") # copy the first `n` frames so that we have real data out_td_data[:num_frames] = in_td_data[:num_frames].clone() @@ -24,18 +40,19 @@ def mock_dataset(in_data_dir, out_data_dir, num_frames=50): out_td_data.lock_() # copy the full statistics of dataset since it's pretty small - in_stats_path = Path(in_data_dir) / "stats.pth" - out_stats_path = Path(out_data_dir) / "stats.pth" + in_stats_path = in_data_dir / "stats.pth" + out_stats_path = out_data_dir / "stats.pth" shutil.copy(in_stats_path, out_stats_path) if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Create dataset") + parser = argparse.ArgumentParser(description="Create a dataset with a subset of frames for quick testing.") parser.add_argument("--in-data-dir", type=str, help="Path to input data") parser.add_argument("--out-data-dir", type=str, help="Path to save the output data") + parser.add_argument("--num-frames", type=int, default=50, help="Number of frames to copy over") args = parser.parse_args() - mock_dataset(args.in_data_dir, args.out_data_dir) \ No newline at end of file + mock_dataset(args.in_data_dir, args.out_data_dir, args.num_frames) \ No newline at end of file