RoboWaiter/BTExpansionCode/llm_test/main_old.py

209 lines
5.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import time
from llm_test.ERNIE_Bot_4 import LLMERNIE
from sympy.parsing.sympy_parser import parse_expr
from sympy import symbols, Not, Or, And, to_dnf
from sympy import symbols, simplify_logic
import re
# 定义要拆分的多个字符
split_characters = r'[()&|]' # 用正则表达式定义多个字符,包括逗号、句号、分号和感叹号等
predicate_list = {"RobotNear","On","Holding","Exists","IsClean","Active","Closed","Low"}
object_list = {'Coffee', 'Water', 'Dessert', 'Softdrink', 'BottledDrink', 'Yogurt', 'ADMilk',
'MilkDrink', 'Milk','VacuumCup','Chips', 'NFCJuice', 'Bernachon', 'ADMilk', 'SpringWater',
'Bar', 'Bar2', 'WaterTable', 'CoffeeTable', 'Table1', 'Table2', 'Table3','BrightTable6',
'Table1','Floor','Chairs',
'AC','TubeLight','HallLight',
'Curtain','ACTemperature'
}
def format_check(result):
try:
goal_dnf = str(to_dnf(result, simplify=True))
except:
return False, None
split_sentences = re.split(split_characters, result)
split_sentences = [s.strip() for s in split_sentences if s.strip()]
wrong_format_set = set()
wrong_predicate_set = set()
wrong_object_set = set()
for sentence in split_sentences:
if sentence == "": continue
try:
goal_dnf = str(to_dnf(sentence, simplify=True))
# 格式正确
word_list = sentence.split("_")
if len(word_list) <=1:
wrong_format_set.add(sentence)
continue
predicate = word_list[0]
if predicate not in predicate_list:
wrong_predicate_set.add(predicate)
for object in word_list[1:]:
if object not in object_list:
wrong_object_set.add(object)
except:
wrong_format_set.add(sentence)
# print(wrong_format_set)
# print(wrong_predicate_set)
# print(wrong_object_set)
if len(wrong_format_set) == 0 and \
len(wrong_predicate_set) == 0 and\
len(wrong_object_set) == 0:
return True, None
else:
return False, [wrong_format_set,wrong_predicate_set,wrong_object_set]
def word_correct(sentence):
try:
goal_dnf = str(to_dnf(sentence, simplify=True))
return True
except:
return False
def get_feedback_prompt(prompt,result,error_list):
# wrong_format_set,wrong_predicate_set,wrong_object_set = error_list
return prompt
# def get_feedback_prompt(prompt,result):
# # split_sentences = re.split(split_characters, result)
# # split_sentences = [s.strip() for s in split_sentences if s.strip()]
# #
# # wrong_format_set = set()
# # wrong_predicate_set = set()
# # wrong_object_set = set()
# #
# # for sentence in split_sentences:
# # if sentence == "": continue
# #
# # try:
# # goal_dnf = str(to_dnf(sentence, simplify=True))
# # # 格式正确
# # word_list = sentence.split("_")
# # if len(word_list) <=1:
# # wrong_format_set.add(sentence)
# # continue
# #
# # predicate = word_list[0]
# # if predicate not in predicate_list:
# # wrong_predicate_set.add(predicate)
# #
# # for object in word_list[1:]:
# # if object not in object_list:
# # wrong_object_set.add(object)
# #
# # except:
# # wrong_format_set.add(sentence)
# #
# # print(wrong_format_set)
# # print(wrong_predicate_set)
# # print(wrong_object_set)
# a = "(On_S dfd oftdrink_Table3 & At_Chairs_Desk | & At_Chairs_Desk)"
# print(feedback(a))
#
#
#
# exit()
data_set_file = "easy.txt"
prompt_file = "prompt.txt"
test_num = 2
with open(data_set_file, 'r', encoding="utf-8") as f:
data_set = f.read().strip()
with open(prompt_file, 'r', encoding="utf-8") as f:
prompt = f.read().strip()
sections = re.split(r'\n\s*\n', data_set)
count = 0
llm = LLMERNIE()
question_list = []
correct_answer_list= []
outputs_list = [[] for _ in range(len(sections))]
# 批量提交问题
for i,s in enumerate(sections[:test_num]) :
x,y = s.strip().splitlines()
x = x.strip()
y = y.strip()
# print(f"x: {x.strip()}, y: {y.strip()}")
question_list.append(x)
correct_answer_list.append(y)
llm.ask(x,prompt=prompt,tag=i)
total_num = len(question_list)
finish_num = 0
SR = 0
GR = 0
while finish_num < total_num:
result = llm.get_result()
if result:
# print(result)
id,question,answer = result
# print(correct_answer_list[id])
outputs_list[id].append(answer)
#如果不正确,且回答次数<5则反馈
format_correct,error_list = format_check(answer)
if not format_correct:
if len(outputs_list[id]) < 5:
new_prompt = get_feedback_prompt(prompt,answer,error_list)
llm.ask(question, prompt=prompt, tag=id)
print(f"id: {id} Retry:{len(outputs_list[id])} A:{outputs_list[id]}, Q: {question}")
else:
finish_num += 1
else:
GR += 1
correct = False
if answer == correct_answer_list[id]:
SR += 1
correct = True
# print(f"correct_num: {correct_num}")
# print()
finish_num += 1
print(f"Correct:{correct} GR:{GR/finish_num} SR:{SR/finish_num} A:{outputs_list[id]}, Q: {question}")
else:
time.sleep(0.01)
llm.close()
# print(f"result: {result}")
# print(f"y: {y}")
# print(f"count: {count}")
# print()
#
# if result == y:
# count += 1