diff --git a/lerobot/scripts/visualize_dataset_html.py b/lerobot/scripts/visualize_dataset_html.py index 39b4c27d..cc3f3930 100644 --- a/lerobot/scripts/visualize_dataset_html.py +++ b/lerobot/scripts/visualize_dataset_html.py @@ -232,69 +232,54 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index) """Get a csv str containing timeseries data of an episode (e.g. state and action). This file will be loaded by Dygraph javascript to plot data in real time.""" columns = [] - has_state = "observation.state" in dataset.features - has_action = "action" in dataset.features + + selected_columns = [col for col, ft in dataset.features.items() if ft["dtype"] == "float32"] + selected_columns.remove("timestamp") # init header of csv with state and action names header = ["timestamp"] - if has_state: + + for column_name in selected_columns: dim_state = ( - dataset.meta.shapes["observation.state"][0] + dataset.meta.shapes[column_name][0] if isinstance(dataset, LeRobotDataset) - else dataset.features["observation.state"].shape[0] + else dataset.features[column_name].shape[0] ) - header += [f"state_{i}" for i in range(dim_state)] - column_names = dataset.features["observation.state"]["names"] - while not isinstance(column_names, list): - column_names = list(column_names.values())[0] - columns.append({"key": "state", "value": column_names}) - if has_action: - dim_action = ( - dataset.meta.shapes["action"][0] - if isinstance(dataset, LeRobotDataset) - else dataset.features.action.shape[0] - ) - header += [f"action_{i}" for i in range(dim_action)] - column_names = dataset.features["action"]["names"] - while not isinstance(column_names, list): - column_names = list(column_names.values())[0] - columns.append({"key": "action", "value": column_names}) + header += [f"{column_name}_{i}" for i in range(dim_state)] + + if "names" in dataset.features[column_name] and dataset.features[column_name]["names"]: + column_names = dataset.features[column_name]["names"] + while not isinstance(column_names, list): + column_names = list(column_names.values())[0] + else: + column_names = [f"motor_{i}" for i in range(dim_state)] + columns.append({"key": column_name, "value": column_names}) + + selected_columns.insert(0, "timestamp") if isinstance(dataset, LeRobotDataset): from_idx = dataset.episode_data_index["from"][episode_index] to_idx = dataset.episode_data_index["to"][episode_index] - selected_columns = ["timestamp"] - if has_state: - selected_columns += ["observation.state"] - if has_action: - selected_columns += ["action"] data = ( dataset.hf_dataset.select(range(from_idx, to_idx)) .select_columns(selected_columns) - .with_format("numpy") + .with_format("pandas") ) - rows = np.hstack( - (np.expand_dims(data["timestamp"], axis=1), *[data[col] for col in selected_columns[1:]]) - ).tolist() else: repo_id = dataset.repo_id - selected_columns = ["timestamp"] - if "observation.state" in dataset.features: - selected_columns.append("observation.state") - if "action" in dataset.features: - selected_columns.append("action") url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/" + dataset.data_path.format( episode_chunk=int(episode_index) // dataset.chunks_size, episode_index=episode_index ) df = pd.read_parquet(url) data = df[selected_columns] # Select specific columns - rows = np.hstack( - ( - np.expand_dims(data["timestamp"], axis=1), - *[np.vstack(data[col]) for col in selected_columns[1:]], - ) - ).tolist() + + rows = np.hstack( + ( + np.expand_dims(data["timestamp"], axis=1), + *[np.vstack(data[col]) for col in selected_columns[1:]], + ) + ).tolist() # Convert data to CSV string csv_buffer = StringIO() @@ -379,10 +364,6 @@ def visualize_dataset_html( template_folder=template_dir, ) else: - image_keys = dataset.meta.image_keys if isinstance(dataset, LeRobotDataset) else [] - if len(image_keys) > 0: - raise NotImplementedError(f"Image keys ({image_keys=}) are currently not supported.") - # Create a simlink from the dataset video folder containg mp4 files to the output directory # so that the http server can get access to the mp4 files. if isinstance(dataset, LeRobotDataset): diff --git a/lerobot/templates/visualize_dataset_template.html b/lerobot/templates/visualize_dataset_template.html index 12d6e991..3c93d2d6 100644 --- a/lerobot/templates/visualize_dataset_template.html +++ b/lerobot/templates/visualize_dataset_template.html @@ -98,9 +98,34 @@ </div> <!-- Videos --> + <div class="max-w-32 relative text-sm mb-4 select-none" + @click.outside="isVideosDropdownOpen = false"> + <div + @click="isVideosDropdownOpen = !isVideosDropdownOpen" + class="p-2 border border-slate-500 rounded flex justify-between items-center cursor-pointer" + > + <span class="truncate">filter videos</span> + <div class="transition-transform" :class="{ 'rotate-180': isVideosDropdownOpen }">🔽</div> + </div> + + <div x-show="isVideosDropdownOpen" + class="absolute mt-1 border border-slate-500 rounded shadow-lg z-10"> + <div> + <template x-for="option in videosKeys" :key="option"> + <div + @click="videosKeysSelected = videosKeysSelected.includes(option) ? videosKeysSelected.filter(v => v !== option) : [...videosKeysSelected, option]" + class="p-2 cursor-pointer bg-slate-900" + :class="{ 'bg-slate-700': videosKeysSelected.includes(option) }" + x-text="option" + ></div> + </template> + </div> + </div> + </div> + <div class="flex flex-wrap gap-x-2 gap-y-6"> {% for video_info in videos_info %} - <div x-show="!videoCodecError" class="max-w-96 relative"> + <div x-show="!videoCodecError && videosKeysSelected.includes('{{ video_info.filename }}')" class="max-w-96 relative"> <p class="absolute inset-x-0 -top-4 text-sm text-gray-300 bg-gray-800 px-2 rounded-t-xl truncate">{{ video_info.filename }}</p> <video muted loop type="video/mp4" class="object-contain w-full h-full" @canplaythrough="videoCanPlay" @timeupdate="() => { if (video.duration) { @@ -250,6 +275,9 @@ nVideos: {{ videos_info | length }}, nVideoReadyToPlay: 0, videoCodecError: false, + isVideosDropdownOpen: false, + videosKeys: {{ videos_info | map(attribute='filename') | list | tojson }}, + videosKeysSelected: [], columns: {{ columns | tojson }}, rowLabels: {{ columns | tojson }}.reduce((colA, colB) => colA.value.length > colB.value.length ? colA : colB).value, @@ -261,6 +289,7 @@ if(!canPlayVideos){ this.videoCodecError = true; } + this.videosKeysSelected = this.videosKeys.map(opt => opt) // process CSV data const csvDataStr = {{ episode_data_csv_str|tojson|safe }};