Add video dataset by default

This commit is contained in:
Cadene 2024-04-30 08:43:25 +00:00
parent 72bcfb9ee4
commit d632f8cb51
9 changed files with 897 additions and 39 deletions

View File

@ -0,0 +1,186 @@
# Video benchmark
## Questions
What is the optimal trade-off between:
- maximizing loading time with random access,
- minimizing memory space on disk,
- maximizing success rate of policies?
How to encode videos?
- How much compression (`-crf`)? Low compression with `0`, normal compression with `20` or extreme with `56`?
- What pixel format to use (`-pix_fmt`)? `yuv444p` or `yuv420p`?
- How many key frames (`-g`)? A key frame every `10` frames?
How to decode videos?
- Which `decoder`? `torchvision`, `torchaudio`, `ffmpegio`, `decord`, or `nvc`?
## Metrics
**Percentage of data compression (higher is better)**
`pc_compression` is the ratio of the memory space on disk taken by the original images to encode, to the memory space taken by the encoded video. For instance, `pc_compression=400%` means that the video takes 4 times less memory space on disk compared to the original images.
**Percentage of loading time (lower is better)**
`pc_load_time` is the ratio of the time it takes to load original images at given timestamps, to the time it takes to decode the exact same frames from the video. Lower is better. For instance, `pc_load_time=120%` means that decoding from video is a bit slower than loading the original images.
**Average L2 error per pixel (lower is better)**
`avg_per_pixel_l2_error` is the average L2 error between each decoded frame and its corresponding original image over all requested timestamps, and also divided by the number of pixels in the image to be comparable when switching to different image sizes.
**Loss of a pretrained policy (higher is better)** (not available)
`loss_pretrained` is the result of evaluating with the selected encoding/decoding settings a policy pretrained on original images. It is easier to understand than `avg_l2_error`.
**Success rate after retraining (higher is better)** (not available)
`success_rate` is the result of training and evaluating a policy with the selected encoding/decoding settings. It is the most difficult metric to get but also the very best.
## Variables
**Image content**
We don't expect the same optimal settings for a dataset of images from a simulation, or from real-world in an appartment, or in a factory, or outdoor, etc. Hence, we run this bechmark on two datasets: `pusht` (simulation) and `umi` (real-world outdoor).
**Requested timestamps**
In this benchmark, we focus on the loading time of random access, so we are not interested about sequentially loading all frames of a video like in a movie. However, the number of consecutive timestamps requested and their spacing can greatly affect the `pc_load_time`. In fact, it is expected to get faster loading time by decoding a large number of consecutive frames from a video, than to load the same data from individual images. To reflect our robotics use case, we consider a setting where we load 2 consecutive frames with 4 frames of spacing.
**Data augmentations**
We might revisit this benchmark and find better settings if we train our policies with various data augmentations to make them more robusts (e.g. robust to color changes, compression, etc.).
## Results
### Loading 2 consecutive frames with 4 frames spacing (Diffusion Policy setting)
**`decoder`**
| repo_id | decoder | pc_load_time | avg_per_pixel_l2_error |
| --- | --- | --- | --- |
| lerobot/pusht | <span style="color: #32CD32;">torchvision</span> | 0.166 | 0.0000119 |
| lerobot/pusht | ffmpegio | 0.009 | 0.0001182 |
| lerobot/pusht | torchaudio | 0.138 | 0.0000359 |
| lerobot/umi_cup_in_the_wild | <span style="color: #32CD32;">torchvision</span> | 0.174 | 0.0000174 |
| lerobot/umi_cup_in_the_wild | ffmpegio | 0.010 | 0.0000735 |
| lerobot/umi_cup_in_the_wild | torchaudio | 0.154 | 0.0000340 |
**`pix_fmt`**
| repo_id | pix_fmt | pc_compression | pc_load_time | avg_per_pixel_l2_error |
| --- | --- | --- | --- | --- |
| lerobot/pusht | yuv420p | 3.602 | 0.202 | 0.0000661 |
| lerobot/pusht | <span style="color: #32CD32;">yuv444p</span> | 3.213 | 0.153 | 0.0000110 |
| lerobot/umi_cup_in_the_wild | yuv420p | 8.879 | 0.202 | 0.0000332 |
| lerobot/umi_cup_in_the_wild | <span style="color: #32CD32;">yuv444p</span> | 8.517 | 0.165 | 0.0000175 |
**`g`**
| repo_id | g | pc_compression | pc_load_time | avg_per_pixel_l2_error |
| --- | --- | --- | --- | --- |
| lerobot/pusht | 1 | 1.308 | 0.190 | 0.0000151 |
| lerobot/pusht | 5 | 2.739 | 0.184 | 0.0000123 |
| lerobot/pusht | 10 | 3.213 | 0.144 | 0.0000116 |
| lerobot/pusht | 15 | 3.460 | 0.137 | 0.0000112 |
| lerobot/pusht | 20 | 3.559 | 0.118 | 0.0000109 |
| lerobot/pusht | 30 | 3.697 | 0.104 | 0.0000117 |
| lerobot/pusht | 40 | 3.763 | 0.092 | 0.0000116 |
| lerobot/pusht | 60 | 3.925 | 0.068 | 0.0000117 |
| lerobot/pusht | 100 | 4.010 | 0.054 | 0.0000117 |
| lerobot/pusht | <span style="color: #32CD32;">None</span> | 4.058 | 0.043 | 0.0000117 |
| lerobot/umi_cup_in_the_wild | 1 | 4.790 | 0.236 | 0.0000221 |
| lerobot/umi_cup_in_the_wild | 5 | 7.707 | 0.201 | 0.0000185 |
| lerobot/umi_cup_in_the_wild | 10 | 8.517 | 0.172 | 0.0000177 |
| lerobot/umi_cup_in_the_wild | 15 | 8.830 | 0.152 | 0.0000170 |
| lerobot/umi_cup_in_the_wild | 20 | 8.961 | 0.133 | 0.0000167 |
| lerobot/umi_cup_in_the_wild | 30 | 8.850 | 0.113 | 0.0000167 |
| lerobot/umi_cup_in_the_wild | 40 | 8.996 | 0.109 | 0.0000174 |
| lerobot/umi_cup_in_the_wild | 60 | 9.113 | 0.081 | 0.0000163 |
| lerobot/umi_cup_in_the_wild | 100 | 9.278 | 0.051 | 0.0000173 |
| lerobot/umi_cup_in_the_wild | <span style="color: #32CD32;">None</span> | 9.396 | 0.030 | 0.0000165 |
**`crf`**
| repo_id | crf | pc_compression | pc_load_time | avg_per_pixel_l2_error |
| --- | --- | --- | --- | --- |
| lerobot/pusht | 0 | 4.529 | 0.041 | 0.0000035 |
| lerobot/pusht | 5 | 3.138 | 0.040 | 0.0000077 |
| lerobot/pusht | <span style="color: #32CD32;">10</span> | 4.058 | 0.038 | 0.0000121 |
| lerobot/pusht | <span style="color: #32CD32;">15</span> | 5.407 | 0.039 | 0.0000195 |
| lerobot/pusht | <span style="color: #32CD32;">20</span> | 7.335 | 0.039 | 0.0000319 |
| lerobot/pusht | <span style="color: #32CD32;">None</span> | 8.909 | 0.046 | 0.0000425 |
| lerobot/pusht | 25 | 10.213 | 0.039 | 0.0000519 |
| lerobot/pusht | 30 | 14.516 | 0.041 | 0.0000795 |
| lerobot/pusht | 40 | 23.546 | 0.041 | 0.0001557 |
| lerobot/pusht | 50 | 28.460 | 0.042 | 0.0002723 |
| lerobot/umi_cup_in_the_wild | 0 | 2.318 | 0.012 | 0.0000056 |
| lerobot/umi_cup_in_the_wild | 5 | 4.899 | 0.019 | 0.0000132 |
| lerobot/umi_cup_in_the_wild | <span style="color: #32CD32;">10</span> | 9.396 | 0.026 | 0.0000183 |
| lerobot/umi_cup_in_the_wild | <span style="color: #32CD32;">15</span> | 19.161 | 0.034 | 0.0000241 |
| lerobot/umi_cup_in_the_wild | <span style="color: #32CD32;">20</span> | 39.311 | 0.039 | 0.0000329 |
| lerobot/umi_cup_in_the_wild | <span style="color: #32CD32;">None</span> | 60.530 | 0.043 | 0.0000401 |
| lerobot/umi_cup_in_the_wild | 25 | 81.048 | 0.046 | 0.0000454 |
| lerobot/umi_cup_in_the_wild | 30 | 165.189 | 0.051 | 0.0000609 |
| lerobot/umi_cup_in_the_wild | 40 | 544.478 | 0.056 | 0.0001095 |
| lerobot/umi_cup_in_the_wild | 50 | 1109.556 | 0.072 | 0.0001815 |
### Loading 6 consecutive frames with no spacing (TDMPC setting)
**`decoder`**
| repo_id | decoder | pc_load_time | avg_per_pixel_l2_error |
| --- | --- | --- | --- |
| lerobot/pusht | <span style="color: #32CD32;">torchvision</span> | 0.386 | 0.0000117 |
| lerobot/pusht | ffmpegio | 0.008 | 0.0000117 |
| lerobot/pusht | torchaudio | 0.184 | 0.0000356 |
| lerobot/umi_cup_in_the_wild | <span style="color: #32CD32;">torchvision</span> | 0.448 | 0.0000178 |
| lerobot/umi_cup_in_the_wild | ffmpegio | 0.009 | 0.0000178 |
| lerobot/umi_cup_in_the_wild | torchaudio | 0.149 | 0.0000349 |
**`pix_fmt`**
| repo_id | pix_fmt | pc_compression | pc_load_time | avg_per_pixel_l2_error |
| --- | --- | --- | --- | --- |
| lerobot/pusht | yuv420p | 3.602 | 0.518 | 0.0000651 |
| lerobot/pusht | <span style="color: #32CD32;">yuv444p</span> | 3.213 | 0.401 | 0.0000117 |
| lerobot/umi_cup_in_the_wild | yuv420p | 8.879 | 0.578 | 0.0000334 |
| lerobot/umi_cup_in_the_wild | <span style="color: #32CD32;">yuv444p</span> | 8.517 | 0.479 | 0.0000178 |
**`g`**
| repo_id | g | pc_compression | pc_load_time | avg_per_pixel_l2_error |
| --- | --- | --- | --- | --- |
| lerobot/pusht | 1 | 1.308 | 0.528 | 0.0000152 |
| lerobot/pusht | 5 | 2.739 | 0.483 | 0.0000124 |
| lerobot/pusht | 10 | 3.213 | 0.396 | 0.0000117 |
| lerobot/pusht | 15 | 3.460 | 0.379 | 0.0000118 |
| lerobot/pusht | 20 | 3.559 | 0.319 | 0.0000114 |
| lerobot/pusht | 30 | 3.697 | 0.278 | 0.0000116 |
| lerobot/pusht | 40 | 3.763 | 0.243 | 0.0000115 |
| lerobot/pusht | 60 | 3.925 | 0.186 | 0.0000118 |
| lerobot/pusht | 100 | 4.010 | 0.156 | 0.0000119 |
| lerobot/pusht | <span style="color: #32CD32;">None</span> | 4.058 | 0.105 | 0.0000121 |
| lerobot/umi_cup_in_the_wild | 1 | 4.790 | 0.605 | 0.0000221 |
| lerobot/umi_cup_in_the_wild | 5 | 7.707 | 0.533 | 0.0000183 |
| lerobot/umi_cup_in_the_wild | 10 | 8.517 | 0.469 | 0.0000178 |
| lerobot/umi_cup_in_the_wild | 15 | 8.830 | 0.399 | 0.0000174 |
| lerobot/umi_cup_in_the_wild | 20 | 8.961 | 0.382 | 0.0000175 |
| lerobot/umi_cup_in_the_wild | 30 | 8.850 | 0.326 | 0.0000172 |
| lerobot/umi_cup_in_the_wild | 40 | 8.996 | 0.279 | 0.0000173 |
| lerobot/umi_cup_in_the_wild | 60 | 9.113 | 0.226 | 0.0000174 |
| lerobot/umi_cup_in_the_wild | 100 | 9.278 | 0.150 | 0.0000175 |
| lerobot/umi_cup_in_the_wild | <span style="color: #32CD32;">None</span> | 9.396 | 0.076 | 0.0000176 |
**`crf`**
| repo_id | crf | pc_compression | pc_load_time | avg_per_pixel_l2_error |
| --- | --- | --- | --- | --- |
| lerobot/pusht | 0 | 4.529 | 0.108 | 0.0000035 |
| lerobot/pusht | 5 | 3.138 | 0.099 | 0.0000077 |
| lerobot/pusht | 10 | 4.058 | 0.091 | 0.0000121 |
| lerobot/pusht | 15 | 5.407 | 0.095 | 0.0000195 |
| lerobot/pusht | 20 | 7.335 | 0.100 | 0.0000318 |
| lerobot/pusht | <span style="color: #32CD32;">None</span> | 8.909 | 0.102 | 0.0000422 |
| lerobot/pusht | 25 | 10.213 | 0.102 | 0.0000517 |
| lerobot/pusht | 30 | 14.516 | 0.104 | 0.0000795 |
| lerobot/pusht | 40 | 23.546 | 0.106 | 0.0001555 |
| lerobot/pusht | 50 | 28.460 | 0.110 | 0.0002723 |
| lerobot/umi_cup_in_the_wild | 0 | 2.318 | 0.032 | 0.0000056 |
| lerobot/umi_cup_in_the_wild | 5 | 4.899 | 0.052 | 0.0000127 |
| lerobot/umi_cup_in_the_wild | <span style="color: #32CD32;">10</span> | 9.396 | 0.073 | 0.0000176 |
| lerobot/umi_cup_in_the_wild | <span style="color: #32CD32;">15</span> | 19.161 | 0.097 | 0.0000234 |
| lerobot/umi_cup_in_the_wild | <span style="color: #32CD32;">20</span> | 39.311 | 0.110 | 0.0000321 |
| lerobot/umi_cup_in_the_wild | <span style="color: #32CD32;">None</span> | 60.530 | 0.117 | 0.0000393 |
| lerobot/umi_cup_in_the_wild | 25 | 81.048 | 0.126 | 0.0000446 |
| lerobot/umi_cup_in_the_wild | 30 | 165.189 | 0.138 | 0.0000603 |
| lerobot/umi_cup_in_the_wild | 40 | 544.478 | 0.151 | 0.0001095 |
| lerobot/umi_cup_in_the_wild | 50 | 1109.556 | 0.167 | 0.0001817 |

