Add episode chunks logic, move_videos & lfs tracking fix
This commit is contained in:
parent
110264000f
commit
c146ba936f
|
@ -87,6 +87,7 @@ import argparse
|
||||||
import contextlib
|
import contextlib
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -108,8 +109,15 @@ from lerobot.scripts.push_dataset_to_hub import push_dataset_card_to_hub
|
||||||
V16 = "v1.6"
|
V16 = "v1.6"
|
||||||
V20 = "v2.0"
|
V20 = "v2.0"
|
||||||
|
|
||||||
PARQUET_PATH = "data/train-{episode_index:05d}-of-{total_episodes:05d}.parquet"
|
EPISODE_CHUNK_SIZE = 1000
|
||||||
VIDEO_PATH = "videos/{video_key}_episode_{episode_index:06d}.mp4"
|
|
||||||
|
CLEAN_GITATTRIBUTES = Path("data/.gitattributes")
|
||||||
|
|
||||||
|
VIDEO_FILE = "{video_key}_episode_{episode_index:06d}.mp4"
|
||||||
|
PARQUET_CHUNK_PATH = (
|
||||||
|
"data/chunk-{episode_chunk:03d}/train-{episode_index:05d}-of-{total_episodes:05d}.parquet"
|
||||||
|
)
|
||||||
|
VIDEO_CHUNK_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
||||||
|
|
||||||
|
|
||||||
def parse_robot_config(config_path: Path, config_overrides: list[str] | None = None) -> tuple[str, dict]:
|
def parse_robot_config(config_path: Path, config_overrides: list[str] | None = None) -> tuple[str, dict]:
|
||||||
|
@ -229,23 +237,125 @@ def add_task_index_from_tasks_col(
|
||||||
|
|
||||||
|
|
||||||
def split_parquet_by_episodes(
|
def split_parquet_by_episodes(
|
||||||
dataset: Dataset, keys: dict[str, list], total_episodes: int, episode_indices: list, output_dir: Path
|
dataset: Dataset,
|
||||||
|
keys: dict[str, list],
|
||||||
|
total_episodes: int,
|
||||||
|
total_chunks: int,
|
||||||
|
output_dir: Path,
|
||||||
) -> list:
|
) -> list:
|
||||||
(output_dir / "data").mkdir(exist_ok=True, parents=True)
|
|
||||||
table = dataset.remove_columns(keys["video"])._data.table
|
table = dataset.remove_columns(keys["video"])._data.table
|
||||||
episode_lengths = []
|
episode_lengths = []
|
||||||
for episode_index in sorted(episode_indices):
|
for ep_chunk in range(total_chunks):
|
||||||
# Write each episode_index to a new parquet file
|
ep_chunk_start = EPISODE_CHUNK_SIZE * ep_chunk
|
||||||
filtered_table = table.filter(pc.equal(table["episode_index"], episode_index))
|
ep_chunk_end = min(EPISODE_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
|
||||||
episode_lengths.insert(episode_index, len(filtered_table))
|
|
||||||
output_file = output_dir / PARQUET_PATH.format(
|
chunk_dir = "/".join(PARQUET_CHUNK_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk)
|
||||||
episode_index=episode_index, total_episodes=total_episodes
|
(output_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
|
||||||
)
|
for ep_idx in range(ep_chunk_start, ep_chunk_end):
|
||||||
pq.write_table(filtered_table, output_file)
|
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
||||||
|
episode_lengths.insert(ep_idx, len(ep_table))
|
||||||
|
output_file = output_dir / PARQUET_CHUNK_PATH.format(
|
||||||
|
episode_chunk=ep_chunk, episode_index=ep_idx, total_episodes=total_episodes
|
||||||
|
)
|
||||||
|
pq.write_table(ep_table, output_file)
|
||||||
|
|
||||||
return episode_lengths
|
return episode_lengths
|
||||||
|
|
||||||
|
|
||||||
|
def move_videos(
|
||||||
|
repo_id: str,
|
||||||
|
video_keys: list[str],
|
||||||
|
total_episodes: int,
|
||||||
|
total_chunks: int,
|
||||||
|
work_dir: Path,
|
||||||
|
branch: str = "main",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
HACK: Since HfApi() doesn't provide a way to move files directly in a repo, this function will run git
|
||||||
|
commands to fetch git lfs video files references to move them into subdirectories without having to
|
||||||
|
actually download them.
|
||||||
|
"""
|
||||||
|
_lfs_clone(repo_id, work_dir, branch)
|
||||||
|
|
||||||
|
video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*.mp4")]
|
||||||
|
total_videos = len(video_files)
|
||||||
|
assert total_videos == total_episodes * len(video_keys)
|
||||||
|
|
||||||
|
fix_lfs_video_files_tracking(work_dir, video_files, CLEAN_GITATTRIBUTES)
|
||||||
|
|
||||||
|
video_dirs = sorted(work_dir.glob("videos*/"))
|
||||||
|
for ep_chunk in range(total_chunks):
|
||||||
|
ep_chunk_start = EPISODE_CHUNK_SIZE * ep_chunk
|
||||||
|
ep_chunk_end = min(EPISODE_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
|
||||||
|
for vid_key in video_keys:
|
||||||
|
chunk_dir = "/".join(VIDEO_CHUNK_PATH.split("/")[:-1]).format(
|
||||||
|
episode_chunk=ep_chunk, video_key=vid_key
|
||||||
|
)
|
||||||
|
(work_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
for ep_idx in range(ep_chunk_start, ep_chunk_end):
|
||||||
|
target_path = VIDEO_CHUNK_PATH.format(
|
||||||
|
episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_idx
|
||||||
|
)
|
||||||
|
video_file = VIDEO_FILE.format(video_key=vid_key, episode_index=ep_idx)
|
||||||
|
if len(video_dirs) == 1:
|
||||||
|
video_path = video_dirs[0] / video_file
|
||||||
|
else:
|
||||||
|
for dir in video_dirs:
|
||||||
|
if (dir / video_file).is_file():
|
||||||
|
video_path = dir / video_file
|
||||||
|
break
|
||||||
|
|
||||||
|
video_path.rename(work_dir / target_path)
|
||||||
|
|
||||||
|
commit_message = "Move video files into chunk subdirectories"
|
||||||
|
subprocess.run(["git", "add", "."], cwd=work_dir, check=True)
|
||||||
|
subprocess.run(["git", "commit", "-m", commit_message], cwd=work_dir, check=True)
|
||||||
|
subprocess.run(["git", "push"], cwd=work_dir, check=True)
|
||||||
|
|
||||||
|
|
||||||
|
def fix_lfs_video_files_tracking(work_dir: Path, video_files: list[str], clean_gitattributes_path: Path):
|
||||||
|
"""
|
||||||
|
HACK: This function fixes the tracking by git lfs which was not properly set on some repos. In that case,
|
||||||
|
there's no other option than to download the actual files and reupload them with lfs tracking.
|
||||||
|
"""
|
||||||
|
# _lfs_clone(repo_id, work_dir, branch)
|
||||||
|
lfs_tracked_files = subprocess.run(
|
||||||
|
["git", "lfs", "ls-files", "-n"], cwd=work_dir, capture_output=True, text=True, check=True
|
||||||
|
)
|
||||||
|
lfs_tracked_files = set(lfs_tracked_files.stdout.splitlines())
|
||||||
|
lfs_untracked_videos = [f for f in video_files if f not in lfs_tracked_files]
|
||||||
|
|
||||||
|
if lfs_untracked_videos:
|
||||||
|
shutil.copyfile(clean_gitattributes_path, work_dir / ".gitattributes")
|
||||||
|
subprocess.run(["git", "add", ".gitattributes"], cwd=work_dir, check=True)
|
||||||
|
for i in range(0, len(lfs_untracked_videos), 100):
|
||||||
|
files = lfs_untracked_videos[i : i + 100]
|
||||||
|
try:
|
||||||
|
subprocess.run(
|
||||||
|
["git", "rm", "--cached", *files], cwd=work_dir, capture_output=True, check=True
|
||||||
|
)
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
print("git rm --cached ERROR:")
|
||||||
|
print(e.stderr)
|
||||||
|
subprocess.run(["git", "add", *files], cwd=work_dir, check=True)
|
||||||
|
|
||||||
|
commit_message = "Track video files with git lfs"
|
||||||
|
subprocess.run(["git", "commit", "-m", commit_message], cwd=work_dir, check=True)
|
||||||
|
subprocess.run(["git", "push"], cwd=work_dir, check=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _lfs_clone(repo_id: str, work_dir: Path, branch: str) -> None:
|
||||||
|
subprocess.run(["git", "lfs", "install"], cwd=work_dir, check=True)
|
||||||
|
repo_url = f"https://huggingface.co/datasets/{repo_id}"
|
||||||
|
env = {"GIT_LFS_SKIP_SMUDGE": "1"} # Prevent downloading LFS files
|
||||||
|
subprocess.run(
|
||||||
|
["git", "clone", "--branch", branch, "--single-branch", "--depth", "1", repo_url, str(work_dir)],
|
||||||
|
check=True,
|
||||||
|
env=env,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_audio_info(video_path: Path | str) -> dict:
|
def _get_audio_info(video_path: Path | str) -> dict:
|
||||||
ffprobe_audio_cmd = [
|
ffprobe_audio_cmd = [
|
||||||
"ffprobe",
|
"ffprobe",
|
||||||
|
@ -323,16 +433,19 @@ def _get_video_info(video_path: Path | str) -> dict:
|
||||||
return video_info
|
return video_info
|
||||||
|
|
||||||
|
|
||||||
def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str]) -> dict:
|
def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch: str) -> dict:
|
||||||
hub_api = HfApi()
|
hub_api = HfApi()
|
||||||
videos_info_dict = {"videos_path": VIDEO_PATH}
|
videos_info_dict = {"videos_path": VIDEO_CHUNK_PATH}
|
||||||
for vid_key in video_keys:
|
|
||||||
# Assumes first episode
|
# Assumes first episode
|
||||||
video_path = VIDEO_PATH.format(video_key=vid_key, episode_index=0)
|
video_files = [
|
||||||
video_path = hub_api.hf_hub_download(
|
VIDEO_CHUNK_PATH.format(episode_chunk=0, video_key=vid_key, episode_index=0) for vid_key in video_keys
|
||||||
repo_id=repo_id, repo_type="dataset", local_dir=local_dir, filename=video_path
|
]
|
||||||
)
|
hub_api.snapshot_download(
|
||||||
videos_info_dict[vid_key] = _get_video_info(video_path)
|
repo_id=repo_id, repo_type="dataset", local_dir=local_dir, revision=branch, allow_patterns=video_files
|
||||||
|
)
|
||||||
|
for vid_key, vid_path in zip(video_keys, video_files, strict=True):
|
||||||
|
videos_info_dict[vid_key] = _get_video_info(local_dir / vid_path)
|
||||||
|
|
||||||
return videos_info_dict
|
return videos_info_dict
|
||||||
|
|
||||||
|
@ -399,6 +512,7 @@ def convert_dataset(
|
||||||
tasks_path: Path | None = None,
|
tasks_path: Path | None = None,
|
||||||
tasks_col: Path | None = None,
|
tasks_col: Path | None = None,
|
||||||
robot_config: dict | None = None,
|
robot_config: dict | None = None,
|
||||||
|
test_branch: str | None = None,
|
||||||
):
|
):
|
||||||
v1 = get_hub_safe_version(repo_id, V16, enforce_v2=False)
|
v1 = get_hub_safe_version(repo_id, V16, enforce_v2=False)
|
||||||
v1x_dir = local_dir / V16 / repo_id
|
v1x_dir = local_dir / V16 / repo_id
|
||||||
|
@ -408,8 +522,12 @@ def convert_dataset(
|
||||||
|
|
||||||
hub_api = HfApi()
|
hub_api = HfApi()
|
||||||
hub_api.snapshot_download(
|
hub_api.snapshot_download(
|
||||||
repo_id=repo_id, repo_type="dataset", revision=v1, local_dir=v1x_dir, ignore_patterns="videos/"
|
repo_id=repo_id, repo_type="dataset", revision=v1, local_dir=v1x_dir, ignore_patterns="videos*/"
|
||||||
)
|
)
|
||||||
|
branch = "main"
|
||||||
|
if test_branch:
|
||||||
|
branch = test_branch
|
||||||
|
create_branch(repo_id=repo_id, branch=test_branch, repo_type="dataset")
|
||||||
|
|
||||||
metadata_v1 = load_json(v1x_dir / "meta_data" / "info.json")
|
metadata_v1 = load_json(v1x_dir / "meta_data" / "info.json")
|
||||||
dataset = datasets.load_dataset("parquet", data_dir=v1x_dir / "data", split="train")
|
dataset = datasets.load_dataset("parquet", data_dir=v1x_dir / "data", split="train")
|
||||||
|
@ -423,10 +541,14 @@ def convert_dataset(
|
||||||
single_task = None
|
single_task = None
|
||||||
tasks_col = "language_instruction"
|
tasks_col = "language_instruction"
|
||||||
|
|
||||||
# Episodes
|
# Episodes & chunks
|
||||||
episode_indices = sorted(dataset.unique("episode_index"))
|
episode_indices = sorted(dataset.unique("episode_index"))
|
||||||
total_episodes = len(episode_indices)
|
total_episodes = len(episode_indices)
|
||||||
assert episode_indices == list(range(total_episodes))
|
assert episode_indices == list(range(total_episodes))
|
||||||
|
total_videos = total_episodes * len(keys["video"])
|
||||||
|
total_chunks = total_episodes // EPISODE_CHUNK_SIZE
|
||||||
|
if total_episodes % EPISODE_CHUNK_SIZE != 0:
|
||||||
|
total_chunks += 1
|
||||||
|
|
||||||
# Tasks
|
# Tasks
|
||||||
if single_task:
|
if single_task:
|
||||||
|
@ -448,25 +570,30 @@ def convert_dataset(
|
||||||
task_json = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)]
|
task_json = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)]
|
||||||
write_json(task_json, v20_dir / "meta" / "tasks.json")
|
write_json(task_json, v20_dir / "meta" / "tasks.json")
|
||||||
|
|
||||||
# Split data into 1 parquet file by episode
|
|
||||||
episode_lengths = split_parquet_by_episodes(dataset, keys, total_episodes, episode_indices, v20_dir)
|
|
||||||
|
|
||||||
# Shapes
|
# Shapes
|
||||||
sequence_shapes = {key: dataset.features[key].length for key in keys["sequence"]}
|
sequence_shapes = {key: dataset.features[key].length for key in keys["sequence"]}
|
||||||
image_shapes = get_image_shapes(dataset, keys["image"]) if len(keys["image"]) > 0 else {}
|
image_shapes = get_image_shapes(dataset, keys["image"]) if len(keys["image"]) > 0 else {}
|
||||||
|
|
||||||
|
# Videos
|
||||||
if len(keys["video"]) > 0:
|
if len(keys["video"]) > 0:
|
||||||
assert metadata_v1.get("video", False)
|
assert metadata_v1.get("video", False)
|
||||||
videos_info = get_videos_info(repo_id, v1x_dir, video_keys=keys["video"])
|
tmp_video_dir = local_dir / "videos" / V20 / repo_id
|
||||||
|
tmp_video_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
move_videos(repo_id, keys["video"], total_episodes, total_chunks, tmp_video_dir, branch)
|
||||||
|
videos_info = get_videos_info(repo_id, v1x_dir, video_keys=keys["video"], branch=branch)
|
||||||
video_shapes = get_video_shapes(videos_info, keys["video"])
|
video_shapes = get_video_shapes(videos_info, keys["video"])
|
||||||
for img_key in keys["video"]:
|
for img_key in keys["video"]:
|
||||||
assert math.isclose(videos_info[img_key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3)
|
assert math.isclose(videos_info[img_key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3)
|
||||||
if "encoding" in metadata_v1:
|
if "encoding" in metadata_v1:
|
||||||
assert videos_info[img_key]["video.pix_fmt"] == metadata_v1["encoding"]["pix_fmt"]
|
assert videos_info[img_key]["video.pix_fmt"] == metadata_v1["encoding"]["pix_fmt"]
|
||||||
else:
|
else:
|
||||||
assert len(keys["video"]) == 0
|
assert metadata_v1.get("video", 0) == 0
|
||||||
videos_info = None
|
videos_info = None
|
||||||
video_shapes = {}
|
video_shapes = {}
|
||||||
|
|
||||||
|
# Split data into 1 parquet file by episode
|
||||||
|
episode_lengths = split_parquet_by_episodes(dataset, keys, total_episodes, total_chunks, v20_dir)
|
||||||
|
|
||||||
# Names
|
# Names
|
||||||
if robot_config is not None:
|
if robot_config is not None:
|
||||||
robot_type = robot_config["robot_type"]
|
robot_type = robot_config["robot_type"]
|
||||||
|
@ -495,11 +622,14 @@ def convert_dataset(
|
||||||
# Assemble metadata v2.0
|
# Assemble metadata v2.0
|
||||||
metadata_v2_0 = {
|
metadata_v2_0 = {
|
||||||
"codebase_version": V20,
|
"codebase_version": V20,
|
||||||
"data_path": PARQUET_PATH,
|
"data_path": PARQUET_CHUNK_PATH,
|
||||||
"robot_type": robot_type,
|
"robot_type": robot_type,
|
||||||
"total_episodes": total_episodes,
|
"total_episodes": total_episodes,
|
||||||
"total_frames": len(dataset),
|
"total_frames": len(dataset),
|
||||||
"total_tasks": len(tasks),
|
"total_tasks": len(tasks),
|
||||||
|
"total_videos": total_videos,
|
||||||
|
"total_chunks": total_chunks,
|
||||||
|
"chunks_size": EPISODE_CHUNK_SIZE,
|
||||||
"fps": metadata_v1["fps"],
|
"fps": metadata_v1["fps"],
|
||||||
"splits": {"train": f"0:{total_episodes}"},
|
"splits": {"train": f"0:{total_episodes}"},
|
||||||
"keys": keys["sequence"],
|
"keys": keys["sequence"],
|
||||||
|
@ -512,37 +642,31 @@ def convert_dataset(
|
||||||
write_json(metadata_v2_0, v20_dir / "meta" / "info.json")
|
write_json(metadata_v2_0, v20_dir / "meta" / "info.json")
|
||||||
convert_stats_to_json(v1x_dir / "meta_data", v20_dir / "meta")
|
convert_stats_to_json(v1x_dir / "meta_data", v20_dir / "meta")
|
||||||
|
|
||||||
#### TODO: delete
|
with contextlib.suppress(EntryNotFoundError):
|
||||||
# repo_id = f"aliberts/{repo_id.split('/')[1]}"
|
hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch)
|
||||||
# if hub_api.repo_exists(repo_id=repo_id, repo_type="dataset"):
|
|
||||||
# hub_api.delete_repo(repo_id=repo_id, repo_type="dataset")
|
|
||||||
# hub_api.create_repo(repo_id=repo_id, repo_type="dataset", exist_ok=True)
|
|
||||||
####
|
|
||||||
|
|
||||||
with contextlib.suppress(EntryNotFoundError):
|
with contextlib.suppress(EntryNotFoundError):
|
||||||
hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision="main")
|
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta_data", repo_type="dataset", revision=branch)
|
||||||
|
|
||||||
with contextlib.suppress(EntryNotFoundError):
|
|
||||||
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta_data", repo_type="dataset", revision="main")
|
|
||||||
|
|
||||||
hub_api.upload_folder(
|
hub_api.upload_folder(
|
||||||
repo_id=repo_id,
|
repo_id=repo_id,
|
||||||
path_in_repo="data",
|
path_in_repo="data",
|
||||||
folder_path=v20_dir / "data",
|
folder_path=v20_dir / "data",
|
||||||
repo_type="dataset",
|
repo_type="dataset",
|
||||||
revision="main",
|
revision=branch,
|
||||||
)
|
)
|
||||||
hub_api.upload_folder(
|
hub_api.upload_folder(
|
||||||
repo_id=repo_id,
|
repo_id=repo_id,
|
||||||
path_in_repo="meta",
|
path_in_repo="meta",
|
||||||
folder_path=v20_dir / "meta",
|
folder_path=v20_dir / "meta",
|
||||||
repo_type="dataset",
|
repo_type="dataset",
|
||||||
revision="main",
|
revision=branch,
|
||||||
)
|
)
|
||||||
|
|
||||||
card_text = f"[meta/info.json](meta/info.json)\n```json\n{json.dumps(metadata_v2_0, indent=4)}\n```"
|
card_text = f"[meta/info.json](meta/info.json)\n```json\n{json.dumps(metadata_v2_0, indent=4)}\n```"
|
||||||
push_dataset_card_to_hub(repo_id=repo_id, revision="main", tags=repo_tags, text=card_text)
|
push_dataset_card_to_hub(repo_id=repo_id, revision=branch, tags=repo_tags, text=card_text)
|
||||||
create_branch(repo_id=repo_id, branch=V20, repo_type="dataset")
|
if not test_branch:
|
||||||
|
create_branch(repo_id=repo_id, branch=V20, repo_type="dataset")
|
||||||
|
|
||||||
# TODO:
|
# TODO:
|
||||||
# - [X] Add shapes
|
# - [X] Add shapes
|
||||||
|
@ -555,7 +679,9 @@ def convert_dataset(
|
||||||
# - [X] Add splits
|
# - [X] Add splits
|
||||||
# - [X] Push properly to branch v2.0 and delete v1.6 stuff from that branch
|
# - [X] Push properly to branch v2.0 and delete v1.6 stuff from that branch
|
||||||
# - [X] Handle multitask datasets
|
# - [X] Handle multitask datasets
|
||||||
# - [/] Add sanity checks (encoding, shapes)
|
# - [X] Handle hf hub repo limits (add chunks logic)
|
||||||
|
# - [X] Add test-branch
|
||||||
|
# - [X] Add sanity checks (encoding, shapes)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
@ -601,6 +727,12 @@ def main():
|
||||||
default=None,
|
default=None,
|
||||||
help="Local directory to store the dataset during conversion. Defaults to /tmp/{repo_id}",
|
help="Local directory to store the dataset during conversion. Defaults to /tmp/{repo_id}",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--test-branch",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Repo branch to test your conversion first (e.g. 'v2.0.test')",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if not args.local_dir:
|
if not args.local_dir:
|
||||||
|
|
Loading…
Reference in New Issue