Fix SpatialSoftmax input shape (#150)
This commit is contained in:
parent
47de07658c
commit
f5de57b385
|
@ -315,11 +315,13 @@ class DiffusionRgbEncoder(nn.Module):
|
|||
|
||||
# Set up pooling and final layers.
|
||||
# Use a dry run to get the feature map shape.
|
||||
# The dummy input should take the number of image channels from `config.input_shapes` and it should use the
|
||||
# height and width from `config.crop_shape`.
|
||||
dummy_input = torch.zeros(size=(1, config.input_shapes["observation.image"][0], *config.crop_shape))
|
||||
with torch.inference_mode():
|
||||
feat_map_shape = tuple(
|
||||
self.backbone(torch.zeros(size=(1, *config.input_shapes["observation.image"]))).shape[1:]
|
||||
)
|
||||
self.pool = SpatialSoftmax(feat_map_shape, num_kp=config.spatial_softmax_num_keypoints)
|
||||
dummy_feature_map = self.backbone(dummy_input)
|
||||
feature_map_shape = tuple(dummy_feature_map.shape[1:])
|
||||
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
|
||||
self.feature_dim = config.spatial_softmax_num_keypoints * 2
|
||||
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
|
||||
self.relu = nn.ReLU()
|
||||
|
|
Loading…
Reference in New Issue