Update DealChat.py

This commit is contained in:
Netceor 2023-11-19 14:24:36 +08:00
parent f5bfe0c5b1
commit 3e48f5a39c
1 changed files with 61 additions and 62 deletions

View File

@ -4,9 +4,8 @@ from robowaiter.behavior_lib._base.Act import Act
from robowaiter.llm_client.multi_rounds import ask_llm, new_history
import random
import spacy
nlp = spacy.load('en_core_web_lg')
# import spacy
# nlp = spacy.load('en_core_web_lg')
class DealChat(Act):
@ -63,65 +62,65 @@ class DealChat(Act):
self.scene.robot.expand_sub_task_tree(goal_set)
def get_object_info(self,**args):
try:
obj = args['obj']
self.function_success = True
except:
obj = None
print("参数解析错误")
near_object = "None"
# 场景中现有物品
cur_things = set()
for item in self.status.objects:
cur_things.add(item.name)
# obj与现有物品进行相似度匹配
query_token = nlp(obj)
for w in self.all_loc_en:
word_token = nlp(w)
similarity = query_token.similarity(word_token)
if similarity > max_similarity:
max_similarity = similarity
similar_word = w
print("max_similarity:",max_similarity,"similar_word:",similar_word)
if similar_word: # 存在同义词说明场景中存在该物品
near_object = random.choices(list(cur_things), k=5) # 返回场景中的5个物品
if obj == "洗手间":
near_object = "大门"
return near_object
def find_location(self, **args):
try:
location = args['obj']
self.function_success = True
except:
obj = None
print("参数解析错误")
near_location = None
# 用户咨询的地点
query_token = nlp(location)
max_similarity = 0
similar_word = None
# 到自己维护的地点列表中找同义词
for w in self.all_loc_en:
word_token = nlp(w)
similarity = query_token.similarity(word_token)
if similarity > max_similarity:
max_similarity = similarity
similar_word = w
print("similarity:", max_similarity, "similar_word:", similar_word)
# 存在同义词说明客户咨询的地点有效
if similar_word:
mp = list(self.loc_map_en[similar_word])
near_location = random.choice(mp)
return near_location
# def get_object_info(self,**args):
# try:
# obj = args['obj']
#
# self.function_success = True
# except:
# obj = None
# print("参数解析错误")
#
# near_object = "None"
#
# # 场景中现有物品
# cur_things = set()
# for item in self.status.objects:
# cur_things.add(item.name)
# # obj与现有物品进行相似度匹配
# query_token = nlp(obj)
# for w in self.all_loc_en:
# word_token = nlp(w)
# similarity = query_token.similarity(word_token)
# if similarity > max_similarity:
# max_similarity = similarity
# similar_word = w
# print("max_similarity:",max_similarity,"similar_word:",similar_word)
#
# if similar_word: # 存在同义词说明场景中存在该物品
# near_object = random.choices(list(cur_things), k=5) # 返回场景中的5个物品
#
# if obj == "洗手间":
# near_object = "大门"
#
# return near_object
#
# def find_location(self, **args):
# try:
# location = args['obj']
# self.function_success = True
# except:
# obj = None
# print("参数解析错误")
#
# near_location = None
# # 用户咨询的地点
# query_token = nlp(location)
# max_similarity = 0
# similar_word = None
# # 到自己维护的地点列表中找同义词
# for w in self.all_loc_en:
# word_token = nlp(w)
# similarity = query_token.similarity(word_token)
# if similarity > max_similarity:
# max_similarity = similarity
# similar_word = w
# print("similarity:", max_similarity, "similar_word:", similar_word)
# # 存在同义词说明客户咨询的地点有效
# if similar_word:
# mp = list(self.loc_map_en[similar_word])
# near_location = random.choice(mp)
# return near_location
def stop_serve(self,**args):