.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "examples/fp8_gemm.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_fp8_gemm.py: FP8 General Matrix Multiplication (GEMM) with Helion ==================================================== This example demonstrates an FP8 GEMM kernel implemented in Helion. The kernel performs matrix multiplication on FP8 inputs, accumulating results in FP32 for accuracy, and outputs in FP16 format. It includes a reference PyTorch implementation using torch._scaled_mm for correctness comparison, and a test function to validate the kernel. .. GENERATED FROM PYTHON SOURCE LINES 11-28 .. code-block:: Python from __future__ import annotations import os import torch import helion from helion._testing import run_example import helion.language as hl # Override default config to work around Triton tl.dot requirement: # `AssertionError: Input shapes should have M >= 16, N >= 16 and K >= 32` config = None if os.environ.get("HELION_USE_DEFAULT_CONFIG") == "1": config = helion.Config(block_sizes=[32, 32, 32]) .. GENERATED FROM PYTHON SOURCE LINES 29-60 .. code-block:: Python @helion.kernel(static_shapes=True, config=config) def fp8_gemm(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """ FP8 General Matrix Multiplication (GEMM). This kernel demonstrates FP8 computation in Helion. When lowered to Triton, the tl.dot operation will handle FP8 inputs natively and accumulate to FP32. Args: x (torch.Tensor): Input tensor of shape [m, k] in FP8 format. y (torch.Tensor): Input tensor of shape [k, n] in FP8 format. Returns: torch.Tensor: Output tensor of shape [m, n] in FP16 format. """ m, k = x.size() k2, n = y.size() assert k == k2, f"size mismatch {k} != {k2}" # Output is in FP16 to match tritonbench behavior out = torch.empty([m, n], dtype=torch.float16, device=x.device) for tile_m, tile_n in hl.tile([m, n]): # Accumulate in FP32 for accuracy acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) for tile_k in hl.tile(k): # Load FP8 tiles directly - no conversion needed x_tile = x[tile_m, tile_k] y_tile = y[tile_k, tile_n] # Use hl.dot for FP8 GEMM acc = hl.dot(x_tile, y_tile, acc=acc) out[tile_m, tile_n] = acc.to(torch.float16) return out .. GENERATED FROM PYTHON SOURCE LINES 61-81 .. code-block:: Python def reference_fp8_gemm_pytorch( x_fp8: torch.Tensor, y_fp8: torch.Tensor ) -> torch.Tensor: """ Reference implementation using torch._scaled_mm. Args: x_fp8 (torch.Tensor): Input tensor in FP8 format. y_fp8 (torch.Tensor): Input tensor in FP8 format. Returns: torch.Tensor: Output tensor in FP16 format. """ # torch._scaled_mm requires column-major for second operand y_fp8_t = y_fp8.T.contiguous().T scale_a = torch.tensor(1.0, device=x_fp8.device) scale_b = torch.tensor(1.0, device=x_fp8.device) return torch._scaled_mm( x_fp8, y_fp8_t, scale_a, scale_b, use_fast_accum=False, out_dtype=torch.float16 ) .. GENERATED FROM PYTHON SOURCE LINES 82-94 .. code-block:: Python def fp8_gemm_tritonbench(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: """ Wrapper for TritonBench compatibility. Args: a (torch.Tensor): Left input tensor in FP8 format. b (torch.Tensor): Right input tensor in FP8 format. Returns: torch.Tensor: Output tensor in FP16 format. """ return fp8_gemm(a, b) .. GENERATED FROM PYTHON SOURCE LINES 95-112 .. code-block:: Python def check(m: int, k: int, n: int) -> None: """ Test the FP8 GEMM implementation against the PyTorch reference. Args: m (int): Number of rows in the left input matrix. k (int): Shared dimension. n (int): Number of columns in the right input matrix. """ # Create FP8 tensors x = torch.randn([m, k], device="cuda", dtype=torch.float32) y = torch.randn([k, n], device="cuda", dtype=torch.float32) # Convert to FP8 format (e4m3fn is commonly used for forward pass) x_fp8 = x.to(torch.float8_e4m3fn) y_fp8 = y.to(torch.float8_e4m3fn) run_example(fp8_gemm, reference_fp8_gemm_pytorch, (x_fp8, y_fp8)) .. GENERATED FROM PYTHON SOURCE LINES 113-122 .. code-block:: Python def main() -> None: """ Main function to run tests with different matrix sizes. """ check(256, 256, 256) check(512, 512, 512) check(1024, 1024, 1024) .. GENERATED FROM PYTHON SOURCE LINES 123-125 .. code-block:: Python if __name__ == "__main__": main() .. _sphx_glr_download_examples_fp8_gemm.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: fp8_gemm.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: fp8_gemm.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: fp8_gemm.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_