修改固定组合
This commit is contained in:
parent
7e63d9658b
commit
f8e4f0053c
|
@ -1,51 +1,55 @@
|
|||
import py_trees as ptree
|
||||
from robowaiter.behavior_lib._base.Act import Act
|
||||
from robowaiter.llm_client.ask_llm import ask_llm
|
||||
|
||||
from robowaiter.llm_client.multi_rounds import ask_llm,new_history
|
||||
|
||||
class DealChat(Act):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.chat_history = ""
|
||||
self.function_success = False
|
||||
self.func_map = {
|
||||
"create_sub_task": self.create_sub_task
|
||||
}
|
||||
|
||||
def _update(self) -> ptree.common.Status:
|
||||
# if self.scene.status?
|
||||
name,sentence = self.scene.state['chat_list'].pop(0)
|
||||
|
||||
if name == "Goal":
|
||||
self.create_sub_task(sentence)
|
||||
self.create_sub_task(goal=sentence)
|
||||
return ptree.common.Status.RUNNING
|
||||
|
||||
if name not in self.scene.state["chat_history"]:
|
||||
self.scene.state["chat_history"][name] = new_history()
|
||||
|
||||
history = self.scene.state["chat_history"][name]
|
||||
self.scene.state["attention"]["customer"] = name
|
||||
self.scene.state["serve_state"] = {
|
||||
"last_chat_time": self.scene.time,
|
||||
}
|
||||
|
||||
function_call, response = ask_llm(sentence,history,func_map=self.func_map)
|
||||
|
||||
self.chat_history += sentence + '\n'
|
||||
|
||||
res_dict = ask_llm(sentence)
|
||||
answer = res_dict["Answer"]
|
||||
self.scene.chat_bubble(answer) # 机器人输出对话
|
||||
self.chat_history += answer + '\n'
|
||||
|
||||
goal = res_dict["Goal"]
|
||||
if goal:
|
||||
if "{" not in goal:
|
||||
goal = {str(goal)}
|
||||
else:
|
||||
goal=eval(goal)
|
||||
|
||||
if goal is not None:
|
||||
print(f'goal:{goal}')
|
||||
|
||||
self.create_sub_task(goal)
|
||||
|
||||
if self.scene.show_bubble:
|
||||
self.scene.chat_bubble(f"{answer}")
|
||||
self.scene.chat_bubble(response) # 机器人输出对话
|
||||
|
||||
return ptree.common.Status.RUNNING
|
||||
|
||||
|
||||
def create_sub_task(self,goal):
|
||||
self.scene.robot.expand_sub_task_tree(goal)
|
||||
def create_sub_task(self,**args):
|
||||
try:
|
||||
goal = args['goal']
|
||||
|
||||
w = goal.split(")")
|
||||
goal_set = set()
|
||||
goal_set.add(w[0] + ")")
|
||||
if len(w)>1:
|
||||
for x in w[1:]:
|
||||
if x != "":
|
||||
goal_set.add(x[0] + ")")
|
||||
self.function_success = True
|
||||
except:
|
||||
print("参数解析错误")
|
||||
|
||||
self.scene.robot.expand_sub_task_tree(goal_set)
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
你好
|
||||
您好,我是这家咖啡厅的服务员,请问您要点什么?
|
||||
|
||||
做一杯咖啡
|
||||
好的,我马上做咖啡
|
||||
create_sub_task
|
||||
{"goal":"On(Coffee,CoffeeTable)"}
|
||||
|
||||
不用了
|
||||
好的,您有需要再跟我说
|
||||
stop_serve
|
||||
{}
|
||||
|
||||
来一号桌
|
||||
好的,我马上来一号桌
|
||||
create_sub_task
|
||||
{"goal":"At(Robot,Table1)"}
|
|
@ -1,29 +1,162 @@
|
|||
import json
|
||||
|
||||
import openai
|
||||
from colorama import init, Fore
|
||||
from loguru import logger
|
||||
import json
|
||||
from robowaiter.llm_client.tool_register import get_tools, dispatch_tool
|
||||
import requests
|
||||
import json
|
||||
from collections import deque
|
||||
|
||||
import urllib3
|
||||
from robowaiter.llm_client.tool_api import run_conversation
|
||||
import copy
|
||||
init(autoreset=True)
|
||||
from robowaiter.utils import get_root_path
|
||||
import os
|
||||
import re
|
||||
from robowaiter.llm_client.single_round import single_round
|
||||
########################################
|
||||
# 该文件实现了与大模型的简单通信、多轮对话,输入end表示对话结束
|
||||
# 该文件实现了与大模型的通信以及工具调用
|
||||
########################################
|
||||
|
||||
# 忽略https的安全性警告
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
url = "https://45.125.46.134:25344/v1/chat/completions"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
base_url = "https://45.125.46.134:25344" # 本地部署的地址,或者使用你访问模型的API地址
|
||||
|
||||
#在这里输入你的问题
|
||||
k=input()
|
||||
data_memory=[]
|
||||
n=1
|
||||
while k!='end':
|
||||
question_now=k
|
||||
user_dict={"role": "user","content":question_now}
|
||||
data_memory.append(user_dict)
|
||||
#print(data_memory)
|
||||
response = run_conversation(str(data_memory))
|
||||
answer=str(response)
|
||||
print(answer)
|
||||
assistant_dict={"role": "assistant","content":answer}
|
||||
data_memory.append(assistant_dict)
|
||||
n=n+2
|
||||
k=input()
|
||||
root_path = get_root_path()
|
||||
# load test questions
|
||||
file_path = os.path.join(root_path,"robowaiter/llm_client/data/fix_questions.txt")
|
||||
|
||||
fix_questions_dict = {}
|
||||
with open(file_path,'r',encoding="utf-8") as f:
|
||||
#读取所有行
|
||||
lines = f.read().strip()
|
||||
sections = re.split(r'\n\s*\n', lines)
|
||||
for s in sections:
|
||||
x = s.split()
|
||||
fix_questions_dict[x[0]] = x[1:]
|
||||
|
||||
functions = get_tools()
|
||||
role_system = [{
|
||||
"role": "system",
|
||||
"content": "你是RoboWaiter,一个由HPCL团队开发的机器人服务员,你在咖啡厅工作。接受顾客的指令并调用工具函数来完成各种服务任务。如果顾客问你们这里有什么,或者想要点单,你说我们咖啡厅提供咖啡,水,点心,酸奶等食物。如果顾客不需要你了,你就回到吧台招待。如果顾客叫你去做某事,你回复:好的,我马上去做这件事。",
|
||||
}]
|
||||
|
||||
def new_history(max_length=7):
|
||||
history = deque(maxlen=max_length)
|
||||
|
||||
return history
|
||||
|
||||
def new_response():
|
||||
return {'choices': [{'index': 0, 'message':{} }]}
|
||||
|
||||
def parse_fix_question(question):
|
||||
response = new_response()
|
||||
fix_ans = fix_questions_dict[question]
|
||||
if len(fix_ans)<=1: #简单对话
|
||||
message = {'role': 'assistant', 'content': fix_ans, 'name': None,
|
||||
'function_call': None}
|
||||
else:
|
||||
reply, func,args = fix_ans
|
||||
# tool_response = dispatch_tool(function_call["name"], json.loads(args))
|
||||
# logger.info(f"Tool Call Response: {tool_response}")
|
||||
message = {'role': 'assistant',
|
||||
'content': f"\n <|assistant|> {func}({args})\n ```python\ntool_call(goal={args})\n```",
|
||||
'name': None,
|
||||
'function_call': {'name': func, 'arguments': args}}
|
||||
|
||||
response["choices"][0]["message"] = message
|
||||
return response
|
||||
|
||||
def get_response(sentence, history, allow_function_call = True):
|
||||
if sentence:
|
||||
history.append({"role": "user", "content": sentence})
|
||||
|
||||
if sentence in fix_questions_dict:
|
||||
return parse_fix_question(sentence)
|
||||
|
||||
params = dict(model="RoboWaiter")
|
||||
params['messages'] = role_system + list(history)
|
||||
if allow_function_call:
|
||||
params["functions"] = functions
|
||||
|
||||
|
||||
response = requests.post(f"{base_url}/v1/chat/completions", json=params, stream=False, verify=False)
|
||||
decoded_line = response.json()
|
||||
return decoded_line
|
||||
|
||||
def deal_response(response, history, func_map=None ):
|
||||
if response["choices"][0]["message"].get("function_call"):
|
||||
function_call = response["choices"][0]["message"]["function_call"]
|
||||
logger.info(f"Function Call Response: {function_call}")
|
||||
|
||||
function_name = function_call["name"]
|
||||
function_args = json.loads(function_call["arguments"])
|
||||
if func_map:
|
||||
tool_response = func_map[function_name](**function_args)
|
||||
else:
|
||||
try:
|
||||
tool_response = dispatch_tool(function_call["name"], function_args)
|
||||
logger.info(f"Tool Call Response: {tool_response}")
|
||||
except:
|
||||
logger.info(f"重试工具调用")
|
||||
# tool_response = dispatch_tool(function_call["name"], function_args)
|
||||
return function_name,None
|
||||
|
||||
return_message = response["choices"][0]["message"]
|
||||
|
||||
history.append(return_message)
|
||||
t = {
|
||||
"role": "function",
|
||||
"name": function_call["name"],
|
||||
"content": str(tool_response), # 调用函数返回结果
|
||||
}
|
||||
|
||||
history.append(t)
|
||||
return function_call["name"], return_message
|
||||
|
||||
else:
|
||||
return_message = response["choices"][0]["message"]
|
||||
reply = return_message["content"]
|
||||
|
||||
history.append(return_message)
|
||||
logger.info(f"Final Reply: \n{reply}")
|
||||
|
||||
return False, reply
|
||||
|
||||
|
||||
def ask_llm(question,history, func_map=None, retry=3):
|
||||
response = get_response(question, history)
|
||||
|
||||
function_call,result = deal_response(response, history, func_map)
|
||||
if function_call:
|
||||
if question in fix_questions_dict:
|
||||
reply = fix_questions_dict[question][0]
|
||||
result = single_round(reply,"你是机器人服务员,请把以下句子换一种表述方式对顾客说,但是意思不变,尽量简短:\n")
|
||||
message = {'role': 'assistant', 'content': result, 'name': None,
|
||||
'function_call': None}
|
||||
history.append(message)
|
||||
|
||||
else:
|
||||
response = get_response(None, history,allow_function_call=False)
|
||||
_,result = deal_response(response, history, func_map)
|
||||
|
||||
|
||||
print(f'{len(history)}条历史记录:')
|
||||
for x in history:
|
||||
print(x)
|
||||
return function_call, result
|
||||
|
||||
if __name__ == "__main__":
|
||||
question = input("\n顾客:")
|
||||
history = new_history()
|
||||
n = 1
|
||||
max_retry = 2
|
||||
|
||||
while question != 'end':
|
||||
function_call, return_message = ask_llm(question,history)
|
||||
|
||||
|
||||
question = input("\n顾客:")
|
||||
|
|
|
@ -9,7 +9,7 @@ import urllib3
|
|||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
|
||||
def single_round(question):
|
||||
def single_round(question,prefix=""):
|
||||
url = "https://45.125.46.134:25344/v1/chat/completions"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
data = {
|
||||
|
@ -21,7 +21,7 @@ def single_round(question):
|
|||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": question
|
||||
"content": prefix + question
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
|
@ -147,6 +147,17 @@ def create_sub_task(
|
|||
|
||||
return goal
|
||||
|
||||
|
||||
@register_tool
|
||||
def stop_serve(
|
||||
) -> bool:
|
||||
"""
|
||||
当顾客通过任何形式表示不再需要服务时,调用该函数
|
||||
"""
|
||||
return True
|
||||
|
||||
|
||||
|
||||
# @register_tool
|
||||
# def get_object_info(
|
||||
# obj: Annotated[str, '需要获取信息的物体名称', True]
|
||||
|
|
|
@ -71,6 +71,7 @@ class Scene:
|
|||
"greeted_customers":set(),
|
||||
"attention":{},
|
||||
"serve_state":{},
|
||||
"chat_history":{}
|
||||
}
|
||||
"""
|
||||
status:
|
||||
|
@ -439,7 +440,7 @@ class Scene:
|
|||
def chat_bubble(self, message):
|
||||
stub.ControlRobot(
|
||||
GrabSim_pb2.ControlInfo(
|
||||
scene=self.sceneID, type=0, action=1, content=message
|
||||
scene=self.sceneID, type=0, action=1, content=message.strip()
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -18,28 +18,14 @@ class SceneOT(Scene):
|
|||
super().__init__(robot)
|
||||
# 在这里加入场景中发生的事件
|
||||
self.new_event_list = [
|
||||
# (9,self.add_walkers,([[0, 880]],)),
|
||||
# (10,self.walker_walk_to,(2,50,500))
|
||||
# (5, self.set_goal("On(Yogurt,Table4)"))
|
||||
(3, self.customer_say, ("System","来一号桌"))
|
||||
# (5, self.set_goal("At(Robot,BrightTable4)"))
|
||||
]
|
||||
|
||||
def _reset(self):
|
||||
# self.add_walkers([[0, 880], [250, 1200]])
|
||||
pass
|
||||
|
||||
# 展示顾客,前8个id是小孩,后面都是大人
|
||||
for i in range(4):
|
||||
self.add_walker(i,50,300 + i * 50)
|
||||
name1 = self.walker_index2mem(1)
|
||||
name2 = self.walker_index2mem(3)
|
||||
|
||||
self.remove_walker(0,2)
|
||||
|
||||
index1 = self.state["customer_mem"][name1]
|
||||
index2 = self.state["customer_mem"][name2]
|
||||
|
||||
self.walker_bubble(name1,f"我是第{index1}个")
|
||||
self.walker_bubble(name2,f"我是第{index2}个")
|
||||
|
||||
def _run(self):
|
||||
pass
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
"""
|
||||
人提出请求,机器人完成任务
|
||||
1. 做咖啡(固定动画):接收到做咖啡指令、走到咖啡机、拿杯子、操作咖啡机、取杯子、送到客人桌子上
|
||||
2. 倒水
|
||||
3. 夹点心
|
||||
|
||||
具体描述:设计一套点单规则(如菜单包含咖啡、水、点心等),按照规则拟造随机的订单。在收到订单后,通过大模型让机器人输出合理的备餐计划,并尝试在模拟环境中按照这个规划实现任务。
|
||||
|
||||
"""
|
||||
|
||||
# todo: 接收点单信息,大模型生成任务规划
|
||||
|
||||
from robowaiter.scene.scene import Scene
|
||||
|
||||
class SceneOT(Scene):
|
||||
|
||||
def __init__(self, robot):
|
||||
super().__init__(robot)
|
||||
# 在这里加入场景中发生的事件
|
||||
self.new_event_list = [
|
||||
# (9,self.add_walkers,([[0, 880]],)),
|
||||
# (10,self.walker_walk_to,(2,50,500))
|
||||
# (5, self.set_goal("On(Yogurt,Table4)"))
|
||||
# (5, self.set_goal("At(Robot,BrightTable4)"))
|
||||
]
|
||||
|
||||
def _reset(self):
|
||||
# self.add_walkers([[0, 880], [250, 1200]])
|
||||
|
||||
# 展示顾客,前8个id是小孩,后面都是大人
|
||||
for i in range(4):
|
||||
self.add_walker(i,50,300 + i * 50)
|
||||
name1 = self.walker_index2mem(1)
|
||||
name2 = self.walker_index2mem(3)
|
||||
|
||||
self.remove_walker(0,2)
|
||||
|
||||
index1 = self.state["customer_mem"][name1]
|
||||
index2 = self.state["customer_mem"][name2]
|
||||
|
||||
self.walker_bubble(name1,f"我是第{index1}个")
|
||||
self.walker_bubble(name2,f"我是第{index2}个")
|
||||
|
||||
def _run(self):
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import os
|
||||
from robowaiter.robot.robot import Robot
|
||||
|
||||
robot = Robot()
|
||||
|
||||
# create task
|
||||
task = SceneOT(robot)
|
||||
task.reset()
|
||||
task.run()
|
Loading…
Reference in New Issue