KDA¶
For framework integrations that already target FlashKDA Python APIs, prefer
the flash_kda wrapper package first. Use the MATE API below when wrapper
coverage is not enough or when direct operator-level control is required.
chunk_kda is MATE’s fused chunked KDA operator on MUSA.
At a glance¶
Public API:
mate.chunk_kda/mate.kda.chunk_kdaDevice: MUSA
Input dtypes:
torch.float16andtorch.bfloat16Head dimension: currently fixed to
128Sequence modes: - Dense:
[B, T, H, 128]- Varlen:[S, H, 128]or[1, S, H, 128]withcu_seqlensOptional recurrent state input/output
Optional preallocated
outputandfinal_state
Toolchain requirements¶
The repository-wide install baseline still applies, but build the current fused chunk KDA path with MUSA SDK / MTCC 5.1.0 or newer.
The 4.3.6 toolchain may fail to compile KDA kernels.
Shape contract¶
For dense mode:
q/k:[B, T, Hqk, 128]v/g:[B, T, Hv, 128]beta:[B, T, Hv]
For varlen mode:
q/k:[S, Hqk, 128]or[1, S, Hqk, 128]v/g:[S, Hv, 128]or[1, S, Hv, 128]beta:[S, Hv]or[1, S, Hv]cu_seqlens: cumulative sequence lengths with shape[num_seqs + 1]
Additional constraints:
k.shape == q.shapeg.shape == v.shapebeta.shape == v.shape[:3]Hvmust be divisible byHqkthe last dimension of every tensor passed to the kernel must be contiguous
State tensors¶
initial_stateis optionalfinal_stateis optional unlessoutput_final_state=Truestate shape:
[num_seqs, Hv, 128, 128]state dtype: same as value dtype or
torch.float32if both
initial_stateandfinal_stateare provided, their dtypes must match
Gate parameters¶
When gate parameters are enabled:
A_loganddt_biasmust be provided togetherA_logshape:[Hv]dt_biasshape:[Hv, 128]lower_bounddefaults to-5.0
Outputs¶
default return:
outputif
output_final_state=True: returns(output, final_state)
API reference¶
- mate.kda.chunk_kda(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, g: torch.Tensor, beta: torch.Tensor, scale: float | None = None, initial_state: torch.Tensor | None = None, output_final_state: bool = False, cu_seqlens: torch.Tensor | None = None, A_log: torch.Tensor | None = None, dt_bias: torch.Tensor | None = None, lower_bound: float = -5.0, use_qk_l2norm_in_kernel: bool = True, output: torch.Tensor | None = None, final_state: torch.Tensor | None = None) torch.Tensor | Tuple[torch.Tensor, torch.Tensor][source]¶
Run the fused chunk KDA kernel.
- Parameters:
q – Query tensor with shape
[B, T, Hqk, 128]for dense mode or[S, Hqk, 128]/[1, S, Hqk, 128]for varlen mode.k – Key tensor with the same shape and dtype as
q.v – Value tensor with shape
[B, T, Hv, 128]or varlen equivalent.g – Gate input tensor with the same shape as
v.beta – Beta logits tensor with shape
[B, T, Hv]or varlen equivalent.scale – Optional QK scaling factor. Defaults to
128**-0.5.initial_state – Optional recurrent state tensor.
output_final_state – Whether to return the final recurrent state.
cu_seqlens – Optional cumulative sequence lengths for varlen mode.
A_log – Optional per-head gate parameter tensor.
dt_bias – Optional per-head, per-channel gate bias tensor.
lower_bound – Gate lower bound used when gate parameters
A_loganddt_biasare enabled. Defaults to-5.0.use_qk_l2norm_in_kernel – Whether to normalize Q/K in the kernel.
output – Optional preallocated output tensor.
final_state – Optional preallocated final-state tensor.