feat: 完善修改成自动绝对路径,添加接口生成

This commit is contained in:
Yun 2024-06-23 14:51:58 +08:00
parent 6eb03ecbff
commit 18d7db35a7
2 changed files with 84 additions and 52 deletions

49
app.py
View File

@ -1,29 +1,22 @@
# server.py # server.py
from flask import Flask, render_template, send_from_directory, request, jsonify import argparse
from flask_sockets import Sockets import asyncio
import base64
import time
import json import json
import gevent
from gevent import pywsgi
from geventwebsocket.handler import WebSocketHandler
import os
import re
import numpy as np
from threading import Thread, Event
import multiprocessing import multiprocessing
from threading import Thread, Event
from aiohttp import web
import aiohttp import aiohttp
import aiohttp_cors import aiohttp_cors
from aiohttp import web
from aiortc import RTCPeerConnection, RTCSessionDescription from aiortc import RTCPeerConnection, RTCSessionDescription
from flask import Flask
from flask_sockets import Sockets
from gevent import pywsgi
from geventwebsocket.handler import WebSocketHandler
from musetalk.simple_musetalk import create_musetalk_human
from webrtc import HumanPlayer from webrtc import HumanPlayer
import argparse
import shutil
import asyncio
app = Flask(__name__) app = Flask(__name__)
sockets = Sockets(app) sockets = Sockets(app)
global nerfreal global nerfreal
@ -135,6 +128,27 @@ async def human(request):
) )
async def handle_create_musetalk(request):
reader = await request.multipart()
# 处理文件部分
file_part = await reader.next()
filename = file_part.filename
file_data = await file_part.read() # 读取文件的内容
# 注意:确保这个文件路径是可写的
with open(filename, 'wb') as f:
f.write(file_data)
# 处理整数部分
part = await reader.next()
avatar_id = int(await part.text())
create_musetalk_human(filename, avatar_id)
os.remove(filename)
return web.json_response({
'status': 'success',
'filename': filename,
'int_value': avatar_id,
})
async def on_shutdown(app): async def on_shutdown(app):
# close peer connections # close peer connections
coros = [pc.close() for pc in pcs] coros = [pc.close() for pc in pcs]
@ -405,6 +419,7 @@ if __name__ == '__main__':
appasync.on_shutdown.append(on_shutdown) appasync.on_shutdown.append(on_shutdown)
appasync.router.add_post("/offer", offer) appasync.router.add_post("/offer", offer)
appasync.router.add_post("/human", human) appasync.router.add_post("/human", human)
appasync.router.add_post("/create_musetalk", handle_create_musetalk)
appasync.router.add_static('/', path='web') appasync.router.add_static('/', path='web')
# Configure default CORS settings. # Configure default CORS settings.

View File

@ -4,7 +4,6 @@ import json
import os import os
import pickle import pickle
import shutil import shutil
import sys
import cv2 import cv2
import numpy as np import numpy as np
@ -17,7 +16,10 @@ from mmpose.apis import inference_topdown, init_model
from mmpose.structures import merge_data_samples from mmpose.structures import merge_data_samples
from tqdm import tqdm from tqdm import tqdm
from utils.face_parsing import FaceParsing try:
from utils.face_parsing import FaceParsing
except ModuleNotFoundError:
from musetalk.utils.face_parsing import FaceParsing
def video2imgs(vid_path, save_path, ext='.png', cut_frame=10000000): def video2imgs(vid_path, save_path, ext='.png', cut_frame=10000000):
@ -55,6 +57,7 @@ def get_landmark_and_bbox(img_list, upperbondrange=0):
print('get key_landmark and face bounding boxes with the default value') print('get key_landmark and face bounding boxes with the default value')
average_range_minus = [] average_range_minus = []
average_range_plus = [] average_range_plus = []
coord_placeholder = (0.0, 0.0, 0.0, 0.0)
for fb in tqdm(batches): for fb in tqdm(batches):
results = inference_topdown(model, np.asarray(fb)[0]) results = inference_topdown(model, np.asarray(fb)[0])
results = merge_data_samples(results) results = merge_data_samples(results)
@ -235,57 +238,47 @@ def get_image_prepare_material(image, face_box, upper_boundary_ratio=0.5, expand
return mask_array, crop_box return mask_array, crop_box
##todo 简单根据文件后缀判断 要更精确的可以自己修改 使用 magic
def is_video_file(file_path):
video_exts = ['.mp4', '.mkv', '.flv', '.avi', '.mov'] # 这里列出了一些常见的视频文件扩展名,可以根据需要添加更多
file_ext = os.path.splitext(file_path)[1].lower() # 获取文件扩展名并转换为小写
return file_ext in video_exts
def create_dir(dir_path): def create_dir(dir_path):
if not os.path.exists(dir_path): if not os.path.exists(dir_path):
os.makedirs(dir_path) os.makedirs(dir_path)
sys.path.append(os.path.dirname(os.path.abspath(__file__))) current_dir = os.path.dirname(os.path.abspath(__file__))
# initialize the mmpose model
device = "cuda" if torch.cuda.is_available() else "cpu" def create_musetalk_human(file, avatar_id):
fa = FaceAlignment(1, flip_input=False, device=device)
config_file = './utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py'
checkpoint_file = '../models/dwpose/dw-ll_ucoco_384.pth'
model = init_model(config_file, checkpoint_file, device=device)
vae = AutoencoderKL.from_pretrained("../models/sd-vae-ft-mse")
vae.to(device)
fp = FaceParsing()
if __name__ == '__main__':
# 视频文件地址
parser = argparse.ArgumentParser()
parser.add_argument("--file",
type=str,
default=r'D:\ok\test.mp4',
)
parser.add_argument("--avatar_id",
type=str,
default='1',
)
args = parser.parse_args()
file = args.file
# 保存文件设置 可以不动 # 保存文件设置 可以不动
save_path = f'../data/avatars/avator_{args.avatar_id}' save_path = os.path.join(current_dir, f'../data/avatars/avator_{avatar_id}')
save_full_path = f'../data/avatars/avator_{args.avatar_id}/full_imgs' save_full_path = os.path.join(current_dir, f'../data/avatars/avator_{avatar_id}/full_imgs')
create_dir(save_path) create_dir(save_path)
create_dir(save_full_path) create_dir(save_full_path)
mask_out_path = f'../data/avatars/avator_{args.avatar_id}/mask' mask_out_path = os.path.join(current_dir, f'../data/avatars/avator_{avatar_id}/mask')
create_dir(mask_out_path) create_dir(mask_out_path)
# 模型 # 模型
mask_coords_path = f'{save_path}/mask_coords.pkl' mask_coords_path = os.path.join(current_dir, f'{save_path}/mask_coords.pkl')
coords_path = f'{save_path}/coords.pkl' coords_path = os.path.join(current_dir, f'{save_path}/coords.pkl')
latents_out_path = f'{save_path}/latents.pt' latents_out_path = os.path.join(current_dir, f'{save_path}/latents.pt')
with open(f'{save_path}/avator_info.json', "w") as f: with open(os.path.join(current_dir, f'{save_path}/avator_info.json'), "w") as f:
json.dump({ json.dump({
"avatar_id": args.avatar_id, "avatar_id": avatar_id,
"video_path": file, "video_path": file,
"bbox_shift": 5 "bbox_shift": 5
}, f) }, f)
if os.path.isfile(file): if os.path.isfile(file):
if is_video_file(file):
video2imgs(file, save_full_path, ext='png') video2imgs(file, save_full_path, ext='png')
else:
shutil.copyfile(file, f"{save_full_path}/{os.path.basename(file)}")
else: else:
files = os.listdir(file) files = os.listdir(file)
files.sort() files.sort()
@ -316,7 +309,6 @@ if __name__ == '__main__':
mask_list_cycle = [] mask_list_cycle = []
for i, frame in enumerate(tqdm(frame_list_cycle)): for i, frame in enumerate(tqdm(frame_list_cycle)):
cv2.imwrite(f"{save_full_path}/{str(i).zfill(8)}.png", frame) cv2.imwrite(f"{save_full_path}/{str(i).zfill(8)}.png", frame)
face_box = coord_list_cycle[i] face_box = coord_list_cycle[i]
mask, crop_box = get_image_prepare_material(frame, face_box) mask, crop_box = get_image_prepare_material(frame, face_box)
cv2.imwrite(f"{mask_out_path}/{str(i).zfill(8)}.png", mask) cv2.imwrite(f"{mask_out_path}/{str(i).zfill(8)}.png", mask)
@ -329,3 +321,28 @@ if __name__ == '__main__':
with open(coords_path, 'wb') as f: with open(coords_path, 'wb') as f:
pickle.dump(coord_list_cycle, f) pickle.dump(coord_list_cycle, f)
torch.save(input_latent_list_cycle, os.path.join(latents_out_path)) torch.save(input_latent_list_cycle, os.path.join(latents_out_path))
# initialize the mmpose model
device = "cuda" if torch.cuda.is_available() else "cpu"
fa = FaceAlignment(1, flip_input=False, device=device)
config_file = os.path.join(current_dir, 'utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py')
checkpoint_file = os.path.abspath(os.path.join(current_dir, '../models/dwpose/dw-ll_ucoco_384.pth'))
model = init_model(config_file, checkpoint_file, device=device)
vae = AutoencoderKL.from_pretrained(os.path.abspath(os.path.join(current_dir, '../models/sd-vae-ft-mse')))
vae.to(device)
fp = FaceParsing(os.path.abspath(os.path.join(current_dir, '../models/face-parse-bisent/resnet18-5c106cde.pth')),
os.path.abspath(os.path.join(current_dir, '../models/face-parse-bisent/79999_iter.pth')))
if __name__ == '__main__':
# 视频文件地址
parser = argparse.ArgumentParser()
parser.add_argument("--file",
type=str,
default=r'D:\ok\00000000.png',
)
parser.add_argument("--avatar_id",
type=str,
default='3',
)
args = parser.parse_args()
create_musetalk_human(args.file, args.avatar_id)