305 lines
10 KiB
Python
305 lines
10 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
""" Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
|
|
|
|
Note: The last frame of the episode doesnt always correspond to a final state.
|
|
That's because our datasets are composed of transition from state to state up to
|
|
the antepenultimate state associated to the ultimate action to arrive in the final state.
|
|
However, there might not be a transition from a final state to another state.
|
|
|
|
Note: This script aims to visualize the data used to train the neural networks.
|
|
~What you see is what you get~. When visualizing image modality, it is often expected to observe
|
|
lossly compression artifacts since these images have been decoded from compressed mp4 videos to
|
|
save disk space. The compression factor applied has been tuned to not affect success rate.
|
|
|
|
Example of usage:
|
|
|
|
- Visualize data stored on a local machine:
|
|
```bash
|
|
local$ python lerobot/scripts/visualize_dataset_html.py \
|
|
--repo-id lerobot/pusht
|
|
|
|
local$ open http://localhost:9090
|
|
```
|
|
|
|
- Visualize data stored on a distant machine with a local viewer:
|
|
```bash
|
|
distant$ python lerobot/scripts/visualize_dataset_html.py \
|
|
--repo-id lerobot/pusht
|
|
|
|
local$ ssh -L 9090:localhost:9090 distant # create a ssh tunnel
|
|
local$ open http://localhost:9090
|
|
```
|
|
|
|
- Select episodes to visualize:
|
|
```bash
|
|
python lerobot/scripts/visualize_dataset_html.py \
|
|
--repo-id lerobot/pusht \
|
|
--episodes 7 3 5 1 4
|
|
```
|
|
"""
|
|
|
|
import argparse
|
|
import logging
|
|
import shutil
|
|
from pathlib import Path
|
|
|
|
import tqdm
|
|
from flask import Flask, redirect, render_template, url_for
|
|
|
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
|
from lerobot.common.utils.utils import init_logging
|
|
|
|
|
|
def run_server(
|
|
dataset: LeRobotDataset,
|
|
episodes: list[int],
|
|
host: str,
|
|
port: str,
|
|
static_folder: Path,
|
|
template_folder: Path,
|
|
):
|
|
app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve())
|
|
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
|
|
|
|
@app.route("/")
|
|
def index():
|
|
# home page redirects to the first episode page
|
|
[dataset_namespace, dataset_name] = dataset.repo_id.split("/")
|
|
first_episode_id = episodes[0]
|
|
return redirect(
|
|
url_for(
|
|
"show_episode",
|
|
dataset_namespace=dataset_namespace,
|
|
dataset_name=dataset_name,
|
|
episode_id=first_episode_id,
|
|
)
|
|
)
|
|
|
|
@app.route("/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>")
|
|
def show_episode(dataset_namespace, dataset_name, episode_id):
|
|
dataset_info = {
|
|
"repo_id": dataset.repo_id,
|
|
"num_samples": dataset.num_samples,
|
|
"num_episodes": dataset.num_episodes,
|
|
"fps": dataset.fps,
|
|
}
|
|
video_paths = get_episode_video_paths(dataset, episode_id)
|
|
language_instruction = get_episode_language_instruction(dataset, episode_id)
|
|
videos_info = [
|
|
{"url": url_for("static", filename=video_path), "filename": Path(video_path).name}
|
|
for video_path in video_paths
|
|
]
|
|
if language_instruction:
|
|
videos_info[0]["language_instruction"] = language_instruction
|
|
|
|
ep_csv_url = url_for("static", filename=get_ep_csv_fname(episode_id))
|
|
return render_template(
|
|
"visualize_dataset_template.html",
|
|
episode_id=episode_id,
|
|
episodes=episodes,
|
|
dataset_info=dataset_info,
|
|
videos_info=videos_info,
|
|
ep_csv_url=ep_csv_url,
|
|
has_policy=False,
|
|
)
|
|
|
|
app.run(host=host, port=port)
|
|
|
|
|
|
def get_ep_csv_fname(episode_id: int):
|
|
ep_csv_fname = f"episode_{episode_id}.csv"
|
|
return ep_csv_fname
|
|
|
|
|
|
def write_episode_data_csv(output_dir, file_name, episode_index, dataset):
|
|
"""Write a csv file containg timeseries data of an episode (e.g. state and action).
|
|
This file will be loaded by Dygraph javascript to plot data in real time."""
|
|
from_idx = dataset.episode_data_index["from"][episode_index]
|
|
to_idx = dataset.episode_data_index["to"][episode_index]
|
|
|
|
has_state = "observation.state" in dataset.hf_dataset.features
|
|
has_action = "action" in dataset.hf_dataset.features
|
|
|
|
# init header of csv with state and action names
|
|
header = ["timestamp"]
|
|
if has_state:
|
|
dim_state = len(dataset.hf_dataset["observation.state"][0])
|
|
header += [f"state_{i}" for i in range(dim_state)]
|
|
if has_action:
|
|
dim_action = len(dataset.hf_dataset["action"][0])
|
|
header += [f"action_{i}" for i in range(dim_action)]
|
|
|
|
columns = ["timestamp"]
|
|
if has_state:
|
|
columns += ["observation.state"]
|
|
if has_action:
|
|
columns += ["action"]
|
|
|
|
rows = []
|
|
data = dataset.hf_dataset.select_columns(columns)
|
|
for i in range(from_idx, to_idx):
|
|
row = [data[i]["timestamp"].item()]
|
|
if has_state:
|
|
row += data[i]["observation.state"].tolist()
|
|
if has_action:
|
|
row += data[i]["action"].tolist()
|
|
rows.append(row)
|
|
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
with open(output_dir / file_name, "w") as f:
|
|
f.write(",".join(header) + "\n")
|
|
for row in rows:
|
|
row_str = [str(col) for col in row]
|
|
f.write(",".join(row_str) + "\n")
|
|
|
|
|
|
def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]:
|
|
# get first frame of episode (hack to get video_path of the episode)
|
|
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
|
|
return [
|
|
dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"]
|
|
for key in dataset.video_frame_keys
|
|
]
|
|
|
|
|
|
def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) -> list[str]:
|
|
# check if the dataset has language instructions
|
|
if "language_instruction" not in dataset.hf_dataset.features:
|
|
return None
|
|
|
|
# get first frame index
|
|
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
|
|
|
|
language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"]
|
|
# TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored
|
|
# with the tf.tensor appearing in the string
|
|
return language_instruction.removeprefix("tf.Tensor(b'").removesuffix("', shape=(), dtype=string)")
|
|
|
|
|
|
def visualize_dataset_html(
|
|
repo_id: str,
|
|
root: Path | None = None,
|
|
episodes: list[int] = None,
|
|
output_dir: Path | None = None,
|
|
serve: bool = True,
|
|
host: str = "127.0.0.1",
|
|
port: int = 9090,
|
|
force_override: bool = False,
|
|
) -> Path | None:
|
|
init_logging()
|
|
|
|
dataset = LeRobotDataset(repo_id, root=root)
|
|
|
|
if not dataset.video:
|
|
raise NotImplementedError(f"Image datasets ({dataset.video=}) are currently not supported.")
|
|
|
|
if output_dir is None:
|
|
output_dir = f"outputs/visualize_dataset_html/{repo_id}"
|
|
|
|
output_dir = Path(output_dir)
|
|
if output_dir.exists():
|
|
if force_override:
|
|
shutil.rmtree(output_dir)
|
|
else:
|
|
logging.info(f"Output directory already exists. Loading from it: '{output_dir}'")
|
|
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Create a simlink from the dataset video folder containg mp4 files to the output directory
|
|
# so that the http server can get access to the mp4 files.
|
|
static_dir = output_dir / "static"
|
|
static_dir.mkdir(parents=True, exist_ok=True)
|
|
ln_videos_dir = static_dir / "videos"
|
|
if not ln_videos_dir.exists():
|
|
ln_videos_dir.symlink_to(dataset.videos_dir.resolve())
|
|
|
|
template_dir = Path(__file__).resolve().parent.parent / "templates"
|
|
|
|
if episodes is None:
|
|
episodes = list(range(dataset.num_episodes))
|
|
|
|
logging.info("Writing CSV files")
|
|
for episode_index in tqdm.tqdm(episodes):
|
|
# write states and actions in a csv (it can be slow for big datasets)
|
|
ep_csv_fname = get_ep_csv_fname(episode_index)
|
|
# TODO(rcadene): speedup script by loading directly from dataset, pyarrow, parquet, safetensors?
|
|
write_episode_data_csv(static_dir, ep_csv_fname, episode_index, dataset)
|
|
|
|
if serve:
|
|
run_server(dataset, episodes, host, port, static_dir, template_dir)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument(
|
|
"--repo-id",
|
|
type=str,
|
|
required=True,
|
|
help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).",
|
|
)
|
|
parser.add_argument(
|
|
"--root",
|
|
type=Path,
|
|
default=None,
|
|
help="Root directory for a dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.",
|
|
)
|
|
parser.add_argument(
|
|
"--episodes",
|
|
type=int,
|
|
nargs="*",
|
|
default=None,
|
|
help="Episode indices to visualize (e.g. `0 1 5 6` to load episodes of index 0, 1, 5 and 6). By default loads all episodes.",
|
|
)
|
|
parser.add_argument(
|
|
"--output-dir",
|
|
type=Path,
|
|
default=None,
|
|
help="Directory path to write html files and kickoff a web server. By default write them to 'outputs/visualize_dataset/REPO_ID'.",
|
|
)
|
|
parser.add_argument(
|
|
"--serve",
|
|
type=int,
|
|
default=1,
|
|
help="Launch web server.",
|
|
)
|
|
parser.add_argument(
|
|
"--host",
|
|
type=str,
|
|
default="127.0.0.1",
|
|
help="Web host used by the http server.",
|
|
)
|
|
parser.add_argument(
|
|
"--port",
|
|
type=int,
|
|
default=9090,
|
|
help="Web port used by the http server.",
|
|
)
|
|
parser.add_argument(
|
|
"--force-override",
|
|
type=int,
|
|
default=0,
|
|
help="Delete the output directory if it exists already.",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
visualize_dataset_html(**vars(args))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|