import torch.nn as nn import torch.nn.functional as F class GraspModel(nn.Module): """ An abstract model for grasp network in a common format. """ def __init__(self): super(GraspModel, self).__init__() def forward(self, x_in): raise NotImplementedError() def compute_loss(self, xc, yc): y_pos, y_cos, y_sin, y_width = yc pos_pred, cos_pred, sin_pred, width_pred = self(xc) p_loss = F.smooth_l1_loss(pos_pred, y_pos) cos_loss = F.smooth_l1_loss(cos_pred, y_cos) sin_loss = F.smooth_l1_loss(sin_pred, y_sin) width_loss = F.smooth_l1_loss(width_pred, y_width) return { 'loss': p_loss + cos_loss + sin_loss + width_loss, 'losses': { 'p_loss': p_loss, 'cos_loss': cos_loss, 'sin_loss': sin_loss, 'width_loss': width_loss }, 'pred': { 'pos': pos_pred, 'cos': cos_pred, 'sin': sin_pred, 'width': width_pred } } def predict(self, xc): pos_pred, cos_pred, sin_pred, width_pred = self(xc) return { 'pos': pos_pred, 'cos': cos_pred, 'sin': sin_pred, 'width': width_pred } class ResidualBlock(nn.Module): """ A residual block with dropout option """ def __init__(self, in_channels, out_channels, kernel_size=3): super(ResidualBlock, self).__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, padding=1) self.bn1 = nn.BatchNorm2d(in_channels) self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size, padding=1) self.bn2 = nn.BatchNorm2d(in_channels) def forward(self, x_in): x = self.bn1(self.conv1(x_in)) x = F.relu(x) x = self.bn2(self.conv2(x)) return x + x_in