From 2a5695d80edbf938c474aec2b59964fdb8eba2b8 Mon Sep 17 00:00:00 2001 From: ChenXL97 <908926798@qq.com> Date: Tue, 14 Nov 2023 23:08:07 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E5=B7=A5=E5=85=B7=E8=B0=83?= =?UTF-8?q?=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- robowaiter/llm_client/tool_api.py | 34 ++++++++++---------------- robowaiter/llm_client/tool_register.py | 12 ++++----- robowaiter/scene/tasks/Open_tasks.py | 2 +- run_robowaiter.py | 2 +- 4 files changed, 21 insertions(+), 29 deletions(-) diff --git a/robowaiter/llm_client/tool_api.py b/robowaiter/llm_client/tool_api.py index 9069663..d924b56 100644 --- a/robowaiter/llm_client/tool_api.py +++ b/robowaiter/llm_client/tool_api.py @@ -38,6 +38,12 @@ def run_conversation(query: str, stream=False, functions=None, max_retry=5): for _ in range(max_retry): if response["choices"][0]["message"].get("function_call"): function_call = response["choices"][0]["message"]["function_call"] + if "sub_task" in function_call["name"]: + return { + "Answer": "好的", + "Goal": function_call["arguments"] + } + logger.info(f"Function Call Response: {function_call}") function_args = json.loads(function_call["arguments"]) tool_response = dispatch_tool(function_call["name"], function_args) @@ -53,31 +59,17 @@ def run_conversation(query: str, stream=False, functions=None, max_retry=5): ) else: reply = response["choices"][0]["message"]["content"] + return { + "Answer": reply, + "Goal": None + } logger.info(f"Final Reply: \n{reply}") return response = get_response(**params) + if __name__ == "__main__": - - # chat_messages = [ - # { - # "role": "system", - # "content": "You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.", - # }, - # { - # "role": "user", - # "content": "你好,给我讲一个故事,大概100字" - # } - # ] - # create_chat_completion("chatglm3-6b", chat_messages, use_stream=False) - - - # query = "你是谁" - # run_conversation(query, stream=False) - # - # logger.info("\n=========== next conversation ===========") - - query = "洗手间在哪儿" - run_conversation(query, functions=functions, stream=False) + query = "给我一杯咖啡" + print(run_conversation(query, functions=functions, stream=False)) diff --git a/robowaiter/llm_client/tool_register.py b/robowaiter/llm_client/tool_register.py index 503a9b7..01022ab 100644 --- a/robowaiter/llm_client/tool_register.py +++ b/robowaiter/llm_client/tool_register.py @@ -126,25 +126,25 @@ def get_tools() -> dict: @register_tool def create_sub_task( - goal: Annotated[str, '用于子任务的目标状态集合', True] + goal: Annotated[str, '子任务需要达到的目标条件集合', True] ) -> str: """ - 当需要完成具身任务(如做咖啡,拿放物体,扫地,前往某位置)时,调用该函数,根据用户的提示进行意图理解,生成子任务的目标状态集合,以一阶逻辑的形式来表示,例如:前往桌子的目标状态为{At(Robot,Table)},做咖啡的目标状态为{On(Coffee,Bar)}等 + 当需要完成具身任务(如做咖啡,拿放物体,扫地,前往某位置)时,调用该函数,根据用户的提示进行意图理解,生成子任务的目标状态集合 `goal`(以一阶逻辑的形式表示)。 """ return goal @register_tool def get_object_info( - object: Annotated[str, '需要判断所在位置的物体', True] + obj: Annotated[str, '需要获取信息的物体名称', True] ) -> str: """ - 在场景中找到相邻的物体,并说出 `object` 在输出物体的附近 + 获取场景中指定物体 `object` 的信息 """ near_object = None - if object == "Table": + if obj == "Table": near_object = "Bar" - if object == "洗手间": + if obj == "洗手间": near_object = "大门" return near_object diff --git a/robowaiter/scene/tasks/Open_tasks.py b/robowaiter/scene/tasks/Open_tasks.py index 01700a2..4ed52eb 100644 --- a/robowaiter/scene/tasks/Open_tasks.py +++ b/robowaiter/scene/tasks/Open_tasks.py @@ -18,7 +18,7 @@ class SceneOT(Scene): super().__init__(robot) # 在这里加入场景中发生的事件 self.event_list = [ - (5,self.create_chat_event("给我做一杯咖啡")) # (事件发生的时间,事件函数) + (5,self.create_chat_event("给我一杯咖啡")) # (事件发生的时间,事件函数) ] def _reset(self): diff --git a/run_robowaiter.py b/run_robowaiter.py index 92c6804..54f37c9 100644 --- a/run_robowaiter.py +++ b/run_robowaiter.py @@ -1,7 +1,7 @@ import os from robowaiter import Robot, task_map -TASK_NAME = 'VLN' +TASK_NAME = 'OT' # create robot project_path = "./robowaiter"