add working data loading optimization

This commit is contained in:
Jade Choghari 2025-02-24 14:21:55 +01:00
parent cf6e677485
commit 17572b3211
1 changed files with 10 additions and 3 deletions

View File

@ -57,12 +57,16 @@ from lerobot.common.datasets.video_utils import (
# let's define a custom fn # let's define a custom fn
def custom_collate_fn(batch): def custom_collate_fn(batch):
# always in the cuda, getitem is on cpu,
# then implement mixed
""" """
Custom collate function that decodes videos on GPU/CPU. Custom collate function that decodes videos on GPU/CPU.
Converts the batch to a dictionary with keys representing each field. Converts the batch to a dictionary with keys representing each field.
Returns a tensor for video frames instead of a list. Returns a tensor for video frames instead of a list.
""" """
# know when it is called
final_batch = {} final_batch = {}
is_main_process = torch.utils.data.get_worker_info() is None
# the batch is given as a list, we need to return a dict # the batch is given as a list, we need to return a dict
for item in batch: for item in batch:
@ -71,14 +75,17 @@ def custom_collate_fn(batch):
for vid_key, video_path in item["video_paths"].items(): for vid_key, video_path in item["video_paths"].items():
# decode video frames based on timestamps # decode video frames based on timestamps
timestamps = item["query_timestamps"][vid_key] timestamps = item["query_timestamps"][vid_key]
frames = decode_video_frames_torchvision(
# ✅ Use CUDA only in the main process
device = "cuda" if is_main_process else "cpu"
frames = decode_video_frames_torchcodec(
video_path=Path(video_path), video_path=Path(video_path),
timestamps=timestamps, timestamps=timestamps,
tolerance_s=0.02, tolerance_s=0.02,
backend="pyav", # backend="pyav",
log_loaded_timestamps=False, log_loaded_timestamps=False,
device=device, # ✅ Keeps CUDA safe
) )
# stack frames for this video key and add directly to the item # stack frames for this video key and add directly to the item
item[vid_key] = frames item[vid_key] = frames