import asyncio import json import websockets import time import logging import tracemalloc import numpy as np import argparse import ssl parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="0.0.0.0", required=False, help="host ip, localhost, 0.0.0.0") parser.add_argument("--port", type=int, default=10095, required=False, help="grpc server port") parser.add_argument("--asr_model", type=str, default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", help="model from modelscope") parser.add_argument("--asr_model_revision", type=str, default="v2.0.4", help="") parser.add_argument("--asr_model_online", type=str, default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online", help="model from modelscope") parser.add_argument("--asr_model_online_revision", type=str, default="v2.0.4", help="") parser.add_argument("--vad_model", type=str, default="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch", help="model from modelscope") parser.add_argument("--vad_model_revision", type=str, default="v2.0.4", help="") parser.add_argument("--punc_model", type=str, default="iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727", help="model from modelscope") parser.add_argument("--punc_model_revision", type=str, default="v2.0.4", help="") parser.add_argument("--ngpu", type=int, default=1, help="0 for cpu, 1 for gpu") parser.add_argument("--device", type=str, default="cuda", help="cuda, cpu") parser.add_argument("--ncpu", type=int, default=4, help="cpu cores") parser.add_argument("--certfile", type=str, default="ssl_key/server.crt", required=False, help="certfile for ssl") parser.add_argument("--keyfile", type=str, default="ssl_key/server.key", required=False, help="keyfile for ssl") args = parser.parse_args() websocket_users = set() print("model loading") from funasr import AutoModel # asr model_asr = AutoModel(model=args.asr_model, model_revision=args.asr_model_revision, ngpu=args.ngpu, ncpu=args.ncpu, device=args.device, disable_pbar=True, disable_log=True, ) # asr model_asr_streaming = AutoModel(model=args.asr_model_online, model_revision=args.asr_model_online_revision, ngpu=args.ngpu, ncpu=args.ncpu, device=args.device, disable_pbar=True, disable_log=True, ) # vad model_vad = AutoModel(model=args.vad_model, model_revision=args.vad_model_revision, ngpu=args.ngpu, ncpu=args.ncpu, device=args.device, disable_pbar=True, disable_log=True, # chunk_size=60, ) if args.punc_model != "": model_punc = AutoModel(model=args.punc_model, model_revision=args.punc_model_revision, ngpu=args.ngpu, ncpu=args.ncpu, device=args.device, disable_pbar=True, disable_log=True, ) else: model_punc = None print("model loaded! only support one client at the same time now!!!!") async def ws_reset(websocket): print("ws reset now, total num is ",len(websocket_users)) websocket.status_dict_asr_online["cache"] = {} websocket.status_dict_asr_online["is_final"] = True websocket.status_dict_vad["cache"] = {} websocket.status_dict_vad["is_final"] = True websocket.status_dict_punc["cache"] = {} await websocket.close() async def clear_websocket(): for websocket in websocket_users: await ws_reset(websocket) websocket_users.clear() async def ws_serve(websocket, path): frames = [] frames_asr = [] frames_asr_online = [] global websocket_users # await clear_websocket() websocket_users.add(websocket) websocket.status_dict_asr = {} websocket.status_dict_asr_online = {"cache": {}, "is_final": False} websocket.status_dict_vad = {'cache': {}, "is_final": False} websocket.status_dict_punc = {'cache': {}} websocket.chunk_interval = 10 websocket.vad_pre_idx = 0 speech_start = False speech_end_i = -1 websocket.wav_name = "microphone" websocket.mode = "2pass" print("new user connected", flush=True) try: async for message in websocket: if isinstance(message, str): messagejson = json.loads(message) if "is_speaking" in messagejson: websocket.is_speaking = messagejson["is_speaking"] websocket.status_dict_asr_online["is_final"] = not websocket.is_speaking if "chunk_interval" in messagejson: websocket.chunk_interval = messagejson["chunk_interval"] if "wav_name" in messagejson: websocket.wav_name = messagejson.get("wav_name") if "chunk_size" in messagejson: chunk_size = messagejson["chunk_size"] if isinstance(chunk_size, str): chunk_size = chunk_size.split(',') websocket.status_dict_asr_online["chunk_size"] = [int(x) for x in chunk_size] if "encoder_chunk_look_back" in messagejson: websocket.status_dict_asr_online["encoder_chunk_look_back"] = messagejson["encoder_chunk_look_back"] if "decoder_chunk_look_back" in messagejson: websocket.status_dict_asr_online["decoder_chunk_look_back"] = messagejson["decoder_chunk_look_back"] if "hotword" in messagejson: websocket.status_dict_asr["hotword"] = messagejson["hotword"] if "mode" in messagejson: websocket.mode = messagejson["mode"] websocket.status_dict_vad["chunk_size"] = int(websocket.status_dict_asr_online["chunk_size"][1]*60/websocket.chunk_interval) if len(frames_asr_online) > 0 or len(frames_asr) > 0 or not isinstance(message, str): if not isinstance(message, str): frames.append(message) duration_ms = len(message)//32 websocket.vad_pre_idx += duration_ms # asr online frames_asr_online.append(message) websocket.status_dict_asr_online["is_final"] = speech_end_i != -1 if len(frames_asr_online) % websocket.chunk_interval == 0 or websocket.status_dict_asr_online["is_final"]: if websocket.mode == "2pass" or websocket.mode == "online": audio_in = b"".join(frames_asr_online) try: await async_asr_online(websocket, audio_in) except: print(f"error in asr streaming, {websocket.status_dict_asr_online}") frames_asr_online = [] if speech_start: frames_asr.append(message) # vad online try: speech_start_i, speech_end_i = await async_vad(websocket, message) except: print("error in vad") if speech_start_i != -1: speech_start = True beg_bias = (websocket.vad_pre_idx-speech_start_i)//duration_ms frames_pre = frames[-beg_bias:] frames_asr = [] frames_asr.extend(frames_pre) # asr punc offline if speech_end_i != -1 or not websocket.is_speaking: # print("vad end point") if websocket.mode == "2pass" or websocket.mode == "offline": audio_in = b"".join(frames_asr) try: await async_asr(websocket, audio_in) except: print("error in asr offline") frames_asr = [] speech_start = False frames_asr_online = [] websocket.status_dict_asr_online["cache"] = {} if not websocket.is_speaking: websocket.vad_pre_idx = 0 frames = [] websocket.status_dict_vad["cache"] = {} else: frames = frames[-20:] except websockets.ConnectionClosed: print("ConnectionClosed...", websocket_users,flush=True) await ws_reset(websocket) websocket_users.remove(websocket) except websockets.InvalidState: print("InvalidState...") except Exception as e: print("Exception:", e) async def async_vad(websocket, audio_in): segments_result = model_vad.generate(input=audio_in, **websocket.status_dict_vad)[0]["value"] # print(segments_result) speech_start = -1 speech_end = -1 if len(segments_result) == 0 or len(segments_result) > 1: return speech_start, speech_end if segments_result[0][0] != -1: speech_start = segments_result[0][0] if segments_result[0][1] != -1: speech_end = segments_result[0][1] return speech_start, speech_end async def async_asr(websocket, audio_in): if len(audio_in) > 0: # print(len(audio_in)) rec_result = model_asr.generate(input=audio_in, **websocket.status_dict_asr)[0] # print("offline_asr, ", rec_result) if model_punc is not None and len(rec_result["text"])>0: # print("offline, before punc", rec_result, "cache", websocket.status_dict_punc) rec_result = model_punc.generate(input=rec_result['text'], **websocket.status_dict_punc)[0] # print("offline, after punc", rec_result) if len(rec_result["text"])>0: # print("offline", rec_result) mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode message = json.dumps({"mode": mode, "text": rec_result["text"], "wav_name": websocket.wav_name,"is_final":websocket.is_speaking}) await websocket.send(message) async def async_asr_online(websocket, audio_in): if len(audio_in) > 0: # print(websocket.status_dict_asr_online.get("is_final", False)) rec_result = model_asr_streaming.generate(input=audio_in, **websocket.status_dict_asr_online)[0] # print("online, ", rec_result) if websocket.mode == "2pass" and websocket.status_dict_asr_online.get("is_final", False): return # websocket.status_dict_asr_online["cache"] = dict() if len(rec_result["text"]): mode = "2pass-online" if "2pass" in websocket.mode else websocket.mode message = json.dumps({"mode": mode, "text": rec_result["text"], "wav_name": websocket.wav_name,"is_final":websocket.is_speaking}) await websocket.send(message) if len(args.certfile)>0: ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) # Generate with Lets Encrypt, copied to this location, chown to current user and 400 permissions ssl_cert = args.certfile ssl_key = args.keyfile ssl_context.load_cert_chain(ssl_cert, keyfile=ssl_key) start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None,ssl=ssl_context) else: start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None) asyncio.get_event_loop().run_until_complete(start_server) asyncio.get_event_loop().run_forever()