Progress on aggregate_datasets
This commit is contained in:
parent
54b5c805bf
commit
b0cca75e5e
|
@ -1,11 +1,13 @@
|
|||
import logging
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
|
||||
import pandas as pd
|
||||
import tqdm
|
||||
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.common.datasets.utils import write_episode, legacy_write_episode_stats, write_info, legacy_write_task
|
||||
from lerobot.common.datasets.utils import DEFAULT_CHUNK_SIZE, DEFAULT_DATA_PATH, DEFAULT_EPISODES_PATH, DEFAULT_VIDEO_PATH, write_episode, legacy_write_episode_stats, write_info, legacy_write_task, write_stats, write_tasks
|
||||
from lerobot.common.utils.utils import init_logging
|
||||
|
||||
|
||||
|
@ -30,22 +32,46 @@ def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
|
|||
|
||||
return fps, robot_type, features
|
||||
|
||||
|
||||
def get_update_episode_and_task_func(episode_index_to_add, task_index_to_global_task_index):
|
||||
def get_update_episode_and_task_func(episode_index_to_add, old_tasks, new_tasks):
|
||||
def _update(row):
|
||||
row["episode_index"] = row["episode_index"] + episode_index_to_add
|
||||
row["task_index"] = task_index_to_global_task_index[row["task_index"]]
|
||||
task = old_tasks.iloc[row["task_index"]].name
|
||||
row["task_index"] = new_tasks.loc[task].task_index.item()
|
||||
return row
|
||||
|
||||
return _update
|
||||
|
||||
def get_update_meta_func(
|
||||
meta_chunk_index_to_add,
|
||||
meta_file_index_to_add,
|
||||
data_chunk_index_to_add,
|
||||
data_file_index_to_add,
|
||||
videos_chunk_index_to_add,
|
||||
videos_file_index_to_add,
|
||||
frame_index_to_add,
|
||||
):
|
||||
def _update(row):
|
||||
row["meta/episodes/chunk_index"] = row["meta/episodes/chunk_index"] + meta_chunk_index_to_add
|
||||
row["meta/episodes/file_index"] = row["meta/episodes/file_index"] + meta_file_index_to_add
|
||||
row["data/chunk_index"] = row["data/chunk_index"] + data_chunk_index_to_add
|
||||
row["data/file_index"] = row["data/file_index"] + data_file_index_to_add
|
||||
for key in videos_chunk_index_to_add:
|
||||
row[f"videos/{key}/chunk_index"] = row[f"videos/{key}/chunk_index"] + videos_chunk_index_to_add[key]
|
||||
row[f"videos/{key}/file_index"] = row[f"videos/{key}/file_index"] + videos_file_index_to_add[key]
|
||||
row["dataset_from_index"] = row["dataset_from_index"] + frame_index_to_add
|
||||
row["dataset_to_index"] = row["dataset_to_index"] + frame_index_to_add
|
||||
return row
|
||||
return _update
|
||||
|
||||
def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, aggr_root=None):
|
||||
def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]=None, aggr_root=None):
|
||||
logging.info("Start aggregate_datasets")
|
||||
|
||||
all_metadata = [LeRobotDatasetMetadata(repo_id) for repo_id in repo_ids]
|
||||
if roots is None:
|
||||
all_metadata = [LeRobotDatasetMetadata(repo_id) for repo_id in repo_ids]
|
||||
else:
|
||||
all_metadata = [LeRobotDatasetMetadata(repo_id, root=root) for repo_id, root in zip(repo_ids, roots)]
|
||||
|
||||
fps, robot_type, features = validate_all_metadata(all_metadata)
|
||||
video_keys = [key for key in features if features[key]["dtype"] == "video"]
|
||||
|
||||
# Create resulting dataset folder
|
||||
aggr_meta = LeRobotDatasetMetadata.create(
|
||||
|
@ -55,95 +81,99 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, aggr_root=None):
|
|||
features=features,
|
||||
root=aggr_root,
|
||||
)
|
||||
aggr_root = aggr_meta.root
|
||||
|
||||
logging.info("Find all tasks")
|
||||
# find all tasks, deduplicate them, create new task indices for each dataset
|
||||
# indexed by dataset index
|
||||
datasets_task_index_to_aggr_task_index = {}
|
||||
aggr_task_index = 0
|
||||
for dataset_index, meta in enumerate(tqdm.tqdm(all_metadata, desc="Find all tasks")):
|
||||
task_index_to_aggr_task_index = {}
|
||||
unique_tasks = pd.concat([meta.tasks for meta in all_metadata]).index.unique()
|
||||
aggr_meta.tasks = pd.DataFrame({"task_index": range(len(unique_tasks))}, index=unique_tasks)
|
||||
|
||||
for task_index, task in meta.tasks.items():
|
||||
if task not in aggr_meta.task_to_task_index:
|
||||
# add the task to aggr tasks mappings
|
||||
aggr_meta.tasks[aggr_task_index] = task
|
||||
aggr_meta.task_to_task_index[task] = aggr_task_index
|
||||
aggr_task_index += 1
|
||||
num_episodes = 0
|
||||
num_frames = 0
|
||||
|
||||
# add task_index anyway
|
||||
task_index_to_aggr_task_index[task_index] = aggr_meta.task_to_task_index[task]
|
||||
aggr_meta_chunk_idx = 0
|
||||
aggr_meta_file_idx = 0
|
||||
|
||||
datasets_task_index_to_aggr_task_index[dataset_index] = task_index_to_aggr_task_index
|
||||
aggr_data_chunk_idx = 0
|
||||
aggr_data_file_idx = 0
|
||||
|
||||
logging.info("Copy data and videos")
|
||||
aggr_episode_index_shift = 0
|
||||
for dataset_index, meta in enumerate(tqdm.tqdm(all_metadata, desc="Copy data and videos")):
|
||||
# cp data
|
||||
for episode_index in range(meta.total_episodes):
|
||||
aggr_episode_index = episode_index + aggr_episode_index_shift
|
||||
data_path = meta.root / meta.get_data_file_path(episode_index)
|
||||
aggr_data_path = aggr_meta.root / aggr_meta.get_data_file_path(aggr_episode_index)
|
||||
aggr_videos_chunk_idx = {key: 0 for key in video_keys}
|
||||
aggr_videos_file_idx = {key: 0 for key in video_keys}
|
||||
|
||||
# update episode_index and task_index
|
||||
df = pd.read_parquet(data_path)
|
||||
update_row_func = get_update_episode_and_task_func(
|
||||
aggr_episode_index_shift, datasets_task_index_to_aggr_task_index[dataset_index]
|
||||
for meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"):
|
||||
|
||||
meta_chunk_file_ids = set([(c,f) for c, f in zip(meta.episodes["meta/episodes/chunk_index"], meta.episodes["meta/episodes/file_index"])])
|
||||
for chunk_idx, file_idx in meta_chunk_file_ids:
|
||||
path = meta.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
df = pd.read_parquet(path)
|
||||
update_meta_func = get_update_meta_func(
|
||||
aggr_meta_chunk_idx,
|
||||
aggr_meta_file_idx,
|
||||
aggr_data_chunk_idx,
|
||||
aggr_data_file_idx,
|
||||
aggr_videos_chunk_idx,
|
||||
aggr_videos_file_idx,
|
||||
num_frames,
|
||||
)
|
||||
df = df.apply(update_row_func, axis=1)
|
||||
|
||||
aggr_data_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
df.to_parquet(aggr_data_path)
|
||||
df = df.apply(update_meta_func, axis=1)
|
||||
|
||||
aggr_path = aggr_root / DEFAULT_EPISODES_PATH.format(chunk_index=aggr_meta_chunk_idx, file_index=aggr_meta_file_idx)
|
||||
aggr_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
df.to_parquet(aggr_path)
|
||||
|
||||
aggr_meta_file_idx += 1
|
||||
if aggr_meta_file_idx >= DEFAULT_CHUNK_SIZE:
|
||||
aggr_meta_file_idx = 0
|
||||
aggr_meta_chunk_idx += 1
|
||||
|
||||
# cp videos
|
||||
for episode_index in range(meta.total_episodes):
|
||||
aggr_episode_index = episode_index + aggr_episode_index_shift
|
||||
for vid_key in meta.video_keys:
|
||||
video_path = meta.root / meta.get_video_file_path(episode_index, vid_key)
|
||||
aggr_video_path = aggr_meta.root / aggr_meta.get_video_file_path(aggr_episode_index, vid_key)
|
||||
aggr_video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(video_path, aggr_video_path)
|
||||
for key in video_keys:
|
||||
video_chunk_file_ids = set([(c,f) for c, f in zip(meta.episodes[f"videos/{key}/chunk_index"], meta.episodes[f"videos/{key}/file_index"])])
|
||||
for chunk_idx, file_idx in video_chunk_file_ids:
|
||||
path = meta.root / DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=chunk_idx, file_index=file_idx)
|
||||
aggr_path = aggr_root / DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=aggr_videos_chunk_idx[key], file_index=aggr_videos_file_idx[key])
|
||||
aggr_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(str(path), str(aggr_path))
|
||||
|
||||
# copy_command = f"cp {video_path} {aggr_video_path} &"
|
||||
# subprocess.Popen(copy_command, shell=True)
|
||||
|
||||
# populate episodes
|
||||
for episode_index, episode_dict in meta.episodes.items():
|
||||
aggr_episode_index = episode_index + aggr_episode_index_shift
|
||||
episode_dict["episode_index"] = aggr_episode_index
|
||||
aggr_meta.episodes[aggr_episode_index] = episode_dict
|
||||
aggr_videos_file_idx[key] += 1
|
||||
if aggr_videos_file_idx[key] >= DEFAULT_CHUNK_SIZE:
|
||||
aggr_videos_file_idx[key] = 0
|
||||
aggr_videos_chunk_idx[key] += 1
|
||||
|
||||
# populate episodes_stats
|
||||
for episode_index, episode_stats in meta.episodes_stats.items():
|
||||
aggr_episode_index = episode_index + aggr_episode_index_shift
|
||||
aggr_meta.episodes_stats[aggr_episode_index] = episode_stats
|
||||
data_chunk_file_ids = set([(c,f) for c, f in zip(meta.episodes["data/chunk_index"], meta.episodes["data/file_index"])])
|
||||
for chunk_idx, file_idx in data_chunk_file_ids:
|
||||
path = meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
df = pd.read_parquet(path)
|
||||
update_data_func = get_update_episode_and_task_func(num_episodes, meta.tasks, aggr_meta.tasks)
|
||||
df = df.apply(update_data_func, axis=1)
|
||||
|
||||
# populate info
|
||||
aggr_meta.info["total_episodes"] += meta.total_episodes
|
||||
aggr_meta.info["total_frames"] += meta.total_frames
|
||||
aggr_meta.info["total_videos"] += len(aggr_meta.video_keys) * meta.total_episodes
|
||||
aggr_path = aggr_root / DEFAULT_DATA_PATH.format(chunk_index=aggr_data_chunk_idx, file_index=aggr_data_file_idx)
|
||||
aggr_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
df.to_parquet(aggr_path)
|
||||
|
||||
aggr_episode_index_shift += meta.total_episodes
|
||||
aggr_data_file_idx += 1
|
||||
if aggr_data_file_idx >= DEFAULT_CHUNK_SIZE:
|
||||
aggr_data_file_idx = 0
|
||||
aggr_data_chunk_idx += 1
|
||||
|
||||
num_episodes += meta.total_episodes
|
||||
num_frames += meta.total_frames
|
||||
|
||||
logging.info("write meta data")
|
||||
|
||||
aggr_meta.info["total_chunks"] = aggr_meta.get_episode_chunk(aggr_episode_index_shift - 1)
|
||||
aggr_meta.info["splits"] = {"train": f"0:{aggr_meta.info['total_episodes']}"}
|
||||
|
||||
# create a new episodes jsonl with updated episode_index using write_episode
|
||||
for episode_dict in aggr_meta.episodes.values():
|
||||
write_episode(episode_dict, aggr_meta.root)
|
||||
|
||||
# create a new episode_stats jsonl with updated episode_index using write_episode_stats
|
||||
for episode_index, episode_stats in aggr_meta.episodes_stats.items():
|
||||
legacy_write_episode_stats(episode_index, episode_stats, aggr_meta.root)
|
||||
|
||||
# create a new task jsonl with updated episode_index using write_task
|
||||
for task_index, task in aggr_meta.tasks.items():
|
||||
legacy_write_task(task_index, task, aggr_meta.root)
|
||||
logging.info("write tasks")
|
||||
write_tasks(aggr_meta.tasks, aggr_meta.root)
|
||||
|
||||
logging.info("write info")
|
||||
aggr_meta.info["total_episodes"] = sum([meta.total_episodes for meta in all_metadata])
|
||||
aggr_meta.info["total_frames"] = sum([meta.total_frames for meta in all_metadata])
|
||||
aggr_meta.info["splits"] = {"train": f"0:{aggr_meta.total_episodes}"}
|
||||
write_info(aggr_meta.info, aggr_meta.root)
|
||||
|
||||
logging.info("write stats")
|
||||
aggr_meta.stats = aggregate_stats([meta.stats for meta in all_metadata])
|
||||
write_stats(aggr_meta.stats, aggr_meta.root)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
init_logging()
|
||||
|
|
|
@ -712,8 +712,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
return get_hf_features_from_features(self.features)
|
||||
|
||||
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
|
||||
ep_start = self.meta.episodes["data/from_index"][ep_idx]
|
||||
ep_end = self.meta.episodes["data/to_index"][ep_idx]
|
||||
ep_start = self.meta.episodes["dataset_from_index"][ep_idx]
|
||||
ep_end = self.meta.episodes["dataset_to_index"][ep_idx]
|
||||
query_indices = {
|
||||
key: [max(ep_start, min(ep_end - 1, idx + delta)) for delta in delta_idx]
|
||||
for key, delta_idx in self.delta_indices.items()
|
||||
|
@ -1017,8 +1017,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
metadata = {
|
||||
"data/chunk_index": chunk_idx,
|
||||
"data/file_index": file_idx,
|
||||
"data/from_index": latest_num_frames,
|
||||
"data/to_index": latest_num_frames + ep_num_frames,
|
||||
"dataset_from_index": latest_num_frames,
|
||||
"dataset_to_index": latest_num_frames + ep_num_frames,
|
||||
}
|
||||
return metadata
|
||||
|
||||
|
|
|
@ -27,6 +27,7 @@ from datasets import Dataset
|
|||
from huggingface_hub import HfApi, snapshot_download
|
||||
|
||||
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.common.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
|
@ -46,6 +47,7 @@ from lerobot.common.datasets.utils import (
|
|||
update_chunk_file_indices,
|
||||
write_episodes,
|
||||
write_info,
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
|
||||
|
@ -306,6 +308,9 @@ def convert_episodes_metadata(root, new_root, episodes_metadata, episodes_video_
|
|||
)
|
||||
write_episodes(ds_episodes, new_root)
|
||||
|
||||
stats = aggregate_stats(list(episodes_stats.values()))
|
||||
write_stats(stats, new_root)
|
||||
|
||||
|
||||
def convert_info(root, new_root):
|
||||
info = load_info(root)
|
||||
|
@ -330,9 +335,16 @@ def convert_dataset(
|
|||
branch: str | None = None,
|
||||
num_workers: int = 4,
|
||||
):
|
||||
|
||||
|
||||
root = HF_LEROBOT_HOME / repo_id
|
||||
old_root = HF_LEROBOT_HOME / f"{repo_id}_old"
|
||||
new_root = HF_LEROBOT_HOME / f"{repo_id}_v30"
|
||||
|
||||
if old_root.is_dir() and root.is_dir():
|
||||
shutil.rmtree(str(root))
|
||||
shutil.move(str(old_root), str(root))
|
||||
|
||||
if new_root.is_dir():
|
||||
shutil.rmtree(new_root)
|
||||
|
||||
|
@ -349,7 +361,7 @@ def convert_dataset(
|
|||
episodes_videos_metadata = convert_videos(root, new_root)
|
||||
convert_episodes_metadata(root, new_root, episodes_metadata, episodes_videos_metadata)
|
||||
|
||||
shutil.move(str(root), str(root) + "_old")
|
||||
shutil.move(str(root), str(old_root))
|
||||
shutil.move(str(new_root), str(root))
|
||||
|
||||
# TODO(racdene)
|
||||
|
@ -365,7 +377,7 @@ def convert_dataset(
|
|||
|
||||
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||
|
||||
LeRobotDataset(repo_id).push_to_hub()
|
||||
# LeRobotDataset(repo_id).push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -212,6 +212,7 @@ def tasks_factory():
|
|||
def episodes_factory(tasks_factory, stats_factory):
|
||||
def _create_episodes(
|
||||
features: dict[str],
|
||||
fps: int = DEFAULT_FPS,
|
||||
total_episodes: int = 3,
|
||||
total_frames: int = 400,
|
||||
video_keys: list[str] | None = None,
|
||||
|
@ -252,6 +253,8 @@ def episodes_factory(tasks_factory, stats_factory):
|
|||
for video_key in video_keys:
|
||||
d[f"videos/{video_key}/chunk_index"] = []
|
||||
d[f"videos/{video_key}/file_index"] = []
|
||||
d[f"videos/{video_key}/from_timestamp"] = []
|
||||
d[f"videos/{video_key}/to_timestamp"] = []
|
||||
|
||||
for stats_key in flatten_dict({"stats": stats_factory(features)}):
|
||||
d[stats_key] = []
|
||||
|
@ -281,6 +284,8 @@ def episodes_factory(tasks_factory, stats_factory):
|
|||
for video_key in video_keys:
|
||||
d[f"videos/{video_key}/chunk_index"].append(0)
|
||||
d[f"videos/{video_key}/file_index"].append(0)
|
||||
d[f"videos/{video_key}/from_timestamp"].append(num_frames / fps)
|
||||
d[f"videos/{video_key}/to_timestamp"].append((num_frames + lengths[ep_idx]) / fps)
|
||||
|
||||
# Add stats columns like "stats/action/max"
|
||||
for stats_key, stats in flatten_dict({"stats": stats_factory(features)}).items():
|
||||
|
@ -306,7 +311,7 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
|
|||
if features is None:
|
||||
features = features_factory()
|
||||
if episodes is None:
|
||||
episodes = episodes_factory(features)
|
||||
episodes = episodes_factory(features, fps)
|
||||
|
||||
timestamp_col = np.array([], dtype=np.float32)
|
||||
frame_index_col = np.array([], dtype=np.int64)
|
||||
|
@ -379,6 +384,7 @@ def lerobot_dataset_metadata_factory(
|
|||
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
|
||||
episodes = episodes_factory(
|
||||
features=info["features"],
|
||||
fps=info["fps"],
|
||||
total_episodes=info["total_episodes"],
|
||||
total_frames=info["total_frames"],
|
||||
video_keys=video_keys,
|
||||
|
@ -441,6 +447,7 @@ def lerobot_dataset_factory(
|
|||
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
|
||||
episodes_metadata = episodes_factory(
|
||||
features=info["features"],
|
||||
fps=info["fps"],
|
||||
total_episodes=info["total_episodes"],
|
||||
total_frames=info["total_frames"],
|
||||
video_keys=video_keys,
|
||||
|
|
|
@ -73,6 +73,7 @@ def create_tasks(tasks_factory):
|
|||
def create_episodes(episodes_factory):
|
||||
def _create_episodes(dir: Path, episodes: datasets.Dataset | None = None):
|
||||
if episodes is None:
|
||||
# TODO(rcadene): add features, fps as arguments
|
||||
episodes = episodes_factory()
|
||||
write_episodes(episodes, dir)
|
||||
|
||||
|
|
|
@ -62,6 +62,7 @@ def mock_snapshot_download_factory(
|
|||
if episodes is None:
|
||||
episodes = episodes_factory(
|
||||
features=info["features"],
|
||||
fps=info["fps"],
|
||||
total_episodes=info["total_episodes"],
|
||||
total_frames=info["total_frames"],
|
||||
tasks=tasks,
|
||||
|
@ -121,6 +122,7 @@ def mock_snapshot_download_factory(
|
|||
create_tasks(local_dir, tasks)
|
||||
if has_episodes:
|
||||
create_episodes(local_dir, episodes)
|
||||
# TODO(rcadene): create_videos?
|
||||
if has_data:
|
||||
create_hf_dataset(local_dir, hf_dataset)
|
||||
|
||||
|
|
|
@ -1,19 +1,29 @@
|
|||
from lerobot.common.datasets.aggregate import aggregate_datasets
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
|
||||
def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
|
||||
dataset_0 = lerobot_dataset_factory(
|
||||
ds_0 = lerobot_dataset_factory(
|
||||
root=tmp_path / "test_0",
|
||||
repo_id=DUMMY_REPO_ID + "_0",
|
||||
repo_id=f"{DUMMY_REPO_ID}_0",
|
||||
total_episodes=10,
|
||||
total_frames=400,
|
||||
)
|
||||
dataset_1 = lerobot_dataset_factory(
|
||||
ds_1 = lerobot_dataset_factory(
|
||||
root=tmp_path / "test_1",
|
||||
repo_id=DUMMY_REPO_ID + "_1",
|
||||
repo_id=f"{DUMMY_REPO_ID}_1",
|
||||
total_episodes=10,
|
||||
total_frames=400,
|
||||
)
|
||||
|
||||
dataset_2 = aggregate_datasets([dataset_0, dataset_1])
|
||||
aggregate_datasets(
|
||||
repo_ids=[ds_0.repo_id, ds_1.repo_id],
|
||||
roots=[ds_0.root, ds_1.root],
|
||||
aggr_repo_id=f"{DUMMY_REPO_ID}_aggr",
|
||||
aggr_root=tmp_path / "test_aggr"
|
||||
)
|
||||
|
||||
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_aggr", root=tmp_path / "test_aggr")
|
||||
for item in aggr_ds:
|
||||
pass
|
||||
|
|
Loading…
Reference in New Issue