View File

@ -0,0 +1,230 @@
"""This file contains work-in-progress alternative to default decoding strategy."""
import einops
import torch
def decode_video_frames_ffmpegio(video_path, timestamps, device="cpu"):
# assert device == "cpu", f"Only CPU decoding is supported with ffmpegio, but device is {device}"
import einops
import ffmpegio
num_contiguous_frames = 1 # noqa: F841
image_format = "rgb24"
list_frames = []
for timestamp in timestamps:
kwargs = {
"ss": str(timestamp),
# vframes=num_contiguous_frames,
"pix_fmt": image_format,
# hwaccel=None if device == "cpu" else device, # ,
"show_log": True,
}
if device == "cuda":
kwargs["hwaccel_in"] = "cuda"
kwargs["hwaccel_output_format_in"] = "cuda"
fs, frames = ffmpegio.video.read(str(video_path), **kwargs)
list_frames.append(torch.from_numpy(frames))
frames = torch.cat(list_frames)
frames = einops.rearrange(frames, "b h w c -> b c h w")
frames = frames.type(torch.float32) / 255
return frames
def yuv_to_rgb(frames):
assert frames.dtype == torch.uint8
assert frames.ndim == 4
assert frames.shape[1] == 3
frames = frames.cpu().to(torch.float)
y = frames[..., 0, :, :]
u = frames[..., 1, :, :]
v = frames[..., 2, :, :]
y /= 255
u = u / 255 - 0.5
v = v / 255 - 0.5
r = y + 1.13983 * v
g = y + -0.39465 * u - 0.58060 * v
b = y + 2.03211 * u
rgb = torch.stack([r, g, b], 1)
rgb = (rgb * 255).clamp(0, 255).to(torch.uint8)
return rgb
def yuv_to_rgb_cv2(frames, return_hwc=True):
assert frames.dtype == torch.uint8
assert frames.ndim == 4
assert frames.shape[1] == 3
frames = frames.cpu()
import cv2
frames = einops.rearrange(frames, "b c h w -> b h w c")
frames = frames.numpy()
frames = [cv2.cvtColor(frame, cv2.COLOR_YUV2RGB) for frame in frames]
frames = [torch.from_numpy(frame) for frame in frames]
frames = torch.stack(frames)
if not return_hwc:
frames = einops.rearrange(frames, "b h w c -> b c h w")
return frames
def decode_video_frames_torchaudio(video_path, timestamps, device="cpu"):
num_contiguous_frames = 1
width = None
height = None
# image_format = "rgb" # or "yuv"
# image_format = None
image_format = "yuv444p"
# image_format = "yuv444p"
# image_format = "rgb24"
frame_rate = None
scale_full_range_filter = False
filter_desc = []
video_stream_kwgs = {
"frames_per_chunk": num_contiguous_frames,
# "buffer_chunk_size": num_contiguous_frames,
}
# choice of decoder
if device == "cuda":
video_stream_kwgs["hw_accel"] = "cuda:0"
video_stream_kwgs["decoder"] = "h264_cuvid"
# video_stream_kwgs["decoder"] = "hevc_cuvid"
# video_stream_kwgs["decoder"] = "av1_cuvid"
# video_stream_kwgs["decoder"] = "ffv1_cuvid"
else:
video_stream_kwgs["decoder"] = "h264"
# video_stream_kwgs["decoder"] = "hevc"
# video_stream_kwgs["decoder"] = "av1"
# video_stream_kwgs["decoder"] = "ffv1"
# resize
resize_width = width is not None
resize_height = height is not None
if resize_width or resize_height:
if device == "cuda":
assert resize_width and resize_height
video_stream_kwgs["decoder_option"] = {"resize": f"{width}x{height}"}
else:
scales = []
if resize_width:
scales.append(f"width={width}")
if resize_height:
scales.append(f"height={height}")
filter_desc.append(f"scale={':'.join(scales)}")
# choice of format
if image_format is not None:
if device == "cuda":
# TODO(rcadene): rebuild ffmpeg with --enable-cuda-nvcc, --enable-cuvid, and --enable-libnpp
# filter_desc.append(f"scale=format={image_format}")
# filter_desc.append(f"scale_cuda=format={image_format}")
# filter_desc.append(f"scale_npp=format={image_format}")
filter_desc.append(f"format=pix_fmts={image_format}")
else:
filter_desc.append(f"format=pix_fmts={image_format}")
# choice of frame rate
if frame_rate is not None:
filter_desc.append(f"fps={frame_rate}")
# to set output scale [0-255] instead of [16-235]
if scale_full_range_filter:
filter_desc.append("scale=in_range=limited:out_range=full")
if len(filter_desc) > 0:
video_stream_kwgs["filter_desc"] = ",".join(filter_desc)
# create a stream and load a certain number of frame at a certain frame rate
# TODO(rcadene): make sure it's the most optimal way to do it
from torchaudio.io import StreamReader
print(video_stream_kwgs)
list_frames = []
for timestamp in timestamps:
s = StreamReader(str(video_path))
s.seek(timestamp)
s.add_video_stream(**video_stream_kwgs)
s.fill_buffer()
(frames,) = s.pop_chunks()
if "yuv" in image_format:
frames = yuv_to_rgb(frames)
assert frames.dtype == torch.uint8
frames = frames.type(torch.float32)
# if device == "cuda":
# The original data had limited range, which is 16-235, and torchaudio does not convert,
# while FFmpeg converts it to full range 0-255. So you can apply a linear transformation.
if not scale_full_range_filter:
frames -= 16
frames *= 255 / (235 - 16)
frames /= 255
frames = frames.clip(0, 1)
list_frames.append(frames)
frames = torch.cat(list_frames)
return frames
# def _decode_frames_decord(video_path, timestamp):
# num_contiguous_frames = 1 # noqa: F841 TODO(rcadene): remove
# device = "cpu"
# from decord import VideoReader, cpu, gpu
# with open(str(video_path), "rb") as f:
# ctx = gpu if device == "cuda" else cpu
# vr = VideoReader(f, ctx=ctx(0)) # noqa: F841
# raise NotImplementedError("Convert `timestamp` into frame_id")
# # frame_id = frame_ids[0].item()
# # frames = vr.get_batch([frame_id])
# # frames = torch.from_numpy(frames.asnumpy())
# # frames = einops.rearrange(frames, "b h w c -> b c h w")
# # return frames
# def decode_frames_nvc(video_path, timestamps, device="cuda"):
# assert device == "cuda"
# import PyNvCodec as nvc
# import PytorchNvCodec as pnvc
# gpuID = 0
# nvDec = nvc.PyNvDecoder('path_to_video_file', gpuID)
# to_rgb = nvc.PySurfaceConverter(nvDec.Width(), nvDec.Height(), nvc.PixelFormat.NV12, nvc.PixelFormat.RGB, gpuID)
# to_planar = nvc.PySurfaceConverter(nvDec.Width(), nvDec.Height(), nvc.PixelFormat.RGB, nvc.PixelFormat.RGB_PLANAR, gpuID)
# while True:
# # Obtain NV12 decoded surface from decoder;
# rawSurface = nvDec.DecodeSingleSurface()
# if (rawSurface.Empty()):
# break
# # Convert to RGB interleaved;
# rgb_byte = to_rgb.Execute(rawSurface)
# # Convert to RGB planar because that's what to_tensor + normalize are doing;
# rgb_planar = to_planar.Execute(rgb_byte)
# # Create torch tensor from it and reshape because
# # pnvc.makefromDevicePtrUint8 creates just a chunk of CUDA memory
# # and then copies data from plane pointer to allocated chunk;
# surfPlane = rgb_planar.PlanePtr()
# surface_tensor = pnvc.makefromDevicePtrUint8(surfPlane.GpuMem(), surfPlane.Width(), surfPlane.Height(), surfPlane.Pitch(), surfPlane.ElemSize())
# surface_tensor.resize_(3, target_h, target_w)

