diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 59744c7e..8648eb1b 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -57,12 +57,16 @@ from lerobot.common.datasets.video_utils import ( # let's define a custom fn def custom_collate_fn(batch): + # always in the cuda, getitem is on cpu, + # then implement mixed """ Custom collate function that decodes videos on GPU/CPU. Converts the batch to a dictionary with keys representing each field. Returns a tensor for video frames instead of a list. """ + # know when it is called final_batch = {} + is_main_process = torch.utils.data.get_worker_info() is None # the batch is given as a list, we need to return a dict for item in batch: @@ -71,14 +75,17 @@ def custom_collate_fn(batch): for vid_key, video_path in item["video_paths"].items(): # decode video frames based on timestamps timestamps = item["query_timestamps"][vid_key] - frames = decode_video_frames_torchvision( + + # ✅ Use CUDA only in the main process + device = "cuda" if is_main_process else "cpu" + frames = decode_video_frames_torchcodec( video_path=Path(video_path), timestamps=timestamps, tolerance_s=0.02, - backend="pyav", + # backend="pyav", log_loaded_timestamps=False, + device=device, # ✅ Keeps CUDA safe ) - # stack frames for this video key and add directly to the item item[vid_key] = frames