ur5-robotic-grasping/network/inference/post_process.py

23 lines
901 B
Python

import torch
from skimage.filters import gaussian
def post_process_output(q_img, cos_img, sin_img, width_img, pixels_max_grasp):
"""
Post-process the raw output of the network, convert to numpy arrays, apply filtering.
:param q_img: Q output of network (as torch Tensors)
:param cos_img: cos output of network
:param sin_img: sin output of network
:param width_img: Width output of network
:return: Filtered Q output, Filtered Angle output, Filtered Width output
"""
q_img = q_img.cpu().numpy().squeeze()
ang_img = (torch.atan2(sin_img, cos_img) / 2.0).cpu().numpy().squeeze()
width_img = width_img.cpu().numpy().squeeze() * pixels_max_grasp
q_img = gaussian(q_img, 2.0, preserve_range=True)
ang_img = gaussian(ang_img, 2.0, preserve_range=True)
width_img = gaussian(width_img, 1.0, preserve_range=True)
return q_img, ang_img, width_img