This commit is contained in:
Alexander Soare 2024-04-08 14:59:37 +01:00
parent 62b18a7607
commit 0b4c42f4ff
1 changed files with 2 additions and 2 deletions

View File

@ -641,7 +641,7 @@ class _SinusoidalPositionEmbedding2D(nn.Module):
"""
not_mask = torch.ones_like(x[0, :1]) # (1, H, W)
# Note: These are like range(1, H+1) and range(1, W+1) respectively, but in most implementations
# they would be range(0, H) and range(0, W). Keeping it at as to match the original code.
# they would be range(0, H) and range(0, W). Keeping it at as is to match the original code.
y_range = not_mask.cumsum(1, dtype=torch.float32)
x_range = not_mask.cumsum(2, dtype=torch.float32)
@ -659,7 +659,7 @@ class _SinusoidalPositionEmbedding2D(nn.Module):
y_range = y_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)
# Note: this stack then flatten operation results in interleaved sine and cosine terms.
# pos_embed_x and pos_embed are (1, H, W, C // 2).
# pos_embed_x and pos_embed_y are (1, H, W, C // 2).
pos_embed_x = torch.stack((x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1).flatten(3)
pos_embed_y = torch.stack((y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1).flatten(3)
pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(0, 3, 1, 2) # (1, C, H, W)