Add video dataset by default
This commit is contained in:
parent
72bcfb9ee4
commit
d632f8cb51
|
@ -0,0 +1,186 @@
|
||||||
|
# Video benchmark
|
||||||
|
|
||||||
|
|
||||||
|
## Questions
|
||||||
|
|
||||||
|
What is the optimal trade-off between:
|
||||||
|
- maximizing loading time with random access,
|
||||||
|
- minimizing memory space on disk,
|
||||||
|
- maximizing success rate of policies?
|
||||||
|
|
||||||
|
How to encode videos?
|
||||||
|
- How much compression (`-crf`)? Low compression with `0`, normal compression with `20` or extreme with `56`?
|
||||||
|
- What pixel format to use (`-pix_fmt`)? `yuv444p` or `yuv420p`?
|
||||||
|
- How many key frames (`-g`)? A key frame every `10` frames?
|
||||||
|
|
||||||
|
How to decode videos?
|
||||||
|
- Which `decoder`? `torchvision`, `torchaudio`, `ffmpegio`, `decord`, or `nvc`?
|
||||||
|
|
||||||
|
## Metrics
|
||||||
|
|
||||||
|
**Percentage of data compression (higher is better)**
|
||||||
|
`pc_compression` is the ratio of the memory space on disk taken by the original images to encode, to the memory space taken by the encoded video. For instance, `pc_compression=400%` means that the video takes 4 times less memory space on disk compared to the original images.
|
||||||
|
|
||||||
|
**Percentage of loading time (lower is better)**
|
||||||
|
`pc_load_time` is the ratio of the time it takes to load original images at given timestamps, to the time it takes to decode the exact same frames from the video. Lower is better. For instance, `pc_load_time=120%` means that decoding from video is a bit slower than loading the original images.
|
||||||
|
|
||||||
|
**Average L2 error per pixel (lower is better)**
|
||||||
|
`avg_per_pixel_l2_error` is the average L2 error between each decoded frame and its corresponding original image over all requested timestamps, and also divided by the number of pixels in the image to be comparable when switching to different image sizes.
|
||||||
|
|
||||||
|
**Loss of a pretrained policy (higher is better)** (not available)
|
||||||
|
`loss_pretrained` is the result of evaluating with the selected encoding/decoding settings a policy pretrained on original images. It is easier to understand than `avg_l2_error`.
|
||||||
|
|
||||||
|
**Success rate after retraining (higher is better)** (not available)
|
||||||
|
`success_rate` is the result of training and evaluating a policy with the selected encoding/decoding settings. It is the most difficult metric to get but also the very best.
|
||||||
|
|
||||||
|
|
||||||
|
## Variables
|
||||||
|
|
||||||
|
**Image content**
|
||||||
|
We don't expect the same optimal settings for a dataset of images from a simulation, or from real-world in an appartment, or in a factory, or outdoor, etc. Hence, we run this bechmark on two datasets: `pusht` (simulation) and `umi` (real-world outdoor).
|
||||||
|
|
||||||
|
**Requested timestamps**
|
||||||
|
In this benchmark, we focus on the loading time of random access, so we are not interested about sequentially loading all frames of a video like in a movie. However, the number of consecutive timestamps requested and their spacing can greatly affect the `pc_load_time`. In fact, it is expected to get faster loading time by decoding a large number of consecutive frames from a video, than to load the same data from individual images. To reflect our robotics use case, we consider a setting where we load 2 consecutive frames with 4 frames of spacing.
|
||||||
|
|
||||||
|
**Data augmentations**
|
||||||
|
We might revisit this benchmark and find better settings if we train our policies with various data augmentations to make them more robusts (e.g. robust to color changes, compression, etc.).
|
||||||
|
|
||||||
|
|
||||||
|
## Results
|
||||||
|
|
||||||
|
### Loading 2 consecutive frames with 4 frames spacing (Diffusion Policy setting)
|
||||||
|
|
||||||
|
**`decoder`**
|
||||||
|
| repo_id | decoder | pc_load_time | avg_per_pixel_l2_error |
|
||||||
|
| --- | --- | --- | --- |
|
||||||
|
| lerobot/pusht | <span style="color: #32CD32;">torchvision</span> | 0.166 | 0.0000119 |
|
||||||
|
| lerobot/pusht | ffmpegio | 0.009 | 0.0001182 |
|
||||||
|
| lerobot/pusht | torchaudio | 0.138 | 0.0000359 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | <span style="color: #32CD32;">torchvision</span> | 0.174 | 0.0000174 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | ffmpegio | 0.010 | 0.0000735 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | torchaudio | 0.154 | 0.0000340 |
|
||||||
|
|
||||||
|
**`pix_fmt`**
|
||||||
|
| repo_id | pix_fmt | pc_compression | pc_load_time | avg_per_pixel_l2_error |
|
||||||
|
| --- | --- | --- | --- | --- |
|
||||||
|
| lerobot/pusht | yuv420p | 3.602 | 0.202 | 0.0000661 |
|
||||||
|
| lerobot/pusht | <span style="color: #32CD32;">yuv444p</span> | 3.213 | 0.153 | 0.0000110 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | yuv420p | 8.879 | 0.202 | 0.0000332 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | <span style="color: #32CD32;">yuv444p</span> | 8.517 | 0.165 | 0.0000175 |
|
||||||
|
|
||||||
|
**`g`**
|
||||||
|
| repo_id | g | pc_compression | pc_load_time | avg_per_pixel_l2_error |
|
||||||
|
| --- | --- | --- | --- | --- |
|
||||||
|
| lerobot/pusht | 1 | 1.308 | 0.190 | 0.0000151 |
|
||||||
|
| lerobot/pusht | 5 | 2.739 | 0.184 | 0.0000123 |
|
||||||
|
| lerobot/pusht | 10 | 3.213 | 0.144 | 0.0000116 |
|
||||||
|
| lerobot/pusht | 15 | 3.460 | 0.137 | 0.0000112 |
|
||||||
|
| lerobot/pusht | 20 | 3.559 | 0.118 | 0.0000109 |
|
||||||
|
| lerobot/pusht | 30 | 3.697 | 0.104 | 0.0000117 |
|
||||||
|
| lerobot/pusht | 40 | 3.763 | 0.092 | 0.0000116 |
|
||||||
|
| lerobot/pusht | 60 | 3.925 | 0.068 | 0.0000117 |
|
||||||
|
| lerobot/pusht | 100 | 4.010 | 0.054 | 0.0000117 |
|
||||||
|
| lerobot/pusht | <span style="color: #32CD32;">None</span> | 4.058 | 0.043 | 0.0000117 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | 1 | 4.790 | 0.236 | 0.0000221 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | 5 | 7.707 | 0.201 | 0.0000185 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | 10 | 8.517 | 0.172 | 0.0000177 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | 15 | 8.830 | 0.152 | 0.0000170 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | 20 | 8.961 | 0.133 | 0.0000167 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | 30 | 8.850 | 0.113 | 0.0000167 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | 40 | 8.996 | 0.109 | 0.0000174 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | 60 | 9.113 | 0.081 | 0.0000163 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | 100 | 9.278 | 0.051 | 0.0000173 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | <span style="color: #32CD32;">None</span> | 9.396 | 0.030 | 0.0000165 |
|
||||||
|
|
||||||
|
**`crf`**
|
||||||
|
| repo_id | crf | pc_compression | pc_load_time | avg_per_pixel_l2_error |
|
||||||
|
| --- | --- | --- | --- | --- |
|
||||||
|
| lerobot/pusht | 0 | 4.529 | 0.041 | 0.0000035 |
|
||||||
|
| lerobot/pusht | 5 | 3.138 | 0.040 | 0.0000077 |
|
||||||
|
| lerobot/pusht | <span style="color: #32CD32;">10</span> | 4.058 | 0.038 | 0.0000121 |
|
||||||
|
| lerobot/pusht | <span style="color: #32CD32;">15</span> | 5.407 | 0.039 | 0.0000195 |
|
||||||
|
| lerobot/pusht | <span style="color: #32CD32;">20</span> | 7.335 | 0.039 | 0.0000319 |
|
||||||
|
| lerobot/pusht | <span style="color: #32CD32;">None</span> | 8.909 | 0.046 | 0.0000425 |
|
||||||
|
| lerobot/pusht | 25 | 10.213 | 0.039 | 0.0000519 |
|
||||||
|
| lerobot/pusht | 30 | 14.516 | 0.041 | 0.0000795 |
|
||||||
|
| lerobot/pusht | 40 | 23.546 | 0.041 | 0.0001557 |
|
||||||
|
| lerobot/pusht | 50 | 28.460 | 0.042 | 0.0002723 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | 0 | 2.318 | 0.012 | 0.0000056 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | 5 | 4.899 | 0.019 | 0.0000132 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | <span style="color: #32CD32;">10</span> | 9.396 | 0.026 | 0.0000183 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | <span style="color: #32CD32;">15</span> | 19.161 | 0.034 | 0.0000241 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | <span style="color: #32CD32;">20</span> | 39.311 | 0.039 | 0.0000329 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | <span style="color: #32CD32;">None</span> | 60.530 | 0.043 | 0.0000401 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | 25 | 81.048 | 0.046 | 0.0000454 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | 30 | 165.189 | 0.051 | 0.0000609 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | 40 | 544.478 | 0.056 | 0.0001095 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | 50 | 1109.556 | 0.072 | 0.0001815 |
|
||||||
|
|
||||||
|
|
||||||
|
### Loading 6 consecutive frames with no spacing (TDMPC setting)
|
||||||
|
|
||||||
|
**`decoder`**
|
||||||
|
| repo_id | decoder | pc_load_time | avg_per_pixel_l2_error |
|
||||||
|
| --- | --- | --- | --- |
|
||||||
|
| lerobot/pusht | <span style="color: #32CD32;">torchvision</span> | 0.386 | 0.0000117 |
|
||||||
|
| lerobot/pusht | ffmpegio | 0.008 | 0.0000117 |
|
||||||
|
| lerobot/pusht | torchaudio | 0.184 | 0.0000356 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | <span style="color: #32CD32;">torchvision</span> | 0.448 | 0.0000178 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | ffmpegio | 0.009 | 0.0000178 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | torchaudio | 0.149 | 0.0000349 |
|
||||||
|
|
||||||
|
**`pix_fmt`**
|
||||||
|
| repo_id | pix_fmt | pc_compression | pc_load_time | avg_per_pixel_l2_error |
|
||||||
|
| --- | --- | --- | --- | --- |
|
||||||
|
| lerobot/pusht | yuv420p | 3.602 | 0.518 | 0.0000651 |
|
||||||
|
| lerobot/pusht | <span style="color: #32CD32;">yuv444p</span> | 3.213 | 0.401 | 0.0000117 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | yuv420p | 8.879 | 0.578 | 0.0000334 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | <span style="color: #32CD32;">yuv444p</span> | 8.517 | 0.479 | 0.0000178 |
|
||||||
|
|
||||||
|
**`g`**
|
||||||
|
| repo_id | g | pc_compression | pc_load_time | avg_per_pixel_l2_error |
|
||||||
|
| --- | --- | --- | --- | --- |
|
||||||
|
| lerobot/pusht | 1 | 1.308 | 0.528 | 0.0000152 |
|
||||||
|
| lerobot/pusht | 5 | 2.739 | 0.483 | 0.0000124 |
|
||||||
|
| lerobot/pusht | 10 | 3.213 | 0.396 | 0.0000117 |
|
||||||
|
| lerobot/pusht | 15 | 3.460 | 0.379 | 0.0000118 |
|
||||||
|
| lerobot/pusht | 20 | 3.559 | 0.319 | 0.0000114 |
|
||||||
|
| lerobot/pusht | 30 | 3.697 | 0.278 | 0.0000116 |
|
||||||
|
| lerobot/pusht | 40 | 3.763 | 0.243 | 0.0000115 |
|
||||||
|
| lerobot/pusht | 60 | 3.925 | 0.186 | 0.0000118 |
|
||||||
|
| lerobot/pusht | 100 | 4.010 | 0.156 | 0.0000119 |
|
||||||
|
| lerobot/pusht | <span style="color: #32CD32;">None</span> | 4.058 | 0.105 | 0.0000121 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | 1 | 4.790 | 0.605 | 0.0000221 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | 5 | 7.707 | 0.533 | 0.0000183 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | 10 | 8.517 | 0.469 | 0.0000178 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | 15 | 8.830 | 0.399 | 0.0000174 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | 20 | 8.961 | 0.382 | 0.0000175 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | 30 | 8.850 | 0.326 | 0.0000172 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | 40 | 8.996 | 0.279 | 0.0000173 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | 60 | 9.113 | 0.226 | 0.0000174 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | 100 | 9.278 | 0.150 | 0.0000175 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | <span style="color: #32CD32;">None</span> | 9.396 | 0.076 | 0.0000176 |
|
||||||
|
|
||||||
|
**`crf`**
|
||||||
|
| repo_id | crf | pc_compression | pc_load_time | avg_per_pixel_l2_error |
|
||||||
|
| --- | --- | --- | --- | --- |
|
||||||
|
| lerobot/pusht | 0 | 4.529 | 0.108 | 0.0000035 |
|
||||||
|
| lerobot/pusht | 5 | 3.138 | 0.099 | 0.0000077 |
|
||||||
|
| lerobot/pusht | 10 | 4.058 | 0.091 | 0.0000121 |
|
||||||
|
| lerobot/pusht | 15 | 5.407 | 0.095 | 0.0000195 |
|
||||||
|
| lerobot/pusht | 20 | 7.335 | 0.100 | 0.0000318 |
|
||||||
|
| lerobot/pusht | <span style="color: #32CD32;">None</span> | 8.909 | 0.102 | 0.0000422 |
|
||||||
|
| lerobot/pusht | 25 | 10.213 | 0.102 | 0.0000517 |
|
||||||
|
| lerobot/pusht | 30 | 14.516 | 0.104 | 0.0000795 |
|
||||||
|
| lerobot/pusht | 40 | 23.546 | 0.106 | 0.0001555 |
|
||||||
|
| lerobot/pusht | 50 | 28.460 | 0.110 | 0.0002723 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | 0 | 2.318 | 0.032 | 0.0000056 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | 5 | 4.899 | 0.052 | 0.0000127 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | <span style="color: #32CD32;">10</span> | 9.396 | 0.073 | 0.0000176 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | <span style="color: #32CD32;">15</span> | 19.161 | 0.097 | 0.0000234 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | <span style="color: #32CD32;">20</span> | 39.311 | 0.110 | 0.0000321 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | <span style="color: #32CD32;">None</span> | 60.530 | 0.117 | 0.0000393 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | 25 | 81.048 | 0.126 | 0.0000446 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | 30 | 165.189 | 0.138 | 0.0000603 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | 40 | 544.478 | 0.151 | 0.0001095 |
|
||||||
|
| lerobot/umi_cup_in_the_wild | 50 | 1109.556 | 0.167 | 0.0001817 |
|
|
@ -0,0 +1,230 @@
|
||||||
|
"""This file contains work-in-progress alternative to default decoding strategy."""
|
||||||
|
|
||||||
|
import einops
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def decode_video_frames_ffmpegio(video_path, timestamps, device="cpu"):
|
||||||
|
# assert device == "cpu", f"Only CPU decoding is supported with ffmpegio, but device is {device}"
|
||||||
|
import einops
|
||||||
|
import ffmpegio
|
||||||
|
|
||||||
|
num_contiguous_frames = 1 # noqa: F841
|
||||||
|
image_format = "rgb24"
|
||||||
|
|
||||||
|
list_frames = []
|
||||||
|
for timestamp in timestamps:
|
||||||
|
kwargs = {
|
||||||
|
"ss": str(timestamp),
|
||||||
|
# vframes=num_contiguous_frames,
|
||||||
|
"pix_fmt": image_format,
|
||||||
|
# hwaccel=None if device == "cpu" else device, # ,
|
||||||
|
"show_log": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
if device == "cuda":
|
||||||
|
kwargs["hwaccel_in"] = "cuda"
|
||||||
|
kwargs["hwaccel_output_format_in"] = "cuda"
|
||||||
|
|
||||||
|
fs, frames = ffmpegio.video.read(str(video_path), **kwargs)
|
||||||
|
list_frames.append(torch.from_numpy(frames))
|
||||||
|
frames = torch.cat(list_frames)
|
||||||
|
|
||||||
|
frames = einops.rearrange(frames, "b h w c -> b c h w")
|
||||||
|
frames = frames.type(torch.float32) / 255
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
def yuv_to_rgb(frames):
|
||||||
|
assert frames.dtype == torch.uint8
|
||||||
|
assert frames.ndim == 4
|
||||||
|
assert frames.shape[1] == 3
|
||||||
|
|
||||||
|
frames = frames.cpu().to(torch.float)
|
||||||
|
y = frames[..., 0, :, :]
|
||||||
|
u = frames[..., 1, :, :]
|
||||||
|
v = frames[..., 2, :, :]
|
||||||
|
|
||||||
|
y /= 255
|
||||||
|
u = u / 255 - 0.5
|
||||||
|
v = v / 255 - 0.5
|
||||||
|
|
||||||
|
r = y + 1.13983 * v
|
||||||
|
g = y + -0.39465 * u - 0.58060 * v
|
||||||
|
b = y + 2.03211 * u
|
||||||
|
|
||||||
|
rgb = torch.stack([r, g, b], 1)
|
||||||
|
rgb = (rgb * 255).clamp(0, 255).to(torch.uint8)
|
||||||
|
return rgb
|
||||||
|
|
||||||
|
|
||||||
|
def yuv_to_rgb_cv2(frames, return_hwc=True):
|
||||||
|
assert frames.dtype == torch.uint8
|
||||||
|
assert frames.ndim == 4
|
||||||
|
assert frames.shape[1] == 3
|
||||||
|
frames = frames.cpu()
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
frames = einops.rearrange(frames, "b c h w -> b h w c")
|
||||||
|
frames = frames.numpy()
|
||||||
|
frames = [cv2.cvtColor(frame, cv2.COLOR_YUV2RGB) for frame in frames]
|
||||||
|
frames = [torch.from_numpy(frame) for frame in frames]
|
||||||
|
frames = torch.stack(frames)
|
||||||
|
if not return_hwc:
|
||||||
|
frames = einops.rearrange(frames, "b h w c -> b c h w")
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
def decode_video_frames_torchaudio(video_path, timestamps, device="cpu"):
|
||||||
|
num_contiguous_frames = 1
|
||||||
|
width = None
|
||||||
|
height = None
|
||||||
|
# image_format = "rgb" # or "yuv"
|
||||||
|
# image_format = None
|
||||||
|
image_format = "yuv444p"
|
||||||
|
# image_format = "yuv444p"
|
||||||
|
# image_format = "rgb24"
|
||||||
|
frame_rate = None
|
||||||
|
|
||||||
|
scale_full_range_filter = False
|
||||||
|
|
||||||
|
filter_desc = []
|
||||||
|
|
||||||
|
video_stream_kwgs = {
|
||||||
|
"frames_per_chunk": num_contiguous_frames,
|
||||||
|
# "buffer_chunk_size": num_contiguous_frames,
|
||||||
|
}
|
||||||
|
|
||||||
|
# choice of decoder
|
||||||
|
if device == "cuda":
|
||||||
|
video_stream_kwgs["hw_accel"] = "cuda:0"
|
||||||
|
video_stream_kwgs["decoder"] = "h264_cuvid"
|
||||||
|
# video_stream_kwgs["decoder"] = "hevc_cuvid"
|
||||||
|
# video_stream_kwgs["decoder"] = "av1_cuvid"
|
||||||
|
# video_stream_kwgs["decoder"] = "ffv1_cuvid"
|
||||||
|
else:
|
||||||
|
video_stream_kwgs["decoder"] = "h264"
|
||||||
|
# video_stream_kwgs["decoder"] = "hevc"
|
||||||
|
# video_stream_kwgs["decoder"] = "av1"
|
||||||
|
# video_stream_kwgs["decoder"] = "ffv1"
|
||||||
|
|
||||||
|
# resize
|
||||||
|
resize_width = width is not None
|
||||||
|
resize_height = height is not None
|
||||||
|
if resize_width or resize_height:
|
||||||
|
if device == "cuda":
|
||||||
|
assert resize_width and resize_height
|
||||||
|
video_stream_kwgs["decoder_option"] = {"resize": f"{width}x{height}"}
|
||||||
|
else:
|
||||||
|
scales = []
|
||||||
|
if resize_width:
|
||||||
|
scales.append(f"width={width}")
|
||||||
|
if resize_height:
|
||||||
|
scales.append(f"height={height}")
|
||||||
|
filter_desc.append(f"scale={':'.join(scales)}")
|
||||||
|
|
||||||
|
# choice of format
|
||||||
|
if image_format is not None:
|
||||||
|
if device == "cuda":
|
||||||
|
# TODO(rcadene): rebuild ffmpeg with --enable-cuda-nvcc, --enable-cuvid, and --enable-libnpp
|
||||||
|
# filter_desc.append(f"scale=format={image_format}")
|
||||||
|
# filter_desc.append(f"scale_cuda=format={image_format}")
|
||||||
|
# filter_desc.append(f"scale_npp=format={image_format}")
|
||||||
|
filter_desc.append(f"format=pix_fmts={image_format}")
|
||||||
|
else:
|
||||||
|
filter_desc.append(f"format=pix_fmts={image_format}")
|
||||||
|
|
||||||
|
# choice of frame rate
|
||||||
|
if frame_rate is not None:
|
||||||
|
filter_desc.append(f"fps={frame_rate}")
|
||||||
|
|
||||||
|
# to set output scale [0-255] instead of [16-235]
|
||||||
|
if scale_full_range_filter:
|
||||||
|
filter_desc.append("scale=in_range=limited:out_range=full")
|
||||||
|
|
||||||
|
if len(filter_desc) > 0:
|
||||||
|
video_stream_kwgs["filter_desc"] = ",".join(filter_desc)
|
||||||
|
|
||||||
|
# create a stream and load a certain number of frame at a certain frame rate
|
||||||
|
# TODO(rcadene): make sure it's the most optimal way to do it
|
||||||
|
from torchaudio.io import StreamReader
|
||||||
|
|
||||||
|
print(video_stream_kwgs)
|
||||||
|
|
||||||
|
list_frames = []
|
||||||
|
for timestamp in timestamps:
|
||||||
|
s = StreamReader(str(video_path))
|
||||||
|
s.seek(timestamp)
|
||||||
|
s.add_video_stream(**video_stream_kwgs)
|
||||||
|
s.fill_buffer()
|
||||||
|
(frames,) = s.pop_chunks()
|
||||||
|
|
||||||
|
if "yuv" in image_format:
|
||||||
|
frames = yuv_to_rgb(frames)
|
||||||
|
|
||||||
|
assert frames.dtype == torch.uint8
|
||||||
|
frames = frames.type(torch.float32)
|
||||||
|
|
||||||
|
# if device == "cuda":
|
||||||
|
# The original data had limited range, which is 16-235, and torchaudio does not convert,
|
||||||
|
# while FFmpeg converts it to full range 0-255. So you can apply a linear transformation.
|
||||||
|
if not scale_full_range_filter:
|
||||||
|
frames -= 16
|
||||||
|
frames *= 255 / (235 - 16)
|
||||||
|
|
||||||
|
frames /= 255
|
||||||
|
|
||||||
|
frames = frames.clip(0, 1)
|
||||||
|
list_frames.append(frames)
|
||||||
|
|
||||||
|
frames = torch.cat(list_frames)
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
# def _decode_frames_decord(video_path, timestamp):
|
||||||
|
# num_contiguous_frames = 1 # noqa: F841 TODO(rcadene): remove
|
||||||
|
# device = "cpu"
|
||||||
|
|
||||||
|
# from decord import VideoReader, cpu, gpu
|
||||||
|
|
||||||
|
# with open(str(video_path), "rb") as f:
|
||||||
|
# ctx = gpu if device == "cuda" else cpu
|
||||||
|
# vr = VideoReader(f, ctx=ctx(0)) # noqa: F841
|
||||||
|
# raise NotImplementedError("Convert `timestamp` into frame_id")
|
||||||
|
# # frame_id = frame_ids[0].item()
|
||||||
|
# # frames = vr.get_batch([frame_id])
|
||||||
|
# # frames = torch.from_numpy(frames.asnumpy())
|
||||||
|
# # frames = einops.rearrange(frames, "b h w c -> b c h w")
|
||||||
|
# # return frames
|
||||||
|
|
||||||
|
|
||||||
|
# def decode_frames_nvc(video_path, timestamps, device="cuda"):
|
||||||
|
# assert device == "cuda"
|
||||||
|
|
||||||
|
# import PyNvCodec as nvc
|
||||||
|
# import PytorchNvCodec as pnvc
|
||||||
|
|
||||||
|
# gpuID = 0
|
||||||
|
|
||||||
|
# nvDec = nvc.PyNvDecoder('path_to_video_file', gpuID)
|
||||||
|
# to_rgb = nvc.PySurfaceConverter(nvDec.Width(), nvDec.Height(), nvc.PixelFormat.NV12, nvc.PixelFormat.RGB, gpuID)
|
||||||
|
# to_planar = nvc.PySurfaceConverter(nvDec.Width(), nvDec.Height(), nvc.PixelFormat.RGB, nvc.PixelFormat.RGB_PLANAR, gpuID)
|
||||||
|
|
||||||
|
# while True:
|
||||||
|
# # Obtain NV12 decoded surface from decoder;
|
||||||
|
# rawSurface = nvDec.DecodeSingleSurface()
|
||||||
|
# if (rawSurface.Empty()):
|
||||||
|
# break
|
||||||
|
|
||||||
|
# # Convert to RGB interleaved;
|
||||||
|
# rgb_byte = to_rgb.Execute(rawSurface)
|
||||||
|
|
||||||
|
# # Convert to RGB planar because that's what to_tensor + normalize are doing;
|
||||||
|
# rgb_planar = to_planar.Execute(rgb_byte)
|
||||||
|
|
||||||
|
# # Create torch tensor from it and reshape because
|
||||||
|
# # pnvc.makefromDevicePtrUint8 creates just a chunk of CUDA memory
|
||||||
|
# # and then copies data from plane pointer to allocated chunk;
|
||||||
|
# surfPlane = rgb_planar.PlanePtr()
|
||||||
|
# surface_tensor = pnvc.makefromDevicePtrUint8(surfPlane.GpuMem(), surfPlane.Width(), surfPlane.Height(), surfPlane.Pitch(), surfPlane.ElemSize())
|
||||||
|
# surface_tensor.resize_(3, target_h, target_w)
|
|
@ -0,0 +1,337 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import einops
|
||||||
|
import numpy
|
||||||
|
import PIL
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.common.datasets._video_benchmark._video_utils import (
|
||||||
|
decode_video_frames_ffmpegio,
|
||||||
|
decode_video_frames_torchaudio,
|
||||||
|
)
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
from lerobot.common.datasets.video_utils import (
|
||||||
|
decode_video_frames_torchvision,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_directory_size(directory):
|
||||||
|
total_size = 0
|
||||||
|
# Iterate over all files and subdirectories recursively
|
||||||
|
for item in directory.rglob("*"):
|
||||||
|
if item.is_file():
|
||||||
|
# Add the file size to the total
|
||||||
|
total_size += item.stat().st_size
|
||||||
|
return total_size
|
||||||
|
|
||||||
|
|
||||||
|
def run_video_benchmark(output_dir, cfg, seed=1337, timestamps_mode="diffusion"):
|
||||||
|
output_dir = Path(output_dir)
|
||||||
|
if output_dir.exists():
|
||||||
|
shutil.rmtree(output_dir)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
repo_id = cfg["repo_id"]
|
||||||
|
|
||||||
|
# TODO(rcadene): rewrite with hardcoding of original images and episodes
|
||||||
|
dataset = LeRobotDataset(
|
||||||
|
repo_id,
|
||||||
|
root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None,
|
||||||
|
)
|
||||||
|
# Get fps
|
||||||
|
fps = dataset.fps
|
||||||
|
|
||||||
|
# we only load first episode
|
||||||
|
ep_num_images = dataset.episode_data_index["to"][0].item()
|
||||||
|
|
||||||
|
# Save/Load image directory for the first episode
|
||||||
|
imgs_dir = Path(f"tmp/data/images/{repo_id}/observation.image_episode_000000")
|
||||||
|
if not imgs_dir.exists():
|
||||||
|
imgs_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||||
|
imgs_dataset = hf_dataset.select_columns("observation.image")
|
||||||
|
|
||||||
|
for i, item in enumerate(imgs_dataset):
|
||||||
|
img = item["observation.image"]
|
||||||
|
img.save(str(imgs_dir / f"frame_{i:06d}.png"), quality=100)
|
||||||
|
|
||||||
|
if i >= ep_num_images - 1:
|
||||||
|
break
|
||||||
|
|
||||||
|
sum_original_frames_size_bytes = get_directory_size(imgs_dir)
|
||||||
|
|
||||||
|
# Encode images into video
|
||||||
|
video_path = output_dir / "episode_0.mp4"
|
||||||
|
|
||||||
|
g = cfg.get("g")
|
||||||
|
crf = cfg.get("crf")
|
||||||
|
pix_fmt = cfg["pix_fmt"]
|
||||||
|
|
||||||
|
ffmpeg_cmd = ""
|
||||||
|
ffmpeg_cmd += f"ffmpeg -r {fps} -f image2 "
|
||||||
|
ffmpeg_cmd += f"-i {str(imgs_dir / 'frame_%06d.png')} "
|
||||||
|
ffmpeg_cmd += "-vcodec libx264 "
|
||||||
|
if g is not None:
|
||||||
|
ffmpeg_cmd += f"-g {g} " # ensures at least 1 keyframe every 10 frames
|
||||||
|
# ffmpeg_cmd += "-keyint_min 10 " set a minimum of 10 frames between 2 key frames
|
||||||
|
# ffmpeg_cmd += "-sc_threshold 0 " disable scene change detection to lower the number of key frames
|
||||||
|
if crf is not None:
|
||||||
|
ffmpeg_cmd += f"-crf {crf} "
|
||||||
|
ffmpeg_cmd += f"-pix_fmt {pix_fmt} "
|
||||||
|
ffmpeg_cmd += f"{str(video_path)}"
|
||||||
|
subprocess.run(ffmpeg_cmd.split(" "), check=True)
|
||||||
|
|
||||||
|
video_size_bytes = video_path.stat().st_size
|
||||||
|
|
||||||
|
# Set decoder
|
||||||
|
|
||||||
|
decoder = cfg["decoder"]
|
||||||
|
decoder_kwgs = cfg["decoder_kwgs"]
|
||||||
|
device = cfg["device"]
|
||||||
|
|
||||||
|
if decoder == "torchaudio":
|
||||||
|
decode_frames_fn = decode_video_frames_torchaudio
|
||||||
|
elif decoder == "ffmpegio":
|
||||||
|
decode_frames_fn = decode_video_frames_ffmpegio
|
||||||
|
elif decoder == "torchvision":
|
||||||
|
decode_frames_fn = decode_video_frames_torchvision
|
||||||
|
else:
|
||||||
|
raise ValueError(decoder)
|
||||||
|
|
||||||
|
# Estimate average loading time
|
||||||
|
|
||||||
|
def load_original_frames(imgs_dir, timestamps):
|
||||||
|
frames = []
|
||||||
|
for ts in timestamps:
|
||||||
|
idx = int(ts * fps)
|
||||||
|
frame = PIL.Image.open(imgs_dir / f"frame_{idx:06d}.png")
|
||||||
|
frame = torch.from_numpy(numpy.array(frame))
|
||||||
|
frame = frame.type(torch.float32) / 255
|
||||||
|
frame = einops.rearrange(frame, "h w c -> c h w")
|
||||||
|
frames.append(frame)
|
||||||
|
return frames
|
||||||
|
|
||||||
|
list_avg_load_time = []
|
||||||
|
list_avg_load_time_from_images = []
|
||||||
|
per_pixel_l2_errors = []
|
||||||
|
|
||||||
|
random.seed(seed)
|
||||||
|
|
||||||
|
for t in range(50):
|
||||||
|
# test loading 2 frames that are 4 frames appart, which might be a common setting
|
||||||
|
ts = random.randint(fps, ep_num_images - fps) / fps
|
||||||
|
|
||||||
|
if timestamps_mode == "diffusion":
|
||||||
|
prev_ts = round(ts - 4 / fps, 4)
|
||||||
|
timestamps = [prev_ts, ts]
|
||||||
|
elif timestamps_mode == "tdmpc":
|
||||||
|
timestamps = [round(ts - i / fps, 4) for i in range(6)][::-1]
|
||||||
|
else:
|
||||||
|
raise ValueError(timestamps_mode)
|
||||||
|
|
||||||
|
num_frames = len(timestamps)
|
||||||
|
|
||||||
|
start_time_s = time.monotonic()
|
||||||
|
frames = decode_frames_fn(video_path, timestamps=timestamps, device=device, **decoder_kwgs)
|
||||||
|
avg_load_time = (time.monotonic() - start_time_s) / num_frames
|
||||||
|
list_avg_load_time.append(avg_load_time)
|
||||||
|
|
||||||
|
start_time_s = time.monotonic()
|
||||||
|
original_frames = load_original_frames(imgs_dir, timestamps)
|
||||||
|
avg_load_time_from_images = (time.monotonic() - start_time_s) / num_frames
|
||||||
|
list_avg_load_time_from_images.append(avg_load_time_from_images)
|
||||||
|
|
||||||
|
# Estimate average L2 error between original frames and decoded frames
|
||||||
|
for i, ts in enumerate(timestamps):
|
||||||
|
# are_close = torch.allclose(frames[i], original_frames[i], atol=0.02)
|
||||||
|
num_pixels = original_frames[i].numel()
|
||||||
|
per_pixel_l2_error = torch.norm(frames[i] - original_frames[i], p=2).item() / num_pixels
|
||||||
|
|
||||||
|
# save decoded frames
|
||||||
|
if t == 0:
|
||||||
|
frame_hwc = (frames[i].permute((1, 2, 0)) * 255).type(torch.uint8).cpu().numpy()
|
||||||
|
PIL.Image.fromarray(frame_hwc).save(output_dir / f"frame_{i:06d}.png")
|
||||||
|
|
||||||
|
# save original_frames
|
||||||
|
idx = int(ts * fps)
|
||||||
|
if t == 0:
|
||||||
|
original_frame = PIL.Image.open(imgs_dir / f"frame_{idx:06d}.png")
|
||||||
|
original_frame.save(output_dir / f"original_frame_{i:06d}.png")
|
||||||
|
|
||||||
|
per_pixel_l2_errors.append(per_pixel_l2_error)
|
||||||
|
|
||||||
|
avg_load_time = float(numpy.array(list_avg_load_time).mean())
|
||||||
|
avg_load_time_from_images = float(numpy.array(list_avg_load_time_from_images).mean())
|
||||||
|
avg_per_pixel_l2_error = float(numpy.array(per_pixel_l2_errors).mean())
|
||||||
|
|
||||||
|
# Save benchmark info
|
||||||
|
|
||||||
|
info = {
|
||||||
|
"sum_original_frames_size_bytes": sum_original_frames_size_bytes,
|
||||||
|
"video_size_bytes": video_size_bytes,
|
||||||
|
"avg_load_time_from_images": avg_load_time_from_images,
|
||||||
|
"avg_load_time": avg_load_time,
|
||||||
|
"pc_compression": sum_original_frames_size_bytes / video_size_bytes,
|
||||||
|
"pc_load_time": avg_load_time_from_images / avg_load_time,
|
||||||
|
"avg_per_pixel_l2_error": avg_per_pixel_l2_error,
|
||||||
|
}
|
||||||
|
|
||||||
|
for key in info:
|
||||||
|
print(key, info[key])
|
||||||
|
|
||||||
|
with open(output_dir / "info.json", "w") as f:
|
||||||
|
json.dump(info, f)
|
||||||
|
|
||||||
|
return info
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
dry_run = True
|
||||||
|
|
||||||
|
bench_dir = Path("tmp/2024_04_29_1049_6_timestamps")
|
||||||
|
|
||||||
|
def display_markdown_table(headers, rows):
|
||||||
|
for i, row in enumerate(rows):
|
||||||
|
new_row = []
|
||||||
|
for col in row:
|
||||||
|
if col is None:
|
||||||
|
new_col = "None"
|
||||||
|
elif isinstance(col, float):
|
||||||
|
new_col = f"{col:.3f}"
|
||||||
|
if new_col == "0.000":
|
||||||
|
new_col = f"{col:.7f}"
|
||||||
|
elif isinstance(col, int):
|
||||||
|
new_col = f"{col}"
|
||||||
|
else:
|
||||||
|
new_col = col
|
||||||
|
new_row.append(new_col)
|
||||||
|
rows[i] = new_row
|
||||||
|
|
||||||
|
header_line = "| " + " | ".join(headers) + " |"
|
||||||
|
separator_line = "| " + " | ".join(["---" for _ in headers]) + " |"
|
||||||
|
body_lines = ["| " + " | ".join(row) + " |" for row in rows]
|
||||||
|
markdown_table = "\n".join([header_line, separator_line] + body_lines)
|
||||||
|
print(markdown_table)
|
||||||
|
print()
|
||||||
|
|
||||||
|
def load_info(out_dir):
|
||||||
|
with open(out_dir / "info.json") as f:
|
||||||
|
info = json.load(f)
|
||||||
|
return info
|
||||||
|
|
||||||
|
repo_ids = ["lerobot/pusht", "lerobot/umi_cup_in_the_wild"]
|
||||||
|
|
||||||
|
# torchvision vs ffmpegio vs torchaudio
|
||||||
|
|
||||||
|
headers = ["repo_id", "decoder", "pc_load_time", "avg_per_pixel_l2_error"]
|
||||||
|
rows = []
|
||||||
|
for repo_id in repo_ids:
|
||||||
|
for decoder in ["torchvision", "ffmpegio", "torchaudio"]:
|
||||||
|
cfg = {
|
||||||
|
"repo_id": repo_id,
|
||||||
|
# video encoding
|
||||||
|
"g": 10,
|
||||||
|
"crf": 10,
|
||||||
|
"pix_fmt": "yuv444p",
|
||||||
|
# video decoding
|
||||||
|
"device": "cpu",
|
||||||
|
"decoder": decoder,
|
||||||
|
"decoder_kwgs": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
if not dry_run:
|
||||||
|
run_video_benchmark(bench_dir / repo_id / decoder, cfg=cfg)
|
||||||
|
info = load_info(bench_dir / repo_id / decoder)
|
||||||
|
rows.append([repo_id, decoder, info["pc_load_time"], info["avg_per_pixel_l2_error"]])
|
||||||
|
display_markdown_table(headers, rows)
|
||||||
|
|
||||||
|
# yuv444p vs yuv420p
|
||||||
|
|
||||||
|
headers = ["repo_id", "pix_fmt", "pc_compression", "pc_load_time", "avg_per_pixel_l2_error"]
|
||||||
|
rows = []
|
||||||
|
for repo_id in repo_ids:
|
||||||
|
for pix_fmt in ["yuv420p", "yuv444p"]:
|
||||||
|
cfg = {
|
||||||
|
"repo_id": repo_id,
|
||||||
|
# video encoding
|
||||||
|
"g": 10,
|
||||||
|
"crf": 10,
|
||||||
|
"pix_fmt": pix_fmt,
|
||||||
|
# video decoding
|
||||||
|
"device": "cpu",
|
||||||
|
"decoder": "torchvision",
|
||||||
|
"decoder_kwgs": {},
|
||||||
|
}
|
||||||
|
if not dry_run:
|
||||||
|
run_video_benchmark(bench_dir / repo_id / f"torchvision_{pix_fmt}", cfg=cfg)
|
||||||
|
info = load_info(bench_dir / repo_id / f"torchvision_{pix_fmt}")
|
||||||
|
rows.append(
|
||||||
|
[
|
||||||
|
repo_id,
|
||||||
|
pix_fmt,
|
||||||
|
info["pc_compression"],
|
||||||
|
info["pc_load_time"],
|
||||||
|
info["avg_per_pixel_l2_error"],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
display_markdown_table(headers, rows)
|
||||||
|
|
||||||
|
# g
|
||||||
|
|
||||||
|
headers = ["repo_id", "g", "pc_compression", "pc_load_time", "avg_per_pixel_l2_error"]
|
||||||
|
rows = []
|
||||||
|
for repo_id in repo_ids:
|
||||||
|
for g in [1, 5, 10, 15, 20, 30, 40, 60, 100, None]:
|
||||||
|
cfg = {
|
||||||
|
"repo_id": repo_id,
|
||||||
|
# video encoding
|
||||||
|
"g": g,
|
||||||
|
"crf": 10,
|
||||||
|
"pix_fmt": "yuv444p",
|
||||||
|
# video decoding
|
||||||
|
"device": "cpu",
|
||||||
|
"decoder": "torchvision",
|
||||||
|
"decoder_kwgs": {},
|
||||||
|
}
|
||||||
|
if not dry_run:
|
||||||
|
run_video_benchmark(bench_dir / repo_id / f"torchvision_g_{g}", cfg=cfg)
|
||||||
|
info = load_info(bench_dir / repo_id / f"torchvision_g_{g}")
|
||||||
|
rows.append(
|
||||||
|
[repo_id, g, info["pc_compression"], info["pc_load_time"], info["avg_per_pixel_l2_error"]]
|
||||||
|
)
|
||||||
|
display_markdown_table(headers, rows)
|
||||||
|
|
||||||
|
# crf
|
||||||
|
|
||||||
|
headers = ["repo_id", "crf", "pc_compression", "pc_load_time", "avg_per_pixel_l2_error"]
|
||||||
|
rows = []
|
||||||
|
for repo_id in repo_ids:
|
||||||
|
for crf in [0, 5, 10, 15, 20, None, 25, 30, 40, 50]:
|
||||||
|
cfg = {
|
||||||
|
"repo_id": repo_id,
|
||||||
|
# video encoding
|
||||||
|
"g": None,
|
||||||
|
"crf": crf,
|
||||||
|
"pix_fmt": "yuv444p",
|
||||||
|
# video decoding
|
||||||
|
"device": "cpu",
|
||||||
|
"decoder": "torchvision",
|
||||||
|
"decoder_kwgs": {},
|
||||||
|
}
|
||||||
|
if not dry_run:
|
||||||
|
run_video_benchmark(bench_dir / repo_id / f"torchvision_crf_{crf}", cfg=cfg)
|
||||||
|
info = load_info(bench_dir / repo_id / f"torchvision_crf_{crf}")
|
||||||
|
rows.append(
|
||||||
|
[repo_id, crf, info["pc_compression"], info["pc_load_time"], info["avg_per_pixel_l2_error"]]
|
||||||
|
)
|
||||||
|
display_markdown_table(headers, rows)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -9,7 +9,9 @@ from lerobot.common.datasets.utils import (
|
||||||
load_info,
|
load_info,
|
||||||
load_previous_and_future_frames,
|
load_previous_and_future_frames,
|
||||||
load_stats,
|
load_stats,
|
||||||
|
load_videos,
|
||||||
)
|
)
|
||||||
|
from lerobot.common.datasets.video_utils import load_from_videos
|
||||||
|
|
||||||
|
|
||||||
class LeRobotDataset(torch.utils.data.Dataset):
|
class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
|
@ -30,18 +32,38 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
self.transform = transform
|
self.transform = transform
|
||||||
self.delta_timestamps = delta_timestamps
|
self.delta_timestamps = delta_timestamps
|
||||||
# load data from hub or locally when root is provided
|
# load data from hub or locally when root is provided
|
||||||
|
# TODO(rcadene, aliberts): implement faster transfer
|
||||||
|
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
|
||||||
self.hf_dataset = load_hf_dataset(repo_id, version, root, split)
|
self.hf_dataset = load_hf_dataset(repo_id, version, root, split)
|
||||||
self.episode_data_index = load_episode_data_index(repo_id, version, root)
|
self.episode_data_index = load_episode_data_index(repo_id, version, root)
|
||||||
self.stats = load_stats(repo_id, version, root)
|
self.stats = load_stats(repo_id, version, root)
|
||||||
self.info = load_info(repo_id, version, root)
|
self.info = load_info(repo_id, version, root)
|
||||||
|
if self.video:
|
||||||
|
self.videos_dir = load_videos(repo_id, version, root)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def fps(self) -> int:
|
def fps(self) -> int:
|
||||||
return self.info["fps"]
|
return self.info["fps"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def video(self) -> int:
|
||||||
|
return self.info.get("video", False)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def image_keys(self) -> list[str]:
|
def image_keys(self) -> list[str]:
|
||||||
return [key for key, feats in self.hf_dataset.features.items() if isinstance(feats, datasets.Image)]
|
image_keys = []
|
||||||
|
for key, feats in self.hf_dataset.features.items():
|
||||||
|
if isinstance(feats, datasets.Image):
|
||||||
|
image_keys.append(key)
|
||||||
|
return image_keys + self.video_frame_keys
|
||||||
|
|
||||||
|
@property
|
||||||
|
def video_frame_keys(self):
|
||||||
|
video_frame_keys = []
|
||||||
|
for key, feats in self.hf_dataset.features.items():
|
||||||
|
if isinstance(feats, datasets.Value) and feats.id == "video_frame":
|
||||||
|
video_frame_keys.append(key)
|
||||||
|
return video_frame_keys
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_samples(self) -> int:
|
def num_samples(self) -> int:
|
||||||
|
@ -66,6 +88,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
|
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.video:
|
||||||
|
item = load_from_videos(item, self.video_frame_keys, self.videos_dir)
|
||||||
|
|
||||||
if self.transform is not None:
|
if self.transform is not None:
|
||||||
item = self.transform(item)
|
item = self.transform(item)
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@ import einops
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
from datasets import Image, load_dataset, load_from_disk
|
from datasets import Image, load_dataset, load_from_disk
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download, snapshot_download
|
||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
@ -127,6 +127,15 @@ def load_info(repo_id, version, root) -> dict:
|
||||||
return info
|
return info
|
||||||
|
|
||||||
|
|
||||||
|
def load_videos(repo_id, version, root) -> Path:
|
||||||
|
if root is not None:
|
||||||
|
path = Path(root) / repo_id / "videos"
|
||||||
|
else:
|
||||||
|
path = snapshot_download(repo_id, allow_patterns="*.mp4", repo_type="dataset", revision=version)
|
||||||
|
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
def load_previous_and_future_frames(
|
def load_previous_and_future_frames(
|
||||||
item: dict[str, torch.Tensor],
|
item: dict[str, torch.Tensor],
|
||||||
hf_dataset: datasets.Dataset,
|
hf_dataset: datasets.Dataset,
|
||||||
|
@ -209,6 +218,7 @@ def load_previous_and_future_frames(
|
||||||
item[key] = hf_dataset.select_columns(key)[data_ids][key]
|
item[key] = hf_dataset.select_columns(key)[data_ids][key]
|
||||||
item[key] = torch.stack(item[key])
|
item[key] = torch.stack(item[key])
|
||||||
item[f"{key}_is_pad"] = is_pad
|
item[f"{key}_is_pad"] = is_pad
|
||||||
|
item[f"{key}_timestamp"] = query_ts
|
||||||
|
|
||||||
return item
|
return item
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,104 @@
|
||||||
|
import itertools
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
|
||||||
|
|
||||||
|
def load_from_videos(item, video_frame_keys, videos_dir):
|
||||||
|
for key in video_frame_keys:
|
||||||
|
ep_idx = item["episode_index"]
|
||||||
|
video_path = videos_dir / key / f"episode_{ep_idx:06d}.mp4"
|
||||||
|
|
||||||
|
if f"{key}_timestamp" in item:
|
||||||
|
# load multiple frames at once
|
||||||
|
timestamps = item[f"{key}_timestamp"]
|
||||||
|
item[key] = decode_video_frames_torchvision(video_path, timestamps)
|
||||||
|
else:
|
||||||
|
# load one frame
|
||||||
|
timestamps = [item["timestamp"]]
|
||||||
|
frames = decode_video_frames_torchvision(video_path, timestamps)
|
||||||
|
assert len(frames) == 1
|
||||||
|
item[key] = frames[0]
|
||||||
|
|
||||||
|
return item
|
||||||
|
|
||||||
|
|
||||||
|
def decode_video_frames_torchvision(
|
||||||
|
video_path: str, timestamps: list[float], device: str = "cpu", log_loaded_timestamps: bool = False
|
||||||
|
):
|
||||||
|
"""Loads frames associated to the requested timestamps of a video
|
||||||
|
|
||||||
|
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
|
||||||
|
the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to
|
||||||
|
that key frame. As a consequence, to access a requested frame, we need to load the preceeding key frame,
|
||||||
|
and all subsequent frames until reaching the requested frame. The number of key frames in a video
|
||||||
|
can be adjusted during encoding to take into account decoding time and video size in bytes.
|
||||||
|
"""
|
||||||
|
# set backend
|
||||||
|
if device == "cpu":
|
||||||
|
# explicitely use pyav
|
||||||
|
torchvision.set_video_backend("pyav")
|
||||||
|
elif device == "cuda":
|
||||||
|
# TODO(rcadene, aliberts): implement video decoding with GPU
|
||||||
|
# torchvision.set_video_backend("cuda")
|
||||||
|
# torchvision.set_video_backend("video_reader")
|
||||||
|
# requires installing torchvision from source, see: https://github.com/pytorch/vision/blob/main/torchvision/csrc/io/decoder/gpu/README.rst
|
||||||
|
# check possible bug: https://github.com/pytorch/vision/issues/7745
|
||||||
|
raise NotImplementedError()
|
||||||
|
else:
|
||||||
|
raise ValueError(device)
|
||||||
|
|
||||||
|
# set a video stream reader
|
||||||
|
# TODO(rcadene): also load audio stream at the same time
|
||||||
|
reader = torchvision.io.VideoReader(str(video_path), "video")
|
||||||
|
|
||||||
|
# sanity preprocessing (e.g. 3.60000003 -> 3.6)
|
||||||
|
timestamps = [round(ts, 4) for ts in timestamps]
|
||||||
|
|
||||||
|
# set the first and last requested timestamps
|
||||||
|
# Note: previous timestamps are usually loaded, since we need to access the previous key frame
|
||||||
|
first_ts = timestamps[0]
|
||||||
|
last_ts = timestamps[-1]
|
||||||
|
|
||||||
|
# access key frame of first requested frame, and load all frames until last requested frame
|
||||||
|
# for details on what `seek` is doing see: https://pyav.basswood-io.com/docs/stable/api/container.html?highlight=inputcontainer#av.container.InputContainer.seek
|
||||||
|
frames = []
|
||||||
|
for frame in itertools.takewhile(lambda x: x["pts"] <= last_ts, reader.seek(first_ts)):
|
||||||
|
# get timestamp of the loaded frame
|
||||||
|
ts = frame["pts"]
|
||||||
|
|
||||||
|
# if the loaded frame is not among the requested frames, we dont add it to the list of output frames
|
||||||
|
is_frame_requested = ts in timestamps
|
||||||
|
if is_frame_requested:
|
||||||
|
frames.append(frame["data"])
|
||||||
|
|
||||||
|
if log_loaded_timestamps:
|
||||||
|
log = f"frame loaded at timestamp={ts:.4f}"
|
||||||
|
if is_frame_requested:
|
||||||
|
log += " requested"
|
||||||
|
print(log)
|
||||||
|
|
||||||
|
frames = torch.stack(frames)
|
||||||
|
|
||||||
|
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
|
||||||
|
frames = frames.type(torch.float32) / 255
|
||||||
|
|
||||||
|
assert len(timestamps) == frames.shape[0]
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
def encode_video_frames(imgs_dir: Path, video_path: Path, fps: int):
|
||||||
|
# For more info this setting, see: `lerobot/common/datasets/_video_benchmark/README.md`
|
||||||
|
video_path = Path(video_path)
|
||||||
|
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
ffmpeg_cmd = (
|
||||||
|
f"ffmpeg -r {fps} -f image2 "
|
||||||
|
f"-i {str(imgs_dir / 'frame_%06d.png')} "
|
||||||
|
"-vcodec libx264 "
|
||||||
|
"-pix_fmt yuv444p "
|
||||||
|
f"{str(video_path)}"
|
||||||
|
)
|
||||||
|
subprocess.run(ffmpeg_cmd.split(" "), check=True)
|
|
@ -255,8 +255,7 @@ def main():
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--video",
|
"--video",
|
||||||
type=int,
|
type=int,
|
||||||
# TODO(rcadene): enable when video PR merges
|
default=1,
|
||||||
default=0,
|
|
||||||
help="Convert each episode of the raw dataset to an mp4 video. This option allows 60 times lower disk space consumption and 25 faster loading time during training.",
|
help="Convert each episode of the raw dataset to an mp4 video. This option allows 60 times lower disk space consumption and 25 faster loading time during training.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand.
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "absl-py"
|
name = "absl-py"
|
||||||
|
@ -3806,38 +3806,6 @@ typing-extensions = ">=4.8.0"
|
||||||
opt-einsum = ["opt-einsum (>=3.3)"]
|
opt-einsum = ["opt-einsum (>=3.3)"]
|
||||||
optree = ["optree (>=0.9.1)"]
|
optree = ["optree (>=0.9.1)"]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "torchaudio"
|
|
||||||
version = "2.3.0"
|
|
||||||
description = "An audio package for PyTorch"
|
|
||||||
optional = false
|
|
||||||
python-versions = "*"
|
|
||||||
files = [
|
|
||||||
{file = "torchaudio-2.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:342108da83aa19a457c9a128b1206fadb603753b51cca022b9f585aac2f4754c"},
|
|
||||||
{file = "torchaudio-2.3.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:73fedb2c631e01fa10feaac308540b836aefe758e55ca3ee026335e5d01e8e30"},
|
|
||||||
{file = "torchaudio-2.3.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:e5bb50b7a4874ed97086c9e516dd90b103d954edcb5ed4b36f4fc22c4000a5a7"},
|
|
||||||
{file = "torchaudio-2.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:b4cc9cef5c98ed37e9405c4e0b0e6413bc101f3f49d45dc4f1d4e927757fe41e"},
|
|
||||||
{file = "torchaudio-2.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:341ca3048ce6edcc731519b30187f0b13acb245c4efe16f925f69f9d533546e1"},
|
|
||||||
{file = "torchaudio-2.3.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:8f2e0a28740bb0ee66369f92c811f33c0a47e6fcfc2de9cee89746472d713906"},
|
|
||||||
{file = "torchaudio-2.3.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:61edb02ae9c0efea4399f9c1f899601136b24f35d430548284ea8eaf6ccbe3be"},
|
|
||||||
{file = "torchaudio-2.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:04bc960cf1aef3b469b095a432a25496bc28197850fc2d90b7b52d6b5255487b"},
|
|
||||||
{file = "torchaudio-2.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:535144a2fbba95fbb3b883224ffcf44788e4cecbabbe49c4a1ae3e7a74f71485"},
|
|
||||||
{file = "torchaudio-2.3.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:fb3f52ed1d63b272c240d9bf051705312cb172212051b8a6a2f64d42e3cc1633"},
|
|
||||||
{file = "torchaudio-2.3.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:668a8b694e5522cff28cd5e02d01aa1b75ce940aa9fb40480892bdc623b1735d"},
|
|
||||||
{file = "torchaudio-2.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:6c1f538018b85d7766835d042e555de2f096f7a69bba6b16031bf42a914dd9e1"},
|
|
||||||
{file = "torchaudio-2.3.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7ba93265455dc363385e98c0cfcaeb586b7401af8a2c824811ee1466134a4f30"},
|
|
||||||
{file = "torchaudio-2.3.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:21bb6d1b384fc8895133f01489133d575d4a715cd81734b89651fb0264bd8b80"},
|
|
||||||
{file = "torchaudio-2.3.0-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:ed1866f508dc689c4f682d330b2ed4c83108d35865e4fb89431819364d8ad9ed"},
|
|
||||||
{file = "torchaudio-2.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:a3cbb230e2bb38ad1a1dd74aea242a154a9f76ab819d9c058b2c5074a9f5d7d2"},
|
|
||||||
{file = "torchaudio-2.3.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f4b933776f20a36af5ddc57968fcb3da34dd03881db8d6760f3e1176803b9cf8"},
|
|
||||||
{file = "torchaudio-2.3.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:c5e63cc2dbf179088b6cdfd21ecdbb943aa003c780075aa440162f231ee72db2"},
|
|
||||||
{file = "torchaudio-2.3.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:d243bb8a1ee263c2cdafb9feed1569c3742d8135731e8f7818de12f4e0c83e28"},
|
|
||||||
{file = "torchaudio-2.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:6cd6d45cf8a45c89953e35434d9a461feb418e51e760adafc606a903dcbb9bd5"},
|
|
||||||
]
|
|
||||||
|
|
||||||
[package.dependencies]
|
|
||||||
torch = "2.3.0"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "torchvision"
|
name = "torchvision"
|
||||||
version = "0.18.0"
|
version = "0.18.0"
|
||||||
|
@ -4299,4 +4267,4 @@ xarm = ["gym-xarm"]
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.10"
|
python-versions = "^3.10"
|
||||||
content-hash = "0f72eb92ac8817a46f0659b4d72647a6b76f6e4ba762d11b280f8a88e6cd4371"
|
content-hash = "fab42b4be590cb2007934cd8f5a218f1f3da4f0b42cdff7e7724af518888d7b4"
|
||||||
|
|
|
@ -56,7 +56,6 @@ pytest = {version = "^8.1.0", optional = true}
|
||||||
pytest-cov = {version = "^5.0.0", optional = true}
|
pytest-cov = {version = "^5.0.0", optional = true}
|
||||||
datasets = "^2.19.0"
|
datasets = "^2.19.0"
|
||||||
imagecodecs = { version = "^2024.1.1", optional = true }
|
imagecodecs = { version = "^2024.1.1", optional = true }
|
||||||
torchaudio = "^2.3.0"
|
|
||||||
|
|
||||||
|
|
||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
|
|
Loading…
Reference in New Issue