From f585cec38566a11a8422f7e413f1c34df33679a4 Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Sat, 22 Feb 2025 12:52:17 +0100 Subject: [PATCH] add good custom fn --- lerobot/common/datasets/lerobot_dataset.py | 2 + lerobot/scripts/train.py | 48 ++++++++-------------- 2 files changed, 18 insertions(+), 32 deletions(-) diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 371c334a..0986ca54 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -692,6 +692,7 @@ class LeRobotDataset(torch.utils.data.Dataset): ep_idx = item["episode_index"].item() query_indices = None + # data logic if self.delta_indices is not None: current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx query_indices, padding = self._get_query_indices(idx, current_ep_idx) @@ -700,6 +701,7 @@ class LeRobotDataset(torch.utils.data.Dataset): for key, val in query_result.items(): item[key] = val + # video logic if len(self.meta.video_keys) > 0: current_ts = item["timestamp"].item() query_timestamps = self._get_query_timestamps(current_ts, query_indices) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index b52fb10f..dc17a59d 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -59,54 +59,38 @@ from torchcodec.decoders import VideoDecoder def custom_collate_fn(batch): """ - Custom collate function that decodes videos on CPU. - Ensures batch format remains unchanged. + Custom collate function that decodes videos on GPU/CPU. + Converts the batch to a dictionary with keys representing each field. + Returns a tensor for video frames instead of a list. """ - batched_frames = {} # Dictionary to hold video tensors - final_batch = {} # Dictionary to hold the rest of the batch + final_batch = {} - # Initialize final_batch with all original keys (except video paths) - for key in batch[0].keys(): - if key not in ["video_paths", "query_timestamps"]: # Skip video-related fields - final_batch[key] = [item[key] for item in batch] - - # Process video decoding + # the batch is given as a list, we need to return a dict for item in batch: + # process video decoding for each item if "video_paths" in item and "query_timestamps" in item: for vid_key, video_path in item["video_paths"].items(): - decoder = VideoDecoder(str(video_path), device="cpu") # CPU decoding - # frames = decoder.get_frames_played_at(item["query_timestamps"][vid_key]).data.float() / 255 + # decode video frames based on timestamps timestamps = item["query_timestamps"][vid_key] frames = decode_video_frames_torchvision( video_path=Path(video_path), timestamps=timestamps, - tolerance_s=0.02, # Adjust tolerance if needed - backend="pyav", # Default backend (modify if needed) + tolerance_s=0.02, + backend="pyav", log_loaded_timestamps=False, ) - if vid_key not in batched_frames: - batched_frames[vid_key] = [] - batched_frames[vid_key].append(frames) + # stack frames for this video key + item[vid_key] = torch.stack(frames) - # Convert lists to tensors where possible - for key in batched_frames: - batched_frames[key] = torch.stack(batched_frames[key]) # Stack tensors - - for key in final_batch: - if isinstance(final_batch[key][0], torch.Tensor): - final_batch[key] = torch.stack(final_batch[key]) - - # **Fix: Ensure video_frames is a single tensor instead of a dictionary** - # hard coded this must change - if len(batched_frames) == 1: - final_batch["observation.images.top"] = list(batched_frames.values())[0] # Direct tensor - else: - final_batch["observation.images.top"] = batched_frames # Keep dict if multiple + # add item data (both video and non-video) to final_batch + for key, value in item.items(): + if key not in final_batch: + final_batch[key] = [] + final_batch[key].append(value) return final_batch - def update_policy( train_metrics: MetricsTracker, policy: PreTrainedPolicy,