Merge branch 'main' into bugfix/opencv-script-types

This commit is contained in:
Alex Thiele 2025-01-03 18:05:28 +00:00
commit 8af3316a87
12 changed files with 465 additions and 131 deletions

View File

@ -3,7 +3,7 @@ default_language_version:
python: python3.10
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
rev: v5.0.0
hooks:
- id: check-added-large-files
- id: debug-statements
@ -14,11 +14,11 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/asottile/pyupgrade
rev: v3.16.0
rev: v3.19.0
hooks:
- id: pyupgrade
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.5.2
rev: v0.8.2
hooks:
- id: ruff
args: [--fix]
@ -32,6 +32,6 @@ repos:
- "--check"
- "--no-update"
- repo: https://github.com/gitleaks/gitleaks
rev: v8.18.4
rev: v8.21.2
hooks:
- id: gitleaks

View File

@ -68,7 +68,7 @@
### Acknowledgment
- Thanks to Tony Zaho, Zipeng Fu and colleagues for open sourcing ACT policy, ALOHA environments and datasets. Ours are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha) and [Mobile ALOHA](https://mobile-aloha.github.io).
- Thanks to Tony Zhao, Zipeng Fu and colleagues for open sourcing ACT policy, ALOHA environments and datasets. Ours are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha) and [Mobile ALOHA](https://mobile-aloha.github.io).
- Thanks to Cheng Chi, Zhenjia Xu and colleagues for open sourcing Diffusion policy, Pusht environment and datasets, as well as UMI datasets. Ours are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu) and [UMI Gripper](https://umi-gripper.github.io).
- Thanks to Nicklas Hansen, Yunhai Feng and colleagues for open sourcing TDMPC policy, Simxarm environments and datasets. Ours are adapted from [TDMPC](https://github.com/nicklashansen/tdmpc) and [FOWM](https://www.yunhaifeng.com/FOWM).
- Thanks to Antonio Loquercio and Ashish Kumar for their early support.

View File

@ -10,10 +10,10 @@ from torchvision.transforms import ToPILImage, v2
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
dataset_repo_id = "lerobot/aloha_static_tape"
dataset_repo_id = "lerobot/aloha_static_screw_driver"
# Create a LeRobotDataset with no transformations
dataset = LeRobotDataset(dataset_repo_id)
dataset = LeRobotDataset(dataset_repo_id, episodes=[0])
# This is equivalent to `dataset = LeRobotDataset(dataset_repo_id, image_transforms=None)`
# Get the index of the first observation in the first episode
@ -28,12 +28,13 @@ transforms = v2.Compose(
[
v2.ColorJitter(brightness=(0.5, 1.5)),
v2.ColorJitter(contrast=(0.5, 1.5)),
v2.ColorJitter(hue=(-0.1, 0.1)),
v2.RandomAdjustSharpness(sharpness_factor=2, p=1),
]
)
# Create another LeRobotDataset with the defined transformations
transformed_dataset = LeRobotDataset(dataset_repo_id, image_transforms=transforms)
transformed_dataset = LeRobotDataset(dataset_repo_id, episodes=[0], image_transforms=transforms)
# Get a frame from the transformed dataset
transformed_frame = transformed_dataset[first_idx][transformed_dataset.meta.camera_keys[0]]

View File

@ -56,7 +56,7 @@ python lerobot/scripts/control_robot.py teleoperate \
--robot-overrides max_relative_target=5
```
By adding `--robot-overrides max_relative_target=5`, we override the default value for `max_relative_target` defined in `lerobot/configs/robot/aloha.yaml`. It is expected to be `5` to limit the magnitude of the movement for more safety, but the teloperation won't be smooth. When you feel confident, you can disable this limit by adding `--robot-overrides max_relative_target=null` to the command line:
By adding `--robot-overrides max_relative_target=5`, we override the default value for `max_relative_target` defined in `lerobot/configs/robot/aloha.yaml`. It is expected to be `5` to limit the magnitude of the movement for more safety, but the teleoperation won't be smooth. When you feel confident, you can disable this limit by adding `--robot-overrides max_relative_target=null` to the command line:
```bash
python lerobot/scripts/control_robot.py teleoperate \
--robot-path lerobot/configs/robot/aloha.yaml \

View File

@ -28,7 +28,7 @@ def safe_stop_image_writer(func):
try:
return func(*args, **kwargs)
except Exception as e:
dataset = kwargs.get("dataset", None)
dataset = kwargs.get("dataset")
image_writer = getattr(dataset, "image_writer", None) if dataset else None
if image_writer is not None:
print("Waiting for image writer to terminate...")

View File

@ -17,9 +17,11 @@ import importlib.resources
import json
import logging
import textwrap
from collections.abc import Iterator
from itertools import accumulate
from pathlib import Path
from pprint import pformat
from types import SimpleNamespace
from typing import Any
import datasets
@ -477,7 +479,6 @@ def create_lerobot_dataset_card(
Note: If specified, license must be one of https://huggingface.co/docs/hub/repositories-licenses.
"""
card_tags = ["LeRobot"]
card_template_path = importlib.resources.path("lerobot.common.datasets", "card_template.md")
if tags:
card_tags += tags
@ -497,8 +498,65 @@ def create_lerobot_dataset_card(
],
)
card_template = (importlib.resources.files("lerobot.common.datasets") / "card_template.md").read_text()
return DatasetCard.from_template(
card_data=card_data,
template_path=str(card_template_path),
template_str=card_template,
**kwargs,
)
class IterableNamespace(SimpleNamespace):
"""
A namespace object that supports both dictionary-like iteration and dot notation access.
Automatically converts nested dictionaries into IterableNamespaces.
This class extends SimpleNamespace to provide:
- Dictionary-style iteration over keys
- Access to items via both dot notation (obj.key) and brackets (obj["key"])
- Dictionary-like methods: items(), keys(), values()
- Recursive conversion of nested dictionaries
Args:
dictionary: Optional dictionary to initialize the namespace
**kwargs: Additional keyword arguments passed to SimpleNamespace
Examples:
>>> data = {"name": "Alice", "details": {"age": 25}}
>>> ns = IterableNamespace(data)
>>> ns.name
'Alice'
>>> ns.details.age
25
>>> list(ns.keys())
['name', 'details']
>>> for key, value in ns.items():
... print(f"{key}: {value}")
name: Alice
details: IterableNamespace(age=25)
"""
def __init__(self, dictionary: dict[str, Any] = None, **kwargs):
super().__init__(**kwargs)
if dictionary is not None:
for key, value in dictionary.items():
if isinstance(value, dict):
setattr(self, key, IterableNamespace(value))
else:
setattr(self, key, value)
def __iter__(self) -> Iterator[str]:
return iter(vars(self))
def __getitem__(self, key: str) -> Any:
return vars(self)[key]
def items(self):
return vars(self).items()
def values(self):
return vars(self).values()
def keys(self):
return vars(self).keys()

View File

@ -184,7 +184,7 @@ def init_policy(pretrained_policy_name_or_path, policy_overrides):
def warmup_record(
robot,
events,
enable_teloperation,
enable_teleoperation,
warmup_time_s,
display_cameras,
fps,
@ -195,7 +195,7 @@ def warmup_record(
display_cameras=display_cameras,
events=events,
fps=fps,
teleoperate=enable_teloperation,
teleoperate=enable_teleoperation,
)

View File

@ -35,7 +35,7 @@ python lerobot/scripts/visualize_dataset.py \
--episode-index 0
```
- Replay a sequence of test episodes:
- Replay a sequence of test episodes:
```bash
python lerobot/scripts/control_sim_robot.py replay \
--robot-path lerobot/configs/robot/your_robot_config.yaml \

View File

@ -53,20 +53,29 @@ python lerobot/scripts/visualize_dataset_html.py \
"""
import argparse
import csv
import json
import logging
import re
import shutil
import tempfile
from io import StringIO
from pathlib import Path
import tqdm
from flask import Flask, redirect, render_template, url_for
import numpy as np
import pandas as pd
import requests
from flask import Flask, redirect, render_template, request, url_for
from lerobot import available_datasets
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import IterableNamespace
from lerobot.common.utils.utils import init_logging
def run_server(
dataset: LeRobotDataset,
episodes: list[int],
dataset: LeRobotDataset | IterableNamespace | None,
episodes: list[int] | None,
host: str,
port: str,
static_folder: Path,
@ -76,10 +85,50 @@ def run_server(
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
@app.route("/")
def index():
# home page redirects to the first episode page
[dataset_namespace, dataset_name] = dataset.repo_id.split("/")
first_episode_id = episodes[0]
def hommepage(dataset=dataset):
if dataset:
dataset_namespace, dataset_name = dataset.repo_id.split("/")
return redirect(
url_for(
"show_episode",
dataset_namespace=dataset_namespace,
dataset_name=dataset_name,
episode_id=0,
)
)
dataset_param, episode_param = None, None
all_params = request.args
if "dataset" in all_params:
dataset_param = all_params["dataset"]
if "episode" in all_params:
episode_param = int(all_params["episode"])
if dataset_param:
dataset_namespace, dataset_name = dataset_param.split("/")
return redirect(
url_for(
"show_episode",
dataset_namespace=dataset_namespace,
dataset_name=dataset_name,
episode_id=episode_param if episode_param is not None else 0,
)
)
featured_datasets = [
"lerobot/aloha_static_cups_open",
"lerobot/columbia_cairlab_pusht_real",
"lerobot/taco_play",
]
return render_template(
"visualize_dataset_homepage.html",
featured_datasets=featured_datasets,
lerobot_datasets=available_datasets,
)
@app.route("/<string:dataset_namespace>/<string:dataset_name>")
def show_first_episode(dataset_namespace, dataset_name):
first_episode_id = 0
return redirect(
url_for(
"show_episode",
@ -90,30 +139,85 @@ def run_server(
)
@app.route("/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>")
def show_episode(dataset_namespace, dataset_name, episode_id):
def show_episode(dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes):
repo_id = f"{dataset_namespace}/{dataset_name}"
try:
if dataset is None:
dataset = get_dataset_info(repo_id)
except FileNotFoundError:
return (
"Make sure to convert your LeRobotDataset to v2 & above. See how to convert your dataset at https://github.com/huggingface/lerobot/pull/461",
400,
)
dataset_version = (
dataset.meta._version if isinstance(dataset, LeRobotDataset) else dataset.codebase_version
)
match = re.search(r"v(\d+)\.", dataset_version)
if match:
major_version = int(match.group(1))
if major_version < 2:
return "Make sure to convert your LeRobotDataset to v2 & above."
episode_data_csv_str, columns = get_episode_data(dataset, episode_id)
dataset_info = {
"repo_id": dataset.repo_id,
"num_samples": dataset.num_frames,
"num_episodes": dataset.num_episodes,
"repo_id": f"{dataset_namespace}/{dataset_name}",
"num_samples": dataset.num_frames
if isinstance(dataset, LeRobotDataset)
else dataset.total_frames,
"num_episodes": dataset.num_episodes
if isinstance(dataset, LeRobotDataset)
else dataset.total_episodes,
"fps": dataset.fps,
}
video_paths = [dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys]
tasks = dataset.meta.episodes[episode_id]["tasks"]
videos_info = [
{"url": url_for("static", filename=video_path), "filename": video_path.name}
for video_path in video_paths
]
if isinstance(dataset, LeRobotDataset):
video_paths = [
dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys
]
videos_info = [
{"url": url_for("static", filename=video_path), "filename": video_path.parent.name}
for video_path in video_paths
]
tasks = dataset.meta.episodes[0]["tasks"]
else:
video_keys = [key for key, ft in dataset.features.items() if ft["dtype"] == "video"]
videos_info = [
{
"url": f"https://huggingface.co/datasets/{repo_id}/resolve/main/"
+ dataset.video_path.format(
episode_chunk=int(episode_id) // dataset.chunks_size,
video_key=video_key,
episode_index=episode_id,
),
"filename": video_key,
}
for video_key in video_keys
]
response = requests.get(
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl"
)
response.raise_for_status()
# Split into lines and parse each line as JSON
tasks_jsonl = [json.loads(line) for line in response.text.splitlines() if line.strip()]
filtered_tasks_jsonl = [row for row in tasks_jsonl if row["episode_index"] == episode_id]
tasks = filtered_tasks_jsonl[0]["tasks"]
videos_info[0]["language_instruction"] = tasks
ep_csv_url = url_for("static", filename=get_ep_csv_fname(episode_id))
if episodes is None:
episodes = list(
range(dataset.num_episodes if isinstance(dataset, LeRobotDataset) else dataset.total_episodes)
)
return render_template(
"visualize_dataset_template.html",
episode_id=episode_id,
episodes=episodes,
dataset_info=dataset_info,
videos_info=videos_info,
ep_csv_url=ep_csv_url,
has_policy=False,
episode_data_csv_str=episode_data_csv_str,
columns=columns,
)
app.run(host=host, port=port)
@ -124,46 +228,84 @@ def get_ep_csv_fname(episode_id: int):
return ep_csv_fname
def write_episode_data_csv(output_dir, file_name, episode_index, dataset):
"""Write a csv file containg timeseries data of an episode (e.g. state and action).
def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index):
"""Get a csv str containing timeseries data of an episode (e.g. state and action).
This file will be loaded by Dygraph javascript to plot data in real time."""
from_idx = dataset.episode_data_index["from"][episode_index]
to_idx = dataset.episode_data_index["to"][episode_index]
columns = []
has_state = "observation.state" in dataset.features
has_action = "action" in dataset.features
# init header of csv with state and action names
header = ["timestamp"]
if has_state:
dim_state = dataset.meta.shapes["observation.state"][0]
dim_state = (
dataset.meta.shapes["observation.state"][0]
if isinstance(dataset, LeRobotDataset)
else dataset.features["observation.state"].shape[0]
)
header += [f"state_{i}" for i in range(dim_state)]
column_names = dataset.features["observation.state"]["names"]
while not isinstance(column_names, list):
column_names = list(column_names.values())[0]
columns.append({"key": "state", "value": column_names})
if has_action:
dim_action = dataset.meta.shapes["action"][0]
dim_action = (
dataset.meta.shapes["action"][0]
if isinstance(dataset, LeRobotDataset)
else dataset.features.action.shape[0]
)
header += [f"action_{i}" for i in range(dim_action)]
column_names = dataset.features["action"]["names"]
while not isinstance(column_names, list):
column_names = list(column_names.values())[0]
columns.append({"key": "action", "value": column_names})
columns = ["timestamp"]
if has_state:
columns += ["observation.state"]
if has_action:
columns += ["action"]
rows = []
data = dataset.hf_dataset.select_columns(columns)
for i in range(from_idx, to_idx):
row = [data[i]["timestamp"].item()]
if isinstance(dataset, LeRobotDataset):
from_idx = dataset.episode_data_index["from"][episode_index]
to_idx = dataset.episode_data_index["to"][episode_index]
selected_columns = ["timestamp"]
if has_state:
row += data[i]["observation.state"].tolist()
selected_columns += ["observation.state"]
if has_action:
row += data[i]["action"].tolist()
rows.append(row)
selected_columns += ["action"]
data = (
dataset.hf_dataset.select(range(from_idx, to_idx))
.select_columns(selected_columns)
.with_format("numpy")
)
rows = np.hstack(
(np.expand_dims(data["timestamp"], axis=1), *[data[col] for col in selected_columns[1:]])
).tolist()
else:
repo_id = dataset.repo_id
selected_columns = ["timestamp"]
if "observation.state" in dataset.features:
selected_columns.append("observation.state")
if "action" in dataset.features:
selected_columns.append("action")
output_dir.mkdir(parents=True, exist_ok=True)
with open(output_dir / file_name, "w") as f:
f.write(",".join(header) + "\n")
for row in rows:
row_str = [str(col) for col in row]
f.write(",".join(row_str) + "\n")
url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/" + dataset.data_path.format(
episode_chunk=int(episode_index) // dataset.chunks_size, episode_index=episode_index
)
df = pd.read_parquet(url)
data = df[selected_columns] # Select specific columns
rows = np.hstack(
(
np.expand_dims(data["timestamp"], axis=1),
*[np.vstack(data[col]) for col in selected_columns[1:]],
)
).tolist()
# Convert data to CSV string
csv_buffer = StringIO()
csv_writer = csv.writer(csv_buffer)
# Write header
csv_writer.writerow(header)
# Write data rows
csv_writer.writerows(rows)
csv_string = csv_buffer.getvalue()
return csv_string, columns
def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]:
@ -175,9 +317,31 @@ def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]
]
def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) -> list[str]:
# check if the dataset has language instructions
if "language_instruction" not in dataset.features:
return None
# get first frame index
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"]
# TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored
# with the tf.tensor appearing in the string
return language_instruction.removeprefix("tf.Tensor(b'").removesuffix("', shape=(), dtype=string)")
def get_dataset_info(repo_id: str) -> IterableNamespace:
response = requests.get(f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json")
response.raise_for_status() # Raises an HTTPError for bad responses
dataset_info = response.json()
dataset_info["repo_id"] = repo_id
return IterableNamespace(dataset_info)
def visualize_dataset_html(
dataset: LeRobotDataset,
episodes: list[int] = None,
dataset: LeRobotDataset | None,
episodes: list[int] | None = None,
output_dir: Path | None = None,
serve: bool = True,
host: str = "127.0.0.1",
@ -186,11 +350,11 @@ def visualize_dataset_html(
) -> Path | None:
init_logging()
if len(dataset.meta.image_keys) > 0:
raise NotImplementedError(f"Image keys ({dataset.meta.image_keys=}) are currently not supported.")
template_dir = Path(__file__).resolve().parent.parent / "templates"
if output_dir is None:
output_dir = f"outputs/visualize_dataset_html/{dataset.repo_id}"
# Create a temporary directory that will be automatically cleaned up
output_dir = tempfile.mkdtemp(prefix="lerobot_visualize_dataset_")
output_dir = Path(output_dir)
if output_dir.exists():
@ -201,28 +365,33 @@ def visualize_dataset_html(
output_dir.mkdir(parents=True, exist_ok=True)
# Create a simlink from the dataset video folder containg mp4 files to the output directory
# so that the http server can get access to the mp4 files.
static_dir = output_dir / "static"
static_dir.mkdir(parents=True, exist_ok=True)
ln_videos_dir = static_dir / "videos"
if not ln_videos_dir.exists():
ln_videos_dir.symlink_to((dataset.root / "videos").resolve())
template_dir = Path(__file__).resolve().parent.parent / "templates"
if dataset is None:
if serve:
run_server(
dataset=None,
episodes=None,
host=host,
port=port,
static_folder=static_dir,
template_folder=template_dir,
)
else:
image_keys = dataset.meta.image_keys if isinstance(dataset, LeRobotDataset) else []
if len(image_keys) > 0:
raise NotImplementedError(f"Image keys ({image_keys=}) are currently not supported.")
if episodes is None:
episodes = list(range(dataset.num_episodes))
# Create a simlink from the dataset video folder containg mp4 files to the output directory
# so that the http server can get access to the mp4 files.
if isinstance(dataset, LeRobotDataset):
ln_videos_dir = static_dir / "videos"
if not ln_videos_dir.exists():
ln_videos_dir.symlink_to((dataset.root / "videos").resolve())
logging.info("Writing CSV files")
for episode_index in tqdm.tqdm(episodes):
# write states and actions in a csv (it can be slow for big datasets)
ep_csv_fname = get_ep_csv_fname(episode_index)
# TODO(rcadene): speedup script by loading directly from dataset, pyarrow, parquet, safetensors?
write_episode_data_csv(static_dir, ep_csv_fname, episode_index, dataset)
if serve:
run_server(dataset, episodes, host, port, static_dir, template_dir)
if serve:
run_server(dataset, episodes, host, port, static_dir, template_dir)
def main():
@ -231,7 +400,7 @@ def main():
parser.add_argument(
"--repo-id",
type=str,
required=True,
default=None,
help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).",
)
parser.add_argument(
@ -246,6 +415,12 @@ def main():
default=None,
help="Root directory for a dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.",
)
parser.add_argument(
"--load-from-hf-hub",
type=int,
default=0,
help="Load videos and parquet files from HF Hub rather than local system.",
)
parser.add_argument(
"--episodes",
type=int,
@ -287,11 +462,19 @@ def main():
args = parser.parse_args()
kwargs = vars(args)
repo_id = kwargs.pop("repo_id")
load_from_hf_hub = kwargs.pop("load_from_hf_hub")
root = kwargs.pop("root")
local_files_only = kwargs.pop("local_files_only")
dataset = LeRobotDataset(repo_id, root=root, local_files_only=local_files_only)
visualize_dataset_html(dataset, **kwargs)
dataset = None
if repo_id:
dataset = (
LeRobotDataset(repo_id, root=root, local_files_only=local_files_only)
if not load_from_hf_hub
else get_dataset_info(repo_id)
)
visualize_dataset_html(dataset, **vars(args))
if __name__ == "__main__":

View File

@ -0,0 +1,68 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Interactive Video Background Page</title>
<script src="https://cdn.tailwindcss.com"></script>
<script defer src="https://cdn.jsdelivr.net/npm/alpinejs@3.x.x/dist/cdn.min.js"></script>
</head>
<body class="h-screen overflow-hidden font-mono text-white" x-data="{
inputValue: '',
navigateToDataset() {
const trimmedValue = this.inputValue.trim();
if (trimmedValue) {
window.location.href = `/${trimmedValue}`;
}
}
}">
<div class="fixed inset-0 w-full h-full overflow-hidden">
<video class="absolute min-w-full min-h-full w-auto h-auto top-1/2 left-1/2 transform -translate-x-1/2 -translate-y-1/2" autoplay muted loop>
<source src="https://huggingface.co/datasets/cadene/koch_bimanual_folding/resolve/v1.6/videos/observation.images.phone_episode_000037.mp4" type="video/mp4">
Your browser does not support HTML5 video.
</video>
</div>
<div class="fixed inset-0 bg-black bg-opacity-80"></div>
<div class="relative z-10 flex flex-col items-center justify-center h-screen">
<div class="text-center mb-8">
<h1 class="text-4xl font-bold mb-4">LeRobot Dataset Visualizer</h1>
<a href="https://x.com/RemiCadene/status/1825455895561859185" target="_blank" rel="noopener noreferrer" class="underline">create & train your own robots</a>
<p class="text-xl mb-4"></p>
<div class="text-left inline-block">
<h3 class="font-semibold mb-2 mt-4">Example Datasets:</h3>
<ul class="list-disc list-inside">
{% for dataset in featured_datasets %}
<li><a href="/{{ dataset }}" class="text-blue-300 hover:text-blue-100 hover:underline">{{ dataset }}</a></li>
{% endfor %}
</ul>
</div>
</div>
<div class="flex w-full max-w-lg px-4 mb-4">
<input
type="text"
x-model="inputValue"
@keyup.enter="navigateToDataset"
placeholder="enter dataset id (ex: lerobot/droid_100)"
class="flex-grow px-4 py-2 rounded-l bg-white bg-opacity-20 text-white placeholder-gray-300 focus:outline-none focus:ring-2 focus:ring-blue-300"
>
<button
@click="navigateToDataset"
class="px-4 py-2 bg-blue-500 text-white rounded-r hover:bg-blue-600 focus:outline-none focus:ring-2 focus:ring-blue-300"
>
Go
</button>
</div>
<details class="mt-4 max-w-full px-4">
<summary>More example datasets</summary>
<ul class="list-disc list-inside max-h-28 overflow-y-auto break-all">
{% for dataset in lerobot_datasets %}
<li><a href="/{{ dataset }}" class="text-blue-300 hover:text-blue-100 hover:underline">{{ dataset }}</a></li>
{% endfor %}
</ul>
</details>
</div>
</body>
</html>

View File

@ -31,11 +31,16 @@
}">
<!-- Sidebar -->
<div x-ref="sidebar" class="bg-slate-900 p-5 break-words overflow-y-auto shrink-0 md:shrink md:w-60 md:max-h-screen">
<h1 class="mb-4 text-xl font-semibold">{{ dataset_info.repo_id }}</h1>
<a href="https://github.com/huggingface/lerobot" target="_blank" class="hidden md:block">
<img src="https://github.com/huggingface/lerobot/raw/main/media/lerobot-logo-thumbnail.png">
</a>
<a href="https://huggingface.co/datasets/{{ dataset_info.repo_id }}" target="_blank">
<h1 class="mb-4 text-xl font-semibold">{{ dataset_info.repo_id }}</h1>
</a>
<ul>
<li>
Number of samples/frames: {{ dataset_info.num_frames }}
Number of samples/frames: {{ dataset_info.num_samples }}
</li>
<li>
Number of episodes: {{ dataset_info.num_episodes }}
@ -93,10 +98,10 @@
</div>
<!-- Videos -->
<div class="flex flex-wrap gap-1">
<div class="flex flex-wrap gap-x-2 gap-y-6">
{% for video_info in videos_info %}
<div x-show="!videoCodecError" class="max-w-96">
<p class="text-sm text-gray-300 bg-gray-800 px-2 rounded-t-xl truncate">{{ video_info.filename }}</p>
<div x-show="!videoCodecError" class="max-w-96 relative">
<p class="absolute inset-x-0 -top-4 text-sm text-gray-300 bg-gray-800 px-2 rounded-t-xl truncate">{{ video_info.filename }}</p>
<video muted loop type="video/mp4" class="object-contain w-full h-full" @canplaythrough="videoCanPlay" @timeupdate="() => {
if (video.duration) {
const time = video.currentTime;
@ -182,12 +187,12 @@
<thead>
<tr>
<th></th>
<template x-for="(_, colIndex) in Array.from({length: nColumns}, (_, index) => index)">
<template x-for="(_, colIndex) in Array.from({length: columns.length}, (_, index) => index)">
<th class="border border-slate-700">
<div class="flex gap-x-2 justify-between px-2">
<input type="checkbox" :checked="isColumnChecked(colIndex)"
@change="toggleColumn(colIndex)">
<p x-text="`${columnNames[colIndex]}`"></p>
<p x-text="`${columns[colIndex].key}`"></p>
</div>
</th>
</template>
@ -197,10 +202,10 @@
<template x-for="(row, rowIndex) in rows">
<tr class="odd:bg-gray-800 even:bg-gray-900">
<td class="border border-slate-700">
<div class="flex gap-x-2 w-24 font-semibold px-1">
<div class="flex gap-x-2 max-w-64 font-semibold px-1 break-all">
<input type="checkbox" :checked="isRowChecked(rowIndex)"
@change="toggleRow(rowIndex)">
<p x-text="`Motor ${rowIndex}`"></p>
<p x-text="`${rowLabels[rowIndex]}`"></p>
</div>
</td>
<template x-for="(cell, colIndex) in row">
@ -222,16 +227,20 @@
</div>
</div>
<script>
const parentOrigin = "https://huggingface.co";
const searchParams = new URLSearchParams();
searchParams.set("dataset", "{{ dataset_info.repo_id }}");
searchParams.set("episode", "{{ episode_id }}");
window.parent.postMessage({ queryString: searchParams.toString() }, parentOrigin);
</script>
<script>
function createAlpineData() {
return {
// state
dygraph: null,
currentFrameData: null,
columnNames: ["state", "action", "pred action"],
nColumns: 2,
nStates: 0,
nActions: 0,
checked: [],
dygraphTime: 0.0,
dygraphIndex: 0,
@ -241,6 +250,8 @@
nVideos: {{ videos_info | length }},
nVideoReadyToPlay: 0,
videoCodecError: false,
columns: {{ columns | tojson }},
rowLabels: {{ columns | tojson }}.reduce((colA, colB) => colA.value.length > colB.value.length ? colA : colB).value,
// alpine initialization
init() {
@ -251,10 +262,17 @@
this.videoCodecError = true;
}
// process CSV data
const csvDataStr = {{ episode_data_csv_str|tojson|safe }};
// Create a Blob with the CSV data
const blob = new Blob([csvDataStr], { type: 'text/csv;charset=utf-8;' });
// Create a URL for the Blob
const csvUrl = URL.createObjectURL(blob);
// process CSV data
this.videos = document.querySelectorAll('video');
this.video = this.videos[0];
this.dygraph = new Dygraph(document.getElementById("graph"), '{{ ep_csv_url }}', {
this.dygraph = new Dygraph(document.getElementById("graph"), csvUrl, {
pixelsPerPoint: 0.01,
legend: 'always',
labelsDiv: document.getElementById('labels'),
@ -275,21 +293,17 @@
this.colors = this.dygraph.getColors();
this.checked = Array(this.colors.length).fill(true);
const seriesNames = this.dygraph.getLabels().slice(1);
this.nStates = seriesNames.findIndex(item => item.startsWith('action_'));
this.nActions = seriesNames.length - this.nStates;
const colors = [];
const LIGHTNESS = [30, 65, 85]; // state_lightness, action_lightness, pred_action_lightness
// colors for "state" lines
for (let hue = 0; hue < 360; hue += parseInt(360/this.nStates)) {
const color = `hsl(${hue}, 100%, ${LIGHTNESS[0]}%)`;
colors.push(color);
}
// colors for "action" lines
for (let hue = 0; hue < 360; hue += parseInt(360/this.nActions)) {
const color = `hsl(${hue}, 100%, ${LIGHTNESS[1]}%)`;
colors.push(color);
let lightness = 30; // const LIGHTNESS = [30, 65, 85]; // state_lightness, action_lightness, pred_action_lightness
for(const column of this.columns){
const nValues = column.value.length;
for (let hue = 0; hue < 360; hue += parseInt(360/nValues)) {
const color = `hsl(${hue}, 100%, ${lightness}%)`;
colors.push(color);
}
lightness += 35;
}
this.dygraph.updateOptions({ colors });
this.colors = colors;
@ -316,17 +330,19 @@
return [];
}
const rows = [];
const nRows = Math.max(this.nStates, this.nActions);
const nRows = Math.max(...this.columns.map(column => column.value.length));
let rowIndex = 0;
while(rowIndex < nRows){
const row = [];
// number of states may NOT match number of actions. In this case, we null-pad the 2D array to make a fully rectangular 2d array
const nullCell = { isNull: true };
const stateValueIdx = rowIndex;
const actionValueIdx = stateValueIdx + this.nStates; // because this.currentFrameData = [state0, state1, ..., stateN, action0, action1, ..., actionN]
// row consists of [state value, action value]
row.push(rowIndex < this.nStates ? this.currentFrameData[stateValueIdx] : nullCell); // push "state value" to row
row.push(rowIndex < this.nActions ? this.currentFrameData[actionValueIdx] : nullCell); // push "action value" to row
let idx = rowIndex;
for(const column of this.columns){
const nColumn = column.value.length;
row.push(rowIndex < nColumn ? this.currentFrameData[idx] : nullCell);
idx += nColumn; // because this.currentFrameData = [state0, state1, ..., stateN, action0, action1, ..., actionN]
}
rowIndex += 1;
rows.push(row);
}

View File

@ -14,17 +14,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.scripts.visualize_dataset_html import visualize_dataset_html
from huggingface_hub import DatasetCard
from lerobot.common.datasets.utils import create_lerobot_dataset_card
def test_visualize_dataset_html(tmp_path, lerobot_dataset_factory):
root = tmp_path / "dataset"
output_dir = tmp_path / "outputs"
dataset = lerobot_dataset_factory(root=root)
visualize_dataset_html(
dataset,
episodes=[0],
output_dir=output_dir,
serve=False,
)
assert (output_dir / "static" / "episode_0.csv").exists()
def test_default_parameters():
card = create_lerobot_dataset_card()
assert isinstance(card, DatasetCard)
assert card.data.tags == ["LeRobot"]
assert card.data.task_categories == ["robotics"]
assert card.data.configs == [
{
"config_name": "default",
"data_files": "data/*/*.parquet",
}
]
def test_with_tags():
tags = ["tag1", "tag2"]
card = create_lerobot_dataset_card(tags=tags)
assert card.data.tags == ["LeRobot", "tag1", "tag2"]