433 lines
15 KiB
Python
433 lines
15 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import importlib
|
|
import json
|
|
import logging
|
|
import subprocess
|
|
import warnings
|
|
from collections import OrderedDict
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Any, ClassVar
|
|
|
|
import pyarrow as pa
|
|
import torch
|
|
import torchvision
|
|
from datasets.features.features import register_feature
|
|
from PIL import Image
|
|
|
|
|
|
def get_safe_default_codec():
|
|
if importlib.util.find_spec("torchcodec"):
|
|
return "torchcodec"
|
|
else:
|
|
logging.warning(
|
|
"'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder"
|
|
)
|
|
return "pyav"
|
|
|
|
|
|
def decode_video_frames(
|
|
video_path: Path | str,
|
|
timestamps: list[float],
|
|
tolerance_s: float,
|
|
backend: str | None = None,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Decodes video frames using the specified backend.
|
|
|
|
Args:
|
|
video_path (Path): Path to the video file.
|
|
timestamps (list[float]): List of timestamps to extract frames.
|
|
tolerance_s (float): Allowed deviation in seconds for frame retrieval.
|
|
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available in the platform; otherwise, defaults to "pyav"..
|
|
|
|
Returns:
|
|
torch.Tensor: Decoded frames.
|
|
|
|
Currently supports torchcodec on cpu and pyav.
|
|
"""
|
|
if backend is None:
|
|
backend = get_safe_default_codec()
|
|
if backend == "torchcodec":
|
|
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
|
|
elif backend in ["pyav", "video_reader"]:
|
|
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
|
else:
|
|
raise ValueError(f"Unsupported video backend: {backend}")
|
|
|
|
|
|
def decode_video_frames_torchvision(
|
|
video_path: Path | str,
|
|
timestamps: list[float],
|
|
tolerance_s: float,
|
|
backend: str = "pyav",
|
|
log_loaded_timestamps: bool = False,
|
|
) -> torch.Tensor:
|
|
"""Loads frames associated to the requested timestamps of a video
|
|
|
|
The backend can be either "pyav" (default) or "video_reader".
|
|
"video_reader" requires installing torchvision from source, see:
|
|
https://github.com/pytorch/vision/blob/main/torchvision/csrc/io/decoder/gpu/README.rst
|
|
(note that you need to compile against ffmpeg<4.3)
|
|
|
|
While both use cpu, "video_reader" is supposedly faster than "pyav" but requires additional setup.
|
|
For more info on video decoding, see `benchmark/video/README.md`
|
|
|
|
See torchvision doc for more info on these two backends:
|
|
https://pytorch.org/vision/0.18/index.html?highlight=backend#torchvision.set_video_backend
|
|
|
|
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
|
|
the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to
|
|
that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame,
|
|
and all subsequent frames until reaching the requested frame. The number of key frames in a video
|
|
can be adjusted during encoding to take into account decoding time and video size in bytes.
|
|
"""
|
|
video_path = str(video_path)
|
|
|
|
# set backend
|
|
keyframes_only = False
|
|
torchvision.set_video_backend(backend)
|
|
if backend == "pyav":
|
|
keyframes_only = True # pyav doesnt support accuracte seek
|
|
|
|
# set a video stream reader
|
|
# TODO(rcadene): also load audio stream at the same time
|
|
reader = torchvision.io.VideoReader(video_path, "video")
|
|
|
|
# set the first and last requested timestamps
|
|
# Note: previous timestamps are usually loaded, since we need to access the previous key frame
|
|
first_ts = min(timestamps)
|
|
last_ts = max(timestamps)
|
|
|
|
# access closest key frame of the first requested frame
|
|
# Note: closest key frame timestamp is usually smaller than `first_ts` (e.g. key frame can be the first frame of the video)
|
|
# for details on what `seek` is doing see: https://pyav.basswood-io.com/docs/stable/api/container.html?highlight=inputcontainer#av.container.InputContainer.seek
|
|
reader.seek(first_ts, keyframes_only=keyframes_only)
|
|
|
|
# load all frames until last requested frame
|
|
loaded_frames = []
|
|
loaded_ts = []
|
|
for frame in reader:
|
|
current_ts = frame["pts"]
|
|
if log_loaded_timestamps:
|
|
logging.info(f"frame loaded at timestamp={current_ts:.4f}")
|
|
loaded_frames.append(frame["data"])
|
|
loaded_ts.append(current_ts)
|
|
if current_ts >= last_ts:
|
|
break
|
|
|
|
if backend == "pyav":
|
|
reader.container.close()
|
|
|
|
reader = None
|
|
|
|
query_ts = torch.tensor(timestamps)
|
|
loaded_ts = torch.tensor(loaded_ts)
|
|
|
|
# compute distances between each query timestamp and timestamps of all loaded frames
|
|
dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
|
|
min_, argmin_ = dist.min(1)
|
|
|
|
is_within_tol = min_ < tolerance_s
|
|
assert is_within_tol.all(), (
|
|
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
|
"It means that the closest frame that can be loaded from the video is too far away in time."
|
|
"This might be due to synchronization issues with timestamps during data collection."
|
|
"To be safe, we advise to ignore this item during training."
|
|
f"\nqueried timestamps: {query_ts}"
|
|
f"\nloaded timestamps: {loaded_ts}"
|
|
f"\nvideo: {video_path}"
|
|
f"\nbackend: {backend}"
|
|
)
|
|
|
|
# get closest frames to the query timestamps
|
|
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
|
|
closest_ts = loaded_ts[argmin_]
|
|
|
|
if log_loaded_timestamps:
|
|
logging.info(f"{closest_ts=}")
|
|
|
|
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
|
|
closest_frames = closest_frames.type(torch.float32) / 255
|
|
|
|
assert len(timestamps) == len(closest_frames)
|
|
return closest_frames
|
|
|
|
|
|
def decode_video_frames_torchcodec(
|
|
video_path: Path | str,
|
|
timestamps: list[float],
|
|
tolerance_s: float,
|
|
device: str = "cpu",
|
|
log_loaded_timestamps: bool = False,
|
|
) -> torch.Tensor:
|
|
"""Loads frames associated with the requested timestamps of a video using torchcodec.
|
|
|
|
Note: Setting device="cuda" outside the main process, e.g. in data loader workers, will lead to CUDA initialization errors.
|
|
|
|
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
|
|
the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to
|
|
that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame,
|
|
and all subsequent frames until reaching the requested frame. The number of key frames in a video
|
|
can be adjusted during encoding to take into account decoding time and video size in bytes.
|
|
"""
|
|
|
|
if importlib.util.find_spec("torchcodec"):
|
|
from torchcodec.decoders import VideoDecoder
|
|
else:
|
|
raise ImportError("torchcodec is required but not available.")
|
|
|
|
# initialize video decoder
|
|
decoder = VideoDecoder(video_path, device=device, seek_mode="approximate")
|
|
loaded_frames = []
|
|
loaded_ts = []
|
|
# get metadata for frame information
|
|
metadata = decoder.metadata
|
|
average_fps = metadata.average_fps
|
|
|
|
# convert timestamps to frame indices
|
|
frame_indices = [round(ts * average_fps) for ts in timestamps]
|
|
|
|
# retrieve frames based on indices
|
|
frames_batch = decoder.get_frames_at(indices=frame_indices)
|
|
|
|
for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=False):
|
|
loaded_frames.append(frame)
|
|
loaded_ts.append(pts.item())
|
|
if log_loaded_timestamps:
|
|
logging.info(f"Frame loaded at timestamp={pts:.4f}")
|
|
|
|
query_ts = torch.tensor(timestamps)
|
|
loaded_ts = torch.tensor(loaded_ts)
|
|
|
|
# compute distances between each query timestamp and loaded timestamps
|
|
dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
|
|
min_, argmin_ = dist.min(1)
|
|
|
|
is_within_tol = min_ < tolerance_s
|
|
assert is_within_tol.all(), (
|
|
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
|
"It means that the closest frame that can be loaded from the video is too far away in time."
|
|
"This might be due to synchronization issues with timestamps during data collection."
|
|
"To be safe, we advise to ignore this item during training."
|
|
f"\nqueried timestamps: {query_ts}"
|
|
f"\nloaded timestamps: {loaded_ts}"
|
|
f"\nvideo: {video_path}"
|
|
)
|
|
|
|
# get closest frames to the query timestamps
|
|
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
|
|
closest_ts = loaded_ts[argmin_]
|
|
|
|
if log_loaded_timestamps:
|
|
logging.info(f"{closest_ts=}")
|
|
|
|
# convert to float32 in [0,1] range (channel first)
|
|
closest_frames = closest_frames.type(torch.float32) / 255
|
|
|
|
assert len(timestamps) == len(closest_frames)
|
|
return closest_frames
|
|
|
|
|
|
def encode_video_frames(
|
|
imgs_dir: Path | str,
|
|
video_path: Path | str,
|
|
fps: int,
|
|
vcodec: str = "libsvtav1",
|
|
pix_fmt: str = "yuv420p",
|
|
g: int | None = 2,
|
|
crf: int | None = 30,
|
|
fast_decode: int = 0,
|
|
log_level: str | None = "error",
|
|
overwrite: bool = False,
|
|
) -> None:
|
|
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
|
|
video_path = Path(video_path)
|
|
imgs_dir = Path(imgs_dir)
|
|
video_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
ffmpeg_args = OrderedDict(
|
|
[
|
|
("-f", "image2"),
|
|
("-r", str(fps)),
|
|
("-i", str(imgs_dir / "frame_%06d.png")),
|
|
("-vcodec", vcodec),
|
|
("-pix_fmt", pix_fmt),
|
|
]
|
|
)
|
|
|
|
if g is not None:
|
|
ffmpeg_args["-g"] = str(g)
|
|
|
|
if crf is not None:
|
|
ffmpeg_args["-crf"] = str(crf)
|
|
|
|
if fast_decode:
|
|
key = "-svtav1-params" if vcodec == "libsvtav1" else "-tune"
|
|
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
|
|
ffmpeg_args[key] = value
|
|
|
|
if log_level is not None:
|
|
ffmpeg_args["-loglevel"] = str(log_level)
|
|
|
|
ffmpeg_args = [item for pair in ffmpeg_args.items() for item in pair]
|
|
if overwrite:
|
|
ffmpeg_args.append("-y")
|
|
|
|
ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(video_path)]
|
|
# redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal
|
|
subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL)
|
|
|
|
if not video_path.exists():
|
|
raise OSError(
|
|
f"Video encoding did not work. File not found: {video_path}. "
|
|
f"Try running the command manually to debug: `{''.join(ffmpeg_cmd)}`"
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class VideoFrame:
|
|
# TODO(rcadene, lhoestq): move to Hugging Face `datasets` repo
|
|
"""
|
|
Provides a type for a dataset containing video frames.
|
|
|
|
Example:
|
|
|
|
```python
|
|
data_dict = [{"image": {"path": "videos/episode_0.mp4", "timestamp": 0.3}}]
|
|
features = {"image": VideoFrame()}
|
|
Dataset.from_dict(data_dict, features=Features(features))
|
|
```
|
|
"""
|
|
|
|
pa_type: ClassVar[Any] = pa.struct({"path": pa.string(), "timestamp": pa.float32()})
|
|
_type: str = field(default="VideoFrame", init=False, repr=False)
|
|
|
|
def __call__(self):
|
|
return self.pa_type
|
|
|
|
|
|
with warnings.catch_warnings():
|
|
warnings.filterwarnings(
|
|
"ignore",
|
|
"'register_feature' is experimental and might be subject to breaking changes in the future.",
|
|
category=UserWarning,
|
|
)
|
|
# to make VideoFrame available in HuggingFace `datasets`
|
|
register_feature(VideoFrame, "VideoFrame")
|
|
|
|
|
|
def get_audio_info(video_path: Path | str) -> dict:
|
|
ffprobe_audio_cmd = [
|
|
"ffprobe",
|
|
"-v",
|
|
"error",
|
|
"-select_streams",
|
|
"a:0",
|
|
"-show_entries",
|
|
"stream=channels,codec_name,bit_rate,sample_rate,bit_depth,channel_layout,duration",
|
|
"-of",
|
|
"json",
|
|
str(video_path),
|
|
]
|
|
result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
|
if result.returncode != 0:
|
|
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
|
|
|
info = json.loads(result.stdout)
|
|
audio_stream_info = info["streams"][0] if info.get("streams") else None
|
|
if audio_stream_info is None:
|
|
return {"has_audio": False}
|
|
|
|
# Return the information, defaulting to None if no audio stream is present
|
|
return {
|
|
"has_audio": True,
|
|
"audio.channels": audio_stream_info.get("channels", None),
|
|
"audio.codec": audio_stream_info.get("codec_name", None),
|
|
"audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None,
|
|
"audio.sample_rate": int(audio_stream_info["sample_rate"])
|
|
if audio_stream_info.get("sample_rate")
|
|
else None,
|
|
"audio.bit_depth": audio_stream_info.get("bit_depth", None),
|
|
"audio.channel_layout": audio_stream_info.get("channel_layout", None),
|
|
}
|
|
|
|
|
|
def get_video_info(video_path: Path | str) -> dict:
|
|
ffprobe_video_cmd = [
|
|
"ffprobe",
|
|
"-v",
|
|
"error",
|
|
"-select_streams",
|
|
"v:0",
|
|
"-show_entries",
|
|
"stream=r_frame_rate,width,height,codec_name,nb_frames,duration,pix_fmt",
|
|
"-of",
|
|
"json",
|
|
str(video_path),
|
|
]
|
|
result = subprocess.run(ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
|
if result.returncode != 0:
|
|
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
|
|
|
info = json.loads(result.stdout)
|
|
video_stream_info = info["streams"][0]
|
|
|
|
# Calculate fps from r_frame_rate
|
|
r_frame_rate = video_stream_info["r_frame_rate"]
|
|
num, denom = map(int, r_frame_rate.split("/"))
|
|
fps = num / denom
|
|
|
|
pixel_channels = get_video_pixel_channels(video_stream_info["pix_fmt"])
|
|
|
|
video_info = {
|
|
"video.fps": fps,
|
|
"video.height": video_stream_info["height"],
|
|
"video.width": video_stream_info["width"],
|
|
"video.channels": pixel_channels,
|
|
"video.codec": video_stream_info["codec_name"],
|
|
"video.pix_fmt": video_stream_info["pix_fmt"],
|
|
"video.is_depth_map": False,
|
|
**get_audio_info(video_path),
|
|
}
|
|
|
|
return video_info
|
|
|
|
|
|
def get_video_pixel_channels(pix_fmt: str) -> int:
|
|
if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt:
|
|
return 1
|
|
elif "rgba" in pix_fmt or "yuva" in pix_fmt:
|
|
return 4
|
|
elif "rgb" in pix_fmt or "yuv" in pix_fmt:
|
|
return 3
|
|
else:
|
|
raise ValueError("Unknown format")
|
|
|
|
|
|
def get_image_pixel_channels(image: Image):
|
|
if image.mode == "L":
|
|
return 1 # Grayscale
|
|
elif image.mode == "LA":
|
|
return 2 # Grayscale + Alpha
|
|
elif image.mode == "RGB":
|
|
return 3 # RGB
|
|
elif image.mode == "RGBA":
|
|
return 4 # RGBA
|
|
else:
|
|
raise ValueError("Unknown format")
|