From 19730b34127a6a39e31a25a8edb3bd711df9ed90 Mon Sep 17 00:00:00 2001 From: Cadene Date: Thu, 14 Mar 2024 16:59:37 +0000 Subject: [PATCH] Add pusht on hf dataset (WIP) --- lerobot/common/datasets/pusht.py | 5 +++++ poetry.lock | 2 +- pyproject.toml | 1 + 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index ae987ad1..c72bc9c1 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -8,6 +8,7 @@ import pymunk import torch import torchrl import tqdm +from huggingface_hub import snapshot_download from tensordict import TensorDict from torchrl.data.replay_buffers.samplers import SliceSampler from torchrl.data.replay_buffers.storages import TensorStorage @@ -112,6 +113,10 @@ class PushtExperienceReplay(AbstractExperienceReplay): ) def _download_and_preproc(self): + snapshot_download(repo_id="cadene/pusht", local_dir=self.data_dir) + return TensorStorage(TensorDict.load_memmap(self.data_dir)) + + def _download_and_preproc_obsolete(self): raw_dir = self.data_dir.parent / f"{self.data_dir.name}_raw" zarr_path = (raw_dir / PUSHT_ZARR).resolve() if not zarr_path.is_dir(): diff --git a/poetry.lock b/poetry.lock index 59de0ec5..b2c8cc36 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3254,4 +3254,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "3d82309a7b2388d774b56ceb6f6906ef0732d8cedda0d76cc84a30e239949be8" +content-hash = "0794a87fd309dffa0ad2982b6902bed7f35ae9e2a82433420516798da04c7197" diff --git a/pyproject.toml b/pyproject.toml index 85af7f82..8542383e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ diffusers = "^0.26.3" torchvision = "^0.17.1" h5py = "^3.10.0" dm-control = "1.0.14" +huggingface-hub = "^0.21.4" [tool.poetry.group.dev.dependencies]