268 lines
8.9 KiB
Python
268 lines
8.9 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import logging
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss
|
|
from torch.optim import Adam
|
|
from torch.utils.data import DataLoader
|
|
from torchmetrics import AUROC, Accuracy, F1Score, Precision, Recall
|
|
from torchvision.datasets import CIFAR10
|
|
from torchvision.transforms import ToTensor
|
|
|
|
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
|
|
Classifier,
|
|
ClassifierConfig,
|
|
)
|
|
|
|
BATCH_SIZE = 1000
|
|
LR = 0.1
|
|
EPOCH_NUM = 2
|
|
|
|
if torch.cuda.is_available():
|
|
DEVICE = torch.device("cuda")
|
|
elif torch.backends.mps.is_available():
|
|
DEVICE = torch.device("mps")
|
|
else:
|
|
DEVICE = torch.device("cpu")
|
|
|
|
|
|
def train_evaluate_multiclass_classifier():
|
|
logging.info(
|
|
f"Start multiclass classifier train eval with {DEVICE} device, batch size {BATCH_SIZE}, learning rate {LR}"
|
|
)
|
|
multiclass_config = ClassifierConfig(
|
|
model_name="microsoft/resnet-18", device=DEVICE, num_classes=10
|
|
)
|
|
multiclass_classifier = Classifier(multiclass_config)
|
|
|
|
trainset = CIFAR10(root="data", train=True, download=True, transform=ToTensor())
|
|
testset = CIFAR10(root="data", train=False, download=True, transform=ToTensor())
|
|
|
|
trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
|
|
testloader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False)
|
|
|
|
multiclass_num_classes = 10
|
|
epoch = 1
|
|
|
|
criterion = CrossEntropyLoss()
|
|
optimizer = Adam(multiclass_classifier.parameters(), lr=LR)
|
|
|
|
multiclass_classifier.train()
|
|
|
|
logging.info("Start multiclass classifier training")
|
|
|
|
# Training loop
|
|
while epoch < EPOCH_NUM: # loop over the dataset multiple times
|
|
for i, data in enumerate(trainloader):
|
|
inputs, labels = data
|
|
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
|
|
|
|
# Zero the parameter gradients
|
|
optimizer.zero_grad()
|
|
|
|
# Forward pass
|
|
outputs = multiclass_classifier(inputs)
|
|
|
|
loss = criterion(outputs.logits, labels)
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
if i % 10 == 0: # print every 10 mini-batches
|
|
logging.info(f"[Epoch {epoch}, Batch {i}] loss: {loss.item():.3f}")
|
|
|
|
epoch += 1
|
|
|
|
print("Multiclass classifier training finished")
|
|
|
|
multiclass_classifier.eval()
|
|
|
|
test_loss = 0.0
|
|
test_labels = []
|
|
test_pridections = []
|
|
test_probs = []
|
|
|
|
with torch.no_grad():
|
|
for data in testloader:
|
|
images, labels = data
|
|
images, labels = images.to(DEVICE), labels.to(DEVICE)
|
|
outputs = multiclass_classifier(images)
|
|
loss = criterion(outputs.logits, labels)
|
|
test_loss += loss.item() * BATCH_SIZE
|
|
|
|
_, predicted = torch.max(outputs.logits, 1)
|
|
test_labels.extend(labels.cpu())
|
|
test_pridections.extend(predicted.cpu())
|
|
test_probs.extend(outputs.probabilities.cpu())
|
|
|
|
test_loss = test_loss / len(testset)
|
|
|
|
logging.info(f"Multiclass classifier test loss {test_loss:.3f}")
|
|
|
|
test_labels = torch.stack(test_labels)
|
|
test_predictions = torch.stack(test_pridections)
|
|
test_probs = torch.stack(test_probs)
|
|
|
|
accuracy = Accuracy(task="multiclass", num_classes=multiclass_num_classes)
|
|
precision = Precision(
|
|
task="multiclass", average="weighted", num_classes=multiclass_num_classes
|
|
)
|
|
recall = Recall(
|
|
task="multiclass", average="weighted", num_classes=multiclass_num_classes
|
|
)
|
|
f1 = F1Score(
|
|
task="multiclass", average="weighted", num_classes=multiclass_num_classes
|
|
)
|
|
auroc = AUROC(
|
|
task="multiclass", num_classes=multiclass_num_classes, average="weighted"
|
|
)
|
|
|
|
# Calculate metrics
|
|
acc = accuracy(test_predictions, test_labels)
|
|
prec = precision(test_predictions, test_labels)
|
|
rec = recall(test_predictions, test_labels)
|
|
f1_score = f1(test_predictions, test_labels)
|
|
auroc_score = auroc(test_probs, test_labels)
|
|
|
|
logging.info(f"Accuracy: {acc:.2f}")
|
|
logging.info(f"Precision: {prec:.2f}")
|
|
logging.info(f"Recall: {rec:.2f}")
|
|
logging.info(f"F1 Score: {f1_score:.2f}")
|
|
logging.info(f"AUROC Score: {auroc_score:.2f}")
|
|
|
|
|
|
def train_evaluate_binary_classifier():
|
|
logging.info(
|
|
f"Start binary classifier train eval with {DEVICE} device, batch size {BATCH_SIZE}, learning rate {LR}"
|
|
)
|
|
|
|
target_binary_class = 3
|
|
|
|
def one_vs_rest(dataset, target_class):
|
|
new_targets = []
|
|
for _, label in dataset:
|
|
new_label = float(1.0) if label == target_class else float(0.0)
|
|
new_targets.append(new_label)
|
|
|
|
dataset.targets = (
|
|
new_targets # Replace the original labels with the binary ones
|
|
)
|
|
return dataset
|
|
|
|
binary_train_dataset = CIFAR10(
|
|
root="data", train=True, download=True, transform=ToTensor()
|
|
)
|
|
binary_test_dataset = CIFAR10(
|
|
root="data", train=False, download=True, transform=ToTensor()
|
|
)
|
|
|
|
# Apply one-vs-rest labeling
|
|
binary_train_dataset = one_vs_rest(binary_train_dataset, target_binary_class)
|
|
binary_test_dataset = one_vs_rest(binary_test_dataset, target_binary_class)
|
|
|
|
binary_trainloader = DataLoader(
|
|
binary_train_dataset, batch_size=BATCH_SIZE, shuffle=True
|
|
)
|
|
binary_testloader = DataLoader(
|
|
binary_test_dataset, batch_size=BATCH_SIZE, shuffle=False
|
|
)
|
|
|
|
binary_epoch = 1
|
|
|
|
binary_config = ClassifierConfig(model_name="microsoft/resnet-50", device=DEVICE)
|
|
binary_classifier = Classifier(binary_config)
|
|
|
|
class_counts = np.bincount(binary_train_dataset.targets)
|
|
n = len(binary_train_dataset)
|
|
w0 = n / (2.0 * class_counts[0])
|
|
w1 = n / (2.0 * class_counts[1])
|
|
|
|
binary_criterion = BCEWithLogitsLoss(pos_weight=torch.tensor(w1 / w0))
|
|
binary_optimizer = Adam(binary_classifier.parameters(), lr=LR)
|
|
|
|
binary_classifier.train()
|
|
|
|
logging.info("Start binary classifier training")
|
|
|
|
# Training loop
|
|
while binary_epoch < EPOCH_NUM: # loop over the dataset multiple times
|
|
for i, data in enumerate(binary_trainloader):
|
|
inputs, labels = data
|
|
inputs, labels = inputs.to(DEVICE), labels.to(torch.float32).to(DEVICE)
|
|
|
|
# Zero the parameter gradients
|
|
binary_optimizer.zero_grad()
|
|
|
|
# Forward pass
|
|
outputs = binary_classifier(inputs)
|
|
loss = binary_criterion(outputs.logits, labels)
|
|
loss.backward()
|
|
binary_optimizer.step()
|
|
|
|
if i % 10 == 0: # print every 10 mini-batches
|
|
print(f"[Epoch {binary_epoch}, Batch {i}] loss: {loss.item():.3f}")
|
|
binary_epoch += 1
|
|
|
|
logging.info("Binary classifier training finished")
|
|
logging.info("Start binary classifier evaluation")
|
|
|
|
binary_classifier.eval()
|
|
|
|
test_loss = 0.0
|
|
test_labels = []
|
|
test_pridections = []
|
|
test_probs = []
|
|
|
|
with torch.no_grad():
|
|
for data in binary_testloader:
|
|
images, labels = data
|
|
images, labels = images.to(DEVICE), labels.to(torch.float32).to(DEVICE)
|
|
outputs = binary_classifier(images)
|
|
loss = binary_criterion(outputs.logits, labels)
|
|
test_loss += loss.item() * BATCH_SIZE
|
|
|
|
test_labels.extend(labels.cpu())
|
|
test_pridections.extend(outputs.logits.cpu())
|
|
test_probs.extend(outputs.probabilities.cpu())
|
|
|
|
test_loss = test_loss / len(binary_test_dataset)
|
|
|
|
logging.info(f"Binary classifier test loss {test_loss:.3f}")
|
|
|
|
test_labels = torch.stack(test_labels)
|
|
test_predictions = torch.stack(test_pridections)
|
|
test_probs = torch.stack(test_probs)
|
|
|
|
# Calculate metrics
|
|
acc = Accuracy(task="binary")(test_predictions, test_labels)
|
|
prec = Precision(task="binary", average="weighted")(test_predictions, test_labels)
|
|
rec = Recall(task="binary", average="weighted")(test_predictions, test_labels)
|
|
f1_score = F1Score(task="binary", average="weighted")(test_predictions, test_labels)
|
|
auroc_score = AUROC(task="binary", average="weighted")(test_probs, test_labels)
|
|
|
|
logging.info(f"Accuracy: {acc:.2f}")
|
|
logging.info(f"Precision: {prec:.2f}")
|
|
logging.info(f"Recall: {rec:.2f}")
|
|
logging.info(f"F1 Score: {f1_score:.2f}")
|
|
logging.info(f"AUROC Score: {auroc_score:.2f}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
train_evaluate_multiclass_classifier()
|
|
train_evaluate_binary_classifier()
|