Fix visualization

This commit is contained in:
Simon Alibert 2024-10-23 14:20:27 +02:00
parent a2a8538ac9
commit 7ae8d05326
1 changed files with 10 additions and 11 deletions

View File

@ -97,14 +97,13 @@ def run_server(
"num_episodes": dataset.num_episodes, "num_episodes": dataset.num_episodes,
"fps": dataset.fps, "fps": dataset.fps,
} }
video_paths = get_episode_video_paths(dataset, episode_id) video_paths = [dataset.get_video_file_path(episode_id, key) for key in dataset.video_keys]
language_instruction = get_episode_language_instruction(dataset, episode_id) tasks = dataset.episode_dicts[episode_id]["tasks"]
videos_info = [ videos_info = [
{"url": url_for("static", filename=video_path), "filename": Path(video_path).name} {"url": url_for("static", filename=video_path), "filename": video_path.name}
for video_path in video_paths for video_path in video_paths
] ]
if language_instruction: videos_info[0]["language_instruction"] = tasks
videos_info[0]["language_instruction"] = language_instruction
ep_csv_url = url_for("static", filename=get_ep_csv_fname(episode_id)) ep_csv_url = url_for("static", filename=get_ep_csv_fname(episode_id))
return render_template( return render_template(
@ -137,10 +136,10 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset):
# init header of csv with state and action names # init header of csv with state and action names
header = ["timestamp"] header = ["timestamp"]
if has_state: if has_state:
dim_state = len(dataset.hf_dataset["observation.state"][0]) dim_state = dataset.shapes["observation.state"]
header += [f"state_{i}" for i in range(dim_state)] header += [f"state_{i}" for i in range(dim_state)]
if has_action: if has_action:
dim_action = len(dataset.hf_dataset["action"][0]) dim_action = dataset.shapes["action"]
header += [f"action_{i}" for i in range(dim_action)] header += [f"action_{i}" for i in range(dim_action)]
columns = ["timestamp"] columns = ["timestamp"]
@ -171,7 +170,7 @@ def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]
# get first frame of episode (hack to get video_path of the episode) # get first frame of episode (hack to get video_path of the episode)
first_frame_idx = dataset.episode_data_index["from"][ep_index].item() first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
return [ return [
dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"] for key in dataset.camera_keys dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"] for key in dataset.video_keys
] ]
@ -203,8 +202,8 @@ def visualize_dataset_html(
dataset = LeRobotDataset(repo_id, root=root) dataset = LeRobotDataset(repo_id, root=root)
if not dataset.video: if len(dataset.image_keys) > 0:
raise NotImplementedError(f"Image datasets ({dataset.video=}) are currently not supported.") raise NotImplementedError(f"Image keys ({dataset.image_keys=}) are currently not supported.")
if output_dir is None: if output_dir is None:
output_dir = f"outputs/visualize_dataset_html/{repo_id}" output_dir = f"outputs/visualize_dataset_html/{repo_id}"
@ -224,7 +223,7 @@ def visualize_dataset_html(
static_dir.mkdir(parents=True, exist_ok=True) static_dir.mkdir(parents=True, exist_ok=True)
ln_videos_dir = static_dir / "videos" ln_videos_dir = static_dir / "videos"
if not ln_videos_dir.exists(): if not ln_videos_dir.exists():
ln_videos_dir.symlink_to(dataset.videos_dir.resolve()) ln_videos_dir.symlink_to((dataset.root / "videos").resolve())
template_dir = Path(__file__).resolve().parent.parent / "templates" template_dir = Path(__file__).resolve().parent.parent / "templates"