From fcfa20299ea610935b113748a68889ba33e358ab Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Thu, 30 May 2024 11:29:19 +0000 Subject: [PATCH] Add video_backend in config --- lerobot/common/datasets/factory.py | 1 + lerobot/common/datasets/lerobot_dataset.py | 3 ++ lerobot/common/datasets/video_utils.py | 37 +++++++++++----------- 3 files changed, 22 insertions(+), 19 deletions(-) diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 7bdc2ca9..d402d42a 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -53,6 +53,7 @@ def make_dataset( cfg.dataset_repo_id, split=split, delta_timestamps=cfg.training.get("delta_timestamps"), + video_backend=cfg.video_backend, ) if cfg.get("override_dataset_stats"): diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 057e4770..a5653f39 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -44,6 +44,7 @@ class LeRobotDataset(torch.utils.data.Dataset): split: str = "train", transform: callable = None, delta_timestamps: dict[list[float]] | None = None, + video_backend: str | None = None, ): super().__init__() self.repo_id = repo_id @@ -65,6 +66,7 @@ class LeRobotDataset(torch.utils.data.Dataset): self.info = load_info(repo_id, version, root) if self.video: self.videos_dir = load_videos(repo_id, version, root) + self.video_backend = video_backend if video_backend is not None else "pyav" @property def fps(self) -> int: @@ -145,6 +147,7 @@ class LeRobotDataset(torch.utils.data.Dataset): self.video_frame_keys, self.videos_dir, self.tolerance_s, + self.video_backend, ) if self.transform is not None: diff --git a/lerobot/common/datasets/video_utils.py b/lerobot/common/datasets/video_utils.py index edfca918..0ac4ae89 100644 --- a/lerobot/common/datasets/video_utils.py +++ b/lerobot/common/datasets/video_utils.py @@ -27,7 +27,11 @@ from datasets.features.features import register_feature def load_from_videos( - item: dict[str, torch.Tensor], video_frame_keys: list[str], videos_dir: Path, tolerance_s: float + item: dict[str, torch.Tensor], + video_frame_keys: list[str], + videos_dir: Path, + tolerance_s: float, + backend: str = "pyav", ): """Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a Segmentation Fault. @@ -46,14 +50,14 @@ def load_from_videos( raise NotImplementedError("All video paths are expected to be the same for now.") video_path = data_dir / paths[0] - frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s) + frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend) item[key] = frames else: # load one frame timestamps = [item[key]["timestamp"]] video_path = data_dir / item[key]["path"] - frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s) + frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend) item[key] = frames[0] return item @@ -63,11 +67,16 @@ def decode_video_frames_torchvision( video_path: str, timestamps: list[float], tolerance_s: float, - device: str = "cpu", + backend: str = "pyav", log_loaded_timestamps: bool = False, ): """Loads frames associated to the requested timestamps of a video + The backend can be either "pyav" (default) or "video_reader". + "video_reader" requires installing torchvision from source, see: + https://github.com/pytorch/vision/blob/main/torchvision/csrc/io/decoder/gpu/README.rst + (note that you need to compile against ffmpeg<4.3) + Note: Video benefits from inter-frame compression. Instead of storing every frame individually, the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame, @@ -78,21 +87,9 @@ def decode_video_frames_torchvision( # set backend keyframes_only = False - if device == "cpu": - # explicitely use pyav - torchvision.set_video_backend("pyav") + torchvision.set_video_backend(backend) + if backend == "pyav": keyframes_only = True # pyav doesnt support accuracte seek - elif device == "cuda": - # TODO(rcadene, aliberts): implement video decoding with GPU - # torchvision.set_video_backend("cuda") - # torchvision.set_video_backend("video_reader") - # requires installing torchvision from source, see: https://github.com/pytorch/vision/blob/main/torchvision/csrc/io/decoder/gpu/README.rst - # check possible bug: https://github.com/pytorch/vision/issues/7745 - raise NotImplementedError( - "Video decoding on gpu with cuda is currently not supported. Use `device='cpu'`." - ) - else: - raise ValueError(device) # set a video stream reader # TODO(rcadene): also load audio stream at the same time @@ -120,7 +117,9 @@ def decode_video_frames_torchvision( if current_ts >= last_ts: break - reader.container.close() + if backend == "pyav": + reader.container.close() + reader = None query_ts = torch.tensor(timestamps)