Add UMI-gripper dataset (#83)
Co-authored-by: Remi <re.cadene@gmail.com>
This commit is contained in:
parent
a4b6c5e3b1
commit
81e490d46f
|
@ -30,10 +30,31 @@ def download_and_upload(root, revision, dataset_id):
|
||||||
download_and_upload_xarm(root, revision, dataset_id)
|
download_and_upload_xarm(root, revision, dataset_id)
|
||||||
elif "aloha" in dataset_id:
|
elif "aloha" in dataset_id:
|
||||||
download_and_upload_aloha(root, revision, dataset_id)
|
download_and_upload_aloha(root, revision, dataset_id)
|
||||||
|
elif "umi" in dataset_id:
|
||||||
|
download_and_upload_umi(root, revision, dataset_id)
|
||||||
else:
|
else:
|
||||||
raise ValueError(dataset_id)
|
raise ValueError(dataset_id)
|
||||||
|
|
||||||
|
|
||||||
|
def concatenate_episodes(ep_dicts):
|
||||||
|
data_dict = {}
|
||||||
|
|
||||||
|
keys = ep_dicts[0].keys()
|
||||||
|
for key in keys:
|
||||||
|
if torch.is_tensor(ep_dicts[0][key][0]):
|
||||||
|
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
|
||||||
|
else:
|
||||||
|
if key not in data_dict:
|
||||||
|
data_dict[key] = []
|
||||||
|
for ep_dict in ep_dicts:
|
||||||
|
for x in ep_dict[key]:
|
||||||
|
data_dict[key].append(x)
|
||||||
|
|
||||||
|
total_frames = data_dict["frame_index"].shape[0]
|
||||||
|
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||||
|
return data_dict
|
||||||
|
|
||||||
|
|
||||||
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
||||||
import zipfile
|
import zipfile
|
||||||
|
|
||||||
|
@ -62,25 +83,6 @@ def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def concatenate_episodes(ep_dicts):
|
|
||||||
data_dict = {}
|
|
||||||
|
|
||||||
keys = ep_dicts[0].keys()
|
|
||||||
for key in keys:
|
|
||||||
if torch.is_tensor(ep_dicts[0][key][0]):
|
|
||||||
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
|
|
||||||
else:
|
|
||||||
if key not in data_dict:
|
|
||||||
data_dict[key] = []
|
|
||||||
for ep_dict in ep_dicts:
|
|
||||||
for x in ep_dict[key]:
|
|
||||||
data_dict[key].append(x)
|
|
||||||
|
|
||||||
total_frames = data_dict["frame_index"].shape[0]
|
|
||||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
|
||||||
return data_dict
|
|
||||||
|
|
||||||
|
|
||||||
def push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id):
|
def push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id):
|
||||||
# push to main to indicate latest version
|
# push to main to indicate latest version
|
||||||
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
|
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
|
||||||
|
@ -515,9 +517,9 @@ def download_and_upload_aloha(root, revision, dataset_id, fps=50):
|
||||||
"episode_index": Value(dtype="int64", id=None),
|
"episode_index": Value(dtype="int64", id=None),
|
||||||
"frame_index": Value(dtype="int64", id=None),
|
"frame_index": Value(dtype="int64", id=None),
|
||||||
"timestamp": Value(dtype="float32", id=None),
|
"timestamp": Value(dtype="float32", id=None),
|
||||||
#'next.reward': Value(dtype='float32', id=None),
|
# "next.reward": Value(dtype="float32", id=None),
|
||||||
"next.done": Value(dtype="bool", id=None),
|
"next.done": Value(dtype="bool", id=None),
|
||||||
#'next.success': Value(dtype='bool', id=None),
|
# "next.success": Value(dtype="bool", id=None),
|
||||||
"index": Value(dtype="int64", id=None),
|
"index": Value(dtype="int64", id=None),
|
||||||
}
|
}
|
||||||
features = Features(features)
|
features = Features(features)
|
||||||
|
@ -531,10 +533,236 @@ def download_and_upload_aloha(root, revision, dataset_id, fps=50):
|
||||||
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
|
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
|
||||||
|
|
||||||
|
|
||||||
|
def download_and_upload_umi(root, revision, dataset_id, fps=10):
|
||||||
|
# fps is equal to 10 source:https://arxiv.org/pdf/2402.10329.pdf#table.caption.16
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import shutil
|
||||||
|
from glob import glob
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import tqdm
|
||||||
|
import zarr
|
||||||
|
from datasets import Dataset, Features, Image, Sequence, Value
|
||||||
|
|
||||||
|
from lerobot.common.datasets._umi_imagecodecs_numcodecs import register_codecs
|
||||||
|
|
||||||
|
# NOTE: This is critical otherwise ValueError: codec not available: 'imagecodecs_jpegxl'
|
||||||
|
# will be raised
|
||||||
|
register_codecs()
|
||||||
|
|
||||||
|
url_cup_in_the_wild = "https://real.stanford.edu/umi/data/zarr_datasets/cup_in_the_wild.zarr.zip"
|
||||||
|
cup_in_the_wild_zarr = Path("umi/cup_in_the_wild/cup_in_the_wild.zarr")
|
||||||
|
|
||||||
|
root = Path(root)
|
||||||
|
raw_dir = root / f"{dataset_id}_raw"
|
||||||
|
zarr_path = (raw_dir / cup_in_the_wild_zarr).resolve()
|
||||||
|
if not zarr_path.is_dir():
|
||||||
|
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
download_and_extract_zip(url_cup_in_the_wild, zarr_path)
|
||||||
|
zarr_data = zarr.open(zarr_path, mode="r")
|
||||||
|
|
||||||
|
# We process the image data separately because it is too large to fit in memory
|
||||||
|
end_pose = torch.from_numpy(zarr_data["data/robot0_demo_end_pose"][:])
|
||||||
|
start_pos = torch.from_numpy(zarr_data["data/robot0_demo_start_pose"][:])
|
||||||
|
eff_pos = torch.from_numpy(zarr_data["data/robot0_eef_pos"][:])
|
||||||
|
eff_rot_axis_angle = torch.from_numpy(zarr_data["data/robot0_eef_rot_axis_angle"][:])
|
||||||
|
gripper_width = torch.from_numpy(zarr_data["data/robot0_gripper_width"][:])
|
||||||
|
|
||||||
|
states_pos = torch.cat([eff_pos, eff_rot_axis_angle], dim=1)
|
||||||
|
states = torch.cat([states_pos, gripper_width], dim=1)
|
||||||
|
|
||||||
|
def get_episode_idxs(episode_ends: np.ndarray) -> np.ndarray:
|
||||||
|
# Optimized and simplified version of this function: https://github.com/real-stanford/universal_manipulation_interface/blob/298776ce251f33b6b3185a98d6e7d1f9ad49168b/diffusion_policy/common/replay_buffer.py#L374
|
||||||
|
from numba import jit
|
||||||
|
|
||||||
|
@jit(nopython=True)
|
||||||
|
def _get_episode_idxs(episode_ends):
|
||||||
|
result = np.zeros((episode_ends[-1],), dtype=np.int64)
|
||||||
|
start_idx = 0
|
||||||
|
for episode_number, end_idx in enumerate(episode_ends):
|
||||||
|
result[start_idx:end_idx] = episode_number
|
||||||
|
start_idx = end_idx
|
||||||
|
return result
|
||||||
|
|
||||||
|
return _get_episode_idxs(episode_ends)
|
||||||
|
|
||||||
|
episode_ends = zarr_data["meta/episode_ends"][:]
|
||||||
|
num_episodes: int = episode_ends.shape[0]
|
||||||
|
|
||||||
|
episode_ids = torch.from_numpy(get_episode_idxs(episode_ends))
|
||||||
|
|
||||||
|
# We convert it in torch tensor later because the jit function does not support torch tensors
|
||||||
|
episode_ends = torch.from_numpy(episode_ends)
|
||||||
|
|
||||||
|
ep_dicts = []
|
||||||
|
episode_data_index = {"from": [], "to": []}
|
||||||
|
id_from = 0
|
||||||
|
|
||||||
|
for episode_id in tqdm.tqdm(range(num_episodes)):
|
||||||
|
id_to = episode_ends[episode_id]
|
||||||
|
|
||||||
|
num_frames = id_to - id_from
|
||||||
|
|
||||||
|
assert (
|
||||||
|
episode_ids[id_from:id_to] == episode_id
|
||||||
|
).all(), f"episode_ids[{id_from}:{id_to}] != {episode_id}"
|
||||||
|
|
||||||
|
state = states[id_from:id_to]
|
||||||
|
ep_dict = {
|
||||||
|
# observation.image will be filled later
|
||||||
|
"observation.state": state,
|
||||||
|
"episode_index": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
||||||
|
"frame_index": torch.arange(0, num_frames, 1),
|
||||||
|
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
||||||
|
"episode_data_index_from": torch.tensor([id_from] * num_frames),
|
||||||
|
"episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames),
|
||||||
|
"end_pose": end_pose[id_from:id_to],
|
||||||
|
"start_pos": start_pos[id_from:id_to],
|
||||||
|
"gripper_width": gripper_width[id_from:id_to],
|
||||||
|
}
|
||||||
|
ep_dicts.append(ep_dict)
|
||||||
|
episode_data_index["from"].append(id_from)
|
||||||
|
episode_data_index["to"].append(id_from + num_frames)
|
||||||
|
id_from += num_frames
|
||||||
|
|
||||||
|
data_dict = concatenate_episodes(ep_dicts)
|
||||||
|
|
||||||
|
total_frames = id_from
|
||||||
|
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||||
|
|
||||||
|
print("Saving images to disk in temporary folder...")
|
||||||
|
# datasets.Image() can take a list of paths to images, so we save the images to a temporary folder
|
||||||
|
# to avoid loading them all in memory
|
||||||
|
_umi_save_images_concurrently(zarr_data, "tmp_umi_images", max_workers=12)
|
||||||
|
print("Saving images to disk in temporary folder... Done")
|
||||||
|
|
||||||
|
# Sort files by number eg. 1.png, 2.png, 3.png, 9.png, 10.png instead of 1.png, 10.png, 2.png, 3.png, 9.png
|
||||||
|
# to correctly match the images with the data
|
||||||
|
images_path = sorted(glob("tmp_umi_images/*"), key=lambda x: int(re.search(r"(\d+)\.png$", x).group(1)))
|
||||||
|
data_dict["observation.image"] = images_path
|
||||||
|
|
||||||
|
features = {
|
||||||
|
"observation.image": Image(),
|
||||||
|
"observation.state": Sequence(
|
||||||
|
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||||
|
),
|
||||||
|
"episode_index": Value(dtype="int64", id=None),
|
||||||
|
"frame_index": Value(dtype="int64", id=None),
|
||||||
|
"timestamp": Value(dtype="float32", id=None),
|
||||||
|
"index": Value(dtype="int64", id=None),
|
||||||
|
"episode_data_index_from": Value(dtype="int64", id=None),
|
||||||
|
"episode_data_index_to": Value(dtype="int64", id=None),
|
||||||
|
# `start_pos` and `end_pos` respectively represent the positions of the end-effector
|
||||||
|
# at the beginning and the end of the episode.
|
||||||
|
# `gripper_width` indicates the distance between the grippers, and this value is included
|
||||||
|
# in the state vector, which comprises the concatenation of the end-effector position
|
||||||
|
# and gripper width.
|
||||||
|
"end_pose": Sequence(length=data_dict["end_pose"].shape[1], feature=Value(dtype="float32", id=None)),
|
||||||
|
"start_pos": Sequence(
|
||||||
|
length=data_dict["start_pos"].shape[1], feature=Value(dtype="float32", id=None)
|
||||||
|
),
|
||||||
|
"gripper_width": Sequence(
|
||||||
|
length=data_dict["gripper_width"].shape[1], feature=Value(dtype="float32", id=None)
|
||||||
|
),
|
||||||
|
}
|
||||||
|
features = Features(features)
|
||||||
|
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
||||||
|
hf_dataset.set_transform(hf_transform_to_torch)
|
||||||
|
|
||||||
|
info = {
|
||||||
|
"fps": fps,
|
||||||
|
}
|
||||||
|
stats = compute_stats(hf_dataset)
|
||||||
|
push_to_hub(
|
||||||
|
hf_dataset=hf_dataset,
|
||||||
|
episode_data_index=episode_data_index,
|
||||||
|
info=info,
|
||||||
|
stats=stats,
|
||||||
|
root=root,
|
||||||
|
revision=revision,
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
)
|
||||||
|
# Cleanup
|
||||||
|
if os.path.exists("tmp_umi_images"):
|
||||||
|
print("Removing temporary images folder")
|
||||||
|
shutil.rmtree("tmp_umi_images")
|
||||||
|
print("Cleanup done")
|
||||||
|
|
||||||
|
|
||||||
|
def _umi_clear_folder(folder_path: str):
|
||||||
|
import os
|
||||||
|
|
||||||
|
"""
|
||||||
|
Clears all the content of the specified folder. Creates the folder if it does not exist.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
folder_path (str): Path to the folder to clear.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> import os
|
||||||
|
>>> os.makedirs('example_folder', exist_ok=True)
|
||||||
|
>>> with open('example_folder/temp_file.txt', 'w') as f:
|
||||||
|
... f.write('example')
|
||||||
|
>>> clear_folder('example_folder')
|
||||||
|
>>> os.listdir('example_folder')
|
||||||
|
[]
|
||||||
|
"""
|
||||||
|
if os.path.exists(folder_path):
|
||||||
|
for filename in os.listdir(folder_path):
|
||||||
|
file_path = os.path.join(folder_path, filename)
|
||||||
|
try:
|
||||||
|
if os.path.isfile(file_path) or os.path.islink(file_path):
|
||||||
|
os.unlink(file_path)
|
||||||
|
elif os.path.isdir(file_path):
|
||||||
|
shutil.rmtree(file_path)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to delete {file_path}. Reason: {e}")
|
||||||
|
else:
|
||||||
|
os.makedirs(folder_path)
|
||||||
|
|
||||||
|
|
||||||
|
def _umi_save_image(img_array: np.array, i: int, folder_path: str):
|
||||||
|
import os
|
||||||
|
|
||||||
|
"""
|
||||||
|
Saves a single image to the specified folder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img_array (ndarray): The numpy array of the image.
|
||||||
|
i (int): Index of the image, used for naming.
|
||||||
|
folder_path (str): Path to the folder where the image will be saved.
|
||||||
|
"""
|
||||||
|
img = PILImage.fromarray(img_array)
|
||||||
|
img_format = "PNG" if img_array.dtype == np.uint8 else "JPEG"
|
||||||
|
img.save(os.path.join(folder_path, f"{i}.{img_format.lower()}"), quality=100)
|
||||||
|
|
||||||
|
|
||||||
|
def _umi_save_images_concurrently(zarr_data: dict, folder_path: str, max_workers: int = 4):
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
|
"""
|
||||||
|
Saves images from the zarr_data to the specified folder using multithreading.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
zarr_data (dict): A dictionary containing image data in an array format.
|
||||||
|
folder_path (str): Path to the folder where images will be saved.
|
||||||
|
max_workers (int): The maximum number of threads to use for saving images.
|
||||||
|
"""
|
||||||
|
num_images = len(zarr_data["data/camera0_rgb"])
|
||||||
|
_umi_clear_folder(folder_path) # Clear or create folder first
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
|
[
|
||||||
|
executor.submit(_umi_save_image, zarr_data["data/camera0_rgb"][i], i, folder_path)
|
||||||
|
for i in range(num_images)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
root = "data"
|
root = "data"
|
||||||
revision = "v1.1"
|
revision = "v1.1"
|
||||||
|
|
||||||
dataset_ids = [
|
dataset_ids = [
|
||||||
"pusht",
|
"pusht",
|
||||||
"xarm_lift_medium",
|
"xarm_lift_medium",
|
||||||
|
@ -545,6 +773,7 @@ if __name__ == "__main__":
|
||||||
"aloha_sim_insertion_scripted",
|
"aloha_sim_insertion_scripted",
|
||||||
"aloha_sim_transfer_cube_human",
|
"aloha_sim_transfer_cube_human",
|
||||||
"aloha_sim_transfer_cube_scripted",
|
"aloha_sim_transfer_cube_scripted",
|
||||||
|
"umi_cup_in_the_wild",
|
||||||
]
|
]
|
||||||
for dataset_id in dataset_ids:
|
for dataset_id in dataset_ids:
|
||||||
download_and_upload(root, revision, dataset_id)
|
download_and_upload(root, revision, dataset_id)
|
||||||
|
|
|
@ -25,6 +25,8 @@ When implementing a new policy class (e.g. `DiffusionPolicy`) follow these steps
|
||||||
- Update variables in `tests/test_available.py` by importing your new Policy class
|
- Update variables in `tests/test_available.py` by importing your new Policy class
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
|
||||||
from lerobot.__version__ import __version__ # noqa: F401
|
from lerobot.__version__ import __version__ # noqa: F401
|
||||||
|
|
||||||
available_tasks_per_env = {
|
available_tasks_per_env = {
|
||||||
|
@ -52,7 +54,12 @@ available_datasets_per_env = {
|
||||||
"lerobot/xarm_push_medium_replay",
|
"lerobot/xarm_push_medium_replay",
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
available_datasets = [dataset for datasets in available_datasets_per_env.values() for dataset in datasets]
|
|
||||||
|
available_datasets_without_env = ["lerobot/umi_cup_in_the_wild"]
|
||||||
|
|
||||||
|
available_datasets = list(
|
||||||
|
itertools.chain(*available_datasets_per_env.values(), available_datasets_without_env)
|
||||||
|
)
|
||||||
|
|
||||||
available_policies = [
|
available_policies = [
|
||||||
"act",
|
"act",
|
||||||
|
|
|
@ -0,0 +1,311 @@
|
||||||
|
# imagecodecs/numcodecs.py
|
||||||
|
|
||||||
|
# Copyright (c) 2021-2022, Christoph Gohlke
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# Redistribution and use in source and binary forms, with or without
|
||||||
|
# modification, are permitted provided that the following conditions are met:
|
||||||
|
#
|
||||||
|
# 1. Redistributions of source code must retain the above copyright notice,
|
||||||
|
# this list of conditions and the following disclaimer.
|
||||||
|
#
|
||||||
|
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
# this list of conditions and the following disclaimer in the documentation
|
||||||
|
# and/or other materials provided with the distribution.
|
||||||
|
#
|
||||||
|
# 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
# contributors may be used to endorse or promote products derived from
|
||||||
|
# this software without specific prior written permission.
|
||||||
|
#
|
||||||
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||||
|
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||||
|
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||||
|
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||||
|
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||||
|
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||||
|
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||||
|
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||||
|
# POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
|
# Copied from: https://github.com/real-stanford/universal_manipulation_interface/blob/298776ce251f33b6b3185a98d6e7d1f9ad49168b/diffusion_policy/codecs/imagecodecs_numcodecs.py#L1
|
||||||
|
"""Additional numcodecs implemented using imagecodecs."""
|
||||||
|
|
||||||
|
__version__ = "2022.9.26"
|
||||||
|
|
||||||
|
__all__ = ("register_codecs",)
|
||||||
|
|
||||||
|
import imagecodecs
|
||||||
|
import numpy
|
||||||
|
from numcodecs.abc import Codec
|
||||||
|
from numcodecs.registry import get_codec, register_codec
|
||||||
|
|
||||||
|
# TODO (azouitine): Remove useless codecs
|
||||||
|
|
||||||
|
|
||||||
|
def protective_squeeze(x: numpy.ndarray):
|
||||||
|
"""
|
||||||
|
Squeeze dim only if it's not the last dim.
|
||||||
|
Image dim expected to be *, H, W, C
|
||||||
|
"""
|
||||||
|
img_shape = x.shape[-3:]
|
||||||
|
if len(x.shape) > 3:
|
||||||
|
n_imgs = numpy.prod(x.shape[:-3])
|
||||||
|
if n_imgs > 1:
|
||||||
|
img_shape = (-1,) + img_shape
|
||||||
|
return x.reshape(img_shape)
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_image_compressor(**kwargs):
|
||||||
|
if imagecodecs.JPEGXL:
|
||||||
|
# has JPEGXL
|
||||||
|
this_kwargs = {
|
||||||
|
"effort": 3,
|
||||||
|
"distance": 0.3,
|
||||||
|
# bug in libjxl, invalid codestream for non-lossless
|
||||||
|
# when decoding speed > 1
|
||||||
|
"decodingspeed": 1,
|
||||||
|
}
|
||||||
|
this_kwargs.update(kwargs)
|
||||||
|
return JpegXl(**this_kwargs)
|
||||||
|
else:
|
||||||
|
this_kwargs = {"level": 50}
|
||||||
|
this_kwargs.update(kwargs)
|
||||||
|
return Jpeg2k(**this_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class Jpeg2k(Codec):
|
||||||
|
"""JPEG 2000 codec for numcodecs."""
|
||||||
|
|
||||||
|
codec_id = "imagecodecs_jpeg2k"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
level=None,
|
||||||
|
codecformat=None,
|
||||||
|
colorspace=None,
|
||||||
|
tile=None,
|
||||||
|
reversible=None,
|
||||||
|
bitspersample=None,
|
||||||
|
resolutions=None,
|
||||||
|
numthreads=None,
|
||||||
|
verbose=0,
|
||||||
|
):
|
||||||
|
self.level = level
|
||||||
|
self.codecformat = codecformat
|
||||||
|
self.colorspace = colorspace
|
||||||
|
self.tile = None if tile is None else tuple(tile)
|
||||||
|
self.reversible = reversible
|
||||||
|
self.bitspersample = bitspersample
|
||||||
|
self.resolutions = resolutions
|
||||||
|
self.numthreads = numthreads
|
||||||
|
self.verbose = verbose
|
||||||
|
|
||||||
|
def encode(self, buf):
|
||||||
|
buf = protective_squeeze(numpy.asarray(buf))
|
||||||
|
return imagecodecs.jpeg2k_encode(
|
||||||
|
buf,
|
||||||
|
level=self.level,
|
||||||
|
codecformat=self.codecformat,
|
||||||
|
colorspace=self.colorspace,
|
||||||
|
tile=self.tile,
|
||||||
|
reversible=self.reversible,
|
||||||
|
bitspersample=self.bitspersample,
|
||||||
|
resolutions=self.resolutions,
|
||||||
|
numthreads=self.numthreads,
|
||||||
|
verbose=self.verbose,
|
||||||
|
)
|
||||||
|
|
||||||
|
def decode(self, buf, out=None):
|
||||||
|
return imagecodecs.jpeg2k_decode(buf, verbose=self.verbose, numthreads=self.numthreads, out=out)
|
||||||
|
|
||||||
|
|
||||||
|
class JpegXl(Codec):
|
||||||
|
"""JPEG XL codec for numcodecs."""
|
||||||
|
|
||||||
|
codec_id = "imagecodecs_jpegxl"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
# encode
|
||||||
|
level=None,
|
||||||
|
effort=None,
|
||||||
|
distance=None,
|
||||||
|
lossless=None,
|
||||||
|
decodingspeed=None,
|
||||||
|
photometric=None,
|
||||||
|
planar=None,
|
||||||
|
usecontainer=None,
|
||||||
|
# decode
|
||||||
|
index=None,
|
||||||
|
keeporientation=None,
|
||||||
|
# both
|
||||||
|
numthreads=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Return JPEG XL image from numpy array.
|
||||||
|
Float must be in nominal range 0..1.
|
||||||
|
|
||||||
|
Currently L, LA, RGB, RGBA images are supported in contig mode.
|
||||||
|
Extra channels are only supported for grayscale images in planar mode.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
level : Default to None, i.e. not overwriting lossess and decodingspeed options.
|
||||||
|
When < 0: Use lossless compression
|
||||||
|
When in [0,1,2,3,4]: Sets the decoding speed tier for the provided options.
|
||||||
|
Minimum is 0 (slowest to decode, best quality/density), and maximum
|
||||||
|
is 4 (fastest to decode, at the cost of some quality/density).
|
||||||
|
effort : Default to 3.
|
||||||
|
Sets encoder effort/speed level without affecting decoding speed.
|
||||||
|
Valid values are, from faster to slower speed: 1:lightning 2:thunder
|
||||||
|
3:falcon 4:cheetah 5:hare 6:wombat 7:squirrel 8:kitten 9:tortoise.
|
||||||
|
Speed: lightning, thunder, falcon, cheetah, hare, wombat, squirrel, kitten, tortoise
|
||||||
|
control the encoder effort in ascending order.
|
||||||
|
This also affects memory usage: using lower effort will typically reduce memory
|
||||||
|
consumption during encoding.
|
||||||
|
lightning and thunder are fast modes useful for lossless mode (modular).
|
||||||
|
falcon disables all of the following tools.
|
||||||
|
cheetah enables coefficient reordering, context clustering, and heuristics for selecting DCT sizes and quantization steps.
|
||||||
|
hare enables Gaborish filtering, chroma from luma, and an initial estimate of quantization steps.
|
||||||
|
wombat enables error diffusion quantization and full DCT size selection heuristics.
|
||||||
|
squirrel (default) enables dots, patches, and spline detection, and full context clustering.
|
||||||
|
kitten optimizes the adaptive quantization for a psychovisual metric.
|
||||||
|
tortoise enables a more thorough adaptive quantization search.
|
||||||
|
distance : Default to 1.0
|
||||||
|
Sets the distance level for lossy compression: target max butteraugli distance,
|
||||||
|
lower = higher quality. Range: 0 .. 15. 0.0 = mathematically lossless
|
||||||
|
(however, use JxlEncoderSetFrameLossless instead to use true lossless,
|
||||||
|
as setting distance to 0 alone is not the only requirement).
|
||||||
|
1.0 = visually lossless. Recommended range: 0.5 .. 3.0.
|
||||||
|
lossess : Default to False.
|
||||||
|
Use lossess encoding.
|
||||||
|
decodingspeed : Default to 0.
|
||||||
|
Duplicate to level. [0,4]
|
||||||
|
photometric : Return JxlColorSpace value.
|
||||||
|
Default logic is quite complicated but works most of the time.
|
||||||
|
Accepted value:
|
||||||
|
int: [-1,3]
|
||||||
|
str: ['RGB',
|
||||||
|
'WHITEISZERO', 'MINISWHITE',
|
||||||
|
'BLACKISZERO', 'MINISBLACK', 'GRAY',
|
||||||
|
'XYB', 'KNOWN']
|
||||||
|
planar : Enable multi-channel mode.
|
||||||
|
Default to false.
|
||||||
|
usecontainer :
|
||||||
|
Forces the encoder to use the box-based container format (BMFF)
|
||||||
|
even when not necessary.
|
||||||
|
When using JxlEncoderUseBoxes, JxlEncoderStoreJPEGMetadata or
|
||||||
|
JxlEncoderSetCodestreamLevel with level 10, the encoder will
|
||||||
|
automatically also use the container format, it is not necessary
|
||||||
|
to use JxlEncoderUseContainer for those use cases.
|
||||||
|
By default this setting is disabled.
|
||||||
|
index : Selectively decode frames for animation.
|
||||||
|
Default to 0, decode all frames.
|
||||||
|
When set to > 0, decode that frame index only.
|
||||||
|
keeporientation :
|
||||||
|
Enables or disables preserving of as-in-bitstream pixeldata orientation.
|
||||||
|
Some images are encoded with an Orientation tag indicating that the
|
||||||
|
decoder must perform a rotation and/or mirroring to the encoded image data.
|
||||||
|
|
||||||
|
If skip_reorientation is JXL_FALSE (the default): the decoder will apply
|
||||||
|
the transformation from the orientation setting, hence rendering the image
|
||||||
|
according to its specified intent. When producing a JxlBasicInfo, the decoder
|
||||||
|
will always set the orientation field to JXL_ORIENT_IDENTITY (matching the
|
||||||
|
returned pixel data) and also align xsize and ysize so that they correspond
|
||||||
|
to the width and the height of the returned pixel data.
|
||||||
|
|
||||||
|
If skip_reorientation is JXL_TRUE: the decoder will skip applying the
|
||||||
|
transformation from the orientation setting, returning the image in
|
||||||
|
the as-in-bitstream pixeldata orientation. This may be faster to decode
|
||||||
|
since the decoder doesnt have to apply the transformation, but can
|
||||||
|
cause wrong display of the image if the orientation tag is not correctly
|
||||||
|
taken into account by the user.
|
||||||
|
|
||||||
|
By default, this option is disabled, and the returned pixel data is
|
||||||
|
re-oriented according to the images Orientation setting.
|
||||||
|
threads : Default to 1.
|
||||||
|
If <= 0, use all cores.
|
||||||
|
If > 32, clipped to 32.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.level = level
|
||||||
|
self.effort = effort
|
||||||
|
self.distance = distance
|
||||||
|
self.lossless = bool(lossless)
|
||||||
|
self.decodingspeed = decodingspeed
|
||||||
|
self.photometric = photometric
|
||||||
|
self.planar = planar
|
||||||
|
self.usecontainer = usecontainer
|
||||||
|
self.index = index
|
||||||
|
self.keeporientation = keeporientation
|
||||||
|
self.numthreads = numthreads
|
||||||
|
|
||||||
|
def encode(self, buf):
|
||||||
|
# TODO: only squeeze all but last dim
|
||||||
|
buf = protective_squeeze(numpy.asarray(buf))
|
||||||
|
return imagecodecs.jpegxl_encode(
|
||||||
|
buf,
|
||||||
|
level=self.level,
|
||||||
|
effort=self.effort,
|
||||||
|
distance=self.distance,
|
||||||
|
lossless=self.lossless,
|
||||||
|
decodingspeed=self.decodingspeed,
|
||||||
|
photometric=self.photometric,
|
||||||
|
planar=self.planar,
|
||||||
|
usecontainer=self.usecontainer,
|
||||||
|
numthreads=self.numthreads,
|
||||||
|
)
|
||||||
|
|
||||||
|
def decode(self, buf, out=None):
|
||||||
|
return imagecodecs.jpegxl_decode(
|
||||||
|
buf,
|
||||||
|
index=self.index,
|
||||||
|
keeporientation=self.keeporientation,
|
||||||
|
numthreads=self.numthreads,
|
||||||
|
out=out,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _flat(out):
|
||||||
|
"""Return numpy array as contiguous view of bytes if possible."""
|
||||||
|
if out is None:
|
||||||
|
return None
|
||||||
|
view = memoryview(out)
|
||||||
|
if view.readonly or not view.contiguous:
|
||||||
|
return None
|
||||||
|
return view.cast("B")
|
||||||
|
|
||||||
|
|
||||||
|
def register_codecs(codecs=None, force=False, verbose=True):
|
||||||
|
"""Register codecs in this module with numcodecs."""
|
||||||
|
for name, cls in globals().items():
|
||||||
|
if not hasattr(cls, "codec_id") or name == "Codec":
|
||||||
|
continue
|
||||||
|
if codecs is not None and cls.codec_id not in codecs:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
try: # noqa: SIM105
|
||||||
|
get_codec({"id": cls.codec_id})
|
||||||
|
except TypeError:
|
||||||
|
# registered, but failed
|
||||||
|
pass
|
||||||
|
except ValueError:
|
||||||
|
# not registered yet
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
if not force:
|
||||||
|
if verbose:
|
||||||
|
log_warning(f"numcodec {cls.codec_id!r} already registered")
|
||||||
|
continue
|
||||||
|
if verbose:
|
||||||
|
log_warning(f"replacing registered numcodec {cls.codec_id!r}")
|
||||||
|
register_codec(cls)
|
||||||
|
|
||||||
|
|
||||||
|
def log_warning(msg, *args, **kwargs):
|
||||||
|
"""Log message with level WARNING."""
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logging.getLogger(__name__).warning(msg, *args, **kwargs)
|
|
@ -1326,6 +1326,49 @@ files = [
|
||||||
{file = "idna-3.7.tar.gz", hash = "sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc"},
|
{file = "idna-3.7.tar.gz", hash = "sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "imagecodecs"
|
||||||
|
version = "2024.1.1"
|
||||||
|
description = "Image transformation, compression, and decompression codecs"
|
||||||
|
optional = true
|
||||||
|
python-versions = ">=3.9"
|
||||||
|
files = [
|
||||||
|
{file = "imagecodecs-2024.1.1-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:4b787ffbe62e98e492ace10229f07ee309f42b93dd0fc602ba7595d06e326cef"},
|
||||||
|
{file = "imagecodecs-2024.1.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a3b87b648b081fb938073b42729d3c2c56344242e1d67af4d53f408eb4c1e8c6"},
|
||||||
|
{file = "imagecodecs-2024.1.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca6eb9cb4d1e8dabbbaaa9e257a421eaab2e9c9d37888dda31c15cf56648f67d"},
|
||||||
|
{file = "imagecodecs-2024.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c5a2c70ee23bc1c59c76bafcdc79c6f61c8c231bad5789df22699d77eb3b6556"},
|
||||||
|
{file = "imagecodecs-2024.1.1-cp310-cp310-win32.whl", hash = "sha256:38b7abfdddc317fc44f69eaa82ca95e44be0392f8b6eefa55f7ff2bf912a2dda"},
|
||||||
|
{file = "imagecodecs-2024.1.1-cp310-cp310-win_amd64.whl", hash = "sha256:a5153df7451f170dfd41c00a2686281bd0a73fcbf315f546689276b594cf97ae"},
|
||||||
|
{file = "imagecodecs-2024.1.1-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:32eb26ebb89dd56f1bb7984f71ce84bfe34d6c21fa573b113fb4f473a9aa7e6c"},
|
||||||
|
{file = "imagecodecs-2024.1.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3e10aea5e391f53f39cb07f1559c780f2292436e756afb7fdf0379b3eacef9cd"},
|
||||||
|
{file = "imagecodecs-2024.1.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c2f5dfa0bc36d86ce3e3b4f14a99bc1cdb8d65898fbd316d7f2b1ff9fdc6f6eb"},
|
||||||
|
{file = "imagecodecs-2024.1.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e6db283c41118cb66c4ec05f966d1b7cc9251f66fcba8cb472d71af2184902fb"},
|
||||||
|
{file = "imagecodecs-2024.1.1-cp311-cp311-win32.whl", hash = "sha256:ef7a13d09966b021c33e2dd726ee44b9cd3cb198b32474c7e48cca55abd52f6a"},
|
||||||
|
{file = "imagecodecs-2024.1.1-cp311-cp311-win_amd64.whl", hash = "sha256:cfff3b3fae93d414ec851ada3fe6875857bc3c234eb9718c8baef8a3130c2ced"},
|
||||||
|
{file = "imagecodecs-2024.1.1-cp311-cp311-win_arm64.whl", hash = "sha256:bf4ed0385973ce3e0b2e2c9d720310e2378b4801ae3f52e0c6cb1c94b116f711"},
|
||||||
|
{file = "imagecodecs-2024.1.1-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:896302c49aa3beae94b45b1d5411850eb129980d1ef6c3e12e0c4a413236866e"},
|
||||||
|
{file = "imagecodecs-2024.1.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:63fc197b091f4dd0f3d490570ec175dd6e276b3b4d7d2b3fa0e89548a6686a57"},
|
||||||
|
{file = "imagecodecs-2024.1.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:636c5e6a599df1a5168c32ed0063a7a98585f424da8679cc20af50a1a1c2185f"},
|
||||||
|
{file = "imagecodecs-2024.1.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4b1c30bf8be21a3e58720dfe86cf618be12c4ce5be0657268983caaf38a59368"},
|
||||||
|
{file = "imagecodecs-2024.1.1-cp312-cp312-win32.whl", hash = "sha256:12c2bb563bf173067b9ccd193af354fea7b78aa8bc43a61943445b0126b1bd4d"},
|
||||||
|
{file = "imagecodecs-2024.1.1-cp312-cp312-win_amd64.whl", hash = "sha256:ccdca7a78ceab005aa3a92fd548e12d042c02862725b18f0328c15fa117c4638"},
|
||||||
|
{file = "imagecodecs-2024.1.1-cp312-cp312-win_arm64.whl", hash = "sha256:51b6c1afe1d04dd8b6059d59d94a487097b2fb44c4a89e244a386ea7fe170f89"},
|
||||||
|
{file = "imagecodecs-2024.1.1-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:93fef596adec62bf49418f6df31c45b19f655af7c943d4ba38ee703a7bc5201d"},
|
||||||
|
{file = "imagecodecs-2024.1.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d969d2b04b24a3a75d747b830a5ac9e2f4fcdc492db565f65d51951a9587da22"},
|
||||||
|
{file = "imagecodecs-2024.1.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8804b1d0be1fbce21d18e26ebafd274e131155a242bfd5fbfd1f02b4df613e28"},
|
||||||
|
{file = "imagecodecs-2024.1.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1946d847e57e3b3a9735a7c68c674185d5061dc475071d177aa09335c667b687"},
|
||||||
|
{file = "imagecodecs-2024.1.1-cp39-cp39-win32.whl", hash = "sha256:d80b382d906152f4a5d1c93d4e589044836d8e6bb11246dabf72f1e74b1166fb"},
|
||||||
|
{file = "imagecodecs-2024.1.1-cp39-cp39-win_amd64.whl", hash = "sha256:bd14e068fb78f291d1e27e548839a470faaf75dd02c9d1376d5194519544cf11"},
|
||||||
|
{file = "imagecodecs-2024.1.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:1967ba770f780ab9aa560cf850268d601e189e4df3dd4109fa8a19cba19c9cda"},
|
||||||
|
{file = "imagecodecs-2024.1.1.tar.gz", hash = "sha256:fde46bd698d008255deef5411c59b35c0e875295e835bf6079f7e2ab22f216eb"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
numpy = "*"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
all = ["matplotlib", "numcodecs", "tifffile"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "imageio"
|
name = "imageio"
|
||||||
version = "2.34.1"
|
version = "2.34.1"
|
||||||
|
@ -2855,13 +2898,13 @@ files = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pytest"
|
name = "pytest"
|
||||||
version = "8.1.2"
|
version = "8.2.0"
|
||||||
description = "pytest: simple powerful testing with Python"
|
description = "pytest: simple powerful testing with Python"
|
||||||
optional = true
|
optional = true
|
||||||
python-versions = ">=3.8"
|
python-versions = ">=3.8"
|
||||||
files = [
|
files = [
|
||||||
{file = "pytest-8.1.2-py3-none-any.whl", hash = "sha256:6c06dc309ff46a05721e6fd48e492a775ed8165d2ecdf57f156a80c7e95bb142"},
|
{file = "pytest-8.2.0-py3-none-any.whl", hash = "sha256:1733f0620f6cda4095bbf0d9ff8022486e91892245bb9e7d5542c018f612f233"},
|
||||||
{file = "pytest-8.1.2.tar.gz", hash = "sha256:f3c45d1d5eed96b01a2aea70dee6a4a366d51d38f9957768083e4fecfc77f3ef"},
|
{file = "pytest-8.2.0.tar.gz", hash = "sha256:d507d4482197eac0ba2bae2e9babf0672eb333017bcedaa5fb1a3d42c1174b3f"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
|
@ -2869,11 +2912,11 @@ colorama = {version = "*", markers = "sys_platform == \"win32\""}
|
||||||
exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""}
|
exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""}
|
||||||
iniconfig = "*"
|
iniconfig = "*"
|
||||||
packaging = "*"
|
packaging = "*"
|
||||||
pluggy = ">=1.4,<2.0"
|
pluggy = ">=1.5,<2.0"
|
||||||
tomli = {version = ">=1", markers = "python_version < \"3.11\""}
|
tomli = {version = ">=1", markers = "python_version < \"3.11\""}
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
testing = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
|
dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pytest-cov"
|
name = "pytest-cov"
|
||||||
|
@ -2943,6 +2986,7 @@ files = [
|
||||||
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
|
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
|
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
|
||||||
|
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
|
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
|
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
|
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
|
||||||
|
@ -4249,9 +4293,10 @@ aloha = ["gym-aloha"]
|
||||||
dev = ["debugpy", "pre-commit"]
|
dev = ["debugpy", "pre-commit"]
|
||||||
pusht = ["gym-pusht"]
|
pusht = ["gym-pusht"]
|
||||||
test = ["pytest", "pytest-cov"]
|
test = ["pytest", "pytest-cov"]
|
||||||
|
umi = ["imagecodecs"]
|
||||||
xarm = ["gym-xarm"]
|
xarm = ["gym-xarm"]
|
||||||
|
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.10"
|
python-versions = "^3.10"
|
||||||
content-hash = "03c1d8598e376680c08f997fff37d3c6aa39efce5e3b6f799ea7bbd476a492f9"
|
content-hash = "8bd1352973c6104e52f50b68f7387d26ced9b07a52e889540b73d132865cda38"
|
||||||
|
|
|
@ -54,6 +54,7 @@ debugpy = {version = "^1.8.1", optional = true}
|
||||||
pytest = {version = "^8.1.0", optional = true}
|
pytest = {version = "^8.1.0", optional = true}
|
||||||
pytest-cov = {version = "^5.0.0", optional = true}
|
pytest-cov = {version = "^5.0.0", optional = true}
|
||||||
datasets = "^2.19.0"
|
datasets = "^2.19.0"
|
||||||
|
imagecodecs = { version = "^2024.1.1", optional = true }
|
||||||
torchaudio = "^2.3.0"
|
torchaudio = "^2.3.0"
|
||||||
|
|
||||||
|
|
||||||
|
@ -63,6 +64,7 @@ xarm = ["gym-xarm"]
|
||||||
aloha = ["gym-aloha"]
|
aloha = ["gym-aloha"]
|
||||||
dev = ["pre-commit", "debugpy"]
|
dev = ["pre-commit", "debugpy"]
|
||||||
test = ["pytest", "pytest-cov"]
|
test = ["pytest", "pytest-cov"]
|
||||||
|
umi = ["imagecodecs"]
|
||||||
|
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
|
|
Binary file not shown.
|
@ -0,0 +1,3 @@
|
||||||
|
{
|
||||||
|
"fps": 10
|
||||||
|
}
|
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,67 @@
|
||||||
|
{
|
||||||
|
"citation": "",
|
||||||
|
"description": "",
|
||||||
|
"features": {
|
||||||
|
"observation.state": {
|
||||||
|
"feature": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"_type": "Value"
|
||||||
|
},
|
||||||
|
"length": 7,
|
||||||
|
"_type": "Sequence"
|
||||||
|
},
|
||||||
|
"episode_index": {
|
||||||
|
"dtype": "int64",
|
||||||
|
"_type": "Value"
|
||||||
|
},
|
||||||
|
"frame_index": {
|
||||||
|
"dtype": "int64",
|
||||||
|
"_type": "Value"
|
||||||
|
},
|
||||||
|
"timestamp": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"_type": "Value"
|
||||||
|
},
|
||||||
|
"episode_data_index_from": {
|
||||||
|
"dtype": "int64",
|
||||||
|
"_type": "Value"
|
||||||
|
},
|
||||||
|
"episode_data_index_to": {
|
||||||
|
"dtype": "int64",
|
||||||
|
"_type": "Value"
|
||||||
|
},
|
||||||
|
"end_pose": {
|
||||||
|
"feature": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"_type": "Value"
|
||||||
|
},
|
||||||
|
"length": 6,
|
||||||
|
"_type": "Sequence"
|
||||||
|
},
|
||||||
|
"start_pos": {
|
||||||
|
"feature": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"_type": "Value"
|
||||||
|
},
|
||||||
|
"length": 6,
|
||||||
|
"_type": "Sequence"
|
||||||
|
},
|
||||||
|
"gripper_width": {
|
||||||
|
"feature": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"_type": "Value"
|
||||||
|
},
|
||||||
|
"length": 1,
|
||||||
|
"_type": "Sequence"
|
||||||
|
},
|
||||||
|
"index": {
|
||||||
|
"dtype": "int64",
|
||||||
|
"_type": "Value"
|
||||||
|
},
|
||||||
|
"observation.image": {
|
||||||
|
"_type": "Image"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"homepage": "",
|
||||||
|
"license": ""
|
||||||
|
}
|
|
@ -0,0 +1,13 @@
|
||||||
|
{
|
||||||
|
"_data_files": [
|
||||||
|
{
|
||||||
|
"filename": "data-00000-of-00001.arrow"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"_fingerprint": "fd95ee932cb1fce2",
|
||||||
|
"_format_columns": null,
|
||||||
|
"_format_kwargs": {},
|
||||||
|
"_format_type": "torch",
|
||||||
|
"_output_all_columns": false,
|
||||||
|
"_split": null
|
||||||
|
}
|
Loading…
Reference in New Issue