diff --git a/lerobot/common/datasets/simxarm.py b/lerobot/common/datasets/simxarm.py index 784242cc..1d56850e 100644 --- a/lerobot/common/datasets/simxarm.py +++ b/lerobot/common/datasets/simxarm.py @@ -1,4 +1,5 @@ import pickle +import zipfile from pathlib import Path from typing import Callable @@ -15,6 +16,22 @@ from torchrl.data.replay_buffers.writers import Writer from lerobot.common.datasets.abstract import AbstractExperienceReplay +def download(): + raise NotImplementedError() + import gdown + + url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya" + download_path = "data.zip" + gdown.download(url, download_path, quiet=False) + print("Extracting...") + with zipfile.ZipFile(download_path, "r") as zip_f: + for member in zip_f.namelist(): + if member.startswith("data/xarm") and member.endswith(".pkl"): + print(member) + zip_f.extract(member=member) + Path(download_path).unlink() + + class SimxarmExperienceReplay(AbstractExperienceReplay): available_datasets = [ "xarm_lift_medium", @@ -48,8 +65,8 @@ class SimxarmExperienceReplay(AbstractExperienceReplay): ) def _download_and_preproc(self): - # download - # TODO(rcadene) + # TODO(rcadene): finish download + download() dataset_path = self.data_dir / "buffer.pkl" print(f"Using offline dataset '{dataset_path}'")