RoboWaiter/robowaiter/scene/ui/pyqt5.py

349 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import importlib
import os
from PyQt5.QtSvg import QGraphicsSvgItem, QSvgWidget
from PyQt5.QtWidgets import QApplication, QMainWindow, QListWidgetItem, QGraphicsView, QGraphicsScene, \
QGraphicsPixmapItem, QGraphicsProxyWidget
from PyQt5.QtCore import QTimer, QPoint, QRectF
from PyQt5 import QtCore
import sys
from robowaiter.scene.ui.window import Ui_MainWindow
from robowaiter.utils.basic import get_root_path
from PyQt5.QtCore import QThread
import queue
import numpy as np
from PyQt5.QtGui import QImage, QPixmap, QDrag, QPainter
from PyQt5.QtCore import Qt
from PyQt5.QtCore import QThread, pyqtSignal
from robowaiter.scene.ui.scene_ui import SceneUI
from robowaiter.scene.ui import scene_ui
root_path = get_root_path()
class TaskThread(QThread):
stop_signal = pyqtSignal()
def __init__(self, task, scene_cls,robot_cls,scene_queue,ui_queue, *args, **kwargs):
super(TaskThread, self).__init__(*args, **kwargs)
self.task = task
self.scene_cls = scene_cls
self.robot_cls = robot_cls
self.scene_queue = scene_queue
self.ui_queue = ui_queue
self.stoped = True
def run(self):
self.scene = self.task(self.scene_cls,self.robot_cls,self.scene_queue,self.ui_queue)
self.scene._run()
# scene._run()
while not self.isInterruptionRequested():
self.scene.step()
# def stop(self):
# del self.scene
def run_scene(scene_cls,robot_cls,scene_queue,ui_queue):
scene = scene_cls(robot_cls(),scene_queue,ui_queue)
scene.reset()
# scene.run()
return scene
# try:
# # 机器人系统的主循环
#
# except Exception as e:
# # 打印异常信息到命令行
# print("Robot system error:", str(e))
example_list = ("AEM","VLN","VLM",'GQA',"OT","AT","reset")
more_example_list = ("VLM_AC","CafeDaily")
dic_more2zh={
"VLM_AC":"开空调并调节空调温度",
"CafeDaily":"咖啡厅的一天"
}
more_example_list_zh = [value for value in dic_more2zh.values()]
class GraphicsView(QGraphicsView):
def __init__(self, parent=None):
super(GraphicsView, self).__init__(parent)
self.setRenderHint(QPainter.Antialiasing)
self.setOptimizationFlag(QGraphicsView.DontAdjustForAntialiasing, True)
self.setViewportUpdateMode(QGraphicsView.FullViewportUpdate)
self.setDragMode(QGraphicsView.ScrollHandDrag)
self.setHorizontalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOff)
self.setVerticalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOff)
def wheelEvent(self, event):
factor = 1.15
if event.angleDelta().y() > 0:
self.scale(factor, factor)
else:
self.scale(1.0 / factor, 1.0 / factor)
class UI(QMainWindow, Ui_MainWindow):
scene = None
history_dict = {
"System": ""
}
def __init__(self,robot_cls):
app = QApplication(sys.argv)
MainWindow = QMainWindow()
super(UI, self).__init__(MainWindow)
# 创建线程安全的队列
self.scene_queue = queue.Queue()
self.ui_queue = queue.Queue()
self.scene_cls = SceneUI
self.robot_cls = robot_cls
self.setupUi(MainWindow) # 初始化UI
MainWindow.setWindowTitle("测试V1.0")
# MainWindow.setWindowIcon(QIcon("icons/umbrella.ico"))
# 绑定说话按钮
self.btn_say.clicked.connect(self.btn_say_on_click)
self.verticalLayout_2.removeWidget(self.img_view_bt)
self.img_view_bt = GraphicsView(self.centralwidget)
self.img_view_bt.setObjectName("img_view_bt")
self.verticalLayout_2.addWidget(self.img_view_bt)
self.img_view_bt.setDragMode(QGraphicsView.ScrollHandDrag)
# 下拉菜单绑定函数
# 下拉菜单.seleted.connect(self.你写的函数)
# # BT树的显示
# self.setScaledContents(True) # 图片将根据label的大小自动缩放
# self.label_moved = False # 用于判断是否移动了标签
# self.drag_start_position = QPoint() # 记录鼠标按下时的位置
self.cb_task.setCurrentIndex(-1)
self.cb_task.setPlaceholderText("请选择更多任务")
# 多个添加条目
self.cb_task.addItems(more_example_list_zh)
# 当下拉索引发生改变时发射信号触发绑定的事件
self.cb_task.currentIndexChanged.connect(self.cb_selectionchange)
# 绑定任务演示按钮
for example in example_list:
btn = getattr(self,f"btn_{example}")
btn.clicked.connect(self.create_example_click(example))
# 设置一个定时器每隔100ms检查一次队列
timer = QTimer()
timer.setInterval(100) # 设置检查队列的间隔时间
timer.timeout.connect(self.handle_queue_messages) # 当定时器超时时,处理队列中的消息
timer.start() # 启动定时器
MainWindow.show()
self.thread = TaskThread(run_scene,self.scene_cls,self.robot_cls,self.scene_queue,self.ui_queue)
self.thread.start()
self.list_customer.itemClicked.connect(self.show_customer_history)
sys.exit(app.exec_())
def cb_selectionchange(self, i):
print(i)
self.create_example_click(more_example_list[i])()
def get_semantic_info(self, semantic_info_str):
# self.textEdit_5.clear()
# self.textEdit_5.append(semantic_info_str)
self.textBrowser_5.clear()
self.textBrowser_5.append(semantic_info_str)
def new_history(self,customer_name,chat):
role = chat["role"]
content = chat["content"].strip()
function_call = chat.get("function_call",None)
if function_call:
return
if role == "function":
new_chat = f'Robot:\n 工具调用-{chat["name"]}\n'
elif role == "user":
new_chat = f'{customer_name}:\n {content}\n'
else:
new_chat = f'Robot:\n {content}\n'
self.history_dict[customer_name] += new_chat + "\n"
self.history_dict["Global"] += new_chat + "\n"
items = self.list_customer.findItems(customer_name, Qt.MatchExactly)
if items:
index = self.list_customer.indexFromItem(items[0])
self.list_customer.setCurrentIndex(index)
self.edit_local_history.clear()
self.edit_local_history.append(self.history_dict["Global"])
def reset(self):
self.history_dict = {
"Global":"",
"System": "",
}
self.edit_local_history.clear()
self.list_customer.clear()
item = QListWidgetItem("Global")
self.list_customer.addItem(item)
item = QListWidgetItem("System")
self.list_customer.addItem(item)
def create_example_click(self,name):
def btn_example_on_click():
self.reset()
self.thread.requestInterruption()
self.thread.wait() # 等待线程安全退出
self.scene_queue = queue.Queue()
self.ui_queue = queue.Queue()
self.thread = TaskThread(run_scene, self.scene_cls, self.robot_cls,self.scene_queue,self.ui_queue)
self.thread.start()
self.scene_func((f"run_{name}",))
return btn_example_on_click
def add_walker(self,name):
self.history_dict[name] = ""
item = QListWidgetItem(name)
self.list_customer.addItem(item)
def show_customer_history(self, item):
# 此处为示例,实际上你可能需要从数据库或其他地方获取聊天记录
name = item.text()
self.edit_local_history.clear()
self.edit_local_history.append(self.history_dict[name])
def btn_say_on_click(self):
question = self.edit_say.text()
if question[-1] == ")":
print(f"设置目标:{question}")
self.scene_func(("new_set_goal",question))
# self.scene.new_set_goal(question)
else:
self.scene_func(("customer_say","System",question))
# self.scene.customer_say("System", question)
def scene_func(self,args):
self.scene_queue.put(args)
def handle_queue_messages(self):
while not self.ui_queue.empty():
message = self.ui_queue.get()
function_name = message[0]
function = getattr(self, function_name, None)
args = []
if len(message)>1:
args = message[1:]
result = function(*args)
def setupUi(self, MainWindow):
Ui_MainWindow.setupUi(self,MainWindow)
# dir_path = os.path.dirname(__file__)
# 加载并显示图片
# img_path = os.path.join(root_path, "robowaiter/utils/draw_bt/test.png")
# self.draw_from_file("img_label_bt",img_path)
# self.label.setAlignment(Qt.AlignCenter) # 图片居中显示
def draw_from_file(self,control_name,file_name):
control = getattr(self,control_name,None)
# 加载并显示图片
pixmap = QPixmap(file_name) # 替换为你的图片路径
if control_name == "img_view_bt":
# self.img_view_bt = ImageView()
scene = QGraphicsScene(self.img_view_bt)
self.img_view_bt.setScene(scene)
pixmap_item = QGraphicsPixmapItem(pixmap)
scene.addItem(pixmap_item)
# widget = QSvgWidget(file_name)
# # widget.resize(10000, 1600)
# svg_width = widget.width()
# widget.resize(svg_width*2, 1600)
# proxy = QGraphicsProxyWidget()
# proxy.setWidget(widget)
# scene.addItem(+proxy)
elif control_name == "img_label_obj":
return
# self.label.setPixmap(pixmap)
else:
control.setPixmap(self.scale_pixmap_to_label(pixmap, control))
def draw_img(self,control_name,img):
control = getattr(self,control_name,None)
# # 加载并显示图片
# print(img)
# img = self.ndarray_to_qimage(img)
# print(img)
# image = Image.open(io.BytesIO(img))
# print(image)
pixmap = QPixmap.fromImage(QImage.fromData(img))
# self.label.setPixmap(pixmap)
control.setPixmap(self.scale_pixmap_to_label(pixmap, control))
# control.setPixmap(pixmap)
def draw_canvas(self,control_name,canvas):
control = getattr(self,control_name,None)
# 加载并显示图片
pixmap = QApplication.grabWidget(canvas)
# self.label.setPixmap(pixmap)
control.setPixmap(self.scale_pixmap_to_label(pixmap, control))
def ndarray_to_qimage(self,array: np.ndarray) -> QImage:
# 假设你的ndarray是uint8类型的且是RGB格式三个通道
height, width, channels = array.shape
if channels ==3:
bytes_per_line = channels * width
image = QImage(array.data, width, height, bytes_per_line, QImage.Format_RGB888)
else:
# 对于灰度图像
image = QImage(array.data, width, height, QImage.Format_Grayscale8)
return image
def scale_pixmap_to_label(self, pixmap, label):
# 获取QLabel的大小
label_width = label.width()
label_height = label.height()
# 保持图片的长宽比同时确保图片的一个维度适应QLabel的大小
scaled_pixmap = pixmap.scaledToWidth(label_width, Qt.SmoothTransformation)
if scaled_pixmap.height() > label_height:
scaled_pixmap = pixmap.scaledToHeight(label_height, Qt.SmoothTransformation)
return scaled_pixmap
def set_scene(self, scene):
self.scene = scene
if __name__ == "__main__":
# app = QApplication(sys.argv)
# MainWindow = QMainWindow()
ui = UI()
# ui.setupUi(MainWindow)
# MainWindow.show()
# sys.exit(app.exec_())