feat: 添加 简单自动生成musetalk数字人

This commit is contained in:
Yun 2024-06-20 20:21:37 +08:00
parent 5da818b9d9
commit c0682408c5
4 changed files with 454 additions and 84 deletions

View File

@ -172,6 +172,13 @@ python -m scripts.realtime_inference --inference_config configs/inference/realti
运行后将results/avatars下文件拷到本项目的data/avatars下 运行后将results/avatars下文件拷到本项目的data/avatars下
``` ```
```bash
也可以试用本地目录下的 simple_musetalk.py
cd musetalk
python simple_musetalk.py --avatar_id 2 --file D:\\ok\\test.mp4
运行后将直接生成在data/avatars下
```
### 3.10 模型用wav2lip ### 3.10 模型用wav2lip
暂不支持rtmp推送 暂不支持rtmp推送
- 下载模型 - 下载模型

182
app.py
View File

@ -1,5 +1,5 @@
# server.py # server.py
from flask import Flask, render_template,send_from_directory,request, jsonify from flask import Flask, render_template, send_from_directory, request, jsonify
from flask_sockets import Sockets from flask_sockets import Sockets
import base64 import base64
import time import time
@ -10,7 +10,7 @@ from geventwebsocket.handler import WebSocketHandler
import os import os
import re import re
import numpy as np import numpy as np
from threading import Thread,Event from threading import Thread, Event
import multiprocessing import multiprocessing
from aiohttp import web from aiohttp import web
@ -24,16 +24,15 @@ import argparse
import shutil import shutil
import asyncio import asyncio
app = Flask(__name__) app = Flask(__name__)
sockets = Sockets(app) sockets = Sockets(app)
global nerfreal global nerfreal
@sockets.route('/humanecho') @sockets.route('/humanecho')
def echo_socket(ws): def echo_socket(ws):
# 获取WebSocket对象 # 获取WebSocket对象
#ws = request.environ.get('wsgi.websocket') # ws = request.environ.get('wsgi.websocket')
# 如果没有获取到,返回错误信息 # 如果没有获取到,返回错误信息
if not ws: if not ws:
print('未建立连接!') print('未建立连接!')
@ -42,11 +41,11 @@ def echo_socket(ws):
else: else:
print('建立连接!') print('建立连接!')
while True: while True:
message = ws.receive() message = ws.receive()
if not message or len(message)==0: if not message or len(message) == 0:
return '输入信息为空' return '输入信息为空'
else: else:
nerfreal.put_msg_txt(message) nerfreal.put_msg_txt(message)
@ -54,15 +53,16 @@ def llm_response(message):
from llm.LLM import LLM from llm.LLM import LLM
# llm = LLM().init_model('Gemini', model_path= 'gemini-pro',api_key='Your API Key', proxy_url=None) # llm = LLM().init_model('Gemini', model_path= 'gemini-pro',api_key='Your API Key', proxy_url=None)
# llm = LLM().init_model('ChatGPT', model_path= 'gpt-3.5-turbo',api_key='Your API Key') # llm = LLM().init_model('ChatGPT', model_path= 'gpt-3.5-turbo',api_key='Your API Key')
llm = LLM().init_model('VllmGPT', model_path= 'THUDM/chatglm3-6b') llm = LLM().init_model('VllmGPT', model_path='THUDM/chatglm3-6b')
response = llm.chat(message) response = llm.chat(message)
print(response) print(response)
return response return response
@sockets.route('/humanchat') @sockets.route('/humanchat')
def chat_socket(ws): def chat_socket(ws):
# 获取WebSocket对象 # 获取WebSocket对象
#ws = request.environ.get('wsgi.websocket') # ws = request.environ.get('wsgi.websocket')
# 如果没有获取到,返回错误信息 # 如果没有获取到,返回错误信息
if not ws: if not ws:
print('未建立连接!') print('未建立连接!')
@ -71,18 +71,20 @@ def chat_socket(ws):
else: else:
print('建立连接!') print('建立连接!')
while True: while True:
message = ws.receive() message = ws.receive()
if len(message)==0: if len(message) == 0:
return '输入信息为空' return '输入信息为空'
else: else:
res=llm_response(message) res = llm_response(message)
nerfreal.put_msg_txt(res) nerfreal.put_msg_txt(res)
#####webrtc############################### #####webrtc###############################
pcs = set() pcs = set()
#@app.route('/offer', methods=['POST'])
# @app.route('/offer', methods=['POST'])
async def offer(request): async def offer(request):
params = await request.json() params = await request.json()
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"]) offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
@ -106,7 +108,7 @@ async def offer(request):
answer = await pc.createAnswer() answer = await pc.createAnswer()
await pc.setLocalDescription(answer) await pc.setLocalDescription(answer)
#return jsonify({"sdp": pc.localDescription.sdp, "type": pc.localDescription.type}) # return jsonify({"sdp": pc.localDescription.sdp, "type": pc.localDescription.type})
return web.Response( return web.Response(
content_type="application/json", content_type="application/json",
@ -115,36 +117,40 @@ async def offer(request):
), ),
) )
async def human(request): async def human(request):
params = await request.json() params = await request.json()
if params['type']=='echo': if params['type'] == 'echo':
nerfreal.put_msg_txt(params['text']) nerfreal.put_msg_txt(params['text'])
elif params['type']=='chat': elif params['type'] == 'chat':
res=await asyncio.get_event_loop().run_in_executor(None, llm_response(params['text'])) res = await asyncio.get_event_loop().run_in_executor(None, llm_response(params['text']))
nerfreal.put_msg_txt(res) nerfreal.put_msg_txt(res)
return web.Response( return web.Response(
content_type="application/json", content_type="application/json",
text=json.dumps( text=json.dumps(
{"code": 0, "data":"ok"} {"code": 0, "data": "ok"}
), ),
) )
async def on_shutdown(app): async def on_shutdown(app):
# close peer connections # close peer connections
coros = [pc.close() for pc in pcs] coros = [pc.close() for pc in pcs]
await asyncio.gather(*coros) await asyncio.gather(*coros)
pcs.clear() pcs.clear()
async def post(url,data):
async def post(url, data):
try: try:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.post(url,data=data) as response: async with session.post(url, data=data) as response:
return await response.text() return await response.text()
except aiohttp.ClientError as e: except aiohttp.ClientError as e:
print(f'Error: {e}') print(f'Error: {e}')
async def run(push_url): async def run(push_url):
pc = RTCPeerConnection() pc = RTCPeerConnection()
pcs.add(pc) pcs.add(pc)
@ -161,8 +167,10 @@ async def run(push_url):
video_sender = pc.addTrack(player.video) video_sender = pc.addTrack(player.video)
await pc.setLocalDescription(await pc.createOffer()) await pc.setLocalDescription(await pc.createOffer())
answer = await post(push_url,pc.localDescription.sdp) answer = await post(push_url, pc.localDescription.sdp)
await pc.setRemoteDescription(RTCSessionDescription(sdp=answer,type='answer')) await pc.setRemoteDescription(RTCSessionDescription(sdp=answer, type='answer'))
########################################## ##########################################
# os.environ['MKL_SERVICE_FORCE_INTEL'] = '1' # os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
# os.environ['MULTIPROCESSING_METHOD'] = 'forkserver' # os.environ['MULTIPROCESSING_METHOD'] = 'forkserver'
@ -181,14 +189,20 @@ if __name__ == '__main__':
### training options ### training options
parser.add_argument('--ckpt', type=str, default='data/pretrained/ngp_kf.pth') parser.add_argument('--ckpt', type=str, default='data/pretrained/ngp_kf.pth')
parser.add_argument('--num_rays', type=int, default=4096 * 16, help="num rays sampled per image for each training step") parser.add_argument('--num_rays', type=int, default=4096 * 16,
help="num rays sampled per image for each training step")
parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch") parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch")
parser.add_argument('--max_steps', type=int, default=16, help="max num steps sampled per ray (only valid when using --cuda_ray)") parser.add_argument('--max_steps', type=int, default=16,
parser.add_argument('--num_steps', type=int, default=16, help="num steps sampled per ray (only valid when NOT using --cuda_ray)") help="max num steps sampled per ray (only valid when using --cuda_ray)")
parser.add_argument('--upsample_steps', type=int, default=0, help="num steps up-sampled per ray (only valid when NOT using --cuda_ray)") parser.add_argument('--num_steps', type=int, default=16,
parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)") help="num steps sampled per ray (only valid when NOT using --cuda_ray)")
parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)") parser.add_argument('--upsample_steps', type=int, default=0,
help="num steps up-sampled per ray (only valid when NOT using --cuda_ray)")
parser.add_argument('--update_extra_interval', type=int, default=16,
help="iter interval to update extra status (only valid when using --cuda_ray)")
parser.add_argument('--max_ray_batch', type=int, default=4096,
help="batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)")
### loss set ### loss set
parser.add_argument('--warmup_step', type=int, default=10000, help="warm up steps") parser.add_argument('--warmup_step', type=int, default=10000, help="warm up steps")
@ -199,27 +213,35 @@ if __name__ == '__main__':
### network backbone options ### network backbone options
parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training") parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training")
parser.add_argument('--bg_img', type=str, default='white', help="background image") parser.add_argument('--bg_img', type=str, default='white', help="background image")
parser.add_argument('--fbg', action='store_true', help="frame-wise bg") parser.add_argument('--fbg', action='store_true', help="frame-wise bg")
parser.add_argument('--exp_eye', action='store_true', help="explicitly control the eyes") parser.add_argument('--exp_eye', action='store_true', help="explicitly control the eyes")
parser.add_argument('--fix_eye', type=float, default=-1, help="fixed eye area, negative to disable, set to 0-0.3 for a reasonable eye") parser.add_argument('--fix_eye', type=float, default=-1,
help="fixed eye area, negative to disable, set to 0-0.3 for a reasonable eye")
parser.add_argument('--smooth_eye', action='store_true', help="smooth the eye area sequence") parser.add_argument('--smooth_eye', action='store_true', help="smooth the eye area sequence")
parser.add_argument('--torso_shrink', type=float, default=0.8, help="shrink bg coords to allow more flexibility in deform") parser.add_argument('--torso_shrink', type=float, default=0.8,
help="shrink bg coords to allow more flexibility in deform")
### dataset options ### dataset options
parser.add_argument('--color_space', type=str, default='srgb', help="Color space, supports (linear, srgb)") parser.add_argument('--color_space', type=str, default='srgb', help="Color space, supports (linear, srgb)")
parser.add_argument('--preload', type=int, default=0, help="0 means load data from disk on-the-fly, 1 means preload to CPU, 2 means GPU.") parser.add_argument('--preload', type=int, default=0,
help="0 means load data from disk on-the-fly, 1 means preload to CPU, 2 means GPU.")
# (the default value is for the fox dataset) # (the default value is for the fox dataset)
parser.add_argument('--bound', type=float, default=1, help="assume the scene is bounded in box[-bound, bound]^3, if > 1, will invoke adaptive ray marching.") parser.add_argument('--bound', type=float, default=1,
help="assume the scene is bounded in box[-bound, bound]^3, if > 1, will invoke adaptive ray marching.")
parser.add_argument('--scale', type=float, default=4, help="scale camera location into box[-bound, bound]^3") parser.add_argument('--scale', type=float, default=4, help="scale camera location into box[-bound, bound]^3")
parser.add_argument('--offset', type=float, nargs='*', default=[0, 0, 0], help="offset of camera location") parser.add_argument('--offset', type=float, nargs='*', default=[0, 0, 0], help="offset of camera location")
parser.add_argument('--dt_gamma', type=float, default=1/256, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)") parser.add_argument('--dt_gamma', type=float, default=1 / 256,
help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)")
parser.add_argument('--min_near', type=float, default=0.05, help="minimum near distance for camera") parser.add_argument('--min_near', type=float, default=0.05, help="minimum near distance for camera")
parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied (sigma)") parser.add_argument('--density_thresh', type=float, default=10,
parser.add_argument('--density_thresh_torso', type=float, default=0.01, help="threshold for density grid to be occupied (alpha)") help="threshold for density grid to be occupied (sigma)")
parser.add_argument('--patch_size', type=int, default=1, help="[experimental] render patches in training, so as to apply LPIPS loss. 1 means disabled, use [64, 32, 16] to enable") parser.add_argument('--density_thresh_torso', type=float, default=0.01,
help="threshold for density grid to be occupied (alpha)")
parser.add_argument('--patch_size', type=int, default=1,
help="[experimental] render patches in training, so as to apply LPIPS loss. 1 means disabled, use [64, 32, 16] to enable")
parser.add_argument('--init_lips', action='store_true', help="init lips region") parser.add_argument('--init_lips', action='store_true', help="init lips region")
parser.add_argument('--finetune_lips', action='store_true', help="use LPIPS and landmarks to fine tune lips region") parser.add_argument('--finetune_lips', action='store_true', help="use LPIPS and landmarks to fine tune lips region")
@ -237,12 +259,15 @@ if __name__ == '__main__':
parser.add_argument('--max_spp', type=int, default=1, help="GUI rendering max sample per pixel") parser.add_argument('--max_spp', type=int, default=1, help="GUI rendering max sample per pixel")
### else ### else
parser.add_argument('--att', type=int, default=2, help="audio attention mode (0 = turn off, 1 = left-direction, 2 = bi-direction)") parser.add_argument('--att', type=int, default=2,
parser.add_argument('--aud', type=str, default='', help="audio source (empty will load the default, else should be a path to a npy file)") help="audio attention mode (0 = turn off, 1 = left-direction, 2 = bi-direction)")
parser.add_argument('--aud', type=str, default='',
help="audio source (empty will load the default, else should be a path to a npy file)")
parser.add_argument('--emb', action='store_true', help="use audio class + embedding instead of logits") parser.add_argument('--emb', action='store_true', help="use audio class + embedding instead of logits")
parser.add_argument('--ind_dim', type=int, default=4, help="individual code dim, 0 to turn off") parser.add_argument('--ind_dim', type=int, default=4, help="individual code dim, 0 to turn off")
parser.add_argument('--ind_num', type=int, default=10000, help="number of individual codes, should be larger than training dataset size") parser.add_argument('--ind_num', type=int, default=10000,
help="number of individual codes, should be larger than training dataset size")
parser.add_argument('--ind_dim_torso', type=int, default=8, help="individual code dim, 0 to turn off") parser.add_argument('--ind_dim_torso', type=int, default=8, help="individual code dim, 0 to turn off")
@ -251,7 +276,8 @@ if __name__ == '__main__':
parser.add_argument('--part2', action='store_true', help="use partial training data (first 15s)") parser.add_argument('--part2', action='store_true', help="use partial training data (first 15s)")
parser.add_argument('--train_camera', action='store_true', help="optimize camera pose") parser.add_argument('--train_camera', action='store_true', help="optimize camera pose")
parser.add_argument('--smooth_path', action='store_true', help="brute-force smooth camera pose trajectory with a window size") parser.add_argument('--smooth_path', action='store_true',
help="brute-force smooth camera pose trajectory with a window size")
parser.add_argument('--smooth_path_window', type=int, default=7, help="smoothing window size") parser.add_argument('--smooth_path_window', type=int, default=7, help="smoothing window size")
# asr # asr
@ -259,8 +285,8 @@ if __name__ == '__main__':
parser.add_argument('--asr_wav', type=str, default='', help="load the wav and use as input") parser.add_argument('--asr_wav', type=str, default='', help="load the wav and use as input")
parser.add_argument('--asr_play', action='store_true', help="play out the audio") parser.add_argument('--asr_play', action='store_true', help="play out the audio")
#parser.add_argument('--asr_model', type=str, default='deepspeech') # parser.add_argument('--asr_model', type=str, default='deepspeech')
parser.add_argument('--asr_model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto') # parser.add_argument('--asr_model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto') #
# parser.add_argument('--asr_model', type=str, default='facebook/wav2vec2-large-960h-lv60-self') # parser.add_argument('--asr_model', type=str, default='facebook/wav2vec2-large-960h-lv60-self')
# parser.add_argument('--asr_model', type=str, default='facebook/hubert-large-ls960-ft') # parser.add_argument('--asr_model', type=str, default='facebook/hubert-large-ls960-ft')
@ -279,7 +305,7 @@ if __name__ == '__main__':
parser.add_argument('--fullbody_offset_x', type=int, default=0) parser.add_argument('--fullbody_offset_x', type=int, default=0)
parser.add_argument('--fullbody_offset_y', type=int, default=0) parser.add_argument('--fullbody_offset_y', type=int, default=0)
#musetalk opt # musetalk opt
parser.add_argument('--avatar_id', type=str, default='avator_1') parser.add_argument('--avatar_id', type=str, default='avator_1')
parser.add_argument('--bbox_shift', type=int, default=5) parser.add_argument('--bbox_shift', type=int, default=5)
parser.add_argument('--batch_size', type=int, default=16) parser.add_argument('--batch_size', type=int, default=16)
@ -289,33 +315,35 @@ if __name__ == '__main__':
parser.add_argument('--customvideo_img', type=str, default='data/customvideo/img') parser.add_argument('--customvideo_img', type=str, default='data/customvideo/img')
parser.add_argument('--customvideo_imgnum', type=int, default=1) parser.add_argument('--customvideo_imgnum', type=int, default=1)
parser.add_argument('--tts', type=str, default='edgetts') #xtts gpt-sovits parser.add_argument('--tts', type=str, default='edgetts') # xtts gpt-sovits
parser.add_argument('--REF_FILE', type=str, default=None) parser.add_argument('--REF_FILE', type=str, default=None)
parser.add_argument('--REF_TEXT', type=str, default=None) parser.add_argument('--REF_TEXT', type=str, default=None)
parser.add_argument('--TTS_SERVER', type=str, default='http://127.0.0.1:9880') # http://localhost:9000 parser.add_argument('--TTS_SERVER', type=str, default='http://127.0.0.1:9880') # http://localhost:9000
# parser.add_argument('--CHARACTER', type=str, default='test') # parser.add_argument('--CHARACTER', type=str, default='test')
# parser.add_argument('--EMOTION', type=str, default='default') # parser.add_argument('--EMOTION', type=str, default='default')
parser.add_argument('--model', type=str, default='ernerf') #musetalk wav2lip parser.add_argument('--model', type=str, default='ernerf') # musetalk wav2lip
parser.add_argument('--transport', type=str, default='rtcpush') #rtmp webrtc rtcpush parser.add_argument('--transport', type=str, default='rtcpush') # rtmp webrtc rtcpush
parser.add_argument('--push_url', type=str, default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') #rtmp://localhost/live/livestream parser.add_argument('--push_url', type=str,
default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') # rtmp://localhost/live/livestream
parser.add_argument('--listenport', type=int, default=8010) parser.add_argument('--listenport', type=int, default=8010)
opt = parser.parse_args() opt = parser.parse_args()
#app.config.from_object(opt) # app.config.from_object(opt)
#print(app.config) # print(app.config)
if opt.model == 'ernerf': if opt.model == 'ernerf':
from ernerf.nerf_triplane.provider import NeRFDataset_Test from ernerf.nerf_triplane.provider import NeRFDataset_Test
from ernerf.nerf_triplane.utils import * from ernerf.nerf_triplane.utils import *
from ernerf.nerf_triplane.network import NeRFNetwork from ernerf.nerf_triplane.network import NeRFNetwork
from nerfreal import NeRFReal from nerfreal import NeRFReal
# assert test mode # assert test mode
opt.test = True opt.test = True
opt.test_train = False opt.test_train = False
#opt.train_camera =True # opt.train_camera =True
# explicit smoothing # explicit smoothing
opt.smooth_path = True opt.smooth_path = True
opt.smooth_lips = True opt.smooth_lips = True
@ -328,7 +356,7 @@ if __name__ == '__main__':
opt.exp_eye = True opt.exp_eye = True
opt.smooth_eye = True opt.smooth_eye = True
if opt.torso_imgs=='': #no img,use model output if opt.torso_imgs == '': # no img,use model output
opt.torso = True opt.torso = True
# assert opt.cuda_ray, "Only support CUDA ray mode." # assert opt.cuda_ray, "Only support CUDA ray mode."
@ -344,9 +372,10 @@ if __name__ == '__main__':
model = NeRFNetwork(opt) model = NeRFNetwork(opt)
criterion = torch.nn.MSELoss(reduction='none') criterion = torch.nn.MSELoss(reduction='none')
metrics = [] # use no metric in GUI for faster initialization... metrics = [] # use no metric in GUI for faster initialization...
print(model) print(model)
trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, metrics=metrics, use_checkpoint=opt.ckpt) trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16,
metrics=metrics, use_checkpoint=opt.ckpt)
test_loader = NeRFDataset_Test(opt, device=device).dataloader() test_loader = NeRFDataset_Test(opt, device=device).dataloader()
model.aud_features = test_loader._data.auds model.aud_features = test_loader._data.auds
@ -356,17 +385,19 @@ if __name__ == '__main__':
nerfreal = NeRFReal(opt, trainer, test_loader) nerfreal = NeRFReal(opt, trainer, test_loader)
elif opt.model == 'musetalk': elif opt.model == 'musetalk':
from musereal import MuseReal from musereal import MuseReal
print(opt) print(opt)
nerfreal = MuseReal(opt) nerfreal = MuseReal(opt)
elif opt.model == 'wav2lip': elif opt.model == 'wav2lip':
from lipreal import LipReal from lipreal import LipReal
print(opt) print(opt)
nerfreal = LipReal(opt) nerfreal = LipReal(opt)
#txt_to_audio('我是中国人,我来自北京') # txt_to_audio('我是中国人,我来自北京')
if opt.transport=='rtmp': if opt.transport == 'rtmp':
thread_quit = Event() thread_quit = Event()
rendthrd = Thread(target=nerfreal.render,args=(thread_quit,)) rendthrd = Thread(target=nerfreal.render, args=(thread_quit,))
rendthrd.start() rendthrd.start()
############################################################################# #############################################################################
@ -374,35 +405,36 @@ if __name__ == '__main__':
appasync.on_shutdown.append(on_shutdown) appasync.on_shutdown.append(on_shutdown)
appasync.router.add_post("/offer", offer) appasync.router.add_post("/offer", offer)
appasync.router.add_post("/human", human) appasync.router.add_post("/human", human)
appasync.router.add_static('/',path='web') appasync.router.add_static('/', path='web')
# Configure default CORS settings. # Configure default CORS settings.
cors = aiohttp_cors.setup(appasync, defaults={ cors = aiohttp_cors.setup(appasync, defaults={
"*": aiohttp_cors.ResourceOptions( "*": aiohttp_cors.ResourceOptions(
allow_credentials=True, allow_credentials=True,
expose_headers="*", expose_headers="*",
allow_headers="*", allow_headers="*",
) )
}) })
# Configure CORS on all routes. # Configure CORS on all routes.
for route in list(appasync.router.routes()): for route in list(appasync.router.routes()):
cors.add(route) cors.add(route)
def run_server(runner): def run_server(runner):
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
loop.run_until_complete(runner.setup()) loop.run_until_complete(runner.setup())
site = web.TCPSite(runner, '0.0.0.0', opt.listenport) site = web.TCPSite(runner, '0.0.0.0', opt.listenport)
loop.run_until_complete(site.start()) loop.run_until_complete(site.start())
if opt.transport=='rtcpush': if opt.transport == 'rtcpush':
loop.run_until_complete(run(opt.push_url)) loop.run_until_complete(run(opt.push_url))
loop.run_forever() loop.run_forever()
Thread(target=run_server, args=(web.AppRunner(appasync),)).start() Thread(target=run_server, args=(web.AppRunner(appasync),)).start()
print('start websocket server') print('start websocket server')
#app.on_shutdown.append(on_shutdown) # app.on_shutdown.append(on_shutdown)
#app.router.add_post("/offer", offer) # app.router.add_post("/offer", offer)
server = pywsgi.WSGIServer(('0.0.0.0', 8000), app, handler_class=WebSocketHandler) server = pywsgi.WSGIServer(('0.0.0.0', 8000), app, handler_class=WebSocketHandler)
server.serve_forever() server.serve_forever()

