Compare commits

...

60 Commits
v1.0 ... main

Author SHA1 Message Date
lipku 14208c6d60 audio chat 2024-10-05 17:54:38 +08:00
lipku 959ecf9be8 add llm stream func 2024-10-05 17:25:01 +08:00
lipku 5e8884fcf3 add audio echo 2024-09-21 10:55:30 +08:00
lipku 00dbc71db9 remove unuse code 2024-09-20 21:25:07 +08:00
lipku a8b40fa813 add audio asr input 2024-09-17 22:11:46 +08:00
lipku 8d5a38222b init funasr 2024-09-15 16:36:04 +08:00
lipku 5340e77e76 webrtc prefer h264 codec 2024-09-08 22:53:37 +08:00
lipku f584cb25d1 add tts cosyvoice 2024-09-08 12:13:33 +08:00
lipku 275af1ed9e fix edgetts exception 2024-09-07 13:44:59 +08:00
lipku 995428b426 update readme 2024-09-03 20:43:30 +08:00
lipku baf8270fc5 add video record 2024-09-01 18:37:43 +08:00
lipku e9faa50b9e load fullbody image to memory 2024-08-24 17:55:03 +08:00
Bruce.Lu 93a6513504 resolve building errors 2024-08-23 22:04:25 +08:00
anxu 93f3ed9895 推理不需要计算梯度 2024-08-15 14:22:51 +08:00
lipku 9c8f020b3f update readme 2024-08-03 17:26:35 +08:00
yuheng 3e60fd7738
Update LICENSE 2024-08-03 16:45:53 +08:00
lipku a9e9cfb220 fix customvideo 2024-08-03 12:58:49 +08:00
lipku 391512f68c add wav2lip customvideo 2024-08-03 08:26:17 +08:00
unknown 0c63e9a11b support multi session 2024-07-17 08:21:31 +08:00
unknown 2883b2243e remove websocket 2024-07-09 08:20:44 +08:00
yuheng 0465437abc
Merge pull request #139 from ShelikeSnow/main
迁移musetalk数字人生成支持图片视频
2024-07-08 20:06:56 +08:00
ShelikeSnow 7917c5f7cc
Merge branch 'lipku:main' into main 2024-07-07 13:54:43 +08:00
lipku 4f14468e19 fix pan address 2024-07-04 20:11:39 +08:00
Yun 1d5c7e1542 Merge remote-tracking branch 'origin/main' 2024-07-04 09:49:14 +08:00
Yun 79df82ebea feat: 完善修改成自动绝对路径,添加接口生成 2024-07-04 09:46:42 +08:00
Yun cd7d5f31b5 feat: 完善修改成自动绝对路径,添加接口生成 2024-07-04 09:43:56 +08:00
lipku c812e45f35 update readme 2024-07-01 07:38:26 +08:00
lipku 9fe4c7fccf wrapper class baseasr; add talk interrupt 2024-06-30 09:41:31 +08:00
Yun 18d7db35a7 feat: 完善修改成自动绝对路径,添加接口生成 2024-06-23 14:51:58 +08:00
ShelikeSnow 6eb03ecbff
Merge branch 'lipku:main' into main 2024-06-23 11:46:33 +08:00
lipku 98eeeb17af update readme 2024-06-22 16:11:44 +08:00
ShelikeSnow 994535fe3e
Merge branch 'lipku:main' into main 2024-06-22 12:49:03 +08:00
lipku da9ffa9521 improve musetalk lipsync and speed 2024-06-22 09:02:01 +08:00
Yun c0682408c5 feat: 添加 简单自动生成musetalk数字人 2024-06-20 20:21:37 +08:00
Yun 5da818b9d9 feat: add musereal static img 2024-06-19 14:47:57 +08:00
lipku 592312ab8c add wav2lip stream 2024-06-17 08:21:03 +08:00
lipku 39d7aff90a add init wav2lip 2024-06-16 11:09:07 +08:00
lipku 6fb8a19fd5 fix musetalk for windows 2024-06-10 13:10:21 +08:00
lipku 58e763fdb6 fix gpt-sovits 2024-06-09 09:43:12 +08:00
lipku d01860176e improve musetalk infer speed 2024-06-09 09:04:04 +08:00
yuheng 016442272e
Merge pull request #105 from yni9ht/fix-gpt-sovits
fix: tts gpt sovits function
2024-06-04 19:30:59 +08:00
yni9ht ff0e11866d fix: tts gpt sovits function 2024-06-04 16:06:21 +08:00
lipku 632409da1e Refactoring tts code 2024-06-02 22:25:19 +08:00
lipku 4e355e9ab9 del nouse code 2024-06-01 06:58:02 +08:00
lipku af1ad0aed8 fix rtmp send sleep time 2024-05-31 23:12:48 +08:00
lipku 677227145e improve nerf audio video sync 2024-05-31 22:39:03 +08:00
yuheng dc94e87620
Merge pull request #92 from Degree-21/feat-add-auto
Feat : 增加 autodl 使用教程
2024-05-31 14:34:11 +08:00
21 dce3085231 feat: add musetalk 2024-05-30 12:22:20 +08:00
21 c2ce2e25a4 fix:修改提示 2024-05-30 11:45:56 +08:00
21 78324506fb fix:还原js 2024-05-30 11:44:27 +08:00
21 d384aaaa1c fix:还原js 2024-05-30 11:42:16 +08:00
21 f1d6821d62 Merge branch 'main' into feat-add-auto 2024-05-30 11:40:47 +08:00
21 8dd3441fcd add 2024-05-30 11:39:43 +08:00
yuheng b902d3244c
Merge pull request #91 from Degree-21/fix-doc-tts
Fix GPT-SoVITS tts doc
2024-05-29 19:32:32 +08:00
21 1fa1620c5e fix: update tts doc 2024-05-29 12:04:31 +08:00
21 a1ae58ffa7 add 2024-05-29 12:04:15 +08:00
unknown bf4e4b0251 fix edgetts temp 2024-05-29 08:46:15 +08:00
lipku 6508a9160c improve musetalk quality 2024-05-26 18:07:22 +08:00
lipku 5a4a459ad5 add musetalk 2024-05-26 11:10:03 +08:00
lipku 6294f64795 add musetalk init 2024-05-25 06:33:59 +08:00
118 changed files with 113073 additions and 2558 deletions

1
.gitignore vendored
View File

@ -15,3 +15,4 @@ pretrained
*.mp4 *.mp4
.DS_Store .DS_Store
workspace/log_ngp.txt workspace/log_ngp.txt
.idea

214
LICENSE
View File

@ -1,21 +1,201 @@
MIT License Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
Copyright (c) 2023 LiHengzhong TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
Permission is hereby granted, free of charge, to any person obtaining a copy 1. Definitions.
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all "License" shall mean the terms and conditions for use, reproduction,
copies or substantial portions of the Software. and distribution as defined by Sections 1 through 9 of this document.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR "Licensor" shall mean the copyright owner or entity authorized by
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, the copyright owner that is granting the License.
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER "Legal Entity" shall mean the union of the acting entity and all
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, other entities that control, are controlled by, or are under common
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE control with that entity. For the purposes of this definition,
SOFTWARE. "control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

152
README.md
View File

@ -1,12 +1,14 @@
A streaming digital human based on the Ernerf model realize audio video synchronous dialogue. It can basically achieve commercial effects. Real time interactive streaming digital human realize audio video synchronous dialogue. It can basically achieve commercial effects.
基于ernerf模型的流式数字人,实现音视频同步对话。基本可以达到商用效果 实时交互流式数字人,实现音视频同步对话。基本可以达到商用效果
[效果演示](https://www.bilibili.com/video/BV1PM4m1y7Q2/) [ernerf效果](https://www.bilibili.com/video/BV1PM4m1y7Q2/) [musetalk效果](https://www.bilibili.com/video/BV1gm421N7vQ/) [wav2lip效果](https://www.bilibili.com/video/BV1Bw4m1e74P/)
## 为避免与3d数字人混淆原项目metahuman-stream改名为livetalking原有链接地址继续可用
## Features ## Features
1. 支持声音克隆 1. 支持多种数字人模型: ernerf、musetalk、wav2lip
2. 支持大模型对话 2. 支持声音克隆
3. 支持多种音频特征驱动wav2vec、hubert 3. 支持数字人说话被打断
4. 支持全身视频拼接 4. 支持全身视频拼接
5. 支持rtmp和webrtc 5. 支持rtmp和webrtc
6. 支持视频编排:不说话时播放自定义视频 6. 支持视频编排:不说话时播放自定义视频
@ -22,17 +24,19 @@ conda create -n nerfstream python=3.10
conda activate nerfstream conda activate nerfstream
conda install pytorch==1.12.1 torchvision==0.13.1 cudatoolkit=11.3 -c pytorch conda install pytorch==1.12.1 torchvision==0.13.1 cudatoolkit=11.3 -c pytorch
pip install -r requirements.txt pip install -r requirements.txt
#如果不训练ernerf模型,不需要安装下面的库
pip install "git+https://github.com/facebookresearch/pytorch3d.git" pip install "git+https://github.com/facebookresearch/pytorch3d.git"
pip install tensorflow-gpu==2.8.0 pip install tensorflow-gpu==2.8.0
pip install --upgrade "protobuf<=3.20.1" pip install --upgrade "protobuf<=3.20.1"
``` ```
如果用pytorch2.1torchvision用0.16可以去torchvision官网根据pytorch版本找匹配的,cudatoolkit可以不用装
安装常见问题[FAQ](/assets/faq.md) 安装常见问题[FAQ](/assets/faq.md)
linux cuda环境搭建可以参考这篇文章 https://zhuanlan.zhihu.com/p/674972886 linux cuda环境搭建可以参考这篇文章 https://zhuanlan.zhihu.com/p/674972886
## 2. Quick Start ## 2. Quick Start
默认采用webrtc推流到srs 默认采用ernerf模型webrtc推流到srs
### 2.1 运行rtmpserver (srs) ### 2.1 运行srs
``` ```
export CANDIDATE='<服务器外网ip>' export CANDIDATE='<服务器外网ip>'
docker run --rm --env CANDIDATE=$CANDIDATE \ docker run --rm --env CANDIDATE=$CANDIDATE \
@ -52,135 +56,45 @@ python app.py
export HF_ENDPOINT=https://hf-mirror.com export HF_ENDPOINT=https://hf-mirror.com
``` ```
用浏览器打开http://serverip:8010/rtcpush.html, 在文本框输入任意文字,提交。数字人播报该段文字 用浏览器打开http://serverip:8010/rtcpushapi.html, 在文本框输入任意文字,提交。数字人播报该段文字
备注:服务端需要开放端口 tcp:8000,8010,1985; udp:8000 备注:服务端需要开放端口 tcp:8000,8010,1985; udp:8000
## 3. More Usage ## 3. More Usage
### 3.1 使用LLM模型进行数字人对话 使用说明: <https://livetalking-doc.readthedocs.io/>
目前借鉴数字人对话系统[LinlyTalker](https://github.com/Kedreamix/Linly-Talker)的方式LLM模型支持Chatgpt,Qwen和GeminiPro。需要在app.py中填入自己的api_key。
用浏览器打开http://serverip:8010/rtcpushchat.html
### 3.2 声音克隆
可以任意选用下面两种服务推荐用gpt-sovits
#### 3.2.1 gpt-sovits
服务部署参照[gpt-sovits](/tts/README.md)
运行
```
python app.py --tts gpt-sovits --TTS_SERVER http://127.0.0.1:5000 --CHARACTER test --EMOTION default
```
#### 3.2.2 xtts
运行xtts服务参照 https://github.com/coqui-ai/xtts-streaming-server
```
docker run --gpus=all -e COQUI_TOS_AGREED=1 --rm -p 9000:80 ghcr.io/coqui-ai/xtts-streaming-server:latest
```
然后运行其中ref.wav为需要克隆的声音文件
```
python app.py --tts xtts --REF_FILE data/ref.wav --TTS_SERVER http://localhost:9000
```
### 3.3 音频特征用hubert
如果训练模型时用的hubert提取音频特征用如下命令启动数字人
```
python app.py --asr_model facebook/hubert-large-ls960-ft
```
### 3.4 设置背景图片
```
python app.py --bg_img bg.jpg
```
### 3.5 全身视频拼接
#### 3.5.1 切割训练用的视频
```
ffmpeg -i fullbody.mp4 -vf crop="400:400:100:5" train.mp4 
```
用train.mp4训练模型
#### 3.5.2 提取全身图片
```
ffmpeg -i fullbody.mp4 -vf fps=25 -qmin 1 -q:v 1 -start_number 0 data/fullbody/img/%d.jpg
```
#### 3.5.2 启动数字人
```
python app.py --fullbody --fullbody_img data/fullbody/img --fullbody_offset_x 100 --fullbody_offset_y 5 --fullbody_width 580 --fullbody_height 1080 --W 400 --H 400
```
- --fullbody_width、--fullbody_height 全身视频的宽、高
- --W、--H 训练视频的宽、高
- ernerf训练第三步torso如果训练的不好在拼接处会有接缝。可以在上面的命令加上--torso_imgs data/xxx/torso_imgstorso不用模型推理直接用训练数据集里的torso图片。这种方式可能头颈处会有些人工痕迹。
### 3.6 不说话时用自定义视频替代
- 提取自定义视频图片
```
ffmpeg -i silence.mp4 -vf fps=25 -qmin 1 -q:v 1 -start_number 0 data/customvideo/img/%d.png
```
- 运行数字人
```
python app.py --customvideo --customvideo_img data/customvideo/img --customvideo_imgnum 100
```
### 3.7 webrtc p2p
此种模式不需要srs
```
python app.py --transport webrtc
```
用浏览器打开http://serverip:8010/webrtc.html
### 3.8 rtmp推送到srs
- 安装rtmpstream库
参照 https://github.com/lipku/python_rtmpstream
- 启动srs
```
docker run --rm -it -p 1935:1935 -p 1985:1985 -p 8080:8080 registry.cn-hangzhou.aliyuncs.com/ossrs/srs:5
```
- 运行数字人
```python
python app.py --transport rtmp --push_url 'rtmp://localhost/live/livestream'
```
用浏览器打开http://serverip:8010/echo.html
## 4. Docker Run ## 4. Docker Run
不需要第1步的安装,直接运行。 不需要前面的安装,直接运行。
``` ```
docker run --gpus all -it --network=host --rm registry.cn-hangzhou.aliyuncs.com/lipku/nerfstream:v1.3 docker run --gpus all -it --network=host --rm registry.cn-beijing.aliyuncs.com/codewithgpu2/lipku-metahuman-stream:vjo1Y6NJ3N
``` ```
docker版本已经不是最新代码可以作为一个空环境把最新代码拷进去运行。 代码在/root/metahuman-stream先git pull拉一下最新代码然后执行命令同第2、3步
另外提供autodl镜像https://www.codewithgpu.com/i/lipku/metahuman-stream/base
## 5. Data flow 提供如下镜像
![](/assets/dataflow.png) - autodl镜像: <https://www.codewithgpu.com/i/lipku/metahuman-stream/base>
[autodl教程](autodl/README.md)
## 6. 数字人模型文件
可以替换成自己训练的模型(https://github.com/Fictionarry/ER-NeRF)
```python
.
├── data
│ ├── data_kf.json
│ ├── au.csv
│ ├── pretrained
│ └── └── ngp_kf.pth
``` ## 5. 性能分析
## 7. 性能分析
1. 帧率 1. 帧率
在Tesla T4显卡上测试整体fps为18左右如果去掉音视频编码推流帧率在20左右。用4090显卡可以达到40多帧/秒。 在Tesla T4显卡上测试整体fps为18左右如果去掉音视频编码推流帧率在20左右。用4090显卡可以达到40多帧/秒。
优化:新开一个线程运行音视频编码推流
2. 延时 2. 延时
整体延时3s左右 整体延时3s左右
1tts延时1.7s左右目前用的edgetts需要将每句话转完后一次性输入可以优化tts改成流式输入 1tts延时1.7s左右目前用的edgetts需要将每句话转完后一次性输入可以优化tts改成流式输入
2wav2vec延时0.4s需要缓存18帧音频做计算 2wav2vec延时0.4s需要缓存18帧音频做计算
3srs转发延时设置srs服务器减少缓冲延时。具体配置可看 https://ossrs.net/lts/zh-cn/docs/v5/doc/low-latency 3srs转发延时设置srs服务器减少缓冲延时。具体配置可看 https://ossrs.net/lts/zh-cn/docs/v5/doc/low-latency
## 8. TODO
## 6. TODO
- [x] 添加chatgpt实现数字人对话 - [x] 添加chatgpt实现数字人对话
- [x] 声音克隆 - [x] 声音克隆
- [x] 数字人静音时用一段视频代替 - [x] 数字人静音时用一段视频代替
- [ ] MuseTalk - [x] MuseTalk
- [x] Wav2Lip
- [ ] TalkingGaussian
---
如果本项目对你有帮助帮忙点个star。也欢迎感兴趣的朋友一起来完善该项目.
* 知识星球: https://t.zsxq.com/7NMyO 沉淀高质量常见问题、最佳实践经验、问题解答
* 微信公众号:数字人技术
![](https://mmbiz.qpic.cn/sz_mmbiz_jpg/l3ZibgueFiaeyfaiaLZGuMGQXnhLWxibpJUS2gfs8Dje6JuMY8zu2tVyU9n8Zx1yaNncvKHBMibX0ocehoITy5qQEZg/640?wxfrom=12&tp=wxpic&usePicPrefetch=1&wx_fmt=jpeg&amp;from=appmsg)
如果本项目对你有帮助帮忙点个star。也欢迎感兴趣的朋友一起来完善该项目。
Email: lipku@foxmail.com
知识星球: https://t.zsxq.com/7NMyO
微信公众号:数字人技术
![](https://mmbiz.qpic.cn/sz_mmbiz_jpg/l3ZibgueFiaeyfaiaLZGuMGQXnhLWxibpJUS2gfs8Dje6JuMY8zu2tVyU9n8Zx1yaNncvKHBMibX0ocehoITy5qQEZg/640?wxfrom=12&tp=wxpic&usePicPrefetch=1&wx_fmt=jpeg&amp;from=appmsg)

486
app.py
View File

@ -17,142 +17,20 @@ from aiohttp import web
import aiohttp import aiohttp
import aiohttp_cors import aiohttp_cors
from aiortc import RTCPeerConnection, RTCSessionDescription from aiortc import RTCPeerConnection, RTCSessionDescription
from aiortc.rtcrtpsender import RTCRtpSender
from webrtc import HumanPlayer from webrtc import HumanPlayer
import argparse import argparse
from ernerf.nerf_triplane.provider import NeRFDataset_Test
from ernerf.nerf_triplane.utils import *
from ernerf.nerf_triplane.network import NeRFNetwork
from nerfreal import NeRFReal
import shutil import shutil
import asyncio import asyncio
import edge_tts import string
from typing import Iterator
import requests
app = Flask(__name__) app = Flask(__name__)
sockets = Sockets(app) sockets = Sockets(app)
global nerfreal nerfreals = []
global tts_type statreals = []
global gspeaker
async def main(voicename: str, text: str, render):
communicate = edge_tts.Communicate(text, voicename)
#with open(OUTPUT_FILE, "wb") as file:
first = True
async for chunk in communicate.stream():
if first:
#render.before_push_audio()
first = False
if chunk["type"] == "audio":
render.push_audio(chunk["data"])
#file.write(chunk["data"])
elif chunk["type"] == "WordBoundary":
pass
def get_speaker(ref_audio,server_url):
files = {"wav_file": ("reference.wav", open(ref_audio, "rb"))}
response = requests.post(f"{server_url}/clone_speaker", files=files)
return response.json()
def xtts(text, speaker, language, server_url, stream_chunk_size) -> Iterator[bytes]:
start = time.perf_counter()
speaker["text"] = text
speaker["language"] = language
speaker["stream_chunk_size"] = stream_chunk_size # you can reduce it to get faster response, but degrade quality
res = requests.post(
f"{server_url}/tts_stream",
json=speaker,
stream=True,
)
end = time.perf_counter()
print(f"xtts Time to make POST: {end-start}s")
if res.status_code != 200:
print("Error:", res.text)
return
first = True
for chunk in res.iter_content(chunk_size=960): #24K*20ms*2
if first:
end = time.perf_counter()
print(f"xtts Time to first chunk: {end-start}s")
first = False
if chunk:
yield chunk
print("xtts response.elapsed:", res.elapsed)
def gpt_sovits(text, character, language, server_url, emotion) -> Iterator[bytes]:
start = time.perf_counter()
req={}
req["text"] = text
req["text_language"] = language
req["character"] = character
req["emotion"] = emotion
#req["stream_chunk_size"] = stream_chunk_size # you can reduce it to get faster response, but degrade quality
req["stream"] = True
res = requests.post(
f"{server_url}/tts",
json=req,
stream=True,
)
end = time.perf_counter()
print(f"gpt_sovits Time to make POST: {end-start}s")
if res.status_code != 200:
print("Error:", res.text)
return
first = True
for chunk in res.iter_content(chunk_size=32000): # 1280 32K*20ms*2
if first:
end = time.perf_counter()
print(f"gpt_sovits Time to first chunk: {end-start}s")
first = False
if chunk:
yield chunk
print("gpt_sovits response.elapsed:", res.elapsed)
def stream_tts(audio_stream,render):
for chunk in audio_stream:
if chunk is not None:
render.push_audio(chunk)
def txt_to_audio(text_):
if tts_type == "edgetts":
voicename = "zh-CN-YunxiaNeural"
text = text_
t = time.time()
asyncio.get_event_loop().run_until_complete(main(voicename,text,nerfreal))
print(f'-------edge tts time:{time.time()-t:.4f}s')
elif tts_type == "gpt-sovits": #gpt_sovits
stream_tts(
gpt_sovits(
text_,
app.config['CHARACTER'], #"test", #character
"zh", #en args.language,
app.config['TTS_SERVER'], #"http://127.0.0.1:5000", #args.server_url,
app.config['EMOTION'], #emotion
),
nerfreal
)
else: #xtts
stream_tts(
xtts(
text_,
gspeaker,
"zh-cn", #en args.language,
app.config['TTS_SERVER'], #"http://localhost:9000", #args.server_url,
"20" #args.stream_chunk_size
),
nerfreal
)
@sockets.route('/humanecho') @sockets.route('/humanecho')
@ -172,17 +50,61 @@ def echo_socket(ws):
if not message or len(message)==0: if not message or len(message)==0:
return '输入信息为空' return '输入信息为空'
else: else:
txt_to_audio(message) nerfreal.put_msg_txt(message)
def llm_response(message): # def llm_response(message):
from llm.LLM import LLM # from llm.LLM import LLM
# llm = LLM().init_model('Gemini', model_path= 'gemini-pro',api_key='Your API Key', proxy_url=None) # # llm = LLM().init_model('Gemini', model_path= 'gemini-pro',api_key='Your API Key', proxy_url=None)
# llm = LLM().init_model('ChatGPT', model_path= 'gpt-3.5-turbo',api_key='Your API Key') # # llm = LLM().init_model('ChatGPT', model_path= 'gpt-3.5-turbo',api_key='Your API Key')
llm = LLM().init_model('VllmGPT', model_path= 'THUDM/chatglm3-6b') # llm = LLM().init_model('VllmGPT', model_path= 'THUDM/chatglm3-6b')
response = llm.chat(message) # response = llm.chat(message)
print(response) # print(response)
return response # return response
def llm_response(message,nerfreal):
start = time.perf_counter()
from openai import OpenAI
client = OpenAI(
# 如果您没有配置环境变量请在此处用您的API Key进行替换
api_key=os.getenv("DASHSCOPE_API_KEY"),
# 填写DashScope SDK的base_url
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
)
end = time.perf_counter()
print(f"llm Time init: {end-start}s")
completion = client.chat.completions.create(
model="qwen-plus",
messages=[{'role': 'system', 'content': 'You are a helpful assistant.'},
{'role': 'user', 'content': message}],
stream=True,
# 通过以下设置在流式输出的最后一行展示token使用信息
stream_options={"include_usage": True}
)
result=""
first = True
for chunk in completion:
if len(chunk.choices)>0:
#print(chunk.choices[0].delta.content)
if first:
end = time.perf_counter()
print(f"llm Time to first chunk: {end-start}s")
first = False
msg = chunk.choices[0].delta.content
lastpos=0
#msglist = re.split('[,.!;:,。!?]',msg)
for i, char in enumerate(msg):
if char in ",.!;:,。!?:;" :
result = result+msg[lastpos:i+1]
lastpos = i+1
if len(result)>10:
print(result)
nerfreal.put_msg_txt(result)
result=""
result = result+msg[lastpos:]
end = time.perf_counter()
print(f"llm Time to last chunk: {end-start}s")
nerfreal.put_msg_txt(result)
@sockets.route('/humanchat') @sockets.route('/humanchat')
def chat_socket(ws): def chat_socket(ws):
@ -202,47 +124,26 @@ def chat_socket(ws):
return '输入信息为空' return '输入信息为空'
else: else:
res=llm_response(message) res=llm_response(message)
txt_to_audio(res) nerfreal.put_msg_txt(res)
#####webrtc############################### #####webrtc###############################
pcs = set() pcs = set()
async def txt_to_audio_async(text_):
if tts_type == "edgetts":
voicename = "zh-CN-YunxiaNeural"
text = text_
t = time.time()
#asyncio.get_event_loop().run_until_complete(main(voicename,text,nerfreal))
await main(voicename,text,nerfreal)
print(f'-------edge tts time:{time.time()-t:.4f}s')
elif tts_type == "gpt-sovits": #gpt_sovits
stream_tts(
gpt_sovits(
text_,
app.config['CHARACTER'], #"test", #character
"zh", #en args.language,
app.config['TTS_SERVER'], #"http://127.0.0.1:5000", #args.server_url,
app.config['EMOTION'], #emotion
),
nerfreal
)
else: #xtts
stream_tts(
xtts(
text_,
gspeaker,
"zh-cn", #en args.language,
app.config['TTS_SERVER'], #"http://localhost:9000", #args.server_url,
"20" #args.stream_chunk_size
),
nerfreal
)
#@app.route('/offer', methods=['POST']) #@app.route('/offer', methods=['POST'])
async def offer(request): async def offer(request):
params = await request.json() params = await request.json()
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"]) offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
sessionid = len(nerfreals)
for index,value in enumerate(statreals):
if value == 0:
sessionid = index
break
if sessionid>=len(nerfreals):
print('reach max session')
return -1
statreals[sessionid] = 1
pc = RTCPeerConnection() pc = RTCPeerConnection()
pcs.add(pc) pcs.add(pc)
@ -252,10 +153,20 @@ async def offer(request):
if pc.connectionState == "failed": if pc.connectionState == "failed":
await pc.close() await pc.close()
pcs.discard(pc) pcs.discard(pc)
statreals[sessionid] = 0
if pc.connectionState == "closed":
pcs.discard(pc)
statreals[sessionid] = 0
player = HumanPlayer(nerfreal) player = HumanPlayer(nerfreals[sessionid])
audio_sender = pc.addTrack(player.audio) audio_sender = pc.addTrack(player.audio)
video_sender = pc.addTrack(player.video) video_sender = pc.addTrack(player.video)
capabilities = RTCRtpSender.getCapabilities("video")
preferences = list(filter(lambda x: x.name == "H264", capabilities.codecs))
preferences += list(filter(lambda x: x.name == "VP8", capabilities.codecs))
preferences += list(filter(lambda x: x.name == "rtx", capabilities.codecs))
transceiver = pc.getTransceivers()[1]
transceiver.setCodecPreferences(preferences)
await pc.setRemoteDescription(offer) await pc.setRemoteDescription(offer)
@ -267,18 +178,22 @@ async def offer(request):
return web.Response( return web.Response(
content_type="application/json", content_type="application/json",
text=json.dumps( text=json.dumps(
{"sdp": pc.localDescription.sdp, "type": pc.localDescription.type} {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type, "sessionid":sessionid}
), ),
) )
async def human(request): async def human(request):
params = await request.json() params = await request.json()
sessionid = params.get('sessionid',0)
if params.get('interrupt'):
nerfreals[sessionid].pause_talk()
if params['type']=='echo': if params['type']=='echo':
await txt_to_audio_async(params['text']) nerfreals[sessionid].put_msg_txt(params['text'])
elif params['type']=='chat': elif params['type']=='chat':
res=llm_response(params['text']) res=await asyncio.get_event_loop().run_in_executor(None, llm_response, params['text'],nerfreals[sessionid])
await txt_to_audio_async(res) #nerfreals[sessionid].put_msg_txt(res)
return web.Response( return web.Response(
content_type="application/json", content_type="application/json",
@ -287,6 +202,70 @@ async def human(request):
), ),
) )
async def humanaudio(request):
try:
form= await request.post()
sessionid = int(form.get('sessionid',0))
fileobj = form["file"]
filename=fileobj.filename
filebytes=fileobj.file.read()
nerfreals[sessionid].put_audio_file(filebytes)
return web.Response(
content_type="application/json",
text=json.dumps(
{"code": 0, "msg":"ok"}
),
)
except Exception as e:
return web.Response(
content_type="application/json",
text=json.dumps(
{"code": -1, "msg":"err","data": ""+e.args[0]+""}
),
)
async def set_audiotype(request):
params = await request.json()
sessionid = params.get('sessionid',0)
nerfreals[sessionid].set_curr_state(params['audiotype'],params['reinit'])
return web.Response(
content_type="application/json",
text=json.dumps(
{"code": 0, "data":"ok"}
),
)
async def record(request):
params = await request.json()
sessionid = params.get('sessionid',0)
if params['type']=='start_record':
# nerfreals[sessionid].put_msg_txt(params['text'])
nerfreals[sessionid].start_recording("data/record_lasted.mp4")
elif params['type']=='end_record':
nerfreals[sessionid].stop_recording()
return web.Response(
content_type="application/json",
text=json.dumps(
{"code": 0, "data":"ok"}
),
)
async def is_speaking(request):
params = await request.json()
sessionid = params.get('sessionid',0)
return web.Response(
content_type="application/json",
text=json.dumps(
{"code": 0, "data": nerfreals[sessionid].is_speaking()}
),
)
async def on_shutdown(app): async def on_shutdown(app):
# close peer connections # close peer connections
coros = [pc.close() for pc in pcs] coros = [pc.close() for pc in pcs]
@ -312,17 +291,18 @@ async def run(push_url):
await pc.close() await pc.close()
pcs.discard(pc) pcs.discard(pc)
player = HumanPlayer(nerfreal) player = HumanPlayer(nerfreals[0])
audio_sender = pc.addTrack(player.audio) audio_sender = pc.addTrack(player.audio)
video_sender = pc.addTrack(player.video) video_sender = pc.addTrack(player.video)
await pc.setLocalDescription(await pc.createOffer()) await pc.setLocalDescription(await pc.createOffer())
answer = await post(push_url,pc.localDescription.sdp) answer = await post(push_url,pc.localDescription.sdp)
await pc.setRemoteDescription(RTCSessionDescription(sdp=answer,type='answer')) await pc.setRemoteDescription(RTCSessionDescription(sdp=answer,type='answer'))
########################################## ##########################################
# os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
# os.environ['MULTIPROCESSING_METHOD'] = 'forkserver'
if __name__ == '__main__': if __name__ == '__main__':
multiprocessing.set_start_method('spawn')
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--pose', type=str, default="data/data_kf.json", help="transforms.json, pose source") parser.add_argument('--pose', type=str, default="data/data_kf.json", help="transforms.json, pose source")
parser.add_argument('--au', type=str, default="data/au.csv", help="eye blink area") parser.add_argument('--au', type=str, default="data/au.csv", help="eye blink area")
@ -419,9 +399,6 @@ if __name__ == '__main__':
# parser.add_argument('--asr_model', type=str, default='facebook/wav2vec2-large-960h-lv60-self') # parser.add_argument('--asr_model', type=str, default='facebook/wav2vec2-large-960h-lv60-self')
# parser.add_argument('--asr_model', type=str, default='facebook/hubert-large-ls960-ft') # parser.add_argument('--asr_model', type=str, default='facebook/hubert-large-ls960-ft')
parser.add_argument('--transport', type=str, default='rtcpush') #rtmp webrtc rtcpush
parser.add_argument('--push_url', type=str, default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') #rtmp://localhost/live/livestream
parser.add_argument('--asr_save_feats', action='store_true') parser.add_argument('--asr_save_feats', action='store_true')
# audio FPS # audio FPS
parser.add_argument('--fps', type=int, default=50) parser.add_argument('--fps', type=int, default=50)
@ -437,73 +414,108 @@ if __name__ == '__main__':
parser.add_argument('--fullbody_offset_x', type=int, default=0) parser.add_argument('--fullbody_offset_x', type=int, default=0)
parser.add_argument('--fullbody_offset_y', type=int, default=0) parser.add_argument('--fullbody_offset_y', type=int, default=0)
parser.add_argument('--customvideo', action='store_true', help="custom video") #musetalk opt
parser.add_argument('--customvideo_img', type=str, default='data/customvideo/img') parser.add_argument('--avatar_id', type=str, default='avator_1')
parser.add_argument('--customvideo_imgnum', type=int, default=1) parser.add_argument('--bbox_shift', type=int, default=5)
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--tts', type=str, default='edgetts') #xtts gpt-sovits # parser.add_argument('--customvideo', action='store_true', help="custom video")
# parser.add_argument('--customvideo_img', type=str, default='data/customvideo/img')
# parser.add_argument('--customvideo_imgnum', type=int, default=1)
parser.add_argument('--customvideo_config', type=str, default='')
parser.add_argument('--tts', type=str, default='edgetts') #xtts gpt-sovits cosyvoice
parser.add_argument('--REF_FILE', type=str, default=None) parser.add_argument('--REF_FILE', type=str, default=None)
parser.add_argument('--TTS_SERVER', type=str, default='http://localhost:9000') #http://127.0.0.1:5000 parser.add_argument('--REF_TEXT', type=str, default=None)
parser.add_argument('--CHARACTER', type=str, default='test') parser.add_argument('--TTS_SERVER', type=str, default='http://127.0.0.1:9880') # http://localhost:9000
parser.add_argument('--EMOTION', type=str, default='default') # parser.add_argument('--CHARACTER', type=str, default='test')
# parser.add_argument('--EMOTION', type=str, default='default')
parser.add_argument('--model', type=str, default='ernerf') #musetalk wav2lip
parser.add_argument('--transport', type=str, default='rtcpush') #rtmp webrtc rtcpush
parser.add_argument('--push_url', type=str, default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') #rtmp://localhost/live/livestream
parser.add_argument('--max_session', type=int, default=1) #multi session count
parser.add_argument('--listenport', type=int, default=8010) parser.add_argument('--listenport', type=int, default=8010)
opt = parser.parse_args() opt = parser.parse_args()
app.config.from_object(opt) #app.config.from_object(opt)
print(app.config) #print(app.config)
opt.customopt = []
if opt.customvideo_config!='':
with open(opt.customvideo_config,'r') as file:
opt.customopt = json.load(file)
tts_type = opt.tts if opt.model == 'ernerf':
if tts_type == "xtts": from ernerf.nerf_triplane.provider import NeRFDataset_Test
print("Computing the latents for a new reference...") from ernerf.nerf_triplane.utils import *
gspeaker = get_speaker(opt.REF_FILE, opt.TTS_SERVER) from ernerf.nerf_triplane.network import NeRFNetwork
from nerfreal import NeRFReal
# assert test mode
opt.test = True
opt.test_train = False
#opt.train_camera =True
# explicit smoothing
opt.smooth_path = True
opt.smooth_lips = True
# assert test mode assert opt.pose != '', 'Must provide a pose source'
opt.test = True
opt.test_train = False
#opt.train_camera =True
# explicit smoothing
opt.smooth_path = True
opt.smooth_lips = True
assert opt.pose != '', 'Must provide a pose source' # if opt.O:
opt.fp16 = True
opt.cuda_ray = True
opt.exp_eye = True
opt.smooth_eye = True
# if opt.O: if opt.torso_imgs=='': #no img,use model output
opt.fp16 = True opt.torso = True
opt.cuda_ray = True
opt.exp_eye = True
opt.smooth_eye = True
if opt.torso_imgs=='': #no img,use model output # assert opt.cuda_ray, "Only support CUDA ray mode."
opt.torso = True opt.asr = True
# assert opt.cuda_ray, "Only support CUDA ray mode." if opt.patch_size > 1:
opt.asr = True # assert opt.patch_size > 16, "patch_size should > 16 to run LPIPS loss."
assert opt.num_rays % (opt.patch_size ** 2) == 0, "patch_size ** 2 should be dividable by num_rays."
seed_everything(opt.seed)
print(opt)
if opt.patch_size > 1: device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# assert opt.patch_size > 16, "patch_size should > 16 to run LPIPS loss." model = NeRFNetwork(opt)
assert opt.num_rays % (opt.patch_size ** 2) == 0, "patch_size ** 2 should be dividable by num_rays."
seed_everything(opt.seed)
print(opt)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') criterion = torch.nn.MSELoss(reduction='none')
model = NeRFNetwork(opt) metrics = [] # use no metric in GUI for faster initialization...
print(model)
trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, metrics=metrics, use_checkpoint=opt.ckpt)
criterion = torch.nn.MSELoss(reduction='none') test_loader = NeRFDataset_Test(opt, device=device).dataloader()
metrics = [] # use no metric in GUI for faster initialization... model.aud_features = test_loader._data.auds
print(model) model.eye_areas = test_loader._data.eye_area
trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, metrics=metrics, use_checkpoint=opt.ckpt)
test_loader = NeRFDataset_Test(opt, device=device).dataloader() # we still need test_loader to provide audio features for testing.
model.aud_features = test_loader._data.auds for _ in range(opt.max_session):
model.eye_areas = test_loader._data.eye_area nerfreal = NeRFReal(opt, trainer, test_loader)
nerfreals.append(nerfreal)
elif opt.model == 'musetalk':
from musereal import MuseReal
print(opt)
for _ in range(opt.max_session):
nerfreal = MuseReal(opt)
nerfreals.append(nerfreal)
elif opt.model == 'wav2lip':
from lipreal import LipReal
print(opt)
for _ in range(opt.max_session):
nerfreal = LipReal(opt)
nerfreals.append(nerfreal)
for _ in range(opt.max_session):
statreals.append(0)
# we still need test_loader to provide audio features for testing.
nerfreal = NeRFReal(opt, trainer, test_loader)
#txt_to_audio('我是中国人,我来自北京')
if opt.transport=='rtmp': if opt.transport=='rtmp':
thread_quit = Event() thread_quit = Event()
rendthrd = Thread(target=nerfreal.render,args=(thread_quit,)) rendthrd = Thread(target=nerfreals[0].render,args=(thread_quit,))
rendthrd.start() rendthrd.start()
############################################################################# #############################################################################
@ -511,6 +523,10 @@ if __name__ == '__main__':
appasync.on_shutdown.append(on_shutdown) appasync.on_shutdown.append(on_shutdown)
appasync.router.add_post("/offer", offer) appasync.router.add_post("/offer", offer)
appasync.router.add_post("/human", human) appasync.router.add_post("/human", human)
appasync.router.add_post("/humanaudio", humanaudio)
appasync.router.add_post("/set_audiotype", set_audiotype)
appasync.router.add_post("/record", record)
appasync.router.add_post("/is_speaking", is_speaking)
appasync.router.add_static('/',path='web') appasync.router.add_static('/',path='web')
# Configure default CORS settings. # Configure default CORS settings.
@ -525,6 +541,12 @@ if __name__ == '__main__':
for route in list(appasync.router.routes()): for route in list(appasync.router.routes()):
cors.add(route) cors.add(route)
pagename='webrtcapi.html'
if opt.transport=='rtmp':
pagename='echoapi.html'
elif opt.transport=='rtcpush':
pagename='rtcpushapi.html'
print('start http server; http://<serverip>:'+str(opt.listenport)+'/'+pagename)
def run_server(runner): def run_server(runner):
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
@ -534,12 +556,14 @@ if __name__ == '__main__':
if opt.transport=='rtcpush': if opt.transport=='rtcpush':
loop.run_until_complete(run(opt.push_url)) loop.run_until_complete(run(opt.push_url))
loop.run_forever() loop.run_forever()
Thread(target=run_server, args=(web.AppRunner(appasync),)).start() #Thread(target=run_server, args=(web.AppRunner(appasync),)).start()
run_server(web.AppRunner(appasync))
print('start websocket server')
#app.on_shutdown.append(on_shutdown) #app.on_shutdown.append(on_shutdown)
#app.router.add_post("/offer", offer) #app.router.add_post("/offer", offer)
server = pywsgi.WSGIServer(('0.0.0.0', 8000), app, handler_class=WebSocketHandler)
server.serve_forever() # print('start websocket server')
# server = pywsgi.WSGIServer(('0.0.0.0', 8000), app, handler_class=WebSocketHandler)
# server.serve_forever()

69
baseasr.py Normal file
View File

@ -0,0 +1,69 @@
import time
import numpy as np
import queue
from queue import Queue
import multiprocessing as mp
class BaseASR:
def __init__(self, opt, parent=None):
self.opt = opt
self.parent = parent
self.fps = opt.fps # 20 ms per frame
self.sample_rate = 16000
self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000)
self.queue = Queue()
self.output_queue = mp.Queue()
self.batch_size = opt.batch_size
self.frames = []
self.stride_left_size = opt.l
self.stride_right_size = opt.r
#self.context_size = 10
self.feat_queue = mp.Queue(2)
#self.warm_up()
def pause_talk(self):
self.queue.queue.clear()
def put_audio_frame(self,audio_chunk): #16khz 20ms pcm
self.queue.put(audio_chunk)
def get_audio_frame(self):
try:
frame = self.queue.get(block=True,timeout=0.01)
type = 0
#print(f'[INFO] get frame {frame.shape}')
except queue.Empty:
if self.parent and self.parent.curr_state>1: #播放自定义音频
frame = self.parent.get_audio_stream(self.parent.curr_state)
type = self.parent.curr_state
else:
frame = np.zeros(self.chunk, dtype=np.float32)
type = 1
return frame,type
def is_audio_frame_empty(self)->bool:
return self.queue.empty()
def get_audio_out(self): #get origin audio pcm to nerf
return self.output_queue.get()
def warm_up(self):
for _ in range(self.stride_left_size + self.stride_right_size):
audio_frame,type=self.get_audio_frame()
self.frames.append(audio_frame)
self.output_queue.put((audio_frame,type))
for _ in range(self.stride_left_size):
self.output_queue.get()
def run_step(self):
pass
def get_next_feat(self,block,timeout):
return self.feat_queue.get(block,timeout)

207
basereal.py Normal file
View File

@ -0,0 +1,207 @@
import math
import torch
import numpy as np
import os
import time
import cv2
import glob
import pickle
import copy
import resampy
import queue
from queue import Queue
from threading import Thread, Event
from io import BytesIO
import soundfile as sf
import av
from fractions import Fraction
from ttsreal import EdgeTTS,VoitsTTS,XTTS,CosyVoiceTTS
from tqdm import tqdm
def read_imgs(img_list):
frames = []
print('reading images...')
for img_path in tqdm(img_list):
frame = cv2.imread(img_path)
frames.append(frame)
return frames
class BaseReal:
def __init__(self, opt):
self.opt = opt
self.sample_rate = 16000
self.chunk = self.sample_rate // opt.fps # 320 samples per chunk (20ms * 16000 / 1000)
if opt.tts == "edgetts":
self.tts = EdgeTTS(opt,self)
elif opt.tts == "gpt-sovits":
self.tts = VoitsTTS(opt,self)
elif opt.tts == "xtts":
self.tts = XTTS(opt,self)
elif opt.tts == "cosyvoice":
self.tts = CosyVoiceTTS(opt,self)
self.speaking = False
self.recording = False
self.recordq_video = Queue()
self.recordq_audio = Queue()
self.curr_state=0
self.custom_img_cycle = {}
self.custom_audio_cycle = {}
self.custom_audio_index = {}
self.custom_index = {}
self.custom_opt = {}
self.__loadcustom()
def put_msg_txt(self,msg):
self.tts.put_msg_txt(msg)
def put_audio_frame(self,audio_chunk): #16khz 20ms pcm
self.asr.put_audio_frame(audio_chunk)
def put_audio_file(self,filebyte):
input_stream = BytesIO(filebyte)
stream = self.__create_bytes_stream(input_stream)
streamlen = stream.shape[0]
idx=0
while streamlen >= self.chunk: #and self.state==State.RUNNING
self.put_audio_frame(stream[idx:idx+self.chunk])
streamlen -= self.chunk
idx += self.chunk
def __create_bytes_stream(self,byte_stream):
#byte_stream=BytesIO(buffer)
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
print(f'[INFO]put audio stream {sample_rate}: {stream.shape}')
stream = stream.astype(np.float32)
if stream.ndim > 1:
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
stream = stream[:, 0]
if sample_rate != self.sample_rate and stream.shape[0]>0:
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.')
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate)
return stream
def pause_talk(self):
self.tts.pause_talk()
self.asr.pause_talk()
def is_speaking(self)->bool:
return self.speaking
def __loadcustom(self):
for item in self.opt.customopt:
print(item)
input_img_list = glob.glob(os.path.join(item['imgpath'], '*.[jpJP][pnPN]*[gG]'))
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
self.custom_img_cycle[item['audiotype']] = read_imgs(input_img_list)
self.custom_audio_cycle[item['audiotype']], sample_rate = sf.read(item['audiopath'], dtype='float32')
self.custom_audio_index[item['audiotype']] = 0
self.custom_index[item['audiotype']] = 0
self.custom_opt[item['audiotype']] = item
def init_customindex(self):
self.curr_state=0
for key in self.custom_audio_index:
self.custom_audio_index[key]=0
for key in self.custom_index:
self.custom_index[key]=0
def start_recording(self,path):
"""开始录制视频"""
if self.recording:
return
self.recording = True
self.recordq_video.queue.clear()
self.recordq_audio.queue.clear()
self.container = av.open(path, mode="w")
process_thread = Thread(target=self.record_frame, args=())
process_thread.start()
def record_frame(self):
videostream = self.container.add_stream("libx264", rate=25)
videostream.codec_context.time_base = Fraction(1, 25)
audiostream = self.container.add_stream("aac")
audiostream.codec_context.time_base = Fraction(1, 16000)
init = True
framenum = 0
while self.recording:
try:
videoframe = self.recordq_video.get(block=True, timeout=1)
videoframe.pts = framenum #int(round(framenum*0.04 / videostream.codec_context.time_base))
videoframe.dts = videoframe.pts
if init:
videostream.width = videoframe.width
videostream.height = videoframe.height
init = False
for packet in videostream.encode(videoframe):
self.container.mux(packet)
for k in range(2):
audioframe = self.recordq_audio.get(block=True, timeout=1)
audioframe.pts = int(round((framenum*2+k)*0.02 / audiostream.codec_context.time_base))
audioframe.dts = audioframe.pts
for packet in audiostream.encode(audioframe):
self.container.mux(packet)
framenum += 1
except queue.Empty:
print('record queue empty,')
continue
except Exception as e:
print(e)
#break
for packet in videostream.encode(None):
self.container.mux(packet)
for packet in audiostream.encode(None):
self.container.mux(packet)
self.container.close()
self.recordq_video.queue.clear()
self.recordq_audio.queue.clear()
print('record thread stop')
def stop_recording(self):
"""停止录制视频"""
if not self.recording:
return
self.recording = False
def mirror_index(self,size, index):
#size = len(self.coord_list_cycle)
turn = index // size
res = index % size
if turn % 2 == 0:
return res
else:
return size - res - 1
def get_audio_stream(self,audiotype):
idx = self.custom_audio_index[audiotype]
stream = self.custom_audio_cycle[audiotype][idx:idx+self.chunk]
self.custom_audio_index[audiotype] += self.chunk
if self.custom_audio_index[audiotype]>=self.custom_audio_cycle[audiotype].shape[0]:
self.curr_state = 1 #当前视频不循环播放,切换到静音状态
return stream
def set_curr_state(self,audiotype, reinit):
print('set_curr_state:',audiotype)
self.curr_state = audiotype
if reinit:
self.custom_audio_index[audiotype] = 0
self.custom_index[audiotype] = 0
# def process_custom(self,audiotype:int,idx:int):
# if self.curr_state!=audiotype: #从推理切到口播
# if idx in self.switch_pos: #在卡点位置可以切换
# self.curr_state=audiotype
# self.custom_index=0
# else:
# self.custom_index+=1

7
data/custom_config.json Normal file
View File

@ -0,0 +1,7 @@
[
{
"audiotype":2,
"imgpath":"data/customvideo/image",
"audiopath":"data/customvideo/audio.wav"
}
]

View File

@ -4,13 +4,13 @@ from torch.utils.cpp_extension import load
_src_path = os.path.dirname(os.path.abspath(__file__)) _src_path = os.path.dirname(os.path.abspath(__file__))
nvcc_flags = [ nvcc_flags = [
'-O3', '-std=c++14', '-O3', '-std=c++17',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', '-allow-unsupported-compiler',
'-use_fast_math' '-use_fast_math'
] ]
if os.name == "posix": if os.name == "posix":
c_flags = ['-O3', '-std=c++14'] c_flags = ['-O3', '-std=c++17']
elif os.name == "nt": elif os.name == "nt":
c_flags = ['/O2', '/std:c++17'] c_flags = ['/O2', '/std:c++17']

View File

@ -5,13 +5,13 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
_src_path = os.path.dirname(os.path.abspath(__file__)) _src_path = os.path.dirname(os.path.abspath(__file__))
nvcc_flags = [ nvcc_flags = [
'-O3', '-std=c++14', '-O3', '-std=c++17',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', '-allow-unsupported-compiler',
'-use_fast_math' '-use_fast_math'
] ]
if os.name == "posix": if os.name == "posix":
c_flags = ['-O3', '-std=c++14'] c_flags = ['-O3', '-std=c++17']
elif os.name == "nt": elif os.name == "nt":
c_flags = ['/O2', '/std:c++17'] c_flags = ['/O2', '/std:c++17']

View File

@ -4,12 +4,12 @@ from torch.utils.cpp_extension import load
_src_path = os.path.dirname(os.path.abspath(__file__)) _src_path = os.path.dirname(os.path.abspath(__file__))
nvcc_flags = [ nvcc_flags = [
'-O3', '-std=c++14', '-O3', '-std=c++17',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
] ]
if os.name == "posix": if os.name == "posix":
c_flags = ['-O3', '-std=c++14', '-finput-charset=UTF-8'] c_flags = ['-O3', '-std=c++17', '-finput-charset=UTF-8']
elif os.name == "nt": elif os.name == "nt":
c_flags = ['/O2', '/std:c++17', '/finput-charset=UTF-8'] c_flags = ['/O2', '/std:c++17', '/finput-charset=UTF-8']

View File

@ -5,12 +5,12 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
_src_path = os.path.dirname(os.path.abspath(__file__)) _src_path = os.path.dirname(os.path.abspath(__file__))
nvcc_flags = [ nvcc_flags = [
'-O3', '-std=c++14', '-O3', '-std=c++17',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__','-allow-unsupported-compiler',
] ]
if os.name == "posix": if os.name == "posix":
c_flags = ['-O3', '-std=c++14'] c_flags = ['-O3', '-std=c++17']
elif os.name == "nt": elif os.name == "nt":
c_flags = ['/O2', '/std:c++17'] c_flags = ['/O2', '/std:c++17']

View File

@ -4,12 +4,12 @@ from torch.utils.cpp_extension import load
_src_path = os.path.dirname(os.path.abspath(__file__)) _src_path = os.path.dirname(os.path.abspath(__file__))
nvcc_flags = [ nvcc_flags = [
'-O3', '-std=c++14', '-O3', '-std=c++17',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__','-allow-unsupported-compiler',
] ]
if os.name == "posix": if os.name == "posix":
c_flags = ['-O3', '-std=c++14'] c_flags = ['-O3', '-std=c++17']
elif os.name == "nt": elif os.name == "nt":
c_flags = ['/O2', '/std:c++17'] c_flags = ['/O2', '/std:c++17']

View File

@ -5,13 +5,13 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
_src_path = os.path.dirname(os.path.abspath(__file__)) _src_path = os.path.dirname(os.path.abspath(__file__))
nvcc_flags = [ nvcc_flags = [
'-O3', '-std=c++14', '-O3', '-std=c++17',
# '-lineinfo', # to debug illegal memory access # '-lineinfo', # to debug illegal memory access
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__','-allow-unsupported-compiler',
] ]
if os.name == "posix": if os.name == "posix":
c_flags = ['-O3', '-std=c++14'] c_flags = ['-O3', '-std=c++17']
elif os.name == "nt": elif os.name == "nt":
c_flags = ['/O2', '/std:c++17'] c_flags = ['/O2', '/std:c++17']

View File

@ -4,12 +4,12 @@ from torch.utils.cpp_extension import load
_src_path = os.path.dirname(os.path.abspath(__file__)) _src_path = os.path.dirname(os.path.abspath(__file__))
nvcc_flags = [ nvcc_flags = [
'-O3', '-std=c++14', '-O3', '-std=c++17',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__','-allow-unsupported-compiler',
] ]
if os.name == "posix": if os.name == "posix":
c_flags = ['-O3', '-std=c++14', '-finput-charset=utf-8'] c_flags = ['-O3', '-std=c++17', '-finput-charset=utf-8']
elif os.name == "nt": elif os.name == "nt":
c_flags = ['/O2', '/std:c++17', '/source-charset:utf-8'] c_flags = ['/O2', '/std:c++17', '/source-charset:utf-8']

View File

@ -5,12 +5,12 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
_src_path = os.path.dirname(os.path.abspath(__file__)) _src_path = os.path.dirname(os.path.abspath(__file__))
nvcc_flags = [ nvcc_flags = [
'-O3', '-std=c++14', '-O3', '-std=c++17',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__','-allow-unsupported-compiler',
] ]
if os.name == "posix": if os.name == "posix":
c_flags = ['-O3', '-std=c++14'] c_flags = ['-O3', '-std=c++17']
elif os.name == "nt": elif os.name == "nt":
c_flags = ['/O2', '/std:c++17'] c_flags = ['/O2', '/std:c++17']

View File

@ -1,315 +0,0 @@
import asyncio
import json
import websockets
import time
import logging
import tracemalloc
import numpy as np
import argparse
import ssl
parser = argparse.ArgumentParser()
parser.add_argument("--host",
type=str,
default="0.0.0.0",
required=False,
help="host ip, localhost, 0.0.0.0")
parser.add_argument("--port",
type=int,
default=10095,
required=False,
help="grpc server port")
parser.add_argument("--asr_model",
type=str,
default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
help="model from modelscope")
parser.add_argument("--asr_model_revision",
type=str,
default="v2.0.4",
help="")
parser.add_argument("--asr_model_online",
type=str,
default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
help="model from modelscope")
parser.add_argument("--asr_model_online_revision",
type=str,
default="v2.0.4",
help="")
parser.add_argument("--vad_model",
type=str,
default="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
help="model from modelscope")
parser.add_argument("--vad_model_revision",
type=str,
default="v2.0.4",
help="")
parser.add_argument("--punc_model",
type=str,
default="iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
help="model from modelscope")
parser.add_argument("--punc_model_revision",
type=str,
default="v2.0.4",
help="")
parser.add_argument("--ngpu",
type=int,
default=1,
help="0 for cpu, 1 for gpu")
parser.add_argument("--device",
type=str,
default="cuda",
help="cuda, cpu")
parser.add_argument("--ncpu",
type=int,
default=4,
help="cpu cores")
parser.add_argument("--certfile",
type=str,
default="ssl_key/server.crt",
required=False,
help="certfile for ssl")
parser.add_argument("--keyfile",
type=str,
default="ssl_key/server.key",
required=False,
help="keyfile for ssl")
args = parser.parse_args()
websocket_users = set()
print("model loading")
from funasr import AutoModel
# asr
model_asr = AutoModel(model=args.asr_model,
model_revision=args.asr_model_revision,
ngpu=args.ngpu,
ncpu=args.ncpu,
device=args.device,
disable_pbar=True,
disable_log=True,
)
# asr
model_asr_streaming = AutoModel(model=args.asr_model_online,
model_revision=args.asr_model_online_revision,
ngpu=args.ngpu,
ncpu=args.ncpu,
device=args.device,
disable_pbar=True,
disable_log=True,
)
# vad
model_vad = AutoModel(model=args.vad_model,
model_revision=args.vad_model_revision,
ngpu=args.ngpu,
ncpu=args.ncpu,
device=args.device,
disable_pbar=True,
disable_log=True,
# chunk_size=60,
)
if args.punc_model != "":
model_punc = AutoModel(model=args.punc_model,
model_revision=args.punc_model_revision,
ngpu=args.ngpu,
ncpu=args.ncpu,
device=args.device,
disable_pbar=True,
disable_log=True,
)
else:
model_punc = None
print("model loaded! only support one client at the same time now!!!!")
async def ws_reset(websocket):
print("ws reset now, total num is ",len(websocket_users))
websocket.status_dict_asr_online["cache"] = {}
websocket.status_dict_asr_online["is_final"] = True
websocket.status_dict_vad["cache"] = {}
websocket.status_dict_vad["is_final"] = True
websocket.status_dict_punc["cache"] = {}
await websocket.close()
async def clear_websocket():
for websocket in websocket_users:
await ws_reset(websocket)
websocket_users.clear()
async def ws_serve(websocket, path):
frames = []
frames_asr = []
frames_asr_online = []
global websocket_users
# await clear_websocket()
websocket_users.add(websocket)
websocket.status_dict_asr = {}
websocket.status_dict_asr_online = {"cache": {}, "is_final": False}
websocket.status_dict_vad = {'cache': {}, "is_final": False}
websocket.status_dict_punc = {'cache': {}}
websocket.chunk_interval = 10
websocket.vad_pre_idx = 0
speech_start = False
speech_end_i = -1
websocket.wav_name = "microphone"
websocket.mode = "2pass"
print("new user connected", flush=True)
try:
async for message in websocket:
if isinstance(message, str):
messagejson = json.loads(message)
if "is_speaking" in messagejson:
websocket.is_speaking = messagejson["is_speaking"]
websocket.status_dict_asr_online["is_final"] = not websocket.is_speaking
if "chunk_interval" in messagejson:
websocket.chunk_interval = messagejson["chunk_interval"]
if "wav_name" in messagejson:
websocket.wav_name = messagejson.get("wav_name")
if "chunk_size" in messagejson:
chunk_size = messagejson["chunk_size"]
if isinstance(chunk_size, str):
chunk_size = chunk_size.split(',')
websocket.status_dict_asr_online["chunk_size"] = [int(x) for x in chunk_size]
if "encoder_chunk_look_back" in messagejson:
websocket.status_dict_asr_online["encoder_chunk_look_back"] = messagejson["encoder_chunk_look_back"]
if "decoder_chunk_look_back" in messagejson:
websocket.status_dict_asr_online["decoder_chunk_look_back"] = messagejson["decoder_chunk_look_back"]
if "hotword" in messagejson:
websocket.status_dict_asr["hotword"] = messagejson["hotword"]
if "mode" in messagejson:
websocket.mode = messagejson["mode"]
websocket.status_dict_vad["chunk_size"] = int(websocket.status_dict_asr_online["chunk_size"][1]*60/websocket.chunk_interval)
if len(frames_asr_online) > 0 or len(frames_asr) > 0 or not isinstance(message, str):
if not isinstance(message, str):
frames.append(message)
duration_ms = len(message)//32
websocket.vad_pre_idx += duration_ms
# asr online
frames_asr_online.append(message)
websocket.status_dict_asr_online["is_final"] = speech_end_i != -1
if len(frames_asr_online) % websocket.chunk_interval == 0 or websocket.status_dict_asr_online["is_final"]:
if websocket.mode == "2pass" or websocket.mode == "online":
audio_in = b"".join(frames_asr_online)
try:
await async_asr_online(websocket, audio_in)
except:
print(f"error in asr streaming, {websocket.status_dict_asr_online}")
frames_asr_online = []
if speech_start:
frames_asr.append(message)
# vad online
try:
speech_start_i, speech_end_i = await async_vad(websocket, message)
except:
print("error in vad")
if speech_start_i != -1:
speech_start = True
beg_bias = (websocket.vad_pre_idx-speech_start_i)//duration_ms
frames_pre = frames[-beg_bias:]
frames_asr = []
frames_asr.extend(frames_pre)
# asr punc offline
if speech_end_i != -1 or not websocket.is_speaking:
# print("vad end point")
if websocket.mode == "2pass" or websocket.mode == "offline":
audio_in = b"".join(frames_asr)
try:
await async_asr(websocket, audio_in)
except:
print("error in asr offline")
frames_asr = []
speech_start = False
frames_asr_online = []
websocket.status_dict_asr_online["cache"] = {}
if not websocket.is_speaking:
websocket.vad_pre_idx = 0
frames = []
websocket.status_dict_vad["cache"] = {}
else:
frames = frames[-20:]
except websockets.ConnectionClosed:
print("ConnectionClosed...", websocket_users,flush=True)
await ws_reset(websocket)
websocket_users.remove(websocket)
except websockets.InvalidState:
print("InvalidState...")
except Exception as e:
print("Exception:", e)
async def async_vad(websocket, audio_in):
segments_result = model_vad.generate(input=audio_in, **websocket.status_dict_vad)[0]["value"]
# print(segments_result)
speech_start = -1
speech_end = -1
if len(segments_result) == 0 or len(segments_result) > 1:
return speech_start, speech_end
if segments_result[0][0] != -1:
speech_start = segments_result[0][0]
if segments_result[0][1] != -1:
speech_end = segments_result[0][1]
return speech_start, speech_end
async def async_asr(websocket, audio_in):
if len(audio_in) > 0:
# print(len(audio_in))
rec_result = model_asr.generate(input=audio_in, **websocket.status_dict_asr)[0]
# print("offline_asr, ", rec_result)
if model_punc is not None and len(rec_result["text"])>0:
# print("offline, before punc", rec_result, "cache", websocket.status_dict_punc)
rec_result = model_punc.generate(input=rec_result['text'], **websocket.status_dict_punc)[0]
# print("offline, after punc", rec_result)
if len(rec_result["text"])>0:
# print("offline", rec_result)
mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
message = json.dumps({"mode": mode, "text": rec_result["text"], "wav_name": websocket.wav_name,"is_final":websocket.is_speaking})
await websocket.send(message)
async def async_asr_online(websocket, audio_in):
if len(audio_in) > 0:
# print(websocket.status_dict_asr_online.get("is_final", False))
rec_result = model_asr_streaming.generate(input=audio_in, **websocket.status_dict_asr_online)[0]
# print("online, ", rec_result)
if websocket.mode == "2pass" and websocket.status_dict_asr_online.get("is_final", False):
return
# websocket.status_dict_asr_online["cache"] = dict()
if len(rec_result["text"]):
mode = "2pass-online" if "2pass" in websocket.mode else websocket.mode
message = json.dumps({"mode": mode, "text": rec_result["text"], "wav_name": websocket.wav_name,"is_final":websocket.is_speaking})
await websocket.send(message)
if len(args.certfile)>0:
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
# Generate with Lets Encrypt, copied to this location, chown to current user and 400 permissions
ssl_cert = args.certfile
ssl_key = args.keyfile
ssl_context.load_cert_chain(ssl_cert, keyfile=ssl_key)
start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None,ssl=ssl_context)
else:
start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)
asyncio.get_event_loop().run_until_complete(start_server)
asyncio.get_event_loop().run_forever()

View File

@ -1 +0,0 @@
在浏览器中打开 samples/html/static/index.html输入ASR服务器地址支持麦克风输入也支持文件输入

View File

@ -1,132 +0,0 @@
<!-- index.html -->
<html>
<head>
<script type="text/javascript" src="mpegts-1.7.3.min.js"></script>
<script type="text/javascript" src="http://cdn.sockjs.org/sockjs-0.3.4.js"></script>
<script src="http://code.jquery.com/jquery-2.1.1.min.js"></script>
<script src="recorder-core.js" charset="UTF-8"></script>
<script src="wav.js" charset="UTF-8"></script>
<script src="pcm.js" charset="UTF-8"></script>
</head>
<body>
<div class="container">
<h1>metahuman voice test</h1>
<form class="form-inline" id="echo-form" name="ssbtn">
<div class="form-group">
<p>input text</p>
<textarea cols="2" rows="3" style="width:600px;height:50px;" class="form-control" id="message"></textarea>
</div>
<button type="submit" class="btn btn-default">Send</button>
</form>
<div id="log">
</div>
<video id="video_player" width="40%" controls autoplay muted></video>
</div>
<div>------------------------------------------------------------------------------------------------------------------------------</div>
<div class="div_class_topArea">
<div class="div_class_recordControl">
asr服务器地址(必填):
<br>
<input id="wssip" type="text" onchange="addresschange()" style="width: 500px;" value="wss://127.0.0.1:10095/"/>
<br>
<a id="wsslink" style="display: none;" href="#" onclick="window.open('https://127.0.0.1:10095/', '_blank')"><div id="info_wslink">点此处手工授权wss://127.0.0.1:10095/</div></a>
<br>
<br>
<div style="border:2px solid #ccc;display: none;">
选择录音模式:<br/>
<label ><input name="recoder_mode" onclick="on_recoder_mode_change()" type="radio" value="mic" checked="true"/>麦克风 </label>&nbsp;&nbsp;
<label><input name="recoder_mode" onclick="on_recoder_mode_change()" type="radio" value="file" />文件 </label>
</div>
<div id="mic_mode_div" style="border:2px solid #ccc;display:none;">
选择asr模型模式:<br/>
<label><input name="asr_mode" type="radio" value="2pass" />2pass </label>&nbsp;&nbsp;
<label><input name="asr_mode" type="radio" value="online" checked="true"/>online </label>&nbsp;&nbsp;
<label><input name="asr_mode" type="radio" value="2pass-offline" />2pass-offline </label>&nbsp;&nbsp;
<label><input name="asr_mode" type="radio" value="offline" />offline </label>
</div>
<div id="rec_mode_div" style="border:2px solid #ccc;display:none;">
<input type="file" id="upfile">
</div>
<div style="border:2px solid #ccc;display: none;">
热词设置(一行一个关键字,空格隔开权重,如"阿里巴巴 20")
<textarea rows="1" id="varHot" style=" width: 100%;height:auto" >阿里巴巴 20&#13;hello world 40</textarea>
</div>
<div style="display: none;">语音识别结果显示:</div>
<br>
<textarea rows="10" id="varArea" readonly="true" style=" width: 100%;height:auto;display: none;" ></textarea>
<br>
<div id="info_div">请点击开始</div>
<div class="div_class_buttons">
<button id="btnConnect">连接</button>
<button id="btnStart">开始</button>
<button id="btnStop">停止</button>
</div>
<audio id="audio_record" type="audio/wav" controls style="margin-top: 2px; width: 100%;display: none;"></audio>
</div>
</div>
<script src="wsconnecter.js" charset="utf-8"></script>
<script src="main.js" charset="utf-8"></script>
</body>
<script type="text/javascript" charset="utf-8">
// $(document).ready(function() {
// var host = window.location.hostname
// var ws = new WebSocket("ws://"+host+":8000/humanecho");
// //document.getElementsByTagName("video")[0].setAttribute("src", aa["video"]);
// ws.onopen = function() {
// console.log('Connected');
// };
// ws.onmessage = function(e) {
// console.log('Received: ' + e.data);
// data = e
// var vid = JSON.parse(data.data);
// console.log(typeof(vid),vid)
// //document.getElementsByTagName("video")[0].setAttribute("src", vid["video"]);
// };
// ws.onclose = function(e) {
// console.log('Closed');
// };
// flvPlayer = mpegts.createPlayer({type: 'flv', url: "http://"+host+":8080/live/livestream.flv", isLive: true, enableStashBuffer: false});
// flvPlayer.attachMediaElement(document.getElementById('video_player'));
// flvPlayer.load();
// flvPlayer.play();
// $('#echo-form').on('submit', function(e) {
// e.preventDefault();
// var message = $('#message').val();
// console.log('Sending: ' + message);
// ws.send(message);
// $('#message').val('');
// });
// });
</script>
</html>

File diff suppressed because one or more lines are too long

View File

@ -1,24 +0,0 @@
1、启动语言识别服务端
创建虚拟环境
conda create -n funasr
conda activate funasr
安装依赖库
pip install torch
pip install modelscope
pip install testresources
pip install websockets
pip install torchaudio
pip install FunASR
pip install pyaudio
python funasr_wss_server.py --port 10095
或者
python funasr_wss_server.py --host "0.0.0.0" --port 10197 --ngpu 0
https://github.com/alibaba-damo-academy/FunASR
https://zhuanlan.zhihu.com/p/649935170

View File

@ -1,17 +0,0 @@
## certificate generation by yourself
generated certificate may not suitable for all browsers due to security concerns. you'd better buy or download an authenticated ssl certificate from authorized agency.
```shell
### 1) Generate a private key
openssl genrsa -des3 -out server.key 2048
### 2) Generate a csr file
openssl req -new -key server.key -out server.csr
### 3) Remove pass
cp server.key server.key.org
openssl rsa -in server.key.org -out server.key
### 4) Generated a crt file, valid for 1 year
openssl x509 -req -days 365 -in server.csr -signkey server.key -out server.crt
```

View File

@ -1,17 +0,0 @@
## 自行生成证书
生成证书(注意这种证书并不能被所有浏览器认可,部分手动授权可以访问,最好使用其他认证的官方ssl证书)
```shell
### 1)生成私钥,按照提示填写内容
openssl genrsa -des3 -out server.key 1024
### 2)生成csr文件 ,按照提示填写内容
openssl req -new -key server.key -out server.csr
### 去掉pass
cp server.key server.key.org
openssl rsa -in server.key.org -out server.key
### 生成crt文件有效期1年365天
openssl x509 -req -days 365 -in server.csr -signkey server.key -out server.crt
```

View File

@ -1,21 +0,0 @@
-----BEGIN CERTIFICATE-----
MIIDhTCCAm0CFGB0Po2IZ0hESavFpcSGRNb9xrNXMA0GCSqGSIb3DQEBCwUAMH8x
CzAJBgNVBAYTAkNOMRAwDgYDVQQIDAdiZWlqaW5nMRAwDgYDVQQHDAdiZWlqaW5n
MRAwDgYDVQQKDAdhbGliYWJhMRAwDgYDVQQLDAdhbGliYWJhMRAwDgYDVQQDDAdh
bGliYWJhMRYwFAYJKoZIhvcNAQkBFgdhbGliYWJhMB4XDTIzMDYxODA2NTcxM1oX
DTI0MDYxNzA2NTcxM1owfzELMAkGA1UEBhMCQ04xEDAOBgNVBAgMB2JlaWppbmcx
EDAOBgNVBAcMB2JlaWppbmcxEDAOBgNVBAoMB2FsaWJhYmExEDAOBgNVBAsMB2Fs
aWJhYmExEDAOBgNVBAMMB2FsaWJhYmExFjAUBgkqhkiG9w0BCQEWB2FsaWJhYmEw
ggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDH9Np1oBunQKMt5M/nU2nD
qVHojXwKKwyiK9DSeGikKwArH2S9NUZNu5RDg46u0iWmT+Vz+toQhkJnfatOVskW
f2bsI54n5eOvmoWOKDXYm2MscvjkuNiYRbqzgUuP9ZSx8k3uyRs++wvmwIoU+PV1
EYFcjk1P2jUGUvKaUlmIDsjs1wOMIbKO6I0UX20FNKlGWacqMR/Dx2ltmGKT1Kaz
Y335lor0bcfQtH542rGS7PDz6JMRNjFT1VFcmnrjRElf4STbaOiIfOjMVZ/9O8Hr
LFItyvkb01Mt7O0jhAXHuE1l/8Y0N3MCYkELG9mQA0BYCFHY0FLuJrGoU03b8KWj
AgMBAAEwDQYJKoZIhvcNAQELBQADggEBAEjC9jB1WZe2ki2JgCS+eAMFsFegiNEz
D0klVB3kiCPK0g7DCxvfWR6kAgEynxRxVX6TN9QcLr4paZItC1Fu2gUMTteNqEuc
dcixJdu9jumuUMBlAKgL5Yyk3alSErsn9ZVF/Q8Kx5arMO/TW3Ulsd8SWQL5C/vq
Fe0SRhpKKoADPfl8MT/XMfB/MwNxVhYDSHzJ1EiN8O5ce6q2tTdi1mlGquzNxhjC
7Q0F36V1HksfzolrlRWRKYP16isnaKUdFfeAzaJsYw33o6VRbk6fo2fTQDHS0wOs
Q48Moc5UxKMLaMMCqLPpWu0TZse+kIw1nTWXk7yJtK0HK5PN3rTocEw=
-----END CERTIFICATE-----

View File

@ -1,27 +0,0 @@
-----BEGIN RSA PRIVATE KEY-----
MIIEowIBAAKCAQEAx/TadaAbp0CjLeTP51Npw6lR6I18CisMoivQ0nhopCsAKx9k
vTVGTbuUQ4OOrtIlpk/lc/raEIZCZ32rTlbJFn9m7COeJ+Xjr5qFjig12JtjLHL4
5LjYmEW6s4FLj/WUsfJN7skbPvsL5sCKFPj1dRGBXI5NT9o1BlLymlJZiA7I7NcD
jCGyjuiNFF9tBTSpRlmnKjEfw8dpbZhik9Sms2N9+ZaK9G3H0LR+eNqxkuzw8+iT
ETYxU9VRXJp640RJX+Ek22joiHzozFWf/TvB6yxSLcr5G9NTLeztI4QFx7hNZf/G
NDdzAmJBCxvZkANAWAhR2NBS7iaxqFNN2/ClowIDAQABAoIBAQC1/STX6eFBWJMs
MhUHdePNMU5bWmqK1qOo9jgZV33l7T06Alit3M8f8JoA2LwEYT/jHtS3upi+cXP+
vWIs6tAaqdoDEmff6FxSd1EXEYHwo3yf+ASQJ6z66nwC5KrhW6L6Uo6bxm4F5Hfw
jU0fyXeeFVCn7Nxw0SlxmA02Z70VFsL8BK9i3kajU18y6drf4VUm55oMEtdEmOh2
eKn4qspBcNblbw+L0QJ+5kN1iRUyJHesQ1GpS+L3yeMVFCW7ctL4Bgw8Z7LE+z7i
C0Weyhul8vuT+7nfF2T37zsSa8iixqpkTokeYh96CZ5nDqa2IDx3oNHWSlkIsV6g
6EUEl9gBAoGBAPIw/M6fIDetMj8f1wG7mIRgJsxI817IS6aBSwB5HkoCJFfrR9Ua
jMNCFIWNs/Om8xeGhq/91hbnCYDNK06V5CUa/uk4CYRs2eQZ3FKoNowtp6u/ieuU
qg8bXM/vR2VWtWVixAMdouT3+KtvlgaVmSnrPiwO4pecGrwu5NW1oJCFAoGBANNb
aE3AcwTDYsqh0N/75G56Q5s1GZ6MCDQGQSh8IkxL6Vg59KnJiIKQ7AxNKFgJZMtY
zZHaqjazeHjOGTiYiC7MMVJtCcOBEfjCouIG8btNYv7Y3dWnOXRZni2telAsRrH9
xS5LaFdCRTjVAwSsppMGwiQtyl6sGLMyz0SXoYoHAoGAKdkFFb6xFm26zOV3hTkg
9V6X1ZyVUL9TMwYMK5zB+w+7r+VbmBrqT6LPYPRHL8adImeARlCZ+YMaRUMuRHnp
3e94NFwWaOdWDu/Y/f9KzZXl7us9rZMWf12+/77cm0oMNeSG8fLg/qdKNHUneyPG
P1QCfiJkTMYQaIvBxpuHjvECgYAKlZ9JlYOtD2PZJfVh4il0ZucP1L7ts7GNeWq1
7lGBZKPQ6UYZYqBVeZB4pTyJ/B5yGIZi8YJoruAvnJKixPC89zjZGeDNS59sx8KE
cziT2rJEdPPXCULVUs+bFf70GOOJcl33jYsyI3139SLrjwHghwwd57UkvJWYE8lR
dA6A7QKBgEfTC+NlzqLPhbB+HPl6CvcUczcXcI9M0heVz/DNMA+4pjxPnv2aeIwh
cL2wq2xr+g1wDBWGVGkVSuZhXm5E6gDetdyVeJnbIUhVjBblnbhHV6GrudjbXGnJ
W9cBgu6DswyHU2cOsqmimu8zLmG6/dQYFHt+kUWGxN8opCzVjgWa
-----END RSA PRIVATE KEY-----

47
lipasr.py Normal file
View File

@ -0,0 +1,47 @@
import time
import torch
import numpy as np
import queue
from queue import Queue
import multiprocessing as mp
from baseasr import BaseASR
from wav2lip import audio
class LipASR(BaseASR):
def run_step(self):
############################################## extract audio feature ##############################################
# get a frame of audio
for _ in range(self.batch_size*2):
frame,type = self.get_audio_frame()
self.frames.append(frame)
# put to output
self.output_queue.put((frame,type))
# context not enough, do not run network.
if len(self.frames) <= self.stride_left_size + self.stride_right_size:
return
inputs = np.concatenate(self.frames) # [N * chunk]
mel = audio.melspectrogram(inputs)
#print(mel.shape[0],mel.shape,len(mel[0]),len(self.frames))
# cut off stride
left = max(0, self.stride_left_size*80/50)
right = min(len(mel[0]), len(mel[0]) - self.stride_right_size*80/50)
mel_idx_multiplier = 80.*2/self.fps
mel_step_size = 16
i = 0
mel_chunks = []
while i < (len(self.frames)-self.stride_left_size-self.stride_right_size)/2:
start_idx = int(left + i * mel_idx_multiplier)
#print(start_idx)
if start_idx + mel_step_size > len(mel[0]):
mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
else:
mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size])
i += 1
self.feat_queue.put(mel_chunks)
# discard the old part to save memory
self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]

281
lipreal.py Normal file
View File

@ -0,0 +1,281 @@
import math
import torch
import numpy as np
#from .utils import *
import subprocess
import os
import time
import cv2
import glob
import pickle
import copy
import queue
from queue import Queue
from threading import Thread, Event
from io import BytesIO
import multiprocessing as mp
from ttsreal import EdgeTTS,VoitsTTS,XTTS
from lipasr import LipASR
import asyncio
from av import AudioFrame, VideoFrame
from wav2lip.models import Wav2Lip
from basereal import BaseReal
#from imgcache import ImgCache
from tqdm import tqdm
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} for inference.'.format(device))
def _load(checkpoint_path):
if device == 'cuda':
checkpoint = torch.load(checkpoint_path)
else:
checkpoint = torch.load(checkpoint_path,
map_location=lambda storage, loc: storage)
return checkpoint
def load_model(path):
model = Wav2Lip()
print("Load checkpoint from: {}".format(path))
checkpoint = _load(path)
s = checkpoint["state_dict"]
new_s = {}
for k, v in s.items():
new_s[k.replace('module.', '')] = v
model.load_state_dict(new_s)
model = model.to(device)
return model.eval()
def read_imgs(img_list):
frames = []
print('reading images...')
for img_path in tqdm(img_list):
frame = cv2.imread(img_path)
frames.append(frame)
return frames
def __mirror_index(size, index):
#size = len(self.coord_list_cycle)
turn = index // size
res = index % size
if turn % 2 == 0:
return res
else:
return size - res - 1
def inference(render_event,batch_size,face_imgs_path,audio_feat_queue,audio_out_queue,res_frame_queue):
model = load_model("./models/wav2lip.pth")
input_face_list = glob.glob(os.path.join(face_imgs_path, '*.[jpJP][pnPN]*[gG]'))
input_face_list = sorted(input_face_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
face_list_cycle = read_imgs(input_face_list)
#input_latent_list_cycle = torch.load(latents_out_path)
length = len(face_list_cycle)
index = 0
count=0
counttime=0
print('start inference')
while True:
if render_event.is_set():
starttime=time.perf_counter()
mel_batch = []
try:
mel_batch = audio_feat_queue.get(block=True, timeout=1)
except queue.Empty:
continue
is_all_silence=True
audio_frames = []
for _ in range(batch_size*2):
frame,type = audio_out_queue.get()
audio_frames.append((frame,type))
if type==0:
is_all_silence=False
if is_all_silence:
for i in range(batch_size):
res_frame_queue.put((None,__mirror_index(length,index),audio_frames[i*2:i*2+2]))
index = index + 1
else:
# print('infer=======')
t=time.perf_counter()
img_batch = []
for i in range(batch_size):
idx = __mirror_index(length,index+i)
face = face_list_cycle[idx]
img_batch.append(face)
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
img_masked = img_batch.copy()
img_masked[:, face.shape[0]//2:] = 0
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
with torch.no_grad():
pred = model(mel_batch, img_batch)
pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
counttime += (time.perf_counter() - t)
count += batch_size
#_totalframe += 1
if count>=100:
print(f"------actual avg infer fps:{count/counttime:.4f}")
count=0
counttime=0
for i,res_frame in enumerate(pred):
#self.__pushmedia(res_frame,loop,audio_track,video_track)
res_frame_queue.put((res_frame,__mirror_index(length,index),audio_frames[i*2:i*2+2]))
index = index + 1
#print('total batch time:',time.perf_counter()-starttime)
else:
time.sleep(1)
print('musereal inference processor stop')
@torch.no_grad()
class LipReal(BaseReal):
def __init__(self, opt):
super().__init__(opt)
#self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
self.W = opt.W
self.H = opt.H
self.fps = opt.fps # 20 ms per frame
#### musetalk
self.avatar_id = opt.avatar_id
self.avatar_path = f"./data/avatars/{self.avatar_id}"
self.full_imgs_path = f"{self.avatar_path}/full_imgs"
self.face_imgs_path = f"{self.avatar_path}/face_imgs"
self.coords_path = f"{self.avatar_path}/coords.pkl"
self.batch_size = opt.batch_size
self.idx = 0
self.res_frame_queue = mp.Queue(self.batch_size*2)
#self.__loadmodels()
self.__loadavatar()
self.asr = LipASR(opt,self)
self.asr.warm_up()
#self.__warm_up()
self.render_event = mp.Event()
mp.Process(target=inference, args=(self.render_event,self.batch_size,self.face_imgs_path,
self.asr.feat_queue,self.asr.output_queue,self.res_frame_queue,
)).start()
# def __loadmodels(self):
# # load model weights
# self.audio_processor, self.vae, self.unet, self.pe = load_all_model()
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# self.timesteps = torch.tensor([0], device=device)
# self.pe = self.pe.half()
# self.vae.vae = self.vae.vae.half()
# self.unet.model = self.unet.model.half()
def __loadavatar(self):
with open(self.coords_path, 'rb') as f:
self.coord_list_cycle = pickle.load(f)
input_img_list = glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]'))
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
self.frame_list_cycle = read_imgs(input_img_list)
#self.imagecache = ImgCache(len(self.coord_list_cycle),self.full_imgs_path,1000)
def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None):
while not quit_event.is_set():
try:
res_frame,idx,audio_frames = self.res_frame_queue.get(block=True, timeout=1)
except queue.Empty:
continue
if audio_frames[0][1]!=0 and audio_frames[1][1]!=0: #全为静音数据只需要取fullimg
self.speaking = False
audiotype = audio_frames[0][1]
if self.custom_index.get(audiotype) is not None: #有自定义视频
mirindex = self.mirror_index(len(self.custom_img_cycle[audiotype]),self.custom_index[audiotype])
combine_frame = self.custom_img_cycle[audiotype][mirindex]
self.custom_index[audiotype] += 1
# if not self.custom_opt[audiotype].loop and self.custom_index[audiotype]>=len(self.custom_img_cycle[audiotype]):
# self.curr_state = 1 #当前视频不循环播放,切换到静音状态
else:
combine_frame = self.frame_list_cycle[idx]
#combine_frame = self.imagecache.get_img(idx)
else:
self.speaking = True
bbox = self.coord_list_cycle[idx]
combine_frame = copy.deepcopy(self.frame_list_cycle[idx])
#combine_frame = copy.deepcopy(self.imagecache.get_img(idx))
y1, y2, x1, x2 = bbox
try:
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
except:
continue
#combine_frame = get_image(ori_frame,res_frame,bbox)
#t=time.perf_counter()
combine_frame[y1:y2, x1:x2] = res_frame
#print('blending time:',time.perf_counter()-t)
image = combine_frame #(outputs['image'] * 255).astype(np.uint8)
new_frame = VideoFrame.from_ndarray(image, format="bgr24")
asyncio.run_coroutine_threadsafe(video_track._queue.put(new_frame), loop)
if self.recording:
self.recordq_video.put(new_frame)
for audio_frame in audio_frames:
frame,type = audio_frame
frame = (frame * 32767).astype(np.int16)
new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0])
new_frame.planes[0].update(frame.tobytes())
new_frame.sample_rate=16000
# if audio_track._queue.qsize()>10:
# time.sleep(0.1)
asyncio.run_coroutine_threadsafe(audio_track._queue.put(new_frame), loop)
if self.recording:
self.recordq_audio.put(new_frame)
print('musereal process_frames thread stop')
def render(self,quit_event,loop=None,audio_track=None,video_track=None):
#if self.opt.asr:
# self.asr.warm_up()
self.tts.render(quit_event)
self.init_customindex()
process_thread = Thread(target=self.process_frames, args=(quit_event,loop,audio_track,video_track))
process_thread.start()
self.render_event.set() #start infer process render
count=0
totaltime=0
_starttime=time.perf_counter()
#_totalframe=0
while not quit_event.is_set():
# update texture every frame
# audio stream thread...
t = time.perf_counter()
self.asr.run_step()
# if video_track._queue.qsize()>=2*self.opt.batch_size:
# print('sleep qsize=',video_track._queue.qsize())
# time.sleep(0.04*video_track._queue.qsize()*0.8)
if video_track._queue.qsize()>=5:
print('sleep qsize=',video_track._queue.qsize())
time.sleep(0.04*video_track._queue.qsize()*0.8)
# delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms
# if delay > 0:
# time.sleep(delay)
self.render_event.clear() #end infer process render
print('musereal thread stop')

View File

@ -15,7 +15,7 @@ class VllmGPT:
self.__URL = "http://{}:{}/v1/completions".format(self.host, self.port) self.__URL = "http://{}:{}/v1/completions".format(self.host, self.port)
self.__URL2 = "http://{}:{}/v1/chat/completions".format(self.host, self.port) self.__URL2 = "http://{}:{}/v1/chat/completions".format(self.host, self.port)
def question(self,cont): def chat(self,cont):
chat_list = [] chat_list = []
# contentdb = content_db.new_instance() # contentdb = content_db.new_instance()
# list = contentdb.get_list('all','desc',11) # list = contentdb.get_list('all','desc',11)
@ -77,5 +77,5 @@ class VllmGPT:
if __name__ == "__main__": if __name__ == "__main__":
vllm = VllmGPT('192.168.1.3','8101') vllm = VllmGPT('192.168.1.3','8101')
req = vllm.question("你叫什么名字啊今年多大了") req = vllm.chat("你叫什么名字啊今年多大了")
print(req) print(req)

View File

36
museasr.py Normal file
View File

@ -0,0 +1,36 @@
import time
import numpy as np
import queue
from queue import Queue
import multiprocessing as mp
from baseasr import BaseASR
from musetalk.whisper.audio2feature import Audio2Feature
class MuseASR(BaseASR):
def __init__(self, opt, parent,audio_processor:Audio2Feature):
super().__init__(opt,parent)
self.audio_processor = audio_processor
def run_step(self):
############################################## extract audio feature ##############################################
start_time = time.time()
for _ in range(self.batch_size*2):
audio_frame,type=self.get_audio_frame()
self.frames.append(audio_frame)
self.output_queue.put((audio_frame,type))
if len(self.frames) <= self.stride_left_size + self.stride_right_size:
return
inputs = np.concatenate(self.frames) # [N * chunk]
whisper_feature = self.audio_processor.audio2feat(inputs)
# for feature in whisper_feature:
# self.audio_feats.append(feature)
#print(f"processing audio costs {(time.time() - start_time) * 1000}ms, inputs shape:{inputs.shape} whisper_feature len:{len(whisper_feature)}")
whisper_chunks = self.audio_processor.feature2chunks(feature_array=whisper_feature,fps=self.fps/2,batch_size=self.batch_size,start=self.stride_left_size/2 )
#print(f"whisper_chunks len:{len(whisper_chunks)},self.audio_feats len:{len(self.audio_feats)},self.output_queue len:{self.output_queue.qsize()}")
#self.audio_feats = self.audio_feats[-(self.stride_left_size + self.stride_right_size):]
self.feat_queue.put(whisper_chunks)
# discard the old part to save memory
self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]

318
musereal.py Normal file
View File

@ -0,0 +1,318 @@
import math
import torch
import numpy as np
#from .utils import *
import subprocess
import os
import time
import torch.nn.functional as F
import cv2
import glob
import pickle
import copy
import queue
from queue import Queue
from threading import Thread, Event
from io import BytesIO
import multiprocessing as mp
from musetalk.utils.utils import get_file_type,get_video_fps,datagen
#from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder
from musetalk.utils.blending import get_image,get_image_prepare_material,get_image_blending
from musetalk.utils.utils import load_all_model,load_diffusion_model,load_audio_model
from ttsreal import EdgeTTS,VoitsTTS,XTTS
from museasr import MuseASR
import asyncio
from av import AudioFrame, VideoFrame
from basereal import BaseReal
from tqdm import tqdm
def read_imgs(img_list):
frames = []
print('reading images...')
for img_path in tqdm(img_list):
frame = cv2.imread(img_path)
frames.append(frame)
return frames
def __mirror_index(size, index):
#size = len(self.coord_list_cycle)
turn = index // size
res = index % size
if turn % 2 == 0:
return res
else:
return size - res - 1
@torch.no_grad()
def inference(render_event,batch_size,latents_out_path,audio_feat_queue,audio_out_queue,res_frame_queue,
): #vae, unet, pe,timesteps
vae, unet, pe = load_diffusion_model()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
timesteps = torch.tensor([0], device=device)
pe = pe.half()
vae.vae = vae.vae.half()
unet.model = unet.model.half()
input_latent_list_cycle = torch.load(latents_out_path)
length = len(input_latent_list_cycle)
index = 0
count=0
counttime=0
print('start inference')
while True:
if render_event.is_set():
starttime=time.perf_counter()
try:
whisper_chunks = audio_feat_queue.get(block=True, timeout=1)
except queue.Empty:
continue
is_all_silence=True
audio_frames = []
for _ in range(batch_size*2):
frame,type = audio_out_queue.get()
audio_frames.append((frame,type))
if type==0:
is_all_silence=False
if is_all_silence:
for i in range(batch_size):
res_frame_queue.put((None,__mirror_index(length,index),audio_frames[i*2:i*2+2]))
index = index + 1
else:
# print('infer=======')
t=time.perf_counter()
whisper_batch = np.stack(whisper_chunks)
latent_batch = []
for i in range(batch_size):
idx = __mirror_index(length,index+i)
latent = input_latent_list_cycle[idx]
latent_batch.append(latent)
latent_batch = torch.cat(latent_batch, dim=0)
# for i, (whisper_batch,latent_batch) in enumerate(gen):
audio_feature_batch = torch.from_numpy(whisper_batch)
audio_feature_batch = audio_feature_batch.to(device=unet.device,
dtype=unet.model.dtype)
audio_feature_batch = pe(audio_feature_batch)
latent_batch = latent_batch.to(dtype=unet.model.dtype)
# print('prepare time:',time.perf_counter()-t)
# t=time.perf_counter()
pred_latents = unet.model(latent_batch,
timesteps,
encoder_hidden_states=audio_feature_batch).sample
# print('unet time:',time.perf_counter()-t)
# t=time.perf_counter()
recon = vae.decode_latents(pred_latents)
# print('vae time:',time.perf_counter()-t)
#print('diffusion len=',len(recon))
counttime += (time.perf_counter() - t)
count += batch_size
#_totalframe += 1
if count>=100:
print(f"------actual avg infer fps:{count/counttime:.4f}")
count=0
counttime=0
for i,res_frame in enumerate(recon):
#self.__pushmedia(res_frame,loop,audio_track,video_track)
res_frame_queue.put((res_frame,__mirror_index(length,index),audio_frames[i*2:i*2+2]))
index = index + 1
#print('total batch time:',time.perf_counter()-starttime)
else:
time.sleep(1)
print('musereal inference processor stop')
@torch.no_grad()
class MuseReal(BaseReal):
def __init__(self, opt):
super().__init__(opt)
#self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
self.W = opt.W
self.H = opt.H
self.fps = opt.fps # 20 ms per frame
#### musetalk
self.avatar_id = opt.avatar_id
self.video_path = '' #video_path
self.bbox_shift = opt.bbox_shift
self.avatar_path = f"./data/avatars/{self.avatar_id}"
self.full_imgs_path = f"{self.avatar_path}/full_imgs"
self.coords_path = f"{self.avatar_path}/coords.pkl"
self.latents_out_path= f"{self.avatar_path}/latents.pt"
self.video_out_path = f"{self.avatar_path}/vid_output/"
self.mask_out_path =f"{self.avatar_path}/mask"
self.mask_coords_path =f"{self.avatar_path}/mask_coords.pkl"
self.avatar_info_path = f"{self.avatar_path}/avator_info.json"
self.avatar_info = {
"avatar_id":self.avatar_id,
"video_path":self.video_path,
"bbox_shift":self.bbox_shift
}
self.batch_size = opt.batch_size
self.idx = 0
self.res_frame_queue = mp.Queue(self.batch_size*2)
self.__loadmodels()
self.__loadavatar()
self.asr = MuseASR(opt,self,self.audio_processor)
self.asr.warm_up()
#self.__warm_up()
self.render_event = mp.Event()
mp.Process(target=inference, args=(self.render_event,self.batch_size,self.latents_out_path,
self.asr.feat_queue,self.asr.output_queue,self.res_frame_queue,
)).start() #self.vae, self.unet, self.pe,self.timesteps
def __loadmodels(self):
# load model weights
self.audio_processor= load_audio_model()
# self.audio_processor, self.vae, self.unet, self.pe = load_all_model()
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# self.timesteps = torch.tensor([0], device=device)
# self.pe = self.pe.half()
# self.vae.vae = self.vae.vae.half()
# self.unet.model = self.unet.model.half()
def __loadavatar(self):
#self.input_latent_list_cycle = torch.load(self.latents_out_path)
with open(self.coords_path, 'rb') as f:
self.coord_list_cycle = pickle.load(f)
input_img_list = glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]'))
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
self.frame_list_cycle = read_imgs(input_img_list)
with open(self.mask_coords_path, 'rb') as f:
self.mask_coords_list_cycle = pickle.load(f)
input_mask_list = glob.glob(os.path.join(self.mask_out_path, '*.[jpJP][pnPN]*[gG]'))
input_mask_list = sorted(input_mask_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
self.mask_list_cycle = read_imgs(input_mask_list)
def __mirror_index(self, index):
size = len(self.coord_list_cycle)
turn = index // size
res = index % size
if turn % 2 == 0:
return res
else:
return size - res - 1
def __warm_up(self):
self.asr.run_step()
whisper_chunks = self.asr.get_next_feat()
whisper_batch = np.stack(whisper_chunks)
latent_batch = []
for i in range(self.batch_size):
idx = self.__mirror_index(self.idx+i)
latent = self.input_latent_list_cycle[idx]
latent_batch.append(latent)
latent_batch = torch.cat(latent_batch, dim=0)
print('infer=======')
# for i, (whisper_batch,latent_batch) in enumerate(gen):
audio_feature_batch = torch.from_numpy(whisper_batch)
audio_feature_batch = audio_feature_batch.to(device=self.unet.device,
dtype=self.unet.model.dtype)
audio_feature_batch = self.pe(audio_feature_batch)
latent_batch = latent_batch.to(dtype=self.unet.model.dtype)
pred_latents = self.unet.model(latent_batch,
self.timesteps,
encoder_hidden_states=audio_feature_batch).sample
recon = self.vae.decode_latents(pred_latents)
def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None):
while not quit_event.is_set():
try:
res_frame,idx,audio_frames = self.res_frame_queue.get(block=True, timeout=1)
except queue.Empty:
continue
if audio_frames[0][1]!=0 and audio_frames[1][1]!=0: #全为静音数据只需要取fullimg
self.speaking = False
audiotype = audio_frames[0][1]
if self.custom_index.get(audiotype) is not None: #有自定义视频
mirindex = self.mirror_index(len(self.custom_img_cycle[audiotype]),self.custom_index[audiotype])
combine_frame = self.custom_img_cycle[audiotype][mirindex]
self.custom_index[audiotype] += 1
# if not self.custom_opt[audiotype].loop and self.custom_index[audiotype]>=len(self.custom_img_cycle[audiotype]):
# self.curr_state = 1 #当前视频不循环播放,切换到静音状态
else:
combine_frame = self.frame_list_cycle[idx]
else:
self.speaking = True
bbox = self.coord_list_cycle[idx]
ori_frame = copy.deepcopy(self.frame_list_cycle[idx])
x1, y1, x2, y2 = bbox
try:
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
except:
continue
mask = self.mask_list_cycle[idx]
mask_crop_box = self.mask_coords_list_cycle[idx]
#combine_frame = get_image(ori_frame,res_frame,bbox)
#t=time.perf_counter()
combine_frame = get_image_blending(ori_frame,res_frame,bbox,mask,mask_crop_box)
#print('blending time:',time.perf_counter()-t)
image = combine_frame #(outputs['image'] * 255).astype(np.uint8)
new_frame = VideoFrame.from_ndarray(image, format="bgr24")
asyncio.run_coroutine_threadsafe(video_track._queue.put(new_frame), loop)
if self.recording:
self.recordq_video.put(new_frame)
for audio_frame in audio_frames:
frame,type = audio_frame
frame = (frame * 32767).astype(np.int16)
new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0])
new_frame.planes[0].update(frame.tobytes())
new_frame.sample_rate=16000
# if audio_track._queue.qsize()>10:
# time.sleep(0.1)
asyncio.run_coroutine_threadsafe(audio_track._queue.put(new_frame), loop)
if self.recording:
self.recordq_audio.put(new_frame)
print('musereal process_frames thread stop')
def render(self,quit_event,loop=None,audio_track=None,video_track=None):
#if self.opt.asr:
# self.asr.warm_up()
self.tts.render(quit_event)
self.init_customindex()
process_thread = Thread(target=self.process_frames, args=(quit_event,loop,audio_track,video_track))
process_thread.start()
self.render_event.set() #start infer process render
count=0
totaltime=0
_starttime=time.perf_counter()
#_totalframe=0
while not quit_event.is_set(): #todo
# update texture every frame
# audio stream thread...
t = time.perf_counter()
self.asr.run_step()
#self.test_step(loop,audio_track,video_track)
# totaltime += (time.perf_counter() - t)
# count += self.opt.batch_size
# if count>=100:
# print(f"------actual avg infer fps:{count/totaltime:.4f}")
# count=0
# totaltime=0
if video_track._queue.qsize()>=1.5*self.opt.batch_size:
print('sleep qsize=',video_track._queue.qsize())
time.sleep(0.04*video_track._queue.qsize()*0.8)
# if video_track._queue.qsize()>=5:
# print('sleep qsize=',video_track._queue.qsize())
# time.sleep(0.04*video_track._queue.qsize()*0.8)
# delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms
# if delay > 0:
# time.sleep(delay)
self.render_event.clear() #end infer process render
print('musereal thread stop')

47
musetalk/models/unet.py Executable file
View File

@ -0,0 +1,47 @@
import torch
import torch.nn as nn
import math
import json
from diffusers import UNet2DConditionModel
import sys
import time
import numpy as np
import os
class PositionalEncoding(nn.Module):
def __init__(self, d_model=384, max_len=5000):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
b, seq_len, d_model = x.size()
pe = self.pe[:, :seq_len, :]
x = x + pe.to(x.device)
return x
class UNet():
def __init__(self,
unet_config,
model_path,
use_float16=False,
):
with open(unet_config, 'r') as f:
unet_config = json.load(f)
self.model = UNet2DConditionModel(**unet_config)
self.pe = PositionalEncoding(d_model=384)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
weights = torch.load(model_path) if torch.cuda.is_available() else torch.load(model_path, map_location=self.device)
self.model.load_state_dict(weights)
if use_float16:
self.model = self.model.half()
self.model.to(self.device)
if __name__ == "__main__":
unet = UNet()

148
musetalk/models/vae.py Executable file
View File

@ -0,0 +1,148 @@
from diffusers import AutoencoderKL
import torch
import torchvision.transforms as transforms
import torch.nn.functional as F
import cv2
import numpy as np
from PIL import Image
import os
class VAE():
"""
VAE (Variational Autoencoder) class for image processing.
"""
def __init__(self, model_path="./models/sd-vae-ft-mse/", resized_img=256, use_float16=False):
"""
Initialize the VAE instance.
:param model_path: Path to the trained model.
:param resized_img: The size to which images are resized.
:param use_float16: Whether to use float16 precision.
"""
self.model_path = model_path
self.vae = AutoencoderKL.from_pretrained(self.model_path)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.vae.to(self.device)
if use_float16:
self.vae = self.vae.half()
self._use_float16 = True
else:
self._use_float16 = False
self.scaling_factor = self.vae.config.scaling_factor
self.transform = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
self._resized_img = resized_img
self._mask_tensor = self.get_mask_tensor()
def get_mask_tensor(self):
"""
Creates a mask tensor for image processing.
:return: A mask tensor.
"""
mask_tensor = torch.zeros((self._resized_img,self._resized_img))
mask_tensor[:self._resized_img//2,:] = 1
mask_tensor[mask_tensor< 0.5] = 0
mask_tensor[mask_tensor>= 0.5] = 1
return mask_tensor
def preprocess_img(self,img_name,half_mask=False):
"""
Preprocess an image for the VAE.
:param img_name: The image file path or a list of image file paths.
:param half_mask: Whether to apply a half mask to the image.
:return: A preprocessed image tensor.
"""
window = []
if isinstance(img_name, str):
window_fnames = [img_name]
for fname in window_fnames:
img = cv2.imread(fname)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (self._resized_img, self._resized_img),
interpolation=cv2.INTER_LANCZOS4)
window.append(img)
else:
img = cv2.cvtColor(img_name, cv2.COLOR_BGR2RGB)
window.append(img)
x = np.asarray(window) / 255.
x = np.transpose(x, (3, 0, 1, 2))
x = torch.squeeze(torch.FloatTensor(x))
if half_mask:
x = x * (self._mask_tensor>0.5)
x = self.transform(x)
x = x.unsqueeze(0) # [1, 3, 256, 256] torch tensor
x = x.to(self.vae.device)
return x
def encode_latents(self,image):
"""
Encode an image into latent variables.
:param image: The image tensor to encode.
:return: The encoded latent variables.
"""
with torch.no_grad():
init_latent_dist = self.vae.encode(image.to(self.vae.dtype)).latent_dist
init_latents = self.scaling_factor * init_latent_dist.sample()
return init_latents
def decode_latents(self, latents):
"""
Decode latent variables back into an image.
:param latents: The latent variables to decode.
:return: A NumPy array representing the decoded image.
"""
latents = (1/ self.scaling_factor) * latents
image = self.vae.decode(latents.to(self.vae.dtype)).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
image = (image * 255).round().astype("uint8")
image = image[...,::-1] # RGB to BGR
return image
def get_latents_for_unet(self,img):
"""
Prepare latent variables for a U-Net model.
:param img: The image to process.
:return: A concatenated tensor of latents for U-Net input.
"""
ref_image = self.preprocess_img(img,half_mask=True) # [1, 3, 256, 256] RGB, torch tensor
masked_latents = self.encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
ref_image = self.preprocess_img(img,half_mask=False) # [1, 3, 256, 256] RGB, torch tensor
ref_latents = self.encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
latent_model_input = torch.cat([masked_latents, ref_latents], dim=1)
return latent_model_input
if __name__ == "__main__":
vae_mode_path = "./models/sd-vae-ft-mse/"
vae = VAE(model_path = vae_mode_path,use_float16=False)
img_path = "./results/sun001_crop/00000.png"
crop_imgs_path = "./results/sun001_crop/"
latents_out_path = "./results/latents/"
if not os.path.exists(latents_out_path):
os.mkdir(latents_out_path)
files = os.listdir(crop_imgs_path)
files.sort()
files = [file for file in files if file.split(".")[-1] == "png"]
for file in files:
index = file.split(".")[0]
img_path = crop_imgs_path + file
latents = vae.get_latents_for_unet(img_path)
print(img_path,"latents",latents.size())
#torch.save(latents,os.path.join(latents_out_path,index+".pt"))
#reload_tensor = torch.load('tensor.pt')
#print(reload_tensor.size())

348
musetalk/simple_musetalk.py Normal file
View File

@ -0,0 +1,348 @@
import argparse
import glob
import json
import os
import pickle
import shutil
import cv2
import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image
from diffusers import AutoencoderKL
from face_alignment import NetworkSize
from mmpose.apis import inference_topdown, init_model
from mmpose.structures import merge_data_samples
from tqdm import tqdm
try:
from utils.face_parsing import FaceParsing
except ModuleNotFoundError:
from musetalk.utils.face_parsing import FaceParsing
def video2imgs(vid_path, save_path, ext='.png', cut_frame=10000000):
cap = cv2.VideoCapture(vid_path)
count = 0
while True:
if count > cut_frame:
break
ret, frame = cap.read()
if ret:
cv2.imwrite(f"{save_path}/{count:08d}.png", frame)
count += 1
else:
break
def read_imgs(img_list):
frames = []
print('reading images...')
for img_path in tqdm(img_list):
frame = cv2.imread(img_path)
frames.append(frame)
return frames
def get_landmark_and_bbox(img_list, upperbondrange=0):
frames = read_imgs(img_list)
batch_size_fa = 1
batches = [frames[i:i + batch_size_fa] for i in range(0, len(frames), batch_size_fa)]
coords_list = []
landmarks = []
if upperbondrange != 0:
print('get key_landmark and face bounding boxes with the bbox_shift:', upperbondrange)
else:
print('get key_landmark and face bounding boxes with the default value')
average_range_minus = []
average_range_plus = []
coord_placeholder = (0.0, 0.0, 0.0, 0.0)
for fb in tqdm(batches):
results = inference_topdown(model, np.asarray(fb)[0])
results = merge_data_samples(results)
keypoints = results.pred_instances.keypoints
face_land_mark = keypoints[0][23:91]
face_land_mark = face_land_mark.astype(np.int32)
# get bounding boxes by face detetion
bbox = fa.get_detections_for_batch(np.asarray(fb))
# adjust the bounding box refer to landmark
# Add the bounding box to a tuple and append it to the coordinates list
for j, f in enumerate(bbox):
if f is None: # no face in the image
coords_list += [coord_placeholder]
continue
half_face_coord = face_land_mark[29] # np.mean([face_land_mark[28], face_land_mark[29]], axis=0)
range_minus = (face_land_mark[30] - face_land_mark[29])[1]
range_plus = (face_land_mark[29] - face_land_mark[28])[1]
average_range_minus.append(range_minus)
average_range_plus.append(range_plus)
if upperbondrange != 0:
half_face_coord[1] = upperbondrange + half_face_coord[1] # 手动调整 + 向下偏29 - 向上偏28
half_face_dist = np.max(face_land_mark[:, 1]) - half_face_coord[1]
upper_bond = half_face_coord[1] - half_face_dist
f_landmark = (
np.min(face_land_mark[:, 0]), int(upper_bond), np.max(face_land_mark[:, 0]),
np.max(face_land_mark[:, 1]))
x1, y1, x2, y2 = f_landmark
if y2 - y1 <= 0 or x2 - x1 <= 0 or x1 < 0: # if the landmark bbox is not suitable, reuse the bbox
coords_list += [f]
w, h = f[2] - f[0], f[3] - f[1]
print("error bbox:", f)
else:
coords_list += [f_landmark]
return coords_list, frames
class FaceAlignment:
def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
device='cuda', flip_input=False, face_detector='sfd', verbose=False):
self.device = device
self.flip_input = flip_input
self.landmarks_type = landmarks_type
self.verbose = verbose
network_size = int(network_size)
if 'cuda' in device:
torch.backends.cudnn.benchmark = True
# torch.backends.cuda.matmul.allow_tf32 = False
# torch.backends.cudnn.benchmark = True
# torch.backends.cudnn.deterministic = False
# torch.backends.cudnn.allow_tf32 = True
print('cuda start')
# Get the face detector
face_detector_module = __import__('face_detection.detection.' + face_detector,
globals(), locals(), [face_detector], 0)
self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose)
def get_detections_for_batch(self, images):
images = images[..., ::-1]
detected_faces = self.face_detector.detect_from_batch(images.copy())
results = []
for i, d in enumerate(detected_faces):
if len(d) == 0:
results.append(None)
continue
d = d[0]
d = np.clip(d, 0, None)
x1, y1, x2, y2 = map(int, d[:-1])
results.append((x1, y1, x2, y2))
return results
def get_mask_tensor():
"""
Creates a mask tensor for image processing.
:return: A mask tensor.
"""
mask_tensor = torch.zeros((256, 256))
mask_tensor[:256 // 2, :] = 1
mask_tensor[mask_tensor < 0.5] = 0
mask_tensor[mask_tensor >= 0.5] = 1
return mask_tensor
def preprocess_img(img_name, half_mask=False):
window = []
if isinstance(img_name, str):
window_fnames = [img_name]
for fname in window_fnames:
img = cv2.imread(fname)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (256, 256),
interpolation=cv2.INTER_LANCZOS4)
window.append(img)
else:
img = cv2.cvtColor(img_name, cv2.COLOR_BGR2RGB)
window.append(img)
x = np.asarray(window) / 255.
x = np.transpose(x, (3, 0, 1, 2))
x = torch.squeeze(torch.FloatTensor(x))
if half_mask:
x = x * (get_mask_tensor() > 0.5)
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
x = normalize(x)
x = x.unsqueeze(0) # [1, 3, 256, 256] torch tensor
x = x.to(device)
return x
def encode_latents(image):
with torch.no_grad():
init_latent_dist = vae.encode(image.to(vae.dtype)).latent_dist
init_latents = vae.config.scaling_factor * init_latent_dist.sample()
return init_latents
def get_latents_for_unet(img):
ref_image = preprocess_img(img, half_mask=True) # [1, 3, 256, 256] RGB, torch tensor
masked_latents = encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
ref_image = preprocess_img(img, half_mask=False) # [1, 3, 256, 256] RGB, torch tensor
ref_latents = encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
latent_model_input = torch.cat([masked_latents, ref_latents], dim=1)
return latent_model_input
def get_crop_box(box, expand):
x, y, x1, y1 = box
x_c, y_c = (x + x1) // 2, (y + y1) // 2
w, h = x1 - x, y1 - y
s = int(max(w, h) // 2 * expand)
crop_box = [x_c - s, y_c - s, x_c + s, y_c + s]
return crop_box, s
def face_seg(image):
seg_image = fp(image)
if seg_image is None:
print("error, no person_segment")
return None
seg_image = seg_image.resize(image.size)
return seg_image
def get_image_prepare_material(image, face_box, upper_boundary_ratio=0.5, expand=1.2):
body = Image.fromarray(image[:, :, ::-1])
x, y, x1, y1 = face_box
# print(x1-x,y1-y)
crop_box, s = get_crop_box(face_box, expand)
x_s, y_s, x_e, y_e = crop_box
face_large = body.crop(crop_box)
ori_shape = face_large.size
mask_image = face_seg(face_large)
mask_small = mask_image.crop((x - x_s, y - y_s, x1 - x_s, y1 - y_s))
mask_image = Image.new('L', ori_shape, 0)
mask_image.paste(mask_small, (x - x_s, y - y_s, x1 - x_s, y1 - y_s))
# keep upper_boundary_ratio of talking area
width, height = mask_image.size
top_boundary = int(height * upper_boundary_ratio)
modified_mask_image = Image.new('L', ori_shape, 0)
modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary))
blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1
mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0)
return mask_array, crop_box
##todo 简单根据文件后缀判断 要更精确的可以自己修改 使用 magic
def is_video_file(file_path):
video_exts = ['.mp4', '.mkv', '.flv', '.avi', '.mov'] # 这里列出了一些常见的视频文件扩展名,可以根据需要添加更多
file_ext = os.path.splitext(file_path)[1].lower() # 获取文件扩展名并转换为小写
return file_ext in video_exts
def create_dir(dir_path):
if not os.path.exists(dir_path):
os.makedirs(dir_path)
current_dir = os.path.dirname(os.path.abspath(__file__))
def create_musetalk_human(file, avatar_id):
# 保存文件设置 可以不动
save_path = os.path.join(current_dir, f'../data/avatars/avator_{avatar_id}')
save_full_path = os.path.join(current_dir, f'../data/avatars/avator_{avatar_id}/full_imgs')
create_dir(save_path)
create_dir(save_full_path)
mask_out_path = os.path.join(current_dir, f'../data/avatars/avator_{avatar_id}/mask')
create_dir(mask_out_path)
# 模型
mask_coords_path = os.path.join(current_dir, f'{save_path}/mask_coords.pkl')
coords_path = os.path.join(current_dir, f'{save_path}/coords.pkl')
latents_out_path = os.path.join(current_dir, f'{save_path}/latents.pt')
with open(os.path.join(current_dir, f'{save_path}/avator_info.json'), "w") as f:
json.dump({
"avatar_id": avatar_id,
"video_path": file,
"bbox_shift": 5
}, f)
if os.path.isfile(file):
if is_video_file(file):
video2imgs(file, save_full_path, ext='png')
else:
shutil.copyfile(file, f"{save_full_path}/{os.path.basename(file)}")
else:
files = os.listdir(file)
files.sort()
files = [file for file in files if file.split(".")[-1] == "png"]
for filename in files:
shutil.copyfile(f"{file}/{filename}", f"{save_full_path}/{filename}")
input_img_list = sorted(glob.glob(os.path.join(save_full_path, '*.[jpJP][pnPN]*[gG]')))
print("extracting landmarks...")
coord_list, frame_list = get_landmark_and_bbox(input_img_list, 5)
input_latent_list = []
idx = -1
# maker if the bbox is not sufficient
coord_placeholder = (0.0, 0.0, 0.0, 0.0)
for bbox, frame in zip(coord_list, frame_list):
idx = idx + 1
if bbox == coord_placeholder:
continue
x1, y1, x2, y2 = bbox
crop_frame = frame[y1:y2, x1:x2]
resized_crop_frame = cv2.resize(crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4)
latents = get_latents_for_unet(resized_crop_frame)
input_latent_list.append(latents)
frame_list_cycle = frame_list #+ frame_list[::-1]
coord_list_cycle = coord_list #+ coord_list[::-1]
input_latent_list_cycle = input_latent_list #+ input_latent_list[::-1]
mask_coords_list_cycle = []
mask_list_cycle = []
for i, frame in enumerate(tqdm(frame_list_cycle)):
cv2.imwrite(f"{save_full_path}/{str(i).zfill(8)}.png", frame)
face_box = coord_list_cycle[i]
mask, crop_box = get_image_prepare_material(frame, face_box)
cv2.imwrite(f"{mask_out_path}/{str(i).zfill(8)}.png", mask)
mask_coords_list_cycle += [crop_box]
mask_list_cycle.append(mask)
with open(mask_coords_path, 'wb') as f:
pickle.dump(mask_coords_list_cycle, f)
with open(coords_path, 'wb') as f:
pickle.dump(coord_list_cycle, f)
torch.save(input_latent_list_cycle, os.path.join(latents_out_path))
# initialize the mmpose model
device = "cuda" if torch.cuda.is_available() else "cpu"
fa = FaceAlignment(1, flip_input=False, device=device)
config_file = os.path.join(current_dir, 'utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py')
checkpoint_file = os.path.abspath(os.path.join(current_dir, '../models/dwpose/dw-ll_ucoco_384.pth'))
model = init_model(config_file, checkpoint_file, device=device)
vae = AutoencoderKL.from_pretrained(os.path.abspath(os.path.join(current_dir, '../models/sd-vae-ft-mse')))
vae.to(device)
fp = FaceParsing(os.path.abspath(os.path.join(current_dir, '../models/face-parse-bisent/resnet18-5c106cde.pth')),
os.path.abspath(os.path.join(current_dir, '../models/face-parse-bisent/79999_iter.pth')))
if __name__ == '__main__':
# 视频文件地址
parser = argparse.ArgumentParser()
parser.add_argument("--file",
type=str,
default=r'D:\ok\00000000.png',
)
parser.add_argument("--avatar_id",
type=str,
default='3',
)
args = parser.parse_args()
create_musetalk_human(args.file, args.avatar_id)

View File

@ -0,0 +1,5 @@
import sys
from os.path import abspath, dirname
current_dir = dirname(abspath(__file__))
parent_dir = dirname(current_dir)
sys.path.append(parent_dir+'/utils')

125
musetalk/utils/blending.py Normal file
View File

@ -0,0 +1,125 @@
from PIL import Image
import numpy as np
import cv2
from face_parsing import FaceParsing
import copy
fp = FaceParsing()
def get_crop_box(box, expand):
x, y, x1, y1 = box
x_c, y_c = (x+x1)//2, (y+y1)//2
w, h = x1-x, y1-y
s = int(max(w, h)//2*expand)
crop_box = [x_c-s, y_c-s, x_c+s, y_c+s]
return crop_box, s
def face_seg(image):
seg_image = fp(image)
if seg_image is None:
print("error, no person_segment")
return None
seg_image = seg_image.resize(image.size)
return seg_image
def get_image(image,face,face_box,upper_boundary_ratio = 0.5,expand=1.2):
#print(image.shape)
#print(face.shape)
body = Image.fromarray(image[:,:,::-1])
face = Image.fromarray(face[:,:,::-1])
x, y, x1, y1 = face_box
#print(x1-x,y1-y)
crop_box, s = get_crop_box(face_box, expand)
x_s, y_s, x_e, y_e = crop_box
face_position = (x, y)
face_large = body.crop(crop_box)
ori_shape = face_large.size
mask_image = face_seg(face_large)
mask_small = mask_image.crop((x-x_s, y-y_s, x1-x_s, y1-y_s))
mask_image = Image.new('L', ori_shape, 0)
mask_image.paste(mask_small, (x-x_s, y-y_s, x1-x_s, y1-y_s))
# keep upper_boundary_ratio of talking area
width, height = mask_image.size
top_boundary = int(height * upper_boundary_ratio)
modified_mask_image = Image.new('L', ori_shape, 0)
modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary))
blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1
mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0)
mask_image = Image.fromarray(mask_array)
face_large.paste(face, (x-x_s, y-y_s, x1-x_s, y1-y_s))
body.paste(face_large, crop_box[:2], mask_image)
body = np.array(body)
return body[:,:,::-1]
def get_image_prepare_material(image,face_box,upper_boundary_ratio = 0.5,expand=1.2):
body = Image.fromarray(image[:,:,::-1])
x, y, x1, y1 = face_box
#print(x1-x,y1-y)
crop_box, s = get_crop_box(face_box, expand)
x_s, y_s, x_e, y_e = crop_box
face_large = body.crop(crop_box)
ori_shape = face_large.size
mask_image = face_seg(face_large)
mask_small = mask_image.crop((x-x_s, y-y_s, x1-x_s, y1-y_s))
mask_image = Image.new('L', ori_shape, 0)
mask_image.paste(mask_small, (x-x_s, y-y_s, x1-x_s, y1-y_s))
# keep upper_boundary_ratio of talking area
width, height = mask_image.size
top_boundary = int(height * upper_boundary_ratio)
modified_mask_image = Image.new('L', ori_shape, 0)
modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary))
blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1
mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0)
return mask_array,crop_box
# def get_image_blending(image,face,face_box,mask_array,crop_box):
# body = Image.fromarray(image[:,:,::-1])
# face = Image.fromarray(face[:,:,::-1])
# x, y, x1, y1 = face_box
# x_s, y_s, x_e, y_e = crop_box
# face_large = body.crop(crop_box)
# mask_image = Image.fromarray(mask_array)
# mask_image = mask_image.convert("L")
# face_large.paste(face, (x-x_s, y-y_s, x1-x_s, y1-y_s))
# body.paste(face_large, crop_box[:2], mask_image)
# body = np.array(body)
# return body[:,:,::-1]
def get_image_blending(image,face,face_box,mask_array,crop_box):
body = image
x, y, x1, y1 = face_box
x_s, y_s, x_e, y_e = crop_box
face_large = copy.deepcopy(body[y_s:y_e, x_s:x_e])
face_large[y-y_s:y1-y_s, x-x_s:x1-x_s]=face
mask_image = cv2.cvtColor(mask_array,cv2.COLOR_BGR2GRAY)
mask_image = (mask_image/255).astype(np.float32)
# mask_not = cv2.bitwise_not(mask_array)
# prospect_tmp = cv2.bitwise_and(face_large, face_large, mask=mask_array)
# background_img = body[y_s:y_e, x_s:x_e]
# background_img = cv2.bitwise_and(background_img, background_img, mask=mask_not)
# body[y_s:y_e, x_s:x_e] = prospect_tmp + background_img
#print(mask_image.shape)
#print(cv2.minMaxLoc(mask_image))
body[y_s:y_e, x_s:x_e] = cv2.blendLinear(face_large,body[y_s:y_e, x_s:x_e],mask_image,1-mask_image)
#body.paste(face_large, crop_box[:2], mask_image)
return body

View File

@ -0,0 +1,54 @@
default_scope = 'mmpose'
# hooks
default_hooks = dict(
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=50),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', interval=10),
sampler_seed=dict(type='DistSamplerSeedHook'),
visualization=dict(type='PoseVisualizationHook', enable=False),
badcase=dict(
type='BadCaseAnalysisHook',
enable=False,
out_dir='badcase',
metric_type='loss',
badcase_thr=5))
# custom hooks
custom_hooks = [
# Synchronize model buffers such as running_mean and running_var in BN
# at the end of each epoch
dict(type='SyncBuffersHook')
]
# multi-processing backend
env_cfg = dict(
cudnn_benchmark=False,
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
dist_cfg=dict(backend='nccl'),
)
# visualizer
vis_backends = [
dict(type='LocalVisBackend'),
# dict(type='TensorboardVisBackend'),
# dict(type='WandbVisBackend'),
]
visualizer = dict(
type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
# logger
log_processor = dict(
type='LogProcessor', window_size=50, by_epoch=True, num_digits=6)
log_level = 'INFO'
load_from = None
resume = False
# file I/O backend
backend_args = dict(backend='local')
# training/validation/testing progress
train_cfg = dict(by_epoch=True)
val_cfg = dict()
test_cfg = dict()

View File

@ -0,0 +1,257 @@
#_base_ = ['../../../_base_/default_runtime.py']
_base_ = ['default_runtime.py']
# runtime
max_epochs = 270
stage2_num_epochs = 30
base_lr = 4e-3
train_batch_size = 32
val_batch_size = 32
train_cfg = dict(max_epochs=max_epochs, val_interval=10)
randomness = dict(seed=21)
# optimizer
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05),
paramwise_cfg=dict(
norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))
# learning rate
param_scheduler = [
dict(
type='LinearLR',
start_factor=1.0e-5,
by_epoch=False,
begin=0,
end=1000),
dict(
# use cosine lr from 150 to 300 epoch
type='CosineAnnealingLR',
eta_min=base_lr * 0.05,
begin=max_epochs // 2,
end=max_epochs,
T_max=max_epochs // 2,
by_epoch=True,
convert_to_iter_based=True),
]
# automatically scaling LR based on the actual training batch size
auto_scale_lr = dict(base_batch_size=512)
# codec settings
codec = dict(
type='SimCCLabel',
input_size=(288, 384),
sigma=(6., 6.93),
simcc_split_ratio=2.0,
normalize=False,
use_dark=False)
# model settings
model = dict(
type='TopdownPoseEstimator',
data_preprocessor=dict(
type='PoseDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True),
backbone=dict(
_scope_='mmdet',
type='CSPNeXt',
arch='P5',
expand_ratio=0.5,
deepen_factor=1.,
widen_factor=1.,
out_indices=(4, ),
channel_attention=True,
norm_cfg=dict(type='SyncBN'),
act_cfg=dict(type='SiLU'),
init_cfg=dict(
type='Pretrained',
prefix='backbone.',
checkpoint='https://download.openmmlab.com/mmpose/v1/projects/'
'rtmpose/cspnext-l_udp-aic-coco_210e-256x192-273b7631_20230130.pth' # noqa: E501
)),
head=dict(
type='RTMCCHead',
in_channels=1024,
out_channels=133,
input_size=codec['input_size'],
in_featuremap_size=(9, 12),
simcc_split_ratio=codec['simcc_split_ratio'],
final_layer_kernel_size=7,
gau_cfg=dict(
hidden_dims=256,
s=128,
expansion_factor=2,
dropout_rate=0.,
drop_path=0.,
act_fn='SiLU',
use_rel_bias=False,
pos_enc=False),
loss=dict(
type='KLDiscretLoss',
use_target_weight=True,
beta=10.,
label_softmax=True),
decoder=codec),
test_cfg=dict(flip_test=True, ))
# base dataset settings
dataset_type = 'UBody2dDataset'
data_mode = 'topdown'
data_root = 'data/UBody/'
backend_args = dict(backend='local')
scenes = [
'Magic_show', 'Entertainment', 'ConductMusic', 'Online_class', 'TalkShow',
'Speech', 'Fitness', 'Interview', 'Olympic', 'TVShow', 'Singing',
'SignLanguage', 'Movie', 'LiveVlog', 'VideoConference'
]
train_datasets = [
dict(
type='CocoWholeBodyDataset',
data_root='data/coco/',
data_mode=data_mode,
ann_file='annotations/coco_wholebody_train_v1.0.json',
data_prefix=dict(img='train2017/'),
pipeline=[])
]
for scene in scenes:
train_dataset = dict(
type=dataset_type,
data_root=data_root,
data_mode=data_mode,
ann_file=f'annotations/{scene}/train_annotations.json',
data_prefix=dict(img='images/'),
pipeline=[],
sample_interval=10)
train_datasets.append(train_dataset)
# pipelines
train_pipeline = [
dict(type='LoadImage', backend_args=backend_args),
dict(type='GetBBoxCenterScale'),
dict(type='RandomFlip', direction='horizontal'),
dict(type='RandomHalfBody'),
dict(
type='RandomBBoxTransform', scale_factor=[0.5, 1.5], rotate_factor=90),
dict(type='TopdownAffine', input_size=codec['input_size']),
dict(type='mmdet.YOLOXHSVRandomAug'),
dict(
type='Albumentation',
transforms=[
dict(type='Blur', p=0.1),
dict(type='MedianBlur', p=0.1),
dict(
type='CoarseDropout',
max_holes=1,
max_height=0.4,
max_width=0.4,
min_holes=1,
min_height=0.2,
min_width=0.2,
p=1.0),
]),
dict(type='GenerateTarget', encoder=codec),
dict(type='PackPoseInputs')
]
val_pipeline = [
dict(type='LoadImage', backend_args=backend_args),
dict(type='GetBBoxCenterScale'),
dict(type='TopdownAffine', input_size=codec['input_size']),
dict(type='PackPoseInputs')
]
train_pipeline_stage2 = [
dict(type='LoadImage', backend_args=backend_args),
dict(type='GetBBoxCenterScale'),
dict(type='RandomFlip', direction='horizontal'),
dict(type='RandomHalfBody'),
dict(
type='RandomBBoxTransform',
shift_factor=0.,
scale_factor=[0.5, 1.5],
rotate_factor=90),
dict(type='TopdownAffine', input_size=codec['input_size']),
dict(type='mmdet.YOLOXHSVRandomAug'),
dict(
type='Albumentation',
transforms=[
dict(type='Blur', p=0.1),
dict(type='MedianBlur', p=0.1),
dict(
type='CoarseDropout',
max_holes=1,
max_height=0.4,
max_width=0.4,
min_holes=1,
min_height=0.2,
min_width=0.2,
p=0.5),
]),
dict(type='GenerateTarget', encoder=codec),
dict(type='PackPoseInputs')
]
# data loaders
train_dataloader = dict(
batch_size=train_batch_size,
num_workers=10,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type='CombinedDataset',
metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'),
datasets=train_datasets,
pipeline=train_pipeline,
test_mode=False,
))
val_dataloader = dict(
batch_size=val_batch_size,
num_workers=10,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
dataset=dict(
type='CocoWholeBodyDataset',
data_root=data_root,
data_mode=data_mode,
ann_file='data/coco/annotations/coco_wholebody_val_v1.0.json',
bbox_file='data/coco/person_detection_results/'
'COCO_val2017_detections_AP_H_56_person.json',
data_prefix=dict(img='coco/val2017/'),
test_mode=True,
pipeline=val_pipeline,
))
test_dataloader = val_dataloader
# hooks
default_hooks = dict(
checkpoint=dict(
save_best='coco-wholebody/AP', rule='greater', max_keep_ckpts=1))
custom_hooks = [
dict(
type='EMAHook',
ema_type='ExpMomentumEMA',
momentum=0.0002,
update_buffers=True,
priority=49),
dict(
type='mmdet.PipelineSwitchHook',
switch_epoch=max_epochs - stage2_num_epochs,
switch_pipeline=train_pipeline_stage2)
]
# evaluators
val_evaluator = dict(
type='CocoWholeBodyMetric',
ann_file='data/coco/annotations/coco_wholebody_val_v1.0.json')
test_evaluator = val_evaluator

View File

@ -0,0 +1 @@
The code for Face Detection in this folder has been taken from the wonderful [face_alignment](https://github.com/1adrianb/face-alignment) repository. This has been modified to take batches of faces at a time.

View File

@ -0,0 +1,7 @@
# -*- coding: utf-8 -*-
__author__ = """Adrian Bulat"""
__email__ = 'adrian.bulat@nottingham.ac.uk'
__version__ = '1.0.1'
from .api import FaceAlignment, LandmarksType, NetworkSize, YOLOv8_face

View File

@ -0,0 +1,240 @@
from __future__ import print_function
import os
import torch
from torch.utils.model_zoo import load_url
from enum import Enum
import numpy as np
import cv2
try:
import urllib.request as request_file
except BaseException:
import urllib as request_file
from .models import FAN, ResNetDepth
from .utils import *
class LandmarksType(Enum):
"""Enum class defining the type of landmarks to detect.
``_2D`` - the detected points ``(x,y)`` are detected in a 2D space and follow the visible contour of the face
``_2halfD`` - this points represent the projection of the 3D points into 3D
``_3D`` - detect the points ``(x,y,z)``` in a 3D space
"""
_2D = 1
_2halfD = 2
_3D = 3
class NetworkSize(Enum):
# TINY = 1
# SMALL = 2
# MEDIUM = 3
LARGE = 4
def __new__(cls, value):
member = object.__new__(cls)
member._value_ = value
return member
def __int__(self):
return self.value
class FaceAlignment:
def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
device='cuda', flip_input=False, face_detector='sfd', verbose=False):
self.device = device
self.flip_input = flip_input
self.landmarks_type = landmarks_type
self.verbose = verbose
network_size = int(network_size)
if 'cuda' in device:
torch.backends.cudnn.benchmark = True
# torch.backends.cuda.matmul.allow_tf32 = False
# torch.backends.cudnn.benchmark = True
# torch.backends.cudnn.deterministic = False
# torch.backends.cudnn.allow_tf32 = True
print('cuda start')
# Get the face detector
face_detector_module = __import__('face_detection.detection.' + face_detector,
globals(), locals(), [face_detector], 0)
self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose)
def get_detections_for_batch(self, images):
images = images[..., ::-1]
detected_faces = self.face_detector.detect_from_batch(images.copy())
results = []
for i, d in enumerate(detected_faces):
if len(d) == 0:
results.append(None)
continue
d = d[0]
d = np.clip(d, 0, None)
x1, y1, x2, y2 = map(int, d[:-1])
results.append((x1, y1, x2, y2))
return results
class YOLOv8_face:
def __init__(self, path = 'face_detection/weights/yolov8n-face.onnx', conf_thres=0.2, iou_thres=0.5):
self.conf_threshold = conf_thres
self.iou_threshold = iou_thres
self.class_names = ['face']
self.num_classes = len(self.class_names)
# Initialize model
self.net = cv2.dnn.readNet(path)
self.input_height = 640
self.input_width = 640
self.reg_max = 16
self.project = np.arange(self.reg_max)
self.strides = (8, 16, 32)
self.feats_hw = [(math.ceil(self.input_height / self.strides[i]), math.ceil(self.input_width / self.strides[i])) for i in range(len(self.strides))]
self.anchors = self.make_anchors(self.feats_hw)
def make_anchors(self, feats_hw, grid_cell_offset=0.5):
"""Generate anchors from features."""
anchor_points = {}
for i, stride in enumerate(self.strides):
h,w = feats_hw[i]
x = np.arange(0, w) + grid_cell_offset # shift x
y = np.arange(0, h) + grid_cell_offset # shift y
sx, sy = np.meshgrid(x, y)
# sy, sx = np.meshgrid(y, x)
anchor_points[stride] = np.stack((sx, sy), axis=-1).reshape(-1, 2)
return anchor_points
def softmax(self, x, axis=1):
x_exp = np.exp(x)
# 如果是列向量则axis=0
x_sum = np.sum(x_exp, axis=axis, keepdims=True)
s = x_exp / x_sum
return s
def resize_image(self, srcimg, keep_ratio=True):
top, left, newh, neww = 0, 0, self.input_width, self.input_height
if keep_ratio and srcimg.shape[0] != srcimg.shape[1]:
hw_scale = srcimg.shape[0] / srcimg.shape[1]
if hw_scale > 1:
newh, neww = self.input_height, int(self.input_width / hw_scale)
img = cv2.resize(srcimg, (neww, newh), interpolation=cv2.INTER_AREA)
left = int((self.input_width - neww) * 0.5)
img = cv2.copyMakeBorder(img, 0, 0, left, self.input_width - neww - left, cv2.BORDER_CONSTANT,
value=(0, 0, 0)) # add border
else:
newh, neww = int(self.input_height * hw_scale), self.input_width
img = cv2.resize(srcimg, (neww, newh), interpolation=cv2.INTER_AREA)
top = int((self.input_height - newh) * 0.5)
img = cv2.copyMakeBorder(img, top, self.input_height - newh - top, 0, 0, cv2.BORDER_CONSTANT,
value=(0, 0, 0))
else:
img = cv2.resize(srcimg, (self.input_width, self.input_height), interpolation=cv2.INTER_AREA)
return img, newh, neww, top, left
def detect(self, srcimg):
input_img, newh, neww, padh, padw = self.resize_image(cv2.cvtColor(srcimg, cv2.COLOR_BGR2RGB))
scale_h, scale_w = srcimg.shape[0]/newh, srcimg.shape[1]/neww
input_img = input_img.astype(np.float32) / 255.0
blob = cv2.dnn.blobFromImage(input_img)
self.net.setInput(blob)
outputs = self.net.forward(self.net.getUnconnectedOutLayersNames())
# if isinstance(outputs, tuple):
# outputs = list(outputs)
# if float(cv2.__version__[:3])>=4.7:
# outputs = [outputs[2], outputs[0], outputs[1]] ###opencv4.7需要这一步opencv4.5不需要
# Perform inference on the image
det_bboxes, det_conf, det_classid, landmarks = self.post_process(outputs, scale_h, scale_w, padh, padw)
return det_bboxes, det_conf, det_classid, landmarks
def post_process(self, preds, scale_h, scale_w, padh, padw):
bboxes, scores, landmarks = [], [], []
for i, pred in enumerate(preds):
stride = int(self.input_height/pred.shape[2])
pred = pred.transpose((0, 2, 3, 1))
box = pred[..., :self.reg_max * 4]
cls = 1 / (1 + np.exp(-pred[..., self.reg_max * 4:-15])).reshape((-1,1))
kpts = pred[..., -15:].reshape((-1,15)) ### x1,y1,score1, ..., x5,y5,score5
# tmp = box.reshape(self.feats_hw[i][0], self.feats_hw[i][1], 4, self.reg_max)
tmp = box.reshape(-1, 4, self.reg_max)
bbox_pred = self.softmax(tmp, axis=-1)
bbox_pred = np.dot(bbox_pred, self.project).reshape((-1,4))
bbox = self.distance2bbox(self.anchors[stride], bbox_pred, max_shape=(self.input_height, self.input_width)) * stride
kpts[:, 0::3] = (kpts[:, 0::3] * 2.0 + (self.anchors[stride][:, 0].reshape((-1,1)) - 0.5)) * stride
kpts[:, 1::3] = (kpts[:, 1::3] * 2.0 + (self.anchors[stride][:, 1].reshape((-1,1)) - 0.5)) * stride
kpts[:, 2::3] = 1 / (1+np.exp(-kpts[:, 2::3]))
bbox -= np.array([[padw, padh, padw, padh]]) ###合理使用广播法则
bbox *= np.array([[scale_w, scale_h, scale_w, scale_h]])
kpts -= np.tile(np.array([padw, padh, 0]), 5).reshape((1,15))
kpts *= np.tile(np.array([scale_w, scale_h, 1]), 5).reshape((1,15))
bboxes.append(bbox)
scores.append(cls)
landmarks.append(kpts)
bboxes = np.concatenate(bboxes, axis=0)
scores = np.concatenate(scores, axis=0)
landmarks = np.concatenate(landmarks, axis=0)
bboxes_wh = bboxes.copy()
bboxes_wh[:, 2:4] = bboxes[:, 2:4] - bboxes[:, 0:2] ####xywh
classIds = np.argmax(scores, axis=1)
confidences = np.max(scores, axis=1) ####max_class_confidence
mask = confidences>self.conf_threshold
bboxes_wh = bboxes_wh[mask] ###合理使用广播法则
confidences = confidences[mask]
classIds = classIds[mask]
landmarks = landmarks[mask]
indices = cv2.dnn.NMSBoxes(bboxes_wh.tolist(), confidences.tolist(), self.conf_threshold,
self.iou_threshold).flatten()
if len(indices) > 0:
mlvl_bboxes = bboxes_wh[indices]
confidences = confidences[indices]
classIds = classIds[indices]
landmarks = landmarks[indices]
return mlvl_bboxes, confidences, classIds, landmarks
else:
print('nothing detect')
return np.array([]), np.array([]), np.array([]), np.array([])
def distance2bbox(self, points, distance, max_shape=None):
x1 = points[:, 0] - distance[:, 0]
y1 = points[:, 1] - distance[:, 1]
x2 = points[:, 0] + distance[:, 2]
y2 = points[:, 1] + distance[:, 3]
if max_shape is not None:
x1 = np.clip(x1, 0, max_shape[1])
y1 = np.clip(y1, 0, max_shape[0])
x2 = np.clip(x2, 0, max_shape[1])
y2 = np.clip(y2, 0, max_shape[0])
return np.stack([x1, y1, x2, y2], axis=-1)
def draw_detections(self, image, boxes, scores, kpts):
for box, score, kp in zip(boxes, scores, kpts):
x, y, w, h = box.astype(int)
# Draw rectangle
cv2.rectangle(image, (x, y), (x + w, y + h), (0, 0, 255), thickness=3)
cv2.putText(image, "face:"+str(round(score,2)), (x, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), thickness=2)
for i in range(5):
cv2.circle(image, (int(kp[i * 3]), int(kp[i * 3 + 1])), 4, (0, 255, 0), thickness=-1)
# cv2.putText(image, str(i), (int(kp[i * 3]), int(kp[i * 3 + 1]) - 10), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), thickness=1)
return image
ROOT = os.path.dirname(os.path.abspath(__file__))

View File

@ -0,0 +1 @@
from .core import FaceDetector

View File

@ -0,0 +1,130 @@
import logging
import glob
from tqdm import tqdm
import numpy as np
import torch
import cv2
class FaceDetector(object):
"""An abstract class representing a face detector.
Any other face detection implementation must subclass it. All subclasses
must implement ``detect_from_image``, that return a list of detected
bounding boxes. Optionally, for speed considerations detect from path is
recommended.
"""
def __init__(self, device, verbose):
self.device = device
self.verbose = verbose
if verbose:
if 'cpu' in device:
logger = logging.getLogger(__name__)
logger.warning("Detection running on CPU, this may be potentially slow.")
if 'cpu' not in device and 'cuda' not in device:
if verbose:
logger.error("Expected values for device are: {cpu, cuda} but got: %s", device)
raise ValueError
def detect_from_image(self, tensor_or_path):
"""Detects faces in a given image.
This function detects the faces present in a provided BGR(usually)
image. The input can be either the image itself or the path to it.
Arguments:
tensor_or_path {numpy.ndarray, torch.tensor or string} -- the path
to an image or the image itself.
Example::
>>> path_to_image = 'data/image_01.jpg'
... detected_faces = detect_from_image(path_to_image)
[A list of bounding boxes (x1, y1, x2, y2)]
>>> image = cv2.imread(path_to_image)
... detected_faces = detect_from_image(image)
[A list of bounding boxes (x1, y1, x2, y2)]
"""
raise NotImplementedError
def detect_from_directory(self, path, extensions=['.jpg', '.png'], recursive=False, show_progress_bar=True):
"""Detects faces from all the images present in a given directory.
Arguments:
path {string} -- a string containing a path that points to the folder containing the images
Keyword Arguments:
extensions {list} -- list of string containing the extensions to be
consider in the following format: ``.extension_name`` (default:
{['.jpg', '.png']}) recursive {bool} -- option wherever to scan the
folder recursively (default: {False}) show_progress_bar {bool} --
display a progressbar (default: {True})
Example:
>>> directory = 'data'
... detected_faces = detect_from_directory(directory)
{A dictionary of [lists containing bounding boxes(x1, y1, x2, y2)]}
"""
if self.verbose:
logger = logging.getLogger(__name__)
if len(extensions) == 0:
if self.verbose:
logger.error("Expected at list one extension, but none was received.")
raise ValueError
if self.verbose:
logger.info("Constructing the list of images.")
additional_pattern = '/**/*' if recursive else '/*'
files = []
for extension in extensions:
files.extend(glob.glob(path + additional_pattern + extension, recursive=recursive))
if self.verbose:
logger.info("Finished searching for images. %s images found", len(files))
logger.info("Preparing to run the detection.")
predictions = {}
for image_path in tqdm(files, disable=not show_progress_bar):
if self.verbose:
logger.info("Running the face detector on image: %s", image_path)
predictions[image_path] = self.detect_from_image(image_path)
if self.verbose:
logger.info("The detector was successfully run on all %s images", len(files))
return predictions
@property
def reference_scale(self):
raise NotImplementedError
@property
def reference_x_shift(self):
raise NotImplementedError
@property
def reference_y_shift(self):
raise NotImplementedError
@staticmethod
def tensor_or_path_to_ndarray(tensor_or_path, rgb=True):
"""Convert path (represented as a string) or torch.tensor to a numpy.ndarray
Arguments:
tensor_or_path {numpy.ndarray, torch.tensor or string} -- path to the image, or the image itself
"""
if isinstance(tensor_or_path, str):
return cv2.imread(tensor_or_path) if not rgb else cv2.imread(tensor_or_path)[..., ::-1]
elif torch.is_tensor(tensor_or_path):
# Call cpu in case its coming from cuda
return tensor_or_path.cpu().numpy()[..., ::-1].copy() if not rgb else tensor_or_path.cpu().numpy()
elif isinstance(tensor_or_path, np.ndarray):
return tensor_or_path[..., ::-1].copy() if not rgb else tensor_or_path
else:
raise TypeError

View File

@ -0,0 +1 @@
from .sfd_detector import SFDDetector as FaceDetector

View File

@ -0,0 +1,129 @@
from __future__ import print_function
import os
import sys
import cv2
import random
import datetime
import time
import math
import argparse
import numpy as np
import torch
try:
from iou import IOU
except BaseException:
# IOU cython speedup 10x
def IOU(ax1, ay1, ax2, ay2, bx1, by1, bx2, by2):
sa = abs((ax2 - ax1) * (ay2 - ay1))
sb = abs((bx2 - bx1) * (by2 - by1))
x1, y1 = max(ax1, bx1), max(ay1, by1)
x2, y2 = min(ax2, bx2), min(ay2, by2)
w = x2 - x1
h = y2 - y1
if w < 0 or h < 0:
return 0.0
else:
return 1.0 * w * h / (sa + sb - w * h)
def bboxlog(x1, y1, x2, y2, axc, ayc, aww, ahh):
xc, yc, ww, hh = (x2 + x1) / 2, (y2 + y1) / 2, x2 - x1, y2 - y1
dx, dy = (xc - axc) / aww, (yc - ayc) / ahh
dw, dh = math.log(ww / aww), math.log(hh / ahh)
return dx, dy, dw, dh
def bboxloginv(dx, dy, dw, dh, axc, ayc, aww, ahh):
xc, yc = dx * aww + axc, dy * ahh + ayc
ww, hh = math.exp(dw) * aww, math.exp(dh) * ahh
x1, x2, y1, y2 = xc - ww / 2, xc + ww / 2, yc - hh / 2, yc + hh / 2
return x1, y1, x2, y2
def nms(dets, thresh):
if 0 == len(dets):
return []
x1, y1, x2, y2, scores = dets[:, 0], dets[:, 1], dets[:, 2], dets[:, 3], dets[:, 4]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
xx1, yy1 = np.maximum(x1[i], x1[order[1:]]), np.maximum(y1[i], y1[order[1:]])
xx2, yy2 = np.minimum(x2[i], x2[order[1:]]), np.minimum(y2[i], y2[order[1:]])
w, h = np.maximum(0.0, xx2 - xx1 + 1), np.maximum(0.0, yy2 - yy1 + 1)
ovr = w * h / (areas[i] + areas[order[1:]] - w * h)
inds = np.where(ovr <= thresh)[0]
order = order[inds + 1]
return keep
def encode(matched, priors, variances):
"""Encode the variances from the priorbox layers into the ground truth boxes
we have matched (based on jaccard overlap) with the prior boxes.
Args:
matched: (tensor) Coords of ground truth for each prior in point-form
Shape: [num_priors, 4].
priors: (tensor) Prior boxes in center-offset form
Shape: [num_priors,4].
variances: (list[float]) Variances of priorboxes
Return:
encoded boxes (tensor), Shape: [num_priors, 4]
"""
# dist b/t match center and prior's center
g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2]
# encode variance
g_cxcy /= (variances[0] * priors[:, 2:])
# match wh / prior wh
g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
g_wh = torch.log(g_wh) / variances[1]
# return target for smooth_l1_loss
return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
def decode(loc, priors, variances):
"""Decode locations from predictions using priors to undo
the encoding we did for offset regression at train time.
Args:
loc (tensor): location predictions for loc layers,
Shape: [num_priors,4]
priors (tensor): Prior boxes in center-offset form.
Shape: [num_priors,4].
variances: (list[float]) Variances of priorboxes
Return:
decoded bounding box predictions
"""
boxes = torch.cat((
priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
boxes[:, :2] -= boxes[:, 2:] / 2
boxes[:, 2:] += boxes[:, :2]
return boxes
def batch_decode(loc, priors, variances):
"""Decode locations from predictions using priors to undo
the encoding we did for offset regression at train time.
Args:
loc (tensor): location predictions for loc layers,
Shape: [num_priors,4]
priors (tensor): Prior boxes in center-offset form.
Shape: [num_priors,4].
variances: (list[float]) Variances of priorboxes
Return:
decoded bounding box predictions
"""
boxes = torch.cat((
priors[:, :, :2] + loc[:, :, :2] * variances[0] * priors[:, :, 2:],
priors[:, :, 2:] * torch.exp(loc[:, :, 2:] * variances[1])), 2)
boxes[:, :, :2] -= boxes[:, :, 2:] / 2
boxes[:, :, 2:] += boxes[:, :, :2]
return boxes

View File

@ -0,0 +1,114 @@
import torch
import torch.nn.functional as F
import os
import sys
import cv2
import random
import datetime
import math
import argparse
import numpy as np
import scipy.io as sio
import zipfile
from .net_s3fd import s3fd
from .bbox import *
def detect(net, img, device):
img = img - np.array([104, 117, 123])
img = img.transpose(2, 0, 1)
img = img.reshape((1,) + img.shape)
if 'cuda' in device:
torch.backends.cudnn.benchmark = True
img = torch.from_numpy(img).float().to(device)
BB, CC, HH, WW = img.size()
with torch.no_grad():
olist = net(img)
bboxlist = []
for i in range(len(olist) // 2):
olist[i * 2] = F.softmax(olist[i * 2], dim=1)
olist = [oelem.data.cpu() for oelem in olist]
for i in range(len(olist) // 2):
ocls, oreg = olist[i * 2], olist[i * 2 + 1]
FB, FC, FH, FW = ocls.size() # feature map size
stride = 2**(i + 2) # 4,8,16,32,64,128
anchor = stride * 4
poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
for Iindex, hindex, windex in poss:
axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
score = ocls[0, 1, hindex, windex]
loc = oreg[0, :, hindex, windex].contiguous().view(1, 4)
priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]])
variances = [0.1, 0.2]
box = decode(loc, priors, variances)
x1, y1, x2, y2 = box[0] * 1.0
# cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
bboxlist.append([x1, y1, x2, y2, score])
bboxlist = np.array(bboxlist)
if 0 == len(bboxlist):
bboxlist = np.zeros((1, 5))
return bboxlist
def batch_detect(net, imgs, device):
imgs = imgs - np.array([104, 117, 123])
imgs = imgs.transpose(0, 3, 1, 2)
if 'cuda' in device:
torch.backends.cudnn.benchmark = True
imgs = torch.from_numpy(imgs).float().to(device)
BB, CC, HH, WW = imgs.size()
with torch.no_grad():
olist = net(imgs)
# print(olist)
bboxlist = []
for i in range(len(olist) // 2):
olist[i * 2] = F.softmax(olist[i * 2], dim=1)
olist = [oelem.cpu() for oelem in olist]
for i in range(len(olist) // 2):
ocls, oreg = olist[i * 2], olist[i * 2 + 1]
FB, FC, FH, FW = ocls.size() # feature map size
stride = 2**(i + 2) # 4,8,16,32,64,128
anchor = stride * 4
poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
for Iindex, hindex, windex in poss:
axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
score = ocls[:, 1, hindex, windex]
loc = oreg[:, :, hindex, windex].contiguous().view(BB, 1, 4)
priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]]).view(1, 1, 4)
variances = [0.1, 0.2]
box = batch_decode(loc, priors, variances)
box = box[:, 0] * 1.0
# cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
bboxlist.append(torch.cat([box, score.unsqueeze(1)], 1).cpu().numpy())
bboxlist = np.array(bboxlist)
if 0 == len(bboxlist):
bboxlist = np.zeros((1, BB, 5))
return bboxlist
def flip_detect(net, img, device):
img = cv2.flip(img, 1)
b = detect(net, img, device)
bboxlist = np.zeros(b.shape)
bboxlist[:, 0] = img.shape[1] - b[:, 2]
bboxlist[:, 1] = b[:, 1]
bboxlist[:, 2] = img.shape[1] - b[:, 0]
bboxlist[:, 3] = b[:, 3]
bboxlist[:, 4] = b[:, 4]
return bboxlist
def pts_to_bb(pts):
min_x, min_y = np.min(pts, axis=0)
max_x, max_y = np.max(pts, axis=0)
return np.array([min_x, min_y, max_x, max_y])

View File

@ -0,0 +1,129 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class L2Norm(nn.Module):
def __init__(self, n_channels, scale=1.0):
super(L2Norm, self).__init__()
self.n_channels = n_channels
self.scale = scale
self.eps = 1e-10
self.weight = nn.Parameter(torch.Tensor(self.n_channels))
self.weight.data *= 0.0
self.weight.data += self.scale
def forward(self, x):
norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps
x = x / norm * self.weight.view(1, -1, 1, 1)
return x
class s3fd(nn.Module):
def __init__(self):
super(s3fd, self).__init__()
self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.fc6 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=3)
self.fc7 = nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0)
self.conv6_1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0)
self.conv6_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)
self.conv7_1 = nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0)
self.conv7_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
self.conv3_3_norm = L2Norm(256, scale=10)
self.conv4_3_norm = L2Norm(512, scale=8)
self.conv5_3_norm = L2Norm(512, scale=5)
self.conv3_3_norm_mbox_conf = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
self.conv3_3_norm_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
self.conv4_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
self.conv4_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
self.conv5_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
self.conv5_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
self.fc7_mbox_conf = nn.Conv2d(1024, 2, kernel_size=3, stride=1, padding=1)
self.fc7_mbox_loc = nn.Conv2d(1024, 4, kernel_size=3, stride=1, padding=1)
self.conv6_2_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
self.conv6_2_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
self.conv7_2_mbox_conf = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1)
self.conv7_2_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
def forward(self, x):
h = F.relu(self.conv1_1(x))
h = F.relu(self.conv1_2(h))
h = F.max_pool2d(h, 2, 2)
h = F.relu(self.conv2_1(h))
h = F.relu(self.conv2_2(h))
h = F.max_pool2d(h, 2, 2)
h = F.relu(self.conv3_1(h))
h = F.relu(self.conv3_2(h))
h = F.relu(self.conv3_3(h))
f3_3 = h
h = F.max_pool2d(h, 2, 2)
h = F.relu(self.conv4_1(h))
h = F.relu(self.conv4_2(h))
h = F.relu(self.conv4_3(h))
f4_3 = h
h = F.max_pool2d(h, 2, 2)
h = F.relu(self.conv5_1(h))
h = F.relu(self.conv5_2(h))
h = F.relu(self.conv5_3(h))
f5_3 = h
h = F.max_pool2d(h, 2, 2)
h = F.relu(self.fc6(h))
h = F.relu(self.fc7(h))
ffc7 = h
h = F.relu(self.conv6_1(h))
h = F.relu(self.conv6_2(h))
f6_2 = h
h = F.relu(self.conv7_1(h))
h = F.relu(self.conv7_2(h))
f7_2 = h
f3_3 = self.conv3_3_norm(f3_3)
f4_3 = self.conv4_3_norm(f4_3)
f5_3 = self.conv5_3_norm(f5_3)
cls1 = self.conv3_3_norm_mbox_conf(f3_3)
reg1 = self.conv3_3_norm_mbox_loc(f3_3)
cls2 = self.conv4_3_norm_mbox_conf(f4_3)
reg2 = self.conv4_3_norm_mbox_loc(f4_3)
cls3 = self.conv5_3_norm_mbox_conf(f5_3)
reg3 = self.conv5_3_norm_mbox_loc(f5_3)
cls4 = self.fc7_mbox_conf(ffc7)
reg4 = self.fc7_mbox_loc(ffc7)
cls5 = self.conv6_2_mbox_conf(f6_2)
reg5 = self.conv6_2_mbox_loc(f6_2)
cls6 = self.conv7_2_mbox_conf(f7_2)
reg6 = self.conv7_2_mbox_loc(f7_2)
# max-out background label
chunk = torch.chunk(cls1, 4, 1)
bmax = torch.max(torch.max(chunk[0], chunk[1]), chunk[2])
cls1 = torch.cat([bmax, chunk[3]], dim=1)
return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6]

View File

@ -0,0 +1,59 @@
import os
import cv2
from torch.utils.model_zoo import load_url
from ..core import FaceDetector
from .net_s3fd import s3fd
from .bbox import *
from .detect import *
models_urls = {
's3fd': 'https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth',
}
class SFDDetector(FaceDetector):
def __init__(self, device, path_to_detector=os.path.join(os.path.dirname(os.path.abspath(__file__)), 's3fd.pth'), verbose=False):
super(SFDDetector, self).__init__(device, verbose)
# Initialise the face detector
if not os.path.isfile(path_to_detector):
model_weights = load_url(models_urls['s3fd'])
else:
model_weights = torch.load(path_to_detector)
self.face_detector = s3fd()
self.face_detector.load_state_dict(model_weights)
self.face_detector.to(device)
self.face_detector.eval()
def detect_from_image(self, tensor_or_path):
image = self.tensor_or_path_to_ndarray(tensor_or_path)
bboxlist = detect(self.face_detector, image, device=self.device)
keep = nms(bboxlist, 0.3)
bboxlist = bboxlist[keep, :]
bboxlist = [x for x in bboxlist if x[-1] > 0.5]
return bboxlist
def detect_from_batch(self, images):
bboxlists = batch_detect(self.face_detector, images, device=self.device)
keeps = [nms(bboxlists[:, i, :], 0.3) for i in range(bboxlists.shape[1])]
bboxlists = [bboxlists[keep, i, :] for i, keep in enumerate(keeps)]
bboxlists = [[x for x in bboxlist if x[-1] > 0.5] for bboxlist in bboxlists]
return bboxlists
@property
def reference_scale(self):
return 195
@property
def reference_x_shift(self):
return 0
@property
def reference_y_shift(self):
return 0

View File

@ -0,0 +1,261 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False):
"3x3 convolution with padding"
return nn.Conv2d(in_planes, out_planes, kernel_size=3,
stride=strd, padding=padding, bias=bias)
class ConvBlock(nn.Module):
def __init__(self, in_planes, out_planes):
super(ConvBlock, self).__init__()
self.bn1 = nn.BatchNorm2d(in_planes)
self.conv1 = conv3x3(in_planes, int(out_planes / 2))
self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4))
self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4))
if in_planes != out_planes:
self.downsample = nn.Sequential(
nn.BatchNorm2d(in_planes),
nn.ReLU(True),
nn.Conv2d(in_planes, out_planes,
kernel_size=1, stride=1, bias=False),
)
else:
self.downsample = None
def forward(self, x):
residual = x
out1 = self.bn1(x)
out1 = F.relu(out1, True)
out1 = self.conv1(out1)
out2 = self.bn2(out1)
out2 = F.relu(out2, True)
out2 = self.conv2(out2)
out3 = self.bn3(out2)
out3 = F.relu(out3, True)
out3 = self.conv3(out3)
out3 = torch.cat((out1, out2, out3), 1)
if self.downsample is not None:
residual = self.downsample(residual)
out3 += residual
return out3
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class HourGlass(nn.Module):
def __init__(self, num_modules, depth, num_features):
super(HourGlass, self).__init__()
self.num_modules = num_modules
self.depth = depth
self.features = num_features
self._generate_network(self.depth)
def _generate_network(self, level):
self.add_module('b1_' + str(level), ConvBlock(self.features, self.features))
self.add_module('b2_' + str(level), ConvBlock(self.features, self.features))
if level > 1:
self._generate_network(level - 1)
else:
self.add_module('b2_plus_' + str(level), ConvBlock(self.features, self.features))
self.add_module('b3_' + str(level), ConvBlock(self.features, self.features))
def _forward(self, level, inp):
# Upper branch
up1 = inp
up1 = self._modules['b1_' + str(level)](up1)
# Lower branch
low1 = F.avg_pool2d(inp, 2, stride=2)
low1 = self._modules['b2_' + str(level)](low1)
if level > 1:
low2 = self._forward(level - 1, low1)
else:
low2 = low1
low2 = self._modules['b2_plus_' + str(level)](low2)
low3 = low2
low3 = self._modules['b3_' + str(level)](low3)
up2 = F.interpolate(low3, scale_factor=2, mode='nearest')
return up1 + up2
def forward(self, x):
return self._forward(self.depth, x)
class FAN(nn.Module):
def __init__(self, num_modules=1):
super(FAN, self).__init__()
self.num_modules = num_modules
# Base part
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
self.bn1 = nn.BatchNorm2d(64)
self.conv2 = ConvBlock(64, 128)
self.conv3 = ConvBlock(128, 128)
self.conv4 = ConvBlock(128, 256)
# Stacking part
for hg_module in range(self.num_modules):
self.add_module('m' + str(hg_module), HourGlass(1, 4, 256))
self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256))
self.add_module('conv_last' + str(hg_module),
nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256))
self.add_module('l' + str(hg_module), nn.Conv2d(256,
68, kernel_size=1, stride=1, padding=0))
if hg_module < self.num_modules - 1:
self.add_module(
'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
self.add_module('al' + str(hg_module), nn.Conv2d(68,
256, kernel_size=1, stride=1, padding=0))
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)), True)
x = F.avg_pool2d(self.conv2(x), 2, stride=2)
x = self.conv3(x)
x = self.conv4(x)
previous = x
outputs = []
for i in range(self.num_modules):
hg = self._modules['m' + str(i)](previous)
ll = hg
ll = self._modules['top_m_' + str(i)](ll)
ll = F.relu(self._modules['bn_end' + str(i)]
(self._modules['conv_last' + str(i)](ll)), True)
# Predict heatmaps
tmp_out = self._modules['l' + str(i)](ll)
outputs.append(tmp_out)
if i < self.num_modules - 1:
ll = self._modules['bl' + str(i)](ll)
tmp_out_ = self._modules['al' + str(i)](tmp_out)
previous = previous + ll + tmp_out_
return outputs
class ResNetDepth(nn.Module):
def __init__(self, block=Bottleneck, layers=[3, 8, 36, 3], num_classes=68):
self.inplanes = 64
super(ResNetDepth, self).__init__()
self.conv1 = nn.Conv2d(3 + 68, 64, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AvgPool2d(7)
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x

View File

@ -0,0 +1,313 @@
from __future__ import print_function
import os
import sys
import time
import torch
import math
import numpy as np
import cv2
def _gaussian(
size=3, sigma=0.25, amplitude=1, normalize=False, width=None,
height=None, sigma_horz=None, sigma_vert=None, mean_horz=0.5,
mean_vert=0.5):
# handle some defaults
if width is None:
width = size
if height is None:
height = size
if sigma_horz is None:
sigma_horz = sigma
if sigma_vert is None:
sigma_vert = sigma
center_x = mean_horz * width + 0.5
center_y = mean_vert * height + 0.5
gauss = np.empty((height, width), dtype=np.float32)
# generate kernel
for i in range(height):
for j in range(width):
gauss[i][j] = amplitude * math.exp(-(math.pow((j + 1 - center_x) / (
sigma_horz * width), 2) / 2.0 + math.pow((i + 1 - center_y) / (sigma_vert * height), 2) / 2.0))
if normalize:
gauss = gauss / np.sum(gauss)
return gauss
def draw_gaussian(image, point, sigma):
# Check if the gaussian is inside
ul = [math.floor(point[0] - 3 * sigma), math.floor(point[1] - 3 * sigma)]
br = [math.floor(point[0] + 3 * sigma), math.floor(point[1] + 3 * sigma)]
if (ul[0] > image.shape[1] or ul[1] > image.shape[0] or br[0] < 1 or br[1] < 1):
return image
size = 6 * sigma + 1
g = _gaussian(size)
g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) - int(max(1, ul[0])) + int(max(1, -ul[0]))]
g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) - int(max(1, ul[1])) + int(max(1, -ul[1]))]
img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))]
img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))]
assert (g_x[0] > 0 and g_y[1] > 0)
image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]
] = image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]] + g[g_y[0] - 1:g_y[1], g_x[0] - 1:g_x[1]]
image[image > 1] = 1
return image
def transform(point, center, scale, resolution, invert=False):
"""Generate and affine transformation matrix.
Given a set of points, a center, a scale and a targer resolution, the
function generates and affine transformation matrix. If invert is ``True``
it will produce the inverse transformation.
Arguments:
point {torch.tensor} -- the input 2D point
center {torch.tensor or numpy.array} -- the center around which to perform the transformations
scale {float} -- the scale of the face/object
resolution {float} -- the output resolution
Keyword Arguments:
invert {bool} -- define wherever the function should produce the direct or the
inverse transformation matrix (default: {False})
"""
_pt = torch.ones(3)
_pt[0] = point[0]
_pt[1] = point[1]
h = 200.0 * scale
t = torch.eye(3)
t[0, 0] = resolution / h
t[1, 1] = resolution / h
t[0, 2] = resolution * (-center[0] / h + 0.5)
t[1, 2] = resolution * (-center[1] / h + 0.5)
if invert:
t = torch.inverse(t)
new_point = (torch.matmul(t, _pt))[0:2]
return new_point.int()
def crop(image, center, scale, resolution=256.0):
"""Center crops an image or set of heatmaps
Arguments:
image {numpy.array} -- an rgb image
center {numpy.array} -- the center of the object, usually the same as of the bounding box
scale {float} -- scale of the face
Keyword Arguments:
resolution {float} -- the size of the output cropped image (default: {256.0})
Returns:
[type] -- [description]
""" # Crop around the center point
""" Crops the image around the center. Input is expected to be an np.ndarray """
ul = transform([1, 1], center, scale, resolution, True)
br = transform([resolution, resolution], center, scale, resolution, True)
# pad = math.ceil(torch.norm((ul - br).float()) / 2.0 - (br[0] - ul[0]) / 2.0)
if image.ndim > 2:
newDim = np.array([br[1] - ul[1], br[0] - ul[0],
image.shape[2]], dtype=np.int32)
newImg = np.zeros(newDim, dtype=np.uint8)
else:
newDim = np.array([br[1] - ul[1], br[0] - ul[0]], dtype=np.int)
newImg = np.zeros(newDim, dtype=np.uint8)
ht = image.shape[0]
wd = image.shape[1]
newX = np.array(
[max(1, -ul[0] + 1), min(br[0], wd) - ul[0]], dtype=np.int32)
newY = np.array(
[max(1, -ul[1] + 1), min(br[1], ht) - ul[1]], dtype=np.int32)
oldX = np.array([max(1, ul[0] + 1), min(br[0], wd)], dtype=np.int32)
oldY = np.array([max(1, ul[1] + 1), min(br[1], ht)], dtype=np.int32)
newImg[newY[0] - 1:newY[1], newX[0] - 1:newX[1]
] = image[oldY[0] - 1:oldY[1], oldX[0] - 1:oldX[1], :]
newImg = cv2.resize(newImg, dsize=(int(resolution), int(resolution)),
interpolation=cv2.INTER_LINEAR)
return newImg
def get_preds_fromhm(hm, center=None, scale=None):
"""Obtain (x,y) coordinates given a set of N heatmaps. If the center
and the scale is provided the function will return the points also in
the original coordinate frame.
Arguments:
hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
Keyword Arguments:
center {torch.tensor} -- the center of the bounding box (default: {None})
scale {float} -- face scale (default: {None})
"""
max, idx = torch.max(
hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
idx += 1
preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
for i in range(preds.size(0)):
for j in range(preds.size(1)):
hm_ = hm[i, j, :]
pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
if pX > 0 and pX < 63 and pY > 0 and pY < 63:
diff = torch.FloatTensor(
[hm_[pY, pX + 1] - hm_[pY, pX - 1],
hm_[pY + 1, pX] - hm_[pY - 1, pX]])
preds[i, j].add_(diff.sign_().mul_(.25))
preds.add_(-.5)
preds_orig = torch.zeros(preds.size())
if center is not None and scale is not None:
for i in range(hm.size(0)):
for j in range(hm.size(1)):
preds_orig[i, j] = transform(
preds[i, j], center, scale, hm.size(2), True)
return preds, preds_orig
def get_preds_fromhm_batch(hm, centers=None, scales=None):
"""Obtain (x,y) coordinates given a set of N heatmaps. If the centers
and the scales is provided the function will return the points also in
the original coordinate frame.
Arguments:
hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
Keyword Arguments:
centers {torch.tensor} -- the centers of the bounding box (default: {None})
scales {float} -- face scales (default: {None})
"""
max, idx = torch.max(
hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
idx += 1
preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
for i in range(preds.size(0)):
for j in range(preds.size(1)):
hm_ = hm[i, j, :]
pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
if pX > 0 and pX < 63 and pY > 0 and pY < 63:
diff = torch.FloatTensor(
[hm_[pY, pX + 1] - hm_[pY, pX - 1],
hm_[pY + 1, pX] - hm_[pY - 1, pX]])
preds[i, j].add_(diff.sign_().mul_(.25))
preds.add_(-.5)
preds_orig = torch.zeros(preds.size())
if centers is not None and scales is not None:
for i in range(hm.size(0)):
for j in range(hm.size(1)):
preds_orig[i, j] = transform(
preds[i, j], centers[i], scales[i], hm.size(2), True)
return preds, preds_orig
def shuffle_lr(parts, pairs=None):
"""Shuffle the points left-right according to the axis of symmetry
of the object.
Arguments:
parts {torch.tensor} -- a 3D or 4D object containing the
heatmaps.
Keyword Arguments:
pairs {list of integers} -- [order of the flipped points] (default: {None})
"""
if pairs is None:
pairs = [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0,
26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 27, 28, 29, 30, 35,
34, 33, 32, 31, 45, 44, 43, 42, 47, 46, 39, 38, 37, 36, 41,
40, 54, 53, 52, 51, 50, 49, 48, 59, 58, 57, 56, 55, 64, 63,
62, 61, 60, 67, 66, 65]
if parts.ndimension() == 3:
parts = parts[pairs, ...]
else:
parts = parts[:, pairs, ...]
return parts
def flip(tensor, is_label=False):
"""Flip an image or a set of heatmaps left-right
Arguments:
tensor {numpy.array or torch.tensor} -- [the input image or heatmaps]
Keyword Arguments:
is_label {bool} -- [denote wherever the input is an image or a set of heatmaps ] (default: {False})
"""
if not torch.is_tensor(tensor):
tensor = torch.from_numpy(tensor)
if is_label:
tensor = shuffle_lr(tensor).flip(tensor.ndimension() - 1)
else:
tensor = tensor.flip(tensor.ndimension() - 1)
return tensor
# From pyzolib/paths.py (https://bitbucket.org/pyzo/pyzolib/src/tip/paths.py)
def appdata_dir(appname=None, roaming=False):
""" appdata_dir(appname=None, roaming=False)
Get the path to the application directory, where applications are allowed
to write user specific files (e.g. configurations). For non-user specific
data, consider using common_appdata_dir().
If appname is given, a subdir is appended (and created if necessary).
If roaming is True, will prefer a roaming directory (Windows Vista/7).
"""
# Define default user directory
userDir = os.getenv('FACEALIGNMENT_USERDIR', None)
if userDir is None:
userDir = os.path.expanduser('~')
if not os.path.isdir(userDir): # pragma: no cover
userDir = '/var/tmp' # issue #54
# Get system app data dir
path = None
if sys.platform.startswith('win'):
path1, path2 = os.getenv('LOCALAPPDATA'), os.getenv('APPDATA')
path = (path2 or path1) if roaming else (path1 or path2)
elif sys.platform.startswith('darwin'):
path = os.path.join(userDir, 'Library', 'Application Support')
# On Linux and as fallback
if not (path and os.path.isdir(path)):
path = userDir
# Maybe we should store things local to the executable (in case of a
# portable distro or a frozen application that wants to be portable)
prefix = sys.prefix
if getattr(sys, 'frozen', None):
prefix = os.path.abspath(os.path.dirname(sys.executable))
for reldir in ('settings', '../settings'):
localpath = os.path.abspath(os.path.join(prefix, reldir))
if os.path.isdir(localpath): # pragma: no cover
try:
open(os.path.join(localpath, 'test.write'), 'wb').close()
os.remove(os.path.join(localpath, 'test.write'))
except IOError:
pass # We cannot write in this directory
else:
path = localpath
break
# Get path specific for this app
if appname:
if path == userDir:
appname = '.' + appname.lstrip('.') # Make it a hidden directory
path = os.path.join(path, appname)
if not os.path.isdir(path): # pragma: no cover
os.mkdir(path)
# Done
return path

View File

@ -0,0 +1,57 @@
import torch
import time
import os
import cv2
import numpy as np
from PIL import Image
from .model import BiSeNet
import torchvision.transforms as transforms
class FaceParsing():
def __init__(self,resnet_path='./models/face-parse-bisent/resnet18-5c106cde.pth',
model_pth='./models/face-parse-bisent/79999_iter.pth'):
self.net = self.model_init(resnet_path,model_pth)
self.preprocess = self.image_preprocess()
def model_init(self,
resnet_path,
model_pth):
net = BiSeNet(resnet_path)
if torch.cuda.is_available():
net.cuda()
net.load_state_dict(torch.load(model_pth))
else:
net.load_state_dict(torch.load(model_pth, map_location=torch.device('cpu')))
net.eval()
return net
def image_preprocess(self):
return transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
def __call__(self, image, size=(512, 512)):
if isinstance(image, str):
image = Image.open(image)
width, height = image.size
with torch.no_grad():
image = image.resize(size, Image.BILINEAR)
img = self.preprocess(image)
if torch.cuda.is_available():
img = torch.unsqueeze(img, 0).cuda()
else:
img = torch.unsqueeze(img, 0)
out = self.net(img)[0]
parsing = out.squeeze(0).cpu().numpy().argmax(0)
parsing[np.where(parsing>13)] = 0
parsing[np.where(parsing>=1)] = 255
parsing = Image.fromarray(parsing.astype(np.uint8))
return parsing
if __name__ == "__main__":
fp = FaceParsing()
segmap = fp('154_small.png')
segmap.save('res.png')

View File

@ -0,0 +1,283 @@
#!/usr/bin/python
# -*- encoding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from .resnet import Resnet18
# from modules.bn import InPlaceABNSync as BatchNorm2d
class ConvBNReLU(nn.Module):
def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
super(ConvBNReLU, self).__init__()
self.conv = nn.Conv2d(in_chan,
out_chan,
kernel_size = ks,
stride = stride,
padding = padding,
bias = False)
self.bn = nn.BatchNorm2d(out_chan)
self.init_weight()
def forward(self, x):
x = self.conv(x)
x = F.relu(self.bn(x))
return x
def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
class BiSeNetOutput(nn.Module):
def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
super(BiSeNetOutput, self).__init__()
self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
self.init_weight()
def forward(self, x):
x = self.conv(x)
x = self.conv_out(x)
return x
def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
def get_params(self):
wd_params, nowd_params = [], []
for name, module in self.named_modules():
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
wd_params.append(module.weight)
if not module.bias is None:
nowd_params.append(module.bias)
elif isinstance(module, nn.BatchNorm2d):
nowd_params += list(module.parameters())
return wd_params, nowd_params
class AttentionRefinementModule(nn.Module):
def __init__(self, in_chan, out_chan, *args, **kwargs):
super(AttentionRefinementModule, self).__init__()
self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
self.bn_atten = nn.BatchNorm2d(out_chan)
self.sigmoid_atten = nn.Sigmoid()
self.init_weight()
def forward(self, x):
feat = self.conv(x)
atten = F.avg_pool2d(feat, feat.size()[2:])
atten = self.conv_atten(atten)
atten = self.bn_atten(atten)
atten = self.sigmoid_atten(atten)
out = torch.mul(feat, atten)
return out
def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
class ContextPath(nn.Module):
def __init__(self, resnet_path, *args, **kwargs):
super(ContextPath, self).__init__()
self.resnet = Resnet18(resnet_path)
self.arm16 = AttentionRefinementModule(256, 128)
self.arm32 = AttentionRefinementModule(512, 128)
self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
self.init_weight()
def forward(self, x):
H0, W0 = x.size()[2:]
feat8, feat16, feat32 = self.resnet(x)
H8, W8 = feat8.size()[2:]
H16, W16 = feat16.size()[2:]
H32, W32 = feat32.size()[2:]
avg = F.avg_pool2d(feat32, feat32.size()[2:])
avg = self.conv_avg(avg)
avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
feat32_arm = self.arm32(feat32)
feat32_sum = feat32_arm + avg_up
feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
feat32_up = self.conv_head32(feat32_up)
feat16_arm = self.arm16(feat16)
feat16_sum = feat16_arm + feat32_up
feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
feat16_up = self.conv_head16(feat16_up)
return feat8, feat16_up, feat32_up # x8, x8, x16
def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
def get_params(self):
wd_params, nowd_params = [], []
for name, module in self.named_modules():
if isinstance(module, (nn.Linear, nn.Conv2d)):
wd_params.append(module.weight)
if not module.bias is None:
nowd_params.append(module.bias)
elif isinstance(module, nn.BatchNorm2d):
nowd_params += list(module.parameters())
return wd_params, nowd_params
### This is not used, since I replace this with the resnet feature with the same size
class SpatialPath(nn.Module):
def __init__(self, *args, **kwargs):
super(SpatialPath, self).__init__()
self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
self.init_weight()
def forward(self, x):
feat = self.conv1(x)
feat = self.conv2(feat)
feat = self.conv3(feat)
feat = self.conv_out(feat)
return feat
def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
def get_params(self):
wd_params, nowd_params = [], []
for name, module in self.named_modules():
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
wd_params.append(module.weight)
if not module.bias is None:
nowd_params.append(module.bias)
elif isinstance(module, nn.BatchNorm2d):
nowd_params += list(module.parameters())
return wd_params, nowd_params
class FeatureFusionModule(nn.Module):
def __init__(self, in_chan, out_chan, *args, **kwargs):
super(FeatureFusionModule, self).__init__()
self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
self.conv1 = nn.Conv2d(out_chan,
out_chan//4,
kernel_size = 1,
stride = 1,
padding = 0,
bias = False)
self.conv2 = nn.Conv2d(out_chan//4,
out_chan,
kernel_size = 1,
stride = 1,
padding = 0,
bias = False)
self.relu = nn.ReLU(inplace=True)
self.sigmoid = nn.Sigmoid()
self.init_weight()
def forward(self, fsp, fcp):
fcat = torch.cat([fsp, fcp], dim=1)
feat = self.convblk(fcat)
atten = F.avg_pool2d(feat, feat.size()[2:])
atten = self.conv1(atten)
atten = self.relu(atten)
atten = self.conv2(atten)
atten = self.sigmoid(atten)
feat_atten = torch.mul(feat, atten)
feat_out = feat_atten + feat
return feat_out
def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
def get_params(self):
wd_params, nowd_params = [], []
for name, module in self.named_modules():
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
wd_params.append(module.weight)
if not module.bias is None:
nowd_params.append(module.bias)
elif isinstance(module, nn.BatchNorm2d):
nowd_params += list(module.parameters())
return wd_params, nowd_params
class BiSeNet(nn.Module):
def __init__(self, resnet_path='models/resnet18-5c106cde.pth', n_classes=19, *args, **kwargs):
super(BiSeNet, self).__init__()
self.cp = ContextPath(resnet_path)
## here self.sp is deleted
self.ffm = FeatureFusionModule(256, 256)
self.conv_out = BiSeNetOutput(256, 256, n_classes)
self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
self.init_weight()
def forward(self, x):
H, W = x.size()[2:]
feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature
feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
feat_fuse = self.ffm(feat_sp, feat_cp8)
feat_out = self.conv_out(feat_fuse)
feat_out16 = self.conv_out16(feat_cp8)
feat_out32 = self.conv_out32(feat_cp16)
feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
return feat_out, feat_out16, feat_out32
def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
def get_params(self):
wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
for name, child in self.named_children():
child_wd_params, child_nowd_params = child.get_params()
if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
lr_mul_wd_params += child_wd_params
lr_mul_nowd_params += child_nowd_params
else:
wd_params += child_wd_params
nowd_params += child_nowd_params
return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
if __name__ == "__main__":
net = BiSeNet(19)
net.cuda()
net.eval()
in_ten = torch.randn(16, 3, 640, 480).cuda()
out, out16, out32 = net(in_ten)
print(out.shape)
net.get_params()

View File

@ -0,0 +1,109 @@
#!/usr/bin/python
# -*- encoding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as modelzoo
# from modules.bn import InPlaceABNSync as BatchNorm2d
resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
class BasicBlock(nn.Module):
def __init__(self, in_chan, out_chan, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(in_chan, out_chan, stride)
self.bn1 = nn.BatchNorm2d(out_chan)
self.conv2 = conv3x3(out_chan, out_chan)
self.bn2 = nn.BatchNorm2d(out_chan)
self.relu = nn.ReLU(inplace=True)
self.downsample = None
if in_chan != out_chan or stride != 1:
self.downsample = nn.Sequential(
nn.Conv2d(in_chan, out_chan,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_chan),
)
def forward(self, x):
residual = self.conv1(x)
residual = F.relu(self.bn1(residual))
residual = self.conv2(residual)
residual = self.bn2(residual)
shortcut = x
if self.downsample is not None:
shortcut = self.downsample(x)
out = shortcut + residual
out = self.relu(out)
return out
def create_layer_basic(in_chan, out_chan, bnum, stride=1):
layers = [BasicBlock(in_chan, out_chan, stride=stride)]
for i in range(bnum-1):
layers.append(BasicBlock(out_chan, out_chan, stride=1))
return nn.Sequential(*layers)
class Resnet18(nn.Module):
def __init__(self, model_path):
super(Resnet18, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
self.init_weight(model_path)
def forward(self, x):
x = self.conv1(x)
x = F.relu(self.bn1(x))
x = self.maxpool(x)
x = self.layer1(x)
feat8 = self.layer2(x) # 1/8
feat16 = self.layer3(feat8) # 1/16
feat32 = self.layer4(feat16) # 1/32
return feat8, feat16, feat32
def init_weight(self, model_path):
state_dict = torch.load(model_path) #modelzoo.load_url(resnet18_url)
self_state_dict = self.state_dict()
for k, v in state_dict.items():
if 'fc' in k: continue
self_state_dict.update({k: v})
self.load_state_dict(self_state_dict)
def get_params(self):
wd_params, nowd_params = [], []
for name, module in self.named_modules():
if isinstance(module, (nn.Linear, nn.Conv2d)):
wd_params.append(module.weight)
if not module.bias is None:
nowd_params.append(module.bias)
elif isinstance(module, nn.BatchNorm2d):
nowd_params += list(module.parameters())
return wd_params, nowd_params
if __name__ == "__main__":
net = Resnet18()
x = torch.randn(16, 3, 224, 224)
out = net(x)
print(out[0].size())
print(out[1].size())
print(out[2].size())
net.get_params()

View File

@ -0,0 +1,154 @@
import sys
from face_detection import FaceAlignment,LandmarksType
from os import listdir, path
import subprocess
import numpy as np
import cv2
import pickle
import os
import json
from mmpose.apis import inference_topdown, init_model
from mmpose.structures import merge_data_samples
import torch
from tqdm import tqdm
# initialize the mmpose model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config_file = './musetalk/utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py'
checkpoint_file = './models/dwpose/dw-ll_ucoco_384.pth'
model = init_model(config_file, checkpoint_file, device=device)
# initialize the face detection model
device = "cuda" if torch.cuda.is_available() else "cpu"
fa = FaceAlignment(LandmarksType._2D, flip_input=False,device=device)
# maker if the bbox is not sufficient
coord_placeholder = (0.0,0.0,0.0,0.0)
def resize_landmark(landmark, w, h, new_w, new_h):
w_ratio = new_w / w
h_ratio = new_h / h
landmark_norm = landmark / [w, h]
landmark_resized = landmark_norm * [new_w, new_h]
return landmark_resized
def read_imgs(img_list):
frames = []
print('reading images...')
for img_path in tqdm(img_list):
frame = cv2.imread(img_path)
frames.append(frame)
return frames
def get_bbox_range(img_list,upperbondrange =0):
frames = read_imgs(img_list)
batch_size_fa = 1
batches = [frames[i:i + batch_size_fa] for i in range(0, len(frames), batch_size_fa)]
coords_list = []
landmarks = []
if upperbondrange != 0:
print('get key_landmark and face bounding boxes with the bbox_shift:',upperbondrange)
else:
print('get key_landmark and face bounding boxes with the default value')
average_range_minus = []
average_range_plus = []
for fb in tqdm(batches):
results = inference_topdown(model, np.asarray(fb)[0])
results = merge_data_samples(results)
keypoints = results.pred_instances.keypoints
face_land_mark= keypoints[0][23:91]
face_land_mark = face_land_mark.astype(np.int32)
# get bounding boxes by face detetion
bbox = fa.get_detections_for_batch(np.asarray(fb))
# adjust the bounding box refer to landmark
# Add the bounding box to a tuple and append it to the coordinates list
for j, f in enumerate(bbox):
if f is None: # no face in the image
coords_list += [coord_placeholder]
continue
half_face_coord = face_land_mark[29]#np.mean([face_land_mark[28], face_land_mark[29]], axis=0)
range_minus = (face_land_mark[30]- face_land_mark[29])[1]
range_plus = (face_land_mark[29]- face_land_mark[28])[1]
average_range_minus.append(range_minus)
average_range_plus.append(range_plus)
if upperbondrange != 0:
half_face_coord[1] = upperbondrange+half_face_coord[1] #手动调整 + 向下偏29 - 向上偏28
text_range=f"Total frame:「{len(frames)}」 Manually adjust range : [ -{int(sum(average_range_minus) / len(average_range_minus))}~{int(sum(average_range_plus) / len(average_range_plus))} ] , the current value: {upperbondrange}"
return text_range
def get_landmark_and_bbox(img_list,upperbondrange =0):
frames = read_imgs(img_list)
batch_size_fa = 1
batches = [frames[i:i + batch_size_fa] for i in range(0, len(frames), batch_size_fa)]
coords_list = []
landmarks = []
if upperbondrange != 0:
print('get key_landmark and face bounding boxes with the bbox_shift:',upperbondrange)
else:
print('get key_landmark and face bounding boxes with the default value')
average_range_minus = []
average_range_plus = []
for fb in tqdm(batches):
results = inference_topdown(model, np.asarray(fb)[0])
results = merge_data_samples(results)
keypoints = results.pred_instances.keypoints
face_land_mark= keypoints[0][23:91]
face_land_mark = face_land_mark.astype(np.int32)
# get bounding boxes by face detetion
bbox = fa.get_detections_for_batch(np.asarray(fb))
# adjust the bounding box refer to landmark
# Add the bounding box to a tuple and append it to the coordinates list
for j, f in enumerate(bbox):
if f is None: # no face in the image
coords_list += [coord_placeholder]
continue
half_face_coord = face_land_mark[29]#np.mean([face_land_mark[28], face_land_mark[29]], axis=0)
range_minus = (face_land_mark[30]- face_land_mark[29])[1]
range_plus = (face_land_mark[29]- face_land_mark[28])[1]
average_range_minus.append(range_minus)
average_range_plus.append(range_plus)
if upperbondrange != 0:
half_face_coord[1] = upperbondrange+half_face_coord[1] #手动调整 + 向下偏29 - 向上偏28
half_face_dist = np.max(face_land_mark[:,1]) - half_face_coord[1]
upper_bond = half_face_coord[1]-half_face_dist
f_landmark = (np.min(face_land_mark[:, 0]),int(upper_bond),np.max(face_land_mark[:, 0]),np.max(face_land_mark[:,1]))
x1, y1, x2, y2 = f_landmark
if y2-y1<=0 or x2-x1<=0 or x1<0: # if the landmark bbox is not suitable, reuse the bbox
coords_list += [f]
w,h = f[2]-f[0], f[3]-f[1]
print("error bbox:",f)
else:
coords_list += [f_landmark]
print("********************************************bbox_shift parameter adjustment**********************************************************")
print(f"Total frame:「{len(frames)}」 Manually adjust range : [ -{int(sum(average_range_minus) / len(average_range_minus))}~{int(sum(average_range_plus) / len(average_range_plus))} ] , the current value: {upperbondrange}")
print("*************************************************************************************************************************************")
return coords_list,frames
if __name__ == "__main__":
img_list = ["./results/lyria/00000.png","./results/lyria/00001.png","./results/lyria/00002.png","./results/lyria/00003.png"]
crop_coord_path = "./coord_face.pkl"
coords_list,full_frames = get_landmark_and_bbox(img_list)
with open(crop_coord_path, 'wb') as f:
pickle.dump(coords_list, f)
for bbox, frame in zip(coords_list,full_frames):
if bbox == coord_placeholder:
continue
x1, y1, x2, y2 = bbox
crop_frame = frame[y1:y2, x1:x2]
print('Cropped shape', crop_frame.shape)
#cv2.imwrite(path.join(save_dir, '{}.png'.format(i)),full_frames[i][0][y1:y2, x1:x2])
print(coords_list)

75
musetalk/utils/utils.py Normal file
View File

@ -0,0 +1,75 @@
import os
import cv2
import numpy as np
import torch
ffmpeg_path = os.getenv('FFMPEG_PATH')
if ffmpeg_path is None:
print("please download ffmpeg-static and export to FFMPEG_PATH. \nFor example: export FFMPEG_PATH=/musetalk/ffmpeg-4.4-amd64-static")
elif ffmpeg_path not in os.getenv('PATH'):
print("add ffmpeg to path")
os.environ["PATH"] = f"{ffmpeg_path}:{os.environ['PATH']}"
from musetalk.whisper.audio2feature import Audio2Feature
from musetalk.models.vae import VAE
from musetalk.models.unet import UNet,PositionalEncoding
def load_all_model():
audio_processor = Audio2Feature(model_path="./models/whisper/tiny.pt")
vae = VAE(model_path = "./models/sd-vae-ft-mse/")
unet = UNet(unet_config="./models/musetalk/musetalk.json",
model_path ="./models/musetalk/pytorch_model.bin")
pe = PositionalEncoding(d_model=384)
return audio_processor,vae,unet,pe
def get_file_type(video_path):
_, ext = os.path.splitext(video_path)
if ext.lower() in ['.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff']:
return 'image'
elif ext.lower() in ['.avi', '.mp4', '.mov', '.flv', '.mkv']:
return 'video'
else:
return 'unsupported'
def get_video_fps(video_path):
video = cv2.VideoCapture(video_path)
fps = video.get(cv2.CAP_PROP_FPS)
video.release()
return fps
def datagen(whisper_chunks,
vae_encode_latents,
batch_size=8,
delay_frame=0):
whisper_batch, latent_batch = [], []
for i, w in enumerate(whisper_chunks):
idx = (i+delay_frame)%len(vae_encode_latents)
latent = vae_encode_latents[idx]
whisper_batch.append(w)
latent_batch.append(latent)
if len(latent_batch) >= batch_size:
whisper_batch = np.stack(whisper_batch)
latent_batch = torch.cat(latent_batch, dim=0)
yield whisper_batch, latent_batch
whisper_batch, latent_batch = [], []
# the last batch may smaller than batch size
if len(latent_batch) > 0:
whisper_batch = np.stack(whisper_batch)
latent_batch = torch.cat(latent_batch, dim=0)
yield whisper_batch, latent_batch
def load_audio_model():
audio_processor = Audio2Feature(model_path="./models/whisper/tiny.pt")
return audio_processor
def load_diffusion_model():
vae = VAE(model_path = "./models/sd-vae-ft-mse/")
unet = UNet(unet_config="./models/musetalk/musetalk.json",
model_path ="./models/musetalk/pytorch_model.bin")
pe = PositionalEncoding(d_model=384)
return vae,unet,pe

View File

@ -0,0 +1,130 @@
import os
from .whisper import load_model
import soundfile as sf
import numpy as np
import time
import sys
sys.path.append("..")
class Audio2Feature():
def __init__(self,
whisper_model_type="tiny",
model_path="./models/whisper/tiny.pt"):
self.whisper_model_type = whisper_model_type
self.model = load_model(model_path) #
def get_sliced_feature(self,
feature_array,
vid_idx,
audio_feat_length=[2,2],
fps=25):
"""
Get sliced features based on a given index
:param feature_array:
:param start_idx: the start index of the feature
:param audio_feat_length:
:return:
"""
length = len(feature_array)
selected_feature = []
selected_idx = []
center_idx = int(vid_idx*50/fps)
left_idx = center_idx-audio_feat_length[0]*2
right_idx = center_idx + (audio_feat_length[1]+1)*2
for idx in range(left_idx,right_idx):
idx = max(0, idx)
idx = min(length-1, idx)
x = feature_array[idx]
selected_feature.append(x)
selected_idx.append(idx)
selected_feature = np.concatenate(selected_feature, axis=0)
selected_feature = selected_feature.reshape(-1, 384)# 50*384
return selected_feature,selected_idx
def get_sliced_feature_sparse(self,feature_array, vid_idx, audio_feat_length= [2,2],fps = 25):
"""
Get sliced features based on a given index
:param feature_array:
:param start_idx: the start index of the feature
:param audio_feat_length:
:return:
"""
length = len(feature_array)
selected_feature = []
selected_idx = []
for dt in range(-audio_feat_length[0],audio_feat_length[1]+1):
left_idx = int((vid_idx+dt)*50/fps)
if left_idx<1 or left_idx>length-1:
print('test-----,left_idx=',left_idx)
left_idx = max(0, left_idx)
left_idx = min(length-1, left_idx)
x = feature_array[left_idx]
x = x[np.newaxis,:,:]
x = np.repeat(x, 2, axis=0)
selected_feature.append(x)
selected_idx.append(left_idx)
selected_idx.append(left_idx)
else:
x = feature_array[left_idx-1:left_idx+1]
selected_feature.append(x)
selected_idx.append(left_idx-1)
selected_idx.append(left_idx)
selected_feature = np.concatenate(selected_feature, axis=0)
selected_feature = selected_feature.reshape(-1, 384)# 50*384
return selected_feature,selected_idx
def feature2chunks(self,feature_array,fps,batch_size,audio_feat_length = [2,2],start=0):
whisper_chunks = []
whisper_idx_multiplier = 50./fps
i = 0
#print(f"video in {fps} FPS, audio idx in 50FPS")
for _ in range(batch_size):
# start_idx = int(i * whisper_idx_multiplier)
# if start_idx>=len(feature_array):
# break
selected_feature,selected_idx = self.get_sliced_feature(feature_array= feature_array,vid_idx = i+start,audio_feat_length=audio_feat_length,fps=fps)
#print(f"i:{i},selected_idx {selected_idx}")
whisper_chunks.append(selected_feature)
i += 1
return whisper_chunks
def audio2feat(self,audio_path):
# get the sample rate of the audio
result = self.model.transcribe(audio_path)
embed_list = []
for emb in result['segments']:
encoder_embeddings = emb['encoder_embeddings']
encoder_embeddings = encoder_embeddings.transpose(0,2,1,3)
encoder_embeddings = encoder_embeddings.squeeze(0)
start_idx = int(emb['start'])
end_idx = int(emb['end'])
emb_end_idx = int((end_idx - start_idx)/2)
embed_list.append(encoder_embeddings[:emb_end_idx])
concatenated_array = np.concatenate(embed_list, axis=0)
return concatenated_array
if __name__ == "__main__":
audio_processor = Audio2Feature(model_path="../../models/whisper/whisper_tiny.pt")
audio_path = "./test.mp3"
array = audio_processor.audio2feat(audio_path)
print(array.shape)
fps = 25
whisper_idx_multiplier = 50./fps
i = 0
print(f"video in {fps} FPS, audio idx in 50FPS")
while 1:
start_idx = int(i * whisper_idx_multiplier)
selected_feature,selected_idx = audio_processor.get_sliced_feature(feature_array= array,vid_idx = i,audio_feat_length=[2,2],fps=fps)
print(f"video idx {i},\t audio idx {selected_idx},\t shape {selected_feature.shape}")
i += 1
if start_idx>len(array):
break

View File

@ -0,0 +1,116 @@
import hashlib
import io
import os
import urllib
import warnings
from typing import List, Optional, Union
import torch
from tqdm import tqdm
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
from .model import Whisper, ModelDimensions
from .transcribe import transcribe
_MODELS = {
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
"large": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large.pt",
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
}
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
os.makedirs(root, exist_ok=True)
expected_sha256 = url.split("/")[-2]
download_target = os.path.join(root, os.path.basename(url))
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")
if os.path.isfile(download_target):
model_bytes = open(download_target, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
return model_bytes if in_memory else download_target
else:
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
model_bytes = open(download_target, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model.")
return model_bytes if in_memory else download_target
def available_models() -> List[str]:
"""Returns the names of available models"""
return list(_MODELS.keys())
def load_model(name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None, in_memory: bool = False) -> Whisper:
"""
Load a Whisper ASR model
Parameters
----------
name : str
one of the official model names listed by `whisper.available_models()`, or
path to a model checkpoint containing the model dimensions and the model state_dict.
device : Union[str, torch.device]
the PyTorch device to put the model into
download_root: str
path to download the model files; by default, it uses "~/.cache/whisper"
in_memory: bool
whether to preload the model weights into host memory
Returns
-------
model : Whisper
The Whisper ASR model instance
"""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if download_root is None:
download_root = os.getenv(
"XDG_CACHE_HOME",
os.path.join(os.path.expanduser("~"), ".cache", "whisper")
)
if name in _MODELS:
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
elif os.path.isfile(name):
checkpoint_file = open(name, "rb").read() if in_memory else name
else:
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
with (io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")) as fp:
checkpoint = torch.load(fp, map_location=device)
del checkpoint_file
dims = ModelDimensions(**checkpoint["dims"])
model = Whisper(dims)
model.load_state_dict(checkpoint["model_state_dict"])
return model.to(device)

View File

@ -0,0 +1,4 @@
from .transcribe import cli
cli()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}

View File

@ -0,0 +1 @@
{"unk_token": "<|endoftext|>", "bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "gpt2", "tokenizer_class": "GPT2Tokenizer"}

File diff suppressed because one or more lines are too long

Binary file not shown.

View File

@ -0,0 +1 @@
{"<|endoftext|>": 50257}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}

View File

@ -0,0 +1 @@
{"unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "multilingual", "errors": "replace", "tokenizer_class": "GPT2Tokenizer"}

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,125 @@
import os
from functools import lru_cache
from typing import Union
import ffmpeg
import numpy as np
import torch
import torch.nn.functional as F
from .utils import exact_div
# hard-coded audio hyperparameters
SAMPLE_RATE = 16000
N_FFT = 400
N_MELS = 80
HOP_LENGTH = 160
CHUNK_LENGTH = 30
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input
def load_audio(file: str, sr: int = SAMPLE_RATE):
"""
Open an audio file and read as mono waveform, resampling as necessary
Parameters
----------
file: str
The audio file to open
sr: int
The sample rate to resample the audio if necessary
Returns
-------
A NumPy array containing the audio waveform, in float32 dtype.
"""
try:
# This launches a subprocess to decode audio while down-mixing and resampling as necessary.
# Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
out, _ = (
ffmpeg.input(file, threads=0)
.output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr)
.run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
)
except ffmpeg.Error as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
"""
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
"""
if torch.is_tensor(array):
if array.shape[axis] > length:
array = array.index_select(dim=axis, index=torch.arange(length))
if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim
pad_widths[axis] = (0, length - array.shape[axis])
array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
else:
if array.shape[axis] > length:
array = array.take(indices=range(length), axis=axis)
if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim
pad_widths[axis] = (0, length - array.shape[axis])
array = np.pad(array, pad_widths)
return array
@lru_cache(maxsize=None)
def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
"""
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
Allows decoupling librosa dependency; saved using:
np.savez_compressed(
"mel_filters.npz",
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
)
"""
assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
with np.load(os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")) as f:
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS):
"""
Compute the log-Mel spectrogram of
Parameters
----------
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
n_mels: int
The number of Mel-frequency filters, only 80 is supported
Returns
-------
torch.Tensor, shape = (80, n_frames)
A Tensor that contains the Mel spectrogram
"""
if not torch.is_tensor(audio):
if isinstance(audio, str):
audio = load_audio(audio)
audio = torch.from_numpy(audio)
window = torch.hann_window(N_FFT).to(audio.device)
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
magnitudes = stft[:, :-1].abs() ** 2
filters = mel_filters(audio.device, n_mels)
mel_spec = filters @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec

View File

@ -0,0 +1,729 @@
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Iterable, Optional, Sequence, Union, TYPE_CHECKING
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.distributions import Categorical
from .audio import CHUNK_LENGTH
from .tokenizer import Tokenizer, get_tokenizer
from .utils import compression_ratio
if TYPE_CHECKING:
from .model import Whisper
@torch.no_grad()
def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None) -> Tuple[Tensor, List[dict]]:
"""
Detect the spoken language in the audio, and return them as list of strings, along with the ids
of the most probable language tokens and the probability distribution over all language tokens.
This is performed outside the main decode loop in order to not interfere with kv-caching.
Returns
-------
language_tokens : Tensor, shape = (n_audio,)
ids of the most probable language tokens, which appears after the startoftranscript token.
language_probs : List[Dict[str, float]], length = n_audio
list of dictionaries containing the probability distribution over all languages.
"""
if tokenizer is None:
tokenizer = get_tokenizer(model.is_multilingual)
if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence:
raise ValueError(f"This model doesn't have language tokens so it can't perform lang id")
single = mel.ndim == 2
if single:
mel = mel.unsqueeze(0)
# skip encoder forward pass if already-encoded audio features were given
if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
mel = model.encoder(mel)
# forward pass using a single token, startoftranscript
n_audio = mel.shape[0]
x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
logits = model.logits(x, mel)[:, 0]
# collect detected languages; suppress all non-language tokens
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
mask[list(tokenizer.all_language_tokens)] = False
logits[:, mask] = -np.inf
language_tokens = logits.argmax(dim=-1)
language_token_probs = logits.softmax(dim=-1).cpu()
language_probs = [
{
c: language_token_probs[i, j].item()
for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
}
for i in range(n_audio)
]
if single:
language_tokens = language_tokens[0]
language_probs = language_probs[0]
return language_tokens, language_probs
@dataclass(frozen=True)
class DecodingOptions:
task: str = "transcribe" # whether to perform X->X "transcribe" or X->English "translate"
language: Optional[str] = None # language that the audio is in; uses detected language if None
# sampling-related options
temperature: float = 0.0
sample_len: Optional[int] = None # maximum number of tokens to sample
best_of: Optional[int] = None # number of independent samples to collect, when t > 0
beam_size: Optional[int] = None # number of beams in beam search, when t == 0
patience: Optional[float] = None # patience in beam search (https://arxiv.org/abs/2204.05424)
# options for ranking generations (either beams or best-of-N samples)
length_penalty: Optional[float] = None # "alpha" in Google NMT, None defaults to length norm
# prompt, prefix, and token suppression
prompt: Optional[Union[str, List[int]]] = None # text or tokens for the previous context
prefix: Optional[Union[str, List[int]]] = None # text or tokens to prefix the current context
suppress_blank: bool = True # this will suppress blank outputs
# list of tokens ids (or comma-separated token ids) to suppress
# "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
# timestamp sampling options
without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
max_initial_timestamp: Optional[float] = 1.0 # the initial timestamp cannot be later than this
# implementation details
fp16: bool = True # use fp16 for most of the calculation
@dataclass(frozen=True)
class DecodingResult:
audio_features: Tensor
language: str
encoder_embeddings: np.ndarray
decoder_embeddings: np.ndarray
language_probs: Optional[Dict[str, float]] = None
tokens: List[int] = field(default_factory=list)
text: str = ""
avg_logprob: float = np.nan
no_speech_prob: float = np.nan
temperature: float = np.nan
compression_ratio: float = np.nan
class Inference:
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
"""Perform a forward pass on the decoder and return per-token logits"""
raise NotImplementedError
def rearrange_kv_cache(self, source_indices) -> None:
"""Update the key-value cache according to the updated beams"""
raise NotImplementedError
def cleanup_caching(self) -> None:
"""Clean up any resources or hooks after decoding is finished"""
pass
class PyTorchInference(Inference):
def __init__(self, model: "Whisper", initial_token_length: int):
self.model: "Whisper" = model
self.initial_token_length = initial_token_length
self.kv_cache = {}
self.hooks = []
def logits(self, tokens: Tensor, audio_features: Tensor, include_embeddings=False) -> Tensor:
if not self.kv_cache:
self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
if tokens.shape[-1] > self.initial_token_length:
# only need to use the last token except in the first forward pass
tokens = tokens[:, -1:]
return_val = self.model.decoder(tokens, audio_features,
kv_cache=self.kv_cache, include_embeddings=include_embeddings)
return return_val
def cleanup_caching(self):
for hook in self.hooks:
hook.remove()
self.kv_cache = {}
self.hooks = []
def rearrange_kv_cache(self, source_indices):
for module, tensor in self.kv_cache.items():
# update the key/value cache to contain the selected sequences
self.kv_cache[module] = tensor[source_indices].detach()
class SequenceRanker:
def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]) -> List[int]:
"""
Given a list of groups of samples and their cumulative log probabilities,
return the indices of the samples in each group to select as the final result
"""
raise NotImplementedError
class MaximumLikelihoodRanker(SequenceRanker):
"""
Select the sample with the highest log probabilities, penalized using either
a simple length normalization or Google NMT paper's length penalty
"""
def __init__(self, length_penalty: Optional[float]):
self.length_penalty = length_penalty
def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]):
def scores(logprobs, lengths):
result = []
for logprob, length in zip(logprobs, lengths):
if self.length_penalty is None:
penalty = length
else:
# from the Google NMT paper
penalty = ((5 + length) / 6) ** self.length_penalty
result.append(logprob / penalty)
return result
# get the sequence with the highest score
lengths = [[len(t) for t in s] for s in tokens]
return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
class TokenDecoder:
def reset(self):
"""Initialize any stateful variables for decoding a new sequence"""
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
"""Specify how to select the next token, based on the current trace and logits
Parameters
----------
tokens : Tensor, shape = (n_batch, current_sequence_length)
all tokens in the context so far, including the prefix and sot_sequence tokens
logits : Tensor, shape = (n_batch, vocab_size)
per-token logits of the probability distribution at the current step
sum_logprobs : Tensor, shape = (n_batch)
cumulative log probabilities for each sequence
Returns
-------
tokens : Tensor, shape = (n_batch, current_sequence_length + 1)
the tokens, appended with the selected next token
completed : bool
True if all sequences has reached the end of text
"""
raise NotImplementedError
def finalize(
self, tokens: Tensor, sum_logprobs: Tensor
) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]:
"""Finalize search and return the final candidate sequences
Parameters
----------
tokens : Tensor, shape = (n_audio, n_group, current_sequence_length)
all tokens in the context so far, including the prefix and sot_sequence
sum_logprobs : Tensor, shape = (n_audio, n_group)
cumulative log probabilities for each sequence
Returns
-------
tokens : Sequence[Sequence[Tensor]], length = n_audio
sequence of Tensors containing candidate token sequences, for each audio input
sum_logprobs : List[List[float]], length = n_audio
sequence of cumulative log probabilities corresponding to the above
"""
raise NotImplementedError
class GreedyDecoder(TokenDecoder):
def __init__(self, temperature: float, eot: int):
self.temperature = temperature
self.eot = eot
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
temperature = self.temperature
if temperature == 0:
next_tokens = logits.argmax(dim=-1)
else:
next_tokens = Categorical(logits=logits / temperature).sample()
logprobs = F.log_softmax(logits.float(), dim=-1)
current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
next_tokens[tokens[:, -1] == self.eot] = self.eot
tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
completed = (tokens[:, -1] == self.eot).all()
return tokens, completed
def finalize(self, tokens: Tensor, sum_logprobs: Tensor):
# make sure each sequence has at least one EOT token at the end
tokens = F.pad(tokens, (0, 1), value=self.eot)
return tokens, sum_logprobs.tolist()
class BeamSearchDecoder(TokenDecoder):
def __init__(self, beam_size: int, eot: int, inference: Inference, patience: Optional[float] = None):
self.beam_size = beam_size
self.eot = eot
self.inference = inference
self.patience = patience or 1.0
self.max_candidates: int = round(beam_size * self.patience)
self.finished_sequences = None
assert self.max_candidates > 0, f"Invalid beam size ({beam_size}) or patience ({patience})"
def reset(self):
self.finished_sequences = None
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
if tokens.shape[0] % self.beam_size != 0:
raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
n_audio = tokens.shape[0] // self.beam_size
if self.finished_sequences is None: # for the first update
self.finished_sequences = [{} for _ in range(n_audio)]
logprobs = F.log_softmax(logits.float(), dim=-1)
next_tokens, source_indices, finished_sequences = [], [], []
for i in range(n_audio):
scores, sources, finished = {}, {}, {}
# STEP 1: calculate the cumulative log probabilities for possible candidates
for j in range(self.beam_size):
idx = i * self.beam_size + j
prefix = tokens[idx].tolist()
for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)):
new_logprob = (sum_logprobs[idx] + logprob).item()
sequence = tuple(prefix + [token.item()])
scores[sequence] = new_logprob
sources[sequence] = idx
# STEP 2: rank the candidates and keep the top beam_size sequences for each audio
saved = 0
for sequence in sorted(scores, key=scores.get, reverse=True):
if sequence[-1] == self.eot:
finished[sequence] = scores[sequence]
else:
sum_logprobs[len(next_tokens)] = scores[sequence]
next_tokens.append(sequence)
source_indices.append(sources[sequence])
saved += 1
if saved == self.beam_size:
break
finished_sequences.append(finished)
tokens = torch.tensor(next_tokens, device=tokens.device)
self.inference.rearrange_kv_cache(source_indices)
# add newly finished sequences to self.finished_sequences
assert len(self.finished_sequences) == len(finished_sequences)
for previously_finished, newly_finished in zip(self.finished_sequences, finished_sequences):
for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
if len(previously_finished) >= self.max_candidates:
break # the candidate list is full
previously_finished[seq] = newly_finished[seq]
# mark as completed if all audio has enough number of samples
completed = all(
len(sequences) >= self.max_candidates for sequences in self.finished_sequences
)
return tokens, completed
def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor):
# collect all finished sequences, including patience, and add unfinished ones if not enough
sum_logprobs = sum_logprobs.cpu()
for i, sequences in enumerate(self.finished_sequences):
if len(sequences) < self.beam_size: # when not enough sequences are finished
for j in list(np.argsort(sum_logprobs[i]))[::-1]:
sequence = preceding_tokens[i, j].tolist() + [self.eot]
sequences[tuple(sequence)] = sum_logprobs[i][j].item()
if len(sequences) >= self.beam_size:
break
tokens: List[List[Tensor]] = [
[torch.tensor(seq) for seq in sequences.keys()] for sequences in self.finished_sequences
]
sum_logprobs: List[List[float]] = [
list(sequences.values()) for sequences in self.finished_sequences
]
return tokens, sum_logprobs
class LogitFilter:
def apply(self, logits: Tensor, tokens: Tensor) -> None:
"""Apply any filtering or masking to logits in-place
Parameters
----------
logits : Tensor, shape = (n_batch, vocab_size)
per-token logits of the probability distribution at the current step
tokens : Tensor, shape = (n_batch, current_sequence_length)
all tokens in the context so far, including the prefix and sot_sequence tokens
"""
raise NotImplementedError
class SuppressBlank(LogitFilter):
def __init__(self, tokenizer: Tokenizer, sample_begin: int):
self.tokenizer = tokenizer
self.sample_begin = sample_begin
def apply(self, logits: Tensor, tokens: Tensor):
if tokens.shape[1] == self.sample_begin:
logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
class SuppressTokens(LogitFilter):
def __init__(self, suppress_tokens: Sequence[int]):
self.suppress_tokens = list(suppress_tokens)
def apply(self, logits: Tensor, tokens: Tensor):
logits[:, self.suppress_tokens] = -np.inf
class ApplyTimestampRules(LogitFilter):
def __init__(
self, tokenizer: Tokenizer, sample_begin: int, max_initial_timestamp_index: Optional[int]
):
self.tokenizer = tokenizer
self.sample_begin = sample_begin
self.max_initial_timestamp_index = max_initial_timestamp_index
def apply(self, logits: Tensor, tokens: Tensor):
# suppress <|notimestamps|> which is handled by without_timestamps
if self.tokenizer.no_timestamps is not None:
logits[:, self.tokenizer.no_timestamps] = -np.inf
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
for k in range(tokens.shape[0]):
seq = [t for t in tokens[k, self.sample_begin :].tolist()]
last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
if last_was_timestamp:
if penultimate_was_timestamp: # has to be non-timestamp
logits[k, self.tokenizer.timestamp_begin :] = -np.inf
else: # cannot be normal text tokens
logits[k, : self.tokenizer.eot] = -np.inf
# apply the `max_initial_timestamp` option
if tokens.shape[1] == self.sample_begin and self.max_initial_timestamp_index is not None:
last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
logits[:, last_allowed + 1 :] = -np.inf
# if sum of probability over timestamps is above any other token, sample timestamp
logprobs = F.log_softmax(logits.float(), dim=-1)
for k in range(tokens.shape[0]):
timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(dim=-1)
max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
if timestamp_logprob > max_text_token_logprob:
logits[k, : self.tokenizer.timestamp_begin] = -np.inf
class DecodingTask:
inference: Inference
sequence_ranker: SequenceRanker
decoder: TokenDecoder
logit_filters: List[LogitFilter]
def __init__(self, model: "Whisper", options: DecodingOptions):
self.model = model
language = options.language or "en"
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=options.task)
self.tokenizer: Tokenizer = tokenizer
self.options: DecodingOptions = self._verify_options(options)
self.n_group: int = options.beam_size or options.best_of or 1
self.n_ctx: int = model.dims.n_text_ctx
self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
if self.options.without_timestamps:
self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
self.initial_tokens: Tuple[int] = self._get_initial_tokens()
self.sample_begin: int = len(self.initial_tokens)
self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
# inference: implements the forward pass through the decoder, including kv caching
self.inference = PyTorchInference(model, len(self.initial_tokens))
# sequence ranker: implements how to rank a group of sampled sequences
self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
# decoder: implements how to select the next tokens, given the autoregressive distribution
if options.beam_size is not None:
self.decoder = BeamSearchDecoder(
options.beam_size, tokenizer.eot, self.inference, options.patience
)
else:
self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
# logit filters: applies various rules to suppress or penalize certain tokens
self.logit_filters = []
if self.options.suppress_blank:
self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
if self.options.suppress_tokens:
self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
if not options.without_timestamps:
precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
max_initial_timestamp_index = None
if options.max_initial_timestamp:
max_initial_timestamp_index = round(self.options.max_initial_timestamp / precision)
self.logit_filters.append(
ApplyTimestampRules(tokenizer, self.sample_begin, max_initial_timestamp_index)
)
def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
if options.beam_size is not None and options.best_of is not None:
raise ValueError("beam_size and best_of can't be given together")
if options.temperature == 0:
if options.best_of is not None:
raise ValueError("best_of with greedy sampling (T=0) is not compatible")
if options.patience is not None and options.beam_size is None:
raise ValueError("patience requires beam_size to be given")
if options.length_penalty is not None and not (0 <= options.length_penalty <= 1):
raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
return options
def _get_initial_tokens(self) -> Tuple[int]:
tokens = list(self.sot_sequence)
prefix = self.options.prefix
prompt = self.options.prompt
if prefix:
prefix_tokens = (
self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix
)
if self.sample_len is not None:
max_prefix_len = self.n_ctx // 2 - self.sample_len
prefix_tokens = prefix_tokens[-max_prefix_len:]
tokens = tokens + prefix_tokens
if prompt:
prompt_tokens = (
self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt
)
tokens = [self.tokenizer.sot_prev] + prompt_tokens[-(self.n_ctx // 2 - 1) :] + tokens
return tuple(tokens)
def _get_suppress_tokens(self) -> Tuple[int]:
suppress_tokens = self.options.suppress_tokens
if isinstance(suppress_tokens, str):
suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
if -1 in suppress_tokens:
suppress_tokens = [t for t in suppress_tokens if t >= 0]
suppress_tokens.extend(self.tokenizer.non_speech_tokens)
elif suppress_tokens is None or len(suppress_tokens) == 0:
suppress_tokens = [] # interpret empty string as an empty list
else:
assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
suppress_tokens.extend(
[self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm]
)
if self.tokenizer.no_speech is not None:
# no-speech probability is collected separately
suppress_tokens.append(self.tokenizer.no_speech)
return tuple(sorted(set(suppress_tokens)))
def _get_audio_features(self, mel: Tensor, include_embeddings: bool = False):
if self.options.fp16:
mel = mel.half()
if mel.shape[-2:] == (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state):
# encoded audio features are given; skip audio encoding
audio_features = mel
else:
result = self.model.encoder(mel, include_embeddings)
if include_embeddings:
audio_features, embeddings = result
else:
audio_features = result
if audio_features.dtype != (torch.float16 if self.options.fp16 else torch.float32):
return TypeError(f"audio_features has an incorrect dtype: {audio_features.dtype}")
if include_embeddings:
return audio_features, embeddings
else:
return audio_features
def _detect_language(self, audio_features: Tensor, tokens: Tensor):
languages = [self.options.language] * audio_features.shape[0]
lang_probs = None
if self.options.language is None or self.options.task == "lang_id":
lang_tokens, lang_probs = self.model.detect_language(audio_features, self.tokenizer)
languages = [max(probs, key=probs.get) for probs in lang_probs]
if self.options.language is None:
tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
return languages, lang_probs
def _main_loop(self, audio_features: Tensor, tokens: Tensor):
assert audio_features.shape[0] == tokens.shape[0]
n_batch = tokens.shape[0]
sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
no_speech_probs = [np.nan] * n_batch
try:
embeddings = []
for i in range(self.sample_len):
logits, token_embeddings = self.inference.logits(tokens, audio_features, include_embeddings=True)
if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs
probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
# now we need to consider the logits at the last token only
logits = logits[:, -1]
token_embeddings = token_embeddings[:, :, -1]
# Append embeddings together
embeddings.append(token_embeddings)
# apply the logit filters, e.g. for suppressing or applying penalty to
for logit_filter in self.logit_filters:
logit_filter.apply(logits, tokens)
# expand the tokens tensor with the selected next tokens
tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
if completed or tokens.shape[-1] > self.n_ctx:
break
finally:
if completed:
embeddings = embeddings[:-1]
embeddings = np.stack(embeddings, 2)
self.inference.cleanup_caching()
return tokens, sum_logprobs, no_speech_probs, embeddings
@torch.no_grad()
def run(self, mel: Tensor) -> List[DecodingResult]:
self.decoder.reset()
tokenizer: Tokenizer = self.tokenizer
n_audio: int = mel.shape[0]
# encoder forward pass
forward_pass: Tuple[Tensor, np.ndarray] = self._get_audio_features(mel, include_embeddings=True)
audio_features, encoder_embeddings = forward_pass
tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
# detect language if requested, overwriting the language token
languages, language_probs = self._detect_language(audio_features, tokens)
if self.options.task == "lang_id":
return [
DecodingResult(audio_features=features, language=language, language_probs=probs)
for features, language, probs in zip(audio_features, languages, language_probs)
]
# repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
audio_features = audio_features.repeat_interleave(self.n_group, dim=0)
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
# call the main sampling loop
tokens, sum_logprobs, no_speech_probs, decoder_embeddings = self._main_loop(audio_features, tokens)
# reshape the tensors to have (n_audio, n_group) as the first two dimensions
audio_features = audio_features[:: self.n_group]
no_speech_probs = no_speech_probs[:: self.n_group]
assert audio_features.shape[0] == len(no_speech_probs) == n_audio
tokens = tokens.reshape(n_audio, self.n_group, -1)
sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
# get the final candidates for each group, and slice between the first sampled token and EOT
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
tokens: List[List[Tensor]] = [
[t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] for s in tokens
]
# select the top-ranked sample in each group
selected = self.sequence_ranker.rank(tokens, sum_logprobs)
tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)]
fields = (texts, languages, tokens, audio_features, avg_logprobs, no_speech_probs)
if len(set(map(len, fields))) != 1:
raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
return [
DecodingResult(
audio_features=features,
language=language,
tokens=tokens,
text=text,
avg_logprob=avg_logprob,
no_speech_prob=no_speech_prob,
temperature=self.options.temperature,
compression_ratio=compression_ratio(text),
encoder_embeddings=encoder_embeddings,
decoder_embeddings=decoder_embeddings
)
for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields)
]
@torch.no_grad()
def decode(model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions()) -> Union[DecodingResult, List[DecodingResult]]:
"""
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
Parameters
----------
model: Whisper
the Whisper model instance
mel: torch.Tensor, shape = (80, 3000) or (*, 80, 3000)
A tensor containing the Mel spectrogram(s)
options: DecodingOptions
A dataclass that contains all necessary options for decoding 30-second segments
Returns
-------
result: Union[DecodingResult, List[DecodingResult]]
The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
"""
single = mel.ndim == 2
if single:
mel = mel.unsqueeze(0)
result = DecodingTask(model, options).run(mel)
if single:
result = result[0]
return result

View File

@ -0,0 +1,290 @@
from dataclasses import dataclass
from typing import Dict
from typing import Iterable, Optional
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
from torch import nn
from .transcribe import transcribe as transcribe_function
from .decoding import detect_language as detect_language_function, decode as decode_function
@dataclass
class ModelDimensions:
n_mels: int
n_audio_ctx: int
n_audio_state: int
n_audio_head: int
n_audio_layer: int
n_vocab: int
n_text_ctx: int
n_text_state: int
n_text_head: int
n_text_layer: int
class LayerNorm(nn.LayerNorm):
def forward(self, x: Tensor) -> Tensor:
return super().forward(x.float()).type(x.dtype)
class Linear(nn.Linear):
def forward(self, x: Tensor) -> Tensor:
return F.linear(
x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype)
)
class Conv1d(nn.Conv1d):
def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor:
return super()._conv_forward(
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
)
def sinusoids(length, channels, max_timescale=10000):
"""Returns sinusoids for positional embedding"""
assert channels % 2 == 0
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
class MultiHeadAttention(nn.Module):
def __init__(self, n_state: int, n_head: int):
super().__init__()
self.n_head = n_head
self.query = Linear(n_state, n_state)
self.key = Linear(n_state, n_state, bias=False)
self.value = Linear(n_state, n_state)
self.out = Linear(n_state, n_state)
def forward(
self,
x: Tensor,
xa: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
kv_cache: Optional[dict] = None,
):
q = self.query(x)
if kv_cache is None or xa is None:
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
# otherwise, perform key/value projections for self- or cross-attention as usual.
k = self.key(x if xa is None else xa)
v = self.value(x if xa is None else xa)
else:
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
k = kv_cache.get(self.key, self.key(xa))
v = kv_cache.get(self.value, self.value(xa))
wv = self.qkv_attention(q, k, v, mask)
return self.out(wv)
def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):
n_batch, n_ctx, n_state = q.shape
scale = (n_state // self.n_head) ** -0.25
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
qk = q @ k
if mask is not None:
qk = qk + mask[:n_ctx, :n_ctx]
w = F.softmax(qk.float(), dim=-1).to(q.dtype)
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
class ResidualAttentionBlock(nn.Module):
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
super().__init__()
self.attn = MultiHeadAttention(n_state, n_head)
self.attn_ln = LayerNorm(n_state)
self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
n_mlp = n_state * 4
self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state))
self.mlp_ln = LayerNorm(n_state)
def forward(
self,
x: Tensor,
xa: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
kv_cache: Optional[dict] = None,
):
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)
if self.cross_attn:
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)
x = x + self.mlp(self.mlp_ln(x))
return x
class AudioEncoder(nn.Module):
def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
super().__init__()
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
[ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
)
self.ln_post = LayerNorm(n_state)
def forward(self, x: Tensor, include_embeddings: bool = False):
"""
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
the mel spectrogram of the audio
include_embeddings: bool
whether to include intermediate steps in the output
"""
x = F.gelu(self.conv1(x))
x = F.gelu(self.conv2(x))
x = x.permute(0, 2, 1)
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
x = (x + self.positional_embedding).to(x.dtype)
if include_embeddings:
embeddings = [x.cpu().detach().numpy()]
for block in self.blocks:
x = block(x)
if include_embeddings:
embeddings.append(x.cpu().detach().numpy())
x = self.ln_post(x)
if include_embeddings:
embeddings = np.stack(embeddings, axis=1)
return x, embeddings
else:
return x
class TextDecoder(nn.Module):
def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
super().__init__()
self.token_embedding = nn.Embedding(n_vocab, n_state)
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
[ResidualAttentionBlock(n_state, n_head, cross_attention=True) for _ in range(n_layer)]
)
self.ln = LayerNorm(n_state)
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
self.register_buffer("mask", mask, persistent=False)
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None, include_embeddings: bool = False):
"""
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
the text tokens
xa : torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx)
the encoded audio features to be attended on
include_embeddings : bool
Whether to include intermediate values in the output to this function
"""
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]]
x = x.to(xa.dtype)
if include_embeddings:
embeddings = [x.cpu().detach().numpy()]
for block in self.blocks:
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
if include_embeddings:
embeddings.append(x.cpu().detach().numpy())
x = self.ln(x)
logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
if include_embeddings:
embeddings = np.stack(embeddings, axis=1)
return logits, embeddings
else:
return logits
class Whisper(nn.Module):
def __init__(self, dims: ModelDimensions):
super().__init__()
self.dims = dims
self.encoder = AudioEncoder(
self.dims.n_mels,
self.dims.n_audio_ctx,
self.dims.n_audio_state,
self.dims.n_audio_head,
self.dims.n_audio_layer,
)
self.decoder = TextDecoder(
self.dims.n_vocab,
self.dims.n_text_ctx,
self.dims.n_text_state,
self.dims.n_text_head,
self.dims.n_text_layer,
)
def embed_audio(self, mel: torch.Tensor):
return self.encoder.forward(mel)
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
return self.decoder.forward(tokens, audio_features)
def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]:
return self.decoder(tokens, self.encoder(mel))
@property
def device(self):
return next(self.parameters()).device
@property
def is_multilingual(self):
return self.dims.n_vocab == 51865
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
"""
The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
tensors calculated for the previous positions. This method returns a dictionary that stores
all caches, and the necessary hooks for the key and value projection modules that save the
intermediate tensors to be reused during later calculations.
Returns
-------
cache : Dict[nn.Module, torch.Tensor]
A dictionary object mapping the key/value projection modules to its cache
hooks : List[RemovableHandle]
List of PyTorch RemovableHandle objects to stop the hooks to be called
"""
cache = {**cache} if cache is not None else {}
hooks = []
def save_to_cache(module, _, output):
if module not in cache or output.shape[1] > self.decoder.positional_embedding.shape[0]:
cache[module] = output # save as-is, for the first token or cross attention
else:
cache[module] = torch.cat([cache[module], output], dim=1).detach()
return cache[module]
def install_hooks(layer: nn.Module):
if isinstance(layer, MultiHeadAttention):
hooks.append(layer.key.register_forward_hook(save_to_cache))
hooks.append(layer.value.register_forward_hook(save_to_cache))
self.decoder.apply(install_hooks)
return cache, hooks
detect_language = detect_language_function
transcribe = transcribe_function
decode = decode_function

View File

@ -0,0 +1,2 @@
from .basic import BasicTextNormalizer
from .english import EnglishTextNormalizer

View File

@ -0,0 +1,71 @@
import re
import unicodedata
import regex
# non-ASCII letters that are not separated by "NFKD" normalization
ADDITIONAL_DIACRITICS = {
"œ": "oe",
"Œ": "OE",
"ø": "o",
"Ø": "O",
"æ": "ae",
"Æ": "AE",
"ß": "ss",
"": "SS",
"đ": "d",
"Đ": "D",
"ð": "d",
"Ð": "D",
"þ": "th",
"Þ": "th",
"ł": "l",
"Ł": "L",
}
def remove_symbols_and_diacritics(s: str, keep=""):
"""
Replace any other markers, symbols, and punctuations with a space,
and drop any diacritics (category 'Mn' and some manual mappings)
"""
return "".join(
c
if c in keep
else ADDITIONAL_DIACRITICS[c]
if c in ADDITIONAL_DIACRITICS
else ""
if unicodedata.category(c) == "Mn"
else " "
if unicodedata.category(c)[0] in "MSP"
else c
for c in unicodedata.normalize("NFKD", s)
)
def remove_symbols(s: str):
"""
Replace any other markers, symbols, punctuations with a space, keeping diacritics
"""
return "".join(
" " if unicodedata.category(c)[0] in "MSP" else c for c in unicodedata.normalize("NFKC", s)
)
class BasicTextNormalizer:
def __init__(self, remove_diacritics: bool = False, split_letters: bool = False):
self.clean = remove_symbols_and_diacritics if remove_diacritics else remove_symbols
self.split_letters = split_letters
def __call__(self, s: str):
s = s.lower()
s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
s = self.clean(s).lower()
if self.split_letters:
s = " ".join(regex.findall(r"\X", s, regex.U))
s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space
return s

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,543 @@
import json
import os
import re
from fractions import Fraction
from typing import Iterator, List, Match, Optional, Union
from more_itertools import windowed
from .basic import remove_symbols_and_diacritics
class EnglishNumberNormalizer:
"""
Convert any spelled-out numbers into arabic numbers, while handling:
- remove any commas
- keep the suffixes such as: `1960s`, `274th`, `32nd`, etc.
- spell out currency symbols after the number. e.g. `$20 million` -> `20000000 dollars`
- spell out `one` and `ones`
- interpret successive single-digit numbers as nominal: `one oh one` -> `101`
"""
def __init__(self):
super().__init__()
self.zeros = {"o", "oh", "zero"}
self.ones = {
name: i
for i, name in enumerate(
[
"one",
"two",
"three",
"four",
"five",
"six",
"seven",
"eight",
"nine",
"ten",
"eleven",
"twelve",
"thirteen",
"fourteen",
"fifteen",
"sixteen",
"seventeen",
"eighteen",
"nineteen",
],
start=1,
)
}
self.ones_plural = {
"sixes" if name == "six" else name + "s": (value, "s")
for name, value in self.ones.items()
}
self.ones_ordinal = {
"zeroth": (0, "th"),
"first": (1, "st"),
"second": (2, "nd"),
"third": (3, "rd"),
"fifth": (5, "th"),
"twelfth": (12, "th"),
**{
name + ("h" if name.endswith("t") else "th"): (value, "th")
for name, value in self.ones.items()
if value > 3 and value != 5 and value != 12
},
}
self.ones_suffixed = {**self.ones_plural, **self.ones_ordinal}
self.tens = {
"twenty": 20,
"thirty": 30,
"forty": 40,
"fifty": 50,
"sixty": 60,
"seventy": 70,
"eighty": 80,
"ninety": 90,
}
self.tens_plural = {
name.replace("y", "ies"): (value, "s") for name, value in self.tens.items()
}
self.tens_ordinal = {
name.replace("y", "ieth"): (value, "th") for name, value in self.tens.items()
}
self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal}
self.multipliers = {
"hundred": 100,
"thousand": 1_000,
"million": 1_000_000,
"billion": 1_000_000_000,
"trillion": 1_000_000_000_000,
"quadrillion": 1_000_000_000_000_000,
"quintillion": 1_000_000_000_000_000_000,
"sextillion": 1_000_000_000_000_000_000_000,
"septillion": 1_000_000_000_000_000_000_000_000,
"octillion": 1_000_000_000_000_000_000_000_000_000,
"nonillion": 1_000_000_000_000_000_000_000_000_000_000,
"decillion": 1_000_000_000_000_000_000_000_000_000_000_000,
}
self.multipliers_plural = {
name + "s": (value, "s") for name, value in self.multipliers.items()
}
self.multipliers_ordinal = {
name + "th": (value, "th") for name, value in self.multipliers.items()
}
self.multipliers_suffixed = {**self.multipliers_plural, **self.multipliers_ordinal}
self.decimals = {*self.ones, *self.tens, *self.zeros}
self.preceding_prefixers = {
"minus": "-",
"negative": "-",
"plus": "+",
"positive": "+",
}
self.following_prefixers = {
"pound": "£",
"pounds": "£",
"euro": "",
"euros": "",
"dollar": "$",
"dollars": "$",
"cent": "¢",
"cents": "¢",
}
self.prefixes = set(
list(self.preceding_prefixers.values()) + list(self.following_prefixers.values())
)
self.suffixers = {
"per": {"cent": "%"},
"percent": "%",
}
self.specials = {"and", "double", "triple", "point"}
self.words = set(
[
key
for mapping in [
self.zeros,
self.ones,
self.ones_suffixed,
self.tens,
self.tens_suffixed,
self.multipliers,
self.multipliers_suffixed,
self.preceding_prefixers,
self.following_prefixers,
self.suffixers,
self.specials,
]
for key in mapping
]
)
self.literal_words = {"one", "ones"}
def process_words(self, words: List[str]) -> Iterator[str]:
prefix: Optional[str] = None
value: Optional[Union[str, int]] = None
skip = False
def to_fraction(s: str):
try:
return Fraction(s)
except ValueError:
return None
def output(result: Union[str, int]):
nonlocal prefix, value
result = str(result)
if prefix is not None:
result = prefix + result
value = None
prefix = None
return result
if len(words) == 0:
return
for prev, current, next in windowed([None] + words + [None], 3):
if skip:
skip = False
continue
next_is_numeric = next is not None and re.match(r"^\d+(\.\d+)?$", next)
has_prefix = current[0] in self.prefixes
current_without_prefix = current[1:] if has_prefix else current
if re.match(r"^\d+(\.\d+)?$", current_without_prefix):
# arabic numbers (potentially with signs and fractions)
f = to_fraction(current_without_prefix)
assert f is not None
if value is not None:
if isinstance(value, str) and value.endswith("."):
# concatenate decimals / ip address components
value = str(value) + str(current)
continue
else:
yield output(value)
prefix = current[0] if has_prefix else prefix
if f.denominator == 1:
value = f.numerator # store integers as int
else:
value = current_without_prefix
elif current not in self.words:
# non-numeric words
if value is not None:
yield output(value)
yield output(current)
elif current in self.zeros:
value = str(value or "") + "0"
elif current in self.ones:
ones = self.ones[current]
if value is None:
value = ones
elif isinstance(value, str) or prev in self.ones:
if prev in self.tens and ones < 10: # replace the last zero with the digit
assert value[-1] == "0"
value = value[:-1] + str(ones)
else:
value = str(value) + str(ones)
elif ones < 10:
if value % 10 == 0:
value += ones
else:
value = str(value) + str(ones)
else: # eleven to nineteen
if value % 100 == 0:
value += ones
else:
value = str(value) + str(ones)
elif current in self.ones_suffixed:
# ordinal or cardinal; yield the number right away
ones, suffix = self.ones_suffixed[current]
if value is None:
yield output(str(ones) + suffix)
elif isinstance(value, str) or prev in self.ones:
if prev in self.tens and ones < 10:
assert value[-1] == "0"
yield output(value[:-1] + str(ones) + suffix)
else:
yield output(str(value) + str(ones) + suffix)
elif ones < 10:
if value % 10 == 0:
yield output(str(value + ones) + suffix)
else:
yield output(str(value) + str(ones) + suffix)
else: # eleven to nineteen
if value % 100 == 0:
yield output(str(value + ones) + suffix)
else:
yield output(str(value) + str(ones) + suffix)
value = None
elif current in self.tens:
tens = self.tens[current]
if value is None:
value = tens
elif isinstance(value, str):
value = str(value) + str(tens)
else:
if value % 100 == 0:
value += tens
else:
value = str(value) + str(tens)
elif current in self.tens_suffixed:
# ordinal or cardinal; yield the number right away
tens, suffix = self.tens_suffixed[current]
if value is None:
yield output(str(tens) + suffix)
elif isinstance(value, str):
yield output(str(value) + str(tens) + suffix)
else:
if value % 100 == 0:
yield output(str(value + tens) + suffix)
else:
yield output(str(value) + str(tens) + suffix)
elif current in self.multipliers:
multiplier = self.multipliers[current]
if value is None:
value = multiplier
elif isinstance(value, str) or value == 0:
f = to_fraction(value)
p = f * multiplier if f is not None else None
if f is not None and p.denominator == 1:
value = p.numerator
else:
yield output(value)
value = multiplier
else:
before = value // 1000 * 1000
residual = value % 1000
value = before + residual * multiplier
elif current in self.multipliers_suffixed:
multiplier, suffix = self.multipliers_suffixed[current]
if value is None:
yield output(str(multiplier) + suffix)
elif isinstance(value, str):
f = to_fraction(value)
p = f * multiplier if f is not None else None
if f is not None and p.denominator == 1:
yield output(str(p.numerator) + suffix)
else:
yield output(value)
yield output(str(multiplier) + suffix)
else: # int
before = value // 1000 * 1000
residual = value % 1000
value = before + residual * multiplier
yield output(str(value) + suffix)
value = None
elif current in self.preceding_prefixers:
# apply prefix (positive, minus, etc.) if it precedes a number
if value is not None:
yield output(value)
if next in self.words or next_is_numeric:
prefix = self.preceding_prefixers[current]
else:
yield output(current)
elif current in self.following_prefixers:
# apply prefix (dollars, cents, etc.) only after a number
if value is not None:
prefix = self.following_prefixers[current]
yield output(value)
else:
yield output(current)
elif current in self.suffixers:
# apply suffix symbols (percent -> '%')
if value is not None:
suffix = self.suffixers[current]
if isinstance(suffix, dict):
if next in suffix:
yield output(str(value) + suffix[next])
skip = True
else:
yield output(value)
yield output(current)
else:
yield output(str(value) + suffix)
else:
yield output(current)
elif current in self.specials:
if next not in self.words and not next_is_numeric:
# apply special handling only if the next word can be numeric
if value is not None:
yield output(value)
yield output(current)
elif current == "and":
# ignore "and" after hundreds, thousands, etc.
if prev not in self.multipliers:
if value is not None:
yield output(value)
yield output(current)
elif current == "double" or current == "triple":
if next in self.ones or next in self.zeros:
repeats = 2 if current == "double" else 3
ones = self.ones.get(next, 0)
value = str(value or "") + str(ones) * repeats
skip = True
else:
if value is not None:
yield output(value)
yield output(current)
elif current == "point":
if next in self.decimals or next_is_numeric:
value = str(value or "") + "."
else:
# should all have been covered at this point
raise ValueError(f"Unexpected token: {current}")
else:
# all should have been covered at this point
raise ValueError(f"Unexpected token: {current}")
if value is not None:
yield output(value)
def preprocess(self, s: str):
# replace "<number> and a half" with "<number> point five"
results = []
segments = re.split(r"\band\s+a\s+half\b", s)
for i, segment in enumerate(segments):
if len(segment.strip()) == 0:
continue
if i == len(segments) - 1:
results.append(segment)
else:
results.append(segment)
last_word = segment.rsplit(maxsplit=2)[-1]
if last_word in self.decimals or last_word in self.multipliers:
results.append("point five")
else:
results.append("and a half")
s = " ".join(results)
# put a space at number/letter boundary
s = re.sub(r"([a-z])([0-9])", r"\1 \2", s)
s = re.sub(r"([0-9])([a-z])", r"\1 \2", s)
# but remove spaces which could be a suffix
s = re.sub(r"([0-9])\s+(st|nd|rd|th|s)\b", r"\1\2", s)
return s
def postprocess(self, s: str):
def combine_cents(m: Match):
try:
currency = m.group(1)
integer = m.group(2)
cents = int(m.group(3))
return f"{currency}{integer}.{cents:02d}"
except ValueError:
return m.string
def extract_cents(m: Match):
try:
return f"¢{int(m.group(1))}"
except ValueError:
return m.string
# apply currency postprocessing; "$2 and ¢7" -> "$2.07"
s = re.sub(r"([€£$])([0-9]+) (?:and )?¢([0-9]{1,2})\b", combine_cents, s)
s = re.sub(r"[€£$]0.([0-9]{1,2})\b", extract_cents, s)
# write "one(s)" instead of "1(s)", just for the readability
s = re.sub(r"\b1(s?)\b", r"one\1", s)
return s
def __call__(self, s: str):
s = self.preprocess(s)
s = " ".join(word for word in self.process_words(s.split()) if word is not None)
s = self.postprocess(s)
return s
class EnglishSpellingNormalizer:
"""
Applies British-American spelling mappings as listed in [1].
[1] https://www.tysto.com/uk-us-spelling-list.html
"""
def __init__(self):
mapping_path = os.path.join(os.path.dirname(__file__), "english.json")
self.mapping = json.load(open(mapping_path))
def __call__(self, s: str):
return " ".join(self.mapping.get(word, word) for word in s.split())
class EnglishTextNormalizer:
def __init__(self):
self.ignore_patterns = r"\b(hmm|mm|mhm|mmm|uh|um)\b"
self.replacers = {
# common contractions
r"\bwon't\b": "will not",
r"\bcan't\b": "can not",
r"\blet's\b": "let us",
r"\bain't\b": "aint",
r"\by'all\b": "you all",
r"\bwanna\b": "want to",
r"\bgotta\b": "got to",
r"\bgonna\b": "going to",
r"\bi'ma\b": "i am going to",
r"\bimma\b": "i am going to",
r"\bwoulda\b": "would have",
r"\bcoulda\b": "could have",
r"\bshoulda\b": "should have",
r"\bma'am\b": "madam",
# contractions in titles/prefixes
r"\bmr\b": "mister ",
r"\bmrs\b": "missus ",
r"\bst\b": "saint ",
r"\bdr\b": "doctor ",
r"\bprof\b": "professor ",
r"\bcapt\b": "captain ",
r"\bgov\b": "governor ",
r"\bald\b": "alderman ",
r"\bgen\b": "general ",
r"\bsen\b": "senator ",
r"\brep\b": "representative ",
r"\bpres\b": "president ",
r"\brev\b": "reverend ",
r"\bhon\b": "honorable ",
r"\basst\b": "assistant ",
r"\bassoc\b": "associate ",
r"\blt\b": "lieutenant ",
r"\bcol\b": "colonel ",
r"\bjr\b": "junior ",
r"\bsr\b": "senior ",
r"\besq\b": "esquire ",
# prefect tenses, ideally it should be any past participles, but it's harder..
r"'d been\b": " had been",
r"'s been\b": " has been",
r"'d gone\b": " had gone",
r"'s gone\b": " has gone",
r"'d done\b": " had done", # "'s done" is ambiguous
r"'s got\b": " has got",
# general contractions
r"n't\b": " not",
r"'re\b": " are",
r"'s\b": " is",
r"'d\b": " would",
r"'ll\b": " will",
r"'t\b": " not",
r"'ve\b": " have",
r"'m\b": " am",
}
self.standardize_numbers = EnglishNumberNormalizer()
self.standardize_spellings = EnglishSpellingNormalizer()
def __call__(self, s: str):
s = s.lower()
s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
s = re.sub(self.ignore_patterns, "", s)
s = re.sub(r"\s+'", "'", s) # standardize when there's a space before an apostrophe
for pattern, replacement in self.replacers.items():
s = re.sub(pattern, replacement, s)
s = re.sub(r"(\d),(\d)", r"\1\2", s) # remove commas between digits
s = re.sub(r"\.([^0-9]|$)", r" \1", s) # remove periods not followed by numbers
s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep some symbols for numerics
s = self.standardize_numbers(s)
s = self.standardize_spellings(s)
# now remove prefix/suffix symbols that are not preceded/followed by numbers
s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s)
s = re.sub(r"([^0-9])%", r"\1 ", s)
s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space
return s

View File

@ -0,0 +1,331 @@
import os
from dataclasses import dataclass
from functools import lru_cache
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from transformers import GPT2TokenizerFast
LANGUAGES = {
"en": "english",
"zh": "chinese",
"de": "german",
"es": "spanish",
"ru": "russian",
"ko": "korean",
"fr": "french",
"ja": "japanese",
"pt": "portuguese",
"tr": "turkish",
"pl": "polish",
"ca": "catalan",
"nl": "dutch",
"ar": "arabic",
"sv": "swedish",
"it": "italian",
"id": "indonesian",
"hi": "hindi",
"fi": "finnish",
"vi": "vietnamese",
"iw": "hebrew",
"uk": "ukrainian",
"el": "greek",
"ms": "malay",
"cs": "czech",
"ro": "romanian",
"da": "danish",
"hu": "hungarian",
"ta": "tamil",
"no": "norwegian",
"th": "thai",
"ur": "urdu",
"hr": "croatian",
"bg": "bulgarian",
"lt": "lithuanian",
"la": "latin",
"mi": "maori",
"ml": "malayalam",
"cy": "welsh",
"sk": "slovak",
"te": "telugu",
"fa": "persian",
"lv": "latvian",
"bn": "bengali",
"sr": "serbian",
"az": "azerbaijani",
"sl": "slovenian",
"kn": "kannada",
"et": "estonian",
"mk": "macedonian",
"br": "breton",
"eu": "basque",
"is": "icelandic",
"hy": "armenian",
"ne": "nepali",
"mn": "mongolian",
"bs": "bosnian",
"kk": "kazakh",
"sq": "albanian",
"sw": "swahili",
"gl": "galician",
"mr": "marathi",
"pa": "punjabi",
"si": "sinhala",
"km": "khmer",
"sn": "shona",
"yo": "yoruba",
"so": "somali",
"af": "afrikaans",
"oc": "occitan",
"ka": "georgian",
"be": "belarusian",
"tg": "tajik",
"sd": "sindhi",
"gu": "gujarati",
"am": "amharic",
"yi": "yiddish",
"lo": "lao",
"uz": "uzbek",
"fo": "faroese",
"ht": "haitian creole",
"ps": "pashto",
"tk": "turkmen",
"nn": "nynorsk",
"mt": "maltese",
"sa": "sanskrit",
"lb": "luxembourgish",
"my": "myanmar",
"bo": "tibetan",
"tl": "tagalog",
"mg": "malagasy",
"as": "assamese",
"tt": "tatar",
"haw": "hawaiian",
"ln": "lingala",
"ha": "hausa",
"ba": "bashkir",
"jw": "javanese",
"su": "sundanese",
}
# language code lookup by name, with a few language aliases
TO_LANGUAGE_CODE = {
**{language: code for code, language in LANGUAGES.items()},
"burmese": "my",
"valencian": "ca",
"flemish": "nl",
"haitian": "ht",
"letzeburgesch": "lb",
"pushto": "ps",
"panjabi": "pa",
"moldavian": "ro",
"moldovan": "ro",
"sinhalese": "si",
"castilian": "es",
}
@dataclass(frozen=True)
class Tokenizer:
"""A thin wrapper around `GPT2TokenizerFast` providing quick access to special tokens"""
tokenizer: "GPT2TokenizerFast"
language: Optional[str]
sot_sequence: Tuple[int]
def encode(self, text, **kwargs):
return self.tokenizer.encode(text, **kwargs)
def decode(self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs):
return self.tokenizer.decode(token_ids, **kwargs)
def decode_with_timestamps(self, tokens) -> str:
"""
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`.
This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
"""
outputs = [[]]
for token in tokens:
if token >= self.timestamp_begin:
timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
outputs.append(timestamp)
outputs.append([])
else:
outputs[-1].append(token)
outputs = [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
return "".join(outputs)
@property
@lru_cache()
def eot(self) -> int:
return self.tokenizer.eos_token_id
@property
@lru_cache()
def sot(self) -> int:
return self._get_single_token_id("<|startoftranscript|>")
@property
@lru_cache()
def sot_lm(self) -> int:
return self._get_single_token_id("<|startoflm|>")
@property
@lru_cache()
def sot_prev(self) -> int:
return self._get_single_token_id("<|startofprev|>")
@property
@lru_cache()
def no_speech(self) -> int:
return self._get_single_token_id("<|nospeech|>")
@property
@lru_cache()
def no_timestamps(self) -> int:
return self._get_single_token_id("<|notimestamps|>")
@property
@lru_cache()
def timestamp_begin(self) -> int:
return self.tokenizer.all_special_ids[-1] + 1
@property
@lru_cache()
def language_token(self) -> int:
"""Returns the token id corresponding to the value of the `language` field"""
if self.language is None:
raise ValueError(f"This tokenizer does not have language token configured")
additional_tokens = dict(
zip(
self.tokenizer.additional_special_tokens,
self.tokenizer.additional_special_tokens_ids,
)
)
candidate = f"<|{self.language}|>"
if candidate in additional_tokens:
return additional_tokens[candidate]
raise KeyError(f"Language {self.language} not found in tokenizer.")
@property
@lru_cache()
def all_language_tokens(self) -> Tuple[int]:
result = []
for token, token_id in zip(
self.tokenizer.additional_special_tokens,
self.tokenizer.additional_special_tokens_ids,
):
if token.strip("<|>") in LANGUAGES:
result.append(token_id)
return tuple(result)
@property
@lru_cache()
def all_language_codes(self) -> Tuple[str]:
return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens)
@property
@lru_cache()
def sot_sequence_including_notimestamps(self) -> Tuple[int]:
return tuple(list(self.sot_sequence) + [self.no_timestamps])
@property
@lru_cache()
def non_speech_tokens(self) -> Tuple[int]:
"""
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
-
- ( SPEAKING FOREIGN LANGUAGE )
- [DAVID] Hey there,
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
"""
symbols = list("\"#()*+/:;<=>@[\\]^_`{|}~「」『』")
symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
# symbols that may be a single token or multiple tokens depending on the tokenizer.
# In case they're multiple tokens, suppress the first token, which is safe because:
# These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
# in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
miscellaneous = set("♩♪♫♬♭♮♯")
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
result = {self.tokenizer.encode(" -")[0], self.tokenizer.encode(" '")[0]}
for symbol in symbols + list(miscellaneous):
for tokens in [self.tokenizer.encode(symbol), self.tokenizer.encode(" " + symbol)]:
if len(tokens) == 1 or symbol in miscellaneous:
result.add(tokens[0])
return tuple(sorted(result))
def _get_single_token_id(self, text) -> int:
tokens = self.tokenizer.encode(text)
assert len(tokens) == 1, f"{text} is not encoded as a single token"
return tokens[0]
@lru_cache(maxsize=None)
def build_tokenizer(name: str = "gpt2"):
os.environ["TOKENIZERS_PARALLELISM"] = "false"
path = os.path.join(os.path.dirname(__file__), "assets", name)
tokenizer = GPT2TokenizerFast.from_pretrained(path)
specials = [
"<|startoftranscript|>",
*[f"<|{lang}|>" for lang in LANGUAGES.keys()],
"<|translate|>",
"<|transcribe|>",
"<|startoflm|>",
"<|startofprev|>",
"<|nospeech|>",
"<|notimestamps|>",
]
tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
return tokenizer
@lru_cache(maxsize=None)
def get_tokenizer(
multilingual: bool,
*,
task: Optional[str] = None, # Literal["transcribe", "translate", None]
language: Optional[str] = None,
) -> Tokenizer:
if language is not None:
language = language.lower()
if language not in LANGUAGES:
if language in TO_LANGUAGE_CODE:
language = TO_LANGUAGE_CODE[language]
else:
raise ValueError(f"Unsupported language: {language}")
if multilingual:
tokenizer_name = "multilingual"
task = task or "transcribe"
language = language or "en"
else:
tokenizer_name = "gpt2"
task = None
language = None
tokenizer = build_tokenizer(name=tokenizer_name)
all_special_ids: List[int] = tokenizer.all_special_ids
sot: int = all_special_ids[1]
translate: int = all_special_ids[-6]
transcribe: int = all_special_ids[-5]
langs = tuple(LANGUAGES.keys())
sot_sequence = [sot]
if language is not None:
sot_sequence.append(sot + 1 + langs.index(language))
if task is not None:
sot_sequence.append(transcribe if task == "transcribe" else translate)
return Tokenizer(tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence))

View File

@ -0,0 +1,207 @@
import argparse
import os
import warnings
from typing import List, Optional, Tuple, Union, TYPE_CHECKING
import numpy as np
import torch
import tqdm
from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram
from .decoding import DecodingOptions, DecodingResult
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, write_vtt, write_srt
if TYPE_CHECKING:
from .model import Whisper
def transcribe(
model: "Whisper",
audio: Union[str, np.ndarray, torch.Tensor],
*,
verbose: Optional[bool] = None,
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
compression_ratio_threshold: Optional[float] = 2.4,
logprob_threshold: Optional[float] = -1.0,
no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = True,
force_extraction: bool = False,
**decode_options,
):
"""
Transcribe an audio file using Whisper
Parameters
----------
model: Whisper
The Whisper model instance
audio: Union[str, np.ndarray, torch.Tensor]
The path to the audio file to open, or the audio waveform
verbose: bool
Whether to display the text being decoded to the console. If True, displays all the details,
If False, displays minimal details. If None, does not display anything
temperature: Union[float, Tuple[float, ...]]
Temperature for sampling. It can be a tuple of temperatures, which will be successfully used
upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
compression_ratio_threshold: float
If the gzip compression ratio is above this value, treat as failed
logprob_threshold: float
If the average log probability over sampled tokens is below this value, treat as failed
no_speech_threshold: float
If the no_speech probability is higher than this value AND the average log probability
over sampled tokens is below `logprob_threshold`, consider the segment as silent
condition_on_previous_text: bool
if True, the previous output of the model is provided as a prompt for the next window;
disabling may make the text inconsistent across windows, but the model becomes less prone to
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
decode_options: dict
Keyword arguments to construct `DecodingOptions` instances
Returns
-------
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
"""
dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
if model.device == torch.device("cpu"):
if torch.cuda.is_available():
warnings.warn("Performing inference on CPU when CUDA is available")
if dtype == torch.float16:
warnings.warn("FP16 is not supported on CPU; using FP32 instead")
dtype = torch.float32
if dtype == torch.float32:
decode_options["fp16"] = False
mel = log_mel_spectrogram(audio)
all_segments = []
def add_segment(
*, start: float, end: float, encoder_embeddings
):
all_segments.append(
{
"start": start,
"end": end,
"encoder_embeddings":encoder_embeddings,
}
)
# show the progress bar when verbose is False (otherwise the transcribed text will be printed)
num_frames = mel.shape[-1]
seek = 0
previous_seek_value = seek
sample_skip = 3000 #
with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar:
while seek < num_frames:
# seek是开始的帧数
end_seek = min(seek + sample_skip, num_frames)
segment = pad_or_trim(mel[:,seek:seek+sample_skip], N_FRAMES).to(model.device).to(dtype)
single = segment.ndim == 2
if single:
segment = segment.unsqueeze(0)
if dtype == torch.float16:
segment = segment.half()
audio_features, embeddings = model.encoder(segment, include_embeddings = True)
encoder_embeddings = embeddings
#print(f"encoder_embeddings shape {encoder_embeddings.shape}")
add_segment(
start=seek,
end=end_seek,
#text_tokens=tokens,
#result=result,
encoder_embeddings=encoder_embeddings,
)
seek+=sample_skip
return dict(segments=all_segments)
def cli():
from . import available_models
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
args = parser.parse_args().__dict__
model_name: str = args.pop("model")
model_dir: str = args.pop("model_dir")
output_dir: str = args.pop("output_dir")
device: str = args.pop("device")
os.makedirs(output_dir, exist_ok=True)
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
if args["language"] is not None:
warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
args["language"] = "en"
temperature = args.pop("temperature")
temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback")
if temperature_increment_on_fallback is not None:
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
else:
temperature = [temperature]
threads = args.pop("threads")
if threads > 0:
torch.set_num_threads(threads)
from . import load_model
model = load_model(model_name, device=device, download_root=model_dir)
for audio_path in args.pop("audio"):
result = transcribe(model, audio_path, temperature=temperature, **args)
audio_basename = os.path.basename(audio_path)
# save TXT
with open(os.path.join(output_dir, audio_basename + ".txt"), "w", encoding="utf-8") as txt:
write_txt(result["segments"], file=txt)
# save VTT
with open(os.path.join(output_dir, audio_basename + ".vtt"), "w", encoding="utf-8") as vtt:
write_vtt(result["segments"], file=vtt)
# save SRT
with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt:
write_srt(result["segments"], file=srt)
if __name__ == '__main__':
cli()

View File

@ -0,0 +1,87 @@
import zlib
from typing import Iterator, TextIO
def exact_div(x, y):
assert x % y == 0
return x // y
def str2bool(string):
str2val = {"True": True, "False": False}
if string in str2val:
return str2val[string]
else:
raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
def optional_int(string):
return None if string == "None" else int(string)
def optional_float(string):
return None if string == "None" else float(string)
def compression_ratio(text) -> float:
return len(text) / len(zlib.compress(text.encode("utf-8")))
def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = '.'):
assert seconds >= 0, "non-negative timestamp expected"
milliseconds = round(seconds * 1000.0)
hours = milliseconds // 3_600_000
milliseconds -= hours * 3_600_000
minutes = milliseconds // 60_000
milliseconds -= minutes * 60_000
seconds = milliseconds // 1_000
milliseconds -= seconds * 1_000
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
def write_txt(transcript: Iterator[dict], file: TextIO):
for segment in transcript:
print(segment['text'].strip(), file=file, flush=True)
def write_vtt(transcript: Iterator[dict], file: TextIO):
print("WEBVTT\n", file=file)
for segment in transcript:
print(
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
f"{segment['text'].strip().replace('-->', '->')}\n",
file=file,
flush=True,
)
def write_srt(transcript: Iterator[dict], file: TextIO):
"""
Write a transcript to a file in SRT format.
Example usage:
from pathlib import Path
from whisper.utils import write_srt
result = transcribe(model, audio_path, temperature=temperature, **args)
# save SRT
audio_basename = Path(audio_path).stem
with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
write_srt(result["segments"], file=srt)
"""
for i, segment in enumerate(transcript, start=1):
# write srt lines
print(
f"{i}\n"
f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
f"{segment['text'].strip().replace('-->', '->')}\n",
file=file,
flush=True,
)

View File

@ -4,50 +4,19 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from transformers import AutoModelForCTC, AutoProcessor, Wav2Vec2Processor, HubertModel from transformers import AutoModelForCTC, AutoProcessor, Wav2Vec2Processor, HubertModel
#import pyaudio
import soundfile as sf
import resampy
import queue import queue
from queue import Queue from queue import Queue
#from collections import deque #from collections import deque
from threading import Thread, Event from threading import Thread, Event
from io import BytesIO
from baseasr import BaseASR
def _read_frame(stream, exit_event, queue, chunk): class NerfASR(BaseASR):
def __init__(self, opt, parent):
while True: super().__init__(opt,parent)
if exit_event.is_set():
print(f'[INFO] read frame thread ends')
break
frame = stream.read(chunk, exception_on_overflow=False)
frame = np.frombuffer(frame, dtype=np.int16).astype(np.float32) / 32767 # [chunk]
queue.put(frame)
def _play_frame(stream, exit_event, queue, chunk):
while True:
if exit_event.is_set():
print(f'[INFO] play frame thread ends')
break
frame = queue.get()
frame = (frame * 32767).astype(np.int16).tobytes()
stream.write(frame, chunk)
class ASR:
def __init__(self, opt):
self.opt = opt
self.play = opt.asr_play #false
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.fps = opt.fps # 20 ms per frame
self.sample_rate = 16000
self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000)
self.mode = 'live' if opt.asr_wav == '' else 'file'
if 'esperanto' in self.opt.asr_model: if 'esperanto' in self.opt.asr_model:
self.audio_dim = 44 self.audio_dim = 44
elif 'deepspeech' in self.opt.asr_model: elif 'deepspeech' in self.opt.asr_model:
@ -62,40 +31,11 @@ class ASR:
self.context_size = opt.m self.context_size = opt.m
self.stride_left_size = opt.l self.stride_left_size = opt.l
self.stride_right_size = opt.r self.stride_right_size = opt.r
self.text = '[START]\n'
self.terminated = False
self.frames = []
self.inwarm = False
# pad left frames # pad left frames
if self.stride_left_size > 0: if self.stride_left_size > 0:
self.frames.extend([np.zeros(self.chunk, dtype=np.float32)] * self.stride_left_size) self.frames.extend([np.zeros(self.chunk, dtype=np.float32)] * self.stride_left_size)
self.exit_event = Event()
#self.audio_instance = pyaudio.PyAudio() #not need
# create input stream
if self.mode == 'file': #live mode
self.file_stream = self.create_file_stream()
else:
self.queue = Queue()
self.input_stream = BytesIO()
self.output_queue = Queue()
# start a background process to read frames
#self.input_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=True, output=False, frames_per_buffer=self.chunk)
#self.queue = Queue()
#self.process_read_frame = Thread(target=_read_frame, args=(self.input_stream, self.exit_event, self.queue, self.chunk))
# play out the audio too...?
if self.play:
self.output_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=False, output=True, frames_per_buffer=self.chunk)
self.output_queue = Queue()
self.process_play_frame = Thread(target=_play_frame, args=(self.output_stream, self.exit_event, self.output_queue, self.chunk))
# current location of audio
self.idx = 0
# create wav2vec model # create wav2vec model
print(f'[INFO] loading ASR model {self.opt.asr_model}...') print(f'[INFO] loading ASR model {self.opt.asr_model}...')
if 'hubert' in self.opt.asr_model: if 'hubert' in self.opt.asr_model:
@ -105,10 +45,6 @@ class ASR:
self.processor = AutoProcessor.from_pretrained(opt.asr_model) self.processor = AutoProcessor.from_pretrained(opt.asr_model)
self.model = AutoModelForCTC.from_pretrained(opt.asr_model).to(self.device) self.model = AutoModelForCTC.from_pretrained(opt.asr_model).to(self.device)
# prepare to save logits
if self.opt.asr_save_feats:
self.all_feats = []
# the extracted features # the extracted features
# use a loop queue to efficiently record endless features: [f--t---][-------][-------] # use a loop queue to efficiently record endless features: [f--t---][-------][-------]
self.feat_buffer_size = 4 self.feat_buffer_size = 4
@ -124,8 +60,20 @@ class ASR:
# warm up steps needed: mid + right + window_size + attention_size # warm up steps needed: mid + right + window_size + attention_size
self.warm_up_steps = self.context_size + self.stride_left_size + self.stride_right_size #+ self.stride_left_size #+ 8 + 2 * 3 self.warm_up_steps = self.context_size + self.stride_left_size + self.stride_right_size #+ self.stride_left_size #+ 8 + 2 * 3
self.listening = False def get_audio_frame(self):
self.playing = False try:
frame = self.queue.get(block=False)
type = 0
#print(f'[INFO] get frame {frame.shape}')
except queue.Empty:
if self.parent and self.parent.curr_state>1: #播放自定义音频
frame = self.parent.get_audio_stream(self.parent.curr_state)
type = self.parent.curr_state
else:
frame = np.zeros(self.chunk, dtype=np.float32)
type = 1
return frame,type
def get_next_feat(self): #get audio embedding to nerf def get_next_feat(self): #get audio embedding to nerf
# return a [1/8, 16] window, for the next input to nerf side. # return a [1/8, 16] window, for the next input to nerf side.
@ -167,29 +115,19 @@ class ASR:
def run_step(self): def run_step(self):
if self.terminated:
return
# get a frame of audio # get a frame of audio
frame,type = self.__get_audio_frame() frame,type = self.get_audio_frame()
self.frames.append(frame)
# the last frame # put to output
if frame is None: self.output_queue.put((frame,type))
# terminate, but always run the network for the left frames # context not enough, do not run network.
self.terminated = True if len(self.frames) < self.stride_left_size + self.context_size + self.stride_right_size:
else: return
self.frames.append(frame)
# put to output
self.output_queue.put((frame,type))
# context not enough, do not run network.
if len(self.frames) < self.stride_left_size + self.context_size + self.stride_right_size:
return
inputs = np.concatenate(self.frames) # [N * chunk] inputs = np.concatenate(self.frames) # [N * chunk]
# discard the old part to save memory # discard the old part to save memory
if not self.terminated: self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]
self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]
#print(f'[INFO] frame_to_text... ') #print(f'[INFO] frame_to_text... ')
#t = time.time() #t = time.time()
@ -197,10 +135,6 @@ class ASR:
#print(f'-------wav2vec time:{time.time()-t:.4f}s') #print(f'-------wav2vec time:{time.time()-t:.4f}s')
feats = logits # better lips-sync than labels feats = logits # better lips-sync than labels
# save feats
if self.opt.asr_save_feats:
self.all_feats.append(feats)
# record the feats efficiently.. (no concat, constant memory) # record the feats efficiently.. (no concat, constant memory)
start = self.feat_buffer_idx * self.context_size start = self.feat_buffer_idx * self.context_size
end = start + feats.shape[0] end = start + feats.shape[0]
@ -212,51 +146,28 @@ class ASR:
# self.text = self.text + ' ' + text # self.text = self.text + ' ' + text
# will only run once at ternimation # will only run once at ternimation
if self.terminated: # if self.terminated:
self.text += '\n[END]' # self.text += '\n[END]'
print(self.text) # print(self.text)
if self.opt.asr_save_feats: # if self.opt.asr_save_feats:
print(f'[INFO] save all feats for training purpose... ') # print(f'[INFO] save all feats for training purpose... ')
feats = torch.cat(self.all_feats, dim=0) # [N, C] # feats = torch.cat(self.all_feats, dim=0) # [N, C]
# print('[INFO] before unfold', feats.shape) # # print('[INFO] before unfold', feats.shape)
window_size = 16 # window_size = 16
padding = window_size // 2 # padding = window_size // 2
feats = feats.view(-1, self.audio_dim).permute(1, 0).contiguous() # [C, M] # feats = feats.view(-1, self.audio_dim).permute(1, 0).contiguous() # [C, M]
feats = feats.view(1, self.audio_dim, -1, 1) # [1, C, M, 1] # feats = feats.view(1, self.audio_dim, -1, 1) # [1, C, M, 1]
unfold_feats = F.unfold(feats, kernel_size=(window_size, 1), padding=(padding, 0), stride=(2, 1)) # [1, C * window_size, M / 2 + 1] # unfold_feats = F.unfold(feats, kernel_size=(window_size, 1), padding=(padding, 0), stride=(2, 1)) # [1, C * window_size, M / 2 + 1]
unfold_feats = unfold_feats.view(self.audio_dim, window_size, -1).permute(2, 1, 0).contiguous() # [C, window_size, M / 2 + 1] --> [M / 2 + 1, window_size, C] # unfold_feats = unfold_feats.view(self.audio_dim, window_size, -1).permute(2, 1, 0).contiguous() # [C, window_size, M / 2 + 1] --> [M / 2 + 1, window_size, C]
# print('[INFO] after unfold', unfold_feats.shape) # # print('[INFO] after unfold', unfold_feats.shape)
# save to a npy file # # save to a npy file
if 'esperanto' in self.opt.asr_model: # if 'esperanto' in self.opt.asr_model:
output_path = self.opt.asr_wav.replace('.wav', '_eo.npy') # output_path = self.opt.asr_wav.replace('.wav', '_eo.npy')
else: # else:
output_path = self.opt.asr_wav.replace('.wav', '.npy') # output_path = self.opt.asr_wav.replace('.wav', '.npy')
np.save(output_path, unfold_feats.cpu().numpy()) # np.save(output_path, unfold_feats.cpu().numpy())
print(f"[INFO] saved logits to {output_path}") # print(f"[INFO] saved logits to {output_path}")
def __get_audio_frame(self):
if self.inwarm: # warm up
return np.zeros(self.chunk, dtype=np.float32),1
if self.mode == 'file':
if self.idx < self.file_stream.shape[0]:
frame = self.file_stream[self.idx: self.idx + self.chunk]
self.idx = self.idx + self.chunk
return frame,0
else:
return None,0
else:
try:
frame = self.queue.get(block=False)
type = 0
print(f'[INFO] get frame {frame.shape}')
except queue.Empty:
frame = np.zeros(self.chunk, dtype=np.float32)
type = 1
self.idx = self.idx + self.chunk
return frame,type
def __frame_to_text(self, frame): def __frame_to_text(self, frame):
@ -277,8 +188,8 @@ class ASR:
right = min(logits.shape[1], logits.shape[1] - self.stride_right_size + 1) # +1 to make sure output is the same length as input. right = min(logits.shape[1], logits.shape[1] - self.stride_right_size + 1) # +1 to make sure output is the same length as input.
# do not cut right if terminated. # do not cut right if terminated.
if self.terminated: # if self.terminated:
right = logits.shape[1] # right = logits.shape[1]
logits = logits[:, left:right] logits = logits[:, left:right]
@ -298,60 +209,23 @@ class ASR:
return logits[0], None,None #predicted_ids[0], transcription # [N,] return logits[0], None,None #predicted_ids[0], transcription # [N,]
def __create_bytes_stream(self,byte_stream):
#byte_stream=BytesIO(buffer)
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}')
stream = stream.astype(np.float32)
if stream.ndim > 1: def warm_up(self):
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.') print(f'[INFO] warm up ASR live model, expected latency = {self.warm_up_steps / self.fps:.6f}s')
stream = stream[:, 0] t = time.time()
#for _ in range(self.stride_left_size):
if sample_rate != self.sample_rate and stream.shape[0]>0: # self.frames.append(np.zeros(self.chunk, dtype=np.float32))
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.') for _ in range(self.warm_up_steps):
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate) self.run_step()
#if torch.cuda.is_available():
# torch.cuda.synchronize()
t = time.time() - t
print(f'[INFO] warm-up done, actual latency = {t:.6f}s')
return stream #self.clear_queue()
def push_audio(self,buffer): #push audio pcm from tts #####not used function#####################################
print(f'[INFO] push_audio {len(buffer)}') '''
if self.opt.tts == "xtts" or self.opt.tts == "gpt-sovits":
if len(buffer)>0:
stream = np.frombuffer(buffer, dtype=np.int16).astype(np.float32) / 32767
if self.opt.tts == "xtts":
stream = resampy.resample(x=stream, sr_orig=24000, sr_new=self.sample_rate)
else:
stream = resampy.resample(x=stream, sr_orig=32000, sr_new=self.sample_rate)
#byte_stream=BytesIO(buffer)
#stream = self.__create_bytes_stream(byte_stream)
streamlen = stream.shape[0]
idx=0
while streamlen >= self.chunk:
self.queue.put(stream[idx:idx+self.chunk])
streamlen -= self.chunk
idx += self.chunk
# if streamlen>0: #skip last frame(not 20ms)
# self.queue.put(stream[idx:])
else: #edge tts
self.input_stream.write(buffer)
if len(buffer)<=0:
self.input_stream.seek(0)
stream = self.__create_bytes_stream(self.input_stream)
streamlen = stream.shape[0]
idx=0
while streamlen >= self.chunk:
self.queue.put(stream[idx:idx+self.chunk])
streamlen -= self.chunk
idx += self.chunk
#if streamlen>0: #skip last frame(not 20ms)
# self.queue.put(stream[idx:])
self.input_stream.seek(0)
self.input_stream.truncate()
def get_audio_out(self): #get origin audio pcm to nerf
return self.output_queue.get()
def __init_queue(self): def __init_queue(self):
self.frames = [] self.frames = []
self.queue.queue.clear() self.queue.queue.clear()
@ -360,10 +234,6 @@ class ASR:
self.tail = 8 self.tail = 8
# attention window... # attention window...
self.att_feats = [torch.zeros(self.audio_dim, 16, dtype=torch.float32, device=self.device)] * 4 self.att_feats = [torch.zeros(self.audio_dim, 16, dtype=torch.float32, device=self.device)] * 4
def before_push_audio(self):
self.__init_queue()
self.warm_up()
def run(self): def run(self):
@ -380,73 +250,6 @@ class ASR:
if self.play: if self.play:
self.output_queue.queue.clear() self.output_queue.queue.clear()
def warm_up(self):
#self.listen()
self.inwarm = True
print(f'[INFO] warm up ASR live model, expected latency = {self.warm_up_steps / self.fps:.6f}s')
t = time.time()
#for _ in range(self.stride_left_size):
# self.frames.append(np.zeros(self.chunk, dtype=np.float32))
for _ in range(self.warm_up_steps):
self.run_step()
#if torch.cuda.is_available():
# torch.cuda.synchronize()
t = time.time() - t
print(f'[INFO] warm-up done, actual latency = {t:.6f}s')
self.inwarm = False
#self.clear_queue()
'''
def create_file_stream(self):
stream, sample_rate = sf.read(self.opt.asr_wav) # [T*sample_rate,] float64
stream = stream.astype(np.float32)
if stream.ndim > 1:
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
stream = stream[:, 0]
if sample_rate != self.sample_rate:
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.')
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate)
print(f'[INFO] loaded audio stream {self.opt.asr_wav}: {stream.shape}')
return stream
def create_pyaudio_stream(self):
import pyaudio
print(f'[INFO] creating live audio stream ...')
audio = pyaudio.PyAudio()
# get devices
info = audio.get_host_api_info_by_index(0)
n_devices = info.get('deviceCount')
for i in range(0, n_devices):
if (audio.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0:
name = audio.get_device_info_by_host_api_device_index(0, i).get('name')
print(f'[INFO] choose audio device {name}, id {i}')
break
# get stream
stream = audio.open(input_device_index=i,
format=pyaudio.paInt16,
channels=1,
rate=self.sample_rate,
input=True,
frames_per_buffer=self.chunk)
return audio, stream
'''
#####not used function#####################################
def listen(self): def listen(self):
# start # start
if self.mode == 'live' and not self.listening: if self.mode == 'live' and not self.listening:
@ -489,6 +292,25 @@ class ASR:
# live mode: also print the result text. # live mode: also print the result text.
self.text += '\n[END]' self.text += '\n[END]'
print(self.text) print(self.text)
def _read_frame(stream, exit_event, queue, chunk):
while True:
if exit_event.is_set():
print(f'[INFO] read frame thread ends')
break
frame = stream.read(chunk, exception_on_overflow=False)
frame = np.frombuffer(frame, dtype=np.int16).astype(np.float32) / 32767 # [chunk]
queue.put(frame)
def _play_frame(stream, exit_event, queue, chunk):
while True:
if exit_event.is_set():
print(f'[INFO] play frame thread ends')
break
frame = queue.get()
frame = (frame * 32767).astype(np.int16).tobytes()
stream.write(frame, chunk)
######################################################### #########################################################
if __name__ == '__main__': if __name__ == '__main__':
@ -522,4 +344,5 @@ if __name__ == '__main__':
raise ValueError("DeepSpeech features should not use this code to extract...") raise ValueError("DeepSpeech features should not use this code to extract...")
with ASR(opt) as asr: with ASR(opt) as asr:
asr.run() asr.run()
'''

View File

@ -8,61 +8,79 @@ import os
import time import time
import torch.nn.functional as F import torch.nn.functional as F
import cv2 import cv2
import glob
from nerfasr import NerfASR
from ttsreal import EdgeTTS,VoitsTTS,XTTS
from asrreal import ASR
import asyncio import asyncio
from av import AudioFrame, VideoFrame from av import AudioFrame, VideoFrame
from basereal import BaseReal
class NeRFReal: #from imgcache import ImgCache
from tqdm import tqdm
def read_imgs(img_list):
frames = []
print('reading images...')
for img_path in tqdm(img_list):
frame = cv2.imread(img_path)
frames.append(frame)
return frames
class NeRFReal(BaseReal):
def __init__(self, opt, trainer, data_loader, debug=True): def __init__(self, opt, trainer, data_loader, debug=True):
self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters. super().__init__(opt)
#self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
self.W = opt.W self.W = opt.W
self.H = opt.H self.H = opt.H
self.debug = debug
self.training = False
self.step = 0 # training step
self.trainer = trainer self.trainer = trainer
self.data_loader = data_loader self.data_loader = data_loader
# use dataloader's bg # use dataloader's bg
bg_img = data_loader._data.bg_img #.view(1, -1, 3) #bg_img = data_loader._data.bg_img #.view(1, -1, 3)
if self.H != bg_img.shape[0] or self.W != bg_img.shape[1]: #if self.H != bg_img.shape[0] or self.W != bg_img.shape[1]:
bg_img = F.interpolate(bg_img.permute(2, 0, 1).unsqueeze(0).contiguous(), (self.H, self.W), mode='bilinear').squeeze(0).permute(1, 2, 0).contiguous() # bg_img = F.interpolate(bg_img.permute(2, 0, 1).unsqueeze(0).contiguous(), (self.H, self.W), mode='bilinear').squeeze(0).permute(1, 2, 0).contiguous()
self.bg_color = bg_img.view(1, -1, 3) #self.bg_color = bg_img.view(1, -1, 3)
# audio features (from dataloader, only used in non-playing mode) # audio features (from dataloader, only used in non-playing mode)
self.audio_features = data_loader._data.auds # [N, 29, 16] #self.audio_features = data_loader._data.auds # [N, 29, 16]
self.audio_idx = 0 #self.audio_idx = 0
#self.frame_total_num = data_loader._data.end_index #self.frame_total_num = data_loader._data.end_index
#print("frame_total_num:",self.frame_total_num) #print("frame_total_num:",self.frame_total_num)
# control eye # control eye
self.eye_area = None if not self.opt.exp_eye else data_loader._data.eye_area.mean().item() #self.eye_area = None if not self.opt.exp_eye else data_loader._data.eye_area.mean().item()
# playing seq from dataloader, or pause. # playing seq from dataloader, or pause.
self.playing = True #False todo
self.loader = iter(data_loader) self.loader = iter(data_loader)
frame_total_num = data_loader._data.end_index
if opt.fullbody:
input_img_list = glob.glob(os.path.join(self.opt.fullbody_img, '*.[jpJP][pnPN]*[gG]'))
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
#print('input_img_list:',input_img_list)
self.fullbody_list_cycle = read_imgs(input_img_list[:frame_total_num])
#self.imagecache = ImgCache(frame_total_num,self.opt.fullbody_img,1000)
self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32) #self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32)
self.need_update = True # camera moved, should reset accumulation #self.need_update = True # camera moved, should reset accumulation
self.spp = 1 # sample per pixel #self.spp = 1 # sample per pixel
self.mode = 'image' # choose from ['image', 'depth'] #self.mode = 'image' # choose from ['image', 'depth']
self.dynamic_resolution = False # assert False! #self.dynamic_resolution = False # assert False!
self.downscale = 1 #self.downscale = 1
self.train_steps = 16 #self.train_steps = 16
self.ind_index = 0 #self.ind_index = 0
self.ind_num = trainer.model.individual_codes.shape[0] #self.ind_num = trainer.model.individual_codes.shape[0]
self.customimg_index = 0 #self.customimg_index = 0
# build asr # build asr
if self.opt.asr: self.asr = NerfASR(opt,self)
self.asr = ASR(opt) self.asr.warm_up()
self.asr.warm_up()
''' '''
video_path = 'video_stream' video_path = 'video_stream'
@ -108,108 +126,115 @@ class NeRFReal:
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
if self.opt.asr: if self.opt.asr:
self.asr.stop() self.asr.stop()
def push_audio(self,chunk):
self.asr.push_audio(chunk)
def before_push_audio(self):
self.asr.before_push_audio()
def mirror_index(self, index): # def mirror_index(self, index):
size = self.opt.customvideo_imgnum # size = self.opt.customvideo_imgnum
turn = index // size # turn = index // size
res = index % size # res = index % size
if turn % 2 == 0: # if turn % 2 == 0:
return res # return res
else: # else:
return size - res - 1 # return size - res - 1
def prepare_buffer(self, outputs):
if self.mode == 'image':
return outputs['image']
else:
return np.expand_dims(outputs['depth'], -1).repeat(3, -1)
def test_step(self,loop=None,audio_track=None,video_track=None): def test_step(self,loop=None,audio_track=None,video_track=None):
#starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) #starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
#starter.record() #starter.record()
if self.playing: try:
try: data = next(self.loader)
data = next(self.loader) except StopIteration:
except StopIteration: self.loader = iter(self.data_loader)
self.loader = iter(self.data_loader) data = next(self.loader)
data = next(self.loader)
if self.opt.asr:
if self.opt.asr: # use the live audio stream
# use the live audio stream data['auds'] = self.asr.get_next_feat()
data['auds'] = self.asr.get_next_feat()
audiotype = 0 audiotype1 = 0
if self.opt.transport=='rtmp': audiotype2 = 0
for _ in range(2): #send audio
frame,type = self.asr.get_audio_out() for i in range(2):
audiotype += type frame,type = self.asr.get_audio_out()
#print(f'[INFO] get_audio_out shape ',frame.shape) if i==0:
self.streamer.stream_frame_audio(frame) audiotype1 = type
else: else:
for _ in range(2): audiotype2 = type
frame,type = self.asr.get_audio_out() #print(f'[INFO] get_audio_out shape ',frame.shape)
audiotype += type if self.opt.transport=='rtmp':
frame = (frame * 32767).astype(np.int16) self.streamer.stream_frame_audio(frame)
new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0]) else: #webrtc
new_frame.planes[0].update(frame.tobytes()) frame = (frame * 32767).astype(np.int16)
new_frame.sample_rate=16000 new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0])
# if audio_track._queue.qsize()>10: new_frame.planes[0].update(frame.tobytes())
# time.sleep(0.1) new_frame.sample_rate=16000
asyncio.run_coroutine_threadsafe(audio_track._queue.put(new_frame), loop) asyncio.run_coroutine_threadsafe(audio_track._queue.put(new_frame), loop)
#t = time.time()
if self.opt.customvideo and audiotype!=0: # if self.opt.transport=='rtmp':
self.loader = iter(self.data_loader) #init # for _ in range(2):
imgindex = self.mirror_index(self.customimg_index) # frame,type = self.asr.get_audio_out()
#print('custom img index:',imgindex) # audiotype += type
image = cv2.imread(os.path.join(self.opt.customvideo_img, str(int(imgindex))+'.png')) # #print(f'[INFO] get_audio_out shape ',frame.shape)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # self.streamer.stream_frame_audio(frame)
# else: #webrtc
# for _ in range(2):
# frame,type = self.asr.get_audio_out()
# audiotype += type
# frame = (frame * 32767).astype(np.int16)
# new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0])
# new_frame.planes[0].update(frame.tobytes())
# new_frame.sample_rate=16000
# # if audio_track._queue.qsize()>10:
# # time.sleep(0.1)
# asyncio.run_coroutine_threadsafe(audio_track._queue.put(new_frame), loop)
#t = time.time()
if audiotype1!=0 and audiotype2!=0: #全为静音数据
self.speaking = False
else:
self.speaking = True
if audiotype1!=0 and audiotype2!=0 and self.custom_index.get(audiotype1) is not None: #不为推理视频并且有自定义视频
mirindex = self.mirror_index(len(self.custom_img_cycle[audiotype1]),self.custom_index[audiotype1])
#imgindex = self.mirror_index(self.customimg_index)
#print('custom img index:',imgindex)
#image = cv2.imread(os.path.join(self.opt.customvideo_img, str(int(imgindex))+'.png'))
image = self.custom_img_cycle[audiotype1][mirindex]
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
self.custom_index[audiotype1] += 1
if self.opt.transport=='rtmp':
self.streamer.stream_frame(image)
else:
new_frame = VideoFrame.from_ndarray(image, format="rgb24")
asyncio.run_coroutine_threadsafe(video_track._queue.put(new_frame), loop)
else: #推理视频+贴回
outputs = self.trainer.test_gui_with_data(data, self.W, self.H)
#print('-------ernerf time: ',time.time()-t)
#print(f'[INFO] outputs shape ',outputs['image'].shape)
image = (outputs['image'] * 255).astype(np.uint8)
if not self.opt.fullbody:
if self.opt.transport=='rtmp': if self.opt.transport=='rtmp':
self.streamer.stream_frame(image) self.streamer.stream_frame(image)
else: else:
new_frame = VideoFrame.from_ndarray(image, format="rgb24") new_frame = VideoFrame.from_ndarray(image, format="rgb24")
asyncio.run_coroutine_threadsafe(video_track._queue.put(new_frame), loop) asyncio.run_coroutine_threadsafe(video_track._queue.put(new_frame), loop)
self.customimg_index += 1 else: #fullbody human
else: #print("frame index:",data['index'])
self.customimg_index = 0 #image_fullbody = cv2.imread(os.path.join(self.opt.fullbody_img, str(data['index'][0])+'.jpg'))
outputs = self.trainer.test_gui_with_data(data, self.W, self.H) image_fullbody = self.fullbody_list_cycle[data['index'][0]]
#print('-------ernerf time: ',time.time()-t) #image_fullbody = self.imagecache.get_img(data['index'][0])
#print(f'[INFO] outputs shape ',outputs['image'].shape) image_fullbody = cv2.cvtColor(image_fullbody, cv2.COLOR_BGR2RGB)
image = (outputs['image'] * 255).astype(np.uint8) start_x = self.opt.fullbody_offset_x # 合并后小图片的起始x坐标
if not self.opt.fullbody: start_y = self.opt.fullbody_offset_y # 合并后小图片的起始y坐标
if self.opt.transport=='rtmp': image_fullbody[start_y:start_y+image.shape[0], start_x:start_x+image.shape[1]] = image
self.streamer.stream_frame(image) if self.opt.transport=='rtmp':
else: self.streamer.stream_frame(image_fullbody)
new_frame = VideoFrame.from_ndarray(image, format="rgb24") else:
asyncio.run_coroutine_threadsafe(video_track._queue.put(new_frame), loop) new_frame = VideoFrame.from_ndarray(image_fullbody, format="rgb24")
else: #fullbody human asyncio.run_coroutine_threadsafe(video_track._queue.put(new_frame), loop)
#print("frame index:",data['index'])
image_fullbody = cv2.imread(os.path.join(self.opt.fullbody_img, str(data['index'][0])+'.jpg'))
image_fullbody = cv2.cvtColor(image_fullbody, cv2.COLOR_BGR2RGB)
start_x = self.opt.fullbody_offset_x # 合并后小图片的起始x坐标
start_y = self.opt.fullbody_offset_y # 合并后小图片的起始y坐标
image_fullbody[start_y:start_y+image.shape[0], start_x:start_x+image.shape[1]] = image
if self.opt.transport=='rtmp':
self.streamer.stream_frame(image_fullbody)
else:
new_frame = VideoFrame.from_ndarray(image_fullbody, format="rgb24")
asyncio.run_coroutine_threadsafe(video_track._queue.put(new_frame), loop)
#self.pipe.stdin.write(image.tostring()) #self.pipe.stdin.write(image.tostring())
else:
if self.audio_features is not None:
auds = get_audio_features(self.audio_features, self.opt.att, self.audio_idx)
else:
auds = None
outputs = self.trainer.test_gui(self.cam.pose, self.cam.intrinsics, self.W, self.H, auds, self.eye_area, self.ind_index, self.bg_color, self.spp, self.downscale)
#ender.record() #ender.record()
#torch.cuda.synchronize() #torch.cuda.synchronize()
#t = starter.elapsed_time(ender) #t = starter.elapsed_time(ender)
@ -218,6 +243,8 @@ class NeRFReal:
#if self.opt.asr: #if self.opt.asr:
# self.asr.warm_up() # self.asr.warm_up()
self.init_customindex()
if self.opt.transport=='rtmp': if self.opt.transport=='rtmp':
from rtmp_streaming import StreamerConfig, Streamer from rtmp_streaming import StreamerConfig, Streamer
fps=25 fps=25
@ -246,14 +273,15 @@ class NeRFReal:
totaltime=0 totaltime=0
_starttime=time.perf_counter() _starttime=time.perf_counter()
_totalframe=0 _totalframe=0
self.tts.render(quit_event)
while not quit_event.is_set(): #todo while not quit_event.is_set(): #todo
# update texture every frame # update texture every frame
# audio stream thread... # audio stream thread...
t = time.perf_counter() t = time.perf_counter()
if self.opt.asr and self.playing: # run 2 ASR steps (audio is at 50FPS, video is at 25FPS)
# run 2 ASR steps (audio is at 50FPS, video is at 25FPS) for _ in range(2):
for _ in range(2): self.asr.run_step()
self.asr.run_step()
self.test_step(loop,audio_track,video_track) self.test_step(loop,audio_track,video_track)
totaltime += (time.perf_counter() - t) totaltime += (time.perf_counter() - t)
count += 1 count += 1
@ -262,7 +290,14 @@ class NeRFReal:
print(f"------actual avg infer fps:{count/totaltime:.4f}") print(f"------actual avg infer fps:{count/totaltime:.4f}")
count=0 count=0
totaltime=0 totaltime=0
delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms if self.opt.transport=='rtmp':
if delay > 0: delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms
time.sleep(delay) if delay > 0:
time.sleep(delay)
else:
if video_track._queue.qsize()>=5:
#print('sleep qsize=',video_track._queue.qsize())
time.sleep(0.04*video_track._queue.qsize()*0.8)
print('nerfreal thread stop')

View File

@ -23,13 +23,21 @@ soundfile
einops einops
configargparse configargparse
lpips lpips==0.1.3
imageio-ffmpeg imageio-ffmpeg
transformers transformers
edge_tts edge_tts==6.1.11
flask flask
flask_sockets flask_sockets
opencv-python-headless opencv-python-headless
aiortc aiortc
aiohttp_cors aiohttp_cors
ffmpeg-python
omegaconf
diffusers
accelerate
librosa
openai

View File

@ -1,99 +0,0 @@
# 采用gpt-sovits方案bert-sovits适合长音频训练gpt-sovits运行短音频快速推理
## 部署tts推理
git clone https://github.com/X-T-E-R/GPT-SoVITS-Inference.git
## 1. 安装依赖库
```
conda create -n GPTSoVits python=3.9
conda activate GPTSoVits
bash install.sh
```
从 [GPT-SoVITS Models](https://huggingface.co/lj1995/GPT-SoVITS) 下载预训练模型,并将它们放置在 `GPT_SoVITS\pretrained_models`
## 2. Model Folder Format
模型文件下载地址 https://www.yuque.com/xter/zibxlp/gsximn7ditzgispg
下载的模型文件放到trained目录下, 如 `trained/Character1/`
Put the pth / ckpt / wav files in it, the wav should be named as the prompt text
Like :
```
trained
--hutao
----hutao-e75.ckpt
----hutao_e60_s3360.pth
----hutao said something.wav
```
## 3. 启动
### 3.1 后端服务:
python Inference/src/tts_backend.py
如果有错误提示找不到cmudict从这下载https://github.com/nltk/nltk_data将packages改名为nltk_data放到home目录下
### 3.2 管理character:
python Inference/src/Character_Manager.py
浏览器打开可以管理character和emotion
### 3.3 测试tts功能:
python Inference/src/TTS_Webui.py
## 4. 接口说明
### 4.1 Character and Emotion List
To obtain the supported characters and their corresponding emotions, please visit the following URL:
- URL: `http://127.0.0.1:5000/character_list`
- Returns: A JSON format list of characters and corresponding emotions
- Method: `GET`
```
{
"Hanabi": [
"default",
"Normal",
"Yandere",
],
"Hutao": [
"default"
]
}
```
### 4.2 Text-to-Speech
- URL: `http://127.0.0.1:5000/tts`
- Returns: Audio on success. Error message on failure.
- Method: `GET`/`POST`
```
{
"method": "POST",
"body": {
"character": "${chaName}",
"emotion": "${Emotion}",
"text": "${speakText}",
"text_language": "${textLanguage}",
"batch_size": ${batch_size},
"speed": ${speed},
"top_k": ${topK},
"top_p": ${topP},
"temperature": ${temperature},
"stream": "${stream}",
"format": "${Format}",
"save_temp": "${saveTemp}"
}
}
```
##### Parameter Explanation
- **text**: The text to be converted, URL encoding is recommended.
- **character**: Character folder name, pay attention to case sensitivity, full/half width, and language.
- **emotion**: Character emotion, must be an actually supported emotion of the character, otherwise, the default emotion will be used.
- **text_language**: Text language (auto / zh / en / ja), default is multilingual mixed.
- **top_k**, **top_p**, **temperature**: GPT model parameters, no need to modify if unfamiliar.
- **batch_size**: How many batches at a time, can be increased for faster processing if you have a powerful computer, integer, default is 1.
- **speed**: Speech speed, default is 1.0.
- **save_temp**: Whether to save temporary files, when true, the backend will save the generated audio, and subsequent identical requests will directly return that data, default is false.
- **stream**: Whether to stream, when true, audio will be returned sentence by sentence, default is false.
- **format**: Format, default is WAV, allows MP3/ WAV/ OGG.
## 部署tts训练
https://github.com/RVC-Boss/GPT-SoVITS
根据文档说明部署将训练后的模型拷到推理服务的trained目录下

View File

@ -1,28 +0,0 @@
import requests
import pyaudio
# 流式传输音频的URL你可以自由改成Post
stream_url = 'http://127.0.0.1:5000/tts?text=这是一段测试文本,旨在通过多种语言风格和复杂性的内容来全面检验文本到语音系统的性能。接下来,我们会探索各种主题和语言结构,包括文学引用、技术性描述、日常会话以及诗歌等。首先,让我们从一段简单的描述性文本开始:“在一个阳光明媚的下午,一位年轻的旅者站在山顶上,眺望着下方那宽广而繁忙的城市。他的心中充满了对未来的憧憬和对旅途的期待。”这段文本测试了系统对自然景观描写的处理能力和情感表达的细腻程度。&stream=true'
# 初始化pyaudio
p = pyaudio.PyAudio()
# 打开音频流
stream = p.open(format=p.get_format_from_width(2),
channels=1,
rate=32000,
output=True)
# 使用requests获取音频流你可以自由改成Post
response = requests.get(stream_url, stream=True)
# 读取数据块并播放
for data in response.iter_content(chunk_size=1024):
stream.write(data)
# 停止和关闭流
stream.stop_stream()
stream.close()
# 终止pyaudio
p.terminate()

310
ttsreal.py Normal file
View File

@ -0,0 +1,310 @@
import time
import numpy as np
import soundfile as sf
import resampy
import asyncio
import edge_tts
from typing import Iterator
import requests
import queue
from queue import Queue
from io import BytesIO
from threading import Thread, Event
from enum import Enum
class State(Enum):
RUNNING=0
PAUSE=1
class BaseTTS:
def __init__(self, opt, parent):
self.opt=opt
self.parent = parent
self.fps = opt.fps # 20 ms per frame
self.sample_rate = 16000
self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000)
self.input_stream = BytesIO()
self.msgqueue = Queue()
self.state = State.RUNNING
def pause_talk(self):
self.msgqueue.queue.clear()
self.state = State.PAUSE
def put_msg_txt(self,msg):
if len(msg)>0:
self.msgqueue.put(msg)
def render(self,quit_event):
process_thread = Thread(target=self.process_tts, args=(quit_event,))
process_thread.start()
def process_tts(self,quit_event):
while not quit_event.is_set():
try:
msg = self.msgqueue.get(block=True, timeout=1)
self.state=State.RUNNING
except queue.Empty:
continue
self.txt_to_audio(msg)
print('ttsreal thread stop')
def txt_to_audio(self,msg):
pass
###########################################################################################
class EdgeTTS(BaseTTS):
def txt_to_audio(self,msg):
voicename = "zh-CN-YunxiaNeural"
text = msg
t = time.time()
asyncio.new_event_loop().run_until_complete(self.__main(voicename,text))
print(f'-------edge tts time:{time.time()-t:.4f}s')
if self.input_stream.getbuffer().nbytes<=0: #edgetts err
print('edgetts err!!!!!')
return
self.input_stream.seek(0)
stream = self.__create_bytes_stream(self.input_stream)
streamlen = stream.shape[0]
idx=0
while streamlen >= self.chunk and self.state==State.RUNNING:
self.parent.put_audio_frame(stream[idx:idx+self.chunk])
streamlen -= self.chunk
idx += self.chunk
#if streamlen>0: #skip last frame(not 20ms)
# self.queue.put(stream[idx:])
self.input_stream.seek(0)
self.input_stream.truncate()
def __create_bytes_stream(self,byte_stream):
#byte_stream=BytesIO(buffer)
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}')
stream = stream.astype(np.float32)
if stream.ndim > 1:
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
stream = stream[:, 0]
if sample_rate != self.sample_rate and stream.shape[0]>0:
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.')
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate)
return stream
async def __main(self,voicename: str, text: str):
try:
communicate = edge_tts.Communicate(text, voicename)
#with open(OUTPUT_FILE, "wb") as file:
first = True
async for chunk in communicate.stream():
if first:
first = False
if chunk["type"] == "audio" and self.state==State.RUNNING:
#self.push_audio(chunk["data"])
self.input_stream.write(chunk["data"])
#file.write(chunk["data"])
elif chunk["type"] == "WordBoundary":
pass
except Exception as e:
print(e)
###########################################################################################
class VoitsTTS(BaseTTS):
def txt_to_audio(self,msg):
self.stream_tts(
self.gpt_sovits(
msg,
self.opt.REF_FILE,
self.opt.REF_TEXT,
"zh", #en args.language,
self.opt.TTS_SERVER, #"http://127.0.0.1:5000", #args.server_url,
)
)
def gpt_sovits(self, text, reffile, reftext,language, server_url) -> Iterator[bytes]:
start = time.perf_counter()
req={
'text':text,
'text_lang':language,
'ref_audio_path':reffile,
'prompt_text':reftext,
'prompt_lang':language,
'media_type':'raw',
'streaming_mode':True
}
# req["text"] = text
# req["text_language"] = language
# req["character"] = character
# req["emotion"] = emotion
# #req["stream_chunk_size"] = stream_chunk_size # you can reduce it to get faster response, but degrade quality
# req["streaming_mode"] = True
try:
res = requests.post(
f"{server_url}/tts",
json=req,
stream=True,
)
end = time.perf_counter()
print(f"gpt_sovits Time to make POST: {end-start}s")
if res.status_code != 200:
print("Error:", res.text)
return
first = True
for chunk in res.iter_content(chunk_size=12800): # 1280 32K*20ms*2
if first:
end = time.perf_counter()
print(f"gpt_sovits Time to first chunk: {end-start}s")
first = False
if chunk and self.state==State.RUNNING:
yield chunk
#print("gpt_sovits response.elapsed:", res.elapsed)
except Exception as e:
print(e)
def stream_tts(self,audio_stream):
for chunk in audio_stream:
if chunk is not None and len(chunk)>0:
stream = np.frombuffer(chunk, dtype=np.int16).astype(np.float32) / 32767
stream = resampy.resample(x=stream, sr_orig=32000, sr_new=self.sample_rate)
#byte_stream=BytesIO(buffer)
#stream = self.__create_bytes_stream(byte_stream)
streamlen = stream.shape[0]
idx=0
while streamlen >= self.chunk:
self.parent.put_audio_frame(stream[idx:idx+self.chunk])
streamlen -= self.chunk
idx += self.chunk
###########################################################################################
class CosyVoiceTTS(BaseTTS):
def txt_to_audio(self,msg):
self.stream_tts(
self.cosy_voice(
msg,
self.opt.REF_FILE,
self.opt.REF_TEXT,
"zh", #en args.language,
self.opt.TTS_SERVER, #"http://127.0.0.1:5000", #args.server_url,
)
)
def cosy_voice(self, text, reffile, reftext,language, server_url) -> Iterator[bytes]:
start = time.perf_counter()
payload = {
'tts_text': text,
'prompt_text': reftext
}
try:
files = [('prompt_wav', ('prompt_wav', open(reffile, 'rb'), 'application/octet-stream'))]
res = requests.request("GET", f"{server_url}/inference_zero_shot", data=payload, files=files, stream=True)
end = time.perf_counter()
print(f"cosy_voice Time to make POST: {end-start}s")
if res.status_code != 200:
print("Error:", res.text)
return
first = True
for chunk in res.iter_content(chunk_size=8820): # 882 22.05K*20ms*2
if first:
end = time.perf_counter()
print(f"cosy_voice Time to first chunk: {end-start}s")
first = False
if chunk and self.state==State.RUNNING:
yield chunk
except Exception as e:
print(e)
def stream_tts(self,audio_stream):
for chunk in audio_stream:
if chunk is not None and len(chunk)>0:
stream = np.frombuffer(chunk, dtype=np.int16).astype(np.float32) / 32767
stream = resampy.resample(x=stream, sr_orig=22050, sr_new=self.sample_rate)
#byte_stream=BytesIO(buffer)
#stream = self.__create_bytes_stream(byte_stream)
streamlen = stream.shape[0]
idx=0
while streamlen >= self.chunk:
self.parent.put_audio_frame(stream[idx:idx+self.chunk])
streamlen -= self.chunk
idx += self.chunk
###########################################################################################
class XTTS(BaseTTS):
def __init__(self, opt, parent):
super().__init__(opt,parent)
self.speaker = self.get_speaker(opt.REF_FILE, opt.TTS_SERVER)
def txt_to_audio(self,msg):
self.stream_tts(
self.xtts(
msg,
self.speaker,
"zh-cn", #en args.language,
self.opt.TTS_SERVER, #"http://localhost:9000", #args.server_url,
"20" #args.stream_chunk_size
)
)
def get_speaker(self,ref_audio,server_url):
files = {"wav_file": ("reference.wav", open(ref_audio, "rb"))}
response = requests.post(f"{server_url}/clone_speaker", files=files)
return response.json()
def xtts(self,text, speaker, language, server_url, stream_chunk_size) -> Iterator[bytes]:
start = time.perf_counter()
speaker["text"] = text
speaker["language"] = language
speaker["stream_chunk_size"] = stream_chunk_size # you can reduce it to get faster response, but degrade quality
try:
res = requests.post(
f"{server_url}/tts_stream",
json=speaker,
stream=True,
)
end = time.perf_counter()
print(f"xtts Time to make POST: {end-start}s")
if res.status_code != 200:
print("Error:", res.text)
return
first = True
for chunk in res.iter_content(chunk_size=9600): #24K*20ms*2
if first:
end = time.perf_counter()
print(f"xtts Time to first chunk: {end-start}s")
first = False
if chunk:
yield chunk
except Exception as e:
print(e)
def stream_tts(self,audio_stream):
for chunk in audio_stream:
if chunk is not None and len(chunk)>0:
stream = np.frombuffer(chunk, dtype=np.int16).astype(np.float32) / 32767
stream = resampy.resample(x=stream, sr_orig=24000, sr_new=self.sample_rate)
#byte_stream=BytesIO(buffer)
#stream = self.__create_bytes_stream(byte_stream)
streamlen = stream.shape[0]
idx=0
while streamlen >= self.chunk:
self.parent.put_audio_frame(stream[idx:idx+self.chunk])
streamlen -= self.chunk
idx += self.chunk

136
wav2lip/audio.py Normal file
View File

@ -0,0 +1,136 @@
import librosa
import librosa.filters
import numpy as np
# import tensorflow as tf
from scipy import signal
from scipy.io import wavfile
from .hparams import hparams as hp
def load_wav(path, sr):
return librosa.core.load(path, sr=sr)[0]
def save_wav(wav, path, sr):
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
#proposed by @dsmiller
wavfile.write(path, sr, wav.astype(np.int16))
def save_wavenet_wav(wav, path, sr):
librosa.output.write_wav(path, wav, sr=sr)
def preemphasis(wav, k, preemphasize=True):
if preemphasize:
return signal.lfilter([1, -k], [1], wav)
return wav
def inv_preemphasis(wav, k, inv_preemphasize=True):
if inv_preemphasize:
return signal.lfilter([1], [1, -k], wav)
return wav
def get_hop_size():
hop_size = hp.hop_size
if hop_size is None:
assert hp.frame_shift_ms is not None
hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate)
return hop_size
def linearspectrogram(wav):
D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
S = _amp_to_db(np.abs(D)) - hp.ref_level_db
if hp.signal_normalization:
return _normalize(S)
return S
def melspectrogram(wav):
D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db
if hp.signal_normalization:
return _normalize(S)
return S
def _lws_processor():
import lws
return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech")
def _stft(y):
if hp.use_lws:
return _lws_processor(hp).stft(y).T
else:
return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size)
##########################################################
#Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
def num_frames(length, fsize, fshift):
"""Compute number of time frames of spectrogram
"""
pad = (fsize - fshift)
if length % fshift == 0:
M = (length + pad * 2 - fsize) // fshift + 1
else:
M = (length + pad * 2 - fsize) // fshift + 2
return M
def pad_lr(x, fsize, fshift):
"""Compute left and right padding
"""
M = num_frames(len(x), fsize, fshift)
pad = (fsize - fshift)
T = len(x) + 2 * pad
r = (M - 1) * fshift + fsize - T
return pad, pad + r
##########################################################
#Librosa correct padding
def librosa_pad_lr(x, fsize, fshift):
return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
# Conversions
_mel_basis = None
def _linear_to_mel(spectogram):
global _mel_basis
if _mel_basis is None:
_mel_basis = _build_mel_basis()
return np.dot(_mel_basis, spectogram)
def _build_mel_basis():
assert hp.fmax <= hp.sample_rate // 2
return librosa.filters.mel(sr=float(hp.sample_rate), n_fft=hp.n_fft, n_mels=hp.num_mels,
fmin=hp.fmin, fmax=hp.fmax)
def _amp_to_db(x):
min_level = np.exp(hp.min_level_db / 20 * np.log(10))
return 20 * np.log10(np.maximum(min_level, x))
def _db_to_amp(x):
return np.power(10.0, (x) * 0.05)
def _normalize(S):
if hp.allow_clipping_in_normalization:
if hp.symmetric_mels:
return np.clip((2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value,
-hp.max_abs_value, hp.max_abs_value)
else:
return np.clip(hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 0, hp.max_abs_value)
assert S.max() <= 0 and S.min() - hp.min_level_db >= 0
if hp.symmetric_mels:
return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value
else:
return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db))
def _denormalize(D):
if hp.allow_clipping_in_normalization:
if hp.symmetric_mels:
return (((np.clip(D, -hp.max_abs_value,
hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value))
+ hp.min_level_db)
else:
return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
if hp.symmetric_mels:
return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db)
else:
return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)

View File

@ -0,0 +1 @@
The code for Face Detection in this folder has been taken from the wonderful [face_alignment](https://github.com/1adrianb/face-alignment) repository. This has been modified to take batches of faces at a time.

View File

@ -0,0 +1,7 @@
# -*- coding: utf-8 -*-
__author__ = """Adrian Bulat"""
__email__ = 'adrian.bulat@nottingham.ac.uk'
__version__ = '1.0.1'
from .api import FaceAlignment, LandmarksType, NetworkSize

View File

@ -0,0 +1,79 @@
from __future__ import print_function
import os
import torch
from torch.utils.model_zoo import load_url
from enum import Enum
import numpy as np
import cv2
try:
import urllib.request as request_file
except BaseException:
import urllib as request_file
from .models import FAN, ResNetDepth
from .utils import *
class LandmarksType(Enum):
"""Enum class defining the type of landmarks to detect.
``_2D`` - the detected points ``(x,y)`` are detected in a 2D space and follow the visible contour of the face
``_2halfD`` - this points represent the projection of the 3D points into 3D
``_3D`` - detect the points ``(x,y,z)``` in a 3D space
"""
_2D = 1
_2halfD = 2
_3D = 3
class NetworkSize(Enum):
# TINY = 1
# SMALL = 2
# MEDIUM = 3
LARGE = 4
def __new__(cls, value):
member = object.__new__(cls)
member._value_ = value
return member
def __int__(self):
return self.value
ROOT = os.path.dirname(os.path.abspath(__file__))
class FaceAlignment:
def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
device='cuda', flip_input=False, face_detector='sfd', verbose=False):
self.device = device
self.flip_input = flip_input
self.landmarks_type = landmarks_type
self.verbose = verbose
network_size = int(network_size)
if 'cuda' in device:
torch.backends.cudnn.benchmark = True
# Get the face detector
face_detector_module = __import__('face_detection.detection.' + face_detector,
globals(), locals(), [face_detector], 0)
self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose)
def get_detections_for_batch(self, images):
images = images[..., ::-1]
detected_faces = self.face_detector.detect_from_batch(images.copy())
results = []
for i, d in enumerate(detected_faces):
if len(d) == 0:
results.append(None)
continue
d = d[0]
d = np.clip(d, 0, None)
x1, y1, x2, y2 = map(int, d[:-1])
results.append((x1, y1, x2, y2))
return results

View File

@ -0,0 +1 @@
from .core import FaceDetector

View File

@ -0,0 +1,130 @@
import logging
import glob
from tqdm import tqdm
import numpy as np
import torch
import cv2
class FaceDetector(object):
"""An abstract class representing a face detector.
Any other face detection implementation must subclass it. All subclasses
must implement ``detect_from_image``, that return a list of detected
bounding boxes. Optionally, for speed considerations detect from path is
recommended.
"""
def __init__(self, device, verbose):
self.device = device
self.verbose = verbose
if verbose:
if 'cpu' in device:
logger = logging.getLogger(__name__)
logger.warning("Detection running on CPU, this may be potentially slow.")
if 'cpu' not in device and 'cuda' not in device:
if verbose:
logger.error("Expected values for device are: {cpu, cuda} but got: %s", device)
raise ValueError
def detect_from_image(self, tensor_or_path):
"""Detects faces in a given image.
This function detects the faces present in a provided BGR(usually)
image. The input can be either the image itself or the path to it.
Arguments:
tensor_or_path {numpy.ndarray, torch.tensor or string} -- the path
to an image or the image itself.
Example::
>>> path_to_image = 'data/image_01.jpg'
... detected_faces = detect_from_image(path_to_image)
[A list of bounding boxes (x1, y1, x2, y2)]
>>> image = cv2.imread(path_to_image)
... detected_faces = detect_from_image(image)
[A list of bounding boxes (x1, y1, x2, y2)]
"""
raise NotImplementedError
def detect_from_directory(self, path, extensions=['.jpg', '.png'], recursive=False, show_progress_bar=True):
"""Detects faces from all the images present in a given directory.
Arguments:
path {string} -- a string containing a path that points to the folder containing the images
Keyword Arguments:
extensions {list} -- list of string containing the extensions to be
consider in the following format: ``.extension_name`` (default:
{['.jpg', '.png']}) recursive {bool} -- option wherever to scan the
folder recursively (default: {False}) show_progress_bar {bool} --
display a progressbar (default: {True})
Example:
>>> directory = 'data'
... detected_faces = detect_from_directory(directory)
{A dictionary of [lists containing bounding boxes(x1, y1, x2, y2)]}
"""
if self.verbose:
logger = logging.getLogger(__name__)
if len(extensions) == 0:
if self.verbose:
logger.error("Expected at list one extension, but none was received.")
raise ValueError
if self.verbose:
logger.info("Constructing the list of images.")
additional_pattern = '/**/*' if recursive else '/*'
files = []
for extension in extensions:
files.extend(glob.glob(path + additional_pattern + extension, recursive=recursive))
if self.verbose:
logger.info("Finished searching for images. %s images found", len(files))
logger.info("Preparing to run the detection.")
predictions = {}
for image_path in tqdm(files, disable=not show_progress_bar):
if self.verbose:
logger.info("Running the face detector on image: %s", image_path)
predictions[image_path] = self.detect_from_image(image_path)
if self.verbose:
logger.info("The detector was successfully run on all %s images", len(files))
return predictions
@property
def reference_scale(self):
raise NotImplementedError
@property
def reference_x_shift(self):
raise NotImplementedError
@property
def reference_y_shift(self):
raise NotImplementedError
@staticmethod
def tensor_or_path_to_ndarray(tensor_or_path, rgb=True):
"""Convert path (represented as a string) or torch.tensor to a numpy.ndarray
Arguments:
tensor_or_path {numpy.ndarray, torch.tensor or string} -- path to the image, or the image itself
"""
if isinstance(tensor_or_path, str):
return cv2.imread(tensor_or_path) if not rgb else cv2.imread(tensor_or_path)[..., ::-1]
elif torch.is_tensor(tensor_or_path):
# Call cpu in case its coming from cuda
return tensor_or_path.cpu().numpy()[..., ::-1].copy() if not rgb else tensor_or_path.cpu().numpy()
elif isinstance(tensor_or_path, np.ndarray):
return tensor_or_path[..., ::-1].copy() if not rgb else tensor_or_path
else:
raise TypeError

View File

@ -0,0 +1 @@
from .sfd_detector import SFDDetector as FaceDetector

View File

@ -0,0 +1,129 @@
from __future__ import print_function
import os
import sys
import cv2
import random
import datetime
import time
import math
import argparse
import numpy as np
import torch
try:
from iou import IOU
except BaseException:
# IOU cython speedup 10x
def IOU(ax1, ay1, ax2, ay2, bx1, by1, bx2, by2):
sa = abs((ax2 - ax1) * (ay2 - ay1))
sb = abs((bx2 - bx1) * (by2 - by1))
x1, y1 = max(ax1, bx1), max(ay1, by1)
x2, y2 = min(ax2, bx2), min(ay2, by2)
w = x2 - x1
h = y2 - y1
if w < 0 or h < 0:
return 0.0
else:
return 1.0 * w * h / (sa + sb - w * h)
def bboxlog(x1, y1, x2, y2, axc, ayc, aww, ahh):
xc, yc, ww, hh = (x2 + x1) / 2, (y2 + y1) / 2, x2 - x1, y2 - y1
dx, dy = (xc - axc) / aww, (yc - ayc) / ahh
dw, dh = math.log(ww / aww), math.log(hh / ahh)
return dx, dy, dw, dh
def bboxloginv(dx, dy, dw, dh, axc, ayc, aww, ahh):
xc, yc = dx * aww + axc, dy * ahh + ayc
ww, hh = math.exp(dw) * aww, math.exp(dh) * ahh
x1, x2, y1, y2 = xc - ww / 2, xc + ww / 2, yc - hh / 2, yc + hh / 2
return x1, y1, x2, y2
def nms(dets, thresh):
if 0 == len(dets):
return []
x1, y1, x2, y2, scores = dets[:, 0], dets[:, 1], dets[:, 2], dets[:, 3], dets[:, 4]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
xx1, yy1 = np.maximum(x1[i], x1[order[1:]]), np.maximum(y1[i], y1[order[1:]])
xx2, yy2 = np.minimum(x2[i], x2[order[1:]]), np.minimum(y2[i], y2[order[1:]])
w, h = np.maximum(0.0, xx2 - xx1 + 1), np.maximum(0.0, yy2 - yy1 + 1)
ovr = w * h / (areas[i] + areas[order[1:]] - w * h)
inds = np.where(ovr <= thresh)[0]
order = order[inds + 1]
return keep
def encode(matched, priors, variances):
"""Encode the variances from the priorbox layers into the ground truth boxes
we have matched (based on jaccard overlap) with the prior boxes.
Args:
matched: (tensor) Coords of ground truth for each prior in point-form
Shape: [num_priors, 4].
priors: (tensor) Prior boxes in center-offset form
Shape: [num_priors,4].
variances: (list[float]) Variances of priorboxes
Return:
encoded boxes (tensor), Shape: [num_priors, 4]
"""
# dist b/t match center and prior's center
g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2]
# encode variance
g_cxcy /= (variances[0] * priors[:, 2:])
# match wh / prior wh
g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
g_wh = torch.log(g_wh) / variances[1]
# return target for smooth_l1_loss
return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
def decode(loc, priors, variances):
"""Decode locations from predictions using priors to undo
the encoding we did for offset regression at train time.
Args:
loc (tensor): location predictions for loc layers,
Shape: [num_priors,4]
priors (tensor): Prior boxes in center-offset form.
Shape: [num_priors,4].
variances: (list[float]) Variances of priorboxes
Return:
decoded bounding box predictions
"""
boxes = torch.cat((
priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
boxes[:, :2] -= boxes[:, 2:] / 2
boxes[:, 2:] += boxes[:, :2]
return boxes
def batch_decode(loc, priors, variances):
"""Decode locations from predictions using priors to undo
the encoding we did for offset regression at train time.
Args:
loc (tensor): location predictions for loc layers,
Shape: [num_priors,4]
priors (tensor): Prior boxes in center-offset form.
Shape: [num_priors,4].
variances: (list[float]) Variances of priorboxes
Return:
decoded bounding box predictions
"""
boxes = torch.cat((
priors[:, :, :2] + loc[:, :, :2] * variances[0] * priors[:, :, 2:],
priors[:, :, 2:] * torch.exp(loc[:, :, 2:] * variances[1])), 2)
boxes[:, :, :2] -= boxes[:, :, 2:] / 2
boxes[:, :, 2:] += boxes[:, :, :2]
return boxes

View File

@ -0,0 +1,112 @@
import torch
import torch.nn.functional as F
import os
import sys
import cv2
import random
import datetime
import math
import argparse
import numpy as np
import scipy.io as sio
import zipfile
from .net_s3fd import s3fd
from .bbox import *
def detect(net, img, device):
img = img - np.array([104, 117, 123])
img = img.transpose(2, 0, 1)
img = img.reshape((1,) + img.shape)
if 'cuda' in device:
torch.backends.cudnn.benchmark = True
img = torch.from_numpy(img).float().to(device)
BB, CC, HH, WW = img.size()
with torch.no_grad():
olist = net(img)
bboxlist = []
for i in range(len(olist) // 2):
olist[i * 2] = F.softmax(olist[i * 2], dim=1)
olist = [oelem.data.cpu() for oelem in olist]
for i in range(len(olist) // 2):
ocls, oreg = olist[i * 2], olist[i * 2 + 1]
FB, FC, FH, FW = ocls.size() # feature map size
stride = 2**(i + 2) # 4,8,16,32,64,128
anchor = stride * 4
poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
for Iindex, hindex, windex in poss:
axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
score = ocls[0, 1, hindex, windex]
loc = oreg[0, :, hindex, windex].contiguous().view(1, 4)
priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]])
variances = [0.1, 0.2]
box = decode(loc, priors, variances)
x1, y1, x2, y2 = box[0] * 1.0
# cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
bboxlist.append([x1, y1, x2, y2, score])
bboxlist = np.array(bboxlist)
if 0 == len(bboxlist):
bboxlist = np.zeros((1, 5))
return bboxlist
def batch_detect(net, imgs, device):
imgs = imgs - np.array([104, 117, 123])
imgs = imgs.transpose(0, 3, 1, 2)
if 'cuda' in device:
torch.backends.cudnn.benchmark = True
imgs = torch.from_numpy(imgs).float().to(device)
BB, CC, HH, WW = imgs.size()
with torch.no_grad():
olist = net(imgs)
bboxlist = []
for i in range(len(olist) // 2):
olist[i * 2] = F.softmax(olist[i * 2], dim=1)
olist = [oelem.data.cpu() for oelem in olist]
for i in range(len(olist) // 2):
ocls, oreg = olist[i * 2], olist[i * 2 + 1]
FB, FC, FH, FW = ocls.size() # feature map size
stride = 2**(i + 2) # 4,8,16,32,64,128
anchor = stride * 4
poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
for Iindex, hindex, windex in poss:
axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
score = ocls[:, 1, hindex, windex]
loc = oreg[:, :, hindex, windex].contiguous().view(BB, 1, 4)
priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]]).view(1, 1, 4)
variances = [0.1, 0.2]
box = batch_decode(loc, priors, variances)
box = box[:, 0] * 1.0
# cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
bboxlist.append(torch.cat([box, score.unsqueeze(1)], 1).cpu().numpy())
bboxlist = np.array(bboxlist)
if 0 == len(bboxlist):
bboxlist = np.zeros((1, BB, 5))
return bboxlist
def flip_detect(net, img, device):
img = cv2.flip(img, 1)
b = detect(net, img, device)
bboxlist = np.zeros(b.shape)
bboxlist[:, 0] = img.shape[1] - b[:, 2]
bboxlist[:, 1] = b[:, 1]
bboxlist[:, 2] = img.shape[1] - b[:, 0]
bboxlist[:, 3] = b[:, 3]
bboxlist[:, 4] = b[:, 4]
return bboxlist
def pts_to_bb(pts):
min_x, min_y = np.min(pts, axis=0)
max_x, max_y = np.max(pts, axis=0)
return np.array([min_x, min_y, max_x, max_y])

View File

@ -0,0 +1,129 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class L2Norm(nn.Module):
def __init__(self, n_channels, scale=1.0):
super(L2Norm, self).__init__()
self.n_channels = n_channels
self.scale = scale
self.eps = 1e-10
self.weight = nn.Parameter(torch.Tensor(self.n_channels))
self.weight.data *= 0.0
self.weight.data += self.scale
def forward(self, x):
norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps
x = x / norm * self.weight.view(1, -1, 1, 1)
return x
class s3fd(nn.Module):
def __init__(self):
super(s3fd, self).__init__()
self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.fc6 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=3)
self.fc7 = nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0)
self.conv6_1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0)
self.conv6_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)
self.conv7_1 = nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0)
self.conv7_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
self.conv3_3_norm = L2Norm(256, scale=10)
self.conv4_3_norm = L2Norm(512, scale=8)
self.conv5_3_norm = L2Norm(512, scale=5)
self.conv3_3_norm_mbox_conf = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
self.conv3_3_norm_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
self.conv4_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
self.conv4_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
self.conv5_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
self.conv5_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
self.fc7_mbox_conf = nn.Conv2d(1024, 2, kernel_size=3, stride=1, padding=1)
self.fc7_mbox_loc = nn.Conv2d(1024, 4, kernel_size=3, stride=1, padding=1)
self.conv6_2_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
self.conv6_2_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
self.conv7_2_mbox_conf = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1)
self.conv7_2_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
def forward(self, x):
h = F.relu(self.conv1_1(x))
h = F.relu(self.conv1_2(h))
h = F.max_pool2d(h, 2, 2)
h = F.relu(self.conv2_1(h))
h = F.relu(self.conv2_2(h))
h = F.max_pool2d(h, 2, 2)
h = F.relu(self.conv3_1(h))
h = F.relu(self.conv3_2(h))
h = F.relu(self.conv3_3(h))
f3_3 = h
h = F.max_pool2d(h, 2, 2)
h = F.relu(self.conv4_1(h))
h = F.relu(self.conv4_2(h))
h = F.relu(self.conv4_3(h))
f4_3 = h
h = F.max_pool2d(h, 2, 2)
h = F.relu(self.conv5_1(h))
h = F.relu(self.conv5_2(h))
h = F.relu(self.conv5_3(h))
f5_3 = h
h = F.max_pool2d(h, 2, 2)
h = F.relu(self.fc6(h))
h = F.relu(self.fc7(h))
ffc7 = h
h = F.relu(self.conv6_1(h))
h = F.relu(self.conv6_2(h))
f6_2 = h
h = F.relu(self.conv7_1(h))
h = F.relu(self.conv7_2(h))
f7_2 = h
f3_3 = self.conv3_3_norm(f3_3)
f4_3 = self.conv4_3_norm(f4_3)
f5_3 = self.conv5_3_norm(f5_3)
cls1 = self.conv3_3_norm_mbox_conf(f3_3)
reg1 = self.conv3_3_norm_mbox_loc(f3_3)
cls2 = self.conv4_3_norm_mbox_conf(f4_3)
reg2 = self.conv4_3_norm_mbox_loc(f4_3)
cls3 = self.conv5_3_norm_mbox_conf(f5_3)
reg3 = self.conv5_3_norm_mbox_loc(f5_3)
cls4 = self.fc7_mbox_conf(ffc7)
reg4 = self.fc7_mbox_loc(ffc7)
cls5 = self.conv6_2_mbox_conf(f6_2)
reg5 = self.conv6_2_mbox_loc(f6_2)
cls6 = self.conv7_2_mbox_conf(f7_2)
reg6 = self.conv7_2_mbox_loc(f7_2)
# max-out background label
chunk = torch.chunk(cls1, 4, 1)
bmax = torch.max(torch.max(chunk[0], chunk[1]), chunk[2])
cls1 = torch.cat([bmax, chunk[3]], dim=1)
return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6]

View File

@ -0,0 +1,59 @@
import os
import cv2
from torch.utils.model_zoo import load_url
from ..core import FaceDetector
from .net_s3fd import s3fd
from .bbox import *
from .detect import *
models_urls = {
's3fd': 'https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth',
}
class SFDDetector(FaceDetector):
def __init__(self, device, path_to_detector=os.path.join(os.path.dirname(os.path.abspath(__file__)), 's3fd.pth'), verbose=False):
super(SFDDetector, self).__init__(device, verbose)
# Initialise the face detector
if not os.path.isfile(path_to_detector):
model_weights = load_url(models_urls['s3fd'])
else:
model_weights = torch.load(path_to_detector)
self.face_detector = s3fd()
self.face_detector.load_state_dict(model_weights)
self.face_detector.to(device)
self.face_detector.eval()
def detect_from_image(self, tensor_or_path):
image = self.tensor_or_path_to_ndarray(tensor_or_path)
bboxlist = detect(self.face_detector, image, device=self.device)
keep = nms(bboxlist, 0.3)
bboxlist = bboxlist[keep, :]
bboxlist = [x for x in bboxlist if x[-1] > 0.5]
return bboxlist
def detect_from_batch(self, images):
bboxlists = batch_detect(self.face_detector, images, device=self.device)
keeps = [nms(bboxlists[:, i, :], 0.3) for i in range(bboxlists.shape[1])]
bboxlists = [bboxlists[keep, i, :] for i, keep in enumerate(keeps)]
bboxlists = [[x for x in bboxlist if x[-1] > 0.5] for bboxlist in bboxlists]
return bboxlists
@property
def reference_scale(self):
return 195
@property
def reference_x_shift(self):
return 0
@property
def reference_y_shift(self):
return 0

View File

@ -0,0 +1,261 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False):
"3x3 convolution with padding"
return nn.Conv2d(in_planes, out_planes, kernel_size=3,
stride=strd, padding=padding, bias=bias)
class ConvBlock(nn.Module):
def __init__(self, in_planes, out_planes):
super(ConvBlock, self).__init__()
self.bn1 = nn.BatchNorm2d(in_planes)
self.conv1 = conv3x3(in_planes, int(out_planes / 2))
self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4))
self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4))
if in_planes != out_planes:
self.downsample = nn.Sequential(
nn.BatchNorm2d(in_planes),
nn.ReLU(True),
nn.Conv2d(in_planes, out_planes,
kernel_size=1, stride=1, bias=False),
)
else:
self.downsample = None
def forward(self, x):
residual = x
out1 = self.bn1(x)
out1 = F.relu(out1, True)
out1 = self.conv1(out1)
out2 = self.bn2(out1)
out2 = F.relu(out2, True)
out2 = self.conv2(out2)
out3 = self.bn3(out2)
out3 = F.relu(out3, True)
out3 = self.conv3(out3)
out3 = torch.cat((out1, out2, out3), 1)
if self.downsample is not None:
residual = self.downsample(residual)
out3 += residual
return out3
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class HourGlass(nn.Module):
def __init__(self, num_modules, depth, num_features):
super(HourGlass, self).__init__()
self.num_modules = num_modules
self.depth = depth
self.features = num_features
self._generate_network(self.depth)
def _generate_network(self, level):
self.add_module('b1_' + str(level), ConvBlock(self.features, self.features))
self.add_module('b2_' + str(level), ConvBlock(self.features, self.features))
if level > 1:
self._generate_network(level - 1)
else:
self.add_module('b2_plus_' + str(level), ConvBlock(self.features, self.features))
self.add_module('b3_' + str(level), ConvBlock(self.features, self.features))
def _forward(self, level, inp):
# Upper branch
up1 = inp
up1 = self._modules['b1_' + str(level)](up1)
# Lower branch
low1 = F.avg_pool2d(inp, 2, stride=2)
low1 = self._modules['b2_' + str(level)](low1)
if level > 1:
low2 = self._forward(level - 1, low1)
else:
low2 = low1
low2 = self._modules['b2_plus_' + str(level)](low2)
low3 = low2
low3 = self._modules['b3_' + str(level)](low3)
up2 = F.interpolate(low3, scale_factor=2, mode='nearest')
return up1 + up2
def forward(self, x):
return self._forward(self.depth, x)
class FAN(nn.Module):
def __init__(self, num_modules=1):
super(FAN, self).__init__()
self.num_modules = num_modules
# Base part
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
self.bn1 = nn.BatchNorm2d(64)
self.conv2 = ConvBlock(64, 128)
self.conv3 = ConvBlock(128, 128)
self.conv4 = ConvBlock(128, 256)
# Stacking part
for hg_module in range(self.num_modules):
self.add_module('m' + str(hg_module), HourGlass(1, 4, 256))
self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256))
self.add_module('conv_last' + str(hg_module),
nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256))
self.add_module('l' + str(hg_module), nn.Conv2d(256,
68, kernel_size=1, stride=1, padding=0))
if hg_module < self.num_modules - 1:
self.add_module(
'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
self.add_module('al' + str(hg_module), nn.Conv2d(68,
256, kernel_size=1, stride=1, padding=0))
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)), True)
x = F.avg_pool2d(self.conv2(x), 2, stride=2)
x = self.conv3(x)
x = self.conv4(x)
previous = x
outputs = []
for i in range(self.num_modules):
hg = self._modules['m' + str(i)](previous)
ll = hg
ll = self._modules['top_m_' + str(i)](ll)
ll = F.relu(self._modules['bn_end' + str(i)]
(self._modules['conv_last' + str(i)](ll)), True)
# Predict heatmaps
tmp_out = self._modules['l' + str(i)](ll)
outputs.append(tmp_out)
if i < self.num_modules - 1:
ll = self._modules['bl' + str(i)](ll)
tmp_out_ = self._modules['al' + str(i)](tmp_out)
previous = previous + ll + tmp_out_
return outputs
class ResNetDepth(nn.Module):
def __init__(self, block=Bottleneck, layers=[3, 8, 36, 3], num_classes=68):
self.inplanes = 64
super(ResNetDepth, self).__init__()
self.conv1 = nn.Conv2d(3 + 68, 64, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AvgPool2d(7)
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x

View File

@ -0,0 +1,313 @@
from __future__ import print_function
import os
import sys
import time
import torch
import math
import numpy as np
import cv2
def _gaussian(
size=3, sigma=0.25, amplitude=1, normalize=False, width=None,
height=None, sigma_horz=None, sigma_vert=None, mean_horz=0.5,
mean_vert=0.5):
# handle some defaults
if width is None:
width = size
if height is None:
height = size
if sigma_horz is None:
sigma_horz = sigma
if sigma_vert is None:
sigma_vert = sigma
center_x = mean_horz * width + 0.5
center_y = mean_vert * height + 0.5
gauss = np.empty((height, width), dtype=np.float32)
# generate kernel
for i in range(height):
for j in range(width):
gauss[i][j] = amplitude * math.exp(-(math.pow((j + 1 - center_x) / (
sigma_horz * width), 2) / 2.0 + math.pow((i + 1 - center_y) / (sigma_vert * height), 2) / 2.0))
if normalize:
gauss = gauss / np.sum(gauss)
return gauss
def draw_gaussian(image, point, sigma):
# Check if the gaussian is inside
ul = [math.floor(point[0] - 3 * sigma), math.floor(point[1] - 3 * sigma)]
br = [math.floor(point[0] + 3 * sigma), math.floor(point[1] + 3 * sigma)]
if (ul[0] > image.shape[1] or ul[1] > image.shape[0] or br[0] < 1 or br[1] < 1):
return image
size = 6 * sigma + 1
g = _gaussian(size)
g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) - int(max(1, ul[0])) + int(max(1, -ul[0]))]
g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) - int(max(1, ul[1])) + int(max(1, -ul[1]))]
img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))]
img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))]
assert (g_x[0] > 0 and g_y[1] > 0)
image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]
] = image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]] + g[g_y[0] - 1:g_y[1], g_x[0] - 1:g_x[1]]
image[image > 1] = 1
return image
def transform(point, center, scale, resolution, invert=False):
"""Generate and affine transformation matrix.
Given a set of points, a center, a scale and a targer resolution, the
function generates and affine transformation matrix. If invert is ``True``
it will produce the inverse transformation.
Arguments:
point {torch.tensor} -- the input 2D point
center {torch.tensor or numpy.array} -- the center around which to perform the transformations
scale {float} -- the scale of the face/object
resolution {float} -- the output resolution
Keyword Arguments:
invert {bool} -- define wherever the function should produce the direct or the
inverse transformation matrix (default: {False})
"""
_pt = torch.ones(3)
_pt[0] = point[0]
_pt[1] = point[1]
h = 200.0 * scale
t = torch.eye(3)
t[0, 0] = resolution / h
t[1, 1] = resolution / h
t[0, 2] = resolution * (-center[0] / h + 0.5)
t[1, 2] = resolution * (-center[1] / h + 0.5)
if invert:
t = torch.inverse(t)
new_point = (torch.matmul(t, _pt))[0:2]
return new_point.int()
def crop(image, center, scale, resolution=256.0):
"""Center crops an image or set of heatmaps
Arguments:
image {numpy.array} -- an rgb image
center {numpy.array} -- the center of the object, usually the same as of the bounding box
scale {float} -- scale of the face
Keyword Arguments:
resolution {float} -- the size of the output cropped image (default: {256.0})
Returns:
[type] -- [description]
""" # Crop around the center point
""" Crops the image around the center. Input is expected to be an np.ndarray """
ul = transform([1, 1], center, scale, resolution, True)
br = transform([resolution, resolution], center, scale, resolution, True)
# pad = math.ceil(torch.norm((ul - br).float()) / 2.0 - (br[0] - ul[0]) / 2.0)
if image.ndim > 2:
newDim = np.array([br[1] - ul[1], br[0] - ul[0],
image.shape[2]], dtype=np.int32)
newImg = np.zeros(newDim, dtype=np.uint8)
else:
newDim = np.array([br[1] - ul[1], br[0] - ul[0]], dtype=np.int)
newImg = np.zeros(newDim, dtype=np.uint8)
ht = image.shape[0]
wd = image.shape[1]
newX = np.array(
[max(1, -ul[0] + 1), min(br[0], wd) - ul[0]], dtype=np.int32)
newY = np.array(
[max(1, -ul[1] + 1), min(br[1], ht) - ul[1]], dtype=np.int32)
oldX = np.array([max(1, ul[0] + 1), min(br[0], wd)], dtype=np.int32)
oldY = np.array([max(1, ul[1] + 1), min(br[1], ht)], dtype=np.int32)
newImg[newY[0] - 1:newY[1], newX[0] - 1:newX[1]
] = image[oldY[0] - 1:oldY[1], oldX[0] - 1:oldX[1], :]
newImg = cv2.resize(newImg, dsize=(int(resolution), int(resolution)),
interpolation=cv2.INTER_LINEAR)
return newImg
def get_preds_fromhm(hm, center=None, scale=None):
"""Obtain (x,y) coordinates given a set of N heatmaps. If the center
and the scale is provided the function will return the points also in
the original coordinate frame.
Arguments:
hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
Keyword Arguments:
center {torch.tensor} -- the center of the bounding box (default: {None})
scale {float} -- face scale (default: {None})
"""
max, idx = torch.max(
hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
idx += 1
preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
for i in range(preds.size(0)):
for j in range(preds.size(1)):
hm_ = hm[i, j, :]
pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
if pX > 0 and pX < 63 and pY > 0 and pY < 63:
diff = torch.FloatTensor(
[hm_[pY, pX + 1] - hm_[pY, pX - 1],
hm_[pY + 1, pX] - hm_[pY - 1, pX]])
preds[i, j].add_(diff.sign_().mul_(.25))
preds.add_(-.5)
preds_orig = torch.zeros(preds.size())
if center is not None and scale is not None:
for i in range(hm.size(0)):
for j in range(hm.size(1)):
preds_orig[i, j] = transform(
preds[i, j], center, scale, hm.size(2), True)
return preds, preds_orig
def get_preds_fromhm_batch(hm, centers=None, scales=None):
"""Obtain (x,y) coordinates given a set of N heatmaps. If the centers
and the scales is provided the function will return the points also in
the original coordinate frame.
Arguments:
hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
Keyword Arguments:
centers {torch.tensor} -- the centers of the bounding box (default: {None})
scales {float} -- face scales (default: {None})
"""
max, idx = torch.max(
hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
idx += 1
preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
for i in range(preds.size(0)):
for j in range(preds.size(1)):
hm_ = hm[i, j, :]
pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
if pX > 0 and pX < 63 and pY > 0 and pY < 63:
diff = torch.FloatTensor(
[hm_[pY, pX + 1] - hm_[pY, pX - 1],
hm_[pY + 1, pX] - hm_[pY - 1, pX]])
preds[i, j].add_(diff.sign_().mul_(.25))
preds.add_(-.5)
preds_orig = torch.zeros(preds.size())
if centers is not None and scales is not None:
for i in range(hm.size(0)):
for j in range(hm.size(1)):
preds_orig[i, j] = transform(
preds[i, j], centers[i], scales[i], hm.size(2), True)
return preds, preds_orig
def shuffle_lr(parts, pairs=None):
"""Shuffle the points left-right according to the axis of symmetry
of the object.
Arguments:
parts {torch.tensor} -- a 3D or 4D object containing the
heatmaps.
Keyword Arguments:
pairs {list of integers} -- [order of the flipped points] (default: {None})
"""
if pairs is None:
pairs = [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0,
26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 27, 28, 29, 30, 35,
34, 33, 32, 31, 45, 44, 43, 42, 47, 46, 39, 38, 37, 36, 41,
40, 54, 53, 52, 51, 50, 49, 48, 59, 58, 57, 56, 55, 64, 63,
62, 61, 60, 67, 66, 65]
if parts.ndimension() == 3:
parts = parts[pairs, ...]
else:
parts = parts[:, pairs, ...]
return parts
def flip(tensor, is_label=False):
"""Flip an image or a set of heatmaps left-right
Arguments:
tensor {numpy.array or torch.tensor} -- [the input image or heatmaps]
Keyword Arguments:
is_label {bool} -- [denote wherever the input is an image or a set of heatmaps ] (default: {False})
"""
if not torch.is_tensor(tensor):
tensor = torch.from_numpy(tensor)
if is_label:
tensor = shuffle_lr(tensor).flip(tensor.ndimension() - 1)
else:
tensor = tensor.flip(tensor.ndimension() - 1)
return tensor
# From pyzolib/paths.py (https://bitbucket.org/pyzo/pyzolib/src/tip/paths.py)
def appdata_dir(appname=None, roaming=False):
""" appdata_dir(appname=None, roaming=False)
Get the path to the application directory, where applications are allowed
to write user specific files (e.g. configurations). For non-user specific
data, consider using common_appdata_dir().
If appname is given, a subdir is appended (and created if necessary).
If roaming is True, will prefer a roaming directory (Windows Vista/7).
"""
# Define default user directory
userDir = os.getenv('FACEALIGNMENT_USERDIR', None)
if userDir is None:
userDir = os.path.expanduser('~')
if not os.path.isdir(userDir): # pragma: no cover
userDir = '/var/tmp' # issue #54
# Get system app data dir
path = None
if sys.platform.startswith('win'):
path1, path2 = os.getenv('LOCALAPPDATA'), os.getenv('APPDATA')
path = (path2 or path1) if roaming else (path1 or path2)
elif sys.platform.startswith('darwin'):
path = os.path.join(userDir, 'Library', 'Application Support')
# On Linux and as fallback
if not (path and os.path.isdir(path)):
path = userDir
# Maybe we should store things local to the executable (in case of a
# portable distro or a frozen application that wants to be portable)
prefix = sys.prefix
if getattr(sys, 'frozen', None):
prefix = os.path.abspath(os.path.dirname(sys.executable))
for reldir in ('settings', '../settings'):
localpath = os.path.abspath(os.path.join(prefix, reldir))
if os.path.isdir(localpath): # pragma: no cover
try:
open(os.path.join(localpath, 'test.write'), 'wb').close()
os.remove(os.path.join(localpath, 'test.write'))
except IOError:
pass # We cannot write in this directory
else:
path = localpath
break
# Get path specific for this app
if appname:
if path == userDir:
appname = '.' + appname.lstrip('.') # Make it a hidden directory
path = os.path.join(path, appname)
if not os.path.isdir(path): # pragma: no cover
os.mkdir(path)
# Done
return path

125
wav2lip/genavatar.py Normal file
View File

@ -0,0 +1,125 @@
from os import listdir, path
import numpy as np
import scipy, cv2, os, sys, argparse
import json, subprocess, random, string
from tqdm import tqdm
from glob import glob
import torch
import pickle
import face_detection
parser = argparse.ArgumentParser(description='Inference code to lip-sync videos in the wild using Wav2Lip models')
parser.add_argument('--img_size', default=96, type=int)
parser.add_argument('--avatar_id', default='wav2lip_avatar1', type=str)
parser.add_argument('--video_path', default='', type=str)
parser.add_argument('--nosmooth', default=False, action='store_true',
help='Prevent smoothing face detections over a short temporal window')
parser.add_argument('--pads', nargs='+', type=int, default=[0, 10, 0, 0],
help='Padding (top, bottom, left, right). Please adjust to include chin at least')
parser.add_argument('--face_det_batch_size', type=int,
help='Batch size for face detection', default=16)
args = parser.parse_args()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} for inference.'.format(device))
def osmakedirs(path_list):
for path in path_list:
os.makedirs(path) if not os.path.exists(path) else None
def video2imgs(vid_path, save_path, ext = '.png',cut_frame = 10000000):
cap = cv2.VideoCapture(vid_path)
count = 0
while True:
if count > cut_frame:
break
ret, frame = cap.read()
if ret:
cv2.imwrite(f"{save_path}/{count:08d}.png", frame)
count += 1
else:
break
def read_imgs(img_list):
frames = []
print('reading images...')
for img_path in tqdm(img_list):
frame = cv2.imread(img_path)
frames.append(frame)
return frames
def get_smoothened_boxes(boxes, T):
for i in range(len(boxes)):
if i + T > len(boxes):
window = boxes[len(boxes) - T:]
else:
window = boxes[i : i + T]
boxes[i] = np.mean(window, axis=0)
return boxes
def face_detect(images):
detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,
flip_input=False, device=device)
batch_size = args.face_det_batch_size
while 1:
predictions = []
try:
for i in tqdm(range(0, len(images), batch_size)):
predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
except RuntimeError:
if batch_size == 1:
raise RuntimeError('Image too big to run face detection on GPU. Please use the --resize_factor argument')
batch_size //= 2
print('Recovering from OOM error; New batch size: {}'.format(batch_size))
continue
break
results = []
pady1, pady2, padx1, padx2 = args.pads
for rect, image in zip(predictions, images):
if rect is None:
cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected.
raise ValueError('Face not detected! Ensure the video contains a face in all the frames.')
y1 = max(0, rect[1] - pady1)
y2 = min(image.shape[0], rect[3] + pady2)
x1 = max(0, rect[0] - padx1)
x2 = min(image.shape[1], rect[2] + padx2)
results.append([x1, y1, x2, y2])
boxes = np.array(results)
if not args.nosmooth: boxes = get_smoothened_boxes(boxes, T=5)
results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]
del detector
return results
if __name__ == "__main__":
avatar_path = f"./results/avatars/{args.avatar_id}"
full_imgs_path = f"{avatar_path}/full_imgs"
face_imgs_path = f"{avatar_path}/face_imgs"
coords_path = f"{avatar_path}/coords.pkl"
osmakedirs([avatar_path,full_imgs_path,face_imgs_path])
print(args)
#if os.path.isfile(args.video_path):
video2imgs(args.video_path, full_imgs_path, ext = 'png')
input_img_list = sorted(glob(os.path.join(full_imgs_path, '*.[jpJP][pnPN]*[gG]')))
frames = read_imgs(input_img_list)
face_det_results = face_detect(frames)
coord_list = []
idx = 0
for frame,coords in face_det_results:
#x1, y1, x2, y2 = bbox
resized_crop_frame = cv2.resize(frame,(args.img_size, args.img_size)) #,interpolation = cv2.INTER_LANCZOS4)
cv2.imwrite(f"{face_imgs_path}/{idx:08d}.png", resized_crop_frame)
coord_list.append(coords)
idx = idx + 1
with open(coords_path, 'wb') as f:
pickle.dump(coord_list, f)

101
wav2lip/hparams.py Normal file
View File

@ -0,0 +1,101 @@
from glob import glob
import os
def get_image_list(data_root, split):
filelist = []
with open('filelists/{}.txt'.format(split)) as f:
for line in f:
line = line.strip()
if ' ' in line: line = line.split()[0]
filelist.append(os.path.join(data_root, line))
return filelist
class HParams:
def __init__(self, **kwargs):
self.data = {}
for key, value in kwargs.items():
self.data[key] = value
def __getattr__(self, key):
if key not in self.data:
raise AttributeError("'HParams' object has no attribute %s" % key)
return self.data[key]
def set_hparam(self, key, value):
self.data[key] = value
# Default hyperparameters
hparams = HParams(
num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality
# network
rescale=True, # Whether to rescale audio prior to preprocessing
rescaling_max=0.9, # Rescaling value
# Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction
# It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder
# Does not work if n_ffit is not multiple of hop_size!!
use_lws=False,
n_fft=800, # Extra window size is filled with 0 paddings to match this parameter
hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate)
win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate)
sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i <filename>)
frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5)
# Mel and Linear spectrograms normalization/scaling and clipping
signal_normalization=True,
# Whether to normalize mel spectrograms to some predefined range (following below parameters)
allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True
symmetric_mels=True,
# Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2,
# faster and cleaner convergence)
max_abs_value=4.,
# max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not
# be too big to avoid gradient explosion,
# not too small for fast convergence)
# Contribution by @begeekmyfriend
# Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude
# levels. Also allows for better G&L phase reconstruction)
preemphasize=True, # whether to apply filter
preemphasis=0.97, # filter coefficient.
# Limits
min_level_db=-100,
ref_level_db=20,
fmin=55,
# Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To
# test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
fmax=7600, # To be increased/reduced depending on data.
###################### Our training parameters #################################
img_size=96,
fps=25,
batch_size=16,
initial_learning_rate=1e-4,
nepochs=200000000000000000, ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs
num_workers=16,
checkpoint_interval=3000,
eval_interval=3000,
save_optimizer_state=True,
syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence.
syncnet_batch_size=64,
syncnet_lr=1e-4,
syncnet_eval_interval=10000,
syncnet_checkpoint_interval=10000,
disc_wt=0.07,
disc_initial_learning_rate=1e-4,
)
def hparams_debug_string():
values = hparams.values()
hp = [" %s: %s" % (name, values[name]) for name in sorted(values) if name != "sentences"]
return "Hyperparameters:\n" + "\n".join(hp)

View File

@ -0,0 +1,2 @@
from .wav2lip import Wav2Lip, Wav2Lip_disc_qual
from .syncnet import SyncNet_color

44
wav2lip/models/conv.py Normal file
View File

@ -0,0 +1,44 @@
import torch
from torch import nn
from torch.nn import functional as F
class Conv2d(nn.Module):
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
super().__init__(*args, **kwargs)
self.conv_block = nn.Sequential(
nn.Conv2d(cin, cout, kernel_size, stride, padding),
nn.BatchNorm2d(cout)
)
self.act = nn.ReLU()
self.residual = residual
def forward(self, x):
out = self.conv_block(x)
if self.residual:
out += x
return self.act(out)
class nonorm_Conv2d(nn.Module):
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
super().__init__(*args, **kwargs)
self.conv_block = nn.Sequential(
nn.Conv2d(cin, cout, kernel_size, stride, padding),
)
self.act = nn.LeakyReLU(0.01, inplace=True)
def forward(self, x):
out = self.conv_block(x)
return self.act(out)
class Conv2dTranspose(nn.Module):
def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs):
super().__init__(*args, **kwargs)
self.conv_block = nn.Sequential(
nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding),
nn.BatchNorm2d(cout)
)
self.act = nn.ReLU()
def forward(self, x):
out = self.conv_block(x)
return self.act(out)

Some files were not shown because too many files have changed in this diff Show More