80 lines
2.6 KiB
Python
80 lines
2.6 KiB
Python
import warnings
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
|
|
warnings.filterwarnings("ignore")
|
|
|
|
from .grasp import GraspRectangles, detect_grasps
|
|
|
|
|
|
def plot_output(fig, rgb_img, grasp_q_img, grasp_angle_img, depth_img=None, no_grasps=1, grasp_width_img=None):
|
|
"""
|
|
Plot the output of a network
|
|
:param fig: Figure to plot the output
|
|
:param rgb_img: RGB Image
|
|
:param depth_img: Depth Image
|
|
:param grasp_q_img: Q output of network
|
|
:param grasp_angle_img: Angle output of network
|
|
:param no_grasps: Maximum number of grasps to plot
|
|
:param grasp_width_img: (optional) Width output of network
|
|
:return:
|
|
"""
|
|
gs = detect_grasps(grasp_q_img, grasp_angle_img, width_img=grasp_width_img, no_grasps=no_grasps)
|
|
|
|
plt.ion()
|
|
plt.clf()
|
|
ax = fig.add_subplot(2, 2, 1)
|
|
ax.imshow(rgb_img)
|
|
for g in gs:
|
|
g.plot(ax)
|
|
ax.set_title('RGB')
|
|
ax.axis('off')
|
|
|
|
if depth_img:
|
|
ax = fig.add_subplot(2, 2, 2)
|
|
ax.imshow(depth_img, cmap='gray')
|
|
for g in gs:
|
|
g.plot(ax)
|
|
ax.set_title('Depth')
|
|
ax.axis('off')
|
|
|
|
ax = fig.add_subplot(2, 2, 3)
|
|
plot = ax.imshow(grasp_q_img, cmap='jet', vmin=0, vmax=1)
|
|
ax.set_title('Q')
|
|
ax.axis('off')
|
|
plt.colorbar(plot)
|
|
|
|
ax = fig.add_subplot(2, 2, 4)
|
|
plot = ax.imshow(grasp_angle_img, cmap='hsv', vmin=-np.pi / 2, vmax=np.pi / 2)
|
|
ax.set_title('Angle')
|
|
ax.axis('off')
|
|
plt.colorbar(plot)
|
|
plt.pause(0.1)
|
|
fig.canvas.draw()
|
|
|
|
|
|
def calculate_iou_match(grasp_q, grasp_angle, ground_truth_bbs, no_grasps=1, grasp_width=None, threshold=0.25):
|
|
"""
|
|
Calculate grasp success using the IoU (Jacquard) metric (e.g. in https://arxiv.org/abs/1301.3592)
|
|
A success is counted if grasp rectangle has a 25% IoU with a ground truth, and is withing 30 degrees.
|
|
:param grasp_q: Q outputs of network (Nx300x300x3)
|
|
:param grasp_angle: Angle outputs of network
|
|
:param ground_truth_bbs: Corresponding ground-truth BoundingBoxes
|
|
:param no_grasps: Maximum number of grasps to consider per image.
|
|
:param grasp_width: (optional) Width output from network
|
|
:param threshold: Threshold for IOU matching. Detect with IOU ≥ threshold
|
|
:return: success
|
|
"""
|
|
|
|
if not isinstance(ground_truth_bbs, GraspRectangles):
|
|
gt_bbs = GraspRectangles.load_from_array(ground_truth_bbs)
|
|
else:
|
|
gt_bbs = ground_truth_bbs
|
|
gs = detect_grasps(grasp_q, grasp_angle, width_img=grasp_width, no_grasps=no_grasps)
|
|
for g in gs:
|
|
if g.max_iou(gt_bbs) > threshold:
|
|
return True
|
|
else:
|
|
return False
|