[HIL-SERL port] Add Reward classifier benchmark tracking to chose best visual encoder (#688)
This commit is contained in:
parent
12525242ce
commit
b63738674c
|
@ -27,6 +27,8 @@ training:
|
||||||
# image_keys: ["observation.images.top", "observation.images.wrist"]
|
# image_keys: ["observation.images.top", "observation.images.wrist"]
|
||||||
image_keys: ["observation.images.laptop", "observation.images.phone"]
|
image_keys: ["observation.images.laptop", "observation.images.phone"]
|
||||||
label_key: "next.reward"
|
label_key: "next.reward"
|
||||||
|
profile_inference_time: false
|
||||||
|
profile_inference_time_iters: 20
|
||||||
|
|
||||||
eval:
|
eval:
|
||||||
batch_size: 16
|
batch_size: 16
|
||||||
|
|
|
@ -20,17 +20,19 @@ from pathlib import Path
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import wandb
|
|
||||||
from deepdiff import DeepDiff
|
from deepdiff import DeepDiff
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
from torch import optim
|
from torch import optim
|
||||||
|
from torch.autograd import profiler
|
||||||
from torch.cuda.amp import GradScaler
|
from torch.cuda.amp import GradScaler
|
||||||
from torch.utils.data import DataLoader, WeightedRandomSampler, random_split
|
from torch.utils.data import DataLoader, RandomSampler, WeightedRandomSampler, random_split
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import wandb
|
||||||
from lerobot.common.datasets.factory import resolve_delta_timestamps
|
from lerobot.common.datasets.factory import resolve_delta_timestamps
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.logger import Logger
|
from lerobot.common.logger import Logger
|
||||||
|
@ -124,6 +126,7 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l
|
||||||
batch_start_time = time.perf_counter()
|
batch_start_time = time.perf_counter()
|
||||||
samples = []
|
samples = []
|
||||||
running_loss = 0
|
running_loss = 0
|
||||||
|
inference_times = []
|
||||||
|
|
||||||
with (
|
with (
|
||||||
torch.no_grad(),
|
torch.no_grad(),
|
||||||
|
@ -133,7 +136,18 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l
|
||||||
images = [batch[img_key].to(device) for img_key in cfg.training.image_keys]
|
images = [batch[img_key].to(device) for img_key in cfg.training.image_keys]
|
||||||
labels = batch[cfg.training.label_key].float().to(device)
|
labels = batch[cfg.training.label_key].float().to(device)
|
||||||
|
|
||||||
outputs = model(images)
|
if cfg.training.profile_inference_time and logger._cfg.wandb.enable:
|
||||||
|
with (
|
||||||
|
profiler.profile(record_shapes=True) as prof,
|
||||||
|
profiler.record_function("model_inference"),
|
||||||
|
):
|
||||||
|
outputs = model(images)
|
||||||
|
inference_times.append(
|
||||||
|
next(x for x in prof.key_averages() if x.key == "model_inference").cpu_time
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
outputs = model(images)
|
||||||
|
|
||||||
loss = criterion(outputs.logits, labels)
|
loss = criterion(outputs.logits, labels)
|
||||||
|
|
||||||
# Track metrics
|
# Track metrics
|
||||||
|
@ -177,9 +191,76 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l
|
||||||
else None,
|
else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(inference_times) > 0:
|
||||||
|
eval_info["inference_time_avg"] = np.mean(inference_times)
|
||||||
|
eval_info["inference_time_median"] = np.median(inference_times)
|
||||||
|
eval_info["inference_time_std"] = np.std(inference_times)
|
||||||
|
eval_info["inference_time_batch_size"] = val_loader.batch_size
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Inference mean time: {eval_info['inference_time_avg']:.2f} us, median: {eval_info['inference_time_median']:.2f} us, std: {eval_info['inference_time_std']:.2f} us, with {len(inference_times)} iterations on {device.type} device, batch size: {eval_info['inference_time_batch_size']}"
|
||||||
|
)
|
||||||
|
|
||||||
return accuracy, eval_info
|
return accuracy, eval_info
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_inference_time(model, dataset, logger, cfg, device, step):
|
||||||
|
if not cfg.training.profile_inference_time:
|
||||||
|
return
|
||||||
|
|
||||||
|
iters = cfg.training.profile_inference_time_iters
|
||||||
|
inference_times = []
|
||||||
|
|
||||||
|
loader = DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=1,
|
||||||
|
num_workers=cfg.training.num_workers,
|
||||||
|
sampler=RandomSampler(dataset),
|
||||||
|
pin_memory=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
for _ in tqdm(range(iters), desc="Benchmarking inference time"):
|
||||||
|
x = next(iter(loader))
|
||||||
|
x = [x[img_key].to(device) for img_key in cfg.training.image_keys]
|
||||||
|
|
||||||
|
# Warm up
|
||||||
|
for _ in range(10):
|
||||||
|
_ = model(x)
|
||||||
|
|
||||||
|
# sync the device
|
||||||
|
if device.type == "cuda":
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
elif device.type == "mps":
|
||||||
|
torch.mps.synchronize()
|
||||||
|
|
||||||
|
with profiler.profile(record_shapes=True) as prof, profiler.record_function("model_inference"):
|
||||||
|
_ = model(x)
|
||||||
|
|
||||||
|
inference_times.append(
|
||||||
|
next(x for x in prof.key_averages() if x.key == "model_inference").cpu_time
|
||||||
|
)
|
||||||
|
|
||||||
|
inference_times = np.array(inference_times)
|
||||||
|
avg, median, std = inference_times.mean(), np.median(inference_times), inference_times.std()
|
||||||
|
print(
|
||||||
|
f"Inference time mean: {avg:.2f} us, median: {median:.2f} us, std: {std:.2f} us, with {iters} iterations on {device.type} device"
|
||||||
|
)
|
||||||
|
if logger._cfg.wandb.enable:
|
||||||
|
logger.log_dict(
|
||||||
|
{
|
||||||
|
"inference_time_benchmark_avg": avg,
|
||||||
|
"inference_time_benchmark_median": median,
|
||||||
|
"inference_time_benchmark_std": std,
|
||||||
|
},
|
||||||
|
step + 1,
|
||||||
|
mode="eval",
|
||||||
|
)
|
||||||
|
|
||||||
|
return avg, median, std
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(version_base="1.2", config_path="../configs/policy", config_name="hilserl_classifier")
|
@hydra.main(version_base="1.2", config_path="../configs/policy", config_name="hilserl_classifier")
|
||||||
def train(cfg: DictConfig) -> None:
|
def train(cfg: DictConfig) -> None:
|
||||||
# Main training pipeline with support for resuming training
|
# Main training pipeline with support for resuming training
|
||||||
|
@ -313,6 +394,8 @@ def train(cfg: DictConfig) -> None:
|
||||||
|
|
||||||
step += len(train_loader)
|
step += len(train_loader)
|
||||||
|
|
||||||
|
benchmark_inference_time(model, dataset, logger, cfg, device, step)
|
||||||
|
|
||||||
logging.info("Training completed")
|
logging.info("Training completed")
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue