[Port Hil-SERL] Add unit tests for the reward classifier & fix imports & check script (#578)
This commit is contained in:
parent
66268fcf85
commit
6340d9d17c
|
@ -13,7 +13,7 @@ class ClassifierConfig:
|
||||||
hidden_dim: int = 256
|
hidden_dim: int = 256
|
||||||
dropout_rate: float = 0.1
|
dropout_rate: float = 0.1
|
||||||
model_name: str = "microsoft/resnet-50"
|
model_name: str = "microsoft/resnet-50"
|
||||||
device: str = "cuda" if torch.cuda.is_available() else "mps"
|
device: str = "cpu"
|
||||||
model_type: str = "cnn" # "transformer" or "cnn"
|
model_type: str = "cnn" # "transformer" or "cnn"
|
||||||
|
|
||||||
def save_pretrained(self, save_dir):
|
def save_pretrained(self, save_dir):
|
||||||
|
|
|
@ -22,6 +22,11 @@ class ClassifierOutput:
|
||||||
self.probabilities = probabilities
|
self.probabilities = probabilities
|
||||||
self.hidden_states = hidden_states
|
self.hidden_states = hidden_states
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return (f"ClassifierOutput(logits={self.logits}, "
|
||||||
|
f"probabilities={self.probabilities}, "
|
||||||
|
f"hidden_states={self.hidden_states})")
|
||||||
|
|
||||||
|
|
||||||
class Classifier(
|
class Classifier(
|
||||||
nn.Module,
|
nn.Module,
|
||||||
|
@ -69,6 +74,8 @@ class Classifier(
|
||||||
self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension
|
self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported CNN architecture")
|
raise ValueError("Unsupported CNN architecture")
|
||||||
|
|
||||||
|
self.encoder = self.encoder.to(self.config.device)
|
||||||
|
|
||||||
def _freeze_encoder(self) -> None:
|
def _freeze_encoder(self) -> None:
|
||||||
"""Freeze the encoder parameters."""
|
"""Freeze the encoder parameters."""
|
||||||
|
@ -93,6 +100,7 @@ class Classifier(
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Linear(self.config.hidden_dim, 1 if self.config.num_classes == 2 else self.config.num_classes),
|
nn.Linear(self.config.hidden_dim, 1 if self.config.num_classes == 2 else self.config.num_classes),
|
||||||
)
|
)
|
||||||
|
self.classifier_head = self.classifier_head.to(self.config.device)
|
||||||
|
|
||||||
def _get_encoder_output(self, x: torch.Tensor) -> torch.Tensor:
|
def _get_encoder_output(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""Extract the appropriate output from the encoder."""
|
"""Extract the appropriate output from the encoder."""
|
||||||
|
|
|
@ -70,7 +70,9 @@ dependencies = [
|
||||||
"termcolor>=2.4.0",
|
"termcolor>=2.4.0",
|
||||||
"torch>=2.2.1",
|
"torch>=2.2.1",
|
||||||
"torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l'))",
|
"torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l'))",
|
||||||
|
"torchmetrics>=1.6.0",
|
||||||
"torchvision>=0.21.0",
|
"torchvision>=0.21.0",
|
||||||
|
"transformers>=4.47.0",
|
||||||
"wandb>=0.16.3",
|
"wandb>=0.16.3",
|
||||||
"zarr>=2.17.0",
|
"zarr>=2.17.0",
|
||||||
]
|
]
|
||||||
|
|
|
@ -14,9 +14,11 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import random
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
from serial import SerialException
|
from serial import SerialException
|
||||||
|
|
||||||
from lerobot import available_cameras, available_motors, available_robots
|
from lerobot import available_cameras, available_motors, available_robots
|
||||||
|
@ -86,3 +88,14 @@ def patch_builtins_input(monkeypatch):
|
||||||
print(text)
|
print(text)
|
||||||
|
|
||||||
monkeypatch.setattr("builtins.input", print_text)
|
monkeypatch.setattr("builtins.input", print_text)
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_addoption(parser):
|
||||||
|
parser.addoption("--seed", action="store", default="42", help="Set random seed for reproducibility")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def set_random_seed(request):
|
||||||
|
seed = int(request.config.getoption("--seed"))
|
||||||
|
random.seed(seed) # Python random
|
||||||
|
torch.manual_seed(seed) # PyTorch
|
||||||
|
|
|
@ -0,0 +1,244 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss
|
||||||
|
from torch.optim import Adam
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torchmetrics import AUROC, Accuracy, F1Score, Precision, Recall
|
||||||
|
from torchvision.datasets import CIFAR10
|
||||||
|
from torchvision.transforms import ToTensor
|
||||||
|
|
||||||
|
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier, ClassifierConfig
|
||||||
|
|
||||||
|
BATCH_SIZE = 1000
|
||||||
|
LR = 0.1
|
||||||
|
EPOCH_NUM = 2
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
DEVICE = torch.device("cuda")
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
DEVICE = torch.device("mps")
|
||||||
|
else:
|
||||||
|
DEVICE = torch.device("cpu")
|
||||||
|
|
||||||
|
|
||||||
|
def train_evaluate_multiclass_classifier():
|
||||||
|
logging.info(
|
||||||
|
f"Start multiclass classifier train eval with {DEVICE} device, batch size {BATCH_SIZE}, learning rate {LR}"
|
||||||
|
)
|
||||||
|
multiclass_config = ClassifierConfig(model_name="microsoft/resnet-18", device=DEVICE, num_classes=10)
|
||||||
|
multiclass_classifier = Classifier(multiclass_config)
|
||||||
|
|
||||||
|
trainset = CIFAR10(root="data", train=True, download=True, transform=ToTensor())
|
||||||
|
testset = CIFAR10(root="data", train=False, download=True, transform=ToTensor())
|
||||||
|
|
||||||
|
trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
|
||||||
|
testloader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False)
|
||||||
|
|
||||||
|
multiclass_num_classes = 10
|
||||||
|
epoch = 1
|
||||||
|
|
||||||
|
criterion = CrossEntropyLoss()
|
||||||
|
optimizer = Adam(multiclass_classifier.parameters(), lr=LR)
|
||||||
|
|
||||||
|
multiclass_classifier.train()
|
||||||
|
|
||||||
|
logging.info("Start multiclass classifier training")
|
||||||
|
|
||||||
|
# Training loop
|
||||||
|
while epoch < EPOCH_NUM: # loop over the dataset multiple times
|
||||||
|
for i, data in enumerate(trainloader):
|
||||||
|
inputs, labels = data
|
||||||
|
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
|
||||||
|
|
||||||
|
# Zero the parameter gradients
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
outputs = multiclass_classifier(inputs)
|
||||||
|
|
||||||
|
loss = criterion(outputs.logits, labels)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
if i % 10 == 0: # print every 10 mini-batches
|
||||||
|
logging.info(f"[Epoch {epoch}, Batch {i}] loss: {loss.item():.3f}")
|
||||||
|
|
||||||
|
epoch += 1
|
||||||
|
|
||||||
|
print("Multiclass classifier training finished")
|
||||||
|
|
||||||
|
multiclass_classifier.eval()
|
||||||
|
|
||||||
|
test_loss = 0.0
|
||||||
|
test_labels = []
|
||||||
|
test_pridections = []
|
||||||
|
test_probs = []
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for data in testloader:
|
||||||
|
images, labels = data
|
||||||
|
images, labels = images.to(DEVICE), labels.to(DEVICE)
|
||||||
|
outputs = multiclass_classifier(images)
|
||||||
|
loss = criterion(outputs.logits, labels)
|
||||||
|
test_loss += loss.item() * BATCH_SIZE
|
||||||
|
|
||||||
|
_, predicted = torch.max(outputs.logits, 1)
|
||||||
|
test_labels.extend(labels.cpu())
|
||||||
|
test_pridections.extend(predicted.cpu())
|
||||||
|
test_probs.extend(outputs.probabilities.cpu())
|
||||||
|
|
||||||
|
test_loss = test_loss / len(testset)
|
||||||
|
|
||||||
|
logging.info(f"Multiclass classifier test loss {test_loss:.3f}")
|
||||||
|
|
||||||
|
test_labels = torch.stack(test_labels)
|
||||||
|
test_predictions = torch.stack(test_pridections)
|
||||||
|
test_probs = torch.stack(test_probs)
|
||||||
|
|
||||||
|
accuracy = Accuracy(task="multiclass", num_classes=multiclass_num_classes)
|
||||||
|
precision = Precision(task="multiclass", average="weighted", num_classes=multiclass_num_classes)
|
||||||
|
recall = Recall(task="multiclass", average="weighted", num_classes=multiclass_num_classes)
|
||||||
|
f1 = F1Score(task="multiclass", average="weighted", num_classes=multiclass_num_classes)
|
||||||
|
auroc = AUROC(task="multiclass", num_classes=multiclass_num_classes, average="weighted")
|
||||||
|
|
||||||
|
# Calculate metrics
|
||||||
|
acc = accuracy(test_predictions, test_labels)
|
||||||
|
prec = precision(test_predictions, test_labels)
|
||||||
|
rec = recall(test_predictions, test_labels)
|
||||||
|
f1_score = f1(test_predictions, test_labels)
|
||||||
|
auroc_score = auroc(test_probs, test_labels)
|
||||||
|
|
||||||
|
logging.info(f"Accuracy: {acc:.2f}")
|
||||||
|
logging.info(f"Precision: {prec:.2f}")
|
||||||
|
logging.info(f"Recall: {rec:.2f}")
|
||||||
|
logging.info(f"F1 Score: {f1_score:.2f}")
|
||||||
|
logging.info(f"AUROC Score: {auroc_score:.2f}")
|
||||||
|
|
||||||
|
|
||||||
|
def train_evaluate_binary_classifier():
|
||||||
|
logging.info(
|
||||||
|
f"Start binary classifier train eval with {DEVICE} device, batch size {BATCH_SIZE}, learning rate {LR}"
|
||||||
|
)
|
||||||
|
|
||||||
|
target_binary_class = 3
|
||||||
|
|
||||||
|
def one_vs_rest(dataset, target_class):
|
||||||
|
new_targets = []
|
||||||
|
for _, label in dataset:
|
||||||
|
new_label = float(1.0) if label == target_class else float(0.0)
|
||||||
|
new_targets.append(new_label)
|
||||||
|
|
||||||
|
dataset.targets = new_targets # Replace the original labels with the binary ones
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
binary_train_dataset = CIFAR10(root="data", train=True, download=True, transform=ToTensor())
|
||||||
|
binary_test_dataset = CIFAR10(root="data", train=False, download=True, transform=ToTensor())
|
||||||
|
|
||||||
|
# Apply one-vs-rest labeling
|
||||||
|
binary_train_dataset = one_vs_rest(binary_train_dataset, target_binary_class)
|
||||||
|
binary_test_dataset = one_vs_rest(binary_test_dataset, target_binary_class)
|
||||||
|
|
||||||
|
binary_trainloader = DataLoader(binary_train_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||||
|
binary_testloader = DataLoader(binary_test_dataset, batch_size=BATCH_SIZE, shuffle=False)
|
||||||
|
|
||||||
|
binary_epoch = 1
|
||||||
|
|
||||||
|
binary_config = ClassifierConfig(model_name="microsoft/resnet-50", device=DEVICE)
|
||||||
|
binary_classifier = Classifier(binary_config)
|
||||||
|
|
||||||
|
class_counts = np.bincount(binary_train_dataset.targets)
|
||||||
|
n = len(binary_train_dataset)
|
||||||
|
w0 = n / (2.0 * class_counts[0])
|
||||||
|
w1 = n / (2.0 * class_counts[1])
|
||||||
|
|
||||||
|
binary_criterion = BCEWithLogitsLoss(pos_weight=torch.tensor(w1 / w0))
|
||||||
|
binary_optimizer = Adam(binary_classifier.parameters(), lr=LR)
|
||||||
|
|
||||||
|
binary_classifier.train()
|
||||||
|
|
||||||
|
logging.info("Start binary classifier training")
|
||||||
|
|
||||||
|
# Training loop
|
||||||
|
while binary_epoch < EPOCH_NUM: # loop over the dataset multiple times
|
||||||
|
for i, data in enumerate(binary_trainloader):
|
||||||
|
inputs, labels = data
|
||||||
|
inputs, labels = inputs.to(DEVICE), labels.to(torch.float32).to(DEVICE)
|
||||||
|
|
||||||
|
# Zero the parameter gradients
|
||||||
|
binary_optimizer.zero_grad()
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
outputs = binary_classifier(inputs)
|
||||||
|
loss = binary_criterion(outputs.logits, labels)
|
||||||
|
loss.backward()
|
||||||
|
binary_optimizer.step()
|
||||||
|
|
||||||
|
if i % 10 == 0: # print every 10 mini-batches
|
||||||
|
print(f"[Epoch {binary_epoch}, Batch {i}] loss: {loss.item():.3f}")
|
||||||
|
binary_epoch += 1
|
||||||
|
|
||||||
|
logging.info("Binary classifier training finished")
|
||||||
|
logging.info("Start binary classifier evaluation")
|
||||||
|
|
||||||
|
binary_classifier.eval()
|
||||||
|
|
||||||
|
test_loss = 0.0
|
||||||
|
test_labels = []
|
||||||
|
test_pridections = []
|
||||||
|
test_probs = []
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for data in binary_testloader:
|
||||||
|
images, labels = data
|
||||||
|
images, labels = images.to(DEVICE), labels.to(torch.float32).to(DEVICE)
|
||||||
|
outputs = binary_classifier(images)
|
||||||
|
loss = binary_criterion(outputs.logits, labels)
|
||||||
|
test_loss += loss.item() * BATCH_SIZE
|
||||||
|
|
||||||
|
test_labels.extend(labels.cpu())
|
||||||
|
test_pridections.extend(outputs.logits.cpu())
|
||||||
|
test_probs.extend(outputs.probabilities.cpu())
|
||||||
|
|
||||||
|
test_loss = test_loss / len(binary_test_dataset)
|
||||||
|
|
||||||
|
logging.info(f"Binary classifier test loss {test_loss:.3f}")
|
||||||
|
|
||||||
|
test_labels = torch.stack(test_labels)
|
||||||
|
test_predictions = torch.stack(test_pridections)
|
||||||
|
test_probs = torch.stack(test_probs)
|
||||||
|
|
||||||
|
# Calculate metrics
|
||||||
|
acc = Accuracy(task="binary")(test_predictions, test_labels)
|
||||||
|
prec = Precision(task="binary", average="weighted")(test_predictions, test_labels)
|
||||||
|
rec = Recall(task="binary", average="weighted")(test_predictions, test_labels)
|
||||||
|
f1_score = F1Score(task="binary", average="weighted")(test_predictions, test_labels)
|
||||||
|
auroc_score = AUROC(task="binary", average="weighted")(test_probs, test_labels)
|
||||||
|
|
||||||
|
logging.info(f"Accuracy: {acc:.2f}")
|
||||||
|
logging.info(f"Precision: {prec:.2f}")
|
||||||
|
logging.info(f"Recall: {rec:.2f}")
|
||||||
|
logging.info(f"F1 Score: {f1_score:.2f}")
|
||||||
|
logging.info(f"AUROC Score: {auroc_score:.2f}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
train_evaluate_multiclass_classifier()
|
||||||
|
train_evaluate_binary_classifier()
|
|
@ -0,0 +1,78 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
|
||||||
|
Classifier,
|
||||||
|
ClassifierConfig,
|
||||||
|
ClassifierOutput,
|
||||||
|
)
|
||||||
|
from tests.utils import require_package
|
||||||
|
|
||||||
|
|
||||||
|
def test_classifier_output():
|
||||||
|
output = ClassifierOutput(
|
||||||
|
logits=torch.tensor([1, 2, 3]), probabilities=torch.tensor([0.1, 0.2, 0.3]), hidden_states=None
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
f"{output}"
|
||||||
|
== "ClassifierOutput(logits=tensor([1, 2, 3]), probabilities=tensor([0.1000, 0.2000, 0.3000]), hidden_states=None)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@require_package("transformers")
|
||||||
|
def test_binary_classifier_with_default_params():
|
||||||
|
config = ClassifierConfig()
|
||||||
|
classifier = Classifier(config)
|
||||||
|
|
||||||
|
batch_size = 10
|
||||||
|
|
||||||
|
input = torch.rand(batch_size, 3, 224, 224)
|
||||||
|
output = classifier(input)
|
||||||
|
|
||||||
|
assert output is not None
|
||||||
|
assert output.logits.shape == torch.Size([batch_size])
|
||||||
|
assert not torch.isnan(output.logits).any(), "Tensor contains NaN values"
|
||||||
|
assert output.probabilities.shape == torch.Size([batch_size])
|
||||||
|
assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values"
|
||||||
|
assert output.hidden_states.shape == torch.Size([batch_size, 2048])
|
||||||
|
assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values"
|
||||||
|
|
||||||
|
|
||||||
|
@require_package("transformers")
|
||||||
|
def test_multiclass_classifier():
|
||||||
|
num_classes = 5
|
||||||
|
config = ClassifierConfig(num_classes=num_classes)
|
||||||
|
classifier = Classifier(config)
|
||||||
|
|
||||||
|
batch_size = 10
|
||||||
|
|
||||||
|
input = torch.rand(batch_size, 3, 224, 224)
|
||||||
|
output = classifier(input)
|
||||||
|
|
||||||
|
assert output is not None
|
||||||
|
assert output.logits.shape == torch.Size([batch_size, num_classes])
|
||||||
|
assert not torch.isnan(output.logits).any(), "Tensor contains NaN values"
|
||||||
|
assert output.probabilities.shape == torch.Size([batch_size, num_classes])
|
||||||
|
assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values"
|
||||||
|
assert output.hidden_states.shape == torch.Size([batch_size, 2048])
|
||||||
|
assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values"
|
||||||
|
|
||||||
|
|
||||||
|
@require_package("transformers")
|
||||||
|
def test_default_device():
|
||||||
|
config = ClassifierConfig()
|
||||||
|
assert config.device == "cpu"
|
||||||
|
|
||||||
|
classifier = Classifier(config)
|
||||||
|
for p in classifier.parameters():
|
||||||
|
assert p.device == torch.device("cpu")
|
||||||
|
|
||||||
|
|
||||||
|
@require_package("transformers")
|
||||||
|
def test_explicit_device_setup():
|
||||||
|
config = ClassifierConfig(device="meta")
|
||||||
|
assert config.device == "meta"
|
||||||
|
|
||||||
|
classifier = Classifier(config)
|
||||||
|
for p in classifier.parameters():
|
||||||
|
assert p.device == torch.device("meta")
|
Loading…
Reference in New Issue