Add fixes for lfs tracking

This commit is contained in:
Simon Alibert 2024-10-17 12:58:48 +02:00
parent 50a75ad3fe
commit ad3f112d16
1 changed files with 58 additions and 32 deletions

View File

@ -85,6 +85,7 @@ python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \
import argparse import argparse
import contextlib import contextlib
import filecmp
import json import json
import math import math
import shutil import shutil
@ -112,7 +113,7 @@ V20 = "v2.0"
EPISODE_CHUNK_SIZE = 1000 EPISODE_CHUNK_SIZE = 1000
CLEAN_GITATTRIBUTES = Path("data/.gitattributes") GITATTRIBUTES_REF = "aliberts/gitattributes_reference"
VIDEO_FILE = "{video_key}_episode_{episode_index:06d}.mp4" VIDEO_FILE = "{video_key}_episode_{episode_index:06d}.mp4"
PARQUET_CHUNK_PATH = ( PARQUET_CHUNK_PATH = (
@ -158,7 +159,7 @@ def load_json(fpath: Path) -> dict:
def write_json(data: dict, fpath: Path) -> None: def write_json(data: dict, fpath: Path) -> None:
fpath.parent.mkdir(exist_ok=True, parents=True) fpath.parent.mkdir(exist_ok=True, parents=True)
with open(fpath, "w") as f: with open(fpath, "w") as f:
json.dump(data, f, indent=4) json.dump(data, f, indent=4, ensure_ascii=False)
def write_jsonlines(data: dict, fpath: Path) -> None: def write_jsonlines(data: dict, fpath: Path) -> None:
@ -274,8 +275,9 @@ def move_videos(
total_episodes: int, total_episodes: int,
total_chunks: int, total_chunks: int,
work_dir: Path, work_dir: Path,
clean_gittatributes: Path,
branch: str = "main", branch: str = "main",
): ) -> None:
""" """
HACK: Since HfApi() doesn't provide a way to move files directly in a repo, this function will run git HACK: Since HfApi() doesn't provide a way to move files directly in a repo, this function will run git
commands to fetch git lfs video files references to move them into subdirectories without having to commands to fetch git lfs video files references to move them into subdirectories without having to
@ -283,11 +285,25 @@ def move_videos(
""" """
_lfs_clone(repo_id, work_dir, branch) _lfs_clone(repo_id, work_dir, branch)
videos_moved = False
video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*.mp4")] video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*.mp4")]
total_videos = len(video_files) if len(video_files) == 0:
assert total_videos == total_episodes * len(video_keys) video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*/*/*.mp4")]
videos_moved = True # Videos have already been moved
fix_lfs_video_files_tracking(work_dir, video_files, CLEAN_GITATTRIBUTES) assert len(video_files) == total_episodes * len(video_keys)
lfs_untracked_videos = _get_lfs_untracked_videos(work_dir, video_files)
current_gittatributes = work_dir / ".gitattributes"
if not filecmp.cmp(current_gittatributes, clean_gittatributes, shallow=False):
fix_gitattributes(work_dir, current_gittatributes, clean_gittatributes)
if lfs_untracked_videos:
fix_lfs_video_files_tracking(work_dir, video_files)
if videos_moved:
return
video_dirs = sorted(work_dir.glob("videos*/")) video_dirs = sorted(work_dir.glob("videos*/"))
for ep_chunk in range(total_chunks): for ep_chunk in range(total_chunks):
@ -320,27 +336,15 @@ def move_videos(
subprocess.run(["git", "push"], cwd=work_dir, check=True) subprocess.run(["git", "push"], cwd=work_dir, check=True)
def fix_lfs_video_files_tracking(work_dir: Path, video_files: list[str], clean_gitattributes_path: Path): def fix_lfs_video_files_tracking(work_dir: Path, lfs_untracked_videos: list[str]) -> None:
""" """
HACK: This function fixes the tracking by git lfs which was not properly set on some repos. In that case, HACK: This function fixes the tracking by git lfs which was not properly set on some repos. In that case,
there's no other option than to download the actual files and reupload them with lfs tracking. there's no other option than to download the actual files and reupload them with lfs tracking.
""" """
# _lfs_clone(repo_id, work_dir, branch)
lfs_tracked_files = subprocess.run(
["git", "lfs", "ls-files", "-n"], cwd=work_dir, capture_output=True, text=True, check=True
)
lfs_tracked_files = set(lfs_tracked_files.stdout.splitlines())
lfs_untracked_videos = [f for f in video_files if f not in lfs_tracked_files]
if lfs_untracked_videos:
shutil.copyfile(clean_gitattributes_path, work_dir / ".gitattributes")
subprocess.run(["git", "add", ".gitattributes"], cwd=work_dir, check=True)
for i in range(0, len(lfs_untracked_videos), 100): for i in range(0, len(lfs_untracked_videos), 100):
files = lfs_untracked_videos[i : i + 100] files = lfs_untracked_videos[i : i + 100]
try: try:
subprocess.run( subprocess.run(["git", "rm", "--cached", *files], cwd=work_dir, capture_output=True, check=True)
["git", "rm", "--cached", *files], cwd=work_dir, capture_output=True, check=True
)
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
print("git rm --cached ERROR:") print("git rm --cached ERROR:")
print(e.stderr) print(e.stderr)
@ -351,6 +355,13 @@ def fix_lfs_video_files_tracking(work_dir: Path, video_files: list[str], clean_g
subprocess.run(["git", "push"], cwd=work_dir, check=True) subprocess.run(["git", "push"], cwd=work_dir, check=True)
def fix_gitattributes(work_dir: Path, current_gittatributes: Path, clean_gittatributes: Path) -> None:
shutil.copyfile(clean_gittatributes, current_gittatributes)
subprocess.run(["git", "add", ".gitattributes"], cwd=work_dir, check=True)
subprocess.run(["git", "commit", "-m", "Fix .gitattributes"], cwd=work_dir, check=True)
subprocess.run(["git", "push"], cwd=work_dir, check=True)
def _lfs_clone(repo_id: str, work_dir: Path, branch: str) -> None: def _lfs_clone(repo_id: str, work_dir: Path, branch: str) -> None:
subprocess.run(["git", "lfs", "install"], cwd=work_dir, check=True) subprocess.run(["git", "lfs", "install"], cwd=work_dir, check=True)
repo_url = f"https://huggingface.co/datasets/{repo_id}" repo_url = f"https://huggingface.co/datasets/{repo_id}"
@ -362,6 +373,14 @@ def _lfs_clone(repo_id: str, work_dir: Path, branch: str) -> None:
) )
def _get_lfs_untracked_videos(work_dir: Path, video_files: list[str]) -> list[str]:
lfs_tracked_files = subprocess.run(
["git", "lfs", "ls-files", "-n"], cwd=work_dir, capture_output=True, text=True, check=True
)
lfs_tracked_files = set(lfs_tracked_files.stdout.splitlines())
return [f for f in video_files if f not in lfs_tracked_files]
def _get_audio_info(video_path: Path | str) -> dict: def _get_audio_info(video_path: Path | str) -> dict:
ffprobe_audio_cmd = [ ffprobe_audio_cmd = [
"ffprobe", "ffprobe",
@ -585,7 +604,14 @@ def convert_dataset(
assert metadata_v1.get("video", False) assert metadata_v1.get("video", False)
tmp_video_dir = local_dir / "videos" / V20 / repo_id tmp_video_dir = local_dir / "videos" / V20 / repo_id
tmp_video_dir.mkdir(parents=True, exist_ok=True) tmp_video_dir.mkdir(parents=True, exist_ok=True)
move_videos(repo_id, keys["video"], total_episodes, total_chunks, tmp_video_dir, branch) clean_gitattr = Path(
hub_api.hf_hub_download(
repo_id=GITATTRIBUTES_REF, repo_type="dataset", local_dir=local_dir, filename=".gitattributes"
)
).absolute()
move_videos(
repo_id, keys["video"], total_episodes, total_chunks, tmp_video_dir, clean_gitattr, branch
)
videos_info = get_videos_info(repo_id, v1x_dir, video_keys=keys["video"], branch=branch) videos_info = get_videos_info(repo_id, v1x_dir, video_keys=keys["video"], branch=branch)
video_shapes = get_video_shapes(videos_info, keys["video"]) video_shapes = get_video_shapes(videos_info, keys["video"])
for img_key in keys["video"]: for img_key in keys["video"]:
@ -735,7 +761,7 @@ def main():
"--local-dir", "--local-dir",
type=Path, type=Path,
default=None, default=None,
help="Local directory to store the dataset during conversion. Defaults to /tmp/{repo_id}", help="Local directory to store the dataset during conversion. Defaults to /tmp/lerobot_dataset_v2",
) )
parser.add_argument( parser.add_argument(
"--test-branch", "--test-branch",
@ -746,7 +772,7 @@ def main():
args = parser.parse_args() args = parser.parse_args()
if not args.local_dir: if not args.local_dir:
args.local_dir = Path(f"/tmp/{args.repo_id}") args.local_dir = Path("/tmp/lerobot_dataset_v2")
robot_config = parse_robot_config(args.robot_config, args.robot_overrides) if args.robot_config else None robot_config = parse_robot_config(args.robot_config, args.robot_overrides) if args.robot_config else None
del args.robot_config, args.robot_overrides del args.robot_config, args.robot_overrides