vllm整合

This commit is contained in:
yanyuxiyangzk@126.com 2024-04-03 20:33:40 +08:00
parent 607be33781
commit fe963ed543
2 changed files with 11 additions and 7 deletions

View File

@ -1,7 +1,7 @@
from llm.Qwen import Qwen from llm.Qwen import Qwen
from llm.Gemini import Gemini from llm.Gemini import Gemini
from llm.ChatGPT import ChatGPT from llm.ChatGPT import ChatGPT
from llm.VllmGPT import VllmGPT
def test_Qwen(question = "如何应对压力?", mode='offline', model_path="Qwen/Qwen-1_8B-Chat"): def test_Qwen(question = "如何应对压力?", mode='offline', model_path="Qwen/Qwen-1_8B-Chat"):
llm = Qwen(mode, model_path) llm = Qwen(mode, model_path)
@ -18,8 +18,8 @@ class LLM:
self.mode = mode self.mode = mode
def init_model(self, model_name, model_path, api_key=None, proxy_url=None): def init_model(self, model_name, model_path, api_key=None, proxy_url=None):
if model_name not in ['Qwen', 'Gemini', 'ChatGPT']: if model_name not in ['Qwen', 'Gemini', 'ChatGPT', 'VllmGPT']:
raise ValueError("model_name must be 'ChatGPT', 'Qwen', or 'Gemini'(其他模型还未集成)") raise ValueError("model_name must be 'ChatGPT', 'VllmGPT', 'Qwen', or 'Gemini'(其他模型还未集成)")
if model_name == 'Gemini': if model_name == 'Gemini':
llm = Gemini(model_path, api_key, proxy_url) llm = Gemini(model_path, api_key, proxy_url)
@ -27,6 +27,8 @@ class LLM:
llm = ChatGPT(model_path, api_key=api_key) llm = ChatGPT(model_path, api_key=api_key)
elif model_name == 'Qwen': elif model_name == 'Qwen':
llm = Qwen(self.mode, model_path) llm = Qwen(self.mode, model_path)
elif model_name == 'VllmGPT':
llm = VllmGPT()
return llm return llm
@ -41,8 +43,10 @@ class LLM:
print(answer) print(answer)
if __name__ == '__main__': if __name__ == '__main__':
llm = LLM() # llm = LLM()
llm.test_Gemini(api_key='你的API Key', proxy_url=None) # llm.test_Gemini(api_key='你的API Key', proxy_url=None)
# llm = LLM().init_model('Gemini', model_path= 'gemini-pro',api_key='AIzaSyBWAWfT8zsyAZcRIXLS5Vzlw8KKCN9qsAg', proxy_url='http://172.31.71.58:7890') # llm = LLM().init_model('Gemini', model_path= 'gemini-pro',api_key='AIzaSyBWAWfT8zsyAZcRIXLS5Vzlw8KKCN9qsAg', proxy_url='http://172.31.71.58:7890')
# response = llm.chat("如何应对压力?") # response = llm.chat("如何应对压力?")
llm = LLM().init_model('VllmGPT', model_path= 'THUDM/chatglm3-6b',api_key='', proxy_url='http://172.31.71.58:7890')
response = llm.chat("如何应对压力?")
# print(response) # print(response)

View File

@ -4,8 +4,8 @@ import requests
class VllmGPT: class VllmGPT:
def __init__(self, host="127.0.0.1", def __init__(self, host="192.168.1.3",
port="8000", port="8101",
model="THUDM/chatglm3-6b", model="THUDM/chatglm3-6b",
max_tokens="1024"): max_tokens="1024"):
self.host = host self.host = host