2024-01-09 10:01:50 +08:00
|
|
|
import math
|
|
|
|
import torch
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
#from .utils import *
|
|
|
|
import subprocess
|
|
|
|
import os
|
|
|
|
import time
|
2024-03-09 10:07:47 +08:00
|
|
|
import torch.nn.functional as F
|
2024-03-23 21:13:21 +08:00
|
|
|
import cv2
|
2024-08-23 23:19:38 +08:00
|
|
|
import glob
|
2024-01-09 10:01:50 +08:00
|
|
|
|
2024-08-03 12:58:49 +08:00
|
|
|
from nerfasr import NerfASR
|
2024-06-02 22:25:19 +08:00
|
|
|
from ttsreal import EdgeTTS,VoitsTTS,XTTS
|
|
|
|
|
2024-04-14 19:08:25 +08:00
|
|
|
import asyncio
|
|
|
|
from av import AudioFrame, VideoFrame
|
2024-08-03 12:58:49 +08:00
|
|
|
from basereal import BaseReal
|
2024-01-09 10:01:50 +08:00
|
|
|
|
2024-09-08 12:13:33 +08:00
|
|
|
#from imgcache import ImgCache
|
|
|
|
|
2024-08-23 23:19:38 +08:00
|
|
|
from tqdm import tqdm
|
|
|
|
def read_imgs(img_list):
|
|
|
|
frames = []
|
|
|
|
print('reading images...')
|
|
|
|
for img_path in tqdm(img_list):
|
|
|
|
frame = cv2.imread(img_path)
|
|
|
|
frames.append(frame)
|
|
|
|
return frames
|
|
|
|
|
2024-08-03 12:58:49 +08:00
|
|
|
class NeRFReal(BaseReal):
|
2024-01-09 10:01:50 +08:00
|
|
|
def __init__(self, opt, trainer, data_loader, debug=True):
|
2024-08-03 12:58:49 +08:00
|
|
|
super().__init__(opt)
|
|
|
|
#self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
|
2024-01-09 10:01:50 +08:00
|
|
|
self.W = opt.W
|
|
|
|
self.H = opt.H
|
|
|
|
|
|
|
|
self.trainer = trainer
|
|
|
|
self.data_loader = data_loader
|
|
|
|
|
|
|
|
# use dataloader's bg
|
2024-06-01 06:58:02 +08:00
|
|
|
#bg_img = data_loader._data.bg_img #.view(1, -1, 3)
|
|
|
|
#if self.H != bg_img.shape[0] or self.W != bg_img.shape[1]:
|
|
|
|
# bg_img = F.interpolate(bg_img.permute(2, 0, 1).unsqueeze(0).contiguous(), (self.H, self.W), mode='bilinear').squeeze(0).permute(1, 2, 0).contiguous()
|
|
|
|
#self.bg_color = bg_img.view(1, -1, 3)
|
2024-01-09 10:01:50 +08:00
|
|
|
|
|
|
|
# audio features (from dataloader, only used in non-playing mode)
|
2024-06-01 06:58:02 +08:00
|
|
|
#self.audio_features = data_loader._data.auds # [N, 29, 16]
|
|
|
|
#self.audio_idx = 0
|
2024-01-09 10:01:50 +08:00
|
|
|
|
2024-03-30 08:06:46 +08:00
|
|
|
#self.frame_total_num = data_loader._data.end_index
|
|
|
|
#print("frame_total_num:",self.frame_total_num)
|
2024-03-23 21:13:21 +08:00
|
|
|
|
2024-01-09 10:01:50 +08:00
|
|
|
# control eye
|
2024-06-01 06:58:02 +08:00
|
|
|
#self.eye_area = None if not self.opt.exp_eye else data_loader._data.eye_area.mean().item()
|
2024-01-09 10:01:50 +08:00
|
|
|
|
|
|
|
# playing seq from dataloader, or pause.
|
|
|
|
self.loader = iter(data_loader)
|
2024-08-23 23:19:38 +08:00
|
|
|
frame_total_num = data_loader._data.end_index
|
|
|
|
if opt.fullbody:
|
|
|
|
input_img_list = glob.glob(os.path.join(self.opt.fullbody_img, '*.[jpJP][pnPN]*[gG]'))
|
|
|
|
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
|
|
|
|
#print('input_img_list:',input_img_list)
|
|
|
|
self.fullbody_list_cycle = read_imgs(input_img_list[:frame_total_num])
|
2024-09-08 12:13:33 +08:00
|
|
|
#self.imagecache = ImgCache(frame_total_num,self.opt.fullbody_img,1000)
|
2024-01-09 10:01:50 +08:00
|
|
|
|
2024-06-01 06:58:02 +08:00
|
|
|
#self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32)
|
|
|
|
#self.need_update = True # camera moved, should reset accumulation
|
|
|
|
#self.spp = 1 # sample per pixel
|
|
|
|
#self.mode = 'image' # choose from ['image', 'depth']
|
2024-01-09 10:01:50 +08:00
|
|
|
|
2024-06-01 06:58:02 +08:00
|
|
|
#self.dynamic_resolution = False # assert False!
|
|
|
|
#self.downscale = 1
|
|
|
|
#self.train_steps = 16
|
2024-01-09 10:01:50 +08:00
|
|
|
|
2024-06-01 06:58:02 +08:00
|
|
|
#self.ind_index = 0
|
|
|
|
#self.ind_num = trainer.model.individual_codes.shape[0]
|
2024-01-09 10:01:50 +08:00
|
|
|
|
2024-08-03 12:58:49 +08:00
|
|
|
#self.customimg_index = 0
|
2024-05-04 10:10:41 +08:00
|
|
|
|
2024-01-09 10:01:50 +08:00
|
|
|
# build asr
|
2024-08-03 12:58:49 +08:00
|
|
|
self.asr = NerfASR(opt,self)
|
2024-06-30 09:41:31 +08:00
|
|
|
self.asr.warm_up()
|
2024-01-09 10:01:50 +08:00
|
|
|
|
|
|
|
'''
|
|
|
|
video_path = 'video_stream'
|
|
|
|
if not os.path.exists(video_path):
|
|
|
|
os.mkfifo(video_path, mode=0o777)
|
|
|
|
audio_path = 'audio_stream'
|
|
|
|
if not os.path.exists(audio_path):
|
|
|
|
os.mkfifo(audio_path, mode=0o777)
|
|
|
|
width=450
|
|
|
|
height=450
|
|
|
|
command = ['ffmpeg',
|
|
|
|
'-y', #'-an',
|
|
|
|
#'-re',
|
|
|
|
'-f', 'rawvideo',
|
|
|
|
'-vcodec','rawvideo',
|
|
|
|
'-pix_fmt', 'rgb24', #像素格式
|
|
|
|
'-s', "{}x{}".format(width, height),
|
|
|
|
'-r', str(fps),
|
|
|
|
'-i', video_path,
|
|
|
|
'-f', 's16le',
|
|
|
|
'-acodec','pcm_s16le',
|
|
|
|
'-ac', '1',
|
|
|
|
'-ar', '16000',
|
|
|
|
'-i', audio_path,
|
|
|
|
#'-fflags', '+genpts',
|
|
|
|
'-map', '0:v',
|
|
|
|
'-map', '1:a',
|
|
|
|
#'-copyts',
|
|
|
|
'-acodec', 'aac',
|
|
|
|
'-pix_fmt', 'yuv420p', #'-vcodec', "h264",
|
|
|
|
#"-rtmp_buffer", "100",
|
|
|
|
'-f' , 'flv',
|
|
|
|
push_url]
|
|
|
|
self.pipe = subprocess.Popen(command, shell=False) #, stdin=subprocess.PIPE)
|
|
|
|
self.fifo_video = open(video_path, 'wb')
|
|
|
|
self.fifo_audio = open(audio_path, 'wb')
|
|
|
|
#self.test_step()
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
|
|
def __enter__(self):
|
|
|
|
return self
|
|
|
|
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
|
|
if self.opt.asr:
|
2024-09-17 22:11:46 +08:00
|
|
|
self.asr.stop()
|
2024-04-05 08:55:21 +08:00
|
|
|
|
2024-05-04 10:10:41 +08:00
|
|
|
|
2024-08-03 12:58:49 +08:00
|
|
|
# def mirror_index(self, index):
|
|
|
|
# size = self.opt.customvideo_imgnum
|
|
|
|
# turn = index // size
|
|
|
|
# res = index % size
|
|
|
|
# if turn % 2 == 0:
|
|
|
|
# return res
|
|
|
|
# else:
|
|
|
|
# return size - res - 1
|
2024-01-09 10:01:50 +08:00
|
|
|
|
2024-04-14 19:08:25 +08:00
|
|
|
def test_step(self,loop=None,audio_track=None,video_track=None):
|
2024-01-09 10:01:50 +08:00
|
|
|
|
|
|
|
#starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
|
|
|
|
#starter.record()
|
|
|
|
|
2024-06-01 06:58:02 +08:00
|
|
|
try:
|
|
|
|
data = next(self.loader)
|
|
|
|
except StopIteration:
|
|
|
|
self.loader = iter(self.data_loader)
|
|
|
|
data = next(self.loader)
|
|
|
|
|
|
|
|
if self.opt.asr:
|
|
|
|
# use the live audio stream
|
|
|
|
data['auds'] = self.asr.get_next_feat()
|
2024-01-09 10:01:50 +08:00
|
|
|
|
2024-08-03 12:58:49 +08:00
|
|
|
audiotype1 = 0
|
|
|
|
audiotype2 = 0
|
|
|
|
#send audio
|
|
|
|
for i in range(2):
|
|
|
|
frame,type = self.asr.get_audio_out()
|
|
|
|
if i==0:
|
|
|
|
audiotype1 = type
|
|
|
|
else:
|
|
|
|
audiotype2 = type
|
|
|
|
#print(f'[INFO] get_audio_out shape ',frame.shape)
|
|
|
|
if self.opt.transport=='rtmp':
|
2024-06-01 06:58:02 +08:00
|
|
|
self.streamer.stream_frame_audio(frame)
|
2024-08-03 12:58:49 +08:00
|
|
|
else: #webrtc
|
2024-06-01 06:58:02 +08:00
|
|
|
frame = (frame * 32767).astype(np.int16)
|
|
|
|
new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0])
|
|
|
|
new_frame.planes[0].update(frame.tobytes())
|
|
|
|
new_frame.sample_rate=16000
|
2024-08-03 12:58:49 +08:00
|
|
|
asyncio.run_coroutine_threadsafe(audio_track._queue.put(new_frame), loop)
|
|
|
|
|
|
|
|
# if self.opt.transport=='rtmp':
|
|
|
|
# for _ in range(2):
|
|
|
|
# frame,type = self.asr.get_audio_out()
|
|
|
|
# audiotype += type
|
|
|
|
# #print(f'[INFO] get_audio_out shape ',frame.shape)
|
|
|
|
# self.streamer.stream_frame_audio(frame)
|
|
|
|
# else: #webrtc
|
|
|
|
# for _ in range(2):
|
|
|
|
# frame,type = self.asr.get_audio_out()
|
|
|
|
# audiotype += type
|
|
|
|
# frame = (frame * 32767).astype(np.int16)
|
|
|
|
# new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0])
|
|
|
|
# new_frame.planes[0].update(frame.tobytes())
|
|
|
|
# new_frame.sample_rate=16000
|
|
|
|
# # if audio_track._queue.qsize()>10:
|
|
|
|
# # time.sleep(0.1)
|
|
|
|
# asyncio.run_coroutine_threadsafe(audio_track._queue.put(new_frame), loop)
|
2024-06-01 06:58:02 +08:00
|
|
|
#t = time.time()
|
2024-09-17 22:11:46 +08:00
|
|
|
if audiotype1!=0 and audiotype2!=0: #全为静音数据
|
|
|
|
self.speaking = False
|
|
|
|
else:
|
|
|
|
self.speaking = True
|
|
|
|
|
2024-08-03 12:58:49 +08:00
|
|
|
if audiotype1!=0 and audiotype2!=0 and self.custom_index.get(audiotype1) is not None: #不为推理视频并且有自定义视频
|
|
|
|
mirindex = self.mirror_index(len(self.custom_img_cycle[audiotype1]),self.custom_index[audiotype1])
|
|
|
|
#imgindex = self.mirror_index(self.customimg_index)
|
2024-06-01 06:58:02 +08:00
|
|
|
#print('custom img index:',imgindex)
|
2024-08-03 12:58:49 +08:00
|
|
|
#image = cv2.imread(os.path.join(self.opt.customvideo_img, str(int(imgindex))+'.png'))
|
|
|
|
image = self.custom_img_cycle[audiotype1][mirindex]
|
2024-06-01 06:58:02 +08:00
|
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
2024-08-03 12:58:49 +08:00
|
|
|
self.custom_index[audiotype1] += 1
|
2024-04-20 18:40:34 +08:00
|
|
|
if self.opt.transport=='rtmp':
|
2024-06-01 06:58:02 +08:00
|
|
|
self.streamer.stream_frame(image)
|
2024-04-20 18:40:34 +08:00
|
|
|
else:
|
2024-06-01 06:58:02 +08:00
|
|
|
new_frame = VideoFrame.from_ndarray(image, format="rgb24")
|
|
|
|
asyncio.run_coroutine_threadsafe(video_track._queue.put(new_frame), loop)
|
2024-08-03 12:58:49 +08:00
|
|
|
else: #推理视频+贴回
|
2024-06-01 06:58:02 +08:00
|
|
|
outputs = self.trainer.test_gui_with_data(data, self.W, self.H)
|
|
|
|
#print('-------ernerf time: ',time.time()-t)
|
|
|
|
#print(f'[INFO] outputs shape ',outputs['image'].shape)
|
|
|
|
image = (outputs['image'] * 255).astype(np.uint8)
|
|
|
|
if not self.opt.fullbody:
|
2024-05-04 10:10:41 +08:00
|
|
|
if self.opt.transport=='rtmp':
|
|
|
|
self.streamer.stream_frame(image)
|
|
|
|
else:
|
|
|
|
new_frame = VideoFrame.from_ndarray(image, format="rgb24")
|
|
|
|
asyncio.run_coroutine_threadsafe(video_track._queue.put(new_frame), loop)
|
2024-06-01 06:58:02 +08:00
|
|
|
else: #fullbody human
|
|
|
|
#print("frame index:",data['index'])
|
2024-08-23 23:19:38 +08:00
|
|
|
#image_fullbody = cv2.imread(os.path.join(self.opt.fullbody_img, str(data['index'][0])+'.jpg'))
|
|
|
|
image_fullbody = self.fullbody_list_cycle[data['index'][0]]
|
2024-09-08 12:13:33 +08:00
|
|
|
#image_fullbody = self.imagecache.get_img(data['index'][0])
|
|
|
|
image_fullbody = cv2.cvtColor(image_fullbody, cv2.COLOR_BGR2RGB)
|
2024-06-01 06:58:02 +08:00
|
|
|
start_x = self.opt.fullbody_offset_x # 合并后小图片的起始x坐标
|
|
|
|
start_y = self.opt.fullbody_offset_y # 合并后小图片的起始y坐标
|
|
|
|
image_fullbody[start_y:start_y+image.shape[0], start_x:start_x+image.shape[1]] = image
|
|
|
|
if self.opt.transport=='rtmp':
|
|
|
|
self.streamer.stream_frame(image_fullbody)
|
|
|
|
else:
|
|
|
|
new_frame = VideoFrame.from_ndarray(image_fullbody, format="rgb24")
|
|
|
|
asyncio.run_coroutine_threadsafe(video_track._queue.put(new_frame), loop)
|
2024-05-04 10:10:41 +08:00
|
|
|
#self.pipe.stdin.write(image.tostring())
|
2024-06-01 06:58:02 +08:00
|
|
|
|
2024-01-09 10:01:50 +08:00
|
|
|
#ender.record()
|
|
|
|
#torch.cuda.synchronize()
|
|
|
|
#t = starter.elapsed_time(ender)
|
|
|
|
|
2024-04-14 19:08:25 +08:00
|
|
|
def render(self,quit_event,loop=None,audio_track=None,video_track=None):
|
2024-04-20 08:29:08 +08:00
|
|
|
#if self.opt.asr:
|
|
|
|
# self.asr.warm_up()
|
2024-05-12 10:30:47 +08:00
|
|
|
|
2024-08-03 12:58:49 +08:00
|
|
|
self.init_customindex()
|
|
|
|
|
2024-04-14 19:08:25 +08:00
|
|
|
if self.opt.transport=='rtmp':
|
2024-05-02 20:32:28 +08:00
|
|
|
from rtmp_streaming import StreamerConfig, Streamer
|
2024-04-14 19:08:25 +08:00
|
|
|
fps=25
|
|
|
|
#push_url='rtmp://localhost/live/livestream' #'data/video/output_0.mp4'
|
|
|
|
sc = StreamerConfig()
|
|
|
|
sc.source_width = self.W
|
|
|
|
sc.source_height = self.H
|
|
|
|
sc.stream_width = self.W
|
|
|
|
sc.stream_height = self.H
|
|
|
|
if self.opt.fullbody:
|
|
|
|
sc.source_width = self.opt.fullbody_width
|
|
|
|
sc.source_height = self.opt.fullbody_height
|
|
|
|
sc.stream_width = self.opt.fullbody_width
|
|
|
|
sc.stream_height = self.opt.fullbody_height
|
|
|
|
sc.stream_fps = fps
|
|
|
|
sc.stream_bitrate = 1000000
|
|
|
|
sc.stream_profile = 'baseline' #'high444' # 'main'
|
|
|
|
sc.audio_channel = 1
|
|
|
|
sc.sample_rate = 16000
|
|
|
|
sc.stream_server = self.opt.push_url
|
|
|
|
self.streamer = Streamer()
|
|
|
|
self.streamer.init(sc)
|
|
|
|
#self.streamer.enable_av_debug_log()
|
|
|
|
|
2024-05-12 10:30:47 +08:00
|
|
|
count=0
|
|
|
|
totaltime=0
|
|
|
|
_starttime=time.perf_counter()
|
|
|
|
_totalframe=0
|
2024-06-02 22:25:19 +08:00
|
|
|
|
|
|
|
self.tts.render(quit_event)
|
2024-04-14 19:08:25 +08:00
|
|
|
while not quit_event.is_set(): #todo
|
2024-01-09 10:01:50 +08:00
|
|
|
# update texture every frame
|
|
|
|
# audio stream thread...
|
2024-04-20 08:29:08 +08:00
|
|
|
t = time.perf_counter()
|
2024-06-30 09:41:31 +08:00
|
|
|
# run 2 ASR steps (audio is at 50FPS, video is at 25FPS)
|
|
|
|
for _ in range(2):
|
|
|
|
self.asr.run_step()
|
2024-04-14 19:08:25 +08:00
|
|
|
self.test_step(loop,audio_track,video_track)
|
2024-04-20 08:29:08 +08:00
|
|
|
totaltime += (time.perf_counter() - t)
|
2024-01-13 17:12:25 +08:00
|
|
|
count += 1
|
2024-05-12 10:30:47 +08:00
|
|
|
_totalframe += 1
|
2024-01-13 17:12:25 +08:00
|
|
|
if count==100:
|
2024-05-02 20:32:28 +08:00
|
|
|
print(f"------actual avg infer fps:{count/totaltime:.4f}")
|
2024-01-13 17:12:25 +08:00
|
|
|
count=0
|
|
|
|
totaltime=0
|
2024-05-31 23:12:48 +08:00
|
|
|
if self.opt.transport=='rtmp':
|
|
|
|
delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms
|
|
|
|
if delay > 0:
|
|
|
|
time.sleep(delay)
|
|
|
|
else:
|
|
|
|
if video_track._queue.qsize()>=5:
|
|
|
|
#print('sleep qsize=',video_track._queue.qsize())
|
2024-06-30 09:41:31 +08:00
|
|
|
time.sleep(0.04*video_track._queue.qsize()*0.8)
|
2024-06-02 22:25:19 +08:00
|
|
|
print('nerfreal thread stop')
|
2024-05-31 23:12:48 +08:00
|
|
|
|
2024-01-09 10:01:50 +08:00
|
|
|
|