Add video_backend in config
This commit is contained in:
parent
3059e25a55
commit
fcfa20299e
|
@ -53,6 +53,7 @@ def make_dataset(
|
||||||
cfg.dataset_repo_id,
|
cfg.dataset_repo_id,
|
||||||
split=split,
|
split=split,
|
||||||
delta_timestamps=cfg.training.get("delta_timestamps"),
|
delta_timestamps=cfg.training.get("delta_timestamps"),
|
||||||
|
video_backend=cfg.video_backend,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.get("override_dataset_stats"):
|
if cfg.get("override_dataset_stats"):
|
||||||
|
|
|
@ -44,6 +44,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
split: str = "train",
|
split: str = "train",
|
||||||
transform: callable = None,
|
transform: callable = None,
|
||||||
delta_timestamps: dict[list[float]] | None = None,
|
delta_timestamps: dict[list[float]] | None = None,
|
||||||
|
video_backend: str | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.repo_id = repo_id
|
self.repo_id = repo_id
|
||||||
|
@ -65,6 +66,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
self.info = load_info(repo_id, version, root)
|
self.info = load_info(repo_id, version, root)
|
||||||
if self.video:
|
if self.video:
|
||||||
self.videos_dir = load_videos(repo_id, version, root)
|
self.videos_dir = load_videos(repo_id, version, root)
|
||||||
|
self.video_backend = video_backend if video_backend is not None else "pyav"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def fps(self) -> int:
|
def fps(self) -> int:
|
||||||
|
@ -145,6 +147,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
self.video_frame_keys,
|
self.video_frame_keys,
|
||||||
self.videos_dir,
|
self.videos_dir,
|
||||||
self.tolerance_s,
|
self.tolerance_s,
|
||||||
|
self.video_backend,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.transform is not None:
|
if self.transform is not None:
|
||||||
|
|
|
@ -27,7 +27,11 @@ from datasets.features.features import register_feature
|
||||||
|
|
||||||
|
|
||||||
def load_from_videos(
|
def load_from_videos(
|
||||||
item: dict[str, torch.Tensor], video_frame_keys: list[str], videos_dir: Path, tolerance_s: float
|
item: dict[str, torch.Tensor],
|
||||||
|
video_frame_keys: list[str],
|
||||||
|
videos_dir: Path,
|
||||||
|
tolerance_s: float,
|
||||||
|
backend: str = "pyav",
|
||||||
):
|
):
|
||||||
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
||||||
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a Segmentation Fault.
|
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a Segmentation Fault.
|
||||||
|
@ -46,14 +50,14 @@ def load_from_videos(
|
||||||
raise NotImplementedError("All video paths are expected to be the same for now.")
|
raise NotImplementedError("All video paths are expected to be the same for now.")
|
||||||
video_path = data_dir / paths[0]
|
video_path = data_dir / paths[0]
|
||||||
|
|
||||||
frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s)
|
frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
||||||
item[key] = frames
|
item[key] = frames
|
||||||
else:
|
else:
|
||||||
# load one frame
|
# load one frame
|
||||||
timestamps = [item[key]["timestamp"]]
|
timestamps = [item[key]["timestamp"]]
|
||||||
video_path = data_dir / item[key]["path"]
|
video_path = data_dir / item[key]["path"]
|
||||||
|
|
||||||
frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s)
|
frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
||||||
item[key] = frames[0]
|
item[key] = frames[0]
|
||||||
|
|
||||||
return item
|
return item
|
||||||
|
@ -63,11 +67,16 @@ def decode_video_frames_torchvision(
|
||||||
video_path: str,
|
video_path: str,
|
||||||
timestamps: list[float],
|
timestamps: list[float],
|
||||||
tolerance_s: float,
|
tolerance_s: float,
|
||||||
device: str = "cpu",
|
backend: str = "pyav",
|
||||||
log_loaded_timestamps: bool = False,
|
log_loaded_timestamps: bool = False,
|
||||||
):
|
):
|
||||||
"""Loads frames associated to the requested timestamps of a video
|
"""Loads frames associated to the requested timestamps of a video
|
||||||
|
|
||||||
|
The backend can be either "pyav" (default) or "video_reader".
|
||||||
|
"video_reader" requires installing torchvision from source, see:
|
||||||
|
https://github.com/pytorch/vision/blob/main/torchvision/csrc/io/decoder/gpu/README.rst
|
||||||
|
(note that you need to compile against ffmpeg<4.3)
|
||||||
|
|
||||||
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
|
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
|
||||||
the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to
|
the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to
|
||||||
that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame,
|
that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame,
|
||||||
|
@ -78,21 +87,9 @@ def decode_video_frames_torchvision(
|
||||||
|
|
||||||
# set backend
|
# set backend
|
||||||
keyframes_only = False
|
keyframes_only = False
|
||||||
if device == "cpu":
|
torchvision.set_video_backend(backend)
|
||||||
# explicitely use pyav
|
if backend == "pyav":
|
||||||
torchvision.set_video_backend("pyav")
|
|
||||||
keyframes_only = True # pyav doesnt support accuracte seek
|
keyframes_only = True # pyav doesnt support accuracte seek
|
||||||
elif device == "cuda":
|
|
||||||
# TODO(rcadene, aliberts): implement video decoding with GPU
|
|
||||||
# torchvision.set_video_backend("cuda")
|
|
||||||
# torchvision.set_video_backend("video_reader")
|
|
||||||
# requires installing torchvision from source, see: https://github.com/pytorch/vision/blob/main/torchvision/csrc/io/decoder/gpu/README.rst
|
|
||||||
# check possible bug: https://github.com/pytorch/vision/issues/7745
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Video decoding on gpu with cuda is currently not supported. Use `device='cpu'`."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(device)
|
|
||||||
|
|
||||||
# set a video stream reader
|
# set a video stream reader
|
||||||
# TODO(rcadene): also load audio stream at the same time
|
# TODO(rcadene): also load audio stream at the same time
|
||||||
|
@ -120,7 +117,9 @@ def decode_video_frames_torchvision(
|
||||||
if current_ts >= last_ts:
|
if current_ts >= last_ts:
|
||||||
break
|
break
|
||||||
|
|
||||||
reader.container.close()
|
if backend == "pyav":
|
||||||
|
reader.container.close()
|
||||||
|
|
||||||
reader = None
|
reader = None
|
||||||
|
|
||||||
query_ts = torch.tensor(timestamps)
|
query_ts = torch.tensor(timestamps)
|
||||||
|
|
Loading…
Reference in New Issue