diff --git a/lerobot/scripts/visualize_dataset_html.py b/lerobot/scripts/visualize_dataset_html.py index 00145aad..2531fbd0 100644 --- a/lerobot/scripts/visualize_dataset_html.py +++ b/lerobot/scripts/visualize_dataset_html.py @@ -50,33 +50,19 @@ python lerobot/scripts/visualize_dataset_html.py \ --repo-id lerobot/pusht \ --episodes 7 3 5 1 4 ``` - -- Run inference of a policy on the dataset and visualize the results: -```bash -python lerobot/scripts/visualize_dataset_html.py \ - --repo-id lerobot/pusht \ - --episodes 7 3 5 1 4 - -p lerobot/diffusion_pusht \ - --policy-overrides device=cpu -``` """ import argparse import logging import shutil -import warnings from pathlib import Path import torch import tqdm from flask import Flask, redirect, render_template, url_for -from safetensors.torch import load_file, save_file -from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset -from lerobot.common.policies.factory import make_policy -from lerobot.common.policies.utils import get_pretrained_policy_path -from lerobot.common.utils.utils import init_hydra_config, init_logging +from lerobot.common.utils.utils import init_logging class EpisodeSampler(torch.utils.data.Sampler): @@ -99,7 +85,6 @@ def run_server( port: str, static_folder: Path, template_folder: Path, - has_policy: bool = False, ): app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve()) app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache @@ -139,7 +124,7 @@ def run_server( dataset_info=dataset_info, videos_info=videos_info, ep_csv_url=ep_csv_url, - has_policy=has_policy, + has_policy=False, ) app.run(host=host, port=port) @@ -150,7 +135,7 @@ def get_ep_csv_fname(episode_id: int): return ep_csv_fname -def write_episode_data_csv(output_dir, file_name, episode_index, dataset, inference_results=None): +def write_episode_data_csv(output_dir, file_name, episode_index, dataset): """Write a csv file containg timeseries data of an episode (e.g. state and action). This file will be loaded by Dygraph javascript to plot data in real time.""" from_idx = dataset.episode_data_index["from"][episode_index] @@ -158,7 +143,6 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset, infere has_state = "observation.state" in dataset.hf_dataset.features has_action = "action" in dataset.hf_dataset.features - has_inference = inference_results is not None # init header of csv with state and action names header = ["timestamp"] @@ -168,13 +152,6 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset, infere if has_action: dim_action = len(dataset.hf_dataset["action"][0]) header += [f"action_{i}" for i in range(dim_action)] - if has_inference: - if "action" in inference_results: - dim_pred_action = inference_results["action"].shape[1] - header += [f"pred_action_{i}" for i in range(dim_pred_action)] - for key in inference_results: - if "loss" in key: - header += [key] columns = ["timestamp"] if has_state: @@ -192,18 +169,6 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset, infere row += data[i]["action"].tolist() rows.append(row) - if has_inference: - num_frames = len(rows) - if "action" in inference_results: - assert num_frames == inference_results["action"].shape[0] - for i in range(num_frames): - rows[i] += inference_results["action"][i].tolist() - for key in inference_results: - if "loss" in key: - assert num_frames == inference_results[key].shape[0] - for i in range(num_frames): - rows[i] += [inference_results[key][i].item()] - output_dir.mkdir(parents=True, exist_ok=True) with open(output_dir / file_name, "w") as f: f.write(",".join(header) + "\n") @@ -221,75 +186,6 @@ def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str] ] -def run_inference( - dataset, episode_index, policy, policy_method="select_action", num_workers=4, batch_size=32, device="cuda" -): - if policy_method not in ["select_action", "forward"]: - raise ValueError( - f"`policy_method` is expected to be 'select_action' or 'forward', but '{policy_method}' is provided instead." - ) - - policy.eval() - policy.to(device) - - logging.info("Loading dataloader") - episode_sampler = EpisodeSampler(dataset, episode_index) - dataloader = torch.utils.data.DataLoader( - dataset, - num_workers=num_workers, - # When using `select_action`, we set batch size 1 so that we feed 1 frame at a time, in a continuous fashion. - batch_size=1 if policy_method == "select_action" else batch_size, - sampler=episode_sampler, - drop_last=False, - ) - - warned_ndim_eq_0 = False - warned_ndim_gt_2 = False - - logging.info("Running inference") - inference_results = {} - for batch in tqdm.tqdm(dataloader, total=len(dataloader)): - batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()} - with torch.inference_mode(): - if policy_method == "select_action": - gt_action = batch.pop("action") - output_dict = {"action": policy.select_action(batch)} - batch["action"] = gt_action - elif policy_method == "forward": - output_dict = policy.forward(batch) - # TODO(rcadene): Save and display all predicted actions at a given timestamp - # Save predicted action for the next timestamp only - output_dict["action"] = output_dict["action"][:, 0, :] - - for key in output_dict: - if output_dict[key].ndim == 0: - if not warned_ndim_eq_0: - warnings.warn( - f"Ignore output key '{key}'. Its value is a scalar instead of a vector. It might have been aggregated over the batch dimension (e.g. `loss.mean()`).", - stacklevel=1, - ) - warned_ndim_eq_0 = True - continue - - if output_dict[key].ndim > 2: - if not warned_ndim_gt_2: - warnings.warn( - f"Ignore output key '{key}'. Its value is a tensor of {output_dict[key].ndim} dimensions instead of a vector.", - stacklevel=1, - ) - warned_ndim_gt_2 = True - continue - - if key not in inference_results: - inference_results[key] = [] - inference_results[key].append(output_dict[key].to("cpu")) - - for key in inference_results: - inference_results[key] = torch.cat(inference_results[key]) - - return inference_results - - def visualize_dataset_html( repo_id: str, root: Path | None = None, @@ -299,28 +195,10 @@ def visualize_dataset_html( host: str = "127.0.0.1", port: int = 9090, force_override: bool = False, - policy_method: str = "select_action", - pretrained_policy_name_or_path: str | None = None, - policy_overrides: list[str] | None = None, ) -> Path | None: init_logging() - has_policy = pretrained_policy_name_or_path is not None - - if has_policy: - logging.info("Loading policy") - pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path) - - hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", policy_overrides) - dataset = make_dataset(hydra_cfg) - policy = make_policy(hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path) - - if policy_method == "select_action": - # Do not load previous observations or future actions, to simulate that the observations come from - # an environment. - dataset.delta_timestamps = None - else: - dataset = LeRobotDataset(repo_id, root=root) + dataset = LeRobotDataset(repo_id, root=root) if not dataset.video: raise NotImplementedError(f"Image datasets ({dataset.video=}) are currently not supported.") @@ -328,11 +206,6 @@ def visualize_dataset_html( if output_dir is None: output_dir = f"outputs/visualize_dataset_html/{repo_id}" - if has_policy: - ckpt_str = pretrained_policy_path.parts[-2] - exp_name = pretrained_policy_path.parts[-4] - output_dir += f"_{exp_name}_{ckpt_str}_{policy_method}" - output_dir = Path(output_dir) if output_dir.exists(): if force_override: @@ -357,31 +230,13 @@ def visualize_dataset_html( logging.info("Writing CSV files") for episode_index in tqdm.tqdm(episodes): - inference_results = None - if has_policy: - inference_results_path = output_dir / f"episode_{episode_index}.safetensors" - if inference_results_path.exists(): - inference_results = load_file(inference_results_path) - else: - inference_results = run_inference( - dataset, - episode_index, - policy, - policy_method, - num_workers=hydra_cfg.training.num_workers, - batch_size=hydra_cfg.training.batch_size, - device=hydra_cfg.device, - ) - inference_results_path.parent.mkdir(parents=True, exist_ok=True) - save_file(inference_results, inference_results_path) - # write states and actions in a csv (it can be slow for big datasets) ep_csv_fname = get_ep_csv_fname(episode_index) # TODO(rcadene): speedup script by loading directly from dataset, pyarrow, parquet, safetensors? - write_episode_data_csv(static_dir, ep_csv_fname, episode_index, dataset, inference_results) + write_episode_data_csv(static_dir, ep_csv_fname, episode_index, dataset) if serve: - run_server(dataset, episodes, host, port, static_dir, template_dir, has_policy) + run_server(dataset, episodes, host, port, static_dir, template_dir) def main(): @@ -437,28 +292,6 @@ def main(): help="Delete the output directory if it exists already.", ) - parser.add_argument( - "--policy-method", - type=str, - default="select_action", - choices=["select_action", "forward"], - help="Python method used to run the inference. By default, set to `select_action` used during evaluation to output the sequence of actions. Can bet set to `forward` used during training to compute the loss.", - ) - parser.add_argument( - "-p", - "--pretrained-policy-name-or-path", - type=str, - help=( - "Either the repo ID of a model hosted on the Hub or a path to a directory containing weights " - "saved using `Policy.save_pretrained`." - ), - ) - parser.add_argument( - "--policy-overrides", - nargs="*", - help="Any key=value arguments to override policy config values (use dots for.nested=overrides)", - ) - args = parser.parse_args() visualize_dataset_html(**vars(args)) diff --git a/tests/test_visualize_dataset_html.py b/tests/test_visualize_dataset_html.py index 77ababfa..4dc3c063 100644 --- a/tests/test_visualize_dataset_html.py +++ b/tests/test_visualize_dataset_html.py @@ -18,12 +18,7 @@ from pathlib import Path import pytest -from lerobot.common.datasets.factory import make_dataset -from lerobot.common.logger import Logger -from lerobot.common.policies.factory import make_policy -from lerobot.common.utils.utils import init_hydra_config from lerobot.scripts.visualize_dataset_html import visualize_dataset_html -from tests.utils import DEFAULT_CONFIG_PATH @pytest.mark.parametrize( @@ -39,34 +34,3 @@ def test_visualize_dataset_html(tmpdir, repo_id): serve=False, ) assert (tmpdir / "static" / "episode_0.csv").exists() - - -@pytest.mark.parametrize( - "repo_id, policy_method", - [ - ("lerobot/pusht", "select_action"), - ("lerobot/pusht", "forward"), - ], -) -def test_visualize_dataset_policy_ckpt_path(tmpdir, repo_id, policy_method): - tmpdir = Path(tmpdir) - - # Create a policy - cfg = init_hydra_config(DEFAULT_CONFIG_PATH, overrides=["device=cpu"]) - dataset = make_dataset(cfg) - policy = make_policy(cfg, dataset_stats=dataset.stats) - - # Save a checkpoint - logger = Logger(cfg, tmpdir) - logger.save_model(tmpdir, policy) - - visualize_dataset_html( - repo_id, - episodes=[0], - output_dir=tmpdir, - serve=False, - pretrained_policy_name_or_path=tmpdir, - policy_method=policy_method, - ) - assert (tmpdir / "static" / "episode_0.csv").exists() - assert (tmpdir / "episode_0.safetensors").exists()