在tool_api_multi_round.py中可以进行多轮对话

This commit is contained in:
ChenXL97 2023-11-16 15:04:27 +08:00
parent 13776107e7
commit b032548d55
8 changed files with 171 additions and 23 deletions

View File

@ -32,6 +32,7 @@ functions = get_tools()
def run_conversation(query: str, stream=False, max_retry=5): def run_conversation(query: str, stream=False, max_retry=5):
params = dict(model="chatglm3", messages=[{"role": "user", "content": query}], stream=stream) params = dict(model="chatglm3", messages=[{"role": "user", "content": query}], stream=stream)
params["functions"] = functions params["functions"] = functions
print(params)
response = get_response(**params) response = get_response(**params)
for _ in range(max_retry): for _ in range(max_retry):

View File

@ -0,0 +1,78 @@
import json
import openai
from colorama import init, Fore
from loguru import logger
import json
from robowaiter.llm_client.tool_register import get_tools, dispatch_tool
import requests
import json
import urllib3
init(autoreset=True)
########################################
# 该文件实现了与大模型的通信以及工具调用
########################################
# 忽略https的安全性警告
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
base_url = "https://45.125.46.134:25344" # 本地部署的地址,或者使用你访问模型的API地址
def get_response(**kwargs):
data = kwargs
response = requests.post(f"{base_url}/v1/chat/completions", json=data, stream=data["stream"], verify=False)
decoded_line = response.json()
return decoded_line
functions = get_tools()
if __name__ == "__main__":
question = input("\n顾客:")
data_memory = [{
"role": "system",
"content": "你是RoboWaiter,一个由HPCL团队开发的机器人服务员你在咖啡厅工作。接受顾客的指令并调用工具函数来完成各种服务任务。",
},]
n = 1
max_retry = 5
params = dict(model="RoboWaiter",messages=data_memory, stream=False)
params["functions"] = functions
while question != 'end':
user_dict = {"role": "user", "content": question}
params["messages"].append(user_dict)
# print(data_memory)
response = get_response(**params)
for _ in range(max_retry):
if response["choices"][0]["message"].get("function_call"):
function_call = response["choices"][0]["message"]["function_call"]
logger.info(f"Function Call Response: {function_call}")
function_args = json.loads(function_call["arguments"])
tool_response = dispatch_tool(function_call["name"], function_args)
logger.info(f"Tool Call Response: {tool_response}")
return_message = response["choices"][0]["message"]
params["messages"].append(return_message)
t = {
"role": "function",
"name": function_call["name"],
"content": tool_response, # 调用函数返回结果
}
params["messages"].append(t)
response = get_response(**params)
else:
return_message = response["choices"][0]["message"]
reply = return_message["content"]
params["messages"].append(return_message)
logger.info(f"Final Reply: \n{reply}")
break
question = input("\n顾客:")

View File

@ -0,0 +1,74 @@
import json
import openai
from colorama import init, Fore
from loguru import logger
import json
from robowaiter.llm_client.tool_register import get_tools, dispatch_tool
import requests
import json
import urllib3
init(autoreset=True)
########################################
# 该文件实现了与大模型的通信以及工具调用
########################################
# 忽略https的安全性警告
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
base_url = "https://45.125.46.134:25344" # 本地部署的地址,或者使用你访问模型的API地址
def get_response(**kwargs):
data = kwargs
response = requests.post(f"{base_url}/v1/chat/completions", json=data, stream=data["stream"], verify=False)
decoded_line = response.json()
return decoded_line
functions = get_tools()
def run_conversation(query: str, stream=False, max_retry=5):
params = dict(model="chatglm3", messages=[{"role": "user", "content": query}], stream=stream)
params["functions"] = functions
response = get_response(**params)
for _ in range(max_retry):
if response["choices"][0]["message"].get("function_call"):
function_call = response["choices"][0]["message"]["function_call"]
logger.info(f"Function Call Response: {function_call}")
if "sub_task" in function_call["name"]:
return {
"Answer": "好的",
"Goal": json.loads(function_call["arguments"])["goal"]
}
function_args = json.loads(function_call["arguments"])
tool_response = dispatch_tool(function_call["name"], function_args)
logger.info(f"Tool Call Response: {tool_response}")
params["messages"].append(response["choices"][0]["message"])
params["messages"].append(
{
"role": "function",
"name": function_call["name"],
"content": tool_response, # 调用函数返回结果
}
)
else:
reply = response["choices"][0]["message"]["content"]
return {
"Answer": reply,
"Goal": None
}
logger.info(f"Final Reply: \n{reply}")
return
response = get_response(**params)
if __name__ == "__main__":
query = "可以带我去吗"
print(run_conversation(query, stream=False))

View File

