Add UMI-gripper dataset (#83)

Co-authored-by: Remi <re.cadene@gmail.com>
This commit is contained in:
Adil Zouitine 2024-04-28 18:41:07 +02:00 committed by GitHub
parent a4b6c5e3b1
commit 81e490d46f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 706 additions and 29 deletions

View File

@ -30,10 +30,31 @@ def download_and_upload(root, revision, dataset_id):
download_and_upload_xarm(root, revision, dataset_id)
elif "aloha" in dataset_id:
download_and_upload_aloha(root, revision, dataset_id)
elif "umi" in dataset_id:
download_and_upload_umi(root, revision, dataset_id)
else:
raise ValueError(dataset_id)
def concatenate_episodes(ep_dicts):
data_dict = {}
keys = ep_dicts[0].keys()
for key in keys:
if torch.is_tensor(ep_dicts[0][key][0]):
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
else:
if key not in data_dict:
data_dict[key] = []
for ep_dict in ep_dicts:
for x in ep_dict[key]:
data_dict[key].append(x)
total_frames = data_dict["frame_index"].shape[0]
data_dict["index"] = torch.arange(0, total_frames, 1)
return data_dict
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
import zipfile
@ -62,25 +83,6 @@ def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
return False
def concatenate_episodes(ep_dicts):
data_dict = {}
keys = ep_dicts[0].keys()
for key in keys:
if torch.is_tensor(ep_dicts[0][key][0]):
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
else:
if key not in data_dict:
data_dict[key] = []
for ep_dict in ep_dicts:
for x in ep_dict[key]:
data_dict[key].append(x)
total_frames = data_dict["frame_index"].shape[0]
data_dict["index"] = torch.arange(0, total_frames, 1)
return data_dict
def push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id):
# push to main to indicate latest version
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
@ -515,9 +517,9 @@ def download_and_upload_aloha(root, revision, dataset_id, fps=50):
"episode_index": Value(dtype="int64", id=None),
"frame_index": Value(dtype="int64", id=None),
"timestamp": Value(dtype="float32", id=None),
#'next.reward': Value(dtype='float32', id=None),
# "next.reward": Value(dtype="float32", id=None),
"next.done": Value(dtype="bool", id=None),
#'next.success': Value(dtype='bool', id=None),
# "next.success": Value(dtype="bool", id=None),
"index": Value(dtype="int64", id=None),
}
features = Features(features)
@ -531,10 +533,236 @@ def download_and_upload_aloha(root, revision, dataset_id, fps=50):
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
def download_and_upload_umi(root, revision, dataset_id, fps=10):
# fps is equal to 10 source:https://arxiv.org/pdf/2402.10329.pdf#table.caption.16
import os
import re
import shutil
from glob import glob
import numpy as np
import torch
import tqdm
import zarr
from datasets import Dataset, Features, Image, Sequence, Value
from lerobot.common.datasets._umi_imagecodecs_numcodecs import register_codecs
# NOTE: This is critical otherwise ValueError: codec not available: 'imagecodecs_jpegxl'
# will be raised
register_codecs()
url_cup_in_the_wild = "https://real.stanford.edu/umi/data/zarr_datasets/cup_in_the_wild.zarr.zip"
cup_in_the_wild_zarr = Path("umi/cup_in_the_wild/cup_in_the_wild.zarr")
root = Path(root)
raw_dir = root / f"{dataset_id}_raw"
zarr_path = (raw_dir / cup_in_the_wild_zarr).resolve()
if not zarr_path.is_dir():
raw_dir.mkdir(parents=True, exist_ok=True)
download_and_extract_zip(url_cup_in_the_wild, zarr_path)
zarr_data = zarr.open(zarr_path, mode="r")
# We process the image data separately because it is too large to fit in memory
end_pose = torch.from_numpy(zarr_data["data/robot0_demo_end_pose"][:])
start_pos = torch.from_numpy(zarr_data["data/robot0_demo_start_pose"][:])
eff_pos = torch.from_numpy(zarr_data["data/robot0_eef_pos"][:])
eff_rot_axis_angle = torch.from_numpy(zarr_data["data/robot0_eef_rot_axis_angle"][:])
gripper_width = torch.from_numpy(zarr_data["data/robot0_gripper_width"][:])
states_pos = torch.cat([eff_pos, eff_rot_axis_angle], dim=1)
states = torch.cat([states_pos, gripper_width], dim=1)
def get_episode_idxs(episode_ends: np.ndarray) -> np.ndarray:
# Optimized and simplified version of this function: https://github.com/real-stanford/universal_manipulation_interface/blob/298776ce251f33b6b3185a98d6e7d1f9ad49168b/diffusion_policy/common/replay_buffer.py#L374
from numba import jit
@jit(nopython=True)
def _get_episode_idxs(episode_ends):
result = np.zeros((episode_ends[-1],), dtype=np.int64)
start_idx = 0
for episode_number, end_idx in enumerate(episode_ends):
result[start_idx:end_idx] = episode_number
start_idx = end_idx
return result
return _get_episode_idxs(episode_ends)
episode_ends = zarr_data["meta/episode_ends"][:]
num_episodes: int = episode_ends.shape[0]
episode_ids = torch.from_numpy(get_episode_idxs(episode_ends))
# We convert it in torch tensor later because the jit function does not support torch tensors
episode_ends = torch.from_numpy(episode_ends)
ep_dicts = []
episode_data_index = {"from": [], "to": []}
id_from = 0
for episode_id in tqdm.tqdm(range(num_episodes)):
id_to = episode_ends[episode_id]
num_frames = id_to - id_from
assert (
episode_ids[id_from:id_to] == episode_id
).all(), f"episode_ids[{id_from}:{id_to}] != {episode_id}"
state = states[id_from:id_to]
ep_dict = {
# observation.image will be filled later
"observation.state": state,
"episode_index": torch.tensor([episode_id] * num_frames, dtype=torch.int),
"frame_index": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / fps,
"episode_data_index_from": torch.tensor([id_from] * num_frames),
"episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames),
"end_pose": end_pose[id_from:id_to],
"start_pos": start_pos[id_from:id_to],
"gripper_width": gripper_width[id_from:id_to],
}
ep_dicts.append(ep_dict)
episode_data_index["from"].append(id_from)
episode_data_index["to"].append(id_from + num_frames)
id_from += num_frames
data_dict = concatenate_episodes(ep_dicts)
total_frames = id_from
data_dict["index"] = torch.arange(0, total_frames, 1)
print("Saving images to disk in temporary folder...")
# datasets.Image() can take a list of paths to images, so we save the images to a temporary folder
# to avoid loading them all in memory
_umi_save_images_concurrently(zarr_data, "tmp_umi_images", max_workers=12)
print("Saving images to disk in temporary folder... Done")
# Sort files by number eg. 1.png, 2.png, 3.png, 9.png, 10.png instead of 1.png, 10.png, 2.png, 3.png, 9.png
# to correctly match the images with the data
images_path = sorted(glob("tmp_umi_images/*"), key=lambda x: int(re.search(r"(\d+)\.png$", x).group(1)))
data_dict["observation.image"] = images_path
features = {
"observation.image": Image(),
"observation.state": Sequence(
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
),
"episode_index": Value(dtype="int64", id=None),
"frame_index": Value(dtype="int64", id=None),
"timestamp": Value(dtype="float32", id=None),
"index": Value(dtype="int64", id=None),
"episode_data_index_from": Value(dtype="int64", id=None),
"episode_data_index_to": Value(dtype="int64", id=None),
# `start_pos` and `end_pos` respectively represent the positions of the end-effector
# at the beginning and the end of the episode.
# `gripper_width` indicates the distance between the grippers, and this value is included
# in the state vector, which comprises the concatenation of the end-effector position
# and gripper width.
"end_pose": Sequence(length=data_dict["end_pose"].shape[1], feature=Value(dtype="float32", id=None)),
"start_pos": Sequence(
length=data_dict["start_pos"].shape[1], feature=Value(dtype="float32", id=None)
),
"gripper_width": Sequence(
length=data_dict["gripper_width"].shape[1], feature=Value(dtype="float32", id=None)
),
}
features = Features(features)
hf_dataset = Dataset.from_dict(data_dict, features=features)
hf_dataset.set_transform(hf_transform_to_torch)
info = {
"fps": fps,
}
stats = compute_stats(hf_dataset)
push_to_hub(
hf_dataset=hf_dataset,
episode_data_index=episode_data_index,
info=info,
stats=stats,
root=root,
revision=revision,
dataset_id=dataset_id,
)
# Cleanup
if os.path.exists("tmp_umi_images"):
print("Removing temporary images folder")
shutil.rmtree("tmp_umi_images")
print("Cleanup done")
def _umi_clear_folder(folder_path: str):
import os
"""
Clears all the content of the specified folder. Creates the folder if it does not exist.
Args:
folder_path (str): Path to the folder to clear.
Examples:
>>> import os
>>> os.makedirs('example_folder', exist_ok=True)
>>> with open('example_folder/temp_file.txt', 'w') as f:
... f.write('example')
>>> clear_folder('example_folder')
>>> os.listdir('example_folder')
[]
"""
if os.path.exists(folder_path):
for filename in os.listdir(folder_path):
file_path = os.path.join(folder_path, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
print(f"Failed to delete {file_path}. Reason: {e}")
else:
os.makedirs(folder_path)
def _umi_save_image(img_array: np.array, i: int, folder_path: str):
import os
"""
Saves a single image to the specified folder.
Args:
img_array (ndarray): The numpy array of the image.
i (int): Index of the image, used for naming.
folder_path (str): Path to the folder where the image will be saved.
"""
img = PILImage.fromarray(img_array)
img_format = "PNG" if img_array.dtype == np.uint8 else "JPEG"
img.save(os.path.join(folder_path, f"{i}.{img_format.lower()}"), quality=100)
def _umi_save_images_concurrently(zarr_data: dict, folder_path: str, max_workers: int = 4):
from concurrent.futures import ThreadPoolExecutor
"""
Saves images from the zarr_data to the specified folder using multithreading.
Args:
zarr_data (dict): A dictionary containing image data in an array format.
folder_path (str): Path to the folder where images will be saved.
max_workers (int): The maximum number of threads to use for saving images.
"""
num_images = len(zarr_data["data/camera0_rgb"])
_umi_clear_folder(folder_path) # Clear or create folder first
with ThreadPoolExecutor(max_workers=max_workers) as executor:
[
executor.submit(_umi_save_image, zarr_data["data/camera0_rgb"][i], i, folder_path)
for i in range(num_images)
]
if __name__ == "__main__":
root = "data"
revision = "v1.1"
dataset_ids = [
"pusht",
"xarm_lift_medium",
@ -545,6 +773,7 @@ if __name__ == "__main__":
"aloha_sim_insertion_scripted",
"aloha_sim_transfer_cube_human",
"aloha_sim_transfer_cube_scripted",
"umi_cup_in_the_wild",
]
for dataset_id in dataset_ids:
download_and_upload(root, revision, dataset_id)

View File

@ -25,6 +25,8 @@ When implementing a new policy class (e.g. `DiffusionPolicy`) follow these steps
- Update variables in `tests/test_available.py` by importing your new Policy class
"""
import itertools
from lerobot.__version__ import __version__ # noqa: F401
available_tasks_per_env = {
@ -52,7 +54,12 @@ available_datasets_per_env = {
"lerobot/xarm_push_medium_replay",
],
}
available_datasets = [dataset for datasets in available_datasets_per_env.values() for dataset in datasets]
available_datasets_without_env = ["lerobot/umi_cup_in_the_wild"]
available_datasets = list(
itertools.chain(*available_datasets_per_env.values(), available_datasets_without_env)
)
available_policies = [
"act",

View File

@ -0,0 +1,311 @@
# imagecodecs/numcodecs.py
# Copyright (c) 2021-2022, Christoph Gohlke
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
# Copied from: https://github.com/real-stanford/universal_manipulation_interface/blob/298776ce251f33b6b3185a98d6e7d1f9ad49168b/diffusion_policy/codecs/imagecodecs_numcodecs.py#L1
"""Additional numcodecs implemented using imagecodecs."""
__version__ = "2022.9.26"
__all__ = ("register_codecs",)
import imagecodecs
import numpy
from numcodecs.abc import Codec
from numcodecs.registry import get_codec, register_codec
# TODO (azouitine): Remove useless codecs
def protective_squeeze(x: numpy.ndarray):
"""
Squeeze dim only if it's not the last dim.
Image dim expected to be *, H, W, C
"""
img_shape = x.shape[-3:]
if len(x.shape) > 3:
n_imgs = numpy.prod(x.shape[:-3])
if n_imgs > 1:
img_shape = (-1,) + img_shape
return x.reshape(img_shape)
def get_default_image_compressor(**kwargs):
if imagecodecs.JPEGXL:
# has JPEGXL
this_kwargs = {
"effort": 3,
"distance": 0.3,
# bug in libjxl, invalid codestream for non-lossless
# when decoding speed > 1
"decodingspeed": 1,
}
this_kwargs.update(kwargs)
return JpegXl(**this_kwargs)
else:
this_kwargs = {"level": 50}
this_kwargs.update(kwargs)
return Jpeg2k(**this_kwargs)
class Jpeg2k(Codec):
"""JPEG 2000 codec for numcodecs."""
codec_id = "imagecodecs_jpeg2k"
def __init__(
self,
level=None,
codecformat=None,
colorspace=None,
tile=None,
reversible=None,
bitspersample=None,
resolutions=None,
numthreads=None,
verbose=0,
):
self.level = level
self.codecformat = codecformat
self.colorspace = colorspace
self.tile = None if tile is None else tuple(tile)
self.reversible = reversible
self.bitspersample = bitspersample
self.resolutions = resolutions
self.numthreads = numthreads
self.verbose = verbose
def encode(self, buf):
buf = protective_squeeze(numpy.asarray(buf))
return imagecodecs.jpeg2k_encode(
buf,
level=self.level,
codecformat=self.codecformat,
colorspace=self.colorspace,
tile=self.tile,
reversible=self.reversible,
bitspersample=self.bitspersample,
resolutions=self.resolutions,
numthreads=self.numthreads,
verbose=self.verbose,
)
def decode(self, buf, out=None):
return imagecodecs.jpeg2k_decode(buf, verbose=self.verbose, numthreads=self.numthreads, out=out)
class JpegXl(Codec):
"""JPEG XL codec for numcodecs."""
codec_id = "imagecodecs_jpegxl"
def __init__(
self,
# encode
level=None,
effort=None,
distance=None,
lossless=None,
decodingspeed=None,
photometric=None,
planar=None,
usecontainer=None,
# decode
index=None,
keeporientation=None,
# both
numthreads=None,
):
"""
Return JPEG XL image from numpy array.
Float must be in nominal range 0..1.
Currently L, LA, RGB, RGBA images are supported in contig mode.
Extra channels are only supported for grayscale images in planar mode.
Parameters
----------
level : Default to None, i.e. not overwriting lossess and decodingspeed options.
When < 0: Use lossless compression
When in [0,1,2,3,4]: Sets the decoding speed tier for the provided options.
Minimum is 0 (slowest to decode, best quality/density), and maximum
is 4 (fastest to decode, at the cost of some quality/density).
effort : Default to 3.
Sets encoder effort/speed level without affecting decoding speed.
Valid values are, from faster to slower speed: 1:lightning 2:thunder
3:falcon 4:cheetah 5:hare 6:wombat 7:squirrel 8:kitten 9:tortoise.
Speed: lightning, thunder, falcon, cheetah, hare, wombat, squirrel, kitten, tortoise
control the encoder effort in ascending order.
This also affects memory usage: using lower effort will typically reduce memory
consumption during encoding.
lightning and thunder are fast modes useful for lossless mode (modular).
falcon disables all of the following tools.
cheetah enables coefficient reordering, context clustering, and heuristics for selecting DCT sizes and quantization steps.
hare enables Gaborish filtering, chroma from luma, and an initial estimate of quantization steps.
wombat enables error diffusion quantization and full DCT size selection heuristics.
squirrel (default) enables dots, patches, and spline detection, and full context clustering.
kitten optimizes the adaptive quantization for a psychovisual metric.
tortoise enables a more thorough adaptive quantization search.
distance : Default to 1.0
Sets the distance level for lossy compression: target max butteraugli distance,
lower = higher quality. Range: 0 .. 15. 0.0 = mathematically lossless
(however, use JxlEncoderSetFrameLossless instead to use true lossless,
as setting distance to 0 alone is not the only requirement).
1.0 = visually lossless. Recommended range: 0.5 .. 3.0.
lossess : Default to False.
Use lossess encoding.
decodingspeed : Default to 0.
Duplicate to level. [0,4]
photometric : Return JxlColorSpace value.
Default logic is quite complicated but works most of the time.
Accepted value:
int: [-1,3]
str: ['RGB',
'WHITEISZERO', 'MINISWHITE',
'BLACKISZERO', 'MINISBLACK', 'GRAY',
'XYB', 'KNOWN']
planar : Enable multi-channel mode.
Default to false.
usecontainer :
Forces the encoder to use the box-based container format (BMFF)
even when not necessary.
When using JxlEncoderUseBoxes, JxlEncoderStoreJPEGMetadata or
JxlEncoderSetCodestreamLevel with level 10, the encoder will
automatically also use the container format, it is not necessary
to use JxlEncoderUseContainer for those use cases.
By default this setting is disabled.
index : Selectively decode frames for animation.
Default to 0, decode all frames.
When set to > 0, decode that frame index only.
keeporientation :
Enables or disables preserving of as-in-bitstream pixeldata orientation.
Some images are encoded with an Orientation tag indicating that the
decoder must perform a rotation and/or mirroring to the encoded image data.
If skip_reorientation is JXL_FALSE (the default): the decoder will apply
the transformation from the orientation setting, hence rendering the image
according to its specified intent. When producing a JxlBasicInfo, the decoder
will always set the orientation field to JXL_ORIENT_IDENTITY (matching the
returned pixel data) and also align xsize and ysize so that they correspond
to the width and the height of the returned pixel data.
If skip_reorientation is JXL_TRUE: the decoder will skip applying the
transformation from the orientation setting, returning the image in
the as-in-bitstream pixeldata orientation. This may be faster to decode
since the decoder doesnt have to apply the transformation, but can
cause wrong display of the image if the orientation tag is not correctly
taken into account by the user.
By default, this option is disabled, and the returned pixel data is
re-oriented according to the images Orientation setting.
threads : Default to 1.
If <= 0, use all cores.
If > 32, clipped to 32.
"""
self.level = level
self.effort = effort
self.distance = distance
self.lossless = bool(lossless)
self.decodingspeed = decodingspeed
self.photometric = photometric
self.planar = planar
self.usecontainer = usecontainer
self.index = index
self.keeporientation = keeporientation
self.numthreads = numthreads
def encode(self, buf):
# TODO: only squeeze all but last dim
buf = protective_squeeze(numpy.asarray(buf))
return imagecodecs.jpegxl_encode(
buf,
level=self.level,
effort=self.effort,
distance=self.distance,
lossless=self.lossless,
decodingspeed=self.decodingspeed,
photometric=self.photometric,
planar=self.planar,
usecontainer=self.usecontainer,
numthreads=self.numthreads,
)
def decode(self, buf, out=None):
return imagecodecs.jpegxl_decode(
buf,
index=self.index,
keeporientation=self.keeporientation,
numthreads=self.numthreads,
out=out,
)
def _flat(out):
"""Return numpy array as contiguous view of bytes if possible."""
if out is None:
return None
view = memoryview(out)
if view.readonly or not view.contiguous:
return None
return view.cast("B")
def register_codecs(codecs=None, force=False, verbose=True):
"""Register codecs in this module with numcodecs."""
for name, cls in globals().items():
if not hasattr(cls, "codec_id") or name == "Codec":
continue
if codecs is not None and cls.codec_id not in codecs:
continue
try:
try: # noqa: SIM105
get_codec({"id": cls.codec_id})
except TypeError:
# registered, but failed
pass
except ValueError:
# not registered yet
pass
else:
if not force:
if verbose:
log_warning(f"numcodec {cls.codec_id!r} already registered")
continue
if verbose:
log_warning(f"replacing registered numcodec {cls.codec_id!r}")
register_codec(cls)
def log_warning(msg, *args, **kwargs):
"""Log message with level WARNING."""
import logging
logging.getLogger(__name__).warning(msg, *args, **kwargs)

57
poetry.lock generated
View File

@ -1326,6 +1326,49 @@ files = [
{file = "idna-3.7.tar.gz", hash = "sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc"},
]
[[package]]
name = "imagecodecs"
version = "2024.1.1"
description = "Image transformation, compression, and decompression codecs"
optional = true
python-versions = ">=3.9"
files = [
{file = "imagecodecs-2024.1.1-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:4b787ffbe62e98e492ace10229f07ee309f42b93dd0fc602ba7595d06e326cef"},
{file = "imagecodecs-2024.1.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a3b87b648b081fb938073b42729d3c2c56344242e1d67af4d53f408eb4c1e8c6"},
{file = "imagecodecs-2024.1.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca6eb9cb4d1e8dabbbaaa9e257a421eaab2e9c9d37888dda31c15cf56648f67d"},
{file = "imagecodecs-2024.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c5a2c70ee23bc1c59c76bafcdc79c6f61c8c231bad5789df22699d77eb3b6556"},
{file = "imagecodecs-2024.1.1-cp310-cp310-win32.whl", hash = "sha256:38b7abfdddc317fc44f69eaa82ca95e44be0392f8b6eefa55f7ff2bf912a2dda"},
{file = "imagecodecs-2024.1.1-cp310-cp310-win_amd64.whl", hash = "sha256:a5153df7451f170dfd41c00a2686281bd0a73fcbf315f546689276b594cf97ae"},
{file = "imagecodecs-2024.1.1-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:32eb26ebb89dd56f1bb7984f71ce84bfe34d6c21fa573b113fb4f473a9aa7e6c"},
{file = "imagecodecs-2024.1.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3e10aea5e391f53f39cb07f1559c780f2292436e756afb7fdf0379b3eacef9cd"},
{file = "imagecodecs-2024.1.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c2f5dfa0bc36d86ce3e3b4f14a99bc1cdb8d65898fbd316d7f2b1ff9fdc6f6eb"},
{file = "imagecodecs-2024.1.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e6db283c41118cb66c4ec05f966d1b7cc9251f66fcba8cb472d71af2184902fb"},
{file = "imagecodecs-2024.1.1-cp311-cp311-win32.whl", hash = "sha256:ef7a13d09966b021c33e2dd726ee44b9cd3cb198b32474c7e48cca55abd52f6a"},
{file = "imagecodecs-2024.1.1-cp311-cp311-win_amd64.whl", hash = "sha256:cfff3b3fae93d414ec851ada3fe6875857bc3c234eb9718c8baef8a3130c2ced"},
{file = "imagecodecs-2024.1.1-cp311-cp311-win_arm64.whl", hash = "sha256:bf4ed0385973ce3e0b2e2c9d720310e2378b4801ae3f52e0c6cb1c94b116f711"},
{file = "imagecodecs-2024.1.1-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:896302c49aa3beae94b45b1d5411850eb129980d1ef6c3e12e0c4a413236866e"},
{file = "imagecodecs-2024.1.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:63fc197b091f4dd0f3d490570ec175dd6e276b3b4d7d2b3fa0e89548a6686a57"},
{file = "imagecodecs-2024.1.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:636c5e6a599df1a5168c32ed0063a7a98585f424da8679cc20af50a1a1c2185f"},
{file = "imagecodecs-2024.1.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4b1c30bf8be21a3e58720dfe86cf618be12c4ce5be0657268983caaf38a59368"},
{file = "imagecodecs-2024.1.1-cp312-cp312-win32.whl", hash = "sha256:12c2bb563bf173067b9ccd193af354fea7b78aa8bc43a61943445b0126b1bd4d"},
{file = "imagecodecs-2024.1.1-cp312-cp312-win_amd64.whl", hash = "sha256:ccdca7a78ceab005aa3a92fd548e12d042c02862725b18f0328c15fa117c4638"},
{file = "imagecodecs-2024.1.1-cp312-cp312-win_arm64.whl", hash = "sha256:51b6c1afe1d04dd8b6059d59d94a487097b2fb44c4a89e244a386ea7fe170f89"},
{file = "imagecodecs-2024.1.1-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:93fef596adec62bf49418f6df31c45b19f655af7c943d4ba38ee703a7bc5201d"},
{file = "imagecodecs-2024.1.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d969d2b04b24a3a75d747b830a5ac9e2f4fcdc492db565f65d51951a9587da22"},
{file = "imagecodecs-2024.1.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8804b1d0be1fbce21d18e26ebafd274e131155a242bfd5fbfd1f02b4df613e28"},
{file = "imagecodecs-2024.1.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1946d847e57e3b3a9735a7c68c674185d5061dc475071d177aa09335c667b687"},
{file = "imagecodecs-2024.1.1-cp39-cp39-win32.whl", hash = "sha256:d80b382d906152f4a5d1c93d4e589044836d8e6bb11246dabf72f1e74b1166fb"},
{file = "imagecodecs-2024.1.1-cp39-cp39-win_amd64.whl", hash = "sha256:bd14e068fb78f291d1e27e548839a470faaf75dd02c9d1376d5194519544cf11"},
{file = "imagecodecs-2024.1.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:1967ba770f780ab9aa560cf850268d601e189e4df3dd4109fa8a19cba19c9cda"},
{file = "imagecodecs-2024.1.1.tar.gz", hash = "sha256:fde46bd698d008255deef5411c59b35c0e875295e835bf6079f7e2ab22f216eb"},
]
[package.dependencies]
numpy = "*"
[package.extras]
all = ["matplotlib", "numcodecs", "tifffile"]
[[package]]
name = "imageio"
version = "2.34.1"
@ -2855,13 +2898,13 @@ files = [
[[package]]
name = "pytest"
version = "8.1.2"
version = "8.2.0"
description = "pytest: simple powerful testing with Python"
optional = true
python-versions = ">=3.8"
files = [
{file = "pytest-8.1.2-py3-none-any.whl", hash = "sha256:6c06dc309ff46a05721e6fd48e492a775ed8165d2ecdf57f156a80c7e95bb142"},
{file = "pytest-8.1.2.tar.gz", hash = "sha256:f3c45d1d5eed96b01a2aea70dee6a4a366d51d38f9957768083e4fecfc77f3ef"},
{file = "pytest-8.2.0-py3-none-any.whl", hash = "sha256:1733f0620f6cda4095bbf0d9ff8022486e91892245bb9e7d5542c018f612f233"},
{file = "pytest-8.2.0.tar.gz", hash = "sha256:d507d4482197eac0ba2bae2e9babf0672eb333017bcedaa5fb1a3d42c1174b3f"},
]
[package.dependencies]
@ -2869,11 +2912,11 @@ colorama = {version = "*", markers = "sys_platform == \"win32\""}
exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""}
iniconfig = "*"
packaging = "*"
pluggy = ">=1.4,<2.0"
pluggy = ">=1.5,<2.0"
tomli = {version = ">=1", markers = "python_version < \"3.11\""}
[package.extras]
testing = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
[[package]]
name = "pytest-cov"
@ -2943,6 +2986,7 @@ files = [
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
@ -4249,9 +4293,10 @@ aloha = ["gym-aloha"]
dev = ["debugpy", "pre-commit"]
pusht = ["gym-pusht"]
test = ["pytest", "pytest-cov"]
umi = ["imagecodecs"]
xarm = ["gym-xarm"]
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "03c1d8598e376680c08f997fff37d3c6aa39efce5e3b6f799ea7bbd476a492f9"
content-hash = "8bd1352973c6104e52f50b68f7387d26ced9b07a52e889540b73d132865cda38"

View File

@ -54,6 +54,7 @@ debugpy = {version = "^1.8.1", optional = true}
pytest = {version = "^8.1.0", optional = true}
pytest-cov = {version = "^5.0.0", optional = true}
datasets = "^2.19.0"
imagecodecs = { version = "^2024.1.1", optional = true }
torchaudio = "^2.3.0"
@ -63,6 +64,7 @@ xarm = ["gym-xarm"]
aloha = ["gym-aloha"]
dev = ["pre-commit", "debugpy"]
test = ["pytest", "pytest-cov"]
umi = ["imagecodecs"]
[tool.ruff]

View File

@ -0,0 +1,3 @@
{
"fps": 10
}

View File

@ -0,0 +1,67 @@
{
"citation": "",
"description": "",
"features": {
"observation.state": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 7,
"_type": "Sequence"
},
"episode_index": {
"dtype": "int64",
"_type": "Value"
},
"frame_index": {
"dtype": "int64",
"_type": "Value"
},
"timestamp": {
"dtype": "float32",
"_type": "Value"
},
"episode_data_index_from": {
"dtype": "int64",
"_type": "Value"
},
"episode_data_index_to": {
"dtype": "int64",
"_type": "Value"
},
"end_pose": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 6,
"_type": "Sequence"
},
"start_pos": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 6,
"_type": "Sequence"
},
"gripper_width": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 1,
"_type": "Sequence"
},
"index": {
"dtype": "int64",
"_type": "Value"
},
"observation.image": {
"_type": "Image"
}
},
"homepage": "",
"license": ""
}

View File

@ -0,0 +1,13 @@
{
"_data_files": [
{
"filename": "data-00000-of-00001.arrow"
}
],
"_fingerprint": "fd95ee932cb1fce2",
"_format_columns": null,
"_format_kwargs": {},
"_format_type": "torch",
"_output_all_columns": false,
"_split": null
}