RoboWaiter/BTExpansionCode/EXP/exp3_Merge_easy_medium_hard...

128 lines
4.7 KiB
Python
Raw Normal View History

2024-04-10 19:59:13 +08:00
import numpy as np
from EXP.exp_tools import state_transition,collect_action_nodes,get_start,BTTest,goal_transfer_str,collect_cond_nodes,BTTest_Merge,BTTest_Merge_easy_medium_hard
import copy
import random
import re
from OptimalBTExpansionAlgorithm import Action,OptBTExpAlgorithm
seed = 1
random.seed(seed)
multiple_num=5
iters_times= 10
iter_action_ls = collect_action_nodes(random,multiple_num,iters_times)
start_robowaiter = get_start()
# 计算state总数
state_num, vaild_state_num= collect_cond_nodes()
# print("meta states num: ",state_num)
# print("states num: ",vaild_state_num)
# print("act num: ",len(action_list))
goal_states = []
easy_data_set_file = "../dataset/easy_instr_goal.txt"
medium_data_set_file = "../dataset/medium_instr_goal.txt"
hard_data_set_file = "../dataset/hard_instr_goal.txt"
with open(easy_data_set_file, 'r', encoding="utf-8") as f:
easy_data_set = f.read().strip()
with open(medium_data_set_file, 'r', encoding="utf-8") as f:
medium_data_set = f.read().strip()
with open(hard_data_set_file, 'r', encoding="utf-8") as f:
hard_data_set = f.read().strip()
all_result=[]
max_merge_times=20
dataset_ls = [easy_data_set,medium_data_set,hard_data_set]
parm_difficule_ls = ['Easy','Medium','Hard']
# dataset_ls = [hard_data_set]
# parm_difficule_ls = ['Hard']
for index, dataset in enumerate(dataset_ls):
print(f"\n----------- {parm_difficule_ls[index]} ----------\n")
sections = re.split(r'\n\s*\n', dataset)
outputs_list = [[] for _ in range(len(sections))]
goal_set_ls = []
for i, s in enumerate(sections):
x, y = s.strip().splitlines()
x = x.strip()
y = y.strip().replace("Goal: ", "")
goal_set_ls.append(y)
goal_states = goal_set_ls
import time
# 针对一个difficult 跑10次
condticks_avg_ls=[]
merge_cond_tick_total = [[] for merge_time in range(max_merge_times)]
for iter in range(iters_times):
action_list = iter_action_ls[iter]
planning_time_ls=[]
planning_time_total=0
for count, goal_str in enumerate(goal_states):
goal = copy.deepcopy(goal_transfer_str(goal_str))
algo = OptBTExpAlgorithm(verbose=False)
algo.clear()
algo.bt_merge=False
start_time = time.time()
algo_right = algo.run_algorithm(start_robowaiter, goal, action_list)
end_time = time.time()
planning_time_ls.append(end_time - start_time)
planning_time_total += (end_time - start_time)
for merge_time in range(max_merge_times):
# 根据子树个数进行合并
bt = algo.merge_subtree(merge_time)
# 计算合并后的 cond tick
# 开始从初始状态运行行为树,测试
state = copy.deepcopy(start_robowaiter)
steps = 0
current_cond_tick_time = 0
val, obj, cost, tick_time, cond_times = bt.cost_tick_cond(state, 0, 0, 0) # tick行为树obj为所运行的行动
current_cond_tick_time += cond_times
while val != 'success' and val != 'failure': # 运行直到行为树成功或失败
state = state_transition(state, obj)
val, obj, cost, tick_time, cond_times = bt.cost_tick_cond(state, 0, 0, 0)
current_cond_tick_time += cond_times
if (val == 'failure'):
print("bt fails at step", steps)
error = True
break
steps += 1
if (steps >= 500): # 至多运行500步
break
# 检查执行后状态满不满足,只有 goal 里有一个满足就行
error = True
for gg in goal:
if gg <= state:
error = False
break
if error:
print("error")
break
# 结束从初始状态运行行为树,测试
merge_cond_tick_total[merge_time].append(current_cond_tick_time)
print("iter:", iter,"cond:",merge_cond_tick_total)
merge_cond_tick_total_np = np.array(merge_cond_tick_total)
merge_cond_tick_averages = np.mean(merge_cond_tick_total_np, axis=1)
print(merge_cond_tick_averages)
all_result.append(merge_cond_tick_averages)
import pandas as pd
df = pd.DataFrame(all_result)
import time
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()).replace("-","").replace(":","")
csv_file_path = 'merged_result_easy_medium_hard_'+time_str+'.csv'
df.to_csv(csv_file_path, index=True)
print("CSV文件已生成:", csv_file_path)