.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "examples/fp8_attention.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_fp8_attention.py: FP8 Attention Example ===================== This example demonstrates how to implement a scaled dot-product attention using FP8 precision in Helion. .. GENERATED FROM PYTHON SOURCE LINES 8-10 Imports ------- .. GENERATED FROM PYTHON SOURCE LINES 10-21 .. code-block:: Python from __future__ import annotations import math from typing import Callable import torch import helion import helion.language as hl .. GENERATED FROM PYTHON SOURCE LINES 22-109 .. code-block:: Python @helion.kernel(static_shapes=True) def fp8_attention_kernel( q: torch.Tensor, # [batch*heads, seq, dim] k: torch.Tensor, # [batch*heads, seq, dim] v: torch.Tensor, # [batch*heads, dim, seq] - pre-transposed batch: int, heads: int, ) -> torch.Tensor: """ Computes scaled dot-product attention using FP8 precision. Implements the attention with FP8 tensors for improved performance and memory efficiency. Args: q: Query tensor of shape [batch*heads, seq, dim] in FP8 format k: Key tensor of shape [batch*heads, seq, dim] in FP8 format v: Value tensor of shape [batch*heads, dim, seq] (pre-transposed) in FP8 format batch: Number of batches heads: Number of attention heads Returns: Output tensor of shape [batch, heads, seq_len, head_dim] in FP8 format """ batch_heads = q.size(0) seq_len = q.size(1) head_dim = q.size(2) # Output tensor with 4D shape in FP8 format out = torch.empty( [batch, heads, seq_len, head_dim], dtype=torch.float8_e4m3fn, device=q.device ) # Scale factor for attention sm_scale = 1.0 / math.sqrt(float(head_dim)) # Triton kernel multiplies sm_scale by 1.44269504 (1/log(2)) for exp2 sm_scale = sm_scale * 1.44269504 # Process each batch*head in parallel for bh in hl.grid(batch_heads): # Calculate batch and head indices b = bh // heads h = bh % heads # Process each query position for tile_m in hl.tile(seq_len): # Initialize for online softmax m_i = hl.full([tile_m], float("-inf"), dtype=torch.float32) l_i = hl.full([tile_m], 0.0, dtype=torch.float32) acc = hl.zeros([tile_m, head_dim], dtype=torch.float32) # Load query tile - keep in FP8 q_tile = q[bh, tile_m, :] # [tile_m, dim] # Compute attention scores for all keys for tile_n in hl.tile(seq_len): # Load key tile and transpose for Q @ K^T k_tile = k[bh, tile_n, :] # [tile_n, dim] - keep in FP8 k_tile_t = k_tile.transpose(0, 1) # [dim, tile_n] # Compute Q @ K^T with FP8 inputs, result in FP32 qk = hl.dot(q_tile, k_tile_t) # [tile_m, tile_n] # Scale QK scores first qk_scaled = qk * sm_scale # [tile_m, tile_n] # Compute max of scaled scores qk_max = torch.amax(qk_scaled, dim=-1) # [tile_m] # Update global max m_new = torch.maximum(m_i, qk_max) # Shift by max for numerical stability qk_shifted = qk_scaled - m_new[:, None] # Use exp2 to match Triton kernel's implementation # Note: Triton kernel already multiplies sm_scale by 1.44269504 p = torch.exp2(qk_shifted) # [tile_m, tile_n] # Sum of exponentials for this block l_ij = torch.sum(p, dim=-1) # [tile_m] # Update accumulators with correction factor # Correction factor for previous blocks alpha = torch.exp2(m_i - m_new) l_i = l_i * alpha + l_ij acc = acc * alpha[:, None] # Load values - V is [dim, seq] v_tile = v[bh, :, tile_n] # [dim, tile_n] - keep in FP8 # Convert p to FP8 for FP8 GEMM p_fp8 = p.to(v.dtype) # Convert to same FP8 type as V # Accumulate attention @ V with FP8 GEMM # v_tile is [dim, tile_n], we need to transpose for P @ V^T v_t = v_tile.t() # [tile_n, dim] acc = hl.dot(p_fp8, v_t, acc=acc) # [tile_m, dim] # Update max tracker m_i = m_new # Final normalization acc = acc / l_i[:, None] # Convert to FP8 before writing to output out[b, h, tile_m, :] = acc.to(torch.float8_e4m3fn) return out .. GENERATED FROM PYTHON SOURCE LINES 110-136 .. code-block:: Python def preprocess_fp8_attention_inputs( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Preprocesses attention inputs by converting them to FP8 format and reshaping. Args: q: Query tensor of shape [batch, heads, seq_len, head_dim] k: Key tensor of shape [batch, heads, seq_len, head_dim] v: Value tensor of shape [batch, heads, seq_len, head_dim] Returns: Tuple of (q_fp8, k_fp8, v_fp8) where: - q_fp8: Query tensor in FP8 format with shape [batch*heads, seq_len, head_dim] - k_fp8: Key tensor in FP8 format with shape [batch*heads, seq_len, head_dim] - v_fp8: Value tensor in FP8 format with shape [batch*heads, head_dim, seq_len] (pre-transposed) """ q_fp8 = q.to(torch.float8_e4m3fn) k_fp8 = k.to(torch.float8_e4m3fn) v = v.permute(0, 1, 3, 2) v_fp8 = v.to(torch.float8_e4m3fn) batch, heads, seq_len, head_dim = q.shape q_fp8_reshaped = q_fp8.reshape(batch * heads, seq_len, head_dim) k_fp8_reshaped = k_fp8.reshape(batch * heads, seq_len, head_dim) v_fp8_reshaped = v_fp8.reshape(batch * heads, head_dim, seq_len) return q_fp8_reshaped, k_fp8_reshaped, v_fp8_reshaped .. GENERATED FROM PYTHON SOURCE LINES 137-157 .. code-block:: Python def fp8_attention_tritonbench( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor ) -> Callable[[], torch.Tensor]: """ Creates a callable function for benchmarking FP8 attention with tritonbench. Preprocesses inputs and returns a lambda function that calls the FP8 attention kernel. Args: q: Query tensor of shape [batch, heads, seq_len, head_dim] k: Key tensor of shape [batch, heads, seq_len, head_dim] v: Value tensor of shape [batch, heads, seq_len, head_dim] Returns: A callable function that executes the FP8 attention kernel """ batch, heads, seq_len, head_dim = q.shape q_fp8, k_fp8, v_fp8 = preprocess_fp8_attention_inputs(q, k, v) # Return lambda that calls the kernel - preprocessing is done outside. # This matches the tritonbench kernel timing measurement setup. return lambda: fp8_attention_kernel(q_fp8, k_fp8, v_fp8, batch, heads) .. GENERATED FROM PYTHON SOURCE LINES 158-241 .. code-block:: Python def _fp8_attention_pytorch_impl( q_fp8: torch.Tensor, k_fp8: torch.Tensor, v_fp8: torch.Tensor, batch: int, heads: int, seq_len: int, head_dim: int, ) -> torch.Tensor: """ PyTorch implementation of FP8 attention for comparison with the kernel version. Args: q_fp8: Query tensor in FP8 format with shape [batch*heads, seq_len, head_dim] k_fp8: Key tensor in FP8 format with shape [batch*heads, seq_len, head_dim] v_fp8: Value tensor in FP8 format with shape [batch*heads, head_dim, seq_len] (pre-transposed) batch: Number of batches heads: Number of attention heads seq_len: Sequence length head_dim: Dimension of each attention head Returns: Output tensor of shape [batch, heads, seq_len, head_dim] in FP8 format """ sm_scale = 1.0 / math.sqrt(float(head_dim)) outputs = [] for i in range(batch * heads): q_i = q_fp8[i] # [seq, dim] - already FP8 k_i = k_fp8[i] # [seq, dim] - already FP8 v_i = v_fp8[i] # [dim, seq] - pre-transposed, already FP8 # For Q @ K^T using torch._scaled_mm # torch._scaled_mm requires column-major for second operand # k_i is [seq, dim], we need K^T as [dim, seq] in column-major # Direct conversion: k_i -> contiguous -> transpose view kt_fp8_col_major = k_i.contiguous().t() # [dim, seq] in column-major # Create scale tensors scale_q = torch.tensor(1.0, device=q_i.device) scale_k = torch.tensor(1.0, device=k_i.device) # Q @ K^T using torch._scaled_mm qk = torch._scaled_mm( q_i, kt_fp8_col_major, scale_q, scale_k, use_fast_accum=False, out_dtype=torch.float32, ) # Compute max before scaling qk_max = torch.amax(qk, dim=-1, keepdim=True) # Scale and shift in one operation, then use exp2 qk_scaled_shifted = qk * sm_scale - qk_max * sm_scale p = torch.exp2(qk_scaled_shifted * 1.44269504) # Normalize p_norm = p / p.sum(dim=-1, keepdim=True) # Step 2: Attention @ V using FP8 # P is [seq, seq], V is [dim, seq] # We want P @ V^T = [seq, seq] @ [seq, dim] = [seq, dim] p_fp8 = p_norm.to(torch.float8_e4m3fn) # row-major [seq, seq] # v_i is [dim, seq], already FP8 # Direct conversion: v_i -> contiguous -> transpose view vt_fp8_col_major = v_i.contiguous().t() # [seq, dim] in column-major # Create scale tensors for P @ V^T scale_p = torch.tensor(1.0, device=p_fp8.device) scale_v = torch.tensor(1.0, device=v_i.device) # P @ V^T using torch._scaled_mm out_i = torch._scaled_mm( p_fp8, vt_fp8_col_major, scale_p, scale_v, use_fast_accum=False, out_dtype=torch.float32, ) out_i = out_i.to(torch.float8_e4m3fn) # convert back to FP8 to match kernel outputs.append(out_i) # Stack and reshape back out_stacked = torch.stack(outputs, dim=0) # [batch*heads, seq, dim] return out_stacked.reshape(batch, heads, seq_len, head_dim) .. GENERATED FROM PYTHON SOURCE LINES 242-259 .. code-block:: Python def fp8_attention_pytorch( q: torch.Tensor, # [batch, heads, seq, dim] k: torch.Tensor, # [batch, heads, seq, dim] v: torch.Tensor, # [batch, heads, seq, dim] ) -> Callable[[], torch.Tensor]: """ Baseline PyTorch implementation of FP8 attention using torch._scaled_mm. """ batch, heads, seq_len, head_dim = q.shape q_fp8, k_fp8, v_fp8 = preprocess_fp8_attention_inputs(q, k, v) # Return lambda that calls the kernel - preprocessing is done outside. # This matches the Helion kernel timing measurement setup. return lambda: _fp8_attention_pytorch_impl( q_fp8, k_fp8, v_fp8, batch, heads, seq_len, head_dim ) .. GENERATED FROM PYTHON SOURCE LINES 260-286 .. code-block:: Python def check(batch: int, heads: int, seq_len: int, head_dim: int) -> None: """ Verifies the FP8 attention kernel implementation against the PyTorch reference implementation. Args: batch: Number of batches heads: Number of attention heads seq_len: Sequence length head_dim: Dimension of each attention head """ torch.manual_seed(42) q = torch.randn(batch, heads, seq_len, head_dim, dtype=torch.float16, device="cuda") k = torch.randn(batch, heads, seq_len, head_dim, dtype=torch.float16, device="cuda") v = torch.randn(batch, heads, seq_len, head_dim, dtype=torch.float16, device="cuda") from helion._testing import run_example helion_fn = fp8_attention_tritonbench(q, k, v) pytorch_fn = fp8_attention_pytorch(q, k, v) run_example( helion_fn, pytorch_fn, (), atol=0.1, rtol=0.1, ) .. GENERATED FROM PYTHON SOURCE LINES 287-297 .. code-block:: Python def main() -> None: """ Main entry point that runs the FP8 attention kernel verification with different configurations. Tests with small, medium, and large attention configurations. """ check(1, 2, 128, 64) check(2, 4, 256, 64) check(4, 8, 512, 128) .. GENERATED FROM PYTHON SOURCE LINES 298-300 .. code-block:: Python if __name__ == "__main__": main() .. _sphx_glr_download_examples_fp8_attention.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: fp8_attention.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: fp8_attention.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: fp8_attention.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_