2024-01-10 11:16:12 +08:00
import time
2024-04-10 19:59:13 +08:00
import numpy as np
2024-01-10 11:16:12 +08:00
from llm_test . ERNIE_Bot_4 import LLMERNIE
from sympy . parsing . sympy_parser import parse_expr
from sympy import symbols , Not , Or , And , to_dnf
from sympy import symbols , simplify_logic
import re
2024-04-10 19:59:13 +08:00
from dataset . data_process_check import format_check , word_correct , goal_transfer_ls_set
2024-01-10 11:16:12 +08:00
2024-04-10 19:59:13 +08:00
def get_feedback_prompt_last ( id , prompt , result , error_list , error_black_set ) :
# wrong_format_set,wrong_predicate_set,wrong_object_set = error_list
2024-01-10 11:16:12 +08:00
2024-04-10 19:59:13 +08:00
error_message = " "
2024-01-10 11:16:12 +08:00
2024-04-10 19:59:13 +08:00
if error_list [ 0 ] != None :
error_message + = " It contains syntax errors or illegal characters. "
# ("It Contains syntax errors or illegal characters that cannot be converted to disjunctive normal form (DNF) using sympy.to_dnf. ")
# "Please check the syntax in your input and ensure there are no prohibited characters. The answer should consist only of ~, |, &, and the given [Condition] and [Object].. ")
else :
if error_list [ 1 ] != set ( ) :
# error_strings = ", ".join(error_list[1])
# error_message += f"\"{error_strings}\" have format errors. They should consist only of ~, |, &, and the given [Condition] and [Object].\n"
error_black_set [ 0 ] | = set ( error_list [ 1 ] )
if error_list [ 2 ] != set ( ) :
# error_strings = ", ".join(error_list[2])
# error_message += f"\"{error_strings}\" are not in [Condition]. Please select the closest predicates from the [Condition] table to form the answer.\n"
error_black_set [ 1 ] | = set ( error_list [ 2 ] )
if error_list [ 3 ] != set ( ) :
# error_strings = ", ".join(error_list[3])
# error_message += f"\"{error_strings}\" are not in [Object]. Please select the closest parameter from the [Object] table to form the answer.\n"
error_black_set [ 2 ] | = set ( error_list [ 3 ] )
# error_strings = "Do not include: "+", ".join(list(error_black_set))+"." +"Please select the closest parameter from the [Condition] and [Object] table to form the answer."
er_word0 = " , " . join ( list ( error_black_set [ 0 ] ) )
er_word1 = " , " . join ( list ( error_black_set [ 1 ] ) )
er_word2 = " , " . join ( list ( error_black_set [ 2 ] ) )
error_message + = f " \" { er_word0 } \" have format errors. The answer should consist only of ~, |, &, and the given [Condition] and [Object]. \n "
error_message + = f " \" { er_word1 } \" are not in [Condition]. Please select the closest predicates from the [Condition] table to form the answer. \n "
error_message + = f " \" { er_word2 } \" are not in [Object]. Please select the closest parameter from the [Object] table to form the answer. \n "
print ( " ** Error_message: " , error_message )
prompt + = " \n " + error_message
return prompt
2024-01-10 11:16:12 +08:00
2024-04-10 19:59:13 +08:00
def get_feedback_prompt0123 ( id , prompt , result , error_list , error_black_set ) :
# wrong_format_set,wrong_predicate_set,wrong_object_set = error_list
2024-01-10 11:16:12 +08:00
2024-04-10 19:59:13 +08:00
error_message = " "
2024-01-10 11:16:12 +08:00
2024-04-10 19:59:13 +08:00
if error_list [ 0 ] != None :
error_message + = " It contains syntax errors or illegal characters. "
else :
if error_list [ 1 ] != set ( ) :
error_black_set [ 0 ] | = set ( error_list [ 1 ] )
er_word0 = " , " . join ( list ( error_black_set [ 0 ] ) )
error_message + = f " \" { er_word0 } \" have format errors. The answer should consist only of ~, |, &, and the given [Condition] and [Object]. \n "
if error_list [ 2 ] != set ( ) :
error_black_set [ 1 ] | = set ( error_list [ 2 ] )
er_word1 = " , " . join ( list ( error_black_set [ 1 ] ) )
error_message + = f " \" { er_word1 } \" are not in [Condition]. Please select the closest predicates from the [Condition] table to form the answer. \n "
if error_list [ 3 ] != set ( ) :
error_black_set [ 2 ] | = set ( error_list [ 3 ] )
er_word2 = " , " . join ( list ( error_black_set [ 2 ] ) )
error_message + = f " \" { er_word2 } \" are not in [Object]. Please select the closest parameter from the [Object] table to form the answer. \n "
print ( " ** Error_message: " , error_message )
prompt + = " \n " + error_message
return prompt
2024-01-10 11:16:12 +08:00
2024-04-10 19:59:13 +08:00
def get_feedback_prompt ( id , prompt1 , prompt2 , question , result , error_list , error_black_set ) :
2024-01-10 11:16:12 +08:00
2024-04-10 19:59:13 +08:00
error_message = " "
er_word0 = " "
er_word1 = " "
er_word2 = " "
2024-01-10 11:16:12 +08:00
2024-04-10 19:59:13 +08:00
if error_list [ 0 ] != None :
error_message = " "
2024-01-10 11:16:12 +08:00
else :
2024-04-10 19:59:13 +08:00
if error_list [ 1 ] != set ( ) :
error_black_set [ 0 ] | = set ( error_list [ 1 ] )
2024-01-10 11:16:12 +08:00
2024-04-10 19:59:13 +08:00
if error_list [ 2 ] != set ( ) :
error_black_set [ 1 ] | = set ( error_list [ 2 ] )
2024-01-10 11:16:12 +08:00
2024-04-10 19:59:13 +08:00
if error_list [ 3 ] != set ( ) :
error_black_set [ 2 ] | = set ( error_list [ 3 ] )
2024-01-10 11:16:12 +08:00
2024-04-10 19:59:13 +08:00
er_word0 = " , " . join ( list ( error_black_set [ 0 ] ) )
er_word1 = " , " . join ( list ( error_black_set [ 1 ] ) )
er_word2 = " , " . join ( list ( error_black_set [ 2 ] ) )
error_message + = f " \n [Blacklist] \n <Illegal Condition>=[ { er_word1 } ] \n <Illegal Object>=[ { er_word2 } ] \n <Other Illegal Words or Characters>=[ { er_word0 } ] \n "
error_message + = " \n [Blacklist] Contains restricted elements. \n " + \
" If a word from <Illegal Condition> is encountered, choose the nearest parameter with a similar meaning from the [Condition] table to formulate the answer. \n " + \
" If a word from <Illegal Object> is encountered, choose the nearest parameter with a similar meaning from the [Object] table to formulate the answer. "
print ( " ** Blacklist: " , f " [Blacklist] \n <Illegal Characters>=[ { er_word0 } ] \n <Illegal Condition>=[ { er_word1 } ] \n <Illegal Object>=[ { er_word2 } ] " )
# prompt = prompt1+prompt2+error_message
# prompt = error_message
# print(prompt)
prompt = prompt1 + prompt2
# prompt+= question + result
prompt + = error_message
print ( error_message )
2024-01-10 11:16:12 +08:00
return prompt
2024-04-10 19:59:13 +08:00
# data_set_file = "easy.txt"
easy_data_set_file = " ../dataset/easy_instr_goal.txt "
medium_data_set_file = " ../dataset/medium_instr_goal.txt "
hard_data_set_file = " ../dataset/hard_instr_goal.txt "
test_data_set_file = " ../dataset/test.txt "
data_set_file = " ../dataset/data100.txt "
prompt_file1 = " prompt_test1.txt "
prompt_file2 = " prompt_test2.txt "
# prompt_file1 = "prompt1.txt"
# prompt_file2 = "prompt2.txt"
test_num = 1
with open ( easy_data_set_file , ' r ' , encoding = " utf-8 " ) as f :
easy_data_set = f . read ( ) . strip ( )
with open ( medium_data_set_file , ' r ' , encoding = " utf-8 " ) as f :
medium_data_set = f . read ( ) . strip ( )
with open ( hard_data_set_file , ' r ' , encoding = " utf-8 " ) as f :
hard_data_set = f . read ( ) . strip ( )
# with open(data_set_file, 'r', encoding="utf-8") as f:
# data_set = f.read().strip()
# easy_sections = re.split(r'\n\s*\n', easy_data_set)
# print("easy:",len(easy_sections))
# medium_sections = re.split(r'\n\s*\n', medium_data_set)
# print("medium:",len(medium_sections))
# hard_sections = re.split(r'\n\s*\n', hard_data_set)
# print("hard:",len(hard_sections))
# data_set = easy_data_set & medium_data_set
# print(data_set)
2024-01-10 11:16:12 +08:00
2024-04-10 19:59:13 +08:00
with open ( test_data_set_file , ' r ' , encoding = " utf-8 " ) as f :
test_data_set = f . read ( ) . strip ( )
data_set = test_data_set
# data_set =hard_data_set
2024-01-10 11:16:12 +08:00
2024-04-10 19:59:13 +08:00
# data_set = easy_data_set + medium_data_set + hard_data_set
2024-01-10 11:16:12 +08:00
2024-04-10 19:59:13 +08:00
with open ( prompt_file1 , ' r ' , encoding = " utf-8 " ) as f :
prompt1 = f . read ( ) . strip ( )
with open ( prompt_file2 , ' r ' , encoding = " utf-8 " ) as f :
prompt2 = f . read ( ) . strip ( )
2024-01-10 11:16:12 +08:00
2024-04-10 19:59:13 +08:00
prompt = prompt1 + prompt2
2024-01-10 11:16:12 +08:00
sections = re . split ( r ' \ n \ s* \ n ' , data_set )
2024-04-10 19:59:13 +08:00
# print("data_set:",len(sections))
2024-01-10 11:16:12 +08:00
count = 0
2024-04-10 19:59:13 +08:00
2024-01-10 11:16:12 +08:00
llm = LLMERNIE ( )
question_list = [ ]
2024-04-10 19:59:13 +08:00
correct_answer_list = [ ]
correct_answer_ls_set = [ ]
2024-01-10 11:16:12 +08:00
outputs_list = [ [ ] for _ in range ( len ( sections ) ) ]
2024-04-10 19:59:13 +08:00
2024-01-10 11:16:12 +08:00
# 批量提交问题
2024-04-10 19:59:13 +08:00
# for i,s in enumerate(sections[:test_num]):
for i , s in enumerate ( sections ) :
x , y = s . strip ( ) . splitlines ( )
2024-01-10 11:16:12 +08:00
x = x . strip ( )
2024-04-10 19:59:13 +08:00
y = y . strip ( ) . replace ( " Goal: " , " " )
2024-01-10 11:16:12 +08:00
# print(f"x: {x.strip()}, y: {y.strip()}")
question_list . append ( x )
correct_answer_list . append ( y )
2024-04-10 19:59:13 +08:00
correct_answer_ls_set . append ( goal_transfer_ls_set ( y ) )
llm . ask ( x , prompt = prompt , tag = i )
2024-01-10 11:16:12 +08:00
2024-04-10 19:59:13 +08:00
total_num = len ( question_list )
2024-01-10 11:16:12 +08:00
2024-04-10 19:59:13 +08:00
error_black_ls = [ [ set ( ) , set ( ) , set ( ) ] for _ in range ( total_num ) ]
2024-01-10 11:16:12 +08:00
2024-04-10 19:59:13 +08:00
try_times = 1
# total_GR_ls = np.zeros(5)
total_GR_ls = [ ]
total_SR_ls = [ ]
total_GCR_ls = [ ]
# for time in range(try_times):
2024-01-10 11:16:12 +08:00
finish_num = 0
SR = 0
GR = 0
2024-04-10 19:59:13 +08:00
GCR = 0
# 统计语法正确的数量
GR_ls = np . zeros ( 6 )
2024-01-10 11:16:12 +08:00
while finish_num < total_num :
result = llm . get_result ( )
2024-04-10 19:59:13 +08:00
2024-01-10 11:16:12 +08:00
if result :
# print(result)
id , question , answer = result
# print(correct_answer_list[id])
outputs_list [ id ] . append ( answer )
#如果不正确,且回答次数<5, 则反馈
2024-04-10 19:59:13 +08:00
answer = " (On_Juice_Table6 | ~Exist_Juice=>On_Coffee_Table6 ) & ( ~Low_AC | Open_Curtain ) "
2024-01-10 11:16:12 +08:00
format_correct , error_list = format_check ( answer )
2024-04-10 19:59:13 +08:00
print ( " error_list: " , error_list )
2024-01-10 11:16:12 +08:00
2024-04-10 19:59:13 +08:00
print ( f " ===== id: { id } Q: { question } ===== " )
2024-01-10 11:16:12 +08:00
2024-04-10 19:59:13 +08:00
if not format_correct :
if len ( outputs_list [ id ] ) < 6 :
print ( " *** answer: " , answer )
# new_prompt = get_feedback_prompt(id,prompt1,prompt2,answer,error_list,error_black_ls[id])
new_prompt = get_feedback_prompt ( id , prompt1 , prompt2 , question_list [ id ] , answer , error_list , error_black_ls [ id ] )
llm . ask ( question , prompt = new_prompt , tag = id )
# llm.ask(question, prompt=prompt, tag=id)
print ( f " Retry: { len ( outputs_list [ id ] ) } A: { outputs_list [ id ] } " )
# print(f"id: {id} Retry:{len(outputs_list[id])} A:{outputs_list[id]}, Q: {question}")
2024-01-10 11:16:12 +08:00
else :
finish_num + = 1
2024-04-10 19:59:13 +08:00
print ( f " A: { outputs_list [ id ] } " )
print ( f " CA: { correct_answer_list [ id ] } " )
GR_r = GR_ls / finish_num
gr_s = 0
for i in range ( len ( GR_r ) ) :
gr_s + = GR_r [ i ]
GR_r [ i ] = gr_s
print ( f " Correct:False GR: { GR_r [ 0 ] , GR_r [ 1 ] , GR_r [ - 1 ] } SR: { SR / finish_num } GCR: { GCR / finish_num } " )
2024-01-10 11:16:12 +08:00
else :
2024-04-10 19:59:13 +08:00
GR_ls [ len ( outputs_list [ id ] ) - 1 ] + = 1
2024-01-10 11:16:12 +08:00
correct = False
2024-04-10 19:59:13 +08:00
answer_ls_set = goal_transfer_ls_set ( answer )
if answer_ls_set == correct_answer_ls_set [ id ] :
2024-01-10 11:16:12 +08:00
SR + = 1
2024-04-10 19:59:13 +08:00
GCR + = 1
2024-01-10 11:16:12 +08:00
correct = True
2024-04-10 19:59:13 +08:00
else :
GCR + = len ( [ a_set for a_set in answer_ls_set if a_set in correct_answer_ls_set [ id ] ] ) * 1.0 / len ( correct_answer_ls_set [ id ] )
# GCR += len(answer_ls_set-correct_answer_ls_set)*1.0/len(correct_answer_ls_set)
2024-01-10 11:16:12 +08:00
# print(f"correct_num: {correct_num}")
# print()
finish_num + = 1
2024-04-10 19:59:13 +08:00
# print(f"Correct:{correct} GR:{GR/finish_num} SR:{SR/finish_num} A:{outputs_list[id]}, Q: {question}")
# print("id=",id,"answer:", answer, " == correct_answer_list[id]:", correct_answer_list[id])
print ( f " A: { outputs_list [ id ] } " )
print ( f " CA: { correct_answer_list [ id ] } " )
GR_r = GR_ls / finish_num
gr_s = 0
for i in range ( len ( GR_r ) ) :
gr_s + = GR_r [ i ]
GR_r [ i ] = gr_s
print ( f " Correct: { correct } GR: { GR_r [ 0 ] , GR_r [ 1 ] , GR_r [ - 1 ] } SR: { SR / finish_num } GCR: { GCR / finish_num } " )
total_GR_ls . append ( GR_ls )
total_SR_ls . append ( SR )
total_GCR_ls . append ( GCR )
2024-01-10 11:16:12 +08:00
else :
time . sleep ( 0.01 )
llm . close ( )
# print(f"result: {result}")
# print(f"y: {y}")
# print(f"count: {count}")
# print()
#
# if result == y:
# count += 1
2024-04-10 19:59:13 +08:00
# print("GR =:", round(np.mean(total_GR_ls), 4), "std=", round(np.std(total_GR_ls), 4))
# print("SR = ", round(np.mean(total_SR_ls), 3), "std=", round(np.std(total_SR_ls, ddof=1), 3))
# print("GCR = :", round(np.mean(total_GCR_ls), 3), "std=", round(np.std(total_GCR_ls, ddof=1), 3))
2024-01-10 11:16:12 +08:00