This commit is contained in:
Ben Sprenger 2025-04-04 20:50:09 -07:00 committed by GitHub
commit 000e447cde
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 448 additions and 12 deletions

View File

@ -601,7 +601,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
def get_episodes_file_paths(self) -> list[Path]:
episodes = self.episodes if self.episodes is not None else list(range(self.meta.total_episodes))
episodes = self.episodes if self.episodes is not None else list(self.meta.episodes)
fpaths = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in episodes]
if len(self.meta.video_keys) > 0:
video_files = [

View File

@ -811,3 +811,33 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features:
f"In episode_buffer not in features: {buffer_keys - set(features)}"
f"In features not in episode_buffer: {set(features) - buffer_keys}"
)
def translate_episode_index_to_position(episode_dicts: dict[dict], episode_index: int) -> int:
"""
Translates an actual episode index to its position in the sequential episode_data_index tensors.
When episodes are removed from a dataset, the remaining episode indices may no longer be sequential
(e.g., they could be [0, 3, 7, 10]). However, the dataset's episode_data_index tensors are always
indexed sequentially from 0 to len(episodes)-1. This function provides the mapping between these
two indexing schemes.
Example:
If a dataset originally had episodes [0, 1, 2, 3, 4] but episodes 1 and 3 were removed,
the remaining episodes would be [0, 2, 4]. In this case:
- Episode index 0 would be at position 0
- Episode index 2 would be at position 1
- Episode index 4 would be at position 2
So translate_episode_index_to_position(episode_dicts, 4) would return 2.
Args:
episode_dicts (dict[dict]): Dictionary of episode dictionaries or list of episode indices
episode_index (int): The actual episode index to translate
Returns:
int: The position of the episode in the episode_data_index tensors
"""
episode_to_position = {ep_idx: i for i, ep_idx in enumerate(episode_dicts)}
position = episode_to_position[episode_index]
return position

View File

@ -19,7 +19,7 @@ from tqdm import tqdm
from lerobot.common.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import write_episode_stats
from lerobot.common.datasets.utils import translate_episode_index_to_position, write_episode_stats
def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray:
@ -31,8 +31,9 @@ def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_
def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
ep_start_idx = dataset.episode_data_index["from"][ep_idx]
ep_end_idx = dataset.episode_data_index["to"][ep_idx]
index_position = translate_episode_index_to_position(dataset.meta.episodes, ep_idx)
ep_start_idx = dataset.episode_data_index["from"][index_position]
ep_end_idx = dataset.episode_data_index["to"][index_position]
ep_data = dataset.hf_dataset.select(range(ep_start_idx, ep_end_idx))
ep_stats = {}

View File

