add dependency
This commit is contained in:
parent
2f9cbfbc4f
commit
a963dba256
|
@ -67,7 +67,7 @@ from lerobot.common.datasets.utils import (
|
|||
)
|
||||
from lerobot.common.datasets.video_utils import (
|
||||
VideoFrame,
|
||||
decode_video_frames_torchcodec,
|
||||
decode_video_frames,
|
||||
encode_video_frames,
|
||||
get_video_info,
|
||||
)
|
||||
|
@ -462,8 +462,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
|
||||
video files are already present on local disk, they won't be downloaded again. Defaults to
|
||||
True.
|
||||
video_backend (str | None, optional): Video backend to use for decoding videos. There is currently
|
||||
a single option which is the pyav decoder used by Torchvision. Defaults to pyav.
|
||||
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec.
|
||||
You can also use the 'pyav' decoder used by Torchvision.
|
||||
"""
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
|
@ -473,7 +473,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
self.episodes = episodes
|
||||
self.tolerance_s = tolerance_s
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self.video_backend = video_backend if video_backend else "pyav"
|
||||
self.video_backend = video_backend if video_backend else "torchcodec"
|
||||
self.delta_indices = None
|
||||
|
||||
# Unused attributes
|
||||
|
@ -707,7 +707,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
item = {}
|
||||
for vid_key, query_ts in query_timestamps.items():
|
||||
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
|
||||
frames = decode_video_frames_torchcodec(video_path, query_ts, self.tolerance_s)
|
||||
frames = decode_video_frames(
|
||||
video_path, query_ts, self.tolerance_s, self.video_backend
|
||||
)
|
||||
item[vid_key] = frames.squeeze(0)
|
||||
|
||||
return item
|
||||
|
@ -1027,7 +1029,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
obj.delta_timestamps = None
|
||||
obj.delta_indices = None
|
||||
obj.episode_data_index = None
|
||||
obj.video_backend = video_backend if video_backend is not None else "pyav"
|
||||
obj.video_backend = video_backend if video_backend is not None else "torchcodec"
|
||||
return obj
|
||||
|
||||
|
||||
|
|
|
@ -29,6 +29,30 @@ from datasets.features.features import register_feature
|
|||
from PIL import Image
|
||||
from torchcodec.decoders import VideoDecoder
|
||||
|
||||
def decode_video_frames(
|
||||
video_path: Path | str,
|
||||
timestamps: list[float],
|
||||
tolerance_s: float,
|
||||
backend: str = "torchcodec",
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Decodes video frames using the specified backend.
|
||||
|
||||
Args:
|
||||
video_path (Path): Path to the video file.
|
||||
query_ts (list[float]): List of timestamps to extract frames.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Decoded frames.
|
||||
|
||||
Currently supports torchcodec on cpu and pyav.
|
||||
"""
|
||||
if backend == "torchcodec":
|
||||
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
|
||||
elif backend == "pyav":
|
||||
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
||||
else:
|
||||
raise ValueError(f"Unsupported video backend: {backend}")
|
||||
|
||||
def decode_video_frames_torchvision(
|
||||
video_path: Path | str,
|
||||
|
|
Loading…
Reference in New Issue