fix bug
This commit is contained in:
parent
f585cec385
commit
cee77f3d4e
|
@ -80,7 +80,7 @@ def custom_collate_fn(batch):
|
||||||
log_loaded_timestamps=False,
|
log_loaded_timestamps=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# stack frames for this video key
|
# stack frames for this video key and add directly to the item
|
||||||
item[vid_key] = torch.stack(frames)
|
item[vid_key] = torch.stack(frames)
|
||||||
|
|
||||||
# add item data (both video and non-video) to final_batch
|
# add item data (both video and non-video) to final_batch
|
||||||
|
@ -89,6 +89,13 @@ def custom_collate_fn(batch):
|
||||||
final_batch[key] = []
|
final_batch[key] = []
|
||||||
final_batch[key].append(value)
|
final_batch[key].append(value)
|
||||||
|
|
||||||
|
# now, stack tensors for each key in final_batch
|
||||||
|
# this is needed to ensure that video frames (and any other tensor fields) are combined
|
||||||
|
# into a single tensor per field, rather than a list of tensors!
|
||||||
|
for key in final_batch:
|
||||||
|
if isinstance(final_batch[key][0], torch.Tensor):
|
||||||
|
final_batch[key] = torch.stack(final_batch[key]) # stack tensors if needed
|
||||||
|
|
||||||
return final_batch
|
return final_batch
|
||||||
|
|
||||||
def update_policy(
|
def update_policy(
|
||||||
|
|
Loading…
Reference in New Issue