import functools
import torch
from typing import Tuple, Optional
from mate.api_logging import mate_api
from mate.gemm import (
ragged_m_moe_gemm_16bit,
masked_moe_gemm_16bit,
ragged_k_moe_gemm_16bit,
ragged_m_moe_gemm_8bit,
masked_moe_gemm_8bit,
gemm_fp8_nt_groupwise,
ragged_k_moe_gemm_8bit,
)
from mate.jit.deep_gemm_attention import (
get_deep_gemm_attention_module,
get_metadata_module,
)
from mate.jit.gemm.deep_gemm.hyperconnection import get_hyperconnection_module
from mate.jit.runtime import ffi_to_torch
@functools.cache
def _get_module():
return get_deep_gemm_attention_module()
def _resolve_num_mps(device: torch.device, num_mps: int) -> int:
if num_mps > 0:
return num_mps
device_index = device.index
if device_index is None:
device_index = torch.musa.current_device()
return torch.musa.get_device_properties(device_index).multi_processor_count
def m_grouped_bf16_gemm_nt_contiguous(
a: torch.Tensor,
b: torch.Tensor,
d: torch.Tensor,
m_indices: torch.Tensor,
alignment_m: int = 128,
):
ragged_m_moe_gemm_16bit(
a,
b,
m_indices,
d,
alignment_m=alignment_m,
)
def m_grouped_bf16_gemm_nt_masked(
a: torch.Tensor,
b: torch.Tensor,
d: torch.Tensor,
masked_m: torch.Tensor,
expected_m: int,
compiled_dims: str = "nk",
enable_overlap: bool = False,
signal: torch.Tensor = None,
):
res = masked_moe_gemm_16bit(
a,
b,
masked_m,
d,
expect_tokens=expected_m,
enable_overlap=enable_overlap,
signal=signal,
)
return res[2:] if enable_overlap else None
def m_grouped_fp8_gemm_nt_contiguous(
a: Tuple[torch.Tensor, torch.Tensor],
b: Tuple[torch.Tensor, torch.Tensor],
d: torch.Tensor,
m_indices: torch.Tensor,
recipe: Optional[Tuple[int, int, int]] = None,
compiled_dims: str = "nk",
disable_ue8m0_cast: bool = True,
alignment_m: int = 128,
):
if not disable_ue8m0_cast:
raise Exception("m_grouped_fp8_gemm_nt_contiguous UE8M0 cast is not supported!")
ragged_m_moe_gemm_8bit(
a, b, m_indices, d, scale_granularity_mnk=recipe, alignment_m=alignment_m
)
def m_grouped_fp8_gemm_nt_masked(
a: Tuple[torch.Tensor, torch.Tensor],
b: Tuple[torch.Tensor, torch.Tensor],
d: torch.Tensor,
masked_m: torch.Tensor,
expected_m: int,
recipe: Optional[Tuple[int, int, int]] = None,
compiled_dims: str = "nk",
disable_ue8m0_cast: bool = True,
enable_overlap: bool = False,
signal: torch.Tensor = None,
):
if not disable_ue8m0_cast:
raise Exception("m_grouped_fp8_gemm_nt_masked UE8M0 cast is not supported!")
res = masked_moe_gemm_8bit(
a,
b,
masked_m,
d,
recipe,
expected_m,
enable_overlap=enable_overlap,
signal=signal,
)
return res[2:] if enable_overlap else None
def k_grouped_fp8_gemm_tn_contiguous(
a: Tuple[torch.Tensor, torch.Tensor],
b: Tuple[torch.Tensor, torch.Tensor],
d: torch.Tensor,
ks: Optional[list[int]] = None,
ks_tensor: Optional[torch.Tensor] = None,
recipe: Optional[Tuple[int, int, int]] = None,
compiled_dims: str = "nk",
):
if ks_tensor is None:
if ks is None:
raise Exception("Must give the ks whether on host or device!")
ks_tensor = torch.tensor(ks, device="musa", dtype=torch.int32)
ragged_k_moe_gemm_8bit(
a,
b,
ks_tensor,
d,
scale_granularity_mnk=recipe,
)
return d
def k_grouped_bf16_gemm_tn_contiguous(
a: torch.Tensor,
b: torch.Tensor,
d: torch.Tensor,
ks: Optional[list[int]] = None,
ks_tensor: Optional[torch.Tensor] = None,
compiled_dims: str = "nk",
):
if ks_tensor is None:
if ks is None:
raise Exception("Must give the ks whether on host or device!")
ks_tensor = torch.tensor(ks, device="musa", dtype=torch.int32)
ragged_k_moe_gemm_16bit(
a,
b,
ks_tensor,
d,
)
return d
# legacy deepgemm api
fp8_m_grouped_gemm_nt_masked = m_grouped_fp8_gemm_nt_masked
bf16_m_grouped_gemm_nt_masked = m_grouped_bf16_gemm_nt_masked
def fp8_gemm_nt(
a: Tuple[torch.Tensor, torch.Tensor],
b: Tuple[torch.Tensor, torch.Tensor],
d: torch.Tensor,
c: Optional[torch.Tensor] = None,
recipe: Optional[Tuple[int, int, int]] = None,
compiled_dims: str = "nk",
disable_ue8m0_cast: bool = True,
):
assert c is None, "Not support GEMM with C"
return gemm_fp8_nt_groupwise(
a[0], b[0], a[1], b[1], scale_granularity_mnk=recipe, out=d
)
[docs]
@mate_api
def get_paged_mqa_logits_metadata(
context_lens: torch.Tensor, block_kv: int, num_mps: int = 0
) -> torch.Tensor:
r"""Get metadata for paged MQA logits
Parameters
----------
context_lens: Tensor
Context lengths of each query, shape ``(batch_size)``
block_kv: Tensor
Block size of kv cache, **must be 64 now**.
num_mps: int
Number of MP to execute. 0 means use all MPs of the current device
Returns
-------
Tensor
Schedule metadata, shape ``(num_mps + 1, 2)``
"""
num_mps = _resolve_num_mps(context_lens.device, num_mps)
schedule_meta = torch.empty(
(num_mps + 1, 2), device=context_lens.device, dtype=torch.int32
)
batch_size = context_lens.shape[0]
get_metadata_module(batch_size).get_function("get_paged_mqa_logits_metadata")(
context_lens, block_kv, schedule_meta
)
return schedule_meta
[docs]
@mate_api
def fp8_paged_mqa_logits(
q: torch.Tensor,
fused_kv_cache: torch.Tensor,
weights: torch.Tensor,
context_lens: torch.Tensor,
block_table: torch.Tensor,
schedule_meta: torch.Tensor,
max_context_len: int,
clean_logits: bool,
) -> torch.Tensor:
r"""FP8 Paged MQA logits
Parameters
----------
q: Tensor
The FP8 query tensor with shape ``(batch_size, next_n, heads, index_dim)``
fused_kv_cache: Tensor
The FP8 kv cache with fp32 scale, shape ``(num_blocks, block_size, 1, index_dim + 4)``
weights: Tensor
The FP32 weight tensor for each query, shape ``(batch_size * next_n, heads)``
context_lens: Tensor
Context lengths tensor, supports two layouts:
- **1D** ``(batch_size,)`` — all ``next_n`` draft tokens of request ``i`` share the
same context length ``context_lens[i]``. The visible KV range for draft token
``j`` is implicitly ``[0, context_lens[i] - next_n + j]``.
- **2D** ``(batch_size, next_n)`` — each draft token has an independent context
length ``context_lens[i, j]``, with visible KV range ``[0, context_lens[i, j] - 1]``.
Useful for tree-based speculative decoding (e.g. Medusa / EAGLE) where tokens
on different branches see different KV prefixes.
The shape is auto-detected; ``get_paged_mqa_logits_metadata`` must be called with
the same ``context_lens`` tensor.
block_table: Tensor
Block table tensor with shape ``(batch_size, max_blocks)``
schedule_meta: Tensor
Schedule metadata tensor with shape ``(num_mps + 1, 2)``, produced by
:func:`get_paged_mqa_logits_metadata`
max_context_len: int
Maximum context length
clean_logits: bool
Whether to zero-fill logit positions that are out of the valid KV range
Returns
-------
Tensor
FP32 logits, shape ``(batch_size * next_n, max_context_len)``
"""
return ffi_to_torch(
_get_module().get_function("fp8_paged_mqa_logits")(
q,
fused_kv_cache,
weights,
context_lens,
block_table,
schedule_meta,
max_context_len,
clean_logits,
)
)
@mate_api
def fp8_mqa_logits(
q: torch.Tensor,
kv: tuple[torch.Tensor, torch.Tensor],
weights: torch.Tensor,
cu_seq_len_k_start: torch.Tensor,
cu_seq_len_k_end: torch.Tensor,
clean_logits: bool = False,
max_seqlen_k: int = 0,
) -> torch.Tensor:
r"""FP8 MQA logits.
This operator computes MQA (multi-query attention) logits for a query sequence against
a *non-paged* KV tensor. It supports both full logits and "compressed logits" mode
(when ``max_seqlen_k > 0``), where the output width is limited to a window size.
Parameters
----------
q : torch.Tensor
FP8 query tensor with shape ``(seq_len, heads, head_dim)`` and dtype
``torch.float8_e4m3fn``.
kv : tuple[torch.Tensor, torch.Tensor]
A tuple ``(kv_fp8, kv_scale)``:
- ``kv_fp8``: FP8 KV tensor with shape ``(seq_len_kv, head_dim)`` and dtype
``torch.float8_e4m3fn``. (MQA uses a single KV head.)
- ``kv_scale``: FP32 scale tensor with shape ``(seq_len_kv,)`` and dtype
``torch.float32``.
weights : torch.Tensor
FP32 weight tensor with shape ``(seq_len, heads)`` and dtype ``torch.float32``.
cu_seq_len_k_start : torch.Tensor
Per-row valid KV start offsets (inclusive) for each query row, with shape
``(seq_len,)`` and dtype ``torch.int32``.
cu_seq_len_k_end : torch.Tensor
Per-row valid KV end offsets (exclusive) for each query row, with shape
``(seq_len,)`` and dtype ``torch.int32``.
clean_logits : bool, default=False
Whether to clean logits outside valid KV range. Must be ``False`` when
``max_seqlen_k > 0``.
max_seqlen_k : int, default=0
If > 0, enables compressed logits mode. The output width becomes ``max_seqlen_k``
(a windowed logits range per row). In this mode, ``clean_logits`` must be ``False``.
Returns
-------
torch.Tensor
FP32 logits tensor with shape:
- ``(seq_len, seq_len_kv)`` if ``max_seqlen_k == 0``
- ``(seq_len, max_seqlen_k)`` if ``max_seqlen_k > 0``
and dtype ``torch.float32``.
"""
kv_fp8, kv_scale = kv
if max_seqlen_k > 0 and clean_logits:
raise ValueError("max_seqlen_k is not supported with clean_logits")
return ffi_to_torch(
_get_module().get_function("fp8_mqa_logits")(
q,
kv_fp8,
weights,
cu_seq_len_k_start,
cu_seq_len_k_end,
kv_scale,
clean_logits,
int(max_seqlen_k),
)
)
@mate_api
def tf32_hc_prenorm_gemm(
a: torch.Tensor,
b: torch.Tensor,
d: torch.Tensor,
sqr_sum: torch.Tensor,
num_splits: Optional[int] = None,
) -> None:
r"""TF32 HyperConnection prenorm GEMM.
Parameters
----------
a : torch.Tensor
Input tensor with shape ``(M, K)`` and dtype ``torch.bfloat16``.
b : torch.Tensor
Weight tensor with shape ``(N, K)`` and dtype ``torch.float32``.
d : torch.Tensor
Output GEMM tensor with dtype ``torch.float32``. Shape is ``(M, N)`` when
``num_splits`` is ``None`` or ``<= 1``; otherwise ``(num_splits, M, N)``.
sqr_sum : torch.Tensor
Output row-wise squared-sum tensor with dtype ``torch.float32``. Shape is
``(M,)`` when ``num_splits`` is ``None`` or ``<= 1``; otherwise
``(num_splits, M)``.
num_splits : int, default=None
Optional split-K factor. When greater than 1, the kernel writes per-split
partial outputs and callers should reduce them along dim 0.
Returns
-------
None
"""
m = a.shape[0]
n = b.shape[0]
num_splits = 1 if num_splits is None or num_splits <= 1 else int(num_splits)
num_mps = _resolve_num_mps(a.device, 0)
get_hyperconnection_module(m, n, num_splits, num_mps).get_function(
"tf32_hc_prenorm_gemm"
)(
a,
b,
d,
sqr_sum,
num_splits,
)