View File

@ -0,0 +1,337 @@
import json
import os
import random
import shutil
import subprocess
import time
from pathlib import Path
import einops
import numpy
import PIL
import torch
from lerobot.common.datasets._video_benchmark._video_utils import (
decode_video_frames_ffmpegio,
decode_video_frames_torchaudio,
)
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.video_utils import (
decode_video_frames_torchvision,
)
def get_directory_size(directory):
total_size = 0
# Iterate over all files and subdirectories recursively
for item in directory.rglob("*"):
if item.is_file():
# Add the file size to the total
total_size += item.stat().st_size
return total_size
def run_video_benchmark(output_dir, cfg, seed=1337, timestamps_mode="diffusion"):
output_dir = Path(output_dir)
if output_dir.exists():
shutil.rmtree(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
repo_id = cfg["repo_id"]
# TODO(rcadene): rewrite with hardcoding of original images and episodes
dataset = LeRobotDataset(
repo_id,
root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None,
)
# Get fps
fps = dataset.fps
# we only load first episode
ep_num_images = dataset.episode_data_index["to"][0].item()
# Save/Load image directory for the first episode
imgs_dir = Path(f"tmp/data/images/{repo_id}/observation.image_episode_000000")
if not imgs_dir.exists():
imgs_dir.mkdir(parents=True, exist_ok=True)
hf_dataset = dataset.hf_dataset.with_format(None)
imgs_dataset = hf_dataset.select_columns("observation.image")
for i, item in enumerate(imgs_dataset):
img = item["observation.image"]
img.save(str(imgs_dir / f"frame_{i:06d}.png"), quality=100)
if i >= ep_num_images - 1:
break
sum_original_frames_size_bytes = get_directory_size(imgs_dir)
# Encode images into video
video_path = output_dir / "episode_0.mp4"
g = cfg.get("g")
crf = cfg.get("crf")
pix_fmt = cfg["pix_fmt"]
ffmpeg_cmd = ""
ffmpeg_cmd += f"ffmpeg -r {fps} -f image2 "
ffmpeg_cmd += f"-i {str(imgs_dir / 'frame_%06d.png')} "
ffmpeg_cmd += "-vcodec libx264 "
if g is not None:
ffmpeg_cmd += f"-g {g} " # ensures at least 1 keyframe every 10 frames
# ffmpeg_cmd += "-keyint_min 10 " set a minimum of 10 frames between 2 key frames
# ffmpeg_cmd += "-sc_threshold 0 " disable scene change detection to lower the number of key frames
if crf is not None:
ffmpeg_cmd += f"-crf {crf} "
ffmpeg_cmd += f"-pix_fmt {pix_fmt} "
ffmpeg_cmd += f"{str(video_path)}"
subprocess.run(ffmpeg_cmd.split(" "), check=True)
video_size_bytes = video_path.stat().st_size
# Set decoder
decoder = cfg["decoder"]
decoder_kwgs = cfg["decoder_kwgs"]
device = cfg["device"]
if decoder == "torchaudio":
decode_frames_fn = decode_video_frames_torchaudio
elif decoder == "ffmpegio":
decode_frames_fn = decode_video_frames_ffmpegio
elif decoder == "torchvision":
decode_frames_fn = decode_video_frames_torchvision
else:
raise ValueError(decoder)
# Estimate average loading time
def load_original_frames(imgs_dir, timestamps):
frames = []
for ts in timestamps:
idx = int(ts * fps)
frame = PIL.Image.open(imgs_dir / f"frame_{idx:06d}.png")
frame = torch.from_numpy(numpy.array(frame))
frame = frame.type(torch.float32) / 255
frame = einops.rearrange(frame, "h w c -> c h w")
frames.append(frame)
return frames
list_avg_load_time = []
list_avg_load_time_from_images = []
per_pixel_l2_errors = []
random.seed(seed)
for t in range(50):
# test loading 2 frames that are 4 frames appart, which might be a common setting
ts = random.randint(fps, ep_num_images - fps) / fps
if timestamps_mode == "diffusion":
prev_ts = round(ts - 4 / fps, 4)
timestamps = [prev_ts, ts]
elif timestamps_mode == "tdmpc":
timestamps = [round(ts - i / fps, 4) for i in range(6)][::-1]
else:
raise ValueError(timestamps_mode)
num_frames = len(timestamps)
start_time_s = time.monotonic()
frames = decode_frames_fn(video_path, timestamps=timestamps, device=device, **decoder_kwgs)
avg_load_time = (time.monotonic() - start_time_s) / num_frames
list_avg_load_time.append(avg_load_time)
start_time_s = time.monotonic()
original_frames = load_original_frames(imgs_dir, timestamps)
avg_load_time_from_images = (time.monotonic() - start_time_s) / num_frames
list_avg_load_time_from_images.append(avg_load_time_from_images)
# Estimate average L2 error between original frames and decoded frames
for i, ts in enumerate(timestamps):
# are_close = torch.allclose(frames[i], original_frames[i], atol=0.02)
num_pixels = original_frames[i].numel()
per_pixel_l2_error = torch.norm(frames[i] - original_frames[i], p=2).item() / num_pixels
# save decoded frames
if t == 0:
frame_hwc = (frames[i].permute((1, 2, 0)) * 255).type(torch.uint8).cpu().numpy()
PIL.Image.fromarray(frame_hwc).save(output_dir / f"frame_{i:06d}.png")
# save original_frames
idx = int(ts * fps)
if t == 0:
original_frame = PIL.Image.open(imgs_dir / f"frame_{idx:06d}.png")
original_frame.save(output_dir / f"original_frame_{i:06d}.png")
per_pixel_l2_errors.append(per_pixel_l2_error)
avg_load_time = float(numpy.array(list_avg_load_time).mean())
avg_load_time_from_images = float(numpy.array(list_avg_load_time_from_images).mean())
avg_per_pixel_l2_error = float(numpy.array(per_pixel_l2_errors).mean())
# Save benchmark info
info = {
"sum_original_frames_size_bytes": sum_original_frames_size_bytes,
"video_size_bytes": video_size_bytes,
"avg_load_time_from_images": avg_load_time_from_images,
"avg_load_time": avg_load_time,
"pc_compression": sum_original_frames_size_bytes / video_size_bytes,
"pc_load_time": avg_load_time_from_images / avg_load_time,
"avg_per_pixel_l2_error": avg_per_pixel_l2_error,
}
for key in info:
print(key, info[key])
with open(output_dir / "info.json", "w") as f:
json.dump(info, f)
return info
def main():
dry_run = True
bench_dir = Path("tmp/2024_04_29_1049_6_timestamps")
def display_markdown_table(headers, rows):
for i, row in enumerate(rows):
new_row = []
for col in row:
if col is None:
new_col = "None"
elif isinstance(col, float):
new_col = f"{col:.3f}"
if new_col == "0.000":
new_col = f"{col:.7f}"
elif isinstance(col, int):
new_col = f"{col}"
else:
new_col = col
new_row.append(new_col)
rows[i] = new_row
header_line = "| " + " | ".join(headers) + " |"
separator_line = "| " + " | ".join(["---" for _ in headers]) + " |"
body_lines = ["| " + " | ".join(row) + " |" for row in rows]
markdown_table = "\n".join([header_line, separator_line] + body_lines)
print(markdown_table)
print()
def load_info(out_dir):
with open(out_dir / "info.json") as f:
info = json.load(f)
return info
repo_ids = ["lerobot/pusht", "lerobot/umi_cup_in_the_wild"]
# torchvision vs ffmpegio vs torchaudio
headers = ["repo_id", "decoder", "pc_load_time", "avg_per_pixel_l2_error"]
rows = []
for repo_id in repo_ids:
for decoder in ["torchvision", "ffmpegio", "torchaudio"]:
cfg = {
"repo_id": repo_id,
# video encoding
"g": 10,
"crf": 10,
"pix_fmt": "yuv444p",
# video decoding
"device": "cpu",
"decoder": decoder,
"decoder_kwgs": {},
}
if not dry_run:
run_video_benchmark(bench_dir / repo_id / decoder, cfg=cfg)
info = load_info(bench_dir / repo_id / decoder)
rows.append([repo_id, decoder, info["pc_load_time"], info["avg_per_pixel_l2_error"]])
display_markdown_table(headers, rows)
# yuv444p vs yuv420p
headers = ["repo_id", "pix_fmt", "pc_compression", "pc_load_time", "avg_per_pixel_l2_error"]
rows = []
for repo_id in repo_ids:
for pix_fmt in ["yuv420p", "yuv444p"]:
cfg = {
"repo_id": repo_id,
# video encoding
"g": 10,
"crf": 10,
"pix_fmt": pix_fmt,
# video decoding
"device": "cpu",
"decoder": "torchvision",
"decoder_kwgs": {},
}
if not dry_run:
run_video_benchmark(bench_dir / repo_id / f"torchvision_{pix_fmt}", cfg=cfg)
info = load_info(bench_dir / repo_id / f"torchvision_{pix_fmt}")
rows.append(
[
repo_id,
pix_fmt,
info["pc_compression"],
info["pc_load_time"],
info["avg_per_pixel_l2_error"],
]
)
display_markdown_table(headers, rows)
# g
headers = ["repo_id", "g", "pc_compression", "pc_load_time", "avg_per_pixel_l2_error"]
rows = []
for repo_id in repo_ids:
for g in [1, 5, 10, 15, 20, 30, 40, 60, 100, None]:
cfg = {
"repo_id": repo_id,
# video encoding
"g": g,
"crf": 10,
"pix_fmt": "yuv444p",
# video decoding
"device": "cpu",
"decoder": "torchvision",
"decoder_kwgs": {},
}
if not dry_run:
run_video_benchmark(bench_dir / repo_id / f"torchvision_g_{g}", cfg=cfg)
info = load_info(bench_dir / repo_id / f"torchvision_g_{g}")
rows.append(
[repo_id, g, info["pc_compression"], info["pc_load_time"], info["avg_per_pixel_l2_error"]]
)
display_markdown_table(headers, rows)
# crf
headers = ["repo_id", "crf", "pc_compression", "pc_load_time", "avg_per_pixel_l2_error"]
rows = []
for repo_id in repo_ids:
for crf in [0, 5, 10, 15, 20, None, 25, 30, 40, 50]:
cfg = {
"repo_id": repo_id,
# video encoding
"g": None,
"crf": crf,
"pix_fmt": "yuv444p",
# video decoding
"device": "cpu",
"decoder": "torchvision",
"decoder_kwgs": {},
}
if not dry_run:
run_video_benchmark(bench_dir / repo_id / f"torchvision_crf_{crf}", cfg=cfg)
info = load_info(bench_dir / repo_id / f"torchvision_crf_{crf}")
rows.append(
[repo_id, crf, info["pc_compression"], info["pc_load_time"], info["avg_per_pixel_l2_error"]]
)
display_markdown_table(headers, rows)
if __name__ == "__main__":
main()

View File

@ -9,7 +9,9 @@ from lerobot.common.datasets.utils import (
load_info,
load_previous_and_future_frames,
load_stats,
load_videos,
)
from lerobot.common.datasets.video_utils import load_from_videos
class LeRobotDataset(torch.utils.data.Dataset):
@ -30,18 +32,38 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.transform = transform
self.delta_timestamps = delta_timestamps
# load data from hub or locally when root is provided
# TODO(rcadene, aliberts): implement faster transfer
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
self.hf_dataset = load_hf_dataset(repo_id, version, root, split)
self.episode_data_index = load_episode_data_index(repo_id, version, root)
self.stats = load_stats(repo_id, version, root)
self.info = load_info(repo_id, version, root)
if self.video:
self.videos_dir = load_videos(repo_id, version, root)
@property
def fps(self) -> int:
return self.info["fps"]
@property
def video(self) -> int:
return self.info.get("video", False)
@property
def image_keys(self) -> list[str]:
return [key for key, feats in self.hf_dataset.features.items() if isinstance(feats, datasets.Image)]
image_keys = []
for key, feats in self.hf_dataset.features.items():
if isinstance(feats, datasets.Image):
image_keys.append(key)
return image_keys + self.video_frame_keys
@property
def video_frame_keys(self):
video_frame_keys = []
for key, feats in self.hf_dataset.features.items():
if isinstance(feats, datasets.Value) and feats.id == "video_frame":
video_frame_keys.append(key)
return video_frame_keys
@property
def num_samples(self) -> int:
@ -66,6 +88,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
)
if self.video:
item = load_from_videos(item, self.video_frame_keys, self.videos_dir)
if self.transform is not None:
item = self.transform(item)

