diff --git a/robowaiter/llm_client/find_obj_utils.py b/robowaiter/llm_client/find_obj_utils.py index 2d940ae..98f965d 100644 --- a/robowaiter/llm_client/find_obj_utils.py +++ b/robowaiter/llm_client/find_obj_utils.py @@ -35,7 +35,7 @@ all_loc_en = ['bar', 'Table', 'sofa', 'stove', 'Gate', 'light switch', 'aircondi 'cake display', 'ChargingStations', 'refrigerator', 'bookshelf'] loc_map_en = {'bar': {'工作台', '服务台', '收银台', '蛋糕柜'}, - 'Table': {'沙发', '大门', '窗户', '休闲区', '墙角', '椅子', '书架'}, + 'Table': {'大门', '休闲区', '墙角'}, 'sofa': {'餐桌', '窗户', '音响', '休闲区', '墙角', '书架'}, 'stove': {'吧台', '橱柜', '工作台', '服务台', '收银台', '蛋糕柜', '冰箱'}, 'Gate': {'吧台', '灯开关', '空调开关', '卫生间', '墙角'}, diff --git a/robowaiter/llm_client/tool_api.py b/robowaiter/llm_client/tool_api.py index 51be613..f6f1bd5 100644 --- a/robowaiter/llm_client/tool_api.py +++ b/robowaiter/llm_client/tool_api.py @@ -60,7 +60,7 @@ def run_conversation(query: str, stream=False, max_retry=5): "content": tool_response, # 调用函数返回结果 } ) - # del params["functions"] + del params["functions"] else: reply = response["choices"][0]["message"]["content"] return { @@ -116,12 +116,24 @@ if __name__ == "__main__": query = "卫生间在哪里" # print(run_conversation(query, stream=False)) - query = "我想看看冰箱,请问哪里可以找到冰箱" - print(run_conversation(query, stream=False)) - - query = "我想找个充电的地方,你能告诉我在哪儿吗" - print(run_conversation(query, stream=False)) - query = "我想找张桌子" # print(run_conversation(query, stream=False)) - # for query in \ No newline at end of file + + query = "我想看看休闲区,请问哪里可以找到休闲区" + print(run_conversation(query, stream=False)) + # + # query = "我想找个充电的地方,你能告诉我在哪儿吗" + # print(run_conversation(query, stream=False)) + + + query = "我想来一个面包" + print(run_conversation(query, stream=False)) + + query = "哪里有卫生纸" + print(run_conversation(query, stream=False)) + + query = "插座在哪里你知道吗" + print(run_conversation(query, stream=False)) + + query = "你们的咖啡厅里有香蕉吗" + print(run_conversation(query, stream=False)) \ No newline at end of file diff --git a/robowaiter/llm_client/tool_register.py b/robowaiter/llm_client/tool_register.py index b2f893e..f156858 100644 --- a/robowaiter/llm_client/tool_register.py +++ b/robowaiter/llm_client/tool_register.py @@ -11,7 +11,7 @@ import spacy _TOOL_HOOKS = {} _TOOL_DESCRIPTIONS = {} nlp = spacy.load('en_core_web_lg') - +all_object = find_obj_utils.all_loc + find_obj_utils.all_obj def register_tool(func: callable): tool_name = func.__name__ @@ -147,37 +147,27 @@ def create_sub_task( return goal -# @register_tool -# def get_object_info( -# obj: Annotated[str, '需要获取信息的物体名称', True] -# ) -> str: -# """ -# 获取场景中指定物体 `object` 在哪里,不涉及到具体的执行任务 -# 如果`object` 是一个地点,例如洗手间,则输出大门。 -# 如果`object`是咖啡,则输出桌子,咖啡在桌子上。 -# 如果`object` 是空桌子,则输出一号桌 -# """ -# near_object = None -# # if obj == "Table": -# # near_object = "Bar" -# # if obj == "洗手间": -# # near_object = "大门" -# # if obj == "空桌子": -# # near_object = "一号桌" -# if obj in find_obj_utils.all_loc: # object是一个地点 -# mp = list(find_obj_utils.loc_map[obj]) -# # near_object = random.choice(mp) -# near_object = mp -# if obj in find_obj_utils.all_obj: # object是一个物品 -# near_ls = find_obj_utils.all_loc + find_obj_utils.all_obj -# near_object = random.choices(near_ls,k=5) -# return near_object +@register_tool +def find_object( + object: Annotated[str, '客人咨询的物品', True] +) -> str: + """" + 用户想找某个物品,获取的object中文 + 在输出中可以找到 + 基于生活经验从输出中选择一个可以与object相关联的来描述物品 + """ + near_object = None + + if object in find_obj_utils.all_obj: # object是一个物品 + near_object = random.choices(all_object, k=10) + # near_object.append(object) + return near_object @register_tool def find_location( location: Annotated[str, '客人咨询的地点', True] ) -> str: - """" + """ 获取的location为英文 用户想找某个地点 """ diff --git a/robowaiter/scene/scene.py b/robowaiter/scene/scene.py index 8837713..0b86f86 100644 --- a/robowaiter/scene/scene.py +++ b/robowaiter/scene/scene.py @@ -483,11 +483,8 @@ class Scene: return False def adjust_kongtiao(self,op_type): - print("self.obj_loc:",self.obj_loc) obj_loc = self.obj_loc[:] - print("obj_loc:",obj_loc,"self.obj_loc:", self.obj_loc) obj_loc[2] -= 5 - print("obj_loc:",obj_loc) if op_type == 13: obj_loc[1] -= 2 if op_type == 14: obj_loc[1] -= 0 if op_type == 15: obj_loc[1] += 2