diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index d5a780ac..8afabd1f 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -707,9 +707,7 @@ class LeRobotDataset(torch.utils.data.Dataset): item = {} for vid_key, query_ts in query_timestamps.items(): video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key) - frames = decode_video_frames_torchcodec( - video_path, query_ts, self.tolerance_s - ) + frames = decode_video_frames_torchcodec(video_path, query_ts, self.tolerance_s) item[vid_key] = frames.squeeze(0) return item diff --git a/lerobot/common/datasets/video_utils.py b/lerobot/common/datasets/video_utils.py index 3090666e..587b11cb 100644 --- a/lerobot/common/datasets/video_utils.py +++ b/lerobot/common/datasets/video_utils.py @@ -29,6 +29,7 @@ from datasets.features.features import register_feature from PIL import Image from torchcodec.decoders import VideoDecoder + def decode_video_frames_torchvision( video_path: Path | str, timestamps: list[float], @@ -125,7 +126,8 @@ def decode_video_frames_torchvision( assert len(timestamps) == len(closest_frames) return closest_frames - + + def decode_video_frames_torchcodec( video_path: Path | str, timestamps: list[float], @@ -142,26 +144,26 @@ def decode_video_frames_torchcodec( # get metadata for frame information metadata = decoder.metadata average_fps = metadata.average_fps - + # convert timestamps to frame indices frame_indices = [round(ts * average_fps) for ts in timestamps] - + # retrieve frames based on indices frames_batch = decoder.get_frames_at(indices=frame_indices) - - for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds): + + for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=False): loaded_frames.append(frame) loaded_ts.append(pts.item()) if log_loaded_timestamps: logging.info(f"Frame loaded at timestamp={pts:.4f}") - + query_ts = torch.tensor(timestamps) loaded_ts = torch.tensor(loaded_ts) - + # compute distances between each query timestamp and loaded timestamps dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1) min_, argmin_ = dist.min(1) - + is_within_tol = min_ < tolerance_s assert is_within_tol.all(), ( f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})." @@ -172,20 +174,21 @@ def decode_video_frames_torchcodec( f"\nloaded timestamps: {loaded_ts}" f"\nvideo: {video_path}" ) - + # get closest frames to the query timestamps closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_]) closest_ts = loaded_ts[argmin_] - + if log_loaded_timestamps: logging.info(f"{closest_ts=}") - + # convert to float32 in [0,1] range (channel first) closest_frames = closest_frames.type(torch.float32) / 255 - + assert len(timestamps) == len(closest_frames) return closest_frames + def encode_video_frames( imgs_dir: Path | str, video_path: Path | str,