attention interface
on this page
the attentioninterface provides a centralized abstraction for attention methods, enabling unified api, dynamic switching at runtime, and easy registration of custom attention functions.
architecture
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
# available implementations
list(ALL_ATTENTION_FUNCTIONS.keys())
# ['flash_attention_3', 'flash_attention_2', 'flex_attention', 'sdpa',
# 'paged|flash_attention_3', 'paged|flash_attention_2', 'paged|sdpa', 'paged|eager'] supported implementations
1. eager attention
simple matrix multiplication with no optimization. remains in modeling files. supports output_attentions=True.
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-1B",
attn_implementation="eager"
) 2. sdpa (scaled dot product attention)
pytorch’s native F.scaled_dot_product_attention. default for pytorch >= 2.1.1. up to 2x faster than eager in bf16.
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-1B",
attn_implementation="sdpa"
) 3. flash attention 2
io-aware attention from dao-ailab. requires fp16 or bf16 dtype. 2-4x faster training and inference with up to 50% memory reduction.
pip install flash-attn --no-build-isolation model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-1B",
attn_implementation="flash_attention_2",
torch_dtype=torch.float16
) 4. flash attention 3
latest version with hopper gpu optimizations. requires h100 or newer.
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-1B",
attn_implementation="flash_attention_3",
torch_dtype=torch.bfloat16
) 5. flexattention
pytorch’s flexible attention with torch.compile. supports custom score/mask modification functions and block-sparse computation via blockmask. does not support dropout.
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-1B",
attn_implementation="flex_attention"
) custom attention registration
required function signature
def custom_attention(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
**kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
returns: (attention_output, attention_weights)
"""
attn_output = compute_attention(query, key, value, attention_mask)
return attn_output, None registration
from transformers import AttentionInterface
AttentionInterface.register("my_attention", my_attention_function)
model = AutoModelForCausalLM.from_pretrained(
"model-name",
attn_implementation="my_attention"
) custom attention example
from transformers import AutoModelForCausalLM, AttentionInterface
from transformers.integrations.sdpa_attention import sdpa_attention_forward
def custom_attention_with_logging(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
log_stats: bool = False,
**kwargs
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
if log_stats:
print(f"Query shape: {query.shape}")
return sdpa_attention_forward(module, query, key, value, attention_mask, **kwargs)
AttentionInterface.register("logged_sdpa", custom_attention_with_logging)
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-1B",
attn_implementation="logged_sdpa"
)
# custom kwargs propagate to attention
outputs = model(input_ids, log_stats=True) runtime attention switching
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B")
# switch to sdpa
model.set_attn_implementation("sdpa")
# switch to flashattention2
model.set_attn_implementation("flash_attention_2")
# switch to custom
model.set_attn_implementation("my_custom_attention") multimodal attention configuration
different attention per backbone:
model = AutoModelForImageTextToText.from_pretrained(
"facebook/chameleon-7b",
attn_implementation={
"vision_config": "sdpa",
"text_config": "flash_attention_2"
}
) attentionmaskinterface
handles automatic mask format conversion between implementations:
from transformers import AttentionMaskInterface
from transformers.masking_utils import sdpa_mask
def custom_mask_function(*args, **kwargs):
return sdpa_mask(*args, **kwargs)
AttentionMaskInterface.register("my_attention", custom_mask_function) flexattention patterns
score_mod function
def score_mod(score, batch, head, q_idx, k_idx):
"""modify attention score before softmax."""
return score / temperature # temperature scaling mask_mod function
def mask_mod(batch, head, q_idx, k_idx):
"""return True to attend, False to mask."""
return q_idx >= k_idx # causal mask blockmask creation
from torch.nn.attention.flex_attention import create_block_mask
def causal_mask(b, h, q_idx, k_idx):
return q_idx >= k_idx
block_mask = create_block_mask(
causal_mask,
B=batch_size,
H=num_heads,
Q_LEN=seq_length,
KV_LEN=seq_length,
BLOCK_SIZE=128,
device="cuda"
) common flexattention patterns
sliding window:
def sliding_window_mask(b, h, q, k, window_size=256):
return (q >= k) and (q - k < window_size) alibi:
def alibi_score_mod(score, b, h, q, k, slope):
return score - abs(q - k) * slope performance benchmarks
| implementation | vs eager | memory |
|---|---|---|
| sdpa (bf16) | up to 2x | moderate |
| sdpa (fp8) | up to 3x | low |
| flashattention2 | 2-4x | 50% less |
| flexattention (training) | 2.4x | higher |
selection guide
| use case | recommended |
|---|---|
| default inference | sdpa |
| memory constrained | flash_attention_2 |
| need attention weights | eager |
| custom attention patterns | flex_attention |
| h100 with fp8 | sdpa |
| training large models | flash_attention_2 |