Revove inference
This commit is contained in:
parent
6b9dcadbf7
commit
ca7f207d74
|
@ -50,33 +50,19 @@ python lerobot/scripts/visualize_dataset_html.py \
|
||||||
--repo-id lerobot/pusht \
|
--repo-id lerobot/pusht \
|
||||||
--episodes 7 3 5 1 4
|
--episodes 7 3 5 1 4
|
||||||
```
|
```
|
||||||
|
|
||||||
- Run inference of a policy on the dataset and visualize the results:
|
|
||||||
```bash
|
|
||||||
python lerobot/scripts/visualize_dataset_html.py \
|
|
||||||
--repo-id lerobot/pusht \
|
|
||||||
--episodes 7 3 5 1 4
|
|
||||||
-p lerobot/diffusion_pusht \
|
|
||||||
--policy-overrides device=cpu
|
|
||||||
```
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import shutil
|
import shutil
|
||||||
import warnings
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
from flask import Flask, redirect, render_template, url_for
|
from flask import Flask, redirect, render_template, url_for
|
||||||
from safetensors.torch import load_file, save_file
|
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.utils.utils import init_logging
|
||||||
from lerobot.common.policies.utils import get_pretrained_policy_path
|
|
||||||
from lerobot.common.utils.utils import init_hydra_config, init_logging
|
|
||||||
|
|
||||||
|
|
||||||
class EpisodeSampler(torch.utils.data.Sampler):
|
class EpisodeSampler(torch.utils.data.Sampler):
|
||||||
|
@ -99,7 +85,6 @@ def run_server(
|
||||||
port: str,
|
port: str,
|
||||||
static_folder: Path,
|
static_folder: Path,
|
||||||
template_folder: Path,
|
template_folder: Path,
|
||||||
has_policy: bool = False,
|
|
||||||
):
|
):
|
||||||
app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve())
|
app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve())
|
||||||
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
|
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
|
||||||
|
@ -139,7 +124,7 @@ def run_server(
|
||||||
dataset_info=dataset_info,
|
dataset_info=dataset_info,
|
||||||
videos_info=videos_info,
|
videos_info=videos_info,
|
||||||
ep_csv_url=ep_csv_url,
|
ep_csv_url=ep_csv_url,
|
||||||
has_policy=has_policy,
|
has_policy=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
app.run(host=host, port=port)
|
app.run(host=host, port=port)
|
||||||
|
@ -150,7 +135,7 @@ def get_ep_csv_fname(episode_id: int):
|
||||||
return ep_csv_fname
|
return ep_csv_fname
|
||||||
|
|
||||||
|
|
||||||
def write_episode_data_csv(output_dir, file_name, episode_index, dataset, inference_results=None):
|
def write_episode_data_csv(output_dir, file_name, episode_index, dataset):
|
||||||
"""Write a csv file containg timeseries data of an episode (e.g. state and action).
|
"""Write a csv file containg timeseries data of an episode (e.g. state and action).
|
||||||
This file will be loaded by Dygraph javascript to plot data in real time."""
|
This file will be loaded by Dygraph javascript to plot data in real time."""
|
||||||
from_idx = dataset.episode_data_index["from"][episode_index]
|
from_idx = dataset.episode_data_index["from"][episode_index]
|
||||||
|
@ -158,7 +143,6 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset, infere
|
||||||
|
|
||||||
has_state = "observation.state" in dataset.hf_dataset.features
|
has_state = "observation.state" in dataset.hf_dataset.features
|
||||||
has_action = "action" in dataset.hf_dataset.features
|
has_action = "action" in dataset.hf_dataset.features
|
||||||
has_inference = inference_results is not None
|
|
||||||
|
|
||||||
# init header of csv with state and action names
|
# init header of csv with state and action names
|
||||||
header = ["timestamp"]
|
header = ["timestamp"]
|
||||||
|
@ -168,13 +152,6 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset, infere
|
||||||
if has_action:
|
if has_action:
|
||||||
dim_action = len(dataset.hf_dataset["action"][0])
|
dim_action = len(dataset.hf_dataset["action"][0])
|
||||||
header += [f"action_{i}" for i in range(dim_action)]
|
header += [f"action_{i}" for i in range(dim_action)]
|
||||||
if has_inference:
|
|
||||||
if "action" in inference_results:
|
|
||||||
dim_pred_action = inference_results["action"].shape[1]
|
|
||||||
header += [f"pred_action_{i}" for i in range(dim_pred_action)]
|
|
||||||
for key in inference_results:
|
|
||||||
if "loss" in key:
|
|
||||||
header += [key]
|
|
||||||
|
|
||||||
columns = ["timestamp"]
|
columns = ["timestamp"]
|
||||||
if has_state:
|
if has_state:
|
||||||
|
@ -192,18 +169,6 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset, infere
|
||||||
row += data[i]["action"].tolist()
|
row += data[i]["action"].tolist()
|
||||||
rows.append(row)
|
rows.append(row)
|
||||||
|
|
||||||
if has_inference:
|
|
||||||
num_frames = len(rows)
|
|
||||||
if "action" in inference_results:
|
|
||||||
assert num_frames == inference_results["action"].shape[0]
|
|
||||||
for i in range(num_frames):
|
|
||||||
rows[i] += inference_results["action"][i].tolist()
|
|
||||||
for key in inference_results:
|
|
||||||
if "loss" in key:
|
|
||||||
assert num_frames == inference_results[key].shape[0]
|
|
||||||
for i in range(num_frames):
|
|
||||||
rows[i] += [inference_results[key][i].item()]
|
|
||||||
|
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
with open(output_dir / file_name, "w") as f:
|
with open(output_dir / file_name, "w") as f:
|
||||||
f.write(",".join(header) + "\n")
|
f.write(",".join(header) + "\n")
|
||||||
|
@ -221,75 +186,6 @@ def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def run_inference(
|
|
||||||
dataset, episode_index, policy, policy_method="select_action", num_workers=4, batch_size=32, device="cuda"
|
|
||||||
):
|
|
||||||
if policy_method not in ["select_action", "forward"]:
|
|
||||||
raise ValueError(
|
|
||||||
f"`policy_method` is expected to be 'select_action' or 'forward', but '{policy_method}' is provided instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
policy.eval()
|
|
||||||
policy.to(device)
|
|
||||||
|
|
||||||
logging.info("Loading dataloader")
|
|
||||||
episode_sampler = EpisodeSampler(dataset, episode_index)
|
|
||||||
dataloader = torch.utils.data.DataLoader(
|
|
||||||
dataset,
|
|
||||||
num_workers=num_workers,
|
|
||||||
# When using `select_action`, we set batch size 1 so that we feed 1 frame at a time, in a continuous fashion.
|
|
||||||
batch_size=1 if policy_method == "select_action" else batch_size,
|
|
||||||
sampler=episode_sampler,
|
|
||||||
drop_last=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
warned_ndim_eq_0 = False
|
|
||||||
warned_ndim_gt_2 = False
|
|
||||||
|
|
||||||
logging.info("Running inference")
|
|
||||||
inference_results = {}
|
|
||||||
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
|
|
||||||
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
|
|
||||||
with torch.inference_mode():
|
|
||||||
if policy_method == "select_action":
|
|
||||||
gt_action = batch.pop("action")
|
|
||||||
output_dict = {"action": policy.select_action(batch)}
|
|
||||||
batch["action"] = gt_action
|
|
||||||
elif policy_method == "forward":
|
|
||||||
output_dict = policy.forward(batch)
|
|
||||||
# TODO(rcadene): Save and display all predicted actions at a given timestamp
|
|
||||||
# Save predicted action for the next timestamp only
|
|
||||||
output_dict["action"] = output_dict["action"][:, 0, :]
|
|
||||||
|
|
||||||
for key in output_dict:
|
|
||||||
if output_dict[key].ndim == 0:
|
|
||||||
if not warned_ndim_eq_0:
|
|
||||||
warnings.warn(
|
|
||||||
f"Ignore output key '{key}'. Its value is a scalar instead of a vector. It might have been aggregated over the batch dimension (e.g. `loss.mean()`).",
|
|
||||||
stacklevel=1,
|
|
||||||
)
|
|
||||||
warned_ndim_eq_0 = True
|
|
||||||
continue
|
|
||||||
|
|
||||||
if output_dict[key].ndim > 2:
|
|
||||||
if not warned_ndim_gt_2:
|
|
||||||
warnings.warn(
|
|
||||||
f"Ignore output key '{key}'. Its value is a tensor of {output_dict[key].ndim} dimensions instead of a vector.",
|
|
||||||
stacklevel=1,
|
|
||||||
)
|
|
||||||
warned_ndim_gt_2 = True
|
|
||||||
continue
|
|
||||||
|
|
||||||
if key not in inference_results:
|
|
||||||
inference_results[key] = []
|
|
||||||
inference_results[key].append(output_dict[key].to("cpu"))
|
|
||||||
|
|
||||||
for key in inference_results:
|
|
||||||
inference_results[key] = torch.cat(inference_results[key])
|
|
||||||
|
|
||||||
return inference_results
|
|
||||||
|
|
||||||
|
|
||||||
def visualize_dataset_html(
|
def visualize_dataset_html(
|
||||||
repo_id: str,
|
repo_id: str,
|
||||||
root: Path | None = None,
|
root: Path | None = None,
|
||||||
|
@ -299,28 +195,10 @@ def visualize_dataset_html(
|
||||||
host: str = "127.0.0.1",
|
host: str = "127.0.0.1",
|
||||||
port: int = 9090,
|
port: int = 9090,
|
||||||
force_override: bool = False,
|
force_override: bool = False,
|
||||||
policy_method: str = "select_action",
|
|
||||||
pretrained_policy_name_or_path: str | None = None,
|
|
||||||
policy_overrides: list[str] | None = None,
|
|
||||||
) -> Path | None:
|
) -> Path | None:
|
||||||
init_logging()
|
init_logging()
|
||||||
|
|
||||||
has_policy = pretrained_policy_name_or_path is not None
|
dataset = LeRobotDataset(repo_id, root=root)
|
||||||
|
|
||||||
if has_policy:
|
|
||||||
logging.info("Loading policy")
|
|
||||||
pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path)
|
|
||||||
|
|
||||||
hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", policy_overrides)
|
|
||||||
dataset = make_dataset(hydra_cfg)
|
|
||||||
policy = make_policy(hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
|
|
||||||
|
|
||||||
if policy_method == "select_action":
|
|
||||||
# Do not load previous observations or future actions, to simulate that the observations come from
|
|
||||||
# an environment.
|
|
||||||
dataset.delta_timestamps = None
|
|
||||||
else:
|
|
||||||
dataset = LeRobotDataset(repo_id, root=root)
|
|
||||||
|
|
||||||
if not dataset.video:
|
if not dataset.video:
|
||||||
raise NotImplementedError(f"Image datasets ({dataset.video=}) are currently not supported.")
|
raise NotImplementedError(f"Image datasets ({dataset.video=}) are currently not supported.")
|
||||||
|
@ -328,11 +206,6 @@ def visualize_dataset_html(
|
||||||
if output_dir is None:
|
if output_dir is None:
|
||||||
output_dir = f"outputs/visualize_dataset_html/{repo_id}"
|
output_dir = f"outputs/visualize_dataset_html/{repo_id}"
|
||||||
|
|
||||||
if has_policy:
|
|
||||||
ckpt_str = pretrained_policy_path.parts[-2]
|
|
||||||
exp_name = pretrained_policy_path.parts[-4]
|
|
||||||
output_dir += f"_{exp_name}_{ckpt_str}_{policy_method}"
|
|
||||||
|
|
||||||
output_dir = Path(output_dir)
|
output_dir = Path(output_dir)
|
||||||
if output_dir.exists():
|
if output_dir.exists():
|
||||||
if force_override:
|
if force_override:
|
||||||
|
@ -357,31 +230,13 @@ def visualize_dataset_html(
|
||||||
|
|
||||||
logging.info("Writing CSV files")
|
logging.info("Writing CSV files")
|
||||||
for episode_index in tqdm.tqdm(episodes):
|
for episode_index in tqdm.tqdm(episodes):
|
||||||
inference_results = None
|
|
||||||
if has_policy:
|
|
||||||
inference_results_path = output_dir / f"episode_{episode_index}.safetensors"
|
|
||||||
if inference_results_path.exists():
|
|
||||||
inference_results = load_file(inference_results_path)
|
|
||||||
else:
|
|
||||||
inference_results = run_inference(
|
|
||||||
dataset,
|
|
||||||
episode_index,
|
|
||||||
policy,
|
|
||||||
policy_method,
|
|
||||||
num_workers=hydra_cfg.training.num_workers,
|
|
||||||
batch_size=hydra_cfg.training.batch_size,
|
|
||||||
device=hydra_cfg.device,
|
|
||||||
)
|
|
||||||
inference_results_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
save_file(inference_results, inference_results_path)
|
|
||||||
|
|
||||||
# write states and actions in a csv (it can be slow for big datasets)
|
# write states and actions in a csv (it can be slow for big datasets)
|
||||||
ep_csv_fname = get_ep_csv_fname(episode_index)
|
ep_csv_fname = get_ep_csv_fname(episode_index)
|
||||||
# TODO(rcadene): speedup script by loading directly from dataset, pyarrow, parquet, safetensors?
|
# TODO(rcadene): speedup script by loading directly from dataset, pyarrow, parquet, safetensors?
|
||||||
write_episode_data_csv(static_dir, ep_csv_fname, episode_index, dataset, inference_results)
|
write_episode_data_csv(static_dir, ep_csv_fname, episode_index, dataset)
|
||||||
|
|
||||||
if serve:
|
if serve:
|
||||||
run_server(dataset, episodes, host, port, static_dir, template_dir, has_policy)
|
run_server(dataset, episodes, host, port, static_dir, template_dir)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
@ -437,28 +292,6 @@ def main():
|
||||||
help="Delete the output directory if it exists already.",
|
help="Delete the output directory if it exists already.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--policy-method",
|
|
||||||
type=str,
|
|
||||||
default="select_action",
|
|
||||||
choices=["select_action", "forward"],
|
|
||||||
help="Python method used to run the inference. By default, set to `select_action` used during evaluation to output the sequence of actions. Can bet set to `forward` used during training to compute the loss.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"-p",
|
|
||||||
"--pretrained-policy-name-or-path",
|
|
||||||
type=str,
|
|
||||||
help=(
|
|
||||||
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
|
|
||||||
"saved using `Policy.save_pretrained`."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--policy-overrides",
|
|
||||||
nargs="*",
|
|
||||||
help="Any key=value arguments to override policy config values (use dots for.nested=overrides)",
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
visualize_dataset_html(**vars(args))
|
visualize_dataset_html(**vars(args))
|
||||||
|
|
||||||
|
|
|
@ -18,12 +18,7 @@ from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
|
||||||
from lerobot.common.logger import Logger
|
|
||||||
from lerobot.common.policies.factory import make_policy
|
|
||||||
from lerobot.common.utils.utils import init_hydra_config
|
|
||||||
from lerobot.scripts.visualize_dataset_html import visualize_dataset_html
|
from lerobot.scripts.visualize_dataset_html import visualize_dataset_html
|
||||||
from tests.utils import DEFAULT_CONFIG_PATH
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -39,34 +34,3 @@ def test_visualize_dataset_html(tmpdir, repo_id):
|
||||||
serve=False,
|
serve=False,
|
||||||
)
|
)
|
||||||
assert (tmpdir / "static" / "episode_0.csv").exists()
|
assert (tmpdir / "static" / "episode_0.csv").exists()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"repo_id, policy_method",
|
|
||||||
[
|
|
||||||
("lerobot/pusht", "select_action"),
|
|
||||||
("lerobot/pusht", "forward"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_visualize_dataset_policy_ckpt_path(tmpdir, repo_id, policy_method):
|
|
||||||
tmpdir = Path(tmpdir)
|
|
||||||
|
|
||||||
# Create a policy
|
|
||||||
cfg = init_hydra_config(DEFAULT_CONFIG_PATH, overrides=["device=cpu"])
|
|
||||||
dataset = make_dataset(cfg)
|
|
||||||
policy = make_policy(cfg, dataset_stats=dataset.stats)
|
|
||||||
|
|
||||||
# Save a checkpoint
|
|
||||||
logger = Logger(cfg, tmpdir)
|
|
||||||
logger.save_model(tmpdir, policy)
|
|
||||||
|
|
||||||
visualize_dataset_html(
|
|
||||||
repo_id,
|
|
||||||
episodes=[0],
|
|
||||||
output_dir=tmpdir,
|
|
||||||
serve=False,
|
|
||||||
pretrained_policy_name_or_path=tmpdir,
|
|
||||||
policy_method=policy_method,
|
|
||||||
)
|
|
||||||
assert (tmpdir / "static" / "episode_0.csv").exists()
|
|
||||||
assert (tmpdir / "episode_0.safetensors").exists()
|
|
||||||
|
|
Loading…
Reference in New Issue