Attention

For framework integrations that already target FlashAttention-3 or FlashMLA Python APIs, prefer the flash_attn_3 or flash_mla wrapper packages first. Use the MATE APIs below when wrapper coverage is not enough.

FMHA

Forward dtype notes:

  • flash_attn_varlen_func and flash_attn_with_kvcache accept torch.float16, torch.bfloat16, and torch.float8_e4m3fn inputs on the MATE FMHA forward path.

  • For FP8 inputs, optional q_descale, k_descale, and v_descale tensors are torch.float32 scale factors with shape (batch_size, num_heads_kv).

  • FP8 inputs produce torch.bfloat16 outputs by default.

  • The FMHA forward path also supports an optional qv input, including FP8 inputs with qv.

  • When both q and qv are FP8, q_descale applies to both query tensors.

  • For best FP8 attention performance, use MUSA SDK 5.2.0 or newer when available.

  • Use flash_attn_combine to merge partial FMHA outputs and log-sum-exp buffers when your integration splits the attention computation.

  • Use get_scheduler_metadata to precompute the scheduler_metadata input for flash_attn_with_kvcache when you call the FMHA path directly.

mate.mha_interface.flash_attn_combine(out_partial: torch.Tensor, lse_partial: torch.Tensor, out: torch.Tensor | None = None, out_dtype: torch.dtype | None = None)[source]
mate.mha_interface.flash_attn_with_kvcache(q: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, k: torch.Tensor | None = None, v: torch.Tensor | None = None, qv: torch.Tensor | None = None, rotary_cos: torch.Tensor | None = None, rotary_sin: torch.Tensor | None = None, cache_seqlens: int | torch.Tensor | None = None, cache_batch_idx: torch.Tensor | None = None, cache_leftpad: torch.Tensor | None = None, page_table: torch.Tensor | None = None, cu_seqlens_q: torch.Tensor | None = None, cu_seqlens_k_new: torch.Tensor | None = None, max_seqlen_q: int | None = None, rotary_seqlens: torch.Tensor | None = None, q_descale: torch.Tensor | None = None, k_descale: torch.Tensor | None = None, v_descale: torch.Tensor | None = None, softmax_scale: float | None = None, causal: bool = False, window_size: Tuple | List | None = (-1, -1), learnable_sink: torch.Tensor | None = None, attention_chunk: int | None = 0, softcap: float = 0.0, rotary_interleaved: bool = True, scheduler_metadata: torch.Tensor | None = None, num_splits: int = 0, pack_gqa=None, sm_margin=0, return_softmax_lse: bool = False, cp_world_size: int = 1, cp_rank: int = 0, cp_tot_seqused_k: torch.Tensor | None = None)[source]

FlashAttention3 compatible API: forward with kv cache

