104 lines
3.5 KiB
Python
104 lines
3.5 KiB
Python
import argparse
|
|
import logging
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import torch.utils.data
|
|
from PIL import Image
|
|
|
|
from network.hardware.device import get_device
|
|
from network.inference.post_process import post_process_output
|
|
from network.utils.data.camera_data import CameraData
|
|
from network.utils.visualisation.plot import plot_results, save_results
|
|
from network.utils.dataset_processing.grasp import detect_grasps
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description='Evaluate network')
|
|
parser.add_argument('--network', type=str,
|
|
help='Path to saved network to evaluate')
|
|
parser.add_argument('--rgb_path', type=str, default='cornell/08/pcd0845r.png',
|
|
help='RGB Image path')
|
|
parser.add_argument('--depth_path', type=str, default='cornell/08/pcd0845d.tiff',
|
|
help='Depth Image path')
|
|
parser.add_argument('--use-depth', type=int, default=1,
|
|
help='Use Depth image for evaluation (1/0)')
|
|
parser.add_argument('--use-rgb', type=int, default=1,
|
|
help='Use RGB image for evaluation (1/0)')
|
|
parser.add_argument('--n-grasps', type=int, default=1,
|
|
help='Number of grasps to consider per image')
|
|
parser.add_argument('--save', type=int, default=0,
|
|
help='Save the results')
|
|
parser.add_argument('--cpu', dest='force_cpu', action='store_true', default=False,
|
|
help='Force code to run in CPU mode')
|
|
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
if __name__ == '__main__':
|
|
args = parse_args()
|
|
|
|
# Load image
|
|
logging.info('Loading image...')
|
|
pic = Image.open(args.rgb_path, 'r')
|
|
rgb = np.array(pic)
|
|
pic = Image.open(args.depth_path, 'r')
|
|
depth = np.expand_dims(np.array(pic), axis=2)
|
|
print(depth.shape)
|
|
|
|
# Load Network
|
|
logging.info('Loading model...')
|
|
net = torch.load(args.network)
|
|
logging.info('Done')
|
|
|
|
# Get the compute device
|
|
device = get_device(args.force_cpu)
|
|
|
|
img_data = CameraData(width=244, height=244, include_depth=args.use_depth, include_rgb=args.use_rgb)
|
|
# img_data = CameraData(include_depth=args.use_depth, include_rgb=args.use_rgb)
|
|
|
|
x, depth_img, rgb_img = img_data.get_data(rgb=rgb, depth=depth)
|
|
|
|
plt.imshow(depth_img[0])
|
|
plt.colorbar(label='Pixel value')
|
|
plt.title('Depth image')
|
|
plt.show()
|
|
|
|
# plt.imshow(rgb_img)
|
|
# plt.title('RGB image')
|
|
# plt.show()
|
|
|
|
with torch.no_grad():
|
|
xc = x.to(device)
|
|
pred = net.predict(xc)
|
|
|
|
q_img, ang_img, width_img = post_process_output(pred['pos'], pred['cos'], pred['sin'], pred['width'])
|
|
|
|
if args.save:
|
|
save_results(
|
|
rgb_img=img_data.get_rgb(rgb, False),
|
|
depth_img=np.squeeze(img_data.get_depth(depth)),
|
|
grasp_q_img=q_img,
|
|
grasp_angle_img=ang_img,
|
|
no_grasps=args.n_grasps,
|
|
grasp_width_img=width_img
|
|
)
|
|
else:
|
|
fig = plt.figure(figsize=(10, 10))
|
|
plot_results(fig=fig,
|
|
rgb_img=img_data.get_rgb(rgb, False),
|
|
grasp_q_img=q_img,
|
|
grasp_angle_img=ang_img,
|
|
no_grasps=args.n_grasps,
|
|
grasp_width_img=width_img)
|
|
fig.savefig('img_result.pdf')
|
|
|
|
grasps = detect_grasps(q_img, ang_img, width_img=width_img, no_grasps=args.n_grasps)
|
|
grasp = grasps[0]
|
|
print(grasp)
|
|
|
|
|