KDA

For framework integrations that already target FlashKDA Python APIs, prefer the flash_kda wrapper package first. Use the MATE API below when wrapper coverage is not enough or when direct operator-level control is required.

chunk_kda is MATE’s fused chunked KDA operator on MUSA.

At a glance

  • Public API: mate.chunk_kda / mate.kda.chunk_kda

  • Device: MUSA

  • Input dtypes: torch.float16 and torch.bfloat16

  • Head dimension: currently fixed to 128

  • Sequence modes: - Dense: [B, T, H, 128] - Varlen: [S, H, 128] or [1, S, H, 128] with cu_seqlens

  • Optional recurrent state input/output

  • Optional preallocated output and final_state

Toolchain requirements

  • The repository-wide install baseline still applies, but build the current fused chunk KDA path with MUSA SDK / MTCC 5.1.0 or newer.

  • The 4.3.6 toolchain may fail to compile KDA kernels.

Shape contract

For dense mode:

  • q / k: [B, T, Hqk, 128]

  • v / g: [B, T, Hv, 128]

  • beta: [B, T, Hv]

For varlen mode:

  • q / k: [S, Hqk, 128] or [1, S, Hqk, 128]

  • v / g: [S, Hv, 128] or [1, S, Hv, 128]

  • beta: [S, Hv] or [1, S, Hv]

  • cu_seqlens: cumulative sequence lengths with shape [num_seqs + 1]

Additional constraints:

  • k.shape == q.shape

  • g.shape == v.shape

  • beta.shape == v.shape[:3]

  • Hv must be divisible by Hqk

  • the last dimension of every tensor passed to the kernel must be contiguous

State tensors

  • initial_state is optional

  • final_state is optional unless output_final_state=True

  • state shape: [num_seqs, Hv, 128, 128]

  • state dtype: same as value dtype or torch.float32

  • if both initial_state and final_state are provided, their dtypes must match

Gate parameters

When gate parameters are enabled:

  • A_log and dt_bias must be provided together

  • A_log shape: [Hv]

  • dt_bias shape: [Hv, 128]

  • lower_bound defaults to -5.0

Outputs

  • default return: output

  • if output_final_state=True: returns (output, final_state)

API reference

mate.kda.chunk_kda(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, g: torch.Tensor, beta: torch.Tensor, scale: float | None = None, initial_state: torch.Tensor | None = None, output_final_state: bool = False, cu_seqlens: torch.Tensor | None = None, A_log: torch.Tensor | None = None, dt_bias: torch.Tensor | None = None, lower_bound: float = -5.0, use_qk_l2norm_in_kernel: bool = True, output: torch.Tensor | None = None, final_state: torch.Tensor | None = None) torch.Tensor | Tuple[torch.Tensor, torch.Tensor][source]

Run the fused chunk KDA kernel.

Parameters:
  • q – Query tensor with shape [B, T, Hqk, 128] for dense mode or [S, Hqk, 128] / [1, S, Hqk, 128] for varlen mode.

  • k – Key tensor with the same shape and dtype as q.

  • v – Value tensor with shape [B, T, Hv, 128] or varlen equivalent.

  • g – Gate input tensor with the same shape as v.

  • beta – Beta logits tensor with shape [B, T, Hv] or varlen equivalent.

  • scale – Optional QK scaling factor. Defaults to 128**-0.5.

  • initial_state – Optional recurrent state tensor.

  • output_final_state – Whether to return the final recurrent state.

  • cu_seqlens – Optional cumulative sequence lengths for varlen mode.

  • A_log – Optional per-head gate parameter tensor.

  • dt_bias – Optional per-head, per-channel gate bias tensor.

  • lower_bound – Gate lower bound used when gate parameters A_log and dt_bias are enabled. Defaults to -5.0.

  • use_qk_l2norm_in_kernel – Whether to normalize Q/K in the kernel.

  • output – Optional preallocated output tensor.

  • final_state – Optional preallocated final-state tensor.