Source code for mate.hyperconnection

"""HyperConnection operators.

The main MHC pre path is organized as:

1. prenorm GEMM + row square-sum;
2. fused MHC pre stage, which applies RMS scaling, split/sinkhorn, and
   produces the layer input.

The returned tensor order follows the SGLang MHC pre convention:
``(post_mix, comb_mix, layer_input)``.
"""

from typing import Optional, Tuple

import torch

from mate.api_logging import mate_api
from mate.hyperconnection_kernels.tilelang import mhc_pre as mhc_pre_tilelang

MHC_PRENORM_BACKEND_DEEPGEMM = "deepgemm"
MHC_BIG_FUSE_BACKEND_TILELANG = "tilelang"

__all__ = [
    "MHC_BIG_FUSE_BACKEND_TILELANG",
    "MHC_PRENORM_BACKEND_DEEPGEMM",
    "mhc_pre",
    "mhc_pre_big_fuse",
    "mhc_prenorm_gemm_sqrsum",
]


def _check_same_device(reference: torch.Tensor, **tensors: torch.Tensor) -> None:
    for name, tensor in tensors.items():
        if tensor.device != reference.device:
            raise ValueError(
                f"Expected {name} to be on device {reference.device}, got {tensor.device}."
            )


def _validate_mhc_inputs(
    residual: torch.Tensor,
    hc_fn: torch.Tensor,
    mhc_scale: torch.Tensor,
    mhc_base: torch.Tensor,
) -> Tuple[Tuple[int, ...], int, int, int]:
    if residual.dim() < 2:
        raise ValueError(
            f"residual must have shape [..., mhc_mult, hidden_size], got {residual.shape}."
        )
    if residual.dtype != torch.bfloat16:
        raise ValueError(f"residual must be bfloat16, got {residual.dtype}.")
    if hc_fn.dtype != torch.float32:
        raise ValueError(f"hc_fn must be float32, got {hc_fn.dtype}.")
    if mhc_scale.dtype != torch.float32 or mhc_base.dtype != torch.float32:
        raise ValueError("mhc_scale and mhc_base must be float32.")

    mhc_mult = residual.shape[-2]
    hidden_size = residual.shape[-1]
    mhc_mult3 = mhc_mult * (2 + mhc_mult)
    mhc_hidden_size = mhc_mult * hidden_size
    if hc_fn.shape != (mhc_mult3, mhc_hidden_size):
        raise ValueError(
            f"hc_fn must have shape {(mhc_mult3, mhc_hidden_size)}, "
            f"got {tuple(hc_fn.shape)}."
        )
    if mhc_scale.shape != (3,):
        raise ValueError(
            f"mhc_scale must have shape [3], got {tuple(mhc_scale.shape)}."
        )
    if mhc_base.shape != (mhc_mult3,):
        raise ValueError(
            f"mhc_base must have shape [{mhc_mult3}], got {tuple(mhc_base.shape)}."
        )
    _check_same_device(residual, hc_fn=hc_fn, mhc_scale=mhc_scale, mhc_base=mhc_base)
    outer_shape = tuple(residual.shape[:-2])
    num_tokens = 1
    for dim in outer_shape:
        num_tokens *= dim
    return outer_shape, num_tokens, mhc_mult, hidden_size


