lerobot/lerobot/common/datasets/video_utils.py

165 lines
5.8 KiB
Python

import logging
import subprocess
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, ClassVar
import pyarrow as pa
import torch
import torchvision
from datasets.features.features import register_feature
def load_from_videos(item, video_frame_keys, videos_dir):
# since video path already contains "videos" (e.g. videos_dir="data/videos", path="videos/episode_0.mp4")
data_dir = videos_dir.parent
for key in video_frame_keys:
ep_idx = item["episode_index"]
video_path = data_dir / key / f"episode_{ep_idx:06d}.mp4"
if isinstance(item[key], list):
# load multiple frames at once
timestamps = [frame["timestamp"] for frame in item[key]]
paths = [frame["path"] for frame in item[key]]
if len(set(paths)) == 1:
raise NotImplementedError("All video paths are expected to be the same for now.")
video_path = data_dir / paths[0]
frames = decode_video_frames_torchvision(video_path, timestamps)
assert len(frames) == len(timestamps)
item[key] = frames
else:
# load one frame
timestamps = [item[key]["timestamp"]]
video_path = data_dir / item[key]["path"]
frames = decode_video_frames_torchvision(video_path, timestamps)
assert len(frames) == 1
item[key] = frames[0]
return item
def decode_video_frames_torchvision(
video_path: str, timestamps: list[float], device: str = "cpu", log_loaded_timestamps: bool = False
):
"""Loads frames associated to the requested timestamps of a video
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to
that key frame. As a consequence, to access a requested frame, we need to load the preceeding key frame,
and all subsequent frames until reaching the requested frame. The number of key frames in a video
can be adjusted during encoding to take into account decoding time and video size in bytes.
"""
video_path = str(video_path)
# set backend
if device == "cpu":
# explicitely use pyav
torchvision.set_video_backend("pyav")
elif device == "cuda":
# TODO(rcadene, aliberts): implement video decoding with GPU
# torchvision.set_video_backend("cuda")
# torchvision.set_video_backend("video_reader")
# requires installing torchvision from source, see: https://github.com/pytorch/vision/blob/main/torchvision/csrc/io/decoder/gpu/README.rst
# check possible bug: https://github.com/pytorch/vision/issues/7745
raise NotImplementedError()
else:
raise ValueError(device)
# set a video stream reader
# TODO(rcadene): also load audio stream at the same time
reader = torchvision.io.VideoReader(video_path, "video")
def round_timestamp(ts):
# sanity preprocessing (e.g. 3.60000003 -> 3.6000, 0.0666666667 -> 0.0667)
return round(ts, 4)
timestamps = [round_timestamp(ts) for ts in timestamps]
# set the first and last requested timestamps
# Note: previous timestamps are usually loaded, since we need to access the previous key frame
first_ts = timestamps[0]
last_ts = timestamps[-1]
# access key frame of first requested frame, and load all frames until last requested frame
# for details on what `seek` is doing see: https://pyav.basswood-io.com/docs/stable/api/container.html?highlight=inputcontainer#av.container.InputContainer.seek
reader.seek(first_ts)
frames = []
for frame in reader:
# get timestamp of the loaded frame
ts = round_timestamp(frame["pts"])
# if the loaded frame is not among the requested frames, we dont add it to the list of output frames
is_frame_requested = ts in timestamps
if is_frame_requested:
frames.append(frame["data"])
if log_loaded_timestamps:
log = f"frame loaded at timestamp={ts:.4f}"
if is_frame_requested:
log += " requested"
logging.info(log)
if len(timestamps) == len(frames):
break
# hard stop
assert (
frame["pts"] >= last_ts
), f"Not enough frames have been loaded in [{first_ts}, {last_ts}]. {len(timestamps)} expected, but only {len(frames)} loaded."
frames = torch.stack(frames)
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
frames = frames.type(torch.float32) / 255
assert len(timestamps) == frames.shape[0]
return frames
def encode_video_frames(imgs_dir: Path, video_path: Path, fps: int):
# For more info this setting, see: `lerobot/common/datasets/_video_benchmark/README.md`
video_path = Path(video_path)
video_path.parent.mkdir(parents=True, exist_ok=True)
ffmpeg_cmd = (
f"ffmpeg -r {fps} "
"-f image2 "
"-loglevel error "
f"-i {str(imgs_dir / 'frame_%06d.png')} "
"-vcodec libx264 "
"-pix_fmt yuv444p "
f"{str(video_path)}"
)
subprocess.run(ffmpeg_cmd.split(" "), check=True)
@dataclass
class VideoFrame:
# TODO(rcadene, lhoestq): move to Hugging Face `datasets` repo
"""
Provides a type for a dataset containing video frames.
Example:
```python
data_dict = [{"image": {"path": "videos/episode_0.mp4", "timestamp": 0.3}}]
features = {"image": VideoFrame()}
Dataset.from_dict(data_dict, features=Features(features))
```
"""
pa_type: ClassVar[Any] = pa.struct({"path": pa.string(), "timestamp": pa.float32()})
_type: str = field(default="VideoFrame", init=False, repr=False)
def __call__(self):
return self.pa_type
# to make it available in HuggingFace `datasets`
register_feature(VideoFrame, "VideoFrame")