rename to html
This commit is contained in:
parent
205e0c9dde
commit
12a1b8f55a
|
@ -30,46 +30,48 @@ Examples:
|
||||||
- Visualize data stored on a local machine:
|
- Visualize data stored on a local machine:
|
||||||
```
|
```
|
||||||
local$ python lerobot/scripts/visualize_dataset.py \
|
local$ python lerobot/scripts/visualize_dataset.py \
|
||||||
--repo-id lerobot/pusht
|
--repo-id lerobot/pusht \
|
||||||
|
--episode-index 0
|
||||||
local$ open http://localhost:9090
|
|
||||||
```
|
```
|
||||||
|
|
||||||
- Visualize data stored on a distant machine with a local viewer:
|
- Visualize data stored on a distant machine with a local viewer:
|
||||||
```
|
```
|
||||||
distant$ python lerobot/scripts/visualize_dataset.py \
|
distant$ python lerobot/scripts/visualize_dataset.py \
|
||||||
--repo-id lerobot/pusht
|
|
||||||
|
|
||||||
local$ ssh -L 9090:localhost:9090 distant # create a ssh tunnel
|
|
||||||
local$ open http://localhost:9090
|
|
||||||
```
|
|
||||||
|
|
||||||
- Select episodes to visualize:
|
|
||||||
```
|
|
||||||
python lerobot/scripts/visualize_dataset.py \
|
|
||||||
--repo-id lerobot/pusht \
|
--repo-id lerobot/pusht \
|
||||||
--episode-indices 7 3 5 1 4
|
--episode-index 0 \
|
||||||
|
--save 1 \
|
||||||
|
--output-dir path/to/directory
|
||||||
|
|
||||||
|
local$ scp distant:path/to/directory/lerobot_pusht_episode_0.rrd .
|
||||||
|
local$ rerun lerobot_pusht_episode_0.rrd
|
||||||
```
|
```
|
||||||
|
|
||||||
|
- Visualize data stored on a distant machine through streaming:
|
||||||
|
(You need to forward the websocket port to the distant machine, with
|
||||||
|
`ssh -L 9087:localhost:9087 username@remote-host`)
|
||||||
|
```
|
||||||
|
distant$ python lerobot/scripts/visualize_dataset.py \
|
||||||
|
--repo-id lerobot/pusht \
|
||||||
|
--episode-index 0 \
|
||||||
|
--mode distant \
|
||||||
|
--ws-port 9087
|
||||||
|
|
||||||
|
local$ rerun ws://localhost:9087
|
||||||
|
```
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import http.server
|
import gc
|
||||||
import logging
|
import logging
|
||||||
import os
|
import time
|
||||||
import shutil
|
|
||||||
import socketserver
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import rerun as rr
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
import yaml
|
|
||||||
from bs4 import BeautifulSoup
|
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
from safetensors.torch import load_file, save_file
|
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.policies.act.modeling_act import ACTPolicy
|
|
||||||
from lerobot.common.utils.utils import init_logging
|
|
||||||
|
|
||||||
|
|
||||||
class EpisodeSampler(torch.utils.data.Sampler):
|
class EpisodeSampler(torch.utils.data.Sampler):
|
||||||
|
@ -85,307 +87,33 @@ class EpisodeSampler(torch.utils.data.Sampler):
|
||||||
return len(self.frame_ids)
|
return len(self.frame_ids)
|
||||||
|
|
||||||
|
|
||||||
class NoCacheHTTPRequestHandler(http.server.SimpleHTTPRequestHandler):
|
def to_hwc_uint8_numpy(chw_float32_torch):
|
||||||
def end_headers(self):
|
assert chw_float32_torch.dtype == torch.float32
|
||||||
self.send_header("Cache-Control", "no-store, no-cache, must-revalidate")
|
assert chw_float32_torch.ndim == 3
|
||||||
self.send_header("Pragma", "no-cache")
|
c, h, w = chw_float32_torch.shape
|
||||||
self.send_header("Expires", "0")
|
assert c < h and c < w, f"expect channel first images, but instead {chw_float32_torch.shape}"
|
||||||
super().end_headers()
|
hwc_uint8_numpy = (chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy()
|
||||||
|
return hwc_uint8_numpy
|
||||||
|
|
||||||
|
|
||||||
def run_server(path, port):
|
def visualize_dataset(
|
||||||
# Change directory to serve 'index.html` as front page
|
repo_id: str,
|
||||||
os.chdir(path)
|
episode_index: int,
|
||||||
|
batch_size: int = 32,
|
||||||
|
num_workers: int = 0,
|
||||||
|
mode: str = "local",
|
||||||
|
web_port: int = 9090,
|
||||||
|
ws_port: int = 9087,
|
||||||
|
save: bool = False,
|
||||||
|
output_dir: Path | None = None,
|
||||||
|
) -> Path | None:
|
||||||
|
if save:
|
||||||
|
assert (
|
||||||
|
output_dir is not None
|
||||||
|
), "Set an output directory where to write .rrd files with `--output-dir path/to/directory`."
|
||||||
|
|
||||||
with socketserver.TCPServer(("", port), NoCacheHTTPRequestHandler) as httpd:
|
logging.info("Loading dataset")
|
||||||
logging.info(f"Serving HTTP on 0.0.0.0 port {port} (http://0.0.0.0:{port}/) ...")
|
dataset = LeRobotDataset(repo_id)
|
||||||
httpd.serve_forever()
|
|
||||||
|
|
||||||
|
|
||||||
def create_html_page(page_title: str):
|
|
||||||
"""Create a html page with beautiful soop with default doctype, meta, header and title."""
|
|
||||||
soup = BeautifulSoup("", "html.parser")
|
|
||||||
|
|
||||||
doctype = soup.new_tag("!DOCTYPE html")
|
|
||||||
soup.append(doctype)
|
|
||||||
|
|
||||||
html = soup.new_tag("html", lang="en")
|
|
||||||
soup.append(html)
|
|
||||||
|
|
||||||
head = soup.new_tag("head")
|
|
||||||
html.append(head)
|
|
||||||
|
|
||||||
meta_charset = soup.new_tag("meta", charset="UTF-8")
|
|
||||||
head.append(meta_charset)
|
|
||||||
|
|
||||||
meta_viewport = soup.new_tag(
|
|
||||||
"meta", attrs={"name": "viewport", "content": "width=device-width, initial-scale=1.0"}
|
|
||||||
)
|
|
||||||
head.append(meta_viewport)
|
|
||||||
|
|
||||||
title = soup.new_tag("title")
|
|
||||||
title.string = page_title
|
|
||||||
head.append(title)
|
|
||||||
|
|
||||||
body = soup.new_tag("body")
|
|
||||||
html.append(body)
|
|
||||||
|
|
||||||
main_div = soup.new_tag("div")
|
|
||||||
body.append(main_div)
|
|
||||||
return soup, head, body
|
|
||||||
|
|
||||||
|
|
||||||
def write_episode_data_csv(output_dir, file_name, episode_index, dataset, inference_results=None):
|
|
||||||
"""Write a csv file containg timeseries data of an episode (e.g. state and action).
|
|
||||||
This file will be loaded by Dygraph javascript to plot data in real time."""
|
|
||||||
from_idx = dataset.episode_data_index["from"][episode_index]
|
|
||||||
to_idx = dataset.episode_data_index["to"][episode_index]
|
|
||||||
|
|
||||||
has_state = "observation.state" in dataset.hf_dataset.features
|
|
||||||
has_action = "action" in dataset.hf_dataset.features
|
|
||||||
has_inference = inference_results is not None
|
|
||||||
|
|
||||||
# init header of csv with state and action names
|
|
||||||
header = ["timestamp"]
|
|
||||||
if has_state:
|
|
||||||
dim_state = len(dataset.hf_dataset["observation.state"][0])
|
|
||||||
header += [f"state_{i}" for i in range(dim_state)]
|
|
||||||
if has_action:
|
|
||||||
dim_action = len(dataset.hf_dataset["action"][0])
|
|
||||||
header += [f"action_{i}" for i in range(dim_action)]
|
|
||||||
if has_inference:
|
|
||||||
assert "actions" in inference_results
|
|
||||||
assert "loss" in inference_results
|
|
||||||
dim_pred_action = inference_results["actions"].shape[2]
|
|
||||||
header += [f"pred_action_{i}" for i in range(dim_pred_action)]
|
|
||||||
header += ["loss"]
|
|
||||||
|
|
||||||
columns = ["timestamp"]
|
|
||||||
if has_state:
|
|
||||||
columns += ["observation.state"]
|
|
||||||
if has_action:
|
|
||||||
columns += ["action"]
|
|
||||||
|
|
||||||
rows = []
|
|
||||||
data = dataset.hf_dataset.select_columns(columns)
|
|
||||||
for i in range(from_idx, to_idx):
|
|
||||||
row = [data[i]["timestamp"].item()]
|
|
||||||
if has_state:
|
|
||||||
row += data[i]["observation.state"].tolist()
|
|
||||||
if has_action:
|
|
||||||
row += data[i]["action"].tolist()
|
|
||||||
rows.append(row)
|
|
||||||
|
|
||||||
if has_inference:
|
|
||||||
num_frames = len(rows)
|
|
||||||
assert num_frames == inference_results["actions"].shape[0]
|
|
||||||
assert num_frames == inference_results["loss"].shape[0]
|
|
||||||
for i in range(num_frames):
|
|
||||||
rows[i] += inference_results["actions"][i, 0].tolist()
|
|
||||||
rows[i] += [inference_results["loss"][i].item()]
|
|
||||||
|
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
with open(output_dir / file_name, "w") as f:
|
|
||||||
f.write(",".join(header) + "\n")
|
|
||||||
for row in rows:
|
|
||||||
row_str = [str(col) for col in row]
|
|
||||||
f.write(",".join(row_str) + "\n")
|
|
||||||
|
|
||||||
|
|
||||||
def write_episode_data_js(output_dir, file_name, ep_csv_fname, dataset):
|
|
||||||
"""Write a javascript file containing logic to synchronize camera feeds and timeseries."""
|
|
||||||
s = ""
|
|
||||||
s += "document.addEventListener('DOMContentLoaded', function () {\n"
|
|
||||||
for i, key in enumerate(dataset.video_frame_keys):
|
|
||||||
s += f" const video{i} = document.getElementById('video_{key}');\n"
|
|
||||||
s += " const slider = document.getElementById('videoControl');\n"
|
|
||||||
s += " const playButton = document.getElementById('playButton');\n"
|
|
||||||
s += f" const dygraph = new Dygraph(document.getElementById('graph'), '{ep_csv_fname}', " + "{\n"
|
|
||||||
s += " pixelsPerPoint: 0.01,\n"
|
|
||||||
s += " legend: 'always',\n"
|
|
||||||
s += " labelsDiv: document.getElementById('labels'),\n"
|
|
||||||
s += " labelsSeparateLines: true,\n"
|
|
||||||
s += " labelsKMB: true,\n"
|
|
||||||
s += " highlightCircleSize: 1.5,\n"
|
|
||||||
s += " highlightSeriesOpts: {\n"
|
|
||||||
s += " strokeWidth: 1.5,\n"
|
|
||||||
s += " strokeBorderWidth: 1,\n"
|
|
||||||
s += " highlightCircleSize: 3\n"
|
|
||||||
s += " }\n"
|
|
||||||
s += " });\n"
|
|
||||||
s += "\n"
|
|
||||||
s += " // Function to play both videos\n"
|
|
||||||
s += " playButton.addEventListener('click', function () {\n"
|
|
||||||
for i in range(len(dataset.video_frame_keys)):
|
|
||||||
s += f" video{i}.play();\n"
|
|
||||||
s += " // playButton.disabled = true; // Optional: disable button after playing\n"
|
|
||||||
s += " });\n"
|
|
||||||
s += "\n"
|
|
||||||
s += " // Update the video time when the slider value changes\n"
|
|
||||||
s += " slider.addEventListener('input', function () {\n"
|
|
||||||
s += " const sliderValue = slider.value;\n"
|
|
||||||
for i in range(len(dataset.video_frame_keys)):
|
|
||||||
s += f" const time{i} = (video{i}.duration * sliderValue) / 100;\n"
|
|
||||||
for i in range(len(dataset.video_frame_keys)):
|
|
||||||
s += f" video{i}.currentTime = time{i};\n"
|
|
||||||
s += " });\n"
|
|
||||||
s += "\n"
|
|
||||||
s += " // Synchronize slider with the video's current time\n"
|
|
||||||
s += " const syncSlider = (video) => {\n"
|
|
||||||
s += " video.addEventListener('timeupdate', function () {\n"
|
|
||||||
s += " if (video.duration) {\n"
|
|
||||||
s += " const pc = (100 / video.duration) * video.currentTime;\n"
|
|
||||||
s += " slider.value = pc;\n"
|
|
||||||
s += " const index = Math.floor(pc * dygraph.numRows() / 100);\n"
|
|
||||||
s += " dygraph.setSelection(index, undefined, true, true);\n"
|
|
||||||
s += " }\n"
|
|
||||||
s += " });\n"
|
|
||||||
s += " };\n"
|
|
||||||
s += "\n"
|
|
||||||
for i in range(len(dataset.video_frame_keys)):
|
|
||||||
s += f" syncSlider(video{i});\n"
|
|
||||||
s += "\n"
|
|
||||||
s += "});\n"
|
|
||||||
|
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
with open(output_dir / file_name, "w", encoding="utf-8") as f:
|
|
||||||
f.write(s)
|
|
||||||
|
|
||||||
|
|
||||||
def write_episode_data_html(output_dir, file_name, js_fname, ep_index, dataset):
|
|
||||||
"""Write an html file containg video feeds and timeseries associated to an episode."""
|
|
||||||
soup, head, body = create_html_page("")
|
|
||||||
|
|
||||||
css_style = soup.new_tag("style")
|
|
||||||
css_style.string = ""
|
|
||||||
css_style.string += "#labels > span.highlight {\n"
|
|
||||||
css_style.string += " border: 1px solid grey;\n"
|
|
||||||
css_style.string += "}"
|
|
||||||
head.append(css_style)
|
|
||||||
|
|
||||||
# Add videos from camera feeds
|
|
||||||
|
|
||||||
videos_control_div = soup.new_tag("div")
|
|
||||||
body.append(videos_control_div)
|
|
||||||
|
|
||||||
videos_div = soup.new_tag("div")
|
|
||||||
videos_control_div.append(videos_div)
|
|
||||||
|
|
||||||
def create_video(id, src):
|
|
||||||
video = soup.new_tag("video", id=id, width="320", height="240", controls="")
|
|
||||||
source = soup.new_tag("source", src=src, type="video/mp4")
|
|
||||||
video.string = "Your browser does not support the video tag."
|
|
||||||
video.append(source)
|
|
||||||
return video
|
|
||||||
|
|
||||||
# get first frame of episode (hack to get video_path of the episode)
|
|
||||||
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
|
|
||||||
|
|
||||||
for key in dataset.video_frame_keys:
|
|
||||||
# Example of video_path: 'videos/observation.image_episode_000004.mp4'
|
|
||||||
video_path = dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"]
|
|
||||||
videos_div.append(create_video(f"video_{key}", video_path))
|
|
||||||
|
|
||||||
# Add controls for videos and graph
|
|
||||||
|
|
||||||
control_div = soup.new_tag("div")
|
|
||||||
videos_control_div.append(control_div)
|
|
||||||
|
|
||||||
button_div = soup.new_tag("div")
|
|
||||||
control_div.append(button_div)
|
|
||||||
|
|
||||||
button = soup.new_tag("button", id="playButton")
|
|
||||||
button.string = "Play Videos"
|
|
||||||
button_div.append(button)
|
|
||||||
|
|
||||||
slider_div = soup.new_tag("div")
|
|
||||||
control_div.append(slider_div)
|
|
||||||
|
|
||||||
slider = soup.new_tag("input", type="range", id="videoControl", min="0", max="100", value="0", step="1")
|
|
||||||
control_div.append(slider)
|
|
||||||
|
|
||||||
# Add graph of states/actions, and its labels
|
|
||||||
|
|
||||||
graph_labels_div = soup.new_tag("div", style="display: flex;")
|
|
||||||
body.append(graph_labels_div)
|
|
||||||
|
|
||||||
graph_div = soup.new_tag("div", id="graph", style="flex: 1; width: 85%")
|
|
||||||
graph_labels_div.append(graph_div)
|
|
||||||
|
|
||||||
labels_div = soup.new_tag("div", id="labels", style="flex: 1; width: 15%")
|
|
||||||
graph_labels_div.append(labels_div)
|
|
||||||
|
|
||||||
# add dygraph library
|
|
||||||
script = soup.new_tag("script", type="text/javascript", src=js_fname)
|
|
||||||
body.append(script)
|
|
||||||
|
|
||||||
script_dygraph = soup.new_tag(
|
|
||||||
"script",
|
|
||||||
type="text/javascript",
|
|
||||||
src="https://cdn.jsdelivr.net/npm/dygraphs@2.1.0/dist/dygraph.min.js",
|
|
||||||
)
|
|
||||||
body.append(script_dygraph)
|
|
||||||
|
|
||||||
link_dygraph = soup.new_tag(
|
|
||||||
"link", rel="stylesheet", href="https://cdn.jsdelivr.net/npm/dygraphs@2.1.0/dist/dygraph.min.css"
|
|
||||||
)
|
|
||||||
body.append(link_dygraph)
|
|
||||||
|
|
||||||
# Write as a html file
|
|
||||||
|
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
with open(output_dir / file_name, "w", encoding="utf-8") as f:
|
|
||||||
f.write(soup.prettify())
|
|
||||||
|
|
||||||
|
|
||||||
def write_episodes_list_html(output_dir, file_name, ep_indices, ep_html_fnames, dataset):
|
|
||||||
"""Write an html file containing information related to the dataset and a list of links to
|
|
||||||
html pages of episodes."""
|
|
||||||
soup, head, body = create_html_page("TODO")
|
|
||||||
|
|
||||||
h3 = soup.new_tag("h3")
|
|
||||||
h3.string = "TODO"
|
|
||||||
body.append(h3)
|
|
||||||
|
|
||||||
ul_info = soup.new_tag("ul")
|
|
||||||
body.append(ul_info)
|
|
||||||
|
|
||||||
li_info = soup.new_tag("li")
|
|
||||||
li_info.string = f"Number of samples/frames: {dataset.num_samples}"
|
|
||||||
ul_info.append(li_info)
|
|
||||||
|
|
||||||
li_info = soup.new_tag("li")
|
|
||||||
li_info.string = f"Number of episodes: {dataset.num_episodes}"
|
|
||||||
ul_info.append(li_info)
|
|
||||||
|
|
||||||
li_info = soup.new_tag("li")
|
|
||||||
li_info.string = f"Frames per second: {dataset.fps}"
|
|
||||||
ul_info.append(li_info)
|
|
||||||
|
|
||||||
# li_info = soup.new_tag("li")
|
|
||||||
# li_info.string = f"Size: {format_big_number(dataset.hf_dataset.info.size_in_bytes)}B"
|
|
||||||
# ul_info.append(li_info)
|
|
||||||
|
|
||||||
ul = soup.new_tag("ul")
|
|
||||||
body.append(ul)
|
|
||||||
|
|
||||||
for ep_idx, ep_html_fname in zip(ep_indices, ep_html_fnames, strict=False):
|
|
||||||
li = soup.new_tag("li")
|
|
||||||
ul.append(li)
|
|
||||||
|
|
||||||
a = soup.new_tag("a", href=ep_html_fname)
|
|
||||||
a.string = f"Episode number {ep_idx}"
|
|
||||||
|
|
||||||
li.append(a)
|
|
||||||
|
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
with open(output_dir / file_name, "w", encoding="utf-8") as f:
|
|
||||||
f.write(soup.prettify())
|
|
||||||
|
|
||||||
|
|
||||||
def run_inference(dataset, episode_index, policy, num_workers=4, batch_size=32, device="cuda"):
|
|
||||||
policy.eval()
|
|
||||||
policy.to(device)
|
|
||||||
|
|
||||||
logging.info("Loading dataloader")
|
logging.info("Loading dataloader")
|
||||||
episode_sampler = EpisodeSampler(dataset, episode_index)
|
episode_sampler = EpisodeSampler(dataset, episode_index)
|
||||||
|
@ -396,104 +124,70 @@ def run_inference(dataset, episode_index, policy, num_workers=4, batch_size=32,
|
||||||
sampler=episode_sampler,
|
sampler=episode_sampler,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Running inference")
|
logging.info("Starting Rerun")
|
||||||
inference_results = {}
|
|
||||||
|
if mode not in ["local", "distant"]:
|
||||||
|
raise ValueError(mode)
|
||||||
|
|
||||||
|
spawn_local_viewer = mode == "local" and not save
|
||||||
|
rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer)
|
||||||
|
|
||||||
|
# Manually call python garbage collector after `rr.init` to avoid hanging in a blocking flush
|
||||||
|
# when iterating on a dataloader with `num_workers` > 0
|
||||||
|
# TODO(rcadene): remove `gc.collect` when rerun version 0.16 is out, which includes a fix
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
if mode == "distant":
|
||||||
|
rr.serve(open_browser=False, web_port=web_port, ws_port=ws_port)
|
||||||
|
|
||||||
|
logging.info("Logging to Rerun")
|
||||||
|
|
||||||
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
|
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
|
||||||
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
|
# iterate over the batch
|
||||||
with torch.inference_mode():
|
for i in range(len(batch["index"])):
|
||||||
output_dict = policy.forward(batch)
|
rr.set_time_sequence("frame_index", batch["frame_index"][i].item())
|
||||||
|
rr.set_time_seconds("timestamp", batch["timestamp"][i].item())
|
||||||
|
|
||||||
for key in output_dict:
|
# display each camera image
|
||||||
if key not in inference_results:
|
for key in dataset.camera_keys:
|
||||||
inference_results[key] = []
|
# TODO(rcadene): add `.compress()`? is it lossless?
|
||||||
inference_results[key].append(output_dict[key].to("cpu"))
|
rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i])))
|
||||||
|
|
||||||
for key in inference_results:
|
# display each dimension of action space (e.g. actuators command)
|
||||||
inference_results[key] = torch.cat(inference_results[key])
|
if "action" in batch:
|
||||||
|
for dim_idx, val in enumerate(batch["action"][i]):
|
||||||
|
rr.log(f"action/{dim_idx}", rr.Scalar(val.item()))
|
||||||
|
|
||||||
return inference_results
|
# display each dimension of observed state space (e.g. agent position in joint space)
|
||||||
|
if "observation.state" in batch:
|
||||||
|
for dim_idx, val in enumerate(batch["observation.state"][i]):
|
||||||
|
rr.log(f"state/{dim_idx}", rr.Scalar(val.item()))
|
||||||
|
|
||||||
|
if "next.done" in batch:
|
||||||
|
rr.log("next.done", rr.Scalar(batch["next.done"][i].item()))
|
||||||
|
|
||||||
def visualize_dataset(
|
if "next.reward" in batch:
|
||||||
repo_id: str,
|
rr.log("next.reward", rr.Scalar(batch["next.reward"][i].item()))
|
||||||
episode_indices: list[int] = None,
|
|
||||||
output_dir: Path | None = None,
|
|
||||||
serve: bool = True,
|
|
||||||
port: int = 9090,
|
|
||||||
force_overwrite: bool = True,
|
|
||||||
policy_repo_id: str | None = None,
|
|
||||||
policy_ckpt_path: Path | None = None,
|
|
||||||
batch_size: int = 32,
|
|
||||||
num_workers: int = 4,
|
|
||||||
) -> Path | None:
|
|
||||||
init_logging()
|
|
||||||
|
|
||||||
has_policy = policy_repo_id or policy_ckpt_path
|
if "next.success" in batch:
|
||||||
|
rr.log("next.success", rr.Scalar(batch["next.success"][i].item()))
|
||||||
if has_policy:
|
|
||||||
logging.info("Loading policy")
|
|
||||||
if policy_repo_id:
|
|
||||||
pretrained_policy_path = Path(snapshot_download(policy_repo_id))
|
|
||||||
elif policy_ckpt_path:
|
|
||||||
pretrained_policy_path = Path(policy_ckpt_path)
|
|
||||||
policy = ACTPolicy.from_pretrained(pretrained_policy_path)
|
|
||||||
with open(pretrained_policy_path / "config.yaml") as f:
|
|
||||||
cfg = yaml.safe_load(f)
|
|
||||||
delta_timestamps = cfg["training"]["delta_timestamps"]
|
|
||||||
else:
|
|
||||||
delta_timestamps = None
|
|
||||||
|
|
||||||
logging.info("Loading dataset")
|
|
||||||
dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
|
|
||||||
|
|
||||||
if not dataset.video:
|
|
||||||
raise NotImplementedError(f"Image datasets ({dataset.video=}) are currently not supported.")
|
|
||||||
|
|
||||||
if output_dir is None:
|
|
||||||
output_dir = f"outputs/visualize_dataset/{repo_id}"
|
|
||||||
|
|
||||||
|
if mode == "local" and save:
|
||||||
|
# save .rrd locally
|
||||||
output_dir = Path(output_dir)
|
output_dir = Path(output_dir)
|
||||||
if force_overwrite and output_dir.exists():
|
|
||||||
shutil.rmtree(output_dir)
|
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
repo_id_str = repo_id.replace("/", "_")
|
||||||
|
rrd_path = output_dir / f"{repo_id_str}_episode_{episode_index}.rrd"
|
||||||
|
rr.save(rrd_path)
|
||||||
|
return rrd_path
|
||||||
|
|
||||||
# Create a simlink from the dataset video folder containg mp4 files to the output directory
|
elif mode == "distant":
|
||||||
# so that the http server can get access to the mp4 files.
|
# stop the process from exiting since it is serving the websocket connection
|
||||||
ln_videos_dir = output_dir / "videos"
|
try:
|
||||||
if not ln_videos_dir.exists():
|
while True:
|
||||||
ln_videos_dir.symlink_to(dataset.videos_dir.resolve())
|
time.sleep(1)
|
||||||
|
except KeyboardInterrupt:
|
||||||
if episode_indices is None:
|
print("Ctrl-C received. Exiting.")
|
||||||
episode_indices = list(range(dataset.num_episodes))
|
|
||||||
|
|
||||||
logging.info("Writing html")
|
|
||||||
ep_html_fnames = []
|
|
||||||
for episode_index in tqdm.tqdm(episode_indices):
|
|
||||||
inference_results = None
|
|
||||||
if has_policy:
|
|
||||||
inference_results_path = output_dir / f"episode_{episode_index}.safetensors"
|
|
||||||
if inference_results_path.exists():
|
|
||||||
inference_results = load_file(inference_results_path)
|
|
||||||
else:
|
|
||||||
inference_results = run_inference(dataset, episode_index, policy)
|
|
||||||
save_file(inference_results, inference_results_path)
|
|
||||||
|
|
||||||
# write states and actions in a csv
|
|
||||||
ep_csv_fname = f"episode_{episode_index}.csv"
|
|
||||||
write_episode_data_csv(output_dir, ep_csv_fname, episode_index, dataset, inference_results)
|
|
||||||
|
|
||||||
js_fname = f"episode_{episode_index}.js"
|
|
||||||
write_episode_data_js(output_dir, js_fname, ep_csv_fname, dataset)
|
|
||||||
|
|
||||||
# write a html page to view videos and timeseries
|
|
||||||
ep_html_fname = f"episode_{episode_index}.html"
|
|
||||||
write_episode_data_html(output_dir, ep_html_fname, js_fname, episode_index, dataset)
|
|
||||||
ep_html_fnames.append(ep_html_fname)
|
|
||||||
|
|
||||||
write_episodes_list_html(output_dir, "index.html", episode_indices, ep_html_fnames, dataset)
|
|
||||||
|
|
||||||
if serve:
|
|
||||||
run_server(output_dir, port)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
@ -503,51 +197,13 @@ def main():
|
||||||
"--repo-id",
|
"--repo-id",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).",
|
help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht`).",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--episode-indices",
|
"--episode-index",
|
||||||
type=int,
|
type=int,
|
||||||
nargs="*",
|
required=True,
|
||||||
default=None,
|
help="Episode to visualize.",
|
||||||
help="Episode indices to visualize (e.g. `0 1 5 6` to load episodes of index 0, 1, 5 and 6). By default loads all episodes.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--output-dir",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Directory path to write html files and kickoff a web server. By default write them to 'outputs/visualize_dataset/REPO_ID'.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--serve",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="Launch web server.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--port",
|
|
||||||
type=int,
|
|
||||||
default=9090,
|
|
||||||
help="Web port used by the http server.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--force-overwrite",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="Delete the output directory if it exists already.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--policy-repo-id",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Name of hugging face repositery containing a pretrained policy (e.g. `lerobot/diffusion_pusht` for https://huggingface.co/lerobot/diffusion_pusht).",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--policy-ckpt-path",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Name of hugging face repositery containing a pretrained policy (e.g. `lerobot/diffusion_pusht` for https://huggingface.co/lerobot/diffusion_pusht).",
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--batch-size",
|
"--batch-size",
|
||||||
|
@ -561,6 +217,49 @@ def main():
|
||||||
default=4,
|
default=4,
|
||||||
help="Number of processes of Dataloader for loading the data.",
|
help="Number of processes of Dataloader for loading the data.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--mode",
|
||||||
|
type=str,
|
||||||
|
default="local",
|
||||||
|
help=(
|
||||||
|
"Mode of viewing between 'local' or 'distant'. "
|
||||||
|
"'local' requires data to be on a local machine. It spawns a viewer to visualize the data locally. "
|
||||||
|
"'distant' creates a server on the distant machine where the data is stored. Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--web-port",
|
||||||
|
type=int,
|
||||||
|
default=9090,
|
||||||
|
help="Web port for rerun.io when `--mode distant` is set.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ws-port",
|
||||||
|
type=int,
|
||||||
|
default=9087,
|
||||||
|
help="Web socket port for rerun.io when `--mode distant` is set.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help=(
|
||||||
|
"Save a .rrd file in the directory provided by `--output-dir`. "
|
||||||
|
"It also deactivates the spawning of a viewer. ",
|
||||||
|
"Visualize the data by running `rerun path/to/file.rrd` on your local machine.",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-dir",
|
||||||
|
type=str,
|
||||||
|
help="Directory path to write a .rrd file when `--save 1` is set.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--root",
|
||||||
|
type=str,
|
||||||
|
help="Root directory for a dataset stored on a local machine.",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
visualize_dataset(**vars(args))
|
visualize_dataset(**vars(args))
|
||||||
|
|
|
@ -1,263 +0,0 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
""" Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
|
|
||||||
|
|
||||||
Note: The last frame of the episode doesnt always correspond to a final state.
|
|
||||||
That's because our datasets are composed of transition from state to state up to
|
|
||||||
the antepenultimate state associated to the ultimate action to arrive in the final state.
|
|
||||||
However, there might not be a transition from a final state to another state.
|
|
||||||
|
|
||||||
Note: This script aims to visualize the data used to train the neural networks.
|
|
||||||
~What you see is what you get~. When visualizing image modality, it is often expected to observe
|
|
||||||
lossly compression artifacts since these images have been decoded from compressed mp4 videos to
|
|
||||||
save disk space. The compression factor applied has been tuned to not affect success rate.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
|
|
||||||
- Visualize data stored on a local machine:
|
|
||||||
```
|
|
||||||
local$ python lerobot/scripts/visualize_dataset.py \
|
|
||||||
--repo-id lerobot/pusht \
|
|
||||||
--episode-index 0
|
|
||||||
```
|
|
||||||
|
|
||||||
- Visualize data stored on a distant machine with a local viewer:
|
|
||||||
```
|
|
||||||
distant$ python lerobot/scripts/visualize_dataset.py \
|
|
||||||
--repo-id lerobot/pusht \
|
|
||||||
--episode-index 0 \
|
|
||||||
--save 1 \
|
|
||||||
--output-dir path/to/directory
|
|
||||||
|
|
||||||
local$ scp distant:path/to/directory/lerobot_pusht_episode_0.rrd .
|
|
||||||
local$ rerun lerobot_pusht_episode_0.rrd
|
|
||||||
```
|
|
||||||
|
|
||||||
- Visualize data stored on a distant machine through streaming:
|
|
||||||
(You need to forward the websocket port to the distant machine, with
|
|
||||||
`ssh -L 9087:localhost:9087 username@remote-host`)
|
|
||||||
```
|
|
||||||
distant$ python lerobot/scripts/visualize_dataset.py \
|
|
||||||
--repo-id lerobot/pusht \
|
|
||||||
--episode-index 0 \
|
|
||||||
--mode distant \
|
|
||||||
--ws-port 9087
|
|
||||||
|
|
||||||
local$ rerun ws://localhost:9087
|
|
||||||
```
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import gc
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import rerun as rr
|
|
||||||
import torch
|
|
||||||
import tqdm
|
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
|
||||||
|
|
||||||
|
|
||||||
class EpisodeSampler(torch.utils.data.Sampler):
|
|
||||||
def __init__(self, dataset, episode_index):
|
|
||||||
from_idx = dataset.episode_data_index["from"][episode_index].item()
|
|
||||||
to_idx = dataset.episode_data_index["to"][episode_index].item()
|
|
||||||
self.frame_ids = range(from_idx, to_idx)
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
return iter(self.frame_ids)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.frame_ids)
|
|
||||||
|
|
||||||
|
|
||||||
def to_hwc_uint8_numpy(chw_float32_torch):
|
|
||||||
assert chw_float32_torch.dtype == torch.float32
|
|
||||||
assert chw_float32_torch.ndim == 3
|
|
||||||
c, h, w = chw_float32_torch.shape
|
|
||||||
assert c < h and c < w, f"expect channel first images, but instead {chw_float32_torch.shape}"
|
|
||||||
hwc_uint8_numpy = (chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy()
|
|
||||||
return hwc_uint8_numpy
|
|
||||||
|
|
||||||
|
|
||||||
def visualize_dataset(
|
|
||||||
repo_id: str,
|
|
||||||
episode_index: int,
|
|
||||||
batch_size: int = 32,
|
|
||||||
num_workers: int = 0,
|
|
||||||
mode: str = "local",
|
|
||||||
web_port: int = 9090,
|
|
||||||
ws_port: int = 9087,
|
|
||||||
save: bool = False,
|
|
||||||
output_dir: Path | None = None,
|
|
||||||
) -> Path | None:
|
|
||||||
if save:
|
|
||||||
assert (
|
|
||||||
output_dir is not None
|
|
||||||
), "Set an output directory where to write .rrd files with `--output-dir path/to/directory`."
|
|
||||||
|
|
||||||
logging.info("Loading dataset")
|
|
||||||
dataset = LeRobotDataset(repo_id)
|
|
||||||
|
|
||||||
logging.info("Loading dataloader")
|
|
||||||
episode_sampler = EpisodeSampler(dataset, episode_index)
|
|
||||||
dataloader = torch.utils.data.DataLoader(
|
|
||||||
dataset,
|
|
||||||
num_workers=num_workers,
|
|
||||||
batch_size=batch_size,
|
|
||||||
sampler=episode_sampler,
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info("Starting Rerun")
|
|
||||||
|
|
||||||
if mode not in ["local", "distant"]:
|
|
||||||
raise ValueError(mode)
|
|
||||||
|
|
||||||
spawn_local_viewer = mode == "local" and not save
|
|
||||||
rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer)
|
|
||||||
|
|
||||||
# Manually call python garbage collector after `rr.init` to avoid hanging in a blocking flush
|
|
||||||
# when iterating on a dataloader with `num_workers` > 0
|
|
||||||
# TODO(rcadene): remove `gc.collect` when rerun version 0.16 is out, which includes a fix
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
if mode == "distant":
|
|
||||||
rr.serve(open_browser=False, web_port=web_port, ws_port=ws_port)
|
|
||||||
|
|
||||||
logging.info("Logging to Rerun")
|
|
||||||
|
|
||||||
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
|
|
||||||
# iterate over the batch
|
|
||||||
for i in range(len(batch["index"])):
|
|
||||||
rr.set_time_sequence("frame_index", batch["frame_index"][i].item())
|
|
||||||
rr.set_time_seconds("timestamp", batch["timestamp"][i].item())
|
|
||||||
|
|
||||||
# display each camera image
|
|
||||||
for key in dataset.camera_keys:
|
|
||||||
# TODO(rcadene): add `.compress()`? is it lossless?
|
|
||||||
rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i])))
|
|
||||||
|
|
||||||
# display each dimension of action space (e.g. actuators command)
|
|
||||||
if "action" in batch:
|
|
||||||
for dim_idx, val in enumerate(batch["action"][i]):
|
|
||||||
rr.log(f"action/{dim_idx}", rr.Scalar(val.item()))
|
|
||||||
|
|
||||||
# display each dimension of observed state space (e.g. agent position in joint space)
|
|
||||||
if "observation.state" in batch:
|
|
||||||
for dim_idx, val in enumerate(batch["observation.state"][i]):
|
|
||||||
rr.log(f"state/{dim_idx}", rr.Scalar(val.item()))
|
|
||||||
|
|
||||||
if "next.done" in batch:
|
|
||||||
rr.log("next.done", rr.Scalar(batch["next.done"][i].item()))
|
|
||||||
|
|
||||||
if "next.reward" in batch:
|
|
||||||
rr.log("next.reward", rr.Scalar(batch["next.reward"][i].item()))
|
|
||||||
|
|
||||||
if "next.success" in batch:
|
|
||||||
rr.log("next.success", rr.Scalar(batch["next.success"][i].item()))
|
|
||||||
|
|
||||||
if mode == "local" and save:
|
|
||||||
# save .rrd locally
|
|
||||||
output_dir = Path(output_dir)
|
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
repo_id_str = repo_id.replace("/", "_")
|
|
||||||
rrd_path = output_dir / f"{repo_id_str}_episode_{episode_index}.rrd"
|
|
||||||
rr.save(rrd_path)
|
|
||||||
return rrd_path
|
|
||||||
|
|
||||||
elif mode == "distant":
|
|
||||||
# stop the process from exiting since it is serving the websocket connection
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
time.sleep(1)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
print("Ctrl-C received. Exiting.")
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--repo-id",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht`).",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--episode-index",
|
|
||||||
type=int,
|
|
||||||
required=True,
|
|
||||||
help="Episode to visualize.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--batch-size",
|
|
||||||
type=int,
|
|
||||||
default=32,
|
|
||||||
help="Batch size loaded by DataLoader.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--num-workers",
|
|
||||||
type=int,
|
|
||||||
default=4,
|
|
||||||
help="Number of processes of Dataloader for loading the data.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--mode",
|
|
||||||
type=str,
|
|
||||||
default="local",
|
|
||||||
help=(
|
|
||||||
"Mode of viewing between 'local' or 'distant'. "
|
|
||||||
"'local' requires data to be on a local machine. It spawns a viewer to visualize the data locally. "
|
|
||||||
"'distant' creates a server on the distant machine where the data is stored. Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--web-port",
|
|
||||||
type=int,
|
|
||||||
default=9090,
|
|
||||||
help="Web port for rerun.io when `--mode distant` is set.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--ws-port",
|
|
||||||
type=int,
|
|
||||||
default=9087,
|
|
||||||
help="Web socket port for rerun.io when `--mode distant` is set.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--save",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help=(
|
|
||||||
"Save a .rrd file in the directory provided by `--output-dir`. "
|
|
||||||
"It also deactivates the spawning of a viewer. ",
|
|
||||||
"Visualize the data by running `rerun path/to/file.rrd` on your local machine.",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--output-dir",
|
|
||||||
type=str,
|
|
||||||
help="Directory path to write a .rrd file when `--save 1` is set.",
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
visualize_dataset(**vars(args))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
Loading…
Reference in New Issue