From c9e8dc3981419e787c6ec453ae52f564aa57cf3d Mon Sep 17 00:00:00 2001 From: ChenXL97 <908926798@qq.com> Date: Tue, 14 Nov 2023 12:10:23 +0800 Subject: [PATCH] =?UTF-8?q?=E7=8E=B0=E5=9C=A8=E5=8F=AF=E4=BB=A5=E8=B0=83?= =?UTF-8?q?=E7=94=A8=E5=A4=A7=E6=A8=A1=E5=9E=8B=E7=9A=84function=20call?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- requirements.txt | 1 + robowaiter/behavior_lib/_base/Behavior.py | 5 +- robowaiter/llm_client/openai_api_request.py | 67 ++++++++++++ robowaiter/llm_client/tool_api.py | 86 +++++++++++++++ robowaiter/llm_client/tool_api_request.py | 67 ++++++++++++ robowaiter/llm_client/tool_register.py | 115 ++++++++++++++++++++ 6 files changed, 337 insertions(+), 4 deletions(-) create mode 100644 robowaiter/llm_client/openai_api_request.py create mode 100644 robowaiter/llm_client/tool_api.py create mode 100644 robowaiter/llm_client/tool_api_request.py create mode 100644 robowaiter/llm_client/tool_register.py diff --git a/requirements.txt b/requirements.txt index 3c8b9f6..7659ed1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,3 +12,4 @@ pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 +loguru \ No newline at end of file diff --git a/robowaiter/behavior_lib/_base/Behavior.py b/robowaiter/behavior_lib/_base/Behavior.py index b80c634..36f95aa 100644 --- a/robowaiter/behavior_lib/_base/Behavior.py +++ b/robowaiter/behavior_lib/_base/Behavior.py @@ -24,10 +24,7 @@ class Bahavior(ptree.behaviour.Behaviour): return ins_name def __init__(self,*args): - name = self.__class__.__name__ - if len(args)>0: - name = f'{name}({",".join(list(args))})' - self.name = name + self.name = Bahavior.get_ins_name(*args) #get valid args # self.valid_arg_list = [] # lines = self.valid_params.strip().splitlines() diff --git a/robowaiter/llm_client/openai_api_request.py b/robowaiter/llm_client/openai_api_request.py new file mode 100644 index 0000000..ebaa273 --- /dev/null +++ b/robowaiter/llm_client/openai_api_request.py @@ -0,0 +1,67 @@ +# 使用curl命令测试返回 +# curl -X POST "http://127.0.0.1:8000/v1/chat/completions" \ +# -H "Content-Type: application/json" \ +# -d "{\"model\": \"chatglm3-6b\", \"messages\": [{\"role\": \"system\", \"content\": \"You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.\"}, {\"role\": \"user\", \"content\": \"你好,给我讲一个故事,大概100字\"}], \"stream\": false, \"max_tokens\": 100, \"temperature\": 0.8, \"top_p\": 0.8}" + +# 使用Python代码测返回 +import requests +import json + +import urllib3 +######################################## +# 该文件实现了与大模型的简单通信 +######################################## + +# 忽略https的安全性警告 +urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + +base_url = "https://45.125.46.134:25344" # 本地部署的地址,或者使用你访问模型的API地址 + +def create_chat_completion(model, messages, use_stream=False): + data = { + "model": model, # 模型名称 + "messages": messages, # 会话历史 + "stream": use_stream, # 是否流式响应 + "max_tokens": 100, # 最多生成字数 + "temperature": 0.8, # 温度 + "top_p": 0.8, # 采样概率 + } + + response = requests.post(f"{base_url}/v1/chat/completions", json=data, stream=use_stream, verify=False) + if response.status_code == 200: + if use_stream: + # 处理流式响应 + for line in response.iter_lines(): + if line: + decoded_line = line.decode('utf-8')[6:] + try: + response_json = json.loads(decoded_line) + content = response_json.get("choices", [{}])[0].get("delta", {}).get("content", "") + print(content) + except: + print("Special Token:", decoded_line) + else: + # 处理非流式响应 + decoded_line = response.json() + print(decoded_line) + content = decoded_line.get("choices", [{}])[0].get("message", "").get("content", "") + print(content) + else: + print("Error:", response.status_code) + return None + + +if __name__ == "__main__": + chat_messages = [ + { + "role": "system", + "content": "You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.", + }, + { + "role": "user", + "content": "你好,给我讲一个故事,大概100字" + } + ] + create_chat_completion("chatglm3-6b", chat_messages, use_stream=False) + + diff --git a/robowaiter/llm_client/tool_api.py b/robowaiter/llm_client/tool_api.py new file mode 100644 index 0000000..2780a14 --- /dev/null +++ b/robowaiter/llm_client/tool_api.py @@ -0,0 +1,86 @@ +import json + +import openai +from colorama import init, Fore +from loguru import logger + +from tool_register import get_tools, dispatch_tool + +init(autoreset=True) + +# 使用Python代码测返回 +import requests +import json + +import urllib3 +######################################## +# 该文件实现了与大模型的简单通信 +######################################## + +# 忽略https的安全性警告 +urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + +base_url = "https://45.125.46.134:25344" # 本地部署的地址,或者使用你访问模型的API地址 + +def get_response(**kwargs): + data = kwargs + + response = requests.post(f"{base_url}/v1/chat/completions", json=data, stream=data["stream"], verify=False) + decoded_line = response.json() + return decoded_line + +functions = get_tools() + +def run_conversation(query: str, stream=False, functions=None, max_retry=5): + params = dict(model="chatglm3", messages=[{"role": "user", "content": query}], stream=stream) + if functions: + params["functions"] = functions + response = get_response(**params) + + for _ in range(max_retry): + if response["choices"][0]["message"].get("function_call"): + function_call = response["choices"][0]["message"]["function_call"] + logger.info(f"Function Call Response: {function_call}") + function_args = json.loads(function_call["arguments"]) + tool_response = dispatch_tool(function_call["name"], function_args) + logger.info(f"Tool Call Response: {tool_response}") + + params["messages"].append(response["choices"][0]["message"]) + params["messages"].append( + { + "role": "function", + "name": function_call["name"], + "content": tool_response, # 调用函数返回结果 + } + ) + else: + reply = response["choices"][0]["message"]["content"] + logger.info(f"Final Reply: \n{reply}") + return + + response = get_response(**params) + + +if __name__ == "__main__": + + # chat_messages = [ + # { + # "role": "system", + # "content": "You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.", + # }, + # { + # "role": "user", + # "content": "你好,给我讲一个故事,大概100字" + # } + # ] + # create_chat_completion("chatglm3-6b", chat_messages, use_stream=False) + + + + query = "你是谁" + run_conversation(query, stream=False) + + logger.info("\n=========== next conversation ===========") + + query = "帮我查询北京的天气怎么样" + run_conversation(query, functions=functions, stream=False) diff --git a/robowaiter/llm_client/tool_api_request.py b/robowaiter/llm_client/tool_api_request.py new file mode 100644 index 0000000..ebaa273 --- /dev/null +++ b/robowaiter/llm_client/tool_api_request.py @@ -0,0 +1,67 @@ +# 使用curl命令测试返回 +# curl -X POST "http://127.0.0.1:8000/v1/chat/completions" \ +# -H "Content-Type: application/json" \ +# -d "{\"model\": \"chatglm3-6b\", \"messages\": [{\"role\": \"system\", \"content\": \"You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.\"}, {\"role\": \"user\", \"content\": \"你好,给我讲一个故事,大概100字\"}], \"stream\": false, \"max_tokens\": 100, \"temperature\": 0.8, \"top_p\": 0.8}" + +# 使用Python代码测返回 +import requests +import json + +import urllib3 +######################################## +# 该文件实现了与大模型的简单通信 +######################################## + +# 忽略https的安全性警告 +urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + +base_url = "https://45.125.46.134:25344" # 本地部署的地址,或者使用你访问模型的API地址 + +def create_chat_completion(model, messages, use_stream=False): + data = { + "model": model, # 模型名称 + "messages": messages, # 会话历史 + "stream": use_stream, # 是否流式响应 + "max_tokens": 100, # 最多生成字数 + "temperature": 0.8, # 温度 + "top_p": 0.8, # 采样概率 + } + + response = requests.post(f"{base_url}/v1/chat/completions", json=data, stream=use_stream, verify=False) + if response.status_code == 200: + if use_stream: + # 处理流式响应 + for line in response.iter_lines(): + if line: + decoded_line = line.decode('utf-8')[6:] + try: + response_json = json.loads(decoded_line) + content = response_json.get("choices", [{}])[0].get("delta", {}).get("content", "") + print(content) + except: + print("Special Token:", decoded_line) + else: + # 处理非流式响应 + decoded_line = response.json() + print(decoded_line) + content = decoded_line.get("choices", [{}])[0].get("message", "").get("content", "") + print(content) + else: + print("Error:", response.status_code) + return None + + +if __name__ == "__main__": + chat_messages = [ + { + "role": "system", + "content": "You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.", + }, + { + "role": "user", + "content": "你好,给我讲一个故事,大概100字" + } + ] + create_chat_completion("chatglm3-6b", chat_messages, use_stream=False) + + diff --git a/robowaiter/llm_client/tool_register.py b/robowaiter/llm_client/tool_register.py new file mode 100644 index 0000000..39ae2fc --- /dev/null +++ b/robowaiter/llm_client/tool_register.py @@ -0,0 +1,115 @@ +import inspect +import traceback +from copy import deepcopy +from pprint import pformat +from types import GenericAlias +from typing import get_origin, Annotated + +_TOOL_HOOKS = {} +_TOOL_DESCRIPTIONS = {} + + +def register_tool(func: callable): + tool_name = func.__name__ + tool_description = inspect.getdoc(func).strip() + python_params = inspect.signature(func).parameters + tool_params = [] + for name, param in python_params.items(): + annotation = param.annotation + if annotation is inspect.Parameter.empty: + raise TypeError(f"Parameter `{name}` missing type annotation") + if get_origin(annotation) != Annotated: + raise TypeError(f"Annotation type for `{name}` must be typing.Annotated") + + typ, (description, required) = annotation.__origin__, annotation.__metadata__ + typ: str = str(typ) if isinstance(typ, GenericAlias) else typ.__name__ + if not isinstance(description, str): + raise TypeError(f"Description for `{name}` must be a string") + if not isinstance(required, bool): + raise TypeError(f"Required for `{name}` must be a bool") + + tool_params.append({ + "name": name, + "description": description, + "type": typ, + "required": required + }) + tool_def = { + "name": tool_name, + "description": tool_description, + "params": tool_params + } + + print("[registered tool] " + pformat(tool_def)) + _TOOL_HOOKS[tool_name] = func + _TOOL_DESCRIPTIONS[tool_name] = tool_def + + return func + + +def dispatch_tool(tool_name: str, tool_params: dict) -> str: + if tool_name not in _TOOL_HOOKS: + return f"Tool `{tool_name}` not found. Please use a provided tool." + tool_call = _TOOL_HOOKS[tool_name] + try: + ret = tool_call(**tool_params) + except: + ret = traceback.format_exc() + return str(ret) + + +def get_tools() -> dict: + return deepcopy(_TOOL_DESCRIPTIONS) + + +# Tool Definitions + +@register_tool +def random_number_generator( + seed: Annotated[int, 'The random seed used by the generator', True], + range: Annotated[tuple[int, int], 'The range of the generated numbers', True], +) -> int: + """ + Generates a random number x, s.t. range[0] <= x < range[1] + """ + if not isinstance(seed, int): + raise TypeError("Seed must be an integer") + if not isinstance(range, tuple): + raise TypeError("Range must be a tuple") + if not isinstance(range[0], int) or not isinstance(range[1], int): + raise TypeError("Range must be a tuple of integers") + + import random + return random.Random(seed).randint(*range) + + +@register_tool +def get_weather( + city_name: Annotated[str, 'The name of the city to be queried', True], +) -> str: + """ + Get the current weather for `city_name` + """ + + if not isinstance(city_name, str): + raise TypeError("City name must be a string") + + key_selection = { + "current_condition": ["temp_C", "FeelsLikeC", "humidity", "weatherDesc", "observation_time"], + } + import requests + try: + resp = requests.get(f"https://wttr.in/{city_name}?format=j1") + resp.raise_for_status() + resp = resp.json() + ret = {k: {_v: resp[k][0][_v] for _v in v} for k, v in key_selection.items()} + except: + import traceback + ret = "Error encountered while fetching weather data!\n" + traceback.format_exc() + + return str(ret) + + +if __name__ == "__main__": + print(dispatch_tool("get_weather", {"city_name": "beijing"})) + print(get_tools())