Enable `video_reader` backend (#220)

Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
Simon Alibert 2024-06-19 17:15:25 +02:00 committed by GitHub
parent 48951662f2
commit 2abef3bef9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 464 additions and 220 deletions

View File

@ -1,4 +1,4 @@
FROM nvidia/cuda:12.4.1-base-ubuntu22.04 FROM nvidia/cuda:12.2.2-devel-ubuntu22.04
# Configure image # Configure image
ARG PYTHON_VERSION=3.10 ARG PYTHON_VERSION=3.10
@ -10,9 +10,9 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
git git-lfs openssh-client \ git git-lfs openssh-client \
nano vim less util-linux \ nano vim less util-linux \
htop atop nvtop \ htop atop nvtop \
sed gawk grep curl wget \ sed gawk grep curl wget zip unzip \
tcpdump sysstat screen tmux \ tcpdump sysstat screen tmux \
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \ libglib2.0-0 libgl1-mesa-glx libegl1-mesa \
python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \ python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \
&& apt-get clean && rm -rf /var/lib/apt/lists/* && apt-get clean && rm -rf /var/lib/apt/lists/*

View File

@ -0,0 +1,90 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Capture video feed from a camera as raw images."""
import argparse
import datetime as dt
from pathlib import Path
import cv2
def display_and_save_video_stream(output_dir: Path, fps: int, width: int, height: int):
now = dt.datetime.now()
capture_dir = output_dir / f"{now:%Y-%m-%d}" / f"{now:%H-%M-%S}"
if not capture_dir.exists():
capture_dir.mkdir(parents=True, exist_ok=True)
# Opens the default webcam
cap = cv2.VideoCapture(0)
if not cap.isOpened():
print("Error: Could not open video stream.")
return
cap.set(cv2.CAP_PROP_FPS, fps)
cap.set(cv2.CAP_PROP_FRAME_WIDTH, width)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
frame_index = 0
while True:
ret, frame = cap.read()
if not ret:
print("Error: Could not read frame.")
break
cv2.imshow("Video Stream", frame)
cv2.imwrite(str(capture_dir / f"frame_{frame_index:06d}.png"), frame)
frame_index += 1
# Break the loop on 'q' key press
if cv2.waitKey(1) & 0xFF == ord("q"):
break
# Release the capture and destroy all windows
cap.release()
cv2.destroyAllWindows()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--output-dir",
type=Path,
default=Path("outputs/cam_capture/"),
help="Directory where the capture images are written. A subfolder named with the current date & time will be created inside it for each capture.",
)
parser.add_argument(
"--fps",
type=int,
default=30,
help="Frames Per Second of the capture.",
)
parser.add_argument(
"--width",
type=int,
default=1280,
help="Width of the captured images.",
)
parser.add_argument(
"--height",
type=int,
default=720,
help="Height of the captured images.",
)
args = parser.parse_args()
display_and_save_video_stream(**vars(args))

View File

@ -13,6 +13,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Assess the performance of video decoding in various configurations.
This script will run different video decoding benchmarks where one parameter varies at a time.
These parameters and theirs values are specified in the BENCHMARKS dict.
All of these benchmarks are evaluated within different timestamps modes corresponding to different frame-loading scenarios:
- `1_frame`: 1 single frame is loaded.
- `2_frames`: 2 consecutive frames are loaded.
- `2_frames_4_space`: 2 frames separated by 4 frames are loaded.
- `6_frames`: 6 consecutive frames are loaded.
These values are more or less arbitrary and based on possible future usage.
These benchmarks are run on the first episode of each dataset specified in DATASET_REPO_IDS.
Note: These datasets need to be image datasets, not video datasets.
"""
import json import json
import random import random
import shutil import shutil
@ -21,15 +38,38 @@ import time
from pathlib import Path from pathlib import Path
import einops import einops
import numpy import numpy as np
import PIL import PIL
import torch import torch
from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.video_utils import ( from lerobot.common.datasets.video_utils import (
decode_video_frames_torchvision, decode_video_frames_torchvision,
) )
OUTPUT_DIR = Path("tmp/run_video_benchmark")
DRY_RUN = False
DATASET_REPO_IDS = [
"lerobot/pusht_image",
"aliberts/aloha_mobile_shrimp_image",
"aliberts/paris_street",
"aliberts/kitchen",
]
TIMESTAMPS_MODES = [
"1_frame",
"2_frames",
"2_frames_4_space",
"6_frames",
]
BENCHMARKS = {
# "pix_fmt": ["yuv420p", "yuv444p"],
# "g": [1, 2, 3, 4, 5, 6, 10, 15, 20, 40, 100, None],
# "crf": [0, 5, 10, 15, 20, None, 25, 30, 40, 50],
"backend": ["pyav", "video_reader"],
}
def get_directory_size(directory): def get_directory_size(directory):
total_size = 0 total_size = 0
@ -56,6 +96,10 @@ def run_video_benchmark(
# TODO(rcadene): rewrite with hardcoding of original images and episodes # TODO(rcadene): rewrite with hardcoding of original images and episodes
dataset = LeRobotDataset(repo_id) dataset = LeRobotDataset(repo_id)
if dataset.video:
raise ValueError(
f"Use only image dataset for running this benchmark. Video dataset provided: {repo_id}"
)
# Get fps # Get fps
fps = dataset.fps fps = dataset.fps
@ -68,10 +112,11 @@ def run_video_benchmark(
if not imgs_dir.exists(): if not imgs_dir.exists():
imgs_dir.mkdir(parents=True, exist_ok=True) imgs_dir.mkdir(parents=True, exist_ok=True)
hf_dataset = dataset.hf_dataset.with_format(None) hf_dataset = dataset.hf_dataset.with_format(None)
imgs_dataset = hf_dataset.select_columns("observation.image") img_keys = [key for key in hf_dataset.features if key.startswith("observation.image")]
imgs_dataset = hf_dataset.select_columns(img_keys[0])
for i, item in enumerate(imgs_dataset): for i, item in enumerate(imgs_dataset):
img = item["observation.image"] img = item[img_keys[0]]
img.save(str(imgs_dir / f"frame_{i:06d}.png"), quality=100) img.save(str(imgs_dir / f"frame_{i:06d}.png"), quality=100)
if i >= ep_num_images - 1: if i >= ep_num_images - 1:
@ -107,7 +152,7 @@ def run_video_benchmark(
decoder = cfg["decoder"] decoder = cfg["decoder"]
decoder_kwgs = cfg["decoder_kwgs"] decoder_kwgs = cfg["decoder_kwgs"]
device = cfg["device"] backend = cfg["backend"]
if decoder == "torchvision": if decoder == "torchvision":
decode_frames_fn = decode_video_frames_torchvision decode_frames_fn = decode_video_frames_torchvision
@ -116,12 +161,12 @@ def run_video_benchmark(
# Estimate average loading time # Estimate average loading time
def load_original_frames(imgs_dir, timestamps): def load_original_frames(imgs_dir, timestamps) -> torch.Tensor:
frames = [] frames = []
for ts in timestamps: for ts in timestamps:
idx = int(ts * fps) idx = int(ts * fps)
frame = PIL.Image.open(imgs_dir / f"frame_{idx:06d}.png") frame = PIL.Image.open(imgs_dir / f"frame_{idx:06d}.png")
frame = torch.from_numpy(numpy.array(frame)) frame = torch.from_numpy(np.array(frame))
frame = frame.type(torch.float32) / 255 frame = frame.type(torch.float32) / 255
frame = einops.rearrange(frame, "h w c -> c h w") frame = einops.rearrange(frame, "h w c -> c h w")
frames.append(frame) frames.append(frame)
@ -130,6 +175,9 @@ def run_video_benchmark(
list_avg_load_time = [] list_avg_load_time = []
list_avg_load_time_from_images = [] list_avg_load_time_from_images = []
per_pixel_l2_errors = [] per_pixel_l2_errors = []
psnr_values = []
ssim_values = []
mse_values = []
random.seed(seed) random.seed(seed)
@ -142,7 +190,7 @@ def run_video_benchmark(
elif timestamps_mode == "2_frames": elif timestamps_mode == "2_frames":
timestamps = [ts - 1 / fps, ts] timestamps = [ts - 1 / fps, ts]
elif timestamps_mode == "2_frames_4_space": elif timestamps_mode == "2_frames_4_space":
timestamps = [ts - 4 / fps, ts] timestamps = [ts - 5 / fps, ts]
elif timestamps_mode == "6_frames": elif timestamps_mode == "6_frames":
timestamps = [ts - i / fps for i in range(6)][::-1] timestamps = [ts - i / fps for i in range(6)][::-1]
else: else:
@ -152,7 +200,7 @@ def run_video_benchmark(
start_time_s = time.monotonic() start_time_s = time.monotonic()
frames = decode_frames_fn( frames = decode_frames_fn(
video_path, timestamps=timestamps, tolerance_s=1e-4, device=device, **decoder_kwgs video_path, timestamps=timestamps, tolerance_s=1e-4, backend=backend, **decoder_kwgs
) )
avg_load_time = (time.monotonic() - start_time_s) / num_frames avg_load_time = (time.monotonic() - start_time_s) / num_frames
list_avg_load_time.append(avg_load_time) list_avg_load_time.append(avg_load_time)
@ -162,11 +210,19 @@ def run_video_benchmark(
avg_load_time_from_images = (time.monotonic() - start_time_s) / num_frames avg_load_time_from_images = (time.monotonic() - start_time_s) / num_frames
list_avg_load_time_from_images.append(avg_load_time_from_images) list_avg_load_time_from_images.append(avg_load_time_from_images)
# Estimate average L2 error between original frames and decoded frames # Estimate reconstruction error between original frames and decoded frames with various metrics
for i, ts in enumerate(timestamps): for i, ts in enumerate(timestamps):
# are_close = torch.allclose(frames[i], original_frames[i], atol=0.02) # are_close = torch.allclose(frames[i], original_frames[i], atol=0.02)
num_pixels = original_frames[i].numel() num_pixels = original_frames[i].numel()
per_pixel_l2_error = torch.norm(frames[i] - original_frames[i], p=2).item() / num_pixels per_pixel_l2_error = torch.norm(frames[i] - original_frames[i], p=2).item() / num_pixels
per_pixel_l2_errors.append(per_pixel_l2_error)
frame_np, original_frame_np = frames[i].numpy(), original_frames[i].numpy()
psnr_values.append(peak_signal_noise_ratio(original_frame_np, frame_np, data_range=1.0))
ssim_values.append(
structural_similarity(original_frame_np, frame_np, data_range=1.0, channel_axis=0)
)
mse_values.append(mean_squared_error(original_frame_np, frame_np))
# save decoded frames # save decoded frames
if t == 0: if t == 0:
@ -179,15 +235,18 @@ def run_video_benchmark(
original_frame = PIL.Image.open(imgs_dir / f"frame_{idx:06d}.png") original_frame = PIL.Image.open(imgs_dir / f"frame_{idx:06d}.png")
original_frame.save(output_dir / f"original_frame_{i:06d}.png") original_frame.save(output_dir / f"original_frame_{i:06d}.png")
per_pixel_l2_errors.append(per_pixel_l2_error) image_size = tuple(dataset[0][dataset.camera_keys[0]].shape[-2:])
avg_load_time = float(np.array(list_avg_load_time).mean())
avg_load_time = float(numpy.array(list_avg_load_time).mean()) avg_load_time_from_images = float(np.array(list_avg_load_time_from_images).mean())
avg_load_time_from_images = float(numpy.array(list_avg_load_time_from_images).mean()) avg_per_pixel_l2_error = float(np.array(per_pixel_l2_errors).mean())
avg_per_pixel_l2_error = float(numpy.array(per_pixel_l2_errors).mean()) avg_psnr = float(np.mean(psnr_values))
avg_ssim = float(np.mean(ssim_values))
avg_mse = float(np.mean(mse_values))
# Save benchmark info # Save benchmark info
info = { info = {
"image_size": image_size,
"sum_original_frames_size_bytes": sum_original_frames_size_bytes, "sum_original_frames_size_bytes": sum_original_frames_size_bytes,
"video_size_bytes": video_size_bytes, "video_size_bytes": video_size_bytes,
"avg_load_time_from_images": avg_load_time_from_images, "avg_load_time_from_images": avg_load_time_from_images,
@ -195,6 +254,9 @@ def run_video_benchmark(
"compression_factor": sum_original_frames_size_bytes / video_size_bytes, "compression_factor": sum_original_frames_size_bytes / video_size_bytes,
"load_time_factor": avg_load_time_from_images / avg_load_time, "load_time_factor": avg_load_time_from_images / avg_load_time,
"avg_per_pixel_l2_error": avg_per_pixel_l2_error, "avg_per_pixel_l2_error": avg_per_pixel_l2_error,
"avg_psnr": avg_psnr,
"avg_ssim": avg_ssim,
"avg_mse": avg_mse,
} }
with open(output_dir / "info.json", "w") as f: with open(output_dir / "info.json", "w") as f:
@ -234,138 +296,113 @@ def load_info(out_dir):
return info return info
def main(): def one_variable_study(
out_dir = Path("tmp/run_video_benchmark") var_name: str, var_values: list, repo_ids: list, bench_dir: Path, timestamps_mode: str, dry_run: bool
dry_run = False ):
repo_ids = ["lerobot/pusht", "lerobot/umi_cup_in_the_wild"] print(f"**`{var_name}`**")
timestamps_modes = [ headers = [
"1_frame", "repo_id",
"2_frames", "image_size",
"2_frames_4_space", var_name,
"6_frames", "compression_factor",
"load_time_factor",
"avg_per_pixel_l2_error",
"avg_psnr",
"avg_ssim",
"avg_mse",
] ]
for timestamps_mode in timestamps_modes: rows = []
bench_dir = out_dir / timestamps_mode base_cfg = {
"repo_id": None,
# video encoding
"g": 2,
"crf": None,
"pix_fmt": "yuv444p",
# video decoding
"backend": "pyav",
"decoder": "torchvision",
"decoder_kwgs": {},
}
for repo_id in repo_ids:
for val in var_values:
cfg = base_cfg.copy()
cfg["repo_id"] = repo_id
cfg[var_name] = val
if not dry_run:
run_video_benchmark(
bench_dir / repo_id / f"torchvision_{var_name}_{val}", cfg, timestamps_mode
)
info = load_info(bench_dir / repo_id / f"torchvision_{var_name}_{val}")
width, height = info["image_size"][0], info["image_size"][1]
rows.append(
[
repo_id,
f"{width} x {height}",
val,
info["compression_factor"],
info["load_time_factor"],
info["avg_per_pixel_l2_error"],
info["avg_psnr"],
info["avg_ssim"],
info["avg_mse"],
]
)
display_markdown_table(headers, rows)
def best_study(repo_ids: list, bench_dir: Path, timestamps_mode: str, dry_run: bool):
"""Change the config once you deciced what's best based on one-variable-studies"""
print("**best**")
headers = [
"repo_id",
"image_size",
"compression_factor",
"load_time_factor",
"avg_per_pixel_l2_error",
"avg_psnr",
"avg_ssim",
"avg_mse",
]
rows = []
for repo_id in repo_ids:
cfg = {
"repo_id": repo_id,
# video encoding
"g": 2,
"crf": None,
"pix_fmt": "yuv444p",
# video decoding
"backend": "video_reader",
"decoder": "torchvision",
"decoder_kwgs": {},
}
if not dry_run:
run_video_benchmark(bench_dir / repo_id / "torchvision_best", cfg, timestamps_mode)
info = load_info(bench_dir / repo_id / "torchvision_best")
width, height = info["image_size"][0], info["image_size"][1]
rows.append(
[
repo_id,
f"{width} x {height}",
info["compression_factor"],
info["load_time_factor"],
info["avg_per_pixel_l2_error"],
]
)
display_markdown_table(headers, rows)
def main():
for timestamps_mode in TIMESTAMPS_MODES:
bench_dir = OUTPUT_DIR / timestamps_mode
print(f"### `{timestamps_mode}`") print(f"### `{timestamps_mode}`")
print() print()
print("**`pix_fmt`**") for name, values in BENCHMARKS.items():
headers = ["repo_id", "pix_fmt", "compression_factor", "load_time_factor", "avg_per_pixel_l2_error"] one_variable_study(name, values, DATASET_REPO_IDS, bench_dir, timestamps_mode, DRY_RUN)
rows = []
for repo_id in repo_ids:
for pix_fmt in ["yuv420p", "yuv444p"]:
cfg = {
"repo_id": repo_id,
# video encoding
"g": 2,
"crf": None,
"pix_fmt": pix_fmt,
# video decoding
"device": "cpu",
"decoder": "torchvision",
"decoder_kwgs": {},
}
if not dry_run:
run_video_benchmark(bench_dir / repo_id / f"torchvision_{pix_fmt}", cfg, timestamps_mode)
info = load_info(bench_dir / repo_id / f"torchvision_{pix_fmt}")
rows.append(
[
repo_id,
pix_fmt,
info["compression_factor"],
info["load_time_factor"],
info["avg_per_pixel_l2_error"],
]
)
display_markdown_table(headers, rows)
print("**`g`**") # best_study(DATASET_REPO_IDS, bench_dir, timestamps_mode, DRY_RUN)
headers = ["repo_id", "g", "compression_factor", "load_time_factor", "avg_per_pixel_l2_error"]
rows = []
for repo_id in repo_ids:
for g in [1, 2, 3, 4, 5, 6, 10, 15, 20, 40, 100, None]:
cfg = {
"repo_id": repo_id,
# video encoding
"g": g,
"pix_fmt": "yuv444p",
# video decoding
"device": "cpu",
"decoder": "torchvision",
"decoder_kwgs": {},
}
if not dry_run:
run_video_benchmark(bench_dir / repo_id / f"torchvision_g_{g}", cfg, timestamps_mode)
info = load_info(bench_dir / repo_id / f"torchvision_g_{g}")
rows.append(
[
repo_id,
g,
info["compression_factor"],
info["load_time_factor"],
info["avg_per_pixel_l2_error"],
]
)
display_markdown_table(headers, rows)
print("**`crf`**")
headers = ["repo_id", "crf", "compression_factor", "load_time_factor", "avg_per_pixel_l2_error"]
rows = []
for repo_id in repo_ids:
for crf in [0, 5, 10, 15, 20, None, 25, 30, 40, 50]:
cfg = {
"repo_id": repo_id,
# video encoding
"g": 2,
"crf": crf,
"pix_fmt": "yuv444p",
# video decoding
"device": "cpu",
"decoder": "torchvision",
"decoder_kwgs": {},
}
if not dry_run:
run_video_benchmark(bench_dir / repo_id / f"torchvision_crf_{crf}", cfg, timestamps_mode)
info = load_info(bench_dir / repo_id / f"torchvision_crf_{crf}")
rows.append(
[
repo_id,
crf,
info["compression_factor"],
info["load_time_factor"],
info["avg_per_pixel_l2_error"],
]
)
display_markdown_table(headers, rows)
print("**best**")
headers = ["repo_id", "compression_factor", "load_time_factor", "avg_per_pixel_l2_error"]
rows = []
for repo_id in repo_ids:
cfg = {
"repo_id": repo_id,
# video encoding
"g": 2,
"crf": None,
"pix_fmt": "yuv444p",
# video decoding
"device": "cpu",
"decoder": "torchvision",
"decoder_kwgs": {},
}
if not dry_run:
run_video_benchmark(bench_dir / repo_id / "torchvision_best", cfg, timestamps_mode)
info = load_info(bench_dir / repo_id / "torchvision_best")
rows.append(
[
repo_id,
info["compression_factor"],
info["load_time_factor"],
info["avg_per_pixel_l2_error"],
]
)
display_markdown_table(headers, rows)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -96,6 +96,7 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
split=split, split=split,
delta_timestamps=cfg.training.get("delta_timestamps"), delta_timestamps=cfg.training.get("delta_timestamps"),
image_transforms=image_transforms, image_transforms=image_transforms,
video_backend=cfg.video_backend,
) )
else: else:
dataset = MultiLeRobotDataset( dataset = MultiLeRobotDataset(
@ -103,6 +104,7 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
split=split, split=split,
delta_timestamps=cfg.training.get("delta_timestamps"), delta_timestamps=cfg.training.get("delta_timestamps"),
image_transforms=image_transforms, image_transforms=image_transforms,
video_backend=cfg.video_backend,
) )
if cfg.get("override_dataset_stats"): if cfg.get("override_dataset_stats"):

View File

@ -48,6 +48,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
split: str = "train", split: str = "train",
image_transforms: Callable | None = None, image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None, delta_timestamps: dict[list[float]] | None = None,
video_backend: str | None = None,
): ):
super().__init__() super().__init__()
self.repo_id = repo_id self.repo_id = repo_id
@ -69,6 +70,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.info = load_info(repo_id, version, root) self.info = load_info(repo_id, version, root)
if self.video: if self.video:
self.videos_dir = load_videos(repo_id, version, root) self.videos_dir = load_videos(repo_id, version, root)
self.video_backend = video_backend if video_backend is not None else "pyav"
@property @property
def fps(self) -> int: def fps(self) -> int:
@ -149,6 +151,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.video_frame_keys, self.video_frame_keys,
self.videos_dir, self.videos_dir,
self.tolerance_s, self.tolerance_s,
self.video_backend,
) )
if self.image_transforms is not None: if self.image_transforms is not None:
@ -188,6 +191,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
stats=None, stats=None,
info=None, info=None,
videos_dir=None, videos_dir=None,
video_backend=None,
) -> "LeRobotDataset": ) -> "LeRobotDataset":
"""Create a LeRobot Dataset from existing data and attributes instead of loading from the filesystem. """Create a LeRobot Dataset from existing data and attributes instead of loading from the filesystem.
@ -210,6 +214,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.stats = stats obj.stats = stats
obj.info = info if info is not None else {} obj.info = info if info is not None else {}
obj.videos_dir = videos_dir obj.videos_dir = videos_dir
obj.video_backend = video_backend if video_backend is not None else "pyav"
return obj return obj
@ -228,6 +233,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
split: str = "train", split: str = "train",
image_transforms: Callable | None = None, image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None, delta_timestamps: dict[list[float]] | None = None,
video_backend: str | None = None,
): ):
super().__init__() super().__init__()
self.repo_ids = repo_ids self.repo_ids = repo_ids
@ -241,6 +247,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
split=split, split=split,
delta_timestamps=delta_timestamps, delta_timestamps=delta_timestamps,
image_transforms=image_transforms, image_transforms=image_transforms,
video_backend=video_backend,
) )
for repo_id in repo_ids for repo_id in repo_ids
] ]

View File

@ -0,0 +1,101 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Contains utilities to process raw data format of png images files recorded with capture_camera_feed.py
"""
from pathlib import Path
import torch
from datasets import Dataset, Features, Image, Value
from PIL import Image as PILImage
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes
from lerobot.common.datasets.utils import calculate_episode_data_index, hf_transform_to_torch
from lerobot.common.datasets.video_utils import VideoFrame
def check_format(raw_dir: Path) -> bool:
image_paths = list(raw_dir.glob("frame_*.png"))
if len(image_paths) == 0:
raise ValueError
def load_from_raw(raw_dir: Path, fps: int, episodes: list[int] | None = None):
if episodes is not None:
# TODO(aliberts): add support for multi-episodes.
raise NotImplementedError()
ep_dict = {}
ep_idx = 0
image_paths = sorted(raw_dir.glob("frame_*.png"))
num_frames = len(image_paths)
ep_dict["observation.image"] = [PILImage.open(x) for x in image_paths]
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames)
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
ep_dicts = [ep_dict]
data_dict = concatenate_episodes(ep_dicts)
total_frames = data_dict["frame_index"].shape[0]
data_dict["index"] = torch.arange(0, total_frames, 1)
return data_dict
def to_hf_dataset(data_dict, video) -> Dataset:
features = {}
if video:
features["observation.image"] = VideoFrame()
else:
features["observation.image"] = Image()
features["episode_index"] = Value(dtype="int64", id=None)
features["frame_index"] = Value(dtype="int64", id=None)
features["timestamp"] = Value(dtype="float32", id=None)
features["index"] = Value(dtype="int64", id=None)
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
def from_raw_to_lerobot_format(
raw_dir: Path,
videos_dir: Path,
fps: int | None = None,
video: bool = True,
episodes: list[int] | None = None,
):
if video or episodes is not None:
# TODO(aliberts): support this
raise NotImplementedError
# sanity check
check_format(raw_dir)
if fps is None:
fps = 30
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
hf_dataset = to_hf_dataset(data_dict, video)
episode_data_index = calculate_episode_data_index(hf_dataset)
info = {
"fps": fps,
"video": video,
}
return hf_dataset, episode_data_index, info

View File

@ -27,7 +27,11 @@ from datasets.features.features import register_feature
def load_from_videos( def load_from_videos(
item: dict[str, torch.Tensor], video_frame_keys: list[str], videos_dir: Path, tolerance_s: float item: dict[str, torch.Tensor],
video_frame_keys: list[str],
videos_dir: Path,
tolerance_s: float,
backend: str = "pyav",
): ):
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function """Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a Segmentation Fault. in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a Segmentation Fault.
@ -46,14 +50,14 @@ def load_from_videos(
raise NotImplementedError("All video paths are expected to be the same for now.") raise NotImplementedError("All video paths are expected to be the same for now.")
video_path = data_dir / paths[0] video_path = data_dir / paths[0]
frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s) frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
item[key] = frames item[key] = frames
else: else:
# load one frame # load one frame
timestamps = [item[key]["timestamp"]] timestamps = [item[key]["timestamp"]]
video_path = data_dir / item[key]["path"] video_path = data_dir / item[key]["path"]
frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s) frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
item[key] = frames[0] item[key] = frames[0]
return item return item
@ -63,11 +67,23 @@ def decode_video_frames_torchvision(
video_path: str, video_path: str,
timestamps: list[float], timestamps: list[float],
tolerance_s: float, tolerance_s: float,
device: str = "cpu", backend: str = "pyav",
log_loaded_timestamps: bool = False, log_loaded_timestamps: bool = False,
): ):
"""Loads frames associated to the requested timestamps of a video """Loads frames associated to the requested timestamps of a video
The backend can be either "pyav" (default) or "video_reader".
"video_reader" requires installing torchvision from source, see:
https://github.com/pytorch/vision/blob/main/torchvision/csrc/io/decoder/gpu/README.rst
(note that you need to compile against ffmpeg<4.3)
While both use cpu, "video_reader" is faster than "pyav" but requires additional setup.
See our benchmark results for more info on performance:
https://github.com/huggingface/lerobot/pull/220
See torchvision doc for more info on these two backends:
https://pytorch.org/vision/0.18/index.html?highlight=backend#torchvision.set_video_backend
Note: Video benefits from inter-frame compression. Instead of storing every frame individually, Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to
that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame, that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame,
@ -78,21 +94,9 @@ def decode_video_frames_torchvision(
# set backend # set backend
keyframes_only = False keyframes_only = False
if device == "cpu": torchvision.set_video_backend(backend)
# explicitely use pyav if backend == "pyav":
torchvision.set_video_backend("pyav")
keyframes_only = True # pyav doesnt support accuracte seek keyframes_only = True # pyav doesnt support accuracte seek
elif device == "cuda":
# TODO(rcadene, aliberts): implement video decoding with GPU
# torchvision.set_video_backend("cuda")
# torchvision.set_video_backend("video_reader")
# requires installing torchvision from source, see: https://github.com/pytorch/vision/blob/main/torchvision/csrc/io/decoder/gpu/README.rst
# check possible bug: https://github.com/pytorch/vision/issues/7745
raise NotImplementedError(
"Video decoding on gpu with cuda is currently not supported. Use `device='cpu'`."
)
else:
raise ValueError(device)
# set a video stream reader # set a video stream reader
# TODO(rcadene): also load audio stream at the same time # TODO(rcadene): also load audio stream at the same time
@ -120,7 +124,9 @@ def decode_video_frames_torchvision(
if current_ts >= last_ts: if current_ts >= last_ts:
break break
reader.container.close() if backend == "pyav":
reader.container.close()
reader = None reader = None
query_ts = torch.tensor(timestamps) query_ts = torch.tensor(timestamps)

View File

@ -28,6 +28,7 @@ seed: ???
# "dataset_index" into the returned item. The index mapping is made according to the order in which the # "dataset_index" into the returned item. The index mapping is made according to the order in which the
# datsets are provided. # datsets are provided.
dataset_repo_id: lerobot/pusht dataset_repo_id: lerobot/pusht
video_backend: pyav
training: training:
offline_steps: ??? offline_steps: ???

View File

@ -55,7 +55,6 @@ from safetensors.torch import save_file
from lerobot.common.datasets.compute_stats import compute_stats from lerobot.common.datasets.compute_stats import compute_stats
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
from lerobot.common.datasets.utils import flatten_dict from lerobot.common.datasets.utils import flatten_dict
@ -70,6 +69,8 @@ def get_from_raw_to_lerobot_format_fn(raw_format: str):
from lerobot.common.datasets.push_dataset_to_hub.dora_parquet_format import from_raw_to_lerobot_format from lerobot.common.datasets.push_dataset_to_hub.dora_parquet_format import from_raw_to_lerobot_format
elif raw_format == "xarm_pkl": elif raw_format == "xarm_pkl":
from lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format import from_raw_to_lerobot_format from lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format import from_raw_to_lerobot_format
elif raw_format == "cam_png":
from lerobot.common.datasets.push_dataset_to_hub.cam_png_format import from_raw_to_lerobot_format
else: else:
raise ValueError( raise ValueError(
f"The selected {raw_format} can't be found. Did you add it to `lerobot/scripts/push_dataset_to_hub.py::get_from_raw_to_lerobot_format_fn`?" f"The selected {raw_format} can't be found. Did you add it to `lerobot/scripts/push_dataset_to_hub.py::get_from_raw_to_lerobot_format_fn`?"
@ -182,10 +183,6 @@ def push_dataset_to_hub(
meta_data_dir = Path(cache_dir) / "meta_data" meta_data_dir = Path(cache_dir) / "meta_data"
videos_dir = Path(cache_dir) / "videos" videos_dir = Path(cache_dir) / "videos"
# Download the raw dataset if available
if not raw_dir.exists():
download_raw(raw_dir, dataset_id)
if raw_format is None: if raw_format is None:
# TODO(rcadene, adilzouitine): implement auto_find_raw_format # TODO(rcadene, adilzouitine): implement auto_find_raw_format
raise NotImplementedError() raise NotImplementedError()

107
poetry.lock generated
View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand. # This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
[[package]] [[package]]
name = "absl-py" name = "absl-py"
@ -3805,31 +3805,31 @@ files = [
[[package]] [[package]]
name = "torch" name = "torch"
version = "2.3.0" version = "2.3.1"
description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
optional = false optional = false
python-versions = ">=3.8.0" python-versions = ">=3.8.0"
files = [ files = [
{file = "torch-2.3.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:d8ea5a465dbfd8501f33c937d1f693176c9aef9d1c1b0ca1d44ed7b0a18c52ac"}, {file = "torch-2.3.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:605a25b23944be5ab7c3467e843580e1d888b8066e5aaf17ff7bf9cc30001cc3"},
{file = "torch-2.3.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:09c81c5859a5b819956c6925a405ef1cdda393c9d8a01ce3851453f699d3358c"}, {file = "torch-2.3.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:f2357eb0965583a0954d6f9ad005bba0091f956aef879822274b1bcdb11bd308"},
{file = "torch-2.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:1bf023aa20902586f614f7682fedfa463e773e26c58820b74158a72470259459"}, {file = "torch-2.3.1-cp310-cp310-win_amd64.whl", hash = "sha256:32b05fe0d1ada7f69c9f86c14ff69b0ef1957a5a54199bacba63d22d8fab720b"},
{file = "torch-2.3.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:758ef938de87a2653bba74b91f703458c15569f1562bf4b6c63c62d9c5a0c1f5"}, {file = "torch-2.3.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:7c09a94362778428484bcf995f6004b04952106aee0ef45ff0b4bab484f5498d"},
{file = "torch-2.3.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:493d54ee2f9df100b5ce1d18c96dbb8d14908721f76351e908c9d2622773a788"}, {file = "torch-2.3.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:b2ec81b61bb094ea4a9dee1cd3f7b76a44555375719ad29f05c0ca8ef596ad39"},
{file = "torch-2.3.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:bce43af735c3da16cc14c7de2be7ad038e2fbf75654c2e274e575c6c05772ace"}, {file = "torch-2.3.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:490cc3d917d1fe0bd027057dfe9941dc1d6d8e3cae76140f5dd9a7e5bc7130ab"},
{file = "torch-2.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:729804e97b7cf19ae9ab4181f91f5e612af07956f35c8b2c8e9d9f3596a8e877"}, {file = "torch-2.3.1-cp311-cp311-win_amd64.whl", hash = "sha256:5802530783bd465fe66c2df99123c9a54be06da118fbd785a25ab0a88123758a"},
{file = "torch-2.3.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:d24e328226d8e2af7cf80fcb1d2f1d108e0de32777fab4aaa2b37b9765d8be73"}, {file = "torch-2.3.1-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:a7dd4ed388ad1f3d502bf09453d5fe596c7b121de7e0cfaca1e2017782e9bbac"},
{file = "torch-2.3.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:b0de2bdc0486ea7b14fc47ff805172df44e421a7318b7c4d92ef589a75d27410"}, {file = "torch-2.3.1-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:a486c0b1976a118805fc7c9641d02df7afbb0c21e6b555d3bb985c9f9601b61a"},
{file = "torch-2.3.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:a306c87a3eead1ed47457822c01dfbd459fe2920f2d38cbdf90de18f23f72542"}, {file = "torch-2.3.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:224259821fe3e4c6f7edf1528e4fe4ac779c77addaa74215eb0b63a5c474d66c"},
{file = "torch-2.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:f9b98bf1a3c8af2d4c41f0bf1433920900896c446d1ddc128290ff146d1eb4bd"}, {file = "torch-2.3.1-cp312-cp312-win_amd64.whl", hash = "sha256:e5fdccbf6f1334b2203a61a0e03821d5845f1421defe311dabeae2fc8fbeac2d"},
{file = "torch-2.3.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:dca986214267b34065a79000cee54232e62b41dff1ec2cab9abc3fc8b3dee0ad"}, {file = "torch-2.3.1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:3c333dc2ebc189561514eda06e81df22bf8fb64e2384746b2cb9f04f96d1d4c8"},
{file = "torch-2.3.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:20572f426965dd8a04e92a473d7e445fa579e09943cc0354f3e6fef6130ce061"}, {file = "torch-2.3.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:07e9ba746832b8d069cacb45f312cadd8ad02b81ea527ec9766c0e7404bb3feb"},
{file = "torch-2.3.0-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:e65ba85ae292909cde0dde6369826d51165a3fc8823dc1854cd9432d7f79b932"}, {file = "torch-2.3.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:462d1c07dbf6bb5d9d2f3316fee73a24f3d12cd8dacf681ad46ef6418f7f6626"},
{file = "torch-2.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:5515503a193781fd1b3f5c474e89c9dfa2faaa782b2795cc4a7ab7e67de923f6"}, {file = "torch-2.3.1-cp38-cp38-win_amd64.whl", hash = "sha256:ff60bf7ce3de1d43ad3f6969983f321a31f0a45df3690921720bcad6a8596cc4"},
{file = "torch-2.3.0-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:6ae9f64b09516baa4ef890af0672dc981c20b1f0d829ce115d4420a247e88fba"}, {file = "torch-2.3.1-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:bee0bd33dc58aa8fc8a7527876e9b9a0e812ad08122054a5bff2ce5abf005b10"},
{file = "torch-2.3.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:cd0dc498b961ab19cb3f8dbf0c6c50e244f2f37dbfa05754ab44ea057c944ef9"}, {file = "torch-2.3.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:aaa872abde9a3d4f91580f6396d54888620f4a0b92e3976a6034759df4b961ad"},
{file = "torch-2.3.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:e05f836559251e4096f3786ee99f4a8cbe67bc7fbedba8ad5e799681e47c5e80"}, {file = "torch-2.3.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:3d7a7f7ef21a7520510553dc3938b0c57c116a7daee20736a9e25cbc0e832bdc"},
{file = "torch-2.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:4fb27b35dbb32303c2927da86e27b54a92209ddfb7234afb1949ea2b3effffea"}, {file = "torch-2.3.1-cp39-cp39-win_amd64.whl", hash = "sha256:4777f6cefa0c2b5fa87223c213e7b6f417cf254a45e5829be4ccd1b2a4ee1011"},
{file = "torch-2.3.0-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:760f8bedff506ce9e6e103498f9b1e9e15809e008368594c3a66bf74a8a51380"}, {file = "torch-2.3.1-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:2bb5af780c55be68fe100feb0528d2edebace1d55cb2e351de735809ba7391eb"},
] ]
[package.dependencies] [package.dependencies]
@ -3850,7 +3850,7 @@ nvidia-cusparse-cu12 = {version = "12.1.0.106", markers = "platform_system == \"
nvidia-nccl-cu12 = {version = "2.20.5", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} nvidia-nccl-cu12 = {version = "2.20.5", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
nvidia-nvtx-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} nvidia-nvtx-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
sympy = "*" sympy = "*"
triton = {version = "2.3.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.12\""} triton = {version = "2.3.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.12\""}
typing-extensions = ">=4.8.0" typing-extensions = ">=4.8.0"
[package.extras] [package.extras]
@ -3859,37 +3859,37 @@ optree = ["optree (>=0.9.1)"]
[[package]] [[package]]
name = "torchvision" name = "torchvision"
version = "0.18.0" version = "0.18.1"
description = "image and video datasets and models for torch deep learning" description = "image and video datasets and models for torch deep learning"
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "torchvision-0.18.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:dd61628a3d189c6852a12dc5ed4cd2eece66d2d67f35a866cb16f1dcb06c8c62"}, {file = "torchvision-0.18.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3e694e54b0548dad99c12af6bf0c8e4f3350137d391dcd19af22a1c5f89322b3"},
{file = "torchvision-0.18.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:493c45f9937dad37aa1b64b14da17c7a589c72b91adc4837d431009cfe29bd53"}, {file = "torchvision-0.18.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:0b3bda0aa5b416eeb547143b8eeaf17720bdba9cf516dc991aacb81811aa96a5"},
{file = "torchvision-0.18.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:5337f6acfa1fe959d5cb340d01a00614d6b31ce7a4824ccb95435a85c5273b95"}, {file = "torchvision-0.18.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:573ff523c739405edb085f65cb592f482d28a30e29b0be4c4ba08040b3ae785f"},
{file = "torchvision-0.18.0-cp310-cp310-win_amd64.whl", hash = "sha256:bd8e6f3b5beb49965f15c461302488edfa3d8c2d01d3bb79b150d6fb62711e3a"}, {file = "torchvision-0.18.1-cp310-cp310-win_amd64.whl", hash = "sha256:ef7bbbc60b38e831a75e547c66ca1784f2ac27100f9e4ddbe9614cef6cbcd942"},
{file = "torchvision-0.18.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6896a52168befe1105fb3c9335287390ed227e71d1e4ec4d68b62e8a3099fc09"}, {file = "torchvision-0.18.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:80b5d794dd0fdba787adc22f1a367a5ead452327686473cb260dd94364bc56a6"},
{file = "torchvision-0.18.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:3d7955398d4ceaad77c487c2c44f6f7813112402c9bab8cd906d346005891048"}, {file = "torchvision-0.18.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:9077cf590cdb3a5e8fdf5cdb71797f8c67713f974cf0228ecb17fcd670ab42f9"},
{file = "torchvision-0.18.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:e5a24d620cea14a4bb89f24aa2b506230c0a16a3ada57fc53ad80cfd256a2128"}, {file = "torchvision-0.18.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:ceb993a882f1ae7ae373ed39c28d7e3e802205b0e59a7ed84ef4028f0bba8d7f"},
{file = "torchvision-0.18.0-cp311-cp311-win_amd64.whl", hash = "sha256:6ad70ddfa879bda5ed886b2518fe562640e0059787cbd65cb2bffa7674541410"}, {file = "torchvision-0.18.1-cp311-cp311-win_amd64.whl", hash = "sha256:52f7436140045dc2239cdc502aa76b2bd8bd676d64244ff154d304aa69852046"},
{file = "torchvision-0.18.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:eb9d83c0e1dbb54ecb0fb04c87f786333e3a6fb8b9c400aca7c31081f9aa5707"}, {file = "torchvision-0.18.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2be6f0bf7c455c89a51a1dbb6f668d36c6edc479f49ac912d745d10df5715657"},
{file = "torchvision-0.18.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:b657d052d146f24cb3b2a78219bfc82ae70a9706671c50f632528907d10cccec"}, {file = "torchvision-0.18.1-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:f118d887bfde3a948a41d56587525401e5cac1b7db2eaca203324d6ed2b1caca"},
{file = "torchvision-0.18.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:a964afbc7ddf50a46b941477f6c35729b416deedd139756befd488245e2e226d"}, {file = "torchvision-0.18.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:13d24d904f65e62d66a1e0c41faec630bc193867b8a4a01166769e8a8e8df8e9"},
{file = "torchvision-0.18.0-cp312-cp312-win_amd64.whl", hash = "sha256:7c770f0f748e0b17f57c0297508d7254f686cdf03fc2e2949f422b20574f4c0f"}, {file = "torchvision-0.18.1-cp312-cp312-win_amd64.whl", hash = "sha256:ed6340b69a63a625e512a66127210d412551d9c5f2ad2978130c6a45bf56cd4a"},
{file = "torchvision-0.18.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:2115a1906c015f5da9ceedc40a983313b0fd6e2c8a17108a92991706f51f6987"}, {file = "torchvision-0.18.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b1c3864fa9378c88bce8ad0ef3599f4f25397897ce612e1c245c74b97092f35e"},
{file = "torchvision-0.18.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:6323f7e5423ff2594d5891863b919deb9d0de95f01c36bf26fbd879036b6ed08"}, {file = "torchvision-0.18.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:02085a2ffc7461f5c0edb07d6f3455ee1806561f37736b903da820067eea58c7"},
{file = "torchvision-0.18.0-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:925d0a82cccf6f986c18b29b4392a942db65cbdb73c13a129c8493822eb9e36f"}, {file = "torchvision-0.18.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:9726c316a2501df8503e5a5dc46a631afd4c515a958972e5b7f7b9c87d2125c0"},
{file = "torchvision-0.18.0-cp38-cp38-win_amd64.whl", hash = "sha256:95b42d0dc599b47a01530c7439a5751e67e45b85e3a67113989cf7c7c70f2039"}, {file = "torchvision-0.18.1-cp38-cp38-win_amd64.whl", hash = "sha256:64a2662dbf30db9055d8b201d6e56f312a504e5ccd9d144c57c41622d3c524cb"},
{file = "torchvision-0.18.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:75e22ecf44a13b8f95b8ad421c0261282d859c61816badaca1959e073ccdd691"}, {file = "torchvision-0.18.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:975b8594c0f5288875408acbb74946eea786c5b008d129c0d045d0ead23742bc"},
{file = "torchvision-0.18.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:4c334b3e719ba0a9ba6e15d4aff1178f5e6d029174f346163fed525f0ccfffd3"}, {file = "torchvision-0.18.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:da83c8bbd34d8bee48bfa1d1b40e0844bc3cba10ed825a5a8cbe3ce7b62264cd"},
{file = "torchvision-0.18.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:36efd87001c6bee2383e043e46a025affb03179747c8f4777b9918527ffce756"}, {file = "torchvision-0.18.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:54bfcd352abb396d5c9c237d200167c178bd136051b138e1e8ef46ce367c2773"},
{file = "torchvision-0.18.0-cp39-cp39-win_amd64.whl", hash = "sha256:ccc292e093771d5baacf5535ac4416306b6b5f15676341cd4d010d8542eace25"}, {file = "torchvision-0.18.1-cp39-cp39-win_amd64.whl", hash = "sha256:5c8366a1aeee49e9ea9e64b30d199debdf06b1bd7610a76165eb5d7869c3bde5"},
] ]
[package.dependencies] [package.dependencies]
numpy = "*" numpy = "*"
pillow = ">=5.3.0,<8.3.dev0 || >=8.4.dev0" pillow = ">=5.3.0,<8.3.dev0 || >=8.4.dev0"
torch = "2.3.0" torch = "2.3.1"
[package.extras] [package.extras]
scipy = ["scipy"] scipy = ["scipy"]
@ -3916,17 +3916,17 @@ telegram = ["requests"]
[[package]] [[package]]
name = "triton" name = "triton"
version = "2.3.0" version = "2.3.1"
description = "A language and compiler for custom Deep Learning operations" description = "A language and compiler for custom Deep Learning operations"
optional = false optional = false
python-versions = "*" python-versions = "*"
files = [ files = [
{file = "triton-2.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ce4b8ff70c48e47274c66f269cce8861cf1dc347ceeb7a67414ca151b1822d8"}, {file = "triton-2.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c84595cbe5e546b1b290d2a58b1494df5a2ef066dd890655e5b8a8a92205c33"},
{file = "triton-2.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c3d9607f85103afdb279938fc1dd2a66e4f5999a58eb48a346bd42738f986dd"}, {file = "triton-2.3.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c9d64ae33bcb3a7a18081e3a746e8cf87ca8623ca13d2c362413ce7a486f893e"},
{file = "triton-2.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:218d742e67480d9581bafb73ed598416cc8a56f6316152e5562ee65e33de01c0"}, {file = "triton-2.3.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eaf80e8761a9e3498aa92e7bf83a085b31959c61f5e8ac14eedd018df6fccd10"},
{file = "triton-2.3.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:381ec6b3dac06922d3e4099cfc943ef032893b25415de295e82b1a82b0359d2c"}, {file = "triton-2.3.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b13bf35a2b659af7159bf78e92798dc62d877aa991de723937329e2d382f1991"},
{file = "triton-2.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:038e06a09c06a164fef9c48de3af1e13a63dc1ba3c792871e61a8e79720ea440"}, {file = "triton-2.3.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63381e35ded3304704ea867ffde3b7cfc42c16a55b3062d41e017ef510433d66"},
{file = "triton-2.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d8f636e0341ac348899a47a057c3daea99ea7db31528a225a3ba4ded28ccc65"}, {file = "triton-2.3.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1d968264523c7a07911c8fb51b4e0d1b920204dae71491b1fe7b01b62a31e124"},
] ]
[package.dependencies] [package.dependencies]
@ -4301,9 +4301,10 @@ dora = ["gym-dora"]
pusht = ["gym-pusht"] pusht = ["gym-pusht"]
test = ["pytest", "pytest-cov"] test = ["pytest", "pytest-cov"]
umi = ["imagecodecs"] umi = ["imagecodecs"]
video-benchmark = ["scikit-image"]
xarm = ["gym-xarm"] xarm = ["gym-xarm"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.10,<3.13" python-versions = ">=3.10,<3.13"
content-hash = "23ddb8dd774a4faf85d08a07dfdf19badb7c370120834b71df4afca254520771" content-hash = "61f99befbc2250fe59cb54119c3dbd3aa3c1dfe5d3d7790c6f7c4f90fe43112e"

View File

@ -60,6 +60,7 @@ pyav = ">=12.0.5"
moviepy = ">=1.0.3" moviepy = ">=1.0.3"
rerun-sdk = ">=0.15.1" rerun-sdk = ">=0.15.1"
deepdiff = ">=7.0.1" deepdiff = ">=7.0.1"
scikit-image = {version = "^0.23.2", optional = true}
[tool.poetry.extras] [tool.poetry.extras]
@ -70,6 +71,7 @@ aloha = ["gym-aloha"]
dev = ["pre-commit", "debugpy"] dev = ["pre-commit", "debugpy"]
test = ["pytest", "pytest-cov"] test = ["pytest", "pytest-cov"]
umi = ["imagecodecs"] umi = ["imagecodecs"]
video_benchmark = ["scikit-image"]
[tool.ruff] [tool.ruff]
line-length = 110 line-length = 110