From f30745058c55f9dbe175258c88598083590bcda5 Mon Sep 17 00:00:00 2001 From: wuziji <2193177243@qq.com> Date: Tue, 14 Nov 2023 09:23:59 +0800 Subject: [PATCH] =?UTF-8?q?New:=20=E5=A2=9E=E5=8A=A0=E6=9E=9A=E4=B8=BE?= =?UTF-8?q?=E6=89=80=E6=9C=89=E7=9B=AE=E6=A0=87=E7=8A=B6=E6=80=81=E7=9A=84?= =?UTF-8?q?=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../dataset/goal_states_generation.py | 47 ++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/robowaiter/behavior_tree/dataset/goal_states_generation.py b/robowaiter/behavior_tree/dataset/goal_states_generation.py index 550d86f..3cc2770 100644 --- a/robowaiter/behavior_tree/dataset/goal_states_generation.py +++ b/robowaiter/behavior_tree/dataset/goal_states_generation.py @@ -29,6 +29,34 @@ def single_predict_generation(oplist_1, oplist_2, predict_pattern) -> str: raise RuntimeError('Incorrect predict pattern!') +def enumerate_predict(oplist_1, oplist_2, predict_pattern) -> [int, list]: + count = 0 + res = [] + + match predict_pattern: + case 'at': + pattern = f'At(%s, %s)' + case 'is': + pattern = f'Is(%s, %s)' + case 'hold': + pattern = f'Holding(%s)' + case 'on': + pattern = f'On(%s, %s)' + case _: + raise RuntimeError('Incorrect predict pattern!') + + for str_1 in oplist_1: + if oplist_2: + for str_2 in oplist_2: + count += 1 + res.append({pattern % (str_1, str_2)}) + else: + count += 1 + res.append({pattern % str_1}) + + return count, res + + def generate_goal_states(vln_num: int, vlm_num: int, opentask_num: int): # res stores lists of sets, while each state represent in set. res = [] @@ -63,4 +91,21 @@ def generate_goal_states(vln_num: int, vlm_num: int, opentask_num: int): return res -generate_goal_states(30, 6, 6) +def enumerate_goal_states(): + # goal states for VLN + count_vln, list_vln = enumerate_predict(['Robot'], Place, 'at') + print(f'VLN 任务的目标状态数:{count_vln}') + + # goal states for VLM + count_vlm_1, list_vlm_1 = enumerate_predict(['Robot'], Place, 'at') + count_vlm_2, list_vlm_2 = enumerate_predict(Operable, ['0', '1'], 'is') + print(f'VLM 任务的目标状态数:{count_vlm_1 * count_vlm_2}') + + # goal states for open-task + count_opentask_1, list_opentask_1 = enumerate_predict(['Robot'], Place, 'at') + count_opentask_2, list_opentask_2 = enumerate_predict(Object, Place, 'on') + print(f'Open-task-1 任务的目标状态数:{count_opentask_1 * count_opentask_2}') + + +# generate_goal_states(30, 6, 6) +enumerate_goal_states()