Source code for mate.deep_gemm

import torch
from typing import Tuple, Optional
from mate.api_logging import mate_api
from mate.mate_runtime import (
    get_num_mps as get_num_mps,
    resolve_num_mps,
    set_num_mps as set_num_mps,
)
from mate.gemm import (
    bmm_fp8,
    ragged_m_moe_gemm_16bit,
    masked_moe_gemm_16bit,
    ragged_k_moe_gemm_16bit,
    ragged_m_moe_gemm_8bit,
    masked_moe_gemm_8bit,
    gemm_fp8_nt_groupwise,
    ragged_k_moe_gemm_8bit,
)
from mate.jit.gemm.deep_gemm.gemm import (
    GEMM_TYPE_NORMAL,
    get_deep_gemm_gemm_module,
)
from mate.jit.gemm.deep_gemm.hyperconnection import get_hyperconnection_module
from mate.jit.gemm.deep_gemm.mqa_logits import get_mqa_logits_module
from mate.jit.gemm.deep_gemm.paged_mqa_logits import (
    get_paged_mqa_logits_metadata_module,
    get_paged_mqa_logits_module,
)
from mate.jit.runtime import ffi_to_torch


def m_grouped_bf16_gemm_nt_contiguous(
    a: torch.Tensor,
    b: torch.Tensor,
    d: torch.Tensor,
    m_indices: torch.Tensor,
    alignment_m: int = 128,
    backend: str = "auto",
):
    ragged_m_moe_gemm_16bit(
        a,
        b,
        m_indices,
        d,
        alignment_m=alignment_m,
        backend=backend,
    )


def m_grouped_bf16_gemm_nt_masked(
    a: torch.Tensor,
    b: torch.Tensor,
    d: torch.Tensor,
    masked_m: torch.Tensor,
    expected_m: int,
    compiled_dims: str = "nk",
    enable_overlap: bool = False,
    signal: torch.Tensor = None,
    backend: str = "auto",
):
    res = masked_moe_gemm_16bit(
        a,
        b,
        masked_m,
        d,
        expect_tokens=expected_m,
        enable_overlap=enable_overlap,
        signal=signal,
        backend=backend,
    )

    return res[2:] if enable_overlap else None


def m_grouped_fp8_gemm_nt_contiguous(
    a: Tuple[torch.Tensor, torch.Tensor],
    b: Tuple[torch.Tensor, torch.Tensor],
    d: torch.Tensor,
    m_indices: torch.Tensor,
    recipe: Optional[Tuple[int, int, int]] = None,
    compiled_dims: str = "nk",
    disable_ue8m0_cast: bool = True,
    alignment_m: int = 128,
    backend: str = "auto",
):
    if not disable_ue8m0_cast:
        raise Exception("m_grouped_fp8_gemm_nt_contiguous UE8M0 cast is not supported!")

    ragged_m_moe_gemm_8bit(
        a,
        b,
        m_indices,
        d,
        scale_granularity_mnk=recipe,
        alignment_m=alignment_m,
        backend=backend,
    )


def m_grouped_fp8_gemm_nt_masked(
    a: Tuple[torch.Tensor, torch.Tensor],
    b: Tuple[torch.Tensor, torch.Tensor],
    d: torch.Tensor,
    masked_m: torch.Tensor,
    expected_m: int,
    recipe: Optional[Tuple[int, int, int]] = None,
    compiled_dims: str = "nk",
    disable_ue8m0_cast: bool = True,
    enable_overlap: bool = False,
    signal: torch.Tensor = None,
    backend: str = "auto",
):
    if not disable_ue8m0_cast:
        raise Exception("m_grouped_fp8_gemm_nt_masked UE8M0 cast is not supported!")

    res = masked_moe_gemm_8bit(
        a,
        b,
        masked_m,
        d,
        recipe,
        expected_m,
        enable_overlap=enable_overlap,
        signal=signal,
        backend=backend,
    )

    return res[2:] if enable_overlap else None


