fix some problem for qwen

This commit is contained in:
waani 2024-04-17 10:11:47 +08:00
parent fafa862ba1
commit 54fcbb8cc7
2 changed files with 10 additions and 9 deletions

View File

@ -26,7 +26,7 @@ class LLM:
elif model_name == 'ChatGPT':
llm = ChatGPT(model_path, api_key=api_key)
elif model_name == 'Qwen':
llm = Qwen(self.mode, model_path)
llm = Qwen(model_path)
elif model_name == 'VllmGPT':
llm = VllmGPT()
return llm

View File

@ -8,18 +8,19 @@ class Qwen:
def __init__(self, model_path="Qwen/Qwen-1_8B-Chat") -> None:
'''暂时不写api版本,与Linly-api相类似,感兴趣可以实现一下'''
self.model, self.tokenizer = self.init_model(model_path)
def init_model(self, path = "Qwen/Qwen-1_8B-Chat"):
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-1_8B-Chat",
device_map="auto",
self.data = {}
def init_model(self, path="Qwen/Qwen-1_8B-Chat"):
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-1_8B-Chat",
device_map="auto",
trust_remote_code=True).eval()
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
return model, tokenizer
return model, tokenizer
def chat(self, question):
self.data["question"] = f"{self.prompt} ### Instruction:{question} ### Response:"
self.data["question"] = f"{question} ### Instruction:{question} ### Response:"
try:
response, history = self.model.chat(self.tokenizer, self.data["question"], history=None)
print(history)
@ -30,7 +31,7 @@ class Qwen:
def test():
llm = Qwen(model_path="Qwen/Qwen-1_8B-Chat")
answer = llm.generate("如何应对压力?")
answer = llm.chat(question="如何应对压力?")
print(answer)
if __name__ == '__main__':