@ -0,0 +1,362 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import shutil
import sys
import tempfile
import time
from copy import deepcopy
from pathlib import Path
from huggingface_hub import HfApi
from lerobot.common.datasets.compute_stats import aggregate_stats
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.common.datasets.utils import (
EPISODES_PATH,
EPISODES_STATS_PATH,
INFO_PATH,
TASKS_PATH,
append_jsonlines,
create_lerobot_dataset_card,
write_episode,
write_episode_stats,
write_info,
)
from lerobot.common.utils.utils import init_logging
def remove_episodes(
dataset: LeRobotDataset,
episodes_to_remove: list[int],
backup: str | Path | bool = False,
) -> LeRobotDataset:
"""
Removes specified episodes from a LeRobotDataset and updates all metadata and files accordingly.
Args:
dataset: The LeRobotDataset to modify
episodes_to_remove: List of episode indices to remove
backup: Controls backup behavior:
- False: No backup is created
- True: Create backup at default location next to dataset
- str/Path: Create backup at the specified location
Returns:
Updated LeRobotDataset with specified episodes removed
"""
if not episodes_to_remove:
return dataset
if not all(ep_idx in dataset.meta.episodes for ep_idx in episodes_to_remove):
raise ValueError("Episodes to remove must be valid episode indices in the dataset")
# Calculate the new metadata
new_meta = deepcopy(dataset.meta)
new_meta.info["total_episodes"] -= len(episodes_to_remove)
new_meta.info["total_frames"] -= sum(
dataset.meta.episodes[ep_idx]["length"] for ep_idx in episodes_to_remove
)
for ep_idx in episodes_to_remove:
new_meta.episodes.pop(ep_idx)
new_meta.episodes_stats.pop(ep_idx)
new_meta.stats = aggregate_stats(list(new_meta.episodes_stats.values()))
tasks = {task for ep in new_meta.episodes.values() if "tasks" in ep for task in ep["tasks"]}
new_meta.tasks = {new_meta.get_task_index(task): task for task in tasks}
new_meta.task_to_task_index = {task: idx for idx, task in new_meta.tasks.items()}
new_meta.info["total_tasks"] = len(new_meta.tasks)
new_meta.info["total_videos"] = (
(new_meta.info["total_episodes"]) * len(dataset.meta.video_keys) if dataset.meta.video_keys else 0
)
if "splits" in new_meta.info:
new_meta.info["splits"] = {"train": f"0:{new_meta.info['total_episodes']}"}
# Now that the metadata is recalculated, we update the dataset files by
# removing the files related to the specified episodes. We perform a safe
# update such that if an error occurs, any changes are rolled back and the
# dataset files are left in its original state. Optionally, a non-temporary
# full backup can be made so that we also have the dataset in its original state.
if backup:
backup_path = (
Path(backup)
if isinstance(backup, (str, Path))
else dataset.root.parent / f"{dataset.root.name}_backup_{int(time.time())}"
)
_backup_folder(dataset.root, backup_path)
_update_dataset_files(
new_meta,
episodes_to_remove,
)
updated_dataset = LeRobotDataset(
repo_id=dataset.repo_id,
root=dataset.root,
episodes=None, # Load all episodes
image_transforms=dataset.image_transforms,
delta_timestamps=dataset.delta_timestamps,
tolerance_s=dataset.tolerance_s,
revision=dataset.revision,
download_videos=False, # No need to download, we just saved them
video_backend=dataset.video_backend,
)
return updated_dataset
def _move_file(src: Path, dest: Path) -> None:
dest.parent.mkdir(parents=True, exist_ok=True)
shutil.move(src, dest)
def _update_dataset_files(new_meta: LeRobotDatasetMetadata, episodes_to_remove: list[int]):
"""Update dataset files.
This function performs a safe update for dataset files. It moves modified or removed
episode files to a temporary directory. Once all changes are made, the temporary
directory is deleted. If an error occurs during the update, all changes are rolled
back and the original dataset files are restored.
Args:
new_meta (LeRobotDatasetMetadata): Updated metadata object containing the new
dataset state after removing episodes
episodes_to_remove (list[int]): List of episode indices to remove from the dataset
Raises:
Exception: If any operation fails, rolls back all changes and re-raises the original exception
"""
with tempfile.TemporaryDirectory(prefix="lerobot_backup_temp_") as backup_path:
backup_dir = Path(backup_path)
# Init empty containers s.t. they are guaranteed to exist in the except block
metadata_files = {}
rel_data_paths = []
rel_video_paths = []
try:
# Step 1: Update metadata files
metadata_files = {
INFO_PATH: lambda: write_info(new_meta.info, new_meta.root),
EPISODES_PATH: lambda: [
write_episode(ep, new_meta.root) for ep in new_meta.episodes.values()
],
TASKS_PATH: lambda: [
append_jsonlines({"task_index": idx, "task": task}, new_meta.root / TASKS_PATH)
for idx, task in new_meta.tasks.items()
],
EPISODES_STATS_PATH: lambda: [
write_episode_stats(idx, stats, new_meta.root)
for idx, stats in new_meta.episodes_stats.items()
],
}
for file_path, update_func in metadata_files.items():
_move_file(new_meta.root / file_path, backup_dir / file_path)
update_func()
# Step 2: Update data and video
rel_data_paths = [new_meta.get_data_file_path(ep_idx) for ep_idx in episodes_to_remove]
rel_video_paths = [
new_meta.get_video_file_path(ep_idx, vid_key)
for ep_idx in episodes_to_remove
for vid_key in new_meta.video_keys
]
for rel_path in rel_data_paths + rel_video_paths:
if (new_meta.root / rel_path).exists():
_move_file(new_meta.root / rel_path, backup_dir / rel_path)
except Exception as e:
logging.error(f"Error updating dataset files: {str(e)}. Rolling back changes.")
# Restore metadata files
for file_path in metadata_files:
if (backup_dir / file_path).exists():
_move_file(backup_dir / file_path, new_meta.root / file_path)
# Restore data and video files
for rel_file_path in rel_data_paths + rel_video_paths:
if (backup_dir / rel_file_path).exists():
_move_file(backup_dir / rel_file_path, new_meta.root / rel_file_path)
raise e
def _backup_folder(target_dir: Path, backup_path: Path) -> None:
if backup_path.resolve() == target_dir.resolve() or backup_path.resolve().is_relative_to(
target_dir.resolve()
):
raise ValueError(
f"Backup directory '{backup_path}' cannot be inside the dataset "
f"directory '{target_dir}' as this would cause infinite recursion"
)
backup_path.parent.mkdir(parents=True, exist_ok=True)
logging.info(f"Creating backup at: {backup_path}")
shutil.copytree(target_dir, backup_path)
def _parse_episodes_list(episodes_str: str) -> list[int]:
"""
Parse a string of episode indices, ranges, and comma-separated lists into a list of integers.
"""
episodes = []
for ep in episodes_str.split(","):
if "-" in ep:
start, end = ep.split("-")
episodes.extend(range(int(start), int(end) + 1))
else:
episodes.append(int(ep))
return episodes
def _delete_hub_file(hub_api: HfApi, repo_id: str, file_path: str, branch: str | None = None):
try:
if hub_api.file_exists(
repo_id,
file_path,
repo_type="dataset",
revision=branch,
):
hub_api.delete_file(
path_in_repo=file_path,
repo_id=repo_id,
repo_type="dataset",
revision=branch,
)
except Exception as e:
logging.error(f"Error removing file '{file_path}' from the hub: {str(e)}")
def _remove_episodes_from_hub(
updated_dataset: LeRobotDataset, episodes_to_remove: list[int], branch: str | None = None
):
"""Remove episodes from the hub repository."""
hub_api = HfApi()
for ep_idx in episodes_to_remove:
data_path = str(updated_dataset.meta.get_data_file_path(ep_idx))
_delete_hub_file(hub_api, updated_dataset.repo_id, data_path, branch)
for vid_key in updated_dataset.meta.video_keys:
video_path = str(updated_dataset.meta.get_video_file_path(ep_idx, vid_key))
_delete_hub_file(hub_api, updated_dataset.repo_id, video_path, branch)
def main():
parser = argparse.ArgumentParser(description="Remove episodes from a LeRobot dataset")
parser.add_argument(
"--repo-id",
type=str,
required=True,
help="Name of hugging face repository containing a LeRobotDataset dataset (e.g. `lerobot/pusht`).",
)
parser.add_argument(
"--root",
type=Path,
default=None,
help="Root directory for the dataset stored locally. By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.",
)
parser.add_argument(
"-e",
"--episodes",
type=str,
required=True,
help="Episodes to remove. Can be a single index, comma-separated indices, or ranges (e.g., '1-5,7,10-12')",
)
parser.add_argument(
"-b",
"--backup",
nargs="?",
const=True,
default=False,
help="Create a backup before modifying the dataset. Without a value, creates a backup in the default location. "
"With a value, either 'true'/'false' or a path to store the backup.",
)
parser.add_argument(
"--push-to-hub",
type=int,
default=1,
help="Upload to Hugging Face hub.",
)
parser.add_argument(
"--private",
type=int,
default=0,
help="If set, the repository on the Hub will be private",
)
parser.add_argument(
"--tags",
type=str,
nargs="+",
help="List of tags to apply to the dataset on the Hub",
)
parser.add_argument("--license", type=str, default=None, help="License to use for the dataset on the Hub")
parser.add_argument("--branch", type=str, default=None, help="Branch to push the dataset to on the Hub")
args = parser.parse_args()
# Parse the backup argument
backup_value = args.backup
if isinstance(backup_value, str):
if backup_value.lower() == "true":
backup_value = True
elif backup_value.lower() == "false":
backup_value = False
# Otherwise, it's treated as a path
# Parse episodes to remove
episodes_to_remove = _parse_episodes_list(args.episodes)
if not episodes_to_remove:
logging.warning("No episodes specified to remove")
sys.exit(0)
# Load the dataset
logging.info(f"Loading dataset '{args.repo_id}'...")
dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root)
logging.info(f"Dataset has {dataset.meta.total_episodes} episodes")
# Modify the dataset
logging.info(f"Removing {len(set(episodes_to_remove))} episodes: {sorted(set(episodes_to_remove))}")
updated_dataset = remove_episodes(
dataset=dataset,
episodes_to_remove=episodes_to_remove,
backup=backup_value,
)
logging.info(
f"Successfully removed episodes. Dataset now has {updated_dataset.meta.total_episodes} episodes."
)
if args.push_to_hub:
logging.info("Pushing dataset to hub...")
updated_dataset.push_to_hub(
tags=args.tags, private=bool(args.private), license=args.license, branch=args.branch
)
updated_card = create_lerobot_dataset_card(
tags=args.tags, dataset_info=updated_dataset.meta.info, license=args.license
)
updated_card.push_to_hub(repo_id=updated_dataset.repo_id, repo_type="dataset", revision=args.branch)
_remove_episodes_from_hub(updated_dataset, episodes_to_remove, branch=args.branch)
logging.info("Dataset pushed to hub.")
if __name__ == "__main__":
init_logging()
main()

