Compare commits
60 Commits
Author | SHA1 | Date |
---|---|---|
|
14208c6d60 | |
|
959ecf9be8 | |
|
5e8884fcf3 | |
|
00dbc71db9 | |
|
a8b40fa813 | |
|
8d5a38222b | |
|
5340e77e76 | |
|
f584cb25d1 | |
|
275af1ed9e | |
|
995428b426 | |
|
baf8270fc5 | |
|
e9faa50b9e | |
|
93a6513504 | |
|
93f3ed9895 | |
|
9c8f020b3f | |
|
3e60fd7738 | |
|
a9e9cfb220 | |
|
391512f68c | |
|
0c63e9a11b | |
|
2883b2243e | |
|
0465437abc | |
|
7917c5f7cc | |
|
4f14468e19 | |
|
1d5c7e1542 | |
|
79df82ebea | |
|
cd7d5f31b5 | |
|
c812e45f35 | |
|
9fe4c7fccf | |
|
18d7db35a7 | |
|
6eb03ecbff | |
|
98eeeb17af | |
|
994535fe3e | |
|
da9ffa9521 | |
|
c0682408c5 | |
|
5da818b9d9 | |
|
592312ab8c | |
|
39d7aff90a | |
|
6fb8a19fd5 | |
|
58e763fdb6 | |
|
d01860176e | |
|
016442272e | |
|
ff0e11866d | |
|
632409da1e | |
|
4e355e9ab9 | |
|
af1ad0aed8 | |
|
677227145e | |
|
dc94e87620 | |
|
dce3085231 | |
|
c2ce2e25a4 | |
|
78324506fb | |
|
d384aaaa1c | |
|
f1d6821d62 | |
|
8dd3441fcd | |
|
b902d3244c | |
|
1fa1620c5e | |
|
a1ae58ffa7 | |
|
bf4e4b0251 | |
|
6508a9160c | |
|
5a4a459ad5 | |
|
6294f64795 |
|
@ -15,3 +15,4 @@ pretrained
|
||||||
*.mp4
|
*.mp4
|
||||||
.DS_Store
|
.DS_Store
|
||||||
workspace/log_ngp.txt
|
workspace/log_ngp.txt
|
||||||
|
.idea
|
214
LICENSE
214
LICENSE
|
@ -1,21 +1,201 @@
|
||||||
MIT License
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
Copyright (c) 2023 LiHengzhong
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
1. Definitions.
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
|
||||||
in the Software without restriction, including without limitation the rights
|
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
|
||||||
furnished to do so, subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
copies or substantial portions of the Software.
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
the copyright owner that is granting the License.
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
other entities that control, are controlled by, or are under common
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
control with that entity. For the purposes of this definition,
|
||||||
SOFTWARE.
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright [yyyy] [name of copyright owner]
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
|
152
README.md
152
README.md
|
@ -1,12 +1,14 @@
|
||||||
A streaming digital human based on the Ernerf model, realize audio video synchronous dialogue. It can basically achieve commercial effects.
|
Real time interactive streaming digital human, realize audio video synchronous dialogue. It can basically achieve commercial effects.
|
||||||
基于ernerf模型的流式数字人,实现音视频同步对话。基本可以达到商用效果
|
实时交互流式数字人,实现音视频同步对话。基本可以达到商用效果
|
||||||
|
|
||||||
[效果演示](https://www.bilibili.com/video/BV1PM4m1y7Q2/)
|
[ernerf效果](https://www.bilibili.com/video/BV1PM4m1y7Q2/) [musetalk效果](https://www.bilibili.com/video/BV1gm421N7vQ/) [wav2lip效果](https://www.bilibili.com/video/BV1Bw4m1e74P/)
|
||||||
|
|
||||||
|
## 为避免与3d数字人混淆,原项目metahuman-stream改名为livetalking,原有链接地址继续可用
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
1. 支持声音克隆
|
1. 支持多种数字人模型: ernerf、musetalk、wav2lip
|
||||||
2. 支持大模型对话
|
2. 支持声音克隆
|
||||||
3. 支持多种音频特征驱动:wav2vec、hubert
|
3. 支持数字人说话被打断
|
||||||
4. 支持全身视频拼接
|
4. 支持全身视频拼接
|
||||||
5. 支持rtmp和webrtc
|
5. 支持rtmp和webrtc
|
||||||
6. 支持视频编排:不说话时播放自定义视频
|
6. 支持视频编排:不说话时播放自定义视频
|
||||||
|
@ -22,17 +24,19 @@ conda create -n nerfstream python=3.10
|
||||||
conda activate nerfstream
|
conda activate nerfstream
|
||||||
conda install pytorch==1.12.1 torchvision==0.13.1 cudatoolkit=11.3 -c pytorch
|
conda install pytorch==1.12.1 torchvision==0.13.1 cudatoolkit=11.3 -c pytorch
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
|
#如果不训练ernerf模型,不需要安装下面的库
|
||||||
pip install "git+https://github.com/facebookresearch/pytorch3d.git"
|
pip install "git+https://github.com/facebookresearch/pytorch3d.git"
|
||||||
pip install tensorflow-gpu==2.8.0
|
pip install tensorflow-gpu==2.8.0
|
||||||
pip install --upgrade "protobuf<=3.20.1"
|
pip install --upgrade "protobuf<=3.20.1"
|
||||||
```
|
```
|
||||||
|
如果用pytorch2.1,torchvision用0.16(可以去torchvision官网根据pytorch版本找匹配的),cudatoolkit可以不用装
|
||||||
安装常见问题[FAQ](/assets/faq.md)
|
安装常见问题[FAQ](/assets/faq.md)
|
||||||
linux cuda环境搭建可以参考这篇文章 https://zhuanlan.zhihu.com/p/674972886
|
linux cuda环境搭建可以参考这篇文章 https://zhuanlan.zhihu.com/p/674972886
|
||||||
|
|
||||||
|
|
||||||
## 2. Quick Start
|
## 2. Quick Start
|
||||||
默认采用webrtc推流到srs
|
默认采用ernerf模型,webrtc推流到srs
|
||||||
### 2.1 运行rtmpserver (srs)
|
### 2.1 运行srs
|
||||||
```
|
```
|
||||||
export CANDIDATE='<服务器外网ip>'
|
export CANDIDATE='<服务器外网ip>'
|
||||||
docker run --rm --env CANDIDATE=$CANDIDATE \
|
docker run --rm --env CANDIDATE=$CANDIDATE \
|
||||||
|
@ -52,135 +56,45 @@ python app.py
|
||||||
export HF_ENDPOINT=https://hf-mirror.com
|
export HF_ENDPOINT=https://hf-mirror.com
|
||||||
```
|
```
|
||||||
|
|
||||||
用浏览器打开http://serverip:8010/rtcpush.html, 在文本框输入任意文字,提交。数字人播报该段文字
|
用浏览器打开http://serverip:8010/rtcpushapi.html, 在文本框输入任意文字,提交。数字人播报该段文字
|
||||||
备注:服务端需要开放端口 tcp:8000,8010,1985; udp:8000
|
备注:服务端需要开放端口 tcp:8000,8010,1985; udp:8000
|
||||||
|
|
||||||
## 3. More Usage
|
## 3. More Usage
|
||||||
### 3.1 使用LLM模型进行数字人对话
|
使用说明: <https://livetalking-doc.readthedocs.io/>
|
||||||
|
|
||||||
目前借鉴数字人对话系统[LinlyTalker](https://github.com/Kedreamix/Linly-Talker)的方式,LLM模型支持Chatgpt,Qwen和GeminiPro。需要在app.py中填入自己的api_key。
|
|
||||||
|
|
||||||
用浏览器打开http://serverip:8010/rtcpushchat.html
|
|
||||||
|
|
||||||
### 3.2 声音克隆
|
|
||||||
可以任意选用下面两种服务,推荐用gpt-sovits
|
|
||||||
#### 3.2.1 gpt-sovits
|
|
||||||
服务部署参照[gpt-sovits](/tts/README.md)
|
|
||||||
运行
|
|
||||||
```
|
|
||||||
python app.py --tts gpt-sovits --TTS_SERVER http://127.0.0.1:5000 --CHARACTER test --EMOTION default
|
|
||||||
```
|
|
||||||
#### 3.2.2 xtts
|
|
||||||
运行xtts服务,参照 https://github.com/coqui-ai/xtts-streaming-server
|
|
||||||
```
|
|
||||||
docker run --gpus=all -e COQUI_TOS_AGREED=1 --rm -p 9000:80 ghcr.io/coqui-ai/xtts-streaming-server:latest
|
|
||||||
```
|
|
||||||
然后运行,其中ref.wav为需要克隆的声音文件
|
|
||||||
```
|
|
||||||
python app.py --tts xtts --REF_FILE data/ref.wav --TTS_SERVER http://localhost:9000
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3.3 音频特征用hubert
|
|
||||||
如果训练模型时用的hubert提取音频特征,用如下命令启动数字人
|
|
||||||
```
|
|
||||||
python app.py --asr_model facebook/hubert-large-ls960-ft
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3.4 设置背景图片
|
|
||||||
```
|
|
||||||
python app.py --bg_img bg.jpg
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3.5 全身视频拼接
|
|
||||||
#### 3.5.1 切割训练用的视频
|
|
||||||
```
|
|
||||||
ffmpeg -i fullbody.mp4 -vf crop="400:400:100:5" train.mp4
|
|
||||||
```
|
|
||||||
用train.mp4训练模型
|
|
||||||
#### 3.5.2 提取全身图片
|
|
||||||
```
|
|
||||||
ffmpeg -i fullbody.mp4 -vf fps=25 -qmin 1 -q:v 1 -start_number 0 data/fullbody/img/%d.jpg
|
|
||||||
```
|
|
||||||
#### 3.5.2 启动数字人
|
|
||||||
```
|
|
||||||
python app.py --fullbody --fullbody_img data/fullbody/img --fullbody_offset_x 100 --fullbody_offset_y 5 --fullbody_width 580 --fullbody_height 1080 --W 400 --H 400
|
|
||||||
```
|
|
||||||
- --fullbody_width、--fullbody_height 全身视频的宽、高
|
|
||||||
- --W、--H 训练视频的宽、高
|
|
||||||
- ernerf训练第三步torso如果训练的不好,在拼接处会有接缝。可以在上面的命令加上--torso_imgs data/xxx/torso_imgs,torso不用模型推理,直接用训练数据集里的torso图片。这种方式可能头颈处会有些人工痕迹。
|
|
||||||
|
|
||||||
### 3.6 不说话时用自定义视频替代
|
|
||||||
- 提取自定义视频图片
|
|
||||||
```
|
|
||||||
ffmpeg -i silence.mp4 -vf fps=25 -qmin 1 -q:v 1 -start_number 0 data/customvideo/img/%d.png
|
|
||||||
```
|
|
||||||
- 运行数字人
|
|
||||||
```
|
|
||||||
python app.py --customvideo --customvideo_img data/customvideo/img --customvideo_imgnum 100
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3.7 webrtc p2p
|
|
||||||
此种模式不需要srs
|
|
||||||
```
|
|
||||||
python app.py --transport webrtc
|
|
||||||
```
|
|
||||||
用浏览器打开http://serverip:8010/webrtc.html
|
|
||||||
|
|
||||||
### 3.8 rtmp推送到srs
|
|
||||||
- 安装rtmpstream库
|
|
||||||
参照 https://github.com/lipku/python_rtmpstream
|
|
||||||
|
|
||||||
- 启动srs
|
|
||||||
```
|
|
||||||
docker run --rm -it -p 1935:1935 -p 1985:1985 -p 8080:8080 registry.cn-hangzhou.aliyuncs.com/ossrs/srs:5
|
|
||||||
```
|
|
||||||
- 运行数字人
|
|
||||||
```python
|
|
||||||
python app.py --transport rtmp --push_url 'rtmp://localhost/live/livestream'
|
|
||||||
```
|
|
||||||
用浏览器打开http://serverip:8010/echo.html
|
|
||||||
|
|
||||||
## 4. Docker Run
|
## 4. Docker Run
|
||||||
不需要第1步的安装,直接运行。
|
不需要前面的安装,直接运行。
|
||||||
```
|
```
|
||||||
docker run --gpus all -it --network=host --rm registry.cn-hangzhou.aliyuncs.com/lipku/nerfstream:v1.3
|
docker run --gpus all -it --network=host --rm registry.cn-beijing.aliyuncs.com/codewithgpu2/lipku-metahuman-stream:vjo1Y6NJ3N
|
||||||
```
|
```
|
||||||
docker版本已经不是最新代码,可以作为一个空环境,把最新代码拷进去运行。
|
代码在/root/metahuman-stream,先git pull拉一下最新代码,然后执行命令同第2、3步
|
||||||
另外提供autodl镜像:https://www.codewithgpu.com/i/lipku/metahuman-stream/base
|
|
||||||
|
|
||||||
## 5. Data flow
|
提供如下镜像
|
||||||

|
- autodl镜像: <https://www.codewithgpu.com/i/lipku/metahuman-stream/base>
|
||||||
|
[autodl教程](autodl/README.md)
|
||||||
|
|
||||||
## 6. 数字人模型文件
|
|
||||||
可以替换成自己训练的模型(https://github.com/Fictionarry/ER-NeRF)
|
|
||||||
```python
|
|
||||||
.
|
|
||||||
├── data
|
|
||||||
│ ├── data_kf.json
|
|
||||||
│ ├── au.csv
|
|
||||||
│ ├── pretrained
|
|
||||||
│ └── └── ngp_kf.pth
|
|
||||||
|
|
||||||
```
|
## 5. 性能分析
|
||||||
|
|
||||||
## 7. 性能分析
|
|
||||||
1. 帧率
|
1. 帧率
|
||||||
在Tesla T4显卡上测试整体fps为18左右,如果去掉音视频编码推流,帧率在20左右。用4090显卡可以达到40多帧/秒。
|
在Tesla T4显卡上测试整体fps为18左右,如果去掉音视频编码推流,帧率在20左右。用4090显卡可以达到40多帧/秒。
|
||||||
优化:新开一个线程运行音视频编码推流
|
|
||||||
2. 延时
|
2. 延时
|
||||||
整体延时3s左右
|
整体延时3s左右
|
||||||
(1)tts延时1.7s左右,目前用的edgetts,需要将每句话转完后一次性输入,可以优化tts改成流式输入
|
(1)tts延时1.7s左右,目前用的edgetts,需要将每句话转完后一次性输入,可以优化tts改成流式输入
|
||||||
(2)wav2vec延时0.4s,需要缓存18帧音频做计算
|
(2)wav2vec延时0.4s,需要缓存18帧音频做计算
|
||||||
(3)srs转发延时,设置srs服务器减少缓冲延时。具体配置可看 https://ossrs.net/lts/zh-cn/docs/v5/doc/low-latency
|
(3)srs转发延时,设置srs服务器减少缓冲延时。具体配置可看 https://ossrs.net/lts/zh-cn/docs/v5/doc/low-latency
|
||||||
|
|
||||||
## 8. TODO
|
|
||||||
|
## 6. TODO
|
||||||
- [x] 添加chatgpt实现数字人对话
|
- [x] 添加chatgpt实现数字人对话
|
||||||
- [x] 声音克隆
|
- [x] 声音克隆
|
||||||
- [x] 数字人静音时用一段视频代替
|
- [x] 数字人静音时用一段视频代替
|
||||||
- [ ] MuseTalk
|
- [x] MuseTalk
|
||||||
|
- [x] Wav2Lip
|
||||||
|
- [ ] TalkingGaussian
|
||||||
|
|
||||||
|
---
|
||||||
|
如果本项目对你有帮助,帮忙点个star。也欢迎感兴趣的朋友一起来完善该项目.
|
||||||
|
* 知识星球: https://t.zsxq.com/7NMyO 沉淀高质量常见问题、最佳实践经验、问题解答
|
||||||
|
* 微信公众号:数字人技术
|
||||||
|

|
||||||
|
|
||||||
如果本项目对你有帮助,帮忙点个star。也欢迎感兴趣的朋友一起来完善该项目。
|
|
||||||
Email: lipku@foxmail.com
|
|
||||||
知识星球: https://t.zsxq.com/7NMyO
|
|
||||||
微信公众号:数字人技术
|
|
||||||

|
|
||||||
|
|
486
app.py
486
app.py
|
@ -17,142 +17,20 @@ from aiohttp import web
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import aiohttp_cors
|
import aiohttp_cors
|
||||||
from aiortc import RTCPeerConnection, RTCSessionDescription
|
from aiortc import RTCPeerConnection, RTCSessionDescription
|
||||||
|
from aiortc.rtcrtpsender import RTCRtpSender
|
||||||
from webrtc import HumanPlayer
|
from webrtc import HumanPlayer
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
from ernerf.nerf_triplane.provider import NeRFDataset_Test
|
|
||||||
from ernerf.nerf_triplane.utils import *
|
|
||||||
from ernerf.nerf_triplane.network import NeRFNetwork
|
|
||||||
from nerfreal import NeRFReal
|
|
||||||
|
|
||||||
import shutil
|
import shutil
|
||||||
import asyncio
|
import asyncio
|
||||||
import edge_tts
|
import string
|
||||||
from typing import Iterator
|
|
||||||
|
|
||||||
import requests
|
|
||||||
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
sockets = Sockets(app)
|
sockets = Sockets(app)
|
||||||
global nerfreal
|
nerfreals = []
|
||||||
global tts_type
|
statreals = []
|
||||||
global gspeaker
|
|
||||||
|
|
||||||
|
|
||||||
async def main(voicename: str, text: str, render):
|
|
||||||
communicate = edge_tts.Communicate(text, voicename)
|
|
||||||
|
|
||||||
#with open(OUTPUT_FILE, "wb") as file:
|
|
||||||
first = True
|
|
||||||
async for chunk in communicate.stream():
|
|
||||||
if first:
|
|
||||||
#render.before_push_audio()
|
|
||||||
first = False
|
|
||||||
if chunk["type"] == "audio":
|
|
||||||
render.push_audio(chunk["data"])
|
|
||||||
#file.write(chunk["data"])
|
|
||||||
elif chunk["type"] == "WordBoundary":
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_speaker(ref_audio,server_url):
|
|
||||||
files = {"wav_file": ("reference.wav", open(ref_audio, "rb"))}
|
|
||||||
response = requests.post(f"{server_url}/clone_speaker", files=files)
|
|
||||||
return response.json()
|
|
||||||
|
|
||||||
def xtts(text, speaker, language, server_url, stream_chunk_size) -> Iterator[bytes]:
|
|
||||||
start = time.perf_counter()
|
|
||||||
speaker["text"] = text
|
|
||||||
speaker["language"] = language
|
|
||||||
speaker["stream_chunk_size"] = stream_chunk_size # you can reduce it to get faster response, but degrade quality
|
|
||||||
res = requests.post(
|
|
||||||
f"{server_url}/tts_stream",
|
|
||||||
json=speaker,
|
|
||||||
stream=True,
|
|
||||||
)
|
|
||||||
end = time.perf_counter()
|
|
||||||
print(f"xtts Time to make POST: {end-start}s")
|
|
||||||
|
|
||||||
if res.status_code != 200:
|
|
||||||
print("Error:", res.text)
|
|
||||||
return
|
|
||||||
|
|
||||||
first = True
|
|
||||||
for chunk in res.iter_content(chunk_size=960): #24K*20ms*2
|
|
||||||
if first:
|
|
||||||
end = time.perf_counter()
|
|
||||||
print(f"xtts Time to first chunk: {end-start}s")
|
|
||||||
first = False
|
|
||||||
if chunk:
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
print("xtts response.elapsed:", res.elapsed)
|
|
||||||
|
|
||||||
def gpt_sovits(text, character, language, server_url, emotion) -> Iterator[bytes]:
|
|
||||||
start = time.perf_counter()
|
|
||||||
req={}
|
|
||||||
req["text"] = text
|
|
||||||
req["text_language"] = language
|
|
||||||
req["character"] = character
|
|
||||||
req["emotion"] = emotion
|
|
||||||
#req["stream_chunk_size"] = stream_chunk_size # you can reduce it to get faster response, but degrade quality
|
|
||||||
req["stream"] = True
|
|
||||||
res = requests.post(
|
|
||||||
f"{server_url}/tts",
|
|
||||||
json=req,
|
|
||||||
stream=True,
|
|
||||||
)
|
|
||||||
end = time.perf_counter()
|
|
||||||
print(f"gpt_sovits Time to make POST: {end-start}s")
|
|
||||||
|
|
||||||
if res.status_code != 200:
|
|
||||||
print("Error:", res.text)
|
|
||||||
return
|
|
||||||
|
|
||||||
first = True
|
|
||||||
for chunk in res.iter_content(chunk_size=32000): # 1280 32K*20ms*2
|
|
||||||
if first:
|
|
||||||
end = time.perf_counter()
|
|
||||||
print(f"gpt_sovits Time to first chunk: {end-start}s")
|
|
||||||
first = False
|
|
||||||
if chunk:
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
print("gpt_sovits response.elapsed:", res.elapsed)
|
|
||||||
|
|
||||||
def stream_tts(audio_stream,render):
|
|
||||||
for chunk in audio_stream:
|
|
||||||
if chunk is not None:
|
|
||||||
render.push_audio(chunk)
|
|
||||||
|
|
||||||
def txt_to_audio(text_):
|
|
||||||
if tts_type == "edgetts":
|
|
||||||
voicename = "zh-CN-YunxiaNeural"
|
|
||||||
text = text_
|
|
||||||
t = time.time()
|
|
||||||
asyncio.get_event_loop().run_until_complete(main(voicename,text,nerfreal))
|
|
||||||
print(f'-------edge tts time:{time.time()-t:.4f}s')
|
|
||||||
elif tts_type == "gpt-sovits": #gpt_sovits
|
|
||||||
stream_tts(
|
|
||||||
gpt_sovits(
|
|
||||||
text_,
|
|
||||||
app.config['CHARACTER'], #"test", #character
|
|
||||||
"zh", #en args.language,
|
|
||||||
app.config['TTS_SERVER'], #"http://127.0.0.1:5000", #args.server_url,
|
|
||||||
app.config['EMOTION'], #emotion
|
|
||||||
),
|
|
||||||
nerfreal
|
|
||||||
)
|
|
||||||
else: #xtts
|
|
||||||
stream_tts(
|
|
||||||
xtts(
|
|
||||||
text_,
|
|
||||||
gspeaker,
|
|
||||||
"zh-cn", #en args.language,
|
|
||||||
app.config['TTS_SERVER'], #"http://localhost:9000", #args.server_url,
|
|
||||||
"20" #args.stream_chunk_size
|
|
||||||
),
|
|
||||||
nerfreal
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@sockets.route('/humanecho')
|
@sockets.route('/humanecho')
|
||||||
|
@ -172,17 +50,61 @@ def echo_socket(ws):
|
||||||
if not message or len(message)==0:
|
if not message or len(message)==0:
|
||||||
return '输入信息为空'
|
return '输入信息为空'
|
||||||
else:
|
else:
|
||||||
txt_to_audio(message)
|
nerfreal.put_msg_txt(message)
|
||||||
|
|
||||||
|
|
||||||
def llm_response(message):
|
# def llm_response(message):
|
||||||
from llm.LLM import LLM
|
# from llm.LLM import LLM
|
||||||
# llm = LLM().init_model('Gemini', model_path= 'gemini-pro',api_key='Your API Key', proxy_url=None)
|
# # llm = LLM().init_model('Gemini', model_path= 'gemini-pro',api_key='Your API Key', proxy_url=None)
|
||||||
# llm = LLM().init_model('ChatGPT', model_path= 'gpt-3.5-turbo',api_key='Your API Key')
|
# # llm = LLM().init_model('ChatGPT', model_path= 'gpt-3.5-turbo',api_key='Your API Key')
|
||||||
llm = LLM().init_model('VllmGPT', model_path= 'THUDM/chatglm3-6b')
|
# llm = LLM().init_model('VllmGPT', model_path= 'THUDM/chatglm3-6b')
|
||||||
response = llm.chat(message)
|
# response = llm.chat(message)
|
||||||
print(response)
|
# print(response)
|
||||||
return response
|
# return response
|
||||||
|
|
||||||
|
def llm_response(message,nerfreal):
|
||||||
|
start = time.perf_counter()
|
||||||
|
from openai import OpenAI
|
||||||
|
client = OpenAI(
|
||||||
|
# 如果您没有配置环境变量,请在此处用您的API Key进行替换
|
||||||
|
api_key=os.getenv("DASHSCOPE_API_KEY"),
|
||||||
|
# 填写DashScope SDK的base_url
|
||||||
|
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||||
|
)
|
||||||
|
end = time.perf_counter()
|
||||||
|
print(f"llm Time init: {end-start}s")
|
||||||
|
completion = client.chat.completions.create(
|
||||||
|
model="qwen-plus",
|
||||||
|
messages=[{'role': 'system', 'content': 'You are a helpful assistant.'},
|
||||||
|
{'role': 'user', 'content': message}],
|
||||||
|
stream=True,
|
||||||
|
# 通过以下设置,在流式输出的最后一行展示token使用信息
|
||||||
|
stream_options={"include_usage": True}
|
||||||
|
)
|
||||||
|
result=""
|
||||||
|
first = True
|
||||||
|
for chunk in completion:
|
||||||
|
if len(chunk.choices)>0:
|
||||||
|
#print(chunk.choices[0].delta.content)
|
||||||
|
if first:
|
||||||
|
end = time.perf_counter()
|
||||||
|
print(f"llm Time to first chunk: {end-start}s")
|
||||||
|
first = False
|
||||||
|
msg = chunk.choices[0].delta.content
|
||||||
|
lastpos=0
|
||||||
|
#msglist = re.split('[,.!;:,。!?]',msg)
|
||||||
|
for i, char in enumerate(msg):
|
||||||
|
if char in ",.!;:,。!?:;" :
|
||||||
|
result = result+msg[lastpos:i+1]
|
||||||
|
lastpos = i+1
|
||||||
|
if len(result)>10:
|
||||||
|
print(result)
|
||||||
|
nerfreal.put_msg_txt(result)
|
||||||
|
result=""
|
||||||
|
result = result+msg[lastpos:]
|
||||||
|
end = time.perf_counter()
|
||||||
|
print(f"llm Time to last chunk: {end-start}s")
|
||||||
|
nerfreal.put_msg_txt(result)
|
||||||
|
|
||||||
@sockets.route('/humanchat')
|
@sockets.route('/humanchat')
|
||||||
def chat_socket(ws):
|
def chat_socket(ws):
|
||||||
|
@ -202,47 +124,26 @@ def chat_socket(ws):
|
||||||
return '输入信息为空'
|
return '输入信息为空'
|
||||||
else:
|
else:
|
||||||
res=llm_response(message)
|
res=llm_response(message)
|
||||||
txt_to_audio(res)
|
nerfreal.put_msg_txt(res)
|
||||||
|
|
||||||
#####webrtc###############################
|
#####webrtc###############################
|
||||||
pcs = set()
|
pcs = set()
|
||||||
|
|
||||||
async def txt_to_audio_async(text_):
|
|
||||||
if tts_type == "edgetts":
|
|
||||||
voicename = "zh-CN-YunxiaNeural"
|
|
||||||
text = text_
|
|
||||||
t = time.time()
|
|
||||||
#asyncio.get_event_loop().run_until_complete(main(voicename,text,nerfreal))
|
|
||||||
await main(voicename,text,nerfreal)
|
|
||||||
print(f'-------edge tts time:{time.time()-t:.4f}s')
|
|
||||||
elif tts_type == "gpt-sovits": #gpt_sovits
|
|
||||||
stream_tts(
|
|
||||||
gpt_sovits(
|
|
||||||
text_,
|
|
||||||
app.config['CHARACTER'], #"test", #character
|
|
||||||
"zh", #en args.language,
|
|
||||||
app.config['TTS_SERVER'], #"http://127.0.0.1:5000", #args.server_url,
|
|
||||||
app.config['EMOTION'], #emotion
|
|
||||||
),
|
|
||||||
nerfreal
|
|
||||||
)
|
|
||||||
else: #xtts
|
|
||||||
stream_tts(
|
|
||||||
xtts(
|
|
||||||
text_,
|
|
||||||
gspeaker,
|
|
||||||
"zh-cn", #en args.language,
|
|
||||||
app.config['TTS_SERVER'], #"http://localhost:9000", #args.server_url,
|
|
||||||
"20" #args.stream_chunk_size
|
|
||||||
),
|
|
||||||
nerfreal
|
|
||||||
)
|
|
||||||
|
|
||||||
#@app.route('/offer', methods=['POST'])
|
#@app.route('/offer', methods=['POST'])
|
||||||
async def offer(request):
|
async def offer(request):
|
||||||
params = await request.json()
|
params = await request.json()
|
||||||
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
|
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
|
||||||
|
|
||||||
|
sessionid = len(nerfreals)
|
||||||
|
for index,value in enumerate(statreals):
|
||||||
|
if value == 0:
|
||||||
|
sessionid = index
|
||||||
|
break
|
||||||
|
if sessionid>=len(nerfreals):
|
||||||
|
print('reach max session')
|
||||||
|
return -1
|
||||||
|
statreals[sessionid] = 1
|
||||||
|
|
||||||
pc = RTCPeerConnection()
|
pc = RTCPeerConnection()
|
||||||
pcs.add(pc)
|
pcs.add(pc)
|
||||||
|
|
||||||
|
@ -252,10 +153,20 @@ async def offer(request):
|
||||||
if pc.connectionState == "failed":
|
if pc.connectionState == "failed":
|
||||||
await pc.close()
|
await pc.close()
|
||||||
pcs.discard(pc)
|
pcs.discard(pc)
|
||||||
|
statreals[sessionid] = 0
|
||||||
|
if pc.connectionState == "closed":
|
||||||
|
pcs.discard(pc)
|
||||||
|
statreals[sessionid] = 0
|
||||||
|
|
||||||
player = HumanPlayer(nerfreal)
|
player = HumanPlayer(nerfreals[sessionid])
|
||||||
audio_sender = pc.addTrack(player.audio)
|
audio_sender = pc.addTrack(player.audio)
|
||||||
video_sender = pc.addTrack(player.video)
|
video_sender = pc.addTrack(player.video)
|
||||||
|
capabilities = RTCRtpSender.getCapabilities("video")
|
||||||
|
preferences = list(filter(lambda x: x.name == "H264", capabilities.codecs))
|
||||||
|
preferences += list(filter(lambda x: x.name == "VP8", capabilities.codecs))
|
||||||
|
preferences += list(filter(lambda x: x.name == "rtx", capabilities.codecs))
|
||||||
|
transceiver = pc.getTransceivers()[1]
|
||||||
|
transceiver.setCodecPreferences(preferences)
|
||||||
|
|
||||||
await pc.setRemoteDescription(offer)
|
await pc.setRemoteDescription(offer)
|
||||||
|
|
||||||
|
@ -267,18 +178,22 @@ async def offer(request):
|
||||||
return web.Response(
|
return web.Response(
|
||||||
content_type="application/json",
|
content_type="application/json",
|
||||||
text=json.dumps(
|
text=json.dumps(
|
||||||
{"sdp": pc.localDescription.sdp, "type": pc.localDescription.type}
|
{"sdp": pc.localDescription.sdp, "type": pc.localDescription.type, "sessionid":sessionid}
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def human(request):
|
async def human(request):
|
||||||
params = await request.json()
|
params = await request.json()
|
||||||
|
|
||||||
|
sessionid = params.get('sessionid',0)
|
||||||
|
if params.get('interrupt'):
|
||||||
|
nerfreals[sessionid].pause_talk()
|
||||||
|
|
||||||
if params['type']=='echo':
|
if params['type']=='echo':
|
||||||
await txt_to_audio_async(params['text'])
|
nerfreals[sessionid].put_msg_txt(params['text'])
|
||||||
elif params['type']=='chat':
|
elif params['type']=='chat':
|
||||||
res=llm_response(params['text'])
|
res=await asyncio.get_event_loop().run_in_executor(None, llm_response, params['text'],nerfreals[sessionid])
|
||||||
await txt_to_audio_async(res)
|
#nerfreals[sessionid].put_msg_txt(res)
|
||||||
|
|
||||||
return web.Response(
|
return web.Response(
|
||||||
content_type="application/json",
|
content_type="application/json",
|
||||||
|
@ -287,6 +202,70 @@ async def human(request):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def humanaudio(request):
|
||||||
|
try:
|
||||||
|
form= await request.post()
|
||||||
|
sessionid = int(form.get('sessionid',0))
|
||||||
|
fileobj = form["file"]
|
||||||
|
filename=fileobj.filename
|
||||||
|
filebytes=fileobj.file.read()
|
||||||
|
nerfreals[sessionid].put_audio_file(filebytes)
|
||||||
|
|
||||||
|
return web.Response(
|
||||||
|
content_type="application/json",
|
||||||
|
text=json.dumps(
|
||||||
|
{"code": 0, "msg":"ok"}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return web.Response(
|
||||||
|
content_type="application/json",
|
||||||
|
text=json.dumps(
|
||||||
|
{"code": -1, "msg":"err","data": ""+e.args[0]+""}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def set_audiotype(request):
|
||||||
|
params = await request.json()
|
||||||
|
|
||||||
|
sessionid = params.get('sessionid',0)
|
||||||
|
nerfreals[sessionid].set_curr_state(params['audiotype'],params['reinit'])
|
||||||
|
|
||||||
|
return web.Response(
|
||||||
|
content_type="application/json",
|
||||||
|
text=json.dumps(
|
||||||
|
{"code": 0, "data":"ok"}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def record(request):
|
||||||
|
params = await request.json()
|
||||||
|
|
||||||
|
sessionid = params.get('sessionid',0)
|
||||||
|
if params['type']=='start_record':
|
||||||
|
# nerfreals[sessionid].put_msg_txt(params['text'])
|
||||||
|
nerfreals[sessionid].start_recording("data/record_lasted.mp4")
|
||||||
|
elif params['type']=='end_record':
|
||||||
|
nerfreals[sessionid].stop_recording()
|
||||||
|
return web.Response(
|
||||||
|
content_type="application/json",
|
||||||
|
text=json.dumps(
|
||||||
|
{"code": 0, "data":"ok"}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def is_speaking(request):
|
||||||
|
params = await request.json()
|
||||||
|
|
||||||
|
sessionid = params.get('sessionid',0)
|
||||||
|
return web.Response(
|
||||||
|
content_type="application/json",
|
||||||
|
text=json.dumps(
|
||||||
|
{"code": 0, "data": nerfreals[sessionid].is_speaking()}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def on_shutdown(app):
|
async def on_shutdown(app):
|
||||||
# close peer connections
|
# close peer connections
|
||||||
coros = [pc.close() for pc in pcs]
|
coros = [pc.close() for pc in pcs]
|
||||||
|
@ -312,17 +291,18 @@ async def run(push_url):
|
||||||
await pc.close()
|
await pc.close()
|
||||||
pcs.discard(pc)
|
pcs.discard(pc)
|
||||||
|
|
||||||
player = HumanPlayer(nerfreal)
|
player = HumanPlayer(nerfreals[0])
|
||||||
audio_sender = pc.addTrack(player.audio)
|
audio_sender = pc.addTrack(player.audio)
|
||||||
video_sender = pc.addTrack(player.video)
|
video_sender = pc.addTrack(player.video)
|
||||||
|
|
||||||
await pc.setLocalDescription(await pc.createOffer())
|
await pc.setLocalDescription(await pc.createOffer())
|
||||||
answer = await post(push_url,pc.localDescription.sdp)
|
answer = await post(push_url,pc.localDescription.sdp)
|
||||||
await pc.setRemoteDescription(RTCSessionDescription(sdp=answer,type='answer'))
|
await pc.setRemoteDescription(RTCSessionDescription(sdp=answer,type='answer'))
|
||||||
##########################################
|
##########################################
|
||||||
|
# os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
|
||||||
|
# os.environ['MULTIPROCESSING_METHOD'] = 'forkserver'
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
multiprocessing.set_start_method('spawn')
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--pose', type=str, default="data/data_kf.json", help="transforms.json, pose source")
|
parser.add_argument('--pose', type=str, default="data/data_kf.json", help="transforms.json, pose source")
|
||||||
parser.add_argument('--au', type=str, default="data/au.csv", help="eye blink area")
|
parser.add_argument('--au', type=str, default="data/au.csv", help="eye blink area")
|
||||||
|
@ -419,9 +399,6 @@ if __name__ == '__main__':
|
||||||
# parser.add_argument('--asr_model', type=str, default='facebook/wav2vec2-large-960h-lv60-self')
|
# parser.add_argument('--asr_model', type=str, default='facebook/wav2vec2-large-960h-lv60-self')
|
||||||
# parser.add_argument('--asr_model', type=str, default='facebook/hubert-large-ls960-ft')
|
# parser.add_argument('--asr_model', type=str, default='facebook/hubert-large-ls960-ft')
|
||||||
|
|
||||||
parser.add_argument('--transport', type=str, default='rtcpush') #rtmp webrtc rtcpush
|
|
||||||
parser.add_argument('--push_url', type=str, default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') #rtmp://localhost/live/livestream
|
|
||||||
|
|
||||||
parser.add_argument('--asr_save_feats', action='store_true')
|
parser.add_argument('--asr_save_feats', action='store_true')
|
||||||
# audio FPS
|
# audio FPS
|
||||||
parser.add_argument('--fps', type=int, default=50)
|
parser.add_argument('--fps', type=int, default=50)
|
||||||
|
@ -437,73 +414,108 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--fullbody_offset_x', type=int, default=0)
|
parser.add_argument('--fullbody_offset_x', type=int, default=0)
|
||||||
parser.add_argument('--fullbody_offset_y', type=int, default=0)
|
parser.add_argument('--fullbody_offset_y', type=int, default=0)
|
||||||
|
|
||||||
parser.add_argument('--customvideo', action='store_true', help="custom video")
|
#musetalk opt
|
||||||
parser.add_argument('--customvideo_img', type=str, default='data/customvideo/img')
|
parser.add_argument('--avatar_id', type=str, default='avator_1')
|
||||||
parser.add_argument('--customvideo_imgnum', type=int, default=1)
|
parser.add_argument('--bbox_shift', type=int, default=5)
|
||||||
|
parser.add_argument('--batch_size', type=int, default=16)
|
||||||
|
|
||||||
parser.add_argument('--tts', type=str, default='edgetts') #xtts gpt-sovits
|
# parser.add_argument('--customvideo', action='store_true', help="custom video")
|
||||||
|
# parser.add_argument('--customvideo_img', type=str, default='data/customvideo/img')
|
||||||
|
# parser.add_argument('--customvideo_imgnum', type=int, default=1)
|
||||||
|
|
||||||
|
parser.add_argument('--customvideo_config', type=str, default='')
|
||||||
|
|
||||||
|
parser.add_argument('--tts', type=str, default='edgetts') #xtts gpt-sovits cosyvoice
|
||||||
parser.add_argument('--REF_FILE', type=str, default=None)
|
parser.add_argument('--REF_FILE', type=str, default=None)
|
||||||
parser.add_argument('--TTS_SERVER', type=str, default='http://localhost:9000') #http://127.0.0.1:5000
|
parser.add_argument('--REF_TEXT', type=str, default=None)
|
||||||
parser.add_argument('--CHARACTER', type=str, default='test')
|
parser.add_argument('--TTS_SERVER', type=str, default='http://127.0.0.1:9880') # http://localhost:9000
|
||||||
parser.add_argument('--EMOTION', type=str, default='default')
|
# parser.add_argument('--CHARACTER', type=str, default='test')
|
||||||
|
# parser.add_argument('--EMOTION', type=str, default='default')
|
||||||
|
|
||||||
|
parser.add_argument('--model', type=str, default='ernerf') #musetalk wav2lip
|
||||||
|
|
||||||
|
parser.add_argument('--transport', type=str, default='rtcpush') #rtmp webrtc rtcpush
|
||||||
|
parser.add_argument('--push_url', type=str, default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') #rtmp://localhost/live/livestream
|
||||||
|
|
||||||
|
parser.add_argument('--max_session', type=int, default=1) #multi session count
|
||||||
parser.add_argument('--listenport', type=int, default=8010)
|
parser.add_argument('--listenport', type=int, default=8010)
|
||||||
|
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
app.config.from_object(opt)
|
#app.config.from_object(opt)
|
||||||
print(app.config)
|
#print(app.config)
|
||||||
|
opt.customopt = []
|
||||||
|
if opt.customvideo_config!='':
|
||||||
|
with open(opt.customvideo_config,'r') as file:
|
||||||
|
opt.customopt = json.load(file)
|
||||||
|
|
||||||
tts_type = opt.tts
|
if opt.model == 'ernerf':
|
||||||
if tts_type == "xtts":
|
from ernerf.nerf_triplane.provider import NeRFDataset_Test
|
||||||
print("Computing the latents for a new reference...")
|
from ernerf.nerf_triplane.utils import *
|
||||||
gspeaker = get_speaker(opt.REF_FILE, opt.TTS_SERVER)
|
from ernerf.nerf_triplane.network import NeRFNetwork
|
||||||
|
from nerfreal import NeRFReal
|
||||||
|
# assert test mode
|
||||||
|
opt.test = True
|
||||||
|
opt.test_train = False
|
||||||
|
#opt.train_camera =True
|
||||||
|
# explicit smoothing
|
||||||
|
opt.smooth_path = True
|
||||||
|
opt.smooth_lips = True
|
||||||
|
|
||||||
# assert test mode
|
assert opt.pose != '', 'Must provide a pose source'
|
||||||
opt.test = True
|
|
||||||
opt.test_train = False
|
|
||||||
#opt.train_camera =True
|
|
||||||
# explicit smoothing
|
|
||||||
opt.smooth_path = True
|
|
||||||
opt.smooth_lips = True
|
|
||||||
|
|
||||||
assert opt.pose != '', 'Must provide a pose source'
|
# if opt.O:
|
||||||
|
opt.fp16 = True
|
||||||
|
opt.cuda_ray = True
|
||||||
|
opt.exp_eye = True
|
||||||
|
opt.smooth_eye = True
|
||||||
|
|
||||||
# if opt.O:
|
if opt.torso_imgs=='': #no img,use model output
|
||||||
opt.fp16 = True
|
opt.torso = True
|
||||||
opt.cuda_ray = True
|
|
||||||
opt.exp_eye = True
|
|
||||||
opt.smooth_eye = True
|
|
||||||
|
|
||||||
if opt.torso_imgs=='': #no img,use model output
|
# assert opt.cuda_ray, "Only support CUDA ray mode."
|
||||||
opt.torso = True
|
opt.asr = True
|
||||||
|
|
||||||
# assert opt.cuda_ray, "Only support CUDA ray mode."
|
if opt.patch_size > 1:
|
||||||
opt.asr = True
|
# assert opt.patch_size > 16, "patch_size should > 16 to run LPIPS loss."
|
||||||
|
assert opt.num_rays % (opt.patch_size ** 2) == 0, "patch_size ** 2 should be dividable by num_rays."
|
||||||
|
seed_everything(opt.seed)
|
||||||
|
print(opt)
|
||||||
|
|
||||||
if opt.patch_size > 1:
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
# assert opt.patch_size > 16, "patch_size should > 16 to run LPIPS loss."
|
model = NeRFNetwork(opt)
|
||||||
assert opt.num_rays % (opt.patch_size ** 2) == 0, "patch_size ** 2 should be dividable by num_rays."
|
|
||||||
seed_everything(opt.seed)
|
|
||||||
print(opt)
|
|
||||||
|
|
||||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
criterion = torch.nn.MSELoss(reduction='none')
|
||||||
model = NeRFNetwork(opt)
|
metrics = [] # use no metric in GUI for faster initialization...
|
||||||
|
print(model)
|
||||||
|
trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, metrics=metrics, use_checkpoint=opt.ckpt)
|
||||||
|
|
||||||
criterion = torch.nn.MSELoss(reduction='none')
|
test_loader = NeRFDataset_Test(opt, device=device).dataloader()
|
||||||
metrics = [] # use no metric in GUI for faster initialization...
|
model.aud_features = test_loader._data.auds
|
||||||
print(model)
|
model.eye_areas = test_loader._data.eye_area
|
||||||
trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, metrics=metrics, use_checkpoint=opt.ckpt)
|
|
||||||
|
|
||||||
test_loader = NeRFDataset_Test(opt, device=device).dataloader()
|
# we still need test_loader to provide audio features for testing.
|
||||||
model.aud_features = test_loader._data.auds
|
for _ in range(opt.max_session):
|
||||||
model.eye_areas = test_loader._data.eye_area
|
nerfreal = NeRFReal(opt, trainer, test_loader)
|
||||||
|
nerfreals.append(nerfreal)
|
||||||
|
elif opt.model == 'musetalk':
|
||||||
|
from musereal import MuseReal
|
||||||
|
print(opt)
|
||||||
|
for _ in range(opt.max_session):
|
||||||
|
nerfreal = MuseReal(opt)
|
||||||
|
nerfreals.append(nerfreal)
|
||||||
|
elif opt.model == 'wav2lip':
|
||||||
|
from lipreal import LipReal
|
||||||
|
print(opt)
|
||||||
|
for _ in range(opt.max_session):
|
||||||
|
nerfreal = LipReal(opt)
|
||||||
|
nerfreals.append(nerfreal)
|
||||||
|
|
||||||
|
for _ in range(opt.max_session):
|
||||||
|
statreals.append(0)
|
||||||
|
|
||||||
# we still need test_loader to provide audio features for testing.
|
|
||||||
nerfreal = NeRFReal(opt, trainer, test_loader)
|
|
||||||
#txt_to_audio('我是中国人,我来自北京')
|
|
||||||
if opt.transport=='rtmp':
|
if opt.transport=='rtmp':
|
||||||
thread_quit = Event()
|
thread_quit = Event()
|
||||||
rendthrd = Thread(target=nerfreal.render,args=(thread_quit,))
|
rendthrd = Thread(target=nerfreals[0].render,args=(thread_quit,))
|
||||||
rendthrd.start()
|
rendthrd.start()
|
||||||
|
|
||||||
#############################################################################
|
#############################################################################
|
||||||
|
@ -511,6 +523,10 @@ if __name__ == '__main__':
|
||||||
appasync.on_shutdown.append(on_shutdown)
|
appasync.on_shutdown.append(on_shutdown)
|
||||||
appasync.router.add_post("/offer", offer)
|
appasync.router.add_post("/offer", offer)
|
||||||
appasync.router.add_post("/human", human)
|
appasync.router.add_post("/human", human)
|
||||||
|
appasync.router.add_post("/humanaudio", humanaudio)
|
||||||
|
appasync.router.add_post("/set_audiotype", set_audiotype)
|
||||||
|
appasync.router.add_post("/record", record)
|
||||||
|
appasync.router.add_post("/is_speaking", is_speaking)
|
||||||
appasync.router.add_static('/',path='web')
|
appasync.router.add_static('/',path='web')
|
||||||
|
|
||||||
# Configure default CORS settings.
|
# Configure default CORS settings.
|
||||||
|
@ -525,6 +541,12 @@ if __name__ == '__main__':
|
||||||
for route in list(appasync.router.routes()):
|
for route in list(appasync.router.routes()):
|
||||||
cors.add(route)
|
cors.add(route)
|
||||||
|
|
||||||
|
pagename='webrtcapi.html'
|
||||||
|
if opt.transport=='rtmp':
|
||||||
|
pagename='echoapi.html'
|
||||||
|
elif opt.transport=='rtcpush':
|
||||||
|
pagename='rtcpushapi.html'
|
||||||
|
print('start http server; http://<serverip>:'+str(opt.listenport)+'/'+pagename)
|
||||||
def run_server(runner):
|
def run_server(runner):
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
|
@ -534,12 +556,14 @@ if __name__ == '__main__':
|
||||||
if opt.transport=='rtcpush':
|
if opt.transport=='rtcpush':
|
||||||
loop.run_until_complete(run(opt.push_url))
|
loop.run_until_complete(run(opt.push_url))
|
||||||
loop.run_forever()
|
loop.run_forever()
|
||||||
Thread(target=run_server, args=(web.AppRunner(appasync),)).start()
|
#Thread(target=run_server, args=(web.AppRunner(appasync),)).start()
|
||||||
|
run_server(web.AppRunner(appasync))
|
||||||
|
|
||||||
print('start websocket server')
|
|
||||||
#app.on_shutdown.append(on_shutdown)
|
#app.on_shutdown.append(on_shutdown)
|
||||||
#app.router.add_post("/offer", offer)
|
#app.router.add_post("/offer", offer)
|
||||||
server = pywsgi.WSGIServer(('0.0.0.0', 8000), app, handler_class=WebSocketHandler)
|
|
||||||
server.serve_forever()
|
# print('start websocket server')
|
||||||
|
# server = pywsgi.WSGIServer(('0.0.0.0', 8000), app, handler_class=WebSocketHandler)
|
||||||
|
# server.serve_forever()
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,69 @@
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import queue
|
||||||
|
from queue import Queue
|
||||||
|
import multiprocessing as mp
|
||||||
|
|
||||||
|
|
||||||
|
class BaseASR:
|
||||||
|
def __init__(self, opt, parent=None):
|
||||||
|
self.opt = opt
|
||||||
|
self.parent = parent
|
||||||
|
|
||||||
|
self.fps = opt.fps # 20 ms per frame
|
||||||
|
self.sample_rate = 16000
|
||||||
|
self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000)
|
||||||
|
self.queue = Queue()
|
||||||
|
self.output_queue = mp.Queue()
|
||||||
|
|
||||||
|
self.batch_size = opt.batch_size
|
||||||
|
|
||||||
|
self.frames = []
|
||||||
|
self.stride_left_size = opt.l
|
||||||
|
self.stride_right_size = opt.r
|
||||||
|
#self.context_size = 10
|
||||||
|
self.feat_queue = mp.Queue(2)
|
||||||
|
|
||||||
|
#self.warm_up()
|
||||||
|
|
||||||
|
def pause_talk(self):
|
||||||
|
self.queue.queue.clear()
|
||||||
|
|
||||||
|
def put_audio_frame(self,audio_chunk): #16khz 20ms pcm
|
||||||
|
self.queue.put(audio_chunk)
|
||||||
|
|
||||||
|
def get_audio_frame(self):
|
||||||
|
try:
|
||||||
|
frame = self.queue.get(block=True,timeout=0.01)
|
||||||
|
type = 0
|
||||||
|
#print(f'[INFO] get frame {frame.shape}')
|
||||||
|
except queue.Empty:
|
||||||
|
if self.parent and self.parent.curr_state>1: #播放自定义音频
|
||||||
|
frame = self.parent.get_audio_stream(self.parent.curr_state)
|
||||||
|
type = self.parent.curr_state
|
||||||
|
else:
|
||||||
|
frame = np.zeros(self.chunk, dtype=np.float32)
|
||||||
|
type = 1
|
||||||
|
|
||||||
|
return frame,type
|
||||||
|
|
||||||
|
def is_audio_frame_empty(self)->bool:
|
||||||
|
return self.queue.empty()
|
||||||
|
|
||||||
|
def get_audio_out(self): #get origin audio pcm to nerf
|
||||||
|
return self.output_queue.get()
|
||||||
|
|
||||||
|
def warm_up(self):
|
||||||
|
for _ in range(self.stride_left_size + self.stride_right_size):
|
||||||
|
audio_frame,type=self.get_audio_frame()
|
||||||
|
self.frames.append(audio_frame)
|
||||||
|
self.output_queue.put((audio_frame,type))
|
||||||
|
for _ in range(self.stride_left_size):
|
||||||
|
self.output_queue.get()
|
||||||
|
|
||||||
|
def run_step(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_next_feat(self,block,timeout):
|
||||||
|
return self.feat_queue.get(block,timeout)
|
|
@ -0,0 +1,207 @@
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import cv2
|
||||||
|
import glob
|
||||||
|
import pickle
|
||||||
|
import copy
|
||||||
|
import resampy
|
||||||
|
|
||||||
|
import queue
|
||||||
|
from queue import Queue
|
||||||
|
from threading import Thread, Event
|
||||||
|
from io import BytesIO
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
|
import av
|
||||||
|
from fractions import Fraction
|
||||||
|
|
||||||
|
from ttsreal import EdgeTTS,VoitsTTS,XTTS,CosyVoiceTTS
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
def read_imgs(img_list):
|
||||||
|
frames = []
|
||||||
|
print('reading images...')
|
||||||
|
for img_path in tqdm(img_list):
|
||||||
|
frame = cv2.imread(img_path)
|
||||||
|
frames.append(frame)
|
||||||
|
return frames
|
||||||
|
|
||||||
|
class BaseReal:
|
||||||
|
def __init__(self, opt):
|
||||||
|
self.opt = opt
|
||||||
|
self.sample_rate = 16000
|
||||||
|
self.chunk = self.sample_rate // opt.fps # 320 samples per chunk (20ms * 16000 / 1000)
|
||||||
|
|
||||||
|
if opt.tts == "edgetts":
|
||||||
|
self.tts = EdgeTTS(opt,self)
|
||||||
|
elif opt.tts == "gpt-sovits":
|
||||||
|
self.tts = VoitsTTS(opt,self)
|
||||||
|
elif opt.tts == "xtts":
|
||||||
|
self.tts = XTTS(opt,self)
|
||||||
|
elif opt.tts == "cosyvoice":
|
||||||
|
self.tts = CosyVoiceTTS(opt,self)
|
||||||
|
|
||||||
|
self.speaking = False
|
||||||
|
|
||||||
|
self.recording = False
|
||||||
|
self.recordq_video = Queue()
|
||||||
|
self.recordq_audio = Queue()
|
||||||
|
|
||||||
|
self.curr_state=0
|
||||||
|
self.custom_img_cycle = {}
|
||||||
|
self.custom_audio_cycle = {}
|
||||||
|
self.custom_audio_index = {}
|
||||||
|
self.custom_index = {}
|
||||||
|
self.custom_opt = {}
|
||||||
|
self.__loadcustom()
|
||||||
|
|
||||||
|
def put_msg_txt(self,msg):
|
||||||
|
self.tts.put_msg_txt(msg)
|
||||||
|
|
||||||
|
def put_audio_frame(self,audio_chunk): #16khz 20ms pcm
|
||||||
|
self.asr.put_audio_frame(audio_chunk)
|
||||||
|
|
||||||
|
def put_audio_file(self,filebyte):
|
||||||
|
input_stream = BytesIO(filebyte)
|
||||||
|
stream = self.__create_bytes_stream(input_stream)
|
||||||
|
streamlen = stream.shape[0]
|
||||||
|
idx=0
|
||||||
|
while streamlen >= self.chunk: #and self.state==State.RUNNING
|
||||||
|
self.put_audio_frame(stream[idx:idx+self.chunk])
|
||||||
|
streamlen -= self.chunk
|
||||||
|
idx += self.chunk
|
||||||
|
|
||||||
|
def __create_bytes_stream(self,byte_stream):
|
||||||
|
#byte_stream=BytesIO(buffer)
|
||||||
|
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
|
||||||
|
print(f'[INFO]put audio stream {sample_rate}: {stream.shape}')
|
||||||
|
stream = stream.astype(np.float32)
|
||||||
|
|
||||||
|
if stream.ndim > 1:
|
||||||
|
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
|
||||||
|
stream = stream[:, 0]
|
||||||
|
|
||||||
|
if sample_rate != self.sample_rate and stream.shape[0]>0:
|
||||||
|
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.')
|
||||||
|
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate)
|
||||||
|
|
||||||
|
return stream
|
||||||
|
|
||||||
|
def pause_talk(self):
|
||||||
|
self.tts.pause_talk()
|
||||||
|
self.asr.pause_talk()
|
||||||
|
|
||||||
|
def is_speaking(self)->bool:
|
||||||
|
return self.speaking
|
||||||
|
|
||||||
|
def __loadcustom(self):
|
||||||
|
for item in self.opt.customopt:
|
||||||
|
print(item)
|
||||||
|
input_img_list = glob.glob(os.path.join(item['imgpath'], '*.[jpJP][pnPN]*[gG]'))
|
||||||
|
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
|
||||||
|
self.custom_img_cycle[item['audiotype']] = read_imgs(input_img_list)
|
||||||
|
self.custom_audio_cycle[item['audiotype']], sample_rate = sf.read(item['audiopath'], dtype='float32')
|
||||||
|
self.custom_audio_index[item['audiotype']] = 0
|
||||||
|
self.custom_index[item['audiotype']] = 0
|
||||||
|
self.custom_opt[item['audiotype']] = item
|
||||||
|
|
||||||
|
def init_customindex(self):
|
||||||
|
self.curr_state=0
|
||||||
|
for key in self.custom_audio_index:
|
||||||
|
self.custom_audio_index[key]=0
|
||||||
|
for key in self.custom_index:
|
||||||
|
self.custom_index[key]=0
|
||||||
|
|
||||||
|
def start_recording(self,path):
|
||||||
|
"""开始录制视频"""
|
||||||
|
if self.recording:
|
||||||
|
return
|
||||||
|
self.recording = True
|
||||||
|
self.recordq_video.queue.clear()
|
||||||
|
self.recordq_audio.queue.clear()
|
||||||
|
self.container = av.open(path, mode="w")
|
||||||
|
|
||||||
|
process_thread = Thread(target=self.record_frame, args=())
|
||||||
|
process_thread.start()
|
||||||
|
|
||||||
|
def record_frame(self):
|
||||||
|
videostream = self.container.add_stream("libx264", rate=25)
|
||||||
|
videostream.codec_context.time_base = Fraction(1, 25)
|
||||||
|
audiostream = self.container.add_stream("aac")
|
||||||
|
audiostream.codec_context.time_base = Fraction(1, 16000)
|
||||||
|
init = True
|
||||||
|
framenum = 0
|
||||||
|
while self.recording:
|
||||||
|
try:
|
||||||
|
videoframe = self.recordq_video.get(block=True, timeout=1)
|
||||||
|
videoframe.pts = framenum #int(round(framenum*0.04 / videostream.codec_context.time_base))
|
||||||
|
videoframe.dts = videoframe.pts
|
||||||
|
if init:
|
||||||
|
videostream.width = videoframe.width
|
||||||
|
videostream.height = videoframe.height
|
||||||
|
init = False
|
||||||
|
for packet in videostream.encode(videoframe):
|
||||||
|
self.container.mux(packet)
|
||||||
|
for k in range(2):
|
||||||
|
audioframe = self.recordq_audio.get(block=True, timeout=1)
|
||||||
|
audioframe.pts = int(round((framenum*2+k)*0.02 / audiostream.codec_context.time_base))
|
||||||
|
audioframe.dts = audioframe.pts
|
||||||
|
for packet in audiostream.encode(audioframe):
|
||||||
|
self.container.mux(packet)
|
||||||
|
framenum += 1
|
||||||
|
except queue.Empty:
|
||||||
|
print('record queue empty,')
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
#break
|
||||||
|
for packet in videostream.encode(None):
|
||||||
|
self.container.mux(packet)
|
||||||
|
for packet in audiostream.encode(None):
|
||||||
|
self.container.mux(packet)
|
||||||
|
self.container.close()
|
||||||
|
self.recordq_video.queue.clear()
|
||||||
|
self.recordq_audio.queue.clear()
|
||||||
|
print('record thread stop')
|
||||||
|
|
||||||
|
def stop_recording(self):
|
||||||
|
"""停止录制视频"""
|
||||||
|
if not self.recording:
|
||||||
|
return
|
||||||
|
self.recording = False
|
||||||
|
|
||||||
|
def mirror_index(self,size, index):
|
||||||
|
#size = len(self.coord_list_cycle)
|
||||||
|
turn = index // size
|
||||||
|
res = index % size
|
||||||
|
if turn % 2 == 0:
|
||||||
|
return res
|
||||||
|
else:
|
||||||
|
return size - res - 1
|
||||||
|
|
||||||
|
def get_audio_stream(self,audiotype):
|
||||||
|
idx = self.custom_audio_index[audiotype]
|
||||||
|
stream = self.custom_audio_cycle[audiotype][idx:idx+self.chunk]
|
||||||
|
self.custom_audio_index[audiotype] += self.chunk
|
||||||
|
if self.custom_audio_index[audiotype]>=self.custom_audio_cycle[audiotype].shape[0]:
|
||||||
|
self.curr_state = 1 #当前视频不循环播放,切换到静音状态
|
||||||
|
return stream
|
||||||
|
|
||||||
|
def set_curr_state(self,audiotype, reinit):
|
||||||
|
print('set_curr_state:',audiotype)
|
||||||
|
self.curr_state = audiotype
|
||||||
|
if reinit:
|
||||||
|
self.custom_audio_index[audiotype] = 0
|
||||||
|
self.custom_index[audiotype] = 0
|
||||||
|
|
||||||
|
# def process_custom(self,audiotype:int,idx:int):
|
||||||
|
# if self.curr_state!=audiotype: #从推理切到口播
|
||||||
|
# if idx in self.switch_pos: #在卡点位置可以切换
|
||||||
|
# self.curr_state=audiotype
|
||||||
|
# self.custom_index=0
|
||||||
|
# else:
|
||||||
|
# self.custom_index+=1
|
|
@ -0,0 +1,7 @@
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"audiotype":2,
|
||||||
|
"imgpath":"data/customvideo/image",
|
||||||
|
"audiopath":"data/customvideo/audio.wav"
|
||||||
|
}
|
||||||
|
]
|
|
@ -4,13 +4,13 @@ from torch.utils.cpp_extension import load
|
||||||
_src_path = os.path.dirname(os.path.abspath(__file__))
|
_src_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
nvcc_flags = [
|
nvcc_flags = [
|
||||||
'-O3', '-std=c++14',
|
'-O3', '-std=c++17',
|
||||||
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', '-allow-unsupported-compiler',
|
||||||
'-use_fast_math'
|
'-use_fast_math'
|
||||||
]
|
]
|
||||||
|
|
||||||
if os.name == "posix":
|
if os.name == "posix":
|
||||||
c_flags = ['-O3', '-std=c++14']
|
c_flags = ['-O3', '-std=c++17']
|
||||||
elif os.name == "nt":
|
elif os.name == "nt":
|
||||||
c_flags = ['/O2', '/std:c++17']
|
c_flags = ['/O2', '/std:c++17']
|
||||||
|
|
||||||
|
|
|
@ -5,13 +5,13 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||||
_src_path = os.path.dirname(os.path.abspath(__file__))
|
_src_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
nvcc_flags = [
|
nvcc_flags = [
|
||||||
'-O3', '-std=c++14',
|
'-O3', '-std=c++17',
|
||||||
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', '-allow-unsupported-compiler',
|
||||||
'-use_fast_math'
|
'-use_fast_math'
|
||||||
]
|
]
|
||||||
|
|
||||||
if os.name == "posix":
|
if os.name == "posix":
|
||||||
c_flags = ['-O3', '-std=c++14']
|
c_flags = ['-O3', '-std=c++17']
|
||||||
elif os.name == "nt":
|
elif os.name == "nt":
|
||||||
c_flags = ['/O2', '/std:c++17']
|
c_flags = ['/O2', '/std:c++17']
|
||||||
|
|
||||||
|
|
|
@ -4,12 +4,12 @@ from torch.utils.cpp_extension import load
|
||||||
_src_path = os.path.dirname(os.path.abspath(__file__))
|
_src_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
nvcc_flags = [
|
nvcc_flags = [
|
||||||
'-O3', '-std=c++14',
|
'-O3', '-std=c++17',
|
||||||
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
||||||
]
|
]
|
||||||
|
|
||||||
if os.name == "posix":
|
if os.name == "posix":
|
||||||
c_flags = ['-O3', '-std=c++14', '-finput-charset=UTF-8']
|
c_flags = ['-O3', '-std=c++17', '-finput-charset=UTF-8']
|
||||||
elif os.name == "nt":
|
elif os.name == "nt":
|
||||||
c_flags = ['/O2', '/std:c++17', '/finput-charset=UTF-8']
|
c_flags = ['/O2', '/std:c++17', '/finput-charset=UTF-8']
|
||||||
|
|
||||||
|
|
|
@ -5,12 +5,12 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||||
_src_path = os.path.dirname(os.path.abspath(__file__))
|
_src_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
nvcc_flags = [
|
nvcc_flags = [
|
||||||
'-O3', '-std=c++14',
|
'-O3', '-std=c++17',
|
||||||
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__','-allow-unsupported-compiler',
|
||||||
]
|
]
|
||||||
|
|
||||||
if os.name == "posix":
|
if os.name == "posix":
|
||||||
c_flags = ['-O3', '-std=c++14']
|
c_flags = ['-O3', '-std=c++17']
|
||||||
elif os.name == "nt":
|
elif os.name == "nt":
|
||||||
c_flags = ['/O2', '/std:c++17']
|
c_flags = ['/O2', '/std:c++17']
|
||||||
|
|
||||||
|
|
|
@ -4,12 +4,12 @@ from torch.utils.cpp_extension import load
|
||||||
_src_path = os.path.dirname(os.path.abspath(__file__))
|
_src_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
nvcc_flags = [
|
nvcc_flags = [
|
||||||
'-O3', '-std=c++14',
|
'-O3', '-std=c++17',
|
||||||
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__','-allow-unsupported-compiler',
|
||||||
]
|
]
|
||||||
|
|
||||||
if os.name == "posix":
|
if os.name == "posix":
|
||||||
c_flags = ['-O3', '-std=c++14']
|
c_flags = ['-O3', '-std=c++17']
|
||||||
elif os.name == "nt":
|
elif os.name == "nt":
|
||||||
c_flags = ['/O2', '/std:c++17']
|
c_flags = ['/O2', '/std:c++17']
|
||||||
|
|
||||||
|
|
|
@ -5,13 +5,13 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||||
_src_path = os.path.dirname(os.path.abspath(__file__))
|
_src_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
nvcc_flags = [
|
nvcc_flags = [
|
||||||
'-O3', '-std=c++14',
|
'-O3', '-std=c++17',
|
||||||
# '-lineinfo', # to debug illegal memory access
|
# '-lineinfo', # to debug illegal memory access
|
||||||
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__','-allow-unsupported-compiler',
|
||||||
]
|
]
|
||||||
|
|
||||||
if os.name == "posix":
|
if os.name == "posix":
|
||||||
c_flags = ['-O3', '-std=c++14']
|
c_flags = ['-O3', '-std=c++17']
|
||||||
elif os.name == "nt":
|
elif os.name == "nt":
|
||||||
c_flags = ['/O2', '/std:c++17']
|
c_flags = ['/O2', '/std:c++17']
|
||||||
|
|
||||||
|
|
|
@ -4,12 +4,12 @@ from torch.utils.cpp_extension import load
|
||||||
_src_path = os.path.dirname(os.path.abspath(__file__))
|
_src_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
nvcc_flags = [
|
nvcc_flags = [
|
||||||
'-O3', '-std=c++14',
|
'-O3', '-std=c++17',
|
||||||
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__','-allow-unsupported-compiler',
|
||||||
]
|
]
|
||||||
|
|
||||||
if os.name == "posix":
|
if os.name == "posix":
|
||||||
c_flags = ['-O3', '-std=c++14', '-finput-charset=utf-8']
|
c_flags = ['-O3', '-std=c++17', '-finput-charset=utf-8']
|
||||||
elif os.name == "nt":
|
elif os.name == "nt":
|
||||||
c_flags = ['/O2', '/std:c++17', '/source-charset:utf-8']
|
c_flags = ['/O2', '/std:c++17', '/source-charset:utf-8']
|
||||||
|
|
||||||
|
|
|
@ -5,12 +5,12 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||||
_src_path = os.path.dirname(os.path.abspath(__file__))
|
_src_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
nvcc_flags = [
|
nvcc_flags = [
|
||||||
'-O3', '-std=c++14',
|
'-O3', '-std=c++17',
|
||||||
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__','-allow-unsupported-compiler',
|
||||||
]
|
]
|
||||||
|
|
||||||
if os.name == "posix":
|
if os.name == "posix":
|
||||||
c_flags = ['-O3', '-std=c++14']
|
c_flags = ['-O3', '-std=c++17']
|
||||||
elif os.name == "nt":
|
elif os.name == "nt":
|
||||||
c_flags = ['/O2', '/std:c++17']
|
c_flags = ['/O2', '/std:c++17']
|
||||||
|
|
||||||
|
|
|
@ -1,315 +0,0 @@
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import websockets
|
|
||||||
import time
|
|
||||||
import logging
|
|
||||||
import tracemalloc
|
|
||||||
import numpy as np
|
|
||||||
import argparse
|
|
||||||
import ssl
|
|
||||||
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("--host",
|
|
||||||
type=str,
|
|
||||||
default="0.0.0.0",
|
|
||||||
required=False,
|
|
||||||
help="host ip, localhost, 0.0.0.0")
|
|
||||||
parser.add_argument("--port",
|
|
||||||
type=int,
|
|
||||||
default=10095,
|
|
||||||
required=False,
|
|
||||||
help="grpc server port")
|
|
||||||
parser.add_argument("--asr_model",
|
|
||||||
type=str,
|
|
||||||
default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
|
||||||
help="model from modelscope")
|
|
||||||
parser.add_argument("--asr_model_revision",
|
|
||||||
type=str,
|
|
||||||
default="v2.0.4",
|
|
||||||
help="")
|
|
||||||
parser.add_argument("--asr_model_online",
|
|
||||||
type=str,
|
|
||||||
default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
|
|
||||||
help="model from modelscope")
|
|
||||||
parser.add_argument("--asr_model_online_revision",
|
|
||||||
type=str,
|
|
||||||
default="v2.0.4",
|
|
||||||
help="")
|
|
||||||
parser.add_argument("--vad_model",
|
|
||||||
type=str,
|
|
||||||
default="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
|
||||||
help="model from modelscope")
|
|
||||||
parser.add_argument("--vad_model_revision",
|
|
||||||
type=str,
|
|
||||||
default="v2.0.4",
|
|
||||||
help="")
|
|
||||||
parser.add_argument("--punc_model",
|
|
||||||
type=str,
|
|
||||||
default="iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
|
|
||||||
help="model from modelscope")
|
|
||||||
parser.add_argument("--punc_model_revision",
|
|
||||||
type=str,
|
|
||||||
default="v2.0.4",
|
|
||||||
help="")
|
|
||||||
parser.add_argument("--ngpu",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="0 for cpu, 1 for gpu")
|
|
||||||
parser.add_argument("--device",
|
|
||||||
type=str,
|
|
||||||
default="cuda",
|
|
||||||
help="cuda, cpu")
|
|
||||||
parser.add_argument("--ncpu",
|
|
||||||
type=int,
|
|
||||||
default=4,
|
|
||||||
help="cpu cores")
|
|
||||||
parser.add_argument("--certfile",
|
|
||||||
type=str,
|
|
||||||
default="ssl_key/server.crt",
|
|
||||||
required=False,
|
|
||||||
help="certfile for ssl")
|
|
||||||
|
|
||||||
parser.add_argument("--keyfile",
|
|
||||||
type=str,
|
|
||||||
default="ssl_key/server.key",
|
|
||||||
required=False,
|
|
||||||
help="keyfile for ssl")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
websocket_users = set()
|
|
||||||
|
|
||||||
print("model loading")
|
|
||||||
from funasr import AutoModel
|
|
||||||
|
|
||||||
# asr
|
|
||||||
model_asr = AutoModel(model=args.asr_model,
|
|
||||||
model_revision=args.asr_model_revision,
|
|
||||||
ngpu=args.ngpu,
|
|
||||||
ncpu=args.ncpu,
|
|
||||||
device=args.device,
|
|
||||||
disable_pbar=True,
|
|
||||||
disable_log=True,
|
|
||||||
)
|
|
||||||
# asr
|
|
||||||
model_asr_streaming = AutoModel(model=args.asr_model_online,
|
|
||||||
model_revision=args.asr_model_online_revision,
|
|
||||||
ngpu=args.ngpu,
|
|
||||||
ncpu=args.ncpu,
|
|
||||||
device=args.device,
|
|
||||||
disable_pbar=True,
|
|
||||||
disable_log=True,
|
|
||||||
)
|
|
||||||
# vad
|
|
||||||
model_vad = AutoModel(model=args.vad_model,
|
|
||||||
model_revision=args.vad_model_revision,
|
|
||||||
ngpu=args.ngpu,
|
|
||||||
ncpu=args.ncpu,
|
|
||||||
device=args.device,
|
|
||||||
disable_pbar=True,
|
|
||||||
disable_log=True,
|
|
||||||
# chunk_size=60,
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.punc_model != "":
|
|
||||||
model_punc = AutoModel(model=args.punc_model,
|
|
||||||
model_revision=args.punc_model_revision,
|
|
||||||
ngpu=args.ngpu,
|
|
||||||
ncpu=args.ncpu,
|
|
||||||
device=args.device,
|
|
||||||
disable_pbar=True,
|
|
||||||
disable_log=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
model_punc = None
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
print("model loaded! only support one client at the same time now!!!!")
|
|
||||||
|
|
||||||
async def ws_reset(websocket):
|
|
||||||
print("ws reset now, total num is ",len(websocket_users))
|
|
||||||
|
|
||||||
websocket.status_dict_asr_online["cache"] = {}
|
|
||||||
websocket.status_dict_asr_online["is_final"] = True
|
|
||||||
websocket.status_dict_vad["cache"] = {}
|
|
||||||
websocket.status_dict_vad["is_final"] = True
|
|
||||||
websocket.status_dict_punc["cache"] = {}
|
|
||||||
|
|
||||||
await websocket.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def clear_websocket():
|
|
||||||
for websocket in websocket_users:
|
|
||||||
await ws_reset(websocket)
|
|
||||||
websocket_users.clear()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def ws_serve(websocket, path):
|
|
||||||
frames = []
|
|
||||||
frames_asr = []
|
|
||||||
frames_asr_online = []
|
|
||||||
global websocket_users
|
|
||||||
# await clear_websocket()
|
|
||||||
websocket_users.add(websocket)
|
|
||||||
websocket.status_dict_asr = {}
|
|
||||||
websocket.status_dict_asr_online = {"cache": {}, "is_final": False}
|
|
||||||
websocket.status_dict_vad = {'cache': {}, "is_final": False}
|
|
||||||
websocket.status_dict_punc = {'cache': {}}
|
|
||||||
websocket.chunk_interval = 10
|
|
||||||
websocket.vad_pre_idx = 0
|
|
||||||
speech_start = False
|
|
||||||
speech_end_i = -1
|
|
||||||
websocket.wav_name = "microphone"
|
|
||||||
websocket.mode = "2pass"
|
|
||||||
print("new user connected", flush=True)
|
|
||||||
|
|
||||||
try:
|
|
||||||
async for message in websocket:
|
|
||||||
if isinstance(message, str):
|
|
||||||
messagejson = json.loads(message)
|
|
||||||
|
|
||||||
if "is_speaking" in messagejson:
|
|
||||||
websocket.is_speaking = messagejson["is_speaking"]
|
|
||||||
websocket.status_dict_asr_online["is_final"] = not websocket.is_speaking
|
|
||||||
if "chunk_interval" in messagejson:
|
|
||||||
websocket.chunk_interval = messagejson["chunk_interval"]
|
|
||||||
if "wav_name" in messagejson:
|
|
||||||
websocket.wav_name = messagejson.get("wav_name")
|
|
||||||
if "chunk_size" in messagejson:
|
|
||||||
chunk_size = messagejson["chunk_size"]
|
|
||||||
if isinstance(chunk_size, str):
|
|
||||||
chunk_size = chunk_size.split(',')
|
|
||||||
websocket.status_dict_asr_online["chunk_size"] = [int(x) for x in chunk_size]
|
|
||||||
if "encoder_chunk_look_back" in messagejson:
|
|
||||||
websocket.status_dict_asr_online["encoder_chunk_look_back"] = messagejson["encoder_chunk_look_back"]
|
|
||||||
if "decoder_chunk_look_back" in messagejson:
|
|
||||||
websocket.status_dict_asr_online["decoder_chunk_look_back"] = messagejson["decoder_chunk_look_back"]
|
|
||||||
if "hotword" in messagejson:
|
|
||||||
websocket.status_dict_asr["hotword"] = messagejson["hotword"]
|
|
||||||
if "mode" in messagejson:
|
|
||||||
websocket.mode = messagejson["mode"]
|
|
||||||
|
|
||||||
websocket.status_dict_vad["chunk_size"] = int(websocket.status_dict_asr_online["chunk_size"][1]*60/websocket.chunk_interval)
|
|
||||||
if len(frames_asr_online) > 0 or len(frames_asr) > 0 or not isinstance(message, str):
|
|
||||||
if not isinstance(message, str):
|
|
||||||
frames.append(message)
|
|
||||||
duration_ms = len(message)//32
|
|
||||||
websocket.vad_pre_idx += duration_ms
|
|
||||||
|
|
||||||
# asr online
|
|
||||||
frames_asr_online.append(message)
|
|
||||||
websocket.status_dict_asr_online["is_final"] = speech_end_i != -1
|
|
||||||
if len(frames_asr_online) % websocket.chunk_interval == 0 or websocket.status_dict_asr_online["is_final"]:
|
|
||||||
if websocket.mode == "2pass" or websocket.mode == "online":
|
|
||||||
audio_in = b"".join(frames_asr_online)
|
|
||||||
try:
|
|
||||||
await async_asr_online(websocket, audio_in)
|
|
||||||
except:
|
|
||||||
print(f"error in asr streaming, {websocket.status_dict_asr_online}")
|
|
||||||
frames_asr_online = []
|
|
||||||
if speech_start:
|
|
||||||
frames_asr.append(message)
|
|
||||||
# vad online
|
|
||||||
try:
|
|
||||||
speech_start_i, speech_end_i = await async_vad(websocket, message)
|
|
||||||
except:
|
|
||||||
print("error in vad")
|
|
||||||
if speech_start_i != -1:
|
|
||||||
speech_start = True
|
|
||||||
beg_bias = (websocket.vad_pre_idx-speech_start_i)//duration_ms
|
|
||||||
frames_pre = frames[-beg_bias:]
|
|
||||||
frames_asr = []
|
|
||||||
frames_asr.extend(frames_pre)
|
|
||||||
# asr punc offline
|
|
||||||
if speech_end_i != -1 or not websocket.is_speaking:
|
|
||||||
# print("vad end point")
|
|
||||||
if websocket.mode == "2pass" or websocket.mode == "offline":
|
|
||||||
audio_in = b"".join(frames_asr)
|
|
||||||
try:
|
|
||||||
await async_asr(websocket, audio_in)
|
|
||||||
except:
|
|
||||||
print("error in asr offline")
|
|
||||||
frames_asr = []
|
|
||||||
speech_start = False
|
|
||||||
frames_asr_online = []
|
|
||||||
websocket.status_dict_asr_online["cache"] = {}
|
|
||||||
if not websocket.is_speaking:
|
|
||||||
websocket.vad_pre_idx = 0
|
|
||||||
frames = []
|
|
||||||
websocket.status_dict_vad["cache"] = {}
|
|
||||||
else:
|
|
||||||
frames = frames[-20:]
|
|
||||||
|
|
||||||
|
|
||||||
except websockets.ConnectionClosed:
|
|
||||||
print("ConnectionClosed...", websocket_users,flush=True)
|
|
||||||
await ws_reset(websocket)
|
|
||||||
websocket_users.remove(websocket)
|
|
||||||
except websockets.InvalidState:
|
|
||||||
print("InvalidState...")
|
|
||||||
except Exception as e:
|
|
||||||
print("Exception:", e)
|
|
||||||
|
|
||||||
|
|
||||||
async def async_vad(websocket, audio_in):
|
|
||||||
|
|
||||||
segments_result = model_vad.generate(input=audio_in, **websocket.status_dict_vad)[0]["value"]
|
|
||||||
# print(segments_result)
|
|
||||||
|
|
||||||
speech_start = -1
|
|
||||||
speech_end = -1
|
|
||||||
|
|
||||||
if len(segments_result) == 0 or len(segments_result) > 1:
|
|
||||||
return speech_start, speech_end
|
|
||||||
if segments_result[0][0] != -1:
|
|
||||||
speech_start = segments_result[0][0]
|
|
||||||
if segments_result[0][1] != -1:
|
|
||||||
speech_end = segments_result[0][1]
|
|
||||||
return speech_start, speech_end
|
|
||||||
|
|
||||||
|
|
||||||
async def async_asr(websocket, audio_in):
|
|
||||||
if len(audio_in) > 0:
|
|
||||||
# print(len(audio_in))
|
|
||||||
rec_result = model_asr.generate(input=audio_in, **websocket.status_dict_asr)[0]
|
|
||||||
# print("offline_asr, ", rec_result)
|
|
||||||
if model_punc is not None and len(rec_result["text"])>0:
|
|
||||||
# print("offline, before punc", rec_result, "cache", websocket.status_dict_punc)
|
|
||||||
rec_result = model_punc.generate(input=rec_result['text'], **websocket.status_dict_punc)[0]
|
|
||||||
# print("offline, after punc", rec_result)
|
|
||||||
if len(rec_result["text"])>0:
|
|
||||||
# print("offline", rec_result)
|
|
||||||
mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
|
|
||||||
message = json.dumps({"mode": mode, "text": rec_result["text"], "wav_name": websocket.wav_name,"is_final":websocket.is_speaking})
|
|
||||||
await websocket.send(message)
|
|
||||||
|
|
||||||
|
|
||||||
async def async_asr_online(websocket, audio_in):
|
|
||||||
if len(audio_in) > 0:
|
|
||||||
# print(websocket.status_dict_asr_online.get("is_final", False))
|
|
||||||
rec_result = model_asr_streaming.generate(input=audio_in, **websocket.status_dict_asr_online)[0]
|
|
||||||
# print("online, ", rec_result)
|
|
||||||
if websocket.mode == "2pass" and websocket.status_dict_asr_online.get("is_final", False):
|
|
||||||
return
|
|
||||||
# websocket.status_dict_asr_online["cache"] = dict()
|
|
||||||
if len(rec_result["text"]):
|
|
||||||
mode = "2pass-online" if "2pass" in websocket.mode else websocket.mode
|
|
||||||
message = json.dumps({"mode": mode, "text": rec_result["text"], "wav_name": websocket.wav_name,"is_final":websocket.is_speaking})
|
|
||||||
await websocket.send(message)
|
|
||||||
|
|
||||||
if len(args.certfile)>0:
|
|
||||||
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
|
||||||
|
|
||||||
# Generate with Lets Encrypt, copied to this location, chown to current user and 400 permissions
|
|
||||||
ssl_cert = args.certfile
|
|
||||||
ssl_key = args.keyfile
|
|
||||||
|
|
||||||
ssl_context.load_cert_chain(ssl_cert, keyfile=ssl_key)
|
|
||||||
start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None,ssl=ssl_context)
|
|
||||||
else:
|
|
||||||
start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)
|
|
||||||
asyncio.get_event_loop().run_until_complete(start_server)
|
|
||||||
asyncio.get_event_loop().run_forever()
|
|
|
@ -1 +0,0 @@
|
||||||
在浏览器中打开 samples/html/static/index.html,输入ASR服务器地址,支持麦克风输入,也支持文件输入
|
|
|
@ -1,132 +0,0 @@
|
||||||
<!-- index.html -->
|
|
||||||
<html>
|
|
||||||
<head>
|
|
||||||
<script type="text/javascript" src="mpegts-1.7.3.min.js"></script>
|
|
||||||
<script type="text/javascript" src="http://cdn.sockjs.org/sockjs-0.3.4.js"></script>
|
|
||||||
<script src="http://code.jquery.com/jquery-2.1.1.min.js"></script>
|
|
||||||
<script src="recorder-core.js" charset="UTF-8"></script>
|
|
||||||
<script src="wav.js" charset="UTF-8"></script>
|
|
||||||
<script src="pcm.js" charset="UTF-8"></script>
|
|
||||||
|
|
||||||
|
|
||||||
</head>
|
|
||||||
<body>
|
|
||||||
<div class="container">
|
|
||||||
<h1>metahuman voice test</h1>
|
|
||||||
<form class="form-inline" id="echo-form" name="ssbtn">
|
|
||||||
<div class="form-group">
|
|
||||||
<p>input text</p>
|
|
||||||
|
|
||||||
<textarea cols="2" rows="3" style="width:600px;height:50px;" class="form-control" id="message"></textarea>
|
|
||||||
</div>
|
|
||||||
<button type="submit" class="btn btn-default">Send</button>
|
|
||||||
</form>
|
|
||||||
<div id="log">
|
|
||||||
|
|
||||||
</div>
|
|
||||||
<video id="video_player" width="40%" controls autoplay muted></video>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div>------------------------------------------------------------------------------------------------------------------------------</div>
|
|
||||||
<div class="div_class_topArea">
|
|
||||||
|
|
||||||
<div class="div_class_recordControl">
|
|
||||||
asr服务器地址(必填):
|
|
||||||
<br>
|
|
||||||
<input id="wssip" type="text" onchange="addresschange()" style="width: 500px;" value="wss://127.0.0.1:10095/"/>
|
|
||||||
<br>
|
|
||||||
<a id="wsslink" style="display: none;" href="#" onclick="window.open('https://127.0.0.1:10095/', '_blank')"><div id="info_wslink">点此处手工授权wss://127.0.0.1:10095/</div></a>
|
|
||||||
<br>
|
|
||||||
<br>
|
|
||||||
<div style="border:2px solid #ccc;display: none;">
|
|
||||||
选择录音模式:<br/>
|
|
||||||
|
|
||||||
<label ><input name="recoder_mode" onclick="on_recoder_mode_change()" type="radio" value="mic" checked="true"/>麦克风 </label>
|
|
||||||
<label><input name="recoder_mode" onclick="on_recoder_mode_change()" type="radio" value="file" />文件 </label>
|
|
||||||
|
|
||||||
</div>
|
|
||||||
|
|
||||||
|
|
||||||
<div id="mic_mode_div" style="border:2px solid #ccc;display:none;">
|
|
||||||
选择asr模型模式:<br/>
|
|
||||||
|
|
||||||
<label><input name="asr_mode" type="radio" value="2pass" />2pass </label>
|
|
||||||
<label><input name="asr_mode" type="radio" value="online" checked="true"/>online </label>
|
|
||||||
<label><input name="asr_mode" type="radio" value="2pass-offline" />2pass-offline </label>
|
|
||||||
<label><input name="asr_mode" type="radio" value="offline" />offline </label>
|
|
||||||
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div id="rec_mode_div" style="border:2px solid #ccc;display:none;">
|
|
||||||
|
|
||||||
|
|
||||||
<input type="file" id="upfile">
|
|
||||||
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div style="border:2px solid #ccc;display: none;">
|
|
||||||
热词设置(一行一个关键字,空格隔开权重,如"阿里巴巴 20"):
|
|
||||||
|
|
||||||
<textarea rows="1" id="varHot" style=" width: 100%;height:auto" >阿里巴巴 20 hello world 40</textarea>
|
|
||||||
|
|
||||||
</div>
|
|
||||||
<div style="display: none;">语音识别结果显示:</div>
|
|
||||||
<br>
|
|
||||||
|
|
||||||
<textarea rows="10" id="varArea" readonly="true" style=" width: 100%;height:auto;display: none;" ></textarea>
|
|
||||||
<br>
|
|
||||||
<div id="info_div">请点击开始</div>
|
|
||||||
<div class="div_class_buttons">
|
|
||||||
<button id="btnConnect">连接</button>
|
|
||||||
<button id="btnStart">开始</button>
|
|
||||||
<button id="btnStop">停止</button>
|
|
||||||
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<audio id="audio_record" type="audio/wav" controls style="margin-top: 2px; width: 100%;display: none;"></audio>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<script src="wsconnecter.js" charset="utf-8"></script>
|
|
||||||
<script src="main.js" charset="utf-8"></script>
|
|
||||||
|
|
||||||
</body>
|
|
||||||
<script type="text/javascript" charset="utf-8">
|
|
||||||
|
|
||||||
// $(document).ready(function() {
|
|
||||||
// var host = window.location.hostname
|
|
||||||
// var ws = new WebSocket("ws://"+host+":8000/humanecho");
|
|
||||||
// //document.getElementsByTagName("video")[0].setAttribute("src", aa["video"]);
|
|
||||||
// ws.onopen = function() {
|
|
||||||
// console.log('Connected');
|
|
||||||
// };
|
|
||||||
// ws.onmessage = function(e) {
|
|
||||||
// console.log('Received: ' + e.data);
|
|
||||||
// data = e
|
|
||||||
// var vid = JSON.parse(data.data);
|
|
||||||
// console.log(typeof(vid),vid)
|
|
||||||
// //document.getElementsByTagName("video")[0].setAttribute("src", vid["video"]);
|
|
||||||
|
|
||||||
// };
|
|
||||||
// ws.onclose = function(e) {
|
|
||||||
// console.log('Closed');
|
|
||||||
// };
|
|
||||||
|
|
||||||
// flvPlayer = mpegts.createPlayer({type: 'flv', url: "http://"+host+":8080/live/livestream.flv", isLive: true, enableStashBuffer: false});
|
|
||||||
// flvPlayer.attachMediaElement(document.getElementById('video_player'));
|
|
||||||
// flvPlayer.load();
|
|
||||||
// flvPlayer.play();
|
|
||||||
|
|
||||||
// $('#echo-form').on('submit', function(e) {
|
|
||||||
// e.preventDefault();
|
|
||||||
// var message = $('#message').val();
|
|
||||||
// console.log('Sending: ' + message);
|
|
||||||
// ws.send(message);
|
|
||||||
// $('#message').val('');
|
|
||||||
// });
|
|
||||||
// });
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
</script>
|
|
||||||
</html>
|
|
File diff suppressed because one or more lines are too long
|
@ -1,24 +0,0 @@
|
||||||
1、启动语言识别服务端
|
|
||||||
创建虚拟环境
|
|
||||||
conda create -n funasr
|
|
||||||
conda activate funasr
|
|
||||||
安装依赖库
|
|
||||||
pip install torch
|
|
||||||
pip install modelscope
|
|
||||||
pip install testresources
|
|
||||||
pip install websockets
|
|
||||||
pip install torchaudio
|
|
||||||
pip install FunASR
|
|
||||||
pip install pyaudio
|
|
||||||
|
|
||||||
|
|
||||||
python funasr_wss_server.py --port 10095
|
|
||||||
或者
|
|
||||||
python funasr_wss_server.py --host "0.0.0.0" --port 10197 --ngpu 0
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
https://github.com/alibaba-damo-academy/FunASR
|
|
||||||
https://zhuanlan.zhihu.com/p/649935170
|
|
|
@ -1,17 +0,0 @@
|
||||||
## certificate generation by yourself
|
|
||||||
generated certificate may not suitable for all browsers due to security concerns. you'd better buy or download an authenticated ssl certificate from authorized agency.
|
|
||||||
|
|
||||||
```shell
|
|
||||||
### 1) Generate a private key
|
|
||||||
openssl genrsa -des3 -out server.key 2048
|
|
||||||
|
|
||||||
### 2) Generate a csr file
|
|
||||||
openssl req -new -key server.key -out server.csr
|
|
||||||
|
|
||||||
### 3) Remove pass
|
|
||||||
cp server.key server.key.org
|
|
||||||
openssl rsa -in server.key.org -out server.key
|
|
||||||
|
|
||||||
### 4) Generated a crt file, valid for 1 year
|
|
||||||
openssl x509 -req -days 365 -in server.csr -signkey server.key -out server.crt
|
|
||||||
```
|
|
|
@ -1,17 +0,0 @@
|
||||||
## 自行生成证书
|
|
||||||
生成证书(注意这种证书并不能被所有浏览器认可,部分手动授权可以访问,最好使用其他认证的官方ssl证书)
|
|
||||||
|
|
||||||
```shell
|
|
||||||
### 1)生成私钥,按照提示填写内容
|
|
||||||
openssl genrsa -des3 -out server.key 1024
|
|
||||||
|
|
||||||
### 2)生成csr文件 ,按照提示填写内容
|
|
||||||
openssl req -new -key server.key -out server.csr
|
|
||||||
|
|
||||||
### 去掉pass
|
|
||||||
cp server.key server.key.org
|
|
||||||
openssl rsa -in server.key.org -out server.key
|
|
||||||
|
|
||||||
### 生成crt文件,有效期1年(365天)
|
|
||||||
openssl x509 -req -days 365 -in server.csr -signkey server.key -out server.crt
|
|
||||||
```
|
|
|
@ -1,21 +0,0 @@
|
||||||
-----BEGIN CERTIFICATE-----
|
|
||||||
MIIDhTCCAm0CFGB0Po2IZ0hESavFpcSGRNb9xrNXMA0GCSqGSIb3DQEBCwUAMH8x
|
|
||||||
CzAJBgNVBAYTAkNOMRAwDgYDVQQIDAdiZWlqaW5nMRAwDgYDVQQHDAdiZWlqaW5n
|
|
||||||
MRAwDgYDVQQKDAdhbGliYWJhMRAwDgYDVQQLDAdhbGliYWJhMRAwDgYDVQQDDAdh
|
|
||||||
bGliYWJhMRYwFAYJKoZIhvcNAQkBFgdhbGliYWJhMB4XDTIzMDYxODA2NTcxM1oX
|
|
||||||
DTI0MDYxNzA2NTcxM1owfzELMAkGA1UEBhMCQ04xEDAOBgNVBAgMB2JlaWppbmcx
|
|
||||||
EDAOBgNVBAcMB2JlaWppbmcxEDAOBgNVBAoMB2FsaWJhYmExEDAOBgNVBAsMB2Fs
|
|
||||||
aWJhYmExEDAOBgNVBAMMB2FsaWJhYmExFjAUBgkqhkiG9w0BCQEWB2FsaWJhYmEw
|
|
||||||
ggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDH9Np1oBunQKMt5M/nU2nD
|
|
||||||
qVHojXwKKwyiK9DSeGikKwArH2S9NUZNu5RDg46u0iWmT+Vz+toQhkJnfatOVskW
|
|
||||||
f2bsI54n5eOvmoWOKDXYm2MscvjkuNiYRbqzgUuP9ZSx8k3uyRs++wvmwIoU+PV1
|
|
||||||
EYFcjk1P2jUGUvKaUlmIDsjs1wOMIbKO6I0UX20FNKlGWacqMR/Dx2ltmGKT1Kaz
|
|
||||||
Y335lor0bcfQtH542rGS7PDz6JMRNjFT1VFcmnrjRElf4STbaOiIfOjMVZ/9O8Hr
|
|
||||||
LFItyvkb01Mt7O0jhAXHuE1l/8Y0N3MCYkELG9mQA0BYCFHY0FLuJrGoU03b8KWj
|
|
||||||
AgMBAAEwDQYJKoZIhvcNAQELBQADggEBAEjC9jB1WZe2ki2JgCS+eAMFsFegiNEz
|
|
||||||
D0klVB3kiCPK0g7DCxvfWR6kAgEynxRxVX6TN9QcLr4paZItC1Fu2gUMTteNqEuc
|
|
||||||
dcixJdu9jumuUMBlAKgL5Yyk3alSErsn9ZVF/Q8Kx5arMO/TW3Ulsd8SWQL5C/vq
|
|
||||||
Fe0SRhpKKoADPfl8MT/XMfB/MwNxVhYDSHzJ1EiN8O5ce6q2tTdi1mlGquzNxhjC
|
|
||||||
7Q0F36V1HksfzolrlRWRKYP16isnaKUdFfeAzaJsYw33o6VRbk6fo2fTQDHS0wOs
|
|
||||||
Q48Moc5UxKMLaMMCqLPpWu0TZse+kIw1nTWXk7yJtK0HK5PN3rTocEw=
|
|
||||||
-----END CERTIFICATE-----
|
|
|
@ -1,27 +0,0 @@
|
||||||
-----BEGIN RSA PRIVATE KEY-----
|
|
||||||
MIIEowIBAAKCAQEAx/TadaAbp0CjLeTP51Npw6lR6I18CisMoivQ0nhopCsAKx9k
|
|
||||||
vTVGTbuUQ4OOrtIlpk/lc/raEIZCZ32rTlbJFn9m7COeJ+Xjr5qFjig12JtjLHL4
|
|
||||||
5LjYmEW6s4FLj/WUsfJN7skbPvsL5sCKFPj1dRGBXI5NT9o1BlLymlJZiA7I7NcD
|
|
||||||
jCGyjuiNFF9tBTSpRlmnKjEfw8dpbZhik9Sms2N9+ZaK9G3H0LR+eNqxkuzw8+iT
|
|
||||||
ETYxU9VRXJp640RJX+Ek22joiHzozFWf/TvB6yxSLcr5G9NTLeztI4QFx7hNZf/G
|
|
||||||
NDdzAmJBCxvZkANAWAhR2NBS7iaxqFNN2/ClowIDAQABAoIBAQC1/STX6eFBWJMs
|
|
||||||
MhUHdePNMU5bWmqK1qOo9jgZV33l7T06Alit3M8f8JoA2LwEYT/jHtS3upi+cXP+
|
|
||||||
vWIs6tAaqdoDEmff6FxSd1EXEYHwo3yf+ASQJ6z66nwC5KrhW6L6Uo6bxm4F5Hfw
|
|
||||||
jU0fyXeeFVCn7Nxw0SlxmA02Z70VFsL8BK9i3kajU18y6drf4VUm55oMEtdEmOh2
|
|
||||||
eKn4qspBcNblbw+L0QJ+5kN1iRUyJHesQ1GpS+L3yeMVFCW7ctL4Bgw8Z7LE+z7i
|
|
||||||
C0Weyhul8vuT+7nfF2T37zsSa8iixqpkTokeYh96CZ5nDqa2IDx3oNHWSlkIsV6g
|
|
||||||
6EUEl9gBAoGBAPIw/M6fIDetMj8f1wG7mIRgJsxI817IS6aBSwB5HkoCJFfrR9Ua
|
|
||||||
jMNCFIWNs/Om8xeGhq/91hbnCYDNK06V5CUa/uk4CYRs2eQZ3FKoNowtp6u/ieuU
|
|
||||||
qg8bXM/vR2VWtWVixAMdouT3+KtvlgaVmSnrPiwO4pecGrwu5NW1oJCFAoGBANNb
|
|
||||||
aE3AcwTDYsqh0N/75G56Q5s1GZ6MCDQGQSh8IkxL6Vg59KnJiIKQ7AxNKFgJZMtY
|
|
||||||
zZHaqjazeHjOGTiYiC7MMVJtCcOBEfjCouIG8btNYv7Y3dWnOXRZni2telAsRrH9
|
|
||||||
xS5LaFdCRTjVAwSsppMGwiQtyl6sGLMyz0SXoYoHAoGAKdkFFb6xFm26zOV3hTkg
|
|
||||||
9V6X1ZyVUL9TMwYMK5zB+w+7r+VbmBrqT6LPYPRHL8adImeARlCZ+YMaRUMuRHnp
|
|
||||||
3e94NFwWaOdWDu/Y/f9KzZXl7us9rZMWf12+/77cm0oMNeSG8fLg/qdKNHUneyPG
|
|
||||||
P1QCfiJkTMYQaIvBxpuHjvECgYAKlZ9JlYOtD2PZJfVh4il0ZucP1L7ts7GNeWq1
|
|
||||||
7lGBZKPQ6UYZYqBVeZB4pTyJ/B5yGIZi8YJoruAvnJKixPC89zjZGeDNS59sx8KE
|
|
||||||
cziT2rJEdPPXCULVUs+bFf70GOOJcl33jYsyI3139SLrjwHghwwd57UkvJWYE8lR
|
|
||||||
dA6A7QKBgEfTC+NlzqLPhbB+HPl6CvcUczcXcI9M0heVz/DNMA+4pjxPnv2aeIwh
|
|
||||||
cL2wq2xr+g1wDBWGVGkVSuZhXm5E6gDetdyVeJnbIUhVjBblnbhHV6GrudjbXGnJ
|
|
||||||
W9cBgu6DswyHU2cOsqmimu8zLmG6/dQYFHt+kUWGxN8opCzVjgWa
|
|
||||||
-----END RSA PRIVATE KEY-----
|
|
|
@ -0,0 +1,47 @@
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import queue
|
||||||
|
from queue import Queue
|
||||||
|
import multiprocessing as mp
|
||||||
|
|
||||||
|
from baseasr import BaseASR
|
||||||
|
from wav2lip import audio
|
||||||
|
|
||||||
|
class LipASR(BaseASR):
|
||||||
|
|
||||||
|
def run_step(self):
|
||||||
|
############################################## extract audio feature ##############################################
|
||||||
|
# get a frame of audio
|
||||||
|
for _ in range(self.batch_size*2):
|
||||||
|
frame,type = self.get_audio_frame()
|
||||||
|
self.frames.append(frame)
|
||||||
|
# put to output
|
||||||
|
self.output_queue.put((frame,type))
|
||||||
|
# context not enough, do not run network.
|
||||||
|
if len(self.frames) <= self.stride_left_size + self.stride_right_size:
|
||||||
|
return
|
||||||
|
|
||||||
|
inputs = np.concatenate(self.frames) # [N * chunk]
|
||||||
|
mel = audio.melspectrogram(inputs)
|
||||||
|
#print(mel.shape[0],mel.shape,len(mel[0]),len(self.frames))
|
||||||
|
# cut off stride
|
||||||
|
left = max(0, self.stride_left_size*80/50)
|
||||||
|
right = min(len(mel[0]), len(mel[0]) - self.stride_right_size*80/50)
|
||||||
|
mel_idx_multiplier = 80.*2/self.fps
|
||||||
|
mel_step_size = 16
|
||||||
|
i = 0
|
||||||
|
mel_chunks = []
|
||||||
|
while i < (len(self.frames)-self.stride_left_size-self.stride_right_size)/2:
|
||||||
|
start_idx = int(left + i * mel_idx_multiplier)
|
||||||
|
#print(start_idx)
|
||||||
|
if start_idx + mel_step_size > len(mel[0]):
|
||||||
|
mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
|
||||||
|
else:
|
||||||
|
mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size])
|
||||||
|
i += 1
|
||||||
|
self.feat_queue.put(mel_chunks)
|
||||||
|
|
||||||
|
# discard the old part to save memory
|
||||||
|
self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]
|
|
@ -0,0 +1,281 @@
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
#from .utils import *
|
||||||
|
import subprocess
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import cv2
|
||||||
|
import glob
|
||||||
|
import pickle
|
||||||
|
import copy
|
||||||
|
|
||||||
|
import queue
|
||||||
|
from queue import Queue
|
||||||
|
from threading import Thread, Event
|
||||||
|
from io import BytesIO
|
||||||
|
import multiprocessing as mp
|
||||||
|
|
||||||
|
|
||||||
|
from ttsreal import EdgeTTS,VoitsTTS,XTTS
|
||||||
|
|
||||||
|
from lipasr import LipASR
|
||||||
|
import asyncio
|
||||||
|
from av import AudioFrame, VideoFrame
|
||||||
|
from wav2lip.models import Wav2Lip
|
||||||
|
from basereal import BaseReal
|
||||||
|
|
||||||
|
#from imgcache import ImgCache
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
print('Using {} for inference.'.format(device))
|
||||||
|
|
||||||
|
def _load(checkpoint_path):
|
||||||
|
if device == 'cuda':
|
||||||
|
checkpoint = torch.load(checkpoint_path)
|
||||||
|
else:
|
||||||
|
checkpoint = torch.load(checkpoint_path,
|
||||||
|
map_location=lambda storage, loc: storage)
|
||||||
|
return checkpoint
|
||||||
|
|
||||||
|
def load_model(path):
|
||||||
|
model = Wav2Lip()
|
||||||
|
print("Load checkpoint from: {}".format(path))
|
||||||
|
checkpoint = _load(path)
|
||||||
|
s = checkpoint["state_dict"]
|
||||||
|
new_s = {}
|
||||||
|
for k, v in s.items():
|
||||||
|
new_s[k.replace('module.', '')] = v
|
||||||
|
model.load_state_dict(new_s)
|
||||||
|
|
||||||
|
model = model.to(device)
|
||||||
|
return model.eval()
|
||||||
|
|
||||||
|
def read_imgs(img_list):
|
||||||
|
frames = []
|
||||||
|
print('reading images...')
|
||||||
|
for img_path in tqdm(img_list):
|
||||||
|
frame = cv2.imread(img_path)
|
||||||
|
frames.append(frame)
|
||||||
|
return frames
|
||||||
|
|
||||||
|
def __mirror_index(size, index):
|
||||||
|
#size = len(self.coord_list_cycle)
|
||||||
|
turn = index // size
|
||||||
|
res = index % size
|
||||||
|
if turn % 2 == 0:
|
||||||
|
return res
|
||||||
|
else:
|
||||||
|
return size - res - 1
|
||||||
|
|
||||||
|
def inference(render_event,batch_size,face_imgs_path,audio_feat_queue,audio_out_queue,res_frame_queue):
|
||||||
|
|
||||||
|
model = load_model("./models/wav2lip.pth")
|
||||||
|
input_face_list = glob.glob(os.path.join(face_imgs_path, '*.[jpJP][pnPN]*[gG]'))
|
||||||
|
input_face_list = sorted(input_face_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
|
||||||
|
face_list_cycle = read_imgs(input_face_list)
|
||||||
|
|
||||||
|
#input_latent_list_cycle = torch.load(latents_out_path)
|
||||||
|
length = len(face_list_cycle)
|
||||||
|
index = 0
|
||||||
|
count=0
|
||||||
|
counttime=0
|
||||||
|
print('start inference')
|
||||||
|
while True:
|
||||||
|
if render_event.is_set():
|
||||||
|
starttime=time.perf_counter()
|
||||||
|
mel_batch = []
|
||||||
|
try:
|
||||||
|
mel_batch = audio_feat_queue.get(block=True, timeout=1)
|
||||||
|
except queue.Empty:
|
||||||
|
continue
|
||||||
|
|
||||||
|
is_all_silence=True
|
||||||
|
audio_frames = []
|
||||||
|
for _ in range(batch_size*2):
|
||||||
|
frame,type = audio_out_queue.get()
|
||||||
|
audio_frames.append((frame,type))
|
||||||
|
if type==0:
|
||||||
|
is_all_silence=False
|
||||||
|
|
||||||
|
if is_all_silence:
|
||||||
|
for i in range(batch_size):
|
||||||
|
res_frame_queue.put((None,__mirror_index(length,index),audio_frames[i*2:i*2+2]))
|
||||||
|
index = index + 1
|
||||||
|
else:
|
||||||
|
# print('infer=======')
|
||||||
|
t=time.perf_counter()
|
||||||
|
img_batch = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
idx = __mirror_index(length,index+i)
|
||||||
|
face = face_list_cycle[idx]
|
||||||
|
img_batch.append(face)
|
||||||
|
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
|
||||||
|
|
||||||
|
img_masked = img_batch.copy()
|
||||||
|
img_masked[:, face.shape[0]//2:] = 0
|
||||||
|
|
||||||
|
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
|
||||||
|
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
|
||||||
|
|
||||||
|
img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
|
||||||
|
mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
pred = model(mel_batch, img_batch)
|
||||||
|
pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
|
||||||
|
|
||||||
|
counttime += (time.perf_counter() - t)
|
||||||
|
count += batch_size
|
||||||
|
#_totalframe += 1
|
||||||
|
if count>=100:
|
||||||
|
print(f"------actual avg infer fps:{count/counttime:.4f}")
|
||||||
|
count=0
|
||||||
|
counttime=0
|
||||||
|
for i,res_frame in enumerate(pred):
|
||||||
|
#self.__pushmedia(res_frame,loop,audio_track,video_track)
|
||||||
|
res_frame_queue.put((res_frame,__mirror_index(length,index),audio_frames[i*2:i*2+2]))
|
||||||
|
index = index + 1
|
||||||
|
#print('total batch time:',time.perf_counter()-starttime)
|
||||||
|
else:
|
||||||
|
time.sleep(1)
|
||||||
|
print('musereal inference processor stop')
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
class LipReal(BaseReal):
|
||||||
|
def __init__(self, opt):
|
||||||
|
super().__init__(opt)
|
||||||
|
#self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
|
||||||
|
self.W = opt.W
|
||||||
|
self.H = opt.H
|
||||||
|
|
||||||
|
self.fps = opt.fps # 20 ms per frame
|
||||||
|
|
||||||
|
#### musetalk
|
||||||
|
self.avatar_id = opt.avatar_id
|
||||||
|
self.avatar_path = f"./data/avatars/{self.avatar_id}"
|
||||||
|
self.full_imgs_path = f"{self.avatar_path}/full_imgs"
|
||||||
|
self.face_imgs_path = f"{self.avatar_path}/face_imgs"
|
||||||
|
self.coords_path = f"{self.avatar_path}/coords.pkl"
|
||||||
|
self.batch_size = opt.batch_size
|
||||||
|
self.idx = 0
|
||||||
|
self.res_frame_queue = mp.Queue(self.batch_size*2)
|
||||||
|
#self.__loadmodels()
|
||||||
|
self.__loadavatar()
|
||||||
|
|
||||||
|
self.asr = LipASR(opt,self)
|
||||||
|
self.asr.warm_up()
|
||||||
|
#self.__warm_up()
|
||||||
|
|
||||||
|
self.render_event = mp.Event()
|
||||||
|
mp.Process(target=inference, args=(self.render_event,self.batch_size,self.face_imgs_path,
|
||||||
|
self.asr.feat_queue,self.asr.output_queue,self.res_frame_queue,
|
||||||
|
)).start()
|
||||||
|
|
||||||
|
# def __loadmodels(self):
|
||||||
|
# # load model weights
|
||||||
|
# self.audio_processor, self.vae, self.unet, self.pe = load_all_model()
|
||||||
|
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
# self.timesteps = torch.tensor([0], device=device)
|
||||||
|
# self.pe = self.pe.half()
|
||||||
|
# self.vae.vae = self.vae.vae.half()
|
||||||
|
# self.unet.model = self.unet.model.half()
|
||||||
|
|
||||||
|
def __loadavatar(self):
|
||||||
|
with open(self.coords_path, 'rb') as f:
|
||||||
|
self.coord_list_cycle = pickle.load(f)
|
||||||
|
input_img_list = glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]'))
|
||||||
|
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
|
||||||
|
self.frame_list_cycle = read_imgs(input_img_list)
|
||||||
|
#self.imagecache = ImgCache(len(self.coord_list_cycle),self.full_imgs_path,1000)
|
||||||
|
|
||||||
|
|
||||||
|
def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None):
|
||||||
|
|
||||||
|
while not quit_event.is_set():
|
||||||
|
try:
|
||||||
|
res_frame,idx,audio_frames = self.res_frame_queue.get(block=True, timeout=1)
|
||||||
|
except queue.Empty:
|
||||||
|
continue
|
||||||
|
if audio_frames[0][1]!=0 and audio_frames[1][1]!=0: #全为静音数据,只需要取fullimg
|
||||||
|
self.speaking = False
|
||||||
|
audiotype = audio_frames[0][1]
|
||||||
|
if self.custom_index.get(audiotype) is not None: #有自定义视频
|
||||||
|
mirindex = self.mirror_index(len(self.custom_img_cycle[audiotype]),self.custom_index[audiotype])
|
||||||
|
combine_frame = self.custom_img_cycle[audiotype][mirindex]
|
||||||
|
self.custom_index[audiotype] += 1
|
||||||
|
# if not self.custom_opt[audiotype].loop and self.custom_index[audiotype]>=len(self.custom_img_cycle[audiotype]):
|
||||||
|
# self.curr_state = 1 #当前视频不循环播放,切换到静音状态
|
||||||
|
else:
|
||||||
|
combine_frame = self.frame_list_cycle[idx]
|
||||||
|
#combine_frame = self.imagecache.get_img(idx)
|
||||||
|
else:
|
||||||
|
self.speaking = True
|
||||||
|
bbox = self.coord_list_cycle[idx]
|
||||||
|
combine_frame = copy.deepcopy(self.frame_list_cycle[idx])
|
||||||
|
#combine_frame = copy.deepcopy(self.imagecache.get_img(idx))
|
||||||
|
y1, y2, x1, x2 = bbox
|
||||||
|
try:
|
||||||
|
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
#combine_frame = get_image(ori_frame,res_frame,bbox)
|
||||||
|
#t=time.perf_counter()
|
||||||
|
combine_frame[y1:y2, x1:x2] = res_frame
|
||||||
|
#print('blending time:',time.perf_counter()-t)
|
||||||
|
|
||||||
|
image = combine_frame #(outputs['image'] * 255).astype(np.uint8)
|
||||||
|
new_frame = VideoFrame.from_ndarray(image, format="bgr24")
|
||||||
|
asyncio.run_coroutine_threadsafe(video_track._queue.put(new_frame), loop)
|
||||||
|
if self.recording:
|
||||||
|
self.recordq_video.put(new_frame)
|
||||||
|
|
||||||
|
for audio_frame in audio_frames:
|
||||||
|
frame,type = audio_frame
|
||||||
|
frame = (frame * 32767).astype(np.int16)
|
||||||
|
new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0])
|
||||||
|
new_frame.planes[0].update(frame.tobytes())
|
||||||
|
new_frame.sample_rate=16000
|
||||||
|
# if audio_track._queue.qsize()>10:
|
||||||
|
# time.sleep(0.1)
|
||||||
|
asyncio.run_coroutine_threadsafe(audio_track._queue.put(new_frame), loop)
|
||||||
|
if self.recording:
|
||||||
|
self.recordq_audio.put(new_frame)
|
||||||
|
print('musereal process_frames thread stop')
|
||||||
|
|
||||||
|
def render(self,quit_event,loop=None,audio_track=None,video_track=None):
|
||||||
|
#if self.opt.asr:
|
||||||
|
# self.asr.warm_up()
|
||||||
|
|
||||||
|
self.tts.render(quit_event)
|
||||||
|
self.init_customindex()
|
||||||
|
process_thread = Thread(target=self.process_frames, args=(quit_event,loop,audio_track,video_track))
|
||||||
|
process_thread.start()
|
||||||
|
|
||||||
|
self.render_event.set() #start infer process render
|
||||||
|
count=0
|
||||||
|
totaltime=0
|
||||||
|
_starttime=time.perf_counter()
|
||||||
|
#_totalframe=0
|
||||||
|
while not quit_event.is_set():
|
||||||
|
# update texture every frame
|
||||||
|
# audio stream thread...
|
||||||
|
t = time.perf_counter()
|
||||||
|
self.asr.run_step()
|
||||||
|
|
||||||
|
# if video_track._queue.qsize()>=2*self.opt.batch_size:
|
||||||
|
# print('sleep qsize=',video_track._queue.qsize())
|
||||||
|
# time.sleep(0.04*video_track._queue.qsize()*0.8)
|
||||||
|
if video_track._queue.qsize()>=5:
|
||||||
|
print('sleep qsize=',video_track._queue.qsize())
|
||||||
|
time.sleep(0.04*video_track._queue.qsize()*0.8)
|
||||||
|
|
||||||
|
# delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms
|
||||||
|
# if delay > 0:
|
||||||
|
# time.sleep(delay)
|
||||||
|
self.render_event.clear() #end infer process render
|
||||||
|
print('musereal thread stop')
|
||||||
|
|
|
@ -15,7 +15,7 @@ class VllmGPT:
|
||||||
self.__URL = "http://{}:{}/v1/completions".format(self.host, self.port)
|
self.__URL = "http://{}:{}/v1/completions".format(self.host, self.port)
|
||||||
self.__URL2 = "http://{}:{}/v1/chat/completions".format(self.host, self.port)
|
self.__URL2 = "http://{}:{}/v1/chat/completions".format(self.host, self.port)
|
||||||
|
|
||||||
def question(self,cont):
|
def chat(self,cont):
|
||||||
chat_list = []
|
chat_list = []
|
||||||
# contentdb = content_db.new_instance()
|
# contentdb = content_db.new_instance()
|
||||||
# list = contentdb.get_list('all','desc',11)
|
# list = contentdb.get_list('all','desc',11)
|
||||||
|
@ -77,5 +77,5 @@ class VllmGPT:
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
vllm = VllmGPT('192.168.1.3','8101')
|
vllm = VllmGPT('192.168.1.3','8101')
|
||||||
req = vllm.question("你叫什么名字啊今年多大了")
|
req = vllm.chat("你叫什么名字啊今年多大了")
|
||||||
print(req)
|
print(req)
|
||||||
|
|
|
@ -0,0 +1,36 @@
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import queue
|
||||||
|
from queue import Queue
|
||||||
|
import multiprocessing as mp
|
||||||
|
from baseasr import BaseASR
|
||||||
|
from musetalk.whisper.audio2feature import Audio2Feature
|
||||||
|
|
||||||
|
class MuseASR(BaseASR):
|
||||||
|
def __init__(self, opt, parent,audio_processor:Audio2Feature):
|
||||||
|
super().__init__(opt,parent)
|
||||||
|
self.audio_processor = audio_processor
|
||||||
|
|
||||||
|
def run_step(self):
|
||||||
|
############################################## extract audio feature ##############################################
|
||||||
|
start_time = time.time()
|
||||||
|
for _ in range(self.batch_size*2):
|
||||||
|
audio_frame,type=self.get_audio_frame()
|
||||||
|
self.frames.append(audio_frame)
|
||||||
|
self.output_queue.put((audio_frame,type))
|
||||||
|
|
||||||
|
if len(self.frames) <= self.stride_left_size + self.stride_right_size:
|
||||||
|
return
|
||||||
|
|
||||||
|
inputs = np.concatenate(self.frames) # [N * chunk]
|
||||||
|
whisper_feature = self.audio_processor.audio2feat(inputs)
|
||||||
|
# for feature in whisper_feature:
|
||||||
|
# self.audio_feats.append(feature)
|
||||||
|
#print(f"processing audio costs {(time.time() - start_time) * 1000}ms, inputs shape:{inputs.shape} whisper_feature len:{len(whisper_feature)}")
|
||||||
|
whisper_chunks = self.audio_processor.feature2chunks(feature_array=whisper_feature,fps=self.fps/2,batch_size=self.batch_size,start=self.stride_left_size/2 )
|
||||||
|
#print(f"whisper_chunks len:{len(whisper_chunks)},self.audio_feats len:{len(self.audio_feats)},self.output_queue len:{self.output_queue.qsize()}")
|
||||||
|
#self.audio_feats = self.audio_feats[-(self.stride_left_size + self.stride_right_size):]
|
||||||
|
self.feat_queue.put(whisper_chunks)
|
||||||
|
# discard the old part to save memory
|
||||||
|
self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]
|
|
@ -0,0 +1,318 @@
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
#from .utils import *
|
||||||
|
import subprocess
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import cv2
|
||||||
|
import glob
|
||||||
|
import pickle
|
||||||
|
import copy
|
||||||
|
|
||||||
|
import queue
|
||||||
|
from queue import Queue
|
||||||
|
from threading import Thread, Event
|
||||||
|
from io import BytesIO
|
||||||
|
import multiprocessing as mp
|
||||||
|
|
||||||
|
from musetalk.utils.utils import get_file_type,get_video_fps,datagen
|
||||||
|
#from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder
|
||||||
|
from musetalk.utils.blending import get_image,get_image_prepare_material,get_image_blending
|
||||||
|
from musetalk.utils.utils import load_all_model,load_diffusion_model,load_audio_model
|
||||||
|
from ttsreal import EdgeTTS,VoitsTTS,XTTS
|
||||||
|
|
||||||
|
from museasr import MuseASR
|
||||||
|
import asyncio
|
||||||
|
from av import AudioFrame, VideoFrame
|
||||||
|
from basereal import BaseReal
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
def read_imgs(img_list):
|
||||||
|
frames = []
|
||||||
|
print('reading images...')
|
||||||
|
for img_path in tqdm(img_list):
|
||||||
|
frame = cv2.imread(img_path)
|
||||||
|
frames.append(frame)
|
||||||
|
return frames
|
||||||
|
|
||||||
|
def __mirror_index(size, index):
|
||||||
|
#size = len(self.coord_list_cycle)
|
||||||
|
turn = index // size
|
||||||
|
res = index % size
|
||||||
|
if turn % 2 == 0:
|
||||||
|
return res
|
||||||
|
else:
|
||||||
|
return size - res - 1
|
||||||
|
@torch.no_grad()
|
||||||
|
def inference(render_event,batch_size,latents_out_path,audio_feat_queue,audio_out_queue,res_frame_queue,
|
||||||
|
): #vae, unet, pe,timesteps
|
||||||
|
|
||||||
|
vae, unet, pe = load_diffusion_model()
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
timesteps = torch.tensor([0], device=device)
|
||||||
|
pe = pe.half()
|
||||||
|
vae.vae = vae.vae.half()
|
||||||
|
unet.model = unet.model.half()
|
||||||
|
|
||||||
|
input_latent_list_cycle = torch.load(latents_out_path)
|
||||||
|
length = len(input_latent_list_cycle)
|
||||||
|
index = 0
|
||||||
|
count=0
|
||||||
|
counttime=0
|
||||||
|
print('start inference')
|
||||||
|
while True:
|
||||||
|
if render_event.is_set():
|
||||||
|
starttime=time.perf_counter()
|
||||||
|
try:
|
||||||
|
whisper_chunks = audio_feat_queue.get(block=True, timeout=1)
|
||||||
|
except queue.Empty:
|
||||||
|
continue
|
||||||
|
is_all_silence=True
|
||||||
|
audio_frames = []
|
||||||
|
for _ in range(batch_size*2):
|
||||||
|
frame,type = audio_out_queue.get()
|
||||||
|
audio_frames.append((frame,type))
|
||||||
|
if type==0:
|
||||||
|
is_all_silence=False
|
||||||
|
if is_all_silence:
|
||||||
|
for i in range(batch_size):
|
||||||
|
res_frame_queue.put((None,__mirror_index(length,index),audio_frames[i*2:i*2+2]))
|
||||||
|
index = index + 1
|
||||||
|
else:
|
||||||
|
# print('infer=======')
|
||||||
|
t=time.perf_counter()
|
||||||
|
whisper_batch = np.stack(whisper_chunks)
|
||||||
|
latent_batch = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
idx = __mirror_index(length,index+i)
|
||||||
|
latent = input_latent_list_cycle[idx]
|
||||||
|
latent_batch.append(latent)
|
||||||
|
latent_batch = torch.cat(latent_batch, dim=0)
|
||||||
|
|
||||||
|
# for i, (whisper_batch,latent_batch) in enumerate(gen):
|
||||||
|
audio_feature_batch = torch.from_numpy(whisper_batch)
|
||||||
|
audio_feature_batch = audio_feature_batch.to(device=unet.device,
|
||||||
|
dtype=unet.model.dtype)
|
||||||
|
audio_feature_batch = pe(audio_feature_batch)
|
||||||
|
latent_batch = latent_batch.to(dtype=unet.model.dtype)
|
||||||
|
# print('prepare time:',time.perf_counter()-t)
|
||||||
|
# t=time.perf_counter()
|
||||||
|
|
||||||
|
pred_latents = unet.model(latent_batch,
|
||||||
|
timesteps,
|
||||||
|
encoder_hidden_states=audio_feature_batch).sample
|
||||||
|
# print('unet time:',time.perf_counter()-t)
|
||||||
|
# t=time.perf_counter()
|
||||||
|
recon = vae.decode_latents(pred_latents)
|
||||||
|
# print('vae time:',time.perf_counter()-t)
|
||||||
|
#print('diffusion len=',len(recon))
|
||||||
|
counttime += (time.perf_counter() - t)
|
||||||
|
count += batch_size
|
||||||
|
#_totalframe += 1
|
||||||
|
if count>=100:
|
||||||
|
print(f"------actual avg infer fps:{count/counttime:.4f}")
|
||||||
|
count=0
|
||||||
|
counttime=0
|
||||||
|
for i,res_frame in enumerate(recon):
|
||||||
|
#self.__pushmedia(res_frame,loop,audio_track,video_track)
|
||||||
|
res_frame_queue.put((res_frame,__mirror_index(length,index),audio_frames[i*2:i*2+2]))
|
||||||
|
index = index + 1
|
||||||
|
#print('total batch time:',time.perf_counter()-starttime)
|
||||||
|
else:
|
||||||
|
time.sleep(1)
|
||||||
|
print('musereal inference processor stop')
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
class MuseReal(BaseReal):
|
||||||
|
def __init__(self, opt):
|
||||||
|
super().__init__(opt)
|
||||||
|
#self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
|
||||||
|
self.W = opt.W
|
||||||
|
self.H = opt.H
|
||||||
|
|
||||||
|
self.fps = opt.fps # 20 ms per frame
|
||||||
|
|
||||||
|
#### musetalk
|
||||||
|
self.avatar_id = opt.avatar_id
|
||||||
|
self.video_path = '' #video_path
|
||||||
|
self.bbox_shift = opt.bbox_shift
|
||||||
|
self.avatar_path = f"./data/avatars/{self.avatar_id}"
|
||||||
|
self.full_imgs_path = f"{self.avatar_path}/full_imgs"
|
||||||
|
self.coords_path = f"{self.avatar_path}/coords.pkl"
|
||||||
|
self.latents_out_path= f"{self.avatar_path}/latents.pt"
|
||||||
|
self.video_out_path = f"{self.avatar_path}/vid_output/"
|
||||||
|
self.mask_out_path =f"{self.avatar_path}/mask"
|
||||||
|
self.mask_coords_path =f"{self.avatar_path}/mask_coords.pkl"
|
||||||
|
self.avatar_info_path = f"{self.avatar_path}/avator_info.json"
|
||||||
|
self.avatar_info = {
|
||||||
|
"avatar_id":self.avatar_id,
|
||||||
|
"video_path":self.video_path,
|
||||||
|
"bbox_shift":self.bbox_shift
|
||||||
|
}
|
||||||
|
self.batch_size = opt.batch_size
|
||||||
|
self.idx = 0
|
||||||
|
self.res_frame_queue = mp.Queue(self.batch_size*2)
|
||||||
|
self.__loadmodels()
|
||||||
|
self.__loadavatar()
|
||||||
|
|
||||||
|
self.asr = MuseASR(opt,self,self.audio_processor)
|
||||||
|
self.asr.warm_up()
|
||||||
|
#self.__warm_up()
|
||||||
|
|
||||||
|
self.render_event = mp.Event()
|
||||||
|
mp.Process(target=inference, args=(self.render_event,self.batch_size,self.latents_out_path,
|
||||||
|
self.asr.feat_queue,self.asr.output_queue,self.res_frame_queue,
|
||||||
|
)).start() #self.vae, self.unet, self.pe,self.timesteps
|
||||||
|
|
||||||
|
def __loadmodels(self):
|
||||||
|
# load model weights
|
||||||
|
self.audio_processor= load_audio_model()
|
||||||
|
# self.audio_processor, self.vae, self.unet, self.pe = load_all_model()
|
||||||
|
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
# self.timesteps = torch.tensor([0], device=device)
|
||||||
|
# self.pe = self.pe.half()
|
||||||
|
# self.vae.vae = self.vae.vae.half()
|
||||||
|
# self.unet.model = self.unet.model.half()
|
||||||
|
|
||||||
|
def __loadavatar(self):
|
||||||
|
#self.input_latent_list_cycle = torch.load(self.latents_out_path)
|
||||||
|
with open(self.coords_path, 'rb') as f:
|
||||||
|
self.coord_list_cycle = pickle.load(f)
|
||||||
|
input_img_list = glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]'))
|
||||||
|
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
|
||||||
|
self.frame_list_cycle = read_imgs(input_img_list)
|
||||||
|
with open(self.mask_coords_path, 'rb') as f:
|
||||||
|
self.mask_coords_list_cycle = pickle.load(f)
|
||||||
|
input_mask_list = glob.glob(os.path.join(self.mask_out_path, '*.[jpJP][pnPN]*[gG]'))
|
||||||
|
input_mask_list = sorted(input_mask_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
|
||||||
|
self.mask_list_cycle = read_imgs(input_mask_list)
|
||||||
|
|
||||||
|
|
||||||
|
def __mirror_index(self, index):
|
||||||
|
size = len(self.coord_list_cycle)
|
||||||
|
turn = index // size
|
||||||
|
res = index % size
|
||||||
|
if turn % 2 == 0:
|
||||||
|
return res
|
||||||
|
else:
|
||||||
|
return size - res - 1
|
||||||
|
|
||||||
|
def __warm_up(self):
|
||||||
|
self.asr.run_step()
|
||||||
|
whisper_chunks = self.asr.get_next_feat()
|
||||||
|
whisper_batch = np.stack(whisper_chunks)
|
||||||
|
latent_batch = []
|
||||||
|
for i in range(self.batch_size):
|
||||||
|
idx = self.__mirror_index(self.idx+i)
|
||||||
|
latent = self.input_latent_list_cycle[idx]
|
||||||
|
latent_batch.append(latent)
|
||||||
|
latent_batch = torch.cat(latent_batch, dim=0)
|
||||||
|
print('infer=======')
|
||||||
|
# for i, (whisper_batch,latent_batch) in enumerate(gen):
|
||||||
|
audio_feature_batch = torch.from_numpy(whisper_batch)
|
||||||
|
audio_feature_batch = audio_feature_batch.to(device=self.unet.device,
|
||||||
|
dtype=self.unet.model.dtype)
|
||||||
|
audio_feature_batch = self.pe(audio_feature_batch)
|
||||||
|
latent_batch = latent_batch.to(dtype=self.unet.model.dtype)
|
||||||
|
|
||||||
|
pred_latents = self.unet.model(latent_batch,
|
||||||
|
self.timesteps,
|
||||||
|
encoder_hidden_states=audio_feature_batch).sample
|
||||||
|
recon = self.vae.decode_latents(pred_latents)
|
||||||
|
|
||||||
|
|
||||||
|
def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None):
|
||||||
|
|
||||||
|
while not quit_event.is_set():
|
||||||
|
try:
|
||||||
|
res_frame,idx,audio_frames = self.res_frame_queue.get(block=True, timeout=1)
|
||||||
|
except queue.Empty:
|
||||||
|
continue
|
||||||
|
if audio_frames[0][1]!=0 and audio_frames[1][1]!=0: #全为静音数据,只需要取fullimg
|
||||||
|
self.speaking = False
|
||||||
|
audiotype = audio_frames[0][1]
|
||||||
|
if self.custom_index.get(audiotype) is not None: #有自定义视频
|
||||||
|
mirindex = self.mirror_index(len(self.custom_img_cycle[audiotype]),self.custom_index[audiotype])
|
||||||
|
combine_frame = self.custom_img_cycle[audiotype][mirindex]
|
||||||
|
self.custom_index[audiotype] += 1
|
||||||
|
# if not self.custom_opt[audiotype].loop and self.custom_index[audiotype]>=len(self.custom_img_cycle[audiotype]):
|
||||||
|
# self.curr_state = 1 #当前视频不循环播放,切换到静音状态
|
||||||
|
else:
|
||||||
|
combine_frame = self.frame_list_cycle[idx]
|
||||||
|
else:
|
||||||
|
self.speaking = True
|
||||||
|
bbox = self.coord_list_cycle[idx]
|
||||||
|
ori_frame = copy.deepcopy(self.frame_list_cycle[idx])
|
||||||
|
x1, y1, x2, y2 = bbox
|
||||||
|
try:
|
||||||
|
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
mask = self.mask_list_cycle[idx]
|
||||||
|
mask_crop_box = self.mask_coords_list_cycle[idx]
|
||||||
|
#combine_frame = get_image(ori_frame,res_frame,bbox)
|
||||||
|
#t=time.perf_counter()
|
||||||
|
combine_frame = get_image_blending(ori_frame,res_frame,bbox,mask,mask_crop_box)
|
||||||
|
#print('blending time:',time.perf_counter()-t)
|
||||||
|
|
||||||
|
image = combine_frame #(outputs['image'] * 255).astype(np.uint8)
|
||||||
|
new_frame = VideoFrame.from_ndarray(image, format="bgr24")
|
||||||
|
asyncio.run_coroutine_threadsafe(video_track._queue.put(new_frame), loop)
|
||||||
|
if self.recording:
|
||||||
|
self.recordq_video.put(new_frame)
|
||||||
|
|
||||||
|
for audio_frame in audio_frames:
|
||||||
|
frame,type = audio_frame
|
||||||
|
frame = (frame * 32767).astype(np.int16)
|
||||||
|
new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0])
|
||||||
|
new_frame.planes[0].update(frame.tobytes())
|
||||||
|
new_frame.sample_rate=16000
|
||||||
|
# if audio_track._queue.qsize()>10:
|
||||||
|
# time.sleep(0.1)
|
||||||
|
asyncio.run_coroutine_threadsafe(audio_track._queue.put(new_frame), loop)
|
||||||
|
if self.recording:
|
||||||
|
self.recordq_audio.put(new_frame)
|
||||||
|
print('musereal process_frames thread stop')
|
||||||
|
|
||||||
|
def render(self,quit_event,loop=None,audio_track=None,video_track=None):
|
||||||
|
#if self.opt.asr:
|
||||||
|
# self.asr.warm_up()
|
||||||
|
|
||||||
|
self.tts.render(quit_event)
|
||||||
|
self.init_customindex()
|
||||||
|
process_thread = Thread(target=self.process_frames, args=(quit_event,loop,audio_track,video_track))
|
||||||
|
process_thread.start()
|
||||||
|
|
||||||
|
self.render_event.set() #start infer process render
|
||||||
|
count=0
|
||||||
|
totaltime=0
|
||||||
|
_starttime=time.perf_counter()
|
||||||
|
#_totalframe=0
|
||||||
|
while not quit_event.is_set(): #todo
|
||||||
|
# update texture every frame
|
||||||
|
# audio stream thread...
|
||||||
|
t = time.perf_counter()
|
||||||
|
self.asr.run_step()
|
||||||
|
#self.test_step(loop,audio_track,video_track)
|
||||||
|
# totaltime += (time.perf_counter() - t)
|
||||||
|
# count += self.opt.batch_size
|
||||||
|
# if count>=100:
|
||||||
|
# print(f"------actual avg infer fps:{count/totaltime:.4f}")
|
||||||
|
# count=0
|
||||||
|
# totaltime=0
|
||||||
|
if video_track._queue.qsize()>=1.5*self.opt.batch_size:
|
||||||
|
print('sleep qsize=',video_track._queue.qsize())
|
||||||
|
time.sleep(0.04*video_track._queue.qsize()*0.8)
|
||||||
|
# if video_track._queue.qsize()>=5:
|
||||||
|
# print('sleep qsize=',video_track._queue.qsize())
|
||||||
|
# time.sleep(0.04*video_track._queue.qsize()*0.8)
|
||||||
|
|
||||||
|
# delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms
|
||||||
|
# if delay > 0:
|
||||||
|
# time.sleep(delay)
|
||||||
|
self.render_event.clear() #end infer process render
|
||||||
|
print('musereal thread stop')
|
||||||
|
|
|
@ -0,0 +1,47 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import math
|
||||||
|
import json
|
||||||
|
|
||||||
|
from diffusers import UNet2DConditionModel
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
|
||||||
|
class PositionalEncoding(nn.Module):
|
||||||
|
def __init__(self, d_model=384, max_len=5000):
|
||||||
|
super(PositionalEncoding, self).__init__()
|
||||||
|
pe = torch.zeros(max_len, d_model)
|
||||||
|
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
||||||
|
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
||||||
|
pe[:, 0::2] = torch.sin(position * div_term)
|
||||||
|
pe[:, 1::2] = torch.cos(position * div_term)
|
||||||
|
pe = pe.unsqueeze(0)
|
||||||
|
self.register_buffer('pe', pe)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
b, seq_len, d_model = x.size()
|
||||||
|
pe = self.pe[:, :seq_len, :]
|
||||||
|
x = x + pe.to(x.device)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class UNet():
|
||||||
|
def __init__(self,
|
||||||
|
unet_config,
|
||||||
|
model_path,
|
||||||
|
use_float16=False,
|
||||||
|
):
|
||||||
|
with open(unet_config, 'r') as f:
|
||||||
|
unet_config = json.load(f)
|
||||||
|
self.model = UNet2DConditionModel(**unet_config)
|
||||||
|
self.pe = PositionalEncoding(d_model=384)
|
||||||
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
weights = torch.load(model_path) if torch.cuda.is_available() else torch.load(model_path, map_location=self.device)
|
||||||
|
self.model.load_state_dict(weights)
|
||||||
|
if use_float16:
|
||||||
|
self.model = self.model.half()
|
||||||
|
self.model.to(self.device)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unet = UNet()
|
|
@ -0,0 +1,148 @@
|
||||||
|
from diffusers import AutoencoderKL
|
||||||
|
import torch
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
import os
|
||||||
|
|
||||||
|
class VAE():
|
||||||
|
"""
|
||||||
|
VAE (Variational Autoencoder) class for image processing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model_path="./models/sd-vae-ft-mse/", resized_img=256, use_float16=False):
|
||||||
|
"""
|
||||||
|
Initialize the VAE instance.
|
||||||
|
|
||||||
|
:param model_path: Path to the trained model.
|
||||||
|
:param resized_img: The size to which images are resized.
|
||||||
|
:param use_float16: Whether to use float16 precision.
|
||||||
|
"""
|
||||||
|
self.model_path = model_path
|
||||||
|
self.vae = AutoencoderKL.from_pretrained(self.model_path)
|
||||||
|
|
||||||
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
self.vae.to(self.device)
|
||||||
|
|
||||||
|
if use_float16:
|
||||||
|
self.vae = self.vae.half()
|
||||||
|
self._use_float16 = True
|
||||||
|
else:
|
||||||
|
self._use_float16 = False
|
||||||
|
|
||||||
|
self.scaling_factor = self.vae.config.scaling_factor
|
||||||
|
self.transform = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||||
|
self._resized_img = resized_img
|
||||||
|
self._mask_tensor = self.get_mask_tensor()
|
||||||
|
|
||||||
|
def get_mask_tensor(self):
|
||||||
|
"""
|
||||||
|
Creates a mask tensor for image processing.
|
||||||
|
:return: A mask tensor.
|
||||||
|
"""
|
||||||
|
mask_tensor = torch.zeros((self._resized_img,self._resized_img))
|
||||||
|
mask_tensor[:self._resized_img//2,:] = 1
|
||||||
|
mask_tensor[mask_tensor< 0.5] = 0
|
||||||
|
mask_tensor[mask_tensor>= 0.5] = 1
|
||||||
|
return mask_tensor
|
||||||
|
|
||||||
|
def preprocess_img(self,img_name,half_mask=False):
|
||||||
|
"""
|
||||||
|
Preprocess an image for the VAE.
|
||||||
|
|
||||||
|
:param img_name: The image file path or a list of image file paths.
|
||||||
|
:param half_mask: Whether to apply a half mask to the image.
|
||||||
|
:return: A preprocessed image tensor.
|
||||||
|
"""
|
||||||
|
window = []
|
||||||
|
if isinstance(img_name, str):
|
||||||
|
window_fnames = [img_name]
|
||||||
|
for fname in window_fnames:
|
||||||
|
img = cv2.imread(fname)
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||||
|
img = cv2.resize(img, (self._resized_img, self._resized_img),
|
||||||
|
interpolation=cv2.INTER_LANCZOS4)
|
||||||
|
window.append(img)
|
||||||
|
else:
|
||||||
|
img = cv2.cvtColor(img_name, cv2.COLOR_BGR2RGB)
|
||||||
|
window.append(img)
|
||||||
|
|
||||||
|
x = np.asarray(window) / 255.
|
||||||
|
x = np.transpose(x, (3, 0, 1, 2))
|
||||||
|
x = torch.squeeze(torch.FloatTensor(x))
|
||||||
|
if half_mask:
|
||||||
|
x = x * (self._mask_tensor>0.5)
|
||||||
|
x = self.transform(x)
|
||||||
|
|
||||||
|
x = x.unsqueeze(0) # [1, 3, 256, 256] torch tensor
|
||||||
|
x = x.to(self.vae.device)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def encode_latents(self,image):
|
||||||
|
"""
|
||||||
|
Encode an image into latent variables.
|
||||||
|
|
||||||
|
:param image: The image tensor to encode.
|
||||||
|
:return: The encoded latent variables.
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
init_latent_dist = self.vae.encode(image.to(self.vae.dtype)).latent_dist
|
||||||
|
init_latents = self.scaling_factor * init_latent_dist.sample()
|
||||||
|
return init_latents
|
||||||
|
|
||||||
|
def decode_latents(self, latents):
|
||||||
|
"""
|
||||||
|
Decode latent variables back into an image.
|
||||||
|
:param latents: The latent variables to decode.
|
||||||
|
:return: A NumPy array representing the decoded image.
|
||||||
|
"""
|
||||||
|
latents = (1/ self.scaling_factor) * latents
|
||||||
|
image = self.vae.decode(latents.to(self.vae.dtype)).sample
|
||||||
|
image = (image / 2 + 0.5).clamp(0, 1)
|
||||||
|
image = image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
|
||||||
|
image = (image * 255).round().astype("uint8")
|
||||||
|
image = image[...,::-1] # RGB to BGR
|
||||||
|
return image
|
||||||
|
|
||||||
|
def get_latents_for_unet(self,img):
|
||||||
|
"""
|
||||||
|
Prepare latent variables for a U-Net model.
|
||||||
|
:param img: The image to process.
|
||||||
|
:return: A concatenated tensor of latents for U-Net input.
|
||||||
|
"""
|
||||||
|
|
||||||
|
ref_image = self.preprocess_img(img,half_mask=True) # [1, 3, 256, 256] RGB, torch tensor
|
||||||
|
masked_latents = self.encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
|
||||||
|
ref_image = self.preprocess_img(img,half_mask=False) # [1, 3, 256, 256] RGB, torch tensor
|
||||||
|
ref_latents = self.encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
|
||||||
|
latent_model_input = torch.cat([masked_latents, ref_latents], dim=1)
|
||||||
|
return latent_model_input
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
vae_mode_path = "./models/sd-vae-ft-mse/"
|
||||||
|
vae = VAE(model_path = vae_mode_path,use_float16=False)
|
||||||
|
img_path = "./results/sun001_crop/00000.png"
|
||||||
|
|
||||||
|
crop_imgs_path = "./results/sun001_crop/"
|
||||||
|
latents_out_path = "./results/latents/"
|
||||||
|
if not os.path.exists(latents_out_path):
|
||||||
|
os.mkdir(latents_out_path)
|
||||||
|
|
||||||
|
files = os.listdir(crop_imgs_path)
|
||||||
|
files.sort()
|
||||||
|
files = [file for file in files if file.split(".")[-1] == "png"]
|
||||||
|
|
||||||
|
for file in files:
|
||||||
|
index = file.split(".")[0]
|
||||||
|
img_path = crop_imgs_path + file
|
||||||
|
latents = vae.get_latents_for_unet(img_path)
|
||||||
|
print(img_path,"latents",latents.size())
|
||||||
|
#torch.save(latents,os.path.join(latents_out_path,index+".pt"))
|
||||||
|
#reload_tensor = torch.load('tensor.pt')
|
||||||
|
#print(reload_tensor.size())
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,348 @@
|
||||||
|
import argparse
|
||||||
|
import glob
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
from PIL import Image
|
||||||
|
from diffusers import AutoencoderKL
|
||||||
|
from face_alignment import NetworkSize
|
||||||
|
from mmpose.apis import inference_topdown, init_model
|
||||||
|
from mmpose.structures import merge_data_samples
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
try:
|
||||||
|
from utils.face_parsing import FaceParsing
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
from musetalk.utils.face_parsing import FaceParsing
|
||||||
|
|
||||||
|
|
||||||
|
def video2imgs(vid_path, save_path, ext='.png', cut_frame=10000000):
|
||||||
|
cap = cv2.VideoCapture(vid_path)
|
||||||
|
count = 0
|
||||||
|
while True:
|
||||||
|
if count > cut_frame:
|
||||||
|
break
|
||||||
|
ret, frame = cap.read()
|
||||||
|
if ret:
|
||||||
|
cv2.imwrite(f"{save_path}/{count:08d}.png", frame)
|
||||||
|
count += 1
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def read_imgs(img_list):
|
||||||
|
frames = []
|
||||||
|
print('reading images...')
|
||||||
|
for img_path in tqdm(img_list):
|
||||||
|
frame = cv2.imread(img_path)
|
||||||
|
frames.append(frame)
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
def get_landmark_and_bbox(img_list, upperbondrange=0):
|
||||||
|
frames = read_imgs(img_list)
|
||||||
|
batch_size_fa = 1
|
||||||
|
batches = [frames[i:i + batch_size_fa] for i in range(0, len(frames), batch_size_fa)]
|
||||||
|
coords_list = []
|
||||||
|
landmarks = []
|
||||||
|
if upperbondrange != 0:
|
||||||
|
print('get key_landmark and face bounding boxes with the bbox_shift:', upperbondrange)
|
||||||
|
else:
|
||||||
|
print('get key_landmark and face bounding boxes with the default value')
|
||||||
|
average_range_minus = []
|
||||||
|
average_range_plus = []
|
||||||
|
coord_placeholder = (0.0, 0.0, 0.0, 0.0)
|
||||||
|
for fb in tqdm(batches):
|
||||||
|
results = inference_topdown(model, np.asarray(fb)[0])
|
||||||
|
results = merge_data_samples(results)
|
||||||
|
keypoints = results.pred_instances.keypoints
|
||||||
|
face_land_mark = keypoints[0][23:91]
|
||||||
|
face_land_mark = face_land_mark.astype(np.int32)
|
||||||
|
|
||||||
|
# get bounding boxes by face detetion
|
||||||
|
bbox = fa.get_detections_for_batch(np.asarray(fb))
|
||||||
|
|
||||||
|
# adjust the bounding box refer to landmark
|
||||||
|
# Add the bounding box to a tuple and append it to the coordinates list
|
||||||
|
for j, f in enumerate(bbox):
|
||||||
|
if f is None: # no face in the image
|
||||||
|
coords_list += [coord_placeholder]
|
||||||
|
continue
|
||||||
|
|
||||||
|
half_face_coord = face_land_mark[29] # np.mean([face_land_mark[28], face_land_mark[29]], axis=0)
|
||||||
|
range_minus = (face_land_mark[30] - face_land_mark[29])[1]
|
||||||
|
range_plus = (face_land_mark[29] - face_land_mark[28])[1]
|
||||||
|
average_range_minus.append(range_minus)
|
||||||
|
average_range_plus.append(range_plus)
|
||||||
|
if upperbondrange != 0:
|
||||||
|
half_face_coord[1] = upperbondrange + half_face_coord[1] # 手动调整 + 向下(偏29) - 向上(偏28)
|
||||||
|
half_face_dist = np.max(face_land_mark[:, 1]) - half_face_coord[1]
|
||||||
|
upper_bond = half_face_coord[1] - half_face_dist
|
||||||
|
|
||||||
|
f_landmark = (
|
||||||
|
np.min(face_land_mark[:, 0]), int(upper_bond), np.max(face_land_mark[:, 0]),
|
||||||
|
np.max(face_land_mark[:, 1]))
|
||||||
|
x1, y1, x2, y2 = f_landmark
|
||||||
|
|
||||||
|
if y2 - y1 <= 0 or x2 - x1 <= 0 or x1 < 0: # if the landmark bbox is not suitable, reuse the bbox
|
||||||
|
coords_list += [f]
|
||||||
|
w, h = f[2] - f[0], f[3] - f[1]
|
||||||
|
print("error bbox:", f)
|
||||||
|
else:
|
||||||
|
coords_list += [f_landmark]
|
||||||
|
return coords_list, frames
|
||||||
|
|
||||||
|
|
||||||
|
class FaceAlignment:
|
||||||
|
def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
|
||||||
|
device='cuda', flip_input=False, face_detector='sfd', verbose=False):
|
||||||
|
self.device = device
|
||||||
|
self.flip_input = flip_input
|
||||||
|
self.landmarks_type = landmarks_type
|
||||||
|
self.verbose = verbose
|
||||||
|
|
||||||
|
network_size = int(network_size)
|
||||||
|
if 'cuda' in device:
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
# torch.backends.cuda.matmul.allow_tf32 = False
|
||||||
|
# torch.backends.cudnn.benchmark = True
|
||||||
|
# torch.backends.cudnn.deterministic = False
|
||||||
|
# torch.backends.cudnn.allow_tf32 = True
|
||||||
|
print('cuda start')
|
||||||
|
|
||||||
|
# Get the face detector
|
||||||
|
face_detector_module = __import__('face_detection.detection.' + face_detector,
|
||||||
|
globals(), locals(), [face_detector], 0)
|
||||||
|
|
||||||
|
self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose)
|
||||||
|
|
||||||
|
def get_detections_for_batch(self, images):
|
||||||
|
images = images[..., ::-1]
|
||||||
|
detected_faces = self.face_detector.detect_from_batch(images.copy())
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for i, d in enumerate(detected_faces):
|
||||||
|
if len(d) == 0:
|
||||||
|
results.append(None)
|
||||||
|
continue
|
||||||
|
d = d[0]
|
||||||
|
d = np.clip(d, 0, None)
|
||||||
|
|
||||||
|
x1, y1, x2, y2 = map(int, d[:-1])
|
||||||
|
results.append((x1, y1, x2, y2))
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def get_mask_tensor():
|
||||||
|
"""
|
||||||
|
Creates a mask tensor for image processing.
|
||||||
|
:return: A mask tensor.
|
||||||
|
"""
|
||||||
|
mask_tensor = torch.zeros((256, 256))
|
||||||
|
mask_tensor[:256 // 2, :] = 1
|
||||||
|
mask_tensor[mask_tensor < 0.5] = 0
|
||||||
|
mask_tensor[mask_tensor >= 0.5] = 1
|
||||||
|
return mask_tensor
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_img(img_name, half_mask=False):
|
||||||
|
window = []
|
||||||
|
if isinstance(img_name, str):
|
||||||
|
window_fnames = [img_name]
|
||||||
|
for fname in window_fnames:
|
||||||
|
img = cv2.imread(fname)
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||||
|
img = cv2.resize(img, (256, 256),
|
||||||
|
interpolation=cv2.INTER_LANCZOS4)
|
||||||
|
window.append(img)
|
||||||
|
else:
|
||||||
|
img = cv2.cvtColor(img_name, cv2.COLOR_BGR2RGB)
|
||||||
|
window.append(img)
|
||||||
|
x = np.asarray(window) / 255.
|
||||||
|
x = np.transpose(x, (3, 0, 1, 2))
|
||||||
|
x = torch.squeeze(torch.FloatTensor(x))
|
||||||
|
if half_mask:
|
||||||
|
x = x * (get_mask_tensor() > 0.5)
|
||||||
|
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||||
|
x = normalize(x)
|
||||||
|
x = x.unsqueeze(0) # [1, 3, 256, 256] torch tensor
|
||||||
|
x = x.to(device)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def encode_latents(image):
|
||||||
|
with torch.no_grad():
|
||||||
|
init_latent_dist = vae.encode(image.to(vae.dtype)).latent_dist
|
||||||
|
init_latents = vae.config.scaling_factor * init_latent_dist.sample()
|
||||||
|
return init_latents
|
||||||
|
|
||||||
|
|
||||||
|
def get_latents_for_unet(img):
|
||||||
|
ref_image = preprocess_img(img, half_mask=True) # [1, 3, 256, 256] RGB, torch tensor
|
||||||
|
masked_latents = encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
|
||||||
|
ref_image = preprocess_img(img, half_mask=False) # [1, 3, 256, 256] RGB, torch tensor
|
||||||
|
ref_latents = encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
|
||||||
|
latent_model_input = torch.cat([masked_latents, ref_latents], dim=1)
|
||||||
|
return latent_model_input
|
||||||
|
|
||||||
|
|
||||||
|
def get_crop_box(box, expand):
|
||||||
|
x, y, x1, y1 = box
|
||||||
|
x_c, y_c = (x + x1) // 2, (y + y1) // 2
|
||||||
|
w, h = x1 - x, y1 - y
|
||||||
|
s = int(max(w, h) // 2 * expand)
|
||||||
|
crop_box = [x_c - s, y_c - s, x_c + s, y_c + s]
|
||||||
|
return crop_box, s
|
||||||
|
|
||||||
|
|
||||||
|
def face_seg(image):
|
||||||
|
seg_image = fp(image)
|
||||||
|
if seg_image is None:
|
||||||
|
print("error, no person_segment")
|
||||||
|
return None
|
||||||
|
|
||||||
|
seg_image = seg_image.resize(image.size)
|
||||||
|
return seg_image
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_prepare_material(image, face_box, upper_boundary_ratio=0.5, expand=1.2):
|
||||||
|
body = Image.fromarray(image[:, :, ::-1])
|
||||||
|
|
||||||
|
x, y, x1, y1 = face_box
|
||||||
|
# print(x1-x,y1-y)
|
||||||
|
crop_box, s = get_crop_box(face_box, expand)
|
||||||
|
x_s, y_s, x_e, y_e = crop_box
|
||||||
|
|
||||||
|
face_large = body.crop(crop_box)
|
||||||
|
ori_shape = face_large.size
|
||||||
|
|
||||||
|
mask_image = face_seg(face_large)
|
||||||
|
mask_small = mask_image.crop((x - x_s, y - y_s, x1 - x_s, y1 - y_s))
|
||||||
|
mask_image = Image.new('L', ori_shape, 0)
|
||||||
|
mask_image.paste(mask_small, (x - x_s, y - y_s, x1 - x_s, y1 - y_s))
|
||||||
|
|
||||||
|
# keep upper_boundary_ratio of talking area
|
||||||
|
width, height = mask_image.size
|
||||||
|
top_boundary = int(height * upper_boundary_ratio)
|
||||||
|
modified_mask_image = Image.new('L', ori_shape, 0)
|
||||||
|
modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary))
|
||||||
|
|
||||||
|
blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1
|
||||||
|
mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0)
|
||||||
|
return mask_array, crop_box
|
||||||
|
|
||||||
|
|
||||||
|
##todo 简单根据文件后缀判断 要更精确的可以自己修改 使用 magic
|
||||||
|
def is_video_file(file_path):
|
||||||
|
video_exts = ['.mp4', '.mkv', '.flv', '.avi', '.mov'] # 这里列出了一些常见的视频文件扩展名,可以根据需要添加更多
|
||||||
|
file_ext = os.path.splitext(file_path)[1].lower() # 获取文件扩展名并转换为小写
|
||||||
|
return file_ext in video_exts
|
||||||
|
|
||||||
|
|
||||||
|
def create_dir(dir_path):
|
||||||
|
if not os.path.exists(dir_path):
|
||||||
|
os.makedirs(dir_path)
|
||||||
|
|
||||||
|
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
|
||||||
|
def create_musetalk_human(file, avatar_id):
|
||||||
|
# 保存文件设置 可以不动
|
||||||
|
save_path = os.path.join(current_dir, f'../data/avatars/avator_{avatar_id}')
|
||||||
|
save_full_path = os.path.join(current_dir, f'../data/avatars/avator_{avatar_id}/full_imgs')
|
||||||
|
create_dir(save_path)
|
||||||
|
create_dir(save_full_path)
|
||||||
|
mask_out_path = os.path.join(current_dir, f'../data/avatars/avator_{avatar_id}/mask')
|
||||||
|
create_dir(mask_out_path)
|
||||||
|
|
||||||
|
# 模型
|
||||||
|
mask_coords_path = os.path.join(current_dir, f'{save_path}/mask_coords.pkl')
|
||||||
|
coords_path = os.path.join(current_dir, f'{save_path}/coords.pkl')
|
||||||
|
latents_out_path = os.path.join(current_dir, f'{save_path}/latents.pt')
|
||||||
|
|
||||||
|
with open(os.path.join(current_dir, f'{save_path}/avator_info.json'), "w") as f:
|
||||||
|
json.dump({
|
||||||
|
"avatar_id": avatar_id,
|
||||||
|
"video_path": file,
|
||||||
|
"bbox_shift": 5
|
||||||
|
}, f)
|
||||||
|
|
||||||
|
if os.path.isfile(file):
|
||||||
|
if is_video_file(file):
|
||||||
|
video2imgs(file, save_full_path, ext='png')
|
||||||
|
else:
|
||||||
|
shutil.copyfile(file, f"{save_full_path}/{os.path.basename(file)}")
|
||||||
|
else:
|
||||||
|
files = os.listdir(file)
|
||||||
|
files.sort()
|
||||||
|
files = [file for file in files if file.split(".")[-1] == "png"]
|
||||||
|
for filename in files:
|
||||||
|
shutil.copyfile(f"{file}/{filename}", f"{save_full_path}/{filename}")
|
||||||
|
input_img_list = sorted(glob.glob(os.path.join(save_full_path, '*.[jpJP][pnPN]*[gG]')))
|
||||||
|
print("extracting landmarks...")
|
||||||
|
coord_list, frame_list = get_landmark_and_bbox(input_img_list, 5)
|
||||||
|
input_latent_list = []
|
||||||
|
idx = -1
|
||||||
|
# maker if the bbox is not sufficient
|
||||||
|
coord_placeholder = (0.0, 0.0, 0.0, 0.0)
|
||||||
|
for bbox, frame in zip(coord_list, frame_list):
|
||||||
|
idx = idx + 1
|
||||||
|
if bbox == coord_placeholder:
|
||||||
|
continue
|
||||||
|
x1, y1, x2, y2 = bbox
|
||||||
|
crop_frame = frame[y1:y2, x1:x2]
|
||||||
|
resized_crop_frame = cv2.resize(crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4)
|
||||||
|
latents = get_latents_for_unet(resized_crop_frame)
|
||||||
|
input_latent_list.append(latents)
|
||||||
|
|
||||||
|
frame_list_cycle = frame_list #+ frame_list[::-1]
|
||||||
|
coord_list_cycle = coord_list #+ coord_list[::-1]
|
||||||
|
input_latent_list_cycle = input_latent_list #+ input_latent_list[::-1]
|
||||||
|
mask_coords_list_cycle = []
|
||||||
|
mask_list_cycle = []
|
||||||
|
for i, frame in enumerate(tqdm(frame_list_cycle)):
|
||||||
|
cv2.imwrite(f"{save_full_path}/{str(i).zfill(8)}.png", frame)
|
||||||
|
face_box = coord_list_cycle[i]
|
||||||
|
mask, crop_box = get_image_prepare_material(frame, face_box)
|
||||||
|
cv2.imwrite(f"{mask_out_path}/{str(i).zfill(8)}.png", mask)
|
||||||
|
mask_coords_list_cycle += [crop_box]
|
||||||
|
mask_list_cycle.append(mask)
|
||||||
|
|
||||||
|
with open(mask_coords_path, 'wb') as f:
|
||||||
|
pickle.dump(mask_coords_list_cycle, f)
|
||||||
|
|
||||||
|
with open(coords_path, 'wb') as f:
|
||||||
|
pickle.dump(coord_list_cycle, f)
|
||||||
|
torch.save(input_latent_list_cycle, os.path.join(latents_out_path))
|
||||||
|
|
||||||
|
|
||||||
|
# initialize the mmpose model
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
fa = FaceAlignment(1, flip_input=False, device=device)
|
||||||
|
config_file = os.path.join(current_dir, 'utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py')
|
||||||
|
checkpoint_file = os.path.abspath(os.path.join(current_dir, '../models/dwpose/dw-ll_ucoco_384.pth'))
|
||||||
|
model = init_model(config_file, checkpoint_file, device=device)
|
||||||
|
vae = AutoencoderKL.from_pretrained(os.path.abspath(os.path.join(current_dir, '../models/sd-vae-ft-mse')))
|
||||||
|
vae.to(device)
|
||||||
|
fp = FaceParsing(os.path.abspath(os.path.join(current_dir, '../models/face-parse-bisent/resnet18-5c106cde.pth')),
|
||||||
|
os.path.abspath(os.path.join(current_dir, '../models/face-parse-bisent/79999_iter.pth')))
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# 视频文件地址
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--file",
|
||||||
|
type=str,
|
||||||
|
default=r'D:\ok\00000000.png',
|
||||||
|
)
|
||||||
|
parser.add_argument("--avatar_id",
|
||||||
|
type=str,
|
||||||
|
default='3',
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
create_musetalk_human(args.file, args.avatar_id)
|
|
@ -0,0 +1,5 @@
|
||||||
|
import sys
|
||||||
|
from os.path import abspath, dirname
|
||||||
|
current_dir = dirname(abspath(__file__))
|
||||||
|
parent_dir = dirname(current_dir)
|
||||||
|
sys.path.append(parent_dir+'/utils')
|
|
@ -0,0 +1,125 @@
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
from face_parsing import FaceParsing
|
||||||
|
import copy
|
||||||
|
|
||||||
|
fp = FaceParsing()
|
||||||
|
|
||||||
|
def get_crop_box(box, expand):
|
||||||
|
x, y, x1, y1 = box
|
||||||
|
x_c, y_c = (x+x1)//2, (y+y1)//2
|
||||||
|
w, h = x1-x, y1-y
|
||||||
|
s = int(max(w, h)//2*expand)
|
||||||
|
crop_box = [x_c-s, y_c-s, x_c+s, y_c+s]
|
||||||
|
return crop_box, s
|
||||||
|
|
||||||
|
def face_seg(image):
|
||||||
|
seg_image = fp(image)
|
||||||
|
if seg_image is None:
|
||||||
|
print("error, no person_segment")
|
||||||
|
return None
|
||||||
|
|
||||||
|
seg_image = seg_image.resize(image.size)
|
||||||
|
return seg_image
|
||||||
|
|
||||||
|
def get_image(image,face,face_box,upper_boundary_ratio = 0.5,expand=1.2):
|
||||||
|
#print(image.shape)
|
||||||
|
#print(face.shape)
|
||||||
|
|
||||||
|
body = Image.fromarray(image[:,:,::-1])
|
||||||
|
face = Image.fromarray(face[:,:,::-1])
|
||||||
|
|
||||||
|
x, y, x1, y1 = face_box
|
||||||
|
#print(x1-x,y1-y)
|
||||||
|
crop_box, s = get_crop_box(face_box, expand)
|
||||||
|
x_s, y_s, x_e, y_e = crop_box
|
||||||
|
face_position = (x, y)
|
||||||
|
|
||||||
|
face_large = body.crop(crop_box)
|
||||||
|
ori_shape = face_large.size
|
||||||
|
|
||||||
|
mask_image = face_seg(face_large)
|
||||||
|
mask_small = mask_image.crop((x-x_s, y-y_s, x1-x_s, y1-y_s))
|
||||||
|
mask_image = Image.new('L', ori_shape, 0)
|
||||||
|
mask_image.paste(mask_small, (x-x_s, y-y_s, x1-x_s, y1-y_s))
|
||||||
|
|
||||||
|
# keep upper_boundary_ratio of talking area
|
||||||
|
width, height = mask_image.size
|
||||||
|
top_boundary = int(height * upper_boundary_ratio)
|
||||||
|
modified_mask_image = Image.new('L', ori_shape, 0)
|
||||||
|
modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary))
|
||||||
|
|
||||||
|
blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1
|
||||||
|
mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0)
|
||||||
|
mask_image = Image.fromarray(mask_array)
|
||||||
|
|
||||||
|
face_large.paste(face, (x-x_s, y-y_s, x1-x_s, y1-y_s))
|
||||||
|
body.paste(face_large, crop_box[:2], mask_image)
|
||||||
|
body = np.array(body)
|
||||||
|
return body[:,:,::-1]
|
||||||
|
|
||||||
|
def get_image_prepare_material(image,face_box,upper_boundary_ratio = 0.5,expand=1.2):
|
||||||
|
body = Image.fromarray(image[:,:,::-1])
|
||||||
|
|
||||||
|
x, y, x1, y1 = face_box
|
||||||
|
#print(x1-x,y1-y)
|
||||||
|
crop_box, s = get_crop_box(face_box, expand)
|
||||||
|
x_s, y_s, x_e, y_e = crop_box
|
||||||
|
|
||||||
|
face_large = body.crop(crop_box)
|
||||||
|
ori_shape = face_large.size
|
||||||
|
|
||||||
|
mask_image = face_seg(face_large)
|
||||||
|
mask_small = mask_image.crop((x-x_s, y-y_s, x1-x_s, y1-y_s))
|
||||||
|
mask_image = Image.new('L', ori_shape, 0)
|
||||||
|
mask_image.paste(mask_small, (x-x_s, y-y_s, x1-x_s, y1-y_s))
|
||||||
|
|
||||||
|
# keep upper_boundary_ratio of talking area
|
||||||
|
width, height = mask_image.size
|
||||||
|
top_boundary = int(height * upper_boundary_ratio)
|
||||||
|
modified_mask_image = Image.new('L', ori_shape, 0)
|
||||||
|
modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary))
|
||||||
|
|
||||||
|
blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1
|
||||||
|
mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0)
|
||||||
|
return mask_array,crop_box
|
||||||
|
|
||||||
|
# def get_image_blending(image,face,face_box,mask_array,crop_box):
|
||||||
|
# body = Image.fromarray(image[:,:,::-1])
|
||||||
|
# face = Image.fromarray(face[:,:,::-1])
|
||||||
|
|
||||||
|
# x, y, x1, y1 = face_box
|
||||||
|
# x_s, y_s, x_e, y_e = crop_box
|
||||||
|
# face_large = body.crop(crop_box)
|
||||||
|
|
||||||
|
# mask_image = Image.fromarray(mask_array)
|
||||||
|
# mask_image = mask_image.convert("L")
|
||||||
|
# face_large.paste(face, (x-x_s, y-y_s, x1-x_s, y1-y_s))
|
||||||
|
# body.paste(face_large, crop_box[:2], mask_image)
|
||||||
|
# body = np.array(body)
|
||||||
|
# return body[:,:,::-1]
|
||||||
|
|
||||||
|
def get_image_blending(image,face,face_box,mask_array,crop_box):
|
||||||
|
body = image
|
||||||
|
x, y, x1, y1 = face_box
|
||||||
|
x_s, y_s, x_e, y_e = crop_box
|
||||||
|
face_large = copy.deepcopy(body[y_s:y_e, x_s:x_e])
|
||||||
|
face_large[y-y_s:y1-y_s, x-x_s:x1-x_s]=face
|
||||||
|
|
||||||
|
mask_image = cv2.cvtColor(mask_array,cv2.COLOR_BGR2GRAY)
|
||||||
|
mask_image = (mask_image/255).astype(np.float32)
|
||||||
|
|
||||||
|
# mask_not = cv2.bitwise_not(mask_array)
|
||||||
|
# prospect_tmp = cv2.bitwise_and(face_large, face_large, mask=mask_array)
|
||||||
|
# background_img = body[y_s:y_e, x_s:x_e]
|
||||||
|
# background_img = cv2.bitwise_and(background_img, background_img, mask=mask_not)
|
||||||
|
# body[y_s:y_e, x_s:x_e] = prospect_tmp + background_img
|
||||||
|
|
||||||
|
#print(mask_image.shape)
|
||||||
|
#print(cv2.minMaxLoc(mask_image))
|
||||||
|
|
||||||
|
body[y_s:y_e, x_s:x_e] = cv2.blendLinear(face_large,body[y_s:y_e, x_s:x_e],mask_image,1-mask_image)
|
||||||
|
|
||||||
|
#body.paste(face_large, crop_box[:2], mask_image)
|
||||||
|
return body
|
|
@ -0,0 +1,54 @@
|
||||||
|
default_scope = 'mmpose'
|
||||||
|
|
||||||
|
# hooks
|
||||||
|
default_hooks = dict(
|
||||||
|
timer=dict(type='IterTimerHook'),
|
||||||
|
logger=dict(type='LoggerHook', interval=50),
|
||||||
|
param_scheduler=dict(type='ParamSchedulerHook'),
|
||||||
|
checkpoint=dict(type='CheckpointHook', interval=10),
|
||||||
|
sampler_seed=dict(type='DistSamplerSeedHook'),
|
||||||
|
visualization=dict(type='PoseVisualizationHook', enable=False),
|
||||||
|
badcase=dict(
|
||||||
|
type='BadCaseAnalysisHook',
|
||||||
|
enable=False,
|
||||||
|
out_dir='badcase',
|
||||||
|
metric_type='loss',
|
||||||
|
badcase_thr=5))
|
||||||
|
|
||||||
|
# custom hooks
|
||||||
|
custom_hooks = [
|
||||||
|
# Synchronize model buffers such as running_mean and running_var in BN
|
||||||
|
# at the end of each epoch
|
||||||
|
dict(type='SyncBuffersHook')
|
||||||
|
]
|
||||||
|
|
||||||
|
# multi-processing backend
|
||||||
|
env_cfg = dict(
|
||||||
|
cudnn_benchmark=False,
|
||||||
|
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
|
||||||
|
dist_cfg=dict(backend='nccl'),
|
||||||
|
)
|
||||||
|
|
||||||
|
# visualizer
|
||||||
|
vis_backends = [
|
||||||
|
dict(type='LocalVisBackend'),
|
||||||
|
# dict(type='TensorboardVisBackend'),
|
||||||
|
# dict(type='WandbVisBackend'),
|
||||||
|
]
|
||||||
|
visualizer = dict(
|
||||||
|
type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
|
||||||
|
|
||||||
|
# logger
|
||||||
|
log_processor = dict(
|
||||||
|
type='LogProcessor', window_size=50, by_epoch=True, num_digits=6)
|
||||||
|
log_level = 'INFO'
|
||||||
|
load_from = None
|
||||||
|
resume = False
|
||||||
|
|
||||||
|
# file I/O backend
|
||||||
|
backend_args = dict(backend='local')
|
||||||
|
|
||||||
|
# training/validation/testing progress
|
||||||
|
train_cfg = dict(by_epoch=True)
|
||||||
|
val_cfg = dict()
|
||||||
|
test_cfg = dict()
|
|
@ -0,0 +1,257 @@
|
||||||
|
#_base_ = ['../../../_base_/default_runtime.py']
|
||||||
|
_base_ = ['default_runtime.py']
|
||||||
|
|
||||||
|
# runtime
|
||||||
|
max_epochs = 270
|
||||||
|
stage2_num_epochs = 30
|
||||||
|
base_lr = 4e-3
|
||||||
|
train_batch_size = 32
|
||||||
|
val_batch_size = 32
|
||||||
|
|
||||||
|
train_cfg = dict(max_epochs=max_epochs, val_interval=10)
|
||||||
|
randomness = dict(seed=21)
|
||||||
|
|
||||||
|
# optimizer
|
||||||
|
optim_wrapper = dict(
|
||||||
|
type='OptimWrapper',
|
||||||
|
optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05),
|
||||||
|
paramwise_cfg=dict(
|
||||||
|
norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))
|
||||||
|
|
||||||
|
# learning rate
|
||||||
|
param_scheduler = [
|
||||||
|
dict(
|
||||||
|
type='LinearLR',
|
||||||
|
start_factor=1.0e-5,
|
||||||
|
by_epoch=False,
|
||||||
|
begin=0,
|
||||||
|
end=1000),
|
||||||
|
dict(
|
||||||
|
# use cosine lr from 150 to 300 epoch
|
||||||
|
type='CosineAnnealingLR',
|
||||||
|
eta_min=base_lr * 0.05,
|
||||||
|
begin=max_epochs // 2,
|
||||||
|
end=max_epochs,
|
||||||
|
T_max=max_epochs // 2,
|
||||||
|
by_epoch=True,
|
||||||
|
convert_to_iter_based=True),
|
||||||
|
]
|
||||||
|
|
||||||
|
# automatically scaling LR based on the actual training batch size
|
||||||
|
auto_scale_lr = dict(base_batch_size=512)
|
||||||
|
|
||||||
|
# codec settings
|
||||||
|
codec = dict(
|
||||||
|
type='SimCCLabel',
|
||||||
|
input_size=(288, 384),
|
||||||
|
sigma=(6., 6.93),
|
||||||
|
simcc_split_ratio=2.0,
|
||||||
|
normalize=False,
|
||||||
|
use_dark=False)
|
||||||
|
|
||||||
|
# model settings
|
||||||
|
model = dict(
|
||||||
|
type='TopdownPoseEstimator',
|
||||||
|
data_preprocessor=dict(
|
||||||
|
type='PoseDataPreprocessor',
|
||||||
|
mean=[123.675, 116.28, 103.53],
|
||||||
|
std=[58.395, 57.12, 57.375],
|
||||||
|
bgr_to_rgb=True),
|
||||||
|
backbone=dict(
|
||||||
|
_scope_='mmdet',
|
||||||
|
type='CSPNeXt',
|
||||||
|
arch='P5',
|
||||||
|
expand_ratio=0.5,
|
||||||
|
deepen_factor=1.,
|
||||||
|
widen_factor=1.,
|
||||||
|
out_indices=(4, ),
|
||||||
|
channel_attention=True,
|
||||||
|
norm_cfg=dict(type='SyncBN'),
|
||||||
|
act_cfg=dict(type='SiLU'),
|
||||||
|
init_cfg=dict(
|
||||||
|
type='Pretrained',
|
||||||
|
prefix='backbone.',
|
||||||
|
checkpoint='https://download.openmmlab.com/mmpose/v1/projects/'
|
||||||
|
'rtmpose/cspnext-l_udp-aic-coco_210e-256x192-273b7631_20230130.pth' # noqa: E501
|
||||||
|
)),
|
||||||
|
head=dict(
|
||||||
|
type='RTMCCHead',
|
||||||
|
in_channels=1024,
|
||||||
|
out_channels=133,
|
||||||
|
input_size=codec['input_size'],
|
||||||
|
in_featuremap_size=(9, 12),
|
||||||
|
simcc_split_ratio=codec['simcc_split_ratio'],
|
||||||
|
final_layer_kernel_size=7,
|
||||||
|
gau_cfg=dict(
|
||||||
|
hidden_dims=256,
|
||||||
|
s=128,
|
||||||
|
expansion_factor=2,
|
||||||
|
dropout_rate=0.,
|
||||||
|
drop_path=0.,
|
||||||
|
act_fn='SiLU',
|
||||||
|
use_rel_bias=False,
|
||||||
|
pos_enc=False),
|
||||||
|
loss=dict(
|
||||||
|
type='KLDiscretLoss',
|
||||||
|
use_target_weight=True,
|
||||||
|
beta=10.,
|
||||||
|
label_softmax=True),
|
||||||
|
decoder=codec),
|
||||||
|
test_cfg=dict(flip_test=True, ))
|
||||||
|
|
||||||
|
# base dataset settings
|
||||||
|
dataset_type = 'UBody2dDataset'
|
||||||
|
data_mode = 'topdown'
|
||||||
|
data_root = 'data/UBody/'
|
||||||
|
|
||||||
|
backend_args = dict(backend='local')
|
||||||
|
|
||||||
|
scenes = [
|
||||||
|
'Magic_show', 'Entertainment', 'ConductMusic', 'Online_class', 'TalkShow',
|
||||||
|
'Speech', 'Fitness', 'Interview', 'Olympic', 'TVShow', 'Singing',
|
||||||
|
'SignLanguage', 'Movie', 'LiveVlog', 'VideoConference'
|
||||||
|
]
|
||||||
|
|
||||||
|
train_datasets = [
|
||||||
|
dict(
|
||||||
|
type='CocoWholeBodyDataset',
|
||||||
|
data_root='data/coco/',
|
||||||
|
data_mode=data_mode,
|
||||||
|
ann_file='annotations/coco_wholebody_train_v1.0.json',
|
||||||
|
data_prefix=dict(img='train2017/'),
|
||||||
|
pipeline=[])
|
||||||
|
]
|
||||||
|
|
||||||
|
for scene in scenes:
|
||||||
|
train_dataset = dict(
|
||||||
|
type=dataset_type,
|
||||||
|
data_root=data_root,
|
||||||
|
data_mode=data_mode,
|
||||||
|
ann_file=f'annotations/{scene}/train_annotations.json',
|
||||||
|
data_prefix=dict(img='images/'),
|
||||||
|
pipeline=[],
|
||||||
|
sample_interval=10)
|
||||||
|
train_datasets.append(train_dataset)
|
||||||
|
|
||||||
|
# pipelines
|
||||||
|
train_pipeline = [
|
||||||
|
dict(type='LoadImage', backend_args=backend_args),
|
||||||
|
dict(type='GetBBoxCenterScale'),
|
||||||
|
dict(type='RandomFlip', direction='horizontal'),
|
||||||
|
dict(type='RandomHalfBody'),
|
||||||
|
dict(
|
||||||
|
type='RandomBBoxTransform', scale_factor=[0.5, 1.5], rotate_factor=90),
|
||||||
|
dict(type='TopdownAffine', input_size=codec['input_size']),
|
||||||
|
dict(type='mmdet.YOLOXHSVRandomAug'),
|
||||||
|
dict(
|
||||||
|
type='Albumentation',
|
||||||
|
transforms=[
|
||||||
|
dict(type='Blur', p=0.1),
|
||||||
|
dict(type='MedianBlur', p=0.1),
|
||||||
|
dict(
|
||||||
|
type='CoarseDropout',
|
||||||
|
max_holes=1,
|
||||||
|
max_height=0.4,
|
||||||
|
max_width=0.4,
|
||||||
|
min_holes=1,
|
||||||
|
min_height=0.2,
|
||||||
|
min_width=0.2,
|
||||||
|
p=1.0),
|
||||||
|
]),
|
||||||
|
dict(type='GenerateTarget', encoder=codec),
|
||||||
|
dict(type='PackPoseInputs')
|
||||||
|
]
|
||||||
|
val_pipeline = [
|
||||||
|
dict(type='LoadImage', backend_args=backend_args),
|
||||||
|
dict(type='GetBBoxCenterScale'),
|
||||||
|
dict(type='TopdownAffine', input_size=codec['input_size']),
|
||||||
|
dict(type='PackPoseInputs')
|
||||||
|
]
|
||||||
|
|
||||||
|
train_pipeline_stage2 = [
|
||||||
|
dict(type='LoadImage', backend_args=backend_args),
|
||||||
|
dict(type='GetBBoxCenterScale'),
|
||||||
|
dict(type='RandomFlip', direction='horizontal'),
|
||||||
|
dict(type='RandomHalfBody'),
|
||||||
|
dict(
|
||||||
|
type='RandomBBoxTransform',
|
||||||
|
shift_factor=0.,
|
||||||
|
scale_factor=[0.5, 1.5],
|
||||||
|
rotate_factor=90),
|
||||||
|
dict(type='TopdownAffine', input_size=codec['input_size']),
|
||||||
|
dict(type='mmdet.YOLOXHSVRandomAug'),
|
||||||
|
dict(
|
||||||
|
type='Albumentation',
|
||||||
|
transforms=[
|
||||||
|
dict(type='Blur', p=0.1),
|
||||||
|
dict(type='MedianBlur', p=0.1),
|
||||||
|
dict(
|
||||||
|
type='CoarseDropout',
|
||||||
|
max_holes=1,
|
||||||
|
max_height=0.4,
|
||||||
|
max_width=0.4,
|
||||||
|
min_holes=1,
|
||||||
|
min_height=0.2,
|
||||||
|
min_width=0.2,
|
||||||
|
p=0.5),
|
||||||
|
]),
|
||||||
|
dict(type='GenerateTarget', encoder=codec),
|
||||||
|
dict(type='PackPoseInputs')
|
||||||
|
]
|
||||||
|
|
||||||
|
# data loaders
|
||||||
|
train_dataloader = dict(
|
||||||
|
batch_size=train_batch_size,
|
||||||
|
num_workers=10,
|
||||||
|
persistent_workers=True,
|
||||||
|
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||||
|
dataset=dict(
|
||||||
|
type='CombinedDataset',
|
||||||
|
metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'),
|
||||||
|
datasets=train_datasets,
|
||||||
|
pipeline=train_pipeline,
|
||||||
|
test_mode=False,
|
||||||
|
))
|
||||||
|
|
||||||
|
val_dataloader = dict(
|
||||||
|
batch_size=val_batch_size,
|
||||||
|
num_workers=10,
|
||||||
|
persistent_workers=True,
|
||||||
|
drop_last=False,
|
||||||
|
sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
|
||||||
|
dataset=dict(
|
||||||
|
type='CocoWholeBodyDataset',
|
||||||
|
data_root=data_root,
|
||||||
|
data_mode=data_mode,
|
||||||
|
ann_file='data/coco/annotations/coco_wholebody_val_v1.0.json',
|
||||||
|
bbox_file='data/coco/person_detection_results/'
|
||||||
|
'COCO_val2017_detections_AP_H_56_person.json',
|
||||||
|
data_prefix=dict(img='coco/val2017/'),
|
||||||
|
test_mode=True,
|
||||||
|
pipeline=val_pipeline,
|
||||||
|
))
|
||||||
|
test_dataloader = val_dataloader
|
||||||
|
|
||||||
|
# hooks
|
||||||
|
default_hooks = dict(
|
||||||
|
checkpoint=dict(
|
||||||
|
save_best='coco-wholebody/AP', rule='greater', max_keep_ckpts=1))
|
||||||
|
|
||||||
|
custom_hooks = [
|
||||||
|
dict(
|
||||||
|
type='EMAHook',
|
||||||
|
ema_type='ExpMomentumEMA',
|
||||||
|
momentum=0.0002,
|
||||||
|
update_buffers=True,
|
||||||
|
priority=49),
|
||||||
|
dict(
|
||||||
|
type='mmdet.PipelineSwitchHook',
|
||||||
|
switch_epoch=max_epochs - stage2_num_epochs,
|
||||||
|
switch_pipeline=train_pipeline_stage2)
|
||||||
|
]
|
||||||
|
|
||||||
|
# evaluators
|
||||||
|
val_evaluator = dict(
|
||||||
|
type='CocoWholeBodyMetric',
|
||||||
|
ann_file='data/coco/annotations/coco_wholebody_val_v1.0.json')
|
||||||
|
test_evaluator = val_evaluator
|
|
@ -0,0 +1 @@
|
||||||
|
The code for Face Detection in this folder has been taken from the wonderful [face_alignment](https://github.com/1adrianb/face-alignment) repository. This has been modified to take batches of faces at a time.
|
|
@ -0,0 +1,7 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
__author__ = """Adrian Bulat"""
|
||||||
|
__email__ = 'adrian.bulat@nottingham.ac.uk'
|
||||||
|
__version__ = '1.0.1'
|
||||||
|
|
||||||
|
from .api import FaceAlignment, LandmarksType, NetworkSize, YOLOv8_face
|
|
@ -0,0 +1,240 @@
|
||||||
|
from __future__ import print_function
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from torch.utils.model_zoo import load_url
|
||||||
|
from enum import Enum
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
try:
|
||||||
|
import urllib.request as request_file
|
||||||
|
except BaseException:
|
||||||
|
import urllib as request_file
|
||||||
|
|
||||||
|
from .models import FAN, ResNetDepth
|
||||||
|
from .utils import *
|
||||||
|
|
||||||
|
|
||||||
|
class LandmarksType(Enum):
|
||||||
|
"""Enum class defining the type of landmarks to detect.
|
||||||
|
|
||||||
|
``_2D`` - the detected points ``(x,y)`` are detected in a 2D space and follow the visible contour of the face
|
||||||
|
``_2halfD`` - this points represent the projection of the 3D points into 3D
|
||||||
|
``_3D`` - detect the points ``(x,y,z)``` in a 3D space
|
||||||
|
|
||||||
|
"""
|
||||||
|
_2D = 1
|
||||||
|
_2halfD = 2
|
||||||
|
_3D = 3
|
||||||
|
|
||||||
|
|
||||||
|
class NetworkSize(Enum):
|
||||||
|
# TINY = 1
|
||||||
|
# SMALL = 2
|
||||||
|
# MEDIUM = 3
|
||||||
|
LARGE = 4
|
||||||
|
|
||||||
|
def __new__(cls, value):
|
||||||
|
member = object.__new__(cls)
|
||||||
|
member._value_ = value
|
||||||
|
return member
|
||||||
|
|
||||||
|
def __int__(self):
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class FaceAlignment:
|
||||||
|
def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
|
||||||
|
device='cuda', flip_input=False, face_detector='sfd', verbose=False):
|
||||||
|
self.device = device
|
||||||
|
self.flip_input = flip_input
|
||||||
|
self.landmarks_type = landmarks_type
|
||||||
|
self.verbose = verbose
|
||||||
|
|
||||||
|
network_size = int(network_size)
|
||||||
|
|
||||||
|
if 'cuda' in device:
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
# torch.backends.cuda.matmul.allow_tf32 = False
|
||||||
|
# torch.backends.cudnn.benchmark = True
|
||||||
|
# torch.backends.cudnn.deterministic = False
|
||||||
|
# torch.backends.cudnn.allow_tf32 = True
|
||||||
|
print('cuda start')
|
||||||
|
|
||||||
|
|
||||||
|
# Get the face detector
|
||||||
|
face_detector_module = __import__('face_detection.detection.' + face_detector,
|
||||||
|
globals(), locals(), [face_detector], 0)
|
||||||
|
|
||||||
|
self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose)
|
||||||
|
|
||||||
|
def get_detections_for_batch(self, images):
|
||||||
|
images = images[..., ::-1]
|
||||||
|
detected_faces = self.face_detector.detect_from_batch(images.copy())
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for i, d in enumerate(detected_faces):
|
||||||
|
if len(d) == 0:
|
||||||
|
results.append(None)
|
||||||
|
continue
|
||||||
|
d = d[0]
|
||||||
|
d = np.clip(d, 0, None)
|
||||||
|
|
||||||
|
x1, y1, x2, y2 = map(int, d[:-1])
|
||||||
|
results.append((x1, y1, x2, y2))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
class YOLOv8_face:
|
||||||
|
def __init__(self, path = 'face_detection/weights/yolov8n-face.onnx', conf_thres=0.2, iou_thres=0.5):
|
||||||
|
self.conf_threshold = conf_thres
|
||||||
|
self.iou_threshold = iou_thres
|
||||||
|
self.class_names = ['face']
|
||||||
|
self.num_classes = len(self.class_names)
|
||||||
|
# Initialize model
|
||||||
|
self.net = cv2.dnn.readNet(path)
|
||||||
|
self.input_height = 640
|
||||||
|
self.input_width = 640
|
||||||
|
self.reg_max = 16
|
||||||
|
|
||||||
|
self.project = np.arange(self.reg_max)
|
||||||
|
self.strides = (8, 16, 32)
|
||||||
|
self.feats_hw = [(math.ceil(self.input_height / self.strides[i]), math.ceil(self.input_width / self.strides[i])) for i in range(len(self.strides))]
|
||||||
|
self.anchors = self.make_anchors(self.feats_hw)
|
||||||
|
|
||||||
|
def make_anchors(self, feats_hw, grid_cell_offset=0.5):
|
||||||
|
"""Generate anchors from features."""
|
||||||
|
anchor_points = {}
|
||||||
|
for i, stride in enumerate(self.strides):
|
||||||
|
h,w = feats_hw[i]
|
||||||
|
x = np.arange(0, w) + grid_cell_offset # shift x
|
||||||
|
y = np.arange(0, h) + grid_cell_offset # shift y
|
||||||
|
sx, sy = np.meshgrid(x, y)
|
||||||
|
# sy, sx = np.meshgrid(y, x)
|
||||||
|
anchor_points[stride] = np.stack((sx, sy), axis=-1).reshape(-1, 2)
|
||||||
|
return anchor_points
|
||||||
|
|
||||||
|
def softmax(self, x, axis=1):
|
||||||
|
x_exp = np.exp(x)
|
||||||
|
# 如果是列向量,则axis=0
|
||||||
|
x_sum = np.sum(x_exp, axis=axis, keepdims=True)
|
||||||
|
s = x_exp / x_sum
|
||||||
|
return s
|
||||||
|
|
||||||
|
def resize_image(self, srcimg, keep_ratio=True):
|
||||||
|
top, left, newh, neww = 0, 0, self.input_width, self.input_height
|
||||||
|
if keep_ratio and srcimg.shape[0] != srcimg.shape[1]:
|
||||||
|
hw_scale = srcimg.shape[0] / srcimg.shape[1]
|
||||||
|
if hw_scale > 1:
|
||||||
|
newh, neww = self.input_height, int(self.input_width / hw_scale)
|
||||||
|
img = cv2.resize(srcimg, (neww, newh), interpolation=cv2.INTER_AREA)
|
||||||
|
left = int((self.input_width - neww) * 0.5)
|
||||||
|
img = cv2.copyMakeBorder(img, 0, 0, left, self.input_width - neww - left, cv2.BORDER_CONSTANT,
|
||||||
|
value=(0, 0, 0)) # add border
|
||||||
|
else:
|
||||||
|
newh, neww = int(self.input_height * hw_scale), self.input_width
|
||||||
|
img = cv2.resize(srcimg, (neww, newh), interpolation=cv2.INTER_AREA)
|
||||||
|
top = int((self.input_height - newh) * 0.5)
|
||||||
|
img = cv2.copyMakeBorder(img, top, self.input_height - newh - top, 0, 0, cv2.BORDER_CONSTANT,
|
||||||
|
value=(0, 0, 0))
|
||||||
|
else:
|
||||||
|
img = cv2.resize(srcimg, (self.input_width, self.input_height), interpolation=cv2.INTER_AREA)
|
||||||
|
return img, newh, neww, top, left
|
||||||
|
|
||||||
|
def detect(self, srcimg):
|
||||||
|
input_img, newh, neww, padh, padw = self.resize_image(cv2.cvtColor(srcimg, cv2.COLOR_BGR2RGB))
|
||||||
|
scale_h, scale_w = srcimg.shape[0]/newh, srcimg.shape[1]/neww
|
||||||
|
input_img = input_img.astype(np.float32) / 255.0
|
||||||
|
|
||||||
|
blob = cv2.dnn.blobFromImage(input_img)
|
||||||
|
self.net.setInput(blob)
|
||||||
|
outputs = self.net.forward(self.net.getUnconnectedOutLayersNames())
|
||||||
|
# if isinstance(outputs, tuple):
|
||||||
|
# outputs = list(outputs)
|
||||||
|
# if float(cv2.__version__[:3])>=4.7:
|
||||||
|
# outputs = [outputs[2], outputs[0], outputs[1]] ###opencv4.7需要这一步,opencv4.5不需要
|
||||||
|
# Perform inference on the image
|
||||||
|
det_bboxes, det_conf, det_classid, landmarks = self.post_process(outputs, scale_h, scale_w, padh, padw)
|
||||||
|
return det_bboxes, det_conf, det_classid, landmarks
|
||||||
|
|
||||||
|
def post_process(self, preds, scale_h, scale_w, padh, padw):
|
||||||
|
bboxes, scores, landmarks = [], [], []
|
||||||
|
for i, pred in enumerate(preds):
|
||||||
|
stride = int(self.input_height/pred.shape[2])
|
||||||
|
pred = pred.transpose((0, 2, 3, 1))
|
||||||
|
|
||||||
|
box = pred[..., :self.reg_max * 4]
|
||||||
|
cls = 1 / (1 + np.exp(-pred[..., self.reg_max * 4:-15])).reshape((-1,1))
|
||||||
|
kpts = pred[..., -15:].reshape((-1,15)) ### x1,y1,score1, ..., x5,y5,score5
|
||||||
|
|
||||||
|
# tmp = box.reshape(self.feats_hw[i][0], self.feats_hw[i][1], 4, self.reg_max)
|
||||||
|
tmp = box.reshape(-1, 4, self.reg_max)
|
||||||
|
bbox_pred = self.softmax(tmp, axis=-1)
|
||||||
|
bbox_pred = np.dot(bbox_pred, self.project).reshape((-1,4))
|
||||||
|
|
||||||
|
bbox = self.distance2bbox(self.anchors[stride], bbox_pred, max_shape=(self.input_height, self.input_width)) * stride
|
||||||
|
kpts[:, 0::3] = (kpts[:, 0::3] * 2.0 + (self.anchors[stride][:, 0].reshape((-1,1)) - 0.5)) * stride
|
||||||
|
kpts[:, 1::3] = (kpts[:, 1::3] * 2.0 + (self.anchors[stride][:, 1].reshape((-1,1)) - 0.5)) * stride
|
||||||
|
kpts[:, 2::3] = 1 / (1+np.exp(-kpts[:, 2::3]))
|
||||||
|
|
||||||
|
bbox -= np.array([[padw, padh, padw, padh]]) ###合理使用广播法则
|
||||||
|
bbox *= np.array([[scale_w, scale_h, scale_w, scale_h]])
|
||||||
|
kpts -= np.tile(np.array([padw, padh, 0]), 5).reshape((1,15))
|
||||||
|
kpts *= np.tile(np.array([scale_w, scale_h, 1]), 5).reshape((1,15))
|
||||||
|
|
||||||
|
bboxes.append(bbox)
|
||||||
|
scores.append(cls)
|
||||||
|
landmarks.append(kpts)
|
||||||
|
|
||||||
|
bboxes = np.concatenate(bboxes, axis=0)
|
||||||
|
scores = np.concatenate(scores, axis=0)
|
||||||
|
landmarks = np.concatenate(landmarks, axis=0)
|
||||||
|
|
||||||
|
bboxes_wh = bboxes.copy()
|
||||||
|
bboxes_wh[:, 2:4] = bboxes[:, 2:4] - bboxes[:, 0:2] ####xywh
|
||||||
|
classIds = np.argmax(scores, axis=1)
|
||||||
|
confidences = np.max(scores, axis=1) ####max_class_confidence
|
||||||
|
|
||||||
|
mask = confidences>self.conf_threshold
|
||||||
|
bboxes_wh = bboxes_wh[mask] ###合理使用广播法则
|
||||||
|
confidences = confidences[mask]
|
||||||
|
classIds = classIds[mask]
|
||||||
|
landmarks = landmarks[mask]
|
||||||
|
|
||||||
|
indices = cv2.dnn.NMSBoxes(bboxes_wh.tolist(), confidences.tolist(), self.conf_threshold,
|
||||||
|
self.iou_threshold).flatten()
|
||||||
|
if len(indices) > 0:
|
||||||
|
mlvl_bboxes = bboxes_wh[indices]
|
||||||
|
confidences = confidences[indices]
|
||||||
|
classIds = classIds[indices]
|
||||||
|
landmarks = landmarks[indices]
|
||||||
|
return mlvl_bboxes, confidences, classIds, landmarks
|
||||||
|
else:
|
||||||
|
print('nothing detect')
|
||||||
|
return np.array([]), np.array([]), np.array([]), np.array([])
|
||||||
|
|
||||||
|
def distance2bbox(self, points, distance, max_shape=None):
|
||||||
|
x1 = points[:, 0] - distance[:, 0]
|
||||||
|
y1 = points[:, 1] - distance[:, 1]
|
||||||
|
x2 = points[:, 0] + distance[:, 2]
|
||||||
|
y2 = points[:, 1] + distance[:, 3]
|
||||||
|
if max_shape is not None:
|
||||||
|
x1 = np.clip(x1, 0, max_shape[1])
|
||||||
|
y1 = np.clip(y1, 0, max_shape[0])
|
||||||
|
x2 = np.clip(x2, 0, max_shape[1])
|
||||||
|
y2 = np.clip(y2, 0, max_shape[0])
|
||||||
|
return np.stack([x1, y1, x2, y2], axis=-1)
|
||||||
|
|
||||||
|
def draw_detections(self, image, boxes, scores, kpts):
|
||||||
|
for box, score, kp in zip(boxes, scores, kpts):
|
||||||
|
x, y, w, h = box.astype(int)
|
||||||
|
# Draw rectangle
|
||||||
|
cv2.rectangle(image, (x, y), (x + w, y + h), (0, 0, 255), thickness=3)
|
||||||
|
cv2.putText(image, "face:"+str(round(score,2)), (x, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), thickness=2)
|
||||||
|
for i in range(5):
|
||||||
|
cv2.circle(image, (int(kp[i * 3]), int(kp[i * 3 + 1])), 4, (0, 255, 0), thickness=-1)
|
||||||
|
# cv2.putText(image, str(i), (int(kp[i * 3]), int(kp[i * 3 + 1]) - 10), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), thickness=1)
|
||||||
|
return image
|
||||||
|
|
||||||
|
ROOT = os.path.dirname(os.path.abspath(__file__))
|
|
@ -0,0 +1 @@
|
||||||
|
from .core import FaceDetector
|
|
@ -0,0 +1,130 @@
|
||||||
|
import logging
|
||||||
|
import glob
|
||||||
|
from tqdm import tqdm
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
|
||||||
|
class FaceDetector(object):
|
||||||
|
"""An abstract class representing a face detector.
|
||||||
|
|
||||||
|
Any other face detection implementation must subclass it. All subclasses
|
||||||
|
must implement ``detect_from_image``, that return a list of detected
|
||||||
|
bounding boxes. Optionally, for speed considerations detect from path is
|
||||||
|
recommended.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, device, verbose):
|
||||||
|
self.device = device
|
||||||
|
self.verbose = verbose
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
if 'cpu' in device:
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.warning("Detection running on CPU, this may be potentially slow.")
|
||||||
|
|
||||||
|
if 'cpu' not in device and 'cuda' not in device:
|
||||||
|
if verbose:
|
||||||
|
logger.error("Expected values for device are: {cpu, cuda} but got: %s", device)
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
def detect_from_image(self, tensor_or_path):
|
||||||
|
"""Detects faces in a given image.
|
||||||
|
|
||||||
|
This function detects the faces present in a provided BGR(usually)
|
||||||
|
image. The input can be either the image itself or the path to it.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
tensor_or_path {numpy.ndarray, torch.tensor or string} -- the path
|
||||||
|
to an image or the image itself.
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> path_to_image = 'data/image_01.jpg'
|
||||||
|
... detected_faces = detect_from_image(path_to_image)
|
||||||
|
[A list of bounding boxes (x1, y1, x2, y2)]
|
||||||
|
>>> image = cv2.imread(path_to_image)
|
||||||
|
... detected_faces = detect_from_image(image)
|
||||||
|
[A list of bounding boxes (x1, y1, x2, y2)]
|
||||||
|
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def detect_from_directory(self, path, extensions=['.jpg', '.png'], recursive=False, show_progress_bar=True):
|
||||||
|
"""Detects faces from all the images present in a given directory.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
path {string} -- a string containing a path that points to the folder containing the images
|
||||||
|
|
||||||
|
Keyword Arguments:
|
||||||
|
extensions {list} -- list of string containing the extensions to be
|
||||||
|
consider in the following format: ``.extension_name`` (default:
|
||||||
|
{['.jpg', '.png']}) recursive {bool} -- option wherever to scan the
|
||||||
|
folder recursively (default: {False}) show_progress_bar {bool} --
|
||||||
|
display a progressbar (default: {True})
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> directory = 'data'
|
||||||
|
... detected_faces = detect_from_directory(directory)
|
||||||
|
{A dictionary of [lists containing bounding boxes(x1, y1, x2, y2)]}
|
||||||
|
|
||||||
|
"""
|
||||||
|
if self.verbose:
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
if len(extensions) == 0:
|
||||||
|
if self.verbose:
|
||||||
|
logger.error("Expected at list one extension, but none was received.")
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
logger.info("Constructing the list of images.")
|
||||||
|
additional_pattern = '/**/*' if recursive else '/*'
|
||||||
|
files = []
|
||||||
|
for extension in extensions:
|
||||||
|
files.extend(glob.glob(path + additional_pattern + extension, recursive=recursive))
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
logger.info("Finished searching for images. %s images found", len(files))
|
||||||
|
logger.info("Preparing to run the detection.")
|
||||||
|
|
||||||
|
predictions = {}
|
||||||
|
for image_path in tqdm(files, disable=not show_progress_bar):
|
||||||
|
if self.verbose:
|
||||||
|
logger.info("Running the face detector on image: %s", image_path)
|
||||||
|
predictions[image_path] = self.detect_from_image(image_path)
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
logger.info("The detector was successfully run on all %s images", len(files))
|
||||||
|
|
||||||
|
return predictions
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reference_scale(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reference_x_shift(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reference_y_shift(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def tensor_or_path_to_ndarray(tensor_or_path, rgb=True):
|
||||||
|
"""Convert path (represented as a string) or torch.tensor to a numpy.ndarray
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
tensor_or_path {numpy.ndarray, torch.tensor or string} -- path to the image, or the image itself
|
||||||
|
"""
|
||||||
|
if isinstance(tensor_or_path, str):
|
||||||
|
return cv2.imread(tensor_or_path) if not rgb else cv2.imread(tensor_or_path)[..., ::-1]
|
||||||
|
elif torch.is_tensor(tensor_or_path):
|
||||||
|
# Call cpu in case its coming from cuda
|
||||||
|
return tensor_or_path.cpu().numpy()[..., ::-1].copy() if not rgb else tensor_or_path.cpu().numpy()
|
||||||
|
elif isinstance(tensor_or_path, np.ndarray):
|
||||||
|
return tensor_or_path[..., ::-1].copy() if not rgb else tensor_or_path
|
||||||
|
else:
|
||||||
|
raise TypeError
|
|
@ -0,0 +1 @@
|
||||||
|
from .sfd_detector import SFDDetector as FaceDetector
|
|
@ -0,0 +1,129 @@
|
||||||
|
from __future__ import print_function
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import cv2
|
||||||
|
import random
|
||||||
|
import datetime
|
||||||
|
import time
|
||||||
|
import math
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
try:
|
||||||
|
from iou import IOU
|
||||||
|
except BaseException:
|
||||||
|
# IOU cython speedup 10x
|
||||||
|
def IOU(ax1, ay1, ax2, ay2, bx1, by1, bx2, by2):
|
||||||
|
sa = abs((ax2 - ax1) * (ay2 - ay1))
|
||||||
|
sb = abs((bx2 - bx1) * (by2 - by1))
|
||||||
|
x1, y1 = max(ax1, bx1), max(ay1, by1)
|
||||||
|
x2, y2 = min(ax2, bx2), min(ay2, by2)
|
||||||
|
w = x2 - x1
|
||||||
|
h = y2 - y1
|
||||||
|
if w < 0 or h < 0:
|
||||||
|
return 0.0
|
||||||
|
else:
|
||||||
|
return 1.0 * w * h / (sa + sb - w * h)
|
||||||
|
|
||||||
|
|
||||||
|
def bboxlog(x1, y1, x2, y2, axc, ayc, aww, ahh):
|
||||||
|
xc, yc, ww, hh = (x2 + x1) / 2, (y2 + y1) / 2, x2 - x1, y2 - y1
|
||||||
|
dx, dy = (xc - axc) / aww, (yc - ayc) / ahh
|
||||||
|
dw, dh = math.log(ww / aww), math.log(hh / ahh)
|
||||||
|
return dx, dy, dw, dh
|
||||||
|
|
||||||
|
|
||||||
|
def bboxloginv(dx, dy, dw, dh, axc, ayc, aww, ahh):
|
||||||
|
xc, yc = dx * aww + axc, dy * ahh + ayc
|
||||||
|
ww, hh = math.exp(dw) * aww, math.exp(dh) * ahh
|
||||||
|
x1, x2, y1, y2 = xc - ww / 2, xc + ww / 2, yc - hh / 2, yc + hh / 2
|
||||||
|
return x1, y1, x2, y2
|
||||||
|
|
||||||
|
|
||||||
|
def nms(dets, thresh):
|
||||||
|
if 0 == len(dets):
|
||||||
|
return []
|
||||||
|
x1, y1, x2, y2, scores = dets[:, 0], dets[:, 1], dets[:, 2], dets[:, 3], dets[:, 4]
|
||||||
|
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||||
|
order = scores.argsort()[::-1]
|
||||||
|
|
||||||
|
keep = []
|
||||||
|
while order.size > 0:
|
||||||
|
i = order[0]
|
||||||
|
keep.append(i)
|
||||||
|
xx1, yy1 = np.maximum(x1[i], x1[order[1:]]), np.maximum(y1[i], y1[order[1:]])
|
||||||
|
xx2, yy2 = np.minimum(x2[i], x2[order[1:]]), np.minimum(y2[i], y2[order[1:]])
|
||||||
|
|
||||||
|
w, h = np.maximum(0.0, xx2 - xx1 + 1), np.maximum(0.0, yy2 - yy1 + 1)
|
||||||
|
ovr = w * h / (areas[i] + areas[order[1:]] - w * h)
|
||||||
|
|
||||||
|
inds = np.where(ovr <= thresh)[0]
|
||||||
|
order = order[inds + 1]
|
||||||
|
|
||||||
|
return keep
|
||||||
|
|
||||||
|
|
||||||
|
def encode(matched, priors, variances):
|
||||||
|
"""Encode the variances from the priorbox layers into the ground truth boxes
|
||||||
|
we have matched (based on jaccard overlap) with the prior boxes.
|
||||||
|
Args:
|
||||||
|
matched: (tensor) Coords of ground truth for each prior in point-form
|
||||||
|
Shape: [num_priors, 4].
|
||||||
|
priors: (tensor) Prior boxes in center-offset form
|
||||||
|
Shape: [num_priors,4].
|
||||||
|
variances: (list[float]) Variances of priorboxes
|
||||||
|
Return:
|
||||||
|
encoded boxes (tensor), Shape: [num_priors, 4]
|
||||||
|
"""
|
||||||
|
|
||||||
|
# dist b/t match center and prior's center
|
||||||
|
g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2]
|
||||||
|
# encode variance
|
||||||
|
g_cxcy /= (variances[0] * priors[:, 2:])
|
||||||
|
# match wh / prior wh
|
||||||
|
g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
|
||||||
|
g_wh = torch.log(g_wh) / variances[1]
|
||||||
|
# return target for smooth_l1_loss
|
||||||
|
return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
|
||||||
|
|
||||||
|
|
||||||
|
def decode(loc, priors, variances):
|
||||||
|
"""Decode locations from predictions using priors to undo
|
||||||
|
the encoding we did for offset regression at train time.
|
||||||
|
Args:
|
||||||
|
loc (tensor): location predictions for loc layers,
|
||||||
|
Shape: [num_priors,4]
|
||||||
|
priors (tensor): Prior boxes in center-offset form.
|
||||||
|
Shape: [num_priors,4].
|
||||||
|
variances: (list[float]) Variances of priorboxes
|
||||||
|
Return:
|
||||||
|
decoded bounding box predictions
|
||||||
|
"""
|
||||||
|
|
||||||
|
boxes = torch.cat((
|
||||||
|
priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
|
||||||
|
priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
|
||||||
|
boxes[:, :2] -= boxes[:, 2:] / 2
|
||||||
|
boxes[:, 2:] += boxes[:, :2]
|
||||||
|
return boxes
|
||||||
|
|
||||||
|
def batch_decode(loc, priors, variances):
|
||||||
|
"""Decode locations from predictions using priors to undo
|
||||||
|
the encoding we did for offset regression at train time.
|
||||||
|
Args:
|
||||||
|
loc (tensor): location predictions for loc layers,
|
||||||
|
Shape: [num_priors,4]
|
||||||
|
priors (tensor): Prior boxes in center-offset form.
|
||||||
|
Shape: [num_priors,4].
|
||||||
|
variances: (list[float]) Variances of priorboxes
|
||||||
|
Return:
|
||||||
|
decoded bounding box predictions
|
||||||
|
"""
|
||||||
|
|
||||||
|
boxes = torch.cat((
|
||||||
|
priors[:, :, :2] + loc[:, :, :2] * variances[0] * priors[:, :, 2:],
|
||||||
|
priors[:, :, 2:] * torch.exp(loc[:, :, 2:] * variances[1])), 2)
|
||||||
|
boxes[:, :, :2] -= boxes[:, :, 2:] / 2
|
||||||
|
boxes[:, :, 2:] += boxes[:, :, :2]
|
||||||
|
return boxes
|
|
@ -0,0 +1,114 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import cv2
|
||||||
|
import random
|
||||||
|
import datetime
|
||||||
|
import math
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import scipy.io as sio
|
||||||
|
import zipfile
|
||||||
|
from .net_s3fd import s3fd
|
||||||
|
from .bbox import *
|
||||||
|
|
||||||
|
|
||||||
|
def detect(net, img, device):
|
||||||
|
img = img - np.array([104, 117, 123])
|
||||||
|
img = img.transpose(2, 0, 1)
|
||||||
|
img = img.reshape((1,) + img.shape)
|
||||||
|
|
||||||
|
if 'cuda' in device:
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
|
img = torch.from_numpy(img).float().to(device)
|
||||||
|
BB, CC, HH, WW = img.size()
|
||||||
|
with torch.no_grad():
|
||||||
|
olist = net(img)
|
||||||
|
|
||||||
|
bboxlist = []
|
||||||
|
for i in range(len(olist) // 2):
|
||||||
|
olist[i * 2] = F.softmax(olist[i * 2], dim=1)
|
||||||
|
olist = [oelem.data.cpu() for oelem in olist]
|
||||||
|
for i in range(len(olist) // 2):
|
||||||
|
ocls, oreg = olist[i * 2], olist[i * 2 + 1]
|
||||||
|
FB, FC, FH, FW = ocls.size() # feature map size
|
||||||
|
stride = 2**(i + 2) # 4,8,16,32,64,128
|
||||||
|
anchor = stride * 4
|
||||||
|
poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
|
||||||
|
for Iindex, hindex, windex in poss:
|
||||||
|
axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
|
||||||
|
score = ocls[0, 1, hindex, windex]
|
||||||
|
loc = oreg[0, :, hindex, windex].contiguous().view(1, 4)
|
||||||
|
priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]])
|
||||||
|
variances = [0.1, 0.2]
|
||||||
|
box = decode(loc, priors, variances)
|
||||||
|
x1, y1, x2, y2 = box[0] * 1.0
|
||||||
|
# cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
|
||||||
|
bboxlist.append([x1, y1, x2, y2, score])
|
||||||
|
bboxlist = np.array(bboxlist)
|
||||||
|
if 0 == len(bboxlist):
|
||||||
|
bboxlist = np.zeros((1, 5))
|
||||||
|
|
||||||
|
return bboxlist
|
||||||
|
|
||||||
|
def batch_detect(net, imgs, device):
|
||||||
|
imgs = imgs - np.array([104, 117, 123])
|
||||||
|
imgs = imgs.transpose(0, 3, 1, 2)
|
||||||
|
|
||||||
|
if 'cuda' in device:
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
|
imgs = torch.from_numpy(imgs).float().to(device)
|
||||||
|
BB, CC, HH, WW = imgs.size()
|
||||||
|
with torch.no_grad():
|
||||||
|
olist = net(imgs)
|
||||||
|
# print(olist)
|
||||||
|
|
||||||
|
bboxlist = []
|
||||||
|
for i in range(len(olist) // 2):
|
||||||
|
olist[i * 2] = F.softmax(olist[i * 2], dim=1)
|
||||||
|
|
||||||
|
olist = [oelem.cpu() for oelem in olist]
|
||||||
|
for i in range(len(olist) // 2):
|
||||||
|
ocls, oreg = olist[i * 2], olist[i * 2 + 1]
|
||||||
|
FB, FC, FH, FW = ocls.size() # feature map size
|
||||||
|
stride = 2**(i + 2) # 4,8,16,32,64,128
|
||||||
|
anchor = stride * 4
|
||||||
|
poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
|
||||||
|
for Iindex, hindex, windex in poss:
|
||||||
|
axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
|
||||||
|
score = ocls[:, 1, hindex, windex]
|
||||||
|
loc = oreg[:, :, hindex, windex].contiguous().view(BB, 1, 4)
|
||||||
|
priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]]).view(1, 1, 4)
|
||||||
|
variances = [0.1, 0.2]
|
||||||
|
box = batch_decode(loc, priors, variances)
|
||||||
|
box = box[:, 0] * 1.0
|
||||||
|
# cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
|
||||||
|
bboxlist.append(torch.cat([box, score.unsqueeze(1)], 1).cpu().numpy())
|
||||||
|
bboxlist = np.array(bboxlist)
|
||||||
|
if 0 == len(bboxlist):
|
||||||
|
bboxlist = np.zeros((1, BB, 5))
|
||||||
|
|
||||||
|
return bboxlist
|
||||||
|
|
||||||
|
def flip_detect(net, img, device):
|
||||||
|
img = cv2.flip(img, 1)
|
||||||
|
b = detect(net, img, device)
|
||||||
|
|
||||||
|
bboxlist = np.zeros(b.shape)
|
||||||
|
bboxlist[:, 0] = img.shape[1] - b[:, 2]
|
||||||
|
bboxlist[:, 1] = b[:, 1]
|
||||||
|
bboxlist[:, 2] = img.shape[1] - b[:, 0]
|
||||||
|
bboxlist[:, 3] = b[:, 3]
|
||||||
|
bboxlist[:, 4] = b[:, 4]
|
||||||
|
return bboxlist
|
||||||
|
|
||||||
|
|
||||||
|
def pts_to_bb(pts):
|
||||||
|
min_x, min_y = np.min(pts, axis=0)
|
||||||
|
max_x, max_y = np.max(pts, axis=0)
|
||||||
|
return np.array([min_x, min_y, max_x, max_y])
|
|
@ -0,0 +1,129 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class L2Norm(nn.Module):
|
||||||
|
def __init__(self, n_channels, scale=1.0):
|
||||||
|
super(L2Norm, self).__init__()
|
||||||
|
self.n_channels = n_channels
|
||||||
|
self.scale = scale
|
||||||
|
self.eps = 1e-10
|
||||||
|
self.weight = nn.Parameter(torch.Tensor(self.n_channels))
|
||||||
|
self.weight.data *= 0.0
|
||||||
|
self.weight.data += self.scale
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps
|
||||||
|
x = x / norm * self.weight.view(1, -1, 1, 1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class s3fd(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(s3fd, self).__init__()
|
||||||
|
self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
|
self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
|
self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
|
self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
|
self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
|
self.fc6 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=3)
|
||||||
|
self.fc7 = nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0)
|
||||||
|
|
||||||
|
self.conv6_1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0)
|
||||||
|
self.conv6_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)
|
||||||
|
|
||||||
|
self.conv7_1 = nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0)
|
||||||
|
self.conv7_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
|
||||||
|
|
||||||
|
self.conv3_3_norm = L2Norm(256, scale=10)
|
||||||
|
self.conv4_3_norm = L2Norm(512, scale=8)
|
||||||
|
self.conv5_3_norm = L2Norm(512, scale=5)
|
||||||
|
|
||||||
|
self.conv3_3_norm_mbox_conf = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.conv3_3_norm_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.conv4_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.conv4_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.conv5_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.conv5_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
|
self.fc7_mbox_conf = nn.Conv2d(1024, 2, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.fc7_mbox_loc = nn.Conv2d(1024, 4, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.conv6_2_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.conv6_2_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.conv7_2_mbox_conf = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.conv7_2_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
h = F.relu(self.conv1_1(x))
|
||||||
|
h = F.relu(self.conv1_2(h))
|
||||||
|
h = F.max_pool2d(h, 2, 2)
|
||||||
|
|
||||||
|
h = F.relu(self.conv2_1(h))
|
||||||
|
h = F.relu(self.conv2_2(h))
|
||||||
|
h = F.max_pool2d(h, 2, 2)
|
||||||
|
|
||||||
|
h = F.relu(self.conv3_1(h))
|
||||||
|
h = F.relu(self.conv3_2(h))
|
||||||
|
h = F.relu(self.conv3_3(h))
|
||||||
|
f3_3 = h
|
||||||
|
h = F.max_pool2d(h, 2, 2)
|
||||||
|
|
||||||
|
h = F.relu(self.conv4_1(h))
|
||||||
|
h = F.relu(self.conv4_2(h))
|
||||||
|
h = F.relu(self.conv4_3(h))
|
||||||
|
f4_3 = h
|
||||||
|
h = F.max_pool2d(h, 2, 2)
|
||||||
|
|
||||||
|
h = F.relu(self.conv5_1(h))
|
||||||
|
h = F.relu(self.conv5_2(h))
|
||||||
|
h = F.relu(self.conv5_3(h))
|
||||||
|
f5_3 = h
|
||||||
|
h = F.max_pool2d(h, 2, 2)
|
||||||
|
|
||||||
|
h = F.relu(self.fc6(h))
|
||||||
|
h = F.relu(self.fc7(h))
|
||||||
|
ffc7 = h
|
||||||
|
h = F.relu(self.conv6_1(h))
|
||||||
|
h = F.relu(self.conv6_2(h))
|
||||||
|
f6_2 = h
|
||||||
|
h = F.relu(self.conv7_1(h))
|
||||||
|
h = F.relu(self.conv7_2(h))
|
||||||
|
f7_2 = h
|
||||||
|
|
||||||
|
f3_3 = self.conv3_3_norm(f3_3)
|
||||||
|
f4_3 = self.conv4_3_norm(f4_3)
|
||||||
|
f5_3 = self.conv5_3_norm(f5_3)
|
||||||
|
|
||||||
|
cls1 = self.conv3_3_norm_mbox_conf(f3_3)
|
||||||
|
reg1 = self.conv3_3_norm_mbox_loc(f3_3)
|
||||||
|
cls2 = self.conv4_3_norm_mbox_conf(f4_3)
|
||||||
|
reg2 = self.conv4_3_norm_mbox_loc(f4_3)
|
||||||
|
cls3 = self.conv5_3_norm_mbox_conf(f5_3)
|
||||||
|
reg3 = self.conv5_3_norm_mbox_loc(f5_3)
|
||||||
|
cls4 = self.fc7_mbox_conf(ffc7)
|
||||||
|
reg4 = self.fc7_mbox_loc(ffc7)
|
||||||
|
cls5 = self.conv6_2_mbox_conf(f6_2)
|
||||||
|
reg5 = self.conv6_2_mbox_loc(f6_2)
|
||||||
|
cls6 = self.conv7_2_mbox_conf(f7_2)
|
||||||
|
reg6 = self.conv7_2_mbox_loc(f7_2)
|
||||||
|
|
||||||
|
# max-out background label
|
||||||
|
chunk = torch.chunk(cls1, 4, 1)
|
||||||
|
bmax = torch.max(torch.max(chunk[0], chunk[1]), chunk[2])
|
||||||
|
cls1 = torch.cat([bmax, chunk[3]], dim=1)
|
||||||
|
|
||||||
|
return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6]
|
|
@ -0,0 +1,59 @@
|
||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
from torch.utils.model_zoo import load_url
|
||||||
|
|
||||||
|
from ..core import FaceDetector
|
||||||
|
|
||||||
|
from .net_s3fd import s3fd
|
||||||
|
from .bbox import *
|
||||||
|
from .detect import *
|
||||||
|
|
||||||
|
models_urls = {
|
||||||
|
's3fd': 'https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth',
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class SFDDetector(FaceDetector):
|
||||||
|
def __init__(self, device, path_to_detector=os.path.join(os.path.dirname(os.path.abspath(__file__)), 's3fd.pth'), verbose=False):
|
||||||
|
super(SFDDetector, self).__init__(device, verbose)
|
||||||
|
|
||||||
|
# Initialise the face detector
|
||||||
|
if not os.path.isfile(path_to_detector):
|
||||||
|
model_weights = load_url(models_urls['s3fd'])
|
||||||
|
else:
|
||||||
|
model_weights = torch.load(path_to_detector)
|
||||||
|
|
||||||
|
self.face_detector = s3fd()
|
||||||
|
self.face_detector.load_state_dict(model_weights)
|
||||||
|
self.face_detector.to(device)
|
||||||
|
self.face_detector.eval()
|
||||||
|
|
||||||
|
def detect_from_image(self, tensor_or_path):
|
||||||
|
image = self.tensor_or_path_to_ndarray(tensor_or_path)
|
||||||
|
|
||||||
|
bboxlist = detect(self.face_detector, image, device=self.device)
|
||||||
|
keep = nms(bboxlist, 0.3)
|
||||||
|
bboxlist = bboxlist[keep, :]
|
||||||
|
bboxlist = [x for x in bboxlist if x[-1] > 0.5]
|
||||||
|
|
||||||
|
return bboxlist
|
||||||
|
|
||||||
|
def detect_from_batch(self, images):
|
||||||
|
bboxlists = batch_detect(self.face_detector, images, device=self.device)
|
||||||
|
keeps = [nms(bboxlists[:, i, :], 0.3) for i in range(bboxlists.shape[1])]
|
||||||
|
bboxlists = [bboxlists[keep, i, :] for i, keep in enumerate(keeps)]
|
||||||
|
bboxlists = [[x for x in bboxlist if x[-1] > 0.5] for bboxlist in bboxlists]
|
||||||
|
|
||||||
|
return bboxlists
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reference_scale(self):
|
||||||
|
return 195
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reference_x_shift(self):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reference_y_shift(self):
|
||||||
|
return 0
|
|
@ -0,0 +1,261 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False):
|
||||||
|
"3x3 convolution with padding"
|
||||||
|
return nn.Conv2d(in_planes, out_planes, kernel_size=3,
|
||||||
|
stride=strd, padding=padding, bias=bias)
|
||||||
|
|
||||||
|
|
||||||
|
class ConvBlock(nn.Module):
|
||||||
|
def __init__(self, in_planes, out_planes):
|
||||||
|
super(ConvBlock, self).__init__()
|
||||||
|
self.bn1 = nn.BatchNorm2d(in_planes)
|
||||||
|
self.conv1 = conv3x3(in_planes, int(out_planes / 2))
|
||||||
|
self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
|
||||||
|
self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4))
|
||||||
|
self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
|
||||||
|
self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4))
|
||||||
|
|
||||||
|
if in_planes != out_planes:
|
||||||
|
self.downsample = nn.Sequential(
|
||||||
|
nn.BatchNorm2d(in_planes),
|
||||||
|
nn.ReLU(True),
|
||||||
|
nn.Conv2d(in_planes, out_planes,
|
||||||
|
kernel_size=1, stride=1, bias=False),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.downsample = None
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
residual = x
|
||||||
|
|
||||||
|
out1 = self.bn1(x)
|
||||||
|
out1 = F.relu(out1, True)
|
||||||
|
out1 = self.conv1(out1)
|
||||||
|
|
||||||
|
out2 = self.bn2(out1)
|
||||||
|
out2 = F.relu(out2, True)
|
||||||
|
out2 = self.conv2(out2)
|
||||||
|
|
||||||
|
out3 = self.bn3(out2)
|
||||||
|
out3 = F.relu(out3, True)
|
||||||
|
out3 = self.conv3(out3)
|
||||||
|
|
||||||
|
out3 = torch.cat((out1, out2, out3), 1)
|
||||||
|
|
||||||
|
if self.downsample is not None:
|
||||||
|
residual = self.downsample(residual)
|
||||||
|
|
||||||
|
out3 += residual
|
||||||
|
|
||||||
|
return out3
|
||||||
|
|
||||||
|
|
||||||
|
class Bottleneck(nn.Module):
|
||||||
|
|
||||||
|
expansion = 4
|
||||||
|
|
||||||
|
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||||
|
super(Bottleneck, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
||||||
|
self.bn1 = nn.BatchNorm2d(planes)
|
||||||
|
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
||||||
|
padding=1, bias=False)
|
||||||
|
self.bn2 = nn.BatchNorm2d(planes)
|
||||||
|
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
||||||
|
self.bn3 = nn.BatchNorm2d(planes * 4)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.downsample = downsample
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
residual = x
|
||||||
|
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.bn1(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
out = self.conv2(out)
|
||||||
|
out = self.bn2(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
out = self.conv3(out)
|
||||||
|
out = self.bn3(out)
|
||||||
|
|
||||||
|
if self.downsample is not None:
|
||||||
|
residual = self.downsample(x)
|
||||||
|
|
||||||
|
out += residual
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class HourGlass(nn.Module):
|
||||||
|
def __init__(self, num_modules, depth, num_features):
|
||||||
|
super(HourGlass, self).__init__()
|
||||||
|
self.num_modules = num_modules
|
||||||
|
self.depth = depth
|
||||||
|
self.features = num_features
|
||||||
|
|
||||||
|
self._generate_network(self.depth)
|
||||||
|
|
||||||
|
def _generate_network(self, level):
|
||||||
|
self.add_module('b1_' + str(level), ConvBlock(self.features, self.features))
|
||||||
|
|
||||||
|
self.add_module('b2_' + str(level), ConvBlock(self.features, self.features))
|
||||||
|
|
||||||
|
if level > 1:
|
||||||
|
self._generate_network(level - 1)
|
||||||
|
else:
|
||||||
|
self.add_module('b2_plus_' + str(level), ConvBlock(self.features, self.features))
|
||||||
|
|
||||||
|
self.add_module('b3_' + str(level), ConvBlock(self.features, self.features))
|
||||||
|
|
||||||
|
def _forward(self, level, inp):
|
||||||
|
# Upper branch
|
||||||
|
up1 = inp
|
||||||
|
up1 = self._modules['b1_' + str(level)](up1)
|
||||||
|
|
||||||
|
# Lower branch
|
||||||
|
low1 = F.avg_pool2d(inp, 2, stride=2)
|
||||||
|
low1 = self._modules['b2_' + str(level)](low1)
|
||||||
|
|
||||||
|
if level > 1:
|
||||||
|
low2 = self._forward(level - 1, low1)
|
||||||
|
else:
|
||||||
|
low2 = low1
|
||||||
|
low2 = self._modules['b2_plus_' + str(level)](low2)
|
||||||
|
|
||||||
|
low3 = low2
|
||||||
|
low3 = self._modules['b3_' + str(level)](low3)
|
||||||
|
|
||||||
|
up2 = F.interpolate(low3, scale_factor=2, mode='nearest')
|
||||||
|
|
||||||
|
return up1 + up2
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self._forward(self.depth, x)
|
||||||
|
|
||||||
|
|
||||||
|
class FAN(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, num_modules=1):
|
||||||
|
super(FAN, self).__init__()
|
||||||
|
self.num_modules = num_modules
|
||||||
|
|
||||||
|
# Base part
|
||||||
|
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
|
||||||
|
self.bn1 = nn.BatchNorm2d(64)
|
||||||
|
self.conv2 = ConvBlock(64, 128)
|
||||||
|
self.conv3 = ConvBlock(128, 128)
|
||||||
|
self.conv4 = ConvBlock(128, 256)
|
||||||
|
|
||||||
|
# Stacking part
|
||||||
|
for hg_module in range(self.num_modules):
|
||||||
|
self.add_module('m' + str(hg_module), HourGlass(1, 4, 256))
|
||||||
|
self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256))
|
||||||
|
self.add_module('conv_last' + str(hg_module),
|
||||||
|
nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
|
||||||
|
self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256))
|
||||||
|
self.add_module('l' + str(hg_module), nn.Conv2d(256,
|
||||||
|
68, kernel_size=1, stride=1, padding=0))
|
||||||
|
|
||||||
|
if hg_module < self.num_modules - 1:
|
||||||
|
self.add_module(
|
||||||
|
'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
|
||||||
|
self.add_module('al' + str(hg_module), nn.Conv2d(68,
|
||||||
|
256, kernel_size=1, stride=1, padding=0))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = F.relu(self.bn1(self.conv1(x)), True)
|
||||||
|
x = F.avg_pool2d(self.conv2(x), 2, stride=2)
|
||||||
|
x = self.conv3(x)
|
||||||
|
x = self.conv4(x)
|
||||||
|
|
||||||
|
previous = x
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
for i in range(self.num_modules):
|
||||||
|
hg = self._modules['m' + str(i)](previous)
|
||||||
|
|
||||||
|
ll = hg
|
||||||
|
ll = self._modules['top_m_' + str(i)](ll)
|
||||||
|
|
||||||
|
ll = F.relu(self._modules['bn_end' + str(i)]
|
||||||
|
(self._modules['conv_last' + str(i)](ll)), True)
|
||||||
|
|
||||||
|
# Predict heatmaps
|
||||||
|
tmp_out = self._modules['l' + str(i)](ll)
|
||||||
|
outputs.append(tmp_out)
|
||||||
|
|
||||||
|
if i < self.num_modules - 1:
|
||||||
|
ll = self._modules['bl' + str(i)](ll)
|
||||||
|
tmp_out_ = self._modules['al' + str(i)](tmp_out)
|
||||||
|
previous = previous + ll + tmp_out_
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class ResNetDepth(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, block=Bottleneck, layers=[3, 8, 36, 3], num_classes=68):
|
||||||
|
self.inplanes = 64
|
||||||
|
super(ResNetDepth, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(3 + 68, 64, kernel_size=7, stride=2, padding=3,
|
||||||
|
bias=False)
|
||||||
|
self.bn1 = nn.BatchNorm2d(64)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||||
|
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||||
|
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
||||||
|
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
||||||
|
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
||||||
|
self.avgpool = nn.AvgPool2d(7)
|
||||||
|
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||||
|
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||||
|
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||||
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
|
m.weight.data.fill_(1)
|
||||||
|
m.bias.data.zero_()
|
||||||
|
|
||||||
|
def _make_layer(self, block, planes, blocks, stride=1):
|
||||||
|
downsample = None
|
||||||
|
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||||
|
downsample = nn.Sequential(
|
||||||
|
nn.Conv2d(self.inplanes, planes * block.expansion,
|
||||||
|
kernel_size=1, stride=stride, bias=False),
|
||||||
|
nn.BatchNorm2d(planes * block.expansion),
|
||||||
|
)
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
layers.append(block(self.inplanes, planes, stride, downsample))
|
||||||
|
self.inplanes = planes * block.expansion
|
||||||
|
for i in range(1, blocks):
|
||||||
|
layers.append(block(self.inplanes, planes))
|
||||||
|
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.bn1(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.maxpool(x)
|
||||||
|
|
||||||
|
x = self.layer1(x)
|
||||||
|
x = self.layer2(x)
|
||||||
|
x = self.layer3(x)
|
||||||
|
x = self.layer4(x)
|
||||||
|
|
||||||
|
x = self.avgpool(x)
|
||||||
|
x = x.view(x.size(0), -1)
|
||||||
|
x = self.fc(x)
|
||||||
|
|
||||||
|
return x
|
|
@ -0,0 +1,313 @@
|
||||||
|
from __future__ import print_function
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
|
||||||
|
def _gaussian(
|
||||||
|
size=3, sigma=0.25, amplitude=1, normalize=False, width=None,
|
||||||
|
height=None, sigma_horz=None, sigma_vert=None, mean_horz=0.5,
|
||||||
|
mean_vert=0.5):
|
||||||
|
# handle some defaults
|
||||||
|
if width is None:
|
||||||
|
width = size
|
||||||
|
if height is None:
|
||||||
|
height = size
|
||||||
|
if sigma_horz is None:
|
||||||
|
sigma_horz = sigma
|
||||||
|
if sigma_vert is None:
|
||||||
|
sigma_vert = sigma
|
||||||
|
center_x = mean_horz * width + 0.5
|
||||||
|
center_y = mean_vert * height + 0.5
|
||||||
|
gauss = np.empty((height, width), dtype=np.float32)
|
||||||
|
# generate kernel
|
||||||
|
for i in range(height):
|
||||||
|
for j in range(width):
|
||||||
|
gauss[i][j] = amplitude * math.exp(-(math.pow((j + 1 - center_x) / (
|
||||||
|
sigma_horz * width), 2) / 2.0 + math.pow((i + 1 - center_y) / (sigma_vert * height), 2) / 2.0))
|
||||||
|
if normalize:
|
||||||
|
gauss = gauss / np.sum(gauss)
|
||||||
|
return gauss
|
||||||
|
|
||||||
|
|
||||||
|
def draw_gaussian(image, point, sigma):
|
||||||
|
# Check if the gaussian is inside
|
||||||
|
ul = [math.floor(point[0] - 3 * sigma), math.floor(point[1] - 3 * sigma)]
|
||||||
|
br = [math.floor(point[0] + 3 * sigma), math.floor(point[1] + 3 * sigma)]
|
||||||
|
if (ul[0] > image.shape[1] or ul[1] > image.shape[0] or br[0] < 1 or br[1] < 1):
|
||||||
|
return image
|
||||||
|
size = 6 * sigma + 1
|
||||||
|
g = _gaussian(size)
|
||||||
|
g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) - int(max(1, ul[0])) + int(max(1, -ul[0]))]
|
||||||
|
g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) - int(max(1, ul[1])) + int(max(1, -ul[1]))]
|
||||||
|
img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))]
|
||||||
|
img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))]
|
||||||
|
assert (g_x[0] > 0 and g_y[1] > 0)
|
||||||
|
image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]
|
||||||
|
] = image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]] + g[g_y[0] - 1:g_y[1], g_x[0] - 1:g_x[1]]
|
||||||
|
image[image > 1] = 1
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def transform(point, center, scale, resolution, invert=False):
|
||||||
|
"""Generate and affine transformation matrix.
|
||||||
|
|
||||||
|
Given a set of points, a center, a scale and a targer resolution, the
|
||||||
|
function generates and affine transformation matrix. If invert is ``True``
|
||||||
|
it will produce the inverse transformation.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
point {torch.tensor} -- the input 2D point
|
||||||
|
center {torch.tensor or numpy.array} -- the center around which to perform the transformations
|
||||||
|
scale {float} -- the scale of the face/object
|
||||||
|
resolution {float} -- the output resolution
|
||||||
|
|
||||||
|
Keyword Arguments:
|
||||||
|
invert {bool} -- define wherever the function should produce the direct or the
|
||||||
|
inverse transformation matrix (default: {False})
|
||||||
|
"""
|
||||||
|
_pt = torch.ones(3)
|
||||||
|
_pt[0] = point[0]
|
||||||
|
_pt[1] = point[1]
|
||||||
|
|
||||||
|
h = 200.0 * scale
|
||||||
|
t = torch.eye(3)
|
||||||
|
t[0, 0] = resolution / h
|
||||||
|
t[1, 1] = resolution / h
|
||||||
|
t[0, 2] = resolution * (-center[0] / h + 0.5)
|
||||||
|
t[1, 2] = resolution * (-center[1] / h + 0.5)
|
||||||
|
|
||||||
|
if invert:
|
||||||
|
t = torch.inverse(t)
|
||||||
|
|
||||||
|
new_point = (torch.matmul(t, _pt))[0:2]
|
||||||
|
|
||||||
|
return new_point.int()
|
||||||
|
|
||||||
|
|
||||||
|
def crop(image, center, scale, resolution=256.0):
|
||||||
|
"""Center crops an image or set of heatmaps
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
image {numpy.array} -- an rgb image
|
||||||
|
center {numpy.array} -- the center of the object, usually the same as of the bounding box
|
||||||
|
scale {float} -- scale of the face
|
||||||
|
|
||||||
|
Keyword Arguments:
|
||||||
|
resolution {float} -- the size of the output cropped image (default: {256.0})
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[type] -- [description]
|
||||||
|
""" # Crop around the center point
|
||||||
|
""" Crops the image around the center. Input is expected to be an np.ndarray """
|
||||||
|
ul = transform([1, 1], center, scale, resolution, True)
|
||||||
|
br = transform([resolution, resolution], center, scale, resolution, True)
|
||||||
|
# pad = math.ceil(torch.norm((ul - br).float()) / 2.0 - (br[0] - ul[0]) / 2.0)
|
||||||
|
if image.ndim > 2:
|
||||||
|
newDim = np.array([br[1] - ul[1], br[0] - ul[0],
|
||||||
|
image.shape[2]], dtype=np.int32)
|
||||||
|
newImg = np.zeros(newDim, dtype=np.uint8)
|
||||||
|
else:
|
||||||
|
newDim = np.array([br[1] - ul[1], br[0] - ul[0]], dtype=np.int)
|
||||||
|
newImg = np.zeros(newDim, dtype=np.uint8)
|
||||||
|
ht = image.shape[0]
|
||||||
|
wd = image.shape[1]
|
||||||
|
newX = np.array(
|
||||||
|
[max(1, -ul[0] + 1), min(br[0], wd) - ul[0]], dtype=np.int32)
|
||||||
|
newY = np.array(
|
||||||
|
[max(1, -ul[1] + 1), min(br[1], ht) - ul[1]], dtype=np.int32)
|
||||||
|
oldX = np.array([max(1, ul[0] + 1), min(br[0], wd)], dtype=np.int32)
|
||||||
|
oldY = np.array([max(1, ul[1] + 1), min(br[1], ht)], dtype=np.int32)
|
||||||
|
newImg[newY[0] - 1:newY[1], newX[0] - 1:newX[1]
|
||||||
|
] = image[oldY[0] - 1:oldY[1], oldX[0] - 1:oldX[1], :]
|
||||||
|
newImg = cv2.resize(newImg, dsize=(int(resolution), int(resolution)),
|
||||||
|
interpolation=cv2.INTER_LINEAR)
|
||||||
|
return newImg
|
||||||
|
|
||||||
|
|
||||||
|
def get_preds_fromhm(hm, center=None, scale=None):
|
||||||
|
"""Obtain (x,y) coordinates given a set of N heatmaps. If the center
|
||||||
|
and the scale is provided the function will return the points also in
|
||||||
|
the original coordinate frame.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
|
||||||
|
|
||||||
|
Keyword Arguments:
|
||||||
|
center {torch.tensor} -- the center of the bounding box (default: {None})
|
||||||
|
scale {float} -- face scale (default: {None})
|
||||||
|
"""
|
||||||
|
max, idx = torch.max(
|
||||||
|
hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
|
||||||
|
idx += 1
|
||||||
|
preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
|
||||||
|
preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
|
||||||
|
preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
|
||||||
|
|
||||||
|
for i in range(preds.size(0)):
|
||||||
|
for j in range(preds.size(1)):
|
||||||
|
hm_ = hm[i, j, :]
|
||||||
|
pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
|
||||||
|
if pX > 0 and pX < 63 and pY > 0 and pY < 63:
|
||||||
|
diff = torch.FloatTensor(
|
||||||
|
[hm_[pY, pX + 1] - hm_[pY, pX - 1],
|
||||||
|
hm_[pY + 1, pX] - hm_[pY - 1, pX]])
|
||||||
|
preds[i, j].add_(diff.sign_().mul_(.25))
|
||||||
|
|
||||||
|
preds.add_(-.5)
|
||||||
|
|
||||||
|
preds_orig = torch.zeros(preds.size())
|
||||||
|
if center is not None and scale is not None:
|
||||||
|
for i in range(hm.size(0)):
|
||||||
|
for j in range(hm.size(1)):
|
||||||
|
preds_orig[i, j] = transform(
|
||||||
|
preds[i, j], center, scale, hm.size(2), True)
|
||||||
|
|
||||||
|
return preds, preds_orig
|
||||||
|
|
||||||
|
def get_preds_fromhm_batch(hm, centers=None, scales=None):
|
||||||
|
"""Obtain (x,y) coordinates given a set of N heatmaps. If the centers
|
||||||
|
and the scales is provided the function will return the points also in
|
||||||
|
the original coordinate frame.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
|
||||||
|
|
||||||
|
Keyword Arguments:
|
||||||
|
centers {torch.tensor} -- the centers of the bounding box (default: {None})
|
||||||
|
scales {float} -- face scales (default: {None})
|
||||||
|
"""
|
||||||
|
max, idx = torch.max(
|
||||||
|
hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
|
||||||
|
idx += 1
|
||||||
|
preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
|
||||||
|
preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
|
||||||
|
preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
|
||||||
|
|
||||||
|
for i in range(preds.size(0)):
|
||||||
|
for j in range(preds.size(1)):
|
||||||
|
hm_ = hm[i, j, :]
|
||||||
|
pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
|
||||||
|
if pX > 0 and pX < 63 and pY > 0 and pY < 63:
|
||||||
|
diff = torch.FloatTensor(
|
||||||
|
[hm_[pY, pX + 1] - hm_[pY, pX - 1],
|
||||||
|
hm_[pY + 1, pX] - hm_[pY - 1, pX]])
|
||||||
|
preds[i, j].add_(diff.sign_().mul_(.25))
|
||||||
|
|
||||||
|
preds.add_(-.5)
|
||||||
|
|
||||||
|
preds_orig = torch.zeros(preds.size())
|
||||||
|
if centers is not None and scales is not None:
|
||||||
|
for i in range(hm.size(0)):
|
||||||
|
for j in range(hm.size(1)):
|
||||||
|
preds_orig[i, j] = transform(
|
||||||
|
preds[i, j], centers[i], scales[i], hm.size(2), True)
|
||||||
|
|
||||||
|
return preds, preds_orig
|
||||||
|
|
||||||
|
def shuffle_lr(parts, pairs=None):
|
||||||
|
"""Shuffle the points left-right according to the axis of symmetry
|
||||||
|
of the object.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
parts {torch.tensor} -- a 3D or 4D object containing the
|
||||||
|
heatmaps.
|
||||||
|
|
||||||
|
Keyword Arguments:
|
||||||
|
pairs {list of integers} -- [order of the flipped points] (default: {None})
|
||||||
|
"""
|
||||||
|
if pairs is None:
|
||||||
|
pairs = [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0,
|
||||||
|
26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 27, 28, 29, 30, 35,
|
||||||
|
34, 33, 32, 31, 45, 44, 43, 42, 47, 46, 39, 38, 37, 36, 41,
|
||||||
|
40, 54, 53, 52, 51, 50, 49, 48, 59, 58, 57, 56, 55, 64, 63,
|
||||||
|
62, 61, 60, 67, 66, 65]
|
||||||
|
if parts.ndimension() == 3:
|
||||||
|
parts = parts[pairs, ...]
|
||||||
|
else:
|
||||||
|
parts = parts[:, pairs, ...]
|
||||||
|
|
||||||
|
return parts
|
||||||
|
|
||||||
|
|
||||||
|
def flip(tensor, is_label=False):
|
||||||
|
"""Flip an image or a set of heatmaps left-right
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
tensor {numpy.array or torch.tensor} -- [the input image or heatmaps]
|
||||||
|
|
||||||
|
Keyword Arguments:
|
||||||
|
is_label {bool} -- [denote wherever the input is an image or a set of heatmaps ] (default: {False})
|
||||||
|
"""
|
||||||
|
if not torch.is_tensor(tensor):
|
||||||
|
tensor = torch.from_numpy(tensor)
|
||||||
|
|
||||||
|
if is_label:
|
||||||
|
tensor = shuffle_lr(tensor).flip(tensor.ndimension() - 1)
|
||||||
|
else:
|
||||||
|
tensor = tensor.flip(tensor.ndimension() - 1)
|
||||||
|
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
# From pyzolib/paths.py (https://bitbucket.org/pyzo/pyzolib/src/tip/paths.py)
|
||||||
|
|
||||||
|
|
||||||
|
def appdata_dir(appname=None, roaming=False):
|
||||||
|
""" appdata_dir(appname=None, roaming=False)
|
||||||
|
|
||||||
|
Get the path to the application directory, where applications are allowed
|
||||||
|
to write user specific files (e.g. configurations). For non-user specific
|
||||||
|
data, consider using common_appdata_dir().
|
||||||
|
If appname is given, a subdir is appended (and created if necessary).
|
||||||
|
If roaming is True, will prefer a roaming directory (Windows Vista/7).
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Define default user directory
|
||||||
|
userDir = os.getenv('FACEALIGNMENT_USERDIR', None)
|
||||||
|
if userDir is None:
|
||||||
|
userDir = os.path.expanduser('~')
|
||||||
|
if not os.path.isdir(userDir): # pragma: no cover
|
||||||
|
userDir = '/var/tmp' # issue #54
|
||||||
|
|
||||||
|
# Get system app data dir
|
||||||
|
path = None
|
||||||
|
if sys.platform.startswith('win'):
|
||||||
|
path1, path2 = os.getenv('LOCALAPPDATA'), os.getenv('APPDATA')
|
||||||
|
path = (path2 or path1) if roaming else (path1 or path2)
|
||||||
|
elif sys.platform.startswith('darwin'):
|
||||||
|
path = os.path.join(userDir, 'Library', 'Application Support')
|
||||||
|
# On Linux and as fallback
|
||||||
|
if not (path and os.path.isdir(path)):
|
||||||
|
path = userDir
|
||||||
|
|
||||||
|
# Maybe we should store things local to the executable (in case of a
|
||||||
|
# portable distro or a frozen application that wants to be portable)
|
||||||
|
prefix = sys.prefix
|
||||||
|
if getattr(sys, 'frozen', None):
|
||||||
|
prefix = os.path.abspath(os.path.dirname(sys.executable))
|
||||||
|
for reldir in ('settings', '../settings'):
|
||||||
|
localpath = os.path.abspath(os.path.join(prefix, reldir))
|
||||||
|
if os.path.isdir(localpath): # pragma: no cover
|
||||||
|
try:
|
||||||
|
open(os.path.join(localpath, 'test.write'), 'wb').close()
|
||||||
|
os.remove(os.path.join(localpath, 'test.write'))
|
||||||
|
except IOError:
|
||||||
|
pass # We cannot write in this directory
|
||||||
|
else:
|
||||||
|
path = localpath
|
||||||
|
break
|
||||||
|
|
||||||
|
# Get path specific for this app
|
||||||
|
if appname:
|
||||||
|
if path == userDir:
|
||||||
|
appname = '.' + appname.lstrip('.') # Make it a hidden directory
|
||||||
|
path = os.path.join(path, appname)
|
||||||
|
if not os.path.isdir(path): # pragma: no cover
|
||||||
|
os.mkdir(path)
|
||||||
|
|
||||||
|
# Done
|
||||||
|
return path
|
|
@ -0,0 +1,57 @@
|
||||||
|
import torch
|
||||||
|
import time
|
||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from .model import BiSeNet
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
|
||||||
|
class FaceParsing():
|
||||||
|
def __init__(self,resnet_path='./models/face-parse-bisent/resnet18-5c106cde.pth',
|
||||||
|
model_pth='./models/face-parse-bisent/79999_iter.pth'):
|
||||||
|
self.net = self.model_init(resnet_path,model_pth)
|
||||||
|
self.preprocess = self.image_preprocess()
|
||||||
|
|
||||||
|
def model_init(self,
|
||||||
|
resnet_path,
|
||||||
|
model_pth):
|
||||||
|
net = BiSeNet(resnet_path)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
net.cuda()
|
||||||
|
net.load_state_dict(torch.load(model_pth))
|
||||||
|
else:
|
||||||
|
net.load_state_dict(torch.load(model_pth, map_location=torch.device('cpu')))
|
||||||
|
net.eval()
|
||||||
|
return net
|
||||||
|
|
||||||
|
def image_preprocess(self):
|
||||||
|
return transforms.Compose([
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
||||||
|
])
|
||||||
|
|
||||||
|
def __call__(self, image, size=(512, 512)):
|
||||||
|
if isinstance(image, str):
|
||||||
|
image = Image.open(image)
|
||||||
|
|
||||||
|
width, height = image.size
|
||||||
|
with torch.no_grad():
|
||||||
|
image = image.resize(size, Image.BILINEAR)
|
||||||
|
img = self.preprocess(image)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
img = torch.unsqueeze(img, 0).cuda()
|
||||||
|
else:
|
||||||
|
img = torch.unsqueeze(img, 0)
|
||||||
|
out = self.net(img)[0]
|
||||||
|
parsing = out.squeeze(0).cpu().numpy().argmax(0)
|
||||||
|
parsing[np.where(parsing>13)] = 0
|
||||||
|
parsing[np.where(parsing>=1)] = 255
|
||||||
|
parsing = Image.fromarray(parsing.astype(np.uint8))
|
||||||
|
return parsing
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
fp = FaceParsing()
|
||||||
|
segmap = fp('154_small.png')
|
||||||
|
segmap.save('res.png')
|
||||||
|
|
|
@ -0,0 +1,283 @@
|
||||||
|
#!/usr/bin/python
|
||||||
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torchvision
|
||||||
|
|
||||||
|
from .resnet import Resnet18
|
||||||
|
# from modules.bn import InPlaceABNSync as BatchNorm2d
|
||||||
|
|
||||||
|
|
||||||
|
class ConvBNReLU(nn.Module):
|
||||||
|
def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
|
||||||
|
super(ConvBNReLU, self).__init__()
|
||||||
|
self.conv = nn.Conv2d(in_chan,
|
||||||
|
out_chan,
|
||||||
|
kernel_size = ks,
|
||||||
|
stride = stride,
|
||||||
|
padding = padding,
|
||||||
|
bias = False)
|
||||||
|
self.bn = nn.BatchNorm2d(out_chan)
|
||||||
|
self.init_weight()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
x = F.relu(self.bn(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
def init_weight(self):
|
||||||
|
for ly in self.children():
|
||||||
|
if isinstance(ly, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(ly.weight, a=1)
|
||||||
|
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
||||||
|
|
||||||
|
class BiSeNetOutput(nn.Module):
|
||||||
|
def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
|
||||||
|
super(BiSeNetOutput, self).__init__()
|
||||||
|
self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
|
||||||
|
self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
|
||||||
|
self.init_weight()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
x = self.conv_out(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def init_weight(self):
|
||||||
|
for ly in self.children():
|
||||||
|
if isinstance(ly, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(ly.weight, a=1)
|
||||||
|
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
||||||
|
|
||||||
|
def get_params(self):
|
||||||
|
wd_params, nowd_params = [], []
|
||||||
|
for name, module in self.named_modules():
|
||||||
|
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
||||||
|
wd_params.append(module.weight)
|
||||||
|
if not module.bias is None:
|
||||||
|
nowd_params.append(module.bias)
|
||||||
|
elif isinstance(module, nn.BatchNorm2d):
|
||||||
|
nowd_params += list(module.parameters())
|
||||||
|
return wd_params, nowd_params
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionRefinementModule(nn.Module):
|
||||||
|
def __init__(self, in_chan, out_chan, *args, **kwargs):
|
||||||
|
super(AttentionRefinementModule, self).__init__()
|
||||||
|
self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
|
||||||
|
self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
|
||||||
|
self.bn_atten = nn.BatchNorm2d(out_chan)
|
||||||
|
self.sigmoid_atten = nn.Sigmoid()
|
||||||
|
self.init_weight()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
feat = self.conv(x)
|
||||||
|
atten = F.avg_pool2d(feat, feat.size()[2:])
|
||||||
|
atten = self.conv_atten(atten)
|
||||||
|
atten = self.bn_atten(atten)
|
||||||
|
atten = self.sigmoid_atten(atten)
|
||||||
|
out = torch.mul(feat, atten)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def init_weight(self):
|
||||||
|
for ly in self.children():
|
||||||
|
if isinstance(ly, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(ly.weight, a=1)
|
||||||
|
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
||||||
|
|
||||||
|
|
||||||
|
class ContextPath(nn.Module):
|
||||||
|
def __init__(self, resnet_path, *args, **kwargs):
|
||||||
|
super(ContextPath, self).__init__()
|
||||||
|
self.resnet = Resnet18(resnet_path)
|
||||||
|
self.arm16 = AttentionRefinementModule(256, 128)
|
||||||
|
self.arm32 = AttentionRefinementModule(512, 128)
|
||||||
|
self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
|
||||||
|
self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
|
||||||
|
self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
|
||||||
|
|
||||||
|
self.init_weight()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
H0, W0 = x.size()[2:]
|
||||||
|
feat8, feat16, feat32 = self.resnet(x)
|
||||||
|
H8, W8 = feat8.size()[2:]
|
||||||
|
H16, W16 = feat16.size()[2:]
|
||||||
|
H32, W32 = feat32.size()[2:]
|
||||||
|
|
||||||
|
avg = F.avg_pool2d(feat32, feat32.size()[2:])
|
||||||
|
avg = self.conv_avg(avg)
|
||||||
|
avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
|
||||||
|
|
||||||
|
feat32_arm = self.arm32(feat32)
|
||||||
|
feat32_sum = feat32_arm + avg_up
|
||||||
|
feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
|
||||||
|
feat32_up = self.conv_head32(feat32_up)
|
||||||
|
|
||||||
|
feat16_arm = self.arm16(feat16)
|
||||||
|
feat16_sum = feat16_arm + feat32_up
|
||||||
|
feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
|
||||||
|
feat16_up = self.conv_head16(feat16_up)
|
||||||
|
|
||||||
|
return feat8, feat16_up, feat32_up # x8, x8, x16
|
||||||
|
|
||||||
|
def init_weight(self):
|
||||||
|
for ly in self.children():
|
||||||
|
if isinstance(ly, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(ly.weight, a=1)
|
||||||
|
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
||||||
|
|
||||||
|
def get_params(self):
|
||||||
|
wd_params, nowd_params = [], []
|
||||||
|
for name, module in self.named_modules():
|
||||||
|
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||||
|
wd_params.append(module.weight)
|
||||||
|
if not module.bias is None:
|
||||||
|
nowd_params.append(module.bias)
|
||||||
|
elif isinstance(module, nn.BatchNorm2d):
|
||||||
|
nowd_params += list(module.parameters())
|
||||||
|
return wd_params, nowd_params
|
||||||
|
|
||||||
|
|
||||||
|
### This is not used, since I replace this with the resnet feature with the same size
|
||||||
|
class SpatialPath(nn.Module):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(SpatialPath, self).__init__()
|
||||||
|
self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
|
||||||
|
self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
|
||||||
|
self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
|
||||||
|
self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
|
||||||
|
self.init_weight()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
feat = self.conv1(x)
|
||||||
|
feat = self.conv2(feat)
|
||||||
|
feat = self.conv3(feat)
|
||||||
|
feat = self.conv_out(feat)
|
||||||
|
return feat
|
||||||
|
|
||||||
|
def init_weight(self):
|
||||||
|
for ly in self.children():
|
||||||
|
if isinstance(ly, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(ly.weight, a=1)
|
||||||
|
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
||||||
|
|
||||||
|
def get_params(self):
|
||||||
|
wd_params, nowd_params = [], []
|
||||||
|
for name, module in self.named_modules():
|
||||||
|
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
||||||
|
wd_params.append(module.weight)
|
||||||
|
if not module.bias is None:
|
||||||
|
nowd_params.append(module.bias)
|
||||||
|
elif isinstance(module, nn.BatchNorm2d):
|
||||||
|
nowd_params += list(module.parameters())
|
||||||
|
return wd_params, nowd_params
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureFusionModule(nn.Module):
|
||||||
|
def __init__(self, in_chan, out_chan, *args, **kwargs):
|
||||||
|
super(FeatureFusionModule, self).__init__()
|
||||||
|
self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
|
||||||
|
self.conv1 = nn.Conv2d(out_chan,
|
||||||
|
out_chan//4,
|
||||||
|
kernel_size = 1,
|
||||||
|
stride = 1,
|
||||||
|
padding = 0,
|
||||||
|
bias = False)
|
||||||
|
self.conv2 = nn.Conv2d(out_chan//4,
|
||||||
|
out_chan,
|
||||||
|
kernel_size = 1,
|
||||||
|
stride = 1,
|
||||||
|
padding = 0,
|
||||||
|
bias = False)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
self.init_weight()
|
||||||
|
|
||||||
|
def forward(self, fsp, fcp):
|
||||||
|
fcat = torch.cat([fsp, fcp], dim=1)
|
||||||
|
feat = self.convblk(fcat)
|
||||||
|
atten = F.avg_pool2d(feat, feat.size()[2:])
|
||||||
|
atten = self.conv1(atten)
|
||||||
|
atten = self.relu(atten)
|
||||||
|
atten = self.conv2(atten)
|
||||||
|
atten = self.sigmoid(atten)
|
||||||
|
feat_atten = torch.mul(feat, atten)
|
||||||
|
feat_out = feat_atten + feat
|
||||||
|
return feat_out
|
||||||
|
|
||||||
|
def init_weight(self):
|
||||||
|
for ly in self.children():
|
||||||
|
if isinstance(ly, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(ly.weight, a=1)
|
||||||
|
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
||||||
|
|
||||||
|
def get_params(self):
|
||||||
|
wd_params, nowd_params = [], []
|
||||||
|
for name, module in self.named_modules():
|
||||||
|
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
||||||
|
wd_params.append(module.weight)
|
||||||
|
if not module.bias is None:
|
||||||
|
nowd_params.append(module.bias)
|
||||||
|
elif isinstance(module, nn.BatchNorm2d):
|
||||||
|
nowd_params += list(module.parameters())
|
||||||
|
return wd_params, nowd_params
|
||||||
|
|
||||||
|
|
||||||
|
class BiSeNet(nn.Module):
|
||||||
|
def __init__(self, resnet_path='models/resnet18-5c106cde.pth', n_classes=19, *args, **kwargs):
|
||||||
|
super(BiSeNet, self).__init__()
|
||||||
|
self.cp = ContextPath(resnet_path)
|
||||||
|
## here self.sp is deleted
|
||||||
|
self.ffm = FeatureFusionModule(256, 256)
|
||||||
|
self.conv_out = BiSeNetOutput(256, 256, n_classes)
|
||||||
|
self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
|
||||||
|
self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
|
||||||
|
self.init_weight()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
H, W = x.size()[2:]
|
||||||
|
feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature
|
||||||
|
feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
|
||||||
|
feat_fuse = self.ffm(feat_sp, feat_cp8)
|
||||||
|
|
||||||
|
feat_out = self.conv_out(feat_fuse)
|
||||||
|
feat_out16 = self.conv_out16(feat_cp8)
|
||||||
|
feat_out32 = self.conv_out32(feat_cp16)
|
||||||
|
|
||||||
|
feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
|
||||||
|
feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
|
||||||
|
feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
|
||||||
|
return feat_out, feat_out16, feat_out32
|
||||||
|
|
||||||
|
def init_weight(self):
|
||||||
|
for ly in self.children():
|
||||||
|
if isinstance(ly, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(ly.weight, a=1)
|
||||||
|
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
||||||
|
|
||||||
|
def get_params(self):
|
||||||
|
wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
|
||||||
|
for name, child in self.named_children():
|
||||||
|
child_wd_params, child_nowd_params = child.get_params()
|
||||||
|
if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
|
||||||
|
lr_mul_wd_params += child_wd_params
|
||||||
|
lr_mul_nowd_params += child_nowd_params
|
||||||
|
else:
|
||||||
|
wd_params += child_wd_params
|
||||||
|
nowd_params += child_nowd_params
|
||||||
|
return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
net = BiSeNet(19)
|
||||||
|
net.cuda()
|
||||||
|
net.eval()
|
||||||
|
in_ten = torch.randn(16, 3, 640, 480).cuda()
|
||||||
|
out, out16, out32 = net(in_ten)
|
||||||
|
print(out.shape)
|
||||||
|
|
||||||
|
net.get_params()
|
|
@ -0,0 +1,109 @@
|
||||||
|
#!/usr/bin/python
|
||||||
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.utils.model_zoo as modelzoo
|
||||||
|
|
||||||
|
# from modules.bn import InPlaceABNSync as BatchNorm2d
|
||||||
|
|
||||||
|
resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
|
||||||
|
|
||||||
|
|
||||||
|
def conv3x3(in_planes, out_planes, stride=1):
|
||||||
|
"""3x3 convolution with padding"""
|
||||||
|
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||||
|
padding=1, bias=False)
|
||||||
|
|
||||||
|
|
||||||
|
class BasicBlock(nn.Module):
|
||||||
|
def __init__(self, in_chan, out_chan, stride=1):
|
||||||
|
super(BasicBlock, self).__init__()
|
||||||
|
self.conv1 = conv3x3(in_chan, out_chan, stride)
|
||||||
|
self.bn1 = nn.BatchNorm2d(out_chan)
|
||||||
|
self.conv2 = conv3x3(out_chan, out_chan)
|
||||||
|
self.bn2 = nn.BatchNorm2d(out_chan)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.downsample = None
|
||||||
|
if in_chan != out_chan or stride != 1:
|
||||||
|
self.downsample = nn.Sequential(
|
||||||
|
nn.Conv2d(in_chan, out_chan,
|
||||||
|
kernel_size=1, stride=stride, bias=False),
|
||||||
|
nn.BatchNorm2d(out_chan),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
residual = self.conv1(x)
|
||||||
|
residual = F.relu(self.bn1(residual))
|
||||||
|
residual = self.conv2(residual)
|
||||||
|
residual = self.bn2(residual)
|
||||||
|
|
||||||
|
shortcut = x
|
||||||
|
if self.downsample is not None:
|
||||||
|
shortcut = self.downsample(x)
|
||||||
|
|
||||||
|
out = shortcut + residual
|
||||||
|
out = self.relu(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def create_layer_basic(in_chan, out_chan, bnum, stride=1):
|
||||||
|
layers = [BasicBlock(in_chan, out_chan, stride=stride)]
|
||||||
|
for i in range(bnum-1):
|
||||||
|
layers.append(BasicBlock(out_chan, out_chan, stride=1))
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
|
||||||
|
class Resnet18(nn.Module):
|
||||||
|
def __init__(self, model_path):
|
||||||
|
super(Resnet18, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
|
||||||
|
bias=False)
|
||||||
|
self.bn1 = nn.BatchNorm2d(64)
|
||||||
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||||
|
self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
|
||||||
|
self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
|
||||||
|
self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
|
||||||
|
self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
|
||||||
|
self.init_weight(model_path)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = F.relu(self.bn1(x))
|
||||||
|
x = self.maxpool(x)
|
||||||
|
|
||||||
|
x = self.layer1(x)
|
||||||
|
feat8 = self.layer2(x) # 1/8
|
||||||
|
feat16 = self.layer3(feat8) # 1/16
|
||||||
|
feat32 = self.layer4(feat16) # 1/32
|
||||||
|
return feat8, feat16, feat32
|
||||||
|
|
||||||
|
def init_weight(self, model_path):
|
||||||
|
state_dict = torch.load(model_path) #modelzoo.load_url(resnet18_url)
|
||||||
|
self_state_dict = self.state_dict()
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
if 'fc' in k: continue
|
||||||
|
self_state_dict.update({k: v})
|
||||||
|
self.load_state_dict(self_state_dict)
|
||||||
|
|
||||||
|
def get_params(self):
|
||||||
|
wd_params, nowd_params = [], []
|
||||||
|
for name, module in self.named_modules():
|
||||||
|
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||||
|
wd_params.append(module.weight)
|
||||||
|
if not module.bias is None:
|
||||||
|
nowd_params.append(module.bias)
|
||||||
|
elif isinstance(module, nn.BatchNorm2d):
|
||||||
|
nowd_params += list(module.parameters())
|
||||||
|
return wd_params, nowd_params
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
net = Resnet18()
|
||||||
|
x = torch.randn(16, 3, 224, 224)
|
||||||
|
out = net(x)
|
||||||
|
print(out[0].size())
|
||||||
|
print(out[1].size())
|
||||||
|
print(out[2].size())
|
||||||
|
net.get_params()
|
|
@ -0,0 +1,154 @@
|
||||||
|
import sys
|
||||||
|
from face_detection import FaceAlignment,LandmarksType
|
||||||
|
from os import listdir, path
|
||||||
|
import subprocess
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
import pickle
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
from mmpose.apis import inference_topdown, init_model
|
||||||
|
from mmpose.structures import merge_data_samples
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
# initialize the mmpose model
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
config_file = './musetalk/utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py'
|
||||||
|
checkpoint_file = './models/dwpose/dw-ll_ucoco_384.pth'
|
||||||
|
model = init_model(config_file, checkpoint_file, device=device)
|
||||||
|
|
||||||
|
# initialize the face detection model
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
fa = FaceAlignment(LandmarksType._2D, flip_input=False,device=device)
|
||||||
|
|
||||||
|
# maker if the bbox is not sufficient
|
||||||
|
coord_placeholder = (0.0,0.0,0.0,0.0)
|
||||||
|
|
||||||
|
def resize_landmark(landmark, w, h, new_w, new_h):
|
||||||
|
w_ratio = new_w / w
|
||||||
|
h_ratio = new_h / h
|
||||||
|
landmark_norm = landmark / [w, h]
|
||||||
|
landmark_resized = landmark_norm * [new_w, new_h]
|
||||||
|
return landmark_resized
|
||||||
|
|
||||||
|
def read_imgs(img_list):
|
||||||
|
frames = []
|
||||||
|
print('reading images...')
|
||||||
|
for img_path in tqdm(img_list):
|
||||||
|
frame = cv2.imread(img_path)
|
||||||
|
frames.append(frame)
|
||||||
|
return frames
|
||||||
|
|
||||||
|
def get_bbox_range(img_list,upperbondrange =0):
|
||||||
|
frames = read_imgs(img_list)
|
||||||
|
batch_size_fa = 1
|
||||||
|
batches = [frames[i:i + batch_size_fa] for i in range(0, len(frames), batch_size_fa)]
|
||||||
|
coords_list = []
|
||||||
|
landmarks = []
|
||||||
|
if upperbondrange != 0:
|
||||||
|
print('get key_landmark and face bounding boxes with the bbox_shift:',upperbondrange)
|
||||||
|
else:
|
||||||
|
print('get key_landmark and face bounding boxes with the default value')
|
||||||
|
average_range_minus = []
|
||||||
|
average_range_plus = []
|
||||||
|
for fb in tqdm(batches):
|
||||||
|
results = inference_topdown(model, np.asarray(fb)[0])
|
||||||
|
results = merge_data_samples(results)
|
||||||
|
keypoints = results.pred_instances.keypoints
|
||||||
|
face_land_mark= keypoints[0][23:91]
|
||||||
|
face_land_mark = face_land_mark.astype(np.int32)
|
||||||
|
|
||||||
|
# get bounding boxes by face detetion
|
||||||
|
bbox = fa.get_detections_for_batch(np.asarray(fb))
|
||||||
|
|
||||||
|
# adjust the bounding box refer to landmark
|
||||||
|
# Add the bounding box to a tuple and append it to the coordinates list
|
||||||
|
for j, f in enumerate(bbox):
|
||||||
|
if f is None: # no face in the image
|
||||||
|
coords_list += [coord_placeholder]
|
||||||
|
continue
|
||||||
|
|
||||||
|
half_face_coord = face_land_mark[29]#np.mean([face_land_mark[28], face_land_mark[29]], axis=0)
|
||||||
|
range_minus = (face_land_mark[30]- face_land_mark[29])[1]
|
||||||
|
range_plus = (face_land_mark[29]- face_land_mark[28])[1]
|
||||||
|
average_range_minus.append(range_minus)
|
||||||
|
average_range_plus.append(range_plus)
|
||||||
|
if upperbondrange != 0:
|
||||||
|
half_face_coord[1] = upperbondrange+half_face_coord[1] #手动调整 + 向下(偏29) - 向上(偏28)
|
||||||
|
|
||||||
|
text_range=f"Total frame:「{len(frames)}」 Manually adjust range : [ -{int(sum(average_range_minus) / len(average_range_minus))}~{int(sum(average_range_plus) / len(average_range_plus))} ] , the current value: {upperbondrange}"
|
||||||
|
return text_range
|
||||||
|
|
||||||
|
|
||||||
|
def get_landmark_and_bbox(img_list,upperbondrange =0):
|
||||||
|
frames = read_imgs(img_list)
|
||||||
|
batch_size_fa = 1
|
||||||
|
batches = [frames[i:i + batch_size_fa] for i in range(0, len(frames), batch_size_fa)]
|
||||||
|
coords_list = []
|
||||||
|
landmarks = []
|
||||||
|
if upperbondrange != 0:
|
||||||
|
print('get key_landmark and face bounding boxes with the bbox_shift:',upperbondrange)
|
||||||
|
else:
|
||||||
|
print('get key_landmark and face bounding boxes with the default value')
|
||||||
|
average_range_minus = []
|
||||||
|
average_range_plus = []
|
||||||
|
for fb in tqdm(batches):
|
||||||
|
results = inference_topdown(model, np.asarray(fb)[0])
|
||||||
|
results = merge_data_samples(results)
|
||||||
|
keypoints = results.pred_instances.keypoints
|
||||||
|
face_land_mark= keypoints[0][23:91]
|
||||||
|
face_land_mark = face_land_mark.astype(np.int32)
|
||||||
|
|
||||||
|
# get bounding boxes by face detetion
|
||||||
|
bbox = fa.get_detections_for_batch(np.asarray(fb))
|
||||||
|
|
||||||
|
# adjust the bounding box refer to landmark
|
||||||
|
# Add the bounding box to a tuple and append it to the coordinates list
|
||||||
|
for j, f in enumerate(bbox):
|
||||||
|
if f is None: # no face in the image
|
||||||
|
coords_list += [coord_placeholder]
|
||||||
|
continue
|
||||||
|
|
||||||
|
half_face_coord = face_land_mark[29]#np.mean([face_land_mark[28], face_land_mark[29]], axis=0)
|
||||||
|
range_minus = (face_land_mark[30]- face_land_mark[29])[1]
|
||||||
|
range_plus = (face_land_mark[29]- face_land_mark[28])[1]
|
||||||
|
average_range_minus.append(range_minus)
|
||||||
|
average_range_plus.append(range_plus)
|
||||||
|
if upperbondrange != 0:
|
||||||
|
half_face_coord[1] = upperbondrange+half_face_coord[1] #手动调整 + 向下(偏29) - 向上(偏28)
|
||||||
|
half_face_dist = np.max(face_land_mark[:,1]) - half_face_coord[1]
|
||||||
|
upper_bond = half_face_coord[1]-half_face_dist
|
||||||
|
|
||||||
|
f_landmark = (np.min(face_land_mark[:, 0]),int(upper_bond),np.max(face_land_mark[:, 0]),np.max(face_land_mark[:,1]))
|
||||||
|
x1, y1, x2, y2 = f_landmark
|
||||||
|
|
||||||
|
if y2-y1<=0 or x2-x1<=0 or x1<0: # if the landmark bbox is not suitable, reuse the bbox
|
||||||
|
coords_list += [f]
|
||||||
|
w,h = f[2]-f[0], f[3]-f[1]
|
||||||
|
print("error bbox:",f)
|
||||||
|
else:
|
||||||
|
coords_list += [f_landmark]
|
||||||
|
|
||||||
|
print("********************************************bbox_shift parameter adjustment**********************************************************")
|
||||||
|
print(f"Total frame:「{len(frames)}」 Manually adjust range : [ -{int(sum(average_range_minus) / len(average_range_minus))}~{int(sum(average_range_plus) / len(average_range_plus))} ] , the current value: {upperbondrange}")
|
||||||
|
print("*************************************************************************************************************************************")
|
||||||
|
return coords_list,frames
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
img_list = ["./results/lyria/00000.png","./results/lyria/00001.png","./results/lyria/00002.png","./results/lyria/00003.png"]
|
||||||
|
crop_coord_path = "./coord_face.pkl"
|
||||||
|
coords_list,full_frames = get_landmark_and_bbox(img_list)
|
||||||
|
with open(crop_coord_path, 'wb') as f:
|
||||||
|
pickle.dump(coords_list, f)
|
||||||
|
|
||||||
|
for bbox, frame in zip(coords_list,full_frames):
|
||||||
|
if bbox == coord_placeholder:
|
||||||
|
continue
|
||||||
|
x1, y1, x2, y2 = bbox
|
||||||
|
crop_frame = frame[y1:y2, x1:x2]
|
||||||
|
print('Cropped shape', crop_frame.shape)
|
||||||
|
|
||||||
|
#cv2.imwrite(path.join(save_dir, '{}.png'.format(i)),full_frames[i][0][y1:y2, x1:x2])
|
||||||
|
print(coords_list)
|
|
@ -0,0 +1,75 @@
|
||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
ffmpeg_path = os.getenv('FFMPEG_PATH')
|
||||||
|
if ffmpeg_path is None:
|
||||||
|
print("please download ffmpeg-static and export to FFMPEG_PATH. \nFor example: export FFMPEG_PATH=/musetalk/ffmpeg-4.4-amd64-static")
|
||||||
|
elif ffmpeg_path not in os.getenv('PATH'):
|
||||||
|
print("add ffmpeg to path")
|
||||||
|
os.environ["PATH"] = f"{ffmpeg_path}:{os.environ['PATH']}"
|
||||||
|
|
||||||
|
|
||||||
|
from musetalk.whisper.audio2feature import Audio2Feature
|
||||||
|
from musetalk.models.vae import VAE
|
||||||
|
from musetalk.models.unet import UNet,PositionalEncoding
|
||||||
|
|
||||||
|
def load_all_model():
|
||||||
|
audio_processor = Audio2Feature(model_path="./models/whisper/tiny.pt")
|
||||||
|
vae = VAE(model_path = "./models/sd-vae-ft-mse/")
|
||||||
|
unet = UNet(unet_config="./models/musetalk/musetalk.json",
|
||||||
|
model_path ="./models/musetalk/pytorch_model.bin")
|
||||||
|
pe = PositionalEncoding(d_model=384)
|
||||||
|
return audio_processor,vae,unet,pe
|
||||||
|
|
||||||
|
def get_file_type(video_path):
|
||||||
|
_, ext = os.path.splitext(video_path)
|
||||||
|
|
||||||
|
if ext.lower() in ['.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff']:
|
||||||
|
return 'image'
|
||||||
|
elif ext.lower() in ['.avi', '.mp4', '.mov', '.flv', '.mkv']:
|
||||||
|
return 'video'
|
||||||
|
else:
|
||||||
|
return 'unsupported'
|
||||||
|
|
||||||
|
def get_video_fps(video_path):
|
||||||
|
video = cv2.VideoCapture(video_path)
|
||||||
|
fps = video.get(cv2.CAP_PROP_FPS)
|
||||||
|
video.release()
|
||||||
|
return fps
|
||||||
|
|
||||||
|
def datagen(whisper_chunks,
|
||||||
|
vae_encode_latents,
|
||||||
|
batch_size=8,
|
||||||
|
delay_frame=0):
|
||||||
|
whisper_batch, latent_batch = [], []
|
||||||
|
for i, w in enumerate(whisper_chunks):
|
||||||
|
idx = (i+delay_frame)%len(vae_encode_latents)
|
||||||
|
latent = vae_encode_latents[idx]
|
||||||
|
whisper_batch.append(w)
|
||||||
|
latent_batch.append(latent)
|
||||||
|
|
||||||
|
if len(latent_batch) >= batch_size:
|
||||||
|
whisper_batch = np.stack(whisper_batch)
|
||||||
|
latent_batch = torch.cat(latent_batch, dim=0)
|
||||||
|
yield whisper_batch, latent_batch
|
||||||
|
whisper_batch, latent_batch = [], []
|
||||||
|
|
||||||
|
# the last batch may smaller than batch size
|
||||||
|
if len(latent_batch) > 0:
|
||||||
|
whisper_batch = np.stack(whisper_batch)
|
||||||
|
latent_batch = torch.cat(latent_batch, dim=0)
|
||||||
|
|
||||||
|
yield whisper_batch, latent_batch
|
||||||
|
|
||||||
|
def load_audio_model():
|
||||||
|
audio_processor = Audio2Feature(model_path="./models/whisper/tiny.pt")
|
||||||
|
return audio_processor
|
||||||
|
|
||||||
|
def load_diffusion_model():
|
||||||
|
vae = VAE(model_path = "./models/sd-vae-ft-mse/")
|
||||||
|
unet = UNet(unet_config="./models/musetalk/musetalk.json",
|
||||||
|
model_path ="./models/musetalk/pytorch_model.bin")
|
||||||
|
pe = PositionalEncoding(d_model=384)
|
||||||
|
return vae,unet,pe
|
|
@ -0,0 +1,130 @@
|
||||||
|
import os
|
||||||
|
from .whisper import load_model
|
||||||
|
import soundfile as sf
|
||||||
|
import numpy as np
|
||||||
|
import time
|
||||||
|
import sys
|
||||||
|
sys.path.append("..")
|
||||||
|
|
||||||
|
class Audio2Feature():
|
||||||
|
def __init__(self,
|
||||||
|
whisper_model_type="tiny",
|
||||||
|
model_path="./models/whisper/tiny.pt"):
|
||||||
|
self.whisper_model_type = whisper_model_type
|
||||||
|
self.model = load_model(model_path) #
|
||||||
|
|
||||||
|
def get_sliced_feature(self,
|
||||||
|
feature_array,
|
||||||
|
vid_idx,
|
||||||
|
audio_feat_length=[2,2],
|
||||||
|
fps=25):
|
||||||
|
"""
|
||||||
|
Get sliced features based on a given index
|
||||||
|
:param feature_array:
|
||||||
|
:param start_idx: the start index of the feature
|
||||||
|
:param audio_feat_length:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
length = len(feature_array)
|
||||||
|
selected_feature = []
|
||||||
|
selected_idx = []
|
||||||
|
|
||||||
|
center_idx = int(vid_idx*50/fps)
|
||||||
|
left_idx = center_idx-audio_feat_length[0]*2
|
||||||
|
right_idx = center_idx + (audio_feat_length[1]+1)*2
|
||||||
|
|
||||||
|
for idx in range(left_idx,right_idx):
|
||||||
|
idx = max(0, idx)
|
||||||
|
idx = min(length-1, idx)
|
||||||
|
x = feature_array[idx]
|
||||||
|
selected_feature.append(x)
|
||||||
|
selected_idx.append(idx)
|
||||||
|
|
||||||
|
selected_feature = np.concatenate(selected_feature, axis=0)
|
||||||
|
selected_feature = selected_feature.reshape(-1, 384)# 50*384
|
||||||
|
return selected_feature,selected_idx
|
||||||
|
|
||||||
|
def get_sliced_feature_sparse(self,feature_array, vid_idx, audio_feat_length= [2,2],fps = 25):
|
||||||
|
"""
|
||||||
|
Get sliced features based on a given index
|
||||||
|
:param feature_array:
|
||||||
|
:param start_idx: the start index of the feature
|
||||||
|
:param audio_feat_length:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
length = len(feature_array)
|
||||||
|
selected_feature = []
|
||||||
|
selected_idx = []
|
||||||
|
|
||||||
|
for dt in range(-audio_feat_length[0],audio_feat_length[1]+1):
|
||||||
|
left_idx = int((vid_idx+dt)*50/fps)
|
||||||
|
if left_idx<1 or left_idx>length-1:
|
||||||
|
print('test-----,left_idx=',left_idx)
|
||||||
|
left_idx = max(0, left_idx)
|
||||||
|
left_idx = min(length-1, left_idx)
|
||||||
|
|
||||||
|
x = feature_array[left_idx]
|
||||||
|
x = x[np.newaxis,:,:]
|
||||||
|
x = np.repeat(x, 2, axis=0)
|
||||||
|
selected_feature.append(x)
|
||||||
|
selected_idx.append(left_idx)
|
||||||
|
selected_idx.append(left_idx)
|
||||||
|
else:
|
||||||
|
x = feature_array[left_idx-1:left_idx+1]
|
||||||
|
selected_feature.append(x)
|
||||||
|
selected_idx.append(left_idx-1)
|
||||||
|
selected_idx.append(left_idx)
|
||||||
|
selected_feature = np.concatenate(selected_feature, axis=0)
|
||||||
|
selected_feature = selected_feature.reshape(-1, 384)# 50*384
|
||||||
|
return selected_feature,selected_idx
|
||||||
|
|
||||||
|
|
||||||
|
def feature2chunks(self,feature_array,fps,batch_size,audio_feat_length = [2,2],start=0):
|
||||||
|
whisper_chunks = []
|
||||||
|
whisper_idx_multiplier = 50./fps
|
||||||
|
i = 0
|
||||||
|
#print(f"video in {fps} FPS, audio idx in 50FPS")
|
||||||
|
for _ in range(batch_size):
|
||||||
|
# start_idx = int(i * whisper_idx_multiplier)
|
||||||
|
# if start_idx>=len(feature_array):
|
||||||
|
# break
|
||||||
|
selected_feature,selected_idx = self.get_sliced_feature(feature_array= feature_array,vid_idx = i+start,audio_feat_length=audio_feat_length,fps=fps)
|
||||||
|
#print(f"i:{i},selected_idx {selected_idx}")
|
||||||
|
whisper_chunks.append(selected_feature)
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
|
||||||
|
return whisper_chunks
|
||||||
|
|
||||||
|
def audio2feat(self,audio_path):
|
||||||
|
# get the sample rate of the audio
|
||||||
|
result = self.model.transcribe(audio_path)
|
||||||
|
embed_list = []
|
||||||
|
for emb in result['segments']:
|
||||||
|
encoder_embeddings = emb['encoder_embeddings']
|
||||||
|
encoder_embeddings = encoder_embeddings.transpose(0,2,1,3)
|
||||||
|
encoder_embeddings = encoder_embeddings.squeeze(0)
|
||||||
|
start_idx = int(emb['start'])
|
||||||
|
end_idx = int(emb['end'])
|
||||||
|
emb_end_idx = int((end_idx - start_idx)/2)
|
||||||
|
embed_list.append(encoder_embeddings[:emb_end_idx])
|
||||||
|
concatenated_array = np.concatenate(embed_list, axis=0)
|
||||||
|
return concatenated_array
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
audio_processor = Audio2Feature(model_path="../../models/whisper/whisper_tiny.pt")
|
||||||
|
audio_path = "./test.mp3"
|
||||||
|
array = audio_processor.audio2feat(audio_path)
|
||||||
|
print(array.shape)
|
||||||
|
fps = 25
|
||||||
|
whisper_idx_multiplier = 50./fps
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
print(f"video in {fps} FPS, audio idx in 50FPS")
|
||||||
|
while 1:
|
||||||
|
start_idx = int(i * whisper_idx_multiplier)
|
||||||
|
selected_feature,selected_idx = audio_processor.get_sliced_feature(feature_array= array,vid_idx = i,audio_feat_length=[2,2],fps=fps)
|
||||||
|
print(f"video idx {i},\t audio idx {selected_idx},\t shape {selected_feature.shape}")
|
||||||
|
i += 1
|
||||||
|
if start_idx>len(array):
|
||||||
|
break
|
|
@ -0,0 +1,116 @@
|
||||||
|
import hashlib
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
import urllib
|
||||||
|
import warnings
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
|
||||||
|
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
|
||||||
|
from .model import Whisper, ModelDimensions
|
||||||
|
from .transcribe import transcribe
|
||||||
|
|
||||||
|
|
||||||
|
_MODELS = {
|
||||||
|
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
||||||
|
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
|
||||||
|
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
|
||||||
|
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
|
||||||
|
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
|
||||||
|
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
|
||||||
|
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
|
||||||
|
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
|
||||||
|
"large": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large.pt",
|
||||||
|
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
|
||||||
|
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
||||||
|
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
|
||||||
|
os.makedirs(root, exist_ok=True)
|
||||||
|
|
||||||
|
expected_sha256 = url.split("/")[-2]
|
||||||
|
download_target = os.path.join(root, os.path.basename(url))
|
||||||
|
|
||||||
|
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
||||||
|
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
||||||
|
|
||||||
|
if os.path.isfile(download_target):
|
||||||
|
model_bytes = open(download_target, "rb").read()
|
||||||
|
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
|
||||||
|
return model_bytes if in_memory else download_target
|
||||||
|
else:
|
||||||
|
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
||||||
|
|
||||||
|
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
||||||
|
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
|
||||||
|
while True:
|
||||||
|
buffer = source.read(8192)
|
||||||
|
if not buffer:
|
||||||
|
break
|
||||||
|
|
||||||
|
output.write(buffer)
|
||||||
|
loop.update(len(buffer))
|
||||||
|
|
||||||
|
model_bytes = open(download_target, "rb").read()
|
||||||
|
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
|
||||||
|
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model.")
|
||||||
|
|
||||||
|
return model_bytes if in_memory else download_target
|
||||||
|
|
||||||
|
|
||||||
|
def available_models() -> List[str]:
|
||||||
|
"""Returns the names of available models"""
|
||||||
|
return list(_MODELS.keys())
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None, in_memory: bool = False) -> Whisper:
|
||||||
|
"""
|
||||||
|
Load a Whisper ASR model
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
name : str
|
||||||
|
one of the official model names listed by `whisper.available_models()`, or
|
||||||
|
path to a model checkpoint containing the model dimensions and the model state_dict.
|
||||||
|
device : Union[str, torch.device]
|
||||||
|
the PyTorch device to put the model into
|
||||||
|
download_root: str
|
||||||
|
path to download the model files; by default, it uses "~/.cache/whisper"
|
||||||
|
in_memory: bool
|
||||||
|
whether to preload the model weights into host memory
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
model : Whisper
|
||||||
|
The Whisper ASR model instance
|
||||||
|
"""
|
||||||
|
|
||||||
|
if device is None:
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
if download_root is None:
|
||||||
|
download_root = os.getenv(
|
||||||
|
"XDG_CACHE_HOME",
|
||||||
|
os.path.join(os.path.expanduser("~"), ".cache", "whisper")
|
||||||
|
)
|
||||||
|
|
||||||
|
if name in _MODELS:
|
||||||
|
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
|
||||||
|
elif os.path.isfile(name):
|
||||||
|
checkpoint_file = open(name, "rb").read() if in_memory else name
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
||||||
|
|
||||||
|
with (io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")) as fp:
|
||||||
|
checkpoint = torch.load(fp, map_location=device)
|
||||||
|
del checkpoint_file
|
||||||
|
|
||||||
|
dims = ModelDimensions(**checkpoint["dims"])
|
||||||
|
model = Whisper(dims)
|
||||||
|
model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
|
|
||||||
|
return model.to(device)
|
|
@ -0,0 +1,4 @@
|
||||||
|
from .transcribe import cli
|
||||||
|
|
||||||
|
|
||||||
|
cli()
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1 @@
|
||||||
|
{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}
|
|
@ -0,0 +1 @@
|
||||||
|
{"unk_token": "<|endoftext|>", "bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "gpt2", "tokenizer_class": "GPT2Tokenizer"}
|
File diff suppressed because one or more lines are too long
Binary file not shown.
|
@ -0,0 +1 @@
|
||||||
|
{"<|endoftext|>": 50257}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1 @@
|
||||||
|
{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}
|
|
@ -0,0 +1 @@
|
||||||
|
{"unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "multilingual", "errors": "replace", "tokenizer_class": "GPT2Tokenizer"}
|
File diff suppressed because one or more lines are too long
|
@ -0,0 +1,125 @@
|
||||||
|
import os
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import ffmpeg
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from .utils import exact_div
|
||||||
|
|
||||||
|
# hard-coded audio hyperparameters
|
||||||
|
SAMPLE_RATE = 16000
|
||||||
|
N_FFT = 400
|
||||||
|
N_MELS = 80
|
||||||
|
HOP_LENGTH = 160
|
||||||
|
CHUNK_LENGTH = 30
|
||||||
|
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
|
||||||
|
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input
|
||||||
|
|
||||||
|
|
||||||
|
def load_audio(file: str, sr: int = SAMPLE_RATE):
|
||||||
|
"""
|
||||||
|
Open an audio file and read as mono waveform, resampling as necessary
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
file: str
|
||||||
|
The audio file to open
|
||||||
|
|
||||||
|
sr: int
|
||||||
|
The sample rate to resample the audio if necessary
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
A NumPy array containing the audio waveform, in float32 dtype.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# This launches a subprocess to decode audio while down-mixing and resampling as necessary.
|
||||||
|
# Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
|
||||||
|
out, _ = (
|
||||||
|
ffmpeg.input(file, threads=0)
|
||||||
|
.output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr)
|
||||||
|
.run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
|
||||||
|
)
|
||||||
|
except ffmpeg.Error as e:
|
||||||
|
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
|
||||||
|
|
||||||
|
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
|
||||||
|
|
||||||
|
|
||||||
|
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
|
||||||
|
"""
|
||||||
|
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
|
||||||
|
"""
|
||||||
|
if torch.is_tensor(array):
|
||||||
|
if array.shape[axis] > length:
|
||||||
|
array = array.index_select(dim=axis, index=torch.arange(length))
|
||||||
|
|
||||||
|
if array.shape[axis] < length:
|
||||||
|
pad_widths = [(0, 0)] * array.ndim
|
||||||
|
pad_widths[axis] = (0, length - array.shape[axis])
|
||||||
|
array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
|
||||||
|
else:
|
||||||
|
if array.shape[axis] > length:
|
||||||
|
array = array.take(indices=range(length), axis=axis)
|
||||||
|
|
||||||
|
if array.shape[axis] < length:
|
||||||
|
pad_widths = [(0, 0)] * array.ndim
|
||||||
|
pad_widths[axis] = (0, length - array.shape[axis])
|
||||||
|
array = np.pad(array, pad_widths)
|
||||||
|
|
||||||
|
return array
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=None)
|
||||||
|
def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
|
||||||
|
Allows decoupling librosa dependency; saved using:
|
||||||
|
|
||||||
|
np.savez_compressed(
|
||||||
|
"mel_filters.npz",
|
||||||
|
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
|
||||||
|
with np.load(os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")) as f:
|
||||||
|
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
|
||||||
|
|
||||||
|
|
||||||
|
def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS):
|
||||||
|
"""
|
||||||
|
Compute the log-Mel spectrogram of
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
|
||||||
|
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
|
||||||
|
|
||||||
|
n_mels: int
|
||||||
|
The number of Mel-frequency filters, only 80 is supported
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor, shape = (80, n_frames)
|
||||||
|
A Tensor that contains the Mel spectrogram
|
||||||
|
"""
|
||||||
|
if not torch.is_tensor(audio):
|
||||||
|
if isinstance(audio, str):
|
||||||
|
audio = load_audio(audio)
|
||||||
|
audio = torch.from_numpy(audio)
|
||||||
|
|
||||||
|
window = torch.hann_window(N_FFT).to(audio.device)
|
||||||
|
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
|
||||||
|
|
||||||
|
magnitudes = stft[:, :-1].abs() ** 2
|
||||||
|
|
||||||
|
filters = mel_filters(audio.device, n_mels)
|
||||||
|
mel_spec = filters @ magnitudes
|
||||||
|
|
||||||
|
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
||||||
|
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
||||||
|
log_spec = (log_spec + 4.0) / 4.0
|
||||||
|
return log_spec
|
|
@ -0,0 +1,729 @@
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Dict, List, Tuple, Iterable, Optional, Sequence, Union, TYPE_CHECKING
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.distributions import Categorical
|
||||||
|
|
||||||
|
from .audio import CHUNK_LENGTH
|
||||||
|
from .tokenizer import Tokenizer, get_tokenizer
|
||||||
|
from .utils import compression_ratio
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .model import Whisper
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None) -> Tuple[Tensor, List[dict]]:
|
||||||
|
"""
|
||||||
|
Detect the spoken language in the audio, and return them as list of strings, along with the ids
|
||||||
|
of the most probable language tokens and the probability distribution over all language tokens.
|
||||||
|
This is performed outside the main decode loop in order to not interfere with kv-caching.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
language_tokens : Tensor, shape = (n_audio,)
|
||||||
|
ids of the most probable language tokens, which appears after the startoftranscript token.
|
||||||
|
language_probs : List[Dict[str, float]], length = n_audio
|
||||||
|
list of dictionaries containing the probability distribution over all languages.
|
||||||
|
"""
|
||||||
|
if tokenizer is None:
|
||||||
|
tokenizer = get_tokenizer(model.is_multilingual)
|
||||||
|
if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence:
|
||||||
|
raise ValueError(f"This model doesn't have language tokens so it can't perform lang id")
|
||||||
|
|
||||||
|
single = mel.ndim == 2
|
||||||
|
if single:
|
||||||
|
mel = mel.unsqueeze(0)
|
||||||
|
|
||||||
|
# skip encoder forward pass if already-encoded audio features were given
|
||||||
|
if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
|
||||||
|
mel = model.encoder(mel)
|
||||||
|
|
||||||
|
# forward pass using a single token, startoftranscript
|
||||||
|
n_audio = mel.shape[0]
|
||||||
|
x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
|
||||||
|
logits = model.logits(x, mel)[:, 0]
|
||||||
|
|
||||||
|
# collect detected languages; suppress all non-language tokens
|
||||||
|
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
|
||||||
|
mask[list(tokenizer.all_language_tokens)] = False
|
||||||
|
logits[:, mask] = -np.inf
|
||||||
|
language_tokens = logits.argmax(dim=-1)
|
||||||
|
language_token_probs = logits.softmax(dim=-1).cpu()
|
||||||
|
language_probs = [
|
||||||
|
{
|
||||||
|
c: language_token_probs[i, j].item()
|
||||||
|
for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
|
||||||
|
}
|
||||||
|
for i in range(n_audio)
|
||||||
|
]
|
||||||
|
|
||||||
|
if single:
|
||||||
|
language_tokens = language_tokens[0]
|
||||||
|
language_probs = language_probs[0]
|
||||||
|
|
||||||
|
return language_tokens, language_probs
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class DecodingOptions:
|
||||||
|
task: str = "transcribe" # whether to perform X->X "transcribe" or X->English "translate"
|
||||||
|
language: Optional[str] = None # language that the audio is in; uses detected language if None
|
||||||
|
|
||||||
|
# sampling-related options
|
||||||
|
temperature: float = 0.0
|
||||||
|
sample_len: Optional[int] = None # maximum number of tokens to sample
|
||||||
|
best_of: Optional[int] = None # number of independent samples to collect, when t > 0
|
||||||
|
beam_size: Optional[int] = None # number of beams in beam search, when t == 0
|
||||||
|
patience: Optional[float] = None # patience in beam search (https://arxiv.org/abs/2204.05424)
|
||||||
|
|
||||||
|
# options for ranking generations (either beams or best-of-N samples)
|
||||||
|
length_penalty: Optional[float] = None # "alpha" in Google NMT, None defaults to length norm
|
||||||
|
|
||||||
|
# prompt, prefix, and token suppression
|
||||||
|
prompt: Optional[Union[str, List[int]]] = None # text or tokens for the previous context
|
||||||
|
prefix: Optional[Union[str, List[int]]] = None # text or tokens to prefix the current context
|
||||||
|
suppress_blank: bool = True # this will suppress blank outputs
|
||||||
|
|
||||||
|
# list of tokens ids (or comma-separated token ids) to suppress
|
||||||
|
# "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
|
||||||
|
suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
|
||||||
|
|
||||||
|
# timestamp sampling options
|
||||||
|
without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
|
||||||
|
max_initial_timestamp: Optional[float] = 1.0 # the initial timestamp cannot be later than this
|
||||||
|
|
||||||
|
# implementation details
|
||||||
|
fp16: bool = True # use fp16 for most of the calculation
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class DecodingResult:
|
||||||
|
audio_features: Tensor
|
||||||
|
language: str
|
||||||
|
encoder_embeddings: np.ndarray
|
||||||
|
decoder_embeddings: np.ndarray
|
||||||
|
language_probs: Optional[Dict[str, float]] = None
|
||||||
|
tokens: List[int] = field(default_factory=list)
|
||||||
|
text: str = ""
|
||||||
|
avg_logprob: float = np.nan
|
||||||
|
no_speech_prob: float = np.nan
|
||||||
|
temperature: float = np.nan
|
||||||
|
compression_ratio: float = np.nan
|
||||||
|
|
||||||
|
|
||||||
|
class Inference:
|
||||||
|
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
|
||||||
|
"""Perform a forward pass on the decoder and return per-token logits"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def rearrange_kv_cache(self, source_indices) -> None:
|
||||||
|
"""Update the key-value cache according to the updated beams"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def cleanup_caching(self) -> None:
|
||||||
|
"""Clean up any resources or hooks after decoding is finished"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PyTorchInference(Inference):
|
||||||
|
def __init__(self, model: "Whisper", initial_token_length: int):
|
||||||
|
self.model: "Whisper" = model
|
||||||
|
self.initial_token_length = initial_token_length
|
||||||
|
self.kv_cache = {}
|
||||||
|
self.hooks = []
|
||||||
|
|
||||||
|
def logits(self, tokens: Tensor, audio_features: Tensor, include_embeddings=False) -> Tensor:
|
||||||
|
if not self.kv_cache:
|
||||||
|
self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
|
||||||
|
|
||||||
|
if tokens.shape[-1] > self.initial_token_length:
|
||||||
|
# only need to use the last token except in the first forward pass
|
||||||
|
tokens = tokens[:, -1:]
|
||||||
|
|
||||||
|
return_val = self.model.decoder(tokens, audio_features,
|
||||||
|
kv_cache=self.kv_cache, include_embeddings=include_embeddings)
|
||||||
|
return return_val
|
||||||
|
|
||||||
|
def cleanup_caching(self):
|
||||||
|
for hook in self.hooks:
|
||||||
|
hook.remove()
|
||||||
|
|
||||||
|
self.kv_cache = {}
|
||||||
|
self.hooks = []
|
||||||
|
|
||||||
|
def rearrange_kv_cache(self, source_indices):
|
||||||
|
for module, tensor in self.kv_cache.items():
|
||||||
|
# update the key/value cache to contain the selected sequences
|
||||||
|
self.kv_cache[module] = tensor[source_indices].detach()
|
||||||
|
|
||||||
|
|
||||||
|
class SequenceRanker:
|
||||||
|
def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]) -> List[int]:
|
||||||
|
"""
|
||||||
|
Given a list of groups of samples and their cumulative log probabilities,
|
||||||
|
return the indices of the samples in each group to select as the final result
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class MaximumLikelihoodRanker(SequenceRanker):
|
||||||
|
"""
|
||||||
|
Select the sample with the highest log probabilities, penalized using either
|
||||||
|
a simple length normalization or Google NMT paper's length penalty
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, length_penalty: Optional[float]):
|
||||||
|
self.length_penalty = length_penalty
|
||||||
|
|
||||||
|
def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]):
|
||||||
|
def scores(logprobs, lengths):
|
||||||
|
result = []
|
||||||
|
for logprob, length in zip(logprobs, lengths):
|
||||||
|
if self.length_penalty is None:
|
||||||
|
penalty = length
|
||||||
|
else:
|
||||||
|
# from the Google NMT paper
|
||||||
|
penalty = ((5 + length) / 6) ** self.length_penalty
|
||||||
|
result.append(logprob / penalty)
|
||||||
|
return result
|
||||||
|
|
||||||
|
# get the sequence with the highest score
|
||||||
|
lengths = [[len(t) for t in s] for s in tokens]
|
||||||
|
return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
|
||||||
|
|
||||||
|
|
||||||
|
class TokenDecoder:
|
||||||
|
def reset(self):
|
||||||
|
"""Initialize any stateful variables for decoding a new sequence"""
|
||||||
|
|
||||||
|
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
|
||||||
|
"""Specify how to select the next token, based on the current trace and logits
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
tokens : Tensor, shape = (n_batch, current_sequence_length)
|
||||||
|
all tokens in the context so far, including the prefix and sot_sequence tokens
|
||||||
|
|
||||||
|
logits : Tensor, shape = (n_batch, vocab_size)
|
||||||
|
per-token logits of the probability distribution at the current step
|
||||||
|
|
||||||
|
sum_logprobs : Tensor, shape = (n_batch)
|
||||||
|
cumulative log probabilities for each sequence
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
tokens : Tensor, shape = (n_batch, current_sequence_length + 1)
|
||||||
|
the tokens, appended with the selected next token
|
||||||
|
|
||||||
|
completed : bool
|
||||||
|
True if all sequences has reached the end of text
|
||||||
|
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def finalize(
|
||||||
|
self, tokens: Tensor, sum_logprobs: Tensor
|
||||||
|
) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]:
|
||||||
|
"""Finalize search and return the final candidate sequences
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
tokens : Tensor, shape = (n_audio, n_group, current_sequence_length)
|
||||||
|
all tokens in the context so far, including the prefix and sot_sequence
|
||||||
|
|
||||||
|
sum_logprobs : Tensor, shape = (n_audio, n_group)
|
||||||
|
cumulative log probabilities for each sequence
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
tokens : Sequence[Sequence[Tensor]], length = n_audio
|
||||||
|
sequence of Tensors containing candidate token sequences, for each audio input
|
||||||
|
|
||||||
|
sum_logprobs : List[List[float]], length = n_audio
|
||||||
|
sequence of cumulative log probabilities corresponding to the above
|
||||||
|
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class GreedyDecoder(TokenDecoder):
|
||||||
|
def __init__(self, temperature: float, eot: int):
|
||||||
|
self.temperature = temperature
|
||||||
|
self.eot = eot
|
||||||
|
|
||||||
|
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
|
||||||
|
temperature = self.temperature
|
||||||
|
if temperature == 0:
|
||||||
|
next_tokens = logits.argmax(dim=-1)
|
||||||
|
else:
|
||||||
|
next_tokens = Categorical(logits=logits / temperature).sample()
|
||||||
|
|
||||||
|
logprobs = F.log_softmax(logits.float(), dim=-1)
|
||||||
|
current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
|
||||||
|
sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
|
||||||
|
|
||||||
|
next_tokens[tokens[:, -1] == self.eot] = self.eot
|
||||||
|
tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
|
||||||
|
|
||||||
|
completed = (tokens[:, -1] == self.eot).all()
|
||||||
|
return tokens, completed
|
||||||
|
|
||||||
|
def finalize(self, tokens: Tensor, sum_logprobs: Tensor):
|
||||||
|
# make sure each sequence has at least one EOT token at the end
|
||||||
|
tokens = F.pad(tokens, (0, 1), value=self.eot)
|
||||||
|
return tokens, sum_logprobs.tolist()
|
||||||
|
|
||||||
|
|
||||||
|
class BeamSearchDecoder(TokenDecoder):
|
||||||
|
def __init__(self, beam_size: int, eot: int, inference: Inference, patience: Optional[float] = None):
|
||||||
|
self.beam_size = beam_size
|
||||||
|
self.eot = eot
|
||||||
|
self.inference = inference
|
||||||
|
self.patience = patience or 1.0
|
||||||
|
self.max_candidates: int = round(beam_size * self.patience)
|
||||||
|
self.finished_sequences = None
|
||||||
|
|
||||||
|
assert self.max_candidates > 0, f"Invalid beam size ({beam_size}) or patience ({patience})"
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.finished_sequences = None
|
||||||
|
|
||||||
|
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
|
||||||
|
if tokens.shape[0] % self.beam_size != 0:
|
||||||
|
raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
|
||||||
|
|
||||||
|
n_audio = tokens.shape[0] // self.beam_size
|
||||||
|
if self.finished_sequences is None: # for the first update
|
||||||
|
self.finished_sequences = [{} for _ in range(n_audio)]
|
||||||
|
|
||||||
|
logprobs = F.log_softmax(logits.float(), dim=-1)
|
||||||
|
next_tokens, source_indices, finished_sequences = [], [], []
|
||||||
|
for i in range(n_audio):
|
||||||
|
scores, sources, finished = {}, {}, {}
|
||||||
|
|
||||||
|
# STEP 1: calculate the cumulative log probabilities for possible candidates
|
||||||
|
for j in range(self.beam_size):
|
||||||
|
idx = i * self.beam_size + j
|
||||||
|
prefix = tokens[idx].tolist()
|
||||||
|
for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)):
|
||||||
|
new_logprob = (sum_logprobs[idx] + logprob).item()
|
||||||
|
sequence = tuple(prefix + [token.item()])
|
||||||
|
scores[sequence] = new_logprob
|
||||||
|
sources[sequence] = idx
|
||||||
|
|
||||||
|
# STEP 2: rank the candidates and keep the top beam_size sequences for each audio
|
||||||
|
saved = 0
|
||||||
|
for sequence in sorted(scores, key=scores.get, reverse=True):
|
||||||
|
if sequence[-1] == self.eot:
|
||||||
|
finished[sequence] = scores[sequence]
|
||||||
|
else:
|
||||||
|
sum_logprobs[len(next_tokens)] = scores[sequence]
|
||||||
|
next_tokens.append(sequence)
|
||||||
|
source_indices.append(sources[sequence])
|
||||||
|
|
||||||
|
saved += 1
|
||||||
|
if saved == self.beam_size:
|
||||||
|
break
|
||||||
|
|
||||||
|
finished_sequences.append(finished)
|
||||||
|
|
||||||
|
tokens = torch.tensor(next_tokens, device=tokens.device)
|
||||||
|
self.inference.rearrange_kv_cache(source_indices)
|
||||||
|
|
||||||
|
# add newly finished sequences to self.finished_sequences
|
||||||
|
assert len(self.finished_sequences) == len(finished_sequences)
|
||||||
|
for previously_finished, newly_finished in zip(self.finished_sequences, finished_sequences):
|
||||||
|
for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
|
||||||
|
if len(previously_finished) >= self.max_candidates:
|
||||||
|
break # the candidate list is full
|
||||||
|
previously_finished[seq] = newly_finished[seq]
|
||||||
|
|
||||||
|
# mark as completed if all audio has enough number of samples
|
||||||
|
completed = all(
|
||||||
|
len(sequences) >= self.max_candidates for sequences in self.finished_sequences
|
||||||
|
)
|
||||||
|
return tokens, completed
|
||||||
|
|
||||||
|
def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor):
|
||||||
|
# collect all finished sequences, including patience, and add unfinished ones if not enough
|
||||||
|
sum_logprobs = sum_logprobs.cpu()
|
||||||
|
for i, sequences in enumerate(self.finished_sequences):
|
||||||
|
if len(sequences) < self.beam_size: # when not enough sequences are finished
|
||||||
|
for j in list(np.argsort(sum_logprobs[i]))[::-1]:
|
||||||
|
sequence = preceding_tokens[i, j].tolist() + [self.eot]
|
||||||
|
sequences[tuple(sequence)] = sum_logprobs[i][j].item()
|
||||||
|
if len(sequences) >= self.beam_size:
|
||||||
|
break
|
||||||
|
|
||||||
|
tokens: List[List[Tensor]] = [
|
||||||
|
[torch.tensor(seq) for seq in sequences.keys()] for sequences in self.finished_sequences
|
||||||
|
]
|
||||||
|
sum_logprobs: List[List[float]] = [
|
||||||
|
list(sequences.values()) for sequences in self.finished_sequences
|
||||||
|
]
|
||||||
|
return tokens, sum_logprobs
|
||||||
|
|
||||||
|
|
||||||
|
class LogitFilter:
|
||||||
|
def apply(self, logits: Tensor, tokens: Tensor) -> None:
|
||||||
|
"""Apply any filtering or masking to logits in-place
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
logits : Tensor, shape = (n_batch, vocab_size)
|
||||||
|
per-token logits of the probability distribution at the current step
|
||||||
|
|
||||||
|
tokens : Tensor, shape = (n_batch, current_sequence_length)
|
||||||
|
all tokens in the context so far, including the prefix and sot_sequence tokens
|
||||||
|
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class SuppressBlank(LogitFilter):
|
||||||
|
def __init__(self, tokenizer: Tokenizer, sample_begin: int):
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.sample_begin = sample_begin
|
||||||
|
|
||||||
|
def apply(self, logits: Tensor, tokens: Tensor):
|
||||||
|
if tokens.shape[1] == self.sample_begin:
|
||||||
|
logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
|
||||||
|
|
||||||
|
|
||||||
|
class SuppressTokens(LogitFilter):
|
||||||
|
def __init__(self, suppress_tokens: Sequence[int]):
|
||||||
|
self.suppress_tokens = list(suppress_tokens)
|
||||||
|
|
||||||
|
def apply(self, logits: Tensor, tokens: Tensor):
|
||||||
|
logits[:, self.suppress_tokens] = -np.inf
|
||||||
|
|
||||||
|
|
||||||
|
class ApplyTimestampRules(LogitFilter):
|
||||||
|
def __init__(
|
||||||
|
self, tokenizer: Tokenizer, sample_begin: int, max_initial_timestamp_index: Optional[int]
|
||||||
|
):
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.sample_begin = sample_begin
|
||||||
|
self.max_initial_timestamp_index = max_initial_timestamp_index
|
||||||
|
|
||||||
|
def apply(self, logits: Tensor, tokens: Tensor):
|
||||||
|
# suppress <|notimestamps|> which is handled by without_timestamps
|
||||||
|
if self.tokenizer.no_timestamps is not None:
|
||||||
|
logits[:, self.tokenizer.no_timestamps] = -np.inf
|
||||||
|
|
||||||
|
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
|
||||||
|
for k in range(tokens.shape[0]):
|
||||||
|
seq = [t for t in tokens[k, self.sample_begin :].tolist()]
|
||||||
|
last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
|
||||||
|
penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
|
||||||
|
|
||||||
|
if last_was_timestamp:
|
||||||
|
if penultimate_was_timestamp: # has to be non-timestamp
|
||||||
|
logits[k, self.tokenizer.timestamp_begin :] = -np.inf
|
||||||
|
else: # cannot be normal text tokens
|
||||||
|
logits[k, : self.tokenizer.eot] = -np.inf
|
||||||
|
|
||||||
|
# apply the `max_initial_timestamp` option
|
||||||
|
if tokens.shape[1] == self.sample_begin and self.max_initial_timestamp_index is not None:
|
||||||
|
last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
|
||||||
|
logits[:, last_allowed + 1 :] = -np.inf
|
||||||
|
|
||||||
|
# if sum of probability over timestamps is above any other token, sample timestamp
|
||||||
|
logprobs = F.log_softmax(logits.float(), dim=-1)
|
||||||
|
for k in range(tokens.shape[0]):
|
||||||
|
timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(dim=-1)
|
||||||
|
max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
|
||||||
|
if timestamp_logprob > max_text_token_logprob:
|
||||||
|
logits[k, : self.tokenizer.timestamp_begin] = -np.inf
|
||||||
|
|
||||||
|
|
||||||
|
class DecodingTask:
|
||||||
|
inference: Inference
|
||||||
|
sequence_ranker: SequenceRanker
|
||||||
|
decoder: TokenDecoder
|
||||||
|
logit_filters: List[LogitFilter]
|
||||||
|
|
||||||
|
def __init__(self, model: "Whisper", options: DecodingOptions):
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
language = options.language or "en"
|
||||||
|
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=options.task)
|
||||||
|
self.tokenizer: Tokenizer = tokenizer
|
||||||
|
self.options: DecodingOptions = self._verify_options(options)
|
||||||
|
|
||||||
|
self.n_group: int = options.beam_size or options.best_of or 1
|
||||||
|
self.n_ctx: int = model.dims.n_text_ctx
|
||||||
|
self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
|
||||||
|
|
||||||
|
self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
|
||||||
|
if self.options.without_timestamps:
|
||||||
|
self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
|
||||||
|
|
||||||
|
self.initial_tokens: Tuple[int] = self._get_initial_tokens()
|
||||||
|
self.sample_begin: int = len(self.initial_tokens)
|
||||||
|
self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
|
||||||
|
|
||||||
|
# inference: implements the forward pass through the decoder, including kv caching
|
||||||
|
self.inference = PyTorchInference(model, len(self.initial_tokens))
|
||||||
|
|
||||||
|
# sequence ranker: implements how to rank a group of sampled sequences
|
||||||
|
self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
|
||||||
|
|
||||||
|
# decoder: implements how to select the next tokens, given the autoregressive distribution
|
||||||
|
if options.beam_size is not None:
|
||||||
|
self.decoder = BeamSearchDecoder(
|
||||||
|
options.beam_size, tokenizer.eot, self.inference, options.patience
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
|
||||||
|
|
||||||
|
# logit filters: applies various rules to suppress or penalize certain tokens
|
||||||
|
self.logit_filters = []
|
||||||
|
if self.options.suppress_blank:
|
||||||
|
self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
|
||||||
|
if self.options.suppress_tokens:
|
||||||
|
self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
|
||||||
|
if not options.without_timestamps:
|
||||||
|
precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
|
||||||
|
max_initial_timestamp_index = None
|
||||||
|
if options.max_initial_timestamp:
|
||||||
|
max_initial_timestamp_index = round(self.options.max_initial_timestamp / precision)
|
||||||
|
self.logit_filters.append(
|
||||||
|
ApplyTimestampRules(tokenizer, self.sample_begin, max_initial_timestamp_index)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
|
||||||
|
if options.beam_size is not None and options.best_of is not None:
|
||||||
|
raise ValueError("beam_size and best_of can't be given together")
|
||||||
|
if options.temperature == 0:
|
||||||
|
if options.best_of is not None:
|
||||||
|
raise ValueError("best_of with greedy sampling (T=0) is not compatible")
|
||||||
|
if options.patience is not None and options.beam_size is None:
|
||||||
|
raise ValueError("patience requires beam_size to be given")
|
||||||
|
if options.length_penalty is not None and not (0 <= options.length_penalty <= 1):
|
||||||
|
raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
|
||||||
|
|
||||||
|
return options
|
||||||
|
|
||||||
|
def _get_initial_tokens(self) -> Tuple[int]:
|
||||||
|
tokens = list(self.sot_sequence)
|
||||||
|
prefix = self.options.prefix
|
||||||
|
prompt = self.options.prompt
|
||||||
|
|
||||||
|
if prefix:
|
||||||
|
prefix_tokens = (
|
||||||
|
self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix
|
||||||
|
)
|
||||||
|
if self.sample_len is not None:
|
||||||
|
max_prefix_len = self.n_ctx // 2 - self.sample_len
|
||||||
|
prefix_tokens = prefix_tokens[-max_prefix_len:]
|
||||||
|
tokens = tokens + prefix_tokens
|
||||||
|
|
||||||
|
if prompt:
|
||||||
|
prompt_tokens = (
|
||||||
|
self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt
|
||||||
|
)
|
||||||
|
tokens = [self.tokenizer.sot_prev] + prompt_tokens[-(self.n_ctx // 2 - 1) :] + tokens
|
||||||
|
|
||||||
|
return tuple(tokens)
|
||||||
|
|
||||||
|
def _get_suppress_tokens(self) -> Tuple[int]:
|
||||||
|
suppress_tokens = self.options.suppress_tokens
|
||||||
|
|
||||||
|
if isinstance(suppress_tokens, str):
|
||||||
|
suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
|
||||||
|
|
||||||
|
if -1 in suppress_tokens:
|
||||||
|
suppress_tokens = [t for t in suppress_tokens if t >= 0]
|
||||||
|
suppress_tokens.extend(self.tokenizer.non_speech_tokens)
|
||||||
|
elif suppress_tokens is None or len(suppress_tokens) == 0:
|
||||||
|
suppress_tokens = [] # interpret empty string as an empty list
|
||||||
|
else:
|
||||||
|
assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
|
||||||
|
|
||||||
|
suppress_tokens.extend(
|
||||||
|
[self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm]
|
||||||
|
)
|
||||||
|
if self.tokenizer.no_speech is not None:
|
||||||
|
# no-speech probability is collected separately
|
||||||
|
suppress_tokens.append(self.tokenizer.no_speech)
|
||||||
|
|
||||||
|
return tuple(sorted(set(suppress_tokens)))
|
||||||
|
|
||||||
|
def _get_audio_features(self, mel: Tensor, include_embeddings: bool = False):
|
||||||
|
if self.options.fp16:
|
||||||
|
mel = mel.half()
|
||||||
|
|
||||||
|
if mel.shape[-2:] == (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state):
|
||||||
|
# encoded audio features are given; skip audio encoding
|
||||||
|
audio_features = mel
|
||||||
|
else:
|
||||||
|
result = self.model.encoder(mel, include_embeddings)
|
||||||
|
if include_embeddings:
|
||||||
|
audio_features, embeddings = result
|
||||||
|
else:
|
||||||
|
audio_features = result
|
||||||
|
|
||||||
|
if audio_features.dtype != (torch.float16 if self.options.fp16 else torch.float32):
|
||||||
|
return TypeError(f"audio_features has an incorrect dtype: {audio_features.dtype}")
|
||||||
|
|
||||||
|
if include_embeddings:
|
||||||
|
return audio_features, embeddings
|
||||||
|
else:
|
||||||
|
return audio_features
|
||||||
|
|
||||||
|
def _detect_language(self, audio_features: Tensor, tokens: Tensor):
|
||||||
|
languages = [self.options.language] * audio_features.shape[0]
|
||||||
|
lang_probs = None
|
||||||
|
|
||||||
|
if self.options.language is None or self.options.task == "lang_id":
|
||||||
|
lang_tokens, lang_probs = self.model.detect_language(audio_features, self.tokenizer)
|
||||||
|
languages = [max(probs, key=probs.get) for probs in lang_probs]
|
||||||
|
if self.options.language is None:
|
||||||
|
tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
|
||||||
|
|
||||||
|
return languages, lang_probs
|
||||||
|
|
||||||
|
def _main_loop(self, audio_features: Tensor, tokens: Tensor):
|
||||||
|
assert audio_features.shape[0] == tokens.shape[0]
|
||||||
|
n_batch = tokens.shape[0]
|
||||||
|
sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
|
||||||
|
no_speech_probs = [np.nan] * n_batch
|
||||||
|
|
||||||
|
try:
|
||||||
|
embeddings = []
|
||||||
|
for i in range(self.sample_len):
|
||||||
|
logits, token_embeddings = self.inference.logits(tokens, audio_features, include_embeddings=True)
|
||||||
|
|
||||||
|
if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs
|
||||||
|
probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
|
||||||
|
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
|
||||||
|
|
||||||
|
# now we need to consider the logits at the last token only
|
||||||
|
logits = logits[:, -1]
|
||||||
|
token_embeddings = token_embeddings[:, :, -1]
|
||||||
|
|
||||||
|
# Append embeddings together
|
||||||
|
embeddings.append(token_embeddings)
|
||||||
|
|
||||||
|
# apply the logit filters, e.g. for suppressing or applying penalty to
|
||||||
|
for logit_filter in self.logit_filters:
|
||||||
|
logit_filter.apply(logits, tokens)
|
||||||
|
|
||||||
|
# expand the tokens tensor with the selected next tokens
|
||||||
|
tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
|
||||||
|
|
||||||
|
if completed or tokens.shape[-1] > self.n_ctx:
|
||||||
|
break
|
||||||
|
finally:
|
||||||
|
if completed:
|
||||||
|
embeddings = embeddings[:-1]
|
||||||
|
embeddings = np.stack(embeddings, 2)
|
||||||
|
self.inference.cleanup_caching()
|
||||||
|
|
||||||
|
return tokens, sum_logprobs, no_speech_probs, embeddings
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def run(self, mel: Tensor) -> List[DecodingResult]:
|
||||||
|
self.decoder.reset()
|
||||||
|
tokenizer: Tokenizer = self.tokenizer
|
||||||
|
n_audio: int = mel.shape[0]
|
||||||
|
|
||||||
|
# encoder forward pass
|
||||||
|
forward_pass: Tuple[Tensor, np.ndarray] = self._get_audio_features(mel, include_embeddings=True)
|
||||||
|
audio_features, encoder_embeddings = forward_pass
|
||||||
|
tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
|
||||||
|
|
||||||
|
# detect language if requested, overwriting the language token
|
||||||
|
languages, language_probs = self._detect_language(audio_features, tokens)
|
||||||
|
if self.options.task == "lang_id":
|
||||||
|
return [
|
||||||
|
DecodingResult(audio_features=features, language=language, language_probs=probs)
|
||||||
|
for features, language, probs in zip(audio_features, languages, language_probs)
|
||||||
|
]
|
||||||
|
|
||||||
|
# repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
|
||||||
|
audio_features = audio_features.repeat_interleave(self.n_group, dim=0)
|
||||||
|
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
|
||||||
|
|
||||||
|
# call the main sampling loop
|
||||||
|
tokens, sum_logprobs, no_speech_probs, decoder_embeddings = self._main_loop(audio_features, tokens)
|
||||||
|
|
||||||
|
# reshape the tensors to have (n_audio, n_group) as the first two dimensions
|
||||||
|
audio_features = audio_features[:: self.n_group]
|
||||||
|
no_speech_probs = no_speech_probs[:: self.n_group]
|
||||||
|
assert audio_features.shape[0] == len(no_speech_probs) == n_audio
|
||||||
|
|
||||||
|
tokens = tokens.reshape(n_audio, self.n_group, -1)
|
||||||
|
sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
|
||||||
|
|
||||||
|
# get the final candidates for each group, and slice between the first sampled token and EOT
|
||||||
|
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
|
||||||
|
tokens: List[List[Tensor]] = [
|
||||||
|
[t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] for s in tokens
|
||||||
|
]
|
||||||
|
|
||||||
|
# select the top-ranked sample in each group
|
||||||
|
selected = self.sequence_ranker.rank(tokens, sum_logprobs)
|
||||||
|
tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
|
||||||
|
texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
|
||||||
|
|
||||||
|
sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
|
||||||
|
avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)]
|
||||||
|
|
||||||
|
fields = (texts, languages, tokens, audio_features, avg_logprobs, no_speech_probs)
|
||||||
|
if len(set(map(len, fields))) != 1:
|
||||||
|
raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
|
||||||
|
|
||||||
|
return [
|
||||||
|
DecodingResult(
|
||||||
|
audio_features=features,
|
||||||
|
language=language,
|
||||||
|
tokens=tokens,
|
||||||
|
text=text,
|
||||||
|
avg_logprob=avg_logprob,
|
||||||
|
no_speech_prob=no_speech_prob,
|
||||||
|
temperature=self.options.temperature,
|
||||||
|
compression_ratio=compression_ratio(text),
|
||||||
|
encoder_embeddings=encoder_embeddings,
|
||||||
|
decoder_embeddings=decoder_embeddings
|
||||||
|
)
|
||||||
|
for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def decode(model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions()) -> Union[DecodingResult, List[DecodingResult]]:
|
||||||
|
"""
|
||||||
|
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
model: Whisper
|
||||||
|
the Whisper model instance
|
||||||
|
|
||||||
|
mel: torch.Tensor, shape = (80, 3000) or (*, 80, 3000)
|
||||||
|
A tensor containing the Mel spectrogram(s)
|
||||||
|
|
||||||
|
options: DecodingOptions
|
||||||
|
A dataclass that contains all necessary options for decoding 30-second segments
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
result: Union[DecodingResult, List[DecodingResult]]
|
||||||
|
The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
|
||||||
|
"""
|
||||||
|
single = mel.ndim == 2
|
||||||
|
if single:
|
||||||
|
mel = mel.unsqueeze(0)
|
||||||
|
|
||||||
|
result = DecodingTask(model, options).run(mel)
|
||||||
|
|
||||||
|
if single:
|
||||||
|
result = result[0]
|
||||||
|
|
||||||
|
return result
|
|
@ -0,0 +1,290 @@
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict
|
||||||
|
from typing import Iterable, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from .transcribe import transcribe as transcribe_function
|
||||||
|
from .decoding import detect_language as detect_language_function, decode as decode_function
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelDimensions:
|
||||||
|
n_mels: int
|
||||||
|
n_audio_ctx: int
|
||||||
|
n_audio_state: int
|
||||||
|
n_audio_head: int
|
||||||
|
n_audio_layer: int
|
||||||
|
n_vocab: int
|
||||||
|
n_text_ctx: int
|
||||||
|
n_text_state: int
|
||||||
|
n_text_head: int
|
||||||
|
n_text_layer: int
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm(nn.LayerNorm):
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return super().forward(x.float()).type(x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class Linear(nn.Linear):
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return F.linear(
|
||||||
|
x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Conv1d(nn.Conv1d):
|
||||||
|
def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor:
|
||||||
|
return super()._conv_forward(
|
||||||
|
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def sinusoids(length, channels, max_timescale=10000):
|
||||||
|
"""Returns sinusoids for positional embedding"""
|
||||||
|
assert channels % 2 == 0
|
||||||
|
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
|
||||||
|
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
|
||||||
|
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
||||||
|
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHeadAttention(nn.Module):
|
||||||
|
def __init__(self, n_state: int, n_head: int):
|
||||||
|
super().__init__()
|
||||||
|
self.n_head = n_head
|
||||||
|
self.query = Linear(n_state, n_state)
|
||||||
|
self.key = Linear(n_state, n_state, bias=False)
|
||||||
|
self.value = Linear(n_state, n_state)
|
||||||
|
self.out = Linear(n_state, n_state)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
xa: Optional[Tensor] = None,
|
||||||
|
mask: Optional[Tensor] = None,
|
||||||
|
kv_cache: Optional[dict] = None,
|
||||||
|
):
|
||||||
|
q = self.query(x)
|
||||||
|
|
||||||
|
if kv_cache is None or xa is None:
|
||||||
|
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
|
||||||
|
# otherwise, perform key/value projections for self- or cross-attention as usual.
|
||||||
|
k = self.key(x if xa is None else xa)
|
||||||
|
v = self.value(x if xa is None else xa)
|
||||||
|
else:
|
||||||
|
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
|
||||||
|
k = kv_cache.get(self.key, self.key(xa))
|
||||||
|
v = kv_cache.get(self.value, self.value(xa))
|
||||||
|
|
||||||
|
wv = self.qkv_attention(q, k, v, mask)
|
||||||
|
return self.out(wv)
|
||||||
|
|
||||||
|
def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):
|
||||||
|
n_batch, n_ctx, n_state = q.shape
|
||||||
|
scale = (n_state // self.n_head) ** -0.25
|
||||||
|
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
|
||||||
|
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
|
||||||
|
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
||||||
|
|
||||||
|
qk = q @ k
|
||||||
|
if mask is not None:
|
||||||
|
qk = qk + mask[:n_ctx, :n_ctx]
|
||||||
|
|
||||||
|
w = F.softmax(qk.float(), dim=-1).to(q.dtype)
|
||||||
|
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualAttentionBlock(nn.Module):
|
||||||
|
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.attn = MultiHeadAttention(n_state, n_head)
|
||||||
|
self.attn_ln = LayerNorm(n_state)
|
||||||
|
|
||||||
|
self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None
|
||||||
|
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
|
||||||
|
|
||||||
|
n_mlp = n_state * 4
|
||||||
|
self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state))
|
||||||
|
self.mlp_ln = LayerNorm(n_state)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
xa: Optional[Tensor] = None,
|
||||||
|
mask: Optional[Tensor] = None,
|
||||||
|
kv_cache: Optional[dict] = None,
|
||||||
|
):
|
||||||
|
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)
|
||||||
|
if self.cross_attn:
|
||||||
|
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)
|
||||||
|
x = x + self.mlp(self.mlp_ln(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AudioEncoder(nn.Module):
|
||||||
|
def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
|
||||||
|
super().__init__()
|
||||||
|
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
|
||||||
|
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
|
||||||
|
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
|
||||||
|
|
||||||
|
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
||||||
|
[ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
|
||||||
|
)
|
||||||
|
self.ln_post = LayerNorm(n_state)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, include_embeddings: bool = False):
|
||||||
|
"""
|
||||||
|
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
|
||||||
|
the mel spectrogram of the audio
|
||||||
|
include_embeddings: bool
|
||||||
|
whether to include intermediate steps in the output
|
||||||
|
"""
|
||||||
|
x = F.gelu(self.conv1(x))
|
||||||
|
x = F.gelu(self.conv2(x))
|
||||||
|
x = x.permute(0, 2, 1)
|
||||||
|
|
||||||
|
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
|
||||||
|
x = (x + self.positional_embedding).to(x.dtype)
|
||||||
|
|
||||||
|
if include_embeddings:
|
||||||
|
embeddings = [x.cpu().detach().numpy()]
|
||||||
|
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(x)
|
||||||
|
if include_embeddings:
|
||||||
|
embeddings.append(x.cpu().detach().numpy())
|
||||||
|
|
||||||
|
x = self.ln_post(x)
|
||||||
|
|
||||||
|
if include_embeddings:
|
||||||
|
embeddings = np.stack(embeddings, axis=1)
|
||||||
|
return x, embeddings
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TextDecoder(nn.Module):
|
||||||
|
def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.token_embedding = nn.Embedding(n_vocab, n_state)
|
||||||
|
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
|
||||||
|
|
||||||
|
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
||||||
|
[ResidualAttentionBlock(n_state, n_head, cross_attention=True) for _ in range(n_layer)]
|
||||||
|
)
|
||||||
|
self.ln = LayerNorm(n_state)
|
||||||
|
|
||||||
|
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
|
||||||
|
self.register_buffer("mask", mask, persistent=False)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None, include_embeddings: bool = False):
|
||||||
|
"""
|
||||||
|
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
|
||||||
|
the text tokens
|
||||||
|
xa : torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx)
|
||||||
|
the encoded audio features to be attended on
|
||||||
|
include_embeddings : bool
|
||||||
|
Whether to include intermediate values in the output to this function
|
||||||
|
"""
|
||||||
|
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
||||||
|
x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]]
|
||||||
|
x = x.to(xa.dtype)
|
||||||
|
|
||||||
|
if include_embeddings:
|
||||||
|
embeddings = [x.cpu().detach().numpy()]
|
||||||
|
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
|
||||||
|
if include_embeddings:
|
||||||
|
embeddings.append(x.cpu().detach().numpy())
|
||||||
|
|
||||||
|
x = self.ln(x)
|
||||||
|
logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
|
||||||
|
|
||||||
|
if include_embeddings:
|
||||||
|
embeddings = np.stack(embeddings, axis=1)
|
||||||
|
return logits, embeddings
|
||||||
|
else:
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
class Whisper(nn.Module):
|
||||||
|
def __init__(self, dims: ModelDimensions):
|
||||||
|
super().__init__()
|
||||||
|
self.dims = dims
|
||||||
|
self.encoder = AudioEncoder(
|
||||||
|
self.dims.n_mels,
|
||||||
|
self.dims.n_audio_ctx,
|
||||||
|
self.dims.n_audio_state,
|
||||||
|
self.dims.n_audio_head,
|
||||||
|
self.dims.n_audio_layer,
|
||||||
|
)
|
||||||
|
self.decoder = TextDecoder(
|
||||||
|
self.dims.n_vocab,
|
||||||
|
self.dims.n_text_ctx,
|
||||||
|
self.dims.n_text_state,
|
||||||
|
self.dims.n_text_head,
|
||||||
|
self.dims.n_text_layer,
|
||||||
|
)
|
||||||
|
|
||||||
|
def embed_audio(self, mel: torch.Tensor):
|
||||||
|
return self.encoder.forward(mel)
|
||||||
|
|
||||||
|
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
|
||||||
|
return self.decoder.forward(tokens, audio_features)
|
||||||
|
|
||||||
|
def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]:
|
||||||
|
return self.decoder(tokens, self.encoder(mel))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return next(self.parameters()).device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_multilingual(self):
|
||||||
|
return self.dims.n_vocab == 51865
|
||||||
|
|
||||||
|
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
|
||||||
|
"""
|
||||||
|
The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
|
||||||
|
tensors calculated for the previous positions. This method returns a dictionary that stores
|
||||||
|
all caches, and the necessary hooks for the key and value projection modules that save the
|
||||||
|
intermediate tensors to be reused during later calculations.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
cache : Dict[nn.Module, torch.Tensor]
|
||||||
|
A dictionary object mapping the key/value projection modules to its cache
|
||||||
|
hooks : List[RemovableHandle]
|
||||||
|
List of PyTorch RemovableHandle objects to stop the hooks to be called
|
||||||
|
"""
|
||||||
|
cache = {**cache} if cache is not None else {}
|
||||||
|
hooks = []
|
||||||
|
|
||||||
|
def save_to_cache(module, _, output):
|
||||||
|
if module not in cache or output.shape[1] > self.decoder.positional_embedding.shape[0]:
|
||||||
|
cache[module] = output # save as-is, for the first token or cross attention
|
||||||
|
else:
|
||||||
|
cache[module] = torch.cat([cache[module], output], dim=1).detach()
|
||||||
|
return cache[module]
|
||||||
|
|
||||||
|
def install_hooks(layer: nn.Module):
|
||||||
|
if isinstance(layer, MultiHeadAttention):
|
||||||
|
hooks.append(layer.key.register_forward_hook(save_to_cache))
|
||||||
|
hooks.append(layer.value.register_forward_hook(save_to_cache))
|
||||||
|
|
||||||
|
self.decoder.apply(install_hooks)
|
||||||
|
return cache, hooks
|
||||||
|
|
||||||
|
detect_language = detect_language_function
|
||||||
|
transcribe = transcribe_function
|
||||||
|
decode = decode_function
|
|
@ -0,0 +1,2 @@
|
||||||
|
from .basic import BasicTextNormalizer
|
||||||
|
from .english import EnglishTextNormalizer
|
|
@ -0,0 +1,71 @@
|
||||||
|
import re
|
||||||
|
import unicodedata
|
||||||
|
|
||||||
|
import regex
|
||||||
|
|
||||||
|
# non-ASCII letters that are not separated by "NFKD" normalization
|
||||||
|
ADDITIONAL_DIACRITICS = {
|
||||||
|
"œ": "oe",
|
||||||
|
"Œ": "OE",
|
||||||
|
"ø": "o",
|
||||||
|
"Ø": "O",
|
||||||
|
"æ": "ae",
|
||||||
|
"Æ": "AE",
|
||||||
|
"ß": "ss",
|
||||||
|
"ẞ": "SS",
|
||||||
|
"đ": "d",
|
||||||
|
"Đ": "D",
|
||||||
|
"ð": "d",
|
||||||
|
"Ð": "D",
|
||||||
|
"þ": "th",
|
||||||
|
"Þ": "th",
|
||||||
|
"ł": "l",
|
||||||
|
"Ł": "L",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def remove_symbols_and_diacritics(s: str, keep=""):
|
||||||
|
"""
|
||||||
|
Replace any other markers, symbols, and punctuations with a space,
|
||||||
|
and drop any diacritics (category 'Mn' and some manual mappings)
|
||||||
|
"""
|
||||||
|
return "".join(
|
||||||
|
c
|
||||||
|
if c in keep
|
||||||
|
else ADDITIONAL_DIACRITICS[c]
|
||||||
|
if c in ADDITIONAL_DIACRITICS
|
||||||
|
else ""
|
||||||
|
if unicodedata.category(c) == "Mn"
|
||||||
|
else " "
|
||||||
|
if unicodedata.category(c)[0] in "MSP"
|
||||||
|
else c
|
||||||
|
for c in unicodedata.normalize("NFKD", s)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def remove_symbols(s: str):
|
||||||
|
"""
|
||||||
|
Replace any other markers, symbols, punctuations with a space, keeping diacritics
|
||||||
|
"""
|
||||||
|
return "".join(
|
||||||
|
" " if unicodedata.category(c)[0] in "MSP" else c for c in unicodedata.normalize("NFKC", s)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BasicTextNormalizer:
|
||||||
|
def __init__(self, remove_diacritics: bool = False, split_letters: bool = False):
|
||||||
|
self.clean = remove_symbols_and_diacritics if remove_diacritics else remove_symbols
|
||||||
|
self.split_letters = split_letters
|
||||||
|
|
||||||
|
def __call__(self, s: str):
|
||||||
|
s = s.lower()
|
||||||
|
s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
|
||||||
|
s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
|
||||||
|
s = self.clean(s).lower()
|
||||||
|
|
||||||
|
if self.split_letters:
|
||||||
|
s = " ".join(regex.findall(r"\X", s, regex.U))
|
||||||
|
|
||||||
|
s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space
|
||||||
|
|
||||||
|
return s
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,543 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from fractions import Fraction
|
||||||
|
from typing import Iterator, List, Match, Optional, Union
|
||||||
|
|
||||||
|
from more_itertools import windowed
|
||||||
|
|
||||||
|
from .basic import remove_symbols_and_diacritics
|
||||||
|
|
||||||
|
|
||||||
|
class EnglishNumberNormalizer:
|
||||||
|
"""
|
||||||
|
Convert any spelled-out numbers into arabic numbers, while handling:
|
||||||
|
|
||||||
|
- remove any commas
|
||||||
|
- keep the suffixes such as: `1960s`, `274th`, `32nd`, etc.
|
||||||
|
- spell out currency symbols after the number. e.g. `$20 million` -> `20000000 dollars`
|
||||||
|
- spell out `one` and `ones`
|
||||||
|
- interpret successive single-digit numbers as nominal: `one oh one` -> `101`
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.zeros = {"o", "oh", "zero"}
|
||||||
|
self.ones = {
|
||||||
|
name: i
|
||||||
|
for i, name in enumerate(
|
||||||
|
[
|
||||||
|
"one",
|
||||||
|
"two",
|
||||||
|
"three",
|
||||||
|
"four",
|
||||||
|
"five",
|
||||||
|
"six",
|
||||||
|
"seven",
|
||||||
|
"eight",
|
||||||
|
"nine",
|
||||||
|
"ten",
|
||||||
|
"eleven",
|
||||||
|
"twelve",
|
||||||
|
"thirteen",
|
||||||
|
"fourteen",
|
||||||
|
"fifteen",
|
||||||
|
"sixteen",
|
||||||
|
"seventeen",
|
||||||
|
"eighteen",
|
||||||
|
"nineteen",
|
||||||
|
],
|
||||||
|
start=1,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
self.ones_plural = {
|
||||||
|
"sixes" if name == "six" else name + "s": (value, "s")
|
||||||
|
for name, value in self.ones.items()
|
||||||
|
}
|
||||||
|
self.ones_ordinal = {
|
||||||
|
"zeroth": (0, "th"),
|
||||||
|
"first": (1, "st"),
|
||||||
|
"second": (2, "nd"),
|
||||||
|
"third": (3, "rd"),
|
||||||
|
"fifth": (5, "th"),
|
||||||
|
"twelfth": (12, "th"),
|
||||||
|
**{
|
||||||
|
name + ("h" if name.endswith("t") else "th"): (value, "th")
|
||||||
|
for name, value in self.ones.items()
|
||||||
|
if value > 3 and value != 5 and value != 12
|
||||||
|
},
|
||||||
|
}
|
||||||
|
self.ones_suffixed = {**self.ones_plural, **self.ones_ordinal}
|
||||||
|
|
||||||
|
self.tens = {
|
||||||
|
"twenty": 20,
|
||||||
|
"thirty": 30,
|
||||||
|
"forty": 40,
|
||||||
|
"fifty": 50,
|
||||||
|
"sixty": 60,
|
||||||
|
"seventy": 70,
|
||||||
|
"eighty": 80,
|
||||||
|
"ninety": 90,
|
||||||
|
}
|
||||||
|
self.tens_plural = {
|
||||||
|
name.replace("y", "ies"): (value, "s") for name, value in self.tens.items()
|
||||||
|
}
|
||||||
|
self.tens_ordinal = {
|
||||||
|
name.replace("y", "ieth"): (value, "th") for name, value in self.tens.items()
|
||||||
|
}
|
||||||
|
self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal}
|
||||||
|
|
||||||
|
self.multipliers = {
|
||||||
|
"hundred": 100,
|
||||||
|
"thousand": 1_000,
|
||||||
|
"million": 1_000_000,
|
||||||
|
"billion": 1_000_000_000,
|
||||||
|
"trillion": 1_000_000_000_000,
|
||||||
|
"quadrillion": 1_000_000_000_000_000,
|
||||||
|
"quintillion": 1_000_000_000_000_000_000,
|
||||||
|
"sextillion": 1_000_000_000_000_000_000_000,
|
||||||
|
"septillion": 1_000_000_000_000_000_000_000_000,
|
||||||
|
"octillion": 1_000_000_000_000_000_000_000_000_000,
|
||||||
|
"nonillion": 1_000_000_000_000_000_000_000_000_000_000,
|
||||||
|
"decillion": 1_000_000_000_000_000_000_000_000_000_000_000,
|
||||||
|
}
|
||||||
|
self.multipliers_plural = {
|
||||||
|
name + "s": (value, "s") for name, value in self.multipliers.items()
|
||||||
|
}
|
||||||
|
self.multipliers_ordinal = {
|
||||||
|
name + "th": (value, "th") for name, value in self.multipliers.items()
|
||||||
|
}
|
||||||
|
self.multipliers_suffixed = {**self.multipliers_plural, **self.multipliers_ordinal}
|
||||||
|
self.decimals = {*self.ones, *self.tens, *self.zeros}
|
||||||
|
|
||||||
|
self.preceding_prefixers = {
|
||||||
|
"minus": "-",
|
||||||
|
"negative": "-",
|
||||||
|
"plus": "+",
|
||||||
|
"positive": "+",
|
||||||
|
}
|
||||||
|
self.following_prefixers = {
|
||||||
|
"pound": "£",
|
||||||
|
"pounds": "£",
|
||||||
|
"euro": "€",
|
||||||
|
"euros": "€",
|
||||||
|
"dollar": "$",
|
||||||
|
"dollars": "$",
|
||||||
|
"cent": "¢",
|
||||||
|
"cents": "¢",
|
||||||
|
}
|
||||||
|
self.prefixes = set(
|
||||||
|
list(self.preceding_prefixers.values()) + list(self.following_prefixers.values())
|
||||||
|
)
|
||||||
|
self.suffixers = {
|
||||||
|
"per": {"cent": "%"},
|
||||||
|
"percent": "%",
|
||||||
|
}
|
||||||
|
self.specials = {"and", "double", "triple", "point"}
|
||||||
|
|
||||||
|
self.words = set(
|
||||||
|
[
|
||||||
|
key
|
||||||
|
for mapping in [
|
||||||
|
self.zeros,
|
||||||
|
self.ones,
|
||||||
|
self.ones_suffixed,
|
||||||
|
self.tens,
|
||||||
|
self.tens_suffixed,
|
||||||
|
self.multipliers,
|
||||||
|
self.multipliers_suffixed,
|
||||||
|
self.preceding_prefixers,
|
||||||
|
self.following_prefixers,
|
||||||
|
self.suffixers,
|
||||||
|
self.specials,
|
||||||
|
]
|
||||||
|
for key in mapping
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.literal_words = {"one", "ones"}
|
||||||
|
|
||||||
|
def process_words(self, words: List[str]) -> Iterator[str]:
|
||||||
|
prefix: Optional[str] = None
|
||||||
|
value: Optional[Union[str, int]] = None
|
||||||
|
skip = False
|
||||||
|
|
||||||
|
def to_fraction(s: str):
|
||||||
|
try:
|
||||||
|
return Fraction(s)
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def output(result: Union[str, int]):
|
||||||
|
nonlocal prefix, value
|
||||||
|
result = str(result)
|
||||||
|
if prefix is not None:
|
||||||
|
result = prefix + result
|
||||||
|
value = None
|
||||||
|
prefix = None
|
||||||
|
return result
|
||||||
|
|
||||||
|
if len(words) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
for prev, current, next in windowed([None] + words + [None], 3):
|
||||||
|
if skip:
|
||||||
|
skip = False
|
||||||
|
continue
|
||||||
|
|
||||||
|
next_is_numeric = next is not None and re.match(r"^\d+(\.\d+)?$", next)
|
||||||
|
has_prefix = current[0] in self.prefixes
|
||||||
|
current_without_prefix = current[1:] if has_prefix else current
|
||||||
|
if re.match(r"^\d+(\.\d+)?$", current_without_prefix):
|
||||||
|
# arabic numbers (potentially with signs and fractions)
|
||||||
|
f = to_fraction(current_without_prefix)
|
||||||
|
assert f is not None
|
||||||
|
if value is not None:
|
||||||
|
if isinstance(value, str) and value.endswith("."):
|
||||||
|
# concatenate decimals / ip address components
|
||||||
|
value = str(value) + str(current)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
yield output(value)
|
||||||
|
|
||||||
|
prefix = current[0] if has_prefix else prefix
|
||||||
|
if f.denominator == 1:
|
||||||
|
value = f.numerator # store integers as int
|
||||||
|
else:
|
||||||
|
value = current_without_prefix
|
||||||
|
elif current not in self.words:
|
||||||
|
# non-numeric words
|
||||||
|
if value is not None:
|
||||||
|
yield output(value)
|
||||||
|
yield output(current)
|
||||||
|
elif current in self.zeros:
|
||||||
|
value = str(value or "") + "0"
|
||||||
|
elif current in self.ones:
|
||||||
|
ones = self.ones[current]
|
||||||
|
|
||||||
|
if value is None:
|
||||||
|
value = ones
|
||||||
|
elif isinstance(value, str) or prev in self.ones:
|
||||||
|
if prev in self.tens and ones < 10: # replace the last zero with the digit
|
||||||
|
assert value[-1] == "0"
|
||||||
|
value = value[:-1] + str(ones)
|
||||||
|
else:
|
||||||
|
value = str(value) + str(ones)
|
||||||
|
elif ones < 10:
|
||||||
|
if value % 10 == 0:
|
||||||
|
value += ones
|
||||||
|
else:
|
||||||
|
value = str(value) + str(ones)
|
||||||
|
else: # eleven to nineteen
|
||||||
|
if value % 100 == 0:
|
||||||
|
value += ones
|
||||||
|
else:
|
||||||
|
value = str(value) + str(ones)
|
||||||
|
elif current in self.ones_suffixed:
|
||||||
|
# ordinal or cardinal; yield the number right away
|
||||||
|
ones, suffix = self.ones_suffixed[current]
|
||||||
|
if value is None:
|
||||||
|
yield output(str(ones) + suffix)
|
||||||
|
elif isinstance(value, str) or prev in self.ones:
|
||||||
|
if prev in self.tens and ones < 10:
|
||||||
|
assert value[-1] == "0"
|
||||||
|
yield output(value[:-1] + str(ones) + suffix)
|
||||||
|
else:
|
||||||
|
yield output(str(value) + str(ones) + suffix)
|
||||||
|
elif ones < 10:
|
||||||
|
if value % 10 == 0:
|
||||||
|
yield output(str(value + ones) + suffix)
|
||||||
|
else:
|
||||||
|
yield output(str(value) + str(ones) + suffix)
|
||||||
|
else: # eleven to nineteen
|
||||||
|
if value % 100 == 0:
|
||||||
|
yield output(str(value + ones) + suffix)
|
||||||
|
else:
|
||||||
|
yield output(str(value) + str(ones) + suffix)
|
||||||
|
value = None
|
||||||
|
elif current in self.tens:
|
||||||
|
tens = self.tens[current]
|
||||||
|
if value is None:
|
||||||
|
value = tens
|
||||||
|
elif isinstance(value, str):
|
||||||
|
value = str(value) + str(tens)
|
||||||
|
else:
|
||||||
|
if value % 100 == 0:
|
||||||
|
value += tens
|
||||||
|
else:
|
||||||
|
value = str(value) + str(tens)
|
||||||
|
elif current in self.tens_suffixed:
|
||||||
|
# ordinal or cardinal; yield the number right away
|
||||||
|
tens, suffix = self.tens_suffixed[current]
|
||||||
|
if value is None:
|
||||||
|
yield output(str(tens) + suffix)
|
||||||
|
elif isinstance(value, str):
|
||||||
|
yield output(str(value) + str(tens) + suffix)
|
||||||
|
else:
|
||||||
|
if value % 100 == 0:
|
||||||
|
yield output(str(value + tens) + suffix)
|
||||||
|
else:
|
||||||
|
yield output(str(value) + str(tens) + suffix)
|
||||||
|
elif current in self.multipliers:
|
||||||
|
multiplier = self.multipliers[current]
|
||||||
|
if value is None:
|
||||||
|
value = multiplier
|
||||||
|
elif isinstance(value, str) or value == 0:
|
||||||
|
f = to_fraction(value)
|
||||||
|
p = f * multiplier if f is not None else None
|
||||||
|
if f is not None and p.denominator == 1:
|
||||||
|
value = p.numerator
|
||||||
|
else:
|
||||||
|
yield output(value)
|
||||||
|
value = multiplier
|
||||||
|
else:
|
||||||
|
before = value // 1000 * 1000
|
||||||
|
residual = value % 1000
|
||||||
|
value = before + residual * multiplier
|
||||||
|
elif current in self.multipliers_suffixed:
|
||||||
|
multiplier, suffix = self.multipliers_suffixed[current]
|
||||||
|
if value is None:
|
||||||
|
yield output(str(multiplier) + suffix)
|
||||||
|
elif isinstance(value, str):
|
||||||
|
f = to_fraction(value)
|
||||||
|
p = f * multiplier if f is not None else None
|
||||||
|
if f is not None and p.denominator == 1:
|
||||||
|
yield output(str(p.numerator) + suffix)
|
||||||
|
else:
|
||||||
|
yield output(value)
|
||||||
|
yield output(str(multiplier) + suffix)
|
||||||
|
else: # int
|
||||||
|
before = value // 1000 * 1000
|
||||||
|
residual = value % 1000
|
||||||
|
value = before + residual * multiplier
|
||||||
|
yield output(str(value) + suffix)
|
||||||
|
value = None
|
||||||
|
elif current in self.preceding_prefixers:
|
||||||
|
# apply prefix (positive, minus, etc.) if it precedes a number
|
||||||
|
if value is not None:
|
||||||
|
yield output(value)
|
||||||
|
|
||||||
|
if next in self.words or next_is_numeric:
|
||||||
|
prefix = self.preceding_prefixers[current]
|
||||||
|
else:
|
||||||
|
yield output(current)
|
||||||
|
elif current in self.following_prefixers:
|
||||||
|
# apply prefix (dollars, cents, etc.) only after a number
|
||||||
|
if value is not None:
|
||||||
|
prefix = self.following_prefixers[current]
|
||||||
|
yield output(value)
|
||||||
|
else:
|
||||||
|
yield output(current)
|
||||||
|
elif current in self.suffixers:
|
||||||
|
# apply suffix symbols (percent -> '%')
|
||||||
|
if value is not None:
|
||||||
|
suffix = self.suffixers[current]
|
||||||
|
if isinstance(suffix, dict):
|
||||||
|
if next in suffix:
|
||||||
|
yield output(str(value) + suffix[next])
|
||||||
|
skip = True
|
||||||
|
else:
|
||||||
|
yield output(value)
|
||||||
|
yield output(current)
|
||||||
|
else:
|
||||||
|
yield output(str(value) + suffix)
|
||||||
|
else:
|
||||||
|
yield output(current)
|
||||||
|
elif current in self.specials:
|
||||||
|
if next not in self.words and not next_is_numeric:
|
||||||
|
# apply special handling only if the next word can be numeric
|
||||||
|
if value is not None:
|
||||||
|
yield output(value)
|
||||||
|
yield output(current)
|
||||||
|
elif current == "and":
|
||||||
|
# ignore "and" after hundreds, thousands, etc.
|
||||||
|
if prev not in self.multipliers:
|
||||||
|
if value is not None:
|
||||||
|
yield output(value)
|
||||||
|
yield output(current)
|
||||||
|
elif current == "double" or current == "triple":
|
||||||
|
if next in self.ones or next in self.zeros:
|
||||||
|
repeats = 2 if current == "double" else 3
|
||||||
|
ones = self.ones.get(next, 0)
|
||||||
|
value = str(value or "") + str(ones) * repeats
|
||||||
|
skip = True
|
||||||
|
else:
|
||||||
|
if value is not None:
|
||||||
|
yield output(value)
|
||||||
|
yield output(current)
|
||||||
|
elif current == "point":
|
||||||
|
if next in self.decimals or next_is_numeric:
|
||||||
|
value = str(value or "") + "."
|
||||||
|
else:
|
||||||
|
# should all have been covered at this point
|
||||||
|
raise ValueError(f"Unexpected token: {current}")
|
||||||
|
else:
|
||||||
|
# all should have been covered at this point
|
||||||
|
raise ValueError(f"Unexpected token: {current}")
|
||||||
|
|
||||||
|
if value is not None:
|
||||||
|
yield output(value)
|
||||||
|
|
||||||
|
def preprocess(self, s: str):
|
||||||
|
# replace "<number> and a half" with "<number> point five"
|
||||||
|
results = []
|
||||||
|
|
||||||
|
segments = re.split(r"\band\s+a\s+half\b", s)
|
||||||
|
for i, segment in enumerate(segments):
|
||||||
|
if len(segment.strip()) == 0:
|
||||||
|
continue
|
||||||
|
if i == len(segments) - 1:
|
||||||
|
results.append(segment)
|
||||||
|
else:
|
||||||
|
results.append(segment)
|
||||||
|
last_word = segment.rsplit(maxsplit=2)[-1]
|
||||||
|
if last_word in self.decimals or last_word in self.multipliers:
|
||||||
|
results.append("point five")
|
||||||
|
else:
|
||||||
|
results.append("and a half")
|
||||||
|
|
||||||
|
s = " ".join(results)
|
||||||
|
|
||||||
|
# put a space at number/letter boundary
|
||||||
|
s = re.sub(r"([a-z])([0-9])", r"\1 \2", s)
|
||||||
|
s = re.sub(r"([0-9])([a-z])", r"\1 \2", s)
|
||||||
|
|
||||||
|
# but remove spaces which could be a suffix
|
||||||
|
s = re.sub(r"([0-9])\s+(st|nd|rd|th|s)\b", r"\1\2", s)
|
||||||
|
|
||||||
|
return s
|
||||||
|
|
||||||
|
def postprocess(self, s: str):
|
||||||
|
def combine_cents(m: Match):
|
||||||
|
try:
|
||||||
|
currency = m.group(1)
|
||||||
|
integer = m.group(2)
|
||||||
|
cents = int(m.group(3))
|
||||||
|
return f"{currency}{integer}.{cents:02d}"
|
||||||
|
except ValueError:
|
||||||
|
return m.string
|
||||||
|
|
||||||
|
def extract_cents(m: Match):
|
||||||
|
try:
|
||||||
|
return f"¢{int(m.group(1))}"
|
||||||
|
except ValueError:
|
||||||
|
return m.string
|
||||||
|
|
||||||
|
# apply currency postprocessing; "$2 and ¢7" -> "$2.07"
|
||||||
|
s = re.sub(r"([€£$])([0-9]+) (?:and )?¢([0-9]{1,2})\b", combine_cents, s)
|
||||||
|
s = re.sub(r"[€£$]0.([0-9]{1,2})\b", extract_cents, s)
|
||||||
|
|
||||||
|
# write "one(s)" instead of "1(s)", just for the readability
|
||||||
|
s = re.sub(r"\b1(s?)\b", r"one\1", s)
|
||||||
|
|
||||||
|
return s
|
||||||
|
|
||||||
|
def __call__(self, s: str):
|
||||||
|
s = self.preprocess(s)
|
||||||
|
s = " ".join(word for word in self.process_words(s.split()) if word is not None)
|
||||||
|
s = self.postprocess(s)
|
||||||
|
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
class EnglishSpellingNormalizer:
|
||||||
|
"""
|
||||||
|
Applies British-American spelling mappings as listed in [1].
|
||||||
|
|
||||||
|
[1] https://www.tysto.com/uk-us-spelling-list.html
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
mapping_path = os.path.join(os.path.dirname(__file__), "english.json")
|
||||||
|
self.mapping = json.load(open(mapping_path))
|
||||||
|
|
||||||
|
def __call__(self, s: str):
|
||||||
|
return " ".join(self.mapping.get(word, word) for word in s.split())
|
||||||
|
|
||||||
|
|
||||||
|
class EnglishTextNormalizer:
|
||||||
|
def __init__(self):
|
||||||
|
self.ignore_patterns = r"\b(hmm|mm|mhm|mmm|uh|um)\b"
|
||||||
|
self.replacers = {
|
||||||
|
# common contractions
|
||||||
|
r"\bwon't\b": "will not",
|
||||||
|
r"\bcan't\b": "can not",
|
||||||
|
r"\blet's\b": "let us",
|
||||||
|
r"\bain't\b": "aint",
|
||||||
|
r"\by'all\b": "you all",
|
||||||
|
r"\bwanna\b": "want to",
|
||||||
|
r"\bgotta\b": "got to",
|
||||||
|
r"\bgonna\b": "going to",
|
||||||
|
r"\bi'ma\b": "i am going to",
|
||||||
|
r"\bimma\b": "i am going to",
|
||||||
|
r"\bwoulda\b": "would have",
|
||||||
|
r"\bcoulda\b": "could have",
|
||||||
|
r"\bshoulda\b": "should have",
|
||||||
|
r"\bma'am\b": "madam",
|
||||||
|
# contractions in titles/prefixes
|
||||||
|
r"\bmr\b": "mister ",
|
||||||
|
r"\bmrs\b": "missus ",
|
||||||
|
r"\bst\b": "saint ",
|
||||||
|
r"\bdr\b": "doctor ",
|
||||||
|
r"\bprof\b": "professor ",
|
||||||
|
r"\bcapt\b": "captain ",
|
||||||
|
r"\bgov\b": "governor ",
|
||||||
|
r"\bald\b": "alderman ",
|
||||||
|
r"\bgen\b": "general ",
|
||||||
|
r"\bsen\b": "senator ",
|
||||||
|
r"\brep\b": "representative ",
|
||||||
|
r"\bpres\b": "president ",
|
||||||
|
r"\brev\b": "reverend ",
|
||||||
|
r"\bhon\b": "honorable ",
|
||||||
|
r"\basst\b": "assistant ",
|
||||||
|
r"\bassoc\b": "associate ",
|
||||||
|
r"\blt\b": "lieutenant ",
|
||||||
|
r"\bcol\b": "colonel ",
|
||||||
|
r"\bjr\b": "junior ",
|
||||||
|
r"\bsr\b": "senior ",
|
||||||
|
r"\besq\b": "esquire ",
|
||||||
|
# prefect tenses, ideally it should be any past participles, but it's harder..
|
||||||
|
r"'d been\b": " had been",
|
||||||
|
r"'s been\b": " has been",
|
||||||
|
r"'d gone\b": " had gone",
|
||||||
|
r"'s gone\b": " has gone",
|
||||||
|
r"'d done\b": " had done", # "'s done" is ambiguous
|
||||||
|
r"'s got\b": " has got",
|
||||||
|
# general contractions
|
||||||
|
r"n't\b": " not",
|
||||||
|
r"'re\b": " are",
|
||||||
|
r"'s\b": " is",
|
||||||
|
r"'d\b": " would",
|
||||||
|
r"'ll\b": " will",
|
||||||
|
r"'t\b": " not",
|
||||||
|
r"'ve\b": " have",
|
||||||
|
r"'m\b": " am",
|
||||||
|
}
|
||||||
|
self.standardize_numbers = EnglishNumberNormalizer()
|
||||||
|
self.standardize_spellings = EnglishSpellingNormalizer()
|
||||||
|
|
||||||
|
def __call__(self, s: str):
|
||||||
|
s = s.lower()
|
||||||
|
|
||||||
|
s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
|
||||||
|
s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
|
||||||
|
s = re.sub(self.ignore_patterns, "", s)
|
||||||
|
s = re.sub(r"\s+'", "'", s) # standardize when there's a space before an apostrophe
|
||||||
|
|
||||||
|
for pattern, replacement in self.replacers.items():
|
||||||
|
s = re.sub(pattern, replacement, s)
|
||||||
|
|
||||||
|
s = re.sub(r"(\d),(\d)", r"\1\2", s) # remove commas between digits
|
||||||
|
s = re.sub(r"\.([^0-9]|$)", r" \1", s) # remove periods not followed by numbers
|
||||||
|
s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep some symbols for numerics
|
||||||
|
|
||||||
|
s = self.standardize_numbers(s)
|
||||||
|
s = self.standardize_spellings(s)
|
||||||
|
|
||||||
|
# now remove prefix/suffix symbols that are not preceded/followed by numbers
|
||||||
|
s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s)
|
||||||
|
s = re.sub(r"([^0-9])%", r"\1 ", s)
|
||||||
|
|
||||||
|
s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space
|
||||||
|
|
||||||
|
return s
|
|
@ -0,0 +1,331 @@
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from transformers import GPT2TokenizerFast
|
||||||
|
|
||||||
|
LANGUAGES = {
|
||||||
|
"en": "english",
|
||||||
|
"zh": "chinese",
|
||||||
|
"de": "german",
|
||||||
|
"es": "spanish",
|
||||||
|
"ru": "russian",
|
||||||
|
"ko": "korean",
|
||||||
|
"fr": "french",
|
||||||
|
"ja": "japanese",
|
||||||
|
"pt": "portuguese",
|
||||||
|
"tr": "turkish",
|
||||||
|
"pl": "polish",
|
||||||
|
"ca": "catalan",
|
||||||
|
"nl": "dutch",
|
||||||
|
"ar": "arabic",
|
||||||
|
"sv": "swedish",
|
||||||
|
"it": "italian",
|
||||||
|
"id": "indonesian",
|
||||||
|
"hi": "hindi",
|
||||||
|
"fi": "finnish",
|
||||||
|
"vi": "vietnamese",
|
||||||
|
"iw": "hebrew",
|
||||||
|
"uk": "ukrainian",
|
||||||
|
"el": "greek",
|
||||||
|
"ms": "malay",
|
||||||
|
"cs": "czech",
|
||||||
|
"ro": "romanian",
|
||||||
|
"da": "danish",
|
||||||
|
"hu": "hungarian",
|
||||||
|
"ta": "tamil",
|
||||||
|
"no": "norwegian",
|
||||||
|
"th": "thai",
|
||||||
|
"ur": "urdu",
|
||||||
|
"hr": "croatian",
|
||||||
|
"bg": "bulgarian",
|
||||||
|
"lt": "lithuanian",
|
||||||
|
"la": "latin",
|
||||||
|
"mi": "maori",
|
||||||
|
"ml": "malayalam",
|
||||||
|
"cy": "welsh",
|
||||||
|
"sk": "slovak",
|
||||||
|
"te": "telugu",
|
||||||
|
"fa": "persian",
|
||||||
|
"lv": "latvian",
|
||||||
|
"bn": "bengali",
|
||||||
|
"sr": "serbian",
|
||||||
|
"az": "azerbaijani",
|
||||||
|
"sl": "slovenian",
|
||||||
|
"kn": "kannada",
|
||||||
|
"et": "estonian",
|
||||||
|
"mk": "macedonian",
|
||||||
|
"br": "breton",
|
||||||
|
"eu": "basque",
|
||||||
|
"is": "icelandic",
|
||||||
|
"hy": "armenian",
|
||||||
|
"ne": "nepali",
|
||||||
|
"mn": "mongolian",
|
||||||
|
"bs": "bosnian",
|
||||||
|
"kk": "kazakh",
|
||||||
|
"sq": "albanian",
|
||||||
|
"sw": "swahili",
|
||||||
|
"gl": "galician",
|
||||||
|
"mr": "marathi",
|
||||||
|
"pa": "punjabi",
|
||||||
|
"si": "sinhala",
|
||||||
|
"km": "khmer",
|
||||||
|
"sn": "shona",
|
||||||
|
"yo": "yoruba",
|
||||||
|
"so": "somali",
|
||||||
|
"af": "afrikaans",
|
||||||
|
"oc": "occitan",
|
||||||
|
"ka": "georgian",
|
||||||
|
"be": "belarusian",
|
||||||
|
"tg": "tajik",
|
||||||
|
"sd": "sindhi",
|
||||||
|
"gu": "gujarati",
|
||||||
|
"am": "amharic",
|
||||||
|
"yi": "yiddish",
|
||||||
|
"lo": "lao",
|
||||||
|
"uz": "uzbek",
|
||||||
|
"fo": "faroese",
|
||||||
|
"ht": "haitian creole",
|
||||||
|
"ps": "pashto",
|
||||||
|
"tk": "turkmen",
|
||||||
|
"nn": "nynorsk",
|
||||||
|
"mt": "maltese",
|
||||||
|
"sa": "sanskrit",
|
||||||
|
"lb": "luxembourgish",
|
||||||
|
"my": "myanmar",
|
||||||
|
"bo": "tibetan",
|
||||||
|
"tl": "tagalog",
|
||||||
|
"mg": "malagasy",
|
||||||
|
"as": "assamese",
|
||||||
|
"tt": "tatar",
|
||||||
|
"haw": "hawaiian",
|
||||||
|
"ln": "lingala",
|
||||||
|
"ha": "hausa",
|
||||||
|
"ba": "bashkir",
|
||||||
|
"jw": "javanese",
|
||||||
|
"su": "sundanese",
|
||||||
|
}
|
||||||
|
|
||||||
|
# language code lookup by name, with a few language aliases
|
||||||
|
TO_LANGUAGE_CODE = {
|
||||||
|
**{language: code for code, language in LANGUAGES.items()},
|
||||||
|
"burmese": "my",
|
||||||
|
"valencian": "ca",
|
||||||
|
"flemish": "nl",
|
||||||
|
"haitian": "ht",
|
||||||
|
"letzeburgesch": "lb",
|
||||||
|
"pushto": "ps",
|
||||||
|
"panjabi": "pa",
|
||||||
|
"moldavian": "ro",
|
||||||
|
"moldovan": "ro",
|
||||||
|
"sinhalese": "si",
|
||||||
|
"castilian": "es",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Tokenizer:
|
||||||
|
"""A thin wrapper around `GPT2TokenizerFast` providing quick access to special tokens"""
|
||||||
|
|
||||||
|
tokenizer: "GPT2TokenizerFast"
|
||||||
|
language: Optional[str]
|
||||||
|
sot_sequence: Tuple[int]
|
||||||
|
|
||||||
|
def encode(self, text, **kwargs):
|
||||||
|
return self.tokenizer.encode(text, **kwargs)
|
||||||
|
|
||||||
|
def decode(self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs):
|
||||||
|
return self.tokenizer.decode(token_ids, **kwargs)
|
||||||
|
|
||||||
|
def decode_with_timestamps(self, tokens) -> str:
|
||||||
|
"""
|
||||||
|
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`.
|
||||||
|
This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
|
||||||
|
"""
|
||||||
|
outputs = [[]]
|
||||||
|
for token in tokens:
|
||||||
|
if token >= self.timestamp_begin:
|
||||||
|
timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
|
||||||
|
outputs.append(timestamp)
|
||||||
|
outputs.append([])
|
||||||
|
else:
|
||||||
|
outputs[-1].append(token)
|
||||||
|
outputs = [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
|
||||||
|
return "".join(outputs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
@lru_cache()
|
||||||
|
def eot(self) -> int:
|
||||||
|
return self.tokenizer.eos_token_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
@lru_cache()
|
||||||
|
def sot(self) -> int:
|
||||||
|
return self._get_single_token_id("<|startoftranscript|>")
|
||||||
|
|
||||||
|
@property
|
||||||
|
@lru_cache()
|
||||||
|
def sot_lm(self) -> int:
|
||||||
|
return self._get_single_token_id("<|startoflm|>")
|
||||||
|
|
||||||
|
@property
|
||||||
|
@lru_cache()
|
||||||
|
def sot_prev(self) -> int:
|
||||||
|
return self._get_single_token_id("<|startofprev|>")
|
||||||
|
|
||||||
|
@property
|
||||||
|
@lru_cache()
|
||||||
|
def no_speech(self) -> int:
|
||||||
|
return self._get_single_token_id("<|nospeech|>")
|
||||||
|
|
||||||
|
@property
|
||||||
|
@lru_cache()
|
||||||
|
def no_timestamps(self) -> int:
|
||||||
|
return self._get_single_token_id("<|notimestamps|>")
|
||||||
|
|
||||||
|
@property
|
||||||
|
@lru_cache()
|
||||||
|
def timestamp_begin(self) -> int:
|
||||||
|
return self.tokenizer.all_special_ids[-1] + 1
|
||||||
|
|
||||||
|
@property
|
||||||
|
@lru_cache()
|
||||||
|
def language_token(self) -> int:
|
||||||
|
"""Returns the token id corresponding to the value of the `language` field"""
|
||||||
|
if self.language is None:
|
||||||
|
raise ValueError(f"This tokenizer does not have language token configured")
|
||||||
|
|
||||||
|
additional_tokens = dict(
|
||||||
|
zip(
|
||||||
|
self.tokenizer.additional_special_tokens,
|
||||||
|
self.tokenizer.additional_special_tokens_ids,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
candidate = f"<|{self.language}|>"
|
||||||
|
if candidate in additional_tokens:
|
||||||
|
return additional_tokens[candidate]
|
||||||
|
|
||||||
|
raise KeyError(f"Language {self.language} not found in tokenizer.")
|
||||||
|
|
||||||
|
@property
|
||||||
|
@lru_cache()
|
||||||
|
def all_language_tokens(self) -> Tuple[int]:
|
||||||
|
result = []
|
||||||
|
for token, token_id in zip(
|
||||||
|
self.tokenizer.additional_special_tokens,
|
||||||
|
self.tokenizer.additional_special_tokens_ids,
|
||||||
|
):
|
||||||
|
if token.strip("<|>") in LANGUAGES:
|
||||||
|
result.append(token_id)
|
||||||
|
return tuple(result)
|
||||||
|
|
||||||
|
@property
|
||||||
|
@lru_cache()
|
||||||
|
def all_language_codes(self) -> Tuple[str]:
|
||||||
|
return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens)
|
||||||
|
|
||||||
|
@property
|
||||||
|
@lru_cache()
|
||||||
|
def sot_sequence_including_notimestamps(self) -> Tuple[int]:
|
||||||
|
return tuple(list(self.sot_sequence) + [self.no_timestamps])
|
||||||
|
|
||||||
|
@property
|
||||||
|
@lru_cache()
|
||||||
|
def non_speech_tokens(self) -> Tuple[int]:
|
||||||
|
"""
|
||||||
|
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
|
||||||
|
annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
|
||||||
|
|
||||||
|
- ♪♪♪
|
||||||
|
- ( SPEAKING FOREIGN LANGUAGE )
|
||||||
|
- [DAVID] Hey there,
|
||||||
|
|
||||||
|
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
|
||||||
|
"""
|
||||||
|
symbols = list("\"#()*+/:;<=>@[\\]^_`{|}~「」『』")
|
||||||
|
symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
|
||||||
|
|
||||||
|
# symbols that may be a single token or multiple tokens depending on the tokenizer.
|
||||||
|
# In case they're multiple tokens, suppress the first token, which is safe because:
|
||||||
|
# These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
|
||||||
|
# in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
|
||||||
|
miscellaneous = set("♩♪♫♬♭♮♯")
|
||||||
|
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
|
||||||
|
|
||||||
|
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
|
||||||
|
result = {self.tokenizer.encode(" -")[0], self.tokenizer.encode(" '")[0]}
|
||||||
|
for symbol in symbols + list(miscellaneous):
|
||||||
|
for tokens in [self.tokenizer.encode(symbol), self.tokenizer.encode(" " + symbol)]:
|
||||||
|
if len(tokens) == 1 or symbol in miscellaneous:
|
||||||
|
result.add(tokens[0])
|
||||||
|
|
||||||
|
return tuple(sorted(result))
|
||||||
|
|
||||||
|
def _get_single_token_id(self, text) -> int:
|
||||||
|
tokens = self.tokenizer.encode(text)
|
||||||
|
assert len(tokens) == 1, f"{text} is not encoded as a single token"
|
||||||
|
return tokens[0]
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=None)
|
||||||
|
def build_tokenizer(name: str = "gpt2"):
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
path = os.path.join(os.path.dirname(__file__), "assets", name)
|
||||||
|
tokenizer = GPT2TokenizerFast.from_pretrained(path)
|
||||||
|
|
||||||
|
specials = [
|
||||||
|
"<|startoftranscript|>",
|
||||||
|
*[f"<|{lang}|>" for lang in LANGUAGES.keys()],
|
||||||
|
"<|translate|>",
|
||||||
|
"<|transcribe|>",
|
||||||
|
"<|startoflm|>",
|
||||||
|
"<|startofprev|>",
|
||||||
|
"<|nospeech|>",
|
||||||
|
"<|notimestamps|>",
|
||||||
|
]
|
||||||
|
|
||||||
|
tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=None)
|
||||||
|
def get_tokenizer(
|
||||||
|
multilingual: bool,
|
||||||
|
*,
|
||||||
|
task: Optional[str] = None, # Literal["transcribe", "translate", None]
|
||||||
|
language: Optional[str] = None,
|
||||||
|
) -> Tokenizer:
|
||||||
|
if language is not None:
|
||||||
|
language = language.lower()
|
||||||
|
if language not in LANGUAGES:
|
||||||
|
if language in TO_LANGUAGE_CODE:
|
||||||
|
language = TO_LANGUAGE_CODE[language]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported language: {language}")
|
||||||
|
|
||||||
|
if multilingual:
|
||||||
|
tokenizer_name = "multilingual"
|
||||||
|
task = task or "transcribe"
|
||||||
|
language = language or "en"
|
||||||
|
else:
|
||||||
|
tokenizer_name = "gpt2"
|
||||||
|
task = None
|
||||||
|
language = None
|
||||||
|
|
||||||
|
tokenizer = build_tokenizer(name=tokenizer_name)
|
||||||
|
all_special_ids: List[int] = tokenizer.all_special_ids
|
||||||
|
sot: int = all_special_ids[1]
|
||||||
|
translate: int = all_special_ids[-6]
|
||||||
|
transcribe: int = all_special_ids[-5]
|
||||||
|
|
||||||
|
langs = tuple(LANGUAGES.keys())
|
||||||
|
sot_sequence = [sot]
|
||||||
|
if language is not None:
|
||||||
|
sot_sequence.append(sot + 1 + langs.index(language))
|
||||||
|
if task is not None:
|
||||||
|
sot_sequence.append(transcribe if task == "transcribe" else translate)
|
||||||
|
|
||||||
|
return Tokenizer(tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence))
|
|
@ -0,0 +1,207 @@
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import warnings
|
||||||
|
from typing import List, Optional, Tuple, Union, TYPE_CHECKING
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram
|
||||||
|
from .decoding import DecodingOptions, DecodingResult
|
||||||
|
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
||||||
|
from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, write_vtt, write_srt
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .model import Whisper
|
||||||
|
|
||||||
|
|
||||||
|
def transcribe(
|
||||||
|
model: "Whisper",
|
||||||
|
audio: Union[str, np.ndarray, torch.Tensor],
|
||||||
|
*,
|
||||||
|
verbose: Optional[bool] = None,
|
||||||
|
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
||||||
|
compression_ratio_threshold: Optional[float] = 2.4,
|
||||||
|
logprob_threshold: Optional[float] = -1.0,
|
||||||
|
no_speech_threshold: Optional[float] = 0.6,
|
||||||
|
condition_on_previous_text: bool = True,
|
||||||
|
force_extraction: bool = False,
|
||||||
|
**decode_options,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Transcribe an audio file using Whisper
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
model: Whisper
|
||||||
|
The Whisper model instance
|
||||||
|
|
||||||
|
audio: Union[str, np.ndarray, torch.Tensor]
|
||||||
|
The path to the audio file to open, or the audio waveform
|
||||||
|
|
||||||
|
verbose: bool
|
||||||
|
Whether to display the text being decoded to the console. If True, displays all the details,
|
||||||
|
If False, displays minimal details. If None, does not display anything
|
||||||
|
|
||||||
|
temperature: Union[float, Tuple[float, ...]]
|
||||||
|
Temperature for sampling. It can be a tuple of temperatures, which will be successfully used
|
||||||
|
upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
|
||||||
|
|
||||||
|
compression_ratio_threshold: float
|
||||||
|
If the gzip compression ratio is above this value, treat as failed
|
||||||
|
|
||||||
|
logprob_threshold: float
|
||||||
|
If the average log probability over sampled tokens is below this value, treat as failed
|
||||||
|
|
||||||
|
no_speech_threshold: float
|
||||||
|
If the no_speech probability is higher than this value AND the average log probability
|
||||||
|
over sampled tokens is below `logprob_threshold`, consider the segment as silent
|
||||||
|
|
||||||
|
condition_on_previous_text: bool
|
||||||
|
if True, the previous output of the model is provided as a prompt for the next window;
|
||||||
|
disabling may make the text inconsistent across windows, but the model becomes less prone to
|
||||||
|
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
|
||||||
|
|
||||||
|
decode_options: dict
|
||||||
|
Keyword arguments to construct `DecodingOptions` instances
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
||||||
|
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
||||||
|
"""
|
||||||
|
dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
|
||||||
|
if model.device == torch.device("cpu"):
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
warnings.warn("Performing inference on CPU when CUDA is available")
|
||||||
|
if dtype == torch.float16:
|
||||||
|
warnings.warn("FP16 is not supported on CPU; using FP32 instead")
|
||||||
|
dtype = torch.float32
|
||||||
|
|
||||||
|
if dtype == torch.float32:
|
||||||
|
decode_options["fp16"] = False
|
||||||
|
|
||||||
|
mel = log_mel_spectrogram(audio)
|
||||||
|
|
||||||
|
all_segments = []
|
||||||
|
def add_segment(
|
||||||
|
*, start: float, end: float, encoder_embeddings
|
||||||
|
):
|
||||||
|
|
||||||
|
all_segments.append(
|
||||||
|
{
|
||||||
|
"start": start,
|
||||||
|
"end": end,
|
||||||
|
"encoder_embeddings":encoder_embeddings,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# show the progress bar when verbose is False (otherwise the transcribed text will be printed)
|
||||||
|
num_frames = mel.shape[-1]
|
||||||
|
seek = 0
|
||||||
|
previous_seek_value = seek
|
||||||
|
sample_skip = 3000 #
|
||||||
|
with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar:
|
||||||
|
while seek < num_frames:
|
||||||
|
# seek是开始的帧数
|
||||||
|
end_seek = min(seek + sample_skip, num_frames)
|
||||||
|
segment = pad_or_trim(mel[:,seek:seek+sample_skip], N_FRAMES).to(model.device).to(dtype)
|
||||||
|
|
||||||
|
single = segment.ndim == 2
|
||||||
|
if single:
|
||||||
|
segment = segment.unsqueeze(0)
|
||||||
|
if dtype == torch.float16:
|
||||||
|
segment = segment.half()
|
||||||
|
audio_features, embeddings = model.encoder(segment, include_embeddings = True)
|
||||||
|
|
||||||
|
encoder_embeddings = embeddings
|
||||||
|
#print(f"encoder_embeddings shape {encoder_embeddings.shape}")
|
||||||
|
add_segment(
|
||||||
|
start=seek,
|
||||||
|
end=end_seek,
|
||||||
|
#text_tokens=tokens,
|
||||||
|
#result=result,
|
||||||
|
encoder_embeddings=encoder_embeddings,
|
||||||
|
)
|
||||||
|
seek+=sample_skip
|
||||||
|
|
||||||
|
return dict(segments=all_segments)
|
||||||
|
|
||||||
|
|
||||||
|
def cli():
|
||||||
|
from . import available_models
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
|
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
||||||
|
parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
|
||||||
|
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
||||||
|
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
||||||
|
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
||||||
|
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
||||||
|
|
||||||
|
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
|
||||||
|
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
|
||||||
|
|
||||||
|
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
|
||||||
|
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
|
||||||
|
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
|
||||||
|
parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
|
||||||
|
parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
|
||||||
|
|
||||||
|
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
|
||||||
|
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
|
||||||
|
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
|
||||||
|
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
|
||||||
|
|
||||||
|
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
|
||||||
|
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
|
||||||
|
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
|
||||||
|
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
|
||||||
|
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
||||||
|
|
||||||
|
args = parser.parse_args().__dict__
|
||||||
|
model_name: str = args.pop("model")
|
||||||
|
model_dir: str = args.pop("model_dir")
|
||||||
|
output_dir: str = args.pop("output_dir")
|
||||||
|
device: str = args.pop("device")
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
|
||||||
|
if args["language"] is not None:
|
||||||
|
warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
|
||||||
|
args["language"] = "en"
|
||||||
|
|
||||||
|
temperature = args.pop("temperature")
|
||||||
|
temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback")
|
||||||
|
if temperature_increment_on_fallback is not None:
|
||||||
|
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
|
||||||
|
else:
|
||||||
|
temperature = [temperature]
|
||||||
|
|
||||||
|
threads = args.pop("threads")
|
||||||
|
if threads > 0:
|
||||||
|
torch.set_num_threads(threads)
|
||||||
|
|
||||||
|
from . import load_model
|
||||||
|
model = load_model(model_name, device=device, download_root=model_dir)
|
||||||
|
|
||||||
|
for audio_path in args.pop("audio"):
|
||||||
|
result = transcribe(model, audio_path, temperature=temperature, **args)
|
||||||
|
|
||||||
|
audio_basename = os.path.basename(audio_path)
|
||||||
|
|
||||||
|
# save TXT
|
||||||
|
with open(os.path.join(output_dir, audio_basename + ".txt"), "w", encoding="utf-8") as txt:
|
||||||
|
write_txt(result["segments"], file=txt)
|
||||||
|
|
||||||
|
# save VTT
|
||||||
|
with open(os.path.join(output_dir, audio_basename + ".vtt"), "w", encoding="utf-8") as vtt:
|
||||||
|
write_vtt(result["segments"], file=vtt)
|
||||||
|
|
||||||
|
# save SRT
|
||||||
|
with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt:
|
||||||
|
write_srt(result["segments"], file=srt)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
cli()
|
|
@ -0,0 +1,87 @@
|
||||||
|
import zlib
|
||||||
|
from typing import Iterator, TextIO
|
||||||
|
|
||||||
|
|
||||||
|
def exact_div(x, y):
|
||||||
|
assert x % y == 0
|
||||||
|
return x // y
|
||||||
|
|
||||||
|
|
||||||
|
def str2bool(string):
|
||||||
|
str2val = {"True": True, "False": False}
|
||||||
|
if string in str2val:
|
||||||
|
return str2val[string]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
|
||||||
|
|
||||||
|
|
||||||
|
def optional_int(string):
|
||||||
|
return None if string == "None" else int(string)
|
||||||
|
|
||||||
|
|
||||||
|
def optional_float(string):
|
||||||
|
return None if string == "None" else float(string)
|
||||||
|
|
||||||
|
|
||||||
|
def compression_ratio(text) -> float:
|
||||||
|
return len(text) / len(zlib.compress(text.encode("utf-8")))
|
||||||
|
|
||||||
|
|
||||||
|
def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = '.'):
|
||||||
|
assert seconds >= 0, "non-negative timestamp expected"
|
||||||
|
milliseconds = round(seconds * 1000.0)
|
||||||
|
|
||||||
|
hours = milliseconds // 3_600_000
|
||||||
|
milliseconds -= hours * 3_600_000
|
||||||
|
|
||||||
|
minutes = milliseconds // 60_000
|
||||||
|
milliseconds -= minutes * 60_000
|
||||||
|
|
||||||
|
seconds = milliseconds // 1_000
|
||||||
|
milliseconds -= seconds * 1_000
|
||||||
|
|
||||||
|
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
|
||||||
|
return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
|
||||||
|
|
||||||
|
|
||||||
|
def write_txt(transcript: Iterator[dict], file: TextIO):
|
||||||
|
for segment in transcript:
|
||||||
|
print(segment['text'].strip(), file=file, flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
def write_vtt(transcript: Iterator[dict], file: TextIO):
|
||||||
|
print("WEBVTT\n", file=file)
|
||||||
|
for segment in transcript:
|
||||||
|
print(
|
||||||
|
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
|
||||||
|
f"{segment['text'].strip().replace('-->', '->')}\n",
|
||||||
|
file=file,
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def write_srt(transcript: Iterator[dict], file: TextIO):
|
||||||
|
"""
|
||||||
|
Write a transcript to a file in SRT format.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
from pathlib import Path
|
||||||
|
from whisper.utils import write_srt
|
||||||
|
|
||||||
|
result = transcribe(model, audio_path, temperature=temperature, **args)
|
||||||
|
|
||||||
|
# save SRT
|
||||||
|
audio_basename = Path(audio_path).stem
|
||||||
|
with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
|
||||||
|
write_srt(result["segments"], file=srt)
|
||||||
|
"""
|
||||||
|
for i, segment in enumerate(transcript, start=1):
|
||||||
|
# write srt lines
|
||||||
|
print(
|
||||||
|
f"{i}\n"
|
||||||
|
f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
|
||||||
|
f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
|
||||||
|
f"{segment['text'].strip().replace('-->', '->')}\n",
|
||||||
|
file=file,
|
||||||
|
flush=True,
|
||||||
|
)
|
|
@ -4,50 +4,19 @@ import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from transformers import AutoModelForCTC, AutoProcessor, Wav2Vec2Processor, HubertModel
|
from transformers import AutoModelForCTC, AutoProcessor, Wav2Vec2Processor, HubertModel
|
||||||
|
|
||||||
#import pyaudio
|
|
||||||
import soundfile as sf
|
|
||||||
import resampy
|
|
||||||
|
|
||||||
import queue
|
import queue
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
#from collections import deque
|
#from collections import deque
|
||||||
from threading import Thread, Event
|
from threading import Thread, Event
|
||||||
from io import BytesIO
|
|
||||||
|
|
||||||
|
from baseasr import BaseASR
|
||||||
|
|
||||||
def _read_frame(stream, exit_event, queue, chunk):
|
class NerfASR(BaseASR):
|
||||||
|
def __init__(self, opt, parent):
|
||||||
while True:
|
super().__init__(opt,parent)
|
||||||
if exit_event.is_set():
|
|
||||||
print(f'[INFO] read frame thread ends')
|
|
||||||
break
|
|
||||||
frame = stream.read(chunk, exception_on_overflow=False)
|
|
||||||
frame = np.frombuffer(frame, dtype=np.int16).astype(np.float32) / 32767 # [chunk]
|
|
||||||
queue.put(frame)
|
|
||||||
|
|
||||||
def _play_frame(stream, exit_event, queue, chunk):
|
|
||||||
|
|
||||||
while True:
|
|
||||||
if exit_event.is_set():
|
|
||||||
print(f'[INFO] play frame thread ends')
|
|
||||||
break
|
|
||||||
frame = queue.get()
|
|
||||||
frame = (frame * 32767).astype(np.int16).tobytes()
|
|
||||||
stream.write(frame, chunk)
|
|
||||||
|
|
||||||
class ASR:
|
|
||||||
def __init__(self, opt):
|
|
||||||
|
|
||||||
self.opt = opt
|
|
||||||
|
|
||||||
self.play = opt.asr_play #false
|
|
||||||
|
|
||||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
self.fps = opt.fps # 20 ms per frame
|
|
||||||
self.sample_rate = 16000
|
|
||||||
self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000)
|
|
||||||
self.mode = 'live' if opt.asr_wav == '' else 'file'
|
|
||||||
|
|
||||||
if 'esperanto' in self.opt.asr_model:
|
if 'esperanto' in self.opt.asr_model:
|
||||||
self.audio_dim = 44
|
self.audio_dim = 44
|
||||||
elif 'deepspeech' in self.opt.asr_model:
|
elif 'deepspeech' in self.opt.asr_model:
|
||||||
|
@ -62,40 +31,11 @@ class ASR:
|
||||||
self.context_size = opt.m
|
self.context_size = opt.m
|
||||||
self.stride_left_size = opt.l
|
self.stride_left_size = opt.l
|
||||||
self.stride_right_size = opt.r
|
self.stride_right_size = opt.r
|
||||||
self.text = '[START]\n'
|
|
||||||
self.terminated = False
|
|
||||||
self.frames = []
|
|
||||||
self.inwarm = False
|
|
||||||
|
|
||||||
# pad left frames
|
# pad left frames
|
||||||
if self.stride_left_size > 0:
|
if self.stride_left_size > 0:
|
||||||
self.frames.extend([np.zeros(self.chunk, dtype=np.float32)] * self.stride_left_size)
|
self.frames.extend([np.zeros(self.chunk, dtype=np.float32)] * self.stride_left_size)
|
||||||
|
|
||||||
|
|
||||||
self.exit_event = Event()
|
|
||||||
#self.audio_instance = pyaudio.PyAudio() #not need
|
|
||||||
|
|
||||||
# create input stream
|
|
||||||
if self.mode == 'file': #live mode
|
|
||||||
self.file_stream = self.create_file_stream()
|
|
||||||
else:
|
|
||||||
self.queue = Queue()
|
|
||||||
self.input_stream = BytesIO()
|
|
||||||
self.output_queue = Queue()
|
|
||||||
# start a background process to read frames
|
|
||||||
#self.input_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=True, output=False, frames_per_buffer=self.chunk)
|
|
||||||
#self.queue = Queue()
|
|
||||||
#self.process_read_frame = Thread(target=_read_frame, args=(self.input_stream, self.exit_event, self.queue, self.chunk))
|
|
||||||
|
|
||||||
# play out the audio too...?
|
|
||||||
if self.play:
|
|
||||||
self.output_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=False, output=True, frames_per_buffer=self.chunk)
|
|
||||||
self.output_queue = Queue()
|
|
||||||
self.process_play_frame = Thread(target=_play_frame, args=(self.output_stream, self.exit_event, self.output_queue, self.chunk))
|
|
||||||
|
|
||||||
# current location of audio
|
|
||||||
self.idx = 0
|
|
||||||
|
|
||||||
# create wav2vec model
|
# create wav2vec model
|
||||||
print(f'[INFO] loading ASR model {self.opt.asr_model}...')
|
print(f'[INFO] loading ASR model {self.opt.asr_model}...')
|
||||||
if 'hubert' in self.opt.asr_model:
|
if 'hubert' in self.opt.asr_model:
|
||||||
|
@ -105,10 +45,6 @@ class ASR:
|
||||||
self.processor = AutoProcessor.from_pretrained(opt.asr_model)
|
self.processor = AutoProcessor.from_pretrained(opt.asr_model)
|
||||||
self.model = AutoModelForCTC.from_pretrained(opt.asr_model).to(self.device)
|
self.model = AutoModelForCTC.from_pretrained(opt.asr_model).to(self.device)
|
||||||
|
|
||||||
# prepare to save logits
|
|
||||||
if self.opt.asr_save_feats:
|
|
||||||
self.all_feats = []
|
|
||||||
|
|
||||||
# the extracted features
|
# the extracted features
|
||||||
# use a loop queue to efficiently record endless features: [f--t---][-------][-------]
|
# use a loop queue to efficiently record endless features: [f--t---][-------][-------]
|
||||||
self.feat_buffer_size = 4
|
self.feat_buffer_size = 4
|
||||||
|
@ -124,8 +60,20 @@ class ASR:
|
||||||
# warm up steps needed: mid + right + window_size + attention_size
|
# warm up steps needed: mid + right + window_size + attention_size
|
||||||
self.warm_up_steps = self.context_size + self.stride_left_size + self.stride_right_size #+ self.stride_left_size #+ 8 + 2 * 3
|
self.warm_up_steps = self.context_size + self.stride_left_size + self.stride_right_size #+ self.stride_left_size #+ 8 + 2 * 3
|
||||||
|
|
||||||
self.listening = False
|
def get_audio_frame(self):
|
||||||
self.playing = False
|
try:
|
||||||
|
frame = self.queue.get(block=False)
|
||||||
|
type = 0
|
||||||
|
#print(f'[INFO] get frame {frame.shape}')
|
||||||
|
except queue.Empty:
|
||||||
|
if self.parent and self.parent.curr_state>1: #播放自定义音频
|
||||||
|
frame = self.parent.get_audio_stream(self.parent.curr_state)
|
||||||
|
type = self.parent.curr_state
|
||||||
|
else:
|
||||||
|
frame = np.zeros(self.chunk, dtype=np.float32)
|
||||||
|
type = 1
|
||||||
|
|
||||||
|
return frame,type
|
||||||
|
|
||||||
def get_next_feat(self): #get audio embedding to nerf
|
def get_next_feat(self): #get audio embedding to nerf
|
||||||
# return a [1/8, 16] window, for the next input to nerf side.
|
# return a [1/8, 16] window, for the next input to nerf side.
|
||||||
|
@ -167,29 +115,19 @@ class ASR:
|
||||||
|
|
||||||
def run_step(self):
|
def run_step(self):
|
||||||
|
|
||||||
if self.terminated:
|
|
||||||
return
|
|
||||||
|
|
||||||
# get a frame of audio
|
# get a frame of audio
|
||||||
frame,type = self.__get_audio_frame()
|
frame,type = self.get_audio_frame()
|
||||||
|
self.frames.append(frame)
|
||||||
# the last frame
|
# put to output
|
||||||
if frame is None:
|
self.output_queue.put((frame,type))
|
||||||
# terminate, but always run the network for the left frames
|
# context not enough, do not run network.
|
||||||
self.terminated = True
|
if len(self.frames) < self.stride_left_size + self.context_size + self.stride_right_size:
|
||||||
else:
|
return
|
||||||
self.frames.append(frame)
|
|
||||||
# put to output
|
|
||||||
self.output_queue.put((frame,type))
|
|
||||||
# context not enough, do not run network.
|
|
||||||
if len(self.frames) < self.stride_left_size + self.context_size + self.stride_right_size:
|
|
||||||
return
|
|
||||||
|
|
||||||
inputs = np.concatenate(self.frames) # [N * chunk]
|
inputs = np.concatenate(self.frames) # [N * chunk]
|
||||||
|
|
||||||
# discard the old part to save memory
|
# discard the old part to save memory
|
||||||
if not self.terminated:
|
self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]
|
||||||
self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]
|
|
||||||
|
|
||||||
#print(f'[INFO] frame_to_text... ')
|
#print(f'[INFO] frame_to_text... ')
|
||||||
#t = time.time()
|
#t = time.time()
|
||||||
|
@ -197,10 +135,6 @@ class ASR:
|
||||||
#print(f'-------wav2vec time:{time.time()-t:.4f}s')
|
#print(f'-------wav2vec time:{time.time()-t:.4f}s')
|
||||||
feats = logits # better lips-sync than labels
|
feats = logits # better lips-sync than labels
|
||||||
|
|
||||||
# save feats
|
|
||||||
if self.opt.asr_save_feats:
|
|
||||||
self.all_feats.append(feats)
|
|
||||||
|
|
||||||
# record the feats efficiently.. (no concat, constant memory)
|
# record the feats efficiently.. (no concat, constant memory)
|
||||||
start = self.feat_buffer_idx * self.context_size
|
start = self.feat_buffer_idx * self.context_size
|
||||||
end = start + feats.shape[0]
|
end = start + feats.shape[0]
|
||||||
|
@ -212,51 +146,28 @@ class ASR:
|
||||||
# self.text = self.text + ' ' + text
|
# self.text = self.text + ' ' + text
|
||||||
|
|
||||||
# will only run once at ternimation
|
# will only run once at ternimation
|
||||||
if self.terminated:
|
# if self.terminated:
|
||||||
self.text += '\n[END]'
|
# self.text += '\n[END]'
|
||||||
print(self.text)
|
# print(self.text)
|
||||||
if self.opt.asr_save_feats:
|
# if self.opt.asr_save_feats:
|
||||||
print(f'[INFO] save all feats for training purpose... ')
|
# print(f'[INFO] save all feats for training purpose... ')
|
||||||
feats = torch.cat(self.all_feats, dim=0) # [N, C]
|
# feats = torch.cat(self.all_feats, dim=0) # [N, C]
|
||||||
# print('[INFO] before unfold', feats.shape)
|
# # print('[INFO] before unfold', feats.shape)
|
||||||
window_size = 16
|
# window_size = 16
|
||||||
padding = window_size // 2
|
# padding = window_size // 2
|
||||||
feats = feats.view(-1, self.audio_dim).permute(1, 0).contiguous() # [C, M]
|
# feats = feats.view(-1, self.audio_dim).permute(1, 0).contiguous() # [C, M]
|
||||||
feats = feats.view(1, self.audio_dim, -1, 1) # [1, C, M, 1]
|
# feats = feats.view(1, self.audio_dim, -1, 1) # [1, C, M, 1]
|
||||||
unfold_feats = F.unfold(feats, kernel_size=(window_size, 1), padding=(padding, 0), stride=(2, 1)) # [1, C * window_size, M / 2 + 1]
|
# unfold_feats = F.unfold(feats, kernel_size=(window_size, 1), padding=(padding, 0), stride=(2, 1)) # [1, C * window_size, M / 2 + 1]
|
||||||
unfold_feats = unfold_feats.view(self.audio_dim, window_size, -1).permute(2, 1, 0).contiguous() # [C, window_size, M / 2 + 1] --> [M / 2 + 1, window_size, C]
|
# unfold_feats = unfold_feats.view(self.audio_dim, window_size, -1).permute(2, 1, 0).contiguous() # [C, window_size, M / 2 + 1] --> [M / 2 + 1, window_size, C]
|
||||||
# print('[INFO] after unfold', unfold_feats.shape)
|
# # print('[INFO] after unfold', unfold_feats.shape)
|
||||||
# save to a npy file
|
# # save to a npy file
|
||||||
if 'esperanto' in self.opt.asr_model:
|
# if 'esperanto' in self.opt.asr_model:
|
||||||
output_path = self.opt.asr_wav.replace('.wav', '_eo.npy')
|
# output_path = self.opt.asr_wav.replace('.wav', '_eo.npy')
|
||||||
else:
|
# else:
|
||||||
output_path = self.opt.asr_wav.replace('.wav', '.npy')
|
# output_path = self.opt.asr_wav.replace('.wav', '.npy')
|
||||||
np.save(output_path, unfold_feats.cpu().numpy())
|
# np.save(output_path, unfold_feats.cpu().numpy())
|
||||||
print(f"[INFO] saved logits to {output_path}")
|
# print(f"[INFO] saved logits to {output_path}")
|
||||||
|
|
||||||
def __get_audio_frame(self):
|
|
||||||
if self.inwarm: # warm up
|
|
||||||
return np.zeros(self.chunk, dtype=np.float32),1
|
|
||||||
|
|
||||||
if self.mode == 'file':
|
|
||||||
if self.idx < self.file_stream.shape[0]:
|
|
||||||
frame = self.file_stream[self.idx: self.idx + self.chunk]
|
|
||||||
self.idx = self.idx + self.chunk
|
|
||||||
return frame,0
|
|
||||||
else:
|
|
||||||
return None,0
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
frame = self.queue.get(block=False)
|
|
||||||
type = 0
|
|
||||||
print(f'[INFO] get frame {frame.shape}')
|
|
||||||
except queue.Empty:
|
|
||||||
frame = np.zeros(self.chunk, dtype=np.float32)
|
|
||||||
type = 1
|
|
||||||
|
|
||||||
self.idx = self.idx + self.chunk
|
|
||||||
|
|
||||||
return frame,type
|
|
||||||
|
|
||||||
|
|
||||||
def __frame_to_text(self, frame):
|
def __frame_to_text(self, frame):
|
||||||
|
@ -277,8 +188,8 @@ class ASR:
|
||||||
right = min(logits.shape[1], logits.shape[1] - self.stride_right_size + 1) # +1 to make sure output is the same length as input.
|
right = min(logits.shape[1], logits.shape[1] - self.stride_right_size + 1) # +1 to make sure output is the same length as input.
|
||||||
|
|
||||||
# do not cut right if terminated.
|
# do not cut right if terminated.
|
||||||
if self.terminated:
|
# if self.terminated:
|
||||||
right = logits.shape[1]
|
# right = logits.shape[1]
|
||||||
|
|
||||||
logits = logits[:, left:right]
|
logits = logits[:, left:right]
|
||||||
|
|
||||||
|
@ -298,60 +209,23 @@ class ASR:
|
||||||
|
|
||||||
return logits[0], None,None #predicted_ids[0], transcription # [N,]
|
return logits[0], None,None #predicted_ids[0], transcription # [N,]
|
||||||
|
|
||||||
def __create_bytes_stream(self,byte_stream):
|
|
||||||
#byte_stream=BytesIO(buffer)
|
|
||||||
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
|
|
||||||
print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}')
|
|
||||||
stream = stream.astype(np.float32)
|
|
||||||
|
|
||||||
if stream.ndim > 1:
|
def warm_up(self):
|
||||||
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
|
print(f'[INFO] warm up ASR live model, expected latency = {self.warm_up_steps / self.fps:.6f}s')
|
||||||
stream = stream[:, 0]
|
t = time.time()
|
||||||
|
#for _ in range(self.stride_left_size):
|
||||||
if sample_rate != self.sample_rate and stream.shape[0]>0:
|
# self.frames.append(np.zeros(self.chunk, dtype=np.float32))
|
||||||
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.')
|
for _ in range(self.warm_up_steps):
|
||||||
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate)
|
self.run_step()
|
||||||
|
#if torch.cuda.is_available():
|
||||||
|
# torch.cuda.synchronize()
|
||||||
|
t = time.time() - t
|
||||||
|
print(f'[INFO] warm-up done, actual latency = {t:.6f}s')
|
||||||
|
|
||||||
return stream
|
#self.clear_queue()
|
||||||
|
|
||||||
def push_audio(self,buffer): #push audio pcm from tts
|
#####not used function#####################################
|
||||||
print(f'[INFO] push_audio {len(buffer)}')
|
'''
|
||||||
if self.opt.tts == "xtts" or self.opt.tts == "gpt-sovits":
|
|
||||||
if len(buffer)>0:
|
|
||||||
stream = np.frombuffer(buffer, dtype=np.int16).astype(np.float32) / 32767
|
|
||||||
if self.opt.tts == "xtts":
|
|
||||||
stream = resampy.resample(x=stream, sr_orig=24000, sr_new=self.sample_rate)
|
|
||||||
else:
|
|
||||||
stream = resampy.resample(x=stream, sr_orig=32000, sr_new=self.sample_rate)
|
|
||||||
#byte_stream=BytesIO(buffer)
|
|
||||||
#stream = self.__create_bytes_stream(byte_stream)
|
|
||||||
streamlen = stream.shape[0]
|
|
||||||
idx=0
|
|
||||||
while streamlen >= self.chunk:
|
|
||||||
self.queue.put(stream[idx:idx+self.chunk])
|
|
||||||
streamlen -= self.chunk
|
|
||||||
idx += self.chunk
|
|
||||||
# if streamlen>0: #skip last frame(not 20ms)
|
|
||||||
# self.queue.put(stream[idx:])
|
|
||||||
else: #edge tts
|
|
||||||
self.input_stream.write(buffer)
|
|
||||||
if len(buffer)<=0:
|
|
||||||
self.input_stream.seek(0)
|
|
||||||
stream = self.__create_bytes_stream(self.input_stream)
|
|
||||||
streamlen = stream.shape[0]
|
|
||||||
idx=0
|
|
||||||
while streamlen >= self.chunk:
|
|
||||||
self.queue.put(stream[idx:idx+self.chunk])
|
|
||||||
streamlen -= self.chunk
|
|
||||||
idx += self.chunk
|
|
||||||
#if streamlen>0: #skip last frame(not 20ms)
|
|
||||||
# self.queue.put(stream[idx:])
|
|
||||||
self.input_stream.seek(0)
|
|
||||||
self.input_stream.truncate()
|
|
||||||
|
|
||||||
def get_audio_out(self): #get origin audio pcm to nerf
|
|
||||||
return self.output_queue.get()
|
|
||||||
|
|
||||||
def __init_queue(self):
|
def __init_queue(self):
|
||||||
self.frames = []
|
self.frames = []
|
||||||
self.queue.queue.clear()
|
self.queue.queue.clear()
|
||||||
|
@ -360,10 +234,6 @@ class ASR:
|
||||||
self.tail = 8
|
self.tail = 8
|
||||||
# attention window...
|
# attention window...
|
||||||
self.att_feats = [torch.zeros(self.audio_dim, 16, dtype=torch.float32, device=self.device)] * 4
|
self.att_feats = [torch.zeros(self.audio_dim, 16, dtype=torch.float32, device=self.device)] * 4
|
||||||
|
|
||||||
def before_push_audio(self):
|
|
||||||
self.__init_queue()
|
|
||||||
self.warm_up()
|
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
|
|
||||||
|
@ -380,73 +250,6 @@ class ASR:
|
||||||
if self.play:
|
if self.play:
|
||||||
self.output_queue.queue.clear()
|
self.output_queue.queue.clear()
|
||||||
|
|
||||||
def warm_up(self):
|
|
||||||
|
|
||||||
#self.listen()
|
|
||||||
|
|
||||||
self.inwarm = True
|
|
||||||
print(f'[INFO] warm up ASR live model, expected latency = {self.warm_up_steps / self.fps:.6f}s')
|
|
||||||
t = time.time()
|
|
||||||
#for _ in range(self.stride_left_size):
|
|
||||||
# self.frames.append(np.zeros(self.chunk, dtype=np.float32))
|
|
||||||
for _ in range(self.warm_up_steps):
|
|
||||||
self.run_step()
|
|
||||||
#if torch.cuda.is_available():
|
|
||||||
# torch.cuda.synchronize()
|
|
||||||
t = time.time() - t
|
|
||||||
print(f'[INFO] warm-up done, actual latency = {t:.6f}s')
|
|
||||||
self.inwarm = False
|
|
||||||
|
|
||||||
#self.clear_queue()
|
|
||||||
|
|
||||||
'''
|
|
||||||
def create_file_stream(self):
|
|
||||||
|
|
||||||
stream, sample_rate = sf.read(self.opt.asr_wav) # [T*sample_rate,] float64
|
|
||||||
stream = stream.astype(np.float32)
|
|
||||||
|
|
||||||
if stream.ndim > 1:
|
|
||||||
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
|
|
||||||
stream = stream[:, 0]
|
|
||||||
|
|
||||||
if sample_rate != self.sample_rate:
|
|
||||||
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.')
|
|
||||||
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate)
|
|
||||||
|
|
||||||
print(f'[INFO] loaded audio stream {self.opt.asr_wav}: {stream.shape}')
|
|
||||||
|
|
||||||
return stream
|
|
||||||
|
|
||||||
|
|
||||||
def create_pyaudio_stream(self):
|
|
||||||
|
|
||||||
import pyaudio
|
|
||||||
|
|
||||||
print(f'[INFO] creating live audio stream ...')
|
|
||||||
|
|
||||||
audio = pyaudio.PyAudio()
|
|
||||||
|
|
||||||
# get devices
|
|
||||||
info = audio.get_host_api_info_by_index(0)
|
|
||||||
n_devices = info.get('deviceCount')
|
|
||||||
|
|
||||||
for i in range(0, n_devices):
|
|
||||||
if (audio.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0:
|
|
||||||
name = audio.get_device_info_by_host_api_device_index(0, i).get('name')
|
|
||||||
print(f'[INFO] choose audio device {name}, id {i}')
|
|
||||||
break
|
|
||||||
|
|
||||||
# get stream
|
|
||||||
stream = audio.open(input_device_index=i,
|
|
||||||
format=pyaudio.paInt16,
|
|
||||||
channels=1,
|
|
||||||
rate=self.sample_rate,
|
|
||||||
input=True,
|
|
||||||
frames_per_buffer=self.chunk)
|
|
||||||
|
|
||||||
return audio, stream
|
|
||||||
'''
|
|
||||||
#####not used function#####################################
|
|
||||||
def listen(self):
|
def listen(self):
|
||||||
# start
|
# start
|
||||||
if self.mode == 'live' and not self.listening:
|
if self.mode == 'live' and not self.listening:
|
||||||
|
@ -489,6 +292,25 @@ class ASR:
|
||||||
# live mode: also print the result text.
|
# live mode: also print the result text.
|
||||||
self.text += '\n[END]'
|
self.text += '\n[END]'
|
||||||
print(self.text)
|
print(self.text)
|
||||||
|
|
||||||
|
def _read_frame(stream, exit_event, queue, chunk):
|
||||||
|
while True:
|
||||||
|
if exit_event.is_set():
|
||||||
|
print(f'[INFO] read frame thread ends')
|
||||||
|
break
|
||||||
|
frame = stream.read(chunk, exception_on_overflow=False)
|
||||||
|
frame = np.frombuffer(frame, dtype=np.int16).astype(np.float32) / 32767 # [chunk]
|
||||||
|
queue.put(frame)
|
||||||
|
|
||||||
|
def _play_frame(stream, exit_event, queue, chunk):
|
||||||
|
|
||||||
|
while True:
|
||||||
|
if exit_event.is_set():
|
||||||
|
print(f'[INFO] play frame thread ends')
|
||||||
|
break
|
||||||
|
frame = queue.get()
|
||||||
|
frame = (frame * 32767).astype(np.int16).tobytes()
|
||||||
|
stream.write(frame, chunk)
|
||||||
#########################################################
|
#########################################################
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -522,4 +344,5 @@ if __name__ == '__main__':
|
||||||
raise ValueError("DeepSpeech features should not use this code to extract...")
|
raise ValueError("DeepSpeech features should not use this code to extract...")
|
||||||
|
|
||||||
with ASR(opt) as asr:
|
with ASR(opt) as asr:
|
||||||
asr.run()
|
asr.run()
|
||||||
|
'''
|
275
nerfreal.py
275
nerfreal.py
|
@ -8,61 +8,79 @@ import os
|
||||||
import time
|
import time
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import cv2
|
import cv2
|
||||||
|
import glob
|
||||||
|
|
||||||
|
from nerfasr import NerfASR
|
||||||
|
from ttsreal import EdgeTTS,VoitsTTS,XTTS
|
||||||
|
|
||||||
from asrreal import ASR
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from av import AudioFrame, VideoFrame
|
from av import AudioFrame, VideoFrame
|
||||||
|
from basereal import BaseReal
|
||||||
|
|
||||||
class NeRFReal:
|
#from imgcache import ImgCache
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
def read_imgs(img_list):
|
||||||
|
frames = []
|
||||||
|
print('reading images...')
|
||||||
|
for img_path in tqdm(img_list):
|
||||||
|
frame = cv2.imread(img_path)
|
||||||
|
frames.append(frame)
|
||||||
|
return frames
|
||||||
|
|
||||||
|
class NeRFReal(BaseReal):
|
||||||
def __init__(self, opt, trainer, data_loader, debug=True):
|
def __init__(self, opt, trainer, data_loader, debug=True):
|
||||||
self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
|
super().__init__(opt)
|
||||||
|
#self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
|
||||||
self.W = opt.W
|
self.W = opt.W
|
||||||
self.H = opt.H
|
self.H = opt.H
|
||||||
self.debug = debug
|
|
||||||
self.training = False
|
|
||||||
self.step = 0 # training step
|
|
||||||
|
|
||||||
self.trainer = trainer
|
self.trainer = trainer
|
||||||
self.data_loader = data_loader
|
self.data_loader = data_loader
|
||||||
|
|
||||||
# use dataloader's bg
|
# use dataloader's bg
|
||||||
bg_img = data_loader._data.bg_img #.view(1, -1, 3)
|
#bg_img = data_loader._data.bg_img #.view(1, -1, 3)
|
||||||
if self.H != bg_img.shape[0] or self.W != bg_img.shape[1]:
|
#if self.H != bg_img.shape[0] or self.W != bg_img.shape[1]:
|
||||||
bg_img = F.interpolate(bg_img.permute(2, 0, 1).unsqueeze(0).contiguous(), (self.H, self.W), mode='bilinear').squeeze(0).permute(1, 2, 0).contiguous()
|
# bg_img = F.interpolate(bg_img.permute(2, 0, 1).unsqueeze(0).contiguous(), (self.H, self.W), mode='bilinear').squeeze(0).permute(1, 2, 0).contiguous()
|
||||||
self.bg_color = bg_img.view(1, -1, 3)
|
#self.bg_color = bg_img.view(1, -1, 3)
|
||||||
|
|
||||||
# audio features (from dataloader, only used in non-playing mode)
|
# audio features (from dataloader, only used in non-playing mode)
|
||||||
self.audio_features = data_loader._data.auds # [N, 29, 16]
|
#self.audio_features = data_loader._data.auds # [N, 29, 16]
|
||||||
self.audio_idx = 0
|
#self.audio_idx = 0
|
||||||
|
|
||||||
#self.frame_total_num = data_loader._data.end_index
|
#self.frame_total_num = data_loader._data.end_index
|
||||||
#print("frame_total_num:",self.frame_total_num)
|
#print("frame_total_num:",self.frame_total_num)
|
||||||
|
|
||||||
# control eye
|
# control eye
|
||||||
self.eye_area = None if not self.opt.exp_eye else data_loader._data.eye_area.mean().item()
|
#self.eye_area = None if not self.opt.exp_eye else data_loader._data.eye_area.mean().item()
|
||||||
|
|
||||||
# playing seq from dataloader, or pause.
|
# playing seq from dataloader, or pause.
|
||||||
self.playing = True #False todo
|
|
||||||
self.loader = iter(data_loader)
|
self.loader = iter(data_loader)
|
||||||
|
frame_total_num = data_loader._data.end_index
|
||||||
|
if opt.fullbody:
|
||||||
|
input_img_list = glob.glob(os.path.join(self.opt.fullbody_img, '*.[jpJP][pnPN]*[gG]'))
|
||||||
|
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
|
||||||
|
#print('input_img_list:',input_img_list)
|
||||||
|
self.fullbody_list_cycle = read_imgs(input_img_list[:frame_total_num])
|
||||||
|
#self.imagecache = ImgCache(frame_total_num,self.opt.fullbody_img,1000)
|
||||||
|
|
||||||
self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32)
|
#self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32)
|
||||||
self.need_update = True # camera moved, should reset accumulation
|
#self.need_update = True # camera moved, should reset accumulation
|
||||||
self.spp = 1 # sample per pixel
|
#self.spp = 1 # sample per pixel
|
||||||
self.mode = 'image' # choose from ['image', 'depth']
|
#self.mode = 'image' # choose from ['image', 'depth']
|
||||||
|
|
||||||
self.dynamic_resolution = False # assert False!
|
#self.dynamic_resolution = False # assert False!
|
||||||
self.downscale = 1
|
#self.downscale = 1
|
||||||
self.train_steps = 16
|
#self.train_steps = 16
|
||||||
|
|
||||||
self.ind_index = 0
|
#self.ind_index = 0
|
||||||
self.ind_num = trainer.model.individual_codes.shape[0]
|
#self.ind_num = trainer.model.individual_codes.shape[0]
|
||||||
|
|
||||||
self.customimg_index = 0
|
#self.customimg_index = 0
|
||||||
|
|
||||||
# build asr
|
# build asr
|
||||||
if self.opt.asr:
|
self.asr = NerfASR(opt,self)
|
||||||
self.asr = ASR(opt)
|
self.asr.warm_up()
|
||||||
self.asr.warm_up()
|
|
||||||
|
|
||||||
'''
|
'''
|
||||||
video_path = 'video_stream'
|
video_path = 'video_stream'
|
||||||
|
@ -108,108 +126,115 @@ class NeRFReal:
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, traceback):
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
if self.opt.asr:
|
if self.opt.asr:
|
||||||
self.asr.stop()
|
self.asr.stop()
|
||||||
|
|
||||||
def push_audio(self,chunk):
|
|
||||||
self.asr.push_audio(chunk)
|
|
||||||
|
|
||||||
def before_push_audio(self):
|
|
||||||
self.asr.before_push_audio()
|
|
||||||
|
|
||||||
def mirror_index(self, index):
|
# def mirror_index(self, index):
|
||||||
size = self.opt.customvideo_imgnum
|
# size = self.opt.customvideo_imgnum
|
||||||
turn = index // size
|
# turn = index // size
|
||||||
res = index % size
|
# res = index % size
|
||||||
if turn % 2 == 0:
|
# if turn % 2 == 0:
|
||||||
return res
|
# return res
|
||||||
else:
|
# else:
|
||||||
return size - res - 1
|
# return size - res - 1
|
||||||
|
|
||||||
def prepare_buffer(self, outputs):
|
|
||||||
if self.mode == 'image':
|
|
||||||
return outputs['image']
|
|
||||||
else:
|
|
||||||
return np.expand_dims(outputs['depth'], -1).repeat(3, -1)
|
|
||||||
|
|
||||||
def test_step(self,loop=None,audio_track=None,video_track=None):
|
def test_step(self,loop=None,audio_track=None,video_track=None):
|
||||||
|
|
||||||
#starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
|
#starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
|
||||||
#starter.record()
|
#starter.record()
|
||||||
|
|
||||||
if self.playing:
|
try:
|
||||||
try:
|
data = next(self.loader)
|
||||||
data = next(self.loader)
|
except StopIteration:
|
||||||
except StopIteration:
|
self.loader = iter(self.data_loader)
|
||||||
self.loader = iter(self.data_loader)
|
data = next(self.loader)
|
||||||
data = next(self.loader)
|
|
||||||
|
if self.opt.asr:
|
||||||
if self.opt.asr:
|
# use the live audio stream
|
||||||
# use the live audio stream
|
data['auds'] = self.asr.get_next_feat()
|
||||||
data['auds'] = self.asr.get_next_feat()
|
|
||||||
|
|
||||||
audiotype = 0
|
audiotype1 = 0
|
||||||
if self.opt.transport=='rtmp':
|
audiotype2 = 0
|
||||||
for _ in range(2):
|
#send audio
|
||||||
frame,type = self.asr.get_audio_out()
|
for i in range(2):
|
||||||
audiotype += type
|
frame,type = self.asr.get_audio_out()
|
||||||
#print(f'[INFO] get_audio_out shape ',frame.shape)
|
if i==0:
|
||||||
self.streamer.stream_frame_audio(frame)
|
audiotype1 = type
|
||||||
else:
|
else:
|
||||||
for _ in range(2):
|
audiotype2 = type
|
||||||
frame,type = self.asr.get_audio_out()
|
#print(f'[INFO] get_audio_out shape ',frame.shape)
|
||||||
audiotype += type
|
if self.opt.transport=='rtmp':
|
||||||
frame = (frame * 32767).astype(np.int16)
|
self.streamer.stream_frame_audio(frame)
|
||||||
new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0])
|
else: #webrtc
|
||||||
new_frame.planes[0].update(frame.tobytes())
|
frame = (frame * 32767).astype(np.int16)
|
||||||
new_frame.sample_rate=16000
|
new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0])
|
||||||
# if audio_track._queue.qsize()>10:
|
new_frame.planes[0].update(frame.tobytes())
|
||||||
# time.sleep(0.1)
|
new_frame.sample_rate=16000
|
||||||
asyncio.run_coroutine_threadsafe(audio_track._queue.put(new_frame), loop)
|
asyncio.run_coroutine_threadsafe(audio_track._queue.put(new_frame), loop)
|
||||||
#t = time.time()
|
|
||||||
if self.opt.customvideo and audiotype!=0:
|
# if self.opt.transport=='rtmp':
|
||||||
self.loader = iter(self.data_loader) #init
|
# for _ in range(2):
|
||||||
imgindex = self.mirror_index(self.customimg_index)
|
# frame,type = self.asr.get_audio_out()
|
||||||
#print('custom img index:',imgindex)
|
# audiotype += type
|
||||||
image = cv2.imread(os.path.join(self.opt.customvideo_img, str(int(imgindex))+'.png'))
|
# #print(f'[INFO] get_audio_out shape ',frame.shape)
|
||||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
# self.streamer.stream_frame_audio(frame)
|
||||||
|
# else: #webrtc
|
||||||
|
# for _ in range(2):
|
||||||
|
# frame,type = self.asr.get_audio_out()
|
||||||
|
# audiotype += type
|
||||||
|
# frame = (frame * 32767).astype(np.int16)
|
||||||
|
# new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0])
|
||||||
|
# new_frame.planes[0].update(frame.tobytes())
|
||||||
|
# new_frame.sample_rate=16000
|
||||||
|
# # if audio_track._queue.qsize()>10:
|
||||||
|
# # time.sleep(0.1)
|
||||||
|
# asyncio.run_coroutine_threadsafe(audio_track._queue.put(new_frame), loop)
|
||||||
|
#t = time.time()
|
||||||
|
if audiotype1!=0 and audiotype2!=0: #全为静音数据
|
||||||
|
self.speaking = False
|
||||||
|
else:
|
||||||
|
self.speaking = True
|
||||||
|
|
||||||
|
if audiotype1!=0 and audiotype2!=0 and self.custom_index.get(audiotype1) is not None: #不为推理视频并且有自定义视频
|
||||||
|
mirindex = self.mirror_index(len(self.custom_img_cycle[audiotype1]),self.custom_index[audiotype1])
|
||||||
|
#imgindex = self.mirror_index(self.customimg_index)
|
||||||
|
#print('custom img index:',imgindex)
|
||||||
|
#image = cv2.imread(os.path.join(self.opt.customvideo_img, str(int(imgindex))+'.png'))
|
||||||
|
image = self.custom_img_cycle[audiotype1][mirindex]
|
||||||
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||||
|
self.custom_index[audiotype1] += 1
|
||||||
|
if self.opt.transport=='rtmp':
|
||||||
|
self.streamer.stream_frame(image)
|
||||||
|
else:
|
||||||
|
new_frame = VideoFrame.from_ndarray(image, format="rgb24")
|
||||||
|
asyncio.run_coroutine_threadsafe(video_track._queue.put(new_frame), loop)
|
||||||
|
else: #推理视频+贴回
|
||||||
|
outputs = self.trainer.test_gui_with_data(data, self.W, self.H)
|
||||||
|
#print('-------ernerf time: ',time.time()-t)
|
||||||
|
#print(f'[INFO] outputs shape ',outputs['image'].shape)
|
||||||
|
image = (outputs['image'] * 255).astype(np.uint8)
|
||||||
|
if not self.opt.fullbody:
|
||||||
if self.opt.transport=='rtmp':
|
if self.opt.transport=='rtmp':
|
||||||
self.streamer.stream_frame(image)
|
self.streamer.stream_frame(image)
|
||||||
else:
|
else:
|
||||||
new_frame = VideoFrame.from_ndarray(image, format="rgb24")
|
new_frame = VideoFrame.from_ndarray(image, format="rgb24")
|
||||||
asyncio.run_coroutine_threadsafe(video_track._queue.put(new_frame), loop)
|
asyncio.run_coroutine_threadsafe(video_track._queue.put(new_frame), loop)
|
||||||
self.customimg_index += 1
|
else: #fullbody human
|
||||||
else:
|
#print("frame index:",data['index'])
|
||||||
self.customimg_index = 0
|
#image_fullbody = cv2.imread(os.path.join(self.opt.fullbody_img, str(data['index'][0])+'.jpg'))
|
||||||
outputs = self.trainer.test_gui_with_data(data, self.W, self.H)
|
image_fullbody = self.fullbody_list_cycle[data['index'][0]]
|
||||||
#print('-------ernerf time: ',time.time()-t)
|
#image_fullbody = self.imagecache.get_img(data['index'][0])
|
||||||
#print(f'[INFO] outputs shape ',outputs['image'].shape)
|
image_fullbody = cv2.cvtColor(image_fullbody, cv2.COLOR_BGR2RGB)
|
||||||
image = (outputs['image'] * 255).astype(np.uint8)
|
start_x = self.opt.fullbody_offset_x # 合并后小图片的起始x坐标
|
||||||
if not self.opt.fullbody:
|
start_y = self.opt.fullbody_offset_y # 合并后小图片的起始y坐标
|
||||||
if self.opt.transport=='rtmp':
|
image_fullbody[start_y:start_y+image.shape[0], start_x:start_x+image.shape[1]] = image
|
||||||
self.streamer.stream_frame(image)
|
if self.opt.transport=='rtmp':
|
||||||
else:
|
self.streamer.stream_frame(image_fullbody)
|
||||||
new_frame = VideoFrame.from_ndarray(image, format="rgb24")
|
else:
|
||||||
asyncio.run_coroutine_threadsafe(video_track._queue.put(new_frame), loop)
|
new_frame = VideoFrame.from_ndarray(image_fullbody, format="rgb24")
|
||||||
else: #fullbody human
|
asyncio.run_coroutine_threadsafe(video_track._queue.put(new_frame), loop)
|
||||||
#print("frame index:",data['index'])
|
|
||||||
image_fullbody = cv2.imread(os.path.join(self.opt.fullbody_img, str(data['index'][0])+'.jpg'))
|
|
||||||
image_fullbody = cv2.cvtColor(image_fullbody, cv2.COLOR_BGR2RGB)
|
|
||||||
start_x = self.opt.fullbody_offset_x # 合并后小图片的起始x坐标
|
|
||||||
start_y = self.opt.fullbody_offset_y # 合并后小图片的起始y坐标
|
|
||||||
image_fullbody[start_y:start_y+image.shape[0], start_x:start_x+image.shape[1]] = image
|
|
||||||
if self.opt.transport=='rtmp':
|
|
||||||
self.streamer.stream_frame(image_fullbody)
|
|
||||||
else:
|
|
||||||
new_frame = VideoFrame.from_ndarray(image_fullbody, format="rgb24")
|
|
||||||
asyncio.run_coroutine_threadsafe(video_track._queue.put(new_frame), loop)
|
|
||||||
#self.pipe.stdin.write(image.tostring())
|
#self.pipe.stdin.write(image.tostring())
|
||||||
else:
|
|
||||||
if self.audio_features is not None:
|
|
||||||
auds = get_audio_features(self.audio_features, self.opt.att, self.audio_idx)
|
|
||||||
else:
|
|
||||||
auds = None
|
|
||||||
outputs = self.trainer.test_gui(self.cam.pose, self.cam.intrinsics, self.W, self.H, auds, self.eye_area, self.ind_index, self.bg_color, self.spp, self.downscale)
|
|
||||||
|
|
||||||
#ender.record()
|
#ender.record()
|
||||||
#torch.cuda.synchronize()
|
#torch.cuda.synchronize()
|
||||||
#t = starter.elapsed_time(ender)
|
#t = starter.elapsed_time(ender)
|
||||||
|
@ -218,6 +243,8 @@ class NeRFReal:
|
||||||
#if self.opt.asr:
|
#if self.opt.asr:
|
||||||
# self.asr.warm_up()
|
# self.asr.warm_up()
|
||||||
|
|
||||||
|
self.init_customindex()
|
||||||
|
|
||||||
if self.opt.transport=='rtmp':
|
if self.opt.transport=='rtmp':
|
||||||
from rtmp_streaming import StreamerConfig, Streamer
|
from rtmp_streaming import StreamerConfig, Streamer
|
||||||
fps=25
|
fps=25
|
||||||
|
@ -246,14 +273,15 @@ class NeRFReal:
|
||||||
totaltime=0
|
totaltime=0
|
||||||
_starttime=time.perf_counter()
|
_starttime=time.perf_counter()
|
||||||
_totalframe=0
|
_totalframe=0
|
||||||
|
|
||||||
|
self.tts.render(quit_event)
|
||||||
while not quit_event.is_set(): #todo
|
while not quit_event.is_set(): #todo
|
||||||
# update texture every frame
|
# update texture every frame
|
||||||
# audio stream thread...
|
# audio stream thread...
|
||||||
t = time.perf_counter()
|
t = time.perf_counter()
|
||||||
if self.opt.asr and self.playing:
|
# run 2 ASR steps (audio is at 50FPS, video is at 25FPS)
|
||||||
# run 2 ASR steps (audio is at 50FPS, video is at 25FPS)
|
for _ in range(2):
|
||||||
for _ in range(2):
|
self.asr.run_step()
|
||||||
self.asr.run_step()
|
|
||||||
self.test_step(loop,audio_track,video_track)
|
self.test_step(loop,audio_track,video_track)
|
||||||
totaltime += (time.perf_counter() - t)
|
totaltime += (time.perf_counter() - t)
|
||||||
count += 1
|
count += 1
|
||||||
|
@ -262,7 +290,14 @@ class NeRFReal:
|
||||||
print(f"------actual avg infer fps:{count/totaltime:.4f}")
|
print(f"------actual avg infer fps:{count/totaltime:.4f}")
|
||||||
count=0
|
count=0
|
||||||
totaltime=0
|
totaltime=0
|
||||||
delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms
|
if self.opt.transport=='rtmp':
|
||||||
if delay > 0:
|
delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms
|
||||||
time.sleep(delay)
|
if delay > 0:
|
||||||
|
time.sleep(delay)
|
||||||
|
else:
|
||||||
|
if video_track._queue.qsize()>=5:
|
||||||
|
#print('sleep qsize=',video_track._queue.qsize())
|
||||||
|
time.sleep(0.04*video_track._queue.qsize()*0.8)
|
||||||
|
print('nerfreal thread stop')
|
||||||
|
|
||||||
|
|
|
@ -23,13 +23,21 @@ soundfile
|
||||||
einops
|
einops
|
||||||
configargparse
|
configargparse
|
||||||
|
|
||||||
lpips
|
lpips==0.1.3
|
||||||
imageio-ffmpeg
|
imageio-ffmpeg
|
||||||
|
|
||||||
transformers
|
transformers
|
||||||
edge_tts
|
edge_tts==6.1.11
|
||||||
flask
|
flask
|
||||||
flask_sockets
|
flask_sockets
|
||||||
opencv-python-headless
|
opencv-python-headless
|
||||||
aiortc
|
aiortc
|
||||||
aiohttp_cors
|
aiohttp_cors
|
||||||
|
|
||||||
|
ffmpeg-python
|
||||||
|
omegaconf
|
||||||
|
diffusers
|
||||||
|
accelerate
|
||||||
|
|
||||||
|
librosa
|
||||||
|
openai
|
||||||
|
|
|
@ -1,99 +0,0 @@
|
||||||
# 采用gpt-sovits方案,bert-sovits适合长音频训练,gpt-sovits运行短音频快速推理
|
|
||||||
## 部署tts推理
|
|
||||||
git clone https://github.com/X-T-E-R/GPT-SoVITS-Inference.git
|
|
||||||
|
|
||||||
## 1. 安装依赖库
|
|
||||||
```
|
|
||||||
conda create -n GPTSoVits python=3.9
|
|
||||||
conda activate GPTSoVits
|
|
||||||
bash install.sh
|
|
||||||
```
|
|
||||||
从 [GPT-SoVITS Models](https://huggingface.co/lj1995/GPT-SoVITS) 下载预训练模型,并将它们放置在 `GPT_SoVITS\pretrained_models` 中
|
|
||||||
|
|
||||||
## 2. Model Folder Format
|
|
||||||
模型文件下载地址 https://www.yuque.com/xter/zibxlp/gsximn7ditzgispg
|
|
||||||
下载的模型文件放到trained目录下, 如 `trained/Character1/`
|
|
||||||
Put the pth / ckpt / wav files in it, the wav should be named as the prompt text
|
|
||||||
Like :
|
|
||||||
|
|
||||||
```
|
|
||||||
trained
|
|
||||||
--hutao
|
|
||||||
----hutao-e75.ckpt
|
|
||||||
----hutao_e60_s3360.pth
|
|
||||||
----hutao said something.wav
|
|
||||||
```
|
|
||||||
|
|
||||||
## 3. 启动
|
|
||||||
### 3.1 后端服务:
|
|
||||||
python Inference/src/tts_backend.py
|
|
||||||
如果有错误提示找不到cmudict,从这下载https://github.com/nltk/nltk_data,将packages改名为nltk_data放到home目录下
|
|
||||||
### 3.2 管理character:
|
|
||||||
python Inference/src/Character_Manager.py
|
|
||||||
浏览器打开可以管理character和emotion
|
|
||||||
### 3.3 测试tts功能:
|
|
||||||
python Inference/src/TTS_Webui.py
|
|
||||||
|
|
||||||
|
|
||||||
## 4. 接口说明
|
|
||||||
### 4.1 Character and Emotion List
|
|
||||||
To obtain the supported characters and their corresponding emotions, please visit the following URL:
|
|
||||||
- URL: `http://127.0.0.1:5000/character_list`
|
|
||||||
- Returns: A JSON format list of characters and corresponding emotions
|
|
||||||
- Method: `GET`
|
|
||||||
|
|
||||||
```
|
|
||||||
{
|
|
||||||
"Hanabi": [
|
|
||||||
"default",
|
|
||||||
"Normal",
|
|
||||||
"Yandere",
|
|
||||||
],
|
|
||||||
"Hutao": [
|
|
||||||
"default"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### 4.2 Text-to-Speech
|
|
||||||
|
|
||||||
- URL: `http://127.0.0.1:5000/tts`
|
|
||||||
- Returns: Audio on success. Error message on failure.
|
|
||||||
- Method: `GET`/`POST`
|
|
||||||
```
|
|
||||||
{
|
|
||||||
"method": "POST",
|
|
||||||
"body": {
|
|
||||||
"character": "${chaName}",
|
|
||||||
"emotion": "${Emotion}",
|
|
||||||
"text": "${speakText}",
|
|
||||||
"text_language": "${textLanguage}",
|
|
||||||
"batch_size": ${batch_size},
|
|
||||||
"speed": ${speed},
|
|
||||||
"top_k": ${topK},
|
|
||||||
"top_p": ${topP},
|
|
||||||
"temperature": ${temperature},
|
|
||||||
"stream": "${stream}",
|
|
||||||
"format": "${Format}",
|
|
||||||
"save_temp": "${saveTemp}"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
##### Parameter Explanation
|
|
||||||
|
|
||||||
- **text**: The text to be converted, URL encoding is recommended.
|
|
||||||
- **character**: Character folder name, pay attention to case sensitivity, full/half width, and language.
|
|
||||||
- **emotion**: Character emotion, must be an actually supported emotion of the character, otherwise, the default emotion will be used.
|
|
||||||
- **text_language**: Text language (auto / zh / en / ja), default is multilingual mixed.
|
|
||||||
- **top_k**, **top_p**, **temperature**: GPT model parameters, no need to modify if unfamiliar.
|
|
||||||
|
|
||||||
- **batch_size**: How many batches at a time, can be increased for faster processing if you have a powerful computer, integer, default is 1.
|
|
||||||
- **speed**: Speech speed, default is 1.0.
|
|
||||||
- **save_temp**: Whether to save temporary files, when true, the backend will save the generated audio, and subsequent identical requests will directly return that data, default is false.
|
|
||||||
- **stream**: Whether to stream, when true, audio will be returned sentence by sentence, default is false.
|
|
||||||
- **format**: Format, default is WAV, allows MP3/ WAV/ OGG.
|
|
||||||
|
|
||||||
## 部署tts训练
|
|
||||||
https://github.com/RVC-Boss/GPT-SoVITS
|
|
||||||
根据文档说明部署,将训练后的模型拷到推理服务的trained目录下
|
|
|
@ -1,28 +0,0 @@
|
||||||
import requests
|
|
||||||
import pyaudio
|
|
||||||
|
|
||||||
# 流式传输音频的URL,你可以自由改成Post
|
|
||||||
stream_url = 'http://127.0.0.1:5000/tts?text=这是一段测试文本,旨在通过多种语言风格和复杂性的内容来全面检验文本到语音系统的性能。接下来,我们会探索各种主题和语言结构,包括文学引用、技术性描述、日常会话以及诗歌等。首先,让我们从一段简单的描述性文本开始:“在一个阳光明媚的下午,一位年轻的旅者站在山顶上,眺望着下方那宽广而繁忙的城市。他的心中充满了对未来的憧憬和对旅途的期待。”这段文本测试了系统对自然景观描写的处理能力和情感表达的细腻程度。&stream=true'
|
|
||||||
|
|
||||||
# 初始化pyaudio
|
|
||||||
p = pyaudio.PyAudio()
|
|
||||||
|
|
||||||
# 打开音频流
|
|
||||||
stream = p.open(format=p.get_format_from_width(2),
|
|
||||||
channels=1,
|
|
||||||
rate=32000,
|
|
||||||
output=True)
|
|
||||||
|
|
||||||
# 使用requests获取音频流,你可以自由改成Post
|
|
||||||
response = requests.get(stream_url, stream=True)
|
|
||||||
|
|
||||||
# 读取数据块并播放
|
|
||||||
for data in response.iter_content(chunk_size=1024):
|
|
||||||
stream.write(data)
|
|
||||||
|
|
||||||
# 停止和关闭流
|
|
||||||
stream.stop_stream()
|
|
||||||
stream.close()
|
|
||||||
|
|
||||||
# 终止pyaudio
|
|
||||||
p.terminate()
|
|
|
@ -0,0 +1,310 @@
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
import soundfile as sf
|
||||||
|
import resampy
|
||||||
|
import asyncio
|
||||||
|
import edge_tts
|
||||||
|
|
||||||
|
from typing import Iterator
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
import queue
|
||||||
|
from queue import Queue
|
||||||
|
from io import BytesIO
|
||||||
|
from threading import Thread, Event
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
class State(Enum):
|
||||||
|
RUNNING=0
|
||||||
|
PAUSE=1
|
||||||
|
|
||||||
|
class BaseTTS:
|
||||||
|
def __init__(self, opt, parent):
|
||||||
|
self.opt=opt
|
||||||
|
self.parent = parent
|
||||||
|
|
||||||
|
self.fps = opt.fps # 20 ms per frame
|
||||||
|
self.sample_rate = 16000
|
||||||
|
self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000)
|
||||||
|
self.input_stream = BytesIO()
|
||||||
|
|
||||||
|
self.msgqueue = Queue()
|
||||||
|
self.state = State.RUNNING
|
||||||
|
|
||||||
|
def pause_talk(self):
|
||||||
|
self.msgqueue.queue.clear()
|
||||||
|
self.state = State.PAUSE
|
||||||
|
|
||||||
|
def put_msg_txt(self,msg):
|
||||||
|
if len(msg)>0:
|
||||||
|
self.msgqueue.put(msg)
|
||||||
|
|
||||||
|
def render(self,quit_event):
|
||||||
|
process_thread = Thread(target=self.process_tts, args=(quit_event,))
|
||||||
|
process_thread.start()
|
||||||
|
|
||||||
|
def process_tts(self,quit_event):
|
||||||
|
while not quit_event.is_set():
|
||||||
|
try:
|
||||||
|
msg = self.msgqueue.get(block=True, timeout=1)
|
||||||
|
self.state=State.RUNNING
|
||||||
|
except queue.Empty:
|
||||||
|
continue
|
||||||
|
self.txt_to_audio(msg)
|
||||||
|
print('ttsreal thread stop')
|
||||||
|
|
||||||
|
def txt_to_audio(self,msg):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
###########################################################################################
|
||||||
|
class EdgeTTS(BaseTTS):
|
||||||
|
def txt_to_audio(self,msg):
|
||||||
|
voicename = "zh-CN-YunxiaNeural"
|
||||||
|
text = msg
|
||||||
|
t = time.time()
|
||||||
|
asyncio.new_event_loop().run_until_complete(self.__main(voicename,text))
|
||||||
|
print(f'-------edge tts time:{time.time()-t:.4f}s')
|
||||||
|
if self.input_stream.getbuffer().nbytes<=0: #edgetts err
|
||||||
|
print('edgetts err!!!!!')
|
||||||
|
return
|
||||||
|
|
||||||
|
self.input_stream.seek(0)
|
||||||
|
stream = self.__create_bytes_stream(self.input_stream)
|
||||||
|
streamlen = stream.shape[0]
|
||||||
|
idx=0
|
||||||
|
while streamlen >= self.chunk and self.state==State.RUNNING:
|
||||||
|
self.parent.put_audio_frame(stream[idx:idx+self.chunk])
|
||||||
|
streamlen -= self.chunk
|
||||||
|
idx += self.chunk
|
||||||
|
#if streamlen>0: #skip last frame(not 20ms)
|
||||||
|
# self.queue.put(stream[idx:])
|
||||||
|
self.input_stream.seek(0)
|
||||||
|
self.input_stream.truncate()
|
||||||
|
|
||||||
|
def __create_bytes_stream(self,byte_stream):
|
||||||
|
#byte_stream=BytesIO(buffer)
|
||||||
|
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
|
||||||
|
print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}')
|
||||||
|
stream = stream.astype(np.float32)
|
||||||
|
|
||||||
|
if stream.ndim > 1:
|
||||||
|
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
|
||||||
|
stream = stream[:, 0]
|
||||||
|
|
||||||
|
if sample_rate != self.sample_rate and stream.shape[0]>0:
|
||||||
|
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.')
|
||||||
|
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate)
|
||||||
|
|
||||||
|
return stream
|
||||||
|
|
||||||
|
async def __main(self,voicename: str, text: str):
|
||||||
|
try:
|
||||||
|
communicate = edge_tts.Communicate(text, voicename)
|
||||||
|
|
||||||
|
#with open(OUTPUT_FILE, "wb") as file:
|
||||||
|
first = True
|
||||||
|
async for chunk in communicate.stream():
|
||||||
|
if first:
|
||||||
|
first = False
|
||||||
|
if chunk["type"] == "audio" and self.state==State.RUNNING:
|
||||||
|
#self.push_audio(chunk["data"])
|
||||||
|
self.input_stream.write(chunk["data"])
|
||||||
|
#file.write(chunk["data"])
|
||||||
|
elif chunk["type"] == "WordBoundary":
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
###########################################################################################
|
||||||
|
class VoitsTTS(BaseTTS):
|
||||||
|
def txt_to_audio(self,msg):
|
||||||
|
self.stream_tts(
|
||||||
|
self.gpt_sovits(
|
||||||
|
msg,
|
||||||
|
self.opt.REF_FILE,
|
||||||
|
self.opt.REF_TEXT,
|
||||||
|
"zh", #en args.language,
|
||||||
|
self.opt.TTS_SERVER, #"http://127.0.0.1:5000", #args.server_url,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def gpt_sovits(self, text, reffile, reftext,language, server_url) -> Iterator[bytes]:
|
||||||
|
start = time.perf_counter()
|
||||||
|
req={
|
||||||
|
'text':text,
|
||||||
|
'text_lang':language,
|
||||||
|
'ref_audio_path':reffile,
|
||||||
|
'prompt_text':reftext,
|
||||||
|
'prompt_lang':language,
|
||||||
|
'media_type':'raw',
|
||||||
|
'streaming_mode':True
|
||||||
|
}
|
||||||
|
# req["text"] = text
|
||||||
|
# req["text_language"] = language
|
||||||
|
# req["character"] = character
|
||||||
|
# req["emotion"] = emotion
|
||||||
|
# #req["stream_chunk_size"] = stream_chunk_size # you can reduce it to get faster response, but degrade quality
|
||||||
|
# req["streaming_mode"] = True
|
||||||
|
try:
|
||||||
|
res = requests.post(
|
||||||
|
f"{server_url}/tts",
|
||||||
|
json=req,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
end = time.perf_counter()
|
||||||
|
print(f"gpt_sovits Time to make POST: {end-start}s")
|
||||||
|
|
||||||
|
if res.status_code != 200:
|
||||||
|
print("Error:", res.text)
|
||||||
|
return
|
||||||
|
|
||||||
|
first = True
|
||||||
|
|
||||||
|
for chunk in res.iter_content(chunk_size=12800): # 1280 32K*20ms*2
|
||||||
|
if first:
|
||||||
|
end = time.perf_counter()
|
||||||
|
print(f"gpt_sovits Time to first chunk: {end-start}s")
|
||||||
|
first = False
|
||||||
|
if chunk and self.state==State.RUNNING:
|
||||||
|
yield chunk
|
||||||
|
#print("gpt_sovits response.elapsed:", res.elapsed)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
def stream_tts(self,audio_stream):
|
||||||
|
for chunk in audio_stream:
|
||||||
|
if chunk is not None and len(chunk)>0:
|
||||||
|
stream = np.frombuffer(chunk, dtype=np.int16).astype(np.float32) / 32767
|
||||||
|
stream = resampy.resample(x=stream, sr_orig=32000, sr_new=self.sample_rate)
|
||||||
|
#byte_stream=BytesIO(buffer)
|
||||||
|
#stream = self.__create_bytes_stream(byte_stream)
|
||||||
|
streamlen = stream.shape[0]
|
||||||
|
idx=0
|
||||||
|
while streamlen >= self.chunk:
|
||||||
|
self.parent.put_audio_frame(stream[idx:idx+self.chunk])
|
||||||
|
streamlen -= self.chunk
|
||||||
|
idx += self.chunk
|
||||||
|
|
||||||
|
###########################################################################################
|
||||||
|
class CosyVoiceTTS(BaseTTS):
|
||||||
|
def txt_to_audio(self,msg):
|
||||||
|
self.stream_tts(
|
||||||
|
self.cosy_voice(
|
||||||
|
msg,
|
||||||
|
self.opt.REF_FILE,
|
||||||
|
self.opt.REF_TEXT,
|
||||||
|
"zh", #en args.language,
|
||||||
|
self.opt.TTS_SERVER, #"http://127.0.0.1:5000", #args.server_url,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def cosy_voice(self, text, reffile, reftext,language, server_url) -> Iterator[bytes]:
|
||||||
|
start = time.perf_counter()
|
||||||
|
payload = {
|
||||||
|
'tts_text': text,
|
||||||
|
'prompt_text': reftext
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
files = [('prompt_wav', ('prompt_wav', open(reffile, 'rb'), 'application/octet-stream'))]
|
||||||
|
res = requests.request("GET", f"{server_url}/inference_zero_shot", data=payload, files=files, stream=True)
|
||||||
|
|
||||||
|
end = time.perf_counter()
|
||||||
|
print(f"cosy_voice Time to make POST: {end-start}s")
|
||||||
|
|
||||||
|
if res.status_code != 200:
|
||||||
|
print("Error:", res.text)
|
||||||
|
return
|
||||||
|
|
||||||
|
first = True
|
||||||
|
|
||||||
|
for chunk in res.iter_content(chunk_size=8820): # 882 22.05K*20ms*2
|
||||||
|
if first:
|
||||||
|
end = time.perf_counter()
|
||||||
|
print(f"cosy_voice Time to first chunk: {end-start}s")
|
||||||
|
first = False
|
||||||
|
if chunk and self.state==State.RUNNING:
|
||||||
|
yield chunk
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
def stream_tts(self,audio_stream):
|
||||||
|
for chunk in audio_stream:
|
||||||
|
if chunk is not None and len(chunk)>0:
|
||||||
|
stream = np.frombuffer(chunk, dtype=np.int16).astype(np.float32) / 32767
|
||||||
|
stream = resampy.resample(x=stream, sr_orig=22050, sr_new=self.sample_rate)
|
||||||
|
#byte_stream=BytesIO(buffer)
|
||||||
|
#stream = self.__create_bytes_stream(byte_stream)
|
||||||
|
streamlen = stream.shape[0]
|
||||||
|
idx=0
|
||||||
|
while streamlen >= self.chunk:
|
||||||
|
self.parent.put_audio_frame(stream[idx:idx+self.chunk])
|
||||||
|
streamlen -= self.chunk
|
||||||
|
idx += self.chunk
|
||||||
|
|
||||||
|
###########################################################################################
|
||||||
|
class XTTS(BaseTTS):
|
||||||
|
def __init__(self, opt, parent):
|
||||||
|
super().__init__(opt,parent)
|
||||||
|
self.speaker = self.get_speaker(opt.REF_FILE, opt.TTS_SERVER)
|
||||||
|
|
||||||
|
def txt_to_audio(self,msg):
|
||||||
|
self.stream_tts(
|
||||||
|
self.xtts(
|
||||||
|
msg,
|
||||||
|
self.speaker,
|
||||||
|
"zh-cn", #en args.language,
|
||||||
|
self.opt.TTS_SERVER, #"http://localhost:9000", #args.server_url,
|
||||||
|
"20" #args.stream_chunk_size
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_speaker(self,ref_audio,server_url):
|
||||||
|
files = {"wav_file": ("reference.wav", open(ref_audio, "rb"))}
|
||||||
|
response = requests.post(f"{server_url}/clone_speaker", files=files)
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
def xtts(self,text, speaker, language, server_url, stream_chunk_size) -> Iterator[bytes]:
|
||||||
|
start = time.perf_counter()
|
||||||
|
speaker["text"] = text
|
||||||
|
speaker["language"] = language
|
||||||
|
speaker["stream_chunk_size"] = stream_chunk_size # you can reduce it to get faster response, but degrade quality
|
||||||
|
try:
|
||||||
|
res = requests.post(
|
||||||
|
f"{server_url}/tts_stream",
|
||||||
|
json=speaker,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
end = time.perf_counter()
|
||||||
|
print(f"xtts Time to make POST: {end-start}s")
|
||||||
|
|
||||||
|
if res.status_code != 200:
|
||||||
|
print("Error:", res.text)
|
||||||
|
return
|
||||||
|
|
||||||
|
first = True
|
||||||
|
|
||||||
|
for chunk in res.iter_content(chunk_size=9600): #24K*20ms*2
|
||||||
|
if first:
|
||||||
|
end = time.perf_counter()
|
||||||
|
print(f"xtts Time to first chunk: {end-start}s")
|
||||||
|
first = False
|
||||||
|
if chunk:
|
||||||
|
yield chunk
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
def stream_tts(self,audio_stream):
|
||||||
|
for chunk in audio_stream:
|
||||||
|
if chunk is not None and len(chunk)>0:
|
||||||
|
stream = np.frombuffer(chunk, dtype=np.int16).astype(np.float32) / 32767
|
||||||
|
stream = resampy.resample(x=stream, sr_orig=24000, sr_new=self.sample_rate)
|
||||||
|
#byte_stream=BytesIO(buffer)
|
||||||
|
#stream = self.__create_bytes_stream(byte_stream)
|
||||||
|
streamlen = stream.shape[0]
|
||||||
|
idx=0
|
||||||
|
while streamlen >= self.chunk:
|
||||||
|
self.parent.put_audio_frame(stream[idx:idx+self.chunk])
|
||||||
|
streamlen -= self.chunk
|
||||||
|
idx += self.chunk
|
|
@ -0,0 +1,136 @@
|
||||||
|
import librosa
|
||||||
|
import librosa.filters
|
||||||
|
import numpy as np
|
||||||
|
# import tensorflow as tf
|
||||||
|
from scipy import signal
|
||||||
|
from scipy.io import wavfile
|
||||||
|
from .hparams import hparams as hp
|
||||||
|
|
||||||
|
def load_wav(path, sr):
|
||||||
|
return librosa.core.load(path, sr=sr)[0]
|
||||||
|
|
||||||
|
def save_wav(wav, path, sr):
|
||||||
|
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
|
||||||
|
#proposed by @dsmiller
|
||||||
|
wavfile.write(path, sr, wav.astype(np.int16))
|
||||||
|
|
||||||
|
def save_wavenet_wav(wav, path, sr):
|
||||||
|
librosa.output.write_wav(path, wav, sr=sr)
|
||||||
|
|
||||||
|
def preemphasis(wav, k, preemphasize=True):
|
||||||
|
if preemphasize:
|
||||||
|
return signal.lfilter([1, -k], [1], wav)
|
||||||
|
return wav
|
||||||
|
|
||||||
|
def inv_preemphasis(wav, k, inv_preemphasize=True):
|
||||||
|
if inv_preemphasize:
|
||||||
|
return signal.lfilter([1], [1, -k], wav)
|
||||||
|
return wav
|
||||||
|
|
||||||
|
def get_hop_size():
|
||||||
|
hop_size = hp.hop_size
|
||||||
|
if hop_size is None:
|
||||||
|
assert hp.frame_shift_ms is not None
|
||||||
|
hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate)
|
||||||
|
return hop_size
|
||||||
|
|
||||||
|
def linearspectrogram(wav):
|
||||||
|
D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
|
||||||
|
S = _amp_to_db(np.abs(D)) - hp.ref_level_db
|
||||||
|
|
||||||
|
if hp.signal_normalization:
|
||||||
|
return _normalize(S)
|
||||||
|
return S
|
||||||
|
|
||||||
|
def melspectrogram(wav):
|
||||||
|
D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
|
||||||
|
S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db
|
||||||
|
|
||||||
|
if hp.signal_normalization:
|
||||||
|
return _normalize(S)
|
||||||
|
return S
|
||||||
|
|
||||||
|
def _lws_processor():
|
||||||
|
import lws
|
||||||
|
return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech")
|
||||||
|
|
||||||
|
def _stft(y):
|
||||||
|
if hp.use_lws:
|
||||||
|
return _lws_processor(hp).stft(y).T
|
||||||
|
else:
|
||||||
|
return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size)
|
||||||
|
|
||||||
|
##########################################################
|
||||||
|
#Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
|
||||||
|
def num_frames(length, fsize, fshift):
|
||||||
|
"""Compute number of time frames of spectrogram
|
||||||
|
"""
|
||||||
|
pad = (fsize - fshift)
|
||||||
|
if length % fshift == 0:
|
||||||
|
M = (length + pad * 2 - fsize) // fshift + 1
|
||||||
|
else:
|
||||||
|
M = (length + pad * 2 - fsize) // fshift + 2
|
||||||
|
return M
|
||||||
|
|
||||||
|
|
||||||
|
def pad_lr(x, fsize, fshift):
|
||||||
|
"""Compute left and right padding
|
||||||
|
"""
|
||||||
|
M = num_frames(len(x), fsize, fshift)
|
||||||
|
pad = (fsize - fshift)
|
||||||
|
T = len(x) + 2 * pad
|
||||||
|
r = (M - 1) * fshift + fsize - T
|
||||||
|
return pad, pad + r
|
||||||
|
##########################################################
|
||||||
|
#Librosa correct padding
|
||||||
|
def librosa_pad_lr(x, fsize, fshift):
|
||||||
|
return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
|
||||||
|
|
||||||
|
# Conversions
|
||||||
|
_mel_basis = None
|
||||||
|
|
||||||
|
def _linear_to_mel(spectogram):
|
||||||
|
global _mel_basis
|
||||||
|
if _mel_basis is None:
|
||||||
|
_mel_basis = _build_mel_basis()
|
||||||
|
return np.dot(_mel_basis, spectogram)
|
||||||
|
|
||||||
|
def _build_mel_basis():
|
||||||
|
assert hp.fmax <= hp.sample_rate // 2
|
||||||
|
return librosa.filters.mel(sr=float(hp.sample_rate), n_fft=hp.n_fft, n_mels=hp.num_mels,
|
||||||
|
fmin=hp.fmin, fmax=hp.fmax)
|
||||||
|
|
||||||
|
def _amp_to_db(x):
|
||||||
|
min_level = np.exp(hp.min_level_db / 20 * np.log(10))
|
||||||
|
return 20 * np.log10(np.maximum(min_level, x))
|
||||||
|
|
||||||
|
def _db_to_amp(x):
|
||||||
|
return np.power(10.0, (x) * 0.05)
|
||||||
|
|
||||||
|
def _normalize(S):
|
||||||
|
if hp.allow_clipping_in_normalization:
|
||||||
|
if hp.symmetric_mels:
|
||||||
|
return np.clip((2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value,
|
||||||
|
-hp.max_abs_value, hp.max_abs_value)
|
||||||
|
else:
|
||||||
|
return np.clip(hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 0, hp.max_abs_value)
|
||||||
|
|
||||||
|
assert S.max() <= 0 and S.min() - hp.min_level_db >= 0
|
||||||
|
if hp.symmetric_mels:
|
||||||
|
return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value
|
||||||
|
else:
|
||||||
|
return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db))
|
||||||
|
|
||||||
|
def _denormalize(D):
|
||||||
|
if hp.allow_clipping_in_normalization:
|
||||||
|
if hp.symmetric_mels:
|
||||||
|
return (((np.clip(D, -hp.max_abs_value,
|
||||||
|
hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value))
|
||||||
|
+ hp.min_level_db)
|
||||||
|
else:
|
||||||
|
return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
|
||||||
|
|
||||||
|
if hp.symmetric_mels:
|
||||||
|
return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db)
|
||||||
|
else:
|
||||||
|
return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
|
|
@ -0,0 +1 @@
|
||||||
|
The code for Face Detection in this folder has been taken from the wonderful [face_alignment](https://github.com/1adrianb/face-alignment) repository. This has been modified to take batches of faces at a time.
|
|
@ -0,0 +1,7 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
__author__ = """Adrian Bulat"""
|
||||||
|
__email__ = 'adrian.bulat@nottingham.ac.uk'
|
||||||
|
__version__ = '1.0.1'
|
||||||
|
|
||||||
|
from .api import FaceAlignment, LandmarksType, NetworkSize
|
|
@ -0,0 +1,79 @@
|
||||||
|
from __future__ import print_function
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from torch.utils.model_zoo import load_url
|
||||||
|
from enum import Enum
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
try:
|
||||||
|
import urllib.request as request_file
|
||||||
|
except BaseException:
|
||||||
|
import urllib as request_file
|
||||||
|
|
||||||
|
from .models import FAN, ResNetDepth
|
||||||
|
from .utils import *
|
||||||
|
|
||||||
|
|
||||||
|
class LandmarksType(Enum):
|
||||||
|
"""Enum class defining the type of landmarks to detect.
|
||||||
|
|
||||||
|
``_2D`` - the detected points ``(x,y)`` are detected in a 2D space and follow the visible contour of the face
|
||||||
|
``_2halfD`` - this points represent the projection of the 3D points into 3D
|
||||||
|
``_3D`` - detect the points ``(x,y,z)``` in a 3D space
|
||||||
|
|
||||||
|
"""
|
||||||
|
_2D = 1
|
||||||
|
_2halfD = 2
|
||||||
|
_3D = 3
|
||||||
|
|
||||||
|
|
||||||
|
class NetworkSize(Enum):
|
||||||
|
# TINY = 1
|
||||||
|
# SMALL = 2
|
||||||
|
# MEDIUM = 3
|
||||||
|
LARGE = 4
|
||||||
|
|
||||||
|
def __new__(cls, value):
|
||||||
|
member = object.__new__(cls)
|
||||||
|
member._value_ = value
|
||||||
|
return member
|
||||||
|
|
||||||
|
def __int__(self):
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
ROOT = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
class FaceAlignment:
|
||||||
|
def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
|
||||||
|
device='cuda', flip_input=False, face_detector='sfd', verbose=False):
|
||||||
|
self.device = device
|
||||||
|
self.flip_input = flip_input
|
||||||
|
self.landmarks_type = landmarks_type
|
||||||
|
self.verbose = verbose
|
||||||
|
|
||||||
|
network_size = int(network_size)
|
||||||
|
|
||||||
|
if 'cuda' in device:
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
|
# Get the face detector
|
||||||
|
face_detector_module = __import__('face_detection.detection.' + face_detector,
|
||||||
|
globals(), locals(), [face_detector], 0)
|
||||||
|
self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose)
|
||||||
|
|
||||||
|
def get_detections_for_batch(self, images):
|
||||||
|
images = images[..., ::-1]
|
||||||
|
detected_faces = self.face_detector.detect_from_batch(images.copy())
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for i, d in enumerate(detected_faces):
|
||||||
|
if len(d) == 0:
|
||||||
|
results.append(None)
|
||||||
|
continue
|
||||||
|
d = d[0]
|
||||||
|
d = np.clip(d, 0, None)
|
||||||
|
|
||||||
|
x1, y1, x2, y2 = map(int, d[:-1])
|
||||||
|
results.append((x1, y1, x2, y2))
|
||||||
|
|
||||||
|
return results
|
|
@ -0,0 +1 @@
|
||||||
|
from .core import FaceDetector
|
|
@ -0,0 +1,130 @@
|
||||||
|
import logging
|
||||||
|
import glob
|
||||||
|
from tqdm import tqdm
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
|
||||||
|
class FaceDetector(object):
|
||||||
|
"""An abstract class representing a face detector.
|
||||||
|
|
||||||
|
Any other face detection implementation must subclass it. All subclasses
|
||||||
|
must implement ``detect_from_image``, that return a list of detected
|
||||||
|
bounding boxes. Optionally, for speed considerations detect from path is
|
||||||
|
recommended.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, device, verbose):
|
||||||
|
self.device = device
|
||||||
|
self.verbose = verbose
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
if 'cpu' in device:
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.warning("Detection running on CPU, this may be potentially slow.")
|
||||||
|
|
||||||
|
if 'cpu' not in device and 'cuda' not in device:
|
||||||
|
if verbose:
|
||||||
|
logger.error("Expected values for device are: {cpu, cuda} but got: %s", device)
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
def detect_from_image(self, tensor_or_path):
|
||||||
|
"""Detects faces in a given image.
|
||||||
|
|
||||||
|
This function detects the faces present in a provided BGR(usually)
|
||||||
|
image. The input can be either the image itself or the path to it.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
tensor_or_path {numpy.ndarray, torch.tensor or string} -- the path
|
||||||
|
to an image or the image itself.
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> path_to_image = 'data/image_01.jpg'
|
||||||
|
... detected_faces = detect_from_image(path_to_image)
|
||||||
|
[A list of bounding boxes (x1, y1, x2, y2)]
|
||||||
|
>>> image = cv2.imread(path_to_image)
|
||||||
|
... detected_faces = detect_from_image(image)
|
||||||
|
[A list of bounding boxes (x1, y1, x2, y2)]
|
||||||
|
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def detect_from_directory(self, path, extensions=['.jpg', '.png'], recursive=False, show_progress_bar=True):
|
||||||
|
"""Detects faces from all the images present in a given directory.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
path {string} -- a string containing a path that points to the folder containing the images
|
||||||
|
|
||||||
|
Keyword Arguments:
|
||||||
|
extensions {list} -- list of string containing the extensions to be
|
||||||
|
consider in the following format: ``.extension_name`` (default:
|
||||||
|
{['.jpg', '.png']}) recursive {bool} -- option wherever to scan the
|
||||||
|
folder recursively (default: {False}) show_progress_bar {bool} --
|
||||||
|
display a progressbar (default: {True})
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> directory = 'data'
|
||||||
|
... detected_faces = detect_from_directory(directory)
|
||||||
|
{A dictionary of [lists containing bounding boxes(x1, y1, x2, y2)]}
|
||||||
|
|
||||||
|
"""
|
||||||
|
if self.verbose:
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
if len(extensions) == 0:
|
||||||
|
if self.verbose:
|
||||||
|
logger.error("Expected at list one extension, but none was received.")
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
logger.info("Constructing the list of images.")
|
||||||
|
additional_pattern = '/**/*' if recursive else '/*'
|
||||||
|
files = []
|
||||||
|
for extension in extensions:
|
||||||
|
files.extend(glob.glob(path + additional_pattern + extension, recursive=recursive))
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
logger.info("Finished searching for images. %s images found", len(files))
|
||||||
|
logger.info("Preparing to run the detection.")
|
||||||
|
|
||||||
|
predictions = {}
|
||||||
|
for image_path in tqdm(files, disable=not show_progress_bar):
|
||||||
|
if self.verbose:
|
||||||
|
logger.info("Running the face detector on image: %s", image_path)
|
||||||
|
predictions[image_path] = self.detect_from_image(image_path)
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
logger.info("The detector was successfully run on all %s images", len(files))
|
||||||
|
|
||||||
|
return predictions
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reference_scale(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reference_x_shift(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reference_y_shift(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def tensor_or_path_to_ndarray(tensor_or_path, rgb=True):
|
||||||
|
"""Convert path (represented as a string) or torch.tensor to a numpy.ndarray
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
tensor_or_path {numpy.ndarray, torch.tensor or string} -- path to the image, or the image itself
|
||||||
|
"""
|
||||||
|
if isinstance(tensor_or_path, str):
|
||||||
|
return cv2.imread(tensor_or_path) if not rgb else cv2.imread(tensor_or_path)[..., ::-1]
|
||||||
|
elif torch.is_tensor(tensor_or_path):
|
||||||
|
# Call cpu in case its coming from cuda
|
||||||
|
return tensor_or_path.cpu().numpy()[..., ::-1].copy() if not rgb else tensor_or_path.cpu().numpy()
|
||||||
|
elif isinstance(tensor_or_path, np.ndarray):
|
||||||
|
return tensor_or_path[..., ::-1].copy() if not rgb else tensor_or_path
|
||||||
|
else:
|
||||||
|
raise TypeError
|
|
@ -0,0 +1 @@
|
||||||
|
from .sfd_detector import SFDDetector as FaceDetector
|
|
@ -0,0 +1,129 @@
|
||||||
|
from __future__ import print_function
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import cv2
|
||||||
|
import random
|
||||||
|
import datetime
|
||||||
|
import time
|
||||||
|
import math
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
try:
|
||||||
|
from iou import IOU
|
||||||
|
except BaseException:
|
||||||
|
# IOU cython speedup 10x
|
||||||
|
def IOU(ax1, ay1, ax2, ay2, bx1, by1, bx2, by2):
|
||||||
|
sa = abs((ax2 - ax1) * (ay2 - ay1))
|
||||||
|
sb = abs((bx2 - bx1) * (by2 - by1))
|
||||||
|
x1, y1 = max(ax1, bx1), max(ay1, by1)
|
||||||
|
x2, y2 = min(ax2, bx2), min(ay2, by2)
|
||||||
|
w = x2 - x1
|
||||||
|
h = y2 - y1
|
||||||
|
if w < 0 or h < 0:
|
||||||
|
return 0.0
|
||||||
|
else:
|
||||||
|
return 1.0 * w * h / (sa + sb - w * h)
|
||||||
|
|
||||||
|
|
||||||
|
def bboxlog(x1, y1, x2, y2, axc, ayc, aww, ahh):
|
||||||
|
xc, yc, ww, hh = (x2 + x1) / 2, (y2 + y1) / 2, x2 - x1, y2 - y1
|
||||||
|
dx, dy = (xc - axc) / aww, (yc - ayc) / ahh
|
||||||
|
dw, dh = math.log(ww / aww), math.log(hh / ahh)
|
||||||
|
return dx, dy, dw, dh
|
||||||
|
|
||||||
|
|
||||||
|
def bboxloginv(dx, dy, dw, dh, axc, ayc, aww, ahh):
|
||||||
|
xc, yc = dx * aww + axc, dy * ahh + ayc
|
||||||
|
ww, hh = math.exp(dw) * aww, math.exp(dh) * ahh
|
||||||
|
x1, x2, y1, y2 = xc - ww / 2, xc + ww / 2, yc - hh / 2, yc + hh / 2
|
||||||
|
return x1, y1, x2, y2
|
||||||
|
|
||||||
|
|
||||||
|
def nms(dets, thresh):
|
||||||
|
if 0 == len(dets):
|
||||||
|
return []
|
||||||
|
x1, y1, x2, y2, scores = dets[:, 0], dets[:, 1], dets[:, 2], dets[:, 3], dets[:, 4]
|
||||||
|
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||||
|
order = scores.argsort()[::-1]
|
||||||
|
|
||||||
|
keep = []
|
||||||
|
while order.size > 0:
|
||||||
|
i = order[0]
|
||||||
|
keep.append(i)
|
||||||
|
xx1, yy1 = np.maximum(x1[i], x1[order[1:]]), np.maximum(y1[i], y1[order[1:]])
|
||||||
|
xx2, yy2 = np.minimum(x2[i], x2[order[1:]]), np.minimum(y2[i], y2[order[1:]])
|
||||||
|
|
||||||
|
w, h = np.maximum(0.0, xx2 - xx1 + 1), np.maximum(0.0, yy2 - yy1 + 1)
|
||||||
|
ovr = w * h / (areas[i] + areas[order[1:]] - w * h)
|
||||||
|
|
||||||
|
inds = np.where(ovr <= thresh)[0]
|
||||||
|
order = order[inds + 1]
|
||||||
|
|
||||||
|
return keep
|
||||||
|
|
||||||
|
|
||||||
|
def encode(matched, priors, variances):
|
||||||
|
"""Encode the variances from the priorbox layers into the ground truth boxes
|
||||||
|
we have matched (based on jaccard overlap) with the prior boxes.
|
||||||
|
Args:
|
||||||
|
matched: (tensor) Coords of ground truth for each prior in point-form
|
||||||
|
Shape: [num_priors, 4].
|
||||||
|
priors: (tensor) Prior boxes in center-offset form
|
||||||
|
Shape: [num_priors,4].
|
||||||
|
variances: (list[float]) Variances of priorboxes
|
||||||
|
Return:
|
||||||
|
encoded boxes (tensor), Shape: [num_priors, 4]
|
||||||
|
"""
|
||||||
|
|
||||||
|
# dist b/t match center and prior's center
|
||||||
|
g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2]
|
||||||
|
# encode variance
|
||||||
|
g_cxcy /= (variances[0] * priors[:, 2:])
|
||||||
|
# match wh / prior wh
|
||||||
|
g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
|
||||||
|
g_wh = torch.log(g_wh) / variances[1]
|
||||||
|
# return target for smooth_l1_loss
|
||||||
|
return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
|
||||||
|
|
||||||
|
|
||||||
|
def decode(loc, priors, variances):
|
||||||
|
"""Decode locations from predictions using priors to undo
|
||||||
|
the encoding we did for offset regression at train time.
|
||||||
|
Args:
|
||||||
|
loc (tensor): location predictions for loc layers,
|
||||||
|
Shape: [num_priors,4]
|
||||||
|
priors (tensor): Prior boxes in center-offset form.
|
||||||
|
Shape: [num_priors,4].
|
||||||
|
variances: (list[float]) Variances of priorboxes
|
||||||
|
Return:
|
||||||
|
decoded bounding box predictions
|
||||||
|
"""
|
||||||
|
|
||||||
|
boxes = torch.cat((
|
||||||
|
priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
|
||||||
|
priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
|
||||||
|
boxes[:, :2] -= boxes[:, 2:] / 2
|
||||||
|
boxes[:, 2:] += boxes[:, :2]
|
||||||
|
return boxes
|
||||||
|
|
||||||
|
def batch_decode(loc, priors, variances):
|
||||||
|
"""Decode locations from predictions using priors to undo
|
||||||
|
the encoding we did for offset regression at train time.
|
||||||
|
Args:
|
||||||
|
loc (tensor): location predictions for loc layers,
|
||||||
|
Shape: [num_priors,4]
|
||||||
|
priors (tensor): Prior boxes in center-offset form.
|
||||||
|
Shape: [num_priors,4].
|
||||||
|
variances: (list[float]) Variances of priorboxes
|
||||||
|
Return:
|
||||||
|
decoded bounding box predictions
|
||||||
|
"""
|
||||||
|
|
||||||
|
boxes = torch.cat((
|
||||||
|
priors[:, :, :2] + loc[:, :, :2] * variances[0] * priors[:, :, 2:],
|
||||||
|
priors[:, :, 2:] * torch.exp(loc[:, :, 2:] * variances[1])), 2)
|
||||||
|
boxes[:, :, :2] -= boxes[:, :, 2:] / 2
|
||||||
|
boxes[:, :, 2:] += boxes[:, :, :2]
|
||||||
|
return boxes
|
|
@ -0,0 +1,112 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import cv2
|
||||||
|
import random
|
||||||
|
import datetime
|
||||||
|
import math
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import scipy.io as sio
|
||||||
|
import zipfile
|
||||||
|
from .net_s3fd import s3fd
|
||||||
|
from .bbox import *
|
||||||
|
|
||||||
|
|
||||||
|
def detect(net, img, device):
|
||||||
|
img = img - np.array([104, 117, 123])
|
||||||
|
img = img.transpose(2, 0, 1)
|
||||||
|
img = img.reshape((1,) + img.shape)
|
||||||
|
|
||||||
|
if 'cuda' in device:
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
|
img = torch.from_numpy(img).float().to(device)
|
||||||
|
BB, CC, HH, WW = img.size()
|
||||||
|
with torch.no_grad():
|
||||||
|
olist = net(img)
|
||||||
|
|
||||||
|
bboxlist = []
|
||||||
|
for i in range(len(olist) // 2):
|
||||||
|
olist[i * 2] = F.softmax(olist[i * 2], dim=1)
|
||||||
|
olist = [oelem.data.cpu() for oelem in olist]
|
||||||
|
for i in range(len(olist) // 2):
|
||||||
|
ocls, oreg = olist[i * 2], olist[i * 2 + 1]
|
||||||
|
FB, FC, FH, FW = ocls.size() # feature map size
|
||||||
|
stride = 2**(i + 2) # 4,8,16,32,64,128
|
||||||
|
anchor = stride * 4
|
||||||
|
poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
|
||||||
|
for Iindex, hindex, windex in poss:
|
||||||
|
axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
|
||||||
|
score = ocls[0, 1, hindex, windex]
|
||||||
|
loc = oreg[0, :, hindex, windex].contiguous().view(1, 4)
|
||||||
|
priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]])
|
||||||
|
variances = [0.1, 0.2]
|
||||||
|
box = decode(loc, priors, variances)
|
||||||
|
x1, y1, x2, y2 = box[0] * 1.0
|
||||||
|
# cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
|
||||||
|
bboxlist.append([x1, y1, x2, y2, score])
|
||||||
|
bboxlist = np.array(bboxlist)
|
||||||
|
if 0 == len(bboxlist):
|
||||||
|
bboxlist = np.zeros((1, 5))
|
||||||
|
|
||||||
|
return bboxlist
|
||||||
|
|
||||||
|
def batch_detect(net, imgs, device):
|
||||||
|
imgs = imgs - np.array([104, 117, 123])
|
||||||
|
imgs = imgs.transpose(0, 3, 1, 2)
|
||||||
|
|
||||||
|
if 'cuda' in device:
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
|
imgs = torch.from_numpy(imgs).float().to(device)
|
||||||
|
BB, CC, HH, WW = imgs.size()
|
||||||
|
with torch.no_grad():
|
||||||
|
olist = net(imgs)
|
||||||
|
|
||||||
|
bboxlist = []
|
||||||
|
for i in range(len(olist) // 2):
|
||||||
|
olist[i * 2] = F.softmax(olist[i * 2], dim=1)
|
||||||
|
olist = [oelem.data.cpu() for oelem in olist]
|
||||||
|
for i in range(len(olist) // 2):
|
||||||
|
ocls, oreg = olist[i * 2], olist[i * 2 + 1]
|
||||||
|
FB, FC, FH, FW = ocls.size() # feature map size
|
||||||
|
stride = 2**(i + 2) # 4,8,16,32,64,128
|
||||||
|
anchor = stride * 4
|
||||||
|
poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
|
||||||
|
for Iindex, hindex, windex in poss:
|
||||||
|
axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
|
||||||
|
score = ocls[:, 1, hindex, windex]
|
||||||
|
loc = oreg[:, :, hindex, windex].contiguous().view(BB, 1, 4)
|
||||||
|
priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]]).view(1, 1, 4)
|
||||||
|
variances = [0.1, 0.2]
|
||||||
|
box = batch_decode(loc, priors, variances)
|
||||||
|
box = box[:, 0] * 1.0
|
||||||
|
# cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
|
||||||
|
bboxlist.append(torch.cat([box, score.unsqueeze(1)], 1).cpu().numpy())
|
||||||
|
bboxlist = np.array(bboxlist)
|
||||||
|
if 0 == len(bboxlist):
|
||||||
|
bboxlist = np.zeros((1, BB, 5))
|
||||||
|
|
||||||
|
return bboxlist
|
||||||
|
|
||||||
|
def flip_detect(net, img, device):
|
||||||
|
img = cv2.flip(img, 1)
|
||||||
|
b = detect(net, img, device)
|
||||||
|
|
||||||
|
bboxlist = np.zeros(b.shape)
|
||||||
|
bboxlist[:, 0] = img.shape[1] - b[:, 2]
|
||||||
|
bboxlist[:, 1] = b[:, 1]
|
||||||
|
bboxlist[:, 2] = img.shape[1] - b[:, 0]
|
||||||
|
bboxlist[:, 3] = b[:, 3]
|
||||||
|
bboxlist[:, 4] = b[:, 4]
|
||||||
|
return bboxlist
|
||||||
|
|
||||||
|
|
||||||
|
def pts_to_bb(pts):
|
||||||
|
min_x, min_y = np.min(pts, axis=0)
|
||||||
|
max_x, max_y = np.max(pts, axis=0)
|
||||||
|
return np.array([min_x, min_y, max_x, max_y])
|
|
@ -0,0 +1,129 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class L2Norm(nn.Module):
|
||||||
|
def __init__(self, n_channels, scale=1.0):
|
||||||
|
super(L2Norm, self).__init__()
|
||||||
|
self.n_channels = n_channels
|
||||||
|
self.scale = scale
|
||||||
|
self.eps = 1e-10
|
||||||
|
self.weight = nn.Parameter(torch.Tensor(self.n_channels))
|
||||||
|
self.weight.data *= 0.0
|
||||||
|
self.weight.data += self.scale
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps
|
||||||
|
x = x / norm * self.weight.view(1, -1, 1, 1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class s3fd(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(s3fd, self).__init__()
|
||||||
|
self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
|
self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
|
self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
|
self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
|
self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
|
self.fc6 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=3)
|
||||||
|
self.fc7 = nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0)
|
||||||
|
|
||||||
|
self.conv6_1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0)
|
||||||
|
self.conv6_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)
|
||||||
|
|
||||||
|
self.conv7_1 = nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0)
|
||||||
|
self.conv7_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
|
||||||
|
|
||||||
|
self.conv3_3_norm = L2Norm(256, scale=10)
|
||||||
|
self.conv4_3_norm = L2Norm(512, scale=8)
|
||||||
|
self.conv5_3_norm = L2Norm(512, scale=5)
|
||||||
|
|
||||||
|
self.conv3_3_norm_mbox_conf = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.conv3_3_norm_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.conv4_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.conv4_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.conv5_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.conv5_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
|
self.fc7_mbox_conf = nn.Conv2d(1024, 2, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.fc7_mbox_loc = nn.Conv2d(1024, 4, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.conv6_2_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.conv6_2_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.conv7_2_mbox_conf = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.conv7_2_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
h = F.relu(self.conv1_1(x))
|
||||||
|
h = F.relu(self.conv1_2(h))
|
||||||
|
h = F.max_pool2d(h, 2, 2)
|
||||||
|
|
||||||
|
h = F.relu(self.conv2_1(h))
|
||||||
|
h = F.relu(self.conv2_2(h))
|
||||||
|
h = F.max_pool2d(h, 2, 2)
|
||||||
|
|
||||||
|
h = F.relu(self.conv3_1(h))
|
||||||
|
h = F.relu(self.conv3_2(h))
|
||||||
|
h = F.relu(self.conv3_3(h))
|
||||||
|
f3_3 = h
|
||||||
|
h = F.max_pool2d(h, 2, 2)
|
||||||
|
|
||||||
|
h = F.relu(self.conv4_1(h))
|
||||||
|
h = F.relu(self.conv4_2(h))
|
||||||
|
h = F.relu(self.conv4_3(h))
|
||||||
|
f4_3 = h
|
||||||
|
h = F.max_pool2d(h, 2, 2)
|
||||||
|
|
||||||
|
h = F.relu(self.conv5_1(h))
|
||||||
|
h = F.relu(self.conv5_2(h))
|
||||||
|
h = F.relu(self.conv5_3(h))
|
||||||
|
f5_3 = h
|
||||||
|
h = F.max_pool2d(h, 2, 2)
|
||||||
|
|
||||||
|
h = F.relu(self.fc6(h))
|
||||||
|
h = F.relu(self.fc7(h))
|
||||||
|
ffc7 = h
|
||||||
|
h = F.relu(self.conv6_1(h))
|
||||||
|
h = F.relu(self.conv6_2(h))
|
||||||
|
f6_2 = h
|
||||||
|
h = F.relu(self.conv7_1(h))
|
||||||
|
h = F.relu(self.conv7_2(h))
|
||||||
|
f7_2 = h
|
||||||
|
|
||||||
|
f3_3 = self.conv3_3_norm(f3_3)
|
||||||
|
f4_3 = self.conv4_3_norm(f4_3)
|
||||||
|
f5_3 = self.conv5_3_norm(f5_3)
|
||||||
|
|
||||||
|
cls1 = self.conv3_3_norm_mbox_conf(f3_3)
|
||||||
|
reg1 = self.conv3_3_norm_mbox_loc(f3_3)
|
||||||
|
cls2 = self.conv4_3_norm_mbox_conf(f4_3)
|
||||||
|
reg2 = self.conv4_3_norm_mbox_loc(f4_3)
|
||||||
|
cls3 = self.conv5_3_norm_mbox_conf(f5_3)
|
||||||
|
reg3 = self.conv5_3_norm_mbox_loc(f5_3)
|
||||||
|
cls4 = self.fc7_mbox_conf(ffc7)
|
||||||
|
reg4 = self.fc7_mbox_loc(ffc7)
|
||||||
|
cls5 = self.conv6_2_mbox_conf(f6_2)
|
||||||
|
reg5 = self.conv6_2_mbox_loc(f6_2)
|
||||||
|
cls6 = self.conv7_2_mbox_conf(f7_2)
|
||||||
|
reg6 = self.conv7_2_mbox_loc(f7_2)
|
||||||
|
|
||||||
|
# max-out background label
|
||||||
|
chunk = torch.chunk(cls1, 4, 1)
|
||||||
|
bmax = torch.max(torch.max(chunk[0], chunk[1]), chunk[2])
|
||||||
|
cls1 = torch.cat([bmax, chunk[3]], dim=1)
|
||||||
|
|
||||||
|
return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6]
|
|
@ -0,0 +1,59 @@
|
||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
from torch.utils.model_zoo import load_url
|
||||||
|
|
||||||
|
from ..core import FaceDetector
|
||||||
|
|
||||||
|
from .net_s3fd import s3fd
|
||||||
|
from .bbox import *
|
||||||
|
from .detect import *
|
||||||
|
|
||||||
|
models_urls = {
|
||||||
|
's3fd': 'https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth',
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class SFDDetector(FaceDetector):
|
||||||
|
def __init__(self, device, path_to_detector=os.path.join(os.path.dirname(os.path.abspath(__file__)), 's3fd.pth'), verbose=False):
|
||||||
|
super(SFDDetector, self).__init__(device, verbose)
|
||||||
|
|
||||||
|
# Initialise the face detector
|
||||||
|
if not os.path.isfile(path_to_detector):
|
||||||
|
model_weights = load_url(models_urls['s3fd'])
|
||||||
|
else:
|
||||||
|
model_weights = torch.load(path_to_detector)
|
||||||
|
|
||||||
|
self.face_detector = s3fd()
|
||||||
|
self.face_detector.load_state_dict(model_weights)
|
||||||
|
self.face_detector.to(device)
|
||||||
|
self.face_detector.eval()
|
||||||
|
|
||||||
|
def detect_from_image(self, tensor_or_path):
|
||||||
|
image = self.tensor_or_path_to_ndarray(tensor_or_path)
|
||||||
|
|
||||||
|
bboxlist = detect(self.face_detector, image, device=self.device)
|
||||||
|
keep = nms(bboxlist, 0.3)
|
||||||
|
bboxlist = bboxlist[keep, :]
|
||||||
|
bboxlist = [x for x in bboxlist if x[-1] > 0.5]
|
||||||
|
|
||||||
|
return bboxlist
|
||||||
|
|
||||||
|
def detect_from_batch(self, images):
|
||||||
|
bboxlists = batch_detect(self.face_detector, images, device=self.device)
|
||||||
|
keeps = [nms(bboxlists[:, i, :], 0.3) for i in range(bboxlists.shape[1])]
|
||||||
|
bboxlists = [bboxlists[keep, i, :] for i, keep in enumerate(keeps)]
|
||||||
|
bboxlists = [[x for x in bboxlist if x[-1] > 0.5] for bboxlist in bboxlists]
|
||||||
|
|
||||||
|
return bboxlists
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reference_scale(self):
|
||||||
|
return 195
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reference_x_shift(self):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reference_y_shift(self):
|
||||||
|
return 0
|
|
@ -0,0 +1,261 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False):
|
||||||
|
"3x3 convolution with padding"
|
||||||
|
return nn.Conv2d(in_planes, out_planes, kernel_size=3,
|
||||||
|
stride=strd, padding=padding, bias=bias)
|
||||||
|
|
||||||
|
|
||||||
|
class ConvBlock(nn.Module):
|
||||||
|
def __init__(self, in_planes, out_planes):
|
||||||
|
super(ConvBlock, self).__init__()
|
||||||
|
self.bn1 = nn.BatchNorm2d(in_planes)
|
||||||
|
self.conv1 = conv3x3(in_planes, int(out_planes / 2))
|
||||||
|
self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
|
||||||
|
self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4))
|
||||||
|
self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
|
||||||
|
self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4))
|
||||||
|
|
||||||
|
if in_planes != out_planes:
|
||||||
|
self.downsample = nn.Sequential(
|
||||||
|
nn.BatchNorm2d(in_planes),
|
||||||
|
nn.ReLU(True),
|
||||||
|
nn.Conv2d(in_planes, out_planes,
|
||||||
|
kernel_size=1, stride=1, bias=False),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.downsample = None
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
residual = x
|
||||||
|
|
||||||
|
out1 = self.bn1(x)
|
||||||
|
out1 = F.relu(out1, True)
|
||||||
|
out1 = self.conv1(out1)
|
||||||
|
|
||||||
|
out2 = self.bn2(out1)
|
||||||
|
out2 = F.relu(out2, True)
|
||||||
|
out2 = self.conv2(out2)
|
||||||
|
|
||||||
|
out3 = self.bn3(out2)
|
||||||
|
out3 = F.relu(out3, True)
|
||||||
|
out3 = self.conv3(out3)
|
||||||
|
|
||||||
|
out3 = torch.cat((out1, out2, out3), 1)
|
||||||
|
|
||||||
|
if self.downsample is not None:
|
||||||
|
residual = self.downsample(residual)
|
||||||
|
|
||||||
|
out3 += residual
|
||||||
|
|
||||||
|
return out3
|
||||||
|
|
||||||
|
|
||||||
|
class Bottleneck(nn.Module):
|
||||||
|
|
||||||
|
expansion = 4
|
||||||
|
|
||||||
|
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||||
|
super(Bottleneck, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
||||||
|
self.bn1 = nn.BatchNorm2d(planes)
|
||||||
|
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
||||||
|
padding=1, bias=False)
|
||||||
|
self.bn2 = nn.BatchNorm2d(planes)
|
||||||
|
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
||||||
|
self.bn3 = nn.BatchNorm2d(planes * 4)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.downsample = downsample
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
residual = x
|
||||||
|
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.bn1(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
out = self.conv2(out)
|
||||||
|
out = self.bn2(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
out = self.conv3(out)
|
||||||
|
out = self.bn3(out)
|
||||||
|
|
||||||
|
if self.downsample is not None:
|
||||||
|
residual = self.downsample(x)
|
||||||
|
|
||||||
|
out += residual
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class HourGlass(nn.Module):
|
||||||
|
def __init__(self, num_modules, depth, num_features):
|
||||||
|
super(HourGlass, self).__init__()
|
||||||
|
self.num_modules = num_modules
|
||||||
|
self.depth = depth
|
||||||
|
self.features = num_features
|
||||||
|
|
||||||
|
self._generate_network(self.depth)
|
||||||
|
|
||||||
|
def _generate_network(self, level):
|
||||||
|
self.add_module('b1_' + str(level), ConvBlock(self.features, self.features))
|
||||||
|
|
||||||
|
self.add_module('b2_' + str(level), ConvBlock(self.features, self.features))
|
||||||
|
|
||||||
|
if level > 1:
|
||||||
|
self._generate_network(level - 1)
|
||||||
|
else:
|
||||||
|
self.add_module('b2_plus_' + str(level), ConvBlock(self.features, self.features))
|
||||||
|
|
||||||
|
self.add_module('b3_' + str(level), ConvBlock(self.features, self.features))
|
||||||
|
|
||||||
|
def _forward(self, level, inp):
|
||||||
|
# Upper branch
|
||||||
|
up1 = inp
|
||||||
|
up1 = self._modules['b1_' + str(level)](up1)
|
||||||
|
|
||||||
|
# Lower branch
|
||||||
|
low1 = F.avg_pool2d(inp, 2, stride=2)
|
||||||
|
low1 = self._modules['b2_' + str(level)](low1)
|
||||||
|
|
||||||
|
if level > 1:
|
||||||
|
low2 = self._forward(level - 1, low1)
|
||||||
|
else:
|
||||||
|
low2 = low1
|
||||||
|
low2 = self._modules['b2_plus_' + str(level)](low2)
|
||||||
|
|
||||||
|
low3 = low2
|
||||||
|
low3 = self._modules['b3_' + str(level)](low3)
|
||||||
|
|
||||||
|
up2 = F.interpolate(low3, scale_factor=2, mode='nearest')
|
||||||
|
|
||||||
|
return up1 + up2
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self._forward(self.depth, x)
|
||||||
|
|
||||||
|
|
||||||
|
class FAN(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, num_modules=1):
|
||||||
|
super(FAN, self).__init__()
|
||||||
|
self.num_modules = num_modules
|
||||||
|
|
||||||
|
# Base part
|
||||||
|
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
|
||||||
|
self.bn1 = nn.BatchNorm2d(64)
|
||||||
|
self.conv2 = ConvBlock(64, 128)
|
||||||
|
self.conv3 = ConvBlock(128, 128)
|
||||||
|
self.conv4 = ConvBlock(128, 256)
|
||||||
|
|
||||||
|
# Stacking part
|
||||||
|
for hg_module in range(self.num_modules):
|
||||||
|
self.add_module('m' + str(hg_module), HourGlass(1, 4, 256))
|
||||||
|
self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256))
|
||||||
|
self.add_module('conv_last' + str(hg_module),
|
||||||
|
nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
|
||||||
|
self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256))
|
||||||
|
self.add_module('l' + str(hg_module), nn.Conv2d(256,
|
||||||
|
68, kernel_size=1, stride=1, padding=0))
|
||||||
|
|
||||||
|
if hg_module < self.num_modules - 1:
|
||||||
|
self.add_module(
|
||||||
|
'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
|
||||||
|
self.add_module('al' + str(hg_module), nn.Conv2d(68,
|
||||||
|
256, kernel_size=1, stride=1, padding=0))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = F.relu(self.bn1(self.conv1(x)), True)
|
||||||
|
x = F.avg_pool2d(self.conv2(x), 2, stride=2)
|
||||||
|
x = self.conv3(x)
|
||||||
|
x = self.conv4(x)
|
||||||
|
|
||||||
|
previous = x
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
for i in range(self.num_modules):
|
||||||
|
hg = self._modules['m' + str(i)](previous)
|
||||||
|
|
||||||
|
ll = hg
|
||||||
|
ll = self._modules['top_m_' + str(i)](ll)
|
||||||
|
|
||||||
|
ll = F.relu(self._modules['bn_end' + str(i)]
|
||||||
|
(self._modules['conv_last' + str(i)](ll)), True)
|
||||||
|
|
||||||
|
# Predict heatmaps
|
||||||
|
tmp_out = self._modules['l' + str(i)](ll)
|
||||||
|
outputs.append(tmp_out)
|
||||||
|
|
||||||
|
if i < self.num_modules - 1:
|
||||||
|
ll = self._modules['bl' + str(i)](ll)
|
||||||
|
tmp_out_ = self._modules['al' + str(i)](tmp_out)
|
||||||
|
previous = previous + ll + tmp_out_
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class ResNetDepth(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, block=Bottleneck, layers=[3, 8, 36, 3], num_classes=68):
|
||||||
|
self.inplanes = 64
|
||||||
|
super(ResNetDepth, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(3 + 68, 64, kernel_size=7, stride=2, padding=3,
|
||||||
|
bias=False)
|
||||||
|
self.bn1 = nn.BatchNorm2d(64)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||||
|
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||||
|
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
||||||
|
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
||||||
|
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
||||||
|
self.avgpool = nn.AvgPool2d(7)
|
||||||
|
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||||
|
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||||
|
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||||
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
|
m.weight.data.fill_(1)
|
||||||
|
m.bias.data.zero_()
|
||||||
|
|
||||||
|
def _make_layer(self, block, planes, blocks, stride=1):
|
||||||
|
downsample = None
|
||||||
|
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||||
|
downsample = nn.Sequential(
|
||||||
|
nn.Conv2d(self.inplanes, planes * block.expansion,
|
||||||
|
kernel_size=1, stride=stride, bias=False),
|
||||||
|
nn.BatchNorm2d(planes * block.expansion),
|
||||||
|
)
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
layers.append(block(self.inplanes, planes, stride, downsample))
|
||||||
|
self.inplanes = planes * block.expansion
|
||||||
|
for i in range(1, blocks):
|
||||||
|
layers.append(block(self.inplanes, planes))
|
||||||
|
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.bn1(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.maxpool(x)
|
||||||
|
|
||||||
|
x = self.layer1(x)
|
||||||
|
x = self.layer2(x)
|
||||||
|
x = self.layer3(x)
|
||||||
|
x = self.layer4(x)
|
||||||
|
|
||||||
|
x = self.avgpool(x)
|
||||||
|
x = x.view(x.size(0), -1)
|
||||||
|
x = self.fc(x)
|
||||||
|
|
||||||
|
return x
|
|
@ -0,0 +1,313 @@
|
||||||
|
from __future__ import print_function
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
|
||||||
|
def _gaussian(
|
||||||
|
size=3, sigma=0.25, amplitude=1, normalize=False, width=None,
|
||||||
|
height=None, sigma_horz=None, sigma_vert=None, mean_horz=0.5,
|
||||||
|
mean_vert=0.5):
|
||||||
|
# handle some defaults
|
||||||
|
if width is None:
|
||||||
|
width = size
|
||||||
|
if height is None:
|
||||||
|
height = size
|
||||||
|
if sigma_horz is None:
|
||||||
|
sigma_horz = sigma
|
||||||
|
if sigma_vert is None:
|
||||||
|
sigma_vert = sigma
|
||||||
|
center_x = mean_horz * width + 0.5
|
||||||
|
center_y = mean_vert * height + 0.5
|
||||||
|
gauss = np.empty((height, width), dtype=np.float32)
|
||||||
|
# generate kernel
|
||||||
|
for i in range(height):
|
||||||
|
for j in range(width):
|
||||||
|
gauss[i][j] = amplitude * math.exp(-(math.pow((j + 1 - center_x) / (
|
||||||
|
sigma_horz * width), 2) / 2.0 + math.pow((i + 1 - center_y) / (sigma_vert * height), 2) / 2.0))
|
||||||
|
if normalize:
|
||||||
|
gauss = gauss / np.sum(gauss)
|
||||||
|
return gauss
|
||||||
|
|
||||||
|
|
||||||
|
def draw_gaussian(image, point, sigma):
|
||||||
|
# Check if the gaussian is inside
|
||||||
|
ul = [math.floor(point[0] - 3 * sigma), math.floor(point[1] - 3 * sigma)]
|
||||||
|
br = [math.floor(point[0] + 3 * sigma), math.floor(point[1] + 3 * sigma)]
|
||||||
|
if (ul[0] > image.shape[1] or ul[1] > image.shape[0] or br[0] < 1 or br[1] < 1):
|
||||||
|
return image
|
||||||
|
size = 6 * sigma + 1
|
||||||
|
g = _gaussian(size)
|
||||||
|
g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) - int(max(1, ul[0])) + int(max(1, -ul[0]))]
|
||||||
|
g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) - int(max(1, ul[1])) + int(max(1, -ul[1]))]
|
||||||
|
img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))]
|
||||||
|
img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))]
|
||||||
|
assert (g_x[0] > 0 and g_y[1] > 0)
|
||||||
|
image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]
|
||||||
|
] = image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]] + g[g_y[0] - 1:g_y[1], g_x[0] - 1:g_x[1]]
|
||||||
|
image[image > 1] = 1
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def transform(point, center, scale, resolution, invert=False):
|
||||||
|
"""Generate and affine transformation matrix.
|
||||||
|
|
||||||
|
Given a set of points, a center, a scale and a targer resolution, the
|
||||||
|
function generates and affine transformation matrix. If invert is ``True``
|
||||||
|
it will produce the inverse transformation.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
point {torch.tensor} -- the input 2D point
|
||||||
|
center {torch.tensor or numpy.array} -- the center around which to perform the transformations
|
||||||
|
scale {float} -- the scale of the face/object
|
||||||
|
resolution {float} -- the output resolution
|
||||||
|
|
||||||
|
Keyword Arguments:
|
||||||
|
invert {bool} -- define wherever the function should produce the direct or the
|
||||||
|
inverse transformation matrix (default: {False})
|
||||||
|
"""
|
||||||
|
_pt = torch.ones(3)
|
||||||
|
_pt[0] = point[0]
|
||||||
|
_pt[1] = point[1]
|
||||||
|
|
||||||
|
h = 200.0 * scale
|
||||||
|
t = torch.eye(3)
|
||||||
|
t[0, 0] = resolution / h
|
||||||
|
t[1, 1] = resolution / h
|
||||||
|
t[0, 2] = resolution * (-center[0] / h + 0.5)
|
||||||
|
t[1, 2] = resolution * (-center[1] / h + 0.5)
|
||||||
|
|
||||||
|
if invert:
|
||||||
|
t = torch.inverse(t)
|
||||||
|
|
||||||
|
new_point = (torch.matmul(t, _pt))[0:2]
|
||||||
|
|
||||||
|
return new_point.int()
|
||||||
|
|
||||||
|
|
||||||
|
def crop(image, center, scale, resolution=256.0):
|
||||||
|
"""Center crops an image or set of heatmaps
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
image {numpy.array} -- an rgb image
|
||||||
|
center {numpy.array} -- the center of the object, usually the same as of the bounding box
|
||||||
|
scale {float} -- scale of the face
|
||||||
|
|
||||||
|
Keyword Arguments:
|
||||||
|
resolution {float} -- the size of the output cropped image (default: {256.0})
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[type] -- [description]
|
||||||
|
""" # Crop around the center point
|
||||||
|
""" Crops the image around the center. Input is expected to be an np.ndarray """
|
||||||
|
ul = transform([1, 1], center, scale, resolution, True)
|
||||||
|
br = transform([resolution, resolution], center, scale, resolution, True)
|
||||||
|
# pad = math.ceil(torch.norm((ul - br).float()) / 2.0 - (br[0] - ul[0]) / 2.0)
|
||||||
|
if image.ndim > 2:
|
||||||
|
newDim = np.array([br[1] - ul[1], br[0] - ul[0],
|
||||||
|
image.shape[2]], dtype=np.int32)
|
||||||
|
newImg = np.zeros(newDim, dtype=np.uint8)
|
||||||
|
else:
|
||||||
|
newDim = np.array([br[1] - ul[1], br[0] - ul[0]], dtype=np.int)
|
||||||
|
newImg = np.zeros(newDim, dtype=np.uint8)
|
||||||
|
ht = image.shape[0]
|
||||||
|
wd = image.shape[1]
|
||||||
|
newX = np.array(
|
||||||
|
[max(1, -ul[0] + 1), min(br[0], wd) - ul[0]], dtype=np.int32)
|
||||||
|
newY = np.array(
|
||||||
|
[max(1, -ul[1] + 1), min(br[1], ht) - ul[1]], dtype=np.int32)
|
||||||
|
oldX = np.array([max(1, ul[0] + 1), min(br[0], wd)], dtype=np.int32)
|
||||||
|
oldY = np.array([max(1, ul[1] + 1), min(br[1], ht)], dtype=np.int32)
|
||||||
|
newImg[newY[0] - 1:newY[1], newX[0] - 1:newX[1]
|
||||||
|
] = image[oldY[0] - 1:oldY[1], oldX[0] - 1:oldX[1], :]
|
||||||
|
newImg = cv2.resize(newImg, dsize=(int(resolution), int(resolution)),
|
||||||
|
interpolation=cv2.INTER_LINEAR)
|
||||||
|
return newImg
|
||||||
|
|
||||||
|
|
||||||
|
def get_preds_fromhm(hm, center=None, scale=None):
|
||||||
|
"""Obtain (x,y) coordinates given a set of N heatmaps. If the center
|
||||||
|
and the scale is provided the function will return the points also in
|
||||||
|
the original coordinate frame.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
|
||||||
|
|
||||||
|
Keyword Arguments:
|
||||||
|
center {torch.tensor} -- the center of the bounding box (default: {None})
|
||||||
|
scale {float} -- face scale (default: {None})
|
||||||
|
"""
|
||||||
|
max, idx = torch.max(
|
||||||
|
hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
|
||||||
|
idx += 1
|
||||||
|
preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
|
||||||
|
preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
|
||||||
|
preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
|
||||||
|
|
||||||
|
for i in range(preds.size(0)):
|
||||||
|
for j in range(preds.size(1)):
|
||||||
|
hm_ = hm[i, j, :]
|
||||||
|
pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
|
||||||
|
if pX > 0 and pX < 63 and pY > 0 and pY < 63:
|
||||||
|
diff = torch.FloatTensor(
|
||||||
|
[hm_[pY, pX + 1] - hm_[pY, pX - 1],
|
||||||
|
hm_[pY + 1, pX] - hm_[pY - 1, pX]])
|
||||||
|
preds[i, j].add_(diff.sign_().mul_(.25))
|
||||||
|
|
||||||
|
preds.add_(-.5)
|
||||||
|
|
||||||
|
preds_orig = torch.zeros(preds.size())
|
||||||
|
if center is not None and scale is not None:
|
||||||
|
for i in range(hm.size(0)):
|
||||||
|
for j in range(hm.size(1)):
|
||||||
|
preds_orig[i, j] = transform(
|
||||||
|
preds[i, j], center, scale, hm.size(2), True)
|
||||||
|
|
||||||
|
return preds, preds_orig
|
||||||
|
|
||||||
|
def get_preds_fromhm_batch(hm, centers=None, scales=None):
|
||||||
|
"""Obtain (x,y) coordinates given a set of N heatmaps. If the centers
|
||||||
|
and the scales is provided the function will return the points also in
|
||||||
|
the original coordinate frame.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
|
||||||
|
|
||||||
|
Keyword Arguments:
|
||||||
|
centers {torch.tensor} -- the centers of the bounding box (default: {None})
|
||||||
|
scales {float} -- face scales (default: {None})
|
||||||
|
"""
|
||||||
|
max, idx = torch.max(
|
||||||
|
hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
|
||||||
|
idx += 1
|
||||||
|
preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
|
||||||
|
preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
|
||||||
|
preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
|
||||||
|
|
||||||
|
for i in range(preds.size(0)):
|
||||||
|
for j in range(preds.size(1)):
|
||||||
|
hm_ = hm[i, j, :]
|
||||||
|
pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
|
||||||
|
if pX > 0 and pX < 63 and pY > 0 and pY < 63:
|
||||||
|
diff = torch.FloatTensor(
|
||||||
|
[hm_[pY, pX + 1] - hm_[pY, pX - 1],
|
||||||
|
hm_[pY + 1, pX] - hm_[pY - 1, pX]])
|
||||||
|
preds[i, j].add_(diff.sign_().mul_(.25))
|
||||||
|
|
||||||
|
preds.add_(-.5)
|
||||||
|
|
||||||
|
preds_orig = torch.zeros(preds.size())
|
||||||
|
if centers is not None and scales is not None:
|
||||||
|
for i in range(hm.size(0)):
|
||||||
|
for j in range(hm.size(1)):
|
||||||
|
preds_orig[i, j] = transform(
|
||||||
|
preds[i, j], centers[i], scales[i], hm.size(2), True)
|
||||||
|
|
||||||
|
return preds, preds_orig
|
||||||
|
|
||||||
|
def shuffle_lr(parts, pairs=None):
|
||||||
|
"""Shuffle the points left-right according to the axis of symmetry
|
||||||
|
of the object.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
parts {torch.tensor} -- a 3D or 4D object containing the
|
||||||
|
heatmaps.
|
||||||
|
|
||||||
|
Keyword Arguments:
|
||||||
|
pairs {list of integers} -- [order of the flipped points] (default: {None})
|
||||||
|
"""
|
||||||
|
if pairs is None:
|
||||||
|
pairs = [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0,
|
||||||
|
26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 27, 28, 29, 30, 35,
|
||||||
|
34, 33, 32, 31, 45, 44, 43, 42, 47, 46, 39, 38, 37, 36, 41,
|
||||||
|
40, 54, 53, 52, 51, 50, 49, 48, 59, 58, 57, 56, 55, 64, 63,
|
||||||
|
62, 61, 60, 67, 66, 65]
|
||||||
|
if parts.ndimension() == 3:
|
||||||
|
parts = parts[pairs, ...]
|
||||||
|
else:
|
||||||
|
parts = parts[:, pairs, ...]
|
||||||
|
|
||||||
|
return parts
|
||||||
|
|
||||||
|
|
||||||
|
def flip(tensor, is_label=False):
|
||||||
|
"""Flip an image or a set of heatmaps left-right
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
tensor {numpy.array or torch.tensor} -- [the input image or heatmaps]
|
||||||
|
|
||||||
|
Keyword Arguments:
|
||||||
|
is_label {bool} -- [denote wherever the input is an image or a set of heatmaps ] (default: {False})
|
||||||
|
"""
|
||||||
|
if not torch.is_tensor(tensor):
|
||||||
|
tensor = torch.from_numpy(tensor)
|
||||||
|
|
||||||
|
if is_label:
|
||||||
|
tensor = shuffle_lr(tensor).flip(tensor.ndimension() - 1)
|
||||||
|
else:
|
||||||
|
tensor = tensor.flip(tensor.ndimension() - 1)
|
||||||
|
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
# From pyzolib/paths.py (https://bitbucket.org/pyzo/pyzolib/src/tip/paths.py)
|
||||||
|
|
||||||
|
|
||||||
|
def appdata_dir(appname=None, roaming=False):
|
||||||
|
""" appdata_dir(appname=None, roaming=False)
|
||||||
|
|
||||||
|
Get the path to the application directory, where applications are allowed
|
||||||
|
to write user specific files (e.g. configurations). For non-user specific
|
||||||
|
data, consider using common_appdata_dir().
|
||||||
|
If appname is given, a subdir is appended (and created if necessary).
|
||||||
|
If roaming is True, will prefer a roaming directory (Windows Vista/7).
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Define default user directory
|
||||||
|
userDir = os.getenv('FACEALIGNMENT_USERDIR', None)
|
||||||
|
if userDir is None:
|
||||||
|
userDir = os.path.expanduser('~')
|
||||||
|
if not os.path.isdir(userDir): # pragma: no cover
|
||||||
|
userDir = '/var/tmp' # issue #54
|
||||||
|
|
||||||
|
# Get system app data dir
|
||||||
|
path = None
|
||||||
|
if sys.platform.startswith('win'):
|
||||||
|
path1, path2 = os.getenv('LOCALAPPDATA'), os.getenv('APPDATA')
|
||||||
|
path = (path2 or path1) if roaming else (path1 or path2)
|
||||||
|
elif sys.platform.startswith('darwin'):
|
||||||
|
path = os.path.join(userDir, 'Library', 'Application Support')
|
||||||
|
# On Linux and as fallback
|
||||||
|
if not (path and os.path.isdir(path)):
|
||||||
|
path = userDir
|
||||||
|
|
||||||
|
# Maybe we should store things local to the executable (in case of a
|
||||||
|
# portable distro or a frozen application that wants to be portable)
|
||||||
|
prefix = sys.prefix
|
||||||
|
if getattr(sys, 'frozen', None):
|
||||||
|
prefix = os.path.abspath(os.path.dirname(sys.executable))
|
||||||
|
for reldir in ('settings', '../settings'):
|
||||||
|
localpath = os.path.abspath(os.path.join(prefix, reldir))
|
||||||
|
if os.path.isdir(localpath): # pragma: no cover
|
||||||
|
try:
|
||||||
|
open(os.path.join(localpath, 'test.write'), 'wb').close()
|
||||||
|
os.remove(os.path.join(localpath, 'test.write'))
|
||||||
|
except IOError:
|
||||||
|
pass # We cannot write in this directory
|
||||||
|
else:
|
||||||
|
path = localpath
|
||||||
|
break
|
||||||
|
|
||||||
|
# Get path specific for this app
|
||||||
|
if appname:
|
||||||
|
if path == userDir:
|
||||||
|
appname = '.' + appname.lstrip('.') # Make it a hidden directory
|
||||||
|
path = os.path.join(path, appname)
|
||||||
|
if not os.path.isdir(path): # pragma: no cover
|
||||||
|
os.mkdir(path)
|
||||||
|
|
||||||
|
# Done
|
||||||
|
return path
|
|
@ -0,0 +1,125 @@
|
||||||
|
from os import listdir, path
|
||||||
|
import numpy as np
|
||||||
|
import scipy, cv2, os, sys, argparse
|
||||||
|
import json, subprocess, random, string
|
||||||
|
from tqdm import tqdm
|
||||||
|
from glob import glob
|
||||||
|
import torch
|
||||||
|
import pickle
|
||||||
|
import face_detection
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='Inference code to lip-sync videos in the wild using Wav2Lip models')
|
||||||
|
parser.add_argument('--img_size', default=96, type=int)
|
||||||
|
parser.add_argument('--avatar_id', default='wav2lip_avatar1', type=str)
|
||||||
|
parser.add_argument('--video_path', default='', type=str)
|
||||||
|
parser.add_argument('--nosmooth', default=False, action='store_true',
|
||||||
|
help='Prevent smoothing face detections over a short temporal window')
|
||||||
|
parser.add_argument('--pads', nargs='+', type=int, default=[0, 10, 0, 0],
|
||||||
|
help='Padding (top, bottom, left, right). Please adjust to include chin at least')
|
||||||
|
parser.add_argument('--face_det_batch_size', type=int,
|
||||||
|
help='Batch size for face detection', default=16)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
print('Using {} for inference.'.format(device))
|
||||||
|
|
||||||
|
def osmakedirs(path_list):
|
||||||
|
for path in path_list:
|
||||||
|
os.makedirs(path) if not os.path.exists(path) else None
|
||||||
|
|
||||||
|
def video2imgs(vid_path, save_path, ext = '.png',cut_frame = 10000000):
|
||||||
|
cap = cv2.VideoCapture(vid_path)
|
||||||
|
count = 0
|
||||||
|
while True:
|
||||||
|
if count > cut_frame:
|
||||||
|
break
|
||||||
|
ret, frame = cap.read()
|
||||||
|
if ret:
|
||||||
|
cv2.imwrite(f"{save_path}/{count:08d}.png", frame)
|
||||||
|
count += 1
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
def read_imgs(img_list):
|
||||||
|
frames = []
|
||||||
|
print('reading images...')
|
||||||
|
for img_path in tqdm(img_list):
|
||||||
|
frame = cv2.imread(img_path)
|
||||||
|
frames.append(frame)
|
||||||
|
return frames
|
||||||
|
|
||||||
|
def get_smoothened_boxes(boxes, T):
|
||||||
|
for i in range(len(boxes)):
|
||||||
|
if i + T > len(boxes):
|
||||||
|
window = boxes[len(boxes) - T:]
|
||||||
|
else:
|
||||||
|
window = boxes[i : i + T]
|
||||||
|
boxes[i] = np.mean(window, axis=0)
|
||||||
|
return boxes
|
||||||
|
|
||||||
|
def face_detect(images):
|
||||||
|
detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,
|
||||||
|
flip_input=False, device=device)
|
||||||
|
|
||||||
|
batch_size = args.face_det_batch_size
|
||||||
|
|
||||||
|
while 1:
|
||||||
|
predictions = []
|
||||||
|
try:
|
||||||
|
for i in tqdm(range(0, len(images), batch_size)):
|
||||||
|
predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
|
||||||
|
except RuntimeError:
|
||||||
|
if batch_size == 1:
|
||||||
|
raise RuntimeError('Image too big to run face detection on GPU. Please use the --resize_factor argument')
|
||||||
|
batch_size //= 2
|
||||||
|
print('Recovering from OOM error; New batch size: {}'.format(batch_size))
|
||||||
|
continue
|
||||||
|
break
|
||||||
|
|
||||||
|
results = []
|
||||||
|
pady1, pady2, padx1, padx2 = args.pads
|
||||||
|
for rect, image in zip(predictions, images):
|
||||||
|
if rect is None:
|
||||||
|
cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected.
|
||||||
|
raise ValueError('Face not detected! Ensure the video contains a face in all the frames.')
|
||||||
|
|
||||||
|
y1 = max(0, rect[1] - pady1)
|
||||||
|
y2 = min(image.shape[0], rect[3] + pady2)
|
||||||
|
x1 = max(0, rect[0] - padx1)
|
||||||
|
x2 = min(image.shape[1], rect[2] + padx2)
|
||||||
|
|
||||||
|
results.append([x1, y1, x2, y2])
|
||||||
|
|
||||||
|
boxes = np.array(results)
|
||||||
|
if not args.nosmooth: boxes = get_smoothened_boxes(boxes, T=5)
|
||||||
|
results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]
|
||||||
|
|
||||||
|
del detector
|
||||||
|
return results
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
avatar_path = f"./results/avatars/{args.avatar_id}"
|
||||||
|
full_imgs_path = f"{avatar_path}/full_imgs"
|
||||||
|
face_imgs_path = f"{avatar_path}/face_imgs"
|
||||||
|
coords_path = f"{avatar_path}/coords.pkl"
|
||||||
|
osmakedirs([avatar_path,full_imgs_path,face_imgs_path])
|
||||||
|
print(args)
|
||||||
|
|
||||||
|
#if os.path.isfile(args.video_path):
|
||||||
|
video2imgs(args.video_path, full_imgs_path, ext = 'png')
|
||||||
|
input_img_list = sorted(glob(os.path.join(full_imgs_path, '*.[jpJP][pnPN]*[gG]')))
|
||||||
|
|
||||||
|
frames = read_imgs(input_img_list)
|
||||||
|
face_det_results = face_detect(frames)
|
||||||
|
coord_list = []
|
||||||
|
idx = 0
|
||||||
|
for frame,coords in face_det_results:
|
||||||
|
#x1, y1, x2, y2 = bbox
|
||||||
|
resized_crop_frame = cv2.resize(frame,(args.img_size, args.img_size)) #,interpolation = cv2.INTER_LANCZOS4)
|
||||||
|
cv2.imwrite(f"{face_imgs_path}/{idx:08d}.png", resized_crop_frame)
|
||||||
|
coord_list.append(coords)
|
||||||
|
idx = idx + 1
|
||||||
|
|
||||||
|
with open(coords_path, 'wb') as f:
|
||||||
|
pickle.dump(coord_list, f)
|
|
@ -0,0 +1,101 @@
|
||||||
|
from glob import glob
|
||||||
|
import os
|
||||||
|
|
||||||
|
def get_image_list(data_root, split):
|
||||||
|
filelist = []
|
||||||
|
|
||||||
|
with open('filelists/{}.txt'.format(split)) as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if ' ' in line: line = line.split()[0]
|
||||||
|
filelist.append(os.path.join(data_root, line))
|
||||||
|
|
||||||
|
return filelist
|
||||||
|
|
||||||
|
class HParams:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self.data = {}
|
||||||
|
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
self.data[key] = value
|
||||||
|
|
||||||
|
def __getattr__(self, key):
|
||||||
|
if key not in self.data:
|
||||||
|
raise AttributeError("'HParams' object has no attribute %s" % key)
|
||||||
|
return self.data[key]
|
||||||
|
|
||||||
|
def set_hparam(self, key, value):
|
||||||
|
self.data[key] = value
|
||||||
|
|
||||||
|
|
||||||
|
# Default hyperparameters
|
||||||
|
hparams = HParams(
|
||||||
|
num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality
|
||||||
|
# network
|
||||||
|
rescale=True, # Whether to rescale audio prior to preprocessing
|
||||||
|
rescaling_max=0.9, # Rescaling value
|
||||||
|
|
||||||
|
# Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction
|
||||||
|
# It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder
|
||||||
|
# Does not work if n_ffit is not multiple of hop_size!!
|
||||||
|
use_lws=False,
|
||||||
|
|
||||||
|
n_fft=800, # Extra window size is filled with 0 paddings to match this parameter
|
||||||
|
hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate)
|
||||||
|
win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate)
|
||||||
|
sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i <filename>)
|
||||||
|
|
||||||
|
frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5)
|
||||||
|
|
||||||
|
# Mel and Linear spectrograms normalization/scaling and clipping
|
||||||
|
signal_normalization=True,
|
||||||
|
# Whether to normalize mel spectrograms to some predefined range (following below parameters)
|
||||||
|
allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True
|
||||||
|
symmetric_mels=True,
|
||||||
|
# Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2,
|
||||||
|
# faster and cleaner convergence)
|
||||||
|
max_abs_value=4.,
|
||||||
|
# max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not
|
||||||
|
# be too big to avoid gradient explosion,
|
||||||
|
# not too small for fast convergence)
|
||||||
|
# Contribution by @begeekmyfriend
|
||||||
|
# Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude
|
||||||
|
# levels. Also allows for better G&L phase reconstruction)
|
||||||
|
preemphasize=True, # whether to apply filter
|
||||||
|
preemphasis=0.97, # filter coefficient.
|
||||||
|
|
||||||
|
# Limits
|
||||||
|
min_level_db=-100,
|
||||||
|
ref_level_db=20,
|
||||||
|
fmin=55,
|
||||||
|
# Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To
|
||||||
|
# test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
|
||||||
|
fmax=7600, # To be increased/reduced depending on data.
|
||||||
|
|
||||||
|
###################### Our training parameters #################################
|
||||||
|
img_size=96,
|
||||||
|
fps=25,
|
||||||
|
|
||||||
|
batch_size=16,
|
||||||
|
initial_learning_rate=1e-4,
|
||||||
|
nepochs=200000000000000000, ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs
|
||||||
|
num_workers=16,
|
||||||
|
checkpoint_interval=3000,
|
||||||
|
eval_interval=3000,
|
||||||
|
save_optimizer_state=True,
|
||||||
|
|
||||||
|
syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence.
|
||||||
|
syncnet_batch_size=64,
|
||||||
|
syncnet_lr=1e-4,
|
||||||
|
syncnet_eval_interval=10000,
|
||||||
|
syncnet_checkpoint_interval=10000,
|
||||||
|
|
||||||
|
disc_wt=0.07,
|
||||||
|
disc_initial_learning_rate=1e-4,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def hparams_debug_string():
|
||||||
|
values = hparams.values()
|
||||||
|
hp = [" %s: %s" % (name, values[name]) for name in sorted(values) if name != "sentences"]
|
||||||
|
return "Hyperparameters:\n" + "\n".join(hp)
|
|
@ -0,0 +1,2 @@
|
||||||
|
from .wav2lip import Wav2Lip, Wav2Lip_disc_qual
|
||||||
|
from .syncnet import SyncNet_color
|
|
@ -0,0 +1,44 @@
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
class Conv2d(nn.Module):
|
||||||
|
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.conv_block = nn.Sequential(
|
||||||
|
nn.Conv2d(cin, cout, kernel_size, stride, padding),
|
||||||
|
nn.BatchNorm2d(cout)
|
||||||
|
)
|
||||||
|
self.act = nn.ReLU()
|
||||||
|
self.residual = residual
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self.conv_block(x)
|
||||||
|
if self.residual:
|
||||||
|
out += x
|
||||||
|
return self.act(out)
|
||||||
|
|
||||||
|
class nonorm_Conv2d(nn.Module):
|
||||||
|
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.conv_block = nn.Sequential(
|
||||||
|
nn.Conv2d(cin, cout, kernel_size, stride, padding),
|
||||||
|
)
|
||||||
|
self.act = nn.LeakyReLU(0.01, inplace=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self.conv_block(x)
|
||||||
|
return self.act(out)
|
||||||
|
|
||||||
|
class Conv2dTranspose(nn.Module):
|
||||||
|
def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.conv_block = nn.Sequential(
|
||||||
|
nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding),
|
||||||
|
nn.BatchNorm2d(cout)
|
||||||
|
)
|
||||||
|
self.act = nn.ReLU()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self.conv_block(x)
|
||||||
|
return self.act(out)
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue