Source code for mate.kda

from __future__ import annotations

from typing import Optional, Tuple, Union

import torch

from mate.api_logging import mate_api
from mate.jit.kda_ops import (
    get_kda_fused_ops_function_name,
    get_kda_fused_ops_module,
    make_kda_fused_ops_config,
)

_SUPPORTED_QKVA_DTYPES = (torch.float16, torch.bfloat16)


def _as_4d_varlen_input(
    x: torch.Tensor,
    *,
    name: str,
    cu_seqlens: Optional[torch.Tensor],
) -> tuple[torch.Tensor, bool]:
    if cu_seqlens is None:
        if x.ndim != 4:
            raise ValueError(f"{name} must be a 4D tensor [B, T, H, 128].")
        return x, False
    if x.ndim == 3:
        return x.unsqueeze(0), True
    if x.ndim == 4:
        if x.shape[0] != 1:
            raise ValueError(
                f"{name}.shape[0] must be 1 when cu_seqlens is provided; "
                "flatten variable-length input as [S, H, 128]."
            )
        return x, False
    raise ValueError(f"{name} must be [S, H, 128] or [1, S, H, 128] for varlen.")


def _check_state_dtype(
    x: torch.Tensor,
    *,
    name: str,
    value_dtype: torch.dtype,
) -> None:
    if x.dtype not in (value_dtype, torch.float32):
        raise TypeError(f"{name} must have dtype {value_dtype} or torch.float32.")


[docs] @mate_api def chunk_kda( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, g: torch.Tensor, beta: torch.Tensor, scale: Optional[float] = None, initial_state: Optional[torch.Tensor] = None, output_final_state: bool = False, cu_seqlens: Optional[torch.Tensor] = None, A_log: Optional[torch.Tensor] = None, dt_bias: Optional[torch.Tensor] = None, lower_bound: float = -5.0, use_qk_l2norm_in_kernel: bool = True, output: Optional[torch.Tensor] = None, final_state: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """Run the fused chunk KDA kernel. Args: q: Query tensor with shape ``[B, T, Hqk, 128]`` for dense mode or ``[S, Hqk, 128]`` / ``[1, S, Hqk, 128]`` for varlen mode. k: Key tensor with the same shape and dtype as ``q``. v: Value tensor with shape ``[B, T, Hv, 128]`` or varlen equivalent. g: Gate input tensor with the same shape as ``v``. beta: Beta logits tensor with shape ``[B, T, Hv]`` or varlen equivalent. scale: Optional QK scaling factor. Defaults to ``128**-0.5``. initial_state: Optional recurrent state tensor. output_final_state: Whether to return the final recurrent state. cu_seqlens: Optional cumulative sequence lengths for varlen mode. A_log: Optional per-head gate parameter tensor. dt_bias: Optional per-head, per-channel gate bias tensor. lower_bound: Gate lower bound used when gate parameters ``A_log`` and ``dt_bias`` are enabled. Defaults to ``-5.0``. use_qk_l2norm_in_kernel: Whether to normalize Q/K in the kernel. output: Optional preallocated output tensor. final_state: Optional preallocated final-state tensor. """ q, squeeze_varlen = _as_4d_varlen_input(q, name="q", cu_seqlens=cu_seqlens) k, _ = _as_4d_varlen_input(k, name="k", cu_seqlens=cu_seqlens) v, _ = _as_4d_varlen_input(v, name="v", cu_seqlens=cu_seqlens) g, _ = _as_4d_varlen_input(g, name="g", cu_seqlens=cu_seqlens) if beta.ndim == 2 and cu_seqlens is not None: beta = beta.unsqueeze(0) if ( q.shape[-1] != 128 or k.shape[-1] != 128 or v.shape[-1] != 128 or g.shape[-1] != 128 ): raise ValueError("chunk_kda currently requires D=128.") if k.shape != q.shape: raise ValueError("k must have the same shape as q.") if q.shape[:2] != v.shape[:2] or q.shape[:2] != g.shape[:2]: raise ValueError("q, v and g must have matching [B, T] dimensions.") if g.shape != v.shape: raise ValueError("g must have the same shape as v.") if v.shape[2] % q.shape[2] != 0: raise ValueError("GVA requires v/g heads to be divisible by q/k heads.") if beta.shape != v.shape[:3]: raise ValueError("beta must have shape [B, T, Hv].") if q.dtype not in _SUPPORTED_QKVA_DTYPES: raise TypeError("chunk_kda supports torch.float16 and torch.bfloat16 inputs.") if ( k.dtype != q.dtype or v.dtype != q.dtype or g.dtype != q.dtype or beta.dtype != q.dtype ): raise TypeError("k, v, g and beta must have the same dtype as q.") if (A_log is None) != (dt_bias is None): raise ValueError("A_log and dt_bias must be provided together.") if initial_state is not None: _check_state_dtype(initial_state, name="initial_state", value_dtype=q.dtype) if scale is None: scale = q.shape[-1] ** -0.5 if cu_seqlens is not None: if cu_seqlens.dtype not in (torch.int32, torch.int64): raise TypeError("cu_seqlens must have dtype torch.int32 or torch.int64.") cu_seqlens = cu_seqlens.contiguous() if output is None: output = torch.empty_like(v) elif squeeze_varlen and output.ndim == 3: output = output.unsqueeze(0) if output.shape != v.shape: raise ValueError("output must have the same shape as v.") if output.dtype != q.dtype: raise TypeError("output must have the same dtype as q.") if output_final_state and final_state is None: nseq = int(cu_seqlens.numel() - 1) if cu_seqlens is not None else q.shape[0] state_dtype = initial_state.dtype if initial_state is not None else q.dtype final_state = torch.empty( (nseq, v.shape[2], 128, 128), device=q.device, dtype=state_dtype, ) elif not output_final_state: final_state = None elif final_state is not None: _check_state_dtype(final_state, name="final_state", value_dtype=q.dtype) if ( initial_state is not None and final_state is not None and initial_state.dtype != final_state.dtype ): raise TypeError("initial_state and final_state must have the same dtype.") state_fp32 = ( initial_state is not None and initial_state.dtype == torch.float32 ) or (final_state is not None and final_state.dtype == torch.float32) state_dtype = ( initial_state.dtype if initial_state is not None else final_state.dtype if final_state is not None else q.dtype ) kda_config = make_kda_fused_ops_config( q.dtype, state_dtype=state_dtype, cu_seqlens_dtype=cu_seqlens.dtype if cu_seqlens is not None else None, has_state_in=initial_state is not None, has_state_out=final_state is not None, state_fp32=state_fp32, has_gate_params=A_log is not None, is_varlen=cu_seqlens is not None, normalize_qk=bool(use_qk_l2norm_in_kernel), ) kda_func_name = get_kda_fused_ops_function_name(kda_config) get_kda_fused_ops_module(kda_config).get_function(kda_func_name)( q, k, v, g, beta, output, initial_state, final_state, cu_seqlens, A_log, dt_bias, float(scale), float(lower_bound), bool(use_qk_l2norm_in_kernel), ) if squeeze_varlen: output = output.squeeze(0) if output_final_state: assert final_state is not None return output, final_state return output