From c1942d45d311d286e6b6a7d6e9ed78af3f9adc84 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Fri, 1 Mar 2024 14:59:05 +0100 Subject: [PATCH] Fixes for PR #4 --- lerobot/common/datasets/factory.py | 1 - lerobot/common/datasets/pusht.py | 17 ++++++++--------- lerobot/common/datasets/utils.py | 30 ++++++++++++++++++++++++++++++ lerobot/common/utils.py | 29 ----------------------------- 4 files changed, 38 insertions(+), 39 deletions(-) create mode 100644 lerobot/common/datasets/utils.py diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 2d26c4cb..9fc0d2da 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -73,7 +73,6 @@ def make_offline_buffer(cfg, sampler=None): elif cfg.env.name == "pusht": offline_buffer = PushtExperienceReplay( "pusht", - download=True, streaming=False, root=DATA_PATH, sampler=sampler, diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index e41b82f4..76f4c6cd 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -18,7 +18,7 @@ from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer from diffusion_policy.common.replay_buffer import ReplayBuffer from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely -from lerobot.common import utils +from lerobot.common.datasets import utils # as define in env SUCCESS_THRESHOLD = 0.95 # 95% coverage, @@ -97,17 +97,15 @@ class PushtExperienceReplay(TensorDictReplayBuffer): replacement: bool = None, streaming: bool = False, root: Path = None, - download: bool | str = False, sampler: Sampler = None, writer: Writer = None, collate_fn: Callable = None, pin_memory: bool = False, prefetch: int = None, - transform: "torchrl.envs.Transform" = None, # noqa-F821 + transform: "torchrl.envs.Transform" = None, # noqa: F821 split_trajs: bool = False, strict_length: bool = True, ): - self.download = download if streaming: raise NotImplementedError self.streaming = streaming @@ -129,8 +127,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer): os.makedirs(root, exist_ok=True) self.root = root - self.raw = self.root / "raw" - if self.download == "force" or (self.download and not self._is_downloaded()): + if not self._is_downloaded(): storage = self._download_and_preproc() else: storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id)) @@ -192,9 +189,11 @@ class PushtExperienceReplay(TensorDictReplayBuffer): def _download_and_preproc(self): # download - self.raw.mkdir(exist_ok=True) - utils.download_and_extract_zip(PUSHT_URL, self.raw) - zarr_path = (self.raw / PUSHT_ZARR).resolve() + raw_dir = self.root / "raw" + zarr_path = (raw_dir / PUSHT_ZARR).resolve() + if not zarr_path.is_dir(): + raw_dir.mkdir(parents=True, exist_ok=True) + utils.download_and_extract_zip(PUSHT_URL, raw_dir) # load dataset_dict = ReplayBuffer.copy_from_path(zarr_path) # , keys=['img', 'state', 'action']) diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py new file mode 100644 index 00000000..0ad43a65 --- /dev/null +++ b/lerobot/common/datasets/utils.py @@ -0,0 +1,30 @@ +import io +import zipfile +from pathlib import Path + +import requests +import tqdm + + +def download_and_extract_zip(url: str, destination_folder: Path) -> bool: + print(f"downloading from {url}") + response = requests.get(url, stream=True) + if response.status_code == 200: + total_size = int(response.headers.get("content-length", 0)) + progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True) + + zip_file = io.BytesIO() + for chunk in response.iter_content(chunk_size=1024): + if chunk: + zip_file.write(chunk) + progress_bar.update(len(chunk)) + + progress_bar.close() + + zip_file.seek(0) + + with zipfile.ZipFile(zip_file, "r") as zip_ref: + zip_ref.extractall(destination_folder) + return True + else: + return False diff --git a/lerobot/common/utils.py b/lerobot/common/utils.py index 22c375c9..d174d4b5 100644 --- a/lerobot/common/utils.py +++ b/lerobot/common/utils.py @@ -1,38 +1,9 @@ -import io import logging import random -import zipfile from datetime import datetime -from pathlib import Path import numpy as np -import requests import torch -import tqdm - - -def download_and_extract_zip(url: str, destination_folder: Path) -> bool: - print(f"downloading from {url}") - response = requests.get(url, stream=True) - if response.status_code == 200: - total_size = int(response.headers.get("content-length", 0)) - progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True) - - zip_file = io.BytesIO() - for chunk in response.iter_content(chunk_size=1024): - if chunk: - zip_file.write(chunk) - progress_bar.update(len(chunk)) - - progress_bar.close() - - zip_file.seek(0) - - with zipfile.ZipFile(zip_file, "r") as zip_ref: - zip_ref.extractall(destination_folder) - return True - else: - return False def set_seed(seed):