diff --git a/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py b/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py index 52c4bba3..e2973ef8 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py @@ -30,12 +30,12 @@ from PIL import Image as PILImage from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION from lerobot.common.datasets.push_dataset_to_hub.utils import ( + calculate_episode_data_index, concatenate_episodes, get_default_encoding, save_images_concurrently, ) from lerobot.common.datasets.utils import ( - calculate_episode_data_index, hf_transform_to_torch, ) from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames diff --git a/lerobot/common/datasets/push_dataset_to_hub/cam_png_format.py b/lerobot/common/datasets/push_dataset_to_hub/cam_png_format.py index be20c92c..26492576 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/cam_png_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/cam_png_format.py @@ -24,8 +24,11 @@ from datasets import Dataset, Features, Image, Value from PIL import Image as PILImage from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION -from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes -from lerobot.common.datasets.utils import calculate_episode_data_index, hf_transform_to_torch +from lerobot.common.datasets.push_dataset_to_hub.utils import ( + calculate_episode_data_index, + concatenate_episodes, +) +from lerobot.common.datasets.utils import hf_transform_to_torch from lerobot.common.datasets.video_utils import VideoFrame diff --git a/lerobot/common/datasets/push_dataset_to_hub/dora_parquet_format.py b/lerobot/common/datasets/push_dataset_to_hub/dora_parquet_format.py index 72be130e..95f9c007 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/dora_parquet_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/dora_parquet_format.py @@ -26,8 +26,8 @@ import torch from datasets import Dataset, Features, Image, Sequence, Value from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION +from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index from lerobot.common.datasets.utils import ( - calculate_episode_data_index, hf_transform_to_torch, ) from lerobot.common.datasets.video_utils import VideoFrame diff --git a/lerobot/common/datasets/push_dataset_to_hub/openx_rlds_format.py b/lerobot/common/datasets/push_dataset_to_hub/openx_rlds_format.py index f5744c52..cfe11503 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/openx_rlds_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/openx_rlds_format.py @@ -42,12 +42,12 @@ from PIL import Image as PILImage from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION from lerobot.common.datasets.push_dataset_to_hub.openx.transforms import OPENX_STANDARDIZATION_TRANSFORMS from lerobot.common.datasets.push_dataset_to_hub.utils import ( + calculate_episode_data_index, concatenate_episodes, get_default_encoding, save_images_concurrently, ) from lerobot.common.datasets.utils import ( - calculate_episode_data_index, hf_transform_to_torch, ) from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames diff --git a/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py b/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py index 13d6c837..27b31ba2 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py @@ -27,12 +27,12 @@ from PIL import Image as PILImage from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION from lerobot.common.datasets.push_dataset_to_hub.utils import ( + calculate_episode_data_index, concatenate_episodes, get_default_encoding, save_images_concurrently, ) from lerobot.common.datasets.utils import ( - calculate_episode_data_index, hf_transform_to_torch, ) from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames diff --git a/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py b/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py index d724cf33..fec893a7 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py @@ -28,12 +28,12 @@ from PIL import Image as PILImage from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import register_codecs from lerobot.common.datasets.push_dataset_to_hub.utils import ( + calculate_episode_data_index, concatenate_episodes, get_default_encoding, save_images_concurrently, ) from lerobot.common.datasets.utils import ( - calculate_episode_data_index, hf_transform_to_torch, ) from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames diff --git a/lerobot/common/datasets/push_dataset_to_hub/utils.py b/lerobot/common/datasets/push_dataset_to_hub/utils.py index 97b54e45..ebcf87f7 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/utils.py +++ b/lerobot/common/datasets/push_dataset_to_hub/utils.py @@ -16,7 +16,9 @@ import inspect from concurrent.futures import ThreadPoolExecutor from pathlib import Path +from typing import Dict +import datasets import numpy import PIL import torch @@ -72,3 +74,58 @@ def check_repo_id(repo_id: str) -> None: f"""`repo_id` is expected to contain a community or user id `/` the name of the dataset (e.g. 'lerobot/pusht'), but contains '{repo_id}'.""" ) + + +# TODO(aliberts): remove +def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]: + """ + Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset. + + Parameters: + - hf_dataset (datasets.Dataset): A HuggingFace dataset containing the episode index. + + Returns: + - episode_data_index: A dictionary containing the data index for each episode. The dictionary has two keys: + - "from": A tensor containing the starting index of each episode. + - "to": A tensor containing the ending index of each episode. + """ + episode_data_index = {"from": [], "to": []} + + current_episode = None + """ + The episode_index is a list of integers, each representing the episode index of the corresponding example. + For instance, the following is a valid episode_index: + [0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2] + + Below, we iterate through the episode_index and populate the episode_data_index dictionary with the starting and + ending index of each episode. For the episode_index above, the episode_data_index dictionary will look like this: + { + "from": [0, 3, 7], + "to": [3, 7, 12] + } + """ + if len(hf_dataset) == 0: + episode_data_index = { + "from": torch.tensor([]), + "to": torch.tensor([]), + } + return episode_data_index + for idx, episode_idx in enumerate(hf_dataset["episode_index"]): + if episode_idx != current_episode: + # We encountered a new episode, so we append its starting location to the "from" list + episode_data_index["from"].append(idx) + # If this is not the first episode, we append the ending location of the previous episode to the "to" list + if current_episode is not None: + episode_data_index["to"].append(idx) + # Let's keep track of the current episode index + current_episode = episode_idx + else: + # We are still in the same episode, so there is nothing for us to do here + pass + # We have reached the end of the dataset, so we append the ending location of the last episode to the "to" list + episode_data_index["to"].append(idx + 1) + + for k in ["from", "to"]: + episode_data_index[k] = torch.tensor(episode_data_index[k]) + + return episode_data_index diff --git a/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py b/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py index ad1cb560..0047e48c 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py @@ -27,12 +27,12 @@ from PIL import Image as PILImage from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION from lerobot.common.datasets.push_dataset_to_hub.utils import ( + calculate_episode_data_index, concatenate_episodes, get_default_encoding, save_images_concurrently, ) from lerobot.common.datasets.utils import ( - calculate_episode_data_index, hf_transform_to_torch, ) from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index e21c0128..daebb505 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -18,7 +18,7 @@ import warnings from itertools import accumulate from pathlib import Path from pprint import pformat -from typing import Any, Dict +from typing import Any import datasets import jsonlines @@ -368,61 +368,6 @@ def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dic return delta_indices -# TODO(aliberts): remove -def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]: - """ - Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset. - - Parameters: - - hf_dataset (datasets.Dataset): A HuggingFace dataset containing the episode index. - - Returns: - - episode_data_index: A dictionary containing the data index for each episode. The dictionary has two keys: - - "from": A tensor containing the starting index of each episode. - - "to": A tensor containing the ending index of each episode. - """ - episode_data_index = {"from": [], "to": []} - - current_episode = None - """ - The episode_index is a list of integers, each representing the episode index of the corresponding example. - For instance, the following is a valid episode_index: - [0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2] - - Below, we iterate through the episode_index and populate the episode_data_index dictionary with the starting and - ending index of each episode. For the episode_index above, the episode_data_index dictionary will look like this: - { - "from": [0, 3, 7], - "to": [3, 7, 12] - } - """ - if len(hf_dataset) == 0: - episode_data_index = { - "from": torch.tensor([]), - "to": torch.tensor([]), - } - return episode_data_index - for idx, episode_idx in enumerate(hf_dataset["episode_index"]): - if episode_idx != current_episode: - # We encountered a new episode, so we append its starting location to the "from" list - episode_data_index["from"].append(idx) - # If this is not the first episode, we append the ending location of the previous episode to the "to" list - if current_episode is not None: - episode_data_index["to"].append(idx) - # Let's keep track of the current episode index - current_episode = episode_idx - else: - # We are still in the same episode, so there is nothing for us to do here - pass - # We have reached the end of the dataset, so we append the ending location of the last episode to the "to" list - episode_data_index["to"].append(idx + 1) - - for k in ["from", "to"]: - episode_data_index[k] = torch.tensor(episode_data_index[k]) - - return episode_data_index - - def cycle(iterable): """The equivalent of itertools.cycle, but safe for Pytorch dataloaders. diff --git a/tests/test_sampler.py b/tests/test_sampler.py index 635e7f11..ee143f37 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -15,9 +15,9 @@ # limitations under the License. from datasets import Dataset +from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index from lerobot.common.datasets.sampler import EpisodeAwareSampler from lerobot.common.datasets.utils import ( - calculate_episode_data_index, hf_transform_to_torch, ) diff --git a/tests/test_utils.py b/tests/test_utils.py index 42715e00..8880d28c 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -7,8 +7,8 @@ import pytest import torch from datasets import Dataset +from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index from lerobot.common.datasets.utils import ( - calculate_episode_data_index, hf_transform_to_torch, ) from lerobot.common.utils.utils import (