From a963dba2568986269a60cc63443d32d5d28ff6db Mon Sep 17 00:00:00 2001 From: root Date: Mon, 3 Mar 2025 07:25:56 +0000 Subject: [PATCH] add dependency --- lerobot/common/datasets/lerobot_dataset.py | 14 +++++++------ lerobot/common/datasets/video_utils.py | 24 ++++++++++++++++++++++ 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 8afabd1f..bd9ca1b2 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -67,7 +67,7 @@ from lerobot.common.datasets.utils import ( ) from lerobot.common.datasets.video_utils import ( VideoFrame, - decode_video_frames_torchcodec, + decode_video_frames, encode_video_frames, get_video_info, ) @@ -462,8 +462,8 @@ class LeRobotDataset(torch.utils.data.Dataset): download_videos (bool, optional): Flag to download the videos. Note that when set to True but the video files are already present on local disk, they won't be downloaded again. Defaults to True. - video_backend (str | None, optional): Video backend to use for decoding videos. There is currently - a single option which is the pyav decoder used by Torchvision. Defaults to pyav. + video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec. + You can also use the 'pyav' decoder used by Torchvision. """ super().__init__() self.repo_id = repo_id @@ -473,7 +473,7 @@ class LeRobotDataset(torch.utils.data.Dataset): self.episodes = episodes self.tolerance_s = tolerance_s self.revision = revision if revision else CODEBASE_VERSION - self.video_backend = video_backend if video_backend else "pyav" + self.video_backend = video_backend if video_backend else "torchcodec" self.delta_indices = None # Unused attributes @@ -707,7 +707,9 @@ class LeRobotDataset(torch.utils.data.Dataset): item = {} for vid_key, query_ts in query_timestamps.items(): video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key) - frames = decode_video_frames_torchcodec(video_path, query_ts, self.tolerance_s) + frames = decode_video_frames( + video_path, query_ts, self.tolerance_s, self.video_backend + ) item[vid_key] = frames.squeeze(0) return item @@ -1027,7 +1029,7 @@ class LeRobotDataset(torch.utils.data.Dataset): obj.delta_timestamps = None obj.delta_indices = None obj.episode_data_index = None - obj.video_backend = video_backend if video_backend is not None else "pyav" + obj.video_backend = video_backend if video_backend is not None else "torchcodec" return obj diff --git a/lerobot/common/datasets/video_utils.py b/lerobot/common/datasets/video_utils.py index 587b11cb..27a60943 100644 --- a/lerobot/common/datasets/video_utils.py +++ b/lerobot/common/datasets/video_utils.py @@ -29,6 +29,30 @@ from datasets.features.features import register_feature from PIL import Image from torchcodec.decoders import VideoDecoder +def decode_video_frames( + video_path: Path | str, + timestamps: list[float], + tolerance_s: float, + backend: str = "torchcodec", +) -> torch.Tensor: + """ + Decodes video frames using the specified backend. + + Args: + video_path (Path): Path to the video file. + query_ts (list[float]): List of timestamps to extract frames. + + Returns: + torch.Tensor: Decoded frames. + + Currently supports torchcodec on cpu and pyav. + """ + if backend == "torchcodec": + return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s) + elif backend == "pyav": + return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend) + else: + raise ValueError(f"Unsupported video backend: {backend}") def decode_video_frames_torchvision( video_path: Path | str,