140 lines
5.2 KiB
Python
140 lines
5.2 KiB
Python
|
|
import torch
|
|
|
|
from rsl_rl.modules.mlp import MlpModel
|
|
from rsl_rl.modules.utils import conv2d_output_shape
|
|
|
|
|
|
class Conv2dModel(torch.nn.Module):
|
|
"""2-D Convolutional model component, with option for max-pooling vs
|
|
downsampling for strides > 1. Requires number of input channels, but
|
|
not input shape. Uses ``torch.nn.Conv2d``.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
channels,
|
|
kernel_sizes,
|
|
strides,
|
|
paddings=None,
|
|
nonlinearity=torch.nn.ReLU, # Module, not Functional.
|
|
use_maxpool=False, # if True: convs use stride 1, maxpool downsample.
|
|
head_sizes=None, # Put an MLP head on top.
|
|
normlayer= None, # If None, will not be used
|
|
):
|
|
super().__init__()
|
|
if paddings is None:
|
|
paddings = [0 for _ in range(len(channels))]
|
|
if isinstance(normlayer, str):
|
|
normlayer = getattr(torch.nn, normlayer)
|
|
assert len(channels) == len(kernel_sizes) == len(strides) == len(paddings)
|
|
in_channels = [in_channels] + channels[:-1]
|
|
ones = [1 for _ in range(len(strides))]
|
|
if use_maxpool:
|
|
maxp_strides = strides
|
|
strides = ones
|
|
else:
|
|
maxp_strides = ones
|
|
conv_layers = [torch.nn.Conv2d(in_channels=ic, out_channels=oc,
|
|
kernel_size=k, stride=s, padding=p) for (ic, oc, k, s, p) in
|
|
zip(in_channels, channels, kernel_sizes, strides, paddings)]
|
|
sequence = list()
|
|
for conv_layer, oc, maxp_stride in zip(conv_layers, channels, maxp_strides):
|
|
if normlayer is not None:
|
|
sequence.extend([conv_layer, normlayer(oc), nonlinearity()])
|
|
else:
|
|
sequence.extend([conv_layer, nonlinearity()])
|
|
if maxp_stride > 1:
|
|
sequence.append(torch.nn.MaxPool2d(maxp_stride)) # No padding.
|
|
self.conv = torch.nn.Sequential(*sequence)
|
|
|
|
def forward(self, input):
|
|
"""Computes the convolution stack on the input; assumes correct shape
|
|
already: [B,C,H,W]."""
|
|
return self.conv(input)
|
|
|
|
def conv_out_size(self, h, w, c=None):
|
|
"""Helper function ot return the output size for a given input shape,
|
|
without actually performing a forward pass through the model."""
|
|
for child in self.conv.children():
|
|
try:
|
|
h, w = conv2d_output_shape(h, w, child.kernel_size,
|
|
child.stride, child.padding)
|
|
except AttributeError:
|
|
pass # Not a conv or maxpool layer.
|
|
try:
|
|
c = child.out_channels
|
|
except AttributeError:
|
|
pass # Not a conv layer.
|
|
return h * w * c
|
|
|
|
def conv_out_resolution(self, h, w):
|
|
"""Helper function that return the resolution (H, W) for a giben input resolution"""
|
|
for child in self.conv.children():
|
|
try:
|
|
h, w = conv2d_output_shape(h, w, child.kernel_size,
|
|
child.stride, child.padding)
|
|
except AttributeError:
|
|
pass # Not a conv or maxpool layer.
|
|
try:
|
|
c = child.out_channels
|
|
except AttributeError:
|
|
pass # Not a conv layer.
|
|
return h, w
|
|
|
|
class Conv2dHeadModel(torch.nn.Module):
|
|
"""Model component composed of a ``Conv2dModel`` component followed by
|
|
a fully-connected ``MlpModel`` head. Requires full input image shape to
|
|
instantiate the MLP head.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
image_shape,
|
|
channels,
|
|
kernel_sizes,
|
|
strides,
|
|
hidden_sizes,
|
|
output_size=None, # if None: nonlinearity applied to output.
|
|
paddings=None,
|
|
nonlinearity=torch.nn.ReLU,
|
|
use_maxpool=False,
|
|
normlayer= None, # if None, will not be used
|
|
):
|
|
super().__init__()
|
|
if isinstance(nonlinearity, str): nonlinearity = getattr(torch.nn, nonlinearity)
|
|
c, h, w = image_shape
|
|
self.conv = Conv2dModel(
|
|
in_channels=c,
|
|
channels=channels,
|
|
kernel_sizes=kernel_sizes,
|
|
strides=strides,
|
|
paddings=paddings,
|
|
nonlinearity=nonlinearity,
|
|
use_maxpool=use_maxpool,
|
|
normlayer= None, # if None, will not be used
|
|
)
|
|
conv_out_size = self.conv.conv_out_size(h, w)
|
|
if hidden_sizes or output_size:
|
|
self.head = MlpModel(conv_out_size, hidden_sizes,
|
|
output_size=output_size, nonlinearity=nonlinearity)
|
|
if output_size is not None:
|
|
self._output_size = output_size
|
|
else:
|
|
self._output_size = (hidden_sizes if
|
|
isinstance(hidden_sizes, int) else hidden_sizes[-1])
|
|
else:
|
|
self.head = lambda x: x
|
|
self._output_size = conv_out_size
|
|
|
|
def forward(self, input):
|
|
"""Compute the convolution and fully connected head on the input;
|
|
assumes correct input shape: [B,C,H,W]."""
|
|
return self.head(self.conv(input).view(input.shape[0], -1))
|
|
|
|
@property
|
|
def output_size(self):
|
|
"""Returns the final output size after MLP head."""
|
|
return self._output_size
|