.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "examples/moe_matmul_ogs.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_moe_matmul_ogs.py: Mixture-of-Experts (MoE) Matmul with Outer-Gather-Scatter (OGS) ================================================================ This example demonstrates a Helion kernel implementation of a Mixture-of-Experts matrix multiplication using an Outer-Gather-Scatter approach. It efficiently handles token routing to multiple experts with variable token counts per expert. The example includes: - The Helion kernel performing tiled matmul per expert with masking for variable token counts. - Helper functions to generate kernel arguments by sorting tokens by expert. - A reference PyTorch implementation for correctness comparison. - A check function to validate the Helion kernel against the reference. .. GENERATED FROM PYTHON SOURCE LINES 15-24 .. code-block:: Python from __future__ import annotations import torch import helion from helion._testing import run_example import helion.language as hl .. GENERATED FROM PYTHON SOURCE LINES 25-83 .. code-block:: Python @helion.kernel(static_shapes=False) def moe_matmul_ogs( A: torch.Tensor, # [T, K] - Input activations (T tokens, K features) W: torch.Tensor, # [E, K, N] - Expert weights (E experts, K input features, N output features) expert_token_counts: torch.Tensor, # [E] - Number of tokens assigned to each expert expert_token_offsets: torch.Tensor, # [E + 1] - Starting position of each expert's tokens in sorted order sorted_to_orig_token_idx: torch.Tensor, # [T] - Maps sorted token positions back to original positions max_T_per_expert_tensor: torch.Tensor, # [max_T_per_expert] - Dummy tensor whose size indicates max tokens per expert ) -> torch.Tensor: # [T, N] - Output activations """ Helion kernel implementing MoE matmul with Outer-Gather-Scatter. Args: A (torch.Tensor): Input activations of shape [T, K]. W (torch.Tensor): Expert weights of shape [E, K, N]. expert_token_counts (torch.Tensor): Number of tokens per expert [E]. expert_token_offsets (torch.Tensor): Starting offsets of tokens per expert [E+1]. sorted_to_orig_token_idx (torch.Tensor): Maps sorted token indices to original token indices [T]. max_T_per_expert_tensor (torch.Tensor): Dummy tensor to indicate max tokens per expert. Returns: torch.Tensor: Output activations of shape [T, N]. """ T, K = A.shape E, _, N = W.shape max_T_per_expert = max_T_per_expert_tensor.numel() C = torch.zeros( T, N, dtype=torch.promote_types(A.dtype, W.dtype), device=A.device, ) for e_idx in hl.grid(E): start = expert_token_offsets[e_idx] num_tokens = expert_token_counts[e_idx] if num_tokens != 0: for tile_t, tile_n in hl.tile([max_T_per_expert, N]): local_token_offsets = tile_t.index token_valid = local_token_offsets < num_tokens local_token_offsets_valid = torch.where( token_valid, local_token_offsets, 0 ) expert_sorted_token_indices = start + local_token_offsets_valid expert_orig_token_indices = sorted_to_orig_token_idx[ expert_sorted_token_indices.squeeze(0) ] acc = hl.zeros([tile_t, tile_n], dtype=torch.float32) for tile_k in hl.tile(K): A_frag = A[expert_orig_token_indices, tile_k] W_frag = W[e_idx, tile_k, tile_n] acc = torch.addmm(acc, A_frag, W_frag) block_T, block_N = acc.size() existing_values = C[expert_orig_token_indices, tile_n] mask_2d = token_valid.view(block_T, 1).expand(block_T, block_N) C[expert_orig_token_indices, tile_n] = torch.where( mask_2d, acc.to(C.dtype), existing_values ) return C .. GENERATED FROM PYTHON SOURCE LINES 84-124 .. code-block:: Python def moe_matmul_ogs_helion_kernel_args_gen( A: torch.Tensor, # [T, K] - Input activations W: torch.Tensor, # [E, K, N] - Expert weights top1_expert_per_token: torch.Tensor, # [T] - Expert assignment for each token (0 to E-1) ) -> tuple[ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor ]: """ Generates arguments for the Helion MoE matmul OGS kernel. Sorts tokens by expert, computes token counts and offsets per expert, and prepares a dummy tensor for max tokens per expert. Args: A (torch.Tensor): Input activations [T, K]. W (torch.Tensor): Expert weights [E, K, N]. top1_expert_per_token (torch.Tensor): Expert assignment per token [T]. Returns: Tuple of tensors to be passed as kernel arguments. """ E = W.size(0) device = A.device sorted_to_orig_token_idx = torch.argsort(top1_expert_per_token, stable=True).to( torch.int32 ) expert_token_counts = torch.bincount(top1_expert_per_token, minlength=E).to( torch.int32 ) expert_token_offsets = torch.empty(E + 1, dtype=torch.int32, device=device) expert_token_offsets[0] = 0 expert_token_offsets[1:] = torch.cumsum(expert_token_counts, 0, dtype=torch.int32) max_T_per_expert = int(expert_token_counts.max().item()) return ( A, W, expert_token_counts, expert_token_offsets, sorted_to_orig_token_idx, torch.empty(max_T_per_expert, device=device), ) .. GENERATED FROM PYTHON SOURCE LINES 125-151 .. code-block:: Python def moe_matmul_ogs_reference( A: torch.Tensor, W: torch.Tensor, top1_expert_per_token: torch.Tensor ) -> torch.Tensor: """ Reference PyTorch implementation of MoE matmul with OGS. Performs matmul per expert by selecting tokens assigned to each expert. Args: A (torch.Tensor): Input activations [T, K]. W (torch.Tensor): Expert weights [E, K, N]. top1_expert_per_token (torch.Tensor): Expert assignment per token [T]. Returns: torch.Tensor: Output activations [T, N]. """ T, K = A.shape N = W.size(2) device, dtype = A.device, torch.promote_types(A.dtype, W.dtype) C = torch.empty(T, N, device=device, dtype=dtype) n_experts = W.size(0) for e in range(n_experts): token_idx = (top1_expert_per_token == e).nonzero(as_tuple=True)[0] if token_idx.numel() == 0: continue C[token_idx] = A[token_idx] @ W[e] return C .. GENERATED FROM PYTHON SOURCE LINES 152-181 .. code-block:: Python def check(T: int, K: int, N: int, n_experts: int) -> None: """ Validates the Helion MoE matmul OGS kernel against the reference implementation. Generates random inputs and expert assignments, runs both implementations, and compares their outputs. Args: T (int): Number of tokens. K (int): Number of input features. N (int): Number of output features. n_experts (int): Number of experts. """ dtype = torch.float16 device = "cuda" if torch.cuda.is_available() else "cpu" A = torch.randn(T, K, device=device, dtype=dtype) W = torch.randn(n_experts, K, N, device=device, dtype=dtype) top1_expert_per_token = torch.randint(n_experts, (T,), device=device) helion_kernel_args = moe_matmul_ogs_helion_kernel_args_gen( A, W, top1_expert_per_token ) def helion_fn() -> torch.Tensor: return moe_matmul_ogs(*helion_kernel_args) def reference_fn() -> torch.Tensor: return moe_matmul_ogs_reference(A, W, top1_expert_per_token) run_example(helion_fn, reference_fn, ()) .. GENERATED FROM PYTHON SOURCE LINES 182-189 .. code-block:: Python def main() -> None: """ Main entry point to run the MoE matmul OGS kernel check with example parameters. """ check(1000, 500, 200, 30) .. GENERATED FROM PYTHON SOURCE LINES 190-192 .. code-block:: Python if __name__ == "__main__": main() .. _sphx_glr_download_examples_moe_matmul_ogs.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: moe_matmul_ogs.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: moe_matmul_ogs.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: moe_matmul_ogs.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_