add good custom fn

This commit is contained in:
Jade Choghari 2025-02-22 12:52:17 +01:00
parent cae49528ee
commit f585cec385
2 changed files with 18 additions and 32 deletions

View File

@ -692,6 +692,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
ep_idx = item["episode_index"].item() ep_idx = item["episode_index"].item()
query_indices = None query_indices = None
# data logic
if self.delta_indices is not None: if self.delta_indices is not None:
current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx
query_indices, padding = self._get_query_indices(idx, current_ep_idx) query_indices, padding = self._get_query_indices(idx, current_ep_idx)
@ -700,6 +701,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
for key, val in query_result.items(): for key, val in query_result.items():
item[key] = val item[key] = val
# video logic
if len(self.meta.video_keys) > 0: if len(self.meta.video_keys) > 0:
current_ts = item["timestamp"].item() current_ts = item["timestamp"].item()
query_timestamps = self._get_query_timestamps(current_ts, query_indices) query_timestamps = self._get_query_timestamps(current_ts, query_indices)

View File

@ -59,54 +59,38 @@ from torchcodec.decoders import VideoDecoder
def custom_collate_fn(batch): def custom_collate_fn(batch):
""" """
Custom collate function that decodes videos on CPU. Custom collate function that decodes videos on GPU/CPU.
Ensures batch format remains unchanged. Converts the batch to a dictionary with keys representing each field.
Returns a tensor for video frames instead of a list.
""" """
batched_frames = {} # Dictionary to hold video tensors final_batch = {}
final_batch = {} # Dictionary to hold the rest of the batch
# Initialize final_batch with all original keys (except video paths) # the batch is given as a list, we need to return a dict
for key in batch[0].keys():
if key not in ["video_paths", "query_timestamps"]: # Skip video-related fields
final_batch[key] = [item[key] for item in batch]
# Process video decoding
for item in batch: for item in batch:
# process video decoding for each item
if "video_paths" in item and "query_timestamps" in item: if "video_paths" in item and "query_timestamps" in item:
for vid_key, video_path in item["video_paths"].items(): for vid_key, video_path in item["video_paths"].items():
decoder = VideoDecoder(str(video_path), device="cpu") # CPU decoding # decode video frames based on timestamps
# frames = decoder.get_frames_played_at(item["query_timestamps"][vid_key]).data.float() / 255
timestamps = item["query_timestamps"][vid_key] timestamps = item["query_timestamps"][vid_key]
frames = decode_video_frames_torchvision( frames = decode_video_frames_torchvision(
video_path=Path(video_path), video_path=Path(video_path),
timestamps=timestamps, timestamps=timestamps,
tolerance_s=0.02, # Adjust tolerance if needed tolerance_s=0.02,
backend="pyav", # Default backend (modify if needed) backend="pyav",
log_loaded_timestamps=False, log_loaded_timestamps=False,
) )
if vid_key not in batched_frames: # stack frames for this video key
batched_frames[vid_key] = [] item[vid_key] = torch.stack(frames)
batched_frames[vid_key].append(frames)
# Convert lists to tensors where possible # add item data (both video and non-video) to final_batch
for key in batched_frames: for key, value in item.items():
batched_frames[key] = torch.stack(batched_frames[key]) # Stack tensors if key not in final_batch:
final_batch[key] = []
for key in final_batch: final_batch[key].append(value)
if isinstance(final_batch[key][0], torch.Tensor):
final_batch[key] = torch.stack(final_batch[key])
# **Fix: Ensure video_frames is a single tensor instead of a dictionary**
# hard coded this must change
if len(batched_frames) == 1:
final_batch["observation.images.top"] = list(batched_frames.values())[0] # Direct tensor
else:
final_batch["observation.images.top"] = batched_frames # Keep dict if multiple
return final_batch return final_batch
def update_policy( def update_policy(
train_metrics: MetricsTracker, train_metrics: MetricsTracker,
policy: PreTrainedPolicy, policy: PreTrainedPolicy,