571 lines
20 KiB
Python
571 lines
20 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
""" Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
|
|
|
|
Note: The last frame of the episode doesnt always correspond to a final state.
|
|
That's because our datasets are composed of transition from state to state up to
|
|
the antepenultimate state associated to the ultimate action to arrive in the final state.
|
|
However, there might not be a transition from a final state to another state.
|
|
|
|
Note: This script aims to visualize the data used to train the neural networks.
|
|
~What you see is what you get~. When visualizing image modality, it is often expected to observe
|
|
lossly compression artifacts since these images have been decoded from compressed mp4 videos to
|
|
save disk space. The compression factor applied has been tuned to not affect success rate.
|
|
|
|
Examples:
|
|
|
|
- Visualize data stored on a local machine:
|
|
```
|
|
local$ python lerobot/scripts/visualize_dataset.py \
|
|
--repo-id lerobot/pusht
|
|
|
|
local$ open http://localhost:9090
|
|
```
|
|
|
|
- Visualize data stored on a distant machine with a local viewer:
|
|
```
|
|
distant$ python lerobot/scripts/visualize_dataset.py \
|
|
--repo-id lerobot/pusht
|
|
|
|
local$ ssh -L 9090:localhost:9090 distant # create a ssh tunnel
|
|
local$ open http://localhost:9090
|
|
```
|
|
|
|
- Select episodes to visualize:
|
|
```
|
|
python lerobot/scripts/visualize_dataset.py \
|
|
--repo-id lerobot/pusht \
|
|
--episode-indices 7 3 5 1 4
|
|
```
|
|
"""
|
|
|
|
import argparse
|
|
import http.server
|
|
import logging
|
|
import os
|
|
import shutil
|
|
import socketserver
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import tqdm
|
|
import yaml
|
|
from bs4 import BeautifulSoup
|
|
from huggingface_hub import snapshot_download
|
|
from safetensors.torch import load_file, save_file
|
|
|
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
|
from lerobot.common.policies.act.modeling_act import ACTPolicy
|
|
from lerobot.common.utils.utils import init_logging
|
|
|
|
|
|
class EpisodeSampler(torch.utils.data.Sampler):
|
|
def __init__(self, dataset, episode_index):
|
|
from_idx = dataset.episode_data_index["from"][episode_index].item()
|
|
to_idx = dataset.episode_data_index["to"][episode_index].item()
|
|
self.frame_ids = range(from_idx, to_idx)
|
|
|
|
def __iter__(self):
|
|
return iter(self.frame_ids)
|
|
|
|
def __len__(self):
|
|
return len(self.frame_ids)
|
|
|
|
|
|
class NoCacheHTTPRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|
def end_headers(self):
|
|
self.send_header("Cache-Control", "no-store, no-cache, must-revalidate")
|
|
self.send_header("Pragma", "no-cache")
|
|
self.send_header("Expires", "0")
|
|
super().end_headers()
|
|
|
|
|
|
def run_server(path, port):
|
|
# Change directory to serve 'index.html` as front page
|
|
os.chdir(path)
|
|
|
|
with socketserver.TCPServer(("", port), NoCacheHTTPRequestHandler) as httpd:
|
|
logging.info(f"Serving HTTP on 0.0.0.0 port {port} (http://0.0.0.0:{port}/) ...")
|
|
httpd.serve_forever()
|
|
|
|
|
|
def create_html_page(page_title: str):
|
|
"""Create a html page with beautiful soop with default doctype, meta, header and title."""
|
|
soup = BeautifulSoup("", "html.parser")
|
|
|
|
doctype = soup.new_tag("!DOCTYPE html")
|
|
soup.append(doctype)
|
|
|
|
html = soup.new_tag("html", lang="en")
|
|
soup.append(html)
|
|
|
|
head = soup.new_tag("head")
|
|
html.append(head)
|
|
|
|
meta_charset = soup.new_tag("meta", charset="UTF-8")
|
|
head.append(meta_charset)
|
|
|
|
meta_viewport = soup.new_tag(
|
|
"meta", attrs={"name": "viewport", "content": "width=device-width, initial-scale=1.0"}
|
|
)
|
|
head.append(meta_viewport)
|
|
|
|
title = soup.new_tag("title")
|
|
title.string = page_title
|
|
head.append(title)
|
|
|
|
body = soup.new_tag("body")
|
|
html.append(body)
|
|
|
|
main_div = soup.new_tag("div")
|
|
body.append(main_div)
|
|
return soup, head, body
|
|
|
|
|
|
def write_episode_data_csv(output_dir, file_name, episode_index, dataset, inference_results=None):
|
|
"""Write a csv file containg timeseries data of an episode (e.g. state and action).
|
|
This file will be loaded by Dygraph javascript to plot data in real time."""
|
|
from_idx = dataset.episode_data_index["from"][episode_index]
|
|
to_idx = dataset.episode_data_index["to"][episode_index]
|
|
|
|
has_state = "observation.state" in dataset.hf_dataset.features
|
|
has_action = "action" in dataset.hf_dataset.features
|
|
has_inference = inference_results is not None
|
|
|
|
# init header of csv with state and action names
|
|
header = ["timestamp"]
|
|
if has_state:
|
|
dim_state = len(dataset.hf_dataset["observation.state"][0])
|
|
header += [f"state_{i}" for i in range(dim_state)]
|
|
if has_action:
|
|
dim_action = len(dataset.hf_dataset["action"][0])
|
|
header += [f"action_{i}" for i in range(dim_action)]
|
|
if has_inference:
|
|
assert "actions" in inference_results
|
|
assert "loss" in inference_results
|
|
dim_pred_action = inference_results["actions"].shape[2]
|
|
header += [f"pred_action_{i}" for i in range(dim_pred_action)]
|
|
header += ["loss"]
|
|
|
|
columns = ["timestamp"]
|
|
if has_state:
|
|
columns += ["observation.state"]
|
|
if has_action:
|
|
columns += ["action"]
|
|
|
|
rows = []
|
|
data = dataset.hf_dataset.select_columns(columns)
|
|
for i in range(from_idx, to_idx):
|
|
row = [data[i]["timestamp"].item()]
|
|
if has_state:
|
|
row += data[i]["observation.state"].tolist()
|
|
if has_action:
|
|
row += data[i]["action"].tolist()
|
|
rows.append(row)
|
|
|
|
if has_inference:
|
|
num_frames = len(rows)
|
|
assert num_frames == inference_results["actions"].shape[0]
|
|
assert num_frames == inference_results["loss"].shape[0]
|
|
for i in range(num_frames):
|
|
rows[i] += inference_results["actions"][i, 0].tolist()
|
|
rows[i] += [inference_results["loss"][i].item()]
|
|
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
with open(output_dir / file_name, "w") as f:
|
|
f.write(",".join(header) + "\n")
|
|
for row in rows:
|
|
row_str = [str(col) for col in row]
|
|
f.write(",".join(row_str) + "\n")
|
|
|
|
|
|
def write_episode_data_js(output_dir, file_name, ep_csv_fname, dataset):
|
|
"""Write a javascript file containing logic to synchronize camera feeds and timeseries."""
|
|
s = ""
|
|
s += "document.addEventListener('DOMContentLoaded', function () {\n"
|
|
for i, key in enumerate(dataset.video_frame_keys):
|
|
s += f" const video{i} = document.getElementById('video_{key}');\n"
|
|
s += " const slider = document.getElementById('videoControl');\n"
|
|
s += " const playButton = document.getElementById('playButton');\n"
|
|
s += f" const dygraph = new Dygraph(document.getElementById('graph'), '{ep_csv_fname}', " + "{\n"
|
|
s += " pixelsPerPoint: 0.01,\n"
|
|
s += " legend: 'always',\n"
|
|
s += " labelsDiv: document.getElementById('labels'),\n"
|
|
s += " labelsSeparateLines: true,\n"
|
|
s += " labelsKMB: true,\n"
|
|
s += " highlightCircleSize: 1.5,\n"
|
|
s += " highlightSeriesOpts: {\n"
|
|
s += " strokeWidth: 1.5,\n"
|
|
s += " strokeBorderWidth: 1,\n"
|
|
s += " highlightCircleSize: 3\n"
|
|
s += " }\n"
|
|
s += " });\n"
|
|
s += "\n"
|
|
s += " // Function to play both videos\n"
|
|
s += " playButton.addEventListener('click', function () {\n"
|
|
for i in range(len(dataset.video_frame_keys)):
|
|
s += f" video{i}.play();\n"
|
|
s += " // playButton.disabled = true; // Optional: disable button after playing\n"
|
|
s += " });\n"
|
|
s += "\n"
|
|
s += " // Update the video time when the slider value changes\n"
|
|
s += " slider.addEventListener('input', function () {\n"
|
|
s += " const sliderValue = slider.value;\n"
|
|
for i in range(len(dataset.video_frame_keys)):
|
|
s += f" const time{i} = (video{i}.duration * sliderValue) / 100;\n"
|
|
for i in range(len(dataset.video_frame_keys)):
|
|
s += f" video{i}.currentTime = time{i};\n"
|
|
s += " });\n"
|
|
s += "\n"
|
|
s += " // Synchronize slider with the video's current time\n"
|
|
s += " const syncSlider = (video) => {\n"
|
|
s += " video.addEventListener('timeupdate', function () {\n"
|
|
s += " if (video.duration) {\n"
|
|
s += " const pc = (100 / video.duration) * video.currentTime;\n"
|
|
s += " slider.value = pc;\n"
|
|
s += " const index = Math.floor(pc * dygraph.numRows() / 100);\n"
|
|
s += " dygraph.setSelection(index, undefined, true, true);\n"
|
|
s += " }\n"
|
|
s += " });\n"
|
|
s += " };\n"
|
|
s += "\n"
|
|
for i in range(len(dataset.video_frame_keys)):
|
|
s += f" syncSlider(video{i});\n"
|
|
s += "\n"
|
|
s += "});\n"
|
|
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
with open(output_dir / file_name, "w", encoding="utf-8") as f:
|
|
f.write(s)
|
|
|
|
|
|
def write_episode_data_html(output_dir, file_name, js_fname, ep_index, dataset):
|
|
"""Write an html file containg video feeds and timeseries associated to an episode."""
|
|
soup, head, body = create_html_page("")
|
|
|
|
css_style = soup.new_tag("style")
|
|
css_style.string = ""
|
|
css_style.string += "#labels > span.highlight {\n"
|
|
css_style.string += " border: 1px solid grey;\n"
|
|
css_style.string += "}"
|
|
head.append(css_style)
|
|
|
|
# Add videos from camera feeds
|
|
|
|
videos_control_div = soup.new_tag("div")
|
|
body.append(videos_control_div)
|
|
|
|
videos_div = soup.new_tag("div")
|
|
videos_control_div.append(videos_div)
|
|
|
|
def create_video(id, src):
|
|
video = soup.new_tag("video", id=id, width="320", height="240", controls="")
|
|
source = soup.new_tag("source", src=src, type="video/mp4")
|
|
video.string = "Your browser does not support the video tag."
|
|
video.append(source)
|
|
return video
|
|
|
|
# get first frame of episode (hack to get video_path of the episode)
|
|
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
|
|
|
|
for key in dataset.video_frame_keys:
|
|
# Example of video_path: 'videos/observation.image_episode_000004.mp4'
|
|
video_path = dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"]
|
|
videos_div.append(create_video(f"video_{key}", video_path))
|
|
|
|
# Add controls for videos and graph
|
|
|
|
control_div = soup.new_tag("div")
|
|
videos_control_div.append(control_div)
|
|
|
|
button_div = soup.new_tag("div")
|
|
control_div.append(button_div)
|
|
|
|
button = soup.new_tag("button", id="playButton")
|
|
button.string = "Play Videos"
|
|
button_div.append(button)
|
|
|
|
slider_div = soup.new_tag("div")
|
|
control_div.append(slider_div)
|
|
|
|
slider = soup.new_tag("input", type="range", id="videoControl", min="0", max="100", value="0", step="1")
|
|
control_div.append(slider)
|
|
|
|
# Add graph of states/actions, and its labels
|
|
|
|
graph_labels_div = soup.new_tag("div", style="display: flex;")
|
|
body.append(graph_labels_div)
|
|
|
|
graph_div = soup.new_tag("div", id="graph", style="flex: 1; width: 85%")
|
|
graph_labels_div.append(graph_div)
|
|
|
|
labels_div = soup.new_tag("div", id="labels", style="flex: 1; width: 15%")
|
|
graph_labels_div.append(labels_div)
|
|
|
|
# add dygraph library
|
|
script = soup.new_tag("script", type="text/javascript", src=js_fname)
|
|
body.append(script)
|
|
|
|
script_dygraph = soup.new_tag(
|
|
"script",
|
|
type="text/javascript",
|
|
src="https://cdn.jsdelivr.net/npm/dygraphs@2.1.0/dist/dygraph.min.js",
|
|
)
|
|
body.append(script_dygraph)
|
|
|
|
link_dygraph = soup.new_tag(
|
|
"link", rel="stylesheet", href="https://cdn.jsdelivr.net/npm/dygraphs@2.1.0/dist/dygraph.min.css"
|
|
)
|
|
body.append(link_dygraph)
|
|
|
|
# Write as a html file
|
|
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
with open(output_dir / file_name, "w", encoding="utf-8") as f:
|
|
f.write(soup.prettify())
|
|
|
|
|
|
def write_episodes_list_html(output_dir, file_name, ep_indices, ep_html_fnames, dataset):
|
|
"""Write an html file containing information related to the dataset and a list of links to
|
|
html pages of episodes."""
|
|
soup, head, body = create_html_page("TODO")
|
|
|
|
h3 = soup.new_tag("h3")
|
|
h3.string = "TODO"
|
|
body.append(h3)
|
|
|
|
ul_info = soup.new_tag("ul")
|
|
body.append(ul_info)
|
|
|
|
li_info = soup.new_tag("li")
|
|
li_info.string = f"Number of samples/frames: {dataset.num_samples}"
|
|
ul_info.append(li_info)
|
|
|
|
li_info = soup.new_tag("li")
|
|
li_info.string = f"Number of episodes: {dataset.num_episodes}"
|
|
ul_info.append(li_info)
|
|
|
|
li_info = soup.new_tag("li")
|
|
li_info.string = f"Frames per second: {dataset.fps}"
|
|
ul_info.append(li_info)
|
|
|
|
# li_info = soup.new_tag("li")
|
|
# li_info.string = f"Size: {format_big_number(dataset.hf_dataset.info.size_in_bytes)}B"
|
|
# ul_info.append(li_info)
|
|
|
|
ul = soup.new_tag("ul")
|
|
body.append(ul)
|
|
|
|
for ep_idx, ep_html_fname in zip(ep_indices, ep_html_fnames, strict=False):
|
|
li = soup.new_tag("li")
|
|
ul.append(li)
|
|
|
|
a = soup.new_tag("a", href=ep_html_fname)
|
|
a.string = f"Episode number {ep_idx}"
|
|
|
|
li.append(a)
|
|
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
with open(output_dir / file_name, "w", encoding="utf-8") as f:
|
|
f.write(soup.prettify())
|
|
|
|
|
|
def run_inference(dataset, episode_index, policy, num_workers=4, batch_size=32, device="cuda"):
|
|
policy.eval()
|
|
policy.to(device)
|
|
|
|
logging.info("Loading dataloader")
|
|
episode_sampler = EpisodeSampler(dataset, episode_index)
|
|
dataloader = torch.utils.data.DataLoader(
|
|
dataset,
|
|
num_workers=num_workers,
|
|
batch_size=batch_size,
|
|
sampler=episode_sampler,
|
|
)
|
|
|
|
logging.info("Running inference")
|
|
inference_results = {}
|
|
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
|
|
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
|
|
with torch.inference_mode():
|
|
output_dict = policy.forward(batch)
|
|
|
|
for key in output_dict:
|
|
if key not in inference_results:
|
|
inference_results[key] = []
|
|
inference_results[key].append(output_dict[key].to("cpu"))
|
|
|
|
for key in inference_results:
|
|
inference_results[key] = torch.cat(inference_results[key])
|
|
|
|
return inference_results
|
|
|
|
|
|
def visualize_dataset(
|
|
repo_id: str,
|
|
episode_indices: list[int] = None,
|
|
output_dir: Path | None = None,
|
|
serve: bool = True,
|
|
port: int = 9090,
|
|
force_overwrite: bool = True,
|
|
policy_repo_id: str | None = None,
|
|
policy_ckpt_path: Path | None = None,
|
|
batch_size: int = 32,
|
|
num_workers: int = 4,
|
|
) -> Path | None:
|
|
init_logging()
|
|
|
|
has_policy = policy_repo_id or policy_ckpt_path
|
|
|
|
if has_policy:
|
|
logging.info("Loading policy")
|
|
if policy_repo_id:
|
|
pretrained_policy_path = Path(snapshot_download(policy_repo_id))
|
|
elif policy_ckpt_path:
|
|
pretrained_policy_path = Path(policy_ckpt_path)
|
|
policy = ACTPolicy.from_pretrained(pretrained_policy_path)
|
|
with open(pretrained_policy_path / "config.yaml") as f:
|
|
cfg = yaml.safe_load(f)
|
|
delta_timestamps = cfg["training"]["delta_timestamps"]
|
|
else:
|
|
delta_timestamps = None
|
|
|
|
logging.info("Loading dataset")
|
|
dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
|
|
|
|
if not dataset.video:
|
|
raise NotImplementedError(f"Image datasets ({dataset.video=}) are currently not supported.")
|
|
|
|
if output_dir is None:
|
|
output_dir = f"outputs/visualize_dataset/{repo_id}"
|
|
|
|
output_dir = Path(output_dir)
|
|
if force_overwrite and output_dir.exists():
|
|
shutil.rmtree(output_dir)
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Create a simlink from the dataset video folder containg mp4 files to the output directory
|
|
# so that the http server can get access to the mp4 files.
|
|
ln_videos_dir = output_dir / "videos"
|
|
if not ln_videos_dir.exists():
|
|
ln_videos_dir.symlink_to(dataset.videos_dir.resolve())
|
|
|
|
if episode_indices is None:
|
|
episode_indices = list(range(dataset.num_episodes))
|
|
|
|
logging.info("Writing html")
|
|
ep_html_fnames = []
|
|
for episode_index in tqdm.tqdm(episode_indices):
|
|
inference_results = None
|
|
if has_policy:
|
|
inference_results_path = output_dir / f"episode_{episode_index}.safetensors"
|
|
if inference_results_path.exists():
|
|
inference_results = load_file(inference_results_path)
|
|
else:
|
|
inference_results = run_inference(dataset, episode_index, policy)
|
|
save_file(inference_results, inference_results_path)
|
|
|
|
# write states and actions in a csv
|
|
ep_csv_fname = f"episode_{episode_index}.csv"
|
|
write_episode_data_csv(output_dir, ep_csv_fname, episode_index, dataset, inference_results)
|
|
|
|
js_fname = f"episode_{episode_index}.js"
|
|
write_episode_data_js(output_dir, js_fname, ep_csv_fname, dataset)
|
|
|
|
# write a html page to view videos and timeseries
|
|
ep_html_fname = f"episode_{episode_index}.html"
|
|
write_episode_data_html(output_dir, ep_html_fname, js_fname, episode_index, dataset)
|
|
ep_html_fnames.append(ep_html_fname)
|
|
|
|
write_episodes_list_html(output_dir, "index.html", episode_indices, ep_html_fnames, dataset)
|
|
|
|
if serve:
|
|
run_server(output_dir, port)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument(
|
|
"--repo-id",
|
|
type=str,
|
|
required=True,
|
|
help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).",
|
|
)
|
|
parser.add_argument(
|
|
"--episode-indices",
|
|
type=int,
|
|
nargs="*",
|
|
default=None,
|
|
help="Episode indices to visualize (e.g. `0 1 5 6` to load episodes of index 0, 1, 5 and 6). By default loads all episodes.",
|
|
)
|
|
parser.add_argument(
|
|
"--output-dir",
|
|
type=str,
|
|
default=None,
|
|
help="Directory path to write html files and kickoff a web server. By default write them to 'outputs/visualize_dataset/REPO_ID'.",
|
|
)
|
|
parser.add_argument(
|
|
"--serve",
|
|
type=int,
|
|
default=1,
|
|
help="Launch web server.",
|
|
)
|
|
parser.add_argument(
|
|
"--port",
|
|
type=int,
|
|
default=9090,
|
|
help="Web port used by the http server.",
|
|
)
|
|
parser.add_argument(
|
|
"--force-overwrite",
|
|
type=int,
|
|
default=1,
|
|
help="Delete the output directory if it exists already.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--policy-repo-id",
|
|
type=str,
|
|
default=None,
|
|
help="Name of hugging face repositery containing a pretrained policy (e.g. `lerobot/diffusion_pusht` for https://huggingface.co/lerobot/diffusion_pusht).",
|
|
)
|
|
parser.add_argument(
|
|
"--policy-ckpt-path",
|
|
type=str,
|
|
default=None,
|
|
help="Name of hugging face repositery containing a pretrained policy (e.g. `lerobot/diffusion_pusht` for https://huggingface.co/lerobot/diffusion_pusht).",
|
|
)
|
|
parser.add_argument(
|
|
"--batch-size",
|
|
type=int,
|
|
default=32,
|
|
help="Batch size loaded by DataLoader.",
|
|
)
|
|
parser.add_argument(
|
|
"--num-workers",
|
|
type=int,
|
|
default=4,
|
|
help="Number of processes of Dataloader for loading the data.",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
visualize_dataset(**vars(args))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|