[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2025-03-03 07:26:27 +00:00
parent a963dba256
commit a8fcd3512d
2 changed files with 21 additions and 21 deletions

View File

@ -462,7 +462,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
video files are already present on local disk, they won't be downloaded again. Defaults to
True.
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec.
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec.
You can also use the 'pyav' decoder used by Torchvision.
"""
super().__init__()
@ -707,9 +707,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
item = {}
for vid_key, query_ts in query_timestamps.items():
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
frames = decode_video_frames(
video_path, query_ts, self.tolerance_s, self.video_backend
)
frames = decode_video_frames(video_path, query_ts, self.tolerance_s, self.video_backend)
item[vid_key] = frames.squeeze(0)
return item

View File

@ -29,30 +29,32 @@ from datasets.features.features import register_feature
from PIL import Image
from torchcodec.decoders import VideoDecoder
def decode_video_frames(
video_path: Path | str,
timestamps: list[float],
tolerance_s: float,
backend: str = "torchcodec",
) -> torch.Tensor:
"""
Decodes video frames using the specified backend.
Args:
video_path (Path): Path to the video file.
query_ts (list[float]): List of timestamps to extract frames.
Returns:
torch.Tensor: Decoded frames.
"""
Decodes video frames using the specified backend.
Args:
video_path (Path): Path to the video file.
query_ts (list[float]): List of timestamps to extract frames.
Returns:
torch.Tensor: Decoded frames.
Currently supports torchcodec on cpu and pyav.
"""
if backend == "torchcodec":
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
elif backend == "pyav":
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
else:
raise ValueError(f"Unsupported video backend: {backend}")
Currently supports torchcodec on cpu and pyav.
"""
if backend == "torchcodec":
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
elif backend == "pyav":
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
else:
raise ValueError(f"Unsupported video backend: {backend}")
def decode_video_frames_torchvision(
video_path: Path | str,