80 lines
2.8 KiB
Python
80 lines
2.8 KiB
Python
import argparse
|
|
import logging
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import torch.utils.data
|
|
|
|
from hardware.camera import RealSenseCamera
|
|
from hardware.device import get_device
|
|
from inference.post_process import post_process_output
|
|
from utils.data.camera_data import CameraData
|
|
from utils.visualisation.plot import save_results, plot_results
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description='Evaluate network')
|
|
parser.add_argument('--network', type=str, default='saved_data/cornell_rgbd_iou_0.96',
|
|
help='Path to saved network to evaluate')
|
|
parser.add_argument('--use-depth', type=int, default=1,
|
|
help='Use Depth image for evaluation (1/0)')
|
|
parser.add_argument('--use-rgb', type=int, default=1,
|
|
help='Use RGB image for evaluation (1/0)')
|
|
parser.add_argument('--n-grasps', type=int, default=1,
|
|
help='Number of grasps to consider per image')
|
|
parser.add_argument('--cpu', dest='force_cpu', action='store_true', default=False,
|
|
help='Force code to run in CPU mode')
|
|
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
if __name__ == '__main__':
|
|
args = parse_args()
|
|
|
|
# Connect to Camera
|
|
logging.info('Connecting to camera...')
|
|
cam = RealSenseCamera(device_id=830112070066)
|
|
cam.connect()
|
|
cam_data = CameraData(include_depth=args.use_depth, include_rgb=args.use_rgb)
|
|
|
|
# Load Network
|
|
logging.info('Loading model...')
|
|
net = torch.load(args.network)
|
|
logging.info('Done')
|
|
|
|
# Get the compute device
|
|
device = get_device(args.force_cpu)
|
|
|
|
try:
|
|
fig = plt.figure(figsize=(10, 10))
|
|
while True:
|
|
image_bundle = cam.get_image_bundle()
|
|
rgb = image_bundle['rgb']
|
|
depth = image_bundle['aligned_depth']
|
|
x, depth_img, rgb_img = cam_data.get_data(rgb=rgb, depth=depth)
|
|
with torch.no_grad():
|
|
xc = x.to(device)
|
|
pred = net.predict(xc)
|
|
|
|
q_img, ang_img, width_img = post_process_output(pred['pos'], pred['cos'], pred['sin'], pred['width'])
|
|
|
|
plot_results(fig=fig,
|
|
rgb_img=cam_data.get_rgb(rgb, False),
|
|
depth_img=np.squeeze(cam_data.get_depth(depth)),
|
|
grasp_q_img=q_img,
|
|
grasp_angle_img=ang_img,
|
|
no_grasps=args.n_grasps,
|
|
grasp_width_img=width_img)
|
|
finally:
|
|
save_results(
|
|
rgb_img=cam_data.get_rgb(rgb, False),
|
|
depth_img=np.squeeze(cam_data.get_depth(depth)),
|
|
grasp_q_img=q_img,
|
|
grasp_angle_img=ang_img,
|
|
no_grasps=args.n_grasps,
|
|
grasp_width_img=width_img
|
|
)
|