331
musetalk/simple_musetalk.py Normal file
View File

@ -0,0 +1,331 @@
import argparse
import glob
import json
import os
import pickle
import shutil
import sys
import cv2
import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image
from diffusers import AutoencoderKL
from face_alignment import NetworkSize
from mmpose.apis import inference_topdown, init_model
from mmpose.structures import merge_data_samples
from tqdm import tqdm
from utils.face_parsing import FaceParsing
def video2imgs(vid_path, save_path, ext='.png', cut_frame=10000000):
cap = cv2.VideoCapture(vid_path)
count = 0
while True:
if count > cut_frame:
break
ret, frame = cap.read()
if ret:
cv2.imwrite(f"{save_path}/{count:08d}.png", frame)
count += 1
else:
break
def read_imgs(img_list):
frames = []
print('reading images...')
for img_path in tqdm(img_list):
frame = cv2.imread(img_path)
frames.append(frame)
return frames
def get_landmark_and_bbox(img_list, upperbondrange=0):
frames = read_imgs(img_list)
batch_size_fa = 1
batches = [frames[i:i + batch_size_fa] for i in range(0, len(frames), batch_size_fa)]
coords_list = []
landmarks = []
if upperbondrange != 0:
print('get key_landmark and face bounding boxes with the bbox_shift:', upperbondrange)
else:
print('get key_landmark and face bounding boxes with the default value')
average_range_minus = []
average_range_plus = []
for fb in tqdm(batches):
results = inference_topdown(model, np.asarray(fb)[0])
results = merge_data_samples(results)
keypoints = results.pred_instances.keypoints
face_land_mark = keypoints[0][23:91]
face_land_mark = face_land_mark.astype(np.int32)
# get bounding boxes by face detetion
bbox = fa.get_detections_for_batch(np.asarray(fb))
# adjust the bounding box refer to landmark
# Add the bounding box to a tuple and append it to the coordinates list
for j, f in enumerate(bbox):
if f is None: # no face in the image
coords_list += [coord_placeholder]
continue
half_face_coord = face_land_mark[29] # np.mean([face_land_mark[28], face_land_mark[29]], axis=0)
range_minus = (face_land_mark[30] - face_land_mark[29])[1]
range_plus = (face_land_mark[29] - face_land_mark[28])[1]
average_range_minus.append(range_minus)
average_range_plus.append(range_plus)
if upperbondrange != 0:
half_face_coord[1] = upperbondrange + half_face_coord[1] # 手动调整 + 向下偏29 - 向上偏28
half_face_dist = np.max(face_land_mark[:, 1]) - half_face_coord[1]
upper_bond = half_face_coord[1] - half_face_dist
f_landmark = (
np.min(face_land_mark[:, 0]), int(upper_bond), np.max(face_land_mark[:, 0]),
np.max(face_land_mark[:, 1]))
x1, y1, x2, y2 = f_landmark
if y2 - y1 <= 0 or x2 - x1 <= 0 or x1 < 0: # if the landmark bbox is not suitable, reuse the bbox
coords_list += [f]
w, h = f[2] - f[0], f[3] - f[1]
print("error bbox:", f)
else:
coords_list += [f_landmark]
return coords_list, frames
class FaceAlignment:
def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
device='cuda', flip_input=False, face_detector='sfd', verbose=False):
self.device = device
self.flip_input = flip_input
self.landmarks_type = landmarks_type
self.verbose = verbose
network_size = int(network_size)
if 'cuda' in device:
torch.backends.cudnn.benchmark = True
# torch.backends.cuda.matmul.allow_tf32 = False
# torch.backends.cudnn.benchmark = True
# torch.backends.cudnn.deterministic = False
# torch.backends.cudnn.allow_tf32 = True
print('cuda start')
# Get the face detector
face_detector_module = __import__('face_detection.detection.' + face_detector,
globals(), locals(), [face_detector], 0)
self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose)
def get_detections_for_batch(self, images):
images = images[..., ::-1]
detected_faces = self.face_detector.detect_from_batch(images.copy())
results = []
for i, d in enumerate(detected_faces):
if len(d) == 0:
results.append(None)
continue
d = d[0]
d = np.clip(d, 0, None)
x1, y1, x2, y2 = map(int, d[:-1])
results.append((x1, y1, x2, y2))
return results
def get_mask_tensor():
"""
Creates a mask tensor for image processing.
:return: A mask tensor.
"""
mask_tensor = torch.zeros((256, 256))
mask_tensor[:256 // 2, :] = 1
mask_tensor[mask_tensor < 0.5] = 0
mask_tensor[mask_tensor >= 0.5] = 1
return mask_tensor
def preprocess_img(img_name, half_mask=False):
window = []
if isinstance(img_name, str):
window_fnames = [img_name]
for fname in window_fnames:
img = cv2.imread(fname)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (256, 256),
interpolation=cv2.INTER_LANCZOS4)
window.append(img)
else:
img = cv2.cvtColor(img_name, cv2.COLOR_BGR2RGB)
window.append(img)
x = np.asarray(window) / 255.
x = np.transpose(x, (3, 0, 1, 2))
x = torch.squeeze(torch.FloatTensor(x))
if half_mask:
x = x * (get_mask_tensor() > 0.5)
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
x = normalize(x)
x = x.unsqueeze(0) # [1, 3, 256, 256] torch tensor
x = x.to(device)
return x
def encode_latents(image):
with torch.no_grad():
init_latent_dist = vae.encode(image.to(vae.dtype)).latent_dist
init_latents = vae.config.scaling_factor * init_latent_dist.sample()
return init_latents
def get_latents_for_unet(img):
ref_image = preprocess_img(img, half_mask=True) # [1, 3, 256, 256] RGB, torch tensor
masked_latents = encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
ref_image = preprocess_img(img, half_mask=False) # [1, 3, 256, 256] RGB, torch tensor
ref_latents = encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
latent_model_input = torch.cat([masked_latents, ref_latents], dim=1)
return latent_model_input
def get_crop_box(box, expand):
x, y, x1, y1 = box
x_c, y_c = (x + x1) // 2, (y + y1) // 2
w, h = x1 - x, y1 - y
s = int(max(w, h) // 2 * expand)
crop_box = [x_c - s, y_c - s, x_c + s, y_c + s]
return crop_box, s
def face_seg(image):
seg_image = fp(image)
if seg_image is None:
print("error, no person_segment")
return None
seg_image = seg_image.resize(image.size)
return seg_image
def get_image_prepare_material(image, face_box, upper_boundary_ratio=0.5, expand=1.2):
body = Image.fromarray(image[:, :, ::-1])
x, y, x1, y1 = face_box
# print(x1-x,y1-y)
crop_box, s = get_crop_box(face_box, expand)
x_s, y_s, x_e, y_e = crop_box
face_large = body.crop(crop_box)
ori_shape = face_large.size
mask_image = face_seg(face_large)
mask_small = mask_image.crop((x - x_s, y - y_s, x1 - x_s, y1 - y_s))
mask_image = Image.new('L', ori_shape, 0)
mask_image.paste(mask_small, (x - x_s, y - y_s, x1 - x_s, y1 - y_s))
# keep upper_boundary_ratio of talking area
width, height = mask_image.size
top_boundary = int(height * upper_boundary_ratio)
modified_mask_image = Image.new('L', ori_shape, 0)
modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary))
blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1
mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0)
return mask_array, crop_box
def create_dir(dir_path):
if not os.path.exists(dir_path):
os.makedirs(dir_path)
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
# initialize the mmpose model
device = "cuda" if torch.cuda.is_available() else "cpu"
fa = FaceAlignment(1, flip_input=False, device=device)
config_file = './utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py'
checkpoint_file = '../models/dwpose/dw-ll_ucoco_384.pth'
model = init_model(config_file, checkpoint_file, device=device)
vae = AutoencoderKL.from_pretrained("../models/sd-vae-ft-mse")
vae.to(device)
fp = FaceParsing()
if __name__ == '__main__':
# 视频文件地址
parser = argparse.ArgumentParser()
parser.add_argument("--file",
type=str,
default=r'D:\ok\test.mp4',
)
parser.add_argument("--avatar_id",
type=str,
default='1',
)
args = parser.parse_args()
file = args.file
# 保存文件设置 可以不动
save_path = f'../data/avatars/avator_{args.avatar_id}'
save_full_path = f'../data/avatars/avator_{args.avatar_id}/full_imgs'
create_dir(save_path)
create_dir(save_full_path)
mask_out_path = f'../data/avatars/avator_{args.avatar_id}/mask'
create_dir(mask_out_path)
# 模型
mask_coords_path = f'{save_path}/mask_coords.pkl'
coords_path = f'{save_path}/coords.pkl'
latents_out_path = f'{save_path}/latents.pt'
with open(f'{save_path}/avator_info.json', "w") as f:
json.dump({
"avatar_id": args.avatar_id,
"video_path": file,
"bbox_shift": 5
}, f)
if os.path.isfile(file):
video2imgs(file, save_full_path, ext='png')
else:
files = os.listdir(file)
files.sort()
files = [file for file in files if file.split(".")[-1] == "png"]
for filename in files:
shutil.copyfile(f"{file}/{filename}", f"{save_full_path}/{filename}")
input_img_list = sorted(glob.glob(os.path.join(save_full_path, '*.[jpJP][pnPN]*[gG]')))
print("extracting landmarks...")
coord_list, frame_list = get_landmark_and_bbox(input_img_list, 5)
input_latent_list = []
idx = -1
# maker if the bbox is not sufficient
coord_placeholder = (0.0, 0.0, 0.0, 0.0)
for bbox, frame in zip(coord_list, frame_list):
idx = idx + 1
if bbox == coord_placeholder:
continue
x1, y1, x2, y2 = bbox
crop_frame = frame[y1:y2, x1:x2]
resized_crop_frame = cv2.resize(crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4)
latents = get_latents_for_unet(resized_crop_frame)
input_latent_list.append(latents)
frame_list_cycle = frame_list + frame_list[::-1]
coord_list_cycle = coord_list + coord_list[::-1]
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
mask_coords_list_cycle = []
mask_list_cycle = []
for i, frame in enumerate(tqdm(frame_list_cycle)):
cv2.imwrite(f"{save_full_path}/{str(i).zfill(8)}.png", frame)
face_box = coord_list_cycle[i]
mask, crop_box = get_image_prepare_material(frame, face_box)
cv2.imwrite(f"{mask_out_path}/{str(i).zfill(8)}.png", mask)
mask_coords_list_cycle += [crop_box]
mask_list_cycle.append(mask)
with open(mask_coords_path, 'wb') as f:
pickle.dump(mask_coords_list_cycle, f)
with open(coords_path, 'wb') as f:
pickle.dump(coord_list_cycle, f)
torch.save(input_latent_list_cycle, os.path.join(latents_out_path))

