diff --git a/robowaiter/llm_client/tool_api.py b/robowaiter/llm_client/tool_api.py index 2780a14..03a9fb3 100644 --- a/robowaiter/llm_client/tool_api.py +++ b/robowaiter/llm_client/tool_api.py @@ -5,16 +5,14 @@ from colorama import init, Fore from loguru import logger from tool_register import get_tools, dispatch_tool - -init(autoreset=True) - -# 使用Python代码测返回 import requests import json import urllib3 +init(autoreset=True) + ######################################## -# 该文件实现了与大模型的简单通信 +# 该文件实现了与大模型的通信以及工具调用 ######################################## # 忽略https的安全性警告 @@ -77,10 +75,10 @@ if __name__ == "__main__": - query = "你是谁" - run_conversation(query, stream=False) + # query = "你是谁" + # run_conversation(query, stream=False) + # + # logger.info("\n=========== next conversation ===========") - logger.info("\n=========== next conversation ===========") - - query = "帮我查询北京的天气怎么样" + query = "洗手间在哪儿" run_conversation(query, functions=functions, stream=False) diff --git a/robowaiter/llm_client/tool_api_request.py b/robowaiter/llm_client/tool_api_request.py deleted file mode 100644 index ebaa273..0000000 --- a/robowaiter/llm_client/tool_api_request.py +++ /dev/null @@ -1,67 +0,0 @@ -# 使用curl命令测试返回 -# curl -X POST "http://127.0.0.1:8000/v1/chat/completions" \ -# -H "Content-Type: application/json" \ -# -d "{\"model\": \"chatglm3-6b\", \"messages\": [{\"role\": \"system\", \"content\": \"You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.\"}, {\"role\": \"user\", \"content\": \"你好,给我讲一个故事,大概100字\"}], \"stream\": false, \"max_tokens\": 100, \"temperature\": 0.8, \"top_p\": 0.8}" - -# 使用Python代码测返回 -import requests -import json - -import urllib3 -######################################## -# 该文件实现了与大模型的简单通信 -######################################## - -# 忽略https的安全性警告 -urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) - -base_url = "https://45.125.46.134:25344" # 本地部署的地址,或者使用你访问模型的API地址 - -def create_chat_completion(model, messages, use_stream=False): - data = { - "model": model, # 模型名称 - "messages": messages, # 会话历史 - "stream": use_stream, # 是否流式响应 - "max_tokens": 100, # 最多生成字数 - "temperature": 0.8, # 温度 - "top_p": 0.8, # 采样概率 - } - - response = requests.post(f"{base_url}/v1/chat/completions", json=data, stream=use_stream, verify=False) - if response.status_code == 200: - if use_stream: - # 处理流式响应 - for line in response.iter_lines(): - if line: - decoded_line = line.decode('utf-8')[6:] - try: - response_json = json.loads(decoded_line) - content = response_json.get("choices", [{}])[0].get("delta", {}).get("content", "") - print(content) - except: - print("Special Token:", decoded_line) - else: - # 处理非流式响应 - decoded_line = response.json() - print(decoded_line) - content = decoded_line.get("choices", [{}])[0].get("message", "").get("content", "") - print(content) - else: - print("Error:", response.status_code) - return None - - -if __name__ == "__main__": - chat_messages = [ - { - "role": "system", - "content": "You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.", - }, - { - "role": "user", - "content": "你好,给我讲一个故事,大概100字" - } - ] - create_chat_completion("chatglm3-6b", chat_messages, use_stream=False) - - diff --git a/robowaiter/llm_client/tool_register.py b/robowaiter/llm_client/tool_register.py index 39ae2fc..ac04c32 100644 --- a/robowaiter/llm_client/tool_register.py +++ b/robowaiter/llm_client/tool_register.py @@ -64,50 +64,91 @@ def get_tools() -> dict: # Tool Definitions -@register_tool -def random_number_generator( - seed: Annotated[int, 'The random seed used by the generator', True], - range: Annotated[tuple[int, int], 'The range of the generated numbers', True], -) -> int: - """ - Generates a random number x, s.t. range[0] <= x < range[1] - """ - if not isinstance(seed, int): - raise TypeError("Seed must be an integer") - if not isinstance(range, tuple): - raise TypeError("Range must be a tuple") - if not isinstance(range[0], int) or not isinstance(range[1], int): - raise TypeError("Range must be a tuple of integers") +# @register_tool +# def random_number_generator( +# seed: Annotated[int, 'The random seed used by the generator', True], +# range: Annotated[tuple[int, int], 'The range of the generated numbers', True], +# ) -> int: +# """ +# Generates a random number x, s.t. range[0] <= x < range[1] +# """ +# if not isinstance(seed, int): +# raise TypeError("Seed must be an integer") +# if not isinstance(range, tuple): +# raise TypeError("Range must be a tuple") +# if not isinstance(range[0], int) or not isinstance(range[1], int): +# raise TypeError("Range must be a tuple of integers") +# +# import random +# return random.Random(seed).randint(*range) - import random - return random.Random(seed).randint(*range) +# @register_tool +# def get_weather( +# city_name: Annotated[str, 'The name of the city to be queried', True], +# ) -> str: +# """ +# Get the current weather for `city_name` +# """ +# +# if not isinstance(city_name, str): +# raise TypeError("City name must be a string") +# +# key_selection = { +# "current_condition": ["temp_C", "FeelsLikeC", "humidity", "weatherDesc", "observation_time"], +# } +# import requests +# try: +# resp = requests.get(f"https://wttr.in/{city_name}?format=j1") +# resp.raise_for_status() +# resp = resp.json() +# ret = {k: {_v: resp[k][0][_v] for _v in v} for k, v in key_selection.items()} +# except: +# import traceback +# ret = "Error encountered while fetching weather data!\n" + traceback.format_exc() +# +# return str(ret) + + +# @register_tool +# def add( +# a: Annotated[int, '需要相加的第1个数', True], +# b: Annotated[int, '需要相加的第2个数', True] +# ) -> int: +# """ +# 获取 `a` + `b` 的值 +# """ +# +# if (not isinstance(a, int)) or (not isinstance(b, int)): +# raise TypeError("相加的数必须为整数") +# +# return int(a+b) @register_tool -def get_weather( - city_name: Annotated[str, 'The name of the city to be queried', True], +def create_sub_task( + goal: Annotated[str, '用于子任务的目标状态集合', True] ) -> str: """ - Get the current weather for `city_name` + 当需要完成具身任务(如做咖啡,拿放物体,扫地,前往某位置)时,调用该函数,根据用户的提示进行意图理解,生成子任务的目标状态集合,以一阶逻辑的形式来表示,例如:前往桌子的目标状态为{At(Robot,Table)},做咖啡的目标状态为{On(Coffee,Bar)}等 """ - if not isinstance(city_name, str): - raise TypeError("City name must be a string") + return goal - key_selection = { - "current_condition": ["temp_C", "FeelsLikeC", "humidity", "weatherDesc", "observation_time"], - } - import requests - try: - resp = requests.get(f"https://wttr.in/{city_name}?format=j1") - resp.raise_for_status() - resp = resp.json() - ret = {k: {_v: resp[k][0][_v] for _v in v} for k, v in key_selection.items()} - except: - import traceback - ret = "Error encountered while fetching weather data!\n" + traceback.format_exc() +@register_tool +def find_near_object( + object: Annotated[str, '需要判断所在位置的物体', True] +) -> str: + """ + 在场景中找到相邻的物体,并说出 `object` 在输出物体的附近 + """ + near_object = None + if object == "Table": + near_object = "Bar" + if object == "洗手间": + near_object = "大门" + + return near_object - return str(ret) if __name__ == "__main__":