New: 完成目标状态数据集的构建
This commit is contained in:
parent
f7a2b9b5d7
commit
5cfd97ba17
File diff suppressed because it is too large
Load Diff
|
@ -1,6 +1,10 @@
|
|||
# the empty string '' represents robot holds nothing
|
||||
Object = ['Coffee', 'Water', 'Dessert', 'Softdrink', 'BottledDrink', 'Yogurt', 'ADMilk', 'MilkDrink', 'Milk',
|
||||
'VacuumCup', '']
|
||||
import os
|
||||
import re
|
||||
|
||||
Object = ['Softdrink', 'BottledDrink', 'Yogurt', 'ADMilk', 'MilkDrink', 'Milk', 'VacuumCup', 'Nothing']
|
||||
|
||||
Cookable = ['Coffee', 'Water', 'Dessert']
|
||||
|
||||
Place = ['Bar', 'WaterTable', 'CoffeeTable', 'Bar2', 'Table1', 'Table2', 'Table3']
|
||||
|
||||
|
@ -35,13 +39,13 @@ def enumerate_predict(oplist_1, oplist_2, predict_pattern) -> [int, list]:
|
|||
|
||||
match predict_pattern:
|
||||
case 'at':
|
||||
pattern = f'At(%s, %s)'
|
||||
pattern = f'At(%s,%s)'
|
||||
case 'is':
|
||||
pattern = f'Is(%s, %s)'
|
||||
pattern = f'Is(%s,%s)'
|
||||
case 'hold':
|
||||
pattern = f'Holding(%s)'
|
||||
case 'on':
|
||||
pattern = f'On(%s, %s)'
|
||||
pattern = f'On(%s,%s)'
|
||||
case _:
|
||||
raise RuntimeError('Incorrect predict pattern!')
|
||||
|
||||
|
@ -49,10 +53,10 @@ def enumerate_predict(oplist_1, oplist_2, predict_pattern) -> [int, list]:
|
|||
if oplist_2:
|
||||
for str_2 in oplist_2:
|
||||
count += 1
|
||||
res.append({pattern % (str_1, str_2)})
|
||||
res.append(pattern % (str_1, str_2))
|
||||
else:
|
||||
count += 1
|
||||
res.append({pattern % str_1})
|
||||
res.append(pattern % str_1)
|
||||
|
||||
return count, res
|
||||
|
||||
|
@ -85,27 +89,127 @@ def generate_goal_states(vln_num: int, vlm_num: int, opentask_num: int):
|
|||
}
|
||||
)
|
||||
|
||||
# print(res)
|
||||
# print(len(res))
|
||||
print(res)
|
||||
print(len(res))
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def enumerate_goal_states():
|
||||
# goal states for VLN
|
||||
def enumerate_goal_states(total: int):
|
||||
|
||||
res = []
|
||||
|
||||
point_15 = int(total * .15)
|
||||
point_10 = int(total * .10)
|
||||
|
||||
# goal states for VLN, .15
|
||||
count_vln, list_vln = enumerate_predict(['Robot'], Place, 'at')
|
||||
print(f'VLN 任务的目标状态数:{count_vln}')
|
||||
list_vln = ['{%s}'%i for i in list_vln]
|
||||
if count_vln < point_15:
|
||||
list_vln *= point_15 // count_vln
|
||||
for i in range(0, point_15 - len(list_vln)):
|
||||
list_vln.append(single_predict_generation(['Robot'], Place, 'at'))
|
||||
# print(f'VLN 任务的目标状态数:{count_vln}')
|
||||
res += list_vln
|
||||
|
||||
# goal states for VLM
|
||||
count_vlm_1, list_vlm_1 = enumerate_predict(['Robot'], Place, 'at')
|
||||
# goal states for VLM-1, 0.15
|
||||
count_vlm_1, list_vlm_1 = enumerate_predict(Object, Place, 'on')
|
||||
list_vlm_1 = ['{%s}' % i for i in list_vlm_1]
|
||||
if count_vlm_1 < point_15:
|
||||
list_vlm_1 *= point_15 // count_vlm_1
|
||||
for i in range(0, point_15 - len(list_vlm_1)):
|
||||
list_vlm_1.append(single_predict_generation(Object, Place, 'on'))
|
||||
res += list_vlm_1
|
||||
|
||||
# goal states for VLM-2, 0.15
|
||||
count_vlm_2, list_vlm_2 = enumerate_predict(Operable, ['0', '1'], 'is')
|
||||
print(f'VLM 任务的目标状态数:{count_vlm_1 * count_vlm_2}')
|
||||
list_vlm_2 = ['{%s}' % i for i in list_vlm_2]
|
||||
if count_vlm_2 < point_15:
|
||||
list_vlm_2 *= point_15 // count_vlm_2
|
||||
for i in range(0, point_15 - len(list_vlm_2)):
|
||||
list_vlm_2.append(single_predict_generation(Operable, ['0', '1'], 'is'))
|
||||
res += list_vlm_2
|
||||
|
||||
# goal states for open-task
|
||||
count_opentask_1, list_opentask_1 = enumerate_predict(['Robot'], Place, 'at')
|
||||
count_opentask_2, list_opentask_2 = enumerate_predict(Object, Place, 'on')
|
||||
print(f'Open-task-1 任务的目标状态数:{count_opentask_1 * count_opentask_2}')
|
||||
# goal states for VLM-3, 0.1
|
||||
count_vlm_3, list_vlm_3 = enumerate_predict(Object, None, 'hold')
|
||||
list_vlm_3 = ['{%s}' % i for i in list_vlm_3]
|
||||
if count_vlm_3 < point_10:
|
||||
list_vlm_3 *= point_10 // count_vlm_3
|
||||
for i in range(0, point_10 - len(list_vlm_3)):
|
||||
list_vlm_3.append(single_predict_generation(Object, None, 'hold'))
|
||||
res += list_vlm_3
|
||||
|
||||
# goal states for OT, 0.15
|
||||
count_ot, list_ot = enumerate_predict(Cookable, Place, 'on')
|
||||
list_ot = ['{%s}' % i for i in list_ot]
|
||||
if count_ot < point_15:
|
||||
list_ot *= point_15 // count_ot
|
||||
for i in range(0, point_15 - len(list_ot)):
|
||||
list_ot.append(single_predict_generation(Cookable, Place, 'on'))
|
||||
res += list_ot
|
||||
|
||||
# goal states for compound-1, 0.1
|
||||
count_1, list_1 = enumerate_predict(['Robot'], Place, 'at')
|
||||
count_2, list_2 = enumerate_predict(Object, Place, 'on')
|
||||
list_tmp = []
|
||||
for i in list_1:
|
||||
for j in list_2:
|
||||
list_tmp.append('{%s,%s}' % (i, j))
|
||||
if len(list_tmp) < point_10:
|
||||
list_tmp *= point_10 // len(list_tmp)
|
||||
list_tmp += list_tmp[0:point_10-len(list_tmp)]
|
||||
else:
|
||||
list_tmp = list_tmp[:point_10]
|
||||
res += list_tmp
|
||||
|
||||
# goal states for compound-2, 0.1
|
||||
count_1, list_1 = enumerate_predict(['Robot'], Place, 'at')
|
||||
count_2, list_2 = enumerate_predict(Operable, ['0', '1'], 'is')
|
||||
list_tmp = []
|
||||
for i in list_1:
|
||||
for j in list_2:
|
||||
list_tmp.append('{%s,%s}' % (i, j))
|
||||
if len(list_tmp) < point_10:
|
||||
list_tmp *= point_10 // len(list_tmp)
|
||||
list_tmp += list_tmp[0:point_10 - len(list_tmp)]
|
||||
else:
|
||||
list_tmp = list_tmp[:point_10]
|
||||
res += list_tmp
|
||||
|
||||
# goal states for compound-3, 0.1
|
||||
count_1, list_1 = enumerate_predict(Cookable, Place, 'on')
|
||||
count_2, list_2 = enumerate_predict(Operable, ['0', '1'], 'is')
|
||||
list_tmp = []
|
||||
for i in list_1:
|
||||
for j in list_2:
|
||||
list_tmp.append('{%s,%s}' % (i, j))
|
||||
if len(list_tmp) < point_10:
|
||||
list_tmp *= point_10 // len(list_tmp)
|
||||
list_tmp += list_tmp[0:point_10 - len(list_tmp)]
|
||||
else:
|
||||
list_tmp = list_tmp[:point_10]
|
||||
res += list_tmp
|
||||
|
||||
# # goal states for VLM-1, 0.15
|
||||
# count_vlm_1, list_vlm_1 = enumerate_predict(['Robot'], Place, 'at')
|
||||
# count_vlm_2, list_vlm_2 = enumerate_predict(Operable, ['0', '1'], 'is')
|
||||
# print(f'VLM 任务的目标状态数:{count_vlm_1 * count_vlm_2}')
|
||||
#
|
||||
# # goal states for open-task
|
||||
# count_opentask_1, list_opentask_1 = enumerate_predict(['Robot'], Place, 'at')
|
||||
# count_opentask_2, list_opentask_2 = enumerate_predict(Object, Place, 'on')
|
||||
# print(f'Open-task-1 任务的目标状态数:{count_opentask_1 * count_opentask_2}')
|
||||
|
||||
with open(os.path.join('./goal_states.txt'), 'w+') as file:
|
||||
for i in res:
|
||||
if 'Is' in i and 'ACTemperature' in i:
|
||||
i = re.sub('0', 'Up', i)
|
||||
i = re.sub('1', 'Down', i)
|
||||
elif 'Is' in i and ('AC' in i or 'HallLight' in i or 'TubeLight' in i or 'Curtain' in i):
|
||||
i = re.sub('0', 'Off', i)
|
||||
i = re.sub('1', 'On', i)
|
||||
|
||||
file.write(i+'\n')
|
||||
|
||||
# generate_goal_states(30, 6, 6)
|
||||
enumerate_goal_states()
|
||||
enumerate_goal_states(5000)
|
||||
|
|
Loading…
Reference in New Issue