livetalking/basereal.py

207 lines
7.2 KiB
Python
Raw Permalink Normal View History

2024-08-03 08:26:17 +08:00
import math
import torch
import numpy as np
import os
import time
import cv2
import glob
import pickle
import copy
2024-09-21 10:55:30 +08:00
import resampy
2024-08-03 08:26:17 +08:00
import queue
from queue import Queue
from threading import Thread, Event
from io import BytesIO
import soundfile as sf
2024-09-01 18:37:43 +08:00
import av
from fractions import Fraction
2024-09-08 12:13:33 +08:00
from ttsreal import EdgeTTS,VoitsTTS,XTTS,CosyVoiceTTS
2024-08-03 12:58:49 +08:00
2024-08-03 08:26:17 +08:00
from tqdm import tqdm
def read_imgs(img_list):
frames = []
print('reading images...')
for img_path in tqdm(img_list):
frame = cv2.imread(img_path)
frames.append(frame)
return frames
class BaseReal:
def __init__(self, opt):
self.opt = opt
self.sample_rate = 16000
self.chunk = self.sample_rate // opt.fps # 320 samples per chunk (20ms * 16000 / 1000)
2024-08-03 12:58:49 +08:00
if opt.tts == "edgetts":
self.tts = EdgeTTS(opt,self)
elif opt.tts == "gpt-sovits":
self.tts = VoitsTTS(opt,self)
elif opt.tts == "xtts":
self.tts = XTTS(opt,self)
2024-09-08 12:13:33 +08:00
elif opt.tts == "cosyvoice":
self.tts = CosyVoiceTTS(opt,self)
2024-09-01 18:37:43 +08:00
2024-09-17 22:11:46 +08:00
self.speaking = False
2024-09-01 18:37:43 +08:00
self.recording = False
self.recordq_video = Queue()
self.recordq_audio = Queue()
2024-08-03 12:58:49 +08:00
2024-08-03 08:26:17 +08:00
self.curr_state=0
self.custom_img_cycle = {}
self.custom_audio_cycle = {}
self.custom_audio_index = {}
self.custom_index = {}
self.custom_opt = {}
self.__loadcustom()
2024-09-17 22:11:46 +08:00
def put_msg_txt(self,msg):
self.tts.put_msg_txt(msg)
def put_audio_frame(self,audio_chunk): #16khz 20ms pcm
self.asr.put_audio_frame(audio_chunk)
2024-09-21 10:55:30 +08:00
def put_audio_file(self,filebyte):
input_stream = BytesIO(filebyte)
stream = self.__create_bytes_stream(input_stream)
streamlen = stream.shape[0]
idx=0
while streamlen >= self.chunk: #and self.state==State.RUNNING
self.put_audio_frame(stream[idx:idx+self.chunk])
streamlen -= self.chunk
idx += self.chunk
def __create_bytes_stream(self,byte_stream):
#byte_stream=BytesIO(buffer)
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
print(f'[INFO]put audio stream {sample_rate}: {stream.shape}')
stream = stream.astype(np.float32)
if stream.ndim > 1:
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
stream = stream[:, 0]
if sample_rate != self.sample_rate and stream.shape[0]>0:
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.')
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate)
return stream
2024-09-17 22:11:46 +08:00
def pause_talk(self):
self.tts.pause_talk()
self.asr.pause_talk()
def is_speaking(self)->bool:
return self.speaking
2024-08-03 08:26:17 +08:00
def __loadcustom(self):
for item in self.opt.customopt:
print(item)
input_img_list = glob.glob(os.path.join(item['imgpath'], '*.[jpJP][pnPN]*[gG]'))
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
self.custom_img_cycle[item['audiotype']] = read_imgs(input_img_list)
self.custom_audio_cycle[item['audiotype']], sample_rate = sf.read(item['audiopath'], dtype='float32')
self.custom_audio_index[item['audiotype']] = 0
self.custom_index[item['audiotype']] = 0
self.custom_opt[item['audiotype']] = item
2024-08-03 12:58:49 +08:00
def init_customindex(self):
self.curr_state=0
for key in self.custom_audio_index:
self.custom_audio_index[key]=0
for key in self.custom_index:
self.custom_index[key]=0
2024-09-07 13:44:59 +08:00
def start_recording(self,path):
2024-09-01 18:37:43 +08:00
"""开始录制视频"""
if self.recording:
return
self.recording = True
self.recordq_video.queue.clear()
self.recordq_audio.queue.clear()
2024-09-07 13:44:59 +08:00
self.container = av.open(path, mode="w")
2024-09-01 18:37:43 +08:00
process_thread = Thread(target=self.record_frame, args=())
process_thread.start()
def record_frame(self):
videostream = self.container.add_stream("libx264", rate=25)
videostream.codec_context.time_base = Fraction(1, 25)
audiostream = self.container.add_stream("aac")
audiostream.codec_context.time_base = Fraction(1, 16000)
init = True
framenum = 0
while self.recording:
try:
videoframe = self.recordq_video.get(block=True, timeout=1)
videoframe.pts = framenum #int(round(framenum*0.04 / videostream.codec_context.time_base))
videoframe.dts = videoframe.pts
if init:
videostream.width = videoframe.width
videostream.height = videoframe.height
init = False
for packet in videostream.encode(videoframe):
self.container.mux(packet)
for k in range(2):
audioframe = self.recordq_audio.get(block=True, timeout=1)
audioframe.pts = int(round((framenum*2+k)*0.02 / audiostream.codec_context.time_base))
audioframe.dts = audioframe.pts
for packet in audiostream.encode(audioframe):
self.container.mux(packet)
framenum += 1
except queue.Empty:
print('record queue empty,')
continue
except Exception as e:
print(e)
#break
2024-09-07 13:44:59 +08:00
for packet in videostream.encode(None):
self.container.mux(packet)
for packet in audiostream.encode(None):
self.container.mux(packet)
2024-09-01 18:37:43 +08:00
self.container.close()
self.recordq_video.queue.clear()
self.recordq_audio.queue.clear()
print('record thread stop')
def stop_recording(self):
"""停止录制视频"""
if not self.recording:
return
self.recording = False
2024-08-03 08:26:17 +08:00
def mirror_index(self,size, index):
#size = len(self.coord_list_cycle)
turn = index // size
res = index % size
if turn % 2 == 0:
return res
else:
return size - res - 1
def get_audio_stream(self,audiotype):
idx = self.custom_audio_index[audiotype]
stream = self.custom_audio_cycle[audiotype][idx:idx+self.chunk]
self.custom_audio_index[audiotype] += self.chunk
2024-08-03 12:58:49 +08:00
if self.custom_audio_index[audiotype]>=self.custom_audio_cycle[audiotype].shape[0]:
2024-08-03 08:26:17 +08:00
self.curr_state = 1 #当前视频不循环播放,切换到静音状态
return stream
def set_curr_state(self,audiotype, reinit):
2024-08-03 12:58:49 +08:00
print('set_curr_state:',audiotype)
2024-08-03 08:26:17 +08:00
self.curr_state = audiotype
if reinit:
self.custom_audio_index[audiotype] = 0
self.custom_index[audiotype] = 0
# def process_custom(self,audiotype:int,idx:int):
# if self.curr_state!=audiotype: #从推理切到口播
# if idx in self.switch_pos: #在卡点位置可以切换
# self.curr_state=audiotype
# self.custom_index=0
# else:
# self.custom_index+=1