From 075f9ca6fe2b2569a433eb58eef137c70623226b Mon Sep 17 00:00:00 2001 From: wuziji <2193177243@qq.com> Date: Mon, 13 Nov 2023 23:06:22 +0800 Subject: [PATCH] =?UTF-8?q?New:=20=E6=96=B0=E5=A2=9E=E4=BA=86=E8=87=AA?= =?UTF-8?q?=E5=8A=A8=E7=94=9F=E6=88=90bt-expansion=E7=9B=AE=E6=A0=87?= =?UTF-8?q?=E7=8A=B6=E6=80=81=E7=9A=84=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- robowaiter/behavior_tree/dataset/__init__.py | 0 .../dataset/goal_states_generation.py | 66 +++++++++++++++++++ 2 files changed, 66 insertions(+) create mode 100644 robowaiter/behavior_tree/dataset/__init__.py create mode 100644 robowaiter/behavior_tree/dataset/goal_states_generation.py diff --git a/robowaiter/behavior_tree/dataset/__init__.py b/robowaiter/behavior_tree/dataset/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/robowaiter/behavior_tree/dataset/goal_states_generation.py b/robowaiter/behavior_tree/dataset/goal_states_generation.py new file mode 100644 index 0000000..550d86f --- /dev/null +++ b/robowaiter/behavior_tree/dataset/goal_states_generation.py @@ -0,0 +1,66 @@ +# the empty string '' represents robot holds nothing +Object = ['Coffee', 'Water', 'Dessert', 'Softdrink', 'BottledDrink', 'Yogurt', 'ADMilk', 'MilkDrink', 'Milk', + 'VacuumCup', ''] + +Place = ['Bar', 'WaterTable', 'CoffeeTable', 'Bar2', 'Table1', 'Table2', 'Table3'] + +Entity = ['Robot', 'Customer'] + +Operable = ['AC', 'ACTemperature', 'HallLight', 'TubeLight', 'Curtain'] + +import random + + +def single_predict_generation(oplist_1, oplist_2, predict_pattern) -> str: + index_1 = random.randint(0, len(oplist_1) - 1) + if oplist_2: + index_2 = random.randint(0, len(oplist_2) - 1) + + match predict_pattern: + case 'at': + return f'At({oplist_1[index_1]}, {oplist_2[index_2]})' + case 'is': + return f'Is({oplist_1[index_1]}, {oplist_2[index_2]})' + case 'hold': + return f'Holding({oplist_1[index_1]})' + case 'on': + return f'On({oplist_1[index_1]}, {oplist_2[index_2]})' + case _: + raise RuntimeError('Incorrect predict pattern!') + + +def generate_goal_states(vln_num: int, vlm_num: int, opentask_num: int): + # res stores lists of sets, while each state represent in set. + res = [] + + # goal states for VLN + for i in range(vln_num): + res.append({single_predict_generation(['Robot'], Place, 'at')}) + + # goal states for VLM + for i in range(int(vlm_num)): + for j in range(int(vlm_num)): + res.append( + { + single_predict_generation(['Robot'], Place, 'at'), + single_predict_generation(Operable, ['0', '1'], 'is') + } + ) + + # goal states for Open-task-1 + for i in range(int(opentask_num)): + for j in range(int(opentask_num)): + res.append( + { + single_predict_generation(['Robot'], Place, 'at'), + single_predict_generation(Object, Place, 'on') + } + ) + + # print(res) + # print(len(res)) + + return res + + +generate_goal_states(30, 6, 6)