WIP
This commit is contained in:
parent
fc4df91883
commit
272a9d9427
|
@ -140,25 +140,25 @@ class ACTPolicy(
|
||||||
batch = self.normalize_targets(batch)
|
batch = self.normalize_targets(batch)
|
||||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||||
|
|
||||||
l1_loss = (
|
bsize = actions_hat.shape[0]
|
||||||
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
l1_loss = F.l1_loss(batch["action"], actions_hat, reduction="none")
|
||||||
).mean()
|
l1_loss = l1_loss * ~batch["action_is_pad"].unsqueeze(-1)
|
||||||
|
l1_loss = l1_loss.view(bsize, -1).mean(dim=1)
|
||||||
|
|
||||||
loss_dict = {"l1_loss": l1_loss.item()}
|
out_dict = {}
|
||||||
|
out_dict["l1_loss"] = l1_loss
|
||||||
if self.config.use_vae:
|
if self.config.use_vae:
|
||||||
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
||||||
# each dimension independently, we sum over the latent dimension to get the total
|
# each dimension independently, we sum over the latent dimension to get the total
|
||||||
# KL-divergence per batch element, then take the mean over the batch.
|
# KL-divergence per batch element, then take the mean over the batch.
|
||||||
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
|
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
|
||||||
mean_kld = (
|
kld_loss = (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1)
|
||||||
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
out_dict["loss"] = l1_loss + kld_loss * self.config.kl_weight
|
||||||
)
|
|
||||||
loss_dict["kld_loss"] = mean_kld.item()
|
|
||||||
loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
|
|
||||||
else:
|
else:
|
||||||
loss_dict["loss"] = l1_loss
|
out_dict["loss"] = l1_loss
|
||||||
|
|
||||||
return loss_dict
|
out_dict["action"] = self.unnormalize_outputs({"action": actions_hat})["action"]
|
||||||
|
return out_dict
|
||||||
|
|
||||||
|
|
||||||
class ACTTemporalEnsembler:
|
class ACTTemporalEnsembler:
|
||||||
|
|
|
@ -268,10 +268,11 @@ def main():
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
kwargs = vars(args)
|
kwargs = vars(args)
|
||||||
repo_id = kwargs.pop("repo_id")
|
repo_id = kwargs.pop("repo_id")
|
||||||
root = kwargs.pop("root")
|
# root = kwargs.pop("root")
|
||||||
|
|
||||||
logging.info("Loading dataset")
|
logging.info("Loading dataset")
|
||||||
dataset = LeRobotDataset(repo_id, root=root, local_files_only=True)
|
|
||||||
|
dataset = LeRobotDataset(repo_id)
|
||||||
|
|
||||||
visualize_dataset(dataset, **vars(args))
|
visualize_dataset(dataset, **vars(args))
|
||||||
|
|
||||||
|
|
|
@ -55,13 +55,30 @@ python lerobot/scripts/visualize_dataset_html.py \
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import shutil
|
import shutil
|
||||||
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
from flask import Flask, redirect, render_template, url_for
|
from flask import Flask, redirect, render_template, url_for
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.utils.utils import init_logging
|
from lerobot.common.policies.factory import make_policy
|
||||||
|
from lerobot.common.utils.utils import init_hydra_config, init_logging
|
||||||
|
from lerobot.scripts.eval import get_pretrained_policy_path
|
||||||
|
|
||||||
|
|
||||||
|
class EpisodeSampler(torch.utils.data.Sampler):
|
||||||
|
def __init__(self, dataset, episode_index):
|
||||||
|
from_idx = dataset.episode_data_index["from"][episode_index].item()
|
||||||
|
to_idx = dataset.episode_data_index["to"][episode_index].item()
|
||||||
|
self.frame_ids = range(from_idx, to_idx)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self.frame_ids)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.frame_ids)
|
||||||
|
|
||||||
|
|
||||||
def run_server(
|
def run_server(
|
||||||
|
@ -119,14 +136,95 @@ def run_server(
|
||||||
app.run(host=host, port=port)
|
app.run(host=host, port=port)
|
||||||
|
|
||||||
|
|
||||||
|
def run_inference(
|
||||||
|
dataset, episode_index, policy, policy_method="select_action", num_workers=4, batch_size=32, device="mps"
|
||||||
|
):
|
||||||
|
if policy_method not in ["select_action", "forward"]:
|
||||||
|
raise ValueError(
|
||||||
|
f"`policy_method` is expected to be 'select_action' or 'forward', but '{policy_method}' is provided instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
policy.eval()
|
||||||
|
policy.to(device)
|
||||||
|
|
||||||
|
logging.info("Loading dataloader")
|
||||||
|
episode_sampler = EpisodeSampler(dataset, episode_index)
|
||||||
|
dataloader = torch.utils.data.DataLoader(
|
||||||
|
dataset,
|
||||||
|
num_workers=num_workers,
|
||||||
|
# When using `select_action`, we set batch size 1 so that we feed 1 frame at a time, in a continuous fashion.
|
||||||
|
batch_size=1 if policy_method == "select_action" else batch_size,
|
||||||
|
sampler=episode_sampler,
|
||||||
|
drop_last=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
warned_ndim_eq_0 = False
|
||||||
|
warned_ndim_gt_2 = False
|
||||||
|
|
||||||
|
logging.info("Running inference")
|
||||||
|
inference_results = {}
|
||||||
|
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
|
||||||
|
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
|
||||||
|
with torch.inference_mode():
|
||||||
|
if policy_method == "select_action":
|
||||||
|
gt_action = batch.pop("action")
|
||||||
|
output_dict = {"action": policy.select_action(batch)}
|
||||||
|
batch["action"] = gt_action
|
||||||
|
elif policy_method == "forward":
|
||||||
|
output_dict = policy.forward(batch)
|
||||||
|
# TODO(rcadene): Save and display all predicted actions at a given timestamp
|
||||||
|
# Save predicted action for the next timestamp only
|
||||||
|
output_dict["action"] = output_dict["action"][:, 0, :]
|
||||||
|
|
||||||
|
for key in output_dict:
|
||||||
|
if output_dict[key].ndim == 0:
|
||||||
|
if not warned_ndim_eq_0:
|
||||||
|
warnings.warn(
|
||||||
|
f"Ignore output key '{key}'. Its value is a scalar instead of a vector. It might have been aggregated over the batch dimension (e.g. `loss.mean()`).",
|
||||||
|
stacklevel=1,
|
||||||
|
)
|
||||||
|
warned_ndim_eq_0 = True
|
||||||
|
continue
|
||||||
|
|
||||||
|
if output_dict[key].ndim > 2:
|
||||||
|
if not warned_ndim_gt_2:
|
||||||
|
warnings.warn(
|
||||||
|
f"Ignore output key '{key}'. Its value is a tensor of {output_dict[key].ndim} dimensions instead of a vector.",
|
||||||
|
stacklevel=1,
|
||||||
|
)
|
||||||
|
warned_ndim_gt_2 = True
|
||||||
|
continue
|
||||||
|
|
||||||
|
if key not in inference_results:
|
||||||
|
inference_results[key] = []
|
||||||
|
inference_results[key].append(output_dict[key].to("cpu"))
|
||||||
|
|
||||||
|
for key in inference_results:
|
||||||
|
inference_results[key] = torch.cat(inference_results[key])
|
||||||
|
|
||||||
|
return inference_results
|
||||||
|
|
||||||
|
|
||||||
def get_ep_csv_fname(episode_id: int):
|
def get_ep_csv_fname(episode_id: int):
|
||||||
ep_csv_fname = f"episode_{episode_id}.csv"
|
ep_csv_fname = f"episode_{episode_id}.csv"
|
||||||
return ep_csv_fname
|
return ep_csv_fname
|
||||||
|
|
||||||
|
|
||||||
def write_episode_data_csv(output_dir, file_name, episode_index, dataset):
|
def write_episode_data_csv(output_dir, file_name, episode_index, dataset, policy=None):
|
||||||
"""Write a csv file containg timeseries data of an episode (e.g. state and action).
|
"""Write a csv file containg timeseries data of an episode (e.g. state and action).
|
||||||
This file will be loaded by Dygraph javascript to plot data in real time."""
|
This file will be loaded by Dygraph javascript to plot data in real time."""
|
||||||
|
|
||||||
|
if policy is not None:
|
||||||
|
inference_results = run_inference(
|
||||||
|
dataset,
|
||||||
|
episode_index,
|
||||||
|
policy,
|
||||||
|
policy_method="select_action",
|
||||||
|
# num_workers=hydra_cfg.training.num_workers,
|
||||||
|
# batch_size=hydra_cfg.training.batch_size,
|
||||||
|
# device=hydra_cfg.device,
|
||||||
|
)
|
||||||
|
|
||||||
from_idx = dataset.episode_data_index["from"][episode_index]
|
from_idx = dataset.episode_data_index["from"][episode_index]
|
||||||
to_idx = dataset.episode_data_index["to"][episode_index]
|
to_idx = dataset.episode_data_index["to"][episode_index]
|
||||||
|
|
||||||
|
@ -141,21 +239,26 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset):
|
||||||
if has_action:
|
if has_action:
|
||||||
dim_action = dataset.meta.shapes["action"][0]
|
dim_action = dataset.meta.shapes["action"][0]
|
||||||
header += [f"action_{i}" for i in range(dim_action)]
|
header += [f"action_{i}" for i in range(dim_action)]
|
||||||
|
if policy is not None:
|
||||||
|
dim_action = dataset.meta.shapes["action"][0]
|
||||||
|
header += [f"pred_action_{i}" for i in range(dim_action)]
|
||||||
|
|
||||||
columns = ["timestamp"]
|
columns = ["timestamp"]
|
||||||
if has_state:
|
if has_state:
|
||||||
columns += ["observation.state"]
|
columns += ["observation.state"]
|
||||||
if has_action:
|
if has_action:
|
||||||
columns += ["action"]
|
columns += ["action"]
|
||||||
|
data = dataset.hf_dataset.select_columns(columns)
|
||||||
|
|
||||||
rows = []
|
rows = []
|
||||||
data = dataset.hf_dataset.select_columns(columns)
|
|
||||||
for i in range(from_idx, to_idx):
|
for i in range(from_idx, to_idx):
|
||||||
row = [data[i]["timestamp"].item()]
|
row = [data[i]["timestamp"].item()]
|
||||||
if has_state:
|
if has_state:
|
||||||
row += data[i]["observation.state"].tolist()
|
row += data[i]["observation.state"].tolist()
|
||||||
if has_action:
|
if has_action:
|
||||||
row += data[i]["action"].tolist()
|
row += data[i]["action"].tolist()
|
||||||
|
if policy is not None:
|
||||||
|
row += inference_results["action"][i].tolist()
|
||||||
rows.append(row)
|
rows.append(row)
|
||||||
|
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
@ -183,6 +286,9 @@ def visualize_dataset_html(
|
||||||
host: str = "127.0.0.1",
|
host: str = "127.0.0.1",
|
||||||
port: int = 9090,
|
port: int = 9090,
|
||||||
force_override: bool = False,
|
force_override: bool = False,
|
||||||
|
policy_method: str = "select_action",
|
||||||
|
pretrained_policy_name_or_path: str | None = None,
|
||||||
|
policy_overrides: list[str] | None = None,
|
||||||
) -> Path | None:
|
) -> Path | None:
|
||||||
init_logging()
|
init_logging()
|
||||||
|
|
||||||
|
@ -214,12 +320,28 @@ def visualize_dataset_html(
|
||||||
if episodes is None:
|
if episodes is None:
|
||||||
episodes = list(range(dataset.num_episodes))
|
episodes = list(range(dataset.num_episodes))
|
||||||
|
|
||||||
|
pretrained_policy_name_or_path = "aliberts/act_reachy_test_model"
|
||||||
|
|
||||||
|
policy = None
|
||||||
|
if pretrained_policy_name_or_path is not None:
|
||||||
|
logging.info("Loading policy")
|
||||||
|
pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path)
|
||||||
|
|
||||||
|
hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", overrides=["device=mps"])
|
||||||
|
# dataset = make_dataset(hydra_cfg)
|
||||||
|
policy = make_policy(hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
|
||||||
|
|
||||||
|
if policy_method == "select_action":
|
||||||
|
# Do not load previous observations or future actions, to simulate that the observations come from
|
||||||
|
# an environment.
|
||||||
|
dataset.delta_timestamps = None
|
||||||
|
|
||||||
logging.info("Writing CSV files")
|
logging.info("Writing CSV files")
|
||||||
for episode_index in tqdm.tqdm(episodes):
|
for episode_index in tqdm.tqdm(episodes):
|
||||||
# write states and actions in a csv (it can be slow for big datasets)
|
# write states and actions in a csv (it can be slow for big datasets)
|
||||||
ep_csv_fname = get_ep_csv_fname(episode_index)
|
ep_csv_fname = get_ep_csv_fname(episode_index)
|
||||||
# TODO(rcadene): speedup script by loading directly from dataset, pyarrow, parquet, safetensors?
|
# TODO(rcadene): speedup script by loading directly from dataset, pyarrow, parquet, safetensors?
|
||||||
write_episode_data_csv(static_dir, ep_csv_fname, episode_index, dataset)
|
write_episode_data_csv(static_dir, ep_csv_fname, episode_index, dataset, policy=policy)
|
||||||
|
|
||||||
if serve:
|
if serve:
|
||||||
run_server(dataset, episodes, host, port, static_dir, template_dir)
|
run_server(dataset, episodes, host, port, static_dir, template_dir)
|
||||||
|
@ -281,8 +403,8 @@ def main():
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
kwargs = vars(args)
|
kwargs = vars(args)
|
||||||
repo_id = kwargs.pop("repo_id")
|
repo_id = kwargs.pop("repo_id")
|
||||||
root = kwargs.pop("root")
|
# root = kwargs.pop("root")
|
||||||
dataset = LeRobotDataset(repo_id, root=root, local_files_only=True)
|
dataset = LeRobotDataset(repo_id)
|
||||||
visualize_dataset_html(dataset, **kwargs)
|
visualize_dataset_html(dataset, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue