.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "examples/rms_norm.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_rms_norm.py: Root Mean Square Normalization Example ================================= This example demonstrates how to implement a Root Mean Square (RMS) normalization operation using Helion. .. GENERATED FROM PYTHON SOURCE LINES 10-12 Imports ------- .. GENERATED FROM PYTHON SOURCE LINES 12-21 .. code-block:: Python from __future__ import annotations import torch import helion from helion._testing import run_example import helion.language as hl .. GENERATED FROM PYTHON SOURCE LINES 22-24 RMS Normalization Kernel --------------------- .. GENERATED FROM PYTHON SOURCE LINES 24-60 .. code-block:: Python @helion.kernel(static_shapes=True) def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-5) -> torch.Tensor: """ Performs Root Mean Square (RMS) normalization on the input tensor. RMS normalization normalizes by the root mean square of the elements: output = x / sqrt(mean(x^2) + eps) * weight Args: x: Input tensor of shape [M, N] weight: Scale parameter of shape [N] eps: Small constant for numerical stability Returns: Output tensor of shape [M, N] with RMS normalization applied """ m, n = x.size() assert weight.size(0) == n, f"weight size mismatch {weight.size(0)} != {n}" out = torch.empty([m, n], dtype=x.dtype, device=x.device) for tile_m in hl.tile(m): x_tile = x[tile_m, :].to(torch.float32) # Compute RMS: sqrt(mean(x^2)) x_squared = x_tile * x_tile mean_x_squared = torch.mean(x_squared, dim=-1, keepdim=True) rms = torch.rsqrt(mean_x_squared + eps) # Apply normalization and weight normalized = x_tile * rms out[tile_m, :] = (normalized * weight[:].to(torch.float32)).to(out.dtype) return out .. GENERATED FROM PYTHON SOURCE LINES 61-63 Benchmark Wrapper -------------- .. GENERATED FROM PYTHON SOURCE LINES 63-78 .. code-block:: Python def rms_norm_tritonbench(H: int, inp: torch.Tensor) -> torch.Tensor: """ Wrapper for tritonbench that matches expected interface. Args: H: Hidden dimension size inp: Input tensor Returns: Normalized tensor """ weight = torch.ones(H, device=inp.device, dtype=inp.dtype) return rms_norm(inp, weight, eps=1e-6) .. GENERATED FROM PYTHON SOURCE LINES 79-81 Reference Implementation -------------------- .. GENERATED FROM PYTHON SOURCE LINES 81-102 .. code-block:: Python def rms_norm_pytorch( x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-5 ) -> torch.Tensor: """ PyTorch reference implementation of RMS normalization. Args: x: Input tensor weight: Scale parameter eps: Small constant for numerical stability Returns: Normalized tensor """ input_dtype = x.dtype hidden_states = x.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + eps) return weight * hidden_states.to(input_dtype) .. GENERATED FROM PYTHON SOURCE LINES 103-105 Verification Function ------------------- .. GENERATED FROM PYTHON SOURCE LINES 105-118 .. code-block:: Python def check(m: int, n: int) -> None: """ Verify the RMS norm kernel implementation against the PyTorch reference implementation. Args: m: First dimension of the test tensor n: Second dimension of the test tensor """ x = torch.randn([m, n], device="cuda", dtype=torch.float16) weight = torch.randn([n], device="cuda", dtype=torch.float16) run_example(rms_norm, rms_norm_pytorch, (x, weight, 1e-5)) .. GENERATED FROM PYTHON SOURCE LINES 119-121 Main Function ----------- .. GENERATED FROM PYTHON SOURCE LINES 121-137 .. code-block:: Python def main() -> None: """ Main entry point that runs the RMS norm kernel verification with different tensor sizes. Tests with three configurations: - 32x64 - 128x256 - 1024x1024 """ check(32, 64) check(128, 256) check(1024, 1024) if __name__ == "__main__": main() .. _sphx_glr_download_examples_rms_norm.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: rms_norm.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: rms_norm.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: rms_norm.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_