From 8da16b16250b9724b5ef57236236bd5f3885ceec Mon Sep 17 00:00:00 2001 From: Ben Sprenger Date: Fri, 7 Mar 2025 15:03:28 +0100 Subject: [PATCH 1/2] feat: add remove_episodes utility This commit introduces a remove_episodes function/CLI tool to remove specific episodes from a dataset, and will automatically modify all required data, video, and metadata. Optionally, the script will push the modified dataset to the hub (True by default). The function will safely remove the episodes, meaning that if at any point during the process a failure occurs, the original dataset is preserved. Additionally, the original dataset is optionally backed up in case it is needed to revert to. --- lerobot/common/datasets/lerobot_dataset.py | 2 +- lerobot/scripts/remove_episodes.py | 407 +++++++++++++++++++++ tests/datasets/test_datasets.py | 36 ++ 3 files changed, 444 insertions(+), 1 deletion(-) create mode 100644 lerobot/scripts/remove_episodes.py diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index d8da85d6..d1c41104 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -601,7 +601,7 @@ class LeRobotDataset(torch.utils.data.Dataset): self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns) def get_episodes_file_paths(self) -> list[Path]: - episodes = self.episodes if self.episodes is not None else list(range(self.meta.total_episodes)) + episodes = self.episodes if self.episodes is not None else list(self.meta.episodes) fpaths = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in episodes] if len(self.meta.video_keys) > 0: video_files = [ diff --git a/lerobot/scripts/remove_episodes.py b/lerobot/scripts/remove_episodes.py new file mode 100644 index 00000000..fe778a48 --- /dev/null +++ b/lerobot/scripts/remove_episodes.py @@ -0,0 +1,407 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import contextlib +import logging +import shutil +import sys +import tempfile +import time +from copy import deepcopy +from pathlib import Path + +from huggingface_hub import HfApi +from huggingface_hub.utils import RevisionNotFoundError + +from lerobot.common.datasets.compute_stats import aggregate_stats +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata +from lerobot.common.datasets.utils import ( + EPISODES_PATH, + EPISODES_STATS_PATH, + INFO_PATH, + TASKS_PATH, + append_jsonlines, + create_lerobot_dataset_card, + write_episode, + write_episode_stats, + write_info, +) +from lerobot.common.utils.utils import init_logging + + +def remove_episodes( + dataset: LeRobotDataset, + episodes_to_remove: list[int], + backup: str | Path | bool = False, +) -> LeRobotDataset: + """ + Removes specified episodes from a LeRobotDataset and updates all metadata and files accordingly. + + Args: + dataset: The LeRobotDataset to modify + episodes_to_remove: List of episode indices to remove + backup: Controls backup behavior: + - False: No backup is created + - True: Create backup at default location next to dataset + - str/Path: Create backup at the specified location + + Returns: + Updated LeRobotDataset with specified episodes removed + """ + if not episodes_to_remove: + return dataset + + if not all(ep_idx in dataset.meta.episodes for ep_idx in episodes_to_remove): + raise ValueError("Episodes to remove must be valid episode indices in the dataset") + + # Calculate the new metadata + new_meta = deepcopy(dataset.meta) + new_meta.info["total_episodes"] -= len(episodes_to_remove) + new_meta.info["total_frames"] -= sum( + dataset.meta.episodes[ep_idx]["length"] for ep_idx in episodes_to_remove + ) + + for ep_idx in episodes_to_remove: + new_meta.episodes.pop(ep_idx) + new_meta.episodes_stats.pop(ep_idx) + new_meta.stats = aggregate_stats(list(new_meta.episodes_stats.values())) + + tasks = {task for ep in new_meta.episodes.values() if "tasks" in ep for task in ep["tasks"]} + new_meta.tasks = {new_meta.get_task_index(task): task for task in tasks} + new_meta.task_to_task_index = {task: idx for idx, task in new_meta.tasks.items()} + new_meta.info["total_tasks"] = len(new_meta.tasks) + + new_meta.info["total_videos"] = ( + (new_meta.info["total_episodes"]) * len(dataset.meta.video_keys) if dataset.meta.video_keys else 0 + ) + + if "splits" in new_meta.info: + new_meta.info["splits"] = {"train": f"0:{new_meta.info['total_episodes']}"} + + # Now that the metadata is recalculated, we update the dataset files by + # removing the files related to the specified episodes. We perform a safe + # update such that if an error occurs, any changes are rolled back and the + # dataset files are left in its original state. Optionally, a non-temporary + # full backup can be made so that we also have the dataset in its original state. + if backup: + backup_path = ( + Path(backup) + if isinstance(backup, (str, Path)) + else dataset.root.parent / f"{dataset.root.name}_backup_{int(time.time())}" + ) + _backup_folder(dataset.root, backup_path) + + _update_dataset_files( + new_meta, + episodes_to_remove, + ) + + updated_dataset = LeRobotDataset( + repo_id=dataset.repo_id, + root=dataset.root, + episodes=None, # Load all episodes + image_transforms=dataset.image_transforms, + delta_timestamps=dataset.delta_timestamps, + tolerance_s=dataset.tolerance_s, + revision=dataset.revision, + download_videos=False, # No need to download, we just saved them + video_backend=dataset.video_backend, + ) + + return updated_dataset + + +def _move_file(src: Path, dest: Path) -> None: + dest.parent.mkdir(parents=True, exist_ok=True) + shutil.move(src, dest) + + +def _update_dataset_files(new_meta: LeRobotDatasetMetadata, episodes_to_remove: list[int]): + """Update dataset files. + + This function performs a safe update for dataset files. It moves modified or removed + episode files to a temporary directory. Once all changes are made, the temporary + directory is deleted. If an error occurs during the update, all changes are rolled + back and the original dataset files are restored. + + Args: + new_meta (LeRobotDatasetMetadata): Updated metadata object containing the new + dataset state after removing episodes + episodes_to_remove (list[int]): List of episode indices to remove from the dataset + + Raises: + Exception: If any operation fails, rolls back all changes and re-raises the original exception + """ + with tempfile.TemporaryDirectory(prefix="lerobot_backup_temp_") as backup_path: + backup_dir = Path(backup_path) + + # Init empty containers s.t. they are guaranteed to exist in the except block + metadata_files = {} + rel_data_paths = [] + rel_video_paths = [] + + try: + # Step 1: Update metadata files + metadata_files = { + INFO_PATH: lambda: write_info(new_meta.info, new_meta.root), + EPISODES_PATH: lambda: [ + write_episode(ep, new_meta.root) for ep in new_meta.episodes.values() + ], + TASKS_PATH: lambda: [ + append_jsonlines({"task_index": idx, "task": task}, new_meta.root / TASKS_PATH) + for idx, task in new_meta.tasks.items() + ], + EPISODES_STATS_PATH: lambda: [ + write_episode_stats(idx, stats, new_meta.root) + for idx, stats in new_meta.episodes_stats.items() + ], + } + for file_path, update_func in metadata_files.items(): + _move_file(new_meta.root / file_path, backup_dir / file_path) + update_func() + + # Step 2: Update data and video + rel_data_paths = [new_meta.get_data_file_path(ep_idx) for ep_idx in episodes_to_remove] + rel_video_paths = [ + new_meta.get_video_file_path(ep_idx, vid_key) + for ep_idx in episodes_to_remove + for vid_key in new_meta.video_keys + ] + for rel_path in rel_data_paths + rel_video_paths: + if (new_meta.root / rel_path).exists(): + _move_file(new_meta.root / rel_path, backup_dir / rel_path) + + except Exception as e: + logging.error(f"Error updating dataset files: {str(e)}. Rolling back changes.") + + # Restore metadata files + for file_path in metadata_files: + if (backup_dir / file_path).exists(): + _move_file(backup_dir / file_path, new_meta.root / file_path) + + # Restore data and video files + for rel_file_path in rel_data_paths + rel_video_paths: + if (backup_dir / rel_file_path).exists(): + _move_file(backup_dir / rel_file_path, new_meta.root / rel_file_path) + + raise e + + +def _backup_folder(target_dir: Path, backup_path: Path) -> None: + if backup_path.resolve() == target_dir.resolve() or backup_path.resolve().is_relative_to( + target_dir.resolve() + ): + raise ValueError( + f"Backup directory '{backup_path}' cannot be inside the dataset " + f"directory '{target_dir}' as this would cause infinite recursion" + ) + + backup_path.parent.mkdir(parents=True, exist_ok=True) + logging.info(f"Creating backup at: {backup_path}") + shutil.copytree(target_dir, backup_path) + + +def _parse_episodes_list(episodes_str: str) -> list[int]: + """ + Parse a string of episode indices, ranges, and comma-separated lists into a list of integers. + """ + episodes = [] + for ep in episodes_str.split(","): + if "-" in ep: + start, end = ep.split("-") + episodes.extend(range(int(start), int(end) + 1)) + else: + episodes.append(int(ep)) + return episodes + + +def _delete_hub_file(hub_api: HfApi, repo_id: str, file_path: str, branch: str): + try: + with contextlib.suppress(RevisionNotFoundError): + if hub_api.file_exists( + repo_id, + file_path, + repo_type="dataset", + revision=branch, + ): + hub_api.delete_file( + path_in_repo=file_path, + repo_id=repo_id, + repo_type="dataset", + revision=branch, + ) + logging.info(f"Deleted '{file_path}' from branch '{branch}'") + except Exception as e: + logging.error(f"Error removing file '{file_path}' from the hub: {str(e)}") + + +def _remove_episodes_from_hub( + updated_dataset: LeRobotDataset, episodes_to_remove: list[int], branch: str | None = None +): + """Remove episodes from the hub repository at a specific revision.""" + hub_api = HfApi() + + try: + for ep_idx in episodes_to_remove: + data_path = str(updated_dataset.meta.get_data_file_path(ep_idx)) + _delete_hub_file(hub_api, updated_dataset.repo_id, data_path, branch) + + for vid_key in updated_dataset.meta.video_keys: + video_path = str(updated_dataset.meta.get_video_file_path(ep_idx, vid_key)) + _delete_hub_file(hub_api, updated_dataset.repo_id, video_path, branch) + + logging.info(f"Successfully removed episode files from Hub on branch '{branch}'") + + except RevisionNotFoundError: + logging.error(f"Branch '{branch}' not found in repository '{updated_dataset.repo_id}'") + except Exception as e: + logging.error(f"Error during Hub operations: {str(e)}") + + +def main(): + parser = argparse.ArgumentParser(description="Remove episodes from a LeRobot dataset") + parser.add_argument( + "--repo-id", + type=str, + required=True, + help="Name of hugging face repository containing a LeRobotDataset dataset (e.g. `lerobot/pusht`).", + ) + parser.add_argument( + "--root", + type=Path, + default=None, + help="Root directory for the dataset stored locally. By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.", + ) + parser.add_argument( + "-e", + "--episodes", + type=str, + required=True, + help="Episodes to remove. Can be a single index, comma-separated indices, or ranges (e.g., '1-5,7,10-12')", + ) + parser.add_argument( + "-b", + "--backup", + nargs="?", + const=True, + default=False, + help="Create a backup before modifying the dataset. Without a value, creates a backup in the default location. " + "With a value, either 'true'/'false' or a path to store the backup.", + ) + parser.add_argument( + "--push-to-hub", + type=int, + default=1, + help="Upload to Hugging Face hub.", + ) + parser.add_argument( + "--private", + type=int, + default=0, + help="If set, the repository on the Hub will be private", + ) + parser.add_argument( + "--tags", + type=str, + nargs="+", + help="List of tags to apply to the dataset on the Hub", + ) + parser.add_argument("--license", type=str, default=None, help="License to use for the dataset on the Hub") + args = parser.parse_args() + + # Parse the backup argument + backup_value = args.backup + if isinstance(backup_value, str): + if backup_value.lower() == "true": + backup_value = True + elif backup_value.lower() == "false": + backup_value = False + # Otherwise, it's treated as a path + + # Parse episodes to remove + episodes_to_remove = _parse_episodes_list(args.episodes) + if not episodes_to_remove: + logging.warning("No episodes specified to remove") + sys.exit(0) + + # Load the dataset + logging.info(f"Loading dataset '{args.repo_id}'...") + dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root) + logging.info(f"Dataset has {dataset.meta.total_episodes} episodes") + + target_revision_tag = dataset.revision + push_branch = "main" + logging.info( + f"Dataset loaded using revision tag: {target_revision_tag}. Changes will be pushed to 'main' and this tag will be updated." + ) + + # Modify the dataset + logging.info(f"Removing {len(set(episodes_to_remove))} episodes: {sorted(set(episodes_to_remove))}") + updated_dataset = remove_episodes( + dataset=dataset, + episodes_to_remove=episodes_to_remove, + backup=backup_value, + ) + logging.info( + f"Successfully removed episodes. Dataset now has {updated_dataset.meta.total_episodes} episodes." + ) + + if args.push_to_hub: + logging.info("Pushing dataset to hub...") + + updated_dataset.push_to_hub( + tags=args.tags, + private=bool(args.private), + license=args.license, + branch=push_branch, + tag_version=False, # Disable automatic tagging here, we'll do it manually later + ) + updated_card = create_lerobot_dataset_card( + tags=args.tags, dataset_info=updated_dataset.meta.info, license=args.license + ) + updated_card.push_to_hub(repo_id=updated_dataset.repo_id, repo_type="dataset", revision=push_branch) + _remove_episodes_from_hub(updated_dataset, episodes_to_remove, branch=push_branch) + + logging.info( + f"Updating tag '{target_revision_tag}' to point to the latest commit on branch '{push_branch}'..." + ) + hub_api = HfApi() + try: + # Delete the old tag first if it exists + with contextlib.suppress(RevisionNotFoundError): + hub_api.delete_tag(updated_dataset.repo_id, tag=target_revision_tag, repo_type="dataset") + logging.info(f"Deleted existing tag '{target_revision_tag}'.") + + # Create the new tag pointing to the head of the push branch + hub_api.create_tag( + updated_dataset.repo_id, + tag=target_revision_tag, + revision=push_branch, + repo_type="dataset", + ) + logging.info( + f"Successfully created tag '{target_revision_tag}' pointing to branch '{push_branch}'." + ) + + except Exception as e: + logging.error(f"Error during tag update for '{target_revision_tag}': {str(e)}") + + logging.info("Dataset pushed to hub.") + + +if __name__ == "__main__": + init_logging() + main() diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 81447089..709d2328 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -19,6 +19,7 @@ import re from copy import deepcopy from itertools import chain from pathlib import Path +from unittest.mock import patch import numpy as np import pytest @@ -28,6 +29,7 @@ from PIL import Image from safetensors.torch import load_file import lerobot +from lerobot.common.datasets.episode_utils import remove_episodes from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.image_writer import image_array_to_pil_image from lerobot.common.datasets.lerobot_dataset import ( @@ -580,3 +582,37 @@ def test_dataset_feature_with_forward_slash_raises_error(): fps=30, features={"a/b": {"dtype": "float32", "shape": 2, "names": None}}, ) + + +@pytest.mark.parametrize( + "total_episodes, total_frames, episodes_to_remove", + [ + (3, 30, [1]), + (3, 30, [0, 2]), + (4, 50, [1, 2]), + ], +) +def test_remove_episodes(tmp_path, lerobot_dataset_factory, total_episodes, total_frames, episodes_to_remove): + dataset = lerobot_dataset_factory( + root=tmp_path / "test", + total_episodes=total_episodes, + total_frames=total_frames, + ) + num_frames_to_remove = 0 + for ep in episodes_to_remove: + num_frames_to_remove += ( + dataset.episode_data_index["to"][ep].item() - dataset.episode_data_index["from"][ep].item() + ) + + with ( + patch("lerobot.common.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.common.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.side_effect = lambda repo_id, version: version + mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(dataset.root) + updated_dataset = remove_episodes(dataset, episodes_to_remove) + + assert updated_dataset.meta.total_episodes == total_episodes - len(episodes_to_remove) + assert updated_dataset.meta.total_frames == total_frames - num_frames_to_remove + for i, ep_meta in enumerate(updated_dataset.meta.episodes.values()): + assert ep_meta["episode_index"] == i From 997133c6bc5811d5e9ecf340d100916e232b4faa Mon Sep 17 00:00:00 2001 From: Ben Sprenger Date: Sun, 16 Mar 2025 15:07:22 +0100 Subject: [PATCH 2/2] fix: Add translation function for non-sequential episode indices When episodes are removed from a LeRobotDataset, the remaining episode indices are no longer sequential, which causes indexing errors in get_episode_data(). This happens because episode_data_index tensors are always indexed sequentially, while the episode indices can be arbitrary. This commit introduces a helper function to make the conversion. --- lerobot/common/datasets/lerobot_dataset.py | 6 ++-- lerobot/common/datasets/utils.py | 30 ++++++++++++++++++++ lerobot/common/datasets/v21/convert_stats.py | 7 +++-- lerobot/scripts/visualize_dataset.py | 6 ++-- lerobot/scripts/visualize_dataset_html.py | 17 +++++++---- 5 files changed, 53 insertions(+), 13 deletions(-) diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index d1c41104..7dabb6a2 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -58,6 +58,7 @@ from lerobot.common.datasets.utils import ( load_info, load_stats, load_tasks, + translate_episode_index_to_position, validate_episode_buffer, validate_frame, write_episode, @@ -663,8 +664,9 @@ class LeRobotDataset(torch.utils.data.Dataset): return get_hf_features_from_features(self.features) def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]: - ep_start = self.episode_data_index["from"][ep_idx] - ep_end = self.episode_data_index["to"][ep_idx] + index_position = translate_episode_index_to_position(self.meta.episodes, ep_idx) + ep_start = self.episode_data_index["from"][index_position] + ep_end = self.episode_data_index["to"][index_position] query_indices = { key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx] for key, delta_idx in self.delta_indices.items() diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 9d8a54db..f5abf1aa 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -811,3 +811,33 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: f"In episode_buffer not in features: {buffer_keys - set(features)}" f"In features not in episode_buffer: {set(features) - buffer_keys}" ) + + +def translate_episode_index_to_position(episode_dicts: dict[dict], episode_index: int) -> int: + """ + Translates an actual episode index to its position in the sequential episode_data_index tensors. + + When episodes are removed from a dataset, the remaining episode indices may no longer be sequential + (e.g., they could be [0, 3, 7, 10]). However, the dataset's episode_data_index tensors are always + indexed sequentially from 0 to len(episodes)-1. This function provides the mapping between these + two indexing schemes. + + Example: + If a dataset originally had episodes [0, 1, 2, 3, 4] but episodes 1 and 3 were removed, + the remaining episodes would be [0, 2, 4]. In this case: + - Episode index 0 would be at position 0 + - Episode index 2 would be at position 1 + - Episode index 4 would be at position 2 + + So translate_episode_index_to_position(episode_dicts, 4) would return 2. + + Args: + episode_dicts (dict[dict]): Dictionary of episode dictionaries or list of episode indices + episode_index (int): The actual episode index to translate + + Returns: + int: The position of the episode in the episode_data_index tensors + """ + episode_to_position = {ep_idx: i for i, ep_idx in enumerate(episode_dicts)} + position = episode_to_position[episode_index] + return position diff --git a/lerobot/common/datasets/v21/convert_stats.py b/lerobot/common/datasets/v21/convert_stats.py index 4a20b427..5fd6e3b9 100644 --- a/lerobot/common/datasets/v21/convert_stats.py +++ b/lerobot/common/datasets/v21/convert_stats.py @@ -19,7 +19,7 @@ from tqdm import tqdm from lerobot.common.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices from lerobot.common.datasets.lerobot_dataset import LeRobotDataset -from lerobot.common.datasets.utils import write_episode_stats +from lerobot.common.datasets.utils import translate_episode_index_to_position, write_episode_stats def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray: @@ -31,8 +31,9 @@ def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_ def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int): - ep_start_idx = dataset.episode_data_index["from"][ep_idx] - ep_end_idx = dataset.episode_data_index["to"][ep_idx] + index_position = translate_episode_index_to_position(dataset.meta.episodes, ep_idx) + ep_start_idx = dataset.episode_data_index["from"][index_position] + ep_end_idx = dataset.episode_data_index["to"][index_position] ep_data = dataset.hf_dataset.select(range(ep_start_idx, ep_end_idx)) ep_stats = {} diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py index cdfea6b8..a25364eb 100644 --- a/lerobot/scripts/visualize_dataset.py +++ b/lerobot/scripts/visualize_dataset.py @@ -75,12 +75,14 @@ import torch.utils.data import tqdm from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.datasets.utils import translate_episode_index_to_position class EpisodeSampler(torch.utils.data.Sampler): def __init__(self, dataset: LeRobotDataset, episode_index: int): - from_idx = dataset.episode_data_index["from"][episode_index].item() - to_idx = dataset.episode_data_index["to"][episode_index].item() + index_position = translate_episode_index_to_position(dataset.meta.episodes, episode_index) + from_idx = dataset.episode_data_index["from"][index_position].item() + to_idx = dataset.episode_data_index["to"][index_position].item() self.frame_ids = range(from_idx, to_idx) def __iter__(self) -> Iterator: diff --git a/lerobot/scripts/visualize_dataset_html.py b/lerobot/scripts/visualize_dataset_html.py index 0fc21a8f..42d8d1da 100644 --- a/lerobot/scripts/visualize_dataset_html.py +++ b/lerobot/scripts/visualize_dataset_html.py @@ -69,7 +69,7 @@ from flask import Flask, redirect, render_template, request, url_for from lerobot import available_datasets from lerobot.common.datasets.lerobot_dataset import LeRobotDataset -from lerobot.common.datasets.utils import IterableNamespace +from lerobot.common.datasets.utils import IterableNamespace, translate_episode_index_to_position from lerobot.common.utils.utils import init_logging @@ -207,7 +207,9 @@ def run_server( if episodes is None: episodes = list( - range(dataset.num_episodes if isinstance(dataset, LeRobotDataset) else dataset.total_episodes) + dataset.meta.episodes + if isinstance(dataset, LeRobotDataset) + else range(dataset.total_episodes) ) return render_template( @@ -268,8 +270,9 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index) selected_columns.insert(0, "timestamp") if isinstance(dataset, LeRobotDataset): - from_idx = dataset.episode_data_index["from"][episode_index] - to_idx = dataset.episode_data_index["to"][episode_index] + index_position = translate_episode_index_to_position(dataset.meta.episodes, episode_index) + from_idx = dataset.episode_data_index["from"][index_position] + to_idx = dataset.episode_data_index["to"][index_position] data = ( dataset.hf_dataset.select(range(from_idx, to_idx)) .select_columns(selected_columns) @@ -305,7 +308,8 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index) def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]: # get first frame of episode (hack to get video_path of the episode) - first_frame_idx = dataset.episode_data_index["from"][ep_index].item() + index_position = translate_episode_index_to_position(dataset.meta.episodes, ep_index) + first_frame_idx = dataset.episode_data_index["from"][index_position].item() return [ dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"] for key in dataset.meta.video_keys @@ -318,7 +322,8 @@ def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) -> return None # get first frame index - first_frame_idx = dataset.episode_data_index["from"][ep_index].item() + index_position = translate_episode_index_to_position(dataset.meta.episodes, ep_index) + first_frame_idx = dataset.episode_data_index["from"][index_position].item() language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"] # TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored