Merge 7fe463b5dd
into 1c873df5c0
This commit is contained in:
commit
000e447cde
|
@ -601,7 +601,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
|
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
|
||||||
|
|
||||||
def get_episodes_file_paths(self) -> list[Path]:
|
def get_episodes_file_paths(self) -> list[Path]:
|
||||||
episodes = self.episodes if self.episodes is not None else list(range(self.meta.total_episodes))
|
episodes = self.episodes if self.episodes is not None else list(self.meta.episodes)
|
||||||
fpaths = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in episodes]
|
fpaths = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in episodes]
|
||||||
if len(self.meta.video_keys) > 0:
|
if len(self.meta.video_keys) > 0:
|
||||||
video_files = [
|
video_files = [
|
||||||
|
|
|
@ -811,3 +811,33 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features:
|
||||||
f"In episode_buffer not in features: {buffer_keys - set(features)}"
|
f"In episode_buffer not in features: {buffer_keys - set(features)}"
|
||||||
f"In features not in episode_buffer: {set(features) - buffer_keys}"
|
f"In features not in episode_buffer: {set(features) - buffer_keys}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def translate_episode_index_to_position(episode_dicts: dict[dict], episode_index: int) -> int:
|
||||||
|
"""
|
||||||
|
Translates an actual episode index to its position in the sequential episode_data_index tensors.
|
||||||
|
|
||||||
|
When episodes are removed from a dataset, the remaining episode indices may no longer be sequential
|
||||||
|
(e.g., they could be [0, 3, 7, 10]). However, the dataset's episode_data_index tensors are always
|
||||||
|
indexed sequentially from 0 to len(episodes)-1. This function provides the mapping between these
|
||||||
|
two indexing schemes.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
If a dataset originally had episodes [0, 1, 2, 3, 4] but episodes 1 and 3 were removed,
|
||||||
|
the remaining episodes would be [0, 2, 4]. In this case:
|
||||||
|
- Episode index 0 would be at position 0
|
||||||
|
- Episode index 2 would be at position 1
|
||||||
|
- Episode index 4 would be at position 2
|
||||||
|
|
||||||
|
So translate_episode_index_to_position(episode_dicts, 4) would return 2.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
episode_dicts (dict[dict]): Dictionary of episode dictionaries or list of episode indices
|
||||||
|
episode_index (int): The actual episode index to translate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: The position of the episode in the episode_data_index tensors
|
||||||
|
"""
|
||||||
|
episode_to_position = {ep_idx: i for i, ep_idx in enumerate(episode_dicts)}
|
||||||
|
position = episode_to_position[episode_index]
|
||||||
|
return position
|
||||||
|
|
|
@ -19,7 +19,7 @@ from tqdm import tqdm
|
||||||
|
|
||||||
from lerobot.common.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices
|
from lerobot.common.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.datasets.utils import write_episode_stats
|
from lerobot.common.datasets.utils import translate_episode_index_to_position, write_episode_stats
|
||||||
|
|
||||||
|
|
||||||
def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray:
|
def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray:
|
||||||
|
@ -31,8 +31,9 @@ def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_
|
||||||
|
|
||||||
|
|
||||||
def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
|
def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
|
||||||
ep_start_idx = dataset.episode_data_index["from"][ep_idx]
|
index_position = translate_episode_index_to_position(dataset.meta.episodes, ep_idx)
|
||||||
ep_end_idx = dataset.episode_data_index["to"][ep_idx]
|
ep_start_idx = dataset.episode_data_index["from"][index_position]
|
||||||
|
ep_end_idx = dataset.episode_data_index["to"][index_position]
|
||||||
ep_data = dataset.hf_dataset.select(range(ep_start_idx, ep_end_idx))
|
ep_data = dataset.hf_dataset.select(range(ep_start_idx, ep_end_idx))
|
||||||
|
|
||||||
ep_stats = {}
|
ep_stats = {}
|
||||||
|
|
|
@ -0,0 +1,362 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import shutil
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
|
from copy import deepcopy
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from huggingface_hub import HfApi
|
||||||
|
|
||||||
|
from lerobot.common.datasets.compute_stats import aggregate_stats
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||||
|
from lerobot.common.datasets.utils import (
|
||||||
|
EPISODES_PATH,
|
||||||
|
EPISODES_STATS_PATH,
|
||||||
|
INFO_PATH,
|
||||||
|
TASKS_PATH,
|
||||||
|
append_jsonlines,
|
||||||
|
create_lerobot_dataset_card,
|
||||||
|
write_episode,
|
||||||
|
write_episode_stats,
|
||||||
|
write_info,
|
||||||
|
)
|
||||||
|
from lerobot.common.utils.utils import init_logging
|
||||||
|
|
||||||
|
|
||||||
|
def remove_episodes(
|
||||||
|
dataset: LeRobotDataset,
|
||||||
|
episodes_to_remove: list[int],
|
||||||
|
backup: str | Path | bool = False,
|
||||||
|
) -> LeRobotDataset:
|
||||||
|
"""
|
||||||
|
Removes specified episodes from a LeRobotDataset and updates all metadata and files accordingly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: The LeRobotDataset to modify
|
||||||
|
episodes_to_remove: List of episode indices to remove
|
||||||
|
backup: Controls backup behavior:
|
||||||
|
- False: No backup is created
|
||||||
|
- True: Create backup at default location next to dataset
|
||||||
|
- str/Path: Create backup at the specified location
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated LeRobotDataset with specified episodes removed
|
||||||
|
"""
|
||||||
|
if not episodes_to_remove:
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
if not all(ep_idx in dataset.meta.episodes for ep_idx in episodes_to_remove):
|
||||||
|
raise ValueError("Episodes to remove must be valid episode indices in the dataset")
|
||||||
|
|
||||||
|
# Calculate the new metadata
|
||||||
|
new_meta = deepcopy(dataset.meta)
|
||||||
|
new_meta.info["total_episodes"] -= len(episodes_to_remove)
|
||||||
|
new_meta.info["total_frames"] -= sum(
|
||||||
|
dataset.meta.episodes[ep_idx]["length"] for ep_idx in episodes_to_remove
|
||||||
|
)
|
||||||
|
|
||||||
|
for ep_idx in episodes_to_remove:
|
||||||
|
new_meta.episodes.pop(ep_idx)
|
||||||
|
new_meta.episodes_stats.pop(ep_idx)
|
||||||
|
new_meta.stats = aggregate_stats(list(new_meta.episodes_stats.values()))
|
||||||
|
|
||||||
|
tasks = {task for ep in new_meta.episodes.values() if "tasks" in ep for task in ep["tasks"]}
|
||||||
|
new_meta.tasks = {new_meta.get_task_index(task): task for task in tasks}
|
||||||
|
new_meta.task_to_task_index = {task: idx for idx, task in new_meta.tasks.items()}
|
||||||
|
new_meta.info["total_tasks"] = len(new_meta.tasks)
|
||||||
|
|
||||||
|
new_meta.info["total_videos"] = (
|
||||||
|
(new_meta.info["total_episodes"]) * len(dataset.meta.video_keys) if dataset.meta.video_keys else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
if "splits" in new_meta.info:
|
||||||
|
new_meta.info["splits"] = {"train": f"0:{new_meta.info['total_episodes']}"}
|
||||||
|
|
||||||
|
# Now that the metadata is recalculated, we update the dataset files by
|
||||||
|
# removing the files related to the specified episodes. We perform a safe
|
||||||
|
# update such that if an error occurs, any changes are rolled back and the
|
||||||
|
# dataset files are left in its original state. Optionally, a non-temporary
|
||||||
|
# full backup can be made so that we also have the dataset in its original state.
|
||||||
|
if backup:
|
||||||
|
backup_path = (
|
||||||
|
Path(backup)
|
||||||
|
if isinstance(backup, (str, Path))
|
||||||
|
else dataset.root.parent / f"{dataset.root.name}_backup_{int(time.time())}"
|
||||||
|
)
|
||||||
|
_backup_folder(dataset.root, backup_path)
|
||||||
|
|
||||||
|
_update_dataset_files(
|
||||||
|
new_meta,
|
||||||
|
episodes_to_remove,
|
||||||
|
)
|
||||||
|
|
||||||
|
updated_dataset = LeRobotDataset(
|
||||||
|
repo_id=dataset.repo_id,
|
||||||
|
root=dataset.root,
|
||||||
|
episodes=None, # Load all episodes
|
||||||
|
image_transforms=dataset.image_transforms,
|
||||||
|
delta_timestamps=dataset.delta_timestamps,
|
||||||
|
tolerance_s=dataset.tolerance_s,
|
||||||
|
revision=dataset.revision,
|
||||||
|
download_videos=False, # No need to download, we just saved them
|
||||||
|
video_backend=dataset.video_backend,
|
||||||
|
)
|
||||||
|
|
||||||
|
return updated_dataset
|
||||||
|
|
||||||
|
|
||||||
|
def _move_file(src: Path, dest: Path) -> None:
|
||||||
|
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
shutil.move(src, dest)
|
||||||
|
|
||||||
|
|
||||||
|
def _update_dataset_files(new_meta: LeRobotDatasetMetadata, episodes_to_remove: list[int]):
|
||||||
|
"""Update dataset files.
|
||||||
|
|
||||||
|
This function performs a safe update for dataset files. It moves modified or removed
|
||||||
|
episode files to a temporary directory. Once all changes are made, the temporary
|
||||||
|
directory is deleted. If an error occurs during the update, all changes are rolled
|
||||||
|
back and the original dataset files are restored.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
new_meta (LeRobotDatasetMetadata): Updated metadata object containing the new
|
||||||
|
dataset state after removing episodes
|
||||||
|
episodes_to_remove (list[int]): List of episode indices to remove from the dataset
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If any operation fails, rolls back all changes and re-raises the original exception
|
||||||
|
"""
|
||||||
|
with tempfile.TemporaryDirectory(prefix="lerobot_backup_temp_") as backup_path:
|
||||||
|
backup_dir = Path(backup_path)
|
||||||
|
|
||||||
|
# Init empty containers s.t. they are guaranteed to exist in the except block
|
||||||
|
metadata_files = {}
|
||||||
|
rel_data_paths = []
|
||||||
|
rel_video_paths = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Step 1: Update metadata files
|
||||||
|
metadata_files = {
|
||||||
|
INFO_PATH: lambda: write_info(new_meta.info, new_meta.root),
|
||||||
|
EPISODES_PATH: lambda: [
|
||||||
|
write_episode(ep, new_meta.root) for ep in new_meta.episodes.values()
|
||||||
|
],
|
||||||
|
TASKS_PATH: lambda: [
|
||||||
|
append_jsonlines({"task_index": idx, "task": task}, new_meta.root / TASKS_PATH)
|
||||||
|
for idx, task in new_meta.tasks.items()
|
||||||
|
],
|
||||||
|
EPISODES_STATS_PATH: lambda: [
|
||||||
|
write_episode_stats(idx, stats, new_meta.root)
|
||||||
|
for idx, stats in new_meta.episodes_stats.items()
|
||||||
|
],
|
||||||
|
}
|
||||||
|
for file_path, update_func in metadata_files.items():
|
||||||
|
_move_file(new_meta.root / file_path, backup_dir / file_path)
|
||||||
|
update_func()
|
||||||
|
|
||||||
|
# Step 2: Update data and video
|
||||||
|
rel_data_paths = [new_meta.get_data_file_path(ep_idx) for ep_idx in episodes_to_remove]
|
||||||
|
rel_video_paths = [
|
||||||
|
new_meta.get_video_file_path(ep_idx, vid_key)
|
||||||
|
for ep_idx in episodes_to_remove
|
||||||
|
for vid_key in new_meta.video_keys
|
||||||
|
]
|
||||||
|
for rel_path in rel_data_paths + rel_video_paths:
|
||||||
|
if (new_meta.root / rel_path).exists():
|
||||||
|
_move_file(new_meta.root / rel_path, backup_dir / rel_path)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error updating dataset files: {str(e)}. Rolling back changes.")
|
||||||
|
|
||||||
|
# Restore metadata files
|
||||||
|
for file_path in metadata_files:
|
||||||
|
if (backup_dir / file_path).exists():
|
||||||
|
_move_file(backup_dir / file_path, new_meta.root / file_path)
|
||||||
|
|
||||||
|
# Restore data and video files
|
||||||
|
for rel_file_path in rel_data_paths + rel_video_paths:
|
||||||
|
if (backup_dir / rel_file_path).exists():
|
||||||
|
_move_file(backup_dir / rel_file_path, new_meta.root / rel_file_path)
|
||||||
|
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
def _backup_folder(target_dir: Path, backup_path: Path) -> None:
|
||||||
|
if backup_path.resolve() == target_dir.resolve() or backup_path.resolve().is_relative_to(
|
||||||
|
target_dir.resolve()
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Backup directory '{backup_path}' cannot be inside the dataset "
|
||||||
|
f"directory '{target_dir}' as this would cause infinite recursion"
|
||||||
|
)
|
||||||
|
|
||||||
|
backup_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
logging.info(f"Creating backup at: {backup_path}")
|
||||||
|
shutil.copytree(target_dir, backup_path)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_episodes_list(episodes_str: str) -> list[int]:
|
||||||
|
"""
|
||||||
|
Parse a string of episode indices, ranges, and comma-separated lists into a list of integers.
|
||||||
|
"""
|
||||||
|
episodes = []
|
||||||
|
for ep in episodes_str.split(","):
|
||||||
|
if "-" in ep:
|
||||||
|
start, end = ep.split("-")
|
||||||
|
episodes.extend(range(int(start), int(end) + 1))
|
||||||
|
else:
|
||||||
|
episodes.append(int(ep))
|
||||||
|
return episodes
|
||||||
|
|
||||||
|
|
||||||
|
def _delete_hub_file(hub_api: HfApi, repo_id: str, file_path: str, branch: str | None = None):
|
||||||
|
try:
|
||||||
|
if hub_api.file_exists(
|
||||||
|
repo_id,
|
||||||
|
file_path,
|
||||||
|
repo_type="dataset",
|
||||||
|
revision=branch,
|
||||||
|
):
|
||||||
|
hub_api.delete_file(
|
||||||
|
path_in_repo=file_path,
|
||||||
|
repo_id=repo_id,
|
||||||
|
repo_type="dataset",
|
||||||
|
revision=branch,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error removing file '{file_path}' from the hub: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
def _remove_episodes_from_hub(
|
||||||
|
updated_dataset: LeRobotDataset, episodes_to_remove: list[int], branch: str | None = None
|
||||||
|
):
|
||||||
|
"""Remove episodes from the hub repository."""
|
||||||
|
hub_api = HfApi()
|
||||||
|
|
||||||
|
for ep_idx in episodes_to_remove:
|
||||||
|
data_path = str(updated_dataset.meta.get_data_file_path(ep_idx))
|
||||||
|
_delete_hub_file(hub_api, updated_dataset.repo_id, data_path, branch)
|
||||||
|
|
||||||
|
for vid_key in updated_dataset.meta.video_keys:
|
||||||
|
video_path = str(updated_dataset.meta.get_video_file_path(ep_idx, vid_key))
|
||||||
|
_delete_hub_file(hub_api, updated_dataset.repo_id, video_path, branch)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Remove episodes from a LeRobot dataset")
|
||||||
|
parser.add_argument(
|
||||||
|
"--repo-id",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Name of hugging face repository containing a LeRobotDataset dataset (e.g. `lerobot/pusht`).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--root",
|
||||||
|
type=Path,
|
||||||
|
default=None,
|
||||||
|
help="Root directory for the dataset stored locally. By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-e",
|
||||||
|
"--episodes",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Episodes to remove. Can be a single index, comma-separated indices, or ranges (e.g., '1-5,7,10-12')",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-b",
|
||||||
|
"--backup",
|
||||||
|
nargs="?",
|
||||||
|
const=True,
|
||||||
|
default=False,
|
||||||
|
help="Create a backup before modifying the dataset. Without a value, creates a backup in the default location. "
|
||||||
|
"With a value, either 'true'/'false' or a path to store the backup.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--push-to-hub",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Upload to Hugging Face hub.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--private",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="If set, the repository on the Hub will be private",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tags",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
help="List of tags to apply to the dataset on the Hub",
|
||||||
|
)
|
||||||
|
parser.add_argument("--license", type=str, default=None, help="License to use for the dataset on the Hub")
|
||||||
|
parser.add_argument("--branch", type=str, default=None, help="Branch to push the dataset to on the Hub")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Parse the backup argument
|
||||||
|
backup_value = args.backup
|
||||||
|
if isinstance(backup_value, str):
|
||||||
|
if backup_value.lower() == "true":
|
||||||
|
backup_value = True
|
||||||
|
elif backup_value.lower() == "false":
|
||||||
|
backup_value = False
|
||||||
|
# Otherwise, it's treated as a path
|
||||||
|
|
||||||
|
# Parse episodes to remove
|
||||||
|
episodes_to_remove = _parse_episodes_list(args.episodes)
|
||||||
|
if not episodes_to_remove:
|
||||||
|
logging.warning("No episodes specified to remove")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
# Load the dataset
|
||||||
|
logging.info(f"Loading dataset '{args.repo_id}'...")
|
||||||
|
dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root)
|
||||||
|
logging.info(f"Dataset has {dataset.meta.total_episodes} episodes")
|
||||||
|
|
||||||
|
# Modify the dataset
|
||||||
|
logging.info(f"Removing {len(set(episodes_to_remove))} episodes: {sorted(set(episodes_to_remove))}")
|
||||||
|
updated_dataset = remove_episodes(
|
||||||
|
dataset=dataset,
|
||||||
|
episodes_to_remove=episodes_to_remove,
|
||||||
|
backup=backup_value,
|
||||||
|
)
|
||||||
|
logging.info(
|
||||||
|
f"Successfully removed episodes. Dataset now has {updated_dataset.meta.total_episodes} episodes."
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.push_to_hub:
|
||||||
|
logging.info("Pushing dataset to hub...")
|
||||||
|
|
||||||
|
updated_dataset.push_to_hub(
|
||||||
|
tags=args.tags, private=bool(args.private), license=args.license, branch=args.branch
|
||||||
|
)
|
||||||
|
updated_card = create_lerobot_dataset_card(
|
||||||
|
tags=args.tags, dataset_info=updated_dataset.meta.info, license=args.license
|
||||||
|
)
|
||||||
|
updated_card.push_to_hub(repo_id=updated_dataset.repo_id, repo_type="dataset", revision=args.branch)
|
||||||
|
_remove_episodes_from_hub(updated_dataset, episodes_to_remove, branch=args.branch)
|
||||||
|
|
||||||
|
logging.info("Dataset pushed to hub.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
init_logging()
|
||||||
|
main()
|
|
@ -75,12 +75,14 @@ import torch.utils.data
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
from lerobot.common.datasets.utils import translate_episode_index_to_position
|
||||||
|
|
||||||
|
|
||||||
class EpisodeSampler(torch.utils.data.Sampler):
|
class EpisodeSampler(torch.utils.data.Sampler):
|
||||||
def __init__(self, dataset: LeRobotDataset, episode_index: int):
|
def __init__(self, dataset: LeRobotDataset, episode_index: int):
|
||||||
from_idx = dataset.episode_data_index["from"][episode_index].item()
|
index_position = translate_episode_index_to_position(dataset.meta.episodes, episode_index)
|
||||||
to_idx = dataset.episode_data_index["to"][episode_index].item()
|
from_idx = dataset.episode_data_index["from"][index_position].item()
|
||||||
|
to_idx = dataset.episode_data_index["to"][index_position].item()
|
||||||
self.frame_ids = range(from_idx, to_idx)
|
self.frame_ids = range(from_idx, to_idx)
|
||||||
|
|
||||||
def __iter__(self) -> Iterator:
|
def __iter__(self) -> Iterator:
|
||||||
|
|
|
@ -69,7 +69,7 @@ from flask import Flask, redirect, render_template, request, url_for
|
||||||
|
|
||||||
from lerobot import available_datasets
|
from lerobot import available_datasets
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.datasets.utils import IterableNamespace
|
from lerobot.common.datasets.utils import IterableNamespace, translate_episode_index_to_position
|
||||||
from lerobot.common.utils.utils import init_logging
|
from lerobot.common.utils.utils import init_logging
|
||||||
|
|
||||||
|
|
||||||
|
@ -207,7 +207,9 @@ def run_server(
|
||||||
|
|
||||||
if episodes is None:
|
if episodes is None:
|
||||||
episodes = list(
|
episodes = list(
|
||||||
range(dataset.num_episodes if isinstance(dataset, LeRobotDataset) else dataset.total_episodes)
|
dataset.meta.episodes
|
||||||
|
if isinstance(dataset, LeRobotDataset)
|
||||||
|
else range(dataset.total_episodes)
|
||||||
)
|
)
|
||||||
|
|
||||||
return render_template(
|
return render_template(
|
||||||
|
@ -268,8 +270,9 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
|
||||||
selected_columns.insert(0, "timestamp")
|
selected_columns.insert(0, "timestamp")
|
||||||
|
|
||||||
if isinstance(dataset, LeRobotDataset):
|
if isinstance(dataset, LeRobotDataset):
|
||||||
from_idx = dataset.episode_data_index["from"][episode_index]
|
index_position = translate_episode_index_to_position(dataset.meta.episodes, episode_index)
|
||||||
to_idx = dataset.episode_data_index["to"][episode_index]
|
from_idx = dataset.episode_data_index["from"][index_position]
|
||||||
|
to_idx = dataset.episode_data_index["to"][index_position]
|
||||||
data = (
|
data = (
|
||||||
dataset.hf_dataset.select(range(from_idx, to_idx))
|
dataset.hf_dataset.select(range(from_idx, to_idx))
|
||||||
.select_columns(selected_columns)
|
.select_columns(selected_columns)
|
||||||
|
@ -305,7 +308,8 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
|
||||||
|
|
||||||
def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]:
|
def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]:
|
||||||
# get first frame of episode (hack to get video_path of the episode)
|
# get first frame of episode (hack to get video_path of the episode)
|
||||||
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
|
index_position = translate_episode_index_to_position(dataset.meta.episodes, ep_index)
|
||||||
|
first_frame_idx = dataset.episode_data_index["from"][index_position].item()
|
||||||
return [
|
return [
|
||||||
dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"]
|
dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"]
|
||||||
for key in dataset.meta.video_keys
|
for key in dataset.meta.video_keys
|
||||||
|
@ -318,7 +322,8 @@ def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) ->
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# get first frame index
|
# get first frame index
|
||||||
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
|
index_position = translate_episode_index_to_position(dataset.meta.episodes, ep_index)
|
||||||
|
first_frame_idx = dataset.episode_data_index["from"][index_position].item()
|
||||||
|
|
||||||
language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"]
|
language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"]
|
||||||
# TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored
|
# TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored
|
||||||
|
|
|
@ -19,6 +19,7 @@ import re
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -28,6 +29,7 @@ from PIL import Image
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
import lerobot
|
import lerobot
|
||||||
|
from lerobot.common.datasets.episode_utils import remove_episodes
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
from lerobot.common.datasets.image_writer import image_array_to_pil_image
|
from lerobot.common.datasets.image_writer import image_array_to_pil_image
|
||||||
from lerobot.common.datasets.lerobot_dataset import (
|
from lerobot.common.datasets.lerobot_dataset import (
|
||||||
|
@ -580,3 +582,37 @@ def test_dataset_feature_with_forward_slash_raises_error():
|
||||||
fps=30,
|
fps=30,
|
||||||
features={"a/b": {"dtype": "float32", "shape": 2, "names": None}},
|
features={"a/b": {"dtype": "float32", "shape": 2, "names": None}},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"total_episodes, total_frames, episodes_to_remove",
|
||||||
|
[
|
||||||
|
(3, 30, [1]),
|
||||||
|
(3, 30, [0, 2]),
|
||||||
|
(4, 50, [1, 2]),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_remove_episodes(tmp_path, lerobot_dataset_factory, total_episodes, total_frames, episodes_to_remove):
|
||||||
|
dataset = lerobot_dataset_factory(
|
||||||
|
root=tmp_path / "test",
|
||||||
|
total_episodes=total_episodes,
|
||||||
|
total_frames=total_frames,
|
||||||
|
)
|
||||||
|
num_frames_to_remove = 0
|
||||||
|
for ep in episodes_to_remove:
|
||||||
|
num_frames_to_remove += (
|
||||||
|
dataset.episode_data_index["to"][ep].item() - dataset.episode_data_index["from"][ep].item()
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("lerobot.common.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||||
|
patch("lerobot.common.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||||
|
):
|
||||||
|
mock_get_safe_version.side_effect = lambda repo_id, version: version
|
||||||
|
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(dataset.root)
|
||||||
|
updated_dataset = remove_episodes(dataset, episodes_to_remove)
|
||||||
|
|
||||||
|
assert updated_dataset.meta.total_episodes == total_episodes - len(episodes_to_remove)
|
||||||
|
assert updated_dataset.meta.total_frames == total_frames - num_frames_to_remove
|
||||||
|
for i, ep_meta in enumerate(updated_dataset.meta.episodes.values()):
|
||||||
|
assert ep_meta["episode_index"] == i
|
||||||
|
|
Loading…
Reference in New Issue