From e5d8955802c8ba385ebc7897b258ec6032989643 Mon Sep 17 00:00:00 2001 From: Jeremy Leibs Date: Wed, 22 May 2024 23:55:00 +0200 Subject: [PATCH] Add a gradio visualizer --- lerobot/common/datasets/lerobot_dataset.py | 4 +- lerobot/common/datasets/utils.py | 42 ++++-- lerobot/common/datasets/video_utils.py | 34 +++-- lerobot/scripts/visualize_dataset_gradio.py | 144 ++++++++++++++++++++ lerobot/scripts/visualize_dataset_rerun.py | 2 +- pyproject.toml | 25 ++-- 6 files changed, 215 insertions(+), 36 deletions(-) create mode 100644 lerobot/scripts/visualize_dataset_gradio.py diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 7885c783..cbaac07f 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -68,8 +68,8 @@ class LeRobotDataset(torch.utils.data.Dataset): self.stats = load_stats(repo_id, version, root) self.info = load_info(repo_id, version, root) self.include_video_images = include_video_images - if self.video: - self.videos_dir = load_videos(repo_id, version, root) + # if self.video: + # self.videos_dir = load_videos(repo_id, version, root) @property def fps(self) -> int: diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 86fef8d4..5083292d 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -70,7 +70,11 @@ def hf_transform_to_torch(items_dict): if isinstance(first_item, PILImage.Image): to_tensor = transforms.ToTensor() items_dict[key] = [to_tensor(img) for img in items_dict[key]] - elif isinstance(first_item, dict) and "path" in first_item and "timestamp" in first_item: + elif ( + isinstance(first_item, dict) + and "path" in first_item + and "timestamp" in first_item + ): # video frame will be processed downstream pass else: @@ -85,7 +89,9 @@ def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset: # TODO(rcadene): clean this which enables getting a subset of dataset if split != "train": if "%" in split: - raise NotImplementedError(f"We dont support splitting based on percentage for now ({split}).") + raise NotImplementedError( + f"We dont support splitting based on percentage for now ({split})." + ) match_from = re.search(r"train\[(\d+):\]", split) match_to = re.search(r"train\[:(\d+)\]", split) if match_from: @@ -118,7 +124,10 @@ def load_episode_data_index(repo_id, version, root) -> dict[str, torch.Tensor]: path = Path(root) / repo_id / "meta_data" / "episode_data_index.safetensors" else: path = hf_hub_download( - repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=version + repo_id, + "meta_data/episode_data_index.safetensors", + repo_type="dataset", + revision=version, ) return load_file(path) @@ -135,7 +144,12 @@ def load_stats(repo_id, version, root) -> dict[str, dict[str, torch.Tensor]]: if root is not None: path = Path(root) / repo_id / "meta_data" / "stats.safetensors" else: - path = hf_hub_download(repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=version) + path = hf_hub_download( + repo_id, + "meta_data/stats.safetensors", + repo_type="dataset", + revision=version, + ) stats = load_file(path) return unflatten_dict(stats) @@ -152,7 +166,9 @@ def load_info(repo_id, version, root) -> dict: if root is not None: path = Path(root) / repo_id / "meta_data" / "info.json" else: - path = hf_hub_download(repo_id, "meta_data/info.json", repo_type="dataset", revision=version) + path = hf_hub_download( + repo_id, "meta_data/info.json", repo_type="dataset", revision=version + ) with open(path) as f: info = json.load(f) @@ -219,7 +235,9 @@ def load_previous_and_future_frames( ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to, 1) # load timestamps - ep_timestamps = hf_dataset.select_columns("timestamp")[ep_data_id_from:ep_data_id_to]["timestamp"] + ep_timestamps = hf_dataset.select_columns("timestamp")[ + ep_data_id_from:ep_data_id_to + ]["timestamp"] ep_timestamps = torch.stack(ep_timestamps) # we make the assumption that the timestamps are sorted @@ -241,7 +259,9 @@ def load_previous_and_future_frames( is_pad = min_ > tolerance_s # check violated query timestamps are all outside the episode range - assert ((query_ts[is_pad] < ep_first_ts) | (ep_last_ts < query_ts[is_pad])).all(), ( + assert ( + (query_ts[is_pad] < ep_first_ts) | (ep_last_ts < query_ts[is_pad]) + ).all(), ( f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tolerance_s=}) inside episode range." "This might be due to synchronization issues with timestamps during data collection." ) @@ -263,7 +283,9 @@ def load_previous_and_future_frames( return item -def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]: +def calculate_episode_data_index( + hf_dataset: datasets.Dataset, +) -> Dict[str, torch.Tensor]: """ Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset. @@ -334,7 +356,9 @@ def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset: } def modify_ep_idx_func(example): - example["episode_index"] = episode_idx_to_reset_idx_mapping[example["episode_index"].item()] + example["episode_index"] = episode_idx_to_reset_idx_mapping[ + example["episode_index"].item() + ] return example hf_dataset = hf_dataset.map(modify_ep_idx_func) diff --git a/lerobot/common/datasets/video_utils.py b/lerobot/common/datasets/video_utils.py index 1f0bf4f8..f05d1ccb 100644 --- a/lerobot/common/datasets/video_utils.py +++ b/lerobot/common/datasets/video_utils.py @@ -15,20 +15,20 @@ # limitations under the License. import logging import math +import multiprocessing import subprocess import warnings from dataclasses import dataclass, field from pathlib import Path from typing import Any, ClassVar +import av import pyarrow as pa +import rerun as rr import torch import torchvision from datasets.features.features import register_feature - -import av -import multiprocessing -import rerun as rr +from huggingface_hub.file_download import hf_hub_download class PeekableIterator: @@ -75,20 +75,23 @@ class SequentialRerunVideoReader: Frames must be consumed in-order. """ - def __init__(self, video_dir: Path, tolerance: float, compression: int | None = 95): - self.video_dir = video_dir + def __init__(self, repo_id: str, tolerance: float, compression: int | None = 95): + self.repo_id = repo_id self.streams: dict[Path, PeekableIterator] = {} self.tolerance = tolerance self.compression = compression - def next_frame(self, path, timestamp): + def start_downloading(self, path): if path not in self.streams: self.streams[path] = PeekableIterator( stream_rerun_images_from_video_mp( - self.video_dir / path, compression=self.compression + self.repo_id, path, compression=self.compression ) ) + def next_frame(self, path, timestamp): + self.start_downloading(path) + (next_frame_ts, next_frame) = self.streams[path].peek() while ( @@ -108,7 +111,10 @@ class SequentialRerunVideoReader: def stream_rerun_images_from_video( - video_path: Path, frame_queue: multiprocessing.Queue, compression: int | None + repo_id, + video_path: str, + frame_queue: multiprocessing.Queue, + compression: int | None, ): """Streams frames from a video file @@ -117,7 +123,9 @@ def stream_rerun_images_from_video( frame_queue (multiprocessing.Queue): Queue to store the frames compression (int | None): Compression level for the images """ - container = av.open(video_path) + cached_video_path = hf_hub_download(repo_id, video_path, repo_type="dataset") + + container = av.open(cached_video_path) for frame in container.decode(video=0): pts = float(frame.pts * frame.time_base) @@ -131,14 +139,16 @@ def stream_rerun_images_from_video( frame_queue.put(None) -def stream_rerun_images_from_video_mp(video_path: Path, compression: int | None) -> Any: +def stream_rerun_images_from_video_mp( + repo_id: str, video_path: str, compression: int | None +) -> Any: frame_queue: multiprocessing.Queue[(int, rr.Image)] = multiprocessing.Queue( maxsize=5 ) extractor_proc = multiprocessing.Process( target=stream_rerun_images_from_video, - args=(video_path, frame_queue, compression), + args=(repo_id, video_path, frame_queue, compression), ) extractor_proc.start() diff --git a/lerobot/scripts/visualize_dataset_gradio.py b/lerobot/scripts/visualize_dataset_gradio.py new file mode 100644 index 00000000..322136a5 --- /dev/null +++ b/lerobot/scripts/visualize_dataset_gradio.py @@ -0,0 +1,144 @@ +#!/usr/bin/env python + +""" Visualize data from individual LeRobot dataset episodes. + +```bash +$ python lerobot/scripts/visualize_dataset_gradio.py +$ open http://127.0.0.1:7860 +``` + +""" + + +import gradio as gr +import rerun as rr +import rerun.blueprint as rrb +import torch +import tqdm +from gradio_rerun import Rerun + +import lerobot +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.datasets.video_utils import SequentialRerunVideoReader + + +class EpisodeSampler(torch.utils.data.Sampler): + def __init__(self, dataset, episode_index): + from_idx = dataset.episode_data_index["from"][episode_index].item() + to_idx = dataset.episode_data_index["to"][episode_index].item() + self.frame_ids = range(from_idx, to_idx) + + def __iter__(self): + return iter(self.frame_ids) + + def __len__(self): + return len(self.frame_ids) + + +@rr.thread_local_stream("lerobot_visualization") +def visualize_dataset( + dataset: dict[str, LeRobotDataset], + episode_index: int, +): + stream = rr.binary_stream() + + batch_size = 32 + num_workers = 0 + + dataset = dataset["dataset"] + + rr.send_blueprint( + rrb.Vertical( + rrb.TimeSeriesView(), + rrb.Horizontal( + contents=[rrb.Spatial2DView(origin=key) for key in dataset.camera_keys] + ), + ) + ) + + yield stream.read() + + video_reader = SequentialRerunVideoReader( + dataset.repo_id, dataset.tolerance_s, compression=95 + ) + for key in dataset.camera_keys: + video_reader.start_downloading(key) + + episode_sampler = EpisodeSampler(dataset, episode_index) + + dataloader = torch.utils.data.DataLoader( + dataset, + num_workers=num_workers, + batch_size=batch_size, + sampler=episode_sampler, + ) + + for batch in tqdm.tqdm(dataloader, total=len(dataloader)): + # iterate over the batch + for i in range(len(batch["index"])): + rr.set_time_sequence("frame_index", batch["frame_index"][i].item()) + rr.set_time_seconds("timestamp", batch["timestamp"][i].item()) + + # display each camera image + for key in dataset.camera_keys: + img = video_reader.next_frame( + batch[key]["path"][i], batch[key]["timestamp"][i] + ) + if img is not None: + rr.log(key, img) + + # display each dimension of action space (e.g. actuators command) + if "action" in batch: + for dim_idx, val in enumerate(batch["action"][i]): + rr.log(f"action/{dim_idx}", rr.Scalar(val.item())) + + # display each dimension of observed state space (e.g. agent position in joint space) + if "observation.state" in batch: + for dim_idx, val in enumerate(batch["observation.state"][i]): + rr.log(f"state/{dim_idx}", rr.Scalar(val.item())) + + if "next.done" in batch: + rr.log("next.done", rr.Scalar(batch["next.done"][i].item())) + + if "next.reward" in batch: + rr.log("next.reward", rr.Scalar(batch["next.reward"][i].item())) + + if "next.success" in batch: + rr.log("next.success", rr.Scalar(batch["next.success"][i].item())) + + yield stream.read() + + +def update_episodes(dataset, loaded_dataset): + loaded_dataset["dataset"] = LeRobotDataset(dataset) + dataset = loaded_dataset["dataset"] + return gr.update(choices=list(range(dataset.num_episodes)), value=0) + + +def main(): + with gr.Blocks() as demo: + loaded_dataset = gr.State({}) + with gr.Row(): + with gr.Column(scale=0.3): + with gr.Row(): + dataset = gr.Dropdown(choices=lerobot.available_real_world_datasets) + with gr.Row(): + episode = gr.Dropdown(choices=[], interactive=True) + with gr.Row(): + load = gr.Button("Load") + with gr.Column(): + viewer = Rerun(streaming=True, height=800) + + dataset.change( + update_episodes, inputs=[dataset, loaded_dataset], outputs=[episode] + ) + load.click( + visualize_dataset, inputs=[loaded_dataset, episode], outputs=[viewer] + ) + + demo.queue(default_concurrency_limit=10) + demo.launch() + + +if __name__ == "__main__": + main() diff --git a/lerobot/scripts/visualize_dataset_rerun.py b/lerobot/scripts/visualize_dataset_rerun.py index 6bb8d6fa..da04f37a 100644 --- a/lerobot/scripts/visualize_dataset_rerun.py +++ b/lerobot/scripts/visualize_dataset_rerun.py @@ -123,7 +123,7 @@ def visualize_dataset( # to do so. Those will be retrieved directly. dataset = LeRobotDataset(repo_id, include_video_images=False) video_reader = SequentialRerunVideoReader( - dataset.videos_dir.parent, dataset.tolerance_s, compression=95 + dataset.repo_id, dataset.tolerance_s, compression=95 ) logging.info("Loading dataloader") diff --git a/pyproject.toml b/pyproject.toml index f043c9de..2e87ab12 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ authors = [ repository = "https://github.com/huggingface/lerobot" readme = "README.md" license = "Apache-2.0" -classifiers=[ +classifiers = [ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", "Intended Audience :: Education", @@ -23,7 +23,7 @@ classifiers=[ "License :: OSI Approved :: Apache Software License", "Programming Language :: Python :: 3.10", ] -packages = [{include = "lerobot"}] +packages = [{ include = "lerobot" }] [tool.poetry.dependencies] @@ -31,7 +31,7 @@ python = ">=3.10,<3.13" termcolor = ">=2.4.0" omegaconf = ">=2.3.0" wandb = ">=0.16.3" -imageio = {extras = ["ffmpeg"], version = ">=2.34.0"} +imageio = { extras = ["ffmpeg"], version = ">=2.34.0" } gdown = ">=5.1.0" hydra-core = ">=1.3.2" einops = ">=0.8.0" @@ -43,21 +43,22 @@ opencv-python = ">=4.9.0" diffusers = "^0.27.2" torchvision = ">=0.18.0" h5py = ">=3.10.0" -huggingface-hub = {extras = ["hf-transfer"], version = "^0.23.0"} +huggingface-hub = { extras = ["hf-transfer"], version = "^0.23.0" } gymnasium = ">=0.29.1" cmake = ">=3.29.0.1" -gym-pusht = { version = ">=0.1.3", optional = true} -gym-xarm = { version = ">=0.1.1", optional = true} -gym-aloha = { version = ">=0.1.1", optional = true} -pre-commit = {version = ">=3.7.0", optional = true} -debugpy = {version = ">=1.8.1", optional = true} -pytest = {version = ">=8.1.0", optional = true} -pytest-cov = {version = ">=5.0.0", optional = true} +gym-pusht = { version = ">=0.1.3", optional = true } +gym-xarm = { version = ">=0.1.1", optional = true } +gym-aloha = { version = ">=0.1.1", optional = true } +pre-commit = { version = ">=3.7.0", optional = true } +debugpy = { version = ">=1.8.1", optional = true } +pytest = { version = ">=8.1.0", optional = true } +pytest-cov = { version = ">=5.0.0", optional = true } datasets = "^2.19.0" imagecodecs = { version = ">=2024.1.1", optional = true } pyav = ">=12.0.5" moviepy = ">=1.0.3" -rerun-sdk = ">=0.15.1" +rerun-sdk = ">=0.16.0" +gradio_rerun = { url = "https://huggingface.co/spaces/jleibs/rerun_streaming_poc/resolve/main/gradio_rerun-0.0.2-py3-none-any.whl?download=true" } [tool.poetry.extras]