Add episode chunks logic, move_videos & lfs tracking fix

This commit is contained in:
Simon Alibert 2024-10-16 23:34:54 +02:00
parent 110264000f
commit c146ba936f
1 changed files with 176 additions and 44 deletions

View File

@ -87,6 +87,7 @@ import argparse
import contextlib
import json
import math
import shutil
import subprocess
import warnings
from pathlib import Path
@ -108,8 +109,15 @@ from lerobot.scripts.push_dataset_to_hub import push_dataset_card_to_hub
V16 = "v1.6"
V20 = "v2.0"
PARQUET_PATH = "data/train-{episode_index:05d}-of-{total_episodes:05d}.parquet"
VIDEO_PATH = "videos/{video_key}_episode_{episode_index:06d}.mp4"
EPISODE_CHUNK_SIZE = 1000
CLEAN_GITATTRIBUTES = Path("data/.gitattributes")
VIDEO_FILE = "{video_key}_episode_{episode_index:06d}.mp4"
PARQUET_CHUNK_PATH = (
"data/chunk-{episode_chunk:03d}/train-{episode_index:05d}-of-{total_episodes:05d}.parquet"
)
VIDEO_CHUNK_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
def parse_robot_config(config_path: Path, config_overrides: list[str] | None = None) -> tuple[str, dict]:
@ -229,23 +237,125 @@ def add_task_index_from_tasks_col(
def split_parquet_by_episodes(
dataset: Dataset, keys: dict[str, list], total_episodes: int, episode_indices: list, output_dir: Path
dataset: Dataset,
keys: dict[str, list],
total_episodes: int,
total_chunks: int,
output_dir: Path,
) -> list:
(output_dir / "data").mkdir(exist_ok=True, parents=True)
table = dataset.remove_columns(keys["video"])._data.table
episode_lengths = []
for episode_index in sorted(episode_indices):
# Write each episode_index to a new parquet file
filtered_table = table.filter(pc.equal(table["episode_index"], episode_index))
episode_lengths.insert(episode_index, len(filtered_table))
output_file = output_dir / PARQUET_PATH.format(
episode_index=episode_index, total_episodes=total_episodes
)
pq.write_table(filtered_table, output_file)
for ep_chunk in range(total_chunks):
ep_chunk_start = EPISODE_CHUNK_SIZE * ep_chunk
ep_chunk_end = min(EPISODE_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
chunk_dir = "/".join(PARQUET_CHUNK_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk)
(output_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
for ep_idx in range(ep_chunk_start, ep_chunk_end):
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
episode_lengths.insert(ep_idx, len(ep_table))
output_file = output_dir / PARQUET_CHUNK_PATH.format(
episode_chunk=ep_chunk, episode_index=ep_idx, total_episodes=total_episodes
)
pq.write_table(ep_table, output_file)
return episode_lengths
def move_videos(
repo_id: str,
video_keys: list[str],
total_episodes: int,
total_chunks: int,
work_dir: Path,
branch: str = "main",
):
"""
HACK: Since HfApi() doesn't provide a way to move files directly in a repo, this function will run git
commands to fetch git lfs video files references to move them into subdirectories without having to
actually download them.
"""
_lfs_clone(repo_id, work_dir, branch)
video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*.mp4")]
total_videos = len(video_files)
assert total_videos == total_episodes * len(video_keys)
fix_lfs_video_files_tracking(work_dir, video_files, CLEAN_GITATTRIBUTES)
video_dirs = sorted(work_dir.glob("videos*/"))
for ep_chunk in range(total_chunks):
ep_chunk_start = EPISODE_CHUNK_SIZE * ep_chunk
ep_chunk_end = min(EPISODE_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
for vid_key in video_keys:
chunk_dir = "/".join(VIDEO_CHUNK_PATH.split("/")[:-1]).format(
episode_chunk=ep_chunk, video_key=vid_key
)
(work_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
for ep_idx in range(ep_chunk_start, ep_chunk_end):
target_path = VIDEO_CHUNK_PATH.format(
episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_idx
)
video_file = VIDEO_FILE.format(video_key=vid_key, episode_index=ep_idx)
if len(video_dirs) == 1:
video_path = video_dirs[0] / video_file
else:
for dir in video_dirs:
if (dir / video_file).is_file():
video_path = dir / video_file
break
video_path.rename(work_dir / target_path)
commit_message = "Move video files into chunk subdirectories"
subprocess.run(["git", "add", "."], cwd=work_dir, check=True)
subprocess.run(["git", "commit", "-m", commit_message], cwd=work_dir, check=True)
subprocess.run(["git", "push"], cwd=work_dir, check=True)
def fix_lfs_video_files_tracking(work_dir: Path, video_files: list[str], clean_gitattributes_path: Path):
"""
HACK: This function fixes the tracking by git lfs which was not properly set on some repos. In that case,
there's no other option than to download the actual files and reupload them with lfs tracking.
"""
# _lfs_clone(repo_id, work_dir, branch)
lfs_tracked_files = subprocess.run(
["git", "lfs", "ls-files", "-n"], cwd=work_dir, capture_output=True, text=True, check=True
)
lfs_tracked_files = set(lfs_tracked_files.stdout.splitlines())
lfs_untracked_videos = [f for f in video_files if f not in lfs_tracked_files]
if lfs_untracked_videos:
shutil.copyfile(clean_gitattributes_path, work_dir / ".gitattributes")
subprocess.run(["git", "add", ".gitattributes"], cwd=work_dir, check=True)
for i in range(0, len(lfs_untracked_videos), 100):
files = lfs_untracked_videos[i : i + 100]
try:
subprocess.run(
["git", "rm", "--cached", *files], cwd=work_dir, capture_output=True, check=True
)
except subprocess.CalledProcessError as e:
print("git rm --cached ERROR:")
print(e.stderr)
subprocess.run(["git", "add", *files], cwd=work_dir, check=True)
commit_message = "Track video files with git lfs"
subprocess.run(["git", "commit", "-m", commit_message], cwd=work_dir, check=True)
subprocess.run(["git", "push"], cwd=work_dir, check=True)
def _lfs_clone(repo_id: str, work_dir: Path, branch: str) -> None:
subprocess.run(["git", "lfs", "install"], cwd=work_dir, check=True)
repo_url = f"https://huggingface.co/datasets/{repo_id}"
env = {"GIT_LFS_SKIP_SMUDGE": "1"} # Prevent downloading LFS files
subprocess.run(
["git", "clone", "--branch", branch, "--single-branch", "--depth", "1", repo_url, str(work_dir)],
check=True,
env=env,
)
def _get_audio_info(video_path: Path | str) -> dict:
ffprobe_audio_cmd = [
"ffprobe",
@ -323,16 +433,19 @@ def _get_video_info(video_path: Path | str) -> dict:
return video_info
def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str]) -> dict:
def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch: str) -> dict:
hub_api = HfApi()
videos_info_dict = {"videos_path": VIDEO_PATH}
for vid_key in video_keys:
# Assumes first episode
video_path = VIDEO_PATH.format(video_key=vid_key, episode_index=0)
video_path = hub_api.hf_hub_download(
repo_id=repo_id, repo_type="dataset", local_dir=local_dir, filename=video_path
)
videos_info_dict[vid_key] = _get_video_info(video_path)
videos_info_dict = {"videos_path": VIDEO_CHUNK_PATH}
# Assumes first episode
video_files = [
VIDEO_CHUNK_PATH.format(episode_chunk=0, video_key=vid_key, episode_index=0) for vid_key in video_keys
]
hub_api.snapshot_download(
repo_id=repo_id, repo_type="dataset", local_dir=local_dir, revision=branch, allow_patterns=video_files
)
for vid_key, vid_path in zip(video_keys, video_files, strict=True):
videos_info_dict[vid_key] = _get_video_info(local_dir / vid_path)
return videos_info_dict
@ -399,6 +512,7 @@ def convert_dataset(
tasks_path: Path | None = None,
tasks_col: Path | None = None,
robot_config: dict | None = None,
test_branch: str | None = None,
):
v1 = get_hub_safe_version(repo_id, V16, enforce_v2=False)
v1x_dir = local_dir / V16 / repo_id
@ -408,8 +522,12 @@ def convert_dataset(
hub_api = HfApi()
hub_api.snapshot_download(
repo_id=repo_id, repo_type="dataset", revision=v1, local_dir=v1x_dir, ignore_patterns="videos/"
repo_id=repo_id, repo_type="dataset", revision=v1, local_dir=v1x_dir, ignore_patterns="videos*/"
)
branch = "main"
if test_branch:
branch = test_branch
create_branch(repo_id=repo_id, branch=test_branch, repo_type="dataset")
metadata_v1 = load_json(v1x_dir / "meta_data" / "info.json")
dataset = datasets.load_dataset("parquet", data_dir=v1x_dir / "data", split="train")
@ -423,10 +541,14 @@ def convert_dataset(
single_task = None
tasks_col = "language_instruction"
# Episodes
# Episodes & chunks
episode_indices = sorted(dataset.unique("episode_index"))
total_episodes = len(episode_indices)
assert episode_indices == list(range(total_episodes))
total_videos = total_episodes * len(keys["video"])
total_chunks = total_episodes // EPISODE_CHUNK_SIZE
if total_episodes % EPISODE_CHUNK_SIZE != 0:
total_chunks += 1
# Tasks
if single_task:
@ -448,25 +570,30 @@ def convert_dataset(
task_json = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)]
write_json(task_json, v20_dir / "meta" / "tasks.json")
# Split data into 1 parquet file by episode
episode_lengths = split_parquet_by_episodes(dataset, keys, total_episodes, episode_indices, v20_dir)
# Shapes
sequence_shapes = {key: dataset.features[key].length for key in keys["sequence"]}
image_shapes = get_image_shapes(dataset, keys["image"]) if len(keys["image"]) > 0 else {}
# Videos
if len(keys["video"]) > 0:
assert metadata_v1.get("video", False)
videos_info = get_videos_info(repo_id, v1x_dir, video_keys=keys["video"])
tmp_video_dir = local_dir / "videos" / V20 / repo_id
tmp_video_dir.mkdir(parents=True, exist_ok=True)
move_videos(repo_id, keys["video"], total_episodes, total_chunks, tmp_video_dir, branch)
videos_info = get_videos_info(repo_id, v1x_dir, video_keys=keys["video"], branch=branch)
video_shapes = get_video_shapes(videos_info, keys["video"])
for img_key in keys["video"]:
assert math.isclose(videos_info[img_key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3)
if "encoding" in metadata_v1:
assert videos_info[img_key]["video.pix_fmt"] == metadata_v1["encoding"]["pix_fmt"]
else:
assert len(keys["video"]) == 0
assert metadata_v1.get("video", 0) == 0
videos_info = None
video_shapes = {}
# Split data into 1 parquet file by episode
episode_lengths = split_parquet_by_episodes(dataset, keys, total_episodes, total_chunks, v20_dir)
# Names
if robot_config is not None:
robot_type = robot_config["robot_type"]
@ -495,11 +622,14 @@ def convert_dataset(
# Assemble metadata v2.0
metadata_v2_0 = {
"codebase_version": V20,
"data_path": PARQUET_PATH,
"data_path": PARQUET_CHUNK_PATH,
"robot_type": robot_type,
"total_episodes": total_episodes,
"total_frames": len(dataset),
"total_tasks": len(tasks),
"total_videos": total_videos,
"total_chunks": total_chunks,
"chunks_size": EPISODE_CHUNK_SIZE,
"fps": metadata_v1["fps"],
"splits": {"train": f"0:{total_episodes}"},
"keys": keys["sequence"],
@ -512,37 +642,31 @@ def convert_dataset(
write_json(metadata_v2_0, v20_dir / "meta" / "info.json")
convert_stats_to_json(v1x_dir / "meta_data", v20_dir / "meta")
#### TODO: delete
# repo_id = f"aliberts/{repo_id.split('/')[1]}"
# if hub_api.repo_exists(repo_id=repo_id, repo_type="dataset"):
# hub_api.delete_repo(repo_id=repo_id, repo_type="dataset")
# hub_api.create_repo(repo_id=repo_id, repo_type="dataset", exist_ok=True)
####
with contextlib.suppress(EntryNotFoundError):
hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch)
with contextlib.suppress(EntryNotFoundError):
hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision="main")
with contextlib.suppress(EntryNotFoundError):
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta_data", repo_type="dataset", revision="main")
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta_data", repo_type="dataset", revision=branch)
hub_api.upload_folder(
repo_id=repo_id,
path_in_repo="data",
folder_path=v20_dir / "data",
repo_type="dataset",
revision="main",
revision=branch,
)
hub_api.upload_folder(
repo_id=repo_id,
path_in_repo="meta",
folder_path=v20_dir / "meta",
repo_type="dataset",
revision="main",
revision=branch,
)
card_text = f"[meta/info.json](meta/info.json)\n```json\n{json.dumps(metadata_v2_0, indent=4)}\n```"
push_dataset_card_to_hub(repo_id=repo_id, revision="main", tags=repo_tags, text=card_text)
create_branch(repo_id=repo_id, branch=V20, repo_type="dataset")
push_dataset_card_to_hub(repo_id=repo_id, revision=branch, tags=repo_tags, text=card_text)
if not test_branch:
create_branch(repo_id=repo_id, branch=V20, repo_type="dataset")
# TODO:
# - [X] Add shapes
@ -555,7 +679,9 @@ def convert_dataset(
# - [X] Add splits
# - [X] Push properly to branch v2.0 and delete v1.6 stuff from that branch
# - [X] Handle multitask datasets
# - [/] Add sanity checks (encoding, shapes)
# - [X] Handle hf hub repo limits (add chunks logic)
# - [X] Add test-branch
# - [X] Add sanity checks (encoding, shapes)
def main():
@ -601,6 +727,12 @@ def main():
default=None,
help="Local directory to store the dataset during conversion. Defaults to /tmp/{repo_id}",
)
parser.add_argument(
"--test-branch",
type=str,
default=None,
help="Repo branch to test your conversion first (e.g. 'v2.0.test')",
)
args = parser.parse_args()
if not args.local_dir: