Add torchcodec cpu (#798)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Remi <re.cadene@gmail.com> Co-authored-by: Remi <remi.cadene@huggingface.co> Co-authored-by: Simon Alibert <simon.alibert@huggingface.co> Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
parent
974028bd28
commit
0e98c6ee96
|
@ -126,7 +126,7 @@ jobs:
|
||||||
# portaudio19-dev is needed to install pyaudio
|
# portaudio19-dev is needed to install pyaudio
|
||||||
run: |
|
run: |
|
||||||
sudo apt-get update && \
|
sudo apt-get update && \
|
||||||
sudo apt-get install -y libegl1-mesa-dev portaudio19-dev
|
sudo apt-get install -y libegl1-mesa-dev ffmpeg portaudio19-dev
|
||||||
|
|
||||||
- name: Install uv and python
|
- name: Install uv and python
|
||||||
uses: astral-sh/setup-uv@v5
|
uses: astral-sh/setup-uv@v5
|
||||||
|
|
|
@ -67,7 +67,7 @@ def parse_int_or_none(value) -> int | None:
|
||||||
def check_datasets_formats(repo_ids: list) -> None:
|
def check_datasets_formats(repo_ids: list) -> None:
|
||||||
for repo_id in repo_ids:
|
for repo_id in repo_ids:
|
||||||
dataset = LeRobotDataset(repo_id)
|
dataset = LeRobotDataset(repo_id)
|
||||||
if dataset.video:
|
if len(dataset.meta.video_keys) > 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Use only image dataset for running this benchmark. Video dataset provided: {repo_id}"
|
f"Use only image dataset for running this benchmark. Video dataset provided: {repo_id}"
|
||||||
)
|
)
|
||||||
|
|
|
@ -67,7 +67,7 @@ from lerobot.common.datasets.utils import (
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.video_utils import (
|
from lerobot.common.datasets.video_utils import (
|
||||||
VideoFrame,
|
VideoFrame,
|
||||||
decode_video_frames_torchvision,
|
decode_video_frames,
|
||||||
encode_video_frames,
|
encode_video_frames,
|
||||||
get_video_info,
|
get_video_info,
|
||||||
)
|
)
|
||||||
|
@ -462,8 +462,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
|
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
|
||||||
video files are already present on local disk, they won't be downloaded again. Defaults to
|
video files are already present on local disk, they won't be downloaded again. Defaults to
|
||||||
True.
|
True.
|
||||||
video_backend (str | None, optional): Video backend to use for decoding videos. There is currently
|
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec.
|
||||||
a single option which is the pyav decoder used by Torchvision. Defaults to pyav.
|
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.repo_id = repo_id
|
self.repo_id = repo_id
|
||||||
|
@ -473,7 +473,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
self.episodes = episodes
|
self.episodes = episodes
|
||||||
self.tolerance_s = tolerance_s
|
self.tolerance_s = tolerance_s
|
||||||
self.revision = revision if revision else CODEBASE_VERSION
|
self.revision = revision if revision else CODEBASE_VERSION
|
||||||
self.video_backend = video_backend if video_backend else "pyav"
|
self.video_backend = video_backend if video_backend else "torchcodec"
|
||||||
self.delta_indices = None
|
self.delta_indices = None
|
||||||
|
|
||||||
# Unused attributes
|
# Unused attributes
|
||||||
|
@ -707,9 +707,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
item = {}
|
item = {}
|
||||||
for vid_key, query_ts in query_timestamps.items():
|
for vid_key, query_ts in query_timestamps.items():
|
||||||
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
|
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
|
||||||
frames = decode_video_frames_torchvision(
|
frames = decode_video_frames(video_path, query_ts, self.tolerance_s, self.video_backend)
|
||||||
video_path, query_ts, self.tolerance_s, self.video_backend
|
|
||||||
)
|
|
||||||
item[vid_key] = frames.squeeze(0)
|
item[vid_key] = frames.squeeze(0)
|
||||||
|
|
||||||
return item
|
return item
|
||||||
|
@ -1029,7 +1027,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
obj.delta_timestamps = None
|
obj.delta_timestamps = None
|
||||||
obj.delta_indices = None
|
obj.delta_indices = None
|
||||||
obj.episode_data_index = None
|
obj.episode_data_index = None
|
||||||
obj.video_backend = video_backend if video_backend is not None else "pyav"
|
obj.video_backend = video_backend if video_backend is not None else "torchcodec"
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -27,6 +27,35 @@ import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
from datasets.features.features import register_feature
|
from datasets.features.features import register_feature
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from torchcodec.decoders import VideoDecoder
|
||||||
|
|
||||||
|
|
||||||
|
def decode_video_frames(
|
||||||
|
video_path: Path | str,
|
||||||
|
timestamps: list[float],
|
||||||
|
tolerance_s: float,
|
||||||
|
backend: str = "torchcodec",
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Decodes video frames using the specified backend.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_path (Path): Path to the video file.
|
||||||
|
timestamps (list[float]): List of timestamps to extract frames.
|
||||||
|
tolerance_s (float): Allowed deviation in seconds for frame retrieval.
|
||||||
|
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Decoded frames.
|
||||||
|
|
||||||
|
Currently supports torchcodec on cpu and pyav.
|
||||||
|
"""
|
||||||
|
if backend == "torchcodec":
|
||||||
|
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
|
||||||
|
elif backend in ["pyav", "video_reader"]:
|
||||||
|
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported video backend: {backend}")
|
||||||
|
|
||||||
|
|
||||||
def decode_video_frames_torchvision(
|
def decode_video_frames_torchvision(
|
||||||
|
@ -127,6 +156,75 @@ def decode_video_frames_torchvision(
|
||||||
return closest_frames
|
return closest_frames
|
||||||
|
|
||||||
|
|
||||||
|
def decode_video_frames_torchcodec(
|
||||||
|
video_path: Path | str,
|
||||||
|
timestamps: list[float],
|
||||||
|
tolerance_s: float,
|
||||||
|
device: str = "cpu",
|
||||||
|
log_loaded_timestamps: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Loads frames associated with the requested timestamps of a video using torchcodec.
|
||||||
|
|
||||||
|
Note: Setting device="cuda" outside the main process, e.g. in data loader workers, will lead to CUDA initialization errors.
|
||||||
|
|
||||||
|
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
|
||||||
|
the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to
|
||||||
|
that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame,
|
||||||
|
and all subsequent frames until reaching the requested frame. The number of key frames in a video
|
||||||
|
can be adjusted during encoding to take into account decoding time and video size in bytes.
|
||||||
|
"""
|
||||||
|
# initialize video decoder
|
||||||
|
decoder = VideoDecoder(video_path, device=device, seek_mode="approximate")
|
||||||
|
loaded_frames = []
|
||||||
|
loaded_ts = []
|
||||||
|
# get metadata for frame information
|
||||||
|
metadata = decoder.metadata
|
||||||
|
average_fps = metadata.average_fps
|
||||||
|
|
||||||
|
# convert timestamps to frame indices
|
||||||
|
frame_indices = [round(ts * average_fps) for ts in timestamps]
|
||||||
|
|
||||||
|
# retrieve frames based on indices
|
||||||
|
frames_batch = decoder.get_frames_at(indices=frame_indices)
|
||||||
|
|
||||||
|
for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=False):
|
||||||
|
loaded_frames.append(frame)
|
||||||
|
loaded_ts.append(pts.item())
|
||||||
|
if log_loaded_timestamps:
|
||||||
|
logging.info(f"Frame loaded at timestamp={pts:.4f}")
|
||||||
|
|
||||||
|
query_ts = torch.tensor(timestamps)
|
||||||
|
loaded_ts = torch.tensor(loaded_ts)
|
||||||
|
|
||||||
|
# compute distances between each query timestamp and loaded timestamps
|
||||||
|
dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
|
||||||
|
min_, argmin_ = dist.min(1)
|
||||||
|
|
||||||
|
is_within_tol = min_ < tolerance_s
|
||||||
|
assert is_within_tol.all(), (
|
||||||
|
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
||||||
|
"It means that the closest frame that can be loaded from the video is too far away in time."
|
||||||
|
"This might be due to synchronization issues with timestamps during data collection."
|
||||||
|
"To be safe, we advise to ignore this item during training."
|
||||||
|
f"\nqueried timestamps: {query_ts}"
|
||||||
|
f"\nloaded timestamps: {loaded_ts}"
|
||||||
|
f"\nvideo: {video_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# get closest frames to the query timestamps
|
||||||
|
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
|
||||||
|
closest_ts = loaded_ts[argmin_]
|
||||||
|
|
||||||
|
if log_loaded_timestamps:
|
||||||
|
logging.info(f"{closest_ts=}")
|
||||||
|
|
||||||
|
# convert to float32 in [0,1] range (channel first)
|
||||||
|
closest_frames = closest_frames.type(torch.float32) / 255
|
||||||
|
|
||||||
|
assert len(timestamps) == len(closest_frames)
|
||||||
|
return closest_frames
|
||||||
|
|
||||||
|
|
||||||
def encode_video_frames(
|
def encode_video_frames(
|
||||||
imgs_dir: Path | str,
|
imgs_dir: Path | str,
|
||||||
video_path: Path | str,
|
video_path: Path | str,
|
||||||
|
|
|
@ -69,6 +69,7 @@ dependencies = [
|
||||||
"rerun-sdk>=0.21.0",
|
"rerun-sdk>=0.21.0",
|
||||||
"termcolor>=2.4.0",
|
"termcolor>=2.4.0",
|
||||||
"torch>=2.2.1",
|
"torch>=2.2.1",
|
||||||
|
"torchcodec>=0.2.1",
|
||||||
"torchvision>=0.21.0",
|
"torchvision>=0.21.0",
|
||||||
"wandb>=0.16.3",
|
"wandb>=0.16.3",
|
||||||
"zarr>=2.17.0",
|
"zarr>=2.17.0",
|
||||||
|
|
Loading…
Reference in New Issue