New: 增加枚举所有目标状态的实现

This commit is contained in:
wuziji 2023-11-14 09:23:59 +08:00
parent 355f1fac7b
commit f30745058c
1 changed files with 46 additions and 1 deletions

View File

@ -29,6 +29,34 @@ def single_predict_generation(oplist_1, oplist_2, predict_pattern) -> str:
raise RuntimeError('Incorrect predict pattern!')
def enumerate_predict(oplist_1, oplist_2, predict_pattern) -> [int, list]:
count = 0
res = []
match predict_pattern:
case 'at':
pattern = f'At(%s, %s)'
case 'is':
pattern = f'Is(%s, %s)'
case 'hold':
pattern = f'Holding(%s)'
case 'on':
pattern = f'On(%s, %s)'
case _:
raise RuntimeError('Incorrect predict pattern!')
for str_1 in oplist_1:
if oplist_2:
for str_2 in oplist_2:
count += 1
res.append({pattern % (str_1, str_2)})
else:
count += 1
res.append({pattern % str_1})
return count, res
def generate_goal_states(vln_num: int, vlm_num: int, opentask_num: int):
# res stores lists of sets, while each state represent in set.
res = []
@ -63,4 +91,21 @@ def generate_goal_states(vln_num: int, vlm_num: int, opentask_num: int):
return res
generate_goal_states(30, 6, 6)
def enumerate_goal_states():
# goal states for VLN
count_vln, list_vln = enumerate_predict(['Robot'], Place, 'at')
print(f'VLN 任务的目标状态数:{count_vln}')
# goal states for VLM
count_vlm_1, list_vlm_1 = enumerate_predict(['Robot'], Place, 'at')
count_vlm_2, list_vlm_2 = enumerate_predict(Operable, ['0', '1'], 'is')
print(f'VLM 任务的目标状态数:{count_vlm_1 * count_vlm_2}')
# goal states for open-task
count_opentask_1, list_opentask_1 = enumerate_predict(['Robot'], Place, 'at')
count_opentask_2, list_opentask_2 = enumerate_predict(Object, Place, 'on')
print(f'Open-task-1 任务的目标状态数:{count_opentask_1 * count_opentask_2}')
# generate_goal_states(30, 6, 6)
enumerate_goal_states()