import argparse import logging import time import numpy as np import torch.utils.data from hardware.device import get_device from inference.post_process import post_process_output from utils.data import get_dataset from utils.dataset_processing import evaluation, grasp from utils.visualisation.plot import save_results logging.basicConfig(level=logging.INFO) def parse_args(): parser = argparse.ArgumentParser(description='Evaluate networks') # Network parser.add_argument('--network', metavar='N', type=str, nargs='+', help='Path to saved networks to evaluate') # Dataset parser.add_argument('--dataset', type=str, help='Dataset Name ("cornell" or "jaquard")') parser.add_argument('--dataset-path', type=str, help='Path to dataset') parser.add_argument('--use-depth', type=int, default=1, help='Use Depth image for evaluation (1/0)') parser.add_argument('--use-rgb', type=int, default=1, help='Use RGB image for evaluation (1/0)') parser.add_argument('--augment', action='store_true', help='Whether data augmentation should be applied') parser.add_argument('--split', type=float, default=0.9, help='Fraction of data for training (remainder is validation)') parser.add_argument('--ds-shuffle', action='store_true', default=False, help='Shuffle the dataset') parser.add_argument('--ds-rotate', type=float, default=0.0, help='Shift the start point of the dataset to use a different test/train split') parser.add_argument('--num-workers', type=int, default=8, help='Dataset workers') # Evaluation parser.add_argument('--n-grasps', type=int, default=1, help='Number of grasps to consider per image') parser.add_argument('--iou-threshold', type=float, default=0.25, help='Threshold for IOU matching') parser.add_argument('--iou-eval', action='store_true', help='Compute success based on IoU metric.') parser.add_argument('--jacquard-output', action='store_true', help='Jacquard-dataset style output') # Misc. parser.add_argument('--vis', action='store_true', help='Visualise the network output') parser.add_argument('--cpu', dest='force_cpu', action='store_true', default=False, help='Force code to run in CPU mode') parser.add_argument('--random-seed', type=int, default=123, help='Random seed for numpy') args = parser.parse_args() if args.jacquard_output and args.dataset != 'jacquard': raise ValueError('--jacquard-output can only be used with the --dataset jacquard option.') if args.jacquard_output and args.augment: raise ValueError('--jacquard-output can not be used with data augmentation.') return args if __name__ == '__main__': args = parse_args() # Get the compute device device = get_device(args.force_cpu) # Load Dataset logging.info('Loading {} Dataset...'.format(args.dataset.title())) Dataset = get_dataset(args.dataset) test_dataset = Dataset(args.dataset_path, ds_rotate=args.ds_rotate, random_rotate=args.augment, random_zoom=args.augment, include_depth=args.use_depth, include_rgb=args.use_rgb) indices = list(range(test_dataset.length)) split = int(np.floor(args.split * test_dataset.length)) if args.ds_shuffle: np.random.seed(args.random_seed) np.random.shuffle(indices) val_indices = indices[split:] val_sampler = torch.utils.data.sampler.SubsetRandomSampler(val_indices) logging.info('Validation size: {}'.format(len(val_indices))) test_data = torch.utils.data.DataLoader( test_dataset, batch_size=1, num_workers=args.num_workers, sampler=val_sampler ) logging.info('Done') for network in args.network: logging.info('\nEvaluating model {}'.format(network)) # Load Network net = torch.load(network) results = {'correct': 0, 'failed': 0} if args.jacquard_output: jo_fn = network + '_jacquard_output.txt' with open(jo_fn, 'w') as f: pass start_time = time.time() with torch.no_grad(): for idx, (x, y, didx, rot, zoom) in enumerate(test_data): xc = x.to(device) yc = [yi.to(device) for yi in y] lossd = net.compute_loss(xc, yc) q_img, ang_img, width_img = post_process_output(lossd['pred']['pos'], lossd['pred']['cos'], lossd['pred']['sin'], lossd['pred']['width']) if args.iou_eval: s = evaluation.calculate_iou_match(q_img, ang_img, test_data.dataset.get_gtbb(didx, rot, zoom), no_grasps=args.n_grasps, grasp_width=width_img, threshold=args.iou_threshold ) if s: results['correct'] += 1 else: results['failed'] += 1 if args.jacquard_output: grasps = grasp.detect_grasps(q_img, ang_img, width_img=width_img, no_grasps=1) with open(jo_fn, 'a') as f: for g in grasps: f.write(test_data.dataset.get_jname(didx) + '\n') f.write(g.to_jacquard(scale=1024 / 300) + '\n') if args.vis: save_results( rgb_img=test_data.dataset.get_rgb(didx, rot, zoom, normalise=False), depth_img=test_data.dataset.get_depth(didx, rot, zoom), grasp_q_img=q_img, grasp_angle_img=ang_img, no_grasps=args.n_grasps, grasp_width_img=width_img ) avg_time = (time.time() - start_time) / len(test_data) logging.info('Average evaluation time per image: {}ms'.format(avg_time * 1000)) if args.iou_eval: logging.info('IOU Results: %d/%d = %f' % (results['correct'], results['correct'] + results['failed'], results['correct'] / (results['correct'] + results['failed']))) if args.jacquard_output: logging.info('Jacquard output saved to {}'.format(jo_fn)) del net torch.cuda.empty_cache()