2023-11-14 12:10:23 +08:00
|
|
|
import json
|
2023-11-16 20:07:01 +08:00
|
|
|
import re
|
2023-11-14 12:10:23 +08:00
|
|
|
|
|
|
|
from colorama import init, Fore
|
|
|
|
from loguru import logger
|
2023-11-15 14:30:57 +08:00
|
|
|
import json
|
|
|
|
from robowaiter.llm_client.tool_register import get_tools, dispatch_tool
|
2023-11-14 12:10:23 +08:00
|
|
|
import requests
|
|
|
|
import json
|
|
|
|
|
|
|
|
import urllib3
|
2023-11-16 20:07:01 +08:00
|
|
|
|
2023-11-14 15:26:02 +08:00
|
|
|
init(autoreset=True)
|
|
|
|
|
2023-11-14 12:10:23 +08:00
|
|
|
########################################
|
2023-11-14 15:26:02 +08:00
|
|
|
# 该文件实现了与大模型的通信以及工具调用
|
2023-11-14 12:10:23 +08:00
|
|
|
########################################
|
|
|
|
|
|
|
|
# 忽略https的安全性警告
|
|
|
|
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
|
|
|
|
2023-11-16 20:07:01 +08:00
|
|
|
base_url = "https://45.125.46.134:25344" # 本地部署的地址,或者使用你访问模型的API地址
|
|
|
|
|
2023-11-14 12:10:23 +08:00
|
|
|
|
|
|
|
def get_response(**kwargs):
|
|
|
|
data = kwargs
|
|
|
|
|
|
|
|
response = requests.post(f"{base_url}/v1/chat/completions", json=data, stream=data["stream"], verify=False)
|
|
|
|
decoded_line = response.json()
|
|
|
|
return decoded_line
|
|
|
|
|
2023-11-16 20:07:01 +08:00
|
|
|
|
2023-11-14 12:10:23 +08:00
|
|
|
functions = get_tools()
|
|
|
|
|
2023-11-16 20:07:01 +08:00
|
|
|
|
|
|
|
def run_conversation(query: str, stream=False, max_retry=5):
|
2023-11-14 12:10:23 +08:00
|
|
|
params = dict(model="chatglm3", messages=[{"role": "user", "content": query}], stream=stream)
|
2023-11-15 14:30:57 +08:00
|
|
|
params["functions"] = functions
|
2023-11-14 12:10:23 +08:00
|
|
|
response = get_response(**params)
|
|
|
|
|
2023-11-17 17:13:27 +08:00
|
|
|
for retry in range(max_retry):
|
2023-11-14 12:10:23 +08:00
|
|
|
if response["choices"][0]["message"].get("function_call"):
|
|
|
|
function_call = response["choices"][0]["message"]["function_call"]
|
2023-11-15 14:30:57 +08:00
|
|
|
logger.info(f"Function Call Response: {function_call}")
|
2023-11-14 23:08:07 +08:00
|
|
|
if "sub_task" in function_call["name"]:
|
|
|
|
return {
|
|
|
|
"Answer": "好的",
|
2023-11-15 14:30:57 +08:00
|
|
|
"Goal": json.loads(function_call["arguments"])["goal"]
|
2023-11-14 23:08:07 +08:00
|
|
|
}
|
|
|
|
|
2023-11-14 12:10:23 +08:00
|
|
|
function_args = json.loads(function_call["arguments"])
|
|
|
|
tool_response = dispatch_tool(function_call["name"], function_args)
|
|
|
|
logger.info(f"Tool Call Response: {tool_response}")
|
|
|
|
|
|
|
|
params["messages"].append(response["choices"][0]["message"])
|
|
|
|
params["messages"].append(
|
|
|
|
{
|
|
|
|
"role": "function",
|
|
|
|
"name": function_call["name"],
|
|
|
|
"content": tool_response, # 调用函数返回结果
|
|
|
|
}
|
|
|
|
)
|
2023-11-17 21:17:19 +08:00
|
|
|
del params["functions"]
|
2023-11-14 12:10:23 +08:00
|
|
|
else:
|
|
|
|
reply = response["choices"][0]["message"]["content"]
|
2023-11-14 23:08:07 +08:00
|
|
|
return {
|
|
|
|
"Answer": reply,
|
|
|
|
"Goal": None
|
|
|
|
}
|
2023-11-14 12:10:23 +08:00
|
|
|
logger.info(f"Final Reply: \n{reply}")
|
|
|
|
return
|
|
|
|
|
|
|
|
response = get_response(**params)
|
|
|
|
|
|
|
|
|
2023-11-16 20:07:01 +08:00
|
|
|
def run_conversation_for_test_only(query: str, stream=False, max_retry=5):
|
|
|
|
params = dict(model="chatglm3", messages=[{"role": "user", "content": query}], stream=stream)
|
|
|
|
params["functions"] = functions
|
|
|
|
response = get_response(**params)
|
|
|
|
|
|
|
|
response_string = ''
|
|
|
|
|
|
|
|
for _ in range(max_retry):
|
|
|
|
if response["choices"][0]["message"].get("function_call"):
|
|
|
|
function_call = response["choices"][0]["message"]["function_call"]
|
|
|
|
response_string += f"Function Call: {function_call} \t"
|
|
|
|
|
|
|
|
function_args = json.loads(function_call["arguments"])
|
|
|
|
if function_call["name"]:
|
|
|
|
tool_response = dispatch_tool(function_call["name"], function_args)
|
|
|
|
response_string += f"Tool Call: %s \t" % (re.sub(r'\n', '', tool_response))
|
|
|
|
else:
|
|
|
|
response_string += f"LLM Cannot find the function call."
|
|
|
|
|
|
|
|
params["messages"].append(response["choices"][0]["message"])
|
|
|
|
params["messages"].append(
|
|
|
|
{
|
|
|
|
"role": "function",
|
|
|
|
"name": function_call["name"],
|
|
|
|
"content": tool_response, # 调用函数返回结果
|
|
|
|
}
|
|
|
|
)
|
|
|
|
response = get_response(**params)['choices'][0]
|
|
|
|
return response_string + "\tResponse: " + str(response)
|
|
|
|
else:
|
|
|
|
reply = response["choices"][0]["message"]["content"]
|
|
|
|
response_string += f"Final Reply: %s" % (re.sub(r'\n', '', reply))
|
|
|
|
response = get_response(**params)['choices'][0]
|
|
|
|
return response_string + "\tResponse: " + str(response)
|
|
|
|
|
2023-11-14 12:10:23 +08:00
|
|
|
|
2023-11-14 23:08:07 +08:00
|
|
|
if __name__ == "__main__":
|
2023-11-17 17:13:27 +08:00
|
|
|
# query = "可以带我去吗"
|
|
|
|
# print(run_conversation_for_test_only(query, stream=False))
|
|
|
|
|
|
|
|
query = "卫生间在哪里" #
|
|
|
|
print(run_conversation(query, stream=False))
|
|
|
|
|
2023-11-17 21:17:19 +08:00
|
|
|
query = "我想找张桌子" #
|
2023-11-17 17:13:27 +08:00
|
|
|
print(run_conversation(query, stream=False))
|
|
|
|
|
2023-11-17 21:17:19 +08:00
|
|
|
query = "我想看看休闲区,请问哪里可以找到休闲区"
|
2023-11-17 17:13:27 +08:00
|
|
|
print(run_conversation(query, stream=False))
|
2023-11-17 21:17:19 +08:00
|
|
|
#
|
|
|
|
# query = "我想找个充电的地方,你能告诉我在哪儿吗"
|
|
|
|
# print(run_conversation(query, stream=False))
|
2023-11-17 16:22:31 +08:00
|
|
|
|
2023-11-17 21:17:19 +08:00
|
|
|
|
|
|
|
query = "我想来一个面包"
|
|
|
|
print(run_conversation(query, stream=False))
|
|
|
|
|
|
|
|
query = "哪里有卫生纸"
|
|
|
|
print(run_conversation(query, stream=False))
|
|
|
|
|
|
|
|
query = "插座在哪里你知道吗"
|
2023-11-17 17:13:27 +08:00
|
|
|
print(run_conversation(query, stream=False))
|
2023-11-17 21:17:19 +08:00
|
|
|
|
|
|
|
query = "你们的咖啡厅里有香蕉吗"
|
|
|
|
print(run_conversation(query, stream=False))
|