def _normalize_prenorm_partials(
    gemm_out_mul: torch.Tensor,
    gemm_out_sqrsum: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    if gemm_out_mul.dim() == 2:
        gemm_out_mul = gemm_out_mul.unsqueeze(0)
    if gemm_out_sqrsum.dim() == 1:
        gemm_out_sqrsum = gemm_out_sqrsum.unsqueeze(0)
    return gemm_out_mul, gemm_out_sqrsum


def _default_split_k(num_tokens: int) -> int:
    return 32 if num_tokens <= 64 else 16


[docs] @mate_api def mhc_prenorm_gemm_sqrsum( residual_flat: torch.Tensor, hc_fn: torch.Tensor, *, backend: str = MHC_PRENORM_BACKEND_DEEPGEMM, split_k: Optional[int] = None, return_partials: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: r"""Run MHC prenorm GEMM and row square-sum. Parameters ---------- residual_flat: Tensor with shape ``(M, mhc_mult, hidden_size)`` or ``(M, K)`` and dtype ``torch.bfloat16``. hc_fn: Weight tensor with shape ``(N, K)`` and dtype ``torch.float32``. backend: Prenorm backend selector. split_k: Prenorm backend split-K factor. If omitted, uses ``32`` for decode-like ``M <= 64`` and ``16`` otherwise. return_partials: When true, return tensors with leading split dimension ``[S, M, ...]``. """ backend = backend.strip().lower() if residual_flat.dtype != torch.bfloat16: raise ValueError(f"residual_flat must be bfloat16, got {residual_flat.dtype}.") if hc_fn.dtype != torch.float32: raise ValueError(f"hc_fn must be float32, got {hc_fn.dtype}.") x_flat = residual_flat.reshape(residual_flat.shape[0], -1) num_tokens, hc_hidden_size = x_flat.shape mhc_mult3 = hc_fn.shape[0] if hc_fn.shape != (mhc_mult3, hc_hidden_size): raise ValueError( f"hc_fn must have shape [N, {hc_hidden_size}], got {tuple(hc_fn.shape)}." ) _check_same_device(x_flat, hc_fn=hc_fn) if backend != MHC_PRENORM_BACKEND_DEEPGEMM: raise ValueError( f"Unsupported MHC prenorm backend={backend!r}; expected 'deepgemm'." ) import mate.deep_gemm as deep_gemm num_splits = _default_split_k(num_tokens) if split_k is None else int(split_k) if num_splits <= 1: d_out = torch.empty( (num_tokens, mhc_mult3), dtype=torch.float32, device=x_flat.device ) s_out = torch.empty((num_tokens,), dtype=torch.float32, device=x_flat.device) deep_gemm.tf32_hc_prenorm_gemm( x_flat, hc_fn.contiguous(), d_out, s_out, num_splits=None if num_splits <= 0 else num_splits, ) if return_partials: return d_out.unsqueeze(0), s_out.unsqueeze(0) return d_out, s_out d_part = torch.empty( (num_splits, num_tokens, mhc_mult3), dtype=torch.float32, device=x_flat.device, ) s_part = torch.empty( (num_splits, num_tokens), dtype=torch.float32, device=x_flat.device ) deep_gemm.tf32_hc_prenorm_gemm( x_flat, hc_fn.contiguous(), d_part, s_part, num_splits=num_splits, ) if return_partials: return d_part, s_part return d_part.sum(dim=0), s_part.sum(dim=0)
[docs] @mate_api def mhc_pre_big_fuse( gemm_out_mul: torch.Tensor, gemm_out_sqrsum: torch.Tensor, mhc_scale: torch.Tensor, mhc_base: torch.Tensor, residual_flat: torch.Tensor, *, rms_eps: float = 1e-6, mhc_pre_eps: float = 1e-6, mhc_sinkhorn_eps: float = 1e-6, mhc_post_mult_value: float = 2.0, sinkhorn_repeat: int = 20, backend: str = MHC_BIG_FUSE_BACKEND_TILELANG, threads: int = 128, hidden_block: int = 512, pass_config: str = "safe", compile_profile: Optional[str] = None, post_mix: Optional[torch.Tensor] = None, comb_mix: Optional[torch.Tensor] = None, layer_input: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: r"""Run the fused MHC pre stage. This is the second stage of the DeepSeek V4 MHC pre path. It consumes the prenorm GEMM output ``mixes`` and row square-sum produced by ``mhc_prenorm_gemm_sqrsum`` or ``mate.deep_gemm.tf32_hc_prenorm_gemm``, applies RMS scaling, computes the MHC pre/post/comb branches, runs Sinkhorn normalization for the comb branch, and writes the fused layer input. ``gemm_out_mul`` and ``gemm_out_sqrsum`` may be either final outputs with shapes ``[num_tokens, mhc_mult * (2 + mhc_mult)]`` / ``[num_tokens]`` or split-K partials with shapes ``[num_splits, num_tokens, mhc_mult * (2 + mhc_mult)]`` / ``[num_splits, num_tokens]``. Passing split-K partials avoids a separate reduction before the fused stage. Parameters ---------- gemm_out_mul: Float32 prenorm GEMM output or split-K partials. gemm_out_sqrsum: Float32 row square-sum output or split-K partials. mhc_scale: Float32 scale tensor with shape ``[3]`` for pre/post/comb branches. mhc_base: Float32 bias tensor with shape ``[mhc_mult * (2 + mhc_mult)]``. residual_flat: Contiguous bfloat16 tensor with shape ``[num_tokens, mhc_mult, hidden_size]``. rms_eps: Epsilon added before RMS inverse square root. mhc_pre_eps: Epsilon added to the sigmoid pre branch. mhc_sinkhorn_eps: Epsilon used by Sinkhorn normalization. mhc_post_mult_value: Multiplicative scale applied to the post branch. sinkhorn_repeat: Number of Sinkhorn normalization iterations. backend: Fused-stage backend selector. threads: Kernel thread count. hidden_block: Hidden dimension tile size used by the layer-input reduction. pass_config: Backend pass configuration name. compile_profile: Optional MUSA compiler flag profile. post_mix: Optional contiguous float32 output buffer with shape ``[num_tokens, mhc_mult]``. comb_mix: Optional contiguous float32 output buffer with shape ``[num_tokens, mhc_mult * mhc_mult]``. layer_input: Optional contiguous bfloat16 output buffer with shape ``[num_tokens, hidden_size]``. Returns ------- tuple[torch.Tensor, torch.Tensor, torch.Tensor] ``(post_mix, comb_mix, layer_input)``. ``residual_flat`` must have contiguous shape ``[num_tokens, mhc_mult, hidden_size]``. Prenorm inputs and optional output buffers must also be contiguous; this avoids hidden copies and keeps the indexing path specialized for the hot serving case. The returned tensors are ``(post_mix, comb_mix, layer_input)`` with shapes ``[num_tokens, mhc_mult]``, ``[num_tokens, mhc_mult * mhc_mult]``, and ``[num_tokens, hidden_size]``. Optional output buffers may be provided to avoid allocations in hot serving paths or graph-capture paths. """ backend = backend.strip().lower() if residual_flat.dim() != 3: raise ValueError( "residual_flat must have shape [num_tokens, mhc_mult, hidden_size]." ) if sinkhorn_repeat <= 0: raise ValueError(f"sinkhorn_repeat must be > 0, got {sinkhorn_repeat}.") if backend != MHC_BIG_FUSE_BACKEND_TILELANG: raise ValueError( f"Unsupported MHC big-fuse backend={backend!r}; expected 'tilelang'." ) mhc_mult = residual_flat.shape[1] hidden_size = residual_flat.shape[2] num_tokens = residual_flat.shape[0] if post_mix is None: post_mix = torch.empty( (num_tokens, mhc_mult), dtype=torch.float32, device=residual_flat.device ) if comb_mix is None: comb_mix = torch.empty( (num_tokens, mhc_mult * mhc_mult), dtype=torch.float32, device=residual_flat.device, ) if layer_input is None: layer_input = torch.empty( (num_tokens, hidden_size), dtype=torch.bfloat16, device=residual_flat.device ) mhc_pre_tilelang.run_mhc_pre_big_fuse( gemm_out_mul, gemm_out_sqrsum, mhc_scale, mhc_base, residual_flat, post_mix, comb_mix, layer_input, rms_eps=rms_eps, mhc_pre_eps=mhc_pre_eps, mhc_sinkhorn_eps=mhc_sinkhorn_eps, mhc_post_mult_value=mhc_post_mult_value, sinkhorn_repeat=sinkhorn_repeat, threads=threads, hidden_block=hidden_block, pass_config=pass_config, compile_profile=compile_profile, ) return post_mix, comb_mix, layer_input
[docs] @mate_api def mhc_pre( residual: torch.Tensor, hc_fn: torch.Tensor, mhc_scale: torch.Tensor, mhc_base: torch.Tensor, *, rms_eps: float = 1e-6, mhc_pre_eps: float = 1e-6, mhc_sinkhorn_eps: float = 1e-6, mhc_post_mult_value: float = 2.0, sinkhorn_repeat: int = 20, split_k: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: r"""Run MHC pre end to end. This convenience wrapper runs the prenorm stage followed by the fused MHC pre stage. It intentionally exposes only model-level semantics; backend tuning knobs remain on ``mhc_prenorm_gemm_sqrsum`` and ``mhc_pre_big_fuse`` for callers that need explicit control. Parameters ---------- residual: Bfloat16 input tensor with shape ``[..., mhc_mult, hidden_size]``. hc_fn: Float32 prenorm GEMM weight with shape ``[mhc_mult * (2 + mhc_mult), mhc_mult * hidden_size]``. mhc_scale: Float32 scale tensor with shape ``[3]`` for pre/post/comb branches. mhc_base: Float32 bias tensor with shape ``[mhc_mult * (2 + mhc_mult)]``. rms_eps: Epsilon added before RMS inverse square root. mhc_pre_eps: Epsilon added to the sigmoid pre branch. mhc_sinkhorn_eps: Epsilon used by Sinkhorn normalization. mhc_post_mult_value: Multiplicative scale applied to the post branch. sinkhorn_repeat: Number of Sinkhorn normalization iterations. split_k: Optional prenorm split-K factor. If omitted, uses the decode/prefill default selected by ``mhc_prenorm_gemm_sqrsum``. Returns ------- tuple[torch.Tensor, torch.Tensor, torch.Tensor] ``(post_mix, comb_mix, layer_input)`` with shapes ``[..., mhc_mult, 1]``, ``[..., mhc_mult, mhc_mult]``, and ``[..., hidden_size]``. """ outer_shape, num_tokens, mhc_mult, hidden_size = _validate_mhc_inputs( residual, hc_fn, mhc_scale, mhc_base ) residual_flat = residual.reshape(num_tokens, mhc_mult, hidden_size) gemm_out_mul, gemm_out_sqrsum = mhc_prenorm_gemm_sqrsum( residual_flat, hc_fn, split_k=split_k, return_partials=True, ) post_mix, comb_mix, layer_input = mhc_pre_big_fuse( gemm_out_mul, gemm_out_sqrsum, mhc_scale, mhc_base, residual_flat, rms_eps=rms_eps, mhc_pre_eps=mhc_pre_eps, mhc_sinkhorn_eps=mhc_sinkhorn_eps, mhc_post_mult_value=mhc_post_mult_value, sinkhorn_repeat=sinkhorn_repeat, ) return ( post_mix.reshape(*outer_shape, mhc_mult, 1), comb_mix.reshape(*outer_shape, mhc_mult, mhc_mult), layer_input.reshape(*outer_shape, hidden_size), )