def k_grouped_fp8_gemm_tn_contiguous(
    a: Tuple[torch.Tensor, torch.Tensor],
    b: Tuple[torch.Tensor, torch.Tensor],
    d: torch.Tensor,
    ks: Optional[list[int]] = None,
    ks_tensor: Optional[torch.Tensor] = None,
    recipe: Optional[Tuple[int, int, int]] = None,
    compiled_dims: str = "nk",
):
    if ks_tensor is None:
        if ks is None:
            raise Exception("Must give the ks whether on host or device!")
        ks_tensor = torch.tensor(ks, device="musa", dtype=torch.int32)

    ragged_k_moe_gemm_8bit(
        a,
        b,
        ks_tensor,
        d,
        scale_granularity_mnk=recipe,
    )

    return d


def k_grouped_bf16_gemm_tn_contiguous(
    a: torch.Tensor,
    b: torch.Tensor,
    d: torch.Tensor,
    ks: Optional[list[int]] = None,
    ks_tensor: Optional[torch.Tensor] = None,
    compiled_dims: str = "nk",
):
    if ks_tensor is None:
        if ks is None:
            raise Exception("Must give the ks whether on host or device!")
        ks_tensor = torch.tensor(ks, device="musa", dtype=torch.int32)

    ragged_k_moe_gemm_16bit(
        a,
        b,
        ks_tensor,
        d,
    )

    return d


# legacy deepgemm api
fp8_m_grouped_gemm_nt_masked = m_grouped_fp8_gemm_nt_masked
bf16_m_grouped_gemm_nt_masked = m_grouped_bf16_gemm_nt_masked


def bf16_gemm_nt(
    a: torch.Tensor,
    b: torch.Tensor,
    d: torch.Tensor,
    c: Optional[torch.Tensor] = None,
    compiled_dims: str = "nk",
    backend: str = "auto",
):
    assert c is None, "Not support GEMM with C"
    if backend not in ("auto", "mutlass"):
        raise ValueError(f"bf16_gemm_nt only supports mutlass backend, got {backend}")

    dispatch_name, mod = get_deep_gemm_gemm_module(
        kind="bf16",
        gemm_type=GEMM_TYPE_NORMAL,
        config_m=a.shape[0],
    )
    mod.get_function(dispatch_name)(
        a,
        b,
        d,
        None,
        0,
        resolve_num_mps(a.device),
    )
    return d


def fp8_gemm_nt(
    a: Tuple[torch.Tensor, torch.Tensor],
    b: Tuple[torch.Tensor, torch.Tensor],
    d: torch.Tensor,
    c: Optional[torch.Tensor] = None,
    recipe: Optional[Tuple[int, int, int]] = None,
    compiled_dims: str = "nk",
    disable_ue8m0_cast: bool = True,
    backend: str = "auto",
):
    assert c is None, "Not support GEMM with C"
    if not disable_ue8m0_cast:
        raise Exception("fp8_gemm_nt UE8M0 cast is not supported!")

    if backend in ("auto", "mudnn"):
        return gemm_fp8_nt_groupwise(
            a[0],
            b[0],
            a[1],
            b[1],
            scale_granularity_mnk=recipe,
            out=d,
            backend=backend,
        )
    if backend != "mutlass":
        raise ValueError(f"Unsupported fp8_gemm_nt backend: {backend}")

    dispatch_name, mod = get_deep_gemm_gemm_module(
        kind="fp8",
        gemm_type=GEMM_TYPE_NORMAL,
        config_m=a[0].shape[0],
    )
    mod.get_function(dispatch_name)(
        a[0],
        a[1],
        b[0],
        b[1],
        d,
        None,
        0,
        resolve_num_mps(a[0].device),
    )
    return d


