[vizualizer] for LeRobodDataset V2 (#576)
This commit is contained in:
parent
3bb5ed5e91
commit
0a4e9e25d0
|
@ -17,9 +17,11 @@ import importlib.resources
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import textwrap
|
import textwrap
|
||||||
|
from collections.abc import Iterator
|
||||||
from itertools import accumulate
|
from itertools import accumulate
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
|
from types import SimpleNamespace
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
|
@ -502,3 +504,58 @@ def create_lerobot_dataset_card(
|
||||||
template_path=str(card_template_path),
|
template_path=str(card_template_path),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class IterableNamespace(SimpleNamespace):
|
||||||
|
"""
|
||||||
|
A namespace object that supports both dictionary-like iteration and dot notation access.
|
||||||
|
Automatically converts nested dictionaries into IterableNamespaces.
|
||||||
|
|
||||||
|
This class extends SimpleNamespace to provide:
|
||||||
|
- Dictionary-style iteration over keys
|
||||||
|
- Access to items via both dot notation (obj.key) and brackets (obj["key"])
|
||||||
|
- Dictionary-like methods: items(), keys(), values()
|
||||||
|
- Recursive conversion of nested dictionaries
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dictionary: Optional dictionary to initialize the namespace
|
||||||
|
**kwargs: Additional keyword arguments passed to SimpleNamespace
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> data = {"name": "Alice", "details": {"age": 25}}
|
||||||
|
>>> ns = IterableNamespace(data)
|
||||||
|
>>> ns.name
|
||||||
|
'Alice'
|
||||||
|
>>> ns.details.age
|
||||||
|
25
|
||||||
|
>>> list(ns.keys())
|
||||||
|
['name', 'details']
|
||||||
|
>>> for key, value in ns.items():
|
||||||
|
... print(f"{key}: {value}")
|
||||||
|
name: Alice
|
||||||
|
details: IterableNamespace(age=25)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dictionary: dict[str, Any] = None, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
if dictionary is not None:
|
||||||
|
for key, value in dictionary.items():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
setattr(self, key, IterableNamespace(value))
|
||||||
|
else:
|
||||||
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[str]:
|
||||||
|
return iter(vars(self))
|
||||||
|
|
||||||
|
def __getitem__(self, key: str) -> Any:
|
||||||
|
return vars(self)[key]
|
||||||
|
|
||||||
|
def items(self):
|
||||||
|
return vars(self).items()
|
||||||
|
|
||||||
|
def values(self):
|
||||||
|
return vars(self).values()
|
||||||
|
|
||||||
|
def keys(self):
|
||||||
|
return vars(self).keys()
|
||||||
|
|
|
@ -53,20 +53,29 @@ python lerobot/scripts/visualize_dataset_html.py \
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import csv
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
from io import StringIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import tqdm
|
import numpy as np
|
||||||
from flask import Flask, redirect, render_template, url_for
|
import pandas as pd
|
||||||
|
import requests
|
||||||
|
from flask import Flask, redirect, render_template, request, url_for
|
||||||
|
|
||||||
|
from lerobot import available_datasets
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
from lerobot.common.datasets.utils import IterableNamespace
|
||||||
from lerobot.common.utils.utils import init_logging
|
from lerobot.common.utils.utils import init_logging
|
||||||
|
|
||||||
|
|
||||||
def run_server(
|
def run_server(
|
||||||
dataset: LeRobotDataset,
|
dataset: LeRobotDataset | IterableNamespace | None,
|
||||||
episodes: list[int],
|
episodes: list[int] | None,
|
||||||
host: str,
|
host: str,
|
||||||
port: str,
|
port: str,
|
||||||
static_folder: Path,
|
static_folder: Path,
|
||||||
|
@ -76,10 +85,50 @@ def run_server(
|
||||||
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
|
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
|
||||||
|
|
||||||
@app.route("/")
|
@app.route("/")
|
||||||
def index():
|
def hommepage(dataset=dataset):
|
||||||
# home page redirects to the first episode page
|
if dataset:
|
||||||
[dataset_namespace, dataset_name] = dataset.repo_id.split("/")
|
dataset_namespace, dataset_name = dataset.repo_id.split("/")
|
||||||
first_episode_id = episodes[0]
|
return redirect(
|
||||||
|
url_for(
|
||||||
|
"show_episode",
|
||||||
|
dataset_namespace=dataset_namespace,
|
||||||
|
dataset_name=dataset_name,
|
||||||
|
episode_id=0,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset_param, episode_param = None, None
|
||||||
|
all_params = request.args
|
||||||
|
if "dataset" in all_params:
|
||||||
|
dataset_param = all_params["dataset"]
|
||||||
|
if "episode" in all_params:
|
||||||
|
episode_param = int(all_params["episode"])
|
||||||
|
|
||||||
|
if dataset_param:
|
||||||
|
dataset_namespace, dataset_name = dataset_param.split("/")
|
||||||
|
return redirect(
|
||||||
|
url_for(
|
||||||
|
"show_episode",
|
||||||
|
dataset_namespace=dataset_namespace,
|
||||||
|
dataset_name=dataset_name,
|
||||||
|
episode_id=episode_param if episode_param is not None else 0,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
featured_datasets = [
|
||||||
|
"lerobot/aloha_static_cups_open",
|
||||||
|
"lerobot/columbia_cairlab_pusht_real",
|
||||||
|
"lerobot/taco_play",
|
||||||
|
]
|
||||||
|
return render_template(
|
||||||
|
"visualize_dataset_homepage.html",
|
||||||
|
featured_datasets=featured_datasets,
|
||||||
|
lerobot_datasets=available_datasets,
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.route("/<string:dataset_namespace>/<string:dataset_name>")
|
||||||
|
def show_first_episode(dataset_namespace, dataset_name):
|
||||||
|
first_episode_id = 0
|
||||||
return redirect(
|
return redirect(
|
||||||
url_for(
|
url_for(
|
||||||
"show_episode",
|
"show_episode",
|
||||||
|
@ -90,30 +139,85 @@ def run_server(
|
||||||
)
|
)
|
||||||
|
|
||||||
@app.route("/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>")
|
@app.route("/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>")
|
||||||
def show_episode(dataset_namespace, dataset_name, episode_id):
|
def show_episode(dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes):
|
||||||
|
repo_id = f"{dataset_namespace}/{dataset_name}"
|
||||||
|
try:
|
||||||
|
if dataset is None:
|
||||||
|
dataset = get_dataset_info(repo_id)
|
||||||
|
except FileNotFoundError:
|
||||||
|
return (
|
||||||
|
"Make sure to convert your LeRobotDataset to v2 & above. See how to convert your dataset at https://github.com/huggingface/lerobot/pull/461",
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
dataset_version = (
|
||||||
|
dataset.meta._version if isinstance(dataset, LeRobotDataset) else dataset.codebase_version
|
||||||
|
)
|
||||||
|
match = re.search(r"v(\d+)\.", dataset_version)
|
||||||
|
if match:
|
||||||
|
major_version = int(match.group(1))
|
||||||
|
if major_version < 2:
|
||||||
|
return "Make sure to convert your LeRobotDataset to v2 & above."
|
||||||
|
|
||||||
|
episode_data_csv_str, columns = get_episode_data(dataset, episode_id)
|
||||||
dataset_info = {
|
dataset_info = {
|
||||||
"repo_id": dataset.repo_id,
|
"repo_id": f"{dataset_namespace}/{dataset_name}",
|
||||||
"num_samples": dataset.num_frames,
|
"num_samples": dataset.num_frames
|
||||||
"num_episodes": dataset.num_episodes,
|
if isinstance(dataset, LeRobotDataset)
|
||||||
|
else dataset.total_frames,
|
||||||
|
"num_episodes": dataset.num_episodes
|
||||||
|
if isinstance(dataset, LeRobotDataset)
|
||||||
|
else dataset.total_episodes,
|
||||||
"fps": dataset.fps,
|
"fps": dataset.fps,
|
||||||
}
|
}
|
||||||
video_paths = [dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys]
|
if isinstance(dataset, LeRobotDataset):
|
||||||
tasks = dataset.meta.episodes[episode_id]["tasks"]
|
video_paths = [
|
||||||
videos_info = [
|
dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys
|
||||||
{"url": url_for("static", filename=video_path), "filename": video_path.name}
|
]
|
||||||
for video_path in video_paths
|
videos_info = [
|
||||||
]
|
{"url": url_for("static", filename=video_path), "filename": video_path.parent.name}
|
||||||
|
for video_path in video_paths
|
||||||
|
]
|
||||||
|
tasks = dataset.meta.episodes[0]["tasks"]
|
||||||
|
else:
|
||||||
|
video_keys = [key for key, ft in dataset.features.items() if ft["dtype"] == "video"]
|
||||||
|
videos_info = [
|
||||||
|
{
|
||||||
|
"url": f"https://huggingface.co/datasets/{repo_id}/resolve/main/"
|
||||||
|
+ dataset.video_path.format(
|
||||||
|
episode_chunk=int(episode_id) // dataset.chunks_size,
|
||||||
|
video_key=video_key,
|
||||||
|
episode_index=episode_id,
|
||||||
|
),
|
||||||
|
"filename": video_key,
|
||||||
|
}
|
||||||
|
for video_key in video_keys
|
||||||
|
]
|
||||||
|
|
||||||
|
response = requests.get(
|
||||||
|
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl"
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
# Split into lines and parse each line as JSON
|
||||||
|
tasks_jsonl = [json.loads(line) for line in response.text.splitlines() if line.strip()]
|
||||||
|
|
||||||
|
filtered_tasks_jsonl = [row for row in tasks_jsonl if row["episode_index"] == episode_id]
|
||||||
|
tasks = filtered_tasks_jsonl[0]["tasks"]
|
||||||
|
|
||||||
videos_info[0]["language_instruction"] = tasks
|
videos_info[0]["language_instruction"] = tasks
|
||||||
|
|
||||||
ep_csv_url = url_for("static", filename=get_ep_csv_fname(episode_id))
|
if episodes is None:
|
||||||
|
episodes = list(
|
||||||
|
range(dataset.num_episodes if isinstance(dataset, LeRobotDataset) else dataset.total_episodes)
|
||||||
|
)
|
||||||
|
|
||||||
return render_template(
|
return render_template(
|
||||||
"visualize_dataset_template.html",
|
"visualize_dataset_template.html",
|
||||||
episode_id=episode_id,
|
episode_id=episode_id,
|
||||||
episodes=episodes,
|
episodes=episodes,
|
||||||
dataset_info=dataset_info,
|
dataset_info=dataset_info,
|
||||||
videos_info=videos_info,
|
videos_info=videos_info,
|
||||||
ep_csv_url=ep_csv_url,
|
episode_data_csv_str=episode_data_csv_str,
|
||||||
has_policy=False,
|
columns=columns,
|
||||||
)
|
)
|
||||||
|
|
||||||
app.run(host=host, port=port)
|
app.run(host=host, port=port)
|
||||||
|
@ -124,46 +228,84 @@ def get_ep_csv_fname(episode_id: int):
|
||||||
return ep_csv_fname
|
return ep_csv_fname
|
||||||
|
|
||||||
|
|
||||||
def write_episode_data_csv(output_dir, file_name, episode_index, dataset):
|
def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index):
|
||||||
"""Write a csv file containg timeseries data of an episode (e.g. state and action).
|
"""Get a csv str containing timeseries data of an episode (e.g. state and action).
|
||||||
This file will be loaded by Dygraph javascript to plot data in real time."""
|
This file will be loaded by Dygraph javascript to plot data in real time."""
|
||||||
from_idx = dataset.episode_data_index["from"][episode_index]
|
columns = []
|
||||||
to_idx = dataset.episode_data_index["to"][episode_index]
|
|
||||||
|
|
||||||
has_state = "observation.state" in dataset.features
|
has_state = "observation.state" in dataset.features
|
||||||
has_action = "action" in dataset.features
|
has_action = "action" in dataset.features
|
||||||
|
|
||||||
# init header of csv with state and action names
|
# init header of csv with state and action names
|
||||||
header = ["timestamp"]
|
header = ["timestamp"]
|
||||||
if has_state:
|
if has_state:
|
||||||
dim_state = dataset.meta.shapes["observation.state"][0]
|
dim_state = (
|
||||||
|
dataset.meta.shapes["observation.state"][0]
|
||||||
|
if isinstance(dataset, LeRobotDataset)
|
||||||
|
else dataset.features["observation.state"].shape[0]
|
||||||
|
)
|
||||||
header += [f"state_{i}" for i in range(dim_state)]
|
header += [f"state_{i}" for i in range(dim_state)]
|
||||||
|
column_names = dataset.features["observation.state"]["names"]
|
||||||
|
while not isinstance(column_names, list):
|
||||||
|
column_names = list(column_names.values())[0]
|
||||||
|
columns.append({"key": "state", "value": column_names})
|
||||||
if has_action:
|
if has_action:
|
||||||
dim_action = dataset.meta.shapes["action"][0]
|
dim_action = (
|
||||||
|
dataset.meta.shapes["action"][0]
|
||||||
|
if isinstance(dataset, LeRobotDataset)
|
||||||
|
else dataset.features.action.shape[0]
|
||||||
|
)
|
||||||
header += [f"action_{i}" for i in range(dim_action)]
|
header += [f"action_{i}" for i in range(dim_action)]
|
||||||
|
column_names = dataset.features["action"]["names"]
|
||||||
|
while not isinstance(column_names, list):
|
||||||
|
column_names = list(column_names.values())[0]
|
||||||
|
columns.append({"key": "action", "value": column_names})
|
||||||
|
|
||||||
columns = ["timestamp"]
|
if isinstance(dataset, LeRobotDataset):
|
||||||
if has_state:
|
from_idx = dataset.episode_data_index["from"][episode_index]
|
||||||
columns += ["observation.state"]
|
to_idx = dataset.episode_data_index["to"][episode_index]
|
||||||
if has_action:
|
selected_columns = ["timestamp"]
|
||||||
columns += ["action"]
|
|
||||||
|
|
||||||
rows = []
|
|
||||||
data = dataset.hf_dataset.select_columns(columns)
|
|
||||||
for i in range(from_idx, to_idx):
|
|
||||||
row = [data[i]["timestamp"].item()]
|
|
||||||
if has_state:
|
if has_state:
|
||||||
row += data[i]["observation.state"].tolist()
|
selected_columns += ["observation.state"]
|
||||||
if has_action:
|
if has_action:
|
||||||
row += data[i]["action"].tolist()
|
selected_columns += ["action"]
|
||||||
rows.append(row)
|
data = (
|
||||||
|
dataset.hf_dataset.select(range(from_idx, to_idx))
|
||||||
|
.select_columns(selected_columns)
|
||||||
|
.with_format("numpy")
|
||||||
|
)
|
||||||
|
rows = np.hstack(
|
||||||
|
(np.expand_dims(data["timestamp"], axis=1), *[data[col] for col in selected_columns[1:]])
|
||||||
|
).tolist()
|
||||||
|
else:
|
||||||
|
repo_id = dataset.repo_id
|
||||||
|
selected_columns = ["timestamp"]
|
||||||
|
if "observation.state" in dataset.features:
|
||||||
|
selected_columns.append("observation.state")
|
||||||
|
if "action" in dataset.features:
|
||||||
|
selected_columns.append("action")
|
||||||
|
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/" + dataset.data_path.format(
|
||||||
with open(output_dir / file_name, "w") as f:
|
episode_chunk=int(episode_index) // dataset.chunks_size, episode_index=episode_index
|
||||||
f.write(",".join(header) + "\n")
|
)
|
||||||
for row in rows:
|
df = pd.read_parquet(url)
|
||||||
row_str = [str(col) for col in row]
|
data = df[selected_columns] # Select specific columns
|
||||||
f.write(",".join(row_str) + "\n")
|
rows = np.hstack(
|
||||||
|
(
|
||||||
|
np.expand_dims(data["timestamp"], axis=1),
|
||||||
|
*[np.vstack(data[col]) for col in selected_columns[1:]],
|
||||||
|
)
|
||||||
|
).tolist()
|
||||||
|
|
||||||
|
# Convert data to CSV string
|
||||||
|
csv_buffer = StringIO()
|
||||||
|
csv_writer = csv.writer(csv_buffer)
|
||||||
|
# Write header
|
||||||
|
csv_writer.writerow(header)
|
||||||
|
# Write data rows
|
||||||
|
csv_writer.writerows(rows)
|
||||||
|
csv_string = csv_buffer.getvalue()
|
||||||
|
|
||||||
|
return csv_string, columns
|
||||||
|
|
||||||
|
|
||||||
def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]:
|
def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]:
|
||||||
|
@ -175,9 +317,31 @@ def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) -> list[str]:
|
||||||
|
# check if the dataset has language instructions
|
||||||
|
if "language_instruction" not in dataset.features:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# get first frame index
|
||||||
|
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
|
||||||
|
|
||||||
|
language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"]
|
||||||
|
# TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored
|
||||||
|
# with the tf.tensor appearing in the string
|
||||||
|
return language_instruction.removeprefix("tf.Tensor(b'").removesuffix("', shape=(), dtype=string)")
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataset_info(repo_id: str) -> IterableNamespace:
|
||||||
|
response = requests.get(f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json")
|
||||||
|
response.raise_for_status() # Raises an HTTPError for bad responses
|
||||||
|
dataset_info = response.json()
|
||||||
|
dataset_info["repo_id"] = repo_id
|
||||||
|
return IterableNamespace(dataset_info)
|
||||||
|
|
||||||
|
|
||||||
def visualize_dataset_html(
|
def visualize_dataset_html(
|
||||||
dataset: LeRobotDataset,
|
dataset: LeRobotDataset | None,
|
||||||
episodes: list[int] = None,
|
episodes: list[int] | None = None,
|
||||||
output_dir: Path | None = None,
|
output_dir: Path | None = None,
|
||||||
serve: bool = True,
|
serve: bool = True,
|
||||||
host: str = "127.0.0.1",
|
host: str = "127.0.0.1",
|
||||||
|
@ -186,11 +350,11 @@ def visualize_dataset_html(
|
||||||
) -> Path | None:
|
) -> Path | None:
|
||||||
init_logging()
|
init_logging()
|
||||||
|
|
||||||
if len(dataset.meta.image_keys) > 0:
|
template_dir = Path(__file__).resolve().parent.parent / "templates"
|
||||||
raise NotImplementedError(f"Image keys ({dataset.meta.image_keys=}) are currently not supported.")
|
|
||||||
|
|
||||||
if output_dir is None:
|
if output_dir is None:
|
||||||
output_dir = f"outputs/visualize_dataset_html/{dataset.repo_id}"
|
# Create a temporary directory that will be automatically cleaned up
|
||||||
|
output_dir = tempfile.mkdtemp(prefix="lerobot_visualize_dataset_")
|
||||||
|
|
||||||
output_dir = Path(output_dir)
|
output_dir = Path(output_dir)
|
||||||
if output_dir.exists():
|
if output_dir.exists():
|
||||||
|
@ -201,28 +365,33 @@ def visualize_dataset_html(
|
||||||
|
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Create a simlink from the dataset video folder containg mp4 files to the output directory
|
|
||||||
# so that the http server can get access to the mp4 files.
|
|
||||||
static_dir = output_dir / "static"
|
static_dir = output_dir / "static"
|
||||||
static_dir.mkdir(parents=True, exist_ok=True)
|
static_dir.mkdir(parents=True, exist_ok=True)
|
||||||
ln_videos_dir = static_dir / "videos"
|
|
||||||
if not ln_videos_dir.exists():
|
|
||||||
ln_videos_dir.symlink_to((dataset.root / "videos").resolve())
|
|
||||||
|
|
||||||
template_dir = Path(__file__).resolve().parent.parent / "templates"
|
if dataset is None:
|
||||||
|
if serve:
|
||||||
|
run_server(
|
||||||
|
dataset=None,
|
||||||
|
episodes=None,
|
||||||
|
host=host,
|
||||||
|
port=port,
|
||||||
|
static_folder=static_dir,
|
||||||
|
template_folder=template_dir,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
image_keys = dataset.meta.image_keys if isinstance(dataset, LeRobotDataset) else []
|
||||||
|
if len(image_keys) > 0:
|
||||||
|
raise NotImplementedError(f"Image keys ({image_keys=}) are currently not supported.")
|
||||||
|
|
||||||
if episodes is None:
|
# Create a simlink from the dataset video folder containg mp4 files to the output directory
|
||||||
episodes = list(range(dataset.num_episodes))
|
# so that the http server can get access to the mp4 files.
|
||||||
|
if isinstance(dataset, LeRobotDataset):
|
||||||
|
ln_videos_dir = static_dir / "videos"
|
||||||
|
if not ln_videos_dir.exists():
|
||||||
|
ln_videos_dir.symlink_to((dataset.root / "videos").resolve())
|
||||||
|
|
||||||
logging.info("Writing CSV files")
|
if serve:
|
||||||
for episode_index in tqdm.tqdm(episodes):
|
run_server(dataset, episodes, host, port, static_dir, template_dir)
|
||||||
# write states and actions in a csv (it can be slow for big datasets)
|
|
||||||
ep_csv_fname = get_ep_csv_fname(episode_index)
|
|
||||||
# TODO(rcadene): speedup script by loading directly from dataset, pyarrow, parquet, safetensors?
|
|
||||||
write_episode_data_csv(static_dir, ep_csv_fname, episode_index, dataset)
|
|
||||||
|
|
||||||
if serve:
|
|
||||||
run_server(dataset, episodes, host, port, static_dir, template_dir)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
@ -231,7 +400,7 @@ def main():
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--repo-id",
|
"--repo-id",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
default=None,
|
||||||
help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).",
|
help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -246,6 +415,12 @@ def main():
|
||||||
default=None,
|
default=None,
|
||||||
help="Root directory for a dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.",
|
help="Root directory for a dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--load-from-hf-hub",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Load videos and parquet files from HF Hub rather than local system.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--episodes",
|
"--episodes",
|
||||||
type=int,
|
type=int,
|
||||||
|
@ -287,11 +462,19 @@ def main():
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
kwargs = vars(args)
|
kwargs = vars(args)
|
||||||
repo_id = kwargs.pop("repo_id")
|
repo_id = kwargs.pop("repo_id")
|
||||||
|
load_from_hf_hub = kwargs.pop("load_from_hf_hub")
|
||||||
root = kwargs.pop("root")
|
root = kwargs.pop("root")
|
||||||
local_files_only = kwargs.pop("local_files_only")
|
local_files_only = kwargs.pop("local_files_only")
|
||||||
|
|
||||||
dataset = LeRobotDataset(repo_id, root=root, local_files_only=local_files_only)
|
dataset = None
|
||||||
visualize_dataset_html(dataset, **kwargs)
|
if repo_id:
|
||||||
|
dataset = (
|
||||||
|
LeRobotDataset(repo_id, root=root, local_files_only=local_files_only)
|
||||||
|
if not load_from_hf_hub
|
||||||
|
else get_dataset_info(repo_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
visualize_dataset_html(dataset, **vars(args))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -0,0 +1,68 @@
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>Interactive Video Background Page</title>
|
||||||
|
<script src="https://cdn.tailwindcss.com"></script>
|
||||||
|
<script defer src="https://cdn.jsdelivr.net/npm/alpinejs@3.x.x/dist/cdn.min.js"></script>
|
||||||
|
</head>
|
||||||
|
<body class="h-screen overflow-hidden font-mono text-white" x-data="{
|
||||||
|
inputValue: '',
|
||||||
|
navigateToDataset() {
|
||||||
|
const trimmedValue = this.inputValue.trim();
|
||||||
|
if (trimmedValue) {
|
||||||
|
window.location.href = `/${trimmedValue}`;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}">
|
||||||
|
<div class="fixed inset-0 w-full h-full overflow-hidden">
|
||||||
|
<video class="absolute min-w-full min-h-full w-auto h-auto top-1/2 left-1/2 transform -translate-x-1/2 -translate-y-1/2" autoplay muted loop>
|
||||||
|
<source src="https://huggingface.co/datasets/cadene/koch_bimanual_folding/resolve/v1.6/videos/observation.images.phone_episode_000037.mp4" type="video/mp4">
|
||||||
|
Your browser does not support HTML5 video.
|
||||||
|
</video>
|
||||||
|
</div>
|
||||||
|
<div class="fixed inset-0 bg-black bg-opacity-80"></div>
|
||||||
|
<div class="relative z-10 flex flex-col items-center justify-center h-screen">
|
||||||
|
<div class="text-center mb-8">
|
||||||
|
<h1 class="text-4xl font-bold mb-4">LeRobot Dataset Visualizer</h1>
|
||||||
|
|
||||||
|
<a href="https://x.com/RemiCadene/status/1825455895561859185" target="_blank" rel="noopener noreferrer" class="underline">create & train your own robots</a>
|
||||||
|
|
||||||
|
<p class="text-xl mb-4"></p>
|
||||||
|
<div class="text-left inline-block">
|
||||||
|
<h3 class="font-semibold mb-2 mt-4">Example Datasets:</h3>
|
||||||
|
<ul class="list-disc list-inside">
|
||||||
|
{% for dataset in featured_datasets %}
|
||||||
|
<li><a href="/{{ dataset }}" class="text-blue-300 hover:text-blue-100 hover:underline">{{ dataset }}</a></li>
|
||||||
|
{% endfor %}
|
||||||
|
</ul>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class="flex w-full max-w-lg px-4 mb-4">
|
||||||
|
<input
|
||||||
|
type="text"
|
||||||
|
x-model="inputValue"
|
||||||
|
@keyup.enter="navigateToDataset"
|
||||||
|
placeholder="enter dataset id (ex: lerobot/droid_100)"
|
||||||
|
class="flex-grow px-4 py-2 rounded-l bg-white bg-opacity-20 text-white placeholder-gray-300 focus:outline-none focus:ring-2 focus:ring-blue-300"
|
||||||
|
>
|
||||||
|
<button
|
||||||
|
@click="navigateToDataset"
|
||||||
|
class="px-4 py-2 bg-blue-500 text-white rounded-r hover:bg-blue-600 focus:outline-none focus:ring-2 focus:ring-blue-300"
|
||||||
|
>
|
||||||
|
Go
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<details class="mt-4 max-w-full px-4">
|
||||||
|
<summary>More example datasets</summary>
|
||||||
|
<ul class="list-disc list-inside max-h-28 overflow-y-auto break-all">
|
||||||
|
{% for dataset in lerobot_datasets %}
|
||||||
|
<li><a href="/{{ dataset }}" class="text-blue-300 hover:text-blue-100 hover:underline">{{ dataset }}</a></li>
|
||||||
|
{% endfor %}
|
||||||
|
</ul>
|
||||||
|
</details>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>
|
|
@ -31,11 +31,16 @@
|
||||||
}">
|
}">
|
||||||
<!-- Sidebar -->
|
<!-- Sidebar -->
|
||||||
<div x-ref="sidebar" class="bg-slate-900 p-5 break-words overflow-y-auto shrink-0 md:shrink md:w-60 md:max-h-screen">
|
<div x-ref="sidebar" class="bg-slate-900 p-5 break-words overflow-y-auto shrink-0 md:shrink md:w-60 md:max-h-screen">
|
||||||
<h1 class="mb-4 text-xl font-semibold">{{ dataset_info.repo_id }}</h1>
|
<a href="https://github.com/huggingface/lerobot" target="_blank" class="hidden md:block">
|
||||||
|
<img src="https://github.com/huggingface/lerobot/raw/main/media/lerobot-logo-thumbnail.png">
|
||||||
|
</a>
|
||||||
|
<a href="https://huggingface.co/datasets/{{ dataset_info.repo_id }}" target="_blank">
|
||||||
|
<h1 class="mb-4 text-xl font-semibold">{{ dataset_info.repo_id }}</h1>
|
||||||
|
</a>
|
||||||
|
|
||||||
<ul>
|
<ul>
|
||||||
<li>
|
<li>
|
||||||
Number of samples/frames: {{ dataset_info.num_frames }}
|
Number of samples/frames: {{ dataset_info.num_samples }}
|
||||||
</li>
|
</li>
|
||||||
<li>
|
<li>
|
||||||
Number of episodes: {{ dataset_info.num_episodes }}
|
Number of episodes: {{ dataset_info.num_episodes }}
|
||||||
|
@ -93,10 +98,10 @@
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- Videos -->
|
<!-- Videos -->
|
||||||
<div class="flex flex-wrap gap-1">
|
<div class="flex flex-wrap gap-x-2 gap-y-6">
|
||||||
{% for video_info in videos_info %}
|
{% for video_info in videos_info %}
|
||||||
<div x-show="!videoCodecError" class="max-w-96">
|
<div x-show="!videoCodecError" class="max-w-96 relative">
|
||||||
<p class="text-sm text-gray-300 bg-gray-800 px-2 rounded-t-xl truncate">{{ video_info.filename }}</p>
|
<p class="absolute inset-x-0 -top-4 text-sm text-gray-300 bg-gray-800 px-2 rounded-t-xl truncate">{{ video_info.filename }}</p>
|
||||||
<video muted loop type="video/mp4" class="object-contain w-full h-full" @canplaythrough="videoCanPlay" @timeupdate="() => {
|
<video muted loop type="video/mp4" class="object-contain w-full h-full" @canplaythrough="videoCanPlay" @timeupdate="() => {
|
||||||
if (video.duration) {
|
if (video.duration) {
|
||||||
const time = video.currentTime;
|
const time = video.currentTime;
|
||||||
|
@ -182,12 +187,12 @@
|
||||||
<thead>
|
<thead>
|
||||||
<tr>
|
<tr>
|
||||||
<th></th>
|
<th></th>
|
||||||
<template x-for="(_, colIndex) in Array.from({length: nColumns}, (_, index) => index)">
|
<template x-for="(_, colIndex) in Array.from({length: columns.length}, (_, index) => index)">
|
||||||
<th class="border border-slate-700">
|
<th class="border border-slate-700">
|
||||||
<div class="flex gap-x-2 justify-between px-2">
|
<div class="flex gap-x-2 justify-between px-2">
|
||||||
<input type="checkbox" :checked="isColumnChecked(colIndex)"
|
<input type="checkbox" :checked="isColumnChecked(colIndex)"
|
||||||
@change="toggleColumn(colIndex)">
|
@change="toggleColumn(colIndex)">
|
||||||
<p x-text="`${columnNames[colIndex]}`"></p>
|
<p x-text="`${columns[colIndex].key}`"></p>
|
||||||
</div>
|
</div>
|
||||||
</th>
|
</th>
|
||||||
</template>
|
</template>
|
||||||
|
@ -197,10 +202,10 @@
|
||||||
<template x-for="(row, rowIndex) in rows">
|
<template x-for="(row, rowIndex) in rows">
|
||||||
<tr class="odd:bg-gray-800 even:bg-gray-900">
|
<tr class="odd:bg-gray-800 even:bg-gray-900">
|
||||||
<td class="border border-slate-700">
|
<td class="border border-slate-700">
|
||||||
<div class="flex gap-x-2 w-24 font-semibold px-1">
|
<div class="flex gap-x-2 max-w-64 font-semibold px-1 break-all">
|
||||||
<input type="checkbox" :checked="isRowChecked(rowIndex)"
|
<input type="checkbox" :checked="isRowChecked(rowIndex)"
|
||||||
@change="toggleRow(rowIndex)">
|
@change="toggleRow(rowIndex)">
|
||||||
<p x-text="`Motor ${rowIndex}`"></p>
|
<p x-text="`${rowLabels[rowIndex]}`"></p>
|
||||||
</div>
|
</div>
|
||||||
</td>
|
</td>
|
||||||
<template x-for="(cell, colIndex) in row">
|
<template x-for="(cell, colIndex) in row">
|
||||||
|
@ -222,16 +227,20 @@
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<script>
|
||||||
|
const parentOrigin = "https://huggingface.co";
|
||||||
|
const searchParams = new URLSearchParams();
|
||||||
|
searchParams.set("dataset", "{{ dataset_info.repo_id }}");
|
||||||
|
searchParams.set("episode", "{{ episode_id }}");
|
||||||
|
window.parent.postMessage({ queryString: searchParams.toString() }, parentOrigin);
|
||||||
|
</script>
|
||||||
|
|
||||||
<script>
|
<script>
|
||||||
function createAlpineData() {
|
function createAlpineData() {
|
||||||
return {
|
return {
|
||||||
// state
|
// state
|
||||||
dygraph: null,
|
dygraph: null,
|
||||||
currentFrameData: null,
|
currentFrameData: null,
|
||||||
columnNames: ["state", "action", "pred action"],
|
|
||||||
nColumns: 2,
|
|
||||||
nStates: 0,
|
|
||||||
nActions: 0,
|
|
||||||
checked: [],
|
checked: [],
|
||||||
dygraphTime: 0.0,
|
dygraphTime: 0.0,
|
||||||
dygraphIndex: 0,
|
dygraphIndex: 0,
|
||||||
|
@ -241,6 +250,8 @@
|
||||||
nVideos: {{ videos_info | length }},
|
nVideos: {{ videos_info | length }},
|
||||||
nVideoReadyToPlay: 0,
|
nVideoReadyToPlay: 0,
|
||||||
videoCodecError: false,
|
videoCodecError: false,
|
||||||
|
columns: {{ columns | tojson }},
|
||||||
|
rowLabels: {{ columns | tojson }}.reduce((colA, colB) => colA.value.length > colB.value.length ? colA : colB).value,
|
||||||
|
|
||||||
// alpine initialization
|
// alpine initialization
|
||||||
init() {
|
init() {
|
||||||
|
@ -251,10 +262,17 @@
|
||||||
this.videoCodecError = true;
|
this.videoCodecError = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// process CSV data
|
||||||
|
const csvDataStr = {{ episode_data_csv_str|tojson|safe }};
|
||||||
|
// Create a Blob with the CSV data
|
||||||
|
const blob = new Blob([csvDataStr], { type: 'text/csv;charset=utf-8;' });
|
||||||
|
// Create a URL for the Blob
|
||||||
|
const csvUrl = URL.createObjectURL(blob);
|
||||||
|
|
||||||
// process CSV data
|
// process CSV data
|
||||||
this.videos = document.querySelectorAll('video');
|
this.videos = document.querySelectorAll('video');
|
||||||
this.video = this.videos[0];
|
this.video = this.videos[0];
|
||||||
this.dygraph = new Dygraph(document.getElementById("graph"), '{{ ep_csv_url }}', {
|
this.dygraph = new Dygraph(document.getElementById("graph"), csvUrl, {
|
||||||
pixelsPerPoint: 0.01,
|
pixelsPerPoint: 0.01,
|
||||||
legend: 'always',
|
legend: 'always',
|
||||||
labelsDiv: document.getElementById('labels'),
|
labelsDiv: document.getElementById('labels'),
|
||||||
|
@ -275,21 +293,17 @@
|
||||||
this.colors = this.dygraph.getColors();
|
this.colors = this.dygraph.getColors();
|
||||||
this.checked = Array(this.colors.length).fill(true);
|
this.checked = Array(this.colors.length).fill(true);
|
||||||
|
|
||||||
const seriesNames = this.dygraph.getLabels().slice(1);
|
|
||||||
this.nStates = seriesNames.findIndex(item => item.startsWith('action_'));
|
|
||||||
this.nActions = seriesNames.length - this.nStates;
|
|
||||||
const colors = [];
|
const colors = [];
|
||||||
const LIGHTNESS = [30, 65, 85]; // state_lightness, action_lightness, pred_action_lightness
|
let lightness = 30; // const LIGHTNESS = [30, 65, 85]; // state_lightness, action_lightness, pred_action_lightness
|
||||||
// colors for "state" lines
|
for(const column of this.columns){
|
||||||
for (let hue = 0; hue < 360; hue += parseInt(360/this.nStates)) {
|
const nValues = column.value.length;
|
||||||
const color = `hsl(${hue}, 100%, ${LIGHTNESS[0]}%)`;
|
for (let hue = 0; hue < 360; hue += parseInt(360/nValues)) {
|
||||||
colors.push(color);
|
const color = `hsl(${hue}, 100%, ${lightness}%)`;
|
||||||
}
|
colors.push(color);
|
||||||
// colors for "action" lines
|
}
|
||||||
for (let hue = 0; hue < 360; hue += parseInt(360/this.nActions)) {
|
lightness += 35;
|
||||||
const color = `hsl(${hue}, 100%, ${LIGHTNESS[1]}%)`;
|
|
||||||
colors.push(color);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
this.dygraph.updateOptions({ colors });
|
this.dygraph.updateOptions({ colors });
|
||||||
this.colors = colors;
|
this.colors = colors;
|
||||||
|
|
||||||
|
@ -316,17 +330,19 @@
|
||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
const rows = [];
|
const rows = [];
|
||||||
const nRows = Math.max(this.nStates, this.nActions);
|
const nRows = Math.max(...this.columns.map(column => column.value.length));
|
||||||
let rowIndex = 0;
|
let rowIndex = 0;
|
||||||
while(rowIndex < nRows){
|
while(rowIndex < nRows){
|
||||||
const row = [];
|
const row = [];
|
||||||
// number of states may NOT match number of actions. In this case, we null-pad the 2D array to make a fully rectangular 2d array
|
// number of states may NOT match number of actions. In this case, we null-pad the 2D array to make a fully rectangular 2d array
|
||||||
const nullCell = { isNull: true };
|
const nullCell = { isNull: true };
|
||||||
const stateValueIdx = rowIndex;
|
|
||||||
const actionValueIdx = stateValueIdx + this.nStates; // because this.currentFrameData = [state0, state1, ..., stateN, action0, action1, ..., actionN]
|
|
||||||
// row consists of [state value, action value]
|
// row consists of [state value, action value]
|
||||||
row.push(rowIndex < this.nStates ? this.currentFrameData[stateValueIdx] : nullCell); // push "state value" to row
|
let idx = rowIndex;
|
||||||
row.push(rowIndex < this.nActions ? this.currentFrameData[actionValueIdx] : nullCell); // push "action value" to row
|
for(const column of this.columns){
|
||||||
|
const nColumn = column.value.length;
|
||||||
|
row.push(rowIndex < nColumn ? this.currentFrameData[idx] : nullCell);
|
||||||
|
idx += nColumn; // because this.currentFrameData = [state0, state1, ..., stateN, action0, action1, ..., actionN]
|
||||||
|
}
|
||||||
rowIndex += 1;
|
rowIndex += 1;
|
||||||
rows.push(row);
|
rows.push(row);
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,30 +0,0 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from lerobot.scripts.visualize_dataset_html import visualize_dataset_html
|
|
||||||
|
|
||||||
|
|
||||||
def test_visualize_dataset_html(tmp_path, lerobot_dataset_factory):
|
|
||||||
root = tmp_path / "dataset"
|
|
||||||
output_dir = tmp_path / "outputs"
|
|
||||||
dataset = lerobot_dataset_factory(root=root)
|
|
||||||
visualize_dataset_html(
|
|
||||||
dataset,
|
|
||||||
episodes=[0],
|
|
||||||
output_dir=output_dir,
|
|
||||||
serve=False,
|
|
||||||
)
|
|
||||||
assert (output_dir / "static" / "episode_0.csv").exists()
|
|
Loading…
Reference in New Issue