Add a gradio visualizer
This commit is contained in:
parent
4bde4fb987
commit
e5d8955802
|
@ -68,8 +68,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
self.stats = load_stats(repo_id, version, root)
|
||||
self.info = load_info(repo_id, version, root)
|
||||
self.include_video_images = include_video_images
|
||||
if self.video:
|
||||
self.videos_dir = load_videos(repo_id, version, root)
|
||||
# if self.video:
|
||||
# self.videos_dir = load_videos(repo_id, version, root)
|
||||
|
||||
@property
|
||||
def fps(self) -> int:
|
||||
|
|
|
@ -70,7 +70,11 @@ def hf_transform_to_torch(items_dict):
|
|||
if isinstance(first_item, PILImage.Image):
|
||||
to_tensor = transforms.ToTensor()
|
||||
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
|
||||
elif isinstance(first_item, dict) and "path" in first_item and "timestamp" in first_item:
|
||||
elif (
|
||||
isinstance(first_item, dict)
|
||||
and "path" in first_item
|
||||
and "timestamp" in first_item
|
||||
):
|
||||
# video frame will be processed downstream
|
||||
pass
|
||||
else:
|
||||
|
@ -85,7 +89,9 @@ def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset:
|
|||
# TODO(rcadene): clean this which enables getting a subset of dataset
|
||||
if split != "train":
|
||||
if "%" in split:
|
||||
raise NotImplementedError(f"We dont support splitting based on percentage for now ({split}).")
|
||||
raise NotImplementedError(
|
||||
f"We dont support splitting based on percentage for now ({split})."
|
||||
)
|
||||
match_from = re.search(r"train\[(\d+):\]", split)
|
||||
match_to = re.search(r"train\[:(\d+)\]", split)
|
||||
if match_from:
|
||||
|
@ -118,7 +124,10 @@ def load_episode_data_index(repo_id, version, root) -> dict[str, torch.Tensor]:
|
|||
path = Path(root) / repo_id / "meta_data" / "episode_data_index.safetensors"
|
||||
else:
|
||||
path = hf_hub_download(
|
||||
repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=version
|
||||
repo_id,
|
||||
"meta_data/episode_data_index.safetensors",
|
||||
repo_type="dataset",
|
||||
revision=version,
|
||||
)
|
||||
|
||||
return load_file(path)
|
||||
|
@ -135,7 +144,12 @@ def load_stats(repo_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
|
|||
if root is not None:
|
||||
path = Path(root) / repo_id / "meta_data" / "stats.safetensors"
|
||||
else:
|
||||
path = hf_hub_download(repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=version)
|
||||
path = hf_hub_download(
|
||||
repo_id,
|
||||
"meta_data/stats.safetensors",
|
||||
repo_type="dataset",
|
||||
revision=version,
|
||||
)
|
||||
|
||||
stats = load_file(path)
|
||||
return unflatten_dict(stats)
|
||||
|
@ -152,7 +166,9 @@ def load_info(repo_id, version, root) -> dict:
|
|||
if root is not None:
|
||||
path = Path(root) / repo_id / "meta_data" / "info.json"
|
||||
else:
|
||||
path = hf_hub_download(repo_id, "meta_data/info.json", repo_type="dataset", revision=version)
|
||||
path = hf_hub_download(
|
||||
repo_id, "meta_data/info.json", repo_type="dataset", revision=version
|
||||
)
|
||||
|
||||
with open(path) as f:
|
||||
info = json.load(f)
|
||||
|
@ -219,7 +235,9 @@ def load_previous_and_future_frames(
|
|||
ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to, 1)
|
||||
|
||||
# load timestamps
|
||||
ep_timestamps = hf_dataset.select_columns("timestamp")[ep_data_id_from:ep_data_id_to]["timestamp"]
|
||||
ep_timestamps = hf_dataset.select_columns("timestamp")[
|
||||
ep_data_id_from:ep_data_id_to
|
||||
]["timestamp"]
|
||||
ep_timestamps = torch.stack(ep_timestamps)
|
||||
|
||||
# we make the assumption that the timestamps are sorted
|
||||
|
@ -241,7 +259,9 @@ def load_previous_and_future_frames(
|
|||
is_pad = min_ > tolerance_s
|
||||
|
||||
# check violated query timestamps are all outside the episode range
|
||||
assert ((query_ts[is_pad] < ep_first_ts) | (ep_last_ts < query_ts[is_pad])).all(), (
|
||||
assert (
|
||||
(query_ts[is_pad] < ep_first_ts) | (ep_last_ts < query_ts[is_pad])
|
||||
).all(), (
|
||||
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tolerance_s=}) inside episode range."
|
||||
"This might be due to synchronization issues with timestamps during data collection."
|
||||
)
|
||||
|
@ -263,7 +283,9 @@ def load_previous_and_future_frames(
|
|||
return item
|
||||
|
||||
|
||||
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]:
|
||||
def calculate_episode_data_index(
|
||||
hf_dataset: datasets.Dataset,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset.
|
||||
|
||||
|
@ -334,7 +356,9 @@ def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset:
|
|||
}
|
||||
|
||||
def modify_ep_idx_func(example):
|
||||
example["episode_index"] = episode_idx_to_reset_idx_mapping[example["episode_index"].item()]
|
||||
example["episode_index"] = episode_idx_to_reset_idx_mapping[
|
||||
example["episode_index"].item()
|
||||
]
|
||||
return example
|
||||
|
||||
hf_dataset = hf_dataset.map(modify_ep_idx_func)
|
||||
|
|
|
@ -15,20 +15,20 @@
|
|||
# limitations under the License.
|
||||
import logging
|
||||
import math
|
||||
import multiprocessing
|
||||
import subprocess
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, ClassVar
|
||||
|
||||
import av
|
||||
import pyarrow as pa
|
||||
import rerun as rr
|
||||
import torch
|
||||
import torchvision
|
||||
from datasets.features.features import register_feature
|
||||
|
||||
import av
|
||||
import multiprocessing
|
||||
import rerun as rr
|
||||
from huggingface_hub.file_download import hf_hub_download
|
||||
|
||||
|
||||
class PeekableIterator:
|
||||
|
@ -75,20 +75,23 @@ class SequentialRerunVideoReader:
|
|||
Frames must be consumed in-order.
|
||||
"""
|
||||
|
||||
def __init__(self, video_dir: Path, tolerance: float, compression: int | None = 95):
|
||||
self.video_dir = video_dir
|
||||
def __init__(self, repo_id: str, tolerance: float, compression: int | None = 95):
|
||||
self.repo_id = repo_id
|
||||
self.streams: dict[Path, PeekableIterator] = {}
|
||||
self.tolerance = tolerance
|
||||
self.compression = compression
|
||||
|
||||
def next_frame(self, path, timestamp):
|
||||
def start_downloading(self, path):
|
||||
if path not in self.streams:
|
||||
self.streams[path] = PeekableIterator(
|
||||
stream_rerun_images_from_video_mp(
|
||||
self.video_dir / path, compression=self.compression
|
||||
self.repo_id, path, compression=self.compression
|
||||
)
|
||||
)
|
||||
|
||||
def next_frame(self, path, timestamp):
|
||||
self.start_downloading(path)
|
||||
|
||||
(next_frame_ts, next_frame) = self.streams[path].peek()
|
||||
|
||||
while (
|
||||
|
@ -108,7 +111,10 @@ class SequentialRerunVideoReader:
|
|||
|
||||
|
||||
def stream_rerun_images_from_video(
|
||||
video_path: Path, frame_queue: multiprocessing.Queue, compression: int | None
|
||||
repo_id,
|
||||
video_path: str,
|
||||
frame_queue: multiprocessing.Queue,
|
||||
compression: int | None,
|
||||
):
|
||||
"""Streams frames from a video file
|
||||
|
||||
|
@ -117,7 +123,9 @@ def stream_rerun_images_from_video(
|
|||
frame_queue (multiprocessing.Queue): Queue to store the frames
|
||||
compression (int | None): Compression level for the images
|
||||
"""
|
||||
container = av.open(video_path)
|
||||
cached_video_path = hf_hub_download(repo_id, video_path, repo_type="dataset")
|
||||
|
||||
container = av.open(cached_video_path)
|
||||
|
||||
for frame in container.decode(video=0):
|
||||
pts = float(frame.pts * frame.time_base)
|
||||
|
@ -131,14 +139,16 @@ def stream_rerun_images_from_video(
|
|||
frame_queue.put(None)
|
||||
|
||||
|
||||
def stream_rerun_images_from_video_mp(video_path: Path, compression: int | None) -> Any:
|
||||
def stream_rerun_images_from_video_mp(
|
||||
repo_id: str, video_path: str, compression: int | None
|
||||
) -> Any:
|
||||
frame_queue: multiprocessing.Queue[(int, rr.Image)] = multiprocessing.Queue(
|
||||
maxsize=5
|
||||
)
|
||||
|
||||
extractor_proc = multiprocessing.Process(
|
||||
target=stream_rerun_images_from_video,
|
||||
args=(video_path, frame_queue, compression),
|
||||
args=(repo_id, video_path, frame_queue, compression),
|
||||
)
|
||||
extractor_proc.start()
|
||||
|
||||
|
|
|
@ -0,0 +1,144 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
""" Visualize data from individual LeRobot dataset episodes.
|
||||
|
||||
```bash
|
||||
$ python lerobot/scripts/visualize_dataset_gradio.py
|
||||
$ open http://127.0.0.1:7860
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
|
||||
import gradio as gr
|
||||
import rerun as rr
|
||||
import rerun.blueprint as rrb
|
||||
import torch
|
||||
import tqdm
|
||||
from gradio_rerun import Rerun
|
||||
|
||||
import lerobot
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.video_utils import SequentialRerunVideoReader
|
||||
|
||||
|
||||
class EpisodeSampler(torch.utils.data.Sampler):
|
||||
def __init__(self, dataset, episode_index):
|
||||
from_idx = dataset.episode_data_index["from"][episode_index].item()
|
||||
to_idx = dataset.episode_data_index["to"][episode_index].item()
|
||||
self.frame_ids = range(from_idx, to_idx)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.frame_ids)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.frame_ids)
|
||||
|
||||
|
||||
@rr.thread_local_stream("lerobot_visualization")
|
||||
def visualize_dataset(
|
||||
dataset: dict[str, LeRobotDataset],
|
||||
episode_index: int,
|
||||
):
|
||||
stream = rr.binary_stream()
|
||||
|
||||
batch_size = 32
|
||||
num_workers = 0
|
||||
|
||||
dataset = dataset["dataset"]
|
||||
|
||||
rr.send_blueprint(
|
||||
rrb.Vertical(
|
||||
rrb.TimeSeriesView(),
|
||||
rrb.Horizontal(
|
||||
contents=[rrb.Spatial2DView(origin=key) for key in dataset.camera_keys]
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
yield stream.read()
|
||||
|
||||
video_reader = SequentialRerunVideoReader(
|
||||
dataset.repo_id, dataset.tolerance_s, compression=95
|
||||
)
|
||||
for key in dataset.camera_keys:
|
||||
video_reader.start_downloading(key)
|
||||
|
||||
episode_sampler = EpisodeSampler(dataset, episode_index)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=num_workers,
|
||||
batch_size=batch_size,
|
||||
sampler=episode_sampler,
|
||||
)
|
||||
|
||||
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
|
||||
# iterate over the batch
|
||||
for i in range(len(batch["index"])):
|
||||
rr.set_time_sequence("frame_index", batch["frame_index"][i].item())
|
||||
rr.set_time_seconds("timestamp", batch["timestamp"][i].item())
|
||||
|
||||
# display each camera image
|
||||
for key in dataset.camera_keys:
|
||||
img = video_reader.next_frame(
|
||||
batch[key]["path"][i], batch[key]["timestamp"][i]
|
||||
)
|
||||
if img is not None:
|
||||
rr.log(key, img)
|
||||
|
||||
# display each dimension of action space (e.g. actuators command)
|
||||
if "action" in batch:
|
||||
for dim_idx, val in enumerate(batch["action"][i]):
|
||||
rr.log(f"action/{dim_idx}", rr.Scalar(val.item()))
|
||||
|
||||
# display each dimension of observed state space (e.g. agent position in joint space)
|
||||
if "observation.state" in batch:
|
||||
for dim_idx, val in enumerate(batch["observation.state"][i]):
|
||||
rr.log(f"state/{dim_idx}", rr.Scalar(val.item()))
|
||||
|
||||
if "next.done" in batch:
|
||||
rr.log("next.done", rr.Scalar(batch["next.done"][i].item()))
|
||||
|
||||
if "next.reward" in batch:
|
||||
rr.log("next.reward", rr.Scalar(batch["next.reward"][i].item()))
|
||||
|
||||
if "next.success" in batch:
|
||||
rr.log("next.success", rr.Scalar(batch["next.success"][i].item()))
|
||||
|
||||
yield stream.read()
|
||||
|
||||
|
||||
def update_episodes(dataset, loaded_dataset):
|
||||
loaded_dataset["dataset"] = LeRobotDataset(dataset)
|
||||
dataset = loaded_dataset["dataset"]
|
||||
return gr.update(choices=list(range(dataset.num_episodes)), value=0)
|
||||
|
||||
|
||||
def main():
|
||||
with gr.Blocks() as demo:
|
||||
loaded_dataset = gr.State({})
|
||||
with gr.Row():
|
||||
with gr.Column(scale=0.3):
|
||||
with gr.Row():
|
||||
dataset = gr.Dropdown(choices=lerobot.available_real_world_datasets)
|
||||
with gr.Row():
|
||||
episode = gr.Dropdown(choices=[], interactive=True)
|
||||
with gr.Row():
|
||||
load = gr.Button("Load")
|
||||
with gr.Column():
|
||||
viewer = Rerun(streaming=True, height=800)
|
||||
|
||||
dataset.change(
|
||||
update_episodes, inputs=[dataset, loaded_dataset], outputs=[episode]
|
||||
)
|
||||
load.click(
|
||||
visualize_dataset, inputs=[loaded_dataset, episode], outputs=[viewer]
|
||||
)
|
||||
|
||||
demo.queue(default_concurrency_limit=10)
|
||||
demo.launch()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -123,7 +123,7 @@ def visualize_dataset(
|
|||
# to do so. Those will be retrieved directly.
|
||||
dataset = LeRobotDataset(repo_id, include_video_images=False)
|
||||
video_reader = SequentialRerunVideoReader(
|
||||
dataset.videos_dir.parent, dataset.tolerance_s, compression=95
|
||||
dataset.repo_id, dataset.tolerance_s, compression=95
|
||||
)
|
||||
|
||||
logging.info("Loading dataloader")
|
||||
|
|
|
@ -13,7 +13,7 @@ authors = [
|
|||
repository = "https://github.com/huggingface/lerobot"
|
||||
readme = "README.md"
|
||||
license = "Apache-2.0"
|
||||
classifiers=[
|
||||
classifiers = [
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Intended Audience :: Developers",
|
||||
"Intended Audience :: Education",
|
||||
|
@ -23,7 +23,7 @@ classifiers=[
|
|||
"License :: OSI Approved :: Apache Software License",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
]
|
||||
packages = [{include = "lerobot"}]
|
||||
packages = [{ include = "lerobot" }]
|
||||
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
|
@ -31,7 +31,7 @@ python = ">=3.10,<3.13"
|
|||
termcolor = ">=2.4.0"
|
||||
omegaconf = ">=2.3.0"
|
||||
wandb = ">=0.16.3"
|
||||
imageio = {extras = ["ffmpeg"], version = ">=2.34.0"}
|
||||
imageio = { extras = ["ffmpeg"], version = ">=2.34.0" }
|
||||
gdown = ">=5.1.0"
|
||||
hydra-core = ">=1.3.2"
|
||||
einops = ">=0.8.0"
|
||||
|
@ -43,21 +43,22 @@ opencv-python = ">=4.9.0"
|
|||
diffusers = "^0.27.2"
|
||||
torchvision = ">=0.18.0"
|
||||
h5py = ">=3.10.0"
|
||||
huggingface-hub = {extras = ["hf-transfer"], version = "^0.23.0"}
|
||||
huggingface-hub = { extras = ["hf-transfer"], version = "^0.23.0" }
|
||||
gymnasium = ">=0.29.1"
|
||||
cmake = ">=3.29.0.1"
|
||||
gym-pusht = { version = ">=0.1.3", optional = true}
|
||||
gym-xarm = { version = ">=0.1.1", optional = true}
|
||||
gym-aloha = { version = ">=0.1.1", optional = true}
|
||||
pre-commit = {version = ">=3.7.0", optional = true}
|
||||
debugpy = {version = ">=1.8.1", optional = true}
|
||||
pytest = {version = ">=8.1.0", optional = true}
|
||||
pytest-cov = {version = ">=5.0.0", optional = true}
|
||||
gym-pusht = { version = ">=0.1.3", optional = true }
|
||||
gym-xarm = { version = ">=0.1.1", optional = true }
|
||||
gym-aloha = { version = ">=0.1.1", optional = true }
|
||||
pre-commit = { version = ">=3.7.0", optional = true }
|
||||
debugpy = { version = ">=1.8.1", optional = true }
|
||||
pytest = { version = ">=8.1.0", optional = true }
|
||||
pytest-cov = { version = ">=5.0.0", optional = true }
|
||||
datasets = "^2.19.0"
|
||||
imagecodecs = { version = ">=2024.1.1", optional = true }
|
||||
pyav = ">=12.0.5"
|
||||
moviepy = ">=1.0.3"
|
||||
rerun-sdk = ">=0.15.1"
|
||||
rerun-sdk = ">=0.16.0"
|
||||
gradio_rerun = { url = "https://huggingface.co/spaces/jleibs/rerun_streaming_poc/resolve/main/gradio_rerun-0.0.2-py3-none-any.whl?download=true" }
|
||||
|
||||
|
||||
[tool.poetry.extras]
|
||||
|
|
Loading…
Reference in New Issue