View File

@ -75,12 +75,14 @@ import torch.utils.data
import tqdm
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import translate_episode_index_to_position
class EpisodeSampler(torch.utils.data.Sampler):
def __init__(self, dataset: LeRobotDataset, episode_index: int):
from_idx = dataset.episode_data_index["from"][episode_index].item()
to_idx = dataset.episode_data_index["to"][episode_index].item()
index_position = translate_episode_index_to_position(dataset.meta.episodes, episode_index)
from_idx = dataset.episode_data_index["from"][index_position].item()
to_idx = dataset.episode_data_index["to"][index_position].item()
self.frame_ids = range(from_idx, to_idx)
def __iter__(self) -> Iterator:

View File

@ -69,7 +69,7 @@ from flask import Flask, redirect, render_template, request, url_for
from lerobot import available_datasets
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import IterableNamespace
from lerobot.common.datasets.utils import IterableNamespace, translate_episode_index_to_position
from lerobot.common.utils.utils import init_logging
@ -207,7 +207,9 @@ def run_server(
if episodes is None:
episodes = list(
range(dataset.num_episodes if isinstance(dataset, LeRobotDataset) else dataset.total_episodes)
dataset.meta.episodes
if isinstance(dataset, LeRobotDataset)
else range(dataset.total_episodes)
)
return render_template(
@ -268,8 +270,9 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
selected_columns.insert(0, "timestamp")
if isinstance(dataset, LeRobotDataset):
from_idx = dataset.episode_data_index["from"][episode_index]
to_idx = dataset.episode_data_index["to"][episode_index]
index_position = translate_episode_index_to_position(dataset.meta.episodes, episode_index)
from_idx = dataset.episode_data_index["from"][index_position]
to_idx = dataset.episode_data_index["to"][index_position]
data = (
dataset.hf_dataset.select(range(from_idx, to_idx))
.select_columns(selected_columns)
@ -305,7 +308,8 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]:
# get first frame of episode (hack to get video_path of the episode)
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
index_position = translate_episode_index_to_position(dataset.meta.episodes, ep_index)
first_frame_idx = dataset.episode_data_index["from"][index_position].item()
return [
dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"]
for key in dataset.meta.video_keys
@ -318,7 +322,8 @@ def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) ->
return None
# get first frame index
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
index_position = translate_episode_index_to_position(dataset.meta.episodes, ep_index)
first_frame_idx = dataset.episode_data_index["from"][index_position].item()
language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"]
# TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored

View File

@ -19,6 +19,7 @@ import re
from copy import deepcopy
from itertools import chain
from pathlib import Path
from unittest.mock import patch
import numpy as np
import pytest
@ -28,6 +29,7 @@ from PIL import Image
from safetensors.torch import load_file
import lerobot
from lerobot.common.datasets.episode_utils import remove_episodes
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.image_writer import image_array_to_pil_image
from lerobot.common.datasets.lerobot_dataset import (
@ -580,3 +582,37 @@ def test_dataset_feature_with_forward_slash_raises_error():
fps=30,
features={"a/b": {"dtype": "float32", "shape": 2, "names": None}},
)
@pytest.mark.parametrize(
"total_episodes, total_frames, episodes_to_remove",
[
(3, 30, [1]),
(3, 30, [0, 2]),
(4, 50, [1, 2]),
],
)
def test_remove_episodes(tmp_path, lerobot_dataset_factory, total_episodes, total_frames, episodes_to_remove):
dataset = lerobot_dataset_factory(
root=tmp_path / "test",
total_episodes=total_episodes,
total_frames=total_frames,
)
num_frames_to_remove = 0
for ep in episodes_to_remove:
num_frames_to_remove += (
dataset.episode_data_index["to"][ep].item() - dataset.episode_data_index["from"][ep].item()
)
with (
patch("lerobot.common.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.common.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.side_effect = lambda repo_id, version: version
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(dataset.root)
updated_dataset = remove_episodes(dataset, episodes_to_remove)
assert updated_dataset.meta.total_episodes == total_episodes - len(episodes_to_remove)
assert updated_dataset.meta.total_frames == total_frames - num_frames_to_remove
for i, ep_meta in enumerate(updated_dataset.meta.episodes.values()):
assert ep_meta["episode_index"] == i