fix: Add translation function for non-sequential episode indices
When episodes are removed from a LeRobotDataset, the remaining episode indices are no longer sequential, which causes indexing errors in get_episode_data(). This happens because episode_data_index tensors are always indexed sequentially, while the episode indices can be arbitrary. This commit introduces a helper function to make the conversion.
This commit is contained in:
parent
284bc5bfe3
commit
7fe463b5dd
|
@ -811,3 +811,33 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features:
|
|||
f"In episode_buffer not in features: {buffer_keys - set(features)}"
|
||||
f"In features not in episode_buffer: {set(features) - buffer_keys}"
|
||||
)
|
||||
|
||||
|
||||
def translate_episode_index_to_position(episode_dicts: dict[dict], episode_index: int) -> int:
|
||||
"""
|
||||
Translates an actual episode index to its position in the sequential episode_data_index tensors.
|
||||
|
||||
When episodes are removed from a dataset, the remaining episode indices may no longer be sequential
|
||||
(e.g., they could be [0, 3, 7, 10]). However, the dataset's episode_data_index tensors are always
|
||||
indexed sequentially from 0 to len(episodes)-1. This function provides the mapping between these
|
||||
two indexing schemes.
|
||||
|
||||
Example:
|
||||
If a dataset originally had episodes [0, 1, 2, 3, 4] but episodes 1 and 3 were removed,
|
||||
the remaining episodes would be [0, 2, 4]. In this case:
|
||||
- Episode index 0 would be at position 0
|
||||
- Episode index 2 would be at position 1
|
||||
- Episode index 4 would be at position 2
|
||||
|
||||
So translate_episode_index_to_position(episode_dicts, 4) would return 2.
|
||||
|
||||
Args:
|
||||
episode_dicts (dict[dict]): Dictionary of episode dictionaries or list of episode indices
|
||||
episode_index (int): The actual episode index to translate
|
||||
|
||||
Returns:
|
||||
int: The position of the episode in the episode_data_index tensors
|
||||
"""
|
||||
episode_to_position = {ep_idx: i for i, ep_idx in enumerate(episode_dicts)}
|
||||
position = episode_to_position[episode_index]
|
||||
return position
|
||||
|
|
|
@ -19,7 +19,7 @@ from tqdm import tqdm
|
|||
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.utils import write_episode_stats
|
||||
from lerobot.common.datasets.utils import translate_episode_index_to_position, write_episode_stats
|
||||
|
||||
|
||||
def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray:
|
||||
|
@ -31,8 +31,9 @@ def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_
|
|||
|
||||
|
||||
def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
|
||||
ep_start_idx = dataset.episode_data_index["from"][ep_idx]
|
||||
ep_end_idx = dataset.episode_data_index["to"][ep_idx]
|
||||
index_position = translate_episode_index_to_position(dataset.meta.episodes, ep_idx)
|
||||
ep_start_idx = dataset.episode_data_index["from"][index_position]
|
||||
ep_end_idx = dataset.episode_data_index["to"][index_position]
|
||||
ep_data = dataset.hf_dataset.select(range(ep_start_idx, ep_end_idx))
|
||||
|
||||
ep_stats = {}
|
||||
|
|
|
@ -75,12 +75,14 @@ import torch.utils.data
|
|||
import tqdm
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.utils import translate_episode_index_to_position
|
||||
|
||||
|
||||
class EpisodeSampler(torch.utils.data.Sampler):
|
||||
def __init__(self, dataset: LeRobotDataset, episode_index: int):
|
||||
from_idx = dataset.episode_data_index["from"][episode_index].item()
|
||||
to_idx = dataset.episode_data_index["to"][episode_index].item()
|
||||
index_position = translate_episode_index_to_position(dataset.meta.episodes, episode_index)
|
||||
from_idx = dataset.episode_data_index["from"][index_position].item()
|
||||
to_idx = dataset.episode_data_index["to"][index_position].item()
|
||||
self.frame_ids = range(from_idx, to_idx)
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
|
|
|
@ -69,7 +69,7 @@ from flask import Flask, redirect, render_template, request, url_for
|
|||
|
||||
from lerobot import available_datasets
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.utils import IterableNamespace
|
||||
from lerobot.common.datasets.utils import IterableNamespace, translate_episode_index_to_position
|
||||
from lerobot.common.utils.utils import init_logging
|
||||
|
||||
|
||||
|
@ -207,7 +207,9 @@ def run_server(
|
|||
|
||||
if episodes is None:
|
||||
episodes = list(
|
||||
range(dataset.num_episodes if isinstance(dataset, LeRobotDataset) else dataset.total_episodes)
|
||||
dataset.meta.episodes
|
||||
if isinstance(dataset, LeRobotDataset)
|
||||
else range(dataset.total_episodes)
|
||||
)
|
||||
|
||||
return render_template(
|
||||
|
@ -268,8 +270,9 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
|
|||
selected_columns.insert(0, "timestamp")
|
||||
|
||||
if isinstance(dataset, LeRobotDataset):
|
||||
from_idx = dataset.episode_data_index["from"][episode_index]
|
||||
to_idx = dataset.episode_data_index["to"][episode_index]
|
||||
index_position = translate_episode_index_to_position(dataset.meta.episodes, episode_index)
|
||||
from_idx = dataset.episode_data_index["from"][index_position]
|
||||
to_idx = dataset.episode_data_index["to"][index_position]
|
||||
data = (
|
||||
dataset.hf_dataset.select(range(from_idx, to_idx))
|
||||
.select_columns(selected_columns)
|
||||
|
@ -305,7 +308,8 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
|
|||
|
||||
def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]:
|
||||
# get first frame of episode (hack to get video_path of the episode)
|
||||
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
|
||||
index_position = translate_episode_index_to_position(dataset.meta.episodes, ep_index)
|
||||
first_frame_idx = dataset.episode_data_index["from"][index_position].item()
|
||||
return [
|
||||
dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"]
|
||||
for key in dataset.meta.video_keys
|
||||
|
@ -318,7 +322,8 @@ def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) ->
|
|||
return None
|
||||
|
||||
# get first frame index
|
||||
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
|
||||
index_position = translate_episode_index_to_position(dataset.meta.episodes, ep_index)
|
||||
first_frame_idx = dataset.episode_data_index["from"][index_position].item()
|
||||
|
||||
language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"]
|
||||
# TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored
|
||||
|
|
Loading…
Reference in New Issue