import torch from skimage.filters import gaussian def post_process_output(q_img, cos_img, sin_img, width_img, pixels_max_grasp): """ Post-process the raw output of the network, convert to numpy arrays, apply filtering. :param q_img: Q output of network (as torch Tensors) :param cos_img: cos output of network :param sin_img: sin output of network :param width_img: Width output of network :return: Filtered Q output, Filtered Angle output, Filtered Width output """ q_img = q_img.cpu().numpy().squeeze() ang_img = (torch.atan2(sin_img, cos_img) / 2.0).cpu().numpy().squeeze() width_img = width_img.cpu().numpy().squeeze() * pixels_max_grasp q_img = gaussian(q_img, 2.0, preserve_range=True) ang_img = gaussian(ang_img, 2.0, preserve_range=True) width_img = gaussian(width_img, 1.0, preserve_range=True) return q_img, ang_img, width_img