2023-10-10 20:47:32 +08:00
|
|
|
import py_trees as ptree
|
2023-11-08 15:28:01 +08:00
|
|
|
from robowaiter.behavior_lib._base.Act import Act
|
2023-11-09 16:07:02 +08:00
|
|
|
|
2023-11-30 17:36:16 +08:00
|
|
|
# from robowaiter.llm_client.multi_rounds_retri import ask_llm, new_history
|
|
|
|
from robowaiter.llm_client.multi_rounds import ask_llm, new_history
|
2023-11-19 14:21:58 +08:00
|
|
|
import random
|
2023-11-23 11:58:50 +08:00
|
|
|
from collections import deque
|
2023-11-19 14:21:58 +08:00
|
|
|
|
2023-11-28 17:50:20 +08:00
|
|
|
from translate import Translator
|
|
|
|
|
|
|
|
|
|
|
|
translator = Translator(to_lang="zh")
|
|
|
|
translator.from_lang = 'en'
|
|
|
|
translator.to_lang = 'zh-cn'
|
2023-11-19 14:44:22 +08:00
|
|
|
|
2023-11-28 17:21:54 +08:00
|
|
|
import spacy
|
2023-11-28 20:06:57 +08:00
|
|
|
# nlp = spacy.load('en_core_web_lg')
|
2023-11-28 17:50:20 +08:00
|
|
|
nlp_zh = spacy.load('zh_core_web_lg')
|
|
|
|
|
2023-11-19 14:21:58 +08:00
|
|
|
|
2023-11-23 11:58:50 +08:00
|
|
|
class History(deque):
|
|
|
|
def __init__(self,scene,customer_name):
|
|
|
|
super().__init__(maxlen=7)
|
|
|
|
self.scene = scene
|
|
|
|
self.customer_name = customer_name
|
|
|
|
|
|
|
|
def append(self, __x) -> None:
|
|
|
|
super().append(__x)
|
|
|
|
self.scene.ui_func(("new_history",self.customer_name, __x))
|
|
|
|
|
2023-11-09 21:52:13 +08:00
|
|
|
|
2023-11-08 15:28:01 +08:00
|
|
|
class DealChat(Act):
|
2023-11-08 10:03:40 +08:00
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
2023-11-15 14:30:57 +08:00
|
|
|
self.chat_history = ""
|
2023-11-18 17:56:48 +08:00
|
|
|
self.function_success = False
|
|
|
|
self.func_map = {
|
2023-11-18 22:30:14 +08:00
|
|
|
"create_sub_task": self.create_sub_task,
|
2023-11-19 14:24:06 +08:00
|
|
|
"stop_serve": self.stop_serve,
|
2023-11-19 17:42:56 +08:00
|
|
|
"get_object_info": self.get_object_info,
|
2023-11-28 17:50:20 +08:00
|
|
|
# "find_location": self.find_location,
|
|
|
|
"get_number_of_objects": self.get_number_of_objects,
|
2023-11-18 17:56:48 +08:00
|
|
|
}
|
2023-10-17 16:28:36 +08:00
|
|
|
|
2023-10-25 22:12:15 +08:00
|
|
|
def _update(self) -> ptree.common.Status:
|
2023-10-25 10:34:24 +08:00
|
|
|
# if self.scene.status?
|
2023-11-19 14:21:58 +08:00
|
|
|
name, sentence = self.scene.state['chat_list'].pop(0)
|
2023-11-18 12:07:30 +08:00
|
|
|
|
2023-11-18 14:13:07 +08:00
|
|
|
if name == "Goal":
|
2023-11-18 17:56:48 +08:00
|
|
|
self.create_sub_task(goal=sentence)
|
2023-11-23 11:58:50 +08:00
|
|
|
self.scene.ui_func(("new_history", "System", {
|
|
|
|
"role": "user",
|
|
|
|
"content": "set goal: " + sentence
|
|
|
|
}))
|
|
|
|
|
2023-11-16 20:48:01 +08:00
|
|
|
return ptree.common.Status.RUNNING
|
|
|
|
|
2023-11-18 17:56:48 +08:00
|
|
|
if name not in self.scene.state["chat_history"]:
|
2023-11-23 11:58:50 +08:00
|
|
|
self.scene.state["chat_history"][name] = History(self.scene,name)
|
2023-11-18 17:56:48 +08:00
|
|
|
|
|
|
|
history = self.scene.state["chat_history"][name]
|
2023-11-18 14:51:17 +08:00
|
|
|
self.scene.state["attention"]["customer"] = name
|
2023-11-19 16:55:22 +08:00
|
|
|
self.scene.state["serve_state"][name] = {
|
2023-11-18 14:51:17 +08:00
|
|
|
"last_chat_time": self.scene.time,
|
2023-11-19 16:55:22 +08:00
|
|
|
"served": False
|
2023-11-18 14:51:17 +08:00
|
|
|
}
|
2023-11-16 20:48:01 +08:00
|
|
|
|
2023-11-18 17:56:48 +08:00
|
|
|
function_call, response = ask_llm(sentence,history,func_map=self.func_map)
|
2023-11-15 14:30:57 +08:00
|
|
|
|
|
|
|
|
2023-11-18 17:56:48 +08:00
|
|
|
self.scene.chat_bubble(response) # 机器人输出对话
|
2023-11-12 14:36:41 +08:00
|
|
|
|
2023-11-18 17:56:48 +08:00
|
|
|
return ptree.common.Status.RUNNING
|
2023-11-09 21:52:13 +08:00
|
|
|
|
2023-10-25 22:12:15 +08:00
|
|
|
|
2023-11-28 17:50:20 +08:00
|
|
|
def obj_name_en2zh(self,obj):
|
|
|
|
obj = obj.replace("_", " ")
|
|
|
|
obj = translator.translate(obj) #转成中文
|
|
|
|
print("====obj:=======",obj)
|
|
|
|
return obj
|
|
|
|
|
2023-11-19 14:21:58 +08:00
|
|
|
def create_sub_task(self, **args):
|
2023-11-18 17:56:48 +08:00
|
|
|
try:
|
|
|
|
goal = args['goal']
|
2023-11-08 17:37:49 +08:00
|
|
|
|
2023-11-18 17:56:48 +08:00
|
|
|
w = goal.split(")")
|
|
|
|
goal_set = set()
|
|
|
|
goal_set.add(w[0] + ")")
|
2023-11-19 14:21:58 +08:00
|
|
|
if len(w) > 1:
|
2023-11-18 17:56:48 +08:00
|
|
|
for x in w[1:]:
|
|
|
|
if x != "":
|
2023-11-18 21:09:14 +08:00
|
|
|
goal_set.add(x[1:] + ")")
|
2023-11-18 17:56:48 +08:00
|
|
|
self.function_success = True
|
|
|
|
except:
|
|
|
|
print("参数解析错误")
|
2023-11-08 17:37:49 +08:00
|
|
|
|
2023-11-18 17:56:48 +08:00
|
|
|
self.scene.robot.expand_sub_task_tree(goal_set)
|
2023-11-18 22:30:14 +08:00
|
|
|
|
2023-11-22 17:58:18 +08:00
|
|
|
|
2023-11-19 17:42:56 +08:00
|
|
|
def get_object_info(self,**args):
|
|
|
|
try:
|
|
|
|
obj = args['obj']
|
|
|
|
self.function_success = True
|
|
|
|
except:
|
|
|
|
obj = None
|
|
|
|
print("参数解析错误")
|
|
|
|
|
2023-11-28 17:50:57 +08:00
|
|
|
|
2023-11-28 18:34:46 +08:00
|
|
|
near_object = None
|
|
|
|
d = {"保温杯": "二号桌子","洗手间":"前门","卫生间":"前门"}
|
2023-11-19 17:42:56 +08:00
|
|
|
|
2023-11-28 17:50:20 +08:00
|
|
|
|
|
|
|
# 先把 obj 转成中文
|
|
|
|
|
|
|
|
# 写死的内容
|
2023-11-19 17:42:56 +08:00
|
|
|
if obj in d.keys():
|
2023-11-28 17:50:20 +08:00
|
|
|
near_object = d[obj]
|
|
|
|
near_object = f"{obj}在{near_object}附近"
|
|
|
|
obj_id = 0
|
|
|
|
else: # 根据相似性查找物品位置
|
|
|
|
obj = self.obj_name_en2zh(obj)
|
2023-11-19 17:42:56 +08:00
|
|
|
|
2023-11-28 17:50:20 +08:00
|
|
|
max_similarity = 0.02
|
|
|
|
similar_word = None
|
|
|
|
|
|
|
|
# 场景中现有物品
|
|
|
|
cur_things = set()
|
|
|
|
for item in self.scene.status.objects:
|
|
|
|
cur_things.add(self.scene.objname_en2zh_dic[item.name])
|
|
|
|
# obj与现有物品进行相似度匹配 中文的匹配
|
2023-11-28 20:10:33 +08:00
|
|
|
# print("==========obj==========:",obj)
|
2023-11-28 17:50:20 +08:00
|
|
|
query_token = nlp_zh(obj)
|
|
|
|
for w in cur_things:
|
|
|
|
word_token = nlp_zh(w)
|
|
|
|
similarity = query_token.similarity(word_token)
|
2023-11-28 20:10:33 +08:00
|
|
|
# print("similarity:", similarity, w)
|
2023-11-28 17:50:20 +08:00
|
|
|
if similarity > max_similarity:
|
|
|
|
max_similarity = similarity
|
|
|
|
similar_word = w
|
|
|
|
if similar_word:
|
|
|
|
print("max_similarity:",max_similarity,"similar_word:",similar_word)
|
|
|
|
|
|
|
|
if similar_word: # 存在同义词说明场景中存在该物品
|
|
|
|
# near_object = random.choices(list(cur_things), k=5) # 返回场景中的5个物品
|
|
|
|
# 找到距离最近的物品
|
|
|
|
similar_word_en = self.scene.objname_zh2en_dic[similar_word]
|
|
|
|
obj_dict = self.scene.status.objects
|
|
|
|
if len(obj_dict)!=0:
|
|
|
|
|
|
|
|
for id, obji in enumerate(obj_dict):
|
|
|
|
if obji.name == similar_word_en:
|
|
|
|
obj_info = obj_dict[id]
|
|
|
|
objx,objy,objz = obj_info.location.X, obj_info.location.Y, obj_info.location.Z
|
|
|
|
break
|
|
|
|
|
|
|
|
# 获取离它最近的物品
|
|
|
|
# min_dis = float('inf')
|
|
|
|
# obj_id = -1
|
|
|
|
# near_object = None
|
|
|
|
# for id,obji in enumerate(obj_dict):
|
|
|
|
# if obji.name != similar_word_en:
|
|
|
|
# obj_info = obj_dict[id]
|
|
|
|
# dis = self.scene.getDistanc3D((obj_info.location.X, obj_info.location.Y, obj_info.location.Z),(objx,objy,objz))
|
|
|
|
# if dis<min_dis:
|
|
|
|
# min_dis = dis
|
|
|
|
# obj_id = id
|
|
|
|
# near_object = obji.name
|
|
|
|
#
|
|
|
|
# near_object = f"{obj}在{self.scene.objname_en2zh_dic[near_object]}附近"
|
|
|
|
|
|
|
|
# 直接输出在哪个桌子上
|
|
|
|
min_dis = float('inf')
|
|
|
|
table_name = -1
|
|
|
|
near_object = None
|
|
|
|
for key,values in self.place_have_obj_xyz_dic.items():
|
|
|
|
dis = self.scene.getDistanc3D(values,(objx, objy, objz))
|
|
|
|
if dis<min_dis:
|
|
|
|
min_dis = dis
|
|
|
|
table_name = key
|
|
|
|
# near_object = f"{obj}在{self.place_en2zh_name[table_name]}附近"
|
|
|
|
# near_object = self.place_en2zh_name[table_name]
|
|
|
|
near_object = obj + "在" + self.place_en2zh_name[table_name] +"附近"
|
|
|
|
# near_object = self.place_en2zh_name[table_name]
|
|
|
|
# near_object = "在" + self.place_en2zh_name[table_name] + "附近"
|
|
|
|
return near_object
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# def find_location(self, **args):
|
|
|
|
# try:
|
|
|
|
# location = args['obj']
|
|
|
|
# self.function_success = True
|
|
|
|
# except:
|
|
|
|
# obj = None
|
|
|
|
# print("参数解析错误")
|
2023-11-19 14:44:22 +08:00
|
|
|
#
|
2023-11-28 17:50:20 +08:00
|
|
|
# d = {"保温杯": "二号桌子"}
|
|
|
|
# if location in d.keys():
|
|
|
|
# result = d[obj]
|
|
|
|
# else:
|
|
|
|
# result = "None"
|
|
|
|
# return result
|
|
|
|
# # 用户咨询的地点
|
|
|
|
# query_token = nlp(location)
|
|
|
|
# max_similarity = 0
|
|
|
|
# similar_word = None
|
|
|
|
# # 到自己维护的地点列表中找同义词
|
|
|
|
# for w in self.scene.all_loc_en:
|
2023-11-19 14:24:36 +08:00
|
|
|
# word_token = nlp(w)
|
|
|
|
# similarity = query_token.similarity(word_token)
|
|
|
|
# if similarity > max_similarity:
|
|
|
|
# max_similarity = similarity
|
|
|
|
# similar_word = w
|
2023-11-28 17:50:20 +08:00
|
|
|
# print("similarity:", max_similarity, "similar_word:", similar_word)
|
|
|
|
# # 存在同义词说明客户咨询的地点有效
|
2023-11-19 14:44:22 +08:00
|
|
|
# if similar_word:
|
2023-11-28 17:50:20 +08:00
|
|
|
# mp = list(self.scene.loc_map_en[similar_word])
|
|
|
|
# near_location = random.choice(mp)
|
|
|
|
# return near_location
|
|
|
|
|
|
|
|
def get_number_of_objects(self,**args):
|
2023-11-19 17:42:56 +08:00
|
|
|
try:
|
2023-11-28 17:50:20 +08:00
|
|
|
obj = args['obj']
|
2023-11-19 17:42:56 +08:00
|
|
|
self.function_success = True
|
2023-11-28 17:50:20 +08:00
|
|
|
obj = self.obj_name_en2zh(obj)
|
2023-11-19 17:42:56 +08:00
|
|
|
except:
|
|
|
|
obj = None
|
|
|
|
print("参数解析错误")
|
|
|
|
|
2023-11-28 17:50:20 +08:00
|
|
|
# 找到最近的中文
|
|
|
|
max_similarity = 0.02
|
|
|
|
similar_word = None
|
|
|
|
|
|
|
|
# obj 是中文
|
|
|
|
# obj = translator.translate(obj) #转成中文
|
|
|
|
# print("obj:",obj)
|
|
|
|
query_token = nlp_zh(obj)
|
|
|
|
for real_obj_name in self.scene.objname_zh2en_dic.keys(): # 在中文名字里面找
|
|
|
|
word_token = nlp_zh(real_obj_name)
|
|
|
|
similarity = query_token.similarity(word_token)
|
|
|
|
# print("similarity:",similarity,real_obj_name)
|
|
|
|
if similarity > max_similarity:
|
|
|
|
max_similarity = similarity
|
|
|
|
similar_word = real_obj_name
|
|
|
|
if similar_word:
|
|
|
|
print("max_similarity:",max_similarity,"similar_word:",similar_word)
|
|
|
|
|
|
|
|
count = 0
|
|
|
|
similar_word_en = self.scene.objname_zh2en_dic[similar_word]
|
|
|
|
if similar_word_en != "Customer":
|
|
|
|
for item in self.scene.status.objects:
|
|
|
|
if item.name == similar_word_en:
|
|
|
|
count+=1
|
2023-11-19 17:42:56 +08:00
|
|
|
else:
|
2023-11-28 17:50:20 +08:00
|
|
|
count = len(self.scene.status.walkers)
|
|
|
|
|
|
|
|
# obj 是英文
|
|
|
|
# query_token = nlp(obj)
|
|
|
|
# for real_obj_name in self.scene.objname_en2zh_dic.keys(): # 在中文名字里面找
|
|
|
|
# word_token = nlp(real_obj_name)
|
2023-11-19 17:42:56 +08:00
|
|
|
# similarity = query_token.similarity(word_token)
|
2023-11-28 17:50:20 +08:00
|
|
|
# print("similarity:",similarity,real_obj_name)
|
2023-11-19 17:42:56 +08:00
|
|
|
# if similarity > max_similarity:
|
|
|
|
# max_similarity = similarity
|
2023-11-28 17:50:20 +08:00
|
|
|
# similar_word = real_obj_name
|
2023-11-19 17:42:56 +08:00
|
|
|
# if similar_word:
|
2023-11-28 17:50:20 +08:00
|
|
|
# print("max_similarity:",max_similarity,"similar_word:",similar_word)
|
2023-11-28 18:34:46 +08:00
|
|
|
return "有"+str(count)+"个"+obj
|
2023-11-28 17:50:20 +08:00
|
|
|
|
|
|
|
|
2023-11-19 10:48:45 +08:00
|
|
|
|
|
|
|
def stop_serve(self,**args):
|
2023-11-19 16:55:22 +08:00
|
|
|
customer = self.scene.state["attention"]["customer"]
|
|
|
|
serve_state = self.scene.state["serve_state"][customer]
|
2023-11-19 10:48:45 +08:00
|
|
|
|
2023-11-19 16:55:22 +08:00
|
|
|
serve_state['served'] = True
|
2023-11-19 10:48:45 +08:00
|
|
|
|
2023-11-19 14:24:06 +08:00
|
|
|
return "好的"
|
|
|
|
|
|
|
|
|