Merge pull request #4 from Cadene/user/aliberts/pusht_buffer_auto_download
Added pusht dataset auto-download
This commit is contained in:
commit
fa7f473142
|
@ -1,9 +1,13 @@
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torchrl.data.replay_buffers import PrioritizedSliceSampler
|
from torchrl.data.replay_buffers import PrioritizedSliceSampler
|
||||||
|
|
||||||
from lerobot.common.datasets.pusht import PushtExperienceReplay
|
from lerobot.common.datasets.pusht import PushtExperienceReplay
|
||||||
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
|
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
|
||||||
|
|
||||||
|
DATA_PATH = Path("data/")
|
||||||
|
|
||||||
# TODO(rcadene): implement
|
# TODO(rcadene): implement
|
||||||
|
|
||||||
# dataset_d4rl = D4RLExperienceReplay(
|
# dataset_d4rl = D4RLExperienceReplay(
|
||||||
|
@ -60,7 +64,7 @@ def make_offline_buffer(cfg, sampler=None):
|
||||||
# download="force",
|
# download="force",
|
||||||
download=True,
|
download=True,
|
||||||
streaming=False,
|
streaming=False,
|
||||||
root="data",
|
root=str(DATA_PATH),
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
|
@ -69,11 +73,8 @@ def make_offline_buffer(cfg, sampler=None):
|
||||||
elif cfg.env.name == "pusht":
|
elif cfg.env.name == "pusht":
|
||||||
offline_buffer = PushtExperienceReplay(
|
offline_buffer = PushtExperienceReplay(
|
||||||
"pusht",
|
"pusht",
|
||||||
# download="force",
|
|
||||||
# TODO(aliberts): automate download
|
|
||||||
download=False,
|
|
||||||
streaming=False,
|
streaming=False,
|
||||||
root="data",
|
root=DATA_PATH,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
|
|
|
@ -9,8 +9,6 @@ import pymunk
|
||||||
import torch
|
import torch
|
||||||
import torchrl
|
import torchrl
|
||||||
import tqdm
|
import tqdm
|
||||||
from diffusion_policy.common.replay_buffer import ReplayBuffer
|
|
||||||
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
|
|
||||||
from tensordict import TensorDict
|
from tensordict import TensorDict
|
||||||
from torchrl.data.datasets.utils import _get_root_dir
|
from torchrl.data.datasets.utils import _get_root_dir
|
||||||
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
|
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
|
||||||
|
@ -18,10 +16,16 @@ from torchrl.data.replay_buffers.samplers import Sampler
|
||||||
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
||||||
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
||||||
|
|
||||||
|
from diffusion_policy.common.replay_buffer import ReplayBuffer
|
||||||
|
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
|
||||||
|
from lerobot.common.datasets import utils
|
||||||
|
|
||||||
# as define in env
|
# as define in env
|
||||||
SUCCESS_THRESHOLD = 0.95 # 95% coverage,
|
SUCCESS_THRESHOLD = 0.95 # 95% coverage,
|
||||||
|
|
||||||
DEFAULT_TEE_MASK = pymunk.ShapeFilter.ALL_MASKS()
|
DEFAULT_TEE_MASK = pymunk.ShapeFilter.ALL_MASKS()
|
||||||
|
PUSHT_URL = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip"
|
||||||
|
PUSHT_ZARR = Path("pusht/pusht_cchi_v7_replay.zarr")
|
||||||
|
|
||||||
|
|
||||||
def get_goal_pose_body(pose):
|
def get_goal_pose_body(pose):
|
||||||
|
@ -83,7 +87,7 @@ def add_tee(
|
||||||
class PushtExperienceReplay(TensorDictReplayBuffer):
|
class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dataset_id,
|
dataset_id: str,
|
||||||
batch_size: int = None,
|
batch_size: int = None,
|
||||||
*,
|
*,
|
||||||
shuffle: bool = True,
|
shuffle: bool = True,
|
||||||
|
@ -93,17 +97,15 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||||
replacement: bool = None,
|
replacement: bool = None,
|
||||||
streaming: bool = False,
|
streaming: bool = False,
|
||||||
root: Path = None,
|
root: Path = None,
|
||||||
download: bool = False,
|
|
||||||
sampler: Sampler = None,
|
sampler: Sampler = None,
|
||||||
writer: Writer = None,
|
writer: Writer = None,
|
||||||
collate_fn: Callable = None,
|
collate_fn: Callable = None,
|
||||||
pin_memory: bool = False,
|
pin_memory: bool = False,
|
||||||
prefetch: int = None,
|
prefetch: int = None,
|
||||||
transform: "torchrl.envs.Transform" = None, # noqa-F821
|
transform: "torchrl.envs.Transform" = None, # noqa: F821
|
||||||
split_trajs: bool = False,
|
split_trajs: bool = False,
|
||||||
strict_length: bool = True,
|
strict_length: bool = True,
|
||||||
):
|
):
|
||||||
self.download = download
|
|
||||||
if streaming:
|
if streaming:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
self.streaming = streaming
|
self.streaming = streaming
|
||||||
|
@ -120,14 +122,12 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||||
if split_trajs:
|
if split_trajs:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
if self.download:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
if root is None:
|
if root is None:
|
||||||
root = _get_root_dir("pusht")
|
root = _get_root_dir("pusht")
|
||||||
os.makedirs(root, exist_ok=True)
|
os.makedirs(root, exist_ok=True)
|
||||||
self.root = Path(root)
|
|
||||||
if self.download == "force" or (self.download and not self._is_downloaded()):
|
self.root = root
|
||||||
|
if not self._is_downloaded():
|
||||||
storage = self._download_and_preproc()
|
storage = self._download_and_preproc()
|
||||||
else:
|
else:
|
||||||
storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id))
|
storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id))
|
||||||
|
@ -173,39 +173,36 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_samples(self):
|
def num_samples(self) -> int:
|
||||||
return len(self)
|
return len(self)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_episodes(self):
|
def num_episodes(self) -> int:
|
||||||
return len(self._storage._storage["episode"].unique())
|
return len(self._storage._storage["episode"].unique())
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def data_path_root(self):
|
def data_path_root(self) -> Path:
|
||||||
if self.streaming:
|
return None if self.streaming else self.root / self.dataset_id
|
||||||
return None
|
|
||||||
return self.root / self.dataset_id
|
|
||||||
|
|
||||||
def _is_downloaded(self):
|
def _is_downloaded(self) -> bool:
|
||||||
return os.path.exists(self.data_path_root)
|
return self.data_path_root.is_dir()
|
||||||
|
|
||||||
def _download_and_preproc(self):
|
def _download_and_preproc(self):
|
||||||
# download
|
# download
|
||||||
# TODO(rcadene)
|
raw_dir = self.root / "raw"
|
||||||
|
zarr_path = (raw_dir / PUSHT_ZARR).resolve()
|
||||||
|
if not zarr_path.is_dir():
|
||||||
|
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
utils.download_and_extract_zip(PUSHT_URL, raw_dir)
|
||||||
|
|
||||||
# load
|
# load
|
||||||
# TODO(aliberts): Dynamic paths
|
|
||||||
zarr_path = (
|
|
||||||
"/home/rcadene/code/diffusion_policy/data/pusht/pusht_cchi_v7_replay.zarr"
|
|
||||||
# "/home/simon/build/diffusion_policy/data/pusht/pusht_cchi_v7_replay.zarr"
|
|
||||||
)
|
|
||||||
dataset_dict = ReplayBuffer.copy_from_path(zarr_path) # , keys=['img', 'state', 'action'])
|
dataset_dict = ReplayBuffer.copy_from_path(zarr_path) # , keys=['img', 'state', 'action'])
|
||||||
|
|
||||||
episode_ids = dataset_dict.get_episode_idxs()
|
episode_ids = dataset_dict.get_episode_idxs()
|
||||||
num_episodes = dataset_dict.meta["episode_ends"].shape[0]
|
num_episodes = dataset_dict.meta["episode_ends"].shape[0]
|
||||||
total_frames = dataset_dict["action"].shape[0]
|
total_frames = dataset_dict["action"].shape[0]
|
||||||
assert len(
|
assert len(
|
||||||
{dataset_dict[key].shape[0] for key in dataset_dict}
|
{dataset_dict[key].shape[0] for key in dataset_dict.keys()} # noqa: SIM118
|
||||||
), "Some data type dont have the same number of total frames."
|
), "Some data type dont have the same number of total frames."
|
||||||
|
|
||||||
# TODO: verify that goal pose is expected to be fixed
|
# TODO: verify that goal pose is expected to be fixed
|
||||||
|
|
|
@ -0,0 +1,30 @@
|
||||||
|
import io
|
||||||
|
import zipfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
||||||
|
print(f"downloading from {url}")
|
||||||
|
response = requests.get(url, stream=True)
|
||||||
|
if response.status_code == 200:
|
||||||
|
total_size = int(response.headers.get("content-length", 0))
|
||||||
|
progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True)
|
||||||
|
|
||||||
|
zip_file = io.BytesIO()
|
||||||
|
for chunk in response.iter_content(chunk_size=1024):
|
||||||
|
if chunk:
|
||||||
|
zip_file.write(chunk)
|
||||||
|
progress_bar.update(len(chunk))
|
||||||
|
|
||||||
|
progress_bar.close()
|
||||||
|
|
||||||
|
zip_file.seek(0)
|
||||||
|
|
||||||
|
with zipfile.ZipFile(zip_file, "r") as zip_ref:
|
||||||
|
zip_ref.extractall(destination_folder)
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
Loading…
Reference in New Issue