Merge 7fe463b5dd
into 1c873df5c0
This commit is contained in:
commit
000e447cde
|
@ -601,7 +601,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
|
||||
|
||||
def get_episodes_file_paths(self) -> list[Path]:
|
||||
episodes = self.episodes if self.episodes is not None else list(range(self.meta.total_episodes))
|
||||
episodes = self.episodes if self.episodes is not None else list(self.meta.episodes)
|
||||
fpaths = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in episodes]
|
||||
if len(self.meta.video_keys) > 0:
|
||||
video_files = [
|
||||
|
|
|
@ -811,3 +811,33 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features:
|
|||
f"In episode_buffer not in features: {buffer_keys - set(features)}"
|
||||
f"In features not in episode_buffer: {set(features) - buffer_keys}"
|
||||
)
|
||||
|
||||
|
||||
def translate_episode_index_to_position(episode_dicts: dict[dict], episode_index: int) -> int:
|
||||
"""
|
||||
Translates an actual episode index to its position in the sequential episode_data_index tensors.
|
||||
|
||||
When episodes are removed from a dataset, the remaining episode indices may no longer be sequential
|
||||
(e.g., they could be [0, 3, 7, 10]). However, the dataset's episode_data_index tensors are always
|
||||
indexed sequentially from 0 to len(episodes)-1. This function provides the mapping between these
|
||||
two indexing schemes.
|
||||
|
||||
Example:
|
||||
If a dataset originally had episodes [0, 1, 2, 3, 4] but episodes 1 and 3 were removed,
|
||||
the remaining episodes would be [0, 2, 4]. In this case:
|
||||
- Episode index 0 would be at position 0
|
||||
- Episode index 2 would be at position 1
|
||||
- Episode index 4 would be at position 2
|
||||
|
||||
So translate_episode_index_to_position(episode_dicts, 4) would return 2.
|
||||
|
||||
Args:
|
||||
episode_dicts (dict[dict]): Dictionary of episode dictionaries or list of episode indices
|
||||
episode_index (int): The actual episode index to translate
|
||||
|
||||
Returns:
|
||||
int: The position of the episode in the episode_data_index tensors
|
||||
"""
|
||||
episode_to_position = {ep_idx: i for i, ep_idx in enumerate(episode_dicts)}
|
||||
position = episode_to_position[episode_index]
|
||||
return position
|
||||
|
|
|
@ -19,7 +19,7 @@ from tqdm import tqdm
|
|||
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.utils import write_episode_stats
|
||||
from lerobot.common.datasets.utils import translate_episode_index_to_position, write_episode_stats
|
||||
|
||||
|
||||
def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray:
|
||||
|
@ -31,8 +31,9 @@ def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_
|
|||
|
||||
|
||||
def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
|
||||
ep_start_idx = dataset.episode_data_index["from"][ep_idx]
|
||||
ep_end_idx = dataset.episode_data_index["to"][ep_idx]
|
||||
index_position = translate_episode_index_to_position(dataset.meta.episodes, ep_idx)
|
||||
ep_start_idx = dataset.episode_data_index["from"][index_position]
|
||||
ep_end_idx = dataset.episode_data_index["to"][index_position]
|
||||
ep_data = dataset.hf_dataset.select(range(ep_start_idx, ep_end_idx))
|
||||
|
||||
ep_stats = {}
|
||||
|
|
|
@ -0,0 +1,362 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import logging
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.common.datasets.utils import (
|
||||
EPISODES_PATH,
|
||||
EPISODES_STATS_PATH,
|
||||
INFO_PATH,
|
||||
TASKS_PATH,
|
||||
append_jsonlines,
|
||||
create_lerobot_dataset_card,
|
||||
write_episode,
|
||||
write_episode_stats,
|
||||
write_info,
|
||||
)
|
||||
from lerobot.common.utils.utils import init_logging
|
||||
|
||||
|
||||
def remove_episodes(
|
||||
dataset: LeRobotDataset,
|
||||
episodes_to_remove: list[int],
|
||||
backup: str | Path | bool = False,
|
||||
) -> LeRobotDataset:
|
||||
"""
|
||||
Removes specified episodes from a LeRobotDataset and updates all metadata and files accordingly.
|
||||
|
||||
Args:
|
||||
dataset: The LeRobotDataset to modify
|
||||
episodes_to_remove: List of episode indices to remove
|
||||
backup: Controls backup behavior:
|
||||
- False: No backup is created
|
||||
- True: Create backup at default location next to dataset
|
||||
- str/Path: Create backup at the specified location
|
||||
|
||||
Returns:
|
||||
Updated LeRobotDataset with specified episodes removed
|
||||
"""
|
||||
if not episodes_to_remove:
|
||||
return dataset
|
||||
|
||||
if not all(ep_idx in dataset.meta.episodes for ep_idx in episodes_to_remove):
|
||||
raise ValueError("Episodes to remove must be valid episode indices in the dataset")
|
||||
|
||||
# Calculate the new metadata
|
||||
new_meta = deepcopy(dataset.meta)
|
||||
new_meta.info["total_episodes"] -= len(episodes_to_remove)
|
||||
new_meta.info["total_frames"] -= sum(
|
||||
dataset.meta.episodes[ep_idx]["length"] for ep_idx in episodes_to_remove
|
||||
)
|
||||
|
||||
for ep_idx in episodes_to_remove:
|
||||
new_meta.episodes.pop(ep_idx)
|
||||
new_meta.episodes_stats.pop(ep_idx)
|
||||
new_meta.stats = aggregate_stats(list(new_meta.episodes_stats.values()))
|
||||
|
||||
tasks = {task for ep in new_meta.episodes.values() if "tasks" in ep for task in ep["tasks"]}
|
||||
new_meta.tasks = {new_meta.get_task_index(task): task for task in tasks}
|
||||
new_meta.task_to_task_index = {task: idx for idx, task in new_meta.tasks.items()}
|
||||
new_meta.info["total_tasks"] = len(new_meta.tasks)
|
||||
|
||||
new_meta.info["total_videos"] = (
|
||||
(new_meta.info["total_episodes"]) * len(dataset.meta.video_keys) if dataset.meta.video_keys else 0
|
||||
)
|
||||
|
||||
if "splits" in new_meta.info:
|
||||
new_meta.info["splits"] = {"train": f"0:{new_meta.info['total_episodes']}"}
|
||||
|
||||
# Now that the metadata is recalculated, we update the dataset files by
|
||||
# removing the files related to the specified episodes. We perform a safe
|
||||
# update such that if an error occurs, any changes are rolled back and the
|
||||
# dataset files are left in its original state. Optionally, a non-temporary
|
||||
# full backup can be made so that we also have the dataset in its original state.
|
||||
if backup:
|
||||
backup_path = (
|
||||
Path(backup)
|
||||
if isinstance(backup, (str, Path))
|
||||
else dataset.root.parent / f"{dataset.root.name}_backup_{int(time.time())}"
|
||||
)
|
||||
_backup_folder(dataset.root, backup_path)
|
||||
|
||||
_update_dataset_files(
|
||||
new_meta,
|
||||
episodes_to_remove,
|
||||
)
|
||||
|
||||
updated_dataset = LeRobotDataset(
|
||||
repo_id=dataset.repo_id,
|
||||
root=dataset.root,
|
||||
episodes=None, # Load all episodes
|
||||
image_transforms=dataset.image_transforms,
|
||||
delta_timestamps=dataset.delta_timestamps,
|
||||
tolerance_s=dataset.tolerance_s,
|
||||
revision=dataset.revision,
|
||||
download_videos=False, # No need to download, we just saved them
|
||||
video_backend=dataset.video_backend,
|
||||
)
|
||||
|
||||
return updated_dataset
|
||||
|
||||
|
||||
def _move_file(src: Path, dest: Path) -> None:
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.move(src, dest)
|
||||
|
||||
|
||||
def _update_dataset_files(new_meta: LeRobotDatasetMetadata, episodes_to_remove: list[int]):
|
||||
"""Update dataset files.
|
||||
|
||||
This function performs a safe update for dataset files. It moves modified or removed
|
||||
episode files to a temporary directory. Once all changes are made, the temporary
|
||||
directory is deleted. If an error occurs during the update, all changes are rolled
|
||||
back and the original dataset files are restored.
|
||||
|
||||
Args:
|
||||
new_meta (LeRobotDatasetMetadata): Updated metadata object containing the new
|
||||
dataset state after removing episodes
|
||||
episodes_to_remove (list[int]): List of episode indices to remove from the dataset
|
||||
|
||||
Raises:
|
||||
Exception: If any operation fails, rolls back all changes and re-raises the original exception
|
||||
"""
|
||||
with tempfile.TemporaryDirectory(prefix="lerobot_backup_temp_") as backup_path:
|
||||
backup_dir = Path(backup_path)
|
||||
|
||||
# Init empty containers s.t. they are guaranteed to exist in the except block
|
||||
metadata_files = {}
|
||||
rel_data_paths = []
|
||||
rel_video_paths = []
|
||||
|
||||
try:
|
||||
# Step 1: Update metadata files
|
||||
metadata_files = {
|
||||
INFO_PATH: lambda: write_info(new_meta.info, new_meta.root),
|
||||
EPISODES_PATH: lambda: [
|
||||
write_episode(ep, new_meta.root) for ep in new_meta.episodes.values()
|
||||
],
|
||||
TASKS_PATH: lambda: [
|
||||
append_jsonlines({"task_index": idx, "task": task}, new_meta.root / TASKS_PATH)
|
||||
for idx, task in new_meta.tasks.items()
|
||||
],
|
||||
EPISODES_STATS_PATH: lambda: [
|
||||
write_episode_stats(idx, stats, new_meta.root)
|
||||
for idx, stats in new_meta.episodes_stats.items()
|
||||
],
|
||||
}
|
||||
for file_path, update_func in metadata_files.items():
|
||||
_move_file(new_meta.root / file_path, backup_dir / file_path)
|
||||
update_func()
|
||||
|
||||
# Step 2: Update data and video
|
||||
rel_data_paths = [new_meta.get_data_file_path(ep_idx) for ep_idx in episodes_to_remove]
|
||||
rel_video_paths = [
|
||||
new_meta.get_video_file_path(ep_idx, vid_key)
|
||||
for ep_idx in episodes_to_remove
|
||||
for vid_key in new_meta.video_keys
|
||||
]
|
||||
for rel_path in rel_data_paths + rel_video_paths:
|
||||
if (new_meta.root / rel_path).exists():
|
||||
_move_file(new_meta.root / rel_path, backup_dir / rel_path)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error updating dataset files: {str(e)}. Rolling back changes.")
|
||||
|
||||
# Restore metadata files
|
||||
for file_path in metadata_files:
|
||||
if (backup_dir / file_path).exists():
|
||||
_move_file(backup_dir / file_path, new_meta.root / file_path)
|
||||
|
||||
# Restore data and video files
|
||||
for rel_file_path in rel_data_paths + rel_video_paths:
|
||||
if (backup_dir / rel_file_path).exists():
|
||||
_move_file(backup_dir / rel_file_path, new_meta.root / rel_file_path)
|
||||
|
||||
raise e
|
||||
|
||||
|
||||
def _backup_folder(target_dir: Path, backup_path: Path) -> None:
|
||||
if backup_path.resolve() == target_dir.resolve() or backup_path.resolve().is_relative_to(
|
||||
target_dir.resolve()
|
||||
):
|
||||
raise ValueError(
|
||||
f"Backup directory '{backup_path}' cannot be inside the dataset "
|
||||
f"directory '{target_dir}' as this would cause infinite recursion"
|
||||
)
|
||||
|
||||
backup_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
logging.info(f"Creating backup at: {backup_path}")
|
||||
shutil.copytree(target_dir, backup_path)
|
||||
|
||||
|
||||
def _parse_episodes_list(episodes_str: str) -> list[int]:
|
||||
"""
|
||||
Parse a string of episode indices, ranges, and comma-separated lists into a list of integers.
|
||||
"""
|
||||
episodes = []
|
||||
for ep in episodes_str.split(","):
|
||||
if "-" in ep:
|
||||
start, end = ep.split("-")
|
||||
episodes.extend(range(int(start), int(end) + 1))
|
||||
else:
|
||||
episodes.append(int(ep))
|
||||
return episodes
|
||||
|
||||
|
||||
def _delete_hub_file(hub_api: HfApi, repo_id: str, file_path: str, branch: str | None = None):
|
||||
try:
|
||||
if hub_api.file_exists(
|
||||
repo_id,
|
||||
file_path,
|
||||
repo_type="dataset",
|
||||
revision=branch,
|
||||
):
|
||||
hub_api.delete_file(
|
||||
path_in_repo=file_path,
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
revision=branch,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error removing file '{file_path}' from the hub: {str(e)}")
|
||||
|
||||
|
||||
def _remove_episodes_from_hub(
|
||||
updated_dataset: LeRobotDataset, episodes_to_remove: list[int], branch: str | None = None
|
||||
):
|
||||
"""Remove episodes from the hub repository."""
|
||||
hub_api = HfApi()
|
||||
|
||||
for ep_idx in episodes_to_remove:
|
||||
data_path = str(updated_dataset.meta.get_data_file_path(ep_idx))
|
||||
_delete_hub_file(hub_api, updated_dataset.repo_id, data_path, branch)
|
||||
|
||||
for vid_key in updated_dataset.meta.video_keys:
|
||||
video_path = str(updated_dataset.meta.get_video_file_path(ep_idx, vid_key))
|
||||
_delete_hub_file(hub_api, updated_dataset.repo_id, video_path, branch)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Remove episodes from a LeRobot dataset")
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Name of hugging face repository containing a LeRobotDataset dataset (e.g. `lerobot/pusht`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Root directory for the dataset stored locally. By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-e",
|
||||
"--episodes",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Episodes to remove. Can be a single index, comma-separated indices, or ranges (e.g., '1-5,7,10-12')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-b",
|
||||
"--backup",
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=False,
|
||||
help="Create a backup before modifying the dataset. Without a value, creates a backup in the default location. "
|
||||
"With a value, either 'true'/'false' or a path to store the backup.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-to-hub",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Upload to Hugging Face hub.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--private",
|
||||
type=int,
|
||||
default=0,
|
||||
help="If set, the repository on the Hub will be private",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tags",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="List of tags to apply to the dataset on the Hub",
|
||||
)
|
||||
parser.add_argument("--license", type=str, default=None, help="License to use for the dataset on the Hub")
|
||||
parser.add_argument("--branch", type=str, default=None, help="Branch to push the dataset to on the Hub")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Parse the backup argument
|
||||
backup_value = args.backup
|
||||
if isinstance(backup_value, str):
|
||||
if backup_value.lower() == "true":
|
||||
backup_value = True
|
||||
elif backup_value.lower() == "false":
|
||||
backup_value = False
|
||||
# Otherwise, it's treated as a path
|
||||
|
||||
# Parse episodes to remove
|
||||
episodes_to_remove = _parse_episodes_list(args.episodes)
|
||||
if not episodes_to_remove:
|
||||
logging.warning("No episodes specified to remove")
|
||||
sys.exit(0)
|
||||
|
||||
# Load the dataset
|
||||
logging.info(f"Loading dataset '{args.repo_id}'...")
|
||||
dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root)
|
||||
logging.info(f"Dataset has {dataset.meta.total_episodes} episodes")
|
||||
|
||||
# Modify the dataset
|
||||
logging.info(f"Removing {len(set(episodes_to_remove))} episodes: {sorted(set(episodes_to_remove))}")
|
||||
updated_dataset = remove_episodes(
|
||||
dataset=dataset,
|
||||
episodes_to_remove=episodes_to_remove,
|
||||
backup=backup_value,
|
||||
)
|
||||
logging.info(
|
||||
f"Successfully removed episodes. Dataset now has {updated_dataset.meta.total_episodes} episodes."
|
||||
)
|
||||
|
||||
if args.push_to_hub:
|
||||
logging.info("Pushing dataset to hub...")
|
||||
|
||||
updated_dataset.push_to_hub(
|
||||
tags=args.tags, private=bool(args.private), license=args.license, branch=args.branch
|
||||
)
|
||||
updated_card = create_lerobot_dataset_card(
|
||||
tags=args.tags, dataset_info=updated_dataset.meta.info, license=args.license
|
||||
)
|
||||
updated_card.push_to_hub(repo_id=updated_dataset.repo_id, repo_type="dataset", revision=args.branch)
|
||||
_remove_episodes_from_hub(updated_dataset, episodes_to_remove, branch=args.branch)
|
||||
|
||||
logging.info("Dataset pushed to hub.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
init_logging()
|
||||
main()
|
|
@ -75,12 +75,14 @@ import torch.utils.data
|
|||
import tqdm
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.utils import translate_episode_index_to_position
|
||||
|
||||
|
||||
class EpisodeSampler(torch.utils.data.Sampler):
|
||||
def __init__(self, dataset: LeRobotDataset, episode_index: int):
|
||||
from_idx = dataset.episode_data_index["from"][episode_index].item()
|
||||
to_idx = dataset.episode_data_index["to"][episode_index].item()
|
||||
index_position = translate_episode_index_to_position(dataset.meta.episodes, episode_index)
|
||||
from_idx = dataset.episode_data_index["from"][index_position].item()
|
||||
to_idx = dataset.episode_data_index["to"][index_position].item()
|
||||
self.frame_ids = range(from_idx, to_idx)
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
|
|
|
@ -69,7 +69,7 @@ from flask import Flask, redirect, render_template, request, url_for
|
|||
|
||||
from lerobot import available_datasets
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.utils import IterableNamespace
|
||||
from lerobot.common.datasets.utils import IterableNamespace, translate_episode_index_to_position
|
||||
from lerobot.common.utils.utils import init_logging
|
||||
|
||||
|
||||
|
@ -207,7 +207,9 @@ def run_server(
|
|||
|
||||
if episodes is None:
|
||||
episodes = list(
|
||||
range(dataset.num_episodes if isinstance(dataset, LeRobotDataset) else dataset.total_episodes)
|
||||
dataset.meta.episodes
|
||||
if isinstance(dataset, LeRobotDataset)
|
||||
else range(dataset.total_episodes)
|
||||
)
|
||||
|
||||
return render_template(
|
||||
|
@ -268,8 +270,9 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
|
|||
selected_columns.insert(0, "timestamp")
|
||||
|
||||
if isinstance(dataset, LeRobotDataset):
|
||||
from_idx = dataset.episode_data_index["from"][episode_index]
|
||||
to_idx = dataset.episode_data_index["to"][episode_index]
|
||||
index_position = translate_episode_index_to_position(dataset.meta.episodes, episode_index)
|
||||
from_idx = dataset.episode_data_index["from"][index_position]
|
||||
to_idx = dataset.episode_data_index["to"][index_position]
|
||||
data = (
|
||||
dataset.hf_dataset.select(range(from_idx, to_idx))
|
||||
.select_columns(selected_columns)
|
||||
|
@ -305,7 +308,8 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
|
|||
|
||||
def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]:
|
||||
# get first frame of episode (hack to get video_path of the episode)
|
||||
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
|
||||
index_position = translate_episode_index_to_position(dataset.meta.episodes, ep_index)
|
||||
first_frame_idx = dataset.episode_data_index["from"][index_position].item()
|
||||
return [
|
||||
dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"]
|
||||
for key in dataset.meta.video_keys
|
||||
|
@ -318,7 +322,8 @@ def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) ->
|
|||
return None
|
||||
|
||||
# get first frame index
|
||||
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
|
||||
index_position = translate_episode_index_to_position(dataset.meta.episodes, ep_index)
|
||||
first_frame_idx = dataset.episode_data_index["from"][index_position].item()
|
||||
|
||||
language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"]
|
||||
# TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored
|
||||
|
|
|
@ -19,6 +19,7 @@ import re
|
|||
from copy import deepcopy
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
@ -28,6 +29,7 @@ from PIL import Image
|
|||
from safetensors.torch import load_file
|
||||
|
||||
import lerobot
|
||||
from lerobot.common.datasets.episode_utils import remove_episodes
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.image_writer import image_array_to_pil_image
|
||||
from lerobot.common.datasets.lerobot_dataset import (
|
||||
|
@ -580,3 +582,37 @@ def test_dataset_feature_with_forward_slash_raises_error():
|
|||
fps=30,
|
||||
features={"a/b": {"dtype": "float32", "shape": 2, "names": None}},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"total_episodes, total_frames, episodes_to_remove",
|
||||
[
|
||||
(3, 30, [1]),
|
||||
(3, 30, [0, 2]),
|
||||
(4, 50, [1, 2]),
|
||||
],
|
||||
)
|
||||
def test_remove_episodes(tmp_path, lerobot_dataset_factory, total_episodes, total_frames, episodes_to_remove):
|
||||
dataset = lerobot_dataset_factory(
|
||||
root=tmp_path / "test",
|
||||
total_episodes=total_episodes,
|
||||
total_frames=total_frames,
|
||||
)
|
||||
num_frames_to_remove = 0
|
||||
for ep in episodes_to_remove:
|
||||
num_frames_to_remove += (
|
||||
dataset.episode_data_index["to"][ep].item() - dataset.episode_data_index["from"][ep].item()
|
||||
)
|
||||
|
||||
with (
|
||||
patch("lerobot.common.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.common.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.side_effect = lambda repo_id, version: version
|
||||
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(dataset.root)
|
||||
updated_dataset = remove_episodes(dataset, episodes_to_remove)
|
||||
|
||||
assert updated_dataset.meta.total_episodes == total_episodes - len(episodes_to_remove)
|
||||
assert updated_dataset.meta.total_frames == total_frames - num_frames_to_remove
|
||||
for i, ep_meta in enumerate(updated_dataset.meta.episodes.values()):
|
||||
assert ep_meta["episode_index"] == i
|
||||
|
|
Loading…
Reference in New Issue