From 100f54ee07418c4d72fd93400ada1fc98f639296 Mon Sep 17 00:00:00 2001 From: Mishig Date: Thu, 9 Jan 2025 11:39:54 +0100 Subject: [PATCH] [viz] Fixes & updates to html visualizer (#617) --- lerobot/scripts/visualize_dataset_html.py | 71 +++++++------------ .../templates/visualize_dataset_template.html | 31 +++++++- 2 files changed, 56 insertions(+), 46 deletions(-) diff --git a/lerobot/scripts/visualize_dataset_html.py b/lerobot/scripts/visualize_dataset_html.py index 39b4c27d..cc3f3930 100644 --- a/lerobot/scripts/visualize_dataset_html.py +++ b/lerobot/scripts/visualize_dataset_html.py @@ -232,69 +232,54 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index) """Get a csv str containing timeseries data of an episode (e.g. state and action). This file will be loaded by Dygraph javascript to plot data in real time.""" columns = [] - has_state = "observation.state" in dataset.features - has_action = "action" in dataset.features + + selected_columns = [col for col, ft in dataset.features.items() if ft["dtype"] == "float32"] + selected_columns.remove("timestamp") # init header of csv with state and action names header = ["timestamp"] - if has_state: + + for column_name in selected_columns: dim_state = ( - dataset.meta.shapes["observation.state"][0] + dataset.meta.shapes[column_name][0] if isinstance(dataset, LeRobotDataset) - else dataset.features["observation.state"].shape[0] + else dataset.features[column_name].shape[0] ) - header += [f"state_{i}" for i in range(dim_state)] - column_names = dataset.features["observation.state"]["names"] - while not isinstance(column_names, list): - column_names = list(column_names.values())[0] - columns.append({"key": "state", "value": column_names}) - if has_action: - dim_action = ( - dataset.meta.shapes["action"][0] - if isinstance(dataset, LeRobotDataset) - else dataset.features.action.shape[0] - ) - header += [f"action_{i}" for i in range(dim_action)] - column_names = dataset.features["action"]["names"] - while not isinstance(column_names, list): - column_names = list(column_names.values())[0] - columns.append({"key": "action", "value": column_names}) + header += [f"{column_name}_{i}" for i in range(dim_state)] + + if "names" in dataset.features[column_name] and dataset.features[column_name]["names"]: + column_names = dataset.features[column_name]["names"] + while not isinstance(column_names, list): + column_names = list(column_names.values())[0] + else: + column_names = [f"motor_{i}" for i in range(dim_state)] + columns.append({"key": column_name, "value": column_names}) + + selected_columns.insert(0, "timestamp") if isinstance(dataset, LeRobotDataset): from_idx = dataset.episode_data_index["from"][episode_index] to_idx = dataset.episode_data_index["to"][episode_index] - selected_columns = ["timestamp"] - if has_state: - selected_columns += ["observation.state"] - if has_action: - selected_columns += ["action"] data = ( dataset.hf_dataset.select(range(from_idx, to_idx)) .select_columns(selected_columns) - .with_format("numpy") + .with_format("pandas") ) - rows = np.hstack( - (np.expand_dims(data["timestamp"], axis=1), *[data[col] for col in selected_columns[1:]]) - ).tolist() else: repo_id = dataset.repo_id - selected_columns = ["timestamp"] - if "observation.state" in dataset.features: - selected_columns.append("observation.state") - if "action" in dataset.features: - selected_columns.append("action") url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/" + dataset.data_path.format( episode_chunk=int(episode_index) // dataset.chunks_size, episode_index=episode_index ) df = pd.read_parquet(url) data = df[selected_columns] # Select specific columns - rows = np.hstack( - ( - np.expand_dims(data["timestamp"], axis=1), - *[np.vstack(data[col]) for col in selected_columns[1:]], - ) - ).tolist() + + rows = np.hstack( + ( + np.expand_dims(data["timestamp"], axis=1), + *[np.vstack(data[col]) for col in selected_columns[1:]], + ) + ).tolist() # Convert data to CSV string csv_buffer = StringIO() @@ -379,10 +364,6 @@ def visualize_dataset_html( template_folder=template_dir, ) else: - image_keys = dataset.meta.image_keys if isinstance(dataset, LeRobotDataset) else [] - if len(image_keys) > 0: - raise NotImplementedError(f"Image keys ({image_keys=}) are currently not supported.") - # Create a simlink from the dataset video folder containg mp4 files to the output directory # so that the http server can get access to the mp4 files. if isinstance(dataset, LeRobotDataset): diff --git a/lerobot/templates/visualize_dataset_template.html b/lerobot/templates/visualize_dataset_template.html index 12d6e991..3c93d2d6 100644 --- a/lerobot/templates/visualize_dataset_template.html +++ b/lerobot/templates/visualize_dataset_template.html @@ -98,9 +98,34 @@ +
+
+ filter videos +
🔽
+
+ +
+
+ +
+
+
+
{% for video_info in videos_info %} -
+

{{ video_info.filename }}