[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2025-03-03 07:26:27 +00:00
parent a963dba256
commit a8fcd3512d
2 changed files with 21 additions and 21 deletions

View File

@ -462,7 +462,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
video files are already present on local disk, they won't be downloaded again. Defaults to video files are already present on local disk, they won't be downloaded again. Defaults to
True. True.
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec. video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec.
You can also use the 'pyav' decoder used by Torchvision. You can also use the 'pyav' decoder used by Torchvision.
""" """
super().__init__() super().__init__()
@ -707,9 +707,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
item = {} item = {}
for vid_key, query_ts in query_timestamps.items(): for vid_key, query_ts in query_timestamps.items():
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key) video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
frames = decode_video_frames( frames = decode_video_frames(video_path, query_ts, self.tolerance_s, self.video_backend)
video_path, query_ts, self.tolerance_s, self.video_backend
)
item[vid_key] = frames.squeeze(0) item[vid_key] = frames.squeeze(0)
return item return item

View File

@ -29,30 +29,32 @@ from datasets.features.features import register_feature
from PIL import Image from PIL import Image
from torchcodec.decoders import VideoDecoder from torchcodec.decoders import VideoDecoder
def decode_video_frames( def decode_video_frames(
video_path: Path | str, video_path: Path | str,
timestamps: list[float], timestamps: list[float],
tolerance_s: float, tolerance_s: float,
backend: str = "torchcodec", backend: str = "torchcodec",
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Decodes video frames using the specified backend. Decodes video frames using the specified backend.
Args: Args:
video_path (Path): Path to the video file. video_path (Path): Path to the video file.
query_ts (list[float]): List of timestamps to extract frames. query_ts (list[float]): List of timestamps to extract frames.
Returns: Returns:
torch.Tensor: Decoded frames. torch.Tensor: Decoded frames.
Currently supports torchcodec on cpu and pyav.
"""
if backend == "torchcodec":
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
elif backend == "pyav":
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
else:
raise ValueError(f"Unsupported video backend: {backend}")
Currently supports torchcodec on cpu and pyav.
"""
if backend == "torchcodec":
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
elif backend == "pyav":
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
else:
raise ValueError(f"Unsupported video backend: {backend}")
def decode_video_frames_torchvision( def decode_video_frames_torchvision(
video_path: Path | str, video_path: Path | str,