128 lines
4.7 KiB
Python
128 lines
4.7 KiB
Python
|
import numpy as np
|
|||
|
|
|||
|
from EXP.exp_tools import state_transition,collect_action_nodes,get_start,BTTest,goal_transfer_str,collect_cond_nodes,BTTest_Merge,BTTest_Merge_easy_medium_hard
|
|||
|
import copy
|
|||
|
import random
|
|||
|
import re
|
|||
|
from OptimalBTExpansionAlgorithm import Action,OptBTExpAlgorithm
|
|||
|
seed = 1
|
|||
|
random.seed(seed)
|
|||
|
multiple_num=5
|
|||
|
iters_times= 10
|
|||
|
iter_action_ls = collect_action_nodes(random,multiple_num,iters_times)
|
|||
|
|
|||
|
|
|||
|
start_robowaiter = get_start()
|
|||
|
|
|||
|
# 计算state总数
|
|||
|
state_num, vaild_state_num= collect_cond_nodes()
|
|||
|
# print("meta states num: ",state_num)
|
|||
|
# print("states num: ",vaild_state_num)
|
|||
|
# print("act num: ",len(action_list))
|
|||
|
|
|||
|
goal_states = []
|
|||
|
|
|||
|
easy_data_set_file = "../dataset/easy_instr_goal.txt"
|
|||
|
medium_data_set_file = "../dataset/medium_instr_goal.txt"
|
|||
|
hard_data_set_file = "../dataset/hard_instr_goal.txt"
|
|||
|
with open(easy_data_set_file, 'r', encoding="utf-8") as f:
|
|||
|
easy_data_set = f.read().strip()
|
|||
|
with open(medium_data_set_file, 'r', encoding="utf-8") as f:
|
|||
|
medium_data_set = f.read().strip()
|
|||
|
with open(hard_data_set_file, 'r', encoding="utf-8") as f:
|
|||
|
hard_data_set = f.read().strip()
|
|||
|
|
|||
|
|
|||
|
all_result=[]
|
|||
|
max_merge_times=20
|
|||
|
dataset_ls = [easy_data_set,medium_data_set,hard_data_set]
|
|||
|
parm_difficule_ls = ['Easy','Medium','Hard']
|
|||
|
|
|||
|
# dataset_ls = [hard_data_set]
|
|||
|
# parm_difficule_ls = ['Hard']
|
|||
|
|
|||
|
for index, dataset in enumerate(dataset_ls):
|
|||
|
|
|||
|
print(f"\n----------- {parm_difficule_ls[index]} ----------\n")
|
|||
|
|
|||
|
sections = re.split(r'\n\s*\n', dataset)
|
|||
|
outputs_list = [[] for _ in range(len(sections))]
|
|||
|
goal_set_ls = []
|
|||
|
for i, s in enumerate(sections):
|
|||
|
x, y = s.strip().splitlines()
|
|||
|
x = x.strip()
|
|||
|
y = y.strip().replace("Goal: ", "")
|
|||
|
goal_set_ls.append(y)
|
|||
|
goal_states = goal_set_ls
|
|||
|
|
|||
|
import time
|
|||
|
|
|||
|
# 针对一个difficult 跑10次
|
|||
|
condticks_avg_ls=[]
|
|||
|
|
|||
|
merge_cond_tick_total = [[] for merge_time in range(max_merge_times)]
|
|||
|
|
|||
|
for iter in range(iters_times):
|
|||
|
|
|||
|
action_list = iter_action_ls[iter]
|
|||
|
|
|||
|
planning_time_ls=[]
|
|||
|
planning_time_total=0
|
|||
|
for count, goal_str in enumerate(goal_states):
|
|||
|
goal = copy.deepcopy(goal_transfer_str(goal_str))
|
|||
|
algo = OptBTExpAlgorithm(verbose=False)
|
|||
|
algo.clear()
|
|||
|
algo.bt_merge=False
|
|||
|
start_time = time.time()
|
|||
|
algo_right = algo.run_algorithm(start_robowaiter, goal, action_list)
|
|||
|
end_time = time.time()
|
|||
|
planning_time_ls.append(end_time - start_time)
|
|||
|
planning_time_total += (end_time - start_time)
|
|||
|
|
|||
|
for merge_time in range(max_merge_times):
|
|||
|
# 根据子树个数进行合并
|
|||
|
bt = algo.merge_subtree(merge_time)
|
|||
|
# 计算合并后的 cond tick
|
|||
|
# 开始从初始状态运行行为树,测试
|
|||
|
state = copy.deepcopy(start_robowaiter)
|
|||
|
steps = 0
|
|||
|
current_cond_tick_time = 0
|
|||
|
val, obj, cost, tick_time, cond_times = bt.cost_tick_cond(state, 0, 0, 0) # tick行为树,obj为所运行的行动
|
|||
|
current_cond_tick_time += cond_times
|
|||
|
while val != 'success' and val != 'failure': # 运行直到行为树成功或失败
|
|||
|
state = state_transition(state, obj)
|
|||
|
val, obj, cost, tick_time, cond_times = bt.cost_tick_cond(state, 0, 0, 0)
|
|||
|
current_cond_tick_time += cond_times
|
|||
|
if (val == 'failure'):
|
|||
|
print("bt fails at step", steps)
|
|||
|
error = True
|
|||
|
break
|
|||
|
steps += 1
|
|||
|
if (steps >= 500): # 至多运行500步
|
|||
|
break
|
|||
|
# 检查执行后状态满不满足,只有 goal 里有一个满足就行
|
|||
|
error = True
|
|||
|
for gg in goal:
|
|||
|
if gg <= state:
|
|||
|
error = False
|
|||
|
break
|
|||
|
if error:
|
|||
|
print("error")
|
|||
|
break
|
|||
|
# 结束从初始状态运行行为树,测试
|
|||
|
merge_cond_tick_total[merge_time].append(current_cond_tick_time)
|
|||
|
print("iter:", iter,"cond:",merge_cond_tick_total)
|
|||
|
merge_cond_tick_total_np = np.array(merge_cond_tick_total)
|
|||
|
merge_cond_tick_averages = np.mean(merge_cond_tick_total_np, axis=1)
|
|||
|
print(merge_cond_tick_averages)
|
|||
|
|
|||
|
all_result.append(merge_cond_tick_averages)
|
|||
|
|
|||
|
|
|||
|
import pandas as pd
|
|||
|
df = pd.DataFrame(all_result)
|
|||
|
import time
|
|||
|
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()).replace("-","").replace(":","")
|
|||
|
csv_file_path = 'merged_result_easy_medium_hard_'+time_str+'.csv'
|
|||
|
df.to_csv(csv_file_path, index=True)
|
|||
|
print("CSV文件已生成:", csv_file_path)
|