爬glm代码
This commit is contained in:
parent
4cb54acd05
commit
c92c6dd605
|
@ -15,6 +15,8 @@ pip install -e .
|
|||
### 安装UI
|
||||
1. 安装 [graphviz-9.0.0](https://gitlab.com/api/v4/projects/4207231/packages/generic/graphviz-releases/9.0.0/windows_10_cmake_Release_graphviz-install-9.0.0-win64.exe) (详见[官网](https://www.graphviz.org/download/#windows))
|
||||
2. 将软件安装目录的bin文件添加到系统环境中。如电脑是Windows系统,Graphviz安装在D:\Program Files (x86)\Graphviz2.38,该目录下有bin文件,将该路径添加到电脑系统环境变量path中,即D:\Program Files (x86)\Graphviz2.38\bin。
|
||||
3. 安装向量数据库
|
||||
conda install -c conda-forge faiss
|
||||
|
||||
### 快速入门
|
||||
1. 安装UE及Harix插件,打开默认项目并运行
|
||||
|
|
|
@ -21,7 +21,8 @@ from robowaiter.algos.retrieval.retrieval_lm.src.data import load_passages
|
|||
|
||||
from robowaiter.algos.retrieval.retrieval_lm.src.evaluation import calculate_matches
|
||||
import warnings
|
||||
|
||||
from robowaiter.utils.basic import get_root_path
|
||||
root_path = get_root_path()
|
||||
warnings.filterwarnings('ignore')
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||
|
||||
|
@ -244,8 +245,8 @@ def retri(query):
|
|||
)
|
||||
# parser.add_argument("--passages", type=str, default='C:/Users/huangyu/Desktop/RoboWaiter-main/RoboWaiter-main/train_robot.jsonl', help="Path to passages (.tsv file)")
|
||||
# parser.add_argument("--passages_embeddings", type=str, default='C:/Users/huangyu/Desktop/RoboWaiter-main/RoboWaiter-main/robot_embeddings/*', help="Glob path to encoded passages")
|
||||
parser.add_argument("--passages", type=str, default='D:/AAAAA_EI_LLM/UnrealProject/RobotProject/Plugins/RoboWaiter/robowaiter/llm_client/train_robot.jsonl', help="Path to passages (.tsv file)")
|
||||
parser.add_argument("--passages_embeddings", type=str, default='D:/AAAAA_EI_LLM/UnrealProject/RobotProject/Plugins/RoboWaiter/robowaiter/algos/retrieval/robot_embeddings/*', help="Glob path to encoded passages")
|
||||
parser.add_argument("--passages", type=str, default=f'{root_path}/robowaiter/llm_client/train_robot.jsonl', help="Path to passages (.tsv file)")
|
||||
parser.add_argument("--passages_embeddings", type=str, default=f'{root_path}/robowaiter/algos/retrieval/robot_embeddings/*', help="Glob path to encoded passages")
|
||||
|
||||
parser.add_argument(
|
||||
"--output_dir", type=str, default='robot_result', help="Results are written to outputdir with data suffix"
|
||||
|
@ -262,7 +263,7 @@ def retri(query):
|
|||
# "--model_name_or_path", type=str, default='C:\\Users\\huangyu\\Desktop\\RoboWaiter-main\\RoboWaiter-main\\contriever-msmarco',help="path to directory containing model weights and config file"
|
||||
# )
|
||||
parser.add_argument(
|
||||
"--model_name_or_path", type=str, default='D:\\AAAAA_EI_LLM\\UnrealProject\\RobotProject\\Plugins\\RoboWaiter\\robowaiter\\algos\\retrieval\\contriever-msmarco',help="path to directory containing model weights and config file"
|
||||
"--model_name_or_path", type=str, default=f'{root_path}/robowaiter/algos/retrieval/contriever-msmarco',help="path to directory containing model weights and config file"
|
||||
)
|
||||
parser.add_argument("--no_fp16", action="store_true", help="inference in fp32")
|
||||
parser.add_argument("--question_maxlength", type=int, default=512, help="Maximum number of tokens in a question")
|
||||
|
|
|
@ -1 +1 @@
|
|||
{"id": 1, "question": "你能把空调打开一下吗?", "ctxs": [{"id": "505", "title": "你能把空调关闭一下吗?", "text": "Is(AC,0)", "score": "1.8559918"}, {"id": "313", "title": "你能把空调打开一下吗?", "text": "Is(AC,1)", "score": "1.8559918"}, {"id": "312", "title": "你能把空调关闭一下吗?", "text": "Is(AC,0)", "score": "1.8559918"}, {"id": "120", "title": "你能把空调打开一下吗?", "text": "Is(AC,1)", "score": "1.8559918"}, {"id": "119", "title": "你能把空调关闭一下吗?", "text": "Is(AC,0)", "score": "1.8559918"}]}
|
||||
{"id": 1, "question": "你能把空调打开一下吗?", "ctxs": [{"id": "505", "title": "你能把空调关闭一下吗?", "text": "Is(AC,0)", "score": "1.8567487"}, {"id": "313", "title": "你能把空调打开一下吗?", "text": "Is(AC,1)", "score": "1.8567487"}, {"id": "312", "title": "你能把空调关闭一下吗?", "text": "Is(AC,0)", "score": "1.8567487"}, {"id": "120", "title": "你能把空调打开一下吗?", "text": "Is(AC,1)", "score": "1.8567487"}, {"id": "119", "title": "你能把空调关闭一下吗?", "text": "Is(AC,0)", "score": "1.8567487"}]}
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
|
||||
import requests
|
||||
import urllib3
|
||||
########################################
|
||||
# 该文件实现了与大模型的简单通信
|
||||
########################################
|
||||
|
||||
# 忽略https的安全性警告
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
|
||||
def single_round(question,prefix=""):
|
||||
url = "https://45.125.46.134:25344/v1/chat/completions"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
data = {
|
||||
"model": "RoboWaiter",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "你是一个机器人服务员:RoboWaiter. 你的职责是为顾客提供对话及具身服务。"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": prefix + question
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
response = requests.post(url, headers=headers, json=data, verify=False)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
return result['choices'][0]['message']['content'].strip()
|
||||
else:
|
||||
return "大模型请求失败:", response.status_code
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
question = '''
|
||||
给我一杯拿铁
|
||||
'''
|
||||
|
||||
print(single_round(question))
|
Loading…
Reference in New Issue