现在可以调用大模型的function call
This commit is contained in:
parent
e5a9f4e583
commit
c9e8dc3981
|
@ -12,3 +12,4 @@ pytorch==1.11.0
|
||||||
torchvision==0.12.0
|
torchvision==0.12.0
|
||||||
torchaudio==0.11.0
|
torchaudio==0.11.0
|
||||||
cudatoolkit=11.3
|
cudatoolkit=11.3
|
||||||
|
loguru
|
|
@ -24,10 +24,7 @@ class Bahavior(ptree.behaviour.Behaviour):
|
||||||
return ins_name
|
return ins_name
|
||||||
|
|
||||||
def __init__(self,*args):
|
def __init__(self,*args):
|
||||||
name = self.__class__.__name__
|
self.name = Bahavior.get_ins_name(*args)
|
||||||
if len(args)>0:
|
|
||||||
name = f'{name}({",".join(list(args))})'
|
|
||||||
self.name = name
|
|
||||||
#get valid args
|
#get valid args
|
||||||
# self.valid_arg_list = []
|
# self.valid_arg_list = []
|
||||||
# lines = self.valid_params.strip().splitlines()
|
# lines = self.valid_params.strip().splitlines()
|
||||||
|
|
|
@ -0,0 +1,67 @@
|
||||||
|
# 使用curl命令测试返回
|
||||||
|
# curl -X POST "http://127.0.0.1:8000/v1/chat/completions" \
|
||||||
|
# -H "Content-Type: application/json" \
|
||||||
|
# -d "{\"model\": \"chatglm3-6b\", \"messages\": [{\"role\": \"system\", \"content\": \"You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.\"}, {\"role\": \"user\", \"content\": \"你好,给我讲一个故事,大概100字\"}], \"stream\": false, \"max_tokens\": 100, \"temperature\": 0.8, \"top_p\": 0.8}"
|
||||||
|
|
||||||
|
# 使用Python代码测返回
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
|
||||||
|
import urllib3
|
||||||
|
########################################
|
||||||
|
# 该文件实现了与大模型的简单通信
|
||||||
|
########################################
|
||||||
|
|
||||||
|
# 忽略https的安全性警告
|
||||||
|
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||||
|
|
||||||
|
base_url = "https://45.125.46.134:25344" # 本地部署的地址,或者使用你访问模型的API地址
|
||||||
|
|
||||||
|
def create_chat_completion(model, messages, use_stream=False):
|
||||||
|
data = {
|
||||||
|
"model": model, # 模型名称
|
||||||
|
"messages": messages, # 会话历史
|
||||||
|
"stream": use_stream, # 是否流式响应
|
||||||
|
"max_tokens": 100, # 最多生成字数
|
||||||
|
"temperature": 0.8, # 温度
|
||||||
|
"top_p": 0.8, # 采样概率
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(f"{base_url}/v1/chat/completions", json=data, stream=use_stream, verify=False)
|
||||||
|
if response.status_code == 200:
|
||||||
|
if use_stream:
|
||||||
|
# 处理流式响应
|
||||||
|
for line in response.iter_lines():
|
||||||
|
if line:
|
||||||
|
decoded_line = line.decode('utf-8')[6:]
|
||||||
|
try:
|
||||||
|
response_json = json.loads(decoded_line)
|
||||||
|
content = response_json.get("choices", [{}])[0].get("delta", {}).get("content", "")
|
||||||
|
print(content)
|
||||||
|
except:
|
||||||
|
print("Special Token:", decoded_line)
|
||||||
|
else:
|
||||||
|
# 处理非流式响应
|
||||||
|
decoded_line = response.json()
|
||||||
|
print(decoded_line)
|
||||||
|
content = decoded_line.get("choices", [{}])[0].get("message", "").get("content", "")
|
||||||
|
print(content)
|
||||||
|
else:
|
||||||
|
print("Error:", response.status_code)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
chat_messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "你好,给我讲一个故事,大概100字"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
create_chat_completion("chatglm3-6b", chat_messages, use_stream=False)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,86 @@
|
||||||
|
import json
|
||||||
|
|
||||||
|
import openai
|
||||||
|
from colorama import init, Fore
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from tool_register import get_tools, dispatch_tool
|
||||||
|
|
||||||
|
init(autoreset=True)
|
||||||
|
|
||||||
|
# 使用Python代码测返回
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
|
||||||
|
import urllib3
|
||||||
|
########################################
|
||||||
|
# 该文件实现了与大模型的简单通信
|
||||||
|
########################################
|
||||||
|
|
||||||
|
# 忽略https的安全性警告
|
||||||
|
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||||
|
|
||||||
|
base_url = "https://45.125.46.134:25344" # 本地部署的地址,或者使用你访问模型的API地址
|
||||||
|
|
||||||
|
def get_response(**kwargs):
|
||||||
|
data = kwargs
|
||||||
|
|
||||||
|
response = requests.post(f"{base_url}/v1/chat/completions", json=data, stream=data["stream"], verify=False)
|
||||||
|
decoded_line = response.json()
|
||||||
|
return decoded_line
|
||||||
|
|
||||||
|
functions = get_tools()
|
||||||
|
|
||||||
|
def run_conversation(query: str, stream=False, functions=None, max_retry=5):
|
||||||
|
params = dict(model="chatglm3", messages=[{"role": "user", "content": query}], stream=stream)
|
||||||
|
if functions:
|
||||||
|
params["functions"] = functions
|
||||||
|
response = get_response(**params)
|
||||||
|
|
||||||
|
for _ in range(max_retry):
|
||||||
|
if response["choices"][0]["message"].get("function_call"):
|
||||||
|
function_call = response["choices"][0]["message"]["function_call"]
|
||||||
|
logger.info(f"Function Call Response: {function_call}")
|
||||||
|
function_args = json.loads(function_call["arguments"])
|
||||||
|
tool_response = dispatch_tool(function_call["name"], function_args)
|
||||||
|
logger.info(f"Tool Call Response: {tool_response}")
|
||||||
|
|
||||||
|
params["messages"].append(response["choices"][0]["message"])
|
||||||
|
params["messages"].append(
|
||||||
|
{
|
||||||
|
"role": "function",
|
||||||
|
"name": function_call["name"],
|
||||||
|
"content": tool_response, # 调用函数返回结果
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
reply = response["choices"][0]["message"]["content"]
|
||||||
|
logger.info(f"Final Reply: \n{reply}")
|
||||||
|
return
|
||||||
|
|
||||||
|
response = get_response(**params)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
# chat_messages = [
|
||||||
|
# {
|
||||||
|
# "role": "system",
|
||||||
|
# "content": "You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.",
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# "role": "user",
|
||||||
|
# "content": "你好,给我讲一个故事,大概100字"
|
||||||
|
# }
|
||||||
|
# ]
|
||||||
|
# create_chat_completion("chatglm3-6b", chat_messages, use_stream=False)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
query = "你是谁"
|
||||||
|
run_conversation(query, stream=False)
|
||||||
|
|
||||||
|
logger.info("\n=========== next conversation ===========")
|
||||||
|
|
||||||
|
query = "帮我查询北京的天气怎么样"
|
||||||
|
run_conversation(query, functions=functions, stream=False)
|
|
@ -0,0 +1,67 @@
|
||||||
|
# 使用curl命令测试返回
|
||||||
|
# curl -X POST "http://127.0.0.1:8000/v1/chat/completions" \
|
||||||
|
# -H "Content-Type: application/json" \
|
||||||
|
# -d "{\"model\": \"chatglm3-6b\", \"messages\": [{\"role\": \"system\", \"content\": \"You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.\"}, {\"role\": \"user\", \"content\": \"你好,给我讲一个故事,大概100字\"}], \"stream\": false, \"max_tokens\": 100, \"temperature\": 0.8, \"top_p\": 0.8}"
|
||||||
|
|
||||||
|
# 使用Python代码测返回
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
|
||||||
|
import urllib3
|
||||||
|
########################################
|
||||||
|
# 该文件实现了与大模型的简单通信
|
||||||
|
########################################
|
||||||
|
|
||||||
|
# 忽略https的安全性警告
|
||||||
|
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||||
|
|
||||||
|
base_url = "https://45.125.46.134:25344" # 本地部署的地址,或者使用你访问模型的API地址
|
||||||
|
|
||||||
|
def create_chat_completion(model, messages, use_stream=False):
|
||||||
|
data = {
|
||||||
|
"model": model, # 模型名称
|
||||||
|
"messages": messages, # 会话历史
|
||||||
|
"stream": use_stream, # 是否流式响应
|
||||||
|
"max_tokens": 100, # 最多生成字数
|
||||||
|
"temperature": 0.8, # 温度
|
||||||
|
"top_p": 0.8, # 采样概率
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(f"{base_url}/v1/chat/completions", json=data, stream=use_stream, verify=False)
|
||||||
|
if response.status_code == 200:
|
||||||
|
if use_stream:
|
||||||
|
# 处理流式响应
|
||||||
|
for line in response.iter_lines():
|
||||||
|
if line:
|
||||||
|
decoded_line = line.decode('utf-8')[6:]
|
||||||
|
try:
|
||||||
|
response_json = json.loads(decoded_line)
|
||||||
|
content = response_json.get("choices", [{}])[0].get("delta", {}).get("content", "")
|
||||||
|
print(content)
|
||||||
|
except:
|
||||||
|
print("Special Token:", decoded_line)
|
||||||
|
else:
|
||||||
|
# 处理非流式响应
|
||||||
|
decoded_line = response.json()
|
||||||
|
print(decoded_line)
|
||||||
|
content = decoded_line.get("choices", [{}])[0].get("message", "").get("content", "")
|
||||||
|
print(content)
|
||||||
|
else:
|
||||||
|
print("Error:", response.status_code)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
chat_messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "你好,给我讲一个故事,大概100字"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
create_chat_completion("chatglm3-6b", chat_messages, use_stream=False)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,115 @@
|
||||||
|
import inspect
|
||||||
|
import traceback
|
||||||
|
from copy import deepcopy
|
||||||
|
from pprint import pformat
|
||||||
|
from types import GenericAlias
|
||||||
|
from typing import get_origin, Annotated
|
||||||
|
|
||||||
|
_TOOL_HOOKS = {}
|
||||||
|
_TOOL_DESCRIPTIONS = {}
|
||||||
|
|
||||||
|
|
||||||
|
def register_tool(func: callable):
|
||||||
|
tool_name = func.__name__
|
||||||
|
tool_description = inspect.getdoc(func).strip()
|
||||||
|
python_params = inspect.signature(func).parameters
|
||||||
|
tool_params = []
|
||||||
|
for name, param in python_params.items():
|
||||||
|
annotation = param.annotation
|
||||||
|
if annotation is inspect.Parameter.empty:
|
||||||
|
raise TypeError(f"Parameter `{name}` missing type annotation")
|
||||||
|
if get_origin(annotation) != Annotated:
|
||||||
|
raise TypeError(f"Annotation type for `{name}` must be typing.Annotated")
|
||||||
|
|
||||||
|
typ, (description, required) = annotation.__origin__, annotation.__metadata__
|
||||||
|
typ: str = str(typ) if isinstance(typ, GenericAlias) else typ.__name__
|
||||||
|
if not isinstance(description, str):
|
||||||
|
raise TypeError(f"Description for `{name}` must be a string")
|
||||||
|
if not isinstance(required, bool):
|
||||||
|
raise TypeError(f"Required for `{name}` must be a bool")
|
||||||
|
|
||||||
|
tool_params.append({
|
||||||
|
"name": name,
|
||||||
|
"description": description,
|
||||||
|
"type": typ,
|
||||||
|
"required": required
|
||||||
|
})
|
||||||
|
tool_def = {
|
||||||
|
"name": tool_name,
|
||||||
|
"description": tool_description,
|
||||||
|
"params": tool_params
|
||||||
|
}
|
||||||
|
|
||||||
|
print("[registered tool] " + pformat(tool_def))
|
||||||
|
_TOOL_HOOKS[tool_name] = func
|
||||||
|
_TOOL_DESCRIPTIONS[tool_name] = tool_def
|
||||||
|
|
||||||
|
return func
|
||||||
|
|
||||||
|
|
||||||
|
def dispatch_tool(tool_name: str, tool_params: dict) -> str:
|
||||||
|
if tool_name not in _TOOL_HOOKS:
|
||||||
|
return f"Tool `{tool_name}` not found. Please use a provided tool."
|
||||||
|
tool_call = _TOOL_HOOKS[tool_name]
|
||||||
|
try:
|
||||||
|
ret = tool_call(**tool_params)
|
||||||
|
except:
|
||||||
|
ret = traceback.format_exc()
|
||||||
|
return str(ret)
|
||||||
|
|
||||||
|
|
||||||
|
def get_tools() -> dict:
|
||||||
|
return deepcopy(_TOOL_DESCRIPTIONS)
|
||||||
|
|
||||||
|
|
||||||
|
# Tool Definitions
|
||||||
|
|
||||||
|
@register_tool
|
||||||
|
def random_number_generator(
|
||||||
|
seed: Annotated[int, 'The random seed used by the generator', True],
|
||||||
|
range: Annotated[tuple[int, int], 'The range of the generated numbers', True],
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Generates a random number x, s.t. range[0] <= x < range[1]
|
||||||
|
"""
|
||||||
|
if not isinstance(seed, int):
|
||||||
|
raise TypeError("Seed must be an integer")
|
||||||
|
if not isinstance(range, tuple):
|
||||||
|
raise TypeError("Range must be a tuple")
|
||||||
|
if not isinstance(range[0], int) or not isinstance(range[1], int):
|
||||||
|
raise TypeError("Range must be a tuple of integers")
|
||||||
|
|
||||||
|
import random
|
||||||
|
return random.Random(seed).randint(*range)
|
||||||
|
|
||||||
|
|
||||||
|
@register_tool
|
||||||
|
def get_weather(
|
||||||
|
city_name: Annotated[str, 'The name of the city to be queried', True],
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Get the current weather for `city_name`
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not isinstance(city_name, str):
|
||||||
|
raise TypeError("City name must be a string")
|
||||||
|
|
||||||
|
key_selection = {
|
||||||
|
"current_condition": ["temp_C", "FeelsLikeC", "humidity", "weatherDesc", "observation_time"],
|
||||||
|
}
|
||||||
|
import requests
|
||||||
|
try:
|
||||||
|
resp = requests.get(f"https://wttr.in/{city_name}?format=j1")
|
||||||
|
resp.raise_for_status()
|
||||||
|
resp = resp.json()
|
||||||
|
ret = {k: {_v: resp[k][0][_v] for _v in v} for k, v in key_selection.items()}
|
||||||
|
except:
|
||||||
|
import traceback
|
||||||
|
ret = "Error encountered while fetching weather data!\n" + traceback.format_exc()
|
||||||
|
|
||||||
|
return str(ret)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print(dispatch_tool("get_weather", {"city_name": "beijing"}))
|
||||||
|
print(get_tools())
|
Loading…
Reference in New Issue