def _validate_fp8_einsum_pair(
    name: str, value: Tuple[torch.Tensor, torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
    if not isinstance(value, tuple) or len(value) != 2:
        raise TypeError(f"{name} must be a tuple of (fp8_tensor, scale_tensor)")
    tensor, scale = value
    if tensor.dtype not in (torch.float8_e4m3fn, torch.float8_e5m2):
        raise ValueError(f"{name}[0] must be an FP8 tensor")
    if scale.dtype != torch.float32:
        raise ValueError(f"{name}[1] must be a float32 scale tensor")
    if tensor.device != scale.device:
        raise ValueError(f"{name}[0] and {name}[1] must be on the same device")
    return tensor, scale


@mate_api
def fp8_einsum(
    expr: str,
    a: Tuple[torch.Tensor, torch.Tensor],
    b: Tuple[torch.Tensor, torch.Tensor],
    d: torch.Tensor,
    c: Optional[torch.Tensor] = None,
    recipe: Tuple[int, int, int] = (1, 128, 128),
) -> None:
    r"""DeepGEMM-compatible FP8 einsum.

    Supported expressions are ``"bhr,hdr->bhd"``, ``"bhd,hdr->bhr"``, and
    ``"bhd,bhr->hdr"``. The quantization recipe must be ``(1, 128, 128)`` or
    ``(1, 1, 128)``.
    """
    recipe_values = tuple(int(x) for x in recipe)
    if len(recipe_values) != 3:
        raise ValueError("fp8_einsum recipe must be a 3-tuple")
    recipe = (recipe_values[0], recipe_values[1], recipe_values[2])
    if recipe not in ((1, 128, 128), (1, 1, 128)):
        raise ValueError("fp8_einsum only supports recipe (1, 128, 128) or (1, 1, 128)")

    a_fp8, a_scale = _validate_fp8_einsum_pair("a", a)
    b_fp8, b_scale = _validate_fp8_einsum_pair("b", b)
    if a_fp8.device != b_fp8.device or a_fp8.device != d.device:
        raise ValueError("a, b and d must be on the same device")
    if c is not None:
        if c.device != d.device:
            raise ValueError("c must be on the same device as d")
        if c.dtype != torch.float32 or d.dtype != torch.float32:
            raise ValueError("fp8_einsum with c expects fp32 c and d tensors")

    if expr == "bhr,hdr->bhd":
        if a_fp8.dim() != 3 or b_fp8.dim() != 3 or d.dim() != 3:
            raise ValueError("fp8_einsum('bhr,hdr->bhd') expects 3D tensors")
        batch, heads, r_dim = a_fp8.shape
        h_b, d_dim, r_b = b_fp8.shape
        if heads != h_b or r_dim != r_b or tuple(d.shape) != (batch, heads, d_dim):
            raise ValueError("expected a[b,h,r], b[h,d,r] and d[b,h,d]")
        c_view = c.permute(1, 0, 2) if c is not None else None
        bmm_fp8(
            a_fp8.permute(1, 0, 2),
            b_fp8,
            a_scale.permute(1, 0, 2),
            b_scale,
            d.dtype,
            out=d.permute(1, 0, 2),
            scale_granularity_mnk=recipe,
            c=c_view,
            major_a_mode="K",
            major_b_mode="K",
        )
        return None

    if expr == "bhd,hdr->bhr":
        if a_fp8.dim() != 3 or b_fp8.dim() != 3 or d.dim() != 3:
            raise ValueError("fp8_einsum('bhd,hdr->bhr') expects 3D tensors")
        batch, heads, d_dim = a_fp8.shape
        h_b, d_b, r_dim = b_fp8.shape
        if heads != h_b or d_dim != d_b or tuple(d.shape) != (batch, heads, r_dim):
            raise ValueError("expected a[b,h,d], b[h,d,r] and d[b,h,r]")
        c_view = c.permute(1, 0, 2) if c is not None else None
        bmm_fp8(
            a_fp8.permute(1, 0, 2),
            b_fp8,
            a_scale.permute(1, 0, 2),
            b_scale,
            d.dtype,
            out=d.permute(1, 0, 2),
            scale_granularity_mnk=recipe,
            c=c_view,
            major_a_mode="K",
            major_b_mode="N",
        )
        return None

    if expr == "bhd,bhr->hdr":
        if a_fp8.dim() != 3 or b_fp8.dim() != 3 or d.dim() != 3:
            raise ValueError("fp8_einsum('bhd,bhr->hdr') expects 3D tensors")
        batch, heads, d_dim = a_fp8.shape
        b_batch, h_b, r_dim = b_fp8.shape
        if batch != b_batch or heads != h_b or tuple(d.shape) != (heads, d_dim, r_dim):
            raise ValueError("expected a[b,h,d], b[b,h,r] and d[h,d,r]")
        bmm_fp8(
            a_fp8.permute(1, 0, 2),
            b_fp8.permute(1, 0, 2),
            a_scale.permute(1, 0, 2),
            b_scale.permute(1, 0, 2),
            d.dtype,
            out=d,
            scale_granularity_mnk=recipe,
            c=c,
            major_a_mode="M",
            major_b_mode="N",
        )
        return None

    raise ValueError(f"Unsupported fp8_einsum expression: {expr}")


[docs] @mate_api def get_paged_mqa_logits_metadata( context_lens: torch.Tensor, block_kv: int, num_mps: int = 0 ) -> torch.Tensor: r"""Get metadata for paged MQA logits Parameters ---------- context_lens: Tensor Context lengths of each query, shape ``(batch_size)`` block_kv: Tensor Block size of kv cache, **must be 64 now**. num_mps: int Number of MP to execute. 0 means use all MPs of the current device Returns ------- Tensor Schedule metadata, shape ``(num_mps + 1, 2)`` """ num_mps = resolve_num_mps(context_lens.device, num_mps) schedule_meta = torch.empty( (num_mps + 1, 2), device=context_lens.device, dtype=torch.int32 ) batch_size = context_lens.shape[0] get_paged_mqa_logits_metadata_module(batch_size).get_function( "get_paged_mqa_logits_metadata" )(context_lens, block_kv, schedule_meta) return schedule_meta
[docs] @mate_api def fp8_paged_mqa_logits( q: torch.Tensor, fused_kv_cache: torch.Tensor, weights: torch.Tensor, context_lens: torch.Tensor, block_table: torch.Tensor, schedule_meta: torch.Tensor, max_context_len: int, clean_logits: bool, ) -> torch.Tensor: r"""FP8 Paged MQA logits Parameters ---------- q: Tensor The FP8 query tensor with shape ``(batch_size, next_n, heads, index_dim)`` fused_kv_cache: Tensor The FP8 kv cache with fp32 scale, shape ``(num_blocks, block_size, 1, index_dim + 4)`` weights: Tensor The FP32 weight tensor for each query, shape ``(batch_size * next_n, heads)`` context_lens: Tensor Context lengths tensor, supports two layouts: - **1D** ``(batch_size,)`` — all ``next_n`` draft tokens of request ``i`` share the same context length ``context_lens[i]``. The visible KV range for draft token ``j`` is implicitly ``[0, context_lens[i] - next_n + j]``. - **2D** ``(batch_size, next_n)`` — each draft token has an independent context length ``context_lens[i, j]``, with visible KV range ``[0, context_lens[i, j] - 1]``. Useful for tree-based speculative decoding (e.g. Medusa / EAGLE) where tokens on different branches see different KV prefixes. The shape is auto-detected; ``get_paged_mqa_logits_metadata`` must be called with the same ``context_lens`` tensor. block_table: Tensor Block table tensor with shape ``(batch_size, max_blocks)`` schedule_meta: Tensor Schedule metadata tensor with shape ``(num_mps + 1, 2)``, produced by :func:`get_paged_mqa_logits_metadata` max_context_len: int Maximum context length clean_logits: bool Whether to zero-fill logit positions that are out of the valid KV range Returns ------- Tensor FP32 logits, shape ``(batch_size * next_n, max_context_len)`` """ next_n = q.shape[1] num_heads = q.shape[2] head_dim = q.shape[3] block_kv = fused_kv_cache.shape[1] is_context_lens_2d = context_lens.dim() == 2 return ffi_to_torch( get_paged_mqa_logits_module( next_n, num_heads, head_dim, block_kv, is_context_lens_2d, ).get_function("fp8_paged_mqa_logits")( q, fused_kv_cache, weights, context_lens, block_table, schedule_meta, max_context_len, clean_logits, ) )
@mate_api def fp8_mqa_logits( q: torch.Tensor, kv: tuple[torch.Tensor, torch.Tensor], weights: torch.Tensor, cu_seq_len_k_start: torch.Tensor, cu_seq_len_k_end: torch.Tensor, clean_logits: bool = False, max_seqlen_k: int = 0, ) -> torch.Tensor: r"""FP8 MQA logits. This operator computes MQA (multi-query attention) logits for a query sequence against a *non-paged* KV tensor. It supports both full logits and "compressed logits" mode (when ``max_seqlen_k > 0``), where the output width is limited to a window size. Parameters ---------- q : torch.Tensor FP8 query tensor with shape ``(seq_len, heads, head_dim)`` and dtype ``torch.float8_e4m3fn``. kv : tuple[torch.Tensor, torch.Tensor] A tuple ``(kv_fp8, kv_scale)``: - ``kv_fp8``: FP8 KV tensor with shape ``(seq_len_kv, head_dim)`` and dtype ``torch.float8_e4m3fn``. (MQA uses a single KV head.) - ``kv_scale``: FP32 scale tensor with shape ``(seq_len_kv,)`` and dtype ``torch.float32``. weights : torch.Tensor FP32 weight tensor with shape ``(seq_len, heads)`` and dtype ``torch.float32``. cu_seq_len_k_start : torch.Tensor Per-row valid KV start offsets (inclusive) for each query row, with shape ``(seq_len,)`` and dtype ``torch.int32``. cu_seq_len_k_end : torch.Tensor Per-row valid KV end offsets (exclusive) for each query row, with shape ``(seq_len,)`` and dtype ``torch.int32``. clean_logits : bool, default=False Whether to clean logits outside valid KV range. Must be ``False`` when ``max_seqlen_k > 0``. max_seqlen_k : int, default=0 If > 0, enables compressed logits mode. The output width becomes ``max_seqlen_k`` (a windowed logits range per row). In this mode, ``clean_logits`` must be ``False``. Returns ------- torch.Tensor FP32 logits tensor with shape: - ``(seq_len, seq_len_kv)`` if ``max_seqlen_k == 0`` - ``(seq_len, max_seqlen_k)`` if ``max_seqlen_k > 0`` and dtype ``torch.float32``. """ kv_fp8, kv_scale = kv if max_seqlen_k > 0 and clean_logits: raise ValueError("max_seqlen_k is not supported with clean_logits") seq_len = q.shape[0] seq_len_kv = kv_fp8.shape[0] num_heads = q.shape[1] head_dim = q.shape[2] compressed_logits = max_seqlen_k > 0 num_mps = resolve_num_mps(q.device) return ffi_to_torch( get_mqa_logits_module( seq_len, seq_len_kv, num_heads, head_dim, compressed_logits, num_mps, ).get_function("fp8_mqa_logits")( q, kv_fp8, weights, cu_seq_len_k_start, cu_seq_len_k_end, kv_scale, clean_logits, int(max_seqlen_k), ) ) @mate_api def tf32_hc_prenorm_gemm( a: torch.Tensor, b: torch.Tensor, d: torch.Tensor, sqr_sum: torch.Tensor, num_splits: Optional[int] = None, ) -> None: r"""TF32 HyperConnection prenorm GEMM. Parameters ---------- a : torch.Tensor Input tensor with shape ``(M, K)`` and dtype ``torch.bfloat16``. b : torch.Tensor Weight tensor with shape ``(N, K)`` and dtype ``torch.float32``. d : torch.Tensor Output GEMM tensor with dtype ``torch.float32``. Shape is ``(M, N)`` when ``num_splits`` is ``None`` or ``<= 1``; otherwise ``(num_splits, M, N)``. sqr_sum : torch.Tensor Output row-wise squared-sum tensor with dtype ``torch.float32``. Shape is ``(M,)`` when ``num_splits`` is ``None`` or ``<= 1``; otherwise ``(num_splits, M)``. num_splits : int, default=None Optional split-K factor. When greater than 1, the kernel writes per-split partial outputs and callers should reduce them along dim 0. Returns ------- None """ m = a.shape[0] n = b.shape[0] num_splits = 1 if num_splits is None or num_splits <= 1 else int(num_splits) num_mps = resolve_num_mps(a.device) get_hyperconnection_module(m, n, num_splits, num_mps).get_function( "tf32_hc_prenorm_gemm" )( a, b, d, sqr_sum, num_splits, )