.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "examples/long_sum.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_long_sum.py: Long Dimension Sum Example ====================== This example demonstrates how to implement efficient sum reduction along a long dimension using Helion. .. GENERATED FROM PYTHON SOURCE LINES 9-11 Imports ------- .. GENERATED FROM PYTHON SOURCE LINES 11-20 .. code-block:: Python from __future__ import annotations import torch import helion from helion._testing import run_example import helion.language as hl .. GENERATED FROM PYTHON SOURCE LINES 21-23 Baseline Implementation ------------------- .. GENERATED FROM PYTHON SOURCE LINES 23-36 .. code-block:: Python def baseline_sum(x: torch.Tensor) -> torch.Tensor: """ PyTorch baseline implementation of sum reduction along the last dimension. Args: x: Input tensor Returns: Tensor with sum of elements along the last dimension """ return x.sum(-1) .. GENERATED FROM PYTHON SOURCE LINES 37-39 Naive Reduction Kernel ------------------ .. GENERATED FROM PYTHON SOURCE LINES 39-68 .. code-block:: Python @helion.kernel( config=helion.Config( block_sizes=[1], reduction_loops=[None], num_warps=32, num_stages=4, indexing="block_ptr", ) ) def longsum(x: torch.Tensor) -> torch.Tensor: """ Naive reduction kernel that sums elements along the last dimension. Loads the entire reduction dimension at once and reduces in registers. Args: x: Input tensor of shape [m, n] Returns: Output tensor of shape [m] containing the sum of each row """ m, _ = x.size() out = torch.empty([m], dtype=x.dtype, device=x.device) for tile_m in hl.tile(m): out[tile_m] = x[tile_m, :].sum(-1) return out .. GENERATED FROM PYTHON SOURCE LINES 69-71 Looped Reduction Kernel ------------------- .. GENERATED FROM PYTHON SOURCE LINES 71-102 .. code-block:: Python @helion.kernel( config=helion.Config( block_sizes=[1], reduction_loops=[ 32768 ], # [None] for naive reduction, [tile_size] for looped reduction num_warps=16, num_stages=5, indexing="pointer", ) ) def longsum_w_red_loop(x: torch.Tensor) -> torch.Tensor: """ Looped reduction kernel that sums elements along the last dimension. Uses a reduction loop with a specified tile size to handle large dimensions efficiently. Args: x: Input tensor of shape [m, n] Returns: Output tensor of shape [m] containing the sum of each row """ m, _ = x.size() out = torch.empty([m], dtype=x.dtype, device=x.device) for tile_m in hl.tile(m): out[tile_m] = x[tile_m, :].sum(-1) return out .. GENERATED FROM PYTHON SOURCE LINES 103-105 Manual Looped Reduction Kernel -------------------------- .. GENERATED FROM PYTHON SOURCE LINES 105-136 .. code-block:: Python @helion.kernel( config=helion.Config( block_sizes=[32768, 1], num_warps=16, num_stages=5, indexing="pointer" ) ) def longsum_manual(x: torch.Tensor) -> torch.Tensor: """ Manual implementation of looped reduction for summing elements along the last dimension. Manually implements the reduction loop with explicit accumulation and final reduction. Args: x: Input tensor of shape [m, n] Returns: Output tensor of shape [m] containing the sum of each row """ m, n = x.size() out = torch.empty([m], dtype=x.dtype, device=x.device) # Call register_block_size to know block_size_n outside of the reduction loop. block_size_n = hl.register_block_size(n) for tile_m in hl.tile(m): acc = hl.zeros([tile_m, block_size_n], dtype=x.dtype) for tile_n in hl.tile(n, block_size=block_size_n): # Reduction loop acc += x[tile_m, tile_n] out[tile_m] = acc.sum(-1) return out .. GENERATED FROM PYTHON SOURCE LINES 137-139 Verification Function ------------------- .. GENERATED FROM PYTHON SOURCE LINES 139-161 .. code-block:: Python def check(m: int, n: int) -> None: """ Verify the sum kernel implementations against PyTorch's native sum function. Tests all three kernel variants (naive, looped, manual) against the baseline. Args: m: First dimension of the test tensor n: Second dimension of the test tensor (reduction dimension) """ x = torch.randn([m, n], device="cuda", dtype=torch.float32) # Test all three kernel variants against the baseline kernels = { "helion naive": longsum, "helion loop": longsum_w_red_loop, "helion manual": longsum_manual, } run_example(kernels, baseline_sum, (x,)) .. GENERATED FROM PYTHON SOURCE LINES 162-164 Main Function ----------- .. GENERATED FROM PYTHON SOURCE LINES 164-175 .. code-block:: Python def main() -> None: """ Main entry point that runs the sum kernel verification with a large tensor. Tests with a tensor of shape [4, 130000] to demonstrate handling of long reduction dimensions. """ check(4, 130000) # seq_len = 128k if __name__ == "__main__": main() .. _sphx_glr_download_examples_long_sum.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: long_sum.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: long_sum.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: long_sum.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_