Merge branch 'main' of github.com:huggingface/lerobot
This commit is contained in:
commit
e05066a88b
45
README.md
45
README.md
|
@ -58,6 +58,7 @@
|
|||
- Thanks to Cheng Chi, Zhenjia Xu and colleagues for open sourcing Diffusion policy, Pusht environment and datasets, as well as UMI datasets. Ours are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu) and [UMI Gripper](https://umi-gripper.github.io).
|
||||
- Thanks to Nicklas Hansen, Yunhai Feng and colleagues for open sourcing TDMPC policy, Simxarm environments and datasets. Ours are adapted from [TDMPC](https://github.com/nicklashansen/tdmpc) and [FOWM](https://www.yunhaifeng.com/FOWM).
|
||||
- Thanks to Antonio Loquercio and Ashish Kumar for their early support.
|
||||
- Thanks to [Seungjae (Jay) Lee](https://sjlee.cc/), [Mahi Shafiullah](https://mahis.life/) and colleagues for open sourcing [VQ-BeT](https://sjlee.cc/vq-bet/) policy and helping us adapt the codebase to our repository. The policy is adapted from [VQ-BeT repo](https://github.com/jayLEE0301/vq_bet_official).
|
||||
|
||||
|
||||
## Installation
|
||||
|
@ -339,7 +340,7 @@ with profile(
|
|||
## Citation
|
||||
|
||||
If you want, you can cite this work with:
|
||||
```
|
||||
```bibtex
|
||||
@misc{cadene2024lerobot,
|
||||
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Wolf, Thomas},
|
||||
title = {LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch},
|
||||
|
@ -347,3 +348,45 @@ If you want, you can cite this work with:
|
|||
year = {2024}
|
||||
}
|
||||
```
|
||||
|
||||
Additionally, if you are using any of the particular policy architecture, pretrained models, or datasets, it is recommended to cite the original authors of the work as they appear below:
|
||||
|
||||
- [Diffusion Policy](https://diffusion-policy.cs.columbia.edu)
|
||||
```bibtex
|
||||
@article{chi2024diffusionpolicy,
|
||||
author = {Cheng Chi and Zhenjia Xu and Siyuan Feng and Eric Cousineau and Yilun Du and Benjamin Burchfiel and Russ Tedrake and Shuran Song},
|
||||
title ={Diffusion Policy: Visuomotor Policy Learning via Action Diffusion},
|
||||
journal = {The International Journal of Robotics Research},
|
||||
year = {2024},
|
||||
}
|
||||
```
|
||||
- [ACT or ALOHA](https://tonyzhaozh.github.io/aloha)
|
||||
```bibtex
|
||||
@article{zhao2023learning,
|
||||
title={Learning fine-grained bimanual manipulation with low-cost hardware},
|
||||
author={Zhao, Tony Z and Kumar, Vikash and Levine, Sergey and Finn, Chelsea},
|
||||
journal={arXiv preprint arXiv:2304.13705},
|
||||
year={2023}
|
||||
}
|
||||
```
|
||||
|
||||
- [TDMPC](https://www.nicklashansen.com/td-mpc/)
|
||||
|
||||
```bibtex
|
||||
@inproceedings{Hansen2022tdmpc,
|
||||
title={Temporal Difference Learning for Model Predictive Control},
|
||||
author={Nicklas Hansen and Xiaolong Wang and Hao Su},
|
||||
booktitle={ICML},
|
||||
year={2022}
|
||||
}
|
||||
```
|
||||
|
||||
- [VQ-BeT](https://sjlee.cc/vq-bet/)
|
||||
```bibtex
|
||||
@article{lee2024behavior,
|
||||
title={Behavior generation with latent actions},
|
||||
author={Lee, Seungjae and Wang, Yibin and Etukuru, Haritheja and Kim, H Jin and Shafiullah, Nur Muhammad Mahi and Pinto, Lerrel},
|
||||
journal={arXiv preprint arXiv:2403.03181},
|
||||
year={2024}
|
||||
}
|
||||
```
|
||||
|
|
|
@ -0,0 +1,271 @@
|
|||
# Video benchmark
|
||||
|
||||
|
||||
## Questions
|
||||
What is the optimal trade-off between:
|
||||
- maximizing loading time with random access,
|
||||
- minimizing memory space on disk,
|
||||
- maximizing success rate of policies,
|
||||
- compatibility across devices/platforms for decoding videos (e.g. video players, web browsers).
|
||||
|
||||
How to encode videos?
|
||||
- Which video codec (`-vcodec`) to use? h264, h265, AV1?
|
||||
- What pixel format to use (`-pix_fmt`)? `yuv444p` or `yuv420p`?
|
||||
- How much compression (`-crf`)? No compression with `0`, intermediate compression with `25` or extreme with `50+`?
|
||||
- Which frequency to chose for key frames (`-g`)? A key frame every `10` frames?
|
||||
|
||||
How to decode videos?
|
||||
- Which `decoder`? `torchvision`, `torchaudio`, `ffmpegio`, `decord`, or `nvc`?
|
||||
- What scenarios to use for the requesting timestamps during benchmark? (`timestamps_mode`)
|
||||
|
||||
|
||||
## Variables
|
||||
**Image content & size**
|
||||
We don't expect the same optimal settings for a dataset of images from a simulation, or from real-world in an appartment, or in a factory, or outdoor, or with lots of moving objects in the scene, etc. Similarly, loading times might not vary linearly with the image size (resolution).
|
||||
For these reasons, we run this benchmark on four representative datasets:
|
||||
- `lerobot/pusht_image`: (96 x 96 pixels) simulation with simple geometric shapes, fixed camera.
|
||||
- `aliberts/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera.
|
||||
- `aliberts/paris_street`: (720 x 1280 pixels) real-world outdoor, moving camera.
|
||||
- `aliberts/kitchen`: (1080 x 1920 pixels) real-world indoor, fixed camera.
|
||||
|
||||
Note: The datasets used for this benchmark need to be image datasets, not video datasets.
|
||||
|
||||
**Data augmentations**
|
||||
We might revisit this benchmark and find better settings if we train our policies with various data augmentations to make them more robust (e.g. robust to color changes, compression, etc.).
|
||||
|
||||
### Encoding parameters
|
||||
| parameter | values |
|
||||
|-------------|--------------------------------------------------------------|
|
||||
| **vcodec** | `libx264`, `libx265`, `libsvtav1` |
|
||||
| **pix_fmt** | `yuv444p`, `yuv420p` |
|
||||
| **g** | `1`, `2`, `3`, `4`, `5`, `6`, `10`, `15`, `20`, `40`, `None` |
|
||||
| **crf** | `0`, `5`, `10`, `15`, `20`, `25`, `30`, `40`, `50`, `None` |
|
||||
|
||||
Note that `crf` value might be interpreted differently by various video codecs. In other words, the same value used with one codec doesn't necessarily translate into the same compression level with another codec. In fact, the default value (`None`) isn't the same amongst the different video codecs. Importantly, it is also the case for many other ffmpeg arguments like `g` which specifies the frequency of the key frames.
|
||||
|
||||
For a comprehensive list and documentation of these parameters, see the ffmpeg documentation depending on the video codec used:
|
||||
- h264: https://trac.ffmpeg.org/wiki/Encode/H.264
|
||||
- h265: https://trac.ffmpeg.org/wiki/Encode/H.265
|
||||
- AV1: https://trac.ffmpeg.org/wiki/Encode/AV1
|
||||
|
||||
### Decoding parameters
|
||||
**Decoder**
|
||||
We tested two video decoding backends from torchvision:
|
||||
- `pyav` (default)
|
||||
- `video_reader` (requires to build torchvision from source)
|
||||
|
||||
**Requested timestamps**
|
||||
Given the way video decoding works, once a keyframe has been loaded, the decoding of subsequent frames is fast.
|
||||
This of course is affected by the `-g` parameter during encoding, which specifies the frequency of the keyframes. Given our typical use cases in robotics policies which might request a few timestamps in different random places, we want to replicate these use cases with the following scenarios:
|
||||
- `1_frame`: 1 frame,
|
||||
- `2_frames`: 2 consecutive frames (e.g. `[t, t + 1 / fps]`),
|
||||
- `6_frames`: 6 consecutive frames (e.g. `[t + i / fps for i in range(6)]`)
|
||||
|
||||
Note that this differs significantly from a typical use case like watching a movie, in which every frame is loaded sequentially from the beginning to the end and it's acceptable to have big values for `-g`.
|
||||
|
||||
Additionally, because some policies might request single timestamps that are a few frames appart, we also have the following scenario:
|
||||
- `2_frames_4_space`: 2 frames with 4 consecutive frames of spacing in between (e.g `[t, t + 5 / fps]`),
|
||||
|
||||
However, due to how video decoding is implemented with `pyav`, we don't have access to an accurate seek so in practice this scenario is essentially the same as `6_frames` since all 6 frames between `t` and `t + 5 / fps` will be decoded.
|
||||
|
||||
|
||||
## Metrics
|
||||
**Data compression ratio (lower is better)**
|
||||
`video_images_size_ratio` is the ratio of the memory space on disk taken by the encoded video over the memory space taken by the original images. For instance, `video_images_size_ratio=25%` means that the video takes 4 times less memory space on disk compared to the original images.
|
||||
|
||||
**Loading time ratio (lower is better)**
|
||||
`video_images_load_time_ratio` is the ratio of the time it takes to decode frames from the video at a given timestamps over the time it takes to load the exact same original images. Lower is better. For instance, `video_images_load_time_ratio=200%` means that decoding from video is 2 times slower than loading the original images.
|
||||
|
||||
**Average Mean Square Error (lower is better)**
|
||||
`avg_mse` is the average mean square error between each decoded frame and its corresponding original image over all requested timestamps, and also divided by the number of pixels in the image to be comparable when switching to different image sizes.
|
||||
|
||||
**Average Peak Signal to Noise Ratio (higher is better)**
|
||||
`avg_psnr` measures the ratio between the maximum possible power of a signal and the power of corrupting noise that affects the fidelity of its representation. Higher PSNR indicates better quality.
|
||||
|
||||
**Average Structural Similarity Index Measure (higher is better)**
|
||||
`avg_ssim` evaluates the perceived quality of images by comparing luminance, contrast, and structure. SSIM values range from -1 to 1, where 1 indicates perfect similarity.
|
||||
|
||||
One aspect that can't be measured here with those metrics is the compatibility of the encoding accross platforms, in particular on web browser, for visualization purposes.
|
||||
h264, h265 and AV1 are all commonly used codecs and should not be pose an issue. However, the chroma subsampling (`pix_fmt`) format might affect compatibility:
|
||||
- `yuv420p` is more widely supported across various platforms, including web browsers.
|
||||
- `yuv444p` offers higher color fidelity but might not be supported as broadly.
|
||||
|
||||
|
||||
<!-- **Loss of a pretrained policy (higher is better)** (not available)
|
||||
`loss_pretrained` is the result of evaluating with the selected encoding/decoding settings a policy pretrained on original images. It is easier to understand than `avg_l2_error`.
|
||||
|
||||
**Success rate after retraining (higher is better)** (not available)
|
||||
`success_rate` is the result of training and evaluating a policy with the selected encoding/decoding settings. It is the most difficult metric to get but also the very best. -->
|
||||
|
||||
|
||||
## How the benchmark works
|
||||
The benchmark evaluates both encoding and decoding of video frames on the first episode of each dataset.
|
||||
|
||||
**Encoding:** for each `vcodec` and `pix_fmt` pair, we use a default value for `g` and `crf` upon which we change a single value (either `g` or `crf`) to one of the specified values (we don't test every combination of those as this would be computationally too heavy).
|
||||
This gives a unique set of encoding parameters which is used to encode the episode.
|
||||
|
||||
**Decoding:** Then, for each of those unique encodings, we iterate through every combination of the decoding parameters `backend` and `timestamps_mode`. For each of them, we record the metrics of a number of samples (given by `--num-samples`). This is parallelized for efficiency and the number of processes can be controlled with `--num-workers`. Ideally, it's best to have a `--num-samples` that is divisible by `--num-workers`.
|
||||
|
||||
Intermediate results saved for each `vcodec` and `pix_fmt` combination in csv tables.
|
||||
These are then all concatenated to a single table ready for analysis.
|
||||
|
||||
## Caveats
|
||||
We tried to measure the most impactful parameters for both encoding and decoding. However, for computational reasons we can't test out every combination.
|
||||
|
||||
Additional encoding parameters exist that are not included in this benchmark. In particular:
|
||||
- `-preset` which allows for selecting encoding presets. This represents a collection of options that will provide a certain encoding speed to compression ratio. By leaving this parameter unspecified, it is considered to be `medium` for libx264 and libx265 and `8` for libsvtav1.
|
||||
- `-tune` which allows to optimize the encoding for certains aspects (e.g. film quality, fast decoding, etc.).
|
||||
|
||||
See the documentation mentioned above for more detailled info on these settings and for a more comprehensive list of other parameters.
|
||||
|
||||
Similarly on the decoding side, other decoders exist but are not implemented in our current benchmark. To name a few:
|
||||
- `torchaudio`
|
||||
- `ffmpegio`
|
||||
- `decord`
|
||||
- `nvc`
|
||||
|
||||
Note as well that since we are mostly interested in the performance at decoding time (also because encoding is done only once before uploading a dataset), we did not measure encoding times nor have any metrics regarding encoding.
|
||||
However, besides the necessity to build ffmpeg from source, encoding did not pose any issue and it didn't take a significant amount of time during this benchmark.
|
||||
|
||||
|
||||
## Install
|
||||
Building ffmpeg from source is required to include libx265 and libaom/libsvtav1 (av1) video codecs ([compilation guide](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu)).
|
||||
|
||||
**Note:** While you still need to build torchvision with a conda-installed `ffmpeg<4.3` to use the `video_reader` decoder (as described in [#220](https://github.com/huggingface/lerobot/pull/220)), you also need another version which is custom-built with all the video codecs for encoding. For the script to then use that version, you can prepend the command above with `PATH="$HOME/bin:$PATH"`, which is where ffmpeg should be built.
|
||||
|
||||
|
||||
## Adding a video decoder
|
||||
Right now, we're only benchmarking the two video decoder available with torchvision: `pyav` and `video_reader`.
|
||||
You can easily add a new decoder to benchmark by adding it to this function in the script:
|
||||
```diff
|
||||
def decode_video_frames(
|
||||
video_path: str,
|
||||
timestamps: list[float],
|
||||
tolerance_s: float,
|
||||
backend: str,
|
||||
) -> torch.Tensor:
|
||||
if backend in ["pyav", "video_reader"]:
|
||||
return decode_video_frames_torchvision(
|
||||
video_path, timestamps, tolerance_s, backend
|
||||
)
|
||||
+ elif backend == ["your_decoder"]:
|
||||
+ return your_decoder_function(
|
||||
+ video_path, timestamps, tolerance_s, backend
|
||||
+ )
|
||||
else:
|
||||
raise NotImplementedError(backend)
|
||||
```
|
||||
|
||||
|
||||
## Example
|
||||
For a quick run, you can try these parameters:
|
||||
```bash
|
||||
python benchmark/video/run_video_benchmark.py \
|
||||
--output-dir outputs/video_benchmark \
|
||||
--repo-ids \
|
||||
lerobot/pusht_image \
|
||||
aliberts/aloha_mobile_shrimp_image \
|
||||
--vcodec libx264 libx265 \
|
||||
--pix-fmt yuv444p yuv420p \
|
||||
--g 2 20 None \
|
||||
--crf 10 40 None \
|
||||
--timestamps-modes 1_frame 2_frames \
|
||||
--backends pyav video_reader \
|
||||
--num-samples 5 \
|
||||
--num-workers 5 \
|
||||
--save-frames 0
|
||||
```
|
||||
|
||||
|
||||
## Results
|
||||
|
||||
### Reproduce
|
||||
We ran the benchmark with the following parameters:
|
||||
```bash
|
||||
# h264 and h265 encodings
|
||||
python benchmark/video/run_video_benchmark.py \
|
||||
--output-dir outputs/video_benchmark \
|
||||
--repo-ids \
|
||||
lerobot/pusht_image \
|
||||
aliberts/aloha_mobile_shrimp_image \
|
||||
aliberts/paris_street \
|
||||
aliberts/kitchen \
|
||||
--vcodec libx264 libx265 \
|
||||
--pix-fmt yuv444p yuv420p \
|
||||
--g 1 2 3 4 5 6 10 15 20 40 None \
|
||||
--crf 0 5 10 15 20 25 30 40 50 None \
|
||||
--timestamps-modes 1_frame 2_frames 6_frames \
|
||||
--backends pyav video_reader \
|
||||
--num-samples 50 \
|
||||
--num-workers 5 \
|
||||
--save-frames 1
|
||||
|
||||
# av1 encoding (only compatible with yuv420p and pyav decoder)
|
||||
python benchmark/video/run_video_benchmark.py \
|
||||
--output-dir outputs/video_benchmark \
|
||||
--repo-ids \
|
||||
lerobot/pusht_image \
|
||||
aliberts/aloha_mobile_shrimp_image \
|
||||
aliberts/paris_street \
|
||||
aliberts/kitchen \
|
||||
--vcodec libsvtav1 \
|
||||
--pix-fmt yuv420p \
|
||||
--g 1 2 3 4 5 6 10 15 20 40 None \
|
||||
--crf 0 5 10 15 20 25 30 40 50 None \
|
||||
--timestamps-modes 1_frame 2_frames 6_frames \
|
||||
--backends pyav \
|
||||
--num-samples 50 \
|
||||
--num-workers 5 \
|
||||
--save-frames 1
|
||||
```
|
||||
|
||||
The full results are available [here](https://docs.google.com/spreadsheets/d/1OYJB43Qu8fC26k_OyoMFgGBBKfQRCi4BIuYitQnq3sw/edit?usp=sharing)
|
||||
|
||||
|
||||
### Parameters selected for LeRobotDataset
|
||||
Considering these results, we chose what we think is the best set of encoding parameter:
|
||||
- vcodec: `libsvtav1`
|
||||
- pix-fmt: `yuv420p`
|
||||
- g: `2`
|
||||
- crf: `30`
|
||||
|
||||
Since we're using av1 encoding, we're choosing the `pyav` decoder as `video_reader` does not support it (and `pyav` doesn't require a custom build of `torchvision`).
|
||||
|
||||
### Summary
|
||||
|
||||
These tables show the results for `g=2` and `crf=30`, using `timestamps-modes=6_frames` and `backend=pyav`
|
||||
|
||||
| video_images_size_ratio | vcodec | pix_fmt | | | |
|
||||
|------------------------------------|------------|---------|-----------|-----------|-----------|
|
||||
| | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | **16.97%** | 17.58% | 18.57% | 18.86% | 22.06% |
|
||||
| aliberts/aloha_mobile_shrimp_image | 2.14% | 2.11% | 1.38% | **1.37%** | 5.59% |
|
||||
| aliberts/paris_street | 2.12% | 2.13% | **1.54%** | **1.54%** | 4.43% |
|
||||
| aliberts/kitchen | 1.40% | 1.39% | **1.00%** | **1.00%** | 2.52% |
|
||||
|
||||
| video_images_load_time_ratio | vcodec | pix_fmt | | | |
|
||||
|------------------------------------|---------|---------|----------|---------|-----------|
|
||||
| | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | 6.45 | 5.19 | **1.90** | 2.12 | 2.47 |
|
||||
| aliberts/aloha_mobile_shrimp_image | 11.80 | 7.92 | 0.71 | 0.85 | **0.48** |
|
||||
| aliberts/paris_street | 2.21 | 2.05 | 0.36 | 0.49 | **0.30** |
|
||||
| aliberts/kitchen | 1.46 | 1.46 | 0.28 | 0.51 | **0.26** |
|
||||
|
||||
| | | vcodec | pix_fmt | | | |
|
||||
|------------------------------------|----------|----------|--------------|----------|-----------|--------------|
|
||||
| | | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | metric | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | avg_mse | 2.90E-04 | **2.03E-04** | 3.13E-04 | 2.29E-04 | 2.19E-04 |
|
||||
| | avg_psnr | 35.44 | 37.07 | 35.49 | **37.30** | 37.20 |
|
||||
| | avg_ssim | 98.28% | **98.85%** | 98.31% | 98.84% | 98.72% |
|
||||
| aliberts/aloha_mobile_shrimp_image | avg_mse | 2.76E-04 | 2.59E-04 | 3.17E-04 | 3.06E-04 | **1.30E-04** |
|
||||
| | avg_psnr | 35.91 | 36.21 | 35.88 | 36.09 | **40.17** |
|
||||
| | avg_ssim | 95.19% | 95.18% | 95.00% | 95.05% | **97.73%** |
|
||||
| aliberts/paris_street | avg_mse | 6.89E-04 | 6.70E-04 | 4.03E-03 | 4.02E-03 | **3.09E-04** |
|
||||
| | avg_psnr | 33.48 | 33.68 | 32.05 | 32.15 | **35.40** |
|
||||
| | avg_ssim | 93.76% | 93.75% | 89.46% | 89.46% | **95.46%** |
|
||||
| aliberts/kitchen | avg_mse | 2.50E-04 | 2.24E-04 | 4.28E-04 | 4.18E-04 | **1.53E-04** |
|
||||
| | avg_psnr | 36.73 | 37.33 | 36.56 | 36.75 | **39.12** |
|
||||
| | avg_ssim | 95.47% | 95.58% | 95.52% | 95.53% | **96.82%** |
|
|
@ -0,0 +1,490 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Assess the performance of video decoding in various configurations.
|
||||
|
||||
This script will benchmark different video encoding and decoding parameters.
|
||||
See the provided README.md or run `python benchmark/video/run_video_benchmark.py --help` for usage info.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import datetime as dt
|
||||
import random
|
||||
import shutil
|
||||
from collections import OrderedDict
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import PIL
|
||||
import torch
|
||||
from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.video_utils import (
|
||||
decode_video_frames_torchvision,
|
||||
encode_video_frames,
|
||||
)
|
||||
from lerobot.common.utils.benchmark import TimeBenchmark
|
||||
|
||||
BASE_ENCODING = OrderedDict(
|
||||
[
|
||||
("vcodec", "libx264"),
|
||||
("pix_fmt", "yuv444p"),
|
||||
("g", 2),
|
||||
("crf", None),
|
||||
# TODO(aliberts): Add fastdecode
|
||||
# ("fastdecode", 0),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# TODO(rcadene, aliberts): move to `utils.py` folder when we want to refactor
|
||||
def parse_int_or_none(value) -> int | None:
|
||||
if value.lower() == "none":
|
||||
return None
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError as e:
|
||||
raise argparse.ArgumentTypeError(f"Invalid int or None: {value}") from e
|
||||
|
||||
|
||||
def check_datasets_formats(repo_ids: list) -> None:
|
||||
for repo_id in repo_ids:
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
if dataset.video:
|
||||
raise ValueError(
|
||||
f"Use only image dataset for running this benchmark. Video dataset provided: {repo_id}"
|
||||
)
|
||||
|
||||
|
||||
def get_directory_size(directory: Path) -> int:
|
||||
total_size = 0
|
||||
for item in directory.rglob("*"):
|
||||
if item.is_file():
|
||||
total_size += item.stat().st_size
|
||||
return total_size
|
||||
|
||||
|
||||
def load_original_frames(imgs_dir: Path, timestamps: list[float], fps: int) -> torch.Tensor:
|
||||
frames = []
|
||||
for ts in timestamps:
|
||||
idx = int(ts * fps)
|
||||
frame = PIL.Image.open(imgs_dir / f"frame_{idx:06d}.png")
|
||||
frame = torch.from_numpy(np.array(frame))
|
||||
frame = frame.type(torch.float32) / 255
|
||||
frame = einops.rearrange(frame, "h w c -> c h w")
|
||||
frames.append(frame)
|
||||
return torch.stack(frames)
|
||||
|
||||
|
||||
def save_decoded_frames(
|
||||
imgs_dir: Path, save_dir: Path, frames: torch.Tensor, timestamps: list[float], fps: int
|
||||
) -> None:
|
||||
if save_dir.exists() and len(list(save_dir.glob("frame_*.png"))) == len(timestamps):
|
||||
return
|
||||
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
for i, ts in enumerate(timestamps):
|
||||
idx = int(ts * fps)
|
||||
frame_hwc = (frames[i].permute((1, 2, 0)) * 255).type(torch.uint8).cpu().numpy()
|
||||
PIL.Image.fromarray(frame_hwc).save(save_dir / f"frame_{idx:06d}_decoded.png")
|
||||
shutil.copyfile(imgs_dir / f"frame_{idx:06d}.png", save_dir / f"frame_{idx:06d}_original.png")
|
||||
|
||||
|
||||
def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
|
||||
ep_num_images = dataset.episode_data_index["to"][0].item()
|
||||
if imgs_dir.exists() and len(list(imgs_dir.glob("frame_*.png"))) == ep_num_images:
|
||||
return
|
||||
|
||||
imgs_dir.mkdir(parents=True, exist_ok=True)
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
|
||||
# We only save images from the first camera
|
||||
img_keys = [key for key in hf_dataset.features if key.startswith("observation.image")]
|
||||
imgs_dataset = hf_dataset.select_columns(img_keys[0])
|
||||
|
||||
for i, item in enumerate(
|
||||
tqdm(imgs_dataset, desc=f"saving {dataset.repo_id} first episode images", leave=False)
|
||||
):
|
||||
img = item[img_keys[0]]
|
||||
img.save(str(imgs_dir / f"frame_{i:06d}.png"), quality=100)
|
||||
|
||||
if i >= ep_num_images - 1:
|
||||
break
|
||||
|
||||
|
||||
def sample_timestamps(timestamps_mode: str, ep_num_images: int, fps: int) -> list[float]:
|
||||
# Start at 5 to allow for 2_frames_4_space and 6_frames
|
||||
idx = random.randint(5, ep_num_images - 1)
|
||||
match timestamps_mode:
|
||||
case "1_frame":
|
||||
frame_indexes = [idx]
|
||||
case "2_frames":
|
||||
frame_indexes = [idx - 1, idx]
|
||||
case "2_frames_4_space":
|
||||
frame_indexes = [idx - 5, idx]
|
||||
case "6_frames":
|
||||
frame_indexes = [idx - i for i in range(6)][::-1]
|
||||
case _:
|
||||
raise ValueError(timestamps_mode)
|
||||
|
||||
return [idx / fps for idx in frame_indexes]
|
||||
|
||||
|
||||
def decode_video_frames(
|
||||
video_path: str,
|
||||
timestamps: list[float],
|
||||
tolerance_s: float,
|
||||
backend: str,
|
||||
) -> torch.Tensor:
|
||||
if backend in ["pyav", "video_reader"]:
|
||||
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
||||
else:
|
||||
raise NotImplementedError(backend)
|
||||
|
||||
|
||||
def benchmark_decoding(
|
||||
imgs_dir: Path,
|
||||
video_path: Path,
|
||||
timestamps_mode: str,
|
||||
backend: str,
|
||||
ep_num_images: int,
|
||||
fps: int,
|
||||
num_samples: int = 50,
|
||||
num_workers: int = 4,
|
||||
save_frames: bool = False,
|
||||
) -> dict:
|
||||
def process_sample(sample: int):
|
||||
time_benchmark = TimeBenchmark()
|
||||
timestamps = sample_timestamps(timestamps_mode, ep_num_images, fps)
|
||||
num_frames = len(timestamps)
|
||||
result = {
|
||||
"psnr_values": [],
|
||||
"ssim_values": [],
|
||||
"mse_values": [],
|
||||
}
|
||||
|
||||
with time_benchmark:
|
||||
frames = decode_video_frames(video_path, timestamps=timestamps, tolerance_s=5e-1, backend=backend)
|
||||
result["load_time_video_ms"] = time_benchmark.result_ms / num_frames
|
||||
|
||||
with time_benchmark:
|
||||
original_frames = load_original_frames(imgs_dir, timestamps, fps)
|
||||
result["load_time_images_ms"] = time_benchmark.result_ms / num_frames
|
||||
|
||||
frames_np, original_frames_np = frames.numpy(), original_frames.numpy()
|
||||
for i in range(num_frames):
|
||||
result["mse_values"].append(mean_squared_error(original_frames_np[i], frames_np[i]))
|
||||
result["psnr_values"].append(
|
||||
peak_signal_noise_ratio(original_frames_np[i], frames_np[i], data_range=1.0)
|
||||
)
|
||||
result["ssim_values"].append(
|
||||
structural_similarity(original_frames_np[i], frames_np[i], data_range=1.0, channel_axis=0)
|
||||
)
|
||||
|
||||
if save_frames and sample == 0:
|
||||
save_dir = video_path.with_suffix("") / f"{timestamps_mode}_{backend}"
|
||||
save_decoded_frames(imgs_dir, save_dir, frames, timestamps, fps)
|
||||
|
||||
return result
|
||||
|
||||
load_times_video_ms = []
|
||||
load_times_images_ms = []
|
||||
mse_values = []
|
||||
psnr_values = []
|
||||
ssim_values = []
|
||||
|
||||
# A sample is a single set of decoded frames specified by timestamps_mode (e.g. a single frame, 2 frames, etc.).
|
||||
# For each sample, we record metrics (loading time and quality metrics) which are then averaged over all samples.
|
||||
# As these samples are independent, we run them in parallel threads to speed up the benchmark.
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
futures = [executor.submit(process_sample, i) for i in range(num_samples)]
|
||||
for future in tqdm(as_completed(futures), total=num_samples, desc="samples", leave=False):
|
||||
result = future.result()
|
||||
load_times_video_ms.append(result["load_time_video_ms"])
|
||||
load_times_images_ms.append(result["load_time_images_ms"])
|
||||
psnr_values.extend(result["psnr_values"])
|
||||
ssim_values.extend(result["ssim_values"])
|
||||
mse_values.extend(result["mse_values"])
|
||||
|
||||
avg_load_time_video_ms = float(np.array(load_times_video_ms).mean())
|
||||
avg_load_time_images_ms = float(np.array(load_times_images_ms).mean())
|
||||
video_images_load_time_ratio = avg_load_time_video_ms / avg_load_time_images_ms
|
||||
|
||||
return {
|
||||
"avg_load_time_video_ms": avg_load_time_video_ms,
|
||||
"avg_load_time_images_ms": avg_load_time_images_ms,
|
||||
"video_images_load_time_ratio": video_images_load_time_ratio,
|
||||
"avg_mse": float(np.mean(mse_values)),
|
||||
"avg_psnr": float(np.mean(psnr_values)),
|
||||
"avg_ssim": float(np.mean(ssim_values)),
|
||||
}
|
||||
|
||||
|
||||
def benchmark_encoding_decoding(
|
||||
dataset: LeRobotDataset,
|
||||
video_path: Path,
|
||||
imgs_dir: Path,
|
||||
encoding_cfg: dict,
|
||||
decoding_cfg: dict,
|
||||
num_samples: int,
|
||||
num_workers: int,
|
||||
save_frames: bool,
|
||||
overwrite: bool = False,
|
||||
seed: int = 1337,
|
||||
) -> list[dict]:
|
||||
fps = dataset.fps
|
||||
|
||||
if overwrite or not video_path.is_file():
|
||||
tqdm.write(f"encoding {video_path}")
|
||||
encode_video_frames(
|
||||
imgs_dir=imgs_dir,
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
video_codec=encoding_cfg["vcodec"],
|
||||
pixel_format=encoding_cfg["pix_fmt"],
|
||||
group_of_pictures_size=encoding_cfg.get("g"),
|
||||
constant_rate_factor=encoding_cfg.get("crf"),
|
||||
# fast_decode=encoding_cfg.get("fastdecode"),
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
ep_num_images = dataset.episode_data_index["to"][0].item()
|
||||
width, height = tuple(dataset[0][dataset.camera_keys[0]].shape[-2:])
|
||||
num_pixels = width * height
|
||||
video_size_bytes = video_path.stat().st_size
|
||||
images_size_bytes = get_directory_size(imgs_dir)
|
||||
video_images_size_ratio = video_size_bytes / images_size_bytes
|
||||
|
||||
random.seed(seed)
|
||||
benchmark_table = []
|
||||
for timestamps_mode in tqdm(
|
||||
decoding_cfg["timestamps_modes"], desc="decodings (timestamps_modes)", leave=False
|
||||
):
|
||||
for backend in tqdm(decoding_cfg["backends"], desc="decodings (backends)", leave=False):
|
||||
benchmark_row = benchmark_decoding(
|
||||
imgs_dir,
|
||||
video_path,
|
||||
timestamps_mode,
|
||||
backend,
|
||||
ep_num_images,
|
||||
fps,
|
||||
num_samples,
|
||||
num_workers,
|
||||
save_frames,
|
||||
)
|
||||
benchmark_row.update(
|
||||
**{
|
||||
"repo_id": dataset.repo_id,
|
||||
"resolution": f"{width} x {height}",
|
||||
"num_pixels": num_pixels,
|
||||
"video_size_bytes": video_size_bytes,
|
||||
"images_size_bytes": images_size_bytes,
|
||||
"video_images_size_ratio": video_images_size_ratio,
|
||||
"timestamps_mode": timestamps_mode,
|
||||
"backend": backend,
|
||||
},
|
||||
**encoding_cfg,
|
||||
)
|
||||
benchmark_table.append(benchmark_row)
|
||||
|
||||
return benchmark_table
|
||||
|
||||
|
||||
def main(
|
||||
output_dir: Path,
|
||||
repo_ids: list[str],
|
||||
vcodec: list[str],
|
||||
pix_fmt: list[str],
|
||||
g: list[int],
|
||||
crf: list[int],
|
||||
# fastdecode: list[int],
|
||||
timestamps_modes: list[str],
|
||||
backends: list[str],
|
||||
num_samples: int,
|
||||
num_workers: int,
|
||||
save_frames: bool,
|
||||
):
|
||||
check_datasets_formats(repo_ids)
|
||||
encoding_benchmarks = {
|
||||
"g": g,
|
||||
"crf": crf,
|
||||
# "fastdecode": fastdecode,
|
||||
}
|
||||
decoding_benchmarks = {
|
||||
"timestamps_modes": timestamps_modes,
|
||||
"backends": backends,
|
||||
}
|
||||
headers = ["repo_id", "resolution", "num_pixels"]
|
||||
headers += list(BASE_ENCODING.keys())
|
||||
headers += [
|
||||
"timestamps_mode",
|
||||
"backend",
|
||||
"video_size_bytes",
|
||||
"images_size_bytes",
|
||||
"video_images_size_ratio",
|
||||
"avg_load_time_video_ms",
|
||||
"avg_load_time_images_ms",
|
||||
"video_images_load_time_ratio",
|
||||
"avg_mse",
|
||||
"avg_psnr",
|
||||
"avg_ssim",
|
||||
]
|
||||
file_paths = []
|
||||
for video_codec in tqdm(vcodec, desc="encodings (vcodec)"):
|
||||
for pixel_format in tqdm(pix_fmt, desc="encodings (pix_fmt)", leave=False):
|
||||
benchmark_table = []
|
||||
for repo_id in tqdm(repo_ids, desc="encodings (datasets)", leave=False):
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
imgs_dir = output_dir / "images" / dataset.repo_id.replace("/", "_")
|
||||
# We only use the first episode
|
||||
save_first_episode(imgs_dir, dataset)
|
||||
for key, values in tqdm(encoding_benchmarks.items(), desc="encodings (g, crf)", leave=False):
|
||||
for value in tqdm(values, desc=f"encodings ({key})", leave=False):
|
||||
encoding_cfg = BASE_ENCODING.copy()
|
||||
encoding_cfg["vcodec"] = video_codec
|
||||
encoding_cfg["pix_fmt"] = pixel_format
|
||||
encoding_cfg[key] = value
|
||||
args_path = Path("_".join(str(value) for value in encoding_cfg.values()))
|
||||
video_path = output_dir / "videos" / args_path / f"{repo_id.replace('/', '_')}.mp4"
|
||||
benchmark_table += benchmark_encoding_decoding(
|
||||
dataset,
|
||||
video_path,
|
||||
imgs_dir,
|
||||
encoding_cfg,
|
||||
decoding_benchmarks,
|
||||
num_samples,
|
||||
num_workers,
|
||||
save_frames,
|
||||
)
|
||||
|
||||
# Save intermediate results
|
||||
benchmark_df = pd.DataFrame(benchmark_table, columns=headers)
|
||||
now = dt.datetime.now()
|
||||
csv_path = (
|
||||
output_dir
|
||||
/ f"{now:%Y-%m-%d}_{now:%H-%M-%S}_{video_codec}_{pixel_format}_{num_samples}-samples.csv"
|
||||
)
|
||||
benchmark_df.to_csv(csv_path, header=True, index=False)
|
||||
file_paths.append(csv_path)
|
||||
del benchmark_df
|
||||
|
||||
# Concatenate all results
|
||||
df_list = [pd.read_csv(csv_path) for csv_path in file_paths]
|
||||
concatenated_df = pd.concat(df_list, ignore_index=True)
|
||||
concatenated_path = output_dir / f"{now:%Y-%m-%d}_{now:%H-%M-%S}_all_{num_samples}-samples.csv"
|
||||
concatenated_df.to_csv(concatenated_path, header=True, index=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
default=Path("outputs/video_benchmark"),
|
||||
help="Directory where the video benchmark outputs are written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo-ids",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=[
|
||||
"lerobot/pusht_image",
|
||||
"aliberts/aloha_mobile_shrimp_image",
|
||||
"aliberts/paris_street",
|
||||
"aliberts/kitchen",
|
||||
],
|
||||
help="Datasets repo-ids to test against. First episodes only are used. Must be images.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vcodec",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=["libx264", "libx265", "libsvtav1"],
|
||||
help="Video codecs to be tested",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pix-fmt",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=["yuv444p", "yuv420p"],
|
||||
help="Pixel formats (chroma subsampling) to be tested",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--g",
|
||||
type=parse_int_or_none,
|
||||
nargs="*",
|
||||
default=[1, 2, 3, 4, 5, 6, 10, 15, 20, 40, 100, None],
|
||||
help="Group of pictures sizes to be tested.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--crf",
|
||||
type=parse_int_or_none,
|
||||
nargs="*",
|
||||
default=[0, 5, 10, 15, 20, 25, 30, 40, 50, None],
|
||||
help="Constant rate factors to be tested.",
|
||||
)
|
||||
# parser.add_argument(
|
||||
# "--fastdecode",
|
||||
# type=int,
|
||||
# nargs="*",
|
||||
# default=[0, 1],
|
||||
# help="Use the fastdecode tuning option. 0 disables it. "
|
||||
# "For libx264 and libx265, only 1 is possible. "
|
||||
# "For libsvtav1, 1, 2 or 3 are possible values with a higher number meaning a faster decoding optimization",
|
||||
# )
|
||||
parser.add_argument(
|
||||
"--timestamps-modes",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=[
|
||||
"1_frame",
|
||||
"2_frames",
|
||||
"2_frames_4_space",
|
||||
"6_frames",
|
||||
],
|
||||
help="Timestamps scenarios to be tested.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backends",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=["pyav", "video_reader"],
|
||||
help="Torchvision decoding backend to be tested.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-samples",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Number of samples for each encoding x decoding config.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of processes for parallelized sample processing.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-frames",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Whether to save decoded frames or not. Enter a non-zero number for true.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(**vars(args))
|
|
@ -8,7 +8,7 @@ ARG DEBIAN_FRONTEND=noninteractive
|
|||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential cmake \
|
||||
git git-lfs openssh-client \
|
||||
nano vim less util-linux \
|
||||
nano vim less util-linux tree \
|
||||
htop atop nvtop \
|
||||
sed gawk grep curl wget zip unzip \
|
||||
tcpdump sysstat screen tmux \
|
||||
|
@ -16,6 +16,34 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
|||
python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install ffmpeg build dependencies. See:
|
||||
# https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu
|
||||
# TODO(aliberts): create image to build dependencies from source instead
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
autoconf automake yasm \
|
||||
libass-dev \
|
||||
libfreetype6-dev \
|
||||
libgnutls28-dev \
|
||||
libunistring-dev \
|
||||
libmp3lame-dev \
|
||||
libtool \
|
||||
libvorbis-dev \
|
||||
meson \
|
||||
ninja-build \
|
||||
pkg-config \
|
||||
texinfo \
|
||||
yasm \
|
||||
zlib1g-dev \
|
||||
nasm \
|
||||
libx264-dev \
|
||||
libx265-dev libnuma-dev \
|
||||
libvpx-dev \
|
||||
libfdk-aac-dev \
|
||||
libopus-dev \
|
||||
libsvtav1-dev libsvtav1enc-dev libsvtav1dec-dev \
|
||||
libdav1d-dev
|
||||
|
||||
|
||||
# Install gh cli tool
|
||||
RUN (type -p wget >/dev/null || (apt update && apt-get install wget -y)) \
|
||||
&& mkdir -p -m 755 /etc/apt/keyrings \
|
||||
|
|
|
@ -70,6 +70,8 @@ available_datasets_per_env = {
|
|||
"lerobot/aloha_sim_transfer_cube_human_image",
|
||||
"lerobot/aloha_sim_transfer_cube_scripted_image",
|
||||
],
|
||||
# TODO(alexander-soare): Add "lerobot/pusht_keypoints". Right now we can't because this is too tightly
|
||||
# coupled with tests.
|
||||
"pusht": ["lerobot/pusht", "lerobot/pusht_image"],
|
||||
"xarm": [
|
||||
"lerobot/xarm_lift_medium",
|
||||
|
|
|
@ -1,334 +0,0 @@
|
|||
# Video benchmark
|
||||
|
||||
|
||||
## Questions
|
||||
|
||||
What is the optimal trade-off between:
|
||||
- maximizing loading time with random access,
|
||||
- minimizing memory space on disk,
|
||||
- maximizing success rate of policies?
|
||||
|
||||
How to encode videos?
|
||||
- How much compression (`-crf`)? Low compression with `0`, normal compression with `20` or extreme with `56`?
|
||||
- What pixel format to use (`-pix_fmt`)? `yuv444p` or `yuv420p`?
|
||||
- How many key frames (`-g`)? A key frame every `10` frames?
|
||||
|
||||
How to decode videos?
|
||||
- Which `decoder`? `torchvision`, `torchaudio`, `ffmpegio`, `decord`, or `nvc`?
|
||||
|
||||
## Metrics
|
||||
|
||||
**Percentage of data compression (higher is better)**
|
||||
`compression_factor` is the ratio of the memory space on disk taken by the original images to encode, to the memory space taken by the encoded video. For instance, `compression_factor=4` means that the video takes 4 times less memory space on disk compared to the original images.
|
||||
|
||||
**Percentage of loading time (higher is better)**
|
||||
`load_time_factor` is the ratio of the time it takes to load original images at given timestamps, to the time it takes to decode the exact same frames from the video. Higher is better. For instance, `load_time_factor=0.5` means that decoding from video is 2 times slower than loading the original images.
|
||||
|
||||
**Average L2 error per pixel (lower is better)**
|
||||
`avg_per_pixel_l2_error` is the average L2 error between each decoded frame and its corresponding original image over all requested timestamps, and also divided by the number of pixels in the image to be comparable when switching to different image sizes.
|
||||
|
||||
**Loss of a pretrained policy (higher is better)** (not available)
|
||||
`loss_pretrained` is the result of evaluating with the selected encoding/decoding settings a policy pretrained on original images. It is easier to understand than `avg_l2_error`.
|
||||
|
||||
**Success rate after retraining (higher is better)** (not available)
|
||||
`success_rate` is the result of training and evaluating a policy with the selected encoding/decoding settings. It is the most difficult metric to get but also the very best.
|
||||
|
||||
|
||||
## Variables
|
||||
|
||||
**Image content**
|
||||
We don't expect the same optimal settings for a dataset of images from a simulation, or from real-world in an appartment, or in a factory, or outdoor, etc. Hence, we run this benchmark on two datasets: `pusht` (simulation) and `umi` (real-world outdoor).
|
||||
|
||||
**Requested timestamps**
|
||||
In this benchmark, we focus on the loading time of random access, so we are not interested in sequentially loading all frames of a video like in a movie. However, the number of consecutive timestamps requested and their spacing can greatly affect the `load_time_factor`. In fact, it is expected to get faster loading time by decoding a large number of consecutive frames from a video, than to load the same data from individual images. To reflect our robotics use case, we consider a few settings:
|
||||
- `single_frame`: 1 frame,
|
||||
- `2_frames`: 2 consecutive frames (e.g. `[t, t + 1 / fps]`),
|
||||
- `2_frames_4_space`: 2 consecutive frames with 4 frames of spacing (e.g `[t, t + 4 / fps]`),
|
||||
|
||||
**Data augmentations**
|
||||
We might revisit this benchmark and find better settings if we train our policies with various data augmentations to make them more robust (e.g. robust to color changes, compression, etc.).
|
||||
|
||||
|
||||
## Results
|
||||
|
||||
**`decoder`**
|
||||
| repo_id | decoder | load_time_factor | avg_per_pixel_l2_error |
|
||||
| --- | --- | --- | --- |
|
||||
| lerobot/pusht | <span style="color: #32CD32;">torchvision</span> | 0.166 | 0.0000119 |
|
||||
| lerobot/pusht | ffmpegio | 0.009 | 0.0001182 |
|
||||
| lerobot/pusht | torchaudio | 0.138 | 0.0000359 |
|
||||
| lerobot/umi_cup_in_the_wild | <span style="color: #32CD32;">torchvision</span> | 0.174 | 0.0000174 |
|
||||
| lerobot/umi_cup_in_the_wild | ffmpegio | 0.010 | 0.0000735 |
|
||||
| lerobot/umi_cup_in_the_wild | torchaudio | 0.154 | 0.0000340 |
|
||||
|
||||
### `1_frame`
|
||||
|
||||
**`pix_fmt`**
|
||||
| repo_id | pix_fmt | compression_factor | load_time_factor | avg_per_pixel_l2_error |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| lerobot/pusht | yuv420p | 3.788 | 0.224 | 0.0000760 |
|
||||
| lerobot/pusht | yuv444p | 3.646 | 0.185 | 0.0000443 |
|
||||
| lerobot/umi_cup_in_the_wild | yuv420p | 14.391 | 0.388 | 0.0000469 |
|
||||
| lerobot/umi_cup_in_the_wild | yuv444p | 14.932 | 0.329 | 0.0000397 |
|
||||
|
||||
**`g`**
|
||||
| repo_id | g | compression_factor | load_time_factor | avg_per_pixel_l2_error |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| lerobot/pusht | 1 | 2.543 | 0.204 | 0.0000556 |
|
||||
| lerobot/pusht | 2 | 3.646 | 0.182 | 0.0000443 |
|
||||
| lerobot/pusht | 3 | 4.431 | 0.174 | 0.0000450 |
|
||||
| lerobot/pusht | 4 | 5.103 | 0.163 | 0.0000448 |
|
||||
| lerobot/pusht | 5 | 5.625 | 0.163 | 0.0000436 |
|
||||
| lerobot/pusht | 6 | 5.974 | 0.155 | 0.0000427 |
|
||||
| lerobot/pusht | 10 | 6.814 | 0.130 | 0.0000410 |
|
||||
| lerobot/pusht | 15 | 7.431 | 0.105 | 0.0000406 |
|
||||
| lerobot/pusht | 20 | 7.662 | 0.097 | 0.0000400 |
|
||||
| lerobot/pusht | 40 | 8.163 | 0.061 | 0.0000405 |
|
||||
| lerobot/pusht | 100 | 8.761 | 0.039 | 0.0000422 |
|
||||
| lerobot/pusht | None | 8.909 | 0.024 | 0.0000431 |
|
||||
| lerobot/umi_cup_in_the_wild | 1 | 14.411 | 0.444 | 0.0000601 |
|
||||
| lerobot/umi_cup_in_the_wild | 2 | 14.932 | 0.345 | 0.0000397 |
|
||||
| lerobot/umi_cup_in_the_wild | 3 | 20.174 | 0.282 | 0.0000416 |
|
||||
| lerobot/umi_cup_in_the_wild | 4 | 24.889 | 0.271 | 0.0000415 |
|
||||
| lerobot/umi_cup_in_the_wild | 5 | 28.825 | 0.260 | 0.0000415 |
|
||||
| lerobot/umi_cup_in_the_wild | 6 | 31.635 | 0.249 | 0.0000415 |
|
||||
| lerobot/umi_cup_in_the_wild | 10 | 39.418 | 0.195 | 0.0000399 |
|
||||
| lerobot/umi_cup_in_the_wild | 15 | 44.577 | 0.169 | 0.0000394 |
|
||||
| lerobot/umi_cup_in_the_wild | 20 | 47.907 | 0.140 | 0.0000390 |
|
||||
| lerobot/umi_cup_in_the_wild | 40 | 52.554 | 0.096 | 0.0000384 |
|
||||
| lerobot/umi_cup_in_the_wild | 100 | 58.241 | 0.046 | 0.0000390 |
|
||||
| lerobot/umi_cup_in_the_wild | None | 60.530 | 0.022 | 0.0000400 |
|
||||
|
||||
**`crf`**
|
||||
| repo_id | crf | compression_factor | load_time_factor | avg_per_pixel_l2_error |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| lerobot/pusht | 0 | 1.699 | 0.175 | 0.0000035 |
|
||||
| lerobot/pusht | 5 | 1.409 | 0.181 | 0.0000080 |
|
||||
| lerobot/pusht | 10 | 1.842 | 0.172 | 0.0000123 |
|
||||
| lerobot/pusht | 15 | 2.322 | 0.187 | 0.0000211 |
|
||||
| lerobot/pusht | 20 | 3.050 | 0.181 | 0.0000346 |
|
||||
| lerobot/pusht | None | 3.646 | 0.189 | 0.0000443 |
|
||||
| lerobot/pusht | 25 | 3.969 | 0.186 | 0.0000521 |
|
||||
| lerobot/pusht | 30 | 5.687 | 0.184 | 0.0000850 |
|
||||
| lerobot/pusht | 40 | 10.818 | 0.193 | 0.0001726 |
|
||||
| lerobot/pusht | 50 | 18.185 | 0.183 | 0.0002606 |
|
||||
| lerobot/umi_cup_in_the_wild | 0 | 1.918 | 0.165 | 0.0000056 |
|
||||
| lerobot/umi_cup_in_the_wild | 5 | 3.207 | 0.171 | 0.0000111 |
|
||||
| lerobot/umi_cup_in_the_wild | 10 | 4.818 | 0.212 | 0.0000153 |
|
||||
| lerobot/umi_cup_in_the_wild | 15 | 7.329 | 0.261 | 0.0000218 |
|
||||
| lerobot/umi_cup_in_the_wild | 20 | 11.361 | 0.312 | 0.0000317 |
|
||||
| lerobot/umi_cup_in_the_wild | None | 14.932 | 0.339 | 0.0000397 |
|
||||
| lerobot/umi_cup_in_the_wild | 25 | 17.741 | 0.297 | 0.0000452 |
|
||||
| lerobot/umi_cup_in_the_wild | 30 | 27.983 | 0.406 | 0.0000629 |
|
||||
| lerobot/umi_cup_in_the_wild | 40 | 82.449 | 0.468 | 0.0001184 |
|
||||
| lerobot/umi_cup_in_the_wild | 50 | 186.145 | 0.515 | 0.0001879 |
|
||||
|
||||
**best**
|
||||
| repo_id | compression_factor | load_time_factor | avg_per_pixel_l2_error |
|
||||
| --- | --- | --- | --- |
|
||||
| lerobot/pusht | 3.646 | 0.188 | 0.0000443 |
|
||||
| lerobot/umi_cup_in_the_wild | 14.932 | 0.339 | 0.0000397 |
|
||||
|
||||
### `2_frames`
|
||||
|
||||
**`pix_fmt`**
|
||||
| repo_id | pix_fmt | compression_factor | load_time_factor | avg_per_pixel_l2_error |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| lerobot/pusht | yuv420p | 3.788 | 0.314 | 0.0000799 |
|
||||
| lerobot/pusht | yuv444p | 3.646 | 0.303 | 0.0000496 |
|
||||
| lerobot/umi_cup_in_the_wild | yuv420p | 14.391 | 0.642 | 0.0000503 |
|
||||
| lerobot/umi_cup_in_the_wild | yuv444p | 14.932 | 0.529 | 0.0000436 |
|
||||
|
||||
**`g`**
|
||||
| repo_id | g | compression_factor | load_time_factor | avg_per_pixel_l2_error |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| lerobot/pusht | 1 | 2.543 | 0.308 | 0.0000599 |
|
||||
| lerobot/pusht | 2 | 3.646 | 0.279 | 0.0000496 |
|
||||
| lerobot/pusht | 3 | 4.431 | 0.259 | 0.0000498 |
|
||||
| lerobot/pusht | 4 | 5.103 | 0.243 | 0.0000501 |
|
||||
| lerobot/pusht | 5 | 5.625 | 0.235 | 0.0000492 |
|
||||
| lerobot/pusht | 6 | 5.974 | 0.230 | 0.0000481 |
|
||||
| lerobot/pusht | 10 | 6.814 | 0.194 | 0.0000468 |
|
||||
| lerobot/pusht | 15 | 7.431 | 0.152 | 0.0000460 |
|
||||
| lerobot/pusht | 20 | 7.662 | 0.151 | 0.0000455 |
|
||||
| lerobot/pusht | 40 | 8.163 | 0.095 | 0.0000454 |
|
||||
| lerobot/pusht | 100 | 8.761 | 0.062 | 0.0000472 |
|
||||
| lerobot/pusht | None | 8.909 | 0.037 | 0.0000479 |
|
||||
| lerobot/umi_cup_in_the_wild | 1 | 14.411 | 0.638 | 0.0000625 |
|
||||
| lerobot/umi_cup_in_the_wild | 2 | 14.932 | 0.537 | 0.0000436 |
|
||||
| lerobot/umi_cup_in_the_wild | 3 | 20.174 | 0.493 | 0.0000437 |
|
||||
| lerobot/umi_cup_in_the_wild | 4 | 24.889 | 0.458 | 0.0000446 |
|
||||
| lerobot/umi_cup_in_the_wild | 5 | 28.825 | 0.438 | 0.0000445 |
|
||||
| lerobot/umi_cup_in_the_wild | 6 | 31.635 | 0.424 | 0.0000444 |
|
||||
| lerobot/umi_cup_in_the_wild | 10 | 39.418 | 0.345 | 0.0000435 |
|
||||
| lerobot/umi_cup_in_the_wild | 15 | 44.577 | 0.313 | 0.0000417 |
|
||||
| lerobot/umi_cup_in_the_wild | 20 | 47.907 | 0.264 | 0.0000421 |
|
||||
| lerobot/umi_cup_in_the_wild | 40 | 52.554 | 0.185 | 0.0000414 |
|
||||
| lerobot/umi_cup_in_the_wild | 100 | 58.241 | 0.090 | 0.0000420 |
|
||||
| lerobot/umi_cup_in_the_wild | None | 60.530 | 0.042 | 0.0000424 |
|
||||
|
||||
**`crf`**
|
||||
| repo_id | crf | compression_factor | load_time_factor | avg_per_pixel_l2_error |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| lerobot/pusht | 0 | 1.699 | 0.302 | 0.0000097 |
|
||||
| lerobot/pusht | 5 | 1.409 | 0.287 | 0.0000142 |
|
||||
| lerobot/pusht | 10 | 1.842 | 0.283 | 0.0000184 |
|
||||
| lerobot/pusht | 15 | 2.322 | 0.305 | 0.0000268 |
|
||||
| lerobot/pusht | 20 | 3.050 | 0.285 | 0.0000402 |
|
||||
| lerobot/pusht | None | 3.646 | 0.285 | 0.0000496 |
|
||||
| lerobot/pusht | 25 | 3.969 | 0.293 | 0.0000572 |
|
||||
| lerobot/pusht | 30 | 5.687 | 0.293 | 0.0000893 |
|
||||
| lerobot/pusht | 40 | 10.818 | 0.319 | 0.0001762 |
|
||||
| lerobot/pusht | 50 | 18.185 | 0.304 | 0.0002626 |
|
||||
| lerobot/umi_cup_in_the_wild | 0 | 1.918 | 0.235 | 0.0000112 |
|
||||
| lerobot/umi_cup_in_the_wild | 5 | 3.207 | 0.261 | 0.0000166 |
|
||||
| lerobot/umi_cup_in_the_wild | 10 | 4.818 | 0.333 | 0.0000207 |
|
||||
| lerobot/umi_cup_in_the_wild | 15 | 7.329 | 0.406 | 0.0000267 |
|
||||
| lerobot/umi_cup_in_the_wild | 20 | 11.361 | 0.489 | 0.0000361 |
|
||||
| lerobot/umi_cup_in_the_wild | None | 14.932 | 0.537 | 0.0000436 |
|
||||
| lerobot/umi_cup_in_the_wild | 25 | 17.741 | 0.578 | 0.0000487 |
|
||||
| lerobot/umi_cup_in_the_wild | 30 | 27.983 | 0.453 | 0.0000655 |
|
||||
| lerobot/umi_cup_in_the_wild | 40 | 82.449 | 0.767 | 0.0001192 |
|
||||
| lerobot/umi_cup_in_the_wild | 50 | 186.145 | 0.816 | 0.0001881 |
|
||||
|
||||
**best**
|
||||
| repo_id | compression_factor | load_time_factor | avg_per_pixel_l2_error |
|
||||
| --- | --- | --- | --- |
|
||||
| lerobot/pusht | 3.646 | 0.283 | 0.0000496 |
|
||||
| lerobot/umi_cup_in_the_wild | 14.932 | 0.543 | 0.0000436 |
|
||||
|
||||
### `2_frames_4_space`
|
||||
|
||||
**`pix_fmt`**
|
||||
| repo_id | pix_fmt | compression_factor | load_time_factor | avg_per_pixel_l2_error |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| lerobot/pusht | yuv420p | 3.788 | 0.257 | 0.0000855 |
|
||||
| lerobot/pusht | yuv444p | 3.646 | 0.261 | 0.0000556 |
|
||||
| lerobot/umi_cup_in_the_wild | yuv420p | 14.391 | 0.493 | 0.0000476 |
|
||||
| lerobot/umi_cup_in_the_wild | yuv444p | 14.932 | 0.371 | 0.0000404 |
|
||||
|
||||
**`g`**
|
||||
| repo_id | g | compression_factor | load_time_factor | avg_per_pixel_l2_error |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| lerobot/pusht | 1 | 2.543 | 0.226 | 0.0000670 |
|
||||
| lerobot/pusht | 2 | 3.646 | 0.222 | 0.0000556 |
|
||||
| lerobot/pusht | 3 | 4.431 | 0.217 | 0.0000567 |
|
||||
| lerobot/pusht | 4 | 5.103 | 0.204 | 0.0000555 |
|
||||
| lerobot/pusht | 5 | 5.625 | 0.179 | 0.0000556 |
|
||||
| lerobot/pusht | 6 | 5.974 | 0.188 | 0.0000544 |
|
||||
| lerobot/pusht | 10 | 6.814 | 0.160 | 0.0000531 |
|
||||
| lerobot/pusht | 15 | 7.431 | 0.150 | 0.0000521 |
|
||||
| lerobot/pusht | 20 | 7.662 | 0.123 | 0.0000519 |
|
||||
| lerobot/pusht | 40 | 8.163 | 0.092 | 0.0000519 |
|
||||
| lerobot/pusht | 100 | 8.761 | 0.053 | 0.0000533 |
|
||||
| lerobot/pusht | None | 8.909 | 0.034 | 0.0000541 |
|
||||
| lerobot/umi_cup_in_the_wild | 1 | 14.411 | 0.409 | 0.0000607 |
|
||||
| lerobot/umi_cup_in_the_wild | 2 | 14.932 | 0.381 | 0.0000404 |
|
||||
| lerobot/umi_cup_in_the_wild | 3 | 20.174 | 0.355 | 0.0000418 |
|
||||
| lerobot/umi_cup_in_the_wild | 4 | 24.889 | 0.346 | 0.0000425 |
|
||||
| lerobot/umi_cup_in_the_wild | 5 | 28.825 | 0.354 | 0.0000419 |
|
||||
| lerobot/umi_cup_in_the_wild | 6 | 31.635 | 0.336 | 0.0000419 |
|
||||
| lerobot/umi_cup_in_the_wild | 10 | 39.418 | 0.314 | 0.0000402 |
|
||||
| lerobot/umi_cup_in_the_wild | 15 | 44.577 | 0.269 | 0.0000397 |
|
||||
| lerobot/umi_cup_in_the_wild | 20 | 47.907 | 0.246 | 0.0000395 |
|
||||
| lerobot/umi_cup_in_the_wild | 40 | 52.554 | 0.171 | 0.0000390 |
|
||||
| lerobot/umi_cup_in_the_wild | 100 | 58.241 | 0.091 | 0.0000399 |
|
||||
| lerobot/umi_cup_in_the_wild | None | 60.530 | 0.043 | 0.0000409 |
|
||||
|
||||
**`crf`**
|
||||
| repo_id | crf | compression_factor | load_time_factor | avg_per_pixel_l2_error |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| lerobot/pusht | 0 | 1.699 | 0.212 | 0.0000193 |
|
||||
| lerobot/pusht | 5 | 1.409 | 0.211 | 0.0000232 |
|
||||
| lerobot/pusht | 10 | 1.842 | 0.199 | 0.0000270 |
|
||||
| lerobot/pusht | 15 | 2.322 | 0.198 | 0.0000347 |
|
||||
| lerobot/pusht | 20 | 3.050 | 0.211 | 0.0000469 |
|
||||
| lerobot/pusht | None | 3.646 | 0.206 | 0.0000556 |
|
||||
| lerobot/pusht | 25 | 3.969 | 0.210 | 0.0000626 |
|
||||
| lerobot/pusht | 30 | 5.687 | 0.223 | 0.0000927 |
|
||||
| lerobot/pusht | 40 | 10.818 | 0.227 | 0.0001763 |
|
||||
| lerobot/pusht | 50 | 18.185 | 0.223 | 0.0002625 |
|
||||
| lerobot/umi_cup_in_the_wild | 0 | 1.918 | 0.147 | 0.0000071 |
|
||||
| lerobot/umi_cup_in_the_wild | 5 | 3.207 | 0.182 | 0.0000125 |
|
||||
| lerobot/umi_cup_in_the_wild | 10 | 4.818 | 0.222 | 0.0000166 |
|
||||
| lerobot/umi_cup_in_the_wild | 15 | 7.329 | 0.270 | 0.0000229 |
|
||||
| lerobot/umi_cup_in_the_wild | 20 | 11.361 | 0.325 | 0.0000326 |
|
||||
| lerobot/umi_cup_in_the_wild | None | 14.932 | 0.362 | 0.0000404 |
|
||||
| lerobot/umi_cup_in_the_wild | 25 | 17.741 | 0.390 | 0.0000459 |
|
||||
| lerobot/umi_cup_in_the_wild | 30 | 27.983 | 0.437 | 0.0000633 |
|
||||
| lerobot/umi_cup_in_the_wild | 40 | 82.449 | 0.499 | 0.0001186 |
|
||||
| lerobot/umi_cup_in_the_wild | 50 | 186.145 | 0.564 | 0.0001879 |
|
||||
|
||||
**best**
|
||||
| repo_id | compression_factor | load_time_factor | avg_per_pixel_l2_error |
|
||||
| --- | --- | --- | --- |
|
||||
| lerobot/pusht | 3.646 | 0.224 | 0.0000556 |
|
||||
| lerobot/umi_cup_in_the_wild | 14.932 | 0.368 | 0.0000404 |
|
||||
|
||||
### `6_frames`
|
||||
|
||||
**`pix_fmt`**
|
||||
| repo_id | pix_fmt | compression_factor | load_time_factor | avg_per_pixel_l2_error |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| lerobot/pusht | yuv420p | 3.788 | 0.660 | 0.0000839 |
|
||||
| lerobot/pusht | yuv444p | 3.646 | 0.546 | 0.0000542 |
|
||||
| lerobot/umi_cup_in_the_wild | yuv420p | 14.391 | 1.225 | 0.0000497 |
|
||||
| lerobot/umi_cup_in_the_wild | yuv444p | 14.932 | 0.908 | 0.0000428 |
|
||||
|
||||
**`g`**
|
||||
| repo_id | g | compression_factor | load_time_factor | avg_per_pixel_l2_error |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| lerobot/pusht | 1 | 2.543 | 0.552 | 0.0000646 |
|
||||
| lerobot/pusht | 2 | 3.646 | 0.534 | 0.0000542 |
|
||||
| lerobot/pusht | 3 | 4.431 | 0.563 | 0.0000546 |
|
||||
| lerobot/pusht | 4 | 5.103 | 0.537 | 0.0000545 |
|
||||
| lerobot/pusht | 5 | 5.625 | 0.477 | 0.0000532 |
|
||||
| lerobot/pusht | 6 | 5.974 | 0.515 | 0.0000530 |
|
||||
| lerobot/pusht | 10 | 6.814 | 0.410 | 0.0000512 |
|
||||
| lerobot/pusht | 15 | 7.431 | 0.405 | 0.0000503 |
|
||||
| lerobot/pusht | 20 | 7.662 | 0.345 | 0.0000500 |
|
||||
| lerobot/pusht | 40 | 8.163 | 0.247 | 0.0000496 |
|
||||
| lerobot/pusht | 100 | 8.761 | 0.147 | 0.0000510 |
|
||||
| lerobot/pusht | None | 8.909 | 0.100 | 0.0000519 |
|
||||
| lerobot/umi_cup_in_the_wild | 1 | 14.411 | 0.997 | 0.0000620 |
|
||||
| lerobot/umi_cup_in_the_wild | 2 | 14.932 | 0.911 | 0.0000428 |
|
||||
| lerobot/umi_cup_in_the_wild | 3 | 20.174 | 0.869 | 0.0000433 |
|
||||
| lerobot/umi_cup_in_the_wild | 4 | 24.889 | 0.874 | 0.0000438 |
|
||||
| lerobot/umi_cup_in_the_wild | 5 | 28.825 | 0.864 | 0.0000439 |
|
||||
| lerobot/umi_cup_in_the_wild | 6 | 31.635 | 0.834 | 0.0000440 |
|
||||
| lerobot/umi_cup_in_the_wild | 10 | 39.418 | 0.781 | 0.0000421 |
|
||||
| lerobot/umi_cup_in_the_wild | 15 | 44.577 | 0.679 | 0.0000411 |
|
||||
| lerobot/umi_cup_in_the_wild | 20 | 47.907 | 0.652 | 0.0000410 |
|
||||
| lerobot/umi_cup_in_the_wild | 40 | 52.554 | 0.465 | 0.0000404 |
|
||||
| lerobot/umi_cup_in_the_wild | 100 | 58.241 | 0.245 | 0.0000413 |
|
||||
| lerobot/umi_cup_in_the_wild | None | 60.530 | 0.116 | 0.0000417 |
|
||||
|
||||
**`crf`**
|
||||
| repo_id | crf | compression_factor | load_time_factor | avg_per_pixel_l2_error |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| lerobot/pusht | 0 | 1.699 | 0.534 | 0.0000163 |
|
||||
| lerobot/pusht | 5 | 1.409 | 0.524 | 0.0000205 |
|
||||
| lerobot/pusht | 10 | 1.842 | 0.510 | 0.0000245 |
|
||||
| lerobot/pusht | 15 | 2.322 | 0.512 | 0.0000324 |
|
||||
| lerobot/pusht | 20 | 3.050 | 0.508 | 0.0000452 |
|
||||
| lerobot/pusht | None | 3.646 | 0.518 | 0.0000542 |
|
||||
| lerobot/pusht | 25 | 3.969 | 0.534 | 0.0000616 |
|
||||
| lerobot/pusht | 30 | 5.687 | 0.530 | 0.0000927 |
|
||||
| lerobot/pusht | 40 | 10.818 | 0.552 | 0.0001777 |
|
||||
| lerobot/pusht | 50 | 18.185 | 0.564 | 0.0002644 |
|
||||
| lerobot/umi_cup_in_the_wild | 0 | 1.918 | 0.401 | 0.0000101 |
|
||||
| lerobot/umi_cup_in_the_wild | 5 | 3.207 | 0.499 | 0.0000156 |
|
||||
| lerobot/umi_cup_in_the_wild | 10 | 4.818 | 0.599 | 0.0000197 |
|
||||
| lerobot/umi_cup_in_the_wild | 15 | 7.329 | 0.704 | 0.0000258 |
|
||||
| lerobot/umi_cup_in_the_wild | 20 | 11.361 | 0.834 | 0.0000352 |
|
||||
| lerobot/umi_cup_in_the_wild | None | 14.932 | 0.925 | 0.0000428 |
|
||||
| lerobot/umi_cup_in_the_wild | 25 | 17.741 | 0.978 | 0.0000480 |
|
||||
| lerobot/umi_cup_in_the_wild | 30 | 27.983 | 1.088 | 0.0000648 |
|
||||
| lerobot/umi_cup_in_the_wild | 40 | 82.449 | 1.324 | 0.0001190 |
|
||||
| lerobot/umi_cup_in_the_wild | 50 | 186.145 | 1.436 | 0.0001880 |
|
||||
|
||||
**best**
|
||||
| repo_id | compression_factor | load_time_factor | avg_per_pixel_l2_error |
|
||||
| --- | --- | --- | --- |
|
||||
| lerobot/pusht | 3.646 | 0.546 | 0.0000542 |
|
||||
| lerobot/umi_cup_in_the_wild | 14.932 | 0.934 | 0.0000428 |
|
|
@ -1,409 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Assess the performance of video decoding in various configurations.
|
||||
|
||||
This script will run different video decoding benchmarks where one parameter varies at a time.
|
||||
These parameters and theirs values are specified in the BENCHMARKS dict.
|
||||
|
||||
All of these benchmarks are evaluated within different timestamps modes corresponding to different frame-loading scenarios:
|
||||
- `1_frame`: 1 single frame is loaded.
|
||||
- `2_frames`: 2 consecutive frames are loaded.
|
||||
- `2_frames_4_space`: 2 frames separated by 4 frames are loaded.
|
||||
- `6_frames`: 6 consecutive frames are loaded.
|
||||
|
||||
These values are more or less arbitrary and based on possible future usage.
|
||||
|
||||
These benchmarks are run on the first episode of each dataset specified in DATASET_REPO_IDS.
|
||||
Note: These datasets need to be image datasets, not video datasets.
|
||||
"""
|
||||
|
||||
import json
|
||||
import random
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.video_utils import (
|
||||
decode_video_frames_torchvision,
|
||||
)
|
||||
|
||||
OUTPUT_DIR = Path("tmp/run_video_benchmark")
|
||||
DRY_RUN = False
|
||||
|
||||
DATASET_REPO_IDS = [
|
||||
"lerobot/pusht_image",
|
||||
"aliberts/aloha_mobile_shrimp_image",
|
||||
"aliberts/paris_street",
|
||||
"aliberts/kitchen",
|
||||
]
|
||||
TIMESTAMPS_MODES = [
|
||||
"1_frame",
|
||||
"2_frames",
|
||||
"2_frames_4_space",
|
||||
"6_frames",
|
||||
]
|
||||
BENCHMARKS = {
|
||||
# "pix_fmt": ["yuv420p", "yuv444p"],
|
||||
# "g": [1, 2, 3, 4, 5, 6, 10, 15, 20, 40, 100, None],
|
||||
# "crf": [0, 5, 10, 15, 20, None, 25, 30, 40, 50],
|
||||
"backend": ["pyav", "video_reader"],
|
||||
}
|
||||
|
||||
|
||||
def get_directory_size(directory):
|
||||
total_size = 0
|
||||
# Iterate over all files and subdirectories recursively
|
||||
for item in directory.rglob("*"):
|
||||
if item.is_file():
|
||||
# Add the file size to the total
|
||||
total_size += item.stat().st_size
|
||||
return total_size
|
||||
|
||||
|
||||
def run_video_benchmark(
|
||||
output_dir,
|
||||
cfg,
|
||||
timestamps_mode,
|
||||
seed=1337,
|
||||
):
|
||||
output_dir = Path(output_dir)
|
||||
if output_dir.exists():
|
||||
shutil.rmtree(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
repo_id = cfg["repo_id"]
|
||||
|
||||
# TODO(rcadene): rewrite with hardcoding of original images and episodes
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
if dataset.video:
|
||||
raise ValueError(
|
||||
f"Use only image dataset for running this benchmark. Video dataset provided: {repo_id}"
|
||||
)
|
||||
|
||||
# Get fps
|
||||
fps = dataset.fps
|
||||
|
||||
# we only load first episode
|
||||
ep_num_images = dataset.episode_data_index["to"][0].item()
|
||||
|
||||
# Save/Load image directory for the first episode
|
||||
imgs_dir = Path(f"tmp/data/images/{repo_id}/observation.image_episode_000000")
|
||||
if not imgs_dir.exists():
|
||||
imgs_dir.mkdir(parents=True, exist_ok=True)
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
img_keys = [key for key in hf_dataset.features if key.startswith("observation.image")]
|
||||
imgs_dataset = hf_dataset.select_columns(img_keys[0])
|
||||
|
||||
for i, item in enumerate(imgs_dataset):
|
||||
img = item[img_keys[0]]
|
||||
img.save(str(imgs_dir / f"frame_{i:06d}.png"), quality=100)
|
||||
|
||||
if i >= ep_num_images - 1:
|
||||
break
|
||||
|
||||
sum_original_frames_size_bytes = get_directory_size(imgs_dir)
|
||||
|
||||
# Encode images into video
|
||||
video_path = output_dir / "episode_0.mp4"
|
||||
|
||||
g = cfg.get("g")
|
||||
crf = cfg.get("crf")
|
||||
pix_fmt = cfg["pix_fmt"]
|
||||
|
||||
cmd = f"ffmpeg -r {fps} "
|
||||
cmd += "-f image2 "
|
||||
cmd += "-loglevel error "
|
||||
cmd += f"-i {str(imgs_dir / 'frame_%06d.png')} "
|
||||
cmd += "-vcodec libx264 "
|
||||
if g is not None:
|
||||
cmd += f"-g {g} " # ensures at least 1 keyframe every 10 frames
|
||||
# cmd += "-keyint_min 10 " set a minimum of 10 frames between 2 key frames
|
||||
# cmd += "-sc_threshold 0 " disable scene change detection to lower the number of key frames
|
||||
if crf is not None:
|
||||
cmd += f"-crf {crf} "
|
||||
cmd += f"-pix_fmt {pix_fmt} "
|
||||
cmd += f"{str(video_path)}"
|
||||
subprocess.run(cmd.split(" "), check=True)
|
||||
|
||||
video_size_bytes = video_path.stat().st_size
|
||||
|
||||
# Set decoder
|
||||
|
||||
decoder = cfg["decoder"]
|
||||
decoder_kwgs = cfg["decoder_kwgs"]
|
||||
backend = cfg["backend"]
|
||||
|
||||
if decoder == "torchvision":
|
||||
decode_frames_fn = decode_video_frames_torchvision
|
||||
else:
|
||||
raise ValueError(decoder)
|
||||
|
||||
# Estimate average loading time
|
||||
|
||||
def load_original_frames(imgs_dir, timestamps) -> torch.Tensor:
|
||||
frames = []
|
||||
for ts in timestamps:
|
||||
idx = int(ts * fps)
|
||||
frame = PIL.Image.open(imgs_dir / f"frame_{idx:06d}.png")
|
||||
frame = torch.from_numpy(np.array(frame))
|
||||
frame = frame.type(torch.float32) / 255
|
||||
frame = einops.rearrange(frame, "h w c -> c h w")
|
||||
frames.append(frame)
|
||||
return frames
|
||||
|
||||
list_avg_load_time = []
|
||||
list_avg_load_time_from_images = []
|
||||
per_pixel_l2_errors = []
|
||||
psnr_values = []
|
||||
ssim_values = []
|
||||
mse_values = []
|
||||
|
||||
random.seed(seed)
|
||||
|
||||
for t in range(50):
|
||||
# test loading 2 frames that are 4 frames appart, which might be a common setting
|
||||
ts = random.randint(fps, ep_num_images - fps) / fps
|
||||
|
||||
if timestamps_mode == "1_frame":
|
||||
timestamps = [ts]
|
||||
elif timestamps_mode == "2_frames":
|
||||
timestamps = [ts - 1 / fps, ts]
|
||||
elif timestamps_mode == "2_frames_4_space":
|
||||
timestamps = [ts - 5 / fps, ts]
|
||||
elif timestamps_mode == "6_frames":
|
||||
timestamps = [ts - i / fps for i in range(6)][::-1]
|
||||
else:
|
||||
raise ValueError(timestamps_mode)
|
||||
|
||||
num_frames = len(timestamps)
|
||||
|
||||
start_time_s = time.monotonic()
|
||||
frames = decode_frames_fn(
|
||||
video_path, timestamps=timestamps, tolerance_s=1e-4, backend=backend, **decoder_kwgs
|
||||
)
|
||||
avg_load_time = (time.monotonic() - start_time_s) / num_frames
|
||||
list_avg_load_time.append(avg_load_time)
|
||||
|
||||
start_time_s = time.monotonic()
|
||||
original_frames = load_original_frames(imgs_dir, timestamps)
|
||||
avg_load_time_from_images = (time.monotonic() - start_time_s) / num_frames
|
||||
list_avg_load_time_from_images.append(avg_load_time_from_images)
|
||||
|
||||
# Estimate reconstruction error between original frames and decoded frames with various metrics
|
||||
for i, ts in enumerate(timestamps):
|
||||
# are_close = torch.allclose(frames[i], original_frames[i], atol=0.02)
|
||||
num_pixels = original_frames[i].numel()
|
||||
per_pixel_l2_error = torch.norm(frames[i] - original_frames[i], p=2).item() / num_pixels
|
||||
per_pixel_l2_errors.append(per_pixel_l2_error)
|
||||
|
||||
frame_np, original_frame_np = frames[i].numpy(), original_frames[i].numpy()
|
||||
psnr_values.append(peak_signal_noise_ratio(original_frame_np, frame_np, data_range=1.0))
|
||||
ssim_values.append(
|
||||
structural_similarity(original_frame_np, frame_np, data_range=1.0, channel_axis=0)
|
||||
)
|
||||
mse_values.append(mean_squared_error(original_frame_np, frame_np))
|
||||
|
||||
# save decoded frames
|
||||
if t == 0:
|
||||
frame_hwc = (frames[i].permute((1, 2, 0)) * 255).type(torch.uint8).cpu().numpy()
|
||||
PIL.Image.fromarray(frame_hwc).save(output_dir / f"frame_{i:06d}.png")
|
||||
|
||||
# save original_frames
|
||||
idx = int(ts * fps)
|
||||
if t == 0:
|
||||
original_frame = PIL.Image.open(imgs_dir / f"frame_{idx:06d}.png")
|
||||
original_frame.save(output_dir / f"original_frame_{i:06d}.png")
|
||||
|
||||
image_size = tuple(dataset[0][dataset.camera_keys[0]].shape[-2:])
|
||||
avg_load_time = float(np.array(list_avg_load_time).mean())
|
||||
avg_load_time_from_images = float(np.array(list_avg_load_time_from_images).mean())
|
||||
avg_per_pixel_l2_error = float(np.array(per_pixel_l2_errors).mean())
|
||||
avg_psnr = float(np.mean(psnr_values))
|
||||
avg_ssim = float(np.mean(ssim_values))
|
||||
avg_mse = float(np.mean(mse_values))
|
||||
|
||||
# Save benchmark info
|
||||
|
||||
info = {
|
||||
"image_size": image_size,
|
||||
"sum_original_frames_size_bytes": sum_original_frames_size_bytes,
|
||||
"video_size_bytes": video_size_bytes,
|
||||
"avg_load_time_from_images": avg_load_time_from_images,
|
||||
"avg_load_time": avg_load_time,
|
||||
"compression_factor": sum_original_frames_size_bytes / video_size_bytes,
|
||||
"load_time_factor": avg_load_time_from_images / avg_load_time,
|
||||
"avg_per_pixel_l2_error": avg_per_pixel_l2_error,
|
||||
"avg_psnr": avg_psnr,
|
||||
"avg_ssim": avg_ssim,
|
||||
"avg_mse": avg_mse,
|
||||
}
|
||||
|
||||
with open(output_dir / "info.json", "w") as f:
|
||||
json.dump(info, f)
|
||||
|
||||
return info
|
||||
|
||||
|
||||
def display_markdown_table(headers, rows):
|
||||
for i, row in enumerate(rows):
|
||||
new_row = []
|
||||
for col in row:
|
||||
if col is None:
|
||||
new_col = "None"
|
||||
elif isinstance(col, float):
|
||||
new_col = f"{col:.3f}"
|
||||
if new_col == "0.000":
|
||||
new_col = f"{col:.7f}"
|
||||
elif isinstance(col, int):
|
||||
new_col = f"{col}"
|
||||
else:
|
||||
new_col = col
|
||||
new_row.append(new_col)
|
||||
rows[i] = new_row
|
||||
|
||||
header_line = "| " + " | ".join(headers) + " |"
|
||||
separator_line = "| " + " | ".join(["---" for _ in headers]) + " |"
|
||||
body_lines = ["| " + " | ".join(row) + " |" for row in rows]
|
||||
markdown_table = "\n".join([header_line, separator_line] + body_lines)
|
||||
print(markdown_table)
|
||||
print()
|
||||
|
||||
|
||||
def load_info(out_dir):
|
||||
with open(out_dir / "info.json") as f:
|
||||
info = json.load(f)
|
||||
return info
|
||||
|
||||
|
||||
def one_variable_study(
|
||||
var_name: str, var_values: list, repo_ids: list, bench_dir: Path, timestamps_mode: str, dry_run: bool
|
||||
):
|
||||
print(f"**`{var_name}`**")
|
||||
headers = [
|
||||
"repo_id",
|
||||
"image_size",
|
||||
var_name,
|
||||
"compression_factor",
|
||||
"load_time_factor",
|
||||
"avg_per_pixel_l2_error",
|
||||
"avg_psnr",
|
||||
"avg_ssim",
|
||||
"avg_mse",
|
||||
]
|
||||
rows = []
|
||||
base_cfg = {
|
||||
"repo_id": None,
|
||||
# video encoding
|
||||
"g": 2,
|
||||
"crf": None,
|
||||
"pix_fmt": "yuv444p",
|
||||
# video decoding
|
||||
"backend": "pyav",
|
||||
"decoder": "torchvision",
|
||||
"decoder_kwgs": {},
|
||||
}
|
||||
for repo_id in repo_ids:
|
||||
for val in var_values:
|
||||
cfg = base_cfg.copy()
|
||||
cfg["repo_id"] = repo_id
|
||||
cfg[var_name] = val
|
||||
if not dry_run:
|
||||
run_video_benchmark(
|
||||
bench_dir / repo_id / f"torchvision_{var_name}_{val}", cfg, timestamps_mode
|
||||
)
|
||||
info = load_info(bench_dir / repo_id / f"torchvision_{var_name}_{val}")
|
||||
width, height = info["image_size"][0], info["image_size"][1]
|
||||
rows.append(
|
||||
[
|
||||
repo_id,
|
||||
f"{width} x {height}",
|
||||
val,
|
||||
info["compression_factor"],
|
||||
info["load_time_factor"],
|
||||
info["avg_per_pixel_l2_error"],
|
||||
info["avg_psnr"],
|
||||
info["avg_ssim"],
|
||||
info["avg_mse"],
|
||||
]
|
||||
)
|
||||
display_markdown_table(headers, rows)
|
||||
|
||||
|
||||
def best_study(repo_ids: list, bench_dir: Path, timestamps_mode: str, dry_run: bool):
|
||||
"""Change the config once you deciced what's best based on one-variable-studies"""
|
||||
print("**best**")
|
||||
headers = [
|
||||
"repo_id",
|
||||
"image_size",
|
||||
"compression_factor",
|
||||
"load_time_factor",
|
||||
"avg_per_pixel_l2_error",
|
||||
"avg_psnr",
|
||||
"avg_ssim",
|
||||
"avg_mse",
|
||||
]
|
||||
rows = []
|
||||
for repo_id in repo_ids:
|
||||
cfg = {
|
||||
"repo_id": repo_id,
|
||||
# video encoding
|
||||
"g": 2,
|
||||
"crf": None,
|
||||
"pix_fmt": "yuv444p",
|
||||
# video decoding
|
||||
"backend": "video_reader",
|
||||
"decoder": "torchvision",
|
||||
"decoder_kwgs": {},
|
||||
}
|
||||
if not dry_run:
|
||||
run_video_benchmark(bench_dir / repo_id / "torchvision_best", cfg, timestamps_mode)
|
||||
info = load_info(bench_dir / repo_id / "torchvision_best")
|
||||
width, height = info["image_size"][0], info["image_size"][1]
|
||||
rows.append(
|
||||
[
|
||||
repo_id,
|
||||
f"{width} x {height}",
|
||||
info["compression_factor"],
|
||||
info["load_time_factor"],
|
||||
info["avg_per_pixel_l2_error"],
|
||||
]
|
||||
)
|
||||
display_markdown_table(headers, rows)
|
||||
|
||||
|
||||
def main():
|
||||
for timestamps_mode in TIMESTAMPS_MODES:
|
||||
bench_dir = OUTPUT_DIR / timestamps_mode
|
||||
|
||||
print(f"### `{timestamps_mode}`")
|
||||
print()
|
||||
|
||||
for name, values in BENCHMARKS.items():
|
||||
one_variable_study(name, values, DATASET_REPO_IDS, bench_dir, timestamps_mode, DRY_RUN)
|
||||
|
||||
# best_study(DATASET_REPO_IDS, bench_dir, timestamps_mode, DRY_RUN)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -36,7 +36,7 @@ from lerobot.common.datasets.utils import (
|
|||
from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos
|
||||
|
||||
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||
CODEBASE_VERSION = "v1.4"
|
||||
CODEBASE_VERSION = "v1.5"
|
||||
|
||||
|
||||
class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
|
|
@ -54,7 +54,14 @@ def check_format(raw_dir):
|
|||
assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
|
||||
|
||||
|
||||
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
|
||||
def load_from_raw(
|
||||
raw_dir: Path,
|
||||
videos_dir: Path,
|
||||
fps: int,
|
||||
video: bool,
|
||||
episodes: list[int] | None = None,
|
||||
keypoints_instead_of_image: bool = False,
|
||||
):
|
||||
try:
|
||||
import pymunk
|
||||
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
|
||||
|
@ -105,10 +112,11 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
|
|||
assert (episode_ids[from_idx:to_idx] == ep_idx).all()
|
||||
|
||||
# get image
|
||||
image = imgs[from_idx:to_idx]
|
||||
assert image.min() >= 0.0
|
||||
assert image.max() <= 255.0
|
||||
image = image.type(torch.uint8)
|
||||
if not keypoints_instead_of_image:
|
||||
image = imgs[from_idx:to_idx]
|
||||
assert image.min() >= 0.0
|
||||
assert image.max() <= 255.0
|
||||
image = image.type(torch.uint8)
|
||||
|
||||
# get state
|
||||
state = states[from_idx:to_idx]
|
||||
|
@ -116,9 +124,11 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
|
|||
block_pos = state[:, 2:4]
|
||||
block_angle = state[:, 4]
|
||||
|
||||
# get reward, success, done
|
||||
# get reward, success, done, and (maybe) keypoints
|
||||
reward = torch.zeros(num_frames)
|
||||
success = torch.zeros(num_frames, dtype=torch.bool)
|
||||
if keypoints_instead_of_image:
|
||||
keypoints = torch.zeros(num_frames, 16) # 8 keypoints each with 2 coords
|
||||
done = torch.zeros(num_frames, dtype=torch.bool)
|
||||
for i in range(num_frames):
|
||||
space = pymunk.Space()
|
||||
|
@ -134,7 +144,7 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
|
|||
]
|
||||
space.add(*walls)
|
||||
|
||||
block_body = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
|
||||
block_body, block_shapes = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
|
||||
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
|
||||
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
|
||||
intersection_area = goal_geom.intersection(block_geom).area
|
||||
|
@ -142,33 +152,40 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
|
|||
coverage = intersection_area / goal_area
|
||||
reward[i] = np.clip(coverage / success_threshold, 0, 1)
|
||||
success[i] = coverage > success_threshold
|
||||
if keypoints_instead_of_image:
|
||||
keypoints[i] = torch.from_numpy(PushTEnv.get_keypoints(block_shapes).flatten())
|
||||
|
||||
# last step of demonstration is considered done
|
||||
done[-1] = True
|
||||
|
||||
ep_dict = {}
|
||||
|
||||
imgs_array = [x.numpy() for x in image]
|
||||
img_key = "observation.image"
|
||||
if video:
|
||||
# save png images in temporary directory
|
||||
tmp_imgs_dir = videos_dir / "tmp_images"
|
||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||
if not keypoints_instead_of_image:
|
||||
imgs_array = [x.numpy() for x in image]
|
||||
img_key = "observation.image"
|
||||
if video:
|
||||
# save png images in temporary directory
|
||||
tmp_imgs_dir = videos_dir / "tmp_images"
|
||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||
|
||||
# encode images to a mp4 video
|
||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||
video_path = videos_dir / fname
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
||||
# encode images to a mp4 video
|
||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||
video_path = videos_dir / fname
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
||||
|
||||
# clean temporary images directory
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
# clean temporary images directory
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
|
||||
# store the reference to the video frame
|
||||
ep_dict[img_key] = [{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)]
|
||||
else:
|
||||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
||||
# store the reference to the video frame
|
||||
ep_dict[img_key] = [
|
||||
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
|
||||
]
|
||||
else:
|
||||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
||||
|
||||
ep_dict["observation.state"] = agent_pos
|
||||
if keypoints_instead_of_image:
|
||||
ep_dict["observation.environment_state"] = keypoints
|
||||
ep_dict["action"] = actions[from_idx:to_idx]
|
||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
|
||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||
|
@ -180,7 +197,6 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
|
|||
ep_dict["next.done"] = torch.cat([done[1:], done[[-1]]])
|
||||
ep_dict["next.success"] = torch.cat([success[1:], success[[-1]]])
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
|
||||
total_frames = data_dict["frame_index"].shape[0]
|
||||
|
@ -188,17 +204,23 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
|
|||
return data_dict
|
||||
|
||||
|
||||
def to_hf_dataset(data_dict, video):
|
||||
def to_hf_dataset(data_dict, video, keypoints_instead_of_image: bool = False):
|
||||
features = {}
|
||||
|
||||
if video:
|
||||
features["observation.image"] = VideoFrame()
|
||||
else:
|
||||
features["observation.image"] = Image()
|
||||
if not keypoints_instead_of_image:
|
||||
if video:
|
||||
features["observation.image"] = VideoFrame()
|
||||
else:
|
||||
features["observation.image"] = Image()
|
||||
|
||||
features["observation.state"] = Sequence(
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
if keypoints_instead_of_image:
|
||||
features["observation.environment_state"] = Sequence(
|
||||
length=data_dict["observation.environment_state"].shape[1],
|
||||
feature=Value(dtype="float32", id=None),
|
||||
)
|
||||
features["action"] = Sequence(
|
||||
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
|
@ -222,17 +244,21 @@ def from_raw_to_lerobot_format(
|
|||
video: bool = True,
|
||||
episodes: list[int] | None = None,
|
||||
):
|
||||
# Manually change this to True to use keypoints of the T instead of an image observation (but don't merge
|
||||
# with True). Also make sure to use video = 0 in the `push_dataset_to_hub.py` script.
|
||||
keypoints_instead_of_image = False
|
||||
|
||||
# sanity check
|
||||
check_format(raw_dir)
|
||||
|
||||
if fps is None:
|
||||
fps = 10
|
||||
|
||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
|
||||
hf_dataset = to_hf_dataset(data_dict, video)
|
||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, keypoints_instead_of_image)
|
||||
hf_dataset = to_hf_dataset(data_dict, video, keypoints_instead_of_image)
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
info = {
|
||||
"fps": fps,
|
||||
"video": video,
|
||||
"video": video if not keypoints_instead_of_image else 0,
|
||||
}
|
||||
return hf_dataset, episode_data_index, info
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
import logging
|
||||
import subprocess
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, ClassVar
|
||||
|
@ -69,7 +70,7 @@ def decode_video_frames_torchvision(
|
|||
tolerance_s: float,
|
||||
backend: str = "pyav",
|
||||
log_loaded_timestamps: bool = False,
|
||||
):
|
||||
) -> torch.Tensor:
|
||||
"""Loads frames associated to the requested timestamps of a video
|
||||
|
||||
The backend can be either "pyav" (default) or "video_reader".
|
||||
|
@ -77,9 +78,8 @@ def decode_video_frames_torchvision(
|
|||
https://github.com/pytorch/vision/blob/main/torchvision/csrc/io/decoder/gpu/README.rst
|
||||
(note that you need to compile against ffmpeg<4.3)
|
||||
|
||||
While both use cpu, "video_reader" is faster than "pyav" but requires additional setup.
|
||||
See our benchmark results for more info on performance:
|
||||
https://github.com/huggingface/lerobot/pull/220
|
||||
While both use cpu, "video_reader" is supposedly faster than "pyav" but requires additional setup.
|
||||
For more info on video decoding, see `benchmark/video/README.md`
|
||||
|
||||
See torchvision doc for more info on these two backends:
|
||||
https://pytorch.org/vision/0.18/index.html?highlight=backend#torchvision.set_video_backend
|
||||
|
@ -142,6 +142,10 @@ def decode_video_frames_torchvision(
|
|||
"It means that the closest frame that can be loaded from the video is too far away in time."
|
||||
"This might be due to synchronization issues with timestamps during data collection."
|
||||
"To be safe, we advise to ignore this item during training."
|
||||
f"\nqueried timestamps: {query_ts}"
|
||||
f"\nloaded timestamps: {loaded_ts}"
|
||||
f"\nvideo: {video_path}"
|
||||
f"\nbackend: {backend}"
|
||||
)
|
||||
|
||||
# get closest frames to the query timestamps
|
||||
|
@ -158,22 +162,52 @@ def decode_video_frames_torchvision(
|
|||
return closest_frames
|
||||
|
||||
|
||||
def encode_video_frames(imgs_dir: Path, video_path: Path, fps: int):
|
||||
"""More info on ffmpeg arguments tuning on `lerobot/common/datasets/_video_benchmark/README.md`"""
|
||||
def encode_video_frames(
|
||||
imgs_dir: Path,
|
||||
video_path: Path,
|
||||
fps: int,
|
||||
video_codec: str = "libsvtav1",
|
||||
pixel_format: str = "yuv420p",
|
||||
group_of_pictures_size: int | None = 2,
|
||||
constant_rate_factor: int | None = 30,
|
||||
fast_decode: int = 0,
|
||||
log_level: str | None = "error",
|
||||
overwrite: bool = False,
|
||||
) -> None:
|
||||
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
|
||||
video_path = Path(video_path)
|
||||
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ffmpeg_cmd = (
|
||||
f"ffmpeg -r {fps} "
|
||||
"-f image2 "
|
||||
"-loglevel error "
|
||||
f"-i {str(imgs_dir / 'frame_%06d.png')} "
|
||||
"-vcodec libx264 "
|
||||
"-g 2 "
|
||||
"-pix_fmt yuv444p "
|
||||
f"{str(video_path)}"
|
||||
ffmpeg_args = OrderedDict(
|
||||
[
|
||||
("-f", "image2"),
|
||||
("-r", str(fps)),
|
||||
("-i", str(imgs_dir / "frame_%06d.png")),
|
||||
("-vcodec", video_codec),
|
||||
("-pix_fmt", pixel_format),
|
||||
]
|
||||
)
|
||||
subprocess.run(ffmpeg_cmd.split(" "), check=True)
|
||||
|
||||
if group_of_pictures_size is not None:
|
||||
ffmpeg_args["-g"] = str(group_of_pictures_size)
|
||||
|
||||
if constant_rate_factor is not None:
|
||||
ffmpeg_args["-crf"] = str(constant_rate_factor)
|
||||
|
||||
if fast_decode:
|
||||
key = "-svtav1-params" if video_codec == "libsvtav1" else "-tune"
|
||||
value = f"fast-decode={fast_decode}" if video_codec == "libsvtav1" else "fastdecode"
|
||||
ffmpeg_args[key] = value
|
||||
|
||||
if log_level is not None:
|
||||
ffmpeg_args["-loglevel"] = str(log_level)
|
||||
|
||||
ffmpeg_args = [item for pair in ffmpeg_args.items() for item in pair]
|
||||
if overwrite:
|
||||
ffmpeg_args.append("-y")
|
||||
|
||||
ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(video_path)]
|
||||
subprocess.run(ffmpeg_cmd, check=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
@ -28,31 +28,35 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
|||
"""
|
||||
# map to expected inputs for the policy
|
||||
return_observations = {}
|
||||
if "pixels" in observations:
|
||||
if isinstance(observations["pixels"], dict):
|
||||
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
|
||||
else:
|
||||
imgs = {"observation.image": observations["pixels"]}
|
||||
|
||||
if isinstance(observations["pixels"], dict):
|
||||
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
|
||||
else:
|
||||
imgs = {"observation.image": observations["pixels"]}
|
||||
for imgkey, img in imgs.items():
|
||||
img = torch.from_numpy(img)
|
||||
|
||||
for imgkey, img in imgs.items():
|
||||
img = torch.from_numpy(img)
|
||||
# sanity check that images are channel last
|
||||
_, h, w, c = img.shape
|
||||
assert c < h and c < w, f"expect channel first images, but instead {img.shape}"
|
||||
|
||||
# sanity check that images are channel last
|
||||
_, h, w, c = img.shape
|
||||
assert c < h and c < w, f"expect channel first images, but instead {img.shape}"
|
||||
# sanity check that images are uint8
|
||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||
|
||||
# sanity check that images are uint8
|
||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||
# convert to channel first of type float32 in range [0,1]
|
||||
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
||||
img = img.type(torch.float32)
|
||||
img /= 255
|
||||
|
||||
# convert to channel first of type float32 in range [0,1]
|
||||
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
||||
img = img.type(torch.float32)
|
||||
img /= 255
|
||||
return_observations[imgkey] = img
|
||||
|
||||
return_observations[imgkey] = img
|
||||
if "environment_state" in observations:
|
||||
return_observations["observation.environment_state"] = torch.from_numpy(
|
||||
observations["environment_state"]
|
||||
).float()
|
||||
|
||||
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
|
||||
# requirement for "agent_pos"
|
||||
return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
|
||||
|
||||
return return_observations
|
||||
|
|
|
@ -28,7 +28,10 @@ class DiffusionConfig:
|
|||
|
||||
Notes on the inputs and outputs:
|
||||
- "observation.state" is required as an input key.
|
||||
- At least one key starting with "observation.image is required as an input.
|
||||
- Either:
|
||||
- At least one key starting with "observation.image is required as an input.
|
||||
AND/OR
|
||||
- The key "observation.environment_state" is required as input.
|
||||
- If there are multiple keys beginning with "observation.image" they are treated as multiple camera
|
||||
views. Right now we only support all images having the same shape.
|
||||
- "action" is required as an output key.
|
||||
|
@ -155,26 +158,33 @@ class DiffusionConfig:
|
|||
raise ValueError(
|
||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||
)
|
||||
|
||||
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
|
||||
if self.crop_shape is not None:
|
||||
|
||||
if len(image_keys) == 0 and "observation.environment_state" not in self.input_shapes:
|
||||
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
|
||||
if len(image_keys) > 0:
|
||||
if self.crop_shape is not None:
|
||||
for image_key in image_keys:
|
||||
if (
|
||||
self.crop_shape[0] > self.input_shapes[image_key][1]
|
||||
or self.crop_shape[1] > self.input_shapes[image_key][2]
|
||||
):
|
||||
raise ValueError(
|
||||
f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} "
|
||||
f"for `crop_shape` and {self.input_shapes[image_key]} for "
|
||||
"`input_shapes[{image_key}]`."
|
||||
)
|
||||
# Check that all input images have the same shape.
|
||||
first_image_key = next(iter(image_keys))
|
||||
for image_key in image_keys:
|
||||
if (
|
||||
self.crop_shape[0] > self.input_shapes[image_key][1]
|
||||
or self.crop_shape[1] > self.input_shapes[image_key][2]
|
||||
):
|
||||
if self.input_shapes[image_key] != self.input_shapes[first_image_key]:
|
||||
raise ValueError(
|
||||
f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} "
|
||||
f"for `crop_shape` and {self.input_shapes[image_key]} for "
|
||||
"`input_shapes[{image_key}]`."
|
||||
f"`input_shapes[{image_key}]` does not match `input_shapes[{first_image_key}]`, but we "
|
||||
"expect all image shapes to match."
|
||||
)
|
||||
# Check that all input images have the same shape.
|
||||
first_image_key = next(iter(image_keys))
|
||||
for image_key in image_keys:
|
||||
if self.input_shapes[image_key] != self.input_shapes[first_image_key]:
|
||||
raise ValueError(
|
||||
f"`input_shapes[{image_key}]` does not match `input_shapes[{first_image_key}]`, but we "
|
||||
"expect all image shapes to match."
|
||||
)
|
||||
|
||||
supported_prediction_types = ["epsilon", "sample"]
|
||||
if self.prediction_type not in supported_prediction_types:
|
||||
raise ValueError(
|
||||
|
|
|
@ -83,16 +83,20 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
self.diffusion = DiffusionModel(config)
|
||||
|
||||
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
self.use_env_state = "observation.environment_state" in config.input_shapes
|
||||
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""Clear observation and action queues. Should be called on `env.reset()`"""
|
||||
self._queues = {
|
||||
"observation.images": deque(maxlen=self.config.n_obs_steps),
|
||||
"observation.state": deque(maxlen=self.config.n_obs_steps),
|
||||
"action": deque(maxlen=self.config.n_action_steps),
|
||||
}
|
||||
if len(self.expected_image_keys) > 0:
|
||||
self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps)
|
||||
if self.use_env_state:
|
||||
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
|
@ -117,7 +121,8 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
|
||||
"""
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
if len(self.expected_image_keys) > 0:
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
# Note: It's important that this happens after stacking the images into a single key.
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
|
@ -137,7 +142,8 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
if len(self.expected_image_keys) > 0:
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
batch = self.normalize_targets(batch)
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
return {"loss": loss}
|
||||
|
@ -161,15 +167,20 @@ class DiffusionModel(nn.Module):
|
|||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.rgb_encoder = DiffusionRgbEncoder(config)
|
||||
# Build observation encoders (depending on which observations are provided).
|
||||
global_cond_dim = config.input_shapes["observation.state"][0]
|
||||
num_images = len([k for k in config.input_shapes if k.startswith("observation.image")])
|
||||
self.unet = DiffusionConditionalUnet1d(
|
||||
config,
|
||||
global_cond_dim=(
|
||||
config.input_shapes["observation.state"][0] + self.rgb_encoder.feature_dim * num_images
|
||||
)
|
||||
* config.n_obs_steps,
|
||||
)
|
||||
self._use_images = False
|
||||
self._use_env_state = False
|
||||
if num_images > 0:
|
||||
self._use_images = True
|
||||
self.rgb_encoder = DiffusionRgbEncoder(config)
|
||||
global_cond_dim += self.rgb_encoder.feature_dim * num_images
|
||||
if "observation.environment_state" in config.input_shapes:
|
||||
self._use_env_state = True
|
||||
global_cond_dim += config.input_shapes["observation.environment_state"][0]
|
||||
|
||||
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
|
||||
|
||||
self.noise_scheduler = _make_noise_scheduler(
|
||||
config.noise_scheduler_type,
|
||||
|
@ -219,24 +230,34 @@ class DiffusionModel(nn.Module):
|
|||
def _prepare_global_conditioning(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Encode image features and concatenate them all together along with the state vector."""
|
||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||
global_cond_feats = [batch["observation.state"]]
|
||||
# Extract image feature (first combine batch, sequence, and camera index dims).
|
||||
img_features = self.rgb_encoder(
|
||||
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
|
||||
)
|
||||
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the feature
|
||||
# dim (effectively concatenating the camera features).
|
||||
img_features = einops.rearrange(
|
||||
img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
|
||||
)
|
||||
# Concatenate state and image features then flatten to (B, global_cond_dim).
|
||||
return torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
|
||||
if self._use_images:
|
||||
img_features = self.rgb_encoder(
|
||||
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
|
||||
)
|
||||
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
|
||||
# feature dim (effectively concatenating the camera features).
|
||||
img_features = einops.rearrange(
|
||||
img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
|
||||
)
|
||||
global_cond_feats.append(img_features)
|
||||
|
||||
if self._use_env_state:
|
||||
global_cond_feats.append(batch["observation.environment_state"])
|
||||
|
||||
# Concatenate features then flatten to (B, global_cond_dim).
|
||||
return torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1)
|
||||
|
||||
def generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""
|
||||
This function expects `batch` to have:
|
||||
{
|
||||
"observation.state": (B, n_obs_steps, state_dim)
|
||||
|
||||
"observation.images": (B, n_obs_steps, num_cameras, C, H, W)
|
||||
AND/OR
|
||||
"observation.environment_state": (B, environment_dim)
|
||||
}
|
||||
"""
|
||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||
|
@ -260,13 +281,18 @@ class DiffusionModel(nn.Module):
|
|||
This function expects `batch` to have (at least):
|
||||
{
|
||||
"observation.state": (B, n_obs_steps, state_dim)
|
||||
|
||||
"observation.images": (B, n_obs_steps, num_cameras, C, H, W)
|
||||
AND/OR
|
||||
"observation.environment_state": (B, environment_dim)
|
||||
|
||||
"action": (B, horizon, action_dim)
|
||||
"action_is_pad": (B, horizon)
|
||||
}
|
||||
"""
|
||||
# Input validation.
|
||||
assert set(batch).issuperset({"observation.state", "observation.images", "action", "action_is_pad"})
|
||||
assert set(batch).issuperset({"observation.state", "action", "action_is_pad"})
|
||||
assert "observation.images" in batch or "observation.environment_state" in batch
|
||||
n_obs_steps = batch["observation.state"].shape[1]
|
||||
horizon = batch["action"].shape[1]
|
||||
assert horizon == self.config.horizon
|
||||
|
|
|
@ -1,3 +1,21 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 Seungjae Lee and Yibin Wang and Haritheja Etukuru
|
||||
# and H. Jin Kim and Nur Muhammad Mahi Shafiullah and Lerrel Pinto
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
|
|
|
@ -1,3 +1,21 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 Seungjae Lee and Yibin Wang and Haritheja Etukuru
|
||||
# and H. Jin Kim and Nur Muhammad Mahi Shafiullah and Lerrel Pinto
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from collections import deque
|
||||
|
|
|
@ -1,3 +1,21 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 Seungjae Lee and Yibin Wang and Haritheja Etukuru
|
||||
# and H. Jin Kim and Nur Muhammad Mahi Shafiullah and Lerrel Pinto
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from functools import partial
|
||||
from math import ceil
|
||||
|
|
|
@ -0,0 +1,92 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import threading
|
||||
import time
|
||||
from contextlib import ContextDecorator
|
||||
|
||||
|
||||
class TimeBenchmark(ContextDecorator):
|
||||
"""
|
||||
Measures execution time using a context manager or decorator.
|
||||
|
||||
This class supports both context manager and decorator usage, and is thread-safe for multithreaded
|
||||
environments.
|
||||
|
||||
Args:
|
||||
print: If True, prints the elapsed time upon exiting the context or completing the function. Defaults
|
||||
to False.
|
||||
|
||||
Examples:
|
||||
|
||||
Using as a context manager:
|
||||
|
||||
>>> benchmark = TimeBenchmark()
|
||||
>>> with benchmark:
|
||||
... time.sleep(1)
|
||||
>>> print(f"Block took {benchmark.result:.4f} seconds")
|
||||
Block took approximately 1.0000 seconds
|
||||
|
||||
Using with multithreading:
|
||||
|
||||
```python
|
||||
import threading
|
||||
|
||||
benchmark = TimeBenchmark()
|
||||
|
||||
def context_manager_example():
|
||||
with benchmark:
|
||||
time.sleep(0.01)
|
||||
print(f"Block took {benchmark.result_ms:.2f} milliseconds")
|
||||
|
||||
threads = []
|
||||
for _ in range(3):
|
||||
t1 = threading.Thread(target=context_manager_example)
|
||||
threads.append(t1)
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
```
|
||||
Expected output:
|
||||
Block took approximately 10.00 milliseconds
|
||||
Block took approximately 10.00 milliseconds
|
||||
Block took approximately 10.00 milliseconds
|
||||
"""
|
||||
|
||||
def __init__(self, print=False):
|
||||
self.local = threading.local()
|
||||
self.print_time = print
|
||||
|
||||
def __enter__(self):
|
||||
self.local.start_time = time.perf_counter()
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc):
|
||||
self.local.end_time = time.perf_counter()
|
||||
self.local.elapsed_time = self.local.end_time - self.local.start_time
|
||||
if self.print_time:
|
||||
print(f"Elapsed time: {self.local.elapsed_time:.4f} seconds")
|
||||
return False
|
||||
|
||||
@property
|
||||
def result(self):
|
||||
return getattr(self.local, "elapsed_time", None)
|
||||
|
||||
@property
|
||||
def result_ms(self):
|
||||
return self.result * 1e3
|
|
@ -39,7 +39,7 @@ training:
|
|||
# `online_env_seed` is used for environments for online training data rollouts.
|
||||
online_env_seed: ???
|
||||
eval_freq: ???
|
||||
log_freq: 250
|
||||
log_freq: 200
|
||||
save_checkpoint: true
|
||||
# Checkpoint is saved every `save_freq` training iterations and after the last training step.
|
||||
save_freq: ???
|
||||
|
|
|
@ -10,11 +10,10 @@ override_dataset_stats:
|
|||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
|
||||
training:
|
||||
offline_steps: 80000
|
||||
offline_steps: 100000
|
||||
online_steps: 0
|
||||
eval_freq: 10000
|
||||
save_freq: 100000
|
||||
log_freq: 250
|
||||
eval_freq: 20000
|
||||
save_freq: 20000
|
||||
save_checkpoint: true
|
||||
|
||||
batch_size: 8
|
||||
|
|
|
@ -36,11 +36,10 @@ override_dataset_stats:
|
|||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
|
||||
training:
|
||||
offline_steps: 80000
|
||||
offline_steps: 100000
|
||||
online_steps: 0
|
||||
eval_freq: -1
|
||||
save_freq: 10000
|
||||
log_freq: 100
|
||||
save_freq: 20000
|
||||
save_checkpoint: true
|
||||
|
||||
batch_size: 8
|
||||
|
|
|
@ -34,11 +34,10 @@ override_dataset_stats:
|
|||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
|
||||
training:
|
||||
offline_steps: 80000
|
||||
offline_steps: 100000
|
||||
online_steps: 0
|
||||
eval_freq: -1
|
||||
save_freq: 10000
|
||||
log_freq: 100
|
||||
save_freq: 20000
|
||||
save_checkpoint: true
|
||||
|
||||
batch_size: 8
|
||||
|
|
|
@ -24,9 +24,8 @@ override_dataset_stats:
|
|||
training:
|
||||
offline_steps: 200000
|
||||
online_steps: 0
|
||||
eval_freq: 5000
|
||||
save_freq: 5000
|
||||
log_freq: 250
|
||||
eval_freq: 25000
|
||||
save_freq: 25000
|
||||
save_checkpoint: true
|
||||
|
||||
batch_size: 64
|
||||
|
|
|
@ -0,0 +1,110 @@
|
|||
# @package _global_
|
||||
|
||||
# Defaults for training for the pusht_keypoints dataset.
|
||||
|
||||
# They keypoints are on the vertices of the rectangles that make up the PushT as documented in the PushT
|
||||
# environment:
|
||||
# https://github.com/huggingface/gym-pusht/blob/5e2489be9ff99ed9cd47b6c653dda3b7aa844d24/gym_pusht/envs/pusht.py#L522-L534
|
||||
# For completeness, the diagram is copied here:
|
||||
# 0───────────1
|
||||
# │ │
|
||||
# 3───4───5───2
|
||||
# │ │
|
||||
# │ │
|
||||
# │ │
|
||||
# │ │
|
||||
# 7───6
|
||||
|
||||
|
||||
# Note: The original work trains keypoints-only with conditioning via inpainting. Here, we encode the
|
||||
# observation along with the agent position and use the encoding as global conditioning for the denoising
|
||||
# U-Net.
|
||||
|
||||
# Note: We do not track EMA model weights as we discovered it does not improve the results. See
|
||||
# https://github.com/huggingface/lerobot/pull/134 for more details.
|
||||
|
||||
seed: 100000
|
||||
dataset_repo_id: lerobot/pusht_keypoints
|
||||
|
||||
training:
|
||||
offline_steps: 200000
|
||||
online_steps: 0
|
||||
eval_freq: 5000
|
||||
save_freq: 5000
|
||||
log_freq: 250
|
||||
save_checkpoint: true
|
||||
|
||||
batch_size: 64
|
||||
grad_clip_norm: 10
|
||||
lr: 1.0e-4
|
||||
lr_scheduler: cosine
|
||||
lr_warmup_steps: 500
|
||||
adam_betas: [0.95, 0.999]
|
||||
adam_eps: 1.0e-8
|
||||
adam_weight_decay: 1.0e-6
|
||||
online_steps_between_rollouts: 1
|
||||
|
||||
delta_timestamps:
|
||||
observation.environment_state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
|
||||
observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
|
||||
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1 - ${policy.n_obs_steps} + ${policy.horizon})]"
|
||||
|
||||
# The original implementation doesn't sample frames for the last 7 steps,
|
||||
# which avoids excessive padding and leads to improved training results.
|
||||
drop_n_last_frames: 7 # ${policy.horizon} - ${policy.n_action_steps} - ${policy.n_obs_steps} + 1
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
batch_size: 50
|
||||
|
||||
policy:
|
||||
name: diffusion
|
||||
|
||||
# Input / output structure.
|
||||
n_obs_steps: 2
|
||||
horizon: 16
|
||||
n_action_steps: 8
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.environment_state: [16]
|
||||
observation.state: ["${env.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.environment_state: min_max
|
||||
observation.state: min_max
|
||||
output_normalization_modes:
|
||||
action: min_max
|
||||
|
||||
# Architecture / modeling.
|
||||
# Vision backbone.
|
||||
vision_backbone: resnet18
|
||||
crop_shape: [84, 84]
|
||||
crop_is_random: True
|
||||
pretrained_backbone_weights: null
|
||||
use_group_norm: True
|
||||
spatial_softmax_num_keypoints: 32
|
||||
# Unet.
|
||||
down_dims: [256, 512, 1024]
|
||||
kernel_size: 5
|
||||
n_groups: 8
|
||||
diffusion_step_embed_dim: 128
|
||||
use_film_scale_modulation: True
|
||||
# Noise scheduler.
|
||||
noise_scheduler_type: DDIM
|
||||
num_train_timesteps: 100
|
||||
beta_schedule: squaredcos_cap_v2
|
||||
beta_start: 0.0001
|
||||
beta_end: 0.02
|
||||
prediction_type: epsilon # epsilon / sample
|
||||
clip_sample: True
|
||||
clip_sample_range: 1.0
|
||||
|
||||
# Inference
|
||||
num_inference_steps: 10 # if not provided, defaults to `num_train_timesteps`
|
||||
|
||||
# Loss computation
|
||||
do_mask_loss_for_padding: false
|
|
@ -11,6 +11,7 @@ training:
|
|||
online_steps_between_rollouts: 1
|
||||
online_sampling_ratio: 0.5
|
||||
online_env_seed: 10000
|
||||
log_freq: 100
|
||||
|
||||
batch_size: 256
|
||||
grad_clip_norm: 10.0
|
||||
|
|
|
@ -22,9 +22,8 @@ override_dataset_stats:
|
|||
training:
|
||||
offline_steps: 250000
|
||||
online_steps: 0
|
||||
eval_freq: 20000
|
||||
save_freq: 20000
|
||||
log_freq: 250
|
||||
eval_freq: 25000
|
||||
save_freq: 25000
|
||||
save_checkpoint: true
|
||||
|
||||
batch_size: 64
|
||||
|
|
|
@ -40,6 +40,60 @@ python lerobot/scripts/push_dataset_to_hub.py \
|
|||
--raw-format umi_zarr \
|
||||
--repo-id lerobot/umi_cup_in_the_wild
|
||||
```
|
||||
|
||||
**WARNING: Updating an existing dataset**
|
||||
|
||||
If you want to update an existing dataset, you need to change the `CODEBASE_VERSION` from `lerobot_dataset.py`
|
||||
before running `push_dataset_to_hub.py`. This is especially useful if you introduce a breaking change
|
||||
intentionally or not (i.e. something not backward compatible such as modifying the reward functions used,
|
||||
deleting some frames at the end of an episode, etc.). That way, people running a previous version of the
|
||||
codebase won't be affected by your change and backward compatibility is maintained.
|
||||
|
||||
For instance, Pusht has many versions to maintain backward compatibility between LeRobot codebase versions:
|
||||
- [v1.0](https://huggingface.co/datasets/lerobot/pusht/tree/v1.0)
|
||||
- [v1.1](https://huggingface.co/datasets/lerobot/pusht/tree/v1.1)
|
||||
- [v1.2](https://huggingface.co/datasets/lerobot/pusht/tree/v1.2)
|
||||
- [v1.3](https://huggingface.co/datasets/lerobot/pusht/tree/v1.3)
|
||||
- [v1.4](https://huggingface.co/datasets/lerobot/pusht/tree/v1.4)
|
||||
- [v1.5](https://huggingface.co/datasets/lerobot/pusht/tree/v1.5) <-- last version
|
||||
- [main](https://huggingface.co/datasets/lerobot/pusht/tree/main) <-- points to the last version
|
||||
|
||||
However, you will need to update the version of ALL the other datasets so that they have the new
|
||||
`CODEBASE_VERSION` as a branch in their hugging face dataset repository. Don't worry, there is an easy way
|
||||
that doesn't require to run `push_dataset_to_hub.py`. You can just "branch-out" from the `main` branch on HF
|
||||
dataset repo by running this script which corresponds to a `git checkout -b` (so no copy or upload needed):
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
from huggingface_hub import create_branch, hf_hub_download
|
||||
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
||||
|
||||
from lerobot import available_datasets
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
|
||||
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" # makes it easier to see the print-out below
|
||||
|
||||
NEW_CODEBASE_VERSION = "v1.5" # REPLACE THIS WITH YOUR DESIRED VERSION
|
||||
|
||||
for repo_id in available_datasets:
|
||||
# First check if the newer version already exists.
|
||||
try:
|
||||
hf_hub_download(
|
||||
repo_id=repo_id, repo_type="dataset", filename=".gitattributes", revision=NEW_CODEBASE_VERSION
|
||||
)
|
||||
print(f"Found existing branch for {repo_id}. Please contact a member of the core LeRobot team.")
|
||||
print("Exiting early")
|
||||
break
|
||||
except RepositoryNotFoundError:
|
||||
# Now create a branch.
|
||||
create_branch(repo_id, repo_type="dataset", branch=NEW_CODEBASE_VERSION, revision=CODEBASE_VERSION)
|
||||
print(f"{repo_id} successfully updated")
|
||||
|
||||
```
|
||||
|
||||
On the other hand, if you are pushing a new dataset, you don't need to worry about any of the instructions
|
||||
above, nor to be compatible with previous codebase versions.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
@ -222,6 +276,7 @@ def push_dataset_to_hub(
|
|||
# get the first episode
|
||||
num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]
|
||||
test_hf_dataset = hf_dataset.select(range(num_items_first_ep))
|
||||
episode_data_index = {k: v[:1] for k, v in episode_data_index.items()}
|
||||
|
||||
test_hf_dataset = test_hf_dataset.with_format(None)
|
||||
test_hf_dataset.save_to_disk(str(tests_data_dir / repo_id / "train"))
|
||||
|
@ -316,7 +371,10 @@ def main():
|
|||
parser.add_argument(
|
||||
"--tests-data-dir",
|
||||
type=Path,
|
||||
help="When provided, save tests artifacts into the given directory for (e.g. `--tests-data-dir tests/data/lerobot/pusht`).",
|
||||
help=(
|
||||
"When provided, save tests artifacts into the given directory "
|
||||
"(e.g. `--tests-data-dir tests/data` will save to tests/data/{--repo-id})."
|
||||
),
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -47,7 +47,7 @@ huggingface-hub = {extras = ["hf-transfer"], version = "^0.23.0"}
|
|||
gymnasium = ">=0.29.1"
|
||||
cmake = ">=3.29.0.1"
|
||||
gym-dora = { git = "https://github.com/dora-rs/dora-lerobot.git", subdirectory = "gym_dora", optional = true }
|
||||
gym-pusht = { version = ">=0.1.3", optional = true}
|
||||
gym-pusht = { version = ">=0.1.5", optional = true}
|
||||
gym-xarm = { version = ">=0.1.1", optional = true}
|
||||
gym-aloha = { version = ">=0.1.1", optional = true}
|
||||
pre-commit = {version = ">=3.7.0", optional = true}
|
||||
|
@ -61,6 +61,8 @@ moviepy = ">=1.0.3"
|
|||
rerun-sdk = ">=0.15.1"
|
||||
deepdiff = ">=7.0.1"
|
||||
scikit-image = {version = "^0.23.2", optional = true}
|
||||
pandas = {version = "^2.2.2", optional = true}
|
||||
pytest-mock = {version = "^3.14.0", optional = true}
|
||||
|
||||
|
||||
[tool.poetry.extras]
|
||||
|
@ -69,9 +71,9 @@ pusht = ["gym-pusht"]
|
|||
xarm = ["gym-xarm"]
|
||||
aloha = ["gym-aloha"]
|
||||
dev = ["pre-commit", "debugpy"]
|
||||
test = ["pytest", "pytest-cov"]
|
||||
test = ["pytest", "pytest-cov", "pytest-mock"]
|
||||
umi = ["imagecodecs"]
|
||||
video_benchmark = ["scikit-image"]
|
||||
video_benchmark = ["scikit-image", "pandas"]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 110
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d5883aa2c8ba2bcd8d047a77064112aa5d4c1c9b8595bb28935ec93ed53627e5
|
||||
oid sha256:52723265cba2ec839a5fcf75733813ecf91019ec0f7a49865fe233616e674583
|
||||
size 3056
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:0eab443dd492d0e271094290ae3cec2c9b2f4a19d35434eb5952cb37b0d40890
|
||||
size 18272
|
||||
oid sha256:8552d4ac6b618a5b2741e174d51f1d4fc0e5f4e6cc7026bebdb6ed145373b042
|
||||
size 18320
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:8c1a72239bb56a6c5714f18d849557c89feb858840e8f86689d017bb49551379
|
||||
oid sha256:a522c7815565f1f81a8bb5a853263405ab8c3b087ecbc7a3b004848891d77342
|
||||
size 247
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a1cd3db853d0f92e1696fe297c550200219d85befdeb5b5eacae4b10a74d9896
|
||||
size 136
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:dbf25de102227dd2d8c3b6c61e1fc25a026d44f151161b88bc9a9eb101e942e4
|
||||
size 33
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:50b3c026da835560f9b87e7dfd28673e766bfb58d56c85002687d0a599b6fa43
|
||||
size 3304
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:958798d23a1690449744961f8c3ed934efe950c664e5fd729468959362840218
|
||||
size 20336
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:686d9d9bad8815d67597b997058d9853a04e5bdbe4eed038f4da9806f867af3d
|
||||
size 1098
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:f22ee3500aca1bea0afdda429e841c57a3278dfea92c79bbbf5dac5f984ed648
|
||||
size 247
|
|
@ -1,3 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d9cc073bcb335024500fe7c823f142a3b4f038ff458d8c47fb6a6918f8f6d5fd
|
||||
oid sha256:b99bbb7332557d47b108fd0262d911c99f5bfce30fa3e76dc802b927284135e7
|
||||
size 111338
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:58c50ef6413b6b3acb7ad280281cdd4eba553f7d3d0b4dad20c262025d610f2b
|
||||
oid sha256:0f63430455e1ca7a5fe28c81a15fc0eb82758035e6b3d623e7e7952e71cb262a
|
||||
size 111338
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:bd1d26e983e2910ec170cd6ac1f4de4d7cb447ee24b516a74f42765d4894e048
|
||||
oid sha256:0b88c39db5b13da646fd5876bd765213569387591d30ec665d048ae1070db0b9
|
||||
size 111338
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:e1247a9d4683520ed338f3fd410cc200999e4b82da573cd499095ba02037586f
|
||||
oid sha256:68eb245890f9537851ea7fb227472dcd4f1fa3820a7c3294a4989e2b9896d078
|
||||
size 111338
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b24f3c3d41428b768082eb3b02b5e22dc9540aa4dbe756d43be214d51e97adba
|
||||
oid sha256:00c74e17bbf7d428b0b0869f388d348820a938c417b3c888a1384980bb53d4d0
|
||||
size 111338
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:5301dc61b585fbfbdb6ce681ffcd52fc26b64a3767567c228a9e4404f7bcb926
|
||||
oid sha256:a5a7f66704640ba18f756fc44c00721c77a406f412a3a9fcc1a2b1868c978444
|
||||
size 111338
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:1ae7f6a7f4ee8340ec73b0e7f1e167046af2af0a22381e0cd3ff42f311e098e0
|
||||
size 794
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:2eeb1b185b505450f8a2b6042537d65d2d8f5ee1396cf878a50d3d2aa3a22822
|
||||
size 794
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:7f2bb24887f9d4c49ad562429f419b7b66f4310a59877104a98d3c5c6ddca996
|
||||
size 794
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a52fe583c816fdfb962111dd1ee1c113a5f4b9699246fab8648f89e056979f8e
|
||||
size 794
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:70dbf161581b860e255573eb1ef90f4defd134d8dcf0afea16099c859c4a8f85
|
||||
size 794
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:198abd0ec4231c13cadf707d553cba3860acbc74a073406ed184eab5495acdfa
|
||||
size 794
|
|
@ -211,7 +211,7 @@ def _mock_download_raw_dora(raw_dir, num_frames=6, num_episodes=3, fps=30):
|
|||
|
||||
fname = f"{cam_key}_episode_{ep_idx:06d}.mp4"
|
||||
video_path = raw_dir / "videos" / fname
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps, video_codec="libx264")
|
||||
|
||||
|
||||
def _mock_download_raw(raw_dir, repo_id):
|
||||
|
@ -229,6 +229,23 @@ def _mock_download_raw(raw_dir, repo_id):
|
|||
raise ValueError(repo_id)
|
||||
|
||||
|
||||
def _mock_encode_video_frames(*args, **kwargs):
|
||||
kwargs["video_codec"] = "libx264"
|
||||
return encode_video_frames(*args, **kwargs)
|
||||
|
||||
|
||||
def patch_encoder(raw_format, mocker):
|
||||
format_module_map = {
|
||||
"aloha_hdf5": "lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format.encode_video_frames",
|
||||
"pusht_zarr": "lerobot.common.datasets.push_dataset_to_hub.pusht_zarr_format.encode_video_frames",
|
||||
"xarm_pkl": "lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format.encode_video_frames",
|
||||
"umi_zarr": "lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format.encode_video_frames",
|
||||
}
|
||||
|
||||
if raw_format in format_module_map:
|
||||
mocker.patch(format_module_map[raw_format], side_effect=_mock_encode_video_frames)
|
||||
|
||||
|
||||
def test_push_dataset_to_hub_invalid_repo_id(tmpdir):
|
||||
with pytest.raises(ValueError):
|
||||
push_dataset_to_hub(Path(tmpdir), "raw_format", "invalid_repo_id")
|
||||
|
@ -251,17 +268,21 @@ def test_push_dataset_to_hub_out_dir_force_override_false(tmpdir):
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"required_packages, raw_format, repo_id",
|
||||
"required_packages, raw_format, repo_id, make_test_data",
|
||||
[
|
||||
(["gym-pusht"], "pusht_zarr", "lerobot/pusht"),
|
||||
(None, "xarm_pkl", "lerobot/xarm_lift_medium"),
|
||||
(None, "aloha_hdf5", "lerobot/aloha_sim_insertion_scripted"),
|
||||
(["imagecodecs"], "umi_zarr", "lerobot/umi_cup_in_the_wild"),
|
||||
(None, "dora_parquet", "cadene/wrist_gripper"),
|
||||
(["gym_pusht"], "pusht_zarr", "lerobot/pusht", False),
|
||||
(["gym_pusht"], "pusht_zarr", "lerobot/pusht", True),
|
||||
(None, "xarm_pkl", "lerobot/xarm_lift_medium", False),
|
||||
(None, "aloha_hdf5", "lerobot/aloha_sim_insertion_scripted", False),
|
||||
(["imagecodecs"], "umi_zarr", "lerobot/umi_cup_in_the_wild", False),
|
||||
(None, "dora_parquet", "cadene/wrist_gripper", False),
|
||||
],
|
||||
)
|
||||
@require_package_arg
|
||||
def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_id):
|
||||
def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_id, make_test_data, mocker):
|
||||
# Patch `encode_video_frames` so that it uses 'libx264' instead of 'libsvtav1' for testing
|
||||
patch_encoder(raw_format, mocker)
|
||||
|
||||
num_episodes = 3
|
||||
tmpdir = Path(tmpdir)
|
||||
|
||||
|
@ -278,6 +299,7 @@ def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_
|
|||
local_dir=local_dir,
|
||||
force_override=False,
|
||||
cache_dir=tmpdir / "cache",
|
||||
tests_data_dir=tmpdir / "tests/data" if make_test_data else None,
|
||||
)
|
||||
|
||||
# minimal generic tests on the local directory containing LeRobotDataset
|
||||
|
@ -299,6 +321,20 @@ def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_
|
|||
for cam_key in lerobot_dataset.camera_keys:
|
||||
assert cam_key in item
|
||||
|
||||
if make_test_data:
|
||||
# Check that only the first episode is selected.
|
||||
test_dataset = LeRobotDataset(repo_id=repo_id, root=tmpdir / "tests/data")
|
||||
num_frames = sum(
|
||||
i == lerobot_dataset.hf_dataset["episode_index"][0]
|
||||
for i in lerobot_dataset.hf_dataset["episode_index"]
|
||||
).item()
|
||||
assert (
|
||||
test_dataset.hf_dataset["episode_index"]
|
||||
== lerobot_dataset.hf_dataset["episode_index"][:num_frames]
|
||||
)
|
||||
for k in ["from", "to"]:
|
||||
assert torch.equal(test_dataset.episode_data_index[k], lerobot_dataset.episode_data_index[k][:1])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"raw_format, repo_id",
|
||||
|
|
Loading…
Reference in New Issue