diff --git a/lerobot/common/datasets/video_utils.py b/lerobot/common/datasets/video_utils.py index 8ed3318d..0b67f9e9 100644 --- a/lerobot/common/datasets/video_utils.py +++ b/lerobot/common/datasets/video_utils.py @@ -127,6 +127,67 @@ def decode_video_frames_torchvision( return closest_frames +def decode_video_frames_torchcodec( + video_path: Path | str, + timestamps: list[float], + tolerance_s: float, + device: str = "cpu", + log_loaded_timestamps: bool = False, +) -> torch.Tensor: + """Loads frames associated with the requested timestamps of a video using torchcodec.""" + video_path = str(video_path) + # initialize video decoder + from torchcodec.decoders import VideoDecoder + decoder = VideoDecoder(video_path, device=device) + loaded_frames = [] + loaded_ts = [] + # get metadata for frame information + metadata = decoder.metadata + average_fps = metadata.average_fps + + # convert timestamps to frame indices + frame_indices = [int(ts * average_fps) for ts in timestamps] + + # retrieve frames based on indices + frames_batch = decoder.get_frames_at(indices=frame_indices) + + for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds): + loaded_frames.append(frame) + loaded_ts.append(pts.item()) + if log_loaded_timestamps: + logging.info(f"Frame loaded at timestamp={pts:.4f}") + + query_ts = torch.tensor(timestamps) + loaded_ts = torch.tensor(loaded_ts) + + # compute distances between each query timestamp and loaded timestamps + dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1) + min_, argmin_ = dist.min(1) + + is_within_tol = min_ < tolerance_s + assert is_within_tol.all(), ( + f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})." + "It means that the closest frame that can be loaded from the video is too far away in time." + "This might be due to synchronization issues with timestamps during data collection." + "To be safe, we advise to ignore this item during training." + f"\nqueried timestamps: {query_ts}" + f"\nloaded timestamps: {loaded_ts}" + f"\nvideo: {video_path}" + ) + + # get closest frames to the query timestamps + closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_]) + closest_ts = loaded_ts[argmin_] + + if log_loaded_timestamps: + logging.info(f"{closest_ts=}") + + # convert to float32 in [0,1] range (channel first) + closest_frames = closest_frames.type(torch.float32) / 255 + + assert len(timestamps) == len(closest_frames) + return closest_frames + def encode_video_frames( imgs_dir: Path | str, video_path: Path | str,