Aggregate works
This commit is contained in:
parent
e2e6f6e666
commit
c36d2253d0
|
@ -36,7 +36,6 @@ python examples/port_datasets/openx_rlds.py \
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import shutil
|
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
@ -316,9 +315,9 @@ def main():
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
droid_dir = Path("/fsx/remi_cadene/.cache/huggingface/lerobot/cadene/droid")
|
# droid_dir = Path("/fsx/remi_cadene/.cache/huggingface/lerobot/cadene/droid")
|
||||||
if droid_dir.exists():
|
# if droid_dir.exists():
|
||||||
shutil.rmtree(droid_dir)
|
# shutil.rmtree(droid_dir)
|
||||||
|
|
||||||
create_lerobot_dataset(**vars(args))
|
create_lerobot_dataset(**vars(args))
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,52 @@
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
repo_id = "cadene/droid"
|
||||||
|
datetime = "2025-02-22_11-23-54"
|
||||||
|
port_log_dir = Path(f"/fsx/remi_cadene/logs/{datetime}_port_openx_droid")
|
||||||
|
|
||||||
|
compl_dir = port_log_dir / "completions"
|
||||||
|
|
||||||
|
paths = list(compl_dir.glob("*"))
|
||||||
|
total_items = len(paths)
|
||||||
|
|
||||||
|
# Use tqdm with the total parameter
|
||||||
|
wrong_completions = []
|
||||||
|
error_messages = []
|
||||||
|
for i, path in tqdm.tqdm(enumerate(paths), total=total_items):
|
||||||
|
try:
|
||||||
|
rank = path.name.lstrip("0")
|
||||||
|
if rank == "":
|
||||||
|
rank = 0
|
||||||
|
meta = LeRobotDatasetMetadata(f"{repo_id}_{datetime}_world_2048_rank_{rank}")
|
||||||
|
last_episode_index = meta.total_episodes - 1
|
||||||
|
last_ep_data_path = meta.root / meta.get_data_file_path(last_episode_index)
|
||||||
|
|
||||||
|
if not last_ep_data_path.exists():
|
||||||
|
raise ValueError(path)
|
||||||
|
|
||||||
|
for vid_key in meta.video_keys:
|
||||||
|
last_ep_vid_path = meta.root / meta.get_video_file_path(last_episode_index, vid_key)
|
||||||
|
if not last_ep_vid_path.exists():
|
||||||
|
raise ValueError(path)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_messages.append(str(e))
|
||||||
|
wrong_completions.append(path)
|
||||||
|
|
||||||
|
for path, error_msg in zip(wrong_completions, error_messages, strict=False):
|
||||||
|
print(path)
|
||||||
|
print(error_msg)
|
||||||
|
print()
|
||||||
|
# path.unlink()
|
||||||
|
|
||||||
|
print(f"Error {len(wrong_completions)} / {total_items}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -1,10 +1,12 @@
|
||||||
import shutil
|
import logging
|
||||||
from pathlib import Path
|
import subprocess
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
import tqdm
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||||
from lerobot.common.datasets.utils import write_episode, write_episode_stats, write_info, write_task
|
from lerobot.common.datasets.utils import write_episode, write_episode_stats, write_info, write_task
|
||||||
|
from lerobot.common.utils.utils import init_logging
|
||||||
|
|
||||||
|
|
||||||
def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
|
def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
|
||||||
|
@ -14,7 +16,7 @@ def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
|
||||||
robot_type = all_metadata[0].robot_type
|
robot_type = all_metadata[0].robot_type
|
||||||
features = all_metadata[0].features
|
features = all_metadata[0].features
|
||||||
|
|
||||||
for meta in all_metadata:
|
for meta in tqdm.tqdm(all_metadata):
|
||||||
if fps != meta.fps:
|
if fps != meta.fps:
|
||||||
raise ValueError(f"Same fps is expected, but got fps={meta.fps} instead of {fps}.")
|
raise ValueError(f"Same fps is expected, but got fps={meta.fps} instead of {fps}.")
|
||||||
if robot_type != meta.robot_type:
|
if robot_type != meta.robot_type:
|
||||||
|
@ -39,6 +41,7 @@ def get_update_episode_and_task_func(episode_index_to_add, task_index_to_global_
|
||||||
|
|
||||||
|
|
||||||
def aggregate_datasets(all_metadata: list[LeRobotDatasetMetadata], repo_id: str, root=None):
|
def aggregate_datasets(all_metadata: list[LeRobotDatasetMetadata], repo_id: str, root=None):
|
||||||
|
logging.info("start aggregate_datasets")
|
||||||
fps, robot_type, features = validate_all_metadata(all_metadata)
|
fps, robot_type, features = validate_all_metadata(all_metadata)
|
||||||
|
|
||||||
# Create resulting dataset folder
|
# Create resulting dataset folder
|
||||||
|
@ -50,11 +53,12 @@ def aggregate_datasets(all_metadata: list[LeRobotDatasetMetadata], repo_id: str,
|
||||||
root=root,
|
root=root,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logging.info("find all tasks")
|
||||||
# find all tasks, deduplicate them, create new task indices for each dataset
|
# find all tasks, deduplicate them, create new task indices for each dataset
|
||||||
# indexed by dataset index
|
# indexed by dataset index
|
||||||
datasets_task_index_to_aggr_task_index = {}
|
datasets_task_index_to_aggr_task_index = {}
|
||||||
aggr_task_index = 0
|
aggr_task_index = 0
|
||||||
for dataset_index, meta in enumerate(all_metadata):
|
for dataset_index, meta in enumerate(tqdm.tqdm(all_metadata)):
|
||||||
task_index_to_aggr_task_index = {}
|
task_index_to_aggr_task_index = {}
|
||||||
|
|
||||||
for task_index, task in meta.tasks.items():
|
for task_index, task in meta.tasks.items():
|
||||||
|
@ -69,8 +73,9 @@ def aggregate_datasets(all_metadata: list[LeRobotDatasetMetadata], repo_id: str,
|
||||||
|
|
||||||
datasets_task_index_to_aggr_task_index[dataset_index] = task_index_to_aggr_task_index
|
datasets_task_index_to_aggr_task_index[dataset_index] = task_index_to_aggr_task_index
|
||||||
|
|
||||||
|
logging.info("cp data and videos")
|
||||||
aggr_episode_index_shift = 0
|
aggr_episode_index_shift = 0
|
||||||
for dataset_index, meta in enumerate(all_metadata):
|
for dataset_index, meta in enumerate(tqdm.tqdm(all_metadata)):
|
||||||
# cp data
|
# cp data
|
||||||
for episode_index in range(meta.total_episodes):
|
for episode_index in range(meta.total_episodes):
|
||||||
aggr_episode_index = episode_index + aggr_episode_index_shift
|
aggr_episode_index = episode_index + aggr_episode_index_shift
|
||||||
|
@ -94,7 +99,10 @@ def aggregate_datasets(all_metadata: list[LeRobotDatasetMetadata], repo_id: str,
|
||||||
video_path = meta.root / meta.get_video_file_path(episode_index, vid_key)
|
video_path = meta.root / meta.get_video_file_path(episode_index, vid_key)
|
||||||
aggr_video_path = aggr_meta.root / aggr_meta.get_video_file_path(aggr_episode_index, vid_key)
|
aggr_video_path = aggr_meta.root / aggr_meta.get_video_file_path(aggr_episode_index, vid_key)
|
||||||
aggr_video_path.parent.mkdir(parents=True, exist_ok=True)
|
aggr_video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
shutil.copy(video_path, aggr_video_path)
|
# shutil.copy(video_path, aggr_video_path)
|
||||||
|
|
||||||
|
copy_command = f"cp {video_path} {aggr_video_path} &"
|
||||||
|
subprocess.Popen(copy_command, shell=True)
|
||||||
|
|
||||||
# populate episodes
|
# populate episodes
|
||||||
for episode_index, episode_dict in meta.episodes.items():
|
for episode_index, episode_dict in meta.episodes.items():
|
||||||
|
@ -109,11 +117,13 @@ def aggregate_datasets(all_metadata: list[LeRobotDatasetMetadata], repo_id: str,
|
||||||
|
|
||||||
# populate info
|
# populate info
|
||||||
aggr_meta.info["total_episodes"] += meta.total_episodes
|
aggr_meta.info["total_episodes"] += meta.total_episodes
|
||||||
aggr_meta.info["total_frames"] += meta.total_episodes
|
aggr_meta.info["total_frames"] += meta.total_frames
|
||||||
aggr_meta.info["total_videos"] += len(aggr_meta.video_keys) * meta.total_episodes
|
aggr_meta.info["total_videos"] += len(aggr_meta.video_keys) * meta.total_episodes
|
||||||
|
|
||||||
aggr_episode_index_shift += meta.total_episodes
|
aggr_episode_index_shift += meta.total_episodes
|
||||||
|
|
||||||
|
logging.info("write meta data")
|
||||||
|
|
||||||
aggr_meta.info["total_chunks"] = aggr_meta.get_episode_chunk(aggr_episode_index_shift - 1)
|
aggr_meta.info["total_chunks"] = aggr_meta.get_episode_chunk(aggr_episode_index_shift - 1)
|
||||||
aggr_meta.info["splits"] = {"train": f"0:{aggr_meta.info['total_episodes']}"}
|
aggr_meta.info["splits"] = {"train": f"0:{aggr_meta.info['total_episodes']}"}
|
||||||
|
|
||||||
|
@ -133,30 +143,30 @@ def aggregate_datasets(all_metadata: list[LeRobotDatasetMetadata], repo_id: str,
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
init_logging()
|
||||||
repo_id = "cadene/droid"
|
repo_id = "cadene/droid"
|
||||||
|
aggr_repo_id = "cadene/droid"
|
||||||
datetime = "2025-02-22_11-23-54"
|
datetime = "2025-02-22_11-23-54"
|
||||||
|
|
||||||
root = Path(f"/tmp/{repo_id}")
|
# root = Path(f"/tmp/{repo_id}")
|
||||||
if root.exists():
|
# if root.exists():
|
||||||
shutil.rmtree(root)
|
# shutil.rmtree(root)
|
||||||
|
root = None
|
||||||
|
|
||||||
all_metadata = [
|
# all_metadata = [LeRobotDatasetMetadata(f"{repo_id}_{datetime}_world_2048_rank_{rank}") for rank in range(2048)]
|
||||||
LeRobotDatasetMetadata(f"{repo_id}_{datetime}_world_2048_rank_0"),
|
|
||||||
LeRobotDatasetMetadata(f"{repo_id}_{datetime}_world_2048_rank_1"),
|
|
||||||
]
|
|
||||||
|
|
||||||
aggregate_datasets(
|
# aggregate_datasets(
|
||||||
all_metadata,
|
# all_metadata,
|
||||||
repo_id,
|
# aggr_repo_id,
|
||||||
root=root,
|
# root=root,
|
||||||
)
|
# )
|
||||||
|
|
||||||
aggr_dataset = LeRobotDataset(
|
aggr_dataset = LeRobotDataset(
|
||||||
repo_id=repo_id,
|
repo_id=aggr_repo_id,
|
||||||
root=root,
|
root=root,
|
||||||
)
|
)
|
||||||
aggr_dataset.push_to_hub()
|
aggr_dataset.push_to_hub(tags=["openx"])
|
||||||
|
|
||||||
# for meta in all_metadata:
|
# for meta in all_metadata:
|
||||||
# dataset = LeRobotDataset(repo_id=meta.repo_id, root=meta.root)
|
# dataset = LeRobotDataset(repo_id=meta.repo_id, root=meta.root)
|
||||||
# dataset.push_to_hub()
|
# dataset.push_to_hub(tags=["openx"])
|
||||||
|
|
|
@ -0,0 +1,19 @@
|
||||||
|
from lerobot.common.datasets.aggregate import aggregate_datasets
|
||||||
|
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||||
|
|
||||||
|
|
||||||
|
def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
|
||||||
|
dataset_0 = lerobot_dataset_factory(
|
||||||
|
root=tmp_path / "test_0",
|
||||||
|
repo_id=DUMMY_REPO_ID + "_0",
|
||||||
|
total_episodes=10,
|
||||||
|
total_frames=400,
|
||||||
|
)
|
||||||
|
dataset_1 = lerobot_dataset_factory(
|
||||||
|
root=tmp_path / "test_1",
|
||||||
|
repo_id=DUMMY_REPO_ID + "_1",
|
||||||
|
total_episodes=10,
|
||||||
|
total_frames=400,
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset_2 = aggregate_datasets([dataset_0, dataset_1])
|
Loading…
Reference in New Issue