209 lines
5.7 KiB
Python
209 lines
5.7 KiB
Python
|
import time
|
|||
|
|
|||
|
from llm_test.ERNIE_Bot_4 import LLMERNIE
|
|||
|
from sympy.parsing.sympy_parser import parse_expr
|
|||
|
from sympy import symbols, Not, Or, And, to_dnf
|
|||
|
from sympy import symbols, simplify_logic
|
|||
|
|
|||
|
import re
|
|||
|
|
|||
|
# 定义要拆分的多个字符
|
|||
|
split_characters = r'[()&|]' # 用正则表达式定义多个字符,包括逗号、句号、分号和感叹号等
|
|||
|
|
|||
|
|
|||
|
predicate_list = {"RobotNear","On","Holding","Exists","IsClean","Active","Closed","Low"}
|
|||
|
object_list = {'Coffee', 'Water', 'Dessert', 'Softdrink', 'BottledDrink', 'Yogurt', 'ADMilk',
|
|||
|
'MilkDrink', 'Milk','VacuumCup','Chips', 'NFCJuice', 'Bernachon', 'ADMilk', 'SpringWater',
|
|||
|
'Bar', 'Bar2', 'WaterTable', 'CoffeeTable', 'Table1', 'Table2', 'Table3','BrightTable6',
|
|||
|
'Table1','Floor','Chairs',
|
|||
|
'AC','TubeLight','HallLight',
|
|||
|
'Curtain','ACTemperature'
|
|||
|
}
|
|||
|
|
|||
|
|
|||
|
def format_check(result):
|
|||
|
try:
|
|||
|
goal_dnf = str(to_dnf(result, simplify=True))
|
|||
|
except:
|
|||
|
return False, None
|
|||
|
|
|||
|
split_sentences = re.split(split_characters, result)
|
|||
|
split_sentences = [s.strip() for s in split_sentences if s.strip()]
|
|||
|
|
|||
|
wrong_format_set = set()
|
|||
|
wrong_predicate_set = set()
|
|||
|
wrong_object_set = set()
|
|||
|
|
|||
|
for sentence in split_sentences:
|
|||
|
if sentence == "": continue
|
|||
|
|
|||
|
try:
|
|||
|
goal_dnf = str(to_dnf(sentence, simplify=True))
|
|||
|
# 格式正确
|
|||
|
word_list = sentence.split("_")
|
|||
|
if len(word_list) <=1:
|
|||
|
wrong_format_set.add(sentence)
|
|||
|
continue
|
|||
|
|
|||
|
predicate = word_list[0]
|
|||
|
if predicate not in predicate_list:
|
|||
|
wrong_predicate_set.add(predicate)
|
|||
|
|
|||
|
for object in word_list[1:]:
|
|||
|
if object not in object_list:
|
|||
|
wrong_object_set.add(object)
|
|||
|
|
|||
|
except:
|
|||
|
wrong_format_set.add(sentence)
|
|||
|
|
|||
|
# print(wrong_format_set)
|
|||
|
# print(wrong_predicate_set)
|
|||
|
# print(wrong_object_set)
|
|||
|
if len(wrong_format_set) == 0 and \
|
|||
|
len(wrong_predicate_set) == 0 and\
|
|||
|
len(wrong_object_set) == 0:
|
|||
|
return True, None
|
|||
|
else:
|
|||
|
return False, [wrong_format_set,wrong_predicate_set,wrong_object_set]
|
|||
|
|
|||
|
def word_correct(sentence):
|
|||
|
try:
|
|||
|
goal_dnf = str(to_dnf(sentence, simplify=True))
|
|||
|
return True
|
|||
|
except:
|
|||
|
return False
|
|||
|
|
|||
|
|
|||
|
def get_feedback_prompt(prompt,result,error_list):
|
|||
|
# wrong_format_set,wrong_predicate_set,wrong_object_set = error_list
|
|||
|
return prompt
|
|||
|
|
|||
|
|
|||
|
# def get_feedback_prompt(prompt,result):
|
|||
|
# # split_sentences = re.split(split_characters, result)
|
|||
|
# # split_sentences = [s.strip() for s in split_sentences if s.strip()]
|
|||
|
# #
|
|||
|
# # wrong_format_set = set()
|
|||
|
# # wrong_predicate_set = set()
|
|||
|
# # wrong_object_set = set()
|
|||
|
# #
|
|||
|
# # for sentence in split_sentences:
|
|||
|
# # if sentence == "": continue
|
|||
|
# #
|
|||
|
# # try:
|
|||
|
# # goal_dnf = str(to_dnf(sentence, simplify=True))
|
|||
|
# # # 格式正确
|
|||
|
# # word_list = sentence.split("_")
|
|||
|
# # if len(word_list) <=1:
|
|||
|
# # wrong_format_set.add(sentence)
|
|||
|
# # continue
|
|||
|
# #
|
|||
|
# # predicate = word_list[0]
|
|||
|
# # if predicate not in predicate_list:
|
|||
|
# # wrong_predicate_set.add(predicate)
|
|||
|
# #
|
|||
|
# # for object in word_list[1:]:
|
|||
|
# # if object not in object_list:
|
|||
|
# # wrong_object_set.add(object)
|
|||
|
# #
|
|||
|
# # except:
|
|||
|
# # wrong_format_set.add(sentence)
|
|||
|
# #
|
|||
|
# # print(wrong_format_set)
|
|||
|
# # print(wrong_predicate_set)
|
|||
|
# # print(wrong_object_set)
|
|||
|
|
|||
|
|
|||
|
# a = "(On_S dfd oftdrink_Table3 & At_Chairs_Desk | & At_Chairs_Desk)"
|
|||
|
# print(feedback(a))
|
|||
|
#
|
|||
|
#
|
|||
|
#
|
|||
|
# exit()
|
|||
|
|
|||
|
|
|||
|
|
|||
|
data_set_file = "easy.txt"
|
|||
|
prompt_file = "prompt.txt"
|
|||
|
test_num = 2
|
|||
|
|
|||
|
|
|||
|
with open(data_set_file, 'r', encoding="utf-8") as f:
|
|||
|
data_set = f.read().strip()
|
|||
|
|
|||
|
with open(prompt_file, 'r', encoding="utf-8") as f:
|
|||
|
prompt = f.read().strip()
|
|||
|
|
|||
|
sections = re.split(r'\n\s*\n', data_set)
|
|||
|
count = 0
|
|||
|
|
|||
|
llm = LLMERNIE()
|
|||
|
question_list = []
|
|||
|
correct_answer_list= []
|
|||
|
outputs_list = [[] for _ in range(len(sections))]
|
|||
|
|
|||
|
# 批量提交问题
|
|||
|
for i,s in enumerate(sections[:test_num]) :
|
|||
|
x,y = s.strip().splitlines()
|
|||
|
x = x.strip()
|
|||
|
y = y.strip()
|
|||
|
# print(f"x: {x.strip()}, y: {y.strip()}")
|
|||
|
question_list.append(x)
|
|||
|
correct_answer_list.append(y)
|
|||
|
llm.ask(x,prompt=prompt,tag=i)
|
|||
|
|
|||
|
|
|||
|
|
|||
|
total_num = len(question_list)
|
|||
|
finish_num = 0
|
|||
|
SR = 0
|
|||
|
GR = 0
|
|||
|
while finish_num < total_num:
|
|||
|
result = llm.get_result()
|
|||
|
if result:
|
|||
|
# print(result)
|
|||
|
id,question,answer = result
|
|||
|
# print(correct_answer_list[id])
|
|||
|
outputs_list[id].append(answer)
|
|||
|
|
|||
|
#如果不正确,且回答次数<5,则反馈
|
|||
|
|
|||
|
format_correct,error_list = format_check(answer)
|
|||
|
|
|||
|
if not format_correct:
|
|||
|
if len(outputs_list[id]) < 5:
|
|||
|
|
|||
|
new_prompt = get_feedback_prompt(prompt,answer,error_list)
|
|||
|
llm.ask(question, prompt=prompt, tag=id)
|
|||
|
|
|||
|
print(f"id: {id} Retry:{len(outputs_list[id])} A:{outputs_list[id]}, Q: {question}")
|
|||
|
else:
|
|||
|
finish_num += 1
|
|||
|
|
|||
|
else:
|
|||
|
GR += 1
|
|||
|
|
|||
|
correct = False
|
|||
|
if answer == correct_answer_list[id]:
|
|||
|
SR += 1
|
|||
|
correct = True
|
|||
|
# print(f"correct_num: {correct_num}")
|
|||
|
# print()
|
|||
|
finish_num += 1
|
|||
|
|
|||
|
print(f"Correct:{correct} GR:{GR/finish_num} SR:{SR/finish_num} A:{outputs_list[id]}, Q: {question}")
|
|||
|
|
|||
|
else:
|
|||
|
time.sleep(0.01)
|
|||
|
|
|||
|
llm.close()
|
|||
|
# print(f"result: {result}")
|
|||
|
# print(f"y: {y}")
|
|||
|
# print(f"count: {count}")
|
|||
|
# print()
|
|||
|
#
|
|||
|
# if result == y:
|
|||
|
# count += 1
|
|||
|
|
|||
|
|