fix: Add translation function for non-sequential episode indices
When episodes are removed from a LeRobotDataset, the remaining episode indices are no longer sequential, which causes indexing errors in get_episode_data(). This happens because episode_data_index tensors are always indexed sequentially, while the episode indices can be arbitrary. This commit introduces a helper function to make the conversion.
This commit is contained in:
parent
284bc5bfe3
commit
7fe463b5dd
lerobot
|
@ -811,3 +811,33 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features:
|
||||||
f"In episode_buffer not in features: {buffer_keys - set(features)}"
|
f"In episode_buffer not in features: {buffer_keys - set(features)}"
|
||||||
f"In features not in episode_buffer: {set(features) - buffer_keys}"
|
f"In features not in episode_buffer: {set(features) - buffer_keys}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def translate_episode_index_to_position(episode_dicts: dict[dict], episode_index: int) -> int:
|
||||||
|
"""
|
||||||
|
Translates an actual episode index to its position in the sequential episode_data_index tensors.
|
||||||
|
|
||||||
|
When episodes are removed from a dataset, the remaining episode indices may no longer be sequential
|
||||||
|
(e.g., they could be [0, 3, 7, 10]). However, the dataset's episode_data_index tensors are always
|
||||||
|
indexed sequentially from 0 to len(episodes)-1. This function provides the mapping between these
|
||||||
|
two indexing schemes.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
If a dataset originally had episodes [0, 1, 2, 3, 4] but episodes 1 and 3 were removed,
|
||||||
|
the remaining episodes would be [0, 2, 4]. In this case:
|
||||||
|
- Episode index 0 would be at position 0
|
||||||
|
- Episode index 2 would be at position 1
|
||||||
|
- Episode index 4 would be at position 2
|
||||||
|
|
||||||
|
So translate_episode_index_to_position(episode_dicts, 4) would return 2.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
episode_dicts (dict[dict]): Dictionary of episode dictionaries or list of episode indices
|
||||||
|
episode_index (int): The actual episode index to translate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: The position of the episode in the episode_data_index tensors
|
||||||
|
"""
|
||||||
|
episode_to_position = {ep_idx: i for i, ep_idx in enumerate(episode_dicts)}
|
||||||
|
position = episode_to_position[episode_index]
|
||||||
|
return position
|
||||||
|
|
|
@ -19,7 +19,7 @@ from tqdm import tqdm
|
||||||
|
|
||||||
from lerobot.common.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices
|
from lerobot.common.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.datasets.utils import write_episode_stats
|
from lerobot.common.datasets.utils import translate_episode_index_to_position, write_episode_stats
|
||||||
|
|
||||||
|
|
||||||
def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray:
|
def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray:
|
||||||
|
@ -31,8 +31,9 @@ def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_
|
||||||
|
|
||||||
|
|
||||||
def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
|
def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
|
||||||
ep_start_idx = dataset.episode_data_index["from"][ep_idx]
|
index_position = translate_episode_index_to_position(dataset.meta.episodes, ep_idx)
|
||||||
ep_end_idx = dataset.episode_data_index["to"][ep_idx]
|
ep_start_idx = dataset.episode_data_index["from"][index_position]
|
||||||
|
ep_end_idx = dataset.episode_data_index["to"][index_position]
|
||||||
ep_data = dataset.hf_dataset.select(range(ep_start_idx, ep_end_idx))
|
ep_data = dataset.hf_dataset.select(range(ep_start_idx, ep_end_idx))
|
||||||
|
|
||||||
ep_stats = {}
|
ep_stats = {}
|
||||||
|
|
|
@ -75,12 +75,14 @@ import torch.utils.data
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
from lerobot.common.datasets.utils import translate_episode_index_to_position
|
||||||
|
|
||||||
|
|
||||||
class EpisodeSampler(torch.utils.data.Sampler):
|
class EpisodeSampler(torch.utils.data.Sampler):
|
||||||
def __init__(self, dataset: LeRobotDataset, episode_index: int):
|
def __init__(self, dataset: LeRobotDataset, episode_index: int):
|
||||||
from_idx = dataset.episode_data_index["from"][episode_index].item()
|
index_position = translate_episode_index_to_position(dataset.meta.episodes, episode_index)
|
||||||
to_idx = dataset.episode_data_index["to"][episode_index].item()
|
from_idx = dataset.episode_data_index["from"][index_position].item()
|
||||||
|
to_idx = dataset.episode_data_index["to"][index_position].item()
|
||||||
self.frame_ids = range(from_idx, to_idx)
|
self.frame_ids = range(from_idx, to_idx)
|
||||||
|
|
||||||
def __iter__(self) -> Iterator:
|
def __iter__(self) -> Iterator:
|
||||||
|
|
|
@ -69,7 +69,7 @@ from flask import Flask, redirect, render_template, request, url_for
|
||||||
|
|
||||||
from lerobot import available_datasets
|
from lerobot import available_datasets
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.datasets.utils import IterableNamespace
|
from lerobot.common.datasets.utils import IterableNamespace, translate_episode_index_to_position
|
||||||
from lerobot.common.utils.utils import init_logging
|
from lerobot.common.utils.utils import init_logging
|
||||||
|
|
||||||
|
|
||||||
|
@ -207,7 +207,9 @@ def run_server(
|
||||||
|
|
||||||
if episodes is None:
|
if episodes is None:
|
||||||
episodes = list(
|
episodes = list(
|
||||||
range(dataset.num_episodes if isinstance(dataset, LeRobotDataset) else dataset.total_episodes)
|
dataset.meta.episodes
|
||||||
|
if isinstance(dataset, LeRobotDataset)
|
||||||
|
else range(dataset.total_episodes)
|
||||||
)
|
)
|
||||||
|
|
||||||
return render_template(
|
return render_template(
|
||||||
|
@ -268,8 +270,9 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
|
||||||
selected_columns.insert(0, "timestamp")
|
selected_columns.insert(0, "timestamp")
|
||||||
|
|
||||||
if isinstance(dataset, LeRobotDataset):
|
if isinstance(dataset, LeRobotDataset):
|
||||||
from_idx = dataset.episode_data_index["from"][episode_index]
|
index_position = translate_episode_index_to_position(dataset.meta.episodes, episode_index)
|
||||||
to_idx = dataset.episode_data_index["to"][episode_index]
|
from_idx = dataset.episode_data_index["from"][index_position]
|
||||||
|
to_idx = dataset.episode_data_index["to"][index_position]
|
||||||
data = (
|
data = (
|
||||||
dataset.hf_dataset.select(range(from_idx, to_idx))
|
dataset.hf_dataset.select(range(from_idx, to_idx))
|
||||||
.select_columns(selected_columns)
|
.select_columns(selected_columns)
|
||||||
|
@ -305,7 +308,8 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
|
||||||
|
|
||||||
def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]:
|
def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]:
|
||||||
# get first frame of episode (hack to get video_path of the episode)
|
# get first frame of episode (hack to get video_path of the episode)
|
||||||
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
|
index_position = translate_episode_index_to_position(dataset.meta.episodes, ep_index)
|
||||||
|
first_frame_idx = dataset.episode_data_index["from"][index_position].item()
|
||||||
return [
|
return [
|
||||||
dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"]
|
dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"]
|
||||||
for key in dataset.meta.video_keys
|
for key in dataset.meta.video_keys
|
||||||
|
@ -318,7 +322,8 @@ def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) ->
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# get first frame index
|
# get first frame index
|
||||||
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
|
index_position = translate_episode_index_to_position(dataset.meta.episodes, ep_index)
|
||||||
|
first_frame_idx = dataset.episode_data_index["from"][index_position].item()
|
||||||
|
|
||||||
language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"]
|
language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"]
|
||||||
# TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored
|
# TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored
|
||||||
|
|
Loading…
Reference in New Issue