From 393c1320bdd0eb9037e859006b78d8b0eb47b929 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Wed, 8 May 2024 14:57:06 +0100 Subject: [PATCH] revision --- .../common/policies/diffusion/modeling_diffusion.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 9abec4ec..a7ba5442 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -315,13 +315,13 @@ class DiffusionRgbEncoder(nn.Module): # Set up pooling and final layers. # Use a dry run to get the feature map shape. + # The dummy input should take the number of image channels from `config.input_shapes` and it should use the + # height and width from `config.crop_shape`. + dummy_input = torch.zeros(size=(1, config.input_shapes["observation.image"][0], *config.crop_shape)) with torch.inference_mode(): - feat_map_shape = tuple( - self.backbone( - torch.zeros(size=(1, config.input_shapes["observation.image"][0], *config.crop_shape)) - ).shape[1:] - ) - self.pool = SpatialSoftmax(feat_map_shape, num_kp=config.spatial_softmax_num_keypoints) + dummy_feature_map = self.backbone(dummy_input) + feature_map_shape = tuple(dummy_feature_map.shape[1:]) + self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints) self.feature_dim = config.spatial_softmax_num_keypoints * 2 self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim) self.relu = nn.ReLU()