From 8546358bc580691e449204bcd3c0a24d5d4a3a22 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Mon, 18 Nov 2024 17:54:15 +0100 Subject: [PATCH] Fix test_visualize_dataset_html --- lerobot/scripts/visualize_dataset_html.py | 35 +++++++---------------- tests/test_visualize_dataset_html.py | 20 +++++-------- 2 files changed, 18 insertions(+), 37 deletions(-) diff --git a/lerobot/scripts/visualize_dataset_html.py b/lerobot/scripts/visualize_dataset_html.py index b79734d9..475983d3 100644 --- a/lerobot/scripts/visualize_dataset_html.py +++ b/lerobot/scripts/visualize_dataset_html.py @@ -130,16 +130,16 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset): from_idx = dataset.episode_data_index["from"][episode_index] to_idx = dataset.episode_data_index["to"][episode_index] - has_state = "observation.state" in dataset.hf_dataset.features - has_action = "action" in dataset.hf_dataset.features + has_state = "observation.state" in dataset.features + has_action = "action" in dataset.features # init header of csv with state and action names header = ["timestamp"] if has_state: - dim_state = dataset.meta.shapes["observation.state"] + dim_state = dataset.meta.shapes["observation.state"][0] header += [f"state_{i}" for i in range(dim_state)] if has_action: - dim_action = dataset.meta.shapes["action"] + dim_action = dataset.meta.shapes["action"][0] header += [f"action_{i}" for i in range(dim_action)] columns = ["timestamp"] @@ -175,23 +175,8 @@ def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str] ] -def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) -> list[str]: - # check if the dataset has language instructions - if "language_instruction" not in dataset.hf_dataset.features: - return None - - # get first frame index - first_frame_idx = dataset.episode_data_index["from"][ep_index].item() - - language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"] - # TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored - # with the tf.tensor appearing in the string - return language_instruction.removeprefix("tf.Tensor(b'").removesuffix("', shape=(), dtype=string)") - - def visualize_dataset_html( - repo_id: str, - root: Path | None = None, + dataset: LeRobotDataset, episodes: list[int] = None, output_dir: Path | None = None, serve: bool = True, @@ -201,13 +186,11 @@ def visualize_dataset_html( ) -> Path | None: init_logging() - dataset = LeRobotDataset(repo_id, root=root) - if len(dataset.meta.image_keys) > 0: raise NotImplementedError(f"Image keys ({dataset.meta.image_keys=}) are currently not supported.") if output_dir is None: - output_dir = f"outputs/visualize_dataset_html/{repo_id}" + output_dir = f"outputs/visualize_dataset_html/{dataset.repo_id}" output_dir = Path(output_dir) if output_dir.exists(): @@ -296,7 +279,11 @@ def main(): ) args = parser.parse_args() - visualize_dataset_html(**vars(args)) + kwargs = vars(args) + repo_id = kwargs.pop("repo_id") + root = kwargs.pop("root") + dataset = LeRobotDataset(repo_id, root=root, local_files_only=True) + visualize_dataset_html(dataset, **kwargs) if __name__ == "__main__": diff --git a/tests/test_visualize_dataset_html.py b/tests/test_visualize_dataset_html.py index 4dc3c063..53924f56 100644 --- a/tests/test_visualize_dataset_html.py +++ b/tests/test_visualize_dataset_html.py @@ -14,23 +14,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pathlib import Path - -import pytest - from lerobot.scripts.visualize_dataset_html import visualize_dataset_html -@pytest.mark.parametrize( - "repo_id", - ["lerobot/pusht"], -) -def test_visualize_dataset_html(tmpdir, repo_id): - tmpdir = Path(tmpdir) +def test_visualize_dataset_html(tmp_path, lerobot_dataset_factory): + root = tmp_path / "dataset" + output_dir = tmp_path / "outputs" + dataset = lerobot_dataset_factory(root=root) visualize_dataset_html( - repo_id, + dataset, episodes=[0], - output_dir=tmpdir, + output_dir=output_dir, serve=False, ) - assert (tmpdir / "static" / "episode_0.csv").exists() + assert (output_dir / "static" / "episode_0.csv").exists()