Fixes for PR #4

This commit is contained in:
Simon Alibert 2024-03-01 14:59:05 +01:00
parent b862145e22
commit c1942d45d3
4 changed files with 38 additions and 39 deletions

View File

@ -73,7 +73,6 @@ def make_offline_buffer(cfg, sampler=None):
elif cfg.env.name == "pusht":
offline_buffer = PushtExperienceReplay(
"pusht",
download=True,
streaming=False,
root=DATA_PATH,
sampler=sampler,

View File

@ -18,7 +18,7 @@ from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
from diffusion_policy.common.replay_buffer import ReplayBuffer
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
from lerobot.common import utils
from lerobot.common.datasets import utils
# as define in env
SUCCESS_THRESHOLD = 0.95 # 95% coverage,
@ -97,17 +97,15 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
replacement: bool = None,
streaming: bool = False,
root: Path = None,
download: bool | str = False,
sampler: Sampler = None,
writer: Writer = None,
collate_fn: Callable = None,
pin_memory: bool = False,
prefetch: int = None,
transform: "torchrl.envs.Transform" = None, # noqa-F821
transform: "torchrl.envs.Transform" = None, # noqa: F821
split_trajs: bool = False,
strict_length: bool = True,
):
self.download = download
if streaming:
raise NotImplementedError
self.streaming = streaming
@ -129,8 +127,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
os.makedirs(root, exist_ok=True)
self.root = root
self.raw = self.root / "raw"
if self.download == "force" or (self.download and not self._is_downloaded()):
if not self._is_downloaded():
storage = self._download_and_preproc()
else:
storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id))
@ -192,9 +189,11 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
def _download_and_preproc(self):
# download
self.raw.mkdir(exist_ok=True)
utils.download_and_extract_zip(PUSHT_URL, self.raw)
zarr_path = (self.raw / PUSHT_ZARR).resolve()
raw_dir = self.root / "raw"
zarr_path = (raw_dir / PUSHT_ZARR).resolve()
if not zarr_path.is_dir():
raw_dir.mkdir(parents=True, exist_ok=True)
utils.download_and_extract_zip(PUSHT_URL, raw_dir)
# load
dataset_dict = ReplayBuffer.copy_from_path(zarr_path) # , keys=['img', 'state', 'action'])

View File

@ -0,0 +1,30 @@
import io
import zipfile
from pathlib import Path
import requests
import tqdm
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
print(f"downloading from {url}")
response = requests.get(url, stream=True)
if response.status_code == 200:
total_size = int(response.headers.get("content-length", 0))
progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True)
zip_file = io.BytesIO()
for chunk in response.iter_content(chunk_size=1024):
if chunk:
zip_file.write(chunk)
progress_bar.update(len(chunk))
progress_bar.close()
zip_file.seek(0)
with zipfile.ZipFile(zip_file, "r") as zip_ref:
zip_ref.extractall(destination_folder)
return True
else:
return False

View File

@ -1,38 +1,9 @@
import io
import logging
import random
import zipfile
from datetime import datetime
from pathlib import Path
import numpy as np
import requests
import torch
import tqdm
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
print(f"downloading from {url}")
response = requests.get(url, stream=True)
if response.status_code == 200:
total_size = int(response.headers.get("content-length", 0))
progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True)
zip_file = io.BytesIO()
for chunk in response.iter_content(chunk_size=1024):
if chunk:
zip_file.write(chunk)
progress_bar.update(len(chunk))
progress_bar.close()
zip_file.seek(0)
with zipfile.ZipFile(zip_file, "r") as zip_ref:
zip_ref.extractall(destination_folder)
return True
else:
return False
def set_seed(seed):