Improve `push_dataset_to_hub` API + Add unit tests (#231)

Co-authored-by: Remi <re.cadene@gmail.com>
Co-authored-by: Simon Alibert <alibert.sim@gmail.com>
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
Thomas Wolf 2024-06-13 15:18:02 +02:00 committed by GitHub
parent c38f535c9f
commit 125bd93e29
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 750 additions and 419 deletions

View File

@ -34,8 +34,8 @@ jobs:
with: with:
lfs: true # Ensure LFS files are pulled lfs: true # Ensure LFS files are pulled
- name: Install EGL - name: Install apt dependencies
run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev ffmpeg
- name: Install poetry - name: Install poetry
run: | run: |
@ -72,6 +72,9 @@ jobs:
with: with:
lfs: true # Ensure LFS files are pulled lfs: true # Ensure LFS files are pulled
- name: Install apt dependencies
run: sudo apt-get update && sudo apt-get install -y ffmpeg
- name: Install poetry - name: Install poetry
run: | run: |
pipx install poetry && poetry config virtualenvs.in-project true pipx install poetry && poetry config virtualenvs.in-project true
@ -106,7 +109,7 @@ jobs:
with: with:
lfs: true # Ensure LFS files are pulled lfs: true # Ensure LFS files are pulled
- name: Install EGL - name: Install apt dependencies
run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev
- name: Install poetry - name: Install poetry

View File

@ -228,13 +228,13 @@ To add a dataset to the hub, you need to login using a write-access token, which
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
``` ```
Then move your dataset folder in `data` directory (e.g. `data/aloha_static_pingpong_test`), and push your dataset to the hub with: Then point to your raw dataset folder (e.g. `data/aloha_static_pingpong_test_raw`), and push your dataset to the hub with:
```bash ```bash
python lerobot/scripts/push_dataset_to_hub.py \ python lerobot/scripts/push_dataset_to_hub.py \
--data-dir data \ --raw-dir data/aloha_static_pingpong_test_raw \
--dataset-id aloha_static_pingpong_test \ --out-dir data \
--raw-format aloha_hdf5 \ --repo-id lerobot/aloha_static_pingpong_test \
--community-id lerobot --raw-format aloha_hdf5
``` ```
See `python lerobot/scripts/push_dataset_to_hub.py --help` for more instructions. See `python lerobot/scripts/push_dataset_to_hub.py --help` for more instructions.

View File

@ -14,156 +14,119 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
This file contains all obsolete download scripts. They are centralized here to not have to load This file contains download scripts for raw datasets.
useless dependencies when using datasets.
Example of usage:
```
python lerobot/common/datasets/push_dataset_to_hub/_download_raw.py \
--raw-dir data/cadene/pusht_raw \
--repo-id cadene/pusht_raw
```
""" """
import io import argparse
import logging import logging
import shutil import warnings
from pathlib import Path from pathlib import Path
import tqdm
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
def download_raw(raw_dir, dataset_id): def download_raw(raw_dir: Path, repo_id: str):
if "aloha" in dataset_id or "image" in dataset_id: # Check repo_id is well formated
download_hub(raw_dir, dataset_id) if len(repo_id.split("/")) != 2:
elif "pusht" in dataset_id: raise ValueError(
download_pusht(raw_dir) f"`repo_id` is expected to contain a community or user id `/` the name of the dataset (e.g. 'lerobot/pusht'), but contains '{repo_id}'."
elif "xarm" in dataset_id: )
download_xarm(raw_dir) user_id, dataset_id = repo_id.split("/")
elif "umi" in dataset_id:
download_umi(raw_dir)
else:
raise ValueError(dataset_id)
if not dataset_id.endswith("_raw"):
def download_and_extract_zip(url: str, destination_folder: Path) -> bool: warnings.warn(
import zipfile f"`dataset_id` ({dataset_id}) doesn't end with '_raw' (e.g. 'lerobot/pusht_raw'). Following this naming convention by renaming your repository is advised, but not mandatory.",
stacklevel=1,
import requests )
print(f"downloading from {url}")
response = requests.get(url, stream=True)
if response.status_code == 200:
total_size = int(response.headers.get("content-length", 0))
progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True)
zip_file = io.BytesIO()
for chunk in response.iter_content(chunk_size=1024):
if chunk:
zip_file.write(chunk)
progress_bar.update(len(chunk))
progress_bar.close()
zip_file.seek(0)
with zipfile.ZipFile(zip_file, "r") as zip_ref:
zip_ref.extractall(destination_folder)
def download_pusht(raw_dir: str):
pusht_url = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip"
raw_dir = Path(raw_dir) raw_dir = Path(raw_dir)
raw_dir.mkdir(parents=True, exist_ok=True) # Send warning if raw_dir isn't well formated
download_and_extract_zip(pusht_url, raw_dir) if raw_dir.parts[-2] != user_id or raw_dir.parts[-1] != dataset_id:
# file is created inside a useful "pusht" directory, so we move it out and delete the dir warnings.warn(
zarr_path = raw_dir / "pusht_cchi_v7_replay.zarr" f"`raw_dir` ({raw_dir}) doesn't contain a community or user id `/` the name of the dataset that match the `repo_id` (e.g. 'data/lerobot/pusht_raw'). Following this naming convention is advised, but not mandatory.",
shutil.move(raw_dir / "pusht" / "pusht_cchi_v7_replay.zarr", zarr_path) stacklevel=1,
shutil.rmtree(raw_dir / "pusht") )
def download_xarm(raw_dir: Path):
"""Download all xarm datasets at once"""
import zipfile
import gdown
raw_dir = Path(raw_dir)
raw_dir.mkdir(parents=True, exist_ok=True)
# from https://github.com/fyhMer/fowm/blob/main/scripts/download_datasets.py
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
zip_path = raw_dir / "data.zip"
gdown.download(url, str(zip_path), quiet=False)
print("Extracting...")
with zipfile.ZipFile(str(zip_path), "r") as zip_f:
for pkl_path in zip_f.namelist():
if pkl_path.startswith("data/xarm") and pkl_path.endswith(".pkl"):
zip_f.extract(member=pkl_path)
# move to corresponding raw directory
extract_dir = pkl_path.replace("/buffer.pkl", "")
raw_pkl_path = raw_dir / "buffer.pkl"
shutil.move(pkl_path, raw_pkl_path)
shutil.rmtree(extract_dir)
zip_path.unlink()
def download_hub(raw_dir: Path, dataset_id: str):
raw_dir = Path(raw_dir)
raw_dir.mkdir(parents=True, exist_ok=True) raw_dir.mkdir(parents=True, exist_ok=True)
logging.info(f"Start downloading from huggingface.co/cadene for {dataset_id}") logging.info(f"Start downloading from huggingface.co/{user_id} for {dataset_id}")
snapshot_download(f"cadene/{dataset_id}_raw", repo_type="dataset", local_dir=raw_dir) snapshot_download(f"{repo_id}", repo_type="dataset", local_dir=raw_dir)
logging.info(f"Finish downloading from huggingface.co/cadene for {dataset_id}") logging.info(f"Finish downloading from huggingface.co/{user_id} for {dataset_id}")
def download_umi(raw_dir: Path): def download_all_raw_datasets():
url_cup_in_the_wild = "https://real.stanford.edu/umi/data/zarr_datasets/cup_in_the_wild.zarr.zip" data_dir = Path("data")
zarr_path = raw_dir / "cup_in_the_wild.zarr" repo_ids = [
"cadene/pusht_image_raw",
"cadene/xarm_lift_medium_image_raw",
"cadene/xarm_lift_medium_replay_image_raw",
"cadene/xarm_push_medium_image_raw",
"cadene/xarm_push_medium_replay_image_raw",
"cadene/aloha_sim_insertion_human_image_raw",
"cadene/aloha_sim_insertion_scripted_image_raw",
"cadene/aloha_sim_transfer_cube_human_image_raw",
"cadene/aloha_sim_transfer_cube_scripted_image_raw",
"cadene/pusht_raw",
"cadene/xarm_lift_medium_raw",
"cadene/xarm_lift_medium_replay_raw",
"cadene/xarm_push_medium_raw",
"cadene/xarm_push_medium_replay_raw",
"cadene/aloha_sim_insertion_human_raw",
"cadene/aloha_sim_insertion_scripted_raw",
"cadene/aloha_sim_transfer_cube_human_raw",
"cadene/aloha_sim_transfer_cube_scripted_raw",
"cadene/aloha_mobile_cabinet_raw",
"cadene/aloha_mobile_chair_raw",
"cadene/aloha_mobile_elevator_raw",
"cadene/aloha_mobile_shrimp_raw",
"cadene/aloha_mobile_wash_pan_raw",
"cadene/aloha_mobile_wipe_wine_raw",
"cadene/aloha_static_battery_raw",
"cadene/aloha_static_candy_raw",
"cadene/aloha_static_coffee_raw",
"cadene/aloha_static_coffee_new_raw",
"cadene/aloha_static_cups_open_raw",
"cadene/aloha_static_fork_pick_up_raw",
"cadene/aloha_static_pingpong_test_raw",
"cadene/aloha_static_pro_pencil_raw",
"cadene/aloha_static_screw_driver_raw",
"cadene/aloha_static_tape_raw",
"cadene/aloha_static_thread_velcro_raw",
"cadene/aloha_static_towel_raw",
"cadene/aloha_static_vinh_cup_raw",
"cadene/aloha_static_vinh_cup_left_raw",
"cadene/aloha_static_ziploc_slide_raw",
"cadene/umi_cup_in_the_wild_raw",
]
for repo_id in repo_ids:
raw_dir = data_dir / repo_id
download_raw(raw_dir, repo_id)
raw_dir = Path(raw_dir)
raw_dir.mkdir(parents=True, exist_ok=True) def main():
download_and_extract_zip(url_cup_in_the_wild, zarr_path) parser = argparse.ArgumentParser()
parser.add_argument(
"--raw-dir",
type=Path,
required=True,
help="Directory containing input raw datasets (e.g. `data/aloha_mobile_chair_raw` or `data/pusht_raw).",
)
parser.add_argument(
"--repo-id",
type=str,
required=True,
help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht_raw`, `cadene/aloha_sim_insertion_human_raw`).",
)
args = parser.parse_args()
download_raw(**vars(args))
if __name__ == "__main__": if __name__ == "__main__":
data_dir = Path("data") main()
dataset_ids = [
"pusht_image",
"xarm_lift_medium_image",
"xarm_lift_medium_replay_image",
"xarm_push_medium_image",
"xarm_push_medium_replay_image",
"aloha_sim_insertion_human_image",
"aloha_sim_insertion_scripted_image",
"aloha_sim_transfer_cube_human_image",
"aloha_sim_transfer_cube_scripted_image",
"pusht",
"xarm_lift_medium",
"xarm_lift_medium_replay",
"xarm_push_medium",
"xarm_push_medium_replay",
"aloha_sim_insertion_human",
"aloha_sim_insertion_scripted",
"aloha_sim_transfer_cube_human",
"aloha_sim_transfer_cube_scripted",
"aloha_mobile_cabinet",
"aloha_mobile_chair",
"aloha_mobile_elevator",
"aloha_mobile_shrimp",
"aloha_mobile_wash_pan",
"aloha_mobile_wipe_wine",
"aloha_static_battery",
"aloha_static_candy",
"aloha_static_coffee",
"aloha_static_coffee_new",
"aloha_static_cups_open",
"aloha_static_fork_pick_up",
"aloha_static_pingpong_test",
"aloha_static_pro_pencil",
"aloha_static_screw_driver",
"aloha_static_tape",
"aloha_static_thread_velcro",
"aloha_static_towel",
"aloha_static_vinh_cup",
"aloha_static_vinh_cup_left",
"aloha_static_ziploc_slide",
"umi_cup_in_the_wild",
]
for dataset_id in dataset_ids:
raw_dir = data_dir / f"{dataset_id}_raw"
download_raw(raw_dir, dataset_id)

View File

@ -30,6 +30,7 @@ from PIL import Image as PILImage
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
calculate_episode_data_index,
hf_transform_to_torch, hf_transform_to_torch,
) )
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
@ -70,16 +71,17 @@ def check_format(raw_dir) -> bool:
assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided." assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
def load_from_raw(raw_dir, out_dir, fps, video, debug): def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
# only frames from simulation are uncompressed # only frames from simulation are uncompressed
compressed_images = "sim" not in raw_dir.name compressed_images = "sim" not in raw_dir.name
hdf5_files = list(raw_dir.glob("*.hdf5")) hdf5_files = sorted(raw_dir.glob("episode_*.hdf5"))
ep_dicts = [] num_episodes = len(hdf5_files)
episode_data_index = {"from": [], "to": []}
id_from = 0 ep_dicts = []
for ep_idx, ep_path in tqdm.tqdm(enumerate(hdf5_files), total=len(hdf5_files)): ep_ids = episodes if episodes else range(num_episodes)
for ep_idx in tqdm.tqdm(ep_ids):
ep_path = hdf5_files[ep_idx]
with h5py.File(ep_path, "r") as ep: with h5py.File(ep_path, "r") as ep:
num_frames = ep["/action"].shape[0] num_frames = ep["/action"].shape[0]
@ -114,12 +116,12 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
if video: if video:
# save png images in temporary directory # save png images in temporary directory
tmp_imgs_dir = out_dir / "tmp_images" tmp_imgs_dir = videos_dir / "tmp_images"
save_images_concurrently(imgs_array, tmp_imgs_dir) save_images_concurrently(imgs_array, tmp_imgs_dir)
# encode images to a mp4 video # encode images to a mp4 video
fname = f"{img_key}_episode_{ep_idx:06d}.mp4" fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
video_path = out_dir / "videos" / fname video_path = videos_dir / fname
encode_video_frames(tmp_imgs_dir, video_path, fps) encode_video_frames(tmp_imgs_dir, video_path, fps)
# clean temporary images directory # clean temporary images directory
@ -147,19 +149,13 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
assert isinstance(ep_idx, int) assert isinstance(ep_idx, int)
ep_dicts.append(ep_dict) ep_dicts.append(ep_dict)
episode_data_index["from"].append(id_from)
episode_data_index["to"].append(id_from + num_frames)
id_from += num_frames
gc.collect() gc.collect()
# process first episode only
if debug:
break
data_dict = concatenate_episodes(ep_dicts) data_dict = concatenate_episodes(ep_dicts)
return data_dict, episode_data_index
total_frames = data_dict["frame_index"].shape[0]
data_dict["index"] = torch.arange(0, total_frames, 1)
return data_dict
def to_hf_dataset(data_dict, video) -> Dataset: def to_hf_dataset(data_dict, video) -> Dataset:
@ -197,16 +193,22 @@ def to_hf_dataset(data_dict, video) -> Dataset:
return hf_dataset return hf_dataset
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False): def from_raw_to_lerobot_format(
raw_dir: Path,
videos_dir: Path,
fps: int | None = None,
video: bool = True,
episodes: list[int] | None = None,
):
# sanity check # sanity check
check_format(raw_dir) check_format(raw_dir)
if fps is None: if fps is None:
fps = 50 fps = 50
data_dir, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug) data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
hf_dataset = to_hf_dataset(data_dir, video) hf_dataset = to_hf_dataset(data_dict, video)
episode_data_index = calculate_episode_data_index(hf_dataset)
info = { info = {
"fps": fps, "fps": fps,
"video": video, "video": video,

View File

@ -17,7 +17,6 @@
Contains utilities to process raw data format from dora-record Contains utilities to process raw data format from dora-record
""" """
import logging
import re import re
from pathlib import Path from pathlib import Path
@ -26,10 +25,10 @@ import torch
from datasets import Dataset, Features, Image, Sequence, Value from datasets import Dataset, Features, Image, Sequence, Value
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
calculate_episode_data_index,
hf_transform_to_torch, hf_transform_to_torch,
) )
from lerobot.common.datasets.video_utils import VideoFrame from lerobot.common.datasets.video_utils import VideoFrame
from lerobot.common.utils.utils import init_logging
def check_format(raw_dir) -> bool: def check_format(raw_dir) -> bool:
@ -41,7 +40,7 @@ def check_format(raw_dir) -> bool:
return True return True
def load_from_raw(raw_dir: Path, out_dir: Path, fps: int): def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
# Load data stream that will be used as reference for the timestamps synchronization # Load data stream that will be used as reference for the timestamps synchronization
reference_files = list(raw_dir.glob("observation.images.cam_*.parquet")) reference_files = list(raw_dir.glob("observation.images.cam_*.parquet"))
if len(reference_files) == 0: if len(reference_files) == 0:
@ -122,8 +121,7 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int):
raise ValueError(f"Episodes indices go from {ep_ids} instead of {expected_ep_ids}") raise ValueError(f"Episodes indices go from {ep_ids} instead of {expected_ep_ids}")
# Create symlink to raw videos directory (that needs to be absolute not relative) # Create symlink to raw videos directory (that needs to be absolute not relative)
out_dir.mkdir(parents=True, exist_ok=True) videos_dir.parent.mkdir(parents=True, exist_ok=True)
videos_dir = out_dir / "videos"
videos_dir.symlink_to((raw_dir / "videos").absolute()) videos_dir.symlink_to((raw_dir / "videos").absolute())
# sanity check the video paths are well formated # sanity check the video paths are well formated
@ -156,16 +154,7 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int):
else: else:
raise ValueError(key) raise ValueError(key)
# Get the episode index containing for each unique episode index return data_dict
first_ep_index_df = df.groupby("episode_index").agg(start_index=("index", "first")).reset_index()
from_ = first_ep_index_df["start_index"].tolist()
to_ = from_[1:] + [len(df)]
episode_data_index = {
"from": from_,
"to": to_,
}
return data_dict, episode_data_index
def to_hf_dataset(data_dict, video) -> Dataset: def to_hf_dataset(data_dict, video) -> Dataset:
@ -203,12 +192,13 @@ def to_hf_dataset(data_dict, video) -> Dataset:
return hf_dataset return hf_dataset
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False): def from_raw_to_lerobot_format(
init_logging() raw_dir: Path,
videos_dir: Path,
if debug: fps: int | None = None,
logging.warning("debug=True not implemented. Falling back to debug=False.") video: bool = True,
episodes: list[int] | None = None,
):
# sanity check # sanity check
check_format(raw_dir) check_format(raw_dir)
@ -220,9 +210,9 @@ def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=Tru
if not video: if not video:
raise NotImplementedError() raise NotImplementedError()
data_df, episode_data_index = load_from_raw(raw_dir, out_dir, fps) data_df = load_from_raw(raw_dir, videos_dir, fps, episodes)
hf_dataset = to_hf_dataset(data_df, video) hf_dataset = to_hf_dataset(data_df, video)
episode_data_index = calculate_episode_data_index(hf_dataset)
info = { info = {
"fps": fps, "fps": fps,
"video": video, "video": video,

View File

@ -27,6 +27,7 @@ from PIL import Image as PILImage
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
calculate_episode_data_index,
hf_transform_to_torch, hf_transform_to_torch,
) )
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
@ -53,7 +54,7 @@ def check_format(raw_dir):
assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets) assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
def load_from_raw(raw_dir, out_dir, fps, video, debug): def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
try: try:
import pymunk import pymunk
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
@ -71,7 +72,6 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
zarr_data = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path) zarr_data = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path)
episode_ids = torch.from_numpy(zarr_data.get_episode_idxs()) episode_ids = torch.from_numpy(zarr_data.get_episode_idxs())
num_episodes = zarr_data.meta["episode_ends"].shape[0]
assert len( assert len(
{zarr_data[key].shape[0] for key in zarr_data.keys()} # noqa: SIM118 {zarr_data[key].shape[0] for key in zarr_data.keys()} # noqa: SIM118
), "Some data type dont have the same number of total frames." ), "Some data type dont have the same number of total frames."
@ -84,25 +84,34 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
states = torch.from_numpy(zarr_data["state"]) states = torch.from_numpy(zarr_data["state"])
actions = torch.from_numpy(zarr_data["action"]) actions = torch.from_numpy(zarr_data["action"])
ep_dicts = [] # load data indices from which each episode starts and ends
episode_data_index = {"from": [], "to": []} from_ids, to_ids = [], []
from_idx = 0
for to_idx in zarr_data.meta["episode_ends"]:
from_ids.append(from_idx)
to_ids.append(to_idx)
from_idx = to_idx
id_from = 0 num_episodes = len(from_ids)
for ep_idx in tqdm.tqdm(range(num_episodes)):
id_to = zarr_data.meta["episode_ends"][ep_idx] ep_dicts = []
num_frames = id_to - id_from ep_ids = episodes if episodes else range(num_episodes)
for ep_idx, selected_ep_idx in tqdm.tqdm(enumerate(ep_ids)):
from_idx = from_ids[selected_ep_idx]
to_idx = to_ids[selected_ep_idx]
num_frames = to_idx - from_idx
# sanity check # sanity check
assert (episode_ids[id_from:id_to] == ep_idx).all() assert (episode_ids[from_idx:to_idx] == ep_idx).all()
# get image # get image
image = imgs[id_from:id_to] image = imgs[from_idx:to_idx]
assert image.min() >= 0.0 assert image.min() >= 0.0
assert image.max() <= 255.0 assert image.max() <= 255.0
image = image.type(torch.uint8) image = image.type(torch.uint8)
# get state # get state
state = states[id_from:id_to] state = states[from_idx:to_idx]
agent_pos = state[:, :2] agent_pos = state[:, :2]
block_pos = state[:, 2:4] block_pos = state[:, 2:4]
block_angle = state[:, 4] block_angle = state[:, 4]
@ -143,12 +152,12 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
img_key = "observation.image" img_key = "observation.image"
if video: if video:
# save png images in temporary directory # save png images in temporary directory
tmp_imgs_dir = out_dir / "tmp_images" tmp_imgs_dir = videos_dir / "tmp_images"
save_images_concurrently(imgs_array, tmp_imgs_dir) save_images_concurrently(imgs_array, tmp_imgs_dir)
# encode images to a mp4 video # encode images to a mp4 video
fname = f"{img_key}_episode_{ep_idx:06d}.mp4" fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
video_path = out_dir / "videos" / fname video_path = videos_dir / fname
encode_video_frames(tmp_imgs_dir, video_path, fps) encode_video_frames(tmp_imgs_dir, video_path, fps)
# clean temporary images directory # clean temporary images directory
@ -160,7 +169,7 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array] ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
ep_dict["observation.state"] = agent_pos ep_dict["observation.state"] = agent_pos
ep_dict["action"] = actions[id_from:id_to] ep_dict["action"] = actions[from_idx:to_idx]
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64) ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
ep_dict["frame_index"] = torch.arange(0, num_frames, 1) ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
@ -172,17 +181,11 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
ep_dict["next.success"] = torch.cat([success[1:], success[[-1]]]) ep_dict["next.success"] = torch.cat([success[1:], success[[-1]]])
ep_dicts.append(ep_dict) ep_dicts.append(ep_dict)
episode_data_index["from"].append(id_from)
episode_data_index["to"].append(id_from + num_frames)
id_from += num_frames
# process first episode only
if debug:
break
data_dict = concatenate_episodes(ep_dicts) data_dict = concatenate_episodes(ep_dicts)
return data_dict, episode_data_index
total_frames = data_dict["frame_index"].shape[0]
data_dict["index"] = torch.arange(0, total_frames, 1)
return data_dict
def to_hf_dataset(data_dict, video): def to_hf_dataset(data_dict, video):
@ -212,16 +215,22 @@ def to_hf_dataset(data_dict, video):
return hf_dataset return hf_dataset
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False): def from_raw_to_lerobot_format(
raw_dir: Path,
videos_dir: Path,
fps: int | None = None,
video: bool = True,
episodes: list[int] | None = None,
):
# sanity check # sanity check
check_format(raw_dir) check_format(raw_dir)
if fps is None: if fps is None:
fps = 10 fps = 10
data_dict, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug) data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
hf_dataset = to_hf_dataset(data_dict, video) hf_dataset = to_hf_dataset(data_dict, video)
episode_data_index = calculate_episode_data_index(hf_dataset)
info = { info = {
"fps": fps, "fps": fps,
"video": video, "video": video,

View File

@ -19,7 +19,6 @@ import logging
import shutil import shutil
from pathlib import Path from pathlib import Path
import numpy as np
import torch import torch
import tqdm import tqdm
import zarr import zarr
@ -29,6 +28,7 @@ from PIL import Image as PILImage
from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import register_codecs from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import register_codecs
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
calculate_episode_data_index,
hf_transform_to_torch, hf_transform_to_torch,
) )
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
@ -59,23 +59,7 @@ def check_format(raw_dir) -> bool:
assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets) assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
def get_episode_idxs(episode_ends: np.ndarray) -> np.ndarray: def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
# Optimized and simplified version of this function: https://github.com/real-stanford/universal_manipulation_interface/blob/298776ce251f33b6b3185a98d6e7d1f9ad49168b/diffusion_policy/common/replay_buffer.py#L374
from numba import jit
@jit(nopython=True)
def _get_episode_idxs(episode_ends):
result = np.zeros((episode_ends[-1],), dtype=np.int64)
start_idx = 0
for episode_number, end_idx in enumerate(episode_ends):
result[start_idx:end_idx] = episode_number
start_idx = end_idx
return result
return _get_episode_idxs(episode_ends)
def load_from_raw(raw_dir, out_dir, fps, video, debug):
zarr_path = raw_dir / "cup_in_the_wild.zarr" zarr_path = raw_dir / "cup_in_the_wild.zarr"
zarr_data = zarr.open(zarr_path, mode="r") zarr_data = zarr.open(zarr_path, mode="r")
@ -92,39 +76,41 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
episode_ends = zarr_data["meta/episode_ends"][:] episode_ends = zarr_data["meta/episode_ends"][:]
num_episodes = episode_ends.shape[0] num_episodes = episode_ends.shape[0]
episode_ids = torch.from_numpy(get_episode_idxs(episode_ends))
# We convert it in torch tensor later because the jit function does not support torch tensors # We convert it in torch tensor later because the jit function does not support torch tensors
episode_ends = torch.from_numpy(episode_ends) episode_ends = torch.from_numpy(episode_ends)
# load data indices from which each episode starts and ends
from_ids, to_ids = [], []
from_idx = 0
for to_idx in episode_ends:
from_ids.append(from_idx)
to_ids.append(to_idx)
from_idx = to_idx
ep_dicts = [] ep_dicts = []
episode_data_index = {"from": [], "to": []} ep_ids = episodes if episodes else range(num_episodes)
for ep_idx, selected_ep_idx in tqdm.tqdm(enumerate(ep_ids)):
id_from = 0 from_idx = from_ids[selected_ep_idx]
for ep_idx in tqdm.tqdm(range(num_episodes)): to_idx = to_ids[selected_ep_idx]
id_to = episode_ends[ep_idx] num_frames = to_idx - from_idx
num_frames = id_to - id_from
# sanity heck
assert (episode_ids[id_from:id_to] == ep_idx).all()
# TODO(rcadene): save temporary images of the episode? # TODO(rcadene): save temporary images of the episode?
state = states[id_from:id_to] state = states[from_idx:to_idx]
ep_dict = {} ep_dict = {}
# load 57MB of images in RAM (400x224x224x3 uint8) # load 57MB of images in RAM (400x224x224x3 uint8)
imgs_array = zarr_data["data/camera0_rgb"][id_from:id_to] imgs_array = zarr_data["data/camera0_rgb"][from_idx:to_idx]
img_key = "observation.image" img_key = "observation.image"
if video: if video:
# save png images in temporary directory # save png images in temporary directory
tmp_imgs_dir = out_dir / "tmp_images" tmp_imgs_dir = videos_dir / "tmp_images"
save_images_concurrently(imgs_array, tmp_imgs_dir) save_images_concurrently(imgs_array, tmp_imgs_dir)
# encode images to a mp4 video # encode images to a mp4 video
fname = f"{img_key}_episode_{ep_idx:06d}.mp4" fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
video_path = out_dir / "videos" / fname video_path = videos_dir / fname
encode_video_frames(tmp_imgs_dir, video_path, fps) encode_video_frames(tmp_imgs_dir, video_path, fps)
# clean temporary images directory # clean temporary images directory
@ -139,27 +125,18 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64) ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
ep_dict["frame_index"] = torch.arange(0, num_frames, 1) ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
ep_dict["episode_data_index_from"] = torch.tensor([id_from] * num_frames) ep_dict["episode_data_index_from"] = torch.tensor([from_idx] * num_frames)
ep_dict["episode_data_index_to"] = torch.tensor([id_from + num_frames] * num_frames) ep_dict["episode_data_index_to"] = torch.tensor([from_idx + num_frames] * num_frames)
ep_dict["end_pose"] = end_pose[id_from:id_to] ep_dict["end_pose"] = end_pose[from_idx:to_idx]
ep_dict["start_pos"] = start_pos[id_from:id_to] ep_dict["start_pos"] = start_pos[from_idx:to_idx]
ep_dict["gripper_width"] = gripper_width[id_from:id_to] ep_dict["gripper_width"] = gripper_width[from_idx:to_idx]
ep_dicts.append(ep_dict) ep_dicts.append(ep_dict)
episode_data_index["from"].append(id_from)
episode_data_index["to"].append(id_from + num_frames)
id_from += num_frames
# process first episode only
if debug:
break
data_dict = concatenate_episodes(ep_dicts) data_dict = concatenate_episodes(ep_dicts)
total_frames = id_from total_frames = data_dict["frame_index"].shape[0]
data_dict["index"] = torch.arange(0, total_frames, 1) data_dict["index"] = torch.arange(0, total_frames, 1)
return data_dict
return data_dict, episode_data_index
def to_hf_dataset(data_dict, video): def to_hf_dataset(data_dict, video):
@ -199,7 +176,13 @@ def to_hf_dataset(data_dict, video):
return hf_dataset return hf_dataset
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False): def from_raw_to_lerobot_format(
raw_dir: Path,
videos_dir: Path,
fps: int | None = None,
video: bool = True,
episodes: list[int] | None = None,
):
# sanity check # sanity check
check_format(raw_dir) check_format(raw_dir)
@ -212,9 +195,9 @@ def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=Tru
"Generating UMI dataset without `video=True` creates ~150GB on disk and requires ~80GB in RAM." "Generating UMI dataset without `video=True` creates ~150GB on disk and requires ~80GB in RAM."
) )
data_dict, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug) data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
hf_dataset = to_hf_dataset(data_dict, video) hf_dataset = to_hf_dataset(data_dict, video)
episode_data_index = calculate_episode_data_index(hf_dataset)
info = { info = {
"fps": fps, "fps": fps,
"video": video, "video": video,

View File

@ -27,6 +27,7 @@ from PIL import Image as PILImage
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
calculate_episode_data_index,
hf_transform_to_torch, hf_transform_to_torch,
) )
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
@ -54,37 +55,42 @@ def check_format(raw_dir):
assert all(len(nested_dict[subkey]) == expected_len for subkey in subkeys if subkey in nested_dict) assert all(len(nested_dict[subkey]) == expected_len for subkey in subkeys if subkey in nested_dict)
def load_from_raw(raw_dir, out_dir, fps, video, debug): def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
pkl_path = raw_dir / "buffer.pkl" pkl_path = raw_dir / "buffer.pkl"
with open(pkl_path, "rb") as f: with open(pkl_path, "rb") as f:
pkl_data = pickle.load(f) pkl_data = pickle.load(f)
ep_dicts = [] # load data indices from which each episode starts and ends
episode_data_index = {"from": [], "to": []} from_ids, to_ids = [], []
from_idx, to_idx = 0, 0
id_from = 0 for done in pkl_data["dones"]:
id_to = 0 to_idx += 1
ep_idx = 0 if not done:
total_frames = pkl_data["actions"].shape[0]
for i in tqdm.tqdm(range(total_frames)):
id_to += 1
if not pkl_data["dones"][i]:
continue continue
from_ids.append(from_idx)
to_ids.append(to_idx)
from_idx = to_idx
num_frames = id_to - id_from num_episodes = len(from_ids)
image = torch.tensor(pkl_data["observations"]["rgb"][id_from:id_to]) ep_dicts = []
ep_ids = episodes if episodes else range(num_episodes)
for ep_idx, selected_ep_idx in tqdm.tqdm(enumerate(ep_ids)):
from_idx = from_ids[selected_ep_idx]
to_idx = to_ids[selected_ep_idx]
num_frames = to_idx - from_idx
image = torch.tensor(pkl_data["observations"]["rgb"][from_idx:to_idx])
image = einops.rearrange(image, "b c h w -> b h w c") image = einops.rearrange(image, "b c h w -> b h w c")
state = torch.tensor(pkl_data["observations"]["state"][id_from:id_to]) state = torch.tensor(pkl_data["observations"]["state"][from_idx:to_idx])
action = torch.tensor(pkl_data["actions"][id_from:id_to]) action = torch.tensor(pkl_data["actions"][from_idx:to_idx])
# TODO(rcadene): we have a missing last frame which is the observation when the env is done # TODO(rcadene): we have a missing last frame which is the observation when the env is done
# it is critical to have this frame for tdmpc to predict a "done observation/state" # it is critical to have this frame for tdmpc to predict a "done observation/state"
# next_image = torch.tensor(pkl_data["next_observations"]["rgb"][id_from:id_to]) # next_image = torch.tensor(pkl_data["next_observations"]["rgb"][from_idx:to_idx])
# next_state = torch.tensor(pkl_data["next_observations"]["state"][id_from:id_to]) # next_state = torch.tensor(pkl_data["next_observations"]["state"][from_idx:to_idx])
next_reward = torch.tensor(pkl_data["rewards"][id_from:id_to]) next_reward = torch.tensor(pkl_data["rewards"][from_idx:to_idx])
next_done = torch.tensor(pkl_data["dones"][id_from:id_to]) next_done = torch.tensor(pkl_data["dones"][from_idx:to_idx])
ep_dict = {} ep_dict = {}
@ -92,12 +98,12 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
img_key = "observation.image" img_key = "observation.image"
if video: if video:
# save png images in temporary directory # save png images in temporary directory
tmp_imgs_dir = out_dir / "tmp_images" tmp_imgs_dir = videos_dir / "tmp_images"
save_images_concurrently(imgs_array, tmp_imgs_dir) save_images_concurrently(imgs_array, tmp_imgs_dir)
# encode images to a mp4 video # encode images to a mp4 video
fname = f"{img_key}_episode_{ep_idx:06d}.mp4" fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
video_path = out_dir / "videos" / fname video_path = videos_dir / fname
encode_video_frames(tmp_imgs_dir, video_path, fps) encode_video_frames(tmp_imgs_dir, video_path, fps)
# clean temporary images directory # clean temporary images directory
@ -119,18 +125,11 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
ep_dict["next.done"] = next_done ep_dict["next.done"] = next_done
ep_dicts.append(ep_dict) ep_dicts.append(ep_dict)
episode_data_index["from"].append(id_from)
episode_data_index["to"].append(id_from + num_frames)
id_from = id_to
ep_idx += 1
# process first episode only
if debug:
break
data_dict = concatenate_episodes(ep_dicts) data_dict = concatenate_episodes(ep_dicts)
return data_dict, episode_data_index
total_frames = data_dict["frame_index"].shape[0]
data_dict["index"] = torch.arange(0, total_frames, 1)
return data_dict
def to_hf_dataset(data_dict, video): def to_hf_dataset(data_dict, video):
@ -161,16 +160,22 @@ def to_hf_dataset(data_dict, video):
return hf_dataset return hf_dataset
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False): def from_raw_to_lerobot_format(
raw_dir: Path,
videos_dir: Path,
fps: int | None = None,
video: bool = True,
episodes: list[int] | None = None,
):
# sanity check # sanity check
check_format(raw_dir) check_format(raw_dir)
if fps is None: if fps is None:
fps = 15 fps = 15
data_dict, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug) data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
hf_dataset = to_hf_dataset(data_dict, video) hf_dataset = to_hf_dataset(data_dict, video)
episode_data_index = calculate_episode_data_index(hf_dataset)
info = { info = {
"fps": fps, "fps": fps,
"video": video, "video": video,

View File

@ -18,58 +18,39 @@ Use this script to convert your dataset into LeRobot dataset format and upload i
or store it locally. LeRobot dataset format is lightweight, fast to load from, and does not require any or store it locally. LeRobot dataset format is lightweight, fast to load from, and does not require any
installation of neural net specific packages like pytorch, tensorflow, jax. installation of neural net specific packages like pytorch, tensorflow, jax.
Example: Example of how to download raw datasets, convert them into LeRobotDataset format, and push them to the hub:
``` ```
python lerobot/scripts/push_dataset_to_hub.py \ python lerobot/scripts/push_dataset_to_hub.py \
--data-dir data \ --raw-dir data/pusht_raw \
--dataset-id pusht \
--raw-format pusht_zarr \ --raw-format pusht_zarr \
--community-id lerobot \ --repo-id lerobot/pusht
--dry-run 1 \
--save-to-disk 1 \
--save-tests-to-disk 0 \
--debug 1
python lerobot/scripts/push_dataset_to_hub.py \ python lerobot/scripts/push_dataset_to_hub.py \
--data-dir data \ --raw-dir data/xarm_lift_medium_raw \
--dataset-id xarm_lift_medium \
--raw-format xarm_pkl \ --raw-format xarm_pkl \
--community-id lerobot \ --repo-id lerobot/xarm_lift_medium
--dry-run 1 \
--save-to-disk 1 \
--save-tests-to-disk 0 \
--debug 1
python lerobot/scripts/push_dataset_to_hub.py \ python lerobot/scripts/push_dataset_to_hub.py \
--data-dir data \ --raw-dir data/aloha_sim_insertion_scripted_raw \
--dataset-id aloha_sim_insertion_scripted \
--raw-format aloha_hdf5 \ --raw-format aloha_hdf5 \
--community-id lerobot \ --repo-id lerobot/aloha_sim_insertion_scripted
--dry-run 1 \
--save-to-disk 1 \
--save-tests-to-disk 0 \
--debug 1
python lerobot/scripts/push_dataset_to_hub.py \ python lerobot/scripts/push_dataset_to_hub.py \
--data-dir data \ --raw-dir data/umi_cup_in_the_wild_raw \
--dataset-id umi_cup_in_the_wild \
--raw-format umi_zarr \ --raw-format umi_zarr \
--community-id lerobot \ --repo-id lerobot/umi_cup_in_the_wild
--dry-run 1 \
--save-to-disk 1 \
--save-tests-to-disk 0 \
--debug 1
``` ```
""" """
import argparse import argparse
import json import json
import shutil import shutil
import warnings
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
import torch import torch
from huggingface_hub import HfApi from huggingface_hub import HfApi, create_branch
from safetensors.torch import save_file from safetensors.torch import save_file
from lerobot.common.datasets.compute_stats import compute_stats from lerobot.common.datasets.compute_stats import compute_stats
@ -85,8 +66,8 @@ def get_from_raw_to_lerobot_format_fn(raw_format: str):
from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format
elif raw_format == "aloha_hdf5": elif raw_format == "aloha_hdf5":
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format
elif raw_format == "aloha_dora": elif raw_format == "dora_parquet":
from lerobot.common.datasets.push_dataset_to_hub.aloha_dora_format import from_raw_to_lerobot_format from lerobot.common.datasets.push_dataset_to_hub.dora_parquet_format import from_raw_to_lerobot_format
elif raw_format == "xarm_pkl": elif raw_format == "xarm_pkl":
from lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format import from_raw_to_lerobot_format from lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format import from_raw_to_lerobot_format
else: else:
@ -147,39 +128,61 @@ def push_videos_to_hub(repo_id: str, videos_dir: str | Path, revision: str | Non
def push_dataset_to_hub( def push_dataset_to_hub(
data_dir: Path, raw_dir: Path,
dataset_id: str, raw_format: str,
raw_format: str | None, repo_id: str,
community_id: str, push_to_hub: bool = True,
revision: str, local_dir: Path | None = None,
dry_run: bool, fps: int | None = None,
save_to_disk: bool, video: bool = True,
tests_data_dir: Path, batch_size: int = 32,
save_tests_to_disk: bool, num_workers: int = 8,
fps: int | None, episodes: list[int] | None = None,
video: bool, force_override: bool = False,
batch_size: int, cache_dir: Path = Path("/tmp"),
num_workers: int, tests_data_dir: Path | None = None,
debug: bool,
): ):
repo_id = f"{community_id}/{dataset_id}" # Check repo_id is well formated
if len(repo_id.split("/")) != 2:
raise ValueError(
f"`repo_id` is expected to contain a community or user id `/` the name of the dataset (e.g. 'lerobot/pusht'), but instead contains '{repo_id}'."
)
user_id, dataset_id = repo_id.split("/")
raw_dir = data_dir / f"{dataset_id}_raw" # Robustify when `raw_dir` is str instead of Path
raw_dir = Path(raw_dir)
if not raw_dir.exists():
raise NotADirectoryError(
f"{raw_dir} does not exists. Check your paths or run this command to download an existing raw dataset on the hub:"
f"python lerobot/common/datasets/push_dataset_to_hub/_download_raw.py --raw-dir your/raw/dir --repo-id your/repo/id_raw"
)
out_dir = data_dir / repo_id if local_dir:
meta_data_dir = out_dir / "meta_data" # Robustify when `local_dir` is str instead of Path
videos_dir = out_dir / "videos" local_dir = Path(local_dir)
tests_out_dir = tests_data_dir / repo_id # Send warning if local_dir isn't well formated
tests_meta_data_dir = tests_out_dir / "meta_data" if local_dir.parts[-2] != user_id or local_dir.parts[-1] != dataset_id:
tests_videos_dir = tests_out_dir / "videos" warnings.warn(
f"`local_dir` ({local_dir}) doesn't contain a community or user id `/` the name of the dataset that match the `repo_id` (e.g. 'data/lerobot/pusht'). Following this naming convention is advised, but not mandatory.",
stacklevel=1,
)
if out_dir.exists(): # Check we don't override an existing `local_dir` by mistake
shutil.rmtree(out_dir) if local_dir.exists():
if force_override:
shutil.rmtree(local_dir)
else:
raise ValueError(f"`local_dir` already exists ({local_dir}). Use `--force-override 1`.")
if tests_out_dir.exists() and save_tests_to_disk: meta_data_dir = local_dir / "meta_data"
shutil.rmtree(tests_out_dir) videos_dir = local_dir / "videos"
else:
# Temporary directory used to store images, videos, meta_data
meta_data_dir = Path(cache_dir) / "meta_data"
videos_dir = Path(cache_dir) / "videos"
# Download the raw dataset if available
if not raw_dir.exists(): if not raw_dir.exists():
download_raw(raw_dir, dataset_id) download_raw(raw_dir, dataset_id)
@ -188,14 +191,14 @@ def push_dataset_to_hub(
raise NotImplementedError() raise NotImplementedError()
# raw_format = auto_find_raw_format(raw_dir) # raw_format = auto_find_raw_format(raw_dir)
from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format)
# convert dataset from original raw format to LeRobot format # convert dataset from original raw format to LeRobot format
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(raw_dir, out_dir, fps, video, debug) from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format)
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(
raw_dir, videos_dir, fps, video, episodes
)
lerobot_dataset = LeRobotDataset.from_preloaded( lerobot_dataset = LeRobotDataset.from_preloaded(
repo_id=repo_id, repo_id=repo_id,
version=revision,
hf_dataset=hf_dataset, hf_dataset=hf_dataset,
episode_data_index=episode_data_index, episode_data_index=episode_data_index,
info=info, info=info,
@ -203,103 +206,80 @@ def push_dataset_to_hub(
) )
stats = compute_stats(lerobot_dataset, batch_size, num_workers) stats = compute_stats(lerobot_dataset, batch_size, num_workers)
if save_to_disk: if local_dir:
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
hf_dataset.save_to_disk(str(out_dir / "train")) hf_dataset.save_to_disk(str(local_dir / "train"))
if not dry_run or save_to_disk: if push_to_hub or local_dir:
# mandatory for upload # mandatory for upload
save_meta_data(info, stats, episode_data_index, meta_data_dir) save_meta_data(info, stats, episode_data_index, meta_data_dir)
if not dry_run: if push_to_hub:
# TODO(rcadene): token needs to be a str | None hf_dataset.push_to_hub(repo_id, revision="main")
hf_dataset.push_to_hub(repo_id, token=True, revision="main")
hf_dataset.push_to_hub(repo_id, token=True, revision=revision)
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main") push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
push_meta_data_to_hub(repo_id, meta_data_dir, revision=revision)
if video: if video:
push_videos_to_hub(repo_id, videos_dir, revision="main") push_videos_to_hub(repo_id, videos_dir, revision="main")
push_videos_to_hub(repo_id, videos_dir, revision=revision) create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)
if save_tests_to_disk: if tests_data_dir:
# get the first episode # get the first episode
num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0] num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]
test_hf_dataset = hf_dataset.select(range(num_items_first_ep)) test_hf_dataset = hf_dataset.select(range(num_items_first_ep))
test_hf_dataset = test_hf_dataset.with_format(None) test_hf_dataset = test_hf_dataset.with_format(None)
test_hf_dataset.save_to_disk(str(tests_out_dir / "train")) test_hf_dataset.save_to_disk(str(tests_data_dir / repo_id / "train"))
save_meta_data(info, stats, episode_data_index, tests_meta_data_dir) tests_meta_data = tests_data_dir / repo_id / "meta_data"
save_meta_data(info, stats, episode_data_index, tests_meta_data)
# copy videos of first episode to tests directory # copy videos of first episode to tests directory
episode_index = 0 episode_index = 0
tests_videos_dir = tests_data_dir / repo_id / "videos"
tests_videos_dir.mkdir(parents=True, exist_ok=True) tests_videos_dir.mkdir(parents=True, exist_ok=True)
for key in lerobot_dataset.video_frame_keys: for key in lerobot_dataset.video_frame_keys:
fname = f"{key}_episode_{episode_index:06d}.mp4" fname = f"{key}_episode_{episode_index:06d}.mp4"
shutil.copy(videos_dir / fname, tests_videos_dir / fname) shutil.copy(videos_dir / fname, tests_videos_dir / fname)
if not save_to_disk and out_dir.exists(): if local_dir is None:
# remove possible temporary files remaining in the output directory # clear cache
shutil.rmtree(out_dir) shutil.rmtree(meta_data_dir)
shutil.rmtree(videos_dir)
return lerobot_dataset
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"--data-dir", "--raw-dir",
type=Path, type=Path,
required=True, required=True,
help="Root directory containing datasets (e.g. `data` or `tmp/data` or `/tmp/lerobot/data`).", help="Directory containing input raw datasets (e.g. `data/aloha_mobile_chair_raw` or `data/pusht_raw).",
)
parser.add_argument(
"--dataset-id",
type=str,
required=True,
help="Name of the dataset (e.g. `pusht`, `aloha_sim_insertion_human`), which matches the folder where the data is stored (e.g. `data/pusht`).",
) )
# TODO(rcadene): add automatic detection of the format
parser.add_argument( parser.add_argument(
"--raw-format", "--raw-format",
type=str, type=str,
help="Dataset type (e.g. `pusht_zarr`, `umi_zarr`, `aloha_hdf5`, `xarm_pkl`). If not provided, will be detected automatically.", required=True,
help="Dataset type (e.g. `pusht_zarr`, `umi_zarr`, `aloha_hdf5`, `xarm_pkl`, `dora_parquet`).",
) )
parser.add_argument( parser.add_argument(
"--community-id", "--repo-id",
type=str, type=str,
default="lerobot", required=True,
help="Community or user ID under which the dataset will be hosted on the Hub.", help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
) )
parser.add_argument( parser.add_argument(
"--revision", "--local-dir",
type=str,
default=CODEBASE_VERSION,
help="Codebase version used to generate the dataset.",
)
parser.add_argument(
"--dry-run",
type=int,
default=0,
help="Run everything without uploading to hub, for testing purposes or storing a dataset locally.",
)
parser.add_argument(
"--save-to-disk",
type=int,
default=1,
help="Save the dataset in the directory specified by `--data-dir`.",
)
parser.add_argument(
"--tests-data-dir",
type=Path, type=Path,
default="tests/data", help="When provided, writes the dataset converted to LeRobotDataset format in this directory (e.g. `data/lerobot/aloha_mobile_chair`).",
help="Directory containing tests artifacts datasets.",
) )
parser.add_argument( parser.add_argument(
"--save-tests-to-disk", "--push-to-hub",
type=int, type=int,
default=1, default=1,
help="Save the dataset with 1 episode used for unit tests in the directory specified by `--tests-data-dir`.", help="Upload to hub.",
) )
parser.add_argument( parser.add_argument(
"--fps", "--fps",
@ -325,10 +305,21 @@ def main():
help="Number of processes of Dataloader for computing the dataset statistics.", help="Number of processes of Dataloader for computing the dataset statistics.",
) )
parser.add_argument( parser.add_argument(
"--debug", "--episodes",
type=int,
nargs="*",
help="When provided, only converts the provided episodes (e.g `--episodes 2 3 4`). Useful to test the code on 1 episode.",
)
parser.add_argument(
"--force-override",
type=int, type=int,
default=0, default=0,
help="Debug mode process the first episode only.", help="When set to 1, removes provided output directory if it already exists. By default, raises a ValueError exception.",
)
parser.add_argument(
"--tests-data-dir",
type=Path,
help="When provided, save tests artifacts into the given directory for (e.g. `--tests-data-dir tests/data/lerobot/pusht`).",
) )
args = parser.parse_args() args = parser.parse_args()

