diff --git a/lerobot/common/datasets/video_utils.py b/lerobot/common/datasets/video_utils.py index c38d570d..7e04d2a9 100644 --- a/lerobot/common/datasets/video_utils.py +++ b/lerobot/common/datasets/video_utils.py @@ -13,16 +13,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import glob import importlib -import json import logging -import subprocess import warnings -from collections import OrderedDict from dataclasses import dataclass, field from pathlib import Path from typing import Any, ClassVar +import av import pyarrow as pa import torch import torchvision @@ -252,51 +251,68 @@ def encode_video_frames( g: int | None = 2, crf: int | None = 30, fast_decode: int = 0, - log_level: str | None = "error", + log_level: int | None = av.logging.ERROR, overwrite: bool = False, ) -> None: """More info on ffmpeg arguments tuning on `benchmark/video/README.md`""" video_path = Path(video_path) imgs_dir = Path(imgs_dir) + + if video_path.exists() and not overwrite: + raise FileExistsError( + f"Video file already exists at {video_path}. Use `overwrite=True` to overwrite it." + ) + video_path.parent.mkdir(parents=True, exist_ok=True) - ffmpeg_args = OrderedDict( - [ - ("-f", "image2"), - ("-r", str(fps)), - ("-i", str(imgs_dir / "frame_%06d.png")), - ("-vcodec", vcodec), - ("-pix_fmt", pix_fmt), - ] + # Get input frames + template = "frame_" + ("[0-9]" * 6) + ".png" + input_list = sorted( + glob.glob(str(imgs_dir / template)), key=lambda x: int(x.split("_")[-1].split(".")[0]) ) + # Define video output options + video_options = {"pix_fmt": pix_fmt} + if g is not None: - ffmpeg_args["-g"] = str(g) + video_options["g"] = str(g) if crf is not None: - ffmpeg_args["-crf"] = str(crf) + video_options["crf"] = str(crf) if fast_decode: - key = "-svtav1-params" if vcodec == "libsvtav1" else "-tune" + key = "svtav1-params" if vcodec == "libsvtav1" else "tune" value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode" - ffmpeg_args[key] = value + video_options[key] = value + # Set logging level if log_level is not None: - ffmpeg_args["-loglevel"] = str(log_level) + # "While less efficient, it is generally preferable to modify logging with Python’s logging" + logging.getLogger("libav").setLevel(log_level) - ffmpeg_args = [item for pair in ffmpeg_args.items() for item in pair] - if overwrite: - ffmpeg_args.append("-y") + # Create and open output file (overwrite by default) + with av.open(str(video_path), "w", format="mp4") as output: + output_stream = output.add_stream(vcodec, fps, options=video_options) - ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(video_path)] - # redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal - subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL) + # Loop through input frames and encode them + for input in input_list: + input_image = Image.open(input).convert("RGB") + input_frame = av.VideoFrame.from_image(input_image) + packet = output_stream.encode(input_frame) + if packet: + output.mux(packet) + + # Flush the encoder + packet = output_stream.encode() + if packet: + output.mux(packet) + + # Reset logging level + if log_level is not None: + av.logging.restore_default_callback() if not video_path.exists(): - raise OSError( - f"Video encoding did not work. File not found: {video_path}. " - f"Try running the command manually to debug: `{''.join(ffmpeg_cmd)}`" - ) + raise OSError(f"Video encoding did not work. File not found: {video_path}.") @dataclass @@ -332,78 +348,68 @@ with warnings.catch_warnings(): def get_audio_info(video_path: Path | str) -> dict: - ffprobe_audio_cmd = [ - "ffprobe", - "-v", - "error", - "-select_streams", - "a:0", - "-show_entries", - "stream=channels,codec_name,bit_rate,sample_rate,bit_depth,channel_layout,duration", - "-of", - "json", - str(video_path), - ] - result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) - if result.returncode != 0: - raise RuntimeError(f"Error running ffprobe: {result.stderr}") + # Set logging level + logging.getLogger("libav").setLevel(av.logging.ERROR) - info = json.loads(result.stdout) - audio_stream_info = info["streams"][0] if info.get("streams") else None - if audio_stream_info is None: - return {"has_audio": False} + # Getting audio stream information + audio_info = {} + with av.open(str(video_path), "r") as audio_file: + try: + audio_stream = audio_file.streams.audio[0] + except IndexError: + # Reset logging level + av.logging.restore_default_callback() + return {"has_audio": False} - # Return the information, defaulting to None if no audio stream is present - return { - "has_audio": True, - "audio.channels": audio_stream_info.get("channels", None), - "audio.codec": audio_stream_info.get("codec_name", None), - "audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None, - "audio.sample_rate": int(audio_stream_info["sample_rate"]) - if audio_stream_info.get("sample_rate") - else None, - "audio.bit_depth": audio_stream_info.get("bit_depth", None), - "audio.channel_layout": audio_stream_info.get("channel_layout", None), - } + audio_info["audio.channels"] = audio_stream.channels + audio_info["audio.codec"] = audio_stream.codec.canonical_name + audio_info["audio.bit_rate"] = ( + audio_stream.bit_rate + ) # In an ideal loseless case : bit depth x sample rate x channels = bit rate. In an actual compressed case, the bit rate is set according to the compression level : the lower the bit rate, the more compression is applied. + audio_info["audio.sample_rate"] = audio_stream.sample_rate # Number of samples per second + audio_info["audio.bit_depth"] = ( + audio_stream.format.bits + ) # In an ideal loseless case : fixed number of bits per sample. In an actual compressed case : variable number of bits per sample (often reduced to match a given depth rate). + audio_info["audio.channel_layout"] = audio_stream.layout.name + audio_info["has_audio"] = True + + # Reset logging level + av.logging.restore_default_callback() + + return audio_info def get_video_info(video_path: Path | str) -> dict: - ffprobe_video_cmd = [ - "ffprobe", - "-v", - "error", - "-select_streams", - "v:0", - "-show_entries", - "stream=r_frame_rate,width,height,codec_name,nb_frames,duration,pix_fmt", - "-of", - "json", - str(video_path), - ] - result = subprocess.run(ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) - if result.returncode != 0: - raise RuntimeError(f"Error running ffprobe: {result.stderr}") + # Set logging level + logging.getLogger("libav").setLevel(av.logging.ERROR) - info = json.loads(result.stdout) - video_stream_info = info["streams"][0] + # Getting video stream information + video_info = {} + with av.open(str(video_path), "r") as video_file: + try: + video_stream = video_file.streams.video[0] + except IndexError: + # Reset logging level + av.logging.restore_default_callback() + return {} - # Calculate fps from r_frame_rate - r_frame_rate = video_stream_info["r_frame_rate"] - num, denom = map(int, r_frame_rate.split("/")) - fps = num / denom + video_info["video.height"] = video_stream.height + video_info["video.width"] = video_stream.width + video_info["video.codec"] = video_stream.codec.canonical_name + video_info["video.pix_fmt"] = video_stream.pix_fmt + video_info["video.is_depth_map"] = False - pixel_channels = get_video_pixel_channels(video_stream_info["pix_fmt"]) + # Calculate fps from r_frame_rate + video_info["video.fps"] = int(video_stream.base_rate) - video_info = { - "video.fps": fps, - "video.height": video_stream_info["height"], - "video.width": video_stream_info["width"], - "video.channels": pixel_channels, - "video.codec": video_stream_info["codec_name"], - "video.pix_fmt": video_stream_info["pix_fmt"], - "video.is_depth_map": False, - **get_audio_info(video_path), - } + pixel_channels = get_video_pixel_channels(video_stream.pix_fmt) + video_info["video.channels"] = pixel_channels + + # Reset logging level + av.logging.restore_default_callback() + + # Adding audio stream information + video_info.update(**get_audio_info(video_path)) return video_info