From b63738674ca5406cfa0db767dd467ad17d1ce371 Mon Sep 17 00:00:00 2001 From: Eugene Mironov Date: Fri, 7 Feb 2025 00:39:51 +0700 Subject: [PATCH] [HIL-SERL port] Add Reward classifier benchmark tracking to chose best visual encoder (#688) --- .../configs/policy/hilserl_classifier.yaml | 2 + lerobot/scripts/train_hilserl_classifier.py | 89 ++++++++++++++++++- 2 files changed, 88 insertions(+), 3 deletions(-) diff --git a/lerobot/configs/policy/hilserl_classifier.yaml b/lerobot/configs/policy/hilserl_classifier.yaml index 21fd4a1a..a315902b 100644 --- a/lerobot/configs/policy/hilserl_classifier.yaml +++ b/lerobot/configs/policy/hilserl_classifier.yaml @@ -27,6 +27,8 @@ training: # image_keys: ["observation.images.top", "observation.images.wrist"] image_keys: ["observation.images.laptop", "observation.images.phone"] label_key: "next.reward" + profile_inference_time: false + profile_inference_time_iters: 20 eval: batch_size: 16 diff --git a/lerobot/scripts/train_hilserl_classifier.py b/lerobot/scripts/train_hilserl_classifier.py index 458e3ff1..0ca8eae4 100644 --- a/lerobot/scripts/train_hilserl_classifier.py +++ b/lerobot/scripts/train_hilserl_classifier.py @@ -20,17 +20,19 @@ from pathlib import Path from pprint import pformat import hydra +import numpy as np import torch import torch.nn as nn -import wandb from deepdiff import DeepDiff from omegaconf import DictConfig, OmegaConf from termcolor import colored from torch import optim +from torch.autograd import profiler from torch.cuda.amp import GradScaler -from torch.utils.data import DataLoader, WeightedRandomSampler, random_split +from torch.utils.data import DataLoader, RandomSampler, WeightedRandomSampler, random_split from tqdm import tqdm +import wandb from lerobot.common.datasets.factory import resolve_delta_timestamps from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.logger import Logger @@ -124,6 +126,7 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l batch_start_time = time.perf_counter() samples = [] running_loss = 0 + inference_times = [] with ( torch.no_grad(), @@ -133,7 +136,18 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l images = [batch[img_key].to(device) for img_key in cfg.training.image_keys] labels = batch[cfg.training.label_key].float().to(device) - outputs = model(images) + if cfg.training.profile_inference_time and logger._cfg.wandb.enable: + with ( + profiler.profile(record_shapes=True) as prof, + profiler.record_function("model_inference"), + ): + outputs = model(images) + inference_times.append( + next(x for x in prof.key_averages() if x.key == "model_inference").cpu_time + ) + else: + outputs = model(images) + loss = criterion(outputs.logits, labels) # Track metrics @@ -177,9 +191,76 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l else None, } + if len(inference_times) > 0: + eval_info["inference_time_avg"] = np.mean(inference_times) + eval_info["inference_time_median"] = np.median(inference_times) + eval_info["inference_time_std"] = np.std(inference_times) + eval_info["inference_time_batch_size"] = val_loader.batch_size + + print( + f"Inference mean time: {eval_info['inference_time_avg']:.2f} us, median: {eval_info['inference_time_median']:.2f} us, std: {eval_info['inference_time_std']:.2f} us, with {len(inference_times)} iterations on {device.type} device, batch size: {eval_info['inference_time_batch_size']}" + ) + return accuracy, eval_info +def benchmark_inference_time(model, dataset, logger, cfg, device, step): + if not cfg.training.profile_inference_time: + return + + iters = cfg.training.profile_inference_time_iters + inference_times = [] + + loader = DataLoader( + dataset, + batch_size=1, + num_workers=cfg.training.num_workers, + sampler=RandomSampler(dataset), + pin_memory=True, + ) + + model.eval() + with torch.no_grad(): + for _ in tqdm(range(iters), desc="Benchmarking inference time"): + x = next(iter(loader)) + x = [x[img_key].to(device) for img_key in cfg.training.image_keys] + + # Warm up + for _ in range(10): + _ = model(x) + + # sync the device + if device.type == "cuda": + torch.cuda.synchronize() + elif device.type == "mps": + torch.mps.synchronize() + + with profiler.profile(record_shapes=True) as prof, profiler.record_function("model_inference"): + _ = model(x) + + inference_times.append( + next(x for x in prof.key_averages() if x.key == "model_inference").cpu_time + ) + + inference_times = np.array(inference_times) + avg, median, std = inference_times.mean(), np.median(inference_times), inference_times.std() + print( + f"Inference time mean: {avg:.2f} us, median: {median:.2f} us, std: {std:.2f} us, with {iters} iterations on {device.type} device" + ) + if logger._cfg.wandb.enable: + logger.log_dict( + { + "inference_time_benchmark_avg": avg, + "inference_time_benchmark_median": median, + "inference_time_benchmark_std": std, + }, + step + 1, + mode="eval", + ) + + return avg, median, std + + @hydra.main(version_base="1.2", config_path="../configs/policy", config_name="hilserl_classifier") def train(cfg: DictConfig) -> None: # Main training pipeline with support for resuming training @@ -313,6 +394,8 @@ def train(cfg: DictConfig) -> None: step += len(train_loader) + benchmark_inference_time(model, dataset, logger, cfg, device, step) + logging.info("Training completed")