Revove inference
This commit is contained in:
parent
6b9dcadbf7
commit
ca7f207d74
|
@ -50,33 +50,19 @@ python lerobot/scripts/visualize_dataset_html.py \
|
|||
--repo-id lerobot/pusht \
|
||||
--episodes 7 3 5 1 4
|
||||
```
|
||||
|
||||
- Run inference of a policy on the dataset and visualize the results:
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset_html.py \
|
||||
--repo-id lerobot/pusht \
|
||||
--episodes 7 3 5 1 4
|
||||
-p lerobot/diffusion_pusht \
|
||||
--policy-overrides device=cpu
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import shutil
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
from flask import Flask, redirect, render_template, url_for
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.utils import get_pretrained_policy_path
|
||||
from lerobot.common.utils.utils import init_hydra_config, init_logging
|
||||
from lerobot.common.utils.utils import init_logging
|
||||
|
||||
|
||||
class EpisodeSampler(torch.utils.data.Sampler):
|
||||
|
@ -99,7 +85,6 @@ def run_server(
|
|||
port: str,
|
||||
static_folder: Path,
|
||||
template_folder: Path,
|
||||
has_policy: bool = False,
|
||||
):
|
||||
app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve())
|
||||
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
|
||||
|
@ -139,7 +124,7 @@ def run_server(
|
|||
dataset_info=dataset_info,
|
||||
videos_info=videos_info,
|
||||
ep_csv_url=ep_csv_url,
|
||||
has_policy=has_policy,
|
||||
has_policy=False,
|
||||
)
|
||||
|
||||
app.run(host=host, port=port)
|
||||
|
@ -150,7 +135,7 @@ def get_ep_csv_fname(episode_id: int):
|
|||
return ep_csv_fname
|
||||
|
||||
|
||||
def write_episode_data_csv(output_dir, file_name, episode_index, dataset, inference_results=None):
|
||||
def write_episode_data_csv(output_dir, file_name, episode_index, dataset):
|
||||
"""Write a csv file containg timeseries data of an episode (e.g. state and action).
|
||||
This file will be loaded by Dygraph javascript to plot data in real time."""
|
||||
from_idx = dataset.episode_data_index["from"][episode_index]
|
||||
|
@ -158,7 +143,6 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset, infere
|
|||
|
||||
has_state = "observation.state" in dataset.hf_dataset.features
|
||||
has_action = "action" in dataset.hf_dataset.features
|
||||
has_inference = inference_results is not None
|
||||
|
||||
# init header of csv with state and action names
|
||||
header = ["timestamp"]
|
||||
|
@ -168,13 +152,6 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset, infere
|
|||
if has_action:
|
||||
dim_action = len(dataset.hf_dataset["action"][0])
|
||||
header += [f"action_{i}" for i in range(dim_action)]
|
||||
if has_inference:
|
||||
if "action" in inference_results:
|
||||
dim_pred_action = inference_results["action"].shape[1]
|
||||
header += [f"pred_action_{i}" for i in range(dim_pred_action)]
|
||||
for key in inference_results:
|
||||
if "loss" in key:
|
||||
header += [key]
|
||||
|
||||
columns = ["timestamp"]
|
||||
if has_state:
|
||||
|
@ -192,18 +169,6 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset, infere
|
|||
row += data[i]["action"].tolist()
|
||||
rows.append(row)
|
||||
|
||||
if has_inference:
|
||||
num_frames = len(rows)
|
||||
if "action" in inference_results:
|
||||
assert num_frames == inference_results["action"].shape[0]
|
||||
for i in range(num_frames):
|
||||
rows[i] += inference_results["action"][i].tolist()
|
||||
for key in inference_results:
|
||||
if "loss" in key:
|
||||
assert num_frames == inference_results[key].shape[0]
|
||||
for i in range(num_frames):
|
||||
rows[i] += [inference_results[key][i].item()]
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_dir / file_name, "w") as f:
|
||||
f.write(",".join(header) + "\n")
|
||||
|
@ -221,75 +186,6 @@ def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]
|
|||
]
|
||||
|
||||
|
||||
def run_inference(
|
||||
dataset, episode_index, policy, policy_method="select_action", num_workers=4, batch_size=32, device="cuda"
|
||||
):
|
||||
if policy_method not in ["select_action", "forward"]:
|
||||
raise ValueError(
|
||||
f"`policy_method` is expected to be 'select_action' or 'forward', but '{policy_method}' is provided instead."
|
||||
)
|
||||
|
||||
policy.eval()
|
||||
policy.to(device)
|
||||
|
||||
logging.info("Loading dataloader")
|
||||
episode_sampler = EpisodeSampler(dataset, episode_index)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=num_workers,
|
||||
# When using `select_action`, we set batch size 1 so that we feed 1 frame at a time, in a continuous fashion.
|
||||
batch_size=1 if policy_method == "select_action" else batch_size,
|
||||
sampler=episode_sampler,
|
||||
drop_last=False,
|
||||
)
|
||||
|
||||
warned_ndim_eq_0 = False
|
||||
warned_ndim_gt_2 = False
|
||||
|
||||
logging.info("Running inference")
|
||||
inference_results = {}
|
||||
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
|
||||
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
|
||||
with torch.inference_mode():
|
||||
if policy_method == "select_action":
|
||||
gt_action = batch.pop("action")
|
||||
output_dict = {"action": policy.select_action(batch)}
|
||||
batch["action"] = gt_action
|
||||
elif policy_method == "forward":
|
||||
output_dict = policy.forward(batch)
|
||||
# TODO(rcadene): Save and display all predicted actions at a given timestamp
|
||||
# Save predicted action for the next timestamp only
|
||||
output_dict["action"] = output_dict["action"][:, 0, :]
|
||||
|
||||
for key in output_dict:
|
||||
if output_dict[key].ndim == 0:
|
||||
if not warned_ndim_eq_0:
|
||||
warnings.warn(
|
||||
f"Ignore output key '{key}'. Its value is a scalar instead of a vector. It might have been aggregated over the batch dimension (e.g. `loss.mean()`).",
|
||||
stacklevel=1,
|
||||
)
|
||||
warned_ndim_eq_0 = True
|
||||
continue
|
||||
|
||||
if output_dict[key].ndim > 2:
|
||||
if not warned_ndim_gt_2:
|
||||
warnings.warn(
|
||||
f"Ignore output key '{key}'. Its value is a tensor of {output_dict[key].ndim} dimensions instead of a vector.",
|
||||
stacklevel=1,
|
||||
)
|
||||
warned_ndim_gt_2 = True
|
||||
continue
|
||||
|
||||
if key not in inference_results:
|
||||
inference_results[key] = []
|
||||
inference_results[key].append(output_dict[key].to("cpu"))
|
||||
|
||||
for key in inference_results:
|
||||
inference_results[key] = torch.cat(inference_results[key])
|
||||
|
||||
return inference_results
|
||||
|
||||
|
||||
def visualize_dataset_html(
|
||||
repo_id: str,
|
||||
root: Path | None = None,
|
||||
|
@ -299,28 +195,10 @@ def visualize_dataset_html(
|
|||
host: str = "127.0.0.1",
|
||||
port: int = 9090,
|
||||
force_override: bool = False,
|
||||
policy_method: str = "select_action",
|
||||
pretrained_policy_name_or_path: str | None = None,
|
||||
policy_overrides: list[str] | None = None,
|
||||
) -> Path | None:
|
||||
init_logging()
|
||||
|
||||
has_policy = pretrained_policy_name_or_path is not None
|
||||
|
||||
if has_policy:
|
||||
logging.info("Loading policy")
|
||||
pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path)
|
||||
|
||||
hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", policy_overrides)
|
||||
dataset = make_dataset(hydra_cfg)
|
||||
policy = make_policy(hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
|
||||
|
||||
if policy_method == "select_action":
|
||||
# Do not load previous observations or future actions, to simulate that the observations come from
|
||||
# an environment.
|
||||
dataset.delta_timestamps = None
|
||||
else:
|
||||
dataset = LeRobotDataset(repo_id, root=root)
|
||||
dataset = LeRobotDataset(repo_id, root=root)
|
||||
|
||||
if not dataset.video:
|
||||
raise NotImplementedError(f"Image datasets ({dataset.video=}) are currently not supported.")
|
||||
|
@ -328,11 +206,6 @@ def visualize_dataset_html(
|
|||
if output_dir is None:
|
||||
output_dir = f"outputs/visualize_dataset_html/{repo_id}"
|
||||
|
||||
if has_policy:
|
||||
ckpt_str = pretrained_policy_path.parts[-2]
|
||||
exp_name = pretrained_policy_path.parts[-4]
|
||||
output_dir += f"_{exp_name}_{ckpt_str}_{policy_method}"
|
||||
|
||||
output_dir = Path(output_dir)
|
||||
if output_dir.exists():
|
||||
if force_override:
|
||||
|
@ -357,31 +230,13 @@ def visualize_dataset_html(
|
|||
|
||||
logging.info("Writing CSV files")
|
||||
for episode_index in tqdm.tqdm(episodes):
|
||||
inference_results = None
|
||||
if has_policy:
|
||||
inference_results_path = output_dir / f"episode_{episode_index}.safetensors"
|
||||
if inference_results_path.exists():
|
||||
inference_results = load_file(inference_results_path)
|
||||
else:
|
||||
inference_results = run_inference(
|
||||
dataset,
|
||||
episode_index,
|
||||
policy,
|
||||
policy_method,
|
||||
num_workers=hydra_cfg.training.num_workers,
|
||||
batch_size=hydra_cfg.training.batch_size,
|
||||
device=hydra_cfg.device,
|
||||
)
|
||||
inference_results_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
save_file(inference_results, inference_results_path)
|
||||
|
||||
# write states and actions in a csv (it can be slow for big datasets)
|
||||
ep_csv_fname = get_ep_csv_fname(episode_index)
|
||||
# TODO(rcadene): speedup script by loading directly from dataset, pyarrow, parquet, safetensors?
|
||||
write_episode_data_csv(static_dir, ep_csv_fname, episode_index, dataset, inference_results)
|
||||
write_episode_data_csv(static_dir, ep_csv_fname, episode_index, dataset)
|
||||
|
||||
if serve:
|
||||
run_server(dataset, episodes, host, port, static_dir, template_dir, has_policy)
|
||||
run_server(dataset, episodes, host, port, static_dir, template_dir)
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -437,28 +292,6 @@ def main():
|
|||
help="Delete the output directory if it exists already.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--policy-method",
|
||||
type=str,
|
||||
default="select_action",
|
||||
choices=["select_action", "forward"],
|
||||
help="Python method used to run the inference. By default, set to `select_action` used during evaluation to output the sequence of actions. Can bet set to `forward` used during training to compute the loss.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--pretrained-policy-name-or-path",
|
||||
type=str,
|
||||
help=(
|
||||
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
|
||||
"saved using `Policy.save_pretrained`."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--policy-overrides",
|
||||
nargs="*",
|
||||
help="Any key=value arguments to override policy config values (use dots for.nested=overrides)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
visualize_dataset_html(**vars(args))
|
||||
|
||||
|
|
|
@ -18,12 +18,7 @@ from pathlib import Path
|
|||
|
||||
import pytest
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.logger import Logger
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
from lerobot.scripts.visualize_dataset_html import visualize_dataset_html
|
||||
from tests.utils import DEFAULT_CONFIG_PATH
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -39,34 +34,3 @@ def test_visualize_dataset_html(tmpdir, repo_id):
|
|||
serve=False,
|
||||
)
|
||||
assert (tmpdir / "static" / "episode_0.csv").exists()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"repo_id, policy_method",
|
||||
[
|
||||
("lerobot/pusht", "select_action"),
|
||||
("lerobot/pusht", "forward"),
|
||||
],
|
||||
)
|
||||
def test_visualize_dataset_policy_ckpt_path(tmpdir, repo_id, policy_method):
|
||||
tmpdir = Path(tmpdir)
|
||||
|
||||
# Create a policy
|
||||
cfg = init_hydra_config(DEFAULT_CONFIG_PATH, overrides=["device=cpu"])
|
||||
dataset = make_dataset(cfg)
|
||||
policy = make_policy(cfg, dataset_stats=dataset.stats)
|
||||
|
||||
# Save a checkpoint
|
||||
logger = Logger(cfg, tmpdir)
|
||||
logger.save_model(tmpdir, policy)
|
||||
|
||||
visualize_dataset_html(
|
||||
repo_id,
|
||||
episodes=[0],
|
||||
output_dir=tmpdir,
|
||||
serve=False,
|
||||
pretrained_policy_name_or_path=tmpdir,
|
||||
policy_method=policy_method,
|
||||
)
|
||||
assert (tmpdir / "static" / "episode_0.csv").exists()
|
||||
assert (tmpdir / "episode_0.safetensors").exists()
|
||||
|
|
Loading…
Reference in New Issue