View File

@ -8,7 +8,7 @@ import einops
import torch
import tqdm
from datasets import Image, load_dataset, load_from_disk
from huggingface_hub import hf_hub_download
from huggingface_hub import hf_hub_download, snapshot_download
from PIL import Image as PILImage
from safetensors.torch import load_file
from torchvision import transforms
@ -127,6 +127,15 @@ def load_info(repo_id, version, root) -> dict:
return info
def load_videos(repo_id, version, root) -> Path:
if root is not None:
path = Path(root) / repo_id / "videos"
else:
path = snapshot_download(repo_id, allow_patterns="*.mp4", repo_type="dataset", revision=version)
return path
def load_previous_and_future_frames(
item: dict[str, torch.Tensor],
hf_dataset: datasets.Dataset,
@ -209,6 +218,7 @@ def load_previous_and_future_frames(
item[key] = hf_dataset.select_columns(key)[data_ids][key]
item[key] = torch.stack(item[key])
item[f"{key}_is_pad"] = is_pad
item[f"{key}_timestamp"] = query_ts
return item

View File

@ -0,0 +1,104 @@
import itertools
import subprocess
from pathlib import Path
import torch
import torchvision
def load_from_videos(item, video_frame_keys, videos_dir):
for key in video_frame_keys:
ep_idx = item["episode_index"]
video_path = videos_dir / key / f"episode_{ep_idx:06d}.mp4"
if f"{key}_timestamp" in item:
# load multiple frames at once
timestamps = item[f"{key}_timestamp"]
item[key] = decode_video_frames_torchvision(video_path, timestamps)
else:
# load one frame
timestamps = [item["timestamp"]]
frames = decode_video_frames_torchvision(video_path, timestamps)
assert len(frames) == 1
item[key] = frames[0]
return item
def decode_video_frames_torchvision(
video_path: str, timestamps: list[float], device: str = "cpu", log_loaded_timestamps: bool = False
):
"""Loads frames associated to the requested timestamps of a video
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to
that key frame. As a consequence, to access a requested frame, we need to load the preceeding key frame,
and all subsequent frames until reaching the requested frame. The number of key frames in a video
can be adjusted during encoding to take into account decoding time and video size in bytes.
"""
# set backend
if device == "cpu":
# explicitely use pyav
torchvision.set_video_backend("pyav")
elif device == "cuda":
# TODO(rcadene, aliberts): implement video decoding with GPU
# torchvision.set_video_backend("cuda")
# torchvision.set_video_backend("video_reader")
# requires installing torchvision from source, see: https://github.com/pytorch/vision/blob/main/torchvision/csrc/io/decoder/gpu/README.rst
# check possible bug: https://github.com/pytorch/vision/issues/7745
raise NotImplementedError()
else:
raise ValueError(device)
# set a video stream reader
# TODO(rcadene): also load audio stream at the same time
reader = torchvision.io.VideoReader(str(video_path), "video")
# sanity preprocessing (e.g. 3.60000003 -> 3.6)
timestamps = [round(ts, 4) for ts in timestamps]
# set the first and last requested timestamps
# Note: previous timestamps are usually loaded, since we need to access the previous key frame
first_ts = timestamps[0]
last_ts = timestamps[-1]
# access key frame of first requested frame, and load all frames until last requested frame
# for details on what `seek` is doing see: https://pyav.basswood-io.com/docs/stable/api/container.html?highlight=inputcontainer#av.container.InputContainer.seek
frames = []
for frame in itertools.takewhile(lambda x: x["pts"] <= last_ts, reader.seek(first_ts)):
# get timestamp of the loaded frame
ts = frame["pts"]
# if the loaded frame is not among the requested frames, we dont add it to the list of output frames
is_frame_requested = ts in timestamps
if is_frame_requested:
frames.append(frame["data"])
if log_loaded_timestamps:
log = f"frame loaded at timestamp={ts:.4f}"
if is_frame_requested:
log += " requested"
print(log)
frames = torch.stack(frames)
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
frames = frames.type(torch.float32) / 255
assert len(timestamps) == frames.shape[0]
return frames
def encode_video_frames(imgs_dir: Path, video_path: Path, fps: int):
# For more info this setting, see: `lerobot/common/datasets/_video_benchmark/README.md`
video_path = Path(video_path)
video_path.parent.mkdir(parents=True, exist_ok=True)
ffmpeg_cmd = (
f"ffmpeg -r {fps} -f image2 "
f"-i {str(imgs_dir / 'frame_%06d.png')} "
"-vcodec libx264 "
"-pix_fmt yuv444p "
f"{str(video_path)}"
)
subprocess.run(ffmpeg_cmd.split(" "), check=True)

View File

@ -255,8 +255,7 @@ def main():
parser.add_argument(
"--video",
type=int,
# TODO(rcadene): enable when video PR merges
default=0,
default=1,
help="Convert each episode of the raw dataset to an mp4 video. This option allows 60 times lower disk space consumption and 25 faster loading time during training.",
)
parser.add_argument(

36
poetry.lock generated
View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand.
[[package]]
name = "absl-py"
@ -3806,38 +3806,6 @@ typing-extensions = ">=4.8.0"
opt-einsum = ["opt-einsum (>=3.3)"]
optree = ["optree (>=0.9.1)"]
[[package]]
name = "torchaudio"
version = "2.3.0"
description = "An audio package for PyTorch"
optional = false
python-versions = "*"
files = [
{file = "torchaudio-2.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:342108da83aa19a457c9a128b1206fadb603753b51cca022b9f585aac2f4754c"},
{file = "torchaudio-2.3.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:73fedb2c631e01fa10feaac308540b836aefe758e55ca3ee026335e5d01e8e30"},
{file = "torchaudio-2.3.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:e5bb50b7a4874ed97086c9e516dd90b103d954edcb5ed4b36f4fc22c4000a5a7"},
{file = "torchaudio-2.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:b4cc9cef5c98ed37e9405c4e0b0e6413bc101f3f49d45dc4f1d4e927757fe41e"},
{file = "torchaudio-2.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:341ca3048ce6edcc731519b30187f0b13acb245c4efe16f925f69f9d533546e1"},
{file = "torchaudio-2.3.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:8f2e0a28740bb0ee66369f92c811f33c0a47e6fcfc2de9cee89746472d713906"},
{file = "torchaudio-2.3.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:61edb02ae9c0efea4399f9c1f899601136b24f35d430548284ea8eaf6ccbe3be"},
{file = "torchaudio-2.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:04bc960cf1aef3b469b095a432a25496bc28197850fc2d90b7b52d6b5255487b"},
{file = "torchaudio-2.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:535144a2fbba95fbb3b883224ffcf44788e4cecbabbe49c4a1ae3e7a74f71485"},
{file = "torchaudio-2.3.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:fb3f52ed1d63b272c240d9bf051705312cb172212051b8a6a2f64d42e3cc1633"},
{file = "torchaudio-2.3.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:668a8b694e5522cff28cd5e02d01aa1b75ce940aa9fb40480892bdc623b1735d"},
{file = "torchaudio-2.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:6c1f538018b85d7766835d042e555de2f096f7a69bba6b16031bf42a914dd9e1"},
{file = "torchaudio-2.3.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7ba93265455dc363385e98c0cfcaeb586b7401af8a2c824811ee1466134a4f30"},
{file = "torchaudio-2.3.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:21bb6d1b384fc8895133f01489133d575d4a715cd81734b89651fb0264bd8b80"},
{file = "torchaudio-2.3.0-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:ed1866f508dc689c4f682d330b2ed4c83108d35865e4fb89431819364d8ad9ed"},
{file = "torchaudio-2.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:a3cbb230e2bb38ad1a1dd74aea242a154a9f76ab819d9c058b2c5074a9f5d7d2"},
{file = "torchaudio-2.3.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f4b933776f20a36af5ddc57968fcb3da34dd03881db8d6760f3e1176803b9cf8"},
{file = "torchaudio-2.3.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:c5e63cc2dbf179088b6cdfd21ecdbb943aa003c780075aa440162f231ee72db2"},
{file = "torchaudio-2.3.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:d243bb8a1ee263c2cdafb9feed1569c3742d8135731e8f7818de12f4e0c83e28"},
{file = "torchaudio-2.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:6cd6d45cf8a45c89953e35434d9a461feb418e51e760adafc606a903dcbb9bd5"},
]
[package.dependencies]
torch = "2.3.0"
[[package]]
name = "torchvision"
version = "0.18.0"
@ -4299,4 +4267,4 @@ xarm = ["gym-xarm"]
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "0f72eb92ac8817a46f0659b4d72647a6b76f6e4ba762d11b280f8a88e6cd4371"
content-hash = "fab42b4be590cb2007934cd8f5a218f1f3da4f0b42cdff7e7724af518888d7b4"

View File

@ -56,7 +56,6 @@ pytest = {version = "^8.1.0", optional = true}
pytest-cov = {version = "^5.0.0", optional = true}
datasets = "^2.19.0"
imagecodecs = { version = "^2024.1.1", optional = true }
torchaudio = "^2.3.0"
[tool.poetry.extras]