修改工具调用

This commit is contained in:
ChenXL97 2023-11-14 23:08:07 +08:00
parent d7ee6a1249
commit 2a5695d80e
4 changed files with 21 additions and 29 deletions

View File

@ -38,6 +38,12 @@ def run_conversation(query: str, stream=False, functions=None, max_retry=5):
for _ in range(max_retry):
if response["choices"][0]["message"].get("function_call"):
function_call = response["choices"][0]["message"]["function_call"]
if "sub_task" in function_call["name"]:
return {
"Answer": "好的",
"Goal": function_call["arguments"]
}
logger.info(f"Function Call Response: {function_call}")
function_args = json.loads(function_call["arguments"])
tool_response = dispatch_tool(function_call["name"], function_args)
@ -53,31 +59,17 @@ def run_conversation(query: str, stream=False, functions=None, max_retry=5):
)
else:
reply = response["choices"][0]["message"]["content"]
return {
"Answer": reply,
"Goal": None
}
logger.info(f"Final Reply: \n{reply}")
return
response = get_response(**params)
if __name__ == "__main__":
# chat_messages = [
# {
# "role": "system",
# "content": "You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.",
# },
# {
# "role": "user",
# "content": "你好给我讲一个故事大概100字"
# }
# ]
# create_chat_completion("chatglm3-6b", chat_messages, use_stream=False)
# query = "你是谁"
# run_conversation(query, stream=False)
#
# logger.info("\n=========== next conversation ===========")
query = "洗手间在哪儿"
run_conversation(query, functions=functions, stream=False)
query = "给我一杯咖啡"
print(run_conversation(query, functions=functions, stream=False))

View File

@ -126,25 +126,25 @@ def get_tools() -> dict:
@register_tool
def create_sub_task(
goal: Annotated[str, '用于子任务的目标状态集合', True]
goal: Annotated[str, '子任务需要达到的目标条件集合', True]
) -> str:
"""
当需要完成具身任务如做咖啡拿放物体扫地前往某位置调用该函数根据用户的提示进行意图理解生成子任务的目标状态集合以一阶逻辑的形式来表示例如前往桌子的目标状态为{At(Robot,Table)}做咖啡的目标状态为{On(Coffee,Bar)}
当需要完成具身任务如做咖啡拿放物体扫地前往某位置调用该函数根据用户的提示进行意图理解生成子任务的目标状态集合 `goal`以一阶逻辑的形式表示
"""
return goal
@register_tool
def get_object_info(
object: Annotated[str, '需要判断所在位置的物体', True]
obj: Annotated[str, '需要获取信息的物体名称', True]
) -> str:
"""
在场景中找到相邻的物体并说出 `object` 在输出物体的附近
获取场景中指定物体 `object` 的信息
"""
near_object = None
if object == "Table":
if obj == "Table":
near_object = "Bar"
if object == "洗手间":
if obj == "洗手间":
near_object = "大门"
return near_object

View File

@ -18,7 +18,7 @@ class SceneOT(Scene):
super().__init__(robot)
# 在这里加入场景中发生的事件
self.event_list = [
(5,self.create_chat_event("给我一杯咖啡")) # (事件发生的时间,事件函数)
(5,self.create_chat_event("给我一杯咖啡")) # (事件发生的时间,事件函数)
]
def _reset(self):

View File

@ -1,7 +1,7 @@
import os
from robowaiter import Robot, task_map
TASK_NAME = 'VLN'
TASK_NAME = 'OT'
# create robot
project_path = "./robowaiter"