diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index dc17a59d..2b6f7d85 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -80,7 +80,7 @@ def custom_collate_fn(batch): log_loaded_timestamps=False, ) - # stack frames for this video key + # stack frames for this video key and add directly to the item item[vid_key] = torch.stack(frames) # add item data (both video and non-video) to final_batch @@ -89,6 +89,13 @@ def custom_collate_fn(batch): final_batch[key] = [] final_batch[key].append(value) + # now, stack tensors for each key in final_batch + # this is needed to ensure that video frames (and any other tensor fields) are combined + # into a single tensor per field, rather than a list of tensors! + for key in final_batch: + if isinstance(final_batch[key][0], torch.Tensor): + final_batch[key] = torch.stack(final_batch[key]) # stack tensors if needed + return final_batch def update_policy(