This commit is contained in:
Remi Cadene 2024-05-17 02:05:44 +00:00
parent 9b62c25f6c
commit 49e1fa708a
1 changed files with 310 additions and 163 deletions

View File

@ -30,164 +30,341 @@ Examples:
- Visualize data stored on a local machine:
```
local$ python lerobot/scripts/visualize_dataset.py \
--repo-id lerobot/pusht \
--episode-index 0
--repo-id lerobot/pusht
local$ open http://localhost:9090
```
- Visualize data stored on a distant machine with a local viewer:
```
distant$ python lerobot/scripts/visualize_dataset.py \
--repo-id lerobot/pusht
local$ ssh -L 9090:localhost:9090 distant # create a ssh tunnel
local$ open http://localhost:9090
```
- Select episodes to visualize:
```
python lerobot/scripts/visualize_dataset.py \
--repo-id lerobot/pusht \
--episode-index 0 \
--save 1 \
--output-dir path/to/directory
local$ scp distant:path/to/directory/lerobot_pusht_episode_0.rrd .
local$ rerun lerobot_pusht_episode_0.rrd
--episode-indices 7 3 5 1 4
```
- Visualize data stored on a distant machine through streaming:
(You need to forward the websocket port to the distant machine, with
`ssh -L 9087:localhost:9087 username@remote-host`)
```
distant$ python lerobot/scripts/visualize_dataset.py \
--repo-id lerobot/pusht \
--episode-index 0 \
--mode distant \
--ws-port 9087
local$ rerun ws://localhost:9087
```
"""
import argparse
import gc
import http.server
import logging
import time
import os
import shutil
import socketserver
from http.server import HTTPServer, SimpleHTTPRequestHandler
from pathlib import Path
import rerun as rr
import torch
import tqdm
from bs4 import BeautifulSoup
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.utils.utils import format_big_number, init_logging
class EpisodeSampler(torch.utils.data.Sampler):
def __init__(self, dataset, episode_index):
from_idx = dataset.episode_data_index["from"][episode_index].item()
to_idx = dataset.episode_data_index["to"][episode_index].item()
self.frame_ids = range(from_idx, to_idx)
def __iter__(self):
return iter(self.frame_ids)
def __len__(self):
return len(self.frame_ids)
class NoCacheHTTPRequestHandler(http.server.SimpleHTTPRequestHandler):
def end_headers(self):
self.send_header("Cache-Control", "no-store, no-cache, must-revalidate")
self.send_header("Pragma", "no-cache")
self.send_header("Expires", "0")
super().end_headers()
def to_hwc_uint8_numpy(chw_float32_torch):
assert chw_float32_torch.dtype == torch.float32
assert chw_float32_torch.ndim == 3
c, h, w = chw_float32_torch.shape
assert c < h and c < w, f"expect channel first images, but instead {chw_float32_torch.shape}"
hwc_uint8_numpy = (chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy()
return hwc_uint8_numpy
def run_server(path, port):
# Change directory to serve 'index.html` as front page
os.chdir(path)
server_address = ("", port)
httpd = HTTPServer(server_address, SimpleHTTPRequestHandler)
with socketserver.TCPServer(("", port), NoCacheHTTPRequestHandler) as httpd:
logging.info(f"Serving HTTP on 0.0.0.0 port {port} (http://0.0.0.0:{port}/) ...")
httpd.serve_forever()
def create_html_page(page_title: str):
"""Create a html page with beautiful soop with default doctype, meta, header and title."""
soup = BeautifulSoup("", "html.parser")
doctype = soup.new_tag("!DOCTYPE html")
soup.append(doctype)
html = soup.new_tag("html", lang="en")
soup.append(html)
head = soup.new_tag("head")
html.append(head)
meta_charset = soup.new_tag("meta", charset="UTF-8")
head.append(meta_charset)
meta_viewport = soup.new_tag(
"meta", attrs={"name": "viewport", "content": "width=device-width, initial-scale=1.0"}
)
head.append(meta_viewport)
title = soup.new_tag("title")
title.string = page_title
head.append(title)
body = soup.new_tag("body")
html.append(body)
main_div = soup.new_tag("div")
body.append(main_div)
return soup, body
def write_episode_data_csv(output_dir, file_name, episode_index, dataset):
"""Write a csv file containg timeseries data of an episode (e.g. state and action).
This file will be loaded by Dygraph javascript to plot data in real time."""
from_idx = dataset.episode_data_index["from"][episode_index]
to_idx = dataset.episode_data_index["to"][episode_index]
has_state = "observation.state" in dataset.hf_dataset.features
has_action = "action" in dataset.hf_dataset.features
# init header of csv with state and action names
header = ["timestamp"]
if has_state:
dim_state = len(dataset.hf_dataset["observation.state"][0])
header += [f"state_{i}" for i in range(dim_state)]
if has_action:
dim_action = len(dataset.hf_dataset["action"][0])
header += [f"action_{i}" for i in range(dim_action)]
columns = ["timestamp"]
if has_state:
columns += ["observation.state"]
if has_action:
columns += ["action"]
rows = []
data = dataset.hf_dataset.select_columns(columns)
for i in range(from_idx, to_idx):
row = [data[i]["timestamp"].item()]
if has_state:
row += data[i]["observation.state"].tolist()
if has_action:
row += data[i]["action"].tolist()
rows.append(row)
output_dir.mkdir(parents=True, exist_ok=True)
with open(output_dir / file_name, "w") as f:
f.write(",".join(header) + "\n")
for row in rows:
row_str = [str(col) for col in row]
f.write(",".join(row_str) + "\n")
def write_episode_data_js(output_dir, file_name, ep_csv_fname, dataset):
"""Write a javascript file containing logic to synchronize camera feeds and timeseries."""
s = ""
s += "document.addEventListener('DOMContentLoaded', function () {\n"
for i, key in enumerate(dataset.video_frame_keys):
s += f" const video{i} = document.getElementById('video_{key}');\n"
s += " const slider = document.getElementById('videoControl');\n"
s += " const playButton = document.getElementById('playButton');\n"
s += f" const dygraph = new Dygraph(document.getElementById('graphdiv'), '{ep_csv_fname}', " + "{\n"
s += " pixelsPerPoint: 0.01\n"
s += " });\n"
s += "\n"
s += " // Function to play both videos\n"
s += " playButton.addEventListener('click', function () {\n"
for i in range(len(dataset.video_frame_keys)):
s += f" video{i}.play();\n"
s += " // playButton.disabled = true; // Optional: disable button after playing\n"
s += " });\n"
s += "\n"
s += " // Update the video time when the slider value changes\n"
s += " slider.addEventListener('input', function () {\n"
s += " const sliderValue = slider.value;\n"
for i in range(len(dataset.video_frame_keys)):
s += f" const time{i} = (video{i}.duration * sliderValue) / 100;\n"
for i in range(len(dataset.video_frame_keys)):
s += f" video{i}.currentTime = time{i};\n"
s += " });\n"
s += "\n"
s += " // Synchronize slider with the video's current time\n"
s += " const syncSlider = (video) => {\n"
s += " video.addEventListener('timeupdate', function () {\n"
s += " if (video.duration) {\n"
s += " const pc = (100 / video.duration) * video.currentTime;\n"
s += " slider.value = pc;\n"
s += " const index = Math.floor(pc * dygraph.numRows() / 100);\n"
s += " dygraph.setSelection(index, undefined, true, true);\n"
s += " }\n"
s += " });\n"
s += " };\n"
s += "\n"
for i in range(len(dataset.video_frame_keys)):
s += f" syncSlider(video{i});\n"
s += "\n"
s += "});\n"
output_dir.mkdir(parents=True, exist_ok=True)
with open(output_dir / file_name, "w", encoding="utf-8") as f:
f.write(s)
def write_episode_data_html(output_dir, file_name, js_fname, ep_index, dataset):
"""Write an html file containg video feeds and timeseries associated to an episode."""
soup, body = create_html_page("")
main_div = soup.new_tag("div")
body.append(main_div)
def create_video(id, src):
video = soup.new_tag("video", id=id, width="320", height="240", controls="")
source = soup.new_tag("source", src=src, type="video/mp4")
video.string = "Your browser does not support the video tag."
video.append(source)
return video
# get first frame of episode (hack to get video_path of the episode)
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
for key in dataset.video_frame_keys:
# Example of video_path: '../videos/observation.image_episode_000004.mp4'
video_path = Path("..") / dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"]
main_div.append(create_video(f"video_{key}", video_path))
control_div = soup.new_tag("div")
body.append(control_div)
graph_div = soup.new_tag("div", id="graphdiv", style="width: 600px; height: 300px;")
body.append(graph_div)
button = soup.new_tag("button", id="playButton")
button.string = "Play Videos"
control_div.append(button)
range_input = soup.new_tag(
"input", type="range", id="videoControl", min="0", max="100", value="0", step="1"
)
control_div.append(range_input)
script = soup.new_tag("script", type="text/javascript", src=js_fname)
body.append(script)
script_dygraph = soup.new_tag(
"script",
type="text/javascript",
src="https://cdn.jsdelivr.net/npm/dygraphs@2.1.0/dist/dygraph.min.js",
)
body.append(script_dygraph)
link_dygraph = soup.new_tag(
"link", rel="stylesheet", href="https://cdn.jsdelivr.net/npm/dygraphs@2.1.0/dist/dygraph.min.css"
)
body.append(link_dygraph)
output_dir.mkdir(parents=True, exist_ok=True)
with open(output_dir / file_name, "w", encoding="utf-8") as f:
f.write(soup.prettify())
def write_episodes_list_html(output_dir, file_name, ep_indices, ep_html_fnames, dataset):
"""Write an html file containing information related to the dataset and a list of links to
html pages of episodes."""
soup, body = create_html_page(dataset.hf_dataset.info.dataset_name)
h3 = soup.new_tag("h3")
h3.string = dataset.hf_dataset.info.dataset_name
body.append(h3)
ul_info = soup.new_tag("ul")
body.append(ul_info)
li_info = soup.new_tag("li")
li_info.string = f"Number of samples/frames: {dataset.num_samples}"
ul_info.append(li_info)
li_info = soup.new_tag("li")
li_info.string = f"Number of episodes: {dataset.num_episodes}"
ul_info.append(li_info)
li_info = soup.new_tag("li")
li_info.string = f"Frames per second: {dataset.fps}"
ul_info.append(li_info)
li_info = soup.new_tag("li")
li_info.string = f"Size: {format_big_number(dataset.hf_dataset.info.size_in_bytes)}B"
ul_info.append(li_info)
ul = soup.new_tag("ul")
body.append(ul)
for ep_idx, ep_html_fname in zip(ep_indices, ep_html_fnames, strict=False):
li = soup.new_tag("li")
ul.append(li)
a = soup.new_tag("a", href=ep_html_fname)
a.string = f"Episode number {ep_idx}"
li.append(a)
output_dir.mkdir(parents=True, exist_ok=True)
with open(output_dir / file_name, "w", encoding="utf-8") as f:
f.write(soup.prettify())
def visualize_dataset(
repo_id: str,
episode_index: int,
batch_size: int = 32,
num_workers: int = 0,
mode: str = "local",
web_port: int = 9090,
ws_port: int = 9087,
save: bool = False,
episode_indices: list[int] = None,
output_dir: Path | None = None,
serve: bool = True,
web_port: int = 9090,
) -> Path | None:
if save:
assert (
output_dir is not None
), "Set an output directory where to write .rrd files with `--output-dir path/to/directory`."
init_logging()
logging.info("Loading dataset")
dataset = LeRobotDataset(repo_id)
logging.info("Loading dataloader")
episode_sampler = EpisodeSampler(dataset, episode_index)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=num_workers,
batch_size=batch_size,
sampler=episode_sampler,
)
if not dataset.video:
raise NotImplementedError(f"Image datasets ({dataset.video=}) are currently not supported.")
logging.info("Starting Rerun")
if output_dir is None:
output_dir = f"outputs/visualize_dataset/{repo_id}"
if mode not in ["local", "distant"]:
raise ValueError(mode)
output_dir = Path(output_dir)
if output_dir.exists():
shutil.rmtree(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
spawn_local_viewer = mode == "local" and not save
rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer)
# Create a simlink from the dataset video folder containg mp4 files to the output directory
# so that the http server can get access to the mp4 files.
ln_videos_dir = output_dir / "videos"
ln_videos_dir.symlink_to(dataset.videos_dir)
# Manually call python garbage collector after `rr.init` to avoid hanging in a blocking flush
# when iterating on a dataloader with `num_workers` > 0
# TODO(rcadene): remove `gc.collect` when rerun version 0.16 is out, which includes a fix
gc.collect()
if episode_indices is None:
episode_indices = list(range(dataset.num_episodes))
if mode == "distant":
rr.serve(open_browser=False, web_port=web_port, ws_port=ws_port)
logging.info("Writing html")
ep_html_fnames = []
for episode_idx in tqdm.tqdm(episode_indices):
# write states and actions in a csv
ep_csv_fname = f"episode_{episode_idx}.csv"
write_episode_data_csv(output_dir, ep_csv_fname, episode_idx, dataset)
logging.info("Logging to Rerun")
js_fname = f"episode_{episode_idx}.js"
write_episode_data_js(output_dir, js_fname, ep_csv_fname, dataset)
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
# iterate over the batch
for i in range(len(batch["index"])):
rr.set_time_sequence("frame_index", batch["frame_index"][i].item())
rr.set_time_seconds("timestamp", batch["timestamp"][i].item())
# write a html page to view videos and timeseries
ep_html_fname = f"episode_{episode_idx}.html"
write_episode_data_html(output_dir, ep_html_fname, js_fname, episode_idx, dataset)
ep_html_fnames.append(ep_html_fname)
# display each camera image
for key in dataset.camera_keys:
# TODO(rcadene): add `.compress()`? is it lossless?
rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i])))
write_episodes_list_html(output_dir, "index.html", episode_indices, ep_html_fnames, dataset)
# display each dimension of action space (e.g. actuators command)
if "action" in batch:
for dim_idx, val in enumerate(batch["action"][i]):
rr.log(f"action/{dim_idx}", rr.Scalar(val.item()))
# display each dimension of observed state space (e.g. agent position in joint space)
if "observation.state" in batch:
for dim_idx, val in enumerate(batch["observation.state"][i]):
rr.log(f"state/{dim_idx}", rr.Scalar(val.item()))
if "next.done" in batch:
rr.log("next.done", rr.Scalar(batch["next.done"][i].item()))
if "next.reward" in batch:
rr.log("next.reward", rr.Scalar(batch["next.reward"][i].item()))
if "next.success" in batch:
rr.log("next.success", rr.Scalar(batch["next.success"][i].item()))
if mode == "local" and save:
# save .rrd locally
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
repo_id_str = repo_id.replace("/", "_")
rrd_path = output_dir / f"{repo_id_str}_episode_{episode_index}.rrd"
rr.save(rrd_path)
return rrd_path
elif mode == "distant":
# stop the process from exiting since it is serving the websocket connection
try:
while True:
time.sleep(1)
except KeyboardInterrupt:
print("Ctrl-C received. Exiting.")
if serve:
run_server(output_dir, web_port)
def main():
@ -200,59 +377,29 @@ def main():
help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht`).",
)
parser.add_argument(
"--episode-index",
"--episode-indices",
type=int,
required=True,
help="Episode to visualize.",
nargs="*",
default=None,
help="Episode indices to visualize (e.g. `0 1 5 6` to load episodes of index 0, 1, 5 and 6). By default loads all episodes.",
)
parser.add_argument(
"--batch-size",
type=int,
default=32,
help="Batch size loaded by DataLoader.",
)
parser.add_argument(
"--num-workers",
type=int,
default=4,
help="Number of processes of Dataloader for loading the data.",
)
parser.add_argument(
"--mode",
"--output-dir",
type=str,
default="local",
help=(
"Mode of viewing between 'local' or 'distant'. "
"'local' requires data to be on a local machine. It spawns a viewer to visualize the data locally. "
"'distant' creates a server on the distant machine where the data is stored. Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine."
),
default=None,
help="Directory path to write html files and kickoff a web server. By default write them to 'outputs/visualize_dataset/REPO_ID'.",
)
parser.add_argument(
"--serve",
type=int,
default=1,
help="Launch web server.",
)
parser.add_argument(
"--web-port",
type=int,
default=9090,
help="Web port for rerun.io when `--mode distant` is set.",
)
parser.add_argument(
"--ws-port",
type=int,
default=9087,
help="Web socket port for rerun.io when `--mode distant` is set.",
)
parser.add_argument(
"--save",
type=int,
default=0,
help=(
"Save a .rrd file in the directory provided by `--output-dir`. "
"It also deactivates the spawning of a viewer. ",
"Visualize the data by running `rerun path/to/file.rrd` on your local machine.",
),
)
parser.add_argument(
"--output-dir",
type=str,
help="Directory path to write a .rrd file when `--save 1` is set.",
help="Web port used by the http server.",
)
args = parser.parse_args()