Parameters:
  • q (Tensor) – The query tensor with shape (batch_size, seqlen, nheads, headdim) if cu_seqlens_q is None, or (total_q, nheads, headdim) if cu_seqlens_q is not None

  • k_cache (Tensor) – The key cache tensor with shape (batch_size_cache, seqlen_cache, nheads_k, headdim) if there’s no page_table, or (num_blocks, page_block_size, nheads_k, headdim) if there’s a page_table (i.e. paged KV cache)

  • v_cache (Tensor) – The value cache tensor with shape (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there’s no page_table, or (num_blocks, page_block_size, nheads_k, headdim_v) if there’s a page_table (i.e. paged KV cache)

  • k (Optional[Tensor]) – The key tensor with shape (batch_size, seqlen_new, nheads_k, headdim) if cu_seqlens_k_new is None, or (total_k_new, nheads_k, headdim) if cu_seqlens_k_new is not None. If k is not None, we concatenate k with k_cache, starting at the indices specified by cache_seqlens.

  • v (Optional[Tensor]) – The value tensor with shape (batch_size, seqlen_new, nheads_k, headdim_v) if cu_seqlens_k_new is None. or (total_k_new, nheads_k, headdim_v) if cu_seqlens_k_new is not None. Similar to k.

  • rotary_cos (Optional[Tensor]) – Tensor with shape (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding to k and q. Only applicable if k and v are passed in. rotary_dim must be <= headdim and divisible by 16. rotary_cos must be on MUSA and have the same dtype as q.

  • rotary_sin (Optional[Tensor]) – Tensor with shape (seqlen_ro, rotary_dim / 2). Similar to rotary_cos and must have the same shape and dtype.

  • cache_seqlens (Union[int, Tensor]) – The sequence lengths of the KV cache, shape (batch_size) if it is tensor.

  • cache_batch_idx (Optional[Tensor]) – The int32 indices used to index into the KV cache, shape (batch_size,). The tensor must be on MUSA and contiguous. If the indices are not distinct, and k and v are provided, the values updated in the cache might come from any of the duplicate indices.

  • cache_leftpad (Optional[Tensor]) – The int32 left padding offset where the KV cache starts for each batch, shape (batch_size,). The tensor must be on MUSA and contiguous. If None, assume 0.

  • page_table (Optional[Tensor]) – The page table tensor with shape (batch_size, max_num_blocks_per_seq)

  • cu_seqlens_q (Optional[Tensor]) – The cumulative sequence lengths of the query, shape (batch_size + 1).

  • cu_seqlens_k_new (Optional[Tensor]) – The cumulative sequence lengths of the new KV, shape (batch_size + 1).

  • rotary_seqlens (Optional[Tensor]) – Optional int32 tensor with shape (batch_size,) used as the rotary position length for each batch.

  • softmax_scale (Optional[float]) – The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim).

  • causal (bool) – Whether to apply causal attention mask (e.g., for auto-regressive modeling).

  • window_size (Tuple[int, int]) – The size of the sliding window. If not (-1, -1), implements sliding window local attention.

  • learnable_sink (Optional[Tensor]) – The Learnable Sink tensor for attention, shape (nheads, ).

  • softcap (float) –

    Anything > 0 activates softcapping attention, applied as

    logits = softcap * tanh(logits / softcap) before the softmax. 0.0 (default) disables softcapping.

  • rotary_interleaved (bool) – If True, rotary embedding uses GPT-J style and combines dimensions 0 & 1, 2 & 3, etc. If False, rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1 (i.e. GPT-NeoX style).

  • num_splits (int) – If > 1, split the key/value into this many chunks along the sequence. If num_splits == 1, we don’t split the key/value. If num_splits == 0, we use a heuristic to automatically determine the number of splits. Don’t change this unless you know what you are doing.

  • return_softmax_lse (bool) – Whether to return the logsumexp of the attention scores.

  • cp_world_size (int) – Total number of ranks in the Context Parallelism (CP) group. Default 1 (CP disabled). When > 1, the global sequence is assumed to be distributed across ranks using an interleaved token pattern, where rank r holds tokens at positions [r, r + cp_world_size, r + 2*cp_world_size, ...].

  • cp_rank (int) – The rank of the current device within the CP group. Default 0.

  • cp_tot_seqused_k (Optional[Tensor]) – The global (across all CP ranks) cumulative key sequence lengths, shape (batch_size + 1,), dtype int32. Required when CP is enabled (cp_world_size > 1) so that each rank can correctly compute causal masking boundaries against the full key sequence. Ignored when cp_world_size == 1.

Returns:

If return_softmax_lse is False, the attention output, shape (batch_size, seqlen, nheads, headdim_v) if cu_seqlens_q is None, or (total_q, nheads, headdim_v) if cu_seqlens_q is not None

If return_softmax_lse is True, a tuple of two tensors:

  • The attention output, shape (batch_size, seqlen, nheads, headdim_v) if cu_seqlens_q is None, or (total_q, nheads, headdim_v) if cu_seqlens_q is not None

  • The log sum exp value, shape (batch_size, nheads, seqlen) if cu_seqlens_q is None, or (nheads, total_q) if cu_seqlens_q is not None

Return type:

Union[Tensor, Tuple[Tensor, Tensor]]

mate.mha_interface.get_scheduler_metadata(batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, seqused_q: torch.Tensor | None = None, seqused_k: torch.Tensor | None = None, qkv_dtype=torch.bfloat16, headdim_v=None, cu_seqlens_q: torch.Tensor | None = None, cu_seqlens_k: torch.Tensor | None = None, cu_seqlens_k_new: torch.Tensor | None = None, cache_leftpad: torch.Tensor | None = None, page_size=None, max_seqlen_k_new=0, causal=False, window_size=(-1, -1), attention_chunk=0, has_softcap=False, num_splits=0, pack_gqa=None, has_qv=False, mp_margin=0)[source]

Build scheduler metadata for flash_attn_with_kvcache.

Use this helper to precompute the tensor passed through the scheduler_metadata argument of flash_attn_with_kvcache. This is the direct top-level API exposed as mate.get_scheduler_metadata.

Parameters:
  • batch_size (int) – Batch size for the scheduled attention workload.

  • max_seqlen_q (int) – Maximum query sequence length used by the target flash_attn_with_kvcache call.

  • max_seqlen_k (int) – Maximum key sequence length already present in the cache.

  • num_heads_q (int) – Number of query heads.

  • num_heads_kv (int) – Number of key / value heads.

  • headdim (int) – Query and key head dimension.

  • seqused_q (Optional[Tensor]) – Optional per-batch query lengths with shape (batch_size,).

  • seqused_k (Optional[Tensor]) – Optional per-batch key lengths with shape (batch_size,).

  • qkv_dtype (torch.dtype) – Data type of the scheduled QKV path. Default torch.bfloat16.

  • headdim_v (Optional[int]) – Value head dimension. Defaults to headdim.

  • cu_seqlens_q (Optional[Tensor]) – Optional cumulative query sequence lengths with shape (batch_size + 1,).

  • cu_seqlens_k (Optional[Tensor]) – Optional cumulative cached key sequence lengths with shape (batch_size + 1,).

  • cu_seqlens_k_new (Optional[Tensor]) – Optional cumulative new-KV sequence lengths with shape (batch_size + 1,).

  • cache_leftpad (Optional[Tensor]) – Optional per-batch left padding offsets for the KV cache.

  • page_size (Optional[int]) – Page size for paged KV-cache scheduling.

  • max_seqlen_k_new (int) – Maximum number of newly appended KV tokens.

  • causal (bool) – Whether the target attention call uses causal masking.

  • window_size (Tuple[int, int]) – Sliding-window attention bounds. (-1, -1) means full context.

  • attention_chunk (int) – Chunk size used by chunked attention scheduling.

  • has_softcap (bool) – Whether the target attention call enables softcapping.

  • num_splits (int) – Requested split count for key / value scheduling.

  • pack_gqa (Optional[bool]) – Optional GQA packing mode.

  • has_qv (bool) – Whether the target attention call includes the optional qv input.

  • mp_margin (int) – Number of MPs reserved for communication or other work.

Returns:

Scheduler metadata tensor to pass to flash_attn_with_kvcache(..., scheduler_metadata=...).

Return type:

Tensor

Notes

Keep the scheduling-related arguments aligned with the corresponding flash_attn_with_kvcache call so the generated metadata matches the actual workload.

mate.mha_interface.flash_attn_varlen_func(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens_q: torch.Tensor | None = None, cu_seqlens_k: torch.Tensor | None = None, max_seqlen_q: int | None = None, max_seqlen_k: int | None = None, seqused_q: torch.Tensor | None = None, seqused_k: torch.Tensor | None = None, page_table: torch.Tensor | None = None, softmax_scale: float | None = None, causal: bool = False, qv: torch.Tensor | None = None, q_descale: torch.Tensor | None = None, k_descale: torch.Tensor | None = None, v_descale: torch.Tensor | None = None, window_size: Tuple | List | None = (-1, -1), learnable_sink: torch.Tensor | None = None, attention_chunk: int | None = 0, softcap: float = 0.0, scheduler_metadata: torch.Tensor | None = None, num_splits: int = 0, pack_gqa=None, deterministic: bool = False, sm_margin=0, return_softmax_lse: bool = False, backend: str = 'auto', cp_world_size: int = 1, cp_rank: int = 0, cp_tot_seqused_k: torch.Tensor | None = None, out: torch.Tensor | None = None)[source]

FlashAttention3 compaitible API: forward with varlen or non-varlen inputs

Parameters:
  • q (Tensor) – The query tensor with shape (batch_size, seqlen, nheads, headdim) if cu_seqlen_q is None, or (total_q, nheads, headdim) if cu_seqlen_q is not None.

  • k (Tensor) – The key tensor with shape (batch_size, seqlen_k, nheads_k, headdim) if cu_seqlen_k is None, or (total_k, nheads_k, headdim) if cu_seqlen_k is not None.

  • v (Tensor) – The value tensor with shape (batch_size, seqlen_k, nheads_k, headdim) if cu_seqlen_k is None, or (total_k, nhead_k, headdim_v) if cu_seqlen_k is not None.

  • cu_seqlens_q (Optional[Tensor]) – The cumulative sequence length tensor for query, shape (batch_size + 1)

  • cu_seqlens_k (Optional[Tensor]) – The cumulative sequence length tensor for key/value, shape (batch_size + 1)

  • max_seqlen_q (Optional[int]) – The maximum sequence length for query, must provided if varlen forward

  • max_seqlen_k (Optional[int]) – The maximum sequence length for key/value

  • seqused_q (Optional[Tensor]) – Tensor with shape (batch_size) If given, only this many element of each batch element’s queries and outputs are used.

  • seqused_k (Optional[Tensor]) – Tensor with shape (batch_size) If given, only this many element of each batch element’s keys and values are used.

  • softmax_scale (Optional[float]) – The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim).

  • causal (bool) – Whether to apply causal attention mask (e.g., for auto-regressive modeling).

  • window_size (Tuple[int, int]) – The size of the sliding window. If not (-1, -1), implements sliding window local attention.

  • learnable_sink (Optional[Tensor]) – The Learnable Sink tensor for attention, shape (nheads, ).

  • softcap (float) –

    Anything > 0 activates softcapping attention, applied as

    logits = softcap * tanh(logits / softcap) before the softmax. 0.0 (default) disables softcapping.

  • return_softmax_lse (bool) – Whether to return the logsumexp of the attention scores.

  • backend (str) – The backend to use. It’s recommend to use the default auto.

  • cp_world_size (int) – Total number of ranks in the Context Parallelism (CP) group. Default 1 (CP disabled). When > 1, the global sequence is assumed to be distributed across ranks using an interleaved token pattern, where rank r holds tokens at positions [r, r + cp_world_size, r + 2*cp_world_size, ...].

  • cp_rank (int) – The rank of the current device within the CP group. Default 0.

  • cp_tot_seqused_k (Optional[Tensor]) – The global (across all CP ranks) cumulative key sequence lengths, shape (batch_size + 1,), dtype int32. Required when CP is enabled (cp_world_size > 1) so that each rank can correctly compute causal masking boundaries against the full key sequence. Ignored when cp_world_size == 1.

