typos
This commit is contained in:
parent
62b18a7607
commit
0b4c42f4ff
|
@ -641,7 +641,7 @@ class _SinusoidalPositionEmbedding2D(nn.Module):
|
|||
"""
|
||||
not_mask = torch.ones_like(x[0, :1]) # (1, H, W)
|
||||
# Note: These are like range(1, H+1) and range(1, W+1) respectively, but in most implementations
|
||||
# they would be range(0, H) and range(0, W). Keeping it at as to match the original code.
|
||||
# they would be range(0, H) and range(0, W). Keeping it at as is to match the original code.
|
||||
y_range = not_mask.cumsum(1, dtype=torch.float32)
|
||||
x_range = not_mask.cumsum(2, dtype=torch.float32)
|
||||
|
||||
|
@ -659,7 +659,7 @@ class _SinusoidalPositionEmbedding2D(nn.Module):
|
|||
y_range = y_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)
|
||||
|
||||
# Note: this stack then flatten operation results in interleaved sine and cosine terms.
|
||||
# pos_embed_x and pos_embed are (1, H, W, C // 2).
|
||||
# pos_embed_x and pos_embed_y are (1, H, W, C // 2).
|
||||
pos_embed_x = torch.stack((x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1).flatten(3)
|
||||
pos_embed_y = torch.stack((y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1).flatten(3)
|
||||
pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(0, 3, 1, 2) # (1, C, H, W)
|
||||
|
|
Loading…
Reference in New Issue