ur5-robotic-grasping/network/inference/models/grasp_model.py

68 lines
1.9 KiB
Python

import torch.nn as nn
import torch.nn.functional as F
class GraspModel(nn.Module):
"""
An abstract model for grasp network in a common format.
"""
def __init__(self):
super(GraspModel, self).__init__()
def forward(self, x_in):
raise NotImplementedError()
def compute_loss(self, xc, yc):
y_pos, y_cos, y_sin, y_width = yc
pos_pred, cos_pred, sin_pred, width_pred = self(xc)
p_loss = F.smooth_l1_loss(pos_pred, y_pos)
cos_loss = F.smooth_l1_loss(cos_pred, y_cos)
sin_loss = F.smooth_l1_loss(sin_pred, y_sin)
width_loss = F.smooth_l1_loss(width_pred, y_width)
return {
'loss': p_loss + cos_loss + sin_loss + width_loss,
'losses': {
'p_loss': p_loss,
'cos_loss': cos_loss,
'sin_loss': sin_loss,
'width_loss': width_loss
},
'pred': {
'pos': pos_pred,
'cos': cos_pred,
'sin': sin_pred,
'width': width_pred
}
}
def predict(self, xc):
pos_pred, cos_pred, sin_pred, width_pred = self(xc)
return {
'pos': pos_pred,
'cos': cos_pred,
'sin': sin_pred,
'width': width_pred
}
class ResidualBlock(nn.Module):
"""
A residual block with dropout option
"""
def __init__(self, in_channels, out_channels, kernel_size=3):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, padding=1)
self.bn1 = nn.BatchNorm2d(in_channels)
self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size, padding=1)
self.bn2 = nn.BatchNorm2d(in_channels)
def forward(self, x_in):
x = self.bn1(self.conv1(x_in))
x = F.relu(x)
x = self.bn2(self.conv2(x))
return x + x_in