Add visualize_dataset_html with `http.server` (#188)
This commit is contained in:
parent
bc6384bb80
commit
2252b42337
|
@ -108,8 +108,8 @@ def visualize_dataset(
|
|||
web_port: int = 9090,
|
||||
ws_port: int = 9087,
|
||||
save: bool = False,
|
||||
output_dir: Path | None = None,
|
||||
root: Path | None = None,
|
||||
output_dir: Path | None = None,
|
||||
) -> Path | None:
|
||||
if save:
|
||||
assert (
|
||||
|
@ -209,6 +209,18 @@ def main():
|
|||
required=True,
|
||||
help="Episode to visualize.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Root directory for a dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Directory path to write a .rrd file when `--save 1` is set.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
|
@ -254,17 +266,6 @@ def main():
|
|||
"Visualize the data by running `rerun path/to/file.rrd` on your local machine."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
help="Directory path to write a .rrd file when `--save 1` is set.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--root",
|
||||
type=str,
|
||||
help="Root directory for a dataset stored on a local machine.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
visualize_dataset(**vars(args))
|
||||
|
|
|
@ -0,0 +1,300 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
|
||||
|
||||
Note: The last frame of the episode doesnt always correspond to a final state.
|
||||
That's because our datasets are composed of transition from state to state up to
|
||||
the antepenultimate state associated to the ultimate action to arrive in the final state.
|
||||
However, there might not be a transition from a final state to another state.
|
||||
|
||||
Note: This script aims to visualize the data used to train the neural networks.
|
||||
~What you see is what you get~. When visualizing image modality, it is often expected to observe
|
||||
lossly compression artifacts since these images have been decoded from compressed mp4 videos to
|
||||
save disk space. The compression factor applied has been tuned to not affect success rate.
|
||||
|
||||
Example of usage:
|
||||
|
||||
- Visualize data stored on a local machine:
|
||||
```bash
|
||||
local$ python lerobot/scripts/visualize_dataset_html.py \
|
||||
--repo-id lerobot/pusht
|
||||
|
||||
local$ open http://localhost:9090
|
||||
```
|
||||
|
||||
- Visualize data stored on a distant machine with a local viewer:
|
||||
```bash
|
||||
distant$ python lerobot/scripts/visualize_dataset_html.py \
|
||||
--repo-id lerobot/pusht
|
||||
|
||||
local$ ssh -L 9090:localhost:9090 distant # create a ssh tunnel
|
||||
local$ open http://localhost:9090
|
||||
```
|
||||
|
||||
- Select episodes to visualize:
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset_html.py \
|
||||
--repo-id lerobot/pusht \
|
||||
--episodes 7 3 5 1 4
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
from flask import Flask, redirect, render_template, url_for
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.utils.utils import init_logging
|
||||
|
||||
|
||||
class EpisodeSampler(torch.utils.data.Sampler):
|
||||
def __init__(self, dataset, episode_index):
|
||||
from_idx = dataset.episode_data_index["from"][episode_index].item()
|
||||
to_idx = dataset.episode_data_index["to"][episode_index].item()
|
||||
self.frame_ids = range(from_idx, to_idx)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.frame_ids)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.frame_ids)
|
||||
|
||||
|
||||
def run_server(
|
||||
dataset: LeRobotDataset,
|
||||
episodes: list[int],
|
||||
host: str,
|
||||
port: str,
|
||||
static_folder: Path,
|
||||
template_folder: Path,
|
||||
):
|
||||
app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve())
|
||||
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
|
||||
|
||||
@app.route("/")
|
||||
def index():
|
||||
# home page redirects to the first episode page
|
||||
[dataset_namespace, dataset_name] = dataset.repo_id.split("/")
|
||||
first_episode_id = episodes[0]
|
||||
return redirect(
|
||||
url_for(
|
||||
"show_episode",
|
||||
dataset_namespace=dataset_namespace,
|
||||
dataset_name=dataset_name,
|
||||
episode_id=first_episode_id,
|
||||
)
|
||||
)
|
||||
|
||||
@app.route("/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>")
|
||||
def show_episode(dataset_namespace, dataset_name, episode_id):
|
||||
dataset_info = {
|
||||
"repo_id": dataset.repo_id,
|
||||
"num_samples": dataset.num_samples,
|
||||
"num_episodes": dataset.num_episodes,
|
||||
"fps": dataset.fps,
|
||||
}
|
||||
video_paths = get_episode_video_paths(dataset, episode_id)
|
||||
videos_info = [
|
||||
{"url": url_for("static", filename=video_path), "filename": Path(video_path).name}
|
||||
for video_path in video_paths
|
||||
]
|
||||
ep_csv_url = url_for("static", filename=get_ep_csv_fname(episode_id))
|
||||
return render_template(
|
||||
"visualize_dataset_template.html",
|
||||
episode_id=episode_id,
|
||||
episodes=episodes,
|
||||
dataset_info=dataset_info,
|
||||
videos_info=videos_info,
|
||||
ep_csv_url=ep_csv_url,
|
||||
has_policy=False,
|
||||
)
|
||||
|
||||
app.run(host=host, port=port)
|
||||
|
||||
|
||||
def get_ep_csv_fname(episode_id: int):
|
||||
ep_csv_fname = f"episode_{episode_id}.csv"
|
||||
return ep_csv_fname
|
||||
|
||||
|
||||
def write_episode_data_csv(output_dir, file_name, episode_index, dataset):
|
||||
"""Write a csv file containg timeseries data of an episode (e.g. state and action).
|
||||
This file will be loaded by Dygraph javascript to plot data in real time."""
|
||||
from_idx = dataset.episode_data_index["from"][episode_index]
|
||||
to_idx = dataset.episode_data_index["to"][episode_index]
|
||||
|
||||
has_state = "observation.state" in dataset.hf_dataset.features
|
||||
has_action = "action" in dataset.hf_dataset.features
|
||||
|
||||
# init header of csv with state and action names
|
||||
header = ["timestamp"]
|
||||
if has_state:
|
||||
dim_state = len(dataset.hf_dataset["observation.state"][0])
|
||||
header += [f"state_{i}" for i in range(dim_state)]
|
||||
if has_action:
|
||||
dim_action = len(dataset.hf_dataset["action"][0])
|
||||
header += [f"action_{i}" for i in range(dim_action)]
|
||||
|
||||
columns = ["timestamp"]
|
||||
if has_state:
|
||||
columns += ["observation.state"]
|
||||
if has_action:
|
||||
columns += ["action"]
|
||||
|
||||
rows = []
|
||||
data = dataset.hf_dataset.select_columns(columns)
|
||||
for i in range(from_idx, to_idx):
|
||||
row = [data[i]["timestamp"].item()]
|
||||
if has_state:
|
||||
row += data[i]["observation.state"].tolist()
|
||||
if has_action:
|
||||
row += data[i]["action"].tolist()
|
||||
rows.append(row)
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_dir / file_name, "w") as f:
|
||||
f.write(",".join(header) + "\n")
|
||||
for row in rows:
|
||||
row_str = [str(col) for col in row]
|
||||
f.write(",".join(row_str) + "\n")
|
||||
|
||||
|
||||
def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]:
|
||||
# get first frame of episode (hack to get video_path of the episode)
|
||||
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
|
||||
return [
|
||||
dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"]
|
||||
for key in dataset.video_frame_keys
|
||||
]
|
||||
|
||||
|
||||
def visualize_dataset_html(
|
||||
repo_id: str,
|
||||
root: Path | None = None,
|
||||
episodes: list[int] = None,
|
||||
output_dir: Path | None = None,
|
||||
serve: bool = True,
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 9090,
|
||||
force_override: bool = False,
|
||||
) -> Path | None:
|
||||
init_logging()
|
||||
|
||||
dataset = LeRobotDataset(repo_id, root=root)
|
||||
|
||||
if not dataset.video:
|
||||
raise NotImplementedError(f"Image datasets ({dataset.video=}) are currently not supported.")
|
||||
|
||||
if output_dir is None:
|
||||
output_dir = f"outputs/visualize_dataset_html/{repo_id}"
|
||||
|
||||
output_dir = Path(output_dir)
|
||||
if output_dir.exists():
|
||||
if force_override:
|
||||
shutil.rmtree(output_dir)
|
||||
else:
|
||||
logging.info(f"Output directory already exists. Loading from it: '{output_dir}'")
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create a simlink from the dataset video folder containg mp4 files to the output directory
|
||||
# so that the http server can get access to the mp4 files.
|
||||
static_dir = output_dir / "static"
|
||||
static_dir.mkdir(parents=True, exist_ok=True)
|
||||
ln_videos_dir = static_dir / "videos"
|
||||
if not ln_videos_dir.exists():
|
||||
ln_videos_dir.symlink_to(dataset.videos_dir.resolve())
|
||||
|
||||
template_dir = Path(__file__).resolve().parent.parent / "templates"
|
||||
|
||||
if episodes is None:
|
||||
episodes = list(range(dataset.num_episodes))
|
||||
|
||||
logging.info("Writing CSV files")
|
||||
for episode_index in tqdm.tqdm(episodes):
|
||||
# write states and actions in a csv (it can be slow for big datasets)
|
||||
ep_csv_fname = get_ep_csv_fname(episode_index)
|
||||
# TODO(rcadene): speedup script by loading directly from dataset, pyarrow, parquet, safetensors?
|
||||
write_episode_data_csv(static_dir, ep_csv_fname, episode_index, dataset)
|
||||
|
||||
if serve:
|
||||
run_server(dataset, episodes, host, port, static_dir, template_dir)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Root directory for a dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--episodes",
|
||||
type=int,
|
||||
nargs="*",
|
||||
default=None,
|
||||
help="Episode indices to visualize (e.g. `0 1 5 6` to load episodes of index 0, 1, 5 and 6). By default loads all episodes.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Directory path to write html files and kickoff a web server. By default write them to 'outputs/visualize_dataset/REPO_ID'.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--serve",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Launch web server.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="127.0.0.1",
|
||||
help="Web host used by the http server.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=9090,
|
||||
help="Web port used by the http server.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force-override",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Delete the output directory if it exists already.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
visualize_dataset_html(**vars(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,360 @@
|
|||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<!-- # TODO(rcadene, mishig25): store the js files locally -->
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/alpinejs/3.13.5/cdn.min.js" defer></script>
|
||||
<script src="https://cdn.jsdelivr.net/npm/dygraphs@2.2.1/dist/dygraph.min.js" type="text/javascript"></script>
|
||||
<script src="https://cdn.tailwindcss.com"></script>
|
||||
<title>{{ dataset_info.repo_id }} episode {{ episode_id }}</title>
|
||||
</head>
|
||||
|
||||
<!-- Use [Alpin.js](https://alpinejs.dev), a lightweight and easy to learn JS framework -->
|
||||
<!-- Use [tailwindcss](https://tailwindcss.com/), CSS classes for styling html -->
|
||||
<!-- Use [dygraphs](https://dygraphs.com/), a lightweight JS charting library -->
|
||||
<body class="flex h-screen max-h-screen bg-slate-950 text-gray-200" x-data="createAlpineData()" @keydown.window="(e) => {
|
||||
// Use the space bar to play and pause, instead of default action (e.g. scrolling)
|
||||
const { keyCode, key } = e;
|
||||
if (keyCode === 32 || key === ' ') {
|
||||
e.preventDefault();
|
||||
$refs.btnPause.classList.contains('hidden') ? $refs.btnPlay.click() : $refs.btnPause.click();
|
||||
}else if (key === 'ArrowDown' || key === 'ArrowUp'){
|
||||
const nextEpisodeId = key === 'ArrowDown' ? {{ episode_id }} + 1 : {{ episode_id }} - 1;
|
||||
const lowestEpisodeId = {{ episodes }}.at(0);
|
||||
const highestEpisodeId = {{ episodes }}.at(-1);
|
||||
if(nextEpisodeId >= lowestEpisodeId && nextEpisodeId <= highestEpisodeId){
|
||||
window.location.href = `./episode_${nextEpisodeId}`;
|
||||
}
|
||||
}
|
||||
}">
|
||||
<!-- Sidebar -->
|
||||
<div x-ref="sidebar" class="w-60 bg-slate-900 p-5 break-words max-h-screen overflow-y-auto">
|
||||
<h1 class="mb-4 text-xl font-semibold">{{ dataset_info.repo_id }}</h1>
|
||||
|
||||
<ul>
|
||||
<li>
|
||||
Number of samples/frames: {{ dataset_info.num_samples }}
|
||||
</li>
|
||||
<li>
|
||||
Number of episodes: {{ dataset_info.num_episodes }}
|
||||
</li>
|
||||
<li>
|
||||
Frames per second: {{ dataset_info.fps }}
|
||||
</li>
|
||||
</ul>
|
||||
|
||||
<p>Episodes:</p>
|
||||
<ul class="ml-2">
|
||||
{% for episode in episodes %}
|
||||
<li class="font-mono text-sm mt-0.5">
|
||||
<a href="episode_{{ episode }}" class="underline {% if episode_id == episode %}font-bold -ml-1{% endif %}">
|
||||
Episode {{ episode }}
|
||||
</a>
|
||||
</li>
|
||||
{% endfor %}
|
||||
</ul>
|
||||
|
||||
</div>
|
||||
|
||||
<!-- Toggle sidebar button -->
|
||||
<button class="flex items-center opacity-50 hover:opacity-100 mx-1"
|
||||
@click="() => ($refs.sidebar.classList.toggle('hidden'))" title="Toggle sidebar">
|
||||
<div class="bg-slate-500 w-2 h-10 rounded-full"></div>
|
||||
</button>
|
||||
|
||||
<!-- Content -->
|
||||
<div class="flex-1 max-h-screen flex flex-col gap-4 overflow-y-auto">
|
||||
<h1 class="text-xl font-bold mt-4 font-mono">
|
||||
Episode {{ episode_id }}
|
||||
</h1>
|
||||
|
||||
<!-- Videos -->
|
||||
<div class="flex flex-wrap gap-1">
|
||||
{% for video_info in videos_info %}
|
||||
<div class="max-w-96">
|
||||
<p class="text-sm text-gray-300 bg-gray-800 px-2 rounded-t-xl truncate">{{ video_info.filename }}</p>
|
||||
<video autoplay muted loop type="video/mp4" class="min-w-64" @timeupdate="() => {
|
||||
if (video.duration) {
|
||||
const time = video.currentTime;
|
||||
const pc = (100 / video.duration) * time;
|
||||
$refs.slider.value = pc;
|
||||
dygraphTime = time;
|
||||
dygraphIndex = Math.floor(pc * dygraph.numRows() / 100);
|
||||
dygraph.setSelection(dygraphIndex, undefined, true, true);
|
||||
|
||||
$refs.timer.textContent = formatTime(time) + ' / ' + formatTime(video.duration);
|
||||
|
||||
updateTimeQuery(time.toFixed(2));
|
||||
}
|
||||
}" @ended="() => {
|
||||
$refs.btnPlay.classList.remove('hidden');
|
||||
$refs.btnPause.classList.add('hidden');
|
||||
}"
|
||||
@loadedmetadata="() => ($refs.timer.textContent = formatTime(0) + ' / ' + formatTime(video.duration))">
|
||||
<source src="{{ video_info.url }}">
|
||||
Your browser does not support the video tag.
|
||||
</video>
|
||||
</div>
|
||||
{% endfor %}
|
||||
</div>
|
||||
|
||||
<!-- Shortcuts info -->
|
||||
<div class="text-sm hidden md:block">
|
||||
Hotkeys: <span class="font-mono">Space</span> to pause/unpause, <span class="font-mono">Arrow Down</span> to go to next episode, <span class="font-mono">Arrow Up</span> to go to previous episode.
|
||||
</div>
|
||||
|
||||
<!-- Controllers -->
|
||||
<div class="flex gap-1 text-3xl items-center">
|
||||
<button x-ref="btnPlay" class="-rotate-90 hidden" class="-rotate-90" title="Play. Toggle with Space" @click="() => {
|
||||
videos.forEach(video => video.play());
|
||||
$refs.btnPlay.classList.toggle('hidden');
|
||||
$refs.btnPause.classList.toggle('hidden');
|
||||
}">🔽</button>
|
||||
<button x-ref="btnPause" title="Pause. Toggle with Space" @click="() => {
|
||||
videos.forEach(video => video.pause());
|
||||
$refs.btnPlay.classList.toggle('hidden');
|
||||
$refs.btnPause.classList.toggle('hidden');
|
||||
}">⏸️</button>
|
||||
<button title="Jump backward 5 seconds"
|
||||
@click="() => (videos.forEach(video => (video.currentTime -= 5)))">⏪</button>
|
||||
<button title="Jump forward 5 seconds"
|
||||
@click="() => (videos.forEach(video => (video.currentTime += 5)))">⏩</button>
|
||||
<button title="Rewind from start"
|
||||
@click="() => (videos.forEach(video => (video.currentTime = 0.0)))">↩️</button>
|
||||
<input x-ref="slider" max="100" min="0" step="1" type="range" value="0" class="w-80 mx-2" @input="() => {
|
||||
const sliderValue = $refs.slider.value;
|
||||
$refs.btnPause.click();
|
||||
videos.forEach(video => {
|
||||
const time = (video.duration * sliderValue) / 100;
|
||||
video.currentTime = time;
|
||||
});
|
||||
}" />
|
||||
<div x-ref="timer" class="font-mono text-sm border border-slate-500 rounded-lg px-1 py-0.5 shrink-0">0:00 /
|
||||
0:00
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Graph -->
|
||||
<div class="flex gap-2 mb-4 flex-wrap">
|
||||
<div>
|
||||
<div id="graph" @mouseleave="() => {
|
||||
dygraph.setSelection(dygraphIndex, undefined, true, true);
|
||||
dygraphTime = video.currentTime;
|
||||
}">
|
||||
</div>
|
||||
<p x-ref="graphTimer" class="font-mono ml-14 mt-4"
|
||||
x-init="$watch('dygraphTime', value => ($refs.graphTimer.innerText = `Time: ${dygraphTime.toFixed(2)}s`))">
|
||||
Time: 0.00s
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<table class="text-sm border-collapse border border-slate-700" x-show="currentFrameData">
|
||||
<thead>
|
||||
<tr>
|
||||
<th></th>
|
||||
<template x-for="(_, colIndex) in Array.from({length: nColumns}, (_, index) => index)">
|
||||
<th class="border border-slate-700">
|
||||
<div class="flex gap-x-2 justify-between px-2">
|
||||
<input type="checkbox" :checked="isColumnChecked(colIndex)"
|
||||
@change="toggleColumn(colIndex)">
|
||||
<p x-text="`${columnNames[colIndex]}`"></p>
|
||||
</div>
|
||||
</th>
|
||||
</template>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<template x-for="(row, rowIndex) in rows">
|
||||
<tr class="odd:bg-gray-800 even:bg-gray-900">
|
||||
<td class="border border-slate-700">
|
||||
<div class="flex gap-x-2 w-24 font-semibold px-1">
|
||||
<input type="checkbox" :checked="isRowChecked(rowIndex)"
|
||||
@change="toggleRow(rowIndex)">
|
||||
<p x-text="`Motor ${rowIndex}`"></p>
|
||||
</div>
|
||||
</td>
|
||||
<template x-for="(cell, colIndex) in row">
|
||||
<td x-show="cell" class="border border-slate-700">
|
||||
<div class="flex gap-x-2 w-24 justify-between px-2">
|
||||
<input type="checkbox" x-model="cell.checked" @change="updateTableValues()">
|
||||
<span x-text="`${cell.value.toFixed(2)}`"
|
||||
:style="`color: ${cell.color}`"></span>
|
||||
</div>
|
||||
</td>
|
||||
</template>
|
||||
</tr>
|
||||
</template>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
<div id="labels" class="hidden">
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
function createAlpineData() {
|
||||
return {
|
||||
// state
|
||||
dygraph: null,
|
||||
currentFrameData: null,
|
||||
columnNames: ["state", "action", "pred action"],
|
||||
nColumns: {% if has_policy %}3{% else %}2{% endif %},
|
||||
checked: [],
|
||||
dygraphTime: 0.0,
|
||||
dygraphIndex: 0,
|
||||
videos: null,
|
||||
video: null,
|
||||
colors: null,
|
||||
|
||||
// alpine initialization
|
||||
init() {
|
||||
this.videos = document.querySelectorAll('video');
|
||||
this.video = this.videos[0];
|
||||
this.dygraph = new Dygraph(document.getElementById("graph"), '{{ ep_csv_url }}', {
|
||||
pixelsPerPoint: 0.01,
|
||||
legend: 'always',
|
||||
labelsDiv: document.getElementById('labels'),
|
||||
labelsKMB: true,
|
||||
strokeWidth: 1.5,
|
||||
pointClickCallback: (event, point) => {
|
||||
this.dygraphTime = point.xval;
|
||||
this.updateTableValues(this.dygraphTime);
|
||||
},
|
||||
highlightCallback: (event, x, points, row, seriesName) => {
|
||||
this.dygraphTime = x;
|
||||
this.updateTableValues(this.dygraphTime);
|
||||
},
|
||||
drawCallback: (dygraph, is_initial) => {
|
||||
if (is_initial) {
|
||||
// dygraph initialization
|
||||
this.dygraph.setSelection(this.dygraphIndex, undefined, true, true);
|
||||
this.colors = this.dygraph.getColors();
|
||||
this.checked = Array(this.colors.length).fill(true);
|
||||
|
||||
const seriesNames = this.dygraph.getLabels().slice(1);
|
||||
const colors = [];
|
||||
const LIGHTNESS = [30, 65, 85]; // state_lightness, action_lightness, pred_action_lightness
|
||||
let lightnessIdx = 0;
|
||||
const chunkSize = Math.ceil(seriesNames.length / this.nColumns);
|
||||
for (let i = 0; i < seriesNames.length; i += chunkSize) {
|
||||
const lightness = LIGHTNESS[lightnessIdx];
|
||||
for (let hue = 0; hue < 360; hue += parseInt(360/chunkSize)) {
|
||||
const color = `hsl(${hue}, 100%, ${lightness}%)`;
|
||||
colors.push(color);
|
||||
}
|
||||
lightnessIdx += 1;
|
||||
}
|
||||
this.dygraph.updateOptions({ colors });
|
||||
this.colors = colors;
|
||||
|
||||
this.updateTableValues();
|
||||
|
||||
let url = new URL(window.location.href);
|
||||
let params = new URLSearchParams(url.search);
|
||||
let time = params.get("t");
|
||||
if(time){
|
||||
time = parseFloat(time);
|
||||
this.videos.forEach(video => (video.currentTime = time));
|
||||
}
|
||||
}
|
||||
},
|
||||
});
|
||||
},
|
||||
|
||||
//#region Table Data
|
||||
|
||||
// turn dygraph's 1D data (at a given time t) to 2D data that whose columns names are defined in this.columnNames.
|
||||
// 2d data view is used to create html table element.
|
||||
get rows() {
|
||||
if (!this.currentFrameData) {
|
||||
return [];
|
||||
}
|
||||
const columnSize = Math.ceil(this.currentFrameData.length / this.nColumns);
|
||||
return Array.from({
|
||||
length: columnSize
|
||||
}, (_, rowIndex) => {
|
||||
const row = [
|
||||
this.currentFrameData[rowIndex] || null,
|
||||
this.currentFrameData[rowIndex + columnSize] || null,
|
||||
];
|
||||
if (this.nColumns === 3) {
|
||||
row.push(this.currentFrameData[rowIndex + 2 * columnSize] || null)
|
||||
}
|
||||
return row;
|
||||
});
|
||||
},
|
||||
isRowChecked(rowIndex) {
|
||||
return this.rows[rowIndex].every(cell => cell && cell.checked);
|
||||
},
|
||||
isColumnChecked(colIndex) {
|
||||
return this.rows.every(row => row[colIndex] && row[colIndex].checked);
|
||||
},
|
||||
toggleRow(rowIndex) {
|
||||
const newState = !this.isRowChecked(rowIndex);
|
||||
this.rows[rowIndex].forEach(cell => {
|
||||
if (cell) cell.checked = newState;
|
||||
});
|
||||
this.updateTableValues();
|
||||
},
|
||||
toggleColumn(colIndex) {
|
||||
const newState = !this.isColumnChecked(colIndex);
|
||||
this.rows.forEach(row => {
|
||||
if (row[colIndex]) row[colIndex].checked = newState;
|
||||
});
|
||||
this.updateTableValues();
|
||||
},
|
||||
|
||||
// given time t, update the values in the html table with "data[t]"
|
||||
updateTableValues(time) {
|
||||
if (!this.colors) {
|
||||
return;
|
||||
}
|
||||
let pc = (100 / this.video.duration) * (time === undefined ? this.video.currentTime : time);
|
||||
if (isNaN(pc)) pc = 0;
|
||||
const index = Math.floor(pc * this.dygraph.numRows() / 100);
|
||||
// slice(1) to remove the timestamp point that we do not need
|
||||
const labels = this.dygraph.getLabels().slice(1);
|
||||
const values = this.dygraph.rawData_[index].slice(1);
|
||||
const checkedNew = this.currentFrameData ? this.currentFrameData.map(cell => cell.checked) : Array(
|
||||
this.colors.length).fill(true);
|
||||
this.currentFrameData = labels.map((label, idx) => ({
|
||||
label,
|
||||
value: values[idx],
|
||||
color: this.colors[idx],
|
||||
checked: checkedNew[idx],
|
||||
}));
|
||||
const shouldUpdateVisibility = !this.checked.every((value, index) => value === checkedNew[index]);
|
||||
if (shouldUpdateVisibility) {
|
||||
this.checked = checkedNew;
|
||||
this.dygraph.setVisibility(this.checked);
|
||||
}
|
||||
},
|
||||
|
||||
//#endregion
|
||||
|
||||
updateTimeQuery(time) {
|
||||
let url = new URL(window.location.href);
|
||||
let params = new URLSearchParams(url.search);
|
||||
params.set("t", time);
|
||||
url.search = params.toString();
|
||||
window.history.replaceState({}, '', url.toString());
|
||||
},
|
||||
|
||||
|
||||
formatTime(time) {
|
||||
var hours = Math.floor(time / 3600);
|
||||
var minutes = Math.floor((time % 3600) / 60);
|
||||
var seconds = Math.floor(time % 60);
|
||||
return (hours > 0 ? hours + ':' : '') + (minutes < 10 ? '0' + minutes : minutes) + ':' + (seconds <
|
||||
10 ?
|
||||
'0' + seconds : seconds);
|
||||
}
|
||||
};
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
|
||||
</html>
|
|
@ -1,4 +1,4 @@
|
|||
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "absl-py"
|
||||
|
@ -192,6 +192,17 @@ charset-normalizer = ["charset-normalizer"]
|
|||
html5lib = ["html5lib"]
|
||||
lxml = ["lxml"]
|
||||
|
||||
[[package]]
|
||||
name = "blinker"
|
||||
version = "1.8.2"
|
||||
description = "Fast, simple object-to-object and broadcast signaling"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "blinker-1.8.2-py3-none-any.whl", hash = "sha256:1779309f71bf239144b9399d06ae925637cf6634cf6bd131104184531bf67c01"},
|
||||
{file = "blinker-1.8.2.tar.gz", hash = "sha256:8f77b09d3bf7c795e969e9486f39c2c5e9c39d4ee07424be2bc594ece9642d83"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "certifi"
|
||||
version = "2024.7.4"
|
||||
|
@ -584,17 +595,6 @@ files = [
|
|||
{file = "debugpy-1.8.2.zip", hash = "sha256:95378ed08ed2089221896b9b3a8d021e642c24edc8fef20e5d4342ca8be65c00"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "decorator"
|
||||
version = "4.4.2"
|
||||
description = "Decorators for Humans"
|
||||
optional = false
|
||||
python-versions = ">=2.6, !=3.0.*, !=3.1.*"
|
||||
files = [
|
||||
{file = "decorator-4.4.2-py2.py3-none-any.whl", hash = "sha256:41fa54c2a0cc4ba648be4fd43cff00aedf5b9465c9bf18d64325bc225f08f760"},
|
||||
{file = "decorator-4.4.2.tar.gz", hash = "sha256:e3a62f0520172440ca0dcc823749319382e377f37f140a0b99ef45fecb84bfe7"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "deepdiff"
|
||||
version = "7.0.1"
|
||||
|
@ -795,6 +795,7 @@ files = [
|
|||
{file = "dora_rs-0.3.5-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:01f811d0c6722f74743c153a7be0144686daeafa968c473e60f6b6c5dc8f5bff"},
|
||||
{file = "dora_rs-0.3.5-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:a36e97d31eeb66e6d5913130695d188ceee1248029961012a8b4f59fd3f58670"},
|
||||
{file = "dora_rs-0.3.5-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:25d620123a733661dc740ef2b456601ddbaa69ae2b50d8141daa3c684bda385c"},
|
||||
{file = "dora_rs-0.3.5-cp37-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a9fdc4e73578bebb1c8d0f8bea2243a5a9e179f08c74d98576123b59b75e5cac"},
|
||||
{file = "dora_rs-0.3.5-cp37-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e65830634c58158557f0ab90e5d1f492bcbc6b74587b05825ba4c20b634dc1bd"},
|
||||
{file = "dora_rs-0.3.5-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c01f9ab8f93295341aeab2d606d484d9cff9d05f57581e2180433ec8e0d38307"},
|
||||
{file = "dora_rs-0.3.5-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:5d6d46a49a34cd7e4f74496a1089b9a1b78282c219a28d98fe031a763e92d530"},
|
||||
|
@ -892,6 +893,28 @@ docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1
|
|||
testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-asyncio (>=0.21)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)", "virtualenv (>=20.26.2)"]
|
||||
typing = ["typing-extensions (>=4.8)"]
|
||||
|
||||
[[package]]
|
||||
name = "flask"
|
||||
version = "3.0.3"
|
||||
description = "A simple framework for building complex web applications."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "flask-3.0.3-py3-none-any.whl", hash = "sha256:34e815dfaa43340d1d15a5c3a02b8476004037eb4840b34910c6e21679d288f3"},
|
||||
{file = "flask-3.0.3.tar.gz", hash = "sha256:ceb27b0af3823ea2737928a4d99d125a06175b8512c445cbd9a9ce200ef76842"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
blinker = ">=1.6.2"
|
||||
click = ">=8.1.3"
|
||||
itsdangerous = ">=2.1.2"
|
||||
Jinja2 = ">=3.1.2"
|
||||
Werkzeug = ">=3.0.0"
|
||||
|
||||
[package.extras]
|
||||
async = ["asgiref (>=3.2)"]
|
||||
dotenv = ["python-dotenv"]
|
||||
|
||||
[[package]]
|
||||
name = "frozenlist"
|
||||
version = "1.4.1"
|
||||
|
@ -1550,6 +1573,17 @@ files = [
|
|||
{file = "intel_openmp-2021.4.0-py2.py3-none-win_amd64.whl", hash = "sha256:eef4c8bcc8acefd7f5cd3b9384dbf73d59e2c99fc56545712ded913f43c4a94f"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itsdangerous"
|
||||
version = "2.2.0"
|
||||
description = "Safely pass data to untrusted environments and back."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "itsdangerous-2.2.0-py3-none-any.whl", hash = "sha256:c6242fc49e35958c8b15141343aa660db5fc54d4f13a1db01a3f5891b98700ef"},
|
||||
{file = "itsdangerous-2.2.0.tar.gz", hash = "sha256:e0050c0b7da1eea53ffaf149c0cfbb5c6e2e2b69c4bef22c81fa6eb73e5f6173"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jinja2"
|
||||
version = "3.1.4"
|
||||
|
@ -1741,9 +1775,13 @@ files = [
|
|||
{file = "lxml-5.2.2-cp36-cp36m-win_amd64.whl", hash = "sha256:edcfa83e03370032a489430215c1e7783128808fd3e2e0a3225deee278585196"},
|
||||
{file = "lxml-5.2.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:28bf95177400066596cdbcfc933312493799382879da504633d16cf60bba735b"},
|
||||
{file = "lxml-5.2.2-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3a745cc98d504d5bd2c19b10c79c61c7c3df9222629f1b6210c0368177589fb8"},
|
||||
{file = "lxml-5.2.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1b590b39ef90c6b22ec0be925b211298e810b4856909c8ca60d27ffbca6c12e6"},
|
||||
{file = "lxml-5.2.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b336b0416828022bfd5a2e3083e7f5ba54b96242159f83c7e3eebaec752f1716"},
|
||||
{file = "lxml-5.2.2-cp37-cp37m-manylinux_2_28_aarch64.whl", hash = "sha256:c2faf60c583af0d135e853c86ac2735ce178f0e338a3c7f9ae8f622fd2eb788c"},
|
||||
{file = "lxml-5.2.2-cp37-cp37m-manylinux_2_28_x86_64.whl", hash = "sha256:4bc6cb140a7a0ad1f7bc37e018d0ed690b7b6520ade518285dc3171f7a117905"},
|
||||
{file = "lxml-5.2.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:7ff762670cada8e05b32bf1e4dc50b140790909caa8303cfddc4d702b71ea184"},
|
||||
{file = "lxml-5.2.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:57f0a0bbc9868e10ebe874e9f129d2917750adf008fe7b9c1598c0fbbfdde6a6"},
|
||||
{file = "lxml-5.2.2-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:a6d2092797b388342c1bc932077ad232f914351932353e2e8706851c870bca1f"},
|
||||
{file = "lxml-5.2.2-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:60499fe961b21264e17a471ec296dcbf4365fbea611bf9e303ab69db7159ce61"},
|
||||
{file = "lxml-5.2.2-cp37-cp37m-win32.whl", hash = "sha256:d9b342c76003c6b9336a80efcc766748a333573abf9350f4094ee46b006ec18f"},
|
||||
{file = "lxml-5.2.2-cp37-cp37m-win_amd64.whl", hash = "sha256:b16db2770517b8799c79aa80f4053cd6f8b716f21f8aca962725a9565ce3ee40"},
|
||||
|
@ -1901,30 +1939,6 @@ files = [
|
|||
intel-openmp = "==2021.*"
|
||||
tbb = "==2021.*"
|
||||
|
||||
[[package]]
|
||||
name = "moviepy"
|
||||
version = "1.0.3"
|
||||
description = "Video editing with Python"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "moviepy-1.0.3.tar.gz", hash = "sha256:2884e35d1788077db3ff89e763c5ba7bfddbd7ae9108c9bc809e7ba58fa433f5"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
decorator = ">=4.0.2,<5.0"
|
||||
imageio = {version = ">=2.5,<3.0", markers = "python_version >= \"3.4\""}
|
||||
imageio_ffmpeg = {version = ">=0.2.0", markers = "python_version >= \"3.4\""}
|
||||
numpy = {version = ">=1.17.3", markers = "python_version > \"2.7\""}
|
||||
proglog = "<=1.0.0"
|
||||
requests = ">=2.8.1,<3.0"
|
||||
tqdm = ">=4.11.2,<5.0"
|
||||
|
||||
[package.extras]
|
||||
doc = ["Sphinx (>=1.5.2,<2.0)", "numpydoc (>=0.6.0,<1.0)", "pygame (>=1.9.3,<2.0)", "sphinx_rtd_theme (>=0.1.10b0,<1.0)"]
|
||||
optional = ["matplotlib (>=2.0.0,<3.0)", "opencv-python (>=3.0,<4.0)", "scikit-image (>=0.13.0,<1.0)", "scikit-learn", "scipy (>=0.19.0,<1.5)", "youtube_dl"]
|
||||
test = ["coverage (<5.0)", "coveralls (>=1.1,<2.0)", "pytest (>=3.0.0,<4.0)", "pytest-cov (>=2.5.1,<3.0)", "requests (>=2.8.1,<3.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "mpmath"
|
||||
version = "1.3.0"
|
||||
|
@ -2696,20 +2710,6 @@ nodeenv = ">=0.11.1"
|
|||
pyyaml = ">=5.1"
|
||||
virtualenv = ">=20.10.0"
|
||||
|
||||
[[package]]
|
||||
name = "proglog"
|
||||
version = "0.1.10"
|
||||
description = "Log and progress bar manager for console, notebooks, web..."
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "proglog-0.1.10-py3-none-any.whl", hash = "sha256:19d5da037e8c813da480b741e3fa71fb1ac0a5b02bf21c41577c7f327485ec50"},
|
||||
{file = "proglog-0.1.10.tar.gz", hash = "sha256:658c28c9c82e4caeb2f25f488fff9ceace22f8d69b15d0c1c86d64275e4ddab4"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
tqdm = "*"
|
||||
|
||||
[[package]]
|
||||
name = "protobuf"
|
||||
version = "5.27.2"
|
||||
|
@ -3276,6 +3276,7 @@ files = [
|
|||
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
|
||||
|
@ -3809,13 +3810,13 @@ test = ["pytest"]
|
|||
|
||||
[[package]]
|
||||
name = "setuptools"
|
||||
version = "71.0.1"
|
||||
version = "71.0.0"
|
||||
description = "Easily download, build, install, upgrade, and uninstall Python packages"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "setuptools-71.0.1-py3-none-any.whl", hash = "sha256:1eb8ef012efae7f6acbc53ec0abde4bc6746c43087fd215ee09e1df48998711f"},
|
||||
{file = "setuptools-71.0.1.tar.gz", hash = "sha256:c51d7fd29843aa18dad362d4b4ecd917022131425438251f4e3d766c964dd1ad"},
|
||||
{file = "setuptools-71.0.0-py3-none-any.whl", hash = "sha256:f06fbe978a91819d250a30e0dc4ca79df713d909e24438a42d0ec300fc52247f"},
|
||||
{file = "setuptools-71.0.0.tar.gz", hash = "sha256:98da3b8aca443b9848a209ae4165e2edede62633219afa493a58fbba57f72e2e"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
|
@ -4215,6 +4216,23 @@ perf = ["orjson"]
|
|||
sweeps = ["sweeps (>=0.2.0)"]
|
||||
workspaces = ["wandb-workspaces"]
|
||||
|
||||
[[package]]
|
||||
name = "werkzeug"
|
||||
version = "3.0.3"
|
||||
description = "The comprehensive WSGI web application library."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "werkzeug-3.0.3-py3-none-any.whl", hash = "sha256:fc9645dc43e03e4d630d23143a04a7f947a9a3b5727cd535fdfe155a17cc48c8"},
|
||||
{file = "werkzeug-3.0.3.tar.gz", hash = "sha256:097e5bfda9f0aba8da6b8545146def481d06aa7d3266e7448e2cccf67dd8bd18"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
MarkupSafe = ">=2.1.1"
|
||||
|
||||
[package.extras]
|
||||
watchdog = ["watchdog (>=2.3)"]
|
||||
|
||||
[[package]]
|
||||
name = "xxhash"
|
||||
version = "3.4.1"
|
||||
|
@ -4485,4 +4503,4 @@ xarm = ["gym-xarm"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.10,<3.13"
|
||||
content-hash = "dfe9c6a54e0382156e62e7bd2c7aab1be6372da76d30c61b06d27232276638cb"
|
||||
content-hash = "25d5a270d770d37b13a93bf72868d3b9e683f8af5252b6332ec926a26fd0c096"
|
||||
|
|
|
@ -57,13 +57,15 @@ pytest-cov = {version = ">=5.0.0", optional = true}
|
|||
datasets = ">=2.19.0"
|
||||
imagecodecs = { version = ">=2024.1.1", optional = true }
|
||||
pyav = ">=12.0.5"
|
||||
moviepy = ">=1.0.3"
|
||||
rerun-sdk = ">=0.15.1"
|
||||
deepdiff = ">=7.0.1"
|
||||
scikit-image = {version = ">=0.23.2", optional = true}
|
||||
flask = ">=3.0.3"
|
||||
pandas = {version = ">=2.2.2", optional = true}
|
||||
scikit-image = {version = ">=0.23.2", optional = true}
|
||||
dynamixel-sdk = {version = ">=3.7.31", optional = true}
|
||||
pynput = {version = ">=1.7.7", optional = true}
|
||||
# TODO(rcadene, salibert): 71.0.1 has a bug
|
||||
setuptools = {version = "!=71.0.1", optional = true}
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.scripts.visualize_dataset_html import visualize_dataset_html
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"repo_id",
|
||||
["lerobot/pusht"],
|
||||
)
|
||||
def test_visualize_dataset_html(tmpdir, repo_id):
|
||||
tmpdir = Path(tmpdir)
|
||||
visualize_dataset_html(
|
||||
repo_id,
|
||||
episodes=[0],
|
||||
output_dir=tmpdir,
|
||||
serve=False,
|
||||
)
|
||||
assert (tmpdir / "static" / "episode_0.csv").exists()
|
Loading…
Reference in New Issue