add working data loading optimization
This commit is contained in:
parent
cf6e677485
commit
17572b3211
|
@ -57,12 +57,16 @@ from lerobot.common.datasets.video_utils import (
|
|||
# let's define a custom fn
|
||||
|
||||
def custom_collate_fn(batch):
|
||||
# always in the cuda, getitem is on cpu,
|
||||
# then implement mixed
|
||||
"""
|
||||
Custom collate function that decodes videos on GPU/CPU.
|
||||
Converts the batch to a dictionary with keys representing each field.
|
||||
Returns a tensor for video frames instead of a list.
|
||||
"""
|
||||
# know when it is called
|
||||
final_batch = {}
|
||||
is_main_process = torch.utils.data.get_worker_info() is None
|
||||
|
||||
# the batch is given as a list, we need to return a dict
|
||||
for item in batch:
|
||||
|
@ -71,14 +75,17 @@ def custom_collate_fn(batch):
|
|||
for vid_key, video_path in item["video_paths"].items():
|
||||
# decode video frames based on timestamps
|
||||
timestamps = item["query_timestamps"][vid_key]
|
||||
frames = decode_video_frames_torchvision(
|
||||
|
||||
# ✅ Use CUDA only in the main process
|
||||
device = "cuda" if is_main_process else "cpu"
|
||||
frames = decode_video_frames_torchcodec(
|
||||
video_path=Path(video_path),
|
||||
timestamps=timestamps,
|
||||
tolerance_s=0.02,
|
||||
backend="pyav",
|
||||
# backend="pyav",
|
||||
log_loaded_timestamps=False,
|
||||
device=device, # ✅ Keeps CUDA safe
|
||||
)
|
||||
|
||||
# stack frames for this video key and add directly to the item
|
||||
item[vid_key] = frames
|
||||
|
||||
|
|
Loading…
Reference in New Issue