Fixes for PR #4
This commit is contained in:
parent
b862145e22
commit
c1942d45d3
|
@ -73,7 +73,6 @@ def make_offline_buffer(cfg, sampler=None):
|
|||
elif cfg.env.name == "pusht":
|
||||
offline_buffer = PushtExperienceReplay(
|
||||
"pusht",
|
||||
download=True,
|
||||
streaming=False,
|
||||
root=DATA_PATH,
|
||||
sampler=sampler,
|
||||
|
|
|
@ -18,7 +18,7 @@ from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
|||
|
||||
from diffusion_policy.common.replay_buffer import ReplayBuffer
|
||||
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
|
||||
from lerobot.common import utils
|
||||
from lerobot.common.datasets import utils
|
||||
|
||||
# as define in env
|
||||
SUCCESS_THRESHOLD = 0.95 # 95% coverage,
|
||||
|
@ -97,17 +97,15 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
|||
replacement: bool = None,
|
||||
streaming: bool = False,
|
||||
root: Path = None,
|
||||
download: bool | str = False,
|
||||
sampler: Sampler = None,
|
||||
writer: Writer = None,
|
||||
collate_fn: Callable = None,
|
||||
pin_memory: bool = False,
|
||||
prefetch: int = None,
|
||||
transform: "torchrl.envs.Transform" = None, # noqa-F821
|
||||
transform: "torchrl.envs.Transform" = None, # noqa: F821
|
||||
split_trajs: bool = False,
|
||||
strict_length: bool = True,
|
||||
):
|
||||
self.download = download
|
||||
if streaming:
|
||||
raise NotImplementedError
|
||||
self.streaming = streaming
|
||||
|
@ -129,8 +127,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
|||
os.makedirs(root, exist_ok=True)
|
||||
|
||||
self.root = root
|
||||
self.raw = self.root / "raw"
|
||||
if self.download == "force" or (self.download and not self._is_downloaded()):
|
||||
if not self._is_downloaded():
|
||||
storage = self._download_and_preproc()
|
||||
else:
|
||||
storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id))
|
||||
|
@ -192,9 +189,11 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
|||
|
||||
def _download_and_preproc(self):
|
||||
# download
|
||||
self.raw.mkdir(exist_ok=True)
|
||||
utils.download_and_extract_zip(PUSHT_URL, self.raw)
|
||||
zarr_path = (self.raw / PUSHT_ZARR).resolve()
|
||||
raw_dir = self.root / "raw"
|
||||
zarr_path = (raw_dir / PUSHT_ZARR).resolve()
|
||||
if not zarr_path.is_dir():
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
utils.download_and_extract_zip(PUSHT_URL, raw_dir)
|
||||
|
||||
# load
|
||||
dataset_dict = ReplayBuffer.copy_from_path(zarr_path) # , keys=['img', 'state', 'action'])
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
import io
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
import tqdm
|
||||
|
||||
|
||||
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
||||
print(f"downloading from {url}")
|
||||
response = requests.get(url, stream=True)
|
||||
if response.status_code == 200:
|
||||
total_size = int(response.headers.get("content-length", 0))
|
||||
progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True)
|
||||
|
||||
zip_file = io.BytesIO()
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
if chunk:
|
||||
zip_file.write(chunk)
|
||||
progress_bar.update(len(chunk))
|
||||
|
||||
progress_bar.close()
|
||||
|
||||
zip_file.seek(0)
|
||||
|
||||
with zipfile.ZipFile(zip_file, "r") as zip_ref:
|
||||
zip_ref.extractall(destination_folder)
|
||||
return True
|
||||
else:
|
||||
return False
|
|
@ -1,38 +1,9 @@
|
|||
import io
|
||||
import logging
|
||||
import random
|
||||
import zipfile
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
|
||||
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
||||
print(f"downloading from {url}")
|
||||
response = requests.get(url, stream=True)
|
||||
if response.status_code == 200:
|
||||
total_size = int(response.headers.get("content-length", 0))
|
||||
progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True)
|
||||
|
||||
zip_file = io.BytesIO()
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
if chunk:
|
||||
zip_file.write(chunk)
|
||||
progress_bar.update(len(chunk))
|
||||
|
||||
progress_bar.close()
|
||||
|
||||
zip_file.seek(0)
|
||||
|
||||
with zipfile.ZipFile(zip_file, "r") as zip_ref:
|
||||
zip_ref.extractall(destination_folder)
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
|
|
Loading…
Reference in New Issue