From 54fcbb8cc7515acc43eb1641446aa74474c8bdd2 Mon Sep 17 00:00:00 2001 From: waani Date: Wed, 17 Apr 2024 10:11:47 +0800 Subject: [PATCH] fix some problem for qwen --- llm/LLM.py | 2 +- llm/Qwen.py | 17 +++++++++-------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/llm/LLM.py b/llm/LLM.py index 3157e8e..e279d26 100644 --- a/llm/LLM.py +++ b/llm/LLM.py @@ -26,7 +26,7 @@ class LLM: elif model_name == 'ChatGPT': llm = ChatGPT(model_path, api_key=api_key) elif model_name == 'Qwen': - llm = Qwen(self.mode, model_path) + llm = Qwen(model_path) elif model_name == 'VllmGPT': llm = VllmGPT() return llm diff --git a/llm/Qwen.py b/llm/Qwen.py index 6c79b31..ab2341b 100644 --- a/llm/Qwen.py +++ b/llm/Qwen.py @@ -8,18 +8,19 @@ class Qwen: def __init__(self, model_path="Qwen/Qwen-1_8B-Chat") -> None: '''暂时不写api版本,与Linly-api相类似,感兴趣可以实现一下''' self.model, self.tokenizer = self.init_model(model_path) - - def init_model(self, path = "Qwen/Qwen-1_8B-Chat"): - model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-1_8B-Chat", - device_map="auto", + self.data = {} + + def init_model(self, path="Qwen/Qwen-1_8B-Chat"): + model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-1_8B-Chat", + device_map="auto", trust_remote_code=True).eval() tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) - return model, tokenizer - + return model, tokenizer + def chat(self, question): - self.data["question"] = f"{self.prompt} ### Instruction:{question} ### Response:" + self.data["question"] = f"{question} ### Instruction:{question} ### Response:" try: response, history = self.model.chat(self.tokenizer, self.data["question"], history=None) print(history) @@ -30,7 +31,7 @@ class Qwen: def test(): llm = Qwen(model_path="Qwen/Qwen-1_8B-Chat") - answer = llm.generate("如何应对压力?") + answer = llm.chat(question="如何应对压力?") print(answer) if __name__ == '__main__':