Add episode chunks logic, move_videos & lfs tracking fix
This commit is contained in:
parent
110264000f
commit
c146ba936f
|
@ -87,6 +87,7 @@ import argparse
|
|||
import contextlib
|
||||
import json
|
||||
import math
|
||||
import shutil
|
||||
import subprocess
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
@ -108,8 +109,15 @@ from lerobot.scripts.push_dataset_to_hub import push_dataset_card_to_hub
|
|||
V16 = "v1.6"
|
||||
V20 = "v2.0"
|
||||
|
||||
PARQUET_PATH = "data/train-{episode_index:05d}-of-{total_episodes:05d}.parquet"
|
||||
VIDEO_PATH = "videos/{video_key}_episode_{episode_index:06d}.mp4"
|
||||
EPISODE_CHUNK_SIZE = 1000
|
||||
|
||||
CLEAN_GITATTRIBUTES = Path("data/.gitattributes")
|
||||
|
||||
VIDEO_FILE = "{video_key}_episode_{episode_index:06d}.mp4"
|
||||
PARQUET_CHUNK_PATH = (
|
||||
"data/chunk-{episode_chunk:03d}/train-{episode_index:05d}-of-{total_episodes:05d}.parquet"
|
||||
)
|
||||
VIDEO_CHUNK_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
||||
|
||||
|
||||
def parse_robot_config(config_path: Path, config_overrides: list[str] | None = None) -> tuple[str, dict]:
|
||||
|
@ -229,23 +237,125 @@ def add_task_index_from_tasks_col(
|
|||
|
||||
|
||||
def split_parquet_by_episodes(
|
||||
dataset: Dataset, keys: dict[str, list], total_episodes: int, episode_indices: list, output_dir: Path
|
||||
dataset: Dataset,
|
||||
keys: dict[str, list],
|
||||
total_episodes: int,
|
||||
total_chunks: int,
|
||||
output_dir: Path,
|
||||
) -> list:
|
||||
(output_dir / "data").mkdir(exist_ok=True, parents=True)
|
||||
table = dataset.remove_columns(keys["video"])._data.table
|
||||
episode_lengths = []
|
||||
for episode_index in sorted(episode_indices):
|
||||
# Write each episode_index to a new parquet file
|
||||
filtered_table = table.filter(pc.equal(table["episode_index"], episode_index))
|
||||
episode_lengths.insert(episode_index, len(filtered_table))
|
||||
output_file = output_dir / PARQUET_PATH.format(
|
||||
episode_index=episode_index, total_episodes=total_episodes
|
||||
)
|
||||
pq.write_table(filtered_table, output_file)
|
||||
for ep_chunk in range(total_chunks):
|
||||
ep_chunk_start = EPISODE_CHUNK_SIZE * ep_chunk
|
||||
ep_chunk_end = min(EPISODE_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
|
||||
|
||||
chunk_dir = "/".join(PARQUET_CHUNK_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk)
|
||||
(output_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
|
||||
for ep_idx in range(ep_chunk_start, ep_chunk_end):
|
||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
||||
episode_lengths.insert(ep_idx, len(ep_table))
|
||||
output_file = output_dir / PARQUET_CHUNK_PATH.format(
|
||||
episode_chunk=ep_chunk, episode_index=ep_idx, total_episodes=total_episodes
|
||||
)
|
||||
pq.write_table(ep_table, output_file)
|
||||
|
||||
return episode_lengths
|
||||
|
||||
|
||||
def move_videos(
|
||||
repo_id: str,
|
||||
video_keys: list[str],
|
||||
total_episodes: int,
|
||||
total_chunks: int,
|
||||
work_dir: Path,
|
||||
branch: str = "main",
|
||||
):
|
||||
"""
|
||||
HACK: Since HfApi() doesn't provide a way to move files directly in a repo, this function will run git
|
||||
commands to fetch git lfs video files references to move them into subdirectories without having to
|
||||
actually download them.
|
||||
"""
|
||||
_lfs_clone(repo_id, work_dir, branch)
|
||||
|
||||
video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*.mp4")]
|
||||
total_videos = len(video_files)
|
||||
assert total_videos == total_episodes * len(video_keys)
|
||||
|
||||
fix_lfs_video_files_tracking(work_dir, video_files, CLEAN_GITATTRIBUTES)
|
||||
|
||||
video_dirs = sorted(work_dir.glob("videos*/"))
|
||||
for ep_chunk in range(total_chunks):
|
||||
ep_chunk_start = EPISODE_CHUNK_SIZE * ep_chunk
|
||||
ep_chunk_end = min(EPISODE_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
|
||||
for vid_key in video_keys:
|
||||
chunk_dir = "/".join(VIDEO_CHUNK_PATH.split("/")[:-1]).format(
|
||||
episode_chunk=ep_chunk, video_key=vid_key
|
||||
)
|
||||
(work_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for ep_idx in range(ep_chunk_start, ep_chunk_end):
|
||||
target_path = VIDEO_CHUNK_PATH.format(
|
||||
episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_idx
|
||||
)
|
||||
video_file = VIDEO_FILE.format(video_key=vid_key, episode_index=ep_idx)
|
||||
if len(video_dirs) == 1:
|
||||
video_path = video_dirs[0] / video_file
|
||||
else:
|
||||
for dir in video_dirs:
|
||||
if (dir / video_file).is_file():
|
||||
video_path = dir / video_file
|
||||
break
|
||||
|
||||
video_path.rename(work_dir / target_path)
|
||||
|
||||
commit_message = "Move video files into chunk subdirectories"
|
||||
subprocess.run(["git", "add", "."], cwd=work_dir, check=True)
|
||||
subprocess.run(["git", "commit", "-m", commit_message], cwd=work_dir, check=True)
|
||||
subprocess.run(["git", "push"], cwd=work_dir, check=True)
|
||||
|
||||
|
||||
def fix_lfs_video_files_tracking(work_dir: Path, video_files: list[str], clean_gitattributes_path: Path):
|
||||
"""
|
||||
HACK: This function fixes the tracking by git lfs which was not properly set on some repos. In that case,
|
||||
there's no other option than to download the actual files and reupload them with lfs tracking.
|
||||
"""
|
||||
# _lfs_clone(repo_id, work_dir, branch)
|
||||
lfs_tracked_files = subprocess.run(
|
||||
["git", "lfs", "ls-files", "-n"], cwd=work_dir, capture_output=True, text=True, check=True
|
||||
)
|
||||
lfs_tracked_files = set(lfs_tracked_files.stdout.splitlines())
|
||||
lfs_untracked_videos = [f for f in video_files if f not in lfs_tracked_files]
|
||||
|
||||
if lfs_untracked_videos:
|
||||
shutil.copyfile(clean_gitattributes_path, work_dir / ".gitattributes")
|
||||
subprocess.run(["git", "add", ".gitattributes"], cwd=work_dir, check=True)
|
||||
for i in range(0, len(lfs_untracked_videos), 100):
|
||||
files = lfs_untracked_videos[i : i + 100]
|
||||
try:
|
||||
subprocess.run(
|
||||
["git", "rm", "--cached", *files], cwd=work_dir, capture_output=True, check=True
|
||||
)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print("git rm --cached ERROR:")
|
||||
print(e.stderr)
|
||||
subprocess.run(["git", "add", *files], cwd=work_dir, check=True)
|
||||
|
||||
commit_message = "Track video files with git lfs"
|
||||
subprocess.run(["git", "commit", "-m", commit_message], cwd=work_dir, check=True)
|
||||
subprocess.run(["git", "push"], cwd=work_dir, check=True)
|
||||
|
||||
|
||||
def _lfs_clone(repo_id: str, work_dir: Path, branch: str) -> None:
|
||||
subprocess.run(["git", "lfs", "install"], cwd=work_dir, check=True)
|
||||
repo_url = f"https://huggingface.co/datasets/{repo_id}"
|
||||
env = {"GIT_LFS_SKIP_SMUDGE": "1"} # Prevent downloading LFS files
|
||||
subprocess.run(
|
||||
["git", "clone", "--branch", branch, "--single-branch", "--depth", "1", repo_url, str(work_dir)],
|
||||
check=True,
|
||||
env=env,
|
||||
)
|
||||
|
||||
|
||||
def _get_audio_info(video_path: Path | str) -> dict:
|
||||
ffprobe_audio_cmd = [
|
||||
"ffprobe",
|
||||
|
@ -323,16 +433,19 @@ def _get_video_info(video_path: Path | str) -> dict:
|
|||
return video_info
|
||||
|
||||
|
||||
def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str]) -> dict:
|
||||
def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch: str) -> dict:
|
||||
hub_api = HfApi()
|
||||
videos_info_dict = {"videos_path": VIDEO_PATH}
|
||||
for vid_key in video_keys:
|
||||
# Assumes first episode
|
||||
video_path = VIDEO_PATH.format(video_key=vid_key, episode_index=0)
|
||||
video_path = hub_api.hf_hub_download(
|
||||
repo_id=repo_id, repo_type="dataset", local_dir=local_dir, filename=video_path
|
||||
)
|
||||
videos_info_dict[vid_key] = _get_video_info(video_path)
|
||||
videos_info_dict = {"videos_path": VIDEO_CHUNK_PATH}
|
||||
|
||||
# Assumes first episode
|
||||
video_files = [
|
||||
VIDEO_CHUNK_PATH.format(episode_chunk=0, video_key=vid_key, episode_index=0) for vid_key in video_keys
|
||||
]
|
||||
hub_api.snapshot_download(
|
||||
repo_id=repo_id, repo_type="dataset", local_dir=local_dir, revision=branch, allow_patterns=video_files
|
||||
)
|
||||
for vid_key, vid_path in zip(video_keys, video_files, strict=True):
|
||||
videos_info_dict[vid_key] = _get_video_info(local_dir / vid_path)
|
||||
|
||||
return videos_info_dict
|
||||
|
||||
|
@ -399,6 +512,7 @@ def convert_dataset(
|
|||
tasks_path: Path | None = None,
|
||||
tasks_col: Path | None = None,
|
||||
robot_config: dict | None = None,
|
||||
test_branch: str | None = None,
|
||||
):
|
||||
v1 = get_hub_safe_version(repo_id, V16, enforce_v2=False)
|
||||
v1x_dir = local_dir / V16 / repo_id
|
||||
|
@ -408,8 +522,12 @@ def convert_dataset(
|
|||
|
||||
hub_api = HfApi()
|
||||
hub_api.snapshot_download(
|
||||
repo_id=repo_id, repo_type="dataset", revision=v1, local_dir=v1x_dir, ignore_patterns="videos/"
|
||||
repo_id=repo_id, repo_type="dataset", revision=v1, local_dir=v1x_dir, ignore_patterns="videos*/"
|
||||
)
|
||||
branch = "main"
|
||||
if test_branch:
|
||||
branch = test_branch
|
||||
create_branch(repo_id=repo_id, branch=test_branch, repo_type="dataset")
|
||||
|
||||
metadata_v1 = load_json(v1x_dir / "meta_data" / "info.json")
|
||||
dataset = datasets.load_dataset("parquet", data_dir=v1x_dir / "data", split="train")
|
||||
|
@ -423,10 +541,14 @@ def convert_dataset(
|
|||
single_task = None
|
||||
tasks_col = "language_instruction"
|
||||
|
||||
# Episodes
|
||||
# Episodes & chunks
|
||||
episode_indices = sorted(dataset.unique("episode_index"))
|
||||
total_episodes = len(episode_indices)
|
||||
assert episode_indices == list(range(total_episodes))
|
||||
total_videos = total_episodes * len(keys["video"])
|
||||
total_chunks = total_episodes // EPISODE_CHUNK_SIZE
|
||||
if total_episodes % EPISODE_CHUNK_SIZE != 0:
|
||||
total_chunks += 1
|
||||
|
||||
# Tasks
|
||||
if single_task:
|
||||
|
@ -448,25 +570,30 @@ def convert_dataset(
|
|||
task_json = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)]
|
||||
write_json(task_json, v20_dir / "meta" / "tasks.json")
|
||||
|
||||
# Split data into 1 parquet file by episode
|
||||
episode_lengths = split_parquet_by_episodes(dataset, keys, total_episodes, episode_indices, v20_dir)
|
||||
|
||||
# Shapes
|
||||
sequence_shapes = {key: dataset.features[key].length for key in keys["sequence"]}
|
||||
image_shapes = get_image_shapes(dataset, keys["image"]) if len(keys["image"]) > 0 else {}
|
||||
|
||||
# Videos
|
||||
if len(keys["video"]) > 0:
|
||||
assert metadata_v1.get("video", False)
|
||||
videos_info = get_videos_info(repo_id, v1x_dir, video_keys=keys["video"])
|
||||
tmp_video_dir = local_dir / "videos" / V20 / repo_id
|
||||
tmp_video_dir.mkdir(parents=True, exist_ok=True)
|
||||
move_videos(repo_id, keys["video"], total_episodes, total_chunks, tmp_video_dir, branch)
|
||||
videos_info = get_videos_info(repo_id, v1x_dir, video_keys=keys["video"], branch=branch)
|
||||
video_shapes = get_video_shapes(videos_info, keys["video"])
|
||||
for img_key in keys["video"]:
|
||||
assert math.isclose(videos_info[img_key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3)
|
||||
if "encoding" in metadata_v1:
|
||||
assert videos_info[img_key]["video.pix_fmt"] == metadata_v1["encoding"]["pix_fmt"]
|
||||
else:
|
||||
assert len(keys["video"]) == 0
|
||||
assert metadata_v1.get("video", 0) == 0
|
||||
videos_info = None
|
||||
video_shapes = {}
|
||||
|
||||
# Split data into 1 parquet file by episode
|
||||
episode_lengths = split_parquet_by_episodes(dataset, keys, total_episodes, total_chunks, v20_dir)
|
||||
|
||||
# Names
|
||||
if robot_config is not None:
|
||||
robot_type = robot_config["robot_type"]
|
||||
|
@ -495,11 +622,14 @@ def convert_dataset(
|
|||
# Assemble metadata v2.0
|
||||
metadata_v2_0 = {
|
||||
"codebase_version": V20,
|
||||
"data_path": PARQUET_PATH,
|
||||
"data_path": PARQUET_CHUNK_PATH,
|
||||
"robot_type": robot_type,
|
||||
"total_episodes": total_episodes,
|
||||
"total_frames": len(dataset),
|
||||
"total_tasks": len(tasks),
|
||||
"total_videos": total_videos,
|
||||
"total_chunks": total_chunks,
|
||||
"chunks_size": EPISODE_CHUNK_SIZE,
|
||||
"fps": metadata_v1["fps"],
|
||||
"splits": {"train": f"0:{total_episodes}"},
|
||||
"keys": keys["sequence"],
|
||||
|
@ -512,37 +642,31 @@ def convert_dataset(
|
|||
write_json(metadata_v2_0, v20_dir / "meta" / "info.json")
|
||||
convert_stats_to_json(v1x_dir / "meta_data", v20_dir / "meta")
|
||||
|
||||
#### TODO: delete
|
||||
# repo_id = f"aliberts/{repo_id.split('/')[1]}"
|
||||
# if hub_api.repo_exists(repo_id=repo_id, repo_type="dataset"):
|
||||
# hub_api.delete_repo(repo_id=repo_id, repo_type="dataset")
|
||||
# hub_api.create_repo(repo_id=repo_id, repo_type="dataset", exist_ok=True)
|
||||
####
|
||||
with contextlib.suppress(EntryNotFoundError):
|
||||
hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch)
|
||||
|
||||
with contextlib.suppress(EntryNotFoundError):
|
||||
hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision="main")
|
||||
|
||||
with contextlib.suppress(EntryNotFoundError):
|
||||
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta_data", repo_type="dataset", revision="main")
|
||||
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta_data", repo_type="dataset", revision=branch)
|
||||
|
||||
hub_api.upload_folder(
|
||||
repo_id=repo_id,
|
||||
path_in_repo="data",
|
||||
folder_path=v20_dir / "data",
|
||||
repo_type="dataset",
|
||||
revision="main",
|
||||
revision=branch,
|
||||
)
|
||||
hub_api.upload_folder(
|
||||
repo_id=repo_id,
|
||||
path_in_repo="meta",
|
||||
folder_path=v20_dir / "meta",
|
||||
repo_type="dataset",
|
||||
revision="main",
|
||||
revision=branch,
|
||||
)
|
||||
|
||||
card_text = f"[meta/info.json](meta/info.json)\n```json\n{json.dumps(metadata_v2_0, indent=4)}\n```"
|
||||
push_dataset_card_to_hub(repo_id=repo_id, revision="main", tags=repo_tags, text=card_text)
|
||||
create_branch(repo_id=repo_id, branch=V20, repo_type="dataset")
|
||||
push_dataset_card_to_hub(repo_id=repo_id, revision=branch, tags=repo_tags, text=card_text)
|
||||
if not test_branch:
|
||||
create_branch(repo_id=repo_id, branch=V20, repo_type="dataset")
|
||||
|
||||
# TODO:
|
||||
# - [X] Add shapes
|
||||
|
@ -555,7 +679,9 @@ def convert_dataset(
|
|||
# - [X] Add splits
|
||||
# - [X] Push properly to branch v2.0 and delete v1.6 stuff from that branch
|
||||
# - [X] Handle multitask datasets
|
||||
# - [/] Add sanity checks (encoding, shapes)
|
||||
# - [X] Handle hf hub repo limits (add chunks logic)
|
||||
# - [X] Add test-branch
|
||||
# - [X] Add sanity checks (encoding, shapes)
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -601,6 +727,12 @@ def main():
|
|||
default=None,
|
||||
help="Local directory to store the dataset during conversion. Defaults to /tmp/{repo_id}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test-branch",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Repo branch to test your conversion first (e.g. 'v2.0.test')",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if not args.local_dir:
|
||||
|
|
Loading…
Reference in New Issue