RoboWaiter/BTExpansionCode/llm_test/LLM_Evaluation_Kit/llm_ask_tools.py

252 lines
10 KiB
Python
Raw Normal View History

2024-06-27 20:28:54 +08:00
import re
from btpg.algos.llm_client.tools import goal_transfer_str, act_str_process
from btpg.utils import ROOT_PATH
# 导入向量数据库检索的相关函数
from btpg.algos.llm_client.vector_database_env_goal import search_nearest_examples
from ordered_set import OrderedSet
def parse_llm_output(answer,goals=True):
goal_set = set()
priority_act_ls, key_predicate, key_objects = [], [], []
try:
if goals:
goal_str = answer.split("Optimal Actions:")[0].replace("Goals:", "").strip()
goal_set = goal_transfer_str(goal_str)
act_str = answer.split("Optimal Actions:")[1].split("Vital Action Predicates:")[0].strip()
predicate_str = answer.split("Vital Action Predicates:")[1].split("Vital Objects:")[0].strip()
objects_str = answer.split("Vital Objects:")[1].strip()
priority_act_ls = act_str_process(act_str)
# Remove all spaces, Split by comma to create a list
key_predicate = predicate_str.replace(" ", "").split(",")
key_objects = objects_str.replace(" ", "").split(",")
priority_act_ls = list(OrderedSet(priority_act_ls))
key_predicate = list(OrderedSet(key_predicate))
key_objects = list(OrderedSet(key_objects))
except Exception as e:
goal_set, priority_act_ls, key_predicate, key_objects = None,None,None,None
print(f"Failed to parse LLM output: {e}")
if goals:
return goal_set,priority_act_ls,key_predicate,key_objects
else:
return priority_act_ls, key_predicate, key_objects
def format_example(metadata):
"""格式化向量数据库的示例数据为所需的格式"""
example_value = metadata['value']
return (
# f"Instruction: {example_value['Instruction']}\n"
f"Goals: {example_value['Goals']}\n"
f"Optimal Actions: {example_value['Optimal Actions']}\n"
f"Vital Action Predicates: {example_value.get('Vital Action Predicates', '')}\n"
f"Vital Objects: {example_value['Vital Objects']}\n")
def extract_llm_from_instr_goal(llm,default_prompt_file,environment,goals,instruction=None,cur_cond_set=None,\
choose_database=False,\
database_index_path=f"{ROOT_PATH}/../test/dataset/env_instruction_vectors.index",verbose=False):
with open(default_prompt_file, 'r', encoding="utf-8") as f:
prompt = f.read().strip()
distances=None
parsed_output =None
parsed_fail=-1
RED = "\033[31m"
RESET = "\033[0m"
while parsed_output==None:
parsed_fail += 1 # 第一次是第0次。 0-1-2-3
print(f"--- LLM: Goal={goals} Parsed Fail={parsed_fail} --- ")
if parsed_fail > 3:
print(f"{RED}----LLM: Goal={goals} Parsed Fail={parsed_fail} >3 break -----{RESET}")
break
if choose_database:
# environment ?
nearest_examples,distances = search_nearest_examples(database_index_path, llm, goals, top_n=5)
# 使用自定义的格式函数将检索到的示例格式化为目标样式
example_texts = '\n'.join([format_example(ex) for ex in nearest_examples])
example_texts = "[Examples]\n" + example_texts
# 输出最近的所有goal
nearest_goals = [ex['value']['Goals'] for ex in nearest_examples]
print("All Goals from nearest examples:")
for g in nearest_goals:
print(f"\033[93m{g}\033[0m") # 打印黄色 print(goal)
# print("distances:",distances)
# print("example_texts:\n",example_texts)
# 替换 prompt 中的 [Examples] 部分
example_marker = "[Examples]"
if example_marker in prompt:
prompt = prompt.replace(example_marker, example_texts)
else:
prompt = f"{prompt}\n{example_texts}"
# 构建完整的 prompt包括检索的 Examples 和当前的指令
goals_str =' & '.join(goals)
# question = f"{prompt}\nInstruction: {instruction}\nGoals: {goals_str}"
question = f"{prompt}\nGoals: {goals_str}"
if verbose:
print("============ Question ================\n",question)
messages = []
messages.append({"role": "user", "content": question})
answer = llm.request(message=messages)
messages.append({"role": "assistant", "content": answer})
# if verbose:
# print("============ Answer ================\n",answer)
parsed_output = parse_llm_output(answer, goals=False)
# if parsed_output is None:
# print(f"\033[91mFailed to parse LLM output for goals: {goals_str}\033[0m")
# return None, None, None, messages, distances
priority_act_ls, key_predicates, key_objects = parsed_output
if priority_act_ls==None:
print(f"\033[91mFailed to parse LLM output for goals: {goals_str}\033[0m")
return priority_act_ls, key_predicates, key_objects, messages, distances,parsed_fail
def extract_llm_from_instr(llm,default_prompt_file,instruction,cur_cond_set,\
choose_database=False,\
index_path=f"{ROOT_PATH}/../test/dataset/env_instruction_vectors.index"):
"""从向量数据库检索并生成初始 prompt"""
with open(default_prompt_file, 'r', encoding="utf-8") as f:
prompt = f.read().strip()
if choose_database:
# 补充:向量数据库检索,拼接上最相近的 Example cur_cond_set
# cur_env_state = ', '.join(map(str, cur_cond_set))
# cur_data = instuction + "\n[current environmental condition]\n" + cur_env_state # 可能还要再调整
# cur_emb = llm.embedding(question=cur_data)
# 导入向量数据库找到最近的前5条。
# 准备好的 30条数据 作为 向量数据库
# example = ""
# 将例子拼在后面
# question+=example
# 检索向量数据库以获取最近的 Examples
nearest_examples,distances = search_nearest_examples(index_path, llm, instruction,top_n=3)
# 使用自定义的格式函数将检索到的示例格式化为目标样式
example_texts = '\n'.join([format_example(ex) for ex in nearest_examples])
example_texts = "[Examples]\n" + example_texts
print("distances:",distances)
# print("example_texts:\n",example_texts)
# 替换 prompt 中的 [Examples] 部分
example_marker = "[Examples]"
if example_marker in prompt:
prompt = prompt.replace(example_marker, example_texts)
else:
prompt = f"{prompt}\n{example_texts}"
# 构建完整的 prompt包括检索的 Examples 和当前的指令
question = f"{prompt}\n{instruction}"
print("question:",question)
messages = []
messages.append({"role": "user", "content": question})
answer = llm.request(message=messages)
messages.append({"role": "assistant", "content": answer})
print(answer)
goal_set, priority_act_ls, key_predicates, key_objects = parse_llm_output(answer)
print("goal",goal_set)
print("act:",priority_act_ls)
print("key_predicate",key_predicates)
print("Vital Objects:",key_objects)
# 提取目标中的所有物体
objects = set()
# 正则表达式用于找到括号中的内容
pattern = re.compile(r'\((.*?)\)')
# 遍历所有表达式,提取物体名称
for expr in goal_set[0]:
# 找到括号内的内容
match = pattern.search(expr)
if match:
# 将括号内的内容按逗号分割并加入到集合中
objects.update(match.group(1).split(','))
key_objects += list(objects)
key_objects = list(set(key_objects))
return goal_set, priority_act_ls, key_predicates, key_objects, messages
def act_tree_verbose(llm, messages, reflect_prompt):
messages.append({"role": "user", "content": reflect_prompt})
answer = llm.request(message=messages)
messages.append({"role": "assistant", "content": answer})
print("============ Answer ================\n",answer)
goal_set, priority_act_ls, key_predicates, key_objects = parse_llm_output(answer)
print("goal",goal_set)
print("act:",priority_act_ls)
print("key_predicate",key_predicates)
print("Vital Objects:",key_objects)
return goal_set, priority_act_ls, key_predicates, key_objects, messages
def convert_conditions(conditions_set):
# Initialize an empty list to store the formatted strings
formatted_conditions = []
# Loop over each condition in the set
for condition in conditions_set:
# Remove the parentheses and split the condition into parts based on the first opening parenthesis
base, args = condition.split("(")
# Remove the closing parenthesis and replace commas with underscores in the arguments
args = args.strip(")").replace(",", "_")
# Concatenate the base and the arguments with an underscore and add to the list
formatted_conditions.append(f"{base.strip()}_{args}")
formatted_conditions_str = " & ".join(formatted_conditions)
return formatted_conditions_str
def extract_llm_from_reflect(llm,messages,nearest_examples=None):
answer = llm.request(message=messages)
messages.append({"role": "assistant", "content": answer})
priority_act_ls, key_predicates, key_objects = parse_llm_output(answer,goals=False) # 返回的都是list
cyan = "\033[36m"
reset = "\033[0m"
print(f"{cyan}--- Reflect Just LLM ---{reset}")
print(f"{cyan}priority_act_ls: {', '.join(priority_act_ls)}{reset}")
print(f"{cyan}key_predicates: {', '.join(key_predicates)}{reset}")
print(f"{cyan}key_objects: {', '.join(key_objects)}{reset}")
# 如果这里面把例子中的pred和obj也加进去
if nearest_examples!=None:
ex_preds=set()
ex_objs=set()
for ex in nearest_examples:
ex_preds |= set(ex['value']['Vital Action Predicates'].replace(" ", "").split(","))
ex_objs |= set(ex['value']['Vital Objects'].replace(" ", "").split(","))
key_predicates = list(set(key_predicates) | ex_preds)
key_objects = list(set(key_objects) | ex_objs)
pass
cyan = "\033[36m"
reset = "\033[0m"
print(f"{cyan}--- Reflect Answers ---{reset}")
print(f"{cyan}priority_act_ls: {', '.join(priority_act_ls)}{reset}")
print(f"{cyan}key_predicates: {', '.join(key_predicates)}{reset}")
print(f"{cyan}key_objects: {', '.join(key_objects)}{reset}")
return priority_act_ls, key_predicates, key_objects, messages