From 49e1fa708a890f784f402b99f52981400ad7963d Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Fri, 17 May 2024 02:05:44 +0000 Subject: [PATCH] WIP --- lerobot/scripts/visualize_dataset.py | 473 ++++++++++++++++++--------- 1 file changed, 310 insertions(+), 163 deletions(-) diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py index 58da6a47..3ba1ec54 100644 --- a/lerobot/scripts/visualize_dataset.py +++ b/lerobot/scripts/visualize_dataset.py @@ -30,164 +30,341 @@ Examples: - Visualize data stored on a local machine: ``` local$ python lerobot/scripts/visualize_dataset.py \ - --repo-id lerobot/pusht \ - --episode-index 0 + --repo-id lerobot/pusht + +local$ open http://localhost:9090 ``` - Visualize data stored on a distant machine with a local viewer: ``` distant$ python lerobot/scripts/visualize_dataset.py \ + --repo-id lerobot/pusht + +local$ ssh -L 9090:localhost:9090 distant # create a ssh tunnel +local$ open http://localhost:9090 +``` + +- Select episodes to visualize: +``` +python lerobot/scripts/visualize_dataset.py \ --repo-id lerobot/pusht \ - --episode-index 0 \ - --save 1 \ - --output-dir path/to/directory - -local$ scp distant:path/to/directory/lerobot_pusht_episode_0.rrd . -local$ rerun lerobot_pusht_episode_0.rrd + --episode-indices 7 3 5 1 4 ``` - -- Visualize data stored on a distant machine through streaming: -(You need to forward the websocket port to the distant machine, with -`ssh -L 9087:localhost:9087 username@remote-host`) -``` -distant$ python lerobot/scripts/visualize_dataset.py \ - --repo-id lerobot/pusht \ - --episode-index 0 \ - --mode distant \ - --ws-port 9087 - -local$ rerun ws://localhost:9087 -``` - """ import argparse -import gc +import http.server import logging -import time +import os +import shutil +import socketserver +from http.server import HTTPServer, SimpleHTTPRequestHandler from pathlib import Path -import rerun as rr -import torch import tqdm +from bs4 import BeautifulSoup from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.utils.utils import format_big_number, init_logging -class EpisodeSampler(torch.utils.data.Sampler): - def __init__(self, dataset, episode_index): - from_idx = dataset.episode_data_index["from"][episode_index].item() - to_idx = dataset.episode_data_index["to"][episode_index].item() - self.frame_ids = range(from_idx, to_idx) - - def __iter__(self): - return iter(self.frame_ids) - - def __len__(self): - return len(self.frame_ids) +class NoCacheHTTPRequestHandler(http.server.SimpleHTTPRequestHandler): + def end_headers(self): + self.send_header("Cache-Control", "no-store, no-cache, must-revalidate") + self.send_header("Pragma", "no-cache") + self.send_header("Expires", "0") + super().end_headers() -def to_hwc_uint8_numpy(chw_float32_torch): - assert chw_float32_torch.dtype == torch.float32 - assert chw_float32_torch.ndim == 3 - c, h, w = chw_float32_torch.shape - assert c < h and c < w, f"expect channel first images, but instead {chw_float32_torch.shape}" - hwc_uint8_numpy = (chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy() - return hwc_uint8_numpy +def run_server(path, port): + # Change directory to serve 'index.html` as front page + os.chdir(path) + + server_address = ("", port) + httpd = HTTPServer(server_address, SimpleHTTPRequestHandler) + + with socketserver.TCPServer(("", port), NoCacheHTTPRequestHandler) as httpd: + logging.info(f"Serving HTTP on 0.0.0.0 port {port} (http://0.0.0.0:{port}/) ...") + httpd.serve_forever() + + +def create_html_page(page_title: str): + """Create a html page with beautiful soop with default doctype, meta, header and title.""" + soup = BeautifulSoup("", "html.parser") + + doctype = soup.new_tag("!DOCTYPE html") + soup.append(doctype) + + html = soup.new_tag("html", lang="en") + soup.append(html) + + head = soup.new_tag("head") + html.append(head) + + meta_charset = soup.new_tag("meta", charset="UTF-8") + head.append(meta_charset) + + meta_viewport = soup.new_tag( + "meta", attrs={"name": "viewport", "content": "width=device-width, initial-scale=1.0"} + ) + head.append(meta_viewport) + + title = soup.new_tag("title") + title.string = page_title + head.append(title) + + body = soup.new_tag("body") + html.append(body) + + main_div = soup.new_tag("div") + body.append(main_div) + return soup, body + + +def write_episode_data_csv(output_dir, file_name, episode_index, dataset): + """Write a csv file containg timeseries data of an episode (e.g. state and action). + This file will be loaded by Dygraph javascript to plot data in real time.""" + from_idx = dataset.episode_data_index["from"][episode_index] + to_idx = dataset.episode_data_index["to"][episode_index] + + has_state = "observation.state" in dataset.hf_dataset.features + has_action = "action" in dataset.hf_dataset.features + + # init header of csv with state and action names + header = ["timestamp"] + if has_state: + dim_state = len(dataset.hf_dataset["observation.state"][0]) + header += [f"state_{i}" for i in range(dim_state)] + if has_action: + dim_action = len(dataset.hf_dataset["action"][0]) + header += [f"action_{i}" for i in range(dim_action)] + + columns = ["timestamp"] + if has_state: + columns += ["observation.state"] + if has_action: + columns += ["action"] + + rows = [] + data = dataset.hf_dataset.select_columns(columns) + for i in range(from_idx, to_idx): + row = [data[i]["timestamp"].item()] + if has_state: + row += data[i]["observation.state"].tolist() + if has_action: + row += data[i]["action"].tolist() + rows.append(row) + + output_dir.mkdir(parents=True, exist_ok=True) + with open(output_dir / file_name, "w") as f: + f.write(",".join(header) + "\n") + for row in rows: + row_str = [str(col) for col in row] + f.write(",".join(row_str) + "\n") + + +def write_episode_data_js(output_dir, file_name, ep_csv_fname, dataset): + """Write a javascript file containing logic to synchronize camera feeds and timeseries.""" + s = "" + s += "document.addEventListener('DOMContentLoaded', function () {\n" + for i, key in enumerate(dataset.video_frame_keys): + s += f" const video{i} = document.getElementById('video_{key}');\n" + s += " const slider = document.getElementById('videoControl');\n" + s += " const playButton = document.getElementById('playButton');\n" + s += f" const dygraph = new Dygraph(document.getElementById('graphdiv'), '{ep_csv_fname}', " + "{\n" + s += " pixelsPerPoint: 0.01\n" + s += " });\n" + s += "\n" + s += " // Function to play both videos\n" + s += " playButton.addEventListener('click', function () {\n" + for i in range(len(dataset.video_frame_keys)): + s += f" video{i}.play();\n" + s += " // playButton.disabled = true; // Optional: disable button after playing\n" + s += " });\n" + s += "\n" + s += " // Update the video time when the slider value changes\n" + s += " slider.addEventListener('input', function () {\n" + s += " const sliderValue = slider.value;\n" + for i in range(len(dataset.video_frame_keys)): + s += f" const time{i} = (video{i}.duration * sliderValue) / 100;\n" + for i in range(len(dataset.video_frame_keys)): + s += f" video{i}.currentTime = time{i};\n" + s += " });\n" + s += "\n" + s += " // Synchronize slider with the video's current time\n" + s += " const syncSlider = (video) => {\n" + s += " video.addEventListener('timeupdate', function () {\n" + s += " if (video.duration) {\n" + s += " const pc = (100 / video.duration) * video.currentTime;\n" + s += " slider.value = pc;\n" + s += " const index = Math.floor(pc * dygraph.numRows() / 100);\n" + s += " dygraph.setSelection(index, undefined, true, true);\n" + s += " }\n" + s += " });\n" + s += " };\n" + s += "\n" + for i in range(len(dataset.video_frame_keys)): + s += f" syncSlider(video{i});\n" + s += "\n" + s += "});\n" + + output_dir.mkdir(parents=True, exist_ok=True) + with open(output_dir / file_name, "w", encoding="utf-8") as f: + f.write(s) + + +def write_episode_data_html(output_dir, file_name, js_fname, ep_index, dataset): + """Write an html file containg video feeds and timeseries associated to an episode.""" + soup, body = create_html_page("") + + main_div = soup.new_tag("div") + body.append(main_div) + + def create_video(id, src): + video = soup.new_tag("video", id=id, width="320", height="240", controls="") + source = soup.new_tag("source", src=src, type="video/mp4") + video.string = "Your browser does not support the video tag." + video.append(source) + return video + + # get first frame of episode (hack to get video_path of the episode) + first_frame_idx = dataset.episode_data_index["from"][ep_index].item() + + for key in dataset.video_frame_keys: + # Example of video_path: '../videos/observation.image_episode_000004.mp4' + video_path = Path("..") / dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"] + main_div.append(create_video(f"video_{key}", video_path)) + + control_div = soup.new_tag("div") + body.append(control_div) + + graph_div = soup.new_tag("div", id="graphdiv", style="width: 600px; height: 300px;") + body.append(graph_div) + + button = soup.new_tag("button", id="playButton") + button.string = "Play Videos" + control_div.append(button) + + range_input = soup.new_tag( + "input", type="range", id="videoControl", min="0", max="100", value="0", step="1" + ) + control_div.append(range_input) + + script = soup.new_tag("script", type="text/javascript", src=js_fname) + body.append(script) + + script_dygraph = soup.new_tag( + "script", + type="text/javascript", + src="https://cdn.jsdelivr.net/npm/dygraphs@2.1.0/dist/dygraph.min.js", + ) + body.append(script_dygraph) + + link_dygraph = soup.new_tag( + "link", rel="stylesheet", href="https://cdn.jsdelivr.net/npm/dygraphs@2.1.0/dist/dygraph.min.css" + ) + body.append(link_dygraph) + + output_dir.mkdir(parents=True, exist_ok=True) + with open(output_dir / file_name, "w", encoding="utf-8") as f: + f.write(soup.prettify()) + + +def write_episodes_list_html(output_dir, file_name, ep_indices, ep_html_fnames, dataset): + """Write an html file containing information related to the dataset and a list of links to + html pages of episodes.""" + soup, body = create_html_page(dataset.hf_dataset.info.dataset_name) + + h3 = soup.new_tag("h3") + h3.string = dataset.hf_dataset.info.dataset_name + body.append(h3) + + ul_info = soup.new_tag("ul") + body.append(ul_info) + + li_info = soup.new_tag("li") + li_info.string = f"Number of samples/frames: {dataset.num_samples}" + ul_info.append(li_info) + + li_info = soup.new_tag("li") + li_info.string = f"Number of episodes: {dataset.num_episodes}" + ul_info.append(li_info) + + li_info = soup.new_tag("li") + li_info.string = f"Frames per second: {dataset.fps}" + ul_info.append(li_info) + + li_info = soup.new_tag("li") + li_info.string = f"Size: {format_big_number(dataset.hf_dataset.info.size_in_bytes)}B" + ul_info.append(li_info) + + ul = soup.new_tag("ul") + body.append(ul) + + for ep_idx, ep_html_fname in zip(ep_indices, ep_html_fnames, strict=False): + li = soup.new_tag("li") + ul.append(li) + + a = soup.new_tag("a", href=ep_html_fname) + a.string = f"Episode number {ep_idx}" + + li.append(a) + + output_dir.mkdir(parents=True, exist_ok=True) + with open(output_dir / file_name, "w", encoding="utf-8") as f: + f.write(soup.prettify()) def visualize_dataset( repo_id: str, - episode_index: int, - batch_size: int = 32, - num_workers: int = 0, - mode: str = "local", - web_port: int = 9090, - ws_port: int = 9087, - save: bool = False, + episode_indices: list[int] = None, output_dir: Path | None = None, + serve: bool = True, + web_port: int = 9090, ) -> Path | None: - if save: - assert ( - output_dir is not None - ), "Set an output directory where to write .rrd files with `--output-dir path/to/directory`." + init_logging() logging.info("Loading dataset") dataset = LeRobotDataset(repo_id) - logging.info("Loading dataloader") - episode_sampler = EpisodeSampler(dataset, episode_index) - dataloader = torch.utils.data.DataLoader( - dataset, - num_workers=num_workers, - batch_size=batch_size, - sampler=episode_sampler, - ) + if not dataset.video: + raise NotImplementedError(f"Image datasets ({dataset.video=}) are currently not supported.") - logging.info("Starting Rerun") + if output_dir is None: + output_dir = f"outputs/visualize_dataset/{repo_id}" - if mode not in ["local", "distant"]: - raise ValueError(mode) + output_dir = Path(output_dir) + if output_dir.exists(): + shutil.rmtree(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) - spawn_local_viewer = mode == "local" and not save - rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer) + # Create a simlink from the dataset video folder containg mp4 files to the output directory + # so that the http server can get access to the mp4 files. + ln_videos_dir = output_dir / "videos" + ln_videos_dir.symlink_to(dataset.videos_dir) - # Manually call python garbage collector after `rr.init` to avoid hanging in a blocking flush - # when iterating on a dataloader with `num_workers` > 0 - # TODO(rcadene): remove `gc.collect` when rerun version 0.16 is out, which includes a fix - gc.collect() + if episode_indices is None: + episode_indices = list(range(dataset.num_episodes)) - if mode == "distant": - rr.serve(open_browser=False, web_port=web_port, ws_port=ws_port) + logging.info("Writing html") + ep_html_fnames = [] + for episode_idx in tqdm.tqdm(episode_indices): + # write states and actions in a csv + ep_csv_fname = f"episode_{episode_idx}.csv" + write_episode_data_csv(output_dir, ep_csv_fname, episode_idx, dataset) - logging.info("Logging to Rerun") + js_fname = f"episode_{episode_idx}.js" + write_episode_data_js(output_dir, js_fname, ep_csv_fname, dataset) - for batch in tqdm.tqdm(dataloader, total=len(dataloader)): - # iterate over the batch - for i in range(len(batch["index"])): - rr.set_time_sequence("frame_index", batch["frame_index"][i].item()) - rr.set_time_seconds("timestamp", batch["timestamp"][i].item()) + # write a html page to view videos and timeseries + ep_html_fname = f"episode_{episode_idx}.html" + write_episode_data_html(output_dir, ep_html_fname, js_fname, episode_idx, dataset) + ep_html_fnames.append(ep_html_fname) - # display each camera image - for key in dataset.camera_keys: - # TODO(rcadene): add `.compress()`? is it lossless? - rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i]))) + write_episodes_list_html(output_dir, "index.html", episode_indices, ep_html_fnames, dataset) - # display each dimension of action space (e.g. actuators command) - if "action" in batch: - for dim_idx, val in enumerate(batch["action"][i]): - rr.log(f"action/{dim_idx}", rr.Scalar(val.item())) - - # display each dimension of observed state space (e.g. agent position in joint space) - if "observation.state" in batch: - for dim_idx, val in enumerate(batch["observation.state"][i]): - rr.log(f"state/{dim_idx}", rr.Scalar(val.item())) - - if "next.done" in batch: - rr.log("next.done", rr.Scalar(batch["next.done"][i].item())) - - if "next.reward" in batch: - rr.log("next.reward", rr.Scalar(batch["next.reward"][i].item())) - - if "next.success" in batch: - rr.log("next.success", rr.Scalar(batch["next.success"][i].item())) - - if mode == "local" and save: - # save .rrd locally - output_dir = Path(output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - repo_id_str = repo_id.replace("/", "_") - rrd_path = output_dir / f"{repo_id_str}_episode_{episode_index}.rrd" - rr.save(rrd_path) - return rrd_path - - elif mode == "distant": - # stop the process from exiting since it is serving the websocket connection - try: - while True: - time.sleep(1) - except KeyboardInterrupt: - print("Ctrl-C received. Exiting.") + if serve: + run_server(output_dir, web_port) def main(): @@ -200,59 +377,29 @@ def main(): help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht`).", ) parser.add_argument( - "--episode-index", + "--episode-indices", type=int, - required=True, - help="Episode to visualize.", + nargs="*", + default=None, + help="Episode indices to visualize (e.g. `0 1 5 6` to load episodes of index 0, 1, 5 and 6). By default loads all episodes.", ) parser.add_argument( - "--batch-size", - type=int, - default=32, - help="Batch size loaded by DataLoader.", - ) - parser.add_argument( - "--num-workers", - type=int, - default=4, - help="Number of processes of Dataloader for loading the data.", - ) - parser.add_argument( - "--mode", + "--output-dir", type=str, - default="local", - help=( - "Mode of viewing between 'local' or 'distant'. " - "'local' requires data to be on a local machine. It spawns a viewer to visualize the data locally. " - "'distant' creates a server on the distant machine where the data is stored. Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine." - ), + default=None, + help="Directory path to write html files and kickoff a web server. By default write them to 'outputs/visualize_dataset/REPO_ID'.", + ) + parser.add_argument( + "--serve", + type=int, + default=1, + help="Launch web server.", ) parser.add_argument( "--web-port", type=int, default=9090, - help="Web port for rerun.io when `--mode distant` is set.", - ) - parser.add_argument( - "--ws-port", - type=int, - default=9087, - help="Web socket port for rerun.io when `--mode distant` is set.", - ) - parser.add_argument( - "--save", - type=int, - default=0, - help=( - "Save a .rrd file in the directory provided by `--output-dir`. " - "It also deactivates the spawning of a viewer. ", - "Visualize the data by running `rerun path/to/file.rrd` on your local machine.", - ), - ) - parser.add_argument( - "--output-dir", - type=str, - help="Directory path to write a .rrd file when `--save 1` is set.", + help="Web port used by the http server.", ) args = parser.parse_args()