diff --git a/robowaiter/behavior_tree/dataset/goal_states_generation.py b/robowaiter/behavior_tree/dataset/goal_states_generation.py index 550d86f..3cc2770 100644 --- a/robowaiter/behavior_tree/dataset/goal_states_generation.py +++ b/robowaiter/behavior_tree/dataset/goal_states_generation.py @@ -29,6 +29,34 @@ def single_predict_generation(oplist_1, oplist_2, predict_pattern) -> str: raise RuntimeError('Incorrect predict pattern!') +def enumerate_predict(oplist_1, oplist_2, predict_pattern) -> [int, list]: + count = 0 + res = [] + + match predict_pattern: + case 'at': + pattern = f'At(%s, %s)' + case 'is': + pattern = f'Is(%s, %s)' + case 'hold': + pattern = f'Holding(%s)' + case 'on': + pattern = f'On(%s, %s)' + case _: + raise RuntimeError('Incorrect predict pattern!') + + for str_1 in oplist_1: + if oplist_2: + for str_2 in oplist_2: + count += 1 + res.append({pattern % (str_1, str_2)}) + else: + count += 1 + res.append({pattern % str_1}) + + return count, res + + def generate_goal_states(vln_num: int, vlm_num: int, opentask_num: int): # res stores lists of sets, while each state represent in set. res = [] @@ -63,4 +91,21 @@ def generate_goal_states(vln_num: int, vlm_num: int, opentask_num: int): return res -generate_goal_states(30, 6, 6) +def enumerate_goal_states(): + # goal states for VLN + count_vln, list_vln = enumerate_predict(['Robot'], Place, 'at') + print(f'VLN 任务的目标状态数:{count_vln}') + + # goal states for VLM + count_vlm_1, list_vlm_1 = enumerate_predict(['Robot'], Place, 'at') + count_vlm_2, list_vlm_2 = enumerate_predict(Operable, ['0', '1'], 'is') + print(f'VLM 任务的目标状态数:{count_vlm_1 * count_vlm_2}') + + # goal states for open-task + count_opentask_1, list_opentask_1 = enumerate_predict(['Robot'], Place, 'at') + count_opentask_2, list_opentask_2 = enumerate_predict(Object, Place, 'on') + print(f'Open-task-1 任务的目标状态数:{count_opentask_1 * count_opentask_2}') + + +# generate_goal_states(30, 6, 6) +enumerate_goal_states()