From 2abef3bef9dabd4a244bb9f6adea4a698c928d9d Mon Sep 17 00:00:00 2001 From: Simon Alibert <75076266+aliberts@users.noreply.github.com> Date: Wed, 19 Jun 2024 17:15:25 +0200 Subject: [PATCH] Enable `video_reader` backend (#220) Co-authored-by: Alexander Soare <alexander.soare159@gmail.com> --- docker/lerobot-gpu-dev/Dockerfile | 6 +- .../_video_benchmark/capture_camera_feed.py | 90 +++++ .../_video_benchmark/run_video_benchmark.py | 317 ++++++++++-------- lerobot/common/datasets/factory.py | 2 + lerobot/common/datasets/lerobot_dataset.py | 7 + .../push_dataset_to_hub/cam_png_format.py | 101 ++++++ lerobot/common/datasets/video_utils.py | 44 +-- lerobot/configs/default.yaml | 1 + lerobot/scripts/push_dataset_to_hub.py | 7 +- poetry.lock | 107 +++--- pyproject.toml | 2 + 11 files changed, 464 insertions(+), 220 deletions(-) create mode 100644 lerobot/common/datasets/_video_benchmark/capture_camera_feed.py create mode 100644 lerobot/common/datasets/push_dataset_to_hub/cam_png_format.py diff --git a/docker/lerobot-gpu-dev/Dockerfile b/docker/lerobot-gpu-dev/Dockerfile index e5c7d454..94b4f3ac 100644 --- a/docker/lerobot-gpu-dev/Dockerfile +++ b/docker/lerobot-gpu-dev/Dockerfile @@ -1,4 +1,4 @@ -FROM nvidia/cuda:12.4.1-base-ubuntu22.04 +FROM nvidia/cuda:12.2.2-devel-ubuntu22.04 # Configure image ARG PYTHON_VERSION=3.10 @@ -10,9 +10,9 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ git git-lfs openssh-client \ nano vim less util-linux \ htop atop nvtop \ - sed gawk grep curl wget \ + sed gawk grep curl wget zip unzip \ tcpdump sysstat screen tmux \ - libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \ + libglib2.0-0 libgl1-mesa-glx libegl1-mesa \ python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \ && apt-get clean && rm -rf /var/lib/apt/lists/* diff --git a/lerobot/common/datasets/_video_benchmark/capture_camera_feed.py b/lerobot/common/datasets/_video_benchmark/capture_camera_feed.py new file mode 100644 index 00000000..3b4c356a --- /dev/null +++ b/lerobot/common/datasets/_video_benchmark/capture_camera_feed.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Capture video feed from a camera as raw images.""" + +import argparse +import datetime as dt +from pathlib import Path + +import cv2 + + +def display_and_save_video_stream(output_dir: Path, fps: int, width: int, height: int): + now = dt.datetime.now() + capture_dir = output_dir / f"{now:%Y-%m-%d}" / f"{now:%H-%M-%S}" + if not capture_dir.exists(): + capture_dir.mkdir(parents=True, exist_ok=True) + + # Opens the default webcam + cap = cv2.VideoCapture(0) + if not cap.isOpened(): + print("Error: Could not open video stream.") + return + + cap.set(cv2.CAP_PROP_FPS, fps) + cap.set(cv2.CAP_PROP_FRAME_WIDTH, width) + cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height) + + frame_index = 0 + while True: + ret, frame = cap.read() + + if not ret: + print("Error: Could not read frame.") + break + + cv2.imshow("Video Stream", frame) + cv2.imwrite(str(capture_dir / f"frame_{frame_index:06d}.png"), frame) + frame_index += 1 + + # Break the loop on 'q' key press + if cv2.waitKey(1) & 0xFF == ord("q"): + break + + # Release the capture and destroy all windows + cap.release() + cv2.destroyAllWindows() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--output-dir", + type=Path, + default=Path("outputs/cam_capture/"), + help="Directory where the capture images are written. A subfolder named with the current date & time will be created inside it for each capture.", + ) + parser.add_argument( + "--fps", + type=int, + default=30, + help="Frames Per Second of the capture.", + ) + parser.add_argument( + "--width", + type=int, + default=1280, + help="Width of the captured images.", + ) + parser.add_argument( + "--height", + type=int, + default=720, + help="Height of the captured images.", + ) + args = parser.parse_args() + display_and_save_video_stream(**vars(args)) diff --git a/lerobot/common/datasets/_video_benchmark/run_video_benchmark.py b/lerobot/common/datasets/_video_benchmark/run_video_benchmark.py index 8be251dc..d92b5eef 100644 --- a/lerobot/common/datasets/_video_benchmark/run_video_benchmark.py +++ b/lerobot/common/datasets/_video_benchmark/run_video_benchmark.py @@ -13,6 +13,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Assess the performance of video decoding in various configurations. + +This script will run different video decoding benchmarks where one parameter varies at a time. +These parameters and theirs values are specified in the BENCHMARKS dict. + +All of these benchmarks are evaluated within different timestamps modes corresponding to different frame-loading scenarios: + - `1_frame`: 1 single frame is loaded. + - `2_frames`: 2 consecutive frames are loaded. + - `2_frames_4_space`: 2 frames separated by 4 frames are loaded. + - `6_frames`: 6 consecutive frames are loaded. + +These values are more or less arbitrary and based on possible future usage. + +These benchmarks are run on the first episode of each dataset specified in DATASET_REPO_IDS. +Note: These datasets need to be image datasets, not video datasets. +""" + import json import random import shutil @@ -21,15 +38,38 @@ import time from pathlib import Path import einops -import numpy +import numpy as np import PIL import torch +from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.video_utils import ( decode_video_frames_torchvision, ) +OUTPUT_DIR = Path("tmp/run_video_benchmark") +DRY_RUN = False + +DATASET_REPO_IDS = [ + "lerobot/pusht_image", + "aliberts/aloha_mobile_shrimp_image", + "aliberts/paris_street", + "aliberts/kitchen", +] +TIMESTAMPS_MODES = [ + "1_frame", + "2_frames", + "2_frames_4_space", + "6_frames", +] +BENCHMARKS = { + # "pix_fmt": ["yuv420p", "yuv444p"], + # "g": [1, 2, 3, 4, 5, 6, 10, 15, 20, 40, 100, None], + # "crf": [0, 5, 10, 15, 20, None, 25, 30, 40, 50], + "backend": ["pyav", "video_reader"], +} + def get_directory_size(directory): total_size = 0 @@ -56,6 +96,10 @@ def run_video_benchmark( # TODO(rcadene): rewrite with hardcoding of original images and episodes dataset = LeRobotDataset(repo_id) + if dataset.video: + raise ValueError( + f"Use only image dataset for running this benchmark. Video dataset provided: {repo_id}" + ) # Get fps fps = dataset.fps @@ -68,10 +112,11 @@ def run_video_benchmark( if not imgs_dir.exists(): imgs_dir.mkdir(parents=True, exist_ok=True) hf_dataset = dataset.hf_dataset.with_format(None) - imgs_dataset = hf_dataset.select_columns("observation.image") + img_keys = [key for key in hf_dataset.features if key.startswith("observation.image")] + imgs_dataset = hf_dataset.select_columns(img_keys[0]) for i, item in enumerate(imgs_dataset): - img = item["observation.image"] + img = item[img_keys[0]] img.save(str(imgs_dir / f"frame_{i:06d}.png"), quality=100) if i >= ep_num_images - 1: @@ -107,7 +152,7 @@ def run_video_benchmark( decoder = cfg["decoder"] decoder_kwgs = cfg["decoder_kwgs"] - device = cfg["device"] + backend = cfg["backend"] if decoder == "torchvision": decode_frames_fn = decode_video_frames_torchvision @@ -116,12 +161,12 @@ def run_video_benchmark( # Estimate average loading time - def load_original_frames(imgs_dir, timestamps): + def load_original_frames(imgs_dir, timestamps) -> torch.Tensor: frames = [] for ts in timestamps: idx = int(ts * fps) frame = PIL.Image.open(imgs_dir / f"frame_{idx:06d}.png") - frame = torch.from_numpy(numpy.array(frame)) + frame = torch.from_numpy(np.array(frame)) frame = frame.type(torch.float32) / 255 frame = einops.rearrange(frame, "h w c -> c h w") frames.append(frame) @@ -130,6 +175,9 @@ def run_video_benchmark( list_avg_load_time = [] list_avg_load_time_from_images = [] per_pixel_l2_errors = [] + psnr_values = [] + ssim_values = [] + mse_values = [] random.seed(seed) @@ -142,7 +190,7 @@ def run_video_benchmark( elif timestamps_mode == "2_frames": timestamps = [ts - 1 / fps, ts] elif timestamps_mode == "2_frames_4_space": - timestamps = [ts - 4 / fps, ts] + timestamps = [ts - 5 / fps, ts] elif timestamps_mode == "6_frames": timestamps = [ts - i / fps for i in range(6)][::-1] else: @@ -152,7 +200,7 @@ def run_video_benchmark( start_time_s = time.monotonic() frames = decode_frames_fn( - video_path, timestamps=timestamps, tolerance_s=1e-4, device=device, **decoder_kwgs + video_path, timestamps=timestamps, tolerance_s=1e-4, backend=backend, **decoder_kwgs ) avg_load_time = (time.monotonic() - start_time_s) / num_frames list_avg_load_time.append(avg_load_time) @@ -162,11 +210,19 @@ def run_video_benchmark( avg_load_time_from_images = (time.monotonic() - start_time_s) / num_frames list_avg_load_time_from_images.append(avg_load_time_from_images) - # Estimate average L2 error between original frames and decoded frames + # Estimate reconstruction error between original frames and decoded frames with various metrics for i, ts in enumerate(timestamps): # are_close = torch.allclose(frames[i], original_frames[i], atol=0.02) num_pixels = original_frames[i].numel() per_pixel_l2_error = torch.norm(frames[i] - original_frames[i], p=2).item() / num_pixels + per_pixel_l2_errors.append(per_pixel_l2_error) + + frame_np, original_frame_np = frames[i].numpy(), original_frames[i].numpy() + psnr_values.append(peak_signal_noise_ratio(original_frame_np, frame_np, data_range=1.0)) + ssim_values.append( + structural_similarity(original_frame_np, frame_np, data_range=1.0, channel_axis=0) + ) + mse_values.append(mean_squared_error(original_frame_np, frame_np)) # save decoded frames if t == 0: @@ -179,15 +235,18 @@ def run_video_benchmark( original_frame = PIL.Image.open(imgs_dir / f"frame_{idx:06d}.png") original_frame.save(output_dir / f"original_frame_{i:06d}.png") - per_pixel_l2_errors.append(per_pixel_l2_error) - - avg_load_time = float(numpy.array(list_avg_load_time).mean()) - avg_load_time_from_images = float(numpy.array(list_avg_load_time_from_images).mean()) - avg_per_pixel_l2_error = float(numpy.array(per_pixel_l2_errors).mean()) + image_size = tuple(dataset[0][dataset.camera_keys[0]].shape[-2:]) + avg_load_time = float(np.array(list_avg_load_time).mean()) + avg_load_time_from_images = float(np.array(list_avg_load_time_from_images).mean()) + avg_per_pixel_l2_error = float(np.array(per_pixel_l2_errors).mean()) + avg_psnr = float(np.mean(psnr_values)) + avg_ssim = float(np.mean(ssim_values)) + avg_mse = float(np.mean(mse_values)) # Save benchmark info info = { + "image_size": image_size, "sum_original_frames_size_bytes": sum_original_frames_size_bytes, "video_size_bytes": video_size_bytes, "avg_load_time_from_images": avg_load_time_from_images, @@ -195,6 +254,9 @@ def run_video_benchmark( "compression_factor": sum_original_frames_size_bytes / video_size_bytes, "load_time_factor": avg_load_time_from_images / avg_load_time, "avg_per_pixel_l2_error": avg_per_pixel_l2_error, + "avg_psnr": avg_psnr, + "avg_ssim": avg_ssim, + "avg_mse": avg_mse, } with open(output_dir / "info.json", "w") as f: @@ -234,138 +296,113 @@ def load_info(out_dir): return info -def main(): - out_dir = Path("tmp/run_video_benchmark") - dry_run = False - repo_ids = ["lerobot/pusht", "lerobot/umi_cup_in_the_wild"] - timestamps_modes = [ - "1_frame", - "2_frames", - "2_frames_4_space", - "6_frames", +def one_variable_study( + var_name: str, var_values: list, repo_ids: list, bench_dir: Path, timestamps_mode: str, dry_run: bool +): + print(f"**`{var_name}`**") + headers = [ + "repo_id", + "image_size", + var_name, + "compression_factor", + "load_time_factor", + "avg_per_pixel_l2_error", + "avg_psnr", + "avg_ssim", + "avg_mse", ] - for timestamps_mode in timestamps_modes: - bench_dir = out_dir / timestamps_mode + rows = [] + base_cfg = { + "repo_id": None, + # video encoding + "g": 2, + "crf": None, + "pix_fmt": "yuv444p", + # video decoding + "backend": "pyav", + "decoder": "torchvision", + "decoder_kwgs": {}, + } + for repo_id in repo_ids: + for val in var_values: + cfg = base_cfg.copy() + cfg["repo_id"] = repo_id + cfg[var_name] = val + if not dry_run: + run_video_benchmark( + bench_dir / repo_id / f"torchvision_{var_name}_{val}", cfg, timestamps_mode + ) + info = load_info(bench_dir / repo_id / f"torchvision_{var_name}_{val}") + width, height = info["image_size"][0], info["image_size"][1] + rows.append( + [ + repo_id, + f"{width} x {height}", + val, + info["compression_factor"], + info["load_time_factor"], + info["avg_per_pixel_l2_error"], + info["avg_psnr"], + info["avg_ssim"], + info["avg_mse"], + ] + ) + display_markdown_table(headers, rows) + + +def best_study(repo_ids: list, bench_dir: Path, timestamps_mode: str, dry_run: bool): + """Change the config once you deciced what's best based on one-variable-studies""" + print("**best**") + headers = [ + "repo_id", + "image_size", + "compression_factor", + "load_time_factor", + "avg_per_pixel_l2_error", + "avg_psnr", + "avg_ssim", + "avg_mse", + ] + rows = [] + for repo_id in repo_ids: + cfg = { + "repo_id": repo_id, + # video encoding + "g": 2, + "crf": None, + "pix_fmt": "yuv444p", + # video decoding + "backend": "video_reader", + "decoder": "torchvision", + "decoder_kwgs": {}, + } + if not dry_run: + run_video_benchmark(bench_dir / repo_id / "torchvision_best", cfg, timestamps_mode) + info = load_info(bench_dir / repo_id / "torchvision_best") + width, height = info["image_size"][0], info["image_size"][1] + rows.append( + [ + repo_id, + f"{width} x {height}", + info["compression_factor"], + info["load_time_factor"], + info["avg_per_pixel_l2_error"], + ] + ) + display_markdown_table(headers, rows) + + +def main(): + for timestamps_mode in TIMESTAMPS_MODES: + bench_dir = OUTPUT_DIR / timestamps_mode print(f"### `{timestamps_mode}`") print() - print("**`pix_fmt`**") - headers = ["repo_id", "pix_fmt", "compression_factor", "load_time_factor", "avg_per_pixel_l2_error"] - rows = [] - for repo_id in repo_ids: - for pix_fmt in ["yuv420p", "yuv444p"]: - cfg = { - "repo_id": repo_id, - # video encoding - "g": 2, - "crf": None, - "pix_fmt": pix_fmt, - # video decoding - "device": "cpu", - "decoder": "torchvision", - "decoder_kwgs": {}, - } - if not dry_run: - run_video_benchmark(bench_dir / repo_id / f"torchvision_{pix_fmt}", cfg, timestamps_mode) - info = load_info(bench_dir / repo_id / f"torchvision_{pix_fmt}") - rows.append( - [ - repo_id, - pix_fmt, - info["compression_factor"], - info["load_time_factor"], - info["avg_per_pixel_l2_error"], - ] - ) - display_markdown_table(headers, rows) + for name, values in BENCHMARKS.items(): + one_variable_study(name, values, DATASET_REPO_IDS, bench_dir, timestamps_mode, DRY_RUN) - print("**`g`**") - headers = ["repo_id", "g", "compression_factor", "load_time_factor", "avg_per_pixel_l2_error"] - rows = [] - for repo_id in repo_ids: - for g in [1, 2, 3, 4, 5, 6, 10, 15, 20, 40, 100, None]: - cfg = { - "repo_id": repo_id, - # video encoding - "g": g, - "pix_fmt": "yuv444p", - # video decoding - "device": "cpu", - "decoder": "torchvision", - "decoder_kwgs": {}, - } - if not dry_run: - run_video_benchmark(bench_dir / repo_id / f"torchvision_g_{g}", cfg, timestamps_mode) - info = load_info(bench_dir / repo_id / f"torchvision_g_{g}") - rows.append( - [ - repo_id, - g, - info["compression_factor"], - info["load_time_factor"], - info["avg_per_pixel_l2_error"], - ] - ) - display_markdown_table(headers, rows) - - print("**`crf`**") - headers = ["repo_id", "crf", "compression_factor", "load_time_factor", "avg_per_pixel_l2_error"] - rows = [] - for repo_id in repo_ids: - for crf in [0, 5, 10, 15, 20, None, 25, 30, 40, 50]: - cfg = { - "repo_id": repo_id, - # video encoding - "g": 2, - "crf": crf, - "pix_fmt": "yuv444p", - # video decoding - "device": "cpu", - "decoder": "torchvision", - "decoder_kwgs": {}, - } - if not dry_run: - run_video_benchmark(bench_dir / repo_id / f"torchvision_crf_{crf}", cfg, timestamps_mode) - info = load_info(bench_dir / repo_id / f"torchvision_crf_{crf}") - rows.append( - [ - repo_id, - crf, - info["compression_factor"], - info["load_time_factor"], - info["avg_per_pixel_l2_error"], - ] - ) - display_markdown_table(headers, rows) - - print("**best**") - headers = ["repo_id", "compression_factor", "load_time_factor", "avg_per_pixel_l2_error"] - rows = [] - for repo_id in repo_ids: - cfg = { - "repo_id": repo_id, - # video encoding - "g": 2, - "crf": None, - "pix_fmt": "yuv444p", - # video decoding - "device": "cpu", - "decoder": "torchvision", - "decoder_kwgs": {}, - } - if not dry_run: - run_video_benchmark(bench_dir / repo_id / "torchvision_best", cfg, timestamps_mode) - info = load_info(bench_dir / repo_id / "torchvision_best") - rows.append( - [ - repo_id, - info["compression_factor"], - info["load_time_factor"], - info["avg_per_pixel_l2_error"], - ] - ) - display_markdown_table(headers, rows) + # best_study(DATASET_REPO_IDS, bench_dir, timestamps_mode, DRY_RUN) if __name__ == "__main__": diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 754bc91b..96a353fb 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -96,6 +96,7 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData split=split, delta_timestamps=cfg.training.get("delta_timestamps"), image_transforms=image_transforms, + video_backend=cfg.video_backend, ) else: dataset = MultiLeRobotDataset( @@ -103,6 +104,7 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData split=split, delta_timestamps=cfg.training.get("delta_timestamps"), image_transforms=image_transforms, + video_backend=cfg.video_backend, ) if cfg.get("override_dataset_stats"): diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index d680b987..61c35aa4 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -48,6 +48,7 @@ class LeRobotDataset(torch.utils.data.Dataset): split: str = "train", image_transforms: Callable | None = None, delta_timestamps: dict[list[float]] | None = None, + video_backend: str | None = None, ): super().__init__() self.repo_id = repo_id @@ -69,6 +70,7 @@ class LeRobotDataset(torch.utils.data.Dataset): self.info = load_info(repo_id, version, root) if self.video: self.videos_dir = load_videos(repo_id, version, root) + self.video_backend = video_backend if video_backend is not None else "pyav" @property def fps(self) -> int: @@ -149,6 +151,7 @@ class LeRobotDataset(torch.utils.data.Dataset): self.video_frame_keys, self.videos_dir, self.tolerance_s, + self.video_backend, ) if self.image_transforms is not None: @@ -188,6 +191,7 @@ class LeRobotDataset(torch.utils.data.Dataset): stats=None, info=None, videos_dir=None, + video_backend=None, ) -> "LeRobotDataset": """Create a LeRobot Dataset from existing data and attributes instead of loading from the filesystem. @@ -210,6 +214,7 @@ class LeRobotDataset(torch.utils.data.Dataset): obj.stats = stats obj.info = info if info is not None else {} obj.videos_dir = videos_dir + obj.video_backend = video_backend if video_backend is not None else "pyav" return obj @@ -228,6 +233,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): split: str = "train", image_transforms: Callable | None = None, delta_timestamps: dict[list[float]] | None = None, + video_backend: str | None = None, ): super().__init__() self.repo_ids = repo_ids @@ -241,6 +247,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): split=split, delta_timestamps=delta_timestamps, image_transforms=image_transforms, + video_backend=video_backend, ) for repo_id in repo_ids ] diff --git a/lerobot/common/datasets/push_dataset_to_hub/cam_png_format.py b/lerobot/common/datasets/push_dataset_to_hub/cam_png_format.py new file mode 100644 index 00000000..4972e6b4 --- /dev/null +++ b/lerobot/common/datasets/push_dataset_to_hub/cam_png_format.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Contains utilities to process raw data format of png images files recorded with capture_camera_feed.py +""" + +from pathlib import Path + +import torch +from datasets import Dataset, Features, Image, Value +from PIL import Image as PILImage + +from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes +from lerobot.common.datasets.utils import calculate_episode_data_index, hf_transform_to_torch +from lerobot.common.datasets.video_utils import VideoFrame + + +def check_format(raw_dir: Path) -> bool: + image_paths = list(raw_dir.glob("frame_*.png")) + if len(image_paths) == 0: + raise ValueError + + +def load_from_raw(raw_dir: Path, fps: int, episodes: list[int] | None = None): + if episodes is not None: + # TODO(aliberts): add support for multi-episodes. + raise NotImplementedError() + + ep_dict = {} + ep_idx = 0 + + image_paths = sorted(raw_dir.glob("frame_*.png")) + num_frames = len(image_paths) + + ep_dict["observation.image"] = [PILImage.open(x) for x in image_paths] + ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames) + ep_dict["frame_index"] = torch.arange(0, num_frames, 1) + ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps + + ep_dicts = [ep_dict] + data_dict = concatenate_episodes(ep_dicts) + total_frames = data_dict["frame_index"].shape[0] + data_dict["index"] = torch.arange(0, total_frames, 1) + return data_dict + + +def to_hf_dataset(data_dict, video) -> Dataset: + features = {} + if video: + features["observation.image"] = VideoFrame() + else: + features["observation.image"] = Image() + + features["episode_index"] = Value(dtype="int64", id=None) + features["frame_index"] = Value(dtype="int64", id=None) + features["timestamp"] = Value(dtype="float32", id=None) + features["index"] = Value(dtype="int64", id=None) + + hf_dataset = Dataset.from_dict(data_dict, features=Features(features)) + hf_dataset.set_transform(hf_transform_to_torch) + return hf_dataset + + +def from_raw_to_lerobot_format( + raw_dir: Path, + videos_dir: Path, + fps: int | None = None, + video: bool = True, + episodes: list[int] | None = None, +): + if video or episodes is not None: + # TODO(aliberts): support this + raise NotImplementedError + + # sanity check + check_format(raw_dir) + + if fps is None: + fps = 30 + + data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes) + hf_dataset = to_hf_dataset(data_dict, video) + episode_data_index = calculate_episode_data_index(hf_dataset) + info = { + "fps": fps, + "video": video, + } + return hf_dataset, episode_data_index, info diff --git a/lerobot/common/datasets/video_utils.py b/lerobot/common/datasets/video_utils.py index edfca918..fdc4fbe9 100644 --- a/lerobot/common/datasets/video_utils.py +++ b/lerobot/common/datasets/video_utils.py @@ -27,7 +27,11 @@ from datasets.features.features import register_feature def load_from_videos( - item: dict[str, torch.Tensor], video_frame_keys: list[str], videos_dir: Path, tolerance_s: float + item: dict[str, torch.Tensor], + video_frame_keys: list[str], + videos_dir: Path, + tolerance_s: float, + backend: str = "pyav", ): """Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a Segmentation Fault. @@ -46,14 +50,14 @@ def load_from_videos( raise NotImplementedError("All video paths are expected to be the same for now.") video_path = data_dir / paths[0] - frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s) + frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend) item[key] = frames else: # load one frame timestamps = [item[key]["timestamp"]] video_path = data_dir / item[key]["path"] - frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s) + frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend) item[key] = frames[0] return item @@ -63,11 +67,23 @@ def decode_video_frames_torchvision( video_path: str, timestamps: list[float], tolerance_s: float, - device: str = "cpu", + backend: str = "pyav", log_loaded_timestamps: bool = False, ): """Loads frames associated to the requested timestamps of a video + The backend can be either "pyav" (default) or "video_reader". + "video_reader" requires installing torchvision from source, see: + https://github.com/pytorch/vision/blob/main/torchvision/csrc/io/decoder/gpu/README.rst + (note that you need to compile against ffmpeg<4.3) + + While both use cpu, "video_reader" is faster than "pyav" but requires additional setup. + See our benchmark results for more info on performance: + https://github.com/huggingface/lerobot/pull/220 + + See torchvision doc for more info on these two backends: + https://pytorch.org/vision/0.18/index.html?highlight=backend#torchvision.set_video_backend + Note: Video benefits from inter-frame compression. Instead of storing every frame individually, the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame, @@ -78,21 +94,9 @@ def decode_video_frames_torchvision( # set backend keyframes_only = False - if device == "cpu": - # explicitely use pyav - torchvision.set_video_backend("pyav") + torchvision.set_video_backend(backend) + if backend == "pyav": keyframes_only = True # pyav doesnt support accuracte seek - elif device == "cuda": - # TODO(rcadene, aliberts): implement video decoding with GPU - # torchvision.set_video_backend("cuda") - # torchvision.set_video_backend("video_reader") - # requires installing torchvision from source, see: https://github.com/pytorch/vision/blob/main/torchvision/csrc/io/decoder/gpu/README.rst - # check possible bug: https://github.com/pytorch/vision/issues/7745 - raise NotImplementedError( - "Video decoding on gpu with cuda is currently not supported. Use `device='cpu'`." - ) - else: - raise ValueError(device) # set a video stream reader # TODO(rcadene): also load audio stream at the same time @@ -120,7 +124,9 @@ def decode_video_frames_torchvision( if current_ts >= last_ts: break - reader.container.close() + if backend == "pyav": + reader.container.close() + reader = None query_ts = torch.tensor(timestamps) diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index 6101df89..c479788b 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -28,6 +28,7 @@ seed: ??? # "dataset_index" into the returned item. The index mapping is made according to the order in which the # datsets are provided. dataset_repo_id: lerobot/pusht +video_backend: pyav training: offline_steps: ??? diff --git a/lerobot/scripts/push_dataset_to_hub.py b/lerobot/scripts/push_dataset_to_hub.py index 18714a40..92a0cc45 100644 --- a/lerobot/scripts/push_dataset_to_hub.py +++ b/lerobot/scripts/push_dataset_to_hub.py @@ -55,7 +55,6 @@ from safetensors.torch import save_file from lerobot.common.datasets.compute_stats import compute_stats from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset -from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw from lerobot.common.datasets.utils import flatten_dict @@ -70,6 +69,8 @@ def get_from_raw_to_lerobot_format_fn(raw_format: str): from lerobot.common.datasets.push_dataset_to_hub.dora_parquet_format import from_raw_to_lerobot_format elif raw_format == "xarm_pkl": from lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format import from_raw_to_lerobot_format + elif raw_format == "cam_png": + from lerobot.common.datasets.push_dataset_to_hub.cam_png_format import from_raw_to_lerobot_format else: raise ValueError( f"The selected {raw_format} can't be found. Did you add it to `lerobot/scripts/push_dataset_to_hub.py::get_from_raw_to_lerobot_format_fn`?" @@ -182,10 +183,6 @@ def push_dataset_to_hub( meta_data_dir = Path(cache_dir) / "meta_data" videos_dir = Path(cache_dir) / "videos" - # Download the raw dataset if available - if not raw_dir.exists(): - download_raw(raw_dir, dataset_id) - if raw_format is None: # TODO(rcadene, adilzouitine): implement auto_find_raw_format raise NotImplementedError() diff --git a/poetry.lock b/poetry.lock index e9d6e848..48c6c057 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "absl-py" @@ -3805,31 +3805,31 @@ files = [ [[package]] name = "torch" -version = "2.3.0" +version = "2.3.1" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" optional = false python-versions = ">=3.8.0" files = [ - {file = "torch-2.3.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:d8ea5a465dbfd8501f33c937d1f693176c9aef9d1c1b0ca1d44ed7b0a18c52ac"}, - {file = "torch-2.3.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:09c81c5859a5b819956c6925a405ef1cdda393c9d8a01ce3851453f699d3358c"}, - {file = "torch-2.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:1bf023aa20902586f614f7682fedfa463e773e26c58820b74158a72470259459"}, - {file = "torch-2.3.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:758ef938de87a2653bba74b91f703458c15569f1562bf4b6c63c62d9c5a0c1f5"}, - {file = "torch-2.3.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:493d54ee2f9df100b5ce1d18c96dbb8d14908721f76351e908c9d2622773a788"}, - {file = "torch-2.3.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:bce43af735c3da16cc14c7de2be7ad038e2fbf75654c2e274e575c6c05772ace"}, - {file = "torch-2.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:729804e97b7cf19ae9ab4181f91f5e612af07956f35c8b2c8e9d9f3596a8e877"}, - {file = "torch-2.3.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:d24e328226d8e2af7cf80fcb1d2f1d108e0de32777fab4aaa2b37b9765d8be73"}, - {file = "torch-2.3.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:b0de2bdc0486ea7b14fc47ff805172df44e421a7318b7c4d92ef589a75d27410"}, - {file = "torch-2.3.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:a306c87a3eead1ed47457822c01dfbd459fe2920f2d38cbdf90de18f23f72542"}, - {file = "torch-2.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:f9b98bf1a3c8af2d4c41f0bf1433920900896c446d1ddc128290ff146d1eb4bd"}, - {file = "torch-2.3.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:dca986214267b34065a79000cee54232e62b41dff1ec2cab9abc3fc8b3dee0ad"}, - {file = "torch-2.3.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:20572f426965dd8a04e92a473d7e445fa579e09943cc0354f3e6fef6130ce061"}, - {file = "torch-2.3.0-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:e65ba85ae292909cde0dde6369826d51165a3fc8823dc1854cd9432d7f79b932"}, - {file = "torch-2.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:5515503a193781fd1b3f5c474e89c9dfa2faaa782b2795cc4a7ab7e67de923f6"}, - {file = "torch-2.3.0-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:6ae9f64b09516baa4ef890af0672dc981c20b1f0d829ce115d4420a247e88fba"}, - {file = "torch-2.3.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:cd0dc498b961ab19cb3f8dbf0c6c50e244f2f37dbfa05754ab44ea057c944ef9"}, - {file = "torch-2.3.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:e05f836559251e4096f3786ee99f4a8cbe67bc7fbedba8ad5e799681e47c5e80"}, - {file = "torch-2.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:4fb27b35dbb32303c2927da86e27b54a92209ddfb7234afb1949ea2b3effffea"}, - {file = "torch-2.3.0-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:760f8bedff506ce9e6e103498f9b1e9e15809e008368594c3a66bf74a8a51380"}, + {file = "torch-2.3.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:605a25b23944be5ab7c3467e843580e1d888b8066e5aaf17ff7bf9cc30001cc3"}, + {file = "torch-2.3.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:f2357eb0965583a0954d6f9ad005bba0091f956aef879822274b1bcdb11bd308"}, + {file = "torch-2.3.1-cp310-cp310-win_amd64.whl", hash = "sha256:32b05fe0d1ada7f69c9f86c14ff69b0ef1957a5a54199bacba63d22d8fab720b"}, + {file = "torch-2.3.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:7c09a94362778428484bcf995f6004b04952106aee0ef45ff0b4bab484f5498d"}, + {file = "torch-2.3.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:b2ec81b61bb094ea4a9dee1cd3f7b76a44555375719ad29f05c0ca8ef596ad39"}, + {file = "torch-2.3.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:490cc3d917d1fe0bd027057dfe9941dc1d6d8e3cae76140f5dd9a7e5bc7130ab"}, + {file = "torch-2.3.1-cp311-cp311-win_amd64.whl", hash = "sha256:5802530783bd465fe66c2df99123c9a54be06da118fbd785a25ab0a88123758a"}, + {file = "torch-2.3.1-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:a7dd4ed388ad1f3d502bf09453d5fe596c7b121de7e0cfaca1e2017782e9bbac"}, + {file = "torch-2.3.1-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:a486c0b1976a118805fc7c9641d02df7afbb0c21e6b555d3bb985c9f9601b61a"}, + {file = "torch-2.3.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:224259821fe3e4c6f7edf1528e4fe4ac779c77addaa74215eb0b63a5c474d66c"}, + {file = "torch-2.3.1-cp312-cp312-win_amd64.whl", hash = "sha256:e5fdccbf6f1334b2203a61a0e03821d5845f1421defe311dabeae2fc8fbeac2d"}, + {file = "torch-2.3.1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:3c333dc2ebc189561514eda06e81df22bf8fb64e2384746b2cb9f04f96d1d4c8"}, + {file = "torch-2.3.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:07e9ba746832b8d069cacb45f312cadd8ad02b81ea527ec9766c0e7404bb3feb"}, + {file = "torch-2.3.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:462d1c07dbf6bb5d9d2f3316fee73a24f3d12cd8dacf681ad46ef6418f7f6626"}, + {file = "torch-2.3.1-cp38-cp38-win_amd64.whl", hash = "sha256:ff60bf7ce3de1d43ad3f6969983f321a31f0a45df3690921720bcad6a8596cc4"}, + {file = "torch-2.3.1-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:bee0bd33dc58aa8fc8a7527876e9b9a0e812ad08122054a5bff2ce5abf005b10"}, + {file = "torch-2.3.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:aaa872abde9a3d4f91580f6396d54888620f4a0b92e3976a6034759df4b961ad"}, + {file = "torch-2.3.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:3d7a7f7ef21a7520510553dc3938b0c57c116a7daee20736a9e25cbc0e832bdc"}, + {file = "torch-2.3.1-cp39-cp39-win_amd64.whl", hash = "sha256:4777f6cefa0c2b5fa87223c213e7b6f417cf254a45e5829be4ccd1b2a4ee1011"}, + {file = "torch-2.3.1-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:2bb5af780c55be68fe100feb0528d2edebace1d55cb2e351de735809ba7391eb"}, ] [package.dependencies] @@ -3850,7 +3850,7 @@ nvidia-cusparse-cu12 = {version = "12.1.0.106", markers = "platform_system == \" nvidia-nccl-cu12 = {version = "2.20.5", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} nvidia-nvtx-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} sympy = "*" -triton = {version = "2.3.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.12\""} +triton = {version = "2.3.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.12\""} typing-extensions = ">=4.8.0" [package.extras] @@ -3859,37 +3859,37 @@ optree = ["optree (>=0.9.1)"] [[package]] name = "torchvision" -version = "0.18.0" +version = "0.18.1" description = "image and video datasets and models for torch deep learning" optional = false python-versions = ">=3.8" files = [ - {file = "torchvision-0.18.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:dd61628a3d189c6852a12dc5ed4cd2eece66d2d67f35a866cb16f1dcb06c8c62"}, - {file = "torchvision-0.18.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:493c45f9937dad37aa1b64b14da17c7a589c72b91adc4837d431009cfe29bd53"}, - {file = "torchvision-0.18.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:5337f6acfa1fe959d5cb340d01a00614d6b31ce7a4824ccb95435a85c5273b95"}, - {file = "torchvision-0.18.0-cp310-cp310-win_amd64.whl", hash = "sha256:bd8e6f3b5beb49965f15c461302488edfa3d8c2d01d3bb79b150d6fb62711e3a"}, - {file = "torchvision-0.18.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6896a52168befe1105fb3c9335287390ed227e71d1e4ec4d68b62e8a3099fc09"}, - {file = "torchvision-0.18.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:3d7955398d4ceaad77c487c2c44f6f7813112402c9bab8cd906d346005891048"}, - {file = "torchvision-0.18.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:e5a24d620cea14a4bb89f24aa2b506230c0a16a3ada57fc53ad80cfd256a2128"}, - {file = "torchvision-0.18.0-cp311-cp311-win_amd64.whl", hash = "sha256:6ad70ddfa879bda5ed886b2518fe562640e0059787cbd65cb2bffa7674541410"}, - {file = "torchvision-0.18.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:eb9d83c0e1dbb54ecb0fb04c87f786333e3a6fb8b9c400aca7c31081f9aa5707"}, - {file = "torchvision-0.18.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:b657d052d146f24cb3b2a78219bfc82ae70a9706671c50f632528907d10cccec"}, - {file = "torchvision-0.18.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:a964afbc7ddf50a46b941477f6c35729b416deedd139756befd488245e2e226d"}, - {file = "torchvision-0.18.0-cp312-cp312-win_amd64.whl", hash = "sha256:7c770f0f748e0b17f57c0297508d7254f686cdf03fc2e2949f422b20574f4c0f"}, - {file = "torchvision-0.18.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:2115a1906c015f5da9ceedc40a983313b0fd6e2c8a17108a92991706f51f6987"}, - {file = "torchvision-0.18.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:6323f7e5423ff2594d5891863b919deb9d0de95f01c36bf26fbd879036b6ed08"}, - {file = "torchvision-0.18.0-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:925d0a82cccf6f986c18b29b4392a942db65cbdb73c13a129c8493822eb9e36f"}, - {file = "torchvision-0.18.0-cp38-cp38-win_amd64.whl", hash = "sha256:95b42d0dc599b47a01530c7439a5751e67e45b85e3a67113989cf7c7c70f2039"}, - {file = "torchvision-0.18.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:75e22ecf44a13b8f95b8ad421c0261282d859c61816badaca1959e073ccdd691"}, - {file = "torchvision-0.18.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:4c334b3e719ba0a9ba6e15d4aff1178f5e6d029174f346163fed525f0ccfffd3"}, - {file = "torchvision-0.18.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:36efd87001c6bee2383e043e46a025affb03179747c8f4777b9918527ffce756"}, - {file = "torchvision-0.18.0-cp39-cp39-win_amd64.whl", hash = "sha256:ccc292e093771d5baacf5535ac4416306b6b5f15676341cd4d010d8542eace25"}, + {file = "torchvision-0.18.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3e694e54b0548dad99c12af6bf0c8e4f3350137d391dcd19af22a1c5f89322b3"}, + {file = "torchvision-0.18.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:0b3bda0aa5b416eeb547143b8eeaf17720bdba9cf516dc991aacb81811aa96a5"}, + {file = "torchvision-0.18.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:573ff523c739405edb085f65cb592f482d28a30e29b0be4c4ba08040b3ae785f"}, + {file = "torchvision-0.18.1-cp310-cp310-win_amd64.whl", hash = "sha256:ef7bbbc60b38e831a75e547c66ca1784f2ac27100f9e4ddbe9614cef6cbcd942"}, + {file = "torchvision-0.18.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:80b5d794dd0fdba787adc22f1a367a5ead452327686473cb260dd94364bc56a6"}, + {file = "torchvision-0.18.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:9077cf590cdb3a5e8fdf5cdb71797f8c67713f974cf0228ecb17fcd670ab42f9"}, + {file = "torchvision-0.18.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:ceb993a882f1ae7ae373ed39c28d7e3e802205b0e59a7ed84ef4028f0bba8d7f"}, + {file = "torchvision-0.18.1-cp311-cp311-win_amd64.whl", hash = "sha256:52f7436140045dc2239cdc502aa76b2bd8bd676d64244ff154d304aa69852046"}, + {file = "torchvision-0.18.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2be6f0bf7c455c89a51a1dbb6f668d36c6edc479f49ac912d745d10df5715657"}, + {file = "torchvision-0.18.1-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:f118d887bfde3a948a41d56587525401e5cac1b7db2eaca203324d6ed2b1caca"}, + {file = "torchvision-0.18.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:13d24d904f65e62d66a1e0c41faec630bc193867b8a4a01166769e8a8e8df8e9"}, + {file = "torchvision-0.18.1-cp312-cp312-win_amd64.whl", hash = "sha256:ed6340b69a63a625e512a66127210d412551d9c5f2ad2978130c6a45bf56cd4a"}, + {file = "torchvision-0.18.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b1c3864fa9378c88bce8ad0ef3599f4f25397897ce612e1c245c74b97092f35e"}, + {file = "torchvision-0.18.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:02085a2ffc7461f5c0edb07d6f3455ee1806561f37736b903da820067eea58c7"}, + {file = "torchvision-0.18.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:9726c316a2501df8503e5a5dc46a631afd4c515a958972e5b7f7b9c87d2125c0"}, + {file = "torchvision-0.18.1-cp38-cp38-win_amd64.whl", hash = "sha256:64a2662dbf30db9055d8b201d6e56f312a504e5ccd9d144c57c41622d3c524cb"}, + {file = "torchvision-0.18.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:975b8594c0f5288875408acbb74946eea786c5b008d129c0d045d0ead23742bc"}, + {file = "torchvision-0.18.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:da83c8bbd34d8bee48bfa1d1b40e0844bc3cba10ed825a5a8cbe3ce7b62264cd"}, + {file = "torchvision-0.18.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:54bfcd352abb396d5c9c237d200167c178bd136051b138e1e8ef46ce367c2773"}, + {file = "torchvision-0.18.1-cp39-cp39-win_amd64.whl", hash = "sha256:5c8366a1aeee49e9ea9e64b30d199debdf06b1bd7610a76165eb5d7869c3bde5"}, ] [package.dependencies] numpy = "*" pillow = ">=5.3.0,<8.3.dev0 || >=8.4.dev0" -torch = "2.3.0" +torch = "2.3.1" [package.extras] scipy = ["scipy"] @@ -3916,17 +3916,17 @@ telegram = ["requests"] [[package]] name = "triton" -version = "2.3.0" +version = "2.3.1" description = "A language and compiler for custom Deep Learning operations" optional = false python-versions = "*" files = [ - {file = "triton-2.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ce4b8ff70c48e47274c66f269cce8861cf1dc347ceeb7a67414ca151b1822d8"}, - {file = "triton-2.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c3d9607f85103afdb279938fc1dd2a66e4f5999a58eb48a346bd42738f986dd"}, - {file = "triton-2.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:218d742e67480d9581bafb73ed598416cc8a56f6316152e5562ee65e33de01c0"}, - {file = "triton-2.3.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:381ec6b3dac06922d3e4099cfc943ef032893b25415de295e82b1a82b0359d2c"}, - {file = "triton-2.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:038e06a09c06a164fef9c48de3af1e13a63dc1ba3c792871e61a8e79720ea440"}, - {file = "triton-2.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d8f636e0341ac348899a47a057c3daea99ea7db31528a225a3ba4ded28ccc65"}, + {file = "triton-2.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c84595cbe5e546b1b290d2a58b1494df5a2ef066dd890655e5b8a8a92205c33"}, + {file = "triton-2.3.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c9d64ae33bcb3a7a18081e3a746e8cf87ca8623ca13d2c362413ce7a486f893e"}, + {file = "triton-2.3.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eaf80e8761a9e3498aa92e7bf83a085b31959c61f5e8ac14eedd018df6fccd10"}, + {file = "triton-2.3.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b13bf35a2b659af7159bf78e92798dc62d877aa991de723937329e2d382f1991"}, + {file = "triton-2.3.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63381e35ded3304704ea867ffde3b7cfc42c16a55b3062d41e017ef510433d66"}, + {file = "triton-2.3.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1d968264523c7a07911c8fb51b4e0d1b920204dae71491b1fe7b01b62a31e124"}, ] [package.dependencies] @@ -4301,9 +4301,10 @@ dora = ["gym-dora"] pusht = ["gym-pusht"] test = ["pytest", "pytest-cov"] umi = ["imagecodecs"] +video-benchmark = ["scikit-image"] xarm = ["gym-xarm"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "23ddb8dd774a4faf85d08a07dfdf19badb7c370120834b71df4afca254520771" +content-hash = "61f99befbc2250fe59cb54119c3dbd3aa3c1dfe5d3d7790c6f7c4f90fe43112e" diff --git a/pyproject.toml b/pyproject.toml index 0c305218..15ac9c73 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,7 @@ pyav = ">=12.0.5" moviepy = ">=1.0.3" rerun-sdk = ">=0.15.1" deepdiff = ">=7.0.1" +scikit-image = {version = "^0.23.2", optional = true} [tool.poetry.extras] @@ -70,6 +71,7 @@ aloha = ["gym-aloha"] dev = ["pre-commit", "debugpy"] test = ["pytest", "pytest-cov"] umi = ["imagecodecs"] +video_benchmark = ["scikit-image"] [tool.ruff] line-length = 110