128 lines
4.1 KiB
Python
128 lines
4.1 KiB
Python
import torch
|
|
import torch.nn.functional as F # noqa: N812
|
|
from packaging.version import Version
|
|
|
|
if Version(torch.__version__) > Version("2.5.0"):
|
|
# Ffex attention is only available from torch 2.5 onwards
|
|
from torch.nn.attention.flex_attention import (
|
|
_mask_mod_signature,
|
|
_round_up_to_multiple,
|
|
create_block_mask,
|
|
create_mask,
|
|
flex_attention,
|
|
)
|
|
|
|
|
|
# @torch.compile(dynamic=False)
|
|
def flex_attention_forward(
|
|
attention_mask: torch.Tensor,
|
|
batch_size: int,
|
|
head_dim: int,
|
|
query_states: torch.Tensor,
|
|
key_states: torch.Tensor,
|
|
value_states: torch.Tensor,
|
|
scaling=None,
|
|
):
|
|
"""
|
|
This is defined out of classes to make compile happy.
|
|
"""
|
|
|
|
original_dtype = query_states.dtype
|
|
num_att_heads = 8
|
|
num_key_value_heads = 1
|
|
num_key_value_groups = num_att_heads // num_key_value_heads
|
|
|
|
key_states = key_states[:, :, :, None, :]
|
|
key_states = key_states.expand(
|
|
batch_size, key_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
|
|
)
|
|
key_states = key_states.reshape(
|
|
batch_size, key_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
|
|
)
|
|
|
|
value_states = value_states[:, :, :, None, :]
|
|
value_states = value_states.expand(
|
|
batch_size, value_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
|
|
)
|
|
value_states = value_states.reshape(
|
|
batch_size, value_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
|
|
)
|
|
|
|
query_states = query_states.transpose(1, 2)
|
|
key_states = key_states.transpose(1, 2)
|
|
value_states = value_states.transpose(1, 2)
|
|
|
|
query_states = query_states.to(torch.float32)
|
|
key_states = key_states.to(torch.float32)
|
|
value_states = value_states.to(torch.float32)
|
|
|
|
causal_mask = attention_mask
|
|
if causal_mask is not None:
|
|
causal_mask = causal_mask[:, None, :, : key_states.shape[2]]
|
|
|
|
if causal_mask.shape[1] == 1 and query_states.shape[1] > 1:
|
|
causal_mask = causal_mask.expand(-1, query_states.shape[1], -1, -1)
|
|
|
|
def precomputed_mask_factory(precomputed_mask: torch.Tensor) -> _mask_mod_signature:
|
|
def mask_mod(b, h, q_idx, kv_idx):
|
|
# Danger zone: if b,h,q_idx,kv_idx exceed the shape, device-side assert occurs.
|
|
return precomputed_mask[b][h][q_idx][kv_idx]
|
|
|
|
return mask_mod
|
|
|
|
b_mask, h_mask, q_len, kv_len = causal_mask.shape # The shape of your mask
|
|
|
|
block_size = 128
|
|
q_len_rounded = _round_up_to_multiple(q_len, block_size)
|
|
kv_len_rounded = _round_up_to_multiple(kv_len, block_size)
|
|
|
|
# *CRITICAL* we do need to expand here, else we get a CUDA index error
|
|
|
|
pad_q = q_len_rounded - q_len
|
|
pad_k = kv_len_rounded - kv_len
|
|
|
|
padded_causal_mask = F.pad(causal_mask, (0, pad_k, 0, pad_q), value=0.0)
|
|
mask_mod_fn_orig = precomputed_mask_factory(padded_causal_mask)
|
|
|
|
mask_4d = create_mask(
|
|
mod_fn=mask_mod_fn_orig,
|
|
B=b_mask,
|
|
H=h_mask,
|
|
Q_LEN=q_len_rounded,
|
|
KV_LEN=kv_len_rounded,
|
|
device=causal_mask.device,
|
|
_compile=False,
|
|
)
|
|
|
|
mask_mod_fn_padded = precomputed_mask_factory(mask_4d)
|
|
block_mask = create_block_mask(
|
|
mask_mod=mask_mod_fn_padded,
|
|
B=b_mask,
|
|
H=h_mask,
|
|
Q_LEN=q_len_rounded,
|
|
KV_LEN=kv_len_rounded,
|
|
BLOCK_SIZE=block_size,
|
|
device=causal_mask.device,
|
|
_compile=False,
|
|
)
|
|
|
|
# mask is applied inside the kernel, ideally more efficiently than score_mod.
|
|
attn_output, attention_weights = flex_attention(
|
|
query_states,
|
|
key_states,
|
|
value_states,
|
|
block_mask=block_mask,
|
|
enable_gqa=True, # because we shaped query/key states for GQA
|
|
scale=head_dim**-0.5 if scaling is None else scaling,
|
|
return_lse=True,
|
|
)
|
|
|
|
attn_output = attn_output.to(dtype=original_dtype)
|
|
attn_output = attn_output.transpose(1, 2).contiguous() # [B, Q_LEN, H, head_dim]
|
|
attn_output = attn_output.reshape(
|
|
batch_size,
|
|
-1,
|
|
attn_output.shape[2] * attn_output.shape[3], # merges [H, head_dim]
|
|
)
|
|
return attn_output
|