From 4e2dc91e59f59449961837b46eef65223b2cbaf2 Mon Sep 17 00:00:00 2001 From: root Date: Sun, 2 Mar 2025 20:47:33 +0000 Subject: [PATCH] add torchcodec cpu --- lerobot/common/datasets/lerobot_dataset.py | 6 +-- lerobot/common/datasets/video_utils.py | 63 +++++++++++++++++++++- 2 files changed, 64 insertions(+), 5 deletions(-) diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 5414c76d..d5a780ac 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -67,7 +67,7 @@ from lerobot.common.datasets.utils import ( ) from lerobot.common.datasets.video_utils import ( VideoFrame, - decode_video_frames_torchvision, + decode_video_frames_torchcodec, encode_video_frames, get_video_info, ) @@ -707,8 +707,8 @@ class LeRobotDataset(torch.utils.data.Dataset): item = {} for vid_key, query_ts in query_timestamps.items(): video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key) - frames = decode_video_frames_torchvision( - video_path, query_ts, self.tolerance_s, self.video_backend + frames = decode_video_frames_torchcodec( + video_path, query_ts, self.tolerance_s ) item[vid_key] = frames.squeeze(0) diff --git a/lerobot/common/datasets/video_utils.py b/lerobot/common/datasets/video_utils.py index 9f043f96..3090666e 100644 --- a/lerobot/common/datasets/video_utils.py +++ b/lerobot/common/datasets/video_utils.py @@ -27,7 +27,7 @@ import torch import torchvision from datasets.features.features import register_feature from PIL import Image - +from torchcodec.decoders import VideoDecoder def decode_video_frames_torchvision( video_path: Path | str, @@ -125,7 +125,66 @@ def decode_video_frames_torchvision( assert len(timestamps) == len(closest_frames) return closest_frames - + +def decode_video_frames_torchcodec( + video_path: Path | str, + timestamps: list[float], + tolerance_s: float, + device: str = "cpu", + log_loaded_timestamps: bool = False, +) -> torch.Tensor: + """Loads frames associated with the requested timestamps of a video using torchcodec.""" + video_path = str(video_path) + # initialize video decoder + decoder = VideoDecoder(video_path, device=device) + loaded_frames = [] + loaded_ts = [] + # get metadata for frame information + metadata = decoder.metadata + average_fps = metadata.average_fps + + # convert timestamps to frame indices + frame_indices = [round(ts * average_fps) for ts in timestamps] + + # retrieve frames based on indices + frames_batch = decoder.get_frames_at(indices=frame_indices) + + for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds): + loaded_frames.append(frame) + loaded_ts.append(pts.item()) + if log_loaded_timestamps: + logging.info(f"Frame loaded at timestamp={pts:.4f}") + + query_ts = torch.tensor(timestamps) + loaded_ts = torch.tensor(loaded_ts) + + # compute distances between each query timestamp and loaded timestamps + dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1) + min_, argmin_ = dist.min(1) + + is_within_tol = min_ < tolerance_s + assert is_within_tol.all(), ( + f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})." + "It means that the closest frame that can be loaded from the video is too far away in time." + "This might be due to synchronization issues with timestamps during data collection." + "To be safe, we advise to ignore this item during training." + f"\nqueried timestamps: {query_ts}" + f"\nloaded timestamps: {loaded_ts}" + f"\nvideo: {video_path}" + ) + + # get closest frames to the query timestamps + closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_]) + closest_ts = loaded_ts[argmin_] + + if log_loaded_timestamps: + logging.info(f"{closest_ts=}") + + # convert to float32 in [0,1] range (channel first) + closest_frames = closest_frames.type(torch.float32) / 255 + + assert len(timestamps) == len(closest_frames) + return closest_frames def encode_video_frames( imgs_dir: Path | str,