RoboWaiter/BTExpansionCode/EXP/exp3_Merge_easy_medium_hard...

128 lines
4.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import numpy as np
from EXP.exp_tools import state_transition,collect_action_nodes,get_start,BTTest,goal_transfer_str,collect_cond_nodes,BTTest_Merge,BTTest_Merge_easy_medium_hard
import copy
import random
import re
from OptimalBTExpansionAlgorithm import Action,OptBTExpAlgorithm
seed = 1
random.seed(seed)
multiple_num=5
iters_times= 10
iter_action_ls = collect_action_nodes(random,multiple_num,iters_times)
start_robowaiter = get_start()
# 计算state总数
state_num, vaild_state_num= collect_cond_nodes()
# print("meta states num: ",state_num)
# print("states num: ",vaild_state_num)
# print("act num: ",len(action_list))
goal_states = []
easy_data_set_file = "../dataset/easy_instr_goal.txt"
medium_data_set_file = "../dataset/medium_instr_goal.txt"
hard_data_set_file = "../dataset/hard_instr_goal.txt"
with open(easy_data_set_file, 'r', encoding="utf-8") as f:
easy_data_set = f.read().strip()
with open(medium_data_set_file, 'r', encoding="utf-8") as f:
medium_data_set = f.read().strip()
with open(hard_data_set_file, 'r', encoding="utf-8") as f:
hard_data_set = f.read().strip()
all_result=[]
max_merge_times=20
dataset_ls = [easy_data_set,medium_data_set,hard_data_set]
parm_difficule_ls = ['Easy','Medium','Hard']
# dataset_ls = [hard_data_set]
# parm_difficule_ls = ['Hard']
for index, dataset in enumerate(dataset_ls):
print(f"\n----------- {parm_difficule_ls[index]} ----------\n")
sections = re.split(r'\n\s*\n', dataset)
outputs_list = [[] for _ in range(len(sections))]
goal_set_ls = []
for i, s in enumerate(sections):
x, y = s.strip().splitlines()
x = x.strip()
y = y.strip().replace("Goal: ", "")
goal_set_ls.append(y)
goal_states = goal_set_ls
import time
# 针对一个difficult 跑10次
condticks_avg_ls=[]
merge_cond_tick_total = [[] for merge_time in range(max_merge_times)]
for iter in range(iters_times):
action_list = iter_action_ls[iter]
planning_time_ls=[]
planning_time_total=0
for count, goal_str in enumerate(goal_states):
goal = copy.deepcopy(goal_transfer_str(goal_str))
algo = OptBTExpAlgorithm(verbose=False)
algo.clear()
algo.bt_merge=False
start_time = time.time()
algo_right = algo.run_algorithm(start_robowaiter, goal, action_list)
end_time = time.time()
planning_time_ls.append(end_time - start_time)
planning_time_total += (end_time - start_time)
for merge_time in range(max_merge_times):
# 根据子树个数进行合并
bt = algo.merge_subtree(merge_time)
# 计算合并后的 cond tick
# 开始从初始状态运行行为树,测试
state = copy.deepcopy(start_robowaiter)
steps = 0
current_cond_tick_time = 0
val, obj, cost, tick_time, cond_times = bt.cost_tick_cond(state, 0, 0, 0) # tick行为树obj为所运行的行动
current_cond_tick_time += cond_times
while val != 'success' and val != 'failure': # 运行直到行为树成功或失败
state = state_transition(state, obj)
val, obj, cost, tick_time, cond_times = bt.cost_tick_cond(state, 0, 0, 0)
current_cond_tick_time += cond_times
if (val == 'failure'):
print("bt fails at step", steps)
error = True
break
steps += 1
if (steps >= 500): # 至多运行500步
break
# 检查执行后状态满不满足,只有 goal 里有一个满足就行
error = True
for gg in goal:
if gg <= state:
error = False
break
if error:
print("error")
break
# 结束从初始状态运行行为树,测试
merge_cond_tick_total[merge_time].append(current_cond_tick_time)
print("iter:", iter,"cond:",merge_cond_tick_total)
merge_cond_tick_total_np = np.array(merge_cond_tick_total)
merge_cond_tick_averages = np.mean(merge_cond_tick_total_np, axis=1)
print(merge_cond_tick_averages)
all_result.append(merge_cond_tick_averages)
import pandas as pd
df = pd.DataFrame(all_result)
import time
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()).replace("-","").replace(":","")
csv_file_path = 'merged_result_easy_medium_hard_'+time_str+'.csv'
df.to_csv(csv_file_path, index=True)
print("CSV文件已生成:", csv_file_path)