feat: 添加 简单自动生成musetalk数字人
This commit is contained in:
parent
5da818b9d9
commit
c0682408c5
|
@ -172,6 +172,13 @@ python -m scripts.realtime_inference --inference_config configs/inference/realti
|
|||
运行后将results/avatars下文件拷到本项目的data/avatars下
|
||||
```
|
||||
|
||||
```bash
|
||||
也可以试用本地目录下的 simple_musetalk.py
|
||||
cd musetalk
|
||||
python simple_musetalk.py --avatar_id 2 --file D:\\ok\\test.mp4
|
||||
运行后将直接生成在data/avatars下
|
||||
```
|
||||
|
||||
### 3.10 模型用wav2lip
|
||||
暂不支持rtmp推送
|
||||
- 下载模型
|
||||
|
|
182
app.py
182
app.py
|
@ -1,5 +1,5 @@
|
|||
# server.py
|
||||
from flask import Flask, render_template,send_from_directory,request, jsonify
|
||||
from flask import Flask, render_template, send_from_directory, request, jsonify
|
||||
from flask_sockets import Sockets
|
||||
import base64
|
||||
import time
|
||||
|
@ -10,7 +10,7 @@ from geventwebsocket.handler import WebSocketHandler
|
|||
import os
|
||||
import re
|
||||
import numpy as np
|
||||
from threading import Thread,Event
|
||||
from threading import Thread, Event
|
||||
import multiprocessing
|
||||
|
||||
from aiohttp import web
|
||||
|
@ -24,16 +24,15 @@ import argparse
|
|||
import shutil
|
||||
import asyncio
|
||||
|
||||
|
||||
app = Flask(__name__)
|
||||
sockets = Sockets(app)
|
||||
global nerfreal
|
||||
|
||||
|
||||
|
||||
@sockets.route('/humanecho')
|
||||
def echo_socket(ws):
|
||||
# 获取WebSocket对象
|
||||
#ws = request.environ.get('wsgi.websocket')
|
||||
# ws = request.environ.get('wsgi.websocket')
|
||||
# 如果没有获取到,返回错误信息
|
||||
if not ws:
|
||||
print('未建立连接!')
|
||||
|
@ -42,11 +41,11 @@ def echo_socket(ws):
|
|||
else:
|
||||
print('建立连接!')
|
||||
while True:
|
||||
message = ws.receive()
|
||||
|
||||
if not message or len(message)==0:
|
||||
message = ws.receive()
|
||||
|
||||
if not message or len(message) == 0:
|
||||
return '输入信息为空'
|
||||
else:
|
||||
else:
|
||||
nerfreal.put_msg_txt(message)
|
||||
|
||||
|
||||
|
@ -54,15 +53,16 @@ def llm_response(message):
|
|||
from llm.LLM import LLM
|
||||
# llm = LLM().init_model('Gemini', model_path= 'gemini-pro',api_key='Your API Key', proxy_url=None)
|
||||
# llm = LLM().init_model('ChatGPT', model_path= 'gpt-3.5-turbo',api_key='Your API Key')
|
||||
llm = LLM().init_model('VllmGPT', model_path= 'THUDM/chatglm3-6b')
|
||||
llm = LLM().init_model('VllmGPT', model_path='THUDM/chatglm3-6b')
|
||||
response = llm.chat(message)
|
||||
print(response)
|
||||
return response
|
||||
|
||||
|
||||
@sockets.route('/humanchat')
|
||||
def chat_socket(ws):
|
||||
# 获取WebSocket对象
|
||||
#ws = request.environ.get('wsgi.websocket')
|
||||
# ws = request.environ.get('wsgi.websocket')
|
||||
# 如果没有获取到,返回错误信息
|
||||
if not ws:
|
||||
print('未建立连接!')
|
||||
|
@ -71,18 +71,20 @@ def chat_socket(ws):
|
|||
else:
|
||||
print('建立连接!')
|
||||
while True:
|
||||
message = ws.receive()
|
||||
|
||||
if len(message)==0:
|
||||
message = ws.receive()
|
||||
|
||||
if len(message) == 0:
|
||||
return '输入信息为空'
|
||||
else:
|
||||
res=llm_response(message)
|
||||
res = llm_response(message)
|
||||
nerfreal.put_msg_txt(res)
|
||||
|
||||
|
||||
#####webrtc###############################
|
||||
pcs = set()
|
||||
|
||||
#@app.route('/offer', methods=['POST'])
|
||||
|
||||
# @app.route('/offer', methods=['POST'])
|
||||
async def offer(request):
|
||||
params = await request.json()
|
||||
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
|
||||
|
@ -106,7 +108,7 @@ async def offer(request):
|
|||
answer = await pc.createAnswer()
|
||||
await pc.setLocalDescription(answer)
|
||||
|
||||
#return jsonify({"sdp": pc.localDescription.sdp, "type": pc.localDescription.type})
|
||||
# return jsonify({"sdp": pc.localDescription.sdp, "type": pc.localDescription.type})
|
||||
|
||||
return web.Response(
|
||||
content_type="application/json",
|
||||
|
@ -115,36 +117,40 @@ async def offer(request):
|
|||
),
|
||||
)
|
||||
|
||||
|
||||
async def human(request):
|
||||
params = await request.json()
|
||||
|
||||
if params['type']=='echo':
|
||||
if params['type'] == 'echo':
|
||||
nerfreal.put_msg_txt(params['text'])
|
||||
elif params['type']=='chat':
|
||||
res=await asyncio.get_event_loop().run_in_executor(None, llm_response(params['text']))
|
||||
elif params['type'] == 'chat':
|
||||
res = await asyncio.get_event_loop().run_in_executor(None, llm_response(params['text']))
|
||||
nerfreal.put_msg_txt(res)
|
||||
|
||||
return web.Response(
|
||||
content_type="application/json",
|
||||
text=json.dumps(
|
||||
{"code": 0, "data":"ok"}
|
||||
{"code": 0, "data": "ok"}
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def on_shutdown(app):
|
||||
# close peer connections
|
||||
coros = [pc.close() for pc in pcs]
|
||||
await asyncio.gather(*coros)
|
||||
pcs.clear()
|
||||
|
||||
async def post(url,data):
|
||||
|
||||
async def post(url, data):
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url,data=data) as response:
|
||||
async with session.post(url, data=data) as response:
|
||||
return await response.text()
|
||||
except aiohttp.ClientError as e:
|
||||
print(f'Error: {e}')
|
||||
|
||||
|
||||
async def run(push_url):
|
||||
pc = RTCPeerConnection()
|
||||
pcs.add(pc)
|
||||
|
@ -161,8 +167,10 @@ async def run(push_url):
|
|||
video_sender = pc.addTrack(player.video)
|
||||
|
||||
await pc.setLocalDescription(await pc.createOffer())
|
||||
answer = await post(push_url,pc.localDescription.sdp)
|
||||
await pc.setRemoteDescription(RTCSessionDescription(sdp=answer,type='answer'))
|
||||
answer = await post(push_url, pc.localDescription.sdp)
|
||||
await pc.setRemoteDescription(RTCSessionDescription(sdp=answer, type='answer'))
|
||||
|
||||
|
||||
##########################################
|
||||
# os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
|
||||
# os.environ['MULTIPROCESSING_METHOD'] = 'forkserver'
|
||||
|
@ -181,14 +189,20 @@ if __name__ == '__main__':
|
|||
|
||||
### training options
|
||||
parser.add_argument('--ckpt', type=str, default='data/pretrained/ngp_kf.pth')
|
||||
|
||||
parser.add_argument('--num_rays', type=int, default=4096 * 16, help="num rays sampled per image for each training step")
|
||||
|
||||
parser.add_argument('--num_rays', type=int, default=4096 * 16,
|
||||
help="num rays sampled per image for each training step")
|
||||
parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch")
|
||||
parser.add_argument('--max_steps', type=int, default=16, help="max num steps sampled per ray (only valid when using --cuda_ray)")
|
||||
parser.add_argument('--num_steps', type=int, default=16, help="num steps sampled per ray (only valid when NOT using --cuda_ray)")
|
||||
parser.add_argument('--upsample_steps', type=int, default=0, help="num steps up-sampled per ray (only valid when NOT using --cuda_ray)")
|
||||
parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)")
|
||||
parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)")
|
||||
parser.add_argument('--max_steps', type=int, default=16,
|
||||
help="max num steps sampled per ray (only valid when using --cuda_ray)")
|
||||
parser.add_argument('--num_steps', type=int, default=16,
|
||||
help="num steps sampled per ray (only valid when NOT using --cuda_ray)")
|
||||
parser.add_argument('--upsample_steps', type=int, default=0,
|
||||
help="num steps up-sampled per ray (only valid when NOT using --cuda_ray)")
|
||||
parser.add_argument('--update_extra_interval', type=int, default=16,
|
||||
help="iter interval to update extra status (only valid when using --cuda_ray)")
|
||||
parser.add_argument('--max_ray_batch', type=int, default=4096,
|
||||
help="batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)")
|
||||
|
||||
### loss set
|
||||
parser.add_argument('--warmup_step', type=int, default=10000, help="warm up steps")
|
||||
|
@ -199,27 +213,35 @@ if __name__ == '__main__':
|
|||
|
||||
### network backbone options
|
||||
parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training")
|
||||
|
||||
|
||||
parser.add_argument('--bg_img', type=str, default='white', help="background image")
|
||||
parser.add_argument('--fbg', action='store_true', help="frame-wise bg")
|
||||
parser.add_argument('--exp_eye', action='store_true', help="explicitly control the eyes")
|
||||
parser.add_argument('--fix_eye', type=float, default=-1, help="fixed eye area, negative to disable, set to 0-0.3 for a reasonable eye")
|
||||
parser.add_argument('--fix_eye', type=float, default=-1,
|
||||
help="fixed eye area, negative to disable, set to 0-0.3 for a reasonable eye")
|
||||
parser.add_argument('--smooth_eye', action='store_true', help="smooth the eye area sequence")
|
||||
|
||||
parser.add_argument('--torso_shrink', type=float, default=0.8, help="shrink bg coords to allow more flexibility in deform")
|
||||
parser.add_argument('--torso_shrink', type=float, default=0.8,
|
||||
help="shrink bg coords to allow more flexibility in deform")
|
||||
|
||||
### dataset options
|
||||
parser.add_argument('--color_space', type=str, default='srgb', help="Color space, supports (linear, srgb)")
|
||||
parser.add_argument('--preload', type=int, default=0, help="0 means load data from disk on-the-fly, 1 means preload to CPU, 2 means GPU.")
|
||||
parser.add_argument('--preload', type=int, default=0,
|
||||
help="0 means load data from disk on-the-fly, 1 means preload to CPU, 2 means GPU.")
|
||||
# (the default value is for the fox dataset)
|
||||
parser.add_argument('--bound', type=float, default=1, help="assume the scene is bounded in box[-bound, bound]^3, if > 1, will invoke adaptive ray marching.")
|
||||
parser.add_argument('--bound', type=float, default=1,
|
||||
help="assume the scene is bounded in box[-bound, bound]^3, if > 1, will invoke adaptive ray marching.")
|
||||
parser.add_argument('--scale', type=float, default=4, help="scale camera location into box[-bound, bound]^3")
|
||||
parser.add_argument('--offset', type=float, nargs='*', default=[0, 0, 0], help="offset of camera location")
|
||||
parser.add_argument('--dt_gamma', type=float, default=1/256, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)")
|
||||
parser.add_argument('--dt_gamma', type=float, default=1 / 256,
|
||||
help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)")
|
||||
parser.add_argument('--min_near', type=float, default=0.05, help="minimum near distance for camera")
|
||||
parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied (sigma)")
|
||||
parser.add_argument('--density_thresh_torso', type=float, default=0.01, help="threshold for density grid to be occupied (alpha)")
|
||||
parser.add_argument('--patch_size', type=int, default=1, help="[experimental] render patches in training, so as to apply LPIPS loss. 1 means disabled, use [64, 32, 16] to enable")
|
||||
parser.add_argument('--density_thresh', type=float, default=10,
|
||||
help="threshold for density grid to be occupied (sigma)")
|
||||
parser.add_argument('--density_thresh_torso', type=float, default=0.01,
|
||||
help="threshold for density grid to be occupied (alpha)")
|
||||
parser.add_argument('--patch_size', type=int, default=1,
|
||||
help="[experimental] render patches in training, so as to apply LPIPS loss. 1 means disabled, use [64, 32, 16] to enable")
|
||||
|
||||
parser.add_argument('--init_lips', action='store_true', help="init lips region")
|
||||
parser.add_argument('--finetune_lips', action='store_true', help="use LPIPS and landmarks to fine tune lips region")
|
||||
|
@ -237,12 +259,15 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--max_spp', type=int, default=1, help="GUI rendering max sample per pixel")
|
||||
|
||||
### else
|
||||
parser.add_argument('--att', type=int, default=2, help="audio attention mode (0 = turn off, 1 = left-direction, 2 = bi-direction)")
|
||||
parser.add_argument('--aud', type=str, default='', help="audio source (empty will load the default, else should be a path to a npy file)")
|
||||
parser.add_argument('--att', type=int, default=2,
|
||||
help="audio attention mode (0 = turn off, 1 = left-direction, 2 = bi-direction)")
|
||||
parser.add_argument('--aud', type=str, default='',
|
||||
help="audio source (empty will load the default, else should be a path to a npy file)")
|
||||
parser.add_argument('--emb', action='store_true', help="use audio class + embedding instead of logits")
|
||||
|
||||
parser.add_argument('--ind_dim', type=int, default=4, help="individual code dim, 0 to turn off")
|
||||
parser.add_argument('--ind_num', type=int, default=10000, help="number of individual codes, should be larger than training dataset size")
|
||||
parser.add_argument('--ind_num', type=int, default=10000,
|
||||
help="number of individual codes, should be larger than training dataset size")
|
||||
|
||||
parser.add_argument('--ind_dim_torso', type=int, default=8, help="individual code dim, 0 to turn off")
|
||||
|
||||
|
@ -251,7 +276,8 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--part2', action='store_true', help="use partial training data (first 15s)")
|
||||
|
||||
parser.add_argument('--train_camera', action='store_true', help="optimize camera pose")
|
||||
parser.add_argument('--smooth_path', action='store_true', help="brute-force smooth camera pose trajectory with a window size")
|
||||
parser.add_argument('--smooth_path', action='store_true',
|
||||
help="brute-force smooth camera pose trajectory with a window size")
|
||||
parser.add_argument('--smooth_path_window', type=int, default=7, help="smoothing window size")
|
||||
|
||||
# asr
|
||||
|
@ -259,8 +285,8 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--asr_wav', type=str, default='', help="load the wav and use as input")
|
||||
parser.add_argument('--asr_play', action='store_true', help="play out the audio")
|
||||
|
||||
#parser.add_argument('--asr_model', type=str, default='deepspeech')
|
||||
parser.add_argument('--asr_model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto') #
|
||||
# parser.add_argument('--asr_model', type=str, default='deepspeech')
|
||||
parser.add_argument('--asr_model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto') #
|
||||
# parser.add_argument('--asr_model', type=str, default='facebook/wav2vec2-large-960h-lv60-self')
|
||||
# parser.add_argument('--asr_model', type=str, default='facebook/hubert-large-ls960-ft')
|
||||
|
||||
|
@ -279,7 +305,7 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--fullbody_offset_x', type=int, default=0)
|
||||
parser.add_argument('--fullbody_offset_y', type=int, default=0)
|
||||
|
||||
#musetalk opt
|
||||
# musetalk opt
|
||||
parser.add_argument('--avatar_id', type=str, default='avator_1')
|
||||
parser.add_argument('--bbox_shift', type=int, default=5)
|
||||
parser.add_argument('--batch_size', type=int, default=16)
|
||||
|
@ -289,33 +315,35 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--customvideo_img', type=str, default='data/customvideo/img')
|
||||
parser.add_argument('--customvideo_imgnum', type=int, default=1)
|
||||
|
||||
parser.add_argument('--tts', type=str, default='edgetts') #xtts gpt-sovits
|
||||
parser.add_argument('--tts', type=str, default='edgetts') # xtts gpt-sovits
|
||||
parser.add_argument('--REF_FILE', type=str, default=None)
|
||||
parser.add_argument('--REF_TEXT', type=str, default=None)
|
||||
parser.add_argument('--TTS_SERVER', type=str, default='http://127.0.0.1:9880') # http://localhost:9000
|
||||
parser.add_argument('--TTS_SERVER', type=str, default='http://127.0.0.1:9880') # http://localhost:9000
|
||||
# parser.add_argument('--CHARACTER', type=str, default='test')
|
||||
# parser.add_argument('--EMOTION', type=str, default='default')
|
||||
|
||||
parser.add_argument('--model', type=str, default='ernerf') #musetalk wav2lip
|
||||
parser.add_argument('--model', type=str, default='ernerf') # musetalk wav2lip
|
||||
|
||||
parser.add_argument('--transport', type=str, default='rtcpush') #rtmp webrtc rtcpush
|
||||
parser.add_argument('--push_url', type=str, default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') #rtmp://localhost/live/livestream
|
||||
parser.add_argument('--transport', type=str, default='rtcpush') # rtmp webrtc rtcpush
|
||||
parser.add_argument('--push_url', type=str,
|
||||
default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') # rtmp://localhost/live/livestream
|
||||
|
||||
parser.add_argument('--listenport', type=int, default=8010)
|
||||
|
||||
opt = parser.parse_args()
|
||||
#app.config.from_object(opt)
|
||||
#print(app.config)
|
||||
# app.config.from_object(opt)
|
||||
# print(app.config)
|
||||
|
||||
if opt.model == 'ernerf':
|
||||
from ernerf.nerf_triplane.provider import NeRFDataset_Test
|
||||
from ernerf.nerf_triplane.utils import *
|
||||
from ernerf.nerf_triplane.network import NeRFNetwork
|
||||
from nerfreal import NeRFReal
|
||||
|
||||
# assert test mode
|
||||
opt.test = True
|
||||
opt.test_train = False
|
||||
#opt.train_camera =True
|
||||
# opt.train_camera =True
|
||||
# explicit smoothing
|
||||
opt.smooth_path = True
|
||||
opt.smooth_lips = True
|
||||
|
@ -328,7 +356,7 @@ if __name__ == '__main__':
|
|||
opt.exp_eye = True
|
||||
opt.smooth_eye = True
|
||||
|
||||
if opt.torso_imgs=='': #no img,use model output
|
||||
if opt.torso_imgs == '': # no img,use model output
|
||||
opt.torso = True
|
||||
|
||||
# assert opt.cuda_ray, "Only support CUDA ray mode."
|
||||
|
@ -344,9 +372,10 @@ if __name__ == '__main__':
|
|||
model = NeRFNetwork(opt)
|
||||
|
||||
criterion = torch.nn.MSELoss(reduction='none')
|
||||
metrics = [] # use no metric in GUI for faster initialization...
|
||||
metrics = [] # use no metric in GUI for faster initialization...
|
||||
print(model)
|
||||
trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, metrics=metrics, use_checkpoint=opt.ckpt)
|
||||
trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16,
|
||||
metrics=metrics, use_checkpoint=opt.ckpt)
|
||||
|
||||
test_loader = NeRFDataset_Test(opt, device=device).dataloader()
|
||||
model.aud_features = test_loader._data.auds
|
||||
|
@ -356,17 +385,19 @@ if __name__ == '__main__':
|
|||
nerfreal = NeRFReal(opt, trainer, test_loader)
|
||||
elif opt.model == 'musetalk':
|
||||
from musereal import MuseReal
|
||||
|
||||
print(opt)
|
||||
nerfreal = MuseReal(opt)
|
||||
elif opt.model == 'wav2lip':
|
||||
from lipreal import LipReal
|
||||
|
||||
print(opt)
|
||||
nerfreal = LipReal(opt)
|
||||
|
||||
#txt_to_audio('我是中国人,我来自北京')
|
||||
if opt.transport=='rtmp':
|
||||
# txt_to_audio('我是中国人,我来自北京')
|
||||
if opt.transport == 'rtmp':
|
||||
thread_quit = Event()
|
||||
rendthrd = Thread(target=nerfreal.render,args=(thread_quit,))
|
||||
rendthrd = Thread(target=nerfreal.render, args=(thread_quit,))
|
||||
rendthrd.start()
|
||||
|
||||
#############################################################################
|
||||
|
@ -374,35 +405,36 @@ if __name__ == '__main__':
|
|||
appasync.on_shutdown.append(on_shutdown)
|
||||
appasync.router.add_post("/offer", offer)
|
||||
appasync.router.add_post("/human", human)
|
||||
appasync.router.add_static('/',path='web')
|
||||
appasync.router.add_static('/', path='web')
|
||||
|
||||
# Configure default CORS settings.
|
||||
cors = aiohttp_cors.setup(appasync, defaults={
|
||||
"*": aiohttp_cors.ResourceOptions(
|
||||
allow_credentials=True,
|
||||
expose_headers="*",
|
||||
allow_headers="*",
|
||||
)
|
||||
})
|
||||
"*": aiohttp_cors.ResourceOptions(
|
||||
allow_credentials=True,
|
||||
expose_headers="*",
|
||||
allow_headers="*",
|
||||
)
|
||||
})
|
||||
# Configure CORS on all routes.
|
||||
for route in list(appasync.router.routes()):
|
||||
cors.add(route)
|
||||
|
||||
|
||||
def run_server(runner):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(runner.setup())
|
||||
site = web.TCPSite(runner, '0.0.0.0', opt.listenport)
|
||||
loop.run_until_complete(site.start())
|
||||
if opt.transport=='rtcpush':
|
||||
if opt.transport == 'rtcpush':
|
||||
loop.run_until_complete(run(opt.push_url))
|
||||
loop.run_forever()
|
||||
loop.run_forever()
|
||||
|
||||
|
||||
Thread(target=run_server, args=(web.AppRunner(appasync),)).start()
|
||||
|
||||
print('start websocket server')
|
||||
#app.on_shutdown.append(on_shutdown)
|
||||
#app.router.add_post("/offer", offer)
|
||||
# app.on_shutdown.append(on_shutdown)
|
||||
# app.router.add_post("/offer", offer)
|
||||
server = pywsgi.WSGIServer(('0.0.0.0', 8000), app, handler_class=WebSocketHandler)
|
||||
server.serve_forever()
|
||||
|
||||
|
|
@ -0,0 +1,331 @@
|
|||
import argparse
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
from PIL import Image
|
||||
from diffusers import AutoencoderKL
|
||||
from face_alignment import NetworkSize
|
||||
from mmpose.apis import inference_topdown, init_model
|
||||
from mmpose.structures import merge_data_samples
|
||||
from tqdm import tqdm
|
||||
|
||||
from utils.face_parsing import FaceParsing
|
||||
|
||||
|
||||
def video2imgs(vid_path, save_path, ext='.png', cut_frame=10000000):
|
||||
cap = cv2.VideoCapture(vid_path)
|
||||
count = 0
|
||||
while True:
|
||||
if count > cut_frame:
|
||||
break
|
||||
ret, frame = cap.read()
|
||||
if ret:
|
||||
cv2.imwrite(f"{save_path}/{count:08d}.png", frame)
|
||||
count += 1
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
def read_imgs(img_list):
|
||||
frames = []
|
||||
print('reading images...')
|
||||
for img_path in tqdm(img_list):
|
||||
frame = cv2.imread(img_path)
|
||||
frames.append(frame)
|
||||
return frames
|
||||
|
||||
|
||||
def get_landmark_and_bbox(img_list, upperbondrange=0):
|
||||
frames = read_imgs(img_list)
|
||||
batch_size_fa = 1
|
||||
batches = [frames[i:i + batch_size_fa] for i in range(0, len(frames), batch_size_fa)]
|
||||
coords_list = []
|
||||
landmarks = []
|
||||
if upperbondrange != 0:
|
||||
print('get key_landmark and face bounding boxes with the bbox_shift:', upperbondrange)
|
||||
else:
|
||||
print('get key_landmark and face bounding boxes with the default value')
|
||||
average_range_minus = []
|
||||
average_range_plus = []
|
||||
for fb in tqdm(batches):
|
||||
results = inference_topdown(model, np.asarray(fb)[0])
|
||||
results = merge_data_samples(results)
|
||||
keypoints = results.pred_instances.keypoints
|
||||
face_land_mark = keypoints[0][23:91]
|
||||
face_land_mark = face_land_mark.astype(np.int32)
|
||||
|
||||
# get bounding boxes by face detetion
|
||||
bbox = fa.get_detections_for_batch(np.asarray(fb))
|
||||
|
||||
# adjust the bounding box refer to landmark
|
||||
# Add the bounding box to a tuple and append it to the coordinates list
|
||||
for j, f in enumerate(bbox):
|
||||
if f is None: # no face in the image
|
||||
coords_list += [coord_placeholder]
|
||||
continue
|
||||
|
||||
half_face_coord = face_land_mark[29] # np.mean([face_land_mark[28], face_land_mark[29]], axis=0)
|
||||
range_minus = (face_land_mark[30] - face_land_mark[29])[1]
|
||||
range_plus = (face_land_mark[29] - face_land_mark[28])[1]
|
||||
average_range_minus.append(range_minus)
|
||||
average_range_plus.append(range_plus)
|
||||
if upperbondrange != 0:
|
||||
half_face_coord[1] = upperbondrange + half_face_coord[1] # 手动调整 + 向下(偏29) - 向上(偏28)
|
||||
half_face_dist = np.max(face_land_mark[:, 1]) - half_face_coord[1]
|
||||
upper_bond = half_face_coord[1] - half_face_dist
|
||||
|
||||
f_landmark = (
|
||||
np.min(face_land_mark[:, 0]), int(upper_bond), np.max(face_land_mark[:, 0]),
|
||||
np.max(face_land_mark[:, 1]))
|
||||
x1, y1, x2, y2 = f_landmark
|
||||
|
||||
if y2 - y1 <= 0 or x2 - x1 <= 0 or x1 < 0: # if the landmark bbox is not suitable, reuse the bbox
|
||||
coords_list += [f]
|
||||
w, h = f[2] - f[0], f[3] - f[1]
|
||||
print("error bbox:", f)
|
||||
else:
|
||||
coords_list += [f_landmark]
|
||||
return coords_list, frames
|
||||
|
||||
|
||||
class FaceAlignment:
|
||||
def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
|
||||
device='cuda', flip_input=False, face_detector='sfd', verbose=False):
|
||||
self.device = device
|
||||
self.flip_input = flip_input
|
||||
self.landmarks_type = landmarks_type
|
||||
self.verbose = verbose
|
||||
|
||||
network_size = int(network_size)
|
||||
if 'cuda' in device:
|
||||
torch.backends.cudnn.benchmark = True
|
||||
# torch.backends.cuda.matmul.allow_tf32 = False
|
||||
# torch.backends.cudnn.benchmark = True
|
||||
# torch.backends.cudnn.deterministic = False
|
||||
# torch.backends.cudnn.allow_tf32 = True
|
||||
print('cuda start')
|
||||
|
||||
# Get the face detector
|
||||
face_detector_module = __import__('face_detection.detection.' + face_detector,
|
||||
globals(), locals(), [face_detector], 0)
|
||||
|
||||
self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose)
|
||||
|
||||
def get_detections_for_batch(self, images):
|
||||
images = images[..., ::-1]
|
||||
detected_faces = self.face_detector.detect_from_batch(images.copy())
|
||||
results = []
|
||||
|
||||
for i, d in enumerate(detected_faces):
|
||||
if len(d) == 0:
|
||||
results.append(None)
|
||||
continue
|
||||
d = d[0]
|
||||
d = np.clip(d, 0, None)
|
||||
|
||||
x1, y1, x2, y2 = map(int, d[:-1])
|
||||
results.append((x1, y1, x2, y2))
|
||||
return results
|
||||
|
||||
|
||||
def get_mask_tensor():
|
||||
"""
|
||||
Creates a mask tensor for image processing.
|
||||
:return: A mask tensor.
|
||||
"""
|
||||
mask_tensor = torch.zeros((256, 256))
|
||||
mask_tensor[:256 // 2, :] = 1
|
||||
mask_tensor[mask_tensor < 0.5] = 0
|
||||
mask_tensor[mask_tensor >= 0.5] = 1
|
||||
return mask_tensor
|
||||
|
||||
|
||||
def preprocess_img(img_name, half_mask=False):
|
||||
window = []
|
||||
if isinstance(img_name, str):
|
||||
window_fnames = [img_name]
|
||||
for fname in window_fnames:
|
||||
img = cv2.imread(fname)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
img = cv2.resize(img, (256, 256),
|
||||
interpolation=cv2.INTER_LANCZOS4)
|
||||
window.append(img)
|
||||
else:
|
||||
img = cv2.cvtColor(img_name, cv2.COLOR_BGR2RGB)
|
||||
window.append(img)
|
||||
x = np.asarray(window) / 255.
|
||||
x = np.transpose(x, (3, 0, 1, 2))
|
||||
x = torch.squeeze(torch.FloatTensor(x))
|
||||
if half_mask:
|
||||
x = x * (get_mask_tensor() > 0.5)
|
||||
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
x = normalize(x)
|
||||
x = x.unsqueeze(0) # [1, 3, 256, 256] torch tensor
|
||||
x = x.to(device)
|
||||
return x
|
||||
|
||||
|
||||
def encode_latents(image):
|
||||
with torch.no_grad():
|
||||
init_latent_dist = vae.encode(image.to(vae.dtype)).latent_dist
|
||||
init_latents = vae.config.scaling_factor * init_latent_dist.sample()
|
||||
return init_latents
|
||||
|
||||
|
||||
def get_latents_for_unet(img):
|
||||
ref_image = preprocess_img(img, half_mask=True) # [1, 3, 256, 256] RGB, torch tensor
|
||||
masked_latents = encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
|
||||
ref_image = preprocess_img(img, half_mask=False) # [1, 3, 256, 256] RGB, torch tensor
|
||||
ref_latents = encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
|
||||
latent_model_input = torch.cat([masked_latents, ref_latents], dim=1)
|
||||
return latent_model_input
|
||||
|
||||
|
||||
def get_crop_box(box, expand):
|
||||
x, y, x1, y1 = box
|
||||
x_c, y_c = (x + x1) // 2, (y + y1) // 2
|
||||
w, h = x1 - x, y1 - y
|
||||
s = int(max(w, h) // 2 * expand)
|
||||
crop_box = [x_c - s, y_c - s, x_c + s, y_c + s]
|
||||
return crop_box, s
|
||||
|
||||
|
||||
def face_seg(image):
|
||||
seg_image = fp(image)
|
||||
if seg_image is None:
|
||||
print("error, no person_segment")
|
||||
return None
|
||||
|
||||
seg_image = seg_image.resize(image.size)
|
||||
return seg_image
|
||||
|
||||
|
||||
def get_image_prepare_material(image, face_box, upper_boundary_ratio=0.5, expand=1.2):
|
||||
body = Image.fromarray(image[:, :, ::-1])
|
||||
|
||||
x, y, x1, y1 = face_box
|
||||
# print(x1-x,y1-y)
|
||||
crop_box, s = get_crop_box(face_box, expand)
|
||||
x_s, y_s, x_e, y_e = crop_box
|
||||
|
||||
face_large = body.crop(crop_box)
|
||||
ori_shape = face_large.size
|
||||
|
||||
mask_image = face_seg(face_large)
|
||||
mask_small = mask_image.crop((x - x_s, y - y_s, x1 - x_s, y1 - y_s))
|
||||
mask_image = Image.new('L', ori_shape, 0)
|
||||
mask_image.paste(mask_small, (x - x_s, y - y_s, x1 - x_s, y1 - y_s))
|
||||
|
||||
# keep upper_boundary_ratio of talking area
|
||||
width, height = mask_image.size
|
||||
top_boundary = int(height * upper_boundary_ratio)
|
||||
modified_mask_image = Image.new('L', ori_shape, 0)
|
||||
modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary))
|
||||
|
||||
blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1
|
||||
mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0)
|
||||
return mask_array, crop_box
|
||||
|
||||
|
||||
def create_dir(dir_path):
|
||||
if not os.path.exists(dir_path):
|
||||
os.makedirs(dir_path)
|
||||
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
# initialize the mmpose model
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
fa = FaceAlignment(1, flip_input=False, device=device)
|
||||
config_file = './utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py'
|
||||
checkpoint_file = '../models/dwpose/dw-ll_ucoco_384.pth'
|
||||
model = init_model(config_file, checkpoint_file, device=device)
|
||||
vae = AutoencoderKL.from_pretrained("../models/sd-vae-ft-mse")
|
||||
vae.to(device)
|
||||
fp = FaceParsing()
|
||||
if __name__ == '__main__':
|
||||
# 视频文件地址
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--file",
|
||||
type=str,
|
||||
default=r'D:\ok\test.mp4',
|
||||
)
|
||||
parser.add_argument("--avatar_id",
|
||||
type=str,
|
||||
default='1',
|
||||
)
|
||||
args = parser.parse_args()
|
||||
file = args.file
|
||||
# 保存文件设置 可以不动
|
||||
save_path = f'../data/avatars/avator_{args.avatar_id}'
|
||||
save_full_path = f'../data/avatars/avator_{args.avatar_id}/full_imgs'
|
||||
create_dir(save_path)
|
||||
create_dir(save_full_path)
|
||||
mask_out_path = f'../data/avatars/avator_{args.avatar_id}/mask'
|
||||
create_dir(mask_out_path)
|
||||
|
||||
# 模型
|
||||
mask_coords_path = f'{save_path}/mask_coords.pkl'
|
||||
coords_path = f'{save_path}/coords.pkl'
|
||||
latents_out_path = f'{save_path}/latents.pt'
|
||||
|
||||
with open(f'{save_path}/avator_info.json', "w") as f:
|
||||
json.dump({
|
||||
"avatar_id": args.avatar_id,
|
||||
"video_path": file,
|
||||
"bbox_shift": 5
|
||||
}, f)
|
||||
|
||||
if os.path.isfile(file):
|
||||
video2imgs(file, save_full_path, ext='png')
|
||||
else:
|
||||
files = os.listdir(file)
|
||||
files.sort()
|
||||
files = [file for file in files if file.split(".")[-1] == "png"]
|
||||
for filename in files:
|
||||
shutil.copyfile(f"{file}/{filename}", f"{save_full_path}/{filename}")
|
||||
input_img_list = sorted(glob.glob(os.path.join(save_full_path, '*.[jpJP][pnPN]*[gG]')))
|
||||
print("extracting landmarks...")
|
||||
coord_list, frame_list = get_landmark_and_bbox(input_img_list, 5)
|
||||
input_latent_list = []
|
||||
idx = -1
|
||||
# maker if the bbox is not sufficient
|
||||
coord_placeholder = (0.0, 0.0, 0.0, 0.0)
|
||||
for bbox, frame in zip(coord_list, frame_list):
|
||||
idx = idx + 1
|
||||
if bbox == coord_placeholder:
|
||||
continue
|
||||
x1, y1, x2, y2 = bbox
|
||||
crop_frame = frame[y1:y2, x1:x2]
|
||||
resized_crop_frame = cv2.resize(crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4)
|
||||
latents = get_latents_for_unet(resized_crop_frame)
|
||||
input_latent_list.append(latents)
|
||||
|
||||
frame_list_cycle = frame_list + frame_list[::-1]
|
||||
coord_list_cycle = coord_list + coord_list[::-1]
|
||||
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
|
||||
mask_coords_list_cycle = []
|
||||
mask_list_cycle = []
|
||||
for i, frame in enumerate(tqdm(frame_list_cycle)):
|
||||
cv2.imwrite(f"{save_full_path}/{str(i).zfill(8)}.png", frame)
|
||||
|
||||
face_box = coord_list_cycle[i]
|
||||
mask, crop_box = get_image_prepare_material(frame, face_box)
|
||||
cv2.imwrite(f"{mask_out_path}/{str(i).zfill(8)}.png", mask)
|
||||
mask_coords_list_cycle += [crop_box]
|
||||
mask_list_cycle.append(mask)
|
||||
|
||||
with open(mask_coords_path, 'wb') as f:
|
||||
pickle.dump(mask_coords_list_cycle, f)
|
||||
|
||||
with open(coords_path, 'wb') as f:
|
||||
pickle.dump(coord_list_cycle, f)
|
||||
torch.save(input_latent_list_cycle, os.path.join(latents_out_path))
|
|
@ -7,18 +7,18 @@ from PIL import Image
|
|||
from .model import BiSeNet
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
|
||||
class FaceParsing():
|
||||
def __init__(self):
|
||||
self.net = self.model_init()
|
||||
def __init__(self, resnet_path='./models/face-parse-bisent/resnet18-5c106cde.pth',
|
||||
model_pth='./models/face-parse-bisent/79999_iter.pth'):
|
||||
self.net = self.model_init(resnet_path,model_pth)
|
||||
self.preprocess = self.image_preprocess()
|
||||
|
||||
def model_init(self,
|
||||
resnet_path='./models/face-parse-bisent/resnet18-5c106cde.pth',
|
||||
model_pth='./models/face-parse-bisent/79999_iter.pth'):
|
||||
def model_init(self,resnet_path, model_pth):
|
||||
net = BiSeNet(resnet_path)
|
||||
if torch.cuda.is_available():
|
||||
net.cuda()
|
||||
net.load_state_dict(torch.load(model_pth))
|
||||
net.load_state_dict(torch.load(model_pth))
|
||||
else:
|
||||
net.load_state_dict(torch.load(model_pth, map_location=torch.device('cpu')))
|
||||
net.eval()
|
||||
|
@ -44,13 +44,13 @@ class FaceParsing():
|
|||
img = torch.unsqueeze(img, 0)
|
||||
out = self.net(img)[0]
|
||||
parsing = out.squeeze(0).cpu().numpy().argmax(0)
|
||||
parsing[np.where(parsing>13)] = 0
|
||||
parsing[np.where(parsing>=1)] = 255
|
||||
parsing[np.where(parsing > 13)] = 0
|
||||
parsing[np.where(parsing >= 1)] = 255
|
||||
parsing = Image.fromarray(parsing.astype(np.uint8))
|
||||
return parsing
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fp = FaceParsing()
|
||||
segmap = fp('154_small.png')
|
||||
segmap.save('res.png')
|
||||
|
||||
|
|
Loading…
Reference in New Issue