Note
Go to the end to download the full example code.
FP8 Attention Example
This example demonstrates how to implement a scaled dot-product attention using FP8 precision in Helion.
Imports
from __future__ import annotations
import math
from typing import Callable
import torch
import helion
import helion.language as hl
@helion.kernel(static_shapes=True)
def fp8_attention_kernel(
q: torch.Tensor, # [batch*heads, seq, dim]
k: torch.Tensor, # [batch*heads, seq, dim]
v: torch.Tensor, # [batch*heads, dim, seq] - pre-transposed
batch: int,
heads: int,
) -> torch.Tensor:
"""
Computes scaled dot-product attention using FP8 precision.
Implements the attention with FP8 tensors for improved performance and memory efficiency.
Args:
q: Query tensor of shape [batch*heads, seq, dim] in FP8 format
k: Key tensor of shape [batch*heads, seq, dim] in FP8 format
v: Value tensor of shape [batch*heads, dim, seq] (pre-transposed) in FP8 format
batch: Number of batches
heads: Number of attention heads
Returns:
Output tensor of shape [batch, heads, seq_len, head_dim] in FP8 format
"""
batch_heads = q.size(0)
seq_len = q.size(1)
head_dim = q.size(2)
# Output tensor with 4D shape in FP8 format
out = torch.empty(
[batch, heads, seq_len, head_dim], dtype=torch.float8_e4m3fn, device=q.device
)
# Scale factor for attention
sm_scale = 1.0 / math.sqrt(float(head_dim))
# Triton kernel multiplies sm_scale by 1.44269504 (1/log(2)) for exp2
sm_scale = sm_scale * 1.44269504
# Process each batch*head in parallel
for bh in hl.grid(batch_heads):
# Calculate batch and head indices
b = bh // heads
h = bh % heads
# Process each query position
for tile_m in hl.tile(seq_len):
# Initialize for online softmax
m_i = hl.full([tile_m], float("-inf"), dtype=torch.float32)
l_i = hl.full([tile_m], 0.0, dtype=torch.float32)
acc = hl.zeros([tile_m, head_dim], dtype=torch.float32)
# Load query tile - keep in FP8
q_tile = q[bh, tile_m, :] # [tile_m, dim]
# Compute attention scores for all keys
for tile_n in hl.tile(seq_len):
# Load key tile and transpose for Q @ K^T
k_tile = k[bh, tile_n, :] # [tile_n, dim] - keep in FP8
k_tile_t = k_tile.transpose(0, 1) # [dim, tile_n]
# Compute Q @ K^T with FP8 inputs, result in FP32
qk = hl.dot(q_tile, k_tile_t) # [tile_m, tile_n]
# Scale QK scores first
qk_scaled = qk * sm_scale # [tile_m, tile_n]
# Compute max of scaled scores
qk_max = torch.amax(qk_scaled, dim=-1) # [tile_m]
# Update global max
m_new = torch.maximum(m_i, qk_max)
# Shift by max for numerical stability
qk_shifted = qk_scaled - m_new[:, None]
# Use exp2 to match Triton kernel's implementation
# Note: Triton kernel already multiplies sm_scale by 1.44269504
p = torch.exp2(qk_shifted) # [tile_m, tile_n]
# Sum of exponentials for this block
l_ij = torch.sum(p, dim=-1) # [tile_m]
# Update accumulators with correction factor
# Correction factor for previous blocks
alpha = torch.exp2(m_i - m_new)
l_i = l_i * alpha + l_ij
acc = acc * alpha[:, None]
# Load values - V is [dim, seq]
v_tile = v[bh, :, tile_n] # [dim, tile_n] - keep in FP8
# Convert p to FP8 for FP8 GEMM
p_fp8 = p.to(v.dtype) # Convert to same FP8 type as V
# Accumulate attention @ V with FP8 GEMM
# v_tile is [dim, tile_n], we need to transpose for P @ V^T
v_t = v_tile.t() # [tile_n, dim]
acc = hl.dot(p_fp8, v_t, acc=acc) # [tile_m, dim]
# Update max tracker
m_i = m_new
# Final normalization
acc = acc / l_i[:, None]
# Convert to FP8 before writing to output
out[b, h, tile_m, :] = acc.to(torch.float8_e4m3fn)
return out
def preprocess_fp8_attention_inputs(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Preprocesses attention inputs by converting them to FP8 format and reshaping.
Args:
q: Query tensor of shape [batch, heads, seq_len, head_dim]
k: Key tensor of shape [batch, heads, seq_len, head_dim]
v: Value tensor of shape [batch, heads, seq_len, head_dim]
Returns:
Tuple of (q_fp8, k_fp8, v_fp8) where:
- q_fp8: Query tensor in FP8 format with shape [batch*heads, seq_len, head_dim]
- k_fp8: Key tensor in FP8 format with shape [batch*heads, seq_len, head_dim]
- v_fp8: Value tensor in FP8 format with shape [batch*heads, head_dim, seq_len] (pre-transposed)
"""
q_fp8 = q.to(torch.float8_e4m3fn)
k_fp8 = k.to(torch.float8_e4m3fn)
v = v.permute(0, 1, 3, 2)
v_fp8 = v.to(torch.float8_e4m3fn)
batch, heads, seq_len, head_dim = q.shape
q_fp8_reshaped = q_fp8.reshape(batch * heads, seq_len, head_dim)
k_fp8_reshaped = k_fp8.reshape(batch * heads, seq_len, head_dim)
v_fp8_reshaped = v_fp8.reshape(batch * heads, head_dim, seq_len)
return q_fp8_reshaped, k_fp8_reshaped, v_fp8_reshaped
def fp8_attention_tritonbench(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
) -> Callable[[], torch.Tensor]:
"""
Creates a callable function for benchmarking FP8 attention with tritonbench.
Preprocesses inputs and returns a lambda function that calls the FP8 attention kernel.
Args:
q: Query tensor of shape [batch, heads, seq_len, head_dim]
k: Key tensor of shape [batch, heads, seq_len, head_dim]
v: Value tensor of shape [batch, heads, seq_len, head_dim]
Returns:
A callable function that executes the FP8 attention kernel
"""
batch, heads, seq_len, head_dim = q.shape
q_fp8, k_fp8, v_fp8 = preprocess_fp8_attention_inputs(q, k, v)
# Return lambda that calls the kernel - preprocessing is done outside.
# This matches the tritonbench kernel timing measurement setup.
return lambda: fp8_attention_kernel(q_fp8, k_fp8, v_fp8, batch, heads)
def _fp8_attention_pytorch_impl(
q_fp8: torch.Tensor,
k_fp8: torch.Tensor,
v_fp8: torch.Tensor,
batch: int,
heads: int,
seq_len: int,
head_dim: int,
) -> torch.Tensor:
"""
PyTorch implementation of FP8 attention for comparison with the kernel version.
Args:
q_fp8: Query tensor in FP8 format with shape [batch*heads, seq_len, head_dim]
k_fp8: Key tensor in FP8 format with shape [batch*heads, seq_len, head_dim]
v_fp8: Value tensor in FP8 format with shape [batch*heads, head_dim, seq_len] (pre-transposed)
batch: Number of batches
heads: Number of attention heads
seq_len: Sequence length
head_dim: Dimension of each attention head
Returns:
Output tensor of shape [batch, heads, seq_len, head_dim] in FP8 format
"""
sm_scale = 1.0 / math.sqrt(float(head_dim))
outputs = []
for i in range(batch * heads):
q_i = q_fp8[i] # [seq, dim] - already FP8
k_i = k_fp8[i] # [seq, dim] - already FP8
v_i = v_fp8[i] # [dim, seq] - pre-transposed, already FP8
# For Q @ K^T using torch._scaled_mm
# torch._scaled_mm requires column-major for second operand
# k_i is [seq, dim], we need K^T as [dim, seq] in column-major
# Direct conversion: k_i -> contiguous -> transpose view
kt_fp8_col_major = k_i.contiguous().t() # [dim, seq] in column-major
# Create scale tensors
scale_q = torch.tensor(1.0, device=q_i.device)
scale_k = torch.tensor(1.0, device=k_i.device)
# Q @ K^T using torch._scaled_mm
qk = torch._scaled_mm(
q_i,
kt_fp8_col_major,
scale_q,
scale_k,
use_fast_accum=False,
out_dtype=torch.float32,
)
# Compute max before scaling
qk_max = torch.amax(qk, dim=-1, keepdim=True)
# Scale and shift in one operation, then use exp2
qk_scaled_shifted = qk * sm_scale - qk_max * sm_scale
p = torch.exp2(qk_scaled_shifted * 1.44269504)
# Normalize
p_norm = p / p.sum(dim=-1, keepdim=True)
# Step 2: Attention @ V using FP8
# P is [seq, seq], V is [dim, seq]
# We want P @ V^T = [seq, seq] @ [seq, dim] = [seq, dim]
p_fp8 = p_norm.to(torch.float8_e4m3fn) # row-major [seq, seq]
# v_i is [dim, seq], already FP8
# Direct conversion: v_i -> contiguous -> transpose view
vt_fp8_col_major = v_i.contiguous().t() # [seq, dim] in column-major
# Create scale tensors for P @ V^T
scale_p = torch.tensor(1.0, device=p_fp8.device)
scale_v = torch.tensor(1.0, device=v_i.device)
# P @ V^T using torch._scaled_mm
out_i = torch._scaled_mm(
p_fp8,
vt_fp8_col_major,
scale_p,
scale_v,
use_fast_accum=False,
out_dtype=torch.float32,
)
out_i = out_i.to(torch.float8_e4m3fn) # convert back to FP8 to match kernel
outputs.append(out_i)
# Stack and reshape back
out_stacked = torch.stack(outputs, dim=0) # [batch*heads, seq, dim]
return out_stacked.reshape(batch, heads, seq_len, head_dim)
def fp8_attention_pytorch(
q: torch.Tensor, # [batch, heads, seq, dim]
k: torch.Tensor, # [batch, heads, seq, dim]
v: torch.Tensor, # [batch, heads, seq, dim]
) -> Callable[[], torch.Tensor]:
"""
Baseline PyTorch implementation of FP8 attention using torch._scaled_mm.
"""
batch, heads, seq_len, head_dim = q.shape
q_fp8, k_fp8, v_fp8 = preprocess_fp8_attention_inputs(q, k, v)
# Return lambda that calls the kernel - preprocessing is done outside.
# This matches the Helion kernel timing measurement setup.
return lambda: _fp8_attention_pytorch_impl(
q_fp8, k_fp8, v_fp8, batch, heads, seq_len, head_dim
)
def check(batch: int, heads: int, seq_len: int, head_dim: int) -> None:
"""
Verifies the FP8 attention kernel implementation against the PyTorch reference implementation.
Args:
batch: Number of batches
heads: Number of attention heads
seq_len: Sequence length
head_dim: Dimension of each attention head
"""
torch.manual_seed(42)
q = torch.randn(batch, heads, seq_len, head_dim, dtype=torch.float16, device="cuda")
k = torch.randn(batch, heads, seq_len, head_dim, dtype=torch.float16, device="cuda")
v = torch.randn(batch, heads, seq_len, head_dim, dtype=torch.float16, device="cuda")
from helion._testing import run_example
helion_fn = fp8_attention_tritonbench(q, k, v)
pytorch_fn = fp8_attention_pytorch(q, k, v)
run_example(
helion_fn,
pytorch_fn,
(),
atol=0.1,
rtol=0.1,
)
def main() -> None:
"""
Main entry point that runs the FP8 attention kernel verification with different configurations.
Tests with small, medium, and large attention configurations.
"""
check(1, 2, 128, 64)
check(2, 4, 256, 64)
check(4, 8, 512, 128)
if __name__ == "__main__":
main()