fix bug
This commit is contained in:
parent
f585cec385
commit
cee77f3d4e
|
@ -80,7 +80,7 @@ def custom_collate_fn(batch):
|
|||
log_loaded_timestamps=False,
|
||||
)
|
||||
|
||||
# stack frames for this video key
|
||||
# stack frames for this video key and add directly to the item
|
||||
item[vid_key] = torch.stack(frames)
|
||||
|
||||
# add item data (both video and non-video) to final_batch
|
||||
|
@ -89,6 +89,13 @@ def custom_collate_fn(batch):
|
|||
final_batch[key] = []
|
||||
final_batch[key].append(value)
|
||||
|
||||
# now, stack tensors for each key in final_batch
|
||||
# this is needed to ensure that video frames (and any other tensor fields) are combined
|
||||
# into a single tensor per field, rather than a list of tensors!
|
||||
for key in final_batch:
|
||||
if isinstance(final_batch[key][0], torch.Tensor):
|
||||
final_batch[key] = torch.stack(final_batch[key]) # stack tensors if needed
|
||||
|
||||
return final_batch
|
||||
|
||||
def update_policy(
|
||||
|
|
Loading…
Reference in New Issue