View File

@ -0,0 +1,352 @@
"""
This file contains generic tests to ensure that nothing breaks if we modify the push_dataset_to_hub API.
Also, this file contains backward compatibility tests. Because they are slow and require to download the raw datasets,
we skip them for now in our CI.
Example to run backward compatiblity tests locally:
```
DATA_DIR=tests/data python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility
```
"""
from pathlib import Path
import numpy as np
import pytest
import torch
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub.utils import save_images_concurrently
from lerobot.common.datasets.video_utils import encode_video_frames
from lerobot.scripts.push_dataset_to_hub import push_dataset_to_hub
from tests.utils import require_package_arg
def _mock_download_raw_pusht(raw_dir, num_frames=4, num_episodes=3):
import zarr
raw_dir.mkdir(parents=True, exist_ok=True)
zarr_path = raw_dir / "pusht_cchi_v7_replay.zarr"
store = zarr.DirectoryStore(zarr_path)
zarr_data = zarr.group(store=store)
zarr_data.create_dataset(
"data/action", shape=(num_frames, 1), chunks=(num_frames, 1), dtype=np.float32, overwrite=True
)
zarr_data.create_dataset(
"data/img",
shape=(num_frames, 96, 96, 3),
chunks=(num_frames, 96, 96, 3),
dtype=np.uint8,
overwrite=True,
)
zarr_data.create_dataset(
"data/n_contacts", shape=(num_frames, 2), chunks=(num_frames, 2), dtype=np.float32, overwrite=True
)
zarr_data.create_dataset(
"data/state", shape=(num_frames, 5), chunks=(num_frames, 5), dtype=np.float32, overwrite=True
)
zarr_data.create_dataset(
"data/keypoint", shape=(num_frames, 9, 2), chunks=(num_frames, 9, 2), dtype=np.float32, overwrite=True
)
zarr_data.create_dataset(
"meta/episode_ends", shape=(num_episodes,), chunks=(num_episodes,), dtype=np.int32, overwrite=True
)
zarr_data["data/action"][:] = np.random.randn(num_frames, 1)
zarr_data["data/img"][:] = np.random.randint(0, 255, size=(num_frames, 96, 96, 3), dtype=np.uint8)
zarr_data["data/n_contacts"][:] = np.random.randn(num_frames, 2)
zarr_data["data/state"][:] = np.random.randn(num_frames, 5)
zarr_data["data/keypoint"][:] = np.random.randn(num_frames, 9, 2)
zarr_data["meta/episode_ends"][:] = np.array([1, 3, 4])
store.close()
def _mock_download_raw_umi(raw_dir, num_frames=4, num_episodes=3):
import zarr
raw_dir.mkdir(parents=True, exist_ok=True)
zarr_path = raw_dir / "cup_in_the_wild.zarr"
store = zarr.DirectoryStore(zarr_path)
zarr_data = zarr.group(store=store)
zarr_data.create_dataset(
"data/camera0_rgb",
shape=(num_frames, 96, 96, 3),
chunks=(num_frames, 96, 96, 3),
dtype=np.uint8,
overwrite=True,
)
zarr_data.create_dataset(
"data/robot0_demo_end_pose",
shape=(num_frames, 5),
chunks=(num_frames, 5),
dtype=np.float32,
overwrite=True,
)
zarr_data.create_dataset(
"data/robot0_demo_start_pose",
shape=(num_frames, 5),
chunks=(num_frames, 5),
dtype=np.float32,
overwrite=True,
)
zarr_data.create_dataset(
"data/robot0_eef_pos", shape=(num_frames, 5), chunks=(num_frames, 5), dtype=np.float32, overwrite=True
)
zarr_data.create_dataset(
"data/robot0_eef_rot_axis_angle",
shape=(num_frames, 5),
chunks=(num_frames, 5),
dtype=np.float32,
overwrite=True,
)
zarr_data.create_dataset(
"data/robot0_gripper_width",
shape=(num_frames, 5),
chunks=(num_frames, 5),
dtype=np.float32,
overwrite=True,
)
zarr_data.create_dataset(
"meta/episode_ends", shape=(num_episodes,), chunks=(num_episodes,), dtype=np.int32, overwrite=True
)
zarr_data["data/camera0_rgb"][:] = np.random.randint(0, 255, size=(num_frames, 96, 96, 3), dtype=np.uint8)
zarr_data["data/robot0_demo_end_pose"][:] = np.random.randn(num_frames, 5)
zarr_data["data/robot0_demo_start_pose"][:] = np.random.randn(num_frames, 5)
zarr_data["data/robot0_eef_pos"][:] = np.random.randn(num_frames, 5)
zarr_data["data/robot0_eef_rot_axis_angle"][:] = np.random.randn(num_frames, 5)
zarr_data["data/robot0_gripper_width"][:] = np.random.randn(num_frames, 5)
zarr_data["meta/episode_ends"][:] = np.array([1, 3, 4])
store.close()
def _mock_download_raw_xarm(raw_dir, num_frames=4):
import pickle
dataset_dict = {
"observations": {
"rgb": np.random.randint(0, 255, size=(num_frames, 3, 84, 84), dtype=np.uint8),
"state": np.random.randn(num_frames, 4),
},
"actions": np.random.randn(num_frames, 3),
"rewards": np.random.randn(num_frames),
"masks": np.random.randn(num_frames),
"dones": np.array([False, True, True, True]),
}
raw_dir.mkdir(parents=True, exist_ok=True)
pkl_path = raw_dir / "buffer.pkl"
with open(pkl_path, "wb") as f:
pickle.dump(dataset_dict, f)
def _mock_download_raw_aloha(raw_dir, num_frames=6, num_episodes=3):
import h5py
for ep_idx in range(num_episodes):
raw_dir.mkdir(parents=True, exist_ok=True)
path_h5 = raw_dir / f"episode_{ep_idx}.hdf5"
with h5py.File(str(path_h5), "w") as f:
f.create_dataset("action", data=np.random.randn(num_frames // num_episodes, 14))
f.create_dataset("observations/qpos", data=np.random.randn(num_frames // num_episodes, 14))
f.create_dataset("observations/qvel", data=np.random.randn(num_frames // num_episodes, 14))
f.create_dataset(
"observations/images/top",
data=np.random.randint(
0, 255, size=(num_frames // num_episodes, 480, 640, 3), dtype=np.uint8
),
)
def _mock_download_raw_dora(raw_dir, num_frames=6, num_episodes=3, fps=30):
from datetime import datetime, timedelta, timezone
import pandas
def write_parquet(key, timestamps, values):
data = {
"timestamp_utc": timestamps,
key: values,
}
df = pandas.DataFrame(data)
raw_dir.mkdir(parents=True, exist_ok=True)
df.to_parquet(raw_dir / f"{key}.parquet", engine="pyarrow")
episode_indices = [None, None, -1, None, None, -1, None, None, -1]
episode_indices_mapping = [0, 0, 0, 1, 1, 1, 2, 2, 2]
frame_indices = [0, 1, -1, 0, 1, -1, 0, 1, -1]
cam_key = "observation.images.cam_high"
timestamps = []
actions = []
states = []
frames = []
# `+ num_episodes`` for buffer frames associated to episode_index=-1
for i, frame_idx in enumerate(frame_indices):
t_utc = datetime.now(timezone.utc) + timedelta(seconds=i / fps)
action = np.random.randn(21).tolist()
state = np.random.randn(21).tolist()
ep_idx = episode_indices_mapping[i]
frame = [{"path": f"videos/{cam_key}_episode_{ep_idx:06d}.mp4", "timestamp": frame_idx / fps}]
timestamps.append(t_utc)
actions.append(action)
states.append(state)
frames.append(frame)
write_parquet(cam_key, timestamps, frames)
write_parquet("observation.state", timestamps, states)
write_parquet("action", timestamps, actions)
write_parquet("episode_index", timestamps, episode_indices)
# write fake mp4 file for each episode
for ep_idx in range(num_episodes):
imgs_array = np.random.randint(0, 255, size=(num_frames // num_episodes, 480, 640, 3), dtype=np.uint8)
tmp_imgs_dir = raw_dir / "tmp_images"
save_images_concurrently(imgs_array, tmp_imgs_dir)
fname = f"{cam_key}_episode_{ep_idx:06d}.mp4"
video_path = raw_dir / "videos" / fname
encode_video_frames(tmp_imgs_dir, video_path, fps)
def _mock_download_raw(raw_dir, repo_id):
if "wrist_gripper" in repo_id:
_mock_download_raw_dora(raw_dir)
elif "aloha" in repo_id:
_mock_download_raw_aloha(raw_dir)
elif "pusht" in repo_id:
_mock_download_raw_pusht(raw_dir)
elif "xarm" in repo_id:
_mock_download_raw_xarm(raw_dir)
elif "umi" in repo_id:
_mock_download_raw_umi(raw_dir)
else:
raise ValueError(repo_id)
def test_push_dataset_to_hub_invalid_repo_id(tmpdir):
with pytest.raises(ValueError):
push_dataset_to_hub(Path(tmpdir), "raw_format", "invalid_repo_id")
def test_push_dataset_to_hub_out_dir_force_override_false(tmpdir):
tmpdir = Path(tmpdir)
out_dir = tmpdir / "out"
raw_dir = tmpdir / "raw"
# mkdir to skip download
raw_dir.mkdir(parents=True, exist_ok=True)
with pytest.raises(ValueError):
push_dataset_to_hub(
raw_dir=raw_dir,
raw_format="some_format",
repo_id="user/dataset",
local_dir=out_dir,
force_override=False,
)
@pytest.mark.parametrize(
"required_packages, raw_format, repo_id",
[
(["gym-pusht"], "pusht_zarr", "lerobot/pusht"),
(None, "xarm_pkl", "lerobot/xarm_lift_medium"),
(None, "aloha_hdf5", "lerobot/aloha_sim_insertion_scripted"),
(["imagecodecs"], "umi_zarr", "lerobot/umi_cup_in_the_wild"),
(None, "dora_parquet", "cadene/wrist_gripper"),
],
)
@require_package_arg
def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_id):
num_episodes = 3
tmpdir = Path(tmpdir)
raw_dir = tmpdir / f"{repo_id}_raw"
_mock_download_raw(raw_dir, repo_id)
local_dir = tmpdir / repo_id
lerobot_dataset = push_dataset_to_hub(
raw_dir=raw_dir,
raw_format=raw_format,
repo_id=repo_id,
push_to_hub=False,
local_dir=local_dir,
force_override=False,
cache_dir=tmpdir / "cache",
)
# minimal generic tests on the local directory containing LeRobotDataset
assert (local_dir / "meta_data" / "info.json").exists()
assert (local_dir / "meta_data" / "stats.safetensors").exists()
assert (local_dir / "meta_data" / "episode_data_index.safetensors").exists()
for i in range(num_episodes):
for cam_key in lerobot_dataset.camera_keys:
assert (local_dir / "videos" / f"{cam_key}_episode_{i:06d}.mp4").exists()
assert (local_dir / "train" / "dataset_info.json").exists()
assert (local_dir / "train" / "state.json").exists()
assert len(list((local_dir / "train").glob("*.arrow"))) > 0
# minimal generic tests on the item
item = lerobot_dataset[0]
assert "index" in item
assert "episode_index" in item
assert "timestamp" in item
for cam_key in lerobot_dataset.camera_keys:
assert cam_key in item
@pytest.mark.parametrize(
"raw_format, repo_id",
[
# TODO(rcadene): add raw dataset test artifacts
("pusht_zarr", "lerobot/pusht"),
("xarm_pkl", "lerobot/xarm_lift_medium"),
("aloha_hdf5", "lerobot/aloha_sim_insertion_scripted"),
("umi_zarr", "lerobot/umi_cup_in_the_wild"),
("dora_parquet", "cadene/wrist_gripper"),
],
)
@pytest.mark.skip(
"Not compatible with our CI since it downloads raw datasets. Run with `DATA_DIR=tests/data python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility`"
)
def test_push_dataset_to_hub_pusht_backward_compatibility(tmpdir, raw_format, repo_id):
_, dataset_id = repo_id.split("/")
tmpdir = Path(tmpdir)
raw_dir = tmpdir / f"{dataset_id}_raw"
local_dir = tmpdir / repo_id
push_dataset_to_hub(
raw_dir=raw_dir,
raw_format=raw_format,
repo_id=repo_id,
push_to_hub=False,
local_dir=local_dir,
force_override=False,
cache_dir=tmpdir / "cache",
episodes=[0],
)
ds_actual = LeRobotDataset(repo_id, root=tmpdir)
ds_reference = LeRobotDataset(repo_id)
assert len(ds_reference.hf_dataset) == len(ds_actual.hf_dataset)
def check_same_items(item1, item2):
assert item1.keys() == item2.keys(), "Keys mismatch"
for key in item1:
if isinstance(item1[key], torch.Tensor) and isinstance(item2[key], torch.Tensor):
assert torch.equal(item1[key], item2[key]), f"Mismatch found in key: {key}"
else:
assert item1[key] == item2[key], f"Mismatch found in key: {key}"
for i in range(len(ds_reference.hf_dataset)):
item_reference = ds_reference.hf_dataset[i]
item_actual = ds_actual.hf_dataset[i]
check_same_items(item_reference, item_actual)

View File

@ -76,6 +76,7 @@ def require_env(func):
""" """
Decorator that skips the test if the required environment package is not installed. Decorator that skips the test if the required environment package is not installed.
As it need 'env_name' in args, it also checks whether it is provided as an argument. As it need 'env_name' in args, it also checks whether it is provided as an argument.
If 'env_name' is None, this check is skipped.
""" """
@wraps(func) @wraps(func)
@ -91,7 +92,7 @@ def require_env(func):
# Perform the package check # Perform the package check
package_name = f"gym_{env_name}" package_name = f"gym_{env_name}"
if not is_package_available(package_name): if env_name is not None and not is_package_available(package_name):
pytest.skip(f"gym-{env_name} not installed") pytest.skip(f"gym-{env_name} not installed")
return func(*args, **kwargs) return func(*args, **kwargs)
@ -99,6 +100,38 @@ def require_env(func):
return wrapper return wrapper
def require_package_arg(func):
"""
Decorator that skips the test if the required package is not installed.
This is similar to `require_env` but more general in that it can check any package (not just environments).
As it need 'required_packages' in args, it also checks whether it is provided as an argument.
If 'required_packages' is None, this check is skipped.
"""
@wraps(func)
def wrapper(*args, **kwargs):
# Determine if 'required_packages' is provided and extract its value
arg_names = func.__code__.co_varnames[: func.__code__.co_argcount]
if "required_packages" in arg_names:
# Get the index of 'required_packages' and retrieve the value from args
index = arg_names.index("required_packages")
required_packages = args[index] if len(args) > index else kwargs.get("required_packages")
else:
raise ValueError("Function does not have 'required_packages' as an argument.")
if required_packages is None:
return func(*args, **kwargs)
# Perform the package check
for package in required_packages:
if not is_package_available(package):
pytest.skip(f"{package} not installed")
return func(*args, **kwargs)
return wrapper
def require_package(package_name): def require_package(package_name):
""" """
Decorator that skips the test if the specified package is not installed. Decorator that skips the test if the specified package is not installed.