[HIL-SERL port] Add Reward classifier benchmark tracking to chose best visual encoder (#688)

This commit is contained in:
Eugene Mironov 2025-02-07 00:39:51 +07:00 committed by GitHub
parent 12525242ce
commit b63738674c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 88 additions and 3 deletions

View File

@ -27,6 +27,8 @@ training:
# image_keys: ["observation.images.top", "observation.images.wrist"]
image_keys: ["observation.images.laptop", "observation.images.phone"]
label_key: "next.reward"
profile_inference_time: false
profile_inference_time_iters: 20
eval:
batch_size: 16

View File

@ -20,17 +20,19 @@ from pathlib import Path
from pprint import pformat
import hydra
import numpy as np
import torch
import torch.nn as nn
import wandb
from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf
from termcolor import colored
from torch import optim
from torch.autograd import profiler
from torch.cuda.amp import GradScaler
from torch.utils.data import DataLoader, WeightedRandomSampler, random_split
from torch.utils.data import DataLoader, RandomSampler, WeightedRandomSampler, random_split
from tqdm import tqdm
import wandb
from lerobot.common.datasets.factory import resolve_delta_timestamps
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.logger import Logger
@ -124,6 +126,7 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l
batch_start_time = time.perf_counter()
samples = []
running_loss = 0
inference_times = []
with (
torch.no_grad(),
@ -133,7 +136,18 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l
images = [batch[img_key].to(device) for img_key in cfg.training.image_keys]
labels = batch[cfg.training.label_key].float().to(device)
outputs = model(images)
if cfg.training.profile_inference_time and logger._cfg.wandb.enable:
with (
profiler.profile(record_shapes=True) as prof,
profiler.record_function("model_inference"),
):
outputs = model(images)
inference_times.append(
next(x for x in prof.key_averages() if x.key == "model_inference").cpu_time
)
else:
outputs = model(images)
loss = criterion(outputs.logits, labels)
# Track metrics
@ -177,9 +191,76 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l
else None,
}
if len(inference_times) > 0:
eval_info["inference_time_avg"] = np.mean(inference_times)
eval_info["inference_time_median"] = np.median(inference_times)
eval_info["inference_time_std"] = np.std(inference_times)
eval_info["inference_time_batch_size"] = val_loader.batch_size
print(
f"Inference mean time: {eval_info['inference_time_avg']:.2f} us, median: {eval_info['inference_time_median']:.2f} us, std: {eval_info['inference_time_std']:.2f} us, with {len(inference_times)} iterations on {device.type} device, batch size: {eval_info['inference_time_batch_size']}"
)
return accuracy, eval_info
def benchmark_inference_time(model, dataset, logger, cfg, device, step):
if not cfg.training.profile_inference_time:
return
iters = cfg.training.profile_inference_time_iters
inference_times = []
loader = DataLoader(
dataset,
batch_size=1,
num_workers=cfg.training.num_workers,
sampler=RandomSampler(dataset),
pin_memory=True,
)
model.eval()
with torch.no_grad():
for _ in tqdm(range(iters), desc="Benchmarking inference time"):
x = next(iter(loader))
x = [x[img_key].to(device) for img_key in cfg.training.image_keys]
# Warm up
for _ in range(10):
_ = model(x)
# sync the device
if device.type == "cuda":
torch.cuda.synchronize()
elif device.type == "mps":
torch.mps.synchronize()
with profiler.profile(record_shapes=True) as prof, profiler.record_function("model_inference"):
_ = model(x)
inference_times.append(
next(x for x in prof.key_averages() if x.key == "model_inference").cpu_time
)
inference_times = np.array(inference_times)
avg, median, std = inference_times.mean(), np.median(inference_times), inference_times.std()
print(
f"Inference time mean: {avg:.2f} us, median: {median:.2f} us, std: {std:.2f} us, with {iters} iterations on {device.type} device"
)
if logger._cfg.wandb.enable:
logger.log_dict(
{
"inference_time_benchmark_avg": avg,
"inference_time_benchmark_median": median,
"inference_time_benchmark_std": std,
},
step + 1,
mode="eval",
)
return avg, median, std
@hydra.main(version_base="1.2", config_path="../configs/policy", config_name="hilserl_classifier")
def train(cfg: DictConfig) -> None:
# Main training pipeline with support for resuming training
@ -313,6 +394,8 @@ def train(cfg: DictConfig) -> None:
step += len(train_loader)
benchmark_inference_time(model, dataset, logger, cfg, device, step)
logging.info("Training completed")