467 lines
17 KiB
Python
467 lines
17 KiB
Python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import logging
|
|
import time
|
|
from contextlib import nullcontext
|
|
from pprint import pformat
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import wandb
|
|
from deepdiff import DeepDiff
|
|
from omegaconf import DictConfig, OmegaConf
|
|
from termcolor import colored
|
|
from torch import optim
|
|
from torch.autograd import profiler
|
|
from torch.cuda.amp import GradScaler
|
|
from torch.utils.data import DataLoader, RandomSampler, WeightedRandomSampler
|
|
from tqdm import tqdm
|
|
|
|
from lerobot.common.datasets.factory import resolve_delta_timestamps
|
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
|
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
|
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
|
from lerobot.common.utils.utils import (
|
|
format_big_number,
|
|
get_safe_torch_device,
|
|
init_hydra_config,
|
|
init_logging,
|
|
set_global_seed,
|
|
)
|
|
from lerobot.scripts.server.buffer import random_shift
|
|
|
|
|
|
def get_model(cfg, logger): # noqa I001
|
|
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
|
|
model = Classifier(classifier_config)
|
|
if cfg.resume:
|
|
model.load_state_dict(Classifier.from_pretrained(str(logger.last_pretrained_model_dir)).state_dict())
|
|
return model
|
|
|
|
|
|
def create_balanced_sampler(dataset, cfg):
|
|
# Get underlying dataset if using Subset
|
|
original_dataset = dataset.dataset if isinstance(dataset, torch.utils.data.Subset) else dataset
|
|
|
|
# Get indices if using Subset (for slicing)
|
|
indices = dataset.indices if isinstance(dataset, torch.utils.data.Subset) else None
|
|
|
|
# Get labels from Hugging Face dataset
|
|
if indices is not None:
|
|
# Get subset of labels using Hugging Face's select()
|
|
hf_subset = original_dataset.hf_dataset.select(indices)
|
|
labels = hf_subset[cfg.training.label_key]
|
|
else:
|
|
# Get all labels directly
|
|
labels = original_dataset.hf_dataset[cfg.training.label_key]
|
|
|
|
labels = torch.stack(labels)
|
|
_, counts = torch.unique(labels, return_counts=True)
|
|
class_weights = 1.0 / counts.float()
|
|
sample_weights = class_weights[labels]
|
|
|
|
return WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
|
|
|
|
|
|
def support_amp(device: torch.device, cfg: DictConfig) -> bool:
|
|
# Check if the device supports AMP
|
|
# Here is an example of the issue that says that MPS doesn't support AMP properply
|
|
return cfg.training.use_amp and device.type in ("cuda", "cpu")
|
|
|
|
|
|
def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg):
|
|
# Single epoch training loop with AMP support and progress tracking
|
|
model.train()
|
|
correct = 0
|
|
total = 0
|
|
|
|
pbar = tqdm(train_loader, desc="Training")
|
|
for batch_idx, batch in enumerate(pbar):
|
|
start_time = time.perf_counter()
|
|
images = [batch[img_key].to(device) for img_key in cfg.training.image_keys]
|
|
images = [random_shift(img, 4) for img in images]
|
|
labels = batch[cfg.training.label_key].float().to(device)
|
|
|
|
# Forward pass with optional AMP
|
|
with torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext():
|
|
outputs = model(images)
|
|
loss = criterion(outputs.logits, labels)
|
|
|
|
# Backward pass with gradient scaling if AMP enabled
|
|
optimizer.zero_grad()
|
|
if cfg.training.use_amp:
|
|
grad_scaler.scale(loss).backward()
|
|
grad_scaler.step(optimizer)
|
|
grad_scaler.update()
|
|
else:
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
# Track metrics
|
|
if model.config.num_classes == 2:
|
|
predictions = (torch.sigmoid(outputs.logits) > 0.5).float()
|
|
else:
|
|
predictions = torch.argmax(outputs.logits, dim=1)
|
|
correct += (predictions == labels).sum().item()
|
|
total += labels.size(0)
|
|
|
|
current_acc = 100 * correct / total
|
|
train_info = {
|
|
"loss": loss.item(),
|
|
"accuracy": current_acc,
|
|
"dataloading_s": time.perf_counter() - start_time,
|
|
}
|
|
|
|
logger.log_dict(train_info, step + batch_idx, mode="train")
|
|
pbar.set_postfix({"loss": f"{loss.item():.4f}", "acc": f"{current_acc:.2f}%"})
|
|
|
|
|
|
def validate(model, val_loader, criterion, device, logger, cfg):
|
|
# Validation loop with metric tracking and sample logging
|
|
model.eval()
|
|
correct = 0
|
|
total = 0
|
|
batch_start_time = time.perf_counter()
|
|
samples = []
|
|
running_loss = 0
|
|
inference_times = []
|
|
|
|
with (
|
|
torch.no_grad(),
|
|
torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext(),
|
|
):
|
|
for batch in tqdm(val_loader, desc="Validation"):
|
|
images = [batch[img_key].to(device) for img_key in cfg.training.image_keys]
|
|
labels = batch[cfg.training.label_key].float().to(device)
|
|
|
|
if cfg.training.profile_inference_time and logger._cfg.wandb.enable:
|
|
with (
|
|
profiler.profile(record_shapes=True) as prof,
|
|
profiler.record_function("model_inference"),
|
|
):
|
|
outputs = model(images)
|
|
inference_times.append(
|
|
next(x for x in prof.key_averages() if x.key == "model_inference").cpu_time
|
|
)
|
|
else:
|
|
outputs = model(images)
|
|
|
|
loss = criterion(outputs.logits, labels)
|
|
|
|
# Track metrics
|
|
if model.config.num_classes == 2:
|
|
predictions = (torch.sigmoid(outputs.logits) > 0.5).float()
|
|
else:
|
|
predictions = torch.argmax(outputs.logits, dim=1)
|
|
correct += (predictions == labels).sum().item()
|
|
total += labels.size(0)
|
|
running_loss += loss.item()
|
|
|
|
# Log sample predictions for visualization
|
|
if len(samples) < cfg.eval.num_samples_to_log:
|
|
for i in range(min(cfg.eval.num_samples_to_log - len(samples), len(images))):
|
|
if model.config.num_classes == 2:
|
|
confidence = round(outputs.probabilities[i].item(), 3)
|
|
else:
|
|
confidence = [round(prob, 3) for prob in outputs.probabilities[i].tolist()]
|
|
samples.append(
|
|
{
|
|
**{
|
|
f"image_{img_key}": wandb.Image(images[img_idx][i].cpu())
|
|
for img_idx, img_key in enumerate(cfg.training.image_keys)
|
|
},
|
|
"true_label": labels[i].item(),
|
|
"predicted": predictions[i].item(),
|
|
"confidence": confidence,
|
|
}
|
|
)
|
|
|
|
accuracy = 100 * correct / total
|
|
avg_loss = running_loss / len(val_loader)
|
|
print(f"Average validation loss {avg_loss}, and accuracy {accuracy}")
|
|
|
|
eval_info = {
|
|
"loss": avg_loss,
|
|
"accuracy": accuracy,
|
|
"eval_s": time.perf_counter() - batch_start_time,
|
|
"eval/prediction_samples": wandb.Table(
|
|
data=[list(s.values()) for s in samples],
|
|
columns=list(samples[0].keys()),
|
|
)
|
|
if logger._cfg.wandb.enable
|
|
else None,
|
|
}
|
|
|
|
if len(inference_times) > 0:
|
|
eval_info["inference_time_avg"] = np.mean(inference_times)
|
|
eval_info["inference_time_median"] = np.median(inference_times)
|
|
eval_info["inference_time_std"] = np.std(inference_times)
|
|
eval_info["inference_time_batch_size"] = val_loader.batch_size
|
|
|
|
print(
|
|
f"Inference mean time: {eval_info['inference_time_avg']:.2f} us, median: {eval_info['inference_time_median']:.2f} us, std: {eval_info['inference_time_std']:.2f} us, with {len(inference_times)} iterations on {device.type} device, batch size: {eval_info['inference_time_batch_size']}"
|
|
)
|
|
|
|
return accuracy, eval_info
|
|
|
|
|
|
def benchmark_inference_time(model, dataset, logger, cfg, device, step):
|
|
if not cfg.training.profile_inference_time:
|
|
return
|
|
|
|
iters = cfg.training.profile_inference_time_iters
|
|
inference_times = []
|
|
|
|
loader = DataLoader(
|
|
dataset,
|
|
batch_size=1,
|
|
num_workers=cfg.training.num_workers,
|
|
sampler=RandomSampler(dataset),
|
|
pin_memory=True,
|
|
)
|
|
|
|
model.eval()
|
|
with torch.no_grad():
|
|
for _ in tqdm(range(iters), desc="Benchmarking inference time"):
|
|
x = next(iter(loader))
|
|
x = [x[img_key].to(device) for img_key in cfg.training.image_keys]
|
|
|
|
# Warm up
|
|
for _ in range(10):
|
|
_ = model(x)
|
|
|
|
# sync the device
|
|
if device.type == "cuda":
|
|
torch.cuda.synchronize()
|
|
elif device.type == "mps":
|
|
torch.mps.synchronize()
|
|
|
|
with (
|
|
profiler.profile(record_shapes=True) as prof,
|
|
profiler.record_function("model_inference"),
|
|
):
|
|
_ = model(x)
|
|
|
|
inference_times.append(
|
|
next(x for x in prof.key_averages() if x.key == "model_inference").cpu_time
|
|
)
|
|
|
|
inference_times = np.array(inference_times)
|
|
avg, median, std = (
|
|
inference_times.mean(),
|
|
np.median(inference_times),
|
|
inference_times.std(),
|
|
)
|
|
print(
|
|
f"Inference time mean: {avg:.2f} us, median: {median:.2f} us, std: {std:.2f} us, with {iters} iterations on {device.type} device"
|
|
)
|
|
if logger._cfg.wandb.enable:
|
|
logger.log_dict(
|
|
{
|
|
"inference_time_benchmark_avg": avg,
|
|
"inference_time_benchmark_median": median,
|
|
"inference_time_benchmark_std": std,
|
|
},
|
|
step + 1,
|
|
mode="eval",
|
|
)
|
|
|
|
return avg, median, std
|
|
|
|
|
|
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None) -> None:
|
|
if out_dir is None:
|
|
raise NotImplementedError()
|
|
if job_name is None:
|
|
raise NotImplementedError()
|
|
|
|
# Main training pipeline with support for resuming training
|
|
init_logging()
|
|
logging.info(OmegaConf.to_yaml(cfg))
|
|
|
|
# Initialize training environment
|
|
device = get_safe_torch_device(cfg.device, log=True)
|
|
set_global_seed(cfg.seed)
|
|
|
|
# Setup dataset and dataloaders
|
|
dataset = LeRobotDataset(
|
|
cfg.dataset_repo_id,
|
|
root=cfg.dataset_root,
|
|
local_files_only=cfg.local_files_only,
|
|
)
|
|
logging.info(f"Dataset size: {len(dataset)}")
|
|
|
|
n_total = len(dataset)
|
|
n_train = int(cfg.train_split_proportion * len(dataset))
|
|
train_dataset = torch.utils.data.Subset(dataset, range(0, n_train))
|
|
val_dataset = torch.utils.data.Subset(dataset, range(n_train, n_total))
|
|
|
|
sampler = create_balanced_sampler(train_dataset, cfg)
|
|
train_loader = DataLoader(
|
|
train_dataset,
|
|
batch_size=cfg.training.batch_size,
|
|
num_workers=cfg.training.num_workers,
|
|
sampler=sampler,
|
|
pin_memory=device.type == "cuda",
|
|
)
|
|
|
|
val_loader = DataLoader(
|
|
val_dataset,
|
|
batch_size=cfg.eval.batch_size,
|
|
shuffle=False,
|
|
num_workers=cfg.training.num_workers,
|
|
pin_memory=device.type == "cuda",
|
|
)
|
|
|
|
# Resume training if requested
|
|
step = 0
|
|
best_val_acc = 0
|
|
|
|
if cfg.resume:
|
|
if not Logger.get_last_checkpoint_dir(out_dir).exists():
|
|
raise RuntimeError(
|
|
"You have set resume=True, but there is no model checkpoint in "
|
|
f"{Logger.get_last_checkpoint_dir(out_dir)}"
|
|
)
|
|
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
|
|
logging.info(
|
|
colored(
|
|
"You have set resume=True, indicating that you wish to resume a run",
|
|
color="yellow",
|
|
attrs=["bold"],
|
|
)
|
|
)
|
|
# Load and validate checkpoint configuration
|
|
checkpoint_cfg = init_hydra_config(checkpoint_cfg_path)
|
|
# Check for differences between the checkpoint configuration and provided configuration.
|
|
# Hack to resolve the delta_timestamps ahead of time in order to properly diff.
|
|
resolve_delta_timestamps(cfg)
|
|
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
|
|
# Ignore the `resume` and parameters.
|
|
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
|
|
del diff["values_changed"]["root['resume']"]
|
|
if len(diff) > 0:
|
|
logging.warning(
|
|
"At least one difference was detected between the checkpoint configuration and "
|
|
f"the provided configuration: \n{pformat(diff)}\nNote that the checkpoint configuration "
|
|
"takes precedence.",
|
|
)
|
|
# Use the checkpoint config instead of the provided config (but keep `resume` parameter).
|
|
cfg = checkpoint_cfg
|
|
cfg.resume = True
|
|
|
|
# Initialize model and training components
|
|
model = get_model(cfg=cfg, logger=logger).to(device)
|
|
|
|
optimizer = optim.AdamW(model.parameters(), lr=cfg.training.learning_rate)
|
|
# Use BCEWithLogitsLoss for binary classification and CrossEntropyLoss for multi-class
|
|
criterion = nn.BCEWithLogitsLoss() if model.config.num_classes == 2 else nn.CrossEntropyLoss()
|
|
grad_scaler = GradScaler(enabled=cfg.training.use_amp)
|
|
|
|
# Log model parameters
|
|
num_learnable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
num_total_params = sum(p.numel() for p in model.parameters())
|
|
logging.info(f"Learnable parameters: {format_big_number(num_learnable_params)}")
|
|
logging.info(f"Total parameters: {format_big_number(num_total_params)}")
|
|
|
|
if cfg.resume:
|
|
step = logger.load_last_training_state(optimizer, None)
|
|
|
|
# Training loop with validation and checkpointing
|
|
for epoch in range(cfg.training.num_epochs):
|
|
logging.info(f"\nEpoch {epoch + 1}/{cfg.training.num_epochs}")
|
|
|
|
train_epoch(
|
|
model,
|
|
train_loader,
|
|
criterion,
|
|
optimizer,
|
|
grad_scaler,
|
|
device,
|
|
logger,
|
|
step,
|
|
cfg,
|
|
)
|
|
|
|
# Periodic validation
|
|
if cfg.training.eval_freq > 0 and (epoch + 1) % cfg.training.eval_freq == 0:
|
|
val_acc, eval_info = validate(
|
|
model,
|
|
val_loader,
|
|
criterion,
|
|
device,
|
|
logger,
|
|
cfg,
|
|
)
|
|
logger.log_dict(eval_info, step + len(train_loader), mode="eval")
|
|
|
|
# Save best model
|
|
if val_acc > best_val_acc:
|
|
best_val_acc = val_acc
|
|
logger.save_checkpoint(
|
|
train_step=step + len(train_loader),
|
|
policy=model,
|
|
optimizer=optimizer,
|
|
scheduler=None,
|
|
identifier="best",
|
|
)
|
|
|
|
# Periodic checkpointing
|
|
if cfg.training.save_checkpoint and (epoch + 1) % cfg.training.save_freq == 0:
|
|
logger.save_checkpoint(
|
|
train_step=step + len(train_loader),
|
|
policy=model,
|
|
optimizer=optimizer,
|
|
scheduler=None,
|
|
identifier=f"{epoch + 1:06d}",
|
|
)
|
|
|
|
step += len(train_loader)
|
|
|
|
benchmark_inference_time(model, dataset, logger, cfg, device, step)
|
|
|
|
logging.info("Training completed")
|
|
|
|
|
|
@hydra.main(
|
|
version_base="1.2",
|
|
config_name="hilserl_classifier",
|
|
config_path="../configs/policy",
|
|
)
|
|
def train_cli(cfg: dict):
|
|
train(
|
|
cfg,
|
|
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
|
|
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
|
|
)
|
|
|
|
|
|
def train_notebook(
|
|
out_dir=None,
|
|
job_name=None,
|
|
config_name="hilserl_classifier",
|
|
config_path="../configs/policy",
|
|
):
|
|
from hydra import compose, initialize
|
|
|
|
hydra.core.global_hydra.GlobalHydra.instance().clear()
|
|
initialize(config_path=config_path)
|
|
cfg = compose(config_name=config_name)
|
|
train(cfg, out_dir=out_dir, job_name=job_name)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
train_cli()
|