From 6340d9d17c70846890ee1a04848db7219f6f80ba Mon Sep 17 00:00:00 2001 From: Eugene Mironov Date: Mon, 23 Dec 2024 16:43:55 +0700 Subject: [PATCH] [Port Hil-SERL] Add unit tests for the reward classifier & fix imports & check script (#578) --- .../classifier/configuration_classifier.py | 2 +- .../hilserl/classifier/modeling_classifier.py | 8 + pyproject.toml | 2 + tests/conftest.py | 13 + .../check_hiserl_reward_classifier.py | 244 ++++++++++++++++++ .../classifier/test_modelling_classifier.py | 78 ++++++ 6 files changed, 346 insertions(+), 1 deletion(-) create mode 100644 tests/policies/hilserl/classifier/check_hiserl_reward_classifier.py create mode 100644 tests/policies/hilserl/classifier/test_modelling_classifier.py diff --git a/lerobot/common/policies/hilserl/classifier/configuration_classifier.py b/lerobot/common/policies/hilserl/classifier/configuration_classifier.py index 209ff659..553e4262 100644 --- a/lerobot/common/policies/hilserl/classifier/configuration_classifier.py +++ b/lerobot/common/policies/hilserl/classifier/configuration_classifier.py @@ -13,7 +13,7 @@ class ClassifierConfig: hidden_dim: int = 256 dropout_rate: float = 0.1 model_name: str = "microsoft/resnet-50" - device: str = "cuda" if torch.cuda.is_available() else "mps" + device: str = "cpu" model_type: str = "cnn" # "transformer" or "cnn" def save_pretrained(self, save_dir): diff --git a/lerobot/common/policies/hilserl/classifier/modeling_classifier.py b/lerobot/common/policies/hilserl/classifier/modeling_classifier.py index dbb434a7..0b8d66ac 100644 --- a/lerobot/common/policies/hilserl/classifier/modeling_classifier.py +++ b/lerobot/common/policies/hilserl/classifier/modeling_classifier.py @@ -22,6 +22,11 @@ class ClassifierOutput: self.probabilities = probabilities self.hidden_states = hidden_states + def __repr__(self): + return (f"ClassifierOutput(logits={self.logits}, " + f"probabilities={self.probabilities}, " + f"hidden_states={self.hidden_states})") + class Classifier( nn.Module, @@ -69,6 +74,8 @@ class Classifier( self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension else: raise ValueError("Unsupported CNN architecture") + + self.encoder = self.encoder.to(self.config.device) def _freeze_encoder(self) -> None: """Freeze the encoder parameters.""" @@ -93,6 +100,7 @@ class Classifier( nn.ReLU(), nn.Linear(self.config.hidden_dim, 1 if self.config.num_classes == 2 else self.config.num_classes), ) + self.classifier_head = self.classifier_head.to(self.config.device) def _get_encoder_output(self, x: torch.Tensor) -> torch.Tensor: """Extract the appropriate output from the encoder.""" diff --git a/pyproject.toml b/pyproject.toml index 1fa7b246..2f5299a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,7 +70,9 @@ dependencies = [ "termcolor>=2.4.0", "torch>=2.2.1", "torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l'))", + "torchmetrics>=1.6.0", "torchvision>=0.21.0", + "transformers>=4.47.0", "wandb>=0.16.3", "zarr>=2.17.0", ] diff --git a/tests/conftest.py b/tests/conftest.py index 7eec94bf..cc35768e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,9 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import random import traceback import pytest +import torch from serial import SerialException from lerobot import available_cameras, available_motors, available_robots @@ -86,3 +88,14 @@ def patch_builtins_input(monkeypatch): print(text) monkeypatch.setattr("builtins.input", print_text) + + +def pytest_addoption(parser): + parser.addoption("--seed", action="store", default="42", help="Set random seed for reproducibility") + + +@pytest.fixture(autouse=True) +def set_random_seed(request): + seed = int(request.config.getoption("--seed")) + random.seed(seed) # Python random + torch.manual_seed(seed) # PyTorch diff --git a/tests/policies/hilserl/classifier/check_hiserl_reward_classifier.py b/tests/policies/hilserl/classifier/check_hiserl_reward_classifier.py new file mode 100644 index 00000000..55e6e381 --- /dev/null +++ b/tests/policies/hilserl/classifier/check_hiserl_reward_classifier.py @@ -0,0 +1,244 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import numpy as np +import torch +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss +from torch.optim import Adam +from torch.utils.data import DataLoader +from torchmetrics import AUROC, Accuracy, F1Score, Precision, Recall +from torchvision.datasets import CIFAR10 +from torchvision.transforms import ToTensor + +from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier, ClassifierConfig + +BATCH_SIZE = 1000 +LR = 0.1 +EPOCH_NUM = 2 + +if torch.cuda.is_available(): + DEVICE = torch.device("cuda") +elif torch.backends.mps.is_available(): + DEVICE = torch.device("mps") +else: + DEVICE = torch.device("cpu") + + +def train_evaluate_multiclass_classifier(): + logging.info( + f"Start multiclass classifier train eval with {DEVICE} device, batch size {BATCH_SIZE}, learning rate {LR}" + ) + multiclass_config = ClassifierConfig(model_name="microsoft/resnet-18", device=DEVICE, num_classes=10) + multiclass_classifier = Classifier(multiclass_config) + + trainset = CIFAR10(root="data", train=True, download=True, transform=ToTensor()) + testset = CIFAR10(root="data", train=False, download=True, transform=ToTensor()) + + trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True) + testloader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False) + + multiclass_num_classes = 10 + epoch = 1 + + criterion = CrossEntropyLoss() + optimizer = Adam(multiclass_classifier.parameters(), lr=LR) + + multiclass_classifier.train() + + logging.info("Start multiclass classifier training") + + # Training loop + while epoch < EPOCH_NUM: # loop over the dataset multiple times + for i, data in enumerate(trainloader): + inputs, labels = data + inputs, labels = inputs.to(DEVICE), labels.to(DEVICE) + + # Zero the parameter gradients + optimizer.zero_grad() + + # Forward pass + outputs = multiclass_classifier(inputs) + + loss = criterion(outputs.logits, labels) + loss.backward() + optimizer.step() + + if i % 10 == 0: # print every 10 mini-batches + logging.info(f"[Epoch {epoch}, Batch {i}] loss: {loss.item():.3f}") + + epoch += 1 + + print("Multiclass classifier training finished") + + multiclass_classifier.eval() + + test_loss = 0.0 + test_labels = [] + test_pridections = [] + test_probs = [] + + with torch.no_grad(): + for data in testloader: + images, labels = data + images, labels = images.to(DEVICE), labels.to(DEVICE) + outputs = multiclass_classifier(images) + loss = criterion(outputs.logits, labels) + test_loss += loss.item() * BATCH_SIZE + + _, predicted = torch.max(outputs.logits, 1) + test_labels.extend(labels.cpu()) + test_pridections.extend(predicted.cpu()) + test_probs.extend(outputs.probabilities.cpu()) + + test_loss = test_loss / len(testset) + + logging.info(f"Multiclass classifier test loss {test_loss:.3f}") + + test_labels = torch.stack(test_labels) + test_predictions = torch.stack(test_pridections) + test_probs = torch.stack(test_probs) + + accuracy = Accuracy(task="multiclass", num_classes=multiclass_num_classes) + precision = Precision(task="multiclass", average="weighted", num_classes=multiclass_num_classes) + recall = Recall(task="multiclass", average="weighted", num_classes=multiclass_num_classes) + f1 = F1Score(task="multiclass", average="weighted", num_classes=multiclass_num_classes) + auroc = AUROC(task="multiclass", num_classes=multiclass_num_classes, average="weighted") + + # Calculate metrics + acc = accuracy(test_predictions, test_labels) + prec = precision(test_predictions, test_labels) + rec = recall(test_predictions, test_labels) + f1_score = f1(test_predictions, test_labels) + auroc_score = auroc(test_probs, test_labels) + + logging.info(f"Accuracy: {acc:.2f}") + logging.info(f"Precision: {prec:.2f}") + logging.info(f"Recall: {rec:.2f}") + logging.info(f"F1 Score: {f1_score:.2f}") + logging.info(f"AUROC Score: {auroc_score:.2f}") + + +def train_evaluate_binary_classifier(): + logging.info( + f"Start binary classifier train eval with {DEVICE} device, batch size {BATCH_SIZE}, learning rate {LR}" + ) + + target_binary_class = 3 + + def one_vs_rest(dataset, target_class): + new_targets = [] + for _, label in dataset: + new_label = float(1.0) if label == target_class else float(0.0) + new_targets.append(new_label) + + dataset.targets = new_targets # Replace the original labels with the binary ones + return dataset + + binary_train_dataset = CIFAR10(root="data", train=True, download=True, transform=ToTensor()) + binary_test_dataset = CIFAR10(root="data", train=False, download=True, transform=ToTensor()) + + # Apply one-vs-rest labeling + binary_train_dataset = one_vs_rest(binary_train_dataset, target_binary_class) + binary_test_dataset = one_vs_rest(binary_test_dataset, target_binary_class) + + binary_trainloader = DataLoader(binary_train_dataset, batch_size=BATCH_SIZE, shuffle=True) + binary_testloader = DataLoader(binary_test_dataset, batch_size=BATCH_SIZE, shuffle=False) + + binary_epoch = 1 + + binary_config = ClassifierConfig(model_name="microsoft/resnet-50", device=DEVICE) + binary_classifier = Classifier(binary_config) + + class_counts = np.bincount(binary_train_dataset.targets) + n = len(binary_train_dataset) + w0 = n / (2.0 * class_counts[0]) + w1 = n / (2.0 * class_counts[1]) + + binary_criterion = BCEWithLogitsLoss(pos_weight=torch.tensor(w1 / w0)) + binary_optimizer = Adam(binary_classifier.parameters(), lr=LR) + + binary_classifier.train() + + logging.info("Start binary classifier training") + + # Training loop + while binary_epoch < EPOCH_NUM: # loop over the dataset multiple times + for i, data in enumerate(binary_trainloader): + inputs, labels = data + inputs, labels = inputs.to(DEVICE), labels.to(torch.float32).to(DEVICE) + + # Zero the parameter gradients + binary_optimizer.zero_grad() + + # Forward pass + outputs = binary_classifier(inputs) + loss = binary_criterion(outputs.logits, labels) + loss.backward() + binary_optimizer.step() + + if i % 10 == 0: # print every 10 mini-batches + print(f"[Epoch {binary_epoch}, Batch {i}] loss: {loss.item():.3f}") + binary_epoch += 1 + + logging.info("Binary classifier training finished") + logging.info("Start binary classifier evaluation") + + binary_classifier.eval() + + test_loss = 0.0 + test_labels = [] + test_pridections = [] + test_probs = [] + + with torch.no_grad(): + for data in binary_testloader: + images, labels = data + images, labels = images.to(DEVICE), labels.to(torch.float32).to(DEVICE) + outputs = binary_classifier(images) + loss = binary_criterion(outputs.logits, labels) + test_loss += loss.item() * BATCH_SIZE + + test_labels.extend(labels.cpu()) + test_pridections.extend(outputs.logits.cpu()) + test_probs.extend(outputs.probabilities.cpu()) + + test_loss = test_loss / len(binary_test_dataset) + + logging.info(f"Binary classifier test loss {test_loss:.3f}") + + test_labels = torch.stack(test_labels) + test_predictions = torch.stack(test_pridections) + test_probs = torch.stack(test_probs) + + # Calculate metrics + acc = Accuracy(task="binary")(test_predictions, test_labels) + prec = Precision(task="binary", average="weighted")(test_predictions, test_labels) + rec = Recall(task="binary", average="weighted")(test_predictions, test_labels) + f1_score = F1Score(task="binary", average="weighted")(test_predictions, test_labels) + auroc_score = AUROC(task="binary", average="weighted")(test_probs, test_labels) + + logging.info(f"Accuracy: {acc:.2f}") + logging.info(f"Precision: {prec:.2f}") + logging.info(f"Recall: {rec:.2f}") + logging.info(f"F1 Score: {f1_score:.2f}") + logging.info(f"AUROC Score: {auroc_score:.2f}") + + +if __name__ == "__main__": + train_evaluate_multiclass_classifier() + train_evaluate_binary_classifier() diff --git a/tests/policies/hilserl/classifier/test_modelling_classifier.py b/tests/policies/hilserl/classifier/test_modelling_classifier.py new file mode 100644 index 00000000..014165eb --- /dev/null +++ b/tests/policies/hilserl/classifier/test_modelling_classifier.py @@ -0,0 +1,78 @@ +import torch + +from lerobot.common.policies.hilserl.classifier.modeling_classifier import ( + Classifier, + ClassifierConfig, + ClassifierOutput, +) +from tests.utils import require_package + + +def test_classifier_output(): + output = ClassifierOutput( + logits=torch.tensor([1, 2, 3]), probabilities=torch.tensor([0.1, 0.2, 0.3]), hidden_states=None + ) + + assert ( + f"{output}" + == "ClassifierOutput(logits=tensor([1, 2, 3]), probabilities=tensor([0.1000, 0.2000, 0.3000]), hidden_states=None)" + ) + + +@require_package("transformers") +def test_binary_classifier_with_default_params(): + config = ClassifierConfig() + classifier = Classifier(config) + + batch_size = 10 + + input = torch.rand(batch_size, 3, 224, 224) + output = classifier(input) + + assert output is not None + assert output.logits.shape == torch.Size([batch_size]) + assert not torch.isnan(output.logits).any(), "Tensor contains NaN values" + assert output.probabilities.shape == torch.Size([batch_size]) + assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values" + assert output.hidden_states.shape == torch.Size([batch_size, 2048]) + assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values" + + +@require_package("transformers") +def test_multiclass_classifier(): + num_classes = 5 + config = ClassifierConfig(num_classes=num_classes) + classifier = Classifier(config) + + batch_size = 10 + + input = torch.rand(batch_size, 3, 224, 224) + output = classifier(input) + + assert output is not None + assert output.logits.shape == torch.Size([batch_size, num_classes]) + assert not torch.isnan(output.logits).any(), "Tensor contains NaN values" + assert output.probabilities.shape == torch.Size([batch_size, num_classes]) + assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values" + assert output.hidden_states.shape == torch.Size([batch_size, 2048]) + assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values" + + +@require_package("transformers") +def test_default_device(): + config = ClassifierConfig() + assert config.device == "cpu" + + classifier = Classifier(config) + for p in classifier.parameters(): + assert p.device == torch.device("cpu") + + +@require_package("transformers") +def test_explicit_device_setup(): + config = ClassifierConfig(device="meta") + assert config.device == "meta" + + classifier = Classifier(config) + for p in classifier.parameters(): + assert p.device == torch.device("meta")