From c36d2253d0b48f6a66a0111fd0018ab5f7a236b1 Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Sun, 23 Feb 2025 18:18:46 +0000 Subject: [PATCH] Aggregate works --- examples/port_datasets/openx_rlds.py | 7 ++- .../port_datasets/openx_rlds_completed.py | 52 ++++++++++++++++++ lerobot/common/datasets/aggregate.py | 54 +++++++++++-------- tests/test_aggregate_datasets.py | 19 +++++++ 4 files changed, 106 insertions(+), 26 deletions(-) create mode 100644 examples/port_datasets/openx_rlds_completed.py create mode 100644 tests/test_aggregate_datasets.py diff --git a/examples/port_datasets/openx_rlds.py b/examples/port_datasets/openx_rlds.py index 51a92773..db051b5f 100644 --- a/examples/port_datasets/openx_rlds.py +++ b/examples/port_datasets/openx_rlds.py @@ -36,7 +36,6 @@ python examples/port_datasets/openx_rlds.py \ import argparse import logging import re -import shutil import time from pathlib import Path @@ -316,9 +315,9 @@ def main(): args = parser.parse_args() - droid_dir = Path("/fsx/remi_cadene/.cache/huggingface/lerobot/cadene/droid") - if droid_dir.exists(): - shutil.rmtree(droid_dir) + # droid_dir = Path("/fsx/remi_cadene/.cache/huggingface/lerobot/cadene/droid") + # if droid_dir.exists(): + # shutil.rmtree(droid_dir) create_lerobot_dataset(**vars(args)) diff --git a/examples/port_datasets/openx_rlds_completed.py b/examples/port_datasets/openx_rlds_completed.py new file mode 100644 index 00000000..849d65cf --- /dev/null +++ b/examples/port_datasets/openx_rlds_completed.py @@ -0,0 +1,52 @@ +from pathlib import Path + +import tqdm + +from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata + + +def main(): + repo_id = "cadene/droid" + datetime = "2025-02-22_11-23-54" + port_log_dir = Path(f"/fsx/remi_cadene/logs/{datetime}_port_openx_droid") + + compl_dir = port_log_dir / "completions" + + paths = list(compl_dir.glob("*")) + total_items = len(paths) + + # Use tqdm with the total parameter + wrong_completions = [] + error_messages = [] + for i, path in tqdm.tqdm(enumerate(paths), total=total_items): + try: + rank = path.name.lstrip("0") + if rank == "": + rank = 0 + meta = LeRobotDatasetMetadata(f"{repo_id}_{datetime}_world_2048_rank_{rank}") + last_episode_index = meta.total_episodes - 1 + last_ep_data_path = meta.root / meta.get_data_file_path(last_episode_index) + + if not last_ep_data_path.exists(): + raise ValueError(path) + + for vid_key in meta.video_keys: + last_ep_vid_path = meta.root / meta.get_video_file_path(last_episode_index, vid_key) + if not last_ep_vid_path.exists(): + raise ValueError(path) + + except Exception as e: + error_messages.append(str(e)) + wrong_completions.append(path) + + for path, error_msg in zip(wrong_completions, error_messages, strict=False): + print(path) + print(error_msg) + print() + # path.unlink() + + print(f"Error {len(wrong_completions)} / {total_items}") + + +if __name__ == "__main__": + main() diff --git a/lerobot/common/datasets/aggregate.py b/lerobot/common/datasets/aggregate.py index 1f02e07e..d891e008 100644 --- a/lerobot/common/datasets/aggregate.py +++ b/lerobot/common/datasets/aggregate.py @@ -1,10 +1,12 @@ -import shutil -from pathlib import Path +import logging +import subprocess import pandas as pd +import tqdm from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata from lerobot.common.datasets.utils import write_episode, write_episode_stats, write_info, write_task +from lerobot.common.utils.utils import init_logging def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]): @@ -14,7 +16,7 @@ def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]): robot_type = all_metadata[0].robot_type features = all_metadata[0].features - for meta in all_metadata: + for meta in tqdm.tqdm(all_metadata): if fps != meta.fps: raise ValueError(f"Same fps is expected, but got fps={meta.fps} instead of {fps}.") if robot_type != meta.robot_type: @@ -39,6 +41,7 @@ def get_update_episode_and_task_func(episode_index_to_add, task_index_to_global_ def aggregate_datasets(all_metadata: list[LeRobotDatasetMetadata], repo_id: str, root=None): + logging.info("start aggregate_datasets") fps, robot_type, features = validate_all_metadata(all_metadata) # Create resulting dataset folder @@ -50,11 +53,12 @@ def aggregate_datasets(all_metadata: list[LeRobotDatasetMetadata], repo_id: str, root=root, ) + logging.info("find all tasks") # find all tasks, deduplicate them, create new task indices for each dataset # indexed by dataset index datasets_task_index_to_aggr_task_index = {} aggr_task_index = 0 - for dataset_index, meta in enumerate(all_metadata): + for dataset_index, meta in enumerate(tqdm.tqdm(all_metadata)): task_index_to_aggr_task_index = {} for task_index, task in meta.tasks.items(): @@ -69,8 +73,9 @@ def aggregate_datasets(all_metadata: list[LeRobotDatasetMetadata], repo_id: str, datasets_task_index_to_aggr_task_index[dataset_index] = task_index_to_aggr_task_index + logging.info("cp data and videos") aggr_episode_index_shift = 0 - for dataset_index, meta in enumerate(all_metadata): + for dataset_index, meta in enumerate(tqdm.tqdm(all_metadata)): # cp data for episode_index in range(meta.total_episodes): aggr_episode_index = episode_index + aggr_episode_index_shift @@ -94,7 +99,10 @@ def aggregate_datasets(all_metadata: list[LeRobotDatasetMetadata], repo_id: str, video_path = meta.root / meta.get_video_file_path(episode_index, vid_key) aggr_video_path = aggr_meta.root / aggr_meta.get_video_file_path(aggr_episode_index, vid_key) aggr_video_path.parent.mkdir(parents=True, exist_ok=True) - shutil.copy(video_path, aggr_video_path) + # shutil.copy(video_path, aggr_video_path) + + copy_command = f"cp {video_path} {aggr_video_path} &" + subprocess.Popen(copy_command, shell=True) # populate episodes for episode_index, episode_dict in meta.episodes.items(): @@ -109,11 +117,13 @@ def aggregate_datasets(all_metadata: list[LeRobotDatasetMetadata], repo_id: str, # populate info aggr_meta.info["total_episodes"] += meta.total_episodes - aggr_meta.info["total_frames"] += meta.total_episodes + aggr_meta.info["total_frames"] += meta.total_frames aggr_meta.info["total_videos"] += len(aggr_meta.video_keys) * meta.total_episodes aggr_episode_index_shift += meta.total_episodes + logging.info("write meta data") + aggr_meta.info["total_chunks"] = aggr_meta.get_episode_chunk(aggr_episode_index_shift - 1) aggr_meta.info["splits"] = {"train": f"0:{aggr_meta.info['total_episodes']}"} @@ -133,30 +143,30 @@ def aggregate_datasets(all_metadata: list[LeRobotDatasetMetadata], repo_id: str, if __name__ == "__main__": + init_logging() repo_id = "cadene/droid" + aggr_repo_id = "cadene/droid" datetime = "2025-02-22_11-23-54" - root = Path(f"/tmp/{repo_id}") - if root.exists(): - shutil.rmtree(root) + # root = Path(f"/tmp/{repo_id}") + # if root.exists(): + # shutil.rmtree(root) + root = None - all_metadata = [ - LeRobotDatasetMetadata(f"{repo_id}_{datetime}_world_2048_rank_0"), - LeRobotDatasetMetadata(f"{repo_id}_{datetime}_world_2048_rank_1"), - ] + # all_metadata = [LeRobotDatasetMetadata(f"{repo_id}_{datetime}_world_2048_rank_{rank}") for rank in range(2048)] - aggregate_datasets( - all_metadata, - repo_id, - root=root, - ) + # aggregate_datasets( + # all_metadata, + # aggr_repo_id, + # root=root, + # ) aggr_dataset = LeRobotDataset( - repo_id=repo_id, + repo_id=aggr_repo_id, root=root, ) - aggr_dataset.push_to_hub() + aggr_dataset.push_to_hub(tags=["openx"]) # for meta in all_metadata: # dataset = LeRobotDataset(repo_id=meta.repo_id, root=meta.root) - # dataset.push_to_hub() + # dataset.push_to_hub(tags=["openx"]) diff --git a/tests/test_aggregate_datasets.py b/tests/test_aggregate_datasets.py new file mode 100644 index 00000000..ad5c2022 --- /dev/null +++ b/tests/test_aggregate_datasets.py @@ -0,0 +1,19 @@ +from lerobot.common.datasets.aggregate import aggregate_datasets +from tests.fixtures.constants import DUMMY_REPO_ID + + +def test_aggregate_datasets(tmp_path, lerobot_dataset_factory): + dataset_0 = lerobot_dataset_factory( + root=tmp_path / "test_0", + repo_id=DUMMY_REPO_ID + "_0", + total_episodes=10, + total_frames=400, + ) + dataset_1 = lerobot_dataset_factory( + root=tmp_path / "test_1", + repo_id=DUMMY_REPO_ID + "_1", + total_episodes=10, + total_frames=400, + ) + + dataset_2 = aggregate_datasets([dataset_0, dataset_1])