From b862145e229ba29d9aa15f61a0a345c667e5b31e Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Fri, 1 Mar 2024 14:31:54 +0100 Subject: [PATCH 1/2] Added pusht dataset auto-download --- lerobot/common/datasets/factory.py | 12 ++++---- lerobot/common/datasets/pusht.py | 44 ++++++++++++++---------------- lerobot/common/utils.py | 29 ++++++++++++++++++++ 3 files changed, 57 insertions(+), 28 deletions(-) diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 9a129ba1..2d26c4cb 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -1,9 +1,13 @@ +from pathlib import Path + import torch from torchrl.data.replay_buffers import PrioritizedSliceSampler from lerobot.common.datasets.pusht import PushtExperienceReplay from lerobot.common.datasets.simxarm import SimxarmExperienceReplay +DATA_PATH = Path("data/") + # TODO(rcadene): implement # dataset_d4rl = D4RLExperienceReplay( @@ -60,7 +64,7 @@ def make_offline_buffer(cfg, sampler=None): # download="force", download=True, streaming=False, - root="data", + root=str(DATA_PATH), sampler=sampler, batch_size=batch_size, pin_memory=pin_memory, @@ -69,11 +73,9 @@ def make_offline_buffer(cfg, sampler=None): elif cfg.env.name == "pusht": offline_buffer = PushtExperienceReplay( "pusht", - # download="force", - # TODO(aliberts): automate download - download=False, + download=True, streaming=False, - root="data", + root=DATA_PATH, sampler=sampler, batch_size=batch_size, pin_memory=pin_memory, diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index 372d9fab..e41b82f4 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -9,8 +9,6 @@ import pymunk import torch import torchrl import tqdm -from diffusion_policy.common.replay_buffer import ReplayBuffer -from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely from tensordict import TensorDict from torchrl.data.datasets.utils import _get_root_dir from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer @@ -18,10 +16,16 @@ from torchrl.data.replay_buffers.samplers import Sampler from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer +from diffusion_policy.common.replay_buffer import ReplayBuffer +from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely +from lerobot.common import utils + # as define in env SUCCESS_THRESHOLD = 0.95 # 95% coverage, DEFAULT_TEE_MASK = pymunk.ShapeFilter.ALL_MASKS() +PUSHT_URL = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip" +PUSHT_ZARR = Path("pusht/pusht_cchi_v7_replay.zarr") def get_goal_pose_body(pose): @@ -83,7 +87,7 @@ def add_tee( class PushtExperienceReplay(TensorDictReplayBuffer): def __init__( self, - dataset_id, + dataset_id: str, batch_size: int = None, *, shuffle: bool = True, @@ -93,7 +97,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer): replacement: bool = None, streaming: bool = False, root: Path = None, - download: bool = False, + download: bool | str = False, sampler: Sampler = None, writer: Writer = None, collate_fn: Callable = None, @@ -120,13 +124,12 @@ class PushtExperienceReplay(TensorDictReplayBuffer): if split_trajs: raise NotImplementedError - if self.download: - raise NotImplementedError() - if root is None: root = _get_root_dir("pusht") os.makedirs(root, exist_ok=True) - self.root = Path(root) + + self.root = root + self.raw = self.root / "raw" if self.download == "force" or (self.download and not self._is_downloaded()): storage = self._download_and_preproc() else: @@ -173,39 +176,34 @@ class PushtExperienceReplay(TensorDictReplayBuffer): ) @property - def num_samples(self): + def num_samples(self) -> int: return len(self) @property - def num_episodes(self): + def num_episodes(self) -> int: return len(self._storage._storage["episode"].unique()) @property - def data_path_root(self): - if self.streaming: - return None - return self.root / self.dataset_id + def data_path_root(self) -> Path: + return None if self.streaming else self.root / self.dataset_id - def _is_downloaded(self): - return os.path.exists(self.data_path_root) + def _is_downloaded(self) -> bool: + return self.data_path_root.is_dir() def _download_and_preproc(self): # download - # TODO(rcadene) + self.raw.mkdir(exist_ok=True) + utils.download_and_extract_zip(PUSHT_URL, self.raw) + zarr_path = (self.raw / PUSHT_ZARR).resolve() # load - # TODO(aliberts): Dynamic paths - zarr_path = ( - "/home/rcadene/code/diffusion_policy/data/pusht/pusht_cchi_v7_replay.zarr" - # "/home/simon/build/diffusion_policy/data/pusht/pusht_cchi_v7_replay.zarr" - ) dataset_dict = ReplayBuffer.copy_from_path(zarr_path) # , keys=['img', 'state', 'action']) episode_ids = dataset_dict.get_episode_idxs() num_episodes = dataset_dict.meta["episode_ends"].shape[0] total_frames = dataset_dict["action"].shape[0] assert len( - {dataset_dict[key].shape[0] for key in dataset_dict} + {dataset_dict[key].shape[0] for key in dataset_dict.keys()} # noqa: SIM118 ), "Some data type dont have the same number of total frames." # TODO: verify that goal pose is expected to be fixed diff --git a/lerobot/common/utils.py b/lerobot/common/utils.py index d174d4b5..22c375c9 100644 --- a/lerobot/common/utils.py +++ b/lerobot/common/utils.py @@ -1,9 +1,38 @@ +import io import logging import random +import zipfile from datetime import datetime +from pathlib import Path import numpy as np +import requests import torch +import tqdm + + +def download_and_extract_zip(url: str, destination_folder: Path) -> bool: + print(f"downloading from {url}") + response = requests.get(url, stream=True) + if response.status_code == 200: + total_size = int(response.headers.get("content-length", 0)) + progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True) + + zip_file = io.BytesIO() + for chunk in response.iter_content(chunk_size=1024): + if chunk: + zip_file.write(chunk) + progress_bar.update(len(chunk)) + + progress_bar.close() + + zip_file.seek(0) + + with zipfile.ZipFile(zip_file, "r") as zip_ref: + zip_ref.extractall(destination_folder) + return True + else: + return False def set_seed(seed): From c1942d45d311d286e6b6a7d6e9ed78af3f9adc84 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Fri, 1 Mar 2024 14:59:05 +0100 Subject: [PATCH 2/2] Fixes for PR #4 --- lerobot/common/datasets/factory.py | 1 - lerobot/common/datasets/pusht.py | 17 ++++++++--------- lerobot/common/datasets/utils.py | 30 ++++++++++++++++++++++++++++++ lerobot/common/utils.py | 29 ----------------------------- 4 files changed, 38 insertions(+), 39 deletions(-) create mode 100644 lerobot/common/datasets/utils.py diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 2d26c4cb..9fc0d2da 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -73,7 +73,6 @@ def make_offline_buffer(cfg, sampler=None): elif cfg.env.name == "pusht": offline_buffer = PushtExperienceReplay( "pusht", - download=True, streaming=False, root=DATA_PATH, sampler=sampler, diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index e41b82f4..76f4c6cd 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -18,7 +18,7 @@ from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer from diffusion_policy.common.replay_buffer import ReplayBuffer from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely -from lerobot.common import utils +from lerobot.common.datasets import utils # as define in env SUCCESS_THRESHOLD = 0.95 # 95% coverage, @@ -97,17 +97,15 @@ class PushtExperienceReplay(TensorDictReplayBuffer): replacement: bool = None, streaming: bool = False, root: Path = None, - download: bool | str = False, sampler: Sampler = None, writer: Writer = None, collate_fn: Callable = None, pin_memory: bool = False, prefetch: int = None, - transform: "torchrl.envs.Transform" = None, # noqa-F821 + transform: "torchrl.envs.Transform" = None, # noqa: F821 split_trajs: bool = False, strict_length: bool = True, ): - self.download = download if streaming: raise NotImplementedError self.streaming = streaming @@ -129,8 +127,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer): os.makedirs(root, exist_ok=True) self.root = root - self.raw = self.root / "raw" - if self.download == "force" or (self.download and not self._is_downloaded()): + if not self._is_downloaded(): storage = self._download_and_preproc() else: storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id)) @@ -192,9 +189,11 @@ class PushtExperienceReplay(TensorDictReplayBuffer): def _download_and_preproc(self): # download - self.raw.mkdir(exist_ok=True) - utils.download_and_extract_zip(PUSHT_URL, self.raw) - zarr_path = (self.raw / PUSHT_ZARR).resolve() + raw_dir = self.root / "raw" + zarr_path = (raw_dir / PUSHT_ZARR).resolve() + if not zarr_path.is_dir(): + raw_dir.mkdir(parents=True, exist_ok=True) + utils.download_and_extract_zip(PUSHT_URL, raw_dir) # load dataset_dict = ReplayBuffer.copy_from_path(zarr_path) # , keys=['img', 'state', 'action']) diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py new file mode 100644 index 00000000..0ad43a65 --- /dev/null +++ b/lerobot/common/datasets/utils.py @@ -0,0 +1,30 @@ +import io +import zipfile +from pathlib import Path + +import requests +import tqdm + + +def download_and_extract_zip(url: str, destination_folder: Path) -> bool: + print(f"downloading from {url}") + response = requests.get(url, stream=True) + if response.status_code == 200: + total_size = int(response.headers.get("content-length", 0)) + progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True) + + zip_file = io.BytesIO() + for chunk in response.iter_content(chunk_size=1024): + if chunk: + zip_file.write(chunk) + progress_bar.update(len(chunk)) + + progress_bar.close() + + zip_file.seek(0) + + with zipfile.ZipFile(zip_file, "r") as zip_ref: + zip_ref.extractall(destination_folder) + return True + else: + return False diff --git a/lerobot/common/utils.py b/lerobot/common/utils.py index 22c375c9..d174d4b5 100644 --- a/lerobot/common/utils.py +++ b/lerobot/common/utils.py @@ -1,38 +1,9 @@ -import io import logging import random -import zipfile from datetime import datetime -from pathlib import Path import numpy as np -import requests import torch -import tqdm - - -def download_and_extract_zip(url: str, destination_folder: Path) -> bool: - print(f"downloading from {url}") - response = requests.get(url, stream=True) - if response.status_code == 200: - total_size = int(response.headers.get("content-length", 0)) - progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True) - - zip_file = io.BytesIO() - for chunk in response.iter_content(chunk_size=1024): - if chunk: - zip_file.write(chunk) - progress_bar.update(len(chunk)) - - progress_bar.close() - - zip_file.seek(0) - - with zipfile.ZipFile(zip_file, "r") as zip_ref: - zip_ref.extractall(destination_folder) - return True - else: - return False def set_seed(seed):