add working data loading optimization
This commit is contained in:
parent
cf6e677485
commit
17572b3211
|
@ -57,12 +57,16 @@ from lerobot.common.datasets.video_utils import (
|
||||||
# let's define a custom fn
|
# let's define a custom fn
|
||||||
|
|
||||||
def custom_collate_fn(batch):
|
def custom_collate_fn(batch):
|
||||||
|
# always in the cuda, getitem is on cpu,
|
||||||
|
# then implement mixed
|
||||||
"""
|
"""
|
||||||
Custom collate function that decodes videos on GPU/CPU.
|
Custom collate function that decodes videos on GPU/CPU.
|
||||||
Converts the batch to a dictionary with keys representing each field.
|
Converts the batch to a dictionary with keys representing each field.
|
||||||
Returns a tensor for video frames instead of a list.
|
Returns a tensor for video frames instead of a list.
|
||||||
"""
|
"""
|
||||||
|
# know when it is called
|
||||||
final_batch = {}
|
final_batch = {}
|
||||||
|
is_main_process = torch.utils.data.get_worker_info() is None
|
||||||
|
|
||||||
# the batch is given as a list, we need to return a dict
|
# the batch is given as a list, we need to return a dict
|
||||||
for item in batch:
|
for item in batch:
|
||||||
|
@ -71,14 +75,17 @@ def custom_collate_fn(batch):
|
||||||
for vid_key, video_path in item["video_paths"].items():
|
for vid_key, video_path in item["video_paths"].items():
|
||||||
# decode video frames based on timestamps
|
# decode video frames based on timestamps
|
||||||
timestamps = item["query_timestamps"][vid_key]
|
timestamps = item["query_timestamps"][vid_key]
|
||||||
frames = decode_video_frames_torchvision(
|
|
||||||
|
# ✅ Use CUDA only in the main process
|
||||||
|
device = "cuda" if is_main_process else "cpu"
|
||||||
|
frames = decode_video_frames_torchcodec(
|
||||||
video_path=Path(video_path),
|
video_path=Path(video_path),
|
||||||
timestamps=timestamps,
|
timestamps=timestamps,
|
||||||
tolerance_s=0.02,
|
tolerance_s=0.02,
|
||||||
backend="pyav",
|
# backend="pyav",
|
||||||
log_loaded_timestamps=False,
|
log_loaded_timestamps=False,
|
||||||
|
device=device, # ✅ Keeps CUDA safe
|
||||||
)
|
)
|
||||||
|
|
||||||
# stack frames for this video key and add directly to the item
|
# stack frames for this video key and add directly to the item
|
||||||
item[vid_key] = frames
|
item[vid_key] = frames
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue