From ff0029f84bf1f3311e0b61c5b41eeae1c588330d Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Sat, 22 Feb 2025 15:33:47 +0000 Subject: [PATCH] aggregate works --- .../port_datasets/openx_rlds_datatrove.py | 4 +- lerobot/common/datasets/aggregate.py | 162 ++++++++++++++++++ 2 files changed, 164 insertions(+), 2 deletions(-) create mode 100644 lerobot/common/datasets/aggregate.py diff --git a/examples/port_datasets/openx_rlds_datatrove.py b/examples/port_datasets/openx_rlds_datatrove.py index 82d6cb3f..b4910544 100644 --- a/examples/port_datasets/openx_rlds_datatrove.py +++ b/examples/port_datasets/openx_rlds_datatrove.py @@ -69,12 +69,12 @@ def main(slurm=True): "job_name": port_job_name, "tasks": 2048, # "workers": 20, # 8 * 16, - "workers": 1, # 8 * 16, + "workers": 20, # 8 * 16, "time": "08:00:00", "partition": "hopper-cpu", "cpus_per_task": 24, "mem_per_cpu_gb": 2, - # "max_array_launch_parallel": True, + "max_array_launch_parallel": True, } else: executor_class = LocalPipelineExecutor diff --git a/lerobot/common/datasets/aggregate.py b/lerobot/common/datasets/aggregate.py new file mode 100644 index 00000000..1f02e07e --- /dev/null +++ b/lerobot/common/datasets/aggregate.py @@ -0,0 +1,162 @@ +import shutil +from pathlib import Path + +import pandas as pd + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata +from lerobot.common.datasets.utils import write_episode, write_episode_stats, write_info, write_task + + +def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]): + # validate same fps, robot_type, features + + fps = all_metadata[0].fps + robot_type = all_metadata[0].robot_type + features = all_metadata[0].features + + for meta in all_metadata: + if fps != meta.fps: + raise ValueError(f"Same fps is expected, but got fps={meta.fps} instead of {fps}.") + if robot_type != meta.robot_type: + raise ValueError( + f"Same robot_type is expected, but got robot_type={meta.robot_type} instead of {robot_type}." + ) + if features != meta.features: + raise ValueError( + f"Same features is expected, but got features={meta.features} instead of {features}." + ) + + return fps, robot_type, features + + +def get_update_episode_and_task_func(episode_index_to_add, task_index_to_global_task_index): + def _update(row): + row["episode_index"] = row["episode_index"] + episode_index_to_add + row["task_index"] = task_index_to_global_task_index[row["task_index"]] + return row + + return _update + + +def aggregate_datasets(all_metadata: list[LeRobotDatasetMetadata], repo_id: str, root=None): + fps, robot_type, features = validate_all_metadata(all_metadata) + + # Create resulting dataset folder + aggr_meta = LeRobotDatasetMetadata.create( + repo_id=repo_id, + fps=fps, + robot_type=robot_type, + features=features, + root=root, + ) + + # find all tasks, deduplicate them, create new task indices for each dataset + # indexed by dataset index + datasets_task_index_to_aggr_task_index = {} + aggr_task_index = 0 + for dataset_index, meta in enumerate(all_metadata): + task_index_to_aggr_task_index = {} + + for task_index, task in meta.tasks.items(): + if task not in aggr_meta.task_to_task_index: + # add the task to aggr tasks mappings + aggr_meta.tasks[aggr_task_index] = task + aggr_meta.task_to_task_index[task] = aggr_task_index + aggr_task_index += 1 + + # add task_index anyway + task_index_to_aggr_task_index[task_index] = aggr_meta.task_to_task_index[task] + + datasets_task_index_to_aggr_task_index[dataset_index] = task_index_to_aggr_task_index + + aggr_episode_index_shift = 0 + for dataset_index, meta in enumerate(all_metadata): + # cp data + for episode_index in range(meta.total_episodes): + aggr_episode_index = episode_index + aggr_episode_index_shift + data_path = meta.root / meta.get_data_file_path(episode_index) + aggr_data_path = aggr_meta.root / aggr_meta.get_data_file_path(aggr_episode_index) + + # update episode_index and task_index + df = pd.read_parquet(data_path) + update_row_func = get_update_episode_and_task_func( + aggr_episode_index_shift, datasets_task_index_to_aggr_task_index[dataset_index] + ) + df = df.apply(update_row_func, axis=1) + + aggr_data_path.parent.mkdir(parents=True, exist_ok=True) + df.to_parquet(aggr_data_path) + + # cp videos + for episode_index in range(meta.total_episodes): + aggr_episode_index = episode_index + aggr_episode_index_shift + for vid_key in meta.video_keys: + video_path = meta.root / meta.get_video_file_path(episode_index, vid_key) + aggr_video_path = aggr_meta.root / aggr_meta.get_video_file_path(aggr_episode_index, vid_key) + aggr_video_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copy(video_path, aggr_video_path) + + # populate episodes + for episode_index, episode_dict in meta.episodes.items(): + aggr_episode_index = episode_index + aggr_episode_index_shift + episode_dict["episode_index"] = aggr_episode_index + aggr_meta.episodes[aggr_episode_index] = episode_dict + + # populate episodes_stats + for episode_index, episode_stats in meta.episodes_stats.items(): + aggr_episode_index = episode_index + aggr_episode_index_shift + aggr_meta.episodes_stats[aggr_episode_index] = episode_stats + + # populate info + aggr_meta.info["total_episodes"] += meta.total_episodes + aggr_meta.info["total_frames"] += meta.total_episodes + aggr_meta.info["total_videos"] += len(aggr_meta.video_keys) * meta.total_episodes + + aggr_episode_index_shift += meta.total_episodes + + aggr_meta.info["total_chunks"] = aggr_meta.get_episode_chunk(aggr_episode_index_shift - 1) + aggr_meta.info["splits"] = {"train": f"0:{aggr_meta.info['total_episodes']}"} + + # create a new episodes jsonl with updated episode_index using write_episode + for episode_dict in aggr_meta.episodes.values(): + write_episode(episode_dict, aggr_meta.root) + + # create a new episode_stats jsonl with updated episode_index using write_episode_stats + for episode_index, episode_stats in aggr_meta.episodes_stats.items(): + write_episode_stats(episode_index, episode_stats, aggr_meta.root) + + # create a new task jsonl with updated episode_index using write_task + for task_index, task in aggr_meta.tasks.items(): + write_task(task_index, task, aggr_meta.root) + + write_info(aggr_meta.info, aggr_meta.root) + + +if __name__ == "__main__": + repo_id = "cadene/droid" + datetime = "2025-02-22_11-23-54" + + root = Path(f"/tmp/{repo_id}") + if root.exists(): + shutil.rmtree(root) + + all_metadata = [ + LeRobotDatasetMetadata(f"{repo_id}_{datetime}_world_2048_rank_0"), + LeRobotDatasetMetadata(f"{repo_id}_{datetime}_world_2048_rank_1"), + ] + + aggregate_datasets( + all_metadata, + repo_id, + root=root, + ) + + aggr_dataset = LeRobotDataset( + repo_id=repo_id, + root=root, + ) + aggr_dataset.push_to_hub() + + # for meta in all_metadata: + # dataset = LeRobotDataset(repo_id=meta.repo_id, root=meta.root) + # dataset.push_to_hub()