GEMM¶
MoE GEMM¶
- mate.gemm.ragged_m_moe_gemm_8bit(input_a: Tuple[torch.Tensor, torch.Tensor], input_b: Tuple[torch.Tensor, torch.Tensor], ragged_tokens_info: torch.Tensor, out: torch.Tensor, gemm_mode: Literal['per_token', 'psum_expert', 'per_expert'] | None = 'per_token', major_a_mode: Literal['M', 'K'] | None = 'K', major_b_mode: Literal['N', 'K'] | None = 'K', scale_granularity_mnk: Tuple[int, int, int] | None = None, num_mp: int | None = None, alignment_m: int | None = None)[source]¶
Perform 8-bit GEMM operation for MoE (Mixture of Experts) with ragged tensor inputs.
This function computes matrix multiplication between 8-bit quantized tensors for MoE models where different experts may have variable numbers of tokens assigned to them.
- Parameters:
input_a (Tuple[Tensor, Tensor]) – Tuple containing (fp8_tensor, scale_tensor) for input A. fp8_tensor has shape
(total_tokens, hidden_size)and should be of fp8 (e4m3/e5m2) type. scale_tensor has shape(total_tokens, hidden_size // scale_granularity_m)and should be of fp32 type.input_b (Tuple[Tensor, Tensor]) – Tuple containing (fp8_tensor, scale_tensor) for input B. fp8_tensor has shape
(num_expert, out_hidden_size, hidden_size)and should be of fp8 (e4m3/e5m2) type. scale_tensor has shape(num_expert, out_hidden_size // scale_granularity_n, hidden_size // scale_granularity_k)and should be of fp32 type.ragged_tokens_info (Tensor) –
Tensor indicating which expert each token belongs to, with shape
(total_tokens,). Values represent expert indices, with -1 for unused positions.If gemm_mode is per_token: Tensor indicating which expert each token belongs to, with shape
(total_tokens,). Values represent expert indices, with -1 for unused positions.- If gemm_mode is psum_expert
Tensor with shape (num_expert, ), indicating how many tokens that first few experts have.
- If gemm_mode is per_expert
Tensor with shape (num_expert, ), indicating how many tokens that every expert has.
out (Tensor) – Output tensor with shape
(total_tokens, out_hidden_size).major_a_mode (Optional[str]) – Indicating major stride of A. Default to K.
major_b_mode (Optional[str]) – Indicating major stride of B. Default to K.
gemm_mode (Optional[str],) – Indicating different meaning of ragged_tokens_info.
scale_granularity_mnk (Optional[Tuple[int, int, int]]) – Quantization granularity for total_tokens, out_hidden_size, hidden_size (m, n, k) dimensions respectively. Default is
(1, 128, 128).alignment_m (Optional[int]) – Alignment requirement for total_tokens (m) dimension. Must be 128 or 256. Default is 128.
num_mp (Optional[int]) – Suggest mp number. If None, will be get from device info.
- Returns:
Result tensor with shape
(total_tokens, out_hidden_size)containing the GEMM output in fp16 or bf16 data type.- Return type:
Tensor
- mate.gemm.ragged_k_moe_gemm_8bit(input_a: Tuple[torch.Tensor, torch.Tensor], input_b: Tuple[torch.Tensor, torch.Tensor], ragged_tokens_info: torch.Tensor, out: torch.Tensor, gemm_mode: Literal['per_expert'] | None = 'per_expert', major_a_mode: Literal['M', 'K'] | None = 'M', major_b_mode: Literal['N', 'K'] | None = 'N', scale_granularity_mnk: Tuple[int, int, int] | None = None, num_mp: int | None = None)[source]¶
Perform 8-bit GEMM operation for MoE (Mixture of Experts) with token of each expert.
This function computes matrix multiplication between 8-bit quantized tensors for MoE models where different experts may have variable numbers of tokens.
- Parameters:
input_a (Tuple[Tensor, Tensor]) – Tuple containing (fp8_tensor, scale_tensor) for input A. fp8_tensor has shape
(k, m)and should be of fp8 (e4m3/e5m2) type. scale_tensor has shape(k // scale_granularity_k, m)and should be of fp32 type.input_b (Tuple[Tensor, Tensor]) – Tuple containing (fp8_tensor, scale_tensor) for input B. fp8_tensor has shape
(k, n)and should be of fp8 (e4m3/e5m2) type. scale_tensor has shape(k // scale_granularity_k, n)and should be of fp32 type.ragged_tokens_info (Tensor) – Tensor indicating the actual number of tokens for each expert, with shape
(num_expert,). Values represent token counts for each expert.out (Tensor) – Output tensor with shape
(num_expert, max_tokens, out_hidden_size). Should be of float type. Should not be None.gemm_mode (Optional[str],) – Indicating different meaning of ragged_tokens_info.
major_a_mode (Optional[str]) – Major mode of A, defult to M. Only support TN m_grouped_gemm on MP31.
major_b_mode (Optional[str]) – Major mode of B, defult to N.
scale_granularity_mnk (Optional[Tuple[int, int, int]]) – Quantization granularity for max_tokens, out_hidden_size, hidden_size (m, n, k) dimensions respectively. Kgroupgemm only support 1D1D scale, should be
(1, 1, 128).num_mp (Optional[int]) – Suggest mp number. If None, will be get from device info.
- Returns:
Result tensor with shape
(num_experts, total_tokens, out_hidden_size)containing the GEMM output in float data type,Representing D = D + A * B for each expert
- mate.gemm.masked_moe_gemm_8bit(input_a: Tuple[torch.Tensor, torch.Tensor], input_b: Tuple[torch.Tensor, torch.Tensor], masked_tokens_info: torch.Tensor, out: torch.Tensor, scale_granularity_mnk: Tuple[int, int, int] | None = None, expect_tokens: int | None = None, enable_overlap: bool = False, signal: torch.Tensor | None = None)[source]¶
Perform 8-bit GEMM operation for MoE (Mixture of Experts) with masked tensor inputs.
This function computes matrix multiplication between 8-bit quantized tensors for MoE models where different experts may have variable numbers of tokens, using a mask to indicate the actual number of tokens per expert.
- Parameters:
input_a (Tuple[Tensor, Tensor]) – Tuple containing (fp8_tensor, scale_tensor) for input A. fp8_tensor has shape
(num_expert, max_tokens, hidden_size)and should be of fp8 (e4m3/e5m2) type. scale_tensor has shape(num_expert, max_tokens, hidden_size // scale_granularity_k)and should be of fp32 type.input_b (Tuple[Tensor, Tensor]) – Tuple containing (fp8_tensor, scale_tensor) for input B. fp8_tensor has shape
(num_expert, out_hidden_size, hidden_size)and should be of fp8 (e4m3/e5m2) type. scale_tensor has shape(num_expert, out_hidden_size // scale_granularity_n, hidden_size // scale_granularity_k)and should be of fp32 type.masked_tokens_info (Tensor) – Tensor indicating the actual number of tokens for each expert, with shape
(num_expert,). Values represent token counts for each expert.out (Tensor) – Output tensor with shape
(num_expert, max_tokens, out_hidden_size). Should be of fp16 or bf16 type. If None, a new tensor will be created.scale_granularity_mnk (Optional[Tuple[int, int, int]]) – Quantization granularity for max_tokens, out_hidden_size, hidden_size (m, n, k) dimensions respectively. Default is
(1, 128, 128).expect_tokens (Optional[int]) – Expected number of tokens. If None, defaults to 0.
enable_overlap (Optional[bool]) – Whether to enable Single-Batch Overlap (SBO). Default is False.
signal (Optional[Tensor]) – Signal tensor with shape ``(num_expert * ceil_div(max_m, 64))``for SBO. Required if enable_overlap is True. If None, a new tensor will be created if needed.
- Returns:
If
enable_overlapisFalse, returns result tensor with shape(num_expert, max_tokens, out_hidden_size). Ifenable_overlapisTrue, returns a tuple containing:result tensor with shape
(num_expert, max_tokens, out_hidden_size)signal tensor
block_m int
threshold int
- Return type:
Union[Tensor, Tuple[Tensor, Tensor, int, int]]
Dense GEMM¶
- mate.gemm.bmm_fp16(a: torch.Tensor, b: torch.Tensor, out_dtype: torch.dtype, out: torch.Tensor | None = None, backend: str = 'auto', c: torch.Tensor | None = None)[source]¶
- mate.gemm.bmm_fp8(a: torch.Tensor, b: torch.Tensor, a_scale: torch.Tensor, b_scale: torch.Tensor, out_dtype: torch.dtype, out: torch.Tensor | None = None, backend: str = 'auto', scale_granularity_mnk: Tuple[int, int, int] | None = None, output_scale: torch.Tensor | None = None)[source]¶
Perform batched matrix multiplication with FP8 quantized tensors.
This function computes the batched matrix multiplication of two FP8 quantized tensors, applying scaling factors to produce a result in the specified output data type.
- Parameters:
a (Tensor) – Input tensor A with shape
(batch, m, k)in FP8 format (e4m3/e5m2). The `k` dimension must be contiguous.b (Tensor) – Input tensor B with shape
(batch, k, n)in FP8 format (e4m3/e5m2). The `k` dimension must be contiguous.a_scale (Tensor) – Scaling factors for tensor A with shape depending on scale_granularity. Should be of fp32 type.
b_scale (Tensor) – Scaling factors for tensor B with shape depending on scale_granularity. Should be of fp32 type.
out_dtype (torch.dtype) – Data type for the output tensor. Only torch.bfloat16 and torch.float16 are supported.
out (Optional[Tensor]) – Pre-allocated output tensor with shape
(batch, m, n). Default is None. If None, a new tensor will be allocated.backend (str) – Backend to use for the operation. Current support backends are “mudnn” and “auto”. Default is “auto”.
scale_granularity_mnk (Optional[Tuple[int, int, int]]) – Granularity of scaling for batch, m, and n dimensions respectively. Only
(-1, -1, -1)and(1, -1, -1)are supported. If None, defaults to(-1, -1, -1).
- Returns:
Result tensor with shape
(batch, m, n)in the specified output data type.- Return type:
Tensor
- mate.gemm.gemm_fp8_nt_groupwise(a: torch.Tensor, b: torch.Tensor, a_scale: torch.Tensor, b_scale: torch.Tensor, scale_major_mode: Literal['MN', 'K'] | None = None, mma_sm: int | None = None, scale_granularity_mnk: Tuple[int, int, int] | None = None, out: torch.Tensor | None = None, out_dtype: torch.dtype | None = None, backend: str = 'auto', output_scale: torch.Tensor | None = None)[source]¶
Perform groupwise FP8 GEMM operation with scaling.
This function computes the matrix multiplication of two FP8 quantized tensors, applying scaling factors to produce a result in the specified output data type. It supports groupwise quantization with configurable scale granularity.
- Parameters:
a (Tensor) – Input tensor A with shape
(m, k)in FP8 format (e4m3/e5m2). Tensor must be contiguous.b (Tensor) – Input tensor B with shape
(n, k)in FP8 format (e4m3/e5m2). Tensor must be contiguous.a_scale (Tensor) – Scaling factors for tensor A. Shape depends on scale_granularity_mnk parameter. Should be of fp32 type. Must be contiguous.
b_scale (Tensor) – Scaling factors for tensor B. Shape depends on scale_granularity_mnk parameter. Should be of fp32 type. Must be contiguous.
scale_major_mode (str) – Scale major mode “MN” or “K” for groupwise operations. Default is “K”.
mma_sm (Optional[int]) – MMA SM configuration. Currently only supports 1. Default is 1.
scale_granularity_mnk (Optional[Tuple[int, int, int]]) – Granularity of scaling for m, n, and k dimensions respectively. Default is
(1, 128, 128).out (Optional[Tensor]) – Pre-allocated output tensor with shape
(m, n). Should be bf16/fp16 whenoutput_scaleis None, or fp8_e4m3 whenoutput_scaleis provided. If None, a new tensor will be allocated.out_dtype (Optional[torch.dtype]) – Data type for the output tensor when
outis None. Ifoutis provided,out.dtypeis validated instead. Defaults to torch.bfloat16 withoutoutput_scaleand fp8_e4m3 withoutput_scale.backend (str) – Backend to use for the operation. Use
"mudnn"whenoutput_scaleis None and"mubin"whenoutput_scaleis provided."auto"selects the supported backend for the selected output path.output_scale (Optional[torch.Tensor]) – Quantization scale tensor for FP8 output. If provided, the operation uses the mubin FP8-output path. If None, output is not quantized. Default is None.
- Returns:
Result tensor with shape
(m, n)in the specified output data type.- Return type:
Tensor
DeepGemm Lighting Indexer¶
- mate.deep_gemm.get_paged_mqa_logits_metadata(context_lens: torch.Tensor, block_kv: int, num_mps: int = 0) torch.Tensor[source]¶
Get metadata for paged MQA logits
- Parameters:
context_lens (Tensor) – Context lengths of each query, shape
(batch_size)block_kv (Tensor) – Block size of kv cache, must be 64 now.
num_mps (int) – Number of MP to execute. 0 means use all MPs of the current device
- Returns:
Schedule metadata, shape
(num_mps + 1, 2)- Return type:
Tensor
- mate.deep_gemm.fp8_paged_mqa_logits(q: torch.Tensor, fused_kv_cache: torch.Tensor, weights: torch.Tensor, context_lens: torch.Tensor, block_table: torch.Tensor, schedule_meta: torch.Tensor, max_context_len: int, clean_logits: bool) torch.Tensor[source]¶
FP8 Paged MQA logits
- Parameters:
q (Tensor) – The FP8 query tensor with shape
(batch_size, next_n, heads, index_dim)fused_kv_cache (Tensor) – The FP8 kv cache with fp32 scale, shape
(num_blocks, block_size, 1, index_dim + 4)weights (Tensor) – The FP32 weight tensor for each query, shape
(batch_size * next_n, heads)context_lens (Tensor) –
Context lengths tensor, supports two layouts:
1D
(batch_size,)— allnext_ndraft tokens of requestishare the same context lengthcontext_lens[i]. The visible KV range for draft tokenjis implicitly[0, context_lens[i] - next_n + j].2D
(batch_size, next_n)— each draft token has an independent context lengthcontext_lens[i, j], with visible KV range[0, context_lens[i, j] - 1]. Useful for tree-based speculative decoding (e.g. Medusa / EAGLE) where tokens on different branches see different KV prefixes.
The shape is auto-detected;
get_paged_mqa_logits_metadatamust be called with the samecontext_lenstensor.block_table (Tensor) – Block table tensor with shape
(batch_size, max_blocks)schedule_meta (Tensor) – Schedule metadata tensor with shape
(num_mps + 1, 2), produced byget_paged_mqa_logits_metadata()max_context_len (int) – Maximum context length
clean_logits (bool) – Whether to zero-fill logit positions that are out of the valid KV range
- Returns:
FP32 logits, shape
(batch_size * next_n, max_context_len)- Return type:
Tensor