209 lines
5.7 KiB
Python
209 lines
5.7 KiB
Python
import time
|
||
|
||
from llm_test.ERNIE_Bot_4 import LLMERNIE
|
||
from sympy.parsing.sympy_parser import parse_expr
|
||
from sympy import symbols, Not, Or, And, to_dnf
|
||
from sympy import symbols, simplify_logic
|
||
|
||
import re
|
||
|
||
# 定义要拆分的多个字符
|
||
split_characters = r'[()&|]' # 用正则表达式定义多个字符,包括逗号、句号、分号和感叹号等
|
||
|
||
|
||
predicate_list = {"RobotNear","On","Holding","Exists","IsClean","Active","Closed","Low"}
|
||
object_list = {'Coffee', 'Water', 'Dessert', 'Softdrink', 'BottledDrink', 'Yogurt', 'ADMilk',
|
||
'MilkDrink', 'Milk','VacuumCup','Chips', 'NFCJuice', 'Bernachon', 'ADMilk', 'SpringWater',
|
||
'Bar', 'Bar2', 'WaterTable', 'CoffeeTable', 'Table1', 'Table2', 'Table3','BrightTable6',
|
||
'Table1','Floor','Chairs',
|
||
'AC','TubeLight','HallLight',
|
||
'Curtain','ACTemperature'
|
||
}
|
||
|
||
|
||
def format_check(result):
|
||
try:
|
||
goal_dnf = str(to_dnf(result, simplify=True))
|
||
except:
|
||
return False, None
|
||
|
||
split_sentences = re.split(split_characters, result)
|
||
split_sentences = [s.strip() for s in split_sentences if s.strip()]
|
||
|
||
wrong_format_set = set()
|
||
wrong_predicate_set = set()
|
||
wrong_object_set = set()
|
||
|
||
for sentence in split_sentences:
|
||
if sentence == "": continue
|
||
|
||
try:
|
||
goal_dnf = str(to_dnf(sentence, simplify=True))
|
||
# 格式正确
|
||
word_list = sentence.split("_")
|
||
if len(word_list) <=1:
|
||
wrong_format_set.add(sentence)
|
||
continue
|
||
|
||
predicate = word_list[0]
|
||
if predicate not in predicate_list:
|
||
wrong_predicate_set.add(predicate)
|
||
|
||
for object in word_list[1:]:
|
||
if object not in object_list:
|
||
wrong_object_set.add(object)
|
||
|
||
except:
|
||
wrong_format_set.add(sentence)
|
||
|
||
# print(wrong_format_set)
|
||
# print(wrong_predicate_set)
|
||
# print(wrong_object_set)
|
||
if len(wrong_format_set) == 0 and \
|
||
len(wrong_predicate_set) == 0 and\
|
||
len(wrong_object_set) == 0:
|
||
return True, None
|
||
else:
|
||
return False, [wrong_format_set,wrong_predicate_set,wrong_object_set]
|
||
|
||
def word_correct(sentence):
|
||
try:
|
||
goal_dnf = str(to_dnf(sentence, simplify=True))
|
||
return True
|
||
except:
|
||
return False
|
||
|
||
|
||
def get_feedback_prompt(prompt,result,error_list):
|
||
# wrong_format_set,wrong_predicate_set,wrong_object_set = error_list
|
||
return prompt
|
||
|
||
|
||
# def get_feedback_prompt(prompt,result):
|
||
# # split_sentences = re.split(split_characters, result)
|
||
# # split_sentences = [s.strip() for s in split_sentences if s.strip()]
|
||
# #
|
||
# # wrong_format_set = set()
|
||
# # wrong_predicate_set = set()
|
||
# # wrong_object_set = set()
|
||
# #
|
||
# # for sentence in split_sentences:
|
||
# # if sentence == "": continue
|
||
# #
|
||
# # try:
|
||
# # goal_dnf = str(to_dnf(sentence, simplify=True))
|
||
# # # 格式正确
|
||
# # word_list = sentence.split("_")
|
||
# # if len(word_list) <=1:
|
||
# # wrong_format_set.add(sentence)
|
||
# # continue
|
||
# #
|
||
# # predicate = word_list[0]
|
||
# # if predicate not in predicate_list:
|
||
# # wrong_predicate_set.add(predicate)
|
||
# #
|
||
# # for object in word_list[1:]:
|
||
# # if object not in object_list:
|
||
# # wrong_object_set.add(object)
|
||
# #
|
||
# # except:
|
||
# # wrong_format_set.add(sentence)
|
||
# #
|
||
# # print(wrong_format_set)
|
||
# # print(wrong_predicate_set)
|
||
# # print(wrong_object_set)
|
||
|
||
|
||
# a = "(On_S dfd oftdrink_Table3 & At_Chairs_Desk | & At_Chairs_Desk)"
|
||
# print(feedback(a))
|
||
#
|
||
#
|
||
#
|
||
# exit()
|
||
|
||
|
||
|
||
data_set_file = "easy.txt"
|
||
prompt_file = "prompt.txt"
|
||
test_num = 2
|
||
|
||
|
||
with open(data_set_file, 'r', encoding="utf-8") as f:
|
||
data_set = f.read().strip()
|
||
|
||
with open(prompt_file, 'r', encoding="utf-8") as f:
|
||
prompt = f.read().strip()
|
||
|
||
sections = re.split(r'\n\s*\n', data_set)
|
||
count = 0
|
||
|
||
llm = LLMERNIE()
|
||
question_list = []
|
||
correct_answer_list= []
|
||
outputs_list = [[] for _ in range(len(sections))]
|
||
|
||
# 批量提交问题
|
||
for i,s in enumerate(sections[:test_num]) :
|
||
x,y = s.strip().splitlines()
|
||
x = x.strip()
|
||
y = y.strip()
|
||
# print(f"x: {x.strip()}, y: {y.strip()}")
|
||
question_list.append(x)
|
||
correct_answer_list.append(y)
|
||
llm.ask(x,prompt=prompt,tag=i)
|
||
|
||
|
||
|
||
total_num = len(question_list)
|
||
finish_num = 0
|
||
SR = 0
|
||
GR = 0
|
||
while finish_num < total_num:
|
||
result = llm.get_result()
|
||
if result:
|
||
# print(result)
|
||
id,question,answer = result
|
||
# print(correct_answer_list[id])
|
||
outputs_list[id].append(answer)
|
||
|
||
#如果不正确,且回答次数<5,则反馈
|
||
|
||
format_correct,error_list = format_check(answer)
|
||
|
||
if not format_correct:
|
||
if len(outputs_list[id]) < 5:
|
||
|
||
new_prompt = get_feedback_prompt(prompt,answer,error_list)
|
||
llm.ask(question, prompt=prompt, tag=id)
|
||
|
||
print(f"id: {id} Retry:{len(outputs_list[id])} A:{outputs_list[id]}, Q: {question}")
|
||
else:
|
||
finish_num += 1
|
||
|
||
else:
|
||
GR += 1
|
||
|
||
correct = False
|
||
if answer == correct_answer_list[id]:
|
||
SR += 1
|
||
correct = True
|
||
# print(f"correct_num: {correct_num}")
|
||
# print()
|
||
finish_num += 1
|
||
|
||
print(f"Correct:{correct} GR:{GR/finish_num} SR:{SR/finish_num} A:{outputs_list[id]}, Q: {question}")
|
||
|
||
else:
|
||
time.sleep(0.01)
|
||
|
||
llm.close()
|
||
# print(f"result: {result}")
|
||
# print(f"y: {y}")
|
||
# print(f"count: {count}")
|
||
# print()
|
||
#
|
||
# if result == y:
|
||
# count += 1
|
||
|
||
|