Returns:

If return_softmax_lse is False, the attention output, shape (total_q, nheads, headdim_v)

If return_softmax_lse is True, a tuple of two tensors:

  • The attention output, shape (total_q, nheads, headdim_v)

  • The log sum exp value, shape (nheads, total_q)

Return type:

Union[Tensor, Tuple[Tensor, Tensor]]

MLA

mate.flashmla.get_mla_metadata(cache_seqlens: torch.Tensor | None, num_q_tokens_per_head_k: int, num_heads_k: int, num_heads_q: int | None = None, is_fp8_kvcache: bool = False, topk: int | None = None, extra_topk: int | None = None, q: torch.Tensor | None = None, bs: int | None = None, topk_length: torch.Tensor | None = None, extra_topk_length: torch.Tensor | None = None) Tuple[torch.Tensor, torch.Tensor][source]

Get metadata for MLA decoding.

Parameters:
  • cache_seqlens (Tensor) – The sequence lengths of the KV cache with shape (batch_size)

  • num_q_tokens_per_head_k (int) – Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k.

  • num_heads_k (int) – The number of k heads.

  • num_heads_q (Optional[int]) – The number of q heads. This argument is optional when sparse attention is not enabled

  • is_fp8_kvcache (bool) – Whether the k_cache and v_cache are in fp8 format.

  • topk (Optional[int]) – If not None, sparse attention will be enabled, and only tokens in the indices array passed to flash_mla_with_kvcache_sm90 will be attended to.

  • extra_topk (Optional[int]) – Optional sparse extra-KV topk. This mirrors FlashMLA sparse decode metadata semantics without folding extra work into topk.

  • topk_length (Optional[Tensor]) – Optional per-batch sparse workload lengths for MODEL1 scheduled decode metadata. If extra_topk_length is provided, both lengths are summed in the C++ metadata kernel.

