Add video_backend in config
This commit is contained in:
parent
3059e25a55
commit
fcfa20299e
|
@ -53,6 +53,7 @@ def make_dataset(
|
|||
cfg.dataset_repo_id,
|
||||
split=split,
|
||||
delta_timestamps=cfg.training.get("delta_timestamps"),
|
||||
video_backend=cfg.video_backend,
|
||||
)
|
||||
|
||||
if cfg.get("override_dataset_stats"):
|
||||
|
|
|
@ -44,6 +44,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
split: str = "train",
|
||||
transform: callable = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
video_backend: str | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
|
@ -65,6 +66,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
self.info = load_info(repo_id, version, root)
|
||||
if self.video:
|
||||
self.videos_dir = load_videos(repo_id, version, root)
|
||||
self.video_backend = video_backend if video_backend is not None else "pyav"
|
||||
|
||||
@property
|
||||
def fps(self) -> int:
|
||||
|
@ -145,6 +147,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
self.video_frame_keys,
|
||||
self.videos_dir,
|
||||
self.tolerance_s,
|
||||
self.video_backend,
|
||||
)
|
||||
|
||||
if self.transform is not None:
|
||||
|
|
|
@ -27,7 +27,11 @@ from datasets.features.features import register_feature
|
|||
|
||||
|
||||
def load_from_videos(
|
||||
item: dict[str, torch.Tensor], video_frame_keys: list[str], videos_dir: Path, tolerance_s: float
|
||||
item: dict[str, torch.Tensor],
|
||||
video_frame_keys: list[str],
|
||||
videos_dir: Path,
|
||||
tolerance_s: float,
|
||||
backend: str = "pyav",
|
||||
):
|
||||
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
||||
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a Segmentation Fault.
|
||||
|
@ -46,14 +50,14 @@ def load_from_videos(
|
|||
raise NotImplementedError("All video paths are expected to be the same for now.")
|
||||
video_path = data_dir / paths[0]
|
||||
|
||||
frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s)
|
||||
frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
||||
item[key] = frames
|
||||
else:
|
||||
# load one frame
|
||||
timestamps = [item[key]["timestamp"]]
|
||||
video_path = data_dir / item[key]["path"]
|
||||
|
||||
frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s)
|
||||
frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
||||
item[key] = frames[0]
|
||||
|
||||
return item
|
||||
|
@ -63,11 +67,16 @@ def decode_video_frames_torchvision(
|
|||
video_path: str,
|
||||
timestamps: list[float],
|
||||
tolerance_s: float,
|
||||
device: str = "cpu",
|
||||
backend: str = "pyav",
|
||||
log_loaded_timestamps: bool = False,
|
||||
):
|
||||
"""Loads frames associated to the requested timestamps of a video
|
||||
|
||||
The backend can be either "pyav" (default) or "video_reader".
|
||||
"video_reader" requires installing torchvision from source, see:
|
||||
https://github.com/pytorch/vision/blob/main/torchvision/csrc/io/decoder/gpu/README.rst
|
||||
(note that you need to compile against ffmpeg<4.3)
|
||||
|
||||
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
|
||||
the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to
|
||||
that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame,
|
||||
|
@ -78,21 +87,9 @@ def decode_video_frames_torchvision(
|
|||
|
||||
# set backend
|
||||
keyframes_only = False
|
||||
if device == "cpu":
|
||||
# explicitely use pyav
|
||||
torchvision.set_video_backend("pyav")
|
||||
torchvision.set_video_backend(backend)
|
||||
if backend == "pyav":
|
||||
keyframes_only = True # pyav doesnt support accuracte seek
|
||||
elif device == "cuda":
|
||||
# TODO(rcadene, aliberts): implement video decoding with GPU
|
||||
# torchvision.set_video_backend("cuda")
|
||||
# torchvision.set_video_backend("video_reader")
|
||||
# requires installing torchvision from source, see: https://github.com/pytorch/vision/blob/main/torchvision/csrc/io/decoder/gpu/README.rst
|
||||
# check possible bug: https://github.com/pytorch/vision/issues/7745
|
||||
raise NotImplementedError(
|
||||
"Video decoding on gpu with cuda is currently not supported. Use `device='cpu'`."
|
||||
)
|
||||
else:
|
||||
raise ValueError(device)
|
||||
|
||||
# set a video stream reader
|
||||
# TODO(rcadene): also load audio stream at the same time
|
||||
|
@ -120,7 +117,9 @@ def decode_video_frames_torchvision(
|
|||
if current_ts >= last_ts:
|
||||
break
|
||||
|
||||
reader.container.close()
|
||||
if backend == "pyav":
|
||||
reader.container.close()
|
||||
|
||||
reader = None
|
||||
|
||||
query_ts = torch.tensor(timestamps)
|
||||
|
|
Loading…
Reference in New Issue