From 1e2a757cd314c1d094e5ab3414c2f40df9ee1ee8 Mon Sep 17 00:00:00 2001 From: Eugene Mironov Date: Mon, 23 Dec 2024 16:43:55 +0700 Subject: [PATCH] [Port Hil-SERL] Add unit tests for the reward classifier & fix imports & check script (#578) --- .../classifier/configuration_classifier.py | 2 +- .../hilserl/classifier/modeling_classifier.py | 8 + poetry.lock | 153 ++++++++++- pyproject.toml | 3 + tests/conftest.py | 13 + .../check_hiserl_reward_classifier.py | 244 ++++++++++++++++++ .../classifier/test_modelling_classifier.py | 78 ++++++ 7 files changed, 499 insertions(+), 2 deletions(-) create mode 100644 tests/policies/hilserl/classifier/check_hiserl_reward_classifier.py create mode 100644 tests/policies/hilserl/classifier/test_modelling_classifier.py diff --git a/lerobot/common/policies/hilserl/classifier/configuration_classifier.py b/lerobot/common/policies/hilserl/classifier/configuration_classifier.py index 209ff659..553e4262 100644 --- a/lerobot/common/policies/hilserl/classifier/configuration_classifier.py +++ b/lerobot/common/policies/hilserl/classifier/configuration_classifier.py @@ -13,7 +13,7 @@ class ClassifierConfig: hidden_dim: int = 256 dropout_rate: float = 0.1 model_name: str = "microsoft/resnet-50" - device: str = "cuda" if torch.cuda.is_available() else "mps" + device: str = "cpu" model_type: str = "cnn" # "transformer" or "cnn" def save_pretrained(self, save_dir): diff --git a/lerobot/common/policies/hilserl/classifier/modeling_classifier.py b/lerobot/common/policies/hilserl/classifier/modeling_classifier.py index dbb434a7..0b8d66ac 100644 --- a/lerobot/common/policies/hilserl/classifier/modeling_classifier.py +++ b/lerobot/common/policies/hilserl/classifier/modeling_classifier.py @@ -22,6 +22,11 @@ class ClassifierOutput: self.probabilities = probabilities self.hidden_states = hidden_states + def __repr__(self): + return (f"ClassifierOutput(logits={self.logits}, " + f"probabilities={self.probabilities}, " + f"hidden_states={self.hidden_states})") + class Classifier( nn.Module, @@ -69,6 +74,8 @@ class Classifier( self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension else: raise ValueError("Unsupported CNN architecture") + + self.encoder = self.encoder.to(self.config.device) def _freeze_encoder(self) -> None: """Freeze the encoder parameters.""" @@ -93,6 +100,7 @@ class Classifier( nn.ReLU(), nn.Linear(self.config.hidden_dim, 1 if self.config.num_classes == 2 else self.config.num_classes), ) + self.classifier_head = self.classifier_head.to(self.config.device) def _get_encoder_output(self, x: torch.Tensor) -> torch.Tensor: """Extract the appropriate output from the encoder.""" diff --git a/poetry.lock b/poetry.lock index 8799e67c..919edd18 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3139,6 +3139,27 @@ dev = ["changelist (==0.5)"] lint = ["pre-commit (==3.7.0)"] test = ["pytest (>=7.4)", "pytest-cov (>=4.1)"] +[[package]] +name = "lightning-utilities" +version = "0.11.9" +description = "Lightning toolbox for across the our ecosystem." +optional = true +python-versions = ">=3.8" +files = [ + {file = "lightning_utilities-0.11.9-py3-none-any.whl", hash = "sha256:ac6d4e9e28faf3ff4be997876750fee10dc604753dbc429bf3848a95c5d7e0d2"}, + {file = "lightning_utilities-0.11.9.tar.gz", hash = "sha256:f5052b81344cc2684aa9afd74b7ce8819a8f49a858184ec04548a5a109dfd053"}, +] + +[package.dependencies] +packaging = ">=17.1" +setuptools = "*" +typing-extensions = "*" + +[package.extras] +cli = ["fire"] +docs = ["requests (>=2.0.0)"] +typing = ["mypy (>=1.0.0)", "types-setuptools"] + [[package]] name = "llvmlite" version = "0.43.0" @@ -6798,6 +6819,38 @@ webencodings = ">=0.4" doc = ["sphinx", "sphinx_rtd_theme"] test = ["pytest", "ruff"] +[[package]] +name = "tokenizers" +version = "0.21.0" +description = "" +optional = true +python-versions = ">=3.7" +files = [ + {file = "tokenizers-0.21.0-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:3c4c93eae637e7d2aaae3d376f06085164e1660f89304c0ab2b1d08a406636b2"}, + {file = "tokenizers-0.21.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:f53ea537c925422a2e0e92a24cce96f6bc5046bbef24a1652a5edc8ba975f62e"}, + {file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b177fb54c4702ef611de0c069d9169f0004233890e0c4c5bd5508ae05abf193"}, + {file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6b43779a269f4629bebb114e19c3fca0223296ae9fea8bb9a7a6c6fb0657ff8e"}, + {file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9aeb255802be90acfd363626753fda0064a8df06031012fe7d52fd9a905eb00e"}, + {file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d8b09dbeb7a8d73ee204a70f94fc06ea0f17dcf0844f16102b9f414f0b7463ba"}, + {file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:400832c0904f77ce87c40f1a8a27493071282f785724ae62144324f171377273"}, + {file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e84ca973b3a96894d1707e189c14a774b701596d579ffc7e69debfc036a61a04"}, + {file = "tokenizers-0.21.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:eb7202d231b273c34ec67767378cd04c767e967fda12d4a9e36208a34e2f137e"}, + {file = "tokenizers-0.21.0-cp39-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:089d56db6782a73a27fd8abf3ba21779f5b85d4a9f35e3b493c7bbcbbf0d539b"}, + {file = "tokenizers-0.21.0-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:c87ca3dc48b9b1222d984b6b7490355a6fdb411a2d810f6f05977258400ddb74"}, + {file = "tokenizers-0.21.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:4145505a973116f91bc3ac45988a92e618a6f83eb458f49ea0790df94ee243ff"}, + {file = "tokenizers-0.21.0-cp39-abi3-win32.whl", hash = "sha256:eb1702c2f27d25d9dd5b389cc1f2f51813e99f8ca30d9e25348db6585a97e24a"}, + {file = "tokenizers-0.21.0-cp39-abi3-win_amd64.whl", hash = "sha256:87841da5a25a3a5f70c102de371db120f41873b854ba65e52bccd57df5a3780c"}, + {file = "tokenizers-0.21.0.tar.gz", hash = "sha256:ee0894bf311b75b0c03079f33859ae4b2334d675d4e93f5a4132e1eae2834fe4"}, +] + +[package.dependencies] +huggingface-hub = ">=0.16.4,<1.0" + +[package.extras] +dev = ["tokenizers[testing]"] +docs = ["setuptools-rust", "sphinx", "sphinx-rtd-theme"] +testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests", "ruff"] + [[package]] name = "tomli" version = "2.0.2" @@ -6863,6 +6916,34 @@ typing-extensions = ">=4.8.0" opt-einsum = ["opt-einsum (>=3.3)"] optree = ["optree (>=0.11.0)"] +[[package]] +name = "torchmetrics" +version = "1.6.0" +description = "PyTorch native Metrics" +optional = true +python-versions = ">=3.9" +files = [ + {file = "torchmetrics-1.6.0-py3-none-any.whl", hash = "sha256:a508cdd87766cedaaf55a419812bf9f493aff8fffc02cc19df5a8e2e7ccb942a"}, + {file = "torchmetrics-1.6.0.tar.gz", hash = "sha256:aebba248708fb90def20cccba6f55bddd134a58de43fb22b0c5ca0f3a89fa984"}, +] + +[package.dependencies] +lightning-utilities = ">=0.8.0" +numpy = ">1.20.0" +packaging = ">17.1" +torch = ">=2.0.0" + +[package.extras] +all = ["SciencePlots (>=2.0.0)", "gammatone (>=1.0.0)", "ipadic (>=1.0.0)", "librosa (>=0.10.0)", "matplotlib (>=3.6.0)", "mecab-python3 (>=1.0.6)", "mypy (==1.13.0)", "nltk (>3.8.1)", "numpy (<2.0)", "onnxruntime (>=1.12.0)", "pesq (>=0.0.4)", "piq (<=0.8.0)", "pycocotools (>2.0.0)", "pystoi (>=0.4.0)", "regex (>=2021.9.24)", "requests (>=2.19.0)", "scipy (>1.0.0)", "sentencepiece (>=0.2.0)", "torch (==2.5.1)", "torch-fidelity (<=0.4.0)", "torchaudio (>=2.0.1)", "torchvision (>=0.15.1)", "tqdm (<4.68.0)", "transformers (>4.4.0)", "transformers (>=4.42.3)", "types-PyYAML", "types-emoji", "types-protobuf", "types-requests", "types-setuptools", "types-six", "types-tabulate"] +audio = ["gammatone (>=1.0.0)", "librosa (>=0.10.0)", "numpy (<2.0)", "onnxruntime (>=1.12.0)", "pesq (>=0.0.4)", "pystoi (>=0.4.0)", "requests (>=2.19.0)", "torchaudio (>=2.0.1)"] +detection = ["pycocotools (>2.0.0)", "torchvision (>=0.15.1)"] +dev = ["PyTDC (==0.4.1)", "SciencePlots (>=2.0.0)", "bert-score (==0.3.13)", "dython (==0.7.6)", "dython (>=0.7.8,<0.8.0)", "fairlearn", "fast-bss-eval (>=0.1.0)", "faster-coco-eval (>=1.6.3)", "gammatone (>=1.0.0)", "huggingface-hub (<0.27)", "ipadic (>=1.0.0)", "jiwer (>=2.3.0)", "kornia (>=0.6.7)", "librosa (>=0.10.0)", "lpips (<=0.1.4)", "matplotlib (>=3.6.0)", "mecab-ko (>=1.0.0,<1.1.0)", "mecab-ko-dic (>=1.0.0)", "mecab-python3 (>=1.0.6)", "mir-eval (>=0.6)", "monai (==1.3.2)", "monai (==1.4.0)", "mypy (==1.13.0)", "netcal (>1.0.0)", "nltk (>3.8.1)", "numpy (<2.0)", "numpy (<2.2.0)", "onnxruntime (>=1.12.0)", "pandas (>1.4.0)", "permetrics (==2.0.0)", "pesq (>=0.0.4)", "piq (<=0.8.0)", "pycocotools (>2.0.0)", "pystoi (>=0.4.0)", "pytorch-msssim (==1.0.0)", "regex (>=2021.9.24)", "requests (>=2.19.0)", "rouge-score (>0.1.0)", "sacrebleu (>=2.3.0)", "scikit-image (>=0.19.0)", "scipy (>1.0.0)", "sentencepiece (>=0.2.0)", "sewar (>=0.4.4)", "statsmodels (>0.13.5)", "torch (==2.5.1)", "torch-complex (<0.5.0)", "torch-fidelity (<=0.4.0)", "torchaudio (>=2.0.1)", "torchvision (>=0.15.1)", "tqdm (<4.68.0)", "transformers (>4.4.0)", "transformers (>=4.42.3)", "types-PyYAML", "types-emoji", "types-protobuf", "types-requests", "types-setuptools", "types-six", "types-tabulate"] +image = ["scipy (>1.0.0)", "torch-fidelity (<=0.4.0)", "torchvision (>=0.15.1)"] +multimodal = ["piq (<=0.8.0)", "transformers (>=4.42.3)"] +text = ["ipadic (>=1.0.0)", "mecab-python3 (>=1.0.6)", "nltk (>3.8.1)", "regex (>=2021.9.24)", "sentencepiece (>=0.2.0)", "tqdm (<4.68.0)", "transformers (>4.4.0)"] +typing = ["mypy (==1.13.0)", "torch (==2.5.1)", "types-PyYAML", "types-emoji", "types-protobuf", "types-requests", "types-setuptools", "types-six", "types-tabulate"] +visual = ["SciencePlots (>=2.0.0)", "matplotlib (>=3.6.0)"] + [[package]] name = "torchvision" version = "0.19.1" @@ -6956,6 +7037,75 @@ files = [ docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"] test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0,<8.2)", "pytest-mock", "pytest-mypy-testing"] +[[package]] +name = "transformers" +version = "4.47.0" +description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" +optional = true +python-versions = ">=3.9.0" +files = [ + {file = "transformers-4.47.0-py3-none-any.whl", hash = "sha256:a8e1bafdaae69abdda3cad638fe392e37c86d2ce0ecfcae11d60abb8f949ff4d"}, + {file = "transformers-4.47.0.tar.gz", hash = "sha256:f8ead7a5a4f6937bb507e66508e5e002dc5930f7b6122a9259c37b099d0f3b19"}, +] + +[package.dependencies] +filelock = "*" +huggingface-hub = ">=0.24.0,<1.0" +numpy = ">=1.17" +packaging = ">=20.0" +pyyaml = ">=5.1" +regex = "!=2019.12.17" +requests = "*" +safetensors = ">=0.4.1" +tokenizers = ">=0.21,<0.22" +tqdm = ">=4.27" + +[package.extras] +accelerate = ["accelerate (>=0.26.0)"] +agents = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch"] +all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch", "torchaudio", "torchvision"] +audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] +benchmark = ["optimum-benchmark (>=0.3.0)"] +codecarbon = ["codecarbon (==1.2.0)"] +deepspeed = ["accelerate (>=0.26.0)", "deepspeed (>=0.9.3)"] +deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.26.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "optuna", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] +dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av (==9.2.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "libcst", "librosa", "nltk (<=3.8.1)", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "libcst", "librosa", "nltk (<=3.8.1)", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.21,<0.22)", "urllib3 (<2.0.0)"] +dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "libcst", "librosa", "nltk (<=3.8.1)", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +flax = ["flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "optax (>=0.0.8,<=0.1.4)", "scipy (<1.13.0)"] +flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] +ftfy = ["ftfy"] +integrations = ["optuna", "ray[tune] (>=2.7.0)", "sigopt"] +ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "rhoknp (>=1.1.0,<1.3.1)", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)"] +modelcreation = ["cookiecutter (==1.7.3)"] +natten = ["natten (>=0.14.6,<0.15.0)"] +onnx = ["onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "tf2onnx"] +onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"] +optuna = ["optuna"] +quality = ["GitPython (<3.1.19)", "datasets (!=2.5.0)", "isort (>=5.5.4)", "libcst", "rich", "ruff (==0.5.1)", "urllib3 (<2.0.0)"] +ray = ["ray[tune] (>=2.7.0)"] +retrieval = ["datasets (!=2.5.0)", "faiss-cpu"] +ruff = ["ruff (==0.5.1)"] +sagemaker = ["sagemaker (>=2.31.0)"] +sentencepiece = ["protobuf", "sentencepiece (>=0.1.91,!=0.1.92)"] +serving = ["fastapi", "pydantic", "starlette", "uvicorn"] +sigopt = ["sigopt"] +sklearn = ["scikit-learn"] +speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] +testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "parameterized", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] +tf = ["keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"] +tf-cpu = ["keras (>2.9,<2.16)", "keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow-cpu (>2.9,<2.16)", "tensorflow-probability (<0.24)", "tensorflow-text (<2.16)", "tf2onnx"] +tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] +tiktoken = ["blobfile", "tiktoken"] +timm = ["timm (<=1.0.11)"] +tokenizers = ["tokenizers (>=0.21,<0.22)"] +torch = ["accelerate (>=0.26.0)", "torch"] +torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] +torch-vision = ["Pillow (>=10.0.1,<=15.0)", "torchvision"] +torchhub = ["filelock", "huggingface-hub (>=0.24.0,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.21,<0.22)", "torch", "tqdm (>=4.27)"] +video = ["av (==9.2.0)"] +vision = ["Pillow (>=10.0.1,<=15.0)"] + [[package]] name = "transforms3d" version = "0.4.2" @@ -7558,6 +7708,7 @@ dev = ["debugpy", "pre-commit"] dora = ["gym-dora"] dynamixel = ["dynamixel-sdk", "pynput"] feetech = ["feetech-servo-sdk", "pynput"] +hilserl = ["torchmetrics", "transformers"] intelrealsense = ["pyrealsense2"] pusht = ["gym-pusht"] stretch = ["hello-robot-stretch-body", "pynput", "pyrealsense2", "pyrender"] @@ -7569,4 +7720,4 @@ xarm = ["gym-xarm"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "41344f0eb2d06d9a378abcd10df8205aa3926ff0a08ac5ab1a0b1bcae7440fd8" +content-hash = "b9d299916ced6af1d243f961a32b0a4aacbef18e0b95337a5224e8511f5d6dda" diff --git a/pyproject.toml b/pyproject.toml index 59c2de8b..738903bd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,8 @@ pyrender = {git = "https://github.com/mmatl/pyrender.git", markers = "sys_platfo hello-robot-stretch-body = {version = ">=0.7.27", markers = "sys_platform == 'linux'", optional = true} pyserial = {version = ">=3.5", optional = true} jsonlines = ">=4.0.0" +transformers = {version = "^4.47.0", optional = true} +torchmetrics = {version = "^1.6.0", optional = true} [tool.poetry.extras] @@ -86,6 +88,7 @@ dynamixel = ["dynamixel-sdk", "pynput"] feetech = ["feetech-servo-sdk", "pynput"] intelrealsense = ["pyrealsense2"] stretch = ["hello-robot-stretch-body", "pyrender", "pyrealsense2", "pynput"] +hilserl = ["transformers", "torchmetrics"] [tool.ruff] line-length = 110 diff --git a/tests/conftest.py b/tests/conftest.py index 2075c2aa..adf050aa 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,9 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import random import traceback import pytest +import torch from serial import SerialException from lerobot import available_cameras, available_motors, available_robots @@ -124,3 +126,14 @@ def patch_builtins_input(monkeypatch): print(text) monkeypatch.setattr("builtins.input", print_text) + + +def pytest_addoption(parser): + parser.addoption("--seed", action="store", default="42", help="Set random seed for reproducibility") + + +@pytest.fixture(autouse=True) +def set_random_seed(request): + seed = int(request.config.getoption("--seed")) + random.seed(seed) # Python random + torch.manual_seed(seed) # PyTorch diff --git a/tests/policies/hilserl/classifier/check_hiserl_reward_classifier.py b/tests/policies/hilserl/classifier/check_hiserl_reward_classifier.py new file mode 100644 index 00000000..55e6e381 --- /dev/null +++ b/tests/policies/hilserl/classifier/check_hiserl_reward_classifier.py @@ -0,0 +1,244 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import numpy as np +import torch +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss +from torch.optim import Adam +from torch.utils.data import DataLoader +from torchmetrics import AUROC, Accuracy, F1Score, Precision, Recall +from torchvision.datasets import CIFAR10 +from torchvision.transforms import ToTensor + +from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier, ClassifierConfig + +BATCH_SIZE = 1000 +LR = 0.1 +EPOCH_NUM = 2 + +if torch.cuda.is_available(): + DEVICE = torch.device("cuda") +elif torch.backends.mps.is_available(): + DEVICE = torch.device("mps") +else: + DEVICE = torch.device("cpu") + + +def train_evaluate_multiclass_classifier(): + logging.info( + f"Start multiclass classifier train eval with {DEVICE} device, batch size {BATCH_SIZE}, learning rate {LR}" + ) + multiclass_config = ClassifierConfig(model_name="microsoft/resnet-18", device=DEVICE, num_classes=10) + multiclass_classifier = Classifier(multiclass_config) + + trainset = CIFAR10(root="data", train=True, download=True, transform=ToTensor()) + testset = CIFAR10(root="data", train=False, download=True, transform=ToTensor()) + + trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True) + testloader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False) + + multiclass_num_classes = 10 + epoch = 1 + + criterion = CrossEntropyLoss() + optimizer = Adam(multiclass_classifier.parameters(), lr=LR) + + multiclass_classifier.train() + + logging.info("Start multiclass classifier training") + + # Training loop + while epoch < EPOCH_NUM: # loop over the dataset multiple times + for i, data in enumerate(trainloader): + inputs, labels = data + inputs, labels = inputs.to(DEVICE), labels.to(DEVICE) + + # Zero the parameter gradients + optimizer.zero_grad() + + # Forward pass + outputs = multiclass_classifier(inputs) + + loss = criterion(outputs.logits, labels) + loss.backward() + optimizer.step() + + if i % 10 == 0: # print every 10 mini-batches + logging.info(f"[Epoch {epoch}, Batch {i}] loss: {loss.item():.3f}") + + epoch += 1 + + print("Multiclass classifier training finished") + + multiclass_classifier.eval() + + test_loss = 0.0 + test_labels = [] + test_pridections = [] + test_probs = [] + + with torch.no_grad(): + for data in testloader: + images, labels = data + images, labels = images.to(DEVICE), labels.to(DEVICE) + outputs = multiclass_classifier(images) + loss = criterion(outputs.logits, labels) + test_loss += loss.item() * BATCH_SIZE + + _, predicted = torch.max(outputs.logits, 1) + test_labels.extend(labels.cpu()) + test_pridections.extend(predicted.cpu()) + test_probs.extend(outputs.probabilities.cpu()) + + test_loss = test_loss / len(testset) + + logging.info(f"Multiclass classifier test loss {test_loss:.3f}") + + test_labels = torch.stack(test_labels) + test_predictions = torch.stack(test_pridections) + test_probs = torch.stack(test_probs) + + accuracy = Accuracy(task="multiclass", num_classes=multiclass_num_classes) + precision = Precision(task="multiclass", average="weighted", num_classes=multiclass_num_classes) + recall = Recall(task="multiclass", average="weighted", num_classes=multiclass_num_classes) + f1 = F1Score(task="multiclass", average="weighted", num_classes=multiclass_num_classes) + auroc = AUROC(task="multiclass", num_classes=multiclass_num_classes, average="weighted") + + # Calculate metrics + acc = accuracy(test_predictions, test_labels) + prec = precision(test_predictions, test_labels) + rec = recall(test_predictions, test_labels) + f1_score = f1(test_predictions, test_labels) + auroc_score = auroc(test_probs, test_labels) + + logging.info(f"Accuracy: {acc:.2f}") + logging.info(f"Precision: {prec:.2f}") + logging.info(f"Recall: {rec:.2f}") + logging.info(f"F1 Score: {f1_score:.2f}") + logging.info(f"AUROC Score: {auroc_score:.2f}") + + +def train_evaluate_binary_classifier(): + logging.info( + f"Start binary classifier train eval with {DEVICE} device, batch size {BATCH_SIZE}, learning rate {LR}" + ) + + target_binary_class = 3 + + def one_vs_rest(dataset, target_class): + new_targets = [] + for _, label in dataset: + new_label = float(1.0) if label == target_class else float(0.0) + new_targets.append(new_label) + + dataset.targets = new_targets # Replace the original labels with the binary ones + return dataset + + binary_train_dataset = CIFAR10(root="data", train=True, download=True, transform=ToTensor()) + binary_test_dataset = CIFAR10(root="data", train=False, download=True, transform=ToTensor()) + + # Apply one-vs-rest labeling + binary_train_dataset = one_vs_rest(binary_train_dataset, target_binary_class) + binary_test_dataset = one_vs_rest(binary_test_dataset, target_binary_class) + + binary_trainloader = DataLoader(binary_train_dataset, batch_size=BATCH_SIZE, shuffle=True) + binary_testloader = DataLoader(binary_test_dataset, batch_size=BATCH_SIZE, shuffle=False) + + binary_epoch = 1 + + binary_config = ClassifierConfig(model_name="microsoft/resnet-50", device=DEVICE) + binary_classifier = Classifier(binary_config) + + class_counts = np.bincount(binary_train_dataset.targets) + n = len(binary_train_dataset) + w0 = n / (2.0 * class_counts[0]) + w1 = n / (2.0 * class_counts[1]) + + binary_criterion = BCEWithLogitsLoss(pos_weight=torch.tensor(w1 / w0)) + binary_optimizer = Adam(binary_classifier.parameters(), lr=LR) + + binary_classifier.train() + + logging.info("Start binary classifier training") + + # Training loop + while binary_epoch < EPOCH_NUM: # loop over the dataset multiple times + for i, data in enumerate(binary_trainloader): + inputs, labels = data + inputs, labels = inputs.to(DEVICE), labels.to(torch.float32).to(DEVICE) + + # Zero the parameter gradients + binary_optimizer.zero_grad() + + # Forward pass + outputs = binary_classifier(inputs) + loss = binary_criterion(outputs.logits, labels) + loss.backward() + binary_optimizer.step() + + if i % 10 == 0: # print every 10 mini-batches + print(f"[Epoch {binary_epoch}, Batch {i}] loss: {loss.item():.3f}") + binary_epoch += 1 + + logging.info("Binary classifier training finished") + logging.info("Start binary classifier evaluation") + + binary_classifier.eval() + + test_loss = 0.0 + test_labels = [] + test_pridections = [] + test_probs = [] + + with torch.no_grad(): + for data in binary_testloader: + images, labels = data + images, labels = images.to(DEVICE), labels.to(torch.float32).to(DEVICE) + outputs = binary_classifier(images) + loss = binary_criterion(outputs.logits, labels) + test_loss += loss.item() * BATCH_SIZE + + test_labels.extend(labels.cpu()) + test_pridections.extend(outputs.logits.cpu()) + test_probs.extend(outputs.probabilities.cpu()) + + test_loss = test_loss / len(binary_test_dataset) + + logging.info(f"Binary classifier test loss {test_loss:.3f}") + + test_labels = torch.stack(test_labels) + test_predictions = torch.stack(test_pridections) + test_probs = torch.stack(test_probs) + + # Calculate metrics + acc = Accuracy(task="binary")(test_predictions, test_labels) + prec = Precision(task="binary", average="weighted")(test_predictions, test_labels) + rec = Recall(task="binary", average="weighted")(test_predictions, test_labels) + f1_score = F1Score(task="binary", average="weighted")(test_predictions, test_labels) + auroc_score = AUROC(task="binary", average="weighted")(test_probs, test_labels) + + logging.info(f"Accuracy: {acc:.2f}") + logging.info(f"Precision: {prec:.2f}") + logging.info(f"Recall: {rec:.2f}") + logging.info(f"F1 Score: {f1_score:.2f}") + logging.info(f"AUROC Score: {auroc_score:.2f}") + + +if __name__ == "__main__": + train_evaluate_multiclass_classifier() + train_evaluate_binary_classifier() diff --git a/tests/policies/hilserl/classifier/test_modelling_classifier.py b/tests/policies/hilserl/classifier/test_modelling_classifier.py new file mode 100644 index 00000000..014165eb --- /dev/null +++ b/tests/policies/hilserl/classifier/test_modelling_classifier.py @@ -0,0 +1,78 @@ +import torch + +from lerobot.common.policies.hilserl.classifier.modeling_classifier import ( + Classifier, + ClassifierConfig, + ClassifierOutput, +) +from tests.utils import require_package + + +def test_classifier_output(): + output = ClassifierOutput( + logits=torch.tensor([1, 2, 3]), probabilities=torch.tensor([0.1, 0.2, 0.3]), hidden_states=None + ) + + assert ( + f"{output}" + == "ClassifierOutput(logits=tensor([1, 2, 3]), probabilities=tensor([0.1000, 0.2000, 0.3000]), hidden_states=None)" + ) + + +@require_package("transformers") +def test_binary_classifier_with_default_params(): + config = ClassifierConfig() + classifier = Classifier(config) + + batch_size = 10 + + input = torch.rand(batch_size, 3, 224, 224) + output = classifier(input) + + assert output is not None + assert output.logits.shape == torch.Size([batch_size]) + assert not torch.isnan(output.logits).any(), "Tensor contains NaN values" + assert output.probabilities.shape == torch.Size([batch_size]) + assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values" + assert output.hidden_states.shape == torch.Size([batch_size, 2048]) + assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values" + + +@require_package("transformers") +def test_multiclass_classifier(): + num_classes = 5 + config = ClassifierConfig(num_classes=num_classes) + classifier = Classifier(config) + + batch_size = 10 + + input = torch.rand(batch_size, 3, 224, 224) + output = classifier(input) + + assert output is not None + assert output.logits.shape == torch.Size([batch_size, num_classes]) + assert not torch.isnan(output.logits).any(), "Tensor contains NaN values" + assert output.probabilities.shape == torch.Size([batch_size, num_classes]) + assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values" + assert output.hidden_states.shape == torch.Size([batch_size, 2048]) + assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values" + + +@require_package("transformers") +def test_default_device(): + config = ClassifierConfig() + assert config.device == "cpu" + + classifier = Classifier(config) + for p in classifier.parameters(): + assert p.device == torch.device("cpu") + + +@require_package("transformers") +def test_explicit_device_setup(): + config = ClassifierConfig(device="meta") + assert config.device == "meta" + + classifier = Classifier(config) + for p in classifier.parameters(): + assert p.device == torch.device("meta")