From 12a1b8f55ad204c78dd2e7e034cf298d80b80c38 Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Thu, 13 Jun 2024 15:56:54 +0000 Subject: [PATCH] rename to html --- lerobot/scripts/visualize_dataset.py | 603 ++++++--------------- lerobot/scripts/visualize_dataset_rerun.py | 263 --------- 2 files changed, 151 insertions(+), 715 deletions(-) delete mode 100644 lerobot/scripts/visualize_dataset_rerun.py diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py index 6534343d..138084ae 100644 --- a/lerobot/scripts/visualize_dataset.py +++ b/lerobot/scripts/visualize_dataset.py @@ -30,46 +30,48 @@ Examples: - Visualize data stored on a local machine: ``` local$ python lerobot/scripts/visualize_dataset.py \ - --repo-id lerobot/pusht - -local$ open http://localhost:9090 + --repo-id lerobot/pusht \ + --episode-index 0 ``` - Visualize data stored on a distant machine with a local viewer: ``` distant$ python lerobot/scripts/visualize_dataset.py \ - --repo-id lerobot/pusht - -local$ ssh -L 9090:localhost:9090 distant # create a ssh tunnel -local$ open http://localhost:9090 -``` - -- Select episodes to visualize: -``` -python lerobot/scripts/visualize_dataset.py \ --repo-id lerobot/pusht \ - --episode-indices 7 3 5 1 4 + --episode-index 0 \ + --save 1 \ + --output-dir path/to/directory + +local$ scp distant:path/to/directory/lerobot_pusht_episode_0.rrd . +local$ rerun lerobot_pusht_episode_0.rrd ``` + +- Visualize data stored on a distant machine through streaming: +(You need to forward the websocket port to the distant machine, with +`ssh -L 9087:localhost:9087 username@remote-host`) +``` +distant$ python lerobot/scripts/visualize_dataset.py \ + --repo-id lerobot/pusht \ + --episode-index 0 \ + --mode distant \ + --ws-port 9087 + +local$ rerun ws://localhost:9087 +``` + """ import argparse -import http.server +import gc import logging -import os -import shutil -import socketserver +import time from pathlib import Path +import rerun as rr import torch import tqdm -import yaml -from bs4 import BeautifulSoup -from huggingface_hub import snapshot_download -from safetensors.torch import load_file, save_file from lerobot.common.datasets.lerobot_dataset import LeRobotDataset -from lerobot.common.policies.act.modeling_act import ACTPolicy -from lerobot.common.utils.utils import init_logging class EpisodeSampler(torch.utils.data.Sampler): @@ -85,307 +87,33 @@ class EpisodeSampler(torch.utils.data.Sampler): return len(self.frame_ids) -class NoCacheHTTPRequestHandler(http.server.SimpleHTTPRequestHandler): - def end_headers(self): - self.send_header("Cache-Control", "no-store, no-cache, must-revalidate") - self.send_header("Pragma", "no-cache") - self.send_header("Expires", "0") - super().end_headers() +def to_hwc_uint8_numpy(chw_float32_torch): + assert chw_float32_torch.dtype == torch.float32 + assert chw_float32_torch.ndim == 3 + c, h, w = chw_float32_torch.shape + assert c < h and c < w, f"expect channel first images, but instead {chw_float32_torch.shape}" + hwc_uint8_numpy = (chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy() + return hwc_uint8_numpy -def run_server(path, port): - # Change directory to serve 'index.html` as front page - os.chdir(path) +def visualize_dataset( + repo_id: str, + episode_index: int, + batch_size: int = 32, + num_workers: int = 0, + mode: str = "local", + web_port: int = 9090, + ws_port: int = 9087, + save: bool = False, + output_dir: Path | None = None, +) -> Path | None: + if save: + assert ( + output_dir is not None + ), "Set an output directory where to write .rrd files with `--output-dir path/to/directory`." - with socketserver.TCPServer(("", port), NoCacheHTTPRequestHandler) as httpd: - logging.info(f"Serving HTTP on 0.0.0.0 port {port} (http://0.0.0.0:{port}/) ...") - httpd.serve_forever() - - -def create_html_page(page_title: str): - """Create a html page with beautiful soop with default doctype, meta, header and title.""" - soup = BeautifulSoup("", "html.parser") - - doctype = soup.new_tag("!DOCTYPE html") - soup.append(doctype) - - html = soup.new_tag("html", lang="en") - soup.append(html) - - head = soup.new_tag("head") - html.append(head) - - meta_charset = soup.new_tag("meta", charset="UTF-8") - head.append(meta_charset) - - meta_viewport = soup.new_tag( - "meta", attrs={"name": "viewport", "content": "width=device-width, initial-scale=1.0"} - ) - head.append(meta_viewport) - - title = soup.new_tag("title") - title.string = page_title - head.append(title) - - body = soup.new_tag("body") - html.append(body) - - main_div = soup.new_tag("div") - body.append(main_div) - return soup, head, body - - -def write_episode_data_csv(output_dir, file_name, episode_index, dataset, inference_results=None): - """Write a csv file containg timeseries data of an episode (e.g. state and action). - This file will be loaded by Dygraph javascript to plot data in real time.""" - from_idx = dataset.episode_data_index["from"][episode_index] - to_idx = dataset.episode_data_index["to"][episode_index] - - has_state = "observation.state" in dataset.hf_dataset.features - has_action = "action" in dataset.hf_dataset.features - has_inference = inference_results is not None - - # init header of csv with state and action names - header = ["timestamp"] - if has_state: - dim_state = len(dataset.hf_dataset["observation.state"][0]) - header += [f"state_{i}" for i in range(dim_state)] - if has_action: - dim_action = len(dataset.hf_dataset["action"][0]) - header += [f"action_{i}" for i in range(dim_action)] - if has_inference: - assert "actions" in inference_results - assert "loss" in inference_results - dim_pred_action = inference_results["actions"].shape[2] - header += [f"pred_action_{i}" for i in range(dim_pred_action)] - header += ["loss"] - - columns = ["timestamp"] - if has_state: - columns += ["observation.state"] - if has_action: - columns += ["action"] - - rows = [] - data = dataset.hf_dataset.select_columns(columns) - for i in range(from_idx, to_idx): - row = [data[i]["timestamp"].item()] - if has_state: - row += data[i]["observation.state"].tolist() - if has_action: - row += data[i]["action"].tolist() - rows.append(row) - - if has_inference: - num_frames = len(rows) - assert num_frames == inference_results["actions"].shape[0] - assert num_frames == inference_results["loss"].shape[0] - for i in range(num_frames): - rows[i] += inference_results["actions"][i, 0].tolist() - rows[i] += [inference_results["loss"][i].item()] - - output_dir.mkdir(parents=True, exist_ok=True) - with open(output_dir / file_name, "w") as f: - f.write(",".join(header) + "\n") - for row in rows: - row_str = [str(col) for col in row] - f.write(",".join(row_str) + "\n") - - -def write_episode_data_js(output_dir, file_name, ep_csv_fname, dataset): - """Write a javascript file containing logic to synchronize camera feeds and timeseries.""" - s = "" - s += "document.addEventListener('DOMContentLoaded', function () {\n" - for i, key in enumerate(dataset.video_frame_keys): - s += f" const video{i} = document.getElementById('video_{key}');\n" - s += " const slider = document.getElementById('videoControl');\n" - s += " const playButton = document.getElementById('playButton');\n" - s += f" const dygraph = new Dygraph(document.getElementById('graph'), '{ep_csv_fname}', " + "{\n" - s += " pixelsPerPoint: 0.01,\n" - s += " legend: 'always',\n" - s += " labelsDiv: document.getElementById('labels'),\n" - s += " labelsSeparateLines: true,\n" - s += " labelsKMB: true,\n" - s += " highlightCircleSize: 1.5,\n" - s += " highlightSeriesOpts: {\n" - s += " strokeWidth: 1.5,\n" - s += " strokeBorderWidth: 1,\n" - s += " highlightCircleSize: 3\n" - s += " }\n" - s += " });\n" - s += "\n" - s += " // Function to play both videos\n" - s += " playButton.addEventListener('click', function () {\n" - for i in range(len(dataset.video_frame_keys)): - s += f" video{i}.play();\n" - s += " // playButton.disabled = true; // Optional: disable button after playing\n" - s += " });\n" - s += "\n" - s += " // Update the video time when the slider value changes\n" - s += " slider.addEventListener('input', function () {\n" - s += " const sliderValue = slider.value;\n" - for i in range(len(dataset.video_frame_keys)): - s += f" const time{i} = (video{i}.duration * sliderValue) / 100;\n" - for i in range(len(dataset.video_frame_keys)): - s += f" video{i}.currentTime = time{i};\n" - s += " });\n" - s += "\n" - s += " // Synchronize slider with the video's current time\n" - s += " const syncSlider = (video) => {\n" - s += " video.addEventListener('timeupdate', function () {\n" - s += " if (video.duration) {\n" - s += " const pc = (100 / video.duration) * video.currentTime;\n" - s += " slider.value = pc;\n" - s += " const index = Math.floor(pc * dygraph.numRows() / 100);\n" - s += " dygraph.setSelection(index, undefined, true, true);\n" - s += " }\n" - s += " });\n" - s += " };\n" - s += "\n" - for i in range(len(dataset.video_frame_keys)): - s += f" syncSlider(video{i});\n" - s += "\n" - s += "});\n" - - output_dir.mkdir(parents=True, exist_ok=True) - with open(output_dir / file_name, "w", encoding="utf-8") as f: - f.write(s) - - -def write_episode_data_html(output_dir, file_name, js_fname, ep_index, dataset): - """Write an html file containg video feeds and timeseries associated to an episode.""" - soup, head, body = create_html_page("") - - css_style = soup.new_tag("style") - css_style.string = "" - css_style.string += "#labels > span.highlight {\n" - css_style.string += " border: 1px solid grey;\n" - css_style.string += "}" - head.append(css_style) - - # Add videos from camera feeds - - videos_control_div = soup.new_tag("div") - body.append(videos_control_div) - - videos_div = soup.new_tag("div") - videos_control_div.append(videos_div) - - def create_video(id, src): - video = soup.new_tag("video", id=id, width="320", height="240", controls="") - source = soup.new_tag("source", src=src, type="video/mp4") - video.string = "Your browser does not support the video tag." - video.append(source) - return video - - # get first frame of episode (hack to get video_path of the episode) - first_frame_idx = dataset.episode_data_index["from"][ep_index].item() - - for key in dataset.video_frame_keys: - # Example of video_path: 'videos/observation.image_episode_000004.mp4' - video_path = dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"] - videos_div.append(create_video(f"video_{key}", video_path)) - - # Add controls for videos and graph - - control_div = soup.new_tag("div") - videos_control_div.append(control_div) - - button_div = soup.new_tag("div") - control_div.append(button_div) - - button = soup.new_tag("button", id="playButton") - button.string = "Play Videos" - button_div.append(button) - - slider_div = soup.new_tag("div") - control_div.append(slider_div) - - slider = soup.new_tag("input", type="range", id="videoControl", min="0", max="100", value="0", step="1") - control_div.append(slider) - - # Add graph of states/actions, and its labels - - graph_labels_div = soup.new_tag("div", style="display: flex;") - body.append(graph_labels_div) - - graph_div = soup.new_tag("div", id="graph", style="flex: 1; width: 85%") - graph_labels_div.append(graph_div) - - labels_div = soup.new_tag("div", id="labels", style="flex: 1; width: 15%") - graph_labels_div.append(labels_div) - - # add dygraph library - script = soup.new_tag("script", type="text/javascript", src=js_fname) - body.append(script) - - script_dygraph = soup.new_tag( - "script", - type="text/javascript", - src="https://cdn.jsdelivr.net/npm/dygraphs@2.1.0/dist/dygraph.min.js", - ) - body.append(script_dygraph) - - link_dygraph = soup.new_tag( - "link", rel="stylesheet", href="https://cdn.jsdelivr.net/npm/dygraphs@2.1.0/dist/dygraph.min.css" - ) - body.append(link_dygraph) - - # Write as a html file - - output_dir.mkdir(parents=True, exist_ok=True) - with open(output_dir / file_name, "w", encoding="utf-8") as f: - f.write(soup.prettify()) - - -def write_episodes_list_html(output_dir, file_name, ep_indices, ep_html_fnames, dataset): - """Write an html file containing information related to the dataset and a list of links to - html pages of episodes.""" - soup, head, body = create_html_page("TODO") - - h3 = soup.new_tag("h3") - h3.string = "TODO" - body.append(h3) - - ul_info = soup.new_tag("ul") - body.append(ul_info) - - li_info = soup.new_tag("li") - li_info.string = f"Number of samples/frames: {dataset.num_samples}" - ul_info.append(li_info) - - li_info = soup.new_tag("li") - li_info.string = f"Number of episodes: {dataset.num_episodes}" - ul_info.append(li_info) - - li_info = soup.new_tag("li") - li_info.string = f"Frames per second: {dataset.fps}" - ul_info.append(li_info) - - # li_info = soup.new_tag("li") - # li_info.string = f"Size: {format_big_number(dataset.hf_dataset.info.size_in_bytes)}B" - # ul_info.append(li_info) - - ul = soup.new_tag("ul") - body.append(ul) - - for ep_idx, ep_html_fname in zip(ep_indices, ep_html_fnames, strict=False): - li = soup.new_tag("li") - ul.append(li) - - a = soup.new_tag("a", href=ep_html_fname) - a.string = f"Episode number {ep_idx}" - - li.append(a) - - output_dir.mkdir(parents=True, exist_ok=True) - with open(output_dir / file_name, "w", encoding="utf-8") as f: - f.write(soup.prettify()) - - -def run_inference(dataset, episode_index, policy, num_workers=4, batch_size=32, device="cuda"): - policy.eval() - policy.to(device) + logging.info("Loading dataset") + dataset = LeRobotDataset(repo_id) logging.info("Loading dataloader") episode_sampler = EpisodeSampler(dataset, episode_index) @@ -396,104 +124,70 @@ def run_inference(dataset, episode_index, policy, num_workers=4, batch_size=32, sampler=episode_sampler, ) - logging.info("Running inference") - inference_results = {} + logging.info("Starting Rerun") + + if mode not in ["local", "distant"]: + raise ValueError(mode) + + spawn_local_viewer = mode == "local" and not save + rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer) + + # Manually call python garbage collector after `rr.init` to avoid hanging in a blocking flush + # when iterating on a dataloader with `num_workers` > 0 + # TODO(rcadene): remove `gc.collect` when rerun version 0.16 is out, which includes a fix + gc.collect() + + if mode == "distant": + rr.serve(open_browser=False, web_port=web_port, ws_port=ws_port) + + logging.info("Logging to Rerun") + for batch in tqdm.tqdm(dataloader, total=len(dataloader)): - batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()} - with torch.inference_mode(): - output_dict = policy.forward(batch) + # iterate over the batch + for i in range(len(batch["index"])): + rr.set_time_sequence("frame_index", batch["frame_index"][i].item()) + rr.set_time_seconds("timestamp", batch["timestamp"][i].item()) - for key in output_dict: - if key not in inference_results: - inference_results[key] = [] - inference_results[key].append(output_dict[key].to("cpu")) + # display each camera image + for key in dataset.camera_keys: + # TODO(rcadene): add `.compress()`? is it lossless? + rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i]))) - for key in inference_results: - inference_results[key] = torch.cat(inference_results[key]) + # display each dimension of action space (e.g. actuators command) + if "action" in batch: + for dim_idx, val in enumerate(batch["action"][i]): + rr.log(f"action/{dim_idx}", rr.Scalar(val.item())) - return inference_results + # display each dimension of observed state space (e.g. agent position in joint space) + if "observation.state" in batch: + for dim_idx, val in enumerate(batch["observation.state"][i]): + rr.log(f"state/{dim_idx}", rr.Scalar(val.item())) + if "next.done" in batch: + rr.log("next.done", rr.Scalar(batch["next.done"][i].item())) -def visualize_dataset( - repo_id: str, - episode_indices: list[int] = None, - output_dir: Path | None = None, - serve: bool = True, - port: int = 9090, - force_overwrite: bool = True, - policy_repo_id: str | None = None, - policy_ckpt_path: Path | None = None, - batch_size: int = 32, - num_workers: int = 4, -) -> Path | None: - init_logging() + if "next.reward" in batch: + rr.log("next.reward", rr.Scalar(batch["next.reward"][i].item())) - has_policy = policy_repo_id or policy_ckpt_path + if "next.success" in batch: + rr.log("next.success", rr.Scalar(batch["next.success"][i].item())) - if has_policy: - logging.info("Loading policy") - if policy_repo_id: - pretrained_policy_path = Path(snapshot_download(policy_repo_id)) - elif policy_ckpt_path: - pretrained_policy_path = Path(policy_ckpt_path) - policy = ACTPolicy.from_pretrained(pretrained_policy_path) - with open(pretrained_policy_path / "config.yaml") as f: - cfg = yaml.safe_load(f) - delta_timestamps = cfg["training"]["delta_timestamps"] - else: - delta_timestamps = None + if mode == "local" and save: + # save .rrd locally + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + repo_id_str = repo_id.replace("/", "_") + rrd_path = output_dir / f"{repo_id_str}_episode_{episode_index}.rrd" + rr.save(rrd_path) + return rrd_path - logging.info("Loading dataset") - dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps) - - if not dataset.video: - raise NotImplementedError(f"Image datasets ({dataset.video=}) are currently not supported.") - - if output_dir is None: - output_dir = f"outputs/visualize_dataset/{repo_id}" - - output_dir = Path(output_dir) - if force_overwrite and output_dir.exists(): - shutil.rmtree(output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - - # Create a simlink from the dataset video folder containg mp4 files to the output directory - # so that the http server can get access to the mp4 files. - ln_videos_dir = output_dir / "videos" - if not ln_videos_dir.exists(): - ln_videos_dir.symlink_to(dataset.videos_dir.resolve()) - - if episode_indices is None: - episode_indices = list(range(dataset.num_episodes)) - - logging.info("Writing html") - ep_html_fnames = [] - for episode_index in tqdm.tqdm(episode_indices): - inference_results = None - if has_policy: - inference_results_path = output_dir / f"episode_{episode_index}.safetensors" - if inference_results_path.exists(): - inference_results = load_file(inference_results_path) - else: - inference_results = run_inference(dataset, episode_index, policy) - save_file(inference_results, inference_results_path) - - # write states and actions in a csv - ep_csv_fname = f"episode_{episode_index}.csv" - write_episode_data_csv(output_dir, ep_csv_fname, episode_index, dataset, inference_results) - - js_fname = f"episode_{episode_index}.js" - write_episode_data_js(output_dir, js_fname, ep_csv_fname, dataset) - - # write a html page to view videos and timeseries - ep_html_fname = f"episode_{episode_index}.html" - write_episode_data_html(output_dir, ep_html_fname, js_fname, episode_index, dataset) - ep_html_fnames.append(ep_html_fname) - - write_episodes_list_html(output_dir, "index.html", episode_indices, ep_html_fnames, dataset) - - if serve: - run_server(output_dir, port) + elif mode == "distant": + # stop the process from exiting since it is serving the websocket connection + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + print("Ctrl-C received. Exiting.") def main(): @@ -503,51 +197,13 @@ def main(): "--repo-id", type=str, required=True, - help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).", + help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht`).", ) parser.add_argument( - "--episode-indices", + "--episode-index", type=int, - nargs="*", - default=None, - help="Episode indices to visualize (e.g. `0 1 5 6` to load episodes of index 0, 1, 5 and 6). By default loads all episodes.", - ) - parser.add_argument( - "--output-dir", - type=str, - default=None, - help="Directory path to write html files and kickoff a web server. By default write them to 'outputs/visualize_dataset/REPO_ID'.", - ) - parser.add_argument( - "--serve", - type=int, - default=1, - help="Launch web server.", - ) - parser.add_argument( - "--port", - type=int, - default=9090, - help="Web port used by the http server.", - ) - parser.add_argument( - "--force-overwrite", - type=int, - default=1, - help="Delete the output directory if it exists already.", - ) - - parser.add_argument( - "--policy-repo-id", - type=str, - default=None, - help="Name of hugging face repositery containing a pretrained policy (e.g. `lerobot/diffusion_pusht` for https://huggingface.co/lerobot/diffusion_pusht).", - ) - parser.add_argument( - "--policy-ckpt-path", - type=str, - default=None, - help="Name of hugging face repositery containing a pretrained policy (e.g. `lerobot/diffusion_pusht` for https://huggingface.co/lerobot/diffusion_pusht).", + required=True, + help="Episode to visualize.", ) parser.add_argument( "--batch-size", @@ -561,6 +217,49 @@ def main(): default=4, help="Number of processes of Dataloader for loading the data.", ) + parser.add_argument( + "--mode", + type=str, + default="local", + help=( + "Mode of viewing between 'local' or 'distant'. " + "'local' requires data to be on a local machine. It spawns a viewer to visualize the data locally. " + "'distant' creates a server on the distant machine where the data is stored. Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine." + ), + ) + parser.add_argument( + "--web-port", + type=int, + default=9090, + help="Web port for rerun.io when `--mode distant` is set.", + ) + parser.add_argument( + "--ws-port", + type=int, + default=9087, + help="Web socket port for rerun.io when `--mode distant` is set.", + ) + parser.add_argument( + "--save", + type=int, + default=0, + help=( + "Save a .rrd file in the directory provided by `--output-dir`. " + "It also deactivates the spawning of a viewer. ", + "Visualize the data by running `rerun path/to/file.rrd` on your local machine.", + ), + ) + parser.add_argument( + "--output-dir", + type=str, + help="Directory path to write a .rrd file when `--save 1` is set.", + ) + + parser.add_argument( + "--root", + type=str, + help="Root directory for a dataset stored on a local machine.", + ) args = parser.parse_args() visualize_dataset(**vars(args)) diff --git a/lerobot/scripts/visualize_dataset_rerun.py b/lerobot/scripts/visualize_dataset_rerun.py deleted file mode 100644 index 58da6a47..00000000 --- a/lerobot/scripts/visualize_dataset_rerun.py +++ /dev/null @@ -1,263 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset. - -Note: The last frame of the episode doesnt always correspond to a final state. -That's because our datasets are composed of transition from state to state up to -the antepenultimate state associated to the ultimate action to arrive in the final state. -However, there might not be a transition from a final state to another state. - -Note: This script aims to visualize the data used to train the neural networks. -~What you see is what you get~. When visualizing image modality, it is often expected to observe -lossly compression artifacts since these images have been decoded from compressed mp4 videos to -save disk space. The compression factor applied has been tuned to not affect success rate. - -Examples: - -- Visualize data stored on a local machine: -``` -local$ python lerobot/scripts/visualize_dataset.py \ - --repo-id lerobot/pusht \ - --episode-index 0 -``` - -- Visualize data stored on a distant machine with a local viewer: -``` -distant$ python lerobot/scripts/visualize_dataset.py \ - --repo-id lerobot/pusht \ - --episode-index 0 \ - --save 1 \ - --output-dir path/to/directory - -local$ scp distant:path/to/directory/lerobot_pusht_episode_0.rrd . -local$ rerun lerobot_pusht_episode_0.rrd -``` - -- Visualize data stored on a distant machine through streaming: -(You need to forward the websocket port to the distant machine, with -`ssh -L 9087:localhost:9087 username@remote-host`) -``` -distant$ python lerobot/scripts/visualize_dataset.py \ - --repo-id lerobot/pusht \ - --episode-index 0 \ - --mode distant \ - --ws-port 9087 - -local$ rerun ws://localhost:9087 -``` - -""" - -import argparse -import gc -import logging -import time -from pathlib import Path - -import rerun as rr -import torch -import tqdm - -from lerobot.common.datasets.lerobot_dataset import LeRobotDataset - - -class EpisodeSampler(torch.utils.data.Sampler): - def __init__(self, dataset, episode_index): - from_idx = dataset.episode_data_index["from"][episode_index].item() - to_idx = dataset.episode_data_index["to"][episode_index].item() - self.frame_ids = range(from_idx, to_idx) - - def __iter__(self): - return iter(self.frame_ids) - - def __len__(self): - return len(self.frame_ids) - - -def to_hwc_uint8_numpy(chw_float32_torch): - assert chw_float32_torch.dtype == torch.float32 - assert chw_float32_torch.ndim == 3 - c, h, w = chw_float32_torch.shape - assert c < h and c < w, f"expect channel first images, but instead {chw_float32_torch.shape}" - hwc_uint8_numpy = (chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy() - return hwc_uint8_numpy - - -def visualize_dataset( - repo_id: str, - episode_index: int, - batch_size: int = 32, - num_workers: int = 0, - mode: str = "local", - web_port: int = 9090, - ws_port: int = 9087, - save: bool = False, - output_dir: Path | None = None, -) -> Path | None: - if save: - assert ( - output_dir is not None - ), "Set an output directory where to write .rrd files with `--output-dir path/to/directory`." - - logging.info("Loading dataset") - dataset = LeRobotDataset(repo_id) - - logging.info("Loading dataloader") - episode_sampler = EpisodeSampler(dataset, episode_index) - dataloader = torch.utils.data.DataLoader( - dataset, - num_workers=num_workers, - batch_size=batch_size, - sampler=episode_sampler, - ) - - logging.info("Starting Rerun") - - if mode not in ["local", "distant"]: - raise ValueError(mode) - - spawn_local_viewer = mode == "local" and not save - rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer) - - # Manually call python garbage collector after `rr.init` to avoid hanging in a blocking flush - # when iterating on a dataloader with `num_workers` > 0 - # TODO(rcadene): remove `gc.collect` when rerun version 0.16 is out, which includes a fix - gc.collect() - - if mode == "distant": - rr.serve(open_browser=False, web_port=web_port, ws_port=ws_port) - - logging.info("Logging to Rerun") - - for batch in tqdm.tqdm(dataloader, total=len(dataloader)): - # iterate over the batch - for i in range(len(batch["index"])): - rr.set_time_sequence("frame_index", batch["frame_index"][i].item()) - rr.set_time_seconds("timestamp", batch["timestamp"][i].item()) - - # display each camera image - for key in dataset.camera_keys: - # TODO(rcadene): add `.compress()`? is it lossless? - rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i]))) - - # display each dimension of action space (e.g. actuators command) - if "action" in batch: - for dim_idx, val in enumerate(batch["action"][i]): - rr.log(f"action/{dim_idx}", rr.Scalar(val.item())) - - # display each dimension of observed state space (e.g. agent position in joint space) - if "observation.state" in batch: - for dim_idx, val in enumerate(batch["observation.state"][i]): - rr.log(f"state/{dim_idx}", rr.Scalar(val.item())) - - if "next.done" in batch: - rr.log("next.done", rr.Scalar(batch["next.done"][i].item())) - - if "next.reward" in batch: - rr.log("next.reward", rr.Scalar(batch["next.reward"][i].item())) - - if "next.success" in batch: - rr.log("next.success", rr.Scalar(batch["next.success"][i].item())) - - if mode == "local" and save: - # save .rrd locally - output_dir = Path(output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - repo_id_str = repo_id.replace("/", "_") - rrd_path = output_dir / f"{repo_id_str}_episode_{episode_index}.rrd" - rr.save(rrd_path) - return rrd_path - - elif mode == "distant": - # stop the process from exiting since it is serving the websocket connection - try: - while True: - time.sleep(1) - except KeyboardInterrupt: - print("Ctrl-C received. Exiting.") - - -def main(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "--repo-id", - type=str, - required=True, - help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht`).", - ) - parser.add_argument( - "--episode-index", - type=int, - required=True, - help="Episode to visualize.", - ) - parser.add_argument( - "--batch-size", - type=int, - default=32, - help="Batch size loaded by DataLoader.", - ) - parser.add_argument( - "--num-workers", - type=int, - default=4, - help="Number of processes of Dataloader for loading the data.", - ) - parser.add_argument( - "--mode", - type=str, - default="local", - help=( - "Mode of viewing between 'local' or 'distant'. " - "'local' requires data to be on a local machine. It spawns a viewer to visualize the data locally. " - "'distant' creates a server on the distant machine where the data is stored. Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine." - ), - ) - parser.add_argument( - "--web-port", - type=int, - default=9090, - help="Web port for rerun.io when `--mode distant` is set.", - ) - parser.add_argument( - "--ws-port", - type=int, - default=9087, - help="Web socket port for rerun.io when `--mode distant` is set.", - ) - parser.add_argument( - "--save", - type=int, - default=0, - help=( - "Save a .rrd file in the directory provided by `--output-dir`. " - "It also deactivates the spawning of a viewer. ", - "Visualize the data by running `rerun path/to/file.rrd` on your local machine.", - ), - ) - parser.add_argument( - "--output-dir", - type=str, - help="Directory path to write a .rrd file when `--save 1` is set.", - ) - - args = parser.parse_args() - visualize_dataset(**vars(args)) - - -if __name__ == "__main__": - main()