Add video_backend in config

This commit is contained in:
Simon Alibert 2024-05-30 11:29:19 +00:00
parent 3059e25a55
commit fcfa20299e
3 changed files with 22 additions and 19 deletions

View File

@ -53,6 +53,7 @@ def make_dataset(
cfg.dataset_repo_id,
split=split,
delta_timestamps=cfg.training.get("delta_timestamps"),
video_backend=cfg.video_backend,
)
if cfg.get("override_dataset_stats"):

View File

@ -44,6 +44,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
split: str = "train",
transform: callable = None,
delta_timestamps: dict[list[float]] | None = None,
video_backend: str | None = None,
):
super().__init__()
self.repo_id = repo_id
@ -65,6 +66,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.info = load_info(repo_id, version, root)
if self.video:
self.videos_dir = load_videos(repo_id, version, root)
self.video_backend = video_backend if video_backend is not None else "pyav"
@property
def fps(self) -> int:
@ -145,6 +147,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.video_frame_keys,
self.videos_dir,
self.tolerance_s,
self.video_backend,
)
if self.transform is not None:

View File

@ -27,7 +27,11 @@ from datasets.features.features import register_feature
def load_from_videos(
item: dict[str, torch.Tensor], video_frame_keys: list[str], videos_dir: Path, tolerance_s: float
item: dict[str, torch.Tensor],
video_frame_keys: list[str],
videos_dir: Path,
tolerance_s: float,
backend: str = "pyav",
):
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a Segmentation Fault.
@ -46,14 +50,14 @@ def load_from_videos(
raise NotImplementedError("All video paths are expected to be the same for now.")
video_path = data_dir / paths[0]
frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s)
frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
item[key] = frames
else:
# load one frame
timestamps = [item[key]["timestamp"]]
video_path = data_dir / item[key]["path"]
frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s)
frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
item[key] = frames[0]
return item
@ -63,11 +67,16 @@ def decode_video_frames_torchvision(
video_path: str,
timestamps: list[float],
tolerance_s: float,
device: str = "cpu",
backend: str = "pyav",
log_loaded_timestamps: bool = False,
):
"""Loads frames associated to the requested timestamps of a video
The backend can be either "pyav" (default) or "video_reader".
"video_reader" requires installing torchvision from source, see:
https://github.com/pytorch/vision/blob/main/torchvision/csrc/io/decoder/gpu/README.rst
(note that you need to compile against ffmpeg<4.3)
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to
that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame,
@ -78,21 +87,9 @@ def decode_video_frames_torchvision(
# set backend
keyframes_only = False
if device == "cpu":
# explicitely use pyav
torchvision.set_video_backend("pyav")
torchvision.set_video_backend(backend)
if backend == "pyav":
keyframes_only = True # pyav doesnt support accuracte seek
elif device == "cuda":
# TODO(rcadene, aliberts): implement video decoding with GPU
# torchvision.set_video_backend("cuda")
# torchvision.set_video_backend("video_reader")
# requires installing torchvision from source, see: https://github.com/pytorch/vision/blob/main/torchvision/csrc/io/decoder/gpu/README.rst
# check possible bug: https://github.com/pytorch/vision/issues/7745
raise NotImplementedError(
"Video decoding on gpu with cuda is currently not supported. Use `device='cpu'`."
)
else:
raise ValueError(device)
# set a video stream reader
# TODO(rcadene): also load audio stream at the same time
@ -120,7 +117,9 @@ def decode_video_frames_torchvision(
if current_ts >= last_ts:
break
reader.container.close()
if backend == "pyav":
reader.container.close()
reader = None
query_ts = torch.tensor(timestamps)