diff --git a/lerobot/scripts/visualize_dataset_html.py b/lerobot/scripts/visualize_dataset_html.py index d9d153a0..ec7e4b1f 100644 --- a/lerobot/scripts/visualize_dataset_html.py +++ b/lerobot/scripts/visualize_dataset_html.py @@ -97,14 +97,13 @@ def run_server( "num_episodes": dataset.num_episodes, "fps": dataset.fps, } - video_paths = get_episode_video_paths(dataset, episode_id) - language_instruction = get_episode_language_instruction(dataset, episode_id) + video_paths = [dataset.get_video_file_path(episode_id, key) for key in dataset.video_keys] + tasks = dataset.episode_dicts[episode_id]["tasks"] videos_info = [ - {"url": url_for("static", filename=video_path), "filename": Path(video_path).name} + {"url": url_for("static", filename=video_path), "filename": video_path.name} for video_path in video_paths ] - if language_instruction: - videos_info[0]["language_instruction"] = language_instruction + videos_info[0]["language_instruction"] = tasks ep_csv_url = url_for("static", filename=get_ep_csv_fname(episode_id)) return render_template( @@ -137,10 +136,10 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset): # init header of csv with state and action names header = ["timestamp"] if has_state: - dim_state = len(dataset.hf_dataset["observation.state"][0]) + dim_state = dataset.shapes["observation.state"] header += [f"state_{i}" for i in range(dim_state)] if has_action: - dim_action = len(dataset.hf_dataset["action"][0]) + dim_action = dataset.shapes["action"] header += [f"action_{i}" for i in range(dim_action)] columns = ["timestamp"] @@ -171,7 +170,7 @@ def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str] # get first frame of episode (hack to get video_path of the episode) first_frame_idx = dataset.episode_data_index["from"][ep_index].item() return [ - dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"] for key in dataset.camera_keys + dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"] for key in dataset.video_keys ] @@ -203,8 +202,8 @@ def visualize_dataset_html( dataset = LeRobotDataset(repo_id, root=root) - if not dataset.video: - raise NotImplementedError(f"Image datasets ({dataset.video=}) are currently not supported.") + if len(dataset.image_keys) > 0: + raise NotImplementedError(f"Image keys ({dataset.image_keys=}) are currently not supported.") if output_dir is None: output_dir = f"outputs/visualize_dataset_html/{repo_id}" @@ -224,7 +223,7 @@ def visualize_dataset_html( static_dir.mkdir(parents=True, exist_ok=True) ln_videos_dir = static_dir / "videos" if not ln_videos_dir.exists(): - ln_videos_dir.symlink_to(dataset.videos_dir.resolve()) + ln_videos_dir.symlink_to((dataset.root / "videos").resolve()) template_dir = Path(__file__).resolve().parent.parent / "templates"