Note
Go to the end to download the full example code
BLackwell Attention Example#
This code implements a custom attention kernel using Helion and PyTorch for efficient computation of scaled dot-product attention, specifically tuned for Blackwell.
Imports#
from __future__ import annotations
import math
from typing import Callable
import torch
from triton.testing import do_bench
import helion
from helion._testing import run_example
from helion.autotuner.config_fragment import EnumFragment
import helion.language as hl
Utility Functions#
def _mul_f32x2(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""Vectorized F32 PTX MUL"""
return hl.inline_asm_elementwise(
"""
{
.reg .b64 ra, rb, rc;
mov.b64 ra, { $2, $3 };
mov.b64 rb, { $4, $5 };
mul.f32x2 rc, ra, rb;
mov.b64 { $0, $1 }, rc;
}
""",
"=r,=r,r,r,r,r",
[a, b],
dtype=torch.float32,
is_pure=True,
pack=2,
)
def _fma_f32x2(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
"""Vectorized F32 PTX FMA"""
return hl.inline_asm_elementwise(
"""
{
.reg .b64 ra, rb, rc;
mov.b64 ra, { $2, $3 };
mov.b64 rb, { $4, $5 };
mul.f32x2 rc, ra, rb;
mov.b64 { $0, $1 }, rc;
}
""",
"=r,=r,r,r,r,r",
[a, b, c],
dtype=torch.float32,
is_pure=True,
pack=2,
)
Attention Kernel Implementation#
pyrefly: ignore [no-matching-overload]
@helion.kernel(
configs=[
helion.Config(
block_sizes=[256, N],
range_warp_specializes=[OUTER_LOOP or None, None if OUTER_LOOP else True],
range_multi_buffers=[None, False],
pid_type="persistent_interleaved",
indexing="tensor_descriptor",
num_warps=4,
num_stages=3,
_triton_range_id_data_partition_factor=0,
_triton_range_value_data_partition_factor=2,
_triton_config_maxRegAutoWS=maxreg,
)
for N in [64, 128]
for OUTER_LOOP in [True]
for maxreg in [152, 192]
],
static_shapes=True,
autotune_accuracy_check=False,
)
def blackwell_attention_kernel(
q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor, qk_scale: float
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Computes scaled dot-product attention.
Implements the attention mechanism: Attention(Q, K, V) = softmax(Q * K^T / sqrt(d_k)) * V
Args:
q_in: Query tensor of shape [..., seq_len_q, head_dim]
k_in: Key tensor of shape [..., seq_len_k, head_dim]
v_in: Value tensor of shape [..., seq_len_k, head_dim]
Returns:
Output tensor of shape [..., seq_len_q, head_dim]
"""
B, H, M, D = q_in.shape
Bk, Hk, N, Dk = k_in.shape
assert Dk == D
assert Bk == B
assert Hk == H
Bv, Hv, Nv, Dv = v_in.shape
assert Bv == B
assert Hv == Hk
assert Nv == N
D = hl.specialize(D)
Dv = hl.specialize(Dv)
q = q_in.reshape(-1, D)
k = k_in.reshape(-1, D)
v = v_in.reshape(-1, Dv)
MM = q.shape[0]
assert v.shape[0] == k.shape[0]
o = q.new_empty(MM, Dv)
lse = q.new_empty(MM, dtype=torch.float32)
block_m = hl.register_block_size(M)
block_n = hl.register_block_size(N)
assert M % block_m == 0
assert N % block_n == 0
hl.register_tunable(
"_triton_range_id_data_partition_factor", EnumFragment(choices=(0,))
)
hl.register_tunable(
"_triton_range_value_data_partition_factor", EnumFragment(choices=(2,))
)
hl.register_tunable("_triton_config_maxRegAutoWS", EnumFragment(choices=(152, 192)))
SUBTILING = True
VECT_MUL = 1
qk_scale = qk_scale * 1.44269504 # 1/log(2)
for tile_m in hl.tile(MM, block_size=block_m):
m_i = hl.zeros([tile_m]) - float("inf")
l_i = hl.zeros([tile_m]) + 1.0
acc = hl.zeros([tile_m, Dv])
q_i = q[tile_m, :]
start_N = tile_m.begin // M * N
for tile_n in hl.tile(N, block_size=block_n):
k_j = k[tile_n + start_N, :]
v_j = v[tile_n + start_N, :]
qk = hl.dot(q_i, k_j.T, out_dtype=torch.float32)
m_ij = torch.maximum(m_i, torch.amax(qk, -1) * qk_scale)
if VECT_MUL == 2 or VECT_MUL == 3:
# pyrefly: ignore [bad-argument-type]
qk = _fma_f32x2(qk, qk_scale, -m_ij[:, None])
else:
qk = qk * qk_scale - m_ij[:, None]
p = torch.exp2(qk)
# -- compute correction factor
alpha = torch.exp2(m_i - m_ij)
l_ij = torch.sum(p, -1)
if SUBTILING:
acc0, acc1 = hl.split(
# pyrefly: ignore [no-matching-overload]
acc.reshape([tile_m, 2, Dv // 2]).permute(0, 2, 1)
)
if VECT_MUL == 1 or VECT_MUL == 3:
acc0 = _mul_f32x2(acc0, alpha[:, None])
acc1 = _mul_f32x2(acc1, alpha[:, None])
else:
acc0 = acc0 * alpha[:, None]
acc1 = acc1 * alpha[:, None]
acc = (
hl.join(acc0, acc1)
.permute(0, 2, 1)
.reshape(acc.size(0), acc.size(1))
)
else:
acc = acc * alpha[:, None]
# update m_i and l_i
# We can potentially move these to be before updating l_ij, so the dot
# is not blocked.
# prepare p and v for the dot
p = p.to(v.dtype)
# note that this non transposed v for FP8 is only supported on Blackwell
acc = hl.dot(p, v_j, acc=acc)
l_i = l_i * alpha + l_ij
m_i = m_ij
m_i += torch.log2(l_i)
acc = acc / l_i[:, None]
lse[tile_m] = m_i
o[tile_m, :] = acc
return o.reshape(B, H, M, Dv), lse.reshape(B, H, M)
def blackwell_attention(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
return blackwell_attention_kernel(q, k, v, qk_scale=math.sqrt(1.0 / q.shape[-1]))
def blackwell_attention_tritonbench(
tb_mod: object, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
) -> Callable:
return lambda: blackwell_attention(q, k, v)
Testing Function#
def test(
z: int,
h: int,
n_ctx: int,
head_dim: int,
dtype: torch.dtype = torch.float32,
device: torch.device | str = "cuda",
) -> None:
"""
Test the attention kernel implementation against PyTorch's native attention functions.
Args:
z: Batch size
h: Number of attention heads
n_ctx: Sequence length (context size)
head_dim: Dimension of each attention head
dtype: Data type for the tensors
device: Device to run the test on
"""
q, k, v = [
torch.randn((z, h, n_ctx, head_dim), dtype=dtype, device=device)
for _ in range(3)
]
def ref_attention(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
) -> torch.Tensor:
"""Reference manual attention implementation"""
p = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_dim)
p = torch.softmax(p.float(), dim=-1).to(dtype)
return torch.matmul(p, v)
baselines = {
"torch": torch.nn.functional.scaled_dot_product_attention,
"ref": ref_attention,
}
run_example(
lambda *args: blackwell_attention(*args)[0],
baselines,
(q, k, v),
atol=0.1,
rtol=0.1,
)
# pyrefly: ignore [bad-assignment]
dur: float = do_bench(lambda: blackwell_attention(q, k, v))
print(
f"{z=} {h=} {n_ctx=} {head_dim=} tflops={z * h * n_ctx * n_ctx * head_dim * 4 / dur * 1e-9:.2f}"
)
Main Function#
def main() -> None:
"""
Main entry point that runs the attention kernel test with specific parameters.
Tests with batch size 2, 32 heads, 1024 sequence length, and 64-dimensional heads using float16.
"""
test(4, 32, 8192, 64, torch.bfloat16)
test(4, 32, 8192, 128, torch.bfloat16)
if __name__ == "__main__":
main()
Total running time of the script: (0 minutes 0.000 seconds)