add good custom fn
This commit is contained in:
parent
cae49528ee
commit
f585cec385
|
@ -692,6 +692,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
ep_idx = item["episode_index"].item()
|
ep_idx = item["episode_index"].item()
|
||||||
|
|
||||||
query_indices = None
|
query_indices = None
|
||||||
|
# data logic
|
||||||
if self.delta_indices is not None:
|
if self.delta_indices is not None:
|
||||||
current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx
|
current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx
|
||||||
query_indices, padding = self._get_query_indices(idx, current_ep_idx)
|
query_indices, padding = self._get_query_indices(idx, current_ep_idx)
|
||||||
|
@ -700,6 +701,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
for key, val in query_result.items():
|
for key, val in query_result.items():
|
||||||
item[key] = val
|
item[key] = val
|
||||||
|
|
||||||
|
# video logic
|
||||||
if len(self.meta.video_keys) > 0:
|
if len(self.meta.video_keys) > 0:
|
||||||
current_ts = item["timestamp"].item()
|
current_ts = item["timestamp"].item()
|
||||||
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
|
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
|
||||||
|
|
|
@ -59,54 +59,38 @@ from torchcodec.decoders import VideoDecoder
|
||||||
|
|
||||||
def custom_collate_fn(batch):
|
def custom_collate_fn(batch):
|
||||||
"""
|
"""
|
||||||
Custom collate function that decodes videos on CPU.
|
Custom collate function that decodes videos on GPU/CPU.
|
||||||
Ensures batch format remains unchanged.
|
Converts the batch to a dictionary with keys representing each field.
|
||||||
|
Returns a tensor for video frames instead of a list.
|
||||||
"""
|
"""
|
||||||
batched_frames = {} # Dictionary to hold video tensors
|
final_batch = {}
|
||||||
final_batch = {} # Dictionary to hold the rest of the batch
|
|
||||||
|
|
||||||
# Initialize final_batch with all original keys (except video paths)
|
# the batch is given as a list, we need to return a dict
|
||||||
for key in batch[0].keys():
|
|
||||||
if key not in ["video_paths", "query_timestamps"]: # Skip video-related fields
|
|
||||||
final_batch[key] = [item[key] for item in batch]
|
|
||||||
|
|
||||||
# Process video decoding
|
|
||||||
for item in batch:
|
for item in batch:
|
||||||
|
# process video decoding for each item
|
||||||
if "video_paths" in item and "query_timestamps" in item:
|
if "video_paths" in item and "query_timestamps" in item:
|
||||||
for vid_key, video_path in item["video_paths"].items():
|
for vid_key, video_path in item["video_paths"].items():
|
||||||
decoder = VideoDecoder(str(video_path), device="cpu") # CPU decoding
|
# decode video frames based on timestamps
|
||||||
# frames = decoder.get_frames_played_at(item["query_timestamps"][vid_key]).data.float() / 255
|
|
||||||
timestamps = item["query_timestamps"][vid_key]
|
timestamps = item["query_timestamps"][vid_key]
|
||||||
frames = decode_video_frames_torchvision(
|
frames = decode_video_frames_torchvision(
|
||||||
video_path=Path(video_path),
|
video_path=Path(video_path),
|
||||||
timestamps=timestamps,
|
timestamps=timestamps,
|
||||||
tolerance_s=0.02, # Adjust tolerance if needed
|
tolerance_s=0.02,
|
||||||
backend="pyav", # Default backend (modify if needed)
|
backend="pyav",
|
||||||
log_loaded_timestamps=False,
|
log_loaded_timestamps=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if vid_key not in batched_frames:
|
# stack frames for this video key
|
||||||
batched_frames[vid_key] = []
|
item[vid_key] = torch.stack(frames)
|
||||||
batched_frames[vid_key].append(frames)
|
|
||||||
|
|
||||||
# Convert lists to tensors where possible
|
# add item data (both video and non-video) to final_batch
|
||||||
for key in batched_frames:
|
for key, value in item.items():
|
||||||
batched_frames[key] = torch.stack(batched_frames[key]) # Stack tensors
|
if key not in final_batch:
|
||||||
|
final_batch[key] = []
|
||||||
for key in final_batch:
|
final_batch[key].append(value)
|
||||||
if isinstance(final_batch[key][0], torch.Tensor):
|
|
||||||
final_batch[key] = torch.stack(final_batch[key])
|
|
||||||
|
|
||||||
# **Fix: Ensure video_frames is a single tensor instead of a dictionary**
|
|
||||||
# hard coded this must change
|
|
||||||
if len(batched_frames) == 1:
|
|
||||||
final_batch["observation.images.top"] = list(batched_frames.values())[0] # Direct tensor
|
|
||||||
else:
|
|
||||||
final_batch["observation.images.top"] = batched_frames # Keep dict if multiple
|
|
||||||
|
|
||||||
return final_batch
|
return final_batch
|
||||||
|
|
||||||
|
|
||||||
def update_policy(
|
def update_policy(
|
||||||
train_metrics: MetricsTracker,
|
train_metrics: MetricsTracker,
|
||||||
policy: PreTrainedPolicy,
|
policy: PreTrainedPolicy,
|
||||||
|
|
Loading…
Reference in New Issue