Add replay_buffer directory in pusht datasets + aloha (WIP)
This commit is contained in:
parent
099a465367
commit
6a1a29386a
34
README.md
34
README.md
|
@ -138,7 +138,7 @@ git lfs pull
|
|||
|
||||
When adding a new dataset, mock it with
|
||||
```
|
||||
python tests/scripts/mock_dataset.py --in-data-dir data/<dataset_id> --out-data-dir tests/data/<dataset_id>
|
||||
python tests/scripts/mock_dataset.py --in-data-dir data/$DATASET --out-data-dir tests/data/$DATASET
|
||||
```
|
||||
|
||||
Run tests
|
||||
|
@ -155,7 +155,9 @@ huggingface-cli login --token $HUGGINGFACE_TOKEN --add-to-git-credential
|
|||
|
||||
Then you can upload it to the hub with:
|
||||
```
|
||||
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload --repo-type dataset $HF_USER/$DATASET data/$DATASET
|
||||
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload $HF_USER/$DATASET data/$DATASET \
|
||||
--repo-type dataset \
|
||||
--revision v1.0
|
||||
```
|
||||
|
||||
For instance, for [cadene/pusht](https://huggingface.co/datasets/cadene/pusht), we used:
|
||||
|
@ -164,6 +166,34 @@ HF_USER=cadene
|
|||
DATASET=pusht
|
||||
```
|
||||
|
||||
If you want to improve an existing dataset, you can download it locally with:
|
||||
```
|
||||
mkdir -p data/$DATASET
|
||||
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli download $HF_USER/$DATASET \
|
||||
--repo-type dataset \
|
||||
--local-dir data/$DATASET \
|
||||
--local-dir-use-symlinks=False \
|
||||
--revision v1.0
|
||||
```
|
||||
|
||||
Iterate on your code and dataset with:
|
||||
```
|
||||
DATA_DIR=data python train.py
|
||||
```
|
||||
|
||||
Then upload a new version:
|
||||
```
|
||||
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload $HF_USER/$DATASET data/$DATASET \
|
||||
--repo-type dataset \
|
||||
--revision v1.1 \
|
||||
--delete "*"
|
||||
```
|
||||
|
||||
And you might want to mock the dataset if you need to update the unit tests as well:
|
||||
```
|
||||
python tests/scripts/mock_dataset.py --in-data-dir data/$DATASET --out-data-dir tests/data/$DATASET
|
||||
```
|
||||
|
||||
|
||||
## Acknowledgment
|
||||
- Our Diffusion policy and Pusht environment are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu/)
|
||||
|
|
|
@ -19,6 +19,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
|
|||
def __init__(
|
||||
self,
|
||||
dataset_id: str,
|
||||
version: str | None = None,
|
||||
batch_size: int = None,
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
|
@ -31,6 +32,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
|
|||
transform: "torchrl.envs.Transform" = None,
|
||||
):
|
||||
self.dataset_id = dataset_id
|
||||
self.version = version
|
||||
self.shuffle = shuffle
|
||||
self.root = root
|
||||
storage = self._download_or_load_dataset()
|
||||
|
@ -96,10 +98,14 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
|
|||
|
||||
def _download_or_load_dataset(self) -> torch.StorageBase:
|
||||
if self.root is None:
|
||||
self.data_dir = Path(snapshot_download(repo_id=f"cadene/{self.dataset_id}", repo_type="dataset"))
|
||||
self.data_dir = Path(
|
||||
snapshot_download(
|
||||
repo_id=f"cadene/{self.dataset_id}", repo_type="dataset", revision=self.version
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.data_dir = self.root / self.dataset_id
|
||||
return TensorStorage(TensorDict.load_memmap(self.data_dir))
|
||||
return TensorStorage(TensorDict.load_memmap(self.data_dir / "replay_buffer"))
|
||||
|
||||
def _compute_stats(self, num_batch=100, batch_size=32):
|
||||
rb = TensorDictReplayBuffer(
|
||||
|
|
|
@ -84,6 +84,7 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
|
|||
def __init__(
|
||||
self,
|
||||
dataset_id: str,
|
||||
version: str | None = None,
|
||||
batch_size: int = None,
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
|
@ -99,6 +100,7 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
|
|||
|
||||
super().__init__(
|
||||
dataset_id,
|
||||
version,
|
||||
batch_size,
|
||||
shuffle=shuffle,
|
||||
root=root,
|
||||
|
|
|
@ -87,6 +87,7 @@ class PushtExperienceReplay(AbstractExperienceReplay):
|
|||
def __init__(
|
||||
self,
|
||||
dataset_id: str,
|
||||
version: str | None = "v1.0",
|
||||
batch_size: int = None,
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
|
@ -100,6 +101,7 @@ class PushtExperienceReplay(AbstractExperienceReplay):
|
|||
):
|
||||
super().__init__(
|
||||
dataset_id,
|
||||
version,
|
||||
batch_size,
|
||||
shuffle=shuffle,
|
||||
root=root,
|
||||
|
|
|
@ -40,6 +40,7 @@ class SimxarmExperienceReplay(AbstractExperienceReplay):
|
|||
def __init__(
|
||||
self,
|
||||
dataset_id: str,
|
||||
version: str | None = None,
|
||||
batch_size: int = None,
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
|
@ -53,6 +54,7 @@ class SimxarmExperienceReplay(AbstractExperienceReplay):
|
|||
):
|
||||
super().__init__(
|
||||
dataset_id,
|
||||
version,
|
||||
batch_size,
|
||||
shuffle=shuffle,
|
||||
root=root,
|
||||
|
|
|
@ -10,12 +10,15 @@ from pathlib import Path
|
|||
|
||||
|
||||
def mock_dataset(in_data_dir, out_data_dir, num_frames=50):
|
||||
in_data_dir = Path(in_data_dir)
|
||||
out_data_dir = Path(out_data_dir)
|
||||
|
||||
# load full dataset as a tensor dict
|
||||
in_td_data = TensorDict.load_memmap(in_data_dir)
|
||||
in_td_data = TensorDict.load_memmap(in_data_dir / "replay_buffer")
|
||||
|
||||
# use 1 frame to know the specification of the dataset
|
||||
# and copy it over `n` frames in the test artifact directory
|
||||
out_td_data = in_td_data[0].expand(num_frames).memmap_like(out_data_dir)
|
||||
out_td_data = in_td_data[0].expand(num_frames).memmap_like(out_data_dir / "replay_buffer")
|
||||
|
||||
# copy the first `n` frames so that we have real data
|
||||
out_td_data[:num_frames] = in_td_data[:num_frames].clone()
|
||||
|
@ -24,8 +27,8 @@ def mock_dataset(in_data_dir, out_data_dir, num_frames=50):
|
|||
out_td_data.lock_()
|
||||
|
||||
# copy the full statistics of dataset since it's pretty small
|
||||
in_stats_path = Path(in_data_dir) / "stats.pth"
|
||||
out_stats_path = Path(out_data_dir) / "stats.pth"
|
||||
in_stats_path = in_data_dir / "stats.pth"
|
||||
out_stats_path = out_data_dir / "stats.pth"
|
||||
shutil.copy(in_stats_path, out_stats_path)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue