diff --git a/.gitattributes b/.gitattributes index 4135de8f..f12e709c 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,2 +1,6 @@ *.memmap filter=lfs diff=lfs merge=lfs -text *.stl filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +*.mp4 filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.json filter=lfs diff=lfs merge=lfs -text diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 45feabdc..a466cff7 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -29,6 +29,8 @@ jobs: MUJOCO_GL: egl steps: - uses: actions/checkout@v4 + with: + lfs: true # Ensure LFS files are pulled - name: Install EGL run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev @@ -65,6 +67,8 @@ jobs: MUJOCO_GL: egl steps: - uses: actions/checkout@v4 + with: + lfs: true # Ensure LFS files are pulled - name: Install poetry run: | @@ -97,6 +101,8 @@ jobs: MUJOCO_GL: egl steps: - uses: actions/checkout@v4 + with: + lfs: true # Ensure LFS files are pulled - name: Install EGL run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index dd8f97e2..b20dede6 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -195,6 +195,11 @@ Follow these steps to start contributing: git commit ``` + Note, if you already commited some changes that have a wrong formatting, you can use: + ```bash + pre-commit run --all-files + ``` + Please write [good commit messages](https://chris.beams.io/posts/git-commit/). It is a good idea to sync your copy of the code with the original diff --git a/Makefile b/Makefile index c4140cd0..2bab6199 100644 --- a/Makefile +++ b/Makefile @@ -20,6 +20,8 @@ build-gpu: test-end-to-end: ${MAKE} test-act-ete-train ${MAKE} test-act-ete-eval + ${MAKE} test-act-ete-train-amp + ${MAKE} test-act-ete-eval-amp ${MAKE} test-diffusion-ete-train ${MAKE} test-diffusion-ete-eval ${MAKE} test-tdmpc-ete-train @@ -30,6 +32,7 @@ test-end-to-end: test-act-ete-train: python lerobot/scripts/train.py \ policy=act \ + policy.dim_model=64 \ env=aloha \ wandb.enable=False \ training.offline_steps=2 \ @@ -52,9 +55,40 @@ test-act-ete-eval: env.episode_length=8 \ device=cpu \ +test-act-ete-train-amp: + python lerobot/scripts/train.py \ + policy=act \ + policy.dim_model=64 \ + env=aloha \ + wandb.enable=False \ + training.offline_steps=2 \ + training.online_steps=0 \ + eval.n_episodes=1 \ + eval.batch_size=1 \ + device=cpu \ + training.save_model=true \ + training.save_freq=2 \ + policy.n_action_steps=20 \ + policy.chunk_size=20 \ + training.batch_size=2 \ + hydra.run.dir=tests/outputs/act/ \ + use_amp=true + +test-act-ete-eval-amp: + python lerobot/scripts/eval.py \ + -p tests/outputs/act/checkpoints/000002 \ + eval.n_episodes=1 \ + eval.batch_size=1 \ + env.episode_length=8 \ + device=cpu \ + use_amp=true + test-diffusion-ete-train: python lerobot/scripts/train.py \ policy=diffusion \ + policy.down_dims=\[64,128,256\] \ + policy.diffusion_step_embed_dim=32 \ + policy.num_inference_steps=10 \ env=pusht \ wandb.enable=False \ training.offline_steps=2 \ @@ -75,6 +109,7 @@ test-diffusion-ete-eval: env.episode_length=8 \ device=cpu \ +# TODO(alexander-soare): Restore online_steps to 2 when it is reinstated. test-tdmpc-ete-train: python lerobot/scripts/train.py \ policy=tdmpc \ @@ -83,7 +118,7 @@ test-tdmpc-ete-train: dataset_repo_id=lerobot/xarm_lift_medium \ wandb.enable=False \ training.offline_steps=2 \ - training.online_steps=2 \ + training.online_steps=0 \ eval.n_episodes=1 \ eval.batch_size=1 \ env.episode_length=2 \ @@ -101,7 +136,6 @@ test-tdmpc-ete-eval: env.episode_length=8 \ device=cpu \ - test-default-ete-eval: python lerobot/scripts/eval.py \ --config lerobot/configs/default.yaml \ diff --git a/README.md b/README.md index db573121..337adb85 100644 --- a/README.md +++ b/README.md @@ -198,11 +198,11 @@ To add a dataset to the hub, you need to login using a write-access token, which huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential ``` -Then move your dataset folder in `data` directory (e.g. `data/aloha_ping_pong`), and push your dataset to the hub with: +Then move your dataset folder in `data` directory (e.g. `data/aloha_static_pingpong_test`), and push your dataset to the hub with: ```bash python lerobot/scripts/push_dataset_to_hub.py \ --data-dir data \ ---dataset-id aloha_ping_ping \ +--dataset-id aloha_static_pingpong_test \ --raw-format aloha_hdf5 \ --community-id lerobot ``` diff --git a/docker/lerobot-gpu/Dockerfile b/docker/lerobot-gpu/Dockerfile index a2823dc2..9889114a 100644 --- a/docker/lerobot-gpu/Dockerfile +++ b/docker/lerobot-gpu/Dockerfile @@ -8,7 +8,7 @@ ARG DEBIAN_FRONTEND=noninteractive RUN apt-get update && apt-get install -y --no-install-recommends \ build-essential cmake \ git git-lfs openssh-client \ - nano vim \ + nano vim ffmpeg \ htop atop nvtop \ sed gawk grep curl wget \ tcpdump sysstat screen \ diff --git a/examples/4_calculate_validation_loss.py b/examples/4_calculate_validation_loss.py new file mode 100644 index 00000000..1428014b --- /dev/null +++ b/examples/4_calculate_validation_loss.py @@ -0,0 +1,90 @@ +"""This script demonstrates how to slice a dataset and calculate the loss on a subset of the data. + +This technique can be useful for debugging and testing purposes, as well as identifying whether a policy +is learning effectively. + +Furthermore, relying on validation loss to evaluate performance is generally not considered a good practice, +especially in the context of imitation learning. The most reliable approach is to evaluate the policy directly +on the target environment, whether that be in simulation or the real world. +""" + +import math +from pathlib import Path + +import torch +from huggingface_hub import snapshot_download + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy + +device = torch.device("cuda") + +# Download the diffusion policy for pusht environment +pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht")) +# OR uncomment the following to evaluate a policy from the local outputs/train folder. +# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion") + +policy = DiffusionPolicy.from_pretrained(pretrained_policy_path) +policy.eval() +policy.to(device) + +# Set up the dataset. +delta_timestamps = { + # Load the previous image and state at -0.1 seconds before current frame, + # then load current image and state corresponding to 0.0 second. + "observation.image": [-0.1, 0.0], + "observation.state": [-0.1, 0.0], + # Load the previous action (-0.1), the next action to be executed (0.0), + # and 14 future actions with a 0.1 seconds spacing. All these actions will be + # used to calculate the loss. + "action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4], +} + +# Load the last 10% of episodes of the dataset as a validation set. +# - Load full dataset +full_dataset = LeRobotDataset("lerobot/pusht", split="train") +# - Calculate train and val subsets +num_train_episodes = math.floor(full_dataset.num_episodes * 90 / 100) +num_val_episodes = full_dataset.num_episodes - num_train_episodes +print(f"Number of episodes in full dataset: {full_dataset.num_episodes}") +print(f"Number of episodes in training dataset (90% subset): {num_train_episodes}") +print(f"Number of episodes in validation dataset (10% subset): {num_val_episodes}") +# - Get first frame index of the validation set +first_val_frame_index = full_dataset.episode_data_index["from"][num_train_episodes].item() +# - Load frames subset belonging to validation set using the `split` argument. +# It utilizes the `datasets` library's syntax for slicing datasets. +# For more information on the Slice API, please see: +# https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits +train_dataset = LeRobotDataset( + "lerobot/pusht", split=f"train[:{first_val_frame_index}]", delta_timestamps=delta_timestamps +) +val_dataset = LeRobotDataset( + "lerobot/pusht", split=f"train[{first_val_frame_index}:]", delta_timestamps=delta_timestamps +) +print(f"Number of frames in training dataset (90% subset): {len(train_dataset)}") +print(f"Number of frames in validation dataset (10% subset): {len(val_dataset)}") + +# Create dataloader for evaluation. +val_dataloader = torch.utils.data.DataLoader( + val_dataset, + num_workers=4, + batch_size=64, + shuffle=False, + pin_memory=device != torch.device("cpu"), + drop_last=False, +) + +# Run validation loop. +loss_cumsum = 0 +n_examples_evaluated = 0 +for batch in val_dataloader: + batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()} + output_dict = policy.forward(batch) + + loss_cumsum += output_dict["loss"].item() + n_examples_evaluated += batch["index"].shape[0] + +# Calculate the average loss over the validation set. +average_loss = loss_cumsum / n_examples_evaluated + +print(f"Average loss on validation set: {average_loss:.4f}") diff --git a/lerobot/__init__.py b/lerobot/__init__.py index e188bc52..e0234f29 100644 --- a/lerobot/__init__.py +++ b/lerobot/__init__.py @@ -61,13 +61,21 @@ available_datasets_per_env = { "lerobot/aloha_sim_insertion_scripted", "lerobot/aloha_sim_transfer_cube_human", "lerobot/aloha_sim_transfer_cube_scripted", + "lerobot/aloha_sim_insertion_human_image", + "lerobot/aloha_sim_insertion_scripted_image", + "lerobot/aloha_sim_transfer_cube_human_image", + "lerobot/aloha_sim_transfer_cube_scripted_image", ], - "pusht": ["lerobot/pusht"], + "pusht": ["lerobot/pusht", "lerobot/pusht_image"], "xarm": [ "lerobot/xarm_lift_medium", "lerobot/xarm_lift_medium_replay", "lerobot/xarm_push_medium", "lerobot/xarm_push_medium_replay", + "lerobot/xarm_lift_medium_image", + "lerobot/xarm_lift_medium_replay_image", + "lerobot/xarm_push_medium_image", + "lerobot/xarm_push_medium_replay_image", ], } diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 21d09879..057e4770 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -20,17 +20,19 @@ import datasets import torch from lerobot.common.datasets.utils import ( + calculate_episode_data_index, load_episode_data_index, load_hf_dataset, load_info, load_previous_and_future_frames, load_stats, load_videos, + reset_episode_index, ) from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None -CODEBASE_VERSION = "v1.3" +CODEBASE_VERSION = "v1.4" class LeRobotDataset(torch.utils.data.Dataset): @@ -54,7 +56,11 @@ class LeRobotDataset(torch.utils.data.Dataset): # TODO(rcadene, aliberts): implement faster transfer # https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads self.hf_dataset = load_hf_dataset(repo_id, version, root, split) - self.episode_data_index = load_episode_data_index(repo_id, version, root) + if split == "train": + self.episode_data_index = load_episode_data_index(repo_id, version, root) + else: + self.episode_data_index = calculate_episode_data_index(self.hf_dataset) + self.hf_dataset = reset_episode_index(self.hf_dataset) self.stats = load_stats(repo_id, version, root) self.info = load_info(repo_id, version, root) if self.video: diff --git a/lerobot/common/datasets/push_dataset_to_hub/_download_raw.py b/lerobot/common/datasets/push_dataset_to_hub/_download_raw.py index 232fd055..7074bcba 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/_download_raw.py +++ b/lerobot/common/datasets/push_dataset_to_hub/_download_raw.py @@ -24,17 +24,16 @@ import shutil from pathlib import Path import tqdm - -ALOHA_RAW_URLS_DIR = "lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls" +from huggingface_hub import snapshot_download def download_raw(raw_dir, dataset_id): - if "pusht" in dataset_id: + if "aloha" in dataset_id or "image" in dataset_id: + download_hub(raw_dir, dataset_id) + elif "pusht" in dataset_id: download_pusht(raw_dir) elif "xarm" in dataset_id: download_xarm(raw_dir) - elif "aloha" in dataset_id: - download_aloha(raw_dir, dataset_id) elif "umi" in dataset_id: download_umi(raw_dir) else: @@ -103,37 +102,13 @@ def download_xarm(raw_dir: Path): zip_path.unlink() -def download_aloha(raw_dir: Path, dataset_id: str): - import gdown - - subset_id = dataset_id.replace("aloha_", "") - urls_path = Path(ALOHA_RAW_URLS_DIR) / f"{subset_id}.txt" - assert urls_path.exists(), f"{subset_id}.txt not found in '{ALOHA_RAW_URLS_DIR}' directory." - - with open(urls_path) as f: - # strip lines and ignore empty lines - urls = [url.strip() for url in f if url.strip()] - - # sanity check - for url in urls: - assert ( - "drive.google.com/drive/folders" in url or "drive.google.com/file" in url - ), f"Wrong url provided '{url}' in file '{urls_path}'." - +def download_hub(raw_dir: Path, dataset_id: str): raw_dir = Path(raw_dir) raw_dir.mkdir(parents=True, exist_ok=True) - logging.info(f"Start downloading from google drive for {dataset_id}") - for url in urls: - if "drive.google.com/drive/folders" in url: - # when a folder url is given, download up to 50 files from the folder - gdown.download_folder(url, output=str(raw_dir), remaining_ok=True) - - elif "drive.google.com/file" in url: - # because of the 50 files limit per folder, we download the remaining files (file by file) - gdown.download(url, output=str(raw_dir), fuzzy=True) - - logging.info(f"End downloading from google drive for {dataset_id}") + logging.info(f"Start downloading from huggingface.co/cadene for {dataset_id}") + snapshot_download(f"cadene/{dataset_id}_raw", repo_type="dataset", local_dir=raw_dir) + logging.info(f"Finish downloading from huggingface.co/cadene for {dataset_id}") def download_umi(raw_dir: Path): @@ -148,21 +123,30 @@ def download_umi(raw_dir: Path): if __name__ == "__main__": data_dir = Path("data") dataset_ids = [ + "pusht_image", + "xarm_lift_medium_image", + "xarm_lift_medium_replay_image", + "xarm_push_medium_image", + "xarm_push_medium_replay_image", + "aloha_sim_insertion_human_image", + "aloha_sim_insertion_scripted_image", + "aloha_sim_transfer_cube_human_image", + "aloha_sim_transfer_cube_scripted_image", "pusht", "xarm_lift_medium", "xarm_lift_medium_replay", "xarm_push_medium", "xarm_push_medium_replay", + "aloha_sim_insertion_human", + "aloha_sim_insertion_scripted", + "aloha_sim_transfer_cube_human", + "aloha_sim_transfer_cube_scripted", "aloha_mobile_cabinet", "aloha_mobile_chair", "aloha_mobile_elevator", "aloha_mobile_shrimp", "aloha_mobile_wash_pan", "aloha_mobile_wipe_wine", - "aloha_sim_insertion_human", - "aloha_sim_insertion_scripted", - "aloha_sim_transfer_cube_human", - "aloha_sim_transfer_cube_scripted", "aloha_static_battery", "aloha_static_candy", "aloha_static_coffee", diff --git a/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py b/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py index 4efadc9e..1c2f066e 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py @@ -17,7 +17,7 @@ Contains utilities to process raw data format of HDF5 files like in: https://github.com/tonyzhaozh/act """ -import re +import gc import shutil from pathlib import Path @@ -79,10 +79,8 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug): episode_data_index = {"from": [], "to": []} id_from = 0 - - for ep_path in tqdm.tqdm(hdf5_files, total=len(hdf5_files)): + for ep_idx, ep_path in tqdm.tqdm(enumerate(hdf5_files), total=len(hdf5_files)): with h5py.File(ep_path, "r") as ep: - ep_idx = int(re.search(r"episode_(\d+)", ep_path.name).group(1)) num_frames = ep["/action"].shape[0] # last step of demonstration is considered done @@ -91,6 +89,10 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug): state = torch.from_numpy(ep["/observations/qpos"][:]) action = torch.from_numpy(ep["/action"][:]) + if "/observations/qvel" in ep: + velocity = torch.from_numpy(ep["/observations/qvel"][:]) + if "/observations/effort" in ep: + effort = torch.from_numpy(ep["/observations/effort"][:]) ep_dict = {} @@ -131,6 +133,10 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug): ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array] ep_dict["observation.state"] = state + if "/observations/velocity" in ep: + ep_dict["observation.velocity"] = velocity + if "/observations/effort" in ep: + ep_dict["observation.effort"] = effort ep_dict["action"] = action ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames) ep_dict["frame_index"] = torch.arange(0, num_frames, 1) @@ -146,6 +152,8 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug): id_from += num_frames + gc.collect() + # process first episode only if debug: break @@ -167,6 +175,14 @@ def to_hf_dataset(data_dict, video) -> Dataset: features["observation.state"] = Sequence( length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None) ) + if "observation.velocity" in data_dict: + features["observation.velocity"] = Sequence( + length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None) + ) + if "observation.effort" in data_dict: + features["observation.effort"] = Sequence( + length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None) + ) features["action"] = Sequence( length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None) ) diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 5cdd5f7c..86fef8d4 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -14,7 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import json +import re from pathlib import Path +from typing import Dict import datasets import torch @@ -79,7 +81,23 @@ def hf_transform_to_torch(items_dict): def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset: """hf_dataset contains all the observations, states, actions, rewards, etc.""" if root is not None: - hf_dataset = load_from_disk(str(Path(root) / repo_id / split)) + hf_dataset = load_from_disk(str(Path(root) / repo_id / "train")) + # TODO(rcadene): clean this which enables getting a subset of dataset + if split != "train": + if "%" in split: + raise NotImplementedError(f"We dont support splitting based on percentage for now ({split}).") + match_from = re.search(r"train\[(\d+):\]", split) + match_to = re.search(r"train\[:(\d+)\]", split) + if match_from: + from_frame_index = int(match_from.group(1)) + hf_dataset = hf_dataset.select(range(from_frame_index, len(hf_dataset))) + elif match_to: + to_frame_index = int(match_to.group(1)) + hf_dataset = hf_dataset.select(range(to_frame_index)) + else: + raise ValueError( + f'`split` ({split}) should either be "train", "train[INT:]", or "train[:INT]"' + ) else: hf_dataset = load_dataset(repo_id, revision=version, split=split) hf_dataset.set_transform(hf_transform_to_torch) @@ -245,6 +263,84 @@ def load_previous_and_future_frames( return item +def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]: + """ + Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset. + + Parameters: + - hf_dataset (datasets.Dataset): A HuggingFace dataset containing the episode index. + + Returns: + - episode_data_index: A dictionary containing the data index for each episode. The dictionary has two keys: + - "from": A tensor containing the starting index of each episode. + - "to": A tensor containing the ending index of each episode. + """ + episode_data_index = {"from": [], "to": []} + + current_episode = None + """ + The episode_index is a list of integers, each representing the episode index of the corresponding example. + For instance, the following is a valid episode_index: + [0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2] + + Below, we iterate through the episode_index and populate the episode_data_index dictionary with the starting and + ending index of each episode. For the episode_index above, the episode_data_index dictionary will look like this: + { + "from": [0, 3, 7], + "to": [3, 7, 12] + } + """ + if len(hf_dataset) == 0: + episode_data_index = { + "from": torch.tensor([]), + "to": torch.tensor([]), + } + return episode_data_index + for idx, episode_idx in enumerate(hf_dataset["episode_index"]): + if episode_idx != current_episode: + # We encountered a new episode, so we append its starting location to the "from" list + episode_data_index["from"].append(idx) + # If this is not the first episode, we append the ending location of the previous episode to the "to" list + if current_episode is not None: + episode_data_index["to"].append(idx) + # Let's keep track of the current episode index + current_episode = episode_idx + else: + # We are still in the same episode, so there is nothing for us to do here + pass + # We have reached the end of the dataset, so we append the ending location of the last episode to the "to" list + episode_data_index["to"].append(idx + 1) + + for k in ["from", "to"]: + episode_data_index[k] = torch.tensor(episode_data_index[k]) + + return episode_data_index + + +def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset: + """ + Reset the `episode_index` of the provided HuggingFace Dataset. + + `episode_data_index` (and related functionality such as `load_previous_and_future_frames`) requires the + `episode_index` to be sorted, continuous (1,1,1 and not 1,2,1) and start at 0. + + This brings the `episode_index` to the required format. + """ + if len(hf_dataset) == 0: + return hf_dataset + unique_episode_idxs = torch.stack(hf_dataset["episode_index"]).unique().tolist() + episode_idx_to_reset_idx_mapping = { + ep_id: reset_ep_id for reset_ep_id, ep_id in enumerate(unique_episode_idxs) + } + + def modify_ep_idx_func(example): + example["episode_index"] = episode_idx_to_reset_idx_mapping[example["episode_index"].item()] + return example + + hf_dataset = hf_dataset.map(modify_ep_idx_func) + return hf_dataset + + def cycle(iterable): """The equivalent of itertools.cycle, but safe for Pytorch dataloaders. diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index ae36b3e2..27329bc9 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -10,6 +10,9 @@ hydra: name: default device: cuda # cpu +# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP, +# automatic gradient scaling is used. +use_amp: false # `seed` is used for training (eg: model initialization, dataset shuffling) # AND for the evaluation environments. seed: ??? @@ -17,6 +20,7 @@ dataset_repo_id: lerobot/pusht training: offline_steps: ??? + # NOTE: `online_steps` is not implemented yet. It's here as a placeholder. online_steps: ??? online_steps_between_rollouts: ??? online_sampling_ratio: 0.5 diff --git a/lerobot/configs/policy/tdmpc.yaml b/lerobot/configs/policy/tdmpc.yaml index 7e736850..09326ab4 100644 --- a/lerobot/configs/policy/tdmpc.yaml +++ b/lerobot/configs/policy/tdmpc.yaml @@ -5,7 +5,8 @@ dataset_repo_id: lerobot/xarm_lift_medium training: offline_steps: 25000 - online_steps: 25000 + # TODO(alexander-soare): uncomment when online training gets reinstated + online_steps: 0 # 25000 not implemented yet eval_freq: 5000 online_steps_between_rollouts: 1 online_sampling_ratio: 0.5 diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 9c95633a..7e4690d0 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -46,6 +46,7 @@ import json import logging import threading import time +from contextlib import nullcontext from copy import deepcopy from datetime import datetime as dt from pathlib import Path @@ -520,7 +521,7 @@ def eval( raise NotImplementedError() # Check device is available - get_safe_torch_device(hydra_cfg.device, log=True) + device = get_safe_torch_device(hydra_cfg.device, log=True) torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True @@ -539,16 +540,17 @@ def eval( policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats) policy.eval() - info = eval_policy( - env, - policy, - hydra_cfg.eval.n_episodes, - max_episodes_rendered=10, - video_dir=Path(out_dir) / "eval", - start_seed=hydra_cfg.seed, - enable_progbar=True, - enable_inner_progbar=True, - ) + with torch.no_grad(), torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext(): + info = eval_policy( + env, + policy, + hydra_cfg.eval.n_episodes, + max_episodes_rendered=10, + video_dir=Path(out_dir) / "eval", + start_seed=hydra_cfg.seed, + enable_progbar=True, + enable_inner_progbar=True, + ) print(info["aggregated"]) # Save info diff --git a/lerobot/scripts/push_dataset_to_hub.py b/lerobot/scripts/push_dataset_to_hub.py index 16d890a7..19af1cf8 100644 --- a/lerobot/scripts/push_dataset_to_hub.py +++ b/lerobot/scripts/push_dataset_to_hub.py @@ -25,7 +25,6 @@ python lerobot/scripts/push_dataset_to_hub.py \ --dataset-id pusht \ --raw-format pusht_zarr \ --community-id lerobot \ ---revision v1.2 \ --dry-run 1 \ --save-to-disk 1 \ --save-tests-to-disk 0 \ @@ -36,7 +35,6 @@ python lerobot/scripts/push_dataset_to_hub.py \ --dataset-id xarm_lift_medium \ --raw-format xarm_pkl \ --community-id lerobot \ ---revision v1.2 \ --dry-run 1 \ --save-to-disk 1 \ --save-tests-to-disk 0 \ @@ -47,7 +45,6 @@ python lerobot/scripts/push_dataset_to_hub.py \ --dataset-id aloha_sim_insertion_scripted \ --raw-format aloha_hdf5 \ --community-id lerobot \ ---revision v1.2 \ --dry-run 1 \ --save-to-disk 1 \ --save-tests-to-disk 0 \ @@ -58,7 +55,6 @@ python lerobot/scripts/push_dataset_to_hub.py \ --dataset-id umi_cup_in_the_wild \ --raw-format umi_zarr \ --community-id lerobot \ ---revision v1.2 \ --dry-run 1 \ --save-to-disk 1 \ --save-tests-to-disk 0 \ @@ -227,8 +223,7 @@ def push_dataset_to_hub( test_hf_dataset = test_hf_dataset.with_format(None) test_hf_dataset.save_to_disk(str(tests_out_dir / "train")) - # copy meta data to tests directory - shutil.copytree(meta_data_dir, tests_meta_data_dir) + save_meta_data(info, stats, episode_data_index, tests_meta_data_dir) # copy videos of first episode to tests directory episode_index = 0 @@ -237,6 +232,10 @@ def push_dataset_to_hub( fname = f"{key}_episode_{episode_index:06d}.mp4" shutil.copy(videos_dir / fname, tests_videos_dir / fname) + if not save_to_disk and out_dir.exists(): + # remove possible temporary files remaining in the output directory + shutil.rmtree(out_dir) + def main(): parser = argparse.ArgumentParser() @@ -314,7 +313,7 @@ def main(): parser.add_argument( "--num-workers", type=int, - default=16, + default=8, help="Number of processes of Dataloader for computing the dataset statistics.", ) parser.add_argument( diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 7ca7a0b3..2b28943d 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -15,15 +15,14 @@ # limitations under the License. import logging import time +from contextlib import nullcontext from copy import deepcopy from pathlib import Path -import datasets import hydra import torch -from datasets import concatenate_datasets -from datasets.utils import disable_progress_bars, enable_progress_bars from omegaconf import DictConfig +from torch.cuda.amp import GradScaler from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.utils import cycle @@ -31,6 +30,7 @@ from lerobot.common.envs.factory import make_env from lerobot.common.logger import Logger, log_output_dir from lerobot.common.policies.factory import make_policy from lerobot.common.policies.policy_protocol import PolicyWithUpdate +from lerobot.common.policies.utils import get_device_from_parameters from lerobot.common.utils.utils import ( format_big_number, get_safe_torch_device, @@ -69,7 +69,6 @@ def make_optimizer_and_scheduler(cfg, policy): cfg.training.adam_eps, cfg.training.adam_weight_decay, ) - assert cfg.training.online_steps == 0, "Diffusion Policy does not handle online training." from diffusers.optimization import get_scheduler lr_scheduler = get_scheduler( @@ -87,21 +86,40 @@ def make_optimizer_and_scheduler(cfg, policy): return optimizer, lr_scheduler -def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None): +def update_policy( + policy, + batch, + optimizer, + grad_clip_norm, + grad_scaler: GradScaler, + lr_scheduler=None, + use_amp: bool = False, +): """Returns a dictionary of items for logging.""" - start_time = time.time() + start_time = time.perf_counter() + device = get_device_from_parameters(policy) policy.train() - output_dict = policy.forward(batch) - # TODO(rcadene): policy.unnormalize_outputs(out_dict) - loss = output_dict["loss"] - loss.backward() + with torch.autocast(device_type=device.type) if use_amp else nullcontext(): + output_dict = policy.forward(batch) + # TODO(rcadene): policy.unnormalize_outputs(out_dict) + loss = output_dict["loss"] + grad_scaler.scale(loss).backward() + + # Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**. + grad_scaler.unscale_(optimizer) + grad_norm = torch.nn.utils.clip_grad_norm_( policy.parameters(), grad_clip_norm, error_if_nonfinite=False, ) - optimizer.step() + # Optimizer's gradients are already unscaled, so scaler.step does not unscale them, + # although it still skips optimizer.step() if the gradients contain infs or NaNs. + grad_scaler.step(optimizer) + # Updates the scale for next iteration. + grad_scaler.update() + optimizer.zero_grad() if lr_scheduler is not None: @@ -115,7 +133,7 @@ def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None): "loss": loss.item(), "grad_norm": float(grad_norm), "lr": optimizer.param_groups[0]["lr"], - "update_s": time.time() - start_time, + "update_s": time.perf_counter() - start_time, **{k: v for k, v in output_dict.items() if k != "loss"}, } @@ -211,103 +229,6 @@ def log_eval_info(logger, info, step, cfg, dataset, is_offline): logger.log_dict(info, step, mode="eval") -def calculate_online_sample_weight(n_off: int, n_on: int, pc_on: float): - """ - Calculate the sampling weight to be assigned to samples so that a specified percentage of the batch comes from online dataset (on average). - - Parameters: - - n_off (int): Number of offline samples, each with a sampling weight of 1. - - n_on (int): Number of online samples. - - pc_on (float): Desired percentage of online samples in decimal form (e.g., 50% as 0.5). - - The total weight of offline samples is n_off * 1.0. - The total weight of offline samples is n_on * w. - The total combined weight of all samples is n_off + n_on * w. - The fraction of the weight that is online is n_on * w / (n_off + n_on * w). - We want this fraction to equal pc_on, so we set up the equation n_on * w / (n_off + n_on * w) = pc_on. - The solution is w = - (n_off * pc_on) / (n_on * (pc_on - 1)) - """ - assert 0.0 <= pc_on <= 1.0 - return -(n_off * pc_on) / (n_on * (pc_on - 1)) - - -def add_episodes_inplace( - online_dataset: torch.utils.data.Dataset, - concat_dataset: torch.utils.data.ConcatDataset, - sampler: torch.utils.data.WeightedRandomSampler, - hf_dataset: datasets.Dataset, - episode_data_index: dict[str, torch.Tensor], - pc_online_samples: float, -): - """ - Modifies the online_dataset, concat_dataset, and sampler in place by integrating - new episodes from hf_dataset into the online_dataset, updating the concatenated - dataset's structure and adjusting the sampling strategy based on the specified - percentage of online samples. - - Parameters: - - online_dataset (torch.utils.data.Dataset): The existing online dataset to be updated. - - concat_dataset (torch.utils.data.ConcatDataset): The concatenated dataset that combines - offline and online datasets, used for sampling purposes. - - sampler (torch.utils.data.WeightedRandomSampler): A sampler that will be updated to - reflect changes in the dataset sizes and specified sampling weights. - - hf_dataset (datasets.Dataset): A Hugging Face dataset containing the new episodes to be added. - - episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices. - They indicate the start index and end index of each episode in the dataset. - - pc_online_samples (float): The target percentage of samples that should come from - the online dataset during sampling operations. - - Raises: - - AssertionError: If the first episode_id or index in hf_dataset is not 0 - """ - first_episode_idx = hf_dataset.select_columns("episode_index")[0]["episode_index"].item() - last_episode_idx = hf_dataset.select_columns("episode_index")[-1]["episode_index"].item() - first_index = hf_dataset.select_columns("index")[0]["index"].item() - last_index = hf_dataset.select_columns("index")[-1]["index"].item() - # sanity check - assert first_episode_idx == 0, f"{first_episode_idx=} is not 0" - assert first_index == 0, f"{first_index=} is not 0" - assert first_index == episode_data_index["from"][first_episode_idx].item() - assert last_index == episode_data_index["to"][last_episode_idx].item() - 1 - - if len(online_dataset) == 0: - # initialize online dataset - online_dataset.hf_dataset = hf_dataset - online_dataset.episode_data_index = episode_data_index - else: - # get the starting indices of the new episodes and frames to be added - start_episode_idx = last_episode_idx + 1 - start_index = last_index + 1 - - def shift_indices(episode_index, index): - # note: we dont shift "frame_index" since it represents the index of the frame in the episode it belongs to - example = {"episode_index": episode_index + start_episode_idx, "index": index + start_index} - return example - - disable_progress_bars() # map has a tqdm progress bar - hf_dataset = hf_dataset.map(shift_indices, input_columns=["episode_index", "index"]) - enable_progress_bars() - - episode_data_index["from"] += start_index - episode_data_index["to"] += start_index - - # extend online dataset - online_dataset.hf_dataset = concatenate_datasets([online_dataset.hf_dataset, hf_dataset]) - - # update the concatenated dataset length used during sampling - concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets) - - # update the sampling weights for each frame so that online frames get sampled a certain percentage of times - len_online = len(online_dataset) - len_offline = len(concat_dataset) - len_online - weight_offline = 1.0 - weight_online = calculate_online_sample_weight(len_offline, len_online, pc_online_samples) - sampler.weights = torch.tensor([weight_offline] * len_offline + [weight_online] * len(online_dataset)) - - # update the total number of samples used during sampling - sampler.num_samples = len(concat_dataset) - - def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None): if out_dir is None: raise NotImplementedError() @@ -316,11 +237,11 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No init_logging() - if cfg.training.online_steps > 0 and cfg.eval.batch_size > 1: - logging.warning("eval.batch_size > 1 not supported for online training steps") + if cfg.training.online_steps > 0: + raise NotImplementedError("Online training is not implemented yet.") # Check device is available - get_safe_torch_device(cfg.device, log=True) + device = get_safe_torch_device(cfg.device, log=True) torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True @@ -338,6 +259,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No # Create optimizer and scheduler # Temporary hack to move optimizer out of policy optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) + grad_scaler = GradScaler(enabled=cfg.use_amp) num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_total_params = sum(p.numel() for p in policy.parameters()) @@ -358,14 +280,15 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No def evaluate_and_checkpoint_if_needed(step): if step % cfg.training.eval_freq == 0: logging.info(f"Eval policy at step {step}") - eval_info = eval_policy( - eval_env, - policy, - cfg.eval.n_episodes, - video_dir=Path(out_dir) / "eval", - max_episodes_rendered=4, - start_seed=cfg.seed, - ) + with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext(): + eval_info = eval_policy( + eval_env, + policy, + cfg.eval.n_episodes, + video_dir=Path(out_dir) / "eval", + max_episodes_rendered=4, + start_seed=cfg.seed, + ) log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline) if cfg.wandb.enable: logger.log_video(eval_info["video_paths"][0], step, mode="eval") @@ -389,23 +312,30 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No num_workers=4, batch_size=cfg.training.batch_size, shuffle=True, - pin_memory=cfg.device != "cpu", + pin_memory=device.type != "cpu", drop_last=False, ) dl_iter = cycle(dataloader) policy.train() - step = 0 # number of policy update (forward + backward + optim) is_offline = True - for offline_step in range(cfg.training.offline_steps): - if offline_step == 0: + for step in range(cfg.training.offline_steps): + if step == 0: logging.info("Start offline training on a fixed dataset") batch = next(dl_iter) for key in batch: - batch[key] = batch[key].to(cfg.device, non_blocking=True) + batch[key] = batch[key].to(device, non_blocking=True) - train_info = update_policy(policy, batch, optimizer, cfg.training.grad_clip_norm, lr_scheduler) + train_info = update_policy( + policy, + batch, + optimizer, + cfg.training.grad_clip_norm, + grad_scaler=grad_scaler, + lr_scheduler=lr_scheduler, + use_amp=cfg.use_amp, + ) # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done? if step % cfg.training.log_freq == 0: @@ -415,11 +345,6 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No # so we pass in step + 1. evaluate_and_checkpoint_if_needed(step + 1) - step += 1 - - # create an env dedicated to online episodes collection from policy rollout - online_training_env = make_env(cfg, n_envs=1) - # create an empty online dataset similar to offline dataset online_dataset = deepcopy(offline_dataset) online_dataset.hf_dataset = {} @@ -436,58 +361,11 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No num_workers=4, batch_size=cfg.training.batch_size, sampler=sampler, - pin_memory=cfg.device != "cpu", + pin_memory=device.type != "cpu", drop_last=False, ) - dl_iter = cycle(dataloader) - - online_step = 0 - is_offline = False - for env_step in range(cfg.training.online_steps): - if env_step == 0: - logging.info("Start online training by interacting with environment") - - policy.eval() - with torch.no_grad(): - eval_info = eval_policy( - online_training_env, - policy, - n_episodes=1, - return_episode_data=True, - start_seed=cfg.training.online_env_seed, - enable_progbar=True, - ) - - add_episodes_inplace( - online_dataset, - concat_dataset, - sampler, - hf_dataset=eval_info["episodes"]["hf_dataset"], - episode_data_index=eval_info["episodes"]["episode_data_index"], - pc_online_samples=cfg.training.online_sampling_ratio, - ) - - policy.train() - for _ in range(cfg.training.online_steps_between_rollouts): - batch = next(dl_iter) - - for key in batch: - batch[key] = batch[key].to(cfg.device, non_blocking=True) - - train_info = update_policy(policy, batch, optimizer, cfg.training.grad_clip_norm, lr_scheduler) - - if step % cfg.training.log_freq == 0: - log_train_info(logger, train_info, step, cfg, online_dataset, is_offline) - - # Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed, - # so we pass in step + 1. - evaluate_and_checkpoint_if_needed(step + 1) - - step += 1 - online_step += 1 eval_env.close() - online_training_env.close() logging.info("End of training") diff --git a/poetry.lock b/poetry.lock index e0b27f15..bde0865e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand. [[package]] name = "absl-py" @@ -1177,6 +1177,78 @@ files = [ [package.dependencies] numpy = ">=1.17.3" +[[package]] +name = "hf-transfer" +version = "0.1.6" +description = "" +optional = false +python-versions = ">=3.7" +files = [ + {file = "hf_transfer-0.1.6-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:6fd3d61f9229d27def007e53540412507b74ac2fdb1a29985ae0b6a5137749a2"}, + {file = "hf_transfer-0.1.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b043bb78df1225de043eb041de9d97783fcca14a0bdc1b1d560fc172fc21b648"}, + {file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7db60dd18eae4fa6ea157235fb82196cde5313995b396d1b591aad3b790a7f8f"}, + {file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:30d31dbab9b5a558cce407b8728e39d87d7af1ef8745ddb90187e9ae0b9e1e90"}, + {file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f6b368bddd757efc7af3126ba81f9ac8f9435e2cc00902cb3d64f2be28d8f719"}, + {file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aa2086d8aefaaa3e144e167324574882004c0cec49bf2d0638ec4b74732d8da0"}, + {file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:45d8985a0940bfe1535cb4ca781f5c11e47c83798ef3373ee1f5d57bbe527a9c"}, + {file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2f42b89735f1cde22f2a795d1f0915741023235666be7de45879e533c7d6010c"}, + {file = "hf_transfer-0.1.6-cp310-none-win32.whl", hash = "sha256:2d2c4c4613f3ad45b6ce6291e347b2d3ba1b86816635681436567e461cb3c961"}, + {file = "hf_transfer-0.1.6-cp310-none-win_amd64.whl", hash = "sha256:78b0eed8d8dce60168a46e584b9742b816af127d7e410a713e12c31249195342"}, + {file = "hf_transfer-0.1.6-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:f1d8c172153f9a6cdaecf137612c42796076f61f6bea1072c90ac2e17c1ab6fa"}, + {file = "hf_transfer-0.1.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2c601996351f90c514a75a0eeb02bf700b1ad1db2d946cbfe4b60b79e29f0b2f"}, + {file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e585c808405557d3f5488f385706abb696997bbae262ea04520757e30836d9d"}, + {file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ec51af1e8cf4268c268bd88932ade3d7ca895a3c661b42493503f02610ae906b"}, + {file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d106fdf996332f6df3ed3fab6d6332df82e8c1fb4b20fd81a491ca4d2ab5616a"}, + {file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e9c2ee9e9fde5a0319cc0e8ddfea10897482bc06d5709b10a238f1bc2ebcbc0b"}, + {file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f394ea32bc7802b061e549d3133efc523b4ae4fd19bf4b74b183ca6066eef94e"}, + {file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4282f09902114cd67fca98a1a1bad569a44521a8395fedf327e966714f68b977"}, + {file = "hf_transfer-0.1.6-cp311-none-win32.whl", hash = "sha256:276dbf307d5ab6f1bcbf57b5918bfcf9c59d6848ccb28242349e1bb5985f983b"}, + {file = "hf_transfer-0.1.6-cp311-none-win_amd64.whl", hash = "sha256:fa475175c51451186bea804471995fa8e7b2a48a61dcca55534911dc25955527"}, + {file = "hf_transfer-0.1.6-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:23d157a67acfa00007799323a1c441b2bbacc7dee625b016b7946fe0e25e6c89"}, + {file = "hf_transfer-0.1.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6067342a2864b988f861cd2d31bd78eb1e84d153a3f6df38485b6696d9ad3013"}, + {file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:91cfcb3070e205b58fa8dc8bcb6a62ccc40913fcdb9cd1ff7c364c8e3aa85345"}, + {file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eb76064ac5165d5eeaaf8d0903e8bf55477221ecc2a4a4d69f0baca065ab905b"}, + {file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9dabd3a177d83028f164984cf4dd859f77ec1e20c97a6f307ff8fcada0785ef1"}, + {file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d0bf4254e44f64a26e0a5b73b5d7e8d91bb36870718fb4f8e126ec943ff4c805"}, + {file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d32c1b106f38f336ceb21531f4db9b57d777b9a33017dafdb6a5316388ebe50"}, + {file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ff05aba3c83921e5c7635ba9f07c693cc893350c447644824043aeac27b285f5"}, + {file = "hf_transfer-0.1.6-cp312-none-win32.whl", hash = "sha256:051ef0c55607652cb5974f59638da035773254b9a07d7ee5b574fe062de4c9d1"}, + {file = "hf_transfer-0.1.6-cp312-none-win_amd64.whl", hash = "sha256:716fb5c574fcbdd8092ce73f9b6c66f42e3544337490f77c60ec07df02bd081b"}, + {file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c0c981134a55965e279cb7be778c1ccaf93f902fc9ebe31da4f30caf824cc4d"}, + {file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1ef1f145f04c5b573915bcb1eb5db4039c74f6b46fce73fc473c4287e613b623"}, + {file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d0a7609b004db3347dbb7796df45403eceb171238210d054d93897d6d84c63a4"}, + {file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:60f0864bf5996773dbd5f8ae4d1649041f773fe9d5769f4c0eeb5553100acef3"}, + {file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d01e55d630ffe70a4f5d0ed576a04c6a48d7c65ca9a7d18f2fca385f20685a9"}, + {file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d855946c5062b665190de15b2bdbd4c8eddfee35350bfb7564592e23d36fbbd3"}, + {file = "hf_transfer-0.1.6-cp37-none-win32.whl", hash = "sha256:fd40b2409cfaf3e8aba20169ee09552f69140e029adeec261b988903ff0c8f6f"}, + {file = "hf_transfer-0.1.6-cp37-none-win_amd64.whl", hash = "sha256:0e0eba49d46d3b5481919aea0794aec625fbc6ecdf13fe7e0e9f3fc5d5ad5971"}, + {file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7e669fecb29fc454449739f9f53ed9253197e7c19e6a6eaa0f08334207af4287"}, + {file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:89f701802892e5eb84f89f402686861f87dc227d6082b05f4e9d9b4e8015a3c3"}, + {file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b6f2b0c8b95b01409275d789a9b74d5f2e146346f985d384bf50ec727caf1ccc"}, + {file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aa855a2fa262792a230f9efcdb5da6d431b747d1861d2a69fe7834b19aea077e"}, + {file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4aa8ca349afb2f0713475426946261eb2035e4efb50ebd2c1d5ad04f395f4217"}, + {file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:01255f043996bc7d1bae62d8afc5033a90c7e36ce308b988eeb84afe0a69562f"}, + {file = "hf_transfer-0.1.6-cp38-none-win32.whl", hash = "sha256:60b1db183e8a7540cd4f8b2160ff4de55f77cb0c3fc6a10be1e7c30eb1b2bdeb"}, + {file = "hf_transfer-0.1.6-cp38-none-win_amd64.whl", hash = "sha256:fb8be3cba6aaa50ab2e9dffbd25c8eb2046785eeff642cf0cdd0dd9ae6be3539"}, + {file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d09af35e3e3f09b664e6429e9a0dc200f29c5bdfd88bdd9666de51183b1fe202"}, + {file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a4505bd707cc14d85c800f961fad8ca76f804a8ad22fbb7b1a217d8d0c15e6a5"}, + {file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2c453fd8b0be9740faa23cecd1f28ee9ead7d900cefa64ff836960c503a744c9"}, + {file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:13cb8884e718a78c3b81a8cdec9c7ac196dd42961fce55c3ccff3dd783e5ad7a"}, + {file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:39cd39df171a2b5404de69c4e6cd14eee47f6fe91c1692f939bfb9e59a0110d8"}, + {file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ff0629ee9f98df57a783599602eb498f9ec3619dc69348b12e4d9d754abf0e9"}, + {file = "hf_transfer-0.1.6-cp39-none-win32.whl", hash = "sha256:164a6ce445eb0cc7c645f5b6e1042c003d33292520c90052b6325f30c98e4c5f"}, + {file = "hf_transfer-0.1.6-cp39-none-win_amd64.whl", hash = "sha256:11b8b4b73bf455f13218c5f827698a30ae10998ca31b8264b51052868c7a9f11"}, + {file = "hf_transfer-0.1.6-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:16957ba057376a99ea361074ce1094f61b58e769defa6be2422ae59c0b6a6530"}, + {file = "hf_transfer-0.1.6-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7db952112e3b8ee1a5cbf500d2443e9ce4fb893281c5310a3e31469898628005"}, + {file = "hf_transfer-0.1.6-pp37-pypy37_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d39d826a7344f5e39f438d62632acd00467aa54a083b66496f61ef67a9885a56"}, + {file = "hf_transfer-0.1.6-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4e2653fbfa92e7651db73d99b697c8684e7345c479bd6857da80bed6138abb2"}, + {file = "hf_transfer-0.1.6-pp38-pypy38_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:144277e6a86add10b90ec3b583253aec777130312256bfc8d5ade5377e253807"}, + {file = "hf_transfer-0.1.6-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3bb53bcd16365313b2aa0dbdc28206f577d70770f31249cdabc387ac5841edcc"}, + {file = "hf_transfer-0.1.6-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:990d73a5a68d8261980f146c51f4c5f9995314011cb225222021ad7c39f3af2d"}, + {file = "hf_transfer-0.1.6-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:652406037029ab9b4097b4c5f29321bad5f64c2b46fbff142509d918aec87c29"}, + {file = "hf_transfer-0.1.6.tar.gz", hash = "sha256:deb505a7d417d7055fd7b3549eadb91dfe782941261f3344025c486c16d1d2f9"}, +] + [[package]] name = "huggingface-hub" version = "0.23.0" @@ -1191,6 +1263,7 @@ files = [ [package.dependencies] filelock = "*" fsspec = ">=2023.5.0" +hf-transfer = {version = ">=0.1.4", optional = true, markers = "extra == \"hf-transfer\""} packaging = ">=20.9" pyyaml = ">=5.1" requests = "*" @@ -4175,4 +4248,4 @@ xarm = ["gym-xarm"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "e3e3c306a5519e4f716a1ac086ad9b734efedcac077a0ec71e5bc16349a1e559" +content-hash = "e4834d67df32c8c617c259b0e59bb33ddaccde08fe940d771e74046cbffe3399" diff --git a/pyproject.toml b/pyproject.toml index 5b80d06f..f043c9de 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,7 @@ opencv-python = ">=4.9.0" diffusers = "^0.27.2" torchvision = ">=0.18.0" h5py = ">=3.10.0" -huggingface-hub = ">=0.21.4" +huggingface-hub = {extras = ["hf-transfer"], version = "^0.23.0"} gymnasium = ">=0.29.1" cmake = ">=3.29.0.1" gym-pusht = { version = ">=0.1.3", optional = true} diff --git a/tests/data/lerobot/aloha_mobile_cabinet/meta_data/episode_data_index.safetensors b/tests/data/lerobot/aloha_mobile_cabinet/meta_data/episode_data_index.safetensors new file mode 100644 index 00000000..767dadb0 --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_cabinet/meta_data/episode_data_index.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9f9347c8d9ac90ee44e6dd86f65043438168df6bbe4bab2d2b875e55ef7376ef +size 1488 diff --git a/tests/data/lerobot/aloha_mobile_cabinet/meta_data/info.json b/tests/data/lerobot/aloha_mobile_cabinet/meta_data/info.json new file mode 100644 index 00000000..8c5c4ee8 --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_cabinet/meta_data/info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb +size 33 diff --git a/tests/data/lerobot/aloha_mobile_cabinet/meta_data/stats.safetensors b/tests/data/lerobot/aloha_mobile_cabinet/meta_data/stats.safetensors new file mode 100644 index 00000000..0d260d7e --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_cabinet/meta_data/stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:02fc4ea25766269f65752a60b0594c43d799b0ae528cd773bf024b064b5aa329 +size 4344 diff --git a/tests/data/lerobot/aloha_mobile_cabinet/train/data-00000-of-00001.arrow b/tests/data/lerobot/aloha_mobile_cabinet/train/data-00000-of-00001.arrow new file mode 100644 index 00000000..73f31bb5 --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_cabinet/train/data-00000-of-00001.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:55d7b1a06fe3e3051482752740074348bdb5fc98fb2e305b06d6203994117b27 +size 592448 diff --git a/tests/data/lerobot/aloha_mobile_cabinet/train/dataset_info.json b/tests/data/lerobot/aloha_mobile_cabinet/train/dataset_info.json new file mode 100644 index 00000000..55b885b6 --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_cabinet/train/dataset_info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8b7fbedfdb3d536847bc6fadf2cbabb9f2b5492edf3e2c274a3e8ffb447105e8 +size 1166 diff --git a/tests/data/lerobot/aloha_mobile_cabinet/train/state.json b/tests/data/lerobot/aloha_mobile_cabinet/train/state.json new file mode 100644 index 00000000..393396e7 --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_cabinet/train/state.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:98329e4b40e9be0d63f7d36da9d86c44bbe7eeeb1b10d3ba973c923f3be70867 +size 247 diff --git a/tests/data/lerobot/aloha_mobile_cabinet/videos/observation.images.cam_high_episode_000000.mp4 b/tests/data/lerobot/aloha_mobile_cabinet/videos/observation.images.cam_high_episode_000000.mp4 new file mode 100644 index 00000000..6b287d2c --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_cabinet/videos/observation.images.cam_high_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:54e42cdfd016a0ced2ab1fe2966a8c15a2384e0dbe1a2fe87433a2d1b8209ac0 +size 5220057 diff --git a/tests/data/lerobot/aloha_mobile_cabinet/videos/observation.images.cam_left_wrist_episode_000000.mp4 b/tests/data/lerobot/aloha_mobile_cabinet/videos/observation.images.cam_left_wrist_episode_000000.mp4 new file mode 100644 index 00000000..cbebb0ef --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_cabinet/videos/observation.images.cam_left_wrist_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:af1ded2a244cb47a96255b75f584a643edf6967e13bb5464b330ffdd9d7ad859 +size 5284692 diff --git a/tests/data/lerobot/aloha_mobile_cabinet/videos/observation.images.cam_right_wrist_episode_000000.mp4 b/tests/data/lerobot/aloha_mobile_cabinet/videos/observation.images.cam_right_wrist_episode_000000.mp4 new file mode 100644 index 00000000..c58387a0 --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_cabinet/videos/observation.images.cam_right_wrist_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:13d1bebabd79984fd6715971be758ef9a354495adea5e8d33f4e7904365e112b +size 5258380 diff --git a/tests/data/lerobot/aloha_mobile_chair/meta_data/episode_data_index.safetensors b/tests/data/lerobot/aloha_mobile_chair/meta_data/episode_data_index.safetensors new file mode 100644 index 00000000..933c06e0 --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_chair/meta_data/episode_data_index.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f33bc6810f0b91817a42610364cb49ed1b99660f058f0f9407e6f5920d0aee02 +size 1008 diff --git a/tests/data/lerobot/aloha_mobile_chair/meta_data/info.json b/tests/data/lerobot/aloha_mobile_chair/meta_data/info.json new file mode 100644 index 00000000..8c5c4ee8 --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_chair/meta_data/info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb +size 33 diff --git a/tests/data/lerobot/aloha_mobile_chair/meta_data/stats.safetensors b/tests/data/lerobot/aloha_mobile_chair/meta_data/stats.safetensors new file mode 100644 index 00000000..0b037f94 --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_chair/meta_data/stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7b58d6c89e936a781a307805ebecf0dd473fbc02d52a7094da62e54bffb9454a +size 4344 diff --git a/tests/data/lerobot/aloha_mobile_chair/train/data-00000-of-00001.arrow b/tests/data/lerobot/aloha_mobile_chair/train/data-00000-of-00001.arrow new file mode 100644 index 00000000..969fef0b --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_chair/train/data-00000-of-00001.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a08be578285cbe2d35b78f150d464ff3e10604a9865398c976983e0d711774f9 +size 788528 diff --git a/tests/data/lerobot/aloha_mobile_chair/train/dataset_info.json b/tests/data/lerobot/aloha_mobile_chair/train/dataset_info.json new file mode 100644 index 00000000..55b885b6 --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_chair/train/dataset_info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8b7fbedfdb3d536847bc6fadf2cbabb9f2b5492edf3e2c274a3e8ffb447105e8 +size 1166 diff --git a/tests/data/lerobot/aloha_mobile_chair/train/state.json b/tests/data/lerobot/aloha_mobile_chair/train/state.json new file mode 100644 index 00000000..c59e8787 --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_chair/train/state.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:34e36233477c8aa0b0840314ddace072062d4f486d06546bbd6550832c370065 +size 247 diff --git a/tests/data/lerobot/aloha_mobile_chair/videos/observation.images.cam_high_episode_000000.mp4 b/tests/data/lerobot/aloha_mobile_chair/videos/observation.images.cam_high_episode_000000.mp4 new file mode 100644 index 00000000..741645aa --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_chair/videos/observation.images.cam_high_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:66e7349a4a82ca6042a7189608d01eb1cfa38d100d039b5445ae1a9e65d824ab +size 14470946 diff --git a/tests/data/lerobot/aloha_mobile_chair/videos/observation.images.cam_left_wrist_episode_000000.mp4 b/tests/data/lerobot/aloha_mobile_chair/videos/observation.images.cam_left_wrist_episode_000000.mp4 new file mode 100644 index 00000000..7f9a021c --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_chair/videos/observation.images.cam_left_wrist_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a2146f0c10c9f2611e57e617983aa4f91ad681b4fc50d91b992b97abd684f926 +size 11662185 diff --git a/tests/data/lerobot/aloha_mobile_chair/videos/observation.images.cam_right_wrist_episode_000000.mp4 b/tests/data/lerobot/aloha_mobile_chair/videos/observation.images.cam_right_wrist_episode_000000.mp4 new file mode 100644 index 00000000..68cfa02e --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_chair/videos/observation.images.cam_right_wrist_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5affbaf1c48895ba3c626e0d8cf1309e5f4ec6bbaa135313096f52a22de66c05 +size 11410342 diff --git a/tests/data/lerobot/aloha_mobile_elevator/meta_data/episode_data_index.safetensors b/tests/data/lerobot/aloha_mobile_elevator/meta_data/episode_data_index.safetensors new file mode 100644 index 00000000..17839482 --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_elevator/meta_data/episode_data_index.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6c2b195ca91b88fd16422128d386d2cabd808a1862c6d127e6bf2e83e1fe819a +size 448 diff --git a/tests/data/lerobot/aloha_mobile_elevator/meta_data/info.json b/tests/data/lerobot/aloha_mobile_elevator/meta_data/info.json new file mode 100644 index 00000000..8c5c4ee8 --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_elevator/meta_data/info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb +size 33 diff --git a/tests/data/lerobot/aloha_mobile_elevator/meta_data/stats.safetensors b/tests/data/lerobot/aloha_mobile_elevator/meta_data/stats.safetensors new file mode 100644 index 00000000..4f9629d1 --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_elevator/meta_data/stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b360b6b956d2adcb20589947c553348ef1eb6b70743c989dcbe95243d8592ce5 +size 4344 diff --git a/tests/data/lerobot/aloha_mobile_elevator/train/data-00000-of-00001.arrow b/tests/data/lerobot/aloha_mobile_elevator/train/data-00000-of-00001.arrow new file mode 100644 index 00000000..f0ae94ac --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_elevator/train/data-00000-of-00001.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3f5c3926b4d4da9271abefcdf6a8952bb1f13258a9c39fe0fd223f548dc89dcb +size 887728 diff --git a/tests/data/lerobot/aloha_mobile_elevator/train/dataset_info.json b/tests/data/lerobot/aloha_mobile_elevator/train/dataset_info.json new file mode 100644 index 00000000..55b885b6 --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_elevator/train/dataset_info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8b7fbedfdb3d536847bc6fadf2cbabb9f2b5492edf3e2c274a3e8ffb447105e8 +size 1166 diff --git a/tests/data/lerobot/aloha_mobile_elevator/train/state.json b/tests/data/lerobot/aloha_mobile_elevator/train/state.json new file mode 100644 index 00000000..d7f90b56 --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_elevator/train/state.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4993b05fb026619eec5eb70db8cadaa041ba4ab92d38b4a387167ace03b1018b +size 247 diff --git a/tests/data/lerobot/aloha_mobile_elevator/videos/observation.images.cam_high_episode_000000.mp4 b/tests/data/lerobot/aloha_mobile_elevator/videos/observation.images.cam_high_episode_000000.mp4 new file mode 100644 index 00000000..1663f8eb --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_elevator/videos/observation.images.cam_high_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bd25d17ef5b7500386761b5e32920879bbdcafe0e17a8a8845628525d861e644 +size 10231081 diff --git a/tests/data/lerobot/aloha_mobile_elevator/videos/observation.images.cam_left_wrist_episode_000000.mp4 b/tests/data/lerobot/aloha_mobile_elevator/videos/observation.images.cam_left_wrist_episode_000000.mp4 new file mode 100644 index 00000000..974db761 --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_elevator/videos/observation.images.cam_left_wrist_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5b557acbfeb0681c0a38e47263d945f6cd3a03461298d8b17209c81e3fd0aae8 +size 9701371 diff --git a/tests/data/lerobot/aloha_mobile_elevator/videos/observation.images.cam_right_wrist_episode_000000.mp4 b/tests/data/lerobot/aloha_mobile_elevator/videos/observation.images.cam_right_wrist_episode_000000.mp4 new file mode 100644 index 00000000..61085310 --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_elevator/videos/observation.images.cam_right_wrist_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:da8f3b4f9f965da63819652b2c042d4cf7e07d14631113ea072087d56370310e +size 10473741 diff --git a/tests/data/lerobot/aloha_mobile_shrimp/meta_data/episode_data_index.safetensors b/tests/data/lerobot/aloha_mobile_shrimp/meta_data/episode_data_index.safetensors new file mode 100644 index 00000000..d0798d77 --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_shrimp/meta_data/episode_data_index.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a053506017d8a78cfd307b2912eeafa1ac1485a280cf90913985fcc40120b5ec +size 416 diff --git a/tests/data/lerobot/aloha_mobile_shrimp/meta_data/info.json b/tests/data/lerobot/aloha_mobile_shrimp/meta_data/info.json new file mode 100644 index 00000000..8c5c4ee8 --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_shrimp/meta_data/info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb +size 33 diff --git a/tests/data/lerobot/aloha_mobile_shrimp/meta_data/stats.safetensors b/tests/data/lerobot/aloha_mobile_shrimp/meta_data/stats.safetensors new file mode 100644 index 00000000..7fcce357 --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_shrimp/meta_data/stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d6d172d1bca02face22ceb4c21ea2b054cf3463025485dce64711b6f36b31f8a +size 4344 diff --git a/tests/data/lerobot/aloha_mobile_shrimp/train/data-00000-of-00001.arrow b/tests/data/lerobot/aloha_mobile_shrimp/train/data-00000-of-00001.arrow new file mode 100644 index 00000000..6b1275d9 --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_shrimp/train/data-00000-of-00001.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7e5ce817a2c188041f57f8d4c465dab3b9c3e4e1aeb7a9fb270230d1b36df530 +size 1477064 diff --git a/tests/data/lerobot/aloha_mobile_shrimp/train/dataset_info.json b/tests/data/lerobot/aloha_mobile_shrimp/train/dataset_info.json new file mode 100644 index 00000000..55b885b6 --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_shrimp/train/dataset_info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8b7fbedfdb3d536847bc6fadf2cbabb9f2b5492edf3e2c274a3e8ffb447105e8 +size 1166 diff --git a/tests/data/lerobot/aloha_mobile_shrimp/train/state.json b/tests/data/lerobot/aloha_mobile_shrimp/train/state.json new file mode 100644 index 00000000..f9db9e15 --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_shrimp/train/state.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4eb2dc373e4ea7d474742590f9073d66a773f6ab94b9e73a8673df19f93fae6d +size 247 diff --git a/tests/data/lerobot/aloha_mobile_shrimp/videos/observation.images.cam_high_episode_000000.mp4 b/tests/data/lerobot/aloha_mobile_shrimp/videos/observation.images.cam_high_episode_000000.mp4 new file mode 100644 index 00000000..32348f9e --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_shrimp/videos/observation.images.cam_high_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d2c55b146fabe78b18c8a28a7746ab56e1ee7a6918e9e3dad9bd196f97975895 +size 26158915 diff --git a/tests/data/lerobot/aloha_mobile_shrimp/videos/observation.images.cam_left_wrist_episode_000000.mp4 b/tests/data/lerobot/aloha_mobile_shrimp/videos/observation.images.cam_left_wrist_episode_000000.mp4 new file mode 100644 index 00000000..23bd8be1 --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_shrimp/videos/observation.images.cam_left_wrist_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:71e1958d77f56843acf1ec48da4f04311a5836c87a0e77dbe26aa47c27c6347e +size 18786848 diff --git a/tests/data/lerobot/aloha_mobile_shrimp/videos/observation.images.cam_right_wrist_episode_000000.mp4 b/tests/data/lerobot/aloha_mobile_shrimp/videos/observation.images.cam_right_wrist_episode_000000.mp4 new file mode 100644 index 00000000..6ea368c2 --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_shrimp/videos/observation.images.cam_right_wrist_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:20780718399b5759ff9a3a79824986310524793066198e3b9a307222f11a93df +size 17769988 diff --git a/tests/data/lerobot/aloha_mobile_wash_pan/meta_data/episode_data_index.safetensors b/tests/data/lerobot/aloha_mobile_wash_pan/meta_data/episode_data_index.safetensors new file mode 100644 index 00000000..30b7978a --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_wash_pan/meta_data/episode_data_index.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:279916f7689ae46af90e92a46eba9486a71fc762e3e2679ab5441eb37126827b +size 928 diff --git a/tests/data/lerobot/aloha_mobile_wash_pan/meta_data/info.json b/tests/data/lerobot/aloha_mobile_wash_pan/meta_data/info.json new file mode 100644 index 00000000..8c5c4ee8 --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_wash_pan/meta_data/info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb +size 33 diff --git a/tests/data/lerobot/aloha_mobile_wash_pan/meta_data/stats.safetensors b/tests/data/lerobot/aloha_mobile_wash_pan/meta_data/stats.safetensors new file mode 100644 index 00000000..a8eb54cc --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_wash_pan/meta_data/stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7a7731051b521694b52b5631470720a7f05331915f4ac4e7f8cd83f9ff459bce +size 4344 diff --git a/tests/data/lerobot/aloha_mobile_wash_pan/train/data-00000-of-00001.arrow b/tests/data/lerobot/aloha_mobile_wash_pan/train/data-00000-of-00001.arrow new file mode 100644 index 00000000..102f0a0d --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_wash_pan/train/data-00000-of-00001.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:99608258e8c9fe5191f1a12edc29b47d307790104149dffb6d3046ddad6aeb1b +size 435600 diff --git a/tests/data/lerobot/aloha_mobile_wash_pan/train/dataset_info.json b/tests/data/lerobot/aloha_mobile_wash_pan/train/dataset_info.json new file mode 100644 index 00000000..55b885b6 --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_wash_pan/train/dataset_info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8b7fbedfdb3d536847bc6fadf2cbabb9f2b5492edf3e2c274a3e8ffb447105e8 +size 1166 diff --git a/tests/data/lerobot/aloha_mobile_wash_pan/train/state.json b/tests/data/lerobot/aloha_mobile_wash_pan/train/state.json new file mode 100644 index 00000000..427a4ccd --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_wash_pan/train/state.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ae6735b7b394914824e974a7461019373a10f9e2d84ddf834bec8ea268d9ec1e +size 247 diff --git a/tests/data/lerobot/aloha_mobile_wash_pan/videos/observation.images.cam_high_episode_000000.mp4 b/tests/data/lerobot/aloha_mobile_wash_pan/videos/observation.images.cam_high_episode_000000.mp4 new file mode 100644 index 00000000..bd734fa9 --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_wash_pan/videos/observation.images.cam_high_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:833e288c5fdacbbe10a5d048cb6f49fe1a396d91b2117b827e130ec11069256a +size 8397615 diff --git a/tests/data/lerobot/aloha_mobile_wash_pan/videos/observation.images.cam_left_wrist_episode_000000.mp4 b/tests/data/lerobot/aloha_mobile_wash_pan/videos/observation.images.cam_left_wrist_episode_000000.mp4 new file mode 100644 index 00000000..166efdda --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_wash_pan/videos/observation.images.cam_left_wrist_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2cb870acb4855fef70f19c5f632d94e4c25eef59eeea92f4b1167a44b1b36b33 +size 5912007 diff --git a/tests/data/lerobot/aloha_mobile_wash_pan/videos/observation.images.cam_right_wrist_episode_000000.mp4 b/tests/data/lerobot/aloha_mobile_wash_pan/videos/observation.images.cam_right_wrist_episode_000000.mp4 new file mode 100644 index 00000000..53b721ca --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_wash_pan/videos/observation.images.cam_right_wrist_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8be36298141b455ea51d17a78e4bbc6619639302139fe2db605bdfa3ff5e91bd +size 4794018 diff --git a/tests/data/lerobot/aloha_mobile_wipe_wine/meta_data/episode_data_index.safetensors b/tests/data/lerobot/aloha_mobile_wipe_wine/meta_data/episode_data_index.safetensors new file mode 100644 index 00000000..32c783b1 --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_wipe_wine/meta_data/episode_data_index.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:348d0ee38a71929b2017d540de870b9dff6d79efdd0cbc5352fa9697e350134a +size 928 diff --git a/tests/data/lerobot/aloha_mobile_wipe_wine/meta_data/info.json b/tests/data/lerobot/aloha_mobile_wipe_wine/meta_data/info.json new file mode 100644 index 00000000..8c5c4ee8 --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_wipe_wine/meta_data/info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb +size 33 diff --git a/tests/data/lerobot/aloha_mobile_wipe_wine/meta_data/stats.safetensors b/tests/data/lerobot/aloha_mobile_wipe_wine/meta_data/stats.safetensors new file mode 100644 index 00000000..afcf1857 --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_wipe_wine/meta_data/stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c5c2996f58d5277fa19cf56ec143334fbee940d1de37530452496a6f0aa11f88 +size 4344 diff --git a/tests/data/lerobot/aloha_mobile_wipe_wine/train/data-00000-of-00001.arrow b/tests/data/lerobot/aloha_mobile_wipe_wine/train/data-00000-of-00001.arrow new file mode 100644 index 00000000..e734adb9 --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_wipe_wine/train/data-00000-of-00001.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:da3a8efea9ba60d1fdd209d45a3387df22a09f7c156904ecb03f10456736fb74 +size 514056 diff --git a/tests/data/lerobot/aloha_mobile_wipe_wine/train/dataset_info.json b/tests/data/lerobot/aloha_mobile_wipe_wine/train/dataset_info.json new file mode 100644 index 00000000..55b885b6 --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_wipe_wine/train/dataset_info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8b7fbedfdb3d536847bc6fadf2cbabb9f2b5492edf3e2c274a3e8ffb447105e8 +size 1166 diff --git a/tests/data/lerobot/aloha_mobile_wipe_wine/train/state.json b/tests/data/lerobot/aloha_mobile_wipe_wine/train/state.json new file mode 100644 index 00000000..a10185eb --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_wipe_wine/train/state.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6b7111ff1ef5c4d6a2990f5f39f42398f061da8c4e81adf46b9d9150ec2feeaf +size 247 diff --git a/tests/data/lerobot/aloha_mobile_wipe_wine/videos/observation.images.cam_high_episode_000000.mp4 b/tests/data/lerobot/aloha_mobile_wipe_wine/videos/observation.images.cam_high_episode_000000.mp4 new file mode 100644 index 00000000..5b98bbae --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_wipe_wine/videos/observation.images.cam_high_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2ac8c2755d940534042595ecad33ebea358974ec67bc041c8675e53b7d2272ff +size 9182551 diff --git a/tests/data/lerobot/aloha_mobile_wipe_wine/videos/observation.images.cam_left_wrist_episode_000000.mp4 b/tests/data/lerobot/aloha_mobile_wipe_wine/videos/observation.images.cam_left_wrist_episode_000000.mp4 new file mode 100644 index 00000000..34677e98 --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_wipe_wine/videos/observation.images.cam_left_wrist_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b35aaa37e66dd5563d93e6059d5b645e112e020e03bd398f7098a5289970953a +size 6378566 diff --git a/tests/data/lerobot/aloha_mobile_wipe_wine/videos/observation.images.cam_right_wrist_episode_000000.mp4 b/tests/data/lerobot/aloha_mobile_wipe_wine/videos/observation.images.cam_right_wrist_episode_000000.mp4 new file mode 100644 index 00000000..cee9add3 --- /dev/null +++ b/tests/data/lerobot/aloha_mobile_wipe_wine/videos/observation.images.cam_right_wrist_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6138247ba7160a3de6c50111e6fcc5ae075044086d8527ae5d435b1f8a7c7a93 +size 6439183 diff --git a/tests/data/lerobot/aloha_sim_insertion_human/meta_data/episode_data_index.safetensors b/tests/data/lerobot/aloha_sim_insertion_human/meta_data/episode_data_index.safetensors index 828c6720..bb503d5d 100644 Binary files a/tests/data/lerobot/aloha_sim_insertion_human/meta_data/episode_data_index.safetensors and b/tests/data/lerobot/aloha_sim_insertion_human/meta_data/episode_data_index.safetensors differ diff --git a/tests/data/lerobot/aloha_sim_insertion_human/meta_data/info.json b/tests/data/lerobot/aloha_sim_insertion_human/meta_data/info.json index 279cf2c2..8c5c4ee8 100644 --- a/tests/data/lerobot/aloha_sim_insertion_human/meta_data/info.json +++ b/tests/data/lerobot/aloha_sim_insertion_human/meta_data/info.json @@ -1,4 +1,3 @@ -{ - "fps": 50, - "video": 1 -} \ No newline at end of file +version https://git-lfs.github.com/spec/v1 +oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb +size 33 diff --git a/tests/data/lerobot/aloha_sim_insertion_human/meta_data/stats.safetensors b/tests/data/lerobot/aloha_sim_insertion_human/meta_data/stats.safetensors index 9b6a7c83..84516142 100644 Binary files a/tests/data/lerobot/aloha_sim_insertion_human/meta_data/stats.safetensors and b/tests/data/lerobot/aloha_sim_insertion_human/meta_data/stats.safetensors differ diff --git a/tests/data/lerobot/aloha_sim_insertion_human/train/data-00000-of-00001.arrow b/tests/data/lerobot/aloha_sim_insertion_human/train/data-00000-of-00001.arrow index d93d0e27..13fb4452 100644 Binary files a/tests/data/lerobot/aloha_sim_insertion_human/train/data-00000-of-00001.arrow and b/tests/data/lerobot/aloha_sim_insertion_human/train/data-00000-of-00001.arrow differ diff --git a/tests/data/lerobot/aloha_sim_insertion_human/train/dataset_info.json b/tests/data/lerobot/aloha_sim_insertion_human/train/dataset_info.json index c6f7b938..1c9122f7 100644 --- a/tests/data/lerobot/aloha_sim_insertion_human/train/dataset_info.json +++ b/tests/data/lerobot/aloha_sim_insertion_human/train/dataset_info.json @@ -1,47 +1,3 @@ -{ - "citation": "", - "description": "", - "features": { - "observation.images.top": { - "_type": "VideoFrame" - }, - "observation.state": { - "feature": { - "dtype": "float32", - "_type": "Value" - }, - "length": 14, - "_type": "Sequence" - }, - "action": { - "feature": { - "dtype": "float32", - "_type": "Value" - }, - "length": 14, - "_type": "Sequence" - }, - "episode_index": { - "dtype": "int64", - "_type": "Value" - }, - "frame_index": { - "dtype": "int64", - "_type": "Value" - }, - "timestamp": { - "dtype": "float32", - "_type": "Value" - }, - "next.done": { - "dtype": "bool", - "_type": "Value" - }, - "index": { - "dtype": "int64", - "_type": "Value" - } - }, - "homepage": "", - "license": "" -} \ No newline at end of file +version https://git-lfs.github.com/spec/v1 +oid sha256:3f44d13de5d5a417263bbd4984942ed42ed3fa0633405aa14d9a969a45274944 +size 842 diff --git a/tests/data/lerobot/aloha_sim_insertion_human/train/state.json b/tests/data/lerobot/aloha_sim_insertion_human/train/state.json index 6cd9158a..aa5f34da 100644 --- a/tests/data/lerobot/aloha_sim_insertion_human/train/state.json +++ b/tests/data/lerobot/aloha_sim_insertion_human/train/state.json @@ -1,13 +1,3 @@ -{ - "_data_files": [ - { - "filename": "data-00000-of-00001.arrow" - } - ], - "_fingerprint": "eb913a2b1a68aa74", - "_format_columns": null, - "_format_kwargs": {}, - "_format_type": null, - "_output_all_columns": false, - "_split": null -} \ No newline at end of file +version https://git-lfs.github.com/spec/v1 +oid sha256:71d6ba89bee5a4ee2761220452999e415bc838a44bebf1b5a2e4ba8622369798 +size 247 diff --git a/tests/data/lerobot/aloha_sim_insertion_human/videos/observation.images.top_episode_000000.mp4 b/tests/data/lerobot/aloha_sim_insertion_human/videos/observation.images.top_episode_000000.mp4 index 56280d53..ef3660f2 100644 Binary files a/tests/data/lerobot/aloha_sim_insertion_human/videos/observation.images.top_episode_000000.mp4 and b/tests/data/lerobot/aloha_sim_insertion_human/videos/observation.images.top_episode_000000.mp4 differ diff --git a/tests/data/lerobot/aloha_sim_insertion_human_image/meta_data/episode_data_index.safetensors b/tests/data/lerobot/aloha_sim_insertion_human_image/meta_data/episode_data_index.safetensors new file mode 100644 index 00000000..6cd34f25 --- /dev/null +++ b/tests/data/lerobot/aloha_sim_insertion_human_image/meta_data/episode_data_index.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e7dbc214a415689ca7fb83b6f8e12ec7824dfe34a66024b0b24bfeb3aeefd0e4 +size 928 diff --git a/tests/data/lerobot/aloha_sim_insertion_human_image/meta_data/info.json b/tests/data/lerobot/aloha_sim_insertion_human_image/meta_data/info.json new file mode 100644 index 00000000..5d86c44e --- /dev/null +++ b/tests/data/lerobot/aloha_sim_insertion_human_image/meta_data/info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:083db9efc5c9e3396c5e1159d020c2a3786f1f1a4b069719d327ed7fbc65c34d +size 33 diff --git a/tests/data/lerobot/aloha_sim_insertion_human_image/meta_data/stats.safetensors b/tests/data/lerobot/aloha_sim_insertion_human_image/meta_data/stats.safetensors new file mode 100644 index 00000000..fa5c7586 --- /dev/null +++ b/tests/data/lerobot/aloha_sim_insertion_human_image/meta_data/stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f98bd8f6347590aecdddaceed95d921f2d9f7bf35fbe742c37bdf12cba11dca6 +size 2904 diff --git a/tests/data/lerobot/aloha_sim_insertion_human_image/train/data-00000-of-00001.arrow b/tests/data/lerobot/aloha_sim_insertion_human_image/train/data-00000-of-00001.arrow new file mode 100644 index 00000000..93a11cf8 --- /dev/null +++ b/tests/data/lerobot/aloha_sim_insertion_human_image/train/data-00000-of-00001.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c0013aea549ec290af94bddde1b559fb8d0967d4c43ef14319177c4e62ed1e91 +size 14545712 diff --git a/tests/data/lerobot/aloha_sim_insertion_human_image/train/dataset_info.json b/tests/data/lerobot/aloha_sim_insertion_human_image/train/dataset_info.json new file mode 100644 index 00000000..cf816b66 --- /dev/null +++ b/tests/data/lerobot/aloha_sim_insertion_human_image/train/dataset_info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4c9545525dc1f4d550591bd5efb63b55c15b983ae0510fefda5a16d77c78b6ef +size 837 diff --git a/tests/data/lerobot/aloha_sim_insertion_human_image/train/state.json b/tests/data/lerobot/aloha_sim_insertion_human_image/train/state.json new file mode 100644 index 00000000..42d31e8e --- /dev/null +++ b/tests/data/lerobot/aloha_sim_insertion_human_image/train/state.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d7aa033603dc90582516dbcdf3e71e4d3113b70ad49098535def0b282135b5f3 +size 247 diff --git a/tests/data/lerobot/aloha_sim_insertion_scripted/meta_data/episode_data_index.safetensors b/tests/data/lerobot/aloha_sim_insertion_scripted/meta_data/episode_data_index.safetensors index 1505d613..4195a89f 100644 Binary files a/tests/data/lerobot/aloha_sim_insertion_scripted/meta_data/episode_data_index.safetensors and b/tests/data/lerobot/aloha_sim_insertion_scripted/meta_data/episode_data_index.safetensors differ diff --git a/tests/data/lerobot/aloha_sim_insertion_scripted/meta_data/info.json b/tests/data/lerobot/aloha_sim_insertion_scripted/meta_data/info.json index 279cf2c2..8c5c4ee8 100644 --- a/tests/data/lerobot/aloha_sim_insertion_scripted/meta_data/info.json +++ b/tests/data/lerobot/aloha_sim_insertion_scripted/meta_data/info.json @@ -1,4 +1,3 @@ -{ - "fps": 50, - "video": 1 -} \ No newline at end of file +version https://git-lfs.github.com/spec/v1 +oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb +size 33 diff --git a/tests/data/lerobot/aloha_sim_insertion_scripted/meta_data/stats.safetensors b/tests/data/lerobot/aloha_sim_insertion_scripted/meta_data/stats.safetensors index 6cce9ffa..e4bae4af 100644 Binary files a/tests/data/lerobot/aloha_sim_insertion_scripted/meta_data/stats.safetensors and b/tests/data/lerobot/aloha_sim_insertion_scripted/meta_data/stats.safetensors differ diff --git a/tests/data/lerobot/aloha_sim_insertion_scripted/train/data-00000-of-00001.arrow b/tests/data/lerobot/aloha_sim_insertion_scripted/train/data-00000-of-00001.arrow index 65a231a6..82a39801 100644 Binary files a/tests/data/lerobot/aloha_sim_insertion_scripted/train/data-00000-of-00001.arrow and b/tests/data/lerobot/aloha_sim_insertion_scripted/train/data-00000-of-00001.arrow differ diff --git a/tests/data/lerobot/aloha_sim_insertion_scripted/train/dataset_info.json b/tests/data/lerobot/aloha_sim_insertion_scripted/train/dataset_info.json index c6f7b938..1c9122f7 100644 --- a/tests/data/lerobot/aloha_sim_insertion_scripted/train/dataset_info.json +++ b/tests/data/lerobot/aloha_sim_insertion_scripted/train/dataset_info.json @@ -1,47 +1,3 @@ -{ - "citation": "", - "description": "", - "features": { - "observation.images.top": { - "_type": "VideoFrame" - }, - "observation.state": { - "feature": { - "dtype": "float32", - "_type": "Value" - }, - "length": 14, - "_type": "Sequence" - }, - "action": { - "feature": { - "dtype": "float32", - "_type": "Value" - }, - "length": 14, - "_type": "Sequence" - }, - "episode_index": { - "dtype": "int64", - "_type": "Value" - }, - "frame_index": { - "dtype": "int64", - "_type": "Value" - }, - "timestamp": { - "dtype": "float32", - "_type": "Value" - }, - "next.done": { - "dtype": "bool", - "_type": "Value" - }, - "index": { - "dtype": "int64", - "_type": "Value" - } - }, - "homepage": "", - "license": "" -} \ No newline at end of file +version https://git-lfs.github.com/spec/v1 +oid sha256:3f44d13de5d5a417263bbd4984942ed42ed3fa0633405aa14d9a969a45274944 +size 842 diff --git a/tests/data/lerobot/aloha_sim_insertion_scripted/train/state.json b/tests/data/lerobot/aloha_sim_insertion_scripted/train/state.json index b96705cb..bb533378 100644 --- a/tests/data/lerobot/aloha_sim_insertion_scripted/train/state.json +++ b/tests/data/lerobot/aloha_sim_insertion_scripted/train/state.json @@ -1,13 +1,3 @@ -{ - "_data_files": [ - { - "filename": "data-00000-of-00001.arrow" - } - ], - "_fingerprint": "d20c2acf1e107266", - "_format_columns": null, - "_format_kwargs": {}, - "_format_type": null, - "_output_all_columns": false, - "_split": null -} \ No newline at end of file +version https://git-lfs.github.com/spec/v1 +oid sha256:e48156ce4f71ac15d78732312fbc7e199f0ecdaac3604231e6be2e3e5b31a0ad +size 247 diff --git a/tests/data/lerobot/aloha_sim_insertion_scripted/videos/observation.images.top_episode_000000.mp4 b/tests/data/lerobot/aloha_sim_insertion_scripted/videos/observation.images.top_episode_000000.mp4 index f36a0c18..07c41ce4 100644 Binary files a/tests/data/lerobot/aloha_sim_insertion_scripted/videos/observation.images.top_episode_000000.mp4 and b/tests/data/lerobot/aloha_sim_insertion_scripted/videos/observation.images.top_episode_000000.mp4 differ diff --git a/tests/data/lerobot/aloha_sim_insertion_scripted_image/meta_data/episode_data_index.safetensors b/tests/data/lerobot/aloha_sim_insertion_scripted_image/meta_data/episode_data_index.safetensors new file mode 100644 index 00000000..4195a89f --- /dev/null +++ b/tests/data/lerobot/aloha_sim_insertion_scripted_image/meta_data/episode_data_index.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4500f31e62f0928a837fa71783acacda0db516c7b00d0586a41ea5fd8fc5e772 +size 928 diff --git a/tests/data/lerobot/aloha_sim_insertion_scripted_image/meta_data/info.json b/tests/data/lerobot/aloha_sim_insertion_scripted_image/meta_data/info.json new file mode 100644 index 00000000..5d86c44e --- /dev/null +++ b/tests/data/lerobot/aloha_sim_insertion_scripted_image/meta_data/info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:083db9efc5c9e3396c5e1159d020c2a3786f1f1a4b069719d327ed7fbc65c34d +size 33 diff --git a/tests/data/lerobot/aloha_sim_insertion_scripted_image/meta_data/stats.safetensors b/tests/data/lerobot/aloha_sim_insertion_scripted_image/meta_data/stats.safetensors new file mode 100644 index 00000000..26f256f5 --- /dev/null +++ b/tests/data/lerobot/aloha_sim_insertion_scripted_image/meta_data/stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0227d4e9e3b43a86bf33fbd68683ede537fdeab1b53f2ebf155620e10054352f +size 2904 diff --git a/tests/data/lerobot/aloha_sim_insertion_scripted_image/train/data-00000-of-00001.arrow b/tests/data/lerobot/aloha_sim_insertion_scripted_image/train/data-00000-of-00001.arrow new file mode 100644 index 00000000..735a420b --- /dev/null +++ b/tests/data/lerobot/aloha_sim_insertion_scripted_image/train/data-00000-of-00001.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1bdf02ecf7d5fc502f6dd9f520c636828a5988ad16a69a137780a824f94f8112 +size 10782640 diff --git a/tests/data/lerobot/aloha_sim_insertion_scripted_image/train/dataset_info.json b/tests/data/lerobot/aloha_sim_insertion_scripted_image/train/dataset_info.json new file mode 100644 index 00000000..cf816b66 --- /dev/null +++ b/tests/data/lerobot/aloha_sim_insertion_scripted_image/train/dataset_info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4c9545525dc1f4d550591bd5efb63b55c15b983ae0510fefda5a16d77c78b6ef +size 837 diff --git a/tests/data/lerobot/aloha_sim_insertion_scripted_image/train/state.json b/tests/data/lerobot/aloha_sim_insertion_scripted_image/train/state.json new file mode 100644 index 00000000..8c42e0b2 --- /dev/null +++ b/tests/data/lerobot/aloha_sim_insertion_scripted_image/train/state.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:752660d8fd884b33b7302a4a42ec7c680de2a3e5022d7d007586f4c6337ce08a +size 247 diff --git a/tests/data/lerobot/aloha_sim_transfer_cube_human/meta_data/episode_data_index.safetensors b/tests/data/lerobot/aloha_sim_transfer_cube_human/meta_data/episode_data_index.safetensors index 1505d613..4195a89f 100644 Binary files a/tests/data/lerobot/aloha_sim_transfer_cube_human/meta_data/episode_data_index.safetensors and b/tests/data/lerobot/aloha_sim_transfer_cube_human/meta_data/episode_data_index.safetensors differ diff --git a/tests/data/lerobot/aloha_sim_transfer_cube_human/meta_data/info.json b/tests/data/lerobot/aloha_sim_transfer_cube_human/meta_data/info.json index 279cf2c2..8c5c4ee8 100644 --- a/tests/data/lerobot/aloha_sim_transfer_cube_human/meta_data/info.json +++ b/tests/data/lerobot/aloha_sim_transfer_cube_human/meta_data/info.json @@ -1,4 +1,3 @@ -{ - "fps": 50, - "video": 1 -} \ No newline at end of file +version https://git-lfs.github.com/spec/v1 +oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb +size 33 diff --git a/tests/data/lerobot/aloha_sim_transfer_cube_human/meta_data/stats.safetensors b/tests/data/lerobot/aloha_sim_transfer_cube_human/meta_data/stats.safetensors index 2fe6aff2..aa4183d1 100644 Binary files a/tests/data/lerobot/aloha_sim_transfer_cube_human/meta_data/stats.safetensors and b/tests/data/lerobot/aloha_sim_transfer_cube_human/meta_data/stats.safetensors differ diff --git a/tests/data/lerobot/aloha_sim_transfer_cube_human/train/data-00000-of-00001.arrow b/tests/data/lerobot/aloha_sim_transfer_cube_human/train/data-00000-of-00001.arrow index a9f60d30..0e83c8db 100644 Binary files a/tests/data/lerobot/aloha_sim_transfer_cube_human/train/data-00000-of-00001.arrow and b/tests/data/lerobot/aloha_sim_transfer_cube_human/train/data-00000-of-00001.arrow differ diff --git a/tests/data/lerobot/aloha_sim_transfer_cube_human/train/dataset_info.json b/tests/data/lerobot/aloha_sim_transfer_cube_human/train/dataset_info.json index c6f7b938..1c9122f7 100644 --- a/tests/data/lerobot/aloha_sim_transfer_cube_human/train/dataset_info.json +++ b/tests/data/lerobot/aloha_sim_transfer_cube_human/train/dataset_info.json @@ -1,47 +1,3 @@ -{ - "citation": "", - "description": "", - "features": { - "observation.images.top": { - "_type": "VideoFrame" - }, - "observation.state": { - "feature": { - "dtype": "float32", - "_type": "Value" - }, - "length": 14, - "_type": "Sequence" - }, - "action": { - "feature": { - "dtype": "float32", - "_type": "Value" - }, - "length": 14, - "_type": "Sequence" - }, - "episode_index": { - "dtype": "int64", - "_type": "Value" - }, - "frame_index": { - "dtype": "int64", - "_type": "Value" - }, - "timestamp": { - "dtype": "float32", - "_type": "Value" - }, - "next.done": { - "dtype": "bool", - "_type": "Value" - }, - "index": { - "dtype": "int64", - "_type": "Value" - } - }, - "homepage": "", - "license": "" -} \ No newline at end of file +version https://git-lfs.github.com/spec/v1 +oid sha256:3f44d13de5d5a417263bbd4984942ed42ed3fa0633405aa14d9a969a45274944 +size 842 diff --git a/tests/data/lerobot/aloha_sim_transfer_cube_human/train/state.json b/tests/data/lerobot/aloha_sim_transfer_cube_human/train/state.json index eb74ba89..5e0a8a48 100644 --- a/tests/data/lerobot/aloha_sim_transfer_cube_human/train/state.json +++ b/tests/data/lerobot/aloha_sim_transfer_cube_human/train/state.json @@ -1,13 +1,3 @@ -{ - "_data_files": [ - { - "filename": "data-00000-of-00001.arrow" - } - ], - "_fingerprint": "243b01eb8a4b184e", - "_format_columns": null, - "_format_kwargs": {}, - "_format_type": null, - "_output_all_columns": false, - "_split": null -} \ No newline at end of file +version https://git-lfs.github.com/spec/v1 +oid sha256:6ff72fd4f6f61309191a7f2829b73649d836c1ed10f00983093dc68599c92404 +size 247 diff --git a/tests/data/lerobot/aloha_sim_transfer_cube_human/videos/observation.images.top_episode_000000.mp4 b/tests/data/lerobot/aloha_sim_transfer_cube_human/videos/observation.images.top_episode_000000.mp4 index 12a1e5be..e31121e9 100644 Binary files a/tests/data/lerobot/aloha_sim_transfer_cube_human/videos/observation.images.top_episode_000000.mp4 and b/tests/data/lerobot/aloha_sim_transfer_cube_human/videos/observation.images.top_episode_000000.mp4 differ diff --git a/tests/data/lerobot/aloha_sim_transfer_cube_human_image/meta_data/episode_data_index.safetensors b/tests/data/lerobot/aloha_sim_transfer_cube_human_image/meta_data/episode_data_index.safetensors new file mode 100644 index 00000000..4195a89f --- /dev/null +++ b/tests/data/lerobot/aloha_sim_transfer_cube_human_image/meta_data/episode_data_index.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4500f31e62f0928a837fa71783acacda0db516c7b00d0586a41ea5fd8fc5e772 +size 928 diff --git a/tests/data/lerobot/aloha_sim_transfer_cube_human_image/meta_data/info.json b/tests/data/lerobot/aloha_sim_transfer_cube_human_image/meta_data/info.json new file mode 100644 index 00000000..5d86c44e --- /dev/null +++ b/tests/data/lerobot/aloha_sim_transfer_cube_human_image/meta_data/info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:083db9efc5c9e3396c5e1159d020c2a3786f1f1a4b069719d327ed7fbc65c34d +size 33 diff --git a/tests/data/lerobot/aloha_sim_transfer_cube_human_image/meta_data/stats.safetensors b/tests/data/lerobot/aloha_sim_transfer_cube_human_image/meta_data/stats.safetensors new file mode 100644 index 00000000..544874c4 --- /dev/null +++ b/tests/data/lerobot/aloha_sim_transfer_cube_human_image/meta_data/stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:09de36f2d6786e65e26d4602e00f9097f63a087a6a4f36e98c5367724acfc755 +size 2904 diff --git a/tests/data/lerobot/aloha_sim_transfer_cube_human_image/train/data-00000-of-00001.arrow b/tests/data/lerobot/aloha_sim_transfer_cube_human_image/train/data-00000-of-00001.arrow new file mode 100644 index 00000000..a8ee8a54 --- /dev/null +++ b/tests/data/lerobot/aloha_sim_transfer_cube_human_image/train/data-00000-of-00001.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2f35b1ce169c6355536405718409041da7969cc351d62ef0d2c6f6351ac009e2 +size 10640376 diff --git a/tests/data/lerobot/aloha_sim_transfer_cube_human_image/train/dataset_info.json b/tests/data/lerobot/aloha_sim_transfer_cube_human_image/train/dataset_info.json new file mode 100644 index 00000000..cf816b66 --- /dev/null +++ b/tests/data/lerobot/aloha_sim_transfer_cube_human_image/train/dataset_info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4c9545525dc1f4d550591bd5efb63b55c15b983ae0510fefda5a16d77c78b6ef +size 837 diff --git a/tests/data/lerobot/aloha_sim_transfer_cube_human_image/train/state.json b/tests/data/lerobot/aloha_sim_transfer_cube_human_image/train/state.json new file mode 100644 index 00000000..35f278ca --- /dev/null +++ b/tests/data/lerobot/aloha_sim_transfer_cube_human_image/train/state.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ec153065a4e52f7d55a7f026d804c57a3ce05dc1faa255a1947369f83c70f1e7 +size 247 diff --git a/tests/data/lerobot/aloha_sim_transfer_cube_scripted/meta_data/episode_data_index.safetensors b/tests/data/lerobot/aloha_sim_transfer_cube_scripted/meta_data/episode_data_index.safetensors index 1505d613..4195a89f 100644 Binary files a/tests/data/lerobot/aloha_sim_transfer_cube_scripted/meta_data/episode_data_index.safetensors and b/tests/data/lerobot/aloha_sim_transfer_cube_scripted/meta_data/episode_data_index.safetensors differ diff --git a/tests/data/lerobot/aloha_sim_transfer_cube_scripted/meta_data/info.json b/tests/data/lerobot/aloha_sim_transfer_cube_scripted/meta_data/info.json index 279cf2c2..8c5c4ee8 100644 --- a/tests/data/lerobot/aloha_sim_transfer_cube_scripted/meta_data/info.json +++ b/tests/data/lerobot/aloha_sim_transfer_cube_scripted/meta_data/info.json @@ -1,4 +1,3 @@ -{ - "fps": 50, - "video": 1 -} \ No newline at end of file +version https://git-lfs.github.com/spec/v1 +oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb +size 33 diff --git a/tests/data/lerobot/aloha_sim_transfer_cube_scripted/meta_data/stats.safetensors b/tests/data/lerobot/aloha_sim_transfer_cube_scripted/meta_data/stats.safetensors index c2ab5b21..057dceca 100644 Binary files a/tests/data/lerobot/aloha_sim_transfer_cube_scripted/meta_data/stats.safetensors and b/tests/data/lerobot/aloha_sim_transfer_cube_scripted/meta_data/stats.safetensors differ diff --git a/tests/data/lerobot/aloha_sim_transfer_cube_scripted/train/data-00000-of-00001.arrow b/tests/data/lerobot/aloha_sim_transfer_cube_scripted/train/data-00000-of-00001.arrow index 405509d1..8f9ae056 100644 Binary files a/tests/data/lerobot/aloha_sim_transfer_cube_scripted/train/data-00000-of-00001.arrow and b/tests/data/lerobot/aloha_sim_transfer_cube_scripted/train/data-00000-of-00001.arrow differ diff --git a/tests/data/lerobot/aloha_sim_transfer_cube_scripted/train/dataset_info.json b/tests/data/lerobot/aloha_sim_transfer_cube_scripted/train/dataset_info.json index c6f7b938..1c9122f7 100644 --- a/tests/data/lerobot/aloha_sim_transfer_cube_scripted/train/dataset_info.json +++ b/tests/data/lerobot/aloha_sim_transfer_cube_scripted/train/dataset_info.json @@ -1,47 +1,3 @@ -{ - "citation": "", - "description": "", - "features": { - "observation.images.top": { - "_type": "VideoFrame" - }, - "observation.state": { - "feature": { - "dtype": "float32", - "_type": "Value" - }, - "length": 14, - "_type": "Sequence" - }, - "action": { - "feature": { - "dtype": "float32", - "_type": "Value" - }, - "length": 14, - "_type": "Sequence" - }, - "episode_index": { - "dtype": "int64", - "_type": "Value" - }, - "frame_index": { - "dtype": "int64", - "_type": "Value" - }, - "timestamp": { - "dtype": "float32", - "_type": "Value" - }, - "next.done": { - "dtype": "bool", - "_type": "Value" - }, - "index": { - "dtype": "int64", - "_type": "Value" - } - }, - "homepage": "", - "license": "" -} \ No newline at end of file +version https://git-lfs.github.com/spec/v1 +oid sha256:3f44d13de5d5a417263bbd4984942ed42ed3fa0633405aa14d9a969a45274944 +size 842 diff --git a/tests/data/lerobot/aloha_sim_transfer_cube_scripted/train/state.json b/tests/data/lerobot/aloha_sim_transfer_cube_scripted/train/state.json index 91c46511..706e48ee 100644 --- a/tests/data/lerobot/aloha_sim_transfer_cube_scripted/train/state.json +++ b/tests/data/lerobot/aloha_sim_transfer_cube_scripted/train/state.json @@ -1,13 +1,3 @@ -{ - "_data_files": [ - { - "filename": "data-00000-of-00001.arrow" - } - ], - "_fingerprint": "eb759bbf60df7be9", - "_format_columns": null, - "_format_kwargs": {}, - "_format_type": null, - "_output_all_columns": false, - "_split": null -} \ No newline at end of file +version https://git-lfs.github.com/spec/v1 +oid sha256:4d5bfac4bd22cab6449b24e457719c6598b367f191160335cba81c3b416b1cd5 +size 247 diff --git a/tests/data/lerobot/aloha_sim_transfer_cube_scripted/videos/observation.images.top_episode_000000.mp4 b/tests/data/lerobot/aloha_sim_transfer_cube_scripted/videos/observation.images.top_episode_000000.mp4 index 2d25242c..dbd3dfab 100644 Binary files a/tests/data/lerobot/aloha_sim_transfer_cube_scripted/videos/observation.images.top_episode_000000.mp4 and b/tests/data/lerobot/aloha_sim_transfer_cube_scripted/videos/observation.images.top_episode_000000.mp4 differ diff --git a/tests/data/lerobot/aloha_sim_transfer_cube_scripted_image/meta_data/episode_data_index.safetensors b/tests/data/lerobot/aloha_sim_transfer_cube_scripted_image/meta_data/episode_data_index.safetensors new file mode 100644 index 00000000..4195a89f --- /dev/null +++ b/tests/data/lerobot/aloha_sim_transfer_cube_scripted_image/meta_data/episode_data_index.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4500f31e62f0928a837fa71783acacda0db516c7b00d0586a41ea5fd8fc5e772 +size 928 diff --git a/tests/data/lerobot/aloha_sim_transfer_cube_scripted_image/meta_data/info.json b/tests/data/lerobot/aloha_sim_transfer_cube_scripted_image/meta_data/info.json new file mode 100644 index 00000000..5d86c44e --- /dev/null +++ b/tests/data/lerobot/aloha_sim_transfer_cube_scripted_image/meta_data/info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:083db9efc5c9e3396c5e1159d020c2a3786f1f1a4b069719d327ed7fbc65c34d +size 33 diff --git a/tests/data/lerobot/aloha_sim_transfer_cube_scripted_image/meta_data/stats.safetensors b/tests/data/lerobot/aloha_sim_transfer_cube_scripted_image/meta_data/stats.safetensors new file mode 100644 index 00000000..54946e8c --- /dev/null +++ b/tests/data/lerobot/aloha_sim_transfer_cube_scripted_image/meta_data/stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5dae4fa688991d97145fd975d317b24b177f674cc57e53ef4caba1413fe1aad8 +size 2904 diff --git a/tests/data/lerobot/aloha_sim_transfer_cube_scripted_image/train/data-00000-of-00001.arrow b/tests/data/lerobot/aloha_sim_transfer_cube_scripted_image/train/data-00000-of-00001.arrow new file mode 100644 index 00000000..61979acb --- /dev/null +++ b/tests/data/lerobot/aloha_sim_transfer_cube_scripted_image/train/data-00000-of-00001.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3948750400c273b7e3ef2998fa84ca7a520d8972c8759c2428ff6fbdc2bd8fb7 +size 11505144 diff --git a/tests/data/lerobot/aloha_sim_transfer_cube_scripted_image/train/dataset_info.json b/tests/data/lerobot/aloha_sim_transfer_cube_scripted_image/train/dataset_info.json new file mode 100644 index 00000000..cf816b66 --- /dev/null +++ b/tests/data/lerobot/aloha_sim_transfer_cube_scripted_image/train/dataset_info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4c9545525dc1f4d550591bd5efb63b55c15b983ae0510fefda5a16d77c78b6ef +size 837 diff --git a/tests/data/lerobot/aloha_sim_transfer_cube_scripted_image/train/state.json b/tests/data/lerobot/aloha_sim_transfer_cube_scripted_image/train/state.json new file mode 100644 index 00000000..bda86ecb --- /dev/null +++ b/tests/data/lerobot/aloha_sim_transfer_cube_scripted_image/train/state.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fda1fe75c9f987c065d4244594e4f6456b7ac6efd7fae2a7952fb48b044dbd30 +size 247 diff --git a/tests/data/lerobot/aloha_static_battery/meta_data/episode_data_index.safetensors b/tests/data/lerobot/aloha_static_battery/meta_data/episode_data_index.safetensors new file mode 100644 index 00000000..bcf6f38c --- /dev/null +++ b/tests/data/lerobot/aloha_static_battery/meta_data/episode_data_index.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a17b1cd612f06662c14e0a74f9bd4787ab97558d57d6e150d06a401a48eccba9 +size 912 diff --git a/tests/data/lerobot/aloha_static_battery/meta_data/info.json b/tests/data/lerobot/aloha_static_battery/meta_data/info.json new file mode 100644 index 00000000..8c5c4ee8 --- /dev/null +++ b/tests/data/lerobot/aloha_static_battery/meta_data/info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb +size 33 diff --git a/tests/data/lerobot/aloha_static_battery/meta_data/stats.safetensors b/tests/data/lerobot/aloha_static_battery/meta_data/stats.safetensors new file mode 100644 index 00000000..59d07cc3 --- /dev/null +++ b/tests/data/lerobot/aloha_static_battery/meta_data/stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f724049d7bd23458f02c606fa59cc1d0f2a44f7fac9e5b4c4eff97c95ac132b3 +size 4208 diff --git a/tests/data/lerobot/aloha_static_battery/train/data-00000-of-00001.arrow b/tests/data/lerobot/aloha_static_battery/train/data-00000-of-00001.arrow new file mode 100644 index 00000000..56a3e72b --- /dev/null +++ b/tests/data/lerobot/aloha_static_battery/train/data-00000-of-00001.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:52fb2e9785eb8d069def324272008ca829ebb32830165bdcbbfdfd3cc6bf42dd +size 240928 diff --git a/tests/data/lerobot/aloha_static_battery/train/dataset_info.json b/tests/data/lerobot/aloha_static_battery/train/dataset_info.json new file mode 100644 index 00000000..7fe0ac4c --- /dev/null +++ b/tests/data/lerobot/aloha_static_battery/train/dataset_info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e2e066afefdee57f3bc534085ab7af54e62d3ab2736d42863a89deb743cd0d04 +size 1075 diff --git a/tests/data/lerobot/aloha_static_battery/train/state.json b/tests/data/lerobot/aloha_static_battery/train/state.json new file mode 100644 index 00000000..04593ca1 --- /dev/null +++ b/tests/data/lerobot/aloha_static_battery/train/state.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5e06118a224b4477c0aaf17d43930cd7f516f3de932c6d5547ee741f609a6228 +size 247 diff --git a/tests/data/lerobot/aloha_static_battery/videos/observation.images.cam_high_episode_000000.mp4 b/tests/data/lerobot/aloha_static_battery/videos/observation.images.cam_high_episode_000000.mp4 new file mode 100644 index 00000000..f1f71ae9 --- /dev/null +++ b/tests/data/lerobot/aloha_static_battery/videos/observation.images.cam_high_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:551bb7677b875b263df5b60d6e258e8ca825572938b65c3eacfa7ae3ea325149 +size 4246212 diff --git a/tests/data/lerobot/aloha_static_battery/videos/observation.images.cam_left_wrist_episode_000000.mp4 b/tests/data/lerobot/aloha_static_battery/videos/observation.images.cam_left_wrist_episode_000000.mp4 new file mode 100644 index 00000000..c91a7999 --- /dev/null +++ b/tests/data/lerobot/aloha_static_battery/videos/observation.images.cam_left_wrist_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7230024f74e30b3f0430790f95e4920968d7e1821355006b48c191487ec6ca4d +size 3712553 diff --git a/tests/data/lerobot/aloha_static_battery/videos/observation.images.cam_low_episode_000000.mp4 b/tests/data/lerobot/aloha_static_battery/videos/observation.images.cam_low_episode_000000.mp4 new file mode 100644 index 00000000..5991d529 --- /dev/null +++ b/tests/data/lerobot/aloha_static_battery/videos/observation.images.cam_low_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:376486caa0033cc2b1345e0e838243c4fe09623ae6e655765f5ec92f90a70d4e +size 3358646 diff --git a/tests/data/lerobot/aloha_static_battery/videos/observation.images.cam_right_wrist_episode_000000.mp4 b/tests/data/lerobot/aloha_static_battery/videos/observation.images.cam_right_wrist_episode_000000.mp4 new file mode 100644 index 00000000..f5343155 --- /dev/null +++ b/tests/data/lerobot/aloha_static_battery/videos/observation.images.cam_right_wrist_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dfed128b3afe26ad99095944023765cfbe32b35f7ddc267ed4b8c19462211e7c +size 4183134 diff --git a/tests/data/lerobot/aloha_static_candy/meta_data/episode_data_index.safetensors b/tests/data/lerobot/aloha_static_candy/meta_data/episode_data_index.safetensors new file mode 100644 index 00000000..c4f52e96 --- /dev/null +++ b/tests/data/lerobot/aloha_static_candy/meta_data/episode_data_index.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ed655b17ac00510afe57cc340cc8fa3b0b7e377f9de41fa53092425174b9730d +size 928 diff --git a/tests/data/lerobot/aloha_static_candy/meta_data/info.json b/tests/data/lerobot/aloha_static_candy/meta_data/info.json new file mode 100644 index 00000000..8c5c4ee8 --- /dev/null +++ b/tests/data/lerobot/aloha_static_candy/meta_data/info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb +size 33 diff --git a/tests/data/lerobot/aloha_static_candy/meta_data/stats.safetensors b/tests/data/lerobot/aloha_static_candy/meta_data/stats.safetensors new file mode 100644 index 00000000..370fc150 --- /dev/null +++ b/tests/data/lerobot/aloha_static_candy/meta_data/stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:973576883a1dedcd1d8c21a8ab2879854d7cbdeff8f40787bc654298deeeaa2f +size 4208 diff --git a/tests/data/lerobot/aloha_static_candy/train/data-00000-of-00001.arrow b/tests/data/lerobot/aloha_static_candy/train/data-00000-of-00001.arrow new file mode 100644 index 00000000..806e0f0a --- /dev/null +++ b/tests/data/lerobot/aloha_static_candy/train/data-00000-of-00001.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1b034ce6eadfd29cbc7eedc62eeb6a01c9ed7ffa411b773d4a1fe9f79bbd847d +size 280536 diff --git a/tests/data/lerobot/aloha_static_candy/train/dataset_info.json b/tests/data/lerobot/aloha_static_candy/train/dataset_info.json new file mode 100644 index 00000000..7fe0ac4c --- /dev/null +++ b/tests/data/lerobot/aloha_static_candy/train/dataset_info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e2e066afefdee57f3bc534085ab7af54e62d3ab2736d42863a89deb743cd0d04 +size 1075 diff --git a/tests/data/lerobot/aloha_static_candy/train/state.json b/tests/data/lerobot/aloha_static_candy/train/state.json new file mode 100644 index 00000000..171981a8 --- /dev/null +++ b/tests/data/lerobot/aloha_static_candy/train/state.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f94ca1a43797e9356f1d0d597557ff20f3ca23b7e1bd6a9155df65dcf9f434e2 +size 247 diff --git a/tests/data/lerobot/aloha_static_candy/videos/observation.images.cam_high_episode_000000.mp4 b/tests/data/lerobot/aloha_static_candy/videos/observation.images.cam_high_episode_000000.mp4 new file mode 100644 index 00000000..eaca1761 --- /dev/null +++ b/tests/data/lerobot/aloha_static_candy/videos/observation.images.cam_high_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a3b0bbc457659de53ee2b8787dc4e920fbd528488514f752e636c3a9c04899df +size 4911104 diff --git a/tests/data/lerobot/aloha_static_candy/videos/observation.images.cam_left_wrist_episode_000000.mp4 b/tests/data/lerobot/aloha_static_candy/videos/observation.images.cam_left_wrist_episode_000000.mp4 new file mode 100644 index 00000000..229ce6ba --- /dev/null +++ b/tests/data/lerobot/aloha_static_candy/videos/observation.images.cam_left_wrist_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d660ec74fbf088baa8f34861f1c8fbba9f0909563e0768867daa86581502f90e +size 3401787 diff --git a/tests/data/lerobot/aloha_static_candy/videos/observation.images.cam_low_episode_000000.mp4 b/tests/data/lerobot/aloha_static_candy/videos/observation.images.cam_low_episode_000000.mp4 new file mode 100644 index 00000000..bcf4e716 --- /dev/null +++ b/tests/data/lerobot/aloha_static_candy/videos/observation.images.cam_low_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ebc983ce4021a4243fd635126a520e20ef83861a586b7f0672b6b608484b93ec +size 4035176 diff --git a/tests/data/lerobot/aloha_static_candy/videos/observation.images.cam_right_wrist_episode_000000.mp4 b/tests/data/lerobot/aloha_static_candy/videos/observation.images.cam_right_wrist_episode_000000.mp4 new file mode 100644 index 00000000..2da6290c --- /dev/null +++ b/tests/data/lerobot/aloha_static_candy/videos/observation.images.cam_right_wrist_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:441479fdd37074272658400b6fb9c8e4c071e98b125320771d63cf4575aff2d2 +size 4215835 diff --git a/tests/data/lerobot/aloha_static_coffee/meta_data/episode_data_index.safetensors b/tests/data/lerobot/aloha_static_coffee/meta_data/episode_data_index.safetensors new file mode 100644 index 00000000..30b7978a --- /dev/null +++ b/tests/data/lerobot/aloha_static_coffee/meta_data/episode_data_index.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:279916f7689ae46af90e92a46eba9486a71fc762e3e2679ab5441eb37126827b +size 928 diff --git a/tests/data/lerobot/aloha_static_coffee/meta_data/info.json b/tests/data/lerobot/aloha_static_coffee/meta_data/info.json new file mode 100644 index 00000000..8c5c4ee8 --- /dev/null +++ b/tests/data/lerobot/aloha_static_coffee/meta_data/info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb +size 33 diff --git a/tests/data/lerobot/aloha_static_coffee/meta_data/stats.safetensors b/tests/data/lerobot/aloha_static_coffee/meta_data/stats.safetensors new file mode 100644 index 00000000..0ca75755 --- /dev/null +++ b/tests/data/lerobot/aloha_static_coffee/meta_data/stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cc624f3eb2a5b26dfe44c468a24ec1214a8816f50c08c9f946a7ea088bc43c1e +size 4752 diff --git a/tests/data/lerobot/aloha_static_coffee/train/data-00000-of-00001.arrow b/tests/data/lerobot/aloha_static_coffee/train/data-00000-of-00001.arrow new file mode 100644 index 00000000..4eb46312 --- /dev/null +++ b/tests/data/lerobot/aloha_static_coffee/train/data-00000-of-00001.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0efc7f82fb433744f77f86adabbba81df6e1d4a0f3d660725a10359549f75d62 +size 502200 diff --git a/tests/data/lerobot/aloha_static_coffee/train/dataset_info.json b/tests/data/lerobot/aloha_static_coffee/train/dataset_info.json new file mode 100644 index 00000000..144d793b --- /dev/null +++ b/tests/data/lerobot/aloha_static_coffee/train/dataset_info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:55969ea876fc62b17afca0777816d93c1c90f23207a3562c907cfbd6502858f9 +size 1237 diff --git a/tests/data/lerobot/aloha_static_coffee/train/state.json b/tests/data/lerobot/aloha_static_coffee/train/state.json new file mode 100644 index 00000000..bfa74d36 --- /dev/null +++ b/tests/data/lerobot/aloha_static_coffee/train/state.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:acbefb679113a410d9b2a337d2d4b5393c60538e4bb5da97686e3abd2fa623b2 +size 247 diff --git a/tests/data/lerobot/aloha_static_coffee/videos/observation.images.cam_high_episode_000000.mp4 b/tests/data/lerobot/aloha_static_coffee/videos/observation.images.cam_high_episode_000000.mp4 new file mode 100644 index 00000000..80dac156 --- /dev/null +++ b/tests/data/lerobot/aloha_static_coffee/videos/observation.images.cam_high_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:345cc5402794b1566fae73146ac3f0b36b18d1df4badcf1897a2ae3fae3f1e99 +size 9177059 diff --git a/tests/data/lerobot/aloha_static_coffee/videos/observation.images.cam_left_wrist_episode_000000.mp4 b/tests/data/lerobot/aloha_static_coffee/videos/observation.images.cam_left_wrist_episode_000000.mp4 new file mode 100644 index 00000000..ee068c57 --- /dev/null +++ b/tests/data/lerobot/aloha_static_coffee/videos/observation.images.cam_left_wrist_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f30fba7ee9b10cb3f5eb50f117c742e47d5ec6c205036edfa62bec9fc8b7517c +size 6345170 diff --git a/tests/data/lerobot/aloha_static_coffee/videos/observation.images.cam_low_episode_000000.mp4 b/tests/data/lerobot/aloha_static_coffee/videos/observation.images.cam_low_episode_000000.mp4 new file mode 100644 index 00000000..72abaad5 --- /dev/null +++ b/tests/data/lerobot/aloha_static_coffee/videos/observation.images.cam_low_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:494871573d44aa601d091f24a4cccd4f625c8290929575ff189636495c970684 +size 7119730 diff --git a/tests/data/lerobot/aloha_static_coffee/videos/observation.images.cam_right_wrist_episode_000000.mp4 b/tests/data/lerobot/aloha_static_coffee/videos/observation.images.cam_right_wrist_episode_000000.mp4 new file mode 100644 index 00000000..0c872310 --- /dev/null +++ b/tests/data/lerobot/aloha_static_coffee/videos/observation.images.cam_right_wrist_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:31dc487809c370dc9e1d740826da03c108740226e19306cf755c27ea623e3e95 +size 6395729 diff --git a/tests/data/lerobot/aloha_static_coffee_new/meta_data/episode_data_index.safetensors b/tests/data/lerobot/aloha_static_coffee_new/meta_data/episode_data_index.safetensors new file mode 100644 index 00000000..beea5395 --- /dev/null +++ b/tests/data/lerobot/aloha_static_coffee_new/meta_data/episode_data_index.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1e2e06405e7a137c3285c74f9bae676ff935c694d446e7000772c8e10df777a8 +size 928 diff --git a/tests/data/lerobot/aloha_static_coffee_new/meta_data/info.json b/tests/data/lerobot/aloha_static_coffee_new/meta_data/info.json new file mode 100644 index 00000000..8c5c4ee8 --- /dev/null +++ b/tests/data/lerobot/aloha_static_coffee_new/meta_data/info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb +size 33 diff --git a/tests/data/lerobot/aloha_static_coffee_new/meta_data/stats.safetensors b/tests/data/lerobot/aloha_static_coffee_new/meta_data/stats.safetensors new file mode 100644 index 00000000..af12bb36 --- /dev/null +++ b/tests/data/lerobot/aloha_static_coffee_new/meta_data/stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b32796a498a506e7701ba91607a3dc59b9dfd8b1efde93ee423a921f05805663 +size 4752 diff --git a/tests/data/lerobot/aloha_static_coffee_new/train/data-00000-of-00001.arrow b/tests/data/lerobot/aloha_static_coffee_new/train/data-00000-of-00001.arrow new file mode 100644 index 00000000..7e026666 --- /dev/null +++ b/tests/data/lerobot/aloha_static_coffee_new/train/data-00000-of-00001.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5da09bf6e21a26299eddae1cf583ab6ce627dd3b6b3cfd488da2783a04672e58 +size 683048 diff --git a/tests/data/lerobot/aloha_static_coffee_new/train/dataset_info.json b/tests/data/lerobot/aloha_static_coffee_new/train/dataset_info.json new file mode 100644 index 00000000..144d793b --- /dev/null +++ b/tests/data/lerobot/aloha_static_coffee_new/train/dataset_info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:55969ea876fc62b17afca0777816d93c1c90f23207a3562c907cfbd6502858f9 +size 1237 diff --git a/tests/data/lerobot/aloha_static_coffee_new/train/state.json b/tests/data/lerobot/aloha_static_coffee_new/train/state.json new file mode 100644 index 00000000..76b538c6 --- /dev/null +++ b/tests/data/lerobot/aloha_static_coffee_new/train/state.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:de36f736a398cc5f11db8f196a92d48a7cba67f19012d529bf0b69ba0e07003d +size 247 diff --git a/tests/data/lerobot/aloha_static_coffee_new/videos/observation.images.cam_high_episode_000000.mp4 b/tests/data/lerobot/aloha_static_coffee_new/videos/observation.images.cam_high_episode_000000.mp4 new file mode 100644 index 00000000..fcd1135c --- /dev/null +++ b/tests/data/lerobot/aloha_static_coffee_new/videos/observation.images.cam_high_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8a5057fc6e4bd535dffdb0e6faad693a2786936ffa666a75280d0db6ae201712 +size 12232643 diff --git a/tests/data/lerobot/aloha_static_coffee_new/videos/observation.images.cam_left_wrist_episode_000000.mp4 b/tests/data/lerobot/aloha_static_coffee_new/videos/observation.images.cam_left_wrist_episode_000000.mp4 new file mode 100644 index 00000000..2677fe66 --- /dev/null +++ b/tests/data/lerobot/aloha_static_coffee_new/videos/observation.images.cam_left_wrist_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4928be2328d324bbb0281cc10307f8987a39a6f1b5c90d39631fdd5b5fa3fbde +size 7717890 diff --git a/tests/data/lerobot/aloha_static_coffee_new/videos/observation.images.cam_low_episode_000000.mp4 b/tests/data/lerobot/aloha_static_coffee_new/videos/observation.images.cam_low_episode_000000.mp4 new file mode 100644 index 00000000..8f7f21a3 --- /dev/null +++ b/tests/data/lerobot/aloha_static_coffee_new/videos/observation.images.cam_low_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4188937513012a3de8d6778f90b968f1fface275b82ee843a59d435a21f64bd7 +size 8759488 diff --git a/tests/data/lerobot/aloha_static_coffee_new/videos/observation.images.cam_right_wrist_episode_000000.mp4 b/tests/data/lerobot/aloha_static_coffee_new/videos/observation.images.cam_right_wrist_episode_000000.mp4 new file mode 100644 index 00000000..acb5814c --- /dev/null +++ b/tests/data/lerobot/aloha_static_coffee_new/videos/observation.images.cam_right_wrist_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:773337beb37a7c463a729914eed354ccf6ce649a3978e4c257be7cd4861f67a0 +size 8968282 diff --git a/tests/data/lerobot/aloha_static_cups_open/meta_data/episode_data_index.safetensors b/tests/data/lerobot/aloha_static_cups_open/meta_data/episode_data_index.safetensors new file mode 100644 index 00000000..4195a89f --- /dev/null +++ b/tests/data/lerobot/aloha_static_cups_open/meta_data/episode_data_index.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4500f31e62f0928a837fa71783acacda0db516c7b00d0586a41ea5fd8fc5e772 +size 928 diff --git a/tests/data/lerobot/aloha_static_cups_open/meta_data/info.json b/tests/data/lerobot/aloha_static_cups_open/meta_data/info.json new file mode 100644 index 00000000..8c5c4ee8 --- /dev/null +++ b/tests/data/lerobot/aloha_static_cups_open/meta_data/info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb +size 33 diff --git a/tests/data/lerobot/aloha_static_cups_open/meta_data/stats.safetensors b/tests/data/lerobot/aloha_static_cups_open/meta_data/stats.safetensors new file mode 100644 index 00000000..2e4012f6 --- /dev/null +++ b/tests/data/lerobot/aloha_static_cups_open/meta_data/stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2926f7125dab62160ba0c8fd19ce8586305344fc218069f45f73b6b672bfff00 +size 4208 diff --git a/tests/data/lerobot/aloha_static_cups_open/train/data-00000-of-00001.arrow b/tests/data/lerobot/aloha_static_cups_open/train/data-00000-of-00001.arrow new file mode 100644 index 00000000..3d387b61 --- /dev/null +++ b/tests/data/lerobot/aloha_static_cups_open/train/data-00000-of-00001.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2ac5379f3dc5151c84dc9300f9384dc09beabf2f11138797410f5a3f53440a21 +size 161640 diff --git a/tests/data/lerobot/aloha_static_cups_open/train/dataset_info.json b/tests/data/lerobot/aloha_static_cups_open/train/dataset_info.json new file mode 100644 index 00000000..7fe0ac4c --- /dev/null +++ b/tests/data/lerobot/aloha_static_cups_open/train/dataset_info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e2e066afefdee57f3bc534085ab7af54e62d3ab2736d42863a89deb743cd0d04 +size 1075 diff --git a/tests/data/lerobot/aloha_static_cups_open/train/state.json b/tests/data/lerobot/aloha_static_cups_open/train/state.json new file mode 100644 index 00000000..ccd89115 --- /dev/null +++ b/tests/data/lerobot/aloha_static_cups_open/train/state.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9e2a558362c1cc7842381b36550e23853cec672d20af98e97dc1fae064a758ed +size 247 diff --git a/tests/data/lerobot/aloha_static_cups_open/videos/observation.images.cam_high_episode_000000.mp4 b/tests/data/lerobot/aloha_static_cups_open/videos/observation.images.cam_high_episode_000000.mp4 new file mode 100644 index 00000000..4621cc81 --- /dev/null +++ b/tests/data/lerobot/aloha_static_cups_open/videos/observation.images.cam_high_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f2f3eb9c8f613dba69ed59f0d28778a04ce6f0eddbc4de86145692fe2f3bde40 +size 2787770 diff --git a/tests/data/lerobot/aloha_static_cups_open/videos/observation.images.cam_left_wrist_episode_000000.mp4 b/tests/data/lerobot/aloha_static_cups_open/videos/observation.images.cam_left_wrist_episode_000000.mp4 new file mode 100644 index 00000000..b13f506c --- /dev/null +++ b/tests/data/lerobot/aloha_static_cups_open/videos/observation.images.cam_left_wrist_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:213eaa3d61b6867f86be0f580aa8860d1431ac280d0b27a794da5f49e98d97ac +size 1881751 diff --git a/tests/data/lerobot/aloha_static_cups_open/videos/observation.images.cam_low_episode_000000.mp4 b/tests/data/lerobot/aloha_static_cups_open/videos/observation.images.cam_low_episode_000000.mp4 new file mode 100644 index 00000000..513286d7 --- /dev/null +++ b/tests/data/lerobot/aloha_static_cups_open/videos/observation.images.cam_low_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:04ab62f752164281b398e377d9b0c6212d3dd5543459ecf9150d123def7425b7 +size 2188614 diff --git a/tests/data/lerobot/aloha_static_cups_open/videos/observation.images.cam_right_wrist_episode_000000.mp4 b/tests/data/lerobot/aloha_static_cups_open/videos/observation.images.cam_right_wrist_episode_000000.mp4 new file mode 100644 index 00000000..1b5dc13c --- /dev/null +++ b/tests/data/lerobot/aloha_static_cups_open/videos/observation.images.cam_right_wrist_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:806e3eab682e7404fc5fe487ca24ad8f94b07ff161018be3944edcb6b97a6219 +size 2588560 diff --git a/tests/data/lerobot/aloha_static_fork_pick_up/meta_data/episode_data_index.safetensors b/tests/data/lerobot/aloha_static_fork_pick_up/meta_data/episode_data_index.safetensors new file mode 100644 index 00000000..4a634f47 --- /dev/null +++ b/tests/data/lerobot/aloha_static_fork_pick_up/meta_data/episode_data_index.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:df5cde6d9bffddea00f30b29d987c3f0dfbed30b78637d0758d8e107c438b7c1 +size 1736 diff --git a/tests/data/lerobot/aloha_static_fork_pick_up/meta_data/info.json b/tests/data/lerobot/aloha_static_fork_pick_up/meta_data/info.json new file mode 100644 index 00000000..8c5c4ee8 --- /dev/null +++ b/tests/data/lerobot/aloha_static_fork_pick_up/meta_data/info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb +size 33 diff --git a/tests/data/lerobot/aloha_static_fork_pick_up/meta_data/stats.safetensors b/tests/data/lerobot/aloha_static_fork_pick_up/meta_data/stats.safetensors new file mode 100644 index 00000000..3a89b9df --- /dev/null +++ b/tests/data/lerobot/aloha_static_fork_pick_up/meta_data/stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:91f7c8d82b012f8e78d5ea02360079f020337bd1c4bee2b8b500c58618f7dfa3 +size 4752 diff --git a/tests/data/lerobot/aloha_static_fork_pick_up/train/data-00000-of-00001.arrow b/tests/data/lerobot/aloha_static_fork_pick_up/train/data-00000-of-00001.arrow new file mode 100644 index 00000000..b1ff1173 --- /dev/null +++ b/tests/data/lerobot/aloha_static_fork_pick_up/train/data-00000-of-00001.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:864372abf759e77b9632759455424eb753f190893db3ef301095f197f655da2f +size 274824 diff --git a/tests/data/lerobot/aloha_static_fork_pick_up/train/dataset_info.json b/tests/data/lerobot/aloha_static_fork_pick_up/train/dataset_info.json new file mode 100644 index 00000000..144d793b --- /dev/null +++ b/tests/data/lerobot/aloha_static_fork_pick_up/train/dataset_info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:55969ea876fc62b17afca0777816d93c1c90f23207a3562c907cfbd6502858f9 +size 1237 diff --git a/tests/data/lerobot/aloha_static_fork_pick_up/train/state.json b/tests/data/lerobot/aloha_static_fork_pick_up/train/state.json new file mode 100644 index 00000000..c0688ecc --- /dev/null +++ b/tests/data/lerobot/aloha_static_fork_pick_up/train/state.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:de515786e9268bce42b97ae7e8090dfdb1bf865beead255ba28f0901b0adf8d5 +size 247 diff --git a/tests/data/lerobot/aloha_static_fork_pick_up/videos/observation.images.cam_high_episode_000000.mp4 b/tests/data/lerobot/aloha_static_fork_pick_up/videos/observation.images.cam_high_episode_000000.mp4 new file mode 100644 index 00000000..faa9dfea --- /dev/null +++ b/tests/data/lerobot/aloha_static_fork_pick_up/videos/observation.images.cam_high_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f4274f639490f60855c4d58ccc0095839e8c6f07849ee986eeaedf90aeca5f8e +size 3859460 diff --git a/tests/data/lerobot/aloha_static_fork_pick_up/videos/observation.images.cam_left_wrist_episode_000000.mp4 b/tests/data/lerobot/aloha_static_fork_pick_up/videos/observation.images.cam_left_wrist_episode_000000.mp4 new file mode 100644 index 00000000..d76da528 --- /dev/null +++ b/tests/data/lerobot/aloha_static_fork_pick_up/videos/observation.images.cam_left_wrist_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1632488aaebcf20a1f1bc4889f422c8afad605521024ae0a569f4d3cba30e24e +size 3572791 diff --git a/tests/data/lerobot/aloha_static_fork_pick_up/videos/observation.images.cam_low_episode_000000.mp4 b/tests/data/lerobot/aloha_static_fork_pick_up/videos/observation.images.cam_low_episode_000000.mp4 new file mode 100644 index 00000000..84763450 --- /dev/null +++ b/tests/data/lerobot/aloha_static_fork_pick_up/videos/observation.images.cam_low_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a57d601ffd4ce86fe706d678c79ef20987d1875ba4126284b5f5dc093ce0e861 +size 2880255 diff --git a/tests/data/lerobot/aloha_static_fork_pick_up/videos/observation.images.cam_right_wrist_episode_000000.mp4 b/tests/data/lerobot/aloha_static_fork_pick_up/videos/observation.images.cam_right_wrist_episode_000000.mp4 new file mode 100644 index 00000000..992fdc9d --- /dev/null +++ b/tests/data/lerobot/aloha_static_fork_pick_up/videos/observation.images.cam_right_wrist_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:409f4ce179fbba18920bf09c17bdeb11a8ed4b9fab080fd6b1d51f9c58d95c46 +size 3745968 diff --git a/tests/data/lerobot/aloha_static_pingpong_test/meta_data/episode_data_index.safetensors b/tests/data/lerobot/aloha_static_pingpong_test/meta_data/episode_data_index.safetensors new file mode 100644 index 00000000..672d46d0 --- /dev/null +++ b/tests/data/lerobot/aloha_static_pingpong_test/meta_data/episode_data_index.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:63985365b49c12789a9a0e3dc330cc3fd492d9e2c70af55928ebdae3b0bcca8e +size 288 diff --git a/tests/data/lerobot/aloha_static_pingpong_test/meta_data/info.json b/tests/data/lerobot/aloha_static_pingpong_test/meta_data/info.json new file mode 100644 index 00000000..8c5c4ee8 --- /dev/null +++ b/tests/data/lerobot/aloha_static_pingpong_test/meta_data/info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb +size 33 diff --git a/tests/data/lerobot/aloha_static_pingpong_test/meta_data/stats.safetensors b/tests/data/lerobot/aloha_static_pingpong_test/meta_data/stats.safetensors new file mode 100644 index 00000000..72661188 --- /dev/null +++ b/tests/data/lerobot/aloha_static_pingpong_test/meta_data/stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ea4df04d8c0a084209af94b60a9bed646c8fe5a6400ab22fa15661bab0fe5a89 +size 4752 diff --git a/tests/data/lerobot/aloha_static_pingpong_test/train/data-00000-of-00001.arrow b/tests/data/lerobot/aloha_static_pingpong_test/train/data-00000-of-00001.arrow new file mode 100644 index 00000000..669d7d10 --- /dev/null +++ b/tests/data/lerobot/aloha_static_pingpong_test/train/data-00000-of-00001.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0045279e29d55a7d686bee2cf5cd80993229710e1ede8c750703c56dcc573d56 +size 274824 diff --git a/tests/data/lerobot/aloha_static_pingpong_test/train/dataset_info.json b/tests/data/lerobot/aloha_static_pingpong_test/train/dataset_info.json new file mode 100644 index 00000000..144d793b --- /dev/null +++ b/tests/data/lerobot/aloha_static_pingpong_test/train/dataset_info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:55969ea876fc62b17afca0777816d93c1c90f23207a3562c907cfbd6502858f9 +size 1237 diff --git a/tests/data/lerobot/aloha_static_pingpong_test/train/state.json b/tests/data/lerobot/aloha_static_pingpong_test/train/state.json new file mode 100644 index 00000000..b9b97753 --- /dev/null +++ b/tests/data/lerobot/aloha_static_pingpong_test/train/state.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:240fae9fa494f7557328dd5e80c5448ceb7271084217b3bc96e123b0e7f4d7bf +size 247 diff --git a/tests/data/lerobot/aloha_static_pingpong_test/videos/observation.images.cam_high_episode_000000.mp4 b/tests/data/lerobot/aloha_static_pingpong_test/videos/observation.images.cam_high_episode_000000.mp4 new file mode 100644 index 00000000..3bfcb67c --- /dev/null +++ b/tests/data/lerobot/aloha_static_pingpong_test/videos/observation.images.cam_high_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:15934e7d48d2914b1d89f262e59799175d2bb2e27a97401e51bc83d5b9fe3837 +size 4403990 diff --git a/tests/data/lerobot/aloha_static_pingpong_test/videos/observation.images.cam_left_wrist_episode_000000.mp4 b/tests/data/lerobot/aloha_static_pingpong_test/videos/observation.images.cam_left_wrist_episode_000000.mp4 new file mode 100644 index 00000000..d97364e4 --- /dev/null +++ b/tests/data/lerobot/aloha_static_pingpong_test/videos/observation.images.cam_left_wrist_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0bf1fbd424c53914d485472ed2bc9d3ecedc1a6f126b83c90fe951deb3a13a65 +size 4025665 diff --git a/tests/data/lerobot/aloha_static_pingpong_test/videos/observation.images.cam_low_episode_000000.mp4 b/tests/data/lerobot/aloha_static_pingpong_test/videos/observation.images.cam_low_episode_000000.mp4 new file mode 100644 index 00000000..5bcdadb6 --- /dev/null +++ b/tests/data/lerobot/aloha_static_pingpong_test/videos/observation.images.cam_low_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8965a19d085ffff9de779de0eba79a6a51f4bb84325a2919604498753778e1c5 +size 3271435 diff --git a/tests/data/lerobot/aloha_static_pingpong_test/videos/observation.images.cam_right_wrist_episode_000000.mp4 b/tests/data/lerobot/aloha_static_pingpong_test/videos/observation.images.cam_right_wrist_episode_000000.mp4 new file mode 100644 index 00000000..71553951 --- /dev/null +++ b/tests/data/lerobot/aloha_static_pingpong_test/videos/observation.images.cam_right_wrist_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:961a12feea3870f45d4ae3fb9a20b8bdb8928020d96c556c24cb48c10a3c9200 +size 3972027 diff --git a/tests/data/lerobot/aloha_static_pro_pencil/meta_data/episode_data_index.safetensors b/tests/data/lerobot/aloha_static_pro_pencil/meta_data/episode_data_index.safetensors new file mode 100644 index 00000000..0a3e7b18 --- /dev/null +++ b/tests/data/lerobot/aloha_static_pro_pencil/meta_data/episode_data_index.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e36408014ca0be67de36d9d5d0910cb94734ac76421d8bdb5dbb9bd0707ef82f +size 528 diff --git a/tests/data/lerobot/aloha_static_pro_pencil/meta_data/info.json b/tests/data/lerobot/aloha_static_pro_pencil/meta_data/info.json new file mode 100644 index 00000000..8c5c4ee8 --- /dev/null +++ b/tests/data/lerobot/aloha_static_pro_pencil/meta_data/info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb +size 33 diff --git a/tests/data/lerobot/aloha_static_pro_pencil/meta_data/stats.safetensors b/tests/data/lerobot/aloha_static_pro_pencil/meta_data/stats.safetensors new file mode 100644 index 00000000..6da8678b --- /dev/null +++ b/tests/data/lerobot/aloha_static_pro_pencil/meta_data/stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7a54e6fbefceec2db7c850d872ba825cbddef0ddfeb7621beca2f26705344de4 +size 4752 diff --git a/tests/data/lerobot/aloha_static_pro_pencil/train/data-00000-of-00001.arrow b/tests/data/lerobot/aloha_static_pro_pencil/train/data-00000-of-00001.arrow new file mode 100644 index 00000000..97b16748 --- /dev/null +++ b/tests/data/lerobot/aloha_static_pro_pencil/train/data-00000-of-00001.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:783d2d14b19a81ded4e00226c20f1d24f6b52f0522a661d716afdb22d6346643 +size 161792 diff --git a/tests/data/lerobot/aloha_static_pro_pencil/train/dataset_info.json b/tests/data/lerobot/aloha_static_pro_pencil/train/dataset_info.json new file mode 100644 index 00000000..144d793b --- /dev/null +++ b/tests/data/lerobot/aloha_static_pro_pencil/train/dataset_info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:55969ea876fc62b17afca0777816d93c1c90f23207a3562c907cfbd6502858f9 +size 1237 diff --git a/tests/data/lerobot/aloha_static_pro_pencil/train/state.json b/tests/data/lerobot/aloha_static_pro_pencil/train/state.json new file mode 100644 index 00000000..b2e0f18a --- /dev/null +++ b/tests/data/lerobot/aloha_static_pro_pencil/train/state.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:db661ba3cf29bf8e16fba3fb79041f987e4aa0f43305e4e81fe1fcc58925d5ca +size 247 diff --git a/tests/data/lerobot/aloha_static_pro_pencil/videos/observation.images.cam_high_episode_000000.mp4 b/tests/data/lerobot/aloha_static_pro_pencil/videos/observation.images.cam_high_episode_000000.mp4 new file mode 100644 index 00000000..965bd33b --- /dev/null +++ b/tests/data/lerobot/aloha_static_pro_pencil/videos/observation.images.cam_high_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:72437edfae1ae4f76ee1800b77c4f31f0e60a7f49d12447a07bcae8710217f26 +size 4493363 diff --git a/tests/data/lerobot/aloha_static_pro_pencil/videos/observation.images.cam_left_wrist_episode_000000.mp4 b/tests/data/lerobot/aloha_static_pro_pencil/videos/observation.images.cam_left_wrist_episode_000000.mp4 new file mode 100644 index 00000000..9443d090 --- /dev/null +++ b/tests/data/lerobot/aloha_static_pro_pencil/videos/observation.images.cam_left_wrist_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1cfec31227cd9f6eec767dd14cf480da92ae17808b003531ee596ca37ae3f2c2 +size 4528034 diff --git a/tests/data/lerobot/aloha_static_pro_pencil/videos/observation.images.cam_low_episode_000000.mp4 b/tests/data/lerobot/aloha_static_pro_pencil/videos/observation.images.cam_low_episode_000000.mp4 new file mode 100644 index 00000000..6485a67a --- /dev/null +++ b/tests/data/lerobot/aloha_static_pro_pencil/videos/observation.images.cam_low_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:24971ddc33b11e012cf2e5d2ae89ac2f8d06f31ae82db5b8ade5583c812a96a3 +size 3243037 diff --git a/tests/data/lerobot/aloha_static_pro_pencil/videos/observation.images.cam_right_wrist_episode_000000.mp4 b/tests/data/lerobot/aloha_static_pro_pencil/videos/observation.images.cam_right_wrist_episode_000000.mp4 new file mode 100644 index 00000000..3fa03fb0 --- /dev/null +++ b/tests/data/lerobot/aloha_static_pro_pencil/videos/observation.images.cam_right_wrist_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1f8e887ee52d8ad4125dceea04aa8db83477c924a7d8bc3cc3fbeb7085814b3b +size 4517634 diff --git a/tests/data/lerobot/aloha_static_screw_driver/meta_data/episode_data_index.safetensors b/tests/data/lerobot/aloha_static_screw_driver/meta_data/episode_data_index.safetensors new file mode 100644 index 00000000..4195a89f --- /dev/null +++ b/tests/data/lerobot/aloha_static_screw_driver/meta_data/episode_data_index.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4500f31e62f0928a837fa71783acacda0db516c7b00d0586a41ea5fd8fc5e772 +size 928 diff --git a/tests/data/lerobot/aloha_static_screw_driver/meta_data/info.json b/tests/data/lerobot/aloha_static_screw_driver/meta_data/info.json new file mode 100644 index 00000000..8c5c4ee8 --- /dev/null +++ b/tests/data/lerobot/aloha_static_screw_driver/meta_data/info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb +size 33 diff --git a/tests/data/lerobot/aloha_static_screw_driver/meta_data/stats.safetensors b/tests/data/lerobot/aloha_static_screw_driver/meta_data/stats.safetensors new file mode 100644 index 00000000..092cce00 --- /dev/null +++ b/tests/data/lerobot/aloha_static_screw_driver/meta_data/stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:df11137a7ff1c75eedeae7b6b15fbfbaf9d4c152583f34371b95add3a10bd2b1 +size 4752 diff --git a/tests/data/lerobot/aloha_static_screw_driver/train/data-00000-of-00001.arrow b/tests/data/lerobot/aloha_static_screw_driver/train/data-00000-of-00001.arrow new file mode 100644 index 00000000..ecb2c3e6 --- /dev/null +++ b/tests/data/lerobot/aloha_static_screw_driver/train/data-00000-of-00001.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7480735fb1263858492ed6384b43606c8019f46ac06cc12b945170db72b6e487 +size 184336 diff --git a/tests/data/lerobot/aloha_static_screw_driver/train/dataset_info.json b/tests/data/lerobot/aloha_static_screw_driver/train/dataset_info.json new file mode 100644 index 00000000..144d793b --- /dev/null +++ b/tests/data/lerobot/aloha_static_screw_driver/train/dataset_info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:55969ea876fc62b17afca0777816d93c1c90f23207a3562c907cfbd6502858f9 +size 1237 diff --git a/tests/data/lerobot/aloha_static_screw_driver/train/state.json b/tests/data/lerobot/aloha_static_screw_driver/train/state.json new file mode 100644 index 00000000..176af15f --- /dev/null +++ b/tests/data/lerobot/aloha_static_screw_driver/train/state.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d842fe613f4f21d84bbf3e233ea637bce2b20e4145cc23f964a5b7a18f002051 +size 247 diff --git a/tests/data/lerobot/aloha_static_screw_driver/videos/observation.images.cam_high_episode_000000.mp4 b/tests/data/lerobot/aloha_static_screw_driver/videos/observation.images.cam_high_episode_000000.mp4 new file mode 100644 index 00000000..c35e8b8f --- /dev/null +++ b/tests/data/lerobot/aloha_static_screw_driver/videos/observation.images.cam_high_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b184062800c884b7f75d57d69619dc9a85fe2320acc0d6591f76cbab1924bbaf +size 2918070 diff --git a/tests/data/lerobot/aloha_static_screw_driver/videos/observation.images.cam_left_wrist_episode_000000.mp4 b/tests/data/lerobot/aloha_static_screw_driver/videos/observation.images.cam_left_wrist_episode_000000.mp4 new file mode 100644 index 00000000..e85a8339 --- /dev/null +++ b/tests/data/lerobot/aloha_static_screw_driver/videos/observation.images.cam_left_wrist_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7fb27591eb299d2a452dca6dd56154149629597d9acb0f40b98b78794758f064 +size 2084920 diff --git a/tests/data/lerobot/aloha_static_screw_driver/videos/observation.images.cam_low_episode_000000.mp4 b/tests/data/lerobot/aloha_static_screw_driver/videos/observation.images.cam_low_episode_000000.mp4 new file mode 100644 index 00000000..7cb600d0 --- /dev/null +++ b/tests/data/lerobot/aloha_static_screw_driver/videos/observation.images.cam_low_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a49ab61d1c17bc3f69a157e292625885342b3f324b1b84c79d843572ce03c17e +size 2194056 diff --git a/tests/data/lerobot/aloha_static_screw_driver/videos/observation.images.cam_right_wrist_episode_000000.mp4 b/tests/data/lerobot/aloha_static_screw_driver/videos/observation.images.cam_right_wrist_episode_000000.mp4 new file mode 100644 index 00000000..1fc81239 --- /dev/null +++ b/tests/data/lerobot/aloha_static_screw_driver/videos/observation.images.cam_right_wrist_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:79765f8cb272841bed0992d15ba47622d73bb0c22970f8cd53382f178081d7cf +size 2364180 diff --git a/tests/data/lerobot/aloha_static_tape/meta_data/episode_data_index.safetensors b/tests/data/lerobot/aloha_static_tape/meta_data/episode_data_index.safetensors new file mode 100644 index 00000000..c4f52e96 --- /dev/null +++ b/tests/data/lerobot/aloha_static_tape/meta_data/episode_data_index.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ed655b17ac00510afe57cc340cc8fa3b0b7e377f9de41fa53092425174b9730d +size 928 diff --git a/tests/data/lerobot/aloha_static_tape/meta_data/info.json b/tests/data/lerobot/aloha_static_tape/meta_data/info.json new file mode 100644 index 00000000..8c5c4ee8 --- /dev/null +++ b/tests/data/lerobot/aloha_static_tape/meta_data/info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb +size 33 diff --git a/tests/data/lerobot/aloha_static_tape/meta_data/stats.safetensors b/tests/data/lerobot/aloha_static_tape/meta_data/stats.safetensors new file mode 100644 index 00000000..edaeec16 --- /dev/null +++ b/tests/data/lerobot/aloha_static_tape/meta_data/stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4c7e81c07331d8caea03629ea375b2c5340a743df7d9eeafa09b91a8caa5a91d +size 4208 diff --git a/tests/data/lerobot/aloha_static_tape/train/data-00000-of-00001.arrow b/tests/data/lerobot/aloha_static_tape/train/data-00000-of-00001.arrow new file mode 100644 index 00000000..e0d1657e --- /dev/null +++ b/tests/data/lerobot/aloha_static_tape/train/data-00000-of-00001.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e93500bcef4b6e05090336376677e8448b360313cc64103d671efe56e0a1762c +size 280536 diff --git a/tests/data/lerobot/aloha_static_tape/train/dataset_info.json b/tests/data/lerobot/aloha_static_tape/train/dataset_info.json new file mode 100644 index 00000000..7fe0ac4c --- /dev/null +++ b/tests/data/lerobot/aloha_static_tape/train/dataset_info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e2e066afefdee57f3bc534085ab7af54e62d3ab2736d42863a89deb743cd0d04 +size 1075 diff --git a/tests/data/lerobot/aloha_static_tape/train/state.json b/tests/data/lerobot/aloha_static_tape/train/state.json new file mode 100644 index 00000000..c4f01bcc --- /dev/null +++ b/tests/data/lerobot/aloha_static_tape/train/state.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:592c54d3d35819519fab84710b9065fa4fec379fb5f6f0fcc16204bc8342afbc +size 247 diff --git a/tests/data/lerobot/aloha_static_tape/videos/observation.images.cam_high_episode_000000.mp4 b/tests/data/lerobot/aloha_static_tape/videos/observation.images.cam_high_episode_000000.mp4 new file mode 100644 index 00000000..7023e6ef --- /dev/null +++ b/tests/data/lerobot/aloha_static_tape/videos/observation.images.cam_high_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:609ae43ed4cdbace153441f7b5c242a05a6d77b03bf95a54085d9d516f99177d +size 4952309 diff --git a/tests/data/lerobot/aloha_static_tape/videos/observation.images.cam_left_wrist_episode_000000.mp4 b/tests/data/lerobot/aloha_static_tape/videos/observation.images.cam_left_wrist_episode_000000.mp4 new file mode 100644 index 00000000..d327ad86 --- /dev/null +++ b/tests/data/lerobot/aloha_static_tape/videos/observation.images.cam_left_wrist_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d4fc833e1230e3bd76bebe05b8140cf8b7ddedcd9f92ecada24cf18797b1122a +size 3595177 diff --git a/tests/data/lerobot/aloha_static_tape/videos/observation.images.cam_low_episode_000000.mp4 b/tests/data/lerobot/aloha_static_tape/videos/observation.images.cam_low_episode_000000.mp4 new file mode 100644 index 00000000..41c475c3 --- /dev/null +++ b/tests/data/lerobot/aloha_static_tape/videos/observation.images.cam_low_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0b6ba12e71d62e794ada84ce6ed60d7d95b1966a019e08461935382869a934a0 +size 4198952 diff --git a/tests/data/lerobot/aloha_static_tape/videos/observation.images.cam_right_wrist_episode_000000.mp4 b/tests/data/lerobot/aloha_static_tape/videos/observation.images.cam_right_wrist_episode_000000.mp4 new file mode 100644 index 00000000..11fa8895 --- /dev/null +++ b/tests/data/lerobot/aloha_static_tape/videos/observation.images.cam_right_wrist_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0cc38c231137681dfbbc881545f022cbe8eaab31e0daa172c3cc0ce59b4a747e +size 3352048 diff --git a/tests/data/lerobot/aloha_static_thread_velcro/meta_data/episode_data_index.safetensors b/tests/data/lerobot/aloha_static_thread_velcro/meta_data/episode_data_index.safetensors new file mode 100644 index 00000000..18fe71f2 --- /dev/null +++ b/tests/data/lerobot/aloha_static_thread_velcro/meta_data/episode_data_index.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e035ea351dc5a5456bd43d75acbe292ca27e03d340ab7fbb325a42abe4f7cacf +size 672 diff --git a/tests/data/lerobot/aloha_static_thread_velcro/meta_data/info.json b/tests/data/lerobot/aloha_static_thread_velcro/meta_data/info.json new file mode 100644 index 00000000..8c5c4ee8 --- /dev/null +++ b/tests/data/lerobot/aloha_static_thread_velcro/meta_data/info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb +size 33 diff --git a/tests/data/lerobot/aloha_static_thread_velcro/meta_data/stats.safetensors b/tests/data/lerobot/aloha_static_thread_velcro/meta_data/stats.safetensors new file mode 100644 index 00000000..88472415 --- /dev/null +++ b/tests/data/lerobot/aloha_static_thread_velcro/meta_data/stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:61d5deeefd272040d74b61ce07a08fd04ca9e2edd6c80f8a6b3e3e4eff91bfa8 +size 4208 diff --git a/tests/data/lerobot/aloha_static_thread_velcro/train/data-00000-of-00001.arrow b/tests/data/lerobot/aloha_static_thread_velcro/train/data-00000-of-00001.arrow new file mode 100644 index 00000000..a77c1adf --- /dev/null +++ b/tests/data/lerobot/aloha_static_thread_velcro/train/data-00000-of-00001.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e11a8bd12e572d46f68bd5705822f659f3933009dba93beb15623c4433a12420 +size 240928 diff --git a/tests/data/lerobot/aloha_static_thread_velcro/train/dataset_info.json b/tests/data/lerobot/aloha_static_thread_velcro/train/dataset_info.json new file mode 100644 index 00000000..7fe0ac4c --- /dev/null +++ b/tests/data/lerobot/aloha_static_thread_velcro/train/dataset_info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e2e066afefdee57f3bc534085ab7af54e62d3ab2736d42863a89deb743cd0d04 +size 1075 diff --git a/tests/data/lerobot/aloha_static_thread_velcro/train/state.json b/tests/data/lerobot/aloha_static_thread_velcro/train/state.json new file mode 100644 index 00000000..fcd38aa8 --- /dev/null +++ b/tests/data/lerobot/aloha_static_thread_velcro/train/state.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5d1ab3212e4f5b5d9873f466c99c37078651d54f087758d8c5cee452701bbcbc +size 247 diff --git a/tests/data/lerobot/aloha_static_thread_velcro/videos/observation.images.cam_high_episode_000000.mp4 b/tests/data/lerobot/aloha_static_thread_velcro/videos/observation.images.cam_high_episode_000000.mp4 new file mode 100644 index 00000000..6f078a3a --- /dev/null +++ b/tests/data/lerobot/aloha_static_thread_velcro/videos/observation.images.cam_high_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ebafc3dd30d15fcbf474269ede2f55f120bdf73c16e20e98504b5bbda3a57149 +size 4385458 diff --git a/tests/data/lerobot/aloha_static_thread_velcro/videos/observation.images.cam_left_wrist_episode_000000.mp4 b/tests/data/lerobot/aloha_static_thread_velcro/videos/observation.images.cam_left_wrist_episode_000000.mp4 new file mode 100644 index 00000000..b859b63b --- /dev/null +++ b/tests/data/lerobot/aloha_static_thread_velcro/videos/observation.images.cam_left_wrist_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aa07ba8bc6a4d4bed7681b7ce3298ccdcd519f0697f6435a59b7b7b7f3502234 +size 3215036 diff --git a/tests/data/lerobot/aloha_static_thread_velcro/videos/observation.images.cam_low_episode_000000.mp4 b/tests/data/lerobot/aloha_static_thread_velcro/videos/observation.images.cam_low_episode_000000.mp4 new file mode 100644 index 00000000..796250b6 --- /dev/null +++ b/tests/data/lerobot/aloha_static_thread_velcro/videos/observation.images.cam_low_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ca347f88364e5e9437f35213d7704629f4a327419bf9edaf0d2747536a254743 +size 3101902 diff --git a/tests/data/lerobot/aloha_static_thread_velcro/videos/observation.images.cam_right_wrist_episode_000000.mp4 b/tests/data/lerobot/aloha_static_thread_velcro/videos/observation.images.cam_right_wrist_episode_000000.mp4 new file mode 100644 index 00000000..10613bce --- /dev/null +++ b/tests/data/lerobot/aloha_static_thread_velcro/videos/observation.images.cam_right_wrist_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8e9336bba763a7ba923c1440e1c99a7e10bbcec9da514dd7426ba31057496d1b +size 3517470 diff --git a/tests/data/lerobot/aloha_static_towel/meta_data/episode_data_index.safetensors b/tests/data/lerobot/aloha_static_towel/meta_data/episode_data_index.safetensors new file mode 100644 index 00000000..6cd34f25 --- /dev/null +++ b/tests/data/lerobot/aloha_static_towel/meta_data/episode_data_index.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e7dbc214a415689ca7fb83b6f8e12ec7824dfe34a66024b0b24bfeb3aeefd0e4 +size 928 diff --git a/tests/data/lerobot/aloha_static_towel/meta_data/info.json b/tests/data/lerobot/aloha_static_towel/meta_data/info.json new file mode 100644 index 00000000..8c5c4ee8 --- /dev/null +++ b/tests/data/lerobot/aloha_static_towel/meta_data/info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb +size 33 diff --git a/tests/data/lerobot/aloha_static_towel/meta_data/stats.safetensors b/tests/data/lerobot/aloha_static_towel/meta_data/stats.safetensors new file mode 100644 index 00000000..12f74dd1 --- /dev/null +++ b/tests/data/lerobot/aloha_static_towel/meta_data/stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8e409aa663d47e19875b2c20a6360f71b335285fa4f091471cb0fa07f82f0801 +size 4752 diff --git a/tests/data/lerobot/aloha_static_towel/train/data-00000-of-00001.arrow b/tests/data/lerobot/aloha_static_towel/train/data-00000-of-00001.arrow new file mode 100644 index 00000000..50829c6f --- /dev/null +++ b/tests/data/lerobot/aloha_static_towel/train/data-00000-of-00001.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8e98de2ea8c8c838600b3bb8407c3eeeab3ab78e6632c49c7377b0491fdd0640 +size 229608 diff --git a/tests/data/lerobot/aloha_static_towel/train/dataset_info.json b/tests/data/lerobot/aloha_static_towel/train/dataset_info.json new file mode 100644 index 00000000..144d793b --- /dev/null +++ b/tests/data/lerobot/aloha_static_towel/train/dataset_info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:55969ea876fc62b17afca0777816d93c1c90f23207a3562c907cfbd6502858f9 +size 1237 diff --git a/tests/data/lerobot/aloha_static_towel/train/state.json b/tests/data/lerobot/aloha_static_towel/train/state.json new file mode 100644 index 00000000..6b123113 --- /dev/null +++ b/tests/data/lerobot/aloha_static_towel/train/state.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dbc5bf36c81ecd2f8acceaec82826ab1be434e79ca721e5b026841a8af66ea77 +size 247 diff --git a/tests/data/lerobot/aloha_static_towel/videos/observation.images.cam_high_episode_000000.mp4 b/tests/data/lerobot/aloha_static_towel/videos/observation.images.cam_high_episode_000000.mp4 new file mode 100644 index 00000000..f49c3c07 --- /dev/null +++ b/tests/data/lerobot/aloha_static_towel/videos/observation.images.cam_high_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e68d65e88e8374d3d274e1dc53e4e033de4cd9edcd92c91b6389491a59de3f2d +size 3405972 diff --git a/tests/data/lerobot/aloha_static_towel/videos/observation.images.cam_left_wrist_episode_000000.mp4 b/tests/data/lerobot/aloha_static_towel/videos/observation.images.cam_left_wrist_episode_000000.mp4 new file mode 100644 index 00000000..4e9fa0d7 --- /dev/null +++ b/tests/data/lerobot/aloha_static_towel/videos/observation.images.cam_left_wrist_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7fcdeacda3ab3818eb7ef0a965bca879b5e7546c70cf567e61137e1d4d0b1939 +size 2846253 diff --git a/tests/data/lerobot/aloha_static_towel/videos/observation.images.cam_low_episode_000000.mp4 b/tests/data/lerobot/aloha_static_towel/videos/observation.images.cam_low_episode_000000.mp4 new file mode 100644 index 00000000..0ba050a0 --- /dev/null +++ b/tests/data/lerobot/aloha_static_towel/videos/observation.images.cam_low_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c7f8b2fe954a90cdaccf69ca1e4650fc351a02113e6ad8fa4ff23814e24db7df +size 2933255 diff --git a/tests/data/lerobot/aloha_static_towel/videos/observation.images.cam_right_wrist_episode_000000.mp4 b/tests/data/lerobot/aloha_static_towel/videos/observation.images.cam_right_wrist_episode_000000.mp4 new file mode 100644 index 00000000..916a8999 --- /dev/null +++ b/tests/data/lerobot/aloha_static_towel/videos/observation.images.cam_right_wrist_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:63dc60583af5e6187c5edc947eafa1cc7b005585f4e4e3f70fd2d17ba3011e21 +size 2208369 diff --git a/tests/data/lerobot/aloha_static_vinh_cup/meta_data/episode_data_index.safetensors b/tests/data/lerobot/aloha_static_vinh_cup/meta_data/episode_data_index.safetensors new file mode 100644 index 00000000..ab81e5db --- /dev/null +++ b/tests/data/lerobot/aloha_static_vinh_cup/meta_data/episode_data_index.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d3fca9067e7bae957ed575e036a22570e352148c56c41827e289d66dd18b3edf +size 1752 diff --git a/tests/data/lerobot/aloha_static_vinh_cup/meta_data/info.json b/tests/data/lerobot/aloha_static_vinh_cup/meta_data/info.json new file mode 100644 index 00000000..8c5c4ee8 --- /dev/null +++ b/tests/data/lerobot/aloha_static_vinh_cup/meta_data/info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb +size 33 diff --git a/tests/data/lerobot/aloha_static_vinh_cup/meta_data/stats.safetensors b/tests/data/lerobot/aloha_static_vinh_cup/meta_data/stats.safetensors new file mode 100644 index 00000000..a2ffb946 --- /dev/null +++ b/tests/data/lerobot/aloha_static_vinh_cup/meta_data/stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e82e25831f4e1576bf7e1ae54edb1a4557b69012fb7cfca790b38d5f86d1f541 +size 4752 diff --git a/tests/data/lerobot/aloha_static_vinh_cup/train/data-00000-of-00001.arrow b/tests/data/lerobot/aloha_static_vinh_cup/train/data-00000-of-00001.arrow new file mode 100644 index 00000000..3c92756b --- /dev/null +++ b/tests/data/lerobot/aloha_static_vinh_cup/train/data-00000-of-00001.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:327b9895d3bcedba90ccf93608cfc02258ed0fe37453cc3d8e129a882b7cce21 +size 229608 diff --git a/tests/data/lerobot/aloha_static_vinh_cup/train/dataset_info.json b/tests/data/lerobot/aloha_static_vinh_cup/train/dataset_info.json new file mode 100644 index 00000000..144d793b --- /dev/null +++ b/tests/data/lerobot/aloha_static_vinh_cup/train/dataset_info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:55969ea876fc62b17afca0777816d93c1c90f23207a3562c907cfbd6502858f9 +size 1237 diff --git a/tests/data/lerobot/aloha_static_vinh_cup/train/state.json b/tests/data/lerobot/aloha_static_vinh_cup/train/state.json new file mode 100644 index 00000000..07581a05 --- /dev/null +++ b/tests/data/lerobot/aloha_static_vinh_cup/train/state.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7f2bd3596d09a26ef6b3ec93bd2640a3aae8b4e57662a46641335c2f538e0d42 +size 247 diff --git a/tests/data/lerobot/aloha_static_vinh_cup/videos/observation.images.cam_high_episode_000000.mp4 b/tests/data/lerobot/aloha_static_vinh_cup/videos/observation.images.cam_high_episode_000000.mp4 new file mode 100644 index 00000000..6351f498 --- /dev/null +++ b/tests/data/lerobot/aloha_static_vinh_cup/videos/observation.images.cam_high_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3b08ab3633d1b4f72f05b1606420454d10a0f5e54dc9e8862ccd7236383c521e +size 3117796 diff --git a/tests/data/lerobot/aloha_static_vinh_cup/videos/observation.images.cam_left_wrist_episode_000000.mp4 b/tests/data/lerobot/aloha_static_vinh_cup/videos/observation.images.cam_left_wrist_episode_000000.mp4 new file mode 100644 index 00000000..3f02b5d8 --- /dev/null +++ b/tests/data/lerobot/aloha_static_vinh_cup/videos/observation.images.cam_left_wrist_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5079b973b7a0741423071fa0270127267424339be87ea03f629a92add56f7e5e +size 2444814 diff --git a/tests/data/lerobot/aloha_static_vinh_cup/videos/observation.images.cam_low_episode_000000.mp4 b/tests/data/lerobot/aloha_static_vinh_cup/videos/observation.images.cam_low_episode_000000.mp4 new file mode 100644 index 00000000..c599e415 --- /dev/null +++ b/tests/data/lerobot/aloha_static_vinh_cup/videos/observation.images.cam_low_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:de35095f8a76f9fc4993322546b6308c8816b3bdbb1ba690f8c0ce71e24237c8 +size 2523272 diff --git a/tests/data/lerobot/aloha_static_vinh_cup/videos/observation.images.cam_right_wrist_episode_000000.mp4 b/tests/data/lerobot/aloha_static_vinh_cup/videos/observation.images.cam_right_wrist_episode_000000.mp4 new file mode 100644 index 00000000..2b262701 --- /dev/null +++ b/tests/data/lerobot/aloha_static_vinh_cup/videos/observation.images.cam_right_wrist_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fc9652e2d4c95d45bead8bb79049b1bf36c8c7e177d2bcf906ee0eee0f8b47c7 +size 3278890 diff --git a/tests/data/lerobot/aloha_static_vinh_cup_left/meta_data/episode_data_index.safetensors b/tests/data/lerobot/aloha_static_vinh_cup_left/meta_data/episode_data_index.safetensors new file mode 100644 index 00000000..4496bc5c --- /dev/null +++ b/tests/data/lerobot/aloha_static_vinh_cup_left/meta_data/episode_data_index.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f447f38b8e6ae16234359701bf3463c0e8a732f332b1ee98fd7838f70030f0cb +size 1736 diff --git a/tests/data/lerobot/aloha_static_vinh_cup_left/meta_data/info.json b/tests/data/lerobot/aloha_static_vinh_cup_left/meta_data/info.json new file mode 100644 index 00000000..8c5c4ee8 --- /dev/null +++ b/tests/data/lerobot/aloha_static_vinh_cup_left/meta_data/info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb +size 33 diff --git a/tests/data/lerobot/aloha_static_vinh_cup_left/meta_data/stats.safetensors b/tests/data/lerobot/aloha_static_vinh_cup_left/meta_data/stats.safetensors new file mode 100644 index 00000000..526e9a3a --- /dev/null +++ b/tests/data/lerobot/aloha_static_vinh_cup_left/meta_data/stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:06288d8c4353940ffe283330ae99cdbb7f0513396983ee518470a2523203c442 +size 4752 diff --git a/tests/data/lerobot/aloha_static_vinh_cup_left/train/data-00000-of-00001.arrow b/tests/data/lerobot/aloha_static_vinh_cup_left/train/data-00000-of-00001.arrow new file mode 100644 index 00000000..0433fbd5 --- /dev/null +++ b/tests/data/lerobot/aloha_static_vinh_cup_left/train/data-00000-of-00001.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:433a22f0e6aea55befe841737eeb2f719fbdc769f1bd6cb8831182b3f09606df +size 229608 diff --git a/tests/data/lerobot/aloha_static_vinh_cup_left/train/dataset_info.json b/tests/data/lerobot/aloha_static_vinh_cup_left/train/dataset_info.json new file mode 100644 index 00000000..144d793b --- /dev/null +++ b/tests/data/lerobot/aloha_static_vinh_cup_left/train/dataset_info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:55969ea876fc62b17afca0777816d93c1c90f23207a3562c907cfbd6502858f9 +size 1237 diff --git a/tests/data/lerobot/aloha_static_vinh_cup_left/train/state.json b/tests/data/lerobot/aloha_static_vinh_cup_left/train/state.json new file mode 100644 index 00000000..7320e0c4 --- /dev/null +++ b/tests/data/lerobot/aloha_static_vinh_cup_left/train/state.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a0f697e633fcb09ac7384c27132ca9a6637782ab7e832de0dc8dcdc036a13e62 +size 247 diff --git a/tests/data/lerobot/aloha_static_vinh_cup_left/videos/observation.images.cam_high_episode_000000.mp4 b/tests/data/lerobot/aloha_static_vinh_cup_left/videos/observation.images.cam_high_episode_000000.mp4 new file mode 100644 index 00000000..9b92e67b --- /dev/null +++ b/tests/data/lerobot/aloha_static_vinh_cup_left/videos/observation.images.cam_high_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2e447647fe72ee8d8fc3e16669c468217660557cee92579c1bc72be43fe7308d +size 3157366 diff --git a/tests/data/lerobot/aloha_static_vinh_cup_left/videos/observation.images.cam_left_wrist_episode_000000.mp4 b/tests/data/lerobot/aloha_static_vinh_cup_left/videos/observation.images.cam_left_wrist_episode_000000.mp4 new file mode 100644 index 00000000..4288c349 --- /dev/null +++ b/tests/data/lerobot/aloha_static_vinh_cup_left/videos/observation.images.cam_left_wrist_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:243268ba041e1cefe2118a0a996e99a986e59e44e19a1cd8787c5af24b777839 +size 3427057 diff --git a/tests/data/lerobot/aloha_static_vinh_cup_left/videos/observation.images.cam_low_episode_000000.mp4 b/tests/data/lerobot/aloha_static_vinh_cup_left/videos/observation.images.cam_low_episode_000000.mp4 new file mode 100644 index 00000000..4901a8ac --- /dev/null +++ b/tests/data/lerobot/aloha_static_vinh_cup_left/videos/observation.images.cam_low_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:86e7dca367bb1fc7ce4ab0e34d568d9e6fcd186c8a48f56fb8ccb814e0c0ac18 +size 2479380 diff --git a/tests/data/lerobot/aloha_static_vinh_cup_left/videos/observation.images.cam_right_wrist_episode_000000.mp4 b/tests/data/lerobot/aloha_static_vinh_cup_left/videos/observation.images.cam_right_wrist_episode_000000.mp4 new file mode 100644 index 00000000..b3538b3d --- /dev/null +++ b/tests/data/lerobot/aloha_static_vinh_cup_left/videos/observation.images.cam_right_wrist_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:87f1915600f6ffb86273f3c3d33406650f68bd563fe2fadb12827cf86a108b05 +size 3431257 diff --git a/tests/data/lerobot/aloha_static_ziploc_slide/meta_data/episode_data_index.safetensors b/tests/data/lerobot/aloha_static_ziploc_slide/meta_data/episode_data_index.safetensors new file mode 100644 index 00000000..31f6d028 --- /dev/null +++ b/tests/data/lerobot/aloha_static_ziploc_slide/meta_data/episode_data_index.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:88e43cf494635aa058483f6cc9953a11ab8269f78849c66f53dfccb3d0a7521b +size 1024 diff --git a/tests/data/lerobot/aloha_static_ziploc_slide/meta_data/info.json b/tests/data/lerobot/aloha_static_ziploc_slide/meta_data/info.json new file mode 100644 index 00000000..8c5c4ee8 --- /dev/null +++ b/tests/data/lerobot/aloha_static_ziploc_slide/meta_data/info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb +size 33 diff --git a/tests/data/lerobot/aloha_static_ziploc_slide/meta_data/stats.safetensors b/tests/data/lerobot/aloha_static_ziploc_slide/meta_data/stats.safetensors new file mode 100644 index 00000000..fac49b2b --- /dev/null +++ b/tests/data/lerobot/aloha_static_ziploc_slide/meta_data/stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f103a4914944258a996601d94515516691749772ff51e227a00b044b9352479b +size 4208 diff --git a/tests/data/lerobot/aloha_static_ziploc_slide/train/data-00000-of-00001.arrow b/tests/data/lerobot/aloha_static_ziploc_slide/train/data-00000-of-00001.arrow new file mode 100644 index 00000000..c0ba4810 --- /dev/null +++ b/tests/data/lerobot/aloha_static_ziploc_slide/train/data-00000-of-00001.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:de3f86519175aba9a03308e16f92b101c669c1c38d8ffb0703baf7ee0c70a2bc +size 122088 diff --git a/tests/data/lerobot/aloha_static_ziploc_slide/train/dataset_info.json b/tests/data/lerobot/aloha_static_ziploc_slide/train/dataset_info.json new file mode 100644 index 00000000..7fe0ac4c --- /dev/null +++ b/tests/data/lerobot/aloha_static_ziploc_slide/train/dataset_info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e2e066afefdee57f3bc534085ab7af54e62d3ab2736d42863a89deb743cd0d04 +size 1075 diff --git a/tests/data/lerobot/aloha_static_ziploc_slide/train/state.json b/tests/data/lerobot/aloha_static_ziploc_slide/train/state.json new file mode 100644 index 00000000..75907f4d --- /dev/null +++ b/tests/data/lerobot/aloha_static_ziploc_slide/train/state.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2bb2830cf2074a67f74027da45417c6e78c1a6476e4d6311ac245ad56a9aeb7a +size 247 diff --git a/tests/data/lerobot/aloha_static_ziploc_slide/videos/observation.images.cam_high_episode_000000.mp4 b/tests/data/lerobot/aloha_static_ziploc_slide/videos/observation.images.cam_high_episode_000000.mp4 new file mode 100644 index 00000000..9d4b355e --- /dev/null +++ b/tests/data/lerobot/aloha_static_ziploc_slide/videos/observation.images.cam_high_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4f829ea261a361c974db69b358f4bff54aabe1ab7290499dc1cd10dab07f8759 +size 2054720 diff --git a/tests/data/lerobot/aloha_static_ziploc_slide/videos/observation.images.cam_left_wrist_episode_000000.mp4 b/tests/data/lerobot/aloha_static_ziploc_slide/videos/observation.images.cam_left_wrist_episode_000000.mp4 new file mode 100644 index 00000000..47d2a043 --- /dev/null +++ b/tests/data/lerobot/aloha_static_ziploc_slide/videos/observation.images.cam_left_wrist_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2598c5ae7ba614e2eb729dc12220aa3618b60cbb1d69736d4d1b7e53ca548c2c +size 1422844 diff --git a/tests/data/lerobot/aloha_static_ziploc_slide/videos/observation.images.cam_low_episode_000000.mp4 b/tests/data/lerobot/aloha_static_ziploc_slide/videos/observation.images.cam_low_episode_000000.mp4 new file mode 100644 index 00000000..82a07ab2 --- /dev/null +++ b/tests/data/lerobot/aloha_static_ziploc_slide/videos/observation.images.cam_low_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:36f59dd411e43e34629e6191d2d7d96eedd9ec5225ee1f66ac99d413a2a08639 +size 2649400 diff --git a/tests/data/lerobot/aloha_static_ziploc_slide/videos/observation.images.cam_right_wrist_episode_000000.mp4 b/tests/data/lerobot/aloha_static_ziploc_slide/videos/observation.images.cam_right_wrist_episode_000000.mp4 new file mode 100644 index 00000000..989ab175 --- /dev/null +++ b/tests/data/lerobot/aloha_static_ziploc_slide/videos/observation.images.cam_right_wrist_episode_000000.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b6a99787d215fb56532b542efae97f6cd04046e8bdc8df3baab2542de48a6c83 +size 1543621 diff --git a/tests/data/lerobot/pusht/meta_data/episode_data_index.safetensors b/tests/data/lerobot/pusht/meta_data/episode_data_index.safetensors index 3511c266..600f8e0f 100644 Binary files a/tests/data/lerobot/pusht/meta_data/episode_data_index.safetensors and b/tests/data/lerobot/pusht/meta_data/episode_data_index.safetensors differ diff --git a/tests/data/lerobot/pusht/meta_data/info.json b/tests/data/lerobot/pusht/meta_data/info.json index b7f39715..1df647c6 100644 --- a/tests/data/lerobot/pusht/meta_data/info.json +++ b/tests/data/lerobot/pusht/meta_data/info.json @@ -1,4 +1,3 @@ -{ - "fps": 10, - "video": 1 -} \ No newline at end of file +version https://git-lfs.github.com/spec/v1 +oid sha256:c306c34ef0ae885bb607026a819b0eeddbc664cba1254f67d0615f90675ca485 +size 33 diff --git a/tests/data/lerobot/pusht/meta_data/stats.safetensors b/tests/data/lerobot/pusht/meta_data/stats.safetensors index e4ebbefe..6202887c 100644 Binary files a/tests/data/lerobot/pusht/meta_data/stats.safetensors and b/tests/data/lerobot/pusht/meta_data/stats.safetensors differ diff --git a/tests/data/lerobot/pusht/train/data-00000-of-00001.arrow b/tests/data/lerobot/pusht/train/data-00000-of-00001.arrow index b99aa290..c46989b3 100644 Binary files a/tests/data/lerobot/pusht/train/data-00000-of-00001.arrow and b/tests/data/lerobot/pusht/train/data-00000-of-00001.arrow differ diff --git a/tests/data/lerobot/pusht/train/dataset_info.json b/tests/data/lerobot/pusht/train/dataset_info.json index a0db336b..e6cac4dc 100644 --- a/tests/data/lerobot/pusht/train/dataset_info.json +++ b/tests/data/lerobot/pusht/train/dataset_info.json @@ -1,55 +1,3 @@ -{ - "citation": "", - "description": "", - "features": { - "observation.image": { - "_type": "VideoFrame" - }, - "observation.state": { - "feature": { - "dtype": "float32", - "_type": "Value" - }, - "length": 2, - "_type": "Sequence" - }, - "action": { - "feature": { - "dtype": "float32", - "_type": "Value" - }, - "length": 2, - "_type": "Sequence" - }, - "episode_index": { - "dtype": "int64", - "_type": "Value" - }, - "frame_index": { - "dtype": "int64", - "_type": "Value" - }, - "timestamp": { - "dtype": "float32", - "_type": "Value" - }, - "next.reward": { - "dtype": "float32", - "_type": "Value" - }, - "next.done": { - "dtype": "bool", - "_type": "Value" - }, - "next.success": { - "dtype": "bool", - "_type": "Value" - }, - "index": { - "dtype": "int64", - "_type": "Value" - } - }, - "homepage": "", - "license": "" -} \ No newline at end of file +version https://git-lfs.github.com/spec/v1 +oid sha256:30d9830b3ef9fc452a1988cd8bcc9ccac35cf8a8081e2a33049369224053cc1b +size 987 diff --git a/tests/data/lerobot/pusht/train/state.json b/tests/data/lerobot/pusht/train/state.json index 776f29ff..a37ac848 100644 --- a/tests/data/lerobot/pusht/train/state.json +++ b/tests/data/lerobot/pusht/train/state.json @@ -1,13 +1,3 @@ -{ - "_data_files": [ - { - "filename": "data-00000-of-00001.arrow" - } - ], - "_fingerprint": "3e02d7879f423c56", - "_format_columns": null, - "_format_kwargs": {}, - "_format_type": null, - "_output_all_columns": false, - "_split": null -} \ No newline at end of file +version https://git-lfs.github.com/spec/v1 +oid sha256:8c1a72239bb56a6c5714f18d849557c89feb858840e8f86689d017bb49551379 +size 247 diff --git a/tests/data/lerobot/pusht/videos/observation.image_episode_000000.mp4 b/tests/data/lerobot/pusht/videos/observation.image_episode_000000.mp4 index b2040bdd..f6cfaadc 100644 Binary files a/tests/data/lerobot/pusht/videos/observation.image_episode_000000.mp4 and b/tests/data/lerobot/pusht/videos/observation.image_episode_000000.mp4 differ diff --git a/tests/data/lerobot/pusht_image/meta_data/episode_data_index.safetensors b/tests/data/lerobot/pusht_image/meta_data/episode_data_index.safetensors new file mode 100644 index 00000000..529898ff --- /dev/null +++ b/tests/data/lerobot/pusht_image/meta_data/episode_data_index.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6ab094bd9de1a90273cefa1e02a84b3c3916e7812b9c73932f8f92ab2f1b9ba9 +size 3432 diff --git a/tests/data/lerobot/pusht_image/meta_data/info.json b/tests/data/lerobot/pusht_image/meta_data/info.json new file mode 100644 index 00000000..1c234978 --- /dev/null +++ b/tests/data/lerobot/pusht_image/meta_data/info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dbf25de102227dd2d8c3b6c61e1fc25a026d44f151161b88bc9a9eb101e942e4 +size 33 diff --git a/tests/data/lerobot/pusht_image/meta_data/stats.safetensors b/tests/data/lerobot/pusht_image/meta_data/stats.safetensors new file mode 100644 index 00000000..d788079d --- /dev/null +++ b/tests/data/lerobot/pusht_image/meta_data/stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:65ceff2650ebbba2ee024be9ec083b7f8f20a69cd2c5bd6624382fe5fb974697 +size 3056 diff --git a/tests/data/lerobot/pusht_image/train/data-00000-of-00001.arrow b/tests/data/lerobot/pusht_image/train/data-00000-of-00001.arrow new file mode 100644 index 00000000..01a08dad --- /dev/null +++ b/tests/data/lerobot/pusht_image/train/data-00000-of-00001.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c453a47e3852f3ca09eedfa0f47fabacccff676809678fc3bded273d51ff25e3 +size 197792 diff --git a/tests/data/lerobot/pusht_image/train/dataset_info.json b/tests/data/lerobot/pusht_image/train/dataset_info.json new file mode 100644 index 00000000..7d38d37b --- /dev/null +++ b/tests/data/lerobot/pusht_image/train/dataset_info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:af0c9dfe9d1e8caa0f9b85ea64895ac7898d6a39414b00b6ced19955e8eceef6 +size 982 diff --git a/tests/data/lerobot/pusht_image/train/state.json b/tests/data/lerobot/pusht_image/train/state.json new file mode 100644 index 00000000..5c5e609a --- /dev/null +++ b/tests/data/lerobot/pusht_image/train/state.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:87d7942da6c4fe7c7a6caefa9a170c0929ac1c57a08889dec8edfe1904e85d42 +size 247 diff --git a/tests/data/lerobot/umi_cup_in_the_wild/meta_data/episode_data_index.safetensors b/tests/data/lerobot/umi_cup_in_the_wild/meta_data/episode_data_index.safetensors index 1505d613..7359c867 100644 Binary files a/tests/data/lerobot/umi_cup_in_the_wild/meta_data/episode_data_index.safetensors and b/tests/data/lerobot/umi_cup_in_the_wild/meta_data/episode_data_index.safetensors differ diff --git a/tests/data/lerobot/umi_cup_in_the_wild/meta_data/info.json b/tests/data/lerobot/umi_cup_in_the_wild/meta_data/info.json index b7f39715..1df647c6 100644 --- a/tests/data/lerobot/umi_cup_in_the_wild/meta_data/info.json +++ b/tests/data/lerobot/umi_cup_in_the_wild/meta_data/info.json @@ -1,4 +1,3 @@ -{ - "fps": 10, - "video": 1 -} \ No newline at end of file +version https://git-lfs.github.com/spec/v1 +oid sha256:c306c34ef0ae885bb607026a819b0eeddbc664cba1254f67d0615f90675ca485 +size 33 diff --git a/tests/data/lerobot/umi_cup_in_the_wild/meta_data/stats.safetensors b/tests/data/lerobot/umi_cup_in_the_wild/meta_data/stats.safetensors index d936f449..9519460e 100644 Binary files a/tests/data/lerobot/umi_cup_in_the_wild/meta_data/stats.safetensors and b/tests/data/lerobot/umi_cup_in_the_wild/meta_data/stats.safetensors differ diff --git a/tests/data/lerobot/umi_cup_in_the_wild/train/data-00000-of-00001.arrow b/tests/data/lerobot/umi_cup_in_the_wild/train/data-00000-of-00001.arrow index 11f45a5d..ab8c75f1 100644 Binary files a/tests/data/lerobot/umi_cup_in_the_wild/train/data-00000-of-00001.arrow and b/tests/data/lerobot/umi_cup_in_the_wild/train/data-00000-of-00001.arrow differ diff --git a/tests/data/lerobot/umi_cup_in_the_wild/train/dataset_info.json b/tests/data/lerobot/umi_cup_in_the_wild/train/dataset_info.json index f590f3e4..822a028e 100644 --- a/tests/data/lerobot/umi_cup_in_the_wild/train/dataset_info.json +++ b/tests/data/lerobot/umi_cup_in_the_wild/train/dataset_info.json @@ -1,67 +1,3 @@ -{ - "citation": "", - "description": "", - "features": { - "observation.image": { - "_type": "VideoFrame" - }, - "observation.state": { - "feature": { - "dtype": "float32", - "_type": "Value" - }, - "length": 7, - "_type": "Sequence" - }, - "episode_index": { - "dtype": "int64", - "_type": "Value" - }, - "frame_index": { - "dtype": "int64", - "_type": "Value" - }, - "timestamp": { - "dtype": "float32", - "_type": "Value" - }, - "episode_data_index_from": { - "dtype": "int64", - "_type": "Value" - }, - "episode_data_index_to": { - "dtype": "int64", - "_type": "Value" - }, - "end_pose": { - "feature": { - "dtype": "float32", - "_type": "Value" - }, - "length": 6, - "_type": "Sequence" - }, - "start_pos": { - "feature": { - "dtype": "float32", - "_type": "Value" - }, - "length": 6, - "_type": "Sequence" - }, - "gripper_width": { - "feature": { - "dtype": "float32", - "_type": "Value" - }, - "length": 1, - "_type": "Sequence" - }, - "index": { - "dtype": "int64", - "_type": "Value" - } - }, - "homepage": "", - "license": "" -} \ No newline at end of file +version https://git-lfs.github.com/spec/v1 +oid sha256:ffb9f0f0c5001460cfe1d49ba84ad2722be5ee7885de97dd733cf9b36104a702 +size 1245 diff --git a/tests/data/lerobot/umi_cup_in_the_wild/train/state.json b/tests/data/lerobot/umi_cup_in_the_wild/train/state.json index 80e610ba..ec2c1c94 100644 --- a/tests/data/lerobot/umi_cup_in_the_wild/train/state.json +++ b/tests/data/lerobot/umi_cup_in_the_wild/train/state.json @@ -1,13 +1,3 @@ -{ - "_data_files": [ - { - "filename": "data-00000-of-00001.arrow" - } - ], - "_fingerprint": "c8b78ec1bbf7a579", - "_format_columns": null, - "_format_kwargs": {}, - "_format_type": null, - "_output_all_columns": false, - "_split": null -} \ No newline at end of file +version https://git-lfs.github.com/spec/v1 +oid sha256:2d810eb9bdf2f3d2529a17d1a07d7437780383792747299a5d256fc67f991450 +size 247 diff --git a/tests/data/lerobot/umi_cup_in_the_wild/videos/observation.image_episode_000000.mp4 b/tests/data/lerobot/umi_cup_in_the_wild/videos/observation.image_episode_000000.mp4 index 3266cf76..3bf18f72 100644 Binary files a/tests/data/lerobot/umi_cup_in_the_wild/videos/observation.image_episode_000000.mp4 and b/tests/data/lerobot/umi_cup_in_the_wild/videos/observation.image_episode_000000.mp4 differ diff --git a/tests/data/lerobot/xarm_lift_medium/meta_data/episode_data_index.safetensors b/tests/data/lerobot/xarm_lift_medium/meta_data/episode_data_index.safetensors index f5e09ec5..3cf01178 100644 Binary files a/tests/data/lerobot/xarm_lift_medium/meta_data/episode_data_index.safetensors and b/tests/data/lerobot/xarm_lift_medium/meta_data/episode_data_index.safetensors differ diff --git a/tests/data/lerobot/xarm_lift_medium/meta_data/info.json b/tests/data/lerobot/xarm_lift_medium/meta_data/info.json index d73052c5..f161e07e 100644 --- a/tests/data/lerobot/xarm_lift_medium/meta_data/info.json +++ b/tests/data/lerobot/xarm_lift_medium/meta_data/info.json @@ -1,4 +1,3 @@ -{ - "fps": 15, - "video": 1 -} \ No newline at end of file +version https://git-lfs.github.com/spec/v1 +oid sha256:0c6e3c50bb65c5e47ac24bc32e53ab533c78b82c6debab45da6ecea4ce067e37 +size 33 diff --git a/tests/data/lerobot/xarm_lift_medium/meta_data/stats.safetensors b/tests/data/lerobot/xarm_lift_medium/meta_data/stats.safetensors index 712c6252..3687f501 100644 Binary files a/tests/data/lerobot/xarm_lift_medium/meta_data/stats.safetensors and b/tests/data/lerobot/xarm_lift_medium/meta_data/stats.safetensors differ diff --git a/tests/data/lerobot/xarm_lift_medium/train/data-00000-of-00001.arrow b/tests/data/lerobot/xarm_lift_medium/train/data-00000-of-00001.arrow index 9625a747..6976bc35 100644 Binary files a/tests/data/lerobot/xarm_lift_medium/train/data-00000-of-00001.arrow and b/tests/data/lerobot/xarm_lift_medium/train/data-00000-of-00001.arrow differ diff --git a/tests/data/lerobot/xarm_lift_medium/train/dataset_info.json b/tests/data/lerobot/xarm_lift_medium/train/dataset_info.json index 3791deef..979b529c 100644 --- a/tests/data/lerobot/xarm_lift_medium/train/dataset_info.json +++ b/tests/data/lerobot/xarm_lift_medium/train/dataset_info.json @@ -1,51 +1,3 @@ -{ - "citation": "", - "description": "", - "features": { - "observation.image": { - "_type": "VideoFrame" - }, - "observation.state": { - "feature": { - "dtype": "float32", - "_type": "Value" - }, - "length": 4, - "_type": "Sequence" - }, - "action": { - "feature": { - "dtype": "float32", - "_type": "Value" - }, - "length": 4, - "_type": "Sequence" - }, - "episode_index": { - "dtype": "int64", - "_type": "Value" - }, - "frame_index": { - "dtype": "int64", - "_type": "Value" - }, - "timestamp": { - "dtype": "float32", - "_type": "Value" - }, - "next.reward": { - "dtype": "float32", - "_type": "Value" - }, - "next.done": { - "dtype": "bool", - "_type": "Value" - }, - "index": { - "dtype": "int64", - "_type": "Value" - } - }, - "homepage": "", - "license": "" -} \ No newline at end of file +version https://git-lfs.github.com/spec/v1 +oid sha256:a219f973d6535f40737265fd15d81944aabf8eb7527384d28c507926bfa89f25 +size 912 diff --git a/tests/data/lerobot/xarm_lift_medium/train/state.json b/tests/data/lerobot/xarm_lift_medium/train/state.json index 3989b594..c476db98 100644 --- a/tests/data/lerobot/xarm_lift_medium/train/state.json +++ b/tests/data/lerobot/xarm_lift_medium/train/state.json @@ -1,13 +1,3 @@ -{ - "_data_files": [ - { - "filename": "data-00000-of-00001.arrow" - } - ], - "_fingerprint": "720072274a55db4d", - "_format_columns": null, - "_format_kwargs": {}, - "_format_type": null, - "_output_all_columns": false, - "_split": null -} \ No newline at end of file +version https://git-lfs.github.com/spec/v1 +oid sha256:814a0694efe765a72beb63e3d21af715835e085cbbda768955c5c204502a7607 +size 247 diff --git a/tests/data/lerobot/xarm_lift_medium/videos/observation.image_episode_000000.mp4 b/tests/data/lerobot/xarm_lift_medium/videos/observation.image_episode_000000.mp4 index 618b888d..2ede48f4 100644 Binary files a/tests/data/lerobot/xarm_lift_medium/videos/observation.image_episode_000000.mp4 and b/tests/data/lerobot/xarm_lift_medium/videos/observation.image_episode_000000.mp4 differ diff --git a/tests/data/lerobot/xarm_lift_medium_image/meta_data/episode_data_index.safetensors b/tests/data/lerobot/xarm_lift_medium_image/meta_data/episode_data_index.safetensors new file mode 100644 index 00000000..e7e90dc3 --- /dev/null +++ b/tests/data/lerobot/xarm_lift_medium_image/meta_data/episode_data_index.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1ddb52e362094cc1469f34f7d723e235abccd24713b962f9765b4f910e85cebd +size 12936 diff --git a/tests/data/lerobot/xarm_lift_medium_image/meta_data/info.json b/tests/data/lerobot/xarm_lift_medium_image/meta_data/info.json new file mode 100644 index 00000000..f3b70e14 --- /dev/null +++ b/tests/data/lerobot/xarm_lift_medium_image/meta_data/info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f1cdc5343e413f5cc546079201b1cfb3f49a46e2bfcee67912f9eb5420c00ce6 +size 33 diff --git a/tests/data/lerobot/xarm_lift_medium_image/meta_data/stats.safetensors b/tests/data/lerobot/xarm_lift_medium_image/meta_data/stats.safetensors new file mode 100644 index 00000000..c5e2b21b --- /dev/null +++ b/tests/data/lerobot/xarm_lift_medium_image/meta_data/stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:625213579edb7c82580ded28b3bb3dd6894bf27502b401a98de2908a65a2a15d +size 2808 diff --git a/tests/data/lerobot/xarm_lift_medium_image/train/data-00000-of-00001.arrow b/tests/data/lerobot/xarm_lift_medium_image/train/data-00000-of-00001.arrow new file mode 100644 index 00000000..53cd0088 --- /dev/null +++ b/tests/data/lerobot/xarm_lift_medium_image/train/data-00000-of-00001.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1c5814b892ceeb0869ec2f1821ab2c5040cbc0f1351a89f5369c67bd137cfe48 +size 105064 diff --git a/tests/data/lerobot/xarm_lift_medium_image/train/dataset_info.json b/tests/data/lerobot/xarm_lift_medium_image/train/dataset_info.json new file mode 100644 index 00000000..e4b66000 --- /dev/null +++ b/tests/data/lerobot/xarm_lift_medium_image/train/dataset_info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ba5e9fec857b85adfa021d06c8679f6f7e030c2bfc5d3be3f6590fb250a17d7c +size 907 diff --git a/tests/data/lerobot/xarm_lift_medium_image/train/state.json b/tests/data/lerobot/xarm_lift_medium_image/train/state.json new file mode 100644 index 00000000..401c77f8 --- /dev/null +++ b/tests/data/lerobot/xarm_lift_medium_image/train/state.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:939e376b61bed55ef489fdb209444bc22e7e220a3390faa5415b2b39a2612cb5 +size 247 diff --git a/tests/data/lerobot/xarm_lift_medium_replay/meta_data/episode_data_index.safetensors b/tests/data/lerobot/xarm_lift_medium_replay/meta_data/episode_data_index.safetensors index f5e09ec5..3cf01178 100644 Binary files a/tests/data/lerobot/xarm_lift_medium_replay/meta_data/episode_data_index.safetensors and b/tests/data/lerobot/xarm_lift_medium_replay/meta_data/episode_data_index.safetensors differ diff --git a/tests/data/lerobot/xarm_lift_medium_replay/meta_data/info.json b/tests/data/lerobot/xarm_lift_medium_replay/meta_data/info.json index d73052c5..f161e07e 100644 --- a/tests/data/lerobot/xarm_lift_medium_replay/meta_data/info.json +++ b/tests/data/lerobot/xarm_lift_medium_replay/meta_data/info.json @@ -1,4 +1,3 @@ -{ - "fps": 15, - "video": 1 -} \ No newline at end of file +version https://git-lfs.github.com/spec/v1 +oid sha256:0c6e3c50bb65c5e47ac24bc32e53ab533c78b82c6debab45da6ecea4ce067e37 +size 33 diff --git a/tests/data/lerobot/xarm_lift_medium_replay/meta_data/stats.safetensors b/tests/data/lerobot/xarm_lift_medium_replay/meta_data/stats.safetensors index a7548bad..9e4aed95 100644 Binary files a/tests/data/lerobot/xarm_lift_medium_replay/meta_data/stats.safetensors and b/tests/data/lerobot/xarm_lift_medium_replay/meta_data/stats.safetensors differ diff --git a/tests/data/lerobot/xarm_lift_medium_replay/train/data-00000-of-00001.arrow b/tests/data/lerobot/xarm_lift_medium_replay/train/data-00000-of-00001.arrow index 102a6154..27899ba9 100644 Binary files a/tests/data/lerobot/xarm_lift_medium_replay/train/data-00000-of-00001.arrow and b/tests/data/lerobot/xarm_lift_medium_replay/train/data-00000-of-00001.arrow differ diff --git a/tests/data/lerobot/xarm_lift_medium_replay/train/dataset_info.json b/tests/data/lerobot/xarm_lift_medium_replay/train/dataset_info.json index 69bf84eb..bd1e76b9 100644 --- a/tests/data/lerobot/xarm_lift_medium_replay/train/dataset_info.json +++ b/tests/data/lerobot/xarm_lift_medium_replay/train/dataset_info.json @@ -1,51 +1,3 @@ -{ - "citation": "", - "description": "", - "features": { - "observation.image": { - "_type": "VideoFrame" - }, - "observation.state": { - "feature": { - "dtype": "float32", - "_type": "Value" - }, - "length": 4, - "_type": "Sequence" - }, - "action": { - "feature": { - "dtype": "float32", - "_type": "Value" - }, - "length": 3, - "_type": "Sequence" - }, - "episode_index": { - "dtype": "int64", - "_type": "Value" - }, - "frame_index": { - "dtype": "int64", - "_type": "Value" - }, - "timestamp": { - "dtype": "float32", - "_type": "Value" - }, - "next.reward": { - "dtype": "float32", - "_type": "Value" - }, - "next.done": { - "dtype": "bool", - "_type": "Value" - }, - "index": { - "dtype": "int64", - "_type": "Value" - } - }, - "homepage": "", - "license": "" -} \ No newline at end of file +version https://git-lfs.github.com/spec/v1 +oid sha256:8133fe105b6e35182c4e24d8ac092730cb8f684f6591e4f1b3a4a2adaf224c46 +size 912 diff --git a/tests/data/lerobot/xarm_lift_medium_replay/train/state.json b/tests/data/lerobot/xarm_lift_medium_replay/train/state.json index 6522dcbd..8436b889 100644 --- a/tests/data/lerobot/xarm_lift_medium_replay/train/state.json +++ b/tests/data/lerobot/xarm_lift_medium_replay/train/state.json @@ -1,13 +1,3 @@ -{ - "_data_files": [ - { - "filename": "data-00000-of-00001.arrow" - } - ], - "_fingerprint": "9f3d8cbb0b2e74a2", - "_format_columns": null, - "_format_kwargs": {}, - "_format_type": null, - "_output_all_columns": false, - "_split": null -} \ No newline at end of file +version https://git-lfs.github.com/spec/v1 +oid sha256:71f84c95ef32e00060d1658cb601be1bc640a85b9e7200a3af481ba93c145de0 +size 247 diff --git a/tests/data/lerobot/xarm_lift_medium_replay/videos/observation.image_episode_000000.mp4 b/tests/data/lerobot/xarm_lift_medium_replay/videos/observation.image_episode_000000.mp4 index f1089c15..e0dc68fa 100644 Binary files a/tests/data/lerobot/xarm_lift_medium_replay/videos/observation.image_episode_000000.mp4 and b/tests/data/lerobot/xarm_lift_medium_replay/videos/observation.image_episode_000000.mp4 differ diff --git a/tests/data/lerobot/xarm_lift_medium_replay_image/meta_data/episode_data_index.safetensors b/tests/data/lerobot/xarm_lift_medium_replay_image/meta_data/episode_data_index.safetensors new file mode 100644 index 00000000..e7e90dc3 --- /dev/null +++ b/tests/data/lerobot/xarm_lift_medium_replay_image/meta_data/episode_data_index.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1ddb52e362094cc1469f34f7d723e235abccd24713b962f9765b4f910e85cebd +size 12936 diff --git a/tests/data/lerobot/xarm_lift_medium_replay_image/meta_data/info.json b/tests/data/lerobot/xarm_lift_medium_replay_image/meta_data/info.json new file mode 100644 index 00000000..f3b70e14 --- /dev/null +++ b/tests/data/lerobot/xarm_lift_medium_replay_image/meta_data/info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f1cdc5343e413f5cc546079201b1cfb3f49a46e2bfcee67912f9eb5420c00ce6 +size 33 diff --git a/tests/data/lerobot/xarm_lift_medium_replay_image/meta_data/stats.safetensors b/tests/data/lerobot/xarm_lift_medium_replay_image/meta_data/stats.safetensors new file mode 100644 index 00000000..c5e2b21b --- /dev/null +++ b/tests/data/lerobot/xarm_lift_medium_replay_image/meta_data/stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:625213579edb7c82580ded28b3bb3dd6894bf27502b401a98de2908a65a2a15d +size 2808 diff --git a/tests/data/lerobot/xarm_lift_medium_replay_image/train/data-00000-of-00001.arrow b/tests/data/lerobot/xarm_lift_medium_replay_image/train/data-00000-of-00001.arrow new file mode 100644 index 00000000..53cd0088 --- /dev/null +++ b/tests/data/lerobot/xarm_lift_medium_replay_image/train/data-00000-of-00001.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1c5814b892ceeb0869ec2f1821ab2c5040cbc0f1351a89f5369c67bd137cfe48 +size 105064 diff --git a/tests/data/lerobot/xarm_lift_medium_replay_image/train/dataset_info.json b/tests/data/lerobot/xarm_lift_medium_replay_image/train/dataset_info.json new file mode 100644 index 00000000..e4b66000 --- /dev/null +++ b/tests/data/lerobot/xarm_lift_medium_replay_image/train/dataset_info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ba5e9fec857b85adfa021d06c8679f6f7e030c2bfc5d3be3f6590fb250a17d7c +size 907 diff --git a/tests/data/lerobot/xarm_lift_medium_replay_image/train/state.json b/tests/data/lerobot/xarm_lift_medium_replay_image/train/state.json new file mode 100644 index 00000000..401c77f8 --- /dev/null +++ b/tests/data/lerobot/xarm_lift_medium_replay_image/train/state.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:939e376b61bed55ef489fdb209444bc22e7e220a3390faa5415b2b39a2612cb5 +size 247 diff --git a/tests/data/lerobot/xarm_push_medium/meta_data/episode_data_index.safetensors b/tests/data/lerobot/xarm_push_medium/meta_data/episode_data_index.safetensors index f5e09ec5..3cf01178 100644 Binary files a/tests/data/lerobot/xarm_push_medium/meta_data/episode_data_index.safetensors and b/tests/data/lerobot/xarm_push_medium/meta_data/episode_data_index.safetensors differ diff --git a/tests/data/lerobot/xarm_push_medium/meta_data/info.json b/tests/data/lerobot/xarm_push_medium/meta_data/info.json index d73052c5..f161e07e 100644 --- a/tests/data/lerobot/xarm_push_medium/meta_data/info.json +++ b/tests/data/lerobot/xarm_push_medium/meta_data/info.json @@ -1,4 +1,3 @@ -{ - "fps": 15, - "video": 1 -} \ No newline at end of file +version https://git-lfs.github.com/spec/v1 +oid sha256:0c6e3c50bb65c5e47ac24bc32e53ab533c78b82c6debab45da6ecea4ce067e37 +size 33 diff --git a/tests/data/lerobot/xarm_push_medium/meta_data/stats.safetensors b/tests/data/lerobot/xarm_push_medium/meta_data/stats.safetensors index a7548bad..9e4aed95 100644 Binary files a/tests/data/lerobot/xarm_push_medium/meta_data/stats.safetensors and b/tests/data/lerobot/xarm_push_medium/meta_data/stats.safetensors differ diff --git a/tests/data/lerobot/xarm_push_medium/train/data-00000-of-00001.arrow b/tests/data/lerobot/xarm_push_medium/train/data-00000-of-00001.arrow index 102a6154..27899ba9 100644 Binary files a/tests/data/lerobot/xarm_push_medium/train/data-00000-of-00001.arrow and b/tests/data/lerobot/xarm_push_medium/train/data-00000-of-00001.arrow differ diff --git a/tests/data/lerobot/xarm_push_medium/train/dataset_info.json b/tests/data/lerobot/xarm_push_medium/train/dataset_info.json index 69bf84eb..bd1e76b9 100644 --- a/tests/data/lerobot/xarm_push_medium/train/dataset_info.json +++ b/tests/data/lerobot/xarm_push_medium/train/dataset_info.json @@ -1,51 +1,3 @@ -{ - "citation": "", - "description": "", - "features": { - "observation.image": { - "_type": "VideoFrame" - }, - "observation.state": { - "feature": { - "dtype": "float32", - "_type": "Value" - }, - "length": 4, - "_type": "Sequence" - }, - "action": { - "feature": { - "dtype": "float32", - "_type": "Value" - }, - "length": 3, - "_type": "Sequence" - }, - "episode_index": { - "dtype": "int64", - "_type": "Value" - }, - "frame_index": { - "dtype": "int64", - "_type": "Value" - }, - "timestamp": { - "dtype": "float32", - "_type": "Value" - }, - "next.reward": { - "dtype": "float32", - "_type": "Value" - }, - "next.done": { - "dtype": "bool", - "_type": "Value" - }, - "index": { - "dtype": "int64", - "_type": "Value" - } - }, - "homepage": "", - "license": "" -} \ No newline at end of file +version https://git-lfs.github.com/spec/v1 +oid sha256:8133fe105b6e35182c4e24d8ac092730cb8f684f6591e4f1b3a4a2adaf224c46 +size 912 diff --git a/tests/data/lerobot/xarm_push_medium/train/state.json b/tests/data/lerobot/xarm_push_medium/train/state.json index 6522dcbd..8436b889 100644 --- a/tests/data/lerobot/xarm_push_medium/train/state.json +++ b/tests/data/lerobot/xarm_push_medium/train/state.json @@ -1,13 +1,3 @@ -{ - "_data_files": [ - { - "filename": "data-00000-of-00001.arrow" - } - ], - "_fingerprint": "9f3d8cbb0b2e74a2", - "_format_columns": null, - "_format_kwargs": {}, - "_format_type": null, - "_output_all_columns": false, - "_split": null -} \ No newline at end of file +version https://git-lfs.github.com/spec/v1 +oid sha256:71f84c95ef32e00060d1658cb601be1bc640a85b9e7200a3af481ba93c145de0 +size 247 diff --git a/tests/data/lerobot/xarm_push_medium/videos/observation.image_episode_000000.mp4 b/tests/data/lerobot/xarm_push_medium/videos/observation.image_episode_000000.mp4 index f1089c15..e0dc68fa 100644 Binary files a/tests/data/lerobot/xarm_push_medium/videos/observation.image_episode_000000.mp4 and b/tests/data/lerobot/xarm_push_medium/videos/observation.image_episode_000000.mp4 differ diff --git a/tests/data/lerobot/xarm_push_medium_image/meta_data/episode_data_index.safetensors b/tests/data/lerobot/xarm_push_medium_image/meta_data/episode_data_index.safetensors new file mode 100644 index 00000000..e7e90dc3 --- /dev/null +++ b/tests/data/lerobot/xarm_push_medium_image/meta_data/episode_data_index.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1ddb52e362094cc1469f34f7d723e235abccd24713b962f9765b4f910e85cebd +size 12936 diff --git a/tests/data/lerobot/xarm_push_medium_image/meta_data/info.json b/tests/data/lerobot/xarm_push_medium_image/meta_data/info.json new file mode 100644 index 00000000..f3b70e14 --- /dev/null +++ b/tests/data/lerobot/xarm_push_medium_image/meta_data/info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f1cdc5343e413f5cc546079201b1cfb3f49a46e2bfcee67912f9eb5420c00ce6 +size 33 diff --git a/tests/data/lerobot/xarm_push_medium_image/meta_data/stats.safetensors b/tests/data/lerobot/xarm_push_medium_image/meta_data/stats.safetensors new file mode 100644 index 00000000..c5e2b21b --- /dev/null +++ b/tests/data/lerobot/xarm_push_medium_image/meta_data/stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:625213579edb7c82580ded28b3bb3dd6894bf27502b401a98de2908a65a2a15d +size 2808 diff --git a/tests/data/lerobot/xarm_push_medium_image/train/data-00000-of-00001.arrow b/tests/data/lerobot/xarm_push_medium_image/train/data-00000-of-00001.arrow new file mode 100644 index 00000000..53cd0088 --- /dev/null +++ b/tests/data/lerobot/xarm_push_medium_image/train/data-00000-of-00001.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1c5814b892ceeb0869ec2f1821ab2c5040cbc0f1351a89f5369c67bd137cfe48 +size 105064 diff --git a/tests/data/lerobot/xarm_push_medium_image/train/dataset_info.json b/tests/data/lerobot/xarm_push_medium_image/train/dataset_info.json new file mode 100644 index 00000000..e4b66000 --- /dev/null +++ b/tests/data/lerobot/xarm_push_medium_image/train/dataset_info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ba5e9fec857b85adfa021d06c8679f6f7e030c2bfc5d3be3f6590fb250a17d7c +size 907 diff --git a/tests/data/lerobot/xarm_push_medium_image/train/state.json b/tests/data/lerobot/xarm_push_medium_image/train/state.json new file mode 100644 index 00000000..401c77f8 --- /dev/null +++ b/tests/data/lerobot/xarm_push_medium_image/train/state.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:939e376b61bed55ef489fdb209444bc22e7e220a3390faa5415b2b39a2612cb5 +size 247 diff --git a/tests/data/lerobot/xarm_push_medium_replay/meta_data/episode_data_index.safetensors b/tests/data/lerobot/xarm_push_medium_replay/meta_data/episode_data_index.safetensors index f5e09ec5..3cf01178 100644 Binary files a/tests/data/lerobot/xarm_push_medium_replay/meta_data/episode_data_index.safetensors and b/tests/data/lerobot/xarm_push_medium_replay/meta_data/episode_data_index.safetensors differ diff --git a/tests/data/lerobot/xarm_push_medium_replay/meta_data/info.json b/tests/data/lerobot/xarm_push_medium_replay/meta_data/info.json index d73052c5..f161e07e 100644 --- a/tests/data/lerobot/xarm_push_medium_replay/meta_data/info.json +++ b/tests/data/lerobot/xarm_push_medium_replay/meta_data/info.json @@ -1,4 +1,3 @@ -{ - "fps": 15, - "video": 1 -} \ No newline at end of file +version https://git-lfs.github.com/spec/v1 +oid sha256:0c6e3c50bb65c5e47ac24bc32e53ab533c78b82c6debab45da6ecea4ce067e37 +size 33 diff --git a/tests/data/lerobot/xarm_push_medium_replay/meta_data/stats.safetensors b/tests/data/lerobot/xarm_push_medium_replay/meta_data/stats.safetensors index a7548bad..9e4aed95 100644 Binary files a/tests/data/lerobot/xarm_push_medium_replay/meta_data/stats.safetensors and b/tests/data/lerobot/xarm_push_medium_replay/meta_data/stats.safetensors differ diff --git a/tests/data/lerobot/xarm_push_medium_replay/train/data-00000-of-00001.arrow b/tests/data/lerobot/xarm_push_medium_replay/train/data-00000-of-00001.arrow index 102a6154..27899ba9 100644 Binary files a/tests/data/lerobot/xarm_push_medium_replay/train/data-00000-of-00001.arrow and b/tests/data/lerobot/xarm_push_medium_replay/train/data-00000-of-00001.arrow differ diff --git a/tests/data/lerobot/xarm_push_medium_replay/train/dataset_info.json b/tests/data/lerobot/xarm_push_medium_replay/train/dataset_info.json index 69bf84eb..bd1e76b9 100644 --- a/tests/data/lerobot/xarm_push_medium_replay/train/dataset_info.json +++ b/tests/data/lerobot/xarm_push_medium_replay/train/dataset_info.json @@ -1,51 +1,3 @@ -{ - "citation": "", - "description": "", - "features": { - "observation.image": { - "_type": "VideoFrame" - }, - "observation.state": { - "feature": { - "dtype": "float32", - "_type": "Value" - }, - "length": 4, - "_type": "Sequence" - }, - "action": { - "feature": { - "dtype": "float32", - "_type": "Value" - }, - "length": 3, - "_type": "Sequence" - }, - "episode_index": { - "dtype": "int64", - "_type": "Value" - }, - "frame_index": { - "dtype": "int64", - "_type": "Value" - }, - "timestamp": { - "dtype": "float32", - "_type": "Value" - }, - "next.reward": { - "dtype": "float32", - "_type": "Value" - }, - "next.done": { - "dtype": "bool", - "_type": "Value" - }, - "index": { - "dtype": "int64", - "_type": "Value" - } - }, - "homepage": "", - "license": "" -} \ No newline at end of file +version https://git-lfs.github.com/spec/v1 +oid sha256:8133fe105b6e35182c4e24d8ac092730cb8f684f6591e4f1b3a4a2adaf224c46 +size 912 diff --git a/tests/data/lerobot/xarm_push_medium_replay/train/state.json b/tests/data/lerobot/xarm_push_medium_replay/train/state.json index 6522dcbd..8436b889 100644 --- a/tests/data/lerobot/xarm_push_medium_replay/train/state.json +++ b/tests/data/lerobot/xarm_push_medium_replay/train/state.json @@ -1,13 +1,3 @@ -{ - "_data_files": [ - { - "filename": "data-00000-of-00001.arrow" - } - ], - "_fingerprint": "9f3d8cbb0b2e74a2", - "_format_columns": null, - "_format_kwargs": {}, - "_format_type": null, - "_output_all_columns": false, - "_split": null -} \ No newline at end of file +version https://git-lfs.github.com/spec/v1 +oid sha256:71f84c95ef32e00060d1658cb601be1bc640a85b9e7200a3af481ba93c145de0 +size 247 diff --git a/tests/data/lerobot/xarm_push_medium_replay/videos/observation.image_episode_000000.mp4 b/tests/data/lerobot/xarm_push_medium_replay/videos/observation.image_episode_000000.mp4 index f1089c15..e0dc68fa 100644 Binary files a/tests/data/lerobot/xarm_push_medium_replay/videos/observation.image_episode_000000.mp4 and b/tests/data/lerobot/xarm_push_medium_replay/videos/observation.image_episode_000000.mp4 differ diff --git a/tests/data/lerobot/xarm_push_medium_replay_image/meta_data/episode_data_index.safetensors b/tests/data/lerobot/xarm_push_medium_replay_image/meta_data/episode_data_index.safetensors new file mode 100644 index 00000000..e7e90dc3 --- /dev/null +++ b/tests/data/lerobot/xarm_push_medium_replay_image/meta_data/episode_data_index.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1ddb52e362094cc1469f34f7d723e235abccd24713b962f9765b4f910e85cebd +size 12936 diff --git a/tests/data/lerobot/xarm_push_medium_replay_image/meta_data/info.json b/tests/data/lerobot/xarm_push_medium_replay_image/meta_data/info.json new file mode 100644 index 00000000..f3b70e14 --- /dev/null +++ b/tests/data/lerobot/xarm_push_medium_replay_image/meta_data/info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f1cdc5343e413f5cc546079201b1cfb3f49a46e2bfcee67912f9eb5420c00ce6 +size 33 diff --git a/tests/data/lerobot/xarm_push_medium_replay_image/meta_data/stats.safetensors b/tests/data/lerobot/xarm_push_medium_replay_image/meta_data/stats.safetensors new file mode 100644 index 00000000..c5e2b21b --- /dev/null +++ b/tests/data/lerobot/xarm_push_medium_replay_image/meta_data/stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:625213579edb7c82580ded28b3bb3dd6894bf27502b401a98de2908a65a2a15d +size 2808 diff --git a/tests/data/lerobot/xarm_push_medium_replay_image/train/data-00000-of-00001.arrow b/tests/data/lerobot/xarm_push_medium_replay_image/train/data-00000-of-00001.arrow new file mode 100644 index 00000000..53cd0088 --- /dev/null +++ b/tests/data/lerobot/xarm_push_medium_replay_image/train/data-00000-of-00001.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1c5814b892ceeb0869ec2f1821ab2c5040cbc0f1351a89f5369c67bd137cfe48 +size 105064 diff --git a/tests/data/lerobot/xarm_push_medium_replay_image/train/dataset_info.json b/tests/data/lerobot/xarm_push_medium_replay_image/train/dataset_info.json new file mode 100644 index 00000000..e4b66000 --- /dev/null +++ b/tests/data/lerobot/xarm_push_medium_replay_image/train/dataset_info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ba5e9fec857b85adfa021d06c8679f6f7e030c2bfc5d3be3f6590fb250a17d7c +size 907 diff --git a/tests/data/lerobot/xarm_push_medium_replay_image/train/state.json b/tests/data/lerobot/xarm_push_medium_replay_image/train/state.json new file mode 100644 index 00000000..401c77f8 --- /dev/null +++ b/tests/data/lerobot/xarm_push_medium_replay_image/train/state.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:939e376b61bed55ef489fdb209444bc22e7e220a3390faa5415b2b39a2612cb5 +size 247 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_cabinet/frame_0.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_cabinet/frame_0.safetensors new file mode 100644 index 00000000..0b3f80c6 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_cabinet/frame_0.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:32a3b53455e9364264dce5ee98eacefcd32624fbeb0c1ee7c951a45dafe6590e +size 11060269 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_cabinet/frame_1.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_cabinet/frame_1.safetensors new file mode 100644 index 00000000..350e36f0 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_cabinet/frame_1.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4808cc1a1af8b5ebe8fbabbd6d7be70cd9378321062045f359f14547c71973f1 +size 11060269 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_cabinet/frame_1498.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_cabinet/frame_1498.safetensors new file mode 100644 index 00000000..b638bda9 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_cabinet/frame_1498.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c73ae67e3c0858679efc3c83b63a2d8971fd370d57d7880262186848565b7ac9 +size 11060269 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_cabinet/frame_1499.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_cabinet/frame_1499.safetensors new file mode 100644 index 00000000..8bdd6de6 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_cabinet/frame_1499.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d0a32a9adea43ad5f117cd0a9d9079c8e12df7ae608edd16356e82b0067e2d5a +size 11060269 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_cabinet/frame_750.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_cabinet/frame_750.safetensors new file mode 100644 index 00000000..69d53e5c --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_cabinet/frame_750.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:17696b496824421cf2d260a302750ee00fae0ce9b6cf6820f84f211d52ac6c92 +size 11060269 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_cabinet/frame_751.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_cabinet/frame_751.safetensors new file mode 100644 index 00000000..7702715f --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_cabinet/frame_751.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:91fb996ff9935cbd3e1d06d72eae719411551863520f98381b269b7754c83c9f +size 11060269 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_chair/frame_0.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_chair/frame_0.safetensors new file mode 100644 index 00000000..911996d2 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_chair/frame_0.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7d6a3c984ac0777ed7d6b7db277a1623c19affbe7365a104d90f6fc59f39927f +size 11060269 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_chair/frame_1.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_chair/frame_1.safetensors new file mode 100644 index 00000000..59975441 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_chair/frame_1.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a03cf3bd6d225129b81167f3f671a0034f52646ce2cd6f97fdd4c1ba7aa16e14 +size 11060269 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_chair/frame_1000.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_chair/frame_1000.safetensors new file mode 100644 index 00000000..f3615c6b --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_chair/frame_1000.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2d1616690cc28bbb635b503e156d488d2eeb20333b2c8a2cd848eaea98f835a9 +size 11060269 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_chair/frame_1001.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_chair/frame_1001.safetensors new file mode 100644 index 00000000..79fd57a1 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_chair/frame_1001.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:746e739e50d0e29c6f37b62f030aa849568b36a72dec60b8dc537d13d640679b +size 11060269 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_chair/frame_1998.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_chair/frame_1998.safetensors new file mode 100644 index 00000000..92a2c420 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_chair/frame_1998.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7be1b4a25c29d39fe4d0503000b66d164e73c33e4adbe054e202896c17a17525 +size 11060269 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_chair/frame_1999.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_chair/frame_1999.safetensors new file mode 100644 index 00000000..fa1453c0 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_chair/frame_1999.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:129ba85086c93df80be8166390d0364ce93a485ed0831d2581619194d2b500e5 +size 11060269 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_elevator/frame_0.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_elevator/frame_0.safetensors new file mode 100644 index 00000000..8c66eac6 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_elevator/frame_0.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3f989c9938006e36b65793cf81504e62b665a73926e8d5bdcd675bdd23a6b9f9 +size 11060269 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_elevator/frame_1.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_elevator/frame_1.safetensors new file mode 100644 index 00000000..586b5634 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_elevator/frame_1.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:661cf3f5bfa068c6e5318eab8abe943d1674a7d2e738759a8f0dc64f99d3ec38 +size 11060269 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_elevator/frame_1125.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_elevator/frame_1125.safetensors new file mode 100644 index 00000000..faca3291 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_elevator/frame_1125.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0c8b10705606794fbccd986db80584f8aff9e0aa0b94a93a676765198a6ec941 +size 11060269 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_elevator/frame_1126.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_elevator/frame_1126.safetensors new file mode 100644 index 00000000..cd9033ed --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_elevator/frame_1126.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:037ca2aa98af6c16acc1ee53342038783d9f6f12e6d5dca8538baede4a7cac3c +size 11060269 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_elevator/frame_2248.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_elevator/frame_2248.safetensors new file mode 100644 index 00000000..a7e42cbb --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_elevator/frame_2248.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:04084eafdde395698b6413506c5bb97b703afca8ffbebf985562531caafb69aa +size 11060269 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_elevator/frame_2249.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_elevator/frame_2249.safetensors new file mode 100644 index 00000000..ec6f9c69 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_elevator/frame_2249.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:18abee589c474bd628896e89c1c950d4e49f2098039cc7c7399a03eef97e9a30 +size 11060269 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_shrimp/frame_0.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_shrimp/frame_0.safetensors new file mode 100644 index 00000000..b716bcc2 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_shrimp/frame_0.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d213757efff680cbeec99e112913804f0d6a006681b52333175e53d69bc35b49 +size 11060269 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_shrimp/frame_1.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_shrimp/frame_1.safetensors new file mode 100644 index 00000000..fed98889 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_shrimp/frame_1.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:29f66531c8d9ed4f0d96f48b700c459e11eba22ae143cb3aa91fcf4efcb37116 +size 11060269 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_shrimp/frame_1875.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_shrimp/frame_1875.safetensors new file mode 100644 index 00000000..769a9e2d --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_shrimp/frame_1875.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cc913ed5a8a0fae6cd77dee1d8ed2eb0d2d405e3a390561d5eb78343f19131a1 +size 11060269 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_shrimp/frame_1876.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_shrimp/frame_1876.safetensors new file mode 100644 index 00000000..ad24ed8d --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_shrimp/frame_1876.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:87fe3188e72c2cb7f1378a828454d3979ec290f257608a19d81301adc3686ca4 +size 11060269 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_shrimp/frame_3748.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_shrimp/frame_3748.safetensors new file mode 100644 index 00000000..ebd4f3de --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_shrimp/frame_3748.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4a46d04275eec0a121288b5d4a0473a42157ab8a804ff304c21405af388f2e6a +size 11060269 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_shrimp/frame_3749.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_shrimp/frame_3749.safetensors new file mode 100644 index 00000000..eb2290f0 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_shrimp/frame_3749.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5ff968d0cbff4bb271ddb45f90094ba37041536d3f29d0a2a94ec31760532b05 +size 11060269 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_wash_pan/frame_0.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_wash_pan/frame_0.safetensors new file mode 100644 index 00000000..5099bcbe --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_wash_pan/frame_0.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a82151149bd9b0049cc798c3157e778318207e714b2d7161c7740da99aefdfc3 +size 11060269 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_wash_pan/frame_1.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_wash_pan/frame_1.safetensors new file mode 100644 index 00000000..7bc05d1d --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_wash_pan/frame_1.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6bcd8dce72039bc4863ec26c194f508a12b82ce09ed1bc9fcee09b937c93591e +size 11060269 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_wash_pan/frame_1098.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_wash_pan/frame_1098.safetensors new file mode 100644 index 00000000..c5fe4e44 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_wash_pan/frame_1098.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:933dbe688cb9e5286c1b451f2835fe8b9056453d36757c4284f9dc1c9e583131 +size 11060269 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_wash_pan/frame_1099.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_wash_pan/frame_1099.safetensors new file mode 100644 index 00000000..a0824577 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_wash_pan/frame_1099.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3720e6b410e637992fb3ec6ffbc915a262a2b4c728aff258daa523e338479ecd +size 11060269 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_wash_pan/frame_550.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_wash_pan/frame_550.safetensors new file mode 100644 index 00000000..6f9c5756 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_wash_pan/frame_550.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ef6a3fd3d4120bf757f6249ff5ecda07a45b3e90f4ea348d9885eedd871eabdf +size 11060269 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_wash_pan/frame_551.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_wash_pan/frame_551.safetensors new file mode 100644 index 00000000..fad4b28b --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_wash_pan/frame_551.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:400455865132b2ac1b14b41418867d2c0b93e7e300605e122b161cac3a714e81 +size 11060269 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_wipe_wine/frame_0.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_wipe_wine/frame_0.safetensors new file mode 100644 index 00000000..f5eee1f2 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_wipe_wine/frame_0.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dcb1c2338a29918e824d261ecb0816a7752459e733f03042bc4c6e161911d818 +size 11060269 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_wipe_wine/frame_1.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_wipe_wine/frame_1.safetensors new file mode 100644 index 00000000..36bf9969 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_wipe_wine/frame_1.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:de1dcead740d6767926f80037ee117ba3184d28058c52e92a172feb6e4a0a38f +size 11060269 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_wipe_wine/frame_1298.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_wipe_wine/frame_1298.safetensors new file mode 100644 index 00000000..d9f2277a --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_wipe_wine/frame_1298.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3e84653faf1326fa0eb79379b7bce982f7bd59662d7855a9fc6212e502f345a5 +size 11060269 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_wipe_wine/frame_1299.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_wipe_wine/frame_1299.safetensors new file mode 100644 index 00000000..4c150d35 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_wipe_wine/frame_1299.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:952e1a841dcdac4aed009cdaf65c7cc9eb2bf77746579e04656433d820e3b735 +size 11060269 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_wipe_wine/frame_650.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_wipe_wine/frame_650.safetensors new file mode 100644 index 00000000..a5fd854d --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_wipe_wine/frame_650.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:011b6e2846cc3a66f7b6d38d7046bfb6f9c5a58889d6f552ddf843c70e07baa9 +size 11060269 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_wipe_wine/frame_651.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_wipe_wine/frame_651.safetensors new file mode 100644 index 00000000..a1b211c3 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_mobile_wipe_wine/frame_651.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8940f9734ee2c1058394b44db84e327647d9b8205e4b9f1c9c671f5fa001db3a +size 11060269 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human/frame_0.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human/frame_0.safetensors index 2a896887..6f46ec88 100644 Binary files a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human/frame_0.safetensors and b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human/frame_0.safetensors differ diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human/frame_1.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human/frame_1.safetensors index b144d764..df6c16e4 100644 Binary files a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human/frame_1.safetensors and b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human/frame_1.safetensors differ diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human/frame_250.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human/frame_250.safetensors index 9c1ab2f3..4d3add53 100644 Binary files a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human/frame_250.safetensors and b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human/frame_250.safetensors differ diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human/frame_251.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human/frame_251.safetensors index b631637b..413d370a 100644 Binary files a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human/frame_251.safetensors and b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human/frame_251.safetensors differ diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human/frame_498.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human/frame_498.safetensors index f61e9d20..2df39dbb 100644 Binary files a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human/frame_498.safetensors and b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human/frame_498.safetensors differ diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human/frame_499.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human/frame_499.safetensors index 80a3642c..866b1f32 100644 Binary files a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human/frame_499.safetensors and b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human/frame_499.safetensors differ diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human_image/frame_0.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human_image/frame_0.safetensors new file mode 100644 index 00000000..b07da829 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human_image/frame_0.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2af6d705aa80d5f2c6c31908caaf0fdb0ebceb8fbf36346b71413e81a481e87d +size 3687117 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human_image/frame_1.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human_image/frame_1.safetensors new file mode 100644 index 00000000..17ec5399 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human_image/frame_1.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ffd2b58881159b00fc3cf34d5bf1d410c1c1f233eb9a5b47e3bbeff3eff23e91 +size 3687117 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human_image/frame_250.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human_image/frame_250.safetensors new file mode 100644 index 00000000..ee25fa75 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human_image/frame_250.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:93c3b99038aa736f627726002ab86f08becf17d947636144c430e37150007036 +size 3687117 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human_image/frame_251.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human_image/frame_251.safetensors new file mode 100644 index 00000000..2d61a88f --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human_image/frame_251.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d918a8acd304895d6c94a871cf8653992265147b49b84f39d217a7f3123fe559 +size 3687117 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human_image/frame_498.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human_image/frame_498.safetensors new file mode 100644 index 00000000..b188b710 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human_image/frame_498.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4067ddb4a061b4cbdd8a30432edd9d99821c9e765e2aa8c47e123a859836d15c +size 3687117 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human_image/frame_499.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human_image/frame_499.safetensors new file mode 100644 index 00000000..4c5aa1cb --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human_image/frame_499.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1d24f81c1c6e9addaea10474804080bdeed09fde586281dd6394cba1a17496ac +size 3687117 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_scripted/frame_0.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_scripted/frame_0.safetensors new file mode 100644 index 00000000..47a1382b --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_scripted/frame_0.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b505e6ce5e901d29f45959788c4b73b12aff600dd417df25e61bd11a8ad39ca1 +size 3687117 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_scripted/frame_1.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_scripted/frame_1.safetensors new file mode 100644 index 00000000..ee124112 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_scripted/frame_1.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2c2f7a1d9331716395acd62cce6a012dbf70a7cdea02fb8bd7cbce288ce08f2c +size 3687117 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_scripted/frame_200.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_scripted/frame_200.safetensors new file mode 100644 index 00000000..4f7548f8 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_scripted/frame_200.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0d224db93c92cc1f77cd05cef1b29cbe6bea9686c69753ce53afabdd426a6d01 +size 3687117 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_scripted/frame_201.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_scripted/frame_201.safetensors new file mode 100644 index 00000000..ff708af0 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_scripted/frame_201.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:96f7529aeff338e33108b805d2fe45c04fa8900b21d445e12e42bd5b505b9d85 +size 3687117 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_scripted/frame_398.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_scripted/frame_398.safetensors new file mode 100644 index 00000000..82e13902 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_scripted/frame_398.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9ea66e7d84c7a55c2bec476265f9e46c3f04bed749deb21c525fcf1495d25026 +size 3687117 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_scripted/frame_399.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_scripted/frame_399.safetensors new file mode 100644 index 00000000..ef623676 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_scripted/frame_399.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f7abc89987da95800f162b7932f4cab48a91ef7cd9c3fc2d65e57319e9613791 +size 3687117 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_scripted_image/frame_0.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_scripted_image/frame_0.safetensors new file mode 100644 index 00000000..11574835 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_scripted_image/frame_0.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:12b287bd4e15918ceccbdbf3bfdb8790504daa9cc1edcefb8c371b124931dab3 +size 3687117 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_scripted_image/frame_1.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_scripted_image/frame_1.safetensors new file mode 100644 index 00000000..5493531c --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_scripted_image/frame_1.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f50e076fd71eceec087c8d94ceaf26a33e7f0a9578a6955693549f44a4699585 +size 3687117 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_scripted_image/frame_200.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_scripted_image/frame_200.safetensors new file mode 100644 index 00000000..7e1652b4 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_scripted_image/frame_200.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:09007af2e8b05545b366df5c1d9e628fc5ca311b6ecd2859994239d69d0feac6 +size 3687117 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_scripted_image/frame_201.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_scripted_image/frame_201.safetensors new file mode 100644 index 00000000..c7a6d8df --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_scripted_image/frame_201.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bffaa9e7e4c75c770a09a068110dd8c3d6f6f547bf90aae4a268deb0817b32a0 +size 3687117 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_scripted_image/frame_398.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_scripted_image/frame_398.safetensors new file mode 100644 index 00000000..a83fa3a6 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_scripted_image/frame_398.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b4328b4c2cd212828344368de7147b0ce4e59f69a3908ff3aa7982ccafdd254d +size 3687117 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_scripted_image/frame_399.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_scripted_image/frame_399.safetensors new file mode 100644 index 00000000..d6f18b75 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_scripted_image/frame_399.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:91bf8d4aad329d4e39b534ebb86ae2b8785af57279e539fc44be17ba89d21f17 +size 3687117 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_human/frame_0.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_human/frame_0.safetensors new file mode 100644 index 00000000..036582db --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_human/frame_0.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4da8e6125f4e59ab3bd7462e5f0a56da927577c3f0d7aef31d89d54def347b44 +size 3687117 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_human/frame_1.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_human/frame_1.safetensors new file mode 100644 index 00000000..a47e1ccd --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_human/frame_1.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e9713af3b6c3af368b605c9b9d7f3c4d37ff071c06aa703c795a37d2a9760697 +size 3687117 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_human/frame_200.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_human/frame_200.safetensors new file mode 100644 index 00000000..cc04ab7a --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_human/frame_200.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:692e0c6559658c42bb279ab366111d3ad2e1d0c8a78e57ebc2ca71fbc5fc98e5 +size 3687117 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_human/frame_201.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_human/frame_201.safetensors new file mode 100644 index 00000000..9fe62316 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_human/frame_201.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9c99ce2d4e089f75e5ae94a0e40037b94ef6afd252967070b1d8b36762eb2bef +size 3687117 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_human/frame_398.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_human/frame_398.safetensors new file mode 100644 index 00000000..5d1bb8ee --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_human/frame_398.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b6c2575e1dd0d1dc5c4fee10156c6cdcff2a03a77d66780d0013902ae139246d +size 3687117 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_human/frame_399.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_human/frame_399.safetensors new file mode 100644 index 00000000..d96099aa --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_human/frame_399.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1172add204cc5996d7ec364dbb571dfa455c0bc3d53bb4ba582b1ed85bf20b21 +size 3687117 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_human_image/frame_0.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_human_image/frame_0.safetensors new file mode 100644 index 00000000..98d61908 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_human_image/frame_0.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:695c6351107b8fdd0b5773653a8f033a9e63991e0234260ca9b859c3ca669e43 +size 3687117 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_human_image/frame_1.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_human_image/frame_1.safetensors new file mode 100644 index 00000000..d5459df0 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_human_image/frame_1.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e6315b8a699e72dac4104f1410a48041469804c8b8d24501b1bfd6bf95ea0be4 +size 3687117 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_human_image/frame_200.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_human_image/frame_200.safetensors new file mode 100644 index 00000000..68231cbf --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_human_image/frame_200.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:445e8afca0b92f7b084e41df803aa1bf2c0783040fbc810df48bc6d9d57675ac +size 3687117 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_human_image/frame_201.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_human_image/frame_201.safetensors new file mode 100644 index 00000000..cc5c7be4 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_human_image/frame_201.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a0e25939a57ce02b4fadc4f2b982936ecc420e40c85f460015d4ea707f9ce04f +size 3687117 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_human_image/frame_398.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_human_image/frame_398.safetensors new file mode 100644 index 00000000..bcec5334 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_human_image/frame_398.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:da0c3e8bc471310a6426b6e459a07fa44c7776ad9209a78963f401e33db1bfed +size 3687117 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_human_image/frame_399.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_human_image/frame_399.safetensors new file mode 100644 index 00000000..64dccd29 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_human_image/frame_399.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:158257d4890ed2d633ebb66c0a1fd9dafc58565b5b04ff75717b97ffddf13972 +size 3687117 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_scripted/frame_0.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_scripted/frame_0.safetensors new file mode 100644 index 00000000..4cc74d4f --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_scripted/frame_0.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7fb4aa159232636e8b6fa2eb0572be11293609e8e9c4ddf28f2cd5e4b7e969af +size 3687117 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_scripted/frame_1.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_scripted/frame_1.safetensors new file mode 100644 index 00000000..eb2c94ce --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_scripted/frame_1.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:51421e68efd2503f08eead4cb20088260f041bb481fb6e42ddd016ece40bde0d +size 3687117 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_scripted/frame_200.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_scripted/frame_200.safetensors new file mode 100644 index 00000000..8798afab --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_scripted/frame_200.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f0595d669fd9a8cf6917583e295a8d72a8c620d10a7fac0ed17bc7fdbd0e1f08 +size 3687117 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_scripted/frame_201.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_scripted/frame_201.safetensors new file mode 100644 index 00000000..f3a845ab --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_scripted/frame_201.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0ed7e661e6c852c8dde6c2a9dafac523440aa89d84ddec1a74de2a48a90a30b8 +size 3687117 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_scripted/frame_398.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_scripted/frame_398.safetensors new file mode 100644 index 00000000..5017e228 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_scripted/frame_398.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a01e4c8ac11c1d279e76cda7ec5d555301b7d6069a2fbfc0ab97aa6bba0991e9 +size 3687117 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_scripted/frame_399.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_scripted/frame_399.safetensors new file mode 100644 index 00000000..768a09e2 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_scripted/frame_399.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eb5d1c5fed18ba2df5cfb64febc1ea0e567878fc6bb881da179693c02843941f +size 3687117 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_scripted_image/frame_0.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_scripted_image/frame_0.safetensors new file mode 100644 index 00000000..4c2d6857 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_scripted_image/frame_0.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d80263ab4343c9e7974e537563992b0dc56d618a751877f62e2fea4cc57befad +size 3687117 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_scripted_image/frame_1.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_scripted_image/frame_1.safetensors new file mode 100644 index 00000000..0c4beef8 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_scripted_image/frame_1.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cb7fc4b1ced815abb3bbb32664914c3f4840ace37deabcfaec436615fcd19bf1 +size 3687117 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_scripted_image/frame_200.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_scripted_image/frame_200.safetensors new file mode 100644 index 00000000..d335f0da --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_scripted_image/frame_200.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1340061f2747abb10c4bc76c5ff7bf0f214f5bf7e61c18660b29f377a6167247 +size 3687117 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_scripted_image/frame_201.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_scripted_image/frame_201.safetensors new file mode 100644 index 00000000..9f0b1f1c --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_scripted_image/frame_201.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4c6e4aaedd3011ef5482e1f37f0019faf60080f4dbf722ee5d3446204eb70a76 +size 3687117 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_scripted_image/frame_398.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_scripted_image/frame_398.safetensors new file mode 100644 index 00000000..167dce76 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_scripted_image/frame_398.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4847a373c47f2f7385f1e08dcd6eb36fc9d2322cdb6a04051aad6fcb4f593497 +size 3687117 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_scripted_image/frame_399.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_scripted_image/frame_399.safetensors new file mode 100644 index 00000000..0c45ff90 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_transfer_cube_scripted_image/frame_399.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:71a92b3ed9ebc4c914911ae21e6fd653fe926e8d834f48e6d522e47fd8d710cb +size 3687117 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_battery/frame_0.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_battery/frame_0.safetensors new file mode 100644 index 00000000..c94599ab --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_battery/frame_0.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6d1910016a994ff85c17ebf000037e22e526d5dd0ba64bbafc87962e86dd2b2e +size 14746637 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_battery/frame_1.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_battery/frame_1.safetensors new file mode 100644 index 00000000..8918b76f --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_battery/frame_1.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b596e0196ad9fcb144e32e59b28f18f9825b220737ed49aa725b19fa74b9b625 +size 14746637 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_battery/frame_300.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_battery/frame_300.safetensors new file mode 100644 index 00000000..a04b3cbd --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_battery/frame_300.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:52cc6e8d8893f291f6659f05effc2dd416bb971b702d27afab05e1480746ba87 +size 14746637 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_battery/frame_301.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_battery/frame_301.safetensors new file mode 100644 index 00000000..ad31ce2a --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_battery/frame_301.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:95c2ae9651355de406bd43ad537164783b650bbfc5230793aad7caae87ac0196 +size 14746637 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_battery/frame_598.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_battery/frame_598.safetensors new file mode 100644 index 00000000..688c23eb --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_battery/frame_598.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9f4ca6734fdd442e75b01fd5d2ea322f9e3bc1f9c37b11092c39bb8789fbe0af +size 14746637 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_battery/frame_599.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_battery/frame_599.safetensors new file mode 100644 index 00000000..8004e64c --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_battery/frame_599.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:925eba693dda7469f79af510d3cff17e24bbf65f75cc4ca1ed745469fe41860a +size 14746637 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_candy/frame_0.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_candy/frame_0.safetensors new file mode 100644 index 00000000..6ac1c693 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_candy/frame_0.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c9a5423a2e6ae916616b9777bc688f9c2eee2fffd376a84eddfe8614d00daf55 +size 14746637 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_candy/frame_1.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_candy/frame_1.safetensors new file mode 100644 index 00000000..7054ee47 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_candy/frame_1.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6c9ea83939b17228347bf8a8cfbe9d94499936668bc8252f0b6643a6918f9cd3 +size 14746637 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_candy/frame_350.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_candy/frame_350.safetensors new file mode 100644 index 00000000..17df84f9 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_candy/frame_350.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ee0dc1da74261be897ab51dab269d60b90b40796caf6b2d6217994d54b1d2a9a +size 14746637 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_candy/frame_351.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_candy/frame_351.safetensors new file mode 100644 index 00000000..511a7021 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_candy/frame_351.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8c64df28a98d92348eb5f4d0f7ddcb900cc8541a83ef8b1dcbab84259927df3e +size 14746637 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_candy/frame_698.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_candy/frame_698.safetensors new file mode 100644 index 00000000..e4debd2a --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_candy/frame_698.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9ea4f809dbd9d063d3067b5a58934fc26dc7a7e111a8bfb3d71970cb0f0801d6 +size 14746637 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_candy/frame_699.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_candy/frame_699.safetensors new file mode 100644 index 00000000..641c9cd4 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_candy/frame_699.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7e78227acd8bbb7c5cafb39681cc1b815135ac5eb72bf47c074caf882855cccc +size 14746637 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_coffee/frame_0.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_coffee/frame_0.safetensors new file mode 100644 index 00000000..4ea9921f --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_coffee/frame_0.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d2af3c8de0b648af094dced6a8e33294371554ba632913823cd03b711bc45ec7 +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_coffee/frame_1.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_coffee/frame_1.safetensors new file mode 100644 index 00000000..5cb0ec61 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_coffee/frame_1.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:787b0630cfa91de73c5c8b6954c5ce83978e0f0144a0905bb996eab276a48edc +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_coffee/frame_1098.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_coffee/frame_1098.safetensors new file mode 100644 index 00000000..4f0d0d8d --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_coffee/frame_1098.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:50c38571b46a10d7401267b9a1d3ca51d3f694f5611581d67b4db8b2990507d9 +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_coffee/frame_1099.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_coffee/frame_1099.safetensors new file mode 100644 index 00000000..6f358a26 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_coffee/frame_1099.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6498ff75f3cf6dca41757f65ae59d5f692dd4a35c649cb160dfc23f88923c69d +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_coffee/frame_550.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_coffee/frame_550.safetensors new file mode 100644 index 00000000..55480a69 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_coffee/frame_550.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1ac3083e8c3f7258e7e846907c34f631fa0d9a39548c4addd26bceb26765d859 +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_coffee/frame_551.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_coffee/frame_551.safetensors new file mode 100644 index 00000000..4677aeac --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_coffee/frame_551.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a75d71cb4a1dbc79b67c9d2c70e40a03983788ce5e829e7c8bd23c12fdc81a56 +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_coffee_new/frame_0.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_coffee_new/frame_0.safetensors new file mode 100644 index 00000000..bbd2e786 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_coffee_new/frame_0.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:61c977c9b0250ad9072c10eb103797d16820e6594b9fc85d1ff4af751fc967b4 +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_coffee_new/frame_1.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_coffee_new/frame_1.safetensors new file mode 100644 index 00000000..dce2c769 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_coffee_new/frame_1.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c79db975e6a0a3b7fedadb6a752e009dd174dc218651234ec1b9bdcfdf160326 +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_coffee_new/frame_1498.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_coffee_new/frame_1498.safetensors new file mode 100644 index 00000000..622a4c68 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_coffee_new/frame_1498.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:55016a32eb8bcd248bf48cb8b77094c2ba328a7854be69687225b520ceb7df34 +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_coffee_new/frame_1499.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_coffee_new/frame_1499.safetensors new file mode 100644 index 00000000..a398a705 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_coffee_new/frame_1499.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fd287bf2e4453f31569f38a1e4708c9b833e27483079e7ee380003c89eb540f4 +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_coffee_new/frame_750.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_coffee_new/frame_750.safetensors new file mode 100644 index 00000000..e2934d74 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_coffee_new/frame_750.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6f5bfe4dba4ecf566b1f31fb7d0fff906e15b809a9a707eebec2c4782cab5fd8 +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_coffee_new/frame_751.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_coffee_new/frame_751.safetensors new file mode 100644 index 00000000..18da5226 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_coffee_new/frame_751.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1a2e14c339168deb3c06ac5a39514031de0f1c4bc2b7f40a470a8a0585e8ba71 +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_cups_open/frame_0.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_cups_open/frame_0.safetensors new file mode 100644 index 00000000..dfce98fc --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_cups_open/frame_0.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b42a23eee3fb1fd9eb80c742c2ce48036a08080cbe8bb0dd59e0aa8cfc3f2735 +size 14746637 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_cups_open/frame_1.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_cups_open/frame_1.safetensors new file mode 100644 index 00000000..3c6f2e34 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_cups_open/frame_1.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e9492c14c40845dbcf0e94f2631c52c074b2fb3264c7ef6db45015298dc921ec +size 14746637 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_cups_open/frame_200.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_cups_open/frame_200.safetensors new file mode 100644 index 00000000..6e41b200 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_cups_open/frame_200.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1c6f6ff40fdf3121e57088ae8c5be91666354e140eaa003b8fb4ef7c806dd77a +size 14746637 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_cups_open/frame_201.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_cups_open/frame_201.safetensors new file mode 100644 index 00000000..e52ebe25 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_cups_open/frame_201.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:01b43229fb0eb11f4a3e9a4e9854eec7addbb371c9a3f305979513593751dfa4 +size 14746637 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_cups_open/frame_398.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_cups_open/frame_398.safetensors new file mode 100644 index 00000000..63bd9df1 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_cups_open/frame_398.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d7ecb2c7db34b1d91f35c50383e44327c7922c0067af85962e69bbc72ca2d5fe +size 14746637 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_cups_open/frame_399.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_cups_open/frame_399.safetensors new file mode 100644 index 00000000..1e206ab6 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_cups_open/frame_399.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:460c1229d0fc8b302758e26316e66f7e2533aed1c0ab027179e14afd1184ab87 +size 14746637 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_fork_pick_up/frame_0.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_fork_pick_up/frame_0.safetensors new file mode 100644 index 00000000..7ebdff9c --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_fork_pick_up/frame_0.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6a18c1dd56cb3eb5af665b42043fa376943f0334c4d3a95eb2b342ab1881b91e +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_fork_pick_up/frame_1.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_fork_pick_up/frame_1.safetensors new file mode 100644 index 00000000..5b153487 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_fork_pick_up/frame_1.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f79ebaeb70ec5ac28fa01705ec7c4652d4423b75900f79919a863456cfd34c54 +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_fork_pick_up/frame_300.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_fork_pick_up/frame_300.safetensors new file mode 100644 index 00000000..d5af77b8 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_fork_pick_up/frame_300.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1df0f8d4ad74137fd4f1af54fe84d7711a3dc70d22a0afce9a81fb93db45fb01 +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_fork_pick_up/frame_301.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_fork_pick_up/frame_301.safetensors new file mode 100644 index 00000000..89f07ac6 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_fork_pick_up/frame_301.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1a15c3351341aeee097e29e55b795e767cafcf0eb0d360188417825340cefc5d +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_fork_pick_up/frame_598.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_fork_pick_up/frame_598.safetensors new file mode 100644 index 00000000..66177b87 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_fork_pick_up/frame_598.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2d4576e7d2897a51a6d52e3b8447f62d105e6c3fe6c15e0ec1ce19a6ea3a30d3 +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_fork_pick_up/frame_599.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_fork_pick_up/frame_599.safetensors new file mode 100644 index 00000000..22b4746e --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_fork_pick_up/frame_599.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2b86f3dcb0279f869cd9ffbeb94c16c98ae78b8474e399494d904293b83cde15 +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_pingpong_test/frame_0.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_pingpong_test/frame_0.safetensors new file mode 100644 index 00000000..0c9ea184 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_pingpong_test/frame_0.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:705661cf02ce74899a6198eda019331ad9c45f9f996388e1b22acc9170c29a92 +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_pingpong_test/frame_1.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_pingpong_test/frame_1.safetensors new file mode 100644 index 00000000..f70515e6 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_pingpong_test/frame_1.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9d71a11ae3c076bfec917cc9c352017084274680de457bdc11030b427f0ad078 +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_pingpong_test/frame_300.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_pingpong_test/frame_300.safetensors new file mode 100644 index 00000000..7ba68a27 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_pingpong_test/frame_300.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d03cb79bd2b7ddb2cd2897776a0389ad19ee5d1b4fdd97d85645b9748a5ea645 +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_pingpong_test/frame_301.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_pingpong_test/frame_301.safetensors new file mode 100644 index 00000000..19c70fa9 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_pingpong_test/frame_301.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5229c4d1eb9507d16b265601f7222cadb82aece19893f85a70635a0d59d1d858 +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_pingpong_test/frame_598.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_pingpong_test/frame_598.safetensors new file mode 100644 index 00000000..cbfa8510 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_pingpong_test/frame_598.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5fdfa27751cbd0ecb220af5d66c6641113f16d1505bbdcafedc4824e71fa4c5f +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_pingpong_test/frame_599.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_pingpong_test/frame_599.safetensors new file mode 100644 index 00000000..266f44ef --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_pingpong_test/frame_599.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:56141ee6728ca9629b48ff6d41ca5fa0ef998b216455413ebb4a11354db30607 +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_pro_pencil/frame_0.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_pro_pencil/frame_0.safetensors new file mode 100644 index 00000000..5a6351af --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_pro_pencil/frame_0.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:538268195367087cd2c608b94214555c65ff8123dfefe3299836e30c6f3fcbc1 +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_pro_pencil/frame_1.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_pro_pencil/frame_1.safetensors new file mode 100644 index 00000000..799e98ba --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_pro_pencil/frame_1.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d3415761128e633d112738489841613811f91020757f9276fc475c6526ed279f +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_pro_pencil/frame_175.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_pro_pencil/frame_175.safetensors new file mode 100644 index 00000000..cfe4472c --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_pro_pencil/frame_175.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:14edcf4dde7404e03802bae4003fdc0baa2b77ebbfe5f443474cad3e9b8cbe03 +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_pro_pencil/frame_176.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_pro_pencil/frame_176.safetensors new file mode 100644 index 00000000..9a19b01c --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_pro_pencil/frame_176.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:233063f6b91e2bcbd81f5d52a1a28c8e13be9665f363783ce769c4042bbb148d +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_pro_pencil/frame_348.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_pro_pencil/frame_348.safetensors new file mode 100644 index 00000000..226288b4 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_pro_pencil/frame_348.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:996379c7d90070cd840ed73a9858948860d3e714b41a4f04b2d71ebd59854e53 +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_pro_pencil/frame_349.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_pro_pencil/frame_349.safetensors new file mode 100644 index 00000000..1035b29c --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_pro_pencil/frame_349.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ae0c867e3fe842bc8280ef8a30631b2c382933d4814ab056f8badc2c22f66e64 +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_screw_driver/frame_0.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_screw_driver/frame_0.safetensors new file mode 100644 index 00000000..7723d800 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_screw_driver/frame_0.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b4999aa699ab2a4f94d407b1a875836e7c87c188ae7aedc960df4b5855ac054a +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_screw_driver/frame_1.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_screw_driver/frame_1.safetensors new file mode 100644 index 00000000..37233acd --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_screw_driver/frame_1.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:801f252ffb4e994976806c88c9b4c3da5558cc3f5f15cc216755cd6345dbc3e0 +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_screw_driver/frame_200.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_screw_driver/frame_200.safetensors new file mode 100644 index 00000000..c5c6aab0 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_screw_driver/frame_200.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9d218fee28c33b2e44085826eb46316e18480cc9d0d653936b0b000cd7119d4d +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_screw_driver/frame_201.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_screw_driver/frame_201.safetensors new file mode 100644 index 00000000..75d10e48 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_screw_driver/frame_201.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:14973b0ca734e585ecc43a955eb2a977ce6fc71e0a5dc2032146bac13faf64e9 +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_screw_driver/frame_398.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_screw_driver/frame_398.safetensors new file mode 100644 index 00000000..59a8517e --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_screw_driver/frame_398.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bfe8fc36006d7a01506cd5d92f5427da872c5b962596af15a0c699781b0c2d5e +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_screw_driver/frame_399.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_screw_driver/frame_399.safetensors new file mode 100644 index 00000000..dcaa7b8a --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_screw_driver/frame_399.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f92e973197fc8f1aac048ebceef32ca0f70989412121dce3984f5088fe9c477e +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_tape/frame_0.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_tape/frame_0.safetensors new file mode 100644 index 00000000..3b5dd212 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_tape/frame_0.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:803d805ba87c055d46652b20120514ae1d29d5d46fe3b89f4ceb3f4c44dcd565 +size 14746637 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_tape/frame_1.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_tape/frame_1.safetensors new file mode 100644 index 00000000..a5bb5938 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_tape/frame_1.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:209c369f461123105eb57c2f5531816401e5e61f894ef3c8e15f085ee33d2b3d +size 14746637 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_tape/frame_350.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_tape/frame_350.safetensors new file mode 100644 index 00000000..f6b81841 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_tape/frame_350.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:038888662ad22df8a86887d7d42fb06cfe262b5e133647d258ed63c76594f839 +size 14746637 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_tape/frame_351.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_tape/frame_351.safetensors new file mode 100644 index 00000000..aab2f3f7 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_tape/frame_351.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:048f8e0ea4b4bb39509dd373f7fc88e5f817180bf0eb6c61c77a861e4a2d0b25 +size 14746637 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_tape/frame_698.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_tape/frame_698.safetensors new file mode 100644 index 00000000..2a6aac2b --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_tape/frame_698.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c52973799fb2ad23d851abfc766c82b87a090104cb168dc39804ab38ae1d355d +size 14746637 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_tape/frame_699.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_tape/frame_699.safetensors new file mode 100644 index 00000000..ecc44a8d --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_tape/frame_699.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4d5fa86fa4414b996f6a7c43539dd53cf998176ccbb43f5bfc7db6954de8cf7e +size 14746637 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_thread_velcro/frame_0.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_thread_velcro/frame_0.safetensors new file mode 100644 index 00000000..0212d8d1 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_thread_velcro/frame_0.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dd7ff1e0a15e03a57817f7ce85b7f0621ae448be289b1123a9f594fe9b724968 +size 14746637 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_thread_velcro/frame_1.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_thread_velcro/frame_1.safetensors new file mode 100644 index 00000000..ee446c83 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_thread_velcro/frame_1.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9690de98f24d445fbd355b6ed1a78438c631ef83cab470ded3b37801fa4a1cbd +size 14746637 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_thread_velcro/frame_300.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_thread_velcro/frame_300.safetensors new file mode 100644 index 00000000..0683ed0c --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_thread_velcro/frame_300.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:71509a6d4ff3296fdd5d28be4fb385fb7c3d2e4f1fdd4beeb1cc05643804240f +size 14746637 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_thread_velcro/frame_301.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_thread_velcro/frame_301.safetensors new file mode 100644 index 00000000..da64bdc7 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_thread_velcro/frame_301.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d8ff39601b5d1394b5d706945a23ddb1bd10c179503b37e5907305adc1f9b55b +size 14746637 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_thread_velcro/frame_598.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_thread_velcro/frame_598.safetensors new file mode 100644 index 00000000..74dce7e5 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_thread_velcro/frame_598.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0d7883ea5f9661a58dbdcf585d01eba301bb53d79a62fe36132cfae74d89d52f +size 14746637 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_thread_velcro/frame_599.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_thread_velcro/frame_599.safetensors new file mode 100644 index 00000000..64deeb2b --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_thread_velcro/frame_599.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:28903f7bea923a0577b37af20eea359d7b7237f0680c8c73150ed6f769a297f8 +size 14746637 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_towel/frame_0.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_towel/frame_0.safetensors new file mode 100644 index 00000000..36d8212d --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_towel/frame_0.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bccbbf15d20c5f9643eb2e2286f1ddcd4a08297d7f3b7829f7e9c32fe2ef2a46 +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_towel/frame_1.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_towel/frame_1.safetensors new file mode 100644 index 00000000..5d9d4572 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_towel/frame_1.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c2679c48910bfcf9e29571690dbc355ce65ad4ae7d35ee6ab9ee35f48c00aa04 +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_towel/frame_250.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_towel/frame_250.safetensors new file mode 100644 index 00000000..27e4dbef --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_towel/frame_250.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:16074bdca2bac28982f3823c7e9ccaeeaaae0c2676baee952ac22a2bb0bcf535 +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_towel/frame_251.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_towel/frame_251.safetensors new file mode 100644 index 00000000..db2f009f --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_towel/frame_251.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e6e1ff5fe6690bf20c59bb802383881665def949912d224fc7dd4ef1686bcc81 +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_towel/frame_498.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_towel/frame_498.safetensors new file mode 100644 index 00000000..d77a0295 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_towel/frame_498.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b7deab7a2f918721a92881ea9e338b7ca093209c40a42a0ebadf47cda0c5c534 +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_towel/frame_499.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_towel/frame_499.safetensors new file mode 100644 index 00000000..36efd61c --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_towel/frame_499.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7d7f1cf5764b5d1b7ae6784fecf8372eba3a6e3f08b39e2731874ca26459ca8c +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_vinh_cup/frame_0.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_vinh_cup/frame_0.safetensors new file mode 100644 index 00000000..206e6874 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_vinh_cup/frame_0.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:092dbf1d1a7b576ff491aef010cc06daea435813f42a87668334b05b0ecd9c8e +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_vinh_cup/frame_1.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_vinh_cup/frame_1.safetensors new file mode 100644 index 00000000..d778fc1d --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_vinh_cup/frame_1.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a1504807c8df2f255f06250c43d98cd057d0486946629f441d64a979e5181046 +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_vinh_cup/frame_250.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_vinh_cup/frame_250.safetensors new file mode 100644 index 00000000..3ba20828 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_vinh_cup/frame_250.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9a0bb3e21109bdaf2ee7f10fa19df8b168dc470d574d105e58caaf662d851680 +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_vinh_cup/frame_251.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_vinh_cup/frame_251.safetensors new file mode 100644 index 00000000..55a95ac3 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_vinh_cup/frame_251.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2fcf3cda237cec2f9599b0866b0c1e9cb5054a2426e5cd39273b7310d8e03a44 +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_vinh_cup/frame_498.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_vinh_cup/frame_498.safetensors new file mode 100644 index 00000000..a4f3a365 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_vinh_cup/frame_498.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0f25504ce7ff734bb9fbf786de06bc3dc939ba176fb5c94e9475ad1d895c2d35 +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_vinh_cup/frame_499.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_vinh_cup/frame_499.safetensors new file mode 100644 index 00000000..ead2aa7b --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_vinh_cup/frame_499.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:83f268f1f69ec78f1ad23dd5db67bad32c6b6f58f9609378054a1e25b21185e0 +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_vinh_cup_left/frame_0.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_vinh_cup_left/frame_0.safetensors new file mode 100644 index 00000000..790e8df4 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_vinh_cup_left/frame_0.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dc73ba199acc90e43493768864749fdf834deb4690b102ab47ec262e70fa0f74 +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_vinh_cup_left/frame_1.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_vinh_cup_left/frame_1.safetensors new file mode 100644 index 00000000..e8e70c32 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_vinh_cup_left/frame_1.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:559f400e967d153286698f2cdd37b6e5cf60f27af46720584bbfffe1188bc9ae +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_vinh_cup_left/frame_250.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_vinh_cup_left/frame_250.safetensors new file mode 100644 index 00000000..7ef4384e --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_vinh_cup_left/frame_250.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:266aa2de91c860bc5de71b8439fb371b72d2a96e103eb47203437274e9508cfb +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_vinh_cup_left/frame_251.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_vinh_cup_left/frame_251.safetensors new file mode 100644 index 00000000..bb527db3 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_vinh_cup_left/frame_251.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b3ca2a8f76f97bf02540ebfaac9775fe2aa17f890f94e24b3827f233e813f794 +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_vinh_cup_left/frame_498.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_vinh_cup_left/frame_498.safetensors new file mode 100644 index 00000000..4fd98837 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_vinh_cup_left/frame_498.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5c455abecfbada4b98ad32197abb030cbab158d2d6ad247c6f24022b39318546 +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_vinh_cup_left/frame_499.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_vinh_cup_left/frame_499.safetensors new file mode 100644 index 00000000..837609fe --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_vinh_cup_left/frame_499.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fa9ee297d7557f5cce6030277b230e95248ddcfdc53dc76787470ffa5b9342a1 +size 14746773 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_ziploc_slide/frame_0.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_ziploc_slide/frame_0.safetensors new file mode 100644 index 00000000..3c0f704f --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_ziploc_slide/frame_0.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f48e8d0c3c0e243918f11440e102d88dbfbba9a5b03f4f9bcdcd1ee93ec18759 +size 14746637 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_ziploc_slide/frame_1.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_ziploc_slide/frame_1.safetensors new file mode 100644 index 00000000..4056b339 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_ziploc_slide/frame_1.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e306f7d873028a2a547806e5da39b111ade93e3fda51b9dec9d8c47f197bff9b +size 14746637 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_ziploc_slide/frame_150.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_ziploc_slide/frame_150.safetensors new file mode 100644 index 00000000..b40a52c3 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_ziploc_slide/frame_150.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c96e924cc3447a11bb19677c771f64750bca1221df77994ffba028ce8470967a +size 14746637 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_ziploc_slide/frame_151.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_ziploc_slide/frame_151.safetensors new file mode 100644 index 00000000..6f65d342 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_ziploc_slide/frame_151.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:58dbfcdd88e8ef0e84ae74e6dfcc5ec5d3f7ce7196ddc178562fbfb179b4da3f +size 14746637 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_ziploc_slide/frame_298.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_ziploc_slide/frame_298.safetensors new file mode 100644 index 00000000..4815ae90 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_ziploc_slide/frame_298.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a0b6d115b33462877d29f67d71fe8ac7a793c37b0c6c622206b3c087ddf8ccb4 +size 14746637 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_ziploc_slide/frame_299.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_ziploc_slide/frame_299.safetensors new file mode 100644 index 00000000..35caadf6 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/aloha_static_ziploc_slide/frame_299.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:46d63645910b87e19cb843232adde7632669a2e3764503a8325d70cdd5f81e37 +size 14746637 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_0.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_0.safetensors index 6a3f3e5e..2f690b32 100644 Binary files a/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_0.safetensors and b/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_0.safetensors differ diff --git a/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_1.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_1.safetensors index 69c89958..761b2f07 100644 Binary files a/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_1.safetensors and b/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_1.safetensors differ diff --git a/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_159.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_159.safetensors index cbf1f1fa..d411d8e3 100644 Binary files a/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_159.safetensors and b/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_159.safetensors differ diff --git a/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_160.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_160.safetensors index 2107611b..6eda2231 100644 Binary files a/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_160.safetensors and b/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_160.safetensors differ diff --git a/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_80.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_80.safetensors index 94c55feb..574d0ef7 100644 Binary files a/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_80.safetensors and b/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_80.safetensors differ diff --git a/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_81.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_81.safetensors index 7f63f83f..1c3dc8e4 100644 Binary files a/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_81.safetensors and b/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_81.safetensors differ diff --git a/tests/data/save_dataset_to_safetensors/lerobot/pusht_image/frame_0.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/pusht_image/frame_0.safetensors new file mode 100644 index 00000000..955f1461 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/pusht_image/frame_0.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:15a7bd80a58044ce91bbc17859b6ee59677782223dfef55cfd34a6f9a3d8f253 +size 111338 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/pusht_image/frame_1.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/pusht_image/frame_1.safetensors new file mode 100644 index 00000000..b83e75fa --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/pusht_image/frame_1.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:08ffcac00f0b03ecbc28ee7112fedb6a7a9379a7c6066bc9af7509ac49247284 +size 111338 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/pusht_image/frame_159.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/pusht_image/frame_159.safetensors new file mode 100644 index 00000000..0a0ecf9e --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/pusht_image/frame_159.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:74956a949d34ef4d40d547a520271488540153d0f107dcc89025baefd59ab46b +size 111338 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/pusht_image/frame_160.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/pusht_image/frame_160.safetensors new file mode 100644 index 00000000..56e0715e --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/pusht_image/frame_160.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:555215357be71d2cecf7f7cb2837a155af7c7d797a74eb5f73313297706d1cdd +size 111338 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/pusht_image/frame_80.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/pusht_image/frame_80.safetensors new file mode 100644 index 00000000..12aaaa49 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/pusht_image/frame_80.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c87df1a1e7d793fcb3b89343257e107fff855109ad2894cb53c24aba5d9f37fb +size 111338 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/pusht_image/frame_81.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/pusht_image/frame_81.safetensors new file mode 100644 index 00000000..4bf026fc --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/pusht_image/frame_81.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:550aa1a6b32c21f711430c4ce677a3fdb64e0312cd909ecc094a96593d6d25b9 +size 111338 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/umi_cup_in_the_wild/frame_0.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/umi_cup_in_the_wild/frame_0.safetensors new file mode 100644 index 00000000..bea1ca1e --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/umi_cup_in_the_wild/frame_0.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8d0a7f85da5e523fdfc6176aee32933f18c2273435f97c9f0068611a3e04058f +size 603012 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/umi_cup_in_the_wild/frame_1.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/umi_cup_in_the_wild/frame_1.safetensors new file mode 100644 index 00000000..95b1cb4e --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/umi_cup_in_the_wild/frame_1.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c3dfa76aa9e2f06074f142f1da95d3c993f44af9129b6e416f23d2dc7b89b384 +size 603012 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/umi_cup_in_the_wild/frame_200.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/umi_cup_in_the_wild/frame_200.safetensors new file mode 100644 index 00000000..7630b216 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/umi_cup_in_the_wild/frame_200.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b18700564cfcae431d1a7272deedca34793cd4c52b98c2c1a3ebfecca3fd8834 +size 603012 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/umi_cup_in_the_wild/frame_201.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/umi_cup_in_the_wild/frame_201.safetensors new file mode 100644 index 00000000..da361d3a --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/umi_cup_in_the_wild/frame_201.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:60b735c470e3ab4404bd008d1ce2f9cb5d2b9878bedbd5070e3e6c5210dac446 +size 603012 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/umi_cup_in_the_wild/frame_398.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/umi_cup_in_the_wild/frame_398.safetensors new file mode 100644 index 00000000..6fc2a9f1 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/umi_cup_in_the_wild/frame_398.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bbca24ea09aa2d9ce08d826f0365349890fcc2ff26d25a0de7de501f77b25730 +size 603012 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/umi_cup_in_the_wild/frame_399.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/umi_cup_in_the_wild/frame_399.safetensors new file mode 100644 index 00000000..bdd1408b --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/umi_cup_in_the_wild/frame_399.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:77c138571a4c07ac244f50d6c1d8a6c93a75eaee434185b5e3561ba27fdd7727 +size 603012 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium/frame_0.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium/frame_0.safetensors index d256267b..3b4a15ca 100644 Binary files a/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium/frame_0.safetensors and b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium/frame_0.safetensors differ diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium/frame_1.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium/frame_1.safetensors index 5d0e8001..95cdbb2f 100644 Binary files a/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium/frame_1.safetensors and b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium/frame_1.safetensors differ diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium/frame_12.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium/frame_12.safetensors index 4c3be9fc..0ffb01d8 100644 Binary files a/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium/frame_12.safetensors and b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium/frame_12.safetensors differ diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium/frame_13.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium/frame_13.safetensors index bc3f3a35..04fe6479 100644 Binary files a/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium/frame_13.safetensors and b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium/frame_13.safetensors differ diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium/frame_23.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium/frame_23.safetensors index 9683bb8e..6b5b7836 100644 Binary files a/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium/frame_23.safetensors and b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium/frame_23.safetensors differ diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium/frame_24.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium/frame_24.safetensors index 0777e9ad..4ef5f6a2 100644 Binary files a/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium/frame_24.safetensors and b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium/frame_24.safetensors differ diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_image/frame_0.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_image/frame_0.safetensors new file mode 100644 index 00000000..83eb81ae --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_image/frame_0.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c64420d96e2aff6e34075174dd17a190e3a9b5b69d1efb03ec34d2e4a6a2a9e6 +size 85349 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_image/frame_1.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_image/frame_1.safetensors new file mode 100644 index 00000000..3e5f163e --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_image/frame_1.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ee543bd812f134438bd9051eb243263dcb47782b849498fbb42bdbb9c8fe58f5 +size 85349 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_image/frame_12.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_image/frame_12.safetensors new file mode 100644 index 00000000..70cc3ecc --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_image/frame_12.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:83458a1aaa5ce7bf3eb99fccbebac10661e18c008de4cc1c0b6b4fc52b3f07ef +size 85349 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_image/frame_13.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_image/frame_13.safetensors new file mode 100644 index 00000000..3efd7c0f --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_image/frame_13.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:65ad2bff1fb6d4a580228c08375d01d0987c376b9c7bedc1b591588e3b72ad14 +size 85349 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_image/frame_23.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_image/frame_23.safetensors new file mode 100644 index 00000000..0b8f00d5 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_image/frame_23.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9daa07f20f117452d95a4d1ae646037184e92a1eba7a065076ab31001fccfb07 +size 85349 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_image/frame_24.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_image/frame_24.safetensors new file mode 100644 index 00000000..8ca052c6 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_image/frame_24.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5f059ca9dd4eec03cebc728726f51fe2e174a520d4ebb5c9fbc2c6564998b73e +size 85349 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_replay/frame_0.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_replay/frame_0.safetensors new file mode 100644 index 00000000..f04f8190 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_replay/frame_0.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9de0a9b57bdf8789dfe18ad4232025c1531774d2a1da2606fd1b10f313120b19 +size 85349 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_replay/frame_1.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_replay/frame_1.safetensors new file mode 100644 index 00000000..64e56e52 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_replay/frame_1.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7a0989c5cee82141bdb8c878661d43e82e40d8bc8a860588ee869db15e628a71 +size 85349 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_replay/frame_12.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_replay/frame_12.safetensors new file mode 100644 index 00000000..d6d98f8a --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_replay/frame_12.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:27f5fcf7e2c4b9d201a5f07d30eaed6047c7e48b17ee9c7dd89ed19ebcc098f8 +size 85349 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_replay/frame_13.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_replay/frame_13.safetensors new file mode 100644 index 00000000..fac57a56 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_replay/frame_13.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cd04d30f3b234a022cb157e301443bedec553824ba720b6965c38fed263b0e70 +size 85349 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_replay/frame_23.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_replay/frame_23.safetensors new file mode 100644 index 00000000..ea384b10 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_replay/frame_23.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f9b15ad185fcd060fc3941c56b662001fcb945588f4819d6ae001b002d969fd3 +size 85349 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_replay/frame_24.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_replay/frame_24.safetensors new file mode 100644 index 00000000..5e50bf50 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_replay/frame_24.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ef8af5274142f5606e04df0fa228bcf60181810c3c31e1de147d24097a599497 +size 85349 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_replay_image/frame_0.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_replay_image/frame_0.safetensors new file mode 100644 index 00000000..83eb81ae --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_replay_image/frame_0.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c64420d96e2aff6e34075174dd17a190e3a9b5b69d1efb03ec34d2e4a6a2a9e6 +size 85349 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_replay_image/frame_1.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_replay_image/frame_1.safetensors new file mode 100644 index 00000000..3e5f163e --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_replay_image/frame_1.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ee543bd812f134438bd9051eb243263dcb47782b849498fbb42bdbb9c8fe58f5 +size 85349 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_replay_image/frame_12.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_replay_image/frame_12.safetensors new file mode 100644 index 00000000..70cc3ecc --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_replay_image/frame_12.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:83458a1aaa5ce7bf3eb99fccbebac10661e18c008de4cc1c0b6b4fc52b3f07ef +size 85349 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_replay_image/frame_13.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_replay_image/frame_13.safetensors new file mode 100644 index 00000000..3efd7c0f --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_replay_image/frame_13.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:65ad2bff1fb6d4a580228c08375d01d0987c376b9c7bedc1b591588e3b72ad14 +size 85349 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_replay_image/frame_23.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_replay_image/frame_23.safetensors new file mode 100644 index 00000000..0b8f00d5 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_replay_image/frame_23.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9daa07f20f117452d95a4d1ae646037184e92a1eba7a065076ab31001fccfb07 +size 85349 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_replay_image/frame_24.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_replay_image/frame_24.safetensors new file mode 100644 index 00000000..8ca052c6 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium_replay_image/frame_24.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5f059ca9dd4eec03cebc728726f51fe2e174a520d4ebb5c9fbc2c6564998b73e +size 85349 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium/frame_0.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium/frame_0.safetensors new file mode 100644 index 00000000..f04f8190 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium/frame_0.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9de0a9b57bdf8789dfe18ad4232025c1531774d2a1da2606fd1b10f313120b19 +size 85349 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium/frame_1.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium/frame_1.safetensors new file mode 100644 index 00000000..64e56e52 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium/frame_1.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7a0989c5cee82141bdb8c878661d43e82e40d8bc8a860588ee869db15e628a71 +size 85349 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium/frame_12.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium/frame_12.safetensors new file mode 100644 index 00000000..d6d98f8a --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium/frame_12.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:27f5fcf7e2c4b9d201a5f07d30eaed6047c7e48b17ee9c7dd89ed19ebcc098f8 +size 85349 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium/frame_13.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium/frame_13.safetensors new file mode 100644 index 00000000..fac57a56 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium/frame_13.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cd04d30f3b234a022cb157e301443bedec553824ba720b6965c38fed263b0e70 +size 85349 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium/frame_23.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium/frame_23.safetensors new file mode 100644 index 00000000..ea384b10 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium/frame_23.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f9b15ad185fcd060fc3941c56b662001fcb945588f4819d6ae001b002d969fd3 +size 85349 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium/frame_24.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium/frame_24.safetensors new file mode 100644 index 00000000..5e50bf50 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium/frame_24.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ef8af5274142f5606e04df0fa228bcf60181810c3c31e1de147d24097a599497 +size 85349 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_image/frame_0.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_image/frame_0.safetensors new file mode 100644 index 00000000..83eb81ae --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_image/frame_0.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c64420d96e2aff6e34075174dd17a190e3a9b5b69d1efb03ec34d2e4a6a2a9e6 +size 85349 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_image/frame_1.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_image/frame_1.safetensors new file mode 100644 index 00000000..3e5f163e --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_image/frame_1.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ee543bd812f134438bd9051eb243263dcb47782b849498fbb42bdbb9c8fe58f5 +size 85349 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_image/frame_12.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_image/frame_12.safetensors new file mode 100644 index 00000000..70cc3ecc --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_image/frame_12.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:83458a1aaa5ce7bf3eb99fccbebac10661e18c008de4cc1c0b6b4fc52b3f07ef +size 85349 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_image/frame_13.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_image/frame_13.safetensors new file mode 100644 index 00000000..3efd7c0f --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_image/frame_13.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:65ad2bff1fb6d4a580228c08375d01d0987c376b9c7bedc1b591588e3b72ad14 +size 85349 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_image/frame_23.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_image/frame_23.safetensors new file mode 100644 index 00000000..0b8f00d5 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_image/frame_23.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9daa07f20f117452d95a4d1ae646037184e92a1eba7a065076ab31001fccfb07 +size 85349 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_image/frame_24.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_image/frame_24.safetensors new file mode 100644 index 00000000..8ca052c6 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_image/frame_24.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5f059ca9dd4eec03cebc728726f51fe2e174a520d4ebb5c9fbc2c6564998b73e +size 85349 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_replay/frame_0.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_replay/frame_0.safetensors new file mode 100644 index 00000000..f04f8190 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_replay/frame_0.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9de0a9b57bdf8789dfe18ad4232025c1531774d2a1da2606fd1b10f313120b19 +size 85349 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_replay/frame_1.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_replay/frame_1.safetensors new file mode 100644 index 00000000..64e56e52 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_replay/frame_1.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7a0989c5cee82141bdb8c878661d43e82e40d8bc8a860588ee869db15e628a71 +size 85349 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_replay/frame_12.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_replay/frame_12.safetensors new file mode 100644 index 00000000..d6d98f8a --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_replay/frame_12.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:27f5fcf7e2c4b9d201a5f07d30eaed6047c7e48b17ee9c7dd89ed19ebcc098f8 +size 85349 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_replay/frame_13.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_replay/frame_13.safetensors new file mode 100644 index 00000000..fac57a56 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_replay/frame_13.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cd04d30f3b234a022cb157e301443bedec553824ba720b6965c38fed263b0e70 +size 85349 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_replay/frame_23.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_replay/frame_23.safetensors new file mode 100644 index 00000000..ea384b10 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_replay/frame_23.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f9b15ad185fcd060fc3941c56b662001fcb945588f4819d6ae001b002d969fd3 +size 85349 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_replay/frame_24.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_replay/frame_24.safetensors new file mode 100644 index 00000000..5e50bf50 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_replay/frame_24.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ef8af5274142f5606e04df0fa228bcf60181810c3c31e1de147d24097a599497 +size 85349 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_replay_image/frame_0.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_replay_image/frame_0.safetensors new file mode 100644 index 00000000..83eb81ae --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_replay_image/frame_0.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c64420d96e2aff6e34075174dd17a190e3a9b5b69d1efb03ec34d2e4a6a2a9e6 +size 85349 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_replay_image/frame_1.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_replay_image/frame_1.safetensors new file mode 100644 index 00000000..3e5f163e --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_replay_image/frame_1.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ee543bd812f134438bd9051eb243263dcb47782b849498fbb42bdbb9c8fe58f5 +size 85349 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_replay_image/frame_12.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_replay_image/frame_12.safetensors new file mode 100644 index 00000000..70cc3ecc --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_replay_image/frame_12.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:83458a1aaa5ce7bf3eb99fccbebac10661e18c008de4cc1c0b6b4fc52b3f07ef +size 85349 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_replay_image/frame_13.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_replay_image/frame_13.safetensors new file mode 100644 index 00000000..3efd7c0f --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_replay_image/frame_13.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:65ad2bff1fb6d4a580228c08375d01d0987c376b9c7bedc1b591588e3b72ad14 +size 85349 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_replay_image/frame_23.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_replay_image/frame_23.safetensors new file mode 100644 index 00000000..0b8f00d5 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_replay_image/frame_23.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9daa07f20f117452d95a4d1ae646037184e92a1eba7a065076ab31001fccfb07 +size 85349 diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_replay_image/frame_24.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_replay_image/frame_24.safetensors new file mode 100644 index 00000000..8ca052c6 --- /dev/null +++ b/tests/data/save_dataset_to_safetensors/lerobot/xarm_push_medium_replay_image/frame_24.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5f059ca9dd4eec03cebc728726f51fe2e174a520d4ebb5c9fbc2c6564998b73e +size 85349 diff --git a/tests/data/save_policy_to_safetensors/aloha_act/actions.safetensors b/tests/data/save_policy_to_safetensors/aloha_act/actions.safetensors index 3c9447d7..c5176423 100644 Binary files a/tests/data/save_policy_to_safetensors/aloha_act/actions.safetensors and b/tests/data/save_policy_to_safetensors/aloha_act/actions.safetensors differ diff --git a/tests/data/save_policy_to_safetensors/aloha_act/grad_stats.safetensors b/tests/data/save_policy_to_safetensors/aloha_act/grad_stats.safetensors index 7dfbc3b3..bdecb18b 100644 Binary files a/tests/data/save_policy_to_safetensors/aloha_act/grad_stats.safetensors and b/tests/data/save_policy_to_safetensors/aloha_act/grad_stats.safetensors differ diff --git a/tests/data/save_policy_to_safetensors/aloha_act/output_dict.safetensors b/tests/data/save_policy_to_safetensors/aloha_act/output_dict.safetensors index 4c738f39..641771c6 100644 Binary files a/tests/data/save_policy_to_safetensors/aloha_act/output_dict.safetensors and b/tests/data/save_policy_to_safetensors/aloha_act/output_dict.safetensors differ diff --git a/tests/data/save_policy_to_safetensors/aloha_act/param_stats.safetensors b/tests/data/save_policy_to_safetensors/aloha_act/param_stats.safetensors index 7a2e0e70..26d91924 100644 Binary files a/tests/data/save_policy_to_safetensors/aloha_act/param_stats.safetensors and b/tests/data/save_policy_to_safetensors/aloha_act/param_stats.safetensors differ diff --git a/tests/data/save_policy_to_safetensors/pusht_diffusion/actions.safetensors b/tests/data/save_policy_to_safetensors/pusht_diffusion/actions.safetensors index 8f039903..538a06a5 100644 Binary files a/tests/data/save_policy_to_safetensors/pusht_diffusion/actions.safetensors and b/tests/data/save_policy_to_safetensors/pusht_diffusion/actions.safetensors differ diff --git a/tests/data/save_policy_to_safetensors/pusht_diffusion/grad_stats.safetensors b/tests/data/save_policy_to_safetensors/pusht_diffusion/grad_stats.safetensors index 2b659396..74bcbd39 100644 Binary files a/tests/data/save_policy_to_safetensors/pusht_diffusion/grad_stats.safetensors and b/tests/data/save_policy_to_safetensors/pusht_diffusion/grad_stats.safetensors differ diff --git a/tests/data/save_policy_to_safetensors/pusht_diffusion/output_dict.safetensors b/tests/data/save_policy_to_safetensors/pusht_diffusion/output_dict.safetensors index a9f61b36..37125b65 100644 Binary files a/tests/data/save_policy_to_safetensors/pusht_diffusion/output_dict.safetensors and b/tests/data/save_policy_to_safetensors/pusht_diffusion/output_dict.safetensors differ diff --git a/tests/data/save_policy_to_safetensors/pusht_diffusion/param_stats.safetensors b/tests/data/save_policy_to_safetensors/pusht_diffusion/param_stats.safetensors index a9f4608f..45efd3eb 100644 Binary files a/tests/data/save_policy_to_safetensors/pusht_diffusion/param_stats.safetensors and b/tests/data/save_policy_to_safetensors/pusht_diffusion/param_stats.safetensors differ diff --git a/tests/data/save_policy_to_safetensors/xarm_tdmpc/actions.safetensors b/tests/data/save_policy_to_safetensors/xarm_tdmpc/actions.safetensors index 0339ca0e..49179928 100644 Binary files a/tests/data/save_policy_to_safetensors/xarm_tdmpc/actions.safetensors and b/tests/data/save_policy_to_safetensors/xarm_tdmpc/actions.safetensors differ diff --git a/tests/data/save_policy_to_safetensors/xarm_tdmpc/grad_stats.safetensors b/tests/data/save_policy_to_safetensors/xarm_tdmpc/grad_stats.safetensors index 5520c643..a9320466 100644 Binary files a/tests/data/save_policy_to_safetensors/xarm_tdmpc/grad_stats.safetensors and b/tests/data/save_policy_to_safetensors/xarm_tdmpc/grad_stats.safetensors differ diff --git a/tests/data/save_policy_to_safetensors/xarm_tdmpc/output_dict.safetensors b/tests/data/save_policy_to_safetensors/xarm_tdmpc/output_dict.safetensors index 2321f31c..ef49ce97 100644 Binary files a/tests/data/save_policy_to_safetensors/xarm_tdmpc/output_dict.safetensors and b/tests/data/save_policy_to_safetensors/xarm_tdmpc/output_dict.safetensors differ diff --git a/tests/data/save_policy_to_safetensors/xarm_tdmpc/param_stats.safetensors b/tests/data/save_policy_to_safetensors/xarm_tdmpc/param_stats.safetensors index 5e8a6947..9b399a4c 100644 Binary files a/tests/data/save_policy_to_safetensors/xarm_tdmpc/param_stats.safetensors and b/tests/data/save_policy_to_safetensors/xarm_tdmpc/param_stats.safetensors differ diff --git a/tests/scripts/save_dataset_to_safetensors.py b/tests/scripts/save_dataset_to_safetensors.py index 554efe75..4aa8131f 100644 --- a/tests/scripts/save_dataset_to_safetensors.py +++ b/tests/scripts/save_dataset_to_safetensors.py @@ -81,11 +81,5 @@ def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"): if __name__ == "__main__": - available_datasets = [ - "lerobot/pusht", - "lerobot/xarm_lift_medium", - "lerobot/aloha_sim_insertion_human", - # "lerobot/umi_cup_in_the_wild", - ] for dataset in available_datasets: save_dataset_to_safetensors("tests/data/save_dataset_to_safetensors", repo_id=dataset) diff --git a/tests/scripts/save_policy_to_safetensor.py b/tests/scripts/save_policy_to_safetensor.py index e79a94ff..89f33374 100644 --- a/tests/scripts/save_policy_to_safetensor.py +++ b/tests/scripts/save_policy_to_safetensor.py @@ -26,7 +26,7 @@ from lerobot.scripts.train import make_optimizer_and_scheduler from tests.utils import DEFAULT_CONFIG_PATH -def get_policy_stats(env_name, policy_name, extra_overrides=None): +def get_policy_stats(env_name, policy_name, extra_overrides): cfg = init_hydra_config( DEFAULT_CONFIG_PATH, overrides=[ @@ -92,6 +92,9 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override env_policy_dir = Path(output_dir) / f"{env_name}_{policy_name}" if env_policy_dir.exists(): + print(f"Overwrite existing safetensors in '{env_policy_dir}':") + print(f" - Validate with: `git add {env_policy_dir}`") + print(f" - Revert with: `git checkout -- {env_policy_dir}`") shutil.rmtree(env_policy_dir) env_policy_dir.mkdir(parents=True, exist_ok=True) @@ -103,8 +106,14 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override if __name__ == "__main__": - # Instructions: include the policies that you want to save artifacts for here. Please make sure to revert - # your changes when you are done. - env_policies = [] + env_policies = [ + ("xarm", "tdmpc", []), + ( + "pusht", + "diffusion", + ["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"], + ), + ("aloha", "act", ["policy.n_action_steps=10"]), + ] for env, policy, extra_overrides in env_policies: save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides) diff --git a/tests/test_examples.py b/tests/test_examples.py index de95a991..a0c60b7e 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # TODO(aliberts): Mute logging for these tests +import io import subprocess import sys from pathlib import Path @@ -32,6 +33,11 @@ def _run_script(path): subprocess.run([sys.executable, path], check=True) +def _read_file(path): + with open(path) as file: + return file.read() + + def test_example_1(): path = "examples/1_load_lerobot_dataset.py" _run_script(path) @@ -39,18 +45,17 @@ def test_example_1(): @require_package("gym_pusht") -def test_examples_3_and_2(): +def test_examples_2_through_4(): """ Train a model with example 3, check the outputs. Evaluate the trained model with example 2, check the outputs. + Calculate the validation loss with example 4, check the outputs. """ - path = "examples/3_train_policy.py" + ### Test example 3 + file_contents = _read_file("examples/3_train_policy.py") - with open(path) as file: - file_contents = file.read() - - # Do less steps, use smaller batch, use CPU, and don't complicate things with dataloader workers. + # Do fewer steps, use smaller batch, use CPU, and don't complicate things with dataloader workers. file_contents = _find_and_replace( file_contents, [ @@ -67,16 +72,17 @@ def test_examples_3_and_2(): for file_name in ["model.safetensors", "config.json"]: assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists() - path = "examples/2_evaluate_pretrained_policy.py" + ### Test example 2 + file_contents = _read_file("examples/2_evaluate_pretrained_policy.py") - with open(path) as file: - file_contents = file.read() - - # Do less evals, use CPU, and use the local model. + # Do fewer evals, use CPU, and use the local model. file_contents = _find_and_replace( file_contents, [ - ('pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))', ""), + ( + 'pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))', + "", + ), ( '# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")', 'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")', @@ -89,3 +95,34 @@ def test_examples_3_and_2(): exec(file_contents, {}) assert Path("outputs/eval/example_pusht_diffusion/rollout.mp4").exists() + + ## Test example 4 + file_contents = _read_file("examples/4_calculate_validation_loss.py") + + # Run on a single example from the last episode, use CPU, and use the local model. + file_contents = _find_and_replace( + file_contents, + [ + ( + 'pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))', + "", + ), + ( + '# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")', + 'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")', + ), + ('split=f"train[{first_val_frame_index}:]"', 'split="train[30:]"'), + ("num_workers=4", "num_workers=0"), + ('device = torch.device("cuda")', 'device = torch.device("cpu")'), + ("batch_size=64", "batch_size=1"), + ], + ) + + # Capture the output of the script + output_buffer = io.StringIO() + sys.stdout = output_buffer + exec(file_contents, {}) + printed_output = output_buffer.getvalue() + # Restore stdout to its original state + sys.stdout = sys.__stdout__ + assert "Average loss on validation set" in printed_output diff --git a/tests/test_policies.py b/tests/test_policies.py index 75633fe6..bb0c7b80 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -31,7 +31,7 @@ from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.policy_protocol import Policy from lerobot.common.utils.utils import init_hydra_config from tests.scripts.save_policy_to_safetensor import get_policy_stats -from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env, require_x86_64_kernel +from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel @pytest.mark.parametrize("policy_name", available_policies) @@ -296,16 +296,17 @@ def test_normalize(insert_temporal_dim): # As artifacts have been generated on an x86_64 kernel, this test won't # pass if it's run on another platform due to floating point errors @require_x86_64_kernel +@require_cpu def test_backward_compatibility(env_name, policy_name, extra_overrides): """ NOTE: If this test does not pass, and you have intentionally changed something in the policy: 1. Inspect the differences in policy outputs and make sure you can account for them. Your PR should include a report on what changed and how that affected the outputs. - 2. Go to the `if __name__ == "__main__"` block of `test/scripts/save_policy_to_safetensors.py` and + 2. Go to the `if __name__ == "__main__"` block of `tests/scripts/save_policy_to_safetensors.py` and add the policies you want to update the test artifacts for. - 3. Run `python test/scripts/save_policy_to_safetensors.py`. The test artifact should be updated. + 3. Run `python tests/scripts/save_policy_to_safetensors.py`. The test artifact should be updated. 4. Check that this test now passes. - 5. Remember to restore `test/scripts/save_policy_to_safetensors.py` to its original state. + 5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state. 6. Remember to stage and commit the resulting changes to `tests/data`. """ env_policy_dir = Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}" diff --git a/tests/test_utils.py b/tests/test_utils.py index bcdd95b4..a7f770fb 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -4,20 +4,28 @@ from typing import Callable import numpy as np import pytest import torch +from datasets import Dataset +from lerobot.common.datasets.utils import ( + calculate_episode_data_index, + hf_transform_to_torch, + reset_episode_index, +) from lerobot.common.utils.utils import seeded_context, set_global_seed @pytest.mark.parametrize( "rand_fn", - [ - random.random, - np.random.random, - lambda: torch.rand(1).item(), - ] - + [lambda: torch.rand(1, device="cuda")] - if torch.cuda.is_available() - else [], + ( + [ + random.random, + np.random.random, + lambda: torch.rand(1).item(), + ] + + [lambda: torch.rand(1, device="cuda")] + if torch.cuda.is_available() + else [] + ), ) def test_seeding(rand_fn: Callable[[], int]): set_global_seed(0) @@ -36,3 +44,31 @@ def test_seeding(rand_fn: Callable[[], int]): c_ = rand_fn() # Check that `seeded_context` and `global_seed` give the same reproducibility. assert c_ == c + + +def test_calculate_episode_data_index(): + dataset = Dataset.from_dict( + { + "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + "index": [0, 1, 2, 3, 4, 5], + "episode_index": [0, 0, 1, 2, 2, 2], + }, + ) + dataset.set_transform(hf_transform_to_torch) + episode_data_index = calculate_episode_data_index(dataset) + assert torch.equal(episode_data_index["from"], torch.tensor([0, 2, 3])) + assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6])) + + +def test_reset_episode_index(): + dataset = Dataset.from_dict( + { + "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + "index": [0, 1, 2, 3, 4, 5], + "episode_index": [10, 10, 11, 12, 12, 12], + }, + ) + dataset.set_transform(hf_transform_to_torch) + correct_episode_index = [0, 0, 1, 2, 2, 2] + dataset = reset_episode_index(dataset) + assert dataset["episode_index"] == correct_episode_index