.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "examples/matmul_layernorm.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_matmul_layernorm.py: Matrix Multiplication with Layer Normalization Example ============================================== This example demonstrates how to implement a fused matrix multiplication and layer normalization operation using Helion. .. GENERATED FROM PYTHON SOURCE LINES 10-12 Imports ------- .. GENERATED FROM PYTHON SOURCE LINES 12-22 .. code-block:: Python from __future__ import annotations import torch import torch.nn.functional as F import helion from helion._testing import run_example import helion.language as hl .. GENERATED FROM PYTHON SOURCE LINES 23-26 MatMul-LayerNorm Kernel -------------------- static_shapes=True gives a performance boost for matmuls .. GENERATED FROM PYTHON SOURCE LINES 26-64 .. code-block:: Python @helion.kernel(static_shapes=True) def matmul_layernorm( x: torch.Tensor, y: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor ) -> torch.Tensor: """ Performs matrix multiplication followed by layer normalization. Args: x: First input tensor of shape [M, K] y: Second input tensor of shape [K, N] weight: Layer normalization weight parameter of shape [N] bias: Layer normalization bias parameter of shape [N] Returns: Output tensor of shape [M, N] containing the result of matrix multiplication followed by layer normalization """ m, k = x.size() k2 = y.size(0) n = hl.register_reduction_dim(y.size(1)) assert k == k2, f"size mismatch {k} != {k2}" assert weight.size(0) == n, f"weight size mismatch {weight.size(0)} != {n}" assert bias.size(0) == n, f"bias size mismatch {bias.size(0)} != {n}" out = torch.empty( [m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device ) for tile_m in hl.tile(m): acc = hl.zeros([tile_m, n], dtype=torch.float32) for tile_k in hl.tile(k): mm = torch.matmul(x[tile_m, tile_k], y[tile_k, :]) acc = acc + mm eps = 1e-5 var, mean = torch.var_mean(acc, dim=-1, keepdim=True, correction=0) normalized = (acc - mean) * torch.rsqrt(var + eps) acc = normalized * (weight[:].to(torch.float32)) + (bias[:].to(torch.float32)) out[tile_m, :] = acc return out .. GENERATED FROM PYTHON SOURCE LINES 65-67 Reference Implementation -------------------- .. GENERATED FROM PYTHON SOURCE LINES 67-94 .. code-block:: Python def matmul_layernorm_pytorch( x: torch.Tensor, y: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor ) -> torch.Tensor: """ PyTorch reference implementation of matrix multiplication followed by layer normalization. Args: x: First input tensor of shape [M, K] y: Second input tensor of shape [K, N] weight: Layer normalization weight parameter of shape [N] bias: Layer normalization bias parameter of shape [N] Returns: Output tensor of shape [M, N] containing the result of matrix multiplication followed by layer normalization """ matmul_out = torch.matmul(x, y) ln_out = F.layer_norm( matmul_out.to(torch.float32), normalized_shape=(matmul_out.shape[-1],), weight=weight.to(torch.float32), bias=bias.to(torch.float32), ) return ln_out.to(torch.promote_types(x.dtype, y.dtype)) .. GENERATED FROM PYTHON SOURCE LINES 95-97 Verification Function ------------------- .. GENERATED FROM PYTHON SOURCE LINES 97-113 .. code-block:: Python def check(m: int, k: int, n: int) -> None: """ Verify the matmul_layernorm kernel implementation against the PyTorch reference implementation. Args: m: First dimension of the first matrix k: Second dimension of the first matrix / First dimension of the second matrix n: Second dimension of the second matrix """ x = torch.randn([m, k], device="cuda", dtype=torch.float16) y = torch.randn([k, n], device="cuda", dtype=torch.float16) weight = torch.randn([n], device="cuda", dtype=torch.float16) bias = torch.randn([n], device="cuda", dtype=torch.float16) run_example(matmul_layernorm, matmul_layernorm_pytorch, (x, y, weight, bias)) .. GENERATED FROM PYTHON SOURCE LINES 114-116 Main Function ----------- .. GENERATED FROM PYTHON SOURCE LINES 116-133 .. code-block:: Python def main() -> None: """ Main entry point that runs the matmul_layernorm kernel verification with different matrix sizes. Tests with two configurations: - 32x64 * 64x200 - 128x256 * 256x400 """ # TODO(yf225): n=64 or 128 throws error, need to investigate # check(32, 64, 64) # check(32, 64, 128) check(32, 64, 200) check(128, 256, 400) if __name__ == "__main__": main() .. _sphx_glr_download_examples_matmul_layernorm.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: matmul_layernorm.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: matmul_layernorm.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: matmul_layernorm.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_