Fix SpatialSoftmax input shape (#150)

This commit is contained in:
Alexander Soare 2024-05-08 14:57:29 +01:00 committed by GitHub
parent 47de07658c
commit f5de57b385
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 6 additions and 4 deletions

View File

@ -315,11 +315,13 @@ class DiffusionRgbEncoder(nn.Module):
# Set up pooling and final layers.
# Use a dry run to get the feature map shape.
# The dummy input should take the number of image channels from `config.input_shapes` and it should use the
# height and width from `config.crop_shape`.
dummy_input = torch.zeros(size=(1, config.input_shapes["observation.image"][0], *config.crop_shape))
with torch.inference_mode():
feat_map_shape = tuple(
self.backbone(torch.zeros(size=(1, *config.input_shapes["observation.image"]))).shape[1:]
)
self.pool = SpatialSoftmax(feat_map_shape, num_kp=config.spatial_softmax_num_keypoints)
dummy_feature_map = self.backbone(dummy_input)
feature_map_shape = tuple(dummy_feature_map.shape[1:])
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
self.feature_dim = config.spatial_softmax_num_keypoints * 2
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
self.relu = nn.ReLU()