Add a gradio visualizer

This commit is contained in:
Jeremy Leibs 2024-05-22 23:55:00 +02:00
parent 4bde4fb987
commit e5d8955802
6 changed files with 215 additions and 36 deletions

View File

@ -68,8 +68,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.stats = load_stats(repo_id, version, root)
self.info = load_info(repo_id, version, root)
self.include_video_images = include_video_images
if self.video:
self.videos_dir = load_videos(repo_id, version, root)
# if self.video:
# self.videos_dir = load_videos(repo_id, version, root)
@property
def fps(self) -> int:

View File

@ -70,7 +70,11 @@ def hf_transform_to_torch(items_dict):
if isinstance(first_item, PILImage.Image):
to_tensor = transforms.ToTensor()
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
elif isinstance(first_item, dict) and "path" in first_item and "timestamp" in first_item:
elif (
isinstance(first_item, dict)
and "path" in first_item
and "timestamp" in first_item
):
# video frame will be processed downstream
pass
else:
@ -85,7 +89,9 @@ def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset:
# TODO(rcadene): clean this which enables getting a subset of dataset
if split != "train":
if "%" in split:
raise NotImplementedError(f"We dont support splitting based on percentage for now ({split}).")
raise NotImplementedError(
f"We dont support splitting based on percentage for now ({split})."
)
match_from = re.search(r"train\[(\d+):\]", split)
match_to = re.search(r"train\[:(\d+)\]", split)
if match_from:
@ -118,7 +124,10 @@ def load_episode_data_index(repo_id, version, root) -> dict[str, torch.Tensor]:
path = Path(root) / repo_id / "meta_data" / "episode_data_index.safetensors"
else:
path = hf_hub_download(
repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=version
repo_id,
"meta_data/episode_data_index.safetensors",
repo_type="dataset",
revision=version,
)
return load_file(path)
@ -135,7 +144,12 @@ def load_stats(repo_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
if root is not None:
path = Path(root) / repo_id / "meta_data" / "stats.safetensors"
else:
path = hf_hub_download(repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=version)
path = hf_hub_download(
repo_id,
"meta_data/stats.safetensors",
repo_type="dataset",
revision=version,
)
stats = load_file(path)
return unflatten_dict(stats)
@ -152,7 +166,9 @@ def load_info(repo_id, version, root) -> dict:
if root is not None:
path = Path(root) / repo_id / "meta_data" / "info.json"
else:
path = hf_hub_download(repo_id, "meta_data/info.json", repo_type="dataset", revision=version)
path = hf_hub_download(
repo_id, "meta_data/info.json", repo_type="dataset", revision=version
)
with open(path) as f:
info = json.load(f)
@ -219,7 +235,9 @@ def load_previous_and_future_frames(
ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to, 1)
# load timestamps
ep_timestamps = hf_dataset.select_columns("timestamp")[ep_data_id_from:ep_data_id_to]["timestamp"]
ep_timestamps = hf_dataset.select_columns("timestamp")[
ep_data_id_from:ep_data_id_to
]["timestamp"]
ep_timestamps = torch.stack(ep_timestamps)
# we make the assumption that the timestamps are sorted
@ -241,7 +259,9 @@ def load_previous_and_future_frames(
is_pad = min_ > tolerance_s
# check violated query timestamps are all outside the episode range
assert ((query_ts[is_pad] < ep_first_ts) | (ep_last_ts < query_ts[is_pad])).all(), (
assert (
(query_ts[is_pad] < ep_first_ts) | (ep_last_ts < query_ts[is_pad])
).all(), (
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tolerance_s=}) inside episode range."
"This might be due to synchronization issues with timestamps during data collection."
)
@ -263,7 +283,9 @@ def load_previous_and_future_frames(
return item
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]:
def calculate_episode_data_index(
hf_dataset: datasets.Dataset,
) -> Dict[str, torch.Tensor]:
"""
Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset.
@ -334,7 +356,9 @@ def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset:
}
def modify_ep_idx_func(example):
example["episode_index"] = episode_idx_to_reset_idx_mapping[example["episode_index"].item()]
example["episode_index"] = episode_idx_to_reset_idx_mapping[
example["episode_index"].item()
]
return example
hf_dataset = hf_dataset.map(modify_ep_idx_func)

View File

@ -15,20 +15,20 @@
# limitations under the License.
import logging
import math
import multiprocessing
import subprocess
import warnings
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, ClassVar
import av
import pyarrow as pa
import rerun as rr
import torch
import torchvision
from datasets.features.features import register_feature
import av
import multiprocessing
import rerun as rr
from huggingface_hub.file_download import hf_hub_download
class PeekableIterator:
@ -75,20 +75,23 @@ class SequentialRerunVideoReader:
Frames must be consumed in-order.
"""
def __init__(self, video_dir: Path, tolerance: float, compression: int | None = 95):
self.video_dir = video_dir
def __init__(self, repo_id: str, tolerance: float, compression: int | None = 95):
self.repo_id = repo_id
self.streams: dict[Path, PeekableIterator] = {}
self.tolerance = tolerance
self.compression = compression
def next_frame(self, path, timestamp):
def start_downloading(self, path):
if path not in self.streams:
self.streams[path] = PeekableIterator(
stream_rerun_images_from_video_mp(
self.video_dir / path, compression=self.compression
self.repo_id, path, compression=self.compression
)
)
def next_frame(self, path, timestamp):
self.start_downloading(path)
(next_frame_ts, next_frame) = self.streams[path].peek()
while (
@ -108,7 +111,10 @@ class SequentialRerunVideoReader:
def stream_rerun_images_from_video(
video_path: Path, frame_queue: multiprocessing.Queue, compression: int | None
repo_id,
video_path: str,
frame_queue: multiprocessing.Queue,
compression: int | None,
):
"""Streams frames from a video file
@ -117,7 +123,9 @@ def stream_rerun_images_from_video(
frame_queue (multiprocessing.Queue): Queue to store the frames
compression (int | None): Compression level for the images
"""
container = av.open(video_path)
cached_video_path = hf_hub_download(repo_id, video_path, repo_type="dataset")
container = av.open(cached_video_path)
for frame in container.decode(video=0):
pts = float(frame.pts * frame.time_base)
@ -131,14 +139,16 @@ def stream_rerun_images_from_video(
frame_queue.put(None)
def stream_rerun_images_from_video_mp(video_path: Path, compression: int | None) -> Any:
def stream_rerun_images_from_video_mp(
repo_id: str, video_path: str, compression: int | None
) -> Any:
frame_queue: multiprocessing.Queue[(int, rr.Image)] = multiprocessing.Queue(
maxsize=5
)
extractor_proc = multiprocessing.Process(
target=stream_rerun_images_from_video,
args=(video_path, frame_queue, compression),
args=(repo_id, video_path, frame_queue, compression),
)
extractor_proc.start()

View File

@ -0,0 +1,144 @@
#!/usr/bin/env python
""" Visualize data from individual LeRobot dataset episodes.
```bash
$ python lerobot/scripts/visualize_dataset_gradio.py
$ open http://127.0.0.1:7860
```
"""
import gradio as gr
import rerun as rr
import rerun.blueprint as rrb
import torch
import tqdm
from gradio_rerun import Rerun
import lerobot
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.video_utils import SequentialRerunVideoReader
class EpisodeSampler(torch.utils.data.Sampler):
def __init__(self, dataset, episode_index):
from_idx = dataset.episode_data_index["from"][episode_index].item()
to_idx = dataset.episode_data_index["to"][episode_index].item()
self.frame_ids = range(from_idx, to_idx)
def __iter__(self):
return iter(self.frame_ids)
def __len__(self):
return len(self.frame_ids)
@rr.thread_local_stream("lerobot_visualization")
def visualize_dataset(
dataset: dict[str, LeRobotDataset],
episode_index: int,
):
stream = rr.binary_stream()
batch_size = 32
num_workers = 0
dataset = dataset["dataset"]
rr.send_blueprint(
rrb.Vertical(
rrb.TimeSeriesView(),
rrb.Horizontal(
contents=[rrb.Spatial2DView(origin=key) for key in dataset.camera_keys]
),
)
)
yield stream.read()
video_reader = SequentialRerunVideoReader(
dataset.repo_id, dataset.tolerance_s, compression=95
)
for key in dataset.camera_keys:
video_reader.start_downloading(key)
episode_sampler = EpisodeSampler(dataset, episode_index)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=num_workers,
batch_size=batch_size,
sampler=episode_sampler,
)
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
# iterate over the batch
for i in range(len(batch["index"])):
rr.set_time_sequence("frame_index", batch["frame_index"][i].item())
rr.set_time_seconds("timestamp", batch["timestamp"][i].item())
# display each camera image
for key in dataset.camera_keys:
img = video_reader.next_frame(
batch[key]["path"][i], batch[key]["timestamp"][i]
)
if img is not None:
rr.log(key, img)
# display each dimension of action space (e.g. actuators command)
if "action" in batch:
for dim_idx, val in enumerate(batch["action"][i]):
rr.log(f"action/{dim_idx}", rr.Scalar(val.item()))
# display each dimension of observed state space (e.g. agent position in joint space)
if "observation.state" in batch:
for dim_idx, val in enumerate(batch["observation.state"][i]):
rr.log(f"state/{dim_idx}", rr.Scalar(val.item()))
if "next.done" in batch:
rr.log("next.done", rr.Scalar(batch["next.done"][i].item()))
if "next.reward" in batch:
rr.log("next.reward", rr.Scalar(batch["next.reward"][i].item()))
if "next.success" in batch:
rr.log("next.success", rr.Scalar(batch["next.success"][i].item()))
yield stream.read()
def update_episodes(dataset, loaded_dataset):
loaded_dataset["dataset"] = LeRobotDataset(dataset)
dataset = loaded_dataset["dataset"]
return gr.update(choices=list(range(dataset.num_episodes)), value=0)
def main():
with gr.Blocks() as demo:
loaded_dataset = gr.State({})
with gr.Row():
with gr.Column(scale=0.3):
with gr.Row():
dataset = gr.Dropdown(choices=lerobot.available_real_world_datasets)
with gr.Row():
episode = gr.Dropdown(choices=[], interactive=True)
with gr.Row():
load = gr.Button("Load")
with gr.Column():
viewer = Rerun(streaming=True, height=800)
dataset.change(
update_episodes, inputs=[dataset, loaded_dataset], outputs=[episode]
)
load.click(
visualize_dataset, inputs=[loaded_dataset, episode], outputs=[viewer]
)
demo.queue(default_concurrency_limit=10)
demo.launch()
if __name__ == "__main__":
main()

View File

@ -123,7 +123,7 @@ def visualize_dataset(
# to do so. Those will be retrieved directly.
dataset = LeRobotDataset(repo_id, include_video_images=False)
video_reader = SequentialRerunVideoReader(
dataset.videos_dir.parent, dataset.tolerance_s, compression=95
dataset.repo_id, dataset.tolerance_s, compression=95
)
logging.info("Loading dataloader")

View File

@ -13,7 +13,7 @@ authors = [
repository = "https://github.com/huggingface/lerobot"
readme = "README.md"
license = "Apache-2.0"
classifiers=[
classifiers = [
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
"Intended Audience :: Education",
@ -23,7 +23,7 @@ classifiers=[
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3.10",
]
packages = [{include = "lerobot"}]
packages = [{ include = "lerobot" }]
[tool.poetry.dependencies]
@ -31,7 +31,7 @@ python = ">=3.10,<3.13"
termcolor = ">=2.4.0"
omegaconf = ">=2.3.0"
wandb = ">=0.16.3"
imageio = {extras = ["ffmpeg"], version = ">=2.34.0"}
imageio = { extras = ["ffmpeg"], version = ">=2.34.0" }
gdown = ">=5.1.0"
hydra-core = ">=1.3.2"
einops = ">=0.8.0"
@ -43,21 +43,22 @@ opencv-python = ">=4.9.0"
diffusers = "^0.27.2"
torchvision = ">=0.18.0"
h5py = ">=3.10.0"
huggingface-hub = {extras = ["hf-transfer"], version = "^0.23.0"}
huggingface-hub = { extras = ["hf-transfer"], version = "^0.23.0" }
gymnasium = ">=0.29.1"
cmake = ">=3.29.0.1"
gym-pusht = { version = ">=0.1.3", optional = true}
gym-xarm = { version = ">=0.1.1", optional = true}
gym-aloha = { version = ">=0.1.1", optional = true}
pre-commit = {version = ">=3.7.0", optional = true}
debugpy = {version = ">=1.8.1", optional = true}
pytest = {version = ">=8.1.0", optional = true}
pytest-cov = {version = ">=5.0.0", optional = true}
gym-pusht = { version = ">=0.1.3", optional = true }
gym-xarm = { version = ">=0.1.1", optional = true }
gym-aloha = { version = ">=0.1.1", optional = true }
pre-commit = { version = ">=3.7.0", optional = true }
debugpy = { version = ">=1.8.1", optional = true }
pytest = { version = ">=8.1.0", optional = true }
pytest-cov = { version = ">=5.0.0", optional = true }
datasets = "^2.19.0"
imagecodecs = { version = ">=2024.1.1", optional = true }
pyav = ">=12.0.5"
moviepy = ">=1.0.3"
rerun-sdk = ">=0.15.1"
rerun-sdk = ">=0.16.0"
gradio_rerun = { url = "https://huggingface.co/spaces/jleibs/rerun_streaming_poc/resolve/main/gradio_rerun-0.0.2-py3-none-any.whl?download=true" }
[tool.poetry.extras]