.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "examples/bmm.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_bmm.py: Batch Matrix Multiplication Example =============================== This example demonstrates how to implement a batch matrix multiplication kernel using Helion. .. GENERATED FROM PYTHON SOURCE LINES 9-11 Imports ------- .. GENERATED FROM PYTHON SOURCE LINES 11-20 .. code-block:: Python from __future__ import annotations import torch import helion from helion._testing import run_example import helion.language as hl .. GENERATED FROM PYTHON SOURCE LINES 21-24 Batch Matrix Multiplication Kernel ------------------------------- static_shapes=True gives a performance boost for matmuls .. GENERATED FROM PYTHON SOURCE LINES 24-52 .. code-block:: Python @helion.kernel(static_shapes=True) def bmm(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: """ Performs batch matrix multiplication. Args: A: Input tensor of shape [B, M, K] B: Input tensor of shape [B, K, N] Returns: Output tensor of shape [B, M, N] containing the result of batch matrix multiplication """ # A: [B, M, K], B: [B, K, N], Out: [B, M, N] # dense bmm b, m, k = A.size() b, k, n = B.size() out = torch.empty( [b, m, n], device=A.device, dtype=torch.promote_types(A.dtype, B.dtype) ) for tile_b, tile_m, tile_n in hl.tile([b, m, n]): acc = hl.zeros([tile_b, tile_m, tile_n], dtype=torch.float32) for tile_k in hl.tile(k): acc = torch.baddbmm( acc, A[tile_b, tile_m, tile_k], B[tile_b, tile_k, tile_n] ) out[tile_b, tile_m, tile_n] = acc return out .. GENERATED FROM PYTHON SOURCE LINES 53-55 Verification Function ------------------- .. GENERATED FROM PYTHON SOURCE LINES 55-70 .. code-block:: Python def check(b: int, m: int, k: int, n: int) -> None: """ Verify the bmm kernel implementation against PyTorch's native bmm function. Args: b: Batch size m: First dimension of the first matrix k: Second dimension of the first matrix / First dimension of the second matrix n: Second dimension of the second matrix """ x = torch.randn([b, m, k], device="cuda", dtype=torch.float16) y = torch.randn([b, k, n], device="cuda", dtype=torch.float16) run_example(bmm, torch.bmm, (x, y)) .. GENERATED FROM PYTHON SOURCE LINES 71-73 Main Function ----------- .. GENERATED FROM PYTHON SOURCE LINES 73-86 .. code-block:: Python def main() -> None: """ Main entry point that runs the bmm kernel verification with specific parameters. Tests with batch size 16, and matrices of dimensions 512x768 and 768x1024. Ensures torch version is at least 2.8 for 16-bit tensor support in baddbmm. """ # torch.baddbmm support for 16-bit tensors requires torch 2.8+ assert torch.__version__.split(".")[:2] >= ["2", "8"], "Requires torch 2.8+" check(16, 512, 768, 1024) if __name__ == "__main__": main() .. _sphx_glr_download_examples_bmm.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: bmm.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: bmm.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: bmm.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_