[HIL-SERL port] Add Reward classifier benchmark tracking to chose best visual encoder (#688)

This commit is contained in:
Eugene Mironov 2025-02-07 00:39:51 +07:00 committed by GitHub
parent 12525242ce
commit b63738674c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 88 additions and 3 deletions

View File

@ -27,6 +27,8 @@ training:
# image_keys: ["observation.images.top", "observation.images.wrist"] # image_keys: ["observation.images.top", "observation.images.wrist"]
image_keys: ["observation.images.laptop", "observation.images.phone"] image_keys: ["observation.images.laptop", "observation.images.phone"]
label_key: "next.reward" label_key: "next.reward"
profile_inference_time: false
profile_inference_time_iters: 20
eval: eval:
batch_size: 16 batch_size: 16

View File

@ -20,17 +20,19 @@ from pathlib import Path
from pprint import pformat from pprint import pformat
import hydra import hydra
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import wandb
from deepdiff import DeepDiff from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from termcolor import colored from termcolor import colored
from torch import optim from torch import optim
from torch.autograd import profiler
from torch.cuda.amp import GradScaler from torch.cuda.amp import GradScaler
from torch.utils.data import DataLoader, WeightedRandomSampler, random_split from torch.utils.data import DataLoader, RandomSampler, WeightedRandomSampler, random_split
from tqdm import tqdm from tqdm import tqdm
import wandb
from lerobot.common.datasets.factory import resolve_delta_timestamps from lerobot.common.datasets.factory import resolve_delta_timestamps
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.logger import Logger from lerobot.common.logger import Logger
@ -124,6 +126,7 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l
batch_start_time = time.perf_counter() batch_start_time = time.perf_counter()
samples = [] samples = []
running_loss = 0 running_loss = 0
inference_times = []
with ( with (
torch.no_grad(), torch.no_grad(),
@ -133,7 +136,18 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l
images = [batch[img_key].to(device) for img_key in cfg.training.image_keys] images = [batch[img_key].to(device) for img_key in cfg.training.image_keys]
labels = batch[cfg.training.label_key].float().to(device) labels = batch[cfg.training.label_key].float().to(device)
outputs = model(images) if cfg.training.profile_inference_time and logger._cfg.wandb.enable:
with (
profiler.profile(record_shapes=True) as prof,
profiler.record_function("model_inference"),
):
outputs = model(images)
inference_times.append(
next(x for x in prof.key_averages() if x.key == "model_inference").cpu_time
)
else:
outputs = model(images)
loss = criterion(outputs.logits, labels) loss = criterion(outputs.logits, labels)
# Track metrics # Track metrics
@ -177,9 +191,76 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l
else None, else None,
} }
if len(inference_times) > 0:
eval_info["inference_time_avg"] = np.mean(inference_times)
eval_info["inference_time_median"] = np.median(inference_times)
eval_info["inference_time_std"] = np.std(inference_times)
eval_info["inference_time_batch_size"] = val_loader.batch_size
print(
f"Inference mean time: {eval_info['inference_time_avg']:.2f} us, median: {eval_info['inference_time_median']:.2f} us, std: {eval_info['inference_time_std']:.2f} us, with {len(inference_times)} iterations on {device.type} device, batch size: {eval_info['inference_time_batch_size']}"
)
return accuracy, eval_info return accuracy, eval_info
def benchmark_inference_time(model, dataset, logger, cfg, device, step):
if not cfg.training.profile_inference_time:
return
iters = cfg.training.profile_inference_time_iters
inference_times = []
loader = DataLoader(
dataset,
batch_size=1,
num_workers=cfg.training.num_workers,
sampler=RandomSampler(dataset),
pin_memory=True,
)
model.eval()
with torch.no_grad():
for _ in tqdm(range(iters), desc="Benchmarking inference time"):
x = next(iter(loader))
x = [x[img_key].to(device) for img_key in cfg.training.image_keys]
# Warm up
for _ in range(10):
_ = model(x)
# sync the device
if device.type == "cuda":
torch.cuda.synchronize()
elif device.type == "mps":
torch.mps.synchronize()
with profiler.profile(record_shapes=True) as prof, profiler.record_function("model_inference"):
_ = model(x)
inference_times.append(
next(x for x in prof.key_averages() if x.key == "model_inference").cpu_time
)
inference_times = np.array(inference_times)
avg, median, std = inference_times.mean(), np.median(inference_times), inference_times.std()
print(
f"Inference time mean: {avg:.2f} us, median: {median:.2f} us, std: {std:.2f} us, with {iters} iterations on {device.type} device"
)
if logger._cfg.wandb.enable:
logger.log_dict(
{
"inference_time_benchmark_avg": avg,
"inference_time_benchmark_median": median,
"inference_time_benchmark_std": std,
},
step + 1,
mode="eval",
)
return avg, median, std
@hydra.main(version_base="1.2", config_path="../configs/policy", config_name="hilserl_classifier") @hydra.main(version_base="1.2", config_path="../configs/policy", config_name="hilserl_classifier")
def train(cfg: DictConfig) -> None: def train(cfg: DictConfig) -> None:
# Main training pipeline with support for resuming training # Main training pipeline with support for resuming training
@ -313,6 +394,8 @@ def train(cfg: DictConfig) -> None:
step += len(train_loader) step += len(train_loader)
benchmark_inference_time(model, dataset, logger, cfg, device, step)
logging.info("Training completed") logging.info("Training completed")