Attention¶
For framework integrations that already target FlashAttention-3 or FlashMLA
Python APIs, prefer the flash_attn_3 or flash_mla wrapper packages
first. Use the MATE APIs below when wrapper coverage is not enough.
FMHA¶
Forward dtype notes:
flash_attn_varlen_funcandflash_attn_with_kvcacheaccepttorch.float16,torch.bfloat16, andtorch.float8_e4m3fninputs on the MATE FMHA forward path.For FP8 inputs, optional
q_descale,k_descale, andv_descaletensors aretorch.float32scale factors with shape(batch_size, num_heads_kv).FP8 inputs produce
torch.bfloat16outputs by default.The FMHA forward path also supports an optional
qvinput, including FP8 inputs withqv.When both
qandqvare FP8,q_descaleapplies to both query tensors.For best FP8 attention performance, use MUSA SDK 5.2.0 or newer when available.
Use
flash_attn_combineto merge partial FMHA outputs and log-sum-exp buffers when your integration splits the attention computation.Use
get_scheduler_metadatato precompute thescheduler_metadatainput forflash_attn_with_kvcachewhen you call the FMHA path directly.
- mate.mha_interface.flash_attn_combine(out_partial: torch.Tensor, lse_partial: torch.Tensor, out: torch.Tensor | None = None, out_dtype: torch.dtype | None = None)[source]¶
- mate.mha_interface.flash_attn_with_kvcache(q: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, k: torch.Tensor | None = None, v: torch.Tensor | None = None, qv: torch.Tensor | None = None, rotary_cos: torch.Tensor | None = None, rotary_sin: torch.Tensor | None = None, cache_seqlens: int | torch.Tensor | None = None, cache_batch_idx: torch.Tensor | None = None, cache_leftpad: torch.Tensor | None = None, page_table: torch.Tensor | None = None, cu_seqlens_q: torch.Tensor | None = None, cu_seqlens_k_new: torch.Tensor | None = None, max_seqlen_q: int | None = None, rotary_seqlens: torch.Tensor | None = None, q_descale: torch.Tensor | None = None, k_descale: torch.Tensor | None = None, v_descale: torch.Tensor | None = None, softmax_scale: float | None = None, causal: bool = False, window_size: Tuple | List | None = (-1, -1), learnable_sink: torch.Tensor | None = None, attention_chunk: int | None = 0, softcap: float = 0.0, rotary_interleaved: bool = True, scheduler_metadata: torch.Tensor | None = None, num_splits: int = 0, pack_gqa=None, sm_margin=0, return_softmax_lse: bool = False, cp_world_size: int = 1, cp_rank: int = 0, cp_tot_seqused_k: torch.Tensor | None = None)[source]¶
FlashAttention3 compatible API: forward with kv cache
- Parameters:
q (Tensor) – The query tensor with shape
(batch_size, seqlen, nheads, headdim)if cu_seqlens_q is None, or(total_q, nheads, headdim)if cu_seqlens_q is not Nonek_cache (Tensor) – The key cache tensor with shape
(batch_size_cache, seqlen_cache, nheads_k, headdim)if there’s no page_table, or(num_blocks, page_block_size, nheads_k, headdim)if there’s a page_table (i.e. paged KV cache)v_cache (Tensor) – The value cache tensor with shape
(batch_size_cache, seqlen_cache, nheads_k, headdim_v)if there’s no page_table, or(num_blocks, page_block_size, nheads_k, headdim_v)if there’s a page_table (i.e. paged KV cache)k (Optional[Tensor]) – The key tensor with shape
(batch_size, seqlen_new, nheads_k, headdim)if cu_seqlens_k_new is None, or(total_k_new, nheads_k, headdim)if cu_seqlens_k_new is not None. If k is not None, we concatenate k with k_cache, starting at the indices specified by cache_seqlens.v (Optional[Tensor]) – The value tensor with shape
(batch_size, seqlen_new, nheads_k, headdim_v)if cu_seqlens_k_new is None. or(total_k_new, nheads_k, headdim_v)if cu_seqlens_k_new is not None. Similar to k.rotary_cos (Optional[Tensor]) – Tensor with shape
(seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding to k and q. Only applicable if k and v are passed in.rotary_dimmust be<= headdimand divisible by 16.rotary_cosmust be on MUSA and have the same dtype as q.rotary_sin (Optional[Tensor]) – Tensor with shape
(seqlen_ro, rotary_dim / 2). Similar to rotary_cos and must have the same shape and dtype.cache_seqlens (Union[int, Tensor]) – The sequence lengths of the KV cache, shape
(batch_size)if it is tensor.cache_batch_idx (Optional[Tensor]) – The int32 indices used to index into the KV cache, shape
(batch_size,). The tensor must be on MUSA and contiguous. If the indices are not distinct, and k and v are provided, the values updated in the cache might come from any of the duplicate indices.cache_leftpad (Optional[Tensor]) – The int32 left padding offset where the KV cache starts for each batch, shape
(batch_size,). The tensor must be on MUSA and contiguous. If None, assume 0.page_table (Optional[Tensor]) – The page table tensor with shape
(batch_size, max_num_blocks_per_seq)cu_seqlens_q (Optional[Tensor]) – The cumulative sequence lengths of the query, shape
(batch_size + 1).cu_seqlens_k_new (Optional[Tensor]) – The cumulative sequence lengths of the new KV, shape
(batch_size + 1).rotary_seqlens (Optional[Tensor]) – Optional int32 tensor with shape
(batch_size,)used as the rotary position length for each batch.softmax_scale (Optional[float]) – The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim).
causal (bool) – Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size (Tuple[int, int]) – The size of the sliding window. If not (-1, -1), implements sliding window local attention.
learnable_sink (Optional[Tensor]) – The Learnable Sink tensor for attention, shape
(nheads, ).softcap (float) –
Anything > 0 activates softcapping attention, applied as
logits = softcap * tanh(logits / softcap)before the softmax. 0.0 (default) disables softcapping.rotary_interleaved (bool) – If True, rotary embedding uses GPT-J style and combines dimensions 0 & 1, 2 & 3, etc. If False, rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1 (i.e. GPT-NeoX style).
num_splits (int) – If > 1, split the key/value into this many chunks along the sequence. If num_splits == 1, we don’t split the key/value. If num_splits == 0, we use a heuristic to automatically determine the number of splits. Don’t change this unless you know what you are doing.
return_softmax_lse (bool) – Whether to return the logsumexp of the attention scores.
cp_world_size (int) – Total number of ranks in the Context Parallelism (CP) group. Default 1 (CP disabled). When > 1, the global sequence is assumed to be distributed across ranks using an interleaved token pattern, where rank
rholds tokens at positions[r, r + cp_world_size, r + 2*cp_world_size, ...].cp_rank (int) – The rank of the current device within the CP group. Default 0.
cp_tot_seqused_k (Optional[Tensor]) – The global (across all CP ranks) cumulative key sequence lengths, shape
(batch_size + 1,), dtypeint32. Required when CP is enabled (cp_world_size > 1) so that each rank can correctly compute causal masking boundaries against the full key sequence. Ignored whencp_world_size == 1.
- Returns:
If
return_softmax_lseisFalse, the attention output, shape(batch_size, seqlen, nheads, headdim_v)if cu_seqlens_q is None, or(total_q, nheads, headdim_v)if cu_seqlens_q is not NoneIf
return_softmax_lseisTrue, a tuple of two tensors:The attention output, shape
(batch_size, seqlen, nheads, headdim_v)if cu_seqlens_q is None, or(total_q, nheads, headdim_v)if cu_seqlens_q is not NoneThe log sum exp value, shape
(batch_size, nheads, seqlen)if cu_seqlens_q is None, or(nheads, total_q)if cu_seqlens_q is not None
- Return type:
Union[Tensor, Tuple[Tensor, Tensor]]
- mate.mha_interface.get_scheduler_metadata(batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, seqused_q: torch.Tensor | None = None, seqused_k: torch.Tensor | None = None, qkv_dtype=torch.bfloat16, headdim_v=None, cu_seqlens_q: torch.Tensor | None = None, cu_seqlens_k: torch.Tensor | None = None, cu_seqlens_k_new: torch.Tensor | None = None, cache_leftpad: torch.Tensor | None = None, page_size=None, max_seqlen_k_new=0, causal=False, window_size=(-1, -1), attention_chunk=0, has_softcap=False, num_splits=0, pack_gqa=None, has_qv=False, mp_margin=0)[source]¶
Build scheduler metadata for
flash_attn_with_kvcache.Use this helper to precompute the tensor passed through the
scheduler_metadataargument offlash_attn_with_kvcache. This is the direct top-level API exposed asmate.get_scheduler_metadata.- Parameters:
batch_size (int) – Batch size for the scheduled attention workload.
max_seqlen_q (int) – Maximum query sequence length used by the target
flash_attn_with_kvcachecall.max_seqlen_k (int) – Maximum key sequence length already present in the cache.
num_heads_q (int) – Number of query heads.
num_heads_kv (int) – Number of key / value heads.
headdim (int) – Query and key head dimension.
seqused_q (Optional[Tensor]) – Optional per-batch query lengths with shape
(batch_size,).seqused_k (Optional[Tensor]) – Optional per-batch key lengths with shape
(batch_size,).qkv_dtype (torch.dtype) – Data type of the scheduled QKV path. Default
torch.bfloat16.headdim_v (Optional[int]) – Value head dimension. Defaults to
headdim.cu_seqlens_q (Optional[Tensor]) – Optional cumulative query sequence lengths with shape
(batch_size + 1,).cu_seqlens_k (Optional[Tensor]) – Optional cumulative cached key sequence lengths with shape
(batch_size + 1,).cu_seqlens_k_new (Optional[Tensor]) – Optional cumulative new-KV sequence lengths with shape
(batch_size + 1,).cache_leftpad (Optional[Tensor]) – Optional per-batch left padding offsets for the KV cache.
page_size (Optional[int]) – Page size for paged KV-cache scheduling.
max_seqlen_k_new (int) – Maximum number of newly appended KV tokens.
causal (bool) – Whether the target attention call uses causal masking.
window_size (Tuple[int, int]) – Sliding-window attention bounds.
(-1, -1)means full context.attention_chunk (int) – Chunk size used by chunked attention scheduling.
has_softcap (bool) – Whether the target attention call enables softcapping.
num_splits (int) – Requested split count for key / value scheduling.
pack_gqa (Optional[bool]) – Optional GQA packing mode.
has_qv (bool) – Whether the target attention call includes the optional
qvinput.mp_margin (int) – Number of MPs reserved for communication or other work.
- Returns:
Scheduler metadata tensor to pass to
flash_attn_with_kvcache(..., scheduler_metadata=...).- Return type:
Tensor
Notes
Keep the scheduling-related arguments aligned with the corresponding
flash_attn_with_kvcachecall so the generated metadata matches the actual workload.
- mate.mha_interface.flash_attn_varlen_func(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens_q: torch.Tensor | None = None, cu_seqlens_k: torch.Tensor | None = None, max_seqlen_q: int | None = None, max_seqlen_k: int | None = None, seqused_q: torch.Tensor | None = None, seqused_k: torch.Tensor | None = None, page_table: torch.Tensor | None = None, softmax_scale: float | None = None, causal: bool = False, qv: torch.Tensor | None = None, q_descale: torch.Tensor | None = None, k_descale: torch.Tensor | None = None, v_descale: torch.Tensor | None = None, window_size: Tuple | List | None = (-1, -1), learnable_sink: torch.Tensor | None = None, attention_chunk: int | None = 0, softcap: float = 0.0, scheduler_metadata: torch.Tensor | None = None, num_splits: int = 0, pack_gqa=None, deterministic: bool = False, sm_margin=0, return_softmax_lse: bool = False, backend: str = 'auto', cp_world_size: int = 1, cp_rank: int = 0, cp_tot_seqused_k: torch.Tensor | None = None, out: torch.Tensor | None = None)[source]¶
FlashAttention3 compaitible API: forward with varlen or non-varlen inputs
- Parameters:
q (Tensor) – The query tensor with shape
(batch_size, seqlen, nheads, headdim)if cu_seqlen_q is None, or(total_q, nheads, headdim)if cu_seqlen_q is not None.k (Tensor) – The key tensor with shape
(batch_size, seqlen_k, nheads_k, headdim)if cu_seqlen_k is None, or(total_k, nheads_k, headdim)if cu_seqlen_k is not None.v (Tensor) – The value tensor with shape
(batch_size, seqlen_k, nheads_k, headdim)if cu_seqlen_k is None, or(total_k, nhead_k, headdim_v)if cu_seqlen_k is not None.cu_seqlens_q (Optional[Tensor]) – The cumulative sequence length tensor for query, shape
(batch_size + 1)cu_seqlens_k (Optional[Tensor]) – The cumulative sequence length tensor for key/value, shape
(batch_size + 1)max_seqlen_q (Optional[int]) – The maximum sequence length for query, must provided if varlen forward
max_seqlen_k (Optional[int]) – The maximum sequence length for key/value
seqused_q (Optional[Tensor]) – Tensor with shape
(batch_size)If given, only this many element of each batch element’s queries and outputs are used.seqused_k (Optional[Tensor]) – Tensor with shape
(batch_size)If given, only this many element of each batch element’s keys and values are used.softmax_scale (Optional[float]) – The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim).
causal (bool) – Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size (Tuple[int, int]) – The size of the sliding window. If not (-1, -1), implements sliding window local attention.
learnable_sink (Optional[Tensor]) – The Learnable Sink tensor for attention, shape
(nheads, ).softcap (float) –
Anything > 0 activates softcapping attention, applied as
logits = softcap * tanh(logits / softcap)before the softmax. 0.0 (default) disables softcapping.return_softmax_lse (bool) – Whether to return the logsumexp of the attention scores.
backend (str) – The backend to use. It’s recommend to use the default
auto.cp_world_size (int) – Total number of ranks in the Context Parallelism (CP) group. Default 1 (CP disabled). When > 1, the global sequence is assumed to be distributed across ranks using an interleaved token pattern, where rank
rholds tokens at positions[r, r + cp_world_size, r + 2*cp_world_size, ...].cp_rank (int) – The rank of the current device within the CP group. Default 0.
cp_tot_seqused_k (Optional[Tensor]) – The global (across all CP ranks) cumulative key sequence lengths, shape
(batch_size + 1,), dtypeint32. Required when CP is enabled (cp_world_size > 1) so that each rank can correctly compute causal masking boundaries against the full key sequence. Ignored whencp_world_size == 1.
- Returns:
If
return_softmax_lseisFalse, the attention output, shape(total_q, nheads, headdim_v)If
return_softmax_lseisTrue, a tuple of two tensors:The attention output, shape
(total_q, nheads, headdim_v)The log sum exp value, shape
(nheads, total_q)
- Return type:
Union[Tensor, Tuple[Tensor, Tensor]]
MLA¶
- mate.flashmla.get_mla_metadata(cache_seqlens: torch.Tensor | None, num_q_tokens_per_head_k: int, num_heads_k: int, num_heads_q: int | None = None, is_fp8_kvcache: bool = False, topk: int | None = None, extra_topk: int | None = None, q: torch.Tensor | None = None, bs: int | None = None, topk_length: torch.Tensor | None = None, extra_topk_length: torch.Tensor | None = None) Tuple[torch.Tensor, torch.Tensor][source]¶
Get metadata for MLA decoding.
- Parameters:
cache_seqlens (Tensor) – The sequence lengths of the KV cache with shape
(batch_size)num_q_tokens_per_head_k (int) – Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k.
num_heads_k (int) – The number of k heads.
num_heads_q (Optional[int]) – The number of q heads. This argument is optional when sparse attention is not enabled
is_fp8_kvcache (bool) – Whether the k_cache and v_cache are in fp8 format.
topk (Optional[int]) – If not None, sparse attention will be enabled, and only tokens in the indices array passed to flash_mla_with_kvcache_sm90 will be attended to.
extra_topk (Optional[int]) – Optional sparse extra-KV topk. This mirrors FlashMLA sparse decode metadata semantics without folding extra work into topk.
topk_length (Optional[Tensor]) – Optional per-batch sparse workload lengths for MODEL1 scheduled decode metadata. If extra_topk_length is provided, both lengths are summed in the C++ metadata kernel.
- Returns:
A tuple of two tensors:
tile_scheduler_metadata, shape
(num_sm_parts, TileSchedulerMetaDataSize)num_splits, shape
(batch_size + 1)
- Return type:
Tuple[Tensor, Tensor]
- mate.flashmla.flash_mla_with_kvcache(q: torch.Tensor, k_cache: torch.Tensor, block_table: torch.Tensor, cache_seqlens: torch.Tensor, head_dim_v: int, tile_scheduler_metadata: torch.Tensor, num_splits: torch.Tensor, softmax_scale: float | None = None, causal: bool = False, is_fp8_kvcache: bool = False, indices: torch.Tensor | None = None, attn_sink: torch.Tensor | None = None, extra_k_cache: torch.Tensor | None = None, extra_indices_in_kvcache: torch.Tensor | None = None, topk_length: torch.Tensor | None = None, extra_topk_length: torch.Tensor | None = None) Tuple[torch.Tensor, torch.Tensor][source]¶
Mla forward with kv cache
- Parameters:
q (Tensor) – The query tensor with shape
(batch_size, seq_len_q, num_heads_q, head_dim).k_cache (Tensor) – The compressed kv cache tensor with shape
(num_blocks, page_block_size, num_heads_k, head_dim).block_table (Tensor) – The page table with shape
(batch_size, max_num_blocks_per_seq).cache_seqlens (Tensor) – The sequence lengths of the ckv cache wtih shape
(batch_size).head_dim_v (int) – Head dimension of v.
tile_scheduler_metadata (Tensor) – The scheduler metadata with shape
(num_sm_parts, TileSchedulerMetaDataSize), returned by get_mla_metadata.num_splits (Tensor) – The num_splits tensor with shape
(batch_size + 1), returned by get_mla_metadata.softmax_scale (Optional[float]) – The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal (bool) – Whether to apply causal attention mask.
is_fp8_kvcache (bool) – Whether the k_cache and v_cache are in fp8 format. For the format of FP8 KV cache, please refer to README.md
indices (Optinal[Tensor]) – The token indices tensor with shape
(batch_size, seq_len_q, topk). If not None, sparse attention will be enabled, and only tokens in the indices array will be attended to. Invalid indices should be set to -1 or numbers >= total_seq_len_kv. For details about how to set up indices, please refer to README.md.
Notes
For DeepSeek V3, DeepSeek V3.1, and DeepSeek V3.2,
head_dimshould be 576 whilehead_dim_vshould be 512.In FP8 + sparse mode, each token’s KV cache is 656 bytes.
k_cachehas shape(num_blocks, page_block_size, num_heads_k, head_dim), andnum_heads_kmust be 1. The first 512 bytes contain the quantized NoPE part with 512float8_e4m3values. The next 16 bytes contain fourfloat32scale factors, one for each group of 128float8_e4m3values. The final 128 bytes contain the RoPE part with 64bfloat16values; this part is left unquantized for accuracy.- Returns:
A tuple of two tensors:
out, shape
(batch_size, seq_len_q, num_heads_q, head_dim_v).softmax_lse, shape
(batch_size, num_heads_q, seq_len_q).
- Return type:
Tuple[Tensor, Tensor]