252 lines
10 KiB
Python
252 lines
10 KiB
Python
|
import re
|
|||
|
from btpg.algos.llm_client.tools import goal_transfer_str, act_str_process
|
|||
|
from btpg.utils import ROOT_PATH
|
|||
|
# 导入向量数据库检索的相关函数
|
|||
|
from btpg.algos.llm_client.vector_database_env_goal import search_nearest_examples
|
|||
|
from ordered_set import OrderedSet
|
|||
|
|
|||
|
|
|||
|
def parse_llm_output(answer,goals=True):
|
|||
|
goal_set = set()
|
|||
|
priority_act_ls, key_predicate, key_objects = [], [], []
|
|||
|
|
|||
|
try:
|
|||
|
if goals:
|
|||
|
goal_str = answer.split("Optimal Actions:")[0].replace("Goals:", "").strip()
|
|||
|
goal_set = goal_transfer_str(goal_str)
|
|||
|
|
|||
|
act_str = answer.split("Optimal Actions:")[1].split("Vital Action Predicates:")[0].strip()
|
|||
|
predicate_str = answer.split("Vital Action Predicates:")[1].split("Vital Objects:")[0].strip()
|
|||
|
objects_str = answer.split("Vital Objects:")[1].strip()
|
|||
|
priority_act_ls = act_str_process(act_str)
|
|||
|
|
|||
|
# Remove all spaces, Split by comma to create a list
|
|||
|
key_predicate = predicate_str.replace(" ", "").split(",")
|
|||
|
key_objects = objects_str.replace(" ", "").split(",")
|
|||
|
|
|||
|
priority_act_ls = list(OrderedSet(priority_act_ls))
|
|||
|
key_predicate = list(OrderedSet(key_predicate))
|
|||
|
key_objects = list(OrderedSet(key_objects))
|
|||
|
except Exception as e:
|
|||
|
goal_set, priority_act_ls, key_predicate, key_objects = None,None,None,None
|
|||
|
print(f"Failed to parse LLM output: {e}")
|
|||
|
|
|||
|
|
|||
|
if goals:
|
|||
|
return goal_set,priority_act_ls,key_predicate,key_objects
|
|||
|
else:
|
|||
|
return priority_act_ls, key_predicate, key_objects
|
|||
|
|
|||
|
|
|||
|
|
|||
|
def format_example(metadata):
|
|||
|
"""格式化向量数据库的示例数据为所需的格式"""
|
|||
|
example_value = metadata['value']
|
|||
|
return (
|
|||
|
# f"Instruction: {example_value['Instruction']}\n"
|
|||
|
f"Goals: {example_value['Goals']}\n"
|
|||
|
f"Optimal Actions: {example_value['Optimal Actions']}\n"
|
|||
|
f"Vital Action Predicates: {example_value.get('Vital Action Predicates', '')}\n"
|
|||
|
f"Vital Objects: {example_value['Vital Objects']}\n")
|
|||
|
|
|||
|
def extract_llm_from_instr_goal(llm,default_prompt_file,environment,goals,instruction=None,cur_cond_set=None,\
|
|||
|
choose_database=False,\
|
|||
|
database_index_path=f"{ROOT_PATH}/../test/dataset/env_instruction_vectors.index",verbose=False):
|
|||
|
with open(default_prompt_file, 'r', encoding="utf-8") as f:
|
|||
|
prompt = f.read().strip()
|
|||
|
|
|||
|
distances=None
|
|||
|
parsed_output =None
|
|||
|
parsed_fail=-1
|
|||
|
RED = "\033[31m"
|
|||
|
RESET = "\033[0m"
|
|||
|
|
|||
|
while parsed_output==None:
|
|||
|
|
|||
|
parsed_fail += 1 # 第一次是第0次。 0-1-2-3
|
|||
|
print(f"--- LLM: Goal={goals} Parsed Fail={parsed_fail} --- ")
|
|||
|
if parsed_fail > 3:
|
|||
|
print(f"{RED}----LLM: Goal={goals} Parsed Fail={parsed_fail} >3 break -----{RESET}")
|
|||
|
break
|
|||
|
|
|||
|
if choose_database:
|
|||
|
|
|||
|
# environment ?
|
|||
|
nearest_examples,distances = search_nearest_examples(database_index_path, llm, goals, top_n=5)
|
|||
|
# 使用自定义的格式函数将检索到的示例格式化为目标样式
|
|||
|
example_texts = '\n'.join([format_example(ex) for ex in nearest_examples])
|
|||
|
example_texts = "[Examples]\n" + example_texts
|
|||
|
|
|||
|
# 输出最近的所有goal
|
|||
|
nearest_goals = [ex['value']['Goals'] for ex in nearest_examples]
|
|||
|
print("All Goals from nearest examples:")
|
|||
|
for g in nearest_goals:
|
|||
|
print(f"\033[93m{g}\033[0m") # 打印黄色 print(goal)
|
|||
|
|
|||
|
# print("distances:",distances)
|
|||
|
# print("example_texts:\n",example_texts)
|
|||
|
# 替换 prompt 中的 [Examples] 部分
|
|||
|
example_marker = "[Examples]"
|
|||
|
if example_marker in prompt:
|
|||
|
prompt = prompt.replace(example_marker, example_texts)
|
|||
|
else:
|
|||
|
prompt = f"{prompt}\n{example_texts}"
|
|||
|
|
|||
|
# 构建完整的 prompt,包括检索的 Examples 和当前的指令
|
|||
|
goals_str =' & '.join(goals)
|
|||
|
# question = f"{prompt}\nInstruction: {instruction}\nGoals: {goals_str}"
|
|||
|
question = f"{prompt}\nGoals: {goals_str}"
|
|||
|
if verbose:
|
|||
|
print("============ Question ================\n",question)
|
|||
|
messages = []
|
|||
|
messages.append({"role": "user", "content": question})
|
|||
|
answer = llm.request(message=messages)
|
|||
|
messages.append({"role": "assistant", "content": answer})
|
|||
|
# if verbose:
|
|||
|
# print("============ Answer ================\n",answer)
|
|||
|
parsed_output = parse_llm_output(answer, goals=False)
|
|||
|
# if parsed_output is None:
|
|||
|
# print(f"\033[91mFailed to parse LLM output for goals: {goals_str}\033[0m")
|
|||
|
# return None, None, None, messages, distances
|
|||
|
|
|||
|
priority_act_ls, key_predicates, key_objects = parsed_output
|
|||
|
|
|||
|
|
|||
|
|
|||
|
if priority_act_ls==None:
|
|||
|
print(f"\033[91mFailed to parse LLM output for goals: {goals_str}\033[0m")
|
|||
|
return priority_act_ls, key_predicates, key_objects, messages, distances,parsed_fail
|
|||
|
|
|||
|
def extract_llm_from_instr(llm,default_prompt_file,instruction,cur_cond_set,\
|
|||
|
choose_database=False,\
|
|||
|
index_path=f"{ROOT_PATH}/../test/dataset/env_instruction_vectors.index"):
|
|||
|
"""从向量数据库检索并生成初始 prompt"""
|
|||
|
|
|||
|
with open(default_prompt_file, 'r', encoding="utf-8") as f:
|
|||
|
prompt = f.read().strip()
|
|||
|
|
|||
|
if choose_database:
|
|||
|
# 补充:向量数据库检索,拼接上最相近的 Example cur_cond_set
|
|||
|
# cur_env_state = ', '.join(map(str, cur_cond_set))
|
|||
|
# cur_data = instuction + "\n[current environmental condition]\n" + cur_env_state # 可能还要再调整
|
|||
|
# cur_emb = llm.embedding(question=cur_data)
|
|||
|
# 导入向量数据库,找到最近的前5条。
|
|||
|
# 准备好的 30条数据 作为 向量数据库
|
|||
|
# example = ""
|
|||
|
# 将例子拼在后面
|
|||
|
# question+=example
|
|||
|
# 检索向量数据库以获取最近的 Examples
|
|||
|
nearest_examples,distances = search_nearest_examples(index_path, llm, instruction,top_n=3)
|
|||
|
# 使用自定义的格式函数将检索到的示例格式化为目标样式
|
|||
|
example_texts = '\n'.join([format_example(ex) for ex in nearest_examples])
|
|||
|
example_texts = "[Examples]\n" + example_texts
|
|||
|
print("distances:",distances)
|
|||
|
# print("example_texts:\n",example_texts)
|
|||
|
# 替换 prompt 中的 [Examples] 部分
|
|||
|
example_marker = "[Examples]"
|
|||
|
if example_marker in prompt:
|
|||
|
prompt = prompt.replace(example_marker, example_texts)
|
|||
|
else:
|
|||
|
prompt = f"{prompt}\n{example_texts}"
|
|||
|
|
|||
|
|
|||
|
# 构建完整的 prompt,包括检索的 Examples 和当前的指令
|
|||
|
question = f"{prompt}\n{instruction}"
|
|||
|
print("question:",question)
|
|||
|
messages = []
|
|||
|
messages.append({"role": "user", "content": question})
|
|||
|
answer = llm.request(message=messages)
|
|||
|
messages.append({"role": "assistant", "content": answer})
|
|||
|
print(answer)
|
|||
|
|
|||
|
goal_set, priority_act_ls, key_predicates, key_objects = parse_llm_output(answer)
|
|||
|
|
|||
|
print("goal",goal_set)
|
|||
|
print("act:",priority_act_ls)
|
|||
|
print("key_predicate",key_predicates)
|
|||
|
print("Vital Objects:",key_objects)
|
|||
|
|
|||
|
|
|||
|
# 提取目标中的所有物体
|
|||
|
objects = set()
|
|||
|
# 正则表达式用于找到括号中的内容
|
|||
|
pattern = re.compile(r'\((.*?)\)')
|
|||
|
# 遍历所有表达式,提取物体名称
|
|||
|
for expr in goal_set[0]:
|
|||
|
# 找到括号内的内容
|
|||
|
match = pattern.search(expr)
|
|||
|
if match:
|
|||
|
# 将括号内的内容按逗号分割并加入到集合中
|
|||
|
objects.update(match.group(1).split(','))
|
|||
|
key_objects += list(objects)
|
|||
|
key_objects = list(set(key_objects))
|
|||
|
|
|||
|
return goal_set, priority_act_ls, key_predicates, key_objects, messages
|
|||
|
|
|||
|
def act_tree_verbose(llm, messages, reflect_prompt):
|
|||
|
messages.append({"role": "user", "content": reflect_prompt})
|
|||
|
answer = llm.request(message=messages)
|
|||
|
messages.append({"role": "assistant", "content": answer})
|
|||
|
|
|||
|
print("============ Answer ================\n",answer)
|
|||
|
|
|||
|
goal_set, priority_act_ls, key_predicates, key_objects = parse_llm_output(answer)
|
|||
|
|
|||
|
print("goal",goal_set)
|
|||
|
print("act:",priority_act_ls)
|
|||
|
print("key_predicate",key_predicates)
|
|||
|
print("Vital Objects:",key_objects)
|
|||
|
|
|||
|
return goal_set, priority_act_ls, key_predicates, key_objects, messages
|
|||
|
|
|||
|
|
|||
|
def convert_conditions(conditions_set):
|
|||
|
# Initialize an empty list to store the formatted strings
|
|||
|
formatted_conditions = []
|
|||
|
|
|||
|
# Loop over each condition in the set
|
|||
|
for condition in conditions_set:
|
|||
|
# Remove the parentheses and split the condition into parts based on the first opening parenthesis
|
|||
|
base, args = condition.split("(")
|
|||
|
# Remove the closing parenthesis and replace commas with underscores in the arguments
|
|||
|
args = args.strip(")").replace(",", "_")
|
|||
|
# Concatenate the base and the arguments with an underscore and add to the list
|
|||
|
formatted_conditions.append(f"{base.strip()}_{args}")
|
|||
|
|
|||
|
formatted_conditions_str = " & ".join(formatted_conditions)
|
|||
|
return formatted_conditions_str
|
|||
|
|
|||
|
|
|||
|
def extract_llm_from_reflect(llm,messages,nearest_examples=None):
|
|||
|
|
|||
|
answer = llm.request(message=messages)
|
|||
|
messages.append({"role": "assistant", "content": answer})
|
|||
|
priority_act_ls, key_predicates, key_objects = parse_llm_output(answer,goals=False) # 返回的都是list
|
|||
|
|
|||
|
cyan = "\033[36m"
|
|||
|
reset = "\033[0m"
|
|||
|
print(f"{cyan}--- Reflect Just LLM ---{reset}")
|
|||
|
print(f"{cyan}priority_act_ls: {', '.join(priority_act_ls)}{reset}")
|
|||
|
print(f"{cyan}key_predicates: {', '.join(key_predicates)}{reset}")
|
|||
|
print(f"{cyan}key_objects: {', '.join(key_objects)}{reset}")
|
|||
|
|
|||
|
# 如果这里面把例子中的pred和obj也加进去
|
|||
|
if nearest_examples!=None:
|
|||
|
ex_preds=set()
|
|||
|
ex_objs=set()
|
|||
|
for ex in nearest_examples:
|
|||
|
ex_preds |= set(ex['value']['Vital Action Predicates'].replace(" ", "").split(","))
|
|||
|
ex_objs |= set(ex['value']['Vital Objects'].replace(" ", "").split(","))
|
|||
|
key_predicates = list(set(key_predicates) | ex_preds)
|
|||
|
key_objects = list(set(key_objects) | ex_objs)
|
|||
|
|
|||
|
pass
|
|||
|
|
|||
|
cyan = "\033[36m"
|
|||
|
reset = "\033[0m"
|
|||
|
print(f"{cyan}--- Reflect Answers ---{reset}")
|
|||
|
print(f"{cyan}priority_act_ls: {', '.join(priority_act_ls)}{reset}")
|
|||
|
print(f"{cyan}key_predicates: {', '.join(key_predicates)}{reset}")
|
|||
|
print(f"{cyan}key_objects: {', '.join(key_objects)}{reset}")
|
|||
|
|
|||
|
return priority_act_ls, key_predicates, key_objects, messages
|