修改工具调用
This commit is contained in:
parent
d7ee6a1249
commit
2a5695d80e
|
@ -38,6 +38,12 @@ def run_conversation(query: str, stream=False, functions=None, max_retry=5):
|
|||
for _ in range(max_retry):
|
||||
if response["choices"][0]["message"].get("function_call"):
|
||||
function_call = response["choices"][0]["message"]["function_call"]
|
||||
if "sub_task" in function_call["name"]:
|
||||
return {
|
||||
"Answer": "好的",
|
||||
"Goal": function_call["arguments"]
|
||||
}
|
||||
|
||||
logger.info(f"Function Call Response: {function_call}")
|
||||
function_args = json.loads(function_call["arguments"])
|
||||
tool_response = dispatch_tool(function_call["name"], function_args)
|
||||
|
@ -53,31 +59,17 @@ def run_conversation(query: str, stream=False, functions=None, max_retry=5):
|
|||
)
|
||||
else:
|
||||
reply = response["choices"][0]["message"]["content"]
|
||||
return {
|
||||
"Answer": reply,
|
||||
"Goal": None
|
||||
}
|
||||
logger.info(f"Final Reply: \n{reply}")
|
||||
return
|
||||
|
||||
response = get_response(**params)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# chat_messages = [
|
||||
# {
|
||||
# "role": "system",
|
||||
# "content": "You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.",
|
||||
# },
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": "你好,给我讲一个故事,大概100字"
|
||||
# }
|
||||
# ]
|
||||
# create_chat_completion("chatglm3-6b", chat_messages, use_stream=False)
|
||||
|
||||
|
||||
# query = "你是谁"
|
||||
# run_conversation(query, stream=False)
|
||||
#
|
||||
# logger.info("\n=========== next conversation ===========")
|
||||
|
||||
query = "洗手间在哪儿"
|
||||
run_conversation(query, functions=functions, stream=False)
|
||||
query = "给我一杯咖啡"
|
||||
print(run_conversation(query, functions=functions, stream=False))
|
||||
|
|
|
@ -126,25 +126,25 @@ def get_tools() -> dict:
|
|||
|
||||
@register_tool
|
||||
def create_sub_task(
|
||||
goal: Annotated[str, '用于子任务的目标状态集合', True]
|
||||
goal: Annotated[str, '子任务需要达到的目标条件集合', True]
|
||||
) -> str:
|
||||
"""
|
||||
当需要完成具身任务(如做咖啡,拿放物体,扫地,前往某位置)时,调用该函数,根据用户的提示进行意图理解,生成子任务的目标状态集合,以一阶逻辑的形式来表示,例如:前往桌子的目标状态为{At(Robot,Table)},做咖啡的目标状态为{On(Coffee,Bar)}等
|
||||
当需要完成具身任务(如做咖啡,拿放物体,扫地,前往某位置)时,调用该函数,根据用户的提示进行意图理解,生成子任务的目标状态集合 `goal`(以一阶逻辑的形式表示)。
|
||||
"""
|
||||
|
||||
return goal
|
||||
|
||||
@register_tool
|
||||
def get_object_info(
|
||||
object: Annotated[str, '需要判断所在位置的物体', True]
|
||||
obj: Annotated[str, '需要获取信息的物体名称', True]
|
||||
) -> str:
|
||||
"""
|
||||
在场景中找到相邻的物体,并说出 `object` 在输出物体的附近
|
||||
获取场景中指定物体 `object` 的信息
|
||||
"""
|
||||
near_object = None
|
||||
if object == "Table":
|
||||
if obj == "Table":
|
||||
near_object = "Bar"
|
||||
if object == "洗手间":
|
||||
if obj == "洗手间":
|
||||
near_object = "大门"
|
||||
|
||||
return near_object
|
||||
|
|
|
@ -18,7 +18,7 @@ class SceneOT(Scene):
|
|||
super().__init__(robot)
|
||||
# 在这里加入场景中发生的事件
|
||||
self.event_list = [
|
||||
(5,self.create_chat_event("给我做一杯咖啡")) # (事件发生的时间,事件函数)
|
||||
(5,self.create_chat_event("给我一杯咖啡")) # (事件发生的时间,事件函数)
|
||||
]
|
||||
|
||||
def _reset(self):
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import os
|
||||
from robowaiter import Robot, task_map
|
||||
|
||||
TASK_NAME = 'VLN'
|
||||
TASK_NAME = 'OT'
|
||||
|
||||
# create robot
|
||||
project_path = "./robowaiter"
|
||||
|
|
Loading…
Reference in New Issue