This commit is contained in:
Jade Choghari 2025-02-22 13:27:45 +01:00
parent f585cec385
commit cee77f3d4e
1 changed files with 8 additions and 1 deletions

View File

@ -80,7 +80,7 @@ def custom_collate_fn(batch):
log_loaded_timestamps=False,
)
# stack frames for this video key
# stack frames for this video key and add directly to the item
item[vid_key] = torch.stack(frames)
# add item data (both video and non-video) to final_batch
@ -89,6 +89,13 @@ def custom_collate_fn(batch):
final_batch[key] = []
final_batch[key].append(value)
# now, stack tensors for each key in final_batch
# this is needed to ensure that video frames (and any other tensor fields) are combined
# into a single tensor per field, rather than a list of tensors!
for key in final_batch:
if isinstance(final_batch[key][0], torch.Tensor):
final_batch[key] = torch.stack(final_batch[key]) # stack tensors if needed
return final_batch
def update_policy(