RoboWaiter/BTExpansionCode/llm_test/main_old.py

209 lines
5.7 KiB
Python
Raw Normal View History

2024-04-10 19:59:13 +08:00
import time
from llm_test.ERNIE_Bot_4 import LLMERNIE
from sympy.parsing.sympy_parser import parse_expr
from sympy import symbols, Not, Or, And, to_dnf
from sympy import symbols, simplify_logic
import re
# 定义要拆分的多个字符
split_characters = r'[()&|]' # 用正则表达式定义多个字符,包括逗号、句号、分号和感叹号等
predicate_list = {"RobotNear","On","Holding","Exists","IsClean","Active","Closed","Low"}
object_list = {'Coffee', 'Water', 'Dessert', 'Softdrink', 'BottledDrink', 'Yogurt', 'ADMilk',
'MilkDrink', 'Milk','VacuumCup','Chips', 'NFCJuice', 'Bernachon', 'ADMilk', 'SpringWater',
'Bar', 'Bar2', 'WaterTable', 'CoffeeTable', 'Table1', 'Table2', 'Table3','BrightTable6',
'Table1','Floor','Chairs',
'AC','TubeLight','HallLight',
'Curtain','ACTemperature'
}
def format_check(result):
try:
goal_dnf = str(to_dnf(result, simplify=True))
except:
return False, None
split_sentences = re.split(split_characters, result)
split_sentences = [s.strip() for s in split_sentences if s.strip()]
wrong_format_set = set()
wrong_predicate_set = set()
wrong_object_set = set()
for sentence in split_sentences:
if sentence == "": continue
try:
goal_dnf = str(to_dnf(sentence, simplify=True))
# 格式正确
word_list = sentence.split("_")
if len(word_list) <=1:
wrong_format_set.add(sentence)
continue
predicate = word_list[0]
if predicate not in predicate_list:
wrong_predicate_set.add(predicate)
for object in word_list[1:]:
if object not in object_list:
wrong_object_set.add(object)
except:
wrong_format_set.add(sentence)
# print(wrong_format_set)
# print(wrong_predicate_set)
# print(wrong_object_set)
if len(wrong_format_set) == 0 and \
len(wrong_predicate_set) == 0 and\
len(wrong_object_set) == 0:
return True, None
else:
return False, [wrong_format_set,wrong_predicate_set,wrong_object_set]
def word_correct(sentence):
try:
goal_dnf = str(to_dnf(sentence, simplify=True))
return True
except:
return False
def get_feedback_prompt(prompt,result,error_list):
# wrong_format_set,wrong_predicate_set,wrong_object_set = error_list
return prompt
# def get_feedback_prompt(prompt,result):
# # split_sentences = re.split(split_characters, result)
# # split_sentences = [s.strip() for s in split_sentences if s.strip()]
# #
# # wrong_format_set = set()
# # wrong_predicate_set = set()
# # wrong_object_set = set()
# #
# # for sentence in split_sentences:
# # if sentence == "": continue
# #
# # try:
# # goal_dnf = str(to_dnf(sentence, simplify=True))
# # # 格式正确
# # word_list = sentence.split("_")
# # if len(word_list) <=1:
# # wrong_format_set.add(sentence)
# # continue
# #
# # predicate = word_list[0]
# # if predicate not in predicate_list:
# # wrong_predicate_set.add(predicate)
# #
# # for object in word_list[1:]:
# # if object not in object_list:
# # wrong_object_set.add(object)
# #
# # except:
# # wrong_format_set.add(sentence)
# #
# # print(wrong_format_set)
# # print(wrong_predicate_set)
# # print(wrong_object_set)
# a = "(On_S dfd oftdrink_Table3 & At_Chairs_Desk | & At_Chairs_Desk)"
# print(feedback(a))
#
#
#
# exit()
data_set_file = "easy.txt"
prompt_file = "prompt.txt"
test_num = 2
with open(data_set_file, 'r', encoding="utf-8") as f:
data_set = f.read().strip()
with open(prompt_file, 'r', encoding="utf-8") as f:
prompt = f.read().strip()
sections = re.split(r'\n\s*\n', data_set)
count = 0
llm = LLMERNIE()
question_list = []
correct_answer_list= []
outputs_list = [[] for _ in range(len(sections))]
# 批量提交问题
for i,s in enumerate(sections[:test_num]) :
x,y = s.strip().splitlines()
x = x.strip()
y = y.strip()
# print(f"x: {x.strip()}, y: {y.strip()}")
question_list.append(x)
correct_answer_list.append(y)
llm.ask(x,prompt=prompt,tag=i)
total_num = len(question_list)
finish_num = 0
SR = 0
GR = 0
while finish_num < total_num:
result = llm.get_result()
if result:
# print(result)
id,question,answer = result
# print(correct_answer_list[id])
outputs_list[id].append(answer)
#如果不正确,且回答次数<5则反馈
format_correct,error_list = format_check(answer)
if not format_correct:
if len(outputs_list[id]) < 5:
new_prompt = get_feedback_prompt(prompt,answer,error_list)
llm.ask(question, prompt=prompt, tag=id)
print(f"id: {id} Retry:{len(outputs_list[id])} A:{outputs_list[id]}, Q: {question}")
else:
finish_num += 1
else:
GR += 1
correct = False
if answer == correct_answer_list[id]:
SR += 1
correct = True
# print(f"correct_num: {correct_num}")
# print()
finish_num += 1
print(f"Correct:{correct} GR:{GR/finish_num} SR:{SR/finish_num} A:{outputs_list[id]}, Q: {question}")
else:
time.sleep(0.01)
llm.close()
# print(f"result: {result}")
# print(f"y: {y}")
# print(f"count: {count}")
# print()
#
# if result == y:
# count += 1