Merge pull request #12 from Cadene/user/aliberts/2024_03_08_test_data

Add pusht test artifact
This commit is contained in:
Simon Alibert 2024-03-09 16:00:39 +01:00 committed by GitHub
commit fa7a947acc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 109 additions and 1 deletions

1
.gitattributes vendored Normal file
View File

@ -0,0 +1 @@
*.memmap filter=lfs diff=lfs merge=lfs -text

View File

@ -17,6 +17,7 @@ jobs:
runs-on: ubuntu-latest
env:
POETRY_VERSION: 1.8.1
DATA_DIR: tests/data
steps:
#----------------------------------------------
# check-out repo and set-up python

1
.gitignore vendored
View File

@ -54,6 +54,7 @@ pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
!tests/data
htmlcov/
.tox/
.nox/

View File

@ -115,6 +115,31 @@ pre-commit run -a
```
**Tests**
Install [git lfs](https://git-lfs.com/) to retrieve test artifacts (if you don't have it already).
On Mac:
```
pytest -sx tests
brew install git-lfs
git lfs install
```
On Ubuntu:
```
sudo apt-get install git-lfs
git lfs install
```
Pull artifacts if they're not in [tests/data](tests/data)
```
git lfs pull
```
When adding a new dataset, mock it with
```
python tests/scripts/mock_dataset.py --in-data-dir data/<dataset_id> --out-data-dir tests/data/<dataset_id>
```
Run tests
```
DATA_DIR="tests/data" pytest -sx tests
```

View File

@ -125,6 +125,9 @@ class PushtExperienceReplay(AbstractExperienceReplay):
episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs())
num_episodes = dataset_dict.meta["episode_ends"].shape[0]
total_frames = dataset_dict["action"].shape[0]
# to create test artifact
# num_episodes = 1
# total_frames = 50
assert len(
{dataset_dict[key].shape[0] for key in dataset_dict.keys()} # noqa: SIM118
), "Some data type dont have the same number of total frames."
@ -142,6 +145,8 @@ class PushtExperienceReplay(AbstractExperienceReplay):
idxtd = 0
for episode_id in tqdm.tqdm(range(num_episodes)):
idx1 = dataset_dict.meta["episode_ends"][episode_id]
# to create test artifact
# idx1 = 51
num_frames = idx1 - idx0

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ba17d8e5c30151ea5f7f6fc31f19e12a68ce2113774b74c8aca0c7ef962a75f4
size 400

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:7a12e561363385e9dfeeab326368731c030ed4b374e7f5897ac819159d2884c5
size 400

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c202d9cfc7858fd49d522047e16948359bbbb2eda2d3825d552e45a78d5f8585
size 400

View File

@ -0,0 +1 @@
{"action": {"device": "cpu", "shape": [50, 2], "dtype": "torch.float32"}, "episode": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "frame_id": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:cc2786e1f9910a9d811400edcddaf7075195f7a16b216dcbefba3bc7c4f2ae51
size 50

View File

@ -0,0 +1 @@
{"reward": {"device": "cpu", "shape": [50, 1], "dtype": "torch.float32"}, "done": {"device": "cpu", "shape": [50, 1], "dtype": "torch.bool"}, "success": {"device": "cpu", "shape": [50, 1], "dtype": "torch.bool"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ff6a3748c8223a82e54c61442df7b8baf478a20497ee2353645a1e9ccd765162
size 5529600

View File

@ -0,0 +1 @@
{"image": {"device": "cpu", "shape": [50, 3, 96, 96], "dtype": "torch.float32"}, "state": {"device": "cpu", "shape": [50, 2], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:fad4ece6d5fd66bbafa34f6ff383c483410082b8d7d4f4616808c3c458ce1d43
size 400

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:6d9c54dee5660c46886f32d80e57e9dd0ffa57ee0cd2a762b036d9c8e0c3a33a
size 200

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:cc2786e1f9910a9d811400edcddaf7075195f7a16b216dcbefba3bc7c4f2ae51
size 50

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4bbde5cfd8cff9fd9fc6c9a57177f6fd31c8a03cf853b7d2234312f38380b0ba
size 5529600

View File

@ -0,0 +1 @@
{"image": {"device": "cpu", "shape": [50, 3, 96, 96], "dtype": "torch.float32"}, "state": {"device": "cpu", "shape": [50, 2], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:67c7e39090a16546fb1eade833d704f26464d574d7e431415f828159a154d2bf
size 400

BIN
tests/data/pusht/stats.pth Normal file

Binary file not shown.

View File

@ -0,0 +1,41 @@
"""
usage: `python tests/scripts/mock_dataset.py --in-data-dir data/pusht --out-data-dir tests/data/pusht`
"""
import argparse
import shutil
from tensordict import TensorDict
from pathlib import Path
def mock_dataset(in_data_dir, out_data_dir, num_frames=50):
# load full dataset as a tensor dict
in_td_data = TensorDict.load_memmap(in_data_dir)
# use 1 frame to know the specification of the dataset
# and copy it over `n` frames in the test artifact directory
out_td_data = in_td_data[0].expand(num_frames).memmap_like(out_data_dir)
# copy the first `n` frames so that we have real data
out_td_data[:num_frames] = in_td_data[:num_frames].clone()
# make sure everything has been properly written
out_td_data.lock_()
# copy the full statistics of dataset since it's pretty small
in_stats_path = Path(in_data_dir) / "stats.pth"
out_stats_path = Path(out_data_dir) / "stats.pth"
shutil.copy(in_stats_path, out_stats_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Create dataset")
parser.add_argument("--in-data-dir", type=str, help="Path to input data")
parser.add_argument("--out-data-dir", type=str, help="Path to save the output data")
args = parser.parse_args()
mock_dataset(args.in_data_dir, args.out_data_dir)