[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
a963dba256
commit
a8fcd3512d
|
@ -462,7 +462,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
|
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
|
||||||
video files are already present on local disk, they won't be downloaded again. Defaults to
|
video files are already present on local disk, they won't be downloaded again. Defaults to
|
||||||
True.
|
True.
|
||||||
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec.
|
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec.
|
||||||
You can also use the 'pyav' decoder used by Torchvision.
|
You can also use the 'pyav' decoder used by Torchvision.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -707,9 +707,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
item = {}
|
item = {}
|
||||||
for vid_key, query_ts in query_timestamps.items():
|
for vid_key, query_ts in query_timestamps.items():
|
||||||
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
|
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
|
||||||
frames = decode_video_frames(
|
frames = decode_video_frames(video_path, query_ts, self.tolerance_s, self.video_backend)
|
||||||
video_path, query_ts, self.tolerance_s, self.video_backend
|
|
||||||
)
|
|
||||||
item[vid_key] = frames.squeeze(0)
|
item[vid_key] = frames.squeeze(0)
|
||||||
|
|
||||||
return item
|
return item
|
||||||
|
|
|
@ -29,30 +29,32 @@ from datasets.features.features import register_feature
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torchcodec.decoders import VideoDecoder
|
from torchcodec.decoders import VideoDecoder
|
||||||
|
|
||||||
|
|
||||||
def decode_video_frames(
|
def decode_video_frames(
|
||||||
video_path: Path | str,
|
video_path: Path | str,
|
||||||
timestamps: list[float],
|
timestamps: list[float],
|
||||||
tolerance_s: float,
|
tolerance_s: float,
|
||||||
backend: str = "torchcodec",
|
backend: str = "torchcodec",
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Decodes video frames using the specified backend.
|
Decodes video frames using the specified backend.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
video_path (Path): Path to the video file.
|
video_path (Path): Path to the video file.
|
||||||
query_ts (list[float]): List of timestamps to extract frames.
|
query_ts (list[float]): List of timestamps to extract frames.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: Decoded frames.
|
torch.Tensor: Decoded frames.
|
||||||
|
|
||||||
|
Currently supports torchcodec on cpu and pyav.
|
||||||
|
"""
|
||||||
|
if backend == "torchcodec":
|
||||||
|
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
|
||||||
|
elif backend == "pyav":
|
||||||
|
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported video backend: {backend}")
|
||||||
|
|
||||||
Currently supports torchcodec on cpu and pyav.
|
|
||||||
"""
|
|
||||||
if backend == "torchcodec":
|
|
||||||
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
|
|
||||||
elif backend == "pyav":
|
|
||||||
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported video backend: {backend}")
|
|
||||||
|
|
||||||
def decode_video_frames_torchvision(
|
def decode_video_frames_torchvision(
|
||||||
video_path: Path | str,
|
video_path: Path | str,
|
||||||
|
|
Loading…
Reference in New Issue