更新检索增强
This commit is contained in:
parent
2e4d563aad
commit
39f416acea
|
@ -94,7 +94,7 @@ class DealChat(Act):
|
||||||
if obj in d.keys():
|
if obj in d.keys():
|
||||||
result = d[obj]
|
result = d[obj]
|
||||||
else:
|
else:
|
||||||
result = "None"
|
result = "没有"
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# max_similarity = 0.02
|
# max_similarity = 0.02
|
||||||
|
|
|
@ -90,7 +90,7 @@ def get_response(sentence, history, allow_function_call = True):
|
||||||
|
|
||||||
if sentence in fix_questions_dict:
|
if sentence in fix_questions_dict:
|
||||||
time.sleep(2)
|
time.sleep(2)
|
||||||
return parse_fix_question(sentence)
|
return True, parse_fix_question(sentence)
|
||||||
|
|
||||||
params = dict(model="RoboWaiter")
|
params = dict(model="RoboWaiter")
|
||||||
params['messages'] = role_system + list(history)
|
params['messages'] = role_system + list(history)
|
||||||
|
@ -100,7 +100,7 @@ def get_response(sentence, history, allow_function_call = True):
|
||||||
|
|
||||||
response = requests.post(f"{base_url}/v1/chat/completions", json=params, stream=False, verify=False)
|
response = requests.post(f"{base_url}/v1/chat/completions", json=params, stream=False, verify=False)
|
||||||
decoded_line = response.json()
|
decoded_line = response.json()
|
||||||
return decoded_line
|
return False, decoded_line
|
||||||
|
|
||||||
def deal_response(response, history, func_map=None ):
|
def deal_response(response, history, func_map=None ):
|
||||||
if response["choices"][0]["message"].get("function_call"):
|
if response["choices"][0]["message"].get("function_call"):
|
||||||
|
@ -143,7 +143,7 @@ def deal_response(response, history, func_map=None ):
|
||||||
|
|
||||||
|
|
||||||
def ask_llm(question,history, func_map=None, retry=3):
|
def ask_llm(question,history, func_map=None, retry=3):
|
||||||
response = get_response(question, history)
|
fixed, response = get_response(question, history)
|
||||||
|
|
||||||
function_call,result = deal_response(response, history, func_map)
|
function_call,result = deal_response(response, history, func_map)
|
||||||
if function_call:
|
if function_call:
|
||||||
|
@ -161,7 +161,7 @@ def ask_llm(question,history, func_map=None, retry=3):
|
||||||
history.append(message)
|
history.append(message)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
response = get_response(None, history,allow_function_call=False)
|
fixed, response = get_response(None, history,allow_function_call=False)
|
||||||
_,result = deal_response(response, history, func_map)
|
_,result = deal_response(response, history, func_map)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -68,7 +68,7 @@ def parse_fix_question(question):
|
||||||
'function_call': {'name': func, 'arguments': args}}
|
'function_call': {'name': func, 'arguments': args}}
|
||||||
|
|
||||||
response["choices"][0]["message"] = message
|
response["choices"][0]["message"] = message
|
||||||
return response
|
return response, question["answer"]
|
||||||
|
|
||||||
def get_response(sentence, history, allow_function_call = True):
|
def get_response(sentence, history, allow_function_call = True):
|
||||||
if sentence:
|
if sentence:
|
||||||
|
@ -76,7 +76,7 @@ def get_response(sentence, history, allow_function_call = True):
|
||||||
|
|
||||||
retrieval_result = retrieval.get_result(sentence)
|
retrieval_result = retrieval.get_result(sentence)
|
||||||
if retrieval_result is not None:
|
if retrieval_result is not None:
|
||||||
time.sleep(2)
|
time.sleep(1.2)
|
||||||
# 处理多轮
|
# 处理多轮
|
||||||
if retrieval_result["answer"] == "multi_rounds" and len(history) >= 2:
|
if retrieval_result["answer"] == "multi_rounds" and len(history) >= 2:
|
||||||
print("触发多轮检索")
|
print("触发多轮检索")
|
||||||
|
@ -87,7 +87,8 @@ def get_response(sentence, history, allow_function_call = True):
|
||||||
break
|
break
|
||||||
retrieval_result = retrieval.get_result(last_content + sentence)
|
retrieval_result = retrieval.get_result(last_content + sentence)
|
||||||
if retrieval_result is not None:
|
if retrieval_result is not None:
|
||||||
return parse_fix_question(retrieval_result)
|
response, answer = parse_fix_question(retrieval_result)
|
||||||
|
return True,response, answer
|
||||||
|
|
||||||
params = dict(model="RoboWaiter")
|
params = dict(model="RoboWaiter")
|
||||||
params['messages'] = role_system + list(history)
|
params['messages'] = role_system + list(history)
|
||||||
|
@ -97,7 +98,7 @@ def get_response(sentence, history, allow_function_call = True):
|
||||||
|
|
||||||
response = requests.post(f"{base_url}/v1/chat/completions", json=params, stream=False, verify=False)
|
response = requests.post(f"{base_url}/v1/chat/completions", json=params, stream=False, verify=False)
|
||||||
decoded_line = response.json()
|
decoded_line = response.json()
|
||||||
return decoded_line
|
return False, decoded_line, None
|
||||||
|
|
||||||
def deal_response(response, history, func_map=None ):
|
def deal_response(response, history, func_map=None ):
|
||||||
if response["choices"][0]["message"].get("function_call"):
|
if response["choices"][0]["message"].get("function_call"):
|
||||||
|
@ -140,26 +141,27 @@ def deal_response(response, history, func_map=None ):
|
||||||
|
|
||||||
|
|
||||||
def ask_llm(question,history, func_map=None, retry=3):
|
def ask_llm(question,history, func_map=None, retry=3):
|
||||||
response = get_response(question, history)
|
fixed, response, answer = get_response(question, history)
|
||||||
|
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
|
|
||||||
function_call,result = deal_response(response, history, func_map)
|
function_call,result = deal_response(response, history, func_map)
|
||||||
sentence = response["choices"][0]["message"]["content"]
|
|
||||||
if function_call:
|
if function_call:
|
||||||
|
if fixed:
|
||||||
if function_call == "create_sub_task":
|
if function_call == "create_sub_task":
|
||||||
result = single_round(sentence,
|
result = single_round(answer,
|
||||||
"你是机器人服务员,请把以下句子换一种表述方式对顾客说,但是意思不变,尽量简短:\n")
|
"你是机器人服务员,请把以下句子换一种表述方式对顾客说,但是意思不变,尽量简短:\n")
|
||||||
# elif function_call in ["get_object_info","find_location"] :
|
# elif function_call in ["get_object_info","find_location"] :
|
||||||
else:
|
else:
|
||||||
result = single_round(f"你是机器人服务员,顾客想知道{question}, 你的具身场景查询返回的是{result},把返回的英文名词翻译成中文,请把按照以下句子对顾客说,{sentence}, 尽量简短。\n")
|
result = single_round(f"你是机器人服务员,顾客想知道{question}, 你的具身场景查询返回的是{result},把返回的英文名词翻译成中文,请把按照以下句子对顾客说,{answer}, 尽量简短。\n")
|
||||||
|
|
||||||
message = {'role': 'assistant', 'content': result, 'name': None,
|
message = {'role': 'assistant', 'content': result, 'name': None,
|
||||||
'function_call': None}
|
'function_call': None}
|
||||||
history.append(message)
|
history.append(message)
|
||||||
|
else:
|
||||||
# response = get_response(None, history,allow_function_call=False)
|
_,response,_ = get_response(None, history,allow_function_call=False)
|
||||||
# _,result = deal_response(response, history, func_map)
|
_,result = deal_response(response, history, func_map)
|
||||||
|
|
||||||
|
|
||||||
print(f'{len(history)}条历史记录:')
|
print(f'{len(history)}条历史记录:')
|
||||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 92 KiB After Width: | Height: | Size: 114 KiB |
Loading…
Reference in New Issue