87 lines
2.7 KiB
Python
87 lines
2.7 KiB
Python
import numpy as np
|
|
import torch
|
|
import matplotlib.pyplot as plt
|
|
|
|
from network.utils.dataset_processing import image
|
|
|
|
|
|
class CameraData:
|
|
"""
|
|
Dataset wrapper for the camera data.
|
|
"""
|
|
def __init__(self,
|
|
width=640,
|
|
height=480,
|
|
output_size=224,
|
|
include_depth=True,
|
|
include_rgb=True
|
|
):
|
|
"""
|
|
:param output_size: Image output size in pixels (square)
|
|
:param include_depth: Whether depth image is included
|
|
:param include_rgb: Whether RGB image is included
|
|
"""
|
|
self.output_size = output_size
|
|
self.include_depth = include_depth
|
|
self.include_rgb = include_rgb
|
|
|
|
if include_depth is False and include_rgb is False:
|
|
raise ValueError('At least one of Depth or RGB must be specified.')
|
|
|
|
left = (width - output_size) // 2
|
|
top = (height - output_size) // 2
|
|
right = (width + output_size) // 2
|
|
bottom = (height + output_size) // 2
|
|
|
|
self.bottom_right = (bottom, right)
|
|
self.top_left = (top, left)
|
|
|
|
@staticmethod
|
|
def numpy_to_torch(s):
|
|
if len(s.shape) == 2:
|
|
return torch.from_numpy(np.expand_dims(s, 0).astype(np.float32))
|
|
else:
|
|
return torch.from_numpy(s.astype(np.float32))
|
|
|
|
def get_depth(self, img):
|
|
depth_img = image.Image(img)
|
|
depth_img.crop(bottom_right=self.bottom_right, top_left=self.top_left)
|
|
depth_img.normalise()
|
|
# depth_img.resize((self.output_size, self.output_size))
|
|
depth_img.img = depth_img.img.transpose((2, 0, 1))
|
|
return depth_img.img
|
|
|
|
def get_rgb(self, img, norm=True):
|
|
rgb_img = image.Image(img)
|
|
rgb_img.crop(bottom_right=self.bottom_right, top_left=self.top_left)
|
|
if norm:
|
|
rgb_img.normalise()
|
|
rgb_img.img = rgb_img.img.transpose((2, 0, 1))
|
|
return rgb_img.img
|
|
|
|
def get_data(self, rgb=None, depth=None):
|
|
depth_img = None
|
|
rgb_img = None
|
|
# Load the depth image
|
|
if self.include_depth:
|
|
depth_img = self.get_depth(img=depth)
|
|
|
|
# Load the RGB image
|
|
if self.include_rgb:
|
|
rgb_img = self.get_rgb(img=rgb)
|
|
|
|
if self.include_depth and self.include_rgb:
|
|
x = self.numpy_to_torch(
|
|
np.concatenate(
|
|
(np.expand_dims(depth_img, 0),
|
|
np.expand_dims(rgb_img, 0)),
|
|
1
|
|
)
|
|
)
|
|
elif self.include_depth:
|
|
x = self.numpy_to_torch(depth_img)
|
|
elif self.include_rgb:
|
|
x = self.numpy_to_torch(np.expand_dims(rgb_img, 0))
|
|
|
|
return x, depth_img, rgb_img
|