Returns:

A tuple of two tensors:

  • tile_scheduler_metadata, shape (num_sm_parts, TileSchedulerMetaDataSize)

  • num_splits, shape (batch_size + 1)

Return type:

Tuple[Tensor, Tensor]

mate.flashmla.flash_mla_with_kvcache(q: torch.Tensor, k_cache: torch.Tensor, block_table: torch.Tensor, cache_seqlens: torch.Tensor, head_dim_v: int, tile_scheduler_metadata: torch.Tensor, num_splits: torch.Tensor, softmax_scale: float | None = None, causal: bool = False, is_fp8_kvcache: bool = False, indices: torch.Tensor | None = None, attn_sink: torch.Tensor | None = None, extra_k_cache: torch.Tensor | None = None, extra_indices_in_kvcache: torch.Tensor | None = None, topk_length: torch.Tensor | None = None, extra_topk_length: torch.Tensor | None = None) Tuple[torch.Tensor, torch.Tensor][source]

Mla forward with kv cache

Parameters:
  • q (Tensor) – The query tensor with shape (batch_size, seq_len_q, num_heads_q, head_dim).

  • k_cache (Tensor) – The compressed kv cache tensor with shape (num_blocks, page_block_size, num_heads_k, head_dim).

  • block_table (Tensor) – The page table with shape (batch_size, max_num_blocks_per_seq).

  • cache_seqlens (Tensor) – The sequence lengths of the ckv cache wtih shape (batch_size).

  • head_dim_v (int) – Head dimension of v.

  • tile_scheduler_metadata (Tensor) – The scheduler metadata with shape (num_sm_parts, TileSchedulerMetaDataSize), returned by get_mla_metadata.

  • num_splits (Tensor) – The num_splits tensor with shape (batch_size + 1), returned by get_mla_metadata.

  • softmax_scale (Optional[float]) – The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).

  • causal (bool) – Whether to apply causal attention mask.

  • is_fp8_kvcache (bool) – Whether the k_cache and v_cache are in fp8 format. For the format of FP8 KV cache, please refer to README.md

  • indices (Optinal[Tensor]) – The token indices tensor with shape (batch_size, seq_len_q, topk). If not None, sparse attention will be enabled, and only tokens in the indices array will be attended to. Invalid indices should be set to -1 or numbers >= total_seq_len_kv. For details about how to set up indices, please refer to README.md.

Notes

For DeepSeek V3, DeepSeek V3.1, and DeepSeek V3.2, head_dim should be 576 while head_dim_v should be 512.

In FP8 + sparse mode, each token’s KV cache is 656 bytes. k_cache has shape (num_blocks, page_block_size, num_heads_k, head_dim), and num_heads_k must be 1. The first 512 bytes contain the quantized NoPE part with 512 float8_e4m3 values. The next 16 bytes contain four float32 scale factors, one for each group of 128 float8_e4m3 values. The final 128 bytes contain the RoPE part with 64 bfloat16 values; this part is left unquantized for accuracy.

Returns:

A tuple of two tensors:

  • out, shape (batch_size, seq_len_q, num_heads_q, head_dim_v).

  • softmax_lse, shape (batch_size, num_heads_q, seq_len_q).

Return type:

Tuple[Tensor, Tensor]