Note
Go to the end to download the full example code
Helion Matmul Kernel Example#
This example demonstrates a Helion kernel implementation of matrix multiplication with support for a customizable epilogue function. It includes autotuning, correctness checks against PyTorch baselines, and integration with tritonbench.
from __future__ import annotations
from typing import TYPE_CHECKING
from typing import Any
import torch
from torch import Tensor
import helion
from helion._compat import use_tileir_tunables
from helion._testing import DEVICE
from helion._testing import HALF_DTYPE
from helion._testing import run_example
import helion.language as hl
if TYPE_CHECKING:
from collections.abc import Callable
@helion.kernel(
# static_shapes=True gives a performance boost for matmuls
static_shapes=True,
# Disable autotung over unrolling/range_num_stages
# tl.dot is pipelined with num_stages
# Note: range_unroll_factors and range_num_stages are not supported for tileir backend
autotune_config_overrides={
"range_unroll_factors": [0, 0],
"range_num_stages": [0, 0],
}
if not use_tileir_tunables()
else {},
)
def matmul(
x: Tensor,
y: Tensor,
epilogue: Callable[[Tensor, tuple[Tensor, ...]], Tensor] = lambda acc, tile: acc,
) -> Tensor:
"""
Performs matrix multiplication of x and y with an optional epilogue function.
Args:
x (Tensor): Left matrix of shape [m, k].
y (Tensor): Right matrix of shape [k, n].
epilogue (Callable, optional): Function applied to the accumulator and tile indices
after the matmul. Defaults to identity (no change).
Returns:
Tensor: Resulting matrix of shape [m, n].
"""
m, k = x.size()
k2, n = y.size()
assert k == k2, f"size mismatch {k} != {k2}"
out = torch.empty(
[m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device
)
for tile_m, tile_n in hl.tile([m, n]):
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
for tile_k in hl.tile(k):
acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n])
out[tile_m, tile_n] = epilogue(acc, (tile_m, tile_n))
return out
class MatMulFunction(torch.autograd.Function):
@staticmethod
def forward( # pyrefly: ignore [bad-override]
ctx: Any, # noqa: ANN401
mat1: Tensor,
mat2: Tensor,
) -> Tensor:
"""Forward pass for matrix multiplication."""
result = matmul(mat1, mat2)
ctx.save_for_backward(mat1, mat2)
return result
@staticmethod
def backward(
ctx: Any, # noqa: ANN401
*grad_outputs: Tensor,
) -> tuple[Tensor | None, Tensor | None]:
grad_out = grad_outputs[0]
mat1, mat2 = ctx.saved_tensors
grad_mat1 = matmul(grad_out, mat2.T) if ctx.needs_input_grad[0] else None
grad_mat2 = matmul(mat1.T, grad_out) if ctx.needs_input_grad[1] else None
return grad_mat1, grad_mat2
def matmul_autograd(mat1: Tensor, mat2: Tensor) -> Tensor:
"""Matrix multiplication with forward + backward support."""
return MatMulFunction.apply(mat1, mat2) # type: ignore[no-any-return]
class AddMMFunction(torch.autograd.Function):
@staticmethod
def forward( # pyrefly: ignore [bad-override]
ctx: Any, # noqa: ANN401
bias: Tensor,
mat1: Tensor,
mat2: Tensor,
alpha: float = 1.0,
beta: float = 1.0,
) -> Tensor:
"""Forward pass for addmm operation using helion matmul with epilogue."""
m, k = mat1.size()
k2, n = mat2.size()
input_broadcasted = torch.broadcast_to(bias, [m, n])
# Define epilogue that adds bias: alpha * acc + beta * bias
def addmm_epilogue(acc: Tensor, tile: tuple[Tensor, ...]) -> Tensor:
return alpha * acc + beta * input_broadcasted[tile[0], tile[1]]
result = matmul(mat1, mat2, addmm_epilogue)
ctx.save_for_backward(bias, mat1, mat2)
ctx.alpha = alpha
ctx.beta = beta
return result
@staticmethod
def backward(
ctx: Any, # noqa: ANN401
*grad_outputs: Tensor,
) -> tuple[Tensor | None, Tensor | None, Tensor | None, None, None]:
grad_out = grad_outputs[0]
bias, mat1, mat2 = ctx.saved_tensors
alpha = ctx.alpha
beta = ctx.beta
def scale_by_alpha(acc: Tensor, tile: tuple[Tensor, ...]) -> Tensor:
return alpha * acc
grad_bias = (beta * grad_out) if ctx.needs_input_grad[0] else None
grad_mat1 = (
matmul(grad_out, mat2.T, scale_by_alpha)
if ctx.needs_input_grad[1]
else None
)
grad_mat2 = (
matmul(mat1.T, grad_out, scale_by_alpha)
if ctx.needs_input_grad[2]
else None
)
return grad_bias, grad_mat1, grad_mat2, None, None
def addmm_autograd(
bias: Tensor, mat1: Tensor, mat2: Tensor, alpha: float = 1.0, beta: float = 1.0
) -> Tensor:
"""AddMM operation with forward + backward support."""
return AddMMFunction.apply(bias, mat1, mat2, alpha, beta) # type: ignore[no-any-return]
def autotune(m: int, k: int, n: int) -> None:
"""
Runs autotuning on the matmul kernel with a ReLU epilogue and saves the best config.
Args:
m (int): Number of rows in matrix x.
k (int): Number of columns in matrix x and rows in matrix y.
n (int): Number of columns in matrix y.
"""
x = torch.randn([m, k], device=DEVICE, dtype=HALF_DTYPE)
y = torch.randn([k, n], device=DEVICE, dtype=HALF_DTYPE)
bias = torch.randn([n], device=DEVICE, dtype=HALF_DTYPE)
args = (x, y, lambda acc, tile: torch.relu(acc + bias[tile[1]]))
best_config = matmul.autotune(args, force=True)
print(f"Best config: {best_config}")
best_config.save("best_config.json")
def check(m: int, k: int, n: int) -> None:
"""
Checks the correctness of the matmul kernel against PyTorch baselines.
Tests:
- Plain matmul without bias.
- Matmul with bias added in the epilogue.
- Matmul with a more complex epilogue applying ReLU after bias addition.
Args:
m (int): Number of rows in matrix x.
k (int): Number of columns in matrix x and rows in matrix y.
n (int): Number of columns in matrix y.
"""
x = torch.randn([m, k], device=DEVICE, dtype=HALF_DTYPE)
y = torch.randn([k, n], device=DEVICE, dtype=HALF_DTYPE)
bias = torch.randn([n], device=DEVICE, dtype=HALF_DTYPE)
bias_scalar = torch.randn([1], device=DEVICE, dtype=HALF_DTYPE)
# Test without bias
run_example(matmul, torch.matmul, (x, y))
# Test for addmm with scalar bias
def addmm(bias: Tensor, mat1: Tensor, mat2: Tensor) -> Tensor:
m, k = mat1.size()
k2, n = mat2.size()
bias = torch.broadcast_to(bias, [m, n])
return matmul(mat1, mat2, lambda acc, tile: acc + bias[tile[0], tile[1]])
run_example(addmm, torch.addmm, (bias_scalar, x, y))
# Test with bias
def helion_linear(x: Tensor, y: Tensor, bias: Tensor) -> Tensor:
return matmul(x, y, lambda acc, tile: acc + bias[tile[1]])
def baseline_linear(x: Tensor, y: Tensor, bias: Tensor) -> Tensor:
return torch.nn.functional.linear(x, y.T, bias)
run_example(helion_linear, baseline_linear, (x, y, bias))
# Test more complex epilogue
def epilogue(acc: Tensor, tile: tuple[Tensor, ...]) -> Tensor:
# The epilogue can use the captured bias tensor that is implicitly lifted to a kernel arg
return torch.relu(acc + bias[tile[1]])
def kernel_wrapper(x: Tensor, y: Tensor) -> Tensor:
return matmul(x, y, epilogue)
def baseline_wrapper(x: Tensor, y: Tensor) -> Tensor:
return torch.relu(x @ y + bias)
run_example(
kernel_wrapper,
baseline_wrapper,
(x, y),
)
# Test matmul forward + backward pass
print("\n\n=== MatMul Forward + Backward Pass Test ===")
x_grad = torch.randn([m, k], device=DEVICE, dtype=HALF_DTYPE, requires_grad=True)
y_grad = torch.randn([k, n], device=DEVICE, dtype=HALF_DTYPE, requires_grad=True)
run_example(
matmul_autograd,
torch.matmul,
(x_grad, y_grad),
kernel_name="helion_matmul_autograd",
baseline_name="torch",
rtol=1e-2,
atol=1e-2,
bwd=True,
)
# Test addmm forward + backward pass
print("\n\n=== AddMM Forward + Backward Pass Test ===")
input_grad = torch.randn(
[m, n], device=DEVICE, dtype=HALF_DTYPE, requires_grad=True
)
mat1_grad = torch.randn([m, k], device=DEVICE, dtype=HALF_DTYPE, requires_grad=True)
mat2_grad = torch.randn([k, n], device=DEVICE, dtype=HALF_DTYPE, requires_grad=True)
# Use lambda to handle the keyword argument format for torch.addmm
run_example(
addmm_autograd,
lambda bias, mat1, mat2, alpha, beta: torch.addmm(
bias, mat1, mat2, alpha=alpha, beta=beta
),
(input_grad, mat1_grad, mat2_grad, 1.0, 1.0),
kernel_name="helion_addmm_autograd",
baseline_name="torch",
rtol=1e-2,
atol=1e-2,
bwd=True,
)
# Test addmm forward + backward with different alpha/beta values
print("\n\n=== AddMM Forward + Backward Test (Alpha=2.0, Beta=0.5) ===")
run_example(
addmm_autograd,
lambda bias, mat1, mat2, alpha, beta: torch.addmm(
bias, mat1, mat2, alpha=alpha, beta=beta
),
(input_grad, mat1_grad, mat2_grad, 2.0, 0.5),
kernel_name="helion_addmm_autograd_scaled",
baseline_name="torch",
rtol=1e-2,
atol=1e-2,
bwd=True,
)
def matmul_tritonbench(
tb_op: object, a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None
) -> Callable:
"""
Wrapper for tritonbench that matches its interface.
Args:
tb_op: TritonBench operator instance
a (torch.Tensor): Left matrix.
b (torch.Tensor): Right matrix.
bias (torch.Tensor or None): Optional bias to add in the epilogue.
Returns:
Callable: A callable that runs the matmul kernel with or without bias.
"""
if bias is not None:
# For gemm with bias, use matmul_autograd and add bias
return lambda: matmul_autograd(a, b) + bias
return lambda: matmul_autograd(a, b)
def addmm_tritonbench(
tb_op: object, bias: Tensor, mat1: Tensor, mat2: Tensor
) -> Callable:
"""
Wrapper for tritonbench that performs a matrix multiplication of the matrices
`mat1` and `mat2` followed by adding `bias` to the result.
Args:
bias (torch.Tensor): Bias to add in the epilogue.
mat1 (torch.Tensor): Left matrix.
mat2 (torch.Tensor): Right matrix.
Returns:
Callable: A callable that runs the addmm autograd function with bias.
"""
return lambda: addmm_autograd(bias, mat1, mat2)
def main() -> None:
"""
Main function to run autotuning (commented out) and correctness checks.
"""
# autotune(1024, 1024, 1024)
check(1024, 1024, 1024)
if __name__ == "__main__":
main()
Total running time of the script: (0 minutes 0.000 seconds)