Add visualize_dataset_html with `http.server` (#188)

This commit is contained in:
Remi 2024-08-08 20:19:06 +03:00 committed by GitHub
parent bc6384bb80
commit 2252b42337
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 785 additions and 68 deletions

View File

@ -108,8 +108,8 @@ def visualize_dataset(
web_port: int = 9090, web_port: int = 9090,
ws_port: int = 9087, ws_port: int = 9087,
save: bool = False, save: bool = False,
output_dir: Path | None = None,
root: Path | None = None, root: Path | None = None,
output_dir: Path | None = None,
) -> Path | None: ) -> Path | None:
if save: if save:
assert ( assert (
@ -209,6 +209,18 @@ def main():
required=True, required=True,
help="Episode to visualize.", help="Episode to visualize.",
) )
parser.add_argument(
"--root",
type=Path,
default=None,
help="Root directory for a dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.",
)
parser.add_argument(
"--output-dir",
type=Path,
default=None,
help="Directory path to write a .rrd file when `--save 1` is set.",
)
parser.add_argument( parser.add_argument(
"--batch-size", "--batch-size",
type=int, type=int,
@ -254,17 +266,6 @@ def main():
"Visualize the data by running `rerun path/to/file.rrd` on your local machine." "Visualize the data by running `rerun path/to/file.rrd` on your local machine."
), ),
) )
parser.add_argument(
"--output-dir",
type=str,
help="Directory path to write a .rrd file when `--save 1` is set.",
)
parser.add_argument(
"--root",
type=str,
help="Root directory for a dataset stored on a local machine.",
)
args = parser.parse_args() args = parser.parse_args()
visualize_dataset(**vars(args)) visualize_dataset(**vars(args))

View File

