Fixes for PR #4

This commit is contained in:
Simon Alibert 2024-03-01 14:59:05 +01:00
parent b862145e22
commit c1942d45d3
4 changed files with 38 additions and 39 deletions

View File

@ -73,7 +73,6 @@ def make_offline_buffer(cfg, sampler=None):
elif cfg.env.name == "pusht": elif cfg.env.name == "pusht":
offline_buffer = PushtExperienceReplay( offline_buffer = PushtExperienceReplay(
"pusht", "pusht",
download=True,
streaming=False, streaming=False,
root=DATA_PATH, root=DATA_PATH,
sampler=sampler, sampler=sampler,

View File

@ -18,7 +18,7 @@ from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
from diffusion_policy.common.replay_buffer import ReplayBuffer from diffusion_policy.common.replay_buffer import ReplayBuffer
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
from lerobot.common import utils from lerobot.common.datasets import utils
# as define in env # as define in env
SUCCESS_THRESHOLD = 0.95 # 95% coverage, SUCCESS_THRESHOLD = 0.95 # 95% coverage,
@ -97,17 +97,15 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
replacement: bool = None, replacement: bool = None,
streaming: bool = False, streaming: bool = False,
root: Path = None, root: Path = None,
download: bool | str = False,
sampler: Sampler = None, sampler: Sampler = None,
writer: Writer = None, writer: Writer = None,
collate_fn: Callable = None, collate_fn: Callable = None,
pin_memory: bool = False, pin_memory: bool = False,
prefetch: int = None, prefetch: int = None,
transform: "torchrl.envs.Transform" = None, # noqa-F821 transform: "torchrl.envs.Transform" = None, # noqa: F821
split_trajs: bool = False, split_trajs: bool = False,
strict_length: bool = True, strict_length: bool = True,
): ):
self.download = download
if streaming: if streaming:
raise NotImplementedError raise NotImplementedError
self.streaming = streaming self.streaming = streaming
@ -129,8 +127,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
os.makedirs(root, exist_ok=True) os.makedirs(root, exist_ok=True)
self.root = root self.root = root
self.raw = self.root / "raw" if not self._is_downloaded():
if self.download == "force" or (self.download and not self._is_downloaded()):
storage = self._download_and_preproc() storage = self._download_and_preproc()
else: else:
storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id)) storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id))
@ -192,9 +189,11 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
def _download_and_preproc(self): def _download_and_preproc(self):
# download # download
self.raw.mkdir(exist_ok=True) raw_dir = self.root / "raw"
utils.download_and_extract_zip(PUSHT_URL, self.raw) zarr_path = (raw_dir / PUSHT_ZARR).resolve()
zarr_path = (self.raw / PUSHT_ZARR).resolve() if not zarr_path.is_dir():
raw_dir.mkdir(parents=True, exist_ok=True)
utils.download_and_extract_zip(PUSHT_URL, raw_dir)
# load # load
dataset_dict = ReplayBuffer.copy_from_path(zarr_path) # , keys=['img', 'state', 'action']) dataset_dict = ReplayBuffer.copy_from_path(zarr_path) # , keys=['img', 'state', 'action'])

View File

@ -0,0 +1,30 @@
import io
import zipfile
from pathlib import Path
import requests
import tqdm
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
print(f"downloading from {url}")
response = requests.get(url, stream=True)
if response.status_code == 200:
total_size = int(response.headers.get("content-length", 0))
progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True)
zip_file = io.BytesIO()
for chunk in response.iter_content(chunk_size=1024):
if chunk:
zip_file.write(chunk)
progress_bar.update(len(chunk))
progress_bar.close()
zip_file.seek(0)
with zipfile.ZipFile(zip_file, "r") as zip_ref:
zip_ref.extractall(destination_folder)
return True
else:
return False

View File

@ -1,38 +1,9 @@
import io
import logging import logging
import random import random
import zipfile
from datetime import datetime from datetime import datetime
from pathlib import Path
import numpy as np import numpy as np
import requests
import torch import torch
import tqdm
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
print(f"downloading from {url}")
response = requests.get(url, stream=True)
if response.status_code == 200:
total_size = int(response.headers.get("content-length", 0))
progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True)
zip_file = io.BytesIO()
for chunk in response.iter_content(chunk_size=1024):
if chunk:
zip_file.write(chunk)
progress_bar.update(len(chunk))
progress_bar.close()
zip_file.seek(0)
with zipfile.ZipFile(zip_file, "r") as zip_ref:
zip_ref.extractall(destination_folder)
return True
else:
return False
def set_seed(seed): def set_seed(seed):