修复bug
This commit is contained in:
parent
7bd1edeeb2
commit
30c8968ec0
|
@ -7,15 +7,14 @@ import grpc
|
||||||
|
|
||||||
from explore import Explore
|
from explore import Explore
|
||||||
|
|
||||||
sys.path.append('./')
|
sys.path.append('/')
|
||||||
sys.path.append('../')
|
sys.path.append('../navigate/')
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
||||||
|
|
||||||
import GrabSim_pb2_grpc
|
from robowaiter.proto import GrabSim_pb2_grpc, GrabSim_pb2
|
||||||
import GrabSim_pb2
|
|
||||||
|
|
||||||
channel = grpc.insecure_channel('localhost:30001', options=[
|
channel = grpc.insecure_channel('localhost:30001', options=[
|
||||||
('grpc.max_send_message_length', 1024 * 1024 * 1024),
|
('grpc.max_send_message_length', 1024 * 1024 * 1024),
|
|
@ -2,22 +2,7 @@ import py_trees as ptree
|
||||||
from robowaiter.behavior_lib._base.Act import Act
|
from robowaiter.behavior_lib._base.Act import Act
|
||||||
from robowaiter.llm_client.ask_llm import ask_llm
|
from robowaiter.llm_client.ask_llm import ask_llm
|
||||||
|
|
||||||
fixed_answers = {
|
|
||||||
"测试VLM:做一杯咖啡":
|
|
||||||
'''
|
|
||||||
测试VLM:做一杯咖啡
|
|
||||||
---
|
|
||||||
{"At(Coffee,Bar)"}
|
|
||||||
'''
|
|
||||||
,
|
|
||||||
"测试VLN:前往桌子":
|
|
||||||
'''
|
|
||||||
测试VLN:前往桌子
|
|
||||||
---
|
|
||||||
{"At(Robot,Table)"}
|
|
||||||
'''
|
|
||||||
,
|
|
||||||
}
|
|
||||||
class DealChat(Act):
|
class DealChat(Act):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -27,20 +12,32 @@ class DealChat(Act):
|
||||||
chat = self.scene.state['chat_list'].pop()
|
chat = self.scene.state['chat_list'].pop()
|
||||||
|
|
||||||
# 判断是否是测试
|
# 判断是否是测试
|
||||||
if chat in fixed_answers.keys():
|
# if chat in fixed_answers.keys():
|
||||||
sentence,goal = fixed_answers[chat].split("---")
|
# sentence,goal = fixed_answers[chat].split("---")
|
||||||
sentence = sentence.strip()
|
# sentence = sentence.strip()
|
||||||
goal = goal.strip()
|
# goal = goal.strip()
|
||||||
print(f'机器人回答:{sentence}')
|
# print(f'机器人回答:{sentence}')
|
||||||
|
# goal = eval(goal)
|
||||||
|
# print(f'goal:{goal}')
|
||||||
|
#
|
||||||
|
# self.create_sub_task(goal)
|
||||||
|
# else:
|
||||||
|
answer = ask_llm(chat)
|
||||||
|
answer_split = answer.split("---")
|
||||||
|
sentence = answer_split[0].strip()
|
||||||
|
goal = None
|
||||||
|
if len(answer_split) > 1:
|
||||||
|
goal = answer_split[1].strip()
|
||||||
|
|
||||||
|
print(f'{sentence}')
|
||||||
|
if goal:
|
||||||
goal = eval(goal)
|
goal = eval(goal)
|
||||||
print(f'goal:{goal}')
|
print(f'goal:{goal}')
|
||||||
|
|
||||||
self.create_sub_task(goal)
|
self.create_sub_task(goal)
|
||||||
else:
|
|
||||||
answer = ask_llm(chat)
|
if self.scene.show_bubble:
|
||||||
print(f"机器人回答:{answer}")
|
self.scene.chat_bubble(f"{answer}")
|
||||||
if self.scene.show_bubble:
|
|
||||||
self.scene.chat_bubble(f"机器人回答:{answer}")
|
|
||||||
|
|
||||||
return ptree.common.Status.RUNNING
|
return ptree.common.Status.RUNNING
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
import urllib3
|
import urllib3
|
||||||
|
from robowaiter.utils import get_root_path
|
||||||
|
from robowaiter.llm_client.single_round import single_round
|
||||||
########################################
|
########################################
|
||||||
# 该文件实现了与大模型的简单通信
|
# 该文件实现了与大模型的简单通信
|
||||||
########################################
|
########################################
|
||||||
|
@ -9,30 +10,13 @@ import urllib3
|
||||||
# 忽略https的安全性警告
|
# 忽略https的安全性警告
|
||||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||||
|
|
||||||
|
root_path = get_root_path()
|
||||||
|
# load test questions
|
||||||
|
|
||||||
|
|
||||||
def ask_llm(question):
|
def ask_llm(question):
|
||||||
url = "https://45.125.46.134:25344/v1/chat/completions"
|
ans = single_round(question)
|
||||||
headers = {"Content-Type": "application/json"}
|
return ans
|
||||||
data = {
|
|
||||||
"model": "RoboWaiter",
|
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"content": "你是一个机器人服务员:RoboWaiter. 你的职责是为顾客提供对话及具身服务。"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": question
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
response = requests.post(url, headers=headers, json=data, verify=False)
|
|
||||||
|
|
||||||
if response.status_code == 200:
|
|
||||||
result = response.json()
|
|
||||||
return result['choices'][0]['message']['content'].strip()
|
|
||||||
else:
|
|
||||||
return "大模型请求失败:", response.status_code
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -0,0 +1,43 @@
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import urllib3
|
||||||
|
########################################
|
||||||
|
# 该文件实现了与大模型的简单通信
|
||||||
|
########################################
|
||||||
|
|
||||||
|
# 忽略https的安全性警告
|
||||||
|
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||||
|
|
||||||
|
|
||||||
|
def single_round(question):
|
||||||
|
url = "https://45.125.46.134:25344/v1/chat/completions"
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
data = {
|
||||||
|
"model": "RoboWaiter",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "你是一个机器人服务员:RoboWaiter. 你的职责是为顾客提供对话及具身服务。"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": question
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(url, headers=headers, json=data, verify=False)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
result = response.json()
|
||||||
|
return result['choices'][0]['message']['content'].strip()
|
||||||
|
else:
|
||||||
|
return "大模型请求失败:", response.status_code
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
question = '''
|
||||||
|
python中如何通过类名字符串的方式来代替isinstance的作用
|
||||||
|
'''
|
||||||
|
|
||||||
|
print(single_round(question))
|
|
@ -0,0 +1,14 @@
|
||||||
|
"测试VLM:做一杯咖啡":
|
||||||
|
'''
|
||||||
|
测试VLM:做一杯咖啡
|
||||||
|
---
|
||||||
|
{"At(Coffee,Bar)"}
|
||||||
|
'''
|
||||||
|
,
|
||||||
|
"测试VLN:前往桌子":
|
||||||
|
'''
|
||||||
|
测试VLN:前往桌子
|
||||||
|
---
|
||||||
|
{"At(Robot,Table)"}
|
||||||
|
'''
|
||||||
|
,
|
|
@ -1,7 +1,4 @@
|
||||||
selector
|
selector
|
||||||
// selector
|
|
||||||
// cond HasMap()
|
|
||||||
// act ExploreEnv()
|
|
||||||
{
|
{
|
||||||
sequence
|
sequence
|
||||||
{
|
{
|
||||||
|
@ -15,4 +12,4 @@ selector
|
||||||
{
|
{
|
||||||
act SubTaskPlaceHolder()
|
act SubTaskPlaceHolder()
|
||||||
|
|
||||||
} cond At(Talb,ea)}
|
} cond At(Talb,ea)}}
|
|
@ -1,5 +1,12 @@
|
||||||
|
import os
|
||||||
|
|
||||||
from robowaiter.utils import *
|
from robowaiter.utils import *
|
||||||
from robowaiter.utils import *
|
from robowaiter.utils import *
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def get_root_path():
|
||||||
|
return os.path.abspath(
|
||||||
|
os.path.join(__file__, "../..")
|
||||||
|
)
|
||||||
|
|
Loading…
Reference in New Issue