Progress on aggregate_datasets

This commit is contained in:
Remi Cadene 2025-04-19 19:11:53 +05:30
parent 54b5c805bf
commit b0cca75e5e
7 changed files with 149 additions and 87 deletions

View File

@ -1,11 +1,13 @@
import logging
from pathlib import Path
import shutil
import pandas as pd
import tqdm
from lerobot.common.datasets.compute_stats import aggregate_stats
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.common.datasets.utils import write_episode, legacy_write_episode_stats, write_info, legacy_write_task
from lerobot.common.datasets.utils import DEFAULT_CHUNK_SIZE, DEFAULT_DATA_PATH, DEFAULT_EPISODES_PATH, DEFAULT_VIDEO_PATH, write_episode, legacy_write_episode_stats, write_info, legacy_write_task, write_stats, write_tasks
from lerobot.common.utils.utils import init_logging
@ -30,22 +32,46 @@ def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
return fps, robot_type, features
def get_update_episode_and_task_func(episode_index_to_add, task_index_to_global_task_index):
def get_update_episode_and_task_func(episode_index_to_add, old_tasks, new_tasks):
def _update(row):
row["episode_index"] = row["episode_index"] + episode_index_to_add
row["task_index"] = task_index_to_global_task_index[row["task_index"]]
task = old_tasks.iloc[row["task_index"]].name
row["task_index"] = new_tasks.loc[task].task_index.item()
return row
return _update
def get_update_meta_func(
meta_chunk_index_to_add,
meta_file_index_to_add,
data_chunk_index_to_add,
data_file_index_to_add,
videos_chunk_index_to_add,
videos_file_index_to_add,
frame_index_to_add,
):
def _update(row):
row["meta/episodes/chunk_index"] = row["meta/episodes/chunk_index"] + meta_chunk_index_to_add
row["meta/episodes/file_index"] = row["meta/episodes/file_index"] + meta_file_index_to_add
row["data/chunk_index"] = row["data/chunk_index"] + data_chunk_index_to_add
row["data/file_index"] = row["data/file_index"] + data_file_index_to_add
for key in videos_chunk_index_to_add:
row[f"videos/{key}/chunk_index"] = row[f"videos/{key}/chunk_index"] + videos_chunk_index_to_add[key]
row[f"videos/{key}/file_index"] = row[f"videos/{key}/file_index"] + videos_file_index_to_add[key]
row["dataset_from_index"] = row["dataset_from_index"] + frame_index_to_add
row["dataset_to_index"] = row["dataset_to_index"] + frame_index_to_add
return row
return _update
def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, aggr_root=None):
def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]=None, aggr_root=None):
logging.info("Start aggregate_datasets")
all_metadata = [LeRobotDatasetMetadata(repo_id) for repo_id in repo_ids]
if roots is None:
all_metadata = [LeRobotDatasetMetadata(repo_id) for repo_id in repo_ids]
else:
all_metadata = [LeRobotDatasetMetadata(repo_id, root=root) for repo_id, root in zip(repo_ids, roots)]
fps, robot_type, features = validate_all_metadata(all_metadata)
video_keys = [key for key in features if features[key]["dtype"] == "video"]
# Create resulting dataset folder
aggr_meta = LeRobotDatasetMetadata.create(
@ -55,95 +81,99 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, aggr_root=None):
features=features,
root=aggr_root,
)
aggr_root = aggr_meta.root
logging.info("Find all tasks")
# find all tasks, deduplicate them, create new task indices for each dataset
# indexed by dataset index
datasets_task_index_to_aggr_task_index = {}
aggr_task_index = 0
for dataset_index, meta in enumerate(tqdm.tqdm(all_metadata, desc="Find all tasks")):
task_index_to_aggr_task_index = {}
unique_tasks = pd.concat([meta.tasks for meta in all_metadata]).index.unique()
aggr_meta.tasks = pd.DataFrame({"task_index": range(len(unique_tasks))}, index=unique_tasks)
for task_index, task in meta.tasks.items():
if task not in aggr_meta.task_to_task_index:
# add the task to aggr tasks mappings
aggr_meta.tasks[aggr_task_index] = task
aggr_meta.task_to_task_index[task] = aggr_task_index
aggr_task_index += 1
num_episodes = 0
num_frames = 0
# add task_index anyway
task_index_to_aggr_task_index[task_index] = aggr_meta.task_to_task_index[task]
aggr_meta_chunk_idx = 0
aggr_meta_file_idx = 0
datasets_task_index_to_aggr_task_index[dataset_index] = task_index_to_aggr_task_index
aggr_data_chunk_idx = 0
aggr_data_file_idx = 0
logging.info("Copy data and videos")
aggr_episode_index_shift = 0
for dataset_index, meta in enumerate(tqdm.tqdm(all_metadata, desc="Copy data and videos")):
# cp data
for episode_index in range(meta.total_episodes):
aggr_episode_index = episode_index + aggr_episode_index_shift
data_path = meta.root / meta.get_data_file_path(episode_index)
aggr_data_path = aggr_meta.root / aggr_meta.get_data_file_path(aggr_episode_index)
aggr_videos_chunk_idx = {key: 0 for key in video_keys}
aggr_videos_file_idx = {key: 0 for key in video_keys}
# update episode_index and task_index
df = pd.read_parquet(data_path)
update_row_func = get_update_episode_and_task_func(
aggr_episode_index_shift, datasets_task_index_to_aggr_task_index[dataset_index]
for meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"):
meta_chunk_file_ids = set([(c,f) for c, f in zip(meta.episodes["meta/episodes/chunk_index"], meta.episodes["meta/episodes/file_index"])])
for chunk_idx, file_idx in meta_chunk_file_ids:
path = meta.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
df = pd.read_parquet(path)
update_meta_func = get_update_meta_func(
aggr_meta_chunk_idx,
aggr_meta_file_idx,
aggr_data_chunk_idx,
aggr_data_file_idx,
aggr_videos_chunk_idx,
aggr_videos_file_idx,
num_frames,
)
df = df.apply(update_row_func, axis=1)
aggr_data_path.parent.mkdir(parents=True, exist_ok=True)
df.to_parquet(aggr_data_path)
df = df.apply(update_meta_func, axis=1)
aggr_path = aggr_root / DEFAULT_EPISODES_PATH.format(chunk_index=aggr_meta_chunk_idx, file_index=aggr_meta_file_idx)
aggr_path.parent.mkdir(parents=True, exist_ok=True)
df.to_parquet(aggr_path)
aggr_meta_file_idx += 1
if aggr_meta_file_idx >= DEFAULT_CHUNK_SIZE:
aggr_meta_file_idx = 0
aggr_meta_chunk_idx += 1
# cp videos
for episode_index in range(meta.total_episodes):
aggr_episode_index = episode_index + aggr_episode_index_shift
for vid_key in meta.video_keys:
video_path = meta.root / meta.get_video_file_path(episode_index, vid_key)
aggr_video_path = aggr_meta.root / aggr_meta.get_video_file_path(aggr_episode_index, vid_key)
aggr_video_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(video_path, aggr_video_path)
for key in video_keys:
video_chunk_file_ids = set([(c,f) for c, f in zip(meta.episodes[f"videos/{key}/chunk_index"], meta.episodes[f"videos/{key}/file_index"])])
for chunk_idx, file_idx in video_chunk_file_ids:
path = meta.root / DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=chunk_idx, file_index=file_idx)
aggr_path = aggr_root / DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=aggr_videos_chunk_idx[key], file_index=aggr_videos_file_idx[key])
aggr_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(str(path), str(aggr_path))
# copy_command = f"cp {video_path} {aggr_video_path} &"
# subprocess.Popen(copy_command, shell=True)
# populate episodes
for episode_index, episode_dict in meta.episodes.items():
aggr_episode_index = episode_index + aggr_episode_index_shift
episode_dict["episode_index"] = aggr_episode_index
aggr_meta.episodes[aggr_episode_index] = episode_dict
aggr_videos_file_idx[key] += 1
if aggr_videos_file_idx[key] >= DEFAULT_CHUNK_SIZE:
aggr_videos_file_idx[key] = 0
aggr_videos_chunk_idx[key] += 1
# populate episodes_stats
for episode_index, episode_stats in meta.episodes_stats.items():
aggr_episode_index = episode_index + aggr_episode_index_shift
aggr_meta.episodes_stats[aggr_episode_index] = episode_stats
data_chunk_file_ids = set([(c,f) for c, f in zip(meta.episodes["data/chunk_index"], meta.episodes["data/file_index"])])
for chunk_idx, file_idx in data_chunk_file_ids:
path = meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
df = pd.read_parquet(path)
update_data_func = get_update_episode_and_task_func(num_episodes, meta.tasks, aggr_meta.tasks)
df = df.apply(update_data_func, axis=1)
# populate info
aggr_meta.info["total_episodes"] += meta.total_episodes
aggr_meta.info["total_frames"] += meta.total_frames
aggr_meta.info["total_videos"] += len(aggr_meta.video_keys) * meta.total_episodes
aggr_path = aggr_root / DEFAULT_DATA_PATH.format(chunk_index=aggr_data_chunk_idx, file_index=aggr_data_file_idx)
aggr_path.parent.mkdir(parents=True, exist_ok=True)
df.to_parquet(aggr_path)
aggr_episode_index_shift += meta.total_episodes
aggr_data_file_idx += 1
if aggr_data_file_idx >= DEFAULT_CHUNK_SIZE:
aggr_data_file_idx = 0
aggr_data_chunk_idx += 1
num_episodes += meta.total_episodes
num_frames += meta.total_frames
logging.info("write meta data")
aggr_meta.info["total_chunks"] = aggr_meta.get_episode_chunk(aggr_episode_index_shift - 1)
aggr_meta.info["splits"] = {"train": f"0:{aggr_meta.info['total_episodes']}"}
# create a new episodes jsonl with updated episode_index using write_episode
for episode_dict in aggr_meta.episodes.values():
write_episode(episode_dict, aggr_meta.root)
# create a new episode_stats jsonl with updated episode_index using write_episode_stats
for episode_index, episode_stats in aggr_meta.episodes_stats.items():
legacy_write_episode_stats(episode_index, episode_stats, aggr_meta.root)
# create a new task jsonl with updated episode_index using write_task
for task_index, task in aggr_meta.tasks.items():
legacy_write_task(task_index, task, aggr_meta.root)
logging.info("write tasks")
write_tasks(aggr_meta.tasks, aggr_meta.root)
logging.info("write info")
aggr_meta.info["total_episodes"] = sum([meta.total_episodes for meta in all_metadata])
aggr_meta.info["total_frames"] = sum([meta.total_frames for meta in all_metadata])
aggr_meta.info["splits"] = {"train": f"0:{aggr_meta.total_episodes}"}
write_info(aggr_meta.info, aggr_meta.root)
logging.info("write stats")
aggr_meta.stats = aggregate_stats([meta.stats for meta in all_metadata])
write_stats(aggr_meta.stats, aggr_meta.root)
if __name__ == "__main__":
init_logging()

View File

@ -712,8 +712,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
return get_hf_features_from_features(self.features)
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
ep_start = self.meta.episodes["data/from_index"][ep_idx]
ep_end = self.meta.episodes["data/to_index"][ep_idx]
ep_start = self.meta.episodes["dataset_from_index"][ep_idx]
ep_end = self.meta.episodes["dataset_to_index"][ep_idx]
query_indices = {
key: [max(ep_start, min(ep_end - 1, idx + delta)) for delta in delta_idx]
for key, delta_idx in self.delta_indices.items()
@ -1017,8 +1017,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
metadata = {
"data/chunk_index": chunk_idx,
"data/file_index": file_idx,
"data/from_index": latest_num_frames,
"data/to_index": latest_num_frames + ep_num_frames,
"dataset_from_index": latest_num_frames,
"dataset_to_index": latest_num_frames + ep_num_frames,
}
return metadata

View File

@ -27,6 +27,7 @@ from datasets import Dataset
from huggingface_hub import HfApi, snapshot_download
from lerobot.common.constants import HF_LEROBOT_HOME
from lerobot.common.datasets.compute_stats import aggregate_stats
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.utils import (
DEFAULT_CHUNK_SIZE,
@ -46,6 +47,7 @@ from lerobot.common.datasets.utils import (
update_chunk_file_indices,
write_episodes,
write_info,
write_stats,
write_tasks,
)
@ -306,6 +308,9 @@ def convert_episodes_metadata(root, new_root, episodes_metadata, episodes_video_
)
write_episodes(ds_episodes, new_root)
stats = aggregate_stats(list(episodes_stats.values()))
write_stats(stats, new_root)
def convert_info(root, new_root):
info = load_info(root)
@ -330,9 +335,16 @@ def convert_dataset(
branch: str | None = None,
num_workers: int = 4,
):
root = HF_LEROBOT_HOME / repo_id
old_root = HF_LEROBOT_HOME / f"{repo_id}_old"
new_root = HF_LEROBOT_HOME / f"{repo_id}_v30"
if old_root.is_dir() and root.is_dir():
shutil.rmtree(str(root))
shutil.move(str(old_root), str(root))
if new_root.is_dir():
shutil.rmtree(new_root)
@ -349,7 +361,7 @@ def convert_dataset(
episodes_videos_metadata = convert_videos(root, new_root)
convert_episodes_metadata(root, new_root, episodes_metadata, episodes_videos_metadata)
shutil.move(str(root), str(root) + "_old")
shutil.move(str(root), str(old_root))
shutil.move(str(new_root), str(root))
# TODO(racdene)
@ -365,7 +377,7 @@ def convert_dataset(
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
LeRobotDataset(repo_id).push_to_hub()
# LeRobotDataset(repo_id).push_to_hub()
if __name__ == "__main__":

View File

@ -212,6 +212,7 @@ def tasks_factory():
def episodes_factory(tasks_factory, stats_factory):
def _create_episodes(
features: dict[str],
fps: int = DEFAULT_FPS,
total_episodes: int = 3,
total_frames: int = 400,
video_keys: list[str] | None = None,
@ -252,6 +253,8 @@ def episodes_factory(tasks_factory, stats_factory):
for video_key in video_keys:
d[f"videos/{video_key}/chunk_index"] = []
d[f"videos/{video_key}/file_index"] = []
d[f"videos/{video_key}/from_timestamp"] = []
d[f"videos/{video_key}/to_timestamp"] = []
for stats_key in flatten_dict({"stats": stats_factory(features)}):
d[stats_key] = []
@ -281,6 +284,8 @@ def episodes_factory(tasks_factory, stats_factory):
for video_key in video_keys:
d[f"videos/{video_key}/chunk_index"].append(0)
d[f"videos/{video_key}/file_index"].append(0)
d[f"videos/{video_key}/from_timestamp"].append(num_frames / fps)
d[f"videos/{video_key}/to_timestamp"].append((num_frames + lengths[ep_idx]) / fps)
# Add stats columns like "stats/action/max"
for stats_key, stats in flatten_dict({"stats": stats_factory(features)}).items():
@ -306,7 +311,7 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
if features is None:
features = features_factory()
if episodes is None:
episodes = episodes_factory(features)
episodes = episodes_factory(features, fps)
timestamp_col = np.array([], dtype=np.float32)
frame_index_col = np.array([], dtype=np.int64)
@ -379,6 +384,7 @@ def lerobot_dataset_metadata_factory(
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
episodes = episodes_factory(
features=info["features"],
fps=info["fps"],
total_episodes=info["total_episodes"],
total_frames=info["total_frames"],
video_keys=video_keys,
@ -441,6 +447,7 @@ def lerobot_dataset_factory(
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
episodes_metadata = episodes_factory(
features=info["features"],
fps=info["fps"],
total_episodes=info["total_episodes"],
total_frames=info["total_frames"],
video_keys=video_keys,

View File

@ -73,6 +73,7 @@ def create_tasks(tasks_factory):
def create_episodes(episodes_factory):
def _create_episodes(dir: Path, episodes: datasets.Dataset | None = None):
if episodes is None:
# TODO(rcadene): add features, fps as arguments
episodes = episodes_factory()
write_episodes(episodes, dir)

View File

@ -62,6 +62,7 @@ def mock_snapshot_download_factory(
if episodes is None:
episodes = episodes_factory(
features=info["features"],
fps=info["fps"],
total_episodes=info["total_episodes"],
total_frames=info["total_frames"],
tasks=tasks,
@ -121,6 +122,7 @@ def mock_snapshot_download_factory(
create_tasks(local_dir, tasks)
if has_episodes:
create_episodes(local_dir, episodes)
# TODO(rcadene): create_videos?
if has_data:
create_hf_dataset(local_dir, hf_dataset)

View File

@ -1,19 +1,29 @@
from lerobot.common.datasets.aggregate import aggregate_datasets
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from tests.fixtures.constants import DUMMY_REPO_ID
def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
dataset_0 = lerobot_dataset_factory(
ds_0 = lerobot_dataset_factory(
root=tmp_path / "test_0",
repo_id=DUMMY_REPO_ID + "_0",
repo_id=f"{DUMMY_REPO_ID}_0",
total_episodes=10,
total_frames=400,
)
dataset_1 = lerobot_dataset_factory(
ds_1 = lerobot_dataset_factory(
root=tmp_path / "test_1",
repo_id=DUMMY_REPO_ID + "_1",
repo_id=f"{DUMMY_REPO_ID}_1",
total_episodes=10,
total_frames=400,
)
dataset_2 = aggregate_datasets([dataset_0, dataset_1])
aggregate_datasets(
repo_ids=[ds_0.repo_id, ds_1.repo_id],
roots=[ds_0.root, ds_1.root],
aggr_repo_id=f"{DUMMY_REPO_ID}_aggr",
aggr_root=tmp_path / "test_aggr"
)
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_aggr", root=tmp_path / "test_aggr")
for item in aggr_ds:
pass