修复bug
This commit is contained in:
parent
7bd1edeeb2
commit
30c8968ec0
|
@ -7,15 +7,14 @@ import grpc
|
|||
|
||||
from explore import Explore
|
||||
|
||||
sys.path.append('./')
|
||||
sys.path.append('../')
|
||||
sys.path.append('/')
|
||||
sys.path.append('../navigate/')
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
||||
|
||||
import GrabSim_pb2_grpc
|
||||
import GrabSim_pb2
|
||||
from robowaiter.proto import GrabSim_pb2_grpc, GrabSim_pb2
|
||||
|
||||
channel = grpc.insecure_channel('localhost:30001', options=[
|
||||
('grpc.max_send_message_length', 1024 * 1024 * 1024),
|
|
@ -2,22 +2,7 @@ import py_trees as ptree
|
|||
from robowaiter.behavior_lib._base.Act import Act
|
||||
from robowaiter.llm_client.ask_llm import ask_llm
|
||||
|
||||
fixed_answers = {
|
||||
"测试VLM:做一杯咖啡":
|
||||
'''
|
||||
测试VLM:做一杯咖啡
|
||||
---
|
||||
{"At(Coffee,Bar)"}
|
||||
'''
|
||||
,
|
||||
"测试VLN:前往桌子":
|
||||
'''
|
||||
测试VLN:前往桌子
|
||||
---
|
||||
{"At(Robot,Table)"}
|
||||
'''
|
||||
,
|
||||
}
|
||||
|
||||
class DealChat(Act):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -27,20 +12,32 @@ class DealChat(Act):
|
|||
chat = self.scene.state['chat_list'].pop()
|
||||
|
||||
# 判断是否是测试
|
||||
if chat in fixed_answers.keys():
|
||||
sentence,goal = fixed_answers[chat].split("---")
|
||||
sentence = sentence.strip()
|
||||
goal = goal.strip()
|
||||
print(f'机器人回答:{sentence}')
|
||||
# if chat in fixed_answers.keys():
|
||||
# sentence,goal = fixed_answers[chat].split("---")
|
||||
# sentence = sentence.strip()
|
||||
# goal = goal.strip()
|
||||
# print(f'机器人回答:{sentence}')
|
||||
# goal = eval(goal)
|
||||
# print(f'goal:{goal}')
|
||||
#
|
||||
# self.create_sub_task(goal)
|
||||
# else:
|
||||
answer = ask_llm(chat)
|
||||
answer_split = answer.split("---")
|
||||
sentence = answer_split[0].strip()
|
||||
goal = None
|
||||
if len(answer_split) > 1:
|
||||
goal = answer_split[1].strip()
|
||||
|
||||
print(f'{sentence}')
|
||||
if goal:
|
||||
goal = eval(goal)
|
||||
print(f'goal:{goal}')
|
||||
|
||||
self.create_sub_task(goal)
|
||||
else:
|
||||
answer = ask_llm(chat)
|
||||
print(f"机器人回答:{answer}")
|
||||
if self.scene.show_bubble:
|
||||
self.scene.chat_bubble(f"机器人回答:{answer}")
|
||||
|
||||
if self.scene.show_bubble:
|
||||
self.scene.chat_bubble(f"{answer}")
|
||||
|
||||
return ptree.common.Status.RUNNING
|
||||
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
|
||||
import requests
|
||||
import urllib3
|
||||
|
||||
from robowaiter.utils import get_root_path
|
||||
from robowaiter.llm_client.single_round import single_round
|
||||
########################################
|
||||
# 该文件实现了与大模型的简单通信
|
||||
########################################
|
||||
|
@ -9,30 +10,13 @@ import urllib3
|
|||
# 忽略https的安全性警告
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
root_path = get_root_path()
|
||||
# load test questions
|
||||
|
||||
|
||||
def ask_llm(question):
|
||||
url = "https://45.125.46.134:25344/v1/chat/completions"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
data = {
|
||||
"model": "RoboWaiter",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "你是一个机器人服务员:RoboWaiter. 你的职责是为顾客提供对话及具身服务。"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": question
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
response = requests.post(url, headers=headers, json=data, verify=False)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
return result['choices'][0]['message']['content'].strip()
|
||||
else:
|
||||
return "大模型请求失败:", response.status_code
|
||||
ans = single_round(question)
|
||||
return ans
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
|
||||
import requests
|
||||
import urllib3
|
||||
########################################
|
||||
# 该文件实现了与大模型的简单通信
|
||||
########################################
|
||||
|
||||
# 忽略https的安全性警告
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
|
||||
def single_round(question):
|
||||
url = "https://45.125.46.134:25344/v1/chat/completions"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
data = {
|
||||
"model": "RoboWaiter",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "你是一个机器人服务员:RoboWaiter. 你的职责是为顾客提供对话及具身服务。"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": question
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
response = requests.post(url, headers=headers, json=data, verify=False)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
return result['choices'][0]['message']['content'].strip()
|
||||
else:
|
||||
return "大模型请求失败:", response.status_code
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
question = '''
|
||||
python中如何通过类名字符串的方式来代替isinstance的作用
|
||||
'''
|
||||
|
||||
print(single_round(question))
|
|
@ -0,0 +1,14 @@
|
|||
"测试VLM:做一杯咖啡":
|
||||
'''
|
||||
测试VLM:做一杯咖啡
|
||||
---
|
||||
{"At(Coffee,Bar)"}
|
||||
'''
|
||||
,
|
||||
"测试VLN:前往桌子":
|
||||
'''
|
||||
测试VLN:前往桌子
|
||||
---
|
||||
{"At(Robot,Table)"}
|
||||
'''
|
||||
,
|
|
@ -1,7 +1,4 @@
|
|||
selector
|
||||
// selector
|
||||
// cond HasMap()
|
||||
// act ExploreEnv()
|
||||
{
|
||||
sequence
|
||||
{
|
||||
|
@ -15,4 +12,4 @@ selector
|
|||
{
|
||||
act SubTaskPlaceHolder()
|
||||
|
||||
} cond At(Talb,ea)}
|
||||
} cond At(Talb,ea)}}
|
|
@ -1,5 +1,12 @@
|
|||
|
||||
import os
|
||||
|
||||
from robowaiter.utils import *
|
||||
from robowaiter.utils import *
|
||||
|
||||
|
||||
|
||||
|
||||
def get_root_path():
|
||||
return os.path.abspath(
|
||||
os.path.join(__file__, "../..")
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue