See also: FlashAttention3 Forward Compatibility.
FlashAttention-3 Compatibility Wrapper (flash_attn_3)¶
flash_attn_3 is a compatibility wrapper package that preserves the official
FlashAttention-3 package surface while running on MUSA through MATE attention
operators.
Overview¶
This wrapper is designed for projects that already target FlashAttention-3
style Python APIs. It helps run existing integrations on MUSA through MATE
with minimal code changes. The current compatibility target is the
flash_attn_3 interface surface.
Package and import¶
Package name:
flash_attn_3Public import path:
flash_attn_interfaceInternal package path:
flash_attn_3Runtime backend: MATE attention operators on MUSA
For the current compatibility scope and known limitations, see
docs/source/wrappers/flash_attention_forward_compatibility.md.
Requirements¶
Before using this wrapper, make sure the following are available:
MATE is installed and importable.
TorchMUSA is installed and the MUSA runtime environment is configured.
The target workload is configured to run on MUSA devices.
Build¶
Build a wheel from the wrappers/flash-attention directory:
python -m build --wheel
The generated wheel will be placed under dist/.
Installation¶
Install from source:
pip install --no-build-isolation -e .
Install a built wheel:
pip install dist/flash_attn_3-*.whl
If you previously installed the legacy mate-flash-attention package, uninstall it before installing flash_attn_3 so the
environment does not keep stale wrapper metadata.
pip uninstall -y mate-flash-attention
pip install dist/flash_attn_3-*.whl
Import¶
Import the package directly:
import flash_attn_interface
Import individual APIs:
from flash_attn_interface import (
flash_attn_func,
flash_attn_varlen_func,
flash_attn_with_kvcache,
get_scheduler_metadata,
)
Public APIs¶
The wrapper currently exposes:
flash_attn_func: FlashAttention-compatible dense QKV entryflash_attn_varlen_func: FlashAttention-compatible varlen / ragged FMHA entryflash_attn_with_kvcache: FlashAttention-compatible KV-cache entry, including paged KV cache use casesget_scheduler_metadata: helper for split-KV scheduler metadata preparation
Quick Start¶
Minimal varlen FMHA example:
import torch
from flash_attn_interface import flash_attn_varlen_func
device = "musa"
dtype = torch.bfloat16
num_heads_q = 32
num_heads_kv = 8
head_dim = 128
cu_seqlens_q = torch.tensor([0, 32, 96], device=device, dtype=torch.int32)
cu_seqlens_k = torch.tensor([0, 64, 160], device=device, dtype=torch.int32)
q = torch.randn((96, num_heads_q, head_dim), device=device, dtype=dtype)
k = torch.randn((160, num_heads_kv, head_dim), device=device, dtype=dtype)
v = torch.randn((160, num_heads_kv, head_dim), device=device, dtype=dtype)
out = flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=64,
max_seqlen_k=96,
causal=False,
)
Paged KV cache skeleton with scheduler metadata:
from flash_attn_interface import flash_attn_with_kvcache, get_scheduler_metadata
metadata = get_scheduler_metadata(
batch_size=batch_size,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
num_heads_q=num_heads_q,
num_heads_kv=num_heads_kv,
headdim=head_dim_qk,
headdim_v=head_dim_v,
cache_seqlens=cache_seqlens,
cu_seqlens_q=cu_seqlens_q,
page_size=page_size,
num_splits=0,
pack_gqa=True,
)
out, lse, *_ = flash_attn_with_kvcache(
q=q_unpad,
k_cache=k_cache_paged,
v_cache=v_cache_paged,
cache_seqlens=cache_seqlens,
page_table=page_table,
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
scheduler_metadata=metadata,
return_softmax_lse=True,
)
Tests¶
Wrapper-level tests are available in:
tests/test_flash_attn.py
Run them from the wrappers/flash-attention directory:
pytest tests/test_flash_attn.py tests/test_interface_unit.py
Notes¶
This wrapper follows the FlashAttention-3 style Python package layout, and the current compatibility target is
flash_attn_3The
flash_attntop-level package is intentionally not shipped, so Transformers FA2 / FA4 package checks keep failingActual feature coverage is documented in
docs/source/wrappers/flash_attention_forward_compatibility.mdFor the authoritative operator behavior, refer to the corresponding MATE APIs under
mate.mha_interface