更新检索增强

This commit is contained in:
ChenXL97 2023-11-28 17:49:40 +08:00
parent 2e4d563aad
commit 39f416acea
4 changed files with 25 additions and 23 deletions

View File

@ -94,7 +94,7 @@ class DealChat(Act):
if obj in d.keys():
result = d[obj]
else:
result = "None"
result = "没有"
return result
# max_similarity = 0.02

View File

@ -90,7 +90,7 @@ def get_response(sentence, history, allow_function_call = True):
if sentence in fix_questions_dict:
time.sleep(2)
return parse_fix_question(sentence)
return True, parse_fix_question(sentence)
params = dict(model="RoboWaiter")
params['messages'] = role_system + list(history)
@ -100,7 +100,7 @@ def get_response(sentence, history, allow_function_call = True):
response = requests.post(f"{base_url}/v1/chat/completions", json=params, stream=False, verify=False)
decoded_line = response.json()
return decoded_line
return False, decoded_line
def deal_response(response, history, func_map=None ):
if response["choices"][0]["message"].get("function_call"):
@ -143,7 +143,7 @@ def deal_response(response, history, func_map=None ):
def ask_llm(question,history, func_map=None, retry=3):
response = get_response(question, history)
fixed, response = get_response(question, history)
function_call,result = deal_response(response, history, func_map)
if function_call:
@ -161,7 +161,7 @@ def ask_llm(question,history, func_map=None, retry=3):
history.append(message)
else:
response = get_response(None, history,allow_function_call=False)
fixed, response = get_response(None, history,allow_function_call=False)
_,result = deal_response(response, history, func_map)

View File

@ -68,7 +68,7 @@ def parse_fix_question(question):
'function_call': {'name': func, 'arguments': args}}
response["choices"][0]["message"] = message
return response
return response, question["answer"]
def get_response(sentence, history, allow_function_call = True):
if sentence:
@ -76,7 +76,7 @@ def get_response(sentence, history, allow_function_call = True):
retrieval_result = retrieval.get_result(sentence)
if retrieval_result is not None:
time.sleep(2)
time.sleep(1.2)
# 处理多轮
if retrieval_result["answer"] == "multi_rounds" and len(history) >= 2:
print("触发多轮检索")
@ -87,7 +87,8 @@ def get_response(sentence, history, allow_function_call = True):
break
retrieval_result = retrieval.get_result(last_content + sentence)
if retrieval_result is not None:
return parse_fix_question(retrieval_result)
response, answer = parse_fix_question(retrieval_result)
return True,response, answer
params = dict(model="RoboWaiter")
params['messages'] = role_system + list(history)
@ -97,7 +98,7 @@ def get_response(sentence, history, allow_function_call = True):
response = requests.post(f"{base_url}/v1/chat/completions", json=params, stream=False, verify=False)
decoded_line = response.json()
return decoded_line
return False, decoded_line, None
def deal_response(response, history, func_map=None ):
if response["choices"][0]["message"].get("function_call"):
@ -140,26 +141,27 @@ def deal_response(response, history, func_map=None ):
def ask_llm(question,history, func_map=None, retry=3):
response = get_response(question, history)
fixed, response, answer = get_response(question, history)
print(f"response: {response}")
function_call,result = deal_response(response, history, func_map)
sentence = response["choices"][0]["message"]["content"]
if function_call:
if function_call == "create_sub_task":
result = single_round(sentence,
"你是机器人服务员,请把以下句子换一种表述方式对顾客说,但是意思不变,尽量简短:\n")
# elif function_call in ["get_object_info","find_location"] :
if fixed:
if function_call == "create_sub_task":
result = single_round(answer,
"你是机器人服务员,请把以下句子换一种表述方式对顾客说,但是意思不变,尽量简短:\n")
# elif function_call in ["get_object_info","find_location"] :
else:
result = single_round(f"你是机器人服务员,顾客想知道{question}, 你的具身场景查询返回的是{result},把返回的英文名词翻译成中文,请把按照以下句子对顾客说,{answer}, 尽量简短。\n")
message = {'role': 'assistant', 'content': result, 'name': None,
'function_call': None}
history.append(message)
else:
result = single_round(f"你是机器人服务员,顾客想知道{question}, 你的具身场景查询返回的是{result},把返回的英文名词翻译成中文,请把按照以下句子对顾客说,{sentence}, 尽量简短。\n")
message = {'role': 'assistant', 'content': result, 'name': None,
'function_call': None}
history.append(message)
# response = get_response(None, history,allow_function_call=False)
# _,result = deal_response(response, history, func_map)
_,response,_ = get_response(None, history,allow_function_call=False)
_,result = deal_response(response, history, func_map)
print(f'{len(history)}条历史记录:')

Binary file not shown.

Before

Width:  |  Height:  |  Size: 92 KiB

After

Width:  |  Height:  |  Size: 114 KiB