Fixes for PR #4
This commit is contained in:
parent
b862145e22
commit
c1942d45d3
|
@ -73,7 +73,6 @@ def make_offline_buffer(cfg, sampler=None):
|
||||||
elif cfg.env.name == "pusht":
|
elif cfg.env.name == "pusht":
|
||||||
offline_buffer = PushtExperienceReplay(
|
offline_buffer = PushtExperienceReplay(
|
||||||
"pusht",
|
"pusht",
|
||||||
download=True,
|
|
||||||
streaming=False,
|
streaming=False,
|
||||||
root=DATA_PATH,
|
root=DATA_PATH,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
|
|
|
@ -18,7 +18,7 @@ from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
||||||
|
|
||||||
from diffusion_policy.common.replay_buffer import ReplayBuffer
|
from diffusion_policy.common.replay_buffer import ReplayBuffer
|
||||||
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
|
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
|
||||||
from lerobot.common import utils
|
from lerobot.common.datasets import utils
|
||||||
|
|
||||||
# as define in env
|
# as define in env
|
||||||
SUCCESS_THRESHOLD = 0.95 # 95% coverage,
|
SUCCESS_THRESHOLD = 0.95 # 95% coverage,
|
||||||
|
@ -97,17 +97,15 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||||
replacement: bool = None,
|
replacement: bool = None,
|
||||||
streaming: bool = False,
|
streaming: bool = False,
|
||||||
root: Path = None,
|
root: Path = None,
|
||||||
download: bool | str = False,
|
|
||||||
sampler: Sampler = None,
|
sampler: Sampler = None,
|
||||||
writer: Writer = None,
|
writer: Writer = None,
|
||||||
collate_fn: Callable = None,
|
collate_fn: Callable = None,
|
||||||
pin_memory: bool = False,
|
pin_memory: bool = False,
|
||||||
prefetch: int = None,
|
prefetch: int = None,
|
||||||
transform: "torchrl.envs.Transform" = None, # noqa-F821
|
transform: "torchrl.envs.Transform" = None, # noqa: F821
|
||||||
split_trajs: bool = False,
|
split_trajs: bool = False,
|
||||||
strict_length: bool = True,
|
strict_length: bool = True,
|
||||||
):
|
):
|
||||||
self.download = download
|
|
||||||
if streaming:
|
if streaming:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
self.streaming = streaming
|
self.streaming = streaming
|
||||||
|
@ -129,8 +127,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||||
os.makedirs(root, exist_ok=True)
|
os.makedirs(root, exist_ok=True)
|
||||||
|
|
||||||
self.root = root
|
self.root = root
|
||||||
self.raw = self.root / "raw"
|
if not self._is_downloaded():
|
||||||
if self.download == "force" or (self.download and not self._is_downloaded()):
|
|
||||||
storage = self._download_and_preproc()
|
storage = self._download_and_preproc()
|
||||||
else:
|
else:
|
||||||
storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id))
|
storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id))
|
||||||
|
@ -192,9 +189,11 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||||
|
|
||||||
def _download_and_preproc(self):
|
def _download_and_preproc(self):
|
||||||
# download
|
# download
|
||||||
self.raw.mkdir(exist_ok=True)
|
raw_dir = self.root / "raw"
|
||||||
utils.download_and_extract_zip(PUSHT_URL, self.raw)
|
zarr_path = (raw_dir / PUSHT_ZARR).resolve()
|
||||||
zarr_path = (self.raw / PUSHT_ZARR).resolve()
|
if not zarr_path.is_dir():
|
||||||
|
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
utils.download_and_extract_zip(PUSHT_URL, raw_dir)
|
||||||
|
|
||||||
# load
|
# load
|
||||||
dataset_dict = ReplayBuffer.copy_from_path(zarr_path) # , keys=['img', 'state', 'action'])
|
dataset_dict = ReplayBuffer.copy_from_path(zarr_path) # , keys=['img', 'state', 'action'])
|
||||||
|
|
|
@ -0,0 +1,30 @@
|
||||||
|
import io
|
||||||
|
import zipfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
||||||
|
print(f"downloading from {url}")
|
||||||
|
response = requests.get(url, stream=True)
|
||||||
|
if response.status_code == 200:
|
||||||
|
total_size = int(response.headers.get("content-length", 0))
|
||||||
|
progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True)
|
||||||
|
|
||||||
|
zip_file = io.BytesIO()
|
||||||
|
for chunk in response.iter_content(chunk_size=1024):
|
||||||
|
if chunk:
|
||||||
|
zip_file.write(chunk)
|
||||||
|
progress_bar.update(len(chunk))
|
||||||
|
|
||||||
|
progress_bar.close()
|
||||||
|
|
||||||
|
zip_file.seek(0)
|
||||||
|
|
||||||
|
with zipfile.ZipFile(zip_file, "r") as zip_ref:
|
||||||
|
zip_ref.extractall(destination_folder)
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
|
@ -1,38 +1,9 @@
|
||||||
import io
|
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
import zipfile
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
|
||||||
|
|
||||||
|
|
||||||
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
|
||||||
print(f"downloading from {url}")
|
|
||||||
response = requests.get(url, stream=True)
|
|
||||||
if response.status_code == 200:
|
|
||||||
total_size = int(response.headers.get("content-length", 0))
|
|
||||||
progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True)
|
|
||||||
|
|
||||||
zip_file = io.BytesIO()
|
|
||||||
for chunk in response.iter_content(chunk_size=1024):
|
|
||||||
if chunk:
|
|
||||||
zip_file.write(chunk)
|
|
||||||
progress_bar.update(len(chunk))
|
|
||||||
|
|
||||||
progress_bar.close()
|
|
||||||
|
|
||||||
zip_file.seek(0)
|
|
||||||
|
|
||||||
with zipfile.ZipFile(zip_file, "r") as zip_ref:
|
|
||||||
zip_ref.extractall(destination_folder)
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def set_seed(seed):
|
def set_seed(seed):
|
||||||
|
|
Loading…
Reference in New Issue