Added pusht dataset auto-download

This commit is contained in:
Simon Alibert 2024-03-01 14:31:54 +01:00
parent ca948c1e5b
commit b862145e22
3 changed files with 57 additions and 28 deletions

View File

@ -1,9 +1,13 @@
from pathlib import Path
import torch
from torchrl.data.replay_buffers import PrioritizedSliceSampler
from lerobot.common.datasets.pusht import PushtExperienceReplay
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
DATA_PATH = Path("data/")
# TODO(rcadene): implement
# dataset_d4rl = D4RLExperienceReplay(
@ -60,7 +64,7 @@ def make_offline_buffer(cfg, sampler=None):
# download="force",
download=True,
streaming=False,
root="data",
root=str(DATA_PATH),
sampler=sampler,
batch_size=batch_size,
pin_memory=pin_memory,
@ -69,11 +73,9 @@ def make_offline_buffer(cfg, sampler=None):
elif cfg.env.name == "pusht":
offline_buffer = PushtExperienceReplay(
"pusht",
# download="force",
# TODO(aliberts): automate download
download=False,
download=True,
streaming=False,
root="data",
root=DATA_PATH,
sampler=sampler,
batch_size=batch_size,
pin_memory=pin_memory,

View File

@ -9,8 +9,6 @@ import pymunk
import torch
import torchrl
import tqdm
from diffusion_policy.common.replay_buffer import ReplayBuffer
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
from tensordict import TensorDict
from torchrl.data.datasets.utils import _get_root_dir
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
@ -18,10 +16,16 @@ from torchrl.data.replay_buffers.samplers import Sampler
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
from diffusion_policy.common.replay_buffer import ReplayBuffer
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
from lerobot.common import utils
# as define in env
SUCCESS_THRESHOLD = 0.95 # 95% coverage,
DEFAULT_TEE_MASK = pymunk.ShapeFilter.ALL_MASKS()
PUSHT_URL = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip"
PUSHT_ZARR = Path("pusht/pusht_cchi_v7_replay.zarr")
def get_goal_pose_body(pose):
@ -83,7 +87,7 @@ def add_tee(
class PushtExperienceReplay(TensorDictReplayBuffer):
def __init__(
self,
dataset_id,
dataset_id: str,
batch_size: int = None,
*,
shuffle: bool = True,
@ -93,7 +97,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
replacement: bool = None,
streaming: bool = False,
root: Path = None,
download: bool = False,
download: bool | str = False,
sampler: Sampler = None,
writer: Writer = None,
collate_fn: Callable = None,
@ -120,13 +124,12 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
if split_trajs:
raise NotImplementedError
if self.download:
raise NotImplementedError()
if root is None:
root = _get_root_dir("pusht")
os.makedirs(root, exist_ok=True)
self.root = Path(root)
self.root = root
self.raw = self.root / "raw"
if self.download == "force" or (self.download and not self._is_downloaded()):
storage = self._download_and_preproc()
else:
@ -173,39 +176,34 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
)
@property
def num_samples(self):
def num_samples(self) -> int:
return len(self)
@property
def num_episodes(self):
def num_episodes(self) -> int:
return len(self._storage._storage["episode"].unique())
@property
def data_path_root(self):
if self.streaming:
return None
return self.root / self.dataset_id
def data_path_root(self) -> Path:
return None if self.streaming else self.root / self.dataset_id
def _is_downloaded(self):
return os.path.exists(self.data_path_root)
def _is_downloaded(self) -> bool:
return self.data_path_root.is_dir()
def _download_and_preproc(self):
# download
# TODO(rcadene)
self.raw.mkdir(exist_ok=True)
utils.download_and_extract_zip(PUSHT_URL, self.raw)
zarr_path = (self.raw / PUSHT_ZARR).resolve()
# load
# TODO(aliberts): Dynamic paths
zarr_path = (
"/home/rcadene/code/diffusion_policy/data/pusht/pusht_cchi_v7_replay.zarr"
# "/home/simon/build/diffusion_policy/data/pusht/pusht_cchi_v7_replay.zarr"
)
dataset_dict = ReplayBuffer.copy_from_path(zarr_path) # , keys=['img', 'state', 'action'])
episode_ids = dataset_dict.get_episode_idxs()
num_episodes = dataset_dict.meta["episode_ends"].shape[0]
total_frames = dataset_dict["action"].shape[0]
assert len(
{dataset_dict[key].shape[0] for key in dataset_dict}
{dataset_dict[key].shape[0] for key in dataset_dict.keys()} # noqa: SIM118
), "Some data type dont have the same number of total frames."
# TODO: verify that goal pose is expected to be fixed

View File

@ -1,9 +1,38 @@
import io
import logging
import random
import zipfile
from datetime import datetime
from pathlib import Path
import numpy as np
import requests
import torch
import tqdm
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
print(f"downloading from {url}")
response = requests.get(url, stream=True)
if response.status_code == 200:
total_size = int(response.headers.get("content-length", 0))
progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True)
zip_file = io.BytesIO()
for chunk in response.iter_content(chunk_size=1024):
if chunk:
zip_file.write(chunk)
progress_bar.update(len(chunk))
progress_bar.close()
zip_file.seek(0)
with zipfile.ZipFile(zip_file, "r") as zip_ref:
zip_ref.extractall(destination_folder)
return True
else:
return False
def set_seed(seed):