Revove inference

This commit is contained in:
Remi Cadene 2024-08-06 17:15:52 +03:00
parent 6b9dcadbf7
commit ca7f207d74
2 changed files with 6 additions and 209 deletions

View File

@ -50,33 +50,19 @@ python lerobot/scripts/visualize_dataset_html.py \
--repo-id lerobot/pusht \ --repo-id lerobot/pusht \
--episodes 7 3 5 1 4 --episodes 7 3 5 1 4
``` ```
- Run inference of a policy on the dataset and visualize the results:
```bash
python lerobot/scripts/visualize_dataset_html.py \
--repo-id lerobot/pusht \
--episodes 7 3 5 1 4
-p lerobot/diffusion_pusht \
--policy-overrides device=cpu
```
""" """
import argparse import argparse
import logging import logging
import shutil import shutil
import warnings
from pathlib import Path from pathlib import Path
import torch import torch
import tqdm import tqdm
from flask import Flask, redirect, render_template, url_for from flask import Flask, redirect, render_template, url_for
from safetensors.torch import load_file, save_file
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.policies.factory import make_policy from lerobot.common.utils.utils import init_logging
from lerobot.common.policies.utils import get_pretrained_policy_path
from lerobot.common.utils.utils import init_hydra_config, init_logging
class EpisodeSampler(torch.utils.data.Sampler): class EpisodeSampler(torch.utils.data.Sampler):
@ -99,7 +85,6 @@ def run_server(
port: str, port: str,
static_folder: Path, static_folder: Path,
template_folder: Path, template_folder: Path,
has_policy: bool = False,
): ):
app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve()) app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve())
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
@ -139,7 +124,7 @@ def run_server(
dataset_info=dataset_info, dataset_info=dataset_info,
videos_info=videos_info, videos_info=videos_info,
ep_csv_url=ep_csv_url, ep_csv_url=ep_csv_url,
has_policy=has_policy, has_policy=False,
) )
app.run(host=host, port=port) app.run(host=host, port=port)
@ -150,7 +135,7 @@ def get_ep_csv_fname(episode_id: int):
return ep_csv_fname return ep_csv_fname
def write_episode_data_csv(output_dir, file_name, episode_index, dataset, inference_results=None): def write_episode_data_csv(output_dir, file_name, episode_index, dataset):
"""Write a csv file containg timeseries data of an episode (e.g. state and action). """Write a csv file containg timeseries data of an episode (e.g. state and action).
This file will be loaded by Dygraph javascript to plot data in real time.""" This file will be loaded by Dygraph javascript to plot data in real time."""
from_idx = dataset.episode_data_index["from"][episode_index] from_idx = dataset.episode_data_index["from"][episode_index]
@ -158,7 +143,6 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset, infere
has_state = "observation.state" in dataset.hf_dataset.features has_state = "observation.state" in dataset.hf_dataset.features
has_action = "action" in dataset.hf_dataset.features has_action = "action" in dataset.hf_dataset.features
has_inference = inference_results is not None
# init header of csv with state and action names # init header of csv with state and action names
header = ["timestamp"] header = ["timestamp"]
@ -168,13 +152,6 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset, infere
if has_action: if has_action:
dim_action = len(dataset.hf_dataset["action"][0]) dim_action = len(dataset.hf_dataset["action"][0])
header += [f"action_{i}" for i in range(dim_action)] header += [f"action_{i}" for i in range(dim_action)]
if has_inference:
if "action" in inference_results:
dim_pred_action = inference_results["action"].shape[1]
header += [f"pred_action_{i}" for i in range(dim_pred_action)]
for key in inference_results:
if "loss" in key:
header += [key]
columns = ["timestamp"] columns = ["timestamp"]
if has_state: if has_state:
@ -192,18 +169,6 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset, infere
row += data[i]["action"].tolist() row += data[i]["action"].tolist()
rows.append(row) rows.append(row)
if has_inference:
num_frames = len(rows)
if "action" in inference_results:
assert num_frames == inference_results["action"].shape[0]
for i in range(num_frames):
rows[i] += inference_results["action"][i].tolist()
for key in inference_results:
if "loss" in key:
assert num_frames == inference_results[key].shape[0]
for i in range(num_frames):
rows[i] += [inference_results[key][i].item()]
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
with open(output_dir / file_name, "w") as f: with open(output_dir / file_name, "w") as f:
f.write(",".join(header) + "\n") f.write(",".join(header) + "\n")
@ -221,75 +186,6 @@ def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]
] ]
def run_inference(
dataset, episode_index, policy, policy_method="select_action", num_workers=4, batch_size=32, device="cuda"
):
if policy_method not in ["select_action", "forward"]:
raise ValueError(
f"`policy_method` is expected to be 'select_action' or 'forward', but '{policy_method}' is provided instead."
)
policy.eval()
policy.to(device)
logging.info("Loading dataloader")
episode_sampler = EpisodeSampler(dataset, episode_index)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=num_workers,
# When using `select_action`, we set batch size 1 so that we feed 1 frame at a time, in a continuous fashion.
batch_size=1 if policy_method == "select_action" else batch_size,
sampler=episode_sampler,
drop_last=False,
)
warned_ndim_eq_0 = False
warned_ndim_gt_2 = False
logging.info("Running inference")
inference_results = {}
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
with torch.inference_mode():
if policy_method == "select_action":
gt_action = batch.pop("action")
output_dict = {"action": policy.select_action(batch)}
batch["action"] = gt_action
elif policy_method == "forward":
output_dict = policy.forward(batch)
# TODO(rcadene): Save and display all predicted actions at a given timestamp
# Save predicted action for the next timestamp only
output_dict["action"] = output_dict["action"][:, 0, :]
for key in output_dict:
if output_dict[key].ndim == 0:
if not warned_ndim_eq_0:
warnings.warn(
f"Ignore output key '{key}'. Its value is a scalar instead of a vector. It might have been aggregated over the batch dimension (e.g. `loss.mean()`).",
stacklevel=1,
)
warned_ndim_eq_0 = True
continue
if output_dict[key].ndim > 2:
if not warned_ndim_gt_2:
warnings.warn(
f"Ignore output key '{key}'. Its value is a tensor of {output_dict[key].ndim} dimensions instead of a vector.",
stacklevel=1,
)
warned_ndim_gt_2 = True
continue
if key not in inference_results:
inference_results[key] = []
inference_results[key].append(output_dict[key].to("cpu"))
for key in inference_results:
inference_results[key] = torch.cat(inference_results[key])
return inference_results
def visualize_dataset_html( def visualize_dataset_html(
repo_id: str, repo_id: str,
root: Path | None = None, root: Path | None = None,
@ -299,28 +195,10 @@ def visualize_dataset_html(
host: str = "127.0.0.1", host: str = "127.0.0.1",
port: int = 9090, port: int = 9090,
force_override: bool = False, force_override: bool = False,
policy_method: str = "select_action",
pretrained_policy_name_or_path: str | None = None,
policy_overrides: list[str] | None = None,
) -> Path | None: ) -> Path | None:
init_logging() init_logging()
has_policy = pretrained_policy_name_or_path is not None dataset = LeRobotDataset(repo_id, root=root)
if has_policy:
logging.info("Loading policy")
pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path)
hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", policy_overrides)
dataset = make_dataset(hydra_cfg)
policy = make_policy(hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
if policy_method == "select_action":
# Do not load previous observations or future actions, to simulate that the observations come from
# an environment.
dataset.delta_timestamps = None
else:
dataset = LeRobotDataset(repo_id, root=root)
if not dataset.video: if not dataset.video:
raise NotImplementedError(f"Image datasets ({dataset.video=}) are currently not supported.") raise NotImplementedError(f"Image datasets ({dataset.video=}) are currently not supported.")
@ -328,11 +206,6 @@ def visualize_dataset_html(
if output_dir is None: if output_dir is None:
output_dir = f"outputs/visualize_dataset_html/{repo_id}" output_dir = f"outputs/visualize_dataset_html/{repo_id}"
if has_policy:
ckpt_str = pretrained_policy_path.parts[-2]
exp_name = pretrained_policy_path.parts[-4]
output_dir += f"_{exp_name}_{ckpt_str}_{policy_method}"
output_dir = Path(output_dir) output_dir = Path(output_dir)
if output_dir.exists(): if output_dir.exists():
if force_override: if force_override:
@ -357,31 +230,13 @@ def visualize_dataset_html(
logging.info("Writing CSV files") logging.info("Writing CSV files")
for episode_index in tqdm.tqdm(episodes): for episode_index in tqdm.tqdm(episodes):
inference_results = None
if has_policy:
inference_results_path = output_dir / f"episode_{episode_index}.safetensors"
if inference_results_path.exists():
inference_results = load_file(inference_results_path)
else:
inference_results = run_inference(
dataset,
episode_index,
policy,
policy_method,
num_workers=hydra_cfg.training.num_workers,
batch_size=hydra_cfg.training.batch_size,
device=hydra_cfg.device,
)
inference_results_path.parent.mkdir(parents=True, exist_ok=True)
save_file(inference_results, inference_results_path)
# write states and actions in a csv (it can be slow for big datasets) # write states and actions in a csv (it can be slow for big datasets)
ep_csv_fname = get_ep_csv_fname(episode_index) ep_csv_fname = get_ep_csv_fname(episode_index)
# TODO(rcadene): speedup script by loading directly from dataset, pyarrow, parquet, safetensors? # TODO(rcadene): speedup script by loading directly from dataset, pyarrow, parquet, safetensors?
write_episode_data_csv(static_dir, ep_csv_fname, episode_index, dataset, inference_results) write_episode_data_csv(static_dir, ep_csv_fname, episode_index, dataset)
if serve: if serve:
run_server(dataset, episodes, host, port, static_dir, template_dir, has_policy) run_server(dataset, episodes, host, port, static_dir, template_dir)
def main(): def main():
@ -437,28 +292,6 @@ def main():
help="Delete the output directory if it exists already.", help="Delete the output directory if it exists already.",
) )
parser.add_argument(
"--policy-method",
type=str,
default="select_action",
choices=["select_action", "forward"],
help="Python method used to run the inference. By default, set to `select_action` used during evaluation to output the sequence of actions. Can bet set to `forward` used during training to compute the loss.",
)
parser.add_argument(
"-p",
"--pretrained-policy-name-or-path",
type=str,
help=(
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
"saved using `Policy.save_pretrained`."
),
)
parser.add_argument(
"--policy-overrides",
nargs="*",
help="Any key=value arguments to override policy config values (use dots for.nested=overrides)",
)
args = parser.parse_args() args = parser.parse_args()
visualize_dataset_html(**vars(args)) visualize_dataset_html(**vars(args))

View File

@ -18,12 +18,7 @@ from pathlib import Path
import pytest import pytest
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.logger import Logger
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils.utils import init_hydra_config
from lerobot.scripts.visualize_dataset_html import visualize_dataset_html from lerobot.scripts.visualize_dataset_html import visualize_dataset_html
from tests.utils import DEFAULT_CONFIG_PATH
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -39,34 +34,3 @@ def test_visualize_dataset_html(tmpdir, repo_id):
serve=False, serve=False,
) )
assert (tmpdir / "static" / "episode_0.csv").exists() assert (tmpdir / "static" / "episode_0.csv").exists()
@pytest.mark.parametrize(
"repo_id, policy_method",
[
("lerobot/pusht", "select_action"),
("lerobot/pusht", "forward"),
],
)
def test_visualize_dataset_policy_ckpt_path(tmpdir, repo_id, policy_method):
tmpdir = Path(tmpdir)
# Create a policy
cfg = init_hydra_config(DEFAULT_CONFIG_PATH, overrides=["device=cpu"])
dataset = make_dataset(cfg)
policy = make_policy(cfg, dataset_stats=dataset.stats)
# Save a checkpoint
logger = Logger(cfg, tmpdir)
logger.save_model(tmpdir, policy)
visualize_dataset_html(
repo_id,
episodes=[0],
output_dir=tmpdir,
serve=False,
pretrained_policy_name_or_path=tmpdir,
policy_method=policy_method,
)
assert (tmpdir / "static" / "episode_0.csv").exists()
assert (tmpdir / "episode_0.safetensors").exists()