SageAttention Compatibility Wrapper (sageattention)¶
sageattention is a compatibility wrapper that preserves the standard
SageAttention Python API surface, running on top of MATE’s dense quantized
attention operators on MUSA.
Overview¶
This wrapper is designed for projects that already target SageAttention-style Python APIs, allowing you to run on MUSA through MATE with minimal integration effort.
The current compatibility scope includes sageattn and
sageattn_qk_int8_pv_fp8_cuda_sm90.
Package and import¶
Package name:
sageattentionImport path:
sageattentionRuntime backend: MATE dense quantized attention operators on MUSA
Requirements¶
Before using this wrapper, make sure the following are available:
MATE is installed and importable.
TorchMUSA and the MUSA runtime environment are available.
The target workload is configured to execute on MUSA devices.
Build¶
Build a wheel from the wrappers/SageAttention directory:
python -m build --wheel
The generated wheel will be placed under:
dist/
Installation¶
Install from source:
pip install --no-build-isolation -e .
Install a built wheel:
pip install dist/sageattention-*.whl
Import¶
Import the package directly:
import sageattention
Import individual APIs:
from sageattention import sageattn, sageattn_qk_int8_pv_fp8_cuda_sm90
Public APIs¶
The wrapper currently exposes:
sageattn: primary SageAttention-compatible dense attention entrysageattn_qk_int8_pv_fp8_cuda_sm90: compatibility alias for the supported dense quantized attention path
Quick Start¶
Minimal dense attention example:
import torch
from sageattention import sageattn
device = "musa"
dtype = torch.bfloat16
q = torch.randn((1, 8, 128, 128), device=device, dtype=dtype)
k = torch.randn((1, 8, 128, 128), device=device, dtype=dtype)
v = torch.randn((1, 8, 128, 128), device=device, dtype=dtype)
out = sageattn(
q,
k,
v,
tensor_layout="HND",
is_causal=False,
qk_quant_dtype="int8",
)
FP8 output example:
out_fp8, out_scale = sageattn(
q,
k,
v,
tensor_layout="HND",
is_causal=False,
qk_quant_dtype="int8",
fp8_output=True,
)
out_dequant = out_fp8.to(torch.float32) * out_scale
When return_lse=True, the FP8 output form returns
(out_fp8, out_scale, lse). Without FP8 output, the return forms are out or
(out, lse).
Tests¶
Wrapper-level tests are available in:
tests/test_sageattn_interface.py
Run them from the wrappers/SageAttention directory:
pytest tests/test_sageattn_interface.py
Notes¶
This wrapper currently supports the dense SageAttention path only
Input tensors must be on the same MUSA device, use
torch.float16ortorch.bfloat16, and share the same dtypeSupported public
tensor_layoutvalues are"HND"and"NHD"Supported head dimensions are positive values up to
128qk_quant_dtypesupportsint8andfp8The default quantization recipe is
(128, 16, -1, 1); passingquant_recipeoverridesqk_quant_granOnly
qk_quant_gran="per_thread"is supported as a shortcut; other granularities should be expressed via an explicit supportedquant_recipefp8_output=Truereturns an FP8 tensor plus atorch.float32out_scaletensor;out_scalehas the same public tensor layout as the output and a final scale dimension of1Unsupported in this wrapper package: varlen, KV-cache wrapper entrypoints, public INT8 wrapper entrypoints other than the SM90-compatible name, and low-level pre-quantized public APIs