.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "examples/layer_norm.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_layer_norm.py: Helion Layer Normalization Forward Example ========================================== This example demonstrates a Helion kernel implementation of 1D layer normalization using FP16 inputs and compares it against PyTorch's built-in layer_norm function. .. GENERATED FROM PYTHON SOURCE LINES 9-18 .. code-block:: Python from __future__ import annotations import torch import helion from helion._testing import run_example import helion.language as hl .. GENERATED FROM PYTHON SOURCE LINES 19-57 .. code-block:: Python @helion.kernel def layer_norm_fwd( x: torch.Tensor, nomralized_shape: list[int], weight: torch.Tensor, bias: torch.Tensor, eps: float = 1e-5, ) -> torch.Tensor: """ Performs 1D layer normalization on the input tensor using Helion. Args: x (torch.Tensor): Input tensor of shape [batch_size, dim], expected to be FP16. nomralized_shape (list[int]): List containing the dimension to normalize over (should be length 1). weight (torch.Tensor): Learnable scale parameter of shape [dim]. bias (torch.Tensor): Learnable bias parameter of shape [dim]. eps (float, optional): Small value added to variance for numerical stability. Default is 1e-5. Returns: torch.Tensor: The layer-normalized output tensor of shape [batch_size, dim], in FP16. """ m, n = x.size() assert weight.size(0) == n, f"weight size mismatch {weight.size(0)} != {m}" assert bias.size(0) == n, f"bias size mismatch {bias.size(0)} != {m}" assert len(nomralized_shape) == 1, ( "Helion layer norm only supports 1D layer norm currently" ) assert nomralized_shape[0] == n, ( f"normalized shape mismatch {nomralized_shape[0]} != {n}" ) out = torch.empty([m, n], dtype=torch.float16, device=x.device) for tile_m in hl.tile(m): acc = x[tile_m, :].to(torch.float32) var, mean = torch.var_mean(acc, dim=-1, keepdim=True, correction=0) normalized = (acc - mean) * torch.rsqrt(var + eps) acc = normalized * (weight[:].to(torch.float32)) + (bias[:].to(torch.float32)) out[tile_m, :] = acc return out .. GENERATED FROM PYTHON SOURCE LINES 58-84 .. code-block:: Python def main() -> None: """ Main execution function for the layer normalization example. - Generates random input, weight, and bias tensors. - Runs the Helion layer normalization kernel and compares its output to PyTorch's built-in layer_norm function using the run_example utility. - Prints comparison results and checks for correctness within specified tolerances. """ batch_size = 32 dim = 64 device = "cuda" x = torch.randn([batch_size, dim], device=device, dtype=torch.float16) weight = torch.randn([dim], device=device, dtype=torch.float16) bias = torch.randn([dim], device=device, dtype=torch.float16) eps = 1e-4 run_example( layer_norm_fwd, torch.nn.functional.layer_norm, (x, [dim], weight, bias, eps), kernel_name="helion", baseline_name="torch", rtol=1e-3, atol=1e-3, ) .. GENERATED FROM PYTHON SOURCE LINES 85-87 .. code-block:: Python if __name__ == "__main__": main() .. _sphx_glr_download_examples_layer_norm.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: layer_norm.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: layer_norm.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: layer_norm.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_