import argparse import logging import matplotlib.pyplot as plt import numpy as np import torch.utils.data from PIL import Image from network.hardware.device import get_device from network.inference.post_process import post_process_output from network.utils.data.camera_data import CameraData from network.utils.visualisation.plot import plot_results, save_results from network.utils.dataset_processing.grasp import detect_grasps logging.basicConfig(level=logging.INFO) def parse_args(): parser = argparse.ArgumentParser(description='Evaluate network') parser.add_argument('--network', type=str, help='Path to saved network to evaluate') parser.add_argument('--rgb_path', type=str, default='cornell/08/pcd0845r.png', help='RGB Image path') parser.add_argument('--depth_path', type=str, default='cornell/08/pcd0845d.tiff', help='Depth Image path') parser.add_argument('--use-depth', type=int, default=1, help='Use Depth image for evaluation (1/0)') parser.add_argument('--use-rgb', type=int, default=1, help='Use RGB image for evaluation (1/0)') parser.add_argument('--n-grasps', type=int, default=1, help='Number of grasps to consider per image') parser.add_argument('--save', type=int, default=0, help='Save the results') parser.add_argument('--cpu', dest='force_cpu', action='store_true', default=False, help='Force code to run in CPU mode') args = parser.parse_args() return args if __name__ == '__main__': args = parse_args() # Load image logging.info('Loading image...') pic = Image.open(args.rgb_path, 'r') rgb = np.array(pic) pic = Image.open(args.depth_path, 'r') depth = np.expand_dims(np.array(pic), axis=2) print(depth.shape) # Load Network logging.info('Loading model...') net = torch.load(args.network) logging.info('Done') # Get the compute device device = get_device(args.force_cpu) img_data = CameraData(width=244, height=244, include_depth=args.use_depth, include_rgb=args.use_rgb) # img_data = CameraData(include_depth=args.use_depth, include_rgb=args.use_rgb) x, depth_img, rgb_img = img_data.get_data(rgb=rgb, depth=depth) plt.imshow(depth_img[0]) plt.colorbar(label='Pixel value') plt.title('Depth image') plt.show() # plt.imshow(rgb_img) # plt.title('RGB image') # plt.show() with torch.no_grad(): xc = x.to(device) pred = net.predict(xc) q_img, ang_img, width_img = post_process_output(pred['pos'], pred['cos'], pred['sin'], pred['width']) if args.save: save_results( rgb_img=img_data.get_rgb(rgb, False), depth_img=np.squeeze(img_data.get_depth(depth)), grasp_q_img=q_img, grasp_angle_img=ang_img, no_grasps=args.n_grasps, grasp_width_img=width_img ) else: fig = plt.figure(figsize=(10, 10)) plot_results(fig=fig, rgb_img=img_data.get_rgb(rgb, False), grasp_q_img=q_img, grasp_angle_img=ang_img, no_grasps=args.n_grasps, grasp_width_img=width_img) fig.savefig('img_result.pdf') grasps = detect_grasps(q_img, ang_img, width_img=width_img, no_grasps=args.n_grasps) grasp = grasps[0] print(grasp)