diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 91cf6dd0..a7ba5442 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -315,11 +315,13 @@ class DiffusionRgbEncoder(nn.Module): # Set up pooling and final layers. # Use a dry run to get the feature map shape. + # The dummy input should take the number of image channels from `config.input_shapes` and it should use the + # height and width from `config.crop_shape`. + dummy_input = torch.zeros(size=(1, config.input_shapes["observation.image"][0], *config.crop_shape)) with torch.inference_mode(): - feat_map_shape = tuple( - self.backbone(torch.zeros(size=(1, *config.input_shapes["observation.image"]))).shape[1:] - ) - self.pool = SpatialSoftmax(feat_map_shape, num_kp=config.spatial_softmax_num_keypoints) + dummy_feature_map = self.backbone(dummy_input) + feature_map_shape = tuple(dummy_feature_map.shape[1:]) + self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints) self.feature_dim = config.spatial_softmax_num_keypoints * 2 self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim) self.relu = nn.ReLU()