170 lines
6.9 KiB
Python
170 lines
6.9 KiB
Python
import argparse
|
|
import logging
|
|
import time
|
|
|
|
import numpy as np
|
|
import torch.utils.data
|
|
|
|
from hardware.device import get_device
|
|
from inference.post_process import post_process_output
|
|
from utils.data import get_dataset
|
|
from utils.dataset_processing import evaluation, grasp
|
|
from utils.visualisation.plot import save_results
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description='Evaluate networks')
|
|
|
|
# Network
|
|
parser.add_argument('--network', metavar='N', type=str, nargs='+',
|
|
help='Path to saved networks to evaluate')
|
|
|
|
# Dataset
|
|
parser.add_argument('--dataset', type=str,
|
|
help='Dataset Name ("cornell" or "jaquard")')
|
|
parser.add_argument('--dataset-path', type=str,
|
|
help='Path to dataset')
|
|
parser.add_argument('--use-depth', type=int, default=1,
|
|
help='Use Depth image for evaluation (1/0)')
|
|
parser.add_argument('--use-rgb', type=int, default=1,
|
|
help='Use RGB image for evaluation (1/0)')
|
|
parser.add_argument('--augment', action='store_true',
|
|
help='Whether data augmentation should be applied')
|
|
parser.add_argument('--split', type=float, default=0.9,
|
|
help='Fraction of data for training (remainder is validation)')
|
|
parser.add_argument('--ds-shuffle', action='store_true', default=False,
|
|
help='Shuffle the dataset')
|
|
parser.add_argument('--ds-rotate', type=float, default=0.0,
|
|
help='Shift the start point of the dataset to use a different test/train split')
|
|
parser.add_argument('--num-workers', type=int, default=8,
|
|
help='Dataset workers')
|
|
|
|
# Evaluation
|
|
parser.add_argument('--n-grasps', type=int, default=1,
|
|
help='Number of grasps to consider per image')
|
|
parser.add_argument('--iou-threshold', type=float, default=0.25,
|
|
help='Threshold for IOU matching')
|
|
parser.add_argument('--iou-eval', action='store_true',
|
|
help='Compute success based on IoU metric.')
|
|
parser.add_argument('--jacquard-output', action='store_true',
|
|
help='Jacquard-dataset style output')
|
|
|
|
# Misc.
|
|
parser.add_argument('--vis', action='store_true',
|
|
help='Visualise the network output')
|
|
parser.add_argument('--cpu', dest='force_cpu', action='store_true', default=False,
|
|
help='Force code to run in CPU mode')
|
|
parser.add_argument('--random-seed', type=int, default=123,
|
|
help='Random seed for numpy')
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.jacquard_output and args.dataset != 'jacquard':
|
|
raise ValueError('--jacquard-output can only be used with the --dataset jacquard option.')
|
|
if args.jacquard_output and args.augment:
|
|
raise ValueError('--jacquard-output can not be used with data augmentation.')
|
|
|
|
return args
|
|
|
|
|
|
if __name__ == '__main__':
|
|
args = parse_args()
|
|
|
|
# Get the compute device
|
|
device = get_device(args.force_cpu)
|
|
|
|
# Load Dataset
|
|
logging.info('Loading {} Dataset...'.format(args.dataset.title()))
|
|
Dataset = get_dataset(args.dataset)
|
|
test_dataset = Dataset(args.dataset_path,
|
|
ds_rotate=args.ds_rotate,
|
|
random_rotate=args.augment,
|
|
random_zoom=args.augment,
|
|
include_depth=args.use_depth,
|
|
include_rgb=args.use_rgb)
|
|
|
|
indices = list(range(test_dataset.length))
|
|
split = int(np.floor(args.split * test_dataset.length))
|
|
if args.ds_shuffle:
|
|
np.random.seed(args.random_seed)
|
|
np.random.shuffle(indices)
|
|
val_indices = indices[split:]
|
|
val_sampler = torch.utils.data.sampler.SubsetRandomSampler(val_indices)
|
|
logging.info('Validation size: {}'.format(len(val_indices)))
|
|
|
|
test_data = torch.utils.data.DataLoader(
|
|
test_dataset,
|
|
batch_size=1,
|
|
num_workers=args.num_workers,
|
|
sampler=val_sampler
|
|
)
|
|
logging.info('Done')
|
|
|
|
for network in args.network:
|
|
logging.info('\nEvaluating model {}'.format(network))
|
|
|
|
# Load Network
|
|
net = torch.load(network)
|
|
|
|
results = {'correct': 0, 'failed': 0}
|
|
|
|
if args.jacquard_output:
|
|
jo_fn = network + '_jacquard_output.txt'
|
|
with open(jo_fn, 'w') as f:
|
|
pass
|
|
|
|
start_time = time.time()
|
|
|
|
with torch.no_grad():
|
|
for idx, (x, y, didx, rot, zoom) in enumerate(test_data):
|
|
xc = x.to(device)
|
|
yc = [yi.to(device) for yi in y]
|
|
lossd = net.compute_loss(xc, yc)
|
|
|
|
q_img, ang_img, width_img = post_process_output(lossd['pred']['pos'], lossd['pred']['cos'],
|
|
lossd['pred']['sin'], lossd['pred']['width'])
|
|
|
|
if args.iou_eval:
|
|
s = evaluation.calculate_iou_match(q_img, ang_img, test_data.dataset.get_gtbb(didx, rot, zoom),
|
|
no_grasps=args.n_grasps,
|
|
grasp_width=width_img,
|
|
threshold=args.iou_threshold
|
|
)
|
|
if s:
|
|
results['correct'] += 1
|
|
else:
|
|
results['failed'] += 1
|
|
|
|
if args.jacquard_output:
|
|
grasps = grasp.detect_grasps(q_img, ang_img, width_img=width_img, no_grasps=1)
|
|
with open(jo_fn, 'a') as f:
|
|
for g in grasps:
|
|
f.write(test_data.dataset.get_jname(didx) + '\n')
|
|
f.write(g.to_jacquard(scale=1024 / 300) + '\n')
|
|
|
|
if args.vis:
|
|
save_results(
|
|
rgb_img=test_data.dataset.get_rgb(didx, rot, zoom, normalise=False),
|
|
depth_img=test_data.dataset.get_depth(didx, rot, zoom),
|
|
grasp_q_img=q_img,
|
|
grasp_angle_img=ang_img,
|
|
no_grasps=args.n_grasps,
|
|
grasp_width_img=width_img
|
|
)
|
|
|
|
avg_time = (time.time() - start_time) / len(test_data)
|
|
logging.info('Average evaluation time per image: {}ms'.format(avg_time * 1000))
|
|
|
|
if args.iou_eval:
|
|
logging.info('IOU Results: %d/%d = %f' % (results['correct'],
|
|
results['correct'] + results['failed'],
|
|
results['correct'] / (results['correct'] + results['failed'])))
|
|
|
|
if args.jacquard_output:
|
|
logging.info('Jacquard output saved to {}'.format(jo_fn))
|
|
|
|
del net
|
|
torch.cuda.empty_cache()
|