View File

@ -7,18 +7,18 @@ from PIL import Image
from .model import BiSeNet from .model import BiSeNet
import torchvision.transforms as transforms import torchvision.transforms as transforms
class FaceParsing(): class FaceParsing():
def __init__(self): def __init__(self, resnet_path='./models/face-parse-bisent/resnet18-5c106cde.pth',
self.net = self.model_init() model_pth='./models/face-parse-bisent/79999_iter.pth'):
self.net = self.model_init(resnet_path,model_pth)
self.preprocess = self.image_preprocess() self.preprocess = self.image_preprocess()
def model_init(self, def model_init(self,resnet_path, model_pth):
resnet_path='./models/face-parse-bisent/resnet18-5c106cde.pth',
model_pth='./models/face-parse-bisent/79999_iter.pth'):
net = BiSeNet(resnet_path) net = BiSeNet(resnet_path)
if torch.cuda.is_available(): if torch.cuda.is_available():
net.cuda() net.cuda()
net.load_state_dict(torch.load(model_pth)) net.load_state_dict(torch.load(model_pth))
else: else:
net.load_state_dict(torch.load(model_pth, map_location=torch.device('cpu'))) net.load_state_dict(torch.load(model_pth, map_location=torch.device('cpu')))
net.eval() net.eval()
@ -44,13 +44,13 @@ class FaceParsing():
img = torch.unsqueeze(img, 0) img = torch.unsqueeze(img, 0)
out = self.net(img)[0] out = self.net(img)[0]
parsing = out.squeeze(0).cpu().numpy().argmax(0) parsing = out.squeeze(0).cpu().numpy().argmax(0)
parsing[np.where(parsing>13)] = 0 parsing[np.where(parsing > 13)] = 0
parsing[np.where(parsing>=1)] = 255 parsing[np.where(parsing >= 1)] = 255
parsing = Image.fromarray(parsing.astype(np.uint8)) parsing = Image.fromarray(parsing.astype(np.uint8))
return parsing return parsing
if __name__ == "__main__": if __name__ == "__main__":
fp = FaceParsing() fp = FaceParsing()
segmap = fp('154_small.png') segmap = fp('154_small.png')
segmap.save('res.png') segmap.save('res.png')