diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml new file mode 100644 index 00000000..132c21cb --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -0,0 +1,54 @@ +name: "\U0001F41B Bug Report" +description: Submit a bug report to help us improve LeRobot +body: + - type: markdown + attributes: + value: | + Thanks for taking the time to submit a bug report! 🐛 + If this is not a bug related to the LeRobot library directly, but instead a general question about your code or the library specifically please use our [discord](https://discord.gg/s3KuuzsPFb). + + - type: textarea + id: system-info + attributes: + label: System Info + description: If needed, you can share your lerobot configuration with us by running `python -m lerobot.commands.env` and copy-pasting its outputs below + render: Shell + placeholder: lerobot version, OS, python version, numpy version, torch version, and lerobot's configuration + validations: + required: true + + - type: checkboxes + id: information-scripts-examples + attributes: + label: Information + description: 'The problem arises when using:' + options: + - label: "One of the scripts in the examples/ folder of LeRobot" + - label: "My own task or dataset (give details below)" + + - type: textarea + id: reproduction + validations: + required: true + attributes: + label: Reproduction + description: | + If needed, provide a simple code sample that reproduces the problem you ran into. It can be a Colab link or just a code snippet. + Sharing error messages or stack traces could be useful as well! + Important! Use code tags to correctly format your code. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting + Try to avoid screenshots, as they are hard to read and don't allow copy-and-pasting. + + placeholder: | + Steps to reproduce the behavior: + + 1. + 2. + 3. + + - type: textarea + id: expected-behavior + validations: + required: true + attributes: + label: Expected behavior + description: "A clear and concise description of what you would expect to happen." diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 00000000..5084567b --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,15 @@ +# What does this PR do? + +Example: Fixes # (issue) + + +## Before submitting +- Read the [contributor guideline](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md#submitting-a-pull-request-pr). +- Provide a minimal code example for the reviewer to checkout & try. +- Explain how you tested your changes. + + +## Who can review? + +Anyone in the community is free to review the PR once the tests have passed. Feel free to tag +members/contributors who may be interested in your PR. Try to avoid tagging more than 3 people. diff --git a/.github/poetry/cpu/poetry.lock b/.github/poetry/cpu/poetry.lock index fe4ed7a0..dfd664b4 100644 --- a/.github/poetry/cpu/poetry.lock +++ b/.github/poetry/cpu/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand. [[package]] name = "absl-py" @@ -11,6 +11,116 @@ files = [ {file = "absl_py-2.1.0-py3-none-any.whl", hash = "sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308"}, ] +[[package]] +name = "aiohttp" +version = "3.9.4" +description = "Async http client/server framework (asyncio)" +optional = false +python-versions = ">=3.8" +files = [ + {file = "aiohttp-3.9.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:76d32588ef7e4a3f3adff1956a0ba96faabbdee58f2407c122dd45aa6e34f372"}, + {file = "aiohttp-3.9.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:56181093c10dbc6ceb8a29dfeea1e815e1dfdc020169203d87fd8d37616f73f9"}, + {file = "aiohttp-3.9.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c7a5b676d3c65e88b3aca41816bf72831898fcd73f0cbb2680e9d88e819d1e4d"}, + {file = "aiohttp-3.9.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d1df528a85fb404899d4207a8d9934cfd6be626e30e5d3a5544a83dbae6d8a7e"}, + {file = "aiohttp-3.9.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f595db1bceabd71c82e92df212dd9525a8a2c6947d39e3c994c4f27d2fe15b11"}, + {file = "aiohttp-3.9.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9c0b09d76e5a4caac3d27752027fbd43dc987b95f3748fad2b924a03fe8632ad"}, + {file = "aiohttp-3.9.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:689eb4356649ec9535b3686200b231876fb4cab4aca54e3bece71d37f50c1d13"}, + {file = "aiohttp-3.9.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a3666cf4182efdb44d73602379a66f5fdfd5da0db5e4520f0ac0dcca644a3497"}, + {file = "aiohttp-3.9.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:b65b0f8747b013570eea2f75726046fa54fa8e0c5db60f3b98dd5d161052004a"}, + {file = "aiohttp-3.9.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:a1885d2470955f70dfdd33a02e1749613c5a9c5ab855f6db38e0b9389453dce7"}, + {file = "aiohttp-3.9.4-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:0593822dcdb9483d41f12041ff7c90d4d1033ec0e880bcfaf102919b715f47f1"}, + {file = "aiohttp-3.9.4-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:47f6eb74e1ecb5e19a78f4a4228aa24df7fbab3b62d4a625d3f41194a08bd54f"}, + {file = "aiohttp-3.9.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c8b04a3dbd54de6ccb7604242fe3ad67f2f3ca558f2d33fe19d4b08d90701a89"}, + {file = "aiohttp-3.9.4-cp310-cp310-win32.whl", hash = "sha256:8a78dfb198a328bfb38e4308ca8167028920fb747ddcf086ce706fbdd23b2926"}, + {file = "aiohttp-3.9.4-cp310-cp310-win_amd64.whl", hash = "sha256:e78da6b55275987cbc89141a1d8e75f5070e577c482dd48bd9123a76a96f0bbb"}, + {file = "aiohttp-3.9.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:c111b3c69060d2bafc446917534150fd049e7aedd6cbf21ba526a5a97b4402a5"}, + {file = "aiohttp-3.9.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:efbdd51872cf170093998c87ccdf3cb5993add3559341a8e5708bcb311934c94"}, + {file = "aiohttp-3.9.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7bfdb41dc6e85d8535b00d73947548a748e9534e8e4fddd2638109ff3fb081df"}, + {file = "aiohttp-3.9.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2bd9d334412961125e9f68d5b73c1d0ab9ea3f74a58a475e6b119f5293eee7ba"}, + {file = "aiohttp-3.9.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:35d78076736f4a668d57ade00c65d30a8ce28719d8a42471b2a06ccd1a2e3063"}, + {file = "aiohttp-3.9.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:824dff4f9f4d0f59d0fa3577932ee9a20e09edec8a2f813e1d6b9f89ced8293f"}, + {file = "aiohttp-3.9.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:52b8b4e06fc15519019e128abedaeb56412b106ab88b3c452188ca47a25c4093"}, + {file = "aiohttp-3.9.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:eae569fb1e7559d4f3919965617bb39f9e753967fae55ce13454bec2d1c54f09"}, + {file = "aiohttp-3.9.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:69b97aa5792428f321f72aeb2f118e56893371f27e0b7d05750bcad06fc42ca1"}, + {file = "aiohttp-3.9.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:4d79aad0ad4b980663316f26d9a492e8fab2af77c69c0f33780a56843ad2f89e"}, + {file = "aiohttp-3.9.4-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:d6577140cd7db19e430661e4b2653680194ea8c22c994bc65b7a19d8ec834403"}, + {file = "aiohttp-3.9.4-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:9860d455847cd98eb67897f5957b7cd69fbcb436dd3f06099230f16a66e66f79"}, + {file = "aiohttp-3.9.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:69ff36d3f8f5652994e08bd22f093e11cfd0444cea310f92e01b45a4e46b624e"}, + {file = "aiohttp-3.9.4-cp311-cp311-win32.whl", hash = "sha256:e27d3b5ed2c2013bce66ad67ee57cbf614288bda8cdf426c8d8fe548316f1b5f"}, + {file = "aiohttp-3.9.4-cp311-cp311-win_amd64.whl", hash = "sha256:d6a67e26daa686a6fbdb600a9af8619c80a332556245fa8e86c747d226ab1a1e"}, + {file = "aiohttp-3.9.4-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:c5ff8ff44825736a4065d8544b43b43ee4c6dd1530f3a08e6c0578a813b0aa35"}, + {file = "aiohttp-3.9.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d12a244627eba4e9dc52cbf924edef905ddd6cafc6513849b4876076a6f38b0e"}, + {file = "aiohttp-3.9.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:dcad56c8d8348e7e468899d2fb3b309b9bc59d94e6db08710555f7436156097f"}, + {file = "aiohttp-3.9.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f7e69a7fd4b5ce419238388e55abd220336bd32212c673ceabc57ccf3d05b55"}, + {file = "aiohttp-3.9.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c4870cb049f10d7680c239b55428916d84158798eb8f353e74fa2c98980dcc0b"}, + {file = "aiohttp-3.9.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3b2feaf1b7031ede1bc0880cec4b0776fd347259a723d625357bb4b82f62687b"}, + {file = "aiohttp-3.9.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:939393e8c3f0a5bcd33ef7ace67680c318dc2ae406f15e381c0054dd658397de"}, + {file = "aiohttp-3.9.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7d2334e387b2adcc944680bebcf412743f2caf4eeebd550f67249c1c3696be04"}, + {file = "aiohttp-3.9.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:e0198ea897680e480845ec0ffc5a14e8b694e25b3f104f63676d55bf76a82f1a"}, + {file = "aiohttp-3.9.4-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:e40d2cd22914d67c84824045861a5bb0fb46586b15dfe4f046c7495bf08306b2"}, + {file = "aiohttp-3.9.4-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:aba80e77c227f4234aa34a5ff2b6ff30c5d6a827a91d22ff6b999de9175d71bd"}, + {file = "aiohttp-3.9.4-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:fb68dc73bc8ac322d2e392a59a9e396c4f35cb6fdbdd749e139d1d6c985f2527"}, + {file = "aiohttp-3.9.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:f3460a92638dce7e47062cf088d6e7663adb135e936cb117be88d5e6c48c9d53"}, + {file = "aiohttp-3.9.4-cp312-cp312-win32.whl", hash = "sha256:32dc814ddbb254f6170bca198fe307920f6c1308a5492f049f7f63554b88ef36"}, + {file = "aiohttp-3.9.4-cp312-cp312-win_amd64.whl", hash = "sha256:63f41a909d182d2b78fe3abef557fcc14da50c7852f70ae3be60e83ff64edba5"}, + {file = "aiohttp-3.9.4-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:c3770365675f6be220032f6609a8fbad994d6dcf3ef7dbcf295c7ee70884c9af"}, + {file = "aiohttp-3.9.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:305edae1dea368ce09bcb858cf5a63a064f3bff4767dec6fa60a0cc0e805a1d3"}, + {file = "aiohttp-3.9.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:6f121900131d116e4a93b55ab0d12ad72573f967b100e49086e496a9b24523ea"}, + {file = "aiohttp-3.9.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b71e614c1ae35c3d62a293b19eface83d5e4d194e3eb2fabb10059d33e6e8cbf"}, + {file = "aiohttp-3.9.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:419f009fa4cfde4d16a7fc070d64f36d70a8d35a90d71aa27670bba2be4fd039"}, + {file = "aiohttp-3.9.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7b39476ee69cfe64061fd77a73bf692c40021f8547cda617a3466530ef63f947"}, + {file = "aiohttp-3.9.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b33f34c9c7decdb2ab99c74be6443942b730b56d9c5ee48fb7df2c86492f293c"}, + {file = "aiohttp-3.9.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c78700130ce2dcebb1a8103202ae795be2fa8c9351d0dd22338fe3dac74847d9"}, + {file = "aiohttp-3.9.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:268ba22d917655d1259af2d5659072b7dc11b4e1dc2cb9662fdd867d75afc6a4"}, + {file = "aiohttp-3.9.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:17e7c051f53a0d2ebf33013a9cbf020bb4e098c4bc5bce6f7b0c962108d97eab"}, + {file = "aiohttp-3.9.4-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:7be99f4abb008cb38e144f85f515598f4c2c8932bf11b65add0ff59c9c876d99"}, + {file = "aiohttp-3.9.4-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:d58a54d6ff08d2547656356eea8572b224e6f9bbc0cf55fa9966bcaac4ddfb10"}, + {file = "aiohttp-3.9.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:7673a76772bda15d0d10d1aa881b7911d0580c980dbd16e59d7ba1422b2d83cd"}, + {file = "aiohttp-3.9.4-cp38-cp38-win32.whl", hash = "sha256:e4370dda04dc8951012f30e1ce7956a0a226ac0714a7b6c389fb2f43f22a250e"}, + {file = "aiohttp-3.9.4-cp38-cp38-win_amd64.whl", hash = "sha256:eb30c4510a691bb87081192a394fb661860e75ca3896c01c6d186febe7c88530"}, + {file = "aiohttp-3.9.4-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:84e90494db7df3be5e056f91412f9fa9e611fbe8ce4aaef70647297f5943b276"}, + {file = "aiohttp-3.9.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7d4845f8501ab28ebfdbeab980a50a273b415cf69e96e4e674d43d86a464df9d"}, + {file = "aiohttp-3.9.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:69046cd9a2a17245c4ce3c1f1a4ff8c70c7701ef222fce3d1d8435f09042bba1"}, + {file = "aiohttp-3.9.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8b73a06bafc8dcc508420db43b4dd5850e41e69de99009d0351c4f3007960019"}, + {file = "aiohttp-3.9.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:418bb0038dfafeac923823c2e63226179976c76f981a2aaad0ad5d51f2229bca"}, + {file = "aiohttp-3.9.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:71a8f241456b6c2668374d5d28398f8e8cdae4cce568aaea54e0f39359cd928d"}, + {file = "aiohttp-3.9.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:935c369bf8acc2dc26f6eeb5222768aa7c62917c3554f7215f2ead7386b33748"}, + {file = "aiohttp-3.9.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:74e4e48c8752d14ecfb36d2ebb3d76d614320570e14de0a3aa7a726ff150a03c"}, + {file = "aiohttp-3.9.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:916b0417aeddf2c8c61291238ce25286f391a6acb6f28005dd9ce282bd6311b6"}, + {file = "aiohttp-3.9.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:9b6787b6d0b3518b2ee4cbeadd24a507756ee703adbac1ab6dc7c4434b8c572a"}, + {file = "aiohttp-3.9.4-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:221204dbda5ef350e8db6287937621cf75e85778b296c9c52260b522231940ed"}, + {file = "aiohttp-3.9.4-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:10afd99b8251022ddf81eaed1d90f5a988e349ee7d779eb429fb07b670751e8c"}, + {file = "aiohttp-3.9.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:2506d9f7a9b91033201be9ffe7d89c6a54150b0578803cce5cb84a943d075bc3"}, + {file = "aiohttp-3.9.4-cp39-cp39-win32.whl", hash = "sha256:e571fdd9efd65e86c6af2f332e0e95dad259bfe6beb5d15b3c3eca3a6eb5d87b"}, + {file = "aiohttp-3.9.4-cp39-cp39-win_amd64.whl", hash = "sha256:7d29dd5319d20aa3b7749719ac9685fbd926f71ac8c77b2477272725f882072d"}, + {file = "aiohttp-3.9.4.tar.gz", hash = "sha256:6ff71ede6d9a5a58cfb7b6fffc83ab5d4a63138276c771ac91ceaaddf5459644"}, +] + +[package.dependencies] +aiosignal = ">=1.1.2" +async-timeout = {version = ">=4.0,<5.0", markers = "python_version < \"3.11\""} +attrs = ">=17.3.0" +frozenlist = ">=1.1.1" +multidict = ">=4.5,<7.0" +yarl = ">=1.0,<2.0" + +[package.extras] +speedups = ["Brotli", "aiodns", "brotlicffi"] + +[[package]] +name = "aiosignal" +version = "1.3.1" +description = "aiosignal: a list of registered asynchronous callbacks" +optional = false +python-versions = ">=3.7" +files = [ + {file = "aiosignal-1.3.1-py3-none-any.whl", hash = "sha256:f8376fb07dd1e86a584e4fcdec80b36b7f81aac666ebc724e2c090300dd83b17"}, + {file = "aiosignal-1.3.1.tar.gz", hash = "sha256:54cd96e15e1649b75d6c87526a6ff0b6c1b0dd3459f43d9ca11d48c339b68cfc"}, +] + +[package.dependencies] +frozenlist = ">=1.1.0" + [[package]] name = "antlr4-python3-runtime" version = "4.9.3" @@ -43,59 +153,35 @@ files = [ ] [[package]] -name = "av" -version = "12.0.0" -description = "Pythonic bindings for FFmpeg's libraries." +name = "async-timeout" +version = "4.0.3" +description = "Timeout context manager for asyncio programs" optional = false -python-versions = ">=3.8" +python-versions = ">=3.7" files = [ - {file = "av-12.0.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b9d0890553951f76c479a9f2bb952aebae902b1c7d52feea614d37e1cd728a44"}, - {file = "av-12.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5d7f229a253c2e3fea9682c09c5ae179bd6d5d2da38d89eb7f29ef7bed10cb2f"}, - {file = "av-12.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:61b3555d143aacf02e0446f6030319403538eba4dc713c18dfa653a2a23e7f9c"}, - {file = "av-12.0.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:607e13b2c2b26159a37525d7b6f647a32ce78711fccff23d146d3e255ffa115f"}, - {file = "av-12.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39f0b4cfb89f4f06b339c766f92648e798a96747d4163f2fa78660d1ab1f1b5e"}, - {file = "av-12.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:41dcb8c269fa58a56edf3a3c814c32a0c69586827f132b4e395a951b0ce14fad"}, - {file = "av-12.0.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4fa78fbe0e4469226512380180063116105048c66cb12e18ab4b518466c57e6c"}, - {file = "av-12.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:60a869be1d6af916e65ea461cb93922f5db0698655ed7a7eae7c3ecd4af4debb"}, - {file = "av-12.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:df61811cc551c186f0a0e530d97b8b139453534d0f92c1790a923f666522ceda"}, - {file = "av-12.0.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:99cd2fc53091ebfb9a2fa9dd3580267f5bd1c040d0efd99fbc1a162576b271cb"}, - {file = "av-12.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a6d4f1e261df48932128e6495772faa4cc23f5dd1512eec73daab82ad9f3240"}, - {file = "av-12.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:6aec88e41a498b1e01e2dce5371557e20f9a51aae0c16decc5924ec0be2e22b6"}, - {file = "av-12.0.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:90eb8f2d548e96cbc6f78e89c911cdb15a3d80fd944f31111660ce45939cd037"}, - {file = "av-12.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d7f3a02910e77d750dbd516256a16db15030e5371530ff5a5ae902dc03d9005d"}, - {file = "av-12.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2477cc51526aa50575313d66e5e8ad7ab944588469be5e557b360ed572ae536"}, - {file = "av-12.0.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a2f47149d3ca6deb79f3e515b8bef50e27ebdb160813e6d67dba77278d2a7883"}, - {file = "av-12.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3306e4a3ce8b5bfcc3075793d4ed3a2df69179d8fba22cb944a6164dc235dfb6"}, - {file = "av-12.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:dc1b742e7f6df1b499fb960bd6697d1dd8e7ada7484a041a8c20e70a87225f53"}, - {file = "av-12.0.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:0183be6889e835e1b074b4037bfce4fd44671c606cf1c4ab92ea2f271b544aec"}, - {file = "av-12.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:57337f20b208292ec8d3b11e4d289d8688a43d728174850a81b865d3253fff2c"}, - {file = "av-12.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ec915e8f6521545a38566eefc281042ee504ea3cee0618d8558e4920588b3b2"}, - {file = "av-12.0.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:33ad5c0a23c45b72bd6bd47f3b2c1adcd2935ee3d0b6178ed66bba62b964ff31"}, - {file = "av-12.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bfc3a652b12c93120514d56cf025da47442c5ba51530cdf7ba3660257dbb0de1"}, - {file = "av-12.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:037f793dd1ef4a1f57f090191a7f803ad10ec82da0d04ea26bbe0b8a145fe927"}, - {file = "av-12.0.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:fc532376aa264722fae55063abd1871d17a563dc895978e142c8ecfcdeb3a2e8"}, - {file = "av-12.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:abf0c4bc40a0af8a30f4cd96f3be6f19fbce0f21222d7fcec148e085127153f7"}, - {file = "av-12.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:81cedd1c072fbebf606724c406b1a1b00adc711f1dfd2bc04c633ce39d8439d8"}, - {file = "av-12.0.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:02d60f48be9f15dcda37d50f3ce8d7249d9a455643d4322dd3449986bacfc628"}, - {file = "av-12.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5d2619e4c26d661eecfc404f7d739d8b35f0dcef353fabe61512e030254b7031"}, - {file = "av-12.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:1892cc91c888d101777d5432d54e0554c11d1c3a2c65d02a2cae0a2256a8fbb9"}, - {file = "av-12.0.0-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:4819e3ef6c3a44ef6f75907229133a1ee7f688245b2cf49b6b8e969a81ca72c9"}, - {file = "av-12.0.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb16bb314cf1503b0250fc46b2c455ee196584231101be0123f4f78638227b62"}, - {file = "av-12.0.0-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f3e6a62bda9a1e144feeb59bbee046d7a2d98399634a30f57e4990197313c158"}, - {file = "av-12.0.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e08175ffbafa3a70c7b2f81083e160e34122a208cdf70f150b8f5d02c2de6965"}, - {file = "av-12.0.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:e1d255be317b7c1ebdc4dae98935b9f3869161112dc829c625e54f90d8bdd7ab"}, - {file = "av-12.0.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:17964b36e08435910aabd5b3f7dca12f99536902529767d276026bc08f94ced7"}, - {file = "av-12.0.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d2d5f78de29edee06ddcdd4c2b759914575492d6a0cd4de2ce31ee63a4953eff"}, - {file = "av-12.0.0-pp38-pypy38_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:309b32bc97158d0f0c19e273b8e17a855a86806b7194aebc23bd497326cff11f"}, - {file = "av-12.0.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c409c71bd9c7c2f8d018c822f36b1447cfa96eca158381a96f3319bb0ff6e79e"}, - {file = "av-12.0.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:08fc5eaef60a257d622998626e233bf3ff90d2f817f6695d6a27e0ffcfe9dcff"}, - {file = "av-12.0.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:746ab0eff8a7a21a6c6d16e6b6e61709527eba2ad1a524d92a01bb60d02a3df7"}, - {file = "av-12.0.0-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:013b3ac3de3aa1c137af0cedafd364fd1c7524ab3e1cd53e04564fd1632ac04d"}, - {file = "av-12.0.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0fa55923527648f51ac005e44fe2797ebc67f53ad4850e0194d3753761ee33a2"}, - {file = "av-12.0.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:35d514f4dee0cf67e9e6b2a65fb4a28f98da88e71e8c7f7960bd04625d9fe965"}, - {file = "av-12.0.0.tar.gz", hash = "sha256:bcf21ebb722d4538b4099e5a78f730d78814dd70003511c185941dba5651b14d"}, + {file = "async-timeout-4.0.3.tar.gz", hash = "sha256:4640d96be84d82d02ed59ea2b7105a0f7b33abe8703703cd0ab0bf87c427522f"}, + {file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"}, ] +[[package]] +name = "attrs" +version = "23.2.0" +description = "Classes Without Boilerplate" +optional = false +python-versions = ">=3.7" +files = [ + {file = "attrs-23.2.0-py3-none-any.whl", hash = "sha256:99b87a485a5820b23b879f04c2305b44b951b502fd64be915879d77a7e8fc6f1"}, + {file = "attrs-23.2.0.tar.gz", hash = "sha256:935dc3b529c262f6cf76e50877d35a4bd3c1de194fd41f47a2b7ae8f19971f30"}, +] + +[package.extras] +cov = ["attrs[tests]", "coverage[toml] (>=5.3)"] +dev = ["attrs[tests]", "pre-commit"] +docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier", "zope-interface"] +tests = ["attrs[tests-no-zope]", "zope-interface"] +tests-mypy = ["mypy (>=1.6)", "pytest-mypy-plugins"] +tests-no-zope = ["attrs[tests-mypy]", "cloudpickle", "hypothesis", "pympler", "pytest (>=4.3.0)", "pytest-xdist[psutil]"] + [[package]] name = "beautifulsoup4" version = "4.12.3" @@ -196,7 +282,7 @@ pycparser = "*" name = "cfgv" version = "3.4.0" description = "Validate configuration and produce human readable error messages." -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9"}, @@ -371,7 +457,7 @@ files = [ name = "coverage" version = "7.4.4" description = "Code coverage measurement for Python" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "coverage-7.4.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e0be5efd5127542ef31f165de269f77560d6cdef525fffa446de6f7e9186cfb2"}, @@ -434,11 +520,55 @@ tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.1 [package.extras] toml = ["tomli"] +[[package]] +name = "datasets" +version = "2.18.0" +description = "HuggingFace community-driven open-source library of datasets" +optional = false +python-versions = ">=3.8.0" +files = [ + {file = "datasets-2.18.0-py3-none-any.whl", hash = "sha256:f1bbf0e2896917a914de01cbd37075b14deea3837af87ad0d9f697388ccaeb50"}, + {file = "datasets-2.18.0.tar.gz", hash = "sha256:cdf8b8c6abf7316377ba4f49f9589a4c74556d6b481afd0abd2284f3d69185cb"}, +] + +[package.dependencies] +aiohttp = "*" +dill = ">=0.3.0,<0.3.9" +filelock = "*" +fsspec = {version = ">=2023.1.0,<=2024.2.0", extras = ["http"]} +huggingface-hub = ">=0.19.4" +multiprocess = "*" +numpy = ">=1.17" +packaging = "*" +pandas = "*" +pyarrow = ">=12.0.0" +pyarrow-hotfix = "*" +pyyaml = ">=5.1" +requests = ">=2.19.0" +tqdm = ">=4.62.1" +xxhash = "*" + +[package.extras] +apache-beam = ["apache-beam (>=2.26.0)"] +audio = ["librosa", "soundfile (>=0.12.1)"] +benchmarks = ["tensorflow (==2.12.0)", "torch (==2.0.1)", "transformers (==4.30.1)"] +dev = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "ruff (>=0.3.0)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"] +docs = ["s3fs", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos", "torch", "transformers"] +jax = ["jax (>=0.3.14)", "jaxlib (>=0.3.14)"] +metrics-tests = ["Werkzeug (>=1.0.1)", "accelerate", "bert-score (>=0.3.6)", "jiwer", "langdetect", "mauve-text", "nltk", "requests-file (>=1.5.1)", "rouge-score", "sacrebleu", "sacremoses", "scikit-learn", "scipy", "sentencepiece", "seqeval", "six (>=1.15.0,<1.16.0)", "spacy (>=3.0.0)", "texttable (>=1.6.3)", "tldextract", "tldextract (>=3.1.0)", "toml (>=0.10.1)", "typer (<0.5.0)"] +quality = ["ruff (>=0.3.0)"] +s3 = ["s3fs"] +tensorflow = ["tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos"] +tensorflow-gpu = ["tensorflow-gpu (>=2.2.0,!=2.6.0,!=2.6.1)"] +tests = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"] +torch = ["torch"] +vision = ["Pillow (>=6.2.1)"] + [[package]] name = "debugpy" version = "1.8.1" description = "An implementation of the Debug Adapter Protocol for Python" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "debugpy-1.8.1-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:3bda0f1e943d386cc7a0e71bfa59f4137909e2ed947fb3946c506e113000f741"}, @@ -465,17 +595,6 @@ files = [ {file = "debugpy-1.8.1.zip", hash = "sha256:f696d6be15be87aef621917585f9bb94b1dc9e8aced570db1b8a6fc14e8f9b42"}, ] -[[package]] -name = "decorator" -version = "4.4.2" -description = "Decorators for Humans" -optional = false -python-versions = ">=2.6, !=3.0.*, !=3.1.*" -files = [ - {file = "decorator-4.4.2-py2.py3-none-any.whl", hash = "sha256:41fa54c2a0cc4ba648be4fd43cff00aedf5b9465c9bf18d64325bc225f08f760"}, - {file = "decorator-4.4.2.tar.gz", hash = "sha256:e3a62f0520172440ca0dcc823749319382e377f37f140a0b99ef45fecb84bfe7"}, -] - [[package]] name = "diffusers" version = "0.26.3" @@ -506,11 +625,26 @@ test = ["GitPython (<3.1.19)", "Jinja2", "compel (==0.1.8)", "datasets", "invisi torch = ["accelerate (>=0.11.0)", "torch (>=1.4,<2.2.0)"] training = ["Jinja2", "accelerate (>=0.11.0)", "datasets", "peft (>=0.6.0)", "protobuf (>=3.20.3,<4)", "tensorboard"] +[[package]] +name = "dill" +version = "0.3.8" +description = "serialize all of Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "dill-0.3.8-py3-none-any.whl", hash = "sha256:c36ca9ffb54365bdd2f8eb3eff7d2a21237f8452b57ace88b1ac615b7e815bd7"}, + {file = "dill-0.3.8.tar.gz", hash = "sha256:3ebe3c479ad625c4553aca177444d89b486b1d84982eeacded644afc0cf797ca"}, +] + +[package.extras] +graph = ["objgraph (>=1.7.2)"] +profile = ["gprof2dot (>=2022.7.29)"] + [[package]] name = "distlib" version = "0.3.8" description = "Distribution utilities" -optional = false +optional = true python-versions = "*" files = [ {file = "distlib-0.3.8-py2.py3-none-any.whl", hash = "sha256:034db59a0b96f8ca18035f36290806a9a6e6bd9d1ff91e45a7f172eb17e51784"}, @@ -658,7 +792,7 @@ files = [ name = "exceptiongroup" version = "1.2.0" description = "Backport of PEP 654 (exception groups)" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "exceptiongroup-1.2.0-py3-none-any.whl", hash = "sha256:4bfd3996ac73b41e9b9628b04e079f193850720ea5945fc96a08633c66912f14"}, @@ -706,17 +840,106 @@ docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1 testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"] typing = ["typing-extensions (>=4.8)"] +[[package]] +name = "frozenlist" +version = "1.4.1" +description = "A list-like structure which implements collections.abc.MutableSequence" +optional = false +python-versions = ">=3.8" +files = [ + {file = "frozenlist-1.4.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f9aa1878d1083b276b0196f2dfbe00c9b7e752475ed3b682025ff20c1c1f51ac"}, + {file = "frozenlist-1.4.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:29acab3f66f0f24674b7dc4736477bcd4bc3ad4b896f5f45379a67bce8b96868"}, + {file = "frozenlist-1.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:74fb4bee6880b529a0c6560885fce4dc95936920f9f20f53d99a213f7bf66776"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:590344787a90ae57d62511dd7c736ed56b428f04cd8c161fcc5e7232c130c69a"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:068b63f23b17df8569b7fdca5517edef76171cf3897eb68beb01341131fbd2ad"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5c849d495bf5154cd8da18a9eb15db127d4dba2968d88831aff6f0331ea9bd4c"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9750cc7fe1ae3b1611bb8cfc3f9ec11d532244235d75901fb6b8e42ce9229dfe"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9b2de4cf0cdd5bd2dee4c4f63a653c61d2408055ab77b151c1957f221cabf2a"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0633c8d5337cb5c77acbccc6357ac49a1770b8c487e5b3505c57b949b4b82e98"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:27657df69e8801be6c3638054e202a135c7f299267f1a55ed3a598934f6c0d75"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:f9a3ea26252bd92f570600098783d1371354d89d5f6b7dfd87359d669f2109b5"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:4f57dab5fe3407b6c0c1cc907ac98e8a189f9e418f3b6e54d65a718aaafe3950"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:e02a0e11cf6597299b9f3bbd3f93d79217cb90cfd1411aec33848b13f5c656cc"}, + {file = "frozenlist-1.4.1-cp310-cp310-win32.whl", hash = "sha256:a828c57f00f729620a442881cc60e57cfcec6842ba38e1b19fd3e47ac0ff8dc1"}, + {file = "frozenlist-1.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:f56e2333dda1fe0f909e7cc59f021eba0d2307bc6f012a1ccf2beca6ba362439"}, + {file = "frozenlist-1.4.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:a0cb6f11204443f27a1628b0e460f37fb30f624be6051d490fa7d7e26d4af3d0"}, + {file = "frozenlist-1.4.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b46c8ae3a8f1f41a0d2ef350c0b6e65822d80772fe46b653ab6b6274f61d4a49"}, + {file = "frozenlist-1.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:fde5bd59ab5357e3853313127f4d3565fc7dad314a74d7b5d43c22c6a5ed2ced"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:722e1124aec435320ae01ee3ac7bec11a5d47f25d0ed6328f2273d287bc3abb0"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2471c201b70d58a0f0c1f91261542a03d9a5e088ed3dc6c160d614c01649c106"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c757a9dd70d72b076d6f68efdbb9bc943665ae954dad2801b874c8c69e185068"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f146e0911cb2f1da549fc58fc7bcd2b836a44b79ef871980d605ec392ff6b0d2"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f9c515e7914626b2a2e1e311794b4c35720a0be87af52b79ff8e1429fc25f19"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:c302220494f5c1ebeb0912ea782bcd5e2f8308037b3c7553fad0e48ebad6ad82"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:442acde1e068288a4ba7acfe05f5f343e19fac87bfc96d89eb886b0363e977ec"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:1b280e6507ea8a4fa0c0a7150b4e526a8d113989e28eaaef946cc77ffd7efc0a"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:fe1a06da377e3a1062ae5fe0926e12b84eceb8a50b350ddca72dc85015873f74"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:db9e724bebd621d9beca794f2a4ff1d26eed5965b004a97f1f1685a173b869c2"}, + {file = "frozenlist-1.4.1-cp311-cp311-win32.whl", hash = "sha256:e774d53b1a477a67838a904131c4b0eef6b3d8a651f8b138b04f748fccfefe17"}, + {file = "frozenlist-1.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:fb3c2db03683b5767dedb5769b8a40ebb47d6f7f45b1b3e3b4b51ec8ad9d9825"}, + {file = "frozenlist-1.4.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:1979bc0aeb89b33b588c51c54ab0161791149f2461ea7c7c946d95d5f93b56ae"}, + {file = "frozenlist-1.4.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:cc7b01b3754ea68a62bd77ce6020afaffb44a590c2289089289363472d13aedb"}, + {file = "frozenlist-1.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c9c92be9fd329ac801cc420e08452b70e7aeab94ea4233a4804f0915c14eba9b"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c3894db91f5a489fc8fa6a9991820f368f0b3cbdb9cd8849547ccfab3392d86"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ba60bb19387e13597fb059f32cd4d59445d7b18b69a745b8f8e5db0346f33480"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8aefbba5f69d42246543407ed2461db31006b0f76c4e32dfd6f42215a2c41d09"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:780d3a35680ced9ce682fbcf4cb9c2bad3136eeff760ab33707b71db84664e3a"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9acbb16f06fe7f52f441bb6f413ebae6c37baa6ef9edd49cdd567216da8600cd"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:23b701e65c7b36e4bf15546a89279bd4d8675faabc287d06bbcfac7d3c33e1e6"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:3e0153a805a98f5ada7e09826255ba99fb4f7524bb81bf6b47fb702666484ae1"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:dd9b1baec094d91bf36ec729445f7769d0d0cf6b64d04d86e45baf89e2b9059b"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:1a4471094e146b6790f61b98616ab8e44f72661879cc63fa1049d13ef711e71e"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:5667ed53d68d91920defdf4035d1cdaa3c3121dc0b113255124bcfada1cfa1b8"}, + {file = "frozenlist-1.4.1-cp312-cp312-win32.whl", hash = "sha256:beee944ae828747fd7cb216a70f120767fc9f4f00bacae8543c14a6831673f89"}, + {file = "frozenlist-1.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:64536573d0a2cb6e625cf309984e2d873979709f2cf22839bf2d61790b448ad5"}, + {file = "frozenlist-1.4.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:20b51fa3f588ff2fe658663db52a41a4f7aa6c04f6201449c6c7c476bd255c0d"}, + {file = "frozenlist-1.4.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:410478a0c562d1a5bcc2f7ea448359fcb050ed48b3c6f6f4f18c313a9bdb1826"}, + {file = "frozenlist-1.4.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:c6321c9efe29975232da3bd0af0ad216800a47e93d763ce64f291917a381b8eb"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:48f6a4533887e189dae092f1cf981f2e3885175f7a0f33c91fb5b7b682b6bab6"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6eb73fa5426ea69ee0e012fb59cdc76a15b1283d6e32e4f8dc4482ec67d1194d"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fbeb989b5cc29e8daf7f976b421c220f1b8c731cbf22b9130d8815418ea45887"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:32453c1de775c889eb4e22f1197fe3bdfe457d16476ea407472b9442e6295f7a"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:693945278a31f2086d9bf3df0fe8254bbeaef1fe71e1351c3bd730aa7d31c41b"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:1d0ce09d36d53bbbe566fe296965b23b961764c0bcf3ce2fa45f463745c04701"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:3a670dc61eb0d0eb7080890c13de3066790f9049b47b0de04007090807c776b0"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:dca69045298ce5c11fd539682cff879cc1e664c245d1c64da929813e54241d11"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:a06339f38e9ed3a64e4c4e43aec7f59084033647f908e4259d279a52d3757d09"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b7f2f9f912dca3934c1baec2e4585a674ef16fe00218d833856408c48d5beee7"}, + {file = "frozenlist-1.4.1-cp38-cp38-win32.whl", hash = "sha256:e7004be74cbb7d9f34553a5ce5fb08be14fb33bc86f332fb71cbe5216362a497"}, + {file = "frozenlist-1.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:5a7d70357e7cee13f470c7883a063aae5fe209a493c57d86eb7f5a6f910fae09"}, + {file = "frozenlist-1.4.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:bfa4a17e17ce9abf47a74ae02f32d014c5e9404b6d9ac7f729e01562bbee601e"}, + {file = "frozenlist-1.4.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b7e3ed87d4138356775346e6845cccbe66cd9e207f3cd11d2f0b9fd13681359d"}, + {file = "frozenlist-1.4.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c99169d4ff810155ca50b4da3b075cbde79752443117d89429595c2e8e37fed8"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:edb678da49d9f72c9f6c609fbe41a5dfb9a9282f9e6a2253d5a91e0fc382d7c0"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6db4667b187a6742b33afbbaf05a7bc551ffcf1ced0000a571aedbb4aa42fc7b"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:55fdc093b5a3cb41d420884cdaf37a1e74c3c37a31f46e66286d9145d2063bd0"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:82e8211d69a4f4bc360ea22cd6555f8e61a1bd211d1d5d39d3d228b48c83a897"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89aa2c2eeb20957be2d950b85974b30a01a762f3308cd02bb15e1ad632e22dc7"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9d3e0c25a2350080e9319724dede4f31f43a6c9779be48021a7f4ebde8b2d742"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:7268252af60904bf52c26173cbadc3a071cece75f873705419c8681f24d3edea"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:0c250a29735d4f15321007fb02865f0e6b6a41a6b88f1f523ca1596ab5f50bd5"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:96ec70beabbd3b10e8bfe52616a13561e58fe84c0101dd031dc78f250d5128b9"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:23b2d7679b73fe0e5a4560b672a39f98dfc6f60df63823b0a9970525325b95f6"}, + {file = "frozenlist-1.4.1-cp39-cp39-win32.whl", hash = "sha256:a7496bfe1da7fb1a4e1cc23bb67c58fab69311cc7d32b5a99c2007b4b2a0e932"}, + {file = "frozenlist-1.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:e6a20a581f9ce92d389a8c7d7c3dd47c81fd5d6e655c8dddf341e14aa48659d0"}, + {file = "frozenlist-1.4.1-py3-none-any.whl", hash = "sha256:04ced3e6a46b4cfffe20f9ae482818e34eba9b5fb0ce4056e4cc9b6e212d09b7"}, + {file = "frozenlist-1.4.1.tar.gz", hash = "sha256:c037a86e8513059a2613aaba4d817bb90b9d9b6b69aace3ce9c877e8c8ed402b"}, +] + [[package]] name = "fsspec" -version = "2024.3.1" +version = "2024.2.0" description = "File-system specification" optional = false python-versions = ">=3.8" files = [ - {file = "fsspec-2024.3.1-py3-none-any.whl", hash = "sha256:918d18d41bf73f0e2b261824baeb1b124bcf771767e3a26425cd7dec3332f512"}, - {file = "fsspec-2024.3.1.tar.gz", hash = "sha256:f39780e282d7d117ffb42bb96992f8a90795e4d0fb0f661a70ca39fe9c43ded9"}, + {file = "fsspec-2024.2.0-py3-none-any.whl", hash = "sha256:817f969556fa5916bc682e02ca2045f96ff7f586d45110fcb76022063ad2c7d8"}, + {file = "fsspec-2024.2.0.tar.gz", hash = "sha256:b6ad1a679f760dda52b1168c859d01b7b80648ea6f7f7c7f5a8a91dc3f3ecb84"}, ] +[package.dependencies] +aiohttp = {version = "<4.0.0a0 || >4.0.0a0,<4.0.0a1 || >4.0.0a1", optional = true, markers = "extra == \"http\""} + [package.extras] abfs = ["adlfs"] adl = ["adlfs"] @@ -940,7 +1163,7 @@ mujoco = "^2.3.7" type = "git" url = "git@github.com:huggingface/gym-xarm.git" reference = "HEAD" -resolved_reference = "08ddd5a9400783a6898bbf3c3014fc5da3961b9d" +resolved_reference = "6a88f7d63833705dfbec4b997bf36cac6b4a448c" [[package]] name = "gymnasium" @@ -1085,7 +1308,7 @@ packaging = "*" name = "identify" version = "2.5.35" description = "File identification library for Python" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "identify-2.5.35-py2.py3-none-any.whl", hash = "sha256:c4de0081837b211594f8e877a6b4fad7ca32bbfc1a9307fdd61c28bfe923f13e"}, @@ -1118,9 +1341,10 @@ files = [ ] [package.dependencies] -av = {version = "*", optional = true, markers = "extra == \"pyav\""} +imageio-ffmpeg = {version = "*", optional = true, markers = "extra == \"ffmpeg\""} numpy = "*" pillow = ">=8.3.2" +psutil = {version = "*", optional = true, markers = "extra == \"ffmpeg\""} [package.extras] all-plugins = ["astropy", "av", "imageio-ffmpeg", "pillow-heif", "psutil", "tifffile"] @@ -1180,7 +1404,7 @@ testing = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", name = "iniconfig" version = "2.0.0" description = "brain-dead simple config-ini parsing" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, @@ -1470,30 +1694,6 @@ files = [ {file = "MarkupSafe-2.1.5.tar.gz", hash = "sha256:d283d37a890ba4c1ae73ffadf8046435c76e7bc2247bbb63c00bd1a709c6544b"}, ] -[[package]] -name = "moviepy" -version = "1.0.3" -description = "Video editing with Python" -optional = false -python-versions = "*" -files = [ - {file = "moviepy-1.0.3.tar.gz", hash = "sha256:2884e35d1788077db3ff89e763c5ba7bfddbd7ae9108c9bc809e7ba58fa433f5"}, -] - -[package.dependencies] -decorator = ">=4.0.2,<5.0" -imageio = {version = ">=2.5,<3.0", markers = "python_version >= \"3.4\""} -imageio_ffmpeg = {version = ">=0.2.0", markers = "python_version >= \"3.4\""} -numpy = {version = ">=1.17.3", markers = "python_version > \"2.7\""} -proglog = "<=1.0.0" -requests = ">=2.8.1,<3.0" -tqdm = ">=4.11.2,<5.0" - -[package.extras] -doc = ["Sphinx (>=1.5.2,<2.0)", "numpydoc (>=0.6.0,<1.0)", "pygame (>=1.9.3,<2.0)", "sphinx_rtd_theme (>=0.1.10b0,<1.0)"] -optional = ["matplotlib (>=2.0.0,<3.0)", "opencv-python (>=3.0,<4.0)", "scikit-image (>=0.13.0,<1.0)", "scikit-learn", "scipy (>=0.19.0,<1.5)", "youtube_dl"] -test = ["coverage (<5.0)", "coveralls (>=1.1,<2.0)", "pytest (>=3.0.0,<4.0)", "pytest-cov (>=2.5.1,<3.0)", "requests (>=2.8.1,<3.0)"] - [[package]] name = "mpmath" version = "1.3.0" @@ -1551,6 +1751,129 @@ glfw = "*" numpy = "*" pyopengl = "*" +[[package]] +name = "multidict" +version = "6.0.5" +description = "multidict implementation" +optional = false +python-versions = ">=3.7" +files = [ + {file = "multidict-6.0.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:228b644ae063c10e7f324ab1ab6b548bdf6f8b47f3ec234fef1093bc2735e5f9"}, + {file = "multidict-6.0.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:896ebdcf62683551312c30e20614305f53125750803b614e9e6ce74a96232604"}, + {file = "multidict-6.0.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:411bf8515f3be9813d06004cac41ccf7d1cd46dfe233705933dd163b60e37600"}, + {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d147090048129ce3c453f0292e7697d333db95e52616b3793922945804a433c"}, + {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:215ed703caf15f578dca76ee6f6b21b7603791ae090fbf1ef9d865571039ade5"}, + {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c6390cf87ff6234643428991b7359b5f59cc15155695deb4eda5c777d2b880f"}, + {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21fd81c4ebdb4f214161be351eb5bcf385426bf023041da2fd9e60681f3cebae"}, + {file = "multidict-6.0.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3cc2ad10255f903656017363cd59436f2111443a76f996584d1077e43ee51182"}, + {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:6939c95381e003f54cd4c5516740faba40cf5ad3eeff460c3ad1d3e0ea2549bf"}, + {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:220dd781e3f7af2c2c1053da9fa96d9cf3072ca58f057f4c5adaaa1cab8fc442"}, + {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:766c8f7511df26d9f11cd3a8be623e59cca73d44643abab3f8c8c07620524e4a"}, + {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:fe5d7785250541f7f5019ab9cba2c71169dc7d74d0f45253f8313f436458a4ef"}, + {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c1c1496e73051918fcd4f58ff2e0f2f3066d1c76a0c6aeffd9b45d53243702cc"}, + {file = "multidict-6.0.5-cp310-cp310-win32.whl", hash = "sha256:7afcdd1fc07befad18ec4523a782cde4e93e0a2bf71239894b8d61ee578c1319"}, + {file = "multidict-6.0.5-cp310-cp310-win_amd64.whl", hash = "sha256:99f60d34c048c5c2fabc766108c103612344c46e35d4ed9ae0673d33c8fb26e8"}, + {file = "multidict-6.0.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:f285e862d2f153a70586579c15c44656f888806ed0e5b56b64489afe4a2dbfba"}, + {file = "multidict-6.0.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:53689bb4e102200a4fafa9de9c7c3c212ab40a7ab2c8e474491914d2305f187e"}, + {file = "multidict-6.0.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:612d1156111ae11d14afaf3a0669ebf6c170dbb735e510a7438ffe2369a847fd"}, + {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7be7047bd08accdb7487737631d25735c9a04327911de89ff1b26b81745bd4e3"}, + {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:de170c7b4fe6859beb8926e84f7d7d6c693dfe8e27372ce3b76f01c46e489fcf"}, + {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:04bde7a7b3de05732a4eb39c94574db1ec99abb56162d6c520ad26f83267de29"}, + {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:85f67aed7bb647f93e7520633d8f51d3cbc6ab96957c71272b286b2f30dc70ed"}, + {file = "multidict-6.0.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:425bf820055005bfc8aa9a0b99ccb52cc2f4070153e34b701acc98d201693733"}, + {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:d3eb1ceec286eba8220c26f3b0096cf189aea7057b6e7b7a2e60ed36b373b77f"}, + {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:7901c05ead4b3fb75113fb1dd33eb1253c6d3ee37ce93305acd9d38e0b5f21a4"}, + {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:e0e79d91e71b9867c73323a3444724d496c037e578a0e1755ae159ba14f4f3d1"}, + {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:29bfeb0dff5cb5fdab2023a7a9947b3b4af63e9c47cae2a10ad58394b517fddc"}, + {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e030047e85cbcedbfc073f71836d62dd5dadfbe7531cae27789ff66bc551bd5e"}, + {file = "multidict-6.0.5-cp311-cp311-win32.whl", hash = "sha256:2f4848aa3baa109e6ab81fe2006c77ed4d3cd1e0ac2c1fbddb7b1277c168788c"}, + {file = "multidict-6.0.5-cp311-cp311-win_amd64.whl", hash = "sha256:2faa5ae9376faba05f630d7e5e6be05be22913782b927b19d12b8145968a85ea"}, + {file = "multidict-6.0.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:51d035609b86722963404f711db441cf7134f1889107fb171a970c9701f92e1e"}, + {file = "multidict-6.0.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:cbebcd5bcaf1eaf302617c114aa67569dd3f090dd0ce8ba9e35e9985b41ac35b"}, + {file = "multidict-6.0.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2ffc42c922dbfddb4a4c3b438eb056828719f07608af27d163191cb3e3aa6cc5"}, + {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ceb3b7e6a0135e092de86110c5a74e46bda4bd4fbfeeb3a3bcec79c0f861e450"}, + {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:79660376075cfd4b2c80f295528aa6beb2058fd289f4c9252f986751a4cd0496"}, + {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e4428b29611e989719874670fd152b6625500ad6c686d464e99f5aaeeaca175a"}, + {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d84a5c3a5f7ce6db1f999fb9438f686bc2e09d38143f2d93d8406ed2dd6b9226"}, + {file = "multidict-6.0.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:76c0de87358b192de7ea9649beb392f107dcad9ad27276324c24c91774ca5271"}, + {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:79a6d2ba910adb2cbafc95dad936f8b9386e77c84c35bc0add315b856d7c3abb"}, + {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:92d16a3e275e38293623ebf639c471d3e03bb20b8ebb845237e0d3664914caef"}, + {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:fb616be3538599e797a2017cccca78e354c767165e8858ab5116813146041a24"}, + {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:14c2976aa9038c2629efa2c148022ed5eb4cb939e15ec7aace7ca932f48f9ba6"}, + {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:435a0984199d81ca178b9ae2c26ec3d49692d20ee29bc4c11a2a8d4514c67eda"}, + {file = "multidict-6.0.5-cp312-cp312-win32.whl", hash = "sha256:9fe7b0653ba3d9d65cbe7698cca585bf0f8c83dbbcc710db9c90f478e175f2d5"}, + {file = "multidict-6.0.5-cp312-cp312-win_amd64.whl", hash = "sha256:01265f5e40f5a17f8241d52656ed27192be03bfa8764d88e8220141d1e4b3556"}, + {file = "multidict-6.0.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:19fe01cea168585ba0f678cad6f58133db2aa14eccaf22f88e4a6dccadfad8b3"}, + {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6bf7a982604375a8d49b6cc1b781c1747f243d91b81035a9b43a2126c04766f5"}, + {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:107c0cdefe028703fb5dafe640a409cb146d44a6ae201e55b35a4af8e95457dd"}, + {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:403c0911cd5d5791605808b942c88a8155c2592e05332d2bf78f18697a5fa15e"}, + {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aeaf541ddbad8311a87dd695ed9642401131ea39ad7bc8cf3ef3967fd093b626"}, + {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e4972624066095e52b569e02b5ca97dbd7a7ddd4294bf4e7247d52635630dd83"}, + {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:d946b0a9eb8aaa590df1fe082cee553ceab173e6cb5b03239716338629c50c7a"}, + {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:b55358304d7a73d7bdf5de62494aaf70bd33015831ffd98bc498b433dfe5b10c"}, + {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:a3145cb08d8625b2d3fee1b2d596a8766352979c9bffe5d7833e0503d0f0b5e5"}, + {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:d65f25da8e248202bd47445cec78e0025c0fe7582b23ec69c3b27a640dd7a8e3"}, + {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:c9bf56195c6bbd293340ea82eafd0071cb3d450c703d2c93afb89f93b8386ccc"}, + {file = "multidict-6.0.5-cp37-cp37m-win32.whl", hash = "sha256:69db76c09796b313331bb7048229e3bee7928eb62bab5e071e9f7fcc4879caee"}, + {file = "multidict-6.0.5-cp37-cp37m-win_amd64.whl", hash = "sha256:fce28b3c8a81b6b36dfac9feb1de115bab619b3c13905b419ec71d03a3fc1423"}, + {file = "multidict-6.0.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:76f067f5121dcecf0d63a67f29080b26c43c71a98b10c701b0677e4a065fbd54"}, + {file = "multidict-6.0.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b82cc8ace10ab5bd93235dfaab2021c70637005e1ac787031f4d1da63d493c1d"}, + {file = "multidict-6.0.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5cb241881eefd96b46f89b1a056187ea8e9ba14ab88ba632e68d7a2ecb7aadf7"}, + {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e8e94e6912639a02ce173341ff62cc1201232ab86b8a8fcc05572741a5dc7d93"}, + {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:09a892e4a9fb47331da06948690ae38eaa2426de97b4ccbfafbdcbe5c8f37ff8"}, + {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:55205d03e8a598cfc688c71ca8ea5f66447164efff8869517f175ea632c7cb7b"}, + {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:37b15024f864916b4951adb95d3a80c9431299080341ab9544ed148091b53f50"}, + {file = "multidict-6.0.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f2a1dee728b52b33eebff5072817176c172050d44d67befd681609b4746e1c2e"}, + {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:edd08e6f2f1a390bf137080507e44ccc086353c8e98c657e666c017718561b89"}, + {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:60d698e8179a42ec85172d12f50b1668254628425a6bd611aba022257cac1386"}, + {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:3d25f19500588cbc47dc19081d78131c32637c25804df8414463ec908631e453"}, + {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:4cc0ef8b962ac7a5e62b9e826bd0cd5040e7d401bc45a6835910ed699037a461"}, + {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:eca2e9d0cc5a889850e9bbd68e98314ada174ff6ccd1129500103df7a94a7a44"}, + {file = "multidict-6.0.5-cp38-cp38-win32.whl", hash = "sha256:4a6a4f196f08c58c59e0b8ef8ec441d12aee4125a7d4f4fef000ccb22f8d7241"}, + {file = "multidict-6.0.5-cp38-cp38-win_amd64.whl", hash = "sha256:0275e35209c27a3f7951e1ce7aaf93ce0d163b28948444bec61dd7badc6d3f8c"}, + {file = "multidict-6.0.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:e7be68734bd8c9a513f2b0cfd508802d6609da068f40dc57d4e3494cefc92929"}, + {file = "multidict-6.0.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1d9ea7a7e779d7a3561aade7d596649fbecfa5c08a7674b11b423783217933f9"}, + {file = "multidict-6.0.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ea1456df2a27c73ce51120fa2f519f1bea2f4a03a917f4a43c8707cf4cbbae1a"}, + {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cf590b134eb70629e350691ecca88eac3e3b8b3c86992042fb82e3cb1830d5e1"}, + {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5c0631926c4f58e9a5ccce555ad7747d9a9f8b10619621f22f9635f069f6233e"}, + {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dce1c6912ab9ff5f179eaf6efe7365c1f425ed690b03341911bf4939ef2f3046"}, + {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0868d64af83169e4d4152ec612637a543f7a336e4a307b119e98042e852ad9c"}, + {file = "multidict-6.0.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:141b43360bfd3bdd75f15ed811850763555a251e38b2405967f8e25fb43f7d40"}, + {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:7df704ca8cf4a073334e0427ae2345323613e4df18cc224f647f251e5e75a527"}, + {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:6214c5a5571802c33f80e6c84713b2c79e024995b9c5897f794b43e714daeec9"}, + {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:cd6c8fca38178e12c00418de737aef1261576bd1b6e8c6134d3e729a4e858b38"}, + {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:e02021f87a5b6932fa6ce916ca004c4d441509d33bbdbeca70d05dff5e9d2479"}, + {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ebd8d160f91a764652d3e51ce0d2956b38efe37c9231cd82cfc0bed2e40b581c"}, + {file = "multidict-6.0.5-cp39-cp39-win32.whl", hash = "sha256:04da1bb8c8dbadf2a18a452639771951c662c5ad03aefe4884775454be322c9b"}, + {file = "multidict-6.0.5-cp39-cp39-win_amd64.whl", hash = "sha256:d6f6d4f185481c9669b9447bf9d9cf3b95a0e9df9d169bbc17e363b7d5487755"}, + {file = "multidict-6.0.5-py3-none-any.whl", hash = "sha256:0d63c74e3d7ab26de115c49bffc92cc77ed23395303d496eae515d4204a625e7"}, + {file = "multidict-6.0.5.tar.gz", hash = "sha256:f7e301075edaf50500f0b341543c41194d8df3ae5caf4702f2095f3ca73dd8da"}, +] + +[[package]] +name = "multiprocess" +version = "0.70.16" +description = "better multiprocessing and multithreading in Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "multiprocess-0.70.16-pp310-pypy310_pp73-macosx_10_13_x86_64.whl", hash = "sha256:476887be10e2f59ff183c006af746cb6f1fd0eadcfd4ef49e605cbe2659920ee"}, + {file = "multiprocess-0.70.16-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:d951bed82c8f73929ac82c61f01a7b5ce8f3e5ef40f5b52553b4f547ce2b08ec"}, + {file = "multiprocess-0.70.16-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:37b55f71c07e2d741374998c043b9520b626a8dddc8b3129222ca4f1a06ef67a"}, + {file = "multiprocess-0.70.16-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:ba8c31889abf4511c7308a8c52bb4a30b9d590e7f58523302ba00237702ca054"}, + {file = "multiprocess-0.70.16-pp39-pypy39_pp73-macosx_10_13_x86_64.whl", hash = "sha256:0dfd078c306e08d46d7a8d06fb120313d87aa43af60d66da43ffff40b44d2f41"}, + {file = "multiprocess-0.70.16-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:e7b9d0f307cd9bd50851afaac0dba2cb6c44449efff697df7c7645f7d3f2be3a"}, + {file = "multiprocess-0.70.16-py310-none-any.whl", hash = "sha256:c4a9944c67bd49f823687463660a2d6daae94c289adff97e0f9d696ba6371d02"}, + {file = "multiprocess-0.70.16-py311-none-any.whl", hash = "sha256:af4cabb0dac72abfb1e794fa7855c325fd2b55a10a44628a3c1ad3311c04127a"}, + {file = "multiprocess-0.70.16-py312-none-any.whl", hash = "sha256:fc0544c531920dde3b00c29863377f87e1632601092ea2daca74e4beb40faa2e"}, + {file = "multiprocess-0.70.16-py38-none-any.whl", hash = "sha256:a71d82033454891091a226dfc319d0cfa8019a4e888ef9ca910372a446de4435"}, + {file = "multiprocess-0.70.16-py39-none-any.whl", hash = "sha256:a0bafd3ae1b732eac64be2e72038231c1ba97724b60b09400d68f229fcc2fbf3"}, + {file = "multiprocess-0.70.16.tar.gz", hash = "sha256:161af703d4652a0e1410be6abccecde4a7ddffd19341be0a7011b94aeb171ac1"}, +] + +[package.dependencies] +dill = ">=0.3.8" + [[package]] name = "networkx" version = "3.2.1" @@ -1573,7 +1896,7 @@ test = ["pytest (>=7.2)", "pytest-cov (>=4.0)"] name = "nodeenv" version = "1.8.0" description = "Node.js virtual environment builder" -optional = false +optional = true python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*" files = [ {file = "nodeenv-1.8.0-py2.py3-none-any.whl", hash = "sha256:df865724bb3c3adc86b3876fa209771517b0cfe596beff01a92700e0e8be4cec"}, @@ -1754,47 +2077,47 @@ files = [ [[package]] name = "pandas" -version = "2.2.1" +version = "2.2.2" description = "Powerful data structures for data analysis, time series, and statistics" optional = false python-versions = ">=3.9" files = [ - {file = "pandas-2.2.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8df8612be9cd1c7797c93e1c5df861b2ddda0b48b08f2c3eaa0702cf88fb5f88"}, - {file = "pandas-2.2.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0f573ab277252ed9aaf38240f3b54cfc90fff8e5cab70411ee1d03f5d51f3944"}, - {file = "pandas-2.2.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f02a3a6c83df4026e55b63c1f06476c9aa3ed6af3d89b4f04ea656ccdaaaa359"}, - {file = "pandas-2.2.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c38ce92cb22a4bea4e3929429aa1067a454dcc9c335799af93ba9be21b6beb51"}, - {file = "pandas-2.2.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:c2ce852e1cf2509a69e98358e8458775f89599566ac3775e70419b98615f4b06"}, - {file = "pandas-2.2.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:53680dc9b2519cbf609c62db3ed7c0b499077c7fefda564e330286e619ff0dd9"}, - {file = "pandas-2.2.1-cp310-cp310-win_amd64.whl", hash = "sha256:94e714a1cca63e4f5939cdce5f29ba8d415d85166be3441165edd427dc9f6bc0"}, - {file = "pandas-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f821213d48f4ab353d20ebc24e4faf94ba40d76680642fb7ce2ea31a3ad94f9b"}, - {file = "pandas-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c70e00c2d894cb230e5c15e4b1e1e6b2b478e09cf27cc593a11ef955b9ecc81a"}, - {file = "pandas-2.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e97fbb5387c69209f134893abc788a6486dbf2f9e511070ca05eed4b930b1b02"}, - {file = "pandas-2.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:101d0eb9c5361aa0146f500773395a03839a5e6ecde4d4b6ced88b7e5a1a6403"}, - {file = "pandas-2.2.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:7d2ed41c319c9fb4fd454fe25372028dfa417aacb9790f68171b2e3f06eae8cd"}, - {file = "pandas-2.2.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:af5d3c00557d657c8773ef9ee702c61dd13b9d7426794c9dfeb1dc4a0bf0ebc7"}, - {file = "pandas-2.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:06cf591dbaefb6da9de8472535b185cba556d0ce2e6ed28e21d919704fef1a9e"}, - {file = "pandas-2.2.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:88ecb5c01bb9ca927ebc4098136038519aa5d66b44671861ffab754cae75102c"}, - {file = "pandas-2.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:04f6ec3baec203c13e3f8b139fb0f9f86cd8c0b94603ae3ae8ce9a422e9f5bee"}, - {file = "pandas-2.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a935a90a76c44fe170d01e90a3594beef9e9a6220021acfb26053d01426f7dc2"}, - {file = "pandas-2.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c391f594aae2fd9f679d419e9a4d5ba4bce5bb13f6a989195656e7dc4b95c8f0"}, - {file = "pandas-2.2.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9d1265545f579edf3f8f0cb6f89f234f5e44ba725a34d86535b1a1d38decbccc"}, - {file = "pandas-2.2.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:11940e9e3056576ac3244baef2fedade891977bcc1cb7e5cc8f8cc7d603edc89"}, - {file = "pandas-2.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:4acf681325ee1c7f950d058b05a820441075b0dd9a2adf5c4835b9bc056bf4fb"}, - {file = "pandas-2.2.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9bd8a40f47080825af4317d0340c656744f2bfdb6819f818e6ba3cd24c0e1397"}, - {file = "pandas-2.2.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:df0c37ebd19e11d089ceba66eba59a168242fc6b7155cba4ffffa6eccdfb8f16"}, - {file = "pandas-2.2.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:739cc70eaf17d57608639e74d63387b0d8594ce02f69e7a0b046f117974b3019"}, - {file = "pandas-2.2.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f9d3558d263073ed95e46f4650becff0c5e1ffe0fc3a015de3c79283dfbdb3df"}, - {file = "pandas-2.2.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:4aa1d8707812a658debf03824016bf5ea0d516afdea29b7dc14cf687bc4d4ec6"}, - {file = "pandas-2.2.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:76f27a809cda87e07f192f001d11adc2b930e93a2b0c4a236fde5429527423be"}, - {file = "pandas-2.2.1-cp39-cp39-win_amd64.whl", hash = "sha256:1ba21b1d5c0e43416218db63037dbe1a01fc101dc6e6024bcad08123e48004ab"}, - {file = "pandas-2.2.1.tar.gz", hash = "sha256:0ab90f87093c13f3e8fa45b48ba9f39181046e8f3317d3aadb2fffbb1b978572"}, + {file = "pandas-2.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:90c6fca2acf139569e74e8781709dccb6fe25940488755716d1d354d6bc58bce"}, + {file = "pandas-2.2.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c7adfc142dac335d8c1e0dcbd37eb8617eac386596eb9e1a1b77791cf2498238"}, + {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4abfe0be0d7221be4f12552995e58723c7422c80a659da13ca382697de830c08"}, + {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8635c16bf3d99040fdf3ca3db669a7250ddf49c55dc4aa8fe0ae0fa8d6dcc1f0"}, + {file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:40ae1dffb3967a52203105a077415a86044a2bea011b5f321c6aa64b379a3f51"}, + {file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8e5a0b00e1e56a842f922e7fae8ae4077aee4af0acb5ae3622bd4b4c30aedf99"}, + {file = "pandas-2.2.2-cp310-cp310-win_amd64.whl", hash = "sha256:ddf818e4e6c7c6f4f7c8a12709696d193976b591cc7dc50588d3d1a6b5dc8772"}, + {file = "pandas-2.2.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:696039430f7a562b74fa45f540aca068ea85fa34c244d0deee539cb6d70aa288"}, + {file = "pandas-2.2.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8e90497254aacacbc4ea6ae5e7a8cd75629d6ad2b30025a4a8b09aa4faf55151"}, + {file = "pandas-2.2.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:58b84b91b0b9f4bafac2a0ac55002280c094dfc6402402332c0913a59654ab2b"}, + {file = "pandas-2.2.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d2123dc9ad6a814bcdea0f099885276b31b24f7edf40f6cdbc0912672e22eee"}, + {file = "pandas-2.2.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:2925720037f06e89af896c70bca73459d7e6a4be96f9de79e2d440bd499fe0db"}, + {file = "pandas-2.2.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0cace394b6ea70c01ca1595f839cf193df35d1575986e484ad35c4aeae7266c1"}, + {file = "pandas-2.2.2-cp311-cp311-win_amd64.whl", hash = "sha256:873d13d177501a28b2756375d59816c365e42ed8417b41665f346289adc68d24"}, + {file = "pandas-2.2.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:9dfde2a0ddef507a631dc9dc4af6a9489d5e2e740e226ad426a05cabfbd7c8ef"}, + {file = "pandas-2.2.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:e9b79011ff7a0f4b1d6da6a61aa1aa604fb312d6647de5bad20013682d1429ce"}, + {file = "pandas-2.2.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1cb51fe389360f3b5a4d57dbd2848a5f033350336ca3b340d1c53a1fad33bcad"}, + {file = "pandas-2.2.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eee3a87076c0756de40b05c5e9a6069c035ba43e8dd71c379e68cab2c20f16ad"}, + {file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:3e374f59e440d4ab45ca2fffde54b81ac3834cf5ae2cdfa69c90bc03bde04d76"}, + {file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:43498c0bdb43d55cb162cdc8c06fac328ccb5d2eabe3cadeb3529ae6f0517c32"}, + {file = "pandas-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:d187d355ecec3629624fccb01d104da7d7f391db0311145817525281e2804d23"}, + {file = "pandas-2.2.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0ca6377b8fca51815f382bd0b697a0814c8bda55115678cbc94c30aacbb6eff2"}, + {file = "pandas-2.2.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9057e6aa78a584bc93a13f0a9bf7e753a5e9770a30b4d758b8d5f2a62a9433cd"}, + {file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:001910ad31abc7bf06f49dcc903755d2f7f3a9186c0c040b827e522e9cef0863"}, + {file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66b479b0bd07204e37583c191535505410daa8df638fd8e75ae1b383851fe921"}, + {file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a77e9d1c386196879aa5eb712e77461aaee433e54c68cf253053a73b7e49c33a"}, + {file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:92fd6b027924a7e178ac202cfbe25e53368db90d56872d20ffae94b96c7acc57"}, + {file = "pandas-2.2.2-cp39-cp39-win_amd64.whl", hash = "sha256:640cef9aa381b60e296db324337a554aeeb883ead99dc8f6c18e81a93942f5f4"}, + {file = "pandas-2.2.2.tar.gz", hash = "sha256:9e79019aba43cb4fda9e4d983f8e88ca0373adbb697ae9c6c43093218de28b54"}, ] [package.dependencies] numpy = [ - {version = ">=1.26.0,<2", markers = "python_version >= \"3.12\""}, - {version = ">=1.23.2,<2", markers = "python_version == \"3.11\""}, - {version = ">=1.22.4,<2", markers = "python_version < \"3.11\""}, + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, + {version = ">=1.23.2", markers = "python_version == \"3.11\""}, + {version = ">=1.22.4", markers = "python_version < \"3.11\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -1939,7 +2262,7 @@ xmp = ["defusedxml"] name = "platformdirs" version = "4.2.0" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "platformdirs-4.2.0-py3-none-any.whl", hash = "sha256:0614df2a2f37e1a662acbd8e2b25b92ccf8632929bc6d43467e17fe89c75e068"}, @@ -1954,7 +2277,7 @@ test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4.3)", "pytest- name = "pluggy" version = "1.4.0" description = "plugin and hook calling mechanisms for python" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "pluggy-1.4.0-py3-none-any.whl", hash = "sha256:7db9f7b503d67d1c5b95f59773ebb58a8c1c288129a88665838012cfb07b8981"}, @@ -1967,13 +2290,13 @@ testing = ["pytest", "pytest-benchmark"] [[package]] name = "pre-commit" -version = "3.6.2" +version = "3.7.0" description = "A framework for managing and maintaining multi-language pre-commit hooks." -optional = false +optional = true python-versions = ">=3.9" files = [ - {file = "pre_commit-3.6.2-py2.py3-none-any.whl", hash = "sha256:ba637c2d7a670c10daedc059f5c49b5bd0aadbccfcd7ec15592cf9665117532c"}, - {file = "pre_commit-3.6.2.tar.gz", hash = "sha256:c3ef34f463045c88658c5b99f38c1e297abdcc0ff13f98d3370055fbbfabc67e"}, + {file = "pre_commit-3.7.0-py2.py3-none-any.whl", hash = "sha256:5eae9e10c2b5ac51577c3452ec0a490455c45a0533f7960f993a0d01e59decab"}, + {file = "pre_commit-3.7.0.tar.gz", hash = "sha256:e209d61b8acdcf742404408531f0c37d49d2c734fd7cff2d6076083d191cb060"}, ] [package.dependencies] @@ -1983,20 +2306,6 @@ nodeenv = ">=0.11.1" pyyaml = ">=5.1" virtualenv = ">=20.10.0" -[[package]] -name = "proglog" -version = "0.1.10" -description = "Log and progress bar manager for console, notebooks, web..." -optional = false -python-versions = "*" -files = [ - {file = "proglog-0.1.10-py3-none-any.whl", hash = "sha256:19d5da037e8c813da480b741e3fa71fb1ac0a5b02bf21c41577c7f327485ec50"}, - {file = "proglog-0.1.10.tar.gz", hash = "sha256:658c28c9c82e4caeb2f25f488fff9ceace22f8d69b15d0c1c86d64275e4ddab4"}, -] - -[package.dependencies] -tqdm = "*" - [[package]] name = "protobuf" version = "4.25.3" @@ -2045,6 +2354,65 @@ files = [ [package.extras] test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"] +[[package]] +name = "pyarrow" +version = "15.0.2" +description = "Python library for Apache Arrow" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pyarrow-15.0.2-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:88b340f0a1d05b5ccc3d2d986279045655b1fe8e41aba6ca44ea28da0d1455d8"}, + {file = "pyarrow-15.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:eaa8f96cecf32da508e6c7f69bb8401f03745c050c1dd42ec2596f2e98deecac"}, + {file = "pyarrow-15.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:23c6753ed4f6adb8461e7c383e418391b8d8453c5d67e17f416c3a5d5709afbd"}, + {file = "pyarrow-15.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f639c059035011db8c0497e541a8a45d98a58dbe34dc8fadd0ef128f2cee46e5"}, + {file = "pyarrow-15.0.2-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:290e36a59a0993e9a5224ed2fb3e53375770f07379a0ea03ee2fce2e6d30b423"}, + {file = "pyarrow-15.0.2-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:06c2bb2a98bc792f040bef31ad3e9be6a63d0cb39189227c08a7d955db96816e"}, + {file = "pyarrow-15.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:f7a197f3670606a960ddc12adbe8075cea5f707ad7bf0dffa09637fdbb89f76c"}, + {file = "pyarrow-15.0.2-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:5f8bc839ea36b1f99984c78e06e7a06054693dc2af8920f6fb416b5bca9944e4"}, + {file = "pyarrow-15.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f5e81dfb4e519baa6b4c80410421528c214427e77ca0ea9461eb4097c328fa33"}, + {file = "pyarrow-15.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3a4f240852b302a7af4646c8bfe9950c4691a419847001178662a98915fd7ee7"}, + {file = "pyarrow-15.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4e7d9cfb5a1e648e172428c7a42b744610956f3b70f524aa3a6c02a448ba853e"}, + {file = "pyarrow-15.0.2-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:2d4f905209de70c0eb5b2de6763104d5a9a37430f137678edfb9a675bac9cd98"}, + {file = "pyarrow-15.0.2-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:90adb99e8ce5f36fbecbbc422e7dcbcbed07d985eed6062e459e23f9e71fd197"}, + {file = "pyarrow-15.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:b116e7fd7889294cbd24eb90cd9bdd3850be3738d61297855a71ac3b8124ee38"}, + {file = "pyarrow-15.0.2-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:25335e6f1f07fdaa026a61c758ee7d19ce824a866b27bba744348fa73bb5a440"}, + {file = "pyarrow-15.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:90f19e976d9c3d8e73c80be84ddbe2f830b6304e4c576349d9360e335cd627fc"}, + {file = "pyarrow-15.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a22366249bf5fd40ddacc4f03cd3160f2d7c247692945afb1899bab8a140ddfb"}, + {file = "pyarrow-15.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c2a335198f886b07e4b5ea16d08ee06557e07db54a8400cc0d03c7f6a22f785f"}, + {file = "pyarrow-15.0.2-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:3e6d459c0c22f0b9c810a3917a1de3ee704b021a5fb8b3bacf968eece6df098f"}, + {file = "pyarrow-15.0.2-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:033b7cad32198754d93465dcfb71d0ba7cb7cd5c9afd7052cab7214676eec38b"}, + {file = "pyarrow-15.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:29850d050379d6e8b5a693098f4de7fd6a2bea4365bfd073d7c57c57b95041ee"}, + {file = "pyarrow-15.0.2-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:7167107d7fb6dcadb375b4b691b7e316f4368f39f6f45405a05535d7ad5e5058"}, + {file = "pyarrow-15.0.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e85241b44cc3d365ef950432a1b3bd44ac54626f37b2e3a0cc89c20e45dfd8bf"}, + {file = "pyarrow-15.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:248723e4ed3255fcd73edcecc209744d58a9ca852e4cf3d2577811b6d4b59818"}, + {file = "pyarrow-15.0.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ff3bdfe6f1b81ca5b73b70a8d482d37a766433823e0c21e22d1d7dde76ca33f"}, + {file = "pyarrow-15.0.2-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:f3d77463dee7e9f284ef42d341689b459a63ff2e75cee2b9302058d0d98fe142"}, + {file = "pyarrow-15.0.2-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:8c1faf2482fb89766e79745670cbca04e7018497d85be9242d5350cba21357e1"}, + {file = "pyarrow-15.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:28f3016958a8e45a1069303a4a4f6a7d4910643fc08adb1e2e4a7ff056272ad3"}, + {file = "pyarrow-15.0.2-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:89722cb64286ab3d4daf168386f6968c126057b8c7ec3ef96302e81d8cdb8ae4"}, + {file = "pyarrow-15.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:cd0ba387705044b3ac77b1b317165c0498299b08261d8122c96051024f953cd5"}, + {file = "pyarrow-15.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ad2459bf1f22b6a5cdcc27ebfd99307d5526b62d217b984b9f5c974651398832"}, + {file = "pyarrow-15.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58922e4bfece8b02abf7159f1f53a8f4d9f8e08f2d988109126c17c3bb261f22"}, + {file = "pyarrow-15.0.2-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:adccc81d3dc0478ea0b498807b39a8d41628fa9210729b2f718b78cb997c7c91"}, + {file = "pyarrow-15.0.2-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:8bd2baa5fe531571847983f36a30ddbf65261ef23e496862ece83bdceb70420d"}, + {file = "pyarrow-15.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:6669799a1d4ca9da9c7e06ef48368320f5856f36f9a4dd31a11839dda3f6cc8c"}, + {file = "pyarrow-15.0.2.tar.gz", hash = "sha256:9c9bc803cb3b7bfacc1e96ffbfd923601065d9d3f911179d81e72d99fd74a3d9"}, +] + +[package.dependencies] +numpy = ">=1.16.6,<2" + +[[package]] +name = "pyarrow-hotfix" +version = "0.6" +description = "" +optional = false +python-versions = ">=3.5" +files = [ + {file = "pyarrow_hotfix-0.6-py3-none-any.whl", hash = "sha256:dcc9ae2d220dff0083be6a9aa8e0cdee5182ad358d4931fce825c545e5c89178"}, + {file = "pyarrow_hotfix-0.6.tar.gz", hash = "sha256:79d3e030f7ff890d408a100ac16d6f00b14d44a502d7897cd9fc3e3a534e9945"}, +] + [[package]] name = "pycparser" version = "2.21" @@ -2060,7 +2428,7 @@ files = [ name = "pygame" version = "2.5.2" description = "Python Game Development" -optional = false +optional = true python-versions = ">=3.6" files = [ {file = "pygame-2.5.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a0769eb628c818761755eb0a0ca8216b95270ea8cbcbc82227e39ac9644643da"}, @@ -2234,7 +2602,7 @@ files = [ name = "pytest" version = "8.1.1" description = "pytest: simple powerful testing with Python" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "pytest-8.1.1-py3-none-any.whl", hash = "sha256:2a8386cfc11fa9d2c50ee7b2a57e7d898ef90470a7a34c4b949ff59662bb78b7"}, @@ -2256,7 +2624,7 @@ testing = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygm name = "pytest-cov" version = "5.0.0" description = "Pytest plugin for measuring coverage." -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "pytest-cov-5.0.0.tar.gz", hash = "sha256:5837b58e9f6ebd335b0f8060eecce69b662415b16dc503883a02f45dfeb14857"}, @@ -3076,7 +3444,7 @@ all = ["defusedxml", "fsspec", "imagecodecs (>=2023.8.12)", "lxml", "matplotlib" name = "tomli" version = "2.0.1" description = "A lil' TOML parser" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, @@ -3214,7 +3582,7 @@ zstd = ["zstandard (>=0.18.0)"] name = "virtualenv" version = "20.25.1" description = "Virtual Python Environment builder" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "virtualenv-20.25.1-py3-none-any.whl", hash = "sha256:961c026ac520bac5f69acb8ea063e8a4f071bcc9457b9c1f28f6b085c511583a"}, @@ -3285,6 +3653,226 @@ MarkupSafe = ">=2.1.1" [package.extras] watchdog = ["watchdog (>=2.3)"] +[[package]] +name = "xxhash" +version = "3.4.1" +description = "Python binding for xxHash" +optional = false +python-versions = ">=3.7" +files = [ + {file = "xxhash-3.4.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:91dbfa55346ad3e18e738742236554531a621042e419b70ad8f3c1d9c7a16e7f"}, + {file = "xxhash-3.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:665a65c2a48a72068fcc4d21721510df5f51f1142541c890491afc80451636d2"}, + {file = "xxhash-3.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb11628470a6004dc71a09fe90c2f459ff03d611376c1debeec2d648f44cb693"}, + {file = "xxhash-3.4.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5bef2a7dc7b4f4beb45a1edbba9b9194c60a43a89598a87f1a0226d183764189"}, + {file = "xxhash-3.4.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9c0f7b2d547d72c7eda7aa817acf8791f0146b12b9eba1d4432c531fb0352228"}, + {file = "xxhash-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:00f2fdef6b41c9db3d2fc0e7f94cb3db86693e5c45d6de09625caad9a469635b"}, + {file = "xxhash-3.4.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:23cfd9ca09acaf07a43e5a695143d9a21bf00f5b49b15c07d5388cadf1f9ce11"}, + {file = "xxhash-3.4.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:6a9ff50a3cf88355ca4731682c168049af1ca222d1d2925ef7119c1a78e95b3b"}, + {file = "xxhash-3.4.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:f1d7c69a1e9ca5faa75546fdd267f214f63f52f12692f9b3a2f6467c9e67d5e7"}, + {file = "xxhash-3.4.1-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:672b273040d5d5a6864a36287f3514efcd1d4b1b6a7480f294c4b1d1ee1b8de0"}, + {file = "xxhash-3.4.1-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:4178f78d70e88f1c4a89ff1ffe9f43147185930bb962ee3979dba15f2b1cc799"}, + {file = "xxhash-3.4.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:9804b9eb254d4b8cc83ab5a2002128f7d631dd427aa873c8727dba7f1f0d1c2b"}, + {file = "xxhash-3.4.1-cp310-cp310-win32.whl", hash = "sha256:c09c49473212d9c87261d22c74370457cfff5db2ddfc7fd1e35c80c31a8c14ce"}, + {file = "xxhash-3.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:ebbb1616435b4a194ce3466d7247df23499475c7ed4eb2681a1fa42ff766aff6"}, + {file = "xxhash-3.4.1-cp310-cp310-win_arm64.whl", hash = "sha256:25dc66be3db54f8a2d136f695b00cfe88018e59ccff0f3b8f545869f376a8a46"}, + {file = "xxhash-3.4.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:58c49083801885273e262c0f5bbeac23e520564b8357fbb18fb94ff09d3d3ea5"}, + {file = "xxhash-3.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b526015a973bfbe81e804a586b703f163861da36d186627e27524f5427b0d520"}, + {file = "xxhash-3.4.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:36ad4457644c91a966f6fe137d7467636bdc51a6ce10a1d04f365c70d6a16d7e"}, + {file = "xxhash-3.4.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:248d3e83d119770f96003271fe41e049dd4ae52da2feb8f832b7a20e791d2920"}, + {file = "xxhash-3.4.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2070b6d5bbef5ee031666cf21d4953c16e92c2f8a24a94b5c240f8995ba3b1d0"}, + {file = "xxhash-3.4.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b2746035f518f0410915e247877f7df43ef3372bf36cfa52cc4bc33e85242641"}, + {file = "xxhash-3.4.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2a8ba6181514681c2591840d5632fcf7356ab287d4aff1c8dea20f3c78097088"}, + {file = "xxhash-3.4.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:0aac5010869240e95f740de43cd6a05eae180c59edd182ad93bf12ee289484fa"}, + {file = "xxhash-3.4.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:4cb11d8debab1626181633d184b2372aaa09825bde709bf927704ed72765bed1"}, + {file = "xxhash-3.4.1-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:b29728cff2c12f3d9f1d940528ee83918d803c0567866e062683f300d1d2eff3"}, + {file = "xxhash-3.4.1-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:a15cbf3a9c40672523bdb6ea97ff74b443406ba0ab9bca10ceccd9546414bd84"}, + {file = "xxhash-3.4.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:6e66df260fed01ed8ea790c2913271641c58481e807790d9fca8bfd5a3c13844"}, + {file = "xxhash-3.4.1-cp311-cp311-win32.whl", hash = "sha256:e867f68a8f381ea12858e6d67378c05359d3a53a888913b5f7d35fbf68939d5f"}, + {file = "xxhash-3.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:200a5a3ad9c7c0c02ed1484a1d838b63edcf92ff538770ea07456a3732c577f4"}, + {file = "xxhash-3.4.1-cp311-cp311-win_arm64.whl", hash = "sha256:1d03f1c0d16d24ea032e99f61c552cb2b77d502e545187338bea461fde253583"}, + {file = "xxhash-3.4.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:c4bbba9b182697a52bc0c9f8ec0ba1acb914b4937cd4a877ad78a3b3eeabefb3"}, + {file = "xxhash-3.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:9fd28a9da300e64e434cfc96567a8387d9a96e824a9be1452a1e7248b7763b78"}, + {file = "xxhash-3.4.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6066d88c9329ab230e18998daec53d819daeee99d003955c8db6fc4971b45ca3"}, + {file = "xxhash-3.4.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:93805bc3233ad89abf51772f2ed3355097a5dc74e6080de19706fc447da99cd3"}, + {file = "xxhash-3.4.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:64da57d5ed586ebb2ecdde1e997fa37c27fe32fe61a656b77fabbc58e6fbff6e"}, + {file = "xxhash-3.4.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a97322e9a7440bf3c9805cbaac090358b43f650516486746f7fa482672593df"}, + {file = "xxhash-3.4.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bbe750d512982ee7d831838a5dee9e9848f3fb440e4734cca3f298228cc957a6"}, + {file = "xxhash-3.4.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:fd79d4087727daf4d5b8afe594b37d611ab95dc8e29fe1a7517320794837eb7d"}, + {file = "xxhash-3.4.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:743612da4071ff9aa4d055f3f111ae5247342931dedb955268954ef7201a71ff"}, + {file = "xxhash-3.4.1-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:b41edaf05734092f24f48c0958b3c6cbaaa5b7e024880692078c6b1f8247e2fc"}, + {file = "xxhash-3.4.1-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:a90356ead70d715fe64c30cd0969072de1860e56b78adf7c69d954b43e29d9fa"}, + {file = "xxhash-3.4.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ac56eebb364e44c85e1d9e9cc5f6031d78a34f0092fea7fc80478139369a8b4a"}, + {file = "xxhash-3.4.1-cp312-cp312-win32.whl", hash = "sha256:911035345932a153c427107397c1518f8ce456f93c618dd1c5b54ebb22e73747"}, + {file = "xxhash-3.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:f31ce76489f8601cc7b8713201ce94b4bd7b7ce90ba3353dccce7e9e1fee71fa"}, + {file = "xxhash-3.4.1-cp312-cp312-win_arm64.whl", hash = "sha256:b5beb1c6a72fdc7584102f42c4d9df232ee018ddf806e8c90906547dfb43b2da"}, + {file = "xxhash-3.4.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6d42b24d1496deb05dee5a24ed510b16de1d6c866c626c2beb11aebf3be278b9"}, + {file = "xxhash-3.4.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3b685fab18876b14a8f94813fa2ca80cfb5ab6a85d31d5539b7cd749ce9e3624"}, + {file = "xxhash-3.4.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:419ffe34c17ae2df019a4685e8d3934d46b2e0bbe46221ab40b7e04ed9f11137"}, + {file = "xxhash-3.4.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0e041ce5714f95251a88670c114b748bca3bf80cc72400e9f23e6d0d59cf2681"}, + {file = "xxhash-3.4.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc860d887c5cb2f524899fb8338e1bb3d5789f75fac179101920d9afddef284b"}, + {file = "xxhash-3.4.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:312eba88ffe0a05e332e3a6f9788b73883752be63f8588a6dc1261a3eaaaf2b2"}, + {file = "xxhash-3.4.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:e01226b6b6a1ffe4e6bd6d08cfcb3ca708b16f02eb06dd44f3c6e53285f03e4f"}, + {file = "xxhash-3.4.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:9f3025a0d5d8cf406a9313cd0d5789c77433ba2004b1c75439b67678e5136537"}, + {file = "xxhash-3.4.1-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:6d3472fd4afef2a567d5f14411d94060099901cd8ce9788b22b8c6f13c606a93"}, + {file = "xxhash-3.4.1-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:43984c0a92f06cac434ad181f329a1445017c33807b7ae4f033878d860a4b0f2"}, + {file = "xxhash-3.4.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:a55e0506fdb09640a82ec4f44171273eeabf6f371a4ec605633adb2837b5d9d5"}, + {file = "xxhash-3.4.1-cp37-cp37m-win32.whl", hash = "sha256:faec30437919555b039a8bdbaba49c013043e8f76c999670aef146d33e05b3a0"}, + {file = "xxhash-3.4.1-cp37-cp37m-win_amd64.whl", hash = "sha256:c9e1b646af61f1fc7083bb7b40536be944f1ac67ef5e360bca2d73430186971a"}, + {file = "xxhash-3.4.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:961d948b7b1c1b6c08484bbce3d489cdf153e4122c3dfb07c2039621243d8795"}, + {file = "xxhash-3.4.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:719a378930504ab159f7b8e20fa2aa1896cde050011af838af7e7e3518dd82de"}, + {file = "xxhash-3.4.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:74fb5cb9406ccd7c4dd917f16630d2e5e8cbbb02fc2fca4e559b2a47a64f4940"}, + {file = "xxhash-3.4.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5dab508ac39e0ab988039bc7f962c6ad021acd81fd29145962b068df4148c476"}, + {file = "xxhash-3.4.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8c59f3e46e7daf4c589e8e853d700ef6607afa037bfad32c390175da28127e8c"}, + {file = "xxhash-3.4.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8cc07256eff0795e0f642df74ad096f8c5d23fe66bc138b83970b50fc7f7f6c5"}, + {file = "xxhash-3.4.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e9f749999ed80f3955a4af0eb18bb43993f04939350b07b8dd2f44edc98ffee9"}, + {file = "xxhash-3.4.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:7688d7c02149a90a3d46d55b341ab7ad1b4a3f767be2357e211b4e893efbaaf6"}, + {file = "xxhash-3.4.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a8b4977963926f60b0d4f830941c864bed16aa151206c01ad5c531636da5708e"}, + {file = "xxhash-3.4.1-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:8106d88da330f6535a58a8195aa463ef5281a9aa23b04af1848ff715c4398fb4"}, + {file = "xxhash-3.4.1-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:4c76a77dbd169450b61c06fd2d5d436189fc8ab7c1571d39265d4822da16df22"}, + {file = "xxhash-3.4.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:11f11357c86d83e53719c592021fd524efa9cf024dc7cb1dfb57bbbd0d8713f2"}, + {file = "xxhash-3.4.1-cp38-cp38-win32.whl", hash = "sha256:0c786a6cd74e8765c6809892a0d45886e7c3dc54de4985b4a5eb8b630f3b8e3b"}, + {file = "xxhash-3.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:aabf37fb8fa27430d50507deeab2ee7b1bcce89910dd10657c38e71fee835594"}, + {file = "xxhash-3.4.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6127813abc1477f3a83529b6bbcfeddc23162cece76fa69aee8f6a8a97720562"}, + {file = "xxhash-3.4.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ef2e194262f5db16075caea7b3f7f49392242c688412f386d3c7b07c7733a70a"}, + {file = "xxhash-3.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:71be94265b6c6590f0018bbf73759d21a41c6bda20409782d8117e76cd0dfa8b"}, + {file = "xxhash-3.4.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:10e0a619cdd1c0980e25eb04e30fe96cf8f4324758fa497080af9c21a6de573f"}, + {file = "xxhash-3.4.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fa122124d2e3bd36581dd78c0efa5f429f5220313479fb1072858188bc2d5ff1"}, + {file = "xxhash-3.4.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e17032f5a4fea0a074717fe33477cb5ee723a5f428de7563e75af64bfc1b1e10"}, + {file = "xxhash-3.4.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ca7783b20e3e4f3f52f093538895863f21d18598f9a48211ad757680c3bd006f"}, + {file = "xxhash-3.4.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:d77d09a1113899fad5f354a1eb4f0a9afcf58cefff51082c8ad643ff890e30cf"}, + {file = "xxhash-3.4.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:21287bcdd299fdc3328cc0fbbdeaa46838a1c05391264e51ddb38a3f5b09611f"}, + {file = "xxhash-3.4.1-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:dfd7a6cc483e20b4ad90224aeb589e64ec0f31e5610ab9957ff4314270b2bf31"}, + {file = "xxhash-3.4.1-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:543c7fcbc02bbb4840ea9915134e14dc3dc15cbd5a30873a7a5bf66039db97ec"}, + {file = "xxhash-3.4.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:fe0a98d990e433013f41827b62be9ab43e3cf18e08b1483fcc343bda0d691182"}, + {file = "xxhash-3.4.1-cp39-cp39-win32.whl", hash = "sha256:b9097af00ebf429cc7c0e7d2fdf28384e4e2e91008130ccda8d5ae653db71e54"}, + {file = "xxhash-3.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:d699b921af0dcde50ab18be76c0d832f803034d80470703700cb7df0fbec2832"}, + {file = "xxhash-3.4.1-cp39-cp39-win_arm64.whl", hash = "sha256:2be491723405e15cc099ade1280133ccfbf6322d2ef568494fb7d07d280e7eee"}, + {file = "xxhash-3.4.1-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:431625fad7ab5649368c4849d2b49a83dc711b1f20e1f7f04955aab86cd307bc"}, + {file = "xxhash-3.4.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fc6dbd5fc3c9886a9e041848508b7fb65fd82f94cc793253990f81617b61fe49"}, + {file = "xxhash-3.4.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f3ff8dbd0ec97aec842476cb8ccc3e17dd288cd6ce3c8ef38bff83d6eb927817"}, + {file = "xxhash-3.4.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ef73a53fe90558a4096e3256752268a8bdc0322f4692ed928b6cd7ce06ad4fe3"}, + {file = "xxhash-3.4.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:450401f42bbd274b519d3d8dcf3c57166913381a3d2664d6609004685039f9d3"}, + {file = "xxhash-3.4.1-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:a162840cf4de8a7cd8720ff3b4417fbc10001eefdd2d21541a8226bb5556e3bb"}, + {file = "xxhash-3.4.1-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b736a2a2728ba45017cb67785e03125a79d246462dfa892d023b827007412c52"}, + {file = "xxhash-3.4.1-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1d0ae4c2e7698adef58710d6e7a32ff518b66b98854b1c68e70eee504ad061d8"}, + {file = "xxhash-3.4.1-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d6322c4291c3ff174dcd104fae41500e75dad12be6f3085d119c2c8a80956c51"}, + {file = "xxhash-3.4.1-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:dd59ed668801c3fae282f8f4edadf6dc7784db6d18139b584b6d9677ddde1b6b"}, + {file = "xxhash-3.4.1-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:92693c487e39523a80474b0394645b393f0ae781d8db3474ccdcead0559ccf45"}, + {file = "xxhash-3.4.1-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4603a0f642a1e8d7f3ba5c4c25509aca6a9c1cc16f85091004a7028607ead663"}, + {file = "xxhash-3.4.1-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6fa45e8cbfbadb40a920fe9ca40c34b393e0b067082d94006f7f64e70c7490a6"}, + {file = "xxhash-3.4.1-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:595b252943b3552de491ff51e5bb79660f84f033977f88f6ca1605846637b7c6"}, + {file = "xxhash-3.4.1-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:562d8b8f783c6af969806aaacf95b6c7b776929ae26c0cd941d54644ea7ef51e"}, + {file = "xxhash-3.4.1-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:41ddeae47cf2828335d8d991f2d2b03b0bdc89289dc64349d712ff8ce59d0647"}, + {file = "xxhash-3.4.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c44d584afdf3c4dbb3277e32321d1a7b01d6071c1992524b6543025fb8f4206f"}, + {file = "xxhash-3.4.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd7bddb3a5b86213cc3f2c61500c16945a1b80ecd572f3078ddbbe68f9dabdfb"}, + {file = "xxhash-3.4.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9ecb6c987b62437c2f99c01e97caf8d25660bf541fe79a481d05732e5236719c"}, + {file = "xxhash-3.4.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:696b4e18b7023527d5c50ed0626ac0520edac45a50ec7cf3fc265cd08b1f4c03"}, + {file = "xxhash-3.4.1.tar.gz", hash = "sha256:0379d6cf1ff987cd421609a264ce025e74f346e3e145dd106c0cc2e3ec3f99a9"}, +] + +[[package]] +name = "yarl" +version = "1.9.4" +description = "Yet another URL library" +optional = false +python-versions = ">=3.7" +files = [ + {file = "yarl-1.9.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a8c1df72eb746f4136fe9a2e72b0c9dc1da1cbd23b5372f94b5820ff8ae30e0e"}, + {file = "yarl-1.9.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a3a6ed1d525bfb91b3fc9b690c5a21bb52de28c018530ad85093cc488bee2dd2"}, + {file = "yarl-1.9.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c38c9ddb6103ceae4e4498f9c08fac9b590c5c71b0370f98714768e22ac6fa66"}, + {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d9e09c9d74f4566e905a0b8fa668c58109f7624db96a2171f21747abc7524234"}, + {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b8477c1ee4bd47c57d49621a062121c3023609f7a13b8a46953eb6c9716ca392"}, + {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d5ff2c858f5f6a42c2a8e751100f237c5e869cbde669a724f2062d4c4ef93551"}, + {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:357495293086c5b6d34ca9616a43d329317feab7917518bc97a08f9e55648455"}, + {file = "yarl-1.9.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:54525ae423d7b7a8ee81ba189f131054defdb122cde31ff17477951464c1691c"}, + {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:801e9264d19643548651b9db361ce3287176671fb0117f96b5ac0ee1c3530d53"}, + {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:e516dc8baf7b380e6c1c26792610230f37147bb754d6426462ab115a02944385"}, + {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:7d5aaac37d19b2904bb9dfe12cdb08c8443e7ba7d2852894ad448d4b8f442863"}, + {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:54beabb809ffcacbd9d28ac57b0db46e42a6e341a030293fb3185c409e626b8b"}, + {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:bac8d525a8dbc2a1507ec731d2867025d11ceadcb4dd421423a5d42c56818541"}, + {file = "yarl-1.9.4-cp310-cp310-win32.whl", hash = "sha256:7855426dfbddac81896b6e533ebefc0af2f132d4a47340cee6d22cac7190022d"}, + {file = "yarl-1.9.4-cp310-cp310-win_amd64.whl", hash = "sha256:848cd2a1df56ddbffeb375535fb62c9d1645dde33ca4d51341378b3f5954429b"}, + {file = "yarl-1.9.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:35a2b9396879ce32754bd457d31a51ff0a9d426fd9e0e3c33394bf4b9036b099"}, + {file = "yarl-1.9.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c7d56b293cc071e82532f70adcbd8b61909eec973ae9d2d1f9b233f3d943f2c"}, + {file = "yarl-1.9.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d8a1c6c0be645c745a081c192e747c5de06e944a0d21245f4cf7c05e457c36e0"}, + {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4b3c1ffe10069f655ea2d731808e76e0f452fc6c749bea04781daf18e6039525"}, + {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:549d19c84c55d11687ddbd47eeb348a89df9cb30e1993f1b128f4685cd0ebbf8"}, + {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a7409f968456111140c1c95301cadf071bd30a81cbd7ab829169fb9e3d72eae9"}, + {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e23a6d84d9d1738dbc6e38167776107e63307dfc8ad108e580548d1f2c587f42"}, + {file = "yarl-1.9.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d8b889777de69897406c9fb0b76cdf2fd0f31267861ae7501d93003d55f54fbe"}, + {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:03caa9507d3d3c83bca08650678e25364e1843b484f19986a527630ca376ecce"}, + {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:4e9035df8d0880b2f1c7f5031f33f69e071dfe72ee9310cfc76f7b605958ceb9"}, + {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:c0ec0ed476f77db9fb29bca17f0a8fcc7bc97ad4c6c1d8959c507decb22e8572"}, + {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:ee04010f26d5102399bd17f8df8bc38dc7ccd7701dc77f4a68c5b8d733406958"}, + {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:49a180c2e0743d5d6e0b4d1a9e5f633c62eca3f8a86ba5dd3c471060e352ca98"}, + {file = "yarl-1.9.4-cp311-cp311-win32.whl", hash = "sha256:81eb57278deb6098a5b62e88ad8281b2ba09f2f1147c4767522353eaa6260b31"}, + {file = "yarl-1.9.4-cp311-cp311-win_amd64.whl", hash = "sha256:d1d2532b340b692880261c15aee4dc94dd22ca5d61b9db9a8a361953d36410b1"}, + {file = "yarl-1.9.4-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:0d2454f0aef65ea81037759be5ca9947539667eecebca092733b2eb43c965a81"}, + {file = "yarl-1.9.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:44d8ffbb9c06e5a7f529f38f53eda23e50d1ed33c6c869e01481d3fafa6b8142"}, + {file = "yarl-1.9.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:aaaea1e536f98754a6e5c56091baa1b6ce2f2700cc4a00b0d49eca8dea471074"}, + {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3777ce5536d17989c91696db1d459574e9a9bd37660ea7ee4d3344579bb6f129"}, + {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9fc5fc1eeb029757349ad26bbc5880557389a03fa6ada41703db5e068881e5f2"}, + {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ea65804b5dc88dacd4a40279af0cdadcfe74b3e5b4c897aa0d81cf86927fee78"}, + {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa102d6d280a5455ad6a0f9e6d769989638718e938a6a0a2ff3f4a7ff8c62cc4"}, + {file = "yarl-1.9.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:09efe4615ada057ba2d30df871d2f668af661e971dfeedf0c159927d48bbeff0"}, + {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:008d3e808d03ef28542372d01057fd09168419cdc8f848efe2804f894ae03e51"}, + {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:6f5cb257bc2ec58f437da2b37a8cd48f666db96d47b8a3115c29f316313654ff"}, + {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:992f18e0ea248ee03b5a6e8b3b4738850ae7dbb172cc41c966462801cbf62cf7"}, + {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:0e9d124c191d5b881060a9e5060627694c3bdd1fe24c5eecc8d5d7d0eb6faabc"}, + {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3986b6f41ad22988e53d5778f91855dc0399b043fc8946d4f2e68af22ee9ff10"}, + {file = "yarl-1.9.4-cp312-cp312-win32.whl", hash = "sha256:4b21516d181cd77ebd06ce160ef8cc2a5e9ad35fb1c5930882baff5ac865eee7"}, + {file = "yarl-1.9.4-cp312-cp312-win_amd64.whl", hash = "sha256:a9bd00dc3bc395a662900f33f74feb3e757429e545d831eef5bb280252631984"}, + {file = "yarl-1.9.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:63b20738b5aac74e239622d2fe30df4fca4942a86e31bf47a81a0e94c14df94f"}, + {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7d7f7de27b8944f1fee2c26a88b4dabc2409d2fea7a9ed3df79b67277644e17"}, + {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c74018551e31269d56fab81a728f683667e7c28c04e807ba08f8c9e3bba32f14"}, + {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ca06675212f94e7a610e85ca36948bb8fc023e458dd6c63ef71abfd482481aa5"}, + {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5aef935237d60a51a62b86249839b51345f47564208c6ee615ed2a40878dccdd"}, + {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2b134fd795e2322b7684155b7855cc99409d10b2e408056db2b93b51a52accc7"}, + {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:d25039a474c4c72a5ad4b52495056f843a7ff07b632c1b92ea9043a3d9950f6e"}, + {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:f7d6b36dd2e029b6bcb8a13cf19664c7b8e19ab3a58e0fefbb5b8461447ed5ec"}, + {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:957b4774373cf6f709359e5c8c4a0af9f6d7875db657adb0feaf8d6cb3c3964c"}, + {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:d7eeb6d22331e2fd42fce928a81c697c9ee2d51400bd1a28803965883e13cead"}, + {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:6a962e04b8f91f8c4e5917e518d17958e3bdee71fd1d8b88cdce74dd0ebbf434"}, + {file = "yarl-1.9.4-cp37-cp37m-win32.whl", hash = "sha256:f3bc6af6e2b8f92eced34ef6a96ffb248e863af20ef4fde9448cc8c9b858b749"}, + {file = "yarl-1.9.4-cp37-cp37m-win_amd64.whl", hash = "sha256:ad4d7a90a92e528aadf4965d685c17dacff3df282db1121136c382dc0b6014d2"}, + {file = "yarl-1.9.4-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:ec61d826d80fc293ed46c9dd26995921e3a82146feacd952ef0757236fc137be"}, + {file = "yarl-1.9.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:8be9e837ea9113676e5754b43b940b50cce76d9ed7d2461df1af39a8ee674d9f"}, + {file = "yarl-1.9.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:bef596fdaa8f26e3d66af846bbe77057237cb6e8efff8cd7cc8dff9a62278bbf"}, + {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2d47552b6e52c3319fede1b60b3de120fe83bde9b7bddad11a69fb0af7db32f1"}, + {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:84fc30f71689d7fc9168b92788abc977dc8cefa806909565fc2951d02f6b7d57"}, + {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4aa9741085f635934f3a2583e16fcf62ba835719a8b2b28fb2917bb0537c1dfa"}, + {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:206a55215e6d05dbc6c98ce598a59e6fbd0c493e2de4ea6cc2f4934d5a18d130"}, + {file = "yarl-1.9.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:07574b007ee20e5c375a8fe4a0789fad26db905f9813be0f9fef5a68080de559"}, + {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:5a2e2433eb9344a163aced6a5f6c9222c0786e5a9e9cac2c89f0b28433f56e23"}, + {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:6ad6d10ed9b67a382b45f29ea028f92d25bc0bc1daf6c5b801b90b5aa70fb9ec"}, + {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:6fe79f998a4052d79e1c30eeb7d6c1c1056ad33300f682465e1b4e9b5a188b78"}, + {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:a825ec844298c791fd28ed14ed1bffc56a98d15b8c58a20e0e08c1f5f2bea1be"}, + {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8619d6915b3b0b34420cf9b2bb6d81ef59d984cb0fde7544e9ece32b4b3043c3"}, + {file = "yarl-1.9.4-cp38-cp38-win32.whl", hash = "sha256:686a0c2f85f83463272ddffd4deb5e591c98aac1897d65e92319f729c320eece"}, + {file = "yarl-1.9.4-cp38-cp38-win_amd64.whl", hash = "sha256:a00862fb23195b6b8322f7d781b0dc1d82cb3bcac346d1e38689370cc1cc398b"}, + {file = "yarl-1.9.4-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:604f31d97fa493083ea21bd9b92c419012531c4e17ea6da0f65cacdcf5d0bd27"}, + {file = "yarl-1.9.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:8a854227cf581330ffa2c4824d96e52ee621dd571078a252c25e3a3b3d94a1b1"}, + {file = "yarl-1.9.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ba6f52cbc7809cd8d74604cce9c14868306ae4aa0282016b641c661f981a6e91"}, + {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a6327976c7c2f4ee6816eff196e25385ccc02cb81427952414a64811037bbc8b"}, + {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8397a3817d7dcdd14bb266283cd1d6fc7264a48c186b986f32e86d86d35fbac5"}, + {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e0381b4ce23ff92f8170080c97678040fc5b08da85e9e292292aba67fdac6c34"}, + {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:23d32a2594cb5d565d358a92e151315d1b2268bc10f4610d098f96b147370136"}, + {file = "yarl-1.9.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ddb2a5c08a4eaaba605340fdee8fc08e406c56617566d9643ad8bf6852778fc7"}, + {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:26a1dc6285e03f3cc9e839a2da83bcbf31dcb0d004c72d0730e755b33466c30e"}, + {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:18580f672e44ce1238b82f7fb87d727c4a131f3a9d33a5e0e82b793362bf18b4"}, + {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:29e0f83f37610f173eb7e7b5562dd71467993495e568e708d99e9d1944f561ec"}, + {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:1f23e4fe1e8794f74b6027d7cf19dc25f8b63af1483d91d595d4a07eca1fb26c"}, + {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:db8e58b9d79200c76956cefd14d5c90af54416ff5353c5bfd7cbe58818e26ef0"}, + {file = "yarl-1.9.4-cp39-cp39-win32.whl", hash = "sha256:c7224cab95645c7ab53791022ae77a4509472613e839dab722a72abe5a684575"}, + {file = "yarl-1.9.4-cp39-cp39-win_amd64.whl", hash = "sha256:824d6c50492add5da9374875ce72db7a0733b29c2394890aef23d533106e2b15"}, + {file = "yarl-1.9.4-py3-none-any.whl", hash = "sha256:928cecb0ef9d5a7946eb6ff58417ad2fe9375762382f1bf5c55e61645f2c43ad"}, + {file = "yarl-1.9.4.tar.gz", hash = "sha256:566db86717cf8080b99b58b083b773a908ae40f06681e87e589a976faf8246bf"}, +] + +[package.dependencies] +idna = ">=2.0" +multidict = ">=4.0" + [[package]] name = "zarr" version = "2.17.1" @@ -3323,10 +3911,12 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [extras] aloha = ["gym-aloha"] +dev = ["debugpy", "pre-commit"] pusht = ["gym-pusht"] +test = ["pytest", "pytest-cov"] xarm = ["gym-xarm"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "8fa6dfc30e605741c24f5de58b89125d5b02153f550e5af7a44356956d6bb167" +content-hash = "bd9c506d2499d5e1e3b5e8b1a0f65df45c8feef38d89d0daeade56847fdb6a2e" diff --git a/.github/poetry/cpu/pyproject.toml b/.github/poetry/cpu/pyproject.toml index f5c439dc..b13a9e97 100644 --- a/.github/poetry/cpu/pyproject.toml +++ b/.github/poetry/cpu/pyproject.toml @@ -1,19 +1,25 @@ [tool.poetry] name = "lerobot" version = "0.1.0" -description = "Le robot is learning" +description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch" authors = [ "Rémi Cadène ", + "Alexander Soare ", + "Quentin Gallouédec ", "Simon Alibert ", + "Thomas Wolf ", ] -repository = "https://github.com/Cadene/lerobot" +repository = "https://github.com/huggingface/lerobot" readme = "README.md" -license = "MIT" +license = "Apache-2.0" classifiers=[ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", "Topic :: Software Development :: Build Tools", - "License :: OSI Approved :: MIT License", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "License :: OSI Approved :: Apache Software License", "Programming Language :: Python :: 3.10", ] packages = [{include = "lerobot"}] @@ -23,53 +29,39 @@ packages = [{include = "lerobot"}] python = "^3.10" termcolor = "^2.4.0" omegaconf = "^2.3.0" -pandas = "^2.2.1" wandb = "^0.16.3" -moviepy = "^1.0.3" -imageio = {extras = ["pyav"], version = "^2.34.0"} +imageio = {extras = ["ffmpeg"], version = "^2.34.0"} gdown = "^5.1.0" hydra-core = "^1.3.2" einops = "^0.7.0" -pygame = "^2.5.2" pymunk = "^6.6.0" zarr = "^2.17.0" numba = "^0.59.0" -mpmath = "^1.3.0" torch = {version = "^2.2.1", source = "torch-cpu"} opencv-python = "^4.9.0.80" diffusers = "^0.26.3" torchvision = {version = "^0.17.1", source = "torch-cpu"} h5py = "^3.10.0" -robomimic = "0.2.0" huggingface-hub = "^0.21.4" +robomimic = "0.2.0" gymnasium = "^0.29.1" cmake = "^3.29.0.1" gym-pusht = { git = "git@github.com:huggingface/gym-pusht.git", optional = true} gym-xarm = { git = "git@github.com:huggingface/gym-xarm.git", optional = true} gym-aloha = { git = "git@github.com:huggingface/gym-aloha.git", optional = true} -# gym-pusht = { path = "../gym-pusht", develop = true, optional = true} -# gym-xarm = { path = "../gym-xarm", develop = true, optional = true} -# gym-aloha = { path = "../gym-aloha", develop = true, optional = true} +pre-commit = {version = "^3.7.0", optional = true} +debugpy = {version = "^1.8.1", optional = true} +pytest = {version = "^8.1.0", optional = true} +pytest-cov = {version = "^5.0.0", optional = true} +datasets = "^2.18.0" [tool.poetry.extras] pusht = ["gym-pusht"] xarm = ["gym-xarm"] aloha = ["gym-aloha"] - - -[tool.poetry.group.dev] -optional = true - - -[tool.poetry.group.dev.dependencies] -pre-commit = "^3.6.2" -debugpy = "^1.8.1" - - -[tool.poetry.group.test.dependencies] -pytest = "^8.1.0" -pytest-cov = "^5.0.0" +dev = ["pre-commit", "debugpy"] +test = ["pytest", "pytest-cov"] [[tool.poetry.source]] @@ -110,10 +102,6 @@ exclude = [ select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"] -[tool.poetry-dynamic-versioning] -enable = true - - [build-system] -requires = ["poetry-core>=1.0.0", "poetry-dynamic-versioning>=1.0.0,<2.0.0"] -build-backend = "poetry_dynamic_versioning.backend" +requires = ["poetry-core>=1.5.0"] +build-backend = "poetry.core.masonry.api" diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index afdcc41f..a86193b8 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -142,10 +142,12 @@ jobs: wandb.enable=False \ offline_steps=2 \ online_steps=0 \ + eval_episodes=1 \ device=cpu \ save_model=true \ save_freq=2 \ - horizon=20 \ + policy.n_action_steps=20 \ + policy.chunk_size=20 \ policy.batch_size=2 \ hydra.run.dir=tests/outputs/act/ @@ -159,17 +161,6 @@ jobs: device=cpu \ policy.pretrained_model_path=tests/outputs/act/models/2.pt - # TODO(aliberts): This takes ~2mn to run, needs to be improved - # - name: Test eval ACT on ALOHA end-to-end (policy is None) - # run: | - # source .venv/bin/activate - # python lerobot/scripts/eval.py \ - # --config lerobot/configs/default.yaml \ - # policy=act \ - # env=aloha \ - # eval_episodes=1 \ - # device=cpu - - name: Test train Diffusion on PushT end-to-end run: | source .venv/bin/activate @@ -179,9 +170,11 @@ jobs: wandb.enable=False \ offline_steps=2 \ online_steps=0 \ + eval_episodes=1 \ device=cpu \ save_model=true \ save_freq=2 \ + policy.batch_size=2 \ hydra.run.dir=tests/outputs/diffusion/ - name: Test eval Diffusion on PushT end-to-end @@ -194,16 +187,6 @@ jobs: device=cpu \ policy.pretrained_model_path=tests/outputs/diffusion/models/2.pt - - name: Test eval Diffusion on PushT end-to-end (policy is None) - run: | - source .venv/bin/activate - python lerobot/scripts/eval.py \ - --config lerobot/configs/default.yaml \ - policy=diffusion \ - env=pusht \ - eval_episodes=1 \ - device=cpu - - name: Test train TDMPC on Simxarm end-to-end run: | source .venv/bin/activate @@ -213,9 +196,11 @@ jobs: wandb.enable=False \ offline_steps=1 \ online_steps=1 \ + eval_episodes=1 \ device=cpu \ save_model=true \ save_freq=2 \ + policy.batch_size=2 \ hydra.run.dir=tests/outputs/tdmpc/ - name: Test eval TDMPC on Simxarm end-to-end @@ -227,13 +212,3 @@ jobs: env.episode_length=8 \ device=cpu \ policy.pretrained_model_path=tests/outputs/tdmpc/models/2.pt - - - name: Test eval TDPMC on Simxarm end-to-end (policy is None) - run: | - source .venv/bin/activate - python lerobot/scripts/eval.py \ - --config lerobot/configs/default.yaml \ - policy=tdmpc \ - env=xarm \ - eval_episodes=1 \ - device=cpu diff --git a/.gitignore b/.gitignore index ad9892d4..3132aba0 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,9 @@ rl nautilus/*.yaml *.key +# Slurm +sbatch*.sh + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 765b678a..1d0fb555 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,7 @@ default_language_version: python: python3.10 repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.5.0 + rev: v4.6.0 hooks: - id: check-added-large-files - id: debug-statements @@ -18,7 +18,7 @@ repos: hooks: - id: pyupgrade - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.3.4 + rev: v0.3.7 hooks: - id: ruff args: [--fix] diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 00000000..04a05275 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,133 @@ + +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our +community a harassment-free experience for everyone, regardless of age, body +size, visible or invisible disability, ethnicity, sex characteristics, gender +identity and expression, level of experience, education, socio-economic status, +nationality, personal appearance, race, caste, color, religion, or sexual +identity and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, +diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our +community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our mistakes, + and learning from the experience +* Focusing on what is best not just for us as individuals, but for the overall + community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or advances of + any kind +* Trolling, insulting or derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email address, + without their explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of +acceptable behavior and will take appropriate and fair corrective action in +response to any behavior that they deem inappropriate, threatening, offensive, +or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject +comments, commits, code, wiki edits, issues, and other contributions that are +not aligned to this Code of Conduct, and will communicate reasons for moderation +decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when +an individual is officially representing the community in public spaces. +Examples of representing our community include using an official email address, +posting via an official social media account, or acting as an appointed +representative at an online or offline event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported to the community leaders responsible for enforcement at +[feedback@huggingface.co](mailto:feedback@huggingface.co). +All complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the +reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining +the consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed +unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing +clarity around the nature of the violation and an explanation of why the +behavior was inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series of +actions. + +**Consequence**: A warning with consequences for continued behavior. No +interaction with the people involved, including unsolicited interaction with +those enforcing the Code of Conduct, for a specified period of time. This +includes avoiding interactions in community spaces as well as external channels +like social media. Violating these terms may lead to a temporary or permanent +ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including +sustained inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public +communication with the community for a specified period of time. No public or +private interaction with the people involved, including unsolicited interaction +with those enforcing the Code of Conduct, is allowed during this period. +Violating these terms may lead to a permanent ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community +standards, including sustained inappropriate behavior, harassment of an +individual, or aggression toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within the +community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], +version 2.1, available at +[https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1]. + +Community Impact Guidelines were inspired by +[Mozilla's code of conduct enforcement ladder][Mozilla CoC]. + +For answers to common questions about this code of conduct, see the FAQ at +[https://www.contributor-covenant.org/faq][FAQ]. Translations are available at +[https://www.contributor-covenant.org/translations][translations]. + +[homepage]: https://www.contributor-covenant.org +[v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html +[Mozilla CoC]: https://github.com/mozilla/diversity +[FAQ]: https://www.contributor-covenant.org/faq +[translations]: https://www.contributor-covenant.org/translations diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 00000000..0b40d81a --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,254 @@ +# How to contribute to 🤗 LeRobot? + +Everyone is welcome to contribute, and we value everybody's contribution. Code +is thus not the only way to help the community. Answering questions, helping +others, reaching out and improving the documentations are immensely valuable to +the community. + +It also helps us if you spread the word: reference the library from blog posts +on the awesome projects it made possible, shout out on Twitter when it has +helped you, or simply ⭐️ the repo to say "thank you". + +Whichever way you choose to contribute, please be mindful to respect our +[code of conduct](https://github.com/huggingface/lerobot/blob/main/CODE_OF_CONDUCT.md). + +## You can contribute in so many ways! + +Some of the ways you can contribute to 🤗 LeRobot: +* Fixing outstanding issues with the existing code. +* Implementing new models, datasets or simulation environments. +* Contributing to the examples or to the documentation. +* Submitting issues related to bugs or desired new features. + +Following the guides below, feel free to open issues and PRs and to coordinate your efforts with the community on our [Discord Channel](https://discord.gg/VjFz58wn3R). For specific inquiries, reach out to [Remi Cadene](remi.cadene@huggingface.co). + +If you are not sure how to contribute or want to know the next features we working on, look on this project page: [LeRobot TODO](https://github.com/orgs/huggingface/projects/46) + +## Submitting a new issue or feature request + +Do your best to follow these guidelines when submitting an issue or a feature +request. It will make it easier for us to come back to you quickly and with good +feedback. + +### Did you find a bug? + +The 🤗 LeRobot library is robust and reliable thanks to the users who notify us of +the problems they encounter. So thank you for reporting an issue. + +First, we would really appreciate it if you could **make sure the bug was not +already reported** (use the search bar on Github under Issues). + +Did not find it? :( So we can act quickly on it, please follow these steps: + +* Include your **OS type and version**, the versions of **Python** and **PyTorch**. +* A short, self-contained, code snippet that allows us to reproduce the bug in + less than 30s. +* The full traceback if an exception is raised. +* Attach any other additional information, like screenshots, you think may help. + +### Do you want a new feature? + +A good feature request addresses the following points: + +1. Motivation first: +* Is it related to a problem/frustration with the library? If so, please explain + why. Providing a code snippet that demonstrates the problem is best. +* Is it related to something you would need for a project? We'd love to hear + about it! +* Is it something you worked on and think could benefit the community? + Awesome! Tell us what problem it solved for you. +2. Write a *paragraph* describing the feature. +3. Provide a **code snippet** that demonstrates its future use. +4. In case this is related to a paper, please attach a link. +5. Attach any additional information (drawings, screenshots, etc.) you think may help. + +If your issue is well written we're already 80% of the way there by the time you +post it. + +## Submitting a pull request (PR) + +Before writing code, we strongly advise you to search through the existing PRs or +issues to make sure that nobody is already working on the same thing. If you are +unsure, it is always a good idea to open an issue to get some feedback. + +You will need basic `git` proficiency to be able to contribute to +🤗 LeRobot. `git` is not the easiest tool to use but it has the greatest +manual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro +Git](https://git-scm.com/book/en/v2) is a very good reference. + +Follow these steps to start contributing: + +1. Fork the [repository](https://github.com/huggingface/lerobot) by + clicking on the 'Fork' button on the repository's page. This creates a copy of the code + under your GitHub user account. + +2. Clone your fork to your local disk, and add the base repository as a remote. The following command + assumes you have your public SSH key uploaded to GitHub. See the following guide for more + [information](https://docs.github.com/en/repositories/creating-and-managing-repositories/cloning-a-repository). + + ```bash + git clone git@github.com:/lerobot.git + cd lerobot + git remote add upstream https://github.com/huggingface/lerobot.git + ``` + +3. Create a new branch to hold your development changes, and do this for every new PR you work on. + + Start by synchronizing your `main` branch with the `upstream/main` branch (more details in the [GitHub Docs](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/syncing-a-fork)): + + ```bash + git checkout main + git fetch upstream + git rebase upstream/main + ``` + + Once your `main` branch is synchronized, create a new branch from it: + + ```bash + git checkout -b a-descriptive-name-for-my-changes + ``` + + 🚨 **Do not** work on the `main` branch. + +4. Instead of using `pip` directly, we use `poetry` for development purposes to easily track our dependencies. + If you don't have it already, follow the [instructions](https://python-poetry.org/docs/#installation) to install it. + Set up a development environment by running the following command in a conda or a virtual environment you've created for working on this library: + Install the project with dev dependencies and all environments: + ```bash + poetry install --sync --with dev --all-extras + ``` + This command should be run when pulling code with and updated version of `pyproject.toml` and `poetry.lock` in order to synchronize your virtual environment with the dependencies. + + To selectively install environments (for example aloha and pusht) use: + ```bash + poetry install --sync --with dev --extras "aloha pusht" + ``` + + The equivalent of `pip install some-package`, would just be: + ```bash + poetry add some-package + ``` + + When changes are made to the poetry sections of the `pyproject.toml`, you should run the following command to lock dependencies. + ```bash + poetry lock --no-update + ``` + + **NOTE:** Currently, to ensure the CI works properly, any new package must also be added in the CPU-only environment dedicated to the CI. To do this, you should create a separate environment and add the new package there as well. For example: + ```bash + # Add the new package to your main poetry env + poetry add some-package + # Add the same package to the CPU-only env dedicated to CI + conda create -y -n lerobot-ci python=3.10 + conda activate lerobot-ci + cd .github/poetry/cpu + poetry add some-package + ``` + +5. Develop the features on your branch. + + As you work on the features, you should make sure that the test suite + passes. You should run the tests impacted by your changes like this (see + below an explanation regarding the environment variable): + + ```bash + pytest tests/.py + ``` + +6. Follow our style. + + `lerobot` relies on `ruff` to format its source code + consistently. Set up [`pre-commit`](https://pre-commit.com/) to run these checks + automatically as Git commit hooks. + + Install `pre-commit` hooks: + ```bash + pre-commit install + ``` + + You can run these hooks whenever you need on staged files with: + ```bash + pre-commit + ``` + + Once you're happy with your changes, add changed files using `git add` and + make a commit with `git commit` to record your changes locally: + + ```bash + git add modified_file.py + git commit + ``` + + Please write [good commit messages](https://chris.beams.io/posts/git-commit/). + + It is a good idea to sync your copy of the code with the original + repository regularly. This way you can quickly account for changes: + + ```bash + git fetch upstream + git rebase upstream/main + ``` + + Push the changes to your account using: + + ```bash + git push -u origin a-descriptive-name-for-my-changes + ``` + +6. Once you are satisfied (**and the checklist below is happy too**), go to the + webpage of your fork on GitHub. Click on 'Pull request' to send your changes + to the project maintainers for review. + +7. It's ok if maintainers ask you for changes. It happens to core contributors + too! So everyone can see the changes in the Pull request, work in your local + branch and push the changes to your fork. They will automatically appear in + the pull request. + + +### Checklist + +1. The title of your pull request should be a summary of its contribution; +2. If your pull request addresses an issue, please mention the issue number in + the pull request description to make sure they are linked (and people + consulting the issue know you are working on it); +3. To indicate a work in progress please prefix the title with `[WIP]`, or preferably mark + the PR as a draft PR. These are useful to avoid duplicated work, and to differentiate + it from PRs ready to be merged; +4. Make sure existing tests pass; + + +### Tests + +An extensive test suite is included to test the library behavior and several examples. Library tests can be found in the [tests folder](https://github.com/huggingface/lerobot/tree/main/tests). + +Install [git lfs](https://git-lfs.com/) to retrieve test artifacts (if you don't have it already). + +On Mac: +```bash +brew install git-lfs +git lfs install +``` + +On Ubuntu: +```bash +sudo apt-get install git-lfs +git lfs install +``` + +Pull artifacts if they're not in [tests/data](tests/data) +```bash +git lfs pull +``` + +We use `pytest` in order to run the tests. From the root of the +repository, here's how to run tests with `pytest` for the library: + +```bash +DATA_DIR="tests/data" python -m pytest -sv ./tests +``` + + +You can specify a smaller set of tests in order to test only the feature +you're working on. diff --git a/README.md b/README.md index 51e03d65..202b90e6 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ [![Status](https://img.shields.io/pypi/status/lerobot)](https://pypi.org/project/lerobot/) [![Version](https://img.shields.io/pypi/v/lerobot)](https://pypi.org/project/lerobot/) [![Examples](https://img.shields.io/badge/Examples-green.svg)](https://github.com/huggingface/lerobot/tree/main/examples) +[![Contributor Covenant](https://img.shields.io/badge/Contributor%20Covenant-v2.1%20adopted-ff69b4.svg)](https://github.com/huggingface/lerobot/blob/main/CODE_OF_CONDUCT.md) [![Discord](https://dcbadge.vercel.app/api/server/C5P34WJ68S?style=flat)](https://discord.gg/s3KuuzsPFb) @@ -120,34 +121,32 @@ wandb login You can import our dataset class, download the data from the HuggingFace hub and use our rendering utilities: ```python """ Copy pasted from `examples/1_visualize_dataset.py` """ +import os +from pathlib import Path + import lerobot from lerobot.common.datasets.aloha import AlohaDataset -from torchrl.data.replay_buffers import SamplerWithoutReplacement from lerobot.scripts.visualize_dataset import render_dataset print(lerobot.available_datasets) # >>> ['aloha_sim_insertion_human', 'aloha_sim_insertion_scripted', 'aloha_sim_transfer_cube_human', 'aloha_sim_transfer_cube_scripted', 'pusht', 'xarm_lift_medium'] -# we use this sampler to sample 1 frame after the other -sampler = SamplerWithoutReplacement(shuffle=False) - -dataset = AlohaDataset("aloha_sim_transfer_cube_human", sampler=sampler) +# TODO(rcadene): remove DATA_DIR +dataset = AlohaDataset("pusht", root=Path(os.environ.get("DATA_DIR"))) video_paths = render_dataset( dataset, out_dir="outputs/visualize_dataset/example", - max_num_samples=300, - fps=50, + max_num_episodes=1, ) print(video_paths) -# >>> ['outputs/visualize_dataset/example/episode_0.mp4'] +# ['outputs/visualize_dataset/example/episode_0.mp4'] ``` Or you can achieve the same result by executing our script from the command line: ```bash python lerobot/scripts/visualize_dataset.py \ -env=aloha \ -task=sim_sim_transfer_cube_human \ +env=pusht \ hydra.run.dir=outputs/visualize_dataset/example # >>> ['outputs/visualize_dataset/example/episode_0.mp4'] ``` @@ -191,89 +190,7 @@ hydra.run.dir=outputs/train/aloha_act ## Contribute -Feel free to open issues and PRs, and to coordinate your efforts with the community on our [Discord Channel](https://discord.gg/VjFz58wn3R). For specific inquiries, reach out to [Remi Cadene](remi.cadene@huggingface.co). - -### TODO - -If you are not sure how to contribute or want to know the next features we working on, look on this project page: [LeRobot TODO](https://github.com/orgs/huggingface/projects/46) - -### Follow our style - -```bash -# install if needed -pre-commit install -# apply style and linter checks before git commit -pre-commit -``` - -### Dependencies - -Instead of using `pip` directly, we use `poetry` for development purposes to easily track our dependencies. -If you don't have it already, follow the [instructions](https://python-poetry.org/docs/#installation) to install it. - -Install the project with dev dependencies and all environments: -```bash -poetry install --sync --with dev --all-extras -``` -This command should be run when pulling code with and updated version of `pyproject.toml` and `poetry.lock` in order to synchronize your virtual environment with the dependencies. - -To selectively install environments (for example aloha and pusht) use: -```bash -poetry install --sync --with dev --extras "aloha pusht" -``` - -The equivalent of `pip install some-package`, would just be: -```bash -poetry add some-package -``` - -When changes are made to the poetry sections of the `pyproject.toml`, you should run the following command to lock dependencies. -```bash -poetry lock --no-update -``` - - -**NOTE:** Currently, to ensure the CI works properly, any new package must also be added in the CPU-only environment dedicated to the CI. To do this, you should create a separate environment and add the new package there as well. For example: -```bash -# Add the new package to your main poetry env -poetry add some-package -# Add the same package to the CPU-only env dedicated to CI -conda create -y -n lerobot-ci python=3.10 -conda activate lerobot-ci -cd .github/poetry/cpu -poetry add some-package -``` - -### Run tests locally - -Install [git lfs](https://git-lfs.com/) to retrieve test artifacts (if you don't have it already). - -On Mac: -```bash -brew install git-lfs -git lfs install -``` - -On Ubuntu: -```bash -sudo apt-get install git-lfs -git lfs install -``` - -Pull artifacts if they're not in [tests/data](tests/data) -```bash -git lfs pull -``` - -When adding a new dataset, mock it with -```bash -python tests/scripts/mock_dataset.py --in-data-dir data/$DATASET --out-data-dir tests/data/$DATASET -``` - -Run tests -```bash -DATA_DIR="tests/data" pytest -sx tests -``` +If you would like to contribute to 🤗 LeRobot, please check out our [contribution guide](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md). ### Add a new dataset diff --git a/download_and_upload_dataset.py b/download_and_upload_dataset.py new file mode 100644 index 00000000..2d339221 --- /dev/null +++ b/download_and_upload_dataset.py @@ -0,0 +1,487 @@ +""" +This file contains all obsolete download scripts. They are centralized here to not have to load +useless dependencies when using datasets. +""" + +import io +import pickle +import shutil +from pathlib import Path + +import einops +import h5py +import numpy as np +import torch +import tqdm +from datasets import Dataset, Features, Image, Sequence, Value +from PIL import Image as PILImage + + +def download_and_upload(root, root_tests, dataset_id): + if "pusht" in dataset_id: + download_and_upload_pusht(root, root_tests, dataset_id) + elif "xarm" in dataset_id: + download_and_upload_xarm(root, root_tests, dataset_id) + elif "aloha" in dataset_id: + download_and_upload_aloha(root, root_tests, dataset_id) + else: + raise ValueError(dataset_id) + + +def download_and_extract_zip(url: str, destination_folder: Path) -> bool: + import zipfile + + import requests + + print(f"downloading from {url}") + response = requests.get(url, stream=True) + if response.status_code == 200: + total_size = int(response.headers.get("content-length", 0)) + progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True) + + zip_file = io.BytesIO() + for chunk in response.iter_content(chunk_size=1024): + if chunk: + zip_file.write(chunk) + progress_bar.update(len(chunk)) + + progress_bar.close() + + zip_file.seek(0) + + with zipfile.ZipFile(zip_file, "r") as zip_ref: + zip_ref.extractall(destination_folder) + return True + else: + return False + + +def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10): + try: + import pymunk + from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely + + from lerobot.common.datasets._diffusion_policy_replay_buffer import ( + ReplayBuffer as DiffusionPolicyReplayBuffer, + ) + except ModuleNotFoundError as e: + print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`") + raise e + + # as define in env + success_threshold = 0.95 # 95% coverage, + + pusht_url = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip" + pusht_zarr = Path("pusht/pusht_cchi_v7_replay.zarr") + + root = Path(root) + raw_dir = root / f"{dataset_id}_raw" + zarr_path = (raw_dir / pusht_zarr).resolve() + if not zarr_path.is_dir(): + raw_dir.mkdir(parents=True, exist_ok=True) + download_and_extract_zip(pusht_url, raw_dir) + + # load + dataset_dict = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path) # , keys=['img', 'state', 'action']) + + episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs()) + num_episodes = dataset_dict.meta["episode_ends"].shape[0] + assert len( + {dataset_dict[key].shape[0] for key in dataset_dict.keys()} # noqa: SIM118 + ), "Some data type dont have the same number of total frames." + + # TODO: verify that goal pose is expected to be fixed + goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians) + goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle) + + imgs = torch.from_numpy(dataset_dict["img"]) # b h w c + states = torch.from_numpy(dataset_dict["state"]) + actions = torch.from_numpy(dataset_dict["action"]) + + ep_dicts = [] + + id_from = 0 + for episode_id in tqdm.tqdm(range(num_episodes)): + id_to = dataset_dict.meta["episode_ends"][episode_id] + + num_frames = id_to - id_from + + assert (episode_ids[id_from:id_to] == episode_id).all() + + image = imgs[id_from:id_to] + assert image.min() >= 0.0 + assert image.max() <= 255.0 + image = image.type(torch.uint8) + + state = states[id_from:id_to] + agent_pos = state[:, :2] + block_pos = state[:, 2:4] + block_angle = state[:, 4] + + reward = torch.zeros(num_frames) + success = torch.zeros(num_frames, dtype=torch.bool) + done = torch.zeros(num_frames, dtype=torch.bool) + for i in range(num_frames): + space = pymunk.Space() + space.gravity = 0, 0 + space.damping = 0 + + # Add walls. + walls = [ + PushTEnv.add_segment(space, (5, 506), (5, 5), 2), + PushTEnv.add_segment(space, (5, 5), (506, 5), 2), + PushTEnv.add_segment(space, (506, 5), (506, 506), 2), + PushTEnv.add_segment(space, (5, 506), (506, 506), 2), + ] + space.add(*walls) + + block_body = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item()) + goal_geom = pymunk_to_shapely(goal_body, block_body.shapes) + block_geom = pymunk_to_shapely(block_body, block_body.shapes) + intersection_area = goal_geom.intersection(block_geom).area + goal_area = goal_geom.area + coverage = intersection_area / goal_area + reward[i] = np.clip(coverage / success_threshold, 0, 1) + success[i] = coverage > success_threshold + + # last step of demonstration is considered done + done[-1] = True + + ep_dict = { + "observation.image": [PILImage.fromarray(x.numpy()) for x in image], + "observation.state": agent_pos, + "action": actions[id_from:id_to], + "episode_id": torch.tensor([episode_id] * num_frames, dtype=torch.int), + "frame_id": torch.arange(0, num_frames, 1), + "timestamp": torch.arange(0, num_frames, 1) / fps, + # "next.observation.image": image[1:], + # "next.observation.state": agent_pos[1:], + # TODO(rcadene): verify that reward and done are aligned with image and agent_pos + "next.reward": torch.cat([reward[1:], reward[[-1]]]), + "next.done": torch.cat([done[1:], done[[-1]]]), + "next.success": torch.cat([success[1:], success[[-1]]]), + "episode_data_index_from": torch.tensor([id_from] * num_frames), + "episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames), + } + ep_dicts.append(ep_dict) + + id_from += num_frames + + data_dict = {} + + keys = ep_dicts[0].keys() + for key in keys: + if torch.is_tensor(ep_dicts[0][key][0]): + data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts]) + else: + if key not in data_dict: + data_dict[key] = [] + for ep_dict in ep_dicts: + for x in ep_dict[key]: + data_dict[key].append(x) + + total_frames = id_from + data_dict["index"] = torch.arange(0, total_frames, 1) + + features = { + "observation.image": Image(), + "observation.state": Sequence( + length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None) + ), + "action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)), + "episode_id": Value(dtype="int64", id=None), + "frame_id": Value(dtype="int64", id=None), + "timestamp": Value(dtype="float32", id=None), + "next.reward": Value(dtype="float32", id=None), + "next.done": Value(dtype="bool", id=None), + "next.success": Value(dtype="bool", id=None), + "index": Value(dtype="int64", id=None), + "episode_data_index_from": Value(dtype="int64", id=None), + "episode_data_index_to": Value(dtype="int64", id=None), + } + features = Features(features) + dataset = Dataset.from_dict(data_dict, features=features) + dataset = dataset.with_format("torch") + + num_items_first_ep = ep_dicts[0]["frame_id"].shape[0] + dataset.select(range(num_items_first_ep)).save_to_disk(f"{root_tests}/{dataset_id}/train") + dataset.push_to_hub(f"lerobot/{dataset_id}", token=True) + dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0") + + +def download_and_upload_xarm(root, root_tests, dataset_id, fps=15): + root = Path(root) + raw_dir = root / f"{dataset_id}_raw" + if not raw_dir.exists(): + import zipfile + + import gdown + + raw_dir.mkdir(parents=True, exist_ok=True) + url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya" + zip_path = raw_dir / "data.zip" + gdown.download(url, str(zip_path), quiet=False) + print("Extracting...") + with zipfile.ZipFile(str(zip_path), "r") as zip_f: + for member in zip_f.namelist(): + if member.startswith("data/xarm") and member.endswith(".pkl"): + print(member) + zip_f.extract(member=member) + zip_path.unlink() + + dataset_path = root / f"{dataset_id}" / "buffer.pkl" + print(f"Using offline dataset '{dataset_path}'") + with open(dataset_path, "rb") as f: + dataset_dict = pickle.load(f) + + total_frames = dataset_dict["actions"].shape[0] + + ep_dicts = [] + + id_from = 0 + id_to = 0 + episode_id = 0 + for i in tqdm.tqdm(range(total_frames)): + id_to += 1 + + if not dataset_dict["dones"][i]: + continue + + num_frames = id_to - id_from + + image = torch.tensor(dataset_dict["observations"]["rgb"][id_from:id_to]) + image = einops.rearrange(image, "b c h w -> b h w c") + state = torch.tensor(dataset_dict["observations"]["state"][id_from:id_to]) + action = torch.tensor(dataset_dict["actions"][id_from:id_to]) + # TODO(rcadene): we have a missing last frame which is the observation when the env is done + # it is critical to have this frame for tdmpc to predict a "done observation/state" + # next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][id_from:id_to]) + # next_state = torch.tensor(dataset_dict["next_observations"]["state"][id_from:id_to]) + next_reward = torch.tensor(dataset_dict["rewards"][id_from:id_to]) + next_done = torch.tensor(dataset_dict["dones"][id_from:id_to]) + + ep_dict = { + "observation.image": [PILImage.fromarray(x.numpy()) for x in image], + "observation.state": state, + "action": action, + "episode_id": torch.tensor([episode_id] * num_frames, dtype=torch.int), + "frame_id": torch.arange(0, num_frames, 1), + "timestamp": torch.arange(0, num_frames, 1) / fps, + # "next.observation.image": next_image, + # "next.observation.state": next_state, + "next.reward": next_reward, + "next.done": next_done, + "episode_data_index_from": torch.tensor([id_from] * num_frames), + "episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames), + } + ep_dicts.append(ep_dict) + + id_from = id_to + episode_id += 1 + + data_dict = {} + keys = ep_dicts[0].keys() + for key in keys: + if torch.is_tensor(ep_dicts[0][key][0]): + data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts]) + else: + if key not in data_dict: + data_dict[key] = [] + for ep_dict in ep_dicts: + for x in ep_dict[key]: + data_dict[key].append(x) + + total_frames = id_from + data_dict["index"] = torch.arange(0, total_frames, 1) + + features = { + "observation.image": Image(), + "observation.state": Sequence( + length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None) + ), + "action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)), + "episode_id": Value(dtype="int64", id=None), + "frame_id": Value(dtype="int64", id=None), + "timestamp": Value(dtype="float32", id=None), + "next.reward": Value(dtype="float32", id=None), + "next.done": Value(dtype="bool", id=None), + #'next.success': Value(dtype='bool', id=None), + "index": Value(dtype="int64", id=None), + "episode_data_index_from": Value(dtype="int64", id=None), + "episode_data_index_to": Value(dtype="int64", id=None), + } + features = Features(features) + dataset = Dataset.from_dict(data_dict, features=features) + dataset = dataset.with_format("torch") + + num_items_first_ep = ep_dicts[0]["frame_id"].shape[0] + dataset.select(range(num_items_first_ep)).save_to_disk(f"{root_tests}/{dataset_id}/train") + dataset.push_to_hub(f"lerobot/{dataset_id}", token=True) + dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0") + + +def download_and_upload_aloha(root, root_tests, dataset_id, fps=50): + folder_urls = { + "aloha_sim_insertion_human": "https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF", + "aloha_sim_insertion_scripted": "https://drive.google.com/drive/folders/1TsojQQSXtHEoGnqgJ3gmpPQR2DPLtS2N", + "aloha_sim_transfer_cube_human": "https://drive.google.com/drive/folders/1sc-E4QYW7A0o23m1u2VWNGVq5smAsfCo", + "aloha_sim_transfer_cube_scripted": "https://drive.google.com/drive/folders/1aRyoOhQwxhyt1J8XgEig4s6kzaw__LXj", + } + + ep48_urls = { + "aloha_sim_insertion_human": "https://drive.google.com/file/d/18Cudl6nikDtgRolea7je8iF_gGKzynOP/view?usp=drive_link", + "aloha_sim_insertion_scripted": "https://drive.google.com/file/d/1wfMSZ24oOh5KR_0aaP3Cnu_c4ZCveduB/view?usp=drive_link", + "aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/18smMymtr8tIxaNUQ61gW6dG50pt3MvGq/view?usp=drive_link", + "aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1pnGIOd-E4-rhz2P3VxpknMKRZCoKt6eI/view?usp=drive_link", + } + + ep49_urls = { + "aloha_sim_insertion_human": "https://drive.google.com/file/d/1C1kZYyROzs-PrLc0SkDgUgMi4-L3lauE/view?usp=drive_link", + "aloha_sim_insertion_scripted": "https://drive.google.com/file/d/17EuCUWS6uCCr6yyNzpXdcdE-_TTNCKtf/view?usp=drive_link", + "aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/1Nk7l53d9sJoGDBKAOnNrExX5nLacATc6/view?usp=drive_link", + "aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1GKReZHrXU73NMiC5zKCq_UtqPVtYq8eo/view?usp=drive_link", + } + + num_episodes = { + "aloha_sim_insertion_human": 50, + "aloha_sim_insertion_scripted": 50, + "aloha_sim_transfer_cube_human": 50, + "aloha_sim_transfer_cube_scripted": 50, + } + + episode_len = { + "aloha_sim_insertion_human": 500, + "aloha_sim_insertion_scripted": 400, + "aloha_sim_transfer_cube_human": 400, + "aloha_sim_transfer_cube_scripted": 400, + } + + cameras = { + "aloha_sim_insertion_human": ["top"], + "aloha_sim_insertion_scripted": ["top"], + "aloha_sim_transfer_cube_human": ["top"], + "aloha_sim_transfer_cube_scripted": ["top"], + } + + root = Path(root) + raw_dir = root / f"{dataset_id}_raw" + if not raw_dir.is_dir(): + import gdown + + assert dataset_id in folder_urls + assert dataset_id in ep48_urls + assert dataset_id in ep49_urls + + raw_dir.mkdir(parents=True, exist_ok=True) + + gdown.download_folder(folder_urls[dataset_id], output=str(raw_dir)) + + # because of the 50 files limit per directory, two files episode 48 and 49 were missing + gdown.download(ep48_urls[dataset_id], output=str(raw_dir / "episode_48.hdf5"), fuzzy=True) + gdown.download(ep49_urls[dataset_id], output=str(raw_dir / "episode_49.hdf5"), fuzzy=True) + + ep_dicts = [] + + id_from = 0 + for ep_id in tqdm.tqdm(range(num_episodes[dataset_id])): + ep_path = raw_dir / f"episode_{ep_id}.hdf5" + with h5py.File(ep_path, "r") as ep: + num_frames = ep["/action"].shape[0] + assert episode_len[dataset_id] == num_frames + + # last step of demonstration is considered done + done = torch.zeros(num_frames, dtype=torch.bool) + done[-1] = True + + state = torch.from_numpy(ep["/observations/qpos"][:]) + action = torch.from_numpy(ep["/action"][:]) + + ep_dict = {} + + for cam in cameras[dataset_id]: + image = torch.from_numpy(ep[f"/observations/images/{cam}"][:]) # b h w c + # image = einops.rearrange(image, "b h w c -> b c h w").contiguous() + ep_dict[f"observation.images.{cam}"] = [PILImage.fromarray(x.numpy()) for x in image] + # ep_dict[f"next.observation.images.{cam}"] = image + + ep_dict.update( + { + "observation.state": state, + "action": action, + "episode_id": torch.tensor([ep_id] * num_frames), + "frame_id": torch.arange(0, num_frames, 1), + "timestamp": torch.arange(0, num_frames, 1) / fps, + # "next.observation.state": state, + # TODO(rcadene): compute reward and success + # "next.reward": reward, + "next.done": done, + # "next.success": success, + "episode_data_index_from": torch.tensor([id_from] * num_frames), + "episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames), + } + ) + + assert isinstance(ep_id, int) + ep_dicts.append(ep_dict) + + id_from += num_frames + + data_dict = {} + + data_dict = {} + keys = ep_dicts[0].keys() + for key in keys: + if torch.is_tensor(ep_dicts[0][key][0]): + data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts]) + else: + if key not in data_dict: + data_dict[key] = [] + for ep_dict in ep_dicts: + for x in ep_dict[key]: + data_dict[key].append(x) + + total_frames = id_from + data_dict["index"] = torch.arange(0, total_frames, 1) + + features = { + "observation.images.top": Image(), + "observation.state": Sequence( + length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None) + ), + "action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)), + "episode_id": Value(dtype="int64", id=None), + "frame_id": Value(dtype="int64", id=None), + "timestamp": Value(dtype="float32", id=None), + #'next.reward': Value(dtype='float32', id=None), + "next.done": Value(dtype="bool", id=None), + #'next.success': Value(dtype='bool', id=None), + "index": Value(dtype="int64", id=None), + "episode_data_index_from": Value(dtype="int64", id=None), + "episode_data_index_to": Value(dtype="int64", id=None), + } + features = Features(features) + dataset = Dataset.from_dict(data_dict, features=features) + dataset = dataset.with_format("torch") + + num_items_first_ep = ep_dicts[0]["frame_id"].shape[0] + dataset.select(range(num_items_first_ep)).save_to_disk(f"{root_tests}/{dataset_id}/train") + dataset.push_to_hub(f"lerobot/{dataset_id}", token=True) + dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0") + + +if __name__ == "__main__": + root = "data" + root_tests = "tests/data" + + dataset_ids = [ + # "pusht", + # "xarm_lift_medium", + # "aloha_sim_insertion_human", + # "aloha_sim_insertion_scripted", + # "aloha_sim_transfer_cube_human", + "aloha_sim_transfer_cube_scripted", + ] + for dataset_id in dataset_ids: + download_and_upload(root, root_tests, dataset_id) + # assume stats have been precomputed + shutil.copy(f"{root}/{dataset_id}/stats.pth", f"{root_tests}/{dataset_id}/stats.pth") diff --git a/examples/1_visualize_dataset.py b/examples/1_visualize_dataset.py index f52ab76a..15e0e54d 100644 --- a/examples/1_visualize_dataset.py +++ b/examples/1_visualize_dataset.py @@ -1,24 +1,20 @@ import os - -from torchrl.data.replay_buffers import SamplerWithoutReplacement +from pathlib import Path import lerobot -from lerobot.common.datasets.aloha import AlohaDataset +from lerobot.common.datasets.pusht import PushtDataset from lerobot.scripts.visualize_dataset import render_dataset print(lerobot.available_datasets) # >>> ['aloha_sim_insertion_human', 'aloha_sim_insertion_scripted', 'aloha_sim_transfer_cube_human', 'aloha_sim_transfer_cube_scripted', 'pusht', 'xarm_lift_medium'] -# we use this sampler to sample 1 frame after the other -sampler = SamplerWithoutReplacement(shuffle=False) - -dataset = AlohaDataset("aloha_sim_transfer_cube_human", sampler=sampler, root=os.environ.get("DATA_DIR")) +# TODO(rcadene): remove DATA_DIR +dataset = PushtDataset("pusht", root=Path(os.environ.get("DATA_DIR"))) video_paths = render_dataset( dataset, out_dir="outputs/visualize_dataset/example", - max_num_samples=300, - fps=50, + max_num_episodes=1, ) print(video_paths) # ['outputs/visualize_dataset/example/episode_0.mp4'] diff --git a/examples/2_evaluate_pretrained_policy.py b/examples/2_evaluate_pretrained_policy.py index be6abd1b..b3d13f74 100644 --- a/examples/2_evaluate_pretrained_policy.py +++ b/examples/2_evaluate_pretrained_policy.py @@ -11,6 +11,7 @@ from lerobot.common.utils import init_hydra_config from lerobot.scripts.eval import eval # Get a pretrained policy from the hub. +# TODO(alexander-soare): This no longer works until we upload a new model that uses the current configs. hub_id = "lerobot/diffusion_policy_pusht_image" folder = Path(snapshot_download(hub_id)) # OR uncomment the following to evaluate a policy from the local outputs/train folder. diff --git a/examples/3_train_policy.py b/examples/3_train_policy.py index 6e01a5d5..7a7a7aaf 100644 --- a/examples/3_train_policy.py +++ b/examples/3_train_policy.py @@ -9,47 +9,60 @@ from pathlib import Path import torch from omegaconf import OmegaConf -from tqdm import trange -from lerobot.common.datasets.factory import make_offline_buffer -from lerobot.common.policies.diffusion.policy import DiffusionPolicy +from lerobot.common.datasets.factory import make_dataset +from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig +from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.common.utils import init_hydra_config output_directory = Path("outputs/train/example_pusht_diffusion") os.makedirs(output_directory, exist_ok=True) -overrides = [ - "env=pusht", - "policy=diffusion", - # Adjust as you prefer. 5000 steps are needed to get something worth evaluating. - "offline_steps=5000", - "log_freq=250", - "device=cuda", -] +# Number of offline training steps (we'll only do offline training for this example. +# Adjust as you prefer. 5000 steps are needed to get something worth evaluating. +training_steps = 5000 +device = torch.device("cuda") +log_freq = 250 -cfg = init_hydra_config("lerobot/configs/default.yaml", overrides) +# Set up the dataset. +hydra_cfg = init_hydra_config("lerobot/configs/default.yaml", overrides=["env=pusht"]) +dataset = make_dataset(hydra_cfg) -policy = DiffusionPolicy( - cfg=cfg.policy, - cfg_device=cfg.device, - cfg_noise_scheduler=cfg.noise_scheduler, - cfg_rgb_model=cfg.rgb_model, - cfg_obs_encoder=cfg.obs_encoder, - cfg_optimizer=cfg.optimizer, - cfg_ema=cfg.ema, - n_action_steps=cfg.n_action_steps, - **cfg.policy, -) +# Set up the the policy. +# Policies are initialized with a configuration class, in this case `DiffusionConfig`. +# For this example, no arguments need to be passed because the defaults are set up for PushT. +# If you're doing something different, you will likely need to change at least some of the defaults. +cfg = DiffusionConfig() +# TODO(alexander-soare): Remove LR scheduler from the policy. +policy = DiffusionPolicy(cfg, lr_scheduler_num_training_steps=training_steps) policy.train() +policy.to(device) -offline_buffer = make_offline_buffer(cfg) +# Create dataloader for offline training. +dataloader = torch.utils.data.DataLoader( + dataset, + num_workers=4, + batch_size=cfg.batch_size, + shuffle=True, + pin_memory=device != torch.device("cpu"), + drop_last=True, +) -for offline_step in trange(cfg.offline_steps): - train_info = policy.update(offline_buffer, offline_step) - if offline_step % cfg.log_freq == 0: - print(train_info) +# Run training loop. +step = 0 +done = False +while not done: + for batch in dataloader: + batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()} + info = policy.update(batch) + if step % log_freq == 0: + print(f"step: {step} loss: {info['loss']:.3f} update_time: {info['update_s']:.3f} (seconds)") + step += 1 + if step >= training_steps: + done = True + break # Save the policy, configuration, and normalization stats for later use. policy.save(output_directory / "model.pt") -OmegaConf.save(cfg, output_directory / "config.yaml") -torch.save(offline_buffer.transform[-1].stats, output_directory / "stats.pth") +OmegaConf.save(hydra_cfg, output_directory / "config.yaml") +torch.save(dataset.transform.transforms[-1].stats, output_directory / "stats.pth") diff --git a/lerobot/__init__.py b/lerobot/__init__.py index 4673aab0..8ab95df8 100644 --- a/lerobot/__init__.py +++ b/lerobot/__init__.py @@ -12,14 +12,11 @@ Example: print(lerobot.available_policies) ``` -Note: - When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to: - 1. set the required class attributes: - - for classes inheriting from `AbstractDataset`: `available_datasets` - - for classes inheriting from `AbstractEnv`: `name`, `available_tasks` - - for classes inheriting from `AbstractPolicy`: `name` - 2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`) - 3. update variables in `tests/test_available.py` by importing your new class +When implementing a new dataset (e.g. `AlohaDataset`), policy (e.g. `DiffusionPolicy`), or environment, follow these steps: +- Set the required class attributes: `available_datasets`. +- Set the required class attributes: `name`. +- Update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`) +- Update variables in `tests/test_available.py` by importing your new class """ from lerobot.__version__ import __version__ # noqa: F401 @@ -32,11 +29,11 @@ available_envs = [ available_tasks_per_env = { "aloha": [ - "sim_insertion", - "sim_transfer_cube", + "AlohaInsertion-v0", + "AlohaTransferCube-v0", ], - "pusht": ["pusht"], - "xarm": ["lift"], + "pusht": ["PushT-v0"], + "xarm": ["XarmLift-v0"], } available_datasets_per_env = { diff --git a/lerobot/commands/env.py b/lerobot/commands/env.py new file mode 100644 index 00000000..1a7e9508 --- /dev/null +++ b/lerobot/commands/env.py @@ -0,0 +1,43 @@ +import platform + +import huggingface_hub + +# import dataset +import numpy as np +import torch + +from lerobot import __version__ as version + +pt_version = torch.__version__ +pt_cuda_available = torch.cuda.is_available() +pt_cuda_available = torch.cuda.is_available() +cuda_version = torch._C._cuda_getCompiledVersion() if torch.version.cuda is not None else "N/A" + + +# TODO(aliberts): refactor into an actual command `lerobot env` +def get_env_info() -> dict: + """Run this to get basic system info to help for tracking issues & bugs.""" + info = { + "`lerobot` version": version, + "Platform": platform.platform(), + "Python version": platform.python_version(), + "Huggingface_hub version": huggingface_hub.__version__, + # TODO(aliberts): Add dataset when https://github.com/huggingface/lerobot/pull/73 is merged + # "Dataset version": dataset.__version__, + "Numpy version": np.__version__, + "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})", + "Cuda version": cuda_version, + "Using GPU in script?": "", + "Using distributed or parallel set-up in script?": "", + } + print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n") + print(format_dict(info)) + return info + + +def format_dict(d: dict) -> str: + return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n" + + +if __name__ == "__main__": + get_env_info() diff --git a/lerobot/common/policies/diffusion/replay_buffer.py b/lerobot/common/datasets/_diffusion_policy_replay_buffer.py similarity index 98% rename from lerobot/common/policies/diffusion/replay_buffer.py rename to lerobot/common/datasets/_diffusion_policy_replay_buffer.py index 7fccf74d..2f532650 100644 --- a/lerobot/common/policies/diffusion/replay_buffer.py +++ b/lerobot/common/datasets/_diffusion_policy_replay_buffer.py @@ -1,3 +1,8 @@ +"""Helper code for loading PushT dataset from Diffusion Policy (https://diffusion-policy.cs.columbia.edu/) + +Copied from the original Diffusion Policy repository and used in our `download_and_upload_dataset.py` script. +""" + from __future__ import annotations import math diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index 50bf819a..0b7ed24b 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -1,72 +1,19 @@ -import logging from pathlib import Path -import einops -import gdown -import h5py import torch -import tqdm +from datasets import load_dataset, load_from_disk -from lerobot.common.datasets.utils import load_data_with_delta_timestamps - -FOLDER_URLS = { - "aloha_sim_insertion_human": "https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF", - "aloha_sim_insertion_scripted": "https://drive.google.com/drive/folders/1TsojQQSXtHEoGnqgJ3gmpPQR2DPLtS2N", - "aloha_sim_transfer_cube_human": "https://drive.google.com/drive/folders/1sc-E4QYW7A0o23m1u2VWNGVq5smAsfCo", - "aloha_sim_transfer_cube_scripted": "https://drive.google.com/drive/folders/1aRyoOhQwxhyt1J8XgEig4s6kzaw__LXj", -} - -EP48_URLS = { - "aloha_sim_insertion_human": "https://drive.google.com/file/d/18Cudl6nikDtgRolea7je8iF_gGKzynOP/view?usp=drive_link", - "aloha_sim_insertion_scripted": "https://drive.google.com/file/d/1wfMSZ24oOh5KR_0aaP3Cnu_c4ZCveduB/view?usp=drive_link", - "aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/18smMymtr8tIxaNUQ61gW6dG50pt3MvGq/view?usp=drive_link", - "aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1pnGIOd-E4-rhz2P3VxpknMKRZCoKt6eI/view?usp=drive_link", -} - -EP49_URLS = { - "aloha_sim_insertion_human": "https://drive.google.com/file/d/1C1kZYyROzs-PrLc0SkDgUgMi4-L3lauE/view?usp=drive_link", - "aloha_sim_insertion_scripted": "https://drive.google.com/file/d/17EuCUWS6uCCr6yyNzpXdcdE-_TTNCKtf/view?usp=drive_link", - "aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/1Nk7l53d9sJoGDBKAOnNrExX5nLacATc6/view?usp=drive_link", - "aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1GKReZHrXU73NMiC5zKCq_UtqPVtYq8eo/view?usp=drive_link", -} - -NUM_EPISODES = { - "aloha_sim_insertion_human": 50, - "aloha_sim_insertion_scripted": 50, - "aloha_sim_transfer_cube_human": 50, - "aloha_sim_transfer_cube_scripted": 50, -} - -EPISODE_LEN = { - "aloha_sim_insertion_human": 500, - "aloha_sim_insertion_scripted": 400, - "aloha_sim_transfer_cube_human": 400, - "aloha_sim_transfer_cube_scripted": 400, -} - -CAMERAS = { - "aloha_sim_insertion_human": ["top"], - "aloha_sim_insertion_scripted": ["top"], - "aloha_sim_transfer_cube_human": ["top"], - "aloha_sim_transfer_cube_scripted": ["top"], -} - - -def download(data_dir, dataset_id): - assert dataset_id in FOLDER_URLS - assert dataset_id in EP48_URLS - assert dataset_id in EP49_URLS - - data_dir.mkdir(parents=True, exist_ok=True) - - gdown.download_folder(FOLDER_URLS[dataset_id], output=str(data_dir)) - - # because of the 50 files limit per directory, two files episode 48 and 49 were missing - gdown.download(EP48_URLS[dataset_id], output=str(data_dir / "episode_48.hdf5"), fuzzy=True) - gdown.download(EP49_URLS[dataset_id], output=str(data_dir / "episode_49.hdf5"), fuzzy=True) +from lerobot.common.datasets.utils import load_previous_and_future_frames class AlohaDataset(torch.utils.data.Dataset): + """ + https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human + https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted + https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human + https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted + """ + available_datasets = [ "aloha_sim_insertion_human", "aloha_sim_insertion_scripted", @@ -79,8 +26,9 @@ class AlohaDataset(torch.utils.data.Dataset): def __init__( self, dataset_id: str, - version: str | None = "v1.2", + version: str | None = "v1.0", root: Path | None = None, + split: str = "train", transform: callable = None, delta_timestamps: dict[list[float]] | None = None, ): @@ -88,120 +36,48 @@ class AlohaDataset(torch.utils.data.Dataset): self.dataset_id = dataset_id self.version = version self.root = root + self.split = split self.transform = transform self.delta_timestamps = delta_timestamps - - self.data_dir = self.root / f"{self.dataset_id}" - if (self.data_dir / "data_dict.pth").exists() and ( - self.data_dir / "data_ids_per_episode.pth" - ).exists(): - self.data_dict = torch.load(self.data_dir / "data_dict.pth") - self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth") + if self.root is not None: + self.data_dict = load_from_disk(Path(self.root) / self.dataset_id / self.split) else: - self._download_and_preproc_obsolete() - self.data_dir.mkdir(parents=True, exist_ok=True) - torch.save(self.data_dict, self.data_dir / "data_dict.pth") - torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth") + self.data_dict = load_dataset( + f"lerobot/{self.dataset_id}", revision=self.version, split=self.split + ) + self.data_dict = self.data_dict.with_format("torch") @property def num_samples(self) -> int: - return len(self.data_dict["index"]) + return len(self.data_dict) @property def num_episodes(self) -> int: - return len(self.data_ids_per_episode) + return len(self.data_dict.unique("episode_id")) def __len__(self): return self.num_samples def __getitem__(self, idx): - item = {} + item = self.data_dict[idx] - # get episode id and timestamp of the sampled frame - current_ts = self.data_dict["timestamp"][idx].item() - episode = self.data_dict["episode"][idx].item() + if self.delta_timestamps is not None: + item = load_previous_and_future_frames( + item, + self.data_dict, + self.delta_timestamps, + ) - for key in self.data_dict: - if self.delta_timestamps is not None and key in self.delta_timestamps: - data, is_pad = load_data_with_delta_timestamps( - self.data_dict, - self.data_ids_per_episode, - self.delta_timestamps, - key, - current_ts, - episode, - ) - item[key] = data - item[f"{key}_is_pad"] = is_pad + # convert images from channel last (PIL) to channel first (pytorch) + for key in self.image_keys: + if item[key].ndim == 3: + item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w + elif item[key].ndim == 4: + item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w else: - item[key] = self.data_dict[key][idx] + raise ValueError(item[key].ndim) if self.transform is not None: item = self.transform(item) return item - - def _download_and_preproc_obsolete(self): - assert self.root is not None - raw_dir = self.root / f"{self.dataset_id}_raw" - if not raw_dir.is_dir(): - download(raw_dir, self.dataset_id) - - total_frames = 0 - logging.info("Compute total number of frames to initialize offline buffer") - for ep_id in range(NUM_EPISODES[self.dataset_id]): - ep_path = raw_dir / f"episode_{ep_id}.hdf5" - with h5py.File(ep_path, "r") as ep: - total_frames += ep["/action"].shape[0] - 1 - logging.info(f"{total_frames=}") - - self.data_ids_per_episode = {} - ep_dicts = [] - - frame_idx = 0 - for ep_id in tqdm.tqdm(range(NUM_EPISODES[self.dataset_id])): - ep_path = raw_dir / f"episode_{ep_id}.hdf5" - with h5py.File(ep_path, "r") as ep: - num_frames = ep["/action"].shape[0] - - # last step of demonstration is considered done - done = torch.zeros(num_frames, dtype=torch.bool) - done[-1] = True - - state = torch.from_numpy(ep["/observations/qpos"][:]) - action = torch.from_numpy(ep["/action"][:]) - - ep_dict = { - "observation.state": state, - "action": action, - "episode": torch.tensor([ep_id] * num_frames), - "frame_id": torch.arange(0, num_frames, 1), - "timestamp": torch.arange(0, num_frames, 1) / self.fps, - # "next.observation.state": state, - # TODO(rcadene): compute reward and success - # "next.reward": reward[1:], - "next.done": done[1:], - # "next.success": success[1:], - } - - for cam in CAMERAS[self.dataset_id]: - image = torch.from_numpy(ep[f"/observations/images/{cam}"][:]) - image = einops.rearrange(image, "b h w c -> b c h w").contiguous() - ep_dict[f"observation.images.{cam}"] = image[:-1] - # ep_dict[f"next.observation.images.{cam}"] = image[1:] - - assert isinstance(ep_id, int) - self.data_ids_per_episode[ep_id] = torch.arange(frame_idx, frame_idx + num_frames, 1) - assert len(self.data_ids_per_episode[ep_id]) == num_frames - - ep_dicts.append(ep_dict) - - frame_idx += num_frames - - self.data_dict = {} - - keys = ep_dicts[0].keys() - for key in keys: - self.data_dict[key] = torch.cat([x[key] for x in ep_dicts]) - - self.data_dict["index"] = torch.arange(0, total_frames, 1) diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 0dab5d4b..07afb614 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -1,15 +1,13 @@ +import logging import os from pathlib import Path import torch from torchvision.transforms import v2 -from lerobot.common.datasets.utils import compute_or_load_stats +from lerobot.common.datasets.utils import compute_stats from lerobot.common.transforms import NormalizeTransform, Prod -# DATA_DIR specifies to location where datasets are loaded. By default, DATA_DIR is None and -# we load from `$HOME/.cache/huggingface/hub/datasets`. For our unit tests, we set `DATA_DIR=tests/data` -# to load a subset of our datasets for faster continuous integration. DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None @@ -18,6 +16,7 @@ def make_dataset( # set normalize=False to remove all transformations and keep images unnormalized in [0,255] normalize=True, stats_path=None, + split="train", ): if cfg.env.name == "xarm": from lerobot.common.datasets.xarm import XarmDataset @@ -40,7 +39,8 @@ def make_dataset( if normalize: # TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, # min_max_from_spec - # stats = dataset.compute_or_load_stats() if stats_path is None else torch.load(stats_path) + # TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std + normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max" if cfg.policy.name == "diffusion" and cfg.env.name == "pusht": stats = {} @@ -51,21 +51,32 @@ def make_dataset( stats["action"] = {} stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32) stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32) + elif stats_path is None: + # load stats if the file exists already or compute stats and save it + if DATA_DIR is None: + # TODO(rcadene): clean stats + precomputed_stats_path = Path("data") / cfg.dataset_id / "stats.pth" + else: + precomputed_stats_path = DATA_DIR / cfg.dataset_id / "stats.pth" + if precomputed_stats_path.exists(): + stats = torch.load(precomputed_stats_path) + else: + logging.info(f"compute_stats and save to {precomputed_stats_path}") + # Create a dataset for stats computation. + stats_dataset = clsfunc( + dataset_id=cfg.dataset_id, + split="train", + root=DATA_DIR, + transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0), + ) + stats = compute_stats(stats_dataset) + precomputed_stats_path.parent.mkdir(parents=True, exist_ok=True) + torch.save(stats, precomputed_stats_path) else: - # instantiate a one frame dataset with light transform - stats_dataset = clsfunc( - dataset_id=cfg.dataset_id, - root=DATA_DIR, - transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0), - ) - stats = compute_or_load_stats(stats_dataset) - - # TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std - normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max" + stats = torch.load(stats_path) transforms = v2.Compose( [ - # TODO(rcadene): we need to do something about image_keys Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0), NormalizeTransform( stats, @@ -86,6 +97,7 @@ def make_dataset( dataset = clsfunc( dataset_id=cfg.dataset_id, + split=split, root=DATA_DIR, delta_timestamps=delta_timestamps, transform=transforms, diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index b468637e..93a4a002 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -1,83 +1,14 @@ from pathlib import Path -import einops -import numpy as np -import pygame -import pymunk import torch -import tqdm -from gym_pusht.envs.pusht import pymunk_to_shapely +from datasets import load_dataset, load_from_disk -from lerobot.common.datasets.utils import download_and_extract_zip, load_data_with_delta_timestamps -from lerobot.common.policies.diffusion.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer - -# as define in env -SUCCESS_THRESHOLD = 0.95 # 95% coverage, - -PUSHT_URL = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip" -PUSHT_ZARR = Path("pusht/pusht_cchi_v7_replay.zarr") - - -def get_goal_pose_body(pose): - mass = 1 - inertia = pymunk.moment_for_box(mass, (50, 100)) - body = pymunk.Body(mass, inertia) - # preserving the legacy assignment order for compatibility - # the order here doesn't matter somehow, maybe because CoM is aligned with body origin - body.position = pose[:2].tolist() - body.angle = pose[2] - return body - - -def add_segment(space, a, b, radius): - shape = pymunk.Segment(space.static_body, a, b, radius) - shape.color = pygame.Color("LightGray") # https://htmlcolorcodes.com/color-names - return shape - - -def add_tee( - space, - position, - angle, - scale=30, - color="LightSlateGray", - mask=None, -): - if mask is None: - mask = pymunk.ShapeFilter.ALL_MASKS() - mass = 1 - length = 4 - vertices1 = [ - (-length * scale / 2, scale), - (length * scale / 2, scale), - (length * scale / 2, 0), - (-length * scale / 2, 0), - ] - inertia1 = pymunk.moment_for_poly(mass, vertices=vertices1) - vertices2 = [ - (-scale / 2, scale), - (-scale / 2, length * scale), - (scale / 2, length * scale), - (scale / 2, scale), - ] - inertia2 = pymunk.moment_for_poly(mass, vertices=vertices1) - body = pymunk.Body(mass, inertia1 + inertia2) - shape1 = pymunk.Poly(body, vertices1) - shape2 = pymunk.Poly(body, vertices2) - shape1.color = pygame.Color(color) - shape2.color = pygame.Color(color) - shape1.filter = pymunk.ShapeFilter(mask=mask) - shape2.filter = pymunk.ShapeFilter(mask=mask) - body.center_of_gravity = (shape1.center_of_gravity + shape2.center_of_gravity) / 2 - body.position = position - body.angle = angle - body.friction = 1 - space.add(body, shape1, shape2) - return body +from lerobot.common.datasets.utils import load_previous_and_future_frames class PushtDataset(torch.utils.data.Dataset): """ + https://huggingface.co/datasets/lerobot/pusht Arguments ---------- @@ -93,8 +24,9 @@ class PushtDataset(torch.utils.data.Dataset): def __init__( self, dataset_id: str, - version: str | None = "v1.2", + version: str | None = "v1.0", root: Path | None = None, + split: str = "train", transform: callable = None, delta_timestamps: dict[list[float]] | None = None, ): @@ -102,177 +34,48 @@ class PushtDataset(torch.utils.data.Dataset): self.dataset_id = dataset_id self.version = version self.root = root + self.split = split self.transform = transform self.delta_timestamps = delta_timestamps - - self.data_dir = self.root / f"{self.dataset_id}" - if (self.data_dir / "data_dict.pth").exists() and ( - self.data_dir / "data_ids_per_episode.pth" - ).exists(): - self.data_dict = torch.load(self.data_dir / "data_dict.pth") - self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth") + if self.root is not None: + self.data_dict = load_from_disk(Path(self.root) / self.dataset_id / self.split) else: - self._download_and_preproc_obsolete() - self.data_dir.mkdir(parents=True, exist_ok=True) - torch.save(self.data_dict, self.data_dir / "data_dict.pth") - torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth") + self.data_dict = load_dataset( + f"lerobot/{self.dataset_id}", revision=self.version, split=self.split + ) + self.data_dict = self.data_dict.with_format("torch") @property def num_samples(self) -> int: - return len(self.data_dict["index"]) + return len(self.data_dict) @property def num_episodes(self) -> int: - return len(self.data_ids_per_episode) + return len(self.data_dict.unique("episode_id")) def __len__(self): return self.num_samples def __getitem__(self, idx): - item = {} + item = self.data_dict[idx] - # get episode id and timestamp of the sampled frame - current_ts = self.data_dict["timestamp"][idx].item() - episode = self.data_dict["episode"][idx].item() + if self.delta_timestamps is not None: + item = load_previous_and_future_frames( + item, + self.data_dict, + self.delta_timestamps, + ) - for key in self.data_dict: - if self.delta_timestamps is not None and key in self.delta_timestamps: - data, is_pad = load_data_with_delta_timestamps( - self.data_dict, - self.data_ids_per_episode, - self.delta_timestamps, - key, - current_ts, - episode, - ) - item[key] = data - item[f"{key}_is_pad"] = is_pad + # convert images from channel last (PIL) to channel first (pytorch) + for key in self.image_keys: + if item[key].ndim == 3: + item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w + elif item[key].ndim == 4: + item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w else: - item[key] = self.data_dict[key][idx] + raise ValueError(item[key].ndim) if self.transform is not None: item = self.transform(item) return item - - def _download_and_preproc_obsolete(self): - assert self.root is not None - raw_dir = self.root / f"{self.dataset_id}_raw" - zarr_path = (raw_dir / PUSHT_ZARR).resolve() - if not zarr_path.is_dir(): - raw_dir.mkdir(parents=True, exist_ok=True) - download_and_extract_zip(PUSHT_URL, raw_dir) - - # load - dataset_dict = DiffusionPolicyReplayBuffer.copy_from_path( - zarr_path - ) # , keys=['img', 'state', 'action']) - - episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs()) - num_episodes = dataset_dict.meta["episode_ends"].shape[0] - total_frames = dataset_dict["action"].shape[0] - # to create test artifact - # num_episodes = 1 - # total_frames = 50 - assert len( - {dataset_dict[key].shape[0] for key in dataset_dict.keys()} # noqa: SIM118 - ), "Some data type dont have the same number of total frames." - - # TODO: verify that goal pose is expected to be fixed - goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians) - goal_body = get_goal_pose_body(goal_pos_angle) - - imgs = torch.from_numpy(dataset_dict["img"]) - imgs = einops.rearrange(imgs, "b h w c -> b c h w") - states = torch.from_numpy(dataset_dict["state"]) - actions = torch.from_numpy(dataset_dict["action"]) - - self.data_ids_per_episode = {} - ep_dicts = [] - - idx0 = 0 - for episode_id in tqdm.tqdm(range(num_episodes)): - idx1 = dataset_dict.meta["episode_ends"][episode_id] - - num_frames = idx1 - idx0 - - assert (episode_ids[idx0:idx1] == episode_id).all() - - image = imgs[idx0:idx1] - - state = states[idx0:idx1] - agent_pos = state[:, :2] - block_pos = state[:, 2:4] - block_angle = state[:, 4] - - reward = torch.zeros(num_frames) - success = torch.zeros(num_frames, dtype=torch.bool) - done = torch.zeros(num_frames, dtype=torch.bool) - for i in range(num_frames): - space = pymunk.Space() - space.gravity = 0, 0 - space.damping = 0 - - # Add walls. - walls = [ - add_segment(space, (5, 506), (5, 5), 2), - add_segment(space, (5, 5), (506, 5), 2), - add_segment(space, (506, 5), (506, 506), 2), - add_segment(space, (5, 506), (506, 506), 2), - ] - space.add(*walls) - - block_body = add_tee(space, block_pos[i].tolist(), block_angle[i].item()) - goal_geom = pymunk_to_shapely(goal_body, block_body.shapes) - block_geom = pymunk_to_shapely(block_body, block_body.shapes) - intersection_area = goal_geom.intersection(block_geom).area - goal_area = goal_geom.area - coverage = intersection_area / goal_area - reward[i] = np.clip(coverage / SUCCESS_THRESHOLD, 0, 1) - success[i] = coverage > SUCCESS_THRESHOLD - - # last step of demonstration is considered done - done[-1] = True - - ep_dict = { - "observation.image": image, - "observation.state": agent_pos, - "action": actions[idx0:idx1], - "episode": torch.tensor([episode_id] * num_frames, dtype=torch.int), - "frame_id": torch.arange(0, num_frames, 1), - "timestamp": torch.arange(0, num_frames, 1) / self.fps, - # "next.observation.image": image[1:], - # "next.observation.state": agent_pos[1:], - # TODO(rcadene): verify that reward and done are aligned with image and agent_pos - "next.reward": torch.cat([reward[1:], reward[[-1]]]), - "next.done": torch.cat([done[1:], done[[-1]]]), - "next.success": torch.cat([success[1:], success[[-1]]]), - } - ep_dicts.append(ep_dict) - - assert isinstance(episode_id, int) - self.data_ids_per_episode[episode_id] = torch.arange(idx0, idx1, 1) - assert len(self.data_ids_per_episode[episode_id]) == num_frames - - idx0 = idx1 - - self.data_dict = {} - - keys = ep_dicts[0].keys() - for key in keys: - self.data_dict[key] = torch.cat([x[key] for x in ep_dicts]) - - self.data_dict["index"] = torch.arange(0, total_frames, 1) - - -if __name__ == "__main__": - dataset = PushtDataset( - "pusht", - root=Path("data"), - delta_timestamps={ - "observation.image": [0, -1, -0.2, -0.1], - "observation.state": [0, -1, -0.2, -0.1], - "action": [-0.1, 0, 1, 2, 3], - }, - ) - dataset[10] diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 3b4aacfc..1b353e69 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -1,115 +1,93 @@ -import io -import logging -import zipfile from copy import deepcopy from math import ceil -from pathlib import Path import einops -import requests import torch import tqdm -def download_and_extract_zip(url: str, destination_folder: Path) -> bool: - print(f"downloading from {url}") - response = requests.get(url, stream=True) - if response.status_code == 200: - total_size = int(response.headers.get("content-length", 0)) - progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True) +def load_previous_and_future_frames( + item: dict[str, torch.Tensor], + data_dict: dict[str, torch.Tensor], + delta_timestamps: dict[str, list[float]], + tol: float = 0.04, +) -> dict[torch.Tensor]: + """ + Given a current item in the dataset containing a timestamp (e.g. 0.6 seconds), and a list of time differences of some modalities (e.g. delta_timestamps={"observation.image": [-0.8, -0.2, 0, 0.2]}), + this function computes for each given modality a list of query timestamps (e.g. [-0.2, 0.4, 0.6, 0.8]) and loads the closest frames in the dataset. - zip_file = io.BytesIO() - for chunk in response.iter_content(chunk_size=1024): - if chunk: - zip_file.write(chunk) - progress_bar.update(len(chunk)) + Importantly, when no frame can be found around a query timestamp within a specified tolerance window (e.g. tol=0.04), this function raises an AssertionError. + When a timestamp is queried before the first available timestamp of the episode or after the last available timestamp, + the violation of the tolerance doesnt raise an AssertionError, and the function populates a boolean array indicating which frames are outside of the episode range. + For instance, this boolean array is useful during batched training to not supervise actions associated to timestamps coming after the end of the episode, + or to pad the observations in a specific way. Note that by default the observation frames before the start of the episode are the same as the first frame of the episode. - progress_bar.close() + Parameters: + - item (dict): A dictionary containing all the data related to a frame. It is the result of `dataset[idx]`. Each key corresponds to a different modality (e.g., "timestamp", "observation.image", "action"). + - data_dict (dict): A dictionary containing the full dataset. Each key corresponds to a different modality (e.g., "timestamp", "observation.image", "action"). + - delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible modality to be retrieved. These deltas are added to the item timestamp to form the query timestamps. + - tol (float, optional): The tolerance level used to determine if a data point is close enough to the query timestamp. Defaults to 0.04. - zip_file.seek(0) + Returns: + - The same item with the queried frames for each modality specified in delta_timestamps, with an additional key for each modality (e.g. "observation.image_is_pad"). - with zipfile.ZipFile(zip_file, "r") as zip_ref: - zip_ref.extractall(destination_folder) - return True - else: - return False - - -def euclidean_distance_matrix(mat0, mat1): - # Compute the square of the distance matrix - sq0 = torch.sum(mat0**2, dim=1, keepdim=True) - sq1 = torch.sum(mat1**2, dim=1, keepdim=True) - distance_sq = sq0 + sq1.transpose(0, 1) - 2 * mat0 @ mat1.transpose(0, 1) - - # Taking the square root to get the euclidean distance - distance = torch.sqrt(torch.clamp(distance_sq, min=0)) - return distance - - -def is_contiguously_true_or_false(bool_vector): - assert bool_vector.ndim == 1 - assert bool_vector.dtype == torch.bool - - # Compare each element with its neighbor to find changes - changes = bool_vector[1:] != bool_vector[:-1] - - # Count the number of changes - num_changes = changes.sum().item() - - # If there's more than one change, the list is not contiguous - return num_changes <= 1 - - # examples = [ - # ([True, False, True, False, False, False], False), - # ([True, True, True, False, False, False], True), - # ([False, False, False, False, False, False], True) - # ] - # for bool_list, expected in examples: - # result = is_contiguously_true_or_false(bool_list) - - -def load_data_with_delta_timestamps( - data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode -): + Raises: + - AssertionError: If any of the frames unexpectedly violate the tolerance level. This could indicate synchronization issues with timestamps during data collection. + """ # get indices of the frames associated to the episode, and their timestamps - ep_data_ids = data_ids_per_episode[episode] - ep_timestamps = data_dict["timestamp"][ep_data_ids] + ep_data_id_from = item["episode_data_index_from"].item() + ep_data_id_to = item["episode_data_index_to"].item() + ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to, 1) - # get timestamps used as query to retrieve data of previous/future frames - delta_ts = delta_timestamps[key] - query_ts = current_ts + torch.tensor(delta_ts) + # load timestamps + ep_timestamps = data_dict.select_columns("timestamp")[ep_data_id_from:ep_data_id_to]["timestamp"] - # compute distances between each query timestamp and all timestamps of all the frames belonging to the episode - dist = euclidean_distance_matrix(query_ts[:, None], ep_timestamps[:, None]) - min_, argmin_ = dist.min(1) + # we make the assumption that the timestamps are sorted + ep_first_ts = ep_timestamps[0] + ep_last_ts = ep_timestamps[-1] + current_ts = item["timestamp"].item() - # get the indices of the data that are closest to the query timestamps - data_ids = ep_data_ids[argmin_] - # closest_ts = ep_timestamps[argmin_] + for key in delta_timestamps: + # get timestamps used as query to retrieve data of previous/future frames + delta_ts = delta_timestamps[key] + query_ts = current_ts + torch.tensor(delta_ts) - # get the data - data = data_dict[key][data_ids].clone() + # compute distances between each query timestamp and all timestamps of all the frames belonging to the episode + dist = torch.cdist(query_ts[:, None], ep_timestamps[:, None], p=1) + min_, argmin_ = dist.min(1) - # TODO(rcadene): synchronize timestamps + interpolation if needed + # TODO(rcadene): synchronize timestamps + interpolation if needed - tol = 0.04 - is_pad = min_ > tol + is_pad = min_ > tol - assert is_contiguously_true_or_false(is_pad), ( - f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tol=})." - "This might be due to synchronization issues with timestamps during data collection." - ) + # check violated query timestamps are all outside the episode range + assert ((query_ts[is_pad] < ep_first_ts) | (ep_last_ts < query_ts[is_pad])).all(), ( + f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tol=}) inside episode range." + "This might be due to synchronization issues with timestamps during data collection." + ) - return data, is_pad + # get dataset indices corresponding to frames to be loaded + data_ids = ep_data_ids[argmin_] + + # load frames modality + item[key] = data_dict.select_columns(key)[data_ids][key] + item[f"{key}_is_pad"] = is_pad + + return item -def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None): - stats_path = dataset.data_dir / "stats.pth" - if stats_path.exists(): - return torch.load(stats_path) +def get_stats_einops_patterns(dataset): + """These einops patterns will be used to aggregate batches and compute statistics.""" + stats_patterns = { + "action": "b c -> c", + "observation.state": "b c -> c", + } + for key in dataset.image_keys: + stats_patterns[key] = "b c h w -> c 1 1" + return stats_patterns - logging.info(f"compute_stats and save to {stats_path}") +def compute_stats(dataset, batch_size=32, max_num_samples=None): if max_num_samples is None: max_num_samples = len(dataset) else: @@ -124,13 +102,8 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None): drop_last=False, ) - # these einops patterns will be used to aggregate batches and compute statistics - stats_patterns = { - "action": "b c -> c", - "observation.state": "b c -> c", - } - for key in dataset.image_keys: - stats_patterns[key] = "b c h w -> c 1 1" + # get einops patterns to aggregate batches and compute statistics + stats_patterns = get_stats_einops_patterns(dataset) # mean and std will be computed incrementally while max and min will track the running value. mean, std, max, min = {}, {}, {}, {} @@ -201,11 +174,14 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None): "min": min[key], } - torch.save(stats, stats_path) return stats def cycle(iterable): + """The equivalent of itertools.cycle, but safe for Pytorch dataloaders. + + See https://github.com/pytorch/pytorch/issues/23900 for information on why itertools.cycle is not safe. + """ iterator = iter(iterable) while True: try: diff --git a/lerobot/common/datasets/xarm.py b/lerobot/common/datasets/xarm.py index 733267ab..605dd1eb 100644 --- a/lerobot/common/datasets/xarm.py +++ b/lerobot/common/datasets/xarm.py @@ -1,30 +1,16 @@ -import pickle -import zipfile from pathlib import Path import torch -import tqdm +from datasets import load_dataset, load_from_disk -from lerobot.common.datasets.utils import load_data_with_delta_timestamps - - -def download(raw_dir): - import gdown - - raw_dir.mkdir(parents=True, exist_ok=True) - url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya" - zip_path = raw_dir / "data.zip" - gdown.download(url, str(zip_path), quiet=False) - print("Extracting...") - with zipfile.ZipFile(str(zip_path), "r") as zip_f: - for member in zip_f.namelist(): - if member.startswith("data/xarm") and member.endswith(".pkl"): - print(member) - zip_f.extract(member=member) - zip_path.unlink() +from lerobot.common.datasets.utils import load_previous_and_future_frames class XarmDataset(torch.utils.data.Dataset): + """ + https://huggingface.co/datasets/lerobot/xarm_lift_medium + """ + available_datasets = [ "xarm_lift_medium", ] @@ -34,8 +20,9 @@ class XarmDataset(torch.utils.data.Dataset): def __init__( self, dataset_id: str, - version: str | None = "v1.1", + version: str | None = "v1.0", root: Path | None = None, + split: str = "train", transform: callable = None, delta_timestamps: dict[list[float]] | None = None, ): @@ -43,120 +30,48 @@ class XarmDataset(torch.utils.data.Dataset): self.dataset_id = dataset_id self.version = version self.root = root + self.split = split self.transform = transform self.delta_timestamps = delta_timestamps - - self.data_dir = self.root / f"{self.dataset_id}" - if (self.data_dir / "data_dict.pth").exists() and ( - self.data_dir / "data_ids_per_episode.pth" - ).exists(): - self.data_dict = torch.load(self.data_dir / "data_dict.pth") - self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth") + if self.root is not None: + self.data_dict = load_from_disk(Path(self.root) / self.dataset_id / self.split) else: - self._download_and_preproc_obsolete() - self.data_dir.mkdir(parents=True, exist_ok=True) - torch.save(self.data_dict, self.data_dir / "data_dict.pth") - torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth") + self.data_dict = load_dataset( + f"lerobot/{self.dataset_id}", revision=self.version, split=self.split + ) + self.data_dict = self.data_dict.with_format("torch") @property def num_samples(self) -> int: - return len(self.data_dict["index"]) + return len(self.data_dict) @property def num_episodes(self) -> int: - return len(self.data_ids_per_episode) + return len(self.data_dict.unique("episode_id")) def __len__(self): return self.num_samples def __getitem__(self, idx): - item = {} + item = self.data_dict[idx] - # get episode id and timestamp of the sampled frame - current_ts = self.data_dict["timestamp"][idx].item() - episode = self.data_dict["episode"][idx].item() + if self.delta_timestamps is not None: + item = load_previous_and_future_frames( + item, + self.data_dict, + self.delta_timestamps, + ) - for key in self.data_dict: - if self.delta_timestamps is not None and key in self.delta_timestamps: - data, is_pad = load_data_with_delta_timestamps( - self.data_dict, - self.data_ids_per_episode, - self.delta_timestamps, - key, - current_ts, - episode, - ) - item[key] = data - item[f"{key}_is_pad"] = is_pad + # convert images from channel last (PIL) to channel first (pytorch) + for key in self.image_keys: + if item[key].ndim == 3: + item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w + elif item[key].ndim == 4: + item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w else: - item[key] = self.data_dict[key][idx] + raise ValueError(item[key].ndim) if self.transform is not None: item = self.transform(item) return item - - def _download_and_preproc_obsolete(self): - assert self.root is not None - raw_dir = self.root / f"{self.dataset_id}_raw" - if not raw_dir.exists(): - download(raw_dir) - - dataset_path = self.root / f"{self.dataset_id}" / "buffer.pkl" - print(f"Using offline dataset '{dataset_path}'") - with open(dataset_path, "rb") as f: - dataset_dict = pickle.load(f) - - total_frames = dataset_dict["actions"].shape[0] - - self.data_ids_per_episode = {} - ep_dicts = [] - - idx0 = 0 - idx1 = 0 - episode_id = 0 - for i in tqdm.tqdm(range(total_frames)): - idx1 += 1 - - if not dataset_dict["dones"][i]: - continue - - num_frames = idx1 - idx0 - - image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1]) - state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1]) - action = torch.tensor(dataset_dict["actions"][idx0:idx1]) - # TODO(rcadene): concat the last "next_observations" to "observations" - # next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1]) - # next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1]) - next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1]) - next_done = torch.tensor(dataset_dict["dones"][idx0:idx1]) - - ep_dict = { - "observation.image": image, - "observation.state": state, - "action": action, - "episode": torch.tensor([episode_id] * num_frames, dtype=torch.int), - "frame_id": torch.arange(0, num_frames, 1), - "timestamp": torch.arange(0, num_frames, 1) / self.fps, - # "next.observation.image": next_image, - # "next.observation.state": next_state, - "next.reward": next_reward, - "next.done": next_done, - } - ep_dicts.append(ep_dict) - - assert isinstance(episode_id, int) - self.data_ids_per_episode[episode_id] = torch.arange(idx0, idx1, 1) - assert len(self.data_ids_per_episode[episode_id]) == num_frames - - idx0 = idx1 - episode_id += 1 - - self.data_dict = {} - - keys = ep_dicts[0].keys() - for key in keys: - self.data_dict[key] = torch.cat([x[key] for x in ep_dicts]) - - self.data_dict["index"] = torch.arange(0, total_frames, 1) diff --git a/lerobot/common/envs/utils.py b/lerobot/common/envs/utils.py index 42353aa6..f82347f5 100644 --- a/lerobot/common/envs/utils.py +++ b/lerobot/common/envs/utils.py @@ -14,11 +14,12 @@ def preprocess_observation(observation, transform=None): imgs = {"observation.image": observation["pixels"]} for imgkey, img in imgs.items(): - img = torch.from_numpy(img).float() + img = torch.from_numpy(img) # convert to (b c h w) torch format img = einops.rearrange(img, "b h w c -> b c h w").contiguous() obs[imgkey] = img + # TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing requirement for "agent_pos" obs["observation.state"] = torch.from_numpy(observation["agent_pos"]).float() # apply same transforms as in training diff --git a/lerobot/common/policies/act/configuration_act.py b/lerobot/common/policies/act/configuration_act.py new file mode 100644 index 00000000..211a8ed0 --- /dev/null +++ b/lerobot/common/policies/act/configuration_act.py @@ -0,0 +1,123 @@ +from dataclasses import dataclass, field + + +@dataclass +class ActionChunkingTransformerConfig: + """Configuration class for the Action Chunking Transformers policy. + + Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer". + + The parameters you will most likely need to change are the ones which depend on the environment / sensors. + Those are: `state_dim`, `action_dim` and `camera_names`. + + Args: + state_dim: Dimensionality of the observation state space (excluding images). + action_dim: Dimensionality of the action space. + n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the + current step and additional steps going back). + camera_names: The (unique) set of names for the cameras. + chunk_size: The size of the action prediction "chunks" in units of environment steps. + n_action_steps: The number of action steps to run in the environment for one invocation of the policy. + This should be no greater than the chunk size. For example, if the chunk size size 100, you may + set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the + environment, and throws the other 50 out. + image_normalization_mean: Value to subtract from the input image pixels (inputs are assumed to be in + [0, 1]) for normalization. + image_normalization_std: Value by which to divide the input image pixels (after the mean has been + subtracted). + vision_backbone: Name of the torchvision resnet backbone to use for encoding images. + use_pretrained_backbone: Whether the backbone should be initialized with pretrained weights from + torchvision. + replace_final_stride_with_dilation: Whether to replace the ResNet's final 2x2 stride with a dilated + convolution. + pre_norm: Whether to use "pre-norm" in the transformer blocks. + d_model: The transformer blocks' main hidden dimension. + n_heads: The number of heads to use in the transformer blocks' multi-head attention. + dim_feedforward: The dimension to expand the transformer's hidden dimension to in the feed-forward + layers. + feedforward_activation: The activation to use in the transformer block's feed-forward layers. + n_encoder_layers: The number of transformer layers to use for the transformer encoder. + n_decoder_layers: The number of transformer layers to use for the transformer decoder. + use_vae: Whether to use a variational objective during training. This introduces another transformer + which is used as the VAE's encoder (not to be confused with the transformer encoder - see + documentation in the policy class). + latent_dim: The VAE's latent dimension. + n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder. + use_temporal_aggregation: Whether to blend the actions of multiple policy invocations for any given + environment step. + dropout: Dropout to use in the transformer layers (see code for details). + kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective + is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`. + """ + + # Environment. + state_dim: int = 14 + action_dim: int = 14 + + # Inputs / output structure. + n_obs_steps: int = 1 + camera_names: tuple[str] = ("top",) + chunk_size: int = 100 + n_action_steps: int = 100 + + # Vision preprocessing. + image_normalization_mean: tuple[float, float, float] = field( + default_factory=lambda: [0.485, 0.456, 0.406] + ) + image_normalization_std: tuple[float, float, float] = field(default_factory=lambda: [0.229, 0.224, 0.225]) + + # Architecture. + # Vision backbone. + vision_backbone: str = "resnet18" + use_pretrained_backbone: bool = True + replace_final_stride_with_dilation: int = False + # Transformer layers. + pre_norm: bool = False + d_model: int = 512 + n_heads: int = 8 + dim_feedforward: int = 3200 + feedforward_activation: str = "relu" + n_encoder_layers: int = 4 + n_decoder_layers: int = 1 + # VAE. + use_vae: bool = True + latent_dim: int = 32 + n_vae_encoder_layers: int = 4 + + # Inference. + use_temporal_aggregation: bool = False + + # Training and loss computation. + dropout: float = 0.1 + kl_weight: float = 10.0 + + # --- + # TODO(alexander-soare): Remove these from the policy config. + batch_size: int = 8 + lr: float = 1e-5 + lr_backbone: float = 1e-5 + weight_decay: float = 1e-4 + grad_clip_norm: float = 10 + utd: int = 1 + + def __post_init__(self): + """Input validation (not exhaustive).""" + if not self.vision_backbone.startswith("resnet"): + raise ValueError( + f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}." + ) + if self.use_temporal_aggregation: + raise NotImplementedError("Temporal aggregation is not yet implemented.") + if self.n_action_steps > self.chunk_size: + raise ValueError( + f"The chunk size is the upper bound for the number of action steps per model invocation. Got " + f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`." + ) + if self.n_obs_steps != 1: + raise ValueError( + f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`" + ) + if self.camera_names != ["top"]: + raise ValueError(f"For now, `camera_names` can only be ['top']. Got {self.camera_names}.") + if len(set(self.camera_names)) != len(self.camera_names): + raise ValueError(f"`camera_names` should not have any repeated entries. Got {self.camera_names}.") diff --git a/lerobot/common/policies/act/policy.py b/lerobot/common/policies/act/modeling_act.py similarity index 63% rename from lerobot/common/policies/act/policy.py rename to lerobot/common/policies/act/modeling_act.py index 25b814ed..c1af4ef4 100644 --- a/lerobot/common/policies/act/policy.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -20,7 +20,7 @@ from torch import Tensor, nn from torchvision.models._utils import IntermediateLayerGetter from torchvision.ops.misc import FrozenBatchNorm2d -from lerobot.common.utils import get_safe_torch_device +from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig class ActionChunkingTransformerPolicy(nn.Module): @@ -61,91 +61,75 @@ class ActionChunkingTransformerPolicy(nn.Module): """ name = "act" - _multiple_obs_steps_not_handled_msg = ( - "ActionChunkingTransformerPolicy does not handle multiple observation steps." - ) - def __init__(self, cfg, device): + def __init__(self, cfg: ActionChunkingTransformerConfig | None = None): """ - TODO(alexander-soare): Add documentation for all parameters once we have model configs established. + Args: + cfg: Policy configuration class instance or None, in which case the default instantiation of the + configuration class is used. """ super().__init__() - if getattr(cfg, "n_obs_steps", 1) != 1: - raise ValueError(self._multiple_obs_steps_not_handled_msg) + if cfg is None: + cfg = ActionChunkingTransformerConfig() self.cfg = cfg - self.n_action_steps = cfg.n_action_steps - self.device = get_safe_torch_device(device) - self.camera_names = cfg.camera_names - self.use_vae = cfg.use_vae - self.horizon = cfg.horizon - self.d_model = cfg.d_model - - transformer_common_kwargs = dict( # noqa: C408 - d_model=self.d_model, - num_heads=cfg.num_heads, - dim_feedforward=cfg.dim_feedforward, - dropout=cfg.dropout, - activation=cfg.activation, - normalize_before=cfg.pre_norm, - ) # BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence]. # The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]). - if self.use_vae: - self.vae_encoder = _TransformerEncoder(num_layers=cfg.vae_enc_layers, **transformer_common_kwargs) - self.vae_encoder_cls_embed = nn.Embedding(1, self.d_model) + if self.cfg.use_vae: + self.vae_encoder = _TransformerEncoder(cfg) + self.vae_encoder_cls_embed = nn.Embedding(1, cfg.d_model) # Projection layer for joint-space configuration to hidden dimension. - self.vae_encoder_robot_state_input_proj = nn.Linear(cfg.state_dim, self.d_model) + self.vae_encoder_robot_state_input_proj = nn.Linear(cfg.state_dim, cfg.d_model) # Projection layer for action (joint-space target) to hidden dimension. - self.vae_encoder_action_input_proj = nn.Linear(cfg.state_dim, self.d_model) + self.vae_encoder_action_input_proj = nn.Linear(cfg.state_dim, cfg.d_model) self.latent_dim = cfg.latent_dim # Projection layer from the VAE encoder's output to the latent distribution's parameter space. - self.vae_encoder_latent_output_proj = nn.Linear(self.d_model, self.latent_dim * 2) + self.vae_encoder_latent_output_proj = nn.Linear(cfg.d_model, self.latent_dim * 2) # Fixed sinusoidal positional embedding the whole input to the VAE encoder. Unsqueeze for batch # dimension. self.register_buffer( "vae_encoder_pos_enc", - _create_sinusoidal_position_embedding(1 + 1 + self.horizon, self.d_model).unsqueeze(0), + _create_sinusoidal_position_embedding(1 + 1 + cfg.chunk_size, cfg.d_model).unsqueeze(0), ) # Backbone for image feature extraction. self.image_normalizer = transforms.Normalize( - mean=cfg.image_normalization.mean, std=cfg.image_normalization.std + mean=cfg.image_normalization_mean, std=cfg.image_normalization_std ) - backbone_model = getattr(torchvision.models, cfg.backbone)( - replace_stride_with_dilation=[False, False, cfg.dilation], - pretrained=cfg.pretrained_backbone, + backbone_model = getattr(torchvision.models, cfg.vision_backbone)( + replace_stride_with_dilation=[False, False, cfg.replace_final_stride_with_dilation], + pretrained=cfg.use_pretrained_backbone, norm_layer=FrozenBatchNorm2d, ) + # Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final feature + # map). # Note: The forward method of this returns a dict: {"feature_map": output}. self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"}) # Transformer (acts as VAE decoder when training with the variational objective). - self.encoder = _TransformerEncoder(num_layers=cfg.enc_layers, **transformer_common_kwargs) - self.decoder = _TransformerDecoder(num_layers=cfg.dec_layers, **transformer_common_kwargs) + self.encoder = _TransformerEncoder(cfg) + self.decoder = _TransformerDecoder(cfg) # Transformer encoder input projections. The tokens will be structured like # [latent, robot_state, image_feature_map_pixels]. - self.encoder_robot_state_input_proj = nn.Linear(cfg.state_dim, self.d_model) - self.encoder_latent_input_proj = nn.Linear(self.latent_dim, self.d_model) + self.encoder_robot_state_input_proj = nn.Linear(cfg.state_dim, cfg.d_model) + self.encoder_latent_input_proj = nn.Linear(self.latent_dim, cfg.d_model) self.encoder_img_feat_input_proj = nn.Conv2d( - backbone_model.fc.in_features, self.d_model, kernel_size=1 + backbone_model.fc.in_features, cfg.d_model, kernel_size=1 ) # Transformer encoder positional embeddings. - self.encoder_robot_and_latent_pos_embed = nn.Embedding(2, self.d_model) - self.encoder_cam_feat_pos_embed = _SinusoidalPositionEmbedding2D(self.d_model // 2) + self.encoder_robot_and_latent_pos_embed = nn.Embedding(2, cfg.d_model) + self.encoder_cam_feat_pos_embed = _SinusoidalPositionEmbedding2D(cfg.d_model // 2) # Transformer decoder. # Learnable positional embedding for the transformer's decoder (in the style of DETR object queries). - self.decoder_pos_embed = nn.Embedding(self.horizon, self.d_model) + self.decoder_pos_embed = nn.Embedding(cfg.chunk_size, cfg.d_model) # Final action regression head on the output of the transformer's decoder. - self.action_head = nn.Linear(self.d_model, cfg.action_dim) + self.action_head = nn.Linear(cfg.d_model, cfg.action_dim) self._reset_parameters() - self._create_optimizer() - self.to(self.device) def _create_optimizer(self): optimizer_params_dicts = [ @@ -173,96 +157,58 @@ class ActionChunkingTransformerPolicy(nn.Module): def reset(self): """This should be called whenever the environment is reset.""" - if self.n_action_steps is not None: - self._action_queue = deque([], maxlen=self.n_action_steps) + if self.cfg.n_action_steps is not None: + self._action_queue = deque([], maxlen=self.cfg.n_action_steps) + + @torch.no_grad + def select_action(self, batch: dict[str, Tensor], **_) -> Tensor: + """Select a single action given environment observations. - def select_action(self, batch: dict[str, Tensor], *_, **__) -> Tensor: - """ This method wraps `select_actions` in order to return one action at a time for execution in the environment. It works by managing the actions in a queue and only calling `select_actions` when the queue is empty. """ + self.eval() if len(self._action_queue) == 0: - # `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has shape - # (n_action_steps, batch_size, *), hence the transpose. - self._action_queue.extend(self.select_actions(batch).transpose(0, 1)) + # `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively + # has shape (n_action_steps, batch_size, *), hence the transpose. + self._action_queue.extend(self._forward(batch)[0][: self.cfg.n_action_steps].transpose(0, 1)) return self._action_queue.popleft() - @torch.no_grad() - def select_actions(self, batch: dict[str, Tensor]) -> Tensor: - """Use the action chunking transformer to generate a sequence of actions.""" - self.eval() - self._preprocess_batch(batch, add_obs_steps_dim=True) + def forward(self, batch, **_) -> dict[str, Tensor]: + """Run the batch through the model and compute the loss for training or validation.""" + actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(batch) - action = self.forward(batch, return_loss=False) + l1_loss = ( + F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1) + ).mean() - if self.cfg.temporal_agg: - # TODO(rcadene): implement temporal aggregation - raise NotImplementedError() - # all_time_actions[[t], t:t+num_queries] = action - # actions_for_curr_step = all_time_actions[:, t] - # actions_populated = torch.all(actions_for_curr_step != 0, axis=1) - # actions_for_curr_step = actions_for_curr_step[actions_populated] - # k = 0.01 - # exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step))) - # exp_weights = exp_weights / exp_weights.sum() - # exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1) - # raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True) + loss_dict = {"l1_loss": l1_loss} + if self.cfg.use_vae: + # Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for + # each dimension independently, we sum over the latent dimension to get the total + # KL-divergence per batch element, then take the mean over the batch. + # (See App. B of https://arxiv.org/abs/1312.6114 for more details). + mean_kld = ( + (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean() + ) + loss_dict["kld_loss"] = mean_kld + loss_dict["loss"] = l1_loss + mean_kld * self.cfg.kl_weight + else: + loss_dict["loss"] = l1_loss - return action[: self.n_action_steps] + return loss_dict - def __call__(self, *args, **kwargs) -> dict: - # TODO(now): Temporary bridge until we know what to do about the `update` method. - return self.update(*args, **kwargs) - - def _preprocess_batch( - self, batch: dict[str, Tensor], add_obs_steps_dim: bool = False - ) -> dict[str, Tensor]: - """ - This function expects `batch` to have (at least): - { - "observation.state": (B, 1, J) OR (B, J) tensor of robot states (joint configuration). - "observation.images.top": (B, 1, C, H, W) OR (B, C, H, W) tensor of images. - "action": (B, H, J) tensor of actions (positional target for robot joint configuration) - "action_is_pad": (B, H) mask for whether the actions are padding outside of the episode bounds. - } - """ - if add_obs_steps_dim: - # Add a dimension for the observations steps. Since n_obs_steps > 1 is not supported right now, - # this just amounts to an unsqueeze. - for k in batch: - if k.startswith("observation."): - batch[k] = batch[k].unsqueeze(1) - - if batch["observation.state"].shape[1] != 1: - raise ValueError(self._multiple_obs_steps_not_handled_msg) - batch["observation.state"] = batch["observation.state"].squeeze(1) - # TODO(alexander-soare): generalize this to multiple images. - assert ( - sum(k.startswith("observation.images.") and not k.endswith("is_pad") for k in batch) == 1 - ), "ACT only handles one image for now." - # Note: no squeeze is required for "observation.images.top" because then we'd have to unsqueeze to get - # the image index dimension. - - def update(self, batch, *_, **__) -> dict: + def update(self, batch, **_) -> dict: + """Run the model in train mode, compute the loss, and do an optimization step.""" start_time = time.time() - self._preprocess_batch(batch) - self.train() - - num_slices = self.cfg.batch_size - batch_size = self.cfg.horizon * num_slices - - assert batch_size % self.cfg.horizon == 0 - assert batch_size % num_slices == 0 - - loss = self.forward(batch, return_loss=True)["loss"] + loss_dict = self.forward(batch) + loss = loss_dict["loss"] loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( - self.parameters(), - self.cfg.grad_clip_norm, - error_if_nonfinite=False, + self.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False ) self.optimizer.step() @@ -277,67 +223,64 @@ class ActionChunkingTransformerPolicy(nn.Module): return info - def forward(self, batch: dict[str, Tensor], return_loss: bool = False) -> dict | Tensor: - images = self.image_normalizer(batch["observation.images.top"]) + def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + """Stacks all the images in a batch and puts them in a new key: "observation.images". - if return_loss: # training time - actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward( - batch["observation.state"], images, batch["action"] - ) - - l1_loss = ( - F.l1_loss(batch["action"], actions_hat, reduction="none") - * ~batch["action_is_pad"].unsqueeze(-1) - ).mean() - - loss_dict = {} - loss_dict["l1"] = l1_loss - if self.cfg.use_vae: - # Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for - # each dimension independently, we sum over the latent dimension to get the total - # KL-divergence per batch element, then take the mean over the batch. - # (See App. B of https://arxiv.org/abs/1312.6114 for more details). - mean_kld = ( - (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean() - ) - loss_dict["kl"] = mean_kld - loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.cfg.kl_weight - else: - loss_dict["loss"] = loss_dict["l1"] - return loss_dict - else: - action, _ = self._forward(batch["observation.state"], images) - return action - - def _forward( - self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None - ) -> tuple[Tensor, tuple[Tensor | None, Tensor | None]]: + This function expects `batch` to have (at least): + { + "observation.state": (B, state_dim) batch of robot states. + "observation.images.{name}": (B, C, H, W) tensor of images. + } """ - Args: - robot_state: (B, J) batch of robot joint configurations. - image: (B, N, C, H, W) batch of N camera frames. - actions: (B, S, A) batch of actions from the target dataset which must be provided if the - VAE is enabled and the model is in training mode. + # Check that there is only one image. + # TODO(alexander-soare): generalize this to multiple images. + provided_cameras = {k.rsplit(".", 1)[-1] for k in batch if k.startswith("observation.images.")} + if len(missing := set(self.cfg.camera_names).difference(provided_cameras)) > 0: + raise ValueError( + f"The following camera images are missing from the provided batch: {missing}. Check the " + "configuration parameter: `camera_names`." + ) + # Stack images in the order dictated by the camera names. + batch["observation.images"] = torch.stack( + [batch[f"observation.images.{name}"] for name in self.cfg.camera_names], + dim=-4, + ) + + def _forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]: + """A forward pass through the Action Chunking Transformer (with optional VAE encoder). + + `batch` should have the following structure: + + { + "observation.state": (B, state_dim) batch of robot states. + "observation.images": (B, n_cameras, C, H, W) batch of images. + "action" (optional, only if training with VAE): (B, chunk_size, action dim) batch of actions. + } + Returns: - (B, S, A) batch of action sequences + (B, chunk_size, action_dim) batch of action sequences Tuple containing the latent PDF's parameters (mean, log(σ²)) both as (B, L) tensors where L is the latent dimension. """ - if self.use_vae and self.training: + if self.cfg.use_vae and self.training: assert ( - actions is not None + "action" in batch ), "actions must be provided when using the variational objective in training mode." - batch_size = robot_state.shape[0] + self._stack_images(batch) + + batch_size = batch["observation.state"].shape[0] # Prepare the latent for input to the transformer encoder. - if self.use_vae and actions is not None: + if self.cfg.use_vae and "action" in batch: # Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence]. cls_embed = einops.repeat( self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size ) # (B, 1, D) - robot_state_embed = self.vae_encoder_robot_state_input_proj(robot_state).unsqueeze(1) # (B, 1, D) - action_embed = self.vae_encoder_action_input_proj(actions) # (B, S, D) + robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]).unsqueeze( + 1 + ) # (B, 1, D) + action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D) vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D) # Prepare fixed positional embedding. @@ -359,15 +302,16 @@ class ActionChunkingTransformerPolicy(nn.Module): # When not using the VAE encoder, we set the latent to be all zeros. mu = log_sigma_x2 = None latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=torch.float32).to( - robot_state.device + batch["observation.state"].device ) # Prepare all other transformer encoder inputs. # Camera observation features and positional embeddings. all_cam_features = [] all_cam_pos_embeds = [] - for cam_id, _ in enumerate(self.camera_names): - cam_features = self.backbone(image[:, cam_id])["feature_map"] + images = self.image_normalizer(batch["observation.images"]) + for cam_index in range(len(self.cfg.camera_names)): + cam_features = self.backbone(images[:, cam_index])["feature_map"] cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype) cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w) all_cam_features.append(cam_features) @@ -377,7 +321,7 @@ class ActionChunkingTransformerPolicy(nn.Module): cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=3) # Get positional embeddings for robot state and latent. - robot_state_embed = self.encoder_robot_state_input_proj(robot_state) + robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) latent_embed = self.encoder_latent_input_proj(latent_sample) # Stack encoder input and positional embeddings moving to (S, B, C). @@ -398,7 +342,9 @@ class ActionChunkingTransformerPolicy(nn.Module): # Forward pass through the transformer modules. encoder_out = self.encoder(encoder_in, pos_embed=pos_embed) decoder_in = torch.zeros( - (self.horizon, batch_size, self.d_model), dtype=pos_embed.dtype, device=pos_embed.device + (self.cfg.chunk_size, batch_size, self.cfg.d_model), + dtype=pos_embed.dtype, + device=pos_embed.device, ) decoder_out = self.decoder( decoder_in, @@ -425,16 +371,10 @@ class ActionChunkingTransformerPolicy(nn.Module): class _TransformerEncoder(nn.Module): """Convenience module for running multiple encoder layers, maybe followed by normalization.""" - def __init__(self, num_layers: int, **encoder_layer_kwargs: dict): + def __init__(self, cfg: ActionChunkingTransformerConfig): super().__init__() - self.layers = nn.ModuleList( - [_TransformerEncoderLayer(**encoder_layer_kwargs) for _ in range(num_layers)] - ) - self.norm = ( - nn.LayerNorm(encoder_layer_kwargs["d_model"]) - if encoder_layer_kwargs["normalize_before"] - else nn.Identity() - ) + self.layers = nn.ModuleList([_TransformerEncoderLayer(cfg) for _ in range(cfg.n_encoder_layers)]) + self.norm = nn.LayerNorm(cfg.d_model) if cfg.pre_norm else nn.Identity() def forward(self, x: Tensor, pos_embed: Tensor | None = None) -> Tensor: for layer in self.layers: @@ -444,39 +384,31 @@ class _TransformerEncoder(nn.Module): class _TransformerEncoderLayer(nn.Module): - def __init__( - self, - d_model: int, - num_heads: int, - dim_feedforward: int, - dropout: float, - activation: str, - normalize_before: bool, - ): + def __init__(self, cfg: ActionChunkingTransformerConfig): super().__init__() - self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout) + self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout) # Feed forward layers. - self.linear1 = nn.Linear(d_model, dim_feedforward) - self.dropout = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim_feedforward, d_model) + self.linear1 = nn.Linear(cfg.d_model, cfg.dim_feedforward) + self.dropout = nn.Dropout(cfg.dropout) + self.linear2 = nn.Linear(cfg.dim_feedforward, cfg.d_model) - self.norm1 = nn.LayerNorm(d_model) - self.norm2 = nn.LayerNorm(d_model) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm(cfg.d_model) + self.norm2 = nn.LayerNorm(cfg.d_model) + self.dropout1 = nn.Dropout(cfg.dropout) + self.dropout2 = nn.Dropout(cfg.dropout) - self.activation = _get_activation_fn(activation) - self.normalize_before = normalize_before + self.activation = _get_activation_fn(cfg.feedforward_activation) + self.pre_norm = cfg.pre_norm def forward(self, x, pos_embed: Tensor | None = None) -> Tensor: skip = x - if self.normalize_before: + if self.pre_norm: x = self.norm1(x) q = k = x if pos_embed is None else x + pos_embed x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights x = skip + self.dropout1(x) - if self.normalize_before: + if self.pre_norm: skip = x x = self.norm2(x) else: @@ -484,20 +416,17 @@ class _TransformerEncoderLayer(nn.Module): skip = x x = self.linear2(self.dropout(self.activation(self.linear1(x)))) x = skip + self.dropout2(x) - if not self.normalize_before: + if not self.pre_norm: x = self.norm2(x) return x class _TransformerDecoder(nn.Module): - def __init__(self, num_layers: int, **decoder_layer_kwargs): + def __init__(self, cfg: ActionChunkingTransformerConfig): """Convenience module for running multiple decoder layers followed by normalization.""" super().__init__() - self.layers = nn.ModuleList( - [_TransformerDecoderLayer(**decoder_layer_kwargs) for _ in range(num_layers)] - ) - self.num_layers = num_layers - self.norm = nn.LayerNorm(decoder_layer_kwargs["d_model"]) + self.layers = nn.ModuleList([_TransformerDecoderLayer(cfg) for _ in range(cfg.n_decoder_layers)]) + self.norm = nn.LayerNorm(cfg.d_model) def forward( self, @@ -516,33 +445,25 @@ class _TransformerDecoder(nn.Module): class _TransformerDecoderLayer(nn.Module): - def __init__( - self, - d_model: int, - num_heads: int, - dim_feedforward: int, - dropout: float, - activation: str, - normalize_before: bool, - ): + def __init__(self, cfg: ActionChunkingTransformerConfig): super().__init__() - self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout) - self.multihead_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout) + self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout) + self.multihead_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout) # Feed forward layers. - self.linear1 = nn.Linear(d_model, dim_feedforward) - self.dropout = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim_feedforward, d_model) + self.linear1 = nn.Linear(cfg.d_model, cfg.dim_feedforward) + self.dropout = nn.Dropout(cfg.dropout) + self.linear2 = nn.Linear(cfg.dim_feedforward, cfg.d_model) - self.norm1 = nn.LayerNorm(d_model) - self.norm2 = nn.LayerNorm(d_model) - self.norm3 = nn.LayerNorm(d_model) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) - self.dropout3 = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm(cfg.d_model) + self.norm2 = nn.LayerNorm(cfg.d_model) + self.norm3 = nn.LayerNorm(cfg.d_model) + self.dropout1 = nn.Dropout(cfg.dropout) + self.dropout2 = nn.Dropout(cfg.dropout) + self.dropout3 = nn.Dropout(cfg.dropout) - self.activation = _get_activation_fn(activation) - self.normalize_before = normalize_before + self.activation = _get_activation_fn(cfg.feedforward_activation) + self.pre_norm = cfg.pre_norm def maybe_add_pos_embed(self, tensor: Tensor, pos_embed: Tensor | None) -> Tensor: return tensor if pos_embed is None else tensor + pos_embed @@ -565,12 +486,12 @@ class _TransformerDecoderLayer(nn.Module): (DS, B, C) tensor of decoder output features. """ skip = x - if self.normalize_before: + if self.pre_norm: x = self.norm1(x) q = k = self.maybe_add_pos_embed(x, decoder_pos_embed) x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights x = skip + self.dropout1(x) - if self.normalize_before: + if self.pre_norm: skip = x x = self.norm2(x) else: @@ -582,7 +503,7 @@ class _TransformerDecoderLayer(nn.Module): value=encoder_out, )[0] # select just the output, not the attention weights x = skip + self.dropout2(x) - if self.normalize_before: + if self.pre_norm: skip = x x = self.norm3(x) else: @@ -590,7 +511,7 @@ class _TransformerDecoderLayer(nn.Module): skip = x x = self.linear2(self.dropout(self.activation(self.linear1(x)))) x = skip + self.dropout3(x) - if not self.normalize_before: + if not self.pre_norm: x = self.norm3(x) return x diff --git a/lerobot/common/policies/diffusion/configuration_diffusion.py b/lerobot/common/policies/diffusion/configuration_diffusion.py new file mode 100644 index 00000000..d8820a0b --- /dev/null +++ b/lerobot/common/policies/diffusion/configuration_diffusion.py @@ -0,0 +1,135 @@ +from dataclasses import dataclass + + +@dataclass +class DiffusionConfig: + """Configuration class for Diffusion Policy. + + Defaults are configured for training with PushT providing proprioceptive and single camera observations. + + The parameters you will most likely need to change are the ones which depend on the environment / sensors. + Those are: `state_dim`, `action_dim` and `image_size`. + + Args: + state_dim: Dimensionality of the observation state space (excluding images). + action_dim: Dimensionality of the action space. + image_size: (H, W) size of the input images. + n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the + current step and additional steps going back). + horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`. + n_action_steps: The number of action steps to run in the environment for one invocation of the policy. + See `DiffusionPolicy.select_action` for more details. + image_normalization_mean: Value to subtract from the input image pixels (inputs are assumed to be in + [0, 1]) for normalization. + image_normalization_std: Value by which to divide the input image pixels (after the mean has been + subtracted). + vision_backbone: Name of the torchvision resnet backbone to use for encoding images. + crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit + within the image size. If None, no cropping is done. + crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval + mode). + use_pretrained_backbone: Whether the backbone should be initialized with pretrained weights from + torchvision. + use_group_norm: Whether to replace batch normalization with group normalization in the backbone. + The group sizes are set to be about 16 (to be precise, feature_dim // 16). + spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax. + down_dims: Feature dimension for each stage of temporal downsampling in the diffusion modeling Unet. + You may provide a variable number of dimensions, therefore also controlling the degree of + downsampling. + kernel_size: The convolutional kernel size of the diffusion modeling Unet. + n_groups: Number of groups used in the group norm of the Unet's convolutional blocks. + diffusion_step_embed_dim: The Unet is conditioned on the diffusion timestep via a small non-linear + network. This is the output dimension of that network, i.e., the embedding dimension. + use_film_scale_modulation: FiLM (https://arxiv.org/abs/1709.07871) is used for the Unet conditioning. + Bias modulation is used be default, while this parameter indicates whether to also use scale + modulation. + num_train_timesteps: Number of diffusion steps for the forward diffusion schedule. + beta_schedule: Name of the diffusion beta schedule as per DDPMScheduler from Hugging Face diffusers. + beta_start: Beta value for the first forward-diffusion step. + beta_end: Beta value for the last forward-diffusion step. + prediction_type: The type of prediction that the diffusion modeling Unet makes. Choose from "epsilon" + or "sample". These have equivalent outcomes from a latent variable modeling perspective, but + "epsilon" has been shown to work better in many deep neural network settings. + clip_sample: Whether to clip the sample to [-`clip_sample_range`, +`clip_sample_range`] for each + denoising step at inference time. WARNING: you will need to make sure your action-space is + normalized to fit within this range. + clip_sample_range: The magnitude of the clipping range as described above. + num_inference_steps: Number of reverse diffusion steps to use at inference time (steps are evenly + spaced). If not provided, this defaults to be the same as `num_train_timesteps`. + """ + + # Environment. + # Inherit these from the environment config. + state_dim: int = 2 + action_dim: int = 2 + image_size: tuple[int, int] = (96, 96) + + # Inputs / output structure. + n_obs_steps: int = 2 + horizon: int = 16 + n_action_steps: int = 8 + + # Vision preprocessing. + image_normalization_mean: tuple[float, float, float] = (0.5, 0.5, 0.5) + image_normalization_std: tuple[float, float, float] = (0.5, 0.5, 0.5) + + # Architecture / modeling. + # Vision backbone. + vision_backbone: str = "resnet18" + crop_shape: tuple[int, int] | None = (84, 84) + crop_is_random: bool = True + use_pretrained_backbone: bool = False + use_group_norm: bool = True + spatial_softmax_num_keypoints: int = 32 + # Unet. + down_dims: tuple[int, ...] = (512, 1024, 2048) + kernel_size: int = 5 + n_groups: int = 8 + diffusion_step_embed_dim: int = 128 + use_film_scale_modulation: bool = True + # Noise scheduler. + num_train_timesteps: int = 100 + beta_schedule: str = "squaredcos_cap_v2" + beta_start: float = 0.0001 + beta_end: float = 0.02 + prediction_type: str = "epsilon" + clip_sample: bool = True + clip_sample_range: float = 1.0 + + # Inference + num_inference_steps: int | None = None + + # --- + # TODO(alexander-soare): Remove these from the policy config. + batch_size: int = 64 + grad_clip_norm: int = 10 + lr: float = 1.0e-4 + lr_scheduler: str = "cosine" + lr_warmup_steps: int = 500 + adam_betas: tuple[float, float] = (0.95, 0.999) + adam_eps: float = 1.0e-8 + adam_weight_decay: float = 1.0e-6 + utd: int = 1 + use_ema: bool = True + ema_update_after_step: int = 0 + ema_min_alpha: float = 0.0 + ema_max_alpha: float = 0.9999 + ema_inv_gamma: float = 1.0 + ema_power: float = 0.75 + + def __post_init__(self): + """Input validation (not exhaustive).""" + if not self.vision_backbone.startswith("resnet"): + raise ValueError( + f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}." + ) + if self.crop_shape[0] > self.image_size[0] or self.crop_shape[1] > self.image_size[1]: + raise ValueError( + f"`crop_shape` should fit within `image_size`. Got {self.crop_shape} for `crop_shape` and " + f"{self.image_size} for `image_size`." + ) + supported_prediction_types = ["epsilon", "sample"] + if self.prediction_type not in supported_prediction_types: + raise ValueError( + f"`prediction_type` must be one of {supported_prediction_types}. Got {self.prediction_type}." + ) diff --git a/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py b/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py deleted file mode 100644 index f7432db3..00000000 --- a/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py +++ /dev/null @@ -1,315 +0,0 @@ -"""Code from the original diffusion policy project. - -Notes on how to load a checkpoint from the original repository: - -In the original repository, run the eval and use a breakpoint to extract the policy weights. - -``` -torch.save(policy.state_dict(), "weights.pt") -``` - -In this repository, add a breakpoint somewhere after creating an equivalent policy and load in the weights: - -``` -loaded = torch.load("weights.pt") -aligned = {} -their_prefix = "obs_encoder.obs_nets.image.backbone" -our_prefix = "obs_encoder.key_model_map.image.backbone" -aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)}) -their_prefix = "obs_encoder.obs_nets.image.pool" -our_prefix = "obs_encoder.key_model_map.image.pool" -aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)}) -their_prefix = "obs_encoder.obs_nets.image.nets.3" -our_prefix = "obs_encoder.key_model_map.image.out" -aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)}) -aligned.update({k: v for k, v in loaded.items() if k.startswith('model.')}) -# Note: here you are loading into the ema model. -missing_keys, unexpected_keys = policy.ema_diffusion.load_state_dict(aligned, strict=False) -assert all('_dummy_variable' in k for k in missing_keys) -assert len(unexpected_keys) == 0 -``` - -Then in that same runtime you can also save the weights with the new aligned state_dict: - -``` -policy.save("weights.pt") -``` - -Now you can remove the breakpoint and extra code and load in the weights just like with any other lerobot checkpoint. - -""" - -from typing import Dict - -import torch -import torch.nn.functional as F # noqa: N812 -from diffusers.schedulers.scheduling_ddpm import DDPMScheduler -from einops import reduce - -from lerobot.common.policies.diffusion.model.conditional_unet1d import ConditionalUnet1D -from lerobot.common.policies.diffusion.model.mask_generator import LowdimMaskGenerator -from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin -from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder -from lerobot.common.policies.diffusion.model.normalizer import LinearNormalizer -from lerobot.common.policies.diffusion.pytorch_utils import dict_apply - - -class BaseImagePolicy(ModuleAttrMixin): - # init accepts keyword argument shape_meta, see config/task/*_image.yaml - - def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - """ - obs_dict: - str: B,To,* - return: B,Ta,Da - """ - raise NotImplementedError() - - # reset state for stateful policies - def reset(self): - pass - - # ========== training =========== - # no standard training interface except setting normalizer - def set_normalizer(self, normalizer: LinearNormalizer): - raise NotImplementedError() - - -class DiffusionUnetImagePolicy(BaseImagePolicy): - def __init__( - self, - shape_meta: dict, - noise_scheduler: DDPMScheduler, - obs_encoder: MultiImageObsEncoder, - horizon, - n_action_steps, - n_obs_steps, - num_inference_steps=None, - obs_as_global_cond=True, - diffusion_step_embed_dim=256, - down_dims=(256, 512, 1024), - kernel_size=5, - n_groups=8, - cond_predict_scale=True, - # parameters passed to step - **kwargs, - ): - super().__init__() - - # parse shapes - action_shape = shape_meta["action"]["shape"] - assert len(action_shape) == 1 - action_dim = action_shape[0] - # get feature dim - obs_feature_dim = obs_encoder.output_shape()[0] - - # create diffusion model - input_dim = action_dim + obs_feature_dim - global_cond_dim = None - if obs_as_global_cond: - input_dim = action_dim - global_cond_dim = obs_feature_dim * n_obs_steps - - model = ConditionalUnet1D( - input_dim=input_dim, - local_cond_dim=None, - global_cond_dim=global_cond_dim, - diffusion_step_embed_dim=diffusion_step_embed_dim, - down_dims=down_dims, - kernel_size=kernel_size, - n_groups=n_groups, - cond_predict_scale=cond_predict_scale, - ) - - self.obs_encoder = obs_encoder - self.model = model - self.noise_scheduler = noise_scheduler - self.mask_generator = LowdimMaskGenerator( - action_dim=action_dim, - obs_dim=0 if obs_as_global_cond else obs_feature_dim, - max_n_obs_steps=n_obs_steps, - fix_obs_steps=True, - action_visible=False, - ) - self.horizon = horizon - self.obs_feature_dim = obs_feature_dim - self.action_dim = action_dim - self.n_action_steps = n_action_steps - self.n_obs_steps = n_obs_steps - self.obs_as_global_cond = obs_as_global_cond - self.kwargs = kwargs - - if num_inference_steps is None: - num_inference_steps = noise_scheduler.config.num_train_timesteps - self.num_inference_steps = num_inference_steps - - # ========= inference ============ - def conditional_sample( - self, - condition_data, - condition_mask, - local_cond=None, - global_cond=None, - generator=None, - # keyword arguments to scheduler.step - **kwargs, - ): - model = self.model - scheduler = self.noise_scheduler - - trajectory = torch.randn( - size=condition_data.shape, - dtype=condition_data.dtype, - device=condition_data.device, - generator=generator, - ) - - # set step values - scheduler.set_timesteps(self.num_inference_steps) - - for t in scheduler.timesteps: - # 1. apply conditioning - trajectory[condition_mask] = condition_data[condition_mask] - - # 2. predict model output - model_output = model(trajectory, t, local_cond=local_cond, global_cond=global_cond) - - # 3. compute previous image: x_t -> x_t-1 - trajectory = scheduler.step( - model_output, - t, - trajectory, - generator=generator, - # **kwargs # TODO(rcadene): in diffusion_policy, expected to be {} - ).prev_sample - - # finally make sure conditioning is enforced - trajectory[condition_mask] = condition_data[condition_mask] - - return trajectory - - def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - """ - obs_dict: must include "obs" key - result: must include "action" key - """ - assert "past_action" not in obs_dict # not implemented yet - nobs = obs_dict - value = next(iter(nobs.values())) - bsize, n_obs_steps = value.shape[:2] - horizon = self.horizon - action_dim = self.action_dim - obs_dim = self.obs_feature_dim - assert self.n_obs_steps == n_obs_steps - - # build input - device = self.device - dtype = self.dtype - - # handle different ways of passing observation - local_cond = None - global_cond = None - if self.obs_as_global_cond: - # condition through global feature - this_nobs = dict_apply(nobs, lambda x: x[:, :n_obs_steps, ...].reshape(-1, *x.shape[2:])) - nobs_features = self.obs_encoder(this_nobs) - # reshape back to B, Do - global_cond = nobs_features.reshape(bsize, -1) - # empty data for action - cond_data = torch.zeros(size=(bsize, horizon, action_dim), device=device, dtype=dtype) - cond_mask = torch.zeros_like(cond_data, dtype=torch.bool) - else: - # condition through impainting - this_nobs = dict_apply(nobs, lambda x: x[:, :n_obs_steps, ...].reshape(-1, *x.shape[2:])) - nobs_features = self.obs_encoder(this_nobs) - # reshape back to B, T, Do - nobs_features = nobs_features.reshape(bsize, n_obs_steps, -1) - cond_data = torch.zeros(size=(bsize, horizon, action_dim + obs_dim), device=device, dtype=dtype) - cond_mask = torch.zeros_like(cond_data, dtype=torch.bool) - cond_data[:, :n_obs_steps, action_dim:] = nobs_features - cond_mask[:, :n_obs_steps, action_dim:] = True - - # run sampling - nsample = self.conditional_sample( - cond_data, cond_mask, local_cond=local_cond, global_cond=global_cond - ) - - action_pred = nsample[..., :action_dim] - # get action - start = n_obs_steps - 1 - end = start + self.n_action_steps - action = action_pred[:, start:end] - - result = {"action": action, "action_pred": action_pred} - return result - - def compute_loss(self, batch): - nobs = { - "image": batch["observation.image"], - "agent_pos": batch["observation.state"], - } - nactions = batch["action"] - batch_size = nactions.shape[0] - horizon = nactions.shape[1] - - # handle different ways of passing observation - local_cond = None - global_cond = None - trajectory = nactions - cond_data = trajectory - if self.obs_as_global_cond: - # reshape B, T, ... to B*T - this_nobs = dict_apply(nobs, lambda x: x[:, : self.n_obs_steps, ...].reshape(-1, *x.shape[2:])) - nobs_features = self.obs_encoder(this_nobs) - # reshape back to B, Do - global_cond = nobs_features.reshape(batch_size, -1) - else: - # reshape B, T, ... to B*T - this_nobs = dict_apply(nobs, lambda x: x.reshape(-1, *x.shape[2:])) - nobs_features = self.obs_encoder(this_nobs) - # reshape back to B, T, Do - nobs_features = nobs_features.reshape(batch_size, horizon, -1) - cond_data = torch.cat([nactions, nobs_features], dim=-1) - trajectory = cond_data.detach() - - # generate impainting mask - condition_mask = self.mask_generator(trajectory.shape) - - # Sample noise that we'll add to the images - noise = torch.randn(trajectory.shape, device=trajectory.device) - bsz = trajectory.shape[0] - # Sample a random timestep for each image - timesteps = torch.randint( - 0, self.noise_scheduler.config.num_train_timesteps, (bsz,), device=trajectory.device - ).long() - # Add noise to the clean images according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_trajectory = self.noise_scheduler.add_noise(trajectory, noise, timesteps) - - # compute loss mask - loss_mask = ~condition_mask - - # apply conditioning - noisy_trajectory[condition_mask] = cond_data[condition_mask] - - # Predict the noise residual - pred = self.model(noisy_trajectory, timesteps, local_cond=local_cond, global_cond=global_cond) - - pred_type = self.noise_scheduler.config.prediction_type - if pred_type == "epsilon": - target = noise - elif pred_type == "sample": - target = trajectory - else: - raise ValueError(f"Unsupported prediction type {pred_type}") - - loss = F.mse_loss(pred, target, reduction="none") - loss = loss * loss_mask.type(loss.dtype) - - if "action_is_pad" in batch: - in_episode_bound = ~batch["action_is_pad"] - loss = loss * in_episode_bound[:, :, None].type(loss.dtype) - - loss = reduce(loss, "b t c -> b", "mean", b=batch_size) - loss = loss.mean() - return loss diff --git a/lerobot/common/policies/diffusion/model/conditional_unet1d.py b/lerobot/common/policies/diffusion/model/conditional_unet1d.py deleted file mode 100644 index d2971d38..00000000 --- a/lerobot/common/policies/diffusion/model/conditional_unet1d.py +++ /dev/null @@ -1,286 +0,0 @@ -import logging -from typing import Union - -import einops -import torch -import torch.nn as nn -from einops.layers.torch import Rearrange - -from lerobot.common.policies.diffusion.model.conv1d_components import Conv1dBlock, Downsample1d, Upsample1d -from lerobot.common.policies.diffusion.model.positional_embedding import SinusoidalPosEmb - -logger = logging.getLogger(__name__) - - -class ConditionalResidualBlock1D(nn.Module): - def __init__( - self, in_channels, out_channels, cond_dim, kernel_size=3, n_groups=8, cond_predict_scale=False - ): - super().__init__() - - self.blocks = nn.ModuleList( - [ - Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups), - Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups), - ] - ) - - # FiLM modulation https://arxiv.org/abs/1709.07871 - # predicts per-channel scale and bias - cond_channels = out_channels - if cond_predict_scale: - cond_channels = out_channels * 2 - self.cond_predict_scale = cond_predict_scale - self.out_channels = out_channels - self.cond_encoder = nn.Sequential( - nn.Mish(), - nn.Linear(cond_dim, cond_channels), - Rearrange("batch t -> batch t 1"), - ) - - # make sure dimensions compatible - self.residual_conv = ( - nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity() - ) - - def forward(self, x, cond): - """ - x : [ batch_size x in_channels x horizon ] - cond : [ batch_size x cond_dim] - - returns: - out : [ batch_size x out_channels x horizon ] - """ - out = self.blocks[0](x) - embed = self.cond_encoder(cond) - if self.cond_predict_scale: - embed = embed.reshape(embed.shape[0], 2, self.out_channels, 1) - scale = embed[:, 0, ...] - bias = embed[:, 1, ...] - out = scale * out + bias - else: - out = out + embed - out = self.blocks[1](out) - out = out + self.residual_conv(x) - return out - - -class ConditionalUnet1D(nn.Module): - def __init__( - self, - input_dim, - local_cond_dim=None, - global_cond_dim=None, - diffusion_step_embed_dim=256, - down_dims=None, - kernel_size=3, - n_groups=8, - cond_predict_scale=False, - ): - super().__init__() - if down_dims is None: - down_dims = [256, 512, 1024] - - all_dims = [input_dim] + list(down_dims) - start_dim = down_dims[0] - - dsed = diffusion_step_embed_dim - diffusion_step_encoder = nn.Sequential( - SinusoidalPosEmb(dsed), - nn.Linear(dsed, dsed * 4), - nn.Mish(), - nn.Linear(dsed * 4, dsed), - ) - cond_dim = dsed - if global_cond_dim is not None: - cond_dim += global_cond_dim - - in_out = list(zip(all_dims[:-1], all_dims[1:], strict=False)) - - local_cond_encoder = None - if local_cond_dim is not None: - _, dim_out = in_out[0] - dim_in = local_cond_dim - local_cond_encoder = nn.ModuleList( - [ - # down encoder - ConditionalResidualBlock1D( - dim_in, - dim_out, - cond_dim=cond_dim, - kernel_size=kernel_size, - n_groups=n_groups, - cond_predict_scale=cond_predict_scale, - ), - # up encoder - ConditionalResidualBlock1D( - dim_in, - dim_out, - cond_dim=cond_dim, - kernel_size=kernel_size, - n_groups=n_groups, - cond_predict_scale=cond_predict_scale, - ), - ] - ) - - mid_dim = all_dims[-1] - self.mid_modules = nn.ModuleList( - [ - ConditionalResidualBlock1D( - mid_dim, - mid_dim, - cond_dim=cond_dim, - kernel_size=kernel_size, - n_groups=n_groups, - cond_predict_scale=cond_predict_scale, - ), - ConditionalResidualBlock1D( - mid_dim, - mid_dim, - cond_dim=cond_dim, - kernel_size=kernel_size, - n_groups=n_groups, - cond_predict_scale=cond_predict_scale, - ), - ] - ) - - down_modules = nn.ModuleList([]) - for ind, (dim_in, dim_out) in enumerate(in_out): - is_last = ind >= (len(in_out) - 1) - down_modules.append( - nn.ModuleList( - [ - ConditionalResidualBlock1D( - dim_in, - dim_out, - cond_dim=cond_dim, - kernel_size=kernel_size, - n_groups=n_groups, - cond_predict_scale=cond_predict_scale, - ), - ConditionalResidualBlock1D( - dim_out, - dim_out, - cond_dim=cond_dim, - kernel_size=kernel_size, - n_groups=n_groups, - cond_predict_scale=cond_predict_scale, - ), - Downsample1d(dim_out) if not is_last else nn.Identity(), - ] - ) - ) - - up_modules = nn.ModuleList([]) - for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): - is_last = ind >= (len(in_out) - 1) - up_modules.append( - nn.ModuleList( - [ - ConditionalResidualBlock1D( - dim_out * 2, - dim_in, - cond_dim=cond_dim, - kernel_size=kernel_size, - n_groups=n_groups, - cond_predict_scale=cond_predict_scale, - ), - ConditionalResidualBlock1D( - dim_in, - dim_in, - cond_dim=cond_dim, - kernel_size=kernel_size, - n_groups=n_groups, - cond_predict_scale=cond_predict_scale, - ), - Upsample1d(dim_in) if not is_last else nn.Identity(), - ] - ) - ) - - final_conv = nn.Sequential( - Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size), - nn.Conv1d(start_dim, input_dim, 1), - ) - - self.diffusion_step_encoder = diffusion_step_encoder - self.local_cond_encoder = local_cond_encoder - self.up_modules = up_modules - self.down_modules = down_modules - self.final_conv = final_conv - - logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters())) - - def forward( - self, - sample: torch.Tensor, - timestep: Union[torch.Tensor, float, int], - local_cond=None, - global_cond=None, - **kwargs, - ): - """ - x: (B,T,input_dim) - timestep: (B,) or int, diffusion step - local_cond: (B,T,local_cond_dim) - global_cond: (B,global_cond_dim) - output: (B,T,input_dim) - """ - sample = einops.rearrange(sample, "b h t -> b t h") - - # 1. time - timesteps = timestep - if not torch.is_tensor(timesteps): - # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can - timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) - elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps.expand(sample.shape[0]) - - global_feature = self.diffusion_step_encoder(timesteps) - - if global_cond is not None: - global_feature = torch.cat([global_feature, global_cond], axis=-1) - - # encode local features - h_local = [] - if local_cond is not None: - local_cond = einops.rearrange(local_cond, "b h t -> b t h") - resnet, resnet2 = self.local_cond_encoder - x = resnet(local_cond, global_feature) - h_local.append(x) - x = resnet2(local_cond, global_feature) - h_local.append(x) - - x = sample - h = [] - for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules): - x = resnet(x, global_feature) - if idx == 0 and len(h_local) > 0: - x = x + h_local[0] - x = resnet2(x, global_feature) - h.append(x) - x = downsample(x) - - for mid_module in self.mid_modules: - x = mid_module(x, global_feature) - - for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules): - x = torch.cat((x, h.pop()), dim=1) - x = resnet(x, global_feature) - # The correct condition should be: - # if idx == (len(self.up_modules)-1) and len(h_local) > 0: - # However this change will break compatibility with published checkpoints. - # Therefore it is left as a comment. - if idx == len(self.up_modules) and len(h_local) > 0: - x = x + h_local[1] - x = resnet2(x, global_feature) - x = upsample(x) - - x = self.final_conv(x) - - x = einops.rearrange(x, "b t h -> b h t") - return x diff --git a/lerobot/common/policies/diffusion/model/conv1d_components.py b/lerobot/common/policies/diffusion/model/conv1d_components.py deleted file mode 100644 index 3c21eaf6..00000000 --- a/lerobot/common/policies/diffusion/model/conv1d_components.py +++ /dev/null @@ -1,47 +0,0 @@ -import torch.nn as nn - -# from einops.layers.torch import Rearrange - - -class Downsample1d(nn.Module): - def __init__(self, dim): - super().__init__() - self.conv = nn.Conv1d(dim, dim, 3, 2, 1) - - def forward(self, x): - return self.conv(x) - - -class Upsample1d(nn.Module): - def __init__(self, dim): - super().__init__() - self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1) - - def forward(self, x): - return self.conv(x) - - -class Conv1dBlock(nn.Module): - """ - Conv1d --> GroupNorm --> Mish - """ - - def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): - super().__init__() - - self.block = nn.Sequential( - nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), - # Rearrange('batch channels horizon -> batch channels 1 horizon'), - nn.GroupNorm(n_groups, out_channels), - # Rearrange('batch channels 1 horizon -> batch channels horizon'), - nn.Mish(), - ) - - def forward(self, x): - return self.block(x) - - -# def test(): -# cb = Conv1dBlock(256, 128, kernel_size=3) -# x = torch.zeros((1,256,16)) -# o = cb(x) diff --git a/lerobot/common/policies/diffusion/model/crop_randomizer.py b/lerobot/common/policies/diffusion/model/crop_randomizer.py deleted file mode 100644 index 2e60f353..00000000 --- a/lerobot/common/policies/diffusion/model/crop_randomizer.py +++ /dev/null @@ -1,294 +0,0 @@ -import torch -import torch.nn as nn -import torchvision.transforms.functional as ttf - -import lerobot.common.policies.diffusion.model.tensor_utils as tu - - -class CropRandomizer(nn.Module): - """ - Randomly sample crops at input, and then average across crop features at output. - """ - - def __init__( - self, - input_shape, - crop_height, - crop_width, - num_crops=1, - pos_enc=False, - ): - """ - Args: - input_shape (tuple, list): shape of input (not including batch dimension) - crop_height (int): crop height - crop_width (int): crop width - num_crops (int): number of random crops to take - pos_enc (bool): if True, add 2 channels to the output to encode the spatial - location of the cropped pixels in the source image - """ - super().__init__() - - assert len(input_shape) == 3 # (C, H, W) - assert crop_height < input_shape[1] - assert crop_width < input_shape[2] - - self.input_shape = input_shape - self.crop_height = crop_height - self.crop_width = crop_width - self.num_crops = num_crops - self.pos_enc = pos_enc - - def output_shape_in(self, input_shape=None): - """ - Function to compute output shape from inputs to this module. Corresponds to - the @forward_in operation, where raw inputs (usually observation modalities) - are passed in. - - Args: - input_shape (iterable of int): shape of input. Does not include batch dimension. - Some modules may not need this argument, if their output does not depend - on the size of the input, or if they assume fixed size input. - - Returns: - out_shape ([int]): list of integers corresponding to output shape - """ - - # outputs are shape (C, CH, CW), or maybe C + 2 if using position encoding, because - # the number of crops are reshaped into the batch dimension, increasing the batch - # size from B to B * N - out_c = self.input_shape[0] + 2 if self.pos_enc else self.input_shape[0] - return [out_c, self.crop_height, self.crop_width] - - def output_shape_out(self, input_shape=None): - """ - Function to compute output shape from inputs to this module. Corresponds to - the @forward_out operation, where processed inputs (usually encoded observation - modalities) are passed in. - - Args: - input_shape (iterable of int): shape of input. Does not include batch dimension. - Some modules may not need this argument, if their output does not depend - on the size of the input, or if they assume fixed size input. - - Returns: - out_shape ([int]): list of integers corresponding to output shape - """ - - # since the forward_out operation splits [B * N, ...] -> [B, N, ...] - # and then pools to result in [B, ...], only the batch dimension changes, - # and so the other dimensions retain their shape. - return list(input_shape) - - def forward_in(self, inputs): - """ - Samples N random crops for each input in the batch, and then reshapes - inputs to [B * N, ...]. - """ - assert len(inputs.shape) >= 3 # must have at least (C, H, W) dimensions - if self.training: - # generate random crops - out, _ = sample_random_image_crops( - images=inputs, - crop_height=self.crop_height, - crop_width=self.crop_width, - num_crops=self.num_crops, - pos_enc=self.pos_enc, - ) - # [B, N, ...] -> [B * N, ...] - return tu.join_dimensions(out, 0, 1) - else: - # take center crop during eval - out = ttf.center_crop(img=inputs, output_size=(self.crop_height, self.crop_width)) - if self.num_crops > 1: - B, C, H, W = out.shape # noqa: N806 - out = out.unsqueeze(1).expand(B, self.num_crops, C, H, W).reshape(-1, C, H, W) - # [B * N, ...] - return out - - def forward_out(self, inputs): - """ - Splits the outputs from shape [B * N, ...] -> [B, N, ...] and then average across N - to result in shape [B, ...] to make sure the network output is consistent with - what would have happened if there were no randomization. - """ - if self.num_crops <= 1: - return inputs - else: - batch_size = inputs.shape[0] // self.num_crops - out = tu.reshape_dimensions( - inputs, begin_axis=0, end_axis=0, target_dims=(batch_size, self.num_crops) - ) - return out.mean(dim=1) - - def forward(self, inputs): - return self.forward_in(inputs) - - def __repr__(self): - """Pretty print network.""" - header = "{}".format(str(self.__class__.__name__)) - msg = header + "(input_shape={}, crop_size=[{}, {}], num_crops={})".format( - self.input_shape, self.crop_height, self.crop_width, self.num_crops - ) - return msg - - -def crop_image_from_indices(images, crop_indices, crop_height, crop_width): - """ - Crops images at the locations specified by @crop_indices. Crops will be - taken across all channels. - - Args: - images (torch.Tensor): batch of images of shape [..., C, H, W] - - crop_indices (torch.Tensor): batch of indices of shape [..., N, 2] where - N is the number of crops to take per image and each entry corresponds - to the pixel height and width of where to take the crop. Note that - the indices can also be of shape [..., 2] if only 1 crop should - be taken per image. Leading dimensions must be consistent with - @images argument. Each index specifies the top left of the crop. - Values must be in range [0, H - CH - 1] x [0, W - CW - 1] where - H and W are the height and width of @images and CH and CW are - @crop_height and @crop_width. - - crop_height (int): height of crop to take - - crop_width (int): width of crop to take - - Returns: - crops (torch.Tesnor): cropped images of shape [..., C, @crop_height, @crop_width] - """ - - # make sure length of input shapes is consistent - assert crop_indices.shape[-1] == 2 - ndim_im_shape = len(images.shape) - ndim_indices_shape = len(crop_indices.shape) - assert (ndim_im_shape == ndim_indices_shape + 1) or (ndim_im_shape == ndim_indices_shape + 2) - - # maybe pad so that @crop_indices is shape [..., N, 2] - is_padded = False - if ndim_im_shape == ndim_indices_shape + 2: - crop_indices = crop_indices.unsqueeze(-2) - is_padded = True - - # make sure leading dimensions between images and indices are consistent - assert images.shape[:-3] == crop_indices.shape[:-2] - - device = images.device - image_c, image_h, image_w = images.shape[-3:] - num_crops = crop_indices.shape[-2] - - # make sure @crop_indices are in valid range - assert (crop_indices[..., 0] >= 0).all().item() - assert (crop_indices[..., 0] < (image_h - crop_height)).all().item() - assert (crop_indices[..., 1] >= 0).all().item() - assert (crop_indices[..., 1] < (image_w - crop_width)).all().item() - - # convert each crop index (ch, cw) into a list of pixel indices that correspond to the entire window. - - # 2D index array with columns [0, 1, ..., CH - 1] and shape [CH, CW] - crop_ind_grid_h = torch.arange(crop_height).to(device) - crop_ind_grid_h = tu.unsqueeze_expand_at(crop_ind_grid_h, size=crop_width, dim=-1) - # 2D index array with rows [0, 1, ..., CW - 1] and shape [CH, CW] - crop_ind_grid_w = torch.arange(crop_width).to(device) - crop_ind_grid_w = tu.unsqueeze_expand_at(crop_ind_grid_w, size=crop_height, dim=0) - # combine into shape [CH, CW, 2] - crop_in_grid = torch.cat((crop_ind_grid_h.unsqueeze(-1), crop_ind_grid_w.unsqueeze(-1)), dim=-1) - - # Add above grid with the offset index of each sampled crop to get 2d indices for each crop. - # After broadcasting, this will be shape [..., N, CH, CW, 2] and each crop has a [CH, CW, 2] - # shape array that tells us which pixels from the corresponding source image to grab. - grid_reshape = [1] * len(crop_indices.shape[:-1]) + [crop_height, crop_width, 2] - all_crop_inds = crop_indices.unsqueeze(-2).unsqueeze(-2) + crop_in_grid.reshape(grid_reshape) - - # For using @torch.gather, convert to flat indices from 2D indices, and also - # repeat across the channel dimension. To get flat index of each pixel to grab for - # each sampled crop, we just use the mapping: ind = h_ind * @image_w + w_ind - all_crop_inds = all_crop_inds[..., 0] * image_w + all_crop_inds[..., 1] # shape [..., N, CH, CW] - all_crop_inds = tu.unsqueeze_expand_at(all_crop_inds, size=image_c, dim=-3) # shape [..., N, C, CH, CW] - all_crop_inds = tu.flatten(all_crop_inds, begin_axis=-2) # shape [..., N, C, CH * CW] - - # Repeat and flatten the source images -> [..., N, C, H * W] and then use gather to index with crop pixel inds - images_to_crop = tu.unsqueeze_expand_at(images, size=num_crops, dim=-4) - images_to_crop = tu.flatten(images_to_crop, begin_axis=-2) - crops = torch.gather(images_to_crop, dim=-1, index=all_crop_inds) - # [..., N, C, CH * CW] -> [..., N, C, CH, CW] - reshape_axis = len(crops.shape) - 1 - crops = tu.reshape_dimensions( - crops, begin_axis=reshape_axis, end_axis=reshape_axis, target_dims=(crop_height, crop_width) - ) - - if is_padded: - # undo padding -> [..., C, CH, CW] - crops = crops.squeeze(-4) - return crops - - -def sample_random_image_crops(images, crop_height, crop_width, num_crops, pos_enc=False): - """ - For each image, randomly sample @num_crops crops of size (@crop_height, @crop_width), from - @images. - - Args: - images (torch.Tensor): batch of images of shape [..., C, H, W] - - crop_height (int): height of crop to take - - crop_width (int): width of crop to take - - num_crops (n): number of crops to sample - - pos_enc (bool): if True, also add 2 channels to the outputs that gives a spatial - encoding of the original source pixel locations. This means that the - output crops will contain information about where in the source image - it was sampled from. - - Returns: - crops (torch.Tensor): crops of shape (..., @num_crops, C, @crop_height, @crop_width) - if @pos_enc is False, otherwise (..., @num_crops, C + 2, @crop_height, @crop_width) - - crop_inds (torch.Tensor): sampled crop indices of shape (..., N, 2) - """ - device = images.device - - # maybe add 2 channels of spatial encoding to the source image - source_im = images - if pos_enc: - # spatial encoding [y, x] in [0, 1] - h, w = source_im.shape[-2:] - pos_y, pos_x = torch.meshgrid(torch.arange(h), torch.arange(w)) - pos_y = pos_y.float().to(device) / float(h) - pos_x = pos_x.float().to(device) / float(w) - position_enc = torch.stack((pos_y, pos_x)) # shape [C, H, W] - - # unsqueeze and expand to match leading dimensions -> shape [..., C, H, W] - leading_shape = source_im.shape[:-3] - position_enc = position_enc[(None,) * len(leading_shape)] - position_enc = position_enc.expand(*leading_shape, -1, -1, -1) - - # concat across channel dimension with input - source_im = torch.cat((source_im, position_enc), dim=-3) - - # make sure sample boundaries ensure crops are fully within the images - image_c, image_h, image_w = source_im.shape[-3:] - max_sample_h = image_h - crop_height - max_sample_w = image_w - crop_width - - # Sample crop locations for all tensor dimensions up to the last 3, which are [C, H, W]. - # Each gets @num_crops samples - typically this will just be the batch dimension (B), so - # we will sample [B, N] indices, but this supports having more than one leading dimension, - # or possibly no leading dimension. - # - # Trick: sample in [0, 1) with rand, then re-scale to [0, M) and convert to long to get sampled ints - crop_inds_h = (max_sample_h * torch.rand(*source_im.shape[:-3], num_crops).to(device)).long() - crop_inds_w = (max_sample_w * torch.rand(*source_im.shape[:-3], num_crops).to(device)).long() - crop_inds = torch.cat((crop_inds_h.unsqueeze(-1), crop_inds_w.unsqueeze(-1)), dim=-1) # shape [..., N, 2] - - crops = crop_image_from_indices( - images=source_im, - crop_indices=crop_inds, - crop_height=crop_height, - crop_width=crop_width, - ) - - return crops, crop_inds diff --git a/lerobot/common/policies/diffusion/model/dict_of_tensor_mixin.py b/lerobot/common/policies/diffusion/model/dict_of_tensor_mixin.py deleted file mode 100644 index d1356006..00000000 --- a/lerobot/common/policies/diffusion/model/dict_of_tensor_mixin.py +++ /dev/null @@ -1,41 +0,0 @@ -import torch -import torch.nn as nn - - -class DictOfTensorMixin(nn.Module): - def __init__(self, params_dict=None): - super().__init__() - if params_dict is None: - params_dict = nn.ParameterDict() - self.params_dict = params_dict - - @property - def device(self): - return next(iter(self.parameters())).device - - def _load_from_state_dict( - self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs - ): - def dfs_add(dest, keys, value: torch.Tensor): - if len(keys) == 1: - dest[keys[0]] = value - return - - if keys[0] not in dest: - dest[keys[0]] = nn.ParameterDict() - dfs_add(dest[keys[0]], keys[1:], value) - - def load_dict(state_dict, prefix): - out_dict = nn.ParameterDict() - for key, value in state_dict.items(): - value: torch.Tensor - if key.startswith(prefix): - param_keys = key[len(prefix) :].split(".")[1:] - # if len(param_keys) == 0: - # import pdb; pdb.set_trace() - dfs_add(out_dict, param_keys, value.clone()) - return out_dict - - self.params_dict = load_dict(state_dict, prefix + "params_dict") - self.params_dict.requires_grad_(False) - return diff --git a/lerobot/common/policies/diffusion/model/ema_model.py b/lerobot/common/policies/diffusion/model/ema_model.py deleted file mode 100644 index 6dc128de..00000000 --- a/lerobot/common/policies/diffusion/model/ema_model.py +++ /dev/null @@ -1,84 +0,0 @@ -import torch -from torch.nn.modules.batchnorm import _BatchNorm - - -class EMAModel: - """ - Exponential Moving Average of models weights - """ - - def __init__( - self, model, update_after_step=0, inv_gamma=1.0, power=2 / 3, min_value=0.0, max_value=0.9999 - ): - """ - @crowsonkb's notes on EMA Warmup: - If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan - to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), - gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 - at 215.4k steps). - Args: - inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. - power (float): Exponential factor of EMA warmup. Default: 2/3. - min_value (float): The minimum EMA decay rate. Default: 0. - """ - - self.averaged_model = model - self.averaged_model.eval() - self.averaged_model.requires_grad_(False) - - self.update_after_step = update_after_step - self.inv_gamma = inv_gamma - self.power = power - self.min_value = min_value - self.max_value = max_value - - self.decay = 0.0 - self.optimization_step = 0 - - def get_decay(self, optimization_step): - """ - Compute the decay factor for the exponential moving average. - """ - step = max(0, optimization_step - self.update_after_step - 1) - value = 1 - (1 + step / self.inv_gamma) ** -self.power - - if step <= 0: - return 0.0 - - return max(self.min_value, min(value, self.max_value)) - - @torch.no_grad() - def step(self, new_model): - self.decay = self.get_decay(self.optimization_step) - - # old_all_dataptrs = set() - # for param in new_model.parameters(): - # data_ptr = param.data_ptr() - # if data_ptr != 0: - # old_all_dataptrs.add(data_ptr) - - # all_dataptrs = set() - for module, ema_module in zip(new_model.modules(), self.averaged_model.modules(), strict=False): - for param, ema_param in zip( - module.parameters(recurse=False), ema_module.parameters(recurse=False), strict=False - ): - # iterative over immediate parameters only. - if isinstance(param, dict): - raise RuntimeError("Dict parameter not supported") - - # data_ptr = param.data_ptr() - # if data_ptr != 0: - # all_dataptrs.add(data_ptr) - - if isinstance(module, _BatchNorm): - # skip batchnorms - ema_param.copy_(param.to(dtype=ema_param.dtype).data) - elif not param.requires_grad: - ema_param.copy_(param.to(dtype=ema_param.dtype).data) - else: - ema_param.mul_(self.decay) - ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay) - - # verify that iterating over module and then parameters is identical to parameters recursively. - # assert old_all_dataptrs == all_dataptrs - self.optimization_step += 1 diff --git a/lerobot/common/policies/diffusion/model/lr_scheduler.py b/lerobot/common/policies/diffusion/model/lr_scheduler.py deleted file mode 100644 index 084b3a36..00000000 --- a/lerobot/common/policies/diffusion/model/lr_scheduler.py +++ /dev/null @@ -1,46 +0,0 @@ -from diffusers.optimization import TYPE_TO_SCHEDULER_FUNCTION, Optimizer, Optional, SchedulerType, Union - - -def get_scheduler( - name: Union[str, SchedulerType], - optimizer: Optimizer, - num_warmup_steps: Optional[int] = None, - num_training_steps: Optional[int] = None, - **kwargs, -): - """ - Added kwargs vs diffuser's original implementation - - Unified API to get any scheduler from its name. - - Args: - name (`str` or `SchedulerType`): - The name of the scheduler to use. - optimizer (`torch.optim.Optimizer`): - The optimizer that will be used during training. - num_warmup_steps (`int`, *optional*): - The number of warmup steps to do. This is not required by all schedulers (hence the argument being - optional), the function will raise an error if it's unset and the scheduler type requires it. - num_training_steps (`int``, *optional*): - The number of training steps to do. This is not required by all schedulers (hence the argument being - optional), the function will raise an error if it's unset and the scheduler type requires it. - """ - name = SchedulerType(name) - schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] - if name == SchedulerType.CONSTANT: - return schedule_func(optimizer, **kwargs) - - # All other schedulers require `num_warmup_steps` - if num_warmup_steps is None: - raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") - - if name == SchedulerType.CONSTANT_WITH_WARMUP: - return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **kwargs) - - # All other schedulers require `num_training_steps` - if num_training_steps is None: - raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") - - return schedule_func( - optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, **kwargs - ) diff --git a/lerobot/common/policies/diffusion/model/mask_generator.py b/lerobot/common/policies/diffusion/model/mask_generator.py deleted file mode 100644 index 63306dea..00000000 --- a/lerobot/common/policies/diffusion/model/mask_generator.py +++ /dev/null @@ -1,65 +0,0 @@ -import torch - -from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin - - -class LowdimMaskGenerator(ModuleAttrMixin): - def __init__( - self, - action_dim, - obs_dim, - # obs mask setup - max_n_obs_steps=2, - fix_obs_steps=True, - # action mask - action_visible=False, - ): - super().__init__() - self.action_dim = action_dim - self.obs_dim = obs_dim - self.max_n_obs_steps = max_n_obs_steps - self.fix_obs_steps = fix_obs_steps - self.action_visible = action_visible - - @torch.no_grad() - def forward(self, shape, seed=None): - device = self.device - B, T, D = shape # noqa: N806 - assert (self.action_dim + self.obs_dim) == D - - # create all tensors on this device - rng = torch.Generator(device=device) - if seed is not None: - rng = rng.manual_seed(seed) - - # generate dim mask - dim_mask = torch.zeros(size=shape, dtype=torch.bool, device=device) - is_action_dim = dim_mask.clone() - is_action_dim[..., : self.action_dim] = True - is_obs_dim = ~is_action_dim - - # generate obs mask - if self.fix_obs_steps: - obs_steps = torch.full((B,), fill_value=self.max_n_obs_steps, device=device) - else: - obs_steps = torch.randint( - low=1, high=self.max_n_obs_steps + 1, size=(B,), generator=rng, device=device - ) - - steps = torch.arange(0, T, device=device).reshape(1, T).expand(B, T) - obs_mask = (obs_steps > steps.T).T.reshape(B, T, 1).expand(B, T, D) - obs_mask = obs_mask & is_obs_dim - - # generate action mask - if self.action_visible: - action_steps = torch.maximum( - obs_steps - 1, torch.tensor(0, dtype=obs_steps.dtype, device=obs_steps.device) - ) - action_mask = (action_steps > steps.T).T.reshape(B, T, 1).expand(B, T, D) - action_mask = action_mask & is_action_dim - - mask = obs_mask - if self.action_visible: - mask = mask | action_mask - - return mask diff --git a/lerobot/common/policies/diffusion/model/module_attr_mixin.py b/lerobot/common/policies/diffusion/model/module_attr_mixin.py deleted file mode 100644 index 5d2cf4ea..00000000 --- a/lerobot/common/policies/diffusion/model/module_attr_mixin.py +++ /dev/null @@ -1,15 +0,0 @@ -import torch.nn as nn - - -class ModuleAttrMixin(nn.Module): - def __init__(self): - super().__init__() - self._dummy_variable = nn.Parameter() - - @property - def device(self): - return next(iter(self.parameters())).device - - @property - def dtype(self): - return next(iter(self.parameters())).dtype diff --git a/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py b/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py deleted file mode 100644 index d724cd49..00000000 --- a/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py +++ /dev/null @@ -1,214 +0,0 @@ -import copy -from typing import Dict, Optional, Tuple, Union - -import torch -import torch.nn as nn -import torchvision -from robomimic.models.base_nets import ResNet18Conv, SpatialSoftmax - -from lerobot.common.policies.diffusion.model.crop_randomizer import CropRandomizer -from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin -from lerobot.common.policies.diffusion.pytorch_utils import replace_submodules - - -class RgbEncoder(nn.Module): - """Following `VisualCore` from Robomimic 0.2.0.""" - - def __init__(self, input_shape, relu=True, pretrained=False, num_keypoints=32): - """ - input_shape: channel-first input shape (C, H, W) - resnet_name: a timm model name. - pretrained: whether to use timm pretrained weights. - relu: whether to use relu as a final step. - num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image). - """ - super().__init__() - self.backbone = ResNet18Conv(input_channel=input_shape[0], pretrained=pretrained) - # Figure out the feature map shape. - with torch.inference_mode(): - feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:]) - self.pool = SpatialSoftmax(feat_map_shape, num_kp=num_keypoints) - self.out = nn.Linear(num_keypoints * 2, num_keypoints * 2) - self.relu = nn.ReLU() if relu else nn.Identity() - - def forward(self, x): - return self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1))) - - -class MultiImageObsEncoder(ModuleAttrMixin): - def __init__( - self, - shape_meta: dict, - rgb_model: Union[nn.Module, Dict[str, nn.Module]], - resize_shape: Union[Tuple[int, int], Dict[str, tuple], None] = None, - crop_shape: Union[Tuple[int, int], Dict[str, tuple], None] = None, - random_crop: bool = True, - # replace BatchNorm with GroupNorm - use_group_norm: bool = False, - # use single rgb model for all rgb inputs - share_rgb_model: bool = False, - # renormalize rgb input with imagenet normalization - # assuming input in [0,1] - norm_mean_std: Optional[tuple[float, float]] = None, - ): - """ - Assumes rgb input: B,C,H,W - Assumes low_dim input: B,D - """ - super().__init__() - - rgb_keys = [] - low_dim_keys = [] - key_model_map = nn.ModuleDict() - key_transform_map = nn.ModuleDict() - key_shape_map = {} - - # handle sharing vision backbone - if share_rgb_model: - assert isinstance(rgb_model, nn.Module) - key_model_map["rgb"] = rgb_model - - obs_shape_meta = shape_meta["obs"] - for key, attr in obs_shape_meta.items(): - shape = tuple(attr["shape"]) - type = attr.get("type", "low_dim") - key_shape_map[key] = shape - if type == "rgb": - rgb_keys.append(key) - # configure model for this key - this_model = None - if not share_rgb_model: - if isinstance(rgb_model, dict): - # have provided model for each key - this_model = rgb_model[key] - else: - assert isinstance(rgb_model, nn.Module) - # have a copy of the rgb model - this_model = copy.deepcopy(rgb_model) - - if this_model is not None: - if use_group_norm: - this_model = replace_submodules( - root_module=this_model, - predicate=lambda x: isinstance(x, nn.BatchNorm2d), - func=lambda x: nn.GroupNorm( - num_groups=x.num_features // 16, num_channels=x.num_features - ), - ) - key_model_map[key] = this_model - - # configure resize - input_shape = shape - this_resizer = nn.Identity() - if resize_shape is not None: - if isinstance(resize_shape, dict): - h, w = resize_shape[key] - else: - h, w = resize_shape - this_resizer = torchvision.transforms.Resize(size=(h, w)) - input_shape = (shape[0], h, w) - - # configure randomizer - this_randomizer = nn.Identity() - if crop_shape is not None: - if isinstance(crop_shape, dict): - h, w = crop_shape[key] - else: - h, w = crop_shape - if random_crop: - this_randomizer = CropRandomizer( - input_shape=input_shape, crop_height=h, crop_width=w, num_crops=1, pos_enc=False - ) - else: - this_normalizer = torchvision.transforms.CenterCrop(size=(h, w)) - # configure normalizer - this_normalizer = nn.Identity() - if norm_mean_std is not None: - this_normalizer = torchvision.transforms.Normalize( - mean=norm_mean_std[0], std=norm_mean_std[1] - ) - - this_transform = nn.Sequential(this_resizer, this_randomizer, this_normalizer) - key_transform_map[key] = this_transform - elif type == "low_dim": - low_dim_keys.append(key) - else: - raise RuntimeError(f"Unsupported obs type: {type}") - rgb_keys = sorted(rgb_keys) - low_dim_keys = sorted(low_dim_keys) - - self.shape_meta = shape_meta - self.key_model_map = key_model_map - self.key_transform_map = key_transform_map - self.share_rgb_model = share_rgb_model - self.rgb_keys = rgb_keys - self.low_dim_keys = low_dim_keys - self.key_shape_map = key_shape_map - - def forward(self, obs_dict): - batch_size = None - features = [] - - # process lowdim input - for key in self.low_dim_keys: - data = obs_dict[key] - if batch_size is None: - batch_size = data.shape[0] - else: - assert batch_size == data.shape[0] - assert data.shape[1:] == self.key_shape_map[key] - features.append(data) - - # process rgb input - if self.share_rgb_model: - # pass all rgb obs to rgb model - imgs = [] - for key in self.rgb_keys: - img = obs_dict[key] - if batch_size is None: - batch_size = img.shape[0] - else: - assert batch_size == img.shape[0] - assert img.shape[1:] == self.key_shape_map[key] - img = self.key_transform_map[key](img) - imgs.append(img) - # (N*B,C,H,W) - imgs = torch.cat(imgs, dim=0) - # (N*B,D) - feature = self.key_model_map["rgb"](imgs) - # (N,B,D) - feature = feature.reshape(-1, batch_size, *feature.shape[1:]) - # (B,N,D) - feature = torch.moveaxis(feature, 0, 1) - # (B,N*D) - feature = feature.reshape(batch_size, -1) - features.append(feature) - else: - # run each rgb obs to independent models - for key in self.rgb_keys: - img = obs_dict[key] - if batch_size is None: - batch_size = img.shape[0] - else: - assert batch_size == img.shape[0] - assert img.shape[1:] == self.key_shape_map[key] - img = self.key_transform_map[key](img) - feature = self.key_model_map[key](img) - features.append(feature) - - # concatenate all features - result = torch.cat(features, dim=-1) - return result - - @torch.no_grad() - def output_shape(self): - example_obs_dict = {} - obs_shape_meta = self.shape_meta["obs"] - batch_size = 1 - for key, attr in obs_shape_meta.items(): - shape = tuple(attr["shape"]) - this_obs = torch.zeros((batch_size,) + shape, dtype=self.dtype, device=self.device) - example_obs_dict[key] = this_obs - example_output = self.forward(example_obs_dict) - output_shape = example_output.shape[1:] - return output_shape diff --git a/lerobot/common/policies/diffusion/model/normalizer.py b/lerobot/common/policies/diffusion/model/normalizer.py deleted file mode 100644 index 0e4d79ab..00000000 --- a/lerobot/common/policies/diffusion/model/normalizer.py +++ /dev/null @@ -1,358 +0,0 @@ -from typing import Dict, Union - -import numpy as np -import torch -import torch.nn as nn -import zarr - -from lerobot.common.policies.diffusion.model.dict_of_tensor_mixin import DictOfTensorMixin -from lerobot.common.policies.diffusion.pytorch_utils import dict_apply - - -class LinearNormalizer(DictOfTensorMixin): - avaliable_modes = ["limits", "gaussian"] - - @torch.no_grad() - def fit( - self, - data: Union[Dict, torch.Tensor, np.ndarray, zarr.Array], - last_n_dims=1, - dtype=torch.float32, - mode="limits", - output_max=1.0, - output_min=-1.0, - range_eps=1e-4, - fit_offset=True, - ): - if isinstance(data, dict): - for key, value in data.items(): - self.params_dict[key] = _fit( - value, - last_n_dims=last_n_dims, - dtype=dtype, - mode=mode, - output_max=output_max, - output_min=output_min, - range_eps=range_eps, - fit_offset=fit_offset, - ) - else: - self.params_dict["_default"] = _fit( - data, - last_n_dims=last_n_dims, - dtype=dtype, - mode=mode, - output_max=output_max, - output_min=output_min, - range_eps=range_eps, - fit_offset=fit_offset, - ) - - def __call__(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor: - return self.normalize(x) - - def __getitem__(self, key: str): - return SingleFieldLinearNormalizer(self.params_dict[key]) - - def __setitem__(self, key: str, value: "SingleFieldLinearNormalizer"): - self.params_dict[key] = value.params_dict - - def _normalize_impl(self, x, forward=True): - if isinstance(x, dict): - result = {} - for key, value in x.items(): - params = self.params_dict[key] - result[key] = _normalize(value, params, forward=forward) - return result - else: - if "_default" not in self.params_dict: - raise RuntimeError("Not initialized") - params = self.params_dict["_default"] - return _normalize(x, params, forward=forward) - - def normalize(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor: - return self._normalize_impl(x, forward=True) - - def unnormalize(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor: - return self._normalize_impl(x, forward=False) - - def get_input_stats(self) -> Dict: - if len(self.params_dict) == 0: - raise RuntimeError("Not initialized") - if len(self.params_dict) == 1 and "_default" in self.params_dict: - return self.params_dict["_default"]["input_stats"] - - result = {} - for key, value in self.params_dict.items(): - if key != "_default": - result[key] = value["input_stats"] - return result - - def get_output_stats(self, key="_default"): - input_stats = self.get_input_stats() - if "min" in input_stats: - # no dict - return dict_apply(input_stats, self.normalize) - - result = {} - for key, group in input_stats.items(): - this_dict = {} - for name, value in group.items(): - this_dict[name] = self.normalize({key: value})[key] - result[key] = this_dict - return result - - -class SingleFieldLinearNormalizer(DictOfTensorMixin): - avaliable_modes = ["limits", "gaussian"] - - @torch.no_grad() - def fit( - self, - data: Union[torch.Tensor, np.ndarray, zarr.Array], - last_n_dims=1, - dtype=torch.float32, - mode="limits", - output_max=1.0, - output_min=-1.0, - range_eps=1e-4, - fit_offset=True, - ): - self.params_dict = _fit( - data, - last_n_dims=last_n_dims, - dtype=dtype, - mode=mode, - output_max=output_max, - output_min=output_min, - range_eps=range_eps, - fit_offset=fit_offset, - ) - - @classmethod - def create_fit(cls, data: Union[torch.Tensor, np.ndarray, zarr.Array], **kwargs): - obj = cls() - obj.fit(data, **kwargs) - return obj - - @classmethod - def create_manual( - cls, - scale: Union[torch.Tensor, np.ndarray], - offset: Union[torch.Tensor, np.ndarray], - input_stats_dict: Dict[str, Union[torch.Tensor, np.ndarray]], - ): - def to_tensor(x): - if not isinstance(x, torch.Tensor): - x = torch.from_numpy(x) - x = x.flatten() - return x - - # check - for x in [offset] + list(input_stats_dict.values()): - assert x.shape == scale.shape - assert x.dtype == scale.dtype - - params_dict = nn.ParameterDict( - { - "scale": to_tensor(scale), - "offset": to_tensor(offset), - "input_stats": nn.ParameterDict(dict_apply(input_stats_dict, to_tensor)), - } - ) - return cls(params_dict) - - @classmethod - def create_identity(cls, dtype=torch.float32): - scale = torch.tensor([1], dtype=dtype) - offset = torch.tensor([0], dtype=dtype) - input_stats_dict = { - "min": torch.tensor([-1], dtype=dtype), - "max": torch.tensor([1], dtype=dtype), - "mean": torch.tensor([0], dtype=dtype), - "std": torch.tensor([1], dtype=dtype), - } - return cls.create_manual(scale, offset, input_stats_dict) - - def normalize(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor: - return _normalize(x, self.params_dict, forward=True) - - def unnormalize(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor: - return _normalize(x, self.params_dict, forward=False) - - def get_input_stats(self): - return self.params_dict["input_stats"] - - def get_output_stats(self): - return dict_apply(self.params_dict["input_stats"], self.normalize) - - def __call__(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor: - return self.normalize(x) - - -def _fit( - data: Union[torch.Tensor, np.ndarray, zarr.Array], - last_n_dims=1, - dtype=torch.float32, - mode="limits", - output_max=1.0, - output_min=-1.0, - range_eps=1e-4, - fit_offset=True, -): - assert mode in ["limits", "gaussian"] - assert last_n_dims >= 0 - assert output_max > output_min - - # convert data to torch and type - if isinstance(data, zarr.Array): - data = data[:] - if isinstance(data, np.ndarray): - data = torch.from_numpy(data) - if dtype is not None: - data = data.type(dtype) - - # convert shape - dim = 1 - if last_n_dims > 0: - dim = np.prod(data.shape[-last_n_dims:]) - data = data.reshape(-1, dim) - - # compute input stats min max mean std - input_min, _ = data.min(axis=0) - input_max, _ = data.max(axis=0) - input_mean = data.mean(axis=0) - input_std = data.std(axis=0) - - # compute scale and offset - if mode == "limits": - if fit_offset: - # unit scale - input_range = input_max - input_min - ignore_dim = input_range < range_eps - input_range[ignore_dim] = output_max - output_min - scale = (output_max - output_min) / input_range - offset = output_min - scale * input_min - offset[ignore_dim] = (output_max + output_min) / 2 - input_min[ignore_dim] - # ignore dims scaled to mean of output max and min - else: - # use this when data is pre-zero-centered. - assert output_max > 0 - assert output_min < 0 - # unit abs - output_abs = min(abs(output_min), abs(output_max)) - input_abs = torch.maximum(torch.abs(input_min), torch.abs(input_max)) - ignore_dim = input_abs < range_eps - input_abs[ignore_dim] = output_abs - # don't scale constant channels - scale = output_abs / input_abs - offset = torch.zeros_like(input_mean) - elif mode == "gaussian": - ignore_dim = input_std < range_eps - scale = input_std.clone() - scale[ignore_dim] = 1 - scale = 1 / scale - - offset = -input_mean * scale if fit_offset else torch.zeros_like(input_mean) - - # save - this_params = nn.ParameterDict( - { - "scale": scale, - "offset": offset, - "input_stats": nn.ParameterDict( - {"min": input_min, "max": input_max, "mean": input_mean, "std": input_std} - ), - } - ) - for p in this_params.parameters(): - p.requires_grad_(False) - return this_params - - -def _normalize(x, params, forward=True): - assert "scale" in params - if isinstance(x, np.ndarray): - x = torch.from_numpy(x) - scale = params["scale"] - offset = params["offset"] - x = x.to(device=scale.device, dtype=scale.dtype) - src_shape = x.shape - x = x.reshape(-1, scale.shape[0]) - x = x * scale + offset if forward else (x - offset) / scale - x = x.reshape(src_shape) - return x - - -def test(): - data = torch.zeros((100, 10, 9, 2)).uniform_() - data[..., 0, 0] = 0 - - normalizer = SingleFieldLinearNormalizer() - normalizer.fit(data, mode="limits", last_n_dims=2) - datan = normalizer.normalize(data) - assert datan.shape == data.shape - assert np.allclose(datan.max(), 1.0) - assert np.allclose(datan.min(), -1.0) - dataun = normalizer.unnormalize(datan) - assert torch.allclose(data, dataun, atol=1e-7) - - _ = normalizer.get_input_stats() - _ = normalizer.get_output_stats() - - normalizer = SingleFieldLinearNormalizer() - normalizer.fit(data, mode="limits", last_n_dims=1, fit_offset=False) - datan = normalizer.normalize(data) - assert datan.shape == data.shape - assert np.allclose(datan.max(), 1.0, atol=1e-3) - assert np.allclose(datan.min(), 0.0, atol=1e-3) - dataun = normalizer.unnormalize(datan) - assert torch.allclose(data, dataun, atol=1e-7) - - data = torch.zeros((100, 10, 9, 2)).uniform_() - normalizer = SingleFieldLinearNormalizer() - normalizer.fit(data, mode="gaussian", last_n_dims=0) - datan = normalizer.normalize(data) - assert datan.shape == data.shape - assert np.allclose(datan.mean(), 0.0, atol=1e-3) - assert np.allclose(datan.std(), 1.0, atol=1e-3) - dataun = normalizer.unnormalize(datan) - assert torch.allclose(data, dataun, atol=1e-7) - - # dict - data = torch.zeros((100, 10, 9, 2)).uniform_() - data[..., 0, 0] = 0 - - normalizer = LinearNormalizer() - normalizer.fit(data, mode="limits", last_n_dims=2) - datan = normalizer.normalize(data) - assert datan.shape == data.shape - assert np.allclose(datan.max(), 1.0) - assert np.allclose(datan.min(), -1.0) - dataun = normalizer.unnormalize(datan) - assert torch.allclose(data, dataun, atol=1e-7) - - _ = normalizer.get_input_stats() - _ = normalizer.get_output_stats() - - data = { - "obs": torch.zeros((1000, 128, 9, 2)).uniform_() * 512, - "action": torch.zeros((1000, 128, 2)).uniform_() * 512, - } - normalizer = LinearNormalizer() - normalizer.fit(data) - datan = normalizer.normalize(data) - dataun = normalizer.unnormalize(datan) - for key in data: - assert torch.allclose(data[key], dataun[key], atol=1e-4) - - _ = normalizer.get_input_stats() - _ = normalizer.get_output_stats() - - state_dict = normalizer.state_dict() - n = LinearNormalizer() - n.load_state_dict(state_dict) - datan = n.normalize(data) - dataun = n.unnormalize(datan) - for key in data: - assert torch.allclose(data[key], dataun[key], atol=1e-4) diff --git a/lerobot/common/policies/diffusion/model/positional_embedding.py b/lerobot/common/policies/diffusion/model/positional_embedding.py deleted file mode 100644 index 65fc97bd..00000000 --- a/lerobot/common/policies/diffusion/model/positional_embedding.py +++ /dev/null @@ -1,19 +0,0 @@ -import math - -import torch -import torch.nn as nn - - -class SinusoidalPosEmb(nn.Module): - def __init__(self, dim): - super().__init__() - self.dim = dim - - def forward(self, x): - device = x.device - half_dim = self.dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, device=device) * -emb) - emb = x[:, None] * emb[None, :] - emb = torch.cat((emb.sin(), emb.cos()), dim=-1) - return emb diff --git a/lerobot/common/policies/diffusion/model/tensor_utils.py b/lerobot/common/policies/diffusion/model/tensor_utils.py deleted file mode 100644 index df9a568a..00000000 --- a/lerobot/common/policies/diffusion/model/tensor_utils.py +++ /dev/null @@ -1,972 +0,0 @@ -""" -A collection of utilities for working with nested tensor structures consisting -of numpy arrays and torch tensors. -""" - -import collections - -import numpy as np -import torch - - -def recursive_dict_list_tuple_apply(x, type_func_dict): - """ - Recursively apply functions to a nested dictionary or list or tuple, given a dictionary of - {data_type: function_to_apply}. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - type_func_dict (dict): a mapping from data types to the functions to be - applied for each data type. - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - assert list not in type_func_dict - assert tuple not in type_func_dict - assert dict not in type_func_dict - - if isinstance(x, (dict, collections.OrderedDict)): - new_x = collections.OrderedDict() if isinstance(x, collections.OrderedDict) else {} - for k, v in x.items(): - new_x[k] = recursive_dict_list_tuple_apply(v, type_func_dict) - return new_x - elif isinstance(x, (list, tuple)): - ret = [recursive_dict_list_tuple_apply(v, type_func_dict) for v in x] - if isinstance(x, tuple): - ret = tuple(ret) - return ret - else: - for t, f in type_func_dict.items(): - if isinstance(x, t): - return f(x) - else: - raise NotImplementedError("Cannot handle data type %s" % str(type(x))) - - -def map_tensor(x, func): - """ - Apply function @func to torch.Tensor objects in a nested dictionary or - list or tuple. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - func (function): function to apply to each tensor - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: func, - type(None): lambda x: x, - }, - ) - - -def map_ndarray(x, func): - """ - Apply function @func to np.ndarray objects in a nested dictionary or - list or tuple. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - func (function): function to apply to each array - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - np.ndarray: func, - type(None): lambda x: x, - }, - ) - - -def map_tensor_ndarray(x, tensor_func, ndarray_func): - """ - Apply function @tensor_func to torch.Tensor objects and @ndarray_func to - np.ndarray objects in a nested dictionary or list or tuple. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - tensor_func (function): function to apply to each tensor - ndarray_Func (function): function to apply to each array - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: tensor_func, - np.ndarray: ndarray_func, - type(None): lambda x: x, - }, - ) - - -def clone(x): - """ - Clones all torch tensors and numpy arrays in nested dictionary or list - or tuple and returns a new nested structure. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: lambda x: x.clone(), - np.ndarray: lambda x: x.copy(), - type(None): lambda x: x, - }, - ) - - -def detach(x): - """ - Detaches all torch tensors in nested dictionary or list - or tuple and returns a new nested structure. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: lambda x: x.detach(), - }, - ) - - -def to_batch(x): - """ - Introduces a leading batch dimension of 1 for all torch tensors and numpy - arrays in nested dictionary or list or tuple and returns a new nested structure. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: lambda x: x[None, ...], - np.ndarray: lambda x: x[None, ...], - type(None): lambda x: x, - }, - ) - - -def to_sequence(x): - """ - Introduces a time dimension of 1 at dimension 1 for all torch tensors and numpy - arrays in nested dictionary or list or tuple and returns a new nested structure. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: lambda x: x[:, None, ...], - np.ndarray: lambda x: x[:, None, ...], - type(None): lambda x: x, - }, - ) - - -def index_at_time(x, ind): - """ - Indexes all torch tensors and numpy arrays in dimension 1 with index @ind in - nested dictionary or list or tuple and returns a new nested structure. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - ind (int): index - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: lambda x: x[:, ind, ...], - np.ndarray: lambda x: x[:, ind, ...], - type(None): lambda x: x, - }, - ) - - -def unsqueeze(x, dim): - """ - Adds dimension of size 1 at dimension @dim in all torch tensors and numpy arrays - in nested dictionary or list or tuple and returns a new nested structure. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - dim (int): dimension - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: lambda x: x.unsqueeze(dim=dim), - np.ndarray: lambda x: np.expand_dims(x, axis=dim), - type(None): lambda x: x, - }, - ) - - -def contiguous(x): - """ - Makes all torch tensors and numpy arrays contiguous in nested dictionary or - list or tuple and returns a new nested structure. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: lambda x: x.contiguous(), - np.ndarray: lambda x: np.ascontiguousarray(x), - type(None): lambda x: x, - }, - ) - - -def to_device(x, device): - """ - Sends all torch tensors in nested dictionary or list or tuple to device - @device, and returns a new nested structure. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - device (torch.Device): device to send tensors to - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: lambda x, d=device: x.to(d), - type(None): lambda x: x, - }, - ) - - -def to_tensor(x): - """ - Converts all numpy arrays in nested dictionary or list or tuple to - torch tensors (and leaves existing torch Tensors as-is), and returns - a new nested structure. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: lambda x: x, - np.ndarray: lambda x: torch.from_numpy(x), - type(None): lambda x: x, - }, - ) - - -def to_numpy(x): - """ - Converts all torch tensors in nested dictionary or list or tuple to - numpy (and leaves existing numpy arrays as-is), and returns - a new nested structure. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - - def f(tensor): - if tensor.is_cuda: - return tensor.detach().cpu().numpy() - else: - return tensor.detach().numpy() - - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: f, - np.ndarray: lambda x: x, - type(None): lambda x: x, - }, - ) - - -def to_list(x): - """ - Converts all torch tensors and numpy arrays in nested dictionary or list - or tuple to a list, and returns a new nested structure. Useful for - json encoding. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - - def f(tensor): - if tensor.is_cuda: - return tensor.detach().cpu().numpy().tolist() - else: - return tensor.detach().numpy().tolist() - - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: f, - np.ndarray: lambda x: x.tolist(), - type(None): lambda x: x, - }, - ) - - -def to_float(x): - """ - Converts all torch tensors and numpy arrays in nested dictionary or list - or tuple to float type entries, and returns a new nested structure. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: lambda x: x.float(), - np.ndarray: lambda x: x.astype(np.float32), - type(None): lambda x: x, - }, - ) - - -def to_uint8(x): - """ - Converts all torch tensors and numpy arrays in nested dictionary or list - or tuple to uint8 type entries, and returns a new nested structure. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: lambda x: x.byte(), - np.ndarray: lambda x: x.astype(np.uint8), - type(None): lambda x: x, - }, - ) - - -def to_torch(x, device): - """ - Converts all numpy arrays and torch tensors in nested dictionary or list or tuple to - torch tensors on device @device and returns a new nested structure. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - device (torch.Device): device to send tensors to - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return to_device(to_float(to_tensor(x)), device) - - -def to_one_hot_single(tensor, num_class): - """ - Convert tensor to one-hot representation, assuming a certain number of total class labels. - - Args: - tensor (torch.Tensor): tensor containing integer labels - num_class (int): number of classes - - Returns: - x (torch.Tensor): tensor containing one-hot representation of labels - """ - x = torch.zeros(tensor.size() + (num_class,)).to(tensor.device) - x.scatter_(-1, tensor.unsqueeze(-1), 1) - return x - - -def to_one_hot(tensor, num_class): - """ - Convert all tensors in nested dictionary or list or tuple to one-hot representation, - assuming a certain number of total class labels. - - Args: - tensor (dict or list or tuple): a possibly nested dictionary or list or tuple - num_class (int): number of classes - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return map_tensor(tensor, func=lambda x, nc=num_class: to_one_hot_single(x, nc)) - - -def flatten_single(x, begin_axis=1): - """ - Flatten a tensor in all dimensions from @begin_axis onwards. - - Args: - x (torch.Tensor): tensor to flatten - begin_axis (int): which axis to flatten from - - Returns: - y (torch.Tensor): flattened tensor - """ - fixed_size = x.size()[:begin_axis] - _s = list(fixed_size) + [-1] - return x.reshape(*_s) - - -def flatten(x, begin_axis=1): - """ - Flatten all tensors in nested dictionary or list or tuple, from @begin_axis onwards. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - begin_axis (int): which axis to flatten from - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: lambda x, b=begin_axis: flatten_single(x, begin_axis=b), - }, - ) - - -def reshape_dimensions_single(x, begin_axis, end_axis, target_dims): - """ - Reshape selected dimensions in a tensor to a target dimension. - - Args: - x (torch.Tensor): tensor to reshape - begin_axis (int): begin dimension - end_axis (int): end dimension - target_dims (tuple or list): target shape for the range of dimensions - (@begin_axis, @end_axis) - - Returns: - y (torch.Tensor): reshaped tensor - """ - assert begin_axis <= end_axis - assert begin_axis >= 0 - assert end_axis < len(x.shape) - assert isinstance(target_dims, (tuple, list)) - s = x.shape - final_s = [] - for i in range(len(s)): - if i == begin_axis: - final_s.extend(target_dims) - elif i < begin_axis or i > end_axis: - final_s.append(s[i]) - return x.reshape(*final_s) - - -def reshape_dimensions(x, begin_axis, end_axis, target_dims): - """ - Reshape selected dimensions for all tensors in nested dictionary or list or tuple - to a target dimension. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - begin_axis (int): begin dimension - end_axis (int): end dimension - target_dims (tuple or list): target shape for the range of dimensions - (@begin_axis, @end_axis) - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: lambda x, b=begin_axis, e=end_axis, t=target_dims: reshape_dimensions_single( - x, begin_axis=b, end_axis=e, target_dims=t - ), - np.ndarray: lambda x, b=begin_axis, e=end_axis, t=target_dims: reshape_dimensions_single( - x, begin_axis=b, end_axis=e, target_dims=t - ), - type(None): lambda x: x, - }, - ) - - -def join_dimensions(x, begin_axis, end_axis): - """ - Joins all dimensions between dimensions (@begin_axis, @end_axis) into a flat dimension, for - all tensors in nested dictionary or list or tuple. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - begin_axis (int): begin dimension - end_axis (int): end dimension - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single( - x, begin_axis=b, end_axis=e, target_dims=[-1] - ), - np.ndarray: lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single( - x, begin_axis=b, end_axis=e, target_dims=[-1] - ), - type(None): lambda x: x, - }, - ) - - -def expand_at_single(x, size, dim): - """ - Expand a tensor at a single dimension @dim by @size - - Args: - x (torch.Tensor): input tensor - size (int): size to expand - dim (int): dimension to expand - - Returns: - y (torch.Tensor): expanded tensor - """ - assert dim < x.ndimension() - assert x.shape[dim] == 1 - expand_dims = [-1] * x.ndimension() - expand_dims[dim] = size - return x.expand(*expand_dims) - - -def expand_at(x, size, dim): - """ - Expand all tensors in nested dictionary or list or tuple at a single - dimension @dim by @size. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - size (int): size to expand - dim (int): dimension to expand - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return map_tensor(x, lambda t, s=size, d=dim: expand_at_single(t, s, d)) - - -def unsqueeze_expand_at(x, size, dim): - """ - Unsqueeze and expand a tensor at a dimension @dim by @size. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - size (int): size to expand - dim (int): dimension to unsqueeze and expand - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - x = unsqueeze(x, dim) - return expand_at(x, size, dim) - - -def repeat_by_expand_at(x, repeats, dim): - """ - Repeat a dimension by combining expand and reshape operations. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - repeats (int): number of times to repeat the target dimension - dim (int): dimension to repeat on - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - x = unsqueeze_expand_at(x, repeats, dim + 1) - return join_dimensions(x, dim, dim + 1) - - -def named_reduce_single(x, reduction, dim): - """ - Reduce tensor at a dimension by named reduction functions. - - Args: - x (torch.Tensor): tensor to be reduced - reduction (str): one of ["sum", "max", "mean", "flatten"] - dim (int): dimension to be reduced (or begin axis for flatten) - - Returns: - y (torch.Tensor): reduced tensor - """ - assert x.ndimension() > dim - assert reduction in ["sum", "max", "mean", "flatten"] - if reduction == "flatten": - x = flatten(x, begin_axis=dim) - elif reduction == "max": - x = torch.max(x, dim=dim)[0] # [B, D] - elif reduction == "sum": - x = torch.sum(x, dim=dim) - else: - x = torch.mean(x, dim=dim) - return x - - -def named_reduce(x, reduction, dim): - """ - Reduces all tensors in nested dictionary or list or tuple at a dimension - using a named reduction function. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - reduction (str): one of ["sum", "max", "mean", "flatten"] - dim (int): dimension to be reduced (or begin axis for flatten) - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return map_tensor(x, func=lambda t, r=reduction, d=dim: named_reduce_single(t, r, d)) - - -def gather_along_dim_with_dim_single(x, target_dim, source_dim, indices): - """ - This function indexes out a target dimension of a tensor in a structured way, - by allowing a different value to be selected for each member of a flat index - tensor (@indices) corresponding to a source dimension. This can be interpreted - as moving along the source dimension, using the corresponding index value - in @indices to select values for all other dimensions outside of the - source and target dimensions. A common use case is to gather values - in target dimension 1 for each batch member (target dimension 0). - - Args: - x (torch.Tensor): tensor to gather values for - target_dim (int): dimension to gather values along - source_dim (int): dimension to hold constant and use for gathering values - from the other dimensions - indices (torch.Tensor): flat index tensor with same shape as tensor @x along - @source_dim - - Returns: - y (torch.Tensor): gathered tensor, with dimension @target_dim indexed out - """ - assert len(indices.shape) == 1 - assert x.shape[source_dim] == indices.shape[0] - - # unsqueeze in all dimensions except the source dimension - new_shape = [1] * x.ndimension() - new_shape[source_dim] = -1 - indices = indices.reshape(*new_shape) - - # repeat in all dimensions - but preserve shape of source dimension, - # and make sure target_dimension has singleton dimension - expand_shape = list(x.shape) - expand_shape[source_dim] = -1 - expand_shape[target_dim] = 1 - indices = indices.expand(*expand_shape) - - out = x.gather(dim=target_dim, index=indices) - return out.squeeze(target_dim) - - -def gather_along_dim_with_dim(x, target_dim, source_dim, indices): - """ - Apply @gather_along_dim_with_dim_single to all tensors in a nested - dictionary or list or tuple. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - target_dim (int): dimension to gather values along - source_dim (int): dimension to hold constant and use for gathering values - from the other dimensions - indices (torch.Tensor): flat index tensor with same shape as tensor @x along - @source_dim - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return map_tensor( - x, lambda y, t=target_dim, s=source_dim, i=indices: gather_along_dim_with_dim_single(y, t, s, i) - ) - - -def gather_sequence_single(seq, indices): - """ - Given a tensor with leading dimensions [B, T, ...], gather an element from each sequence in - the batch given an index for each sequence. - - Args: - seq (torch.Tensor): tensor with leading dimensions [B, T, ...] - indices (torch.Tensor): tensor indices of shape [B] - - Return: - y (torch.Tensor): indexed tensor of shape [B, ....] - """ - return gather_along_dim_with_dim_single(seq, target_dim=1, source_dim=0, indices=indices) - - -def gather_sequence(seq, indices): - """ - Given a nested dictionary or list or tuple, gathers an element from each sequence of the batch - for tensors with leading dimensions [B, T, ...]. - - Args: - seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors - of leading dimensions [B, T, ...] - indices (torch.Tensor): tensor indices of shape [B] - - Returns: - y (dict or list or tuple): new nested dict-list-tuple with tensors of shape [B, ...] - """ - return gather_along_dim_with_dim(seq, target_dim=1, source_dim=0, indices=indices) - - -def pad_sequence_single(seq, padding, batched=False, pad_same=True, pad_values=None): - """ - Pad input tensor or array @seq in the time dimension (dimension 1). - - Args: - seq (np.ndarray or torch.Tensor): sequence to be padded - padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1 - batched (bool): if sequence has the batch dimension - pad_same (bool): if pad by duplicating - pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same - - Returns: - padded sequence (np.ndarray or torch.Tensor) - """ - assert isinstance(seq, (np.ndarray, torch.Tensor)) - assert pad_same or pad_values is not None - if pad_values is not None: - assert isinstance(pad_values, float) - repeat_func = np.repeat if isinstance(seq, np.ndarray) else torch.repeat_interleave - concat_func = np.concatenate if isinstance(seq, np.ndarray) else torch.cat - ones_like_func = np.ones_like if isinstance(seq, np.ndarray) else torch.ones_like - seq_dim = 1 if batched else 0 - - begin_pad = [] - end_pad = [] - - if padding[0] > 0: - pad = seq[[0]] if pad_same else ones_like_func(seq[[0]]) * pad_values - begin_pad.append(repeat_func(pad, padding[0], seq_dim)) - if padding[1] > 0: - pad = seq[[-1]] if pad_same else ones_like_func(seq[[-1]]) * pad_values - end_pad.append(repeat_func(pad, padding[1], seq_dim)) - - return concat_func(begin_pad + [seq] + end_pad, seq_dim) - - -def pad_sequence(seq, padding, batched=False, pad_same=True, pad_values=None): - """ - Pad a nested dictionary or list or tuple of sequence tensors in the time dimension (dimension 1). - - Args: - seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors - of leading dimensions [B, T, ...] - padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1 - batched (bool): if sequence has the batch dimension - pad_same (bool): if pad by duplicating - pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same - - Returns: - padded sequence (dict or list or tuple) - """ - return recursive_dict_list_tuple_apply( - seq, - { - torch.Tensor: lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values: pad_sequence_single( - x, p, b, ps, pv - ), - np.ndarray: lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values: pad_sequence_single( - x, p, b, ps, pv - ), - type(None): lambda x: x, - }, - ) - - -def assert_size_at_dim_single(x, size, dim, msg): - """ - Ensure that array or tensor @x has size @size in dim @dim. - - Args: - x (np.ndarray or torch.Tensor): input array or tensor - size (int): size that tensors should have at @dim - dim (int): dimension to check - msg (str): text to display if assertion fails - """ - assert x.shape[dim] == size, msg - - -def assert_size_at_dim(x, size, dim, msg): - """ - Ensure that arrays and tensors in nested dictionary or list or tuple have - size @size in dim @dim. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - size (int): size that tensors should have at @dim - dim (int): dimension to check - """ - map_tensor(x, lambda t, s=size, d=dim, m=msg: assert_size_at_dim_single(t, s, d, m)) - - -def get_shape(x): - """ - Get all shapes of arrays and tensors in nested dictionary or list or tuple. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - - Returns: - y (dict or list or tuple): new nested dict-list-tuple that contains each array or - tensor's shape - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: lambda x: x.shape, - np.ndarray: lambda x: x.shape, - type(None): lambda x: x, - }, - ) - - -def list_of_flat_dict_to_dict_of_list(list_of_dict): - """ - Helper function to go from a list of flat dictionaries to a dictionary of lists. - By "flat" we mean that none of the values are dictionaries, but are numpy arrays, - floats, etc. - - Args: - list_of_dict (list): list of flat dictionaries - - Returns: - dict_of_list (dict): dictionary of lists - """ - assert isinstance(list_of_dict, list) - dic = collections.OrderedDict() - for i in range(len(list_of_dict)): - for k in list_of_dict[i]: - if k not in dic: - dic[k] = [] - dic[k].append(list_of_dict[i][k]) - return dic - - -def flatten_nested_dict_list(d, parent_key="", sep="_", item_key=""): - """ - Flatten a nested dict or list to a list. - - For example, given a dict - { - a: 1 - b: { - c: 2 - } - c: 3 - } - - the function would return [(a, 1), (b_c, 2), (c, 3)] - - Args: - d (dict, list): a nested dict or list to be flattened - parent_key (str): recursion helper - sep (str): separator for nesting keys - item_key (str): recursion helper - Returns: - list: a list of (key, value) tuples - """ - items = [] - if isinstance(d, (tuple, list)): - new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key - for i, v in enumerate(d): - items.extend(flatten_nested_dict_list(v, new_key, sep=sep, item_key=str(i))) - return items - elif isinstance(d, dict): - new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key - for k, v in d.items(): - assert isinstance(k, str) - items.extend(flatten_nested_dict_list(v, new_key, sep=sep, item_key=k)) - return items - else: - new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key - return [(new_key, d)] - - -def time_distributed(inputs, op, activation=None, inputs_as_kwargs=False, inputs_as_args=False, **kwargs): - """ - Apply function @op to all tensors in nested dictionary or list or tuple @inputs in both the - batch (B) and time (T) dimension, where the tensors are expected to have shape [B, T, ...]. - Will do this by reshaping tensors to [B * T, ...], passing through the op, and then reshaping - outputs to [B, T, ...]. - - Args: - inputs (list or tuple or dict): a possibly nested dictionary or list or tuple with tensors - of leading dimensions [B, T, ...] - op: a layer op that accepts inputs - activation: activation to apply at the output - inputs_as_kwargs (bool): whether to feed input as a kwargs dict to the op - inputs_as_args (bool) whether to feed input as a args list to the op - kwargs (dict): other kwargs to supply to the op - - Returns: - outputs (dict or list or tuple): new nested dict-list-tuple with tensors of leading dimension [B, T]. - """ - batch_size, seq_len = flatten_nested_dict_list(inputs)[0][1].shape[:2] - inputs = join_dimensions(inputs, 0, 1) - if inputs_as_kwargs: - outputs = op(**inputs, **kwargs) - elif inputs_as_args: - outputs = op(*inputs, **kwargs) - else: - outputs = op(inputs, **kwargs) - - if activation is not None: - outputs = map_tensor(outputs, activation) - outputs = reshape_dimensions(outputs, begin_axis=0, end_axis=0, target_dims=(batch_size, seq_len)) - return outputs diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py new file mode 100644 index 00000000..e7cc62f4 --- /dev/null +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -0,0 +1,723 @@ +"""Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion" + +TODO(alexander-soare): + - Remove reliance on Robomimic for SpatialSoftmax. + - Remove reliance on diffusers for DDPMScheduler and LR scheduler. + - Move EMA out of policy. + - Consolidate _DiffusionUnetImagePolicy into DiffusionPolicy. + - One more pass on comments and documentation. +""" + +import copy +import logging +import math +import time +from collections import deque +from itertools import chain +from typing import Callable + +import einops +import torch +import torch.nn.functional as F # noqa: N812 +import torchvision +from diffusers.optimization import get_scheduler +from diffusers.schedulers.scheduling_ddpm import DDPMScheduler +from robomimic.models.base_nets import SpatialSoftmax +from torch import Tensor, nn +from torch.nn.modules.batchnorm import _BatchNorm + +from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig +from lerobot.common.policies.utils import ( + get_device_from_parameters, + get_dtype_from_parameters, + populate_queues, +) + + +class DiffusionPolicy(nn.Module): + """ + Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion" + (paper: https://arxiv.org/abs/2303.04137, code: https://github.com/real-stanford/diffusion_policy). + """ + + name = "diffusion" + + def __init__(self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0): + """ + Args: + cfg: Policy configuration class instance or None, in which case the default instantiation of the + configuration class is used. + """ + super().__init__() + # TODO(alexander-soare): LR scheduler will be removed. + assert lr_scheduler_num_training_steps > 0 + if cfg is None: + cfg = DiffusionConfig() + self.cfg = cfg + + # queues are populated during rollout of the policy, they contain the n latest observations and actions + self._queues = None + + self.diffusion = _DiffusionUnetImagePolicy(cfg) + + # TODO(alexander-soare): This should probably be managed outside of the policy class. + self.ema_diffusion = None + self.ema = None + if self.cfg.use_ema: + self.ema_diffusion = copy.deepcopy(self.diffusion) + self.ema = _EMA(cfg, model=self.ema_diffusion) + + # TODO(alexander-soare): Move optimizer out of policy. + self.optimizer = torch.optim.Adam( + self.diffusion.parameters(), cfg.lr, cfg.adam_betas, cfg.adam_eps, cfg.adam_weight_decay + ) + + # TODO(alexander-soare): Move LR scheduler out of policy. + # TODO(rcadene): modify lr scheduler so that it doesn't depend on epochs but steps + self.global_step = 0 + + # configure lr scheduler + self.lr_scheduler = get_scheduler( + cfg.lr_scheduler, + optimizer=self.optimizer, + num_warmup_steps=cfg.lr_warmup_steps, + num_training_steps=lr_scheduler_num_training_steps, + # pytorch assumes stepping LRScheduler every epoch + # however huggingface diffusers steps it every batch + last_epoch=self.global_step - 1, + ) + + def reset(self): + """ + Clear observation and action queues. Should be called on `env.reset()` + """ + self._queues = { + "observation.image": deque(maxlen=self.cfg.n_obs_steps), + "observation.state": deque(maxlen=self.cfg.n_obs_steps), + "action": deque(maxlen=self.cfg.n_action_steps), + } + + @torch.no_grad + def select_action(self, batch: dict[str, Tensor], **_) -> Tensor: + """Select a single action given environment observations. + + This method handles caching a history of observations and an action trajectory generated by the + underlying diffusion model. Here's how it works: + - `n_obs_steps` steps worth of observations are cached (for the first steps, the observation is + copied `n_obs_steps` times to fill the cache). + - The diffusion model generates `horizon` steps worth of actions. + - `n_action_steps` worth of actions are actually kept for execution, starting from the current step. + Schematically this looks like: + ---------------------------------------------------------------------------------------------- + (legend: o = n_obs_steps, h = horizon, a = n_action_steps) + |timestep | n-o+1 | n-o+2 | ..... | n | ..... | n+a-1 | n+a | ..... |n-o+1+h| + |observation is used | YES | YES | YES | NO | NO | NO | NO | NO | NO | + |action is generated | YES | YES | YES | YES | YES | YES | YES | YES | YES | + |action is used | NO | NO | NO | YES | YES | YES | NO | NO | NO | + ---------------------------------------------------------------------------------------------- + Note that this means we require: `n_action_steps < horizon - n_obs_steps + 1`. Also, note that + "horizon" may not the best name to describe what the variable actually means, because this period is + actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past. + + Note: this method uses the ema model weights if self.training == False, otherwise the non-ema model + weights. + """ + assert "observation.image" in batch + assert "observation.state" in batch + assert len(batch) == 2 + + self._queues = populate_queues(self._queues, batch) + + if len(self._queues["action"]) == 0: + # stack n latest observations from the queue + batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch} + if not self.training and self.ema_diffusion is not None: + actions = self.ema_diffusion.generate_actions(batch) + else: + actions = self.diffusion.generate_actions(batch) + self._queues["action"].extend(actions.transpose(0, 1)) + + action = self._queues["action"].popleft() + return action + + def forward(self, batch: dict[str, Tensor], **_) -> dict[str, Tensor]: + """Run the batch through the model and compute the loss for training or validation.""" + loss = self.diffusion.compute_loss(batch) + return {"loss": loss} + + def update(self, batch: dict[str, Tensor], **_) -> dict: + """Run the model in train mode, compute the loss, and do an optimization step.""" + start_time = time.time() + + self.diffusion.train() + + loss = self.forward(batch)["loss"] + loss.backward() + + grad_norm = torch.nn.utils.clip_grad_norm_( + self.diffusion.parameters(), + self.cfg.grad_clip_norm, + error_if_nonfinite=False, + ) + + self.optimizer.step() + self.optimizer.zero_grad() + self.lr_scheduler.step() + + if self.ema is not None: + self.ema.step(self.diffusion) + + info = { + "loss": loss.item(), + "grad_norm": float(grad_norm), + "lr": self.lr_scheduler.get_last_lr()[0], + "update_s": time.time() - start_time, + } + + return info + + def save(self, fp): + torch.save(self.state_dict(), fp) + + def load(self, fp): + d = torch.load(fp) + missing_keys, unexpected_keys = self.load_state_dict(d, strict=False) + if len(missing_keys) > 0: + assert all(k.startswith("ema_diffusion.") for k in missing_keys) + logging.warning( + "DiffusionPolicy.load expected ema parameters in loaded state dict but none were found." + ) + assert len(unexpected_keys) == 0 + + +class _DiffusionUnetImagePolicy(nn.Module): + def __init__(self, cfg: DiffusionConfig): + super().__init__() + self.cfg = cfg + + self.rgb_encoder = _RgbEncoder(cfg) + self.unet = _ConditionalUnet1D( + cfg, global_cond_dim=(cfg.action_dim + self.rgb_encoder.feature_dim) * cfg.n_obs_steps + ) + + self.noise_scheduler = DDPMScheduler( + num_train_timesteps=cfg.num_train_timesteps, + beta_start=cfg.beta_start, + beta_end=cfg.beta_end, + beta_schedule=cfg.beta_schedule, + variance_type="fixed_small", + clip_sample=cfg.clip_sample, + clip_sample_range=cfg.clip_sample_range, + prediction_type=cfg.prediction_type, + ) + + if cfg.num_inference_steps is None: + self.num_inference_steps = self.noise_scheduler.config.num_train_timesteps + else: + self.num_inference_steps = cfg.num_inference_steps + + # ========= inference ============ + def conditional_sample( + self, batch_size: int, global_cond: Tensor | None = None, generator: torch.Generator | None = None + ) -> Tensor: + device = get_device_from_parameters(self) + dtype = get_dtype_from_parameters(self) + + # Sample prior. + sample = torch.randn( + size=(batch_size, self.cfg.horizon, self.cfg.action_dim), + dtype=dtype, + device=device, + generator=generator, + ) + + self.noise_scheduler.set_timesteps(self.num_inference_steps) + + for t in self.noise_scheduler.timesteps: + # Predict model output. + model_output = self.unet( + sample, + torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device), + global_cond=global_cond, + ) + # Compute previous image: x_t -> x_t-1 + sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample + + return sample + + def generate_actions(self, batch: dict[str, Tensor]) -> Tensor: + """ + This function expects `batch` to have (at least): + { + "observation.state": (B, n_obs_steps, state_dim) + "observation.image": (B, n_obs_steps, C, H, W) + } + """ + assert set(batch).issuperset({"observation.state", "observation.image"}) + batch_size, n_obs_steps = batch["observation.state"].shape[:2] + assert n_obs_steps == self.cfg.n_obs_steps + + # Extract image feature (first combine batch and sequence dims). + img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ...")) + # Separate batch and sequence dims. + img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size) + # Concatenate state and image features then flatten to (B, global_cond_dim). + global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1) + + # run sampling + sample = self.conditional_sample(batch_size, global_cond=global_cond) + + # `horizon` steps worth of actions (from the first observation). + actions = sample[..., : self.cfg.action_dim] + # Extract `n_action_steps` steps worth of actions (from the current observation). + start = n_obs_steps - 1 + end = start + self.cfg.n_action_steps + actions = actions[:, start:end] + + return actions + + def compute_loss(self, batch: dict[str, Tensor]) -> Tensor: + """ + This function expects `batch` to have (at least): + { + "observation.state": (B, n_obs_steps, state_dim) + "observation.image": (B, n_obs_steps, C, H, W) + "action": (B, horizon, action_dim) + "action_is_pad": (B, horizon) + } + """ + # Input validation. + assert set(batch).issuperset({"observation.state", "observation.image", "action", "action_is_pad"}) + batch_size, n_obs_steps = batch["observation.state"].shape[:2] + horizon = batch["action"].shape[1] + assert horizon == self.cfg.horizon + assert n_obs_steps == self.cfg.n_obs_steps + + # Extract image feature (first combine batch and sequence dims). + img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ...")) + # Separate batch and sequence dims. + img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size) + # Concatenate state and image features then flatten to (B, global_cond_dim). + global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1) + + trajectory = batch["action"] + + # Forward diffusion. + # Sample noise to add to the trajectory. + eps = torch.randn(trajectory.shape, device=trajectory.device) + # Sample a random noising timestep for each item in the batch. + timesteps = torch.randint( + low=0, + high=self.noise_scheduler.config.num_train_timesteps, + size=(trajectory.shape[0],), + device=trajectory.device, + ).long() + # Add noise to the clean trajectories according to the noise magnitude at each timestep. + noisy_trajectory = self.noise_scheduler.add_noise(trajectory, eps, timesteps) + + # Run the denoising network (that might denoise the trajectory, or attempt to predict the noise). + pred = self.unet(noisy_trajectory, timesteps, global_cond=global_cond) + + # Compute the loss. + # The target is either the original trajectory, or the noise. + if self.cfg.prediction_type == "epsilon": + target = eps + elif self.cfg.prediction_type == "sample": + target = batch["action"] + else: + raise ValueError(f"Unsupported prediction type {self.cfg.prediction_type}") + + loss = F.mse_loss(pred, target, reduction="none") + + # Mask loss wherever the action is padded with copies (edges of the dataset trajectory). + if "action_is_pad" in batch: + in_episode_bound = ~batch["action_is_pad"] + loss = loss * in_episode_bound.unsqueeze(-1) + + return loss.mean() + + +class _RgbEncoder(nn.Module): + """Encoder an RGB image into a 1D feature vector. + + Includes the ability to normalize and crop the image first. + """ + + def __init__(self, cfg: DiffusionConfig): + super().__init__() + # Set up optional preprocessing. + if all(v == 1.0 for v in chain(cfg.image_normalization_mean, cfg.image_normalization_std)): + self.normalizer = nn.Identity() + else: + self.normalizer = torchvision.transforms.Normalize( + mean=cfg.image_normalization_mean, std=cfg.image_normalization_std + ) + if cfg.crop_shape is not None: + self.do_crop = True + # Always use center crop for eval + self.center_crop = torchvision.transforms.CenterCrop(cfg.crop_shape) + if cfg.crop_is_random: + self.maybe_random_crop = torchvision.transforms.RandomCrop(cfg.crop_shape) + else: + self.maybe_random_crop = self.center_crop + else: + self.do_crop = False + + # Set up backbone. + backbone_model = getattr(torchvision.models, cfg.vision_backbone)( + pretrained=cfg.use_pretrained_backbone + ) + # Note: This assumes that the layer4 feature map is children()[-3] + # TODO(alexander-soare): Use a safer alternative. + self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2])) + if cfg.use_group_norm: + if cfg.use_pretrained_backbone: + raise ValueError( + "You can't replace BatchNorm in a pretrained model without ruining the weights!" + ) + self.backbone = _replace_submodules( + root_module=self.backbone, + predicate=lambda x: isinstance(x, nn.BatchNorm2d), + func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features), + ) + + # Set up pooling and final layers. + # Use a dry run to get the feature map shape. + with torch.inference_mode(): + feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, 3, *cfg.image_size))).shape[1:]) + self.pool = SpatialSoftmax(feat_map_shape, num_kp=cfg.spatial_softmax_num_keypoints) + self.feature_dim = cfg.spatial_softmax_num_keypoints * 2 + self.out = nn.Linear(cfg.spatial_softmax_num_keypoints * 2, self.feature_dim) + self.relu = nn.ReLU() + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x: (B, C, H, W) image tensor with pixel values in [0, 1]. + Returns: + (B, D) image feature. + """ + # Preprocess: normalize and maybe crop (if it was set up in the __init__). + x = self.normalizer(x) + if self.do_crop: + if self.training: # noqa: SIM108 + x = self.maybe_random_crop(x) + else: + # Always use center crop for eval. + x = self.center_crop(x) + # Extract backbone feature. + x = torch.flatten(self.pool(self.backbone(x)), start_dim=1) + # Final linear layer with non-linearity. + x = self.relu(self.out(x)) + return x + + +def _replace_submodules( + root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module] +) -> nn.Module: + """ + Args: + root_module: The module for which the submodules need to be replaced + predicate: Takes a module as an argument and must return True if the that module is to be replaced. + func: Takes a module as an argument and returns a new module to replace it with. + Returns: + The root module with its submodules replaced. + """ + if predicate(root_module): + return func(root_module) + + replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)] + for *parents, k in replace_list: + parent_module = root_module + if len(parents) > 0: + parent_module = root_module.get_submodule(".".join(parents)) + if isinstance(parent_module, nn.Sequential): + src_module = parent_module[int(k)] + else: + src_module = getattr(parent_module, k) + tgt_module = func(src_module) + if isinstance(parent_module, nn.Sequential): + parent_module[int(k)] = tgt_module + else: + setattr(parent_module, k, tgt_module) + # verify that all BN are replaced + assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True)) + return root_module + + +class _SinusoidalPosEmb(nn.Module): + """1D sinusoidal positional embeddings as in Attention is All You Need.""" + + def __init__(self, dim: int): + super().__init__() + self.dim = dim + + def forward(self, x: Tensor) -> Tensor: + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x.unsqueeze(-1) * emb.unsqueeze(0) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class _Conv1dBlock(nn.Module): + """Conv1d --> GroupNorm --> Mish""" + + def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): + super().__init__() + + self.block = nn.Sequential( + nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), + nn.GroupNorm(n_groups, out_channels), + nn.Mish(), + ) + + def forward(self, x): + return self.block(x) + + +class _ConditionalUnet1D(nn.Module): + """A 1D convolutional UNet with FiLM modulation for conditioning. + + Note: this removes local conditioning as compared to the original diffusion policy code. + """ + + def __init__(self, cfg: DiffusionConfig, global_cond_dim: int): + super().__init__() + + self.cfg = cfg + + # Encoder for the diffusion timestep. + self.diffusion_step_encoder = nn.Sequential( + _SinusoidalPosEmb(cfg.diffusion_step_embed_dim), + nn.Linear(cfg.diffusion_step_embed_dim, cfg.diffusion_step_embed_dim * 4), + nn.Mish(), + nn.Linear(cfg.diffusion_step_embed_dim * 4, cfg.diffusion_step_embed_dim), + ) + + # The FiLM conditioning dimension. + cond_dim = cfg.diffusion_step_embed_dim + global_cond_dim + + # In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we + # just reverse these. + in_out = [(cfg.action_dim, cfg.down_dims[0])] + list( + zip(cfg.down_dims[:-1], cfg.down_dims[1:], strict=True) + ) + + # Unet encoder. + common_res_block_kwargs = { + "cond_dim": cond_dim, + "kernel_size": cfg.kernel_size, + "n_groups": cfg.n_groups, + "use_film_scale_modulation": cfg.use_film_scale_modulation, + } + self.down_modules = nn.ModuleList([]) + for ind, (dim_in, dim_out) in enumerate(in_out): + is_last = ind >= (len(in_out) - 1) + self.down_modules.append( + nn.ModuleList( + [ + _ConditionalResidualBlock1D(dim_in, dim_out, **common_res_block_kwargs), + _ConditionalResidualBlock1D(dim_out, dim_out, **common_res_block_kwargs), + # Downsample as long as it is not the last block. + nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(), + ] + ) + ) + + # Processing in the middle of the auto-encoder. + self.mid_modules = nn.ModuleList( + [ + _ConditionalResidualBlock1D(cfg.down_dims[-1], cfg.down_dims[-1], **common_res_block_kwargs), + _ConditionalResidualBlock1D(cfg.down_dims[-1], cfg.down_dims[-1], **common_res_block_kwargs), + ] + ) + + # Unet decoder. + self.up_modules = nn.ModuleList([]) + for ind, (dim_out, dim_in) in enumerate(reversed(in_out[1:])): + is_last = ind >= (len(in_out) - 1) + self.up_modules.append( + nn.ModuleList( + [ + # dim_in * 2, because it takes the encoder's skip connection as well + _ConditionalResidualBlock1D(dim_in * 2, dim_out, **common_res_block_kwargs), + _ConditionalResidualBlock1D(dim_out, dim_out, **common_res_block_kwargs), + # Upsample as long as it is not the last block. + nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(), + ] + ) + ) + + self.final_conv = nn.Sequential( + _Conv1dBlock(cfg.down_dims[0], cfg.down_dims[0], kernel_size=cfg.kernel_size), + nn.Conv1d(cfg.down_dims[0], cfg.action_dim, 1), + ) + + def forward(self, x: Tensor, timestep: Tensor | int, global_cond=None) -> Tensor: + """ + Args: + x: (B, T, input_dim) tensor for input to the Unet. + timestep: (B,) tensor of (timestep_we_are_denoising_from - 1). + global_cond: (B, global_cond_dim) + output: (B, T, input_dim) + Returns: + (B, T, input_dim) diffusion model prediction. + """ + # For 1D convolutions we'll need feature dimension first. + x = einops.rearrange(x, "b t d -> b d t") + + timesteps_embed = self.diffusion_step_encoder(timestep) + + # If there is a global conditioning feature, concatenate it to the timestep embedding. + if global_cond is not None: + global_feature = torch.cat([timesteps_embed, global_cond], axis=-1) + else: + global_feature = timesteps_embed + + # Run encoder, keeping track of skip features to pass to the decoder. + encoder_skip_features: list[Tensor] = [] + for resnet, resnet2, downsample in self.down_modules: + x = resnet(x, global_feature) + x = resnet2(x, global_feature) + encoder_skip_features.append(x) + x = downsample(x) + + for mid_module in self.mid_modules: + x = mid_module(x, global_feature) + + # Run decoder, using the skip features from the encoder. + for resnet, resnet2, upsample in self.up_modules: + x = torch.cat((x, encoder_skip_features.pop()), dim=1) + x = resnet(x, global_feature) + x = resnet2(x, global_feature) + x = upsample(x) + + x = self.final_conv(x) + + x = einops.rearrange(x, "b d t -> b t d") + return x + + +class _ConditionalResidualBlock1D(nn.Module): + """ResNet style 1D convolutional block with FiLM modulation for conditioning.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + cond_dim: int, + kernel_size: int = 3, + n_groups: int = 8, + # Set to True to do scale modulation with FiLM as well as bias modulation (defaults to False meaning + # FiLM just modulates bias). + use_film_scale_modulation: bool = False, + ): + super().__init__() + + self.use_film_scale_modulation = use_film_scale_modulation + self.out_channels = out_channels + + self.conv1 = _Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups) + + # FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale. + cond_channels = out_channels * 2 if use_film_scale_modulation else out_channels + self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels)) + + self.conv2 = _Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups) + + # A final convolution for dimension matching the residual (if needed). + self.residual_conv = ( + nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity() + ) + + def forward(self, x: Tensor, cond: Tensor) -> Tensor: + """ + Args: + x: (B, in_channels, T) + cond: (B, cond_dim) + Returns: + (B, out_channels, T) + """ + out = self.conv1(x) + + # Get condition embedding. Unsqueeze for broadcasting to `out`, resulting in (B, out_channels, 1). + cond_embed = self.cond_encoder(cond).unsqueeze(-1) + if self.use_film_scale_modulation: + # Treat the embedding as a list of scales and biases. + scale = cond_embed[:, : self.out_channels] + bias = cond_embed[:, self.out_channels :] + out = scale * out + bias + else: + # Treat the embedding as biases. + out = out + cond_embed + + out = self.conv2(out) + out = out + self.residual_conv(x) + return out + + +class _EMA: + """ + Exponential Moving Average of models weights + """ + + def __init__(self, cfg: DiffusionConfig, model: nn.Module): + """ + @crowsonkb's notes on EMA Warmup: + If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan + to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), + gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 + at 215.4k steps). + Args: + inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. + power (float): Exponential factor of EMA warmup. Default: 2/3. + min_alpha (float): The minimum EMA decay rate. Default: 0. + """ + + self.averaged_model = model + self.averaged_model.eval() + self.averaged_model.requires_grad_(False) + + self.update_after_step = cfg.ema_update_after_step + self.inv_gamma = cfg.ema_inv_gamma + self.power = cfg.ema_power + self.min_alpha = cfg.ema_min_alpha + self.max_alpha = cfg.ema_max_alpha + + self.alpha = 0.0 + self.optimization_step = 0 + + def get_decay(self, optimization_step): + """ + Compute the decay factor for the exponential moving average. + """ + step = max(0, optimization_step - self.update_after_step - 1) + value = 1 - (1 + step / self.inv_gamma) ** -self.power + + if step <= 0: + return 0.0 + + return max(self.min_alpha, min(value, self.max_alpha)) + + @torch.no_grad() + def step(self, new_model): + self.alpha = self.get_decay(self.optimization_step) + + for module, ema_module in zip(new_model.modules(), self.averaged_model.modules(), strict=True): + # Iterate over immediate parameters only. + for param, ema_param in zip( + module.parameters(recurse=False), ema_module.parameters(recurse=False), strict=True + ): + if isinstance(param, dict): + raise RuntimeError("Dict parameter not supported") + if isinstance(module, _BatchNorm) or not param.requires_grad: + # Copy BatchNorm parameters, and non-trainable parameters directly. + ema_param.copy_(param.to(dtype=ema_param.dtype).data) + else: + ema_param.mul_(self.alpha) + ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.alpha) + + self.optimization_step += 1 diff --git a/lerobot/common/policies/diffusion/policy.py b/lerobot/common/policies/diffusion/policy.py deleted file mode 100644 index 9785358b..00000000 --- a/lerobot/common/policies/diffusion/policy.py +++ /dev/null @@ -1,195 +0,0 @@ -import copy -import logging -import time -from collections import deque - -import hydra -import torch -from torch import nn - -from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy -from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler -from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder, RgbEncoder -from lerobot.common.policies.utils import populate_queues -from lerobot.common.utils import get_safe_torch_device - - -class DiffusionPolicy(nn.Module): - name = "diffusion" - - def __init__( - self, - cfg, - cfg_device, - cfg_noise_scheduler, - cfg_rgb_model, - cfg_obs_encoder, - cfg_optimizer, - cfg_ema, - shape_meta: dict, - horizon, - n_action_steps, - n_obs_steps, - num_inference_steps=None, - obs_as_global_cond=True, - diffusion_step_embed_dim=256, - down_dims=(256, 512, 1024), - kernel_size=5, - n_groups=8, - cond_predict_scale=True, - # parameters passed to step - **kwargs, - ): - super().__init__() - self.cfg = cfg - self.n_obs_steps = n_obs_steps - self.n_action_steps = n_action_steps - # queues are populated during rollout of the policy, they contain the n latest observations and actions - self._queues = None - - noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler) - rgb_model_input_shape = copy.deepcopy(shape_meta.obs.image.shape) - if cfg_obs_encoder.crop_shape is not None: - rgb_model_input_shape[1:] = cfg_obs_encoder.crop_shape - rgb_model = RgbEncoder(input_shape=rgb_model_input_shape, **cfg_rgb_model) - obs_encoder = MultiImageObsEncoder( - rgb_model=rgb_model, - **cfg_obs_encoder, - ) - - self.diffusion = DiffusionUnetImagePolicy( - shape_meta=shape_meta, - noise_scheduler=noise_scheduler, - obs_encoder=obs_encoder, - horizon=horizon, - n_action_steps=n_action_steps, - n_obs_steps=n_obs_steps, - num_inference_steps=num_inference_steps, - obs_as_global_cond=obs_as_global_cond, - diffusion_step_embed_dim=diffusion_step_embed_dim, - down_dims=down_dims, - kernel_size=kernel_size, - n_groups=n_groups, - cond_predict_scale=cond_predict_scale, - # parameters passed to step - **kwargs, - ) - - self.device = get_safe_torch_device(cfg_device) - self.diffusion.to(self.device) - - self.ema_diffusion = None - self.ema = None - if self.cfg.use_ema: - self.ema_diffusion = copy.deepcopy(self.diffusion) - self.ema = hydra.utils.instantiate( - cfg_ema, - model=self.ema_diffusion, - ) - - self.optimizer = hydra.utils.instantiate( - cfg_optimizer, - params=self.diffusion.parameters(), - ) - - # TODO(rcadene): modify lr scheduler so that it doesnt depend on epochs but steps - self.global_step = 0 - - # configure lr scheduler - self.lr_scheduler = get_scheduler( - cfg.lr_scheduler, - optimizer=self.optimizer, - num_warmup_steps=cfg.lr_warmup_steps, - num_training_steps=cfg.offline_steps, - # pytorch assumes stepping LRScheduler every epoch - # however huggingface diffusers steps it every batch - last_epoch=self.global_step - 1, - ) - - def reset(self): - """ - Clear observation and action queues. Should be called on `env.reset()` - """ - self._queues = { - "observation.image": deque(maxlen=self.n_obs_steps), - "observation.state": deque(maxlen=self.n_obs_steps), - "action": deque(maxlen=self.n_action_steps), - } - - @torch.no_grad() - def select_action(self, batch, step): - """ - Note: this uses the ema model weights if self.training == False, otherwise the non-ema model weights. - """ - # TODO(rcadene): remove unused step - del step - assert "observation.image" in batch - assert "observation.state" in batch - assert len(batch) == 2 - - self._queues = populate_queues(self._queues, batch) - - if len(self._queues["action"]) == 0: - # stack n latest observations from the queue - batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch} - - obs_dict = { - "image": batch["observation.image"], - "agent_pos": batch["observation.state"], - } - if self.training: - out = self.diffusion.predict_action(obs_dict) - else: - out = self.ema_diffusion.predict_action(obs_dict) - self._queues["action"].extend(out["action"].transpose(0, 1)) - - action = self._queues["action"].popleft() - return action - - def forward(self, batch, step): - start_time = time.time() - - self.diffusion.train() - - loss = self.diffusion.compute_loss(batch) - loss.backward() - - grad_norm = torch.nn.utils.clip_grad_norm_( - self.diffusion.parameters(), - self.cfg.grad_clip_norm, - error_if_nonfinite=False, - ) - - self.optimizer.step() - self.optimizer.zero_grad() - self.lr_scheduler.step() - - if self.ema is not None: - self.ema.step(self.diffusion) - - info = { - "loss": loss.item(), - "grad_norm": float(grad_norm), - "lr": self.lr_scheduler.get_last_lr()[0], - "update_s": time.time() - start_time, - } - - # TODO(rcadene): remove hardcoding - # in diffusion_policy, len(dataloader) is 168 for a batch_size of 64 - if step % 168 == 0: - self.global_step += 1 - - return info - - def save(self, fp): - torch.save(self.state_dict(), fp) - - def load(self, fp): - d = torch.load(fp) - missing_keys, unexpected_keys = self.load_state_dict(d, strict=False) - if len(missing_keys) > 0: - assert all(k.startswith("ema_diffusion.") for k in missing_keys) - logging.warning( - "DiffusionPolicy.load expected ema parameters in loaded state dict but none were found." - ) - assert len(unexpected_keys) == 0 diff --git a/lerobot/common/policies/diffusion/pytorch_utils.py b/lerobot/common/policies/diffusion/pytorch_utils.py deleted file mode 100644 index ed5dc23a..00000000 --- a/lerobot/common/policies/diffusion/pytorch_utils.py +++ /dev/null @@ -1,76 +0,0 @@ -from typing import Callable, Dict - -import torch -import torch.nn as nn -import torchvision - - -def get_resnet(name, weights=None, **kwargs): - """ - name: resnet18, resnet34, resnet50 - weights: "IMAGENET1K_V1", "r3m" - """ - # load r3m weights - if (weights == "r3m") or (weights == "R3M"): - return get_r3m(name=name, **kwargs) - - func = getattr(torchvision.models, name) - resnet = func(weights=weights, **kwargs) - resnet.fc = torch.nn.Identity() - return resnet - - -def get_r3m(name, **kwargs): - """ - name: resnet18, resnet34, resnet50 - """ - import r3m - - r3m.device = "cpu" - model = r3m.load_r3m(name) - r3m_model = model.module - resnet_model = r3m_model.convnet - resnet_model = resnet_model.to("cpu") - return resnet_model - - -def dict_apply( - x: Dict[str, torch.Tensor], func: Callable[[torch.Tensor], torch.Tensor] -) -> Dict[str, torch.Tensor]: - result = {} - for key, value in x.items(): - if isinstance(value, dict): - result[key] = dict_apply(value, func) - else: - result[key] = func(value) - return result - - -def replace_submodules( - root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module] -) -> nn.Module: - """ - predicate: Return true if the module is to be replaced. - func: Return new module to use. - """ - if predicate(root_module): - return func(root_module) - - bn_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)] - for *parent, k in bn_list: - parent_module = root_module - if len(parent) > 0: - parent_module = root_module.get_submodule(".".join(parent)) - if isinstance(parent_module, nn.Sequential): - src_module = parent_module[int(k)] - else: - src_module = getattr(parent_module, k) - tgt_module = func(src_module) - if isinstance(parent_module, nn.Sequential): - parent_module[int(k)] = tgt_module - else: - setattr(parent_module, k, tgt_module) - # verify that all BN are replaced - bn_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)] - assert len(bn_list) == 0 - return root_module diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index a1cbea9a..b5b5f861 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -1,42 +1,61 @@ -def make_policy(cfg): - if cfg.policy.name == "tdmpc": +import inspect + +from omegaconf import DictConfig, OmegaConf + +from lerobot.common.utils import get_safe_torch_device + + +def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg): + expected_kwargs = set(inspect.signature(policy_cfg_class).parameters) + assert set(hydra_cfg.policy).issuperset( + expected_kwargs + ), f"Hydra config is missing arguments: {set(expected_kwargs).difference(hydra_cfg.policy)}" + policy_cfg = policy_cfg_class( + **{ + k: v + for k, v in OmegaConf.to_container(hydra_cfg.policy, resolve=True).items() + if k in expected_kwargs + } + ) + return policy_cfg + + +def make_policy(hydra_cfg: DictConfig): + if hydra_cfg.policy.name == "tdmpc": from lerobot.common.policies.tdmpc.policy import TDMPCPolicy policy = TDMPCPolicy( - cfg.policy, n_obs_steps=cfg.n_obs_steps, n_action_steps=cfg.n_action_steps, device=cfg.device + hydra_cfg.policy, + n_obs_steps=hydra_cfg.n_obs_steps, + n_action_steps=hydra_cfg.n_action_steps, + device=hydra_cfg.device, ) - elif cfg.policy.name == "diffusion": - from lerobot.common.policies.diffusion.policy import DiffusionPolicy + elif hydra_cfg.policy.name == "diffusion": + from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig + from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy - policy = DiffusionPolicy( - cfg=cfg.policy, - cfg_device=cfg.device, - cfg_noise_scheduler=cfg.noise_scheduler, - cfg_rgb_model=cfg.rgb_model, - cfg_obs_encoder=cfg.obs_encoder, - cfg_optimizer=cfg.optimizer, - cfg_ema=cfg.ema, - # n_obs_steps=cfg.n_obs_steps, - # n_action_steps=cfg.n_action_steps, - **cfg.policy, - ) - elif cfg.policy.name == "act": - from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy + policy_cfg = _policy_cfg_from_hydra_cfg(DiffusionConfig, hydra_cfg) + policy = DiffusionPolicy(policy_cfg, hydra_cfg.offline_steps) + policy.to(get_safe_torch_device(hydra_cfg.device)) + elif hydra_cfg.policy.name == "act": + from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig + from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy - policy = ActionChunkingTransformerPolicy(cfg.policy, cfg.device) - policy.to(cfg.device) + policy_cfg = _policy_cfg_from_hydra_cfg(ActionChunkingTransformerConfig, hydra_cfg) + policy = ActionChunkingTransformerPolicy(policy_cfg) + policy.to(get_safe_torch_device(hydra_cfg.device)) else: - raise ValueError(cfg.policy.name) + raise ValueError(hydra_cfg.policy.name) - if cfg.policy.pretrained_model_path: + if hydra_cfg.policy.pretrained_model_path: # TODO(rcadene): hack for old pretrained models from fowm - if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path: - if "offline" in cfg.pretrained_model_path: + if hydra_cfg.policy.name == "tdmpc" and "fowm" in hydra_cfg.policy.pretrained_model_path: + if "offline" in hydra_cfg.policy.pretrained_model_path: policy.step[0] = 25000 - elif "final" in cfg.pretrained_model_path: + elif "final" in hydra_cfg.policy.pretrained_model_path: policy.step[0] = 100000 else: raise NotImplementedError() - policy.load(cfg.policy.pretrained_model_path) + policy.load(hydra_cfg.policy.pretrained_model_path) return policy diff --git a/lerobot/common/policies/policy_protocol.py b/lerobot/common/policies/policy_protocol.py new file mode 100644 index 00000000..6401c734 --- /dev/null +++ b/lerobot/common/policies/policy_protocol.py @@ -0,0 +1,45 @@ +"""A protocol that all policies should follow. + +This provides a mechanism for type-hinting and isinstance checks without requiring the policies classes +subclass a base class. + +The protocol structure, method signatures, and docstrings should be used by developers as a reference for +how to implement new policies. +""" + +from typing import Protocol, runtime_checkable + +from torch import Tensor + + +@runtime_checkable +class Policy(Protocol): + """The required interface for implementing a policy.""" + + name: str + + def reset(self): + """To be called whenever the environment is reset. + + Does things like clearing caches. + """ + + def forward(self, batch: dict[str, Tensor]) -> dict: + """Run the batch through the model and compute the loss for training or validation. + + Returns a dictionary with "loss" and maybe other information. + """ + + def select_action(self, batch: dict[str, Tensor]): + """Return one action to run in the environment (potentially in batch mode). + + When the model uses a history of observations, or outputs a sequence of actions, this method deals + with caching. + """ + + def update(self, batch): + """Does compute_loss then an optimization step. + + TODO(alexander-soare): We will move the optimization step back into the training loop, so this will + disappear. + """ diff --git a/lerobot/common/policies/tdmpc/policy.py b/lerobot/common/policies/tdmpc/policy.py index 39c9bee0..4cb3741a 100644 --- a/lerobot/common/policies/tdmpc/policy.py +++ b/lerobot/common/policies/tdmpc/policy.py @@ -335,97 +335,13 @@ class TDMPCPolicy(nn.Module): return td_target def forward(self, batch, step): + # TODO(alexander-soare): Refactor TDMPC and make it comply with the policy interface documentation. + raise NotImplementedError() + + def update(self, batch, step): """Main update function. Corresponds to one iteration of the model learning.""" start_time = time.time() - # num_slices = self.cfg.batch_size - # batch_size = self.cfg.horizon * num_slices - - # if demo_buffer is None: - # demo_batch_size = 0 - # else: - # # Update oversampling ratio - # demo_pc_batch = h.linear_schedule(self.cfg.demo_schedule, step) - # demo_num_slices = int(demo_pc_batch * self.batch_size) - # demo_batch_size = self.cfg.horizon * demo_num_slices - # batch_size -= demo_batch_size - # num_slices -= demo_num_slices - # replay_buffer._sampler.num_slices = num_slices - # demo_buffer._sampler.num_slices = demo_num_slices - - # assert demo_batch_size % self.cfg.horizon == 0 - # assert demo_batch_size % demo_num_slices == 0 - - # assert batch_size % self.cfg.horizon == 0 - # assert batch_size % num_slices == 0 - - # # Sample from interaction dataset - - # def process_batch(batch, horizon, num_slices): - # # trajectory t = 256, horizon h = 5 - # # (t h) ... -> h t ... - # batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous() - - # obs = { - # "rgb": batch["observation", "image"][FIRST_FRAME].to(self.device, non_blocking=True), - # "state": batch["observation", "state"][FIRST_FRAME].to(self.device, non_blocking=True), - # } - # action = batch["action"].to(self.device, non_blocking=True) - # next_obses = { - # "rgb": batch["next", "observation", "image"].to(self.device, non_blocking=True), - # "state": batch["next", "observation", "state"].to(self.device, non_blocking=True), - # } - # reward = batch["next", "reward"].to(self.device, non_blocking=True) - - # idxs = batch["index"][FIRST_FRAME].to(self.device, non_blocking=True) - # weights = batch["_weight"][FIRST_FRAME, :, None].to(self.device, non_blocking=True) - - # # TODO(rcadene): rearrange directly in offline dataset - # if reward.ndim == 2: - # reward = einops.rearrange(reward, "h t -> h t 1") - - # assert reward.ndim == 3 - # assert reward.shape == (horizon, num_slices, 1) - # # We dont use `batch["next", "done"]` since it only indicates the end of an - # # episode, but not the end of the trajectory of an episode. - # # Neither does `batch["next", "terminated"]` - # done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device) - # mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device) - # return obs, action, next_obses, reward, mask, done, idxs, weights - - # batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample() - - # obs, action, next_obses, reward, mask, done, idxs, weights = process_batch( - # batch, self.cfg.horizon, num_slices - # ) - - # Sample from demonstration dataset - # if demo_batch_size > 0: - # demo_batch = demo_buffer.sample(demo_batch_size) - # ( - # demo_obs, - # demo_action, - # demo_next_obses, - # demo_reward, - # demo_mask, - # demo_done, - # demo_idxs, - # demo_weights, - # ) = process_batch(demo_batch, self.cfg.horizon, demo_num_slices) - - # if isinstance(obs, dict): - # obs = {k: torch.cat([obs[k], demo_obs[k]]) for k in obs} - # next_obses = {k: torch.cat([next_obses[k], demo_next_obses[k]], dim=1) for k in next_obses} - # else: - # obs = torch.cat([obs, demo_obs]) - # next_obses = torch.cat([next_obses, demo_next_obses], dim=1) - # action = torch.cat([action, demo_action], dim=1) - # reward = torch.cat([reward, demo_reward], dim=1) - # mask = torch.cat([mask, demo_mask], dim=1) - # done = torch.cat([done, demo_done], dim=1) - # idxs = torch.cat([idxs, demo_idxs]) - # weights = torch.cat([weights, demo_weights]) - batch_size = batch["index"].shape[0] # TODO(rcadene): convert tdmpc with (batch size, time/horizon, channels) @@ -539,6 +455,7 @@ class TDMPCPolicy(nn.Module): ) self.optim.step() + # TODO(rcadene): implement PrioritizedSampling by modifying sampler.weights with priorities computed by a criterion # if self.cfg.per: # # Update priorities # priorities = priority_loss.clamp(max=1e4).detach() diff --git a/lerobot/common/policies/utils.py b/lerobot/common/policies/utils.py index b0503fe4..b23c1336 100644 --- a/lerobot/common/policies/utils.py +++ b/lerobot/common/policies/utils.py @@ -1,3 +1,7 @@ +import torch +from torch import nn + + def populate_queues(queues, batch): for key in batch: if len(queues[key]) != queues[key].maxlen: @@ -8,3 +12,19 @@ def populate_queues(queues, batch): # add latest observation to the queue queues[key].append(batch[key]) return queues + + +def get_device_from_parameters(module: nn.Module) -> torch.device: + """Get a module's device by checking one of its parameters. + + Note: assumes that all parameters have the same device + """ + return next(iter(module.parameters())).device + + +def get_dtype_from_parameters(module: nn.Module) -> torch.dtype: + """Get a module's parameter dtype by checking one of its parameters. + + Note: assumes that all parameters have the same dtype. + """ + return next(iter(module.parameters())).dtype diff --git a/lerobot/common/utils.py b/lerobot/common/utils.py index e3e22832..81b3d986 100644 --- a/lerobot/common/utils.py +++ b/lerobot/common/utils.py @@ -11,6 +11,7 @@ from omegaconf import DictConfig def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device: + """Given a string, return a torch.device with checks on whether the device is available.""" match cfg_device: case "cuda": assert torch.cuda.is_available() @@ -98,6 +99,7 @@ def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> D def print_cuda_memory_usage(): + """Use this function to locate and debug memory leak.""" import gc gc.collect() diff --git a/lerobot/configs/env/aloha.yaml b/lerobot/configs/env/aloha.yaml index 7a8d8b58..6b836795 100644 --- a/lerobot/configs/env/aloha.yaml +++ b/lerobot/configs/env/aloha.yaml @@ -18,7 +18,6 @@ env: from_pixels: True pixels_only: False image_size: [3, 480, 640] - action_repeat: 1 episode_length: 400 fps: ${fps} diff --git a/lerobot/configs/env/pusht.yaml b/lerobot/configs/env/pusht.yaml index a5fbcc25..a7097ffd 100644 --- a/lerobot/configs/env/pusht.yaml +++ b/lerobot/configs/env/pusht.yaml @@ -18,7 +18,6 @@ env: from_pixels: True pixels_only: False image_size: 96 - action_repeat: 1 episode_length: 300 fps: ${fps} diff --git a/lerobot/configs/env/xarm.yaml b/lerobot/configs/env/xarm.yaml index 8b3c72ef..bcba659e 100644 --- a/lerobot/configs/env/xarm.yaml +++ b/lerobot/configs/env/xarm.yaml @@ -17,7 +17,6 @@ env: from_pixels: True pixels_only: False image_size: 84 - # action_repeat: 2 # we can remove if policy has n_action_steps=2 episode_length: 25 fps: ${fps} diff --git a/lerobot/configs/policy/act.yaml b/lerobot/configs/policy/act.yaml index e2074b46..eb4e512b 100644 --- a/lerobot/configs/policy/act.yaml +++ b/lerobot/configs/policy/act.yaml @@ -8,61 +8,63 @@ eval_freq: 10000 save_freq: 100000 log_freq: 250 -horizon: 100 n_obs_steps: 1 # when temporal_agg=False, n_action_steps=horizon -n_action_steps: ${horizon} +# See `configuration_act.py` for more details. policy: name: act pretrained_model_path: + # Environment. + # Inherit these from the environment config. + state_dim: ??? + action_dim: ??? + + # Inputs / output structure. + n_obs_steps: ${n_obs_steps} + camera_names: [top] # [top, front_close, left_pillar, right_pillar] + chunk_size: 100 # chunk_size + n_action_steps: 100 + + # Vision preprocessing. + image_normalization_mean: [0.485, 0.456, 0.406] + image_normalization_std: [0.229, 0.224, 0.225] + + # Architecture. + # Vision backbone. + vision_backbone: resnet18 + use_pretrained_backbone: true + replace_final_stride_with_dilation: false + # Transformer layers. + pre_norm: false + d_model: 512 + n_heads: 8 + dim_feedforward: 3200 + feedforward_activation: relu + n_encoder_layers: 4 + n_decoder_layers: 1 + # VAE. + use_vae: true + latent_dim: 32 + n_vae_encoder_layers: 4 + + # Inference. + use_temporal_aggregation: false + + # Training and loss computation. + dropout: 0.1 + kl_weight: 10.0 + + # --- + # TODO(alexander-soare): Remove these from the policy config. + batch_size: 8 lr: 1e-5 lr_backbone: 1e-5 - pretrained_backbone: true weight_decay: 1e-4 grad_clip_norm: 10 - backbone: resnet18 - horizon: ${horizon} # chunk_size - kl_weight: 10 - d_model: 512 - dim_feedforward: 3200 - vae_enc_layers: 4 - enc_layers: 4 - dec_layers: 1 - num_heads: 8 - #camera_names: [top, front_close, left_pillar, right_pillar] - camera_names: [top] - dilation: false - dropout: 0.1 - pre_norm: false - activation: relu - latent_dim: 32 - - use_vae: true - - batch_size: 8 - - per_alpha: 0.6 - per_beta: 0.4 - - balanced_sampling: false utd: 1 - n_obs_steps: ${n_obs_steps} - n_action_steps: ${n_action_steps} - - temporal_agg: false - - state_dim: 14 - action_dim: 14 - - image_normalization: - mean: [0.485, 0.456, 0.406] - std: [0.229, 0.224, 0.225] - delta_timestamps: - observation.images.top: [0.0] - observation.state: [0.0] - action: "[i / ${fps} for i in range(${horizon})]" + action: "[i / ${fps} for i in range(${policy.chunk_size})]" diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index 811ee824..44746dfc 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -1,17 +1,5 @@ # @package _global_ -shape_meta: - # acceptable types: rgb, low_dim - obs: - image: - shape: [3, 96, 96] - type: rgb - agent_pos: - shape: [2] - type: low_dim - action: - shape: [2] - seed: 100000 horizon: 16 n_obs_steps: 2 @@ -19,7 +7,6 @@ n_action_steps: 8 dataset_obs_steps: ${n_obs_steps} past_action_visible: False keypoint_visible_rate: 1.0 -obs_as_global_cond: True eval_episodes: 50 eval_freq: 5000 @@ -34,76 +21,70 @@ offline_prioritized_sampler: true policy: name: diffusion - shape_meta: ${shape_meta} + pretrained_model_path: - horizon: ${horizon} + # Environment. + # Inherit these from the environment config. + state_dim: ??? + action_dim: ??? + image_size: + - ${env.image_size} # height + - ${env.image_size} # width + + # Inputs / output structure. n_obs_steps: ${n_obs_steps} + horizon: ${horizon} n_action_steps: ${n_action_steps} - num_inference_steps: 100 - obs_as_global_cond: ${obs_as_global_cond} - # crop_shape: null - diffusion_step_embed_dim: 128 + + # Vision preprocessing. + image_normalization_mean: [0.5, 0.5, 0.5] + image_normalization_std: [0.5, 0.5, 0.5] + + # Architecture / modeling. + # Vision backbone. + vision_backbone: resnet18 + crop_shape: [84, 84] + crop_is_random: True + use_pretrained_backbone: false + use_group_norm: True + spatial_softmax_num_keypoints: 32 + # Unet. down_dims: [512, 1024, 2048] kernel_size: 5 n_groups: 8 - cond_predict_scale: True - - pretrained_model_path: - - batch_size: 64 - - per_alpha: 0.6 - per_beta: 0.4 - - balanced_sampling: false - utd: 1 - offline_steps: ${offline_steps} - use_ema: true - lr_scheduler: cosine - lr_warmup_steps: 500 - grad_clip_norm: 10 - - delta_timestamps: - observation.image: [-0.1, 0] - observation.state: [-0.1, 0] - action: [-0.1, 0, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1.0, 1.1, 1.2, 1.3, 1.4] - -noise_scheduler: - _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler + diffusion_step_embed_dim: 128 + use_film_scale_modulation: True + # Noise scheduler. num_train_timesteps: 100 + beta_schedule: squaredcos_cap_v2 beta_start: 0.0001 beta_end: 0.02 - beta_schedule: squaredcos_cap_v2 - variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan - clip_sample: True # required when predict_epsilon=False - prediction_type: epsilon # or sample + prediction_type: epsilon # epsilon / sample + clip_sample: True + clip_sample_range: 1.0 -obs_encoder: - shape_meta: ${shape_meta} - # resize_shape: null - crop_shape: [84, 84] - # constant center crop - random_crop: True - use_group_norm: True - share_rgb_model: False - norm_mean_std: [0.5, 0.5] # for PushT the original impl normalizes to [-1, 1] (maybe not the case for robomimic envs) + # Inference + num_inference_steps: 100 -rgb_model: - pretrained: false - num_keypoints: 32 - relu: true - -ema: - _target_: lerobot.common.policies.diffusion.model.ema_model.EMAModel - update_after_step: 0 - inv_gamma: 1.0 - power: 0.75 - min_value: 0.0 - max_value: 0.9999 - -optimizer: - _target_: torch.optim.AdamW + # --- + # TODO(alexander-soare): Remove these from the policy config. + batch_size: 64 + grad_clip_norm: 10 lr: 1.0e-4 - betas: [0.95, 0.999] - eps: 1.0e-8 - weight_decay: 1.0e-6 + lr_scheduler: cosine + lr_warmup_steps: 500 + adam_betas: [0.95, 0.999] + adam_eps: 1.0e-8 + adam_weight_decay: 1.0e-6 + utd: 1 + use_ema: true + ema_update_after_step: 0 + ema_min_alpha: 0.0 + ema_max_alpha: 0.9999 + ema_inv_gamma: 1.0 + ema_power: 0.75 + + delta_timestamps: + observation.image: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]" + observation.state: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]" + action: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1 - ${n_obs_steps} + ${policy.horizon})]" diff --git a/lerobot/configs/policy/tdmpc.yaml b/lerobot/configs/policy/tdmpc.yaml index 6a06ef51..4fd2b6bb 100644 --- a/lerobot/configs/policy/tdmpc.yaml +++ b/lerobot/configs/policy/tdmpc.yaml @@ -36,6 +36,7 @@ policy: log_std_max: 2 # learning + batch_size: 256 max_buffer_size: 10000 horizon: 5 reward_coef: 0.5 @@ -82,5 +83,3 @@ policy: observation.state: "[i / ${fps} for i in range(6)]" action: "[i / ${fps} for i in range(5)]" next.reward: "[i / ${fps} for i in range(5)]" - - batch_size: 256 diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 72daaf70..e3e439d5 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -32,6 +32,7 @@ import json import logging import threading import time +from copy import deepcopy from datetime import datetime as dt from pathlib import Path @@ -40,7 +41,9 @@ import gymnasium as gym import imageio import numpy as np import torch +from datasets import Dataset from huggingface_hub import snapshot_download +from PIL import Image as PILImage from lerobot.common.datasets.factory import make_dataset from lerobot.common.envs.factory import make_env @@ -56,15 +59,15 @@ def write_video(video_path, stacked_frames, fps): def eval_policy( env: gym.vector.VectorEnv, - policy, - save_video: bool = False, + policy: torch.nn.Module, + max_episodes_rendered: int = 0, video_dir: Path = None, # TODO(rcadene): make it possible to overwrite fps? we should use env.fps - fps: int = 15, - return_first_video: bool = False, transform: callable = None, seed=None, ): + fps = env.unwrapped.metadata["render_fps"] + if policy is not None: policy.eval() device = "cpu" if policy is None else next(policy.parameters()).device @@ -83,14 +86,11 @@ def eval_policy( # needed as I'm currently taking a ceil. ep_frames = [] - def maybe_render_frame(env): - if save_video: # noqa: B023 - if return_first_video: - visu = env.envs[0].render() - visu = visu[None, ...] # add batch dim - else: - visu = np.stack([env.render() for env in env.envs]) - ep_frames.append(visu) # noqa: B023 + def render_frame(env): + # noqa: B023 + eps_rendered = min(max_episodes_rendered, len(env.envs)) + visu = np.stack([env.envs[i].render() for i in range(eps_rendered)]) + ep_frames.append(visu) # noqa: B023 for _ in range(num_episodes): seeds.append("TODO") @@ -104,8 +104,14 @@ def eval_policy( # reset the environment observation, info = env.reset(seed=seed) - maybe_render_frame(env) + if max_episodes_rendered > 0: + render_frame(env) + observations = [] + actions = [] + # episode + # frame_id + # timestamp rewards = [] successes = [] dones = [] @@ -113,25 +119,32 @@ def eval_policy( done = torch.tensor([False for _ in env.envs]) step = 0 while not done.all(): + # format from env keys to lerobot keys + observation = preprocess_observation(observation) + observations.append(deepcopy(observation)) + # apply transform to normalize the observations - observation = preprocess_observation(observation, transform) + for key in observation: + observation[key] = torch.stack([transform({key: item})[key] for item in observation[key]]) # send observation to device/gpu observation = {key: observation[key].to(device, non_blocking=True) for key in observation} # get the next action for the environment with torch.inference_mode(): - action = policy.select_action(observation, step) + action = policy.select_action(observation, step=step) # apply inverse transform to unnormalize the action action = postprocess_action(action, transform) action = np.array([[0, 0, 0, 0]], dtype=np.float32) - # apply the next + # apply the next action observation, reward, terminated, truncated, info = env.step(action) - maybe_render_frame(env) + if max_episodes_rendered > 0: + render_frame(env) # TODO(rcadene): implement a wrapper over env to return torch tensors in float32 (and cuda?) + action = torch.from_numpy(action) reward = torch.from_numpy(reward) terminated = torch.from_numpy(terminated) truncated = torch.from_numpy(truncated) @@ -148,12 +161,24 @@ def eval_policy( success = [False for _ in env.envs] success = torch.tensor(success) + actions.append(action) rewards.append(reward) dones.append(done) successes.append(success) step += 1 + env.close() + + # add the last observation when the env is done + observation = preprocess_observation(observation) + observations.append(deepcopy(observation)) + + new_obses = {} + for key in observations[0].keys(): # noqa: SIM118 + new_obses[key] = torch.stack([obs[key] for obs in observations], dim=1) + observations = new_obses + actions = torch.stack(actions, dim=1) rewards = torch.stack(rewards, dim=1) successes = torch.stack(successes, dim=1) dones = torch.stack(dones, dim=1) @@ -173,29 +198,71 @@ def eval_policy( max_rewards.extend(batch_max_reward.tolist()) all_successes.extend(batch_success.tolist()) - env.close() + # similar logic is implemented in dataset preprocessing + ep_dicts = [] + num_episodes = dones.shape[0] + total_frames = 0 + idx_from = 0 + for ep_id in range(num_episodes): + num_frames = done_indices[ep_id].item() + 1 + total_frames += num_frames - if save_video or return_first_video: + # TODO(rcadene): We need to add a missing last frame which is the observation + # of a done state. it is critical to have this frame for tdmpc to predict a "done observation/state" + ep_dict = { + "action": actions[ep_id, :num_frames], + "episode_id": torch.tensor([ep_id] * num_frames), + "frame_id": torch.arange(0, num_frames, 1), + "timestamp": torch.arange(0, num_frames, 1) / fps, + "next.done": dones[ep_id, :num_frames], + "next.reward": rewards[ep_id, :num_frames].type(torch.float32), + "episode_data_index_from": torch.tensor([idx_from] * num_frames), + "episode_data_index_to": torch.tensor([idx_from + num_frames] * num_frames), + } + for key in observations: + ep_dict[key] = observations[key][ep_id][:num_frames] + ep_dicts.append(ep_dict) + + idx_from += num_frames + + # similar logic is implemented in dataset preprocessing + data_dict = {} + keys = ep_dicts[0].keys() + for key in keys: + if "image" not in key: + data_dict[key] = torch.cat([x[key] for x in ep_dicts]) + else: + if key not in data_dict: + data_dict[key] = [] + for ep_dict in ep_dicts: + for x in ep_dict[key]: + # c h w -> h w c + img = PILImage.fromarray(x.permute(1, 2, 0).numpy()) + data_dict[key].append(img) + + data_dict["index"] = torch.arange(0, total_frames, 1) + + data_dict = Dataset.from_dict(data_dict).with_format("torch") + + if max_episodes_rendered > 0: batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *) - if save_video: - for stacked_frames, done_index in zip( - batch_stacked_frames, done_indices.flatten().tolist(), strict=False - ): - if episode_counter >= num_episodes: - continue - video_dir.mkdir(parents=True, exist_ok=True) - video_path = video_dir / f"eval_episode_{episode_counter}.mp4" - thread = threading.Thread( - target=write_video, - args=(str(video_path), stacked_frames[:done_index], fps), - ) - thread.start() - threads.append(thread) - episode_counter += 1 + for stacked_frames, done_index in zip( + batch_stacked_frames, done_indices.flatten().tolist(), strict=False + ): + if episode_counter >= num_episodes: + continue + video_dir.mkdir(parents=True, exist_ok=True) + video_path = video_dir / f"eval_episode_{episode_counter}.mp4" + thread = threading.Thread( + target=write_video, + args=(str(video_path), stacked_frames[:done_index], fps), + ) + thread.start() + threads.append(thread) + episode_counter += 1 - if return_first_video: - first_video = batch_stacked_frames[0].transpose(0, 3, 1, 2) + videos = einops.rearrange(batch_stacked_frames, "b t h w c -> b t c h w") for thread in threads: thread.join() @@ -226,9 +293,10 @@ def eval_policy( "eval_s": time.time() - start, "eval_ep_s": (time.time() - start) / num_episodes, }, + "episodes": data_dict, } - if return_first_video: - return info, first_video + if max_episodes_rendered > 0: + info["videos"] = videos return info @@ -256,16 +324,14 @@ def eval(cfg: dict, out_dir=None, stats_path=None): logging.info("Making environment.") env = make_env(cfg, num_parallel_envs=cfg.eval_episodes) - # when policy is None, rollout a random policy - policy = make_policy(cfg) if cfg.policy.pretrained_model_path else None + logging.info("Making policy.") + policy = make_policy(cfg) info = eval_policy( env, - policy=policy, - save_video=True, + policy, + max_episodes_rendered=10, video_dir=Path(out_dir) / "eval", - fps=cfg.env.fps, - # TODO(rcadene): what should we do with the transform? transform=transform, seed=cfg.seed, ) @@ -273,6 +339,9 @@ def eval(cfg: dict, out_dir=None, stats_path=None): # Save info with open(Path(out_dir) / "eval_info.json", "w") as f: + # remove pytorch tensors which are not serializable to save the evaluation results only + del info["episodes"] + del info["videos"] json.dump(info, f, indent=2) logging.info("End of eval") diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 9a1472d5..be25faf7 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -1,9 +1,11 @@ import logging +from copy import deepcopy from pathlib import Path import hydra -import numpy as np import torch +from datasets import concatenate_datasets +from datasets.utils.logging import disable_progress_bar from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.utils import cycle @@ -108,6 +110,68 @@ def log_eval_info(logger, info, step, cfg, dataset, is_offline): logger.log_dict(info, step, mode="eval") +def calculate_online_sample_weight(n_off: int, n_on: int, pc_on: float): + """ + Calculate the sampling weight to be assigned to samples so that a specified percentage of the batch comes from online dataset (on average). + + Parameters: + - n_off (int): Number of offline samples, each with a sampling weight of 1. + - n_on (int): Number of online samples. + - pc_on (float): Desired percentage of online samples in decimal form (e.g., 50% as 0.5). + + The total weight of offline samples is n_off * 1.0. + The total weight of offline samples is n_on * w. + The total combined weight of all samples is n_off + n_on * w. + The fraction of the weight that is online is n_on * w / (n_off + n_on * w). + We want this fraction to equal pc_on, so we set up the equation n_on * w / (n_off + n_on * w) = pc_on. + The solution is w = - (n_off * pc_on) / (n_on * (pc_on - 1)) + """ + assert 0.0 <= pc_on <= 1.0 + return -(n_off * pc_on) / (n_on * (pc_on - 1)) + + +def add_episodes_inplace(data_dict, online_dataset, concat_dataset, sampler, pc_online_samples): + first_episode_id = data_dict.select_columns("episode_id")[0]["episode_id"].item() + first_index = data_dict.select_columns("index")[0]["index"].item() + assert first_episode_id == 0, f"We expect the first episode_id to be 0 and not {first_episode_id}" + assert first_index == 0, f"We expect the first first_index to be 0 and not {first_index}" + + if len(online_dataset) == 0: + # initialize online dataset + online_dataset.data_dict = data_dict + else: + # find episode index and data frame indices according to previous episode in online_dataset + start_episode = online_dataset.select_columns("episode_id")[-1]["episode_id"].item() + 1 + start_index = online_dataset.select_columns("index")[-1]["index"].item() + 1 + + def shift_indices(example): + # note: we dont shift "frame_id" since it represents the index of the frame in the episode it belongs to + example["episode_id"] += start_episode + example["index"] += start_index + example["episode_data_index_from"] += start_index + example["episode_data_index_to"] += start_index + return example + + disable_progress_bar() # map has a tqdm progress bar + data_dict = data_dict.map(shift_indices) + + # extend online dataset + online_dataset.data_dict = concatenate_datasets([online_dataset.data_dict, data_dict]) + + # update the concatenated dataset length used during sampling + concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets) + + # update the sampling weights for each frame so that online frames get sampled a certain percentage of times + len_online = len(online_dataset) + len_offline = len(concat_dataset) - len_online + weight_offline = 1.0 + weight_online = calculate_online_sample_weight(len_offline, len_online, pc_online_samples) + sampler.weights = torch.tensor([weight_offline] * len_offline + [weight_online] * len(online_dataset)) + + # update the total number of samples used during sampling + sampler.num_samples = len(concat_dataset) + + def train(cfg: dict, out_dir=None, job_name=None): if out_dir is None: raise NotImplementedError() @@ -127,26 +191,7 @@ def train(cfg: dict, out_dir=None, job_name=None): set_global_seed(cfg.seed) logging.info("make_dataset") - dataset = make_dataset(cfg) - - # TODO(rcadene): move balanced_sampling, per_alpha, per_beta outside policy - # if cfg.policy.balanced_sampling: - # logging.info("make online_buffer") - # num_traj_per_batch = cfg.policy.batch_size - - # online_sampler = PrioritizedSliceSampler( - # max_capacity=100_000, - # alpha=cfg.policy.per_alpha, - # beta=cfg.policy.per_beta, - # num_slices=num_traj_per_batch, - # strict_length=True, - # ) - - # online_buffer = TensorDictReplayBuffer( - # storage=LazyMemmapStorage(100_000), - # sampler=online_sampler, - # transform=dataset.transform, - # ) + offline_dataset = make_dataset(cfg) logging.info("make_env") env = make_env(cfg, num_parallel_envs=cfg.eval_episodes) @@ -164,10 +209,8 @@ def train(cfg: dict, out_dir=None, job_name=None): logging.info(f"{cfg.env.task=}") logging.info(f"{cfg.offline_steps=} ({format_big_number(cfg.offline_steps)})") logging.info(f"{cfg.online_steps=}") - # TODO(now): uncomment - # logging.info(f"{cfg.env.action_repeat=}") - logging.info(f"{dataset.num_samples=} ({format_big_number(dataset.num_samples)})") - logging.info(f"{dataset.num_episodes=}") + logging.info(f"{offline_dataset.num_samples=} ({format_big_number(offline_dataset.num_samples)})") + logging.info(f"{offline_dataset.num_episodes=}") logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") @@ -175,18 +218,17 @@ def train(cfg: dict, out_dir=None, job_name=None): def _maybe_eval_and_maybe_save(step): if step % cfg.eval_freq == 0: logging.info(f"Eval policy at step {step}") - eval_info, first_video = eval_policy( + eval_info = eval_policy( env, policy, - return_first_video=True, video_dir=Path(out_dir) / "eval", - save_video=True, - transform=dataset.transform, + max_episodes_rendered=4, + transform=offline_dataset.transform, seed=cfg.seed, ) - log_eval_info(logger, eval_info["aggregated"], step, cfg, dataset, is_offline) + log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline) if cfg.wandb.enable: - logger.log_video(first_video, step, mode="eval") + logger.log_video(eval_info["videos"][0], step, mode="eval") logging.info("Resume training") if cfg.save_model and step % cfg.save_freq == 0: @@ -194,18 +236,19 @@ def train(cfg: dict, out_dir=None, job_name=None): logger.save_model(policy, identifier=step) logging.info("Resume training") - step = 0 # number of policy update (forward + backward + optim) - - is_offline = True + # create dataloader for offline training dataloader = torch.utils.data.DataLoader( - dataset, + offline_dataset, num_workers=4, batch_size=cfg.policy.batch_size, shuffle=True, pin_memory=cfg.device != "cpu", - drop_last=True, + drop_last=False, ) dl_iter = cycle(dataloader) + + step = 0 # number of policy update (forward + backward + optim) + is_offline = True for offline_step in range(cfg.offline_steps): if offline_step == 0: logging.info("Start offline training on a fixed dataset") @@ -215,11 +258,11 @@ def train(cfg: dict, out_dir=None, job_name=None): for key in batch: batch[key] = batch[key].to(cfg.device, non_blocking=True) - train_info = policy(batch, step) + train_info = policy.update(batch, step=step) # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done? if step % cfg.log_freq == 0: - log_train_info(logger, train_info, step, cfg, dataset, is_offline) + log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline) # Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass in # step + 1. @@ -227,61 +270,59 @@ def train(cfg: dict, out_dir=None, job_name=None): step += 1 - raise NotImplementedError() + # create an env dedicated to online episodes collection from policy rollout + rollout_env = make_env(cfg, num_parallel_envs=1) + + # create an empty online dataset similar to offline dataset + online_dataset = deepcopy(offline_dataset) + online_dataset.data_dict = {} + + # create dataloader for online training + concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset]) + weights = [1.0] * len(concat_dataset) + sampler = torch.utils.data.WeightedRandomSampler( + weights, num_samples=len(concat_dataset), replacement=True + ) + dataloader = torch.utils.data.DataLoader( + concat_dataset, + num_workers=4, + batch_size=cfg.policy.batch_size, + sampler=sampler, + pin_memory=cfg.device != "cpu", + drop_last=False, + ) + dl_iter = cycle(dataloader) - demo_buffer = dataset if cfg.policy.balanced_sampling else None online_step = 0 is_offline = False for env_step in range(cfg.online_steps): if env_step == 0: logging.info("Start online training by interacting with environment") - # TODO: add configurable number of rollout? (default=1) + with torch.no_grad(): - rollout = env.rollout( - max_steps=cfg.env.episode_length, - policy=policy, - auto_cast_to_device=True, + eval_info = eval_policy( + rollout_env, + policy, + transform=offline_dataset.transform, + seed=cfg.seed, ) - assert ( - len(rollout.batch_size) == 2 - ), "2 dimensions expected: number of env in parallel x max number of steps during rollout" - - num_parallel_env = rollout.batch_size[0] - if num_parallel_env != 1: - # TODO(rcadene): when num_parallel_env > 1, rollout["episode"] needs to be properly set and we need to add tests - raise NotImplementedError() - - num_max_steps = rollout.batch_size[1] - assert num_max_steps <= cfg.env.episode_length - - # reshape to have a list of steps to insert into online_buffer - rollout = rollout.reshape(num_parallel_env * num_max_steps) - - # set same episode index for all time steps contained in this rollout - rollout["episode"] = torch.tensor([env_step] * len(rollout), dtype=torch.int) - # online_buffer.extend(rollout) - - ep_sum_reward = rollout["next", "reward"].sum() - ep_max_reward = rollout["next", "reward"].max() - ep_success = rollout["next", "success"].any() - rollout_info = { - "avg_sum_reward": np.nanmean(ep_sum_reward), - "avg_max_reward": np.nanmean(ep_max_reward), - "pc_success": np.nanmean(ep_success) * 100, - "env_step": env_step, - "ep_length": len(rollout), - } + online_pc_sampling = cfg.get("demo_schedule", 0.5) + add_episodes_inplace( + eval_info["episodes"], online_dataset, concat_dataset, sampler, online_pc_sampling + ) for _ in range(cfg.policy.utd): - train_info = policy.update( - # online_buffer, - step, - demo_buffer=demo_buffer, - ) + policy.train() + batch = next(dl_iter) + + for key in batch: + batch[key] = batch[key].to(cfg.device, non_blocking=True) + + train_info = policy.update(batch, step) + if step % cfg.log_freq == 0: - train_info.update(rollout_info) - log_train_info(logger, train_info, step, cfg, dataset, is_offline) + log_train_info(logger, train_info, step, cfg, online_dataset, is_offline) # Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass # in step + 1. diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py index 93315e90..739115e9 100644 --- a/lerobot/scripts/visualize_dataset.py +++ b/lerobot/scripts/visualize_dataset.py @@ -6,9 +6,6 @@ import einops import hydra import imageio import torch -from torchrl.data.replay_buffers import ( - SamplerWithoutReplacement, -) from lerobot.common.datasets.factory import make_dataset from lerobot.common.logger import log_output_dir @@ -39,19 +36,11 @@ def visualize_dataset(cfg: dict, out_dir=None): init_logging() log_output_dir(out_dir) - # we expect frames of each episode to be stored next to each others sequentially - sampler = SamplerWithoutReplacement( - shuffle=False, - ) - logging.info("make_dataset") dataset = make_dataset( cfg, - overwrite_sampler=sampler, # remove all transformations such as rescale images from [0,255] to [0,1] or normalization normalize=False, - overwrite_batch_size=1, - overwrite_prefetch=12, ) logging.info("Start rendering episodes from offline buffer") @@ -60,64 +49,51 @@ def visualize_dataset(cfg: dict, out_dir=None): logging.info(video_path) -def render_dataset(dataset, out_dir, max_num_samples, fps): +def render_dataset(dataset, out_dir, max_num_episodes): out_dir = Path(out_dir) video_paths = [] threads = [] - frames = {} - current_ep_idx = 0 - logging.info(f"Visualizing episode {current_ep_idx}") - for i in range(max_num_samples): - # TODO(rcadene): make it work with bsize > 1 - ep_td = dataset.sample(1) - ep_idx = ep_td["episode"][FIRST_FRAME].item() - # TODO(rcadene): modify dataset._sampler._sample_list or sampler to randomly sample an episode, but sequentially sample frames - num_frames_left = dataset._sampler._sample_list.numel() - episode_is_done = ep_idx != current_ep_idx + dataloader = torch.utils.data.DataLoader( + dataset, + num_workers=4, + batch_size=1, + shuffle=False, + ) + dl_iter = iter(dataloader) - if episode_is_done: - logging.info(f"Rendering episode {current_ep_idx}") + for ep_id in range(min(max_num_episodes, dataset.num_episodes)): + logging.info(f"Rendering episode {ep_id}") - for im_key in dataset.image_keys: - if not episode_is_done and num_frames_left > 0 and i < (max_num_samples - 1): + frames = {} + end_of_episode = False + while not end_of_episode: + item = next(dl_iter) + + for im_key in dataset.image_keys: # when first frame of episode, initialize frames dict if im_key not in frames: frames[im_key] = [] # add current frame to list of frames to render - frames[im_key].append(ep_td[im_key]) + frames[im_key].append(item[im_key]) + + end_of_episode = item["index"].item() == item["episode_data_index_to"].item() - 1 + + out_dir.mkdir(parents=True, exist_ok=True) + for im_key in dataset.image_keys: + if len(dataset.image_keys) > 1: + im_name = im_key.replace("observation.images.", "") + video_path = out_dir / f"episode_{ep_id}_{im_name}.mp4" else: - # When episode has no more frame in its list of observation, - # one frame still remains. It is the result of the last action taken. - # It is stored in `"next"`, so we add it to the list of frames to render. - frames[im_key].append(ep_td["next"][im_key]) + video_path = out_dir / f"episode_{ep_id}.mp4" + video_paths.append(video_path) - out_dir.mkdir(parents=True, exist_ok=True) - if len(dataset.image_keys) > 1: - camera = im_key[-1] - video_path = out_dir / f"episode_{current_ep_idx}_{camera}.mp4" - else: - video_path = out_dir / f"episode_{current_ep_idx}.mp4" - video_paths.append(str(video_path)) - - thread = threading.Thread( - target=cat_and_write_video, - args=(str(video_path), frames[im_key], fps), - ) - thread.start() - threads.append(thread) - - current_ep_idx = ep_idx - - # reset list of frames - del frames[im_key] - - if num_frames_left == 0: - logging.info("Ran out of frames") - break - - if current_ep_idx == NUM_EPISODES_TO_RENDER: - break + thread = threading.Thread( + target=cat_and_write_video, + args=(str(video_path), frames[im_key], dataset.fps), + ) + thread.start() + threads.append(thread) for thread in threads: thread.join() diff --git a/poetry.lock b/poetry.lock index faeb70f1..a70e404a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand. [[package]] name = "absl-py" @@ -11,6 +11,116 @@ files = [ {file = "absl_py-2.1.0-py3-none-any.whl", hash = "sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308"}, ] +[[package]] +name = "aiohttp" +version = "3.9.4" +description = "Async http client/server framework (asyncio)" +optional = false +python-versions = ">=3.8" +files = [ + {file = "aiohttp-3.9.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:76d32588ef7e4a3f3adff1956a0ba96faabbdee58f2407c122dd45aa6e34f372"}, + {file = "aiohttp-3.9.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:56181093c10dbc6ceb8a29dfeea1e815e1dfdc020169203d87fd8d37616f73f9"}, + {file = "aiohttp-3.9.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c7a5b676d3c65e88b3aca41816bf72831898fcd73f0cbb2680e9d88e819d1e4d"}, + {file = "aiohttp-3.9.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d1df528a85fb404899d4207a8d9934cfd6be626e30e5d3a5544a83dbae6d8a7e"}, + {file = "aiohttp-3.9.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f595db1bceabd71c82e92df212dd9525a8a2c6947d39e3c994c4f27d2fe15b11"}, + {file = "aiohttp-3.9.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9c0b09d76e5a4caac3d27752027fbd43dc987b95f3748fad2b924a03fe8632ad"}, + {file = "aiohttp-3.9.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:689eb4356649ec9535b3686200b231876fb4cab4aca54e3bece71d37f50c1d13"}, + {file = "aiohttp-3.9.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a3666cf4182efdb44d73602379a66f5fdfd5da0db5e4520f0ac0dcca644a3497"}, + {file = "aiohttp-3.9.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:b65b0f8747b013570eea2f75726046fa54fa8e0c5db60f3b98dd5d161052004a"}, + {file = "aiohttp-3.9.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:a1885d2470955f70dfdd33a02e1749613c5a9c5ab855f6db38e0b9389453dce7"}, + {file = "aiohttp-3.9.4-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:0593822dcdb9483d41f12041ff7c90d4d1033ec0e880bcfaf102919b715f47f1"}, + {file = "aiohttp-3.9.4-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:47f6eb74e1ecb5e19a78f4a4228aa24df7fbab3b62d4a625d3f41194a08bd54f"}, + {file = "aiohttp-3.9.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c8b04a3dbd54de6ccb7604242fe3ad67f2f3ca558f2d33fe19d4b08d90701a89"}, + {file = "aiohttp-3.9.4-cp310-cp310-win32.whl", hash = "sha256:8a78dfb198a328bfb38e4308ca8167028920fb747ddcf086ce706fbdd23b2926"}, + {file = "aiohttp-3.9.4-cp310-cp310-win_amd64.whl", hash = "sha256:e78da6b55275987cbc89141a1d8e75f5070e577c482dd48bd9123a76a96f0bbb"}, + {file = "aiohttp-3.9.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:c111b3c69060d2bafc446917534150fd049e7aedd6cbf21ba526a5a97b4402a5"}, + {file = "aiohttp-3.9.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:efbdd51872cf170093998c87ccdf3cb5993add3559341a8e5708bcb311934c94"}, + {file = "aiohttp-3.9.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7bfdb41dc6e85d8535b00d73947548a748e9534e8e4fddd2638109ff3fb081df"}, + {file = "aiohttp-3.9.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2bd9d334412961125e9f68d5b73c1d0ab9ea3f74a58a475e6b119f5293eee7ba"}, + {file = "aiohttp-3.9.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:35d78076736f4a668d57ade00c65d30a8ce28719d8a42471b2a06ccd1a2e3063"}, + {file = "aiohttp-3.9.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:824dff4f9f4d0f59d0fa3577932ee9a20e09edec8a2f813e1d6b9f89ced8293f"}, + {file = "aiohttp-3.9.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:52b8b4e06fc15519019e128abedaeb56412b106ab88b3c452188ca47a25c4093"}, + {file = "aiohttp-3.9.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:eae569fb1e7559d4f3919965617bb39f9e753967fae55ce13454bec2d1c54f09"}, + {file = "aiohttp-3.9.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:69b97aa5792428f321f72aeb2f118e56893371f27e0b7d05750bcad06fc42ca1"}, + {file = "aiohttp-3.9.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:4d79aad0ad4b980663316f26d9a492e8fab2af77c69c0f33780a56843ad2f89e"}, + {file = "aiohttp-3.9.4-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:d6577140cd7db19e430661e4b2653680194ea8c22c994bc65b7a19d8ec834403"}, + {file = "aiohttp-3.9.4-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:9860d455847cd98eb67897f5957b7cd69fbcb436dd3f06099230f16a66e66f79"}, + {file = "aiohttp-3.9.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:69ff36d3f8f5652994e08bd22f093e11cfd0444cea310f92e01b45a4e46b624e"}, + {file = "aiohttp-3.9.4-cp311-cp311-win32.whl", hash = "sha256:e27d3b5ed2c2013bce66ad67ee57cbf614288bda8cdf426c8d8fe548316f1b5f"}, + {file = "aiohttp-3.9.4-cp311-cp311-win_amd64.whl", hash = "sha256:d6a67e26daa686a6fbdb600a9af8619c80a332556245fa8e86c747d226ab1a1e"}, + {file = "aiohttp-3.9.4-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:c5ff8ff44825736a4065d8544b43b43ee4c6dd1530f3a08e6c0578a813b0aa35"}, + {file = "aiohttp-3.9.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d12a244627eba4e9dc52cbf924edef905ddd6cafc6513849b4876076a6f38b0e"}, + {file = "aiohttp-3.9.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:dcad56c8d8348e7e468899d2fb3b309b9bc59d94e6db08710555f7436156097f"}, + {file = "aiohttp-3.9.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f7e69a7fd4b5ce419238388e55abd220336bd32212c673ceabc57ccf3d05b55"}, + {file = "aiohttp-3.9.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c4870cb049f10d7680c239b55428916d84158798eb8f353e74fa2c98980dcc0b"}, + {file = "aiohttp-3.9.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3b2feaf1b7031ede1bc0880cec4b0776fd347259a723d625357bb4b82f62687b"}, + {file = "aiohttp-3.9.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:939393e8c3f0a5bcd33ef7ace67680c318dc2ae406f15e381c0054dd658397de"}, + {file = "aiohttp-3.9.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7d2334e387b2adcc944680bebcf412743f2caf4eeebd550f67249c1c3696be04"}, + {file = "aiohttp-3.9.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:e0198ea897680e480845ec0ffc5a14e8b694e25b3f104f63676d55bf76a82f1a"}, + {file = "aiohttp-3.9.4-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:e40d2cd22914d67c84824045861a5bb0fb46586b15dfe4f046c7495bf08306b2"}, + {file = "aiohttp-3.9.4-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:aba80e77c227f4234aa34a5ff2b6ff30c5d6a827a91d22ff6b999de9175d71bd"}, + {file = "aiohttp-3.9.4-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:fb68dc73bc8ac322d2e392a59a9e396c4f35cb6fdbdd749e139d1d6c985f2527"}, + {file = "aiohttp-3.9.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:f3460a92638dce7e47062cf088d6e7663adb135e936cb117be88d5e6c48c9d53"}, + {file = "aiohttp-3.9.4-cp312-cp312-win32.whl", hash = "sha256:32dc814ddbb254f6170bca198fe307920f6c1308a5492f049f7f63554b88ef36"}, + {file = "aiohttp-3.9.4-cp312-cp312-win_amd64.whl", hash = "sha256:63f41a909d182d2b78fe3abef557fcc14da50c7852f70ae3be60e83ff64edba5"}, + {file = "aiohttp-3.9.4-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:c3770365675f6be220032f6609a8fbad994d6dcf3ef7dbcf295c7ee70884c9af"}, + {file = "aiohttp-3.9.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:305edae1dea368ce09bcb858cf5a63a064f3bff4767dec6fa60a0cc0e805a1d3"}, + {file = "aiohttp-3.9.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:6f121900131d116e4a93b55ab0d12ad72573f967b100e49086e496a9b24523ea"}, + {file = "aiohttp-3.9.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b71e614c1ae35c3d62a293b19eface83d5e4d194e3eb2fabb10059d33e6e8cbf"}, + {file = "aiohttp-3.9.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:419f009fa4cfde4d16a7fc070d64f36d70a8d35a90d71aa27670bba2be4fd039"}, + {file = "aiohttp-3.9.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7b39476ee69cfe64061fd77a73bf692c40021f8547cda617a3466530ef63f947"}, + {file = "aiohttp-3.9.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b33f34c9c7decdb2ab99c74be6443942b730b56d9c5ee48fb7df2c86492f293c"}, + {file = "aiohttp-3.9.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c78700130ce2dcebb1a8103202ae795be2fa8c9351d0dd22338fe3dac74847d9"}, + {file = "aiohttp-3.9.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:268ba22d917655d1259af2d5659072b7dc11b4e1dc2cb9662fdd867d75afc6a4"}, + {file = "aiohttp-3.9.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:17e7c051f53a0d2ebf33013a9cbf020bb4e098c4bc5bce6f7b0c962108d97eab"}, + {file = "aiohttp-3.9.4-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:7be99f4abb008cb38e144f85f515598f4c2c8932bf11b65add0ff59c9c876d99"}, + {file = "aiohttp-3.9.4-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:d58a54d6ff08d2547656356eea8572b224e6f9bbc0cf55fa9966bcaac4ddfb10"}, + {file = "aiohttp-3.9.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:7673a76772bda15d0d10d1aa881b7911d0580c980dbd16e59d7ba1422b2d83cd"}, + {file = "aiohttp-3.9.4-cp38-cp38-win32.whl", hash = "sha256:e4370dda04dc8951012f30e1ce7956a0a226ac0714a7b6c389fb2f43f22a250e"}, + {file = "aiohttp-3.9.4-cp38-cp38-win_amd64.whl", hash = "sha256:eb30c4510a691bb87081192a394fb661860e75ca3896c01c6d186febe7c88530"}, + {file = "aiohttp-3.9.4-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:84e90494db7df3be5e056f91412f9fa9e611fbe8ce4aaef70647297f5943b276"}, + {file = "aiohttp-3.9.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7d4845f8501ab28ebfdbeab980a50a273b415cf69e96e4e674d43d86a464df9d"}, + {file = "aiohttp-3.9.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:69046cd9a2a17245c4ce3c1f1a4ff8c70c7701ef222fce3d1d8435f09042bba1"}, + {file = "aiohttp-3.9.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8b73a06bafc8dcc508420db43b4dd5850e41e69de99009d0351c4f3007960019"}, + {file = "aiohttp-3.9.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:418bb0038dfafeac923823c2e63226179976c76f981a2aaad0ad5d51f2229bca"}, + {file = "aiohttp-3.9.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:71a8f241456b6c2668374d5d28398f8e8cdae4cce568aaea54e0f39359cd928d"}, + {file = "aiohttp-3.9.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:935c369bf8acc2dc26f6eeb5222768aa7c62917c3554f7215f2ead7386b33748"}, + {file = "aiohttp-3.9.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:74e4e48c8752d14ecfb36d2ebb3d76d614320570e14de0a3aa7a726ff150a03c"}, + {file = "aiohttp-3.9.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:916b0417aeddf2c8c61291238ce25286f391a6acb6f28005dd9ce282bd6311b6"}, + {file = "aiohttp-3.9.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:9b6787b6d0b3518b2ee4cbeadd24a507756ee703adbac1ab6dc7c4434b8c572a"}, + {file = "aiohttp-3.9.4-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:221204dbda5ef350e8db6287937621cf75e85778b296c9c52260b522231940ed"}, + {file = "aiohttp-3.9.4-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:10afd99b8251022ddf81eaed1d90f5a988e349ee7d779eb429fb07b670751e8c"}, + {file = "aiohttp-3.9.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:2506d9f7a9b91033201be9ffe7d89c6a54150b0578803cce5cb84a943d075bc3"}, + {file = "aiohttp-3.9.4-cp39-cp39-win32.whl", hash = "sha256:e571fdd9efd65e86c6af2f332e0e95dad259bfe6beb5d15b3c3eca3a6eb5d87b"}, + {file = "aiohttp-3.9.4-cp39-cp39-win_amd64.whl", hash = "sha256:7d29dd5319d20aa3b7749719ac9685fbd926f71ac8c77b2477272725f882072d"}, + {file = "aiohttp-3.9.4.tar.gz", hash = "sha256:6ff71ede6d9a5a58cfb7b6fffc83ab5d4a63138276c771ac91ceaaddf5459644"}, +] + +[package.dependencies] +aiosignal = ">=1.1.2" +async-timeout = {version = ">=4.0,<5.0", markers = "python_version < \"3.11\""} +attrs = ">=17.3.0" +frozenlist = ">=1.1.1" +multidict = ">=4.5,<7.0" +yarl = ">=1.0,<2.0" + +[package.extras] +speedups = ["Brotli", "aiodns", "brotlicffi"] + +[[package]] +name = "aiosignal" +version = "1.3.1" +description = "aiosignal: a list of registered asynchronous callbacks" +optional = false +python-versions = ">=3.7" +files = [ + {file = "aiosignal-1.3.1-py3-none-any.whl", hash = "sha256:f8376fb07dd1e86a584e4fcdec80b36b7f81aac666ebc724e2c090300dd83b17"}, + {file = "aiosignal-1.3.1.tar.gz", hash = "sha256:54cd96e15e1649b75d6c87526a6ff0b6c1b0dd3459f43d9ca11d48c339b68cfc"}, +] + +[package.dependencies] +frozenlist = ">=1.1.0" + [[package]] name = "antlr4-python3-runtime" version = "4.9.3" @@ -43,59 +153,35 @@ files = [ ] [[package]] -name = "av" -version = "12.0.0" -description = "Pythonic bindings for FFmpeg's libraries." +name = "async-timeout" +version = "4.0.3" +description = "Timeout context manager for asyncio programs" optional = false -python-versions = ">=3.8" +python-versions = ">=3.7" files = [ - {file = "av-12.0.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b9d0890553951f76c479a9f2bb952aebae902b1c7d52feea614d37e1cd728a44"}, - {file = "av-12.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5d7f229a253c2e3fea9682c09c5ae179bd6d5d2da38d89eb7f29ef7bed10cb2f"}, - {file = "av-12.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:61b3555d143aacf02e0446f6030319403538eba4dc713c18dfa653a2a23e7f9c"}, - {file = "av-12.0.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:607e13b2c2b26159a37525d7b6f647a32ce78711fccff23d146d3e255ffa115f"}, - {file = "av-12.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39f0b4cfb89f4f06b339c766f92648e798a96747d4163f2fa78660d1ab1f1b5e"}, - {file = "av-12.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:41dcb8c269fa58a56edf3a3c814c32a0c69586827f132b4e395a951b0ce14fad"}, - {file = "av-12.0.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4fa78fbe0e4469226512380180063116105048c66cb12e18ab4b518466c57e6c"}, - {file = "av-12.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:60a869be1d6af916e65ea461cb93922f5db0698655ed7a7eae7c3ecd4af4debb"}, - {file = "av-12.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:df61811cc551c186f0a0e530d97b8b139453534d0f92c1790a923f666522ceda"}, - {file = "av-12.0.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:99cd2fc53091ebfb9a2fa9dd3580267f5bd1c040d0efd99fbc1a162576b271cb"}, - {file = "av-12.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a6d4f1e261df48932128e6495772faa4cc23f5dd1512eec73daab82ad9f3240"}, - {file = "av-12.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:6aec88e41a498b1e01e2dce5371557e20f9a51aae0c16decc5924ec0be2e22b6"}, - {file = "av-12.0.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:90eb8f2d548e96cbc6f78e89c911cdb15a3d80fd944f31111660ce45939cd037"}, - {file = "av-12.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d7f3a02910e77d750dbd516256a16db15030e5371530ff5a5ae902dc03d9005d"}, - {file = "av-12.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2477cc51526aa50575313d66e5e8ad7ab944588469be5e557b360ed572ae536"}, - {file = "av-12.0.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a2f47149d3ca6deb79f3e515b8bef50e27ebdb160813e6d67dba77278d2a7883"}, - {file = "av-12.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3306e4a3ce8b5bfcc3075793d4ed3a2df69179d8fba22cb944a6164dc235dfb6"}, - {file = "av-12.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:dc1b742e7f6df1b499fb960bd6697d1dd8e7ada7484a041a8c20e70a87225f53"}, - {file = "av-12.0.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:0183be6889e835e1b074b4037bfce4fd44671c606cf1c4ab92ea2f271b544aec"}, - {file = "av-12.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:57337f20b208292ec8d3b11e4d289d8688a43d728174850a81b865d3253fff2c"}, - {file = "av-12.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ec915e8f6521545a38566eefc281042ee504ea3cee0618d8558e4920588b3b2"}, - {file = "av-12.0.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:33ad5c0a23c45b72bd6bd47f3b2c1adcd2935ee3d0b6178ed66bba62b964ff31"}, - {file = "av-12.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bfc3a652b12c93120514d56cf025da47442c5ba51530cdf7ba3660257dbb0de1"}, - {file = "av-12.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:037f793dd1ef4a1f57f090191a7f803ad10ec82da0d04ea26bbe0b8a145fe927"}, - {file = "av-12.0.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:fc532376aa264722fae55063abd1871d17a563dc895978e142c8ecfcdeb3a2e8"}, - {file = "av-12.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:abf0c4bc40a0af8a30f4cd96f3be6f19fbce0f21222d7fcec148e085127153f7"}, - {file = "av-12.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:81cedd1c072fbebf606724c406b1a1b00adc711f1dfd2bc04c633ce39d8439d8"}, - {file = "av-12.0.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:02d60f48be9f15dcda37d50f3ce8d7249d9a455643d4322dd3449986bacfc628"}, - {file = "av-12.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5d2619e4c26d661eecfc404f7d739d8b35f0dcef353fabe61512e030254b7031"}, - {file = "av-12.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:1892cc91c888d101777d5432d54e0554c11d1c3a2c65d02a2cae0a2256a8fbb9"}, - {file = "av-12.0.0-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:4819e3ef6c3a44ef6f75907229133a1ee7f688245b2cf49b6b8e969a81ca72c9"}, - {file = "av-12.0.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb16bb314cf1503b0250fc46b2c455ee196584231101be0123f4f78638227b62"}, - {file = "av-12.0.0-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f3e6a62bda9a1e144feeb59bbee046d7a2d98399634a30f57e4990197313c158"}, - {file = "av-12.0.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e08175ffbafa3a70c7b2f81083e160e34122a208cdf70f150b8f5d02c2de6965"}, - {file = "av-12.0.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:e1d255be317b7c1ebdc4dae98935b9f3869161112dc829c625e54f90d8bdd7ab"}, - {file = "av-12.0.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:17964b36e08435910aabd5b3f7dca12f99536902529767d276026bc08f94ced7"}, - {file = "av-12.0.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d2d5f78de29edee06ddcdd4c2b759914575492d6a0cd4de2ce31ee63a4953eff"}, - {file = "av-12.0.0-pp38-pypy38_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:309b32bc97158d0f0c19e273b8e17a855a86806b7194aebc23bd497326cff11f"}, - {file = "av-12.0.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c409c71bd9c7c2f8d018c822f36b1447cfa96eca158381a96f3319bb0ff6e79e"}, - {file = "av-12.0.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:08fc5eaef60a257d622998626e233bf3ff90d2f817f6695d6a27e0ffcfe9dcff"}, - {file = "av-12.0.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:746ab0eff8a7a21a6c6d16e6b6e61709527eba2ad1a524d92a01bb60d02a3df7"}, - {file = "av-12.0.0-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:013b3ac3de3aa1c137af0cedafd364fd1c7524ab3e1cd53e04564fd1632ac04d"}, - {file = "av-12.0.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0fa55923527648f51ac005e44fe2797ebc67f53ad4850e0194d3753761ee33a2"}, - {file = "av-12.0.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:35d514f4dee0cf67e9e6b2a65fb4a28f98da88e71e8c7f7960bd04625d9fe965"}, - {file = "av-12.0.0.tar.gz", hash = "sha256:bcf21ebb722d4538b4099e5a78f730d78814dd70003511c185941dba5651b14d"}, + {file = "async-timeout-4.0.3.tar.gz", hash = "sha256:4640d96be84d82d02ed59ea2b7105a0f7b33abe8703703cd0ab0bf87c427522f"}, + {file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"}, ] +[[package]] +name = "attrs" +version = "23.2.0" +description = "Classes Without Boilerplate" +optional = false +python-versions = ">=3.7" +files = [ + {file = "attrs-23.2.0-py3-none-any.whl", hash = "sha256:99b87a485a5820b23b879f04c2305b44b951b502fd64be915879d77a7e8fc6f1"}, + {file = "attrs-23.2.0.tar.gz", hash = "sha256:935dc3b529c262f6cf76e50877d35a4bd3c1de194fd41f47a2b7ae8f19971f30"}, +] + +[package.extras] +cov = ["attrs[tests]", "coverage[toml] (>=5.3)"] +dev = ["attrs[tests]", "pre-commit"] +docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier", "zope-interface"] +tests = ["attrs[tests-no-zope]", "zope-interface"] +tests-mypy = ["mypy (>=1.6)", "pytest-mypy-plugins"] +tests-no-zope = ["attrs[tests-mypy]", "cloudpickle", "hypothesis", "pympler", "pytest (>=4.3.0)", "pytest-xdist[psutil]"] + [[package]] name = "beautifulsoup4" version = "4.12.3" @@ -196,7 +282,7 @@ pycparser = "*" name = "cfgv" version = "3.4.0" description = "Validate configuration and produce human readable error messages." -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9"}, @@ -371,7 +457,7 @@ files = [ name = "coverage" version = "7.4.4" description = "Code coverage measurement for Python" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "coverage-7.4.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e0be5efd5127542ef31f165de269f77560d6cdef525fffa446de6f7e9186cfb2"}, @@ -434,11 +520,55 @@ tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.1 [package.extras] toml = ["tomli"] +[[package]] +name = "datasets" +version = "2.18.0" +description = "HuggingFace community-driven open-source library of datasets" +optional = false +python-versions = ">=3.8.0" +files = [ + {file = "datasets-2.18.0-py3-none-any.whl", hash = "sha256:f1bbf0e2896917a914de01cbd37075b14deea3837af87ad0d9f697388ccaeb50"}, + {file = "datasets-2.18.0.tar.gz", hash = "sha256:cdf8b8c6abf7316377ba4f49f9589a4c74556d6b481afd0abd2284f3d69185cb"}, +] + +[package.dependencies] +aiohttp = "*" +dill = ">=0.3.0,<0.3.9" +filelock = "*" +fsspec = {version = ">=2023.1.0,<=2024.2.0", extras = ["http"]} +huggingface-hub = ">=0.19.4" +multiprocess = "*" +numpy = ">=1.17" +packaging = "*" +pandas = "*" +pyarrow = ">=12.0.0" +pyarrow-hotfix = "*" +pyyaml = ">=5.1" +requests = ">=2.19.0" +tqdm = ">=4.62.1" +xxhash = "*" + +[package.extras] +apache-beam = ["apache-beam (>=2.26.0)"] +audio = ["librosa", "soundfile (>=0.12.1)"] +benchmarks = ["tensorflow (==2.12.0)", "torch (==2.0.1)", "transformers (==4.30.1)"] +dev = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "ruff (>=0.3.0)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"] +docs = ["s3fs", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos", "torch", "transformers"] +jax = ["jax (>=0.3.14)", "jaxlib (>=0.3.14)"] +metrics-tests = ["Werkzeug (>=1.0.1)", "accelerate", "bert-score (>=0.3.6)", "jiwer", "langdetect", "mauve-text", "nltk", "requests-file (>=1.5.1)", "rouge-score", "sacrebleu", "sacremoses", "scikit-learn", "scipy", "sentencepiece", "seqeval", "six (>=1.15.0,<1.16.0)", "spacy (>=3.0.0)", "texttable (>=1.6.3)", "tldextract", "tldextract (>=3.1.0)", "toml (>=0.10.1)", "typer (<0.5.0)"] +quality = ["ruff (>=0.3.0)"] +s3 = ["s3fs"] +tensorflow = ["tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos"] +tensorflow-gpu = ["tensorflow-gpu (>=2.2.0,!=2.6.0,!=2.6.1)"] +tests = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"] +torch = ["torch"] +vision = ["Pillow (>=6.2.1)"] + [[package]] name = "debugpy" version = "1.8.1" description = "An implementation of the Debug Adapter Protocol for Python" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "debugpy-1.8.1-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:3bda0f1e943d386cc7a0e71bfa59f4137909e2ed947fb3946c506e113000f741"}, @@ -465,17 +595,6 @@ files = [ {file = "debugpy-1.8.1.zip", hash = "sha256:f696d6be15be87aef621917585f9bb94b1dc9e8aced570db1b8a6fc14e8f9b42"}, ] -[[package]] -name = "decorator" -version = "4.4.2" -description = "Decorators for Humans" -optional = false -python-versions = ">=2.6, !=3.0.*, !=3.1.*" -files = [ - {file = "decorator-4.4.2-py2.py3-none-any.whl", hash = "sha256:41fa54c2a0cc4ba648be4fd43cff00aedf5b9465c9bf18d64325bc225f08f760"}, - {file = "decorator-4.4.2.tar.gz", hash = "sha256:e3a62f0520172440ca0dcc823749319382e377f37f140a0b99ef45fecb84bfe7"}, -] - [[package]] name = "diffusers" version = "0.26.3" @@ -506,11 +625,26 @@ test = ["GitPython (<3.1.19)", "Jinja2", "compel (==0.1.8)", "datasets", "invisi torch = ["accelerate (>=0.11.0)", "torch (>=1.4,<2.2.0)"] training = ["Jinja2", "accelerate (>=0.11.0)", "datasets", "peft (>=0.6.0)", "protobuf (>=3.20.3,<4)", "tensorboard"] +[[package]] +name = "dill" +version = "0.3.8" +description = "serialize all of Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "dill-0.3.8-py3-none-any.whl", hash = "sha256:c36ca9ffb54365bdd2f8eb3eff7d2a21237f8452b57ace88b1ac615b7e815bd7"}, + {file = "dill-0.3.8.tar.gz", hash = "sha256:3ebe3c479ad625c4553aca177444d89b486b1d84982eeacded644afc0cf797ca"}, +] + +[package.extras] +graph = ["objgraph (>=1.7.2)"] +profile = ["gprof2dot (>=2022.7.29)"] + [[package]] name = "distlib" version = "0.3.8" description = "Distribution utilities" -optional = false +optional = true python-versions = "*" files = [ {file = "distlib-0.3.8-py2.py3-none-any.whl", hash = "sha256:034db59a0b96f8ca18035f36290806a9a6e6bd9d1ff91e45a7f172eb17e51784"}, @@ -658,7 +792,7 @@ files = [ name = "exceptiongroup" version = "1.2.0" description = "Backport of PEP 654 (exception groups)" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "exceptiongroup-1.2.0-py3-none-any.whl", hash = "sha256:4bfd3996ac73b41e9b9628b04e079f193850720ea5945fc96a08633c66912f14"}, @@ -706,17 +840,106 @@ docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1 testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"] typing = ["typing-extensions (>=4.8)"] +[[package]] +name = "frozenlist" +version = "1.4.1" +description = "A list-like structure which implements collections.abc.MutableSequence" +optional = false +python-versions = ">=3.8" +files = [ + {file = "frozenlist-1.4.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f9aa1878d1083b276b0196f2dfbe00c9b7e752475ed3b682025ff20c1c1f51ac"}, + {file = "frozenlist-1.4.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:29acab3f66f0f24674b7dc4736477bcd4bc3ad4b896f5f45379a67bce8b96868"}, + {file = "frozenlist-1.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:74fb4bee6880b529a0c6560885fce4dc95936920f9f20f53d99a213f7bf66776"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:590344787a90ae57d62511dd7c736ed56b428f04cd8c161fcc5e7232c130c69a"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:068b63f23b17df8569b7fdca5517edef76171cf3897eb68beb01341131fbd2ad"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5c849d495bf5154cd8da18a9eb15db127d4dba2968d88831aff6f0331ea9bd4c"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9750cc7fe1ae3b1611bb8cfc3f9ec11d532244235d75901fb6b8e42ce9229dfe"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9b2de4cf0cdd5bd2dee4c4f63a653c61d2408055ab77b151c1957f221cabf2a"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0633c8d5337cb5c77acbccc6357ac49a1770b8c487e5b3505c57b949b4b82e98"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:27657df69e8801be6c3638054e202a135c7f299267f1a55ed3a598934f6c0d75"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:f9a3ea26252bd92f570600098783d1371354d89d5f6b7dfd87359d669f2109b5"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:4f57dab5fe3407b6c0c1cc907ac98e8a189f9e418f3b6e54d65a718aaafe3950"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:e02a0e11cf6597299b9f3bbd3f93d79217cb90cfd1411aec33848b13f5c656cc"}, + {file = "frozenlist-1.4.1-cp310-cp310-win32.whl", hash = "sha256:a828c57f00f729620a442881cc60e57cfcec6842ba38e1b19fd3e47ac0ff8dc1"}, + {file = "frozenlist-1.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:f56e2333dda1fe0f909e7cc59f021eba0d2307bc6f012a1ccf2beca6ba362439"}, + {file = "frozenlist-1.4.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:a0cb6f11204443f27a1628b0e460f37fb30f624be6051d490fa7d7e26d4af3d0"}, + {file = "frozenlist-1.4.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b46c8ae3a8f1f41a0d2ef350c0b6e65822d80772fe46b653ab6b6274f61d4a49"}, + {file = "frozenlist-1.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:fde5bd59ab5357e3853313127f4d3565fc7dad314a74d7b5d43c22c6a5ed2ced"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:722e1124aec435320ae01ee3ac7bec11a5d47f25d0ed6328f2273d287bc3abb0"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2471c201b70d58a0f0c1f91261542a03d9a5e088ed3dc6c160d614c01649c106"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c757a9dd70d72b076d6f68efdbb9bc943665ae954dad2801b874c8c69e185068"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f146e0911cb2f1da549fc58fc7bcd2b836a44b79ef871980d605ec392ff6b0d2"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f9c515e7914626b2a2e1e311794b4c35720a0be87af52b79ff8e1429fc25f19"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:c302220494f5c1ebeb0912ea782bcd5e2f8308037b3c7553fad0e48ebad6ad82"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:442acde1e068288a4ba7acfe05f5f343e19fac87bfc96d89eb886b0363e977ec"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:1b280e6507ea8a4fa0c0a7150b4e526a8d113989e28eaaef946cc77ffd7efc0a"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:fe1a06da377e3a1062ae5fe0926e12b84eceb8a50b350ddca72dc85015873f74"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:db9e724bebd621d9beca794f2a4ff1d26eed5965b004a97f1f1685a173b869c2"}, + {file = "frozenlist-1.4.1-cp311-cp311-win32.whl", hash = "sha256:e774d53b1a477a67838a904131c4b0eef6b3d8a651f8b138b04f748fccfefe17"}, + {file = "frozenlist-1.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:fb3c2db03683b5767dedb5769b8a40ebb47d6f7f45b1b3e3b4b51ec8ad9d9825"}, + {file = "frozenlist-1.4.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:1979bc0aeb89b33b588c51c54ab0161791149f2461ea7c7c946d95d5f93b56ae"}, + {file = "frozenlist-1.4.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:cc7b01b3754ea68a62bd77ce6020afaffb44a590c2289089289363472d13aedb"}, + {file = "frozenlist-1.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c9c92be9fd329ac801cc420e08452b70e7aeab94ea4233a4804f0915c14eba9b"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c3894db91f5a489fc8fa6a9991820f368f0b3cbdb9cd8849547ccfab3392d86"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ba60bb19387e13597fb059f32cd4d59445d7b18b69a745b8f8e5db0346f33480"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8aefbba5f69d42246543407ed2461db31006b0f76c4e32dfd6f42215a2c41d09"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:780d3a35680ced9ce682fbcf4cb9c2bad3136eeff760ab33707b71db84664e3a"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9acbb16f06fe7f52f441bb6f413ebae6c37baa6ef9edd49cdd567216da8600cd"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:23b701e65c7b36e4bf15546a89279bd4d8675faabc287d06bbcfac7d3c33e1e6"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:3e0153a805a98f5ada7e09826255ba99fb4f7524bb81bf6b47fb702666484ae1"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:dd9b1baec094d91bf36ec729445f7769d0d0cf6b64d04d86e45baf89e2b9059b"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:1a4471094e146b6790f61b98616ab8e44f72661879cc63fa1049d13ef711e71e"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:5667ed53d68d91920defdf4035d1cdaa3c3121dc0b113255124bcfada1cfa1b8"}, + {file = "frozenlist-1.4.1-cp312-cp312-win32.whl", hash = "sha256:beee944ae828747fd7cb216a70f120767fc9f4f00bacae8543c14a6831673f89"}, + {file = "frozenlist-1.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:64536573d0a2cb6e625cf309984e2d873979709f2cf22839bf2d61790b448ad5"}, + {file = "frozenlist-1.4.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:20b51fa3f588ff2fe658663db52a41a4f7aa6c04f6201449c6c7c476bd255c0d"}, + {file = "frozenlist-1.4.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:410478a0c562d1a5bcc2f7ea448359fcb050ed48b3c6f6f4f18c313a9bdb1826"}, + {file = "frozenlist-1.4.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:c6321c9efe29975232da3bd0af0ad216800a47e93d763ce64f291917a381b8eb"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:48f6a4533887e189dae092f1cf981f2e3885175f7a0f33c91fb5b7b682b6bab6"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6eb73fa5426ea69ee0e012fb59cdc76a15b1283d6e32e4f8dc4482ec67d1194d"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fbeb989b5cc29e8daf7f976b421c220f1b8c731cbf22b9130d8815418ea45887"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:32453c1de775c889eb4e22f1197fe3bdfe457d16476ea407472b9442e6295f7a"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:693945278a31f2086d9bf3df0fe8254bbeaef1fe71e1351c3bd730aa7d31c41b"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:1d0ce09d36d53bbbe566fe296965b23b961764c0bcf3ce2fa45f463745c04701"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:3a670dc61eb0d0eb7080890c13de3066790f9049b47b0de04007090807c776b0"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:dca69045298ce5c11fd539682cff879cc1e664c245d1c64da929813e54241d11"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:a06339f38e9ed3a64e4c4e43aec7f59084033647f908e4259d279a52d3757d09"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b7f2f9f912dca3934c1baec2e4585a674ef16fe00218d833856408c48d5beee7"}, + {file = "frozenlist-1.4.1-cp38-cp38-win32.whl", hash = "sha256:e7004be74cbb7d9f34553a5ce5fb08be14fb33bc86f332fb71cbe5216362a497"}, + {file = "frozenlist-1.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:5a7d70357e7cee13f470c7883a063aae5fe209a493c57d86eb7f5a6f910fae09"}, + {file = "frozenlist-1.4.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:bfa4a17e17ce9abf47a74ae02f32d014c5e9404b6d9ac7f729e01562bbee601e"}, + {file = "frozenlist-1.4.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b7e3ed87d4138356775346e6845cccbe66cd9e207f3cd11d2f0b9fd13681359d"}, + {file = "frozenlist-1.4.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c99169d4ff810155ca50b4da3b075cbde79752443117d89429595c2e8e37fed8"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:edb678da49d9f72c9f6c609fbe41a5dfb9a9282f9e6a2253d5a91e0fc382d7c0"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6db4667b187a6742b33afbbaf05a7bc551ffcf1ced0000a571aedbb4aa42fc7b"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:55fdc093b5a3cb41d420884cdaf37a1e74c3c37a31f46e66286d9145d2063bd0"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:82e8211d69a4f4bc360ea22cd6555f8e61a1bd211d1d5d39d3d228b48c83a897"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89aa2c2eeb20957be2d950b85974b30a01a762f3308cd02bb15e1ad632e22dc7"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9d3e0c25a2350080e9319724dede4f31f43a6c9779be48021a7f4ebde8b2d742"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:7268252af60904bf52c26173cbadc3a071cece75f873705419c8681f24d3edea"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:0c250a29735d4f15321007fb02865f0e6b6a41a6b88f1f523ca1596ab5f50bd5"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:96ec70beabbd3b10e8bfe52616a13561e58fe84c0101dd031dc78f250d5128b9"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:23b2d7679b73fe0e5a4560b672a39f98dfc6f60df63823b0a9970525325b95f6"}, + {file = "frozenlist-1.4.1-cp39-cp39-win32.whl", hash = "sha256:a7496bfe1da7fb1a4e1cc23bb67c58fab69311cc7d32b5a99c2007b4b2a0e932"}, + {file = "frozenlist-1.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:e6a20a581f9ce92d389a8c7d7c3dd47c81fd5d6e655c8dddf341e14aa48659d0"}, + {file = "frozenlist-1.4.1-py3-none-any.whl", hash = "sha256:04ced3e6a46b4cfffe20f9ae482818e34eba9b5fb0ce4056e4cc9b6e212d09b7"}, + {file = "frozenlist-1.4.1.tar.gz", hash = "sha256:c037a86e8513059a2613aaba4d817bb90b9d9b6b69aace3ce9c877e8c8ed402b"}, +] + [[package]] name = "fsspec" -version = "2024.3.1" +version = "2024.2.0" description = "File-system specification" optional = false python-versions = ">=3.8" files = [ - {file = "fsspec-2024.3.1-py3-none-any.whl", hash = "sha256:918d18d41bf73f0e2b261824baeb1b124bcf771767e3a26425cd7dec3332f512"}, - {file = "fsspec-2024.3.1.tar.gz", hash = "sha256:f39780e282d7d117ffb42bb96992f8a90795e4d0fb0f661a70ca39fe9c43ded9"}, + {file = "fsspec-2024.2.0-py3-none-any.whl", hash = "sha256:817f969556fa5916bc682e02ca2045f96ff7f586d45110fcb76022063ad2c7d8"}, + {file = "fsspec-2024.2.0.tar.gz", hash = "sha256:b6ad1a679f760dda52b1168c859d01b7b80648ea6f7f7c7f5a8a91dc3f3ecb84"}, ] +[package.dependencies] +aiohttp = {version = "<4.0.0a0 || >4.0.0a0,<4.0.0a1 || >4.0.0a1", optional = true, markers = "extra == \"http\""} + [package.extras] abfs = ["adlfs"] adl = ["adlfs"] @@ -921,7 +1144,7 @@ shapely = "^2.0.3" type = "git" url = "git@github.com:huggingface/gym-pusht.git" reference = "HEAD" -resolved_reference = "824b22832cc8d71a4b4e96a57563510cf47e30c1" +resolved_reference = "080d4ce4d8d3140b2fd204ed628bda14dc58ff06" [[package]] name = "gym-xarm" @@ -941,7 +1164,7 @@ mujoco = "^2.3.7" type = "git" url = "git@github.com:huggingface/gym-xarm.git" reference = "HEAD" -resolved_reference = "ce294c0d30def08414d9237e2bf9f373d448ca07" +resolved_reference = "6a88f7d63833705dfbec4b997bf36cac6b4a448c" [[package]] name = "gymnasium" @@ -1033,78 +1256,6 @@ files = [ [package.dependencies] numpy = ">=1.17.3" -[[package]] -name = "hf-transfer" -version = "0.1.6" -description = "" -optional = false -python-versions = ">=3.7" -files = [ - {file = "hf_transfer-0.1.6-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:6fd3d61f9229d27def007e53540412507b74ac2fdb1a29985ae0b6a5137749a2"}, - {file = "hf_transfer-0.1.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b043bb78df1225de043eb041de9d97783fcca14a0bdc1b1d560fc172fc21b648"}, - {file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7db60dd18eae4fa6ea157235fb82196cde5313995b396d1b591aad3b790a7f8f"}, - {file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:30d31dbab9b5a558cce407b8728e39d87d7af1ef8745ddb90187e9ae0b9e1e90"}, - {file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f6b368bddd757efc7af3126ba81f9ac8f9435e2cc00902cb3d64f2be28d8f719"}, - {file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aa2086d8aefaaa3e144e167324574882004c0cec49bf2d0638ec4b74732d8da0"}, - {file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:45d8985a0940bfe1535cb4ca781f5c11e47c83798ef3373ee1f5d57bbe527a9c"}, - {file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2f42b89735f1cde22f2a795d1f0915741023235666be7de45879e533c7d6010c"}, - {file = "hf_transfer-0.1.6-cp310-none-win32.whl", hash = "sha256:2d2c4c4613f3ad45b6ce6291e347b2d3ba1b86816635681436567e461cb3c961"}, - {file = "hf_transfer-0.1.6-cp310-none-win_amd64.whl", hash = "sha256:78b0eed8d8dce60168a46e584b9742b816af127d7e410a713e12c31249195342"}, - {file = "hf_transfer-0.1.6-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:f1d8c172153f9a6cdaecf137612c42796076f61f6bea1072c90ac2e17c1ab6fa"}, - {file = "hf_transfer-0.1.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2c601996351f90c514a75a0eeb02bf700b1ad1db2d946cbfe4b60b79e29f0b2f"}, - {file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e585c808405557d3f5488f385706abb696997bbae262ea04520757e30836d9d"}, - {file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ec51af1e8cf4268c268bd88932ade3d7ca895a3c661b42493503f02610ae906b"}, - {file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d106fdf996332f6df3ed3fab6d6332df82e8c1fb4b20fd81a491ca4d2ab5616a"}, - {file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e9c2ee9e9fde5a0319cc0e8ddfea10897482bc06d5709b10a238f1bc2ebcbc0b"}, - {file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f394ea32bc7802b061e549d3133efc523b4ae4fd19bf4b74b183ca6066eef94e"}, - {file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4282f09902114cd67fca98a1a1bad569a44521a8395fedf327e966714f68b977"}, - {file = "hf_transfer-0.1.6-cp311-none-win32.whl", hash = "sha256:276dbf307d5ab6f1bcbf57b5918bfcf9c59d6848ccb28242349e1bb5985f983b"}, - {file = "hf_transfer-0.1.6-cp311-none-win_amd64.whl", hash = "sha256:fa475175c51451186bea804471995fa8e7b2a48a61dcca55534911dc25955527"}, - {file = "hf_transfer-0.1.6-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:23d157a67acfa00007799323a1c441b2bbacc7dee625b016b7946fe0e25e6c89"}, - {file = "hf_transfer-0.1.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6067342a2864b988f861cd2d31bd78eb1e84d153a3f6df38485b6696d9ad3013"}, - {file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:91cfcb3070e205b58fa8dc8bcb6a62ccc40913fcdb9cd1ff7c364c8e3aa85345"}, - {file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eb76064ac5165d5eeaaf8d0903e8bf55477221ecc2a4a4d69f0baca065ab905b"}, - {file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9dabd3a177d83028f164984cf4dd859f77ec1e20c97a6f307ff8fcada0785ef1"}, - {file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d0bf4254e44f64a26e0a5b73b5d7e8d91bb36870718fb4f8e126ec943ff4c805"}, - {file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d32c1b106f38f336ceb21531f4db9b57d777b9a33017dafdb6a5316388ebe50"}, - {file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ff05aba3c83921e5c7635ba9f07c693cc893350c447644824043aeac27b285f5"}, - {file = "hf_transfer-0.1.6-cp312-none-win32.whl", hash = "sha256:051ef0c55607652cb5974f59638da035773254b9a07d7ee5b574fe062de4c9d1"}, - {file = "hf_transfer-0.1.6-cp312-none-win_amd64.whl", hash = "sha256:716fb5c574fcbdd8092ce73f9b6c66f42e3544337490f77c60ec07df02bd081b"}, - {file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c0c981134a55965e279cb7be778c1ccaf93f902fc9ebe31da4f30caf824cc4d"}, - {file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1ef1f145f04c5b573915bcb1eb5db4039c74f6b46fce73fc473c4287e613b623"}, - {file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d0a7609b004db3347dbb7796df45403eceb171238210d054d93897d6d84c63a4"}, - {file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:60f0864bf5996773dbd5f8ae4d1649041f773fe9d5769f4c0eeb5553100acef3"}, - {file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d01e55d630ffe70a4f5d0ed576a04c6a48d7c65ca9a7d18f2fca385f20685a9"}, - {file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d855946c5062b665190de15b2bdbd4c8eddfee35350bfb7564592e23d36fbbd3"}, - {file = "hf_transfer-0.1.6-cp37-none-win32.whl", hash = "sha256:fd40b2409cfaf3e8aba20169ee09552f69140e029adeec261b988903ff0c8f6f"}, - {file = "hf_transfer-0.1.6-cp37-none-win_amd64.whl", hash = "sha256:0e0eba49d46d3b5481919aea0794aec625fbc6ecdf13fe7e0e9f3fc5d5ad5971"}, - {file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7e669fecb29fc454449739f9f53ed9253197e7c19e6a6eaa0f08334207af4287"}, - {file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:89f701802892e5eb84f89f402686861f87dc227d6082b05f4e9d9b4e8015a3c3"}, - {file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b6f2b0c8b95b01409275d789a9b74d5f2e146346f985d384bf50ec727caf1ccc"}, - {file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aa855a2fa262792a230f9efcdb5da6d431b747d1861d2a69fe7834b19aea077e"}, - {file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4aa8ca349afb2f0713475426946261eb2035e4efb50ebd2c1d5ad04f395f4217"}, - {file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:01255f043996bc7d1bae62d8afc5033a90c7e36ce308b988eeb84afe0a69562f"}, - {file = "hf_transfer-0.1.6-cp38-none-win32.whl", hash = "sha256:60b1db183e8a7540cd4f8b2160ff4de55f77cb0c3fc6a10be1e7c30eb1b2bdeb"}, - {file = "hf_transfer-0.1.6-cp38-none-win_amd64.whl", hash = "sha256:fb8be3cba6aaa50ab2e9dffbd25c8eb2046785eeff642cf0cdd0dd9ae6be3539"}, - {file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d09af35e3e3f09b664e6429e9a0dc200f29c5bdfd88bdd9666de51183b1fe202"}, - {file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a4505bd707cc14d85c800f961fad8ca76f804a8ad22fbb7b1a217d8d0c15e6a5"}, - {file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2c453fd8b0be9740faa23cecd1f28ee9ead7d900cefa64ff836960c503a744c9"}, - {file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:13cb8884e718a78c3b81a8cdec9c7ac196dd42961fce55c3ccff3dd783e5ad7a"}, - {file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:39cd39df171a2b5404de69c4e6cd14eee47f6fe91c1692f939bfb9e59a0110d8"}, - {file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ff0629ee9f98df57a783599602eb498f9ec3619dc69348b12e4d9d754abf0e9"}, - {file = "hf_transfer-0.1.6-cp39-none-win32.whl", hash = "sha256:164a6ce445eb0cc7c645f5b6e1042c003d33292520c90052b6325f30c98e4c5f"}, - {file = "hf_transfer-0.1.6-cp39-none-win_amd64.whl", hash = "sha256:11b8b4b73bf455f13218c5f827698a30ae10998ca31b8264b51052868c7a9f11"}, - {file = "hf_transfer-0.1.6-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:16957ba057376a99ea361074ce1094f61b58e769defa6be2422ae59c0b6a6530"}, - {file = "hf_transfer-0.1.6-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7db952112e3b8ee1a5cbf500d2443e9ce4fb893281c5310a3e31469898628005"}, - {file = "hf_transfer-0.1.6-pp37-pypy37_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d39d826a7344f5e39f438d62632acd00467aa54a083b66496f61ef67a9885a56"}, - {file = "hf_transfer-0.1.6-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4e2653fbfa92e7651db73d99b697c8684e7345c479bd6857da80bed6138abb2"}, - {file = "hf_transfer-0.1.6-pp38-pypy38_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:144277e6a86add10b90ec3b583253aec777130312256bfc8d5ade5377e253807"}, - {file = "hf_transfer-0.1.6-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3bb53bcd16365313b2aa0dbdc28206f577d70770f31249cdabc387ac5841edcc"}, - {file = "hf_transfer-0.1.6-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:990d73a5a68d8261980f146c51f4c5f9995314011cb225222021ad7c39f3af2d"}, - {file = "hf_transfer-0.1.6-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:652406037029ab9b4097b4c5f29321bad5f64c2b46fbff142509d918aec87c29"}, - {file = "hf_transfer-0.1.6.tar.gz", hash = "sha256:deb505a7d417d7055fd7b3549eadb91dfe782941261f3344025c486c16d1d2f9"}, -] - [[package]] name = "huggingface-hub" version = "0.21.4" @@ -1119,7 +1270,6 @@ files = [ [package.dependencies] filelock = "*" fsspec = ">=2023.5.0" -hf-transfer = {version = ">=0.1.4", optional = true, markers = "extra == \"hf_transfer\""} packaging = ">=20.9" pyyaml = ">=5.1" requests = "*" @@ -1159,7 +1309,7 @@ packaging = "*" name = "identify" version = "2.5.35" description = "File identification library for Python" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "identify-2.5.35-py2.py3-none-any.whl", hash = "sha256:c4de0081837b211594f8e877a6b4fad7ca32bbfc1a9307fdd61c28bfe923f13e"}, @@ -1192,9 +1342,10 @@ files = [ ] [package.dependencies] -av = {version = "*", optional = true, markers = "extra == \"pyav\""} +imageio-ffmpeg = {version = "*", optional = true, markers = "extra == \"ffmpeg\""} numpy = "*" pillow = ">=8.3.2" +psutil = {version = "*", optional = true, markers = "extra == \"ffmpeg\""} [package.extras] all-plugins = ["astropy", "av", "imageio-ffmpeg", "pillow-heif", "psutil", "tifffile"] @@ -1254,7 +1405,7 @@ testing = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", name = "iniconfig" version = "2.0.0" description = "brain-dead simple config-ini parsing" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, @@ -1622,30 +1773,6 @@ files = [ {file = "MarkupSafe-2.1.5.tar.gz", hash = "sha256:d283d37a890ba4c1ae73ffadf8046435c76e7bc2247bbb63c00bd1a709c6544b"}, ] -[[package]] -name = "moviepy" -version = "1.0.3" -description = "Video editing with Python" -optional = false -python-versions = "*" -files = [ - {file = "moviepy-1.0.3.tar.gz", hash = "sha256:2884e35d1788077db3ff89e763c5ba7bfddbd7ae9108c9bc809e7ba58fa433f5"}, -] - -[package.dependencies] -decorator = ">=4.0.2,<5.0" -imageio = {version = ">=2.5,<3.0", markers = "python_version >= \"3.4\""} -imageio_ffmpeg = {version = ">=0.2.0", markers = "python_version >= \"3.4\""} -numpy = {version = ">=1.17.3", markers = "python_version > \"2.7\""} -proglog = "<=1.0.0" -requests = ">=2.8.1,<3.0" -tqdm = ">=4.11.2,<5.0" - -[package.extras] -doc = ["Sphinx (>=1.5.2,<2.0)", "numpydoc (>=0.6.0,<1.0)", "pygame (>=1.9.3,<2.0)", "sphinx_rtd_theme (>=0.1.10b0,<1.0)"] -optional = ["matplotlib (>=2.0.0,<3.0)", "opencv-python (>=3.0,<4.0)", "scikit-image (>=0.13.0,<1.0)", "scikit-learn", "scipy (>=0.19.0,<1.5)", "youtube_dl"] -test = ["coverage (<5.0)", "coveralls (>=1.1,<2.0)", "pytest (>=3.0.0,<4.0)", "pytest-cov (>=2.5.1,<3.0)", "requests (>=2.8.1,<3.0)"] - [[package]] name = "mpmath" version = "1.3.0" @@ -1703,6 +1830,129 @@ glfw = "*" numpy = "*" pyopengl = "*" +[[package]] +name = "multidict" +version = "6.0.5" +description = "multidict implementation" +optional = false +python-versions = ">=3.7" +files = [ + {file = "multidict-6.0.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:228b644ae063c10e7f324ab1ab6b548bdf6f8b47f3ec234fef1093bc2735e5f9"}, + {file = "multidict-6.0.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:896ebdcf62683551312c30e20614305f53125750803b614e9e6ce74a96232604"}, + {file = "multidict-6.0.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:411bf8515f3be9813d06004cac41ccf7d1cd46dfe233705933dd163b60e37600"}, + {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d147090048129ce3c453f0292e7697d333db95e52616b3793922945804a433c"}, + {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:215ed703caf15f578dca76ee6f6b21b7603791ae090fbf1ef9d865571039ade5"}, + {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c6390cf87ff6234643428991b7359b5f59cc15155695deb4eda5c777d2b880f"}, + {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21fd81c4ebdb4f214161be351eb5bcf385426bf023041da2fd9e60681f3cebae"}, + {file = "multidict-6.0.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3cc2ad10255f903656017363cd59436f2111443a76f996584d1077e43ee51182"}, + {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:6939c95381e003f54cd4c5516740faba40cf5ad3eeff460c3ad1d3e0ea2549bf"}, + {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:220dd781e3f7af2c2c1053da9fa96d9cf3072ca58f057f4c5adaaa1cab8fc442"}, + {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:766c8f7511df26d9f11cd3a8be623e59cca73d44643abab3f8c8c07620524e4a"}, + {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:fe5d7785250541f7f5019ab9cba2c71169dc7d74d0f45253f8313f436458a4ef"}, + {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c1c1496e73051918fcd4f58ff2e0f2f3066d1c76a0c6aeffd9b45d53243702cc"}, + {file = "multidict-6.0.5-cp310-cp310-win32.whl", hash = "sha256:7afcdd1fc07befad18ec4523a782cde4e93e0a2bf71239894b8d61ee578c1319"}, + {file = "multidict-6.0.5-cp310-cp310-win_amd64.whl", hash = "sha256:99f60d34c048c5c2fabc766108c103612344c46e35d4ed9ae0673d33c8fb26e8"}, + {file = "multidict-6.0.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:f285e862d2f153a70586579c15c44656f888806ed0e5b56b64489afe4a2dbfba"}, + {file = "multidict-6.0.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:53689bb4e102200a4fafa9de9c7c3c212ab40a7ab2c8e474491914d2305f187e"}, + {file = "multidict-6.0.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:612d1156111ae11d14afaf3a0669ebf6c170dbb735e510a7438ffe2369a847fd"}, + {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7be7047bd08accdb7487737631d25735c9a04327911de89ff1b26b81745bd4e3"}, + {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:de170c7b4fe6859beb8926e84f7d7d6c693dfe8e27372ce3b76f01c46e489fcf"}, + {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:04bde7a7b3de05732a4eb39c94574db1ec99abb56162d6c520ad26f83267de29"}, + {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:85f67aed7bb647f93e7520633d8f51d3cbc6ab96957c71272b286b2f30dc70ed"}, + {file = "multidict-6.0.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:425bf820055005bfc8aa9a0b99ccb52cc2f4070153e34b701acc98d201693733"}, + {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:d3eb1ceec286eba8220c26f3b0096cf189aea7057b6e7b7a2e60ed36b373b77f"}, + {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:7901c05ead4b3fb75113fb1dd33eb1253c6d3ee37ce93305acd9d38e0b5f21a4"}, + {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:e0e79d91e71b9867c73323a3444724d496c037e578a0e1755ae159ba14f4f3d1"}, + {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:29bfeb0dff5cb5fdab2023a7a9947b3b4af63e9c47cae2a10ad58394b517fddc"}, + {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e030047e85cbcedbfc073f71836d62dd5dadfbe7531cae27789ff66bc551bd5e"}, + {file = "multidict-6.0.5-cp311-cp311-win32.whl", hash = "sha256:2f4848aa3baa109e6ab81fe2006c77ed4d3cd1e0ac2c1fbddb7b1277c168788c"}, + {file = "multidict-6.0.5-cp311-cp311-win_amd64.whl", hash = "sha256:2faa5ae9376faba05f630d7e5e6be05be22913782b927b19d12b8145968a85ea"}, + {file = "multidict-6.0.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:51d035609b86722963404f711db441cf7134f1889107fb171a970c9701f92e1e"}, + {file = "multidict-6.0.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:cbebcd5bcaf1eaf302617c114aa67569dd3f090dd0ce8ba9e35e9985b41ac35b"}, + {file = "multidict-6.0.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2ffc42c922dbfddb4a4c3b438eb056828719f07608af27d163191cb3e3aa6cc5"}, + {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ceb3b7e6a0135e092de86110c5a74e46bda4bd4fbfeeb3a3bcec79c0f861e450"}, + {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:79660376075cfd4b2c80f295528aa6beb2058fd289f4c9252f986751a4cd0496"}, + {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e4428b29611e989719874670fd152b6625500ad6c686d464e99f5aaeeaca175a"}, + {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d84a5c3a5f7ce6db1f999fb9438f686bc2e09d38143f2d93d8406ed2dd6b9226"}, + {file = "multidict-6.0.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:76c0de87358b192de7ea9649beb392f107dcad9ad27276324c24c91774ca5271"}, + {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:79a6d2ba910adb2cbafc95dad936f8b9386e77c84c35bc0add315b856d7c3abb"}, + {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:92d16a3e275e38293623ebf639c471d3e03bb20b8ebb845237e0d3664914caef"}, + {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:fb616be3538599e797a2017cccca78e354c767165e8858ab5116813146041a24"}, + {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:14c2976aa9038c2629efa2c148022ed5eb4cb939e15ec7aace7ca932f48f9ba6"}, + {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:435a0984199d81ca178b9ae2c26ec3d49692d20ee29bc4c11a2a8d4514c67eda"}, + {file = "multidict-6.0.5-cp312-cp312-win32.whl", hash = "sha256:9fe7b0653ba3d9d65cbe7698cca585bf0f8c83dbbcc710db9c90f478e175f2d5"}, + {file = "multidict-6.0.5-cp312-cp312-win_amd64.whl", hash = "sha256:01265f5e40f5a17f8241d52656ed27192be03bfa8764d88e8220141d1e4b3556"}, + {file = "multidict-6.0.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:19fe01cea168585ba0f678cad6f58133db2aa14eccaf22f88e4a6dccadfad8b3"}, + {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6bf7a982604375a8d49b6cc1b781c1747f243d91b81035a9b43a2126c04766f5"}, + {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:107c0cdefe028703fb5dafe640a409cb146d44a6ae201e55b35a4af8e95457dd"}, + {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:403c0911cd5d5791605808b942c88a8155c2592e05332d2bf78f18697a5fa15e"}, + {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aeaf541ddbad8311a87dd695ed9642401131ea39ad7bc8cf3ef3967fd093b626"}, + {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e4972624066095e52b569e02b5ca97dbd7a7ddd4294bf4e7247d52635630dd83"}, + {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:d946b0a9eb8aaa590df1fe082cee553ceab173e6cb5b03239716338629c50c7a"}, + {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:b55358304d7a73d7bdf5de62494aaf70bd33015831ffd98bc498b433dfe5b10c"}, + {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:a3145cb08d8625b2d3fee1b2d596a8766352979c9bffe5d7833e0503d0f0b5e5"}, + {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:d65f25da8e248202bd47445cec78e0025c0fe7582b23ec69c3b27a640dd7a8e3"}, + {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:c9bf56195c6bbd293340ea82eafd0071cb3d450c703d2c93afb89f93b8386ccc"}, + {file = "multidict-6.0.5-cp37-cp37m-win32.whl", hash = "sha256:69db76c09796b313331bb7048229e3bee7928eb62bab5e071e9f7fcc4879caee"}, + {file = "multidict-6.0.5-cp37-cp37m-win_amd64.whl", hash = "sha256:fce28b3c8a81b6b36dfac9feb1de115bab619b3c13905b419ec71d03a3fc1423"}, + {file = "multidict-6.0.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:76f067f5121dcecf0d63a67f29080b26c43c71a98b10c701b0677e4a065fbd54"}, + {file = "multidict-6.0.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b82cc8ace10ab5bd93235dfaab2021c70637005e1ac787031f4d1da63d493c1d"}, + {file = "multidict-6.0.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5cb241881eefd96b46f89b1a056187ea8e9ba14ab88ba632e68d7a2ecb7aadf7"}, + {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e8e94e6912639a02ce173341ff62cc1201232ab86b8a8fcc05572741a5dc7d93"}, + {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:09a892e4a9fb47331da06948690ae38eaa2426de97b4ccbfafbdcbe5c8f37ff8"}, + {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:55205d03e8a598cfc688c71ca8ea5f66447164efff8869517f175ea632c7cb7b"}, + {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:37b15024f864916b4951adb95d3a80c9431299080341ab9544ed148091b53f50"}, + {file = "multidict-6.0.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f2a1dee728b52b33eebff5072817176c172050d44d67befd681609b4746e1c2e"}, + {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:edd08e6f2f1a390bf137080507e44ccc086353c8e98c657e666c017718561b89"}, + {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:60d698e8179a42ec85172d12f50b1668254628425a6bd611aba022257cac1386"}, + {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:3d25f19500588cbc47dc19081d78131c32637c25804df8414463ec908631e453"}, + {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:4cc0ef8b962ac7a5e62b9e826bd0cd5040e7d401bc45a6835910ed699037a461"}, + {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:eca2e9d0cc5a889850e9bbd68e98314ada174ff6ccd1129500103df7a94a7a44"}, + {file = "multidict-6.0.5-cp38-cp38-win32.whl", hash = "sha256:4a6a4f196f08c58c59e0b8ef8ec441d12aee4125a7d4f4fef000ccb22f8d7241"}, + {file = "multidict-6.0.5-cp38-cp38-win_amd64.whl", hash = "sha256:0275e35209c27a3f7951e1ce7aaf93ce0d163b28948444bec61dd7badc6d3f8c"}, + {file = "multidict-6.0.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:e7be68734bd8c9a513f2b0cfd508802d6609da068f40dc57d4e3494cefc92929"}, + {file = "multidict-6.0.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1d9ea7a7e779d7a3561aade7d596649fbecfa5c08a7674b11b423783217933f9"}, + {file = "multidict-6.0.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ea1456df2a27c73ce51120fa2f519f1bea2f4a03a917f4a43c8707cf4cbbae1a"}, + {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cf590b134eb70629e350691ecca88eac3e3b8b3c86992042fb82e3cb1830d5e1"}, + {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5c0631926c4f58e9a5ccce555ad7747d9a9f8b10619621f22f9635f069f6233e"}, + {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dce1c6912ab9ff5f179eaf6efe7365c1f425ed690b03341911bf4939ef2f3046"}, + {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0868d64af83169e4d4152ec612637a543f7a336e4a307b119e98042e852ad9c"}, + {file = "multidict-6.0.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:141b43360bfd3bdd75f15ed811850763555a251e38b2405967f8e25fb43f7d40"}, + {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:7df704ca8cf4a073334e0427ae2345323613e4df18cc224f647f251e5e75a527"}, + {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:6214c5a5571802c33f80e6c84713b2c79e024995b9c5897f794b43e714daeec9"}, + {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:cd6c8fca38178e12c00418de737aef1261576bd1b6e8c6134d3e729a4e858b38"}, + {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:e02021f87a5b6932fa6ce916ca004c4d441509d33bbdbeca70d05dff5e9d2479"}, + {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ebd8d160f91a764652d3e51ce0d2956b38efe37c9231cd82cfc0bed2e40b581c"}, + {file = "multidict-6.0.5-cp39-cp39-win32.whl", hash = "sha256:04da1bb8c8dbadf2a18a452639771951c662c5ad03aefe4884775454be322c9b"}, + {file = "multidict-6.0.5-cp39-cp39-win_amd64.whl", hash = "sha256:d6f6d4f185481c9669b9447bf9d9cf3b95a0e9df9d169bbc17e363b7d5487755"}, + {file = "multidict-6.0.5-py3-none-any.whl", hash = "sha256:0d63c74e3d7ab26de115c49bffc92cc77ed23395303d496eae515d4204a625e7"}, + {file = "multidict-6.0.5.tar.gz", hash = "sha256:f7e301075edaf50500f0b341543c41194d8df3ae5caf4702f2095f3ca73dd8da"}, +] + +[[package]] +name = "multiprocess" +version = "0.70.16" +description = "better multiprocessing and multithreading in Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "multiprocess-0.70.16-pp310-pypy310_pp73-macosx_10_13_x86_64.whl", hash = "sha256:476887be10e2f59ff183c006af746cb6f1fd0eadcfd4ef49e605cbe2659920ee"}, + {file = "multiprocess-0.70.16-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:d951bed82c8f73929ac82c61f01a7b5ce8f3e5ef40f5b52553b4f547ce2b08ec"}, + {file = "multiprocess-0.70.16-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:37b55f71c07e2d741374998c043b9520b626a8dddc8b3129222ca4f1a06ef67a"}, + {file = "multiprocess-0.70.16-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:ba8c31889abf4511c7308a8c52bb4a30b9d590e7f58523302ba00237702ca054"}, + {file = "multiprocess-0.70.16-pp39-pypy39_pp73-macosx_10_13_x86_64.whl", hash = "sha256:0dfd078c306e08d46d7a8d06fb120313d87aa43af60d66da43ffff40b44d2f41"}, + {file = "multiprocess-0.70.16-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:e7b9d0f307cd9bd50851afaac0dba2cb6c44449efff697df7c7645f7d3f2be3a"}, + {file = "multiprocess-0.70.16-py310-none-any.whl", hash = "sha256:c4a9944c67bd49f823687463660a2d6daae94c289adff97e0f9d696ba6371d02"}, + {file = "multiprocess-0.70.16-py311-none-any.whl", hash = "sha256:af4cabb0dac72abfb1e794fa7855c325fd2b55a10a44628a3c1ad3311c04127a"}, + {file = "multiprocess-0.70.16-py312-none-any.whl", hash = "sha256:fc0544c531920dde3b00c29863377f87e1632601092ea2daca74e4beb40faa2e"}, + {file = "multiprocess-0.70.16-py38-none-any.whl", hash = "sha256:a71d82033454891091a226dfc319d0cfa8019a4e888ef9ca910372a446de4435"}, + {file = "multiprocess-0.70.16-py39-none-any.whl", hash = "sha256:a0bafd3ae1b732eac64be2e72038231c1ba97724b60b09400d68f229fcc2fbf3"}, + {file = "multiprocess-0.70.16.tar.gz", hash = "sha256:161af703d4652a0e1410be6abccecde4a7ddffd19341be0a7011b94aeb171ac1"}, +] + +[package.dependencies] +dill = ">=0.3.8" + [[package]] name = "networkx" version = "3.3" @@ -1725,7 +1975,7 @@ test = ["pytest (>=7.2)", "pytest-cov (>=4.0)"] name = "nodeenv" version = "1.8.0" description = "Node.js virtual environment builder" -optional = false +optional = true python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*" files = [ {file = "nodeenv-1.8.0-py2.py3-none-any.whl", hash = "sha256:df865724bb3c3adc86b3876fa209771517b0cfe596beff01a92700e0e8be4cec"}, @@ -2047,47 +2297,47 @@ files = [ [[package]] name = "pandas" -version = "2.2.1" +version = "2.2.2" description = "Powerful data structures for data analysis, time series, and statistics" optional = false python-versions = ">=3.9" files = [ - {file = "pandas-2.2.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8df8612be9cd1c7797c93e1c5df861b2ddda0b48b08f2c3eaa0702cf88fb5f88"}, - {file = "pandas-2.2.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0f573ab277252ed9aaf38240f3b54cfc90fff8e5cab70411ee1d03f5d51f3944"}, - {file = "pandas-2.2.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f02a3a6c83df4026e55b63c1f06476c9aa3ed6af3d89b4f04ea656ccdaaaa359"}, - {file = "pandas-2.2.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c38ce92cb22a4bea4e3929429aa1067a454dcc9c335799af93ba9be21b6beb51"}, - {file = "pandas-2.2.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:c2ce852e1cf2509a69e98358e8458775f89599566ac3775e70419b98615f4b06"}, - {file = "pandas-2.2.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:53680dc9b2519cbf609c62db3ed7c0b499077c7fefda564e330286e619ff0dd9"}, - {file = "pandas-2.2.1-cp310-cp310-win_amd64.whl", hash = "sha256:94e714a1cca63e4f5939cdce5f29ba8d415d85166be3441165edd427dc9f6bc0"}, - {file = "pandas-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f821213d48f4ab353d20ebc24e4faf94ba40d76680642fb7ce2ea31a3ad94f9b"}, - {file = "pandas-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c70e00c2d894cb230e5c15e4b1e1e6b2b478e09cf27cc593a11ef955b9ecc81a"}, - {file = "pandas-2.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e97fbb5387c69209f134893abc788a6486dbf2f9e511070ca05eed4b930b1b02"}, - {file = "pandas-2.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:101d0eb9c5361aa0146f500773395a03839a5e6ecde4d4b6ced88b7e5a1a6403"}, - {file = "pandas-2.2.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:7d2ed41c319c9fb4fd454fe25372028dfa417aacb9790f68171b2e3f06eae8cd"}, - {file = "pandas-2.2.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:af5d3c00557d657c8773ef9ee702c61dd13b9d7426794c9dfeb1dc4a0bf0ebc7"}, - {file = "pandas-2.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:06cf591dbaefb6da9de8472535b185cba556d0ce2e6ed28e21d919704fef1a9e"}, - {file = "pandas-2.2.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:88ecb5c01bb9ca927ebc4098136038519aa5d66b44671861ffab754cae75102c"}, - {file = "pandas-2.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:04f6ec3baec203c13e3f8b139fb0f9f86cd8c0b94603ae3ae8ce9a422e9f5bee"}, - {file = "pandas-2.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a935a90a76c44fe170d01e90a3594beef9e9a6220021acfb26053d01426f7dc2"}, - {file = "pandas-2.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c391f594aae2fd9f679d419e9a4d5ba4bce5bb13f6a989195656e7dc4b95c8f0"}, - {file = "pandas-2.2.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9d1265545f579edf3f8f0cb6f89f234f5e44ba725a34d86535b1a1d38decbccc"}, - {file = "pandas-2.2.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:11940e9e3056576ac3244baef2fedade891977bcc1cb7e5cc8f8cc7d603edc89"}, - {file = "pandas-2.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:4acf681325ee1c7f950d058b05a820441075b0dd9a2adf5c4835b9bc056bf4fb"}, - {file = "pandas-2.2.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9bd8a40f47080825af4317d0340c656744f2bfdb6819f818e6ba3cd24c0e1397"}, - {file = "pandas-2.2.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:df0c37ebd19e11d089ceba66eba59a168242fc6b7155cba4ffffa6eccdfb8f16"}, - {file = "pandas-2.2.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:739cc70eaf17d57608639e74d63387b0d8594ce02f69e7a0b046f117974b3019"}, - {file = "pandas-2.2.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f9d3558d263073ed95e46f4650becff0c5e1ffe0fc3a015de3c79283dfbdb3df"}, - {file = "pandas-2.2.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:4aa1d8707812a658debf03824016bf5ea0d516afdea29b7dc14cf687bc4d4ec6"}, - {file = "pandas-2.2.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:76f27a809cda87e07f192f001d11adc2b930e93a2b0c4a236fde5429527423be"}, - {file = "pandas-2.2.1-cp39-cp39-win_amd64.whl", hash = "sha256:1ba21b1d5c0e43416218db63037dbe1a01fc101dc6e6024bcad08123e48004ab"}, - {file = "pandas-2.2.1.tar.gz", hash = "sha256:0ab90f87093c13f3e8fa45b48ba9f39181046e8f3317d3aadb2fffbb1b978572"}, + {file = "pandas-2.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:90c6fca2acf139569e74e8781709dccb6fe25940488755716d1d354d6bc58bce"}, + {file = "pandas-2.2.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c7adfc142dac335d8c1e0dcbd37eb8617eac386596eb9e1a1b77791cf2498238"}, + {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4abfe0be0d7221be4f12552995e58723c7422c80a659da13ca382697de830c08"}, + {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8635c16bf3d99040fdf3ca3db669a7250ddf49c55dc4aa8fe0ae0fa8d6dcc1f0"}, + {file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:40ae1dffb3967a52203105a077415a86044a2bea011b5f321c6aa64b379a3f51"}, + {file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8e5a0b00e1e56a842f922e7fae8ae4077aee4af0acb5ae3622bd4b4c30aedf99"}, + {file = "pandas-2.2.2-cp310-cp310-win_amd64.whl", hash = "sha256:ddf818e4e6c7c6f4f7c8a12709696d193976b591cc7dc50588d3d1a6b5dc8772"}, + {file = "pandas-2.2.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:696039430f7a562b74fa45f540aca068ea85fa34c244d0deee539cb6d70aa288"}, + {file = "pandas-2.2.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8e90497254aacacbc4ea6ae5e7a8cd75629d6ad2b30025a4a8b09aa4faf55151"}, + {file = "pandas-2.2.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:58b84b91b0b9f4bafac2a0ac55002280c094dfc6402402332c0913a59654ab2b"}, + {file = "pandas-2.2.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d2123dc9ad6a814bcdea0f099885276b31b24f7edf40f6cdbc0912672e22eee"}, + {file = "pandas-2.2.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:2925720037f06e89af896c70bca73459d7e6a4be96f9de79e2d440bd499fe0db"}, + {file = "pandas-2.2.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0cace394b6ea70c01ca1595f839cf193df35d1575986e484ad35c4aeae7266c1"}, + {file = "pandas-2.2.2-cp311-cp311-win_amd64.whl", hash = "sha256:873d13d177501a28b2756375d59816c365e42ed8417b41665f346289adc68d24"}, + {file = "pandas-2.2.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:9dfde2a0ddef507a631dc9dc4af6a9489d5e2e740e226ad426a05cabfbd7c8ef"}, + {file = "pandas-2.2.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:e9b79011ff7a0f4b1d6da6a61aa1aa604fb312d6647de5bad20013682d1429ce"}, + {file = "pandas-2.2.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1cb51fe389360f3b5a4d57dbd2848a5f033350336ca3b340d1c53a1fad33bcad"}, + {file = "pandas-2.2.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eee3a87076c0756de40b05c5e9a6069c035ba43e8dd71c379e68cab2c20f16ad"}, + {file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:3e374f59e440d4ab45ca2fffde54b81ac3834cf5ae2cdfa69c90bc03bde04d76"}, + {file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:43498c0bdb43d55cb162cdc8c06fac328ccb5d2eabe3cadeb3529ae6f0517c32"}, + {file = "pandas-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:d187d355ecec3629624fccb01d104da7d7f391db0311145817525281e2804d23"}, + {file = "pandas-2.2.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0ca6377b8fca51815f382bd0b697a0814c8bda55115678cbc94c30aacbb6eff2"}, + {file = "pandas-2.2.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9057e6aa78a584bc93a13f0a9bf7e753a5e9770a30b4d758b8d5f2a62a9433cd"}, + {file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:001910ad31abc7bf06f49dcc903755d2f7f3a9186c0c040b827e522e9cef0863"}, + {file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66b479b0bd07204e37583c191535505410daa8df638fd8e75ae1b383851fe921"}, + {file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a77e9d1c386196879aa5eb712e77461aaee433e54c68cf253053a73b7e49c33a"}, + {file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:92fd6b027924a7e178ac202cfbe25e53368db90d56872d20ffae94b96c7acc57"}, + {file = "pandas-2.2.2-cp39-cp39-win_amd64.whl", hash = "sha256:640cef9aa381b60e296db324337a554aeeb883ead99dc8f6c18e81a93942f5f4"}, + {file = "pandas-2.2.2.tar.gz", hash = "sha256:9e79019aba43cb4fda9e4d983f8e88ca0373adbb697ae9c6c43093218de28b54"}, ] [package.dependencies] numpy = [ - {version = ">=1.26.0,<2", markers = "python_version >= \"3.12\""}, - {version = ">=1.23.2,<2", markers = "python_version == \"3.11\""}, - {version = ">=1.22.4,<2", markers = "python_version < \"3.11\""}, + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, + {version = ">=1.23.2", markers = "python_version == \"3.11\""}, + {version = ">=1.22.4", markers = "python_version < \"3.11\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -2233,7 +2483,7 @@ xmp = ["defusedxml"] name = "platformdirs" version = "4.2.0" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "platformdirs-4.2.0-py3-none-any.whl", hash = "sha256:0614df2a2f37e1a662acbd8e2b25b92ccf8632929bc6d43467e17fe89c75e068"}, @@ -2248,7 +2498,7 @@ test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4.3)", "pytest- name = "pluggy" version = "1.4.0" description = "plugin and hook calling mechanisms for python" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "pluggy-1.4.0-py3-none-any.whl", hash = "sha256:7db9f7b503d67d1c5b95f59773ebb58a8c1c288129a88665838012cfb07b8981"}, @@ -2263,7 +2513,7 @@ testing = ["pytest", "pytest-benchmark"] name = "pre-commit" version = "3.7.0" description = "A framework for managing and maintaining multi-language pre-commit hooks." -optional = false +optional = true python-versions = ">=3.9" files = [ {file = "pre_commit-3.7.0-py2.py3-none-any.whl", hash = "sha256:5eae9e10c2b5ac51577c3452ec0a490455c45a0533f7960f993a0d01e59decab"}, @@ -2277,20 +2527,6 @@ nodeenv = ">=0.11.1" pyyaml = ">=5.1" virtualenv = ">=20.10.0" -[[package]] -name = "proglog" -version = "0.1.10" -description = "Log and progress bar manager for console, notebooks, web..." -optional = false -python-versions = "*" -files = [ - {file = "proglog-0.1.10-py3-none-any.whl", hash = "sha256:19d5da037e8c813da480b741e3fa71fb1ac0a5b02bf21c41577c7f327485ec50"}, - {file = "proglog-0.1.10.tar.gz", hash = "sha256:658c28c9c82e4caeb2f25f488fff9ceace22f8d69b15d0c1c86d64275e4ddab4"}, -] - -[package.dependencies] -tqdm = "*" - [[package]] name = "protobuf" version = "4.25.3" @@ -2339,6 +2575,65 @@ files = [ [package.extras] test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"] +[[package]] +name = "pyarrow" +version = "15.0.2" +description = "Python library for Apache Arrow" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pyarrow-15.0.2-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:88b340f0a1d05b5ccc3d2d986279045655b1fe8e41aba6ca44ea28da0d1455d8"}, + {file = "pyarrow-15.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:eaa8f96cecf32da508e6c7f69bb8401f03745c050c1dd42ec2596f2e98deecac"}, + {file = "pyarrow-15.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:23c6753ed4f6adb8461e7c383e418391b8d8453c5d67e17f416c3a5d5709afbd"}, + {file = "pyarrow-15.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f639c059035011db8c0497e541a8a45d98a58dbe34dc8fadd0ef128f2cee46e5"}, + {file = "pyarrow-15.0.2-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:290e36a59a0993e9a5224ed2fb3e53375770f07379a0ea03ee2fce2e6d30b423"}, + {file = "pyarrow-15.0.2-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:06c2bb2a98bc792f040bef31ad3e9be6a63d0cb39189227c08a7d955db96816e"}, + {file = "pyarrow-15.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:f7a197f3670606a960ddc12adbe8075cea5f707ad7bf0dffa09637fdbb89f76c"}, + {file = "pyarrow-15.0.2-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:5f8bc839ea36b1f99984c78e06e7a06054693dc2af8920f6fb416b5bca9944e4"}, + {file = "pyarrow-15.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f5e81dfb4e519baa6b4c80410421528c214427e77ca0ea9461eb4097c328fa33"}, + {file = "pyarrow-15.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3a4f240852b302a7af4646c8bfe9950c4691a419847001178662a98915fd7ee7"}, + {file = "pyarrow-15.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4e7d9cfb5a1e648e172428c7a42b744610956f3b70f524aa3a6c02a448ba853e"}, + {file = "pyarrow-15.0.2-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:2d4f905209de70c0eb5b2de6763104d5a9a37430f137678edfb9a675bac9cd98"}, + {file = "pyarrow-15.0.2-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:90adb99e8ce5f36fbecbbc422e7dcbcbed07d985eed6062e459e23f9e71fd197"}, + {file = "pyarrow-15.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:b116e7fd7889294cbd24eb90cd9bdd3850be3738d61297855a71ac3b8124ee38"}, + {file = "pyarrow-15.0.2-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:25335e6f1f07fdaa026a61c758ee7d19ce824a866b27bba744348fa73bb5a440"}, + {file = "pyarrow-15.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:90f19e976d9c3d8e73c80be84ddbe2f830b6304e4c576349d9360e335cd627fc"}, + {file = "pyarrow-15.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a22366249bf5fd40ddacc4f03cd3160f2d7c247692945afb1899bab8a140ddfb"}, + {file = "pyarrow-15.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c2a335198f886b07e4b5ea16d08ee06557e07db54a8400cc0d03c7f6a22f785f"}, + {file = "pyarrow-15.0.2-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:3e6d459c0c22f0b9c810a3917a1de3ee704b021a5fb8b3bacf968eece6df098f"}, + {file = "pyarrow-15.0.2-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:033b7cad32198754d93465dcfb71d0ba7cb7cd5c9afd7052cab7214676eec38b"}, + {file = "pyarrow-15.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:29850d050379d6e8b5a693098f4de7fd6a2bea4365bfd073d7c57c57b95041ee"}, + {file = "pyarrow-15.0.2-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:7167107d7fb6dcadb375b4b691b7e316f4368f39f6f45405a05535d7ad5e5058"}, + {file = "pyarrow-15.0.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e85241b44cc3d365ef950432a1b3bd44ac54626f37b2e3a0cc89c20e45dfd8bf"}, + {file = "pyarrow-15.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:248723e4ed3255fcd73edcecc209744d58a9ca852e4cf3d2577811b6d4b59818"}, + {file = "pyarrow-15.0.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ff3bdfe6f1b81ca5b73b70a8d482d37a766433823e0c21e22d1d7dde76ca33f"}, + {file = "pyarrow-15.0.2-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:f3d77463dee7e9f284ef42d341689b459a63ff2e75cee2b9302058d0d98fe142"}, + {file = "pyarrow-15.0.2-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:8c1faf2482fb89766e79745670cbca04e7018497d85be9242d5350cba21357e1"}, + {file = "pyarrow-15.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:28f3016958a8e45a1069303a4a4f6a7d4910643fc08adb1e2e4a7ff056272ad3"}, + {file = "pyarrow-15.0.2-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:89722cb64286ab3d4daf168386f6968c126057b8c7ec3ef96302e81d8cdb8ae4"}, + {file = "pyarrow-15.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:cd0ba387705044b3ac77b1b317165c0498299b08261d8122c96051024f953cd5"}, + {file = "pyarrow-15.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ad2459bf1f22b6a5cdcc27ebfd99307d5526b62d217b984b9f5c974651398832"}, + {file = "pyarrow-15.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58922e4bfece8b02abf7159f1f53a8f4d9f8e08f2d988109126c17c3bb261f22"}, + {file = "pyarrow-15.0.2-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:adccc81d3dc0478ea0b498807b39a8d41628fa9210729b2f718b78cb997c7c91"}, + {file = "pyarrow-15.0.2-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:8bd2baa5fe531571847983f36a30ddbf65261ef23e496862ece83bdceb70420d"}, + {file = "pyarrow-15.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:6669799a1d4ca9da9c7e06ef48368320f5856f36f9a4dd31a11839dda3f6cc8c"}, + {file = "pyarrow-15.0.2.tar.gz", hash = "sha256:9c9bc803cb3b7bfacc1e96ffbfd923601065d9d3f911179d81e72d99fd74a3d9"}, +] + +[package.dependencies] +numpy = ">=1.16.6,<2" + +[[package]] +name = "pyarrow-hotfix" +version = "0.6" +description = "" +optional = false +python-versions = ">=3.5" +files = [ + {file = "pyarrow_hotfix-0.6-py3-none-any.whl", hash = "sha256:dcc9ae2d220dff0083be6a9aa8e0cdee5182ad358d4931fce825c545e5c89178"}, + {file = "pyarrow_hotfix-0.6.tar.gz", hash = "sha256:79d3e030f7ff890d408a100ac16d6f00b14d44a502d7897cd9fc3e3a534e9945"}, +] + [[package]] name = "pycparser" version = "2.22" @@ -2354,7 +2649,7 @@ files = [ name = "pygame" version = "2.5.2" description = "Python Game Development" -optional = false +optional = true python-versions = ">=3.6" files = [ {file = "pygame-2.5.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a0769eb628c818761755eb0a0ca8216b95270ea8cbcbc82227e39ac9644643da"}, @@ -2528,7 +2823,7 @@ files = [ name = "pytest" version = "8.1.1" description = "pytest: simple powerful testing with Python" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "pytest-8.1.1-py3-none-any.whl", hash = "sha256:2a8386cfc11fa9d2c50ee7b2a57e7d898ef90470a7a34c4b949ff59662bb78b7"}, @@ -2550,7 +2845,7 @@ testing = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygm name = "pytest-cov" version = "5.0.0" description = "Pytest plugin for measuring coverage." -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "pytest-cov-5.0.0.tar.gz", hash = "sha256:5837b58e9f6ebd335b0f8060eecce69b662415b16dc503883a02f45dfeb14857"}, @@ -3370,7 +3665,7 @@ all = ["defusedxml", "fsspec", "imagecodecs (>=2023.8.12)", "lxml", "matplotlib" name = "tomli" version = "2.0.1" description = "A lil' TOML parser" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, @@ -3563,7 +3858,7 @@ zstd = ["zstandard (>=0.18.0)"] name = "virtualenv" version = "20.25.1" description = "Virtual Python Environment builder" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "virtualenv-20.25.1-py3-none-any.whl", hash = "sha256:961c026ac520bac5f69acb8ea063e8a4f071bcc9457b9c1f28f6b085c511583a"}, @@ -3634,6 +3929,226 @@ MarkupSafe = ">=2.1.1" [package.extras] watchdog = ["watchdog (>=2.3)"] +[[package]] +name = "xxhash" +version = "3.4.1" +description = "Python binding for xxHash" +optional = false +python-versions = ">=3.7" +files = [ + {file = "xxhash-3.4.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:91dbfa55346ad3e18e738742236554531a621042e419b70ad8f3c1d9c7a16e7f"}, + {file = "xxhash-3.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:665a65c2a48a72068fcc4d21721510df5f51f1142541c890491afc80451636d2"}, + {file = "xxhash-3.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb11628470a6004dc71a09fe90c2f459ff03d611376c1debeec2d648f44cb693"}, + {file = "xxhash-3.4.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5bef2a7dc7b4f4beb45a1edbba9b9194c60a43a89598a87f1a0226d183764189"}, + {file = "xxhash-3.4.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9c0f7b2d547d72c7eda7aa817acf8791f0146b12b9eba1d4432c531fb0352228"}, + {file = "xxhash-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:00f2fdef6b41c9db3d2fc0e7f94cb3db86693e5c45d6de09625caad9a469635b"}, + {file = "xxhash-3.4.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:23cfd9ca09acaf07a43e5a695143d9a21bf00f5b49b15c07d5388cadf1f9ce11"}, + {file = "xxhash-3.4.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:6a9ff50a3cf88355ca4731682c168049af1ca222d1d2925ef7119c1a78e95b3b"}, + {file = "xxhash-3.4.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:f1d7c69a1e9ca5faa75546fdd267f214f63f52f12692f9b3a2f6467c9e67d5e7"}, + {file = "xxhash-3.4.1-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:672b273040d5d5a6864a36287f3514efcd1d4b1b6a7480f294c4b1d1ee1b8de0"}, + {file = "xxhash-3.4.1-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:4178f78d70e88f1c4a89ff1ffe9f43147185930bb962ee3979dba15f2b1cc799"}, + {file = "xxhash-3.4.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:9804b9eb254d4b8cc83ab5a2002128f7d631dd427aa873c8727dba7f1f0d1c2b"}, + {file = "xxhash-3.4.1-cp310-cp310-win32.whl", hash = "sha256:c09c49473212d9c87261d22c74370457cfff5db2ddfc7fd1e35c80c31a8c14ce"}, + {file = "xxhash-3.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:ebbb1616435b4a194ce3466d7247df23499475c7ed4eb2681a1fa42ff766aff6"}, + {file = "xxhash-3.4.1-cp310-cp310-win_arm64.whl", hash = "sha256:25dc66be3db54f8a2d136f695b00cfe88018e59ccff0f3b8f545869f376a8a46"}, + {file = "xxhash-3.4.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:58c49083801885273e262c0f5bbeac23e520564b8357fbb18fb94ff09d3d3ea5"}, + {file = "xxhash-3.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b526015a973bfbe81e804a586b703f163861da36d186627e27524f5427b0d520"}, + {file = "xxhash-3.4.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:36ad4457644c91a966f6fe137d7467636bdc51a6ce10a1d04f365c70d6a16d7e"}, + {file = "xxhash-3.4.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:248d3e83d119770f96003271fe41e049dd4ae52da2feb8f832b7a20e791d2920"}, + {file = "xxhash-3.4.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2070b6d5bbef5ee031666cf21d4953c16e92c2f8a24a94b5c240f8995ba3b1d0"}, + {file = "xxhash-3.4.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b2746035f518f0410915e247877f7df43ef3372bf36cfa52cc4bc33e85242641"}, + {file = "xxhash-3.4.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2a8ba6181514681c2591840d5632fcf7356ab287d4aff1c8dea20f3c78097088"}, + {file = "xxhash-3.4.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:0aac5010869240e95f740de43cd6a05eae180c59edd182ad93bf12ee289484fa"}, + {file = "xxhash-3.4.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:4cb11d8debab1626181633d184b2372aaa09825bde709bf927704ed72765bed1"}, + {file = "xxhash-3.4.1-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:b29728cff2c12f3d9f1d940528ee83918d803c0567866e062683f300d1d2eff3"}, + {file = "xxhash-3.4.1-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:a15cbf3a9c40672523bdb6ea97ff74b443406ba0ab9bca10ceccd9546414bd84"}, + {file = "xxhash-3.4.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:6e66df260fed01ed8ea790c2913271641c58481e807790d9fca8bfd5a3c13844"}, + {file = "xxhash-3.4.1-cp311-cp311-win32.whl", hash = "sha256:e867f68a8f381ea12858e6d67378c05359d3a53a888913b5f7d35fbf68939d5f"}, + {file = "xxhash-3.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:200a5a3ad9c7c0c02ed1484a1d838b63edcf92ff538770ea07456a3732c577f4"}, + {file = "xxhash-3.4.1-cp311-cp311-win_arm64.whl", hash = "sha256:1d03f1c0d16d24ea032e99f61c552cb2b77d502e545187338bea461fde253583"}, + {file = "xxhash-3.4.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:c4bbba9b182697a52bc0c9f8ec0ba1acb914b4937cd4a877ad78a3b3eeabefb3"}, + {file = "xxhash-3.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:9fd28a9da300e64e434cfc96567a8387d9a96e824a9be1452a1e7248b7763b78"}, + {file = "xxhash-3.4.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6066d88c9329ab230e18998daec53d819daeee99d003955c8db6fc4971b45ca3"}, + {file = "xxhash-3.4.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:93805bc3233ad89abf51772f2ed3355097a5dc74e6080de19706fc447da99cd3"}, + {file = "xxhash-3.4.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:64da57d5ed586ebb2ecdde1e997fa37c27fe32fe61a656b77fabbc58e6fbff6e"}, + {file = "xxhash-3.4.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a97322e9a7440bf3c9805cbaac090358b43f650516486746f7fa482672593df"}, + {file = "xxhash-3.4.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bbe750d512982ee7d831838a5dee9e9848f3fb440e4734cca3f298228cc957a6"}, + {file = "xxhash-3.4.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:fd79d4087727daf4d5b8afe594b37d611ab95dc8e29fe1a7517320794837eb7d"}, + {file = "xxhash-3.4.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:743612da4071ff9aa4d055f3f111ae5247342931dedb955268954ef7201a71ff"}, + {file = "xxhash-3.4.1-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:b41edaf05734092f24f48c0958b3c6cbaaa5b7e024880692078c6b1f8247e2fc"}, + {file = "xxhash-3.4.1-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:a90356ead70d715fe64c30cd0969072de1860e56b78adf7c69d954b43e29d9fa"}, + {file = "xxhash-3.4.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ac56eebb364e44c85e1d9e9cc5f6031d78a34f0092fea7fc80478139369a8b4a"}, + {file = "xxhash-3.4.1-cp312-cp312-win32.whl", hash = "sha256:911035345932a153c427107397c1518f8ce456f93c618dd1c5b54ebb22e73747"}, + {file = "xxhash-3.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:f31ce76489f8601cc7b8713201ce94b4bd7b7ce90ba3353dccce7e9e1fee71fa"}, + {file = "xxhash-3.4.1-cp312-cp312-win_arm64.whl", hash = "sha256:b5beb1c6a72fdc7584102f42c4d9df232ee018ddf806e8c90906547dfb43b2da"}, + {file = "xxhash-3.4.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6d42b24d1496deb05dee5a24ed510b16de1d6c866c626c2beb11aebf3be278b9"}, + {file = "xxhash-3.4.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3b685fab18876b14a8f94813fa2ca80cfb5ab6a85d31d5539b7cd749ce9e3624"}, + {file = "xxhash-3.4.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:419ffe34c17ae2df019a4685e8d3934d46b2e0bbe46221ab40b7e04ed9f11137"}, + {file = "xxhash-3.4.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0e041ce5714f95251a88670c114b748bca3bf80cc72400e9f23e6d0d59cf2681"}, + {file = "xxhash-3.4.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc860d887c5cb2f524899fb8338e1bb3d5789f75fac179101920d9afddef284b"}, + {file = "xxhash-3.4.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:312eba88ffe0a05e332e3a6f9788b73883752be63f8588a6dc1261a3eaaaf2b2"}, + {file = "xxhash-3.4.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:e01226b6b6a1ffe4e6bd6d08cfcb3ca708b16f02eb06dd44f3c6e53285f03e4f"}, + {file = "xxhash-3.4.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:9f3025a0d5d8cf406a9313cd0d5789c77433ba2004b1c75439b67678e5136537"}, + {file = "xxhash-3.4.1-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:6d3472fd4afef2a567d5f14411d94060099901cd8ce9788b22b8c6f13c606a93"}, + {file = "xxhash-3.4.1-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:43984c0a92f06cac434ad181f329a1445017c33807b7ae4f033878d860a4b0f2"}, + {file = "xxhash-3.4.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:a55e0506fdb09640a82ec4f44171273eeabf6f371a4ec605633adb2837b5d9d5"}, + {file = "xxhash-3.4.1-cp37-cp37m-win32.whl", hash = "sha256:faec30437919555b039a8bdbaba49c013043e8f76c999670aef146d33e05b3a0"}, + {file = "xxhash-3.4.1-cp37-cp37m-win_amd64.whl", hash = "sha256:c9e1b646af61f1fc7083bb7b40536be944f1ac67ef5e360bca2d73430186971a"}, + {file = "xxhash-3.4.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:961d948b7b1c1b6c08484bbce3d489cdf153e4122c3dfb07c2039621243d8795"}, + {file = "xxhash-3.4.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:719a378930504ab159f7b8e20fa2aa1896cde050011af838af7e7e3518dd82de"}, + {file = "xxhash-3.4.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:74fb5cb9406ccd7c4dd917f16630d2e5e8cbbb02fc2fca4e559b2a47a64f4940"}, + {file = "xxhash-3.4.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5dab508ac39e0ab988039bc7f962c6ad021acd81fd29145962b068df4148c476"}, + {file = "xxhash-3.4.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8c59f3e46e7daf4c589e8e853d700ef6607afa037bfad32c390175da28127e8c"}, + {file = "xxhash-3.4.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8cc07256eff0795e0f642df74ad096f8c5d23fe66bc138b83970b50fc7f7f6c5"}, + {file = "xxhash-3.4.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e9f749999ed80f3955a4af0eb18bb43993f04939350b07b8dd2f44edc98ffee9"}, + {file = "xxhash-3.4.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:7688d7c02149a90a3d46d55b341ab7ad1b4a3f767be2357e211b4e893efbaaf6"}, + {file = "xxhash-3.4.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a8b4977963926f60b0d4f830941c864bed16aa151206c01ad5c531636da5708e"}, + {file = "xxhash-3.4.1-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:8106d88da330f6535a58a8195aa463ef5281a9aa23b04af1848ff715c4398fb4"}, + {file = "xxhash-3.4.1-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:4c76a77dbd169450b61c06fd2d5d436189fc8ab7c1571d39265d4822da16df22"}, + {file = "xxhash-3.4.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:11f11357c86d83e53719c592021fd524efa9cf024dc7cb1dfb57bbbd0d8713f2"}, + {file = "xxhash-3.4.1-cp38-cp38-win32.whl", hash = "sha256:0c786a6cd74e8765c6809892a0d45886e7c3dc54de4985b4a5eb8b630f3b8e3b"}, + {file = "xxhash-3.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:aabf37fb8fa27430d50507deeab2ee7b1bcce89910dd10657c38e71fee835594"}, + {file = "xxhash-3.4.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6127813abc1477f3a83529b6bbcfeddc23162cece76fa69aee8f6a8a97720562"}, + {file = "xxhash-3.4.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ef2e194262f5db16075caea7b3f7f49392242c688412f386d3c7b07c7733a70a"}, + {file = "xxhash-3.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:71be94265b6c6590f0018bbf73759d21a41c6bda20409782d8117e76cd0dfa8b"}, + {file = "xxhash-3.4.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:10e0a619cdd1c0980e25eb04e30fe96cf8f4324758fa497080af9c21a6de573f"}, + {file = "xxhash-3.4.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fa122124d2e3bd36581dd78c0efa5f429f5220313479fb1072858188bc2d5ff1"}, + {file = "xxhash-3.4.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e17032f5a4fea0a074717fe33477cb5ee723a5f428de7563e75af64bfc1b1e10"}, + {file = "xxhash-3.4.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ca7783b20e3e4f3f52f093538895863f21d18598f9a48211ad757680c3bd006f"}, + {file = "xxhash-3.4.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:d77d09a1113899fad5f354a1eb4f0a9afcf58cefff51082c8ad643ff890e30cf"}, + {file = "xxhash-3.4.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:21287bcdd299fdc3328cc0fbbdeaa46838a1c05391264e51ddb38a3f5b09611f"}, + {file = "xxhash-3.4.1-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:dfd7a6cc483e20b4ad90224aeb589e64ec0f31e5610ab9957ff4314270b2bf31"}, + {file = "xxhash-3.4.1-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:543c7fcbc02bbb4840ea9915134e14dc3dc15cbd5a30873a7a5bf66039db97ec"}, + {file = "xxhash-3.4.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:fe0a98d990e433013f41827b62be9ab43e3cf18e08b1483fcc343bda0d691182"}, + {file = "xxhash-3.4.1-cp39-cp39-win32.whl", hash = "sha256:b9097af00ebf429cc7c0e7d2fdf28384e4e2e91008130ccda8d5ae653db71e54"}, + {file = "xxhash-3.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:d699b921af0dcde50ab18be76c0d832f803034d80470703700cb7df0fbec2832"}, + {file = "xxhash-3.4.1-cp39-cp39-win_arm64.whl", hash = "sha256:2be491723405e15cc099ade1280133ccfbf6322d2ef568494fb7d07d280e7eee"}, + {file = "xxhash-3.4.1-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:431625fad7ab5649368c4849d2b49a83dc711b1f20e1f7f04955aab86cd307bc"}, + {file = "xxhash-3.4.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fc6dbd5fc3c9886a9e041848508b7fb65fd82f94cc793253990f81617b61fe49"}, + {file = "xxhash-3.4.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f3ff8dbd0ec97aec842476cb8ccc3e17dd288cd6ce3c8ef38bff83d6eb927817"}, + {file = "xxhash-3.4.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ef73a53fe90558a4096e3256752268a8bdc0322f4692ed928b6cd7ce06ad4fe3"}, + {file = "xxhash-3.4.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:450401f42bbd274b519d3d8dcf3c57166913381a3d2664d6609004685039f9d3"}, + {file = "xxhash-3.4.1-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:a162840cf4de8a7cd8720ff3b4417fbc10001eefdd2d21541a8226bb5556e3bb"}, + {file = "xxhash-3.4.1-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b736a2a2728ba45017cb67785e03125a79d246462dfa892d023b827007412c52"}, + {file = "xxhash-3.4.1-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1d0ae4c2e7698adef58710d6e7a32ff518b66b98854b1c68e70eee504ad061d8"}, + {file = "xxhash-3.4.1-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d6322c4291c3ff174dcd104fae41500e75dad12be6f3085d119c2c8a80956c51"}, + {file = "xxhash-3.4.1-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:dd59ed668801c3fae282f8f4edadf6dc7784db6d18139b584b6d9677ddde1b6b"}, + {file = "xxhash-3.4.1-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:92693c487e39523a80474b0394645b393f0ae781d8db3474ccdcead0559ccf45"}, + {file = "xxhash-3.4.1-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4603a0f642a1e8d7f3ba5c4c25509aca6a9c1cc16f85091004a7028607ead663"}, + {file = "xxhash-3.4.1-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6fa45e8cbfbadb40a920fe9ca40c34b393e0b067082d94006f7f64e70c7490a6"}, + {file = "xxhash-3.4.1-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:595b252943b3552de491ff51e5bb79660f84f033977f88f6ca1605846637b7c6"}, + {file = "xxhash-3.4.1-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:562d8b8f783c6af969806aaacf95b6c7b776929ae26c0cd941d54644ea7ef51e"}, + {file = "xxhash-3.4.1-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:41ddeae47cf2828335d8d991f2d2b03b0bdc89289dc64349d712ff8ce59d0647"}, + {file = "xxhash-3.4.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c44d584afdf3c4dbb3277e32321d1a7b01d6071c1992524b6543025fb8f4206f"}, + {file = "xxhash-3.4.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd7bddb3a5b86213cc3f2c61500c16945a1b80ecd572f3078ddbbe68f9dabdfb"}, + {file = "xxhash-3.4.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9ecb6c987b62437c2f99c01e97caf8d25660bf541fe79a481d05732e5236719c"}, + {file = "xxhash-3.4.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:696b4e18b7023527d5c50ed0626ac0520edac45a50ec7cf3fc265cd08b1f4c03"}, + {file = "xxhash-3.4.1.tar.gz", hash = "sha256:0379d6cf1ff987cd421609a264ce025e74f346e3e145dd106c0cc2e3ec3f99a9"}, +] + +[[package]] +name = "yarl" +version = "1.9.4" +description = "Yet another URL library" +optional = false +python-versions = ">=3.7" +files = [ + {file = "yarl-1.9.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a8c1df72eb746f4136fe9a2e72b0c9dc1da1cbd23b5372f94b5820ff8ae30e0e"}, + {file = "yarl-1.9.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a3a6ed1d525bfb91b3fc9b690c5a21bb52de28c018530ad85093cc488bee2dd2"}, + {file = "yarl-1.9.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c38c9ddb6103ceae4e4498f9c08fac9b590c5c71b0370f98714768e22ac6fa66"}, + {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d9e09c9d74f4566e905a0b8fa668c58109f7624db96a2171f21747abc7524234"}, + {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b8477c1ee4bd47c57d49621a062121c3023609f7a13b8a46953eb6c9716ca392"}, + {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d5ff2c858f5f6a42c2a8e751100f237c5e869cbde669a724f2062d4c4ef93551"}, + {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:357495293086c5b6d34ca9616a43d329317feab7917518bc97a08f9e55648455"}, + {file = "yarl-1.9.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:54525ae423d7b7a8ee81ba189f131054defdb122cde31ff17477951464c1691c"}, + {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:801e9264d19643548651b9db361ce3287176671fb0117f96b5ac0ee1c3530d53"}, + {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:e516dc8baf7b380e6c1c26792610230f37147bb754d6426462ab115a02944385"}, + {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:7d5aaac37d19b2904bb9dfe12cdb08c8443e7ba7d2852894ad448d4b8f442863"}, + {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:54beabb809ffcacbd9d28ac57b0db46e42a6e341a030293fb3185c409e626b8b"}, + {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:bac8d525a8dbc2a1507ec731d2867025d11ceadcb4dd421423a5d42c56818541"}, + {file = "yarl-1.9.4-cp310-cp310-win32.whl", hash = "sha256:7855426dfbddac81896b6e533ebefc0af2f132d4a47340cee6d22cac7190022d"}, + {file = "yarl-1.9.4-cp310-cp310-win_amd64.whl", hash = "sha256:848cd2a1df56ddbffeb375535fb62c9d1645dde33ca4d51341378b3f5954429b"}, + {file = "yarl-1.9.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:35a2b9396879ce32754bd457d31a51ff0a9d426fd9e0e3c33394bf4b9036b099"}, + {file = "yarl-1.9.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c7d56b293cc071e82532f70adcbd8b61909eec973ae9d2d1f9b233f3d943f2c"}, + {file = "yarl-1.9.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d8a1c6c0be645c745a081c192e747c5de06e944a0d21245f4cf7c05e457c36e0"}, + {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4b3c1ffe10069f655ea2d731808e76e0f452fc6c749bea04781daf18e6039525"}, + {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:549d19c84c55d11687ddbd47eeb348a89df9cb30e1993f1b128f4685cd0ebbf8"}, + {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a7409f968456111140c1c95301cadf071bd30a81cbd7ab829169fb9e3d72eae9"}, + {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e23a6d84d9d1738dbc6e38167776107e63307dfc8ad108e580548d1f2c587f42"}, + {file = "yarl-1.9.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d8b889777de69897406c9fb0b76cdf2fd0f31267861ae7501d93003d55f54fbe"}, + {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:03caa9507d3d3c83bca08650678e25364e1843b484f19986a527630ca376ecce"}, + {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:4e9035df8d0880b2f1c7f5031f33f69e071dfe72ee9310cfc76f7b605958ceb9"}, + {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:c0ec0ed476f77db9fb29bca17f0a8fcc7bc97ad4c6c1d8959c507decb22e8572"}, + {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:ee04010f26d5102399bd17f8df8bc38dc7ccd7701dc77f4a68c5b8d733406958"}, + {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:49a180c2e0743d5d6e0b4d1a9e5f633c62eca3f8a86ba5dd3c471060e352ca98"}, + {file = "yarl-1.9.4-cp311-cp311-win32.whl", hash = "sha256:81eb57278deb6098a5b62e88ad8281b2ba09f2f1147c4767522353eaa6260b31"}, + {file = "yarl-1.9.4-cp311-cp311-win_amd64.whl", hash = "sha256:d1d2532b340b692880261c15aee4dc94dd22ca5d61b9db9a8a361953d36410b1"}, + {file = "yarl-1.9.4-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:0d2454f0aef65ea81037759be5ca9947539667eecebca092733b2eb43c965a81"}, + {file = "yarl-1.9.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:44d8ffbb9c06e5a7f529f38f53eda23e50d1ed33c6c869e01481d3fafa6b8142"}, + {file = "yarl-1.9.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:aaaea1e536f98754a6e5c56091baa1b6ce2f2700cc4a00b0d49eca8dea471074"}, + {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3777ce5536d17989c91696db1d459574e9a9bd37660ea7ee4d3344579bb6f129"}, + {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9fc5fc1eeb029757349ad26bbc5880557389a03fa6ada41703db5e068881e5f2"}, + {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ea65804b5dc88dacd4a40279af0cdadcfe74b3e5b4c897aa0d81cf86927fee78"}, + {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa102d6d280a5455ad6a0f9e6d769989638718e938a6a0a2ff3f4a7ff8c62cc4"}, + {file = "yarl-1.9.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:09efe4615ada057ba2d30df871d2f668af661e971dfeedf0c159927d48bbeff0"}, + {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:008d3e808d03ef28542372d01057fd09168419cdc8f848efe2804f894ae03e51"}, + {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:6f5cb257bc2ec58f437da2b37a8cd48f666db96d47b8a3115c29f316313654ff"}, + {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:992f18e0ea248ee03b5a6e8b3b4738850ae7dbb172cc41c966462801cbf62cf7"}, + {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:0e9d124c191d5b881060a9e5060627694c3bdd1fe24c5eecc8d5d7d0eb6faabc"}, + {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3986b6f41ad22988e53d5778f91855dc0399b043fc8946d4f2e68af22ee9ff10"}, + {file = "yarl-1.9.4-cp312-cp312-win32.whl", hash = "sha256:4b21516d181cd77ebd06ce160ef8cc2a5e9ad35fb1c5930882baff5ac865eee7"}, + {file = "yarl-1.9.4-cp312-cp312-win_amd64.whl", hash = "sha256:a9bd00dc3bc395a662900f33f74feb3e757429e545d831eef5bb280252631984"}, + {file = "yarl-1.9.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:63b20738b5aac74e239622d2fe30df4fca4942a86e31bf47a81a0e94c14df94f"}, + {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7d7f7de27b8944f1fee2c26a88b4dabc2409d2fea7a9ed3df79b67277644e17"}, + {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c74018551e31269d56fab81a728f683667e7c28c04e807ba08f8c9e3bba32f14"}, + {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ca06675212f94e7a610e85ca36948bb8fc023e458dd6c63ef71abfd482481aa5"}, + {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5aef935237d60a51a62b86249839b51345f47564208c6ee615ed2a40878dccdd"}, + {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2b134fd795e2322b7684155b7855cc99409d10b2e408056db2b93b51a52accc7"}, + {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:d25039a474c4c72a5ad4b52495056f843a7ff07b632c1b92ea9043a3d9950f6e"}, + {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:f7d6b36dd2e029b6bcb8a13cf19664c7b8e19ab3a58e0fefbb5b8461447ed5ec"}, + {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:957b4774373cf6f709359e5c8c4a0af9f6d7875db657adb0feaf8d6cb3c3964c"}, + {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:d7eeb6d22331e2fd42fce928a81c697c9ee2d51400bd1a28803965883e13cead"}, + {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:6a962e04b8f91f8c4e5917e518d17958e3bdee71fd1d8b88cdce74dd0ebbf434"}, + {file = "yarl-1.9.4-cp37-cp37m-win32.whl", hash = "sha256:f3bc6af6e2b8f92eced34ef6a96ffb248e863af20ef4fde9448cc8c9b858b749"}, + {file = "yarl-1.9.4-cp37-cp37m-win_amd64.whl", hash = "sha256:ad4d7a90a92e528aadf4965d685c17dacff3df282db1121136c382dc0b6014d2"}, + {file = "yarl-1.9.4-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:ec61d826d80fc293ed46c9dd26995921e3a82146feacd952ef0757236fc137be"}, + {file = "yarl-1.9.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:8be9e837ea9113676e5754b43b940b50cce76d9ed7d2461df1af39a8ee674d9f"}, + {file = "yarl-1.9.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:bef596fdaa8f26e3d66af846bbe77057237cb6e8efff8cd7cc8dff9a62278bbf"}, + {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2d47552b6e52c3319fede1b60b3de120fe83bde9b7bddad11a69fb0af7db32f1"}, + {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:84fc30f71689d7fc9168b92788abc977dc8cefa806909565fc2951d02f6b7d57"}, + {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4aa9741085f635934f3a2583e16fcf62ba835719a8b2b28fb2917bb0537c1dfa"}, + {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:206a55215e6d05dbc6c98ce598a59e6fbd0c493e2de4ea6cc2f4934d5a18d130"}, + {file = "yarl-1.9.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:07574b007ee20e5c375a8fe4a0789fad26db905f9813be0f9fef5a68080de559"}, + {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:5a2e2433eb9344a163aced6a5f6c9222c0786e5a9e9cac2c89f0b28433f56e23"}, + {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:6ad6d10ed9b67a382b45f29ea028f92d25bc0bc1daf6c5b801b90b5aa70fb9ec"}, + {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:6fe79f998a4052d79e1c30eeb7d6c1c1056ad33300f682465e1b4e9b5a188b78"}, + {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:a825ec844298c791fd28ed14ed1bffc56a98d15b8c58a20e0e08c1f5f2bea1be"}, + {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8619d6915b3b0b34420cf9b2bb6d81ef59d984cb0fde7544e9ece32b4b3043c3"}, + {file = "yarl-1.9.4-cp38-cp38-win32.whl", hash = "sha256:686a0c2f85f83463272ddffd4deb5e591c98aac1897d65e92319f729c320eece"}, + {file = "yarl-1.9.4-cp38-cp38-win_amd64.whl", hash = "sha256:a00862fb23195b6b8322f7d781b0dc1d82cb3bcac346d1e38689370cc1cc398b"}, + {file = "yarl-1.9.4-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:604f31d97fa493083ea21bd9b92c419012531c4e17ea6da0f65cacdcf5d0bd27"}, + {file = "yarl-1.9.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:8a854227cf581330ffa2c4824d96e52ee621dd571078a252c25e3a3b3d94a1b1"}, + {file = "yarl-1.9.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ba6f52cbc7809cd8d74604cce9c14868306ae4aa0282016b641c661f981a6e91"}, + {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a6327976c7c2f4ee6816eff196e25385ccc02cb81427952414a64811037bbc8b"}, + {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8397a3817d7dcdd14bb266283cd1d6fc7264a48c186b986f32e86d86d35fbac5"}, + {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e0381b4ce23ff92f8170080c97678040fc5b08da85e9e292292aba67fdac6c34"}, + {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:23d32a2594cb5d565d358a92e151315d1b2268bc10f4610d098f96b147370136"}, + {file = "yarl-1.9.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ddb2a5c08a4eaaba605340fdee8fc08e406c56617566d9643ad8bf6852778fc7"}, + {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:26a1dc6285e03f3cc9e839a2da83bcbf31dcb0d004c72d0730e755b33466c30e"}, + {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:18580f672e44ce1238b82f7fb87d727c4a131f3a9d33a5e0e82b793362bf18b4"}, + {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:29e0f83f37610f173eb7e7b5562dd71467993495e568e708d99e9d1944f561ec"}, + {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:1f23e4fe1e8794f74b6027d7cf19dc25f8b63af1483d91d595d4a07eca1fb26c"}, + {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:db8e58b9d79200c76956cefd14d5c90af54416ff5353c5bfd7cbe58818e26ef0"}, + {file = "yarl-1.9.4-cp39-cp39-win32.whl", hash = "sha256:c7224cab95645c7ab53791022ae77a4509472613e839dab722a72abe5a684575"}, + {file = "yarl-1.9.4-cp39-cp39-win_amd64.whl", hash = "sha256:824d6c50492add5da9374875ce72db7a0733b29c2394890aef23d533106e2b15"}, + {file = "yarl-1.9.4-py3-none-any.whl", hash = "sha256:928cecb0ef9d5a7946eb6ff58417ad2fe9375762382f1bf5c55e61645f2c43ad"}, + {file = "yarl-1.9.4.tar.gz", hash = "sha256:566db86717cf8080b99b58b083b773a908ae40f06681e87e589a976faf8246bf"}, +] + +[package.dependencies] +idna = ">=2.0" +multidict = ">=4.0" + [[package]] name = "zarr" version = "2.17.2" @@ -3672,10 +4187,12 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [extras] aloha = ["gym-aloha"] +dev = ["debugpy", "pre-commit"] pusht = ["gym-pusht"] +test = ["pytest", "pytest-cov"] xarm = ["gym-xarm"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "7ec0310f8dd0ffa4d92fa78e06513bce98c3657692b3753ff34aadd297a3766c" +content-hash = "01ad4eb04061ec9f785d4574bf66d3e5cb4549e2ea11ab175895f94cb62c1f1c" diff --git a/pyproject.toml b/pyproject.toml index 743dece8..09348989 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,19 +1,25 @@ [tool.poetry] name = "lerobot" version = "0.1.0" -description = "Le robot is learning" +description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch" authors = [ "Rémi Cadène ", + "Alexander Soare ", + "Quentin Gallouédec ", "Simon Alibert ", + "Thomas Wolf ", ] -repository = "https://github.com/Cadene/lerobot" +repository = "https://github.com/huggingface/lerobot" readme = "README.md" -license = "MIT" +license = "Apache-2.0" classifiers=[ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", "Topic :: Software Development :: Build Tools", - "License :: OSI Approved :: MIT License", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "License :: OSI Approved :: Apache Software License", "Programming Language :: Python :: 3.10", ] packages = [{include = "lerobot"}] @@ -23,52 +29,38 @@ packages = [{include = "lerobot"}] python = "^3.10" termcolor = "^2.4.0" omegaconf = "^2.3.0" -pandas = "^2.2.1" wandb = "^0.16.3" -moviepy = "^1.0.3" -imageio = {extras = ["pyav"], version = "^2.34.0"} +imageio = {extras = ["ffmpeg"], version = "^2.34.0"} gdown = "^5.1.0" hydra-core = "^1.3.2" einops = "^0.7.0" -pygame = "^2.5.2" pymunk = "^6.6.0" zarr = "^2.17.0" numba = "^0.59.0" -mpmath = "^1.3.0" torch = "^2.2.1" opencv-python = "^4.9.0.80" diffusers = "^0.26.3" torchvision = "^0.17.1" h5py = "^3.10.0" -huggingface-hub = {extras = ["hf-transfer"], version = "^0.21.4"} +huggingface-hub = "^0.21.4" robomimic = "0.2.0" gymnasium = "^0.29.1" cmake = "^3.29.0.1" gym-pusht = { git = "git@github.com:huggingface/gym-pusht.git", optional = true} gym-xarm = { git = "git@github.com:huggingface/gym-xarm.git", optional = true} gym-aloha = { git = "git@github.com:huggingface/gym-aloha.git", optional = true} -# gym-pusht = { path = "../gym-pusht", develop = true, optional = true} -# gym-xarm = { path = "../gym-xarm", develop = true, optional = true} -# gym-aloha = { path = "../gym-aloha", develop = true, optional = true} +pre-commit = {version = "^3.7.0", optional = true} +debugpy = {version = "^1.8.1", optional = true} +pytest = {version = "^8.1.0", optional = true} +pytest-cov = {version = "^5.0.0", optional = true} +datasets = "^2.18.0" [tool.poetry.extras] pusht = ["gym-pusht"] xarm = ["gym-xarm"] aloha = ["gym-aloha"] - - -[tool.poetry.group.dev] -optional = true - - -[tool.poetry.group.dev.dependencies] -pre-commit = "^3.6.2" -debugpy = "^1.8.1" - - -[tool.poetry.group.test.dependencies] -pytest = "^8.1.0" -pytest-cov = "^5.0.0" +dev = ["pre-commit", "debugpy"] +test = ["pytest", "pytest-cov"] [tool.ruff] @@ -103,13 +95,7 @@ exclude = [ select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"] ignore-init-module-imports = true -[tool.poetry-dynamic-versioning] -enable = true - [build-system] -requires = ["poetry-core>=1.0.0", "poetry-dynamic-versioning>=1.0.0,<2.0.0"] -build-backend = "poetry_dynamic_versioning.backend" - -[tool.black] -line-length = 110 +requires = ["poetry-core>=1.5.0"] +build-backend = "poetry.core.masonry.api" diff --git a/sbatch.sh b/sbatch.sh deleted file mode 100644 index c08f7055..00000000 --- a/sbatch.sh +++ /dev/null @@ -1,25 +0,0 @@ -#!/bin/bash -#SBATCH --nodes=1 # total number of nodes (N to be defined) -#SBATCH --ntasks-per-node=1 # number of tasks per node (here 8 tasks, or 1 task per GPU) -#SBATCH --gres=gpu:1 # number of GPUs reserved per node (here 8, or all the GPUs) -#SBATCH --cpus-per-task=8 # number of cores per task (8x8 = 64 cores, or all the cores) -#SBATCH --time=2-00:00:00 -#SBATCH --output=/home/rcadene/slurm/%j.out -#SBATCH --error=/home/rcadene/slurm/%j.err -#SBATCH --qos=low -#SBATCH --mail-user=re.cadene@gmail.com -#SBATCH --mail-type=ALL - -CMD=$@ -echo "command: $CMD" - -apptainer exec --nv \ -~/apptainer/nvidia_cuda:12.2.2-devel-ubuntu22.04.sif $SHELL - -source ~/.bashrc -#conda activate fowm -conda activate lerobot - -export DATA_DIR="data" - -srun $CMD diff --git a/sbatch_hopper.sh b/sbatch_hopper.sh deleted file mode 100644 index cc410048..00000000 --- a/sbatch_hopper.sh +++ /dev/null @@ -1,17 +0,0 @@ -#!/bin/bash -#SBATCH --nodes=1 # total number of nodes (N to be defined) -#SBATCH --ntasks-per-node=1 # number of tasks per node (here 8 tasks, or 1 task per GPU) -#SBATCH --qos=normal # number of GPUs reserved per node (here 8, or all the GPUs) -#SBATCH --partition=hopper-prod -#SBATCH --gres=gpu:1 # number of GPUs reserved per node (here 8, or all the GPUs) -#SBATCH --cpus-per-task=12 # number of cores per task -#SBATCH --mem-per-cpu=11G -#SBATCH --time=12:00:00 -#SBATCH --output=/admin/home/remi_cadene/slurm/%j.out -#SBATCH --error=/admin/home/remi_cadene/slurm/%j.err -#SBATCH --mail-user=remi_cadene@huggingface.co -#SBATCH --mail-type=ALL - -CMD=$@ -echo "command: $CMD" -srun $CMD diff --git a/tests/data/aloha_sim_insertion_human/data_dict.pth b/tests/data/aloha_sim_insertion_human/data_dict.pth deleted file mode 100644 index 1370c9ea..00000000 Binary files a/tests/data/aloha_sim_insertion_human/data_dict.pth and /dev/null differ diff --git a/tests/data/aloha_sim_insertion_human/data_ids_per_episode.pth b/tests/data/aloha_sim_insertion_human/data_ids_per_episode.pth deleted file mode 100644 index a1d481dd..00000000 Binary files a/tests/data/aloha_sim_insertion_human/data_ids_per_episode.pth and /dev/null differ diff --git a/tests/data/aloha_sim_insertion_human/train/data-00000-of-00001.arrow b/tests/data/aloha_sim_insertion_human/train/data-00000-of-00001.arrow new file mode 100644 index 00000000..165298cf Binary files /dev/null and b/tests/data/aloha_sim_insertion_human/train/data-00000-of-00001.arrow differ diff --git a/tests/data/aloha_sim_insertion_human/train/dataset_info.json b/tests/data/aloha_sim_insertion_human/train/dataset_info.json new file mode 100644 index 00000000..542c7bf1 --- /dev/null +++ b/tests/data/aloha_sim_insertion_human/train/dataset_info.json @@ -0,0 +1,55 @@ +{ + "citation": "", + "description": "", + "features": { + "observation.images.top": { + "_type": "Image" + }, + "observation.state": { + "feature": { + "dtype": "float32", + "_type": "Value" + }, + "length": 14, + "_type": "Sequence" + }, + "action": { + "feature": { + "dtype": "float32", + "_type": "Value" + }, + "length": 14, + "_type": "Sequence" + }, + "episode_id": { + "dtype": "int64", + "_type": "Value" + }, + "frame_id": { + "dtype": "int64", + "_type": "Value" + }, + "timestamp": { + "dtype": "float32", + "_type": "Value" + }, + "next.done": { + "dtype": "bool", + "_type": "Value" + }, + "episode_data_index_from": { + "dtype": "int64", + "_type": "Value" + }, + "episode_data_index_to": { + "dtype": "int64", + "_type": "Value" + }, + "index": { + "dtype": "int64", + "_type": "Value" + } + }, + "homepage": "", + "license": "" +} \ No newline at end of file diff --git a/tests/data/aloha_sim_insertion_human/train/state.json b/tests/data/aloha_sim_insertion_human/train/state.json new file mode 100644 index 00000000..39101fd5 --- /dev/null +++ b/tests/data/aloha_sim_insertion_human/train/state.json @@ -0,0 +1,13 @@ +{ + "_data_files": [ + { + "filename": "data-00000-of-00001.arrow" + } + ], + "_fingerprint": "d79cf82ffc86f110", + "_format_columns": null, + "_format_kwargs": {}, + "_format_type": "torch", + "_output_all_columns": false, + "_split": null +} \ No newline at end of file diff --git a/tests/data/aloha_sim_insertion_scripted/data_dict.pth b/tests/data/aloha_sim_insertion_scripted/data_dict.pth deleted file mode 100644 index 00c9f335..00000000 Binary files a/tests/data/aloha_sim_insertion_scripted/data_dict.pth and /dev/null differ diff --git a/tests/data/aloha_sim_insertion_scripted/data_ids_per_episode.pth b/tests/data/aloha_sim_insertion_scripted/data_ids_per_episode.pth deleted file mode 100644 index a1d481dd..00000000 Binary files a/tests/data/aloha_sim_insertion_scripted/data_ids_per_episode.pth and /dev/null differ diff --git a/tests/data/aloha_sim_insertion_scripted/train/data-00000-of-00001.arrow b/tests/data/aloha_sim_insertion_scripted/train/data-00000-of-00001.arrow new file mode 100644 index 00000000..034f759f Binary files /dev/null and b/tests/data/aloha_sim_insertion_scripted/train/data-00000-of-00001.arrow differ diff --git a/tests/data/aloha_sim_insertion_scripted/train/dataset_info.json b/tests/data/aloha_sim_insertion_scripted/train/dataset_info.json new file mode 100644 index 00000000..542c7bf1 --- /dev/null +++ b/tests/data/aloha_sim_insertion_scripted/train/dataset_info.json @@ -0,0 +1,55 @@ +{ + "citation": "", + "description": "", + "features": { + "observation.images.top": { + "_type": "Image" + }, + "observation.state": { + "feature": { + "dtype": "float32", + "_type": "Value" + }, + "length": 14, + "_type": "Sequence" + }, + "action": { + "feature": { + "dtype": "float32", + "_type": "Value" + }, + "length": 14, + "_type": "Sequence" + }, + "episode_id": { + "dtype": "int64", + "_type": "Value" + }, + "frame_id": { + "dtype": "int64", + "_type": "Value" + }, + "timestamp": { + "dtype": "float32", + "_type": "Value" + }, + "next.done": { + "dtype": "bool", + "_type": "Value" + }, + "episode_data_index_from": { + "dtype": "int64", + "_type": "Value" + }, + "episode_data_index_to": { + "dtype": "int64", + "_type": "Value" + }, + "index": { + "dtype": "int64", + "_type": "Value" + } + }, + "homepage": "", + "license": "" +} \ No newline at end of file diff --git a/tests/data/aloha_sim_insertion_scripted/train/state.json b/tests/data/aloha_sim_insertion_scripted/train/state.json new file mode 100644 index 00000000..ecaa8fd8 --- /dev/null +++ b/tests/data/aloha_sim_insertion_scripted/train/state.json @@ -0,0 +1,13 @@ +{ + "_data_files": [ + { + "filename": "data-00000-of-00001.arrow" + } + ], + "_fingerprint": "d8e4a817b5449498", + "_format_columns": null, + "_format_kwargs": {}, + "_format_type": "torch", + "_output_all_columns": false, + "_split": null +} \ No newline at end of file diff --git a/tests/data/aloha_sim_transfer_cube_human/data_dict.pth b/tests/data/aloha_sim_transfer_cube_human/data_dict.pth deleted file mode 100644 index ab851779..00000000 Binary files a/tests/data/aloha_sim_transfer_cube_human/data_dict.pth and /dev/null differ diff --git a/tests/data/aloha_sim_transfer_cube_human/data_ids_per_episode.pth b/tests/data/aloha_sim_transfer_cube_human/data_ids_per_episode.pth deleted file mode 100644 index a1d481dd..00000000 Binary files a/tests/data/aloha_sim_transfer_cube_human/data_ids_per_episode.pth and /dev/null differ diff --git a/tests/data/aloha_sim_transfer_cube_human/train/data-00000-of-00001.arrow b/tests/data/aloha_sim_transfer_cube_human/train/data-00000-of-00001.arrow new file mode 100644 index 00000000..9682f005 Binary files /dev/null and b/tests/data/aloha_sim_transfer_cube_human/train/data-00000-of-00001.arrow differ diff --git a/tests/data/aloha_sim_transfer_cube_human/train/dataset_info.json b/tests/data/aloha_sim_transfer_cube_human/train/dataset_info.json new file mode 100644 index 00000000..542c7bf1 --- /dev/null +++ b/tests/data/aloha_sim_transfer_cube_human/train/dataset_info.json @@ -0,0 +1,55 @@ +{ + "citation": "", + "description": "", + "features": { + "observation.images.top": { + "_type": "Image" + }, + "observation.state": { + "feature": { + "dtype": "float32", + "_type": "Value" + }, + "length": 14, + "_type": "Sequence" + }, + "action": { + "feature": { + "dtype": "float32", + "_type": "Value" + }, + "length": 14, + "_type": "Sequence" + }, + "episode_id": { + "dtype": "int64", + "_type": "Value" + }, + "frame_id": { + "dtype": "int64", + "_type": "Value" + }, + "timestamp": { + "dtype": "float32", + "_type": "Value" + }, + "next.done": { + "dtype": "bool", + "_type": "Value" + }, + "episode_data_index_from": { + "dtype": "int64", + "_type": "Value" + }, + "episode_data_index_to": { + "dtype": "int64", + "_type": "Value" + }, + "index": { + "dtype": "int64", + "_type": "Value" + } + }, + "homepage": "", + "license": "" +} \ No newline at end of file diff --git a/tests/data/aloha_sim_transfer_cube_human/train/state.json b/tests/data/aloha_sim_transfer_cube_human/train/state.json new file mode 100644 index 00000000..0167986b --- /dev/null +++ b/tests/data/aloha_sim_transfer_cube_human/train/state.json @@ -0,0 +1,13 @@ +{ + "_data_files": [ + { + "filename": "data-00000-of-00001.arrow" + } + ], + "_fingerprint": "f03482befa767127", + "_format_columns": null, + "_format_kwargs": {}, + "_format_type": "torch", + "_output_all_columns": false, + "_split": null +} \ No newline at end of file diff --git a/tests/data/aloha_sim_transfer_cube_scripted/data_dict.pth b/tests/data/aloha_sim_transfer_cube_scripted/data_dict.pth deleted file mode 100644 index bd308bb0..00000000 Binary files a/tests/data/aloha_sim_transfer_cube_scripted/data_dict.pth and /dev/null differ diff --git a/tests/data/aloha_sim_transfer_cube_scripted/data_ids_per_episode.pth b/tests/data/aloha_sim_transfer_cube_scripted/data_ids_per_episode.pth deleted file mode 100644 index a1d481dd..00000000 Binary files a/tests/data/aloha_sim_transfer_cube_scripted/data_ids_per_episode.pth and /dev/null differ diff --git a/tests/data/aloha_sim_transfer_cube_scripted/train/data-00000-of-00001.arrow b/tests/data/aloha_sim_transfer_cube_scripted/train/data-00000-of-00001.arrow new file mode 100644 index 00000000..567191d5 Binary files /dev/null and b/tests/data/aloha_sim_transfer_cube_scripted/train/data-00000-of-00001.arrow differ diff --git a/tests/data/aloha_sim_transfer_cube_scripted/train/dataset_info.json b/tests/data/aloha_sim_transfer_cube_scripted/train/dataset_info.json new file mode 100644 index 00000000..542c7bf1 --- /dev/null +++ b/tests/data/aloha_sim_transfer_cube_scripted/train/dataset_info.json @@ -0,0 +1,55 @@ +{ + "citation": "", + "description": "", + "features": { + "observation.images.top": { + "_type": "Image" + }, + "observation.state": { + "feature": { + "dtype": "float32", + "_type": "Value" + }, + "length": 14, + "_type": "Sequence" + }, + "action": { + "feature": { + "dtype": "float32", + "_type": "Value" + }, + "length": 14, + "_type": "Sequence" + }, + "episode_id": { + "dtype": "int64", + "_type": "Value" + }, + "frame_id": { + "dtype": "int64", + "_type": "Value" + }, + "timestamp": { + "dtype": "float32", + "_type": "Value" + }, + "next.done": { + "dtype": "bool", + "_type": "Value" + }, + "episode_data_index_from": { + "dtype": "int64", + "_type": "Value" + }, + "episode_data_index_to": { + "dtype": "int64", + "_type": "Value" + }, + "index": { + "dtype": "int64", + "_type": "Value" + } + }, + "homepage": "", + "license": "" +} \ No newline at end of file diff --git a/tests/data/aloha_sim_transfer_cube_scripted/train/state.json b/tests/data/aloha_sim_transfer_cube_scripted/train/state.json new file mode 100644 index 00000000..56005bc9 --- /dev/null +++ b/tests/data/aloha_sim_transfer_cube_scripted/train/state.json @@ -0,0 +1,13 @@ +{ + "_data_files": [ + { + "filename": "data-00000-of-00001.arrow" + } + ], + "_fingerprint": "93e03c6320c7d56e", + "_format_columns": null, + "_format_kwargs": {}, + "_format_type": "torch", + "_output_all_columns": false, + "_split": null +} \ No newline at end of file diff --git a/tests/data/pusht/data_dict.pth b/tests/data/pusht/data_dict.pth deleted file mode 100644 index 40d96a51..00000000 Binary files a/tests/data/pusht/data_dict.pth and /dev/null differ diff --git a/tests/data/pusht/data_ids_per_episode.pth b/tests/data/pusht/data_ids_per_episode.pth deleted file mode 100644 index a1d481dd..00000000 Binary files a/tests/data/pusht/data_ids_per_episode.pth and /dev/null differ diff --git a/tests/data/pusht/train/data-00000-of-00001.arrow b/tests/data/pusht/train/data-00000-of-00001.arrow new file mode 100644 index 00000000..9a36a8db Binary files /dev/null and b/tests/data/pusht/train/data-00000-of-00001.arrow differ diff --git a/tests/data/pusht/train/dataset_info.json b/tests/data/pusht/train/dataset_info.json new file mode 100644 index 00000000..667e06f7 --- /dev/null +++ b/tests/data/pusht/train/dataset_info.json @@ -0,0 +1,63 @@ +{ + "citation": "", + "description": "", + "features": { + "observation.image": { + "_type": "Image" + }, + "observation.state": { + "feature": { + "dtype": "float32", + "_type": "Value" + }, + "length": 2, + "_type": "Sequence" + }, + "action": { + "feature": { + "dtype": "float32", + "_type": "Value" + }, + "length": 2, + "_type": "Sequence" + }, + "episode_id": { + "dtype": "int64", + "_type": "Value" + }, + "frame_id": { + "dtype": "int64", + "_type": "Value" + }, + "timestamp": { + "dtype": "float32", + "_type": "Value" + }, + "next.reward": { + "dtype": "float32", + "_type": "Value" + }, + "next.done": { + "dtype": "bool", + "_type": "Value" + }, + "next.success": { + "dtype": "bool", + "_type": "Value" + }, + "episode_data_index_from": { + "dtype": "int64", + "_type": "Value" + }, + "episode_data_index_to": { + "dtype": "int64", + "_type": "Value" + }, + "index": { + "dtype": "int64", + "_type": "Value" + } + }, + "homepage": "", + "license": "" +} \ No newline at end of file diff --git a/tests/data/pusht/train/state.json b/tests/data/pusht/train/state.json new file mode 100644 index 00000000..7e0ff574 --- /dev/null +++ b/tests/data/pusht/train/state.json @@ -0,0 +1,13 @@ +{ + "_data_files": [ + { + "filename": "data-00000-of-00001.arrow" + } + ], + "_fingerprint": "21bb9a76ed78a475", + "_format_columns": null, + "_format_kwargs": {}, + "_format_type": "torch", + "_output_all_columns": false, + "_split": null +} \ No newline at end of file diff --git a/tests/data/xarm_lift_medium/data_dict.pth b/tests/data/xarm_lift_medium/data_dict.pth deleted file mode 100644 index 5c166576..00000000 Binary files a/tests/data/xarm_lift_medium/data_dict.pth and /dev/null differ diff --git a/tests/data/xarm_lift_medium/data_ids_per_episode.pth b/tests/data/xarm_lift_medium/data_ids_per_episode.pth deleted file mode 100644 index 21095017..00000000 Binary files a/tests/data/xarm_lift_medium/data_ids_per_episode.pth and /dev/null differ diff --git a/tests/data/xarm_lift_medium/train/data-00000-of-00001.arrow b/tests/data/xarm_lift_medium/train/data-00000-of-00001.arrow new file mode 100644 index 00000000..45d527e0 Binary files /dev/null and b/tests/data/xarm_lift_medium/train/data-00000-of-00001.arrow differ diff --git a/tests/data/xarm_lift_medium/train/dataset_info.json b/tests/data/xarm_lift_medium/train/dataset_info.json new file mode 100644 index 00000000..bb647c41 --- /dev/null +++ b/tests/data/xarm_lift_medium/train/dataset_info.json @@ -0,0 +1,59 @@ +{ + "citation": "", + "description": "", + "features": { + "observation.image": { + "_type": "Image" + }, + "observation.state": { + "feature": { + "dtype": "float32", + "_type": "Value" + }, + "length": 4, + "_type": "Sequence" + }, + "action": { + "feature": { + "dtype": "float32", + "_type": "Value" + }, + "length": 4, + "_type": "Sequence" + }, + "episode_id": { + "dtype": "int64", + "_type": "Value" + }, + "frame_id": { + "dtype": "int64", + "_type": "Value" + }, + "timestamp": { + "dtype": "float32", + "_type": "Value" + }, + "next.reward": { + "dtype": "float32", + "_type": "Value" + }, + "next.done": { + "dtype": "bool", + "_type": "Value" + }, + "episode_data_index_from": { + "dtype": "int64", + "_type": "Value" + }, + "episode_data_index_to": { + "dtype": "int64", + "_type": "Value" + }, + "index": { + "dtype": "int64", + "_type": "Value" + } + }, + "homepage": "", + "license": "" +} \ No newline at end of file diff --git a/tests/data/xarm_lift_medium/train/state.json b/tests/data/xarm_lift_medium/train/state.json new file mode 100644 index 00000000..c930c52c --- /dev/null +++ b/tests/data/xarm_lift_medium/train/state.json @@ -0,0 +1,13 @@ +{ + "_data_files": [ + { + "filename": "data-00000-of-00001.arrow" + } + ], + "_fingerprint": "a95cbec45e3bb9d6", + "_format_columns": null, + "_format_kwargs": {}, + "_format_type": "torch", + "_output_all_columns": false, + "_split": null +} \ No newline at end of file diff --git a/tests/test_available.py b/tests/test_available.py index 8df2c945..373cc1a7 100644 --- a/tests/test_available.py +++ b/tests/test_available.py @@ -1,64 +1,53 @@ """ This test verifies that all environments, datasets, policies listed in `lerobot/__init__.py` can be sucessfully -imported and that their class attributes (eg. `available_datasets`, `name`, `available_tasks`) corresponds. +imported and that their class attributes (eg. `available_datasets`, `name`, `available_tasks`) are valid. -Note: - When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to: - 1. set the required class attributes: - - for classes inheriting from `AbstractDataset`: `available_datasets` - - for classes inheriting from `AbstractEnv`: `name`, `available_tasks` - - for classes inheriting from `AbstractPolicy`: `name` - 2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`) - 3. update variables in `tests/test_available.py` by importing your new class +When implementing a new dataset (e.g. `AlohaDataset`), policy (e.g. `DiffusionPolicy`), or environment, follow these steps: +- Set the required class attributes: `available_datasets`. +- Set the required class attributes: `name`. +- Update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`) +- Update variables in `tests/test_available.py` by importing your new class """ +import importlib import pytest import lerobot +import gymnasium as gym -# from lerobot.common.envs.aloha.env import AlohaEnv -# from gym_pusht.envs import PushtEnv -# from gym_xarm.envs import SimxarmEnv +from lerobot.common.datasets.xarm import XarmDataset +from lerobot.common.datasets.aloha import AlohaDataset +from lerobot.common.datasets.pusht import PushtDataset -# from lerobot.common.datasets.xarm import SimxarmDataset -# from lerobot.common.datasets.aloha import AlohaDataset -# from lerobot.common.datasets.pusht import PushtDataset - -# from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy -# from lerobot.common.policies.diffusion.policy import DiffusionPolicy -# from lerobot.common.policies.tdmpc.policy import TDMPCPolicy +from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy +from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy +from lerobot.common.policies.tdmpc.policy import TDMPCPolicy -# def test_available(): -# pol_classes = [ -# ActionChunkingTransformerPolicy, -# DiffusionPolicy, -# TDMPCPolicy, -# ] +def test_available(): + policy_classes = [ + ActionChunkingTransformerPolicy, + DiffusionPolicy, + TDMPCPolicy, + ] -# env_classes = [ -# AlohaEnv, -# PushtEnv, -# SimxarmEnv, -# ] - -# dat_classes = [ -# AlohaDataset, -# PushtDataset, -# SimxarmDataset, -# ] + dataset_class_per_env = { + "aloha": AlohaDataset, + "pusht": PushtDataset, + "xarm": XarmDataset, + } -# policies = [pol_cls.name for pol_cls in pol_classes] -# assert set(policies) == set(lerobot.available_policies) + policies = [pol_cls.name for pol_cls in policy_classes] + assert set(policies) == set(lerobot.available_policies), policies -# envs = [env_cls.name for env_cls in env_classes] -# assert set(envs) == set(lerobot.available_envs) + for env_name in lerobot.available_envs: + for task_name in lerobot.available_tasks_per_env[env_name]: + package_name = f"gym_{env_name}" + importlib.import_module(package_name) + gym_handle = f"{package_name}/{task_name}" + assert gym_handle in gym.envs.registry.keys(), gym_handle -# tasks_per_env = {env_cls.name: env_cls.available_tasks for env_cls in env_classes} -# for env in envs: -# assert set(tasks_per_env[env]) == set(lerobot.available_tasks_per_env[env]) - -# datasets_per_env = {env_cls.name: dat_cls.available_datasets for env_cls, dat_cls in zip(env_classes, dat_classes)} -# for env in envs: -# assert set(datasets_per_env[env]) == set(lerobot.available_datasets_per_env[env]) + dataset_class = dataset_class_per_env[env_name] + available_datasets = lerobot.available_datasets_per_env[env_name] + assert set(available_datasets) == set(dataset_class.available_datasets), f"{env_name=} {available_datasets=}" diff --git a/tests/test_datasets.py b/tests/test_datasets.py index e24d7b4d..18d1e9d7 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -1,10 +1,15 @@ +import os +from pathlib import Path +import einops import pytest import torch +from lerobot.common.datasets.utils import compute_stats, get_stats_einops_patterns, load_previous_and_future_frames +from lerobot.common.transforms import Prod from lerobot.common.utils import init_hydra_config import logging from lerobot.common.datasets.factory import make_dataset - +from datasets import Dataset from .utils import DEVICE, DEFAULT_CONFIG_PATH @@ -32,7 +37,7 @@ def test_factory(env_name, dataset_id, policy_name): keys_ndim_required = [ ("action", 1, True), - ("episode", 0, True), + ("episode_id", 0, True), ("frame_id", 0, True), ("timestamp", 0, True), # TODO(rcadene): should we rename it agent_pos? @@ -45,6 +50,7 @@ def test_factory(env_name, dataset_id, policy_name): keys_ndim_required.append( (key, 3, True), ) + assert dataset.data_dict[key].dtype == torch.uint8, f"{key}" # test number of dimensions for key, ndim, required in keys_ndim_required: @@ -81,28 +87,115 @@ def test_factory(env_name, dataset_id, policy_name): assert key in item, f"{key}" -# def test_compute_stats(): -# """Check that the statistics are computed correctly according to the stats_patterns property. +def test_compute_stats(): + """Check that the statistics are computed correctly according to the stats_patterns property. + + We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do + because we are working with a small dataset). + """ + from lerobot.common.datasets.xarm import XarmDataset + + DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None + + # get transform to convert images from uint8 [0,255] to float32 [0,1] + transform = Prod(in_keys=XarmDataset.image_keys, prod=1 / 255.0) + + dataset = XarmDataset( + dataset_id="xarm_lift_medium", + root=DATA_DIR, + transform=transform, + ) + + # Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched + # computation of the statistics. While doing this, we also make sure it works when we don't divide the + # dataset into even batches. + computed_stats = compute_stats(dataset, batch_size=int(len(dataset) * 0.25)) + + # get einops patterns to aggregate batches and compute statistics + stats_patterns = get_stats_einops_patterns(dataset) + + # get all frames from the dataset in the same dtype and range as during compute_stats + dataloader = torch.utils.data.DataLoader( + dataset, + num_workers=8, + batch_size=len(dataset), + shuffle=False, + ) + data_dict = next(iter(dataloader)) + + # compute stats based on all frames from the dataset without any batching + expected_stats = {} + for k, pattern in stats_patterns.items(): + expected_stats[k] = {} + expected_stats[k]["mean"] = einops.reduce(data_dict[k], pattern, "mean") + expected_stats[k]["std"] = torch.sqrt(einops.reduce((data_dict[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean")) + expected_stats[k]["min"] = einops.reduce(data_dict[k], pattern, "min") + expected_stats[k]["max"] = einops.reduce(data_dict[k], pattern, "max") + + # test computed stats match expected stats + for k in stats_patterns: + assert torch.allclose(computed_stats[k]["mean"], expected_stats[k]["mean"]) + assert torch.allclose(computed_stats[k]["std"], expected_stats[k]["std"]) + assert torch.allclose(computed_stats[k]["min"], expected_stats[k]["min"]) + assert torch.allclose(computed_stats[k]["max"], expected_stats[k]["max"]) + + # TODO(rcadene): check that the stats used for training are correct too + # # load stats that are expected to match the ones returned by computed_stats + # assert (dataset.data_dir / "stats.pth").exists() + # loaded_stats = torch.load(dataset.data_dir / "stats.pth") + + # # test loaded stats match expected stats + # for k in stats_patterns: + # assert torch.allclose(loaded_stats[k]["mean"], expected_stats[k]["mean"]) + # assert torch.allclose(loaded_stats[k]["std"], expected_stats[k]["std"]) + # assert torch.allclose(loaded_stats[k]["min"], expected_stats[k]["min"]) + # assert torch.allclose(loaded_stats[k]["max"], expected_stats[k]["max"]) + + +def test_load_previous_and_future_frames_within_tolerance(): + data_dict = Dataset.from_dict({ + "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5], + "index": [0, 1, 2, 3, 4], + "episode_data_index_from": [0, 0, 0, 0, 0], + "episode_data_index_to": [5, 5, 5, 5, 5], + }) + data_dict = data_dict.with_format("torch") + item = data_dict[2] + delta_timestamps = {"index": [-0.2, 0, 0.139]} + tol = 0.04 + item = load_previous_and_future_frames(item, data_dict, delta_timestamps, tol) + data, is_pad = item["index"], item["index_is_pad"] + assert torch.equal(data, torch.tensor([0, 2, 3])), "Data does not match expected values" + assert not is_pad.any(), "Unexpected padding detected" + +def test_load_previous_and_future_frames_outside_tolerance_inside_episode_range(): + data_dict = Dataset.from_dict({ + "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5], + "index": [0, 1, 2, 3, 4], + "episode_data_index_from": [0, 0, 0, 0, 0], + "episode_data_index_to": [5, 5, 5, 5, 5], + }) + data_dict = data_dict.with_format("torch") + item = data_dict[2] + delta_timestamps = {"index": [-0.2, 0, 0.141]} + tol = 0.04 + with pytest.raises(AssertionError): + load_previous_and_future_frames(item, data_dict, delta_timestamps, tol) + +def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range(): + data_dict = Dataset.from_dict({ + "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5], + "index": [0, 1, 2, 3, 4], + "episode_data_index_from": [0, 0, 0, 0, 0], + "episode_data_index_to": [5, 5, 5, 5, 5], + }) + data_dict = data_dict.with_format("torch") + item = data_dict[2] + delta_timestamps = {"index": [-0.3, -0.24, 0, 0.26, 0.3]} + tol = 0.04 + item = load_previous_and_future_frames(item, data_dict, delta_timestamps, tol) + data, is_pad = item["index"], item["index_is_pad"] + assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values" + assert torch.equal(is_pad, torch.tensor([True, False, False, True, True])), "Padding does not match expected values" + -# We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do -# because we are working with a small dataset). -# """ -# cfg = init_hydra_config( -# DEFAULT_CONFIG_PATH, overrides=["env=aloha", "env.task=sim_transfer_cube_human"] -# ) -# dataset = make_dataset(cfg) -# # Get all of the data. -# all_data = dataset.data_dict -# # Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched -# # computation of the statistics. While doing this, we also make sure it works when we don't divide the -# # dataset into even batches. -# computed_stats = buffer._compute_stats(batch_size=int(len(all_data) * 0.75)) -# for k, pattern in buffer.stats_patterns.items(): -# expected_mean = einops.reduce(all_data[k], pattern, "mean") -# assert torch.allclose(computed_stats[k]["mean"], expected_mean) -# assert torch.allclose( -# computed_stats[k]["std"], -# torch.sqrt(einops.reduce((all_data[k] - expected_mean) ** 2, pattern, "mean")) -# ) -# assert torch.allclose(computed_stats[k]["min"], einops.reduce(all_data[k], pattern, "min")) -# assert torch.allclose(computed_stats[k]["max"], einops.reduce(all_data[k], pattern, "max")) diff --git a/tests/test_examples.py b/tests/test_examples.py index 4263e452..c510eb1e 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -1,8 +1,8 @@ from pathlib import Path -def _find_and_replace(text: str, finds: list[str], replaces: list[str]) -> str: - for f, r in zip(finds, replaces): +def _find_and_replace(text: str, finds_and_replaces: list[tuple[str, str]]) -> str: + for f, r in finds_and_replaces: assert f in text text = text.replace(f, r) return text @@ -29,14 +29,19 @@ def test_examples_3_and_2(): with open(path, "r") as file: file_contents = file.read() - # Do less steps and use CPU. + # Do less steps, use smaller batch, use CPU, and don't complicate things with dataloader workers. file_contents = _find_and_replace( file_contents, - ['"offline_steps=5000"', '"device=cuda"'], - ['"offline_steps=1"', '"device=cpu"'], + [ + ("training_steps = 5000", "training_steps = 1"), + ("num_workers=4", "num_workers=0"), + ('device = torch.device("cuda")', 'device = torch.device("cpu")'), + ("batch_size=cfg.batch_size", "batch_size=1"), + ], ) - exec(file_contents) + # Pass empty globals to allow dictionary comprehension https://stackoverflow.com/a/32897127/4391249. + exec(file_contents, {}) for file_name in ["model.pt", "stats.pth", "config.yaml"]: assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists() @@ -50,20 +55,15 @@ def test_examples_3_and_2(): file_contents = _find_and_replace( file_contents, [ - '"eval_episodes=10"', - '"rollout_batch_size=10"', - '"device=cuda"', - '# folder = Path("outputs/train/example_pusht_diffusion")', - 'hub_id = "lerobot/diffusion_policy_pusht_image"', - "folder = Path(snapshot_download(hub_id)", - ], - [ - '"eval_episodes=1"', - '"rollout_batch_size=1"', - '"device=cpu"', - 'folder = Path("outputs/train/example_pusht_diffusion")', - "", - "", + ('"eval_episodes=10"', '"eval_episodes=1"'), + ('"rollout_batch_size=10"', '"rollout_batch_size=1"'), + ('"device=cuda"', '"device=cpu"'), + ( + '# folder = Path("outputs/train/example_pusht_diffusion")', + 'folder = Path("outputs/train/example_pusht_diffusion")', + ), + ('hub_id = "lerobot/diffusion_policy_pusht_image"', ""), + ("folder = Path(snapshot_download(hub_id)", ""), ], ) diff --git a/tests/test_policies.py b/tests/test_policies.py index 8ccc7c62..f53e402a 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -4,11 +4,13 @@ import torch from lerobot.common.datasets.utils import cycle from lerobot.common.envs.utils import postprocess_action, preprocess_observation from lerobot.common.policies.factory import make_policy +from lerobot.common.policies.policy_protocol import Policy from lerobot.common.envs.factory import make_env from lerobot.common.datasets.factory import make_dataset from lerobot.common.utils import init_hydra_config from .utils import DEVICE, DEFAULT_CONFIG_PATH + @pytest.mark.parametrize( "env_name,policy_name,extra_overrides", [ @@ -27,6 +29,7 @@ def test_policy(env_name, policy_name, extra_overrides): """ Tests: - Making the policy object. + - Checking that the policy follows the correct protocol. - Updating the policy. - Using the policy to select actions at inference time. - Test the action can be applied to the policy @@ -38,10 +41,14 @@ def test_policy(env_name, policy_name, extra_overrides): f"policy={policy_name}", f"device={DEVICE}", ] - + extra_overrides + + extra_overrides, ) # Check that we can make the policy object. policy = make_policy(cfg) + # Check that the policy follows the required protocol. + assert isinstance( + policy, Policy + ), f"The policy does not follow the required protocol. Please see {Policy.__module__}.{Policy.__name__}." # Check that we run select_actions and get the appropriate output. dataset = make_dataset(cfg) env = make_env(cfg, num_parallel_envs=2) @@ -62,7 +69,7 @@ def test_policy(env_name, policy_name, extra_overrides): batch[key] = batch[key].to(DEVICE, non_blocking=True) # Test updating the policy - policy(batch, step=0) + policy.update(batch, step=0) # reset the policy and environment policy.reset() @@ -83,4 +90,3 @@ def test_policy(env_name, policy_name, extra_overrides): # Test step through policy env.step(action) -