This commit is contained in:
Remi Cadene 2024-11-27 14:57:14 +01:00
parent fc4df91883
commit 272a9d9427
3 changed files with 142 additions and 19 deletions

View File

@ -140,25 +140,25 @@ class ACTPolicy(
batch = self.normalize_targets(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
l1_loss = (
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
).mean()
bsize = actions_hat.shape[0]
l1_loss = F.l1_loss(batch["action"], actions_hat, reduction="none")
l1_loss = l1_loss * ~batch["action_is_pad"].unsqueeze(-1)
l1_loss = l1_loss.view(bsize, -1).mean(dim=1)
loss_dict = {"l1_loss": l1_loss.item()}
out_dict = {}
out_dict["l1_loss"] = l1_loss
if self.config.use_vae:
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
# each dimension independently, we sum over the latent dimension to get the total
# KL-divergence per batch element, then take the mean over the batch.
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
mean_kld = (
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
)
loss_dict["kld_loss"] = mean_kld.item()
loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
kld_loss = (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1)
out_dict["loss"] = l1_loss + kld_loss * self.config.kl_weight
else:
loss_dict["loss"] = l1_loss
out_dict["loss"] = l1_loss
return loss_dict
out_dict["action"] = self.unnormalize_outputs({"action": actions_hat})["action"]
return out_dict
class ACTTemporalEnsembler:

View File

@ -268,10 +268,11 @@ def main():
args = parser.parse_args()
kwargs = vars(args)
repo_id = kwargs.pop("repo_id")
root = kwargs.pop("root")
# root = kwargs.pop("root")
logging.info("Loading dataset")
dataset = LeRobotDataset(repo_id, root=root, local_files_only=True)
dataset = LeRobotDataset(repo_id)
visualize_dataset(dataset, **vars(args))

View File

@ -55,13 +55,30 @@ python lerobot/scripts/visualize_dataset_html.py \
import argparse
import logging
import shutil
import warnings
from pathlib import Path
import torch
import tqdm
from flask import Flask, redirect, render_template, url_for
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.utils.utils import init_logging
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils.utils import init_hydra_config, init_logging
from lerobot.scripts.eval import get_pretrained_policy_path
class EpisodeSampler(torch.utils.data.Sampler):
def __init__(self, dataset, episode_index):
from_idx = dataset.episode_data_index["from"][episode_index].item()
to_idx = dataset.episode_data_index["to"][episode_index].item()
self.frame_ids = range(from_idx, to_idx)
def __iter__(self):
return iter(self.frame_ids)
def __len__(self):
return len(self.frame_ids)
def run_server(
@ -119,14 +136,95 @@ def run_server(
app.run(host=host, port=port)
def run_inference(
dataset, episode_index, policy, policy_method="select_action", num_workers=4, batch_size=32, device="mps"
):
if policy_method not in ["select_action", "forward"]:
raise ValueError(
f"`policy_method` is expected to be 'select_action' or 'forward', but '{policy_method}' is provided instead."
)
policy.eval()
policy.to(device)
logging.info("Loading dataloader")
episode_sampler = EpisodeSampler(dataset, episode_index)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=num_workers,
# When using `select_action`, we set batch size 1 so that we feed 1 frame at a time, in a continuous fashion.
batch_size=1 if policy_method == "select_action" else batch_size,
sampler=episode_sampler,
drop_last=False,
)
warned_ndim_eq_0 = False
warned_ndim_gt_2 = False
logging.info("Running inference")
inference_results = {}
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
with torch.inference_mode():
if policy_method == "select_action":
gt_action = batch.pop("action")
output_dict = {"action": policy.select_action(batch)}
batch["action"] = gt_action
elif policy_method == "forward":
output_dict = policy.forward(batch)
# TODO(rcadene): Save and display all predicted actions at a given timestamp
# Save predicted action for the next timestamp only
output_dict["action"] = output_dict["action"][:, 0, :]
for key in output_dict:
if output_dict[key].ndim == 0:
if not warned_ndim_eq_0:
warnings.warn(
f"Ignore output key '{key}'. Its value is a scalar instead of a vector. It might have been aggregated over the batch dimension (e.g. `loss.mean()`).",
stacklevel=1,
)
warned_ndim_eq_0 = True
continue
if output_dict[key].ndim > 2:
if not warned_ndim_gt_2:
warnings.warn(
f"Ignore output key '{key}'. Its value is a tensor of {output_dict[key].ndim} dimensions instead of a vector.",
stacklevel=1,
)
warned_ndim_gt_2 = True
continue
if key not in inference_results:
inference_results[key] = []
inference_results[key].append(output_dict[key].to("cpu"))
for key in inference_results:
inference_results[key] = torch.cat(inference_results[key])
return inference_results
def get_ep_csv_fname(episode_id: int):
ep_csv_fname = f"episode_{episode_id}.csv"
return ep_csv_fname
def write_episode_data_csv(output_dir, file_name, episode_index, dataset):
def write_episode_data_csv(output_dir, file_name, episode_index, dataset, policy=None):
"""Write a csv file containg timeseries data of an episode (e.g. state and action).
This file will be loaded by Dygraph javascript to plot data in real time."""
if policy is not None:
inference_results = run_inference(
dataset,
episode_index,
policy,
policy_method="select_action",
# num_workers=hydra_cfg.training.num_workers,
# batch_size=hydra_cfg.training.batch_size,
# device=hydra_cfg.device,
)
from_idx = dataset.episode_data_index["from"][episode_index]
to_idx = dataset.episode_data_index["to"][episode_index]
@ -141,21 +239,26 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset):
if has_action:
dim_action = dataset.meta.shapes["action"][0]
header += [f"action_{i}" for i in range(dim_action)]
if policy is not None:
dim_action = dataset.meta.shapes["action"][0]
header += [f"pred_action_{i}" for i in range(dim_action)]
columns = ["timestamp"]
if has_state:
columns += ["observation.state"]
if has_action:
columns += ["action"]
data = dataset.hf_dataset.select_columns(columns)
rows = []
data = dataset.hf_dataset.select_columns(columns)
for i in range(from_idx, to_idx):
row = [data[i]["timestamp"].item()]
if has_state:
row += data[i]["observation.state"].tolist()
if has_action:
row += data[i]["action"].tolist()
if policy is not None:
row += inference_results["action"][i].tolist()
rows.append(row)
output_dir.mkdir(parents=True, exist_ok=True)
@ -183,6 +286,9 @@ def visualize_dataset_html(
host: str = "127.0.0.1",
port: int = 9090,
force_override: bool = False,
policy_method: str = "select_action",
pretrained_policy_name_or_path: str | None = None,
policy_overrides: list[str] | None = None,
) -> Path | None:
init_logging()
@ -214,12 +320,28 @@ def visualize_dataset_html(
if episodes is None:
episodes = list(range(dataset.num_episodes))
pretrained_policy_name_or_path = "aliberts/act_reachy_test_model"
policy = None
if pretrained_policy_name_or_path is not None:
logging.info("Loading policy")
pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path)
hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", overrides=["device=mps"])
# dataset = make_dataset(hydra_cfg)
policy = make_policy(hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
if policy_method == "select_action":
# Do not load previous observations or future actions, to simulate that the observations come from
# an environment.
dataset.delta_timestamps = None
logging.info("Writing CSV files")
for episode_index in tqdm.tqdm(episodes):
# write states and actions in a csv (it can be slow for big datasets)
ep_csv_fname = get_ep_csv_fname(episode_index)
# TODO(rcadene): speedup script by loading directly from dataset, pyarrow, parquet, safetensors?
write_episode_data_csv(static_dir, ep_csv_fname, episode_index, dataset)
write_episode_data_csv(static_dir, ep_csv_fname, episode_index, dataset, policy=policy)
if serve:
run_server(dataset, episodes, host, port, static_dir, template_dir)
@ -281,8 +403,8 @@ def main():
args = parser.parse_args()
kwargs = vars(args)
repo_id = kwargs.pop("repo_id")
root = kwargs.pop("root")
dataset = LeRobotDataset(repo_id, root=root, local_files_only=True)
# root = kwargs.pop("root")
dataset = LeRobotDataset(repo_id)
visualize_dataset_html(dataset, **kwargs)