feat: 完善修改成自动绝对路径,添加接口生成
This commit is contained in:
parent
6eb03ecbff
commit
18d7db35a7
49
app.py
49
app.py
|
@ -1,29 +1,22 @@
|
|||
# server.py
|
||||
from flask import Flask, render_template, send_from_directory, request, jsonify
|
||||
from flask_sockets import Sockets
|
||||
import base64
|
||||
import time
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import gevent
|
||||
from gevent import pywsgi
|
||||
from geventwebsocket.handler import WebSocketHandler
|
||||
import os
|
||||
import re
|
||||
import numpy as np
|
||||
from threading import Thread, Event
|
||||
import multiprocessing
|
||||
from threading import Thread, Event
|
||||
|
||||
from aiohttp import web
|
||||
import aiohttp
|
||||
import aiohttp_cors
|
||||
from aiohttp import web
|
||||
from aiortc import RTCPeerConnection, RTCSessionDescription
|
||||
from flask import Flask
|
||||
from flask_sockets import Sockets
|
||||
from gevent import pywsgi
|
||||
from geventwebsocket.handler import WebSocketHandler
|
||||
|
||||
from musetalk.simple_musetalk import create_musetalk_human
|
||||
from webrtc import HumanPlayer
|
||||
|
||||
import argparse
|
||||
|
||||
import shutil
|
||||
import asyncio
|
||||
|
||||
app = Flask(__name__)
|
||||
sockets = Sockets(app)
|
||||
global nerfreal
|
||||
|
@ -135,6 +128,27 @@ async def human(request):
|
|||
)
|
||||
|
||||
|
||||
async def handle_create_musetalk(request):
|
||||
reader = await request.multipart()
|
||||
# 处理文件部分
|
||||
file_part = await reader.next()
|
||||
filename = file_part.filename
|
||||
file_data = await file_part.read() # 读取文件的内容
|
||||
# 注意:确保这个文件路径是可写的
|
||||
with open(filename, 'wb') as f:
|
||||
f.write(file_data)
|
||||
# 处理整数部分
|
||||
part = await reader.next()
|
||||
avatar_id = int(await part.text())
|
||||
create_musetalk_human(filename, avatar_id)
|
||||
os.remove(filename)
|
||||
return web.json_response({
|
||||
'status': 'success',
|
||||
'filename': filename,
|
||||
'int_value': avatar_id,
|
||||
})
|
||||
|
||||
|
||||
async def on_shutdown(app):
|
||||
# close peer connections
|
||||
coros = [pc.close() for pc in pcs]
|
||||
|
@ -405,6 +419,7 @@ if __name__ == '__main__':
|
|||
appasync.on_shutdown.append(on_shutdown)
|
||||
appasync.router.add_post("/offer", offer)
|
||||
appasync.router.add_post("/human", human)
|
||||
appasync.router.add_post("/create_musetalk", handle_create_musetalk)
|
||||
appasync.router.add_static('/', path='web')
|
||||
|
||||
# Configure default CORS settings.
|
||||
|
|
|
@ -4,7 +4,6 @@ import json
|
|||
import os
|
||||
import pickle
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
@ -17,7 +16,10 @@ from mmpose.apis import inference_topdown, init_model
|
|||
from mmpose.structures import merge_data_samples
|
||||
from tqdm import tqdm
|
||||
|
||||
from utils.face_parsing import FaceParsing
|
||||
try:
|
||||
from utils.face_parsing import FaceParsing
|
||||
except ModuleNotFoundError:
|
||||
from musetalk.utils.face_parsing import FaceParsing
|
||||
|
||||
|
||||
def video2imgs(vid_path, save_path, ext='.png', cut_frame=10000000):
|
||||
|
@ -55,6 +57,7 @@ def get_landmark_and_bbox(img_list, upperbondrange=0):
|
|||
print('get key_landmark and face bounding boxes with the default value')
|
||||
average_range_minus = []
|
||||
average_range_plus = []
|
||||
coord_placeholder = (0.0, 0.0, 0.0, 0.0)
|
||||
for fb in tqdm(batches):
|
||||
results = inference_topdown(model, np.asarray(fb)[0])
|
||||
results = merge_data_samples(results)
|
||||
|
@ -235,57 +238,47 @@ def get_image_prepare_material(image, face_box, upper_boundary_ratio=0.5, expand
|
|||
return mask_array, crop_box
|
||||
|
||||
|
||||
##todo 简单根据文件后缀判断 要更精确的可以自己修改 使用 magic
|
||||
def is_video_file(file_path):
|
||||
video_exts = ['.mp4', '.mkv', '.flv', '.avi', '.mov'] # 这里列出了一些常见的视频文件扩展名,可以根据需要添加更多
|
||||
file_ext = os.path.splitext(file_path)[1].lower() # 获取文件扩展名并转换为小写
|
||||
return file_ext in video_exts
|
||||
|
||||
|
||||
def create_dir(dir_path):
|
||||
if not os.path.exists(dir_path):
|
||||
os.makedirs(dir_path)
|
||||
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# initialize the mmpose model
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
fa = FaceAlignment(1, flip_input=False, device=device)
|
||||
config_file = './utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py'
|
||||
checkpoint_file = '../models/dwpose/dw-ll_ucoco_384.pth'
|
||||
model = init_model(config_file, checkpoint_file, device=device)
|
||||
vae = AutoencoderKL.from_pretrained("../models/sd-vae-ft-mse")
|
||||
vae.to(device)
|
||||
fp = FaceParsing()
|
||||
if __name__ == '__main__':
|
||||
# 视频文件地址
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--file",
|
||||
type=str,
|
||||
default=r'D:\ok\test.mp4',
|
||||
)
|
||||
parser.add_argument("--avatar_id",
|
||||
type=str,
|
||||
default='1',
|
||||
)
|
||||
args = parser.parse_args()
|
||||
file = args.file
|
||||
|
||||
def create_musetalk_human(file, avatar_id):
|
||||
# 保存文件设置 可以不动
|
||||
save_path = f'../data/avatars/avator_{args.avatar_id}'
|
||||
save_full_path = f'../data/avatars/avator_{args.avatar_id}/full_imgs'
|
||||
save_path = os.path.join(current_dir, f'../data/avatars/avator_{avatar_id}')
|
||||
save_full_path = os.path.join(current_dir, f'../data/avatars/avator_{avatar_id}/full_imgs')
|
||||
create_dir(save_path)
|
||||
create_dir(save_full_path)
|
||||
mask_out_path = f'../data/avatars/avator_{args.avatar_id}/mask'
|
||||
mask_out_path = os.path.join(current_dir, f'../data/avatars/avator_{avatar_id}/mask')
|
||||
create_dir(mask_out_path)
|
||||
|
||||
# 模型
|
||||
mask_coords_path = f'{save_path}/mask_coords.pkl'
|
||||
coords_path = f'{save_path}/coords.pkl'
|
||||
latents_out_path = f'{save_path}/latents.pt'
|
||||
mask_coords_path = os.path.join(current_dir, f'{save_path}/mask_coords.pkl')
|
||||
coords_path = os.path.join(current_dir, f'{save_path}/coords.pkl')
|
||||
latents_out_path = os.path.join(current_dir, f'{save_path}/latents.pt')
|
||||
|
||||
with open(f'{save_path}/avator_info.json', "w") as f:
|
||||
with open(os.path.join(current_dir, f'{save_path}/avator_info.json'), "w") as f:
|
||||
json.dump({
|
||||
"avatar_id": args.avatar_id,
|
||||
"avatar_id": avatar_id,
|
||||
"video_path": file,
|
||||
"bbox_shift": 5
|
||||
}, f)
|
||||
|
||||
if os.path.isfile(file):
|
||||
video2imgs(file, save_full_path, ext='png')
|
||||
if is_video_file(file):
|
||||
video2imgs(file, save_full_path, ext='png')
|
||||
else:
|
||||
shutil.copyfile(file, f"{save_full_path}/{os.path.basename(file)}")
|
||||
else:
|
||||
files = os.listdir(file)
|
||||
files.sort()
|
||||
|
@ -316,7 +309,6 @@ if __name__ == '__main__':
|
|||
mask_list_cycle = []
|
||||
for i, frame in enumerate(tqdm(frame_list_cycle)):
|
||||
cv2.imwrite(f"{save_full_path}/{str(i).zfill(8)}.png", frame)
|
||||
|
||||
face_box = coord_list_cycle[i]
|
||||
mask, crop_box = get_image_prepare_material(frame, face_box)
|
||||
cv2.imwrite(f"{mask_out_path}/{str(i).zfill(8)}.png", mask)
|
||||
|
@ -329,3 +321,28 @@ if __name__ == '__main__':
|
|||
with open(coords_path, 'wb') as f:
|
||||
pickle.dump(coord_list_cycle, f)
|
||||
torch.save(input_latent_list_cycle, os.path.join(latents_out_path))
|
||||
|
||||
|
||||
# initialize the mmpose model
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
fa = FaceAlignment(1, flip_input=False, device=device)
|
||||
config_file = os.path.join(current_dir, 'utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py')
|
||||
checkpoint_file = os.path.abspath(os.path.join(current_dir, '../models/dwpose/dw-ll_ucoco_384.pth'))
|
||||
model = init_model(config_file, checkpoint_file, device=device)
|
||||
vae = AutoencoderKL.from_pretrained(os.path.abspath(os.path.join(current_dir, '../models/sd-vae-ft-mse')))
|
||||
vae.to(device)
|
||||
fp = FaceParsing(os.path.abspath(os.path.join(current_dir, '../models/face-parse-bisent/resnet18-5c106cde.pth')),
|
||||
os.path.abspath(os.path.join(current_dir, '../models/face-parse-bisent/79999_iter.pth')))
|
||||
if __name__ == '__main__':
|
||||
# 视频文件地址
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--file",
|
||||
type=str,
|
||||
default=r'D:\ok\00000000.png',
|
||||
)
|
||||
parser.add_argument("--avatar_id",
|
||||
type=str,
|
||||
default='3',
|
||||
)
|
||||
args = parser.parse_args()
|
||||
create_musetalk_human(args.file, args.avatar_id)
|
||||
|
|
Loading…
Reference in New Issue