@ -0,0 +1,300 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
Note: The last frame of the episode doesnt always correspond to a final state.
That's because our datasets are composed of transition from state to state up to
the antepenultimate state associated to the ultimate action to arrive in the final state.
However, there might not be a transition from a final state to another state.
Note: This script aims to visualize the data used to train the neural networks.
~What you see is what you get~. When visualizing image modality, it is often expected to observe
lossly compression artifacts since these images have been decoded from compressed mp4 videos to
save disk space. The compression factor applied has been tuned to not affect success rate.
Example of usage:
- Visualize data stored on a local machine:
```bash
local$ python lerobot/scripts/visualize_dataset_html.py \
--repo-id lerobot/pusht
local$ open http://localhost:9090
```
- Visualize data stored on a distant machine with a local viewer:
```bash
distant$ python lerobot/scripts/visualize_dataset_html.py \
--repo-id lerobot/pusht
local$ ssh -L 9090:localhost:9090 distant # create a ssh tunnel
local$ open http://localhost:9090
```
- Select episodes to visualize:
```bash
python lerobot/scripts/visualize_dataset_html.py \
--repo-id lerobot/pusht \
--episodes 7 3 5 1 4
```
"""
import argparse
import logging
import shutil
from pathlib import Path
import torch
import tqdm
from flask import Flask, redirect, render_template, url_for
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.utils.utils import init_logging
class EpisodeSampler(torch.utils.data.Sampler):
def __init__(self, dataset, episode_index):
from_idx = dataset.episode_data_index["from"][episode_index].item()
to_idx = dataset.episode_data_index["to"][episode_index].item()
self.frame_ids = range(from_idx, to_idx)
def __iter__(self):
return iter(self.frame_ids)
def __len__(self):
return len(self.frame_ids)
def run_server(
dataset: LeRobotDataset,
episodes: list[int],
host: str,
port: str,
static_folder: Path,
template_folder: Path,
):
app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve())
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
@app.route("/")
def index():
# home page redirects to the first episode page
[dataset_namespace, dataset_name] = dataset.repo_id.split("/")
first_episode_id = episodes[0]
return redirect(
url_for(
"show_episode",
dataset_namespace=dataset_namespace,
dataset_name=dataset_name,
episode_id=first_episode_id,
)
)
@app.route("/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>")
def show_episode(dataset_namespace, dataset_name, episode_id):
dataset_info = {
"repo_id": dataset.repo_id,
"num_samples": dataset.num_samples,
"num_episodes": dataset.num_episodes,
"fps": dataset.fps,
}
video_paths = get_episode_video_paths(dataset, episode_id)
videos_info = [
{"url": url_for("static", filename=video_path), "filename": Path(video_path).name}
for video_path in video_paths
]
ep_csv_url = url_for("static", filename=get_ep_csv_fname(episode_id))
return render_template(
"visualize_dataset_template.html",
episode_id=episode_id,
episodes=episodes,
dataset_info=dataset_info,
videos_info=videos_info,
ep_csv_url=ep_csv_url,
has_policy=False,
)
app.run(host=host, port=port)
def get_ep_csv_fname(episode_id: int):
ep_csv_fname = f"episode_{episode_id}.csv"
return ep_csv_fname
def write_episode_data_csv(output_dir, file_name, episode_index, dataset):
"""Write a csv file containg timeseries data of an episode (e.g. state and action).
This file will be loaded by Dygraph javascript to plot data in real time."""
from_idx = dataset.episode_data_index["from"][episode_index]
to_idx = dataset.episode_data_index["to"][episode_index]
has_state = "observation.state" in dataset.hf_dataset.features
has_action = "action" in dataset.hf_dataset.features
# init header of csv with state and action names
header = ["timestamp"]
if has_state:
dim_state = len(dataset.hf_dataset["observation.state"][0])
header += [f"state_{i}" for i in range(dim_state)]
if has_action:
dim_action = len(dataset.hf_dataset["action"][0])
header += [f"action_{i}" for i in range(dim_action)]
columns = ["timestamp"]
if has_state:
columns += ["observation.state"]
if has_action:
columns += ["action"]
rows = []
data = dataset.hf_dataset.select_columns(columns)
for i in range(from_idx, to_idx):
row = [data[i]["timestamp"].item()]
if has_state:
row += data[i]["observation.state"].tolist()
if has_action:
row += data[i]["action"].tolist()
rows.append(row)
output_dir.mkdir(parents=True, exist_ok=True)
with open(output_dir / file_name, "w") as f:
f.write(",".join(header) + "\n")
for row in rows:
row_str = [str(col) for col in row]
f.write(",".join(row_str) + "\n")
def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]:
# get first frame of episode (hack to get video_path of the episode)
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
return [
dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"]
for key in dataset.video_frame_keys
]
def visualize_dataset_html(
repo_id: str,
root: Path | None = None,
episodes: list[int] = None,
output_dir: Path | None = None,
serve: bool = True,
host: str = "127.0.0.1",
port: int = 9090,
force_override: bool = False,
) -> Path | None:
init_logging()
dataset = LeRobotDataset(repo_id, root=root)
if not dataset.video:
raise NotImplementedError(f"Image datasets ({dataset.video=}) are currently not supported.")
if output_dir is None:
output_dir = f"outputs/visualize_dataset_html/{repo_id}"
output_dir = Path(output_dir)
if output_dir.exists():
if force_override:
shutil.rmtree(output_dir)
else:
logging.info(f"Output directory already exists. Loading from it: '{output_dir}'")
output_dir.mkdir(parents=True, exist_ok=True)
# Create a simlink from the dataset video folder containg mp4 files to the output directory
# so that the http server can get access to the mp4 files.
static_dir = output_dir / "static"
static_dir.mkdir(parents=True, exist_ok=True)
ln_videos_dir = static_dir / "videos"
if not ln_videos_dir.exists():
ln_videos_dir.symlink_to(dataset.videos_dir.resolve())
template_dir = Path(__file__).resolve().parent.parent / "templates"
if episodes is None:
episodes = list(range(dataset.num_episodes))
logging.info("Writing CSV files")
for episode_index in tqdm.tqdm(episodes):
# write states and actions in a csv (it can be slow for big datasets)
ep_csv_fname = get_ep_csv_fname(episode_index)
# TODO(rcadene): speedup script by loading directly from dataset, pyarrow, parquet, safetensors?
write_episode_data_csv(static_dir, ep_csv_fname, episode_index, dataset)
if serve:
run_server(dataset, episodes, host, port, static_dir, template_dir)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--repo-id",
type=str,
required=True,
help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).",
)
parser.add_argument(
"--root",
type=Path,
default=None,
help="Root directory for a dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.",
)
parser.add_argument(
"--episodes",
type=int,
nargs="*",
default=None,
help="Episode indices to visualize (e.g. `0 1 5 6` to load episodes of index 0, 1, 5 and 6). By default loads all episodes.",
)
parser.add_argument(
"--output-dir",
type=Path,
default=None,
help="Directory path to write html files and kickoff a web server. By default write them to 'outputs/visualize_dataset/REPO_ID'.",
)
parser.add_argument(
"--serve",
type=int,
default=1,
help="Launch web server.",
)
parser.add_argument(
"--host",
type=str,
default="127.0.0.1",
help="Web host used by the http server.",
)
parser.add_argument(
"--port",
type=int,
default=9090,
help="Web port used by the http server.",
)
parser.add_argument(
"--force-override",
type=int,
default=0,
help="Delete the output directory if it exists already.",
)
args = parser.parse_args()
visualize_dataset_html(**vars(args))
if __name__ == "__main__":
main()

View File

@ -0,0 +1,360 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<!-- # TODO(rcadene, mishig25): store the js files locally -->
<script src="https://cdnjs.cloudflare.com/ajax/libs/alpinejs/3.13.5/cdn.min.js" defer></script>
<script src="https://cdn.jsdelivr.net/npm/dygraphs@2.2.1/dist/dygraph.min.js" type="text/javascript"></script>
<script src="https://cdn.tailwindcss.com"></script>
<title>{{ dataset_info.repo_id }} episode {{ episode_id }}</title>
</head>
<!-- Use [Alpin.js](https://alpinejs.dev), a lightweight and easy to learn JS framework -->
<!-- Use [tailwindcss](https://tailwindcss.com/), CSS classes for styling html -->
<!-- Use [dygraphs](https://dygraphs.com/), a lightweight JS charting library -->
<body class="flex h-screen max-h-screen bg-slate-950 text-gray-200" x-data="createAlpineData()" @keydown.window="(e) => {
// Use the space bar to play and pause, instead of default action (e.g. scrolling)
const { keyCode, key } = e;
if (keyCode === 32 || key === ' ') {
e.preventDefault();
$refs.btnPause.classList.contains('hidden') ? $refs.btnPlay.click() : $refs.btnPause.click();
}else if (key === 'ArrowDown' || key === 'ArrowUp'){
const nextEpisodeId = key === 'ArrowDown' ? {{ episode_id }} + 1 : {{ episode_id }} - 1;
const lowestEpisodeId = {{ episodes }}.at(0);
const highestEpisodeId = {{ episodes }}.at(-1);
if(nextEpisodeId >= lowestEpisodeId && nextEpisodeId <= highestEpisodeId){
window.location.href = `./episode_${nextEpisodeId}`;
}
}
}">
<!-- Sidebar -->
<div x-ref="sidebar" class="w-60 bg-slate-900 p-5 break-words max-h-screen overflow-y-auto">
<h1 class="mb-4 text-xl font-semibold">{{ dataset_info.repo_id }}</h1>
<ul>
<li>
Number of samples/frames: {{ dataset_info.num_samples }}
</li>
<li>
Number of episodes: {{ dataset_info.num_episodes }}
</li>
<li>
Frames per second: {{ dataset_info.fps }}
</li>
</ul>
<p>Episodes:</p>
<ul class="ml-2">
{% for episode in episodes %}
<li class="font-mono text-sm mt-0.5">
<a href="episode_{{ episode }}" class="underline {% if episode_id == episode %}font-bold -ml-1{% endif %}">
Episode {{ episode }}
</a>
</li>
{% endfor %}
</ul>
</div>
<!-- Toggle sidebar button -->
<button class="flex items-center opacity-50 hover:opacity-100 mx-1"
@click="() => ($refs.sidebar.classList.toggle('hidden'))" title="Toggle sidebar">
<div class="bg-slate-500 w-2 h-10 rounded-full"></div>
</button>
<!-- Content -->
<div class="flex-1 max-h-screen flex flex-col gap-4 overflow-y-auto">
<h1 class="text-xl font-bold mt-4 font-mono">
Episode {{ episode_id }}
</h1>
<!-- Videos -->
<div class="flex flex-wrap gap-1">
{% for video_info in videos_info %}
<div class="max-w-96">
<p class="text-sm text-gray-300 bg-gray-800 px-2 rounded-t-xl truncate">{{ video_info.filename }}</p>
<video autoplay muted loop type="video/mp4" class="min-w-64" @timeupdate="() => {
if (video.duration) {
const time = video.currentTime;
const pc = (100 / video.duration) * time;
$refs.slider.value = pc;
dygraphTime = time;
dygraphIndex = Math.floor(pc * dygraph.numRows() / 100);
dygraph.setSelection(dygraphIndex, undefined, true, true);
$refs.timer.textContent = formatTime(time) + ' / ' + formatTime(video.duration);
updateTimeQuery(time.toFixed(2));
}
}" @ended="() => {
$refs.btnPlay.classList.remove('hidden');
$refs.btnPause.classList.add('hidden');
}"
@loadedmetadata="() => ($refs.timer.textContent = formatTime(0) + ' / ' + formatTime(video.duration))">
<source src="{{ video_info.url }}">
Your browser does not support the video tag.
</video>
</div>
{% endfor %}
</div>
<!-- Shortcuts info -->
<div class="text-sm hidden md:block">
Hotkeys: <span class="font-mono">Space</span> to pause/unpause, <span class="font-mono">Arrow Down</span> to go to next episode, <span class="font-mono">Arrow Up</span> to go to previous episode.
</div>
<!-- Controllers -->
<div class="flex gap-1 text-3xl items-center">
<button x-ref="btnPlay" class="-rotate-90 hidden" class="-rotate-90" title="Play. Toggle with Space" @click="() => {
videos.forEach(video => video.play());
$refs.btnPlay.classList.toggle('hidden');
$refs.btnPause.classList.toggle('hidden');
}">🔽</button>
<button x-ref="btnPause" title="Pause. Toggle with Space" @click="() => {
videos.forEach(video => video.pause());
$refs.btnPlay.classList.toggle('hidden');
$refs.btnPause.classList.toggle('hidden');
}">⏸️</button>
<button title="Jump backward 5 seconds"
@click="() => (videos.forEach(video => (video.currentTime -= 5)))">⏪</button>
<button title="Jump forward 5 seconds"
@click="() => (videos.forEach(video => (video.currentTime += 5)))">⏩</button>
<button title="Rewind from start"
@click="() => (videos.forEach(video => (video.currentTime = 0.0)))">↩️</button>
<input x-ref="slider" max="100" min="0" step="1" type="range" value="0" class="w-80 mx-2" @input="() => {
const sliderValue = $refs.slider.value;
$refs.btnPause.click();
videos.forEach(video => {
const time = (video.duration * sliderValue) / 100;
video.currentTime = time;
});
}" />
<div x-ref="timer" class="font-mono text-sm border border-slate-500 rounded-lg px-1 py-0.5 shrink-0">0:00 /
0:00
</div>
</div>
<!-- Graph -->
<div class="flex gap-2 mb-4 flex-wrap">
<div>
<div id="graph" @mouseleave="() => {
dygraph.setSelection(dygraphIndex, undefined, true, true);
dygraphTime = video.currentTime;
}">
</div>
<p x-ref="graphTimer" class="font-mono ml-14 mt-4"
x-init="$watch('dygraphTime', value => ($refs.graphTimer.innerText = `Time: ${dygraphTime.toFixed(2)}s`))">
Time: 0.00s
</p>
</div>
<table class="text-sm border-collapse border border-slate-700" x-show="currentFrameData">
<thead>
<tr>
<th></th>
<template x-for="(_, colIndex) in Array.from({length: nColumns}, (_, index) => index)">
<th class="border border-slate-700">
<div class="flex gap-x-2 justify-between px-2">
<input type="checkbox" :checked="isColumnChecked(colIndex)"
@change="toggleColumn(colIndex)">
<p x-text="`${columnNames[colIndex]}`"></p>
</div>
</th>
</template>
</tr>
</thead>
<tbody>
<template x-for="(row, rowIndex) in rows">
<tr class="odd:bg-gray-800 even:bg-gray-900">
<td class="border border-slate-700">
<div class="flex gap-x-2 w-24 font-semibold px-1">
<input type="checkbox" :checked="isRowChecked(rowIndex)"
@change="toggleRow(rowIndex)">
<p x-text="`Motor ${rowIndex}`"></p>
</div>
</td>
<template x-for="(cell, colIndex) in row">
<td x-show="cell" class="border border-slate-700">
<div class="flex gap-x-2 w-24 justify-between px-2">
<input type="checkbox" x-model="cell.checked" @change="updateTableValues()">
<span x-text="`${cell.value.toFixed(2)}`"
:style="`color: ${cell.color}`"></span>
</div>
</td>
</template>
</tr>
</template>
</tbody>
</table>
<div id="labels" class="hidden">
</div>
</div>
</div>
<script>
function createAlpineData() {
return {
// state
dygraph: null,
currentFrameData: null,
columnNames: ["state", "action", "pred action"],
nColumns: {% if has_policy %}3{% else %}2{% endif %},
checked: [],
dygraphTime: 0.0,
dygraphIndex: 0,
videos: null,
video: null,
colors: null,
// alpine initialization
init() {
this.videos = document.querySelectorAll('video');
this.video = this.videos[0];
this.dygraph = new Dygraph(document.getElementById("graph"), '{{ ep_csv_url }}', {
pixelsPerPoint: 0.01,
legend: 'always',
labelsDiv: document.getElementById('labels'),
labelsKMB: true,
strokeWidth: 1.5,
pointClickCallback: (event, point) => {
this.dygraphTime = point.xval;
this.updateTableValues(this.dygraphTime);
},
highlightCallback: (event, x, points, row, seriesName) => {
this.dygraphTime = x;
this.updateTableValues(this.dygraphTime);
},
drawCallback: (dygraph, is_initial) => {
if (is_initial) {
// dygraph initialization
this.dygraph.setSelection(this.dygraphIndex, undefined, true, true);
this.colors = this.dygraph.getColors();
this.checked = Array(this.colors.length).fill(true);
const seriesNames = this.dygraph.getLabels().slice(1);
const colors = [];
const LIGHTNESS = [30, 65, 85]; // state_lightness, action_lightness, pred_action_lightness
let lightnessIdx = 0;
const chunkSize = Math.ceil(seriesNames.length / this.nColumns);
for (let i = 0; i < seriesNames.length; i += chunkSize) {
const lightness = LIGHTNESS[lightnessIdx];
for (let hue = 0; hue < 360; hue += parseInt(360/chunkSize)) {
const color = `hsl(${hue}, 100%, ${lightness}%)`;
colors.push(color);
}
lightnessIdx += 1;
}
this.dygraph.updateOptions({ colors });
this.colors = colors;
this.updateTableValues();
let url = new URL(window.location.href);
let params = new URLSearchParams(url.search);
let time = params.get("t");
if(time){
time = parseFloat(time);
this.videos.forEach(video => (video.currentTime = time));
}
}
},
});
},
//#region Table Data
// turn dygraph's 1D data (at a given time t) to 2D data that whose columns names are defined in this.columnNames.
// 2d data view is used to create html table element.
get rows() {
if (!this.currentFrameData) {
return [];
}
const columnSize = Math.ceil(this.currentFrameData.length / this.nColumns);
return Array.from({
length: columnSize
}, (_, rowIndex) => {
const row = [
this.currentFrameData[rowIndex] || null,
this.currentFrameData[rowIndex + columnSize] || null,
];
if (this.nColumns === 3) {
row.push(this.currentFrameData[rowIndex + 2 * columnSize] || null)
}
return row;
});
},
isRowChecked(rowIndex) {
return this.rows[rowIndex].every(cell => cell && cell.checked);
},
isColumnChecked(colIndex) {
return this.rows.every(row => row[colIndex] && row[colIndex].checked);
},
toggleRow(rowIndex) {
const newState = !this.isRowChecked(rowIndex);
this.rows[rowIndex].forEach(cell => {
if (cell) cell.checked = newState;
});
this.updateTableValues();
},
toggleColumn(colIndex) {
const newState = !this.isColumnChecked(colIndex);
this.rows.forEach(row => {
if (row[colIndex]) row[colIndex].checked = newState;
});
this.updateTableValues();
},
// given time t, update the values in the html table with "data[t]"
updateTableValues(time) {
if (!this.colors) {
return;
}
let pc = (100 / this.video.duration) * (time === undefined ? this.video.currentTime : time);
if (isNaN(pc)) pc = 0;
const index = Math.floor(pc * this.dygraph.numRows() / 100);
// slice(1) to remove the timestamp point that we do not need
const labels = this.dygraph.getLabels().slice(1);
const values = this.dygraph.rawData_[index].slice(1);
const checkedNew = this.currentFrameData ? this.currentFrameData.map(cell => cell.checked) : Array(
this.colors.length).fill(true);
this.currentFrameData = labels.map((label, idx) => ({
label,
value: values[idx],
color: this.colors[idx],
checked: checkedNew[idx],
}));
const shouldUpdateVisibility = !this.checked.every((value, index) => value === checkedNew[index]);
if (shouldUpdateVisibility) {
this.checked = checkedNew;
this.dygraph.setVisibility(this.checked);
}
},
//#endregion
updateTimeQuery(time) {
let url = new URL(window.location.href);
let params = new URLSearchParams(url.search);
params.set("t", time);
url.search = params.toString();
window.history.replaceState({}, '', url.toString());
},
formatTime(time) {
var hours = Math.floor(time / 3600);
var minutes = Math.floor((time % 3600) / 60);
var seconds = Math.floor(time % 60);
return (hours > 0 ? hours + ':' : '') + (minutes < 10 ? '0' + minutes : minutes) + ':' + (seconds <
10 ?
'0' + seconds : seconds);
}
};
}
</script>
</body>
</html>

126
poetry.lock generated
View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. # This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
[[package]] [[package]]
name = "absl-py" name = "absl-py"
@ -192,6 +192,17 @@ charset-normalizer = ["charset-normalizer"]
html5lib = ["html5lib"] html5lib = ["html5lib"]
lxml = ["lxml"] lxml = ["lxml"]
[[package]]
name = "blinker"
version = "1.8.2"
description = "Fast, simple object-to-object and broadcast signaling"
optional = false
python-versions = ">=3.8"
files = [
{file = "blinker-1.8.2-py3-none-any.whl", hash = "sha256:1779309f71bf239144b9399d06ae925637cf6634cf6bd131104184531bf67c01"},
{file = "blinker-1.8.2.tar.gz", hash = "sha256:8f77b09d3bf7c795e969e9486f39c2c5e9c39d4ee07424be2bc594ece9642d83"},
]
[[package]] [[package]]
name = "certifi" name = "certifi"
version = "2024.7.4" version = "2024.7.4"
@ -584,17 +595,6 @@ files = [
{file = "debugpy-1.8.2.zip", hash = "sha256:95378ed08ed2089221896b9b3a8d021e642c24edc8fef20e5d4342ca8be65c00"}, {file = "debugpy-1.8.2.zip", hash = "sha256:95378ed08ed2089221896b9b3a8d021e642c24edc8fef20e5d4342ca8be65c00"},
] ]
[[package]]
name = "decorator"
version = "4.4.2"
description = "Decorators for Humans"
optional = false
python-versions = ">=2.6, !=3.0.*, !=3.1.*"
files = [
{file = "decorator-4.4.2-py2.py3-none-any.whl", hash = "sha256:41fa54c2a0cc4ba648be4fd43cff00aedf5b9465c9bf18d64325bc225f08f760"},
{file = "decorator-4.4.2.tar.gz", hash = "sha256:e3a62f0520172440ca0dcc823749319382e377f37f140a0b99ef45fecb84bfe7"},
]
[[package]] [[package]]
name = "deepdiff" name = "deepdiff"
version = "7.0.1" version = "7.0.1"
@ -795,6 +795,7 @@ files = [
{file = "dora_rs-0.3.5-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:01f811d0c6722f74743c153a7be0144686daeafa968c473e60f6b6c5dc8f5bff"}, {file = "dora_rs-0.3.5-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:01f811d0c6722f74743c153a7be0144686daeafa968c473e60f6b6c5dc8f5bff"},
{file = "dora_rs-0.3.5-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:a36e97d31eeb66e6d5913130695d188ceee1248029961012a8b4f59fd3f58670"}, {file = "dora_rs-0.3.5-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:a36e97d31eeb66e6d5913130695d188ceee1248029961012a8b4f59fd3f58670"},
{file = "dora_rs-0.3.5-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:25d620123a733661dc740ef2b456601ddbaa69ae2b50d8141daa3c684bda385c"}, {file = "dora_rs-0.3.5-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:25d620123a733661dc740ef2b456601ddbaa69ae2b50d8141daa3c684bda385c"},
{file = "dora_rs-0.3.5-cp37-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a9fdc4e73578bebb1c8d0f8bea2243a5a9e179f08c74d98576123b59b75e5cac"},
{file = "dora_rs-0.3.5-cp37-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e65830634c58158557f0ab90e5d1f492bcbc6b74587b05825ba4c20b634dc1bd"}, {file = "dora_rs-0.3.5-cp37-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e65830634c58158557f0ab90e5d1f492bcbc6b74587b05825ba4c20b634dc1bd"},
{file = "dora_rs-0.3.5-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c01f9ab8f93295341aeab2d606d484d9cff9d05f57581e2180433ec8e0d38307"}, {file = "dora_rs-0.3.5-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c01f9ab8f93295341aeab2d606d484d9cff9d05f57581e2180433ec8e0d38307"},
{file = "dora_rs-0.3.5-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:5d6d46a49a34cd7e4f74496a1089b9a1b78282c219a28d98fe031a763e92d530"}, {file = "dora_rs-0.3.5-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:5d6d46a49a34cd7e4f74496a1089b9a1b78282c219a28d98fe031a763e92d530"},
@ -892,6 +893,28 @@ docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1
testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-asyncio (>=0.21)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)", "virtualenv (>=20.26.2)"] testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-asyncio (>=0.21)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)", "virtualenv (>=20.26.2)"]
typing = ["typing-extensions (>=4.8)"] typing = ["typing-extensions (>=4.8)"]
[[package]]
name = "flask"
version = "3.0.3"
description = "A simple framework for building complex web applications."
optional = false
python-versions = ">=3.8"
files = [
{file = "flask-3.0.3-py3-none-any.whl", hash = "sha256:34e815dfaa43340d1d15a5c3a02b8476004037eb4840b34910c6e21679d288f3"},
{file = "flask-3.0.3.tar.gz", hash = "sha256:ceb27b0af3823ea2737928a4d99d125a06175b8512c445cbd9a9ce200ef76842"},
]
[package.dependencies]
blinker = ">=1.6.2"
click = ">=8.1.3"
itsdangerous = ">=2.1.2"
Jinja2 = ">=3.1.2"
Werkzeug = ">=3.0.0"
[package.extras]
async = ["asgiref (>=3.2)"]
dotenv = ["python-dotenv"]
[[package]] [[package]]
name = "frozenlist" name = "frozenlist"
version = "1.4.1" version = "1.4.1"
@ -1550,6 +1573,17 @@ files = [
{file = "intel_openmp-2021.4.0-py2.py3-none-win_amd64.whl", hash = "sha256:eef4c8bcc8acefd7f5cd3b9384dbf73d59e2c99fc56545712ded913f43c4a94f"}, {file = "intel_openmp-2021.4.0-py2.py3-none-win_amd64.whl", hash = "sha256:eef4c8bcc8acefd7f5cd3b9384dbf73d59e2c99fc56545712ded913f43c4a94f"},
] ]
[[package]]
name = "itsdangerous"
version = "2.2.0"
description = "Safely pass data to untrusted environments and back."
optional = false
python-versions = ">=3.8"
files = [
{file = "itsdangerous-2.2.0-py3-none-any.whl", hash = "sha256:c6242fc49e35958c8b15141343aa660db5fc54d4f13a1db01a3f5891b98700ef"},
{file = "itsdangerous-2.2.0.tar.gz", hash = "sha256:e0050c0b7da1eea53ffaf149c0cfbb5c6e2e2b69c4bef22c81fa6eb73e5f6173"},
]
[[package]] [[package]]
name = "jinja2" name = "jinja2"
version = "3.1.4" version = "3.1.4"
@ -1741,9 +1775,13 @@ files = [
{file = "lxml-5.2.2-cp36-cp36m-win_amd64.whl", hash = "sha256:edcfa83e03370032a489430215c1e7783128808fd3e2e0a3225deee278585196"}, {file = "lxml-5.2.2-cp36-cp36m-win_amd64.whl", hash = "sha256:edcfa83e03370032a489430215c1e7783128808fd3e2e0a3225deee278585196"},
{file = "lxml-5.2.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:28bf95177400066596cdbcfc933312493799382879da504633d16cf60bba735b"}, {file = "lxml-5.2.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:28bf95177400066596cdbcfc933312493799382879da504633d16cf60bba735b"},
{file = "lxml-5.2.2-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3a745cc98d504d5bd2c19b10c79c61c7c3df9222629f1b6210c0368177589fb8"}, {file = "lxml-5.2.2-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3a745cc98d504d5bd2c19b10c79c61c7c3df9222629f1b6210c0368177589fb8"},
{file = "lxml-5.2.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1b590b39ef90c6b22ec0be925b211298e810b4856909c8ca60d27ffbca6c12e6"},
{file = "lxml-5.2.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b336b0416828022bfd5a2e3083e7f5ba54b96242159f83c7e3eebaec752f1716"}, {file = "lxml-5.2.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b336b0416828022bfd5a2e3083e7f5ba54b96242159f83c7e3eebaec752f1716"},
{file = "lxml-5.2.2-cp37-cp37m-manylinux_2_28_aarch64.whl", hash = "sha256:c2faf60c583af0d135e853c86ac2735ce178f0e338a3c7f9ae8f622fd2eb788c"},
{file = "lxml-5.2.2-cp37-cp37m-manylinux_2_28_x86_64.whl", hash = "sha256:4bc6cb140a7a0ad1f7bc37e018d0ed690b7b6520ade518285dc3171f7a117905"}, {file = "lxml-5.2.2-cp37-cp37m-manylinux_2_28_x86_64.whl", hash = "sha256:4bc6cb140a7a0ad1f7bc37e018d0ed690b7b6520ade518285dc3171f7a117905"},
{file = "lxml-5.2.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:7ff762670cada8e05b32bf1e4dc50b140790909caa8303cfddc4d702b71ea184"},
{file = "lxml-5.2.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:57f0a0bbc9868e10ebe874e9f129d2917750adf008fe7b9c1598c0fbbfdde6a6"}, {file = "lxml-5.2.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:57f0a0bbc9868e10ebe874e9f129d2917750adf008fe7b9c1598c0fbbfdde6a6"},
{file = "lxml-5.2.2-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:a6d2092797b388342c1bc932077ad232f914351932353e2e8706851c870bca1f"},
{file = "lxml-5.2.2-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:60499fe961b21264e17a471ec296dcbf4365fbea611bf9e303ab69db7159ce61"}, {file = "lxml-5.2.2-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:60499fe961b21264e17a471ec296dcbf4365fbea611bf9e303ab69db7159ce61"},
{file = "lxml-5.2.2-cp37-cp37m-win32.whl", hash = "sha256:d9b342c76003c6b9336a80efcc766748a333573abf9350f4094ee46b006ec18f"}, {file = "lxml-5.2.2-cp37-cp37m-win32.whl", hash = "sha256:d9b342c76003c6b9336a80efcc766748a333573abf9350f4094ee46b006ec18f"},
{file = "lxml-5.2.2-cp37-cp37m-win_amd64.whl", hash = "sha256:b16db2770517b8799c79aa80f4053cd6f8b716f21f8aca962725a9565ce3ee40"}, {file = "lxml-5.2.2-cp37-cp37m-win_amd64.whl", hash = "sha256:b16db2770517b8799c79aa80f4053cd6f8b716f21f8aca962725a9565ce3ee40"},
@ -1901,30 +1939,6 @@ files = [
intel-openmp = "==2021.*" intel-openmp = "==2021.*"
tbb = "==2021.*" tbb = "==2021.*"
[[package]]
name = "moviepy"
version = "1.0.3"
description = "Video editing with Python"
optional = false
python-versions = "*"
files = [
{file = "moviepy-1.0.3.tar.gz", hash = "sha256:2884e35d1788077db3ff89e763c5ba7bfddbd7ae9108c9bc809e7ba58fa433f5"},
]
[package.dependencies]
decorator = ">=4.0.2,<5.0"
imageio = {version = ">=2.5,<3.0", markers = "python_version >= \"3.4\""}
imageio_ffmpeg = {version = ">=0.2.0", markers = "python_version >= \"3.4\""}
numpy = {version = ">=1.17.3", markers = "python_version > \"2.7\""}
proglog = "<=1.0.0"
requests = ">=2.8.1,<3.0"
tqdm = ">=4.11.2,<5.0"
[package.extras]
doc = ["Sphinx (>=1.5.2,<2.0)", "numpydoc (>=0.6.0,<1.0)", "pygame (>=1.9.3,<2.0)", "sphinx_rtd_theme (>=0.1.10b0,<1.0)"]
optional = ["matplotlib (>=2.0.0,<3.0)", "opencv-python (>=3.0,<4.0)", "scikit-image (>=0.13.0,<1.0)", "scikit-learn", "scipy (>=0.19.0,<1.5)", "youtube_dl"]
test = ["coverage (<5.0)", "coveralls (>=1.1,<2.0)", "pytest (>=3.0.0,<4.0)", "pytest-cov (>=2.5.1,<3.0)", "requests (>=2.8.1,<3.0)"]
[[package]] [[package]]
name = "mpmath" name = "mpmath"
version = "1.3.0" version = "1.3.0"
@ -2696,20 +2710,6 @@ nodeenv = ">=0.11.1"
pyyaml = ">=5.1" pyyaml = ">=5.1"
virtualenv = ">=20.10.0" virtualenv = ">=20.10.0"
[[package]]
name = "proglog"
version = "0.1.10"
description = "Log and progress bar manager for console, notebooks, web..."
optional = false
python-versions = "*"
files = [
{file = "proglog-0.1.10-py3-none-any.whl", hash = "sha256:19d5da037e8c813da480b741e3fa71fb1ac0a5b02bf21c41577c7f327485ec50"},
{file = "proglog-0.1.10.tar.gz", hash = "sha256:658c28c9c82e4caeb2f25f488fff9ceace22f8d69b15d0c1c86d64275e4ddab4"},
]
[package.dependencies]
tqdm = "*"
[[package]] [[package]]
name = "protobuf" name = "protobuf"
version = "5.27.2" version = "5.27.2"
@ -3276,6 +3276,7 @@ files = [
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
@ -3809,13 +3810,13 @@ test = ["pytest"]
[[package]] [[package]]
name = "setuptools" name = "setuptools"
version = "71.0.1" version = "71.0.0"
description = "Easily download, build, install, upgrade, and uninstall Python packages" description = "Easily download, build, install, upgrade, and uninstall Python packages"
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "setuptools-71.0.1-py3-none-any.whl", hash = "sha256:1eb8ef012efae7f6acbc53ec0abde4bc6746c43087fd215ee09e1df48998711f"}, {file = "setuptools-71.0.0-py3-none-any.whl", hash = "sha256:f06fbe978a91819d250a30e0dc4ca79df713d909e24438a42d0ec300fc52247f"},
{file = "setuptools-71.0.1.tar.gz", hash = "sha256:c51d7fd29843aa18dad362d4b4ecd917022131425438251f4e3d766c964dd1ad"}, {file = "setuptools-71.0.0.tar.gz", hash = "sha256:98da3b8aca443b9848a209ae4165e2edede62633219afa493a58fbba57f72e2e"},
] ]
[package.extras] [package.extras]
@ -4215,6 +4216,23 @@ perf = ["orjson"]
sweeps = ["sweeps (>=0.2.0)"] sweeps = ["sweeps (>=0.2.0)"]
workspaces = ["wandb-workspaces"] workspaces = ["wandb-workspaces"]
[[package]]
name = "werkzeug"
version = "3.0.3"
description = "The comprehensive WSGI web application library."
optional = false
python-versions = ">=3.8"
files = [
{file = "werkzeug-3.0.3-py3-none-any.whl", hash = "sha256:fc9645dc43e03e4d630d23143a04a7f947a9a3b5727cd535fdfe155a17cc48c8"},
{file = "werkzeug-3.0.3.tar.gz", hash = "sha256:097e5bfda9f0aba8da6b8545146def481d06aa7d3266e7448e2cccf67dd8bd18"},
]
[package.dependencies]
MarkupSafe = ">=2.1.1"
[package.extras]
watchdog = ["watchdog (>=2.3)"]
[[package]] [[package]]
name = "xxhash" name = "xxhash"
version = "3.4.1" version = "3.4.1"
@ -4485,4 +4503,4 @@ xarm = ["gym-xarm"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.10,<3.13" python-versions = ">=3.10,<3.13"
content-hash = "dfe9c6a54e0382156e62e7bd2c7aab1be6372da76d30c61b06d27232276638cb" content-hash = "25d5a270d770d37b13a93bf72868d3b9e683f8af5252b6332ec926a26fd0c096"

View File

@ -57,13 +57,15 @@ pytest-cov = {version = ">=5.0.0", optional = true}
datasets = ">=2.19.0" datasets = ">=2.19.0"
imagecodecs = { version = ">=2024.1.1", optional = true } imagecodecs = { version = ">=2024.1.1", optional = true }
pyav = ">=12.0.5" pyav = ">=12.0.5"
moviepy = ">=1.0.3"
rerun-sdk = ">=0.15.1" rerun-sdk = ">=0.15.1"
deepdiff = ">=7.0.1" deepdiff = ">=7.0.1"
scikit-image = {version = ">=0.23.2", optional = true} flask = ">=3.0.3"
pandas = {version = ">=2.2.2", optional = true} pandas = {version = ">=2.2.2", optional = true}
scikit-image = {version = ">=0.23.2", optional = true}
dynamixel-sdk = {version = ">=3.7.31", optional = true} dynamixel-sdk = {version = ">=3.7.31", optional = true}
pynput = {version = ">=1.7.7", optional = true} pynput = {version = ">=1.7.7", optional = true}
# TODO(rcadene, salibert): 71.0.1 has a bug
setuptools = {version = "!=71.0.1", optional = true}

View File

@ -0,0 +1,36 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
import pytest
from lerobot.scripts.visualize_dataset_html import visualize_dataset_html
@pytest.mark.parametrize(
"repo_id",
["lerobot/pusht"],
)
def test_visualize_dataset_html(tmpdir, repo_id):
tmpdir = Path(tmpdir)
visualize_dataset_html(
repo_id,
episodes=[0],
output_dir=tmpdir,
serve=False,
)
assert (tmpdir / "static" / "episode_0.csv").exists()