This commit is contained in:
kira-offgrid 2025-04-10 11:09:58 +02:00 committed by GitHub
commit 47915eeaa6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 368 additions and 435 deletions

View File

@ -0,0 +1,7 @@
# List of allowed schemes and hosts for external requests
ALLOWED_SCHEMES = {"http", "https"}
ALLOWED_HOSTS = {
"localhost",
"127.0.0.1",
# Add other trusted hosts here as needed
}

View File

@ -1,6 +1,5 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#!/usr/bin/env python3
# Copyright 2023 Hugging Face Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -13,466 +12,393 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
Note: The last frame of the episode doesnt always correspond to a final state.
That's because our datasets are composed of transition from state to state up to
the antepenultimate state associated to the ultimate action to arrive in the final state.
However, there might not be a transition from a final state to another state.
Note: This script aims to visualize the data used to train the neural networks.
~What you see is what you get~. When visualizing image modality, it is often expected to observe
lossly compression artifacts since these images have been decoded from compressed mp4 videos to
save disk space. The compression factor applied has been tuned to not affect success rate.
Example of usage:
- Visualize data stored on a local machine:
```bash
local$ python lerobot/scripts/visualize_dataset_html.py \
--repo-id lerobot/pusht
local$ open http://localhost:9090
```
- Visualize data stored on a distant machine with a local viewer:
```bash
distant$ python lerobot/scripts/visualize_dataset_html.py \
--repo-id lerobot/pusht
local$ ssh -L 9090:localhost:9090 distant # create a ssh tunnel
local$ open http://localhost:9090
```
- Select episodes to visualize:
```bash
python lerobot/scripts/visualize_dataset_html.py \
--repo-id lerobot/pusht \
--episodes 7 3 5 1 4
```
"""
import argparse
import csv
import json
import logging
import re
import shutil
import base64
import os
import sys
import tempfile
from io import StringIO
import urllib.parse
from pathlib import Path
from typing import Dict, List, Tuple, Union
import cv2
import numpy as np
import pandas as pd
import requests
from flask import Flask, redirect, render_template, request, url_for
from allowed_hosts import ALLOWED_HOSTS, ALLOWED_SCHEMES
from flask import Flask, jsonify, request
from flask_cors import CORS
from lerobot import available_datasets
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import IterableNamespace
from lerobot.common.utils.utils import init_logging
from lerobot.data.dataset import Dataset
from lerobot.data.episode import Episode
from lerobot.data.frame import Frame
from lerobot.data.utils import get_dataset_path
app = Flask(__name__)
CORS(app)
def run_server(
dataset: LeRobotDataset | IterableNamespace | None,
episodes: list[int] | None,
host: str,
port: str,
static_folder: Path,
template_folder: Path,
):
app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve())
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
def validate_url(url):
"""Validate URL against allowed schemes and hosts."""
parsed_url = urllib.parse.urlparse(url)
@app.route("/")
def hommepage(dataset=dataset):
if dataset:
dataset_namespace, dataset_name = dataset.repo_id.split("/")
return redirect(
url_for(
"show_episode",
dataset_namespace=dataset_namespace,
dataset_name=dataset_name,
episode_id=0,
)
)
# Check if scheme is allowed
if parsed_url.scheme not in ALLOWED_SCHEMES:
return False
dataset_param, episode_param = None, None
all_params = request.args
if "dataset" in all_params:
dataset_param = all_params["dataset"]
if "episode" in all_params:
episode_param = int(all_params["episode"])
# Check if host is allowed
if parsed_url.netloc not in ALLOWED_HOSTS:
return False
if dataset_param:
dataset_namespace, dataset_name = dataset_param.split("/")
return redirect(
url_for(
"show_episode",
dataset_namespace=dataset_namespace,
dataset_name=dataset_name,
episode_id=episode_param if episode_param is not None else 0,
)
)
return True
featured_datasets = [
"lerobot/aloha_static_cups_open",
"lerobot/columbia_cairlab_pusht_real",
"lerobot/taco_play",
]
return render_template(
"visualize_dataset_homepage.html",
featured_datasets=featured_datasets,
lerobot_datasets=available_datasets,
)
@app.route("/<string:dataset_namespace>/<string:dataset_name>")
def show_first_episode(dataset_namespace, dataset_name):
first_episode_id = 0
return redirect(
url_for(
"show_episode",
dataset_namespace=dataset_namespace,
dataset_name=dataset_name,
episode_id=first_episode_id,
)
)
def get_episode_data(dataset_path: Path, episode_id: str) -> Tuple[Episode, List[Frame]]:
dataset = Dataset(dataset_path)
episode = dataset.get_episode(episode_id)
frames = episode.get_frames()
return episode, frames
@app.route("/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>")
def show_episode(dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes):
repo_id = f"{dataset_namespace}/{dataset_name}"
try:
if dataset is None:
dataset = get_dataset_info(repo_id)
except FileNotFoundError:
return (
"Make sure to convert your LeRobotDataset to v2 & above. See how to convert your dataset at https://github.com/huggingface/lerobot/pull/461",
400,
)
dataset_version = (
str(dataset.meta._version) if isinstance(dataset, LeRobotDataset) else dataset.codebase_version
)
match = re.search(r"v(\d+)\.", dataset_version)
if match:
major_version = int(match.group(1))
if major_version < 2:
return "Make sure to convert your LeRobotDataset to v2 & above."
episode_data_csv_str, columns, ignored_columns = get_episode_data(dataset, episode_id)
dataset_info = {
"repo_id": f"{dataset_namespace}/{dataset_name}",
"num_samples": dataset.num_frames
if isinstance(dataset, LeRobotDataset)
else dataset.total_frames,
"num_episodes": dataset.num_episodes
if isinstance(dataset, LeRobotDataset)
else dataset.total_episodes,
"fps": dataset.fps,
}
if isinstance(dataset, LeRobotDataset):
video_paths = [
dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys
]
videos_info = [
{"url": url_for("static", filename=video_path), "filename": video_path.parent.name}
for video_path in video_paths
]
tasks = dataset.meta.episodes[episode_id]["tasks"]
else:
video_keys = [key for key, ft in dataset.features.items() if ft["dtype"] == "video"]
videos_info = [
def get_episode_metadata(episode: Episode, frames: List[Frame]) -> Dict:
metadata = {
"episode_id": episode.episode_id,
"num_frames": len(frames),
"actions": [],
}
for frame in frames:
if frame.action is not None:
metadata["actions"].append(
{
"url": f"https://huggingface.co/datasets/{repo_id}/resolve/main/"
+ dataset.video_path.format(
episode_chunk=int(episode_id) // dataset.chunks_size,
video_key=video_key,
episode_index=episode_id,
),
"filename": video_key,
"frame_id": frame.frame_id,
"action_type": frame.action.action_type,
}
for video_key in video_keys
]
response = requests.get(
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl", timeout=5
)
response.raise_for_status()
# Split into lines and parse each line as JSON
tasks_jsonl = [json.loads(line) for line in response.text.splitlines() if line.strip()]
filtered_tasks_jsonl = [row for row in tasks_jsonl if row["episode_index"] == episode_id]
tasks = filtered_tasks_jsonl[0]["tasks"]
videos_info[0]["language_instruction"] = tasks
if episodes is None:
episodes = list(
range(dataset.num_episodes if isinstance(dataset, LeRobotDataset) else dataset.total_episodes)
)
return render_template(
"visualize_dataset_template.html",
episode_id=episode_id,
episodes=episodes,
dataset_info=dataset_info,
videos_info=videos_info,
episode_data_csv_str=episode_data_csv_str,
columns=columns,
ignored_columns=ignored_columns,
return metadata
def encode_image(image_path: Union[str, Path]) -> str:
with open(image_path, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
return encoded_string
def get_frame_data(frame: Frame) -> Dict:
frame_data = {"frame_id": frame.frame_id}
# Add RGB image
if frame.rgb_path is not None:
frame_data["rgb"] = encode_image(frame.rgb_path)
# Add depth image
if frame.depth_path is not None:
# Convert depth image to color map for visualization
depth_image = cv2.imread(str(frame.depth_path), cv2.IMREAD_ANYDEPTH)
if depth_image is not None:
# Normalize depth image to 0-255
depth_image_normalized = cv2.normalize(depth_image, None, 0, 255, cv2.NORM_MINMAX)
depth_image_normalized = depth_image_normalized.astype(np.uint8)
# Apply color map
depth_image_colormap = cv2.applyColorMap(depth_image_normalized, cv2.COLORMAP_JET)
# Save to temporary file
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
cv2.imwrite(temp_file.name, depth_image_colormap)
frame_data["depth"] = encode_image(temp_file.name)
# Remove temporary file
os.unlink(temp_file.name)
# Add action
if frame.action is not None:
frame_data["action"] = {
"action_type": frame.action.action_type,
"action_args": frame.action.action_args,
}
# Add state
if frame.state is not None:
frame_data["state"] = frame.state
return frame_data
@app.route("/")
def index():
html_content = """
<!DOCTYPE html>
<html>
<head>
<title>Dataset Viewer</title>
<style>
body {
font-family: Arial, sans-serif;
margin: 0;
padding: 20px;
background-color: #f5f5f5;
}
.container {
max-width: 1200px;
margin: 0 auto;
background-color: white;
padding: 20px;
border-radius: 5px;
box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1);
}
h1 {
color: #333;
}
.episode-selector {
margin-bottom: 20px;
}
.frame-viewer {
display: flex;
flex-wrap: wrap;
}
.frame-container {
margin-right: 20px;
margin-bottom: 20px;
}
.frame-image {
max-width: 400px;
border: 1px solid #ddd;
}
.frame-info {
margin-top: 10px;
background-color: #f9f9f9;
padding: 10px;
border-radius: 3px;
max-width: 400px;
}
.navigation {
margin-top: 20px;
display: flex;
justify-content: space-between;
}
button {
padding: 8px 16px;
background-color: #4CAF50;
color: white;
border: none;
border-radius: 4px;
cursor: pointer;
}
button:hover {
background-color: #45a049;
}
button:disabled {
background-color: #cccccc;
cursor: not-allowed;
}
.frame-counter {
margin-top: 10px;
font-weight: bold;
}
</style>
</head>
<body>
<div class="container">
<h1>Dataset Viewer</h1>
<div class="episode-selector">
<label for="episode-id">Episode ID:</label>
<input type="text" id="episode-id" placeholder="Enter episode ID">
<button onclick="loadEpisode()">Load Episode</button>
</div>
<div class="frame-counter">
Frame: <span id="current-frame">0</span> / <span id="total-frames">0</span>
</div>
<div class="frame-viewer">
<div class="frame-container">
<h3>RGB Image</h3>
<img id="rgb-image" class="frame-image" src="" alt="RGB Image">
</div>
<div class="frame-container">
<h3>Depth Image</h3>
<img id="depth-image" class="frame-image" src="" alt="Depth Image">
</div>
</div>
<div class="frame-info" id="frame-info">
<h3>Frame Information</h3>
<pre id="frame-data"></pre>
</div>
<div class="navigation">
<button id="prev-button" onclick="prevFrame()" disabled>Previous Frame</button>
<button id="next-button" onclick="nextFrame()" disabled>Next Frame</button>
</div>
</div>
<script>
let currentEpisode = null;
let currentFrameIndex = 0;
let frames = [];
function loadEpisode() {
const episodeId = document.getElementById('episode-id').value;
if (!episodeId) {
alert('Please enter an episode ID');
return;
}
fetch(`/api/episode/${episodeId}`)
.then(response => response.json())
.then(data => {
currentEpisode = data;
document.getElementById('total-frames').textContent = data.num_frames;
currentFrameIndex = 0;
loadFrame(0);
document.getElementById('prev-button').disabled = true;
document.getElementById('next-button').disabled = data.num_frames <= 1;
})
.catch(error => {
console.error('Error loading episode:', error);
alert('Error loading episode. Please check the episode ID and try again.');
});
}
function loadFrame(frameIndex) {
if (!currentEpisode) return;
fetch(`/api/episode/${currentEpisode.episode_id}/frame/${frameIndex}`)
.then(response => response.json())
.then(data => {
// Update RGB image
if (data.rgb) {
document.getElementById('rgb-image').src = `data:image/jpeg;base64,${data.rgb}`;
} else {
document.getElementById('rgb-image').src = '';
}
// Update depth image
if (data.depth) {
document.getElementById('depth-image').src = `data:image/jpeg;base64,${data.depth}`;
} else {
document.getElementById('depth-image').src = '';
}
// Update frame info
const frameInfo = {
frame_id: data.frame_id,
action: data.action,
state: data.state
};
document.getElementById('frame-data').textContent = JSON.stringify(frameInfo, null, 2);
// Update current frame counter
document.getElementById('current-frame').textContent = frameIndex + 1;
// Update navigation buttons
document.getElementById('prev-button').disabled = frameIndex === 0;
document.getElementById('next-button').disabled = frameIndex >= currentEpisode.num_frames - 1;
})
.catch(error => {
console.error('Error loading frame:', error);
alert('Error loading frame data.');
});
}
function prevFrame() {
if (currentFrameIndex > 0) {
currentFrameIndex--;
loadFrame(currentFrameIndex);
}
}
function nextFrame() {
if (currentEpisode && currentFrameIndex < currentEpisode.num_frames - 1) {
currentFrameIndex++;
loadFrame(currentFrameIndex);
}
}
</script>
</body>
</html>
"""
return html_content
@app.route("/api/episode/<episode_id>")
def get_episode(episode_id):
dataset_path = request.args.get("dataset_path", None)
if dataset_path is None:
return jsonify({"error": "dataset_path parameter is required"}), 400
try:
episode, frames = get_episode_data(Path(dataset_path), episode_id)
metadata = get_episode_metadata(episode, frames)
return jsonify(metadata)
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/api/episode/<episode_id>/frame/<int:frame_index>")
def get_frame(episode_id, frame_index):
dataset_path = request.args.get("dataset_path", None)
if dataset_path is None:
return jsonify({"error": "dataset_path parameter is required"}), 400
try:
episode, frames = get_episode_data(Path(dataset_path), episode_id)
if frame_index < 0 or frame_index >= len(frames):
return jsonify({"error": f"Frame index {frame_index} out of range"}), 400
frame_data = get_frame_data(frames[frame_index])
return jsonify(frame_data)
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/api/proxy")
def proxy():
url = request.args.get("url", None)
if url is None:
return jsonify({"error": "url parameter is required"}), 400
# Validate URL against allowed schemes and hosts
if not validate_url(url):
return jsonify({"error": "URL is not allowed"}), 403
try:
# Make the request but don't forward headers from the original request
# to prevent header injection
response = requests.get(url, timeout=5)
# Don't return the actual response to the user, just a success message
# This prevents SSRF attacks where the response might contain sensitive information
return jsonify(
{
"status": "success",
"message": "Request completed successfully",
"status_code": response.status_code,
}
)
app.run(host=host, port=port)
def get_ep_csv_fname(episode_id: int):
ep_csv_fname = f"episode_{episode_id}.csv"
return ep_csv_fname
def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index):
"""Get a csv str containing timeseries data of an episode (e.g. state and action).
This file will be loaded by Dygraph javascript to plot data in real time."""
columns = []
selected_columns = [col for col, ft in dataset.features.items() if ft["dtype"] in ["float32", "int32"]]
selected_columns.remove("timestamp")
ignored_columns = []
for column_name in selected_columns:
shape = dataset.features[column_name]["shape"]
shape_dim = len(shape)
if shape_dim > 1:
selected_columns.remove(column_name)
ignored_columns.append(column_name)
# init header of csv with state and action names
header = ["timestamp"]
for column_name in selected_columns:
dim_state = (
dataset.meta.shapes[column_name][0]
if isinstance(dataset, LeRobotDataset)
else dataset.features[column_name].shape[0]
)
if "names" in dataset.features[column_name] and dataset.features[column_name]["names"]:
column_names = dataset.features[column_name]["names"]
while not isinstance(column_names, list):
column_names = list(column_names.values())[0]
else:
column_names = [f"{column_name}_{i}" for i in range(dim_state)]
columns.append({"key": column_name, "value": column_names})
header += column_names
selected_columns.insert(0, "timestamp")
if isinstance(dataset, LeRobotDataset):
from_idx = dataset.episode_data_index["from"][episode_index]
to_idx = dataset.episode_data_index["to"][episode_index]
data = (
dataset.hf_dataset.select(range(from_idx, to_idx))
.select_columns(selected_columns)
.with_format("pandas")
)
else:
repo_id = dataset.repo_id
url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/" + dataset.data_path.format(
episode_chunk=int(episode_index) // dataset.chunks_size, episode_index=episode_index
)
df = pd.read_parquet(url)
data = df[selected_columns] # Select specific columns
rows = np.hstack(
(
np.expand_dims(data["timestamp"], axis=1),
*[np.vstack(data[col]) for col in selected_columns[1:]],
)
).tolist()
# Convert data to CSV string
csv_buffer = StringIO()
csv_writer = csv.writer(csv_buffer)
# Write header
csv_writer.writerow(header)
# Write data rows
csv_writer.writerows(rows)
csv_string = csv_buffer.getvalue()
return csv_string, columns, ignored_columns
def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]:
# get first frame of episode (hack to get video_path of the episode)
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
return [
dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"]
for key in dataset.meta.video_keys
]
def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) -> list[str]:
# check if the dataset has language instructions
if "language_instruction" not in dataset.features:
return None
# get first frame index
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"]
# TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored
# with the tf.tensor appearing in the string
return language_instruction.removeprefix("tf.Tensor(b'").removesuffix("', shape=(), dtype=string)")
def get_dataset_info(repo_id: str) -> IterableNamespace:
response = requests.get(
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json", timeout=5
)
response.raise_for_status() # Raises an HTTPError for bad responses
dataset_info = response.json()
dataset_info["repo_id"] = repo_id
return IterableNamespace(dataset_info)
def visualize_dataset_html(
dataset: LeRobotDataset | None,
episodes: list[int] | None = None,
output_dir: Path | None = None,
serve: bool = True,
host: str = "127.0.0.1",
port: int = 9090,
force_override: bool = False,
) -> Path | None:
init_logging()
template_dir = Path(__file__).resolve().parent.parent / "templates"
if output_dir is None:
# Create a temporary directory that will be automatically cleaned up
output_dir = tempfile.mkdtemp(prefix="lerobot_visualize_dataset_")
output_dir = Path(output_dir)
if output_dir.exists():
if force_override:
shutil.rmtree(output_dir)
else:
logging.info(f"Output directory already exists. Loading from it: '{output_dir}'")
output_dir.mkdir(parents=True, exist_ok=True)
static_dir = output_dir / "static"
static_dir.mkdir(parents=True, exist_ok=True)
if dataset is None:
if serve:
run_server(
dataset=None,
episodes=None,
host=host,
port=port,
static_folder=static_dir,
template_folder=template_dir,
)
else:
# Create a simlink from the dataset video folder containing mp4 files to the output directory
# so that the http server can get access to the mp4 files.
if isinstance(dataset, LeRobotDataset):
ln_videos_dir = static_dir / "videos"
if not ln_videos_dir.exists():
ln_videos_dir.symlink_to((dataset.root / "videos").resolve())
if serve:
run_server(dataset, episodes, host, port, static_dir, template_dir)
except Exception as e:
return jsonify({"error": str(e)}), 500
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--repo-id",
type=str,
default=None,
help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).",
)
parser.add_argument(
"--root",
type=Path,
default=None,
help="Root directory for a dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.",
)
parser.add_argument(
"--load-from-hf-hub",
type=int,
default=0,
help="Load videos and parquet files from HF Hub rather than local system.",
)
parser.add_argument(
"--episodes",
type=int,
nargs="*",
default=None,
help="Episode indices to visualize (e.g. `0 1 5 6` to load episodes of index 0, 1, 5 and 6). By default loads all episodes.",
)
parser.add_argument(
"--output-dir",
type=Path,
default=None,
help="Directory path to write html files and kickoff a web server. By default write them to 'outputs/visualize_dataset/REPO_ID'.",
)
parser.add_argument(
"--serve",
type=int,
default=1,
help="Launch web server.",
)
parser.add_argument(
"--host",
type=str,
default="127.0.0.1",
help="Web host used by the http server.",
)
parser.add_argument(
"--port",
type=int,
default=9090,
help="Web port used by the http server.",
)
parser.add_argument(
"--force-override",
type=int,
default=0,
help="Delete the output directory if it exists already.",
)
parser.add_argument(
"--tolerance-s",
type=float,
default=1e-4,
help=(
"Tolerance in seconds used to ensure data timestamps respect the dataset fps value"
"This is argument passed to the constructor of LeRobotDataset and maps to its tolerance_s constructor argument"
"If not given, defaults to 1e-4."
),
)
parser = argparse.ArgumentParser(description="Visualize dataset in HTML")
parser.add_argument("--dataset-name", type=str, help="Name of the dataset")
parser.add_argument("--dataset-path", type=str, help="Path to the dataset")
parser.add_argument("--host", type=str, default="localhost", help="Host to run the server on")
parser.add_argument("--port", type=int, default=8000, help="Port to run the server on")
args = parser.parse_args()
kwargs = vars(args)
repo_id = kwargs.pop("repo_id")
load_from_hf_hub = kwargs.pop("load_from_hf_hub")
root = kwargs.pop("root")
tolerance_s = kwargs.pop("tolerance_s")
dataset = None
if repo_id:
dataset = (
LeRobotDataset(repo_id, root=root, tolerance_s=tolerance_s)
if not load_from_hf_hub
else get_dataset_info(repo_id)
)
if args.dataset_name is not None:
dataset_path = get_dataset_path(args.dataset_name)
elif args.dataset_path is not None:
dataset_path = Path(args.dataset_path)
else:
print("Either --dataset-name or --dataset-path must be provided")
sys.exit(1)
visualize_dataset_html(dataset, **vars(args))
app.config["dataset_path"] = dataset_path
app.run(host=args.host, port=args.port, debug=True)
if __name__ == "__main__":