Attention¶
FMHA¶
- mate.mha_interface.flash_attn_with_kvcache(q: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, k: torch.Tensor | None = None, v: torch.Tensor | None = None, qv: torch.Tensor | None = None, rotary_cos: torch.Tensor | None = None, rotary_sin: torch.Tensor | None = None, cache_seqlens: int | torch.Tensor | None = None, cache_batch_idx: torch.Tensor | None = None, cache_leftpad: torch.Tensor | None = None, page_table: torch.Tensor | None = None, cu_seqlens_q: torch.Tensor | None = None, cu_seqlens_k_new: torch.Tensor | None = None, max_seqlen_q: int | None = None, rotary_seqlens: torch.Tensor | None = None, q_descale: torch.Tensor | None = None, k_descale: torch.Tensor | None = None, v_descale: torch.Tensor | None = None, softmax_scale: float | None = None, causal: bool = False, window_size: Tuple | List | None = (-1, -1), learnable_sink: torch.Tensor | None = None, attention_chunk: int | None = 0, softcap: float = 0.0, rotary_interleaved: bool = True, scheduler_metadata: torch.Tensor | None = None, num_splits: int = 0, pack_gqa=None, sm_margin=0, return_softmax_lse: bool = False, cp_world_size: int = 1, cp_rank: int = 0, cp_tot_seqused_k: torch.Tensor | None = None)[source]¶
FlashAttention3 compatible API: forward with kv cache
- Parameters:
q (Tensor) – The query tensor with shape
(batch_size, seqlen, nheads, headdim)if cu_seqlens_q is None, or(total_q, nheads, headdim)if cu_seqlens_q is not Nonek_cache (Tensor) – The key cache tensor with shape
(batch_size_cache, seqlen_cache, nheads_k, headdim)if there’s no page_table, or(num_blocks, page_block_size, nheads_k, headdim)if there’s a page_table (i.e. paged KV cache)v_cache (Tensor) – The value cache tensor with shape
(batch_size_cache, seqlen_cache, nheads_k, headdim_v)if there’s no page_table, or(num_blocks, page_block_size, nheads_k, headdim_v)if there’s a page_table (i.e. paged KV cache)k (Optional[Tensor]) – The key tensor with shape
(batch_size, seqlen_new, nheads_k, headdim)if cu_seqlens_k_new is None, or(total_k_new, nheads_k, headdim)if cu_seqlens_k_new is not None. If k is not None, we concatenate k with k_cache, starting at the indices specified by cache_seqlens.v (Optional[Tensor]) – The value tensor with shape
(batch_size, seqlen_new, nheads_k, headdim_v)if cu_seqlens_k_new is None. or(total_k_new, nheads_k, headdim_v)if cu_seqlens_k_new is not None. Similar to k.rotary_cos (Optional[Tensor]) – Tensor with shape
(seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding to k and q. Only applicable if k and v are passed in.rotary_dimmust be<= headdimand divisible by 16.rotary_cosmust be on MUSA and have the same dtype as q.rotary_sin (Optional[Tensor]) – Tensor with shape
(seqlen_ro, rotary_dim / 2). Similar to rotary_cos and must have the same shape and dtype.cache_seqlens (Union[int, Tensor]) – The sequence lengths of the KV cache, shape
(batch_size)if it is tensor.cache_batch_idx (Optional[Tensor]) – The int32 indices used to index into the KV cache, shape
(batch_size,). The tensor must be on MUSA and contiguous. If the indices are not distinct, and k and v are provided, the values updated in the cache might come from any of the duplicate indices.cache_leftpad (Optional[Tensor]) – The int32 left padding offset where the KV cache starts for each batch, shape
(batch_size,). The tensor must be on MUSA and contiguous. If None, assume 0.page_table (Optional[Tensor]) – The page table tensor with shape
(batch_size, max_num_blocks_per_seq)cu_seqlens_q (Optional[Tensor]) – The cumulative sequence lengths of the query, shape
(batch_size + 1).cu_seqlens_k_new (Optional[Tensor]) – The cumulative sequence lengths of the new KV, shape
(batch_size + 1).rotary_seqlens (Optional[Tensor]) – Optional int32 tensor with shape
(batch_size,)used as the rotary position length for each batch.softmax_scale (Optional[float]) – The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim).
causal (bool) – Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size (Tuple[int, int]) – The size of the sliding window. If not (-1, -1), implements sliding window local attention.
learnable_sink (Optional[Tensor]) – The Learnable Sink tensor for attention, shape
(nheads, ).softcap (float) –
Anything > 0 activates softcapping attention, applied as
logits = softcap * tanh(logits / softcap)before the softmax. 0.0 (default) disables softcapping.rotary_interleaved (bool) – If True, rotary embedding uses GPT-J style and combines dimensions 0 & 1, 2 & 3, etc. If False, rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1 (i.e. GPT-NeoX style).
num_splits (int) – If > 1, split the key/value into this many chunks along the sequence. If num_splits == 1, we don’t split the key/value. If num_splits == 0, we use a heuristic to automatically determine the number of splits. Don’t change this unless you know what you are doing.
return_softmax_lse (bool) – Whether to return the logsumexp of the attention scores.
cp_world_size (int) – Total number of ranks in the Context Parallelism (CP) group. Default 1 (CP disabled). When > 1, the global sequence is assumed to be distributed across ranks using an interleaved token pattern, where rank
rholds tokens at positions[r, r + cp_world_size, r + 2*cp_world_size, ...].cp_rank (int) – The rank of the current device within the CP group. Default 0.
cp_tot_seqused_k (Optional[Tensor]) – The global (across all CP ranks) cumulative key sequence lengths, shape
(batch_size + 1,), dtypeint32. Required when CP is enabled (cp_world_size > 1) so that each rank can correctly compute causal masking boundaries against the full key sequence. Ignored whencp_world_size == 1.
- Returns:
If
return_softmax_lseisFalse, the attention output, shape(batch_size, seqlen, nheads, headdim_v)if cu_seqlens_q is None, or(total_q, nheads, headdim_v)if cu_seqlens_q is not NoneIf
return_softmax_lseisTrue, a tuple of two tensors:The attention output, shape
(batch_size, seqlen, nheads, headdim_v)if cu_seqlens_q is None, or(total_q, nheads, headdim_v)if cu_seqlens_q is not NoneThe log sum exp value, shape
(batch_size, nheads, seqlen)if cu_seqlens_q is None, or(nheads, total_q)if cu_seqlens_q is not None
- Return type:
Union[Tensor, Tuple[Tensor, Tensor]]
- mate.mha_interface.flash_attn_varlen_func(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens_q: torch.Tensor | None = None, cu_seqlens_k: torch.Tensor | None = None, max_seqlen_q: int | None = None, max_seqlen_k: int | None = None, seqused_q: torch.Tensor | None = None, seqused_k: torch.Tensor | None = None, page_table: torch.Tensor | None = None, softmax_scale: float | None = None, causal: bool = False, qv: torch.Tensor | None = None, q_descale: torch.Tensor | None = None, k_descale: torch.Tensor | None = None, v_descale: torch.Tensor | None = None, window_size: Tuple | List | None = (-1, -1), learnable_sink: torch.Tensor | None = None, attention_chunk: int | None = 0, softcap: float = 0.0, scheduler_metadata: torch.Tensor | None = None, num_splits: int = 0, pack_gqa=None, deterministic: bool = False, sm_margin=0, return_softmax_lse: bool = False, backend: str = 'auto', cp_world_size: int = 1, cp_rank: int = 0, cp_tot_seqused_k: torch.Tensor | None = None, out: torch.Tensor | None = None)[source]¶
FlashAttention3 compaitible API: forward with varlen or non-varlen inputs
- Parameters:
q (Tensor) – The query tensor with shape
(batch_size, seqlen, nheads, headdim)if cu_seqlen_q is None, or(total_q, nheads, headdim)if cu_seqlen_q is not None.k (Tensor) – The key tensor with shape
(batch_size, seqlen_k, nheads_k, headdim)if cu_seqlen_k is None, or(total_k, nheads_k, headdim)if cu_seqlen_k is not None.v (Tensor) – The value tensor with shape
(batch_size, seqlen_k, nheads_k, headdim)if cu_seqlen_k is None, or(total_k, nhead_k, headdim_v)if cu_seqlen_k is not None.cu_seqlens_q (Optional[Tensor]) – The cumulative sequence length tensor for query, shape
(batch_size + 1)cu_seqlens_k (Optional[Tensor]) – The cumulative sequence length tensor for key/value, shape
(batch_size + 1)max_seqlen_q (Optional[int]) – The maximum sequence length for query, must provided if varlen forward
max_seqlen_k (Optional[int]) – The maximum sequence length for key/value
seqused_q (Optional[Tensor]) – Tensor with shape
(batch_size)If given, only this many element of each batch element’s queries and outputs are used.seqused_k (Optional[Tensor]) – Tensor with shape
(batch_size)If given, only this many element of each batch element’s keys and values are used.softmax_scale (Optional[float]) – The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim).
causal (bool) – Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size (Tuple[int, int]) – The size of the sliding window. If not (-1, -1), implements sliding window local attention.
learnable_sink (Optional[Tensor]) – The Learnable Sink tensor for attention, shape
(nheads, ).softcap (float) –
Anything > 0 activates softcapping attention, applied as
logits = softcap * tanh(logits / softcap)before the softmax. 0.0 (default) disables softcapping.return_softmax_lse (bool) – Whether to return the logsumexp of the attention scores.
backend (str) – The backend to use. It’s recommend to use the default
auto.cp_world_size (int) – Total number of ranks in the Context Parallelism (CP) group. Default 1 (CP disabled). When > 1, the global sequence is assumed to be distributed across ranks using an interleaved token pattern, where rank
rholds tokens at positions[r, r + cp_world_size, r + 2*cp_world_size, ...].cp_rank (int) – The rank of the current device within the CP group. Default 0.
cp_tot_seqused_k (Optional[Tensor]) – The global (across all CP ranks) cumulative key sequence lengths, shape
(batch_size + 1,), dtypeint32. Required when CP is enabled (cp_world_size > 1) so that each rank can correctly compute causal masking boundaries against the full key sequence. Ignored whencp_world_size == 1.
- Returns:
If
return_softmax_lseisFalse, the attention output, shape(total_q, nheads, headdim_v)If
return_softmax_lseisTrue, a tuple of two tensors:The attention output, shape
(total_q, nheads, headdim_v)The log sum exp value, shape
(nheads, total_q)
- Return type:
Union[Tensor, Tuple[Tensor, Tensor]]
MLA¶
- mate.flashmla.get_mla_metadata(cache_seqlens: torch.Tensor | None, num_q_tokens_per_head_k: int, num_heads_k: int, num_heads_q: int | None = None, is_fp8_kvcache: bool = False, topk: int | None = None, extra_topk: int | None = None, q: torch.Tensor | None = None, bs: int | None = None, topk_length: torch.Tensor | None = None, extra_topk_length: torch.Tensor | None = None) Tuple[torch.Tensor, torch.Tensor][source]¶
Get metadata for MLA decoding.
- Parameters:
cache_seqlens (Tensor) – The sequence lengths of the KV cache with shape
(batch_size)num_q_tokens_per_head_k (int) – Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k.
num_heads_k (int) – The number of k heads.
num_heads_q (Optional[int]) – The number of q heads. This argument is optional when sparse attention is not enabled
is_fp8_kvcache (bool) – Whether the k_cache and v_cache are in fp8 format.
topk (Optional[int]) – If not None, sparse attention will be enabled, and only tokens in the indices array passed to flash_mla_with_kvcache_sm90 will be attended to.
extra_topk (Optional[int]) – Optional sparse extra-KV topk. This mirrors FlashMLA sparse decode metadata semantics without folding extra work into topk.
topk_length (Optional[Tensor]) – Optional per-batch sparse workload lengths for MODEL1 scheduled decode metadata. If extra_topk_length is provided, both lengths are summed in the C++ metadata kernel.
- Returns:
A tuple of two tensors:
tile_scheduler_metadata, shape
(num_sm_parts, TileSchedulerMetaDataSize)num_splits, shape
(batch_size + 1)
- Return type:
Tuple[Tensor, Tensor]
- mate.flashmla.flash_mla_with_kvcache(q: torch.Tensor, k_cache: torch.Tensor, block_table: torch.Tensor, cache_seqlens: torch.Tensor, head_dim_v: int, tile_scheduler_metadata: torch.Tensor, num_splits: torch.Tensor, softmax_scale: float | None = None, causal: bool = False, is_fp8_kvcache: bool = False, indices: torch.Tensor | None = None, attn_sink: torch.Tensor | None = None, extra_k_cache: torch.Tensor | None = None, extra_indices_in_kvcache: torch.Tensor | None = None, topk_length: torch.Tensor | None = None, extra_topk_length: torch.Tensor | None = None) Tuple[torch.Tensor, torch.Tensor][source]¶
Mla forward with kv cache
- Parameters:
q (Tensor) – The query tensor with shape
(batch_size, seq_len_q, num_heads_q, head_dim).k_cache (Tensor) – The compressed kv cache tensor with shape
(num_blocks, page_block_size, num_heads_k, head_dim).block_table (Tensor) – The page table with shape
(batch_size, max_num_blocks_per_seq).cache_seqlens (Tensor) – The sequence lengths of the ckv cache wtih shape
(batch_size).head_dim_v (int) – Head dimension of v.
tile_scheduler_metadata (Tensor) – The scheduler metadata with shape
(num_sm_parts, TileSchedulerMetaDataSize), returned by get_mla_metadata.num_splits (Tensor) – The num_splits tensor with shape
(batch_size + 1), returned by get_mla_metadata.softmax_scale (Optional[float]) – The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal (bool) – Whether to apply causal attention mask.
is_fp8_kvcache (bool) – Whether the k_cache and v_cache are in fp8 format. For the format of FP8 KV cache, please refer to README.md
indices (Optinal[Tensor]) – The token indices tensor with shape
(batch_size, seq_len_q, topk). If not None, sparse attention will be enabled, and only tokens in the indices array will be attended to. Invalid indices should be set to -1 or numbers >= total_seq_len_kv. For details about how to set up indices, please refer to README.md.V3 (For DeepSeek) –
head_dim should be 576 while head_dim_v should be 512. In FP8+sparse mode, each token’s KV cache is 656 Bytes, structured as:
The shape of the tensor k_cache is (num_blocks, page_block_size, num_heads_k, head_dim), and num_heads_k must be 1.
First 512 bytes: The “quantized NoPE” part, containing 512 float8_e4m3 values.
Next 16 bytes: Scale factors, containing 4 float32 values. The first float32 is the scale for the first 128 float8_e4m3 values, the second for the next 128, and so on.
Last 128 bytes: The “RoPE” part, containing 64 bfloat16 values. This part is not quantized for accuracy.
V3.1 (DeepSeek) –
head_dim should be 576 while head_dim_v should be 512. In FP8+sparse mode, each token’s KV cache is 656 Bytes, structured as:
The shape of the tensor k_cache is (num_blocks, page_block_size, num_heads_k, head_dim), and num_heads_k must be 1.
First 512 bytes: The “quantized NoPE” part, containing 512 float8_e4m3 values.
Next 16 bytes: Scale factors, containing 4 float32 values. The first float32 is the scale for the first 128 float8_e4m3 values, the second for the next 128, and so on.
Last 128 bytes: The “RoPE” part, containing 64 bfloat16 values. This part is not quantized for accuracy.
V3.2 (and DeepSeek) –
head_dim should be 576 while head_dim_v should be 512. In FP8+sparse mode, each token’s KV cache is 656 Bytes, structured as:
The shape of the tensor k_cache is (num_blocks, page_block_size, num_heads_k, head_dim), and num_heads_k must be 1.
First 512 bytes: The “quantized NoPE” part, containing 512 float8_e4m3 values.
Next 16 bytes: Scale factors, containing 4 float32 values. The first float32 is the scale for the first 128 float8_e4m3 values, the second for the next 128, and so on.
Last 128 bytes: The “RoPE” part, containing 64 bfloat16 values. This part is not quantized for accuracy.
- Returns:
A tuple of two tensors:
out, shape
(batch_size, seq_len_q, num_heads_q, head_dim_v).softmax_lse, shape
(batch_size, num_heads_q, seq_len_q).
- Return type:
Tuple[Tensor, Tensor]