@ -132,8 +132,13 @@ def create_sub_task(
当需要完成具身任务如做咖啡拿放物体扫地前往某位置调用该函数根据用户的提示进行意图理解生成子任务的目标状态集合 `goal`以一阶逻辑的形式表示用户意图 当需要完成具身任务如做咖啡拿放物体扫地前往某位置调用该函数根据用户的提示进行意图理解生成子任务的目标状态集合 `goal`以一阶逻辑的形式表示用户意图
做一杯咖啡,则该函数的参数为 "On(Coffee,Bar)", 做一杯咖啡,则该函数的参数为 "On(Coffee,Bar)",
前往一号桌,则该函数的参数为 "At(Robot,Table1)", 前往一号桌,则该函数的参数为 "At(Robot,Table1)",
打开空调,则该函数的参数为 "Is(AC,On)", 前往二号桌,则该函数的参数为 "At(Robot,Table2)",
关空调,则该函数的参数为 "Is(AC,Off)", 打开空调,则该函数的参数为 "Is(AC,On)",
关空调,则该函数的参数为 "Is(AC,Off)",
打开窗帘,则该函数的参数为 "Is(Curtain,On)",
关闭窗帘,则该函数的参数为 "Is(Curtain,Off)",
拖地,则该函数的参数为 "Is(Floor,Clean)",
打开大厅灯,则该函数的参数为 "Is(HallLight,On)",
""" """
return goal return goal
@ -143,9 +148,9 @@ def get_object_info(
obj: Annotated[str, '需要获取信息的物体名称', True] obj: Annotated[str, '需要获取信息的物体名称', True]
) -> str: ) -> str:
""" """
获取场景中指定物体 `object` 在哪里 获取场景中指定物体 `object` 在哪里不涉及到具体的执行任务
如果`object` 是一个地点例如洗手间地方则输出 如果`object` 是一个地点例如洗手间则输出大门
如果`object`一个咖啡则输出 如果`object`咖啡则输出桌子咖啡在桌子上
如果`object` 是空桌子则输出一号桌 如果`object` 是空桌子则输出一号桌
""" """
near_object = None near_object = None

View File

@ -16,8 +16,9 @@ class SceneAT(Scene):
super().__init__(robot) super().__init__(robot)
def _reset(self): def _reset(self):
self.add_walker(1085, 2630, 220) # self.add_walker(1085, 2630, 220)
self.control_walker([self.walker_control_generator(0, False, 100, 755, 1900, 180)]) # self.control_walker([self.walker_control_generator(0, False, 100, 755, 1900, 180)])
pass
def _run(self): def _run(self):

View File

@ -22,8 +22,8 @@ class SceneGQA(Scene):
def _reset(self): def _reset(self):
# self.clean_walker() # self.clean_walker()
self.add_walkers() # self.add_walkers([[50, 500,90]])
pass
# self.walker_bubble("洗手间在哪里") # self.walker_bubble("洗手间在哪里")
# self.control_walker([self.walker_control_generator(0, False, 100, 755, 1900, 180)]) # self.control_walker([self.walker_control_generator(0, False, 100, 755, 1900, 180)])

View File

@ -19,8 +19,7 @@ class SceneOT(Scene):
super().__init__(robot) super().__init__(robot)
# 在这里加入场景中发生的事件 # 在这里加入场景中发生的事件
self.event_list = [ self.event_list = [
# (5,self.create_chat_event("给我一杯咖啡")) # (事件发生的时间,事件函数) (5, self.create_chat_event("我有点热,能开个空调吗?")),
(5, self.create_chat_event("我有点热,能开个空调吗?")) # (事件发生的时间,事件函数)
] ]
def _reset(self): def _reset(self):
@ -30,7 +29,6 @@ class SceneOT(Scene):
print("scene.walkers:",scene.walkers) print("scene.walkers:",scene.walkers)
cont = scene.walkers[0].name+":我有点热,能开个空调吗?" cont = scene.walkers[0].name+":我有点热,能开个空调吗?"
self.control_robot_action(0,3,cont) self.control_robot_action(0,3,cont)
# self.control_walker([self.walker_control_generator(0, False, 100, 755, 1900, 180)])
pass pass
def _run(self): def _run(self):

View File

@ -18,20 +18,11 @@ class SceneOT(Scene):
super().__init__(robot) super().__init__(robot)
# 在这里加入场景中发生的事件 # 在这里加入场景中发生的事件
self.event_list = [ self.event_list = [
# (5,self.create_chat_event("做一杯咖啡")), (5, self.create_chat_event("给我一杯咖啡"))
(5,self.create_chat_event("感觉有点冷,可以关一下空调吗")),
] ]
def _reset(self): def _reset(self):
scene = self.add_walkers([[50, 300, 0]]) pass
# time.sleep(2.0)
# print("我有点热,能开个空调吗?")
print("scene.walkers:",scene.walkers)
cont = scene.walkers[0].name+":我有点热,能开个空调吗?"
self.control_robot_action(0,3,cont)
# self.add_walker(1085, 2630, 220)
# self.control_walker([self.walker_control_generator(0, False, 100, 755, 1900, 180)])
def _run(self): def _run(self):