311 lines
12 KiB
Python
311 lines
12 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import logging
|
|
import time
|
|
from contextlib import nullcontext
|
|
from pathlib import Path
|
|
from pprint import pformat
|
|
|
|
import hydra
|
|
import torch
|
|
import torch.nn as nn
|
|
from deepdiff import DeepDiff
|
|
from omegaconf import DictConfig, OmegaConf
|
|
from termcolor import colored
|
|
from torch import optim
|
|
from torch.cuda.amp import GradScaler
|
|
from torch.utils.data import DataLoader, WeightedRandomSampler, random_split
|
|
from tqdm import tqdm
|
|
|
|
import wandb
|
|
from lerobot.common.datasets.factory import resolve_delta_timestamps
|
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
|
from lerobot.common.logger import Logger
|
|
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
|
|
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
|
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
|
from lerobot.common.utils.utils import (
|
|
format_big_number,
|
|
get_safe_torch_device,
|
|
init_hydra_config,
|
|
set_global_seed,
|
|
)
|
|
|
|
|
|
def get_model(cfg, logger):
|
|
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
|
|
model = Classifier(classifier_config)
|
|
if cfg.resume:
|
|
model.load_state_dict(Classifier.from_pretrained(str(logger.last_pretrained_model_dir)).state_dict())
|
|
return model
|
|
|
|
|
|
def create_balanced_sampler(dataset, cfg):
|
|
# Creates a weighted sampler to handle class imbalance
|
|
|
|
labels = torch.tensor([item[cfg.training.label_key] for item in dataset])
|
|
_, counts = torch.unique(labels, return_counts=True)
|
|
class_weights = 1.0 / counts.float()
|
|
sample_weights = class_weights[labels]
|
|
|
|
return WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
|
|
|
|
|
|
def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg):
|
|
# Single epoch training loop with AMP support and progress tracking
|
|
model.train()
|
|
correct = 0
|
|
total = 0
|
|
|
|
pbar = tqdm(train_loader, desc="Training")
|
|
for batch_idx, batch in enumerate(pbar):
|
|
start_time = time.perf_counter()
|
|
images = batch[cfg.training.image_key].to(device)
|
|
labels = batch[cfg.training.label_key].float().to(device)
|
|
|
|
# Forward pass with optional AMP
|
|
with torch.autocast(device_type=device.type) if cfg.training.use_amp else nullcontext():
|
|
outputs = model(images)
|
|
loss = criterion(outputs.logits, labels)
|
|
|
|
# Backward pass with gradient scaling if AMP enabled
|
|
optimizer.zero_grad()
|
|
if cfg.training.use_amp:
|
|
grad_scaler.scale(loss).backward()
|
|
grad_scaler.step(optimizer)
|
|
grad_scaler.update()
|
|
else:
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
# Track metrics
|
|
if model.config.num_classes == 2:
|
|
predictions = (torch.sigmoid(outputs.logits) > 0.5).float()
|
|
else:
|
|
predictions = torch.argmax(outputs.logits, dim=1)
|
|
correct += (predictions == labels).sum().item()
|
|
total += labels.size(0)
|
|
|
|
current_acc = 100 * correct / total
|
|
train_info = {
|
|
"loss": loss.item(),
|
|
"accuracy": current_acc,
|
|
"dataloading_s": time.perf_counter() - start_time,
|
|
}
|
|
|
|
logger.log_dict(train_info, step + batch_idx, mode="train")
|
|
pbar.set_postfix({"loss": f"{loss.item():.4f}", "acc": f"{current_acc:.2f}%"})
|
|
|
|
|
|
def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_log=8):
|
|
# Validation loop with metric tracking and sample logging
|
|
model.eval()
|
|
correct = 0
|
|
total = 0
|
|
batch_start_time = time.perf_counter()
|
|
samples = []
|
|
running_loss = 0
|
|
|
|
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.training.use_amp else nullcontext():
|
|
for batch in tqdm(val_loader, desc="Validation"):
|
|
images = batch[cfg.training.image_key].to(device)
|
|
labels = batch[cfg.training.label_key].float().to(device)
|
|
|
|
outputs = model(images)
|
|
loss = criterion(outputs.logits, labels)
|
|
|
|
# Track metrics
|
|
if model.config.num_classes == 2:
|
|
predictions = (torch.sigmoid(outputs.logits) > 0.5).float()
|
|
else:
|
|
predictions = torch.argmax(outputs.logits, dim=1)
|
|
correct += (predictions == labels).sum().item()
|
|
total += labels.size(0)
|
|
running_loss += loss.item()
|
|
|
|
# Log sample predictions for visualization
|
|
if len(samples) < num_samples_to_log:
|
|
for i in range(min(num_samples_to_log - len(samples), len(images))):
|
|
if model.config.num_classes == 2:
|
|
confidence = round(outputs.probabilities[i].item(), 3)
|
|
else:
|
|
confidence = [round(prob, 3) for prob in outputs.probabilities[i].tolist()]
|
|
samples.append(
|
|
{
|
|
"image": wandb.Image(images[i].cpu()),
|
|
"true_label": labels[i].item(),
|
|
"predicted": predictions[i].item(),
|
|
"confidence": confidence,
|
|
}
|
|
)
|
|
|
|
accuracy = 100 * correct / total
|
|
avg_loss = running_loss / len(val_loader)
|
|
|
|
eval_info = {
|
|
"loss": avg_loss,
|
|
"accuracy": accuracy,
|
|
"eval_s": time.perf_counter() - batch_start_time,
|
|
"eval/prediction_samples": wandb.Table(
|
|
data=[[s["image"], s["true_label"], s["predicted"], f"{s['confidence']}"] for s in samples],
|
|
columns=["Image", "True Label", "Predicted", "Confidence"],
|
|
)
|
|
if logger._cfg.wandb.enable
|
|
else None,
|
|
}
|
|
|
|
return accuracy, eval_info
|
|
|
|
|
|
@hydra.main(version_base="1.2", config_path="../configs", config_name="hilserl_classifier")
|
|
def train(cfg: DictConfig) -> None:
|
|
# Main training pipeline with support for resuming training
|
|
logging.info(OmegaConf.to_yaml(cfg))
|
|
|
|
# Initialize training environment
|
|
device = get_safe_torch_device(cfg.device, log=True)
|
|
set_global_seed(cfg.seed)
|
|
|
|
out_dir = Path(cfg.output_dir)
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
logger = Logger(cfg, out_dir, cfg.wandb.job_name if cfg.wandb.enable else None)
|
|
|
|
# Setup dataset and dataloaders
|
|
dataset = LeRobotDataset(cfg.dataset_repo_id)
|
|
logging.info(f"Dataset size: {len(dataset)}")
|
|
|
|
train_size = int(cfg.train_split_proportion * len(dataset))
|
|
val_size = len(dataset) - train_size
|
|
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
|
|
|
|
sampler = create_balanced_sampler(train_dataset, cfg)
|
|
train_loader = DataLoader(
|
|
train_dataset,
|
|
batch_size=cfg.training.batch_size,
|
|
num_workers=cfg.training.num_workers,
|
|
sampler=sampler,
|
|
pin_memory=True,
|
|
)
|
|
|
|
val_loader = DataLoader(
|
|
val_dataset,
|
|
batch_size=cfg.eval.batch_size,
|
|
shuffle=False,
|
|
num_workers=cfg.training.num_workers,
|
|
pin_memory=True,
|
|
)
|
|
|
|
# Resume training if requested
|
|
step = 0
|
|
best_val_acc = 0
|
|
|
|
if cfg.resume:
|
|
if not Logger.get_last_checkpoint_dir(out_dir).exists():
|
|
raise RuntimeError(
|
|
"You have set resume=True, but there is no model checkpoint in "
|
|
f"{Logger.get_last_checkpoint_dir(out_dir)}"
|
|
)
|
|
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
|
|
logging.info(
|
|
colored(
|
|
"You have set resume=True, indicating that you wish to resume a run",
|
|
color="yellow",
|
|
attrs=["bold"],
|
|
)
|
|
)
|
|
# Load and validate checkpoint configuration
|
|
checkpoint_cfg = init_hydra_config(checkpoint_cfg_path)
|
|
# Check for differences between the checkpoint configuration and provided configuration.
|
|
# Hack to resolve the delta_timestamps ahead of time in order to properly diff.
|
|
resolve_delta_timestamps(cfg)
|
|
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
|
|
# Ignore the `resume` and parameters.
|
|
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
|
|
del diff["values_changed"]["root['resume']"]
|
|
if len(diff) > 0:
|
|
logging.warning(
|
|
"At least one difference was detected between the checkpoint configuration and "
|
|
f"the provided configuration: \n{pformat(diff)}\nNote that the checkpoint configuration "
|
|
"takes precedence.",
|
|
)
|
|
# Use the checkpoint config instead of the provided config (but keep `resume` parameter).
|
|
cfg = checkpoint_cfg
|
|
cfg.resume = True
|
|
|
|
# Initialize model and training components
|
|
model = get_model(cfg=cfg, logger=logger).to(device)
|
|
|
|
optimizer = optim.AdamW(model.parameters(), lr=cfg.training.learning_rate)
|
|
# Use BCEWithLogitsLoss for binary classification and CrossEntropyLoss for multi-class
|
|
criterion = nn.BCEWithLogitsLoss() if model.config.num_classes == 2 else nn.CrossEntropyLoss()
|
|
grad_scaler = GradScaler(enabled=cfg.training.use_amp)
|
|
|
|
# Log model parameters
|
|
num_learnable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
num_total_params = sum(p.numel() for p in model.parameters())
|
|
logging.info(f"Learnable parameters: {format_big_number(num_learnable_params)}")
|
|
logging.info(f"Total parameters: {format_big_number(num_total_params)}")
|
|
|
|
if cfg.resume:
|
|
step = logger.load_last_training_state(optimizer, None)
|
|
|
|
# Training loop with validation and checkpointing
|
|
for epoch in range(cfg.training.num_epochs):
|
|
logging.info(f"\nEpoch {epoch+1}/{cfg.training.num_epochs}")
|
|
|
|
train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg)
|
|
|
|
# Periodic validation
|
|
if cfg.training.eval_freq > 0 and (epoch + 1) % cfg.training.eval_freq == 0:
|
|
val_acc, eval_info = validate(
|
|
model,
|
|
val_loader,
|
|
criterion,
|
|
device,
|
|
logger,
|
|
cfg,
|
|
)
|
|
logger.log_dict(eval_info, step + len(train_loader), mode="eval")
|
|
|
|
# Save best model
|
|
if val_acc > best_val_acc:
|
|
best_val_acc = val_acc
|
|
logger.save_checkpoint(
|
|
train_step=step + len(train_loader),
|
|
policy=model,
|
|
optimizer=optimizer,
|
|
scheduler=None,
|
|
identifier="best",
|
|
)
|
|
|
|
# Periodic checkpointing
|
|
if cfg.training.save_checkpoint and (epoch + 1) % cfg.training.save_freq == 0:
|
|
logger.save_checkpoint(
|
|
train_step=step + len(train_loader),
|
|
policy=model,
|
|
optimizer=optimizer,
|
|
scheduler=None,
|
|
identifier=f"{epoch+1:06d}",
|
|
)
|
|
|
|
step += len(train_loader)
|
|
|
|
logging.info("Training completed")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
train()
|