GEMM

For framework integrations that already target DeepGEMM style Python APIs, prefer the deep-gemm wrapper first. Use the MATE APIs below when wrapper coverage is not enough.

MoE GEMM

mate.gemm.ragged_m_moe_gemm_8bit(input_a: Tuple[torch.Tensor, torch.Tensor], input_b: Tuple[torch.Tensor, torch.Tensor], ragged_tokens_info: torch.Tensor, out: torch.Tensor, gemm_mode: Literal['per_token', 'psum_expert', 'per_expert'] | None = 'per_token', major_a_mode: Literal['M', 'K'] | None = 'K', major_b_mode: Literal['N', 'K'] | None = 'K', scale_granularity_mnk: Tuple[int, int, int] | None = None, num_mp: int | None = None, alignment_m: int | None = None, backend: Literal['auto', 'mubin', 'mutlass'] | None = 'auto')[source]

Perform 8-bit GEMM operation for MoE (Mixture of Experts) with ragged tensor inputs.

This function computes matrix multiplication between 8-bit quantized tensors for MoE models where different experts may have variable numbers of tokens assigned to them.

Parameters:
  • input_a (Tuple[Tensor, Tensor]) – Tuple containing (fp8_tensor, scale_tensor) for input A. fp8_tensor has shape (total_tokens, hidden_size) and should be of fp8 (e4m3/e5m2) type. scale_tensor has shape (total_tokens, hidden_size // scale_granularity_m) and should be of fp32 type.

  • input_b (Tuple[Tensor, Tensor]) – Tuple containing (fp8_tensor, scale_tensor) for input B. fp8_tensor has shape (num_expert, out_hidden_size, hidden_size) and should be of fp8 (e4m3/e5m2) type. scale_tensor has shape (num_expert, out_hidden_size // scale_granularity_n, hidden_size // scale_granularity_k) and should be of fp32 type.

  • ragged_tokens_info (Tensor) – Metadata tensor whose meaning depends on gemm_mode. For per_token, it has shape (total_tokens,) and stores the expert index for each token, with -1 for unused positions. For psum_expert, it has shape (num_expert,) and stores how many tokens the leading experts have in prefix-sum form. For per_expert, it has shape (num_expert,) and stores the token count for each expert.

  • out (Tensor) – Output tensor with shape (total_tokens, out_hidden_size).

  • major_a_mode (Optional[str]) – Indicating major stride of A. Default to K.

  • major_b_mode (Optional[str]) – Indicating major stride of B. Default to K.

  • gemm_mode (Optional[str],) – Indicating different meaning of ragged_tokens_info.

  • scale_granularity_mnk (Optional[Tuple[int, int, int]]) – Quantization granularity for total_tokens, out_hidden_size, hidden_size (m, n, k) dimensions respectively. Default is (1, 128, 128).

  • alignment_m (Optional[int]) – Alignment requirement for total_tokens (m) dimension. Must be 128 or 256. Default is 128.

  • num_mp (Optional[int]) – Suggest mp number. If None, will be get from device info.

Returns:

Result tensor with shape (total_tokens, out_hidden_size) containing the GEMM output in fp16 or bf16 data type.

Return type:

Tensor

mate.gemm.ragged_k_moe_gemm_8bit(input_a: Tuple[torch.Tensor, torch.Tensor], input_b: Tuple[torch.Tensor, torch.Tensor], ragged_tokens_info: torch.Tensor, out: torch.Tensor, gemm_mode: Literal['per_expert'] | None = 'per_expert', major_a_mode: Literal['M', 'K'] | None = 'M', major_b_mode: Literal['N', 'K'] | None = 'N', scale_granularity_mnk: Tuple[int, int, int] | None = None, num_mp: int | None = None)[source]

Perform 8-bit GEMM operation for MoE (Mixture of Experts) with token of each expert.

This function computes matrix multiplication between 8-bit quantized tensors for MoE models where different experts may have variable numbers of tokens.

Parameters:
  • input_a (Tuple[Tensor, Tensor]) – Tuple containing (fp8_tensor, scale_tensor) for input A. fp8_tensor has shape (k, m) and should be of fp8 (e4m3/e5m2) type. scale_tensor has shape (k // scale_granularity_k, m) and should be of fp32 type.

  • input_b (Tuple[Tensor, Tensor]) – Tuple containing (fp8_tensor, scale_tensor) for input B. fp8_tensor has shape (k, n) and should be of fp8 (e4m3/e5m2) type. scale_tensor has shape (k // scale_granularity_k, n) and should be of fp32 type.

  • ragged_tokens_info (Tensor) – Tensor indicating the actual number of tokens for each expert, with shape (num_expert,). Values represent token counts for each expert.

  • out (Tensor) – Output tensor with shape (num_expert, max_tokens, out_hidden_size). Should be of float type. Should not be None.

  • gemm_mode (Optional[str],) – Indicating different meaning of ragged_tokens_info.

  • major_a_mode (Optional[str]) – Major mode of A, defult to M. Only support TN m_grouped_gemm on MP31.

  • major_b_mode (Optional[str]) – Major mode of B, defult to N.

  • scale_granularity_mnk (Optional[Tuple[int, int, int]]) – Quantization granularity for max_tokens, out_hidden_size, hidden_size (m, n, k) dimensions respectively. Kgroupgemm only support 1D1D scale, should be (1, 1, 128).

  • num_mp (Optional[int]) – Suggest mp number. If None, will be get from device info.

Returns:

  • Result tensor with shape (num_experts, total_tokens, out_hidden_size) containing the GEMM output in float data type,

  • Representing D = D + A * B for each expert

mate.gemm.masked_moe_gemm_8bit(input_a: Tuple[torch.Tensor, torch.Tensor], input_b: Tuple[torch.Tensor, torch.Tensor], masked_tokens_info: torch.Tensor, out: torch.Tensor, scale_granularity_mnk: Tuple[int, int, int] | None = None, expect_tokens: int | None = None, enable_overlap: bool = False, signal: torch.Tensor | None = None, backend: Literal['auto', 'mubin', 'mutlass'] | None = 'auto')[source]

Perform 8-bit GEMM operation for MoE (Mixture of Experts) with masked tensor inputs.

This function computes matrix multiplication between 8-bit quantized tensors for MoE models where different experts may have variable numbers of tokens, using a mask to indicate the actual number of tokens per expert.

Parameters:
  • input_a (Tuple[Tensor, Tensor]) – Tuple containing (fp8_tensor, scale_tensor) for input A. fp8_tensor has shape (num_expert, max_tokens, hidden_size) and should be of fp8 (e4m3/e5m2) type. scale_tensor has shape (num_expert, max_tokens, hidden_size // scale_granularity_k) and should be of fp32 type.

  • input_b (Tuple[Tensor, Tensor]) – Tuple containing (fp8_tensor, scale_tensor) for input B. fp8_tensor has shape (num_expert, out_hidden_size, hidden_size) and should be of fp8 (e4m3/e5m2) type. scale_tensor has shape (num_expert, out_hidden_size // scale_granularity_n, hidden_size // scale_granularity_k) and should be of fp32 type.

  • masked_tokens_info (Tensor) – Tensor indicating the actual number of tokens for each expert, with shape (num_expert,). Values represent token counts for each expert.

  • out (Tensor) – Output tensor with shape (num_expert, max_tokens, out_hidden_size). Should be of fp16 or bf16 type. If None, a new tensor will be created.

  • scale_granularity_mnk (Optional[Tuple[int, int, int]]) – Quantization granularity for max_tokens, out_hidden_size, hidden_size (m, n, k) dimensions respectively. Default is (1, 128, 128).

  • expect_tokens (Optional[int]) – Expected number of tokens. If None, defaults to 0.

  • enable_overlap (Optional[bool]) – Whether to enable Single-Batch Overlap (SBO). Default is False.

  • signal (Optional[Tensor]) – Signal tensor with shape (num_expert * ceil_div(max_m, 64)) for SBO. Required if enable_overlap is True. If None, a new tensor is created when needed.

Returns:

If enable_overlap is False, returns result tensor with shape (num_expert, max_tokens, out_hidden_size). If enable_overlap is True, returns a tuple containing:

  • result tensor with shape (num_expert, max_tokens, out_hidden_size)

  • signal tensor

  • block_m int

  • threshold int

Return type:

Union[Tensor, Tuple[Tensor, Tensor, int, int]]

Dense GEMM

mate.gemm.bmm_fp16(a: torch.Tensor, b: torch.Tensor, out_dtype: torch.dtype, out: torch.Tensor | None = None, backend: str = 'auto', c: torch.Tensor | None = None)[source]
mate.gemm.bmm_fp8(a: torch.Tensor, b: torch.Tensor, a_scale: torch.Tensor, b_scale: torch.Tensor, out_dtype: torch.dtype, out: torch.Tensor | None = None, backend: str = 'auto', scale_granularity_mnk: Tuple[int, int, int] | None = None, output_scale: torch.Tensor | None = None, c: torch.Tensor | None = None, major_a_mode: Literal['K', 'M'] = 'K', major_b_mode: Literal['N', 'K'] = 'K')[source]

Perform batched matrix multiplication with FP8 quantized tensors.

This function computes the batched matrix multiplication of two FP8 quantized tensors, applying scaling factors to produce a result in the specified output data type.

Parameters:
  • a (Tensor) – Input tensor A in FP8 format (e4m3/e5m2). Shape is (batch, m, k) when major_a_mode="K" and (batch, k, m) when major_a_mode="M". The declared major matrix dimension must have stride 1.

  • b (Tensor) – Input tensor B in FP8 format (e4m3/e5m2). Shape is (batch, n, k) by default with major_b_mode="K" and (batch, k, n) when major_b_mode="N". The declared major matrix dimension must have stride 1.

  • a_scale (Tensor) – Scaling factors for tensor A with shape depending on scale_granularity. Should be of fp32 type.

  • b_scale (Tensor) – Scaling factors for tensor B with shape depending on scale_granularity. Should be of fp32 type.

  • out_dtype (torch.dtype) – Data type for the output tensor. torch.bfloat16, torch.float16 and torch.float32 are supported.

  • out (Optional[Tensor]) – Pre-allocated output tensor with shape (batch, m, n). Default is None. If None, a new tensor will be allocated.

  • backend (str) – Backend to use for the operation. Current support backends are “mudnn” and “auto”. Default is “auto”.

  • scale_granularity_mnk (Optional[Tuple[int, int, int]]) – Granularity of scaling for batch, m, and n dimensions respectively. (-1, -1, -1), (1, -1, -1), (1, 128, 128) and (1, 1, 128) are supported. If None, defaults to (-1, -1, -1).

  • c (Optional[Tensor]) – Optional FP32 accumulation tensor with shape (batch, m, n).

  • major_a_mode (str) – "K" treats A as (batch, m, k); "M" treats A as (batch, k, m) and asks MatMulLt to transpose A.

  • major_b_mode (str) – "K" treats B as (batch, n, k) and asks MatMulLt to transpose B. "N" treats B as (batch, k, n). Default is "K".

Returns:

Result tensor with shape (batch, m, n) in the specified output data type.

Return type:

Tensor

mate.gemm.gemm_fp8_nt_groupwise(a: torch.Tensor, b: torch.Tensor, a_scale: torch.Tensor, b_scale: torch.Tensor, scale_major_mode: Literal['MN', 'K'] | None = None, mma_sm: int | None = None, scale_granularity_mnk: Tuple[int, int, int] | None = None, out: torch.Tensor | None = None, out_dtype: torch.dtype | None = None, backend: str = 'auto', output_scale: torch.Tensor | None = None)[source]

Perform groupwise FP8 GEMM operation with scaling.

This function computes the matrix multiplication of two FP8 quantized tensors, applying scaling factors to produce a result in the specified output data type. It supports groupwise quantization with configurable scale granularity.

Parameters:
  • a (Tensor) – Input tensor A with shape (m, k) in FP8 format (e4m3/e5m2). Tensor must be contiguous.

  • b (Tensor) – Input tensor B with shape (n, k) in FP8 format (e4m3/e5m2). Tensor must be contiguous.

  • a_scale (Tensor) – Scaling factors for tensor A. Shape depends on scale_granularity_mnk parameter. Should be of fp32 type. Must be contiguous.

  • b_scale (Tensor) – Scaling factors for tensor B. Shape depends on scale_granularity_mnk parameter. Should be of fp32 type. Must be contiguous.

  • scale_major_mode (str) – Scale major mode “MN” or “K” for groupwise operations. Default is “K”.

  • mma_sm (Optional[int]) – MMA SM configuration. Currently only supports 1. Default is 1.

  • scale_granularity_mnk (Optional[Tuple[int, int, int]]) – Granularity of scaling for m, n, and k dimensions respectively. Default is (1, 128, 128).

  • out (Optional[Tensor]) – Pre-allocated output tensor with shape (m, n). Should be bf16/fp16 when output_scale is None, or fp8_e4m3 when output_scale is provided. If None, a new tensor will be allocated.

  • out_dtype (Optional[torch.dtype]) – Data type for the output tensor when out is None. If out is provided, out.dtype is validated instead. Defaults to torch.bfloat16 without output_scale and fp8_e4m3 with output_scale.

  • backend (str) – Backend to use for the operation. Use "mudnn" when output_scale is None and "mubin" when output_scale is provided. "auto" selects the supported backend for the selected output path.

  • output_scale (Optional[torch.Tensor]) – Quantization scale tensor for FP8 output. If provided, the operation uses the mubin FP8-output path. If None, output is not quantized. Default is None.

Returns:

Result tensor with shape (m, n) in the specified output data type.

Return type:

Tensor

DeepGemm Lighting Indexer

mate.deep_gemm.get_paged_mqa_logits_metadata(context_lens: torch.Tensor, block_kv: int, num_mps: int = 0) torch.Tensor[source]

Get metadata for paged MQA logits

Parameters:
  • context_lens (Tensor) – Context lengths of each query, shape (batch_size)

  • block_kv (Tensor) – Block size of kv cache, must be 64 now.

  • num_mps (int) – Number of MP to execute. 0 means use all MPs of the current device

Returns:

Schedule metadata, shape (num_mps + 1, 2)

Return type:

Tensor

mate.deep_gemm.fp8_paged_mqa_logits(q: torch.Tensor, fused_kv_cache: torch.Tensor, weights: torch.Tensor, context_lens: torch.Tensor, block_table: torch.Tensor, schedule_meta: torch.Tensor, max_context_len: int, clean_logits: bool) torch.Tensor[source]

FP8 Paged MQA logits

Parameters:
  • q (Tensor) – The FP8 query tensor with shape (batch_size, next_n, heads, index_dim)

  • fused_kv_cache (Tensor) – The FP8 kv cache with fp32 scale, shape (num_blocks, block_size, 1, index_dim + 4)

  • weights (Tensor) – The FP32 weight tensor for each query, shape (batch_size * next_n, heads)

  • context_lens (Tensor) –

    Context lengths tensor, supports two layouts:

    • 1D (batch_size,) — all next_n draft tokens of request i share the same context length context_lens[i]. The visible KV range for draft token j is implicitly [0, context_lens[i] - next_n + j].

    • 2D (batch_size, next_n) — each draft token has an independent context length context_lens[i, j], with visible KV range [0, context_lens[i, j] - 1]. Useful for tree-based speculative decoding (e.g. Medusa / EAGLE) where tokens on different branches see different KV prefixes.

    The shape is auto-detected; get_paged_mqa_logits_metadata must be called with the same context_lens tensor.

  • block_table (Tensor) – Block table tensor with shape (batch_size, max_blocks)

  • schedule_meta (Tensor) – Schedule metadata tensor with shape (num_mps + 1, 2), produced by get_paged_mqa_logits_metadata()

  • max_context_len (int) – Maximum context length

  • clean_logits (bool) – Whether to zero-fill logit positions that are out of the valid KV range

Returns:

FP32 logits, shape (batch_size * next_n, max_context_len)

Return type:

Tensor