add dependency

This commit is contained in:
root 2025-03-03 07:25:56 +00:00
parent 2f9cbfbc4f
commit a963dba256
2 changed files with 32 additions and 6 deletions

View File

@ -67,7 +67,7 @@ from lerobot.common.datasets.utils import (
) )
from lerobot.common.datasets.video_utils import ( from lerobot.common.datasets.video_utils import (
VideoFrame, VideoFrame,
decode_video_frames_torchcodec, decode_video_frames,
encode_video_frames, encode_video_frames,
get_video_info, get_video_info,
) )
@ -462,8 +462,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
video files are already present on local disk, they won't be downloaded again. Defaults to video files are already present on local disk, they won't be downloaded again. Defaults to
True. True.
video_backend (str | None, optional): Video backend to use for decoding videos. There is currently video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec.
a single option which is the pyav decoder used by Torchvision. Defaults to pyav. You can also use the 'pyav' decoder used by Torchvision.
""" """
super().__init__() super().__init__()
self.repo_id = repo_id self.repo_id = repo_id
@ -473,7 +473,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.episodes = episodes self.episodes = episodes
self.tolerance_s = tolerance_s self.tolerance_s = tolerance_s
self.revision = revision if revision else CODEBASE_VERSION self.revision = revision if revision else CODEBASE_VERSION
self.video_backend = video_backend if video_backend else "pyav" self.video_backend = video_backend if video_backend else "torchcodec"
self.delta_indices = None self.delta_indices = None
# Unused attributes # Unused attributes
@ -707,7 +707,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
item = {} item = {}
for vid_key, query_ts in query_timestamps.items(): for vid_key, query_ts in query_timestamps.items():
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key) video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
frames = decode_video_frames_torchcodec(video_path, query_ts, self.tolerance_s) frames = decode_video_frames(
video_path, query_ts, self.tolerance_s, self.video_backend
)
item[vid_key] = frames.squeeze(0) item[vid_key] = frames.squeeze(0)
return item return item
@ -1027,7 +1029,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.delta_timestamps = None obj.delta_timestamps = None
obj.delta_indices = None obj.delta_indices = None
obj.episode_data_index = None obj.episode_data_index = None
obj.video_backend = video_backend if video_backend is not None else "pyav" obj.video_backend = video_backend if video_backend is not None else "torchcodec"
return obj return obj

View File

@ -29,6 +29,30 @@ from datasets.features.features import register_feature
from PIL import Image from PIL import Image
from torchcodec.decoders import VideoDecoder from torchcodec.decoders import VideoDecoder
def decode_video_frames(
video_path: Path | str,
timestamps: list[float],
tolerance_s: float,
backend: str = "torchcodec",
) -> torch.Tensor:
"""
Decodes video frames using the specified backend.
Args:
video_path (Path): Path to the video file.
query_ts (list[float]): List of timestamps to extract frames.
Returns:
torch.Tensor: Decoded frames.
Currently supports torchcodec on cpu and pyav.
"""
if backend == "torchcodec":
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
elif backend == "pyav":
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
else:
raise ValueError(f"Unsupported video backend: {backend}")
def decode_video_frames_torchvision( def decode_video_frames_torchvision(
video_path: Path | str, video_path: Path | str,