Progress on aggregate_datasets
This commit is contained in:
parent
54b5c805bf
commit
b0cca75e5e
|
@ -1,11 +1,13 @@
|
||||||
import logging
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
|
from lerobot.common.datasets.compute_stats import aggregate_stats
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||||
from lerobot.common.datasets.utils import write_episode, legacy_write_episode_stats, write_info, legacy_write_task
|
from lerobot.common.datasets.utils import DEFAULT_CHUNK_SIZE, DEFAULT_DATA_PATH, DEFAULT_EPISODES_PATH, DEFAULT_VIDEO_PATH, write_episode, legacy_write_episode_stats, write_info, legacy_write_task, write_stats, write_tasks
|
||||||
from lerobot.common.utils.utils import init_logging
|
from lerobot.common.utils.utils import init_logging
|
||||||
|
|
||||||
|
|
||||||
|
@ -30,22 +32,46 @@ def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
|
||||||
|
|
||||||
return fps, robot_type, features
|
return fps, robot_type, features
|
||||||
|
|
||||||
|
def get_update_episode_and_task_func(episode_index_to_add, old_tasks, new_tasks):
|
||||||
def get_update_episode_and_task_func(episode_index_to_add, task_index_to_global_task_index):
|
|
||||||
def _update(row):
|
def _update(row):
|
||||||
row["episode_index"] = row["episode_index"] + episode_index_to_add
|
row["episode_index"] = row["episode_index"] + episode_index_to_add
|
||||||
row["task_index"] = task_index_to_global_task_index[row["task_index"]]
|
task = old_tasks.iloc[row["task_index"]].name
|
||||||
|
row["task_index"] = new_tasks.loc[task].task_index.item()
|
||||||
return row
|
return row
|
||||||
|
|
||||||
return _update
|
return _update
|
||||||
|
|
||||||
|
def get_update_meta_func(
|
||||||
|
meta_chunk_index_to_add,
|
||||||
|
meta_file_index_to_add,
|
||||||
|
data_chunk_index_to_add,
|
||||||
|
data_file_index_to_add,
|
||||||
|
videos_chunk_index_to_add,
|
||||||
|
videos_file_index_to_add,
|
||||||
|
frame_index_to_add,
|
||||||
|
):
|
||||||
|
def _update(row):
|
||||||
|
row["meta/episodes/chunk_index"] = row["meta/episodes/chunk_index"] + meta_chunk_index_to_add
|
||||||
|
row["meta/episodes/file_index"] = row["meta/episodes/file_index"] + meta_file_index_to_add
|
||||||
|
row["data/chunk_index"] = row["data/chunk_index"] + data_chunk_index_to_add
|
||||||
|
row["data/file_index"] = row["data/file_index"] + data_file_index_to_add
|
||||||
|
for key in videos_chunk_index_to_add:
|
||||||
|
row[f"videos/{key}/chunk_index"] = row[f"videos/{key}/chunk_index"] + videos_chunk_index_to_add[key]
|
||||||
|
row[f"videos/{key}/file_index"] = row[f"videos/{key}/file_index"] + videos_file_index_to_add[key]
|
||||||
|
row["dataset_from_index"] = row["dataset_from_index"] + frame_index_to_add
|
||||||
|
row["dataset_to_index"] = row["dataset_to_index"] + frame_index_to_add
|
||||||
|
return row
|
||||||
|
return _update
|
||||||
|
|
||||||
def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, aggr_root=None):
|
def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]=None, aggr_root=None):
|
||||||
logging.info("Start aggregate_datasets")
|
logging.info("Start aggregate_datasets")
|
||||||
|
|
||||||
|
if roots is None:
|
||||||
all_metadata = [LeRobotDatasetMetadata(repo_id) for repo_id in repo_ids]
|
all_metadata = [LeRobotDatasetMetadata(repo_id) for repo_id in repo_ids]
|
||||||
|
else:
|
||||||
|
all_metadata = [LeRobotDatasetMetadata(repo_id, root=root) for repo_id, root in zip(repo_ids, roots)]
|
||||||
|
|
||||||
fps, robot_type, features = validate_all_metadata(all_metadata)
|
fps, robot_type, features = validate_all_metadata(all_metadata)
|
||||||
|
video_keys = [key for key in features if features[key]["dtype"] == "video"]
|
||||||
|
|
||||||
# Create resulting dataset folder
|
# Create resulting dataset folder
|
||||||
aggr_meta = LeRobotDatasetMetadata.create(
|
aggr_meta = LeRobotDatasetMetadata.create(
|
||||||
|
@ -55,95 +81,99 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, aggr_root=None):
|
||||||
features=features,
|
features=features,
|
||||||
root=aggr_root,
|
root=aggr_root,
|
||||||
)
|
)
|
||||||
|
aggr_root = aggr_meta.root
|
||||||
|
|
||||||
logging.info("Find all tasks")
|
logging.info("Find all tasks")
|
||||||
# find all tasks, deduplicate them, create new task indices for each dataset
|
unique_tasks = pd.concat([meta.tasks for meta in all_metadata]).index.unique()
|
||||||
# indexed by dataset index
|
aggr_meta.tasks = pd.DataFrame({"task_index": range(len(unique_tasks))}, index=unique_tasks)
|
||||||
datasets_task_index_to_aggr_task_index = {}
|
|
||||||
aggr_task_index = 0
|
|
||||||
for dataset_index, meta in enumerate(tqdm.tqdm(all_metadata, desc="Find all tasks")):
|
|
||||||
task_index_to_aggr_task_index = {}
|
|
||||||
|
|
||||||
for task_index, task in meta.tasks.items():
|
num_episodes = 0
|
||||||
if task not in aggr_meta.task_to_task_index:
|
num_frames = 0
|
||||||
# add the task to aggr tasks mappings
|
|
||||||
aggr_meta.tasks[aggr_task_index] = task
|
|
||||||
aggr_meta.task_to_task_index[task] = aggr_task_index
|
|
||||||
aggr_task_index += 1
|
|
||||||
|
|
||||||
# add task_index anyway
|
aggr_meta_chunk_idx = 0
|
||||||
task_index_to_aggr_task_index[task_index] = aggr_meta.task_to_task_index[task]
|
aggr_meta_file_idx = 0
|
||||||
|
|
||||||
datasets_task_index_to_aggr_task_index[dataset_index] = task_index_to_aggr_task_index
|
aggr_data_chunk_idx = 0
|
||||||
|
aggr_data_file_idx = 0
|
||||||
|
|
||||||
logging.info("Copy data and videos")
|
aggr_videos_chunk_idx = {key: 0 for key in video_keys}
|
||||||
aggr_episode_index_shift = 0
|
aggr_videos_file_idx = {key: 0 for key in video_keys}
|
||||||
for dataset_index, meta in enumerate(tqdm.tqdm(all_metadata, desc="Copy data and videos")):
|
|
||||||
# cp data
|
|
||||||
for episode_index in range(meta.total_episodes):
|
|
||||||
aggr_episode_index = episode_index + aggr_episode_index_shift
|
|
||||||
data_path = meta.root / meta.get_data_file_path(episode_index)
|
|
||||||
aggr_data_path = aggr_meta.root / aggr_meta.get_data_file_path(aggr_episode_index)
|
|
||||||
|
|
||||||
# update episode_index and task_index
|
for meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"):
|
||||||
df = pd.read_parquet(data_path)
|
|
||||||
update_row_func = get_update_episode_and_task_func(
|
meta_chunk_file_ids = set([(c,f) for c, f in zip(meta.episodes["meta/episodes/chunk_index"], meta.episodes["meta/episodes/file_index"])])
|
||||||
aggr_episode_index_shift, datasets_task_index_to_aggr_task_index[dataset_index]
|
for chunk_idx, file_idx in meta_chunk_file_ids:
|
||||||
|
path = meta.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||||
|
df = pd.read_parquet(path)
|
||||||
|
update_meta_func = get_update_meta_func(
|
||||||
|
aggr_meta_chunk_idx,
|
||||||
|
aggr_meta_file_idx,
|
||||||
|
aggr_data_chunk_idx,
|
||||||
|
aggr_data_file_idx,
|
||||||
|
aggr_videos_chunk_idx,
|
||||||
|
aggr_videos_file_idx,
|
||||||
|
num_frames,
|
||||||
)
|
)
|
||||||
df = df.apply(update_row_func, axis=1)
|
df = df.apply(update_meta_func, axis=1)
|
||||||
|
|
||||||
aggr_data_path.parent.mkdir(parents=True, exist_ok=True)
|
aggr_path = aggr_root / DEFAULT_EPISODES_PATH.format(chunk_index=aggr_meta_chunk_idx, file_index=aggr_meta_file_idx)
|
||||||
df.to_parquet(aggr_data_path)
|
aggr_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
df.to_parquet(aggr_path)
|
||||||
|
|
||||||
|
aggr_meta_file_idx += 1
|
||||||
|
if aggr_meta_file_idx >= DEFAULT_CHUNK_SIZE:
|
||||||
|
aggr_meta_file_idx = 0
|
||||||
|
aggr_meta_chunk_idx += 1
|
||||||
|
|
||||||
# cp videos
|
# cp videos
|
||||||
for episode_index in range(meta.total_episodes):
|
for key in video_keys:
|
||||||
aggr_episode_index = episode_index + aggr_episode_index_shift
|
video_chunk_file_ids = set([(c,f) for c, f in zip(meta.episodes[f"videos/{key}/chunk_index"], meta.episodes[f"videos/{key}/file_index"])])
|
||||||
for vid_key in meta.video_keys:
|
for chunk_idx, file_idx in video_chunk_file_ids:
|
||||||
video_path = meta.root / meta.get_video_file_path(episode_index, vid_key)
|
path = meta.root / DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=chunk_idx, file_index=file_idx)
|
||||||
aggr_video_path = aggr_meta.root / aggr_meta.get_video_file_path(aggr_episode_index, vid_key)
|
aggr_path = aggr_root / DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=aggr_videos_chunk_idx[key], file_index=aggr_videos_file_idx[key])
|
||||||
aggr_video_path.parent.mkdir(parents=True, exist_ok=True)
|
aggr_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
shutil.copy(video_path, aggr_video_path)
|
shutil.copy(str(path), str(aggr_path))
|
||||||
|
|
||||||
# copy_command = f"cp {video_path} {aggr_video_path} &"
|
# copy_command = f"cp {video_path} {aggr_video_path} &"
|
||||||
# subprocess.Popen(copy_command, shell=True)
|
# subprocess.Popen(copy_command, shell=True)
|
||||||
|
|
||||||
# populate episodes
|
aggr_videos_file_idx[key] += 1
|
||||||
for episode_index, episode_dict in meta.episodes.items():
|
if aggr_videos_file_idx[key] >= DEFAULT_CHUNK_SIZE:
|
||||||
aggr_episode_index = episode_index + aggr_episode_index_shift
|
aggr_videos_file_idx[key] = 0
|
||||||
episode_dict["episode_index"] = aggr_episode_index
|
aggr_videos_chunk_idx[key] += 1
|
||||||
aggr_meta.episodes[aggr_episode_index] = episode_dict
|
|
||||||
|
|
||||||
# populate episodes_stats
|
data_chunk_file_ids = set([(c,f) for c, f in zip(meta.episodes["data/chunk_index"], meta.episodes["data/file_index"])])
|
||||||
for episode_index, episode_stats in meta.episodes_stats.items():
|
for chunk_idx, file_idx in data_chunk_file_ids:
|
||||||
aggr_episode_index = episode_index + aggr_episode_index_shift
|
path = meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||||
aggr_meta.episodes_stats[aggr_episode_index] = episode_stats
|
df = pd.read_parquet(path)
|
||||||
|
update_data_func = get_update_episode_and_task_func(num_episodes, meta.tasks, aggr_meta.tasks)
|
||||||
|
df = df.apply(update_data_func, axis=1)
|
||||||
|
|
||||||
# populate info
|
aggr_path = aggr_root / DEFAULT_DATA_PATH.format(chunk_index=aggr_data_chunk_idx, file_index=aggr_data_file_idx)
|
||||||
aggr_meta.info["total_episodes"] += meta.total_episodes
|
aggr_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
aggr_meta.info["total_frames"] += meta.total_frames
|
df.to_parquet(aggr_path)
|
||||||
aggr_meta.info["total_videos"] += len(aggr_meta.video_keys) * meta.total_episodes
|
|
||||||
|
|
||||||
aggr_episode_index_shift += meta.total_episodes
|
aggr_data_file_idx += 1
|
||||||
|
if aggr_data_file_idx >= DEFAULT_CHUNK_SIZE:
|
||||||
|
aggr_data_file_idx = 0
|
||||||
|
aggr_data_chunk_idx += 1
|
||||||
|
|
||||||
logging.info("write meta data")
|
num_episodes += meta.total_episodes
|
||||||
|
num_frames += meta.total_frames
|
||||||
|
|
||||||
aggr_meta.info["total_chunks"] = aggr_meta.get_episode_chunk(aggr_episode_index_shift - 1)
|
logging.info("write tasks")
|
||||||
aggr_meta.info["splits"] = {"train": f"0:{aggr_meta.info['total_episodes']}"}
|
write_tasks(aggr_meta.tasks, aggr_meta.root)
|
||||||
|
|
||||||
# create a new episodes jsonl with updated episode_index using write_episode
|
|
||||||
for episode_dict in aggr_meta.episodes.values():
|
|
||||||
write_episode(episode_dict, aggr_meta.root)
|
|
||||||
|
|
||||||
# create a new episode_stats jsonl with updated episode_index using write_episode_stats
|
|
||||||
for episode_index, episode_stats in aggr_meta.episodes_stats.items():
|
|
||||||
legacy_write_episode_stats(episode_index, episode_stats, aggr_meta.root)
|
|
||||||
|
|
||||||
# create a new task jsonl with updated episode_index using write_task
|
|
||||||
for task_index, task in aggr_meta.tasks.items():
|
|
||||||
legacy_write_task(task_index, task, aggr_meta.root)
|
|
||||||
|
|
||||||
|
logging.info("write info")
|
||||||
|
aggr_meta.info["total_episodes"] = sum([meta.total_episodes for meta in all_metadata])
|
||||||
|
aggr_meta.info["total_frames"] = sum([meta.total_frames for meta in all_metadata])
|
||||||
|
aggr_meta.info["splits"] = {"train": f"0:{aggr_meta.total_episodes}"}
|
||||||
write_info(aggr_meta.info, aggr_meta.root)
|
write_info(aggr_meta.info, aggr_meta.root)
|
||||||
|
|
||||||
|
logging.info("write stats")
|
||||||
|
aggr_meta.stats = aggregate_stats([meta.stats for meta in all_metadata])
|
||||||
|
write_stats(aggr_meta.stats, aggr_meta.root)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
init_logging()
|
init_logging()
|
||||||
|
|
|
@ -712,8 +712,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
return get_hf_features_from_features(self.features)
|
return get_hf_features_from_features(self.features)
|
||||||
|
|
||||||
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
|
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
|
||||||
ep_start = self.meta.episodes["data/from_index"][ep_idx]
|
ep_start = self.meta.episodes["dataset_from_index"][ep_idx]
|
||||||
ep_end = self.meta.episodes["data/to_index"][ep_idx]
|
ep_end = self.meta.episodes["dataset_to_index"][ep_idx]
|
||||||
query_indices = {
|
query_indices = {
|
||||||
key: [max(ep_start, min(ep_end - 1, idx + delta)) for delta in delta_idx]
|
key: [max(ep_start, min(ep_end - 1, idx + delta)) for delta in delta_idx]
|
||||||
for key, delta_idx in self.delta_indices.items()
|
for key, delta_idx in self.delta_indices.items()
|
||||||
|
@ -1017,8 +1017,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
metadata = {
|
metadata = {
|
||||||
"data/chunk_index": chunk_idx,
|
"data/chunk_index": chunk_idx,
|
||||||
"data/file_index": file_idx,
|
"data/file_index": file_idx,
|
||||||
"data/from_index": latest_num_frames,
|
"dataset_from_index": latest_num_frames,
|
||||||
"data/to_index": latest_num_frames + ep_num_frames,
|
"dataset_to_index": latest_num_frames + ep_num_frames,
|
||||||
}
|
}
|
||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
|
|
|
@ -27,6 +27,7 @@ from datasets import Dataset
|
||||||
from huggingface_hub import HfApi, snapshot_download
|
from huggingface_hub import HfApi, snapshot_download
|
||||||
|
|
||||||
from lerobot.common.constants import HF_LEROBOT_HOME
|
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||||
|
from lerobot.common.datasets.compute_stats import aggregate_stats
|
||||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
DEFAULT_CHUNK_SIZE,
|
DEFAULT_CHUNK_SIZE,
|
||||||
|
@ -46,6 +47,7 @@ from lerobot.common.datasets.utils import (
|
||||||
update_chunk_file_indices,
|
update_chunk_file_indices,
|
||||||
write_episodes,
|
write_episodes,
|
||||||
write_info,
|
write_info,
|
||||||
|
write_stats,
|
||||||
write_tasks,
|
write_tasks,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -306,6 +308,9 @@ def convert_episodes_metadata(root, new_root, episodes_metadata, episodes_video_
|
||||||
)
|
)
|
||||||
write_episodes(ds_episodes, new_root)
|
write_episodes(ds_episodes, new_root)
|
||||||
|
|
||||||
|
stats = aggregate_stats(list(episodes_stats.values()))
|
||||||
|
write_stats(stats, new_root)
|
||||||
|
|
||||||
|
|
||||||
def convert_info(root, new_root):
|
def convert_info(root, new_root):
|
||||||
info = load_info(root)
|
info = load_info(root)
|
||||||
|
@ -330,9 +335,16 @@ def convert_dataset(
|
||||||
branch: str | None = None,
|
branch: str | None = None,
|
||||||
num_workers: int = 4,
|
num_workers: int = 4,
|
||||||
):
|
):
|
||||||
|
|
||||||
|
|
||||||
root = HF_LEROBOT_HOME / repo_id
|
root = HF_LEROBOT_HOME / repo_id
|
||||||
|
old_root = HF_LEROBOT_HOME / f"{repo_id}_old"
|
||||||
new_root = HF_LEROBOT_HOME / f"{repo_id}_v30"
|
new_root = HF_LEROBOT_HOME / f"{repo_id}_v30"
|
||||||
|
|
||||||
|
if old_root.is_dir() and root.is_dir():
|
||||||
|
shutil.rmtree(str(root))
|
||||||
|
shutil.move(str(old_root), str(root))
|
||||||
|
|
||||||
if new_root.is_dir():
|
if new_root.is_dir():
|
||||||
shutil.rmtree(new_root)
|
shutil.rmtree(new_root)
|
||||||
|
|
||||||
|
@ -349,7 +361,7 @@ def convert_dataset(
|
||||||
episodes_videos_metadata = convert_videos(root, new_root)
|
episodes_videos_metadata = convert_videos(root, new_root)
|
||||||
convert_episodes_metadata(root, new_root, episodes_metadata, episodes_videos_metadata)
|
convert_episodes_metadata(root, new_root, episodes_metadata, episodes_videos_metadata)
|
||||||
|
|
||||||
shutil.move(str(root), str(root) + "_old")
|
shutil.move(str(root), str(old_root))
|
||||||
shutil.move(str(new_root), str(root))
|
shutil.move(str(new_root), str(root))
|
||||||
|
|
||||||
# TODO(racdene)
|
# TODO(racdene)
|
||||||
|
@ -365,7 +377,7 @@ def convert_dataset(
|
||||||
|
|
||||||
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||||
|
|
||||||
LeRobotDataset(repo_id).push_to_hub()
|
# LeRobotDataset(repo_id).push_to_hub()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -212,6 +212,7 @@ def tasks_factory():
|
||||||
def episodes_factory(tasks_factory, stats_factory):
|
def episodes_factory(tasks_factory, stats_factory):
|
||||||
def _create_episodes(
|
def _create_episodes(
|
||||||
features: dict[str],
|
features: dict[str],
|
||||||
|
fps: int = DEFAULT_FPS,
|
||||||
total_episodes: int = 3,
|
total_episodes: int = 3,
|
||||||
total_frames: int = 400,
|
total_frames: int = 400,
|
||||||
video_keys: list[str] | None = None,
|
video_keys: list[str] | None = None,
|
||||||
|
@ -252,6 +253,8 @@ def episodes_factory(tasks_factory, stats_factory):
|
||||||
for video_key in video_keys:
|
for video_key in video_keys:
|
||||||
d[f"videos/{video_key}/chunk_index"] = []
|
d[f"videos/{video_key}/chunk_index"] = []
|
||||||
d[f"videos/{video_key}/file_index"] = []
|
d[f"videos/{video_key}/file_index"] = []
|
||||||
|
d[f"videos/{video_key}/from_timestamp"] = []
|
||||||
|
d[f"videos/{video_key}/to_timestamp"] = []
|
||||||
|
|
||||||
for stats_key in flatten_dict({"stats": stats_factory(features)}):
|
for stats_key in flatten_dict({"stats": stats_factory(features)}):
|
||||||
d[stats_key] = []
|
d[stats_key] = []
|
||||||
|
@ -281,6 +284,8 @@ def episodes_factory(tasks_factory, stats_factory):
|
||||||
for video_key in video_keys:
|
for video_key in video_keys:
|
||||||
d[f"videos/{video_key}/chunk_index"].append(0)
|
d[f"videos/{video_key}/chunk_index"].append(0)
|
||||||
d[f"videos/{video_key}/file_index"].append(0)
|
d[f"videos/{video_key}/file_index"].append(0)
|
||||||
|
d[f"videos/{video_key}/from_timestamp"].append(num_frames / fps)
|
||||||
|
d[f"videos/{video_key}/to_timestamp"].append((num_frames + lengths[ep_idx]) / fps)
|
||||||
|
|
||||||
# Add stats columns like "stats/action/max"
|
# Add stats columns like "stats/action/max"
|
||||||
for stats_key, stats in flatten_dict({"stats": stats_factory(features)}).items():
|
for stats_key, stats in flatten_dict({"stats": stats_factory(features)}).items():
|
||||||
|
@ -306,7 +311,7 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
|
||||||
if features is None:
|
if features is None:
|
||||||
features = features_factory()
|
features = features_factory()
|
||||||
if episodes is None:
|
if episodes is None:
|
||||||
episodes = episodes_factory(features)
|
episodes = episodes_factory(features, fps)
|
||||||
|
|
||||||
timestamp_col = np.array([], dtype=np.float32)
|
timestamp_col = np.array([], dtype=np.float32)
|
||||||
frame_index_col = np.array([], dtype=np.int64)
|
frame_index_col = np.array([], dtype=np.int64)
|
||||||
|
@ -379,6 +384,7 @@ def lerobot_dataset_metadata_factory(
|
||||||
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
|
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
|
||||||
episodes = episodes_factory(
|
episodes = episodes_factory(
|
||||||
features=info["features"],
|
features=info["features"],
|
||||||
|
fps=info["fps"],
|
||||||
total_episodes=info["total_episodes"],
|
total_episodes=info["total_episodes"],
|
||||||
total_frames=info["total_frames"],
|
total_frames=info["total_frames"],
|
||||||
video_keys=video_keys,
|
video_keys=video_keys,
|
||||||
|
@ -441,6 +447,7 @@ def lerobot_dataset_factory(
|
||||||
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
|
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
|
||||||
episodes_metadata = episodes_factory(
|
episodes_metadata = episodes_factory(
|
||||||
features=info["features"],
|
features=info["features"],
|
||||||
|
fps=info["fps"],
|
||||||
total_episodes=info["total_episodes"],
|
total_episodes=info["total_episodes"],
|
||||||
total_frames=info["total_frames"],
|
total_frames=info["total_frames"],
|
||||||
video_keys=video_keys,
|
video_keys=video_keys,
|
||||||
|
|
|
@ -73,6 +73,7 @@ def create_tasks(tasks_factory):
|
||||||
def create_episodes(episodes_factory):
|
def create_episodes(episodes_factory):
|
||||||
def _create_episodes(dir: Path, episodes: datasets.Dataset | None = None):
|
def _create_episodes(dir: Path, episodes: datasets.Dataset | None = None):
|
||||||
if episodes is None:
|
if episodes is None:
|
||||||
|
# TODO(rcadene): add features, fps as arguments
|
||||||
episodes = episodes_factory()
|
episodes = episodes_factory()
|
||||||
write_episodes(episodes, dir)
|
write_episodes(episodes, dir)
|
||||||
|
|
||||||
|
|
|
@ -62,6 +62,7 @@ def mock_snapshot_download_factory(
|
||||||
if episodes is None:
|
if episodes is None:
|
||||||
episodes = episodes_factory(
|
episodes = episodes_factory(
|
||||||
features=info["features"],
|
features=info["features"],
|
||||||
|
fps=info["fps"],
|
||||||
total_episodes=info["total_episodes"],
|
total_episodes=info["total_episodes"],
|
||||||
total_frames=info["total_frames"],
|
total_frames=info["total_frames"],
|
||||||
tasks=tasks,
|
tasks=tasks,
|
||||||
|
@ -121,6 +122,7 @@ def mock_snapshot_download_factory(
|
||||||
create_tasks(local_dir, tasks)
|
create_tasks(local_dir, tasks)
|
||||||
if has_episodes:
|
if has_episodes:
|
||||||
create_episodes(local_dir, episodes)
|
create_episodes(local_dir, episodes)
|
||||||
|
# TODO(rcadene): create_videos?
|
||||||
if has_data:
|
if has_data:
|
||||||
create_hf_dataset(local_dir, hf_dataset)
|
create_hf_dataset(local_dir, hf_dataset)
|
||||||
|
|
||||||
|
|
|
@ -1,19 +1,29 @@
|
||||||
from lerobot.common.datasets.aggregate import aggregate_datasets
|
from lerobot.common.datasets.aggregate import aggregate_datasets
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||||
|
|
||||||
|
|
||||||
def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
|
def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
|
||||||
dataset_0 = lerobot_dataset_factory(
|
ds_0 = lerobot_dataset_factory(
|
||||||
root=tmp_path / "test_0",
|
root=tmp_path / "test_0",
|
||||||
repo_id=DUMMY_REPO_ID + "_0",
|
repo_id=f"{DUMMY_REPO_ID}_0",
|
||||||
total_episodes=10,
|
total_episodes=10,
|
||||||
total_frames=400,
|
total_frames=400,
|
||||||
)
|
)
|
||||||
dataset_1 = lerobot_dataset_factory(
|
ds_1 = lerobot_dataset_factory(
|
||||||
root=tmp_path / "test_1",
|
root=tmp_path / "test_1",
|
||||||
repo_id=DUMMY_REPO_ID + "_1",
|
repo_id=f"{DUMMY_REPO_ID}_1",
|
||||||
total_episodes=10,
|
total_episodes=10,
|
||||||
total_frames=400,
|
total_frames=400,
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset_2 = aggregate_datasets([dataset_0, dataset_1])
|
aggregate_datasets(
|
||||||
|
repo_ids=[ds_0.repo_id, ds_1.repo_id],
|
||||||
|
roots=[ds_0.root, ds_1.root],
|
||||||
|
aggr_repo_id=f"{DUMMY_REPO_ID}_aggr",
|
||||||
|
aggr_root=tmp_path / "test_aggr"
|
||||||
|
)
|
||||||
|
|
||||||
|
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_aggr", root=tmp_path / "test_aggr")
|
||||||
|
for item in aggr_ds:
|
||||||
|
pass
|
||||||
|
|
Loading…
Reference in New Issue