diff --git a/README.md b/README.md index 59929341..3a1735f5 100644 --- a/README.md +++ b/README.md @@ -178,6 +178,7 @@ Under the hood, the `LeRobotDataset` format makes use of several ways to seriali Here are the important details and internal structure organization of a typical `LeRobotDataset` instantiated with `dataset = LeRobotDataset("lerobot/aloha_static_coffee")`. The exact features will change from dataset to dataset but not the main aspects: ``` +TODO: IMPROVE dataset attributes: ├ hf_dataset: a Hugging Face dataset (backed by Arrow/parquet). Typical features example: │ ├ observation.images.cam_high (VideoFrame): @@ -190,7 +191,7 @@ dataset attributes: │ ├ timestamp (float32): timestamp in the episode │ ├ next.done (bool): indicates the end of en episode ; True for the last frame in each episode │ └ index (int64): general index in the whole dataset - ├ episode_data_index: contains 2 tensors with the start and end indices of each episode + ├ meta: contains 2 tensors with the start and end indices of each episode │ ├ from (1D int64 tensor): first frame index for each episode — shape (num episodes,) starts with 0 │ └ to: (1D int64 tensor): last frame index for each episode — shape (num episodes,) ├ stats: a dictionary of statistics (max, mean, min, std) for each feature in the dataset, for instance diff --git a/benchmarks/video/run_video_benchmark.py b/benchmarks/video/run_video_benchmark.py index e9066487..fb1e4396 100644 --- a/benchmarks/video/run_video_benchmark.py +++ b/benchmarks/video/run_video_benchmark.py @@ -108,7 +108,8 @@ def save_decoded_frames( def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None: - ep_num_images = dataset.episode_data_index["to"][0].item() + episode_index = 0 + ep_num_images = dataset.meta.episodes["length"][episode_index] if imgs_dir.exists() and len(list(imgs_dir.glob("frame_*.png"))) == ep_num_images: return @@ -265,7 +266,8 @@ def benchmark_encoding_decoding( overwrite=True, ) - ep_num_images = dataset.episode_data_index["to"][0].item() + episode_index = 0 + ep_num_images = dataset.meta.episodes["length"][episode_index] width, height = tuple(dataset[0][dataset.meta.camera_keys[0]].shape[-2:]) num_pixels = width * height video_size_bytes = video_path.stat().st_size diff --git a/examples/1_load_lerobot_dataset.py b/examples/1_load_lerobot_dataset.py index 96c104b6..3f4662de 100644 --- a/examples/1_load_lerobot_dataset.py +++ b/examples/1_load_lerobot_dataset.py @@ -78,11 +78,11 @@ print(dataset.hf_dataset) # LeRobot datasets also subclasses PyTorch datasets so you can do everything you know and love from working # with the latter, like iterating through the dataset. # The __getitem__ iterates over the frames of the dataset. Since our datasets are also structured by -# episodes, you can access the frame indices of any episode using the episode_data_index. Here, we access +# episodes, you can access the frame indices of any episode using dataset.meta.episodes. Here, we access # frame indices associated to the first episode: episode_index = 0 -from_idx = dataset.episode_data_index["from"][episode_index].item() -to_idx = dataset.episode_data_index["to"][episode_index].item() +from_idx = dataset.meta.episodes["dataset_from_index"][episode_index] +to_idx = dataset.meta.episodes["dataset_to_index"][episode_index] # Then we grab all the image frames from the first camera: camera_key = dataset.meta.camera_keys[0] diff --git a/examples/advanced/1_add_image_transforms.py b/examples/advanced/1_add_image_transforms.py index 882710e3..29539703 100644 --- a/examples/advanced/1_add_image_transforms.py +++ b/examples/advanced/1_add_image_transforms.py @@ -17,7 +17,7 @@ dataset = LeRobotDataset(dataset_repo_id, episodes=[0]) # This is equivalent to `dataset = LeRobotDataset(dataset_repo_id, image_transforms=None)` # Get the index of the first observation in the first episode -first_idx = dataset.episode_data_index["from"][0].item() +first_idx = dataset.meta.episodes["dataset_from_index"][0] # Get the frame corresponding to the first camera frame = dataset[first_idx][dataset.meta.camera_keys[0]] diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 018b2241..932baec5 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -51,6 +51,7 @@ from lerobot.common.datasets.utils import ( get_features_from_robot, get_hf_dataset_size_in_mb, get_hf_features_from_features, + get_parquet_file_size_in_mb, get_parquet_num_frames, get_safe_version, get_video_duration_in_s, @@ -59,15 +60,16 @@ from lerobot.common.datasets.utils import ( load_episodes, load_info, load_nested_dataset, + load_stats, load_tasks, update_chunk_file_indices, validate_episode_buffer, validate_frame, write_info, write_json, + write_stats, write_tasks, ) -from lerobot.common.datasets.v30.convert_dataset_v21_to_v30 import get_parquet_file_size_in_mb from lerobot.common.datasets.video_utils import ( VideoFrame, decode_video_frames_torchvision, @@ -111,8 +113,7 @@ class LeRobotDatasetMetadata: check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION) self.tasks = load_tasks(self.root) self.episodes = load_episodes(self.root) - # TODO(rcadene): https://huggingface.slack.com/archives/C02V51Q3800/p1743517952388249?thread_ts=1742896075.499119&cid=C02V51Q3800 - # self.stats = aggregate_stats(list(self.episodes_stats.values())) + self.stats = load_stats(self.root) def pull_from_repo( self, @@ -272,10 +273,17 @@ class LeRobotDatasetMetadata: chunk_idx, file_idx = 0, 0 df["meta/episodes/chunk_index"] = [chunk_idx] df["meta/episodes/file_index"] = [file_idx] + df["dataset_from_index"] = [0] + df["dataset_to_index"] = [len(df)] else: # Retrieve information from the latest parquet file latest_ep = self.episodes.with_format( - columns=["meta/episodes/chunk_index", "meta/episodes/file_index"] + columns=[ + "meta/episodes/chunk_index", + "meta/episodes/file_index", + "dataset_from_index", + "dataset_to_index", + ] )[-1] chunk_idx, file_idx = ( latest_ep["meta/episodes/chunk_index"], @@ -285,16 +293,18 @@ class LeRobotDatasetMetadata: latest_path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx) latest_size_in_mb = get_parquet_file_size_in_mb(latest_path) - # Determine if a new parquet file is needed if latest_size_in_mb + ep_size_in_mb >= self.files_size_in_mb: # Size limit is reached, prepare new parquet file chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size) - df["meta/episodes/chunk_index"] = [chunk_idx] - df["meta/episodes/file_index"] = [file_idx] - else: - # Update the existing parquet file with new row - df["meta/episodes/chunk_index"] = [chunk_idx] - df["meta/episodes/file_index"] = [file_idx] + + # Update the existing pandas dataframe with new row + df["meta/episodes/chunk_index"] = [chunk_idx] + df["meta/episodes/file_index"] = [file_idx] + df["dataset_from_index"] = [latest_ep["dataset_to_index"]] + df["dataset_to_index"] = [latest_ep["dataset_to_index"] + len(df)] + + if latest_size_in_mb + ep_size_in_mb < self.files_size_in_mb: + # Size limit wasnt reached, concatenate latest dataframe with new one latest_df = pd.read_parquet(latest_path) df = pd.concat([latest_df, df], ignore_index=True) @@ -333,8 +343,8 @@ class LeRobotDatasetMetadata: self.update_video_info() write_info(self.info, self.root) - self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats - # TODO: write stats + self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats is not None else episode_stats + write_stats(self.stats, self.root) def update_video_info(self) -> None: """ @@ -401,8 +411,7 @@ class LeRobotDatasetMetadata: obj.tasks = None obj.episodes = None - # TODO(rcadene) stats - obj.stats = {} + obj.stats = None obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos) if len(obj.video_keys) > 0 and not use_videos: raise ValueError() diff --git a/lerobot/common/datasets/online_buffer.py b/lerobot/common/datasets/online_buffer.py index d907e468..3c8fe176 100644 --- a/lerobot/common/datasets/online_buffer.py +++ b/lerobot/common/datasets/online_buffer.py @@ -337,13 +337,11 @@ def compute_sampler_weights( if len(offline_dataset) > 0: offline_data_mask_indices = [] for start_index, end_index in zip( - offline_dataset.episode_data_index["from"], - offline_dataset.episode_data_index["to"], + offline_dataset.meta.episodes["dataset_from_index"], + offline_dataset.meta.episodes["dataset_to_index"], strict=True, ): - offline_data_mask_indices.extend( - range(start_index.item(), end_index.item() - offline_drop_n_last_frames) - ) + offline_data_mask_indices.extend(range(start_index, end_index - offline_drop_n_last_frames)) offline_data_mask = torch.zeros(len(offline_dataset), dtype=torch.bool) offline_data_mask[torch.tensor(offline_data_mask_indices)] = True weights.append( diff --git a/lerobot/common/datasets/sampler.py b/lerobot/common/datasets/sampler.py index 2f6c15c1..02fdc63d 100644 --- a/lerobot/common/datasets/sampler.py +++ b/lerobot/common/datasets/sampler.py @@ -21,7 +21,8 @@ import torch class EpisodeAwareSampler: def __init__( self, - episode_data_index: dict, + dataset_from_indices: list[int], + dataset_to_indices: list[int], episode_indices_to_use: Union[list, None] = None, drop_n_first_frames: int = 0, drop_n_last_frames: int = 0, @@ -30,7 +31,8 @@ class EpisodeAwareSampler: """Sampler that optionally incorporates episode boundary information. Args: - episode_data_index: Dictionary with keys 'from' and 'to' containing the start and end indices of each episode. + dataset_from_indices: List of indices containing the start of each episode in the dataset. + dataset_to_indices: List of indices containing the end of each episode in the dataset. episode_indices_to_use: List of episode indices to use. If None, all episodes are used. Assumes that episodes are indexed from 0 to N-1. drop_n_first_frames: Number of frames to drop from the start of each episode. @@ -39,12 +41,10 @@ class EpisodeAwareSampler: """ indices = [] for episode_idx, (start_index, end_index) in enumerate( - zip(episode_data_index["from"], episode_data_index["to"], strict=True) + zip(dataset_from_indices, dataset_to_indices, strict=True) ): if episode_indices_to_use is None or episode_idx in episode_indices_to_use: - indices.extend( - range(start_index.item() + drop_n_first_frames, end_index.item() - drop_n_last_frames) - ) + indices.extend(range(start_index + drop_n_first_frames, end_index - drop_n_last_frames)) self.indices = indices self.shuffle = shuffle diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 437a854c..d0cefb3c 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -21,7 +21,6 @@ import shutil import subprocess import tempfile from collections.abc import Iterator -from itertools import accumulate from pathlib import Path from pprint import pformat from types import SimpleNamespace @@ -56,23 +55,23 @@ DEFAULT_FILE_SIZE_IN_MB = 500.0 # Max size per file # Keep legacy for `convert_dataset_v21_to_v30.py` LEGACY_EPISODES_PATH = "meta/episodes.jsonl" -LEGACY_STATS_PATH = "meta/stats.json" LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl" LEGACY_TASKS_PATH = "meta/tasks.jsonl" LEGACY_DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4" LEGACY_DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet" -# TODO -DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png" +DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.png" + +INFO_PATH = "meta/info.json" +STATS_PATH = "meta/stats.json" EPISODES_DIR = "meta/episodes" DATA_DIR = "data" VIDEO_DIR = "videos" -INFO_PATH = "meta/info.json" CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}" -DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet" DEFAULT_TASKS_PATH = "meta/tasks.parquet" +DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet" DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet" DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4" @@ -95,6 +94,12 @@ DEFAULT_FEATURES = { } +def get_parquet_file_size_in_mb(parquet_path): + metadata = pq.read_metadata(parquet_path) + uncompressed_size = metadata.num_rows * metadata.row_group(0).total_byte_size + return uncompressed_size / (1024**2) + + def get_hf_dataset_size_in_mb(hf_ds: Dataset) -> int: return hf_ds.data.nbytes / (1024**2) @@ -317,7 +322,7 @@ def load_info(local_dir: Path) -> dict: def write_stats(stats: dict, local_dir: Path): serialized_stats = serialize_dict(stats) - write_json(serialized_stats, local_dir / LEGACY_STATS_PATH) + write_json(serialized_stats, local_dir / STATS_PATH) def cast_stats_to_numpy(stats) -> dict[str, dict[str, np.ndarray]]: @@ -326,9 +331,9 @@ def cast_stats_to_numpy(stats) -> dict[str, dict[str, np.ndarray]]: def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]: - if not (local_dir / LEGACY_STATS_PATH).exists(): + if not (local_dir / STATS_PATH).exists(): return None - stats = load_json(local_dir / LEGACY_STATS_PATH) + stats = load_json(local_dir / STATS_PATH) return cast_stats_to_numpy(stats) @@ -375,13 +380,6 @@ def write_episodes(episodes: Dataset, local_dir: Path): if get_hf_dataset_size_in_mb(episodes) > DEFAULT_FILE_SIZE_IN_MB: raise NotImplementedError("Contact a maintainer.") - def add_chunk_file_indices(row): - row["chunk_index"] = 0 - row["file_index"] = 0 - return row - - episodes = episodes.map(add_chunk_file_indices) - fpath = local_dir / DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0) fpath.parent.mkdir(parents=True, exist_ok=True) episodes.to_parquet(fpath) @@ -642,20 +640,6 @@ def create_empty_dataset_info( } -def get_episode_data_index( - episode_dicts: dict[dict], episodes: list[int] | None = None -) -> dict[str, torch.Tensor]: - episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items()} - if episodes is not None: - episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes} - - cumulative_lengths = list(accumulate(episode_lengths.values())) - return { - "from": torch.LongTensor([0] + cumulative_lengths[:-1]), - "to": torch.LongTensor(cumulative_lengths), - } - - def check_timestamps_sync( timestamps: np.ndarray, episode_indices: np.ndarray, diff --git a/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py b/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py index 38f852bb..69c8d58a 100644 --- a/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py +++ b/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py @@ -123,10 +123,10 @@ from lerobot.common.datasets.utils import ( DEFAULT_CHUNK_SIZE, DEFAULT_DATA_PATH, DEFAULT_VIDEO_PATH, - LEGACY_EPISODES_PATH, INFO_PATH, - LEGACY_STATS_PATH, + LEGACY_EPISODES_PATH, LEGACY_TASKS_PATH, + STATS_PATH, create_branch, create_lerobot_dataset_card, flatten_dict, @@ -188,7 +188,7 @@ def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None: serialized_stats = {key: value.tolist() for key, value in stats.items()} serialized_stats = unflatten_dict(serialized_stats) - json_path = v2_dir / LEGACY_STATS_PATH + json_path = v2_dir / STATS_PATH json_path.parent.mkdir(exist_ok=True, parents=True) with open(json_path, "w") as f: json.dump(serialized_stats, f, indent=4) @@ -296,9 +296,7 @@ def split_parquet_by_episodes( for ep_idx in range(ep_chunk_start, ep_chunk_end): ep_table = table.filter(pc.equal(table["episode_index"], ep_idx)) episode_lengths.insert(ep_idx, len(ep_table)) - output_file = output_dir / DEFAULT_DATA_PATH.format( - episode_chunk=ep_chunk, episode_index=ep_idx - ) + output_file = output_dir / DEFAULT_DATA_PATH.format(episode_chunk=ep_chunk, episode_index=ep_idx) pq.write_table(ep_table, output_file) return episode_lengths diff --git a/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py b/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py index 8d291b64..a3f210d6 100644 --- a/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py +++ b/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py @@ -23,7 +23,7 @@ import logging from huggingface_hub import HfApi from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset -from lerobot.common.datasets.utils import LEGACY_EPISODES_STATS_PATH, LEGACY_STATS_PATH, load_stats, write_info +from lerobot.common.datasets.utils import LEGACY_EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info from lerobot.common.datasets.v21.convert_stats import check_aggregate_stats, convert_stats V20 = "v2.0" @@ -60,15 +60,15 @@ def convert_dataset( dataset.push_to_hub(branch=branch, tag_version=False, allow_patterns="meta/") # delete old stats.json file - if (dataset.root / LEGACY_STATS_PATH).is_file: - (dataset.root / LEGACY_STATS_PATH).unlink() + if (dataset.root / STATS_PATH).is_file: + (dataset.root / STATS_PATH).unlink() hub_api = HfApi() if hub_api.file_exists( - repo_id=dataset.repo_id, filename=LEGACY_STATS_PATH, revision=branch, repo_type="dataset" + repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset" ): hub_api.delete_file( - path_in_repo=LEGACY_STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset" + path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset" ) hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset") diff --git a/lerobot/common/datasets/v30/convert_dataset_v21_to_v30.py b/lerobot/common/datasets/v30/convert_dataset_v21_to_v30.py index 21cc7e22..7a4ff6c9 100644 --- a/lerobot/common/datasets/v30/convert_dataset_v21_to_v30.py +++ b/lerobot/common/datasets/v30/convert_dataset_v21_to_v30.py @@ -18,15 +18,16 @@ python lerobot/common/datasets/v30/convert_dataset_v21_to_v30.py \ """ import argparse +import shutil from pathlib import Path import pandas as pd -import pyarrow.parquet as pq import tqdm from datasets import Dataset -from huggingface_hub import snapshot_download +from huggingface_hub import HfApi, snapshot_download from lerobot.common.constants import HF_LEROBOT_HOME +from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset from lerobot.common.datasets.utils import ( DEFAULT_CHUNK_SIZE, DEFAULT_DATA_PATH, @@ -34,6 +35,7 @@ from lerobot.common.datasets.utils import ( DEFAULT_VIDEO_PATH, concat_video_files, flatten_dict, + get_parquet_file_size_in_mb, get_parquet_num_frames, get_video_duration_in_s, get_video_size_in_mb, @@ -93,12 +95,6 @@ meta/info.json """ -def get_parquet_file_size_in_mb(parquet_path): - metadata = pq.read_metadata(parquet_path) - uncompressed_size = metadata.num_rows * metadata.row_group(0).total_byte_size - return uncompressed_size / (1024**2) - - # def generate_flat_ep_stats(episodes_stats): # for ep_idx, ep_stats in episodes_stats.items(): # flat_ep_stats = flatten_dict(ep_stats) @@ -148,8 +144,8 @@ def convert_data(root, new_root): "episode_index": ep_idx, "data/chunk_index": chunk_idx, "data/file_index": file_idx, - "data/from_index": num_frames, - "data/to_index": num_frames + ep_num_frames, + "dataset_from_index": num_frames, + "dataset_to_index": num_frames + ep_num_frames, } size_in_mb += ep_size_in_mb num_frames += ep_num_frames @@ -337,6 +333,9 @@ def convert_dataset( root = HF_LEROBOT_HOME / repo_id new_root = HF_LEROBOT_HOME / f"{repo_id}_v30" + if new_root.is_dir(): + shutil.rmtree(new_root) + snapshot_download( repo_id, repo_type="dataset", @@ -350,6 +349,24 @@ def convert_dataset( episodes_videos_metadata = convert_videos(root, new_root) convert_episodes_metadata(root, new_root, episodes_metadata, episodes_videos_metadata) + shutil.move(str(root), str(root) + "_old") + shutil.move(str(new_root), str(root)) + + # TODO(racdene) + if False: + hub_api = HfApi() + hub_api.delete_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset") + hub_api.delete_files( + delete_patterns=["data/chunk*/episode_*", "meta/*.jsonl", "videos/chunk*"], + repo_id=repo_id, + revision=branch, + repo_type="dataset", + ) + + hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset") + + LeRobotDataset(repo_id).push_to_hub() + if __name__ == "__main__": parser = argparse.ArgumentParser() diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index e36c697a..27d4f05e 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -167,7 +167,8 @@ def train(cfg: TrainPipelineConfig): if hasattr(cfg.policy, "drop_n_last_frames"): shuffle = False sampler = EpisodeAwareSampler( - dataset.episode_data_index, + dataset.meta.episodes["dataset_from_index"], + dataset.meta.episodes["dataset_to_index"], drop_n_last_frames=cfg.policy.drop_n_last_frames, shuffle=True, ) diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py index 11feb1af..29883b0f 100644 --- a/lerobot/scripts/visualize_dataset.py +++ b/lerobot/scripts/visualize_dataset.py @@ -79,8 +79,8 @@ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset class EpisodeSampler(torch.utils.data.Sampler): def __init__(self, dataset: LeRobotDataset, episode_index: int): - from_idx = dataset.episode_data_index["from"][episode_index].item() - to_idx = dataset.episode_data_index["to"][episode_index].item() + from_idx = dataset.meta.episodes["dataset_from_index"][episode_index].item() + to_idx = dataset.meta.episodes["dataset_to_index"][episode_index].item() self.frame_ids = range(from_idx, to_idx) def __iter__(self) -> Iterator: diff --git a/lerobot/scripts/visualize_dataset_html.py b/lerobot/scripts/visualize_dataset_html.py index a6899ce9..9ac8b8dd 100644 --- a/lerobot/scripts/visualize_dataset_html.py +++ b/lerobot/scripts/visualize_dataset_html.py @@ -259,8 +259,8 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index) selected_columns.insert(0, "timestamp") if isinstance(dataset, LeRobotDataset): - from_idx = dataset.episode_data_index["from"][episode_index] - to_idx = dataset.episode_data_index["to"][episode_index] + from_idx = dataset.meta.episodes["dataset_from_index"][episode_index] + to_idx = dataset.meta.episodes["dataset_to_index"][episode_index] data = ( dataset.hf_dataset.select(range(from_idx, to_idx)) .select_columns(selected_columns) @@ -296,7 +296,7 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index) def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]: # get first frame of episode (hack to get video_path of the episode) - first_frame_idx = dataset.episode_data_index["from"][ep_index].item() + first_frame_idx = dataset.meta.episodes["dataset_from_index"][ep_index] return [ dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"] for key in dataset.meta.video_keys @@ -309,7 +309,7 @@ def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) -> return None # get first frame index - first_frame_idx = dataset.episode_data_index["from"][ep_index].item() + first_frame_idx = dataset.meta.episodes["dataset_from_index"][ep_index] language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"] # TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored diff --git a/tests/fixtures/dataset_factories.py b/tests/fixtures/dataset_factories.py index 4bffe9e0..9df64653 100644 --- a/tests/fixtures/dataset_factories.py +++ b/tests/fixtures/dataset_factories.py @@ -230,6 +230,8 @@ def episodes_factory(tasks_factory, stats_factory): "meta/episodes/file_index": [], "data/chunk_index": [], "data/file_index": [], + "dataset_from_index": [], + "dataset_to_index": [], "tasks": [], "length": [], } @@ -241,6 +243,7 @@ def episodes_factory(tasks_factory, stats_factory): for stats_key in flatten_dict({"stats": stats_factory(features)}): d[stats_key] = [] + num_frames = 0 remaining_tasks = list(tasks.index) for ep_idx in range(total_episodes): num_tasks_in_episode = random.randint(1, min(3, num_tasks_available)) if multi_task else 1 @@ -256,6 +259,8 @@ def episodes_factory(tasks_factory, stats_factory): d["meta/episodes/file_index"].append(0) d["data/chunk_index"].append(0) d["data/file_index"].append(0) + d["dataset_from_index"].append(num_frames) + d["dataset_to_index"].append(num_frames + lengths[ep_idx]) d["tasks"].append(episode_tasks) d["length"].append(lengths[ep_idx]) @@ -268,6 +273,8 @@ def episodes_factory(tasks_factory, stats_factory): for stats_key, stats in flatten_dict({"stats": stats_factory(features)}).items(): d[stats_key].append(stats) + num_frames += lengths[ep_idx] + return Dataset.from_dict(d) return _create_episodes @@ -283,10 +290,10 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar ) -> datasets.Dataset: if tasks is None: tasks = tasks_factory() - if episodes is None: - episodes = episodes_factory() if features is None: features = features_factory() + if episodes is None: + episodes = episodes_factory(features) timestamp_col = np.array([], dtype=np.float32) frame_index_col = np.array([], dtype=np.int64) diff --git a/tests/fixtures/hub.py b/tests/fixtures/hub.py index 6fd9c7a2..91055e5c 100644 --- a/tests/fixtures/hub.py +++ b/tests/fixtures/hub.py @@ -10,7 +10,7 @@ from lerobot.common.datasets.utils import ( DEFAULT_EPISODES_PATH, DEFAULT_TASKS_PATH, INFO_PATH, - LEGACY_STATS_PATH, + STATS_PATH, ) from tests.fixtures.constants import LEROBOT_TEST_DIR @@ -70,7 +70,7 @@ def mock_snapshot_download_factory( # List all possible files all_files = [ INFO_PATH, - LEGACY_STATS_PATH, + STATS_PATH, # TODO(rcadene): remove naive chunk 0 file 0 ? DEFAULT_TASKS_PATH.format(chunk_index=0, file_index=0), DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0), diff --git a/tests/scripts/save_dataset_to_safetensors.py b/tests/scripts/save_dataset_to_safetensors.py index 3b77348c..b72c137d 100644 --- a/tests/scripts/save_dataset_to_safetensors.py +++ b/tests/scripts/save_dataset_to_safetensors.py @@ -47,17 +47,23 @@ def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"): ) # save 2 first frames of first episode - i = dataset.episode_data_index["from"][0].item() + i = dataset.meta.episodes["dataset_from_index"][0].item() save_file(dataset[i], repo_dir / f"frame_{i}.safetensors") save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors") # save 2 frames at the middle of first episode - i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2) + i = int( + ( + dataset.meta.episodes["dataset_to_index"][0].item() + - dataset.meta.episodes["dataset_from_index"][0].item() + ) + / 2 + ) save_file(dataset[i], repo_dir / f"frame_{i}.safetensors") save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors") # save 2 last frames of first episode - i = dataset.episode_data_index["to"][0].item() + i = dataset.meta.episodes["dataset_to_index"][0].item() save_file(dataset[i - 2], repo_dir / f"frame_{i - 2}.safetensors") save_file(dataset[i - 1], repo_dir / f"frame_{i - 1}.safetensors") @@ -65,17 +71,17 @@ def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"): # We currently cant because our test dataset only contains the first episode # # save 2 first frames of second episode - # i = dataset.episode_data_index["from"][1].item() + # i = dataset.meta.episodes["dataset_from_index"][1].item() # save_file(dataset[i], repo_dir / f"frame_{i}.safetensors") # save_file(dataset[i + 1], repo_dir / f"frame_{i+1}.safetensors") # # save 2 last frames of second episode - # i = dataset.episode_data_index["to"][1].item() + # i = dataset.meta.episodes["dataset_to_index"][1].item() # save_file(dataset[i - 2], repo_dir / f"frame_{i-2}.safetensors") # save_file(dataset[i - 1], repo_dir / f"frame_{i-1}.safetensors") # # save 2 last frames of last episode - # i = dataset.episode_data_index["to"][-1].item() + # i = dataset.meta.episodes["dataset_to_index"][-1].item() # save_file(dataset[i - 2], repo_dir / f"frame_{i-2}.safetensors") # save_file(dataset[i - 1], repo_dir / f"frame_{i-1}.safetensors") diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 003a60c9..195db04f 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -507,17 +507,23 @@ def test_backward_compatibility(repo_id): ) # test2 first frames of first episode - i = dataset.episode_data_index["from"][0].item() + i = dataset.meta.episodes["dataset_from_index"][0].item() load_and_compare(i) load_and_compare(i + 1) # test 2 frames at the middle of first episode - i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2) + i = int( + ( + dataset.meta.episodes["dataset_to_index"][0].item() + - dataset.meta.episodes["dataset_from_index"][0].item() + ) + / 2 + ) load_and_compare(i) load_and_compare(i + 1) # test 2 last frames of first episode - i = dataset.episode_data_index["to"][0].item() + i = dataset.meta.episodes["dataset_to_index"][0].item() load_and_compare(i - 2) load_and_compare(i - 1) @@ -525,17 +531,17 @@ def test_backward_compatibility(repo_id): # We currently cant because our test dataset only contains the first episode # # test 2 first frames of second episode - # i = dataset.episode_data_index["from"][1].item() + # i = dataset.meta.episodes["dataset_from_index"][1].item() # load_and_compare(i) # load_and_compare(i + 1) # # test 2 last frames of second episode - # i = dataset.episode_data_index["to"][1].item() + # i = dataset.meta.episodes["dataset_to_index"][1].item() # load_and_compare(i - 2) # load_and_compare(i - 1) # # test 2 last frames of last episode - # i = dataset.episode_data_index["to"][-1].item() + # i = dataset.meta.episodes["dataset_to_index"][-1].item() # load_and_compare(i - 2) # load_and_compare(i - 1) diff --git a/tests/test_delta_timestamps.py b/tests/test_delta_timestamps.py index b27cc1eb..1477ec73 100644 --- a/tests/test_delta_timestamps.py +++ b/tests/test_delta_timestamps.py @@ -43,8 +43,8 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, np.n def synced_timestamps_factory(hf_dataset_factory): def _create_synced_timestamps(fps: int = 30) -> tuple[np.ndarray, np.ndarray, np.ndarray]: hf_dataset = hf_dataset_factory(fps=fps) - timestamps = torch.stack(hf_dataset["timestamp"]).numpy() - episode_indices = torch.stack(hf_dataset["episode_index"]).numpy() + timestamps = hf_dataset["timestamp"].numpy() + episode_indices = hf_dataset["episode_index"].numpy() episode_data_index = calculate_episode_data_index(hf_dataset) return timestamps, episode_indices, episode_data_index diff --git a/tests/test_policies.py b/tests/test_policies.py index 9dab6176..5e08ea81 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -68,7 +68,11 @@ def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_p }, } info = info_factory( - total_episodes=1, total_frames=1, camera_features=camera_features, motor_features=motor_features + total_episodes=1, + total_frames=1, + total_tasks=1, + camera_features=camera_features, + motor_features=motor_features, ) ds_meta = lerobot_dataset_metadata_factory(root=tmp_path / "init", info=info) return ds_meta diff --git a/tests/test_sampler.py b/tests/test_sampler.py index ee143f37..91cdb8bc 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -32,7 +32,7 @@ def test_drop_n_first_frames(): ) dataset.set_transform(hf_transform_to_torch) episode_data_index = calculate_episode_data_index(dataset) - sampler = EpisodeAwareSampler(episode_data_index, drop_n_first_frames=1) + sampler = EpisodeAwareSampler(episode_data_index["from"], episode_data_index["to"], drop_n_first_frames=1) assert sampler.indices == [1, 4, 5] assert len(sampler) == 3 assert list(sampler) == [1, 4, 5] @@ -48,7 +48,7 @@ def test_drop_n_last_frames(): ) dataset.set_transform(hf_transform_to_torch) episode_data_index = calculate_episode_data_index(dataset) - sampler = EpisodeAwareSampler(episode_data_index, drop_n_last_frames=1) + sampler = EpisodeAwareSampler(episode_data_index["from"], episode_data_index["to"], drop_n_last_frames=1) assert sampler.indices == [0, 3, 4] assert len(sampler) == 3 assert list(sampler) == [0, 3, 4] @@ -64,7 +64,9 @@ def test_episode_indices_to_use(): ) dataset.set_transform(hf_transform_to_torch) episode_data_index = calculate_episode_data_index(dataset) - sampler = EpisodeAwareSampler(episode_data_index, episode_indices_to_use=[0, 2]) + sampler = EpisodeAwareSampler( + episode_data_index["from"], episode_data_index["to"], episode_indices_to_use=[0, 2] + ) assert sampler.indices == [0, 1, 3, 4, 5] assert len(sampler) == 5 assert list(sampler) == [0, 1, 3, 4, 5] @@ -80,11 +82,11 @@ def test_shuffle(): ) dataset.set_transform(hf_transform_to_torch) episode_data_index = calculate_episode_data_index(dataset) - sampler = EpisodeAwareSampler(episode_data_index, shuffle=False) + sampler = EpisodeAwareSampler(episode_data_index["from"], episode_data_index["to"], shuffle=False) assert sampler.indices == [0, 1, 2, 3, 4, 5] assert len(sampler) == 6 assert list(sampler) == [0, 1, 2, 3, 4, 5] - sampler = EpisodeAwareSampler(episode_data_index, shuffle=True) + sampler = EpisodeAwareSampler(episode_data_index["from"], episode_data_index["to"], shuffle=True) assert sampler.indices == [0, 1, 2, 3, 4, 5] assert len(sampler) == 6 assert set(sampler) == {0, 1, 2, 3, 4, 5} diff --git a/tests/test_utils.py b/tests/test_utils.py deleted file mode 100644 index b2f14694..00000000 --- a/tests/test_utils.py +++ /dev/null @@ -1,21 +0,0 @@ -import torch -from datasets import Dataset - -from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index -from lerobot.common.datasets.utils import ( - hf_transform_to_torch, -) - - -def test_calculate_episode_data_index(): - dataset = Dataset.from_dict( - { - "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - "index": [0, 1, 2, 3, 4, 5], - "episode_index": [0, 0, 1, 2, 2, 2], - }, - ) - dataset.set_transform(hf_transform_to_torch) - episode_data_index = calculate_episode_data_index(dataset) - assert torch.equal(episode_data_index["from"], torch.tensor([0, 2, 3])) - assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6]))