RoboWaiter/BTExpansionCode/llm_test/main.py

305 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import time
import numpy as np
from llm_test.ERNIE_Bot_4 import LLMERNIE
from sympy.parsing.sympy_parser import parse_expr
from sympy import symbols, Not, Or, And, to_dnf
from sympy import symbols, simplify_logic
import re
from dataset.data_process_check import format_check,word_correct,goal_transfer_ls_set
def get_feedback_prompt_last(id,prompt,result,error_list,error_black_set):
# wrong_format_set,wrong_predicate_set,wrong_object_set = error_list
error_message=""
if error_list[0]!=None:
error_message += "It contains syntax errors or illegal characters."
# ("It Contains syntax errors or illegal characters that cannot be converted to disjunctive normal form (DNF) using sympy.to_dnf. ")
# "Please check the syntax in your input and ensure there are no prohibited characters. The answer should consist only of ~, |, &, and the given [Condition] and [Object].. ")
else:
if error_list[1]!=set():
# error_strings = ", ".join(error_list[1])
# error_message += f"\"{error_strings}\" have format errors. They should consist only of ~, |, &, and the given [Condition] and [Object].\n"
error_black_set[0] |= set(error_list[1])
if error_list[2]!=set():
# error_strings = ", ".join(error_list[2])
# error_message += f"\"{error_strings}\" are not in [Condition]. Please select the closest predicates from the [Condition] table to form the answer.\n"
error_black_set[1] |= set(error_list[2])
if error_list[3]!=set():
# error_strings = ", ".join(error_list[3])
# error_message += f"\"{error_strings}\" are not in [Object]. Please select the closest parameter from the [Object] table to form the answer.\n"
error_black_set[2] |= set(error_list[3])
# error_strings = "Do not include: "+", ".join(list(error_black_set))+"." +"Please select the closest parameter from the [Condition] and [Object] table to form the answer."
er_word0 = ", ".join(list(error_black_set[0]))
er_word1 = ", ".join(list(error_black_set[1]))
er_word2 = ", ".join(list(error_black_set[2]))
error_message += f"\"{er_word0}\" have format errors. The answer should consist only of ~, |, &, and the given [Condition] and [Object].\n"
error_message += f"\"{er_word1}\" are not in [Condition]. Please select the closest predicates from the [Condition] table to form the answer.\n"
error_message += f"\"{er_word2}\" are not in [Object]. Please select the closest parameter from the [Object] table to form the answer.\n"
print("** Error_message: ",error_message)
prompt += "\n"+ error_message
return prompt
def get_feedback_prompt0123(id,prompt,result,error_list,error_black_set):
# wrong_format_set,wrong_predicate_set,wrong_object_set = error_list
error_message=""
if error_list[0]!=None:
error_message += "It contains syntax errors or illegal characters."
else:
if error_list[1]!=set():
error_black_set[0] |= set(error_list[1])
er_word0 = ", ".join(list(error_black_set[0]))
error_message += f"\"{er_word0}\" have format errors. The answer should consist only of ~, |, &, and the given [Condition] and [Object].\n"
if error_list[2]!=set():
error_black_set[1] |= set(error_list[2])
er_word1 = ", ".join(list(error_black_set[1]))
error_message += f"\"{er_word1}\" are not in [Condition]. Please select the closest predicates from the [Condition] table to form the answer.\n"
if error_list[3]!=set():
error_black_set[2] |= set(error_list[3])
er_word2 = ", ".join(list(error_black_set[2]))
error_message += f"\"{er_word2}\" are not in [Object]. Please select the closest parameter from the [Object] table to form the answer.\n"
print("** Error_message: ",error_message)
prompt += "\n"+ error_message
return prompt
def get_feedback_prompt(id,prompt1,prompt2,question,result,error_list,error_black_set):
error_message=""
er_word0=""
er_word1=""
er_word2=""
if error_list[0]!=None:
error_message = ""
else:
if error_list[1]!=set():
error_black_set[0] |= set(error_list[1])
if error_list[2]!=set():
error_black_set[1] |= set(error_list[2])
if error_list[3]!=set():
error_black_set[2] |= set(error_list[3])
er_word0 = ", ".join(list(error_black_set[0]))
er_word1 = ", ".join(list(error_black_set[1]))
er_word2 = ", ".join(list(error_black_set[2]))
error_message += f"\n[Blacklist]\n<Illegal Condition>=[{er_word1}]\n<Illegal Object>=[{er_word2}]\n<Other Illegal Words or Characters>=[{er_word0}]\n"
error_message += "\n[Blacklist] Contains restricted elements.\n"+\
"If a word from <Illegal Condition> is encountered, choose the nearest parameter with a similar meaning from the [Condition] table to formulate the answer.\n"+\
"If a word from <Illegal Object> is encountered, choose the nearest parameter with a similar meaning from the [Object] table to formulate the answer."
print("** Blacklist: ",f"[Blacklist]\n<Illegal Characters>=[{er_word0}]\n<Illegal Condition>=[{er_word1}]\n<Illegal Object>=[{er_word2}]")
# prompt = prompt1+prompt2+error_message
# prompt = error_message
# print(prompt)
prompt = prompt1+prompt2
# prompt+= question + result
prompt += error_message
print(error_message)
return prompt
# data_set_file = "easy.txt"
easy_data_set_file = "../dataset/easy_instr_goal.txt"
medium_data_set_file = "../dataset/medium_instr_goal.txt"
hard_data_set_file = "../dataset/hard_instr_goal.txt"
test_data_set_file = "../dataset/test.txt"
data_set_file = "../dataset/data100.txt"
prompt_file1 = "prompt_test1.txt"
prompt_file2 = "prompt_test2.txt"
# prompt_file1 = "prompt1.txt"
# prompt_file2 = "prompt2.txt"
test_num = 1
with open(easy_data_set_file, 'r', encoding="utf-8") as f:
easy_data_set = f.read().strip()
with open(medium_data_set_file, 'r', encoding="utf-8") as f:
medium_data_set = f.read().strip()
with open(hard_data_set_file, 'r', encoding="utf-8") as f:
hard_data_set = f.read().strip()
# with open(data_set_file, 'r', encoding="utf-8") as f:
# data_set = f.read().strip()
# easy_sections = re.split(r'\n\s*\n', easy_data_set)
# print("easy:",len(easy_sections))
# medium_sections = re.split(r'\n\s*\n', medium_data_set)
# print("medium:",len(medium_sections))
# hard_sections = re.split(r'\n\s*\n', hard_data_set)
# print("hard:",len(hard_sections))
# data_set = easy_data_set & medium_data_set
# print(data_set)
with open(test_data_set_file, 'r', encoding="utf-8") as f:
test_data_set = f.read().strip()
data_set = test_data_set
# data_set =hard_data_set
# data_set = easy_data_set + medium_data_set + hard_data_set
with open(prompt_file1, 'r', encoding="utf-8") as f:
prompt1 = f.read().strip()
with open(prompt_file2, 'r', encoding="utf-8") as f:
prompt2 = f.read().strip()
prompt = prompt1+prompt2
sections = re.split(r'\n\s*\n', data_set)
# print("data_set:",len(sections))
count = 0
llm = LLMERNIE()
question_list = []
correct_answer_list = []
correct_answer_ls_set = []
outputs_list = [[] for _ in range(len(sections))]
# 批量提交问题
# for i,s in enumerate(sections[:test_num]):
for i, s in enumerate(sections):
x, y = s.strip().splitlines()
x = x.strip()
y = y.strip().replace("Goal: ", "")
# print(f"x: {x.strip()}, y: {y.strip()}")
question_list.append(x)
correct_answer_list.append(y)
correct_answer_ls_set.append(goal_transfer_ls_set(y))
llm.ask(x, prompt=prompt, tag=i)
total_num = len(question_list)
error_black_ls = [[set(),set(),set()] for _ in range(total_num)]
try_times = 1
# total_GR_ls = np.zeros(5)
total_GR_ls=[]
total_SR_ls = []
total_GCR_ls = []
# for time in range(try_times):
finish_num = 0
SR = 0
GR = 0
GCR = 0
# 统计语法正确的数量
GR_ls=np.zeros(6)
while finish_num < total_num:
result = llm.get_result()
if result:
# print(result)
id,question,answer = result
# print(correct_answer_list[id])
outputs_list[id].append(answer)
#如果不正确,且回答次数<5则反馈
answer = "(On_Juice_Table6 | ~Exist_Juice=>On_Coffee_Table6 ) & ( ~Low_AC | Open_Curtain )"
format_correct,error_list = format_check(answer)
print("error_list:",error_list)
print(f"===== id:{id} Q: {question} =====")
if not format_correct:
if len(outputs_list[id]) < 6:
print("*** answer:",answer)
# new_prompt = get_feedback_prompt(id,prompt1,prompt2,answer,error_list,error_black_ls[id])
new_prompt = get_feedback_prompt(id, prompt1,prompt2, question_list[id],answer, error_list, error_black_ls[id])
llm.ask(question, prompt=new_prompt, tag=id)
# llm.ask(question, prompt=prompt, tag=id)
print(f" Retry:{len(outputs_list[id])} A:{outputs_list[id]}")
# print(f"id: {id} Retry:{len(outputs_list[id])} A:{outputs_list[id]}, Q: {question}")
else:
finish_num += 1
print(f"A: {outputs_list[id]}")
print(f"CA: {correct_answer_list[id]}")
GR_r = GR_ls / finish_num
gr_s = 0
for i in range(len(GR_r)):
gr_s += GR_r[i]
GR_r[i] = gr_s
print(f"Correct:False GR:{GR_r[0],GR_r[1],GR_r[-1]} SR:{SR / finish_num} GCR:{GCR / finish_num}")
else:
GR_ls[len(outputs_list[id])-1] += 1
correct = False
answer_ls_set = goal_transfer_ls_set(answer)
if answer_ls_set == correct_answer_ls_set[id]:
SR += 1
GCR+=1
correct = True
else:
GCR += len([a_set for a_set in answer_ls_set if a_set in correct_answer_ls_set[id]])*1.0/len(correct_answer_ls_set[id])
# GCR += len(answer_ls_set-correct_answer_ls_set)*1.0/len(correct_answer_ls_set)
# print(f"correct_num: {correct_num}")
# print()
finish_num += 1
# print(f"Correct:{correct} GR:{GR/finish_num} SR:{SR/finish_num} A:{outputs_list[id]}, Q: {question}")
# print("id=",id,"answer:", answer, " == correct_answer_list[id]:", correct_answer_list[id])
print(f"A: {outputs_list[id]}")
print(f"CA: {correct_answer_list[id]}")
GR_r = GR_ls/finish_num
gr_s=0
for i in range(len(GR_r)):
gr_s+=GR_r[i]
GR_r[i] = gr_s
print(f"Correct:{correct} GR:{GR_r[0],GR_r[1],GR_r[-1]} SR:{SR/finish_num} GCR:{GCR/finish_num}")
total_GR_ls.append(GR_ls)
total_SR_ls.append(SR)
total_GCR_ls.append(GCR)
else:
time.sleep(0.01)
llm.close()
# print(f"result: {result}")
# print(f"y: {y}")
# print(f"count: {count}")
# print()
#
# if result == y:
# count += 1
# print("GR =:", round(np.mean(total_GR_ls), 4), "std=", round(np.std(total_GR_ls), 4))
# print("SR = ", round(np.mean(total_SR_ls), 3), "std=", round(np.std(total_SR_ls, ddof=1), 3))
# print("GCR = :", round(np.mean(total_GCR_ls), 3), "std=", round(np.std(total_GCR_ls, ddof=1), 3))