Move calculate_episode_data_index
This commit is contained in:
parent
74270c8c91
commit
7b159a6b22
|
@ -30,12 +30,12 @@ from PIL import Image as PILImage
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||||
|
calculate_episode_data_index,
|
||||||
concatenate_episodes,
|
concatenate_episodes,
|
||||||
get_default_encoding,
|
get_default_encoding,
|
||||||
save_images_concurrently,
|
save_images_concurrently,
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
calculate_episode_data_index,
|
|
||||||
hf_transform_to_torch,
|
hf_transform_to_torch,
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||||
|
|
|
@ -24,8 +24,11 @@ from datasets import Dataset, Features, Image, Value
|
||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes
|
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||||
from lerobot.common.datasets.utils import calculate_episode_data_index, hf_transform_to_torch
|
calculate_episode_data_index,
|
||||||
|
concatenate_episodes,
|
||||||
|
)
|
||||||
|
from lerobot.common.datasets.utils import hf_transform_to_torch
|
||||||
from lerobot.common.datasets.video_utils import VideoFrame
|
from lerobot.common.datasets.video_utils import VideoFrame
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -26,8 +26,8 @@ import torch
|
||||||
from datasets import Dataset, Features, Image, Sequence, Value
|
from datasets import Dataset, Features, Image, Sequence, Value
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||||
|
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
calculate_episode_data_index,
|
|
||||||
hf_transform_to_torch,
|
hf_transform_to_torch,
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.video_utils import VideoFrame
|
from lerobot.common.datasets.video_utils import VideoFrame
|
||||||
|
|
|
@ -42,12 +42,12 @@ from PIL import Image as PILImage
|
||||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.openx.transforms import OPENX_STANDARDIZATION_TRANSFORMS
|
from lerobot.common.datasets.push_dataset_to_hub.openx.transforms import OPENX_STANDARDIZATION_TRANSFORMS
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||||
|
calculate_episode_data_index,
|
||||||
concatenate_episodes,
|
concatenate_episodes,
|
||||||
get_default_encoding,
|
get_default_encoding,
|
||||||
save_images_concurrently,
|
save_images_concurrently,
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
calculate_episode_data_index,
|
|
||||||
hf_transform_to_torch,
|
hf_transform_to_torch,
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||||
|
|
|
@ -27,12 +27,12 @@ from PIL import Image as PILImage
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||||
|
calculate_episode_data_index,
|
||||||
concatenate_episodes,
|
concatenate_episodes,
|
||||||
get_default_encoding,
|
get_default_encoding,
|
||||||
save_images_concurrently,
|
save_images_concurrently,
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
calculate_episode_data_index,
|
|
||||||
hf_transform_to_torch,
|
hf_transform_to_torch,
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||||
|
|
|
@ -28,12 +28,12 @@ from PIL import Image as PILImage
|
||||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||||
from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import register_codecs
|
from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import register_codecs
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||||
|
calculate_episode_data_index,
|
||||||
concatenate_episodes,
|
concatenate_episodes,
|
||||||
get_default_encoding,
|
get_default_encoding,
|
||||||
save_images_concurrently,
|
save_images_concurrently,
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
calculate_episode_data_index,
|
|
||||||
hf_transform_to_torch,
|
hf_transform_to_torch,
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||||
|
|
|
@ -16,7 +16,9 @@
|
||||||
import inspect
|
import inspect
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import datasets
|
||||||
import numpy
|
import numpy
|
||||||
import PIL
|
import PIL
|
||||||
import torch
|
import torch
|
||||||
|
@ -72,3 +74,58 @@ def check_repo_id(repo_id: str) -> None:
|
||||||
f"""`repo_id` is expected to contain a community or user id `/` the name of the dataset
|
f"""`repo_id` is expected to contain a community or user id `/` the name of the dataset
|
||||||
(e.g. 'lerobot/pusht'), but contains '{repo_id}'."""
|
(e.g. 'lerobot/pusht'), but contains '{repo_id}'."""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(aliberts): remove
|
||||||
|
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- hf_dataset (datasets.Dataset): A HuggingFace dataset containing the episode index.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- episode_data_index: A dictionary containing the data index for each episode. The dictionary has two keys:
|
||||||
|
- "from": A tensor containing the starting index of each episode.
|
||||||
|
- "to": A tensor containing the ending index of each episode.
|
||||||
|
"""
|
||||||
|
episode_data_index = {"from": [], "to": []}
|
||||||
|
|
||||||
|
current_episode = None
|
||||||
|
"""
|
||||||
|
The episode_index is a list of integers, each representing the episode index of the corresponding example.
|
||||||
|
For instance, the following is a valid episode_index:
|
||||||
|
[0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]
|
||||||
|
|
||||||
|
Below, we iterate through the episode_index and populate the episode_data_index dictionary with the starting and
|
||||||
|
ending index of each episode. For the episode_index above, the episode_data_index dictionary will look like this:
|
||||||
|
{
|
||||||
|
"from": [0, 3, 7],
|
||||||
|
"to": [3, 7, 12]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
if len(hf_dataset) == 0:
|
||||||
|
episode_data_index = {
|
||||||
|
"from": torch.tensor([]),
|
||||||
|
"to": torch.tensor([]),
|
||||||
|
}
|
||||||
|
return episode_data_index
|
||||||
|
for idx, episode_idx in enumerate(hf_dataset["episode_index"]):
|
||||||
|
if episode_idx != current_episode:
|
||||||
|
# We encountered a new episode, so we append its starting location to the "from" list
|
||||||
|
episode_data_index["from"].append(idx)
|
||||||
|
# If this is not the first episode, we append the ending location of the previous episode to the "to" list
|
||||||
|
if current_episode is not None:
|
||||||
|
episode_data_index["to"].append(idx)
|
||||||
|
# Let's keep track of the current episode index
|
||||||
|
current_episode = episode_idx
|
||||||
|
else:
|
||||||
|
# We are still in the same episode, so there is nothing for us to do here
|
||||||
|
pass
|
||||||
|
# We have reached the end of the dataset, so we append the ending location of the last episode to the "to" list
|
||||||
|
episode_data_index["to"].append(idx + 1)
|
||||||
|
|
||||||
|
for k in ["from", "to"]:
|
||||||
|
episode_data_index[k] = torch.tensor(episode_data_index[k])
|
||||||
|
|
||||||
|
return episode_data_index
|
||||||
|
|
|
@ -27,12 +27,12 @@ from PIL import Image as PILImage
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||||
|
calculate_episode_data_index,
|
||||||
concatenate_episodes,
|
concatenate_episodes,
|
||||||
get_default_encoding,
|
get_default_encoding,
|
||||||
save_images_concurrently,
|
save_images_concurrently,
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
calculate_episode_data_index,
|
|
||||||
hf_transform_to_torch,
|
hf_transform_to_torch,
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||||
|
|
|
@ -18,7 +18,7 @@ import warnings
|
||||||
from itertools import accumulate
|
from itertools import accumulate
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import jsonlines
|
import jsonlines
|
||||||
|
@ -368,61 +368,6 @@ def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dic
|
||||||
return delta_indices
|
return delta_indices
|
||||||
|
|
||||||
|
|
||||||
# TODO(aliberts): remove
|
|
||||||
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- hf_dataset (datasets.Dataset): A HuggingFace dataset containing the episode index.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- episode_data_index: A dictionary containing the data index for each episode. The dictionary has two keys:
|
|
||||||
- "from": A tensor containing the starting index of each episode.
|
|
||||||
- "to": A tensor containing the ending index of each episode.
|
|
||||||
"""
|
|
||||||
episode_data_index = {"from": [], "to": []}
|
|
||||||
|
|
||||||
current_episode = None
|
|
||||||
"""
|
|
||||||
The episode_index is a list of integers, each representing the episode index of the corresponding example.
|
|
||||||
For instance, the following is a valid episode_index:
|
|
||||||
[0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]
|
|
||||||
|
|
||||||
Below, we iterate through the episode_index and populate the episode_data_index dictionary with the starting and
|
|
||||||
ending index of each episode. For the episode_index above, the episode_data_index dictionary will look like this:
|
|
||||||
{
|
|
||||||
"from": [0, 3, 7],
|
|
||||||
"to": [3, 7, 12]
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
if len(hf_dataset) == 0:
|
|
||||||
episode_data_index = {
|
|
||||||
"from": torch.tensor([]),
|
|
||||||
"to": torch.tensor([]),
|
|
||||||
}
|
|
||||||
return episode_data_index
|
|
||||||
for idx, episode_idx in enumerate(hf_dataset["episode_index"]):
|
|
||||||
if episode_idx != current_episode:
|
|
||||||
# We encountered a new episode, so we append its starting location to the "from" list
|
|
||||||
episode_data_index["from"].append(idx)
|
|
||||||
# If this is not the first episode, we append the ending location of the previous episode to the "to" list
|
|
||||||
if current_episode is not None:
|
|
||||||
episode_data_index["to"].append(idx)
|
|
||||||
# Let's keep track of the current episode index
|
|
||||||
current_episode = episode_idx
|
|
||||||
else:
|
|
||||||
# We are still in the same episode, so there is nothing for us to do here
|
|
||||||
pass
|
|
||||||
# We have reached the end of the dataset, so we append the ending location of the last episode to the "to" list
|
|
||||||
episode_data_index["to"].append(idx + 1)
|
|
||||||
|
|
||||||
for k in ["from", "to"]:
|
|
||||||
episode_data_index[k] = torch.tensor(episode_data_index[k])
|
|
||||||
|
|
||||||
return episode_data_index
|
|
||||||
|
|
||||||
|
|
||||||
def cycle(iterable):
|
def cycle(iterable):
|
||||||
"""The equivalent of itertools.cycle, but safe for Pytorch dataloaders.
|
"""The equivalent of itertools.cycle, but safe for Pytorch dataloaders.
|
||||||
|
|
||||||
|
|
|
@ -15,9 +15,9 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
|
|
||||||
|
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||||
from lerobot.common.datasets.sampler import EpisodeAwareSampler
|
from lerobot.common.datasets.sampler import EpisodeAwareSampler
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
calculate_episode_data_index,
|
|
||||||
hf_transform_to_torch,
|
hf_transform_to_torch,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -7,8 +7,8 @@ import pytest
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
|
|
||||||
|
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
calculate_episode_data_index,
|
|
||||||
hf_transform_to_torch,
|
hf_transform_to_torch,
|
||||||
)
|
)
|
||||||
from lerobot.common.utils.utils import (
|
from lerobot.common.utils.utils import (
|
||||||
|
|
Loading…
Reference in New Issue