attention interface

the attentioninterface provides a centralized abstraction for attention methods, enabling unified api, dynamic switching at runtime, and easy registration of custom attention functions.

architecture

from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS

# available implementations
list(ALL_ATTENTION_FUNCTIONS.keys())
# ['flash_attention_3', 'flash_attention_2', 'flex_attention', 'sdpa',
#  'paged|flash_attention_3', 'paged|flash_attention_2', 'paged|sdpa', 'paged|eager']

supported implementations

1. eager attention

simple matrix multiplication with no optimization. remains in modeling files. supports output_attentions=True.

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.2-1B",
    attn_implementation="eager"
)

2. sdpa (scaled dot product attention)

pytorch’s native F.scaled_dot_product_attention. default for pytorch >= 2.1.1. up to 2x faster than eager in bf16.

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.2-1B",
    attn_implementation="sdpa"
)

3. flash attention 2

io-aware attention from dao-ailab. requires fp16 or bf16 dtype. 2-4x faster training and inference with up to 50% memory reduction.

pip install flash-attn --no-build-isolation
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.2-1B",
    attn_implementation="flash_attention_2",
    torch_dtype=torch.float16
)

4. flash attention 3

latest version with hopper gpu optimizations. requires h100 or newer.

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.2-1B",
    attn_implementation="flash_attention_3",
    torch_dtype=torch.bfloat16
)

5. flexattention

pytorch’s flexible attention with torch.compile. supports custom score/mask modification functions and block-sparse computation via blockmask. does not support dropout.

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.2-1B",
    attn_implementation="flex_attention"
)

custom attention registration

required function signature

def custom_attention(
    module: torch.nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    **kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
    """
    returns: (attention_output, attention_weights)
    """
    attn_output = compute_attention(query, key, value, attention_mask)
    return attn_output, None

registration

from transformers import AttentionInterface

AttentionInterface.register("my_attention", my_attention_function)

model = AutoModelForCausalLM.from_pretrained(
    "model-name",
    attn_implementation="my_attention"
)

custom attention example

from transformers import AutoModelForCausalLM, AttentionInterface
from transformers.integrations.sdpa_attention import sdpa_attention_forward

def custom_attention_with_logging(
    module: torch.nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    log_stats: bool = False,
    **kwargs
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
    if log_stats:
        print(f"Query shape: {query.shape}")
    return sdpa_attention_forward(module, query, key, value, attention_mask, **kwargs)

AttentionInterface.register("logged_sdpa", custom_attention_with_logging)

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.2-1B",
    attn_implementation="logged_sdpa"
)

# custom kwargs propagate to attention
outputs = model(input_ids, log_stats=True)

runtime attention switching

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B")

# switch to sdpa
model.set_attn_implementation("sdpa")

# switch to flashattention2
model.set_attn_implementation("flash_attention_2")

# switch to custom
model.set_attn_implementation("my_custom_attention")

multimodal attention configuration

different attention per backbone:

model = AutoModelForImageTextToText.from_pretrained(
    "facebook/chameleon-7b",
    attn_implementation={
        "vision_config": "sdpa",
        "text_config": "flash_attention_2"
    }
)

attentionmaskinterface

handles automatic mask format conversion between implementations:

from transformers import AttentionMaskInterface
from transformers.masking_utils import sdpa_mask

def custom_mask_function(*args, **kwargs):
    return sdpa_mask(*args, **kwargs)

AttentionMaskInterface.register("my_attention", custom_mask_function)

flexattention patterns

score_mod function

def score_mod(score, batch, head, q_idx, k_idx):
    """modify attention score before softmax."""
    return score / temperature  # temperature scaling

mask_mod function

def mask_mod(batch, head, q_idx, k_idx):
    """return True to attend, False to mask."""
    return q_idx >= k_idx  # causal mask

blockmask creation

from torch.nn.attention.flex_attention import create_block_mask

def causal_mask(b, h, q_idx, k_idx):
    return q_idx >= k_idx

block_mask = create_block_mask(
    causal_mask,
    B=batch_size,
    H=num_heads,
    Q_LEN=seq_length,
    KV_LEN=seq_length,
    BLOCK_SIZE=128,
    device="cuda"
)

common flexattention patterns

sliding window:

def sliding_window_mask(b, h, q, k, window_size=256):
    return (q >= k) and (q - k < window_size)

alibi:

def alibi_score_mod(score, b, h, q, k, slope):
    return score - abs(q - k) * slope

performance benchmarks

implementationvs eagermemory
sdpa (bf16)up to 2xmoderate
sdpa (fp8)up to 3xlow
flashattention22-4x50% less
flexattention (training)2.4xhigher

selection guide

use caserecommended
default inferencesdpa
memory constrainedflash_attention_2
need attention weightseager
custom attention patternsflex_attention
h100 with fp8sdpa
training large modelsflash_attention_2

references

on this page