[Port Hil-SERL] Add unit tests for the reward classifier & fix imports & check script (#578)

This commit is contained in:
Eugene Mironov 2024-12-23 16:43:55 +07:00 committed by GitHub
parent 7b68bfb73b
commit 70b652f791
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 499 additions and 2 deletions

View File

@ -13,7 +13,7 @@ class ClassifierConfig:
hidden_dim: int = 256
dropout_rate: float = 0.1
model_name: str = "microsoft/resnet-50"
device: str = "cuda" if torch.cuda.is_available() else "mps"
device: str = "cpu"
model_type: str = "cnn" # "transformer" or "cnn"
def save_pretrained(self, save_dir):

View File

@ -22,6 +22,11 @@ class ClassifierOutput:
self.probabilities = probabilities
self.hidden_states = hidden_states
def __repr__(self):
return (f"ClassifierOutput(logits={self.logits}, "
f"probabilities={self.probabilities}, "
f"hidden_states={self.hidden_states})")
class Classifier(
nn.Module,
@ -69,6 +74,8 @@ class Classifier(
self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension
else:
raise ValueError("Unsupported CNN architecture")
self.encoder = self.encoder.to(self.config.device)
def _freeze_encoder(self) -> None:
"""Freeze the encoder parameters."""
@ -93,6 +100,7 @@ class Classifier(
nn.ReLU(),
nn.Linear(self.config.hidden_dim, 1 if self.config.num_classes == 2 else self.config.num_classes),
)
self.classifier_head = self.classifier_head.to(self.config.device)
def _get_encoder_output(self, x: torch.Tensor) -> torch.Tensor:
"""Extract the appropriate output from the encoder."""

153
poetry.lock generated
View File

@ -3139,6 +3139,27 @@ dev = ["changelist (==0.5)"]
lint = ["pre-commit (==3.7.0)"]
test = ["pytest (>=7.4)", "pytest-cov (>=4.1)"]
[[package]]
name = "lightning-utilities"
version = "0.11.9"
description = "Lightning toolbox for across the our ecosystem."
optional = true
python-versions = ">=3.8"
files = [
{file = "lightning_utilities-0.11.9-py3-none-any.whl", hash = "sha256:ac6d4e9e28faf3ff4be997876750fee10dc604753dbc429bf3848a95c5d7e0d2"},
{file = "lightning_utilities-0.11.9.tar.gz", hash = "sha256:f5052b81344cc2684aa9afd74b7ce8819a8f49a858184ec04548a5a109dfd053"},
]
[package.dependencies]
packaging = ">=17.1"
setuptools = "*"
typing-extensions = "*"
[package.extras]
cli = ["fire"]
docs = ["requests (>=2.0.0)"]
typing = ["mypy (>=1.0.0)", "types-setuptools"]
[[package]]
name = "llvmlite"
version = "0.43.0"
@ -6798,6 +6819,38 @@ webencodings = ">=0.4"
doc = ["sphinx", "sphinx_rtd_theme"]
test = ["pytest", "ruff"]
[[package]]
name = "tokenizers"
version = "0.21.0"
description = ""
optional = true
python-versions = ">=3.7"
files = [
{file = "tokenizers-0.21.0-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:3c4c93eae637e7d2aaae3d376f06085164e1660f89304c0ab2b1d08a406636b2"},
{file = "tokenizers-0.21.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:f53ea537c925422a2e0e92a24cce96f6bc5046bbef24a1652a5edc8ba975f62e"},
{file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b177fb54c4702ef611de0c069d9169f0004233890e0c4c5bd5508ae05abf193"},
{file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6b43779a269f4629bebb114e19c3fca0223296ae9fea8bb9a7a6c6fb0657ff8e"},
{file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9aeb255802be90acfd363626753fda0064a8df06031012fe7d52fd9a905eb00e"},
{file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d8b09dbeb7a8d73ee204a70f94fc06ea0f17dcf0844f16102b9f414f0b7463ba"},
{file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:400832c0904f77ce87c40f1a8a27493071282f785724ae62144324f171377273"},
{file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e84ca973b3a96894d1707e189c14a774b701596d579ffc7e69debfc036a61a04"},
{file = "tokenizers-0.21.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:eb7202d231b273c34ec67767378cd04c767e967fda12d4a9e36208a34e2f137e"},
{file = "tokenizers-0.21.0-cp39-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:089d56db6782a73a27fd8abf3ba21779f5b85d4a9f35e3b493c7bbcbbf0d539b"},
{file = "tokenizers-0.21.0-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:c87ca3dc48b9b1222d984b6b7490355a6fdb411a2d810f6f05977258400ddb74"},
{file = "tokenizers-0.21.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:4145505a973116f91bc3ac45988a92e618a6f83eb458f49ea0790df94ee243ff"},
{file = "tokenizers-0.21.0-cp39-abi3-win32.whl", hash = "sha256:eb1702c2f27d25d9dd5b389cc1f2f51813e99f8ca30d9e25348db6585a97e24a"},
{file = "tokenizers-0.21.0-cp39-abi3-win_amd64.whl", hash = "sha256:87841da5a25a3a5f70c102de371db120f41873b854ba65e52bccd57df5a3780c"},
{file = "tokenizers-0.21.0.tar.gz", hash = "sha256:ee0894bf311b75b0c03079f33859ae4b2334d675d4e93f5a4132e1eae2834fe4"},
]
[package.dependencies]
huggingface-hub = ">=0.16.4,<1.0"
[package.extras]
dev = ["tokenizers[testing]"]
docs = ["setuptools-rust", "sphinx", "sphinx-rtd-theme"]
testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests", "ruff"]
[[package]]
name = "tomli"
version = "2.0.2"
@ -6863,6 +6916,34 @@ typing-extensions = ">=4.8.0"
opt-einsum = ["opt-einsum (>=3.3)"]
optree = ["optree (>=0.11.0)"]
[[package]]
name = "torchmetrics"
version = "1.6.0"
description = "PyTorch native Metrics"
optional = true
python-versions = ">=3.9"
files = [
{file = "torchmetrics-1.6.0-py3-none-any.whl", hash = "sha256:a508cdd87766cedaaf55a419812bf9f493aff8fffc02cc19df5a8e2e7ccb942a"},
{file = "torchmetrics-1.6.0.tar.gz", hash = "sha256:aebba248708fb90def20cccba6f55bddd134a58de43fb22b0c5ca0f3a89fa984"},
]
[package.dependencies]
lightning-utilities = ">=0.8.0"
numpy = ">1.20.0"
packaging = ">17.1"
torch = ">=2.0.0"
[package.extras]
all = ["SciencePlots (>=2.0.0)", "gammatone (>=1.0.0)", "ipadic (>=1.0.0)", "librosa (>=0.10.0)", "matplotlib (>=3.6.0)", "mecab-python3 (>=1.0.6)", "mypy (==1.13.0)", "nltk (>3.8.1)", "numpy (<2.0)", "onnxruntime (>=1.12.0)", "pesq (>=0.0.4)", "piq (<=0.8.0)", "pycocotools (>2.0.0)", "pystoi (>=0.4.0)", "regex (>=2021.9.24)", "requests (>=2.19.0)", "scipy (>1.0.0)", "sentencepiece (>=0.2.0)", "torch (==2.5.1)", "torch-fidelity (<=0.4.0)", "torchaudio (>=2.0.1)", "torchvision (>=0.15.1)", "tqdm (<4.68.0)", "transformers (>4.4.0)", "transformers (>=4.42.3)", "types-PyYAML", "types-emoji", "types-protobuf", "types-requests", "types-setuptools", "types-six", "types-tabulate"]
audio = ["gammatone (>=1.0.0)", "librosa (>=0.10.0)", "numpy (<2.0)", "onnxruntime (>=1.12.0)", "pesq (>=0.0.4)", "pystoi (>=0.4.0)", "requests (>=2.19.0)", "torchaudio (>=2.0.1)"]
detection = ["pycocotools (>2.0.0)", "torchvision (>=0.15.1)"]
dev = ["PyTDC (==0.4.1)", "SciencePlots (>=2.0.0)", "bert-score (==0.3.13)", "dython (==0.7.6)", "dython (>=0.7.8,<0.8.0)", "fairlearn", "fast-bss-eval (>=0.1.0)", "faster-coco-eval (>=1.6.3)", "gammatone (>=1.0.0)", "huggingface-hub (<0.27)", "ipadic (>=1.0.0)", "jiwer (>=2.3.0)", "kornia (>=0.6.7)", "librosa (>=0.10.0)", "lpips (<=0.1.4)", "matplotlib (>=3.6.0)", "mecab-ko (>=1.0.0,<1.1.0)", "mecab-ko-dic (>=1.0.0)", "mecab-python3 (>=1.0.6)", "mir-eval (>=0.6)", "monai (==1.3.2)", "monai (==1.4.0)", "mypy (==1.13.0)", "netcal (>1.0.0)", "nltk (>3.8.1)", "numpy (<2.0)", "numpy (<2.2.0)", "onnxruntime (>=1.12.0)", "pandas (>1.4.0)", "permetrics (==2.0.0)", "pesq (>=0.0.4)", "piq (<=0.8.0)", "pycocotools (>2.0.0)", "pystoi (>=0.4.0)", "pytorch-msssim (==1.0.0)", "regex (>=2021.9.24)", "requests (>=2.19.0)", "rouge-score (>0.1.0)", "sacrebleu (>=2.3.0)", "scikit-image (>=0.19.0)", "scipy (>1.0.0)", "sentencepiece (>=0.2.0)", "sewar (>=0.4.4)", "statsmodels (>0.13.5)", "torch (==2.5.1)", "torch-complex (<0.5.0)", "torch-fidelity (<=0.4.0)", "torchaudio (>=2.0.1)", "torchvision (>=0.15.1)", "tqdm (<4.68.0)", "transformers (>4.4.0)", "transformers (>=4.42.3)", "types-PyYAML", "types-emoji", "types-protobuf", "types-requests", "types-setuptools", "types-six", "types-tabulate"]
image = ["scipy (>1.0.0)", "torch-fidelity (<=0.4.0)", "torchvision (>=0.15.1)"]
multimodal = ["piq (<=0.8.0)", "transformers (>=4.42.3)"]
text = ["ipadic (>=1.0.0)", "mecab-python3 (>=1.0.6)", "nltk (>3.8.1)", "regex (>=2021.9.24)", "sentencepiece (>=0.2.0)", "tqdm (<4.68.0)", "transformers (>4.4.0)"]
typing = ["mypy (==1.13.0)", "torch (==2.5.1)", "types-PyYAML", "types-emoji", "types-protobuf", "types-requests", "types-setuptools", "types-six", "types-tabulate"]
visual = ["SciencePlots (>=2.0.0)", "matplotlib (>=3.6.0)"]
[[package]]
name = "torchvision"
version = "0.19.1"
@ -6956,6 +7037,75 @@ files = [
docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"]
test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0,<8.2)", "pytest-mock", "pytest-mypy-testing"]
[[package]]
name = "transformers"
version = "4.47.0"
description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow"
optional = true
python-versions = ">=3.9.0"
files = [
{file = "transformers-4.47.0-py3-none-any.whl", hash = "sha256:a8e1bafdaae69abdda3cad638fe392e37c86d2ce0ecfcae11d60abb8f949ff4d"},
{file = "transformers-4.47.0.tar.gz", hash = "sha256:f8ead7a5a4f6937bb507e66508e5e002dc5930f7b6122a9259c37b099d0f3b19"},
]
[package.dependencies]
filelock = "*"
huggingface-hub = ">=0.24.0,<1.0"
numpy = ">=1.17"
packaging = ">=20.0"
pyyaml = ">=5.1"
regex = "!=2019.12.17"
requests = "*"
safetensors = ">=0.4.1"
tokenizers = ">=0.21,<0.22"
tqdm = ">=4.27"
[package.extras]
accelerate = ["accelerate (>=0.26.0)"]
agents = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch"]
all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch", "torchaudio", "torchvision"]
audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
benchmark = ["optimum-benchmark (>=0.3.0)"]
codecarbon = ["codecarbon (==1.2.0)"]
deepspeed = ["accelerate (>=0.26.0)", "deepspeed (>=0.9.3)"]
deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.26.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "optuna", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"]
dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av (==9.2.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "libcst", "librosa", "nltk (<=3.8.1)", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"]
dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "libcst", "librosa", "nltk (<=3.8.1)", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.21,<0.22)", "urllib3 (<2.0.0)"]
dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "libcst", "librosa", "nltk (<=3.8.1)", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"]
flax = ["flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "optax (>=0.0.8,<=0.1.4)", "scipy (<1.13.0)"]
flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
ftfy = ["ftfy"]
integrations = ["optuna", "ray[tune] (>=2.7.0)", "sigopt"]
ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "rhoknp (>=1.1.0,<1.3.1)", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)"]
modelcreation = ["cookiecutter (==1.7.3)"]
natten = ["natten (>=0.14.6,<0.15.0)"]
onnx = ["onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "tf2onnx"]
onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"]
optuna = ["optuna"]
quality = ["GitPython (<3.1.19)", "datasets (!=2.5.0)", "isort (>=5.5.4)", "libcst", "rich", "ruff (==0.5.1)", "urllib3 (<2.0.0)"]
ray = ["ray[tune] (>=2.7.0)"]
retrieval = ["datasets (!=2.5.0)", "faiss-cpu"]
ruff = ["ruff (==0.5.1)"]
sagemaker = ["sagemaker (>=2.31.0)"]
sentencepiece = ["protobuf", "sentencepiece (>=0.1.91,!=0.1.92)"]
serving = ["fastapi", "pydantic", "starlette", "uvicorn"]
sigopt = ["sigopt"]
sklearn = ["scikit-learn"]
speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"]
testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "parameterized", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"]
tf = ["keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"]
tf-cpu = ["keras (>2.9,<2.16)", "keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow-cpu (>2.9,<2.16)", "tensorflow-probability (<0.24)", "tensorflow-text (<2.16)", "tf2onnx"]
tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
tiktoken = ["blobfile", "tiktoken"]
timm = ["timm (<=1.0.11)"]
tokenizers = ["tokenizers (>=0.21,<0.22)"]
torch = ["accelerate (>=0.26.0)", "torch"]
torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"]
torch-vision = ["Pillow (>=10.0.1,<=15.0)", "torchvision"]
torchhub = ["filelock", "huggingface-hub (>=0.24.0,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.21,<0.22)", "torch", "tqdm (>=4.27)"]
video = ["av (==9.2.0)"]
vision = ["Pillow (>=10.0.1,<=15.0)"]
[[package]]
name = "transforms3d"
version = "0.4.2"
@ -7558,6 +7708,7 @@ dev = ["debugpy", "pre-commit"]
dora = ["gym-dora"]
dynamixel = ["dynamixel-sdk", "pynput"]
feetech = ["feetech-servo-sdk", "pynput"]
hilserl = ["torchmetrics", "transformers"]
intelrealsense = ["pyrealsense2"]
pusht = ["gym-pusht"]
stretch = ["hello-robot-stretch-body", "pynput", "pyrealsense2", "pyrender"]
@ -7569,4 +7720,4 @@ xarm = ["gym-xarm"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.10,<3.13"
content-hash = "41344f0eb2d06d9a378abcd10df8205aa3926ff0a08ac5ab1a0b1bcae7440fd8"
content-hash = "b9d299916ced6af1d243f961a32b0a4aacbef18e0b95337a5224e8511f5d6dda"

View File

@ -71,6 +71,8 @@ pyrender = {git = "https://github.com/mmatl/pyrender.git", markers = "sys_platfo
hello-robot-stretch-body = {version = ">=0.7.27", markers = "sys_platform == 'linux'", optional = true}
pyserial = {version = ">=3.5", optional = true}
jsonlines = ">=4.0.0"
transformers = {version = "^4.47.0", optional = true}
torchmetrics = {version = "^1.6.0", optional = true}
[tool.poetry.extras]
@ -86,6 +88,7 @@ dynamixel = ["dynamixel-sdk", "pynput"]
feetech = ["feetech-servo-sdk", "pynput"]
intelrealsense = ["pyrealsense2"]
stretch = ["hello-robot-stretch-body", "pyrender", "pyrealsense2", "pynput"]
hilserl = ["transformers", "torchmetrics"]
[tool.ruff]
line-length = 110

View File

@ -14,9 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import traceback
import pytest
import torch
from serial import SerialException
from lerobot import available_cameras, available_motors, available_robots
@ -124,3 +126,14 @@ def patch_builtins_input(monkeypatch):
print(text)
monkeypatch.setattr("builtins.input", print_text)
def pytest_addoption(parser):
parser.addoption("--seed", action="store", default="42", help="Set random seed for reproducibility")
@pytest.fixture(autouse=True)
def set_random_seed(request):
seed = int(request.config.getoption("--seed"))
random.seed(seed) # Python random
torch.manual_seed(seed) # PyTorch

View File

@ -0,0 +1,244 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import numpy as np
import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchmetrics import AUROC, Accuracy, F1Score, Precision, Recall
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier, ClassifierConfig
BATCH_SIZE = 1000
LR = 0.1
EPOCH_NUM = 2
if torch.cuda.is_available():
DEVICE = torch.device("cuda")
elif torch.backends.mps.is_available():
DEVICE = torch.device("mps")
else:
DEVICE = torch.device("cpu")
def train_evaluate_multiclass_classifier():
logging.info(
f"Start multiclass classifier train eval with {DEVICE} device, batch size {BATCH_SIZE}, learning rate {LR}"
)
multiclass_config = ClassifierConfig(model_name="microsoft/resnet-18", device=DEVICE, num_classes=10)
multiclass_classifier = Classifier(multiclass_config)
trainset = CIFAR10(root="data", train=True, download=True, transform=ToTensor())
testset = CIFAR10(root="data", train=False, download=True, transform=ToTensor())
trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
testloader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False)
multiclass_num_classes = 10
epoch = 1
criterion = CrossEntropyLoss()
optimizer = Adam(multiclass_classifier.parameters(), lr=LR)
multiclass_classifier.train()
logging.info("Start multiclass classifier training")
# Training loop
while epoch < EPOCH_NUM: # loop over the dataset multiple times
for i, data in enumerate(trainloader):
inputs, labels = data
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
# Zero the parameter gradients
optimizer.zero_grad()
# Forward pass
outputs = multiclass_classifier(inputs)
loss = criterion(outputs.logits, labels)
loss.backward()
optimizer.step()
if i % 10 == 0: # print every 10 mini-batches
logging.info(f"[Epoch {epoch}, Batch {i}] loss: {loss.item():.3f}")
epoch += 1
print("Multiclass classifier training finished")
multiclass_classifier.eval()
test_loss = 0.0
test_labels = []
test_pridections = []
test_probs = []
with torch.no_grad():
for data in testloader:
images, labels = data
images, labels = images.to(DEVICE), labels.to(DEVICE)
outputs = multiclass_classifier(images)
loss = criterion(outputs.logits, labels)
test_loss += loss.item() * BATCH_SIZE
_, predicted = torch.max(outputs.logits, 1)
test_labels.extend(labels.cpu())
test_pridections.extend(predicted.cpu())
test_probs.extend(outputs.probabilities.cpu())
test_loss = test_loss / len(testset)
logging.info(f"Multiclass classifier test loss {test_loss:.3f}")
test_labels = torch.stack(test_labels)
test_predictions = torch.stack(test_pridections)
test_probs = torch.stack(test_probs)
accuracy = Accuracy(task="multiclass", num_classes=multiclass_num_classes)
precision = Precision(task="multiclass", average="weighted", num_classes=multiclass_num_classes)
recall = Recall(task="multiclass", average="weighted", num_classes=multiclass_num_classes)
f1 = F1Score(task="multiclass", average="weighted", num_classes=multiclass_num_classes)
auroc = AUROC(task="multiclass", num_classes=multiclass_num_classes, average="weighted")
# Calculate metrics
acc = accuracy(test_predictions, test_labels)
prec = precision(test_predictions, test_labels)
rec = recall(test_predictions, test_labels)
f1_score = f1(test_predictions, test_labels)
auroc_score = auroc(test_probs, test_labels)
logging.info(f"Accuracy: {acc:.2f}")
logging.info(f"Precision: {prec:.2f}")
logging.info(f"Recall: {rec:.2f}")
logging.info(f"F1 Score: {f1_score:.2f}")
logging.info(f"AUROC Score: {auroc_score:.2f}")
def train_evaluate_binary_classifier():
logging.info(
f"Start binary classifier train eval with {DEVICE} device, batch size {BATCH_SIZE}, learning rate {LR}"
)
target_binary_class = 3
def one_vs_rest(dataset, target_class):
new_targets = []
for _, label in dataset:
new_label = float(1.0) if label == target_class else float(0.0)
new_targets.append(new_label)
dataset.targets = new_targets # Replace the original labels with the binary ones
return dataset
binary_train_dataset = CIFAR10(root="data", train=True, download=True, transform=ToTensor())
binary_test_dataset = CIFAR10(root="data", train=False, download=True, transform=ToTensor())
# Apply one-vs-rest labeling
binary_train_dataset = one_vs_rest(binary_train_dataset, target_binary_class)
binary_test_dataset = one_vs_rest(binary_test_dataset, target_binary_class)
binary_trainloader = DataLoader(binary_train_dataset, batch_size=BATCH_SIZE, shuffle=True)
binary_testloader = DataLoader(binary_test_dataset, batch_size=BATCH_SIZE, shuffle=False)
binary_epoch = 1
binary_config = ClassifierConfig(model_name="microsoft/resnet-50", device=DEVICE)
binary_classifier = Classifier(binary_config)
class_counts = np.bincount(binary_train_dataset.targets)
n = len(binary_train_dataset)
w0 = n / (2.0 * class_counts[0])
w1 = n / (2.0 * class_counts[1])
binary_criterion = BCEWithLogitsLoss(pos_weight=torch.tensor(w1 / w0))
binary_optimizer = Adam(binary_classifier.parameters(), lr=LR)
binary_classifier.train()
logging.info("Start binary classifier training")
# Training loop
while binary_epoch < EPOCH_NUM: # loop over the dataset multiple times
for i, data in enumerate(binary_trainloader):
inputs, labels = data
inputs, labels = inputs.to(DEVICE), labels.to(torch.float32).to(DEVICE)
# Zero the parameter gradients
binary_optimizer.zero_grad()
# Forward pass
outputs = binary_classifier(inputs)
loss = binary_criterion(outputs.logits, labels)
loss.backward()
binary_optimizer.step()
if i % 10 == 0: # print every 10 mini-batches
print(f"[Epoch {binary_epoch}, Batch {i}] loss: {loss.item():.3f}")
binary_epoch += 1
logging.info("Binary classifier training finished")
logging.info("Start binary classifier evaluation")
binary_classifier.eval()
test_loss = 0.0
test_labels = []
test_pridections = []
test_probs = []
with torch.no_grad():
for data in binary_testloader:
images, labels = data
images, labels = images.to(DEVICE), labels.to(torch.float32).to(DEVICE)
outputs = binary_classifier(images)
loss = binary_criterion(outputs.logits, labels)
test_loss += loss.item() * BATCH_SIZE
test_labels.extend(labels.cpu())
test_pridections.extend(outputs.logits.cpu())
test_probs.extend(outputs.probabilities.cpu())
test_loss = test_loss / len(binary_test_dataset)
logging.info(f"Binary classifier test loss {test_loss:.3f}")
test_labels = torch.stack(test_labels)
test_predictions = torch.stack(test_pridections)
test_probs = torch.stack(test_probs)
# Calculate metrics
acc = Accuracy(task="binary")(test_predictions, test_labels)
prec = Precision(task="binary", average="weighted")(test_predictions, test_labels)
rec = Recall(task="binary", average="weighted")(test_predictions, test_labels)
f1_score = F1Score(task="binary", average="weighted")(test_predictions, test_labels)
auroc_score = AUROC(task="binary", average="weighted")(test_probs, test_labels)
logging.info(f"Accuracy: {acc:.2f}")
logging.info(f"Precision: {prec:.2f}")
logging.info(f"Recall: {rec:.2f}")
logging.info(f"F1 Score: {f1_score:.2f}")
logging.info(f"AUROC Score: {auroc_score:.2f}")
if __name__ == "__main__":
train_evaluate_multiclass_classifier()
train_evaluate_binary_classifier()

View File

@ -0,0 +1,78 @@
import torch
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
Classifier,
ClassifierConfig,
ClassifierOutput,
)
from tests.utils import require_package
def test_classifier_output():
output = ClassifierOutput(
logits=torch.tensor([1, 2, 3]), probabilities=torch.tensor([0.1, 0.2, 0.3]), hidden_states=None
)
assert (
f"{output}"
== "ClassifierOutput(logits=tensor([1, 2, 3]), probabilities=tensor([0.1000, 0.2000, 0.3000]), hidden_states=None)"
)
@require_package("transformers")
def test_binary_classifier_with_default_params():
config = ClassifierConfig()
classifier = Classifier(config)
batch_size = 10
input = torch.rand(batch_size, 3, 224, 224)
output = classifier(input)
assert output is not None
assert output.logits.shape == torch.Size([batch_size])
assert not torch.isnan(output.logits).any(), "Tensor contains NaN values"
assert output.probabilities.shape == torch.Size([batch_size])
assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values"
assert output.hidden_states.shape == torch.Size([batch_size, 2048])
assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values"
@require_package("transformers")
def test_multiclass_classifier():
num_classes = 5
config = ClassifierConfig(num_classes=num_classes)
classifier = Classifier(config)
batch_size = 10
input = torch.rand(batch_size, 3, 224, 224)
output = classifier(input)
assert output is not None
assert output.logits.shape == torch.Size([batch_size, num_classes])
assert not torch.isnan(output.logits).any(), "Tensor contains NaN values"
assert output.probabilities.shape == torch.Size([batch_size, num_classes])
assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values"
assert output.hidden_states.shape == torch.Size([batch_size, 2048])
assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values"
@require_package("transformers")
def test_default_device():
config = ClassifierConfig()
assert config.device == "cpu"
classifier = Classifier(config)
for p in classifier.parameters():
assert p.device == torch.device("cpu")
@require_package("transformers")
def test_explicit_device_setup():
config = ClassifierConfig(device="meta")
assert config.device == "meta"
classifier = Classifier(config)
for p in classifier.parameters():
assert p.device == torch.device("meta")