Improve `push_dataset_to_hub` API + Add unit tests (#231)
Co-authored-by: Remi <re.cadene@gmail.com> Co-authored-by: Simon Alibert <alibert.sim@gmail.com> Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
parent
c38f535c9f
commit
125bd93e29
|
@ -34,8 +34,8 @@ jobs:
|
||||||
with:
|
with:
|
||||||
lfs: true # Ensure LFS files are pulled
|
lfs: true # Ensure LFS files are pulled
|
||||||
|
|
||||||
- name: Install EGL
|
- name: Install apt dependencies
|
||||||
run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev
|
run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev ffmpeg
|
||||||
|
|
||||||
- name: Install poetry
|
- name: Install poetry
|
||||||
run: |
|
run: |
|
||||||
|
@ -72,6 +72,9 @@ jobs:
|
||||||
with:
|
with:
|
||||||
lfs: true # Ensure LFS files are pulled
|
lfs: true # Ensure LFS files are pulled
|
||||||
|
|
||||||
|
- name: Install apt dependencies
|
||||||
|
run: sudo apt-get update && sudo apt-get install -y ffmpeg
|
||||||
|
|
||||||
- name: Install poetry
|
- name: Install poetry
|
||||||
run: |
|
run: |
|
||||||
pipx install poetry && poetry config virtualenvs.in-project true
|
pipx install poetry && poetry config virtualenvs.in-project true
|
||||||
|
@ -106,7 +109,7 @@ jobs:
|
||||||
with:
|
with:
|
||||||
lfs: true # Ensure LFS files are pulled
|
lfs: true # Ensure LFS files are pulled
|
||||||
|
|
||||||
- name: Install EGL
|
- name: Install apt dependencies
|
||||||
run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev
|
run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev
|
||||||
|
|
||||||
- name: Install poetry
|
- name: Install poetry
|
||||||
|
|
10
README.md
10
README.md
|
@ -228,13 +228,13 @@ To add a dataset to the hub, you need to login using a write-access token, which
|
||||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||||
```
|
```
|
||||||
|
|
||||||
Then move your dataset folder in `data` directory (e.g. `data/aloha_static_pingpong_test`), and push your dataset to the hub with:
|
Then point to your raw dataset folder (e.g. `data/aloha_static_pingpong_test_raw`), and push your dataset to the hub with:
|
||||||
```bash
|
```bash
|
||||||
python lerobot/scripts/push_dataset_to_hub.py \
|
python lerobot/scripts/push_dataset_to_hub.py \
|
||||||
--data-dir data \
|
--raw-dir data/aloha_static_pingpong_test_raw \
|
||||||
--dataset-id aloha_static_pingpong_test \
|
--out-dir data \
|
||||||
--raw-format aloha_hdf5 \
|
--repo-id lerobot/aloha_static_pingpong_test \
|
||||||
--community-id lerobot
|
--raw-format aloha_hdf5
|
||||||
```
|
```
|
||||||
|
|
||||||
See `python lerobot/scripts/push_dataset_to_hub.py --help` for more instructions.
|
See `python lerobot/scripts/push_dataset_to_hub.py --help` for more instructions.
|
||||||
|
|
|
@ -14,156 +14,119 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
This file contains all obsolete download scripts. They are centralized here to not have to load
|
This file contains download scripts for raw datasets.
|
||||||
useless dependencies when using datasets.
|
|
||||||
|
Example of usage:
|
||||||
|
```
|
||||||
|
python lerobot/common/datasets/push_dataset_to_hub/_download_raw.py \
|
||||||
|
--raw-dir data/cadene/pusht_raw \
|
||||||
|
--repo-id cadene/pusht_raw
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import io
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import shutil
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import tqdm
|
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
|
||||||
def download_raw(raw_dir, dataset_id):
|
def download_raw(raw_dir: Path, repo_id: str):
|
||||||
if "aloha" in dataset_id or "image" in dataset_id:
|
# Check repo_id is well formated
|
||||||
download_hub(raw_dir, dataset_id)
|
if len(repo_id.split("/")) != 2:
|
||||||
elif "pusht" in dataset_id:
|
raise ValueError(
|
||||||
download_pusht(raw_dir)
|
f"`repo_id` is expected to contain a community or user id `/` the name of the dataset (e.g. 'lerobot/pusht'), but contains '{repo_id}'."
|
||||||
elif "xarm" in dataset_id:
|
)
|
||||||
download_xarm(raw_dir)
|
user_id, dataset_id = repo_id.split("/")
|
||||||
elif "umi" in dataset_id:
|
|
||||||
download_umi(raw_dir)
|
|
||||||
else:
|
|
||||||
raise ValueError(dataset_id)
|
|
||||||
|
|
||||||
|
if not dataset_id.endswith("_raw"):
|
||||||
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
warnings.warn(
|
||||||
import zipfile
|
f"`dataset_id` ({dataset_id}) doesn't end with '_raw' (e.g. 'lerobot/pusht_raw'). Following this naming convention by renaming your repository is advised, but not mandatory.",
|
||||||
|
stacklevel=1,
|
||||||
import requests
|
)
|
||||||
|
|
||||||
print(f"downloading from {url}")
|
|
||||||
response = requests.get(url, stream=True)
|
|
||||||
if response.status_code == 200:
|
|
||||||
total_size = int(response.headers.get("content-length", 0))
|
|
||||||
progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True)
|
|
||||||
|
|
||||||
zip_file = io.BytesIO()
|
|
||||||
for chunk in response.iter_content(chunk_size=1024):
|
|
||||||
if chunk:
|
|
||||||
zip_file.write(chunk)
|
|
||||||
progress_bar.update(len(chunk))
|
|
||||||
|
|
||||||
progress_bar.close()
|
|
||||||
|
|
||||||
zip_file.seek(0)
|
|
||||||
|
|
||||||
with zipfile.ZipFile(zip_file, "r") as zip_ref:
|
|
||||||
zip_ref.extractall(destination_folder)
|
|
||||||
|
|
||||||
|
|
||||||
def download_pusht(raw_dir: str):
|
|
||||||
pusht_url = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip"
|
|
||||||
|
|
||||||
raw_dir = Path(raw_dir)
|
raw_dir = Path(raw_dir)
|
||||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
# Send warning if raw_dir isn't well formated
|
||||||
download_and_extract_zip(pusht_url, raw_dir)
|
if raw_dir.parts[-2] != user_id or raw_dir.parts[-1] != dataset_id:
|
||||||
# file is created inside a useful "pusht" directory, so we move it out and delete the dir
|
warnings.warn(
|
||||||
zarr_path = raw_dir / "pusht_cchi_v7_replay.zarr"
|
f"`raw_dir` ({raw_dir}) doesn't contain a community or user id `/` the name of the dataset that match the `repo_id` (e.g. 'data/lerobot/pusht_raw'). Following this naming convention is advised, but not mandatory.",
|
||||||
shutil.move(raw_dir / "pusht" / "pusht_cchi_v7_replay.zarr", zarr_path)
|
stacklevel=1,
|
||||||
shutil.rmtree(raw_dir / "pusht")
|
)
|
||||||
|
|
||||||
|
|
||||||
def download_xarm(raw_dir: Path):
|
|
||||||
"""Download all xarm datasets at once"""
|
|
||||||
import zipfile
|
|
||||||
|
|
||||||
import gdown
|
|
||||||
|
|
||||||
raw_dir = Path(raw_dir)
|
|
||||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
# from https://github.com/fyhMer/fowm/blob/main/scripts/download_datasets.py
|
|
||||||
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
|
|
||||||
zip_path = raw_dir / "data.zip"
|
|
||||||
gdown.download(url, str(zip_path), quiet=False)
|
|
||||||
print("Extracting...")
|
|
||||||
with zipfile.ZipFile(str(zip_path), "r") as zip_f:
|
|
||||||
for pkl_path in zip_f.namelist():
|
|
||||||
if pkl_path.startswith("data/xarm") and pkl_path.endswith(".pkl"):
|
|
||||||
zip_f.extract(member=pkl_path)
|
|
||||||
# move to corresponding raw directory
|
|
||||||
extract_dir = pkl_path.replace("/buffer.pkl", "")
|
|
||||||
raw_pkl_path = raw_dir / "buffer.pkl"
|
|
||||||
shutil.move(pkl_path, raw_pkl_path)
|
|
||||||
shutil.rmtree(extract_dir)
|
|
||||||
zip_path.unlink()
|
|
||||||
|
|
||||||
|
|
||||||
def download_hub(raw_dir: Path, dataset_id: str):
|
|
||||||
raw_dir = Path(raw_dir)
|
|
||||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
logging.info(f"Start downloading from huggingface.co/cadene for {dataset_id}")
|
logging.info(f"Start downloading from huggingface.co/{user_id} for {dataset_id}")
|
||||||
snapshot_download(f"cadene/{dataset_id}_raw", repo_type="dataset", local_dir=raw_dir)
|
snapshot_download(f"{repo_id}", repo_type="dataset", local_dir=raw_dir)
|
||||||
logging.info(f"Finish downloading from huggingface.co/cadene for {dataset_id}")
|
logging.info(f"Finish downloading from huggingface.co/{user_id} for {dataset_id}")
|
||||||
|
|
||||||
|
|
||||||
def download_umi(raw_dir: Path):
|
def download_all_raw_datasets():
|
||||||
url_cup_in_the_wild = "https://real.stanford.edu/umi/data/zarr_datasets/cup_in_the_wild.zarr.zip"
|
data_dir = Path("data")
|
||||||
zarr_path = raw_dir / "cup_in_the_wild.zarr"
|
repo_ids = [
|
||||||
|
"cadene/pusht_image_raw",
|
||||||
|
"cadene/xarm_lift_medium_image_raw",
|
||||||
|
"cadene/xarm_lift_medium_replay_image_raw",
|
||||||
|
"cadene/xarm_push_medium_image_raw",
|
||||||
|
"cadene/xarm_push_medium_replay_image_raw",
|
||||||
|
"cadene/aloha_sim_insertion_human_image_raw",
|
||||||
|
"cadene/aloha_sim_insertion_scripted_image_raw",
|
||||||
|
"cadene/aloha_sim_transfer_cube_human_image_raw",
|
||||||
|
"cadene/aloha_sim_transfer_cube_scripted_image_raw",
|
||||||
|
"cadene/pusht_raw",
|
||||||
|
"cadene/xarm_lift_medium_raw",
|
||||||
|
"cadene/xarm_lift_medium_replay_raw",
|
||||||
|
"cadene/xarm_push_medium_raw",
|
||||||
|
"cadene/xarm_push_medium_replay_raw",
|
||||||
|
"cadene/aloha_sim_insertion_human_raw",
|
||||||
|
"cadene/aloha_sim_insertion_scripted_raw",
|
||||||
|
"cadene/aloha_sim_transfer_cube_human_raw",
|
||||||
|
"cadene/aloha_sim_transfer_cube_scripted_raw",
|
||||||
|
"cadene/aloha_mobile_cabinet_raw",
|
||||||
|
"cadene/aloha_mobile_chair_raw",
|
||||||
|
"cadene/aloha_mobile_elevator_raw",
|
||||||
|
"cadene/aloha_mobile_shrimp_raw",
|
||||||
|
"cadene/aloha_mobile_wash_pan_raw",
|
||||||
|
"cadene/aloha_mobile_wipe_wine_raw",
|
||||||
|
"cadene/aloha_static_battery_raw",
|
||||||
|
"cadene/aloha_static_candy_raw",
|
||||||
|
"cadene/aloha_static_coffee_raw",
|
||||||
|
"cadene/aloha_static_coffee_new_raw",
|
||||||
|
"cadene/aloha_static_cups_open_raw",
|
||||||
|
"cadene/aloha_static_fork_pick_up_raw",
|
||||||
|
"cadene/aloha_static_pingpong_test_raw",
|
||||||
|
"cadene/aloha_static_pro_pencil_raw",
|
||||||
|
"cadene/aloha_static_screw_driver_raw",
|
||||||
|
"cadene/aloha_static_tape_raw",
|
||||||
|
"cadene/aloha_static_thread_velcro_raw",
|
||||||
|
"cadene/aloha_static_towel_raw",
|
||||||
|
"cadene/aloha_static_vinh_cup_raw",
|
||||||
|
"cadene/aloha_static_vinh_cup_left_raw",
|
||||||
|
"cadene/aloha_static_ziploc_slide_raw",
|
||||||
|
"cadene/umi_cup_in_the_wild_raw",
|
||||||
|
]
|
||||||
|
for repo_id in repo_ids:
|
||||||
|
raw_dir = data_dir / repo_id
|
||||||
|
download_raw(raw_dir, repo_id)
|
||||||
|
|
||||||
raw_dir = Path(raw_dir)
|
|
||||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
def main():
|
||||||
download_and_extract_zip(url_cup_in_the_wild, zarr_path)
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--raw-dir",
|
||||||
|
type=Path,
|
||||||
|
required=True,
|
||||||
|
help="Directory containing input raw datasets (e.g. `data/aloha_mobile_chair_raw` or `data/pusht_raw).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--repo-id",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht_raw`, `cadene/aloha_sim_insertion_human_raw`).",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
download_raw(**vars(args))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
data_dir = Path("data")
|
main()
|
||||||
dataset_ids = [
|
|
||||||
"pusht_image",
|
|
||||||
"xarm_lift_medium_image",
|
|
||||||
"xarm_lift_medium_replay_image",
|
|
||||||
"xarm_push_medium_image",
|
|
||||||
"xarm_push_medium_replay_image",
|
|
||||||
"aloha_sim_insertion_human_image",
|
|
||||||
"aloha_sim_insertion_scripted_image",
|
|
||||||
"aloha_sim_transfer_cube_human_image",
|
|
||||||
"aloha_sim_transfer_cube_scripted_image",
|
|
||||||
"pusht",
|
|
||||||
"xarm_lift_medium",
|
|
||||||
"xarm_lift_medium_replay",
|
|
||||||
"xarm_push_medium",
|
|
||||||
"xarm_push_medium_replay",
|
|
||||||
"aloha_sim_insertion_human",
|
|
||||||
"aloha_sim_insertion_scripted",
|
|
||||||
"aloha_sim_transfer_cube_human",
|
|
||||||
"aloha_sim_transfer_cube_scripted",
|
|
||||||
"aloha_mobile_cabinet",
|
|
||||||
"aloha_mobile_chair",
|
|
||||||
"aloha_mobile_elevator",
|
|
||||||
"aloha_mobile_shrimp",
|
|
||||||
"aloha_mobile_wash_pan",
|
|
||||||
"aloha_mobile_wipe_wine",
|
|
||||||
"aloha_static_battery",
|
|
||||||
"aloha_static_candy",
|
|
||||||
"aloha_static_coffee",
|
|
||||||
"aloha_static_coffee_new",
|
|
||||||
"aloha_static_cups_open",
|
|
||||||
"aloha_static_fork_pick_up",
|
|
||||||
"aloha_static_pingpong_test",
|
|
||||||
"aloha_static_pro_pencil",
|
|
||||||
"aloha_static_screw_driver",
|
|
||||||
"aloha_static_tape",
|
|
||||||
"aloha_static_thread_velcro",
|
|
||||||
"aloha_static_towel",
|
|
||||||
"aloha_static_vinh_cup",
|
|
||||||
"aloha_static_vinh_cup_left",
|
|
||||||
"aloha_static_ziploc_slide",
|
|
||||||
"umi_cup_in_the_wild",
|
|
||||||
]
|
|
||||||
for dataset_id in dataset_ids:
|
|
||||||
raw_dir = data_dir / f"{dataset_id}_raw"
|
|
||||||
download_raw(raw_dir, dataset_id)
|
|
||||||
|
|
|
@ -30,6 +30,7 @@ from PIL import Image as PILImage
|
||||||
|
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
|
calculate_episode_data_index,
|
||||||
hf_transform_to_torch,
|
hf_transform_to_torch,
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||||
|
@ -70,16 +71,17 @@ def check_format(raw_dir) -> bool:
|
||||||
assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
|
assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
|
||||||
|
|
||||||
|
|
||||||
def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
|
||||||
# only frames from simulation are uncompressed
|
# only frames from simulation are uncompressed
|
||||||
compressed_images = "sim" not in raw_dir.name
|
compressed_images = "sim" not in raw_dir.name
|
||||||
|
|
||||||
hdf5_files = list(raw_dir.glob("*.hdf5"))
|
hdf5_files = sorted(raw_dir.glob("episode_*.hdf5"))
|
||||||
ep_dicts = []
|
num_episodes = len(hdf5_files)
|
||||||
episode_data_index = {"from": [], "to": []}
|
|
||||||
|
|
||||||
id_from = 0
|
ep_dicts = []
|
||||||
for ep_idx, ep_path in tqdm.tqdm(enumerate(hdf5_files), total=len(hdf5_files)):
|
ep_ids = episodes if episodes else range(num_episodes)
|
||||||
|
for ep_idx in tqdm.tqdm(ep_ids):
|
||||||
|
ep_path = hdf5_files[ep_idx]
|
||||||
with h5py.File(ep_path, "r") as ep:
|
with h5py.File(ep_path, "r") as ep:
|
||||||
num_frames = ep["/action"].shape[0]
|
num_frames = ep["/action"].shape[0]
|
||||||
|
|
||||||
|
@ -114,12 +116,12 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||||
|
|
||||||
if video:
|
if video:
|
||||||
# save png images in temporary directory
|
# save png images in temporary directory
|
||||||
tmp_imgs_dir = out_dir / "tmp_images"
|
tmp_imgs_dir = videos_dir / "tmp_images"
|
||||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||||
|
|
||||||
# encode images to a mp4 video
|
# encode images to a mp4 video
|
||||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||||
video_path = out_dir / "videos" / fname
|
video_path = videos_dir / fname
|
||||||
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
||||||
|
|
||||||
# clean temporary images directory
|
# clean temporary images directory
|
||||||
|
@ -147,19 +149,13 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||||
assert isinstance(ep_idx, int)
|
assert isinstance(ep_idx, int)
|
||||||
ep_dicts.append(ep_dict)
|
ep_dicts.append(ep_dict)
|
||||||
|
|
||||||
episode_data_index["from"].append(id_from)
|
|
||||||
episode_data_index["to"].append(id_from + num_frames)
|
|
||||||
|
|
||||||
id_from += num_frames
|
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
# process first episode only
|
|
||||||
if debug:
|
|
||||||
break
|
|
||||||
|
|
||||||
data_dict = concatenate_episodes(ep_dicts)
|
data_dict = concatenate_episodes(ep_dicts)
|
||||||
return data_dict, episode_data_index
|
|
||||||
|
total_frames = data_dict["frame_index"].shape[0]
|
||||||
|
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||||
|
return data_dict
|
||||||
|
|
||||||
|
|
||||||
def to_hf_dataset(data_dict, video) -> Dataset:
|
def to_hf_dataset(data_dict, video) -> Dataset:
|
||||||
|
@ -197,16 +193,22 @@ def to_hf_dataset(data_dict, video) -> Dataset:
|
||||||
return hf_dataset
|
return hf_dataset
|
||||||
|
|
||||||
|
|
||||||
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
|
def from_raw_to_lerobot_format(
|
||||||
|
raw_dir: Path,
|
||||||
|
videos_dir: Path,
|
||||||
|
fps: int | None = None,
|
||||||
|
video: bool = True,
|
||||||
|
episodes: list[int] | None = None,
|
||||||
|
):
|
||||||
# sanity check
|
# sanity check
|
||||||
check_format(raw_dir)
|
check_format(raw_dir)
|
||||||
|
|
||||||
if fps is None:
|
if fps is None:
|
||||||
fps = 50
|
fps = 50
|
||||||
|
|
||||||
data_dir, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug)
|
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
|
||||||
hf_dataset = to_hf_dataset(data_dir, video)
|
hf_dataset = to_hf_dataset(data_dict, video)
|
||||||
|
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||||
info = {
|
info = {
|
||||||
"fps": fps,
|
"fps": fps,
|
||||||
"video": video,
|
"video": video,
|
||||||
|
|
|
@ -17,7 +17,6 @@
|
||||||
Contains utilities to process raw data format from dora-record
|
Contains utilities to process raw data format from dora-record
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
|
||||||
import re
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
@ -26,10 +25,10 @@ import torch
|
||||||
from datasets import Dataset, Features, Image, Sequence, Value
|
from datasets import Dataset, Features, Image, Sequence, Value
|
||||||
|
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
|
calculate_episode_data_index,
|
||||||
hf_transform_to_torch,
|
hf_transform_to_torch,
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.video_utils import VideoFrame
|
from lerobot.common.datasets.video_utils import VideoFrame
|
||||||
from lerobot.common.utils.utils import init_logging
|
|
||||||
|
|
||||||
|
|
||||||
def check_format(raw_dir) -> bool:
|
def check_format(raw_dir) -> bool:
|
||||||
|
@ -41,7 +40,7 @@ def check_format(raw_dir) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def load_from_raw(raw_dir: Path, out_dir: Path, fps: int):
|
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
|
||||||
# Load data stream that will be used as reference for the timestamps synchronization
|
# Load data stream that will be used as reference for the timestamps synchronization
|
||||||
reference_files = list(raw_dir.glob("observation.images.cam_*.parquet"))
|
reference_files = list(raw_dir.glob("observation.images.cam_*.parquet"))
|
||||||
if len(reference_files) == 0:
|
if len(reference_files) == 0:
|
||||||
|
@ -122,8 +121,7 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int):
|
||||||
raise ValueError(f"Episodes indices go from {ep_ids} instead of {expected_ep_ids}")
|
raise ValueError(f"Episodes indices go from {ep_ids} instead of {expected_ep_ids}")
|
||||||
|
|
||||||
# Create symlink to raw videos directory (that needs to be absolute not relative)
|
# Create symlink to raw videos directory (that needs to be absolute not relative)
|
||||||
out_dir.mkdir(parents=True, exist_ok=True)
|
videos_dir.parent.mkdir(parents=True, exist_ok=True)
|
||||||
videos_dir = out_dir / "videos"
|
|
||||||
videos_dir.symlink_to((raw_dir / "videos").absolute())
|
videos_dir.symlink_to((raw_dir / "videos").absolute())
|
||||||
|
|
||||||
# sanity check the video paths are well formated
|
# sanity check the video paths are well formated
|
||||||
|
@ -156,16 +154,7 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int):
|
||||||
else:
|
else:
|
||||||
raise ValueError(key)
|
raise ValueError(key)
|
||||||
|
|
||||||
# Get the episode index containing for each unique episode index
|
return data_dict
|
||||||
first_ep_index_df = df.groupby("episode_index").agg(start_index=("index", "first")).reset_index()
|
|
||||||
from_ = first_ep_index_df["start_index"].tolist()
|
|
||||||
to_ = from_[1:] + [len(df)]
|
|
||||||
episode_data_index = {
|
|
||||||
"from": from_,
|
|
||||||
"to": to_,
|
|
||||||
}
|
|
||||||
|
|
||||||
return data_dict, episode_data_index
|
|
||||||
|
|
||||||
|
|
||||||
def to_hf_dataset(data_dict, video) -> Dataset:
|
def to_hf_dataset(data_dict, video) -> Dataset:
|
||||||
|
@ -203,12 +192,13 @@ def to_hf_dataset(data_dict, video) -> Dataset:
|
||||||
return hf_dataset
|
return hf_dataset
|
||||||
|
|
||||||
|
|
||||||
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
|
def from_raw_to_lerobot_format(
|
||||||
init_logging()
|
raw_dir: Path,
|
||||||
|
videos_dir: Path,
|
||||||
if debug:
|
fps: int | None = None,
|
||||||
logging.warning("debug=True not implemented. Falling back to debug=False.")
|
video: bool = True,
|
||||||
|
episodes: list[int] | None = None,
|
||||||
|
):
|
||||||
# sanity check
|
# sanity check
|
||||||
check_format(raw_dir)
|
check_format(raw_dir)
|
||||||
|
|
||||||
|
@ -220,9 +210,9 @@ def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=Tru
|
||||||
if not video:
|
if not video:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
data_df, episode_data_index = load_from_raw(raw_dir, out_dir, fps)
|
data_df = load_from_raw(raw_dir, videos_dir, fps, episodes)
|
||||||
hf_dataset = to_hf_dataset(data_df, video)
|
hf_dataset = to_hf_dataset(data_df, video)
|
||||||
|
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||||
info = {
|
info = {
|
||||||
"fps": fps,
|
"fps": fps,
|
||||||
"video": video,
|
"video": video,
|
|
@ -27,6 +27,7 @@ from PIL import Image as PILImage
|
||||||
|
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
|
calculate_episode_data_index,
|
||||||
hf_transform_to_torch,
|
hf_transform_to_torch,
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||||
|
@ -53,7 +54,7 @@ def check_format(raw_dir):
|
||||||
assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
|
assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
|
||||||
|
|
||||||
|
|
||||||
def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
|
||||||
try:
|
try:
|
||||||
import pymunk
|
import pymunk
|
||||||
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
|
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
|
||||||
|
@ -71,7 +72,6 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||||
zarr_data = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path)
|
zarr_data = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path)
|
||||||
|
|
||||||
episode_ids = torch.from_numpy(zarr_data.get_episode_idxs())
|
episode_ids = torch.from_numpy(zarr_data.get_episode_idxs())
|
||||||
num_episodes = zarr_data.meta["episode_ends"].shape[0]
|
|
||||||
assert len(
|
assert len(
|
||||||
{zarr_data[key].shape[0] for key in zarr_data.keys()} # noqa: SIM118
|
{zarr_data[key].shape[0] for key in zarr_data.keys()} # noqa: SIM118
|
||||||
), "Some data type dont have the same number of total frames."
|
), "Some data type dont have the same number of total frames."
|
||||||
|
@ -84,25 +84,34 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||||
states = torch.from_numpy(zarr_data["state"])
|
states = torch.from_numpy(zarr_data["state"])
|
||||||
actions = torch.from_numpy(zarr_data["action"])
|
actions = torch.from_numpy(zarr_data["action"])
|
||||||
|
|
||||||
ep_dicts = []
|
# load data indices from which each episode starts and ends
|
||||||
episode_data_index = {"from": [], "to": []}
|
from_ids, to_ids = [], []
|
||||||
|
from_idx = 0
|
||||||
|
for to_idx in zarr_data.meta["episode_ends"]:
|
||||||
|
from_ids.append(from_idx)
|
||||||
|
to_ids.append(to_idx)
|
||||||
|
from_idx = to_idx
|
||||||
|
|
||||||
id_from = 0
|
num_episodes = len(from_ids)
|
||||||
for ep_idx in tqdm.tqdm(range(num_episodes)):
|
|
||||||
id_to = zarr_data.meta["episode_ends"][ep_idx]
|
ep_dicts = []
|
||||||
num_frames = id_to - id_from
|
ep_ids = episodes if episodes else range(num_episodes)
|
||||||
|
for ep_idx, selected_ep_idx in tqdm.tqdm(enumerate(ep_ids)):
|
||||||
|
from_idx = from_ids[selected_ep_idx]
|
||||||
|
to_idx = to_ids[selected_ep_idx]
|
||||||
|
num_frames = to_idx - from_idx
|
||||||
|
|
||||||
# sanity check
|
# sanity check
|
||||||
assert (episode_ids[id_from:id_to] == ep_idx).all()
|
assert (episode_ids[from_idx:to_idx] == ep_idx).all()
|
||||||
|
|
||||||
# get image
|
# get image
|
||||||
image = imgs[id_from:id_to]
|
image = imgs[from_idx:to_idx]
|
||||||
assert image.min() >= 0.0
|
assert image.min() >= 0.0
|
||||||
assert image.max() <= 255.0
|
assert image.max() <= 255.0
|
||||||
image = image.type(torch.uint8)
|
image = image.type(torch.uint8)
|
||||||
|
|
||||||
# get state
|
# get state
|
||||||
state = states[id_from:id_to]
|
state = states[from_idx:to_idx]
|
||||||
agent_pos = state[:, :2]
|
agent_pos = state[:, :2]
|
||||||
block_pos = state[:, 2:4]
|
block_pos = state[:, 2:4]
|
||||||
block_angle = state[:, 4]
|
block_angle = state[:, 4]
|
||||||
|
@ -143,12 +152,12 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||||
img_key = "observation.image"
|
img_key = "observation.image"
|
||||||
if video:
|
if video:
|
||||||
# save png images in temporary directory
|
# save png images in temporary directory
|
||||||
tmp_imgs_dir = out_dir / "tmp_images"
|
tmp_imgs_dir = videos_dir / "tmp_images"
|
||||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||||
|
|
||||||
# encode images to a mp4 video
|
# encode images to a mp4 video
|
||||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||||
video_path = out_dir / "videos" / fname
|
video_path = videos_dir / fname
|
||||||
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
||||||
|
|
||||||
# clean temporary images directory
|
# clean temporary images directory
|
||||||
|
@ -160,7 +169,7 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
||||||
|
|
||||||
ep_dict["observation.state"] = agent_pos
|
ep_dict["observation.state"] = agent_pos
|
||||||
ep_dict["action"] = actions[id_from:id_to]
|
ep_dict["action"] = actions[from_idx:to_idx]
|
||||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
|
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
|
||||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||||
|
@ -172,17 +181,11 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||||
ep_dict["next.success"] = torch.cat([success[1:], success[[-1]]])
|
ep_dict["next.success"] = torch.cat([success[1:], success[[-1]]])
|
||||||
ep_dicts.append(ep_dict)
|
ep_dicts.append(ep_dict)
|
||||||
|
|
||||||
episode_data_index["from"].append(id_from)
|
|
||||||
episode_data_index["to"].append(id_from + num_frames)
|
|
||||||
|
|
||||||
id_from += num_frames
|
|
||||||
|
|
||||||
# process first episode only
|
|
||||||
if debug:
|
|
||||||
break
|
|
||||||
|
|
||||||
data_dict = concatenate_episodes(ep_dicts)
|
data_dict = concatenate_episodes(ep_dicts)
|
||||||
return data_dict, episode_data_index
|
|
||||||
|
total_frames = data_dict["frame_index"].shape[0]
|
||||||
|
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||||
|
return data_dict
|
||||||
|
|
||||||
|
|
||||||
def to_hf_dataset(data_dict, video):
|
def to_hf_dataset(data_dict, video):
|
||||||
|
@ -212,16 +215,22 @@ def to_hf_dataset(data_dict, video):
|
||||||
return hf_dataset
|
return hf_dataset
|
||||||
|
|
||||||
|
|
||||||
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
|
def from_raw_to_lerobot_format(
|
||||||
|
raw_dir: Path,
|
||||||
|
videos_dir: Path,
|
||||||
|
fps: int | None = None,
|
||||||
|
video: bool = True,
|
||||||
|
episodes: list[int] | None = None,
|
||||||
|
):
|
||||||
# sanity check
|
# sanity check
|
||||||
check_format(raw_dir)
|
check_format(raw_dir)
|
||||||
|
|
||||||
if fps is None:
|
if fps is None:
|
||||||
fps = 10
|
fps = 10
|
||||||
|
|
||||||
data_dict, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug)
|
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
|
||||||
hf_dataset = to_hf_dataset(data_dict, video)
|
hf_dataset = to_hf_dataset(data_dict, video)
|
||||||
|
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||||
info = {
|
info = {
|
||||||
"fps": fps,
|
"fps": fps,
|
||||||
"video": video,
|
"video": video,
|
||||||
|
|
|
@ -19,7 +19,6 @@ import logging
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
import zarr
|
import zarr
|
||||||
|
@ -29,6 +28,7 @@ from PIL import Image as PILImage
|
||||||
from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import register_codecs
|
from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import register_codecs
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
|
calculate_episode_data_index,
|
||||||
hf_transform_to_torch,
|
hf_transform_to_torch,
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||||
|
@ -59,23 +59,7 @@ def check_format(raw_dir) -> bool:
|
||||||
assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
|
assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
|
||||||
|
|
||||||
|
|
||||||
def get_episode_idxs(episode_ends: np.ndarray) -> np.ndarray:
|
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
|
||||||
# Optimized and simplified version of this function: https://github.com/real-stanford/universal_manipulation_interface/blob/298776ce251f33b6b3185a98d6e7d1f9ad49168b/diffusion_policy/common/replay_buffer.py#L374
|
|
||||||
from numba import jit
|
|
||||||
|
|
||||||
@jit(nopython=True)
|
|
||||||
def _get_episode_idxs(episode_ends):
|
|
||||||
result = np.zeros((episode_ends[-1],), dtype=np.int64)
|
|
||||||
start_idx = 0
|
|
||||||
for episode_number, end_idx in enumerate(episode_ends):
|
|
||||||
result[start_idx:end_idx] = episode_number
|
|
||||||
start_idx = end_idx
|
|
||||||
return result
|
|
||||||
|
|
||||||
return _get_episode_idxs(episode_ends)
|
|
||||||
|
|
||||||
|
|
||||||
def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
|
||||||
zarr_path = raw_dir / "cup_in_the_wild.zarr"
|
zarr_path = raw_dir / "cup_in_the_wild.zarr"
|
||||||
zarr_data = zarr.open(zarr_path, mode="r")
|
zarr_data = zarr.open(zarr_path, mode="r")
|
||||||
|
|
||||||
|
@ -92,39 +76,41 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||||
episode_ends = zarr_data["meta/episode_ends"][:]
|
episode_ends = zarr_data["meta/episode_ends"][:]
|
||||||
num_episodes = episode_ends.shape[0]
|
num_episodes = episode_ends.shape[0]
|
||||||
|
|
||||||
episode_ids = torch.from_numpy(get_episode_idxs(episode_ends))
|
|
||||||
|
|
||||||
# We convert it in torch tensor later because the jit function does not support torch tensors
|
# We convert it in torch tensor later because the jit function does not support torch tensors
|
||||||
episode_ends = torch.from_numpy(episode_ends)
|
episode_ends = torch.from_numpy(episode_ends)
|
||||||
|
|
||||||
|
# load data indices from which each episode starts and ends
|
||||||
|
from_ids, to_ids = [], []
|
||||||
|
from_idx = 0
|
||||||
|
for to_idx in episode_ends:
|
||||||
|
from_ids.append(from_idx)
|
||||||
|
to_ids.append(to_idx)
|
||||||
|
from_idx = to_idx
|
||||||
|
|
||||||
ep_dicts = []
|
ep_dicts = []
|
||||||
episode_data_index = {"from": [], "to": []}
|
ep_ids = episodes if episodes else range(num_episodes)
|
||||||
|
for ep_idx, selected_ep_idx in tqdm.tqdm(enumerate(ep_ids)):
|
||||||
id_from = 0
|
from_idx = from_ids[selected_ep_idx]
|
||||||
for ep_idx in tqdm.tqdm(range(num_episodes)):
|
to_idx = to_ids[selected_ep_idx]
|
||||||
id_to = episode_ends[ep_idx]
|
num_frames = to_idx - from_idx
|
||||||
num_frames = id_to - id_from
|
|
||||||
|
|
||||||
# sanity heck
|
|
||||||
assert (episode_ids[id_from:id_to] == ep_idx).all()
|
|
||||||
|
|
||||||
# TODO(rcadene): save temporary images of the episode?
|
# TODO(rcadene): save temporary images of the episode?
|
||||||
|
|
||||||
state = states[id_from:id_to]
|
state = states[from_idx:to_idx]
|
||||||
|
|
||||||
ep_dict = {}
|
ep_dict = {}
|
||||||
|
|
||||||
# load 57MB of images in RAM (400x224x224x3 uint8)
|
# load 57MB of images in RAM (400x224x224x3 uint8)
|
||||||
imgs_array = zarr_data["data/camera0_rgb"][id_from:id_to]
|
imgs_array = zarr_data["data/camera0_rgb"][from_idx:to_idx]
|
||||||
img_key = "observation.image"
|
img_key = "observation.image"
|
||||||
if video:
|
if video:
|
||||||
# save png images in temporary directory
|
# save png images in temporary directory
|
||||||
tmp_imgs_dir = out_dir / "tmp_images"
|
tmp_imgs_dir = videos_dir / "tmp_images"
|
||||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||||
|
|
||||||
# encode images to a mp4 video
|
# encode images to a mp4 video
|
||||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||||
video_path = out_dir / "videos" / fname
|
video_path = videos_dir / fname
|
||||||
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
||||||
|
|
||||||
# clean temporary images directory
|
# clean temporary images directory
|
||||||
|
@ -139,27 +125,18 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
|
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
|
||||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||||
ep_dict["episode_data_index_from"] = torch.tensor([id_from] * num_frames)
|
ep_dict["episode_data_index_from"] = torch.tensor([from_idx] * num_frames)
|
||||||
ep_dict["episode_data_index_to"] = torch.tensor([id_from + num_frames] * num_frames)
|
ep_dict["episode_data_index_to"] = torch.tensor([from_idx + num_frames] * num_frames)
|
||||||
ep_dict["end_pose"] = end_pose[id_from:id_to]
|
ep_dict["end_pose"] = end_pose[from_idx:to_idx]
|
||||||
ep_dict["start_pos"] = start_pos[id_from:id_to]
|
ep_dict["start_pos"] = start_pos[from_idx:to_idx]
|
||||||
ep_dict["gripper_width"] = gripper_width[id_from:id_to]
|
ep_dict["gripper_width"] = gripper_width[from_idx:to_idx]
|
||||||
ep_dicts.append(ep_dict)
|
ep_dicts.append(ep_dict)
|
||||||
|
|
||||||
episode_data_index["from"].append(id_from)
|
|
||||||
episode_data_index["to"].append(id_from + num_frames)
|
|
||||||
id_from += num_frames
|
|
||||||
|
|
||||||
# process first episode only
|
|
||||||
if debug:
|
|
||||||
break
|
|
||||||
|
|
||||||
data_dict = concatenate_episodes(ep_dicts)
|
data_dict = concatenate_episodes(ep_dicts)
|
||||||
|
|
||||||
total_frames = id_from
|
total_frames = data_dict["frame_index"].shape[0]
|
||||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||||
|
return data_dict
|
||||||
return data_dict, episode_data_index
|
|
||||||
|
|
||||||
|
|
||||||
def to_hf_dataset(data_dict, video):
|
def to_hf_dataset(data_dict, video):
|
||||||
|
@ -199,7 +176,13 @@ def to_hf_dataset(data_dict, video):
|
||||||
return hf_dataset
|
return hf_dataset
|
||||||
|
|
||||||
|
|
||||||
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
|
def from_raw_to_lerobot_format(
|
||||||
|
raw_dir: Path,
|
||||||
|
videos_dir: Path,
|
||||||
|
fps: int | None = None,
|
||||||
|
video: bool = True,
|
||||||
|
episodes: list[int] | None = None,
|
||||||
|
):
|
||||||
# sanity check
|
# sanity check
|
||||||
check_format(raw_dir)
|
check_format(raw_dir)
|
||||||
|
|
||||||
|
@ -212,9 +195,9 @@ def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=Tru
|
||||||
"Generating UMI dataset without `video=True` creates ~150GB on disk and requires ~80GB in RAM."
|
"Generating UMI dataset without `video=True` creates ~150GB on disk and requires ~80GB in RAM."
|
||||||
)
|
)
|
||||||
|
|
||||||
data_dict, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug)
|
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
|
||||||
hf_dataset = to_hf_dataset(data_dict, video)
|
hf_dataset = to_hf_dataset(data_dict, video)
|
||||||
|
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||||
info = {
|
info = {
|
||||||
"fps": fps,
|
"fps": fps,
|
||||||
"video": video,
|
"video": video,
|
||||||
|
|
|
@ -27,6 +27,7 @@ from PIL import Image as PILImage
|
||||||
|
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
|
calculate_episode_data_index,
|
||||||
hf_transform_to_torch,
|
hf_transform_to_torch,
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||||
|
@ -54,37 +55,42 @@ def check_format(raw_dir):
|
||||||
assert all(len(nested_dict[subkey]) == expected_len for subkey in subkeys if subkey in nested_dict)
|
assert all(len(nested_dict[subkey]) == expected_len for subkey in subkeys if subkey in nested_dict)
|
||||||
|
|
||||||
|
|
||||||
def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
|
||||||
pkl_path = raw_dir / "buffer.pkl"
|
pkl_path = raw_dir / "buffer.pkl"
|
||||||
|
|
||||||
with open(pkl_path, "rb") as f:
|
with open(pkl_path, "rb") as f:
|
||||||
pkl_data = pickle.load(f)
|
pkl_data = pickle.load(f)
|
||||||
|
|
||||||
ep_dicts = []
|
# load data indices from which each episode starts and ends
|
||||||
episode_data_index = {"from": [], "to": []}
|
from_ids, to_ids = [], []
|
||||||
|
from_idx, to_idx = 0, 0
|
||||||
id_from = 0
|
for done in pkl_data["dones"]:
|
||||||
id_to = 0
|
to_idx += 1
|
||||||
ep_idx = 0
|
if not done:
|
||||||
total_frames = pkl_data["actions"].shape[0]
|
|
||||||
for i in tqdm.tqdm(range(total_frames)):
|
|
||||||
id_to += 1
|
|
||||||
|
|
||||||
if not pkl_data["dones"][i]:
|
|
||||||
continue
|
continue
|
||||||
|
from_ids.append(from_idx)
|
||||||
|
to_ids.append(to_idx)
|
||||||
|
from_idx = to_idx
|
||||||
|
|
||||||
num_frames = id_to - id_from
|
num_episodes = len(from_ids)
|
||||||
|
|
||||||
image = torch.tensor(pkl_data["observations"]["rgb"][id_from:id_to])
|
ep_dicts = []
|
||||||
|
ep_ids = episodes if episodes else range(num_episodes)
|
||||||
|
for ep_idx, selected_ep_idx in tqdm.tqdm(enumerate(ep_ids)):
|
||||||
|
from_idx = from_ids[selected_ep_idx]
|
||||||
|
to_idx = to_ids[selected_ep_idx]
|
||||||
|
num_frames = to_idx - from_idx
|
||||||
|
|
||||||
|
image = torch.tensor(pkl_data["observations"]["rgb"][from_idx:to_idx])
|
||||||
image = einops.rearrange(image, "b c h w -> b h w c")
|
image = einops.rearrange(image, "b c h w -> b h w c")
|
||||||
state = torch.tensor(pkl_data["observations"]["state"][id_from:id_to])
|
state = torch.tensor(pkl_data["observations"]["state"][from_idx:to_idx])
|
||||||
action = torch.tensor(pkl_data["actions"][id_from:id_to])
|
action = torch.tensor(pkl_data["actions"][from_idx:to_idx])
|
||||||
# TODO(rcadene): we have a missing last frame which is the observation when the env is done
|
# TODO(rcadene): we have a missing last frame which is the observation when the env is done
|
||||||
# it is critical to have this frame for tdmpc to predict a "done observation/state"
|
# it is critical to have this frame for tdmpc to predict a "done observation/state"
|
||||||
# next_image = torch.tensor(pkl_data["next_observations"]["rgb"][id_from:id_to])
|
# next_image = torch.tensor(pkl_data["next_observations"]["rgb"][from_idx:to_idx])
|
||||||
# next_state = torch.tensor(pkl_data["next_observations"]["state"][id_from:id_to])
|
# next_state = torch.tensor(pkl_data["next_observations"]["state"][from_idx:to_idx])
|
||||||
next_reward = torch.tensor(pkl_data["rewards"][id_from:id_to])
|
next_reward = torch.tensor(pkl_data["rewards"][from_idx:to_idx])
|
||||||
next_done = torch.tensor(pkl_data["dones"][id_from:id_to])
|
next_done = torch.tensor(pkl_data["dones"][from_idx:to_idx])
|
||||||
|
|
||||||
ep_dict = {}
|
ep_dict = {}
|
||||||
|
|
||||||
|
@ -92,12 +98,12 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||||
img_key = "observation.image"
|
img_key = "observation.image"
|
||||||
if video:
|
if video:
|
||||||
# save png images in temporary directory
|
# save png images in temporary directory
|
||||||
tmp_imgs_dir = out_dir / "tmp_images"
|
tmp_imgs_dir = videos_dir / "tmp_images"
|
||||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||||
|
|
||||||
# encode images to a mp4 video
|
# encode images to a mp4 video
|
||||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||||
video_path = out_dir / "videos" / fname
|
video_path = videos_dir / fname
|
||||||
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
||||||
|
|
||||||
# clean temporary images directory
|
# clean temporary images directory
|
||||||
|
@ -119,18 +125,11 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||||
ep_dict["next.done"] = next_done
|
ep_dict["next.done"] = next_done
|
||||||
ep_dicts.append(ep_dict)
|
ep_dicts.append(ep_dict)
|
||||||
|
|
||||||
episode_data_index["from"].append(id_from)
|
|
||||||
episode_data_index["to"].append(id_from + num_frames)
|
|
||||||
|
|
||||||
id_from = id_to
|
|
||||||
ep_idx += 1
|
|
||||||
|
|
||||||
# process first episode only
|
|
||||||
if debug:
|
|
||||||
break
|
|
||||||
|
|
||||||
data_dict = concatenate_episodes(ep_dicts)
|
data_dict = concatenate_episodes(ep_dicts)
|
||||||
return data_dict, episode_data_index
|
|
||||||
|
total_frames = data_dict["frame_index"].shape[0]
|
||||||
|
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||||
|
return data_dict
|
||||||
|
|
||||||
|
|
||||||
def to_hf_dataset(data_dict, video):
|
def to_hf_dataset(data_dict, video):
|
||||||
|
@ -161,16 +160,22 @@ def to_hf_dataset(data_dict, video):
|
||||||
return hf_dataset
|
return hf_dataset
|
||||||
|
|
||||||
|
|
||||||
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
|
def from_raw_to_lerobot_format(
|
||||||
|
raw_dir: Path,
|
||||||
|
videos_dir: Path,
|
||||||
|
fps: int | None = None,
|
||||||
|
video: bool = True,
|
||||||
|
episodes: list[int] | None = None,
|
||||||
|
):
|
||||||
# sanity check
|
# sanity check
|
||||||
check_format(raw_dir)
|
check_format(raw_dir)
|
||||||
|
|
||||||
if fps is None:
|
if fps is None:
|
||||||
fps = 15
|
fps = 15
|
||||||
|
|
||||||
data_dict, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug)
|
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
|
||||||
hf_dataset = to_hf_dataset(data_dict, video)
|
hf_dataset = to_hf_dataset(data_dict, video)
|
||||||
|
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||||
info = {
|
info = {
|
||||||
"fps": fps,
|
"fps": fps,
|
||||||
"video": video,
|
"video": video,
|
||||||
|
|
|
@ -18,58 +18,39 @@ Use this script to convert your dataset into LeRobot dataset format and upload i
|
||||||
or store it locally. LeRobot dataset format is lightweight, fast to load from, and does not require any
|
or store it locally. LeRobot dataset format is lightweight, fast to load from, and does not require any
|
||||||
installation of neural net specific packages like pytorch, tensorflow, jax.
|
installation of neural net specific packages like pytorch, tensorflow, jax.
|
||||||
|
|
||||||
Example:
|
Example of how to download raw datasets, convert them into LeRobotDataset format, and push them to the hub:
|
||||||
```
|
```
|
||||||
python lerobot/scripts/push_dataset_to_hub.py \
|
python lerobot/scripts/push_dataset_to_hub.py \
|
||||||
--data-dir data \
|
--raw-dir data/pusht_raw \
|
||||||
--dataset-id pusht \
|
|
||||||
--raw-format pusht_zarr \
|
--raw-format pusht_zarr \
|
||||||
--community-id lerobot \
|
--repo-id lerobot/pusht
|
||||||
--dry-run 1 \
|
|
||||||
--save-to-disk 1 \
|
|
||||||
--save-tests-to-disk 0 \
|
|
||||||
--debug 1
|
|
||||||
|
|
||||||
python lerobot/scripts/push_dataset_to_hub.py \
|
python lerobot/scripts/push_dataset_to_hub.py \
|
||||||
--data-dir data \
|
--raw-dir data/xarm_lift_medium_raw \
|
||||||
--dataset-id xarm_lift_medium \
|
|
||||||
--raw-format xarm_pkl \
|
--raw-format xarm_pkl \
|
||||||
--community-id lerobot \
|
--repo-id lerobot/xarm_lift_medium
|
||||||
--dry-run 1 \
|
|
||||||
--save-to-disk 1 \
|
|
||||||
--save-tests-to-disk 0 \
|
|
||||||
--debug 1
|
|
||||||
|
|
||||||
python lerobot/scripts/push_dataset_to_hub.py \
|
python lerobot/scripts/push_dataset_to_hub.py \
|
||||||
--data-dir data \
|
--raw-dir data/aloha_sim_insertion_scripted_raw \
|
||||||
--dataset-id aloha_sim_insertion_scripted \
|
|
||||||
--raw-format aloha_hdf5 \
|
--raw-format aloha_hdf5 \
|
||||||
--community-id lerobot \
|
--repo-id lerobot/aloha_sim_insertion_scripted
|
||||||
--dry-run 1 \
|
|
||||||
--save-to-disk 1 \
|
|
||||||
--save-tests-to-disk 0 \
|
|
||||||
--debug 1
|
|
||||||
|
|
||||||
python lerobot/scripts/push_dataset_to_hub.py \
|
python lerobot/scripts/push_dataset_to_hub.py \
|
||||||
--data-dir data \
|
--raw-dir data/umi_cup_in_the_wild_raw \
|
||||||
--dataset-id umi_cup_in_the_wild \
|
|
||||||
--raw-format umi_zarr \
|
--raw-format umi_zarr \
|
||||||
--community-id lerobot \
|
--repo-id lerobot/umi_cup_in_the_wild
|
||||||
--dry-run 1 \
|
|
||||||
--save-to-disk 1 \
|
|
||||||
--save-tests-to-disk 0 \
|
|
||||||
--debug 1
|
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import shutil
|
import shutil
|
||||||
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi, create_branch
|
||||||
from safetensors.torch import save_file
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
from lerobot.common.datasets.compute_stats import compute_stats
|
from lerobot.common.datasets.compute_stats import compute_stats
|
||||||
|
@ -85,8 +66,8 @@ def get_from_raw_to_lerobot_format_fn(raw_format: str):
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format
|
from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format
|
||||||
elif raw_format == "aloha_hdf5":
|
elif raw_format == "aloha_hdf5":
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format
|
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format
|
||||||
elif raw_format == "aloha_dora":
|
elif raw_format == "dora_parquet":
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.aloha_dora_format import from_raw_to_lerobot_format
|
from lerobot.common.datasets.push_dataset_to_hub.dora_parquet_format import from_raw_to_lerobot_format
|
||||||
elif raw_format == "xarm_pkl":
|
elif raw_format == "xarm_pkl":
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format import from_raw_to_lerobot_format
|
from lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format import from_raw_to_lerobot_format
|
||||||
else:
|
else:
|
||||||
|
@ -147,39 +128,61 @@ def push_videos_to_hub(repo_id: str, videos_dir: str | Path, revision: str | Non
|
||||||
|
|
||||||
|
|
||||||
def push_dataset_to_hub(
|
def push_dataset_to_hub(
|
||||||
data_dir: Path,
|
raw_dir: Path,
|
||||||
dataset_id: str,
|
raw_format: str,
|
||||||
raw_format: str | None,
|
repo_id: str,
|
||||||
community_id: str,
|
push_to_hub: bool = True,
|
||||||
revision: str,
|
local_dir: Path | None = None,
|
||||||
dry_run: bool,
|
fps: int | None = None,
|
||||||
save_to_disk: bool,
|
video: bool = True,
|
||||||
tests_data_dir: Path,
|
batch_size: int = 32,
|
||||||
save_tests_to_disk: bool,
|
num_workers: int = 8,
|
||||||
fps: int | None,
|
episodes: list[int] | None = None,
|
||||||
video: bool,
|
force_override: bool = False,
|
||||||
batch_size: int,
|
cache_dir: Path = Path("/tmp"),
|
||||||
num_workers: int,
|
tests_data_dir: Path | None = None,
|
||||||
debug: bool,
|
|
||||||
):
|
):
|
||||||
repo_id = f"{community_id}/{dataset_id}"
|
# Check repo_id is well formated
|
||||||
|
if len(repo_id.split("/")) != 2:
|
||||||
|
raise ValueError(
|
||||||
|
f"`repo_id` is expected to contain a community or user id `/` the name of the dataset (e.g. 'lerobot/pusht'), but instead contains '{repo_id}'."
|
||||||
|
)
|
||||||
|
user_id, dataset_id = repo_id.split("/")
|
||||||
|
|
||||||
raw_dir = data_dir / f"{dataset_id}_raw"
|
# Robustify when `raw_dir` is str instead of Path
|
||||||
|
raw_dir = Path(raw_dir)
|
||||||
|
if not raw_dir.exists():
|
||||||
|
raise NotADirectoryError(
|
||||||
|
f"{raw_dir} does not exists. Check your paths or run this command to download an existing raw dataset on the hub:"
|
||||||
|
f"python lerobot/common/datasets/push_dataset_to_hub/_download_raw.py --raw-dir your/raw/dir --repo-id your/repo/id_raw"
|
||||||
|
)
|
||||||
|
|
||||||
out_dir = data_dir / repo_id
|
if local_dir:
|
||||||
meta_data_dir = out_dir / "meta_data"
|
# Robustify when `local_dir` is str instead of Path
|
||||||
videos_dir = out_dir / "videos"
|
local_dir = Path(local_dir)
|
||||||
|
|
||||||
tests_out_dir = tests_data_dir / repo_id
|
# Send warning if local_dir isn't well formated
|
||||||
tests_meta_data_dir = tests_out_dir / "meta_data"
|
if local_dir.parts[-2] != user_id or local_dir.parts[-1] != dataset_id:
|
||||||
tests_videos_dir = tests_out_dir / "videos"
|
warnings.warn(
|
||||||
|
f"`local_dir` ({local_dir}) doesn't contain a community or user id `/` the name of the dataset that match the `repo_id` (e.g. 'data/lerobot/pusht'). Following this naming convention is advised, but not mandatory.",
|
||||||
|
stacklevel=1,
|
||||||
|
)
|
||||||
|
|
||||||
if out_dir.exists():
|
# Check we don't override an existing `local_dir` by mistake
|
||||||
shutil.rmtree(out_dir)
|
if local_dir.exists():
|
||||||
|
if force_override:
|
||||||
|
shutil.rmtree(local_dir)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"`local_dir` already exists ({local_dir}). Use `--force-override 1`.")
|
||||||
|
|
||||||
if tests_out_dir.exists() and save_tests_to_disk:
|
meta_data_dir = local_dir / "meta_data"
|
||||||
shutil.rmtree(tests_out_dir)
|
videos_dir = local_dir / "videos"
|
||||||
|
else:
|
||||||
|
# Temporary directory used to store images, videos, meta_data
|
||||||
|
meta_data_dir = Path(cache_dir) / "meta_data"
|
||||||
|
videos_dir = Path(cache_dir) / "videos"
|
||||||
|
|
||||||
|
# Download the raw dataset if available
|
||||||
if not raw_dir.exists():
|
if not raw_dir.exists():
|
||||||
download_raw(raw_dir, dataset_id)
|
download_raw(raw_dir, dataset_id)
|
||||||
|
|
||||||
|
@ -188,14 +191,14 @@ def push_dataset_to_hub(
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
# raw_format = auto_find_raw_format(raw_dir)
|
# raw_format = auto_find_raw_format(raw_dir)
|
||||||
|
|
||||||
from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format)
|
|
||||||
|
|
||||||
# convert dataset from original raw format to LeRobot format
|
# convert dataset from original raw format to LeRobot format
|
||||||
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(raw_dir, out_dir, fps, video, debug)
|
from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format)
|
||||||
|
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(
|
||||||
|
raw_dir, videos_dir, fps, video, episodes
|
||||||
|
)
|
||||||
|
|
||||||
lerobot_dataset = LeRobotDataset.from_preloaded(
|
lerobot_dataset = LeRobotDataset.from_preloaded(
|
||||||
repo_id=repo_id,
|
repo_id=repo_id,
|
||||||
version=revision,
|
|
||||||
hf_dataset=hf_dataset,
|
hf_dataset=hf_dataset,
|
||||||
episode_data_index=episode_data_index,
|
episode_data_index=episode_data_index,
|
||||||
info=info,
|
info=info,
|
||||||
|
@ -203,103 +206,80 @@ def push_dataset_to_hub(
|
||||||
)
|
)
|
||||||
stats = compute_stats(lerobot_dataset, batch_size, num_workers)
|
stats = compute_stats(lerobot_dataset, batch_size, num_workers)
|
||||||
|
|
||||||
if save_to_disk:
|
if local_dir:
|
||||||
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
|
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
|
||||||
hf_dataset.save_to_disk(str(out_dir / "train"))
|
hf_dataset.save_to_disk(str(local_dir / "train"))
|
||||||
|
|
||||||
if not dry_run or save_to_disk:
|
if push_to_hub or local_dir:
|
||||||
# mandatory for upload
|
# mandatory for upload
|
||||||
save_meta_data(info, stats, episode_data_index, meta_data_dir)
|
save_meta_data(info, stats, episode_data_index, meta_data_dir)
|
||||||
|
|
||||||
if not dry_run:
|
if push_to_hub:
|
||||||
# TODO(rcadene): token needs to be a str | None
|
hf_dataset.push_to_hub(repo_id, revision="main")
|
||||||
hf_dataset.push_to_hub(repo_id, token=True, revision="main")
|
|
||||||
hf_dataset.push_to_hub(repo_id, token=True, revision=revision)
|
|
||||||
|
|
||||||
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
|
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
|
||||||
push_meta_data_to_hub(repo_id, meta_data_dir, revision=revision)
|
|
||||||
|
|
||||||
if video:
|
if video:
|
||||||
push_videos_to_hub(repo_id, videos_dir, revision="main")
|
push_videos_to_hub(repo_id, videos_dir, revision="main")
|
||||||
push_videos_to_hub(repo_id, videos_dir, revision=revision)
|
create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)
|
||||||
|
|
||||||
if save_tests_to_disk:
|
if tests_data_dir:
|
||||||
# get the first episode
|
# get the first episode
|
||||||
num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]
|
num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]
|
||||||
test_hf_dataset = hf_dataset.select(range(num_items_first_ep))
|
test_hf_dataset = hf_dataset.select(range(num_items_first_ep))
|
||||||
|
|
||||||
test_hf_dataset = test_hf_dataset.with_format(None)
|
test_hf_dataset = test_hf_dataset.with_format(None)
|
||||||
test_hf_dataset.save_to_disk(str(tests_out_dir / "train"))
|
test_hf_dataset.save_to_disk(str(tests_data_dir / repo_id / "train"))
|
||||||
|
|
||||||
save_meta_data(info, stats, episode_data_index, tests_meta_data_dir)
|
tests_meta_data = tests_data_dir / repo_id / "meta_data"
|
||||||
|
save_meta_data(info, stats, episode_data_index, tests_meta_data)
|
||||||
|
|
||||||
# copy videos of first episode to tests directory
|
# copy videos of first episode to tests directory
|
||||||
episode_index = 0
|
episode_index = 0
|
||||||
|
tests_videos_dir = tests_data_dir / repo_id / "videos"
|
||||||
tests_videos_dir.mkdir(parents=True, exist_ok=True)
|
tests_videos_dir.mkdir(parents=True, exist_ok=True)
|
||||||
for key in lerobot_dataset.video_frame_keys:
|
for key in lerobot_dataset.video_frame_keys:
|
||||||
fname = f"{key}_episode_{episode_index:06d}.mp4"
|
fname = f"{key}_episode_{episode_index:06d}.mp4"
|
||||||
shutil.copy(videos_dir / fname, tests_videos_dir / fname)
|
shutil.copy(videos_dir / fname, tests_videos_dir / fname)
|
||||||
|
|
||||||
if not save_to_disk and out_dir.exists():
|
if local_dir is None:
|
||||||
# remove possible temporary files remaining in the output directory
|
# clear cache
|
||||||
shutil.rmtree(out_dir)
|
shutil.rmtree(meta_data_dir)
|
||||||
|
shutil.rmtree(videos_dir)
|
||||||
|
|
||||||
|
return lerobot_dataset
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--data-dir",
|
"--raw-dir",
|
||||||
type=Path,
|
type=Path,
|
||||||
required=True,
|
required=True,
|
||||||
help="Root directory containing datasets (e.g. `data` or `tmp/data` or `/tmp/lerobot/data`).",
|
help="Directory containing input raw datasets (e.g. `data/aloha_mobile_chair_raw` or `data/pusht_raw).",
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--dataset-id",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Name of the dataset (e.g. `pusht`, `aloha_sim_insertion_human`), which matches the folder where the data is stored (e.g. `data/pusht`).",
|
|
||||||
)
|
)
|
||||||
|
# TODO(rcadene): add automatic detection of the format
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--raw-format",
|
"--raw-format",
|
||||||
type=str,
|
type=str,
|
||||||
help="Dataset type (e.g. `pusht_zarr`, `umi_zarr`, `aloha_hdf5`, `xarm_pkl`). If not provided, will be detected automatically.",
|
required=True,
|
||||||
|
help="Dataset type (e.g. `pusht_zarr`, `umi_zarr`, `aloha_hdf5`, `xarm_pkl`, `dora_parquet`).",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--community-id",
|
"--repo-id",
|
||||||
type=str,
|
type=str,
|
||||||
default="lerobot",
|
required=True,
|
||||||
help="Community or user ID under which the dataset will be hosted on the Hub.",
|
help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--revision",
|
"--local-dir",
|
||||||
type=str,
|
|
||||||
default=CODEBASE_VERSION,
|
|
||||||
help="Codebase version used to generate the dataset.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--dry-run",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="Run everything without uploading to hub, for testing purposes or storing a dataset locally.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--save-to-disk",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="Save the dataset in the directory specified by `--data-dir`.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--tests-data-dir",
|
|
||||||
type=Path,
|
type=Path,
|
||||||
default="tests/data",
|
help="When provided, writes the dataset converted to LeRobotDataset format in this directory (e.g. `data/lerobot/aloha_mobile_chair`).",
|
||||||
help="Directory containing tests artifacts datasets.",
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--save-tests-to-disk",
|
"--push-to-hub",
|
||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=1,
|
||||||
help="Save the dataset with 1 episode used for unit tests in the directory specified by `--tests-data-dir`.",
|
help="Upload to hub.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--fps",
|
"--fps",
|
||||||
|
@ -325,10 +305,21 @@ def main():
|
||||||
help="Number of processes of Dataloader for computing the dataset statistics.",
|
help="Number of processes of Dataloader for computing the dataset statistics.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--debug",
|
"--episodes",
|
||||||
|
type=int,
|
||||||
|
nargs="*",
|
||||||
|
help="When provided, only converts the provided episodes (e.g `--episodes 2 3 4`). Useful to test the code on 1 episode.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--force-override",
|
||||||
type=int,
|
type=int,
|
||||||
default=0,
|
default=0,
|
||||||
help="Debug mode process the first episode only.",
|
help="When set to 1, removes provided output directory if it already exists. By default, raises a ValueError exception.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tests-data-dir",
|
||||||
|
type=Path,
|
||||||
|
help="When provided, save tests artifacts into the given directory for (e.g. `--tests-data-dir tests/data/lerobot/pusht`).",
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
|
@ -0,0 +1,352 @@
|
||||||
|
"""
|
||||||
|
This file contains generic tests to ensure that nothing breaks if we modify the push_dataset_to_hub API.
|
||||||
|
Also, this file contains backward compatibility tests. Because they are slow and require to download the raw datasets,
|
||||||
|
we skip them for now in our CI.
|
||||||
|
|
||||||
|
Example to run backward compatiblity tests locally:
|
||||||
|
```
|
||||||
|
DATA_DIR=tests/data python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
from lerobot.common.datasets.push_dataset_to_hub.utils import save_images_concurrently
|
||||||
|
from lerobot.common.datasets.video_utils import encode_video_frames
|
||||||
|
from lerobot.scripts.push_dataset_to_hub import push_dataset_to_hub
|
||||||
|
from tests.utils import require_package_arg
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_download_raw_pusht(raw_dir, num_frames=4, num_episodes=3):
|
||||||
|
import zarr
|
||||||
|
|
||||||
|
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
zarr_path = raw_dir / "pusht_cchi_v7_replay.zarr"
|
||||||
|
store = zarr.DirectoryStore(zarr_path)
|
||||||
|
zarr_data = zarr.group(store=store)
|
||||||
|
|
||||||
|
zarr_data.create_dataset(
|
||||||
|
"data/action", shape=(num_frames, 1), chunks=(num_frames, 1), dtype=np.float32, overwrite=True
|
||||||
|
)
|
||||||
|
zarr_data.create_dataset(
|
||||||
|
"data/img",
|
||||||
|
shape=(num_frames, 96, 96, 3),
|
||||||
|
chunks=(num_frames, 96, 96, 3),
|
||||||
|
dtype=np.uint8,
|
||||||
|
overwrite=True,
|
||||||
|
)
|
||||||
|
zarr_data.create_dataset(
|
||||||
|
"data/n_contacts", shape=(num_frames, 2), chunks=(num_frames, 2), dtype=np.float32, overwrite=True
|
||||||
|
)
|
||||||
|
zarr_data.create_dataset(
|
||||||
|
"data/state", shape=(num_frames, 5), chunks=(num_frames, 5), dtype=np.float32, overwrite=True
|
||||||
|
)
|
||||||
|
zarr_data.create_dataset(
|
||||||
|
"data/keypoint", shape=(num_frames, 9, 2), chunks=(num_frames, 9, 2), dtype=np.float32, overwrite=True
|
||||||
|
)
|
||||||
|
zarr_data.create_dataset(
|
||||||
|
"meta/episode_ends", shape=(num_episodes,), chunks=(num_episodes,), dtype=np.int32, overwrite=True
|
||||||
|
)
|
||||||
|
|
||||||
|
zarr_data["data/action"][:] = np.random.randn(num_frames, 1)
|
||||||
|
zarr_data["data/img"][:] = np.random.randint(0, 255, size=(num_frames, 96, 96, 3), dtype=np.uint8)
|
||||||
|
zarr_data["data/n_contacts"][:] = np.random.randn(num_frames, 2)
|
||||||
|
zarr_data["data/state"][:] = np.random.randn(num_frames, 5)
|
||||||
|
zarr_data["data/keypoint"][:] = np.random.randn(num_frames, 9, 2)
|
||||||
|
zarr_data["meta/episode_ends"][:] = np.array([1, 3, 4])
|
||||||
|
|
||||||
|
store.close()
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_download_raw_umi(raw_dir, num_frames=4, num_episodes=3):
|
||||||
|
import zarr
|
||||||
|
|
||||||
|
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
zarr_path = raw_dir / "cup_in_the_wild.zarr"
|
||||||
|
store = zarr.DirectoryStore(zarr_path)
|
||||||
|
zarr_data = zarr.group(store=store)
|
||||||
|
|
||||||
|
zarr_data.create_dataset(
|
||||||
|
"data/camera0_rgb",
|
||||||
|
shape=(num_frames, 96, 96, 3),
|
||||||
|
chunks=(num_frames, 96, 96, 3),
|
||||||
|
dtype=np.uint8,
|
||||||
|
overwrite=True,
|
||||||
|
)
|
||||||
|
zarr_data.create_dataset(
|
||||||
|
"data/robot0_demo_end_pose",
|
||||||
|
shape=(num_frames, 5),
|
||||||
|
chunks=(num_frames, 5),
|
||||||
|
dtype=np.float32,
|
||||||
|
overwrite=True,
|
||||||
|
)
|
||||||
|
zarr_data.create_dataset(
|
||||||
|
"data/robot0_demo_start_pose",
|
||||||
|
shape=(num_frames, 5),
|
||||||
|
chunks=(num_frames, 5),
|
||||||
|
dtype=np.float32,
|
||||||
|
overwrite=True,
|
||||||
|
)
|
||||||
|
zarr_data.create_dataset(
|
||||||
|
"data/robot0_eef_pos", shape=(num_frames, 5), chunks=(num_frames, 5), dtype=np.float32, overwrite=True
|
||||||
|
)
|
||||||
|
zarr_data.create_dataset(
|
||||||
|
"data/robot0_eef_rot_axis_angle",
|
||||||
|
shape=(num_frames, 5),
|
||||||
|
chunks=(num_frames, 5),
|
||||||
|
dtype=np.float32,
|
||||||
|
overwrite=True,
|
||||||
|
)
|
||||||
|
zarr_data.create_dataset(
|
||||||
|
"data/robot0_gripper_width",
|
||||||
|
shape=(num_frames, 5),
|
||||||
|
chunks=(num_frames, 5),
|
||||||
|
dtype=np.float32,
|
||||||
|
overwrite=True,
|
||||||
|
)
|
||||||
|
zarr_data.create_dataset(
|
||||||
|
"meta/episode_ends", shape=(num_episodes,), chunks=(num_episodes,), dtype=np.int32, overwrite=True
|
||||||
|
)
|
||||||
|
|
||||||
|
zarr_data["data/camera0_rgb"][:] = np.random.randint(0, 255, size=(num_frames, 96, 96, 3), dtype=np.uint8)
|
||||||
|
zarr_data["data/robot0_demo_end_pose"][:] = np.random.randn(num_frames, 5)
|
||||||
|
zarr_data["data/robot0_demo_start_pose"][:] = np.random.randn(num_frames, 5)
|
||||||
|
zarr_data["data/robot0_eef_pos"][:] = np.random.randn(num_frames, 5)
|
||||||
|
zarr_data["data/robot0_eef_rot_axis_angle"][:] = np.random.randn(num_frames, 5)
|
||||||
|
zarr_data["data/robot0_gripper_width"][:] = np.random.randn(num_frames, 5)
|
||||||
|
zarr_data["meta/episode_ends"][:] = np.array([1, 3, 4])
|
||||||
|
|
||||||
|
store.close()
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_download_raw_xarm(raw_dir, num_frames=4):
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
dataset_dict = {
|
||||||
|
"observations": {
|
||||||
|
"rgb": np.random.randint(0, 255, size=(num_frames, 3, 84, 84), dtype=np.uint8),
|
||||||
|
"state": np.random.randn(num_frames, 4),
|
||||||
|
},
|
||||||
|
"actions": np.random.randn(num_frames, 3),
|
||||||
|
"rewards": np.random.randn(num_frames),
|
||||||
|
"masks": np.random.randn(num_frames),
|
||||||
|
"dones": np.array([False, True, True, True]),
|
||||||
|
}
|
||||||
|
|
||||||
|
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
pkl_path = raw_dir / "buffer.pkl"
|
||||||
|
with open(pkl_path, "wb") as f:
|
||||||
|
pickle.dump(dataset_dict, f)
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_download_raw_aloha(raw_dir, num_frames=6, num_episodes=3):
|
||||||
|
import h5py
|
||||||
|
|
||||||
|
for ep_idx in range(num_episodes):
|
||||||
|
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
path_h5 = raw_dir / f"episode_{ep_idx}.hdf5"
|
||||||
|
with h5py.File(str(path_h5), "w") as f:
|
||||||
|
f.create_dataset("action", data=np.random.randn(num_frames // num_episodes, 14))
|
||||||
|
f.create_dataset("observations/qpos", data=np.random.randn(num_frames // num_episodes, 14))
|
||||||
|
f.create_dataset("observations/qvel", data=np.random.randn(num_frames // num_episodes, 14))
|
||||||
|
f.create_dataset(
|
||||||
|
"observations/images/top",
|
||||||
|
data=np.random.randint(
|
||||||
|
0, 255, size=(num_frames // num_episodes, 480, 640, 3), dtype=np.uint8
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_download_raw_dora(raw_dir, num_frames=6, num_episodes=3, fps=30):
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
import pandas
|
||||||
|
|
||||||
|
def write_parquet(key, timestamps, values):
|
||||||
|
data = {
|
||||||
|
"timestamp_utc": timestamps,
|
||||||
|
key: values,
|
||||||
|
}
|
||||||
|
df = pandas.DataFrame(data)
|
||||||
|
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
df.to_parquet(raw_dir / f"{key}.parquet", engine="pyarrow")
|
||||||
|
|
||||||
|
episode_indices = [None, None, -1, None, None, -1, None, None, -1]
|
||||||
|
episode_indices_mapping = [0, 0, 0, 1, 1, 1, 2, 2, 2]
|
||||||
|
frame_indices = [0, 1, -1, 0, 1, -1, 0, 1, -1]
|
||||||
|
|
||||||
|
cam_key = "observation.images.cam_high"
|
||||||
|
timestamps = []
|
||||||
|
actions = []
|
||||||
|
states = []
|
||||||
|
frames = []
|
||||||
|
# `+ num_episodes`` for buffer frames associated to episode_index=-1
|
||||||
|
for i, frame_idx in enumerate(frame_indices):
|
||||||
|
t_utc = datetime.now(timezone.utc) + timedelta(seconds=i / fps)
|
||||||
|
action = np.random.randn(21).tolist()
|
||||||
|
state = np.random.randn(21).tolist()
|
||||||
|
ep_idx = episode_indices_mapping[i]
|
||||||
|
frame = [{"path": f"videos/{cam_key}_episode_{ep_idx:06d}.mp4", "timestamp": frame_idx / fps}]
|
||||||
|
timestamps.append(t_utc)
|
||||||
|
actions.append(action)
|
||||||
|
states.append(state)
|
||||||
|
frames.append(frame)
|
||||||
|
|
||||||
|
write_parquet(cam_key, timestamps, frames)
|
||||||
|
write_parquet("observation.state", timestamps, states)
|
||||||
|
write_parquet("action", timestamps, actions)
|
||||||
|
write_parquet("episode_index", timestamps, episode_indices)
|
||||||
|
|
||||||
|
# write fake mp4 file for each episode
|
||||||
|
for ep_idx in range(num_episodes):
|
||||||
|
imgs_array = np.random.randint(0, 255, size=(num_frames // num_episodes, 480, 640, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
tmp_imgs_dir = raw_dir / "tmp_images"
|
||||||
|
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||||
|
|
||||||
|
fname = f"{cam_key}_episode_{ep_idx:06d}.mp4"
|
||||||
|
video_path = raw_dir / "videos" / fname
|
||||||
|
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_download_raw(raw_dir, repo_id):
|
||||||
|
if "wrist_gripper" in repo_id:
|
||||||
|
_mock_download_raw_dora(raw_dir)
|
||||||
|
elif "aloha" in repo_id:
|
||||||
|
_mock_download_raw_aloha(raw_dir)
|
||||||
|
elif "pusht" in repo_id:
|
||||||
|
_mock_download_raw_pusht(raw_dir)
|
||||||
|
elif "xarm" in repo_id:
|
||||||
|
_mock_download_raw_xarm(raw_dir)
|
||||||
|
elif "umi" in repo_id:
|
||||||
|
_mock_download_raw_umi(raw_dir)
|
||||||
|
else:
|
||||||
|
raise ValueError(repo_id)
|
||||||
|
|
||||||
|
|
||||||
|
def test_push_dataset_to_hub_invalid_repo_id(tmpdir):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
push_dataset_to_hub(Path(tmpdir), "raw_format", "invalid_repo_id")
|
||||||
|
|
||||||
|
|
||||||
|
def test_push_dataset_to_hub_out_dir_force_override_false(tmpdir):
|
||||||
|
tmpdir = Path(tmpdir)
|
||||||
|
out_dir = tmpdir / "out"
|
||||||
|
raw_dir = tmpdir / "raw"
|
||||||
|
# mkdir to skip download
|
||||||
|
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
push_dataset_to_hub(
|
||||||
|
raw_dir=raw_dir,
|
||||||
|
raw_format="some_format",
|
||||||
|
repo_id="user/dataset",
|
||||||
|
local_dir=out_dir,
|
||||||
|
force_override=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"required_packages, raw_format, repo_id",
|
||||||
|
[
|
||||||
|
(["gym-pusht"], "pusht_zarr", "lerobot/pusht"),
|
||||||
|
(None, "xarm_pkl", "lerobot/xarm_lift_medium"),
|
||||||
|
(None, "aloha_hdf5", "lerobot/aloha_sim_insertion_scripted"),
|
||||||
|
(["imagecodecs"], "umi_zarr", "lerobot/umi_cup_in_the_wild"),
|
||||||
|
(None, "dora_parquet", "cadene/wrist_gripper"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@require_package_arg
|
||||||
|
def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_id):
|
||||||
|
num_episodes = 3
|
||||||
|
tmpdir = Path(tmpdir)
|
||||||
|
|
||||||
|
raw_dir = tmpdir / f"{repo_id}_raw"
|
||||||
|
_mock_download_raw(raw_dir, repo_id)
|
||||||
|
|
||||||
|
local_dir = tmpdir / repo_id
|
||||||
|
|
||||||
|
lerobot_dataset = push_dataset_to_hub(
|
||||||
|
raw_dir=raw_dir,
|
||||||
|
raw_format=raw_format,
|
||||||
|
repo_id=repo_id,
|
||||||
|
push_to_hub=False,
|
||||||
|
local_dir=local_dir,
|
||||||
|
force_override=False,
|
||||||
|
cache_dir=tmpdir / "cache",
|
||||||
|
)
|
||||||
|
|
||||||
|
# minimal generic tests on the local directory containing LeRobotDataset
|
||||||
|
assert (local_dir / "meta_data" / "info.json").exists()
|
||||||
|
assert (local_dir / "meta_data" / "stats.safetensors").exists()
|
||||||
|
assert (local_dir / "meta_data" / "episode_data_index.safetensors").exists()
|
||||||
|
for i in range(num_episodes):
|
||||||
|
for cam_key in lerobot_dataset.camera_keys:
|
||||||
|
assert (local_dir / "videos" / f"{cam_key}_episode_{i:06d}.mp4").exists()
|
||||||
|
assert (local_dir / "train" / "dataset_info.json").exists()
|
||||||
|
assert (local_dir / "train" / "state.json").exists()
|
||||||
|
assert len(list((local_dir / "train").glob("*.arrow"))) > 0
|
||||||
|
|
||||||
|
# minimal generic tests on the item
|
||||||
|
item = lerobot_dataset[0]
|
||||||
|
assert "index" in item
|
||||||
|
assert "episode_index" in item
|
||||||
|
assert "timestamp" in item
|
||||||
|
for cam_key in lerobot_dataset.camera_keys:
|
||||||
|
assert cam_key in item
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"raw_format, repo_id",
|
||||||
|
[
|
||||||
|
# TODO(rcadene): add raw dataset test artifacts
|
||||||
|
("pusht_zarr", "lerobot/pusht"),
|
||||||
|
("xarm_pkl", "lerobot/xarm_lift_medium"),
|
||||||
|
("aloha_hdf5", "lerobot/aloha_sim_insertion_scripted"),
|
||||||
|
("umi_zarr", "lerobot/umi_cup_in_the_wild"),
|
||||||
|
("dora_parquet", "cadene/wrist_gripper"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.skip(
|
||||||
|
"Not compatible with our CI since it downloads raw datasets. Run with `DATA_DIR=tests/data python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility`"
|
||||||
|
)
|
||||||
|
def test_push_dataset_to_hub_pusht_backward_compatibility(tmpdir, raw_format, repo_id):
|
||||||
|
_, dataset_id = repo_id.split("/")
|
||||||
|
|
||||||
|
tmpdir = Path(tmpdir)
|
||||||
|
raw_dir = tmpdir / f"{dataset_id}_raw"
|
||||||
|
local_dir = tmpdir / repo_id
|
||||||
|
|
||||||
|
push_dataset_to_hub(
|
||||||
|
raw_dir=raw_dir,
|
||||||
|
raw_format=raw_format,
|
||||||
|
repo_id=repo_id,
|
||||||
|
push_to_hub=False,
|
||||||
|
local_dir=local_dir,
|
||||||
|
force_override=False,
|
||||||
|
cache_dir=tmpdir / "cache",
|
||||||
|
episodes=[0],
|
||||||
|
)
|
||||||
|
|
||||||
|
ds_actual = LeRobotDataset(repo_id, root=tmpdir)
|
||||||
|
ds_reference = LeRobotDataset(repo_id)
|
||||||
|
|
||||||
|
assert len(ds_reference.hf_dataset) == len(ds_actual.hf_dataset)
|
||||||
|
|
||||||
|
def check_same_items(item1, item2):
|
||||||
|
assert item1.keys() == item2.keys(), "Keys mismatch"
|
||||||
|
|
||||||
|
for key in item1:
|
||||||
|
if isinstance(item1[key], torch.Tensor) and isinstance(item2[key], torch.Tensor):
|
||||||
|
assert torch.equal(item1[key], item2[key]), f"Mismatch found in key: {key}"
|
||||||
|
else:
|
||||||
|
assert item1[key] == item2[key], f"Mismatch found in key: {key}"
|
||||||
|
|
||||||
|
for i in range(len(ds_reference.hf_dataset)):
|
||||||
|
item_reference = ds_reference.hf_dataset[i]
|
||||||
|
item_actual = ds_actual.hf_dataset[i]
|
||||||
|
check_same_items(item_reference, item_actual)
|
|
@ -76,6 +76,7 @@ def require_env(func):
|
||||||
"""
|
"""
|
||||||
Decorator that skips the test if the required environment package is not installed.
|
Decorator that skips the test if the required environment package is not installed.
|
||||||
As it need 'env_name' in args, it also checks whether it is provided as an argument.
|
As it need 'env_name' in args, it also checks whether it is provided as an argument.
|
||||||
|
If 'env_name' is None, this check is skipped.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
|
@ -91,7 +92,7 @@ def require_env(func):
|
||||||
|
|
||||||
# Perform the package check
|
# Perform the package check
|
||||||
package_name = f"gym_{env_name}"
|
package_name = f"gym_{env_name}"
|
||||||
if not is_package_available(package_name):
|
if env_name is not None and not is_package_available(package_name):
|
||||||
pytest.skip(f"gym-{env_name} not installed")
|
pytest.skip(f"gym-{env_name} not installed")
|
||||||
|
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
@ -99,6 +100,38 @@ def require_env(func):
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def require_package_arg(func):
|
||||||
|
"""
|
||||||
|
Decorator that skips the test if the required package is not installed.
|
||||||
|
This is similar to `require_env` but more general in that it can check any package (not just environments).
|
||||||
|
As it need 'required_packages' in args, it also checks whether it is provided as an argument.
|
||||||
|
If 'required_packages' is None, this check is skipped.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
# Determine if 'required_packages' is provided and extract its value
|
||||||
|
arg_names = func.__code__.co_varnames[: func.__code__.co_argcount]
|
||||||
|
if "required_packages" in arg_names:
|
||||||
|
# Get the index of 'required_packages' and retrieve the value from args
|
||||||
|
index = arg_names.index("required_packages")
|
||||||
|
required_packages = args[index] if len(args) > index else kwargs.get("required_packages")
|
||||||
|
else:
|
||||||
|
raise ValueError("Function does not have 'required_packages' as an argument.")
|
||||||
|
|
||||||
|
if required_packages is None:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
# Perform the package check
|
||||||
|
for package in required_packages:
|
||||||
|
if not is_package_available(package):
|
||||||
|
pytest.skip(f"{package} not installed")
|
||||||
|
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def require_package(package_name):
|
def require_package(package_name):
|
||||||
"""
|
"""
|
||||||
Decorator that skips the test if the specified package is not installed.
|
Decorator that skips the test if the specified package is not installed.
|
||||||
|
|
Loading…
Reference in New Issue