import argparse import logging import matplotlib.pyplot as plt import numpy as np import torch.utils.data from hardware.camera import RealSenseCamera from hardware.device import get_device from inference.post_process import post_process_output from utils.data.camera_data import CameraData from utils.visualisation.plot import save_results, plot_results logging.basicConfig(level=logging.INFO) def parse_args(): parser = argparse.ArgumentParser(description='Evaluate network') parser.add_argument('--network', type=str, default='saved_data/cornell_rgbd_iou_0.96', help='Path to saved network to evaluate') parser.add_argument('--use-depth', type=int, default=1, help='Use Depth image for evaluation (1/0)') parser.add_argument('--use-rgb', type=int, default=1, help='Use RGB image for evaluation (1/0)') parser.add_argument('--n-grasps', type=int, default=1, help='Number of grasps to consider per image') parser.add_argument('--cpu', dest='force_cpu', action='store_true', default=False, help='Force code to run in CPU mode') args = parser.parse_args() return args if __name__ == '__main__': args = parse_args() # Connect to Camera logging.info('Connecting to camera...') cam = RealSenseCamera(device_id=830112070066) cam.connect() cam_data = CameraData(include_depth=args.use_depth, include_rgb=args.use_rgb) # Load Network logging.info('Loading model...') net = torch.load(args.network) logging.info('Done') # Get the compute device device = get_device(args.force_cpu) try: fig = plt.figure(figsize=(10, 10)) while True: image_bundle = cam.get_image_bundle() rgb = image_bundle['rgb'] depth = image_bundle['aligned_depth'] x, depth_img, rgb_img = cam_data.get_data(rgb=rgb, depth=depth) with torch.no_grad(): xc = x.to(device) pred = net.predict(xc) q_img, ang_img, width_img = post_process_output(pred['pos'], pred['cos'], pred['sin'], pred['width']) plot_results(fig=fig, rgb_img=cam_data.get_rgb(rgb, False), depth_img=np.squeeze(cam_data.get_depth(depth)), grasp_q_img=q_img, grasp_angle_img=ang_img, no_grasps=args.n_grasps, grasp_width_img=width_img) finally: save_results( rgb_img=cam_data.get_rgb(rgb, False), depth_img=np.squeeze(cam_data.get_depth(depth)), grasp_q_img=q_img, grasp_angle_img=ang_img, no_grasps=args.n_grasps, grasp_width_img=width_img )