2024-05-15 18:13:09 +08:00
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
2025-03-17 20:23:11 +08:00
import importlib
2024-11-30 02:04:00 +08:00
import json
2024-05-03 06:50:19 +08:00
import logging
import subprocess
import warnings
2024-07-10 02:20:25 +08:00
from collections import OrderedDict
2024-05-03 06:50:19 +08:00
from dataclasses import dataclass , field
from pathlib import Path
from typing import Any , ClassVar
import pyarrow as pa
import torch
import torchvision
from datasets . features . features import register_feature
2024-11-30 02:04:00 +08:00
from PIL import Image
2025-03-17 20:23:11 +08:00
def get_safe_default_codec ( ) :
if importlib . util . find_spec ( " torchcodec " ) :
return " torchcodec "
else :
logging . warning (
" ' torchcodec ' is not available in your platform, falling back to ' pyav ' as a default decoder "
)
return " pyav "
2025-03-14 23:53:42 +08:00
def decode_video_frames (
video_path : Path | str ,
timestamps : list [ float ] ,
tolerance_s : float ,
2025-03-17 20:23:11 +08:00
backend : str | None = None ,
2025-03-14 23:53:42 +08:00
) - > torch . Tensor :
"""
Decodes video frames using the specified backend .
Args :
video_path ( Path ) : Path to the video file .
timestamps ( list [ float ] ) : List of timestamps to extract frames .
tolerance_s ( float ) : Allowed deviation in seconds for frame retrieval .
2025-03-17 20:23:11 +08:00
backend ( str , optional ) : Backend to use for decoding . Defaults to " torchcodec " when available in the platform ; otherwise , defaults to " pyav " . .
2025-03-14 23:53:42 +08:00
Returns :
torch . Tensor : Decoded frames .
Currently supports torchcodec on cpu and pyav .
"""
2025-03-17 20:23:11 +08:00
if backend is None :
backend = get_safe_default_codec ( )
2025-03-14 23:53:42 +08:00
if backend == " torchcodec " :
return decode_video_frames_torchcodec ( video_path , timestamps , tolerance_s )
elif backend in [ " pyav " , " video_reader " ] :
return decode_video_frames_torchvision ( video_path , timestamps , tolerance_s , backend )
else :
raise ValueError ( f " Unsupported video backend: { backend } " )
2024-05-03 06:50:19 +08:00
def decode_video_frames_torchvision (
2024-11-30 02:04:00 +08:00
video_path : Path | str ,
2024-05-03 06:50:19 +08:00
timestamps : list [ float ] ,
tolerance_s : float ,
2024-06-19 23:15:25 +08:00
backend : str = " pyav " ,
2024-05-03 06:50:19 +08:00
log_loaded_timestamps : bool = False ,
2024-07-10 02:20:25 +08:00
) - > torch . Tensor :
2024-05-03 06:50:19 +08:00
""" Loads frames associated to the requested timestamps of a video
2024-06-19 23:15:25 +08:00
The backend can be either " pyav " ( default ) or " video_reader " .
" video_reader " requires installing torchvision from source , see :
https : / / github . com / pytorch / vision / blob / main / torchvision / csrc / io / decoder / gpu / README . rst
( note that you need to compile against ffmpeg < 4.3 )
2024-07-10 02:20:25 +08:00
While both use cpu , " video_reader " is supposedly faster than " pyav " but requires additional setup .
For more info on video decoding , see ` benchmark / video / README . md `
2024-06-19 23:15:25 +08:00
See torchvision doc for more info on these two backends :
https : / / pytorch . org / vision / 0.18 / index . html ? highlight = backend #torchvision.set_video_backend
2024-05-03 06:50:19 +08:00
Note : Video benefits from inter - frame compression . Instead of storing every frame individually ,
the encoder stores a reference frame ( or a key frame ) and subsequent frames as differences relative to
that key frame . As a consequence , to access a requested frame , we need to load the preceding key frame ,
and all subsequent frames until reaching the requested frame . The number of key frames in a video
can be adjusted during encoding to take into account decoding time and video size in bytes .
"""
video_path = str ( video_path )
# set backend
keyframes_only = False
2024-06-19 23:15:25 +08:00
torchvision . set_video_backend ( backend )
if backend == " pyav " :
2024-05-03 06:50:19 +08:00
keyframes_only = True # pyav doesnt support accuracte seek
# set a video stream reader
# TODO(rcadene): also load audio stream at the same time
reader = torchvision . io . VideoReader ( video_path , " video " )
# set the first and last requested timestamps
# Note: previous timestamps are usually loaded, since we need to access the previous key frame
2025-02-25 22:27:29 +08:00
first_ts = min ( timestamps )
last_ts = max ( timestamps )
2024-05-03 06:50:19 +08:00
# access closest key frame of the first requested frame
2025-02-26 06:51:15 +08:00
# Note: closest key frame timestamp is usually smaller than `first_ts` (e.g. key frame can be the first frame of the video)
2024-05-03 06:50:19 +08:00
# for details on what `seek` is doing see: https://pyav.basswood-io.com/docs/stable/api/container.html?highlight=inputcontainer#av.container.InputContainer.seek
reader . seek ( first_ts , keyframes_only = keyframes_only )
# load all frames until last requested frame
loaded_frames = [ ]
loaded_ts = [ ]
for frame in reader :
current_ts = frame [ " pts " ]
if log_loaded_timestamps :
logging . info ( f " frame loaded at timestamp= { current_ts : .4f } " )
loaded_frames . append ( frame [ " data " ] )
loaded_ts . append ( current_ts )
if current_ts > = last_ts :
break
2024-06-19 23:15:25 +08:00
if backend == " pyav " :
reader . container . close ( )
2024-05-03 06:50:19 +08:00
reader = None
query_ts = torch . tensor ( timestamps )
loaded_ts = torch . tensor ( loaded_ts )
# compute distances between each query timestamp and timestamps of all loaded frames
dist = torch . cdist ( query_ts [ : , None ] , loaded_ts [ : , None ] , p = 1 )
min_ , argmin_ = dist . min ( 1 )
is_within_tol = min_ < tolerance_s
assert is_within_tol . all ( ) , (
f " One or several query timestamps unexpectedly violate the tolerance ( { min_ [ ~ is_within_tol ] } > { tolerance_s =} ). "
" It means that the closest frame that can be loaded from the video is too far away in time. "
" This might be due to synchronization issues with timestamps during data collection. "
" To be safe, we advise to ignore this item during training. "
2024-07-10 02:20:25 +08:00
f " \n queried timestamps: { query_ts } "
f " \n loaded timestamps: { loaded_ts } "
f " \n video: { video_path } "
f " \n backend: { backend } "
2024-05-03 06:50:19 +08:00
)
# get closest frames to the query timestamps
closest_frames = torch . stack ( [ loaded_frames [ idx ] for idx in argmin_ ] )
closest_ts = loaded_ts [ argmin_ ]
if log_loaded_timestamps :
logging . info ( f " { closest_ts =} " )
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
closest_frames = closest_frames . type ( torch . float32 ) / 255
assert len ( timestamps ) == len ( closest_frames )
return closest_frames
2025-03-14 23:53:42 +08:00
def decode_video_frames_torchcodec (
video_path : Path | str ,
timestamps : list [ float ] ,
tolerance_s : float ,
device : str = " cpu " ,
log_loaded_timestamps : bool = False ,
) - > torch . Tensor :
""" Loads frames associated with the requested timestamps of a video using torchcodec.
Note : Setting device = " cuda " outside the main process , e . g . in data loader workers , will lead to CUDA initialization errors .
Note : Video benefits from inter - frame compression . Instead of storing every frame individually ,
the encoder stores a reference frame ( or a key frame ) and subsequent frames as differences relative to
that key frame . As a consequence , to access a requested frame , we need to load the preceding key frame ,
and all subsequent frames until reaching the requested frame . The number of key frames in a video
can be adjusted during encoding to take into account decoding time and video size in bytes .
"""
2025-03-17 20:23:11 +08:00
if importlib . util . find_spec ( " torchcodec " ) :
from torchcodec . decoders import VideoDecoder
else :
raise ImportError ( " torchcodec is required but not available. " )
2025-03-14 23:53:42 +08:00
# initialize video decoder
decoder = VideoDecoder ( video_path , device = device , seek_mode = " approximate " )
loaded_frames = [ ]
loaded_ts = [ ]
# get metadata for frame information
metadata = decoder . metadata
average_fps = metadata . average_fps
# convert timestamps to frame indices
frame_indices = [ round ( ts * average_fps ) for ts in timestamps ]
# retrieve frames based on indices
frames_batch = decoder . get_frames_at ( indices = frame_indices )
for frame , pts in zip ( frames_batch . data , frames_batch . pts_seconds , strict = False ) :
loaded_frames . append ( frame )
loaded_ts . append ( pts . item ( ) )
if log_loaded_timestamps :
logging . info ( f " Frame loaded at timestamp= { pts : .4f } " )
query_ts = torch . tensor ( timestamps )
loaded_ts = torch . tensor ( loaded_ts )
# compute distances between each query timestamp and loaded timestamps
dist = torch . cdist ( query_ts [ : , None ] , loaded_ts [ : , None ] , p = 1 )
min_ , argmin_ = dist . min ( 1 )
is_within_tol = min_ < tolerance_s
assert is_within_tol . all ( ) , (
f " One or several query timestamps unexpectedly violate the tolerance ( { min_ [ ~ is_within_tol ] } > { tolerance_s =} ). "
" It means that the closest frame that can be loaded from the video is too far away in time. "
" This might be due to synchronization issues with timestamps during data collection. "
" To be safe, we advise to ignore this item during training. "
f " \n queried timestamps: { query_ts } "
f " \n loaded timestamps: { loaded_ts } "
f " \n video: { video_path } "
)
# get closest frames to the query timestamps
closest_frames = torch . stack ( [ loaded_frames [ idx ] for idx in argmin_ ] )
closest_ts = loaded_ts [ argmin_ ]
if log_loaded_timestamps :
logging . info ( f " { closest_ts =} " )
# convert to float32 in [0,1] range (channel first)
closest_frames = closest_frames . type ( torch . float32 ) / 255
assert len ( timestamps ) == len ( closest_frames )
return closest_frames
2024-07-10 02:20:25 +08:00
def encode_video_frames (
2024-11-30 02:04:00 +08:00
imgs_dir : Path | str ,
video_path : Path | str ,
2024-07-10 02:20:25 +08:00
fps : int ,
2024-07-23 02:08:59 +08:00
vcodec : str = " libsvtav1 " ,
pix_fmt : str = " yuv420p " ,
g : int | None = 2 ,
crf : int | None = 30 ,
2024-07-10 02:20:25 +08:00
fast_decode : int = 0 ,
log_level : str | None = " error " ,
overwrite : bool = False ,
) - > None :
""" More info on ffmpeg arguments tuning on `benchmark/video/README.md` """
2024-05-03 06:50:19 +08:00
video_path = Path ( video_path )
2025-03-29 01:08:12 +08:00
imgs_dir = Path ( imgs_dir )
2024-05-03 06:50:19 +08:00
video_path . parent . mkdir ( parents = True , exist_ok = True )
2024-07-10 02:20:25 +08:00
ffmpeg_args = OrderedDict (
[
( " -f " , " image2 " ) ,
( " -r " , str ( fps ) ) ,
( " -i " , str ( imgs_dir / " frame_ %06d .png " ) ) ,
2024-07-23 02:08:59 +08:00
( " -vcodec " , vcodec ) ,
( " -pix_fmt " , pix_fmt ) ,
2024-07-10 02:20:25 +08:00
]
2024-05-03 06:50:19 +08:00
)
2024-07-10 02:20:25 +08:00
2024-07-23 02:08:59 +08:00
if g is not None :
ffmpeg_args [ " -g " ] = str ( g )
2024-07-10 02:20:25 +08:00
2024-07-23 02:08:59 +08:00
if crf is not None :
ffmpeg_args [ " -crf " ] = str ( crf )
2024-07-10 02:20:25 +08:00
if fast_decode :
2024-07-23 02:08:59 +08:00
key = " -svtav1-params " if vcodec == " libsvtav1 " else " -tune "
value = f " fast-decode= { fast_decode } " if vcodec == " libsvtav1 " else " fastdecode "
2024-07-10 02:20:25 +08:00
ffmpeg_args [ key ] = value
if log_level is not None :
ffmpeg_args [ " -loglevel " ] = str ( log_level )
ffmpeg_args = [ item for pair in ffmpeg_args . items ( ) for item in pair ]
if overwrite :
ffmpeg_args . append ( " -y " )
ffmpeg_cmd = [ " ffmpeg " ] + ffmpeg_args + [ str ( video_path ) ]
2024-07-15 23:43:10 +08:00
# redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal
subprocess . run ( ffmpeg_cmd , check = True , stdin = subprocess . DEVNULL )
2024-05-03 06:50:19 +08:00
2024-08-16 00:11:33 +08:00
if not video_path . exists ( ) :
raise OSError (
f " Video encoding did not work. File not found: { video_path } . "
f " Try running the command manually to debug: ` { ' ' . join ( ffmpeg_cmd ) } ` "
)
2024-05-03 06:50:19 +08:00
@dataclass
class VideoFrame :
# TODO(rcadene, lhoestq): move to Hugging Face `datasets` repo
"""
Provides a type for a dataset containing video frames .
Example :
` ` ` python
data_dict = [ { " image " : { " path " : " videos/episode_0.mp4 " , " timestamp " : 0.3 } } ]
features = { " image " : VideoFrame ( ) }
Dataset . from_dict ( data_dict , features = Features ( features ) )
` ` `
"""
pa_type : ClassVar [ Any ] = pa . struct ( { " path " : pa . string ( ) , " timestamp " : pa . float32 ( ) } )
_type : str = field ( default = " VideoFrame " , init = False , repr = False )
def __call__ ( self ) :
return self . pa_type
with warnings . catch_warnings ( ) :
warnings . filterwarnings (
" ignore " ,
" ' register_feature ' is experimental and might be subject to breaking changes in the future. " ,
category = UserWarning ,
)
# to make VideoFrame available in HuggingFace `datasets`
register_feature ( VideoFrame , " VideoFrame " )
2024-11-30 02:04:00 +08:00
def get_audio_info ( video_path : Path | str ) - > dict :
ffprobe_audio_cmd = [
" ffprobe " ,
" -v " ,
" error " ,
" -select_streams " ,
" a:0 " ,
" -show_entries " ,
" stream=channels,codec_name,bit_rate,sample_rate,bit_depth,channel_layout,duration " ,
" -of " ,
" json " ,
str ( video_path ) ,
]
result = subprocess . run ( ffprobe_audio_cmd , stdout = subprocess . PIPE , stderr = subprocess . PIPE , text = True )
if result . returncode != 0 :
raise RuntimeError ( f " Error running ffprobe: { result . stderr } " )
info = json . loads ( result . stdout )
audio_stream_info = info [ " streams " ] [ 0 ] if info . get ( " streams " ) else None
if audio_stream_info is None :
return { " has_audio " : False }
# Return the information, defaulting to None if no audio stream is present
return {
" has_audio " : True ,
" audio.channels " : audio_stream_info . get ( " channels " , None ) ,
" audio.codec " : audio_stream_info . get ( " codec_name " , None ) ,
" audio.bit_rate " : int ( audio_stream_info [ " bit_rate " ] ) if audio_stream_info . get ( " bit_rate " ) else None ,
" audio.sample_rate " : int ( audio_stream_info [ " sample_rate " ] )
if audio_stream_info . get ( " sample_rate " )
else None ,
" audio.bit_depth " : audio_stream_info . get ( " bit_depth " , None ) ,
" audio.channel_layout " : audio_stream_info . get ( " channel_layout " , None ) ,
}
def get_video_info ( video_path : Path | str ) - > dict :
ffprobe_video_cmd = [
" ffprobe " ,
" -v " ,
" error " ,
" -select_streams " ,
" v:0 " ,
" -show_entries " ,
" stream=r_frame_rate,width,height,codec_name,nb_frames,duration,pix_fmt " ,
" -of " ,
" json " ,
str ( video_path ) ,
]
result = subprocess . run ( ffprobe_video_cmd , stdout = subprocess . PIPE , stderr = subprocess . PIPE , text = True )
if result . returncode != 0 :
raise RuntimeError ( f " Error running ffprobe: { result . stderr } " )
info = json . loads ( result . stdout )
video_stream_info = info [ " streams " ] [ 0 ]
# Calculate fps from r_frame_rate
r_frame_rate = video_stream_info [ " r_frame_rate " ]
num , denom = map ( int , r_frame_rate . split ( " / " ) )
fps = num / denom
pixel_channels = get_video_pixel_channels ( video_stream_info [ " pix_fmt " ] )
video_info = {
" video.fps " : fps ,
" video.height " : video_stream_info [ " height " ] ,
" video.width " : video_stream_info [ " width " ] ,
" video.channels " : pixel_channels ,
" video.codec " : video_stream_info [ " codec_name " ] ,
" video.pix_fmt " : video_stream_info [ " pix_fmt " ] ,
" video.is_depth_map " : False ,
* * get_audio_info ( video_path ) ,
}
return video_info
def get_video_pixel_channels ( pix_fmt : str ) - > int :
if " gray " in pix_fmt or " depth " in pix_fmt or " monochrome " in pix_fmt :
return 1
elif " rgba " in pix_fmt or " yuva " in pix_fmt :
return 4
elif " rgb " in pix_fmt or " yuv " in pix_fmt :
return 3
else :
raise ValueError ( " Unknown format " )
def get_image_pixel_channels ( image : Image ) :
if image . mode == " L " :
return 1 # Grayscale
elif image . mode == " LA " :
return 2 # Grayscale + Alpha
elif image . mode == " RGB " :
return 3 # RGB
elif image . mode == " RGBA " :
return 4 # RGBA
else :
raise ValueError ( " Unknown format " )