diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py index f947e610..6cff5752 100644 --- a/lerobot/scripts/visualize_dataset.py +++ b/lerobot/scripts/visualize_dataset.py @@ -108,8 +108,8 @@ def visualize_dataset( web_port: int = 9090, ws_port: int = 9087, save: bool = False, - output_dir: Path | None = None, root: Path | None = None, + output_dir: Path | None = None, ) -> Path | None: if save: assert ( @@ -209,6 +209,18 @@ def main(): required=True, help="Episode to visualize.", ) + parser.add_argument( + "--root", + type=Path, + default=None, + help="Root directory for a dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.", + ) + parser.add_argument( + "--output-dir", + type=Path, + default=None, + help="Directory path to write a .rrd file when `--save 1` is set.", + ) parser.add_argument( "--batch-size", type=int, @@ -254,17 +266,6 @@ def main(): "Visualize the data by running `rerun path/to/file.rrd` on your local machine." ), ) - parser.add_argument( - "--output-dir", - type=str, - help="Directory path to write a .rrd file when `--save 1` is set.", - ) - - parser.add_argument( - "--root", - type=str, - help="Root directory for a dataset stored on a local machine.", - ) args = parser.parse_args() visualize_dataset(**vars(args)) diff --git a/lerobot/scripts/visualize_dataset_html.py b/lerobot/scripts/visualize_dataset_html.py new file mode 100644 index 00000000..2531fbd0 --- /dev/null +++ b/lerobot/scripts/visualize_dataset_html.py @@ -0,0 +1,300 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset. + +Note: The last frame of the episode doesnt always correspond to a final state. +That's because our datasets are composed of transition from state to state up to +the antepenultimate state associated to the ultimate action to arrive in the final state. +However, there might not be a transition from a final state to another state. + +Note: This script aims to visualize the data used to train the neural networks. +~What you see is what you get~. When visualizing image modality, it is often expected to observe +lossly compression artifacts since these images have been decoded from compressed mp4 videos to +save disk space. The compression factor applied has been tuned to not affect success rate. + +Example of usage: + +- Visualize data stored on a local machine: +```bash +local$ python lerobot/scripts/visualize_dataset_html.py \ + --repo-id lerobot/pusht + +local$ open http://localhost:9090 +``` + +- Visualize data stored on a distant machine with a local viewer: +```bash +distant$ python lerobot/scripts/visualize_dataset_html.py \ + --repo-id lerobot/pusht + +local$ ssh -L 9090:localhost:9090 distant # create a ssh tunnel +local$ open http://localhost:9090 +``` + +- Select episodes to visualize: +```bash +python lerobot/scripts/visualize_dataset_html.py \ + --repo-id lerobot/pusht \ + --episodes 7 3 5 1 4 +``` +""" + +import argparse +import logging +import shutil +from pathlib import Path + +import torch +import tqdm +from flask import Flask, redirect, render_template, url_for + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.utils.utils import init_logging + + +class EpisodeSampler(torch.utils.data.Sampler): + def __init__(self, dataset, episode_index): + from_idx = dataset.episode_data_index["from"][episode_index].item() + to_idx = dataset.episode_data_index["to"][episode_index].item() + self.frame_ids = range(from_idx, to_idx) + + def __iter__(self): + return iter(self.frame_ids) + + def __len__(self): + return len(self.frame_ids) + + +def run_server( + dataset: LeRobotDataset, + episodes: list[int], + host: str, + port: str, + static_folder: Path, + template_folder: Path, +): + app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve()) + app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache + + @app.route("/") + def index(): + # home page redirects to the first episode page + [dataset_namespace, dataset_name] = dataset.repo_id.split("/") + first_episode_id = episodes[0] + return redirect( + url_for( + "show_episode", + dataset_namespace=dataset_namespace, + dataset_name=dataset_name, + episode_id=first_episode_id, + ) + ) + + @app.route("///episode_") + def show_episode(dataset_namespace, dataset_name, episode_id): + dataset_info = { + "repo_id": dataset.repo_id, + "num_samples": dataset.num_samples, + "num_episodes": dataset.num_episodes, + "fps": dataset.fps, + } + video_paths = get_episode_video_paths(dataset, episode_id) + videos_info = [ + {"url": url_for("static", filename=video_path), "filename": Path(video_path).name} + for video_path in video_paths + ] + ep_csv_url = url_for("static", filename=get_ep_csv_fname(episode_id)) + return render_template( + "visualize_dataset_template.html", + episode_id=episode_id, + episodes=episodes, + dataset_info=dataset_info, + videos_info=videos_info, + ep_csv_url=ep_csv_url, + has_policy=False, + ) + + app.run(host=host, port=port) + + +def get_ep_csv_fname(episode_id: int): + ep_csv_fname = f"episode_{episode_id}.csv" + return ep_csv_fname + + +def write_episode_data_csv(output_dir, file_name, episode_index, dataset): + """Write a csv file containg timeseries data of an episode (e.g. state and action). + This file will be loaded by Dygraph javascript to plot data in real time.""" + from_idx = dataset.episode_data_index["from"][episode_index] + to_idx = dataset.episode_data_index["to"][episode_index] + + has_state = "observation.state" in dataset.hf_dataset.features + has_action = "action" in dataset.hf_dataset.features + + # init header of csv with state and action names + header = ["timestamp"] + if has_state: + dim_state = len(dataset.hf_dataset["observation.state"][0]) + header += [f"state_{i}" for i in range(dim_state)] + if has_action: + dim_action = len(dataset.hf_dataset["action"][0]) + header += [f"action_{i}" for i in range(dim_action)] + + columns = ["timestamp"] + if has_state: + columns += ["observation.state"] + if has_action: + columns += ["action"] + + rows = [] + data = dataset.hf_dataset.select_columns(columns) + for i in range(from_idx, to_idx): + row = [data[i]["timestamp"].item()] + if has_state: + row += data[i]["observation.state"].tolist() + if has_action: + row += data[i]["action"].tolist() + rows.append(row) + + output_dir.mkdir(parents=True, exist_ok=True) + with open(output_dir / file_name, "w") as f: + f.write(",".join(header) + "\n") + for row in rows: + row_str = [str(col) for col in row] + f.write(",".join(row_str) + "\n") + + +def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]: + # get first frame of episode (hack to get video_path of the episode) + first_frame_idx = dataset.episode_data_index["from"][ep_index].item() + return [ + dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"] + for key in dataset.video_frame_keys + ] + + +def visualize_dataset_html( + repo_id: str, + root: Path | None = None, + episodes: list[int] = None, + output_dir: Path | None = None, + serve: bool = True, + host: str = "127.0.0.1", + port: int = 9090, + force_override: bool = False, +) -> Path | None: + init_logging() + + dataset = LeRobotDataset(repo_id, root=root) + + if not dataset.video: + raise NotImplementedError(f"Image datasets ({dataset.video=}) are currently not supported.") + + if output_dir is None: + output_dir = f"outputs/visualize_dataset_html/{repo_id}" + + output_dir = Path(output_dir) + if output_dir.exists(): + if force_override: + shutil.rmtree(output_dir) + else: + logging.info(f"Output directory already exists. Loading from it: '{output_dir}'") + + output_dir.mkdir(parents=True, exist_ok=True) + + # Create a simlink from the dataset video folder containg mp4 files to the output directory + # so that the http server can get access to the mp4 files. + static_dir = output_dir / "static" + static_dir.mkdir(parents=True, exist_ok=True) + ln_videos_dir = static_dir / "videos" + if not ln_videos_dir.exists(): + ln_videos_dir.symlink_to(dataset.videos_dir.resolve()) + + template_dir = Path(__file__).resolve().parent.parent / "templates" + + if episodes is None: + episodes = list(range(dataset.num_episodes)) + + logging.info("Writing CSV files") + for episode_index in tqdm.tqdm(episodes): + # write states and actions in a csv (it can be slow for big datasets) + ep_csv_fname = get_ep_csv_fname(episode_index) + # TODO(rcadene): speedup script by loading directly from dataset, pyarrow, parquet, safetensors? + write_episode_data_csv(static_dir, ep_csv_fname, episode_index, dataset) + + if serve: + run_server(dataset, episodes, host, port, static_dir, template_dir) + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--repo-id", + type=str, + required=True, + help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).", + ) + parser.add_argument( + "--root", + type=Path, + default=None, + help="Root directory for a dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.", + ) + parser.add_argument( + "--episodes", + type=int, + nargs="*", + default=None, + help="Episode indices to visualize (e.g. `0 1 5 6` to load episodes of index 0, 1, 5 and 6). By default loads all episodes.", + ) + parser.add_argument( + "--output-dir", + type=Path, + default=None, + help="Directory path to write html files and kickoff a web server. By default write them to 'outputs/visualize_dataset/REPO_ID'.", + ) + parser.add_argument( + "--serve", + type=int, + default=1, + help="Launch web server.", + ) + parser.add_argument( + "--host", + type=str, + default="127.0.0.1", + help="Web host used by the http server.", + ) + parser.add_argument( + "--port", + type=int, + default=9090, + help="Web port used by the http server.", + ) + parser.add_argument( + "--force-override", + type=int, + default=0, + help="Delete the output directory if it exists already.", + ) + + args = parser.parse_args() + visualize_dataset_html(**vars(args)) + + +if __name__ == "__main__": + main() diff --git a/lerobot/templates/visualize_dataset_template.html b/lerobot/templates/visualize_dataset_template.html new file mode 100644 index 00000000..16ca0fa3 --- /dev/null +++ b/lerobot/templates/visualize_dataset_template.html @@ -0,0 +1,360 @@ + + + + + + + + + + + {{ dataset_info.repo_id }} episode {{ episode_id }} + + + + + + + +
+

