From 5b74205e1691b64d55a7b771ece3e5e020801edd Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Sat, 1 Jun 2024 15:45:26 +0000 Subject: [PATCH] Add custom visualize_dataset.py --- lerobot/common/logger.py | 3 - lerobot/common/policies/act/modeling_act.py | 23 +- lerobot/scripts/train.py | 2 +- lerobot/scripts/visualize_dataset.py | 597 +++++++++++++++----- lerobot/scripts/visualize_dataset_rerun.py | 263 +++++++++ tests/test_visualize_dataset.py | 5 +- 6 files changed, 730 insertions(+), 163 deletions(-) create mode 100644 lerobot/scripts/visualize_dataset_rerun.py diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py index 853acbc3..71d961a1 100644 --- a/lerobot/common/logger.py +++ b/lerobot/common/logger.py @@ -233,9 +233,6 @@ class Logger: if self._wandb is not None: for k, v in d.items(): if not isinstance(v, (int, float, str)): - logging.warning( - f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.' - ) continue self._wandb.log({f"{mode}/{k}": v}, step=step) diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index bef59bec..a3dc4ccb 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -139,25 +139,26 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): batch = self.normalize_targets(batch) actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) - l1_loss = ( - F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1) - ).mean() + bsize = actions_hat.shape[0] + l1_loss = F.l1_loss(batch["action"], actions_hat, reduction="none") + l1_loss = l1_loss * ~batch["action_is_pad"].unsqueeze(-1) + l1_loss = l1_loss.view(bsize, -1).mean(dim=1) + + out_dict = {} + out_dict["l1_loss"] = l1_loss - loss_dict = {"l1_loss": l1_loss.item()} if self.config.use_vae: # Calculate Dā‚–ā‚—(latent_pdf || standard_normal). Note: After computing the KL-divergence for # each dimension independently, we sum over the latent dimension to get the total # KL-divergence per batch element, then take the mean over the batch. # (See App. B of https://arxiv.org/abs/1312.6114 for more details). - mean_kld = ( - (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean() - ) - loss_dict["kld_loss"] = mean_kld.item() - loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight + kld_loss = (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1) + out_dict["loss"] = l1_loss + kld_loss * self.config.kl_weight else: - loss_dict["loss"] = l1_loss + out_dict["loss"] = l1_loss - return loss_dict + out_dict["action"] = self.unnormalize_outputs({"action": actions_hat})["action"] + return out_dict class ACT(nn.Module): diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 860412bd..e63a5633 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -107,7 +107,7 @@ def update_policy( with torch.autocast(device_type=device.type) if use_amp else nullcontext(): output_dict = policy.forward(batch) # TODO(rcadene): policy.unnormalize_outputs(out_dict) - loss = output_dict["loss"] + loss = output_dict["loss"].mean() grad_scaler.scale(loss).backward() # Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**. diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py index 58da6a47..6534343d 100644 --- a/lerobot/scripts/visualize_dataset.py +++ b/lerobot/scripts/visualize_dataset.py @@ -30,48 +30,46 @@ Examples: - Visualize data stored on a local machine: ``` local$ python lerobot/scripts/visualize_dataset.py \ - --repo-id lerobot/pusht \ - --episode-index 0 + --repo-id lerobot/pusht + +local$ open http://localhost:9090 ``` - Visualize data stored on a distant machine with a local viewer: ``` distant$ python lerobot/scripts/visualize_dataset.py \ + --repo-id lerobot/pusht + +local$ ssh -L 9090:localhost:9090 distant # create a ssh tunnel +local$ open http://localhost:9090 +``` + +- Select episodes to visualize: +``` +python lerobot/scripts/visualize_dataset.py \ --repo-id lerobot/pusht \ - --episode-index 0 \ - --save 1 \ - --output-dir path/to/directory - -local$ scp distant:path/to/directory/lerobot_pusht_episode_0.rrd . -local$ rerun lerobot_pusht_episode_0.rrd + --episode-indices 7 3 5 1 4 ``` - -- Visualize data stored on a distant machine through streaming: -(You need to forward the websocket port to the distant machine, with -`ssh -L 9087:localhost:9087 username@remote-host`) -``` -distant$ python lerobot/scripts/visualize_dataset.py \ - --repo-id lerobot/pusht \ - --episode-index 0 \ - --mode distant \ - --ws-port 9087 - -local$ rerun ws://localhost:9087 -``` - """ import argparse -import gc +import http.server import logging -import time +import os +import shutil +import socketserver from pathlib import Path -import rerun as rr import torch import tqdm +import yaml +from bs4 import BeautifulSoup +from huggingface_hub import snapshot_download +from safetensors.torch import load_file, save_file from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.policies.act.modeling_act import ACTPolicy +from lerobot.common.utils.utils import init_logging class EpisodeSampler(torch.utils.data.Sampler): @@ -87,33 +85,307 @@ class EpisodeSampler(torch.utils.data.Sampler): return len(self.frame_ids) -def to_hwc_uint8_numpy(chw_float32_torch): - assert chw_float32_torch.dtype == torch.float32 - assert chw_float32_torch.ndim == 3 - c, h, w = chw_float32_torch.shape - assert c < h and c < w, f"expect channel first images, but instead {chw_float32_torch.shape}" - hwc_uint8_numpy = (chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy() - return hwc_uint8_numpy +class NoCacheHTTPRequestHandler(http.server.SimpleHTTPRequestHandler): + def end_headers(self): + self.send_header("Cache-Control", "no-store, no-cache, must-revalidate") + self.send_header("Pragma", "no-cache") + self.send_header("Expires", "0") + super().end_headers() -def visualize_dataset( - repo_id: str, - episode_index: int, - batch_size: int = 32, - num_workers: int = 0, - mode: str = "local", - web_port: int = 9090, - ws_port: int = 9087, - save: bool = False, - output_dir: Path | None = None, -) -> Path | None: - if save: - assert ( - output_dir is not None - ), "Set an output directory where to write .rrd files with `--output-dir path/to/directory`." +def run_server(path, port): + # Change directory to serve 'index.html` as front page + os.chdir(path) - logging.info("Loading dataset") - dataset = LeRobotDataset(repo_id) + with socketserver.TCPServer(("", port), NoCacheHTTPRequestHandler) as httpd: + logging.info(f"Serving HTTP on 0.0.0.0 port {port} (http://0.0.0.0:{port}/) ...") + httpd.serve_forever() + + +def create_html_page(page_title: str): + """Create a html page with beautiful soop with default doctype, meta, header and title.""" + soup = BeautifulSoup("", "html.parser") + + doctype = soup.new_tag("!DOCTYPE html") + soup.append(doctype) + + html = soup.new_tag("html", lang="en") + soup.append(html) + + head = soup.new_tag("head") + html.append(head) + + meta_charset = soup.new_tag("meta", charset="UTF-8") + head.append(meta_charset) + + meta_viewport = soup.new_tag( + "meta", attrs={"name": "viewport", "content": "width=device-width, initial-scale=1.0"} + ) + head.append(meta_viewport) + + title = soup.new_tag("title") + title.string = page_title + head.append(title) + + body = soup.new_tag("body") + html.append(body) + + main_div = soup.new_tag("div") + body.append(main_div) + return soup, head, body + + +def write_episode_data_csv(output_dir, file_name, episode_index, dataset, inference_results=None): + """Write a csv file containg timeseries data of an episode (e.g. state and action). + This file will be loaded by Dygraph javascript to plot data in real time.""" + from_idx = dataset.episode_data_index["from"][episode_index] + to_idx = dataset.episode_data_index["to"][episode_index] + + has_state = "observation.state" in dataset.hf_dataset.features + has_action = "action" in dataset.hf_dataset.features + has_inference = inference_results is not None + + # init header of csv with state and action names + header = ["timestamp"] + if has_state: + dim_state = len(dataset.hf_dataset["observation.state"][0]) + header += [f"state_{i}" for i in range(dim_state)] + if has_action: + dim_action = len(dataset.hf_dataset["action"][0]) + header += [f"action_{i}" for i in range(dim_action)] + if has_inference: + assert "actions" in inference_results + assert "loss" in inference_results + dim_pred_action = inference_results["actions"].shape[2] + header += [f"pred_action_{i}" for i in range(dim_pred_action)] + header += ["loss"] + + columns = ["timestamp"] + if has_state: + columns += ["observation.state"] + if has_action: + columns += ["action"] + + rows = [] + data = dataset.hf_dataset.select_columns(columns) + for i in range(from_idx, to_idx): + row = [data[i]["timestamp"].item()] + if has_state: + row += data[i]["observation.state"].tolist() + if has_action: + row += data[i]["action"].tolist() + rows.append(row) + + if has_inference: + num_frames = len(rows) + assert num_frames == inference_results["actions"].shape[0] + assert num_frames == inference_results["loss"].shape[0] + for i in range(num_frames): + rows[i] += inference_results["actions"][i, 0].tolist() + rows[i] += [inference_results["loss"][i].item()] + + output_dir.mkdir(parents=True, exist_ok=True) + with open(output_dir / file_name, "w") as f: + f.write(",".join(header) + "\n") + for row in rows: + row_str = [str(col) for col in row] + f.write(",".join(row_str) + "\n") + + +def write_episode_data_js(output_dir, file_name, ep_csv_fname, dataset): + """Write a javascript file containing logic to synchronize camera feeds and timeseries.""" + s = "" + s += "document.addEventListener('DOMContentLoaded', function () {\n" + for i, key in enumerate(dataset.video_frame_keys): + s += f" const video{i} = document.getElementById('video_{key}');\n" + s += " const slider = document.getElementById('videoControl');\n" + s += " const playButton = document.getElementById('playButton');\n" + s += f" const dygraph = new Dygraph(document.getElementById('graph'), '{ep_csv_fname}', " + "{\n" + s += " pixelsPerPoint: 0.01,\n" + s += " legend: 'always',\n" + s += " labelsDiv: document.getElementById('labels'),\n" + s += " labelsSeparateLines: true,\n" + s += " labelsKMB: true,\n" + s += " highlightCircleSize: 1.5,\n" + s += " highlightSeriesOpts: {\n" + s += " strokeWidth: 1.5,\n" + s += " strokeBorderWidth: 1,\n" + s += " highlightCircleSize: 3\n" + s += " }\n" + s += " });\n" + s += "\n" + s += " // Function to play both videos\n" + s += " playButton.addEventListener('click', function () {\n" + for i in range(len(dataset.video_frame_keys)): + s += f" video{i}.play();\n" + s += " // playButton.disabled = true; // Optional: disable button after playing\n" + s += " });\n" + s += "\n" + s += " // Update the video time when the slider value changes\n" + s += " slider.addEventListener('input', function () {\n" + s += " const sliderValue = slider.value;\n" + for i in range(len(dataset.video_frame_keys)): + s += f" const time{i} = (video{i}.duration * sliderValue) / 100;\n" + for i in range(len(dataset.video_frame_keys)): + s += f" video{i}.currentTime = time{i};\n" + s += " });\n" + s += "\n" + s += " // Synchronize slider with the video's current time\n" + s += " const syncSlider = (video) => {\n" + s += " video.addEventListener('timeupdate', function () {\n" + s += " if (video.duration) {\n" + s += " const pc = (100 / video.duration) * video.currentTime;\n" + s += " slider.value = pc;\n" + s += " const index = Math.floor(pc * dygraph.numRows() / 100);\n" + s += " dygraph.setSelection(index, undefined, true, true);\n" + s += " }\n" + s += " });\n" + s += " };\n" + s += "\n" + for i in range(len(dataset.video_frame_keys)): + s += f" syncSlider(video{i});\n" + s += "\n" + s += "});\n" + + output_dir.mkdir(parents=True, exist_ok=True) + with open(output_dir / file_name, "w", encoding="utf-8") as f: + f.write(s) + + +def write_episode_data_html(output_dir, file_name, js_fname, ep_index, dataset): + """Write an html file containg video feeds and timeseries associated to an episode.""" + soup, head, body = create_html_page("") + + css_style = soup.new_tag("style") + css_style.string = "" + css_style.string += "#labels > span.highlight {\n" + css_style.string += " border: 1px solid grey;\n" + css_style.string += "}" + head.append(css_style) + + # Add videos from camera feeds + + videos_control_div = soup.new_tag("div") + body.append(videos_control_div) + + videos_div = soup.new_tag("div") + videos_control_div.append(videos_div) + + def create_video(id, src): + video = soup.new_tag("video", id=id, width="320", height="240", controls="") + source = soup.new_tag("source", src=src, type="video/mp4") + video.string = "Your browser does not support the video tag." + video.append(source) + return video + + # get first frame of episode (hack to get video_path of the episode) + first_frame_idx = dataset.episode_data_index["from"][ep_index].item() + + for key in dataset.video_frame_keys: + # Example of video_path: 'videos/observation.image_episode_000004.mp4' + video_path = dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"] + videos_div.append(create_video(f"video_{key}", video_path)) + + # Add controls for videos and graph + + control_div = soup.new_tag("div") + videos_control_div.append(control_div) + + button_div = soup.new_tag("div") + control_div.append(button_div) + + button = soup.new_tag("button", id="playButton") + button.string = "Play Videos" + button_div.append(button) + + slider_div = soup.new_tag("div") + control_div.append(slider_div) + + slider = soup.new_tag("input", type="range", id="videoControl", min="0", max="100", value="0", step="1") + control_div.append(slider) + + # Add graph of states/actions, and its labels + + graph_labels_div = soup.new_tag("div", style="display: flex;") + body.append(graph_labels_div) + + graph_div = soup.new_tag("div", id="graph", style="flex: 1; width: 85%") + graph_labels_div.append(graph_div) + + labels_div = soup.new_tag("div", id="labels", style="flex: 1; width: 15%") + graph_labels_div.append(labels_div) + + # add dygraph library + script = soup.new_tag("script", type="text/javascript", src=js_fname) + body.append(script) + + script_dygraph = soup.new_tag( + "script", + type="text/javascript", + src="https://cdn.jsdelivr.net/npm/dygraphs@2.1.0/dist/dygraph.min.js", + ) + body.append(script_dygraph) + + link_dygraph = soup.new_tag( + "link", rel="stylesheet", href="https://cdn.jsdelivr.net/npm/dygraphs@2.1.0/dist/dygraph.min.css" + ) + body.append(link_dygraph) + + # Write as a html file + + output_dir.mkdir(parents=True, exist_ok=True) + with open(output_dir / file_name, "w", encoding="utf-8") as f: + f.write(soup.prettify()) + + +def write_episodes_list_html(output_dir, file_name, ep_indices, ep_html_fnames, dataset): + """Write an html file containing information related to the dataset and a list of links to + html pages of episodes.""" + soup, head, body = create_html_page("TODO") + + h3 = soup.new_tag("h3") + h3.string = "TODO" + body.append(h3) + + ul_info = soup.new_tag("ul") + body.append(ul_info) + + li_info = soup.new_tag("li") + li_info.string = f"Number of samples/frames: {dataset.num_samples}" + ul_info.append(li_info) + + li_info = soup.new_tag("li") + li_info.string = f"Number of episodes: {dataset.num_episodes}" + ul_info.append(li_info) + + li_info = soup.new_tag("li") + li_info.string = f"Frames per second: {dataset.fps}" + ul_info.append(li_info) + + # li_info = soup.new_tag("li") + # li_info.string = f"Size: {format_big_number(dataset.hf_dataset.info.size_in_bytes)}B" + # ul_info.append(li_info) + + ul = soup.new_tag("ul") + body.append(ul) + + for ep_idx, ep_html_fname in zip(ep_indices, ep_html_fnames, strict=False): + li = soup.new_tag("li") + ul.append(li) + + a = soup.new_tag("a", href=ep_html_fname) + a.string = f"Episode number {ep_idx}" + + li.append(a) + + output_dir.mkdir(parents=True, exist_ok=True) + with open(output_dir / file_name, "w", encoding="utf-8") as f: + f.write(soup.prettify()) + + +def run_inference(dataset, episode_index, policy, num_workers=4, batch_size=32, device="cuda"): + policy.eval() + policy.to(device) logging.info("Loading dataloader") episode_sampler = EpisodeSampler(dataset, episode_index) @@ -124,70 +396,104 @@ def visualize_dataset( sampler=episode_sampler, ) - logging.info("Starting Rerun") - - if mode not in ["local", "distant"]: - raise ValueError(mode) - - spawn_local_viewer = mode == "local" and not save - rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer) - - # Manually call python garbage collector after `rr.init` to avoid hanging in a blocking flush - # when iterating on a dataloader with `num_workers` > 0 - # TODO(rcadene): remove `gc.collect` when rerun version 0.16 is out, which includes a fix - gc.collect() - - if mode == "distant": - rr.serve(open_browser=False, web_port=web_port, ws_port=ws_port) - - logging.info("Logging to Rerun") - + logging.info("Running inference") + inference_results = {} for batch in tqdm.tqdm(dataloader, total=len(dataloader)): - # iterate over the batch - for i in range(len(batch["index"])): - rr.set_time_sequence("frame_index", batch["frame_index"][i].item()) - rr.set_time_seconds("timestamp", batch["timestamp"][i].item()) + batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()} + with torch.inference_mode(): + output_dict = policy.forward(batch) - # display each camera image - for key in dataset.camera_keys: - # TODO(rcadene): add `.compress()`? is it lossless? - rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i]))) + for key in output_dict: + if key not in inference_results: + inference_results[key] = [] + inference_results[key].append(output_dict[key].to("cpu")) - # display each dimension of action space (e.g. actuators command) - if "action" in batch: - for dim_idx, val in enumerate(batch["action"][i]): - rr.log(f"action/{dim_idx}", rr.Scalar(val.item())) + for key in inference_results: + inference_results[key] = torch.cat(inference_results[key]) - # display each dimension of observed state space (e.g. agent position in joint space) - if "observation.state" in batch: - for dim_idx, val in enumerate(batch["observation.state"][i]): - rr.log(f"state/{dim_idx}", rr.Scalar(val.item())) + return inference_results - if "next.done" in batch: - rr.log("next.done", rr.Scalar(batch["next.done"][i].item())) - if "next.reward" in batch: - rr.log("next.reward", rr.Scalar(batch["next.reward"][i].item())) +def visualize_dataset( + repo_id: str, + episode_indices: list[int] = None, + output_dir: Path | None = None, + serve: bool = True, + port: int = 9090, + force_overwrite: bool = True, + policy_repo_id: str | None = None, + policy_ckpt_path: Path | None = None, + batch_size: int = 32, + num_workers: int = 4, +) -> Path | None: + init_logging() - if "next.success" in batch: - rr.log("next.success", rr.Scalar(batch["next.success"][i].item())) + has_policy = policy_repo_id or policy_ckpt_path - if mode == "local" and save: - # save .rrd locally - output_dir = Path(output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - repo_id_str = repo_id.replace("/", "_") - rrd_path = output_dir / f"{repo_id_str}_episode_{episode_index}.rrd" - rr.save(rrd_path) - return rrd_path + if has_policy: + logging.info("Loading policy") + if policy_repo_id: + pretrained_policy_path = Path(snapshot_download(policy_repo_id)) + elif policy_ckpt_path: + pretrained_policy_path = Path(policy_ckpt_path) + policy = ACTPolicy.from_pretrained(pretrained_policy_path) + with open(pretrained_policy_path / "config.yaml") as f: + cfg = yaml.safe_load(f) + delta_timestamps = cfg["training"]["delta_timestamps"] + else: + delta_timestamps = None - elif mode == "distant": - # stop the process from exiting since it is serving the websocket connection - try: - while True: - time.sleep(1) - except KeyboardInterrupt: - print("Ctrl-C received. Exiting.") + logging.info("Loading dataset") + dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps) + + if not dataset.video: + raise NotImplementedError(f"Image datasets ({dataset.video=}) are currently not supported.") + + if output_dir is None: + output_dir = f"outputs/visualize_dataset/{repo_id}" + + output_dir = Path(output_dir) + if force_overwrite and output_dir.exists(): + shutil.rmtree(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # Create a simlink from the dataset video folder containg mp4 files to the output directory + # so that the http server can get access to the mp4 files. + ln_videos_dir = output_dir / "videos" + if not ln_videos_dir.exists(): + ln_videos_dir.symlink_to(dataset.videos_dir.resolve()) + + if episode_indices is None: + episode_indices = list(range(dataset.num_episodes)) + + logging.info("Writing html") + ep_html_fnames = [] + for episode_index in tqdm.tqdm(episode_indices): + inference_results = None + if has_policy: + inference_results_path = output_dir / f"episode_{episode_index}.safetensors" + if inference_results_path.exists(): + inference_results = load_file(inference_results_path) + else: + inference_results = run_inference(dataset, episode_index, policy) + save_file(inference_results, inference_results_path) + + # write states and actions in a csv + ep_csv_fname = f"episode_{episode_index}.csv" + write_episode_data_csv(output_dir, ep_csv_fname, episode_index, dataset, inference_results) + + js_fname = f"episode_{episode_index}.js" + write_episode_data_js(output_dir, js_fname, ep_csv_fname, dataset) + + # write a html page to view videos and timeseries + ep_html_fname = f"episode_{episode_index}.html" + write_episode_data_html(output_dir, ep_html_fname, js_fname, episode_index, dataset) + ep_html_fnames.append(ep_html_fname) + + write_episodes_list_html(output_dir, "index.html", episode_indices, ep_html_fnames, dataset) + + if serve: + run_server(output_dir, port) def main(): @@ -197,13 +503,51 @@ def main(): "--repo-id", type=str, required=True, - help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht`).", + help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).", ) parser.add_argument( - "--episode-index", + "--episode-indices", type=int, - required=True, - help="Episode to visualize.", + nargs="*", + default=None, + help="Episode indices to visualize (e.g. `0 1 5 6` to load episodes of index 0, 1, 5 and 6). By default loads all episodes.", + ) + parser.add_argument( + "--output-dir", + type=str, + default=None, + help="Directory path to write html files and kickoff a web server. By default write them to 'outputs/visualize_dataset/REPO_ID'.", + ) + parser.add_argument( + "--serve", + type=int, + default=1, + help="Launch web server.", + ) + parser.add_argument( + "--port", + type=int, + default=9090, + help="Web port used by the http server.", + ) + parser.add_argument( + "--force-overwrite", + type=int, + default=1, + help="Delete the output directory if it exists already.", + ) + + parser.add_argument( + "--policy-repo-id", + type=str, + default=None, + help="Name of hugging face repositery containing a pretrained policy (e.g. `lerobot/diffusion_pusht` for https://huggingface.co/lerobot/diffusion_pusht).", + ) + parser.add_argument( + "--policy-ckpt-path", + type=str, + default=None, + help="Name of hugging face repositery containing a pretrained policy (e.g. `lerobot/diffusion_pusht` for https://huggingface.co/lerobot/diffusion_pusht).", ) parser.add_argument( "--batch-size", @@ -217,43 +561,6 @@ def main(): default=4, help="Number of processes of Dataloader for loading the data.", ) - parser.add_argument( - "--mode", - type=str, - default="local", - help=( - "Mode of viewing between 'local' or 'distant'. " - "'local' requires data to be on a local machine. It spawns a viewer to visualize the data locally. " - "'distant' creates a server on the distant machine where the data is stored. Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine." - ), - ) - parser.add_argument( - "--web-port", - type=int, - default=9090, - help="Web port for rerun.io when `--mode distant` is set.", - ) - parser.add_argument( - "--ws-port", - type=int, - default=9087, - help="Web socket port for rerun.io when `--mode distant` is set.", - ) - parser.add_argument( - "--save", - type=int, - default=0, - help=( - "Save a .rrd file in the directory provided by `--output-dir`. " - "It also deactivates the spawning of a viewer. ", - "Visualize the data by running `rerun path/to/file.rrd` on your local machine.", - ), - ) - parser.add_argument( - "--output-dir", - type=str, - help="Directory path to write a .rrd file when `--save 1` is set.", - ) args = parser.parse_args() visualize_dataset(**vars(args)) diff --git a/lerobot/scripts/visualize_dataset_rerun.py b/lerobot/scripts/visualize_dataset_rerun.py new file mode 100644 index 00000000..58da6a47 --- /dev/null +++ b/lerobot/scripts/visualize_dataset_rerun.py @@ -0,0 +1,263 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset. + +Note: The last frame of the episode doesnt always correspond to a final state. +That's because our datasets are composed of transition from state to state up to +the antepenultimate state associated to the ultimate action to arrive in the final state. +However, there might not be a transition from a final state to another state. + +Note: This script aims to visualize the data used to train the neural networks. +~What you see is what you get~. When visualizing image modality, it is often expected to observe +lossly compression artifacts since these images have been decoded from compressed mp4 videos to +save disk space. The compression factor applied has been tuned to not affect success rate. + +Examples: + +- Visualize data stored on a local machine: +``` +local$ python lerobot/scripts/visualize_dataset.py \ + --repo-id lerobot/pusht \ + --episode-index 0 +``` + +- Visualize data stored on a distant machine with a local viewer: +``` +distant$ python lerobot/scripts/visualize_dataset.py \ + --repo-id lerobot/pusht \ + --episode-index 0 \ + --save 1 \ + --output-dir path/to/directory + +local$ scp distant:path/to/directory/lerobot_pusht_episode_0.rrd . +local$ rerun lerobot_pusht_episode_0.rrd +``` + +- Visualize data stored on a distant machine through streaming: +(You need to forward the websocket port to the distant machine, with +`ssh -L 9087:localhost:9087 username@remote-host`) +``` +distant$ python lerobot/scripts/visualize_dataset.py \ + --repo-id lerobot/pusht \ + --episode-index 0 \ + --mode distant \ + --ws-port 9087 + +local$ rerun ws://localhost:9087 +``` + +""" + +import argparse +import gc +import logging +import time +from pathlib import Path + +import rerun as rr +import torch +import tqdm + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset + + +class EpisodeSampler(torch.utils.data.Sampler): + def __init__(self, dataset, episode_index): + from_idx = dataset.episode_data_index["from"][episode_index].item() + to_idx = dataset.episode_data_index["to"][episode_index].item() + self.frame_ids = range(from_idx, to_idx) + + def __iter__(self): + return iter(self.frame_ids) + + def __len__(self): + return len(self.frame_ids) + + +def to_hwc_uint8_numpy(chw_float32_torch): + assert chw_float32_torch.dtype == torch.float32 + assert chw_float32_torch.ndim == 3 + c, h, w = chw_float32_torch.shape + assert c < h and c < w, f"expect channel first images, but instead {chw_float32_torch.shape}" + hwc_uint8_numpy = (chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy() + return hwc_uint8_numpy + + +def visualize_dataset( + repo_id: str, + episode_index: int, + batch_size: int = 32, + num_workers: int = 0, + mode: str = "local", + web_port: int = 9090, + ws_port: int = 9087, + save: bool = False, + output_dir: Path | None = None, +) -> Path | None: + if save: + assert ( + output_dir is not None + ), "Set an output directory where to write .rrd files with `--output-dir path/to/directory`." + + logging.info("Loading dataset") + dataset = LeRobotDataset(repo_id) + + logging.info("Loading dataloader") + episode_sampler = EpisodeSampler(dataset, episode_index) + dataloader = torch.utils.data.DataLoader( + dataset, + num_workers=num_workers, + batch_size=batch_size, + sampler=episode_sampler, + ) + + logging.info("Starting Rerun") + + if mode not in ["local", "distant"]: + raise ValueError(mode) + + spawn_local_viewer = mode == "local" and not save + rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer) + + # Manually call python garbage collector after `rr.init` to avoid hanging in a blocking flush + # when iterating on a dataloader with `num_workers` > 0 + # TODO(rcadene): remove `gc.collect` when rerun version 0.16 is out, which includes a fix + gc.collect() + + if mode == "distant": + rr.serve(open_browser=False, web_port=web_port, ws_port=ws_port) + + logging.info("Logging to Rerun") + + for batch in tqdm.tqdm(dataloader, total=len(dataloader)): + # iterate over the batch + for i in range(len(batch["index"])): + rr.set_time_sequence("frame_index", batch["frame_index"][i].item()) + rr.set_time_seconds("timestamp", batch["timestamp"][i].item()) + + # display each camera image + for key in dataset.camera_keys: + # TODO(rcadene): add `.compress()`? is it lossless? + rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i]))) + + # display each dimension of action space (e.g. actuators command) + if "action" in batch: + for dim_idx, val in enumerate(batch["action"][i]): + rr.log(f"action/{dim_idx}", rr.Scalar(val.item())) + + # display each dimension of observed state space (e.g. agent position in joint space) + if "observation.state" in batch: + for dim_idx, val in enumerate(batch["observation.state"][i]): + rr.log(f"state/{dim_idx}", rr.Scalar(val.item())) + + if "next.done" in batch: + rr.log("next.done", rr.Scalar(batch["next.done"][i].item())) + + if "next.reward" in batch: + rr.log("next.reward", rr.Scalar(batch["next.reward"][i].item())) + + if "next.success" in batch: + rr.log("next.success", rr.Scalar(batch["next.success"][i].item())) + + if mode == "local" and save: + # save .rrd locally + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + repo_id_str = repo_id.replace("/", "_") + rrd_path = output_dir / f"{repo_id_str}_episode_{episode_index}.rrd" + rr.save(rrd_path) + return rrd_path + + elif mode == "distant": + # stop the process from exiting since it is serving the websocket connection + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + print("Ctrl-C received. Exiting.") + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--repo-id", + type=str, + required=True, + help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht`).", + ) + parser.add_argument( + "--episode-index", + type=int, + required=True, + help="Episode to visualize.", + ) + parser.add_argument( + "--batch-size", + type=int, + default=32, + help="Batch size loaded by DataLoader.", + ) + parser.add_argument( + "--num-workers", + type=int, + default=4, + help="Number of processes of Dataloader for loading the data.", + ) + parser.add_argument( + "--mode", + type=str, + default="local", + help=( + "Mode of viewing between 'local' or 'distant'. " + "'local' requires data to be on a local machine. It spawns a viewer to visualize the data locally. " + "'distant' creates a server on the distant machine where the data is stored. Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine." + ), + ) + parser.add_argument( + "--web-port", + type=int, + default=9090, + help="Web port for rerun.io when `--mode distant` is set.", + ) + parser.add_argument( + "--ws-port", + type=int, + default=9087, + help="Web socket port for rerun.io when `--mode distant` is set.", + ) + parser.add_argument( + "--save", + type=int, + default=0, + help=( + "Save a .rrd file in the directory provided by `--output-dir`. " + "It also deactivates the spawning of a viewer. ", + "Visualize the data by running `rerun path/to/file.rrd` on your local machine.", + ), + ) + parser.add_argument( + "--output-dir", + type=str, + help="Directory path to write a .rrd file when `--save 1` is set.", + ) + + args = parser.parse_args() + visualize_dataset(**vars(args)) + + +if __name__ == "__main__": + main() diff --git a/tests/test_visualize_dataset.py b/tests/test_visualize_dataset.py index 99954040..71819568 100644 --- a/tests/test_visualize_dataset.py +++ b/tests/test_visualize_dataset.py @@ -25,9 +25,8 @@ from lerobot.scripts.visualize_dataset import visualize_dataset def test_visualize_dataset(tmpdir, repo_id): rrd_path = visualize_dataset( repo_id, - episode_index=0, - batch_size=32, - save=True, + episode_indices=[0], output_dir=tmpdir, + serve=False, ) assert rrd_path.exists()