Add fixes for lfs tracking

This commit is contained in:
Simon Alibert 2024-10-17 12:58:48 +02:00
parent 50a75ad3fe
commit ad3f112d16
1 changed files with 58 additions and 32 deletions

View File

@ -85,6 +85,7 @@ python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \
import argparse
import contextlib
import filecmp
import json
import math
import shutil
@ -112,7 +113,7 @@ V20 = "v2.0"
EPISODE_CHUNK_SIZE = 1000
CLEAN_GITATTRIBUTES = Path("data/.gitattributes")
GITATTRIBUTES_REF = "aliberts/gitattributes_reference"
VIDEO_FILE = "{video_key}_episode_{episode_index:06d}.mp4"
PARQUET_CHUNK_PATH = (
@ -158,7 +159,7 @@ def load_json(fpath: Path) -> dict:
def write_json(data: dict, fpath: Path) -> None:
fpath.parent.mkdir(exist_ok=True, parents=True)
with open(fpath, "w") as f:
json.dump(data, f, indent=4)
json.dump(data, f, indent=4, ensure_ascii=False)
def write_jsonlines(data: dict, fpath: Path) -> None:
@ -274,8 +275,9 @@ def move_videos(
total_episodes: int,
total_chunks: int,
work_dir: Path,
clean_gittatributes: Path,
branch: str = "main",
):
) -> None:
"""
HACK: Since HfApi() doesn't provide a way to move files directly in a repo, this function will run git
commands to fetch git lfs video files references to move them into subdirectories without having to
@ -283,11 +285,25 @@ def move_videos(
"""
_lfs_clone(repo_id, work_dir, branch)
videos_moved = False
video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*.mp4")]
total_videos = len(video_files)
assert total_videos == total_episodes * len(video_keys)
if len(video_files) == 0:
video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*/*/*.mp4")]
videos_moved = True # Videos have already been moved
fix_lfs_video_files_tracking(work_dir, video_files, CLEAN_GITATTRIBUTES)
assert len(video_files) == total_episodes * len(video_keys)
lfs_untracked_videos = _get_lfs_untracked_videos(work_dir, video_files)
current_gittatributes = work_dir / ".gitattributes"
if not filecmp.cmp(current_gittatributes, clean_gittatributes, shallow=False):
fix_gitattributes(work_dir, current_gittatributes, clean_gittatributes)
if lfs_untracked_videos:
fix_lfs_video_files_tracking(work_dir, video_files)
if videos_moved:
return
video_dirs = sorted(work_dir.glob("videos*/"))
for ep_chunk in range(total_chunks):
@ -320,35 +336,30 @@ def move_videos(
subprocess.run(["git", "push"], cwd=work_dir, check=True)
def fix_lfs_video_files_tracking(work_dir: Path, video_files: list[str], clean_gitattributes_path: Path):
def fix_lfs_video_files_tracking(work_dir: Path, lfs_untracked_videos: list[str]) -> None:
"""
HACK: This function fixes the tracking by git lfs which was not properly set on some repos. In that case,
there's no other option than to download the actual files and reupload them with lfs tracking.
"""
# _lfs_clone(repo_id, work_dir, branch)
lfs_tracked_files = subprocess.run(
["git", "lfs", "ls-files", "-n"], cwd=work_dir, capture_output=True, text=True, check=True
)
lfs_tracked_files = set(lfs_tracked_files.stdout.splitlines())
lfs_untracked_videos = [f for f in video_files if f not in lfs_tracked_files]
for i in range(0, len(lfs_untracked_videos), 100):
files = lfs_untracked_videos[i : i + 100]
try:
subprocess.run(["git", "rm", "--cached", *files], cwd=work_dir, capture_output=True, check=True)
except subprocess.CalledProcessError as e:
print("git rm --cached ERROR:")
print(e.stderr)
subprocess.run(["git", "add", *files], cwd=work_dir, check=True)
if lfs_untracked_videos:
shutil.copyfile(clean_gitattributes_path, work_dir / ".gitattributes")
subprocess.run(["git", "add", ".gitattributes"], cwd=work_dir, check=True)
for i in range(0, len(lfs_untracked_videos), 100):
files = lfs_untracked_videos[i : i + 100]
try:
subprocess.run(
["git", "rm", "--cached", *files], cwd=work_dir, capture_output=True, check=True
)
except subprocess.CalledProcessError as e:
print("git rm --cached ERROR:")
print(e.stderr)
subprocess.run(["git", "add", *files], cwd=work_dir, check=True)
commit_message = "Track video files with git lfs"
subprocess.run(["git", "commit", "-m", commit_message], cwd=work_dir, check=True)
subprocess.run(["git", "push"], cwd=work_dir, check=True)
commit_message = "Track video files with git lfs"
subprocess.run(["git", "commit", "-m", commit_message], cwd=work_dir, check=True)
subprocess.run(["git", "push"], cwd=work_dir, check=True)
def fix_gitattributes(work_dir: Path, current_gittatributes: Path, clean_gittatributes: Path) -> None:
shutil.copyfile(clean_gittatributes, current_gittatributes)
subprocess.run(["git", "add", ".gitattributes"], cwd=work_dir, check=True)
subprocess.run(["git", "commit", "-m", "Fix .gitattributes"], cwd=work_dir, check=True)
subprocess.run(["git", "push"], cwd=work_dir, check=True)
def _lfs_clone(repo_id: str, work_dir: Path, branch: str) -> None:
@ -362,6 +373,14 @@ def _lfs_clone(repo_id: str, work_dir: Path, branch: str) -> None:
)
def _get_lfs_untracked_videos(work_dir: Path, video_files: list[str]) -> list[str]:
lfs_tracked_files = subprocess.run(
["git", "lfs", "ls-files", "-n"], cwd=work_dir, capture_output=True, text=True, check=True
)
lfs_tracked_files = set(lfs_tracked_files.stdout.splitlines())
return [f for f in video_files if f not in lfs_tracked_files]
def _get_audio_info(video_path: Path | str) -> dict:
ffprobe_audio_cmd = [
"ffprobe",
@ -585,7 +604,14 @@ def convert_dataset(
assert metadata_v1.get("video", False)
tmp_video_dir = local_dir / "videos" / V20 / repo_id
tmp_video_dir.mkdir(parents=True, exist_ok=True)
move_videos(repo_id, keys["video"], total_episodes, total_chunks, tmp_video_dir, branch)
clean_gitattr = Path(
hub_api.hf_hub_download(
repo_id=GITATTRIBUTES_REF, repo_type="dataset", local_dir=local_dir, filename=".gitattributes"
)
).absolute()
move_videos(
repo_id, keys["video"], total_episodes, total_chunks, tmp_video_dir, clean_gitattr, branch
)
videos_info = get_videos_info(repo_id, v1x_dir, video_keys=keys["video"], branch=branch)
video_shapes = get_video_shapes(videos_info, keys["video"])
for img_key in keys["video"]:
@ -735,7 +761,7 @@ def main():
"--local-dir",
type=Path,
default=None,
help="Local directory to store the dataset during conversion. Defaults to /tmp/{repo_id}",
help="Local directory to store the dataset during conversion. Defaults to /tmp/lerobot_dataset_v2",
)
parser.add_argument(
"--test-branch",
@ -746,7 +772,7 @@ def main():
args = parser.parse_args()
if not args.local_dir:
args.local_dir = Path(f"/tmp/{args.repo_id}")
args.local_dir = Path("/tmp/lerobot_dataset_v2")
robot_config = parse_robot_config(args.robot_config, args.robot_overrides) if args.robot_config else None
del args.robot_config, args.robot_overrides