.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "examples/softmax.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_softmax.py: Helion Softmax Kernel Examples ============================== This example demonstrates multiple Helion kernel implementations of the softmax function, including a simple wrapper around PyTorch's softmax, a decomposed version using explicit exponentiation and normalization, and a numerically optimized two-pass version. The example also includes a check function to compare these kernels against PyTorch's built-in softmax for correctness. .. GENERATED FROM PYTHON SOURCE LINES 12-21 .. code-block:: Python from __future__ import annotations import torch import helion from helion._testing import run_example import helion.language as hl .. GENERATED FROM PYTHON SOURCE LINES 22-38 .. code-block:: Python @helion.kernel() def softmax(x: torch.Tensor) -> torch.Tensor: """ Simple Helion kernel wrapping PyTorch's softmax function. Args: x (torch.Tensor): Input tensor of shape [n, m]. Returns: torch.Tensor: Softmax output tensor of the same shape. """ n, _m = x.size() out = torch.empty_like(x) for tile_n in hl.tile(n): out[tile_n, :] = torch.nn.functional.softmax(x[tile_n, :], dim=1) return out .. GENERATED FROM PYTHON SOURCE LINES 39-60 .. code-block:: Python @helion.kernel() def softmax_decomposed(x: torch.Tensor) -> torch.Tensor: """ Helion kernel implementing softmax by decomposing into max, exp, and normalization steps. This avoids using PyTorch's built-in softmax decomposition. Args: x (torch.Tensor): Input tensor of shape [n, m]. Returns: torch.Tensor: Softmax output tensor of the same shape. """ n, _m = x.size() out = torch.empty_like(x) for tile_n in hl.tile(n): values = x[tile_n, :] amax = torch.amax(values, dim=1, keepdim=True) exp = torch.exp(values - amax) sum_exp = torch.sum(exp, dim=1, keepdim=True) out[tile_n, :] = exp / sum_exp return out .. GENERATED FROM PYTHON SOURCE LINES 61-92 .. code-block:: Python @helion.kernel() def softmax_two_pass(x: torch.Tensor) -> torch.Tensor: """ Numerically optimized Helion kernel performing softmax in two passes. This version uses fewer passes but is less numerically stable. Args: x (torch.Tensor): Input tensor of shape [m, n]. Returns: torch.Tensor: Softmax output tensor of the same shape. """ m, n = x.size() out = torch.empty_like(x) block_size_m = hl.register_block_size(m) block_size_n = hl.register_block_size(n) for tile_m in hl.tile(m, block_size=block_size_m): mi = hl.full([tile_m], float("-inf"), dtype=torch.float32) di = hl.zeros([tile_m], dtype=torch.float32) for tile_n in hl.tile(n, block_size=block_size_n): values = x[tile_m, tile_n] local_amax = torch.amax(values, dim=1) mi_next = torch.maximum(mi, local_amax) di = di * torch.exp(mi - mi_next) + torch.exp( values - mi_next[:, None] ).sum(dim=1) mi = mi_next for tile_n in hl.tile(n, block_size=block_size_n): values = x[tile_m, tile_n] out[tile_m, tile_n] = torch.exp(values - mi[:, None]) / di[:, None] return out .. GENERATED FROM PYTHON SOURCE LINES 93-109 .. code-block:: Python def check(m: int, n: int) -> None: """ Runs correctness checks comparing Helion softmax kernels against PyTorch's softmax. Args: m (int): Number of rows in input tensor. n (int): Number of columns in input tensor. """ x = torch.randn([m, n], device="cuda", dtype=torch.float16) kernels = { "helion simple": softmax, # "helion decomposed": softmax_decomposed, # Disabled due to possible issues "helion two pass": softmax_two_pass, } run_example(kernels, lambda x: torch.nn.functional.softmax(x, dim=1), (x,)) .. GENERATED FROM PYTHON SOURCE LINES 110-117 .. code-block:: Python def main() -> None: """ Main function to run the softmax kernel correctness check with example input size. """ check(1024, 1024) .. GENERATED FROM PYTHON SOURCE LINES 118-120 .. code-block:: Python if __name__ == "__main__": main() .. _sphx_glr_download_examples_softmax.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: softmax.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: softmax.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: softmax.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_