DeepGEMM Compatibility Wrapper (deep-gemm)¶
deep-gemm is a compatibility wrapper package that preserves the
deep_gemm import path while running on MUSA through MATE GEMM operators.
Overview¶
This wrapper is designed for projects that already target DeepGEMM-style Python APIs. It helps run existing integrations on MUSA through MATE with minimal code changes.
The current compatibility scope includes grouped GEMM, dense BF16 / FP8 GEMM, FP8 einsum, HyperConnection prenorm GEMM, and MQA (multi-query attention) logits APIs.
Package and import¶
Package name:
deep-gemmImport path:
deep_gemmRuntime backend: MATE GEMM and logits operators on MUSA
Requirements¶
Before using this wrapper, make sure the following are available:
MATE is installed and importable.
TorchMUSA is installed and the MUSA runtime environment is configured.
The target workload is configured to run on MUSA devices.
Build¶
Build a wheel from the wrappers/DeepGEMM directory:
python -m build --wheel
The generated wheel will be placed under:
dist/
Installation¶
Install from source:
pip install --no-build-isolation -e .
Install a built wheel:
pip install dist/deep_gemm-*.whl
If you previously installed the legacy mate-deep-gemm package, uninstall it
before installing deep-gemm so the environment does not keep stale wrapper
metadata.
pip uninstall -y mate-deep-gemm
pip install dist/deep_gemm-*.whl
Import¶
Import the package directly:
import deep_gemm
Import individual APIs:
from deep_gemm import (
bf16_gemm_nt,
m_grouped_bf16_gemm_nt_contiguous,
m_grouped_bf16_gemm_nt_masked,
m_grouped_fp8_gemm_nt_contiguous,
m_grouped_fp8_gemm_nt_masked,
fp8_gemm_nt,
fp8_einsum,
tf32_hc_prenorm_gemm,
get_paged_mqa_logits_metadata,
fp8_paged_mqa_logits,
fp8_mqa_logits,
)
Public APIs¶
Dense BF16 GEMM:
bf16_gemm_nt
Grouped GEMM:
m_grouped_bf16_gemm_nt_contiguousm_grouped_bf16_gemm_nt_maskedm_grouped_fp8_gemm_nt_contiguousm_grouped_fp8_gemm_nt_maskedLegacy aliases:
fp8_m_grouped_gemm_nt_masked,bf16_m_grouped_gemm_nt_masked
Dense FP8 GEMM:
fp8_gemm_ntfp8_einsum
HyperConnection prenorm GEMM:
tf32_hc_prenorm_gemm
MQA logits APIs:
get_paged_mqa_logits_metadatafp8_paged_mqa_logitsfp8_mqa_logits
Utility helpers re-exported from deep_gemm.utils:
bench,bench_kineto,calc_diffget_num_sms,set_num_smsget_tc_util,set_tc_utilget_mk_alignment_for_contiguous_layoutget_col_major_tma_aligned_tensorget_mn_major_tma_aligned_tensor
Testing helpers are implemented in mate.testing.deep_gemm and are also
available from the upstream-style path:
deep_gemm.testing.benchdeep_gemm.testing.bench_kinetodeep_gemm.testing.calc_diff
Contiguous Grouped GEMM Alignment¶
get_mk_alignment_for_contiguous_layout() returns the M-axis padding alignment
used by DeepGEMM-compatible contiguous grouped GEMM wrappers. It defaults to
128 and can be overridden with:
export MATE_DEEPGEMM_MK_ALIGNMENT=256
Only 128 and 256 are supported. Use the returned value when padding each
expert segment and building m_indices for
m_grouped_{fp8,bf16}_gemm_nt_contiguous. Set the environment variable before
starting Python; the helper reads and caches the value on first use.
Quick Start¶
Minimal import example:
import deep_gemm
An example script is provided at:
examples/run_deep_gemm.py
Examples¶
Run the bundled example:
python examples/run_deep_gemm.py
Notes¶
This wrapper preserves the DeepGEMM-style Python surface, but execution is provided by MATE on MUSA
get_paged_mqa_logits_metadata(..., block_kv, ...)currently requiresblock_kv == 64The example script currently demonstrates the FP8 grouped GEMM path