Move calculate_episode_data_index

This commit is contained in:
Simon Alibert 2024-11-03 19:13:00 +01:00
parent 74270c8c91
commit 7b159a6b22
11 changed files with 71 additions and 66 deletions

View File

@ -30,12 +30,12 @@ from PIL import Image as PILImage
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.push_dataset_to_hub.utils import ( from lerobot.common.datasets.push_dataset_to_hub.utils import (
calculate_episode_data_index,
concatenate_episodes, concatenate_episodes,
get_default_encoding, get_default_encoding,
save_images_concurrently, save_images_concurrently,
) )
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
calculate_episode_data_index,
hf_transform_to_torch, hf_transform_to_torch,
) )
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames

View File

@ -24,8 +24,11 @@ from datasets import Dataset, Features, Image, Value
from PIL import Image as PILImage from PIL import Image as PILImage
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes from lerobot.common.datasets.push_dataset_to_hub.utils import (
from lerobot.common.datasets.utils import calculate_episode_data_index, hf_transform_to_torch calculate_episode_data_index,
concatenate_episodes,
)
from lerobot.common.datasets.utils import hf_transform_to_torch
from lerobot.common.datasets.video_utils import VideoFrame from lerobot.common.datasets.video_utils import VideoFrame

View File

@ -26,8 +26,8 @@ import torch
from datasets import Dataset, Features, Image, Sequence, Value from datasets import Dataset, Features, Image, Sequence, Value
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
calculate_episode_data_index,
hf_transform_to_torch, hf_transform_to_torch,
) )
from lerobot.common.datasets.video_utils import VideoFrame from lerobot.common.datasets.video_utils import VideoFrame

View File

@ -42,12 +42,12 @@ from PIL import Image as PILImage
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.push_dataset_to_hub.openx.transforms import OPENX_STANDARDIZATION_TRANSFORMS from lerobot.common.datasets.push_dataset_to_hub.openx.transforms import OPENX_STANDARDIZATION_TRANSFORMS
from lerobot.common.datasets.push_dataset_to_hub.utils import ( from lerobot.common.datasets.push_dataset_to_hub.utils import (
calculate_episode_data_index,
concatenate_episodes, concatenate_episodes,
get_default_encoding, get_default_encoding,
save_images_concurrently, save_images_concurrently,
) )
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
calculate_episode_data_index,
hf_transform_to_torch, hf_transform_to_torch,
) )
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames

View File

@ -27,12 +27,12 @@ from PIL import Image as PILImage
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.push_dataset_to_hub.utils import ( from lerobot.common.datasets.push_dataset_to_hub.utils import (
calculate_episode_data_index,
concatenate_episodes, concatenate_episodes,
get_default_encoding, get_default_encoding,
save_images_concurrently, save_images_concurrently,
) )
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
calculate_episode_data_index,
hf_transform_to_torch, hf_transform_to_torch,
) )
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames

View File

@ -28,12 +28,12 @@ from PIL import Image as PILImage
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import register_codecs from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import register_codecs
from lerobot.common.datasets.push_dataset_to_hub.utils import ( from lerobot.common.datasets.push_dataset_to_hub.utils import (
calculate_episode_data_index,
concatenate_episodes, concatenate_episodes,
get_default_encoding, get_default_encoding,
save_images_concurrently, save_images_concurrently,
) )
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
calculate_episode_data_index,
hf_transform_to_torch, hf_transform_to_torch,
) )
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames

View File

