fix(codec): hot-fix for default codec in linux arm platforms (#868)
This commit is contained in:
parent
9f0a8a49d0
commit
1c15bab70f
|
@ -51,7 +51,7 @@ For a comprehensive list and documentation of these parameters, see the ffmpeg d
|
||||||
### Decoding parameters
|
### Decoding parameters
|
||||||
**Decoder**
|
**Decoder**
|
||||||
We tested two video decoding backends from torchvision:
|
We tested two video decoding backends from torchvision:
|
||||||
- `pyav` (default)
|
- `pyav`
|
||||||
- `video_reader` (requires to build torchvision from source)
|
- `video_reader` (requires to build torchvision from source)
|
||||||
|
|
||||||
**Requested timestamps**
|
**Requested timestamps**
|
||||||
|
|
|
@ -69,6 +69,7 @@ from lerobot.common.datasets.video_utils import (
|
||||||
VideoFrame,
|
VideoFrame,
|
||||||
decode_video_frames,
|
decode_video_frames,
|
||||||
encode_video_frames,
|
encode_video_frames,
|
||||||
|
get_safe_default_codec,
|
||||||
get_video_info,
|
get_video_info,
|
||||||
)
|
)
|
||||||
from lerobot.common.robot_devices.robots.utils import Robot
|
from lerobot.common.robot_devices.robots.utils import Robot
|
||||||
|
@ -462,7 +463,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
|
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
|
||||||
video files are already present on local disk, they won't be downloaded again. Defaults to
|
video files are already present on local disk, they won't be downloaded again. Defaults to
|
||||||
True.
|
True.
|
||||||
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec.
|
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'.
|
||||||
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
|
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -473,7 +474,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
self.episodes = episodes
|
self.episodes = episodes
|
||||||
self.tolerance_s = tolerance_s
|
self.tolerance_s = tolerance_s
|
||||||
self.revision = revision if revision else CODEBASE_VERSION
|
self.revision = revision if revision else CODEBASE_VERSION
|
||||||
self.video_backend = video_backend if video_backend else "torchcodec"
|
self.video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||||
self.delta_indices = None
|
self.delta_indices = None
|
||||||
|
|
||||||
# Unused attributes
|
# Unused attributes
|
||||||
|
@ -1027,7 +1028,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
obj.delta_timestamps = None
|
obj.delta_timestamps = None
|
||||||
obj.delta_indices = None
|
obj.delta_indices = None
|
||||||
obj.episode_data_index = None
|
obj.episode_data_index = None
|
||||||
obj.video_backend = video_backend if video_backend is not None else "torchcodec"
|
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import importlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import subprocess
|
import subprocess
|
||||||
|
@ -27,14 +28,23 @@ import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
from datasets.features.features import register_feature
|
from datasets.features.features import register_feature
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torchcodec.decoders import VideoDecoder
|
|
||||||
|
|
||||||
|
def get_safe_default_codec():
|
||||||
|
if importlib.util.find_spec("torchcodec"):
|
||||||
|
return "torchcodec"
|
||||||
|
else:
|
||||||
|
logging.warning(
|
||||||
|
"'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder"
|
||||||
|
)
|
||||||
|
return "pyav"
|
||||||
|
|
||||||
|
|
||||||
def decode_video_frames(
|
def decode_video_frames(
|
||||||
video_path: Path | str,
|
video_path: Path | str,
|
||||||
timestamps: list[float],
|
timestamps: list[float],
|
||||||
tolerance_s: float,
|
tolerance_s: float,
|
||||||
backend: str = "torchcodec",
|
backend: str | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Decodes video frames using the specified backend.
|
Decodes video frames using the specified backend.
|
||||||
|
@ -43,13 +53,15 @@ def decode_video_frames(
|
||||||
video_path (Path): Path to the video file.
|
video_path (Path): Path to the video file.
|
||||||
timestamps (list[float]): List of timestamps to extract frames.
|
timestamps (list[float]): List of timestamps to extract frames.
|
||||||
tolerance_s (float): Allowed deviation in seconds for frame retrieval.
|
tolerance_s (float): Allowed deviation in seconds for frame retrieval.
|
||||||
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec".
|
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available in the platform; otherwise, defaults to "pyav"..
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: Decoded frames.
|
torch.Tensor: Decoded frames.
|
||||||
|
|
||||||
Currently supports torchcodec on cpu and pyav.
|
Currently supports torchcodec on cpu and pyav.
|
||||||
"""
|
"""
|
||||||
|
if backend is None:
|
||||||
|
backend = get_safe_default_codec()
|
||||||
if backend == "torchcodec":
|
if backend == "torchcodec":
|
||||||
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
|
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
|
||||||
elif backend in ["pyav", "video_reader"]:
|
elif backend in ["pyav", "video_reader"]:
|
||||||
|
@ -173,6 +185,12 @@ def decode_video_frames_torchcodec(
|
||||||
and all subsequent frames until reaching the requested frame. The number of key frames in a video
|
and all subsequent frames until reaching the requested frame. The number of key frames in a video
|
||||||
can be adjusted during encoding to take into account decoding time and video size in bytes.
|
can be adjusted during encoding to take into account decoding time and video size in bytes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if importlib.util.find_spec("torchcodec"):
|
||||||
|
from torchcodec.decoders import VideoDecoder
|
||||||
|
else:
|
||||||
|
raise ImportError("torchcodec is required but not available.")
|
||||||
|
|
||||||
# initialize video decoder
|
# initialize video decoder
|
||||||
decoder = VideoDecoder(video_path, device=device, seek_mode="approximate")
|
decoder = VideoDecoder(video_path, device=device, seek_mode="approximate")
|
||||||
loaded_frames = []
|
loaded_frames = []
|
||||||
|
|
|
@ -20,6 +20,7 @@ from lerobot.common import (
|
||||||
policies, # noqa: F401
|
policies, # noqa: F401
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.transforms import ImageTransformsConfig
|
from lerobot.common.datasets.transforms import ImageTransformsConfig
|
||||||
|
from lerobot.common.datasets.video_utils import get_safe_default_codec
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -35,7 +36,7 @@ class DatasetConfig:
|
||||||
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
|
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
|
||||||
revision: str | None = None
|
revision: str | None = None
|
||||||
use_imagenet_stats: bool = True
|
use_imagenet_stats: bool = True
|
||||||
video_backend: str = "pyav"
|
video_backend: str = field(default_factory=get_safe_default_codec)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
@ -69,7 +69,7 @@ dependencies = [
|
||||||
"rerun-sdk>=0.21.0",
|
"rerun-sdk>=0.21.0",
|
||||||
"termcolor>=2.4.0",
|
"termcolor>=2.4.0",
|
||||||
"torch>=2.2.1",
|
"torch>=2.2.1",
|
||||||
"torchcodec>=0.2.1",
|
"torchcodec>=0.2.1 ; sys_platform != 'linux' or (sys_platform == 'linux' and platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')",
|
||||||
"torchvision>=0.21.0",
|
"torchvision>=0.21.0",
|
||||||
"wandb>=0.16.3",
|
"wandb>=0.16.3",
|
||||||
"zarr>=2.17.0",
|
"zarr>=2.17.0",
|
||||||
|
|
Loading…
Reference in New Issue