HyperConnection¶
MHC Pre¶
- mate.hyperconnection.mhc_pre(residual: torch.Tensor, hc_fn: torch.Tensor, mhc_scale: torch.Tensor, mhc_base: torch.Tensor, *, rms_eps: float = 1e-06, mhc_pre_eps: float = 1e-06, mhc_sinkhorn_eps: float = 1e-06, mhc_post_mult_value: float = 2.0, sinkhorn_repeat: int = 20, split_k: int | None = None) Tuple[torch.Tensor, torch.Tensor, torch.Tensor][source]¶
Run MHC pre end to end.
This convenience wrapper runs the prenorm stage followed by the fused MHC pre stage. It intentionally exposes only model-level semantics; backend tuning knobs remain on
mhc_prenorm_gemm_sqrsumandmhc_pre_big_fusefor callers that need explicit control.- Parameters:
residual – Bfloat16 input tensor with shape
[..., mhc_mult, hidden_size].hc_fn – Float32 prenorm GEMM weight with shape
[mhc_mult * (2 + mhc_mult), mhc_mult * hidden_size].mhc_scale – Float32 scale tensor with shape
[3]for pre/post/comb branches.mhc_base – Float32 bias tensor with shape
[mhc_mult * (2 + mhc_mult)].rms_eps – Epsilon added before RMS inverse square root.
mhc_pre_eps – Epsilon added to the sigmoid pre branch.
mhc_sinkhorn_eps – Epsilon used by Sinkhorn normalization.
mhc_post_mult_value – Multiplicative scale applied to the post branch.
sinkhorn_repeat – Number of Sinkhorn normalization iterations.
split_k – Optional prenorm split-K factor. If omitted, uses the decode/prefill default selected by
mhc_prenorm_gemm_sqrsum.
- Returns:
(post_mix, comb_mix, layer_input)with shapes[..., mhc_mult, 1],[..., mhc_mult, mhc_mult], and[..., hidden_size].- Return type:
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
- mate.hyperconnection.mhc_prenorm_gemm_sqrsum(residual_flat: torch.Tensor, hc_fn: torch.Tensor, *, backend: str = 'deepgemm', split_k: int | None = None, return_partials: bool = False) Tuple[torch.Tensor, torch.Tensor][source]¶
Run MHC prenorm GEMM and row square-sum.
- Parameters:
residual_flat – Tensor with shape
(M, mhc_mult, hidden_size)or(M, K)and dtypetorch.bfloat16.hc_fn – Weight tensor with shape
(N, K)and dtypetorch.float32.backend – Prenorm backend selector.
split_k – Prenorm backend split-K factor. If omitted, uses
32for decode-likeM <= 64and16otherwise.return_partials – When true, return tensors with leading split dimension
[S, M, ...].
- mate.hyperconnection.mhc_pre_big_fuse(gemm_out_mul: torch.Tensor, gemm_out_sqrsum: torch.Tensor, mhc_scale: torch.Tensor, mhc_base: torch.Tensor, residual_flat: torch.Tensor, *, rms_eps: float = 1e-06, mhc_pre_eps: float = 1e-06, mhc_sinkhorn_eps: float = 1e-06, mhc_post_mult_value: float = 2.0, sinkhorn_repeat: int = 20, backend: str = 'tilelang', threads: int = 128, hidden_block: int = 512, pass_config: str = 'safe', compile_profile: str | None = None, post_mix: torch.Tensor | None = None, comb_mix: torch.Tensor | None = None, layer_input: torch.Tensor | None = None) Tuple[torch.Tensor, torch.Tensor, torch.Tensor][source]¶
Run the fused MHC pre stage.
This is the second stage of the DeepSeek V4 MHC pre path. It consumes the prenorm GEMM output
mixesand row square-sum produced bymhc_prenorm_gemm_sqrsumormate.deep_gemm.tf32_hc_prenorm_gemm, applies RMS scaling, computes the MHC pre/post/comb branches, runs Sinkhorn normalization for the comb branch, and writes the fused layer input.gemm_out_mulandgemm_out_sqrsummay be either final outputs with shapes[num_tokens, mhc_mult * (2 + mhc_mult)]/[num_tokens]or split-K partials with shapes[num_splits, num_tokens, mhc_mult * (2 + mhc_mult)]/[num_splits, num_tokens]. Passing split-K partials avoids a separate reduction before the fused stage.- Parameters:
gemm_out_mul – Float32 prenorm GEMM output or split-K partials.
gemm_out_sqrsum – Float32 row square-sum output or split-K partials.
mhc_scale – Float32 scale tensor with shape
[3]for pre/post/comb branches.mhc_base – Float32 bias tensor with shape
[mhc_mult * (2 + mhc_mult)].residual_flat – Contiguous bfloat16 tensor with shape
[num_tokens, mhc_mult, hidden_size].rms_eps – Epsilon added before RMS inverse square root.
mhc_pre_eps – Epsilon added to the sigmoid pre branch.
mhc_sinkhorn_eps – Epsilon used by Sinkhorn normalization.
mhc_post_mult_value – Multiplicative scale applied to the post branch.
sinkhorn_repeat – Number of Sinkhorn normalization iterations.
backend – Fused-stage backend selector.
threads – Kernel thread count.
hidden_block – Hidden dimension tile size used by the layer-input reduction.
pass_config – Backend pass configuration name.
compile_profile – Optional MUSA compiler flag profile.
post_mix – Optional contiguous float32 output buffer with shape
[num_tokens, mhc_mult].comb_mix – Optional contiguous float32 output buffer with shape
[num_tokens, mhc_mult * mhc_mult].layer_input – Optional contiguous bfloat16 output buffer with shape
[num_tokens, hidden_size].
- Returns:
tuple[torch.Tensor, torch.Tensor, torch.Tensor] –
(post_mix, comb_mix, layer_input).residual_flatmust have contiguous shape[num_tokens, mhc_mult, hidden_size]. Prenorm inputs and optional outputbuffers must also be contiguous; this avoids hidden copies and keeps the
indexing path specialized for the hot serving case. The returned
tensors are
(post_mix, comb_mix, layer_input)with shapes[num_tokens, mhc_mult],[num_tokens, mhc_mult * mhc_mult], and[num_tokens, hidden_size]. Optional output buffers may be provided toavoid allocations in hot serving paths or graph-capture paths.