.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "examples/matmul.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_matmul.py: Helion Matmul Kernel Example ============================ This example demonstrates a Helion kernel implementation of matrix multiplication with support for a customizable epilogue function. It includes autotuning, correctness checks against PyTorch baselines, and integration with tritonbench. .. GENERATED FROM PYTHON SOURCE LINES 10-25 .. code-block:: Python from __future__ import annotations from typing import TYPE_CHECKING import torch from torch import Tensor import helion from helion._testing import run_example import helion.language as hl if TYPE_CHECKING: from collections.abc import Callable .. GENERATED FROM PYTHON SOURCE LINES 26-59 .. code-block:: Python @helion.kernel( # static_shapes=True gives a performance boost for matmuls static_shapes=True, ) def matmul( x: Tensor, y: Tensor, epilogue: Callable[[Tensor, tuple[Tensor, ...]], Tensor] = lambda acc, tile: acc, ) -> Tensor: """ Performs matrix multiplication of x and y with an optional epilogue function. Args: x (Tensor): Left matrix of shape [m, k]. y (Tensor): Right matrix of shape [k, n]. epilogue (Callable, optional): Function applied to the accumulator and tile indices after the matmul. Defaults to identity (no change). Returns: Tensor: Resulting matrix of shape [m, n]. """ m, k = x.size() k2, n = y.size() assert k == k2, f"size mismatch {k} != {k2}" out = torch.empty( [m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device ) for tile_m, tile_n in hl.tile([m, n]): acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) for tile_k in hl.tile(k): acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) out[tile_m, tile_n] = epilogue(acc, (tile_m, tile_n)) return out .. GENERATED FROM PYTHON SOURCE LINES 60-77 .. code-block:: Python def autotune(m: int, k: int, n: int) -> None: """ Runs autotuning on the matmul kernel with a ReLU epilogue and saves the best config. Args: m (int): Number of rows in matrix x. k (int): Number of columns in matrix x and rows in matrix y. n (int): Number of columns in matrix y. """ x = torch.randn([m, k], device="cuda", dtype=torch.float16) y = torch.randn([k, n], device="cuda", dtype=torch.float16) bias = torch.randn([n], device="cuda", dtype=torch.float16) args = (x, y, lambda acc, tile: torch.relu(acc + bias[tile[1]])) best_config = matmul.autotune(args, force=True) print(f"Best config: {best_config}") best_config.save("best_config.json") .. GENERATED FROM PYTHON SOURCE LINES 78-123 .. code-block:: Python def check(m: int, k: int, n: int) -> None: """ Checks the correctness of the matmul kernel against PyTorch baselines. Tests: - Plain matmul without bias. - Matmul with bias added in the epilogue. - Matmul with a more complex epilogue applying ReLU after bias addition. Args: m (int): Number of rows in matrix x. k (int): Number of columns in matrix x and rows in matrix y. n (int): Number of columns in matrix y. """ x = torch.randn([m, k], device="cuda", dtype=torch.float16) y = torch.randn([k, n], device="cuda", dtype=torch.float16) bias = torch.randn([n], device="cuda", dtype=torch.float16) # Test without bias run_example(matmul, torch.matmul, (x, y)) # Test with bias def helion_linear(x: Tensor, y: Tensor, bias: Tensor) -> Tensor: return matmul(x, y, lambda acc, tile: acc + bias[tile[1]]) def baseline_linear(x: Tensor, y: Tensor, bias: Tensor) -> Tensor: return torch.nn.functional.linear(x, y.T, bias) run_example(helion_linear, baseline_linear, (x, y, bias)) # Test more complex epilogue def epilogue(acc: Tensor, tile: tuple[Tensor, ...]) -> Tensor: # The epilogue can use the captured bias tensor that is implicitly lifted to a kernel arg return torch.relu(acc + bias[tile[1]]) def kernel_wrapper(x: Tensor, y: Tensor) -> Tensor: return matmul(x, y, epilogue) def baseline_wrapper(x: Tensor, y: Tensor) -> Tensor: return torch.relu(x @ y + bias) run_example( kernel_wrapper, baseline_wrapper, (x, y), ) .. GENERATED FROM PYTHON SOURCE LINES 124-141 .. code-block:: Python def matmul_tritonbench( a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None ) -> Callable: """ Wrapper for tritonbench that matches its interface. Args: a (torch.Tensor): Left matrix. b (torch.Tensor): Right matrix. bias (torch.Tensor or None): Optional bias to add in the epilogue. Returns: Callable: A callable that runs the matmul kernel with or without bias. """ if bias is not None: return lambda: matmul(a, b, lambda acc, tile: acc + bias[tile[1]]) return lambda: matmul(a, b) .. GENERATED FROM PYTHON SOURCE LINES 142-150 .. code-block:: Python def main() -> None: """ Main function to run autotuning (commented out) and correctness checks. """ # autotune(1024, 1024, 1024) check(1024, 1024, 1024) .. GENERATED FROM PYTHON SOURCE LINES 151-153 .. code-block:: Python if __name__ == "__main__": main() .. _sphx_glr_download_examples_matmul.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: matmul.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: matmul.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: matmul.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_