From 7ae8d05326430398517b342cef35e8baf545b62b Mon Sep 17 00:00:00 2001
From: Simon Alibert <simon.alibert@huggingface.co>
Date: Wed, 23 Oct 2024 14:20:27 +0200
Subject: [PATCH] Fix visualization

---
 lerobot/scripts/visualize_dataset_html.py | 21 ++++++++++-----------
 1 file changed, 10 insertions(+), 11 deletions(-)

diff --git a/lerobot/scripts/visualize_dataset_html.py b/lerobot/scripts/visualize_dataset_html.py
index d9d153a0..ec7e4b1f 100644
--- a/lerobot/scripts/visualize_dataset_html.py
+++ b/lerobot/scripts/visualize_dataset_html.py
@@ -97,14 +97,13 @@ def run_server(
             "num_episodes": dataset.num_episodes,
             "fps": dataset.fps,
         }
-        video_paths = get_episode_video_paths(dataset, episode_id)
-        language_instruction = get_episode_language_instruction(dataset, episode_id)
+        video_paths = [dataset.get_video_file_path(episode_id, key) for key in dataset.video_keys]
+        tasks = dataset.episode_dicts[episode_id]["tasks"]
         videos_info = [
-            {"url": url_for("static", filename=video_path), "filename": Path(video_path).name}
+            {"url": url_for("static", filename=video_path), "filename": video_path.name}
             for video_path in video_paths
         ]
-        if language_instruction:
-            videos_info[0]["language_instruction"] = language_instruction
+        videos_info[0]["language_instruction"] = tasks
 
         ep_csv_url = url_for("static", filename=get_ep_csv_fname(episode_id))
         return render_template(
@@ -137,10 +136,10 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset):
     # init header of csv with state and action names
     header = ["timestamp"]
     if has_state:
-        dim_state = len(dataset.hf_dataset["observation.state"][0])
+        dim_state = dataset.shapes["observation.state"]
         header += [f"state_{i}" for i in range(dim_state)]
     if has_action:
-        dim_action = len(dataset.hf_dataset["action"][0])
+        dim_action = dataset.shapes["action"]
         header += [f"action_{i}" for i in range(dim_action)]
 
     columns = ["timestamp"]
@@ -171,7 +170,7 @@ def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]
     # get first frame of episode (hack to get video_path of the episode)
     first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
     return [
-        dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"] for key in dataset.camera_keys
+        dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"] for key in dataset.video_keys
     ]
 
 
@@ -203,8 +202,8 @@ def visualize_dataset_html(
 
     dataset = LeRobotDataset(repo_id, root=root)
 
-    if not dataset.video:
-        raise NotImplementedError(f"Image datasets ({dataset.video=}) are currently not supported.")
+    if len(dataset.image_keys) > 0:
+        raise NotImplementedError(f"Image keys ({dataset.image_keys=}) are currently not supported.")
 
     if output_dir is None:
         output_dir = f"outputs/visualize_dataset_html/{repo_id}"
@@ -224,7 +223,7 @@ def visualize_dataset_html(
     static_dir.mkdir(parents=True, exist_ok=True)
     ln_videos_dir = static_dir / "videos"
     if not ln_videos_dir.exists():
-        ln_videos_dir.symlink_to(dataset.videos_dir.resolve())
+        ln_videos_dir.symlink_to((dataset.root / "videos").resolve())
 
     template_dir = Path(__file__).resolve().parent.parent / "templates"