@ -16,7 +16,9 @@
import inspect import inspect
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from pathlib import Path from pathlib import Path
from typing import Dict
import datasets
import numpy import numpy
import PIL import PIL
import torch import torch
@ -72,3 +74,58 @@ def check_repo_id(repo_id: str) -> None:
f"""`repo_id` is expected to contain a community or user id `/` the name of the dataset f"""`repo_id` is expected to contain a community or user id `/` the name of the dataset
(e.g. 'lerobot/pusht'), but contains '{repo_id}'.""" (e.g. 'lerobot/pusht'), but contains '{repo_id}'."""
) )
# TODO(aliberts): remove
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]:
"""
Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset.
Parameters:
- hf_dataset (datasets.Dataset): A HuggingFace dataset containing the episode index.
Returns:
- episode_data_index: A dictionary containing the data index for each episode. The dictionary has two keys:
- "from": A tensor containing the starting index of each episode.
- "to": A tensor containing the ending index of each episode.
"""
episode_data_index = {"from": [], "to": []}
current_episode = None
"""
The episode_index is a list of integers, each representing the episode index of the corresponding example.
For instance, the following is a valid episode_index:
[0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]
Below, we iterate through the episode_index and populate the episode_data_index dictionary with the starting and
ending index of each episode. For the episode_index above, the episode_data_index dictionary will look like this:
{
"from": [0, 3, 7],
"to": [3, 7, 12]
}
"""
if len(hf_dataset) == 0:
episode_data_index = {
"from": torch.tensor([]),
"to": torch.tensor([]),
}
return episode_data_index
for idx, episode_idx in enumerate(hf_dataset["episode_index"]):
if episode_idx != current_episode:
# We encountered a new episode, so we append its starting location to the "from" list
episode_data_index["from"].append(idx)
# If this is not the first episode, we append the ending location of the previous episode to the "to" list
if current_episode is not None:
episode_data_index["to"].append(idx)
# Let's keep track of the current episode index
current_episode = episode_idx
else:
# We are still in the same episode, so there is nothing for us to do here
pass
# We have reached the end of the dataset, so we append the ending location of the last episode to the "to" list
episode_data_index["to"].append(idx + 1)
for k in ["from", "to"]:
episode_data_index[k] = torch.tensor(episode_data_index[k])
return episode_data_index

View File

@ -27,12 +27,12 @@ from PIL import Image as PILImage
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.push_dataset_to_hub.utils import ( from lerobot.common.datasets.push_dataset_to_hub.utils import (
calculate_episode_data_index,
concatenate_episodes, concatenate_episodes,
get_default_encoding, get_default_encoding,
save_images_concurrently, save_images_concurrently,
) )
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
calculate_episode_data_index,
hf_transform_to_torch, hf_transform_to_torch,
) )
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames

View File

@ -18,7 +18,7 @@ import warnings
from itertools import accumulate from itertools import accumulate
from pathlib import Path from pathlib import Path
from pprint import pformat from pprint import pformat
from typing import Any, Dict from typing import Any
import datasets import datasets
import jsonlines import jsonlines
@ -368,61 +368,6 @@ def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dic
return delta_indices return delta_indices
# TODO(aliberts): remove
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]:
"""
Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset.
Parameters:
- hf_dataset (datasets.Dataset): A HuggingFace dataset containing the episode index.
Returns:
- episode_data_index: A dictionary containing the data index for each episode. The dictionary has two keys:
- "from": A tensor containing the starting index of each episode.
- "to": A tensor containing the ending index of each episode.
"""
episode_data_index = {"from": [], "to": []}
current_episode = None
"""
The episode_index is a list of integers, each representing the episode index of the corresponding example.
For instance, the following is a valid episode_index:
[0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]
Below, we iterate through the episode_index and populate the episode_data_index dictionary with the starting and
ending index of each episode. For the episode_index above, the episode_data_index dictionary will look like this:
{
"from": [0, 3, 7],
"to": [3, 7, 12]
}
"""
if len(hf_dataset) == 0:
episode_data_index = {
"from": torch.tensor([]),
"to": torch.tensor([]),
}
return episode_data_index
for idx, episode_idx in enumerate(hf_dataset["episode_index"]):
if episode_idx != current_episode:
# We encountered a new episode, so we append its starting location to the "from" list
episode_data_index["from"].append(idx)
# If this is not the first episode, we append the ending location of the previous episode to the "to" list
if current_episode is not None:
episode_data_index["to"].append(idx)
# Let's keep track of the current episode index
current_episode = episode_idx
else:
# We are still in the same episode, so there is nothing for us to do here
pass
# We have reached the end of the dataset, so we append the ending location of the last episode to the "to" list
episode_data_index["to"].append(idx + 1)
for k in ["from", "to"]:
episode_data_index[k] = torch.tensor(episode_data_index[k])
return episode_data_index
def cycle(iterable): def cycle(iterable):
"""The equivalent of itertools.cycle, but safe for Pytorch dataloaders. """The equivalent of itertools.cycle, but safe for Pytorch dataloaders.

View File

@ -15,9 +15,9 @@
# limitations under the License. # limitations under the License.
from datasets import Dataset from datasets import Dataset
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
from lerobot.common.datasets.sampler import EpisodeAwareSampler from lerobot.common.datasets.sampler import EpisodeAwareSampler
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
calculate_episode_data_index,
hf_transform_to_torch, hf_transform_to_torch,
) )

View File

@ -7,8 +7,8 @@ import pytest
import torch import torch
from datasets import Dataset from datasets import Dataset
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
calculate_episode_data_index,
hf_transform_to_torch, hf_transform_to_torch,
) )
from lerobot.common.utils.utils import ( from lerobot.common.utils.utils import (