diff --git a/lerobot/scripts/visualize_dataset_rerun.py b/lerobot/scripts/visualize_dataset_rerun.py new file mode 100644 index 00000000..58da6a47 --- /dev/null +++ b/lerobot/scripts/visualize_dataset_rerun.py @@ -0,0 +1,263 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset. + +Note: The last frame of the episode doesnt always correspond to a final state. +That's because our datasets are composed of transition from state to state up to +the antepenultimate state associated to the ultimate action to arrive in the final state. +However, there might not be a transition from a final state to another state. + +Note: This script aims to visualize the data used to train the neural networks. +~What you see is what you get~. When visualizing image modality, it is often expected to observe +lossly compression artifacts since these images have been decoded from compressed mp4 videos to +save disk space. The compression factor applied has been tuned to not affect success rate. + +Examples: + +- Visualize data stored on a local machine: +``` +local$ python lerobot/scripts/visualize_dataset.py \ + --repo-id lerobot/pusht \ + --episode-index 0 +``` + +- Visualize data stored on a distant machine with a local viewer: +``` +distant$ python lerobot/scripts/visualize_dataset.py \ + --repo-id lerobot/pusht \ + --episode-index 0 \ + --save 1 \ + --output-dir path/to/directory + +local$ scp distant:path/to/directory/lerobot_pusht_episode_0.rrd . +local$ rerun lerobot_pusht_episode_0.rrd +``` + +- Visualize data stored on a distant machine through streaming: +(You need to forward the websocket port to the distant machine, with +`ssh -L 9087:localhost:9087 username@remote-host`) +``` +distant$ python lerobot/scripts/visualize_dataset.py \ + --repo-id lerobot/pusht \ + --episode-index 0 \ + --mode distant \ + --ws-port 9087 + +local$ rerun ws://localhost:9087 +``` + +""" + +import argparse +import gc +import logging +import time +from pathlib import Path + +import rerun as rr +import torch +import tqdm + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset + + +class EpisodeSampler(torch.utils.data.Sampler): + def __init__(self, dataset, episode_index): + from_idx = dataset.episode_data_index["from"][episode_index].item() + to_idx = dataset.episode_data_index["to"][episode_index].item() + self.frame_ids = range(from_idx, to_idx) + + def __iter__(self): + return iter(self.frame_ids) + + def __len__(self): + return len(self.frame_ids) + + +def to_hwc_uint8_numpy(chw_float32_torch): + assert chw_float32_torch.dtype == torch.float32 + assert chw_float32_torch.ndim == 3 + c, h, w = chw_float32_torch.shape + assert c < h and c < w, f"expect channel first images, but instead {chw_float32_torch.shape}" + hwc_uint8_numpy = (chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy() + return hwc_uint8_numpy + + +def visualize_dataset( + repo_id: str, + episode_index: int, + batch_size: int = 32, + num_workers: int = 0, + mode: str = "local", + web_port: int = 9090, + ws_port: int = 9087, + save: bool = False, + output_dir: Path | None = None, +) -> Path | None: + if save: + assert ( + output_dir is not None + ), "Set an output directory where to write .rrd files with `--output-dir path/to/directory`." + + logging.info("Loading dataset") + dataset = LeRobotDataset(repo_id) + + logging.info("Loading dataloader") + episode_sampler = EpisodeSampler(dataset, episode_index) + dataloader = torch.utils.data.DataLoader( + dataset, + num_workers=num_workers, + batch_size=batch_size, + sampler=episode_sampler, + ) + + logging.info("Starting Rerun") + + if mode not in ["local", "distant"]: + raise ValueError(mode) + + spawn_local_viewer = mode == "local" and not save + rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer) + + # Manually call python garbage collector after `rr.init` to avoid hanging in a blocking flush + # when iterating on a dataloader with `num_workers` > 0 + # TODO(rcadene): remove `gc.collect` when rerun version 0.16 is out, which includes a fix + gc.collect() + + if mode == "distant": + rr.serve(open_browser=False, web_port=web_port, ws_port=ws_port) + + logging.info("Logging to Rerun") + + for batch in tqdm.tqdm(dataloader, total=len(dataloader)): + # iterate over the batch + for i in range(len(batch["index"])): + rr.set_time_sequence("frame_index", batch["frame_index"][i].item()) + rr.set_time_seconds("timestamp", batch["timestamp"][i].item()) + + # display each camera image + for key in dataset.camera_keys: + # TODO(rcadene): add `.compress()`? is it lossless? + rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i]))) + + # display each dimension of action space (e.g. actuators command) + if "action" in batch: + for dim_idx, val in enumerate(batch["action"][i]): + rr.log(f"action/{dim_idx}", rr.Scalar(val.item())) + + # display each dimension of observed state space (e.g. agent position in joint space) + if "observation.state" in batch: + for dim_idx, val in enumerate(batch["observation.state"][i]): + rr.log(f"state/{dim_idx}", rr.Scalar(val.item())) + + if "next.done" in batch: + rr.log("next.done", rr.Scalar(batch["next.done"][i].item())) + + if "next.reward" in batch: + rr.log("next.reward", rr.Scalar(batch["next.reward"][i].item())) + + if "next.success" in batch: + rr.log("next.success", rr.Scalar(batch["next.success"][i].item())) + + if mode == "local" and save: + # save .rrd locally + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + repo_id_str = repo_id.replace("/", "_") + rrd_path = output_dir / f"{repo_id_str}_episode_{episode_index}.rrd" + rr.save(rrd_path) + return rrd_path + + elif mode == "distant": + # stop the process from exiting since it is serving the websocket connection + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + print("Ctrl-C received. Exiting.") + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--repo-id", + type=str, + required=True, + help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht`).", + ) + parser.add_argument( + "--episode-index", + type=int, + required=True, + help="Episode to visualize.", + ) + parser.add_argument( + "--batch-size", + type=int, + default=32, + help="Batch size loaded by DataLoader.", + ) + parser.add_argument( + "--num-workers", + type=int, + default=4, + help="Number of processes of Dataloader for loading the data.", + ) + parser.add_argument( + "--mode", + type=str, + default="local", + help=( + "Mode of viewing between 'local' or 'distant'. " + "'local' requires data to be on a local machine. It spawns a viewer to visualize the data locally. " + "'distant' creates a server on the distant machine where the data is stored. Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine." + ), + ) + parser.add_argument( + "--web-port", + type=int, + default=9090, + help="Web port for rerun.io when `--mode distant` is set.", + ) + parser.add_argument( + "--ws-port", + type=int, + default=9087, + help="Web socket port for rerun.io when `--mode distant` is set.", + ) + parser.add_argument( + "--save", + type=int, + default=0, + help=( + "Save a .rrd file in the directory provided by `--output-dir`. " + "It also deactivates the spawning of a viewer. ", + "Visualize the data by running `rerun path/to/file.rrd` on your local machine.", + ), + ) + parser.add_argument( + "--output-dir", + type=str, + help="Directory path to write a .rrd file when `--save 1` is set.", + ) + + args = parser.parse_args() + visualize_dataset(**vars(args)) + + +if __name__ == "__main__": + main()