diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py index 3ba1ec54..2f6ab545 100644 --- a/lerobot/scripts/visualize_dataset.py +++ b/lerobot/scripts/visualize_dataset.py @@ -58,7 +58,6 @@ import logging import os import shutil import socketserver -from http.server import HTTPServer, SimpleHTTPRequestHandler from pathlib import Path import tqdm @@ -80,9 +79,6 @@ def run_server(path, port): # Change directory to serve 'index.html` as front page os.chdir(path) - server_address = ("", port) - httpd = HTTPServer(server_address, SimpleHTTPRequestHandler) - with socketserver.TCPServer(("", port), NoCacheHTTPRequestHandler) as httpd: logging.info(f"Serving HTTP on 0.0.0.0 port {port} (http://0.0.0.0:{port}/) ...") httpd.serve_forever() @@ -171,8 +167,12 @@ def write_episode_data_js(output_dir, file_name, ep_csv_fname, dataset): s += f" const video{i} = document.getElementById('video_{key}');\n" s += " const slider = document.getElementById('videoControl');\n" s += " const playButton = document.getElementById('playButton');\n" - s += f" const dygraph = new Dygraph(document.getElementById('graphdiv'), '{ep_csv_fname}', " + "{\n" - s += " pixelsPerPoint: 0.01\n" + s += f" const dygraph = new Dygraph(document.getElementById('graph'), '{ep_csv_fname}', " + "{\n" + s += " pixelsPerPoint: 0.01,\n" + s += " legend: 'always',\n" + s += " labelsDiv: document.getElementById('labels'),\n" + s += " labelsSeparateLines: true,\n" + s += " labelsKMB: true\n" s += " });\n" s += "\n" s += " // Function to play both videos\n" @@ -217,8 +217,13 @@ def write_episode_data_html(output_dir, file_name, js_fname, ep_index, dataset): """Write an html file containg video feeds and timeseries associated to an episode.""" soup, body = create_html_page("") - main_div = soup.new_tag("div") - body.append(main_div) + # Add videos from camera feeds + + videos_control_div = soup.new_tag("div") + body.append(videos_control_div) + + videos_div = soup.new_tag("div") + videos_control_div.append(videos_div) def create_video(id, src): video = soup.new_tag("video", id=id, width="320", height="240", controls="") @@ -231,25 +236,40 @@ def write_episode_data_html(output_dir, file_name, js_fname, ep_index, dataset): first_frame_idx = dataset.episode_data_index["from"][ep_index].item() for key in dataset.video_frame_keys: - # Example of video_path: '../videos/observation.image_episode_000004.mp4' - video_path = Path("..") / dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"] - main_div.append(create_video(f"video_{key}", video_path)) + # Example of video_path: 'videos/observation.image_episode_000004.mp4' + video_path = dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"] + videos_div.append(create_video(f"video_{key}", video_path)) + + # Add controls for videos and graph control_div = soup.new_tag("div") - body.append(control_div) + videos_control_div.append(control_div) - graph_div = soup.new_tag("div", id="graphdiv", style="width: 600px; height: 300px;") - body.append(graph_div) + button_div = soup.new_tag("div") + control_div.append(button_div) button = soup.new_tag("button", id="playButton") button.string = "Play Videos" - control_div.append(button) + button_div.append(button) - range_input = soup.new_tag( - "input", type="range", id="videoControl", min="0", max="100", value="0", step="1" - ) - control_div.append(range_input) + slider_div = soup.new_tag("div") + control_div.append(slider_div) + slider = soup.new_tag("input", type="range", id="videoControl", min="0", max="100", value="0", step="1") + control_div.append(slider) + + # Add graph of states/actions, and its labels + + graph_labels_div = soup.new_tag("div", style="display: flex;") + body.append(graph_labels_div) + + graph_div = soup.new_tag("div", id="graph", style="flex: 1; width: 85%") + graph_labels_div.append(graph_div) + + labels_div = soup.new_tag("div", id="labels", style="flex: 1; width: 15%") + graph_labels_div.append(labels_div) + + # add dygraph library script = soup.new_tag("script", type="text/javascript", src=js_fname) body.append(script) @@ -265,6 +285,8 @@ def write_episode_data_html(output_dir, file_name, js_fname, ep_index, dataset): ) body.append(link_dygraph) + # Write as a html file + output_dir.mkdir(parents=True, exist_ok=True) with open(output_dir / file_name, "w", encoding="utf-8") as f: f.write(soup.prettify()) @@ -320,7 +342,7 @@ def visualize_dataset( episode_indices: list[int] = None, output_dir: Path | None = None, serve: bool = True, - web_port: int = 9090, + port: int = 9090, ) -> Path | None: init_logging() @@ -364,7 +386,7 @@ def visualize_dataset( write_episodes_list_html(output_dir, "index.html", episode_indices, ep_html_fnames, dataset) if serve: - run_server(output_dir, web_port) + run_server(output_dir, port) def main(): @@ -396,7 +418,7 @@ def main(): help="Launch web server.", ) parser.add_argument( - "--web-port", + "--port", type=int, default=9090, help="Web port used by the http server.",