Update DealChat.py
This commit is contained in:
parent
f5bfe0c5b1
commit
3e48f5a39c
|
@ -4,9 +4,8 @@ from robowaiter.behavior_lib._base.Act import Act
|
||||||
from robowaiter.llm_client.multi_rounds import ask_llm, new_history
|
from robowaiter.llm_client.multi_rounds import ask_llm, new_history
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import spacy
|
# import spacy
|
||||||
|
# nlp = spacy.load('en_core_web_lg')
|
||||||
nlp = spacy.load('en_core_web_lg')
|
|
||||||
|
|
||||||
|
|
||||||
class DealChat(Act):
|
class DealChat(Act):
|
||||||
|
@ -63,65 +62,65 @@ class DealChat(Act):
|
||||||
|
|
||||||
self.scene.robot.expand_sub_task_tree(goal_set)
|
self.scene.robot.expand_sub_task_tree(goal_set)
|
||||||
|
|
||||||
def get_object_info(self,**args):
|
# def get_object_info(self,**args):
|
||||||
try:
|
# try:
|
||||||
obj = args['obj']
|
# obj = args['obj']
|
||||||
|
#
|
||||||
self.function_success = True
|
# self.function_success = True
|
||||||
except:
|
# except:
|
||||||
obj = None
|
# obj = None
|
||||||
print("参数解析错误")
|
# print("参数解析错误")
|
||||||
|
#
|
||||||
near_object = "None"
|
# near_object = "None"
|
||||||
|
#
|
||||||
# 场景中现有物品
|
# # 场景中现有物品
|
||||||
cur_things = set()
|
# cur_things = set()
|
||||||
for item in self.status.objects:
|
# for item in self.status.objects:
|
||||||
cur_things.add(item.name)
|
# cur_things.add(item.name)
|
||||||
# obj与现有物品进行相似度匹配
|
# # obj与现有物品进行相似度匹配
|
||||||
query_token = nlp(obj)
|
# query_token = nlp(obj)
|
||||||
for w in self.all_loc_en:
|
# for w in self.all_loc_en:
|
||||||
word_token = nlp(w)
|
# word_token = nlp(w)
|
||||||
similarity = query_token.similarity(word_token)
|
# similarity = query_token.similarity(word_token)
|
||||||
if similarity > max_similarity:
|
# if similarity > max_similarity:
|
||||||
max_similarity = similarity
|
# max_similarity = similarity
|
||||||
similar_word = w
|
# similar_word = w
|
||||||
print("max_similarity:",max_similarity,"similar_word:",similar_word)
|
# print("max_similarity:",max_similarity,"similar_word:",similar_word)
|
||||||
|
#
|
||||||
if similar_word: # 存在同义词说明场景中存在该物品
|
# if similar_word: # 存在同义词说明场景中存在该物品
|
||||||
near_object = random.choices(list(cur_things), k=5) # 返回场景中的5个物品
|
# near_object = random.choices(list(cur_things), k=5) # 返回场景中的5个物品
|
||||||
|
#
|
||||||
if obj == "洗手间":
|
# if obj == "洗手间":
|
||||||
near_object = "大门"
|
# near_object = "大门"
|
||||||
|
#
|
||||||
return near_object
|
# return near_object
|
||||||
|
#
|
||||||
def find_location(self, **args):
|
# def find_location(self, **args):
|
||||||
try:
|
# try:
|
||||||
location = args['obj']
|
# location = args['obj']
|
||||||
self.function_success = True
|
# self.function_success = True
|
||||||
except:
|
# except:
|
||||||
obj = None
|
# obj = None
|
||||||
print("参数解析错误")
|
# print("参数解析错误")
|
||||||
|
#
|
||||||
near_location = None
|
# near_location = None
|
||||||
# 用户咨询的地点
|
# # 用户咨询的地点
|
||||||
query_token = nlp(location)
|
# query_token = nlp(location)
|
||||||
max_similarity = 0
|
# max_similarity = 0
|
||||||
similar_word = None
|
# similar_word = None
|
||||||
# 到自己维护的地点列表中找同义词
|
# # 到自己维护的地点列表中找同义词
|
||||||
for w in self.all_loc_en:
|
# for w in self.all_loc_en:
|
||||||
word_token = nlp(w)
|
# word_token = nlp(w)
|
||||||
similarity = query_token.similarity(word_token)
|
# similarity = query_token.similarity(word_token)
|
||||||
if similarity > max_similarity:
|
# if similarity > max_similarity:
|
||||||
max_similarity = similarity
|
# max_similarity = similarity
|
||||||
similar_word = w
|
# similar_word = w
|
||||||
print("similarity:", max_similarity, "similar_word:", similar_word)
|
# print("similarity:", max_similarity, "similar_word:", similar_word)
|
||||||
# 存在同义词说明客户咨询的地点有效
|
# # 存在同义词说明客户咨询的地点有效
|
||||||
if similar_word:
|
# if similar_word:
|
||||||
mp = list(self.loc_map_en[similar_word])
|
# mp = list(self.loc_map_en[similar_word])
|
||||||
near_location = random.choice(mp)
|
# near_location = random.choice(mp)
|
||||||
return near_location
|
# return near_location
|
||||||
|
|
||||||
def stop_serve(self,**args):
|
def stop_serve(self,**args):
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue