import functools
import torch
from typing import List, Optional, Union, Tuple
from mate.api_logging import mate_api
from .jit.flash_attention_ops import get_flash_attention_ops_module
from .jit.mla_ops import get_mla_ops_module
from .jit.attention.fmha import (
_fmha_get_metadata as jit_fmha_get_metadata,
) # noqa: F401
from .jit.attention.fmha import _fmha_fwd as jit_fmha_fwd # noqa: F401
from .jit.attention.fmha.fmha_combine import _flash_attn_combine
from .execution_context import raise_complete_if_dry_run
@functools.cache
def _get_flash_attention_ops():
return get_flash_attention_ops_module()
@functools.cache
def _get_mla_ops():
return get_mla_ops_module()
def _check_valid_asm_input(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
page_table,
seqused_q,
seqused_k,
qv,
window_size,
learnable_sink,
attention_chunk,
softcap,
cp_world_size=1,
):
enable_mubin = True
enable_mubin &= q.is_musa
enable_mubin &= k.is_musa
enable_mubin &= v.is_musa
enable_mubin &= q.dtype == torch.float16 or q.dtype == torch.bfloat16
enable_mubin &= k.dtype == torch.float16 or k.dtype == torch.bfloat16
enable_mubin &= v.dtype == torch.float16 or v.dtype == torch.bfloat16
enable_mubin &= q.dtype == k.dtype and q.dtype == v.dtype
enable_mubin &= q.dim() == 3 or q.dim() == 4
enable_mubin &= k.dim() == 3 or k.dim() == 4
enable_mubin &= v.dim() == 3 or v.dim() == 4
enable_mubin &= q.dim() == k.dim() and q.dim() == v.dim()
headdim_qk = q.shape[-1]
headdim_v = v.shape[-1]
is_192_128 = headdim_qk == 192 and headdim_v == 128
is_128_128_or_less = headdim_qk == headdim_v and headdim_qk <= 128
enable_mubin &= is_192_128 or is_128_128_or_less
enable_mubin &= page_table is None
enable_mubin &= seqused_q is None
enable_mubin &= seqused_k is None
window_size_left, window_size_right = window_size
enable_mubin &= window_size_left is None or window_size_left < 0
enable_mubin &= window_size_right is None or window_size_right <= 0
enable_mubin &= qv is None
enable_mubin &= softcap == 0.0
enable_mubin &= learnable_sink is None
enable_mubin &= attention_chunk == 0
enable_mubin &= cp_world_size == 1
if not enable_mubin:
return enable_mubin
if q.dim() == 3:
total_seq_q, nr_heads, headdim_qk = q.shape
total_seq_kv, nr_heads_kv, _ = k.shape
_, _, headdim_v = v.shape
enable_mubin &= k.shape == (total_seq_kv, nr_heads_kv, headdim_qk)
enable_mubin &= v.shape == (total_seq_kv, nr_heads_kv, headdim_v)
enable_mubin &= cu_seqlens_q.is_musa
enable_mubin &= cu_seqlens_k.is_musa
enable_mubin &= cu_seqlens_q is not None
enable_mubin &= cu_seqlens_k is not None
enable_mubin &= cu_seqlens_k.numel() == cu_seqlens_q.numel()
enable_mubin &= max_seqlen_q is not None
enable_mubin &= max_seqlen_k is not None
if q.dim() == 4:
batch, seq_q, nr_heads, headdim_qk = q.shape
_, seq_kv, nr_heads_kv, _ = k.shape
_, _, _, headdim_v = v.shape
enable_mubin &= k.shape == (batch, seq_kv, nr_heads_kv, headdim_qk)
enable_mubin &= v.shape == (batch, seq_kv, nr_heads_kv, headdim_v)
return enable_mubin
def maybe_contiguous(x):
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
def _allocate_mla_decode_outputs(q: torch.Tensor, head_dim_v: int):
if q.dim() == 4:
out = torch.empty(
(*q.shape[:-1], head_dim_v),
dtype=q.dtype,
device=q.device,
)
softmax_lse = torch.empty(
q.shape[:-1],
dtype=torch.float32,
device=q.device,
)
else:
out = torch.empty(
(q.shape[0], q.shape[1], head_dim_v),
dtype=q.dtype,
device=q.device,
)
softmax_lse = torch.empty(
(q.shape[1], q.shape[0]),
dtype=torch.float32,
device=q.device,
)
return out, softmax_lse
def _prepare_mla_query_input(
x: torch.Tensor, *, require_seq_dense: bool
) -> torch.Tensor:
# Match Python-side materialization to the exact MLA backend stride contract.
if x.stride(-1) != 1:
return x.contiguous()
if require_seq_dense and x.dim() == 4:
if x.stride(1) != x.shape[-2] * x.stride(2):
return x.contiguous()
if require_seq_dense and x.dim() == 3:
if x.stride(0) != x.shape[-2] * x.stride(1):
return x.contiguous()
return x
def _flash_attn_forward(
q,
k,
v,
k_new,
v_new,
qv,
out,
cu_seqlens_q,
cu_seqlens_k,
cu_seqlens_k_new,
seqused_q,
seqused_k,
max_seqlen_q,
max_seqlen_k,
page_table,
kv_batch_idx,
leftpad_k,
rotary_cos,
rotary_sin,
seqlens_rotary,
q_descale,
k_descale,
v_descale,
softmax_scale,
causal,
window_size=(-1, -1),
learnable_sink=None,
attention_chunk=0,
softcap=0.0,
rotary_interleaved=True,
scheduler_metadata=None,
num_splits=-1,
pack_gqa=None,
sm_margin=0,
cp_world_size: int = 1,
cp_rank: int = 0,
cp_tot_seqused_k: Optional[torch.Tensor] = None,
):
q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)
]
seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]
page_table, kv_batch_idx, leftpad_k = [
maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)
]
rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
seqlens_rotary = maybe_contiguous(seqlens_rotary)
out, softmax_lse, *rest = jit_fmha_fwd(
q=q,
k=k,
v=v,
k_new=k_new,
v_new=v_new,
q_v=qv,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
cu_seqlens_k_new=cu_seqlens_k_new,
seqused_q=seqused_q,
seqused_k=seqused_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
page_table=page_table,
kv_batch_idx=kv_batch_idx,
leftpad_k=leftpad_k,
rotary_cos=rotary_cos,
rotary_sin=rotary_sin,
seqlens_rotary=seqlens_rotary,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
softmax_scale=softmax_scale,
is_causal=causal,
window_size_left=window_size[0],
window_size_right=window_size[1],
attention_chunk=attention_chunk,
learnable_sink=learnable_sink,
softcap=softcap,
is_rotary_interleaved=rotary_interleaved,
scheduler_metadata=scheduler_metadata,
num_splits=num_splits,
pack_gqa=pack_gqa,
mp_margin=sm_margin,
return_lse=True,
lse=None,
out=out,
cp_world_size=cp_world_size,
cp_rank=cp_rank,
cp_tot_seqused_k=cp_tot_seqused_k,
)
return out, softmax_lse, *rest
def _flash_attn_varlen_backward(
dout: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
out: torch.Tensor,
softmax_lse: torch.Tensor,
dq: Optional[torch.Tensor],
dk: Optional[torch.Tensor],
dv: Optional[torch.Tensor],
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
dropout_p: float,
softmax_scale: float,
causal: bool,
window_size_left: int,
window_size_right: int,
softcap: float,
alibi_slopes: Optional[torch.Tensor],
deterministic: bool,
rng_state: Optional[torch.Tensor] = None,
zero_tensors: bool = False,
) -> torch.Tensor:
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
_get_flash_attention_ops().get_function("dnn_mha_varlen_bwd")(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
cu_seqlens_q,
cu_seqlens_k,
None,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
zero_tensors,
causal,
window_size_left,
window_size_right,
softcap,
deterministic,
None,
None,
)
return None
class FlashAttnVarlenFunc(torch.autograd.Function):
@staticmethod
def forward(
ctx: torch.autograd.function.FunctionCtx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_k: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_k: Optional[int] = None,
seqused_q: Optional[torch.Tensor] = None,
seqused_k: Optional[torch.Tensor] = None,
page_table: Optional[torch.Tensor] = None,
softmax_scale: Optional[float] = None,
causal: bool = False,
qv: Optional[torch.Tensor] = None,
q_descale: Optional[torch.Tensor] = None,
k_descale: Optional[torch.Tensor] = None,
v_descale: Optional[torch.Tensor] = None,
window_size: Union[Tuple, List, None] = (-1, -1),
learnable_sink: Optional[torch.Tensor] = None,
attention_chunk: Optional[int] = 0,
softcap: float = 0.0,
scheduler_metadata: Optional[torch.Tensor] = None,
num_splits: int = -1,
pack_gqa=None,
deterministic: bool = False,
sm_margin=0,
return_softmax_lse: bool = False,
backend: str = "auto", # "auto", "mutlass", "mubin"
cp_world_size: int = 1,
cp_rank: int = 0,
cp_tot_seqused_k: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
):
if window_size is None:
window_size = (-1, -1)
if attention_chunk is None:
attention_chunk = 0
select_backend = backend
if select_backend == "auto":
enable_mubin = _check_valid_asm_input(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
page_table=page_table,
seqused_q=seqused_q,
seqused_k=seqused_k,
qv=qv,
window_size=window_size,
learnable_sink=learnable_sink,
attention_chunk=attention_chunk,
softcap=softcap,
cp_world_size=cp_world_size,
)
if enable_mubin:
select_backend = "mubin"
else:
select_backend = "mutlass"
# assert not enable_mubin
if softmax_scale is None:
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (
-0.5
)
if select_backend == "mutlass":
out, softmax_lse, *rest = _flash_attn_forward(
q=q,
k=k,
v=v,
k_new=None,
v_new=None,
qv=qv,
out=out,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
cu_seqlens_k_new=None,
seqused_q=seqused_q,
seqused_k=seqused_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
page_table=page_table,
kv_batch_idx=None,
leftpad_k=None,
rotary_cos=None,
rotary_sin=None,
seqlens_rotary=None,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
learnable_sink=learnable_sink,
attention_chunk=attention_chunk,
softcap=softcap,
rotary_interleaved=True,
scheduler_metadata=scheduler_metadata,
num_splits=num_splits,
pack_gqa=pack_gqa,
sm_margin=sm_margin,
cp_world_size=cp_world_size,
cp_rank=cp_rank,
cp_tot_seqused_k=cp_tot_seqused_k,
)
elif select_backend == "mubin":
# In dry run, don't run mubin kernels
raise_complete_if_dry_run()
is_varlen = cu_seqlens_q is not None and cu_seqlens_k is not None
window_size_left, window_size_right = window_size
assert seqused_q is None
assert seqused_k is None
assert window_size_left is None or window_size_left < 0
assert window_size_right is None or window_size_right < 0
assert qv is None
assert learnable_sink is None
assert softcap == 0.0
assert attention_chunk == 0
assert q.dtype in [torch.float16, torch.bfloat16]
assert q.dtype == k.dtype and q.dtype == v.dtype
assert page_table is None
if is_varlen:
assert cu_seqlens_q is not None
assert cu_seqlens_k is not None
assert max_seqlen_k is not None
assert max_seqlen_q is not None
total_seqlen, nr_heads, _ = q.shape
headdim_v = v.shape[-1]
if out is None:
out = torch.empty(
(total_seqlen, nr_heads, headdim_v),
dtype=q.dtype,
device=q.device,
)
batch = cu_seqlens_q.shape[0] - 1
softmax_lse = torch.empty(
(nr_heads, total_seqlen),
dtype=torch.float32,
device=q.device,
)
else:
# is no varlen
# bshd
batch, seq_q, nr_heads, _ = q.shape
headdim_v = v.shape[-1]
if out is None:
out = torch.empty(
(batch, seq_q, nr_heads, headdim_v),
dtype=q.dtype,
device=q.device,
)
softmax_lse = torch.empty(
(batch, nr_heads, seq_q), dtype=torch.float32, device=q.device
)
_get_flash_attention_ops().get_function("flash_atten_varlen_asm")(
q,
k,
v,
softmax_scale,
out,
softmax_lse,
causal,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
)
else:
raise ValueError(
f"Only support backend 'mutlass', 'mubin' and 'auto'! Get unknown backend {select_backend}!"
)
is_grad = any(x.requires_grad for x in [q, k, v])
if is_grad:
ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k)
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_k = max_seqlen_k
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.window_size = window_size
ctx.softcap = softcap
ctx.deterministic = deterministic
should_return_lse = return_softmax_lse
return (out, softmax_lse) if should_return_lse else out
@staticmethod
def backward(ctx, dout, *args):
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors
headdim = q.shape[-1]
if headdim == 256:
from .flash_attention.tilelang.flash_attention_varlen_bwd import (
flashattn_varlen_bwd_interface,
)
dq, dk, dv = flashattn_varlen_bwd_interface(
q,
k,
v,
out,
dout,
softmax_lse,
ctx.max_seqlen_q,
ctx.max_seqlen_k,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
is_causal=ctx.causal,
smscale=ctx.softmax_scale,
dtype=None,
block_M=64,
block_N=64,
threads=640,
is_bhsd=False,
)
else:
# dnn bwd need (b,h max_q) lse but fwd lse is (h, total_q) currently!
def lse_varlen_to_padded(lse_flat, cu_seqlen_q, max_q, pad_value=0.0):
# lse_flat: [H, total_Q]
# lse_padded: [H, b, max_q]
H, total_Q = lse_flat.shape
# b = cu_seqlen_q.shape[0] - 1
device = lse_flat.device
seq_ids = torch.arange(max_q, device=device).unsqueeze(0)
offsets = cu_seqlen_q[:-1].unsqueeze(1)
indices = seq_ids + offsets
seqlens = cu_seqlen_q[1:] - cu_seqlen_q[:-1]
valid_mask = seq_ids < seqlens.unsqueeze(1)
indices = torch.clamp(indices, 0, total_Q - 1)
lse_gathered = lse_flat[:, indices]
lse_padded = torch.where(
valid_mask.unsqueeze(0),
lse_gathered,
torch.tensor(pad_value, device=device, dtype=lse_flat.dtype),
)
return lse_padded.permute(1, 0, 2).contiguous()
softmax_lse = lse_varlen_to_padded(
softmax_lse, cu_seqlens_q, ctx.max_seqlen_q
)
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
head_size_og = dout.size(2)
dout_padded = dout
if head_size_og % 8 != 0:
dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
_flash_attn_varlen_backward(
dout_padded,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
cu_seqlens_q,
cu_seqlens_k,
ctx.max_seqlen_q,
ctx.max_seqlen_k,
dropout_p=0.0,
softmax_scale=ctx.softmax_scale,
causal=ctx.causal,
window_size_left=ctx.window_size[0],
window_size_right=ctx.window_size[1],
softcap=ctx.softcap,
alibi_slopes=None,
deterministic=ctx.deterministic,
rng_state=None,
)
# dq = dq[..., : dout.shape[-1]]
# dk = dk[..., : dout.shape[-1]]
# dv = dv[..., : dout.shape[-1]]
return (
dq,
dk,
dv,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
[docs]
@mate_api
def flash_attn_varlen_func(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_k: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_k: Optional[int] = None,
seqused_q: Optional[torch.Tensor] = None,
seqused_k: Optional[torch.Tensor] = None,
page_table: Optional[torch.Tensor] = None,
softmax_scale: Optional[float] = None,
causal: bool = False,
qv: Optional[torch.Tensor] = None,
q_descale: Optional[torch.Tensor] = None,
k_descale: Optional[torch.Tensor] = None,
v_descale: Optional[torch.Tensor] = None,
window_size: Union[Tuple, List, None] = (-1, -1),
learnable_sink: Optional[torch.Tensor] = None,
attention_chunk: Optional[int] = 0,
softcap: float = 0.0,
scheduler_metadata: Optional[torch.Tensor] = None,
num_splits: int = 0,
pack_gqa=None,
deterministic: bool = False,
sm_margin=0,
return_softmax_lse: bool = False,
backend: str = "auto", # "auto", "mutlass", "mubin"
cp_world_size: int = 1,
cp_rank: int = 0,
cp_tot_seqused_k: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
):
r"""
FlashAttention3 compaitible API: forward with varlen or non-varlen inputs
Parameters
----------
q : Tensor
The query tensor with shape ``(batch_size, seqlen, nheads, headdim)`` if cu_seqlen_q is None,
or ``(total_q, nheads, headdim)`` if cu_seqlen_q is not None.
k : Tensor
The key tensor with shape ``(batch_size, seqlen_k, nheads_k, headdim)`` if cu_seqlen_k is None,
or ``(total_k, nheads_k, headdim)`` if cu_seqlen_k is not None.
v : Tensor
The value tensor with shape ``(batch_size, seqlen_k, nheads_k, headdim)`` if cu_seqlen_k is None,
or ``(total_k, nhead_k, headdim_v)`` if cu_seqlen_k is not None.
cu_seqlens_q : Optional[Tensor]
The cumulative sequence length tensor for query, shape ``(batch_size + 1)``
cu_seqlens_k : Optional[Tensor]
The cumulative sequence length tensor for key/value, shape ``(batch_size + 1)``
max_seqlen_q : Optional[int]
The maximum sequence length for query, must provided if varlen forward
max_seqlen_k : Optional[int]
The maximum sequence length for key/value
seqused_q: Optional[Tensor]
Tensor with shape ``(batch_size)``
If given, only this many element of each batch element's queries and outputs are used.
seqused_k: Optional[Tensor]
Tensor with shape ``(batch_size)``
If given, only this many element of each batch element's keys and values are used.
softmax_scale: Optional[float]
The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim).
causal: bool
Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: Tuple[int, int]
The size of the sliding window. If not (-1, -1), implements sliding window local attention.
learnable_sink: Optional[Tensor]
The Learnable Sink tensor for attention, shape ``(nheads, )``.
softcap: float
Anything > 0 activates softcapping attention, applied as
``logits = softcap * tanh(logits / softcap)`` before the softmax.
0.0 (default) disables softcapping.
return_softmax_lse: bool
Whether to return the logsumexp of the attention scores.
backend: str
The backend to use. It's recommend to use the default ``auto``.
cp_world_size: int
Total number of ranks in the Context Parallelism (CP) group. Default 1 (CP disabled).
When > 1, the global sequence is assumed to be distributed across ranks using an
interleaved token pattern, where rank ``r`` holds tokens at positions
``[r, r + cp_world_size, r + 2*cp_world_size, ...]``.
cp_rank: int
The rank of the current device within the CP group. Default 0.
cp_tot_seqused_k: Optional[Tensor]
The **global** (across all CP ranks) cumulative key sequence lengths, shape
``(batch_size + 1,)``, dtype ``int32``. Required when CP is enabled (``cp_world_size > 1``)
so that each rank can correctly compute causal masking boundaries against the full
key sequence. Ignored when ``cp_world_size == 1``.
Returns
-------
Union[Tensor, Tuple[Tensor, Tensor]]
If :attr:`return_softmax_lse` is ``False``, the attention output, shape ``(total_q, nheads, headdim_v)``
If :attr:`return_softmax_lse` is ``True``, a tuple of two tensors:
* The attention output, shape ``(total_q, nheads, headdim_v)``
* The log sum exp value, shape ``(nheads, total_q)``
"""
return FlashAttnVarlenFunc.apply(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
seqused_q,
seqused_k,
page_table,
softmax_scale,
causal,
qv,
q_descale,
k_descale,
v_descale,
window_size,
learnable_sink,
attention_chunk,
softcap,
scheduler_metadata,
num_splits,
pack_gqa,
deterministic,
sm_margin,
return_softmax_lse,
backend,
cp_world_size,
cp_rank,
cp_tot_seqused_k,
out,
)
[docs]
@mate_api
def flash_attn_combine(
out_partial: torch.Tensor,
lse_partial: torch.Tensor,
out: Optional[torch.Tensor] = None,
out_dtype: Optional[torch.dtype] = None,
):
return _flash_attn_combine(out_partial, lse_partial, out, out_dtype)
[docs]
@mate_api
def flash_attn_with_kvcache(
q: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
k: Optional[torch.Tensor] = None,
v: Optional[torch.Tensor] = None,
qv: Optional[torch.Tensor] = None,
rotary_cos: Optional[torch.Tensor] = None,
rotary_sin: Optional[torch.Tensor] = None,
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
cache_batch_idx: Optional[torch.Tensor] = None,
cache_leftpad: Optional[torch.Tensor] = None,
page_table: Optional[torch.Tensor] = None,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_k_new: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
rotary_seqlens: Optional[torch.Tensor] = None,
q_descale: Optional[torch.Tensor] = None,
k_descale: Optional[torch.Tensor] = None,
v_descale: Optional[torch.Tensor] = None,
softmax_scale: Optional[float] = None,
causal: bool = False,
window_size: Union[Tuple, List, None] = (
-1,
-1,
), # -1 means infinite context window
learnable_sink: Optional[torch.Tensor] = None,
attention_chunk: Optional[int] = 0,
softcap: float = 0.0, # 0.0 means deactivated
rotary_interleaved: bool = True,
scheduler_metadata: Optional[torch.Tensor] = None,
num_splits: int = 0, # Can be tuned for speed
pack_gqa=None, # Can be tuned for speed
sm_margin=0, # Can be tuned if some SMs are used for communication
return_softmax_lse: bool = False,
cp_world_size: int = 1,
cp_rank: int = 0,
cp_tot_seqused_k: Optional[torch.Tensor] = None,
):
r"""FlashAttention3 compatible API: forward with kv cache
Parameters
----------
q : Tensor
The query tensor with shape ``(batch_size, seqlen, nheads, headdim)`` if cu_seqlens_q is None,
or ``(total_q, nheads, headdim)`` if cu_seqlens_q is not None
k_cache : Tensor
The key cache tensor with shape ``(batch_size_cache, seqlen_cache, nheads_k, headdim)`` if there's no page_table,
or ``(num_blocks, page_block_size, nheads_k, headdim)`` if there's a page_table (i.e. paged KV cache)
v_cache : Tensor
The value cache tensor with shape ``(batch_size_cache, seqlen_cache, nheads_k, headdim_v)`` if there's no page_table,
or ``(num_blocks, page_block_size, nheads_k, headdim_v)`` if there's a page_table (i.e. paged KV cache)
k : Optional[Tensor]
The key tensor with shape ``(batch_size, seqlen_new, nheads_k, headdim)`` if cu_seqlens_k_new is None,
or ``(total_k_new, nheads_k, headdim)`` if cu_seqlens_k_new is not None.
If k is not None, we concatenate k with k_cache, starting at the indices specified by cache_seqlens.
v : Optional[Tensor]
The value tensor with shape ``(batch_size, seqlen_new, nheads_k, headdim_v)`` if cu_seqlens_k_new is None.
or ``(total_k_new, nheads_k, headdim_v)`` if cu_seqlens_k_new is not None.
Similar to k.
rotary_cos: Optional[Tensor]
Tensor with shape ``(seqlen_ro, rotary_dim / 2)``. If not None, we apply rotary embedding to k and q.
Only applicable if k and v are passed in. ``rotary_dim`` must be ``<= headdim`` and divisible by 16.
``rotary_cos`` must be on MUSA and have the same dtype as q.
rotary_sin: Optional[Tensor]
Tensor with shape ``(seqlen_ro, rotary_dim / 2)``. Similar to rotary_cos and must have the same shape and dtype.
cache_seqlens: Union[int, Tensor]
The sequence lengths of the KV cache, shape ``(batch_size)`` if it is tensor.
cache_batch_idx: Optional[Tensor]
The int32 indices used to index into the KV cache, shape ``(batch_size,)``.
The tensor must be on MUSA and contiguous.
If the indices are not distinct, and k and v are provided, the values updated in the cache might come from any of the duplicate indices.
cache_leftpad: Optional[Tensor]
The int32 left padding offset where the KV cache starts for each batch, shape ``(batch_size,)``.
The tensor must be on MUSA and contiguous. If None, assume 0.
page_table: Optional[Tensor]
The page table tensor with shape ``(batch_size, max_num_blocks_per_seq)``
cu_seqlens_q: Optional[Tensor]
The cumulative sequence lengths of the query, shape ``(batch_size + 1)``.
cu_seqlens_k_new: Optional[Tensor]
The cumulative sequence lengths of the new KV, shape ``(batch_size + 1)``.
rotary_seqlens: Optional[Tensor]
Optional int32 tensor with shape ``(batch_size,)`` used as the rotary position length for each batch.
softmax_scale: Optional[float]
The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim).
causal: bool
Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: Tuple[int, int]
The size of the sliding window. If not (-1, -1), implements sliding window local attention.
learnable_sink: Optional[Tensor]
The Learnable Sink tensor for attention, shape ``(nheads, )``.
softcap: float
Anything > 0 activates softcapping attention, applied as
``logits = softcap * tanh(logits / softcap)`` before the softmax.
0.0 (default) disables softcapping.
rotary_interleaved: bool
If True, rotary embedding uses GPT-J style and combines dimensions 0 & 1, 2 & 3, etc. If False,
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
(i.e. GPT-NeoX style).
num_splits: int
If > 1, split the key/value into this many chunks along the sequence.
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
to automatically determine the number of splits.
Don't change this unless you know what you are doing.
return_softmax_lse: bool
Whether to return the logsumexp of the attention scores.
cp_world_size: int
Total number of ranks in the Context Parallelism (CP) group. Default 1 (CP disabled).
When > 1, the global sequence is assumed to be distributed across ranks using an
interleaved token pattern, where rank ``r`` holds tokens at positions
``[r, r + cp_world_size, r + 2*cp_world_size, ...]``.
cp_rank: int
The rank of the current device within the CP group. Default 0.
cp_tot_seqused_k: Optional[Tensor]
The **global** (across all CP ranks) cumulative key sequence lengths, shape
``(batch_size + 1,)``, dtype ``int32``. Required when CP is enabled (``cp_world_size > 1``)
so that each rank can correctly compute causal masking boundaries against the full
key sequence. Ignored when ``cp_world_size == 1``.
Returns
-------
Union[Tensor, Tuple[Tensor, Tensor]]
If :attr:`return_softmax_lse` is ``False``, the attention output, shape ``(batch_size, seqlen, nheads, headdim_v)`` if cu_seqlens_q is None,
or ``(total_q, nheads, headdim_v)`` if cu_seqlens_q is not None
If :attr:`return_softmax_lse` is ``True``, a tuple of two tensors:
* The attention output, shape ``(batch_size, seqlen, nheads, headdim_v)`` if cu_seqlens_q is None,
or ``(total_q, nheads, headdim_v)`` if cu_seqlens_q is not None
* The log sum exp value, shape ``(batch_size, nheads, seqlen)`` if cu_seqlens_q is None,
or ``(nheads, total_q)`` if cu_seqlens_q is not None
"""
assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
if window_size is None:
window_size = (-1, -1)
if attention_chunk is None:
attention_chunk = 0
if softmax_scale is None:
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (
-0.5
)
if cache_seqlens is not None and isinstance(cache_seqlens, int):
cache_seqlens = torch.full(
(k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
)
cache_seqlens = maybe_contiguous(cache_seqlens)
is_mla_decode = qv is not None and qv.shape[-1] == 512 and q.shape[-1] == 64
if is_mla_decode and isinstance(scheduler_metadata, tuple):
mla_qv = qv
mla_q = q
use_flash_mla_asm = mla_qv.shape[-2] == 128
require_seq_dense = not use_flash_mla_asm
mla_qv = _prepare_mla_query_input(mla_qv, require_seq_dense=require_seq_dense)
mla_q = _prepare_mla_query_input(mla_q, require_seq_dense=require_seq_dense)
out, softmax_lse = _allocate_mla_decode_outputs(mla_q, mla_qv.shape[-1])
_get_mla_ops().get_function("dispatch_mla_impl_for_fa_interface")(
mla_qv,
mla_q,
v_cache,
k_cache,
cache_seqlens,
page_table,
softmax_scale,
causal,
cu_seqlens_q,
max_seqlen_q,
out,
softmax_lse,
scheduler_metadata[0] if scheduler_metadata is not None else None,
scheduler_metadata[1] if scheduler_metadata is not None else False,
)
rest = []
else:
out, softmax_lse, *rest = jit_fmha_fwd(
q=q,
k=k_cache,
v=v_cache,
k_new=k,
v_new=v,
q_v=qv,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=None,
cu_seqlens_k_new=cu_seqlens_k_new,
seqused_q=None,
seqused_k=cache_seqlens,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=None,
page_table=page_table,
kv_batch_idx=cache_batch_idx,
leftpad_k=cache_leftpad,
rotary_cos=rotary_cos,
rotary_sin=rotary_sin,
seqlens_rotary=rotary_seqlens,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
softmax_scale=softmax_scale,
is_causal=causal,
window_size_left=window_size[0],
window_size_right=window_size[1],
attention_chunk=attention_chunk,
learnable_sink=learnable_sink,
softcap=softcap,
is_rotary_interleaved=rotary_interleaved,
scheduler_metadata=scheduler_metadata,
num_splits=num_splits,
pack_gqa=pack_gqa,
mp_margin=sm_margin,
return_lse=return_softmax_lse,
lse=None,
out=None,
cp_world_size=cp_world_size,
cp_rank=cp_rank,
cp_tot_seqused_k=cp_tot_seqused_k,
)
return (out, softmax_lse, *rest) if return_softmax_lse else out