Added pusht dataset auto-download

This commit is contained in:
Simon Alibert 2024-03-01 14:31:54 +01:00
parent ca948c1e5b
commit b862145e22
3 changed files with 57 additions and 28 deletions
lerobot/common

View File

@ -1,9 +1,13 @@
from pathlib import Path
import torch import torch
from torchrl.data.replay_buffers import PrioritizedSliceSampler from torchrl.data.replay_buffers import PrioritizedSliceSampler
from lerobot.common.datasets.pusht import PushtExperienceReplay from lerobot.common.datasets.pusht import PushtExperienceReplay
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
DATA_PATH = Path("data/")
# TODO(rcadene): implement # TODO(rcadene): implement
# dataset_d4rl = D4RLExperienceReplay( # dataset_d4rl = D4RLExperienceReplay(
@ -60,7 +64,7 @@ def make_offline_buffer(cfg, sampler=None):
# download="force", # download="force",
download=True, download=True,
streaming=False, streaming=False,
root="data", root=str(DATA_PATH),
sampler=sampler, sampler=sampler,
batch_size=batch_size, batch_size=batch_size,
pin_memory=pin_memory, pin_memory=pin_memory,
@ -69,11 +73,9 @@ def make_offline_buffer(cfg, sampler=None):
elif cfg.env.name == "pusht": elif cfg.env.name == "pusht":
offline_buffer = PushtExperienceReplay( offline_buffer = PushtExperienceReplay(
"pusht", "pusht",
# download="force", download=True,
# TODO(aliberts): automate download
download=False,
streaming=False, streaming=False,
root="data", root=DATA_PATH,
sampler=sampler, sampler=sampler,
batch_size=batch_size, batch_size=batch_size,
pin_memory=pin_memory, pin_memory=pin_memory,

View File

@ -9,8 +9,6 @@ import pymunk
import torch import torch
import torchrl import torchrl
import tqdm import tqdm
from diffusion_policy.common.replay_buffer import ReplayBuffer
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
from tensordict import TensorDict from tensordict import TensorDict
from torchrl.data.datasets.utils import _get_root_dir from torchrl.data.datasets.utils import _get_root_dir
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
@ -18,10 +16,16 @@ from torchrl.data.replay_buffers.samplers import Sampler
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
from diffusion_policy.common.replay_buffer import ReplayBuffer
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
from lerobot.common import utils
# as define in env # as define in env
SUCCESS_THRESHOLD = 0.95 # 95% coverage, SUCCESS_THRESHOLD = 0.95 # 95% coverage,
DEFAULT_TEE_MASK = pymunk.ShapeFilter.ALL_MASKS() DEFAULT_TEE_MASK = pymunk.ShapeFilter.ALL_MASKS()
PUSHT_URL = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip"
PUSHT_ZARR = Path("pusht/pusht_cchi_v7_replay.zarr")
def get_goal_pose_body(pose): def get_goal_pose_body(pose):
@ -83,7 +87,7 @@ def add_tee(
class PushtExperienceReplay(TensorDictReplayBuffer): class PushtExperienceReplay(TensorDictReplayBuffer):
def __init__( def __init__(
self, self,
dataset_id, dataset_id: str,
batch_size: int = None, batch_size: int = None,
*, *,
shuffle: bool = True, shuffle: bool = True,
@ -93,7 +97,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
replacement: bool = None, replacement: bool = None,
streaming: bool = False, streaming: bool = False,
root: Path = None, root: Path = None,
download: bool = False, download: bool | str = False,
sampler: Sampler = None, sampler: Sampler = None,
writer: Writer = None, writer: Writer = None,
collate_fn: Callable = None, collate_fn: Callable = None,
@ -120,13 +124,12 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
if split_trajs: if split_trajs:
raise NotImplementedError raise NotImplementedError
if self.download:
raise NotImplementedError()
if root is None: if root is None:
root = _get_root_dir("pusht") root = _get_root_dir("pusht")
os.makedirs(root, exist_ok=True) os.makedirs(root, exist_ok=True)
self.root = Path(root)
self.root = root
self.raw = self.root / "raw"
if self.download == "force" or (self.download and not self._is_downloaded()): if self.download == "force" or (self.download and not self._is_downloaded()):
storage = self._download_and_preproc() storage = self._download_and_preproc()
else: else:
@ -173,39 +176,34 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
) )
@property @property
def num_samples(self): def num_samples(self) -> int:
return len(self) return len(self)
@property @property
def num_episodes(self): def num_episodes(self) -> int:
return len(self._storage._storage["episode"].unique()) return len(self._storage._storage["episode"].unique())
@property @property
def data_path_root(self): def data_path_root(self) -> Path:
if self.streaming: return None if self.streaming else self.root / self.dataset_id
return None
return self.root / self.dataset_id
def _is_downloaded(self): def _is_downloaded(self) -> bool:
return os.path.exists(self.data_path_root) return self.data_path_root.is_dir()
def _download_and_preproc(self): def _download_and_preproc(self):
# download # download
# TODO(rcadene) self.raw.mkdir(exist_ok=True)
utils.download_and_extract_zip(PUSHT_URL, self.raw)
zarr_path = (self.raw / PUSHT_ZARR).resolve()
# load # load
# TODO(aliberts): Dynamic paths
zarr_path = (
"/home/rcadene/code/diffusion_policy/data/pusht/pusht_cchi_v7_replay.zarr"
# "/home/simon/build/diffusion_policy/data/pusht/pusht_cchi_v7_replay.zarr"
)
dataset_dict = ReplayBuffer.copy_from_path(zarr_path) # , keys=['img', 'state', 'action']) dataset_dict = ReplayBuffer.copy_from_path(zarr_path) # , keys=['img', 'state', 'action'])
episode_ids = dataset_dict.get_episode_idxs() episode_ids = dataset_dict.get_episode_idxs()
num_episodes = dataset_dict.meta["episode_ends"].shape[0] num_episodes = dataset_dict.meta["episode_ends"].shape[0]
total_frames = dataset_dict["action"].shape[0] total_frames = dataset_dict["action"].shape[0]
assert len( assert len(
{dataset_dict[key].shape[0] for key in dataset_dict} {dataset_dict[key].shape[0] for key in dataset_dict.keys()} # noqa: SIM118
), "Some data type dont have the same number of total frames." ), "Some data type dont have the same number of total frames."
# TODO: verify that goal pose is expected to be fixed # TODO: verify that goal pose is expected to be fixed

View File

@ -1,9 +1,38 @@
import io
import logging import logging
import random import random
import zipfile
from datetime import datetime from datetime import datetime
from pathlib import Path
import numpy as np import numpy as np
import requests
import torch import torch
import tqdm
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
print(f"downloading from {url}")
response = requests.get(url, stream=True)
if response.status_code == 200:
total_size = int(response.headers.get("content-length", 0))
progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True)
zip_file = io.BytesIO()
for chunk in response.iter_content(chunk_size=1024):
if chunk:
zip_file.write(chunk)
progress_bar.update(len(chunk))
progress_bar.close()
zip_file.seek(0)
with zipfile.ZipFile(zip_file, "r") as zip_ref:
zip_ref.extractall(destination_folder)
return True
else:
return False
def set_seed(seed): def set_seed(seed):