fix(codec): hot-fix for default codec in linux arm platforms (#868)

This commit is contained in:
Steven Palma 2025-03-17 13:23:11 +01:00 committed by GitHub
parent 9f0a8a49d0
commit 1c15bab70f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 29 additions and 9 deletions

View File

@ -51,7 +51,7 @@ For a comprehensive list and documentation of these parameters, see the ffmpeg d
### Decoding parameters ### Decoding parameters
**Decoder** **Decoder**
We tested two video decoding backends from torchvision: We tested two video decoding backends from torchvision:
- `pyav` (default) - `pyav`
- `video_reader` (requires to build torchvision from source) - `video_reader` (requires to build torchvision from source)
**Requested timestamps** **Requested timestamps**

View File

@ -69,6 +69,7 @@ from lerobot.common.datasets.video_utils import (
VideoFrame, VideoFrame,
decode_video_frames, decode_video_frames,
encode_video_frames, encode_video_frames,
get_safe_default_codec,
get_video_info, get_video_info,
) )
from lerobot.common.robot_devices.robots.utils import Robot from lerobot.common.robot_devices.robots.utils import Robot
@ -462,7 +463,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
video files are already present on local disk, they won't be downloaded again. Defaults to video files are already present on local disk, they won't be downloaded again. Defaults to
True. True.
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec. video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'.
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision. You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
""" """
super().__init__() super().__init__()
@ -473,7 +474,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.episodes = episodes self.episodes = episodes
self.tolerance_s = tolerance_s self.tolerance_s = tolerance_s
self.revision = revision if revision else CODEBASE_VERSION self.revision = revision if revision else CODEBASE_VERSION
self.video_backend = video_backend if video_backend else "torchcodec" self.video_backend = video_backend if video_backend else get_safe_default_codec()
self.delta_indices = None self.delta_indices = None
# Unused attributes # Unused attributes
@ -1027,7 +1028,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.delta_timestamps = None obj.delta_timestamps = None
obj.delta_indices = None obj.delta_indices = None
obj.episode_data_index = None obj.episode_data_index = None
obj.video_backend = video_backend if video_backend is not None else "torchcodec" obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
return obj return obj

View File

@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import importlib
import json import json
import logging import logging
import subprocess import subprocess
@ -27,14 +28,23 @@ import torch
import torchvision import torchvision
from datasets.features.features import register_feature from datasets.features.features import register_feature
from PIL import Image from PIL import Image
from torchcodec.decoders import VideoDecoder
def get_safe_default_codec():
if importlib.util.find_spec("torchcodec"):
return "torchcodec"
else:
logging.warning(
"'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder"
)
return "pyav"
def decode_video_frames( def decode_video_frames(
video_path: Path | str, video_path: Path | str,
timestamps: list[float], timestamps: list[float],
tolerance_s: float, tolerance_s: float,
backend: str = "torchcodec", backend: str | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Decodes video frames using the specified backend. Decodes video frames using the specified backend.
@ -43,13 +53,15 @@ def decode_video_frames(
video_path (Path): Path to the video file. video_path (Path): Path to the video file.
timestamps (list[float]): List of timestamps to extract frames. timestamps (list[float]): List of timestamps to extract frames.
tolerance_s (float): Allowed deviation in seconds for frame retrieval. tolerance_s (float): Allowed deviation in seconds for frame retrieval.
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec". backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available in the platform; otherwise, defaults to "pyav"..
Returns: Returns:
torch.Tensor: Decoded frames. torch.Tensor: Decoded frames.
Currently supports torchcodec on cpu and pyav. Currently supports torchcodec on cpu and pyav.
""" """
if backend is None:
backend = get_safe_default_codec()
if backend == "torchcodec": if backend == "torchcodec":
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s) return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
elif backend in ["pyav", "video_reader"]: elif backend in ["pyav", "video_reader"]:
@ -173,6 +185,12 @@ def decode_video_frames_torchcodec(
and all subsequent frames until reaching the requested frame. The number of key frames in a video and all subsequent frames until reaching the requested frame. The number of key frames in a video
can be adjusted during encoding to take into account decoding time and video size in bytes. can be adjusted during encoding to take into account decoding time and video size in bytes.
""" """
if importlib.util.find_spec("torchcodec"):
from torchcodec.decoders import VideoDecoder
else:
raise ImportError("torchcodec is required but not available.")
# initialize video decoder # initialize video decoder
decoder = VideoDecoder(video_path, device=device, seek_mode="approximate") decoder = VideoDecoder(video_path, device=device, seek_mode="approximate")
loaded_frames = [] loaded_frames = []

View File

@ -20,6 +20,7 @@ from lerobot.common import (
policies, # noqa: F401 policies, # noqa: F401
) )
from lerobot.common.datasets.transforms import ImageTransformsConfig from lerobot.common.datasets.transforms import ImageTransformsConfig
from lerobot.common.datasets.video_utils import get_safe_default_codec
@dataclass @dataclass
@ -35,7 +36,7 @@ class DatasetConfig:
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig) image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
revision: str | None = None revision: str | None = None
use_imagenet_stats: bool = True use_imagenet_stats: bool = True
video_backend: str = "pyav" video_backend: str = field(default_factory=get_safe_default_codec)
@dataclass @dataclass

View File

@ -69,7 +69,7 @@ dependencies = [
"rerun-sdk>=0.21.0", "rerun-sdk>=0.21.0",
"termcolor>=2.4.0", "termcolor>=2.4.0",
"torch>=2.2.1", "torch>=2.2.1",
"torchcodec>=0.2.1", "torchcodec>=0.2.1 ; sys_platform != 'linux' or (sys_platform == 'linux' and platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')",
"torchvision>=0.21.0", "torchvision>=0.21.0",
"wandb>=0.16.3", "wandb>=0.16.3",
"zarr>=2.17.0", "zarr>=2.17.0",