diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index bd9ca1b2..1e000439 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -462,7 +462,7 @@ class LeRobotDataset(torch.utils.data.Dataset): download_videos (bool, optional): Flag to download the videos. Note that when set to True but the video files are already present on local disk, they won't be downloaded again. Defaults to True. - video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec. + video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec. You can also use the 'pyav' decoder used by Torchvision. """ super().__init__() @@ -707,9 +707,7 @@ class LeRobotDataset(torch.utils.data.Dataset): item = {} for vid_key, query_ts in query_timestamps.items(): video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key) - frames = decode_video_frames( - video_path, query_ts, self.tolerance_s, self.video_backend - ) + frames = decode_video_frames(video_path, query_ts, self.tolerance_s, self.video_backend) item[vid_key] = frames.squeeze(0) return item diff --git a/lerobot/common/datasets/video_utils.py b/lerobot/common/datasets/video_utils.py index 27a60943..9e042d92 100644 --- a/lerobot/common/datasets/video_utils.py +++ b/lerobot/common/datasets/video_utils.py @@ -29,30 +29,32 @@ from datasets.features.features import register_feature from PIL import Image from torchcodec.decoders import VideoDecoder + def decode_video_frames( video_path: Path | str, timestamps: list[float], tolerance_s: float, backend: str = "torchcodec", ) -> torch.Tensor: - """ - Decodes video frames using the specified backend. - - Args: - video_path (Path): Path to the video file. - query_ts (list[float]): List of timestamps to extract frames. - - Returns: - torch.Tensor: Decoded frames. + """ + Decodes video frames using the specified backend. + + Args: + video_path (Path): Path to the video file. + query_ts (list[float]): List of timestamps to extract frames. + + Returns: + torch.Tensor: Decoded frames. + + Currently supports torchcodec on cpu and pyav. + """ + if backend == "torchcodec": + return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s) + elif backend == "pyav": + return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend) + else: + raise ValueError(f"Unsupported video backend: {backend}") - Currently supports torchcodec on cpu and pyav. - """ - if backend == "torchcodec": - return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s) - elif backend == "pyav": - return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend) - else: - raise ValueError(f"Unsupported video backend: {backend}") def decode_video_frames_torchvision( video_path: Path | str,