From b0cca75e5e5eef8686227991787a13bfc1fbd7c1 Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Sat, 19 Apr 2025 19:11:53 +0530 Subject: [PATCH] Progress on aggregate_datasets --- lerobot/common/datasets/aggregate.py | 180 ++++++++++-------- lerobot/common/datasets/lerobot_dataset.py | 8 +- .../v30/convert_dataset_v21_to_v30.py | 16 +- tests/fixtures/dataset_factories.py | 9 +- tests/fixtures/files.py | 1 + tests/fixtures/hub.py | 2 + tests/test_aggregate_datasets.py | 20 +- 7 files changed, 149 insertions(+), 87 deletions(-) diff --git a/lerobot/common/datasets/aggregate.py b/lerobot/common/datasets/aggregate.py index df74e767..67cc3ee4 100644 --- a/lerobot/common/datasets/aggregate.py +++ b/lerobot/common/datasets/aggregate.py @@ -1,11 +1,13 @@ import logging +from pathlib import Path import shutil import pandas as pd import tqdm +from lerobot.common.datasets.compute_stats import aggregate_stats from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata -from lerobot.common.datasets.utils import write_episode, legacy_write_episode_stats, write_info, legacy_write_task +from lerobot.common.datasets.utils import DEFAULT_CHUNK_SIZE, DEFAULT_DATA_PATH, DEFAULT_EPISODES_PATH, DEFAULT_VIDEO_PATH, write_episode, legacy_write_episode_stats, write_info, legacy_write_task, write_stats, write_tasks from lerobot.common.utils.utils import init_logging @@ -30,22 +32,46 @@ def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]): return fps, robot_type, features - -def get_update_episode_and_task_func(episode_index_to_add, task_index_to_global_task_index): +def get_update_episode_and_task_func(episode_index_to_add, old_tasks, new_tasks): def _update(row): row["episode_index"] = row["episode_index"] + episode_index_to_add - row["task_index"] = task_index_to_global_task_index[row["task_index"]] + task = old_tasks.iloc[row["task_index"]].name + row["task_index"] = new_tasks.loc[task].task_index.item() return row - return _update +def get_update_meta_func( + meta_chunk_index_to_add, + meta_file_index_to_add, + data_chunk_index_to_add, + data_file_index_to_add, + videos_chunk_index_to_add, + videos_file_index_to_add, + frame_index_to_add, +): + def _update(row): + row["meta/episodes/chunk_index"] = row["meta/episodes/chunk_index"] + meta_chunk_index_to_add + row["meta/episodes/file_index"] = row["meta/episodes/file_index"] + meta_file_index_to_add + row["data/chunk_index"] = row["data/chunk_index"] + data_chunk_index_to_add + row["data/file_index"] = row["data/file_index"] + data_file_index_to_add + for key in videos_chunk_index_to_add: + row[f"videos/{key}/chunk_index"] = row[f"videos/{key}/chunk_index"] + videos_chunk_index_to_add[key] + row[f"videos/{key}/file_index"] = row[f"videos/{key}/file_index"] + videos_file_index_to_add[key] + row["dataset_from_index"] = row["dataset_from_index"] + frame_index_to_add + row["dataset_to_index"] = row["dataset_to_index"] + frame_index_to_add + return row + return _update -def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, aggr_root=None): +def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]=None, aggr_root=None): logging.info("Start aggregate_datasets") - all_metadata = [LeRobotDatasetMetadata(repo_id) for repo_id in repo_ids] + if roots is None: + all_metadata = [LeRobotDatasetMetadata(repo_id) for repo_id in repo_ids] + else: + all_metadata = [LeRobotDatasetMetadata(repo_id, root=root) for repo_id, root in zip(repo_ids, roots)] fps, robot_type, features = validate_all_metadata(all_metadata) + video_keys = [key for key in features if features[key]["dtype"] == "video"] # Create resulting dataset folder aggr_meta = LeRobotDatasetMetadata.create( @@ -55,95 +81,99 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, aggr_root=None): features=features, root=aggr_root, ) + aggr_root = aggr_meta.root logging.info("Find all tasks") - # find all tasks, deduplicate them, create new task indices for each dataset - # indexed by dataset index - datasets_task_index_to_aggr_task_index = {} - aggr_task_index = 0 - for dataset_index, meta in enumerate(tqdm.tqdm(all_metadata, desc="Find all tasks")): - task_index_to_aggr_task_index = {} + unique_tasks = pd.concat([meta.tasks for meta in all_metadata]).index.unique() + aggr_meta.tasks = pd.DataFrame({"task_index": range(len(unique_tasks))}, index=unique_tasks) - for task_index, task in meta.tasks.items(): - if task not in aggr_meta.task_to_task_index: - # add the task to aggr tasks mappings - aggr_meta.tasks[aggr_task_index] = task - aggr_meta.task_to_task_index[task] = aggr_task_index - aggr_task_index += 1 + num_episodes = 0 + num_frames = 0 - # add task_index anyway - task_index_to_aggr_task_index[task_index] = aggr_meta.task_to_task_index[task] + aggr_meta_chunk_idx = 0 + aggr_meta_file_idx = 0 - datasets_task_index_to_aggr_task_index[dataset_index] = task_index_to_aggr_task_index + aggr_data_chunk_idx = 0 + aggr_data_file_idx = 0 - logging.info("Copy data and videos") - aggr_episode_index_shift = 0 - for dataset_index, meta in enumerate(tqdm.tqdm(all_metadata, desc="Copy data and videos")): - # cp data - for episode_index in range(meta.total_episodes): - aggr_episode_index = episode_index + aggr_episode_index_shift - data_path = meta.root / meta.get_data_file_path(episode_index) - aggr_data_path = aggr_meta.root / aggr_meta.get_data_file_path(aggr_episode_index) + aggr_videos_chunk_idx = {key: 0 for key in video_keys} + aggr_videos_file_idx = {key: 0 for key in video_keys} - # update episode_index and task_index - df = pd.read_parquet(data_path) - update_row_func = get_update_episode_and_task_func( - aggr_episode_index_shift, datasets_task_index_to_aggr_task_index[dataset_index] + for meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"): + + meta_chunk_file_ids = set([(c,f) for c, f in zip(meta.episodes["meta/episodes/chunk_index"], meta.episodes["meta/episodes/file_index"])]) + for chunk_idx, file_idx in meta_chunk_file_ids: + path = meta.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx) + df = pd.read_parquet(path) + update_meta_func = get_update_meta_func( + aggr_meta_chunk_idx, + aggr_meta_file_idx, + aggr_data_chunk_idx, + aggr_data_file_idx, + aggr_videos_chunk_idx, + aggr_videos_file_idx, + num_frames, ) - df = df.apply(update_row_func, axis=1) - - aggr_data_path.parent.mkdir(parents=True, exist_ok=True) - df.to_parquet(aggr_data_path) + df = df.apply(update_meta_func, axis=1) + + aggr_path = aggr_root / DEFAULT_EPISODES_PATH.format(chunk_index=aggr_meta_chunk_idx, file_index=aggr_meta_file_idx) + aggr_path.parent.mkdir(parents=True, exist_ok=True) + df.to_parquet(aggr_path) + + aggr_meta_file_idx += 1 + if aggr_meta_file_idx >= DEFAULT_CHUNK_SIZE: + aggr_meta_file_idx = 0 + aggr_meta_chunk_idx += 1 # cp videos - for episode_index in range(meta.total_episodes): - aggr_episode_index = episode_index + aggr_episode_index_shift - for vid_key in meta.video_keys: - video_path = meta.root / meta.get_video_file_path(episode_index, vid_key) - aggr_video_path = aggr_meta.root / aggr_meta.get_video_file_path(aggr_episode_index, vid_key) - aggr_video_path.parent.mkdir(parents=True, exist_ok=True) - shutil.copy(video_path, aggr_video_path) + for key in video_keys: + video_chunk_file_ids = set([(c,f) for c, f in zip(meta.episodes[f"videos/{key}/chunk_index"], meta.episodes[f"videos/{key}/file_index"])]) + for chunk_idx, file_idx in video_chunk_file_ids: + path = meta.root / DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=chunk_idx, file_index=file_idx) + aggr_path = aggr_root / DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=aggr_videos_chunk_idx[key], file_index=aggr_videos_file_idx[key]) + aggr_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copy(str(path), str(aggr_path)) # copy_command = f"cp {video_path} {aggr_video_path} &" # subprocess.Popen(copy_command, shell=True) - # populate episodes - for episode_index, episode_dict in meta.episodes.items(): - aggr_episode_index = episode_index + aggr_episode_index_shift - episode_dict["episode_index"] = aggr_episode_index - aggr_meta.episodes[aggr_episode_index] = episode_dict + aggr_videos_file_idx[key] += 1 + if aggr_videos_file_idx[key] >= DEFAULT_CHUNK_SIZE: + aggr_videos_file_idx[key] = 0 + aggr_videos_chunk_idx[key] += 1 - # populate episodes_stats - for episode_index, episode_stats in meta.episodes_stats.items(): - aggr_episode_index = episode_index + aggr_episode_index_shift - aggr_meta.episodes_stats[aggr_episode_index] = episode_stats + data_chunk_file_ids = set([(c,f) for c, f in zip(meta.episodes["data/chunk_index"], meta.episodes["data/file_index"])]) + for chunk_idx, file_idx in data_chunk_file_ids: + path = meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx) + df = pd.read_parquet(path) + update_data_func = get_update_episode_and_task_func(num_episodes, meta.tasks, aggr_meta.tasks) + df = df.apply(update_data_func, axis=1) - # populate info - aggr_meta.info["total_episodes"] += meta.total_episodes - aggr_meta.info["total_frames"] += meta.total_frames - aggr_meta.info["total_videos"] += len(aggr_meta.video_keys) * meta.total_episodes + aggr_path = aggr_root / DEFAULT_DATA_PATH.format(chunk_index=aggr_data_chunk_idx, file_index=aggr_data_file_idx) + aggr_path.parent.mkdir(parents=True, exist_ok=True) + df.to_parquet(aggr_path) - aggr_episode_index_shift += meta.total_episodes + aggr_data_file_idx += 1 + if aggr_data_file_idx >= DEFAULT_CHUNK_SIZE: + aggr_data_file_idx = 0 + aggr_data_chunk_idx += 1 + + num_episodes += meta.total_episodes + num_frames += meta.total_frames - logging.info("write meta data") - - aggr_meta.info["total_chunks"] = aggr_meta.get_episode_chunk(aggr_episode_index_shift - 1) - aggr_meta.info["splits"] = {"train": f"0:{aggr_meta.info['total_episodes']}"} - - # create a new episodes jsonl with updated episode_index using write_episode - for episode_dict in aggr_meta.episodes.values(): - write_episode(episode_dict, aggr_meta.root) - - # create a new episode_stats jsonl with updated episode_index using write_episode_stats - for episode_index, episode_stats in aggr_meta.episodes_stats.items(): - legacy_write_episode_stats(episode_index, episode_stats, aggr_meta.root) - - # create a new task jsonl with updated episode_index using write_task - for task_index, task in aggr_meta.tasks.items(): - legacy_write_task(task_index, task, aggr_meta.root) + logging.info("write tasks") + write_tasks(aggr_meta.tasks, aggr_meta.root) + logging.info("write info") + aggr_meta.info["total_episodes"] = sum([meta.total_episodes for meta in all_metadata]) + aggr_meta.info["total_frames"] = sum([meta.total_frames for meta in all_metadata]) + aggr_meta.info["splits"] = {"train": f"0:{aggr_meta.total_episodes}"} write_info(aggr_meta.info, aggr_meta.root) + logging.info("write stats") + aggr_meta.stats = aggregate_stats([meta.stats for meta in all_metadata]) + write_stats(aggr_meta.stats, aggr_meta.root) + if __name__ == "__main__": init_logging() diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 52e1b015..6fb68b60 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -712,8 +712,8 @@ class LeRobotDataset(torch.utils.data.Dataset): return get_hf_features_from_features(self.features) def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]: - ep_start = self.meta.episodes["data/from_index"][ep_idx] - ep_end = self.meta.episodes["data/to_index"][ep_idx] + ep_start = self.meta.episodes["dataset_from_index"][ep_idx] + ep_end = self.meta.episodes["dataset_to_index"][ep_idx] query_indices = { key: [max(ep_start, min(ep_end - 1, idx + delta)) for delta in delta_idx] for key, delta_idx in self.delta_indices.items() @@ -1017,8 +1017,8 @@ class LeRobotDataset(torch.utils.data.Dataset): metadata = { "data/chunk_index": chunk_idx, "data/file_index": file_idx, - "data/from_index": latest_num_frames, - "data/to_index": latest_num_frames + ep_num_frames, + "dataset_from_index": latest_num_frames, + "dataset_to_index": latest_num_frames + ep_num_frames, } return metadata diff --git a/lerobot/common/datasets/v30/convert_dataset_v21_to_v30.py b/lerobot/common/datasets/v30/convert_dataset_v21_to_v30.py index 7a4ff6c9..83a6145c 100644 --- a/lerobot/common/datasets/v30/convert_dataset_v21_to_v30.py +++ b/lerobot/common/datasets/v30/convert_dataset_v21_to_v30.py @@ -27,6 +27,7 @@ from datasets import Dataset from huggingface_hub import HfApi, snapshot_download from lerobot.common.constants import HF_LEROBOT_HOME +from lerobot.common.datasets.compute_stats import aggregate_stats from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset from lerobot.common.datasets.utils import ( DEFAULT_CHUNK_SIZE, @@ -46,6 +47,7 @@ from lerobot.common.datasets.utils import ( update_chunk_file_indices, write_episodes, write_info, + write_stats, write_tasks, ) @@ -306,6 +308,9 @@ def convert_episodes_metadata(root, new_root, episodes_metadata, episodes_video_ ) write_episodes(ds_episodes, new_root) + stats = aggregate_stats(list(episodes_stats.values())) + write_stats(stats, new_root) + def convert_info(root, new_root): info = load_info(root) @@ -330,9 +335,16 @@ def convert_dataset( branch: str | None = None, num_workers: int = 4, ): + + root = HF_LEROBOT_HOME / repo_id + old_root = HF_LEROBOT_HOME / f"{repo_id}_old" new_root = HF_LEROBOT_HOME / f"{repo_id}_v30" + if old_root.is_dir() and root.is_dir(): + shutil.rmtree(str(root)) + shutil.move(str(old_root), str(root)) + if new_root.is_dir(): shutil.rmtree(new_root) @@ -349,7 +361,7 @@ def convert_dataset( episodes_videos_metadata = convert_videos(root, new_root) convert_episodes_metadata(root, new_root, episodes_metadata, episodes_videos_metadata) - shutil.move(str(root), str(root) + "_old") + shutil.move(str(root), str(old_root)) shutil.move(str(new_root), str(root)) # TODO(racdene) @@ -365,7 +377,7 @@ def convert_dataset( hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset") - LeRobotDataset(repo_id).push_to_hub() + # LeRobotDataset(repo_id).push_to_hub() if __name__ == "__main__": diff --git a/tests/fixtures/dataset_factories.py b/tests/fixtures/dataset_factories.py index b9c33483..256952bd 100644 --- a/tests/fixtures/dataset_factories.py +++ b/tests/fixtures/dataset_factories.py @@ -212,6 +212,7 @@ def tasks_factory(): def episodes_factory(tasks_factory, stats_factory): def _create_episodes( features: dict[str], + fps: int = DEFAULT_FPS, total_episodes: int = 3, total_frames: int = 400, video_keys: list[str] | None = None, @@ -252,6 +253,8 @@ def episodes_factory(tasks_factory, stats_factory): for video_key in video_keys: d[f"videos/{video_key}/chunk_index"] = [] d[f"videos/{video_key}/file_index"] = [] + d[f"videos/{video_key}/from_timestamp"] = [] + d[f"videos/{video_key}/to_timestamp"] = [] for stats_key in flatten_dict({"stats": stats_factory(features)}): d[stats_key] = [] @@ -281,6 +284,8 @@ def episodes_factory(tasks_factory, stats_factory): for video_key in video_keys: d[f"videos/{video_key}/chunk_index"].append(0) d[f"videos/{video_key}/file_index"].append(0) + d[f"videos/{video_key}/from_timestamp"].append(num_frames / fps) + d[f"videos/{video_key}/to_timestamp"].append((num_frames + lengths[ep_idx]) / fps) # Add stats columns like "stats/action/max" for stats_key, stats in flatten_dict({"stats": stats_factory(features)}).items(): @@ -306,7 +311,7 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar if features is None: features = features_factory() if episodes is None: - episodes = episodes_factory(features) + episodes = episodes_factory(features, fps) timestamp_col = np.array([], dtype=np.float32) frame_index_col = np.array([], dtype=np.int64) @@ -379,6 +384,7 @@ def lerobot_dataset_metadata_factory( video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"] episodes = episodes_factory( features=info["features"], + fps=info["fps"], total_episodes=info["total_episodes"], total_frames=info["total_frames"], video_keys=video_keys, @@ -441,6 +447,7 @@ def lerobot_dataset_factory( video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"] episodes_metadata = episodes_factory( features=info["features"], + fps=info["fps"], total_episodes=info["total_episodes"], total_frames=info["total_frames"], video_keys=video_keys, diff --git a/tests/fixtures/files.py b/tests/fixtures/files.py index ad4c3c95..805fee3a 100644 --- a/tests/fixtures/files.py +++ b/tests/fixtures/files.py @@ -73,6 +73,7 @@ def create_tasks(tasks_factory): def create_episodes(episodes_factory): def _create_episodes(dir: Path, episodes: datasets.Dataset | None = None): if episodes is None: + # TODO(rcadene): add features, fps as arguments episodes = episodes_factory() write_episodes(episodes, dir) diff --git a/tests/fixtures/hub.py b/tests/fixtures/hub.py index a7d72323..6caa9246 100644 --- a/tests/fixtures/hub.py +++ b/tests/fixtures/hub.py @@ -62,6 +62,7 @@ def mock_snapshot_download_factory( if episodes is None: episodes = episodes_factory( features=info["features"], + fps=info["fps"], total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks, @@ -121,6 +122,7 @@ def mock_snapshot_download_factory( create_tasks(local_dir, tasks) if has_episodes: create_episodes(local_dir, episodes) + # TODO(rcadene): create_videos? if has_data: create_hf_dataset(local_dir, hf_dataset) diff --git a/tests/test_aggregate_datasets.py b/tests/test_aggregate_datasets.py index ad5c2022..7380eced 100644 --- a/tests/test_aggregate_datasets.py +++ b/tests/test_aggregate_datasets.py @@ -1,19 +1,29 @@ from lerobot.common.datasets.aggregate import aggregate_datasets +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from tests.fixtures.constants import DUMMY_REPO_ID def test_aggregate_datasets(tmp_path, lerobot_dataset_factory): - dataset_0 = lerobot_dataset_factory( + ds_0 = lerobot_dataset_factory( root=tmp_path / "test_0", - repo_id=DUMMY_REPO_ID + "_0", + repo_id=f"{DUMMY_REPO_ID}_0", total_episodes=10, total_frames=400, ) - dataset_1 = lerobot_dataset_factory( + ds_1 = lerobot_dataset_factory( root=tmp_path / "test_1", - repo_id=DUMMY_REPO_ID + "_1", + repo_id=f"{DUMMY_REPO_ID}_1", total_episodes=10, total_frames=400, ) - dataset_2 = aggregate_datasets([dataset_0, dataset_1]) + aggregate_datasets( + repo_ids=[ds_0.repo_id, ds_1.repo_id], + roots=[ds_0.root, ds_1.root], + aggr_repo_id=f"{DUMMY_REPO_ID}_aggr", + aggr_root=tmp_path / "test_aggr" + ) + + aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_aggr", root=tmp_path / "test_aggr") + for item in aggr_ds: + pass