From c92c6dd60598760b6bbf7fcb65600cd40cdbc4a4 Mon Sep 17 00:00:00 2001 From: ChenXL97 <908926798@qq.com> Date: Fri, 24 Nov 2023 09:47:27 +0800 Subject: [PATCH] =?UTF-8?q?=E7=88=ACglm=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 2 + robowaiter/llm_client/passage_retrieval3.py | 9 ++-- .../llm_client/robot_result/test_robot.jsonl | 2 +- robowaiter/llm_client/single_round_crawer.py | 43 +++++++++++++++++++ 4 files changed, 51 insertions(+), 5 deletions(-) create mode 100644 robowaiter/llm_client/single_round_crawer.py diff --git a/README.md b/README.md index 62ec602..49f40bf 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,8 @@ pip install -e . ### 安装UI 1. 安装 [graphviz-9.0.0](https://gitlab.com/api/v4/projects/4207231/packages/generic/graphviz-releases/9.0.0/windows_10_cmake_Release_graphviz-install-9.0.0-win64.exe) (详见[官网](https://www.graphviz.org/download/#windows)) 2. 将软件安装目录的bin文件添加到系统环境中。如电脑是Windows系统,Graphviz安装在D:\Program Files (x86)\Graphviz2.38,该目录下有bin文件,将该路径添加到电脑系统环境变量path中,即D:\Program Files (x86)\Graphviz2.38\bin。 +3. 安装向量数据库 +conda install -c conda-forge faiss ### 快速入门 1. 安装UE及Harix插件,打开默认项目并运行 diff --git a/robowaiter/llm_client/passage_retrieval3.py b/robowaiter/llm_client/passage_retrieval3.py index cc162b5..693b072 100644 --- a/robowaiter/llm_client/passage_retrieval3.py +++ b/robowaiter/llm_client/passage_retrieval3.py @@ -21,7 +21,8 @@ from robowaiter.algos.retrieval.retrieval_lm.src.data import load_passages from robowaiter.algos.retrieval.retrieval_lm.src.evaluation import calculate_matches import warnings - +from robowaiter.utils.basic import get_root_path +root_path = get_root_path() warnings.filterwarnings('ignore') os.environ["TOKENIZERS_PARALLELISM"] = "true" @@ -244,8 +245,8 @@ def retri(query): ) # parser.add_argument("--passages", type=str, default='C:/Users/huangyu/Desktop/RoboWaiter-main/RoboWaiter-main/train_robot.jsonl', help="Path to passages (.tsv file)") # parser.add_argument("--passages_embeddings", type=str, default='C:/Users/huangyu/Desktop/RoboWaiter-main/RoboWaiter-main/robot_embeddings/*', help="Glob path to encoded passages") - parser.add_argument("--passages", type=str, default='D:/AAAAA_EI_LLM/UnrealProject/RobotProject/Plugins/RoboWaiter/robowaiter/llm_client/train_robot.jsonl', help="Path to passages (.tsv file)") - parser.add_argument("--passages_embeddings", type=str, default='D:/AAAAA_EI_LLM/UnrealProject/RobotProject/Plugins/RoboWaiter/robowaiter/algos/retrieval/robot_embeddings/*', help="Glob path to encoded passages") + parser.add_argument("--passages", type=str, default=f'{root_path}/robowaiter/llm_client/train_robot.jsonl', help="Path to passages (.tsv file)") + parser.add_argument("--passages_embeddings", type=str, default=f'{root_path}/robowaiter/algos/retrieval/robot_embeddings/*', help="Glob path to encoded passages") parser.add_argument( "--output_dir", type=str, default='robot_result', help="Results are written to outputdir with data suffix" @@ -262,7 +263,7 @@ def retri(query): # "--model_name_or_path", type=str, default='C:\\Users\\huangyu\\Desktop\\RoboWaiter-main\\RoboWaiter-main\\contriever-msmarco',help="path to directory containing model weights and config file" # ) parser.add_argument( - "--model_name_or_path", type=str, default='D:\\AAAAA_EI_LLM\\UnrealProject\\RobotProject\\Plugins\\RoboWaiter\\robowaiter\\algos\\retrieval\\contriever-msmarco',help="path to directory containing model weights and config file" + "--model_name_or_path", type=str, default=f'{root_path}/robowaiter/algos/retrieval/contriever-msmarco',help="path to directory containing model weights and config file" ) parser.add_argument("--no_fp16", action="store_true", help="inference in fp32") parser.add_argument("--question_maxlength", type=int, default=512, help="Maximum number of tokens in a question") diff --git a/robowaiter/llm_client/robot_result/test_robot.jsonl b/robowaiter/llm_client/robot_result/test_robot.jsonl index 1abd9ec..9a0bdbe 100644 --- a/robowaiter/llm_client/robot_result/test_robot.jsonl +++ b/robowaiter/llm_client/robot_result/test_robot.jsonl @@ -1 +1 @@ -{"id": 1, "question": "你能把空调打开一下吗?", "ctxs": [{"id": "505", "title": "你能把空调关闭一下吗?", "text": "Is(AC,0)", "score": "1.8559918"}, {"id": "313", "title": "你能把空调打开一下吗?", "text": "Is(AC,1)", "score": "1.8559918"}, {"id": "312", "title": "你能把空调关闭一下吗?", "text": "Is(AC,0)", "score": "1.8559918"}, {"id": "120", "title": "你能把空调打开一下吗?", "text": "Is(AC,1)", "score": "1.8559918"}, {"id": "119", "title": "你能把空调关闭一下吗?", "text": "Is(AC,0)", "score": "1.8559918"}]} +{"id": 1, "question": "你能把空调打开一下吗?", "ctxs": [{"id": "505", "title": "你能把空调关闭一下吗?", "text": "Is(AC,0)", "score": "1.8567487"}, {"id": "313", "title": "你能把空调打开一下吗?", "text": "Is(AC,1)", "score": "1.8567487"}, {"id": "312", "title": "你能把空调关闭一下吗?", "text": "Is(AC,0)", "score": "1.8567487"}, {"id": "120", "title": "你能把空调打开一下吗?", "text": "Is(AC,1)", "score": "1.8567487"}, {"id": "119", "title": "你能把空调关闭一下吗?", "text": "Is(AC,0)", "score": "1.8567487"}]} diff --git a/robowaiter/llm_client/single_round_crawer.py b/robowaiter/llm_client/single_round_crawer.py new file mode 100644 index 0000000..d0c88c6 --- /dev/null +++ b/robowaiter/llm_client/single_round_crawer.py @@ -0,0 +1,43 @@ + +import requests +import urllib3 +######################################## +# 该文件实现了与大模型的简单通信 +######################################## + +# 忽略https的安全性警告 +urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + + +def single_round(question,prefix=""): + url = "https://45.125.46.134:25344/v1/chat/completions" + headers = {"Content-Type": "application/json"} + data = { + "model": "RoboWaiter", + "messages": [ + { + "role": "system", + "content": "你是一个机器人服务员:RoboWaiter. 你的职责是为顾客提供对话及具身服务。" + }, + { + "role": "user", + "content": prefix + question + } + ] + } + + response = requests.post(url, headers=headers, json=data, verify=False) + + if response.status_code == 200: + result = response.json() + return result['choices'][0]['message']['content'].strip() + else: + return "大模型请求失败:", response.status_code + + +if __name__ == '__main__': + question = ''' + 给我一杯拿铁 + ''' + + print(single_round(question)) \ No newline at end of file