From 272a9d942703b5c1539495e7d5d84bf6600fd8b9 Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Wed, 27 Nov 2024 14:57:14 +0100 Subject: [PATCH] WIP --- lerobot/common/policies/act/modeling_act.py | 22 ++-- lerobot/scripts/visualize_dataset.py | 5 +- lerobot/scripts/visualize_dataset_html.py | 134 +++++++++++++++++++- 3 files changed, 142 insertions(+), 19 deletions(-) diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 418863a1..8e0516d6 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -140,25 +140,25 @@ class ACTPolicy( batch = self.normalize_targets(batch) actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) - l1_loss = ( - F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1) - ).mean() + bsize = actions_hat.shape[0] + l1_loss = F.l1_loss(batch["action"], actions_hat, reduction="none") + l1_loss = l1_loss * ~batch["action_is_pad"].unsqueeze(-1) + l1_loss = l1_loss.view(bsize, -1).mean(dim=1) - loss_dict = {"l1_loss": l1_loss.item()} + out_dict = {} + out_dict["l1_loss"] = l1_loss if self.config.use_vae: # Calculate Dā‚–ā‚—(latent_pdf || standard_normal). Note: After computing the KL-divergence for # each dimension independently, we sum over the latent dimension to get the total # KL-divergence per batch element, then take the mean over the batch. # (See App. B of https://arxiv.org/abs/1312.6114 for more details). - mean_kld = ( - (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean() - ) - loss_dict["kld_loss"] = mean_kld.item() - loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight + kld_loss = (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1) + out_dict["loss"] = l1_loss + kld_loss * self.config.kl_weight else: - loss_dict["loss"] = l1_loss + out_dict["loss"] = l1_loss - return loss_dict + out_dict["action"] = self.unnormalize_outputs({"action": actions_hat})["action"] + return out_dict class ACTTemporalEnsembler: diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py index 03205f25..9f0b6781 100644 --- a/lerobot/scripts/visualize_dataset.py +++ b/lerobot/scripts/visualize_dataset.py @@ -268,10 +268,11 @@ def main(): args = parser.parse_args() kwargs = vars(args) repo_id = kwargs.pop("repo_id") - root = kwargs.pop("root") + # root = kwargs.pop("root") logging.info("Loading dataset") - dataset = LeRobotDataset(repo_id, root=root, local_files_only=True) + + dataset = LeRobotDataset(repo_id) visualize_dataset(dataset, **vars(args)) diff --git a/lerobot/scripts/visualize_dataset_html.py b/lerobot/scripts/visualize_dataset_html.py index 475983d3..51272cb1 100644 --- a/lerobot/scripts/visualize_dataset_html.py +++ b/lerobot/scripts/visualize_dataset_html.py @@ -55,13 +55,30 @@ python lerobot/scripts/visualize_dataset_html.py \ import argparse import logging import shutil +import warnings from pathlib import Path +import torch import tqdm from flask import Flask, redirect, render_template, url_for from lerobot.common.datasets.lerobot_dataset import LeRobotDataset -from lerobot.common.utils.utils import init_logging +from lerobot.common.policies.factory import make_policy +from lerobot.common.utils.utils import init_hydra_config, init_logging +from lerobot.scripts.eval import get_pretrained_policy_path + + +class EpisodeSampler(torch.utils.data.Sampler): + def __init__(self, dataset, episode_index): + from_idx = dataset.episode_data_index["from"][episode_index].item() + to_idx = dataset.episode_data_index["to"][episode_index].item() + self.frame_ids = range(from_idx, to_idx) + + def __iter__(self): + return iter(self.frame_ids) + + def __len__(self): + return len(self.frame_ids) def run_server( @@ -119,14 +136,95 @@ def run_server( app.run(host=host, port=port) +def run_inference( + dataset, episode_index, policy, policy_method="select_action", num_workers=4, batch_size=32, device="mps" +): + if policy_method not in ["select_action", "forward"]: + raise ValueError( + f"`policy_method` is expected to be 'select_action' or 'forward', but '{policy_method}' is provided instead." + ) + + policy.eval() + policy.to(device) + + logging.info("Loading dataloader") + episode_sampler = EpisodeSampler(dataset, episode_index) + dataloader = torch.utils.data.DataLoader( + dataset, + num_workers=num_workers, + # When using `select_action`, we set batch size 1 so that we feed 1 frame at a time, in a continuous fashion. + batch_size=1 if policy_method == "select_action" else batch_size, + sampler=episode_sampler, + drop_last=False, + ) + + warned_ndim_eq_0 = False + warned_ndim_gt_2 = False + + logging.info("Running inference") + inference_results = {} + for batch in tqdm.tqdm(dataloader, total=len(dataloader)): + batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()} + with torch.inference_mode(): + if policy_method == "select_action": + gt_action = batch.pop("action") + output_dict = {"action": policy.select_action(batch)} + batch["action"] = gt_action + elif policy_method == "forward": + output_dict = policy.forward(batch) + # TODO(rcadene): Save and display all predicted actions at a given timestamp + # Save predicted action for the next timestamp only + output_dict["action"] = output_dict["action"][:, 0, :] + + for key in output_dict: + if output_dict[key].ndim == 0: + if not warned_ndim_eq_0: + warnings.warn( + f"Ignore output key '{key}'. Its value is a scalar instead of a vector. It might have been aggregated over the batch dimension (e.g. `loss.mean()`).", + stacklevel=1, + ) + warned_ndim_eq_0 = True + continue + + if output_dict[key].ndim > 2: + if not warned_ndim_gt_2: + warnings.warn( + f"Ignore output key '{key}'. Its value is a tensor of {output_dict[key].ndim} dimensions instead of a vector.", + stacklevel=1, + ) + warned_ndim_gt_2 = True + continue + + if key not in inference_results: + inference_results[key] = [] + inference_results[key].append(output_dict[key].to("cpu")) + + for key in inference_results: + inference_results[key] = torch.cat(inference_results[key]) + + return inference_results + + def get_ep_csv_fname(episode_id: int): ep_csv_fname = f"episode_{episode_id}.csv" return ep_csv_fname -def write_episode_data_csv(output_dir, file_name, episode_index, dataset): +def write_episode_data_csv(output_dir, file_name, episode_index, dataset, policy=None): """Write a csv file containg timeseries data of an episode (e.g. state and action). This file will be loaded by Dygraph javascript to plot data in real time.""" + + if policy is not None: + inference_results = run_inference( + dataset, + episode_index, + policy, + policy_method="select_action", + # num_workers=hydra_cfg.training.num_workers, + # batch_size=hydra_cfg.training.batch_size, + # device=hydra_cfg.device, + ) + from_idx = dataset.episode_data_index["from"][episode_index] to_idx = dataset.episode_data_index["to"][episode_index] @@ -141,21 +239,26 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset): if has_action: dim_action = dataset.meta.shapes["action"][0] header += [f"action_{i}" for i in range(dim_action)] + if policy is not None: + dim_action = dataset.meta.shapes["action"][0] + header += [f"pred_action_{i}" for i in range(dim_action)] columns = ["timestamp"] if has_state: columns += ["observation.state"] if has_action: columns += ["action"] + data = dataset.hf_dataset.select_columns(columns) rows = [] - data = dataset.hf_dataset.select_columns(columns) for i in range(from_idx, to_idx): row = [data[i]["timestamp"].item()] if has_state: row += data[i]["observation.state"].tolist() if has_action: row += data[i]["action"].tolist() + if policy is not None: + row += inference_results["action"][i].tolist() rows.append(row) output_dir.mkdir(parents=True, exist_ok=True) @@ -183,6 +286,9 @@ def visualize_dataset_html( host: str = "127.0.0.1", port: int = 9090, force_override: bool = False, + policy_method: str = "select_action", + pretrained_policy_name_or_path: str | None = None, + policy_overrides: list[str] | None = None, ) -> Path | None: init_logging() @@ -214,12 +320,28 @@ def visualize_dataset_html( if episodes is None: episodes = list(range(dataset.num_episodes)) + pretrained_policy_name_or_path = "aliberts/act_reachy_test_model" + + policy = None + if pretrained_policy_name_or_path is not None: + logging.info("Loading policy") + pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path) + + hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", overrides=["device=mps"]) + # dataset = make_dataset(hydra_cfg) + policy = make_policy(hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path) + + if policy_method == "select_action": + # Do not load previous observations or future actions, to simulate that the observations come from + # an environment. + dataset.delta_timestamps = None + logging.info("Writing CSV files") for episode_index in tqdm.tqdm(episodes): # write states and actions in a csv (it can be slow for big datasets) ep_csv_fname = get_ep_csv_fname(episode_index) # TODO(rcadene): speedup script by loading directly from dataset, pyarrow, parquet, safetensors? - write_episode_data_csv(static_dir, ep_csv_fname, episode_index, dataset) + write_episode_data_csv(static_dir, ep_csv_fname, episode_index, dataset, policy=policy) if serve: run_server(dataset, episodes, host, port, static_dir, template_dir) @@ -281,8 +403,8 @@ def main(): args = parser.parse_args() kwargs = vars(args) repo_id = kwargs.pop("repo_id") - root = kwargs.pop("root") - dataset = LeRobotDataset(repo_id, root=root, local_files_only=True) + # root = kwargs.pop("root") + dataset = LeRobotDataset(repo_id) visualize_dataset_html(dataset, **kwargs)