Add torchcodec cpu (#798)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Remi <re.cadene@gmail.com>
Co-authored-by: Remi <remi.cadene@huggingface.co>
Co-authored-by: Simon Alibert <simon.alibert@huggingface.co>
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
Jade Choghari 2025-03-14 18:53:42 +03:00 committed by GitHub
parent 974028bd28
commit 0e98c6ee96
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 107 additions and 10 deletions

View File

@ -126,7 +126,7 @@ jobs:
# portaudio19-dev is needed to install pyaudio # portaudio19-dev is needed to install pyaudio
run: | run: |
sudo apt-get update && \ sudo apt-get update && \
sudo apt-get install -y libegl1-mesa-dev portaudio19-dev sudo apt-get install -y libegl1-mesa-dev ffmpeg portaudio19-dev
- name: Install uv and python - name: Install uv and python
uses: astral-sh/setup-uv@v5 uses: astral-sh/setup-uv@v5

View File

@ -67,7 +67,7 @@ def parse_int_or_none(value) -> int | None:
def check_datasets_formats(repo_ids: list) -> None: def check_datasets_formats(repo_ids: list) -> None:
for repo_id in repo_ids: for repo_id in repo_ids:
dataset = LeRobotDataset(repo_id) dataset = LeRobotDataset(repo_id)
if dataset.video: if len(dataset.meta.video_keys) > 0:
raise ValueError( raise ValueError(
f"Use only image dataset for running this benchmark. Video dataset provided: {repo_id}" f"Use only image dataset for running this benchmark. Video dataset provided: {repo_id}"
) )

View File

@ -67,7 +67,7 @@ from lerobot.common.datasets.utils import (
) )
from lerobot.common.datasets.video_utils import ( from lerobot.common.datasets.video_utils import (
VideoFrame, VideoFrame,
decode_video_frames_torchvision, decode_video_frames,
encode_video_frames, encode_video_frames,
get_video_info, get_video_info,
) )
@ -462,8 +462,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
video files are already present on local disk, they won't be downloaded again. Defaults to video files are already present on local disk, they won't be downloaded again. Defaults to
True. True.
video_backend (str | None, optional): Video backend to use for decoding videos. There is currently video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec.
a single option which is the pyav decoder used by Torchvision. Defaults to pyav. You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
""" """
super().__init__() super().__init__()
self.repo_id = repo_id self.repo_id = repo_id
@ -473,7 +473,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.episodes = episodes self.episodes = episodes
self.tolerance_s = tolerance_s self.tolerance_s = tolerance_s
self.revision = revision if revision else CODEBASE_VERSION self.revision = revision if revision else CODEBASE_VERSION
self.video_backend = video_backend if video_backend else "pyav" self.video_backend = video_backend if video_backend else "torchcodec"
self.delta_indices = None self.delta_indices = None
# Unused attributes # Unused attributes
@ -707,9 +707,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
item = {} item = {}
for vid_key, query_ts in query_timestamps.items(): for vid_key, query_ts in query_timestamps.items():
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key) video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
frames = decode_video_frames_torchvision( frames = decode_video_frames(video_path, query_ts, self.tolerance_s, self.video_backend)
video_path, query_ts, self.tolerance_s, self.video_backend
)
item[vid_key] = frames.squeeze(0) item[vid_key] = frames.squeeze(0)
return item return item
@ -1029,7 +1027,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.delta_timestamps = None obj.delta_timestamps = None
obj.delta_indices = None obj.delta_indices = None
obj.episode_data_index = None obj.episode_data_index = None
obj.video_backend = video_backend if video_backend is not None else "pyav" obj.video_backend = video_backend if video_backend is not None else "torchcodec"
return obj return obj

View File

@ -27,6 +27,35 @@ import torch
import torchvision import torchvision
from datasets.features.features import register_feature from datasets.features.features import register_feature
from PIL import Image from PIL import Image
from torchcodec.decoders import VideoDecoder
def decode_video_frames(
video_path: Path | str,
timestamps: list[float],
tolerance_s: float,
backend: str = "torchcodec",
) -> torch.Tensor:
"""
Decodes video frames using the specified backend.
Args:
video_path (Path): Path to the video file.
timestamps (list[float]): List of timestamps to extract frames.
tolerance_s (float): Allowed deviation in seconds for frame retrieval.
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec".
Returns:
torch.Tensor: Decoded frames.
Currently supports torchcodec on cpu and pyav.
"""
if backend == "torchcodec":
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
elif backend in ["pyav", "video_reader"]:
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
else:
raise ValueError(f"Unsupported video backend: {backend}")
def decode_video_frames_torchvision( def decode_video_frames_torchvision(
@ -127,6 +156,75 @@ def decode_video_frames_torchvision(
return closest_frames return closest_frames
def decode_video_frames_torchcodec(
video_path: Path | str,
timestamps: list[float],
tolerance_s: float,
device: str = "cpu",
log_loaded_timestamps: bool = False,
) -> torch.Tensor:
"""Loads frames associated with the requested timestamps of a video using torchcodec.
Note: Setting device="cuda" outside the main process, e.g. in data loader workers, will lead to CUDA initialization errors.
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to
that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame,
and all subsequent frames until reaching the requested frame. The number of key frames in a video
can be adjusted during encoding to take into account decoding time and video size in bytes.
"""
# initialize video decoder
decoder = VideoDecoder(video_path, device=device, seek_mode="approximate")
loaded_frames = []
loaded_ts = []
# get metadata for frame information
metadata = decoder.metadata
average_fps = metadata.average_fps
# convert timestamps to frame indices
frame_indices = [round(ts * average_fps) for ts in timestamps]
# retrieve frames based on indices
frames_batch = decoder.get_frames_at(indices=frame_indices)
for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=False):
loaded_frames.append(frame)
loaded_ts.append(pts.item())
if log_loaded_timestamps:
logging.info(f"Frame loaded at timestamp={pts:.4f}")
query_ts = torch.tensor(timestamps)
loaded_ts = torch.tensor(loaded_ts)
# compute distances between each query timestamp and loaded timestamps
dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
min_, argmin_ = dist.min(1)
is_within_tol = min_ < tolerance_s
assert is_within_tol.all(), (
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
"It means that the closest frame that can be loaded from the video is too far away in time."
"This might be due to synchronization issues with timestamps during data collection."
"To be safe, we advise to ignore this item during training."
f"\nqueried timestamps: {query_ts}"
f"\nloaded timestamps: {loaded_ts}"
f"\nvideo: {video_path}"
)
# get closest frames to the query timestamps
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
closest_ts = loaded_ts[argmin_]
if log_loaded_timestamps:
logging.info(f"{closest_ts=}")
# convert to float32 in [0,1] range (channel first)
closest_frames = closest_frames.type(torch.float32) / 255
assert len(timestamps) == len(closest_frames)
return closest_frames
def encode_video_frames( def encode_video_frames(
imgs_dir: Path | str, imgs_dir: Path | str,
video_path: Path | str, video_path: Path | str,

View File

@ -69,6 +69,7 @@ dependencies = [
"rerun-sdk>=0.21.0", "rerun-sdk>=0.21.0",
"termcolor>=2.4.0", "termcolor>=2.4.0",
"torch>=2.2.1", "torch>=2.2.1",
"torchcodec>=0.2.1",
"torchvision>=0.21.0", "torchvision>=0.21.0",
"wandb>=0.16.3", "wandb>=0.16.3",
"zarr>=2.17.0", "zarr>=2.17.0",