IMPROVE: 增加目标状态数据,删除非holding谓词的nothing

This commit is contained in:
wuziji 2023-11-14 20:22:53 +08:00
parent a012fabf6d
commit 6adc7c750e
4 changed files with 1532 additions and 1655 deletions

File diff suppressed because it is too large Load Diff

View File

@ -2,7 +2,7 @@
import os
import re
Object = ['Softdrink', 'BottledDrink', 'Yogurt', 'ADMilk', 'MilkDrink', 'Milk', 'VacuumCup', 'Nothing']
Object = ['Softdrink', 'BottledDrink', 'Yogurt', 'ADMilk', 'MilkDrink', 'Milk', 'VacuumCup']
Cookable = ['Coffee', 'Water', 'Dessert']
@ -10,7 +10,7 @@ Place = ['Bar', 'WaterTable', 'CoffeeTable', 'Bar2', 'Table1', 'Table2', 'Table3
Entity = ['Robot', 'Customer']
Operable = ['AC', 'ACTemperature', 'HallLight', 'TubeLight', 'Curtain']
Operable = ['AC', 'ACTemperature', 'HallLight', 'TubeLight', 'Curtain', 'Chairs', 'Floor', 'Table']
import random
@ -96,7 +96,6 @@ def generate_goal_states(vln_num: int, vlm_num: int, opentask_num: int):
def enumerate_goal_states(total: int):
res = []
point_15 = int(total * .15)
@ -104,7 +103,7 @@ def enumerate_goal_states(total: int):
# goal states for VLN, .15
count_vln, list_vln = enumerate_predict(['Robot'], Place, 'at')
list_vln = ['{%s}'%i for i in list_vln]
list_vln = ['{%s}' % i for i in list_vln]
if count_vln < point_15:
list_vln *= point_15 // count_vln
for i in range(0, point_15 - len(list_vln)):
@ -131,7 +130,7 @@ def enumerate_goal_states(total: int):
res += list_vlm_2
# goal states for VLM-3, 0.1
count_vlm_3, list_vlm_3 = enumerate_predict(Object, None, 'hold')
count_vlm_3, list_vlm_3 = enumerate_predict(Object + ['Nothing'], None, 'hold')
list_vlm_3 = ['{%s}' % i for i in list_vlm_3]
if count_vlm_3 < point_10:
list_vlm_3 *= point_10 // count_vlm_3
@ -157,7 +156,7 @@ def enumerate_goal_states(total: int):
list_tmp.append('{%s,%s}' % (i, j))
if len(list_tmp) < point_10:
list_tmp *= point_10 // len(list_tmp)
list_tmp += list_tmp[0:point_10-len(list_tmp)]
list_tmp += list_tmp[0:point_10 - len(list_tmp)]
else:
list_tmp = list_tmp[:point_10]
res += list_tmp
@ -208,8 +207,12 @@ def enumerate_goal_states(total: int):
elif 'Is' in i and ('AC' in i or 'HallLight' in i or 'TubeLight' in i or 'Curtain' in i):
i = re.sub('0', 'Off', i)
i = re.sub('1', 'On', i)
elif 'Is' in i and ('Chairs' in i or 'Floor' in i or 'Table' in i):
i = re.sub('0', 'Dirty', i)
i = re.sub('1', 'Clean', i)
file.write(i + '\n')
file.write(i+'\n')
# generate_goal_states(30, 6, 6)
enumerate_goal_states(5000)

View File

@ -8,9 +8,9 @@ project_path = "./robowaiter"
ptml_path = os.path.join(project_path, 'robot/Default.ptml')
behavior_lib_path = os.path.join(project_path, 'behavior_lib')
robot = Robot(ptml_path,behavior_lib_path)
robot = Robot(ptml_path, behavior_lib_path)
# create task
task = task_map[TASK_NAME](robot)
task.reset()
task.run()
task.run()

File diff suppressed because it is too large Load Diff