add working data loading optimization

This commit is contained in:
Jade Choghari 2025-02-24 14:21:55 +01:00
parent cf6e677485
commit 17572b3211
1 changed files with 10 additions and 3 deletions

View File

@ -57,12 +57,16 @@ from lerobot.common.datasets.video_utils import (
# let's define a custom fn
def custom_collate_fn(batch):
# always in the cuda, getitem is on cpu,
# then implement mixed
"""
Custom collate function that decodes videos on GPU/CPU.
Converts the batch to a dictionary with keys representing each field.
Returns a tensor for video frames instead of a list.
"""
# know when it is called
final_batch = {}
is_main_process = torch.utils.data.get_worker_info() is None
# the batch is given as a list, we need to return a dict
for item in batch:
@ -71,14 +75,17 @@ def custom_collate_fn(batch):
for vid_key, video_path in item["video_paths"].items():
# decode video frames based on timestamps
timestamps = item["query_timestamps"][vid_key]
frames = decode_video_frames_torchvision(
# ✅ Use CUDA only in the main process
device = "cuda" if is_main_process else "cpu"
frames = decode_video_frames_torchcodec(
video_path=Path(video_path),
timestamps=timestamps,
tolerance_s=0.02,
backend="pyav",
# backend="pyav",
log_loaded_timestamps=False,
device=device, # ✅ Keeps CUDA safe
)
# stack frames for this video key and add directly to the item
item[vid_key] = frames