HyperConnection

Use these MATE HyperConnection APIs when wrapper coverage is not enough. For DeepGEMM-style prenorm GEMM flows, prefer the deep-gemm wrapper when it matches your framework surface.

MHC Pre

mate.hyperconnection.mhc_pre(residual: torch.Tensor, hc_fn: torch.Tensor, mhc_scale: torch.Tensor, mhc_base: torch.Tensor, *, rms_eps: float = 1e-06, mhc_pre_eps: float = 1e-06, mhc_sinkhorn_eps: float = 1e-06, mhc_post_mult_value: float = 2.0, sinkhorn_repeat: int = 20, split_k: int | None = None) Tuple[torch.Tensor, torch.Tensor, torch.Tensor][source]

Run MHC pre end to end.

This convenience wrapper runs the prenorm stage followed by the fused MHC pre stage. It intentionally exposes only model-level semantics; backend tuning knobs remain on mhc_prenorm_gemm_sqrsum and mhc_pre_big_fuse for callers that need explicit control.

Parameters:
  • residual – Bfloat16 input tensor with shape [..., mhc_mult, hidden_size].

  • hc_fn – Float32 prenorm GEMM weight with shape [mhc_mult * (2 + mhc_mult), mhc_mult * hidden_size].

  • mhc_scale – Float32 scale tensor with shape [3] for pre/post/comb branches.

  • mhc_base – Float32 bias tensor with shape [mhc_mult * (2 + mhc_mult)].

  • rms_eps – Epsilon added before RMS inverse square root.

  • mhc_pre_eps – Epsilon added to the sigmoid pre branch.

  • mhc_sinkhorn_eps – Epsilon used by Sinkhorn normalization.

  • mhc_post_mult_value – Multiplicative scale applied to the post branch.

  • sinkhorn_repeat – Number of Sinkhorn normalization iterations.

  • split_k – Optional prenorm split-K factor. If omitted, uses the decode/prefill default selected by mhc_prenorm_gemm_sqrsum.

Returns:

(post_mix, comb_mix, layer_input) with shapes [..., mhc_mult, 1], [..., mhc_mult, mhc_mult], and [..., hidden_size].

Return type:

tuple[torch.Tensor, torch.Tensor, torch.Tensor]

mate.hyperconnection.mhc_prenorm_gemm_sqrsum(residual_flat: torch.Tensor, hc_fn: torch.Tensor, *, backend: str = 'deepgemm', split_k: int | None = None, return_partials: bool = False) Tuple[torch.Tensor, torch.Tensor][source]

Run MHC prenorm GEMM and row square-sum.

Parameters:
  • residual_flat – Tensor with shape (M, mhc_mult, hidden_size) or (M, K) and dtype torch.bfloat16.

  • hc_fn – Weight tensor with shape (N, K) and dtype torch.float32.

  • backend – Prenorm backend selector.

  • split_k – Prenorm backend split-K factor. If omitted, uses 32 for decode-like M <= 64 and 16 otherwise.

  • return_partials – When true, return tensors with leading split dimension [S, M, ...].

mate.hyperconnection.mhc_pre_big_fuse(gemm_out_mul: torch.Tensor, gemm_out_sqrsum: torch.Tensor, mhc_scale: torch.Tensor, mhc_base: torch.Tensor, residual_flat: torch.Tensor, *, rms_eps: float = 1e-06, mhc_pre_eps: float = 1e-06, mhc_sinkhorn_eps: float = 1e-06, mhc_post_mult_value: float = 2.0, sinkhorn_repeat: int = 20, backend: str = 'tilelang', threads: int = 128, hidden_block: int = 512, pass_config: str = 'safe', compile_profile: str | None = None, post_mix: torch.Tensor | None = None, comb_mix: torch.Tensor | None = None, layer_input: torch.Tensor | None = None) Tuple[torch.Tensor, torch.Tensor, torch.Tensor][source]

Run the fused MHC pre stage.

This is the second stage of the DeepSeek V4 MHC pre path. It consumes the prenorm GEMM output mixes and row square-sum produced by mhc_prenorm_gemm_sqrsum or mate.deep_gemm.tf32_hc_prenorm_gemm, applies RMS scaling, computes the MHC pre/post/comb branches, runs Sinkhorn normalization for the comb branch, and writes the fused layer input.

gemm_out_mul and gemm_out_sqrsum may be either final outputs with shapes [num_tokens, mhc_mult * (2 + mhc_mult)] / [num_tokens] or split-K partials with shapes [num_splits, num_tokens, mhc_mult * (2 + mhc_mult)] / [num_splits, num_tokens]. Passing split-K partials avoids a separate reduction before the fused stage.

Parameters:
  • gemm_out_mul – Float32 prenorm GEMM output or split-K partials.

  • gemm_out_sqrsum – Float32 row square-sum output or split-K partials.

  • mhc_scale – Float32 scale tensor with shape [3] for pre/post/comb branches.

  • mhc_base – Float32 bias tensor with shape [mhc_mult * (2 + mhc_mult)].

  • residual_flat – Contiguous bfloat16 tensor with shape [num_tokens, mhc_mult, hidden_size].

  • rms_eps – Epsilon added before RMS inverse square root.

  • mhc_pre_eps – Epsilon added to the sigmoid pre branch.

  • mhc_sinkhorn_eps – Epsilon used by Sinkhorn normalization.

  • mhc_post_mult_value – Multiplicative scale applied to the post branch.

  • sinkhorn_repeat – Number of Sinkhorn normalization iterations.

  • backend – Fused-stage backend selector.

  • threads – Kernel thread count.

  • hidden_block – Hidden dimension tile size used by the layer-input reduction.

  • pass_config – Backend pass configuration name.

  • compile_profile – Optional MUSA compiler flag profile.

  • post_mix – Optional contiguous float32 output buffer with shape [num_tokens, mhc_mult].

  • comb_mix – Optional contiguous float32 output buffer with shape [num_tokens, mhc_mult * mhc_mult].

  • layer_input – Optional contiguous bfloat16 output buffer with shape [num_tokens, hidden_size].

Returns:

  • tuple[torch.Tensor, torch.Tensor, torch.Tensor](post_mix, comb_mix, layer_input).

  • residual_flat must have contiguous shape

  • [num_tokens, mhc_mult, hidden_size]. Prenorm inputs and optional output

  • buffers must also be contiguous; this avoids hidden copies and keeps the

  • indexing path specialized for the hot serving case. The returned

  • tensors are (post_mix, comb_mix, layer_input) with shapes

  • [num_tokens, mhc_mult], [num_tokens, mhc_mult * mhc_mult], and

  • [num_tokens, hidden_size]. Optional output buffers may be provided to

  • avoid allocations in hot serving paths or graph-capture paths.