fix: Add translation function for non-sequential episode indices

When episodes are removed from a LeRobotDataset, the remaining
episode indices are no longer sequential, which causes indexing errors
in get_episode_data(). This happens because episode_data_index tensors
are always indexed sequentially, while the episode indices can be
arbitrary. This commit introduces a helper function to make the
conversion.
This commit is contained in:
Ben Sprenger 2025-03-16 15:07:22 +01:00
parent 284bc5bfe3
commit 7fe463b5dd
4 changed files with 49 additions and 11 deletions

View File

@ -811,3 +811,33 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features:
f"In episode_buffer not in features: {buffer_keys - set(features)}"
f"In features not in episode_buffer: {set(features) - buffer_keys}"
)
def translate_episode_index_to_position(episode_dicts: dict[dict], episode_index: int) -> int:
"""
Translates an actual episode index to its position in the sequential episode_data_index tensors.
When episodes are removed from a dataset, the remaining episode indices may no longer be sequential
(e.g., they could be [0, 3, 7, 10]). However, the dataset's episode_data_index tensors are always
indexed sequentially from 0 to len(episodes)-1. This function provides the mapping between these
two indexing schemes.
Example:
If a dataset originally had episodes [0, 1, 2, 3, 4] but episodes 1 and 3 were removed,
the remaining episodes would be [0, 2, 4]. In this case:
- Episode index 0 would be at position 0
- Episode index 2 would be at position 1
- Episode index 4 would be at position 2
So translate_episode_index_to_position(episode_dicts, 4) would return 2.
Args:
episode_dicts (dict[dict]): Dictionary of episode dictionaries or list of episode indices
episode_index (int): The actual episode index to translate
Returns:
int: The position of the episode in the episode_data_index tensors
"""
episode_to_position = {ep_idx: i for i, ep_idx in enumerate(episode_dicts)}
position = episode_to_position[episode_index]
return position

View File

@ -19,7 +19,7 @@ from tqdm import tqdm
from lerobot.common.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import write_episode_stats
from lerobot.common.datasets.utils import translate_episode_index_to_position, write_episode_stats
def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray:
@ -31,8 +31,9 @@ def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_
def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
ep_start_idx = dataset.episode_data_index["from"][ep_idx]
ep_end_idx = dataset.episode_data_index["to"][ep_idx]
index_position = translate_episode_index_to_position(dataset.meta.episodes, ep_idx)
ep_start_idx = dataset.episode_data_index["from"][index_position]
ep_end_idx = dataset.episode_data_index["to"][index_position]
ep_data = dataset.hf_dataset.select(range(ep_start_idx, ep_end_idx))
ep_stats = {}

View File

@ -75,12 +75,14 @@ import torch.utils.data
import tqdm
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import translate_episode_index_to_position
class EpisodeSampler(torch.utils.data.Sampler):
def __init__(self, dataset: LeRobotDataset, episode_index: int):
from_idx = dataset.episode_data_index["from"][episode_index].item()
to_idx = dataset.episode_data_index["to"][episode_index].item()
index_position = translate_episode_index_to_position(dataset.meta.episodes, episode_index)
from_idx = dataset.episode_data_index["from"][index_position].item()
to_idx = dataset.episode_data_index["to"][index_position].item()
self.frame_ids = range(from_idx, to_idx)
def __iter__(self) -> Iterator:

View File

@ -69,7 +69,7 @@ from flask import Flask, redirect, render_template, request, url_for
from lerobot import available_datasets
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import IterableNamespace
from lerobot.common.datasets.utils import IterableNamespace, translate_episode_index_to_position
from lerobot.common.utils.utils import init_logging
@ -207,7 +207,9 @@ def run_server(
if episodes is None:
episodes = list(
range(dataset.num_episodes if isinstance(dataset, LeRobotDataset) else dataset.total_episodes)
dataset.meta.episodes
if isinstance(dataset, LeRobotDataset)
else range(dataset.total_episodes)
)
return render_template(
@ -268,8 +270,9 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
selected_columns.insert(0, "timestamp")
if isinstance(dataset, LeRobotDataset):
from_idx = dataset.episode_data_index["from"][episode_index]
to_idx = dataset.episode_data_index["to"][episode_index]
index_position = translate_episode_index_to_position(dataset.meta.episodes, episode_index)
from_idx = dataset.episode_data_index["from"][index_position]
to_idx = dataset.episode_data_index["to"][index_position]
data = (
dataset.hf_dataset.select(range(from_idx, to_idx))
.select_columns(selected_columns)
@ -305,7 +308,8 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]:
# get first frame of episode (hack to get video_path of the episode)
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
index_position = translate_episode_index_to_position(dataset.meta.episodes, ep_index)
first_frame_idx = dataset.episode_data_index["from"][index_position].item()
return [
dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"]
for key in dataset.meta.video_keys
@ -318,7 +322,8 @@ def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) ->
return None
# get first frame index
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
index_position = translate_episode_index_to_position(dataset.meta.episodes, ep_index)
first_frame_idx = dataset.episode_data_index["from"][index_position].item()
language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"]
# TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored