Fix visualization
This commit is contained in:
parent
a2a8538ac9
commit
7ae8d05326
|
@ -97,14 +97,13 @@ def run_server(
|
|||
"num_episodes": dataset.num_episodes,
|
||||
"fps": dataset.fps,
|
||||
}
|
||||
video_paths = get_episode_video_paths(dataset, episode_id)
|
||||
language_instruction = get_episode_language_instruction(dataset, episode_id)
|
||||
video_paths = [dataset.get_video_file_path(episode_id, key) for key in dataset.video_keys]
|
||||
tasks = dataset.episode_dicts[episode_id]["tasks"]
|
||||
videos_info = [
|
||||
{"url": url_for("static", filename=video_path), "filename": Path(video_path).name}
|
||||
{"url": url_for("static", filename=video_path), "filename": video_path.name}
|
||||
for video_path in video_paths
|
||||
]
|
||||
if language_instruction:
|
||||
videos_info[0]["language_instruction"] = language_instruction
|
||||
videos_info[0]["language_instruction"] = tasks
|
||||
|
||||
ep_csv_url = url_for("static", filename=get_ep_csv_fname(episode_id))
|
||||
return render_template(
|
||||
|
@ -137,10 +136,10 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset):
|
|||
# init header of csv with state and action names
|
||||
header = ["timestamp"]
|
||||
if has_state:
|
||||
dim_state = len(dataset.hf_dataset["observation.state"][0])
|
||||
dim_state = dataset.shapes["observation.state"]
|
||||
header += [f"state_{i}" for i in range(dim_state)]
|
||||
if has_action:
|
||||
dim_action = len(dataset.hf_dataset["action"][0])
|
||||
dim_action = dataset.shapes["action"]
|
||||
header += [f"action_{i}" for i in range(dim_action)]
|
||||
|
||||
columns = ["timestamp"]
|
||||
|
@ -171,7 +170,7 @@ def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]
|
|||
# get first frame of episode (hack to get video_path of the episode)
|
||||
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
|
||||
return [
|
||||
dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"] for key in dataset.camera_keys
|
||||
dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"] for key in dataset.video_keys
|
||||
]
|
||||
|
||||
|
||||
|
@ -203,8 +202,8 @@ def visualize_dataset_html(
|
|||
|
||||
dataset = LeRobotDataset(repo_id, root=root)
|
||||
|
||||
if not dataset.video:
|
||||
raise NotImplementedError(f"Image datasets ({dataset.video=}) are currently not supported.")
|
||||
if len(dataset.image_keys) > 0:
|
||||
raise NotImplementedError(f"Image keys ({dataset.image_keys=}) are currently not supported.")
|
||||
|
||||
if output_dir is None:
|
||||
output_dir = f"outputs/visualize_dataset_html/{repo_id}"
|
||||
|
@ -224,7 +223,7 @@ def visualize_dataset_html(
|
|||
static_dir.mkdir(parents=True, exist_ok=True)
|
||||
ln_videos_dir = static_dir / "videos"
|
||||
if not ln_videos_dir.exists():
|
||||
ln_videos_dir.symlink_to(dataset.videos_dir.resolve())
|
||||
ln_videos_dir.symlink_to((dataset.root / "videos").resolve())
|
||||
|
||||
template_dir = Path(__file__).resolve().parent.parent / "templates"
|
||||
|
||||
|
|
Loading…
Reference in New Issue