livetalking/llm/Qwen.py

39 lines
1.4 KiB
Python
Raw Normal View History

2024-01-27 19:38:13 +08:00
import os
import torch
import requests
from transformers import AutoModelForCausalLM, AutoTokenizer
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
class Qwen:
def __init__(self, model_path="Qwen/Qwen-1_8B-Chat") -> None:
'''暂时不写api版本,与Linly-api相类似,感兴趣可以实现一下'''
self.model, self.tokenizer = self.init_model(model_path)
2024-04-17 10:11:47 +08:00
self.data = {}
def init_model(self, path="Qwen/Qwen-1_8B-Chat"):
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-1_8B-Chat",
device_map="auto",
2024-01-27 19:38:13 +08:00
trust_remote_code=True).eval()
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
2024-04-17 10:11:47 +08:00
return model, tokenizer
2024-01-27 19:38:13 +08:00
def chat(self, question):
2024-04-17 10:11:47 +08:00
self.data["question"] = f"{question} ### Instruction:{question} ### Response:"
2024-01-27 19:38:13 +08:00
try:
response, history = self.model.chat(self.tokenizer, self.data["question"], history=None)
print(history)
return response
except:
return "对不起,你的请求出错了,请再次尝试。\nSorry, your request has encountered an error. Please try again.\n"
def test():
llm = Qwen(model_path="Qwen/Qwen-1_8B-Chat")
2024-04-17 10:11:47 +08:00
answer = llm.chat(question="如何应对压力?")
2024-01-27 19:38:13 +08:00
print(answer)
if __name__ == '__main__':
test()