{{ dataset_info.repo_id }}

+ +
    +
  • + Number of samples/frames: {{ dataset_info.num_samples }} +
  • +
  • + Number of episodes: {{ dataset_info.num_episodes }} +
  • +
  • + Frames per second: {{ dataset_info.fps }} +
  • +
+ +

Episodes:

+ + +
+ + + + + +
+

+ Episode {{ episode_id }} +

+ + +
+ {% for video_info in videos_info %} +
+

{{ video_info.filename }}

+ +
+ {% endfor %} +
+ + + + + +
+ + + + + + +
0:00 / + 0:00 +
+
+ + +
+
+
+
+

+ Time: 0.00s +

+
+ + + + + + + + + + + +
+ + +
+
+ + + + + diff --git a/poetry.lock b/poetry.lock index ae500299..084af6b6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "absl-py" @@ -192,6 +192,17 @@ charset-normalizer = ["charset-normalizer"] html5lib = ["html5lib"] lxml = ["lxml"] +[[package]] +name = "blinker" +version = "1.8.2" +description = "Fast, simple object-to-object and broadcast signaling" +optional = false +python-versions = ">=3.8" +files = [ + {file = "blinker-1.8.2-py3-none-any.whl", hash = "sha256:1779309f71bf239144b9399d06ae925637cf6634cf6bd131104184531bf67c01"}, + {file = "blinker-1.8.2.tar.gz", hash = "sha256:8f77b09d3bf7c795e969e9486f39c2c5e9c39d4ee07424be2bc594ece9642d83"}, +] + [[package]] name = "certifi" version = "2024.7.4" @@ -584,17 +595,6 @@ files = [ {file = "debugpy-1.8.2.zip", hash = "sha256:95378ed08ed2089221896b9b3a8d021e642c24edc8fef20e5d4342ca8be65c00"}, ] -[[package]] -name = "decorator" -version = "4.4.2" -description = "Decorators for Humans" -optional = false -python-versions = ">=2.6, !=3.0.*, !=3.1.*" -files = [ - {file = "decorator-4.4.2-py2.py3-none-any.whl", hash = "sha256:41fa54c2a0cc4ba648be4fd43cff00aedf5b9465c9bf18d64325bc225f08f760"}, - {file = "decorator-4.4.2.tar.gz", hash = "sha256:e3a62f0520172440ca0dcc823749319382e377f37f140a0b99ef45fecb84bfe7"}, -] - [[package]] name = "deepdiff" version = "7.0.1" @@ -795,6 +795,7 @@ files = [ {file = "dora_rs-0.3.5-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:01f811d0c6722f74743c153a7be0144686daeafa968c473e60f6b6c5dc8f5bff"}, {file = "dora_rs-0.3.5-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:a36e97d31eeb66e6d5913130695d188ceee1248029961012a8b4f59fd3f58670"}, {file = "dora_rs-0.3.5-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:25d620123a733661dc740ef2b456601ddbaa69ae2b50d8141daa3c684bda385c"}, + {file = "dora_rs-0.3.5-cp37-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a9fdc4e73578bebb1c8d0f8bea2243a5a9e179f08c74d98576123b59b75e5cac"}, {file = "dora_rs-0.3.5-cp37-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e65830634c58158557f0ab90e5d1f492bcbc6b74587b05825ba4c20b634dc1bd"}, {file = "dora_rs-0.3.5-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c01f9ab8f93295341aeab2d606d484d9cff9d05f57581e2180433ec8e0d38307"}, {file = "dora_rs-0.3.5-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:5d6d46a49a34cd7e4f74496a1089b9a1b78282c219a28d98fe031a763e92d530"}, @@ -892,6 +893,28 @@ docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1 testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-asyncio (>=0.21)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)", "virtualenv (>=20.26.2)"] typing = ["typing-extensions (>=4.8)"] +[[package]] +name = "flask" +version = "3.0.3" +description = "A simple framework for building complex web applications." +optional = false +python-versions = ">=3.8" +files = [ + {file = "flask-3.0.3-py3-none-any.whl", hash = "sha256:34e815dfaa43340d1d15a5c3a02b8476004037eb4840b34910c6e21679d288f3"}, + {file = "flask-3.0.3.tar.gz", hash = "sha256:ceb27b0af3823ea2737928a4d99d125a06175b8512c445cbd9a9ce200ef76842"}, +] + +[package.dependencies] +blinker = ">=1.6.2" +click = ">=8.1.3" +itsdangerous = ">=2.1.2" +Jinja2 = ">=3.1.2" +Werkzeug = ">=3.0.0" + +[package.extras] +async = ["asgiref (>=3.2)"] +dotenv = ["python-dotenv"] + [[package]] name = "frozenlist" version = "1.4.1" @@ -1550,6 +1573,17 @@ files = [ {file = "intel_openmp-2021.4.0-py2.py3-none-win_amd64.whl", hash = "sha256:eef4c8bcc8acefd7f5cd3b9384dbf73d59e2c99fc56545712ded913f43c4a94f"}, ] +[[package]] +name = "itsdangerous" +version = "2.2.0" +description = "Safely pass data to untrusted environments and back." +optional = false +python-versions = ">=3.8" +files = [ + {file = "itsdangerous-2.2.0-py3-none-any.whl", hash = "sha256:c6242fc49e35958c8b15141343aa660db5fc54d4f13a1db01a3f5891b98700ef"}, + {file = "itsdangerous-2.2.0.tar.gz", hash = "sha256:e0050c0b7da1eea53ffaf149c0cfbb5c6e2e2b69c4bef22c81fa6eb73e5f6173"}, +] + [[package]] name = "jinja2" version = "3.1.4" @@ -1741,9 +1775,13 @@ files = [ {file = "lxml-5.2.2-cp36-cp36m-win_amd64.whl", hash = "sha256:edcfa83e03370032a489430215c1e7783128808fd3e2e0a3225deee278585196"}, {file = "lxml-5.2.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:28bf95177400066596cdbcfc933312493799382879da504633d16cf60bba735b"}, {file = "lxml-5.2.2-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3a745cc98d504d5bd2c19b10c79c61c7c3df9222629f1b6210c0368177589fb8"}, + {file = "lxml-5.2.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1b590b39ef90c6b22ec0be925b211298e810b4856909c8ca60d27ffbca6c12e6"}, {file = "lxml-5.2.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b336b0416828022bfd5a2e3083e7f5ba54b96242159f83c7e3eebaec752f1716"}, + {file = "lxml-5.2.2-cp37-cp37m-manylinux_2_28_aarch64.whl", hash = "sha256:c2faf60c583af0d135e853c86ac2735ce178f0e338a3c7f9ae8f622fd2eb788c"}, {file = "lxml-5.2.2-cp37-cp37m-manylinux_2_28_x86_64.whl", hash = "sha256:4bc6cb140a7a0ad1f7bc37e018d0ed690b7b6520ade518285dc3171f7a117905"}, + {file = "lxml-5.2.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:7ff762670cada8e05b32bf1e4dc50b140790909caa8303cfddc4d702b71ea184"}, {file = "lxml-5.2.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:57f0a0bbc9868e10ebe874e9f129d2917750adf008fe7b9c1598c0fbbfdde6a6"}, + {file = "lxml-5.2.2-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:a6d2092797b388342c1bc932077ad232f914351932353e2e8706851c870bca1f"}, {file = "lxml-5.2.2-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:60499fe961b21264e17a471ec296dcbf4365fbea611bf9e303ab69db7159ce61"}, {file = "lxml-5.2.2-cp37-cp37m-win32.whl", hash = "sha256:d9b342c76003c6b9336a80efcc766748a333573abf9350f4094ee46b006ec18f"}, {file = "lxml-5.2.2-cp37-cp37m-win_amd64.whl", hash = "sha256:b16db2770517b8799c79aa80f4053cd6f8b716f21f8aca962725a9565ce3ee40"}, @@ -1901,30 +1939,6 @@ files = [ intel-openmp = "==2021.*" tbb = "==2021.*" -[[package]] -name = "moviepy" -version = "1.0.3" -description = "Video editing with Python" -optional = false -python-versions = "*" -files = [ - {file = "moviepy-1.0.3.tar.gz", hash = "sha256:2884e35d1788077db3ff89e763c5ba7bfddbd7ae9108c9bc809e7ba58fa433f5"}, -] - -[package.dependencies] -decorator = ">=4.0.2,<5.0" -imageio = {version = ">=2.5,<3.0", markers = "python_version >= \"3.4\""} -imageio_ffmpeg = {version = ">=0.2.0", markers = "python_version >= \"3.4\""} -numpy = {version = ">=1.17.3", markers = "python_version > \"2.7\""} -proglog = "<=1.0.0" -requests = ">=2.8.1,<3.0" -tqdm = ">=4.11.2,<5.0" - -[package.extras] -doc = ["Sphinx (>=1.5.2,<2.0)", "numpydoc (>=0.6.0,<1.0)", "pygame (>=1.9.3,<2.0)", "sphinx_rtd_theme (>=0.1.10b0,<1.0)"] -optional = ["matplotlib (>=2.0.0,<3.0)", "opencv-python (>=3.0,<4.0)", "scikit-image (>=0.13.0,<1.0)", "scikit-learn", "scipy (>=0.19.0,<1.5)", "youtube_dl"] -test = ["coverage (<5.0)", "coveralls (>=1.1,<2.0)", "pytest (>=3.0.0,<4.0)", "pytest-cov (>=2.5.1,<3.0)", "requests (>=2.8.1,<3.0)"] - [[package]] name = "mpmath" version = "1.3.0" @@ -2696,20 +2710,6 @@ nodeenv = ">=0.11.1" pyyaml = ">=5.1" virtualenv = ">=20.10.0" -[[package]] -name = "proglog" -version = "0.1.10" -description = "Log and progress bar manager for console, notebooks, web..." -optional = false -python-versions = "*" -files = [ - {file = "proglog-0.1.10-py3-none-any.whl", hash = "sha256:19d5da037e8c813da480b741e3fa71fb1ac0a5b02bf21c41577c7f327485ec50"}, - {file = "proglog-0.1.10.tar.gz", hash = "sha256:658c28c9c82e4caeb2f25f488fff9ceace22f8d69b15d0c1c86d64275e4ddab4"}, -] - -[package.dependencies] -tqdm = "*" - [[package]] name = "protobuf" version = "5.27.2" @@ -3276,6 +3276,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -3809,13 +3810,13 @@ test = ["pytest"] [[package]] name = "setuptools" -version = "71.0.1" +version = "71.0.0" description = "Easily download, build, install, upgrade, and uninstall Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "setuptools-71.0.1-py3-none-any.whl", hash = "sha256:1eb8ef012efae7f6acbc53ec0abde4bc6746c43087fd215ee09e1df48998711f"}, - {file = "setuptools-71.0.1.tar.gz", hash = "sha256:c51d7fd29843aa18dad362d4b4ecd917022131425438251f4e3d766c964dd1ad"}, + {file = "setuptools-71.0.0-py3-none-any.whl", hash = "sha256:f06fbe978a91819d250a30e0dc4ca79df713d909e24438a42d0ec300fc52247f"}, + {file = "setuptools-71.0.0.tar.gz", hash = "sha256:98da3b8aca443b9848a209ae4165e2edede62633219afa493a58fbba57f72e2e"}, ] [package.extras] @@ -4215,6 +4216,23 @@ perf = ["orjson"] sweeps = ["sweeps (>=0.2.0)"] workspaces = ["wandb-workspaces"] +[[package]] +name = "werkzeug" +version = "3.0.3" +description = "The comprehensive WSGI web application library." +optional = false +python-versions = ">=3.8" +files = [ + {file = "werkzeug-3.0.3-py3-none-any.whl", hash = "sha256:fc9645dc43e03e4d630d23143a04a7f947a9a3b5727cd535fdfe155a17cc48c8"}, + {file = "werkzeug-3.0.3.tar.gz", hash = "sha256:097e5bfda9f0aba8da6b8545146def481d06aa7d3266e7448e2cccf67dd8bd18"}, +] + +[package.dependencies] +MarkupSafe = ">=2.1.1" + +[package.extras] +watchdog = ["watchdog (>=2.3)"] + [[package]] name = "xxhash" version = "3.4.1" @@ -4485,4 +4503,4 @@ xarm = ["gym-xarm"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "dfe9c6a54e0382156e62e7bd2c7aab1be6372da76d30c61b06d27232276638cb" +content-hash = "25d5a270d770d37b13a93bf72868d3b9e683f8af5252b6332ec926a26fd0c096" diff --git a/pyproject.toml b/pyproject.toml index 787984a0..8d4ea792 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,13 +57,15 @@ pytest-cov = {version = ">=5.0.0", optional = true} datasets = ">=2.19.0" imagecodecs = { version = ">=2024.1.1", optional = true } pyav = ">=12.0.5" -moviepy = ">=1.0.3" rerun-sdk = ">=0.15.1" deepdiff = ">=7.0.1" -scikit-image = {version = ">=0.23.2", optional = true} +flask = ">=3.0.3" pandas = {version = ">=2.2.2", optional = true} +scikit-image = {version = ">=0.23.2", optional = true} dynamixel-sdk = {version = ">=3.7.31", optional = true} pynput = {version = ">=1.7.7", optional = true} +# TODO(rcadene, salibert): 71.0.1 has a bug +setuptools = {version = "!=71.0.1", optional = true} diff --git a/tests/test_visualize_dataset_html.py b/tests/test_visualize_dataset_html.py new file mode 100644 index 00000000..4dc3c063 --- /dev/null +++ b/tests/test_visualize_dataset_html.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +import pytest + +from lerobot.scripts.visualize_dataset_html import visualize_dataset_html + + +@pytest.mark.parametrize( + "repo_id", + ["lerobot/pusht"], +) +def test_visualize_dataset_html(tmpdir, repo_id): + tmpdir = Path(tmpdir) + visualize_dataset_html( + repo_id, + episodes=[0], + output_dir=tmpdir, + serve=False, + ) + assert (tmpdir / "static" / "episode_0.csv").exists()